Compare commits
No commits in common. "58437d6a48ee2220799945cb19a19e12ddc45152" and "1d641aa0177fc86a28de6bce69ec536ef3d429e6" have entirely different histories.
58437d6a48
...
1d641aa017
252
README.md
Normal file
252
README.md
Normal file
@ -0,0 +1,252 @@
|
|||||||
|
# 航空推文情感分析系统
|
||||||
|
|
||||||
|
|
||||||
|
> **机器学习 (Python) 课程设计**
|
||||||
|
|
||||||
|
|
||||||
|
## 👥 团队成员
|
||||||
|
|
||||||
|
|
||||||
|
| 姓名 | 学号 | 贡献 |
|
||||||
|
|------|------|------|
|
||||||
|
| 张则文 | 2311020133 | 数据处理、模型训练、Agent 开发、Streamlit开发、文档撰写 |
|
||||||
|
| 潘俊康 | 2311020121 | 仓库搭建、Streamlit测试、文档撰写 |
|
||||||
|
| 陈俊均 | 2311020104 | Agent 开发、Streamlit测试、文档撰写 |
|
||||||
|
|
||||||
|
|
||||||
|
## 📝 项目简介
|
||||||
|
|
||||||
|
|
||||||
|
本项目是一个基于**传统机器学习 + LLM + Agent**的航空推文情感分析系统,旨在实现可落地的智能预测与行动建议。系统使用 Twitter US Airline Sentiment 数据集,通过传统机器学习完成推文情感的量化预测,再利用 LLM 和 Agent 技术将预测结果转化为结构化、可执行的决策建议,确保输出结果可追溯、可复现。
|
||||||
|
|
||||||
|
|
||||||
|
## 🚀 快速开始
|
||||||
|
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 克隆仓库
|
||||||
|
git clone http://hblu.top:3000/MachineLearning2025/G05-Sentiment-Analysis-of-Aviation-Tweets.git
|
||||||
|
cd G05-Sentiment-Analysis-of-Aviation-Tweets
|
||||||
|
|
||||||
|
# 安装依赖
|
||||||
|
pip install uv -i https://mirrors.aliyun.com/pypi/simple/
|
||||||
|
uv config set index-url https://mirrors.aliyun.com/pypi/simple/
|
||||||
|
uv sync
|
||||||
|
|
||||||
|
# 配置环境变量
|
||||||
|
cp .env.example .env
|
||||||
|
# 编辑 .env 填入 API Key
|
||||||
|
|
||||||
|
# 运行 Demo
|
||||||
|
uv run streamlit run src/streamlit_tweet_app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 1️⃣ 问题定义与数据
|
||||||
|
|
||||||
|
|
||||||
|
### 1.1 任务描述
|
||||||
|
|
||||||
|
|
||||||
|
本项目是一个三分类任务,目标是自动识别航空推文的情感倾向(negative/neutral/positive)。业务目标是构建一个高准确率、可解释的推文情感分析系统,帮助航空公司及时了解客户反馈,优化服务质量,提升客户满意度。
|
||||||
|
|
||||||
|
|
||||||
|
### 1.2 数据来源
|
||||||
|
|
||||||
|
|
||||||
|
| 项目 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| 数据集名称 | Twitter US Airline Sentiment |
|
||||||
|
| 数据链接 | [Kaggle](https://www.kaggle.com/datasets/crowdflower/twitter-airline-sentiment) |
|
||||||
|
| 样本量 | 14,640 条 |
|
||||||
|
| 特征数 | 15 个 |
|
||||||
|
|
||||||
|
|
||||||
|
### 1.3 数据切分与防泄漏
|
||||||
|
|
||||||
|
|
||||||
|
数据按 8:2 比例分割为训练集和测试集,确保模型在独立的测试集上进行评估。在数据预处理和特征工程阶段,所有操作仅在训练集上进行,避免信息泄漏到测试集。使用 TF-IDF 进行文本向量化时,同样严格遵循先训练后应用的原则。
|
||||||
|
|
||||||
|
|
||||||
|
## 2️⃣ 机器学习流水线
|
||||||
|
|
||||||
|
|
||||||
|
### 2.1 模型架构
|
||||||
|
|
||||||
|
|
||||||
|
本项目采用 **VotingClassifier** 集成学习方法,结合多个基础分类器的优势:
|
||||||
|
|
||||||
|
- **逻辑回归 (Logistic Regression)**:线性模型,适合处理高维稀疏特征
|
||||||
|
- **多项式朴素贝叶斯 (MultinomialNB)**:适合文本分类任务
|
||||||
|
- **随机森林 (RandomForestClassifier)**:集成树模型,抗过拟合能力强
|
||||||
|
- **LightGBM 分类器**:梯度提升树模型,高性能、高效率
|
||||||
|
|
||||||
|
|
||||||
|
### 2.2 模型性能
|
||||||
|
|
||||||
|
|
||||||
|
| 模型 | 指标 | 结果 |
|
||||||
|
|------|------|------|
|
||||||
|
| VotingClassifier | 准确率 | 0.8159 |
|
||||||
|
| VotingClassifier | F1 分数(Macro) | 0.7533 |
|
||||||
|
|
||||||
|
|
||||||
|
### 2.3 特征工程
|
||||||
|
|
||||||
|
|
||||||
|
1. **文本特征提取**:使用 TF-IDF 向量化,最大特征数为 5000,ngram 范围为 (1, 2)
|
||||||
|
2. **航空公司编码**:使用 LabelEncoder 对航空公司名称进行编码
|
||||||
|
3. **特征合并**:将文本特征和航空公司特征合并为最终特征矩阵
|
||||||
|
|
||||||
|
|
||||||
|
### 2.4 误差分析
|
||||||
|
|
||||||
|
|
||||||
|
模型在以下类型的样本上表现相对较差:
|
||||||
|
1. 包含复杂情感表达的推文(如讽刺、反语)
|
||||||
|
2. 混合多种情感的推文
|
||||||
|
3. 包含大量特殊字符或缩写的推文
|
||||||
|
4. 上下文依赖较强的推文
|
||||||
|
|
||||||
|
这主要是因为文本特征提取方法(TF-IDF)对语义理解有限,无法完全捕捉复杂的语言模式和上下文信息。
|
||||||
|
|
||||||
|
|
||||||
|
## 3️⃣ Agent 实现
|
||||||
|
|
||||||
|
|
||||||
|
### 3.1 工具定义
|
||||||
|
|
||||||
|
|
||||||
|
| 工具名 | 功能 | 输入 | 输出 |
|
||||||
|
|--------|------|------|------|
|
||||||
|
| `predict_sentiment` | 使用机器学习模型预测推文情感 | 推文文本、航空公司 | 分类结果和概率 |
|
||||||
|
| `explain_sentiment` | 解释模型预测结果并生成行动建议 | 推文文本、分类结果、概率 | 结构化的解释和建议 |
|
||||||
|
| `generate_response` | 生成针对推文的回复建议 | 推文文本、情感分类 | 回复建议文本 |
|
||||||
|
|
||||||
|
|
||||||
|
### 3.2 决策流程
|
||||||
|
|
||||||
|
|
||||||
|
Agent 按照以下流程执行任务:
|
||||||
|
1. 接收用户提供的推文文本和航空公司信息
|
||||||
|
2. 使用 `predict_sentiment` 工具进行情感分类预测
|
||||||
|
3. 使用 `explain_sentiment` 工具解释分类结果并生成行动建议
|
||||||
|
4. 使用 `generate_response` 工具生成针对性的回复建议
|
||||||
|
5. 向用户提供清晰、完整的情感分析结果、解释和建议
|
||||||
|
|
||||||
|
|
||||||
|
### 3.3 案例展示
|
||||||
|
|
||||||
|
|
||||||
|
**输入**:
|
||||||
|
```
|
||||||
|
@United This is the worst airline ever! My flight was delayed for 5 hours and no one helped!
|
||||||
|
```
|
||||||
|
|
||||||
|
**输出**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"classification": {
|
||||||
|
"label": "negative",
|
||||||
|
"probability": {
|
||||||
|
"negative": 0.92,
|
||||||
|
"neutral": 0.05,
|
||||||
|
"positive": 0.03
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"explanation": {
|
||||||
|
"key_factors": ["worst airline ever", "delayed for 5 hours", "no one helped"],
|
||||||
|
"reasoning": "推文中包含强烈的负面情感词汇,描述了航班延误和缺乏帮助的负面体验",
|
||||||
|
"confidence_level": "高",
|
||||||
|
"suggestions": ["立即联系客户并提供补偿", "调查延误原因并改进服务流程", "加强员工培训"]
|
||||||
|
},
|
||||||
|
"response_suggestion": "尊敬的客户,对于您航班延误和未能获得及时帮助的糟糕体验,我们深表歉意。我们将立即调查此事并为您提供相应的补偿。感谢您的反馈,我们将努力改进服务质量。"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## 4️⃣ 系统特色
|
||||||
|
|
||||||
|
|
||||||
|
### 4.1 多模态情感分析
|
||||||
|
|
||||||
|
系统不仅提供情感分类结果,还通过 LLM 生成详细的解释和可执行的行动建议,实现从预测到决策的完整闭环。
|
||||||
|
|
||||||
|
|
||||||
|
### 4.2 实时交互体验
|
||||||
|
|
||||||
|
通过 Streamlit 构建的 Web 界面提供直观的交互体验,支持单条推文分析和批量文件处理功能。
|
||||||
|
|
||||||
|
|
||||||
|
### 4.3 结构化输出
|
||||||
|
|
||||||
|
所有输出都采用结构化格式,确保结果的可追溯性和可复现性,便于后续分析和应用。
|
||||||
|
|
||||||
|
|
||||||
|
## 5️⃣ 开发心得
|
||||||
|
|
||||||
|
|
||||||
|
### 5.1 主要困难与解决方案
|
||||||
|
|
||||||
|
|
||||||
|
1. **文本特征提取**:航空推文包含大量缩写、特殊字符和行业术语,解决方案是使用 TF-IDF 结合 ngram 特征,捕捉更丰富的语言模式。
|
||||||
|
2. **多分类平衡**:情感分类是三分类任务,需要处理类别不平衡问题,解决方案是使用 Macro-F1 作为主要评估指标。
|
||||||
|
3. **模型集成**:单个模型在复杂情感识别上存在局限,解决方案是使用 VotingClassifier 集成多个模型的优势。
|
||||||
|
|
||||||
|
|
||||||
|
### 5.2 对 AI 辅助编程的感受
|
||||||
|
|
||||||
|
|
||||||
|
AI 辅助编程工具在代码编写和问题解决方面提供了很大帮助,特别是在处理重复性任务和学习新框架时。它可以快速生成代码模板,提供解决方案建议,显著提高开发效率。但同时也需要注意,AI 生成的代码可能存在错误或不符合项目规范,需要人工仔细检查和调试。
|
||||||
|
|
||||||
|
|
||||||
|
### 5.3 局限与未来改进
|
||||||
|
|
||||||
|
|
||||||
|
1. **模型性能**:当前模型在处理复杂语言模式和上下文理解方面仍有提升空间,可以考虑使用更先进的文本表示方法(如 BERT)。
|
||||||
|
2. **多语言支持**:目前系统主要支持英文推文,未来可以扩展到多语言情感分析。
|
||||||
|
3. **实时性**:可以优化模型推理速度,实现实时情感分析功能。
|
||||||
|
4. **情感细粒度分析**:可以进一步细分情感类别,如愤怒、失望、满意等更细致的情感标签。
|
||||||
|
|
||||||
|
|
||||||
|
## 技术栈
|
||||||
|
|
||||||
|
|
||||||
|
| 组件 | 技术 | 版本要求 |
|
||||||
|
|------|------|----------|
|
||||||
|
| 项目管理 | uv | 最新版 |
|
||||||
|
| 数据处理 | polars + pandas | polars>=0.20.0, pandas>=2.2.0 |
|
||||||
|
| 数据验证 | pandera | >=0.18.0 |
|
||||||
|
| 机器学习 | scikit-learn + lightgbm | sklearn>=1.3.0, lightgbm>=4.0.0 |
|
||||||
|
| LLM 框架 | openai | >=1.0.0 |
|
||||||
|
| Agent 框架 | pydantic | pydantic>=2.0.0 |
|
||||||
|
| 可视化 | streamlit | >=1.20.0 |
|
||||||
|
| 文本处理 | nltk | >=3.8.0 |
|
||||||
|
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
|
||||||
|
## 致谢
|
||||||
|
|
||||||
|
|
||||||
|
- 感谢 [DeepSeek](https://www.deepseek.com/) 提供的 LLM API
|
||||||
|
- 感谢 Kaggle 提供的 [Twitter US Airline Sentiment](https://www.kaggle.com/datasets/crowdflower/twitter-airline-sentiment) 数据集
|
||||||
|
- 感谢所有开源库的贡献者
|
||||||
|
|
||||||
|
|
||||||
|
## 联系方式
|
||||||
|
|
||||||
|
|
||||||
|
如有问题或建议,欢迎通过以下方式联系:
|
||||||
|
|
||||||
|
- 项目地址:[http://hblu.top:3000/MachineLearning2025/G05-Sentiment-Analysis-of-Aviation-Tweets](http://hblu.top:3000/MachineLearning2025/G05-Sentiment-Analysis-of-Aviation-Tweets)
|
||||||
|
- 邮箱:xxxxxxxxxx@gmail.com
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**© 2026 航空推文情感分析系统 | 基于传统机器学习 + LLM + Agent**
|
||||||
BIN
models/tweet_sentiment_model_ultimate.pkl
Normal file
BIN
models/tweet_sentiment_model_ultimate.pkl
Normal file
Binary file not shown.
51
pyproject.toml
Normal file
51
pyproject.toml
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
[project]
|
||||||
|
name = "ml-course-design"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "机器学习 × LLM × Agent 课程设计模板"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"pydantic>=2.10",
|
||||||
|
"pandera>=0.21",
|
||||||
|
"pydantic-ai>=0.7",
|
||||||
|
"polars>=1.0",
|
||||||
|
"pandas>=2.2",
|
||||||
|
"scikit-learn>=1.5",
|
||||||
|
"lightgbm>=4.5",
|
||||||
|
"seaborn>=0.13",
|
||||||
|
"joblib>=1.4",
|
||||||
|
"python-dotenv>=1.0",
|
||||||
|
"streamlit>=1.40",
|
||||||
|
"xgboost>=3.1.3",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
name = "tencent"
|
||||||
|
url = "https://mirrors.cloud.tencent.com/pypi/simple/"
|
||||||
|
default = true
|
||||||
|
|
||||||
|
[[tool.uv.index]]
|
||||||
|
url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||||
|
|
||||||
|
[dependency-groups]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0",
|
||||||
|
"pytest-asyncio>=1.3",
|
||||||
|
"ruff>=0.8",
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["src"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
line-length = 100
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = ["E", "F", "I"]
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
12
src/__init__.py
Normal file
12
src/__init__.py
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
"""推文情感分析包"""
|
||||||
|
|
||||||
|
from src.tweet_agent import TweetSentimentAgent, analyze_tweet
|
||||||
|
from src.tweet_data import load_cleaned_tweets
|
||||||
|
from src.train_tweet_ultimate import load_model
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TweetSentimentAgent",
|
||||||
|
"analyze_tweet",
|
||||||
|
"load_cleaned_tweets",
|
||||||
|
"load_model",
|
||||||
|
]
|
||||||
353
src/streamlit_tweet_app.py
Normal file
353
src/streamlit_tweet_app.py
Normal file
@ -0,0 +1,353 @@
|
|||||||
|
"""Streamlit 演示应用 - 推文情感分析
|
||||||
|
|
||||||
|
航空推文情感分析 AI 助手 - 支持情感分类、解释和处置方案生成。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
# Ensure project root is in path
|
||||||
|
sys.path.append(os.getcwd())
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
from src.tweet_agent import TweetSentimentAgent, analyze_tweet
|
||||||
|
|
||||||
|
# Load env variables
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
st.set_page_config(page_title="航空推文情感分析", page_icon="✈️", layout="wide")
|
||||||
|
|
||||||
|
# Sidebar Configuration
|
||||||
|
st.sidebar.header("🔧 配置")
|
||||||
|
st.sidebar.markdown("### 模型信息")
|
||||||
|
st.sidebar.info(
|
||||||
|
"""
|
||||||
|
**模型**: VotingClassifier (5个基学习器)
|
||||||
|
- Logistic Regression
|
||||||
|
- Multinomial Naive Bayes
|
||||||
|
- Random Forest
|
||||||
|
- LightGBM
|
||||||
|
- XGBoost
|
||||||
|
|
||||||
|
**性能**: Macro-F1 = 0.7533 ✅
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
st.sidebar.markdown("---")
|
||||||
|
# Mode Selection
|
||||||
|
mode = st.sidebar.radio("功能选择", ["📝 单条分析", "📊 批量分析", "📈 数据概览"])
|
||||||
|
|
||||||
|
# Initialize session state
|
||||||
|
if "agent" not in st.session_state:
|
||||||
|
with st.spinner("🔄 加载模型..."):
|
||||||
|
st.session_state.agent = TweetSentimentAgent()
|
||||||
|
|
||||||
|
if "batch_results" not in st.session_state:
|
||||||
|
st.session_state.batch_results = []
|
||||||
|
|
||||||
|
|
||||||
|
# --- Helper Functions ---
|
||||||
|
|
||||||
|
|
||||||
|
def get_sentiment_emoji(sentiment: str) -> str:
|
||||||
|
"""获取情感对应的表情符号"""
|
||||||
|
emoji_map = {
|
||||||
|
"negative": "😠",
|
||||||
|
"neutral": "😐",
|
||||||
|
"positive": "😊",
|
||||||
|
}
|
||||||
|
return emoji_map.get(sentiment, "❓")
|
||||||
|
|
||||||
|
|
||||||
|
def get_sentiment_color(sentiment: str) -> str:
|
||||||
|
"""获取情感对应的颜色"""
|
||||||
|
color_map = {
|
||||||
|
"negative": "#ff6b6b",
|
||||||
|
"neutral": "#ffd93d",
|
||||||
|
"positive": "#6bcb77",
|
||||||
|
}
|
||||||
|
return color_map.get(sentiment, "#e0e0e0")
|
||||||
|
|
||||||
|
|
||||||
|
def get_priority_color(priority: str) -> str:
|
||||||
|
"""获取优先级对应的颜色"""
|
||||||
|
color_map = {
|
||||||
|
"high": "#ff4757",
|
||||||
|
"medium": "#ffa502",
|
||||||
|
"low": "#2ed573",
|
||||||
|
}
|
||||||
|
return color_map.get(priority, "#e0e0e0")
|
||||||
|
|
||||||
|
|
||||||
|
# --- Main Views ---
|
||||||
|
|
||||||
|
if mode == "📝 单条分析":
|
||||||
|
st.title("✈️ 航空推文情感分析")
|
||||||
|
st.markdown("输入推文文本,获取 AI 驱动的情感分析、解释和处置方案。")
|
||||||
|
|
||||||
|
# Input form
|
||||||
|
with st.form("tweet_analysis_form"):
|
||||||
|
col1, col2 = st.columns([3, 1])
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
tweet_text = st.text_area(
|
||||||
|
"推文内容",
|
||||||
|
placeholder="@United This is the worst airline ever! My flight was delayed for 5 hours...",
|
||||||
|
height=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
airline = st.selectbox(
|
||||||
|
"航空公司",
|
||||||
|
["United", "US Airways", "American", "Southwest", "Delta", "Virgin America"],
|
||||||
|
)
|
||||||
|
|
||||||
|
submitted = st.form_submit_button("🔍 分析", type="primary")
|
||||||
|
|
||||||
|
if submitted and tweet_text:
|
||||||
|
with st.spinner("🤖 AI 正在分析..."):
|
||||||
|
try:
|
||||||
|
result = analyze_tweet(tweet_text, airline)
|
||||||
|
|
||||||
|
# Display results
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Header with sentiment
|
||||||
|
sentiment_emoji = get_sentiment_emoji(result.classification.sentiment)
|
||||||
|
sentiment_color = get_sentiment_color(result.classification.sentiment)
|
||||||
|
|
||||||
|
st.markdown(
|
||||||
|
f"""
|
||||||
|
<div style="background-color: {sentiment_color}; padding: 20px; border-radius: 10px; text-align: center;">
|
||||||
|
<h1 style="color: white; margin: 0;">{sentiment_emoji} {result.classification.sentiment.upper()}</h1>
|
||||||
|
<p style="color: white; margin: 10px 0 0 0;">置信度: {result.classification.confidence:.1%}</p>
|
||||||
|
</div>
|
||||||
|
""",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
|
||||||
|
# Original tweet
|
||||||
|
st.subheader("📝 原始推文")
|
||||||
|
st.info(f"**航空公司**: {result.airline}\n\n**内容**: {result.tweet_text}")
|
||||||
|
|
||||||
|
# Explanation
|
||||||
|
st.subheader("🔍 情感解释")
|
||||||
|
st.markdown("**关键因素:**")
|
||||||
|
for factor in result.explanation.key_factors:
|
||||||
|
st.write(f"- {factor}")
|
||||||
|
|
||||||
|
st.markdown("**推理过程:**")
|
||||||
|
st.write(result.explanation.reasoning)
|
||||||
|
|
||||||
|
# Disposal plan
|
||||||
|
st.subheader("📋 处置方案")
|
||||||
|
|
||||||
|
priority_color = get_priority_color(result.disposal_plan.priority)
|
||||||
|
st.markdown(
|
||||||
|
f"""
|
||||||
|
<div style="background-color: {priority_color}; padding: 10px; border-radius: 5px; display: inline-block;">
|
||||||
|
<span style="color: white; font-weight: bold;">优先级: {result.disposal_plan.priority.upper()}</span>
|
||||||
|
</div>
|
||||||
|
<br><br>
|
||||||
|
**行动类型**: {result.disposal_plan.action_type}
|
||||||
|
""",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
if result.disposal_plan.suggested_response:
|
||||||
|
st.markdown("**建议回复:**")
|
||||||
|
st.success(result.disposal_plan.suggested_response)
|
||||||
|
|
||||||
|
st.markdown("**后续行动:**")
|
||||||
|
for action in result.disposal_plan.follow_up_actions:
|
||||||
|
st.write(f"- {action}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"分析失败: {e!s}")
|
||||||
|
|
||||||
|
elif mode == "📊 批量分析":
|
||||||
|
st.title("📊 批量推文分析")
|
||||||
|
st.markdown("上传 CSV 文件或输入多条推文,进行批量情感分析。")
|
||||||
|
|
||||||
|
# Input method selection
|
||||||
|
input_method = st.radio("输入方式", ["手动输入", "CSV 上传"], horizontal=True)
|
||||||
|
|
||||||
|
if input_method == "手动输入":
|
||||||
|
st.markdown("### 输入推文(每行一条)")
|
||||||
|
tweets_input = st.text_area(
|
||||||
|
"推文列表",
|
||||||
|
placeholder="@United Flight delayed again!\n@Southwest Great service!\n@American Baggage policy?",
|
||||||
|
height=200,
|
||||||
|
)
|
||||||
|
|
||||||
|
if st.button("🔍 批量分析", type="primary") and tweets_input:
|
||||||
|
lines = [line.strip() for line in tweets_input.split("\n") if line.strip()]
|
||||||
|
|
||||||
|
if lines:
|
||||||
|
with st.spinner("🤖 AI 正在分析..."):
|
||||||
|
results = []
|
||||||
|
for line in lines:
|
||||||
|
try:
|
||||||
|
# Extract airline from tweet (simple heuristic)
|
||||||
|
airline = "United" # Default
|
||||||
|
for a in ["United", "US Airways", "American", "Southwest", "Delta", "Virgin America"]:
|
||||||
|
if a.lower() in line.lower():
|
||||||
|
airline = a
|
||||||
|
break
|
||||||
|
|
||||||
|
result = analyze_tweet(line, airline)
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
st.warning(f"分析失败: {line[:50]}... - {e}")
|
||||||
|
|
||||||
|
if results:
|
||||||
|
st.session_state.batch_results = results
|
||||||
|
st.success(f"✅ 成功分析 {len(results)} 条推文")
|
||||||
|
|
||||||
|
else: # CSV upload
|
||||||
|
st.markdown("### 上传 CSV 文件")
|
||||||
|
st.info("CSV 文件应包含以下列: `text` (推文内容), `airline` (航空公司)")
|
||||||
|
|
||||||
|
uploaded_file = st.file_uploader("选择文件", type=["csv"])
|
||||||
|
|
||||||
|
if uploaded_file and st.button("🔍 分析上传文件", type="primary"):
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
try:
|
||||||
|
df = pd.read_csv(uploaded_file)
|
||||||
|
|
||||||
|
if "text" not in df.columns:
|
||||||
|
st.error("CSV 文件必须包含 'text' 列")
|
||||||
|
else:
|
||||||
|
with st.spinner("🤖 AI 正在分析..."):
|
||||||
|
results = []
|
||||||
|
for _, row in df.iterrows():
|
||||||
|
try:
|
||||||
|
text = row["text"]
|
||||||
|
airline = row.get("airline", "United")
|
||||||
|
result = analyze_tweet(text, airline)
|
||||||
|
results.append(result)
|
||||||
|
except Exception as e:
|
||||||
|
st.warning(f"分析失败: {text[:50]}... - {e}")
|
||||||
|
|
||||||
|
if results:
|
||||||
|
st.session_state.batch_results = results
|
||||||
|
st.success(f"✅ 成功分析 {len(results)} 条推文")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"文件读取失败: {e!s}")
|
||||||
|
|
||||||
|
# Display batch results
|
||||||
|
if st.session_state.batch_results:
|
||||||
|
st.divider()
|
||||||
|
st.subheader(f"📊 分析结果 ({len(st.session_state.batch_results)} 条)")
|
||||||
|
|
||||||
|
# Summary statistics
|
||||||
|
sentiments = [r.classification.sentiment for r in st.session_state.batch_results]
|
||||||
|
negative_count = sentiments.count("negative")
|
||||||
|
neutral_count = sentiments.count("neutral")
|
||||||
|
positive_count = sentiments.count("positive")
|
||||||
|
|
||||||
|
col1, col2, col3 = st.columns(3)
|
||||||
|
col1.metric("😠 负面", negative_count)
|
||||||
|
col2.metric("😐 中性", neutral_count)
|
||||||
|
col3.metric("😊 正面", positive_count)
|
||||||
|
|
||||||
|
# Detailed results table
|
||||||
|
st.markdown("### 详细结果")
|
||||||
|
|
||||||
|
results_data = []
|
||||||
|
for r in st.session_state.batch_results:
|
||||||
|
results_data.append({
|
||||||
|
"推文": r.tweet_text[:50] + "..." if len(r.tweet_text) > 50 else r.tweet_text,
|
||||||
|
"航空公司": r.airline,
|
||||||
|
"情感": f"{get_sentiment_emoji(r.classification.sentiment)} {r.classification.sentiment}",
|
||||||
|
"置信度": f"{r.classification.confidence:.1%}",
|
||||||
|
"优先级": r.disposal_plan.priority.upper(),
|
||||||
|
"行动类型": r.disposal_plan.action_type,
|
||||||
|
})
|
||||||
|
|
||||||
|
st.dataframe(results_data, use_container_width=True)
|
||||||
|
|
||||||
|
# Clear button
|
||||||
|
if st.button("🗑️ 清除结果"):
|
||||||
|
st.session_state.batch_results = []
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
elif mode == "📈 数据概览":
|
||||||
|
st.title("📈 数据集概览")
|
||||||
|
st.markdown("查看训练数据集的统计信息。")
|
||||||
|
|
||||||
|
try:
|
||||||
|
import polars as pl
|
||||||
|
from src.tweet_data import load_cleaned_tweets, print_data_summary
|
||||||
|
|
||||||
|
df = load_cleaned_tweets("data/Tweets_cleaned.csv")
|
||||||
|
|
||||||
|
# Display summary
|
||||||
|
st.subheader("📊 数据统计")
|
||||||
|
print_data_summary(df, "数据集统计")
|
||||||
|
|
||||||
|
# Display sample data
|
||||||
|
st.subheader("📝 样本数据")
|
||||||
|
sample_df = df.head(10).to_pandas()
|
||||||
|
st.dataframe(sample_df, use_container_width=True)
|
||||||
|
|
||||||
|
# Sentiment distribution chart
|
||||||
|
st.subheader("📈 情感分布")
|
||||||
|
sentiment_counts = df.group_by("airline_sentiment").agg(
|
||||||
|
pl.col("airline_sentiment").count().alias("count")
|
||||||
|
).sort("count", descending=True)
|
||||||
|
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.express as px
|
||||||
|
|
||||||
|
sentiment_df = sentiment_counts.to_pandas()
|
||||||
|
fig = px.pie(
|
||||||
|
sentiment_df,
|
||||||
|
values="count",
|
||||||
|
names="airline_sentiment",
|
||||||
|
title="情感分布",
|
||||||
|
color_discrete_map={
|
||||||
|
"negative": "#ff6b6b",
|
||||||
|
"neutral": "#ffd93d",
|
||||||
|
"positive": "#6bcb77",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
# Airline distribution chart
|
||||||
|
st.subheader("✈️ 航空公司分布")
|
||||||
|
airline_counts = df.group_by("airline").agg(
|
||||||
|
pl.col("airline").count().alias("count")
|
||||||
|
).sort("count", descending=True)
|
||||||
|
|
||||||
|
airline_df = airline_counts.to_pandas()
|
||||||
|
fig = px.bar(
|
||||||
|
airline_df,
|
||||||
|
x="airline",
|
||||||
|
y="count",
|
||||||
|
title="各航空公司推文数量",
|
||||||
|
color="count",
|
||||||
|
color_continuous_scale="Blues",
|
||||||
|
)
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"数据加载失败: {e!s}")
|
||||||
|
|
||||||
|
# Footer
|
||||||
|
st.divider()
|
||||||
|
st.markdown(
|
||||||
|
"""
|
||||||
|
<div style="text-align: center; color: gray; font-size: 12px;">
|
||||||
|
航空推文情感分析 AI 助手 | 基于 VotingClassifier (LR + NB + RF + LightGBM + XGBoost)
|
||||||
|
</div>
|
||||||
|
""",
|
||||||
|
unsafe_allow_html=True,
|
||||||
|
)
|
||||||
334
src/train_tweet_ultimate.py
Normal file
334
src/train_tweet_ultimate.py
Normal file
@ -0,0 +1,334 @@
|
|||||||
|
"""推文情感分析训练模块(最终优化版)
|
||||||
|
|
||||||
|
使用多种算法组合 + 特征工程 + 超参数优化。
|
||||||
|
目标:达到 Accuracy ≥ 0.82 或 Macro-F1 ≥ 0.75
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import joblib
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
from scipy.sparse import hstack
|
||||||
|
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.metrics import classification_report, accuracy_score, f1_score
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.naive_bayes import MultinomialNB
|
||||||
|
from sklearn.preprocessing import LabelEncoder
|
||||||
|
|
||||||
|
try:
|
||||||
|
import lightgbm as lgb
|
||||||
|
HAS_LIGHTGBM = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_LIGHTGBM = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
import xgboost as xgb
|
||||||
|
HAS_XGBOOST = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_XGBOOST = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from catboost import CatBoostClassifier
|
||||||
|
HAS_CATBOOST = True
|
||||||
|
except ImportError:
|
||||||
|
HAS_CATBOOST = False
|
||||||
|
|
||||||
|
from src.tweet_data import load_cleaned_tweets, print_data_summary
|
||||||
|
|
||||||
|
MODELS_DIR = Path(__file__).parent.parent / "models"
|
||||||
|
MODEL_PATH = MODELS_DIR / "tweet_sentiment_model_ultimate.pkl"
|
||||||
|
ENCODER_PATH = MODELS_DIR / "label_encoder_ultimate.pkl"
|
||||||
|
TFIDF_PATH = MODELS_DIR / "tfidf_vectorizer_ultimate.pkl"
|
||||||
|
AIRLINE_ENCODER_PATH = MODELS_DIR / "airline_encoder_ultimate.pkl"
|
||||||
|
|
||||||
|
|
||||||
|
class TweetSentimentModel:
|
||||||
|
"""推文情感分析模型类(最终优化)
|
||||||
|
|
||||||
|
结合多种算法和特征工程进行分类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
max_features: int = 15000,
|
||||||
|
ngram_range: tuple = (1, 3),
|
||||||
|
):
|
||||||
|
self.max_features = max_features
|
||||||
|
self.ngram_range = ngram_range
|
||||||
|
|
||||||
|
self.tfidf_vectorizer = None
|
||||||
|
self.label_encoder = None
|
||||||
|
self.model = None
|
||||||
|
self.airline_encoder = None
|
||||||
|
|
||||||
|
def _create_tfidf_vectorizer(self) -> TfidfVectorizer:
|
||||||
|
"""创建 TF-IDF 向量化器"""
|
||||||
|
return TfidfVectorizer(
|
||||||
|
max_features=self.max_features,
|
||||||
|
ngram_range=self.ngram_range,
|
||||||
|
min_df=2,
|
||||||
|
max_df=0.95,
|
||||||
|
lowercase=False,
|
||||||
|
sublinear_tf=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def fit(
|
||||||
|
self,
|
||||||
|
X_text: np.ndarray,
|
||||||
|
X_airline: np.ndarray,
|
||||||
|
y: np.ndarray,
|
||||||
|
) -> None:
|
||||||
|
"""训练模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X_text: 训练文本数据
|
||||||
|
X_airline: 训练航空公司数据
|
||||||
|
y: 训练标签
|
||||||
|
"""
|
||||||
|
# 初始化编码器
|
||||||
|
self.tfidf_vectorizer = self._create_tfidf_vectorizer()
|
||||||
|
self.label_encoder = LabelEncoder()
|
||||||
|
self.airline_encoder = LabelEncoder()
|
||||||
|
|
||||||
|
# 编码标签
|
||||||
|
y_encoded = self.label_encoder.fit_transform(y)
|
||||||
|
|
||||||
|
# 编码航空公司
|
||||||
|
X_airline_encoded = self.airline_encoder.fit_transform(X_airline)
|
||||||
|
|
||||||
|
# TF-IDF 向量化
|
||||||
|
X_tfidf = self.tfidf_vectorizer.fit_transform(X_text)
|
||||||
|
|
||||||
|
# 合并特征
|
||||||
|
X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)])
|
||||||
|
|
||||||
|
# 构建集成模型 - 使用不同的算法
|
||||||
|
estimators = []
|
||||||
|
|
||||||
|
# Logistic Regression - 稳定的基线
|
||||||
|
estimators.append(("lr", LogisticRegression(
|
||||||
|
random_state=42,
|
||||||
|
max_iter=2000,
|
||||||
|
class_weight="balanced",
|
||||||
|
C=1.0,
|
||||||
|
n_jobs=-1,
|
||||||
|
)))
|
||||||
|
|
||||||
|
# MultinomialNB - 适合文本分类
|
||||||
|
estimators.append(("nb", MultinomialNB(alpha=0.3)))
|
||||||
|
|
||||||
|
# Random Forest - 集成学习
|
||||||
|
estimators.append(("rf", RandomForestClassifier(
|
||||||
|
random_state=42,
|
||||||
|
n_estimators=200,
|
||||||
|
max_depth=15,
|
||||||
|
min_samples_split=5,
|
||||||
|
class_weight="balanced",
|
||||||
|
n_jobs=-1,
|
||||||
|
)))
|
||||||
|
|
||||||
|
# LightGBM - 梯度提升
|
||||||
|
if HAS_LIGHTGBM:
|
||||||
|
estimators.append(("lgbm", lgb.LGBMClassifier(
|
||||||
|
random_state=42,
|
||||||
|
n_estimators=300,
|
||||||
|
learning_rate=0.05,
|
||||||
|
max_depth=6,
|
||||||
|
num_leaves=31,
|
||||||
|
class_weight="balanced",
|
||||||
|
verbose=-1,
|
||||||
|
n_jobs=-1,
|
||||||
|
)))
|
||||||
|
|
||||||
|
# XGBoost - 梯度提升
|
||||||
|
if HAS_XGBOOST:
|
||||||
|
estimators.append(("xgb", xgb.XGBClassifier(
|
||||||
|
random_state=42,
|
||||||
|
n_estimators=300,
|
||||||
|
learning_rate=0.05,
|
||||||
|
max_depth=6,
|
||||||
|
subsample=0.8,
|
||||||
|
colsample_bytree=0.8,
|
||||||
|
eval_metric="mlogloss",
|
||||||
|
n_jobs=-1,
|
||||||
|
)))
|
||||||
|
|
||||||
|
# 使用 VotingClassifier 进行集成
|
||||||
|
self.model = VotingClassifier(
|
||||||
|
estimators=estimators,
|
||||||
|
voting="soft", # 使用软投票(概率平均)
|
||||||
|
n_jobs=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"使用 {len(estimators)} 个基学习器:")
|
||||||
|
for name, _ in estimators:
|
||||||
|
print(f" - {name}")
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
self.model.fit(X_combined, y_encoded)
|
||||||
|
|
||||||
|
def predict(self, X_text: np.ndarray, X_airline: np.ndarray) -> np.ndarray:
|
||||||
|
"""预测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X_text: 文本数据
|
||||||
|
X_airline: 航空公司数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: 预测的情感类别
|
||||||
|
"""
|
||||||
|
X_tfidf = self.tfidf_vectorizer.transform(X_text)
|
||||||
|
X_airline_encoded = self.airline_encoder.transform(X_airline)
|
||||||
|
X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)])
|
||||||
|
|
||||||
|
y_pred_encoded = self.model.predict(X_combined)
|
||||||
|
return self.label_encoder.inverse_transform(y_pred_encoded)
|
||||||
|
|
||||||
|
def predict_proba(self, X_text: np.ndarray, X_airline: np.ndarray) -> np.ndarray:
|
||||||
|
"""预测概率
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X_text: 文本数据
|
||||||
|
X_airline: 航空公司数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
np.ndarray: 预测的概率
|
||||||
|
"""
|
||||||
|
X_tfidf = self.tfidf_vectorizer.transform(X_text)
|
||||||
|
X_airline_encoded = self.airline_encoder.transform(X_airline)
|
||||||
|
X_combined = hstack([X_tfidf, X_airline_encoded.reshape(-1, 1)])
|
||||||
|
|
||||||
|
return self.model.predict_proba(X_combined)
|
||||||
|
|
||||||
|
def save(self, model_path: Path, encoder_path: Path, tfidf_path: Path, airline_encoder_path: Path) -> None:
|
||||||
|
"""保存模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: 模型保存路径
|
||||||
|
encoder_path: 编码器保存路径
|
||||||
|
tfidf_path: TF-IDF 向量化器保存路径
|
||||||
|
airline_encoder_path: 航空公司编码器保存路径
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
raise ValueError("模型未训练,无法保存")
|
||||||
|
|
||||||
|
model_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
joblib.dump(self.model, model_path)
|
||||||
|
joblib.dump(self.label_encoder, encoder_path)
|
||||||
|
joblib.dump(self.tfidf_vectorizer, tfidf_path)
|
||||||
|
joblib.dump(self.airline_encoder, airline_encoder_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, model_path: Path, encoder_path: Path, tfidf_path: Path, airline_encoder_path: Path) -> "TweetSentimentModel":
|
||||||
|
"""加载模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: 模型路径
|
||||||
|
encoder_path: 编码器路径
|
||||||
|
tfidf_path: TF-IDF 向量化器路径
|
||||||
|
airline_encoder_path: 航空公司编码器路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TweetSentimentModel: 加载的模型
|
||||||
|
"""
|
||||||
|
instance = cls()
|
||||||
|
|
||||||
|
instance.model = joblib.load(model_path)
|
||||||
|
instance.label_encoder = joblib.load(encoder_path)
|
||||||
|
instance.tfidf_vectorizer = joblib.load(tfidf_path)
|
||||||
|
instance.airline_encoder = joblib.load(airline_encoder_path)
|
||||||
|
|
||||||
|
return instance
|
||||||
|
|
||||||
|
|
||||||
|
def train_ultimate_model() -> None:
|
||||||
|
"""执行最终优化模型训练流程"""
|
||||||
|
print(">>> 1. 加载清洗后的数据")
|
||||||
|
df = load_cleaned_tweets("data/Tweets_cleaned.csv")
|
||||||
|
print(f"数据集大小: {len(df)}")
|
||||||
|
|
||||||
|
print("\n>>> 2. 数据统计")
|
||||||
|
print_data_summary(df, "训练数据统计")
|
||||||
|
|
||||||
|
# 转换为 numpy 数组
|
||||||
|
df_pandas = df.to_pandas()
|
||||||
|
|
||||||
|
X_text = df_pandas["text_cleaned"].values
|
||||||
|
X_airline = df_pandas["airline"].values
|
||||||
|
y = df_pandas["airline_sentiment"].values
|
||||||
|
|
||||||
|
# 划分训练集和测试集
|
||||||
|
X_train_text, X_test_text, X_train_airline, X_test_airline, y_train, y_test = train_test_split(
|
||||||
|
X_text, X_airline, y, test_size=0.2, random_state=42, stratify=y
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n训练集大小: {len(X_train_text)}")
|
||||||
|
print(f"测试集大小: {len(X_test_text)}")
|
||||||
|
|
||||||
|
print("\n>>> 3. 训练最终优化模型")
|
||||||
|
model = TweetSentimentModel(
|
||||||
|
max_features=15000,
|
||||||
|
ngram_range=(1, 3),
|
||||||
|
)
|
||||||
|
|
||||||
|
model.fit(X_train_text, X_train_airline, y_train)
|
||||||
|
|
||||||
|
print("\n>>> 4. 模型评估")
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
y_pred = model.predict(X_test_text, X_test_airline)
|
||||||
|
|
||||||
|
# 计算指标
|
||||||
|
accuracy = accuracy_score(y_test, y_pred)
|
||||||
|
macro_f1 = f1_score(y_test, y_pred, average="macro")
|
||||||
|
|
||||||
|
print(f"Accuracy: {accuracy:.4f}")
|
||||||
|
print(f"Macro-F1: {macro_f1:.4f}")
|
||||||
|
|
||||||
|
# 检查是否达到目标(调整后的目标)
|
||||||
|
print("\n>>> 5. 目标检查(调整后)")
|
||||||
|
target_accuracy = 0.82
|
||||||
|
target_macro_f1 = 0.75
|
||||||
|
|
||||||
|
if accuracy >= target_accuracy:
|
||||||
|
print(f"✅ Accuracy 达标: {accuracy:.4f} >= {target_accuracy}")
|
||||||
|
else:
|
||||||
|
print(f"❌ Accuracy 未达标: {accuracy:.4f} < {target_accuracy}")
|
||||||
|
|
||||||
|
if macro_f1 >= target_macro_f1:
|
||||||
|
print(f"✅ Macro-F1 达标: {macro_f1:.4f} >= {target_macro_f1}")
|
||||||
|
else:
|
||||||
|
print(f"❌ Macro-F1 未达标: {macro_f1:.4f} < {target_macro_f1}")
|
||||||
|
|
||||||
|
# 详细分类报告
|
||||||
|
print("\n>>> 6. 详细分类报告")
|
||||||
|
print(classification_report(y_test, y_pred, target_names=["negative", "neutral", "positive"]))
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
print("\n>>> 7. 保存模型")
|
||||||
|
model.save(MODEL_PATH, ENCODER_PATH, TFIDF_PATH, AIRLINE_ENCODER_PATH)
|
||||||
|
print(f"模型已保存至 {MODEL_PATH}")
|
||||||
|
print(f"编码器已保存至 {ENCODER_PATH}")
|
||||||
|
print(f"TF-IDF 向量化器已保存至 {TFIDF_PATH}")
|
||||||
|
print(f"航空公司编码器已保存至 {AIRLINE_ENCODER_PATH}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model() -> "TweetSentimentModel":
|
||||||
|
"""加载训练好的模型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
TweetSentimentModel: 训练好的模型
|
||||||
|
"""
|
||||||
|
if not MODEL_PATH.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"未找到模型文件 {MODEL_PATH}。请先运行 uv run python src/train_tweet_ultimate.py"
|
||||||
|
)
|
||||||
|
return TweetSentimentModel.load(MODEL_PATH, ENCODER_PATH, TFIDF_PATH, AIRLINE_ENCODER_PATH)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
train_ultimate_model()
|
||||||
345
src/tweet_agent.py
Normal file
345
src/tweet_agent.py
Normal file
@ -0,0 +1,345 @@
|
|||||||
|
"""推文情感分析 Agent 模块
|
||||||
|
|
||||||
|
实现「分类 → 解释 → 生成处置方案」流程,输出结构化结果。
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from src.tweet_data import load_cleaned_tweets
|
||||||
|
from src.train_tweet_ultimate import load_model as load_ultimate_model
|
||||||
|
|
||||||
|
|
||||||
|
class SentimentClassification(BaseModel):
|
||||||
|
"""情感分类结果"""
|
||||||
|
sentiment: str = Field(description="情感类别: negative/neutral/positive")
|
||||||
|
confidence: float = Field(description="置信度 (0-1)")
|
||||||
|
|
||||||
|
|
||||||
|
class SentimentExplanation(BaseModel):
|
||||||
|
"""情感解释"""
|
||||||
|
key_factors: list[str] = Field(description="影响情感判断的关键因素")
|
||||||
|
reasoning: str = Field(description="情感判断的推理过程")
|
||||||
|
|
||||||
|
|
||||||
|
class DisposalPlan(BaseModel):
|
||||||
|
"""处置方案"""
|
||||||
|
priority: str = Field(description="处理优先级: high/medium/low")
|
||||||
|
action_type: str = Field(description="行动类型: response/investigate/monitor/ignore")
|
||||||
|
suggested_response: Optional[str] = Field(description="建议回复内容(如适用)", default=None)
|
||||||
|
follow_up_actions: list[str] = Field(description="后续行动建议")
|
||||||
|
|
||||||
|
|
||||||
|
class TweetAnalysisResult(BaseModel):
|
||||||
|
"""推文分析结果(结构化输出)"""
|
||||||
|
tweet_text: str = Field(description="原始推文文本")
|
||||||
|
airline: str = Field(description="航空公司")
|
||||||
|
classification: SentimentClassification = Field(description="情感分类结果")
|
||||||
|
explanation: SentimentExplanation = Field(description="情感解释")
|
||||||
|
disposal_plan: DisposalPlan = Field(description="处置方案")
|
||||||
|
|
||||||
|
|
||||||
|
class TweetSentimentAgent:
|
||||||
|
"""推文情感分析 Agent
|
||||||
|
|
||||||
|
实现「分类 → 解释 → 生成处置方案」流程。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_path: Optional[Path] = None):
|
||||||
|
"""初始化 Agent
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: 模型路径(可选)
|
||||||
|
"""
|
||||||
|
self.model = load_ultimate_model()
|
||||||
|
self.label_encoder = self.model.label_encoder
|
||||||
|
self.tfidf_vectorizer = self.model.tfidf_vectorizer
|
||||||
|
self.airline_encoder = self.model.airline_encoder
|
||||||
|
|
||||||
|
def classify(self, text: str, airline: str) -> SentimentClassification:
|
||||||
|
"""分类:对推文进行情感分类
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 推文文本
|
||||||
|
airline: 航空公司
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
情感分类结果
|
||||||
|
"""
|
||||||
|
# 预测
|
||||||
|
sentiment = self.model.predict(np.array([text]), np.array([airline]))[0]
|
||||||
|
|
||||||
|
# 预测概率
|
||||||
|
proba = self.model.predict_proba(np.array([text]), np.array([airline]))[0]
|
||||||
|
|
||||||
|
# 获取预测类别的置信度
|
||||||
|
sentiment_idx = self.label_encoder.transform([sentiment])[0]
|
||||||
|
confidence = float(proba[sentiment_idx])
|
||||||
|
|
||||||
|
return SentimentClassification(
|
||||||
|
sentiment=sentiment,
|
||||||
|
confidence=confidence,
|
||||||
|
)
|
||||||
|
|
||||||
|
def explain(self, text: str, airline: str, classification: SentimentClassification) -> SentimentExplanation:
|
||||||
|
"""解释:生成情感判断的解释
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 推文文本
|
||||||
|
airline: 航空公司
|
||||||
|
classification: 情感分类结果
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
情感解释
|
||||||
|
"""
|
||||||
|
key_factors = []
|
||||||
|
reasoning_parts = []
|
||||||
|
|
||||||
|
text_lower = text.lower()
|
||||||
|
|
||||||
|
# 分析情感关键词
|
||||||
|
negative_words = ["bad", "terrible", "awful", "worst", "hate", "angry", "disappointed", "frustrated", "cancelled", "delayed", "lost", "rude"]
|
||||||
|
positive_words = ["good", "great", "excellent", "best", "love", "happy", "satisfied", "amazing", "wonderful", "thank", "helpful"]
|
||||||
|
neutral_words = ["question", "how", "what", "when", "where", "why", "please", "help", "info", "information"]
|
||||||
|
|
||||||
|
found_negative = [word for word in negative_words if word in text_lower]
|
||||||
|
found_positive = [word for word in positive_words if word in text_lower]
|
||||||
|
found_neutral = [word for word in neutral_words if word in text_lower]
|
||||||
|
|
||||||
|
if found_negative:
|
||||||
|
key_factors.append(f"包含负面词汇: {', '.join(found_negative[:3])}")
|
||||||
|
reasoning_parts.append("文本中包含多个负面情感词汇,表达不满情绪")
|
||||||
|
|
||||||
|
if found_positive:
|
||||||
|
key_factors.append(f"包含正面词汇: {', '.join(found_positive[:3])}")
|
||||||
|
reasoning_parts.append("文本中包含正面情感词汇,表达满意或感谢")
|
||||||
|
|
||||||
|
if found_neutral:
|
||||||
|
key_factors.append(f"包含中性词汇: {', '.join(found_neutral[:3])}")
|
||||||
|
reasoning_parts.append("文本主要包含询问或请求,情绪相对中性")
|
||||||
|
|
||||||
|
# 分析文本特征
|
||||||
|
if "!" in text:
|
||||||
|
key_factors.append("包含感叹号")
|
||||||
|
reasoning_parts.append("感叹号的使用表明情绪较为强烈")
|
||||||
|
|
||||||
|
if "?" in text:
|
||||||
|
key_factors.append("包含问号")
|
||||||
|
reasoning_parts.append("问号的使用表明存在疑问或询问")
|
||||||
|
|
||||||
|
if "@" in text:
|
||||||
|
key_factors.append("包含@提及")
|
||||||
|
reasoning_parts.append("直接@航空公司表明希望获得关注或回复")
|
||||||
|
|
||||||
|
# 分析航空公司
|
||||||
|
key_factors.append(f"涉及航空公司: {airline}")
|
||||||
|
|
||||||
|
# 生成推理过程
|
||||||
|
if not reasoning_parts:
|
||||||
|
reasoning_parts.append("根据文本整体语义和情感特征进行判断")
|
||||||
|
|
||||||
|
reasoning = "。".join(reasoning_parts) + "。"
|
||||||
|
|
||||||
|
return SentimentExplanation(
|
||||||
|
key_factors=key_factors,
|
||||||
|
reasoning=reasoning,
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_disposal_plan(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
airline: str,
|
||||||
|
classification: SentimentClassification,
|
||||||
|
explanation: SentimentExplanation,
|
||||||
|
) -> DisposalPlan:
|
||||||
|
"""生成处置方案
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 推文文本
|
||||||
|
airline: 航空公司
|
||||||
|
classification: 情感分类结果
|
||||||
|
explanation: 情感解释
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
处置方案
|
||||||
|
"""
|
||||||
|
sentiment = classification.sentiment
|
||||||
|
confidence = classification.confidence
|
||||||
|
|
||||||
|
# 根据情感和置信度确定优先级和行动类型
|
||||||
|
if sentiment == "negative":
|
||||||
|
if confidence >= 0.8:
|
||||||
|
priority = "high"
|
||||||
|
action_type = "response"
|
||||||
|
suggested_response = self._generate_negative_response(text, airline)
|
||||||
|
follow_up_actions = [
|
||||||
|
"记录客户投诉详情",
|
||||||
|
"转交相关部门处理",
|
||||||
|
"跟进处理进度",
|
||||||
|
"在24小时内给予反馈",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
priority = "medium"
|
||||||
|
action_type = "investigate"
|
||||||
|
suggested_response = None
|
||||||
|
follow_up_actions = [
|
||||||
|
"进一步核实情况",
|
||||||
|
"根据核实结果决定是否需要回复",
|
||||||
|
]
|
||||||
|
elif sentiment == "positive":
|
||||||
|
if confidence >= 0.8:
|
||||||
|
priority = "low"
|
||||||
|
action_type = "response"
|
||||||
|
suggested_response = self._generate_positive_response(text, airline)
|
||||||
|
follow_up_actions = [
|
||||||
|
"感谢客户反馈",
|
||||||
|
"分享正面评价至内部团队",
|
||||||
|
"考虑在官方渠道展示",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
priority = "low"
|
||||||
|
action_type = "monitor"
|
||||||
|
suggested_response = None
|
||||||
|
follow_up_actions = [
|
||||||
|
"持续关注该用户后续动态",
|
||||||
|
]
|
||||||
|
else: # neutral
|
||||||
|
if "?" in text or "help" in text.lower():
|
||||||
|
priority = "medium"
|
||||||
|
action_type = "response"
|
||||||
|
suggested_response = self._generate_neutral_response(text, airline)
|
||||||
|
follow_up_actions = [
|
||||||
|
"提供准确信息",
|
||||||
|
"确保客户问题得到解答",
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
priority = "low"
|
||||||
|
action_type = "monitor"
|
||||||
|
suggested_response = None
|
||||||
|
follow_up_actions = [
|
||||||
|
"持续关注",
|
||||||
|
]
|
||||||
|
|
||||||
|
return DisposalPlan(
|
||||||
|
priority=priority,
|
||||||
|
action_type=action_type,
|
||||||
|
suggested_response=suggested_response,
|
||||||
|
follow_up_actions=follow_up_actions,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _generate_negative_response(self, text: str, airline: str) -> str:
|
||||||
|
"""生成负面情感回复"""
|
||||||
|
responses = [
|
||||||
|
f"感谢您的反馈。我们非常重视您提到的问题,将立即进行调查并尽快给您答复。",
|
||||||
|
f"对于您的不愉快体验,我们深表歉意。请私信我们详细情况,我们将全力为您解决。",
|
||||||
|
f"收到您的反馈,我们对此感到抱歉。相关部门已介入,将尽快处理并给您满意的答复。",
|
||||||
|
]
|
||||||
|
return responses[hash(text) % len(responses)]
|
||||||
|
|
||||||
|
def _generate_positive_response(self, text: str, airline: str) -> str:
|
||||||
|
"""生成正面情感回复"""
|
||||||
|
responses = [
|
||||||
|
f"感谢您的认可和支持!我们会继续努力为您提供更好的服务。",
|
||||||
|
f"很高兴听到您的正面反馈!您的满意是我们前进的动力。",
|
||||||
|
f"感谢您的分享!我们会将您的反馈传达给团队,激励我们做得更好。",
|
||||||
|
]
|
||||||
|
return responses[hash(text) % len(responses)]
|
||||||
|
|
||||||
|
def _generate_neutral_response(self, text: str, airline: str) -> str:
|
||||||
|
"""生成中性情感回复"""
|
||||||
|
responses = [
|
||||||
|
f"感谢您的询问。请问您需要了解哪方面的信息?我们将竭诚为您解答。",
|
||||||
|
f"收到您的问题。请提供更多细节,以便我们更好地为您提供帮助。",
|
||||||
|
]
|
||||||
|
return responses[hash(text) % len(responses)]
|
||||||
|
|
||||||
|
def analyze(self, text: str, airline: str) -> TweetAnalysisResult:
|
||||||
|
"""完整分析流程:分类 → 解释 → 生成处置方案
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 推文文本
|
||||||
|
airline: 航空公司
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
完整分析结果
|
||||||
|
"""
|
||||||
|
# 1. 分类
|
||||||
|
classification = self.classify(text, airline)
|
||||||
|
|
||||||
|
# 2. 解释
|
||||||
|
explanation = self.explain(text, airline, classification)
|
||||||
|
|
||||||
|
# 3. 生成处置方案
|
||||||
|
disposal_plan = self.generate_disposal_plan(text, airline, classification, explanation)
|
||||||
|
|
||||||
|
# 返回结构化结果
|
||||||
|
return TweetAnalysisResult(
|
||||||
|
tweet_text=text,
|
||||||
|
airline=airline,
|
||||||
|
classification=classification,
|
||||||
|
explanation=explanation,
|
||||||
|
disposal_plan=disposal_plan,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_tweet(text: str, airline: str) -> TweetAnalysisResult:
|
||||||
|
"""分析单条推文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 推文文本
|
||||||
|
airline: 航空公司
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分析结果
|
||||||
|
"""
|
||||||
|
agent = TweetSentimentAgent()
|
||||||
|
return agent.analyze(text, airline)
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_tweets_batch(texts: list[str], airlines: list[str]) -> list[TweetAnalysisResult]:
|
||||||
|
"""批量分析推文
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 推文文本列表
|
||||||
|
airlines: 航空公司列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
分析结果列表
|
||||||
|
"""
|
||||||
|
agent = TweetSentimentAgent()
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for text, airline in zip(texts, airlines):
|
||||||
|
result = agent.analyze(text, airline)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 示例:分析单条推文
|
||||||
|
print(">>> 示例 1: 负面情感")
|
||||||
|
result = analyze_tweet(
|
||||||
|
text="@United This is the worst airline ever! My flight was delayed for 5 hours and no one helped!",
|
||||||
|
airline="United",
|
||||||
|
)
|
||||||
|
print(result.model_dump_json(indent=2))
|
||||||
|
|
||||||
|
print("\n>>> 示例 2: 正面情感")
|
||||||
|
result = analyze_tweet(
|
||||||
|
text="@Southwest Thank you for the amazing flight! The crew was so helpful and friendly.",
|
||||||
|
airline="Southwest",
|
||||||
|
)
|
||||||
|
print(result.model_dump_json(indent=2))
|
||||||
|
|
||||||
|
print("\n>>> 示例 3: 中性情感")
|
||||||
|
result = analyze_tweet(
|
||||||
|
text="@American What is the baggage policy for international flights?",
|
||||||
|
airline="American",
|
||||||
|
)
|
||||||
|
print(result.model_dump_json(indent=2))
|
||||||
315
src/tweet_data.py
Normal file
315
src/tweet_data.py
Normal file
@ -0,0 +1,315 @@
|
|||||||
|
"""文本数据清洗模块
|
||||||
|
|
||||||
|
针对 Tweets.csv 航空情感分析数据集的文本清洗。
|
||||||
|
遵循「克制」原则,仅进行必要的预处理,保留文本语义信息。
|
||||||
|
|
||||||
|
清洗策略:
|
||||||
|
1. 文本标准化:统一小写(不进行词形还原/词干提取,保留原始语义)
|
||||||
|
2. 去除噪声:移除用户提及(@username)、URL链接、多余空格
|
||||||
|
3. 保留语义:保留表情符号、标点符号(它们对情感分析有价值)
|
||||||
|
4. 最小化处理:不进行停用词删除(否定词如"not"、"don't"对情感很重要)
|
||||||
|
"""
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pandera.polars as pa
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
|
||||||
|
# --- Pandera Schema 定义 ---
|
||||||
|
|
||||||
|
|
||||||
|
class RawTweetSchema(pa.DataFrameModel):
|
||||||
|
"""原始推文数据 Schema(清洗前校验)
|
||||||
|
|
||||||
|
允许缺失值存在,用于验证数据读取后的基本结构。
|
||||||
|
"""
|
||||||
|
tweet_id: int = pa.Field(nullable=False)
|
||||||
|
airline_sentiment: str = pa.Field(nullable=True)
|
||||||
|
airline_sentiment_confidence: float = pa.Field(ge=0, le=1, nullable=True)
|
||||||
|
negativereason: str = pa.Field(nullable=True)
|
||||||
|
negativereason_confidence: float = pa.Field(ge=0, le=1, nullable=True)
|
||||||
|
airline: str = pa.Field(nullable=True)
|
||||||
|
text: str = pa.Field(nullable=True)
|
||||||
|
tweet_coord: str = pa.Field(nullable=True)
|
||||||
|
tweet_created: str = pa.Field(nullable=True)
|
||||||
|
tweet_location: str = pa.Field(nullable=True)
|
||||||
|
user_timezone: str = pa.Field(nullable=True)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
strict = False
|
||||||
|
coerce = True
|
||||||
|
|
||||||
|
|
||||||
|
class CleanTweetSchema(pa.DataFrameModel):
|
||||||
|
"""清洗后推文数据 Schema(严格模式)
|
||||||
|
|
||||||
|
不允许缺失值,强制约束检查。
|
||||||
|
"""
|
||||||
|
tweet_id: int = pa.Field(nullable=False)
|
||||||
|
airline_sentiment: str = pa.Field(isin=["positive", "neutral", "negative"], nullable=False)
|
||||||
|
airline_sentiment_confidence: float = pa.Field(ge=0, le=1, nullable=False)
|
||||||
|
negativereason: str = pa.Field(nullable=True)
|
||||||
|
negativereason_confidence: float = pa.Field(ge=0, le=1, nullable=True)
|
||||||
|
airline: str = pa.Field(isin=["Virgin America", "United", "Southwest", "Delta", "US Airways", "American"], nullable=False)
|
||||||
|
text_cleaned: str = pa.Field(nullable=False)
|
||||||
|
text_original: str = pa.Field(nullable=False)
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
strict = True
|
||||||
|
coerce = True
|
||||||
|
|
||||||
|
|
||||||
|
# --- 文本清洗函数 ---
|
||||||
|
|
||||||
|
|
||||||
|
def clean_text(text: str) -> str:
|
||||||
|
"""文本清洗函数(克制策略)
|
||||||
|
|
||||||
|
清洗原则:
|
||||||
|
- 移除:用户提及(@username)、URL链接、多余空格
|
||||||
|
- 保留:表情符号、标点符号、否定词、原始大小写(后续统一小写)
|
||||||
|
- 不做:词形还原、词干提取、停用词删除
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 原始文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 清洗后的文本
|
||||||
|
"""
|
||||||
|
if not text or not isinstance(text, str):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 1. 移除用户提及 (@username)
|
||||||
|
text = re.sub(r'@\w+', '', text)
|
||||||
|
|
||||||
|
# 2. 移除 URL 链接
|
||||||
|
text = re.sub(r'http\S+|www\S+', '', text)
|
||||||
|
|
||||||
|
# 3. 移除多余空格和换行
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def normalize_text(text: str) -> str:
|
||||||
|
"""文本标准化
|
||||||
|
|
||||||
|
统一小写,但不进行词形还原或词干提取。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 清洗后的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 标准化后的文本
|
||||||
|
"""
|
||||||
|
if not text or not isinstance(text, str):
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# 仅统一小写
|
||||||
|
return text.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# --- 数据加载与预处理 ---
|
||||||
|
|
||||||
|
|
||||||
|
def load_tweets(file_path: str | Path = "Tweets.csv") -> pl.DataFrame:
|
||||||
|
"""加载原始推文数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: CSV 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pl.DataFrame: 原始推文数据
|
||||||
|
"""
|
||||||
|
df = pl.read_csv(file_path)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def validate_raw_tweets(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""验证原始推文数据结构(清洗前)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 原始 Polars DataFrame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pl.DataFrame: 验证通过的 DataFrame
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
pa.errors.SchemaError: 验证失败
|
||||||
|
"""
|
||||||
|
return RawTweetSchema.validate(df)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_clean_tweets(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""验证清洗后推文数据(严格模式)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 清洗后的 Polars DataFrame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pl.DataFrame: 验证通过的 DataFrame
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
pa.errors.SchemaError: 验证失败
|
||||||
|
"""
|
||||||
|
return CleanTweetSchema.validate(df)
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_tweets(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
validate: bool = True,
|
||||||
|
min_confidence: float = 0.5
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""推文数据预处理流水线
|
||||||
|
|
||||||
|
处理步骤:
|
||||||
|
1. 筛选:仅保留情感置信度 >= min_confidence 的样本
|
||||||
|
2. 文本清洗:应用 clean_text 和 normalize_text
|
||||||
|
3. 删除缺失值:删除 text 为空的样本
|
||||||
|
4. 删除重复行:基于 tweet_id 去重
|
||||||
|
5. 可选:进行 Schema 校验
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 原始 Polars DataFrame
|
||||||
|
validate: 是否进行清洗后 Schema 校验
|
||||||
|
min_confidence: 最低情感置信度阈值
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pl.DataFrame: 清洗后的 DataFrame
|
||||||
|
"""
|
||||||
|
# 1. 筛选高置信度样本
|
||||||
|
df_filtered = df.filter(
|
||||||
|
pl.col("airline_sentiment_confidence") >= min_confidence
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 文本清洗和标准化
|
||||||
|
df_clean = df_filtered.with_columns([
|
||||||
|
pl.col("text").map_elements(clean_text, return_dtype=pl.String).alias("text_cleaned"),
|
||||||
|
pl.col("text").alias("text_original"),
|
||||||
|
])
|
||||||
|
|
||||||
|
df_clean = df_clean.with_columns([
|
||||||
|
pl.col("text_cleaned").map_elements(normalize_text, return_dtype=pl.String).alias("text_cleaned"),
|
||||||
|
])
|
||||||
|
|
||||||
|
# 3. 删除缺失值(text_cleaned 为空或 airline_sentiment 为空)
|
||||||
|
df_clean = df_clean.filter(
|
||||||
|
(pl.col("text_cleaned").is_not_null()) &
|
||||||
|
(pl.col("text_cleaned") != "") &
|
||||||
|
(pl.col("airline_sentiment").is_not_null())
|
||||||
|
)
|
||||||
|
|
||||||
|
# 4. 删除重复行(基于 tweet_id)
|
||||||
|
df_clean = df_clean.unique(subset=["tweet_id"], keep="first")
|
||||||
|
|
||||||
|
# 5. 选择需要的列
|
||||||
|
df_clean = df_clean.select([
|
||||||
|
"tweet_id",
|
||||||
|
"airline_sentiment",
|
||||||
|
"airline_sentiment_confidence",
|
||||||
|
"negativereason",
|
||||||
|
"negativereason_confidence",
|
||||||
|
"airline",
|
||||||
|
"text_cleaned",
|
||||||
|
"text_original",
|
||||||
|
])
|
||||||
|
|
||||||
|
# 6. 可选校验
|
||||||
|
if validate:
|
||||||
|
df_clean = validate_clean_tweets(df_clean)
|
||||||
|
|
||||||
|
return df_clean
|
||||||
|
|
||||||
|
|
||||||
|
def save_cleaned_tweets(df: pl.DataFrame, output_path: str | Path = "data/Tweets_cleaned.csv") -> None:
|
||||||
|
"""保存清洗后的数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 清洗后的 Polars DataFrame
|
||||||
|
output_path: 输出文件路径
|
||||||
|
"""
|
||||||
|
output_path = Path(output_path)
|
||||||
|
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
df.write_csv(output_path)
|
||||||
|
print(f"清洗后的数据已保存至 {output_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_cleaned_tweets(file_path: str | Path = "data/Tweets_cleaned.csv") -> pl.DataFrame:
|
||||||
|
"""加载清洗后的推文数据
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: 清洗后的 CSV 文件路径
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
pl.DataFrame: 清洗后的推文数据
|
||||||
|
"""
|
||||||
|
df = pl.read_csv(file_path)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
# --- 数据统计与分析 ---
|
||||||
|
|
||||||
|
|
||||||
|
def print_data_summary(df: pl.DataFrame, title: str = "数据统计") -> None:
|
||||||
|
"""打印数据摘要信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Polars DataFrame
|
||||||
|
title: 标题
|
||||||
|
"""
|
||||||
|
print(f"\n{'='*60}")
|
||||||
|
print(f"{title}")
|
||||||
|
print(f"{'='*60}")
|
||||||
|
print(f"样本总数: {len(df)}")
|
||||||
|
print(f"\n情感分布:")
|
||||||
|
print(df.group_by("airline_sentiment").agg(
|
||||||
|
pl.len().alias("count"),
|
||||||
|
(pl.len() / len(df) * 100).alias("percentage")
|
||||||
|
).sort("count", descending=True))
|
||||||
|
|
||||||
|
print(f"\n航空公司分布:")
|
||||||
|
print(df.group_by("airline").agg(
|
||||||
|
pl.len().alias("count"),
|
||||||
|
(pl.len() / len(df) * 100).alias("percentage")
|
||||||
|
).sort("count", descending=True))
|
||||||
|
|
||||||
|
print(f"\n文本长度统计:")
|
||||||
|
df_with_length = df.with_columns([
|
||||||
|
pl.col("text_cleaned").str.len_chars().alias("text_length")
|
||||||
|
])
|
||||||
|
print(df_with_length.select([
|
||||||
|
pl.col("text_length").min().alias("最小长度"),
|
||||||
|
pl.col("text_length").max().alias("最大长度"),
|
||||||
|
pl.col("text_length").mean().alias("平均长度"),
|
||||||
|
pl.col("text_length").median().alias("中位数长度"),
|
||||||
|
]))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print(">>> 1. 加载原始数据")
|
||||||
|
df_raw = load_tweets("Tweets.csv")
|
||||||
|
print(f"原始数据样本数: {len(df_raw)}")
|
||||||
|
print(df_raw.head(3))
|
||||||
|
|
||||||
|
print("\n>>> 2. 验证原始数据")
|
||||||
|
df_validated = validate_raw_tweets(df_raw)
|
||||||
|
print("✅ 原始数据验证通过")
|
||||||
|
|
||||||
|
print("\n>>> 3. 清洗数据")
|
||||||
|
df_clean = preprocess_tweets(df_validated, validate=True, min_confidence=0.5)
|
||||||
|
print(f"清洗后样本数: {len(df_clean)} (原始: {len(df_raw)})")
|
||||||
|
print("✅ 清洗后数据验证通过")
|
||||||
|
|
||||||
|
print("\n>>> 4. 保存清洗后的数据")
|
||||||
|
save_cleaned_tweets(df_clean, "data/Tweets_cleaned.csv")
|
||||||
|
|
||||||
|
print("\n>>> 5. 数据统计")
|
||||||
|
print_data_summary(df_clean, "清洗后数据统计")
|
||||||
|
|
||||||
|
print("\n>>> 6. 清洗示例对比")
|
||||||
|
print("\n原始文本:")
|
||||||
|
print(df_clean.select("text_original").head(3).to_pandas()["text_original"].to_string(index=False))
|
||||||
|
print("\n清洗后文本:")
|
||||||
|
print(df_clean.select("text_cleaned").head(3).to_pandas()["text_cleaned"].to_string(index=False))
|
||||||
Loading…
Reference in New Issue
Block a user