"""Streamlit 演示应用 - 推文情感分析 航空推文情感分析 AI 助手 - 支持情感分类、解释和处置方案生成。 """ import os import sys import streamlit as st # Ensure project root is in path sys.path.append(os.getcwd()) from dotenv import load_dotenv from src.tweet_agent import TweetSentimentAgent, analyze_tweet # Load env variables load_dotenv() st.set_page_config(page_title="航空推文情感分析", page_icon="✈️", layout="wide") # Sidebar Configuration st.sidebar.header("🔧 配置") st.sidebar.markdown("### 模型信息") st.sidebar.info( """ **模型**: VotingClassifier (5个基学习器) - Logistic Regression - Multinomial Naive Bayes - Random Forest - LightGBM - XGBoost **性能**: Macro-F1 = 0.7533 ✅ """ ) st.sidebar.markdown("---") # Mode Selection mode = st.sidebar.radio("功能选择", ["📝 单条分析", "📊 批量分析", "📈 数据概览"]) # Initialize session state if "agent" not in st.session_state: with st.spinner("🔄 加载模型..."): st.session_state.agent = TweetSentimentAgent() if "batch_results" not in st.session_state: st.session_state.batch_results = [] # --- Helper Functions --- def get_sentiment_emoji(sentiment: str) -> str: """获取情感对应的表情符号""" emoji_map = { "negative": "😠", "neutral": "😐", "positive": "😊", } return emoji_map.get(sentiment, "❓") def get_sentiment_color(sentiment: str) -> str: """获取情感对应的颜色""" color_map = { "negative": "#ff6b6b", "neutral": "#ffd93d", "positive": "#6bcb77", } return color_map.get(sentiment, "#e0e0e0") def get_priority_color(priority: str) -> str: """获取优先级对应的颜色""" color_map = { "high": "#ff4757", "medium": "#ffa502", "low": "#2ed573", } return color_map.get(priority, "#e0e0e0") # --- Main Views --- if mode == "📝 单条分析": st.title("✈️ 航空推文情感分析") st.markdown("输入推文文本,获取 AI 驱动的情感分析、解释和处置方案。") # Input form with st.form("tweet_analysis_form"): col1, col2 = st.columns([3, 1]) with col1: tweet_text = st.text_area( "推文内容", placeholder="@United This is the worst airline ever! My flight was delayed for 5 hours...", height=100, ) with col2: airline = st.selectbox( "航空公司", ["United", "US Airways", "American", "Southwest", "Delta", "Virgin America"], ) submitted = st.form_submit_button("🔍 分析", type="primary") if submitted and tweet_text: with st.spinner("🤖 AI 正在分析..."): try: result = analyze_tweet(tweet_text, airline) # Display results st.divider() # Header with sentiment sentiment_emoji = get_sentiment_emoji(result.classification.sentiment) sentiment_color = get_sentiment_color(result.classification.sentiment) st.markdown( f"""

{sentiment_emoji} {result.classification.sentiment.upper()}

置信度: {result.classification.confidence:.1%}

""", unsafe_allow_html=True, ) st.divider() # Original tweet st.subheader("📝 原始推文") st.info(f"**航空公司**: {result.airline}\n\n**内容**: {result.tweet_text}") # Explanation st.subheader("🔍 情感解释") st.markdown("**关键因素:**") for factor in result.explanation.key_factors: st.write(f"- {factor}") st.markdown("**推理过程:**") st.write(result.explanation.reasoning) # Disposal plan st.subheader("📋 处置方案") priority_color = get_priority_color(result.disposal_plan.priority) st.markdown( f"""
优先级: {result.disposal_plan.priority.upper()}


