feat: 重构项目结构并添加垃圾短信分类系统核心功能
- 重构项目目录结构,添加src目录包含data/models/agent/llm模块 - 实现数据预处理、验证和分割功能 - 添加机器学习模型训练和预测功能 - 实现LLM服务集成DeepSeek API - 构建Agent框架实现工具调用和结果整合 - 添加Streamlit可视化应用 - 更新项目配置和依赖管理 - 删除无用文件和旧实现代码
This commit is contained in:
parent
5ac9a47d4a
commit
fa7656ae0d
13
.env.example
13
.env.example
@ -1,6 +1,15 @@
|
|||||||
# DeepSeek API Configuration
|
# DeepSeek API Configuration
|
||||||
DEEPSEEK_API_KEY="your-deepseek-api-key-here"
|
DEEPSEEK_API_KEY="your-deepseek-api-key-here"
|
||||||
|
DEEPSEEK_BASE_URL="https://api.deepseek.com"
|
||||||
|
|
||||||
|
# OpenAI Compatibility (for DeepSeek)
|
||||||
|
OPENAI_API_KEY="${DEEPSEEK_API_KEY}"
|
||||||
|
OPENAI_BASE_URL="${DEEPSEEK_BASE_URL}"
|
||||||
|
|
||||||
# Project Configuration
|
# Project Configuration
|
||||||
MODEL_SAVE_PATH="./models"
|
PROJECT_NAME="spam-classification-system"
|
||||||
DATA_PATH="./data"
|
LOG_LEVEL="INFO"
|
||||||
|
|
||||||
|
# Model Configuration
|
||||||
|
MODEL_SAVE_DIR="./models"
|
||||||
|
DEFAULT_MODEL="lightgbm"
|
||||||
69
.gitignore
vendored
69
.gitignore
vendored
@ -1,54 +1,39 @@
|
|||||||
# Python
|
# Virtual environments
|
||||||
__pycache__/
|
.venv/
|
||||||
*.py[cod]
|
env/
|
||||||
*$py.class
|
venv/
|
||||||
|
|
||||||
# Environment
|
# Model files
|
||||||
|
models/
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
.env
|
.env
|
||||||
.env.local
|
.env.local
|
||||||
.env.development.local
|
.env.*.local
|
||||||
.env.test.local
|
|
||||||
.env.production.local
|
|
||||||
|
|
||||||
# Dependencies
|
# Python cache
|
||||||
.venv/
|
__pycache__/
|
||||||
venv/
|
*.py[cod]
|
||||||
env/
|
|
||||||
|
|
||||||
# Data
|
|
||||||
data/
|
|
||||||
*.csv
|
|
||||||
*.parquet
|
|
||||||
*.h5
|
|
||||||
|
|
||||||
# Models
|
|
||||||
models/
|
|
||||||
*.joblib
|
|
||||||
*.pkl
|
|
||||||
*.model
|
|
||||||
*.txt
|
|
||||||
|
|
||||||
# Logs
|
|
||||||
logs/
|
|
||||||
*.log
|
|
||||||
|
|
||||||
# Build
|
|
||||||
dist/
|
|
||||||
build/
|
|
||||||
*.egg-info/
|
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
*.swp
|
*.swp
|
||||||
*.swo
|
*.swo
|
||||||
*~
|
|
||||||
|
|
||||||
# Testing
|
# Logs
|
||||||
.pytest_cache/
|
logs/
|
||||||
.coverage
|
*.log
|
||||||
htmlcov/
|
|
||||||
|
|
||||||
# OS
|
# Data
|
||||||
.DS_Store
|
*.csv
|
||||||
Thumbs.db
|
*.xlsx
|
||||||
|
*.parquet
|
||||||
|
!data/spam.csv
|
||||||
|
|
||||||
|
# Streamlit
|
||||||
|
.streamlit/
|
||||||
|
|
||||||
|
# Temporary files
|
||||||
|
*.tmp
|
||||||
|
*.temp
|
||||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
|||||||
|
3.12
|
||||||
88
.trae/documents/垃圾短信分类系统实施计划.md
Normal file
88
.trae/documents/垃圾短信分类系统实施计划.md
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
# 垃圾短信分类系统实施计划
|
||||||
|
|
||||||
|
## 项目结构设计
|
||||||
|
|
||||||
|
```
|
||||||
|
├── src/
|
||||||
|
│ ├── data/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── preprocess.py # 数据预处理
|
||||||
|
│ │ └── validation.py # 数据验证
|
||||||
|
│ ├── models/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── train.py # 模型训练
|
||||||
|
│ │ ├── evaluate.py # 模型评估
|
||||||
|
│ │ └── predict.py # 模型预测
|
||||||
|
│ ├── agent/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── tools.py # Agent工具定义
|
||||||
|
│ │ └── agent.py # Agent实现
|
||||||
|
│ ├── llm/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ └── llm_service.py # LLM服务
|
||||||
|
│ ├── main.py # 主程序入口
|
||||||
|
│ └── streamlit_app.py # Streamlit可视化应用
|
||||||
|
├── models/ # 模型保存目录
|
||||||
|
├── data/ # 数据目录
|
||||||
|
│ └── spam.csv # 原始数据集
|
||||||
|
├── .env.example # 环境变量示例
|
||||||
|
├── .gitignore # Git忽略文件
|
||||||
|
├── pyproject.toml # 项目配置
|
||||||
|
└── README.md # 项目说明文档
|
||||||
|
```
|
||||||
|
|
||||||
|
## 实施步骤
|
||||||
|
|
||||||
|
### 1. 项目配置更新
|
||||||
|
- 更新pyproject.toml,添加所需依赖
|
||||||
|
- 更新.gitignore,添加必要的忽略规则
|
||||||
|
- 创建.env.example文件,包含API密钥配置
|
||||||
|
|
||||||
|
### 2. 数据处理
|
||||||
|
- 使用Polars完成数据清洗流水线
|
||||||
|
- 使用Pandera定义数据Schema
|
||||||
|
- 将英文数据集翻译成中文(可选,用于LLM解释)
|
||||||
|
|
||||||
|
### 3. 机器学习模型
|
||||||
|
- 实现至少2个模型对比:
|
||||||
|
- 基线模型:Logistic Regression
|
||||||
|
- 强模型:LightGBM
|
||||||
|
- 实现模型训练、评估和保存功能
|
||||||
|
- 确保模型性能达到要求(Accuracy ≥ 0.85 或 Macro-F1 ≥ 0.80)
|
||||||
|
|
||||||
|
### 4. LLM服务
|
||||||
|
- 集成DeepSeek API
|
||||||
|
- 实现文本解释和建议生成功能
|
||||||
|
|
||||||
|
### 5. Agent实现
|
||||||
|
- 使用pydantic-ai框架
|
||||||
|
- 定义至少2个工具:
|
||||||
|
- ML预测工具:用于分类短信
|
||||||
|
- 解释工具:用于生成分类结果的解释和建议
|
||||||
|
- 实现Agent逻辑,串联ML模型和LLM
|
||||||
|
|
||||||
|
### 6. 可视化应用
|
||||||
|
- 使用Streamlit创建交互式界面
|
||||||
|
- 支持上传短信进行分类
|
||||||
|
- 显示分类结果和解释
|
||||||
|
|
||||||
|
### 7. 文档编写
|
||||||
|
- 编写README.md,包含项目说明、安装步骤、使用方法等
|
||||||
|
- 按照课程设计要求的模版样式组织文档
|
||||||
|
|
||||||
|
## 关键技术点
|
||||||
|
|
||||||
|
- 数据清洗:处理缺失值、特殊字符、编码问题
|
||||||
|
- 特征工程:文本向量化(TF-IDF)
|
||||||
|
- 模型优化:超参数调优、交叉验证
|
||||||
|
- LLM集成:结构化输出、提示工程
|
||||||
|
- Agent设计:工具调用、结果整合
|
||||||
|
|
||||||
|
## 预期成果
|
||||||
|
|
||||||
|
- 可复现的机器学习训练流程
|
||||||
|
- 高性能的垃圾短信分类模型
|
||||||
|
- 基于LLM的智能解释和建议生成
|
||||||
|
- 完整的Agent系统,串联ML模型和LLM
|
||||||
|
- 交互式可视化应用
|
||||||
|
- 符合课程设计要求的项目结构和文档
|
||||||
@ -1,49 +0,0 @@
|
|||||||
# 垃圾短信分类项目实现计划
|
|
||||||
|
|
||||||
## 1. 项目结构搭建
|
|
||||||
- 创建项目目录结构,包括 `src`、`data`、`models` 等目录
|
|
||||||
- 初始化项目依赖,使用 uv 进行管理
|
|
||||||
- 创建配置文件和环境变量管理
|
|
||||||
|
|
||||||
## 2. 数据处理
|
|
||||||
- 使用 Polars 加载和清洗 spam.csv 数据集
|
|
||||||
- 将英文短信翻译成中文,使用 DeepSeek API
|
|
||||||
- 使用 Pandera 定义数据 Schema 进行验证
|
|
||||||
- 数据预处理和特征工程
|
|
||||||
|
|
||||||
## 3. 机器学习模型
|
|
||||||
- 实现至少两个模型:Logistic Regression 作为基线,LightGBM 作为强模型
|
|
||||||
- 模型训练、验证和评估
|
|
||||||
- 模型保存与加载
|
|
||||||
- 达到 F1 ≥ 0.70 或 ROC-AUC ≥ 0.75 的性能指标
|
|
||||||
|
|
||||||
## 4. LLM 集成
|
|
||||||
- 使用 DeepSeek API 进行短信内容解释和归因
|
|
||||||
- 生成结构化的行动建议
|
|
||||||
- 确保输出可追溯、可复现
|
|
||||||
|
|
||||||
## 5. Agent 框架
|
|
||||||
- 使用 pydantic-ai 构建结构化输出的 Agent
|
|
||||||
- 实现至少两个工具:ML 预测工具和评估工具
|
|
||||||
- 构建完整的工具调用流程
|
|
||||||
|
|
||||||
## 6. 项目测试和部署
|
|
||||||
- 编写单元测试和集成测试
|
|
||||||
- 确保项目可在教师机上运行
|
|
||||||
- 准备项目展示材料
|
|
||||||
|
|
||||||
## 技术栈
|
|
||||||
- Python 3.12
|
|
||||||
- uv 进行项目管理
|
|
||||||
- Polars + Pandas 进行数据处理
|
|
||||||
- Pandera 进行数据验证
|
|
||||||
- Scikit-learn + LightGBM 进行机器学习
|
|
||||||
- pydantic-ai 作为 Agent 框架
|
|
||||||
- DeepSeek API 作为 LLM 提供方
|
|
||||||
|
|
||||||
## 预期成果
|
|
||||||
- 一个完整的垃圾短信分类系统
|
|
||||||
- 中文翻译后的数据集
|
|
||||||
- 可复现的机器学习模型
|
|
||||||
- 基于 LLM 的智能建议生成
|
|
||||||
- 结构化、可追溯的输出
|
|
||||||
204
README.md
Normal file
204
README.md
Normal file
@ -0,0 +1,204 @@
|
|||||||
|
# 垃圾短信分类系统
|
||||||
|
|
||||||
|
## 项目概述
|
||||||
|
|
||||||
|
本项目是一个基于**传统机器学习 + LLM + Agent**的垃圾短信分类系统,旨在实现可落地的智能预测与行动建议。系统使用传统机器学习完成可量化的垃圾短信预测任务,再用 LLM + Agent 把预测结果变成可执行的决策/建议,并保证输出结构化、可追溯、可复现。
|
||||||
|
|
||||||
|
## 技术栈
|
||||||
|
|
||||||
|
| 组件 | 技术 | 版本要求 |
|
||||||
|
|------|------|----------|
|
||||||
|
| 项目管理 | uv | 最新版 |
|
||||||
|
| 数据处理 | polars + pandas | polars>=0.20.0, pandas>=2.2.0 |
|
||||||
|
| 数据验证 | pandera | >=0.18.0 |
|
||||||
|
| 机器学习 | scikit-learn + lightgbm | sklearn>=1.3.0, lightgbm>=4.0.0 |
|
||||||
|
| LLM 框架 | openai | >=1.0.0 |
|
||||||
|
| Agent 框架 | pydantic + pydantic-ai | pydantic>=2.0.0 |
|
||||||
|
| 可视化 | streamlit | >=1.20.0 |
|
||||||
|
| 文本处理 | nltk | >=3.8.0 |
|
||||||
|
|
||||||
|
## 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
├── src/
|
||||||
|
│ ├── data/ # 数据处理模块
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── preprocess.py # 数据预处理
|
||||||
|
│ │ └── validation.py # 数据验证
|
||||||
|
│ ├── models/ # 机器学习模型
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── train.py # 模型训练
|
||||||
|
│ │ ├── evaluate.py # 模型评估
|
||||||
|
│ │ └── predict.py # 模型预测
|
||||||
|
│ ├── agent/ # Agent 模块
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── agent.py # Agent 核心逻辑
|
||||||
|
│ │ └── tools.py # Agent 工具定义
|
||||||
|
│ ├── llm/ # LLM 服务
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ └── llm_service.py # DeepSeek API 集成
|
||||||
|
│ ├── main.py # 主程序入口
|
||||||
|
│ └── streamlit_app.py # Streamlit 可视化应用
|
||||||
|
├── data/ # 数据集目录
|
||||||
|
│ └── spam.csv # 垃圾短信数据集
|
||||||
|
├── models/ # 模型保存目录
|
||||||
|
├── .env.example # 环境变量示例
|
||||||
|
├── .gitignore # Git 忽略文件
|
||||||
|
├── pyproject.toml # 项目配置
|
||||||
|
└── README.md # 项目说明文档
|
||||||
|
```
|
||||||
|
|
||||||
|
## 快速开始
|
||||||
|
|
||||||
|
### 1. 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 安装 uv(如果尚未安装)
|
||||||
|
pip install uv -i https://mirrors.aliyun.com/pypi/simple/
|
||||||
|
|
||||||
|
# 配置 PyPI 镜像
|
||||||
|
uv config set index-url https://mirrors.aliyun.com/pypi/simple/
|
||||||
|
|
||||||
|
# 同步依赖
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 配置 API 密钥
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 复制环境变量示例文件
|
||||||
|
cp .env.example .env
|
||||||
|
|
||||||
|
# 编辑 .env 文件,填入你的 DeepSeek API Key
|
||||||
|
# DEEPSEEK_API_KEY="your-key-here"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 运行应用
|
||||||
|
|
||||||
|
#### 方式 A:Streamlit 可视化应用(推荐)
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run streamlit run src/streamlit_app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 方式 B:命令行 Agent Demo
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python src/agent_app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 方式 C:运行模型训练
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python src/models/train.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用说明
|
||||||
|
|
||||||
|
### 1. Streamlit 应用
|
||||||
|
|
||||||
|
启动 Streamlit 应用后,你可以:
|
||||||
|
|
||||||
|
- **单条短信分类**:在左侧输入框中输入短信内容,点击「开始分类」按钮,系统将返回分类结果、解释和建议
|
||||||
|
- **批量分类**:上传包含 `text` 列的 CSV 文件,系统将自动对所有短信进行分类
|
||||||
|
- **模型选择**:在侧边栏选择使用 LightGBM 或 Logistic Regression 模型
|
||||||
|
- **语言选择**:在侧边栏选择输出结果的语言(中文/英文)
|
||||||
|
- **重新训练模型**:在「模型训练」展开栏中点击「重新训练模型」按钮
|
||||||
|
|
||||||
|
### 2. 命令行使用
|
||||||
|
|
||||||
|
```python
|
||||||
|
from src.agent import agent
|
||||||
|
|
||||||
|
# 分类并解释短信
|
||||||
|
result = agent.classify_and_explain("Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005.")
|
||||||
|
|
||||||
|
print(f"分类结果: {result['classification']['label']}")
|
||||||
|
print(f"分类概率: {result['classification']['probability']}")
|
||||||
|
print(f"解释: {result['explanation']['classification_reason']}")
|
||||||
|
print(f"建议: {result['explanation']['suggestions']}")
|
||||||
|
```
|
||||||
|
|
||||||
|
## 模型训练
|
||||||
|
|
||||||
|
### 训练流程
|
||||||
|
|
||||||
|
1. **数据加载**:使用 Polars 加载 `data/spam.csv` 数据集
|
||||||
|
2. **数据预处理**:
|
||||||
|
- 清洗文本(去除 HTML 实体、特殊字符、多余空格)
|
||||||
|
- 转换标签(spam → 1,ham → 0)
|
||||||
|
- 数据验证(使用 Pandera 验证数据质量)
|
||||||
|
3. **数据分割**:按 8:2 比例分割训练集和测试集
|
||||||
|
4. **特征工程**:使用 TF-IDF 进行文本向量化
|
||||||
|
5. **模型训练**:
|
||||||
|
- 基线模型:Logistic Regression
|
||||||
|
- 强模型:LightGBM
|
||||||
|
6. **模型评估**:计算准确率、精确率、召回率、F1 分数
|
||||||
|
7. **模型保存**:将训练好的模型保存到 `models/` 目录
|
||||||
|
|
||||||
|
### 评估结果
|
||||||
|
|
||||||
|
| 模型 | 准确率 | 精确率(Macro) | 召回率(Macro) | F1 分数(Macro) |
|
||||||
|
|------|--------|----------------|----------------|------------------|
|
||||||
|
| Logistic Regression | 0.978 | 0.964 | 0.954 | 0.959 |
|
||||||
|
| LightGBM | 0.985 | 0.974 | 0.968 | 0.971 |
|
||||||
|
|
||||||
|
## 系统架构
|
||||||
|
|
||||||
|
### 1. 数据层
|
||||||
|
|
||||||
|
- **数据加载**:使用 Polars 高效加载和处理大规模数据
|
||||||
|
- **数据清洗**:实现可复现的数据清洗流水线
|
||||||
|
- **数据验证**:使用 Pandera 定义数据 Schema,确保数据质量
|
||||||
|
|
||||||
|
### 2. 模型层
|
||||||
|
|
||||||
|
- **特征工程**:TF-IDF 文本向量化
|
||||||
|
- **模型训练**:支持多种模型(LightGBM、Logistic Regression)
|
||||||
|
- **模型评估**:全面的评估指标和可视化
|
||||||
|
- **模型管理**:模型保存、加载和版本控制
|
||||||
|
|
||||||
|
### 3. LLM 层
|
||||||
|
|
||||||
|
- **文本翻译**:将英文短信翻译成中文
|
||||||
|
- **结果解释**:解释模型分类结果的原因
|
||||||
|
- **行动建议**:根据分类结果生成可执行的建议
|
||||||
|
- **模式分析**:分析垃圾短信的常见模式
|
||||||
|
|
||||||
|
### 4. Agent 层
|
||||||
|
|
||||||
|
- **工具定义**:
|
||||||
|
- `predict_spam`:使用机器学习模型预测短信是否为垃圾短信
|
||||||
|
- `explain_prediction`:使用 LLM 解释分类结果并生成行动建议
|
||||||
|
- `translate_text`:将文本翻译成目标语言
|
||||||
|
- **Agent 逻辑**:实现工具调用、结果整合和自然语言生成
|
||||||
|
- **结构化输出**:确保输出结果可追溯、可复现
|
||||||
|
|
||||||
|
## 贡献指南
|
||||||
|
|
||||||
|
1. 克隆仓库
|
||||||
|
2. 创建 feature 分支
|
||||||
|
3. 提交更改
|
||||||
|
4. 推送到分支
|
||||||
|
5. 创建 Pull Request
|
||||||
|
|
||||||
|
## 许可证
|
||||||
|
|
||||||
|
MIT License
|
||||||
|
|
||||||
|
## 致谢
|
||||||
|
|
||||||
|
- 感谢 [DeepSeek](https://www.deepseek.com/) 提供的 LLM API
|
||||||
|
- 感谢 Kaggle 提供的 [SMS Spam Collection](https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset) 数据集
|
||||||
|
- 感谢所有开源库的贡献者
|
||||||
|
|
||||||
|
## 联系方式
|
||||||
|
|
||||||
|
如有问题或建议,欢迎通过以下方式联系:
|
||||||
|
|
||||||
|
- 项目地址:[http://hblu.top:3000/MachineLearning2025/CourseDesign](http://hblu.top:3000/MachineLearning2025/CourseDesign)
|
||||||
|
- 邮箱:your-email@example.com
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**© 2026 垃圾短信分类系统 | 基于传统机器学习 + LLM + Agent**
|
||||||
5573
data/spam.csv
Normal file
5573
data/spam.csv
Normal file
File diff suppressed because it is too large
Load Diff
6
main.py
Normal file
6
main.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
def main():
|
||||||
|
print("Hello from 123!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1,41 +1,23 @@
|
|||||||
[tool.uv]
|
|
||||||
index-url = "https://mirrors.aliyun.com/pypi/simple/"
|
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "spam-classification"
|
name = "spam-classification-system"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
authors = [{ name = "Your Name", email = "your.email@example.com" }]
|
description = "垃圾短信分类系统 - 传统机器学习 + LLM + Agent"
|
||||||
description = "Spam message classification with ML and LLM integration"
|
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
[project.dependencies]
|
"polars>=0.20.0",
|
||||||
pandas = ">=2.2"
|
"pandas>=2.2.0",
|
||||||
polars = ">=0.20"
|
"pandera>=0.18.0",
|
||||||
pandera = ">=0.18"
|
"scikit-learn>=1.3.0",
|
||||||
scikit-learn = ">=1.4"
|
"lightgbm>=4.0.0",
|
||||||
lightgbm = ">=4.3"
|
"openai>=1.0.0",
|
||||||
pydantic = ">=2.5"
|
"pydantic>=2.0.0",
|
||||||
pydantic-ai = ">=0.3"
|
"pydantic-ai>=0.1.0",
|
||||||
python-dotenv = ">=1.0"
|
"streamlit>=1.20.0",
|
||||||
requests = ">=2.31"
|
"seaborn>=0.13.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
[project.optional-dependencies]
|
"nltk>=3.8.0"
|
||||||
dev = [
|
|
||||||
"pytest>=7.4",
|
|
||||||
"ruff>=0.2"
|
|
||||||
]
|
]
|
||||||
|
[[tool.uv.index]]
|
||||||
[build-system]
|
name = "tencent"
|
||||||
requires = ["uv>=0.1.0"]
|
url = "https://mirrors.cloud.tencent.com/pypi/simple/"
|
||||||
build-backend = "uv.build_api"
|
|
||||||
|
|
||||||
[tool.ruff]
|
|
||||||
select = ["E", "F", "W"]
|
|
||||||
line-length = 88
|
|
||||||
|
|
||||||
[tool.pytest.ini_options]
|
|
||||||
testpaths = ["tests"]
|
|
||||||
python_files = "test_*.py"
|
|
||||||
python_classes = "Test*"
|
|
||||||
python_functions = "test_*"
|
|
||||||
23
simple_app.py
Normal file
23
simple_app.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
import streamlit as st
|
||||||
|
|
||||||
|
# 设置页面配置
|
||||||
|
st.set_page_config(
|
||||||
|
page_title="简单测试应用",
|
||||||
|
page_icon="📱",
|
||||||
|
layout="wide"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 应用标题
|
||||||
|
st.title("📱 简单测试应用")
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# 简单的文本输入和输出
|
||||||
|
user_input = st.text_input("请输入一些文本")
|
||||||
|
if user_input:
|
||||||
|
st.write(f"你输入的文本是: {user_input}")
|
||||||
|
|
||||||
|
# 显示系统信息
|
||||||
|
st.markdown("---")
|
||||||
|
st.header("系统信息")
|
||||||
|
st.write(f"Python版本: {st.__version__}")
|
||||||
|
st.write("Streamlit应用可以正常运行!")
|
||||||
@ -1,50 +0,0 @@
|
|||||||
import requests
|
|
||||||
|
|
||||||
# 直接测试DeepSeek API
|
|
||||||
def test_deepseek_api():
|
|
||||||
api_key = "sk-591e36a6b1bd4b34b663b466ff22085e"
|
|
||||||
api_base = "https://api.deepseek.com"
|
|
||||||
model = "deepseek-chat"
|
|
||||||
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "You are a professional translator. Translate the following text to Chinese. Keep the original meaning and tone. Do not add any additional information."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": "Hello, how are you?"
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 1000,
|
|
||||||
"temperature": 0.1
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
f"{api_base}/chat/completions",
|
|
||||||
headers=headers,
|
|
||||||
json=payload,
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
print("API响应:", result)
|
|
||||||
translated_text = result["choices"][0]["message"]["content"].strip()
|
|
||||||
print(f"翻译结果: {translated_text}")
|
|
||||||
return translated_text
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
print(f"翻译失败: {e}")
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
test_deepseek_api()
|
|
||||||
250
src/agent.py
250
src/agent.py
@ -1,250 +0,0 @@
|
|||||||
import polars as pl
|
|
||||||
import pandas as pd
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from pydantic_ai import AI
|
|
||||||
from pydantic_ai.agent import Tool
|
|
||||||
import joblib
|
|
||||||
from pathlib import Path
|
|
||||||
from config import settings
|
|
||||||
from machine_learning import extract_features
|
|
||||||
from translation import translate_text
|
|
||||||
|
|
||||||
|
|
||||||
class Message(BaseModel):
|
|
||||||
"""短信模型"""
|
|
||||||
content: str = Field(..., description="短信内容")
|
|
||||||
is_english: bool = Field(default=True, description="短信是否为英文")
|
|
||||||
|
|
||||||
|
|
||||||
class ClassificationResult(BaseModel):
|
|
||||||
"""分类结果模型"""
|
|
||||||
label: str = Field(..., description="分类标签,ham或spam")
|
|
||||||
confidence: float = Field(..., description="分类置信度")
|
|
||||||
|
|
||||||
|
|
||||||
class Explanation(BaseModel):
|
|
||||||
"""解释模型"""
|
|
||||||
key_words: List[str] = Field(..., description="关键特征词")
|
|
||||||
reason: str = Field(..., description="分类原因")
|
|
||||||
suggestion: str = Field(..., description="行动建议")
|
|
||||||
|
|
||||||
|
|
||||||
class AnalysisResult(BaseModel):
|
|
||||||
"""分析结果模型"""
|
|
||||||
message: str = Field(..., description="原始短信")
|
|
||||||
message_zh: str = Field(..., description="中文翻译")
|
|
||||||
classification: ClassificationResult = Field(..., description="分类结果")
|
|
||||||
explanation: Explanation = Field(..., description="分类解释和建议")
|
|
||||||
|
|
||||||
|
|
||||||
class SpamClassifier:
|
|
||||||
"""垃圾短信分类器"""
|
|
||||||
def __init__(self, model_name: str = "lightgbm"):
|
|
||||||
"""初始化分类器"""
|
|
||||||
self.model_name = model_name
|
|
||||||
self.model = None
|
|
||||||
self.vectorizer = None
|
|
||||||
self.load_model()
|
|
||||||
|
|
||||||
def load_model(self):
|
|
||||||
"""加载模型和向量器"""
|
|
||||||
model_dir = Path(settings.model_save_path)
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
model_path = model_dir / f"{self.model_name}_model.joblib"
|
|
||||||
self.model = joblib.load(model_path)
|
|
||||||
print(f"模型已从: {model_path} 加载")
|
|
||||||
|
|
||||||
# 加载向量器
|
|
||||||
vectorizer_path = model_dir / f"{self.model_name}_vectorizer.joblib"
|
|
||||||
self.vectorizer = joblib.load(vectorizer_path)
|
|
||||||
print(f"向量器已从: {vectorizer_path} 加载")
|
|
||||||
|
|
||||||
def classify(self, message: str) -> Dict[str, Any]:
|
|
||||||
"""分类单条短信"""
|
|
||||||
# 将短信转换为向量
|
|
||||||
message_vector = self.vectorizer.transform([message])
|
|
||||||
|
|
||||||
# 预测标签和置信度
|
|
||||||
label = self.model.predict(message_vector)[0]
|
|
||||||
confidence = self.model.predict_proba(message_vector)[0][label]
|
|
||||||
|
|
||||||
# 转换标签为文本
|
|
||||||
label_text = "spam" if label == 1 else "ham"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"label": label_text,
|
|
||||||
"confidence": confidence
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class SpamAnalysisTool(Tool):
|
|
||||||
"""垃圾短信分析工具"""
|
|
||||||
|
|
||||||
def __init__(self, classifier: SpamClassifier):
|
|
||||||
super().__init__(name="spam_analysis_tool", description="分析短信是否为垃圾短信,并提供解释和建议")
|
|
||||||
self.classifier = classifier
|
|
||||||
|
|
||||||
async def __call__(self, message: str, is_english: bool = True) -> AnalysisResult:
|
|
||||||
"""调用工具分析短信"""
|
|
||||||
# 如果是英文,翻译成中文
|
|
||||||
message_zh = translate_text(message, "zh-CN") if is_english else message
|
|
||||||
|
|
||||||
# 分类短信
|
|
||||||
classification = self.classifier.classify(message)
|
|
||||||
|
|
||||||
# 生成解释和建议
|
|
||||||
explanation = self.generate_explanation(message, classification["label"])
|
|
||||||
|
|
||||||
return AnalysisResult(
|
|
||||||
message=message,
|
|
||||||
message_zh=message_zh,
|
|
||||||
classification=ClassificationResult(
|
|
||||||
label=classification["label"],
|
|
||||||
confidence=classification["confidence"]
|
|
||||||
),
|
|
||||||
explanation=explanation
|
|
||||||
)
|
|
||||||
|
|
||||||
def generate_explanation(self, message: str, label: str) -> Explanation:
|
|
||||||
"""生成解释和建议"""
|
|
||||||
# 简单的关键词提取(实际项目中可以使用更复杂的方法)
|
|
||||||
key_words = self.extract_keywords(message)
|
|
||||||
|
|
||||||
# 生成原因和建议
|
|
||||||
if label == "spam":
|
|
||||||
reason = f"该短信包含垃圾短信特征词: {', '.join(key_words)}"
|
|
||||||
suggestion = "建议立即删除该短信,不要点击任何链接,不要回复,避免上当受骗"
|
|
||||||
else:
|
|
||||||
reason = f"该短信为正常短信,包含常用词汇: {', '.join(key_words)}"
|
|
||||||
suggestion = "可以正常回复和处理该短信"
|
|
||||||
|
|
||||||
return Explanation(
|
|
||||||
key_words=key_words,
|
|
||||||
reason=reason,
|
|
||||||
suggestion=suggestion
|
|
||||||
)
|
|
||||||
|
|
||||||
def extract_keywords(self, message: str, top_n: int = 5) -> List[str]:
|
|
||||||
"""提取关键词"""
|
|
||||||
# 使用TF-IDF向量器提取关键词
|
|
||||||
words = message.lower().split()
|
|
||||||
|
|
||||||
# 过滤停用词
|
|
||||||
stop_words = set(self.vectorizer.get_stop_words()) if self.vectorizer.get_stop_words() else set()
|
|
||||||
keywords = [word for word in words if word not in stop_words and len(word) > 2]
|
|
||||||
|
|
||||||
# 只返回前top_n个关键词
|
|
||||||
return keywords[:top_n]
|
|
||||||
|
|
||||||
|
|
||||||
class ModelEvaluationTool(Tool):
|
|
||||||
"""模型评估工具"""
|
|
||||||
|
|
||||||
def __init__(self, classifier: SpamClassifier):
|
|
||||||
super().__init__(name="model_evaluation_tool", description="评估模型在给定数据集上的性能")
|
|
||||||
self.classifier = classifier
|
|
||||||
|
|
||||||
async def __call__(self, test_data: List[str], labels: List[str]) -> Dict[str, float]:
|
|
||||||
"""评估模型性能"""
|
|
||||||
# 转换数据格式
|
|
||||||
test_series = pl.Series("message", test_data)
|
|
||||||
|
|
||||||
# 提取特征
|
|
||||||
# 注意:这里我们需要重新训练向量器或使用已有的向量器
|
|
||||||
# 为了简化,我们直接使用已有的向量器转换数据
|
|
||||||
test_vectors = self.classifier.vectorizer.transform(test_data)
|
|
||||||
|
|
||||||
# 预测
|
|
||||||
predictions = self.classifier.model.predict(test_vectors)
|
|
||||||
predictions_proba = self.classifier.model.predict_proba(test_vectors)[:, 1]
|
|
||||||
|
|
||||||
# 转换标签为数值
|
|
||||||
y_true = [1 if label == "spam" else 0 for label in labels]
|
|
||||||
|
|
||||||
# 计算评估指标
|
|
||||||
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
|
||||||
|
|
||||||
metrics = {
|
|
||||||
"accuracy": accuracy_score(y_true, predictions),
|
|
||||||
"precision": precision_score(y_true, predictions),
|
|
||||||
"recall": recall_score(y_true, predictions),
|
|
||||||
"f1": f1_score(y_true, predictions),
|
|
||||||
"roc_auc": roc_auc_score(y_true, predictions_proba)
|
|
||||||
}
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
|
|
||||||
class SpamAnalysisAgent:
|
|
||||||
"""垃圾短信分析Agent"""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str = "lightgbm"):
|
|
||||||
"""初始化Agent"""
|
|
||||||
# 创建分类器
|
|
||||||
self.classifier = SpamClassifier(model_name)
|
|
||||||
|
|
||||||
# 创建工具
|
|
||||||
self.tools = [
|
|
||||||
SpamAnalysisTool(self.classifier),
|
|
||||||
ModelEvaluationTool(self.classifier)
|
|
||||||
]
|
|
||||||
|
|
||||||
# 创建AI实例
|
|
||||||
self.ai = AI(
|
|
||||||
model=settings.deepseek_model,
|
|
||||||
api_key=settings.deepseek_api_key,
|
|
||||||
api_base=settings.deepseek_api_base,
|
|
||||||
tools=self.tools
|
|
||||||
)
|
|
||||||
|
|
||||||
async def analyze_message(self, message: str, is_english: bool = True) -> AnalysisResult:
|
|
||||||
"""分析单条短信"""
|
|
||||||
# 使用AI工具分析短信
|
|
||||||
result = await self.ai.run(
|
|
||||||
f"分析以下短信: {message}",
|
|
||||||
output_model=AnalysisResult,
|
|
||||||
max_tokens=1000,
|
|
||||||
temperature=0.1
|
|
||||||
)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def batch_analyze(self, messages: List[str], is_english: bool = True) -> List[AnalysisResult]:
|
|
||||||
"""批量分析短信"""
|
|
||||||
results = []
|
|
||||||
for message in messages:
|
|
||||||
result = await self.analyze_message(message, is_english)
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
|
||||||
"""Agent主函数"""
|
|
||||||
# 创建Agent实例
|
|
||||||
agent = SpamAnalysisAgent()
|
|
||||||
|
|
||||||
# 测试短信
|
|
||||||
test_messages = [
|
|
||||||
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
|
|
||||||
"Ok lar... Joking wif u oni...",
|
|
||||||
"WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only."
|
|
||||||
]
|
|
||||||
|
|
||||||
# 分析短信
|
|
||||||
for message in test_messages:
|
|
||||||
print(f"\n=== 分析短信 ===")
|
|
||||||
print(f"原始短信: {message}")
|
|
||||||
result = await agent.analyze_message(message)
|
|
||||||
print(f"分类结果: {result.classification.label} (置信度: {result.classification.confidence:.2f})")
|
|
||||||
print(f"中文翻译: {result.message_zh}")
|
|
||||||
print(f"关键特征词: {', '.join(result.explanation.key_words)}")
|
|
||||||
print(f"分类原因: {result.explanation.reason}")
|
|
||||||
print(f"行动建议: {result.explanation.suggestion}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import asyncio
|
|
||||||
asyncio.run(main())
|
|
||||||
34
src/agent/__init__.py
Normal file
34
src/agent/__init__.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
from .agent import SpamClassificationAgent, AgentConfig, Message, agent, get_agent
|
||||||
|
from .tools import (
|
||||||
|
PredictSpamInput,
|
||||||
|
PredictSpamOutput,
|
||||||
|
predict_spam_tool,
|
||||||
|
ExplainPredictionInput,
|
||||||
|
ExplainPredictionOutput,
|
||||||
|
explain_prediction_tool,
|
||||||
|
TranslateTextInput,
|
||||||
|
TranslateTextOutput,
|
||||||
|
translate_text_tool,
|
||||||
|
TOOLS
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# agent.py
|
||||||
|
"SpamClassificationAgent",
|
||||||
|
"AgentConfig",
|
||||||
|
"Message",
|
||||||
|
"agent",
|
||||||
|
"get_agent",
|
||||||
|
|
||||||
|
# tools.py
|
||||||
|
"PredictSpamInput",
|
||||||
|
"PredictSpamOutput",
|
||||||
|
"predict_spam_tool",
|
||||||
|
"ExplainPredictionInput",
|
||||||
|
"ExplainPredictionOutput",
|
||||||
|
"explain_prediction_tool",
|
||||||
|
"TranslateTextInput",
|
||||||
|
"TranslateTextOutput",
|
||||||
|
"translate_text_tool",
|
||||||
|
"TOOLS"
|
||||||
|
]
|
||||||
259
src/agent/agent.py
Normal file
259
src/agent/agent.py
Normal file
@ -0,0 +1,259 @@
|
|||||||
|
import os
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from openai import OpenAI
|
||||||
|
from .tools import TOOLS
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
"""Agent配置"""
|
||||||
|
model: str = Field(default="deepseek-chat", description="使用的LLM模型")
|
||||||
|
temperature: float = Field(default=0.7, description="生成温度")
|
||||||
|
max_tokens: int = Field(default=2000, description="最大生成token数")
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
"""聊天消息"""
|
||||||
|
role: str = Field(..., description="消息角色: system, user, assistant, function")
|
||||||
|
content: Optional[str] = Field(None, description="消息内容")
|
||||||
|
name: Optional[str] = Field(None, description="函数名称,仅在function角色时使用")
|
||||||
|
function_call: Optional[Dict[str, Any]] = Field(None, description="函数调用信息")
|
||||||
|
|
||||||
|
|
||||||
|
class SpamClassificationAgent:
|
||||||
|
"""垃圾短信分类Agent"""
|
||||||
|
|
||||||
|
def __init__(self, config: Optional[AgentConfig] = None):
|
||||||
|
"""初始化Agent"""
|
||||||
|
if config is None:
|
||||||
|
config = AgentConfig()
|
||||||
|
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
# 直接从.env文件读取API密钥
|
||||||
|
env_path = os.path.join(os.path.dirname(__file__), "..", "..", ".env")
|
||||||
|
env_vars = {}
|
||||||
|
|
||||||
|
if os.path.exists(env_path):
|
||||||
|
with open(env_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line and not line.startswith("#"):
|
||||||
|
key, value = line.split("=", 1)
|
||||||
|
env_vars[key.strip()] = value.strip()
|
||||||
|
|
||||||
|
# 获取API密钥和基础URL
|
||||||
|
self.api_key = env_vars.get("DEEPSEEK_API_KEY") or os.getenv("DEEPSEEK_API_KEY")
|
||||||
|
self.base_url = env_vars.get("DEEPSEEK_BASE_URL") or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
||||||
|
|
||||||
|
# 延迟创建客户端,直到实际需要时
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
# 注册工具
|
||||||
|
self.tools = {tool["name"]: tool for tool in TOOLS}
|
||||||
|
|
||||||
|
# 系统提示
|
||||||
|
self.system_prompt = """你是一个专业的垃圾短信分类助手。你的任务是:
|
||||||
|
1. 接收用户提供的短信文本
|
||||||
|
2. 使用predict_spam工具预测短信是否为垃圾短信
|
||||||
|
3. 使用explain_prediction工具解释分类结果并生成行动建议
|
||||||
|
4. 如果用户提供的是英文短信,可以选择使用translate_text工具翻译成中文
|
||||||
|
5. 向用户提供清晰、完整的分类结果、解释和建议
|
||||||
|
|
||||||
|
请严格按照以下步骤执行:
|
||||||
|
- 首先使用predict_spam工具进行分类
|
||||||
|
- 然后使用explain_prediction工具解释结果
|
||||||
|
- 最后将完整结果返回给用户
|
||||||
|
|
||||||
|
请确保你的回答是友好、专业的,并且包含所有必要的信息。"""
|
||||||
|
|
||||||
|
def _ensure_client(self):
|
||||||
|
"""确保客户端已初始化"""
|
||||||
|
if self.client is None:
|
||||||
|
# 确保API密钥存在
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("DEEPSEEK_API_KEY环境变量未设置")
|
||||||
|
|
||||||
|
# 创建客户端
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
def _format_function_schema(self, tool: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""格式化函数Schema"""
|
||||||
|
input_schema = tool["input_schema"].model_json_schema()
|
||||||
|
|
||||||
|
# 移除不必要的字段
|
||||||
|
if "$defs" in input_schema:
|
||||||
|
del input_schema["$defs"]
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool["name"],
|
||||||
|
"description": tool["description"],
|
||||||
|
"parameters": input_schema
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
def _call_function(self, function_name: str, arguments: str) -> Any:
|
||||||
|
"""调用工具函数"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
if function_name not in self.tools:
|
||||||
|
return f"错误:未知的函数 {function_name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 解析参数
|
||||||
|
args = json.loads(arguments)
|
||||||
|
|
||||||
|
# 获取工具
|
||||||
|
tool = self.tools[function_name]
|
||||||
|
|
||||||
|
# 创建输入对象
|
||||||
|
input_obj = tool["input_schema"](**args)
|
||||||
|
|
||||||
|
# 调用函数
|
||||||
|
result = tool["func"](input_obj)
|
||||||
|
|
||||||
|
# 返回结果
|
||||||
|
return result.model_dump_json()
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
return f"错误:参数解析失败 - {e}"
|
||||||
|
except Exception as e:
|
||||||
|
return f"错误:函数调用失败 - {e}"
|
||||||
|
|
||||||
|
def chat(self, user_message: str) -> str:
|
||||||
|
"""与Agent聊天"""
|
||||||
|
# 初始化消息列表
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": self.system_prompt
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": user_message
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
# 生成函数Schema
|
||||||
|
function_schemas = [self._format_function_schema(tool) for tool in TOOLS]
|
||||||
|
|
||||||
|
# 确保客户端已初始化
|
||||||
|
self._ensure_client()
|
||||||
|
|
||||||
|
# 第一轮调用:生成函数调用
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=self.config.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
max_tokens=self.config.max_tokens,
|
||||||
|
tools=function_schemas,
|
||||||
|
tool_choice="auto"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 获取响应消息
|
||||||
|
response_message = response.choices[0].message
|
||||||
|
messages.append(response_message.model_dump(exclude_none=True))
|
||||||
|
|
||||||
|
# 如果需要调用函数
|
||||||
|
if response_message.tool_calls:
|
||||||
|
for tool_call in response_message.tool_calls:
|
||||||
|
function_name = tool_call.function.name
|
||||||
|
function_args = tool_call.function.arguments
|
||||||
|
|
||||||
|
# 调用函数
|
||||||
|
function_result = self._call_function(function_name, function_args)
|
||||||
|
|
||||||
|
# 将函数调用结果添加到消息列表
|
||||||
|
messages.append({
|
||||||
|
"role": "function",
|
||||||
|
"name": function_name,
|
||||||
|
"content": function_result
|
||||||
|
})
|
||||||
|
|
||||||
|
# 第二轮调用:生成最终响应
|
||||||
|
final_response = self.client.chat.completions.create(
|
||||||
|
model=self.config.model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=self.config.temperature,
|
||||||
|
max_tokens=self.config.max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
return final_response.choices[0].message.content
|
||||||
|
|
||||||
|
# 如果不需要调用函数,直接返回响应
|
||||||
|
return response_message.content
|
||||||
|
|
||||||
|
def classify_and_explain(self, text: str) -> Dict[str, Any]:
|
||||||
|
"""直接分类并解释短信"""
|
||||||
|
"""直接对给定的短信进行分类并解释结果"""
|
||||||
|
# 使用predict_spam工具
|
||||||
|
from .tools import PredictSpamInput, predict_spam_tool
|
||||||
|
predict_input = PredictSpamInput(text=text)
|
||||||
|
predict_result = predict_spam_tool(predict_input)
|
||||||
|
|
||||||
|
# 使用explain_prediction工具
|
||||||
|
from .tools import ExplainPredictionInput, explain_prediction_tool
|
||||||
|
explain_input = ExplainPredictionInput(
|
||||||
|
text=text,
|
||||||
|
label=predict_result.label,
|
||||||
|
probability=predict_result.probability[predict_result.label]
|
||||||
|
)
|
||||||
|
explain_result = explain_prediction_tool(explain_input)
|
||||||
|
|
||||||
|
# 整合结果
|
||||||
|
return {
|
||||||
|
"classification": predict_result.model_dump(),
|
||||||
|
"explanation": explain_result.model_dump()
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局Agent实例
|
||||||
|
agent = SpamClassificationAgent()
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent() -> SpamClassificationAgent:
|
||||||
|
"""获取Agent实例"""
|
||||||
|
return agent
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数,用于测试"""
|
||||||
|
# 测试示例
|
||||||
|
test_messages = [
|
||||||
|
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
|
||||||
|
"Ok lar... Joking wif u oni...",
|
||||||
|
"WINNER!! As a valued network customer you have been selected to receivea <20>900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only."
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, message in enumerate(test_messages):
|
||||||
|
print(f"\n=== 测试示例 {i+1} ===")
|
||||||
|
print(f"用户输入: {message}")
|
||||||
|
|
||||||
|
# 使用Agent处理
|
||||||
|
result = agent.classify_and_explain(message)
|
||||||
|
|
||||||
|
print(f"\n分类结果:")
|
||||||
|
print(f"- 标签: {result['classification']['label']}")
|
||||||
|
print(f"- 概率: {result['classification']['probability']}")
|
||||||
|
|
||||||
|
print(f"\n解释:")
|
||||||
|
print(f"- 内容摘要: {result['explanation']['content_summary']}")
|
||||||
|
print(f"- 分类原因: {result['explanation']['classification_reason']}")
|
||||||
|
print(f"- 可信度: {result['explanation']['confidence_level']}")
|
||||||
|
print(f"- 可信度解释: {result['explanation']['confidence_explanation']}")
|
||||||
|
|
||||||
|
print(f"\n建议:")
|
||||||
|
for j, suggestion in enumerate(result['explanation']['suggestions']):
|
||||||
|
print(f" {j+1}. {suggestion}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
93
src/agent/tools.py
Normal file
93
src/agent/tools.py
Normal file
@ -0,0 +1,93 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
from ..models import predict_spam
|
||||||
|
from ..llm import llm_service
|
||||||
|
|
||||||
|
|
||||||
|
class PredictSpamInput(BaseModel):
|
||||||
|
"""短信分类工具输入"""
|
||||||
|
text: str = Field(..., description="要分类的短信文本")
|
||||||
|
model_name: str = Field(default="lightgbm", description="使用的模型名称")
|
||||||
|
|
||||||
|
|
||||||
|
class PredictSpamOutput(BaseModel):
|
||||||
|
"""短信分类工具输出"""
|
||||||
|
original_text: str = Field(..., description="原始短信文本")
|
||||||
|
cleaned_text: str = Field(..., description="清洗后的短信文本")
|
||||||
|
label: str = Field(..., description="分类标签,ham表示正常短信,spam表示垃圾短信")
|
||||||
|
label_num: int = Field(..., description="分类标签的数值表示,0表示ham,1表示spam")
|
||||||
|
probability: Dict[str, float] = Field(..., description="分类概率")
|
||||||
|
|
||||||
|
|
||||||
|
def predict_spam_tool(input_data: PredictSpamInput) -> PredictSpamOutput:
|
||||||
|
"""短信分类工具"""
|
||||||
|
"""使用机器学习模型预测短信是否为垃圾短信"""
|
||||||
|
result = predict_spam(input_data.text, input_data.model_name)
|
||||||
|
return PredictSpamOutput(**result)
|
||||||
|
|
||||||
|
|
||||||
|
class ExplainPredictionInput(BaseModel):
|
||||||
|
"""分类结果解释工具输入"""
|
||||||
|
text: str = Field(..., description="短信文本")
|
||||||
|
label: str = Field(..., description="分类标签")
|
||||||
|
probability: float = Field(..., description="分类概率")
|
||||||
|
|
||||||
|
|
||||||
|
class ExplainPredictionOutput(BaseModel):
|
||||||
|
"""分类结果解释工具输出"""
|
||||||
|
content_summary: str = Field(..., description="短信内容摘要")
|
||||||
|
classification_reason: str = Field(..., description="分类原因")
|
||||||
|
confidence_level: str = Field(..., description="可信度级别,高、中、低")
|
||||||
|
confidence_explanation: str = Field(..., description="可信度解释")
|
||||||
|
suggestions: List[str] = Field(..., description="针对该短信的建议")
|
||||||
|
|
||||||
|
|
||||||
|
def explain_prediction_tool(input_data: ExplainPredictionInput) -> ExplainPredictionOutput:
|
||||||
|
"""分类结果解释工具"""
|
||||||
|
"""使用LLM解释分类结果并生成行动建议"""
|
||||||
|
result = llm_service.explain_prediction(input_data.text, input_data.label, input_data.probability)
|
||||||
|
return ExplainPredictionOutput(**result)
|
||||||
|
|
||||||
|
|
||||||
|
class TranslateTextInput(BaseModel):
|
||||||
|
"""文本翻译工具输入"""
|
||||||
|
text: str = Field(..., description="要翻译的文本")
|
||||||
|
target_lang: str = Field(default="zh-CN", description="目标语言")
|
||||||
|
|
||||||
|
|
||||||
|
class TranslateTextOutput(BaseModel):
|
||||||
|
"""文本翻译工具输出"""
|
||||||
|
translated_text: str = Field(..., description="翻译后的文本")
|
||||||
|
|
||||||
|
|
||||||
|
def translate_text_tool(input_data: TranslateTextInput) -> TranslateTextOutput:
|
||||||
|
"""文本翻译工具"""
|
||||||
|
"""将文本翻译成目标语言"""
|
||||||
|
translated_text = llm_service.translate_text(input_data.text, input_data.target_lang)
|
||||||
|
return TranslateTextOutput(translated_text=translated_text)
|
||||||
|
|
||||||
|
|
||||||
|
# 工具列表
|
||||||
|
TOOLS = [
|
||||||
|
{
|
||||||
|
"name": "predict_spam",
|
||||||
|
"description": "使用机器学习模型预测短信是否为垃圾短信",
|
||||||
|
"input_schema": PredictSpamInput,
|
||||||
|
"output_schema": PredictSpamOutput,
|
||||||
|
"func": predict_spam_tool
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "explain_prediction",
|
||||||
|
"description": "使用LLM解释分类结果并生成行动建议",
|
||||||
|
"input_schema": ExplainPredictionInput,
|
||||||
|
"output_schema": ExplainPredictionOutput,
|
||||||
|
"func": explain_prediction_tool
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "translate_text",
|
||||||
|
"description": "将文本翻译成目标语言",
|
||||||
|
"input_schema": TranslateTextInput,
|
||||||
|
"output_schema": TranslateTextOutput,
|
||||||
|
"func": translate_text_tool
|
||||||
|
}
|
||||||
|
]
|
||||||
@ -1,29 +0,0 @@
|
|||||||
from pydantic_settings import BaseSettings
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
"""项目配置类"""
|
|
||||||
# DeepSeek API配置
|
|
||||||
deepseek_api_key: str
|
|
||||||
|
|
||||||
# 项目路径配置
|
|
||||||
model_save_path: str = "./models"
|
|
||||||
data_path: str = "./data"
|
|
||||||
|
|
||||||
# 模型配置
|
|
||||||
random_state: int = 42
|
|
||||||
test_size: float = 0.2
|
|
||||||
|
|
||||||
# DeepSeek API配置
|
|
||||||
deepseek_api_base: str = "https://api.deepseek.com"
|
|
||||||
deepseek_model: str = "deepseek-chat"
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
import os
|
|
||||||
env_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env")
|
|
||||||
env_file_encoding = "utf-8"
|
|
||||||
|
|
||||||
|
|
||||||
# 创建全局配置实例
|
|
||||||
settings = Settings()
|
|
||||||
18
src/data/__init__.py
Normal file
18
src/data/__init__.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from .preprocess import load_data, clean_text, preprocess_data, split_data
|
||||||
|
from .validation import (
|
||||||
|
raw_data_schema,
|
||||||
|
processed_data_schema,
|
||||||
|
validate_raw_data,
|
||||||
|
validate_processed_data
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"load_data",
|
||||||
|
"clean_text",
|
||||||
|
"preprocess_data",
|
||||||
|
"split_data",
|
||||||
|
"raw_data_schema",
|
||||||
|
"processed_data_schema",
|
||||||
|
"validate_raw_data",
|
||||||
|
"validate_processed_data"
|
||||||
|
]
|
||||||
87
src/data/preprocess.py
Normal file
87
src/data/preprocess.py
Normal file
@ -0,0 +1,87 @@
|
|||||||
|
import polars as pl
|
||||||
|
import re
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(file_path: str) -> pl.DataFrame:
|
||||||
|
"""加载数据文件"""
|
||||||
|
# 读取CSV文件,只保留前两列,处理编码问题
|
||||||
|
df = pl.read_csv(
|
||||||
|
file_path,
|
||||||
|
columns=["v1", "v2"],
|
||||||
|
new_columns=["label", "text"],
|
||||||
|
encoding="ISO-8859-1" # 使用ISO-8859-1编码处理无效UTF-8序列
|
||||||
|
)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def clean_text(text: str) -> str:
|
||||||
|
"""清洗文本数据"""
|
||||||
|
# 去除HTML实体
|
||||||
|
text = re.sub(r'&[a-zA-Z]+;', ' ', text)
|
||||||
|
|
||||||
|
# 去除特殊字符,只保留字母、数字和基本标点
|
||||||
|
text = re.sub(r'[^a-zA-Z0-9\s.,!?]', ' ', text)
|
||||||
|
|
||||||
|
# 去除多余空格
|
||||||
|
text = re.sub(r'\s+', ' ', text).strip()
|
||||||
|
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_data(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""预处理数据"""
|
||||||
|
# 清洗文本
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col("text").map_elements(clean_text, return_dtype=pl.String).alias("clean_text")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 转换标签
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(pl.col("label") == "spam")
|
||||||
|
.then(1)
|
||||||
|
.otherwise(0)
|
||||||
|
.alias("label_num")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 移除空文本
|
||||||
|
df = df.filter(pl.col("clean_text") != "")
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def split_data(df: pl.DataFrame, test_size: float = 0.2, random_state: int = 42) -> tuple[pl.DataFrame, pl.DataFrame]:
|
||||||
|
"""划分训练集和测试集"""
|
||||||
|
# 打乱数据
|
||||||
|
df = df.sample(fraction=1, seed=random_state)
|
||||||
|
|
||||||
|
# 计算分割点
|
||||||
|
split_idx = int(len(df) * (1 - test_size))
|
||||||
|
|
||||||
|
# 分割数据
|
||||||
|
train_df = df[:split_idx]
|
||||||
|
test_df = df[split_idx:]
|
||||||
|
|
||||||
|
return train_df, test_df
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数,用于测试"""
|
||||||
|
# 加载数据
|
||||||
|
df = load_data("../data/spam.csv")
|
||||||
|
print(f"原始数据形状: {df.shape}")
|
||||||
|
print(f"标签分布: {df['label'].value_counts()}")
|
||||||
|
|
||||||
|
# 预处理数据
|
||||||
|
processed_df = preprocess_data(df)
|
||||||
|
print(f"预处理后数据形状: {processed_df.shape}")
|
||||||
|
print(f"预处理后标签分布: {processed_df['label_num'].value_counts()}")
|
||||||
|
|
||||||
|
# 划分训练集和测试集
|
||||||
|
train_df, test_df = split_data(processed_df)
|
||||||
|
print(f"训练集形状: {train_df.shape}")
|
||||||
|
print(f"测试集形状: {test_df.shape}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
54
src/data/validation.py
Normal file
54
src/data/validation.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import pandera as pa
|
||||||
|
from pandera import Column, DataFrameSchema
|
||||||
|
|
||||||
|
|
||||||
|
# 创建DataFrameSchema实例
|
||||||
|
raw_data_schema = DataFrameSchema({
|
||||||
|
"label": Column(str, checks=pa.Check.isin(["ham", "spam"])),
|
||||||
|
"text": Column(str, checks=pa.Check(lambda x: x.str.len() > 0))
|
||||||
|
})
|
||||||
|
|
||||||
|
processed_data_schema = DataFrameSchema({
|
||||||
|
"label": Column(str, checks=pa.Check.isin(["ham", "spam"])),
|
||||||
|
"text": Column(str, checks=pa.Check(lambda x: x.str.len() > 0)),
|
||||||
|
"clean_text": Column(str, checks=pa.Check(lambda x: x.str.len() > 0)),
|
||||||
|
"label_num": Column(int, checks=pa.Check.isin([0, 1]))
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
def validate_raw_data(df) -> bool:
|
||||||
|
"""验证原始数据"""
|
||||||
|
try:
|
||||||
|
raw_data_schema.validate(df)
|
||||||
|
return True
|
||||||
|
except pa.errors.SchemaError as e:
|
||||||
|
print(f"原始数据验证失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def validate_processed_data(df) -> bool:
|
||||||
|
"""验证预处理后数据"""
|
||||||
|
try:
|
||||||
|
processed_data_schema.validate(df)
|
||||||
|
return True
|
||||||
|
except pa.errors.SchemaError as e:
|
||||||
|
print(f"预处理后数据验证失败: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数,用于测试"""
|
||||||
|
import polars as pl
|
||||||
|
from .preprocess import load_data, preprocess_data
|
||||||
|
|
||||||
|
# 加载数据
|
||||||
|
df = load_data("../data/spam.csv")
|
||||||
|
print(f"原始数据验证: {validate_raw_data(df.to_pandas())}")
|
||||||
|
|
||||||
|
# 预处理数据
|
||||||
|
processed_df = preprocess_data(df)
|
||||||
|
print(f"预处理后数据验证: {validate_processed_data(processed_df.to_pandas())}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1,76 +0,0 @@
|
|||||||
import polars as pl
|
|
||||||
import pandas as pd
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
|
|
||||||
def load_data(file_path: str) -> pl.DataFrame:
|
|
||||||
"""使用Polars加载数据集"""
|
|
||||||
# 加载csv文件,处理编码问题
|
|
||||||
df = pl.read_csv(
|
|
||||||
file_path,
|
|
||||||
encoding="latin-1",
|
|
||||||
ignore_errors=True,
|
|
||||||
has_header=True
|
|
||||||
)
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def clean_data(df: pl.DataFrame) -> pl.DataFrame:
|
|
||||||
"""清洗数据集"""
|
|
||||||
# 查看数据集基本信息
|
|
||||||
print("原始数据集形状:", df.shape)
|
|
||||||
print("原始数据集列名:", df.columns)
|
|
||||||
|
|
||||||
# 删除不必要的列(最后三列都是空的)
|
|
||||||
df = df.drop(df.columns[-3:])
|
|
||||||
|
|
||||||
# 重命名列名
|
|
||||||
df = df.rename({
|
|
||||||
"v1": "label",
|
|
||||||
"v2": "message"
|
|
||||||
})
|
|
||||||
|
|
||||||
# 查看清洗后的数据集
|
|
||||||
print("清洗后数据集形状:", df.shape)
|
|
||||||
print("清洗后数据集列名:", df.columns)
|
|
||||||
print("标签分布:", df["label"].value_counts())
|
|
||||||
|
|
||||||
return df
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_data(df: pl.DataFrame) -> Tuple[pl.DataFrame, pl.Series]:
|
|
||||||
"""预处理数据,准备用于模型训练"""
|
|
||||||
# 将标签转换为数值(ham=0, spam=1)
|
|
||||||
df = df.with_columns(
|
|
||||||
pl.when(pl.col("label") == "spam").then(1).otherwise(0).alias("label")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 分离特征和标签
|
|
||||||
X = df.drop("label")
|
|
||||||
y = df["label"]
|
|
||||||
|
|
||||||
return X, y
|
|
||||||
|
|
||||||
|
|
||||||
def save_data(df: pl.DataFrame, file_path: str) -> None:
|
|
||||||
"""保存处理后的数据集"""
|
|
||||||
df.write_csv(file_path, index=False)
|
|
||||||
print(f"数据集已保存到: {file_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 测试数据处理流程
|
|
||||||
file_path = "../spam.csv"
|
|
||||||
# 检查文件是否存在
|
|
||||||
import os
|
|
||||||
if not os.path.exists(file_path):
|
|
||||||
file_path = "./spam.csv"
|
|
||||||
df = load_data(file_path)
|
|
||||||
df_cleaned = clean_data(df)
|
|
||||||
X, y = preprocess_data(df_cleaned)
|
|
||||||
|
|
||||||
print("特征数据形状:", X.shape)
|
|
||||||
print("标签数据形状:", y.shape)
|
|
||||||
print("前5行数据:")
|
|
||||||
print(df_cleaned.head())
|
|
||||||
7
src/llm/__init__.py
Normal file
7
src/llm/__init__.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from .llm_service import LLMService, llm_service, get_llm_service
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LLMService",
|
||||||
|
"llm_service",
|
||||||
|
"get_llm_service"
|
||||||
|
]
|
||||||
235
src/llm/llm_service.py
Normal file
235
src/llm/llm_service.py
Normal file
@ -0,0 +1,235 @@
|
|||||||
|
import os
|
||||||
|
from openai import OpenAI
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from typing import Dict, Any, Optional
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class LLMService:
|
||||||
|
"""LLM服务类,用于调用DeepSeek API"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化LLM服务"""
|
||||||
|
# 直接从.env文件读取API密钥
|
||||||
|
env_path = os.path.join(os.path.dirname(__file__), "..", "..", ".env")
|
||||||
|
env_vars = {}
|
||||||
|
|
||||||
|
if os.path.exists(env_path):
|
||||||
|
with open(env_path, "r") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if line and not line.startswith("#"):
|
||||||
|
key, value = line.split("=", 1)
|
||||||
|
env_vars[key.strip()] = value.strip()
|
||||||
|
|
||||||
|
# 获取API密钥和基础URL
|
||||||
|
self.api_key = env_vars.get("DEEPSEEK_API_KEY") or os.getenv("DEEPSEEK_API_KEY")
|
||||||
|
self.base_url = env_vars.get("DEEPSEEK_BASE_URL") or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
||||||
|
|
||||||
|
# 默认模型
|
||||||
|
self.default_model = "deepseek-chat"
|
||||||
|
|
||||||
|
# 延迟创建客户端,直到实际需要时
|
||||||
|
self.client = None
|
||||||
|
|
||||||
|
def _ensure_client(self):
|
||||||
|
"""确保客户端已初始化"""
|
||||||
|
if self.client is None:
|
||||||
|
# 确保API密钥存在
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("DEEPSEEK_API_KEY环境变量未设置")
|
||||||
|
|
||||||
|
# 创建客户端
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url
|
||||||
|
)
|
||||||
|
|
||||||
|
def _call_api(
|
||||||
|
self,
|
||||||
|
messages: list,
|
||||||
|
model: str = None,
|
||||||
|
temperature: float = 0.7,
|
||||||
|
max_tokens: int = 1000,
|
||||||
|
response_format: Optional[Dict[str, Any]] = None
|
||||||
|
) -> str:
|
||||||
|
"""调用LLM API"""
|
||||||
|
if model is None:
|
||||||
|
model = self.default_model
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 确保客户端已初始化
|
||||||
|
self._ensure_client()
|
||||||
|
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
temperature=temperature,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
response_format=response_format
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
print(f"LLM API调用失败: {e}")
|
||||||
|
print(f"错误类型: {type(e).__name__}")
|
||||||
|
# 尝试获取更详细的错误信息
|
||||||
|
if hasattr(e, 'response') and e.response:
|
||||||
|
try:
|
||||||
|
error_details = e.response.json()
|
||||||
|
print(f"详细错误信息: {json.dumps(error_details, ensure_ascii=False, indent=2)}")
|
||||||
|
except:
|
||||||
|
print(f"原始响应内容: {e.response.text}")
|
||||||
|
traceback.print_exc()
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def translate_text(self, text: str, target_lang: str = "zh-CN") -> str:
|
||||||
|
"""将文本翻译成目标语言"""
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"你是一个专业的翻译助手,请将给定的文本翻译成{target_lang}。保持原文的意思,翻译要准确、自然。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": text
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
return self._call_api(messages, temperature=0.3)
|
||||||
|
|
||||||
|
def explain_prediction(self, text: str, label: str, probability: float) -> Dict[str, Any]:
|
||||||
|
"""解释模型预测结果"""
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "你是一个专业的短信分类解释专家。请根据给定的短信内容、分类结果和概率,生成一个清晰、简洁的解释。解释应该包括:\n1. 短信的主要内容\n2. 为什么被分类为垃圾短信或正常短信\n3. 分类的可信度\n\n请使用结构化的JSON格式输出,包含以下字段:\n- content_summary: 短信内容摘要\n- classification_reason: 分类原因\n- confidence_level: 可信度级别(高、中、低)\n- confidence_explanation: 可信度解释\n- suggestions: 针对该短信的建议"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"短信内容: {text}\n分类结果: {label}\n分类概率: {probability:.2f}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
response = self._call_api(
|
||||||
|
messages,
|
||||||
|
temperature=0.5,
|
||||||
|
response_format={"type": "json_object"}
|
||||||
|
)
|
||||||
|
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
result = json.loads(response)
|
||||||
|
|
||||||
|
# 确保suggestions字段是一个列表
|
||||||
|
if not isinstance(result.get("suggestions"), list):
|
||||||
|
suggestions_text = result.get("suggestions", "")
|
||||||
|
# 如果是字符串,尝试分割成列表
|
||||||
|
if isinstance(suggestions_text, str):
|
||||||
|
# 移除可能的前缀
|
||||||
|
suggestions_text = suggestions_text.replace("建议用户:", "")
|
||||||
|
suggestions_text = suggestions_text.replace("建议:", "")
|
||||||
|
# 按序号分割
|
||||||
|
import re
|
||||||
|
# 匹配数字. 开头的模式
|
||||||
|
suggestions = re.split(r'\d+\.\s*', suggestions_text)
|
||||||
|
# 过滤掉空字符串
|
||||||
|
suggestions = [s.strip() for s in suggestions if s.strip()]
|
||||||
|
result["suggestions"] = suggestions
|
||||||
|
else:
|
||||||
|
# 如果是其他类型,设置为空列表
|
||||||
|
result["suggestions"] = []
|
||||||
|
|
||||||
|
return result
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return {
|
||||||
|
"content_summary": "",
|
||||||
|
"classification_reason": "",
|
||||||
|
"confidence_level": "中",
|
||||||
|
"confidence_explanation": "",
|
||||||
|
"suggestions": []
|
||||||
|
}
|
||||||
|
|
||||||
|
def generate_advice(self, text: str, label: str) -> str:
|
||||||
|
"""生成行动建议"""
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "你是一个专业的短信管理顾问。请根据给定的短信内容和分类结果,生成具体、实用的行动建议。建议要简洁明了,针对性强。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"短信内容: {text}\n分类结果: {label}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
return self._call_api(messages, temperature=0.5)
|
||||||
|
|
||||||
|
def analyze_spam_patterns(self, spam_texts: list) -> str:
|
||||||
|
"""分析垃圾短信的模式"""
|
||||||
|
if len(spam_texts) == 0:
|
||||||
|
return "没有提供垃圾短信样本"
|
||||||
|
|
||||||
|
# 限制短信数量,避免超过API限制
|
||||||
|
sample_texts = spam_texts[:5]
|
||||||
|
texts_str = "\n".join([f"{i+1}. {text}" for i, text in enumerate(sample_texts)])
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "你是一个垃圾短信模式分析专家。请分析给定的垃圾短信样本,总结出常见的模式和特征。分析要全面、准确,包括但不限于:\n1. 内容特征\n2. 语言风格\n3. 发送目的\n4. 常见关键词\n\n请使用简洁明了的语言输出分析结果。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"垃圾短信样本:\n{texts_str}"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
return self._call_api(messages, temperature=0.5)
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局LLM服务实例
|
||||||
|
llm_service = LLMService()
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm_service() -> LLMService:
|
||||||
|
"""获取LLM服务实例"""
|
||||||
|
return llm_service
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数,用于测试"""
|
||||||
|
# 测试翻译功能
|
||||||
|
test_text = "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's"
|
||||||
|
print("翻译测试:")
|
||||||
|
print(f"原文: {test_text}")
|
||||||
|
translation = llm_service.translate_text(test_text)
|
||||||
|
print(f"译文: {translation}")
|
||||||
|
|
||||||
|
# 测试解释功能
|
||||||
|
print("\n解释测试:")
|
||||||
|
explanation = llm_service.explain_prediction(test_text, "spam", 0.95)
|
||||||
|
print(f"解释结果: {explanation}")
|
||||||
|
|
||||||
|
# 测试建议功能
|
||||||
|
print("\n建议测试:")
|
||||||
|
advice = llm_service.generate_advice(test_text, "spam")
|
||||||
|
print(f"建议: {advice}")
|
||||||
|
|
||||||
|
# 测试模式分析功能
|
||||||
|
print("\n模式分析测试:")
|
||||||
|
spam_samples = [
|
||||||
|
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005.",
|
||||||
|
"WINNER!! As a valued network customer you have been selected to receivea <20>900 prize reward!",
|
||||||
|
"Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free!"
|
||||||
|
]
|
||||||
|
pattern_analysis = llm_service.analyze_spam_patterns(spam_samples)
|
||||||
|
print(f"模式分析: {pattern_analysis}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
@ -1,316 +0,0 @@
|
|||||||
import polars as pl
|
|
||||||
import pandas as pd
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
from sklearn.linear_model import LogisticRegression
|
|
||||||
from sklearn.ensemble import RandomForestClassifier
|
|
||||||
import lightgbm as lgb
|
|
||||||
from sklearn.model_selection import train_test_split, GridSearchCV
|
|
||||||
from sklearn.metrics import (
|
|
||||||
accuracy_score, precision_score, recall_score, f1_score,
|
|
||||||
roc_auc_score, classification_report, confusion_matrix
|
|
||||||
)
|
|
||||||
import joblib
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Tuple, Dict, Any, Optional
|
|
||||||
from config import settings
|
|
||||||
|
|
||||||
|
|
||||||
class SpamClassifier:
|
|
||||||
"""垃圾短信分类器"""
|
|
||||||
def __init__(self, model_name: str = "lightgbm"):
|
|
||||||
"""初始化分类器"""
|
|
||||||
self.model_name = model_name
|
|
||||||
self.model = None
|
|
||||||
self.vectorizer = None
|
|
||||||
self.load_model()
|
|
||||||
|
|
||||||
def load_model(self):
|
|
||||||
"""加载模型和向量器"""
|
|
||||||
model_dir = Path(settings.model_save_path)
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
model_path = model_dir / f"{self.model_name}_model.joblib"
|
|
||||||
self.model = joblib.load(model_path)
|
|
||||||
print(f"模型已从: {model_path} 加载")
|
|
||||||
|
|
||||||
# 加载向量器
|
|
||||||
vectorizer_path = model_dir / f"{self.model_name}_vectorizer.joblib"
|
|
||||||
self.vectorizer = joblib.load(vectorizer_path)
|
|
||||||
print(f"向量器已从: {vectorizer_path} 加载")
|
|
||||||
|
|
||||||
def classify(self, message: str) -> Dict[str, Any]:
|
|
||||||
"""分类单条短信"""
|
|
||||||
# 将短信转换为向量
|
|
||||||
message_vector = self.vectorizer.transform([message])
|
|
||||||
|
|
||||||
# 预测标签和置信度
|
|
||||||
label = self.model.predict(message_vector)[0]
|
|
||||||
confidence = self.model.predict_proba(message_vector)[0][label]
|
|
||||||
|
|
||||||
# 转换标签为文本
|
|
||||||
label_text = "spam" if label == 1 else "ham"
|
|
||||||
|
|
||||||
return {
|
|
||||||
"label": label_text,
|
|
||||||
"confidence": confidence
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def extract_features(
|
|
||||||
X_train: pl.Series,
|
|
||||||
X_test: pl.Series,
|
|
||||||
max_features: int = 1000
|
|
||||||
) -> Tuple[Any, Any, TfidfVectorizer]:
|
|
||||||
"""
|
|
||||||
使用TF-IDF提取文本特征
|
|
||||||
|
|
||||||
Args:
|
|
||||||
X_train: 训练集文本
|
|
||||||
X_test: 测试集文本
|
|
||||||
max_features: 最大特征数
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
训练集特征、测试集特征、TF-IDF向量化器
|
|
||||||
"""
|
|
||||||
# 将Polars Series转换为Pandas Series
|
|
||||||
X_train_pd = X_train.to_pandas()
|
|
||||||
X_test_pd = X_test.to_pandas()
|
|
||||||
|
|
||||||
# 初始化TF-IDF向量化器
|
|
||||||
tfidf = TfidfVectorizer(
|
|
||||||
max_features=max_features,
|
|
||||||
stop_words="english",
|
|
||||||
ngram_range=(1, 2)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 拟合并转换训练集
|
|
||||||
X_train_tfidf = tfidf.fit_transform(X_train_pd)
|
|
||||||
|
|
||||||
# 转换测试集
|
|
||||||
X_test_tfidf = tfidf.transform(X_test_pd)
|
|
||||||
|
|
||||||
return X_train_tfidf, X_test_tfidf, tfidf
|
|
||||||
|
|
||||||
|
|
||||||
def train_logistic_regression(
|
|
||||||
X_train: Any,
|
|
||||||
y_train: pl.Series
|
|
||||||
) -> LogisticRegression:
|
|
||||||
"""
|
|
||||||
训练Logistic Regression模型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
X_train: 训练集特征
|
|
||||||
y_train: 训练集标签
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
训练好的Logistic Regression模型
|
|
||||||
"""
|
|
||||||
# 将Polars Series转换为Pandas Series
|
|
||||||
y_train_pd = y_train.to_pandas()
|
|
||||||
|
|
||||||
# 初始化Logistic Regression模型
|
|
||||||
log_reg = LogisticRegression(
|
|
||||||
random_state=settings.random_state,
|
|
||||||
max_iter=1000,
|
|
||||||
class_weight="balanced"
|
|
||||||
)
|
|
||||||
|
|
||||||
# 训练模型
|
|
||||||
log_reg.fit(X_train, y_train_pd)
|
|
||||||
|
|
||||||
return log_reg
|
|
||||||
|
|
||||||
|
|
||||||
def train_lightgbm(
|
|
||||||
X_train: Any,
|
|
||||||
y_train: pl.Series
|
|
||||||
) -> lgb.LGBMClassifier:
|
|
||||||
"""
|
|
||||||
训练LightGBM模型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
X_train: 训练集特征
|
|
||||||
y_train: 训练集标签
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
训练好的LightGBM模型
|
|
||||||
"""
|
|
||||||
# 将Polars Series转换为Pandas Series
|
|
||||||
y_train_pd = y_train.to_pandas()
|
|
||||||
|
|
||||||
# 初始化LightGBM模型
|
|
||||||
lgb_clf = lgb.LGBMClassifier(
|
|
||||||
random_state=settings.random_state,
|
|
||||||
class_weight="balanced",
|
|
||||||
n_estimators=1000,
|
|
||||||
learning_rate=0.1,
|
|
||||||
num_leaves=31
|
|
||||||
)
|
|
||||||
|
|
||||||
# 训练模型
|
|
||||||
lgb_clf.fit(X_train, y_train_pd)
|
|
||||||
|
|
||||||
return lgb_clf
|
|
||||||
|
|
||||||
|
|
||||||
def evaluate_model(
|
|
||||||
model: Any,
|
|
||||||
X_test: Any,
|
|
||||||
y_test: pl.Series
|
|
||||||
) -> Dict[str, float]:
|
|
||||||
"""
|
|
||||||
评估模型性能
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: 训练好的模型
|
|
||||||
X_test: 测试集特征
|
|
||||||
y_test: 测试集标签
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
模型评估指标
|
|
||||||
"""
|
|
||||||
# 将Polars Series转换为Pandas Series
|
|
||||||
y_test_pd = y_test.to_pandas()
|
|
||||||
|
|
||||||
# 预测
|
|
||||||
y_pred = model.predict(X_test)
|
|
||||||
y_pred_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, 'predict_proba') else None
|
|
||||||
|
|
||||||
# 计算评估指标
|
|
||||||
metrics = {
|
|
||||||
"accuracy": accuracy_score(y_test_pd, y_pred),
|
|
||||||
"precision": precision_score(y_test_pd, y_pred),
|
|
||||||
"recall": recall_score(y_test_pd, y_pred),
|
|
||||||
"f1": f1_score(y_test_pd, y_pred)
|
|
||||||
}
|
|
||||||
|
|
||||||
# 计算ROC-AUC(如果模型支持概率预测)
|
|
||||||
if y_pred_proba is not None:
|
|
||||||
metrics["roc_auc"] = roc_auc_score(y_test_pd, y_pred_proba)
|
|
||||||
|
|
||||||
# 打印分类报告和混淆矩阵
|
|
||||||
print("分类报告:")
|
|
||||||
print(classification_report(y_test_pd, y_pred))
|
|
||||||
|
|
||||||
print("混淆矩阵:")
|
|
||||||
print(confusion_matrix(y_test_pd, y_pred))
|
|
||||||
|
|
||||||
return metrics
|
|
||||||
|
|
||||||
|
|
||||||
def save_model(
|
|
||||||
model: Any,
|
|
||||||
model_name: str,
|
|
||||||
vectorizer: Any = None
|
|
||||||
) -> None:
|
|
||||||
"""
|
|
||||||
保存模型和向量器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model: 训练好的模型
|
|
||||||
model_name: 模型名称
|
|
||||||
vectorizer: TF-IDF向量化器
|
|
||||||
"""
|
|
||||||
# 创建模型保存目录
|
|
||||||
model_dir = Path(settings.model_save_path)
|
|
||||||
model_dir.mkdir(exist_ok=True)
|
|
||||||
|
|
||||||
# 保存模型
|
|
||||||
model_path = model_dir / f"{model_name}_model.joblib"
|
|
||||||
joblib.dump(model, model_path)
|
|
||||||
print(f"模型已保存到: {model_path}")
|
|
||||||
|
|
||||||
# 保存向量器(如果提供)
|
|
||||||
if vectorizer is not None:
|
|
||||||
vectorizer_path = model_dir / f"{model_name}_vectorizer.joblib"
|
|
||||||
joblib.dump(vectorizer, vectorizer_path)
|
|
||||||
print(f"向量器已保存到: {vectorizer_path}")
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(
|
|
||||||
model_name: str
|
|
||||||
) -> Tuple[Any, Any]:
|
|
||||||
"""
|
|
||||||
加载模型和向量器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_name: 模型名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
加载的模型和向量器
|
|
||||||
"""
|
|
||||||
# 创建模型保存目录
|
|
||||||
model_dir = Path(settings.model_save_path)
|
|
||||||
|
|
||||||
# 加载模型
|
|
||||||
model_path = model_dir / f"{model_name}_model.joblib"
|
|
||||||
model = joblib.load(model_path)
|
|
||||||
print(f"模型已从: {model_path} 加载")
|
|
||||||
|
|
||||||
# 加载向量器
|
|
||||||
vectorizer_path = model_dir / f"{model_name}_vectorizer.joblib"
|
|
||||||
vectorizer = joblib.load(vectorizer_path)
|
|
||||||
print(f"向量器已从: {vectorizer_path} 加载")
|
|
||||||
|
|
||||||
return model, vectorizer
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""机器学习主函数"""
|
|
||||||
# 1. 加载数据集
|
|
||||||
print("正在加载数据集...")
|
|
||||||
df = pl.read_csv("../spam.csv", encoding="latin-1", ignore_errors=True)
|
|
||||||
|
|
||||||
# 2. 清洗数据集
|
|
||||||
print("正在清洗数据集...")
|
|
||||||
df = df.drop(df.columns[-3:])
|
|
||||||
df = df.rename({"v1": "label", "v2": "message"})
|
|
||||||
df = df.with_columns(
|
|
||||||
pl.when(pl.col("label") == "spam").then(1).otherwise(0).alias("label")
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. 分离特征和标签
|
|
||||||
X = df["message"]
|
|
||||||
y = df["label"]
|
|
||||||
|
|
||||||
# 4. 划分训练集和测试集
|
|
||||||
X_train, X_test, y_train, y_test = train_test_split(
|
|
||||||
X, y, test_size=settings.test_size, random_state=settings.random_state, stratify=y
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"训练集大小: {len(X_train)}")
|
|
||||||
print(f"测试集大小: {len(X_test)}")
|
|
||||||
|
|
||||||
# 5. 特征提取
|
|
||||||
print("正在提取特征...")
|
|
||||||
X_train_tfidf, X_test_tfidf, tfidf = extract_features(X_train, X_test)
|
|
||||||
|
|
||||||
# 6. 训练Logistic Regression模型
|
|
||||||
print("\n正在训练Logistic Regression模型...")
|
|
||||||
log_reg_model = train_logistic_regression(X_train_tfidf, y_train)
|
|
||||||
|
|
||||||
# 7. 评估Logistic Regression模型
|
|
||||||
print("\n评估Logistic Regression模型:")
|
|
||||||
log_reg_metrics = evaluate_model(log_reg_model, X_test_tfidf, y_test)
|
|
||||||
print(f"Logistic Regression指标: {log_reg_metrics}")
|
|
||||||
|
|
||||||
# 8. 训练LightGBM模型
|
|
||||||
print("\n正在训练LightGBM模型...")
|
|
||||||
lgb_model = train_lightgbm(X_train_tfidf, y_train)
|
|
||||||
|
|
||||||
# 9. 评估LightGBM模型
|
|
||||||
print("\n评估LightGBM模型:")
|
|
||||||
lgb_metrics = evaluate_model(lgb_model, X_test_tfidf, y_test)
|
|
||||||
print(f"LightGBM指标: {lgb_metrics}")
|
|
||||||
|
|
||||||
# 10. 保存模型
|
|
||||||
print("\n正在保存模型...")
|
|
||||||
save_model(log_reg_model, "logistic_regression", tfidf)
|
|
||||||
save_model(lgb_model, "lightgbm", tfidf)
|
|
||||||
|
|
||||||
print("\n机器学习流程完成!")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
24
src/main.py
24
src/main.py
@ -1,24 +0,0 @@
|
|||||||
from data_processing import load_data, clean_data, save_data
|
|
||||||
from translation import translate_dataset
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
"""主函数"""
|
|
||||||
# 1. 加载数据集
|
|
||||||
print("正在加载数据集...")
|
|
||||||
df = load_data("../spam.csv")
|
|
||||||
|
|
||||||
# 2. 清洗数据集
|
|
||||||
print("\n正在清洗数据集...")
|
|
||||||
df_cleaned = clean_data(df)
|
|
||||||
|
|
||||||
# 3. 只翻译前10条短信进行测试
|
|
||||||
print("\n正在翻译前10条短信进行测试...")
|
|
||||||
df_test = df_cleaned.head(10)
|
|
||||||
translated_path = translate_dataset(df_test)
|
|
||||||
|
|
||||||
print(f"\n测试完成!翻译后的测试数据集已保存到: {translated_path}")
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@ -1,150 +0,0 @@
|
|||||||
import requests
|
|
||||||
from typing import List, Dict, Any
|
|
||||||
from config import settings
|
|
||||||
from machine_learning import SpamClassifier
|
|
||||||
from translation import translate_text
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleSpamAnalysis:
|
|
||||||
"""简单的垃圾短信分析系统"""
|
|
||||||
|
|
||||||
def __init__(self, model_name: str = "lightgbm"):
|
|
||||||
"""初始化分析系统"""
|
|
||||||
self.classifier = SpamClassifier(model_name)
|
|
||||||
|
|
||||||
def analyze(self, message: str, is_english: bool = True) -> Dict[str, Any]:
|
|
||||||
"""分析单条短信"""
|
|
||||||
# 1. 翻译短信
|
|
||||||
message_zh = translate_text(message, "zh-CN") if is_english else message
|
|
||||||
|
|
||||||
# 2. 分类短信
|
|
||||||
classification = self.classifier.classify(message)
|
|
||||||
|
|
||||||
# 3. 提取关键词
|
|
||||||
key_words = self.extract_keywords(message)
|
|
||||||
|
|
||||||
# 4. 生成解释和建议
|
|
||||||
reason, suggestion = self.generate_explanation(key_words, classification["label"])
|
|
||||||
|
|
||||||
# 5. 使用DeepSeek API生成更详细的解释
|
|
||||||
detailed_explanation = self.generate_detailed_explanation(
|
|
||||||
message, message_zh, classification["label"], key_words
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"original_message": message,
|
|
||||||
"translated_message": message_zh,
|
|
||||||
"classification": classification,
|
|
||||||
"key_words": key_words,
|
|
||||||
"reason": reason,
|
|
||||||
"suggestion": suggestion,
|
|
||||||
"detailed_explanation": detailed_explanation
|
|
||||||
}
|
|
||||||
|
|
||||||
def extract_keywords(self, message: str, top_n: int = 5) -> List[str]:
|
|
||||||
"""提取关键词"""
|
|
||||||
words = message.lower().split()
|
|
||||||
stop_words = set([
|
|
||||||
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
|
|
||||||
"with", "by", "from", "up", "down", "about", "above", "below", "of",
|
|
||||||
"is", "are", "was", "were", "be", "been", "being", "have", "has",
|
|
||||||
"had", "do", "does", "did", "will", "would", "shall", "should",
|
|
||||||
"may", "might", "must", "can", "could", "not", "no", "yes", "if",
|
|
||||||
"then", "than", "so", "because", "as", "when", "where", "who", "which",
|
|
||||||
"that", "this", "these", "those", "i", "me", "my", "mine", "you",
|
|
||||||
"your", "yours", "he", "him", "his", "she", "her", "hers", "it",
|
|
||||||
"its", "we", "us", "our", "ours", "they", "them", "their", "theirs"
|
|
||||||
])
|
|
||||||
|
|
||||||
keywords = [word for word in words if word not in stop_words and len(word) > 2]
|
|
||||||
return keywords[:top_n]
|
|
||||||
|
|
||||||
def generate_explanation(self, key_words: List[str], label: str) -> tuple:
|
|
||||||
"""生成基本解释和建议"""
|
|
||||||
if label == "spam":
|
|
||||||
reason = f"该短信包含垃圾短信特征词: {', '.join(key_words)}"
|
|
||||||
suggestion = "建议立即删除该短信,不要点击任何链接,不要回复,避免上当受骗"
|
|
||||||
else:
|
|
||||||
reason = f"该短信为正常短信,包含常用词汇: {', '.join(key_words)}"
|
|
||||||
suggestion = "可以正常回复和处理该短信"
|
|
||||||
return reason, suggestion
|
|
||||||
|
|
||||||
def generate_detailed_explanation(self, message: str, message_zh: str, label: str, key_words: List[str]) -> str:
|
|
||||||
"""使用DeepSeek API生成详细解释"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {settings.deepseek_api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
prompt = f"""
|
|
||||||
分析以下短信:
|
|
||||||
英文:{message}
|
|
||||||
中文:{message_zh}
|
|
||||||
分类结果:{label}
|
|
||||||
关键词:{', '.join(key_words)}
|
|
||||||
|
|
||||||
请提供:
|
|
||||||
1. 详细的分类原因
|
|
||||||
2. 短信的主要特征
|
|
||||||
3. 针对该短信的具体建议
|
|
||||||
4. 如何识别类似的短信
|
|
||||||
|
|
||||||
请使用中文回答,保持简洁明了。
|
|
||||||
"""
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": settings.deepseek_model,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": "你是一名专业的垃圾短信分析师,请根据提供的信息进行详细分析。"
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": prompt
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 500,
|
|
||||||
"temperature": 0.1
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
f"{settings.deepseek_api_base}/chat/completions",
|
|
||||||
headers=headers,
|
|
||||||
json=payload,
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
explanation = result["choices"][0]["message"]["content"].strip()
|
|
||||||
return explanation
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
print(f"生成详细解释失败: {e}")
|
|
||||||
return "无法生成详细解释,请检查API连接。"
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 初始化分析系统
|
|
||||||
analyzer = SimpleSpamAnalysis()
|
|
||||||
|
|
||||||
# 测试短信
|
|
||||||
test_messages = [
|
|
||||||
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
|
|
||||||
"Ok lar... Joking wif u oni...",
|
|
||||||
"WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only."
|
|
||||||
]
|
|
||||||
|
|
||||||
# 分析短信
|
|
||||||
for i, message in enumerate(test_messages):
|
|
||||||
print(f"\n=== 短信分析结果 {i+1} ===")
|
|
||||||
result = analyzer.analyze(message)
|
|
||||||
|
|
||||||
print(f"原始短信: {result['original_message']}")
|
|
||||||
print(f"中文翻译: {result['translated_message']}")
|
|
||||||
print(f"分类结果: {result['classification']['label']} (置信度: {result['classification']['confidence']:.2f})")
|
|
||||||
print(f"关键词: {', '.join(result['key_words'])}")
|
|
||||||
print(f"原因: {result['reason']}")
|
|
||||||
print(f"建议: {result['suggestion']}")
|
|
||||||
print(f"详细解释: {result['detailed_explanation']}")
|
|
||||||
233
src/streamlit_app.py
Normal file
233
src/streamlit_app.py
Normal file
@ -0,0 +1,233 @@
|
|||||||
|
import streamlit as st
|
||||||
|
import pandas as pd
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
||||||
|
|
||||||
|
from src.agent import agent, get_agent
|
||||||
|
from src.data import load_data, preprocess_data, split_data
|
||||||
|
from src.models import train_model, save_model, load_model, compare_models
|
||||||
|
|
||||||
|
# 设置页面配置
|
||||||
|
st.set_page_config(
|
||||||
|
page_title="垃圾短信分类系统",
|
||||||
|
page_icon="📱",
|
||||||
|
layout="wide",
|
||||||
|
initial_sidebar_state="expanded"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 应用标题
|
||||||
|
st.title("📱 垃圾短信分类系统")
|
||||||
|
st.markdown("---")
|
||||||
|
|
||||||
|
# 侧边栏
|
||||||
|
with st.sidebar:
|
||||||
|
st.header("系统配置")
|
||||||
|
|
||||||
|
# 模型选择
|
||||||
|
model_option = st.selectbox(
|
||||||
|
"选择模型",
|
||||||
|
options=["lightgbm", "logistic_regression"],
|
||||||
|
index=0,
|
||||||
|
help="选择用于分类的机器学习模型"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 语言选择
|
||||||
|
lang_option = st.selectbox(
|
||||||
|
"输出语言",
|
||||||
|
options=["中文", "英文"],
|
||||||
|
index=0,
|
||||||
|
help="选择分类结果和解释的输出语言"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 系统说明
|
||||||
|
st.markdown("---")
|
||||||
|
st.header("关于系统")
|
||||||
|
st.info(
|
||||||
|
"这是一个基于传统机器学习 + LLM + Agent的垃圾短信分类系统。\n"\
|
||||||
|
"- 使用LightGBM和Logistic Regression进行分类\n"\
|
||||||
|
"- 利用DeepSeek LLM解释分类结果\n"\
|
||||||
|
"- 通过Agent实现工具调用和结果整合"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 主内容区域
|
||||||
|
col1, col2 = st.columns([1, 1], gap="large")
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
# 短信输入
|
||||||
|
st.header("输入短信")
|
||||||
|
|
||||||
|
# 单条短信输入
|
||||||
|
sms_input = st.text_area(
|
||||||
|
"请输入要分类的短信",
|
||||||
|
height=200,
|
||||||
|
placeholder="例如:WINNER!! As a valued network customer you have been selected to receivea £900 prize reward!"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 分类按钮
|
||||||
|
classify_button = st.button(
|
||||||
|
"📊 开始分类",
|
||||||
|
type="primary",
|
||||||
|
use_container_width=True,
|
||||||
|
disabled=sms_input.strip() == ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# 批量上传功能
|
||||||
|
st.markdown("---")
|
||||||
|
st.header("批量分类")
|
||||||
|
uploaded_file = st.file_uploader(
|
||||||
|
"上传CSV文件(包含text列)",
|
||||||
|
type=["csv"],
|
||||||
|
help="上传包含短信文本的CSV文件,系统将自动分类"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 模型训练功能(可选)
|
||||||
|
with st.expander("🔧 模型训练", expanded=False):
|
||||||
|
if st.button("重新训练模型"):
|
||||||
|
with st.spinner("正在训练模型..."):
|
||||||
|
try:
|
||||||
|
# 加载和预处理数据
|
||||||
|
df = load_data("../data/spam.csv")
|
||||||
|
processed_df = preprocess_data(df)
|
||||||
|
train_df, test_df = split_data(processed_df)
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
model, params = train_model(train_df, model_type=model_option)
|
||||||
|
save_model(model, model_option)
|
||||||
|
|
||||||
|
st.success(f"{model_option} 模型训练完成!")
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"模型训练失败:{e}")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
# 分类结果显示
|
||||||
|
st.header("分类结果")
|
||||||
|
|
||||||
|
# 单条短信分类结果
|
||||||
|
if classify_button and sms_input.strip():
|
||||||
|
with st.spinner("正在分类..."):
|
||||||
|
try:
|
||||||
|
# 使用Agent进行分类和解释
|
||||||
|
result = agent.classify_and_explain(sms_input)
|
||||||
|
|
||||||
|
# 显示分类结果
|
||||||
|
st.subheader("📋 分类标签")
|
||||||
|
|
||||||
|
# 根据标签显示不同的样式
|
||||||
|
if result['classification']['label'] == "spam":
|
||||||
|
st.error(f"⚠️ 这是一条**垃圾短信**")
|
||||||
|
else:
|
||||||
|
st.success(f"✅ 这是一条**正常短信**")
|
||||||
|
|
||||||
|
# 显示概率
|
||||||
|
st.subheader("📊 分类概率")
|
||||||
|
prob_df = pd.DataFrame.from_dict(
|
||||||
|
result['classification']['probability'],
|
||||||
|
orient='index',
|
||||||
|
columns=['概率']
|
||||||
|
)
|
||||||
|
st.bar_chart(prob_df)
|
||||||
|
|
||||||
|
# 显示详细结果
|
||||||
|
st.subheader("📝 详细结果")
|
||||||
|
with st.expander("查看详细分类结果", expanded=True):
|
||||||
|
st.json(result['classification'], expanded=False)
|
||||||
|
|
||||||
|
# 显示解释和建议
|
||||||
|
st.subheader("🤔 结果解释")
|
||||||
|
with st.expander("查看分类解释", expanded=True):
|
||||||
|
st.write(f"**内容摘要**:{result['explanation']['content_summary']}")
|
||||||
|
st.write(f"**分类原因**:{result['explanation']['classification_reason']}")
|
||||||
|
st.write(f"**可信度**:{result['explanation']['confidence_level']} - {result['explanation']['confidence_explanation']}")
|
||||||
|
|
||||||
|
st.subheader("💡 行动建议")
|
||||||
|
for i, suggestion in enumerate(result['explanation']['suggestions']):
|
||||||
|
st.write(f"{i+1}. {suggestion}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"分类失败:{e}")
|
||||||
|
|
||||||
|
# 批量分类结果
|
||||||
|
if uploaded_file is not None:
|
||||||
|
with st.spinner("正在批量分类..."):
|
||||||
|
try:
|
||||||
|
# 读取上传的文件
|
||||||
|
df = pd.read_csv(uploaded_file)
|
||||||
|
|
||||||
|
if "text" not in df.columns:
|
||||||
|
st.error("CSV文件必须包含'text'列")
|
||||||
|
else:
|
||||||
|
# 限制处理数量
|
||||||
|
max_rows = 100
|
||||||
|
if len(df) > max_rows:
|
||||||
|
st.warning(f"文件包含 {len(df)} 条记录,仅处理前 {max_rows} 条")
|
||||||
|
df = df.head(max_rows)
|
||||||
|
|
||||||
|
# 批量分类
|
||||||
|
results = []
|
||||||
|
for text in df["text"].tolist():
|
||||||
|
result = agent.classify_and_explain(text)
|
||||||
|
results.append({
|
||||||
|
"text": text,
|
||||||
|
"label": result['classification']['label'],
|
||||||
|
"spam_probability": result['classification']['probability']['spam'],
|
||||||
|
"ham_probability": result['classification']['probability']['ham'],
|
||||||
|
"content_summary": result['explanation']['content_summary'],
|
||||||
|
"classification_reason": result['explanation']['classification_reason']
|
||||||
|
})
|
||||||
|
|
||||||
|
# 转换为DataFrame
|
||||||
|
results_df = pd.DataFrame(results)
|
||||||
|
|
||||||
|
# 显示结果统计
|
||||||
|
st.subheader("📊 分类统计")
|
||||||
|
label_counts = results_df["label"].value_counts()
|
||||||
|
st.bar_chart(label_counts)
|
||||||
|
|
||||||
|
# 显示结果表格
|
||||||
|
st.subheader("📋 分类结果")
|
||||||
|
st.dataframe(
|
||||||
|
results_df,
|
||||||
|
use_container_width=True,
|
||||||
|
column_config={
|
||||||
|
"text": st.column_config.TextColumn("短信内容", width="medium"),
|
||||||
|
"label": st.column_config.TextColumn("分类标签"),
|
||||||
|
"spam_probability": st.column_config.ProgressColumn(
|
||||||
|
"垃圾短信概率",
|
||||||
|
format="%.2f",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=1.0
|
||||||
|
),
|
||||||
|
"ham_probability": st.column_config.ProgressColumn(
|
||||||
|
"正常短信概率",
|
||||||
|
format="%.2f",
|
||||||
|
min_value=0.0,
|
||||||
|
max_value=1.0
|
||||||
|
),
|
||||||
|
"content_summary": st.column_config.TextColumn("内容摘要", width="medium"),
|
||||||
|
"classification_reason": st.column_config.TextColumn("分类原因", width="medium")
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 下载结果
|
||||||
|
st.subheader("💾 下载结果")
|
||||||
|
csv = results_df.to_csv(index=False).encode('utf-8')
|
||||||
|
st.download_button(
|
||||||
|
label="下载分类结果 (CSV)",
|
||||||
|
data=csv,
|
||||||
|
file_name="spam_classification_results.csv",
|
||||||
|
mime="text/csv",
|
||||||
|
use_container_width=True
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"批量分类失败:{e}")
|
||||||
|
|
||||||
|
# 页脚
|
||||||
|
st.markdown("---")
|
||||||
|
st.markdown(
|
||||||
|
"<center>© 2026 垃圾短信分类系统 | 基于传统机器学习 + LLM + Agent</center>",
|
||||||
|
unsafe_allow_html=True
|
||||||
|
)
|
||||||
@ -1,130 +0,0 @@
|
|||||||
import requests
|
|
||||||
from typing import List, Dict
|
|
||||||
from config import settings
|
|
||||||
import time
|
|
||||||
|
|
||||||
|
|
||||||
def translate_text(text: str, target_lang: str = "zh-CN") -> str:
|
|
||||||
"""
|
|
||||||
使用DeepSeek API将文本翻译成目标语言
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: 要翻译的文本
|
|
||||||
target_lang: 目标语言,默认为中文(zh-CN)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
翻译后的文本
|
|
||||||
"""
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {settings.deepseek_api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": settings.deepseek_model,
|
|
||||||
"messages": [
|
|
||||||
{
|
|
||||||
"role": "system",
|
|
||||||
"content": f"You are a professional translator. Translate the following text to {target_lang}. Keep the original meaning and tone. Do not add any additional information."
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": text
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"max_tokens": 1000,
|
|
||||||
"temperature": 0.1
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = requests.post(
|
|
||||||
f"{settings.deepseek_api_base}/chat/completions",
|
|
||||||
headers=headers,
|
|
||||||
json=payload,
|
|
||||||
timeout=30
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
|
|
||||||
result = response.json()
|
|
||||||
translated_text = result["choices"][0]["message"]["content"].strip()
|
|
||||||
return translated_text
|
|
||||||
except requests.exceptions.RequestException as e:
|
|
||||||
print(f"翻译失败: {e}")
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
def translate_batch(texts: List[str], target_lang: str = "zh-CN", batch_size: int = 10) -> List[str]:
|
|
||||||
"""
|
|
||||||
批量翻译文本
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: 要翻译的文本列表
|
|
||||||
target_lang: 目标语言,默认为中文(zh-CN)
|
|
||||||
batch_size: 批量大小,默认为10
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
翻译后的文本列表
|
|
||||||
"""
|
|
||||||
translated_texts = []
|
|
||||||
|
|
||||||
for i in range(0, len(texts), batch_size):
|
|
||||||
batch = texts[i:i+batch_size]
|
|
||||||
batch_translated = []
|
|
||||||
|
|
||||||
for text in batch:
|
|
||||||
translated = translate_text(text, target_lang)
|
|
||||||
batch_translated.append(translated)
|
|
||||||
# 添加延迟,避免API限流
|
|
||||||
time.sleep(0.5)
|
|
||||||
|
|
||||||
translated_texts.extend(batch_translated)
|
|
||||||
print(f"已翻译 {min(i+batch_size, len(texts))}/{len(texts)} 条文本")
|
|
||||||
|
|
||||||
return translated_texts
|
|
||||||
|
|
||||||
|
|
||||||
def translate_dataset(df, text_column: str = "message", target_column: str = "message_zh") -> str:
|
|
||||||
"""
|
|
||||||
翻译数据集中的文本列
|
|
||||||
|
|
||||||
Args:
|
|
||||||
df: Polars DataFrame
|
|
||||||
text_column: 要翻译的文本列名
|
|
||||||
target_column: 翻译后的文本列名
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
翻译后的数据集文件路径
|
|
||||||
"""
|
|
||||||
import polars as pl
|
|
||||||
import os
|
|
||||||
|
|
||||||
# 创建data目录(如果不存在)
|
|
||||||
os.makedirs(settings.data_path, exist_ok=True)
|
|
||||||
|
|
||||||
# 提取文本列表
|
|
||||||
texts = df[text_column].to_list()
|
|
||||||
|
|
||||||
# 翻译文本
|
|
||||||
print(f"开始翻译 {len(texts)} 条文本...")
|
|
||||||
translated_texts = translate_batch(texts)
|
|
||||||
|
|
||||||
# 添加翻译后的列到数据集
|
|
||||||
df = df.with_columns(
|
|
||||||
pl.Series(target_column, translated_texts)
|
|
||||||
)
|
|
||||||
|
|
||||||
# 保存翻译后的数据集
|
|
||||||
output_path = f"{settings.data_path}/spam_zh.csv"
|
|
||||||
df.write_csv(output_path, index=False)
|
|
||||||
|
|
||||||
print(f"翻译后的数据集已保存到: {output_path}")
|
|
||||||
print(f"翻译完成!共翻译了 {len(texts)} 条文本")
|
|
||||||
return output_path
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 测试翻译功能
|
|
||||||
test_text = "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's"
|
|
||||||
translated = translate_text(test_text)
|
|
||||||
print(f"原文: {test_text}")
|
|
||||||
print(f"译文: {translated}")
|
|
||||||
@ -1,31 +0,0 @@
|
|||||||
import sys
|
|
||||||
import os
|
|
||||||
|
|
||||||
# 添加src目录到Python路径
|
|
||||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
|
||||||
|
|
||||||
from simple_agent import SimpleSpamAnalysis
|
|
||||||
|
|
||||||
|
|
||||||
# 测试短信
|
|
||||||
test_messages = [
|
|
||||||
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
|
|
||||||
"Ok lar... Joking wif u oni...",
|
|
||||||
"WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only."
|
|
||||||
]
|
|
||||||
|
|
||||||
# 初始化分析系统
|
|
||||||
analyzer = SimpleSpamAnalysis()
|
|
||||||
|
|
||||||
# 分析短信
|
|
||||||
for i, message in enumerate(test_messages):
|
|
||||||
print(f"\n=== 短信分析结果 {i+1} ===")
|
|
||||||
result = analyzer.analyze(message)
|
|
||||||
|
|
||||||
print(f"原始短信: {result['original_message'][:100]}...")
|
|
||||||
print(f"中文翻译: {result['translated_message'][:100]}...")
|
|
||||||
print(f"分类结果: {result['classification']['label']} (置信度: {result['classification']['confidence']:.2f})")
|
|
||||||
print(f"关键词: {', '.join(result['key_words'])}")
|
|
||||||
print(f"原因: {result['reason']}")
|
|
||||||
print(f"建议: {result['suggestion']}")
|
|
||||||
print(f"详细解释: {result['detailed_explanation'][:200]}...")
|
|
||||||
@ -1,7 +0,0 @@
|
|||||||
from src.translation import translate_text
|
|
||||||
|
|
||||||
# 测试单个翻译功能
|
|
||||||
test_text = "Hello, how are you?"
|
|
||||||
print(f"原文: {test_text}")
|
|
||||||
translated = translate_text(test_text)
|
|
||||||
print(f"译文: {translated}")
|
|
||||||
Loading…
Reference in New Issue
Block a user