feat(模型): 添加模型训练、评估和预测功能

- 实现模型训练模块,支持逻辑回归和LightGBM
- 添加模型评估功能,包括分类报告和混淆矩阵
- 实现预测功能,支持单条、批量及DataFrame输入
- 更新.gitignore,移除models目录忽略
- 添加模型评估结果文件
This commit is contained in:
2026-01-15 14:15:17 +08:00
parent e081ed9329
commit 48013e93bd
11 changed files with 514 additions and 2 deletions

2
.gitignore vendored
View File

@ -3,8 +3,6 @@
env/ env/
venv/ venv/
# Model files
models/
# Environment variables # Environment variables
.env .env

BIN
models/lightgbm.joblib Normal file

Binary file not shown.

View File

@ -0,0 +1,8 @@
precision recall f1-score support
ham 0.98 0.99 0.99 969
spam 0.93 0.89 0.91 145
accuracy 0.98 1114
macro avg 0.96 0.94 0.95 1114
weighted avg 0.98 0.98 0.98 1114

Binary file not shown.

After

Width:  |  Height:  |  Size: 65 KiB

Binary file not shown.

View File

@ -0,0 +1,8 @@
precision recall f1-score support
ham 0.99 0.99 0.99 969
spam 0.92 0.94 0.93 145
accuracy 0.98 1114
macro avg 0.95 0.96 0.96 1114
weighted avg 0.98 0.98 0.98 1114

Binary file not shown.

After

Width:  |  Height:  |  Size: 64 KiB

42
src/models/__init__.py Normal file
View File

@ -0,0 +1,42 @@
from .train import (
build_tfidf_vectorizer,
build_logistic_regression,
build_lightgbm,
train_model,
save_model,
load_model
)
from .evaluate import (
evaluate_model,
print_evaluation_results,
plot_confusion_matrix,
save_evaluation_results,
compare_models
)
from .predict import (
predict_spam,
predict_batch_spam,
predict_from_df
)
__all__ = [
# train.py
"build_tfidf_vectorizer",
"build_logistic_regression",
"build_lightgbm",
"train_model",
"save_model",
"load_model",
# evaluate.py
"evaluate_model",
"print_evaluation_results",
"plot_confusion_matrix",
"save_evaluation_results",
"compare_models",
# predict.py
"predict_spam",
"predict_batch_spam",
"predict_from_df"
]

159
src/models/evaluate.py Normal file
View File

@ -0,0 +1,159 @@
from sklearn.metrics import (
accuracy_score,
precision_score,
recall_score,
f1_score,
classification_report,
confusion_matrix
)
import matplotlib.pyplot as plt
import seaborn as sns
import os
from typing import Dict, Any, List
import polars as pl
def evaluate_model(
model,
test_df: pl.DataFrame
) -> Dict[str, Any]:
"""评估模型"""
# 准备测试数据
X_test = test_df["clean_text"].to_list()
y_test = test_df["label_num"].to_list()
# 预测
y_pred = model.predict(X_test)
y_pred_proba = model.predict_proba(X_test)[:, 1]
# 计算指标
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred, average="macro")
recall = recall_score(y_test, y_pred, average="macro")
f1 = f1_score(y_test, y_pred, average="macro")
# 生成分类报告
class_report = classification_report(y_test, y_pred, target_names=["ham", "spam"])
# 生成混淆矩阵
conf_matrix = confusion_matrix(y_test, y_pred)
# 收集结果
results = {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
"class_report": class_report,
"conf_matrix": conf_matrix,
"y_test": y_test,
"y_pred": y_pred,
"y_pred_proba": y_pred_proba
}
return results
def print_evaluation_results(results: Dict[str, Any]) -> None:
"""打印评估结果"""
print(f"Accuracy: {results['accuracy']:.4f}")
print(f"Precision (macro): {results['precision']:.4f}")
print(f"Recall (macro): {results['recall']:.4f}")
print(f"F1 Score (macro): {results['f1']:.4f}")
print("\nClassification Report:")
print(results['class_report'])
def plot_confusion_matrix(
conf_matrix: List[List[int]],
save_path: str = None
) -> None:
"""绘制混淆矩阵"""
plt.figure(figsize=(8, 6))
sns.heatmap(
conf_matrix,
annot=True,
fmt="d",
cmap="Blues",
xticklabels=["ham", "spam"],
yticklabels=["ham", "spam"]
)
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.title("Confusion Matrix")
if save_path:
os.makedirs(os.path.dirname(save_path), exist_ok=True)
plt.savefig(save_path, dpi=300, bbox_inches="tight")
print(f"混淆矩阵已保存到: {save_path}")
plt.close()
def save_evaluation_results(
results: Dict[str, Any],
model_name: str,
save_dir: str = "./models"
) -> None:
"""保存评估结果"""
# 确保保存目录存在
os.makedirs(save_dir, exist_ok=True)
# 保存分类报告
report_path = os.path.join(save_dir, f"{model_name}_classification_report.txt")
with open(report_path, "w") as f:
f.write(results["class_report"])
# 保存混淆矩阵
conf_matrix_path = os.path.join(save_dir, f"{model_name}_confusion_matrix.png")
plot_confusion_matrix(results["conf_matrix"], conf_matrix_path)
print(f"评估结果已保存到: {save_dir}")
def compare_models(
models: Dict[str, Any],
test_df: pl.DataFrame
) -> Dict[str, Dict[str, Any]]:
"""比较多个模型"""
comparison_results = {}
for model_name, model in models.items():
print(f"\n=== 评估 {model_name} 模型 ===")
results = evaluate_model(model, test_df)
print_evaluation_results(results)
save_evaluation_results(results, model_name)
comparison_results[model_name] = results
return comparison_results
def main():
"""主函数,用于测试"""
from .train import load_model
from ..data import load_data, preprocess_data, split_data
# 加载和预处理数据
df = load_data("../data/spam.csv")
processed_df = preprocess_data(df)
_, test_df = split_data(processed_df)
# 加载模型
try:
lr_model = load_model("logistic_regression")
lgbm_model = load_model("lightgbm")
# 比较模型
models = {
"logistic_regression": lr_model,
"lightgbm": lgbm_model
}
compare_models(models, test_df)
except FileNotFoundError:
print("模型文件不存在请先运行train.py训练模型")
if __name__ == "__main__":
main()

