上传文件至 src

This commit is contained in:
张则文 2026-01-15 23:25:47 +08:00
parent 9cc826963b
commit 1cd74f3b4e
3 changed files with 994 additions and 0 deletions

334
src/train_tweet_ultimate.py Normal file
View File

@ -0,0 +1,334 @@
"""推文情感分析训练模块(最终优化版)
使用多种算法组合 + 特征工程 + 超参数优化
目标达到 Accuracy 0.82 Macro-F1 0.75
"""
from pathlib import Path
import joblib
import numpy as np
import polars as pl
from scipy.sparse import hstack
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, accuracy_score, f1_score
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import MultinomialNB
from sklearn.preprocessing import LabelEncoder
try:
import lightgbm as lgb
HAS_LIGHTGBM = True
except ImportError:
HAS_LIGHTGBM = False
try:
import xgboost as xgb
HAS_XGBOOST = True
except ImportError:
HAS_XGBOOST = False
try:
from catboost import CatBoostClassifier
HAS_CATBOOST = True
except ImportError:
HAS_CATBOOST = False
from src.tweet_data import load_cleaned_tweets, print_data_summary
MODELS_DIR = Path("models")
MODEL_PATH = MODELS_DIR / "tweet_sentiment_model_ultimate.pkl"
ENCODER_PATH = MODELS_DIR / "label_encoder_ultimate.pkl"
TFIDF_PATH = MODELS_DIR / "tfidf_vectorizer_ultimate.pkl"
AIRLINE_ENCODER_PATH = MODELS_DIR / "airline_encoder_ultimate.pkl"
class TweetSentimentModel:
"""推文情感分析模型类(最终优化)
结合多种算法和特征工程进行分类
"""
def __init__(
self,
max_features: int = 15000,
ngram_range: tuple = (1, 3),
):
self.max_features = max_features
self.ngram_range = ngram_range
self.tfidf_vectorizer = None
self.label_encoder = None
self.model = None
self.airline_encoder = None
def _create_tfidf_vectorizer(self) -> TfidfVectorizer:
"""创建 TF-IDF 向量化器"""
return TfidfVectorizer(
max_features=self.max_features,
ngram_range=self.ngram_range,
min_df=2,
max_df=0.95,
lowercase=False,
sublinear_tf=True,
)
def fit(
self,
X_text: np.ndarray,
X_airline: np.ndarray,
y: np.ndarray,
) -> None:
"""训练模型
Args:
X_text: 训练文本数据
X_airline: 训练航空公司数据
y: 训练标签
"""
# 初始化编码器
self.tfidf_vectorizer = self._create_tfidf_vectorizer()
self.label_encoder = LabelEncoder()
self.airline_encoder = LabelEncoder()
# 编码标签
y_encoded = self.label_encoder.fit_transform(y)
# 编码航空公司
X_airline_encoded = self.airline_encoder.fit_transform(X_airline)
# TF-IDF 向量化
X_tfidf = self.tfidf_vectorizer.fit_transform(X_text)
# 合并特征
X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)])
# 构建集成模型 - 使用不同的算法
estimators = []
# Logistic Regression - 稳定的基线
estimators.append(("lr", LogisticRegression(
random_state=42,
max_iter=2000,
class_weight="balanced",
C=1.0,
n_jobs=-1,
)))
# MultinomialNB - 适合文本分类
estimators.append(("nb", MultinomialNB(alpha=0.3)))
# Random Forest - 集成学习
estimators.append(("rf", RandomForestClassifier(
random_state=42,
n_estimators=200,
max_depth=15,
min_samples_split=5,
class_weight="balanced",
n_jobs=-1,
)))
# LightGBM - 梯度提升
if HAS_LIGHTGBM:
estimators.append(("lgbm", lgb.LGBMClassifier(
random_state=42,
n_estimators=300,
learning_rate=0.05,
max_depth=6,
num_leaves=31,
class_weight="balanced",
verbose=-1,
n_jobs=-1,
)))
# XGBoost - 梯度提升
if HAS_XGBOOST:
estimators.append(("xgb", xgb.XGBClassifier(
random_state=42,
n_estimators=300,
learning_rate=0.05,
max_depth=6,
subsample=0.8,
colsample_bytree=0.8,
eval_metric="mlogloss",
n_jobs=-1,
)))
# 使用 VotingClassifier 进行集成
self.model = VotingClassifier(
estimators=estimators,
voting="soft", # 使用软投票(概率平均)
n_jobs=-1,
)
print(f"使用 {len(estimators)} 个基学习器:")
for name, _ in estimators:
print(f" - {name}")
# 训练模型
self.model.fit(X_combined, y_encoded)
def predict(self, X_text: np.ndarray, X_airline: np.ndarray) -> np.ndarray:
"""预测
Args:
X_text: 文本数据
X_airline: 航空公司数据
Returns:
np.ndarray: 预测的情感类别
"""
X_tfidf = self.tfidf_vectorizer.transform(X_text)
X_airline_encoded = self.airline_encoder.transform(X_airline)
X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)])
y_pred_encoded = self.model.predict(X_combined)
return self.label_encoder.inverse_transform(y_pred_encoded)
def predict_proba(self, X_text: np.ndarray, X_airline: np.ndarray) -> np.ndarray:
"""预测概率
Args:
X_text: 文本数据
X_airline: 航空公司数据
Returns:
np.ndarray: 预测的概率
"""
X_tfidf = self.tfidf_vectorizer.transform(X_text)
X_airline_encoded = self.airline_encoder.transform(X_airline)
X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)])
return self.model.predict_proba(X_combined)
def save(self, model_path: Path, encoder_path: Path, tfidf_path: Path, airline_encoder_path: Path) -> None:
"""保存模型
Args:
model_path: 模型保存路径
encoder_path: 编码器保存路径
tfidf_path: TF-IDF 向量化器保存路径
airline_encoder_path: 航空公司编码器保存路径
"""
if self.model is None:
raise ValueError("模型未训练,无法保存")
model_path.parent.mkdir(parents=True, exist_ok=True)
joblib.dump(self.model, model_path)
joblib.dump(self.label_encoder, encoder_path)
joblib.dump(self.tfidf_vectorizer, tfidf_path)
joblib.dump(self.airline_encoder, airline_encoder_path)
@classmethod
def load(cls, model_path: Path, encoder_path: Path, tfidf_path: Path, airline_encoder_path: Path) -> "TweetSentimentModel":
"""加载模型
Args:
model_path: 模型路径
encoder_path: 编码器路径
tfidf_path: TF-IDF 向量化器路径
airline_encoder_path: 航空公司编码器路径
Returns:
TweetSentimentModel: 加载的模型
"""
instance = cls()
instance.model = joblib.load(model_path)
instance.label_encoder = joblib.load(encoder_path)
instance.tfidf_vectorizer = joblib.load(tfidf_path)
instance.airline_encoder = joblib.load(airline_encoder_path)
return instance
def train_ultimate_model() -> None:
"""执行最终优化模型训练流程"""
print(">>> 1. 加载清洗后的数据")
df = load_cleaned_tweets("data/Tweets_cleaned.csv")
print(f"数据集大小: {len(df)}")
print("\n>>> 2. 数据统计")
print_data_summary(df, "训练数据统计")
# 转换为 numpy 数组
df_pandas = df.to_pandas()
X_text = df_pandas["text_cleaned"].values
X_airline = df_pandas["airline"].values
y = df_pandas["airline_sentiment"].values
# 划分训练集和测试集
X_train_text, X_test_text, X_train_airline, X_test_airline, y_train, y_test = train_test_split(
X_text, X_airline, y, test_size=0.2, random_state=42, stratify=y
)
print(f"\n训练集大小: {len(X_train_text)}")
print(f"测试集大小: {len(X_test_text)}")
print("\n>>> 3. 训练最终优化模型")
model = TweetSentimentModel(
max_features=15000,
ngram_range=(1, 3),
)
model.fit(X_train_text, X_train_airline, y_train)
print("\n>>> 4. 模型评估")
# 预测
y_pred = model.predict(X_test_text, X_test_airline)
# 计算指标
accuracy = accuracy_score(y_test, y_pred)
macro_f1 = f1_score(y_test, y_pred, average="macro")
print(f"Accuracy: {accuracy:.4f}")
print(f"Macro-F1: {macro_f1:.4f}")
# 检查是否达到目标(调整后的目标)
print("\n>>> 5. 目标检查(调整后)")
target_accuracy = 0.82
target_macro_f1 = 0.75
if accuracy >= target_accuracy:
print(f"✅ Accuracy 达标: {accuracy:.4f} >= {target_accuracy}")
else:
print(f"❌ Accuracy 未达标: {accuracy:.4f} < {target_accuracy}")
if macro_f1 >= target_macro_f1:
print(f"✅ Macro-F1 达标: {macro_f1:.4f} >= {target_macro_f1}")
else:
print(f"❌ Macro-F1 未达标: {macro_f1:.4f} < {target_macro_f1}")
# 详细分类报告
print("\n>>> 6. 详细分类报告")
print(classification_report(y_test, y_pred, target_names=["negative", "neutral", "positive"]))
# 保存模型
print("\n>>> 7. 保存模型")
model.save(MODEL_PATH, ENCODER_PATH, TFIDF_PATH, AIRLINE_ENCODER_PATH)
print(f"模型已保存至 {MODEL_PATH}")
print(f"编码器已保存至 {ENCODER_PATH}")
print(f"TF-IDF 向量化器已保存至 {TFIDF_PATH}")
print(f"航空公司编码器已保存至 {AIRLINE_ENCODER_PATH}")
def load_model() -> "TweetSentimentModel":
"""加载训练好的模型
Returns:
TweetSentimentModel: 训练好的模型
"""
if not MODEL_PATH.exists():
raise FileNotFoundError(
f"未找到模型文件 {MODEL_PATH}。请先运行 uv run python src/train_tweet_ultimate.py"
)
return TweetSentimentModel.load(MODEL_PATH, ENCODER_PATH, TFIDF_PATH, AIRLINE_ENCODER_PATH)
if __name__ == "__main__":
train_ultimate_model()

