diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..9afae17 --- /dev/null +++ b/.env.example @@ -0,0 +1,6 @@ +# DeepSeek API Configuration +DEEPSEEK_API_KEY="your-deepseek-api-key-here" + +# Project Configuration +MODEL_SAVE_PATH="./models" +DATA_PATH="./data" diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6ddf800 --- /dev/null +++ b/.gitignore @@ -0,0 +1,54 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class + +# Environment +.env +.env.local +.env.development.local +.env.test.local +.env.production.local + +# Dependencies +.venv/ +venv/ +env/ + +# Data +data/ +*.csv +*.parquet +*.h5 + +# Models +models/ +*.joblib +*.pkl +*.model +*.txt + +# Logs +logs/ +*.log + +# Build +dist/ +build/ +*.egg-info/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# Testing +.pytest_cache/ +.coverage +htmlcov/ + +# OS +.DS_Store +Thumbs.db diff --git a/.trae/documents/垃圾短信分类项目实现计划.md b/.trae/documents/垃圾短信分类项目实现计划.md new file mode 100644 index 0000000..6cabde2 --- /dev/null +++ b/.trae/documents/垃圾短信分类项目实现计划.md @@ -0,0 +1,49 @@ +# 垃圾短信分类项目实现计划 + +## 1. 项目结构搭建 +- 创建项目目录结构,包括 `src`、`data`、`models` 等目录 +- 初始化项目依赖,使用 uv 进行管理 +- 创建配置文件和环境变量管理 + +## 2. 数据处理 +- 使用 Polars 加载和清洗 spam.csv 数据集 +- 将英文短信翻译成中文,使用 DeepSeek API +- 使用 Pandera 定义数据 Schema 进行验证 +- 数据预处理和特征工程 + +## 3. 机器学习模型 +- 实现至少两个模型:Logistic Regression 作为基线,LightGBM 作为强模型 +- 模型训练、验证和评估 +- 模型保存与加载 +- 达到 F1 ≥ 0.70 或 ROC-AUC ≥ 0.75 的性能指标 + +## 4. LLM 集成 +- 使用 DeepSeek API 进行短信内容解释和归因 +- 生成结构化的行动建议 +- 确保输出可追溯、可复现 + +## 5. Agent 框架 +- 使用 pydantic-ai 构建结构化输出的 Agent +- 实现至少两个工具:ML 预测工具和评估工具 +- 构建完整的工具调用流程 + +## 6. 项目测试和部署 +- 编写单元测试和集成测试 +- 确保项目可在教师机上运行 +- 准备项目展示材料 + +## 技术栈 +- Python 3.12 +- uv 进行项目管理 +- Polars + Pandas 进行数据处理 +- Pandera 进行数据验证 +- Scikit-learn + LightGBM 进行机器学习 +- pydantic-ai 作为 Agent 框架 +- DeepSeek API 作为 LLM 提供方 + +## 预期成果 +- 一个完整的垃圾短信分类系统 +- 中文翻译后的数据集 +- 可复现的机器学习模型 +- 基于 LLM 的智能建议生成 +- 结构化、可追溯的输出 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..c2fc5c4 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,41 @@ +[tool.uv] +index-url = "https://mirrors.aliyun.com/pypi/simple/" + +[project] +name = "spam-classification" +version = "0.1.0" +authors = [{ name = "Your Name", email = "your.email@example.com" }] +description = "Spam message classification with ML and LLM integration" +readme = "README.md" +requires-python = ">=3.12" + +[project.dependencies] +pandas = ">=2.2" +polars = ">=0.20" +pandera = ">=0.18" +scikit-learn = ">=1.4" +lightgbm = ">=4.3" +pydantic = ">=2.5" +pydantic-ai = ">=0.3" +python-dotenv = ">=1.0" +requests = ">=2.31" + +[project.optional-dependencies] +dev = [ + "pytest>=7.4", + "ruff>=0.2" +] + +[build-system] +requires = ["uv>=0.1.0"] +build-backend = "uv.build_api" + +[tool.ruff] +select = ["E", "F", "W"] +line-length = 88 + +[tool.pytest.ini_options] +testpaths = ["tests"] +python_files = "test_*.py" +python_classes = "Test*" +python_functions = "test_*" diff --git a/simple_test.py b/simple_test.py new file mode 100644 index 0000000..8702750 --- /dev/null +++ b/simple_test.py @@ -0,0 +1,50 @@ +import requests + +# 直接测试DeepSeek API +def test_deepseek_api(): + api_key = "sk-591e36a6b1bd4b34b663b466ff22085e" + api_base = "https://api.deepseek.com" + model = "deepseek-chat" + + headers = { + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json" + } + + payload = { + "model": model, + "messages": [ + { + "role": "system", + "content": "You are a professional translator. Translate the following text to Chinese. Keep the original meaning and tone. Do not add any additional information." + }, + { + "role": "user", + "content": "Hello, how are you?" + } + ], + "max_tokens": 1000, + "temperature": 0.1 + } + + try: + response = requests.post( + f"{api_base}/chat/completions", + headers=headers, + json=payload, + timeout=30 + ) + response.raise_for_status() + + result = response.json() + print("API响应:", result) + translated_text = result["choices"][0]["message"]["content"].strip() + print(f"翻译结果: {translated_text}") + return translated_text + except requests.exceptions.RequestException as e: + print(f"翻译失败: {e}") + return None + + +if __name__ == "__main__": + test_deepseek_api() diff --git a/src/agent.py b/src/agent.py new file mode 100644 index 0000000..9013f94 --- /dev/null +++ b/src/agent.py @@ -0,0 +1,250 @@ +import polars as pl +import pandas as pd +from typing import List, Dict, Any, Optional +from pydantic import BaseModel, Field +from pydantic_ai import AI +from pydantic_ai.agent import Tool +import joblib +from pathlib import Path +from config import settings +from machine_learning import extract_features +from translation import translate_text + + +class Message(BaseModel): + """短信模型""" + content: str = Field(..., description="短信内容") + is_english: bool = Field(default=True, description="短信是否为英文") + + +class ClassificationResult(BaseModel): + """分类结果模型""" + label: str = Field(..., description="分类标签,ham或spam") + confidence: float = Field(..., description="分类置信度") + + +class Explanation(BaseModel): + """解释模型""" + key_words: List[str] = Field(..., description="关键特征词") + reason: str = Field(..., description="分类原因") + suggestion: str = Field(..., description="行动建议") + + +class AnalysisResult(BaseModel): + """分析结果模型""" + message: str = Field(..., description="原始短信") + message_zh: str = Field(..., description="中文翻译") + classification: ClassificationResult = Field(..., description="分类结果") + explanation: Explanation = Field(..., description="分类解释和建议") + + +class SpamClassifier: + """垃圾短信分类器""" + def __init__(self, model_name: str = "lightgbm"): + """初始化分类器""" + self.model_name = model_name + self.model = None + self.vectorizer = None + self.load_model() + + def load_model(self): + """加载模型和向量器""" + model_dir = Path(settings.model_save_path) + + # 加载模型 + model_path = model_dir / f"{self.model_name}_model.joblib" + self.model = joblib.load(model_path) + print(f"模型已从: {model_path} 加载") + + # 加载向量器 + vectorizer_path = model_dir / f"{self.model_name}_vectorizer.joblib" + self.vectorizer = joblib.load(vectorizer_path) + print(f"向量器已从: {vectorizer_path} 加载") + + def classify(self, message: str) -> Dict[str, Any]: + """分类单条短信""" + # 将短信转换为向量 + message_vector = self.vectorizer.transform([message]) + + # 预测标签和置信度 + label = self.model.predict(message_vector)[0] + confidence = self.model.predict_proba(message_vector)[0][label] + + # 转换标签为文本 + label_text = "spam" if label == 1 else "ham" + + return { + "label": label_text, + "confidence": confidence + } + + +class SpamAnalysisTool(Tool): + """垃圾短信分析工具""" + + def __init__(self, classifier: SpamClassifier): + super().__init__(name="spam_analysis_tool", description="分析短信是否为垃圾短信,并提供解释和建议") + self.classifier = classifier + + async def __call__(self, message: str, is_english: bool = True) -> AnalysisResult: + """调用工具分析短信""" + # 如果是英文,翻译成中文 + message_zh = translate_text(message, "zh-CN") if is_english else message + + # 分类短信 + classification = self.classifier.classify(message) + + # 生成解释和建议 + explanation = self.generate_explanation(message, classification["label"]) + + return AnalysisResult( + message=message, + message_zh=message_zh, + classification=ClassificationResult( + label=classification["label"], + confidence=classification["confidence"] + ), + explanation=explanation + ) + + def generate_explanation(self, message: str, label: str) -> Explanation: + """生成解释和建议""" + # 简单的关键词提取(实际项目中可以使用更复杂的方法) + key_words = self.extract_keywords(message) + + # 生成原因和建议 + if label == "spam": + reason = f"该短信包含垃圾短信特征词: {', '.join(key_words)}" + suggestion = "建议立即删除该短信,不要点击任何链接,不要回复,避免上当受骗" + else: + reason = f"该短信为正常短信,包含常用词汇: {', '.join(key_words)}" + suggestion = "可以正常回复和处理该短信" + + return Explanation( + key_words=key_words, + reason=reason, + suggestion=suggestion + ) + + def extract_keywords(self, message: str, top_n: int = 5) -> List[str]: + """提取关键词""" + # 使用TF-IDF向量器提取关键词 + words = message.lower().split() + + # 过滤停用词 + stop_words = set(self.vectorizer.get_stop_words()) if self.vectorizer.get_stop_words() else set() + keywords = [word for word in words if word not in stop_words and len(word) > 2] + + # 只返回前top_n个关键词 + return keywords[:top_n] + + +class ModelEvaluationTool(Tool): + """模型评估工具""" + + def __init__(self, classifier: SpamClassifier): + super().__init__(name="model_evaluation_tool", description="评估模型在给定数据集上的性能") + self.classifier = classifier + + async def __call__(self, test_data: List[str], labels: List[str]) -> Dict[str, float]: + """评估模型性能""" + # 转换数据格式 + test_series = pl.Series("message", test_data) + + # 提取特征 + # 注意:这里我们需要重新训练向量器或使用已有的向量器 + # 为了简化,我们直接使用已有的向量器转换数据 + test_vectors = self.classifier.vectorizer.transform(test_data) + + # 预测 + predictions = self.classifier.model.predict(test_vectors) + predictions_proba = self.classifier.model.predict_proba(test_vectors)[:, 1] + + # 转换标签为数值 + y_true = [1 if label == "spam" else 0 for label in labels] + + # 计算评估指标 + from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score + + metrics = { + "accuracy": accuracy_score(y_true, predictions), + "precision": precision_score(y_true, predictions), + "recall": recall_score(y_true, predictions), + "f1": f1_score(y_true, predictions), + "roc_auc": roc_auc_score(y_true, predictions_proba) + } + + return metrics + + +class SpamAnalysisAgent: + """垃圾短信分析Agent""" + + def __init__(self, model_name: str = "lightgbm"): + """初始化Agent""" + # 创建分类器 + self.classifier = SpamClassifier(model_name) + + # 创建工具 + self.tools = [ + SpamAnalysisTool(self.classifier), + ModelEvaluationTool(self.classifier) + ] + + # 创建AI实例 + self.ai = AI( + model=settings.deepseek_model, + api_key=settings.deepseek_api_key, + api_base=settings.deepseek_api_base, + tools=self.tools + ) + + async def analyze_message(self, message: str, is_english: bool = True) -> AnalysisResult: + """分析单条短信""" + # 使用AI工具分析短信 + result = await self.ai.run( + f"分析以下短信: {message}", + output_model=AnalysisResult, + max_tokens=1000, + temperature=0.1 + ) + + return result + + async def batch_analyze(self, messages: List[str], is_english: bool = True) -> List[AnalysisResult]: + """批量分析短信""" + results = [] + for message in messages: + result = await self.analyze_message(message, is_english) + results.append(result) + + return results + + +async def main(): + """Agent主函数""" + # 创建Agent实例 + agent = SpamAnalysisAgent() + + # 测试短信 + test_messages = [ + "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's", + "Ok lar... Joking wif u oni...", + "WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only." + ] + + # 分析短信 + for message in test_messages: + print(f"\n=== 分析短信 ===") + print(f"原始短信: {message}") + result = await agent.analyze_message(message) + print(f"分类结果: {result.classification.label} (置信度: {result.classification.confidence:.2f})") + print(f"中文翻译: {result.message_zh}") + print(f"关键特征词: {', '.join(result.explanation.key_words)}") + print(f"分类原因: {result.explanation.reason}") + print(f"行动建议: {result.explanation.suggestion}") + + +if __name__ == "__main__": + import asyncio + asyncio.run(main()) diff --git a/src/config.py b/src/config.py new file mode 100644 index 0000000..75f13a5 --- /dev/null +++ b/src/config.py @@ -0,0 +1,29 @@ +from pydantic_settings import BaseSettings +from typing import Optional + + +class Settings(BaseSettings): + """项目配置类""" + # DeepSeek API配置 + deepseek_api_key: str + + # 项目路径配置 + model_save_path: str = "./models" + data_path: str = "./data" + + # 模型配置 + random_state: int = 42 + test_size: float = 0.2 + + # DeepSeek API配置 + deepseek_api_base: str = "https://api.deepseek.com" + deepseek_model: str = "deepseek-chat" + + class Config: + import os + env_file = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), ".env") + env_file_encoding = "utf-8" + + +# 创建全局配置实例 +settings = Settings() diff --git a/src/data_processing.py b/src/data_processing.py new file mode 100644 index 0000000..48369b1 --- /dev/null +++ b/src/data_processing.py @@ -0,0 +1,76 @@ +import polars as pl +import pandas as pd +from pathlib import Path +from typing import Tuple + + +def load_data(file_path: str) -> pl.DataFrame: + """使用Polars加载数据集""" + # 加载csv文件,处理编码问题 + df = pl.read_csv( + file_path, + encoding="latin-1", + ignore_errors=True, + has_header=True + ) + return df + + +def clean_data(df: pl.DataFrame) -> pl.DataFrame: + """清洗数据集""" + # 查看数据集基本信息 + print("原始数据集形状:", df.shape) + print("原始数据集列名:", df.columns) + + # 删除不必要的列(最后三列都是空的) + df = df.drop(df.columns[-3:]) + + # 重命名列名 + df = df.rename({ + "v1": "label", + "v2": "message" + }) + + # 查看清洗后的数据集 + print("清洗后数据集形状:", df.shape) + print("清洗后数据集列名:", df.columns) + print("标签分布:", df["label"].value_counts()) + + return df + + +def preprocess_data(df: pl.DataFrame) -> Tuple[pl.DataFrame, pl.Series]: + """预处理数据,准备用于模型训练""" + # 将标签转换为数值(ham=0, spam=1) + df = df.with_columns( + pl.when(pl.col("label") == "spam").then(1).otherwise(0).alias("label") + ) + + # 分离特征和标签 + X = df.drop("label") + y = df["label"] + + return X, y + + +def save_data(df: pl.DataFrame, file_path: str) -> None: + """保存处理后的数据集""" + df.write_csv(file_path, index=False) + print(f"数据集已保存到: {file_path}") + + +if __name__ == "__main__": + # 测试数据处理流程 + file_path = "../spam.csv" + # 检查文件是否存在 + import os + if not os.path.exists(file_path): + file_path = "./spam.csv" + df = load_data(file_path) + df_cleaned = clean_data(df) + X, y = preprocess_data(df_cleaned) + + print("特征数据形状:", X.shape) + print("标签数据形状:", y.shape) + print("前5行数据:") + print(df_cleaned.head()) diff --git a/src/machine_learning.py b/src/machine_learning.py new file mode 100644 index 0000000..828200a --- /dev/null +++ b/src/machine_learning.py @@ -0,0 +1,316 @@ +import polars as pl +import pandas as pd +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.linear_model import LogisticRegression +from sklearn.ensemble import RandomForestClassifier +import lightgbm as lgb +from sklearn.model_selection import train_test_split, GridSearchCV +from sklearn.metrics import ( + accuracy_score, precision_score, recall_score, f1_score, + roc_auc_score, classification_report, confusion_matrix +) +import joblib +from pathlib import Path +from typing import Tuple, Dict, Any, Optional +from config import settings + + +class SpamClassifier: + """垃圾短信分类器""" + def __init__(self, model_name: str = "lightgbm"): + """初始化分类器""" + self.model_name = model_name + self.model = None + self.vectorizer = None + self.load_model() + + def load_model(self): + """加载模型和向量器""" + model_dir = Path(settings.model_save_path) + + # 加载模型 + model_path = model_dir / f"{self.model_name}_model.joblib" + self.model = joblib.load(model_path) + print(f"模型已从: {model_path} 加载") + + # 加载向量器 + vectorizer_path = model_dir / f"{self.model_name}_vectorizer.joblib" + self.vectorizer = joblib.load(vectorizer_path) + print(f"向量器已从: {vectorizer_path} 加载") + + def classify(self, message: str) -> Dict[str, Any]: + """分类单条短信""" + # 将短信转换为向量 + message_vector = self.vectorizer.transform([message]) + + # 预测标签和置信度 + label = self.model.predict(message_vector)[0] + confidence = self.model.predict_proba(message_vector)[0][label] + + # 转换标签为文本 + label_text = "spam" if label == 1 else "ham" + + return { + "label": label_text, + "confidence": confidence + } + + +def extract_features( + X_train: pl.Series, + X_test: pl.Series, + max_features: int = 1000 +) -> Tuple[Any, Any, TfidfVectorizer]: + """ + 使用TF-IDF提取文本特征 + + Args: + X_train: 训练集文本 + X_test: 测试集文本 + max_features: 最大特征数 + + Returns: + 训练集特征、测试集特征、TF-IDF向量化器 + """ + # 将Polars Series转换为Pandas Series + X_train_pd = X_train.to_pandas() + X_test_pd = X_test.to_pandas() + + # 初始化TF-IDF向量化器 + tfidf = TfidfVectorizer( + max_features=max_features, + stop_words="english", + ngram_range=(1, 2) + ) + + # 拟合并转换训练集 + X_train_tfidf = tfidf.fit_transform(X_train_pd) + + # 转换测试集 + X_test_tfidf = tfidf.transform(X_test_pd) + + return X_train_tfidf, X_test_tfidf, tfidf + + +def train_logistic_regression( + X_train: Any, + y_train: pl.Series +) -> LogisticRegression: + """ + 训练Logistic Regression模型 + + Args: + X_train: 训练集特征 + y_train: 训练集标签 + + Returns: + 训练好的Logistic Regression模型 + """ + # 将Polars Series转换为Pandas Series + y_train_pd = y_train.to_pandas() + + # 初始化Logistic Regression模型 + log_reg = LogisticRegression( + random_state=settings.random_state, + max_iter=1000, + class_weight="balanced" + ) + + # 训练模型 + log_reg.fit(X_train, y_train_pd) + + return log_reg + + +def train_lightgbm( + X_train: Any, + y_train: pl.Series +) -> lgb.LGBMClassifier: + """ + 训练LightGBM模型 + + Args: + X_train: 训练集特征 + y_train: 训练集标签 + + Returns: + 训练好的LightGBM模型 + """ + # 将Polars Series转换为Pandas Series + y_train_pd = y_train.to_pandas() + + # 初始化LightGBM模型 + lgb_clf = lgb.LGBMClassifier( + random_state=settings.random_state, + class_weight="balanced", + n_estimators=1000, + learning_rate=0.1, + num_leaves=31 + ) + + # 训练模型 + lgb_clf.fit(X_train, y_train_pd) + + return lgb_clf + + +def evaluate_model( + model: Any, + X_test: Any, + y_test: pl.Series +) -> Dict[str, float]: + """ + 评估模型性能 + + Args: + model: 训练好的模型 + X_test: 测试集特征 + y_test: 测试集标签 + + Returns: + 模型评估指标 + """ + # 将Polars Series转换为Pandas Series + y_test_pd = y_test.to_pandas() + + # 预测 + y_pred = model.predict(X_test) + y_pred_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, 'predict_proba') else None + + # 计算评估指标 + metrics = { + "accuracy": accuracy_score(y_test_pd, y_pred), + "precision": precision_score(y_test_pd, y_pred), + "recall": recall_score(y_test_pd, y_pred), + "f1": f1_score(y_test_pd, y_pred) + } + + # 计算ROC-AUC(如果模型支持概率预测) + if y_pred_proba is not None: + metrics["roc_auc"] = roc_auc_score(y_test_pd, y_pred_proba) + + # 打印分类报告和混淆矩阵 + print("分类报告:") + print(classification_report(y_test_pd, y_pred)) + + print("混淆矩阵:") + print(confusion_matrix(y_test_pd, y_pred)) + + return metrics + + +def save_model( + model: Any, + model_name: str, + vectorizer: Any = None +) -> None: + """ + 保存模型和向量器 + + Args: + model: 训练好的模型 + model_name: 模型名称 + vectorizer: TF-IDF向量化器 + """ + # 创建模型保存目录 + model_dir = Path(settings.model_save_path) + model_dir.mkdir(exist_ok=True) + + # 保存模型 + model_path = model_dir / f"{model_name}_model.joblib" + joblib.dump(model, model_path) + print(f"模型已保存到: {model_path}") + + # 保存向量器(如果提供) + if vectorizer is not None: + vectorizer_path = model_dir / f"{model_name}_vectorizer.joblib" + joblib.dump(vectorizer, vectorizer_path) + print(f"向量器已保存到: {vectorizer_path}") + + +def load_model( + model_name: str +) -> Tuple[Any, Any]: + """ + 加载模型和向量器 + + Args: + model_name: 模型名称 + + Returns: + 加载的模型和向量器 + """ + # 创建模型保存目录 + model_dir = Path(settings.model_save_path) + + # 加载模型 + model_path = model_dir / f"{model_name}_model.joblib" + model = joblib.load(model_path) + print(f"模型已从: {model_path} 加载") + + # 加载向量器 + vectorizer_path = model_dir / f"{model_name}_vectorizer.joblib" + vectorizer = joblib.load(vectorizer_path) + print(f"向量器已从: {vectorizer_path} 加载") + + return model, vectorizer + + +def main(): + """机器学习主函数""" + # 1. 加载数据集 + print("正在加载数据集...") + df = pl.read_csv("../spam.csv", encoding="latin-1", ignore_errors=True) + + # 2. 清洗数据集 + print("正在清洗数据集...") + df = df.drop(df.columns[-3:]) + df = df.rename({"v1": "label", "v2": "message"}) + df = df.with_columns( + pl.when(pl.col("label") == "spam").then(1).otherwise(0).alias("label") + ) + + # 3. 分离特征和标签 + X = df["message"] + y = df["label"] + + # 4. 划分训练集和测试集 + X_train, X_test, y_train, y_test = train_test_split( + X, y, test_size=settings.test_size, random_state=settings.random_state, stratify=y + ) + + print(f"训练集大小: {len(X_train)}") + print(f"测试集大小: {len(X_test)}") + + # 5. 特征提取 + print("正在提取特征...") + X_train_tfidf, X_test_tfidf, tfidf = extract_features(X_train, X_test) + + # 6. 训练Logistic Regression模型 + print("\n正在训练Logistic Regression模型...") + log_reg_model = train_logistic_regression(X_train_tfidf, y_train) + + # 7. 评估Logistic Regression模型 + print("\n评估Logistic Regression模型:") + log_reg_metrics = evaluate_model(log_reg_model, X_test_tfidf, y_test) + print(f"Logistic Regression指标: {log_reg_metrics}") + + # 8. 训练LightGBM模型 + print("\n正在训练LightGBM模型...") + lgb_model = train_lightgbm(X_train_tfidf, y_train) + + # 9. 评估LightGBM模型 + print("\n评估LightGBM模型:") + lgb_metrics = evaluate_model(lgb_model, X_test_tfidf, y_test) + print(f"LightGBM指标: {lgb_metrics}") + + # 10. 保存模型 + print("\n正在保存模型...") + save_model(log_reg_model, "logistic_regression", tfidf) + save_model(lgb_model, "lightgbm", tfidf) + + print("\n机器学习流程完成!") + + +if __name__ == "__main__": + main() diff --git a/src/main.py b/src/main.py new file mode 100644 index 0000000..d6a70fa --- /dev/null +++ b/src/main.py @@ -0,0 +1,24 @@ +from data_processing import load_data, clean_data, save_data +from translation import translate_dataset + + +def main(): + """主函数""" + # 1. 加载数据集 + print("正在加载数据集...") + df = load_data("../spam.csv") + + # 2. 清洗数据集 + print("\n正在清洗数据集...") + df_cleaned = clean_data(df) + + # 3. 只翻译前10条短信进行测试 + print("\n正在翻译前10条短信进行测试...") + df_test = df_cleaned.head(10) + translated_path = translate_dataset(df_test) + + print(f"\n测试完成!翻译后的测试数据集已保存到: {translated_path}") + + +if __name__ == "__main__": + main() diff --git a/src/simple_agent.py b/src/simple_agent.py new file mode 100644 index 0000000..a6522f1 --- /dev/null +++ b/src/simple_agent.py @@ -0,0 +1,150 @@ +import requests +from typing import List, Dict, Any +from config import settings +from machine_learning import SpamClassifier +from translation import translate_text + + +class SimpleSpamAnalysis: + """简单的垃圾短信分析系统""" + + def __init__(self, model_name: str = "lightgbm"): + """初始化分析系统""" + self.classifier = SpamClassifier(model_name) + + def analyze(self, message: str, is_english: bool = True) -> Dict[str, Any]: + """分析单条短信""" + # 1. 翻译短信 + message_zh = translate_text(message, "zh-CN") if is_english else message + + # 2. 分类短信 + classification = self.classifier.classify(message) + + # 3. 提取关键词 + key_words = self.extract_keywords(message) + + # 4. 生成解释和建议 + reason, suggestion = self.generate_explanation(key_words, classification["label"]) + + # 5. 使用DeepSeek API生成更详细的解释 + detailed_explanation = self.generate_detailed_explanation( + message, message_zh, classification["label"], key_words + ) + + return { + "original_message": message, + "translated_message": message_zh, + "classification": classification, + "key_words": key_words, + "reason": reason, + "suggestion": suggestion, + "detailed_explanation": detailed_explanation + } + + def extract_keywords(self, message: str, top_n: int = 5) -> List[str]: + """提取关键词""" + words = message.lower().split() + stop_words = set([ + "the", "a", "an", "and", "or", "but", "in", "on", "at", "to", "for", + "with", "by", "from", "up", "down", "about", "above", "below", "of", + "is", "are", "was", "were", "be", "been", "being", "have", "has", + "had", "do", "does", "did", "will", "would", "shall", "should", + "may", "might", "must", "can", "could", "not", "no", "yes", "if", + "then", "than", "so", "because", "as", "when", "where", "who", "which", + "that", "this", "these", "those", "i", "me", "my", "mine", "you", + "your", "yours", "he", "him", "his", "she", "her", "hers", "it", + "its", "we", "us", "our", "ours", "they", "them", "their", "theirs" + ]) + + keywords = [word for word in words if word not in stop_words and len(word) > 2] + return keywords[:top_n] + + def generate_explanation(self, key_words: List[str], label: str) -> tuple: + """生成基本解释和建议""" + if label == "spam": + reason = f"该短信包含垃圾短信特征词: {', '.join(key_words)}" + suggestion = "建议立即删除该短信,不要点击任何链接,不要回复,避免上当受骗" + else: + reason = f"该短信为正常短信,包含常用词汇: {', '.join(key_words)}" + suggestion = "可以正常回复和处理该短信" + return reason, suggestion + + def generate_detailed_explanation(self, message: str, message_zh: str, label: str, key_words: List[str]) -> str: + """使用DeepSeek API生成详细解释""" + headers = { + "Authorization": f"Bearer {settings.deepseek_api_key}", + "Content-Type": "application/json" + } + + prompt = f""" + 分析以下短信: + 英文:{message} + 中文:{message_zh} + 分类结果:{label} + 关键词:{', '.join(key_words)} + + 请提供: + 1. 详细的分类原因 + 2. 短信的主要特征 + 3. 针对该短信的具体建议 + 4. 如何识别类似的短信 + + 请使用中文回答,保持简洁明了。 + """ + + payload = { + "model": settings.deepseek_model, + "messages": [ + { + "role": "system", + "content": "你是一名专业的垃圾短信分析师,请根据提供的信息进行详细分析。" + }, + { + "role": "user", + "content": prompt + } + ], + "max_tokens": 500, + "temperature": 0.1 + } + + try: + response = requests.post( + f"{settings.deepseek_api_base}/chat/completions", + headers=headers, + json=payload, + timeout=30 + ) + response.raise_for_status() + + result = response.json() + explanation = result["choices"][0]["message"]["content"].strip() + return explanation + except requests.exceptions.RequestException as e: + print(f"生成详细解释失败: {e}") + return "无法生成详细解释,请检查API连接。" + + +if __name__ == "__main__": + # 初始化分析系统 + analyzer = SimpleSpamAnalysis() + + # 测试短信 + test_messages = [ + "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's", + "Ok lar... Joking wif u oni...", + "WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only." + ] + + # 分析短信 + for i, message in enumerate(test_messages): + print(f"\n=== 短信分析结果 {i+1} ===") + result = analyzer.analyze(message) + + print(f"原始短信: {result['original_message']}") + print(f"中文翻译: {result['translated_message']}") + print(f"分类结果: {result['classification']['label']} (置信度: {result['classification']['confidence']:.2f})") + print(f"关键词: {', '.join(result['key_words'])}") + print(f"原因: {result['reason']}") + print(f"建议: {result['suggestion']}") + print(f"详细解释: {result['detailed_explanation']}") diff --git a/src/translation.py b/src/translation.py new file mode 100644 index 0000000..1251461 --- /dev/null +++ b/src/translation.py @@ -0,0 +1,130 @@ +import requests +from typing import List, Dict +from config import settings +import time + + +def translate_text(text: str, target_lang: str = "zh-CN") -> str: + """ + 使用DeepSeek API将文本翻译成目标语言 + + Args: + text: 要翻译的文本 + target_lang: 目标语言,默认为中文(zh-CN) + + Returns: + 翻译后的文本 + """ + headers = { + "Authorization": f"Bearer {settings.deepseek_api_key}", + "Content-Type": "application/json" + } + + payload = { + "model": settings.deepseek_model, + "messages": [ + { + "role": "system", + "content": f"You are a professional translator. Translate the following text to {target_lang}. Keep the original meaning and tone. Do not add any additional information." + }, + { + "role": "user", + "content": text + } + ], + "max_tokens": 1000, + "temperature": 0.1 + } + + try: + response = requests.post( + f"{settings.deepseek_api_base}/chat/completions", + headers=headers, + json=payload, + timeout=30 + ) + response.raise_for_status() + + result = response.json() + translated_text = result["choices"][0]["message"]["content"].strip() + return translated_text + except requests.exceptions.RequestException as e: + print(f"翻译失败: {e}") + return text + + +def translate_batch(texts: List[str], target_lang: str = "zh-CN", batch_size: int = 10) -> List[str]: + """ + 批量翻译文本 + + Args: + texts: 要翻译的文本列表 + target_lang: 目标语言,默认为中文(zh-CN) + batch_size: 批量大小,默认为10 + + Returns: + 翻译后的文本列表 + """ + translated_texts = [] + + for i in range(0, len(texts), batch_size): + batch = texts[i:i+batch_size] + batch_translated = [] + + for text in batch: + translated = translate_text(text, target_lang) + batch_translated.append(translated) + # 添加延迟,避免API限流 + time.sleep(0.5) + + translated_texts.extend(batch_translated) + print(f"已翻译 {min(i+batch_size, len(texts))}/{len(texts)} 条文本") + + return translated_texts + + +def translate_dataset(df, text_column: str = "message", target_column: str = "message_zh") -> str: + """ + 翻译数据集中的文本列 + + Args: + df: Polars DataFrame + text_column: 要翻译的文本列名 + target_column: 翻译后的文本列名 + + Returns: + 翻译后的数据集文件路径 + """ + import polars as pl + import os + + # 创建data目录(如果不存在) + os.makedirs(settings.data_path, exist_ok=True) + + # 提取文本列表 + texts = df[text_column].to_list() + + # 翻译文本 + print(f"开始翻译 {len(texts)} 条文本...") + translated_texts = translate_batch(texts) + + # 添加翻译后的列到数据集 + df = df.with_columns( + pl.Series(target_column, translated_texts) + ) + + # 保存翻译后的数据集 + output_path = f"{settings.data_path}/spam_zh.csv" + df.write_csv(output_path, index=False) + + print(f"翻译后的数据集已保存到: {output_path}") + print(f"翻译完成!共翻译了 {len(texts)} 条文本") + return output_path + + +if __name__ == "__main__": + # 测试翻译功能 + test_text = "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's" + translated = translate_text(test_text) + print(f"原文: {test_text}") + print(f"译文: {translated}") diff --git a/test_analysis.py b/test_analysis.py new file mode 100644 index 0000000..6b88433 --- /dev/null +++ b/test_analysis.py @@ -0,0 +1,31 @@ +import sys +import os + +# 添加src目录到Python路径 +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src')) + +from simple_agent import SimpleSpamAnalysis + + +# 测试短信 +test_messages = [ + "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's", + "Ok lar... Joking wif u oni...", + "WINNER!! As a valued network customer you have been selected to receivea £900 prize reward! To claim call 09061701461. Claim code KL341. Valid 12 hours only." +] + +# 初始化分析系统 +analyzer = SimpleSpamAnalysis() + +# 分析短信 +for i, message in enumerate(test_messages): + print(f"\n=== 短信分析结果 {i+1} ===") + result = analyzer.analyze(message) + + print(f"原始短信: {result['original_message'][:100]}...") + print(f"中文翻译: {result['translated_message'][:100]}...") + print(f"分类结果: {result['classification']['label']} (置信度: {result['classification']['confidence']:.2f})") + print(f"关键词: {', '.join(result['key_words'])}") + print(f"原因: {result['reason']}") + print(f"建议: {result['suggestion']}") + print(f"详细解释: {result['detailed_explanation'][:200]}...") diff --git a/test_translation.py b/test_translation.py new file mode 100644 index 0000000..7df5883 --- /dev/null +++ b/test_translation.py @@ -0,0 +1,7 @@ +from src.translation import translate_text + +# 测试单个翻译功能 +test_text = "Hello, how are you?" +print(f"原文: {test_text}") +translated = translate_text(test_text) +print(f"译文: {translated}")