135
src/models/predict.py Normal file
View File

@ -0,0 +1,135 @@
from typing import List, Dict, Any
import polars as pl
import warnings
from sklearn.exceptions import DataConversionWarning
from .train import load_model
from ..data import clean_text
# 忽略特征名称不匹配的警告
warnings.filterwarnings("ignore", message="X does not have valid feature names")
def predict_spam(
text: str,
model_name: str = "lightgbm"
) -> Dict[str, Any]:
"""预测单条短信是否为垃圾短信"""
# 加载模型
model = load_model(model_name)
# 清洗文本
cleaned_text = clean_text(text)
# 预测
prediction = model.predict([cleaned_text])[0]
probability = model.predict_proba([cleaned_text])[0]
# 转换结果
label = "spam" if prediction == 1 else "ham"
return {
"original_text": text,
"cleaned_text": cleaned_text,
"label": label,
"label_num": int(prediction),
"probability": {
"ham": float(probability[0]),
"spam": float(probability[1])
}
}
def predict_batch_spam(
texts: List[str],
model_name: str = "lightgbm"
) -> List[Dict[str, Any]]:
"""批量预测短信是否为垃圾短信"""
# 加载模型
model = load_model(model_name)
# 清洗文本
cleaned_texts = [clean_text(text) for text in texts]
# 预测
predictions = model.predict(cleaned_texts)
probabilities = model.predict_proba(cleaned_texts)
# 转换结果
results = []
for i, (text, cleaned_text, prediction, probability) in enumerate(zip(texts, cleaned_texts, predictions, probabilities)):
label = "spam" if prediction == 1 else "ham"
results.append({
"id": i,
"original_text": text,
"cleaned_text": cleaned_text,
"label": label,
"label_num": int(prediction),
"probability": {
"ham": float(probability[0]),
"spam": float(probability[1])
}
})
return results
def predict_from_df(
df: pl.DataFrame,
text_column: str = "text",
model_name: str = "lightgbm"
) -> pl.DataFrame:
"""从DataFrame中预测短信是否为垃圾短信"""
# 清洗文本
df = df.with_columns(
pl.col(text_column).map_elements(clean_text, return_dtype=pl.String).alias("clean_text")
)
# 准备预测数据
texts = df["clean_text"].to_list()
# 预测
model = load_model(model_name)
predictions = model.predict(texts)
probabilities = model.predict_proba(texts)
# 添加预测结果
df = df.with_columns(
pl.Series("label_num", predictions).cast(pl.Int64),
pl.Series("ham_prob", [p[0] for p in probabilities]).cast(pl.Float64),
pl.Series("spam_prob", [p[1] for p in probabilities]).cast(pl.Float64)
)
# 添加标签
df = df.with_columns(
pl.when(pl.col("label_num") == 1)
.then("spam")
.otherwise("ham")
.alias("label")
)
return df
def main():
"""主函数,用于测试"""
# 测试单条预测
test_text = "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's"
result = predict_spam(test_text)
print("单条预测结果:")
print(result)
# 测试批量预测
test_texts = [
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005.",
"Ok lar... Joking wif u oni...",
"WINNER!! As a valued network customer you have been selected to receivea <20>900 prize reward!",
"I'm gonna be home soon and i don't want to talk about this stuff anymore tonight, k?"
]
results = predict_batch_spam(test_texts)
print("\n批量预测结果:")
for res in results:
print(res)
if __name__ == "__main__":
main()

