git remote add origin http://hblu.top:3000/MachineLearning2025/G03-304.git
git push -u origin main
This commit is contained in:
commit
4f5d7d977b
1
ml_course_design/.env.example
Normal file
1
ml_course_design/.env.example
Normal file
@ -0,0 +1 @@
|
|||||||
|
DEEPSEEK_API_KEY=your-key-here
|
||||||
26
ml_course_design/.gitignore
vendored
Normal file
26
ml_course_design/.gitignore
vendored
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
# ===== Configuration files that should not be committed =====
|
||||||
|
.env
|
||||||
|
|
||||||
|
# ===== Python virtual environments =====
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
__pycache__/
|
||||||
|
*.pyc
|
||||||
|
*.pyo
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# ===== IDE settings =====
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
|
||||||
|
# ===== macOS system files =====
|
||||||
|
.DS_Store
|
||||||
|
|
||||||
|
# ===== Jupyter =====
|
||||||
|
.ipynb_checkpoints/
|
||||||
|
|
||||||
|
# ===== Large files (over 10MB should be excluded) =====
|
||||||
|
# Uncomment and add specific large files if needed
|
||||||
|
# data/large_dataset.csv
|
||||||
|
# models/large_model.pkl
|
||||||
373
ml_course_design/README.md
Normal file
373
ml_course_design/README.md
Normal file
@ -0,0 +1,373 @@
|
|||||||
|
# 客户流失预测系统
|
||||||
|
|
||||||
|
> **机器学习 (Python) 课程设计**
|
||||||
|
|
||||||
|
## 👥 团队成员
|
||||||
|
|
||||||
|
| 姓名 | 学号 | 贡献 |
|
||||||
|
|------|------|------|
|
||||||
|
| 黄迎 | 2311020109 | 数据处理、模型训练 |
|
||||||
|
| 龚士皓 | 2311020107 | Agent 开发、Streamlit |
|
||||||
|
| 金文磊 | 2311020110 | 测试、文档撰写 |
|
||||||
|
|
||||||
|
## 📋 项目概述
|
||||||
|
|
||||||
|
本项目是一个基于机器学习的电信客户流失预测系统,结合了智能Agent技术,能够通过自然语言交互和可视化界面提供客户流失风险预测服务。
|
||||||
|
|
||||||
|
### 项目目标
|
||||||
|
|
||||||
|
- 构建一个结构化数据集的分类/回归模型
|
||||||
|
- 实现一个能够理解自然语言的智能Agent
|
||||||
|
- 提供用户友好的可视化交互界面
|
||||||
|
|
||||||
|
### 技术栈
|
||||||
|
|
||||||
|
- **数据处理**: Polars + pandas
|
||||||
|
- **可视化**: Seaborn + Streamlit + Plotly
|
||||||
|
- **数据验证**: Pydantic + pandera
|
||||||
|
- **机器学习**: scikit-learn + LightGBM
|
||||||
|
- **智能Agent**: pydantic-ai
|
||||||
|
- **LLM服务**: DeepSeek
|
||||||
|
|
||||||
|
## 🚀 快速开始
|
||||||
|
|
||||||
|
### 1. 环境配置
|
||||||
|
|
||||||
|
#### 安装依赖
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 使用uv安装项目依赖
|
||||||
|
uv sync
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 配置API Key
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 复制环境变量示例文件
|
||||||
|
cp .env.example .env
|
||||||
|
|
||||||
|
# 编辑.env文件,配置DeepSeek API Key
|
||||||
|
# DEEPSEEK_API_KEY="your-key-here"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 运行应用
|
||||||
|
|
||||||
|
#### 方式A: 运行Streamlit演示应用
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run streamlit run src/streamlit_app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 方式B: 运行智能Agent演示
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python src/agent_app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 方式C: 运行模型训练脚本
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv run python src/train.py
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 从任意目录运行(可选)
|
||||||
|
|
||||||
|
如果你想从项目根目录外运行应用,可以使用完整路径:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 运行智能Agent演示
|
||||||
|
uv run python "path/to/ml_course_design/src/agent_app.py"
|
||||||
|
|
||||||
|
# 运行模型训练脚本
|
||||||
|
uv run python "path/to/ml_course_design/src/train.py"
|
||||||
|
|
||||||
|
# 运行Streamlit演示应用
|
||||||
|
uv run -C "path/to/ml_course_design" streamlit run src/streamlit_app.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📊 数据说明
|
||||||
|
|
||||||
|
### 数据集
|
||||||
|
|
||||||
|
本项目使用了Kaggle上的**Telco Customer Churn**数据集,包含了7043名电信客户的信息和流失状态。
|
||||||
|
|
||||||
|
### 数据字段
|
||||||
|
|
||||||
|
- **客户信息**: 性别、年龄、是否有伴侣/家属、在网时长
|
||||||
|
- **服务信息**: 电话服务、互联网服务、在线安全、云备份等
|
||||||
|
- **合同信息**: 合同类型、支付方式、月费用、总费用
|
||||||
|
- **目标变量**: 是否流失(Churn)
|
||||||
|
|
||||||
|
### 数据预处理
|
||||||
|
|
||||||
|
- 使用Polars Lazy API进行高效数据处理
|
||||||
|
- 处理缺失值和异常值
|
||||||
|
- 特征编码和标准化
|
||||||
|
|
||||||
|
## 🧠 机器学习实现
|
||||||
|
|
||||||
|
### 模型架构
|
||||||
|
|
||||||
|
- **基准模型**: Logistic Regression
|
||||||
|
- **高级模型**: LightGBM
|
||||||
|
|
||||||
|
### 评估指标
|
||||||
|
|
||||||
|
| 模型 | 准确率 | 精确率 | 召回率 | F1分数 | ROC-AUC |
|
||||||
|
|------|--------|--------|--------|--------|---------|
|
||||||
|
| Logistic Regression | 0.8068 | 0.6600 | 0.5629 | 0.6076 | 0.8547 |
|
||||||
|
| LightGBM | 0.9723 | 0.9358 | 0.9616 | 0.9485 | 0.9951 |
|
||||||
|
|
||||||
|
### 特征重要性
|
||||||
|
|
||||||
|
影响客户流失的关键特征包括:
|
||||||
|
- 合同类型(月付客户流失风险更高)
|
||||||
|
- 在网时长(新客户流失风险更高)
|
||||||
|
- 月费用(高费用客户流失风险更高)
|
||||||
|
- 支付方式(电子支票支付客户流失风险更高)
|
||||||
|
|
||||||
|
## 🤖 Agent 实现
|
||||||
|
|
||||||
|
### 功能概述
|
||||||
|
|
||||||
|
智能Agent能够理解自然语言输入,提取客户信息,并提供流失风险预测和决策建议。
|
||||||
|
|
||||||
|
### 工具列表
|
||||||
|
|
||||||
|
| 工具名称 | 功能 | 输入 | 输出 |
|
||||||
|
|---------|------|------|------|
|
||||||
|
| `predict_churn` | 使用ML模型预测流失风险 | CustomerFeatures | float |
|
||||||
|
| `explain_churn` | 解释影响流失的关键因素 | CustomerFeatures | list[str] |
|
||||||
|
|
||||||
|
### 交互示例
|
||||||
|
|
||||||
|
**输入**:
|
||||||
|
```
|
||||||
|
我有一个女性客户,35岁,在网2个月,月费用89.99,使用电子支票支付,采用月付合同
|
||||||
|
```
|
||||||
|
|
||||||
|
**输出**:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"risk_score": 0.72,
|
||||||
|
"decision": "高风险客户,建议重点关注",
|
||||||
|
"actions": ["主动联系客户", "提供个性化优惠", "分析使用习惯"],
|
||||||
|
"rationale": "月付合同、在网时长短和电子支票支付是导致高流失风险的主要因素"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🎨 Streamlit 应用
|
||||||
|
|
||||||
|
### 功能特点
|
||||||
|
|
||||||
|
- **直观的输入界面**: 分步填写客户信息
|
||||||
|
- **实时预测结果**: 立即显示流失风险评分
|
||||||
|
- **风险等级可视化**: 使用颜色和进度条直观展示风险
|
||||||
|
- **影响因素分析**: 提供详细的风险因素解释
|
||||||
|
- **数据统计展示**: 可视化展示不同特征与流失率的关系
|
||||||
|
|
||||||
|
### 使用方法
|
||||||
|
|
||||||
|
1. 在左侧边栏填写客户信息
|
||||||
|
2. 点击"预测流失风险"按钮
|
||||||
|
3. 在主界面查看预测结果和建议
|
||||||
|
|
||||||
|
## 📁 项目结构
|
||||||
|
|
||||||
|
```
|
||||||
|
ml_course_design/
|
||||||
|
├── pyproject.toml # 项目依赖配置
|
||||||
|
├── .env.example # 环境变量示例
|
||||||
|
├── .gitignore # Git忽略规则
|
||||||
|
├── README.md # 项目说明文档
|
||||||
|
├── data/ # 数据集目录
|
||||||
|
│ └── WA_Fn-UseC_-Telco-Customer-Churn.csv
|
||||||
|
├── models/ # 模型保存目录
|
||||||
|
│ └── best_model_lr.joblib
|
||||||
|
├── src/ # 源代码目录
|
||||||
|
│ ├── __init__.py
|
||||||
|
│ ├── data.py # 数据处理模块
|
||||||
|
│ ├── features.py # 特征定义模块
|
||||||
|
│ ├── train.py # 模型训练模块
|
||||||
|
│ ├── infer.py # 推理接口模块
|
||||||
|
│ ├── agent_app.py # Agent应用
|
||||||
|
│ └── streamlit_app.py # Streamlit应用
|
||||||
|
└── tests/ # 测试目录
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔧 核心模块说明
|
||||||
|
|
||||||
|
### 1. 数据处理模块 (data.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 使用Polars Lazy API高效处理数据
|
||||||
|
lf = pl.scan_csv("data/train.csv")
|
||||||
|
result = (
|
||||||
|
lf.filter(pl.col("age") > 30)
|
||||||
|
.group_by("category")
|
||||||
|
.agg(pl.col("value").mean())
|
||||||
|
.collect()
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 特征定义模块 (features.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 使用Pydantic定义特征模型
|
||||||
|
class CustomerFeatures(BaseModel):
|
||||||
|
gender: gender_types
|
||||||
|
SeniorCitizen: int = Field(ge=0, le=1)
|
||||||
|
tenure: int = Field(ge=0, le=100)
|
||||||
|
MonthlyCharges: float = Field(ge=0, le=200)
|
||||||
|
# ... 其他特征
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 模型训练模块 (train.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 创建预处理管道
|
||||||
|
preprocessor = ColumnTransformer([
|
||||||
|
('num', StandardScaler(), numeric_features),
|
||||||
|
('cat', OneHotEncoder(), categorical_features)
|
||||||
|
])
|
||||||
|
|
||||||
|
# 训练LightGBM模型
|
||||||
|
lgb_model = lgb.train(
|
||||||
|
params,
|
||||||
|
lgb_train,
|
||||||
|
num_boost_round=500
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 推理接口模块 (infer.py)
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 单例预测
|
||||||
|
result = inferencer.predict_single(customer_features)
|
||||||
|
|
||||||
|
# 预测解释
|
||||||
|
result = inferencer.explain_prediction(customer_features)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📈 模型性能
|
||||||
|
|
||||||
|
### 训练集性能
|
||||||
|
|
||||||
|
| 模型 | 准确率 | 精确率 | 召回率 | F1分数 | ROC-AUC |
|
||||||
|
|------|--------|--------|--------|--------|---------|
|
||||||
|
| Logistic Regression | 0.8068 | 0.6600 | 0.5629 | 0.6076 | 0.8547 |
|
||||||
|
| LightGBM | 0.9723 | 0.9358 | 0.9616 | 0.9485 | 0.9951 |
|
||||||
|
|
||||||
|
### 测试集性能
|
||||||
|
|
||||||
|
| 模型 | 准确率 | 精确率 | 召回率 | F1分数 | ROC-AUC |
|
||||||
|
|------|--------|--------|--------|--------|---------|
|
||||||
|
| Logistic Regression | 0.7982 | 0.6364 | 0.5615 | 0.5966 | 0.8357 |
|
||||||
|
|
||||||
|
## 🎯 项目亮点
|
||||||
|
|
||||||
|
1. **高效数据处理**: 使用Polars Lazy API实现大规模数据的快速处理
|
||||||
|
2. **严格数据验证**: 结合Pydantic和pandera确保数据质量
|
||||||
|
3. **双模型架构**: 同时实现基准模型和高级模型,便于对比分析
|
||||||
|
4. **智能Agent交互**: 支持自然语言查询,提供人性化服务
|
||||||
|
5. **可视化界面**: 直观的Streamlit应用,降低使用门槛
|
||||||
|
6. **可解释性**: 提供详细的预测解释和影响因素分析
|
||||||
|
|
||||||
|
## 📝 开发日志
|
||||||
|
|
||||||
|
### Day 1: 项目初始化
|
||||||
|
- 完成项目结构搭建
|
||||||
|
- 配置开发环境
|
||||||
|
- 数据探索和分析
|
||||||
|
|
||||||
|
### Day 2: 数据处理
|
||||||
|
- 实现数据加载和预处理
|
||||||
|
- 特征工程
|
||||||
|
- 数据验证规则定义
|
||||||
|
|
||||||
|
### Day 3: 模型训练
|
||||||
|
- 实现Logistic Regression模型
|
||||||
|
- 实现LightGBM模型
|
||||||
|
- 模型评估和对比
|
||||||
|
|
||||||
|
### Day 4: Agent和应用开发
|
||||||
|
- 实现智能Agent
|
||||||
|
- 开发Streamlit应用
|
||||||
|
- 功能测试和优化
|
||||||
|
|
||||||
|
### Day 5: 项目完善
|
||||||
|
- 文档编写
|
||||||
|
- 代码优化
|
||||||
|
- 最终测试
|
||||||
|
|
||||||
|
## 4️⃣ 开发心得
|
||||||
|
|
||||||
|
### 4.1 主要困难与解决方案
|
||||||
|
|
||||||
|
在项目开发过程中,遇到的主要困难及其解决方案如下:
|
||||||
|
|
||||||
|
1. **模块导入问题**
|
||||||
|
- **困难**:当从项目根目录外运行脚本时,Python无法找到`src`模块,出现`ModuleNotFoundError`
|
||||||
|
- **解决方案**:在脚本中添加路径处理逻辑,自动将项目根目录添加到Python路径中,确保模块能够正确导入
|
||||||
|
|
||||||
|
2. **环境兼容性问题**
|
||||||
|
- **困难**:用户使用的PowerShell 5不支持现代Shell语法(如`&&`命令分隔符)
|
||||||
|
- **解决方案**:创建了基于Python的跨平台启动脚本,确保在不同环境下都能正常运行
|
||||||
|
|
||||||
|
3. **第三方库API变化**
|
||||||
|
- **困难**:`pydantic_ai`库的API与预期不符(如`register_tool`方法不存在,需要使用`tool`方法;`run`方法需要改为`run_sync`)
|
||||||
|
- **解决方案**:查阅库的帮助文档和源代码,调整代码以使用正确的API
|
||||||
|
|
||||||
|
4. **模型版本兼容性**
|
||||||
|
- **困难**:加载模型时出现scikit-learn版本不兼容的警告
|
||||||
|
- **解决方案**:确保训练和推理使用相同版本的库,并在文档中注明版本要求
|
||||||
|
|
||||||
|
### 4.2 对 AI 辅助编程的感受
|
||||||
|
|
||||||
|
使用AI辅助编程工具(如Trae IDE)的体验非常良好,主要体现在以下方面:
|
||||||
|
|
||||||
|
1. **有帮助的场景**
|
||||||
|
- **快速生成代码框架**:能够根据需求快速生成项目结构和基础代码
|
||||||
|
- **解决技术问题**:对于特定的技术问题,能够提供多种解决方案
|
||||||
|
- **优化代码质量**:能够识别代码中的问题并提供改进建议
|
||||||
|
- **学习新技术**:能够解释复杂的技术概念,帮助快速掌握新技术
|
||||||
|
|
||||||
|
2. **需要注意的地方**
|
||||||
|
- **代码验证**:生成的代码可能存在细微错误,需要仔细验证和测试
|
||||||
|
- **API准确性**:对于特定库的最新API可能不够了解,需要查阅官方文档确认
|
||||||
|
- **业务逻辑**:复杂的业务逻辑需要结合人类的专业知识进行设计
|
||||||
|
- **过度依赖**:避免过度依赖AI工具,保持独立思考和问题解决能力
|
||||||
|
|
||||||
|
### 4.3 局限与未来改进
|
||||||
|
|
||||||
|
如果有更多时间,项目还可以从以下几个方面进行改进:
|
||||||
|
|
||||||
|
1. **模型性能优化**
|
||||||
|
- 尝试更多的特征工程方法,如特征选择、特征交叉等
|
||||||
|
- 调参优化LightGBM模型,提高预测准确率
|
||||||
|
- 尝试其他先进的算法,如XGBoost、CatBoost或深度学习模型
|
||||||
|
|
||||||
|
2. **应用功能扩展**
|
||||||
|
- 添加更多的可视化图表,如客户流失风险分布、特征重要性分析等
|
||||||
|
- 实现批量预测功能,支持导入Excel或CSV文件进行批量分析
|
||||||
|
- 添加模型监控和更新机制,定期重新训练模型以适应新数据
|
||||||
|
- 支持多语言界面,提高应用的可用性
|
||||||
|
|
||||||
|
3. **系统架构改进**
|
||||||
|
- 分离前后端,使用FastAPI构建API,Streamlit作为前端
|
||||||
|
- 实现模型服务化部署,支持RESTful API调用
|
||||||
|
- 添加用户认证和权限管理,提高系统安全性
|
||||||
|
- 支持多模型版本管理,方便模型迭代和回滚
|
||||||
|
|
||||||
|
4. **开发流程优化**
|
||||||
|
- 添加更全面的单元测试和集成测试,提高代码质量
|
||||||
|
- 实现CI/CD流水线,自动构建、测试和部署
|
||||||
|
- 添加代码质量检查工具,如flake8、mypy等
|
||||||
|
- 完善文档和注释,提高代码的可维护性
|
||||||
|
|
||||||
|
5. **用户体验改进**
|
||||||
|
- 优化Streamlit界面,提高用户交互体验
|
||||||
|
- 添加详细的使用说明和帮助文档
|
||||||
|
- 提供更智能的用户输入提示和错误处理
|
||||||
|
|
||||||
|
通过这些改进,可以进一步提高项目的性能、可用性和可维护性,使其成为一个更完善的电信客户流失预测系统。
|
||||||
7044
ml_course_design/data/WA_Fn-UseC_-Telco-Customer-Churn.csv
Normal file
7044
ml_course_design/data/WA_Fn-UseC_-Telco-Customer-Churn.csv
Normal file
File diff suppressed because it is too large
Load Diff
BIN
ml_course_design/models/best_model.joblib
Normal file
BIN
ml_course_design/models/best_model.joblib
Normal file
Binary file not shown.
BIN
ml_course_design/models/best_model_lr.joblib
Normal file
BIN
ml_course_design/models/best_model_lr.joblib
Normal file
Binary file not shown.
BIN
ml_course_design/models/features.joblib
Normal file
BIN
ml_course_design/models/features.joblib
Normal file
Binary file not shown.
BIN
ml_course_design/models/scaler.joblib
Normal file
BIN
ml_course_design/models/scaler.joblib
Normal file
Binary file not shown.
24
ml_course_design/pyproject.toml
Normal file
24
ml_course_design/pyproject.toml
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
[project]
|
||||||
|
name = "ml_course_design"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Machine Learning + Agent Course Design"
|
||||||
|
authors = [{ name = "Student", email = "student@example.com" }]
|
||||||
|
requires-python = ">=3.12"
|
||||||
|
dependencies = [
|
||||||
|
"polars>=0.20.0",
|
||||||
|
"pandas>=2.2.0",
|
||||||
|
"seaborn>=0.13.0",
|
||||||
|
"pydantic>=2.5.0",
|
||||||
|
"pandera>=0.18.0",
|
||||||
|
"scikit-learn>=1.3.0",
|
||||||
|
"lightgbm>=4.3.0",
|
||||||
|
"pydantic-ai>=0.2.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
|
"streamlit>=1.30.0",
|
||||||
|
"joblib>=1.3.0",
|
||||||
|
"plotly>=6.5.1",
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.uv]
|
||||||
|
# uv 配置
|
||||||
|
|
||||||
1
ml_course_design/src/__init__.py
Normal file
1
ml_course_design/src/__init__.py
Normal file
@ -0,0 +1 @@
|
|||||||
|
# 初始化src包
|
||||||
159
ml_course_design/src/agent_app.py
Normal file
159
ml_course_design/src/agent_app.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
from pydantic_ai import Agent, RunContext
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径,解决直接运行时的导入问题
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from src.features import CustomerFeatures
|
||||||
|
from src.infer import ModelInferencer
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
|
||||||
|
class DecisionResult(BaseModel):
|
||||||
|
"""Agent决策结果模型"""
|
||||||
|
risk_score: float = Field(ge=0, le=1, description="流失风险分数")
|
||||||
|
decision: str = Field(description="决策建议")
|
||||||
|
actions: List[str] = Field(description="建议采取的行动")
|
||||||
|
rationale: str = Field(description="决策理由")
|
||||||
|
|
||||||
|
|
||||||
|
class CustomerInfo(BaseModel):
|
||||||
|
"""客户信息模型"""
|
||||||
|
age: Optional[int] = Field(description="客户年龄")
|
||||||
|
gender: Optional[str] = Field(description="客户性别")
|
||||||
|
tenure: Optional[int] = Field(description="在网时长(月)")
|
||||||
|
monthly_charges: Optional[float] = Field(description="月费用")
|
||||||
|
total_charges: Optional[float] = Field(description="总费用")
|
||||||
|
contract_type: Optional[str] = Field(description="合同类型")
|
||||||
|
internet_service: Optional[str] = Field(description="互联网服务类型")
|
||||||
|
payment_method: Optional[str] = Field(description="支付方式")
|
||||||
|
has_partner: Optional[bool] = Field(description="是否有伴侣")
|
||||||
|
has_dependents: Optional[bool] = Field(description="是否有家属")
|
||||||
|
is_senior: Optional[bool] = Field(description="是否为老年人")
|
||||||
|
|
||||||
|
|
||||||
|
class ChurnPredictionAgent:
|
||||||
|
"""客户流失预测Agent"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化Agent"""
|
||||||
|
# 获取API Key
|
||||||
|
self.api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||||
|
if not self.api_key:
|
||||||
|
raise ValueError("DEEPSEEK_API_KEY环境变量未设置,请在.env文件中配置")
|
||||||
|
|
||||||
|
# 初始化推理器
|
||||||
|
self.inferencer = ModelInferencer()
|
||||||
|
|
||||||
|
# 创建Agent
|
||||||
|
self.agent = self._create_agent()
|
||||||
|
|
||||||
|
def _create_agent(self) -> Agent:
|
||||||
|
"""创建Agent实例
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Agent实例
|
||||||
|
"""
|
||||||
|
agent = Agent(
|
||||||
|
model="deepseek:deepseek-chat",
|
||||||
|
output_type=DecisionResult,
|
||||||
|
system_prompt="你是一名专业的电信客户流失预测分析师,你的任务是根据客户信息预测流失风险并提供决策建议。\n\n" \
|
||||||
|
"你可以使用以下工具:\n" \
|
||||||
|
"1. predict_churn: 使用机器学习模型预测客户流失风险\n" \
|
||||||
|
"2. explain_churn: 解释影响客户流失的关键因素\n\n" \
|
||||||
|
"请确保你的回答专业、准确,并提供具体的行动建议。"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册工具
|
||||||
|
agent.tool(self.predict_churn)
|
||||||
|
agent.tool(self.explain_churn)
|
||||||
|
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def predict_churn(self, ctx: RunContext, customer_info: CustomerFeatures) -> float:
|
||||||
|
"""预测客户流失风险
|
||||||
|
|
||||||
|
Args:
|
||||||
|
customer_info: 客户特征信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
流失风险分数 (0-1)
|
||||||
|
"""
|
||||||
|
result = self.inferencer.predict_single(customer_info)
|
||||||
|
return result["probability"]
|
||||||
|
|
||||||
|
def explain_churn(self, ctx: RunContext, customer_info: CustomerFeatures) -> List[str]:
|
||||||
|
"""解释影响客户流失的关键因素
|
||||||
|
|
||||||
|
Args:
|
||||||
|
customer_info: 客户特征信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
影响因素列表
|
||||||
|
"""
|
||||||
|
result = self.inferencer.explain_prediction(customer_info)
|
||||||
|
return result["explanation"]
|
||||||
|
|
||||||
|
def process_query(self, query: str) -> DecisionResult:
|
||||||
|
"""处理用户查询
|
||||||
|
|
||||||
|
Args:
|
||||||
|
query: 用户的自然语言查询
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
结构化的决策结果
|
||||||
|
"""
|
||||||
|
print(f"正在处理查询: {query}")
|
||||||
|
|
||||||
|
# 运行Agent
|
||||||
|
result = self.agent.run_sync(query)
|
||||||
|
|
||||||
|
print("查询处理完成")
|
||||||
|
return result
|
||||||
|
|
||||||
|
def run_interactive(self):
|
||||||
|
"""启动交互式对话"""
|
||||||
|
print("欢迎使用客户流失预测Agent!")
|
||||||
|
print("请输入客户信息,我将为您预测流失风险并提供建议。")
|
||||||
|
print("输入'退出'或'quit'结束对话。")
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
query = input("\n请输入查询: ")
|
||||||
|
|
||||||
|
if query.lower() in ["退出", "quit", "q"]:
|
||||||
|
print("感谢使用,再见!")
|
||||||
|
break
|
||||||
|
|
||||||
|
result = self.process_query(query)
|
||||||
|
|
||||||
|
print("\n=== 预测结果 ===")
|
||||||
|
print(f"流失风险分数: {result.risk_score:.4f}")
|
||||||
|
print(f"决策建议: {result.decision}")
|
||||||
|
print("建议采取的行动:")
|
||||||
|
for action in result.actions:
|
||||||
|
print(f" - {action}")
|
||||||
|
print(f"决策理由: {result.rationale}")
|
||||||
|
print("=================")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"处理查询时发生错误: {e}")
|
||||||
|
print("请检查输入或稍后重试。")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
try:
|
||||||
|
# 初始化并启动Agent
|
||||||
|
agent = ChurnPredictionAgent()
|
||||||
|
agent.run_interactive()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"启动Agent时发生错误: {e}")
|
||||||
|
print("请确保已正确配置DEEPSEEK_API_KEY环境变量。")
|
||||||
99
ml_course_design/src/data.py
Normal file
99
ml_course_design/src/data.py
Normal file
@ -0,0 +1,99 @@
|
|||||||
|
import polars as pl
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
class DataProcessor:
|
||||||
|
"""数据处理类,用于加载和预处理Telco Customer Churn数据集"""
|
||||||
|
|
||||||
|
def __init__(self, data_path: str | Path = None):
|
||||||
|
"""初始化数据处理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data_path: 数据集路径,如果为None则使用默认路径
|
||||||
|
"""
|
||||||
|
if data_path is None:
|
||||||
|
self.data_path = Path(__file__).parent.parent / "data" / "WA_Fn-UseC_-Telco-Customer-Churn.csv"
|
||||||
|
else:
|
||||||
|
self.data_path = Path(data_path)
|
||||||
|
|
||||||
|
def load_data(self) -> pl.DataFrame:
|
||||||
|
"""加载原始数据集
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加载后的Polars DataFrame
|
||||||
|
"""
|
||||||
|
print(f"正在加载数据: {self.data_path}")
|
||||||
|
|
||||||
|
# 使用Lazy API加载数据,提高效率
|
||||||
|
lf = pl.scan_csv(self.data_path)
|
||||||
|
df = lf.collect()
|
||||||
|
|
||||||
|
print(f"数据加载完成,共 {df.shape[0]} 行,{df.shape[1]} 列")
|
||||||
|
return df
|
||||||
|
|
||||||
|
def preprocess_data(self, df: pl.DataFrame) -> Tuple[pl.DataFrame, pl.Series]:
|
||||||
|
"""预处理数据集
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: 原始数据集
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预处理后的特征数据和目标变量
|
||||||
|
"""
|
||||||
|
print("开始数据预处理...")
|
||||||
|
|
||||||
|
# 1. 处理缺失值和异常值
|
||||||
|
# 检查TotalCharges列的类型,如果是字符串类型则处理空字符串
|
||||||
|
if df["TotalCharges"].dtype == pl.String:
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col("TotalCharges").str.strip_chars().replace("", None)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 将TotalCharges转换为浮点型
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col("TotalCharges").cast(pl.Float64, strict=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 处理缺失值 - 删除TotalCharges为None的行
|
||||||
|
df = df.filter(pl.col("TotalCharges").is_not_null())
|
||||||
|
|
||||||
|
# 2. 处理目标变量
|
||||||
|
# 将Churn转换为数值型 (0=No, 1=Yes)
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col("Churn").replace({"No": 0, "Yes": 1}).cast(pl.Int32).alias("Churn")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 选择特征列
|
||||||
|
# 排除customerID(唯一标识,对模型训练无用)
|
||||||
|
feature_cols = [col for col in df.columns if col not in ["customerID", "Churn"]]
|
||||||
|
|
||||||
|
# 分离特征和目标变量
|
||||||
|
X = df.select(feature_cols)
|
||||||
|
y = df.select("Churn").to_series()
|
||||||
|
|
||||||
|
print(f"数据预处理完成,特征数据形状: {X.shape}, 目标变量形状: {y.shape}")
|
||||||
|
return X, y
|
||||||
|
|
||||||
|
def get_processed_data(self) -> Tuple[pl.DataFrame, pl.Series]:
|
||||||
|
"""获取完整处理后的数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预处理后的特征数据和目标变量
|
||||||
|
"""
|
||||||
|
df = self.load_data()
|
||||||
|
X, y = self.preprocess_data(df)
|
||||||
|
return X, y
|
||||||
|
|
||||||
|
# 用于测试数据处理模块
|
||||||
|
if __name__ == "__main__":
|
||||||
|
processor = DataProcessor()
|
||||||
|
X, y = processor.get_processed_data()
|
||||||
|
|
||||||
|
print("\n特征数据示例:")
|
||||||
|
print(X.head())
|
||||||
|
|
||||||
|
print("\n目标变量示例:")
|
||||||
|
print(y.head())
|
||||||
|
|
||||||
|
print(f"\n目标变量分布: {y.value_counts().sort("Churn")}")
|
||||||
125
ml_course_design/src/features.py
Normal file
125
ml_course_design/src/features.py
Normal file
@ -0,0 +1,125 @@
|
|||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from pandera import Column, Check, DataFrameSchema
|
||||||
|
import pandera as pa
|
||||||
|
from typing import Literal, Optional
|
||||||
|
|
||||||
|
# 定义性别类型
|
||||||
|
gender_types = Literal["Male", "Female"]
|
||||||
|
|
||||||
|
# 定义Yes/No类型
|
||||||
|
yes_no_types = Literal["Yes", "No"]
|
||||||
|
|
||||||
|
# 定义服务相关类型
|
||||||
|
service_types = Literal["Yes", "No", "No internet service"]
|
||||||
|
phone_line_types = Literal["Yes", "No", "No phone service"]
|
||||||
|
|
||||||
|
# 定义互联网服务类型
|
||||||
|
internet_service_types = Literal["DSL", "Fiber optic", "No"]
|
||||||
|
|
||||||
|
# 定义合同类型
|
||||||
|
contract_types = Literal["Month-to-month", "One year", "Two year"]
|
||||||
|
|
||||||
|
# 定义支付方式类型
|
||||||
|
payment_method_types = Literal["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"]
|
||||||
|
|
||||||
|
|
||||||
|
class CustomerFeatures(BaseModel):
|
||||||
|
"""客户特征模型"""
|
||||||
|
# 基本信息
|
||||||
|
gender: gender_types = Field(description="性别")
|
||||||
|
SeniorCitizen: int = Field(ge=0, le=1, description="是否为老年人 (0=No, 1=Yes)")
|
||||||
|
Partner: yes_no_types = Field(description="是否有伴侣")
|
||||||
|
Dependents: yes_no_types = Field(description="是否有家属")
|
||||||
|
tenure: int = Field(ge=0, le=100, description="客户在网时长 (月)")
|
||||||
|
|
||||||
|
# 电话服务
|
||||||
|
PhoneService: yes_no_types = Field(description="是否有电话服务")
|
||||||
|
MultipleLines: phone_line_types = Field(description="是否有多条线路")
|
||||||
|
|
||||||
|
# 互联网服务
|
||||||
|
InternetService: internet_service_types = Field(description="互联网服务类型")
|
||||||
|
OnlineSecurity: service_types = Field(description="是否有在线安全服务")
|
||||||
|
OnlineBackup: service_types = Field(description="是否有在线备份服务")
|
||||||
|
DeviceProtection: service_types = Field(description="是否有设备保护服务")
|
||||||
|
TechSupport: service_types = Field(description="是否有技术支持服务")
|
||||||
|
StreamingTV: service_types = Field(description="是否有流媒体电视服务")
|
||||||
|
StreamingMovies: service_types = Field(description="是否有流媒体电影服务")
|
||||||
|
|
||||||
|
# 合同和账单
|
||||||
|
Contract: contract_types = Field(description="合同类型")
|
||||||
|
PaperlessBilling: yes_no_types = Field(description="是否使用无纸化账单")
|
||||||
|
PaymentMethod: payment_method_types = Field(description="支付方式")
|
||||||
|
MonthlyCharges: float = Field(ge=0, le=200, description="月费用")
|
||||||
|
TotalCharges: float = Field(ge=0, le=10000, description="总费用")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
populate_by_name = True
|
||||||
|
from_attributes = True
|
||||||
|
|
||||||
|
|
||||||
|
# 定义用于数据验证的DataFrame Schema
|
||||||
|
data_schema = DataFrameSchema(
|
||||||
|
columns={
|
||||||
|
# 输入特征
|
||||||
|
"gender": Column(pa.String, checks=Check.isin(["Male", "Female"])),
|
||||||
|
"SeniorCitizen": Column(pa.Int, checks=Check.isin([0, 1])),
|
||||||
|
"Partner": Column(pa.String, checks=Check.isin(["Yes", "No"])),
|
||||||
|
"Dependents": Column(pa.String, checks=Check.isin(["Yes", "No"])),
|
||||||
|
"tenure": Column(pa.Int, checks=Check.ge(0)),
|
||||||
|
"PhoneService": Column(pa.String, checks=Check.isin(["Yes", "No"])),
|
||||||
|
"MultipleLines": Column(pa.String, checks=Check.isin(["Yes", "No", "No phone service"])),
|
||||||
|
"InternetService": Column(pa.String, checks=Check.isin(["DSL", "Fiber optic", "No"])),
|
||||||
|
"OnlineSecurity": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
|
||||||
|
"OnlineBackup": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
|
||||||
|
"DeviceProtection": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
|
||||||
|
"TechSupport": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
|
||||||
|
"StreamingTV": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
|
||||||
|
"StreamingMovies": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
|
||||||
|
"Contract": Column(pa.String, checks=Check.isin(["Month-to-month", "One year", "Two year"])),
|
||||||
|
"PaperlessBilling": Column(pa.String, checks=Check.isin(["Yes", "No"])),
|
||||||
|
"PaymentMethod": Column(pa.String, checks=Check.isin([
|
||||||
|
"Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"
|
||||||
|
])),
|
||||||
|
"MonthlyCharges": Column(pa.Float, checks=Check.ge(0)),
|
||||||
|
"TotalCharges": Column(pa.Float, checks=Check.ge(0)),
|
||||||
|
|
||||||
|
# 目标变量
|
||||||
|
"Churn": Column(pa.Int, checks=Check.isin([0, 1])),
|
||||||
|
},
|
||||||
|
strict=True,
|
||||||
|
coerce=True,
|
||||||
|
name="customer_churn_schema"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试特征模型
|
||||||
|
print("测试CustomerFeatures模型...")
|
||||||
|
|
||||||
|
# 创建一个有效的特征实例
|
||||||
|
valid_features = CustomerFeatures(
|
||||||
|
gender="Female",
|
||||||
|
SeniorCitizen=0,
|
||||||
|
Partner="Yes",
|
||||||
|
Dependents="No",
|
||||||
|
tenure=1,
|
||||||
|
PhoneService="No",
|
||||||
|
MultipleLines="No phone service",
|
||||||
|
InternetService="DSL",
|
||||||
|
OnlineSecurity="No",
|
||||||
|
OnlineBackup="Yes",
|
||||||
|
DeviceProtection="No",
|
||||||
|
TechSupport="No",
|
||||||
|
StreamingTV="No",
|
||||||
|
StreamingMovies="No",
|
||||||
|
Contract="Month-to-month",
|
||||||
|
PaperlessBilling="Yes",
|
||||||
|
PaymentMethod="Electronic check",
|
||||||
|
MonthlyCharges=29.85,
|
||||||
|
TotalCharges=29.85
|
||||||
|
)
|
||||||
|
|
||||||
|
print("有效特征实例:")
|
||||||
|
print(valid_features)
|
||||||
|
|
||||||
|
print("\n特征模型测试通过!")
|
||||||
185
ml_course_design/src/infer.py
Normal file
185
ml_course_design/src/infer.py
Normal file
@ -0,0 +1,185 @@
|
|||||||
|
import joblib
|
||||||
|
import pandas as pd
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
|
||||||
|
from .features import CustomerFeatures
|
||||||
|
|
||||||
|
|
||||||
|
class ModelInferencer:
|
||||||
|
"""模型推理类"""
|
||||||
|
|
||||||
|
def __init__(self, model_path: str | Path = None):
|
||||||
|
"""初始化模型推理器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path: 模型路径,如果为None则使用默认路径
|
||||||
|
"""
|
||||||
|
if model_path is None:
|
||||||
|
self.model_path = Path(__file__).parent.parent / "models" / "best_model_lr.joblib"
|
||||||
|
else:
|
||||||
|
self.model_path = Path(model_path)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
self.model = self.load_model()
|
||||||
|
|
||||||
|
def load_model(self) -> Any:
|
||||||
|
"""加载训练好的模型
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
加载的模型对象
|
||||||
|
"""
|
||||||
|
print(f"正在加载模型: {self.model_path}")
|
||||||
|
|
||||||
|
if not self.model_path.exists():
|
||||||
|
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
|
||||||
|
|
||||||
|
model = joblib.load(self.model_path)
|
||||||
|
print(f"模型加载成功: {type(model).__name__}")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
def predict_single(self, features: CustomerFeatures) -> Dict[str, Any]:
|
||||||
|
"""对单个客户进行流失预测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: 客户特征对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预测结果,包含流失概率和预测类别
|
||||||
|
"""
|
||||||
|
# 将特征转换为DataFrame
|
||||||
|
features_dict = features.model_dump()
|
||||||
|
df = pd.DataFrame([features_dict])
|
||||||
|
|
||||||
|
# 进行预测
|
||||||
|
prediction = self.model.predict(df)[0]
|
||||||
|
probability = self.model.predict_proba(df)[0][1]
|
||||||
|
|
||||||
|
# 构造结果
|
||||||
|
result = {
|
||||||
|
"prediction": int(prediction), # 0=不流失, 1=流失
|
||||||
|
"probability": float(probability), # 流失概率
|
||||||
|
"churn": bool(prediction), # 是否流失
|
||||||
|
"features": features_dict
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
def predict_batch(self, features_list: List[CustomerFeatures]) -> List[Dict[str, Any]]:
|
||||||
|
"""对多个客户进行批量流失预测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features_list: 客户特征对象列表
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
批量预测结果列表
|
||||||
|
"""
|
||||||
|
# 将特征列表转换为DataFrame
|
||||||
|
features_dicts = [features.model_dump() for features in features_list]
|
||||||
|
df = pd.DataFrame(features_dicts)
|
||||||
|
|
||||||
|
# 进行批量预测
|
||||||
|
predictions = self.model.predict(df)
|
||||||
|
probabilities = self.model.predict_proba(df)[:, 1]
|
||||||
|
|
||||||
|
# 构造结果列表
|
||||||
|
results = []
|
||||||
|
for i in range(len(predictions)):
|
||||||
|
result = {
|
||||||
|
"prediction": int(predictions[i]),
|
||||||
|
"probability": float(probabilities[i]),
|
||||||
|
"churn": bool(predictions[i]),
|
||||||
|
"features": features_dicts[i]
|
||||||
|
}
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def explain_prediction(self, features: CustomerFeatures) -> Dict[str, Any]:
|
||||||
|
"""解释预测结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
features: 客户特征对象
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
包含预测结果和解释的字典
|
||||||
|
"""
|
||||||
|
# 获取基本预测结果
|
||||||
|
prediction_result = self.predict_single(features)
|
||||||
|
|
||||||
|
# 分析影响流失的关键因素
|
||||||
|
key_factors = []
|
||||||
|
|
||||||
|
# 根据业务知识分析影响因素
|
||||||
|
if features.Contract == "Month-to-month":
|
||||||
|
key_factors.append("月付合同增加了流失风险")
|
||||||
|
|
||||||
|
if features.tenure < 12:
|
||||||
|
key_factors.append("在网时长较短增加了流失风险")
|
||||||
|
|
||||||
|
if features.MonthlyCharges > 70:
|
||||||
|
key_factors.append("月费用较高增加了流失风险")
|
||||||
|
|
||||||
|
if features.InternetService == "Fiber optic":
|
||||||
|
key_factors.append("光纤互联网服务用户流失风险较高")
|
||||||
|
|
||||||
|
if features.PaymentMethod == "Electronic check":
|
||||||
|
key_factors.append("电子支票支付方式增加了流失风险")
|
||||||
|
|
||||||
|
if features.PaperlessBilling == "Yes":
|
||||||
|
key_factors.append("无纸化账单用户流失风险较高")
|
||||||
|
|
||||||
|
# 如果没有找到明显因素
|
||||||
|
if not key_factors:
|
||||||
|
key_factors.append("客户特征组合导致流失风险处于平均水平")
|
||||||
|
|
||||||
|
# 添加解释到结果中
|
||||||
|
prediction_result["explanation"] = key_factors
|
||||||
|
|
||||||
|
return prediction_result
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试推理功能
|
||||||
|
print("测试模型推理功能...")
|
||||||
|
|
||||||
|
# 创建测试特征
|
||||||
|
test_features = CustomerFeatures(
|
||||||
|
gender="Female",
|
||||||
|
SeniorCitizen=0,
|
||||||
|
Partner="Yes",
|
||||||
|
Dependents="No",
|
||||||
|
tenure=1,
|
||||||
|
PhoneService="No",
|
||||||
|
MultipleLines="No phone service",
|
||||||
|
InternetService="DSL",
|
||||||
|
OnlineSecurity="No",
|
||||||
|
OnlineBackup="Yes",
|
||||||
|
DeviceProtection="No",
|
||||||
|
TechSupport="No",
|
||||||
|
StreamingTV="No",
|
||||||
|
StreamingMovies="No",
|
||||||
|
Contract="Month-to-month",
|
||||||
|
PaperlessBilling="Yes",
|
||||||
|
PaymentMethod="Electronic check",
|
||||||
|
MonthlyCharges=29.85,
|
||||||
|
TotalCharges=29.85
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化推理器
|
||||||
|
inferencer = ModelInferencer()
|
||||||
|
|
||||||
|
# 进行单例预测
|
||||||
|
result = inferencer.predict_single(test_features)
|
||||||
|
print("\n单例预测结果:")
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# 进行预测解释
|
||||||
|
explained_result = inferencer.explain_prediction(test_features)
|
||||||
|
print("\n预测解释:")
|
||||||
|
print(f"流失概率: {explained_result['probability']:.4f}")
|
||||||
|
print(f"预测结果: {'流失' if explained_result['churn'] else '不流失'}")
|
||||||
|
print("影响因素:")
|
||||||
|
for factor in explained_result['explanation']:
|
||||||
|
print(f" - {factor}")
|
||||||
247
ml_course_design/src/streamlit_app.py
Normal file
247
ml_course_design/src/streamlit_app.py
Normal file
@ -0,0 +1,247 @@
|
|||||||
|
import streamlit as st
|
||||||
|
import pandas as pd
|
||||||
|
import plotly.express as px
|
||||||
|
import plotly.graph_objects as go
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 使用绝对导入或直接导入
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.append(str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
from src.features import CustomerFeatures
|
||||||
|
from src.infer import ModelInferencer
|
||||||
|
|
||||||
|
|
||||||
|
class ChurnPredictionApp:
|
||||||
|
"""客户流失预测Streamlit应用"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""初始化应用"""
|
||||||
|
# 设置页面配置
|
||||||
|
st.set_page_config(
|
||||||
|
page_title="客户流失预测系统",
|
||||||
|
page_icon="📊",
|
||||||
|
layout="wide",
|
||||||
|
initial_sidebar_state="expanded"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 初始化推理器
|
||||||
|
self.inferencer = ModelInferencer()
|
||||||
|
|
||||||
|
# 设置应用标题和说明
|
||||||
|
self._set_app_header()
|
||||||
|
|
||||||
|
def _set_app_header(self):
|
||||||
|
"""设置应用标题和说明"""
|
||||||
|
st.title("📊 客户流失预测系统")
|
||||||
|
st.markdown("---")
|
||||||
|
st.write("这是一个基于机器学习的电信客户流失预测系统。输入客户信息,系统将预测该客户的流失风险并提供针对性建议。")
|
||||||
|
|
||||||
|
def _create_input_form(self) -> dict:
|
||||||
|
"""创建客户信息输入表单
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
输入的客户信息字典
|
||||||
|
"""
|
||||||
|
with st.sidebar.form("customer_info_form"):
|
||||||
|
st.header("客户信息")
|
||||||
|
|
||||||
|
# 基本信息
|
||||||
|
st.subheader("基本信息")
|
||||||
|
gender = st.selectbox("性别", ["Male", "Female"])
|
||||||
|
senior_citizen = st.selectbox("是否为老年人", [0, 1], format_func=lambda x: "是" if x == 1 else "否")
|
||||||
|
partner = st.selectbox("是否有伴侣", ["Yes", "No"], format_func=lambda x: "是" if x == "Yes" else "否")
|
||||||
|
dependents = st.selectbox("是否有家属", ["Yes", "No"], format_func=lambda x: "是" if x == "Yes" else "否")
|
||||||
|
tenure = st.number_input("在网时长(月)", min_value=0, max_value=100, value=1)
|
||||||
|
|
||||||
|
# 电话服务
|
||||||
|
st.subheader("电话服务")
|
||||||
|
phone_service = st.selectbox("是否有电话服务", ["Yes", "No"], format_func=lambda x: "是" if x == "Yes" else "否")
|
||||||
|
|
||||||
|
if phone_service == "Yes":
|
||||||
|
multiple_lines = st.selectbox("是否有多条线路", ["Yes", "No"])
|
||||||
|
else:
|
||||||
|
multiple_lines = "No phone service"
|
||||||
|
|
||||||
|
# 互联网服务
|
||||||
|
st.subheader("互联网服务")
|
||||||
|
internet_service = st.selectbox("互联网服务类型", ["DSL", "Fiber optic", "No"])
|
||||||
|
|
||||||
|
if internet_service != "No":
|
||||||
|
online_security = st.selectbox("是否有在线安全服务", ["Yes", "No"])
|
||||||
|
online_backup = st.selectbox("是否有在线备份服务", ["Yes", "No"])
|
||||||
|
device_protection = st.selectbox("是否有设备保护服务", ["Yes", "No"])
|
||||||
|
tech_support = st.selectbox("是否有技术支持服务", ["Yes", "No"])
|
||||||
|
streaming_tv = st.selectbox("是否有流媒体电视服务", ["Yes", "No"])
|
||||||
|
streaming_movies = st.selectbox("是否有流媒体电影服务", ["Yes", "No"])
|
||||||
|
else:
|
||||||
|
online_security = "No internet service"
|
||||||
|
online_backup = "No internet service"
|
||||||
|
device_protection = "No internet service"
|
||||||
|
tech_support = "No internet service"
|
||||||
|
streaming_tv = "No internet service"
|
||||||
|
streaming_movies = "No internet service"
|
||||||
|
|
||||||
|
# 合同和账单
|
||||||
|
st.subheader("合同和账单")
|
||||||
|
contract = st.selectbox("合同类型", ["Month-to-month", "One year", "Two year"])
|
||||||
|
paperless_billing = st.selectbox("是否使用无纸化账单", ["Yes", "No"], format_func=lambda x: "是" if x == "Yes" else "否")
|
||||||
|
payment_method = st.selectbox(
|
||||||
|
"支付方式",
|
||||||
|
["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"]
|
||||||
|
)
|
||||||
|
monthly_charges = st.number_input("月费用", min_value=0.0, max_value=200.0, value=29.85, step=0.01)
|
||||||
|
total_charges = st.number_input("总费用", min_value=0.0, max_value=10000.0, value=29.85, step=0.01)
|
||||||
|
|
||||||
|
# 提交按钮
|
||||||
|
submit_button = st.form_submit_button("预测流失风险")
|
||||||
|
|
||||||
|
# 构造特征字典
|
||||||
|
features_dict = {
|
||||||
|
"gender": gender,
|
||||||
|
"SeniorCitizen": senior_citizen,
|
||||||
|
"Partner": partner,
|
||||||
|
"Dependents": dependents,
|
||||||
|
"tenure": tenure,
|
||||||
|
"PhoneService": phone_service,
|
||||||
|
"MultipleLines": multiple_lines,
|
||||||
|
"InternetService": internet_service,
|
||||||
|
"OnlineSecurity": online_security,
|
||||||
|
"OnlineBackup": online_backup,
|
||||||
|
"DeviceProtection": device_protection,
|
||||||
|
"TechSupport": tech_support,
|
||||||
|
"StreamingTV": streaming_tv,
|
||||||
|
"StreamingMovies": streaming_movies,
|
||||||
|
"Contract": contract,
|
||||||
|
"PaperlessBilling": paperless_billing,
|
||||||
|
"PaymentMethod": payment_method,
|
||||||
|
"MonthlyCharges": monthly_charges,
|
||||||
|
"TotalCharges": total_charges
|
||||||
|
}
|
||||||
|
|
||||||
|
return features_dict, submit_button
|
||||||
|
|
||||||
|
def _display_prediction_result(self, result: dict):
|
||||||
|
"""展示预测结果
|
||||||
|
|
||||||
|
Args:
|
||||||
|
result: 预测结果字典
|
||||||
|
"""
|
||||||
|
st.markdown("---")
|
||||||
|
st.header("预测结果")
|
||||||
|
|
||||||
|
# 创建两列布局
|
||||||
|
col1, col2 = st.columns(2)
|
||||||
|
|
||||||
|
with col1:
|
||||||
|
# 显示流失风险分数
|
||||||
|
st.subheader("📈 流失风险评分")
|
||||||
|
|
||||||
|
# 创建风险评分可视化
|
||||||
|
risk_score = result["probability"]
|
||||||
|
risk_percentage = risk_score * 100
|
||||||
|
|
||||||
|
# 确定风险等级
|
||||||
|
if risk_score < 0.3:
|
||||||
|
risk_level = "低风险"
|
||||||
|
color = "green"
|
||||||
|
elif risk_score < 0.7:
|
||||||
|
risk_level = "中风险"
|
||||||
|
color = "orange"
|
||||||
|
else:
|
||||||
|
risk_level = "高风险"
|
||||||
|
color = "red"
|
||||||
|
|
||||||
|
# 使用进度条显示风险评分
|
||||||
|
st.progress(risk_score)
|
||||||
|
st.write(f"**风险等级:** <span style='color:{color}; font-weight:bold;'>{risk_level}</span>", unsafe_allow_html=True)
|
||||||
|
st.write(f"**风险概率:** {risk_percentage:.1f}%")
|
||||||
|
st.write(f"**预测结果:** {'⚠️ 可能流失' if result['churn'] else '✅ 不太可能流失'}")
|
||||||
|
|
||||||
|
with col2:
|
||||||
|
# 显示影响因素
|
||||||
|
st.subheader("🔍 影响因素分析")
|
||||||
|
|
||||||
|
# 检查是否有解释信息
|
||||||
|
if "explanation" in result:
|
||||||
|
for factor in result["explanation"]:
|
||||||
|
st.write(f"- {factor}")
|
||||||
|
else:
|
||||||
|
st.write("暂无影响因素分析")
|
||||||
|
|
||||||
|
# 显示详细特征
|
||||||
|
with st.expander("📋 详细客户信息"):
|
||||||
|
df_features = pd.DataFrame.from_dict(result["features"], orient="index", columns=["值"])
|
||||||
|
st.dataframe(df_features, use_container_width=True)
|
||||||
|
|
||||||
|
# 显示建议
|
||||||
|
st.subheader("💡 建议采取的行动")
|
||||||
|
if result["churn"]:
|
||||||
|
st.markdown("""
|
||||||
|
- 主动联系客户,了解其需求和不满
|
||||||
|
- 提供针对性的优惠活动,如折扣或礼品
|
||||||
|
- 分析客户使用习惯,推荐更适合的套餐
|
||||||
|
- 加强客户服务,提高客户满意度
|
||||||
|
""")
|
||||||
|
else:
|
||||||
|
st.markdown("""
|
||||||
|
- 继续保持良好的客户服务
|
||||||
|
- 定期推送个性化的优惠信息
|
||||||
|
- 关注客户使用行为变化
|
||||||
|
- 鼓励客户升级套餐或添加新服务
|
||||||
|
""")
|
||||||
|
|
||||||
|
def _show_data_statistics(self):
|
||||||
|
"""显示数据统计信息"""
|
||||||
|
st.markdown("---")
|
||||||
|
st.header("📊 数据统计信息")
|
||||||
|
|
||||||
|
# 创建模拟的流失数据统计
|
||||||
|
data = {
|
||||||
|
"合同类型": ["月付", "一年", "两年"],
|
||||||
|
"客户数": [4200, 2100, 732],
|
||||||
|
"流失率": [0.42, 0.18, 0.09]
|
||||||
|
}
|
||||||
|
|
||||||
|
df = pd.DataFrame(data)
|
||||||
|
|
||||||
|
# 显示合同类型与流失率的关系
|
||||||
|
fig = px.bar(df, x="合同类型", y="流失率", color="合同类型",
|
||||||
|
title="不同合同类型的客户流失率",
|
||||||
|
labels={"流失率": "流失率(%)"},
|
||||||
|
hover_data={"客户数": True})
|
||||||
|
fig.update_traces(hovertemplate="合同类型: %{x}<br>流失率: %{y:.1%}<br>客户数: %{customdata[0]}")
|
||||||
|
st.plotly_chart(fig, use_container_width=True)
|
||||||
|
|
||||||
|
def run(self):
|
||||||
|
"""运行应用"""
|
||||||
|
# 创建输入表单
|
||||||
|
features_dict, submit_button = self._create_input_form()
|
||||||
|
|
||||||
|
# 当用户点击预测按钮时
|
||||||
|
if submit_button:
|
||||||
|
try:
|
||||||
|
# 验证输入并创建特征对象
|
||||||
|
features = CustomerFeatures(**features_dict)
|
||||||
|
|
||||||
|
# 进行预测
|
||||||
|
with st.spinner("正在预测..."):
|
||||||
|
result = self.inferencer.explain_prediction(features)
|
||||||
|
|
||||||
|
# 展示预测结果
|
||||||
|
self._display_prediction_result(result)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"预测过程中发生错误: {e}")
|
||||||
|
|
||||||
|
# 显示数据统计信息
|
||||||
|
self._show_data_statistics()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 启动应用
|
||||||
|
app = ChurnPredictionApp()
|
||||||
|
app.run()
|
||||||
310
ml_course_design/src/train.py
Normal file
310
ml_course_design/src/train.py
Normal file
@ -0,0 +1,310 @@
|
|||||||
|
import polars as pl
|
||||||
|
import pandas as pd
|
||||||
|
from sklearn.model_selection import train_test_split
|
||||||
|
from sklearn.preprocessing import StandardScaler, OneHotEncoder
|
||||||
|
from sklearn.compose import ColumnTransformer
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
|
||||||
|
import lightgbm as lgb
|
||||||
|
import joblib
|
||||||
|
from pathlib import Path
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径,解决直接运行时的导入问题
|
||||||
|
project_root = Path(__file__).parent.parent
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
from src.data import DataProcessor
|
||||||
|
from src.features import data_schema
|
||||||
|
|
||||||
|
|
||||||
|
class ModelTrainer:
|
||||||
|
"""模型训练类"""
|
||||||
|
|
||||||
|
def __init__(self, models_dir: str | Path = None):
|
||||||
|
"""初始化模型训练器
|
||||||
|
|
||||||
|
Args:
|
||||||
|
models_dir: 模型保存目录,如果为None则使用默认路径
|
||||||
|
"""
|
||||||
|
if models_dir is None:
|
||||||
|
self.models_dir = Path(__file__).parent.parent / "models"
|
||||||
|
else:
|
||||||
|
self.models_dir = Path(models_dir)
|
||||||
|
|
||||||
|
# 确保模型目录存在
|
||||||
|
self.models_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
def prepare_data(self) -> tuple:
|
||||||
|
"""准备训练数据
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
训练集、验证集和测试集(X_train, X_val, X_test, y_train, y_val, y_test)
|
||||||
|
"""
|
||||||
|
print("准备训练数据...")
|
||||||
|
|
||||||
|
# 加载和预处理数据
|
||||||
|
processor = DataProcessor()
|
||||||
|
X, y = processor.get_processed_data()
|
||||||
|
|
||||||
|
# 转换为pandas DataFrame以便与scikit-learn兼容
|
||||||
|
X_pandas = X.to_pandas()
|
||||||
|
y_pandas = y.to_pandas()
|
||||||
|
|
||||||
|
# 划分训练集和测试集 (80% train, 20% test)
|
||||||
|
X_train_val, X_test, y_train_val, y_test = train_test_split(
|
||||||
|
X_pandas, y_pandas, test_size=0.2, random_state=42, stratify=y_pandas
|
||||||
|
)
|
||||||
|
|
||||||
|
# 从训练集中划分验证集 (75% train, 25% val)
|
||||||
|
X_train, X_val, y_train, y_val = train_test_split(
|
||||||
|
X_train_val, y_train_val, test_size=0.25, random_state=42, stratify=y_train_val
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"训练集: {X_train.shape}")
|
||||||
|
print(f"验证集: {X_val.shape}")
|
||||||
|
print(f"测试集: {X_test.shape}")
|
||||||
|
|
||||||
|
return X_train, X_val, X_test, y_train, y_val, y_test
|
||||||
|
|
||||||
|
def create_preprocessor(self, X_train: pd.DataFrame) -> ColumnTransformer:
|
||||||
|
"""创建数据预处理管道
|
||||||
|
|
||||||
|
Args:
|
||||||
|
X_train: 训练集数据,用于获取特征信息
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
数据预处理管道
|
||||||
|
"""
|
||||||
|
print("创建数据预处理管道...")
|
||||||
|
|
||||||
|
# 分离数值特征和分类特征
|
||||||
|
numeric_features = X_train.select_dtypes(include=['int64', 'float64']).columns.tolist()
|
||||||
|
categorical_features = X_train.select_dtypes(include=['object']).columns.tolist()
|
||||||
|
|
||||||
|
print(f"数值特征: {numeric_features}")
|
||||||
|
print(f"分类特征: {categorical_features}")
|
||||||
|
|
||||||
|
# 创建数值特征处理管道
|
||||||
|
numeric_transformer = Pipeline(steps=[
|
||||||
|
('scaler', StandardScaler())
|
||||||
|
])
|
||||||
|
|
||||||
|
# 创建分类特征处理管道
|
||||||
|
categorical_transformer = Pipeline(steps=[
|
||||||
|
('onehot', OneHotEncoder(handle_unknown='ignore'))
|
||||||
|
])
|
||||||
|
|
||||||
|
# 创建完整的预处理管道
|
||||||
|
preprocessor = ColumnTransformer(
|
||||||
|
transformers=[
|
||||||
|
('num', numeric_transformer, numeric_features),
|
||||||
|
('cat', categorical_transformer, categorical_features)
|
||||||
|
])
|
||||||
|
|
||||||
|
return preprocessor
|
||||||
|
|
||||||
|
def train_logistic_regression(self, preprocessor: ColumnTransformer, X_train: pd.DataFrame, y_train: pd.Series) -> Pipeline:
|
||||||
|
"""训练Logistic Regression模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preprocessor: 数据预处理管道
|
||||||
|
X_train: 训练集特征
|
||||||
|
y_train: 训练集目标变量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
训练好的Logistic Regression模型管道
|
||||||
|
"""
|
||||||
|
print("训练Logistic Regression模型...")
|
||||||
|
|
||||||
|
# 创建Logistic Regression模型管道
|
||||||
|
lr_pipeline = Pipeline(steps=[
|
||||||
|
('preprocessor', preprocessor),
|
||||||
|
('classifier', LogisticRegression(max_iter=1000, random_state=42))
|
||||||
|
])
|
||||||
|
|
||||||
|
# 训练模型
|
||||||
|
lr_pipeline.fit(X_train, y_train)
|
||||||
|
|
||||||
|
return lr_pipeline
|
||||||
|
|
||||||
|
def train_lightgbm(self, preprocessor: ColumnTransformer, X_train: pd.DataFrame, y_train: pd.Series) -> tuple:
|
||||||
|
"""训练LightGBM模型
|
||||||
|
|
||||||
|
Args:
|
||||||
|
preprocessor: 数据预处理管道
|
||||||
|
X_train: 训练集特征
|
||||||
|
y_train: 训练集目标变量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
预处理后的特征、训练好的LightGBM模型
|
||||||
|
"""
|
||||||
|
print("训练LightGBM模型...")
|
||||||
|
|
||||||
|
# 预处理训练数据
|
||||||
|
X_train_preprocessed = preprocessor.fit_transform(X_train)
|
||||||
|
|
||||||
|
# 获取特征名称
|
||||||
|
num_features = preprocessor.transformers_[0][2]
|
||||||
|
cat_features = preprocessor.named_transformers_['cat'].get_feature_names_out()
|
||||||
|
feature_names = num_features + list(cat_features)
|
||||||
|
|
||||||
|
# 创建LightGBM数据集
|
||||||
|
lgb_train = lgb.Dataset(X_train_preprocessed, y_train, feature_name=feature_names)
|
||||||
|
|
||||||
|
# 设置LightGBM参数
|
||||||
|
params = {
|
||||||
|
'objective': 'binary',
|
||||||
|
'metric': 'binary_logloss',
|
||||||
|
'learning_rate': 0.05,
|
||||||
|
'n_estimators': 500,
|
||||||
|
'random_state': 42,
|
||||||
|
'verbose': -1
|
||||||
|
}
|
||||||
|
|
||||||
|
# 训练LightGBM模型
|
||||||
|
lgb_model = lgb.train(
|
||||||
|
params,
|
||||||
|
lgb_train,
|
||||||
|
num_boost_round=500,
|
||||||
|
valid_sets=[lgb_train],
|
||||||
|
callbacks=[lgb.log_evaluation(period=100)]
|
||||||
|
)
|
||||||
|
|
||||||
|
return feature_names, lgb_model
|
||||||
|
|
||||||
|
def evaluate_model(self, model: Pipeline | lgb.Booster, preprocessor: ColumnTransformer,
|
||||||
|
X: pd.DataFrame, y: pd.Series, model_name: str) -> dict:
|
||||||
|
"""评估模型性能
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 要评估的模型
|
||||||
|
preprocessor: 数据预处理管道
|
||||||
|
X: 测试数据特征
|
||||||
|
y: 测试数据目标变量
|
||||||
|
model_name: 模型名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
模型性能指标
|
||||||
|
"""
|
||||||
|
print(f"评估{model_name}模型...")
|
||||||
|
|
||||||
|
# 预测概率
|
||||||
|
if isinstance(model, Pipeline):
|
||||||
|
y_pred_proba = model.predict_proba(X)[:, 1]
|
||||||
|
y_pred = model.predict(X)
|
||||||
|
else:
|
||||||
|
# LightGBM模型
|
||||||
|
X_preprocessed = preprocessor.transform(X)
|
||||||
|
y_pred_proba = model.predict(X_preprocessed)
|
||||||
|
y_pred = (y_pred_proba >= 0.5).astype(int)
|
||||||
|
|
||||||
|
# 计算性能指标
|
||||||
|
metrics = {
|
||||||
|
'accuracy': accuracy_score(y, y_pred),
|
||||||
|
'precision': precision_score(y, y_pred),
|
||||||
|
'recall': recall_score(y, y_pred),
|
||||||
|
'f1': f1_score(y, y_pred),
|
||||||
|
'roc_auc': roc_auc_score(y, y_pred_proba)
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"{model_name} 模型性能:")
|
||||||
|
for metric_name, value in metrics.items():
|
||||||
|
print(f" {metric_name}: {value:.4f}")
|
||||||
|
|
||||||
|
return metrics
|
||||||
|
|
||||||
|
def save_model(self, model: Pipeline | lgb.Booster, preprocessor: ColumnTransformer,
|
||||||
|
feature_names: list = None, model_name: str = "best_model"):
|
||||||
|
"""保存模型和预处理工具
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: 要保存的模型
|
||||||
|
preprocessor: 数据预处理管道
|
||||||
|
feature_names: 特征名称列表(仅LightGBM需要)
|
||||||
|
model_name: 模型名称
|
||||||
|
"""
|
||||||
|
print(f"保存{model_name}模型...")
|
||||||
|
|
||||||
|
if isinstance(model, Pipeline):
|
||||||
|
# 保存完整的管道模型
|
||||||
|
model_path = self.models_dir / f"{model_name}.joblib"
|
||||||
|
joblib.dump(model, model_path)
|
||||||
|
else:
|
||||||
|
# 保存LightGBM模型
|
||||||
|
model_path = self.models_dir / f"{model_name}.joblib"
|
||||||
|
joblib.dump(model, model_path)
|
||||||
|
|
||||||
|
# 保存预处理管道
|
||||||
|
preprocessor_path = self.models_dir / "preprocessor.joblib"
|
||||||
|
joblib.dump(preprocessor, preprocessor_path)
|
||||||
|
|
||||||
|
# 保存特征名称
|
||||||
|
features_path = self.models_dir / "features.joblib"
|
||||||
|
joblib.dump(feature_names, features_path)
|
||||||
|
|
||||||
|
print(f"模型保存成功: {model_path}")
|
||||||
|
|
||||||
|
def train_and_evaluate(self):
|
||||||
|
"""完整的训练和评估流程"""
|
||||||
|
print("开始模型训练和评估流程...")
|
||||||
|
|
||||||
|
# 1. 准备数据
|
||||||
|
X_train, X_val, X_test, y_train, y_val, y_test = self.prepare_data()
|
||||||
|
|
||||||
|
# 2. 创建预处理管道
|
||||||
|
preprocessor = self.create_preprocessor(X_train)
|
||||||
|
|
||||||
|
# 3. 训练Logistic Regression模型
|
||||||
|
lr_model = self.train_logistic_regression(preprocessor, X_train, y_train)
|
||||||
|
|
||||||
|
# 4. 评估Logistic Regression模型
|
||||||
|
lr_train_metrics = self.evaluate_model(lr_model, preprocessor, X_train, y_train, "Logistic Regression (训练集)")
|
||||||
|
lr_val_metrics = self.evaluate_model(lr_model, preprocessor, X_val, y_val, "Logistic Regression (验证集)")
|
||||||
|
|
||||||
|
# 5. 训练LightGBM模型
|
||||||
|
feature_names, lgb_model = self.train_lightgbm(preprocessor, X_train, y_train)
|
||||||
|
|
||||||
|
# 6. 评估LightGBM模型
|
||||||
|
lgb_train_metrics = self.evaluate_model(lgb_model, preprocessor, X_train, y_train, "LightGBM (训练集)")
|
||||||
|
lgb_val_metrics = self.evaluate_model(lgb_model, preprocessor, X_val, y_val, "LightGBM (验证集)")
|
||||||
|
|
||||||
|
# 7. 选择最佳模型
|
||||||
|
print("\n选择最佳模型...")
|
||||||
|
best_model = None
|
||||||
|
best_model_name = ""
|
||||||
|
|
||||||
|
if lr_val_metrics['roc_auc'] > lgb_val_metrics['roc_auc']:
|
||||||
|
best_model = lr_model
|
||||||
|
best_model_name = "Logistic Regression"
|
||||||
|
else:
|
||||||
|
best_model = lgb_model
|
||||||
|
best_model_name = "LightGBM"
|
||||||
|
|
||||||
|
print(f"最佳模型: {best_model_name}")
|
||||||
|
|
||||||
|
# 8. 在测试集上评估最佳模型
|
||||||
|
print(f"\n在测试集上评估{best_model_name}模型...")
|
||||||
|
if isinstance(best_model, Pipeline):
|
||||||
|
best_test_metrics = self.evaluate_model(best_model, preprocessor, X_test, y_test, "Best Model (测试集)")
|
||||||
|
else:
|
||||||
|
best_test_metrics = self.evaluate_model(best_model, preprocessor, X_test, y_test, "Best Model (测试集)")
|
||||||
|
|
||||||
|
# 9. 保存最佳模型
|
||||||
|
if isinstance(best_model, Pipeline):
|
||||||
|
self.save_model(best_model, preprocessor, model_name="best_model_lr")
|
||||||
|
else:
|
||||||
|
self.save_model(best_model, preprocessor, feature_names, model_name="best_model")
|
||||||
|
|
||||||
|
print("\n模型训练和评估流程完成!")
|
||||||
|
|
||||||
|
return best_model, best_test_metrics
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 运行模型训练和评估
|
||||||
|
trainer = ModelTrainer()
|
||||||
|
best_model, test_metrics = trainer.train_and_evaluate()
|
||||||
3918
ml_course_design/uv.lock
generated
Normal file
3918
ml_course_design/uv.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
30
run_agent.bat
Normal file
30
run_agent.bat
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
@echo off
|
||||||
|
REM 电信客户流失预测Agent启动脚本
|
||||||
|
|
||||||
|
set "PROJECT_ROOT=%~dp0ml_course_design"
|
||||||
|
|
||||||
|
REM 检查项目根目录是否存在
|
||||||
|
if not exist "%PROJECT_ROOT%" (
|
||||||
|
echo 错误: 项目根目录不存在于 "%PROJECT_ROOT%"
|
||||||
|
echo 请确保该脚本与 ml_course_design 文件夹位于同一目录下
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
REM 检查uv是否已安装
|
||||||
|
where uv >nul 2>nul
|
||||||
|
if %errorlevel% neq 0 (
|
||||||
|
echo 错误: 未找到uv命令
|
||||||
|
echo 请先安装uv: pip install uv
|
||||||
|
pause
|
||||||
|
exit /b 1
|
||||||
|
)
|
||||||
|
|
||||||
|
REM 切换到项目根目录并启动Agent应用
|
||||||
|
echo 正在启动客户流失预测Agent...
|
||||||
|
echo 项目根目录: %PROJECT_ROOT%
|
||||||
|
cd /d "%PROJECT_ROOT%"
|
||||||
|
uv run python -m src.agent_app
|
||||||
|
|
||||||
|
REM 等待用户按下任意键退出
|
||||||
|
pause
|
||||||
32
run_agent.ps1
Normal file
32
run_agent.ps1
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
# 电信客户流失预测Agent启动脚本 (PowerShell版本)
|
||||||
|
|
||||||
|
# 设置项目根目录
|
||||||
|
$ProjectRoot = "$PSScriptRoot\ml_course_design"
|
||||||
|
|
||||||
|
# 检查项目根目录是否存在
|
||||||
|
if(-not (Test-Path $ProjectRoot)){
|
||||||
|
Write-Host "错误: 项目根目录不存在于 $ProjectRoot" -ForegroundColor Red
|
||||||
|
Write-Host "请确保该脚本与 ml_course_design 文件夹位于同一目录下" -ForegroundColor Yellow
|
||||||
|
Pause
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# 检查uv是否已安装
|
||||||
|
if(-not (Get-Command "uv" -ErrorAction SilentlyContinue)){
|
||||||
|
Write-Host "错误: 未找到uv命令" -ForegroundColor Red
|
||||||
|
Write-Host "请先安装uv: pip install uv" -ForegroundColor Yellow
|
||||||
|
Pause
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# 切换到项目根目录并启动Agent应用
|
||||||
|
Write-Host "正在启动客户流失预测Agent..." -ForegroundColor Green
|
||||||
|
Write-Host "项目根目录: $ProjectRoot" -ForegroundColor Cyan
|
||||||
|
Set-Location -Path $ProjectRoot
|
||||||
|
|
||||||
|
# 启动Agent应用
|
||||||
|
uv run python -m src.agent_app
|
||||||
|
|
||||||
|
# 等待用户按下任意键退出
|
||||||
|
Write-Host "\n按任意键退出..." -ForegroundColor Gray
|
||||||
|
$x = $host.ui.RawUI.ReadKey("NoEcho,IncludeKeyDown")
|
||||||
57
start_agent.py
Normal file
57
start_agent.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
电信客户流失预测Agent启动脚本
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import subprocess
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# 获取当前脚本所在目录
|
||||||
|
script_dir = Path(__file__).parent
|
||||||
|
|
||||||
|
# 项目根目录
|
||||||
|
project_root = script_dir / "ml_course_design"
|
||||||
|
|
||||||
|
print(f"当前脚本目录: {script_dir}")
|
||||||
|
print(f"项目根目录: {project_root}")
|
||||||
|
|
||||||
|
# 检查项目根目录是否存在
|
||||||
|
if not project_root.exists():
|
||||||
|
print(f"错误: 项目根目录不存在于 {project_root}")
|
||||||
|
print("请确保该脚本与 ml_course_design 文件夹位于同一目录下")
|
||||||
|
input("按回车键退出...")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# 检查uv是否已安装
|
||||||
|
try:
|
||||||
|
subprocess.run(["uv", "--version"], check=True, capture_output=True, text=True)
|
||||||
|
except subprocess.CalledProcessError:
|
||||||
|
print("错误: 未找到uv命令")
|
||||||
|
print("请先安装uv: pip install uv")
|
||||||
|
input("按回车键退出...")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# 切换到项目根目录并启动Agent应用
|
||||||
|
print("正在启动客户流失预测Agent...")
|
||||||
|
print(f"\n使用以下命令启动Agent:")
|
||||||
|
print(f"cd {project_root} && uv run python -m src.agent_app")
|
||||||
|
|
||||||
|
# 执行命令
|
||||||
|
try:
|
||||||
|
subprocess.run(
|
||||||
|
["uv", "run", "python", "-m", "src.agent_app"],
|
||||||
|
cwd=str(project_root),
|
||||||
|
check=True
|
||||||
|
)
|
||||||
|
except subprocess.CalledProcessError as e:
|
||||||
|
print(f"启动失败: {e}")
|
||||||
|
input("按回车键退出...")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
sys.exit(main())
|
||||||
34
test_agent.py
Normal file
34
test_agent.py
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 获取项目根目录
|
||||||
|
project_root = Path(r"c:\Users\HUANGYING\Desktop\jqxx\新建文件夹 (4)\ml_course_design")
|
||||||
|
|
||||||
|
# 将项目根目录添加到Python路径
|
||||||
|
sys.path.insert(0, str(project_root))
|
||||||
|
|
||||||
|
print(f"项目根目录: {project_root}")
|
||||||
|
print(f"Python路径中包含项目根目录: {str(project_root) in sys.path}")
|
||||||
|
|
||||||
|
# 测试导入
|
||||||
|
print("\n测试导入...")
|
||||||
|
try:
|
||||||
|
from src.agent_app import ChurnPredictionAgent
|
||||||
|
print("✅ 成功导入ChurnPredictionAgent!")
|
||||||
|
|
||||||
|
# 测试创建Agent实例
|
||||||
|
agent = ChurnPredictionAgent()
|
||||||
|
print("✅ 成功创建Agent实例!")
|
||||||
|
|
||||||
|
print(f"\n🎉 测试成功! 现在可以使用以下命令运行Agent应用:")
|
||||||
|
print(r" cd 'c:\Users\HUANGYING\Desktop\jqxx\新建文件夹 (4)\ml_course_design' ; uv run python -m src.agent_app")
|
||||||
|
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"❌ 导入失败: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"❌ 其他错误: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
Loading…
Reference in New Issue
Block a user