feat(模型): 添加模型训练、评估和预测功能
- 实现模型训练模块,支持逻辑回归和LightGBM - 添加模型评估功能,包括分类报告和混淆矩阵 - 实现预测功能,支持单条、批量及DataFrame输入 - 更新.gitignore,移除models目录忽略 - 添加模型评估结果文件
This commit is contained in:
parent
e081ed9329
commit
48013e93bd
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,8 +3,6 @@
|
||||
env/
|
||||
venv/
|
||||
|
||||
# Model files
|
||||
models/
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
|
||||
BIN
models/lightgbm.joblib
Normal file
BIN
models/lightgbm.joblib
Normal file
Binary file not shown.
8
models/lightgbm_classification_report.txt
Normal file
8
models/lightgbm_classification_report.txt
Normal file
@ -0,0 +1,8 @@
|
||||
precision recall f1-score support
|
||||
|
||||
ham 0.98 0.99 0.99 969
|
||||
spam 0.93 0.89 0.91 145
|
||||
|
||||
accuracy 0.98 1114
|
||||
macro avg 0.96 0.94 0.95 1114
|
||||
weighted avg 0.98 0.98 0.98 1114
|
||||
BIN
models/lightgbm_confusion_matrix.png
Normal file
BIN
models/lightgbm_confusion_matrix.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 65 KiB |
BIN
models/logistic_regression.joblib
Normal file
BIN
models/logistic_regression.joblib
Normal file
Binary file not shown.
8
models/logistic_regression_classification_report.txt
Normal file
8
models/logistic_regression_classification_report.txt
Normal file
@ -0,0 +1,8 @@
|
||||
precision recall f1-score support
|
||||
|
||||
ham 0.99 0.99 0.99 969
|
||||
spam 0.92 0.94 0.93 145
|
||||
|
||||
accuracy 0.98 1114
|
||||
macro avg 0.95 0.96 0.96 1114
|
||||
weighted avg 0.98 0.98 0.98 1114
|
||||
BIN
models/logistic_regression_confusion_matrix.png
Normal file
BIN
models/logistic_regression_confusion_matrix.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 64 KiB |
42
src/models/__init__.py
Normal file
42
src/models/__init__.py
Normal file
@ -0,0 +1,42 @@
|
||||
from .train import (
|
||||
build_tfidf_vectorizer,
|
||||
build_logistic_regression,
|
||||
build_lightgbm,
|
||||
train_model,
|
||||
save_model,
|
||||
load_model
|
||||
)
|
||||
from .evaluate import (
|
||||
evaluate_model,
|
||||
print_evaluation_results,
|
||||
plot_confusion_matrix,
|
||||
save_evaluation_results,
|
||||
compare_models
|
||||
)
|
||||
from .predict import (
|
||||
predict_spam,
|
||||
predict_batch_spam,
|
||||
predict_from_df
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# train.py
|
||||
"build_tfidf_vectorizer",
|
||||
"build_logistic_regression",
|
||||
"build_lightgbm",
|
||||
"train_model",
|
||||
"save_model",
|
||||
"load_model",
|
||||
|
||||
# evaluate.py
|
||||
"evaluate_model",
|
||||
"print_evaluation_results",
|
||||
"plot_confusion_matrix",
|
||||
"save_evaluation_results",
|
||||
"compare_models",
|
||||
|
||||
# predict.py
|
||||
"predict_spam",
|
||||
"predict_batch_spam",
|
||||
"predict_from_df"
|
||||
]
|
||||
159
src/models/evaluate.py
Normal file
159
src/models/evaluate.py
Normal file
@ -0,0 +1,159 @@
|
||||
from sklearn.metrics import (
|
||||
accuracy_score,
|
||||
precision_score,
|
||||
recall_score,
|
||||
f1_score,
|
||||
classification_report,
|
||||
confusion_matrix
|
||||
)
|
||||
import matplotlib.pyplot as plt
|
||||
import seaborn as sns
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
import polars as pl
|
||||
|
||||
|
||||
|
||||
def evaluate_model(
|
||||
model,
|
||||
test_df: pl.DataFrame
|
||||
) -> Dict[str, Any]:
|
||||
"""评估模型"""
|
||||
# 准备测试数据
|
||||
X_test = test_df["clean_text"].to_list()
|
||||
y_test = test_df["label_num"].to_list()
|
||||
|
||||
# 预测
|
||||
y_pred = model.predict(X_test)
|
||||
y_pred_proba = model.predict_proba(X_test)[:, 1]
|
||||
|
||||
# 计算指标
|
||||
accuracy = accuracy_score(y_test, y_pred)
|
||||
precision = precision_score(y_test, y_pred, average="macro")
|
||||
recall = recall_score(y_test, y_pred, average="macro")
|
||||
f1 = f1_score(y_test, y_pred, average="macro")
|
||||
|
||||
# 生成分类报告
|
||||
class_report = classification_report(y_test, y_pred, target_names=["ham", "spam"])
|
||||
|
||||
# 生成混淆矩阵
|
||||
conf_matrix = confusion_matrix(y_test, y_pred)
|
||||
|
||||
# 收集结果
|
||||
results = {
|
||||
"accuracy": accuracy,
|
||||
"precision": precision,
|
||||
"recall": recall,
|
||||
"f1": f1,
|
||||
"class_report": class_report,
|
||||
"conf_matrix": conf_matrix,
|
||||
"y_test": y_test,
|
||||
"y_pred": y_pred,
|
||||
"y_pred_proba": y_pred_proba
|
||||
}
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def print_evaluation_results(results: Dict[str, Any]) -> None:
|
||||
"""打印评估结果"""
|
||||
print(f"Accuracy: {results['accuracy']:.4f}")
|
||||
print(f"Precision (macro): {results['precision']:.4f}")
|
||||
print(f"Recall (macro): {results['recall']:.4f}")
|
||||
print(f"F1 Score (macro): {results['f1']:.4f}")
|
||||
print("\nClassification Report:")
|
||||
print(results['class_report'])
|
||||
|
||||
|
||||
def plot_confusion_matrix(
|
||||
conf_matrix: List[List[int]],
|
||||
save_path: str = None
|
||||
) -> None:
|
||||
"""绘制混淆矩阵"""
|
||||
plt.figure(figsize=(8, 6))
|
||||
sns.heatmap(
|
||||
conf_matrix,
|
||||
annot=True,
|
||||
fmt="d",
|
||||
cmap="Blues",
|
||||
xticklabels=["ham", "spam"],
|
||||
yticklabels=["ham", "spam"]
|
||||
)
|
||||
plt.xlabel("Predicted Label")
|
||||
plt.ylabel("True Label")
|
||||
plt.title("Confusion Matrix")
|
||||
|
||||
if save_path:
|
||||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||
print(f"混淆矩阵已保存到: {save_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def save_evaluation_results(
|
||||
results: Dict[str, Any],
|
||||
model_name: str,
|
||||
save_dir: str = "./models"
|
||||
) -> None:
|
||||
"""保存评估结果"""
|
||||
# 确保保存目录存在
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
|
||||
# 保存分类报告
|
||||
report_path = os.path.join(save_dir, f"{model_name}_classification_report.txt")
|
||||
with open(report_path, "w") as f:
|
||||
f.write(results["class_report"])
|
||||
|
||||
# 保存混淆矩阵
|
||||
conf_matrix_path = os.path.join(save_dir, f"{model_name}_confusion_matrix.png")
|
||||
plot_confusion_matrix(results["conf_matrix"], conf_matrix_path)
|
||||
|
||||
print(f"评估结果已保存到: {save_dir}")
|
||||
|
||||
|
||||
def compare_models(
|
||||
models: Dict[str, Any],
|
||||
test_df: pl.DataFrame
|
||||
) -> Dict[str, Dict[str, Any]]:
|
||||
"""比较多个模型"""
|
||||
comparison_results = {}
|
||||
|
||||
for model_name, model in models.items():
|
||||
print(f"\n=== 评估 {model_name} 模型 ===")
|
||||
results = evaluate_model(model, test_df)
|
||||
print_evaluation_results(results)
|
||||
save_evaluation_results(results, model_name)
|
||||
comparison_results[model_name] = results
|
||||
|
||||
return comparison_results
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,用于测试"""
|
||||
from .train import load_model
|
||||
from ..data import load_data, preprocess_data, split_data
|
||||
|
||||
# 加载和预处理数据
|
||||
df = load_data("../data/spam.csv")
|
||||
processed_df = preprocess_data(df)
|
||||
_, test_df = split_data(processed_df)
|
||||
|
||||
# 加载模型
|
||||
try:
|
||||
lr_model = load_model("logistic_regression")
|
||||
lgbm_model = load_model("lightgbm")
|
||||
|
||||
# 比较模型
|
||||
models = {
|
||||
"logistic_regression": lr_model,
|
||||
"lightgbm": lgbm_model
|
||||
}
|
||||
|
||||
compare_models(models, test_df)
|
||||
except FileNotFoundError:
|
||||
print("模型文件不存在,请先运行train.py训练模型")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
135
src/models/predict.py
Normal file
135
src/models/predict.py
Normal file
@ -0,0 +1,135 @@
|
||||
from typing import List, Dict, Any
|
||||
import polars as pl
|
||||
import warnings
|
||||
from sklearn.exceptions import DataConversionWarning
|
||||
from .train import load_model
|
||||
from ..data import clean_text
|
||||
|
||||
# 忽略特征名称不匹配的警告
|
||||
warnings.filterwarnings("ignore", message="X does not have valid feature names")
|
||||
|
||||
|
||||
def predict_spam(
|
||||
text: str,
|
||||
model_name: str = "lightgbm"
|
||||
) -> Dict[str, Any]:
|
||||
"""预测单条短信是否为垃圾短信"""
|
||||
# 加载模型
|
||||
model = load_model(model_name)
|
||||
|
||||
# 清洗文本
|
||||
cleaned_text = clean_text(text)
|
||||
|
||||
# 预测
|
||||
prediction = model.predict([cleaned_text])[0]
|
||||
probability = model.predict_proba([cleaned_text])[0]
|
||||
|
||||
# 转换结果
|
||||
label = "spam" if prediction == 1 else "ham"
|
||||
|
||||
return {
|
||||
"original_text": text,
|
||||
"cleaned_text": cleaned_text,
|
||||
"label": label,
|
||||
"label_num": int(prediction),
|
||||
"probability": {
|
||||
"ham": float(probability[0]),
|
||||
"spam": float(probability[1])
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def predict_batch_spam(
|
||||
texts: List[str],
|
||||
model_name: str = "lightgbm"
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""批量预测短信是否为垃圾短信"""
|
||||
# 加载模型
|
||||
model = load_model(model_name)
|
||||
|
||||
# 清洗文本
|
||||
cleaned_texts = [clean_text(text) for text in texts]
|
||||
|
||||
# 预测
|
||||
predictions = model.predict(cleaned_texts)
|
||||
probabilities = model.predict_proba(cleaned_texts)
|
||||
|
||||
# 转换结果
|
||||
results = []
|
||||
for i, (text, cleaned_text, prediction, probability) in enumerate(zip(texts, cleaned_texts, predictions, probabilities)):
|
||||
label = "spam" if prediction == 1 else "ham"
|
||||
results.append({
|
||||
"id": i,
|
||||
"original_text": text,
|
||||
"cleaned_text": cleaned_text,
|
||||
"label": label,
|
||||
"label_num": int(prediction),
|
||||
"probability": {
|
||||
"ham": float(probability[0]),
|
||||
"spam": float(probability[1])
|
||||
}
|
||||
})
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def predict_from_df(
|
||||
df: pl.DataFrame,
|
||||
text_column: str = "text",
|
||||
model_name: str = "lightgbm"
|
||||
) -> pl.DataFrame:
|
||||
"""从DataFrame中预测短信是否为垃圾短信"""
|
||||
# 清洗文本
|
||||
df = df.with_columns(
|
||||
pl.col(text_column).map_elements(clean_text, return_dtype=pl.String).alias("clean_text")
|
||||
)
|
||||
|
||||
# 准备预测数据
|
||||
texts = df["clean_text"].to_list()
|
||||
|
||||
# 预测
|
||||
model = load_model(model_name)
|
||||
predictions = model.predict(texts)
|
||||
probabilities = model.predict_proba(texts)
|
||||
|
||||
# 添加预测结果
|
||||
df = df.with_columns(
|
||||
pl.Series("label_num", predictions).cast(pl.Int64),
|
||||
pl.Series("ham_prob", [p[0] for p in probabilities]).cast(pl.Float64),
|
||||
pl.Series("spam_prob", [p[1] for p in probabilities]).cast(pl.Float64)
|
||||
)
|
||||
|
||||
# 添加标签
|
||||
df = df.with_columns(
|
||||
pl.when(pl.col("label_num") == 1)
|
||||
.then("spam")
|
||||
.otherwise("ham")
|
||||
.alias("label")
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,用于测试"""
|
||||
# 测试单条预测
|
||||
test_text = "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's"
|
||||
result = predict_spam(test_text)
|
||||
print("单条预测结果:")
|
||||
print(result)
|
||||
|
||||
# 测试批量预测
|
||||
test_texts = [
|
||||
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005.",
|
||||
"Ok lar... Joking wif u oni...",
|
||||
"WINNER!! As a valued network customer you have been selected to receivea <20>900 prize reward!",
|
||||
"I'm gonna be home soon and i don't want to talk about this stuff anymore tonight, k?"
|
||||
]
|
||||
results = predict_batch_spam(test_texts)
|
||||
print("\n批量预测结果:")
|
||||
for res in results:
|
||||
print(res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
162
src/models/train.py
Normal file
162
src/models/train.py
Normal file
@ -0,0 +1,162 @@
|
||||
import os
|
||||
import joblib
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from lightgbm import LGBMClassifier
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.model_selection import GridSearchCV
|
||||
from typing import Tuple, Dict, Any
|
||||
import polars as pl
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
# 获取模型保存目录
|
||||
MODEL_SAVE_DIR = os.getenv("MODEL_SAVE_DIR", "./models")
|
||||
|
||||
|
||||
def build_tfidf_vectorizer() -> TfidfVectorizer:
|
||||
"""构建TF-IDF向量化器"""
|
||||
return TfidfVectorizer(
|
||||
stop_words="english",
|
||||
max_features=10000,
|
||||
ngram_range=(1, 2),
|
||||
lowercase=True,
|
||||
token_pattern=r'\b[a-zA-Z0-9_]{1,}\b'
|
||||
)
|
||||
|
||||
|
||||
def build_logistic_regression() -> Pipeline:
|
||||
"""构建Logistic Regression模型流水线"""
|
||||
pipeline = Pipeline([
|
||||
("tfidf", build_tfidf_vectorizer()),
|
||||
("lr", LogisticRegression(
|
||||
random_state=42,
|
||||
max_iter=1000,
|
||||
class_weight="balanced"
|
||||
))
|
||||
])
|
||||
return pipeline
|
||||
|
||||
|
||||
def build_lightgbm() -> Pipeline:
|
||||
"""构建LightGBM模型流水线"""
|
||||
pipeline = Pipeline([
|
||||
("tfidf", build_tfidf_vectorizer()),
|
||||
("lgbm", LGBMClassifier(
|
||||
random_state=42,
|
||||
class_weight="balanced",
|
||||
verbose=-1,
|
||||
feature_name="auto"
|
||||
))
|
||||
])
|
||||
return pipeline
|
||||
|
||||
|
||||
def train_model(
|
||||
train_df: pl.DataFrame,
|
||||
model_type: str = "lightgbm",
|
||||
hyperparam_tune: bool = False
|
||||
) -> Tuple[Pipeline, Dict[str, Any]]:
|
||||
"""训练模型"""
|
||||
# 准备训练数据
|
||||
X_train = train_df["clean_text"].to_list()
|
||||
y_train = train_df["label_num"].to_list()
|
||||
|
||||
# 选择模型
|
||||
if model_type == "logistic_regression":
|
||||
pipeline = build_logistic_regression()
|
||||
param_grid = {
|
||||
"lr__C": [0.1, 1.0, 10.0],
|
||||
"lr__solver": ["liblinear", "lbfgs"]
|
||||
}
|
||||
elif model_type == "lightgbm":
|
||||
pipeline = build_lightgbm()
|
||||
param_grid = {
|
||||
"lgbm__n_estimators": [100, 200],
|
||||
"lgbm__learning_rate": [0.1, 0.2],
|
||||
"lgbm__max_depth": [3, 5, 7]
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"不支持的模型类型: {model_type}")
|
||||
|
||||
# 超参数调优
|
||||
if hyperparam_tune:
|
||||
grid_search = GridSearchCV(
|
||||
pipeline,
|
||||
param_grid,
|
||||
cv=5,
|
||||
scoring="f1_macro",
|
||||
n_jobs=-1
|
||||
)
|
||||
grid_search.fit(X_train, y_train)
|
||||
best_model = grid_search.best_estimator_
|
||||
best_params = grid_search.best_params_
|
||||
print(f"最佳参数: {best_params}")
|
||||
else:
|
||||
best_model = pipeline
|
||||
best_model.fit(X_train, y_train)
|
||||
best_params = {}
|
||||
|
||||
return best_model, best_params
|
||||
|
||||
|
||||
def save_model(model: Pipeline, model_name: str) -> str:
|
||||
"""保存模型"""
|
||||
# 确保模型保存目录存在
|
||||
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
||||
|
||||
# 构建模型文件路径
|
||||
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.joblib")
|
||||
|
||||
# 保存模型
|
||||
joblib.dump(model, model_path)
|
||||
print(f"模型已保存到: {model_path}")
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
def load_model(model_name: str) -> Pipeline:
|
||||
"""加载模型"""
|
||||
# 构建模型文件路径
|
||||
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.joblib")
|
||||
|
||||
# 加载模型
|
||||
if not os.path.exists(model_path):
|
||||
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||
|
||||
model = joblib.load(model_path)
|
||||
print(f"模型已从: {model_path} 加载")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,用于测试"""
|
||||
import sys
|
||||
import os
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
||||
|
||||
from src.data import load_data, preprocess_data, split_data
|
||||
|
||||
# 加载和预处理数据
|
||||
df = load_data("./data/spam.csv")
|
||||
processed_df = preprocess_data(df)
|
||||
train_df, test_df = split_data(processed_df)
|
||||
|
||||
print("训练Logistic Regression模型...")
|
||||
lr_model, lr_params = train_model(train_df, model_type="logistic_regression")
|
||||
save_model(lr_model, "logistic_regression")
|
||||
|
||||
print("\n训练LightGBM模型...")
|
||||
lgbm_model, lgbm_params = train_model(train_df, model_type="lightgbm")
|
||||
save_model(lgbm_model, "lightgbm")
|
||||
|
||||
print("\n模型训练完成!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user