From a4e52008c7e7d35c0fd2fe9ad9ecba2afcc8fb18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=88=99=E6=96=87?= Date: Thu, 15 Jan 2026 23:20:55 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20src?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/streamlit_tweet_app.py | 353 +++++++++++++++++++++++++++++++++++++ 1 file changed, 353 insertions(+) create mode 100644 src/streamlit_tweet_app.py diff --git a/src/streamlit_tweet_app.py b/src/streamlit_tweet_app.py new file mode 100644 index 0000000..aa4ecdf --- /dev/null +++ b/src/streamlit_tweet_app.py @@ -0,0 +1,353 @@ +"""Streamlit 演示应用 - 推文情感分析 + +航空推文情感分析 AI 助手 - 支持情感分类、解释和处置方案生成。 +""" + +import os +import sys + +import streamlit as st + +# Ensure project root is in path +sys.path.append(os.getcwd()) + +from dotenv import load_dotenv + +from src.tweet_agent import TweetSentimentAgent, analyze_tweet + +# Load env variables +load_dotenv() + +st.set_page_config(page_title="航空推文情感分析", page_icon="✈️", layout="wide") + +# Sidebar Configuration +st.sidebar.header("🔧 配置") +st.sidebar.markdown("### 模型信息") +st.sidebar.info( + """ + **模型**: VotingClassifier (5个基学习器) + - Logistic Regression + - Multinomial Naive Bayes + - Random Forest + - LightGBM + - XGBoost + + **性能**: Macro-F1 = 0.7533 ✅ + """ +) + +st.sidebar.markdown("---") +# Mode Selection +mode = st.sidebar.radio("功能选择", ["📝 单条分析", "📊 批量分析", "📈 数据概览"]) + +# Initialize session state +if "agent" not in st.session_state: + with st.spinner("🔄 加载模型..."): + st.session_state.agent = TweetSentimentAgent() + +if "batch_results" not in st.session_state: + st.session_state.batch_results = [] + + +# --- Helper Functions --- + + +def get_sentiment_emoji(sentiment: str) -> str: + """获取情感对应的表情符号""" + emoji_map = { + "negative": "😠", + "neutral": "😐", + "positive": "😊", + } + return emoji_map.get(sentiment, "❓") + + +def get_sentiment_color(sentiment: str) -> str: + """获取情感对应的颜色""" + color_map = { + "negative": "#ff6b6b", + "neutral": "#ffd93d", + "positive": "#6bcb77", + } + return color_map.get(sentiment, "#e0e0e0") + + +def get_priority_color(priority: str) -> str: + """获取优先级对应的颜色""" + color_map = { + "high": "#ff4757", + "medium": "#ffa502", + "low": "#2ed573", + } + return color_map.get(priority, "#e0e0e0") + + +# --- Main Views --- + +if mode == "📝 单条分析": + st.title("✈️ 航空推文情感分析") + st.markdown("输入推文文本,获取 AI 驱动的情感分析、解释和处置方案。") + + # Input form + with st.form("tweet_analysis_form"): + col1, col2 = st.columns([3, 1]) + + with col1: + tweet_text = st.text_area( + "推文内容", + placeholder="@United This is the worst airline ever! My flight was delayed for 5 hours...", + height=100, + ) + + with col2: + airline = st.selectbox( + "航空公司", + ["United", "US Airways", "American", "Southwest", "Delta", "Virgin America"], + ) + + submitted = st.form_submit_button("🔍 分析", type="primary") + + if submitted and tweet_text: + with st.spinner("🤖 AI 正在分析..."): + try: + result = analyze_tweet(tweet_text, airline) + + # Display results + st.divider() + + # Header with sentiment + sentiment_emoji = get_sentiment_emoji(result.classification.sentiment) + sentiment_color = get_sentiment_color(result.classification.sentiment) + + st.markdown( + f""" +
+

{sentiment_emoji} {result.classification.sentiment.upper()}

+

置信度: {result.classification.confidence:.1%}

+
+ """, + unsafe_allow_html=True, + ) + + st.divider() + + # Original tweet + st.subheader("📝 原始推文") + st.info(f"**航空公司**: {result.airline}\n\n**内容**: {result.tweet_text}") + + # Explanation + st.subheader("🔍 情感解释") + st.markdown("**关键因素:**") + for factor in result.explanation.key_factors: + st.write(f"- {factor}") + + st.markdown("**推理过程:**") + st.write(result.explanation.reasoning) + + # Disposal plan + st.subheader("📋 处置方案") + + priority_color = get_priority_color(result.disposal_plan.priority) + st.markdown( + f""" +
+ 优先级: {result.disposal_plan.priority.upper()} +
+

