feat: 初始化垃圾短信分类项目基础结构
添加项目核心文件结构,包括: - 配置文件和环境变量管理 - 数据处理和翻译模块 - 机器学习模型训练和评估 - 基于LLM的智能分析Agent - 测试脚本和项目文档
This commit is contained in:
parent
d597ddd2ff
commit
aa10e463b4
6
.env.example
Normal file
6
.env.example
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
# DeepSeek API Configuration
|
||||||
|
DEEPSEEK_API_KEY="your-deepseek-api-key-here"
|
||||||
|
|
||||||
|
# Project Configuration
|
||||||
|
MODEL_SAVE_PATH="./models"
|
||||||
|
DATA_PATH="./data"
|
||||||
54
.gitignore
vendored
Normal file
54
.gitignore
vendored
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
.env.local
|
||||||
|
.env.development.local
|
||||||
|
.env.test.local
|
||||||
|
.env.production.local
|
||||||
|
|
||||||
|
# Dependencies
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# Data
|
||||||
|
data/
|
||||||
|
*.csv
|
||||||
|
*.parquet
|
||||||
|
*.h5
|
||||||
|
|
||||||
|
# Models
|
||||||
|
models/
|
||||||
|
*.joblib
|
||||||
|
*.pkl
|
||||||
|
*.model
|
||||||
|
*.txt
|
||||||
|
|
||||||
|
# Logs
|
||||||
|
logs/
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# Build
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
*.egg-info/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# Testing
|
||||||
|
.pytest_cache/
|
||||||
|
.coverage
|
||||||
|
htmlcov/
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
49
.trae/documents/垃圾短信分类项目实现计划.md
Normal file
49
.trae/documents/垃圾短信分类项目实现计划.md
Normal file
@ -0,0 +1,49 @@
|
|||||||
|
# 垃圾短信分类项目实现计划
|
||||||
|
|
||||||
|
## 1. 项目结构搭建
|
||||||
|
- 创建项目目录结构,包括 `src`、`data`、`models` 等目录
|
||||||
|
- 初始化项目依赖,使用 uv 进行管理
|
||||||
|
- 创建配置文件和环境变量管理
|
||||||
|
|
||||||
|
## 2. 数据处理
|
||||||
|
- 使用 Polars 加载和清洗 spam.csv 数据集
|
||||||
|
- 将英文短信翻译成中文,使用 DeepSeek API
|
||||||
|
- 使用 Pandera 定义数据 Schema 进行验证
|
||||||
|
- 数据预处理和特征工程
|
||||||
|
|
||||||
|
## 3. 机器学习模型
|
||||||
|
- 实现至少两个模型:Logistic Regression 作为基线,LightGBM 作为强模型
|
||||||
|
- 模型训练、验证和评估
|
||||||
|
- 模型保存与加载
|
||||||
|
- 达到 F1 ≥ 0.70 或 ROC-AUC ≥ 0.75 的性能指标
|
||||||
|
|
||||||
|
## 4. LLM 集成
|
||||||
|
- 使用 DeepSeek API 进行短信内容解释和归因
|
||||||
|
- 生成结构化的行动建议
|
||||||
|
- 确保输出可追溯、可复现
|
||||||
|
|
||||||
|
## 5. Agent 框架
|
||||||
|
- 使用 pydantic-ai 构建结构化输出的 Agent
|
||||||
|
- 实现至少两个工具:ML 预测工具和评估工具
|
||||||
|
- 构建完整的工具调用流程
|
||||||
|
|
||||||
|
## 6. 项目测试和部署
|
||||||
|
- 编写单元测试和集成测试
|
||||||
|
- 确保项目可在教师机上运行
|
||||||
|
- 准备项目展示材料
|
||||||
|
|
||||||
|
## 技术栈
|
||||||
|
- Python 3.12
|
||||||
|
- uv 进行项目管理
|
||||||
|
- Polars + Pandas 进行数据处理
|
||||||
|
- Pandera 进行数据验证
|
||||||
|
- Scikit-learn + LightGBM 进行机器学习
|
||||||
|
- pydantic-ai 作为 Agent 框架
|
||||||
|
- DeepSeek API 作为 LLM 提供方
|
||||||
|
|
||||||
|
## 预期成果
|
||||||
|
- 一个完整的垃圾短信分类系统
|
||||||
|
- 中文翻译后的数据集
|
||||||
|
- 可复现的机器学习模型
|
||||||
|
- 基于 LLM 的智能建议生成
|
||||||
|
- 结构化、可追溯的输出
|
||||||
41
pyproject.toml
Normal file
41
pyproject.toml
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
[tool.uv]
|
||||||
|
index-url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "spam-classification"
|
||||||
|
version = "0.1.0"
|
||||||
|
authors = [{ name = "Your Name", email = "your.email@example.com" }]
|
||||||
|
description = "Spam message classification with ML and LLM integration"
|
||||||
|
readme = "README.md"
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
|
||||||
|
[project.dependencies]
|
||||||
|
pandas = ">=2.2"
|
||||||
|
polars = ">=0.20"
|
||||||
|
pandera = ">=0.18"
|
||||||
|
scikit-learn = ">=1.4"
|
||||||
|
lightgbm = ">=4.3"
|
||||||
|
pydantic = ">=2.5"
|
||||||
|
pydantic-ai = ">=0.3"
|
||||||
|
python-dotenv = ">=1.0"
|
||||||
|
requests = ">=2.31"
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=7.4",
|
||||||
|
"ruff>=0.2"
|
||||||
|
]
|
||||||
|
|
||||||
|
[build-system]
|
||||||
|
requires = ["uv>=0.1.0"]
|
||||||
|
build-backend = "uv.build_api"
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
select = ["E", "F", "W"]
|
||||||
|
line-length = 88
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests"]
|
||||||
|
python_files = "test_*.py"
|
||||||
|
python_classes = "Test*"
|
||||||
|
python_functions = "test_*"
|
||||||
50
simple_test.py
Normal file
50
simple_test.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
import requests
|
||||||
|
|
||||||
|
# 直接测试DeepSeek API
|
||||||
|
def test_deepseek_api():
|
||||||
|
api_key = "sk-591e36a6b1bd4b34b663b466ff22085e"
|
||||||
|
api_base = "https://api.deepseek.com"
|
||||||
|
model = "deepseek-chat"
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You are a professional translator. Translate the following text to Chinese. Keep the original meaning and tone. Do not add any additional information."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hello, how are you?"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 1000,
|
||||||
|
"temperature": 0.1
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{api_base}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
print("API响应:", result)
|
||||||
|
translated_text = result["choices"][0]["message"]["content"].strip()
|
||||||
|
print(f"翻译结果: {translated_text}")
|
||||||
|
return translated_text
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"翻译失败: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_deepseek_api()
|
||||||
250
src/agent.py
Normal file
250
src/agent.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
import polars as pl
|
||||||
|
import pandas as pd
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic_ai import AI
|
||||||
|
from pydantic_ai.agent import Tool
|
||||||
|
import joblib
|
||||||
|
from pathlib import Path
|
||||||
|
from config import settings
|
||||||
|
from machine_learning import extract_features
|
||||||
|
from translation import translate_text
|
||||||
|
|
||||||
|
|
||||||
|
class Message(BaseModel):
|
||||||
|
"""短信模型"""
|
||||||
|
content: str = Field(..., description="短信内容")
|
||||||
|
is_english: bool = Field(default=True, description="短信是否为英文")
|
||||||
|
|
||||||
|
|
||||||
|
class ClassificationResult(BaseModel):
|
||||||
|
"""分类结果模型"""
|
||||||
|
label: str = Field(..., description="分类标签,ham或spam")
|
||||||
|
confidence: float = Field(..., description="分类置信度")
|
||||||
|
|
||||||
|
|
||||||
|
class Explanation(BaseModel):
|
||||||
|
"""解释模型"""
|
||||||
|
key_words: List[str] = Field(..., description="关键特征词")
|
||||||
|
reason: str = Field(..., description="分类原因")
|
||||||
|
suggestion: str = Field(..., description="行动建议")
|
||||||
|
|
||||||
|
|
||||||
|
class AnalysisResult(BaseModel):
|
||||||
|
"""分析结果模型"""
|
||||||
|
message: str = Field(..., description="原始短信")
|
||||||
|
message_zh: str = Field(..., description="中文翻译")
|
||||||
|
classification: ClassificationResult = Field(..., description="分类结果")
|
||||||
|
explanation: Explanation = Field(..., description="分类解释和建议")
|
||||||
|
|
||||||
|
|
||||||
|
class SpamClassifier:
|
||||||
|
"""垃圾短信分类器"""
|
||||||
|
def __init__(self, model_name: str = "lightgbm"):
|
||||||
|
"""初始化分类器"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.model = None
|
||||||
|
self.vectorizer = None
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
"""加载模型和向量器"""
|
||||||
|
model_dir = Path(settings.model_save_path)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
model_path = model_dir / f"{self.model_name}_model.joblib"
|
||||||
|
self.model = joblib.load(model_path)
|
||||||
|
print(f"模型已从: {model_path} 加载")
|
||||||
|
|
||||||
|
# 加载向量器
|
||||||
|
vectorizer_path = model_dir / f"{self.model_name}_vectorizer.joblib"
|
||||||
|
self.vectorizer = joblib.load(vectorizer_path)
|
||||||
|
print(f"向量器已从: {vectorizer_path} 加载")
|
||||||
|
|
||||||
|
def classify(self, message: str) -> Dict[str, Any]:
|
||||||
|
"""分类单条短信"""
|
||||||
|
# 将短信转换为向量
|
||||||
|
message_vector = self.vectorizer.transform([message])
|
||||||
|
|
||||||
|
# 预测标签和置信度
|
||||||
|
label = self.model.predict(message_vector)[0]
|
||||||
|
confidence = self.model.predict_proba(message_vector)[0][label]
|
||||||
|
|
||||||
|
# 转换标签为文本
|
||||||
|
label_text = "spam" if label == 1 else "ham"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"label": label_text,
|
||||||
|
"confidence": confidence
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class SpamAnalysisTool(Tool):
|
||||||
|
"""垃圾短信分析工具"""
|
||||||
|
|
||||||
|
def __init__(self, classifier: SpamClassifier):
|
||||||
|
super().__init__(name="spam_analysis_tool", description="分析短信是否为垃圾短信,并提供解释和建议")
|
||||||
|
self.classifier = classifier
|
||||||
|
|
||||||
|
async def __call__(self, message: str, is_english: bool = True) -> AnalysisResult:
|
||||||
|
"""调用工具分析短信"""
|
||||||
|
# 如果是英文,翻译成中文
|
||||||
|
message_zh = translate_text(message, "zh-CN") if is_english else message
|
||||||
|
|
||||||
|
# 分类短信
|
||||||
|
classification = self.classifier.classify(message)
|
||||||
|
|
||||||
|
# 生成解释和建议
|
||||||
|
explanation = self.generate_explanation(message, classification["label"])
|
||||||
|
|
||||||
|
return AnalysisResult(
|
||||||
|
message=message,
|
||||||
|
message_zh=message_zh,
|
||||||
|
classification=ClassificationResult(
|
||||||
|
label=classification["label"],
|
||||||
|
confidence=classification["confidence"]
|
||||||
|
),
|
||||||
|
explanation=explanation
|
||||||
|
)
|
||||||
|
|
||||||
|
def generate_explanation(self, message: str, label: str) -> Explanation:
|
||||||
|
"""生成解释和建议"""
|
||||||
|
# 简单的关键词提取(实际项目中可以使用更复杂的方法)
|
||||||
|
key_words = self.extract_keywords(message)
|
||||||
|
|
||||||
|
# 生成原因和建议
|
||||||
|
if label == "spam":
|
||||||
|
reason = f"该短信包含垃圾短信特征词: {', '.join(key_words)}"
|
||||||
|
suggestion = "建议立即删除该短信,不要点击任何链接,不要回复,避免上当受骗"
|
||||||
|
else:
|
||||||
|
reason = f"该短信为正常短信,包含常用词汇: {', '.join(key_words)}"
|
||||||
|
suggestion = "可以正常回复和处理该短信"
|
||||||
|
|
||||||
|
return Explanation(
|
||||||
|
key_words=key_words,
|
||||||
|
reason=reason,
|
||||||
|
suggestion=suggestion
|
||||||
|
)
|
||||||
|
|
||||||
|
def extract_keywords(self, message: str, top_n: int = 5) -> List[str]:
|
||||||
|
"""提取关键词"""
|
||||||
|
# 使用TF-IDF向量器提取关键词
|
||||||
|
words = message.lower().split()
|
||||||
|
|
||||||
|
# 过滤停用词
|
||||||
|
stop_words = set(self.vectorizer.get_stop_words()) if self.vectorizer.get_stop_words() else set()
|
||||||
|
keywords = [word for word in words if word not in stop_words and len(word) > 2]
|
||||||
|
|
||||||
|
# 只返回前top_n个关键词
|
||||||
|
return keywords[:top_n]
|
||||||
|
|
||||||
|
|
||||||
|
class ModelEvaluationTool(Tool):
|
||||||
|
"""模型评估工具"""
|
||||||
|
|
||||||
|
def __init__(self, classifier: SpamClassifier):
|
||||||
|
super().__init__(name="model_evaluation_tool", description="评估模型在给定数据集上的性能")
|
||||||
|
self.classifier = classifier
|
||||||
|
|
||||||
|
async def __call__(self, test_data: List[str], labels: List[str]) -> Dict[str, float]:
|
||||||
|
"""评估模型性能"""
|
||||||
|
# 转换数据格式
|
||||||
|
test_series = pl.Series("message", test_data)
|
||||||
|
|
||||||
|
# 提取特征
|
||||||
|
# 注意:这里我们需要重新训练向量器或使用已有的向量器
|
||||||
|
# 为了简化,我们直接使用已有的向量器转换数据
|
||||||
|
test_vectors = self.classifier.vectorizer.transform(test_data)
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
predictions = self.classifier.model.predict(test_vectors)
|
||||||
|
predictions_proba = self.classifier.model.predict_proba(test_vectors)[:, 1]
|
||||||
|
|
||||||
|
# 转换标签为数值
|
||||||
|
y_true = [1 if label == "spam" else 0 for label in labels]
|
||||||
|
|
||||||
|
# 计算评估指标
|
||||||
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
||||||
|
|
||||||
|
metrics = {
|
||||||
|
"accuracy": accuracy_score(y_true, predictions),
|
||||||
|
"precision": precision_score(y_true, predictions),
|
||||||
|
"recall": recall_score(y_true, predictions),
|
||||||
|
"f1": f1_score(y_true, predictions),
|
||||||
|
"roc_auc": roc_auc_score(y_true, predictions_proba)
|
||||||
|
}
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
class SpamAnalysisAgent:
|
||||||
|
"""垃圾短信分析Agent"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "lightgbm"):
|
||||||
|
"""初始化Agent"""
|
||||||
|
# 创建分类器
|
||||||
|
self.classifier = SpamClassifier(model_name)
|
||||||
|
|
||||||
|
# 创建工具
|
||||||
|
self.tools = [
|
||||||
|
SpamAnalysisTool(self.classifier),
|
||||||
|
ModelEvaluationTool(self.classifier)
|
||||||
|
]
|
||||||
|
|
||||||
|
# 创建AI实例
|
||||||
|
self.ai = AI(
|
||||||
|
model=settings.deepseek_model,
|
||||||
|
api_key=settings.deepseek_api_key,
|
||||||
|
api_base=settings.deepseek_api_base,
|
||||||
|
tools=self.tools
|
||||||
|
)
|
||||||
|
|
||||||
|
async def analyze_message(self, message: str, is_english: bool = True) -> AnalysisResult:
|
||||||
|
"""分析单条短信"""
|
||||||
|
# 使用AI工具分析短信
|
||||||
|
result = await self.ai.run(
|
||||||
|
f"分析以下短信: {message}",
|
||||||
|
output_model=AnalysisResult,
|
||||||
|
max_tokens=1000,
|
||||||
|
temperature=0.1
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def batch_analyze(self, messages: List[str], is_english: bool = True) -> List[AnalysisResult]:
|
||||||
|
"""批量分析短信"""
|
||||||
|
results = []
|
||||||
|
for message in messages:
|
||||||
|
result = await self.analyze_message(message, is_english)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Agent主函数"""
|
||||||
|
# 创建Agent实例
|
||||||
|
agent = SpamAnalysisAgent()
|
||||||
|
|
||||||
|
# 测试短信
|
||||||
|
test_messages = [
|
||||||
|
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
|
||||||
|
"Ok lar... Joking wif u oni...",
|
||||||
|
"WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only."
|
||||||
|
]
|
||||||
|
|
||||||
|
# 分析短信
|
||||||
|
for message in test_messages:
|
||||||
|
print(f"\n=== 分析短信 ===")
|
||||||
|
print(f"原始短信: {message}")
|
||||||
|
result = await agent.analyze_message(message)
|
||||||
|
print(f"分类结果: {result.classification.label} (置信度: {result.classification.confidence:.2f})")
|
||||||
|
print(f"中文翻译: {result.message_zh}")
|
||||||
|
print(f"关键特征词: {', '.join(result.explanation.key_words)}")
|
||||||
|
print(f"分类原因: {result.explanation.reason}")
|
||||||
|
print(f"行动建议: {result.explanation.suggestion}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import asyncio
|
||||||
|
asyncio.run(main())
|
||||||
29
src/config.py
Normal file
29
src/config.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
"""项目配置类"""
|
||||||
|
# DeepSeek API配置
|
||||||
|
deepseek_api_key: str
|
||||||
|
|
||||||
|
# 项目路径配置
|
||||||
|
model_save_path: str = "./models"
|
||||||
|
data_path: str = "./data"
|
||||||
|
|
||||||
|
# 模型配置
|
||||||
|
random_state: int = 42
|
||||||
|
test_size: float = 0.2
|
||||||
|
|
||||||
|
# DeepSeek API配置
|
||||||
|
deepseek_api_base: str = "https://api.deepseek.com"
|
||||||
|
deepseek_model: str = "deepseek-chat"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
import os
|
||||||
|
env_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env")
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
|
||||||
|
|
||||||
|
# 创建全局配置实例
|
||||||
|
settings = Settings()
|
||||||
76
src/data_processing.py
Normal file
76
src/data_processing.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
import polars as pl
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
|
||||||
|
def load_data(file_path: str) -> pl.DataFrame:
|
||||||
|
"""使用Polars加载数据集"""
|
||||||
|
# 加载csv文件,处理编码问题
|
||||||
|
df = pl.read_csv(
|
||||||
|
file_path,
|
||||||
|
encoding="latin-1",
|
||||||
|
ignore_errors=True,
|
||||||
|
has_header=True
|
||||||
|
)
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def clean_data(df: pl.DataFrame) -> pl.DataFrame:
|
||||||
|
"""清洗数据集"""
|
||||||
|
# 查看数据集基本信息
|
||||||
|
print("原始数据集形状:", df.shape)
|
||||||
|
print("原始数据集列名:", df.columns)
|
||||||
|
|
||||||
|
# 删除不必要的列(最后三列都是空的)
|
||||||
|
df = df.drop(df.columns[-3:])
|
||||||
|
|
||||||
|
# 重命名列名
|
||||||
|
df = df.rename({
|
||||||
|
"v1": "label",
|
||||||
|
"v2": "message"
|
||||||
|
})
|
||||||
|
|
||||||
|
# 查看清洗后的数据集
|
||||||
|
print("清洗后数据集形状:", df.shape)
|
||||||
|
print("清洗后数据集列名:", df.columns)
|
||||||
|
print("标签分布:", df["label"].value_counts())
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_data(df: pl.DataFrame) -> Tuple[pl.DataFrame, pl.Series]:
|
||||||
|
"""预处理数据,准备用于模型训练"""
|
||||||
|
# 将标签转换为数值(ham=0, spam=1)
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(pl.col("label") == "spam").then(1).otherwise(0).alias("label")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 分离特征和标签
|
||||||
|
X = df.drop("label")
|
||||||
|
y = df["label"]
|
||||||
|
|
||||||
|
return X, y
|
||||||
|
|
||||||
|
|
||||||
|
def save_data(df: pl.DataFrame, file_path: str) -> None:
|
||||||
|
"""保存处理后的数据集"""
|
||||||
|
df.write_csv(file_path, index=False)
|
||||||
|
print(f"数据集已保存到: {file_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试数据处理流程
|
||||||
|
file_path = "../spam.csv"
|
||||||
|
# 检查文件是否存在
|
||||||
|
import os
|
||||||
|
if not os.path.exists(file_path):
|
||||||
|
file_path = "./spam.csv"
|
||||||
|
df = load_data(file_path)
|
||||||
|
df_cleaned = clean_data(df)
|
||||||
|
X, y = preprocess_data(df_cleaned)
|
||||||
|
|
||||||
|
print("特征数据形状:", X.shape)
|
||||||
|
print("标签数据形状:", y.shape)
|
||||||
|
print("前5行数据:")
|
||||||
|
print(df_cleaned.head())
|
||||||
316
src/machine_learning.py
Normal file
316
src/machine_learning.py
Normal file
@ -0,0 +1,316 @@
|
|||||||
|
import polars as pl
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.ensemble import RandomForestClassifier
|
||||||
|
import lightgbm as lgb
|
||||||
|
from sklearn.model_selection import train_test_split, GridSearchCV
|
||||||
|
from sklearn.metrics import (
|
||||||
|
accuracy_score, precision_score, recall_score, f1_score,
|
||||||
|
roc_auc_score, classification_report, confusion_matrix
|
||||||
|
)
|
||||||
|
import joblib
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple, Dict, Any, Optional
|
||||||
|
from config import settings
|
||||||
|
|
||||||
|
|
||||||
|
class SpamClassifier:
|
||||||
|
"""垃圾短信分类器"""
|
||||||
|
def __init__(self, model_name: str = "lightgbm"):
|
||||||
|
"""初始化分类器"""
|
||||||
|
self.model_name = model_name
|
||||||
|
self.model = None
|
||||||
|
self.vectorizer = None
|
||||||
|
self.load_model()
|
||||||
|
|
||||||
|
def load_model(self):
|
||||||
|
"""加载模型和向量器"""
|
||||||
|
model_dir = Path(settings.model_save_path)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
model_path = model_dir / f"{self.model_name}_model.joblib"
|
||||||
|
self.model = joblib.load(model_path)
|
||||||
|
print(f"模型已从: {model_path} 加载")
|
||||||
|
|
||||||
|
# 加载向量器
|
||||||
|
vectorizer_path = model_dir / f"{self.model_name}_vectorizer.joblib"
|
||||||
|
self.vectorizer = joblib.load(vectorizer_path)
|
||||||
|
print(f"向量器已从: {vectorizer_path} 加载")
|
||||||
|
|
||||||
|
def classify(self, message: str) -> Dict[str, Any]:
|
||||||
|
"""分类单条短信"""
|
||||||
|
# 将短信转换为向量
|
||||||
|
message_vector = self.vectorizer.transform([message])
|
||||||
|
|
||||||
|
# 预测标签和置信度
|
||||||
|
label = self.model.predict(message_vector)[0]
|
||||||
|
confidence = self.model.predict_proba(message_vector)[0][label]
|
||||||
|
|
||||||
|
# 转换标签为文本
|
||||||
|
label_text = "spam" if label == 1 else "ham"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"label": label_text,
|
||||||
|
"confidence": confidence
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def extract_features(
|
||||||
|
X_train: pl.Series,
|
||||||
|
X_test: pl.Series,
|
||||||
|
max_features: int = 1000
|
||||||
|
) -> Tuple[Any, Any, TfidfVectorizer]:
|
||||||
|
"""
|
||||||
|
使用TF-IDF提取文本特征
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X_train: 训练集文本
|
||||||
|
X_test: 测试集文本
|
||||||
|
max_features: 最大特征数
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
训练集特征、测试集特征、TF-IDF向量化器
|
||||||
|
"""
|
||||||
|
# 将Polars Series转换为Pandas Series
|
||||||
|
X_train_pd = X_train.to_pandas()
|
||||||
|
X_test_pd = X_test.to_pandas()
|
||||||
|
|
||||||
|
# 初始化TF-IDF向量化器
|
||||||
|
tfidf = TfidfVectorizer(
|
||||||
|
max_features=max_features,
|
||||||
|
stop_words="english",
|
||||||
|
ngram_range=(1, 2)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 拟合并转换训练集
|
||||||
|
X_train_tfidf = tfidf.fit_transform(X_train_pd)
|
||||||
|
|
||||||
|
# 转换测试集
|
||||||
|
X_test_tfidf = tfidf.transform(X_test_pd)
|
||||||
|
|
||||||
|
return X_train_tfidf, X_test_tfidf, tfidf
|
||||||
|
|
||||||
|
|
||||||
|
def train_logistic_regression(
|
||||||
|
X_train: Any,
|
||||||
|
y_train: pl.Series
|
||||||
|
) -> LogisticRegression:
|
||||||
|
"""
|
||||||
|
训练Logistic Regression模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X_train: 训练集特征
|
||||||
|
y_train: 训练集标签
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
训练好的Logistic Regression模型
|
||||||
|
"""
|
||||||
|
# 将Polars Series转换为Pandas Series
|
||||||
|
y_train_pd = y_train.to_pandas()
|
||||||
|
|
||||||
|
# 初始化Logistic Regression模型
|
||||||
|
log_reg = LogisticRegression(
|
||||||
|
random_state=settings.random_state,
|
||||||
|
max_iter=1000,
|
||||||
|
class_weight="balanced"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
log_reg.fit(X_train, y_train_pd)
|
||||||
|
|
||||||
|
return log_reg
|
||||||
|
|
||||||
|
|
||||||
|
def train_lightgbm(
|
||||||
|
X_train: Any,
|
||||||
|
y_train: pl.Series
|
||||||
|
) -> lgb.LGBMClassifier:
|
||||||
|
"""
|
||||||
|
训练LightGBM模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X_train: 训练集特征
|
||||||
|
y_train: 训练集标签
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
训练好的LightGBM模型
|
||||||
|
"""
|
||||||
|
# 将Polars Series转换为Pandas Series
|
||||||
|
y_train_pd = y_train.to_pandas()
|
||||||
|
|
||||||
|
# 初始化LightGBM模型
|
||||||
|
lgb_clf = lgb.LGBMClassifier(
|
||||||
|
random_state=settings.random_state,
|
||||||
|
class_weight="balanced",
|
||||||
|
n_estimators=1000,
|
||||||
|
learning_rate=0.1,
|
||||||
|
num_leaves=31
|
||||||
|
)
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
lgb_clf.fit(X_train, y_train_pd)
|
||||||
|
|
||||||
|
return lgb_clf
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(
|
||||||
|
model: Any,
|
||||||
|
X_test: Any,
|
||||||
|
y_test: pl.Series
|
||||||
|
) -> Dict[str, float]:
|
||||||
|
"""
|
||||||
|
评估模型性能
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 训练好的模型
|
||||||
|
X_test: 测试集特征
|
||||||
|
y_test: 测试集标签
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模型评估指标
|
||||||
|
"""
|
||||||
|
# 将Polars Series转换为Pandas Series
|
||||||
|
y_test_pd = y_test.to_pandas()
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
y_pred = model.predict(X_test)
|
||||||
|
y_pred_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, 'predict_proba') else None
|
||||||
|
|
||||||
|
# 计算评估指标
|
||||||
|
metrics = {
|
||||||
|
"accuracy": accuracy_score(y_test_pd, y_pred),
|
||||||
|
"precision": precision_score(y_test_pd, y_pred),
|
||||||
|
"recall": recall_score(y_test_pd, y_pred),
|
||||||
|
"f1": f1_score(y_test_pd, y_pred)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 计算ROC-AUC(如果模型支持概率预测)
|
||||||
|
if y_pred_proba is not None:
|
||||||
|
metrics["roc_auc"] = roc_auc_score(y_test_pd, y_pred_proba)
|
||||||
|
|
||||||
|
# 打印分类报告和混淆矩阵
|
||||||
|
print("分类报告:")
|
||||||
|
print(classification_report(y_test_pd, y_pred))
|
||||||
|
|
||||||
|
print("混淆矩阵:")
|
||||||
|
print(confusion_matrix(y_test_pd, y_pred))
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(
|
||||||
|
model: Any,
|
||||||
|
model_name: str,
|
||||||
|
vectorizer: Any = None
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
保存模型和向量器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 训练好的模型
|
||||||
|
model_name: 模型名称
|
||||||
|
vectorizer: TF-IDF向量化器
|
||||||
|
"""
|
||||||
|
# 创建模型保存目录
|
||||||
|
model_dir = Path(settings.model_save_path)
|
||||||
|
model_dir.mkdir(exist_ok=True)
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
model_path = model_dir / f"{model_name}_model.joblib"
|
||||||
|
joblib.dump(model, model_path)
|
||||||
|
print(f"模型已保存到: {model_path}")
|
||||||
|
|
||||||
|
# 保存向量器(如果提供)
|
||||||
|
if vectorizer is not None:
|
||||||
|
vectorizer_path = model_dir / f"{model_name}_vectorizer.joblib"
|
||||||
|
joblib.dump(vectorizer, vectorizer_path)
|
||||||
|
print(f"向量器已保存到: {vectorizer_path}")
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(
|
||||||
|
model_name: str
|
||||||
|
) -> Tuple[Any, Any]:
|
||||||
|
"""
|
||||||
|
加载模型和向量器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加载的模型和向量器
|
||||||
|
"""
|
||||||
|
# 创建模型保存目录
|
||||||
|
model_dir = Path(settings.model_save_path)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
model_path = model_dir / f"{model_name}_model.joblib"
|
||||||
|
model = joblib.load(model_path)
|
||||||
|
print(f"模型已从: {model_path} 加载")
|
||||||
|
|
||||||
|
# 加载向量器
|
||||||
|
vectorizer_path = model_dir / f"{model_name}_vectorizer.joblib"
|
||||||
|
vectorizer = joblib.load(vectorizer_path)
|
||||||
|
print(f"向量器已从: {vectorizer_path} 加载")
|
||||||
|
|
||||||
|
return model, vectorizer
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""机器学习主函数"""
|
||||||
|
# 1. 加载数据集
|
||||||
|
print("正在加载数据集...")
|
||||||
|
df = pl.read_csv("../spam.csv", encoding="latin-1", ignore_errors=True)
|
||||||
|
|
||||||
|
# 2. 清洗数据集
|
||||||
|
print("正在清洗数据集...")
|
||||||
|
df = df.drop(df.columns[-3:])
|
||||||
|
df = df.rename({"v1": "label", "v2": "message"})
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(pl.col("label") == "spam").then(1).otherwise(0).alias("label")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 分离特征和标签
|
||||||
|
X = df["message"]
|
||||||
|
y = df["label"]
|
||||||
|
|
||||||
|
# 4. 划分训练集和测试集
|
||||||
|
X_train, X_test, y_train, y_test = train_test_split(
|
||||||
|
X, y, test_size=settings.test_size, random_state=settings.random_state, stratify=y
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"训练集大小: {len(X_train)}")
|
||||||
|
print(f"测试集大小: {len(X_test)}")
|
||||||
|
|
||||||
|
# 5. 特征提取
|
||||||
|
print("正在提取特征...")
|
||||||
|
X_train_tfidf, X_test_tfidf, tfidf = extract_features(X_train, X_test)
|
||||||
|
|
||||||
|
# 6. 训练Logistic Regression模型
|
||||||
|
print("\n正在训练Logistic Regression模型...")
|
||||||
|
log_reg_model = train_logistic_regression(X_train_tfidf, y_train)
|
||||||
|
|
||||||
|
# 7. 评估Logistic Regression模型
|
||||||
|
print("\n评估Logistic Regression模型:")
|
||||||
|
log_reg_metrics = evaluate_model(log_reg_model, X_test_tfidf, y_test)
|
||||||
|
print(f"Logistic Regression指标: {log_reg_metrics}")
|
||||||
|
|
||||||
|
# 8. 训练LightGBM模型
|
||||||
|
print("\n正在训练LightGBM模型...")
|
||||||
|
lgb_model = train_lightgbm(X_train_tfidf, y_train)
|
||||||
|
|
||||||
|
# 9. 评估LightGBM模型
|
||||||
|
print("\n评估LightGBM模型:")
|
||||||
|
lgb_metrics = evaluate_model(lgb_model, X_test_tfidf, y_test)
|
||||||
|
print(f"LightGBM指标: {lgb_metrics}")
|
||||||
|
|
||||||
|
# 10. 保存模型
|
||||||
|
print("\n正在保存模型...")
|
||||||
|
save_model(log_reg_model, "logistic_regression", tfidf)
|
||||||
|
save_model(lgb_model, "lightgbm", tfidf)
|
||||||
|
|
||||||
|
print("\n机器学习流程完成!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
24
src/main.py
Normal file
24
src/main.py
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
from data_processing import load_data, clean_data, save_data
|
||||||
|
from translation import translate_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数"""
|
||||||
|
# 1. 加载数据集
|
||||||
|
print("正在加载数据集...")
|
||||||
|
df = load_data("../spam.csv")
|
||||||
|
|
||||||
|
# 2. 清洗数据集
|
||||||
|
print("\n正在清洗数据集...")
|
||||||
|
df_cleaned = clean_data(df)
|
||||||
|
|
||||||
|
# 3. 只翻译前10条短信进行测试
|
||||||
|
print("\n正在翻译前10条短信进行测试...")
|
||||||
|
df_test = df_cleaned.head(10)
|
||||||
|
translated_path = translate_dataset(df_test)
|
||||||
|
|
||||||
|
print(f"\n测试完成!翻译后的测试数据集已保存到: {translated_path}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
150
src/simple_agent.py
Normal file
150
src/simple_agent.py
Normal file
@ -0,0 +1,150 @@
|
|||||||
|
import requests
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
from config import settings
|
||||||
|
from machine_learning import SpamClassifier
|
||||||
|
from translation import translate_text
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleSpamAnalysis:
|
||||||
|
"""简单的垃圾短信分析系统"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = "lightgbm"):
|
||||||
|
"""初始化分析系统"""
|
||||||
|
self.classifier = SpamClassifier(model_name)
|
||||||
|
|
||||||
|
def analyze(self, message: str, is_english: bool = True) -> Dict[str, Any]:
|
||||||
|
"""分析单条短信"""
|
||||||
|
# 1. 翻译短信
|
||||||
|
message_zh = translate_text(message, "zh-CN") if is_english else message
|
||||||
|
|
||||||
|
# 2. 分类短信
|
||||||
|
classification = self.classifier.classify(message)
|
||||||
|
|
||||||
|
# 3. 提取关键词
|
||||||
|
key_words = self.extract_keywords(message)
|
||||||
|
|
||||||
|
# 4. 生成解释和建议
|
||||||
|
reason, suggestion = self.generate_explanation(key_words, classification["label"])
|
||||||
|
|
||||||
|
# 5. 使用DeepSeek API生成更详细的解释
|
||||||
|
detailed_explanation = self.generate_detailed_explanation(
|
||||||
|
message, message_zh, classification["label"], key_words
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"original_message": message,
|
||||||
|
"translated_message": message_zh,
|
||||||
|
"classification": classification,
|
||||||
|
"key_words": key_words,
|
||||||
|
"reason": reason,
|
||||||
|
"suggestion": suggestion,
|
||||||
|
"detailed_explanation": detailed_explanation
|
||||||
|
}
|
||||||
|
|
||||||
|
def extract_keywords(self, message: str, top_n: int = 5) -> List[str]:
|
||||||
|
"""提取关键词"""
|
||||||
|
words = message.lower().split()
|
||||||
|
stop_words = set([
|
||||||
|
"the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for",
|
||||||
|
"with", "by", "from", "up", "down", "about", "above", "below", "of",
|
||||||
|
"is", "are", "was", "were", "be", "been", "being", "have", "has",
|
||||||
|
"had", "do", "does", "did", "will", "would", "shall", "should",
|
||||||
|
"may", "might", "must", "can", "could", "not", "no", "yes", "if",
|
||||||
|
"then", "than", "so", "because", "as", "when", "where", "who", "which",
|
||||||
|
"that", "this", "these", "those", "i", "me", "my", "mine", "you",
|
||||||
|
"your", "yours", "he", "him", "his", "she", "her", "hers", "it",
|
||||||
|
"its", "we", "us", "our", "ours", "they", "them", "their", "theirs"
|
||||||
|
])
|
||||||
|
|
||||||
|
keywords = [word for word in words if word not in stop_words and len(word) > 2]
|
||||||
|
return keywords[:top_n]
|
||||||
|
|
||||||
|
def generate_explanation(self, key_words: List[str], label: str) -> tuple:
|
||||||
|
"""生成基本解释和建议"""
|
||||||
|
if label == "spam":
|
||||||
|
reason = f"该短信包含垃圾短信特征词: {', '.join(key_words)}"
|
||||||
|
suggestion = "建议立即删除该短信,不要点击任何链接,不要回复,避免上当受骗"
|
||||||
|
else:
|
||||||
|
reason = f"该短信为正常短信,包含常用词汇: {', '.join(key_words)}"
|
||||||
|
suggestion = "可以正常回复和处理该短信"
|
||||||
|
return reason, suggestion
|
||||||
|
|
||||||
|
def generate_detailed_explanation(self, message: str, message_zh: str, label: str, key_words: List[str]) -> str:
|
||||||
|
"""使用DeepSeek API生成详细解释"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {settings.deepseek_api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
分析以下短信:
|
||||||
|
英文:{message}
|
||||||
|
中文:{message_zh}
|
||||||
|
分类结果:{label}
|
||||||
|
关键词:{', '.join(key_words)}
|
||||||
|
|
||||||
|
请提供:
|
||||||
|
1. 详细的分类原因
|
||||||
|
2. 短信的主要特征
|
||||||
|
3. 针对该短信的具体建议
|
||||||
|
4. 如何识别类似的短信
|
||||||
|
|
||||||
|
请使用中文回答,保持简洁明了。
|
||||||
|
"""
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": settings.deepseek_model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "你是一名专业的垃圾短信分析师,请根据提供的信息进行详细分析。"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": prompt
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 500,
|
||||||
|
"temperature": 0.1
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{settings.deepseek_api_base}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
explanation = result["choices"][0]["message"]["content"].strip()
|
||||||
|
return explanation
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"生成详细解释失败: {e}")
|
||||||
|
return "无法生成详细解释,请检查API连接。"
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 初始化分析系统
|
||||||
|
analyzer = SimpleSpamAnalysis()
|
||||||
|
|
||||||
|
# 测试短信
|
||||||
|
test_messages = [
|
||||||
|
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
|
||||||
|
"Ok lar... Joking wif u oni...",
|
||||||
|
"WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only."
|
||||||
|
]
|
||||||
|
|
||||||
|
# 分析短信
|
||||||
|
for i, message in enumerate(test_messages):
|
||||||
|
print(f"\n=== 短信分析结果 {i+1} ===")
|
||||||
|
result = analyzer.analyze(message)
|
||||||
|
|
||||||
|
print(f"原始短信: {result['original_message']}")
|
||||||
|
print(f"中文翻译: {result['translated_message']}")
|
||||||
|
print(f"分类结果: {result['classification']['label']} (置信度: {result['classification']['confidence']:.2f})")
|
||||||
|
print(f"关键词: {', '.join(result['key_words'])}")
|
||||||
|
print(f"原因: {result['reason']}")
|
||||||
|
print(f"建议: {result['suggestion']}")
|
||||||
|
print(f"详细解释: {result['detailed_explanation']}")
|
||||||
130
src/translation.py
Normal file
130
src/translation.py
Normal file
@ -0,0 +1,130 @@
|
|||||||
|
import requests
|
||||||
|
from typing import List, Dict
|
||||||
|
from config import settings
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def translate_text(text: str, target_lang: str = "zh-CN") -> str:
|
||||||
|
"""
|
||||||
|
使用DeepSeek API将文本翻译成目标语言
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 要翻译的文本
|
||||||
|
target_lang: 目标语言,默认为中文(zh-CN)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
翻译后的文本
|
||||||
|
"""
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {settings.deepseek_api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"model": settings.deepseek_model,
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": f"You are a professional translator. Translate the following text to {target_lang}. Keep the original meaning and tone. Do not add any additional information."
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": text
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"max_tokens": 1000,
|
||||||
|
"temperature": 0.1
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
|
f"{settings.deepseek_api_base}/chat/completions",
|
||||||
|
headers=headers,
|
||||||
|
json=payload,
|
||||||
|
timeout=30
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
result = response.json()
|
||||||
|
translated_text = result["choices"][0]["message"]["content"].strip()
|
||||||
|
return translated_text
|
||||||
|
except requests.exceptions.RequestException as e:
|
||||||
|
print(f"翻译失败: {e}")
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def translate_batch(texts: List[str], target_lang: str = "zh-CN", batch_size: int = 10) -> List[str]:
|
||||||
|
"""
|
||||||
|
批量翻译文本
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: 要翻译的文本列表
|
||||||
|
target_lang: 目标语言,默认为中文(zh-CN)
|
||||||
|
batch_size: 批量大小,默认为10
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
翻译后的文本列表
|
||||||
|
"""
|
||||||
|
translated_texts = []
|
||||||
|
|
||||||
|
for i in range(0, len(texts), batch_size):
|
||||||
|
batch = texts[i:i+batch_size]
|
||||||
|
batch_translated = []
|
||||||
|
|
||||||
|
for text in batch:
|
||||||
|
translated = translate_text(text, target_lang)
|
||||||
|
batch_translated.append(translated)
|
||||||
|
# 添加延迟,避免API限流
|
||||||
|
time.sleep(0.5)
|
||||||
|
|
||||||
|
translated_texts.extend(batch_translated)
|
||||||
|
print(f"已翻译 {min(i+batch_size, len(texts))}/{len(texts)} 条文本")
|
||||||
|
|
||||||
|
return translated_texts
|
||||||
|
|
||||||
|
|
||||||
|
def translate_dataset(df, text_column: str = "message", target_column: str = "message_zh") -> str:
|
||||||
|
"""
|
||||||
|
翻译数据集中的文本列
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Polars DataFrame
|
||||||
|
text_column: 要翻译的文本列名
|
||||||
|
target_column: 翻译后的文本列名
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
翻译后的数据集文件路径
|
||||||
|
"""
|
||||||
|
import polars as pl
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 创建data目录(如果不存在)
|
||||||
|
os.makedirs(settings.data_path, exist_ok=True)
|
||||||
|
|
||||||
|
# 提取文本列表
|
||||||
|
texts = df[text_column].to_list()
|
||||||
|
|
||||||
|
# 翻译文本
|
||||||
|
print(f"开始翻译 {len(texts)} 条文本...")
|
||||||
|
translated_texts = translate_batch(texts)
|
||||||
|
|
||||||
|
# 添加翻译后的列到数据集
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.Series(target_column, translated_texts)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 保存翻译后的数据集
|
||||||
|
output_path = f"{settings.data_path}/spam_zh.csv"
|
||||||
|
df.write_csv(output_path, index=False)
|
||||||
|
|
||||||
|
print(f"翻译后的数据集已保存到: {output_path}")
|
||||||
|
print(f"翻译完成!共翻译了 {len(texts)} 条文本")
|
||||||
|
return output_path
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试翻译功能
|
||||||
|
test_text = "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's"
|
||||||
|
translated = translate_text(test_text)
|
||||||
|
print(f"原文: {test_text}")
|
||||||
|
print(f"译文: {translated}")
|
||||||
31
test_analysis.py
Normal file
31
test_analysis.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 添加src目录到Python路径
|
||||||
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src'))
|
||||||
|
|
||||||
|
from simple_agent import SimpleSpamAnalysis
|
||||||
|
|
||||||
|
|
||||||
|
# 测试短信
|
||||||
|
test_messages = [
|
||||||
|
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's",
|
||||||
|
"Ok lar... Joking wif u oni...",
|
||||||
|
"WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only."
|
||||||
|
]
|
||||||
|
|
||||||
|
# 初始化分析系统
|
||||||
|
analyzer = SimpleSpamAnalysis()
|
||||||
|
|
||||||
|
# 分析短信
|
||||||
|
for i, message in enumerate(test_messages):
|
||||||
|
print(f"\n=== 短信分析结果 {i+1} ===")
|
||||||
|
result = analyzer.analyze(message)
|
||||||
|
|
||||||
|
print(f"原始短信: {result['original_message'][:100]}...")
|
||||||
|
print(f"中文翻译: {result['translated_message'][:100]}...")
|
||||||
|
print(f"分类结果: {result['classification']['label']} (置信度: {result['classification']['confidence']:.2f})")
|
||||||
|
print(f"关键词: {', '.join(result['key_words'])}")
|
||||||
|
print(f"原因: {result['reason']}")
|
||||||
|
print(f"建议: {result['suggestion']}")
|
||||||
|
print(f"详细解释: {result['detailed_explanation'][:200]}...")
|
||||||
7
test_translation.py
Normal file
7
test_translation.py
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
from src.translation import translate_text
|
||||||
|
|
||||||
|
# 测试单个翻译功能
|
||||||
|
test_text = "Hello, how are you?"
|
||||||
|
print(f"原文: {test_text}")
|
||||||
|
translated = translate_text(test_text)
|
||||||
|
print(f"译文: {translated}")
|
||||||
Loading…
Reference in New Issue
Block a user