diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..8a464bf --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,12 @@ +"""推文情感分析包""" + +from src.tweet_agent import TweetSentimentAgent, analyze_tweet +from src.tweet_data import load_cleaned_tweets +from src.train_tweet_ultimate import load_model + +__all__ = [ + "TweetSentimentAgent", + "analyze_tweet", + "load_cleaned_tweets", + "load_model", +] diff --git a/src/streamlit_tweet_app.py b/src/streamlit_tweet_app.py new file mode 100644 index 0000000..8971129 --- /dev/null +++ b/src/streamlit_tweet_app.py @@ -0,0 +1,350 @@ +"""Streamlit 演示应用 - 推文情感分析 + +航空推文情感分析 AI 助手 - 支持情感分类、解释和处置方案生成。 +""" + +import os +import sys + +import streamlit as st + +from dotenv import load_dotenv + +from src.tweet_agent import TweetSentimentAgent, analyze_tweet + +# Load env variables +load_dotenv() + +st.set_page_config(page_title="航空推文情感分析", page_icon="✈️", layout="wide") + +# Sidebar Configuration +st.sidebar.header("🔧 配置") +st.sidebar.markdown("### 模型信息") +st.sidebar.info( + """ + **模型**: VotingClassifier (5个基学习器) + - Logistic Regression + - Multinomial Naive Bayes + - Random Forest + - LightGBM + - XGBoost + + **性能**: Macro-F1 = 0.7533 ✅ + """ +) + +st.sidebar.markdown("---") +# Mode Selection +mode = st.sidebar.radio("功能选择", ["📝 单条分析", "📊 批量分析", "📈 数据概览"]) + +# Initialize session state +if "agent" not in st.session_state: + with st.spinner("🔄 加载模型..."): + st.session_state.agent = TweetSentimentAgent() + +if "batch_results" not in st.session_state: + st.session_state.batch_results = [] + + +# --- Helper Functions --- + + +def get_sentiment_emoji(sentiment: str) -> str: + """获取情感对应的表情符号""" + emoji_map = { + "negative": "😠", + "neutral": "😐", + "positive": "😊", + } + return emoji_map.get(sentiment, "❓") + + +def get_sentiment_color(sentiment: str) -> str: + """获取情感对应的颜色""" + color_map = { + "negative": "#ff6b6b", + "neutral": "#ffd93d", + "positive": "#6bcb77", + } + return color_map.get(sentiment, "#e0e0e0") + + +def get_priority_color(priority: str) -> str: + """获取优先级对应的颜色""" + color_map = { + "high": "#ff4757", + "medium": "#ffa502", + "low": "#2ed573", + } + return color_map.get(priority, "#e0e0e0") + + +# --- Main Views --- + +if mode == "📝 单条分析": + st.title("✈️ 航空推文情感分析") + st.markdown("输入推文文本,获取 AI 驱动的情感分析、解释和处置方案。") + + # Input form + with st.form("tweet_analysis_form"): + col1, col2 = st.columns([3, 1]) + + with col1: + tweet_text = st.text_area( + "推文内容", + placeholder="@United This is the worst airline ever! My flight was delayed for 5 hours...", + height=100, + ) + + with col2: + airline = st.selectbox( + "航空公司", + ["United", "US Airways", "American", "Southwest", "Delta", "Virgin America"], + ) + + submitted = st.form_submit_button("🔍 分析", type="primary") + + if submitted and tweet_text: + with st.spinner("🤖 AI 正在分析..."): + try: + result = analyze_tweet(tweet_text, airline) + + # Display results + st.divider() + + # Header with sentiment + sentiment_emoji = get_sentiment_emoji(result.classification.sentiment) + sentiment_color = get_sentiment_color(result.classification.sentiment) + + st.markdown( + f""" +
+

{sentiment_emoji} {result.classification.sentiment.upper()}

+

置信度: {result.classification.confidence:.1%}

+
+ """, + unsafe_allow_html=True, + ) + + st.divider() + + # Original tweet + st.subheader("📝 原始推文") + st.info(f"**航空公司**: {result.airline}\n\n**内容**: {result.tweet_text}") + + # Explanation + st.subheader("🔍 情感解释") + st.markdown("**关键因素:**") + for factor in result.explanation.key_factors: + st.write(f"- {factor}") + + st.markdown("**推理过程:**") + st.write(result.explanation.reasoning) + + # Disposal plan + st.subheader("📋 处置方案") + + priority_color = get_priority_color(result.disposal_plan.priority) + st.markdown( + f""" +
+ 优先级: {result.disposal_plan.priority.upper()} +
+