345
src/tweet_agent.py Normal file
View File

@ -0,0 +1,345 @@
"""推文情感分析 Agent 模块
实现分类 解释 生成处置方案流程输出结构化结果
"""
from pathlib import Path
from typing import Optional
import numpy as np
import polars as pl
from pydantic import BaseModel, Field
from src.tweet_data import load_cleaned_tweets
from src.train_tweet_ultimate import load_model as load_ultimate_model
class SentimentClassification(BaseModel):
"""情感分类结果"""
sentiment: str = Field(description="情感类别: negative/neutral/positive")
confidence: float = Field(description="置信度 (0-1)")
class SentimentExplanation(BaseModel):
"""情感解释"""
key_factors: list[str] = Field(description="影响情感判断的关键因素")
reasoning: str = Field(description="情感判断的推理过程")
class DisposalPlan(BaseModel):
"""处置方案"""
priority: str = Field(description="处理优先级: high/medium/low")
action_type: str = Field(description="行动类型: response/investigate/monitor/ignore")
suggested_response: Optional[str] = Field(description="建议回复内容(如适用)", default=None)
follow_up_actions: list[str] = Field(description="后续行动建议")
class TweetAnalysisResult(BaseModel):
"""推文分析结果(结构化输出)"""
tweet_text: str = Field(description="原始推文文本")
airline: str = Field(description="航空公司")
classification: SentimentClassification = Field(description="情感分类结果")
explanation: SentimentExplanation = Field(description="情感解释")
disposal_plan: DisposalPlan = Field(description="处置方案")
class TweetSentimentAgent:
"""推文情感分析 Agent
实现分类 解释 生成处置方案流程
"""
def __init__(self, model_path: Optional[Path] = None):
"""初始化 Agent
Args:
model_path: 模型路径可选
"""
self.model = load_ultimate_model()
self.label_encoder = self.model.label_encoder
self.tfidf_vectorizer = self.model.tfidf_vectorizer
self.airline_encoder = self.model.airline_encoder
def classify(self, text: str, airline: str) -> SentimentClassification:
"""分类:对推文进行情感分类
Args:
text: 推文文本
airline: 航空公司
Returns:
情感分类结果
"""
# 预测
sentiment = self.model.predict(np.array([text]), np.array([airline]))[0]
# 预测概率
proba = self.model.predict_proba(np.array([text]), np.array([airline]))[0]
# 获取预测类别的置信度
sentiment_idx = self.label_encoder.transform([sentiment])[0]
confidence = float(proba[sentiment_idx])
return SentimentClassification(
sentiment=sentiment,
confidence=confidence,
)
def explain(self, text: str, airline: str, classification: SentimentClassification) -> SentimentExplanation:
"""解释:生成情感判断的解释
Args:
text: 推文文本
airline: 航空公司
classification: 情感分类结果
Returns:
情感解释
"""
key_factors = []
reasoning_parts = []
text_lower = text.lower()
# 分析情感关键词
negative_words = ["bad", "terrible", "awful", "worst", "hate", "angry", "disappointed", "frustrated", "cancelled", "delayed", "lost", "rude"]
positive_words = ["good", "great", "excellent", "best", "love", "happy", "satisfied", "amazing", "wonderful", "thank", "helpful"]
neutral_words = ["question", "how", "what", "when", "where", "why", "please", "help", "info", "information"]
found_negative = [word for word in negative_words if word in text_lower]
found_positive = [word for word in positive_words if word in text_lower]
found_neutral = [word for word in neutral_words if word in text_lower]
if found_negative:
key_factors.append(f"包含负面词汇: {', '.join(found_negative[:3])}")
reasoning_parts.append("文本中包含多个负面情感词汇,表达不满情绪")
if found_positive:
key_factors.append(f"包含正面词汇: {', '.join(found_positive[:3])}")
reasoning_parts.append("文本中包含正面情感词汇,表达满意或感谢")
if found_neutral:
key_factors.append(f"包含中性词汇: {', '.join(found_neutral[:3])}")
reasoning_parts.append("文本主要包含询问或请求,情绪相对中性")
# 分析文本特征
if "!" in text:
key_factors.append("包含感叹号")
reasoning_parts.append("感叹号的使用表明情绪较为强烈")
if "?" in text:
key_factors.append("包含问号")
reasoning_parts.append("问号的使用表明存在疑问或询问")
if "@" in text:
key_factors.append("包含@提及")
reasoning_parts.append("直接@航空公司表明希望获得关注或回复")
# 分析航空公司
key_factors.append(f"涉及航空公司: {airline}")
# 生成推理过程
if not reasoning_parts:
reasoning_parts.append("根据文本整体语义和情感特征进行判断")
reasoning = "".join(reasoning_parts) + ""
return SentimentExplanation(
key_factors=key_factors,
reasoning=reasoning,
)
def generate_disposal_plan(
self,
text: str,
airline: str,
classification: SentimentClassification,
explanation: SentimentExplanation,
) -> DisposalPlan:
"""生成处置方案
Args:
text: 推文文本
airline: 航空公司
classification: 情感分类结果
explanation: 情感解释
Returns:
处置方案
"""
sentiment = classification.sentiment
confidence = classification.confidence
# 根据情感和置信度确定优先级和行动类型
if sentiment == "negative":
if confidence >= 0.8:
priority = "high"
action_type = "response"
suggested_response = self._generate_negative_response(text, airline)
follow_up_actions = [
"记录客户投诉详情",
"转交相关部门处理",
"跟进处理进度",
"在24小时内给予反馈",
]
else:
priority = "medium"
action_type = "investigate"
suggested_response = None
follow_up_actions = [
"进一步核实情况",
"根据核实结果决定是否需要回复",
]
elif sentiment == "positive":
if confidence >= 0.8:
priority = "low"
action_type = "response"
suggested_response = self._generate_positive_response(text, airline)
follow_up_actions = [
"感谢客户反馈",
"分享正面评价至内部团队",
"考虑在官方渠道展示",
]
else:
priority = "low"
action_type = "monitor"
suggested_response = None
follow_up_actions = [
"持续关注该用户后续动态",
]
else: # neutral
if "?" in text or "help" in text.lower():
priority = "medium"
action_type = "response"
suggested_response = self._generate_neutral_response(text, airline)
follow_up_actions = [
"提供准确信息",
"确保客户问题得到解答",
]
else:
priority = "low"
action_type = "monitor"
suggested_response = None
follow_up_actions = [
"持续关注",
]
return DisposalPlan(
priority=priority,
action_type=action_type,
suggested_response=suggested_response,
follow_up_actions=follow_up_actions,
)
def _generate_negative_response(self, text: str, airline: str) -> str:
"""生成负面情感回复"""
responses = [
f"感谢您的反馈。我们非常重视您提到的问题,将立即进行调查并尽快给您答复。",
f"对于您的不愉快体验,我们深表歉意。请私信我们详细情况,我们将全力为您解决。",
f"收到您的反馈,我们对此感到抱歉。相关部门已介入,将尽快处理并给您满意的答复。",
]
return responses[hash(text) % len(responses)]
def _generate_positive_response(self, text: str, airline: str) -> str:
"""生成正面情感回复"""
responses = [
f"感谢您的认可和支持!我们会继续努力为您提供更好的服务。",
f"很高兴听到您的正面反馈!您的满意是我们前进的动力。",
f"感谢您的分享!我们会将您的反馈传达给团队,激励我们做得更好。",
]
return responses[hash(text) % len(responses)]
def _generate_neutral_response(self, text: str, airline: str) -> str:
"""生成中性情感回复"""
responses = [
f"感谢您的询问。请问您需要了解哪方面的信息?我们将竭诚为您解答。",
f"收到您的问题。请提供更多细节,以便我们更好地为您提供帮助。",
]
return responses[hash(text) % len(responses)]
def analyze(self, text: str, airline: str) -> TweetAnalysisResult:
"""完整分析流程:分类 → 解释 → 生成处置方案
Args:
text: 推文文本
airline: 航空公司
Returns:
完整分析结果
"""
# 1. 分类
classification = self.classify(text, airline)
# 2. 解释
explanation = self.explain(text, airline, classification)
# 3. 生成处置方案
disposal_plan = self.generate_disposal_plan(text, airline, classification, explanation)
# 返回结构化结果
return TweetAnalysisResult(
tweet_text=text,
airline=airline,
classification=classification,
explanation=explanation,
disposal_plan=disposal_plan,
)
def analyze_tweet(text: str, airline: str) -> TweetAnalysisResult:
"""分析单条推文
Args:
text: 推文文本
airline: 航空公司
Returns:
分析结果
"""
agent = TweetSentimentAgent()
return agent.analyze(text, airline)
def analyze_tweets_batch(texts: list[str], airlines: list[str]) -> list[TweetAnalysisResult]:
"""批量分析推文
Args:
texts: 推文文本列表
airlines: 航空公司列表
Returns:
分析结果列表
"""
agent = TweetSentimentAgent()
results = []
for text, airline in zip(texts, airlines):
result = agent.analyze(text, airline)
results.append(result)
return results
if __name__ == "__main__":
# 示例:分析单条推文
print(">>> 示例 1: 负面情感")
result = analyze_tweet(
text="@United This is the worst airline ever! My flight was delayed for 5 hours and no one helped!",
airline="United",
)
print(result.model_dump_json(indent=2))
print("\n>>> 示例 2: 正面情感")
result = analyze_tweet(
text="@Southwest Thank you for the amazing flight! The crew was so helpful and friendly.",
airline="Southwest",
)
print(result.model_dump_json(indent=2))
print("\n>>> 示例 3: 中性情感")
result = analyze_tweet(
text="@American What is the baggage policy for international flights?",
airline="American",
)
print(result.model_dump_json(indent=2))

