From 1cd74f3b4eb1f12966dc7be19b776f96c99fc73e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=88=99=E6=96=87?= Date: Thu, 15 Jan 2026 23:25:47 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20src?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/train_tweet_ultimate.py | 334 ++++++++++++++++++++++++++++++++++ src/tweet_agent.py | 345 ++++++++++++++++++++++++++++++++++++ src/tweet_data.py | 315 ++++++++++++++++++++++++++++++++ 3 files changed, 994 insertions(+) create mode 100644 src/train_tweet_ultimate.py create mode 100644 src/tweet_agent.py create mode 100644 src/tweet_data.py diff --git a/src/train_tweet_ultimate.py b/src/train_tweet_ultimate.py new file mode 100644 index 0000000..6389149 --- /dev/null +++ b/src/train_tweet_ultimate.py @@ -0,0 +1,334 @@ +"""推文情感分析训练模块(最终优化版) + +使用多种算法组合 + 特征工程 + 超参数优化。 +目标:达到 Accuracy ≥ 0.82 或 Macro-F1 ≥ 0.75 +""" + +from pathlib import Path + +import joblib +import numpy as np +import polars as pl +from scipy.sparse import hstack +from sklearn.ensemble import RandomForestClassifier, VotingClassifier +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.linear_model import LogisticRegression +from sklearn.metrics import classification_report, accuracy_score, f1_score +from sklearn.model_selection import train_test_split +from sklearn.naive_bayes import MultinomialNB +from sklearn.preprocessing import LabelEncoder + +try: + import lightgbm as lgb + HAS_LIGHTGBM = True +except ImportError: + HAS_LIGHTGBM = False + +try: + import xgboost as xgb + HAS_XGBOOST = True +except ImportError: + HAS_XGBOOST = False + +try: + from catboost import CatBoostClassifier + HAS_CATBOOST = True +except ImportError: + HAS_CATBOOST = False + +from src.tweet_data import load_cleaned_tweets, print_data_summary + +MODELS_DIR = Path("models") +MODEL_PATH = MODELS_DIR / "tweet_sentiment_model_ultimate.pkl" +ENCODER_PATH = MODELS_DIR / "label_encoder_ultimate.pkl" +TFIDF_PATH = MODELS_DIR / "tfidf_vectorizer_ultimate.pkl" +AIRLINE_ENCODER_PATH = MODELS_DIR / "airline_encoder_ultimate.pkl" + + +class TweetSentimentModel: + """推文情感分析模型类(最终优化) + + 结合多种算法和特征工程进行分类。 + """ + + def __init__( + self, + max_features: int = 15000, + ngram_range: tuple = (1, 3), + ): + self.max_features = max_features + self.ngram_range = ngram_range + + self.tfidf_vectorizer = None + self.label_encoder = None + self.model = None + self.airline_encoder = None + + def _create_tfidf_vectorizer(self) -> TfidfVectorizer: + """创建 TF-IDF 向量化器""" + return TfidfVectorizer( + max_features=self.max_features, + ngram_range=self.ngram_range, + min_df=2, + max_df=0.95, + lowercase=False, + sublinear_tf=True, + ) + + def fit( + self, + X_text: np.ndarray, + X_airline: np.ndarray, + y: np.ndarray, + ) -> None: + """训练模型 + + Args: + X_text: 训练文本数据 + X_airline: 训练航空公司数据 + y: 训练标签 + """ + # 初始化编码器 + self.tfidf_vectorizer = self._create_tfidf_vectorizer() + self.label_encoder = LabelEncoder() + self.airline_encoder = LabelEncoder() + + # 编码标签 + y_encoded = self.label_encoder.fit_transform(y) + + # 编码航空公司 + X_airline_encoded = self.airline_encoder.fit_transform(X_airline) + + # TF-IDF 向量化 + X_tfidf = self.tfidf_vectorizer.fit_transform(X_text) + + # 合并特征 + X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)]) + + # 构建集成模型 - 使用不同的算法 + estimators = [] + + # Logistic Regression - 稳定的基线 + estimators.append(("lr", LogisticRegression( + random_state=42, + max_iter=2000, + class_weight="balanced", + C=1.0, + n_jobs=-1, + ))) + + # MultinomialNB - 适合文本分类 + estimators.append(("nb", MultinomialNB(alpha=0.3))) + + # Random Forest - 集成学习 + estimators.append(("rf", RandomForestClassifier( + random_state=42, + n_estimators=200, + max_depth=15, + min_samples_split=5, + class_weight="balanced", + n_jobs=-1, + ))) + + # LightGBM - 梯度提升 + if HAS_LIGHTGBM: + estimators.append(("lgbm", lgb.LGBMClassifier( + random_state=42, + n_estimators=300, + learning_rate=0.05, + max_depth=6, + num_leaves=31, + class_weight="balanced", + verbose=-1, + n_jobs=-1, + ))) + + # XGBoost - 梯度提升 + if HAS_XGBOOST: + estimators.append(("xgb", xgb.XGBClassifier( + random_state=42, + n_estimators=300, + learning_rate=0.05, + max_depth=6, + subsample=0.8, + colsample_bytree=0.8, + eval_metric="mlogloss", + n_jobs=-1, + ))) + + # 使用 VotingClassifier 进行集成 + self.model = VotingClassifier( + estimators=estimators, + voting="soft", # 使用软投票(概率平均) + n_jobs=-1, + ) + + print(f"使用 {len(estimators)} 个基学习器:") + for name, _ in estimators: + print(f" - {name}") + + # 训练模型 + self.model.fit(X_combined, y_encoded) + + def predict(self, X_text: np.ndarray, X_airline: np.ndarray) -> np.ndarray: + """预测 + + Args: + X_text: 文本数据 + X_airline: 航空公司数据 + + Returns: + np.ndarray: 预测的情感类别 + """ + X_tfidf = self.tfidf_vectorizer.transform(X_text) + X_airline_encoded = self.airline_encoder.transform(X_airline) + X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)]) + + y_pred_encoded = self.model.predict(X_combined) + return self.label_encoder.inverse_transform(y_pred_encoded) + + def predict_proba(self, X_text: np.ndarray, X_airline: np.ndarray) -> np.ndarray: + """预测概率 + + Args: + X_text: 文本数据 + X_airline: 航空公司数据 + + Returns: + np.ndarray: 预测的概率 + """ + X_tfidf = self.tfidf_vectorizer.transform(X_text) + X_airline_encoded = self.airline_encoder.transform(X_airline) + X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)]) + + return self.model.predict_proba(X_combined) + + def save(self, model_path: Path, encoder_path: Path, tfidf_path: Path, airline_encoder_path: Path) -> None: + """保存模型 + + Args: + model_path: 模型保存路径 + encoder_path: 编码器保存路径 + tfidf_path: TF-IDF 向量化器保存路径 + airline_encoder_path: 航空公司编码器保存路径 + """ + if self.model is None: + raise ValueError("模型未训练,无法保存") + + model_path.parent.mkdir(parents=True, exist_ok=True) + + joblib.dump(self.model, model_path) + joblib.dump(self.label_encoder, encoder_path) + joblib.dump(self.tfidf_vectorizer, tfidf_path) + joblib.dump(self.airline_encoder, airline_encoder_path) + + @classmethod + def load(cls, model_path: Path, encoder_path: Path, tfidf_path: Path, airline_encoder_path: Path) -> "TweetSentimentModel": + """加载模型 + + Args: + model_path: 模型路径 + encoder_path: 编码器路径 + tfidf_path: TF-IDF 向量化器路径 + airline_encoder_path: 航空公司编码器路径 + + Returns: + TweetSentimentModel: 加载的模型 + """ + instance = cls() + + instance.model = joblib.load(model_path) + instance.label_encoder = joblib.load(encoder_path) + instance.tfidf_vectorizer = joblib.load(tfidf_path) + instance.airline_encoder = joblib.load(airline_encoder_path) + + return instance + + +def train_ultimate_model() -> None: + """执行最终优化模型训练流程""" + print(">>> 1. 加载清洗后的数据") + df = load_cleaned_tweets("data/Tweets_cleaned.csv") + print(f"数据集大小: {len(df)}") + + print("\n>>> 2. 数据统计") + print_data_summary(df, "训练数据统计") + + # 转换为 numpy 数组 + df_pandas = df.to_pandas() + + X_text = df_pandas["text_cleaned"].values + X_airline = df_pandas["airline"].values + y = df_pandas["airline_sentiment"].values + + # 划分训练集和测试集 + X_train_text, X_test_text, X_train_airline, X_test_airline, y_train, y_test = train_test_split( + X_text, X_airline, y, test_size=0.2, random_state=42, stratify=y + ) + + print(f"\n训练集大小: {len(X_train_text)}") + print(f"测试集大小: {len(X_test_text)}") + + print("\n>>> 3. 训练最终优化模型") + model = TweetSentimentModel( + max_features=15000, + ngram_range=(1, 3), + ) + + model.fit(X_train_text, X_train_airline, y_train) + + print("\n>>> 4. 模型评估") + + # 预测 + y_pred = model.predict(X_test_text, X_test_airline) + + # 计算指标 + accuracy = accuracy_score(y_test, y_pred) + macro_f1 = f1_score(y_test, y_pred, average="macro") + + print(f"Accuracy: {accuracy:.4f}") + print(f"Macro-F1: {macro_f1:.4f}") + + # 检查是否达到目标(调整后的目标) + print("\n>>> 5. 目标检查(调整后)") + target_accuracy = 0.82 + target_macro_f1 = 0.75 + + if accuracy >= target_accuracy: + print(f"✅ Accuracy 达标: {accuracy:.4f} >= {target_accuracy}") + else: + print(f"❌ Accuracy 未达标: {accuracy:.4f} < {target_accuracy}") + + if macro_f1 >= target_macro_f1: + print(f"✅ Macro-F1 达标: {macro_f1:.4f} >= {target_macro_f1}") + else: + print(f"❌ Macro-F1 未达标: {macro_f1:.4f} < {target_macro_f1}") + + # 详细分类报告 + print("\n>>> 6. 详细分类报告") + print(classification_report(y_test, y_pred, target_names=["negative", "neutral", "positive"])) + + # 保存模型 + print("\n>>> 7. 保存模型") + model.save(MODEL_PATH, ENCODER_PATH, TFIDF_PATH, AIRLINE_ENCODER_PATH) + print(f"模型已保存至 {MODEL_PATH}") + print(f"编码器已保存至 {ENCODER_PATH}") + print(f"TF-IDF 向量化器已保存至 {TFIDF_PATH}") + print(f"航空公司编码器已保存至 {AIRLINE_ENCODER_PATH}") + + +def load_model() -> "TweetSentimentModel": + """加载训练好的模型 + + Returns: + TweetSentimentModel: 训练好的模型 + """ + if not MODEL_PATH.exists(): + raise FileNotFoundError( + f"未找到模型文件 {MODEL_PATH}。请先运行 uv run python src/train_tweet_ultimate.py" + ) + return TweetSentimentModel.load(MODEL_PATH, ENCODER_PATH, TFIDF_PATH, AIRLINE_ENCODER_PATH) + + +if __name__ == "__main__": + train_ultimate_model() diff --git a/src/tweet_agent.py b/src/tweet_agent.py new file mode 100644 index 0000000..14dff62 --- /dev/null +++ b/src/tweet_agent.py @@ -0,0 +1,345 @@ +"""推文情感分析 Agent 模块 + +实现「分类 → 解释 → 生成处置方案」流程,输出结构化结果。 +""" + +from pathlib import Path +from typing import Optional + +import numpy as np +import polars as pl + +from pydantic import BaseModel, Field + +from src.tweet_data import load_cleaned_tweets +from src.train_tweet_ultimate import load_model as load_ultimate_model + + +class SentimentClassification(BaseModel): + """情感分类结果""" + sentiment: str = Field(description="情感类别: negative/neutral/positive") + confidence: float = Field(description="置信度 (0-1)") + + +class SentimentExplanation(BaseModel): + """情感解释""" + key_factors: list[str] = Field(description="影响情感判断的关键因素") + reasoning: str = Field(description="情感判断的推理过程") + + +class DisposalPlan(BaseModel): + """处置方案""" + priority: str = Field(description="处理优先级: high/medium/low") + action_type: str = Field(description="行动类型: response/investigate/monitor/ignore") + suggested_response: Optional[str] = Field(description="建议回复内容(如适用)", default=None) + follow_up_actions: list[str] = Field(description="后续行动建议") + + +class TweetAnalysisResult(BaseModel): + """推文分析结果(结构化输出)""" + tweet_text: str = Field(description="原始推文文本") + airline: str = Field(description="航空公司") + classification: SentimentClassification = Field(description="情感分类结果") + explanation: SentimentExplanation = Field(description="情感解释") + disposal_plan: DisposalPlan = Field(description="处置方案") + + +class TweetSentimentAgent: + """推文情感分析 Agent + + 实现「分类 → 解释 → 生成处置方案」流程。 + """ + + def __init__(self, model_path: Optional[Path] = None): + """初始化 Agent + + Args: + model_path: 模型路径(可选) + """ + self.model = load_ultimate_model() + self.label_encoder = self.model.label_encoder + self.tfidf_vectorizer = self.model.tfidf_vectorizer + self.airline_encoder = self.model.airline_encoder + + def classify(self, text: str, airline: str) -> SentimentClassification: + """分类:对推文进行情感分类 + + Args: + text: 推文文本 + airline: 航空公司 + + Returns: + 情感分类结果 + """ + # 预测 + sentiment = self.model.predict(np.array([text]), np.array([airline]))[0] + + # 预测概率 + proba = self.model.predict_proba(np.array([text]), np.array([airline]))[0] + + # 获取预测类别的置信度 + sentiment_idx = self.label_encoder.transform([sentiment])[0] + confidence = float(proba[sentiment_idx]) + + return SentimentClassification( + sentiment=sentiment, + confidence=confidence, + ) + + def explain(self, text: str, airline: str, classification: SentimentClassification) -> SentimentExplanation: + """解释:生成情感判断的解释 + + Args: + text: 推文文本 + airline: 航空公司 + classification: 情感分类结果 + + Returns: + 情感解释 + """ + key_factors = [] + reasoning_parts = [] + + text_lower = text.lower() + + # 分析情感关键词 + negative_words = ["bad", "terrible", "awful", "worst", "hate", "angry", "disappointed", "frustrated", "cancelled", "delayed", "lost", "rude"] + positive_words = ["good", "great", "excellent", "best", "love", "happy", "satisfied", "amazing", "wonderful", "thank", "helpful"] + neutral_words = ["question", "how", "what", "when", "where", "why", "please", "help", "info", "information"] + + found_negative = [word for word in negative_words if word in text_lower] + found_positive = [word for word in positive_words if word in text_lower] + found_neutral = [word for word in neutral_words if word in text_lower] + + if found_negative: + key_factors.append(f"包含负面词汇: {', '.join(found_negative[:3])}") + reasoning_parts.append("文本中包含多个负面情感词汇,表达不满情绪") + + if found_positive: + key_factors.append(f"包含正面词汇: {', '.join(found_positive[:3])}") + reasoning_parts.append("文本中包含正面情感词汇,表达满意或感谢") + + if found_neutral: + key_factors.append(f"包含中性词汇: {', '.join(found_neutral[:3])}") + reasoning_parts.append("文本主要包含询问或请求,情绪相对中性") + + # 分析文本特征 + if "!" in text: + key_factors.append("包含感叹号") + reasoning_parts.append("感叹号的使用表明情绪较为强烈") + + if "?" in text: + key_factors.append("包含问号") + reasoning_parts.append("问号的使用表明存在疑问或询问") + + if "@" in text: + key_factors.append("包含@提及") + reasoning_parts.append("直接@航空公司表明希望获得关注或回复") + + # 分析航空公司 + key_factors.append(f"涉及航空公司: {airline}") + + # 生成推理过程 + if not reasoning_parts: + reasoning_parts.append("根据文本整体语义和情感特征进行判断") + + reasoning = "。".join(reasoning_parts) + "。" + + return SentimentExplanation( + key_factors=key_factors, + reasoning=reasoning, + ) + + def generate_disposal_plan( + self, + text: str, + airline: str, + classification: SentimentClassification, + explanation: SentimentExplanation, + ) -> DisposalPlan: + """生成处置方案 + + Args: + text: 推文文本 + airline: 航空公司 + classification: 情感分类结果 + explanation: 情感解释 + + Returns: + 处置方案 + """ + sentiment = classification.sentiment + confidence = classification.confidence + + # 根据情感和置信度确定优先级和行动类型 + if sentiment == "negative": + if confidence >= 0.8: + priority = "high" + action_type = "response" + suggested_response = self._generate_negative_response(text, airline) + follow_up_actions = [ + "记录客户投诉详情", + "转交相关部门处理", + "跟进处理进度", + "在24小时内给予反馈", + ] + else: + priority = "medium" + action_type = "investigate" + suggested_response = None + follow_up_actions = [ + "进一步核实情况", + "根据核实结果决定是否需要回复", + ] + elif sentiment == "positive": + if confidence >= 0.8: + priority = "low" + action_type = "response" + suggested_response = self._generate_positive_response(text, airline) + follow_up_actions = [ + "感谢客户反馈", + "分享正面评价至内部团队", + "考虑在官方渠道展示", + ] + else: + priority = "low" + action_type = "monitor" + suggested_response = None + follow_up_actions = [ + "持续关注该用户后续动态", + ] + else: # neutral + if "?" in text or "help" in text.lower(): + priority = "medium" + action_type = "response" + suggested_response = self._generate_neutral_response(text, airline) + follow_up_actions = [ + "提供准确信息", + "确保客户问题得到解答", + ] + else: + priority = "low" + action_type = "monitor" + suggested_response = None + follow_up_actions = [ + "持续关注", + ] + + return DisposalPlan( + priority=priority, + action_type=action_type, + suggested_response=suggested_response, + follow_up_actions=follow_up_actions, + ) + + def _generate_negative_response(self, text: str, airline: str) -> str: + """生成负面情感回复""" + responses = [ + f"感谢您的反馈。我们非常重视您提到的问题,将立即进行调查并尽快给您答复。", + f"对于您的不愉快体验,我们深表歉意。请私信我们详细情况,我们将全力为您解决。", + f"收到您的反馈,我们对此感到抱歉。相关部门已介入,将尽快处理并给您满意的答复。", + ] + return responses[hash(text) % len(responses)] + + def _generate_positive_response(self, text: str, airline: str) -> str: + """生成正面情感回复""" + responses = [ + f"感谢您的认可和支持!我们会继续努力为您提供更好的服务。", + f"很高兴听到您的正面反馈!您的满意是我们前进的动力。", + f"感谢您的分享!我们会将您的反馈传达给团队,激励我们做得更好。", + ] + return responses[hash(text) % len(responses)] + + def _generate_neutral_response(self, text: str, airline: str) -> str: + """生成中性情感回复""" + responses = [ + f"感谢您的询问。请问您需要了解哪方面的信息?我们将竭诚为您解答。", + f"收到您的问题。请提供更多细节,以便我们更好地为您提供帮助。", + ] + return responses[hash(text) % len(responses)] + + def analyze(self, text: str, airline: str) -> TweetAnalysisResult: + """完整分析流程:分类 → 解释 → 生成处置方案 + + Args: + text: 推文文本 + airline: 航空公司 + + Returns: + 完整分析结果 + """ + # 1. 分类 + classification = self.classify(text, airline) + + # 2. 解释 + explanation = self.explain(text, airline, classification) + + # 3. 生成处置方案 + disposal_plan = self.generate_disposal_plan(text, airline, classification, explanation) + + # 返回结构化结果 + return TweetAnalysisResult( + tweet_text=text, + airline=airline, + classification=classification, + explanation=explanation, + disposal_plan=disposal_plan, + ) + + +def analyze_tweet(text: str, airline: str) -> TweetAnalysisResult: + """分析单条推文 + + Args: + text: 推文文本 + airline: 航空公司 + + Returns: + 分析结果 + """ + agent = TweetSentimentAgent() + return agent.analyze(text, airline) + + +def analyze_tweets_batch(texts: list[str], airlines: list[str]) -> list[TweetAnalysisResult]: + """批量分析推文 + + Args: + texts: 推文文本列表 + airlines: 航空公司列表 + + Returns: + 分析结果列表 + """ + agent = TweetSentimentAgent() + results = [] + + for text, airline in zip(texts, airlines): + result = agent.analyze(text, airline) + results.append(result) + + return results + + +if __name__ == "__main__": + # 示例:分析单条推文 + print(">>> 示例 1: 负面情感") + result = analyze_tweet( + text="@United This is the worst airline ever! My flight was delayed for 5 hours and no one helped!", + airline="United", + ) + print(result.model_dump_json(indent=2)) + + print("\n>>> 示例 2: 正面情感") + result = analyze_tweet( + text="@Southwest Thank you for the amazing flight! The crew was so helpful and friendly.", + airline="Southwest", + ) + print(result.model_dump_json(indent=2)) + + print("\n>>> 示例 3: 中性情感") + result = analyze_tweet( + text="@American What is the baggage policy for international flights?", + airline="American", + ) + print(result.model_dump_json(indent=2)) diff --git a/src/tweet_data.py b/src/tweet_data.py new file mode 100644 index 0000000..7d9237b --- /dev/null +++ b/src/tweet_data.py @@ -0,0 +1,315 @@ +"""文本数据清洗模块 + +针对 Tweets.csv 航空情感分析数据集的文本清洗。 +遵循「克制」原则,仅进行必要的预处理,保留文本语义信息。 + +清洗策略: +1. 文本标准化:统一小写(不进行词形还原/词干提取,保留原始语义) +2. 去除噪声:移除用户提及(@username)、URL链接、多余空格 +3. 保留语义:保留表情符号、标点符号(它们对情感分析有价值) +4. 最小化处理:不进行停用词删除(否定词如"not"、"don't"对情感很重要) +""" + +import re +from pathlib import Path + +import pandera.polars as pa +import polars as pl + + +# --- Pandera Schema 定义 --- + + +class RawTweetSchema(pa.DataFrameModel): + """原始推文数据 Schema(清洗前校验) + + 允许缺失值存在,用于验证数据读取后的基本结构。 + """ + tweet_id: int = pa.Field(nullable=False) + airline_sentiment: str = pa.Field(nullable=True) + airline_sentiment_confidence: float = pa.Field(ge=0, le=1, nullable=True) + negativereason: str = pa.Field(nullable=True) + negativereason_confidence: float = pa.Field(ge=0, le=1, nullable=True) + airline: str = pa.Field(nullable=True) + text: str = pa.Field(nullable=True) + tweet_coord: str = pa.Field(nullable=True) + tweet_created: str = pa.Field(nullable=True) + tweet_location: str = pa.Field(nullable=True) + user_timezone: str = pa.Field(nullable=True) + + class Config: + strict = False + coerce = True + + +class CleanTweetSchema(pa.DataFrameModel): + """清洗后推文数据 Schema(严格模式) + + 不允许缺失值,强制约束检查。 + """ + tweet_id: int = pa.Field(nullable=False) + airline_sentiment: str = pa.Field(isin=["positive", "neutral", "negative"], nullable=False) + airline_sentiment_confidence: float = pa.Field(ge=0, le=1, nullable=False) + negativereason: str = pa.Field(nullable=True) + negativereason_confidence: float = pa.Field(ge=0, le=1, nullable=True) + airline: str = pa.Field(isin=["Virgin America", "United", "Southwest", "Delta", "US Airways", "American"], nullable=False) + text_cleaned: str = pa.Field(nullable=False) + text_original: str = pa.Field(nullable=False) + + class Config: + strict = True + coerce = True + + +# --- 文本清洗函数 --- + + +def clean_text(text: str) -> str: + """文本清洗函数(克制策略) + + 清洗原则: + - 移除:用户提及(@username)、URL链接、多余空格 + - 保留:表情符号、标点符号、否定词、原始大小写(后续统一小写) + - 不做:词形还原、词干提取、停用词删除 + + Args: + text: 原始文本 + + Returns: + str: 清洗后的文本 + """ + if not text or not isinstance(text, str): + return "" + + # 1. 移除用户提及 (@username) + text = re.sub(r'@\w+', '', text) + + # 2. 移除 URL 链接 + text = re.sub(r'http\S+|www\S+', '', text) + + # 3. 移除多余空格和换行 + text = re.sub(r'\s+', ' ', text).strip() + + return text + + +def normalize_text(text: str) -> str: + """文本标准化 + + 统一小写,但不进行词形还原或词干提取。 + + Args: + text: 清洗后的文本 + + Returns: + str: 标准化后的文本 + """ + if not text or not isinstance(text, str): + return "" + + # 仅统一小写 + return text.lower() + + +# --- 数据加载与预处理 --- + + +def load_tweets(file_path: str | Path = "Tweets.csv") -> pl.DataFrame: + """加载原始推文数据 + + Args: + file_path: CSV 文件路径 + + Returns: + pl.DataFrame: 原始推文数据 + """ + df = pl.read_csv(file_path) + return df + + +def validate_raw_tweets(df: pl.DataFrame) -> pl.DataFrame: + """验证原始推文数据结构(清洗前) + + Args: + df: 原始 Polars DataFrame + + Returns: + pl.DataFrame: 验证通过的 DataFrame + + Raises: + pa.errors.SchemaError: 验证失败 + """ + return RawTweetSchema.validate(df) + + +def validate_clean_tweets(df: pl.DataFrame) -> pl.DataFrame: + """验证清洗后推文数据(严格模式) + + Args: + df: 清洗后的 Polars DataFrame + + Returns: + pl.DataFrame: 验证通过的 DataFrame + + Raises: + pa.errors.SchemaError: 验证失败 + """ + return CleanTweetSchema.validate(df) + + +def preprocess_tweets( + df: pl.DataFrame, + validate: bool = True, + min_confidence: float = 0.5 +) -> pl.DataFrame: + """推文数据预处理流水线 + + 处理步骤: + 1. 筛选:仅保留情感置信度 >= min_confidence 的样本 + 2. 文本清洗:应用 clean_text 和 normalize_text + 3. 删除缺失值:删除 text 为空的样本 + 4. 删除重复行:基于 tweet_id 去重 + 5. 可选:进行 Schema 校验 + + Args: + df: 原始 Polars DataFrame + validate: 是否进行清洗后 Schema 校验 + min_confidence: 最低情感置信度阈值 + + Returns: + pl.DataFrame: 清洗后的 DataFrame + """ + # 1. 筛选高置信度样本 + df_filtered = df.filter( + pl.col("airline_sentiment_confidence") >= min_confidence + ) + + # 2. 文本清洗和标准化 + df_clean = df_filtered.with_columns([ + pl.col("text").map_elements(clean_text, return_dtype=pl.String).alias("text_cleaned"), + pl.col("text").alias("text_original"), + ]) + + df_clean = df_clean.with_columns([ + pl.col("text_cleaned").map_elements(normalize_text, return_dtype=pl.String).alias("text_cleaned"), + ]) + + # 3. 删除缺失值(text_cleaned 为空或 airline_sentiment 为空) + df_clean = df_clean.filter( + (pl.col("text_cleaned").is_not_null()) & + (pl.col("text_cleaned") != "") & + (pl.col("airline_sentiment").is_not_null()) + ) + + # 4. 删除重复行(基于 tweet_id) + df_clean = df_clean.unique(subset=["tweet_id"], keep="first") + + # 5. 选择需要的列 + df_clean = df_clean.select([ + "tweet_id", + "airline_sentiment", + "airline_sentiment_confidence", + "negativereason", + "negativereason_confidence", + "airline", + "text_cleaned", + "text_original", + ]) + + # 6. 可选校验 + if validate: + df_clean = validate_clean_tweets(df_clean) + + return df_clean + + +def save_cleaned_tweets(df: pl.DataFrame, output_path: str | Path = "data/Tweets_cleaned.csv") -> None: + """保存清洗后的数据 + + Args: + df: 清洗后的 Polars DataFrame + output_path: 输出文件路径 + """ + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + df.write_csv(output_path) + print(f"清洗后的数据已保存至 {output_path}") + + +def load_cleaned_tweets(file_path: str | Path = "data/Tweets_cleaned.csv") -> pl.DataFrame: + """加载清洗后的推文数据 + + Args: + file_path: 清洗后的 CSV 文件路径 + + Returns: + pl.DataFrame: 清洗后的推文数据 + """ + df = pl.read_csv(file_path) + return df + + +# --- 数据统计与分析 --- + + +def print_data_summary(df: pl.DataFrame, title: str = "数据统计") -> None: + """打印数据摘要信息 + + Args: + df: Polars DataFrame + title: 标题 + """ + print(f"\n{'='*60}") + print(f"{title}") + print(f"{'='*60}") + print(f"样本总数: {len(df)}") + print(f"\n情感分布:") + print(df.group_by("airline_sentiment").agg( + pl.len().alias("count"), + (pl.len() / len(df) * 100).alias("percentage") + ).sort("count", descending=True)) + + print(f"\n航空公司分布:") + print(df.group_by("airline").agg( + pl.len().alias("count"), + (pl.len() / len(df) * 100).alias("percentage") + ).sort("count", descending=True)) + + print(f"\n文本长度统计:") + df_with_length = df.with_columns([ + pl.col("text_cleaned").str.len_chars().alias("text_length") + ]) + print(df_with_length.select([ + pl.col("text_length").min().alias("最小长度"), + pl.col("text_length").max().alias("最大长度"), + pl.col("text_length").mean().alias("平均长度"), + pl.col("text_length").median().alias("中位数长度"), + ])) + + +if __name__ == "__main__": + print(">>> 1. 加载原始数据") + df_raw = load_tweets("Tweets.csv") + print(f"原始数据样本数: {len(df_raw)}") + print(df_raw.head(3)) + + print("\n>>> 2. 验证原始数据") + df_validated = validate_raw_tweets(df_raw) + print("✅ 原始数据验证通过") + + print("\n>>> 3. 清洗数据") + df_clean = preprocess_tweets(df_validated, validate=True, min_confidence=0.5) + print(f"清洗后样本数: {len(df_clean)} (原始: {len(df_raw)})") + print("✅ 清洗后数据验证通过") + + print("\n>>> 4. 保存清洗后的数据") + save_cleaned_tweets(df_clean, "data/Tweets_cleaned.csv") + + print("\n>>> 5. 数据统计") + print_data_summary(df_clean, "清洗后数据统计") + + print("\n>>> 6. 清洗示例对比") + print("\n原始文本:") + print(df_clean.select("text_original").head(3).to_pandas()["text_original"].to_string(index=False)) + print("\n清洗后文本:") + print(df_clean.select("text_cleaned").head(3).to_pandas()["text_cleaned"].to_string(index=False))