telco-customer-churn-predic.../visualization.py

165 lines
7.4 KiB
Python
Raw Permalink Normal View History

# visualization.py - 客户流失预测模型可视化(直接运行即可)
import joblib
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.metrics import confusion_matrix, roc_curve, auc
# -------------------------- 基础设置(解决中文显示、图表样式)--------------------------
plt.rcParams['font.sans-serif'] = ['SimHei'] # 中文显示Windows
# 如果是Mac/Linux替换为plt.rcParams['font.sans-serif'] = ['Arial Unicode MS']
plt.rcParams['axes.unicode_minus'] = False # 解决负号显示问题
plt.style.use('seaborn-v0_8-whitegrid') # 图表样式(清爽易读)
# -------------------------- 1. 加载模型和数据(复用已有逻辑)--------------------------
def load_model_and_data():
"""加载训练好的模型和测试集数据"""
# 加载模型(确保模型文件路径正确)
try:
model = joblib.load('telco_churn_model.pkl')
print("✅ 模型加载成功")
except FileNotFoundError:
raise FileNotFoundError("❌ 未找到模型文件!请先运行 src/model.py 训练模型")
# 加载并切分数据(复用 src/data.py 的逻辑,避免重复代码)
try:
from src.data import load_data, split_data
df = load_data()
X_train, X_test, y_train, y_test = split_data(df)
print("✅ 测试集数据加载成功共1409条")
return model, X_test, y_test
except ImportError:
raise ImportError("❌ 未找到 src/data.py请确保项目目录结构正确")
# -------------------------- 2. 特征重要性TOP10可视化核心业务洞察--------------------------
def plot_feature_importance(model):
"""绘制特征重要性TOP10图表"""
# 提取预处理后的特征名和重要性得分
preprocessor = model.named_steps['preprocessor']
feature_names = preprocessor.get_feature_names_out()
feature_importance = model.named_steps['classifier'].feature_importances_
# 整理数据(排序+取TOP10简化特征名方便显示
feature_df = pd.DataFrame({
'特征名': feature_names,
'重要性': feature_importance
}).sort_values('重要性', ascending=False).head(10)
# 简化特征名(原特征名太长,图表显示优化)
feature_name_map = {
'tenure': '客户在网时长',
'TotalCharges': '总消费金额',
'MonthlyCharges': '月消费金额',
'Contract_Two year': '合约期-2年',
'InternetService_Fiber optic': '网络类型-光纤',
'PaymentMethod_Electronic check': '支付方式-电子支票',
'Contract_One year': '合约期-1年',
'OnlineSecurity_Yes': '在线安全服务-有',
'TechSupport_Yes': '技术支持-有',
'PaperlessBilling_Yes': '电子账单-是'
}
feature_df['简化特征名'] = feature_df['特征名'].map(lambda x: feature_name_map.get(x, x[:15])) # 兜底避免报错
# 绘制水平条形图(更易读)
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(
x='重要性', y='简化特征名', data=feature_df,
palette='viridis_r', ax=ax # 颜色渐变(反向,重要性越高颜色越深)
)
# 图表美化(标题、标签、数值标注)
ax.set_title('客户流失预测 - 特征重要性TOP10', fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel('重要性得分', fontsize=12)
ax.set_ylabel('特征', fontsize=12)
ax.tick_params(axis='y', labelsize=10)
# 在条形图上添加数值(直观展示得分)
for i, v in enumerate(feature_df['重要性']):
ax.text(v + 0.002, i, f'{v:.3f}', va='center', fontsize=9)
# 保存图表高清可直接插入PPT
plt.tight_layout()
plt.savefig('特征重要性TOP10.png', dpi=300, bbox_inches='tight')
print("✅ 特征重要性图表已保存为特征重要性TOP10.png")
# -------------------------- 3. 混淆矩阵可视化(模型效果直观展示)--------------------------
def plot_confusion_matrix(model, X_test, y_test):
"""绘制混淆矩阵(展示模型预测准确率、漏判/误判情况)"""
# 生成预测结果
y_pred = model.predict(X_test)
# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)
# 混淆矩阵标签0=未流失1=流失)
labels = ['未流失', '流失']
# 绘制热力图
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(
cm, annot=True, fmt='d', cmap='Blues', # fmt='d' 显示整数
xticklabels=labels, yticklabels=labels, ax=ax,
cbar_kws={'label': '客户数量'} # 颜色条标签
)
# 图表美化
ax.set_title('客户流失预测 - 混淆矩阵', fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel('预测标签', fontsize=12)
ax.set_ylabel('真实标签', fontsize=12)
# 添加统计信息(准确率、流失识别率)
total = cm.sum()
accuracy = (cm[0,0] + cm[1,1]) / total
recall_churn = cm[1,1] / (cm[1,0] + cm[1,1]) # 流失客户识别率
ax.text(0.5, -0.15, f'准确率:{accuracy:.3f} | 流失识别率:{recall_churn:.3f}',
ha='center', transform=ax.transAxes, fontsize=11)
# 保存图表
plt.tight_layout()
plt.savefig('混淆矩阵.png', dpi=300, bbox_inches='tight')
print("✅ 混淆矩阵图表已保存为:混淆矩阵.png")
# -------------------------- 4. 可选ROC曲线可视化进阶模型评估--------------------------
def plot_roc_curve(model, X_test, y_test):
"""绘制ROC曲线展示模型区分能力AUC值"""
# 生成预测概率
y_pred_proba = model.predict_proba(X_test)[:, 1] # 取流失1类的概率
fpr, tpr, _ = roc_curve(y_test, y_pred_proba)
roc_auc = auc(fpr, tpr)
# 绘制ROC曲线
fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC曲线 (AUC = {roc_auc:.3f})')
ax.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', label='随机猜测')
# 图表美化
ax.set_xlim([0.0, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_title('客户流失预测 - ROC曲线', fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel('假阳性率(误判为流失)', fontsize=12)
ax.set_ylabel('真阳性率(正确识别流失)', fontsize=12)
ax.legend(loc="lower right", fontsize=11)
ax.grid(True, alpha=0.3)
# 保存图表
plt.tight_layout()
plt.savefig('ROC曲线.png', dpi=300, bbox_inches='tight')
print("✅ ROC曲线图表已保存为ROC曲线.png")
# -------------------------- 主函数(一键运行所有可视化)--------------------------
if __name__ == "__main__":
print("🚀 开始生成可视化图表...")
try:
# 加载模型和数据
model, X_test, y_test = load_model_and_data()
# 生成3张图表特征重要性 + 混淆矩阵 + ROC曲线
plot_feature_importance(model)
plot_confusion_matrix(model, X_test, y_test)
plot_roc_curve(model, X_test, y_test)
print("\n🎉 所有图表生成完成!文件保存在项目根目录:")
print("1. 特征重要性TOP10.png业务洞察核心")
print("2. 混淆矩阵.png模型效果直观展示")
print("3. ROC曲线.png进阶评估AUC值")
except Exception as e:
print(f"\n❌ 生成失败:{str(e)}")