315
src/tweet_data.py Normal file
View File

@ -0,0 +1,315 @@
"""文本数据清洗模块
针对 Tweets.csv 航空情感分析数据集的文本清洗
遵循克制原则仅进行必要的预处理保留文本语义信息
清洗策略
1. 文本标准化统一小写不进行词形还原/词干提取保留原始语义
2. 去除噪声移除用户提及(@username)URL链接多余空格
3. 保留语义保留表情符号标点符号它们对情感分析有价值
4. 最小化处理不进行停用词删除否定词如"not""don't"对情感很重要
"""
import re
from pathlib import Path
import pandera.polars as pa
import polars as pl
# --- Pandera Schema 定义 ---
class RawTweetSchema(pa.DataFrameModel):
"""原始推文数据 Schema清洗前校验
允许缺失值存在用于验证数据读取后的基本结构
"""
tweet_id: int = pa.Field(nullable=False)
airline_sentiment: str = pa.Field(nullable=True)
airline_sentiment_confidence: float = pa.Field(ge=0, le=1, nullable=True)
negativereason: str = pa.Field(nullable=True)
negativereason_confidence: float = pa.Field(ge=0, le=1, nullable=True)
airline: str = pa.Field(nullable=True)
text: str = pa.Field(nullable=True)
tweet_coord: str = pa.Field(nullable=True)
tweet_created: str = pa.Field(nullable=True)
tweet_location: str = pa.Field(nullable=True)
user_timezone: str = pa.Field(nullable=True)
class Config:
strict = False
coerce = True
class CleanTweetSchema(pa.DataFrameModel):
"""清洗后推文数据 Schema严格模式
不允许缺失值强制约束检查
"""
tweet_id: int = pa.Field(nullable=False)
airline_sentiment: str = pa.Field(isin=["positive", "neutral", "negative"], nullable=False)
airline_sentiment_confidence: float = pa.Field(ge=0, le=1, nullable=False)
negativereason: str = pa.Field(nullable=True)
negativereason_confidence: float = pa.Field(ge=0, le=1, nullable=True)
airline: str = pa.Field(isin=["Virgin America", "United", "Southwest", "Delta", "US Airways", "American"], nullable=False)
text_cleaned: str = pa.Field(nullable=False)
text_original: str = pa.Field(nullable=False)
class Config:
strict = True
coerce = True
# --- 文本清洗函数 ---
def clean_text(text: str) -> str:
"""文本清洗函数(克制策略)
清洗原则
- 移除用户提及(@username)URL链接多余空格
- 保留表情符号标点符号否定词原始大小写后续统一小写
- 不做词形还原词干提取停用词删除
Args:
text: 原始文本
Returns:
str: 清洗后的文本
"""
if not text or not isinstance(text, str):
return ""
# 1. 移除用户提及 (@username)
text = re.sub(r'@\w+', '', text)
# 2. 移除 URL 链接
text = re.sub(r'http\S+|www\S+', '', text)
# 3. 移除多余空格和换行
text = re.sub(r'\s+', ' ', text).strip()
return text
def normalize_text(text: str) -> str:
"""文本标准化
统一小写但不进行词形还原或词干提取
Args:
text: 清洗后的文本
Returns:
str: 标准化后的文本
"""
if not text or not isinstance(text, str):
return ""
# 仅统一小写
return text.lower()
# --- 数据加载与预处理 ---
def load_tweets(file_path: str | Path = "Tweets.csv") -> pl.DataFrame:
"""加载原始推文数据
Args:
file_path: CSV 文件路径
Returns:
pl.DataFrame: 原始推文数据
"""
df = pl.read_csv(file_path)
return df
def validate_raw_tweets(df: pl.DataFrame) -> pl.DataFrame:
"""验证原始推文数据结构(清洗前)
Args:
df: 原始 Polars DataFrame
Returns:
pl.DataFrame: 验证通过的 DataFrame
Raises:
pa.errors.SchemaError: 验证失败
"""
return RawTweetSchema.validate(df)
def validate_clean_tweets(df: pl.DataFrame) -> pl.DataFrame:
"""验证清洗后推文数据(严格模式)
Args:
df: 清洗后的 Polars DataFrame
Returns:
pl.DataFrame: 验证通过的 DataFrame
Raises:
pa.errors.SchemaError: 验证失败
"""
return CleanTweetSchema.validate(df)
def preprocess_tweets(
df: pl.DataFrame,
validate: bool = True,
min_confidence: float = 0.5
) -> pl.DataFrame:
"""推文数据预处理流水线
处理步骤
1. 筛选仅保留情感置信度 >= min_confidence 的样本
2. 文本清洗应用 clean_text normalize_text
3. 删除缺失值删除 text 为空的样本
4. 删除重复行基于 tweet_id 去重
5. 可选进行 Schema 校验
Args:
df: 原始 Polars DataFrame
validate: 是否进行清洗后 Schema 校验
min_confidence: 最低情感置信度阈值
Returns:
pl.DataFrame: 清洗后的 DataFrame
"""
# 1. 筛选高置信度样本
df_filtered = df.filter(
pl.col("airline_sentiment_confidence") >= min_confidence
)
# 2. 文本清洗和标准化
df_clean = df_filtered.with_columns([
pl.col("text").map_elements(clean_text, return_dtype=pl.String).alias("text_cleaned"),
pl.col("text").alias("text_original"),
])
df_clean = df_clean.with_columns([
pl.col("text_cleaned").map_elements(normalize_text, return_dtype=pl.String).alias("text_cleaned"),
])
# 3. 删除缺失值text_cleaned 为空或 airline_sentiment 为空)
df_clean = df_clean.filter(
(pl.col("text_cleaned").is_not_null()) &
(pl.col("text_cleaned") != "") &
(pl.col("airline_sentiment").is_not_null())
)
# 4. 删除重复行(基于 tweet_id
df_clean = df_clean.unique(subset=["tweet_id"], keep="first")
# 5. 选择需要的列
df_clean = df_clean.select([
"tweet_id",
"airline_sentiment",
"airline_sentiment_confidence",
"negativereason",
"negativereason_confidence",
"airline",
"text_cleaned",
"text_original",
])
# 6. 可选校验
if validate:
df_clean = validate_clean_tweets(df_clean)
return df_clean
def save_cleaned_tweets(df: pl.DataFrame, output_path: str | Path = "data/Tweets_cleaned.csv") -> None:
"""保存清洗后的数据
Args:
df: 清洗后的 Polars DataFrame
output_path: 输出文件路径
"""
output_path = Path(output_path)
output_path.parent.mkdir(parents=True, exist_ok=True)
df.write_csv(output_path)
print(f"清洗后的数据已保存至 {output_path}")
def load_cleaned_tweets(file_path: str | Path = "data/Tweets_cleaned.csv") -> pl.DataFrame:
"""加载清洗后的推文数据
Args:
file_path: 清洗后的 CSV 文件路径
Returns:
pl.DataFrame: 清洗后的推文数据
"""
df = pl.read_csv(file_path)
return df
# --- 数据统计与分析 ---
def print_data_summary(df: pl.DataFrame, title: str = "数据统计") -> None:
"""打印数据摘要信息
Args:
df: Polars DataFrame
title: 标题
"""
print(f"\n{'='*60}")
print(f"{title}")
print(f"{'='*60}")
print(f"样本总数: {len(df)}")
print(f"\n情感分布:")
print(df.group_by("airline_sentiment").agg(
pl.len().alias("count"),
(pl.len() / len(df) * 100).alias("percentage")
).sort("count", descending=True))
print(f"\n航空公司分布:")
print(df.group_by("airline").agg(
pl.len().alias("count"),
(pl.len() / len(df) * 100).alias("percentage")
).sort("count", descending=True))
print(f"\n文本长度统计:")
df_with_length = df.with_columns([
pl.col("text_cleaned").str.len_chars().alias("text_length")
])
print(df_with_length.select([
pl.col("text_length").min().alias("最小长度"),
pl.col("text_length").max().alias("最大长度"),
pl.col("text_length").mean().alias("平均长度"),
pl.col("text_length").median().alias("中位数长度"),
]))
if __name__ == "__main__":
print(">>> 1. 加载原始数据")
df_raw = load_tweets("Tweets.csv")
print(f"原始数据样本数: {len(df_raw)}")
print(df_raw.head(3))
print("\n>>> 2. 验证原始数据")
df_validated = validate_raw_tweets(df_raw)
print("✅ 原始数据验证通过")
print("\n>>> 3. 清洗数据")
df_clean = preprocess_tweets(df_validated, validate=True, min_confidence=0.5)
print(f"清洗后样本数: {len(df_clean)} (原始: {len(df_raw)})")
print("✅ 清洗后数据验证通过")
print("\n>>> 4. 保存清洗后的数据")
save_cleaned_tweets(df_clean, "data/Tweets_cleaned.csv")
print("\n>>> 5. 数据统计")
print_data_summary(df_clean, "清洗后数据统计")
print("\n>>> 6. 清洗示例对比")
print("\n原始文本:")
print(df_clean.select("text_original").head(3).to_pandas()["text_original"].to_string(index=False))
print("\n清洗后文本:")
print(df_clean.select("text_cleaned").head(3).to_pandas()["text_cleaned"].to_string(index=False))