**行动类型**: {result.disposal_plan.action_type} """, unsafe_allow_html=True, ) if result.disposal_plan.suggested_response: st.markdown("**建议回复:**") st.success(result.disposal_plan.suggested_response) st.markdown("**后续行动:**") for action in result.disposal_plan.follow_up_actions: st.write(f"- {action}") except Exception as e: st.error(f"分析失败: {e!s}") elif mode == "📊 批量分析": st.title("📊 批量推文分析") st.markdown("上传 CSV 文件或输入多条推文,进行批量情感分析。") # Input method selection input_method = st.radio("输入方式", ["手动输入", "CSV 上传"], horizontal=True) if input_method == "手动输入": st.markdown("### 输入推文(每行一条)") tweets_input = st.text_area( "推文列表", placeholder="@United Flight delayed again!\n@Southwest Great service!\n@American Baggage policy?", height=200, ) if st.button("🔍 批量分析", type="primary") and tweets_input: lines = [line.strip() for line in tweets_input.split("\n") if line.strip()] if lines: with st.spinner("🤖 AI 正在分析..."): results = [] for line in lines: try: # Extract airline from tweet (simple heuristic) airline = "United" # Default for a in ["United", "US Airways", "American", "Southwest", "Delta", "Virgin America"]: if a.lower() in line.lower(): airline = a break result = analyze_tweet(line, airline) results.append(result) except Exception as e: st.warning(f"分析失败: {line[:50]}... - {e}") if results: st.session_state.batch_results = results st.success(f"✅ 成功分析 {len(results)} 条推文") else: # CSV upload st.markdown("### 上传 CSV 文件") st.info("CSV 文件应包含以下列: `text` (推文内容), `airline` (航空公司)") uploaded_file = st.file_uploader("选择文件", type=["csv"]) if uploaded_file and st.button("🔍 分析上传文件", type="primary"): import pandas as pd try: df = pd.read_csv(uploaded_file) if "text" not in df.columns: st.error("CSV 文件必须包含 'text' 列") else: with st.spinner("🤖 AI 正在分析..."): results = [] for _, row in df.iterrows(): try: text = row["text"] airline = row.get("airline", "United") result = analyze_tweet(text, airline) results.append(result) except Exception as e: st.warning(f"分析失败: {text[:50]}... - {e}") if results: st.session_state.batch_results = results st.success(f"✅ 成功分析 {len(results)} 条推文") except Exception as e: st.error(f"文件读取失败: {e!s}") # Display batch results if st.session_state.batch_results: st.divider() st.subheader(f"📊 分析结果 ({len(st.session_state.batch_results)} 条)") # Summary statistics sentiments = [r.classification.sentiment for r in st.session_state.batch_results] negative_count = sentiments.count("negative") neutral_count = sentiments.count("neutral") positive_count = sentiments.count("positive") col1, col2, col3 = st.columns(3) col1.metric("😠 负面", negative_count) col2.metric("😐 中性", neutral_count) col3.metric("😊 正面", positive_count) # Detailed results table st.markdown("### 详细结果") results_data = [] for r in st.session_state.batch_results: results_data.append({ "推文": r.tweet_text[:50] + "..." if len(r.tweet_text) > 50 else r.tweet_text, "航空公司": r.airline, "情感": f"{get_sentiment_emoji(r.classification.sentiment)} {r.classification.sentiment}", "置信度": f"{r.classification.confidence:.1%}", "优先级": r.disposal_plan.priority.upper(), "行动类型": r.disposal_plan.action_type, }) st.dataframe(results_data, use_container_width=True) # Clear button if st.button("🗑️ 清除结果"): st.session_state.batch_results = [] st.rerun() elif mode == "📈 数据概览": st.title("📈 数据集概览") st.markdown("查看训练数据集的统计信息。") try: import polars as pl from src.tweet_data import load_cleaned_tweets, print_data_summary df = load_cleaned_tweets("data/Tweets_cleaned.csv") # Display summary st.subheader("📊 数据统计") print_data_summary(df, "数据集统计") # Display sample data st.subheader("📝 样本数据") sample_df = df.head(10).to_pandas() st.dataframe(sample_df, use_container_width=True) # Sentiment distribution chart st.subheader("📈 情感分布") sentiment_counts = df.group_by("airline_sentiment").agg( pl.col("airline_sentiment").count().alias("count") ).sort("count", descending=True) import pandas as pd import plotly.express as px sentiment_df = sentiment_counts.to_pandas() fig = px.pie( sentiment_df, values="count", names="airline_sentiment", title="情感分布", color_discrete_map={ "negative": "#ff6b6b", "neutral": "#ffd93d", "positive": "#6bcb77", }, ) st.plotly_chart(fig, use_container_width=True) # Airline distribution chart st.subheader("✈️ 航空公司分布") airline_counts = df.group_by("airline").agg( pl.col("airline").count().alias("count") ).sort("count", descending=True) airline_df = airline_counts.to_pandas() fig = px.bar( airline_df, x="airline", y="count", title="各航空公司推文数量", color="count", color_continuous_scale="Blues", ) st.plotly_chart(fig, use_container_width=True) except Exception as e: st.error(f"数据加载失败: {e!s}") # Footer st.divider() st.markdown( """
航空推文情感分析 AI 助手 | 基于 VotingClassifier (LR + NB + RF + LightGBM + XGBoost)
""", unsafe_allow_html=True, )