CourseDesign/src/train.py

117 lines
3.7 KiB
Python
Raw Normal View History

import sys
import os
# 修复模块路径问题,让你可以在根目录直接 python src/train.py
sys.path.append(os.getcwd())
import joblib
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.compose import ColumnTransformer
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import classification_report, accuracy_score, f1_score
from src.data import generate_data, preprocess_data
MODELS_DIR = "models"
MODEL_PATH = os.path.join(MODELS_DIR, "model.pkl")
def get_pipeline(model_type="rf"):
"""
构建标准的 Sklearn 处理流水线
1. 数值特征 -> 缺失填充 (均值) -> 标准化
2. 类别特征 -> 缺失填充 (众数) -> OneHot编码
3. 模型 -> LR RF
"""
# 定义特征列
numeric_features = ["study_hours", "sleep_hours", "attendance_rate", "stress_level"]
categorical_features = ["study_type"]
# 数值处理管道
numeric_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="mean")),
("scaler", StandardScaler())
])
# 类别处理管道
categorical_transformer = Pipeline(steps=[
("imputer", SimpleImputer(strategy="most_frequent")),
("onehot", OneHotEncoder(handle_unknown="ignore"))
])
# 组合预处理
preprocessor = ColumnTransformer(
transformers=[
("num", numeric_transformer, numeric_features),
("cat", categorical_transformer, categorical_features)
]
)
# 选择模型
if model_type == "lr":
clf = LogisticRegression(random_state=42)
else:
clf = RandomForestClassifier(n_estimators=500, max_depth=5, random_state=42)
return Pipeline(steps=[
("preprocessor", preprocessor),
("classifier", clf)
])
def train():
print(">>> 1. 数据准备")
df = generate_data(n_samples=2000)
df = preprocess_data(df)
X = df.drop(columns=["is_pass"])
y = df["is_pass"]
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42
)
print(f"训练集大小: {X_train.shape}, 测试集大小: {X_test.shape}")
print("\n>>> 2. 模型训练与对比")
# 模型 A: 逻辑回归 (Baseline)
pipe_lr = get_pipeline("lr")
pipe_lr.fit(X_train, y_train)
y_pred_lr = pipe_lr.predict(X_test)
f1_lr = f1_score(y_test, y_pred_lr)
print(f"[Baseline - LogisticRegression] F1: {f1_lr:.4f}")
# 模型 B: 随机森林 (Target)
pipe_rf = get_pipeline("rf")
pipe_rf.fit(X_train, y_train)
y_pred_rf = pipe_rf.predict(X_test)
f1_rf = f1_score(y_test, y_pred_rf)
print(f"[Target - RandomForest] F1: {f1_rf:.4f}")
print("\n>>> 3. 如果 RF 更好,则进行详细评估")
best_model = pipe_rf
print(classification_report(y_test, y_pred_rf))
print("\n>>> 4. 误差分析 (Error Analysis)")
# 找出模型预测错误的样本
test_df = X_test.copy()
test_df["True Label"] = y_test
test_df["Pred Label"] = y_pred_rf
errors = test_df[test_df["True Label"] != test_df["Pred Label"]]
print(f"总计错误样本数: {len(errors)}")
if len(errors) > 0:
print("典型错误样本预览:")
print(errors.head(3))
print("\n>>> 5. 保存最佳模型")
os.makedirs(MODELS_DIR, exist_ok=True)
joblib.dump(best_model, MODEL_PATH)
print(f"模型 Pipeline 已完整保存至 {MODEL_PATH}")
if __name__ == "__main__":
train()