162
src/models/train.py Normal file
View File

@ -0,0 +1,162 @@
import os
import joblib
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression
from lightgbm import LGBMClassifier
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
from typing import Tuple, Dict, Any
import polars as pl
from dotenv import load_dotenv
# 加载环境变量
load_dotenv()
# 获取模型保存目录
MODEL_SAVE_DIR = os.getenv("MODEL_SAVE_DIR", "./models")
def build_tfidf_vectorizer() -> TfidfVectorizer:
"""构建TF-IDF向量化器"""
return TfidfVectorizer(
stop_words="english",
max_features=10000,
ngram_range=(1, 2),
lowercase=True,
token_pattern=r'\b[a-zA-Z0-9_]{1,}\b'
)
def build_logistic_regression() -> Pipeline:
"""构建Logistic Regression模型流水线"""
pipeline = Pipeline([
("tfidf", build_tfidf_vectorizer()),
("lr", LogisticRegression(
random_state=42,
max_iter=1000,
class_weight="balanced"
))
])
return pipeline
def build_lightgbm() -> Pipeline:
"""构建LightGBM模型流水线"""
pipeline = Pipeline([
("tfidf", build_tfidf_vectorizer()),
("lgbm", LGBMClassifier(
random_state=42,
class_weight="balanced",
verbose=-1,
feature_name="auto"
))
])
return pipeline
def train_model(
train_df: pl.DataFrame,
model_type: str = "lightgbm",
hyperparam_tune: bool = False
) -> Tuple[Pipeline, Dict[str, Any]]:
"""训练模型"""
# 准备训练数据
X_train = train_df["clean_text"].to_list()
y_train = train_df["label_num"].to_list()
# 选择模型
if model_type == "logistic_regression":
pipeline = build_logistic_regression()
param_grid = {
"lr__C": [0.1, 1.0, 10.0],
"lr__solver": ["liblinear", "lbfgs"]
}
elif model_type == "lightgbm":
pipeline = build_lightgbm()
param_grid = {
"lgbm__n_estimators": [100, 200],
"lgbm__learning_rate": [0.1, 0.2],
"lgbm__max_depth": [3, 5, 7]
}
else:
raise ValueError(f"不支持的模型类型: {model_type}")
# 超参数调优
if hyperparam_tune:
grid_search = GridSearchCV(
pipeline,
param_grid,
cv=5,
scoring="f1_macro",
n_jobs=-1
)
grid_search.fit(X_train, y_train)
best_model = grid_search.best_estimator_
best_params = grid_search.best_params_
print(f"最佳参数: {best_params}")
else:
best_model = pipeline
best_model.fit(X_train, y_train)
best_params = {}
return best_model, best_params
def save_model(model: Pipeline, model_name: str) -> str:
"""保存模型"""
# 确保模型保存目录存在
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
# 构建模型文件路径
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.joblib")
# 保存模型
joblib.dump(model, model_path)
print(f"模型已保存到: {model_path}")
return model_path
def load_model(model_name: str) -> Pipeline:
"""加载模型"""
# 构建模型文件路径
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.joblib")
# 加载模型
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件不存在: {model_path}")
model = joblib.load(model_path)
print(f"模型已从: {model_path} 加载")
return model
def main():
"""主函数,用于测试"""
import sys
import os
# 添加项目根目录到Python路径
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from src.data import load_data, preprocess_data, split_data
# 加载和预处理数据
df = load_data("./data/spam.csv")
processed_df = preprocess_data(df)
train_df, test_df = split_data(processed_df)
print("训练Logistic Regression模型...")
lr_model, lr_params = train_model(train_df, model_type="logistic_regression")
save_model(lr_model, "logistic_regression")
print("\n训练LightGBM模型...")
lgbm_model, lgbm_params = train_model(train_df, model_type="lightgbm")
save_model(lgbm_model, "lightgbm")
print("\n模型训练完成!")
if __name__ == "__main__":
main()