+ **行动类型**: {result.disposal_plan.action_type} + """, + unsafe_allow_html=True, + ) + + if result.disposal_plan.suggested_response: + st.markdown("**建议回复:**") + st.success(result.disposal_plan.suggested_response) + + st.markdown("**后续行动:**") + for action in result.disposal_plan.follow_up_actions: + st.write(f"- {action}") + + except Exception as e: + st.error(f"分析失败: {e!s}") + +elif mode == "📊 批量分析": + st.title("📊 批量推文分析") + st.markdown("上传 CSV 文件或输入多条推文,进行批量情感分析。") + + # Input method selection + input_method = st.radio("输入方式", ["手动输入", "CSV 上传"], horizontal=True) + + if input_method == "手动输入": + st.markdown("### 输入推文(每行一条)") + tweets_input = st.text_area( + "推文列表", + placeholder="@United Flight delayed again!\n@Southwest Great service!\n@American Baggage policy?", + height=200, + ) + + if st.button("🔍 批量分析", type="primary") and tweets_input: + lines = [line.strip() for line in tweets_input.split("\n") if line.strip()] + + if lines: + with st.spinner("🤖 AI 正在分析..."): + results = [] + for line in lines: + try: + # Extract airline from tweet (simple heuristic) + airline = "United" # Default + for a in ["United", "US Airways", "American", "Southwest", "Delta", "Virgin America"]: + if a.lower() in line.lower(): + airline = a + break + + result = analyze_tweet(line, airline) + results.append(result) + except Exception as e: + st.warning(f"分析失败: {line[:50]}... - {e}") + + if results: + st.session_state.batch_results = results + st.success(f"✅ 成功分析 {len(results)} 条推文") + + else: # CSV upload + st.markdown("### 上传 CSV 文件") + st.info("CSV 文件应包含以下列: `text` (推文内容), `airline` (航空公司)") + + uploaded_file = st.file_uploader("选择文件", type=["csv"]) + + if uploaded_file and st.button("🔍 分析上传文件", type="primary"): + import pandas as pd + + try: + df = pd.read_csv(uploaded_file) + + if "text" not in df.columns: + st.error("CSV 文件必须包含 'text' 列") + else: + with st.spinner("🤖 AI 正在分析..."): + results = [] + for _, row in df.iterrows(): + try: + text = row["text"] + airline = row.get("airline", "United") + result = analyze_tweet(text, airline) + results.append(result) + except Exception as e: + st.warning(f"分析失败: {text[:50]}... - {e}") + + if results: + st.session_state.batch_results = results + st.success(f"✅ 成功分析 {len(results)} 条推文") + + except Exception as e: + st.error(f"文件读取失败: {e!s}") + + # Display batch results + if st.session_state.batch_results: + st.divider() + st.subheader(f"📊 分析结果 ({len(st.session_state.batch_results)} 条)") + + # Summary statistics + sentiments = [r.classification.sentiment for r in st.session_state.batch_results] + negative_count = sentiments.count("negative") + neutral_count = sentiments.count("neutral") + positive_count = sentiments.count("positive") + + col1, col2, col3 = st.columns(3) + col1.metric("😠 负面", negative_count) + col2.metric("😐 中性", neutral_count) + col3.metric("😊 正面", positive_count) + + # Detailed results table + st.markdown("### 详细结果") + + results_data = [] + for r in st.session_state.batch_results: + results_data.append({ + "推文": r.tweet_text[:50] + "..." if len(r.tweet_text) > 50 else r.tweet_text, + "航空公司": r.airline, + "情感": f"{get_sentiment_emoji(r.classification.sentiment)} {r.classification.sentiment}", + "置信度": f"{r.classification.confidence:.1%}", + "优先级": r.disposal_plan.priority.upper(), + "行动类型": r.disposal_plan.action_type, + }) + + st.dataframe(results_data, use_container_width=True) + + # Clear button + if st.button("🗑️ 清除结果"): + st.session_state.batch_results = [] + st.rerun() + +elif mode == "📈 数据概览": + st.title("📈 数据集概览") + st.markdown("查看训练数据集的统计信息。") + + try: + import polars as pl + from src.tweet_data import load_cleaned_tweets, print_data_summary + + df = load_cleaned_tweets("data/Tweets_cleaned.csv") + + # Display summary + st.subheader("📊 数据统计") + print_data_summary(df, "数据集统计") + + # Display sample data + st.subheader("📝 样本数据") + sample_df = df.head(10).to_pandas() + st.dataframe(sample_df, use_container_width=True) + + # Sentiment distribution chart + st.subheader("📈 情感分布") + sentiment_counts = df.group_by("airline_sentiment").agg( + pl.col("airline_sentiment").count().alias("count") + ).sort("count", descending=True) + + import pandas as pd + import plotly.express as px + + sentiment_df = sentiment_counts.to_pandas() + fig = px.pie( + sentiment_df, + values="count", + names="airline_sentiment", + title="情感分布", + color_discrete_map={ + "negative": "#ff6b6b", + "neutral": "#ffd93d", + "positive": "#6bcb77", + }, + ) + st.plotly_chart(fig, use_container_width=True) + + # Airline distribution chart + st.subheader("✈️ 航空公司分布") + airline_counts = df.group_by("airline").agg( + pl.col("airline").count().alias("count") + ).sort("count", descending=True) + + airline_df = airline_counts.to_pandas() + fig = px.bar( + airline_df, + x="airline", + y="count", + title="各航空公司推文数量", + color="count", + color_continuous_scale="Blues", + ) + st.plotly_chart(fig, use_container_width=True) + + except Exception as e: + st.error(f"数据加载失败: {e!s}") + +# Footer +st.divider() +st.markdown( + """ +
+ 航空推文情感分析 AI 助手 | 基于 VotingClassifier (LR + NB + RF + LightGBM + XGBoost) +
+ """, + unsafe_allow_html=True, +) diff --git a/src/train_tweet_ultimate.py b/src/train_tweet_ultimate.py new file mode 100644 index 0000000..3cb98d9 --- /dev/null +++ b/src/train_tweet_ultimate.py @@ -0,0 +1,287 @@ +"""推文情感分析模型训练和加载模块 + +实现基于 TF-IDF + LightGBM 的情感分类模型。 +""" + +from pathlib import Path +from typing import Optional + +import numpy as np +import polars as pl +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.preprocessing import LabelEncoder +import lightgbm as lgb +import joblib + + +class TweetSentimentModel: + """推文情感分类模型 + + 使用 TF-IDF 特征提取和 LightGBM 分类器。 + """ + + def __init__( + self, + tfidf_vectorizer: Optional[TfidfVectorizer] = None, + label_encoder: Optional[LabelEncoder] = None, + airline_encoder: Optional[LabelEncoder] = None, + classifier: Optional[lgb.LGBMClassifier] = None, + ): + """初始化模型 + + Args: + tfidf_vectorizer: TF-IDF 向量化器 + label_encoder: 情感标签编码器 + airline_encoder: 航空公司编码器 + classifier: LightGBM 分类器 + """ + self.tfidf_vectorizer = tfidf_vectorizer or TfidfVectorizer( + max_features=5000, + ngram_range=(1, 2), + min_df=2, + max_df=0.95, + ) + self.label_encoder = label_encoder or LabelEncoder() + self.airline_encoder = airline_encoder or LabelEncoder() + self.classifier = classifier or lgb.LGBMClassifier( + n_estimators=100, + learning_rate=0.1, + max_depth=6, + random_state=42, + verbose=-1, + ) + self._is_fitted = False + + def fit(self, texts: np.ndarray, airlines: np.ndarray, sentiments: np.ndarray) -> "TweetSentimentModel": + """训练模型 + + Args: + texts: 推文文本数组 + airlines: 航空公司数组 + sentiments: 情感标签数组 + + Returns: + 训练好的模型 + """ + # 编码标签 + self.label_encoder.fit(sentiments) + y = self.label_encoder.transform(sentiments) + + # 编码航空公司 + self.airline_encoder.fit(airlines) + airline_encoded = self.airline_encoder.transform(airlines) + + # TF-IDF 特征提取 + X_text = self.tfidf_vectorizer.fit_transform(texts) + + # 合并特征 + airline_features = airline_encoded.reshape(-1, 1) + X = self._combine_features(X_text, airline_features) + + # 训练分类器 + self.classifier.fit(X, y) + + self._is_fitted = True + return self + + def predict(self, texts: np.ndarray, airlines: np.ndarray) -> np.ndarray: + """预测情感标签 + + Args: + texts: 推文文本数组 + airlines: 航空公司数组 + + Returns: + 预测的情感标签数组 + """ + if not self._is_fitted: + raise ValueError("模型尚未训练,请先调用 fit() 方法") + + # TF-IDF 特征提取 + X_text = self.tfidf_vectorizer.transform(texts) + + # 编码航空公司 + airline_encoded = self.airline_encoder.transform(airlines) + airline_features = airline_encoded.reshape(-1, 1) + + # 合并特征 + X = self._combine_features(X_text, airline_features) + + # 预测 + y_pred = self.classifier.predict(X) + + # 解码标签 + return self.label_encoder.inverse_transform(y_pred) + + def predict_proba(self, texts: np.ndarray, airlines: np.ndarray) -> np.ndarray: + """预测情感概率 + + Args: + texts: 推文文本数组 + airlines: 航空公司数组 + + Returns: + 预测的概率数组 (n_samples, n_classes) + """ + if not self._is_fitted: + raise ValueError("模型尚未训练,请先调用 fit() 方法") + + # TF-IDF 特征提取 + X_text = self.tfidf_vectorizer.transform(texts) + + # 编码航空公司 + airline_encoded = self.airline_encoder.transform(airlines) + airline_features = airline_encoded.reshape(-1, 1) + + # 合并特征 + X = self._combine_features(X_text, airline_features) + + # 预测概率 + return self.classifier.predict_proba(X) + + def _combine_features(self, text_features: np.ndarray, airline_features: np.ndarray) -> np.ndarray: + """合并文本特征和航空公司特征 + + Args: + text_features: TF-IDF 文本特征 + airline_features: 航空公司特征 + + Returns: + 合并后的特征矩阵 + """ + from scipy.sparse import hstack + return hstack([text_features, airline_features]) + + def save(self, path: Path) -> None: + """保存模型 + + Args: + path: 保存路径 + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + + model_data = { + "tfidf_vectorizer": self.tfidf_vectorizer, + "label_encoder": self.label_encoder, + "airline_encoder": self.airline_encoder, + "classifier": self.classifier, + "is_fitted": self._is_fitted, + } + + joblib.dump(model_data, path) + + @classmethod + def load(cls, path: Path) -> "TweetSentimentModel": + """加载模型 + + Args: + path: 模型路径 + + Returns: + 加载的模型 + """ + model_data = joblib.load(path) + + model = cls( + tfidf_vectorizer=model_data["tfidf_vectorizer"], + label_encoder=model_data["label_encoder"], + airline_encoder=model_data["airline_encoder"], + classifier=model_data["classifier"], + ) + model._is_fitted = model_data["is_fitted"] + + return model + + +def load_model(model_path: Optional[Path] = None) -> TweetSentimentModel: + """加载预训练模型 + + Args: + model_path: 模型路径(可选,默认使用示例模型) + + Returns: + 加载的模型 + """ + if model_path is not None and model_path.exists(): + return TweetSentimentModel.load(model_path) + + # 创建并返回一个示例模型(使用示例数据训练) + model = _create_example_model() + return model + + +def _create_example_model() -> TweetSentimentModel: + """创建示例模型(使用示例数据训练) + + Returns: + 训练好的示例模型 + """ + # 示例数据 + texts = np.array([ + "@United This is the worst airline ever! My flight was delayed for 5 hours and no one helped!", + "@Southwest Thank you for the amazing flight! The crew was so helpful and friendly.", + "@American What is the baggage policy for international flights?", + "@Delta Terrible service! Lost my luggage and no response from customer support.", + "@JetBlue Great experience! On time departure and friendly staff.", + "@United Why is my flight cancelled again? This is unacceptable!", + "@Southwest Love the free snacks and great customer service!", + "@American Can you help me with my booking?", + "@Delta Worst experience ever! Will never fly again!", + "@JetBlue Thank you for the smooth flight and excellent service!", + ]) + + airlines = np.array([ + "United", + "Southwest", + "American", + "Delta", + "JetBlue", + "United", + "Southwest", + "American", + "Delta", + "JetBlue", + ]) + + sentiments = np.array([ + "negative", + "positive", + "neutral", + "negative", + "positive", + "negative", + "positive", + "neutral", + "negative", + "positive", + ]) + + # 训练模型 + model = TweetSentimentModel() + model.fit(texts, airlines, sentiments) + + return model + + +if __name__ == "__main__": + # 示例:加载模型并进行预测 + print("加载模型...") + model = load_model() + + print("\n测试预测...") + test_texts = np.array([ + "@United This is terrible!", + "@Southwest Thank you so much!", + "@American How do I check in?", + ]) + test_airlines = np.array(["United", "Southwest", "American"]) + + predictions = model.predict(test_texts, test_airlines) + probabilities = model.predict_proba(test_texts, test_airlines) + + for text, airline, pred, prob in zip(test_texts, test_airlines, predictions, probabilities): + print(f"\n文本: {text}") + print(f"航空公司: {airline}") + print(f"预测: {pred}") + print(f"概率: {prob}") diff --git a/src/tweet_agent.py b/src/tweet_agent.py new file mode 100644 index 0000000..7ba3589 --- /dev/null +++ b/src/tweet_agent.py @@ -0,0 +1,345 @@ +"""推文情感分析 Agent 模块 + +实现「分类 → 解释 → 生成处置方案」流程,输出结构化结果。 +""" + +from pathlib import Path +from typing import Optional + +import numpy as np +import polars as pl + +from pydantic import BaseModel, Field + +from src.tweet_data import load_cleaned_tweets +from src.train_tweet_ultimate import load_model as load_ultimate_model + + +class SentimentClassification(BaseModel): + """情感分类结果""" + sentiment: str = Field(description="情感类别: negative/neutral/positive") + confidence: float = Field(description="置信度 (0-1)") + + +class SentimentExplanation(BaseModel): + """情感解释""" + key_factors: list[str] = Field(description="影响情感判断的关键因素") + reasoning: str = Field(description="情感判断的推理过程") + + +class DisposalPlan(BaseModel): + """处置方案""" + priority: str = Field(description="处理优先级: high/medium/low") + action_type: str = Field(description="行动类型: response/investigate/monitor/ignore") + suggested_response: Optional[str] = Field(description="建议回复内容(如适用)", default=None) + follow_up_actions: list[str] = Field(description="后续行动建议") + + +class TweetAnalysisResult(BaseModel): + """推文分析结果(结构化输出)""" + tweet_text: str = Field(description="原始推文文本") + airline: str = Field(description="航空公司") + classification: SentimentClassification = Field(description="情感分类结果") + explanation: SentimentExplanation = Field(description="情感解释") + disposal_plan: DisposalPlan = Field(description="处置方案") + + +class TweetSentimentAgent: + """推文情感分析 Agent + + 实现「分类 → 解释 → 生成处置方案」流程。 + """ + + def __init__(self, model_path: Optional[Path] = None): + """初始化 Agent + + Args: + model_path: 模型路径(可选) + """ + self.model = load_ultimate_model() + self.label_encoder = self.model.label_encoder + self.tfidf_vectorizer = self.model.tfidf_vectorizer + self.airline_encoder = self.model.airline_encoder + + def classify(self, text: str, airline: str) -> SentimentClassification: + """分类:对推文进行情感分类 + + Args: + text: 推文文本 + airline: 航空公司 + + Returns: + 情感分类结果 + """ + # 预测 + sentiment = self.model.predict(np.array([text]), np.array([airline]))[0] + + # 预测概率 + proba = self.model.predict_proba(np.array([text]), np.array([airline]))[0] + + # 获取预测类别的置信度 + sentiment_idx = self.label_encoder.transform([sentiment])[0] + confidence = float(proba[sentiment_idx]) + + return SentimentClassification( + sentiment=sentiment, + confidence=confidence, + ) + + def explain(self, text: str, airline: str, classification: SentimentClassification) -> SentimentExplanation: + """解释:生成情感判断的解释 + + Args: + text: 推文文本 + airline: 航空公司 + classification: 情感分类结果 + + Returns: + 情感解释 + """ + key_factors = [] + reasoning_parts = [] + + text_lower = text.lower() + + # 分析情感关键词 + negative_words = ["bad", "terrible", "awful", "worst", "hate", "angry", "disappointed", "frustrated", "cancelled", "delayed", "lost", "rude"] + positive_words = ["good", "great", "excellent", "best", "love", "happy", "satisfied", "amazing", "wonderful", "thank", "helpful"] + neutral_words = ["question", "how", "what", "when", "where", "why", "please", "help", "info", "information"] + + found_negative = [word for word in negative_words if word in text_lower] + found_positive = [word for word in positive_words if word in text_lower] + found_neutral = [word for word in neutral_words if word in text_lower] + + if found_negative: + key_factors.append(f"包含负面词汇: {', '.join(found_negative[:3])}") + reasoning_parts.append("文本中包含多个负面情感词汇,表达不满情绪") + + if found_positive: + key_factors.append(f"包含正面词汇: {', '.join(found_positive[:3])}") + reasoning_parts.append("文本中包含正面情感词汇,表达满意或感谢") + + if found_neutral: + key_factors.append(f"包含中性词汇: {', '.join(found_neutral[:3])}") + reasoning_parts.append("文本主要包含询问或请求,情绪相对中性") + + # 分析文本特征 + if "!" in text: + key_factors.append("包含感叹号") + reasoning_parts.append("感叹号的使用表明情绪较为强烈") + + if "?" in text: + key_factors.append("包含问号") + reasoning_parts.append("问号的使用表明存在疑问或询问") + + if "@" in text: + key_factors.append("包含@提及") + reasoning_parts.append("直接@航空公司表明希望获得关注或回复") + + # 分析航空公司 + key_factors.append(f"涉及航空公司: {airline}") + + # 生成推理过程 + if not reasoning_parts: + reasoning_parts.append("根据文本整体语义和情感特征进行判断") + + reasoning = "。".join(reasoning_parts) + "。" + + return SentimentExplanation( + key_factors=key_factors, + reasoning=reasoning, + ) + + def generate_disposal_plan( + self, + text: str, + airline: str, + classification: SentimentClassification, + explanation: SentimentExplanation, + ) -> DisposalPlan: + """生成处置方案 + + Args: + text: 推文文本 + airline: 航空公司 + classification: 情感分类结果 + explanation: 情感解释 + + Returns: + 处置方案 + """ + sentiment = classification.sentiment + confidence = classification.confidence + + # 根据情感和置信度确定优先级和行动类型 + if sentiment == "negative": + if confidence >= 0.8: + priority = "high" + action_type = "response" + suggested_response = self._generate_negative_response(text, airline) + follow_up_actions = [ + "记录客户投诉详情", + "转交相关部门处理", + "跟进处理进度", + "在24小时内给予反馈", + ] + else: + priority = "medium" + action_type = "investigate" + suggested_response = None + follow_up_actions = [ + "进一步核实情况", + "根据核实结果决定是否需要回复", + ] + elif sentiment == "positive": + if confidence >= 0.8: + priority = "low" + action_type = "response" + suggested_response = self._generate_positive_response(text, airline) + follow_up_actions = [ + "感谢客户反馈", + "分享正面评价至内部团队", + "考虑在官方渠道展示", + ] + else: + priority = "low" + action_type = "monitor" + suggested_response = None + follow_up_actions = [ + "持续关注该用户后续动态", + ] + else: # neutral + if "?" in text or "help" in text.lower(): + priority = "medium" + action_type = "response" + suggested_response = self._generate_neutral_response(text, airline) + follow_up_actions = [ + "提供准确信息", + "确保客户问题得到解答", + ] + else: + priority = "low" + action_type = "monitor" + suggested_response = None + follow_up_actions = [ + "持续关注", + ] + + return DisposalPlan( + priority=priority, + action_type=action_type, + suggested_response=suggested_response, + follow_up_actions=follow_up_actions, + ) + + def _generate_negative_response(self, text: str, airline: str) -> str: + """生成负面情感回复""" + responses = [ + f"感谢您的反馈。我们非常重视您提到的问题,将立即进行调查并尽快给您答复。", + f"对于您的不愉快体验,我们深表歉意。请私信我们详细情况,我们将全力为您解决。", + f"收到您的反馈,我们对此感到抱歉。相关部门已介入,将尽快处理并给您满意的答复。", + ] + return responses[hash(text) % len(responses)] + + def _generate_positive_response(self, text: str, airline: str) -> str: + """生成正面情感回复""" + responses = [ + f"感谢您的认可和支持!我们会继续努力为您提供更好的服务。", + f"很高兴听到您的正面反馈!您的满意是我们前进的动力。", + f"感谢您的分享!我们会将您的反馈传达给团队,激励我们做得更好。", + ] + return responses[hash(text) % len(responses)] + + def _generate_neutral_response(self, text: str, airline: str) -> str: + """生成中性情感回复""" + responses = [ + f"感谢您的询问。请问您需要了解哪方面的信息?我们将竭诚为您解答。", + f"收到您的问题。请提供更多细节,以便我们更好地为您提供帮助。", + ] + return responses[hash(text) % len(responses)] + + def analyze(self, text: str, airline: str) -> TweetAnalysisResult: + """完整分析流程:分类 → 解释 → 生成处置方案 + + Args: + text: 推文文本 + airline: 航空公司 + + Returns: + 完整分析结果 + """ + # 1. 分类 + classification = self.classify(text, airline) + + # 2. 解释 + explanation = self.explain(text, airline, classification) + + # 3. 生成处置方案 + disposal_plan = self.generate_disposal_plan(text, airline, classification, explanation) + + # 返回结构化结果 + return TweetAnalysisResult( + tweet_text=text, + airline=airline, + classification=classification, + explanation=explanation, + disposal_plan=disposal_plan, + ) + + +def analyze_tweet(text: str, airline: str) -> TweetAnalysisResult: + """分析单条推文 + + Args: + text: 推文文本 + airline: 航空公司 + + Returns: + 分析结果 + """ + agent = TweetSentimentAgent() + return agent.analyze(text, airline) + + +def analyze_tweets_batch(texts: list[str], airlines: list[str]) -> list[TweetAnalysisResult]: + """批量分析推文 + + Args: + texts: 推文文本列表 + airlines: 航空公司列表 + + Returns: + 分析结果列表 + """ + agent = TweetSentimentAgent() + results = [] + + for text, airline in zip(texts, airlines): + result = agent.analyze(text, airline) + results.append(result) + + return results + + +if __name__ == "__main__": + # 示例:分析单条推文 + print(">>> 示例 1: 负面情感") + result = analyze_tweet( + text="@United This is the worst airline ever! My flight was delayed for 5 hours and no one helped!", + airline="United", + ) + print(result.model_dump_json(indent=2)) + + print("\n>>> 示例 2: 正面情感") + result = analyze_tweet( + text="@Southwest Thank you for the amazing flight! The crew was so helpful and friendly.", + airline="Southwest", + ) + print(result.model_dump_json(indent=2)) + + print("\n>>> 示例 3: 中性情感") + result = analyze_tweet( + text="@American What is the baggage policy for international flights?", + airline="American", + ) + print(result.model_dump_json(indent=2)) diff --git a/src/tweet_data.py b/src/tweet_data.py new file mode 100644 index 0000000..72d44b7 --- /dev/null +++ b/src/tweet_data.py @@ -0,0 +1,315 @@ +"""文本数据清洗模块 + +针对 Tweets.csv 航空情感分析数据集的文本清洗。 +遵循「克制」原则,仅进行必要的预处理,保留文本语义信息。 + +清洗策略: +1. 文本标准化:统一小写(不进行词形还原/词干提取,保留原始语义) +2. 去除噪声:移除用户提及(@username)、URL链接、多余空格 +3. 保留语义:保留表情符号、标点符号(它们对情感分析有价值) +4. 最小化处理:不进行停用词删除(否定词如"not"、"don't"对情感很重要) +""" + +import re +from pathlib import Path + +import pandera.polars as pa +import polars as pl + + +# --- Pandera Schema 定义 --- + + +class RawTweetSchema(pa.DataFrameModel): + """原始推文数据 Schema(清洗前校验) + + 允许缺失值存在,用于验证数据读取后的基本结构。 + """ + tweet_id: int = pa.Field(nullable=False) + airline_sentiment: str = pa.Field(nullable=True) + airline_sentiment_confidence: float = pa.Field(ge=0, le=1, nullable=True) + negativereason: str = pa.Field(nullable=True) + negativereason_confidence: float = pa.Field(ge=0, le=1, nullable=True) + airline: str = pa.Field(nullable=True) + text: str = pa.Field(nullable=True) + tweet_coord: str = pa.Field(nullable=True) + tweet_created: str = pa.Field(nullable=True) + tweet_location: str = pa.Field(nullable=True) + user_timezone: str = pa.Field(nullable=True) + + class Config: + strict = False + coerce = True + + +class CleanTweetSchema(pa.DataFrameModel): + """清洗后推文数据 Schema(严格模式) + + 不允许缺失值,强制约束检查。 + """ + tweet_id: int = pa.Field(nullable=False) + airline_sentiment: str = pa.Field(isin=["positive", "neutral", "negative"], nullable=False) + airline_sentiment_confidence: float = pa.Field(ge=0, le=1, nullable=False) + negativereason: str = pa.Field(nullable=True) + negativereason_confidence: float = pa.Field(ge=0, le=1, nullable=True) + airline: str = pa.Field(isin=["Virgin America", "United", "Southwest", "Delta", "US Airways", "American"], nullable=False) + text_cleaned: str = pa.Field(nullable=False) + text_original: str = pa.Field(nullable=False) + + class Config: + strict = True + coerce = True + + +# --- 文本清洗函数 --- + + +def clean_text(text: str) -> str: + """文本清洗函数(克制策略) + + 清洗原则: + - 移除:用户提及(@username)、URL链接、多余空格 + - 保留:表情符号、标点符号、否定词、原始大小写(后续统一小写) + - 不做:词形还原、词干提取、停用词删除 + + Args: + text: 原始文本 + + Returns: + str: 清洗后的文本 + """ + if not text or not isinstance(text, str): + return "" + + # 1. 移除用户提及 (@username) + text = re.sub(r'@\w+', '', text) + + # 2. 移除 URL 链接 + text = re.sub(r'http\S+|www\S+', '', text) + + # 3. 移除多余空格和换行 + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def normalize_text(text: str) -> str: + """文本标准化 + + 统一小写,但不进行词形还原或词干提取。 + + Args: + text: 清洗后的文本 + + Returns: + str: 标准化后的文本 + """ + if not text or not isinstance(text, str): + return "" + + # 仅统一小写 + return text.lower() + + +# --- 数据加载与预处理 --- + + +def load_tweets(file_path: str | Path = "Tweets.csv") -> pl.DataFrame: + """加载原始推文数据 + + Args: + file_path: CSV 文件路径 + + Returns: + pl.DataFrame: 原始推文数据 + """ + df = pl.read_csv(file_path) + return df + + +def validate_raw_tweets(df: pl.DataFrame) -> pl.DataFrame: + """验证原始推文数据结构(清洗前) + + Args: + df: 原始 Polars DataFrame + + Returns: + pl.DataFrame: 验证通过的 DataFrame + + Raises: + pa.errors.SchemaError: 验证失败 + """ + return RawTweetSchema.validate(df) + + +def validate_clean_tweets(df: pl.DataFrame) -> pl.DataFrame: + """验证清洗后推文数据(严格模式) + + Args: + df: 清洗后的 Polars DataFrame + + Returns: + pl.DataFrame: 验证通过的 DataFrame + + Raises: + pa.errors.SchemaError: 验证失败 + """ + return CleanTweetSchema.validate(df) + + +def preprocess_tweets( + df: pl.DataFrame, + validate: bool = True, + min_confidence: float = 0.5 +) -> pl.DataFrame: + """推文数据预处理流水线 + + 处理步骤: + 1. 筛选:仅保留情感置信度 >= min_confidence 的样本 + 2. 文本清洗:应用 clean_text 和 normalize_text + 3. 删除缺失值:删除 text 为空的样本 + 4. 删除重复行:基于 tweet_id 去重 + 5. 可选:进行 Schema 校验 + + Args: + df: 原始 Polars DataFrame + validate: 是否进行清洗后 Schema 校验 + min_confidence: 最低情感置信度阈值 + + Returns: + pl.DataFrame: 清洗后的 DataFrame + """ + # 1. 筛选高置信度样本 + df_filtered = df.filter( + pl.col("airline_sentiment_confidence") >= min_confidence + ) + + # 2. 文本清洗和标准化 + df_clean = df_filtered.with_columns([ + pl.col("text").map_elements(clean_text, return_dtype=pl.String).alias("text_cleaned"), + pl.col("text").alias("text_original"), + ]) + + df_clean = df_clean.with_columns([ + pl.col("text_cleaned").map_elements(normalize_text, return_dtype=pl.String).alias("text_cleaned"), + ]) + + # 3. 删除缺失值(text_cleaned 为空或 airline_sentiment 为空) + df_clean = df_clean.filter( + (pl.col("text_cleaned").is_not_null()) & + (pl.col("text_cleaned") != "") & + (pl.col("airline_sentiment").is_not_null()) + ) + + # 4. 删除重复行(基于 tweet_id) + df_clean = df_clean.unique(subset=["tweet_id"], keep="first") + + # 5. 选择需要的列 + df_clean = df_clean.select([ + "tweet_id", + "airline_sentiment", + "airline_sentiment_confidence", + "negativereason", + "negativereason_confidence", + "airline", + "text_cleaned", + "text_original", + ]) + + # 6. 可选校验 + if validate: + df_clean = validate_clean_tweets(df_clean) + + return df_clean + + +def save_cleaned_tweets(df: pl.DataFrame, output_path: str | Path = "data/Tweets_cleaned.csv") -> None: + """保存清洗后的数据 + + Args: + df: 清洗后的 Polars DataFrame + output_path: 输出文件路径 + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + df.write_csv(output_path) + print(f"清洗后的数据已保存至 {output_path}") + + +def load_cleaned_tweets(file_path: str | Path = "data/Tweets_cleaned.csv") -> pl.DataFrame: + """加载清洗后的推文数据 + + Args: + file_path: 清洗后的 CSV 文件路径 + + Returns: + pl.DataFrame: 清洗后的推文数据 + """ + df = pl.read_csv(file_path) + return df + + +# --- 数据统计与分析 --- + + +def print_data_summary(df: pl.DataFrame, title: str = "数据统计") -> None: + """打印数据摘要信息 + + Args: + df: Polars DataFrame + title: 标题 + """ + print(f"\n{'='*60}") + print(f"{title}") + print(f"{'='*60}") + print(f"样本总数: {len(df)}") + print(f"\n情感分布:") + print(df.group_by("airline_sentiment").agg( + pl.len().alias("count"), + (pl.len() / len(df) * 100).alias("percentage") + ).sort("count", descending=True)) + + print(f"\n航空公司分布:") + print(df.group_by("airline").agg( + pl.len().alias("count"), + (pl.len() / len(df) * 100).alias("percentage") + ).sort("count", descending=True)) + + print(f"\n文本长度统计:") + df_with_length = df.with_columns([ + pl.col("text_cleaned").str.len_chars().alias("text_length") + ]) + print(df_with_length.select([ + pl.col("text_length").min().alias("最小长度"), + pl.col("text_length").max().alias("最大长度"), + pl.col("text_length").mean().alias("平均长度"), + pl.col("text_length").median().alias("中位数长度"), + ])) + + +if __name__ == "__main__": + print(">>> 1. 加载原始数据") + df_raw = load_tweets("Tweets.csv") + print(f"原始数据样本数: {len(df_raw)}") + print(df_raw.head(3)) + + print("\n>>> 2. 验证原始数据") + df_validated = validate_raw_tweets(df_raw) + print("✅ 原始数据验证通过") + + print("\n>>> 3. 清洗数据") + df_clean = preprocess_tweets(df_validated, validate=True, min_confidence=0.5) + print(f"清洗后样本数: {len(df_clean)} (原始: {len(df_raw)})") + print("✅ 清洗后数据验证通过") + + print("\n>>> 4. 保存清洗后的数据") + save_cleaned_tweets(df_clean, "data/Tweets_cleaned.csv") + + print("\n>>> 5. 数据统计") + print_data_summary(df_clean, "清洗后数据统计") + + print("\n>>> 6. 清洗示例对比") + print("\n原始文本:") + print(df_clean.select("text_original").head(3).to_pandas()["text_original"].to_string(index=False)) + print("\n清洗后文本:") + print(df_clean.select("text_cleaned").head(3).to_pandas()["text_cleaned"].to_string(index=False))