G09-BankMarketing/train.py
2026-01-16 19:28:30 +08:00

91 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import polars as pl
import pandas as pd
import lightgbm as lgb
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, roc_auc_score, f1_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
import joblib
import logging
import os
from src.data import load_and_clean_data
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
def train_models(data_path="data/bank.csv", model_dir="models"):
# 1. 加载数据
df_pl = load_and_clean_data(data_path)
df = df_pl.to_pandas() # 转换为 Pandas 以兼容 Sklearn
# 2. 特征预处理
# 区分分类和数值特征
target_col = "target"
X = df.drop(columns=[target_col])
y = df[target_col]
cat_cols = X.select_dtypes(include=['object', 'category']).columns.tolist()
num_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
# Label Encoding (为了简化LightGBM 可以直接处理 Category但 Sklearn 需要编码)
encoders = {}
for col in cat_cols:
le = LabelEncoder()
X[col] = le.fit_transform(X[col].astype(str))
encoders[col] = le
# 3. 数据切分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 4. 训练基线模型 (Logistic Regression)
logger.info("训练基线模型 (Logistic Regression)...")
# 逻辑回归需要归一化
scaler = StandardScaler()
X_train_scaled = X_train.copy()
X_test_scaled = X_test.copy()
X_train_scaled[num_cols] = scaler.fit_transform(X_train[num_cols])
X_test_scaled[num_cols] = scaler.transform(X_test[num_cols])
lr_model = LogisticRegression(max_iter=1000, random_state=42)
lr_model.fit(X_train_scaled, y_train)
lr_pred = lr_model.predict(X_test_scaled)
lr_prob = lr_model.predict_proba(X_test_scaled)[:, 1]
logger.info(f"Baseline F1: {f1_score(y_test, lr_pred):.4f}")
logger.info(f"Baseline AUC: {roc_auc_score(y_test, lr_prob):.4f}")
# 5. 训练进阶模型 (LightGBM)
logger.info("训练进阶模型 (LightGBM)...")
lgb_model = lgb.LGBMClassifier(n_estimators=100, learning_rate=0.05, random_state=42, verbose=-1)
lgb_model.fit(X_train, y_train)
lgb_pred = lgb_model.predict(X_test)
lgb_prob = lgb_model.predict_proba(X_test)[:, 1]
logger.info(f"LightGBM F1: {f1_score(y_test, lgb_pred):.4f}")
logger.info(f"LightGBM AUC: {roc_auc_score(y_test, lgb_prob):.4f}")
# 6. 保存模型与元数据
if not os.path.exists(model_dir):
os.makedirs(model_dir)
artifacts = {
"lgb_model": lgb_model,
"lr_model": lr_model,
"scaler": scaler,
"encoders": encoders,
"features": list(X.columns),
"cat_cols": cat_cols,
"num_cols": num_cols
}
joblib.dump(artifacts, os.path.join(model_dir, "model_artifacts.pkl"))
logger.info(f"模型已保存至 {model_dir}/model_artifacts.pkl")
return artifacts
if __name__ == "__main__":
train_models()