diff --git a/.gitignore b/.gitignore index cc0310c..76e7681 100644 --- a/.gitignore +++ b/.gitignore @@ -3,8 +3,6 @@ env/ venv/ -# Model files -models/ # Environment variables .env diff --git a/models/lightgbm.joblib b/models/lightgbm.joblib new file mode 100644 index 0000000..76e1104 Binary files /dev/null and b/models/lightgbm.joblib differ diff --git a/models/lightgbm_classification_report.txt b/models/lightgbm_classification_report.txt new file mode 100644 index 0000000..85b43b8 --- /dev/null +++ b/models/lightgbm_classification_report.txt @@ -0,0 +1,8 @@ + precision recall f1-score support + + ham 0.98 0.99 0.99 969 + spam 0.93 0.89 0.91 145 + + accuracy 0.98 1114 + macro avg 0.96 0.94 0.95 1114 +weighted avg 0.98 0.98 0.98 1114 diff --git a/models/lightgbm_confusion_matrix.png b/models/lightgbm_confusion_matrix.png new file mode 100644 index 0000000..8505421 Binary files /dev/null and b/models/lightgbm_confusion_matrix.png differ diff --git a/models/logistic_regression.joblib b/models/logistic_regression.joblib new file mode 100644 index 0000000..a8c4fc7 Binary files /dev/null and b/models/logistic_regression.joblib differ diff --git a/models/logistic_regression_classification_report.txt b/models/logistic_regression_classification_report.txt new file mode 100644 index 0000000..e888fbc --- /dev/null +++ b/models/logistic_regression_classification_report.txt @@ -0,0 +1,8 @@ + precision recall f1-score support + + ham 0.99 0.99 0.99 969 + spam 0.92 0.94 0.93 145 + + accuracy 0.98 1114 + macro avg 0.95 0.96 0.96 1114 +weighted avg 0.98 0.98 0.98 1114 diff --git a/models/logistic_regression_confusion_matrix.png b/models/logistic_regression_confusion_matrix.png new file mode 100644 index 0000000..84db3f5 Binary files /dev/null and b/models/logistic_regression_confusion_matrix.png differ diff --git a/src/models/__init__.py b/src/models/__init__.py new file mode 100644 index 0000000..fae3f00 --- /dev/null +++ b/src/models/__init__.py @@ -0,0 +1,42 @@ +from .train import ( + build_tfidf_vectorizer, + build_logistic_regression, + build_lightgbm, + train_model, + save_model, + load_model +) +from .evaluate import ( + evaluate_model, + print_evaluation_results, + plot_confusion_matrix, + save_evaluation_results, + compare_models +) +from .predict import ( + predict_spam, + predict_batch_spam, + predict_from_df +) + +__all__ = [ + # train.py + "build_tfidf_vectorizer", + "build_logistic_regression", + "build_lightgbm", + "train_model", + "save_model", + "load_model", + + # evaluate.py + "evaluate_model", + "print_evaluation_results", + "plot_confusion_matrix", + "save_evaluation_results", + "compare_models", + + # predict.py + "predict_spam", + "predict_batch_spam", + "predict_from_df" +] \ No newline at end of file diff --git a/src/models/evaluate.py b/src/models/evaluate.py new file mode 100644 index 0000000..fb99e6d --- /dev/null +++ b/src/models/evaluate.py @@ -0,0 +1,159 @@ +from sklearn.metrics import ( + accuracy_score, + precision_score, + recall_score, + f1_score, + classification_report, + confusion_matrix +) +import matplotlib.pyplot as plt +import seaborn as sns +import os +from typing import Dict, Any, List +import polars as pl + + + +def evaluate_model( + model, + test_df: pl.DataFrame +) -> Dict[str, Any]: + """评估模型""" + # 准备测试数据 + X_test = test_df["clean_text"].to_list() + y_test = test_df["label_num"].to_list() + + # 预测 + y_pred = model.predict(X_test) + y_pred_proba = model.predict_proba(X_test)[:, 1] + + # 计算指标 + accuracy = accuracy_score(y_test, y_pred) + precision = precision_score(y_test, y_pred, average="macro") + recall = recall_score(y_test, y_pred, average="macro") + f1 = f1_score(y_test, y_pred, average="macro") + + # 生成分类报告 + class_report = classification_report(y_test, y_pred, target_names=["ham", "spam"]) + + # 生成混淆矩阵 + conf_matrix = confusion_matrix(y_test, y_pred) + + # 收集结果 + results = { + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1": f1, + "class_report": class_report, + "conf_matrix": conf_matrix, + "y_test": y_test, + "y_pred": y_pred, + "y_pred_proba": y_pred_proba + } + + return results + + +def print_evaluation_results(results: Dict[str, Any]) -> None: + """打印评估结果""" + print(f"Accuracy: {results['accuracy']:.4f}") + print(f"Precision (macro): {results['precision']:.4f}") + print(f"Recall (macro): {results['recall']:.4f}") + print(f"F1 Score (macro): {results['f1']:.4f}") + print("\nClassification Report:") + print(results['class_report']) + + +def plot_confusion_matrix( + conf_matrix: List[List[int]], + save_path: str = None +) -> None: + """绘制混淆矩阵""" + plt.figure(figsize=(8, 6)) + sns.heatmap( + conf_matrix, + annot=True, + fmt="d", + cmap="Blues", + xticklabels=["ham", "spam"], + yticklabels=["ham", "spam"] + ) + plt.xlabel("Predicted Label") + plt.ylabel("True Label") + plt.title("Confusion Matrix") + + if save_path: + os.makedirs(os.path.dirname(save_path), exist_ok=True) + plt.savefig(save_path, dpi=300, bbox_inches="tight") + print(f"混淆矩阵已保存到: {save_path}") + + plt.close() + + +def save_evaluation_results( + results: Dict[str, Any], + model_name: str, + save_dir: str = "./models" +) -> None: + """保存评估结果""" + # 确保保存目录存在 + os.makedirs(save_dir, exist_ok=True) + + # 保存分类报告 + report_path = os.path.join(save_dir, f"{model_name}_classification_report.txt") + with open(report_path, "w") as f: + f.write(results["class_report"]) + + # 保存混淆矩阵 + conf_matrix_path = os.path.join(save_dir, f"{model_name}_confusion_matrix.png") + plot_confusion_matrix(results["conf_matrix"], conf_matrix_path) + + print(f"评估结果已保存到: {save_dir}") + + +def compare_models( + models: Dict[str, Any], + test_df: pl.DataFrame +) -> Dict[str, Dict[str, Any]]: + """比较多个模型""" + comparison_results = {} + + for model_name, model in models.items(): + print(f"\n=== 评估 {model_name} 模型 ===") + results = evaluate_model(model, test_df) + print_evaluation_results(results) + save_evaluation_results(results, model_name) + comparison_results[model_name] = results + + return comparison_results + + +def main(): + """主函数,用于测试""" + from .train import load_model + from ..data import load_data, preprocess_data, split_data + + # 加载和预处理数据 + df = load_data("../data/spam.csv") + processed_df = preprocess_data(df) + _, test_df = split_data(processed_df) + + # 加载模型 + try: + lr_model = load_model("logistic_regression") + lgbm_model = load_model("lightgbm") + + # 比较模型 + models = { + "logistic_regression": lr_model, + "lightgbm": lgbm_model + } + + compare_models(models, test_df) + except FileNotFoundError: + print("模型文件不存在,请先运行train.py训练模型") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/models/predict.py b/src/models/predict.py new file mode 100644 index 0000000..55eca16 --- /dev/null +++ b/src/models/predict.py @@ -0,0 +1,135 @@ +from typing import List, Dict, Any +import polars as pl +import warnings +from sklearn.exceptions import DataConversionWarning +from .train import load_model +from ..data import clean_text + +# 忽略特征名称不匹配的警告 +warnings.filterwarnings("ignore", message="X does not have valid feature names") + + +def predict_spam( + text: str, + model_name: str = "lightgbm" +) -> Dict[str, Any]: + """预测单条短信是否为垃圾短信""" + # 加载模型 + model = load_model(model_name) + + # 清洗文本 + cleaned_text = clean_text(text) + + # 预测 + prediction = model.predict([cleaned_text])[0] + probability = model.predict_proba([cleaned_text])[0] + + # 转换结果 + label = "spam" if prediction == 1 else "ham" + + return { + "original_text": text, + "cleaned_text": cleaned_text, + "label": label, + "label_num": int(prediction), + "probability": { + "ham": float(probability[0]), + "spam": float(probability[1]) + } + } + + +def predict_batch_spam( + texts: List[str], + model_name: str = "lightgbm" +) -> List[Dict[str, Any]]: + """批量预测短信是否为垃圾短信""" + # 加载模型 + model = load_model(model_name) + + # 清洗文本 + cleaned_texts = [clean_text(text) for text in texts] + + # 预测 + predictions = model.predict(cleaned_texts) + probabilities = model.predict_proba(cleaned_texts) + + # 转换结果 + results = [] + for i, (text, cleaned_text, prediction, probability) in enumerate(zip(texts, cleaned_texts, predictions, probabilities)): + label = "spam" if prediction == 1 else "ham" + results.append({ + "id": i, + "original_text": text, + "cleaned_text": cleaned_text, + "label": label, + "label_num": int(prediction), + "probability": { + "ham": float(probability[0]), + "spam": float(probability[1]) + } + }) + + return results + + +def predict_from_df( + df: pl.DataFrame, + text_column: str = "text", + model_name: str = "lightgbm" +) -> pl.DataFrame: + """从DataFrame中预测短信是否为垃圾短信""" + # 清洗文本 + df = df.with_columns( + pl.col(text_column).map_elements(clean_text, return_dtype=pl.String).alias("clean_text") + ) + + # 准备预测数据 + texts = df["clean_text"].to_list() + + # 预测 + model = load_model(model_name) + predictions = model.predict(texts) + probabilities = model.predict_proba(texts) + + # 添加预测结果 + df = df.with_columns( + pl.Series("label_num", predictions).cast(pl.Int64), + pl.Series("ham_prob", [p[0] for p in probabilities]).cast(pl.Float64), + pl.Series("spam_prob", [p[1] for p in probabilities]).cast(pl.Float64) + ) + + # 添加标签 + df = df.with_columns( + pl.when(pl.col("label_num") == 1) + .then("spam") + .otherwise("ham") + .alias("label") + ) + + return df + + +def main(): + """主函数,用于测试""" + # 测试单条预测 + test_text = "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's" + result = predict_spam(test_text) + print("单条预测结果:") + print(result) + + # 测试批量预测 + test_texts = [ + "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005.", + "Ok lar... Joking wif u oni...", + "WINNER!! As a valued network customer you have been selected to receivea �900 prize reward!", + "I'm gonna be home soon and i don't want to talk about this stuff anymore tonight, k?" + ] + results = predict_batch_spam(test_texts) + print("\n批量预测结果:") + for res in results: + print(res) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/models/train.py b/src/models/train.py new file mode 100644 index 0000000..c5f5ddc --- /dev/null +++ b/src/models/train.py @@ -0,0 +1,162 @@ +import os +import joblib +from sklearn.feature_extraction.text import TfidfVectorizer +from sklearn.linear_model import LogisticRegression +from lightgbm import LGBMClassifier +from sklearn.pipeline import Pipeline +from sklearn.model_selection import GridSearchCV +from typing import Tuple, Dict, Any +import polars as pl +from dotenv import load_dotenv + +# 加载环境变量 +load_dotenv() + +# 获取模型保存目录 +MODEL_SAVE_DIR = os.getenv("MODEL_SAVE_DIR", "./models") + + +def build_tfidf_vectorizer() -> TfidfVectorizer: + """构建TF-IDF向量化器""" + return TfidfVectorizer( + stop_words="english", + max_features=10000, + ngram_range=(1, 2), + lowercase=True, + token_pattern=r'\b[a-zA-Z0-9_]{1,}\b' + ) + + +def build_logistic_regression() -> Pipeline: + """构建Logistic Regression模型流水线""" + pipeline = Pipeline([ + ("tfidf", build_tfidf_vectorizer()), + ("lr", LogisticRegression( + random_state=42, + max_iter=1000, + class_weight="balanced" + )) + ]) + return pipeline + + +def build_lightgbm() -> Pipeline: + """构建LightGBM模型流水线""" + pipeline = Pipeline([ + ("tfidf", build_tfidf_vectorizer()), + ("lgbm", LGBMClassifier( + random_state=42, + class_weight="balanced", + verbose=-1, + feature_name="auto" + )) + ]) + return pipeline + + +def train_model( + train_df: pl.DataFrame, + model_type: str = "lightgbm", + hyperparam_tune: bool = False +) -> Tuple[Pipeline, Dict[str, Any]]: + """训练模型""" + # 准备训练数据 + X_train = train_df["clean_text"].to_list() + y_train = train_df["label_num"].to_list() + + # 选择模型 + if model_type == "logistic_regression": + pipeline = build_logistic_regression() + param_grid = { + "lr__C": [0.1, 1.0, 10.0], + "lr__solver": ["liblinear", "lbfgs"] + } + elif model_type == "lightgbm": + pipeline = build_lightgbm() + param_grid = { + "lgbm__n_estimators": [100, 200], + "lgbm__learning_rate": [0.1, 0.2], + "lgbm__max_depth": [3, 5, 7] + } + else: + raise ValueError(f"不支持的模型类型: {model_type}") + + # 超参数调优 + if hyperparam_tune: + grid_search = GridSearchCV( + pipeline, + param_grid, + cv=5, + scoring="f1_macro", + n_jobs=-1 + ) + grid_search.fit(X_train, y_train) + best_model = grid_search.best_estimator_ + best_params = grid_search.best_params_ + print(f"最佳参数: {best_params}") + else: + best_model = pipeline + best_model.fit(X_train, y_train) + best_params = {} + + return best_model, best_params + + +def save_model(model: Pipeline, model_name: str) -> str: + """保存模型""" + # 确保模型保存目录存在 + os.makedirs(MODEL_SAVE_DIR, exist_ok=True) + + # 构建模型文件路径 + model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.joblib") + + # 保存模型 + joblib.dump(model, model_path) + print(f"模型已保存到: {model_path}") + + return model_path + + +def load_model(model_name: str) -> Pipeline: + """加载模型""" + # 构建模型文件路径 + model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.joblib") + + # 加载模型 + if not os.path.exists(model_path): + raise FileNotFoundError(f"模型文件不存在: {model_path}") + + model = joblib.load(model_path) + print(f"模型已从: {model_path} 加载") + + return model + + +def main(): + """主函数,用于测试""" + import sys + import os + + # 添加项目根目录到Python路径 + sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) + + from src.data import load_data, preprocess_data, split_data + + # 加载和预处理数据 + df = load_data("./data/spam.csv") + processed_df = preprocess_data(df) + train_df, test_df = split_data(processed_df) + + print("训练Logistic Regression模型...") + lr_model, lr_params = train_model(train_df, model_type="logistic_regression") + save_model(lr_model, "logistic_regression") + + print("\n训练LightGBM模型...") + lgbm_model, lgbm_params = train_model(train_df, model_type="lightgbm") + save_model(lgbm_model, "lightgbm") + + print("\n模型训练完成!") + + +if __name__ == "__main__": + main() \ No newline at end of file