diff --git a/.env.example b/.env.example index 9149aee..1302fca 100644 --- a/.env.example +++ b/.env.example @@ -1,10 +1,10 @@ # DeepSeek API Configuration -DEEPSEEK_API_KEY="your-deepseek-api-key-here" -DEEPSEEK_BASE_URL="https://api.deepseek.com" +DEEPSEEK_API_KEY=your-deepseek-api-key-here +DEEPSEEK_BASE_URL=https://api.deepseek.com # OpenAI Compatibility (for DeepSeek) -OPENAI_API_KEY="${DEEPSEEK_API_KEY}" -OPENAI_BASE_URL="${DEEPSEEK_BASE_URL}" +OPENAI_API_KEY=your-deepseek-api-key-here +OPENAI_BASE_URL=https://api.deepseek.com # Project Configuration PROJECT_NAME="spam-classification-system" diff --git a/README.md b/README.md index 9e9244f..29e2a9b 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,188 @@ # 垃圾短信分类系统 -## 项目概述 -本项目是一个基于**传统机器学习 + LLM + Agent**的垃圾短信分类系统,旨在实现可落地的智能预测与行动建议。系统使用传统机器学习完成可量化的垃圾短信预测任务,再用 LLM + Agent 把预测结果变成可执行的决策/建议,并保证输出结构化、可追溯、可复现。 +> **机器学习 (Python) 课程设计** + + +## 👥 团队成员 + + +| 姓名 | 学号 | 贡献 | +|------|------|------| +| 朱指乐 | 2311020135 | 数据处理、模型训练 | +| 肖康 | 2311020125 | Agent 开发、LLM 服务 | +| 龙思富 | 2311020114 | 可视化、优化streamlit应用、文档撰写| + + +## 📝 项目简介 + + +本项目是一个基于**传统机器学习 + LLM + Agent**的垃圾短信分类系统,旨在实现可落地的智能预测与行动建议。系统使用 SMS Spam Collection 数据集,通过传统机器学习完成垃圾短信的量化预测,再利用 LLM 和 Agent 技术将预测结果转化为结构化、可执行的决策建议,确保输出结果可追溯、可复现。 + + +## 🚀 快速开始 + + +```bash +# 克隆仓库 +git clone http://hblu.top:3000/MachineLearning2025/CourseDesign.git +cd CourseDesign + +# 安装依赖 +pip install uv -i https://mirrors.aliyun.com/pypi/simple/ +uv config set index-url https://mirrors.aliyun.com/pypi/simple/ +uv sync + +# 配置环境变量 +cp .env.example .env +# 编辑 .env 填入 API Key + +# 运行 Demo +uv run streamlit run src/streamlit_app.py +``` + + +## 1️⃣ 问题定义与数据 + + +### 1.1 任务描述 + + +本项目是一个二分类任务,目标是自动识别垃圾短信(spam)和正常短信(ham)。业务目标是构建一个高准确率、可解释的垃圾短信分类系统,帮助用户有效过滤垃圾信息,提升信息安全和用户体验。 + + +### 1.2 数据来源 + + +| 项目 | 说明 | +|------|------| +| 数据集名称 | SMS Spam Collection | +| 数据链接 | `https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset` | +| 样本量 | 5,572 条 | +| 特征数 | 1 个(短信文本) | + + +### 1.3 数据切分与防泄漏 + + +数据按 8:2 比例分割为训练集和测试集,确保模型在独立的测试集上进行评估。在数据预处理和特征工程阶段,所有操作仅在训练集上进行,避免信息泄漏到测试集。使用 TF-IDF 进行文本向量化时,同样严格遵循先训练后应用的原则。 + + +## 2️⃣ 机器学习流水线 + + +### 2.1 基线模型 + + +| 模型 | 指标 | 结果 | +|------|------|------| +| Logistic Regression | 准确率 | 0.978 | +| Logistic Regression | F1 分数(Macro) | 0.959 | + + +### 2.2 进阶模型 + + +| 模型 | 指标 | 结果 | +|------|------|------| +| LightGBM | 准确率 | 0.985 | +| LightGBM | F1 分数(Macro) | 0.971 | + + +### 2.3 误差分析 + + +模型在以下类型的样本上表现相对较差: +1. 包含大量特殊字符或缩写的短信 +2. 内容模糊、边界不清的促销短信 +3. 混合中英文的短信 +4. 模仿正常短信格式的垃圾短信 + +这主要是因为文本特征提取方法(TF-IDF)对语义理解有限,无法完全捕捉复杂的语言模式和上下文信息。 + + +## 3️⃣ Agent 实现 + + +### 3.1 工具定义 + + +| 工具名 | 功能 | 输入 | 输出 | +|--------|------|------|------| +| `predict_spam` | 使用机器学习模型预测短信是否为垃圾短信 | 短信文本 | 分类结果(spam/ham)和概率 | +| `explain_prediction` | 解释模型预测结果并生成行动建议 | 短信文本、分类结果、概率 | 结构化的解释和建议 | +| `translate_text` | 将文本翻译成目标语言 | 文本、目标语言 | 翻译后的文本 | + + +### 3.2 决策流程 + + +Agent 按照以下流程执行任务: +1. 接收用户提供的短信文本 +2. 使用 `predict_spam` 工具进行分类预测 +3. 使用 `explain_prediction` 工具解释分类结果并生成行动建议 +4. 如果短信为英文,可选择使用 `translate_text` 工具翻译成中文 +5. 向用户提供清晰、完整的分类结果、解释和建议 + + +### 3.3 案例展示 + + +**输入**: +``` +Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's +``` + +**输出**: +```json +{ + "classification": { + "label": "spam", + "probability": { + "spam": 0.98, + "ham": 0.02 + } + }, + "explanation": { + "content_summary": "这是一条关于免费赢取足总杯决赛门票的竞赛广告短信", + "classification_reason": "短信包含'Free entry'、'win'、'comp'等典型的垃圾短信关键词,且提供了需要用户回复的电话号码,符合垃圾短信的特征", + "confidence_level": "高", + "confidence_explanation": "模型以98%的概率将其分类为垃圾短信,基于文本中的垃圾短信特征词汇和结构", + "suggestions": ["不要回复此短信,避免产生额外费用", "将此号码加入黑名单", "删除该短信"] + } +} +``` + + +## 4️⃣ 开发心得 + + +### 4.1 主要困难与解决方案 + + +1. **文本特征提取**:原始文本数据难以直接用于机器学习模型,解决方案是使用 TF-IDF 进行文本向量化,将文本转化为数值特征。 +2. **模型可解释性**:传统机器学习模型的预测结果缺乏可解释性,解决方案是集成 LLM 服务,对模型预测结果进行自然语言解释。 +3. **API 集成与错误处理**:LLM API 调用可能会遇到各种错误,解决方案是实现完善的错误处理机制,确保系统稳定性。 + + +### 4.2 对 AI 辅助编程的感受 + + +AI 辅助编程工具(如 GitHub Copilot)在代码编写和问题解决方面提供了很大帮助,特别是在处理重复性任务和学习新框架时。它可以快速生成代码模板,提供解决方案建议,显著提高开发效率。但同时也需要注意,AI 生成的代码可能存在错误或不符合项目规范,需要人工仔细检查和调试。 + + +### 4.3 局限与未来改进 + + +1. **模型性能**:当前模型在处理复杂语言模式和上下文理解方面仍有提升空间,可以考虑使用更先进的文本表示方法(如 BERT)。 +2. **多语言支持**:目前系统主要支持中英文短信,未来可以扩展到更多语言。 +3. **实时性**:可以优化模型推理速度,实现实时分类功能。 +4. **用户界面**:可以进一步改进 Streamlit 应用的用户体验,增加更多交互功能和可视化效果。 + ## 技术栈 + | 组件 | 技术 | 版本要求 | |------|------|----------| | 项目管理 | uv | 最新版 | @@ -13,189 +190,33 @@ | 数据验证 | pandera | >=0.18.0 | | 机器学习 | scikit-learn + lightgbm | sklearn>=1.3.0, lightgbm>=4.0.0 | | LLM 框架 | openai | >=1.0.0 | -| Agent 框架 | pydantic + pydantic-ai | pydantic>=2.0.0 | +| Agent 框架 | pydantic | pydantic>=2.0.0 | | 可视化 | streamlit | >=1.20.0 | | 文本处理 | nltk | >=3.8.0 | -## 项目结构 - -``` -├── src/ -│ ├── data/ # 数据处理模块 -│ │ ├── __init__.py -│ │ ├── preprocess.py # 数据预处理 -│ │ └── validation.py # 数据验证 -│ ├── models/ # 机器学习模型 -│ │ ├── __init__.py -│ │ ├── train.py # 模型训练 -│ │ ├── evaluate.py # 模型评估 -│ │ └── predict.py # 模型预测 -│ ├── agent/ # Agent 模块 -│ │ ├── __init__.py -│ │ ├── agent.py # Agent 核心逻辑 -│ │ └── tools.py # Agent 工具定义 -│ ├── llm/ # LLM 服务 -│ │ ├── __init__.py -│ │ └── llm_service.py # DeepSeek API 集成 -│ ├── main.py # 主程序入口 -│ └── streamlit_app.py # Streamlit 可视化应用 -├── data/ # 数据集目录 -│ └── spam.csv # 垃圾短信数据集 -├── models/ # 模型保存目录 -├── .env.example # 环境变量示例 -├── .gitignore # Git 忽略文件 -├── pyproject.toml # 项目配置 -└── README.md # 项目说明文档 -``` - -## 快速开始 - -### 1. 安装依赖 - -```bash -# 安装 uv(如果尚未安装) -pip install uv -i https://mirrors.aliyun.com/pypi/simple/ - -# 配置 PyPI 镜像 -uv config set index-url https://mirrors.aliyun.com/pypi/simple/ - -# 同步依赖 -uv sync -``` - -### 2. 配置 API 密钥 - -```bash -# 复制环境变量示例文件 -cp .env.example .env - -# 编辑 .env 文件,填入你的 DeepSeek API Key -# DEEPSEEK_API_KEY="your-key-here" -``` - -### 3. 运行应用 - -#### 方式 A:Streamlit 可视化应用(推荐) - -```bash -uv run streamlit run src/streamlit_app.py -``` - -#### 方式 B:命令行 Agent Demo - -```bash -uv run python src/agent_app.py -``` - -#### 方式 C:运行模型训练 - -```bash -uv run python src/models/train.py -``` - -## 使用说明 - -### 1. Streamlit 应用 - -启动 Streamlit 应用后,你可以: - -- **单条短信分类**:在左侧输入框中输入短信内容,点击「开始分类」按钮,系统将返回分类结果、解释和建议 -- **批量分类**:上传包含 `text` 列的 CSV 文件,系统将自动对所有短信进行分类 -- **模型选择**:在侧边栏选择使用 LightGBM 或 Logistic Regression 模型 -- **语言选择**:在侧边栏选择输出结果的语言(中文/英文) -- **重新训练模型**:在「模型训练」展开栏中点击「重新训练模型」按钮 - -### 2. 命令行使用 - -```python -from src.agent import agent - -# 分类并解释短信 -result = agent.classify_and_explain("Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005.") - -print(f"分类结果: {result['classification']['label']}") -print(f"分类概率: {result['classification']['probability']}") -print(f"解释: {result['explanation']['classification_reason']}") -print(f"建议: {result['explanation']['suggestions']}") -``` - -## 模型训练 - -### 训练流程 - -1. **数据加载**:使用 Polars 加载 `data/spam.csv` 数据集 -2. **数据预处理**: - - 清洗文本(去除 HTML 实体、特殊字符、多余空格) - - 转换标签(spam → 1,ham → 0) - - 数据验证(使用 Pandera 验证数据质量) -3. **数据分割**:按 8:2 比例分割训练集和测试集 -4. **特征工程**:使用 TF-IDF 进行文本向量化 -5. **模型训练**: - - 基线模型:Logistic Regression - - 强模型:LightGBM -6. **模型评估**:计算准确率、精确率、召回率、F1 分数 -7. **模型保存**:将训练好的模型保存到 `models/` 目录 - -### 评估结果 - -| 模型 | 准确率 | 精确率(Macro) | 召回率(Macro) | F1 分数(Macro) | -|------|--------|----------------|----------------|------------------| -| Logistic Regression | 0.978 | 0.964 | 0.954 | 0.959 | -| LightGBM | 0.985 | 0.974 | 0.968 | 0.971 | - -## 系统架构 - -### 1. 数据层 - -- **数据加载**:使用 Polars 高效加载和处理大规模数据 -- **数据清洗**:实现可复现的数据清洗流水线 -- **数据验证**:使用 Pandera 定义数据 Schema,确保数据质量 - -### 2. 模型层 - -- **特征工程**:TF-IDF 文本向量化 -- **模型训练**:支持多种模型(LightGBM、Logistic Regression) -- **模型评估**:全面的评估指标和可视化 -- **模型管理**:模型保存、加载和版本控制 - -### 3. LLM 层 - -- **文本翻译**:将英文短信翻译成中文 -- **结果解释**:解释模型分类结果的原因 -- **行动建议**:根据分类结果生成可执行的建议 -- **模式分析**:分析垃圾短信的常见模式 - -### 4. Agent 层 - -- **工具定义**: - - `predict_spam`:使用机器学习模型预测短信是否为垃圾短信 - - `explain_prediction`:使用 LLM 解释分类结果并生成行动建议 - - `translate_text`:将文本翻译成目标语言 -- **Agent 逻辑**:实现工具调用、结果整合和自然语言生成 -- **结构化输出**:确保输出结果可追溯、可复现 - -## 贡献指南 - -1. 朱指乐-仓库创建-项目初始化-数据获取-数据预处理-模型训练-模型评估-模型保存 -2. 肖康-LLM服务-实现文本解释和建议生成功能-agent实现-工具调用-结构化输出 -3. 龙思富-可视化-交互式可视化应用-文档编写 ## 许可证 + MIT License + ## 致谢 + - 感谢 [DeepSeek](https://www.deepseek.com/) 提供的 LLM API - 感谢 Kaggle 提供的 [SMS Spam Collection](https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset) 数据集 - 感谢所有开源库的贡献者 + ## 联系方式 + 如有问题或建议,欢迎通过以下方式联系: - 项目地址:[http://hblu.top:3000/MachineLearning2025/CourseDesign](http://hblu.top:3000/MachineLearning2025/CourseDesign) -- 邮箱:your-email@example.com +- 邮箱:xxxxxxxxxx@gmail.com + --- diff --git a/src/agent/agent.py b/src/agent/agent.py index 0f5cd0b..d60ef1d 100644 --- a/src/agent/agent.py +++ b/src/agent/agent.py @@ -50,10 +50,20 @@ class SpamClassificationAgent: self.api_key = env_vars.get("DEEPSEEK_API_KEY") or os.getenv("DEEPSEEK_API_KEY") self.base_url = env_vars.get("DEEPSEEK_BASE_URL") or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com") - # 验证并修复base_url,确保包含协议前缀 - if self.base_url and not (self.base_url.startswith("http://") or self.base_url.startswith("https://")): - print(f"警告: base_url '{self.base_url}' 缺少协议前缀,将添加 'https://'") - self.base_url = f"https://{self.base_url}" + # 清理并修复base_url,确保包含协议前缀 + if self.base_url: + # 清理URL:移除多余的空格、引号和反引号 + self.base_url = self.base_url.strip() + # 移除可能的引号 + for quote_char in ['"', "'", '`']: + if self.base_url.startswith(quote_char) and self.base_url.endswith(quote_char): + self.base_url = self.base_url[1:-1].strip() + break + + # 确保包含协议前缀 + if not (self.base_url.startswith("http://") or self.base_url.startswith("https://")): + print(f"警告: base_url '{self.base_url}' 缺少协议前缀,将添加 'https://'") + self.base_url = f"https://{self.base_url}" # 延迟创建客户端,直到实际需要时 self.client = None diff --git a/src/llm/llm_service.py b/src/llm/llm_service.py index 067c563..97ce4eb 100644 --- a/src/llm/llm_service.py +++ b/src/llm/llm_service.py @@ -29,10 +29,20 @@ class LLMService: self.api_key = env_vars.get("DEEPSEEK_API_KEY") or os.getenv("DEEPSEEK_API_KEY") self.base_url = env_vars.get("DEEPSEEK_BASE_URL") or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com") - # 验证并修复base_url,确保包含协议前缀 - if self.base_url and not (self.base_url.startswith("http://") or self.base_url.startswith("https://")): - print(f"警告: base_url '{self.base_url}' 缺少协议前缀,将添加 'https://'") - self.base_url = f"https://{self.base_url}" + # 清理并修复base_url,确保包含协议前缀 + if self.base_url: + # 清理URL:移除多余的空格、引号和反引号 + self.base_url = self.base_url.strip() + # 移除可能的引号 + for quote_char in ['"', "'", '`']: + if self.base_url.startswith(quote_char) and self.base_url.endswith(quote_char): + self.base_url = self.base_url[1:-1].strip() + break + + # 确保包含协议前缀 + if not (self.base_url.startswith("http://") or self.base_url.startswith("https://")): + print(f"警告: base_url '{self.base_url}' 缺少协议前缀,将添加 'https://'") + self.base_url = f"https://{self.base_url}" # 默认模型 self.default_model = "deepseek-chat"