docs: 更新README文档和.env示例文件
更新README文档,添加团队成员信息、项目简介、技术栈等内容,并优化.env示例文件的格式
This commit is contained in:
parent
72ffdf1647
commit
73490e6799
@ -1,10 +1,10 @@
|
||||
# DeepSeek API Configuration
|
||||
DEEPSEEK_API_KEY="your-deepseek-api-key-here"
|
||||
DEEPSEEK_BASE_URL="https://api.deepseek.com"
|
||||
DEEPSEEK_API_KEY=your-deepseek-api-key-here
|
||||
DEEPSEEK_BASE_URL=https://api.deepseek.com
|
||||
|
||||
# OpenAI Compatibility (for DeepSeek)
|
||||
OPENAI_API_KEY="${DEEPSEEK_API_KEY}"
|
||||
OPENAI_BASE_URL="${DEEPSEEK_BASE_URL}"
|
||||
OPENAI_API_KEY=your-deepseek-api-key-here
|
||||
OPENAI_BASE_URL=https://api.deepseek.com
|
||||
|
||||
# Project Configuration
|
||||
PROJECT_NAME="spam-classification-system"
|
||||
|
||||
353
README.md
353
README.md
@ -1,11 +1,188 @@
|
||||
# 垃圾短信分类系统
|
||||
|
||||
## 项目概述
|
||||
|
||||
本项目是一个基于**传统机器学习 + LLM + Agent**的垃圾短信分类系统,旨在实现可落地的智能预测与行动建议。系统使用传统机器学习完成可量化的垃圾短信预测任务,再用 LLM + Agent 把预测结果变成可执行的决策/建议,并保证输出结构化、可追溯、可复现。
|
||||
> **机器学习 (Python) 课程设计**
|
||||
|
||||
|
||||
## 👥 团队成员
|
||||
|
||||
|
||||
| 姓名 | 学号 | 贡献 |
|
||||
|------|------|------|
|
||||
| 朱指乐 | 2311020135 | 数据处理、模型训练 |
|
||||
| 肖康 | 2311020125 | Agent 开发、LLM 服务 |
|
||||
| 龙思富 | 2311020114 | 可视化、优化streamlit应用、文档撰写|
|
||||
|
||||
|
||||
## 📝 项目简介
|
||||
|
||||
|
||||
本项目是一个基于**传统机器学习 + LLM + Agent**的垃圾短信分类系统,旨在实现可落地的智能预测与行动建议。系统使用 SMS Spam Collection 数据集,通过传统机器学习完成垃圾短信的量化预测,再利用 LLM 和 Agent 技术将预测结果转化为结构化、可执行的决策建议,确保输出结果可追溯、可复现。
|
||||
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
|
||||
```bash
|
||||
# 克隆仓库
|
||||
git clone http://hblu.top:3000/MachineLearning2025/CourseDesign.git
|
||||
cd CourseDesign
|
||||
|
||||
# 安装依赖
|
||||
pip install uv -i https://mirrors.aliyun.com/pypi/simple/
|
||||
uv config set index-url https://mirrors.aliyun.com/pypi/simple/
|
||||
uv sync
|
||||
|
||||
# 配置环境变量
|
||||
cp .env.example .env
|
||||
# 编辑 .env 填入 API Key
|
||||
|
||||
# 运行 Demo
|
||||
uv run streamlit run src/streamlit_app.py
|
||||
```
|
||||
|
||||
|
||||
## 1️⃣ 问题定义与数据
|
||||
|
||||
|
||||
### 1.1 任务描述
|
||||
|
||||
|
||||
本项目是一个二分类任务,目标是自动识别垃圾短信(spam)和正常短信(ham)。业务目标是构建一个高准确率、可解释的垃圾短信分类系统,帮助用户有效过滤垃圾信息,提升信息安全和用户体验。
|
||||
|
||||
|
||||
### 1.2 数据来源
|
||||
|
||||
|
||||
| 项目 | 说明 |
|
||||
|------|------|
|
||||
| 数据集名称 | SMS Spam Collection |
|
||||
| 数据链接 | `https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset` |
|
||||
| 样本量 | 5,572 条 |
|
||||
| 特征数 | 1 个(短信文本) |
|
||||
|
||||
|
||||
### 1.3 数据切分与防泄漏
|
||||
|
||||
|
||||
数据按 8:2 比例分割为训练集和测试集,确保模型在独立的测试集上进行评估。在数据预处理和特征工程阶段,所有操作仅在训练集上进行,避免信息泄漏到测试集。使用 TF-IDF 进行文本向量化时,同样严格遵循先训练后应用的原则。
|
||||
|
||||
|
||||
## 2️⃣ 机器学习流水线
|
||||
|
||||
|
||||
### 2.1 基线模型
|
||||
|
||||
|
||||
| 模型 | 指标 | 结果 |
|
||||
|------|------|------|
|
||||
| Logistic Regression | 准确率 | 0.978 |
|
||||
| Logistic Regression | F1 分数(Macro) | 0.959 |
|
||||
|
||||
|
||||
### 2.2 进阶模型
|
||||
|
||||
|
||||
| 模型 | 指标 | 结果 |
|
||||
|------|------|------|
|
||||
| LightGBM | 准确率 | 0.985 |
|
||||
| LightGBM | F1 分数(Macro) | 0.971 |
|
||||
|
||||
|
||||
### 2.3 误差分析
|
||||
|
||||
|
||||
模型在以下类型的样本上表现相对较差:
|
||||
1. 包含大量特殊字符或缩写的短信
|
||||
2. 内容模糊、边界不清的促销短信
|
||||
3. 混合中英文的短信
|
||||
4. 模仿正常短信格式的垃圾短信
|
||||
|
||||
这主要是因为文本特征提取方法(TF-IDF)对语义理解有限,无法完全捕捉复杂的语言模式和上下文信息。
|
||||
|
||||
|
||||
## 3️⃣ Agent 实现
|
||||
|
||||
|
||||
### 3.1 工具定义
|
||||
|
||||
|
||||
| 工具名 | 功能 | 输入 | 输出 |
|
||||
|--------|------|------|------|
|
||||
| `predict_spam` | 使用机器学习模型预测短信是否为垃圾短信 | 短信文本 | 分类结果(spam/ham)和概率 |
|
||||
| `explain_prediction` | 解释模型预测结果并生成行动建议 | 短信文本、分类结果、概率 | 结构化的解释和建议 |
|
||||
| `translate_text` | 将文本翻译成目标语言 | 文本、目标语言 | 翻译后的文本 |
|
||||
|
||||
|
||||
### 3.2 决策流程
|
||||
|
||||
|
||||
Agent 按照以下流程执行任务:
|
||||
1. 接收用户提供的短信文本
|
||||
2. 使用 `predict_spam` 工具进行分类预测
|
||||
3. 使用 `explain_prediction` 工具解释分类结果并生成行动建议
|
||||
4. 如果短信为英文,可选择使用 `translate_text` 工具翻译成中文
|
||||
5. 向用户提供清晰、完整的分类结果、解释和建议
|
||||
|
||||
|
||||
### 3.3 案例展示
|
||||
|
||||
|
||||
**输入**:
|
||||
```
|
||||
Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's
|
||||
```
|
||||
|
||||
**输出**:
|
||||
```json
|
||||
{
|
||||
"classification": {
|
||||
"label": "spam",
|
||||
"probability": {
|
||||
"spam": 0.98,
|
||||
"ham": 0.02
|
||||
}
|
||||
},
|
||||
"explanation": {
|
||||
"content_summary": "这是一条关于免费赢取足总杯决赛门票的竞赛广告短信",
|
||||
"classification_reason": "短信包含'Free entry'、'win'、'comp'等典型的垃圾短信关键词,且提供了需要用户回复的电话号码,符合垃圾短信的特征",
|
||||
"confidence_level": "高",
|
||||
"confidence_explanation": "模型以98%的概率将其分类为垃圾短信,基于文本中的垃圾短信特征词汇和结构",
|
||||
"suggestions": ["不要回复此短信,避免产生额外费用", "将此号码加入黑名单", "删除该短信"]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## 4️⃣ 开发心得
|
||||
|
||||
|
||||
### 4.1 主要困难与解决方案
|
||||
|
||||
|
||||
1. **文本特征提取**:原始文本数据难以直接用于机器学习模型,解决方案是使用 TF-IDF 进行文本向量化,将文本转化为数值特征。
|
||||
2. **模型可解释性**:传统机器学习模型的预测结果缺乏可解释性,解决方案是集成 LLM 服务,对模型预测结果进行自然语言解释。
|
||||
3. **API 集成与错误处理**:LLM API 调用可能会遇到各种错误,解决方案是实现完善的错误处理机制,确保系统稳定性。
|
||||
|
||||
|
||||
### 4.2 对 AI 辅助编程的感受
|
||||
|
||||
|
||||
AI 辅助编程工具(如 GitHub Copilot)在代码编写和问题解决方面提供了很大帮助,特别是在处理重复性任务和学习新框架时。它可以快速生成代码模板,提供解决方案建议,显著提高开发效率。但同时也需要注意,AI 生成的代码可能存在错误或不符合项目规范,需要人工仔细检查和调试。
|
||||
|
||||
|
||||
### 4.3 局限与未来改进
|
||||
|
||||
|
||||
1. **模型性能**:当前模型在处理复杂语言模式和上下文理解方面仍有提升空间,可以考虑使用更先进的文本表示方法(如 BERT)。
|
||||
2. **多语言支持**:目前系统主要支持中英文短信,未来可以扩展到更多语言。
|
||||
3. **实时性**:可以优化模型推理速度,实现实时分类功能。
|
||||
4. **用户界面**:可以进一步改进 Streamlit 应用的用户体验,增加更多交互功能和可视化效果。
|
||||
|
||||
|
||||
## 技术栈
|
||||
|
||||
|
||||
| 组件 | 技术 | 版本要求 |
|
||||
|------|------|----------|
|
||||
| 项目管理 | uv | 最新版 |
|
||||
@ -13,189 +190,33 @@
|
||||
| 数据验证 | pandera | >=0.18.0 |
|
||||
| 机器学习 | scikit-learn + lightgbm | sklearn>=1.3.0, lightgbm>=4.0.0 |
|
||||
| LLM 框架 | openai | >=1.0.0 |
|
||||
| Agent 框架 | pydantic + pydantic-ai | pydantic>=2.0.0 |
|
||||
| Agent 框架 | pydantic | pydantic>=2.0.0 |
|
||||
| 可视化 | streamlit | >=1.20.0 |
|
||||
| 文本处理 | nltk | >=3.8.0 |
|
||||
|
||||
## 项目结构
|
||||
|
||||
```
|
||||
├── src/
|
||||
│ ├── data/ # 数据处理模块
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── preprocess.py # 数据预处理
|
||||
│ │ └── validation.py # 数据验证
|
||||
│ ├── models/ # 机器学习模型
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── train.py # 模型训练
|
||||
│ │ ├── evaluate.py # 模型评估
|
||||
│ │ └── predict.py # 模型预测
|
||||
│ ├── agent/ # Agent 模块
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── agent.py # Agent 核心逻辑
|
||||
│ │ └── tools.py # Agent 工具定义
|
||||
│ ├── llm/ # LLM 服务
|
||||
│ │ ├── __init__.py
|
||||
│ │ └── llm_service.py # DeepSeek API 集成
|
||||
│ ├── main.py # 主程序入口
|
||||
│ └── streamlit_app.py # Streamlit 可视化应用
|
||||
├── data/ # 数据集目录
|
||||
│ └── spam.csv # 垃圾短信数据集
|
||||
├── models/ # 模型保存目录
|
||||
├── .env.example # 环境变量示例
|
||||
├── .gitignore # Git 忽略文件
|
||||
├── pyproject.toml # 项目配置
|
||||
└── README.md # 项目说明文档
|
||||
```
|
||||
|
||||
## 快速开始
|
||||
|
||||
### 1. 安装依赖
|
||||
|
||||
```bash
|
||||
# 安装 uv(如果尚未安装)
|
||||
pip install uv -i https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
# 配置 PyPI 镜像
|
||||
uv config set index-url https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
# 同步依赖
|
||||
uv sync
|
||||
```
|
||||
|
||||
### 2. 配置 API 密钥
|
||||
|
||||
```bash
|
||||
# 复制环境变量示例文件
|
||||
cp .env.example .env
|
||||
|
||||
# 编辑 .env 文件,填入你的 DeepSeek API Key
|
||||
# DEEPSEEK_API_KEY="your-key-here"
|
||||
```
|
||||
|
||||
### 3. 运行应用
|
||||
|
||||
#### 方式 A:Streamlit 可视化应用(推荐)
|
||||
|
||||
```bash
|
||||
uv run streamlit run src/streamlit_app.py
|
||||
```
|
||||
|
||||
#### 方式 B:命令行 Agent Demo
|
||||
|
||||
```bash
|
||||
uv run python src/agent_app.py
|
||||
```
|
||||
|
||||
#### 方式 C:运行模型训练
|
||||
|
||||
```bash
|
||||
uv run python src/models/train.py
|
||||
```
|
||||
|
||||
## 使用说明
|
||||
|
||||
### 1. Streamlit 应用
|
||||
|
||||
启动 Streamlit 应用后,你可以:
|
||||
|
||||
- **单条短信分类**:在左侧输入框中输入短信内容,点击「开始分类」按钮,系统将返回分类结果、解释和建议
|
||||
- **批量分类**:上传包含 `text` 列的 CSV 文件,系统将自动对所有短信进行分类
|
||||
- **模型选择**:在侧边栏选择使用 LightGBM 或 Logistic Regression 模型
|
||||
- **语言选择**:在侧边栏选择输出结果的语言(中文/英文)
|
||||
- **重新训练模型**:在「模型训练」展开栏中点击「重新训练模型」按钮
|
||||
|
||||
### 2. 命令行使用
|
||||
|
||||
```python
|
||||
from src.agent import agent
|
||||
|
||||
# 分类并解释短信
|
||||
result = agent.classify_and_explain("Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005.")
|
||||
|
||||
print(f"分类结果: {result['classification']['label']}")
|
||||
print(f"分类概率: {result['classification']['probability']}")
|
||||
print(f"解释: {result['explanation']['classification_reason']}")
|
||||
print(f"建议: {result['explanation']['suggestions']}")
|
||||
```
|
||||
|
||||
## 模型训练
|
||||
|
||||
### 训练流程
|
||||
|
||||
1. **数据加载**:使用 Polars 加载 `data/spam.csv` 数据集
|
||||
2. **数据预处理**:
|
||||
- 清洗文本(去除 HTML 实体、特殊字符、多余空格)
|
||||
- 转换标签(spam → 1,ham → 0)
|
||||
- 数据验证(使用 Pandera 验证数据质量)
|
||||
3. **数据分割**:按 8:2 比例分割训练集和测试集
|
||||
4. **特征工程**:使用 TF-IDF 进行文本向量化
|
||||
5. **模型训练**:
|
||||
- 基线模型:Logistic Regression
|
||||
- 强模型:LightGBM
|
||||
6. **模型评估**:计算准确率、精确率、召回率、F1 分数
|
||||
7. **模型保存**:将训练好的模型保存到 `models/` 目录
|
||||
|
||||
### 评估结果
|
||||
|
||||
| 模型 | 准确率 | 精确率(Macro) | 召回率(Macro) | F1 分数(Macro) |
|
||||
|------|--------|----------------|----------------|------------------|
|
||||
| Logistic Regression | 0.978 | 0.964 | 0.954 | 0.959 |
|
||||
| LightGBM | 0.985 | 0.974 | 0.968 | 0.971 |
|
||||
|
||||
## 系统架构
|
||||
|
||||
### 1. 数据层
|
||||
|
||||
- **数据加载**:使用 Polars 高效加载和处理大规模数据
|
||||
- **数据清洗**:实现可复现的数据清洗流水线
|
||||
- **数据验证**:使用 Pandera 定义数据 Schema,确保数据质量
|
||||
|
||||
### 2. 模型层
|
||||
|
||||
- **特征工程**:TF-IDF 文本向量化
|
||||
- **模型训练**:支持多种模型(LightGBM、Logistic Regression)
|
||||
- **模型评估**:全面的评估指标和可视化
|
||||
- **模型管理**:模型保存、加载和版本控制
|
||||
|
||||
### 3. LLM 层
|
||||
|
||||
- **文本翻译**:将英文短信翻译成中文
|
||||
- **结果解释**:解释模型分类结果的原因
|
||||
- **行动建议**:根据分类结果生成可执行的建议
|
||||
- **模式分析**:分析垃圾短信的常见模式
|
||||
|
||||
### 4. Agent 层
|
||||
|
||||
- **工具定义**:
|
||||
- `predict_spam`:使用机器学习模型预测短信是否为垃圾短信
|
||||
- `explain_prediction`:使用 LLM 解释分类结果并生成行动建议
|
||||
- `translate_text`:将文本翻译成目标语言
|
||||
- **Agent 逻辑**:实现工具调用、结果整合和自然语言生成
|
||||
- **结构化输出**:确保输出结果可追溯、可复现
|
||||
|
||||
## 贡献指南
|
||||
|
||||
1. 朱指乐-仓库创建-项目初始化-数据获取-数据预处理-模型训练-模型评估-模型保存
|
||||
2. 肖康-LLM服务-实现文本解释和建议生成功能-agent实现-工具调用-结构化输出
|
||||
3. 龙思富-可视化-交互式可视化应用-文档编写
|
||||
|
||||
## 许可证
|
||||
|
||||
|
||||
MIT License
|
||||
|
||||
|
||||
## 致谢
|
||||
|
||||
|
||||
- 感谢 [DeepSeek](https://www.deepseek.com/) 提供的 LLM API
|
||||
- 感谢 Kaggle 提供的 [SMS Spam Collection](https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset) 数据集
|
||||
- 感谢所有开源库的贡献者
|
||||
|
||||
|
||||
## 联系方式
|
||||
|
||||
|
||||
如有问题或建议,欢迎通过以下方式联系:
|
||||
|
||||
- 项目地址:[http://hblu.top:3000/MachineLearning2025/CourseDesign](http://hblu.top:3000/MachineLearning2025/CourseDesign)
|
||||
- 邮箱:your-email@example.com
|
||||
- 邮箱:xxxxxxxxxx@gmail.com
|
||||
|
||||
|
||||
---
|
||||
|
||||
|
||||
@ -50,10 +50,20 @@ class SpamClassificationAgent:
|
||||
self.api_key = env_vars.get("DEEPSEEK_API_KEY") or os.getenv("DEEPSEEK_API_KEY")
|
||||
self.base_url = env_vars.get("DEEPSEEK_BASE_URL") or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
||||
|
||||
# 验证并修复base_url,确保包含协议前缀
|
||||
if self.base_url and not (self.base_url.startswith("http://") or self.base_url.startswith("https://")):
|
||||
print(f"警告: base_url '{self.base_url}' 缺少协议前缀,将添加 'https://'")
|
||||
self.base_url = f"https://{self.base_url}"
|
||||
# 清理并修复base_url,确保包含协议前缀
|
||||
if self.base_url:
|
||||
# 清理URL:移除多余的空格、引号和反引号
|
||||
self.base_url = self.base_url.strip()
|
||||
# 移除可能的引号
|
||||
for quote_char in ['"', "'", '`']:
|
||||
if self.base_url.startswith(quote_char) and self.base_url.endswith(quote_char):
|
||||
self.base_url = self.base_url[1:-1].strip()
|
||||
break
|
||||
|
||||
# 确保包含协议前缀
|
||||
if not (self.base_url.startswith("http://") or self.base_url.startswith("https://")):
|
||||
print(f"警告: base_url '{self.base_url}' 缺少协议前缀,将添加 'https://'")
|
||||
self.base_url = f"https://{self.base_url}"
|
||||
|
||||
# 延迟创建客户端,直到实际需要时
|
||||
self.client = None
|
||||
|
||||
@ -29,10 +29,20 @@ class LLMService:
|
||||
self.api_key = env_vars.get("DEEPSEEK_API_KEY") or os.getenv("DEEPSEEK_API_KEY")
|
||||
self.base_url = env_vars.get("DEEPSEEK_BASE_URL") or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
||||
|
||||
# 验证并修复base_url,确保包含协议前缀
|
||||
if self.base_url and not (self.base_url.startswith("http://") or self.base_url.startswith("https://")):
|
||||
print(f"警告: base_url '{self.base_url}' 缺少协议前缀,将添加 'https://'")
|
||||
self.base_url = f"https://{self.base_url}"
|
||||
# 清理并修复base_url,确保包含协议前缀
|
||||
if self.base_url:
|
||||
# 清理URL:移除多余的空格、引号和反引号
|
||||
self.base_url = self.base_url.strip()
|
||||
# 移除可能的引号
|
||||
for quote_char in ['"', "'", '`']:
|
||||
if self.base_url.startswith(quote_char) and self.base_url.endswith(quote_char):
|
||||
self.base_url = self.base_url[1:-1].strip()
|
||||
break
|
||||
|
||||
# 确保包含协议前缀
|
||||
if not (self.base_url.startswith("http://") or self.base_url.startswith("https://")):
|
||||
print(f"警告: base_url '{self.base_url}' 缺少协议前缀,将添加 'https://'")
|
||||
self.base_url = f"https://{self.base_url}"
|
||||
|
||||
# 默认模型
|
||||
self.default_model = "deepseek-chat"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user