telco-customer-churn-predic.../save_model.py

31 lines
937 B
Python
Raw Permalink Normal View History

# save_model.py
import joblib
from src.data import load_data, split_data
from src.model import build_preprocessor
from sklearn.ensemble import RandomForestClassifier
from sklearn.pipeline import Pipeline
if __name__ == "__main__":
# 加载数据
df = load_data()
X_train, X_test, y_train, y_test = split_data(df)
preprocessor = build_preprocessor(X_train)
# 训练调优后的随机森林模型
model = Pipeline(steps=[
('preprocessor', preprocessor),
('classifier', RandomForestClassifier(
n_estimators=200,
max_depth=15,
min_samples_split=8,
min_samples_leaf=4,
class_weight='balanced_subsample',
random_state=42,
n_jobs=-1
))
])
model.fit(X_train, y_train)
# 保存模型
joblib.dump(model, 'telco_churn_model.pkl')
print("✅ 模型已保存为telco_churn_model.pkl")