+ **行动类型**: {result.disposal_plan.action_type} + """, + unsafe_allow_html=True, + ) + + if result.disposal_plan.suggested_response: + st.markdown("**建议回复:**") + st.success(result.disposal_plan.suggested_response) + + st.markdown("**后续行动:**") + for action in result.disposal_plan.follow_up_actions: + st.write(f"- {action}") + + except Exception as e: + st.error(f"分析失败: {e!s}") + +elif mode == "📊 批量分析": + st.title("📊 批量推文分析") + st.markdown("上传 CSV 文件或输入多条推文,进行批量情感分析。") + + # Input method selection + input_method = st.radio("输入方式", ["手动输入", "CSV 上传"], horizontal=True) + + if input_method == "手动输入": + st.markdown("### 输入推文(每行一条)") + tweets_input = st.text_area( + "推文列表", + placeholder="@United Flight delayed again!\n@Southwest Great service!\n@American Baggage policy?", + height=200, + ) + + if st.button("🔍 批量分析", type="primary") and tweets_input: + lines = [line.strip() for line in tweets_input.split("\n") if line.strip()] + + if lines: + with st.spinner("🤖 AI 正在分析..."): + results = [] + for line in lines: + try: + # Extract airline from tweet (simple heuristic) + airline = "United" # Default + for a in ["United", "US Airways", "American", "Southwest", "Delta", "Virgin America"]: + if a.lower() in line.lower(): + airline = a + break + + result = analyze_tweet(line, airline) + results.append(result) + except Exception as e: + st.warning(f"分析失败: {line[:50]}... - {e}") + + if results: + st.session_state.batch_results = results + st.success(f"✅ 成功分析 {len(results)} 条推文") + + else: # CSV upload + st.markdown("### 上传 CSV 文件") + st.info("CSV 文件应包含以下列: `text` (推文内容), `airline` (航空公司)") + + uploaded_file = st.file_uploader("选择文件", type=["csv"]) + + if uploaded_file and st.button("🔍 分析上传文件", type="primary"): + import pandas as pd + + try: + df = pd.read_csv(uploaded_file) + + if "text" not in df.columns: + st.error("CSV 文件必须包含 'text' 列") + else: + with st.spinner("🤖 AI 正在分析..."): + results = [] + for _, row in df.iterrows(): + try: + text = row["text"] + airline = row.get("airline", "United") + result = analyze_tweet(text, airline) + results.append(result) + except Exception as e: + st.warning(f"分析失败: {text[:50]}... - {e}") + + if results: + st.session_state.batch_results = results + st.success(f"✅ 成功分析 {len(results)} 条推文") + + except Exception as e: + st.error(f"文件读取失败: {e!s}") + + # Display batch results + if st.session_state.batch_results: + st.divider() + st.subheader(f"📊 分析结果 ({len(st.session_state.batch_results)} 条)") + + # Summary statistics + sentiments = [r.classification.sentiment for r in st.session_state.batch_results] + negative_count = sentiments.count("negative") + neutral_count = sentiments.count("neutral") + positive_count = sentiments.count("positive") + + col1, col2, col3 = st.columns(3) + col1.metric("😠 负面", negative_count) + col2.metric("😐 中性", neutral_count) + col3.metric("😊 正面", positive_count) + + # Detailed results table + st.markdown("### 详细结果") + + results_data = [] + for r in st.session_state.batch_results: + results_data.append({ + "推文": r.tweet_text[:50] + "..." if len(r.tweet_text) > 50 else r.tweet_text, + "航空公司": r.airline, + "情感": f"{get_sentiment_emoji(r.classification.sentiment)} {r.classification.sentiment}", + "置信度": f"{r.classification.confidence:.1%}", + "优先级": r.disposal_plan.priority.upper(), + "行动类型": r.disposal_plan.action_type, + }) + + st.dataframe(results_data, use_container_width=True) + + # Clear button + if st.button("🗑️ 清除结果"): + st.session_state.batch_results = [] + st.rerun() + +elif mode == "📈 数据概览": + st.title("📈 数据集概览") + st.markdown("查看训练数据集的统计信息。") + + try: + import polars as pl + from src.tweet_data import load_cleaned_tweets, print_data_summary + + df = load_cleaned_tweets("data/Tweets_cleaned.csv") + + # Display summary + st.subheader("📊 数据统计") + print_data_summary(df, "数据集统计") + + # Display sample data + st.subheader("📝 样本数据") + sample_df = df.head(10).to_pandas() + st.dataframe(sample_df, use_container_width=True) + + # Sentiment distribution chart + st.subheader("📈 情感分布") + sentiment_counts = df.group_by("airline_sentiment").agg( + pl.col("airline_sentiment").count().alias("count") + ).sort("count", descending=True) + + import pandas as pd + import plotly.express as px + + sentiment_df = sentiment_counts.to_pandas() + fig = px.pie( + sentiment_df, + values="count", + names="airline_sentiment", + title="情感分布", + color_discrete_map={ + "negative": "#ff6b6b", + "neutral": "#ffd93d", + "positive": "#6bcb77", + }, + ) + st.plotly_chart(fig, use_container_width=True) + + # Airline distribution chart + st.subheader("✈️ 航空公司分布") + airline_counts = df.group_by("airline").agg( + pl.col("airline").count().alias("count") + ).sort("count", descending=True) + + airline_df = airline_counts.to_pandas() + fig = px.bar( + airline_df, + x="airline", + y="count", + title="各航空公司推文数量", + color="count", + color_continuous_scale="Blues", + ) + st.plotly_chart(fig, use_container_width=True) + + except Exception as e: + st.error(f"数据加载失败: {e!s}") + +# Footer +st.divider() +st.markdown( + """ +
+ 航空推文情感分析 AI 助手 | 基于 VotingClassifier (LR + NB + RF + LightGBM + XGBoost) +
+ """, + unsafe_allow_html=True, +)