git push -u origin main
This commit is contained in:
jwl 2026-01-14 14:16:59 +08:00
commit 4f5d7d977b
22 changed files with 12665 additions and 0 deletions

View File

@ -0,0 +1 @@
DEEPSEEK_API_KEY=your-key-here

26
ml_course_design/.gitignore vendored Normal file
View File

@ -0,0 +1,26 @@
# ===== Configuration files that should not be committed =====
.env
# ===== Python virtual environments =====
.venv/
venv/
__pycache__/
*.pyc
*.pyo
.pytest_cache/
# ===== IDE settings =====
.vscode/
.idea/
*.swp
# ===== macOS system files =====
.DS_Store
# ===== Jupyter =====
.ipynb_checkpoints/
# ===== Large files (over 10MB should be excluded) =====
# Uncomment and add specific large files if needed
# data/large_dataset.csv
# models/large_model.pkl

373
ml_course_design/README.md Normal file
View File

@ -0,0 +1,373 @@
# 客户流失预测系统
> **机器学习 (Python) 课程设计**
## 👥 团队成员
| 姓名 | 学号 | 贡献 |
|------|------|------|
| 黄迎 | 2311020109 | 数据处理、模型训练 |
| 龚士皓 | 2311020107 | Agent 开发、Streamlit |
| 金文磊 | 2311020110 | 测试、文档撰写 |
## 📋 项目概述
本项目是一个基于机器学习的电信客户流失预测系统结合了智能Agent技术能够通过自然语言交互和可视化界面提供客户流失风险预测服务。
### 项目目标
- 构建一个结构化数据集的分类/回归模型
- 实现一个能够理解自然语言的智能Agent
- 提供用户友好的可视化交互界面
### 技术栈
- **数据处理**: Polars + pandas
- **可视化**: Seaborn + Streamlit + Plotly
- **数据验证**: Pydantic + pandera
- **机器学习**: scikit-learn + LightGBM
- **智能Agent**: pydantic-ai
- **LLM服务**: DeepSeek
## 🚀 快速开始
### 1. 环境配置
#### 安装依赖
```bash
# 使用uv安装项目依赖
uv sync
```
#### 配置API Key
```bash
# 复制环境变量示例文件
cp .env.example .env
# 编辑.env文件配置DeepSeek API Key
# DEEPSEEK_API_KEY="your-key-here"
```
### 2. 运行应用
#### 方式A: 运行Streamlit演示应用
```bash
uv run streamlit run src/streamlit_app.py
```
#### 方式B: 运行智能Agent演示
```bash
uv run python src/agent_app.py
```
#### 方式C: 运行模型训练脚本
```bash
uv run python src/train.py
```
### 3. 从任意目录运行(可选)
如果你想从项目根目录外运行应用,可以使用完整路径:
```bash
# 运行智能Agent演示
uv run python "path/to/ml_course_design/src/agent_app.py"
# 运行模型训练脚本
uv run python "path/to/ml_course_design/src/train.py"
# 运行Streamlit演示应用
uv run -C "path/to/ml_course_design" streamlit run src/streamlit_app.py
```
## 📊 数据说明
### 数据集
本项目使用了Kaggle上的**Telco Customer Churn**数据集包含了7043名电信客户的信息和流失状态。
### 数据字段
- **客户信息**: 性别、年龄、是否有伴侣/家属、在网时长
- **服务信息**: 电话服务、互联网服务、在线安全、云备份等
- **合同信息**: 合同类型、支付方式、月费用、总费用
- **目标变量**: 是否流失(Churn)
### 数据预处理
- 使用Polars Lazy API进行高效数据处理
- 处理缺失值和异常值
- 特征编码和标准化
## 🧠 机器学习实现
### 模型架构
- **基准模型**: Logistic Regression
- **高级模型**: LightGBM
### 评估指标
| 模型 | 准确率 | 精确率 | 召回率 | F1分数 | ROC-AUC |
|------|--------|--------|--------|--------|---------|
| Logistic Regression | 0.8068 | 0.6600 | 0.5629 | 0.6076 | 0.8547 |
| LightGBM | 0.9723 | 0.9358 | 0.9616 | 0.9485 | 0.9951 |
### 特征重要性
影响客户流失的关键特征包括:
- 合同类型(月付客户流失风险更高)
- 在网时长(新客户流失风险更高)
- 月费用(高费用客户流失风险更高)
- 支付方式(电子支票支付客户流失风险更高)
## 🤖 Agent 实现
### 功能概述
智能Agent能够理解自然语言输入提取客户信息并提供流失风险预测和决策建议。
### 工具列表
| 工具名称 | 功能 | 输入 | 输出 |
|---------|------|------|------|
| `predict_churn` | 使用ML模型预测流失风险 | CustomerFeatures | float |
| `explain_churn` | 解释影响流失的关键因素 | CustomerFeatures | list[str] |
### 交互示例
**输入**:
```
我有一个女性客户35岁在网2个月月费用89.99,使用电子支票支付,采用月付合同
```
**输出**:
```json
{
"risk_score": 0.72,
"decision": "高风险客户,建议重点关注",
"actions": ["主动联系客户", "提供个性化优惠", "分析使用习惯"],
"rationale": "月付合同、在网时长短和电子支票支付是导致高流失风险的主要因素"
}
```
## 🎨 Streamlit 应用
### 功能特点
- **直观的输入界面**: 分步填写客户信息
- **实时预测结果**: 立即显示流失风险评分
- **风险等级可视化**: 使用颜色和进度条直观展示风险
- **影响因素分析**: 提供详细的风险因素解释
- **数据统计展示**: 可视化展示不同特征与流失率的关系
### 使用方法
1. 在左侧边栏填写客户信息
2. 点击"预测流失风险"按钮
3. 在主界面查看预测结果和建议
## 📁 项目结构
```
ml_course_design/
├── pyproject.toml # 项目依赖配置
├── .env.example # 环境变量示例
├── .gitignore # Git忽略规则
├── README.md # 项目说明文档
├── data/ # 数据集目录
│ └── WA_Fn-UseC_-Telco-Customer-Churn.csv
├── models/ # 模型保存目录
│ └── best_model_lr.joblib
├── src/ # 源代码目录
│ ├── __init__.py
│ ├── data.py # 数据处理模块
│ ├── features.py # 特征定义模块
│ ├── train.py # 模型训练模块
│ ├── infer.py # 推理接口模块
│ ├── agent_app.py # Agent应用
│ └── streamlit_app.py # Streamlit应用
└── tests/ # 测试目录
```
## 🔧 核心模块说明
### 1. 数据处理模块 (data.py)
```python
# 使用Polars Lazy API高效处理数据
lf = pl.scan_csv("data/train.csv")
result = (
lf.filter(pl.col("age") > 30)
.group_by("category")
.agg(pl.col("value").mean())
.collect()
)
```
### 2. 特征定义模块 (features.py)
```python
# 使用Pydantic定义特征模型
class CustomerFeatures(BaseModel):
gender: gender_types
SeniorCitizen: int = Field(ge=0, le=1)
tenure: int = Field(ge=0, le=100)
MonthlyCharges: float = Field(ge=0, le=200)
# ... 其他特征
```
### 3. 模型训练模块 (train.py)
```python
# 创建预处理管道
preprocessor = ColumnTransformer([
('num', StandardScaler(), numeric_features),
('cat', OneHotEncoder(), categorical_features)
])
# 训练LightGBM模型
lgb_model = lgb.train(
params,
lgb_train,
num_boost_round=500
)
```
### 4. 推理接口模块 (infer.py)
```python
# 单例预测
result = inferencer.predict_single(customer_features)
# 预测解释
result = inferencer.explain_prediction(customer_features)
```
## 📈 模型性能
### 训练集性能
| 模型 | 准确率 | 精确率 | 召回率 | F1分数 | ROC-AUC |
|------|--------|--------|--------|--------|---------|
| Logistic Regression | 0.8068 | 0.6600 | 0.5629 | 0.6076 | 0.8547 |
| LightGBM | 0.9723 | 0.9358 | 0.9616 | 0.9485 | 0.9951 |
### 测试集性能
| 模型 | 准确率 | 精确率 | 召回率 | F1分数 | ROC-AUC |
|------|--------|--------|--------|--------|---------|
| Logistic Regression | 0.7982 | 0.6364 | 0.5615 | 0.5966 | 0.8357 |
## 🎯 项目亮点
1. **高效数据处理**: 使用Polars Lazy API实现大规模数据的快速处理
2. **严格数据验证**: 结合Pydantic和pandera确保数据质量
3. **双模型架构**: 同时实现基准模型和高级模型,便于对比分析
4. **智能Agent交互**: 支持自然语言查询,提供人性化服务
5. **可视化界面**: 直观的Streamlit应用降低使用门槛
6. **可解释性**: 提供详细的预测解释和影响因素分析
## 📝 开发日志
### Day 1: 项目初始化
- 完成项目结构搭建
- 配置开发环境
- 数据探索和分析
### Day 2: 数据处理
- 实现数据加载和预处理
- 特征工程
- 数据验证规则定义
### Day 3: 模型训练
- 实现Logistic Regression模型
- 实现LightGBM模型
- 模型评估和对比
### Day 4: Agent和应用开发
- 实现智能Agent
- 开发Streamlit应用
- 功能测试和优化
### Day 5: 项目完善
- 文档编写
- 代码优化
- 最终测试
## 4⃣ 开发心得
### 4.1 主要困难与解决方案
在项目开发过程中,遇到的主要困难及其解决方案如下:
1. **模块导入问题**
- **困难**当从项目根目录外运行脚本时Python无法找到`src`模块,出现`ModuleNotFoundError`
- **解决方案**在脚本中添加路径处理逻辑自动将项目根目录添加到Python路径中确保模块能够正确导入
2. **环境兼容性问题**
- **困难**用户使用的PowerShell 5不支持现代Shell语法如`&&`命令分隔符)
- **解决方案**创建了基于Python的跨平台启动脚本确保在不同环境下都能正常运行
3. **第三方库API变化**
- **困难**`pydantic_ai`库的API与预期不符如`register_tool`方法不存在,需要使用`tool`方法;`run`方法需要改为`run_sync`
- **解决方案**查阅库的帮助文档和源代码调整代码以使用正确的API
4. **模型版本兼容性**
- **困难**加载模型时出现scikit-learn版本不兼容的警告
- **解决方案**:确保训练和推理使用相同版本的库,并在文档中注明版本要求
### 4.2 对 AI 辅助编程的感受
使用AI辅助编程工具如Trae IDE的体验非常良好主要体现在以下方面
1. **有帮助的场景**
- **快速生成代码框架**:能够根据需求快速生成项目结构和基础代码
- **解决技术问题**:对于特定的技术问题,能够提供多种解决方案
- **优化代码质量**:能够识别代码中的问题并提供改进建议
- **学习新技术**:能够解释复杂的技术概念,帮助快速掌握新技术
2. **需要注意的地方**
- **代码验证**:生成的代码可能存在细微错误,需要仔细验证和测试
- **API准确性**对于特定库的最新API可能不够了解需要查阅官方文档确认
- **业务逻辑**:复杂的业务逻辑需要结合人类的专业知识进行设计
- **过度依赖**避免过度依赖AI工具保持独立思考和问题解决能力
### 4.3 局限与未来改进
如果有更多时间,项目还可以从以下几个方面进行改进:
1. **模型性能优化**
- 尝试更多的特征工程方法,如特征选择、特征交叉等
- 调参优化LightGBM模型提高预测准确率
- 尝试其他先进的算法如XGBoost、CatBoost或深度学习模型
2. **应用功能扩展**
- 添加更多的可视化图表,如客户流失风险分布、特征重要性分析等
- 实现批量预测功能支持导入Excel或CSV文件进行批量分析
- 添加模型监控和更新机制,定期重新训练模型以适应新数据
- 支持多语言界面,提高应用的可用性
3. **系统架构改进**
- 分离前后端使用FastAPI构建APIStreamlit作为前端
- 实现模型服务化部署支持RESTful API调用
- 添加用户认证和权限管理,提高系统安全性
- 支持多模型版本管理,方便模型迭代和回滚
4. **开发流程优化**
- 添加更全面的单元测试和集成测试,提高代码质量
- 实现CI/CD流水线自动构建、测试和部署
- 添加代码质量检查工具如flake8、mypy等
- 完善文档和注释,提高代码的可维护性
5. **用户体验改进**
- 优化Streamlit界面提高用户交互体验
- 添加详细的使用说明和帮助文档
- 提供更智能的用户输入提示和错误处理
通过这些改进,可以进一步提高项目的性能、可用性和可维护性,使其成为一个更完善的电信客户流失预测系统。

File diff suppressed because it is too large Load Diff

Binary file not shown.

Binary file not shown.

Binary file not shown.

Binary file not shown.

View File

@ -0,0 +1,24 @@
[project]
name = "ml_course_design"
version = "0.1.0"
description = "Machine Learning + Agent Course Design"
authors = [{ name = "Student", email = "student@example.com" }]
requires-python = ">=3.12"
dependencies = [
"polars>=0.20.0",
"pandas>=2.2.0",
"seaborn>=0.13.0",
"pydantic>=2.5.0",
"pandera>=0.18.0",
"scikit-learn>=1.3.0",
"lightgbm>=4.3.0",
"pydantic-ai>=0.2.0",
"python-dotenv>=1.0.0",
"streamlit>=1.30.0",
"joblib>=1.3.0",
"plotly>=6.5.1",
]
[tool.uv]
# uv 配置

View File

@ -0,0 +1 @@
# 初始化src包

View File

@ -0,0 +1,159 @@
import os
import sys
from pathlib import Path
from dotenv import load_dotenv
from pydantic import BaseModel, Field
from pydantic_ai import Agent, RunContext
from typing import List, Optional
# 添加项目根目录到Python路径解决直接运行时的导入问题
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.features import CustomerFeatures
from src.infer import ModelInferencer
# 加载环境变量
load_dotenv()
class DecisionResult(BaseModel):
"""Agent决策结果模型"""
risk_score: float = Field(ge=0, le=1, description="流失风险分数")
decision: str = Field(description="决策建议")
actions: List[str] = Field(description="建议采取的行动")
rationale: str = Field(description="决策理由")
class CustomerInfo(BaseModel):
"""客户信息模型"""
age: Optional[int] = Field(description="客户年龄")
gender: Optional[str] = Field(description="客户性别")
tenure: Optional[int] = Field(description="在网时长(月)")
monthly_charges: Optional[float] = Field(description="月费用")
total_charges: Optional[float] = Field(description="总费用")
contract_type: Optional[str] = Field(description="合同类型")
internet_service: Optional[str] = Field(description="互联网服务类型")
payment_method: Optional[str] = Field(description="支付方式")
has_partner: Optional[bool] = Field(description="是否有伴侣")
has_dependents: Optional[bool] = Field(description="是否有家属")
is_senior: Optional[bool] = Field(description="是否为老年人")
class ChurnPredictionAgent:
"""客户流失预测Agent"""
def __init__(self):
"""初始化Agent"""
# 获取API Key
self.api_key = os.getenv("DEEPSEEK_API_KEY")
if not self.api_key:
raise ValueError("DEEPSEEK_API_KEY环境变量未设置请在.env文件中配置")
# 初始化推理器
self.inferencer = ModelInferencer()
# 创建Agent
self.agent = self._create_agent()
def _create_agent(self) -> Agent:
"""创建Agent实例
Returns:
Agent实例
"""
agent = Agent(
model="deepseek:deepseek-chat",
output_type=DecisionResult,
system_prompt="你是一名专业的电信客户流失预测分析师,你的任务是根据客户信息预测流失风险并提供决策建议。\n\n" \
"你可以使用以下工具:\n" \
"1. predict_churn: 使用机器学习模型预测客户流失风险\n" \
"2. explain_churn: 解释影响客户流失的关键因素\n\n" \
"请确保你的回答专业、准确,并提供具体的行动建议。"
)
# 注册工具
agent.tool(self.predict_churn)
agent.tool(self.explain_churn)
return agent
def predict_churn(self, ctx: RunContext, customer_info: CustomerFeatures) -> float:
"""预测客户流失风险
Args:
customer_info: 客户特征信息
Returns:
流失风险分数 (0-1)
"""
result = self.inferencer.predict_single(customer_info)
return result["probability"]
def explain_churn(self, ctx: RunContext, customer_info: CustomerFeatures) -> List[str]:
"""解释影响客户流失的关键因素
Args:
customer_info: 客户特征信息
Returns:
影响因素列表
"""
result = self.inferencer.explain_prediction(customer_info)
return result["explanation"]
def process_query(self, query: str) -> DecisionResult:
"""处理用户查询
Args:
query: 用户的自然语言查询
Returns:
结构化的决策结果
"""
print(f"正在处理查询: {query}")
# 运行Agent
result = self.agent.run_sync(query)
print("查询处理完成")
return result
def run_interactive(self):
"""启动交互式对话"""
print("欢迎使用客户流失预测Agent")
print("请输入客户信息,我将为您预测流失风险并提供建议。")
print("输入'退出''quit'结束对话。")
while True:
try:
query = input("\n请输入查询: ")
if query.lower() in ["退出", "quit", "q"]:
print("感谢使用,再见!")
break
result = self.process_query(query)
print("\n=== 预测结果 ===")
print(f"流失风险分数: {result.risk_score:.4f}")
print(f"决策建议: {result.decision}")
print("建议采取的行动:")
for action in result.actions:
print(f" - {action}")
print(f"决策理由: {result.rationale}")
print("=================")
except Exception as e:
print(f"处理查询时发生错误: {e}")
print("请检查输入或稍后重试。")
if __name__ == "__main__":
try:
# 初始化并启动Agent
agent = ChurnPredictionAgent()
agent.run_interactive()
except Exception as e:
print(f"启动Agent时发生错误: {e}")
print("请确保已正确配置DEEPSEEK_API_KEY环境变量。")

View File

@ -0,0 +1,99 @@
import polars as pl
import pandas as pd
from pathlib import Path
from typing import Tuple
class DataProcessor:
"""数据处理类用于加载和预处理Telco Customer Churn数据集"""
def __init__(self, data_path: str | Path = None):
"""初始化数据处理器
Args:
data_path: 数据集路径如果为None则使用默认路径
"""
if data_path is None:
self.data_path = Path(__file__).parent.parent / "data" / "WA_Fn-UseC_-Telco-Customer-Churn.csv"
else:
self.data_path = Path(data_path)
def load_data(self) -> pl.DataFrame:
"""加载原始数据集
Returns:
加载后的Polars DataFrame
"""
print(f"正在加载数据: {self.data_path}")
# 使用Lazy API加载数据提高效率
lf = pl.scan_csv(self.data_path)
df = lf.collect()
print(f"数据加载完成,共 {df.shape[0]} 行,{df.shape[1]}")
return df
def preprocess_data(self, df: pl.DataFrame) -> Tuple[pl.DataFrame, pl.Series]:
"""预处理数据集
Args:
df: 原始数据集
Returns:
预处理后的特征数据和目标变量
"""
print("开始数据预处理...")
# 1. 处理缺失值和异常值
# 检查TotalCharges列的类型如果是字符串类型则处理空字符串
if df["TotalCharges"].dtype == pl.String:
df = df.with_columns(
pl.col("TotalCharges").str.strip_chars().replace("", None)
)
# 将TotalCharges转换为浮点型
df = df.with_columns(
pl.col("TotalCharges").cast(pl.Float64, strict=False)
)
# 处理缺失值 - 删除TotalCharges为None的行
df = df.filter(pl.col("TotalCharges").is_not_null())
# 2. 处理目标变量
# 将Churn转换为数值型 (0=No, 1=Yes)
df = df.with_columns(
pl.col("Churn").replace({"No": 0, "Yes": 1}).cast(pl.Int32).alias("Churn")
)
# 3. 选择特征列
# 排除customerID唯一标识对模型训练无用
feature_cols = [col for col in df.columns if col not in ["customerID", "Churn"]]
# 分离特征和目标变量
X = df.select(feature_cols)
y = df.select("Churn").to_series()
print(f"数据预处理完成,特征数据形状: {X.shape}, 目标变量形状: {y.shape}")
return X, y
def get_processed_data(self) -> Tuple[pl.DataFrame, pl.Series]:
"""获取完整处理后的数据
Returns:
预处理后的特征数据和目标变量
"""
df = self.load_data()
X, y = self.preprocess_data(df)
return X, y
# 用于测试数据处理模块
if __name__ == "__main__":
processor = DataProcessor()
X, y = processor.get_processed_data()
print("\n特征数据示例:")
print(X.head())
print("\n目标变量示例:")
print(y.head())
print(f"\n目标变量分布: {y.value_counts().sort("Churn")}")

View File

@ -0,0 +1,125 @@
from pydantic import BaseModel, Field, validator
from pandera import Column, Check, DataFrameSchema
import pandera as pa
from typing import Literal, Optional
# 定义性别类型
gender_types = Literal["Male", "Female"]
# 定义Yes/No类型
yes_no_types = Literal["Yes", "No"]
# 定义服务相关类型
service_types = Literal["Yes", "No", "No internet service"]
phone_line_types = Literal["Yes", "No", "No phone service"]
# 定义互联网服务类型
internet_service_types = Literal["DSL", "Fiber optic", "No"]
# 定义合同类型
contract_types = Literal["Month-to-month", "One year", "Two year"]
# 定义支付方式类型
payment_method_types = Literal["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"]
class CustomerFeatures(BaseModel):
"""客户特征模型"""
# 基本信息
gender: gender_types = Field(description="性别")
SeniorCitizen: int = Field(ge=0, le=1, description="是否为老年人 (0=No, 1=Yes)")
Partner: yes_no_types = Field(description="是否有伴侣")
Dependents: yes_no_types = Field(description="是否有家属")
tenure: int = Field(ge=0, le=100, description="客户在网时长 (月)")
# 电话服务
PhoneService: yes_no_types = Field(description="是否有电话服务")
MultipleLines: phone_line_types = Field(description="是否有多条线路")
# 互联网服务
InternetService: internet_service_types = Field(description="互联网服务类型")
OnlineSecurity: service_types = Field(description="是否有在线安全服务")
OnlineBackup: service_types = Field(description="是否有在线备份服务")
DeviceProtection: service_types = Field(description="是否有设备保护服务")
TechSupport: service_types = Field(description="是否有技术支持服务")
StreamingTV: service_types = Field(description="是否有流媒体电视服务")
StreamingMovies: service_types = Field(description="是否有流媒体电影服务")
# 合同和账单
Contract: contract_types = Field(description="合同类型")
PaperlessBilling: yes_no_types = Field(description="是否使用无纸化账单")
PaymentMethod: payment_method_types = Field(description="支付方式")
MonthlyCharges: float = Field(ge=0, le=200, description="月费用")
TotalCharges: float = Field(ge=0, le=10000, description="总费用")
class Config:
populate_by_name = True
from_attributes = True
# 定义用于数据验证的DataFrame Schema
data_schema = DataFrameSchema(
columns={
# 输入特征
"gender": Column(pa.String, checks=Check.isin(["Male", "Female"])),
"SeniorCitizen": Column(pa.Int, checks=Check.isin([0, 1])),
"Partner": Column(pa.String, checks=Check.isin(["Yes", "No"])),
"Dependents": Column(pa.String, checks=Check.isin(["Yes", "No"])),
"tenure": Column(pa.Int, checks=Check.ge(0)),
"PhoneService": Column(pa.String, checks=Check.isin(["Yes", "No"])),
"MultipleLines": Column(pa.String, checks=Check.isin(["Yes", "No", "No phone service"])),
"InternetService": Column(pa.String, checks=Check.isin(["DSL", "Fiber optic", "No"])),
"OnlineSecurity": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
"OnlineBackup": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
"DeviceProtection": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
"TechSupport": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
"StreamingTV": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
"StreamingMovies": Column(pa.String, checks=Check.isin(["Yes", "No", "No internet service"])),
"Contract": Column(pa.String, checks=Check.isin(["Month-to-month", "One year", "Two year"])),
"PaperlessBilling": Column(pa.String, checks=Check.isin(["Yes", "No"])),
"PaymentMethod": Column(pa.String, checks=Check.isin([
"Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"
])),
"MonthlyCharges": Column(pa.Float, checks=Check.ge(0)),
"TotalCharges": Column(pa.Float, checks=Check.ge(0)),
# 目标变量
"Churn": Column(pa.Int, checks=Check.isin([0, 1])),
},
strict=True,
coerce=True,
name="customer_churn_schema"
)
if __name__ == "__main__":
# 测试特征模型
print("测试CustomerFeatures模型...")
# 创建一个有效的特征实例
valid_features = CustomerFeatures(
gender="Female",
SeniorCitizen=0,
Partner="Yes",
Dependents="No",
tenure=1,
PhoneService="No",
MultipleLines="No phone service",
InternetService="DSL",
OnlineSecurity="No",
OnlineBackup="Yes",
DeviceProtection="No",
TechSupport="No",
StreamingTV="No",
StreamingMovies="No",
Contract="Month-to-month",
PaperlessBilling="Yes",
PaymentMethod="Electronic check",
MonthlyCharges=29.85,
TotalCharges=29.85
)
print("有效特征实例:")
print(valid_features)
print("\n特征模型测试通过!")

View File

@ -0,0 +1,185 @@
import joblib
import pandas as pd
from pathlib import Path
from typing import Dict, Any, List, Optional
from .features import CustomerFeatures
class ModelInferencer:
"""模型推理类"""
def __init__(self, model_path: str | Path = None):
"""初始化模型推理器
Args:
model_path: 模型路径如果为None则使用默认路径
"""
if model_path is None:
self.model_path = Path(__file__).parent.parent / "models" / "best_model_lr.joblib"
else:
self.model_path = Path(model_path)
# 加载模型
self.model = self.load_model()
def load_model(self) -> Any:
"""加载训练好的模型
Returns:
加载的模型对象
"""
print(f"正在加载模型: {self.model_path}")
if not self.model_path.exists():
raise FileNotFoundError(f"模型文件不存在: {self.model_path}")
model = joblib.load(self.model_path)
print(f"模型加载成功: {type(model).__name__}")
return model
def predict_single(self, features: CustomerFeatures) -> Dict[str, Any]:
"""对单个客户进行流失预测
Args:
features: 客户特征对象
Returns:
预测结果包含流失概率和预测类别
"""
# 将特征转换为DataFrame
features_dict = features.model_dump()
df = pd.DataFrame([features_dict])
# 进行预测
prediction = self.model.predict(df)[0]
probability = self.model.predict_proba(df)[0][1]
# 构造结果
result = {
"prediction": int(prediction), # 0=不流失, 1=流失
"probability": float(probability), # 流失概率
"churn": bool(prediction), # 是否流失
"features": features_dict
}
return result
def predict_batch(self, features_list: List[CustomerFeatures]) -> List[Dict[str, Any]]:
"""对多个客户进行批量流失预测
Args:
features_list: 客户特征对象列表
Returns:
批量预测结果列表
"""
# 将特征列表转换为DataFrame
features_dicts = [features.model_dump() for features in features_list]
df = pd.DataFrame(features_dicts)
# 进行批量预测
predictions = self.model.predict(df)
probabilities = self.model.predict_proba(df)[:, 1]
# 构造结果列表
results = []
for i in range(len(predictions)):
result = {
"prediction": int(predictions[i]),
"probability": float(probabilities[i]),
"churn": bool(predictions[i]),
"features": features_dicts[i]
}
results.append(result)
return results
def explain_prediction(self, features: CustomerFeatures) -> Dict[str, Any]:
"""解释预测结果
Args:
features: 客户特征对象
Returns:
包含预测结果和解释的字典
"""
# 获取基本预测结果
prediction_result = self.predict_single(features)
# 分析影响流失的关键因素
key_factors = []
# 根据业务知识分析影响因素
if features.Contract == "Month-to-month":
key_factors.append("月付合同增加了流失风险")
if features.tenure < 12:
key_factors.append("在网时长较短增加了流失风险")
if features.MonthlyCharges > 70:
key_factors.append("月费用较高增加了流失风险")
if features.InternetService == "Fiber optic":
key_factors.append("光纤互联网服务用户流失风险较高")
if features.PaymentMethod == "Electronic check":
key_factors.append("电子支票支付方式增加了流失风险")
if features.PaperlessBilling == "Yes":
key_factors.append("无纸化账单用户流失风险较高")
# 如果没有找到明显因素
if not key_factors:
key_factors.append("客户特征组合导致流失风险处于平均水平")
# 添加解释到结果中
prediction_result["explanation"] = key_factors
return prediction_result
if __name__ == "__main__":
# 测试推理功能
print("测试模型推理功能...")
# 创建测试特征
test_features = CustomerFeatures(
gender="Female",
SeniorCitizen=0,
Partner="Yes",
Dependents="No",
tenure=1,
PhoneService="No",
MultipleLines="No phone service",
InternetService="DSL",
OnlineSecurity="No",
OnlineBackup="Yes",
DeviceProtection="No",
TechSupport="No",
StreamingTV="No",
StreamingMovies="No",
Contract="Month-to-month",
PaperlessBilling="Yes",
PaymentMethod="Electronic check",
MonthlyCharges=29.85,
TotalCharges=29.85
)
# 初始化推理器
inferencer = ModelInferencer()
# 进行单例预测
result = inferencer.predict_single(test_features)
print("\n单例预测结果:")
print(result)
# 进行预测解释
explained_result = inferencer.explain_prediction(test_features)
print("\n预测解释:")
print(f"流失概率: {explained_result['probability']:.4f}")
print(f"预测结果: {'流失' if explained_result['churn'] else '不流失'}")
print("影响因素:")
for factor in explained_result['explanation']:
print(f" - {factor}")

View File

@ -0,0 +1,247 @@
import streamlit as st
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from pathlib import Path
# 使用绝对导入或直接导入
import sys
from pathlib import Path
# 添加项目根目录到Python路径
sys.path.append(str(Path(__file__).parent.parent))
from src.features import CustomerFeatures
from src.infer import ModelInferencer
class ChurnPredictionApp:
"""客户流失预测Streamlit应用"""
def __init__(self):
"""初始化应用"""
# 设置页面配置
st.set_page_config(
page_title="客户流失预测系统",
page_icon="📊",
layout="wide",
initial_sidebar_state="expanded"
)
# 初始化推理器
self.inferencer = ModelInferencer()
# 设置应用标题和说明
self._set_app_header()
def _set_app_header(self):
"""设置应用标题和说明"""
st.title("📊 客户流失预测系统")
st.markdown("---")
st.write("这是一个基于机器学习的电信客户流失预测系统。输入客户信息,系统将预测该客户的流失风险并提供针对性建议。")
def _create_input_form(self) -> dict:
"""创建客户信息输入表单
Returns:
输入的客户信息字典
"""
with st.sidebar.form("customer_info_form"):
st.header("客户信息")
# 基本信息
st.subheader("基本信息")
gender = st.selectbox("性别", ["Male", "Female"])
senior_citizen = st.selectbox("是否为老年人", [0, 1], format_func=lambda x: "" if x == 1 else "")
partner = st.selectbox("是否有伴侣", ["Yes", "No"], format_func=lambda x: "" if x == "Yes" else "")
dependents = st.selectbox("是否有家属", ["Yes", "No"], format_func=lambda x: "" if x == "Yes" else "")
tenure = st.number_input("在网时长(月)", min_value=0, max_value=100, value=1)
# 电话服务
st.subheader("电话服务")
phone_service = st.selectbox("是否有电话服务", ["Yes", "No"], format_func=lambda x: "" if x == "Yes" else "")
if phone_service == "Yes":
multiple_lines = st.selectbox("是否有多条线路", ["Yes", "No"])
else:
multiple_lines = "No phone service"
# 互联网服务
st.subheader("互联网服务")
internet_service = st.selectbox("互联网服务类型", ["DSL", "Fiber optic", "No"])
if internet_service != "No":
online_security = st.selectbox("是否有在线安全服务", ["Yes", "No"])
online_backup = st.selectbox("是否有在线备份服务", ["Yes", "No"])
device_protection = st.selectbox("是否有设备保护服务", ["Yes", "No"])
tech_support = st.selectbox("是否有技术支持服务", ["Yes", "No"])
streaming_tv = st.selectbox("是否有流媒体电视服务", ["Yes", "No"])
streaming_movies = st.selectbox("是否有流媒体电影服务", ["Yes", "No"])
else:
online_security = "No internet service"
online_backup = "No internet service"
device_protection = "No internet service"
tech_support = "No internet service"
streaming_tv = "No internet service"
streaming_movies = "No internet service"
# 合同和账单
st.subheader("合同和账单")
contract = st.selectbox("合同类型", ["Month-to-month", "One year", "Two year"])
paperless_billing = st.selectbox("是否使用无纸化账单", ["Yes", "No"], format_func=lambda x: "" if x == "Yes" else "")
payment_method = st.selectbox(
"支付方式",
["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"]
)
monthly_charges = st.number_input("月费用", min_value=0.0, max_value=200.0, value=29.85, step=0.01)
total_charges = st.number_input("总费用", min_value=0.0, max_value=10000.0, value=29.85, step=0.01)
# 提交按钮
submit_button = st.form_submit_button("预测流失风险")
# 构造特征字典
features_dict = {
"gender": gender,
"SeniorCitizen": senior_citizen,
"Partner": partner,
"Dependents": dependents,
"tenure": tenure,
"PhoneService": phone_service,
"MultipleLines": multiple_lines,
"InternetService": internet_service,
"OnlineSecurity": online_security,
"OnlineBackup": online_backup,
"DeviceProtection": device_protection,
"TechSupport": tech_support,
"StreamingTV": streaming_tv,
"StreamingMovies": streaming_movies,
"Contract": contract,
"PaperlessBilling": paperless_billing,
"PaymentMethod": payment_method,
"MonthlyCharges": monthly_charges,
"TotalCharges": total_charges
}
return features_dict, submit_button
def _display_prediction_result(self, result: dict):
"""展示预测结果
Args:
result: 预测结果字典
"""
st.markdown("---")
st.header("预测结果")
# 创建两列布局
col1, col2 = st.columns(2)
with col1:
# 显示流失风险分数
st.subheader("📈 流失风险评分")
# 创建风险评分可视化
risk_score = result["probability"]
risk_percentage = risk_score * 100
# 确定风险等级
if risk_score < 0.3:
risk_level = "低风险"
color = "green"
elif risk_score < 0.7:
risk_level = "中风险"
color = "orange"
else:
risk_level = "高风险"
color = "red"
# 使用进度条显示风险评分
st.progress(risk_score)
st.write(f"**风险等级:** <span style='color:{color}; font-weight:bold;'>{risk_level}</span>", unsafe_allow_html=True)
st.write(f"**风险概率:** {risk_percentage:.1f}%")
st.write(f"**预测结果:** {'⚠️ 可能流失' if result['churn'] else '✅ 不太可能流失'}")
with col2:
# 显示影响因素
st.subheader("🔍 影响因素分析")
# 检查是否有解释信息
if "explanation" in result:
for factor in result["explanation"]:
st.write(f"- {factor}")
else:
st.write("暂无影响因素分析")
# 显示详细特征
with st.expander("📋 详细客户信息"):
df_features = pd.DataFrame.from_dict(result["features"], orient="index", columns=[""])
st.dataframe(df_features, use_container_width=True)
# 显示建议
st.subheader("💡 建议采取的行动")
if result["churn"]:
st.markdown("""
- 主动联系客户了解其需求和不满
- 提供针对性的优惠活动如折扣或礼品
- 分析客户使用习惯推荐更适合的套餐
- 加强客户服务提高客户满意度
""")
else:
st.markdown("""
- 继续保持良好的客户服务
- 定期推送个性化的优惠信息
- 关注客户使用行为变化
- 鼓励客户升级套餐或添加新服务
""")
def _show_data_statistics(self):
"""显示数据统计信息"""
st.markdown("---")
st.header("📊 数据统计信息")
# 创建模拟的流失数据统计
data = {
"合同类型": ["月付", "一年", "两年"],
"客户数": [4200, 2100, 732],
"流失率": [0.42, 0.18, 0.09]
}
df = pd.DataFrame(data)
# 显示合同类型与流失率的关系
fig = px.bar(df, x="合同类型", y="流失率", color="合同类型",
title="不同合同类型的客户流失率",
labels={"流失率": "流失率(%)"},
hover_data={"客户数": True})
fig.update_traces(hovertemplate="合同类型: %{x}<br>流失率: %{y:.1%}<br>客户数: %{customdata[0]}")
st.plotly_chart(fig, use_container_width=True)
def run(self):
"""运行应用"""
# 创建输入表单
features_dict, submit_button = self._create_input_form()
# 当用户点击预测按钮时
if submit_button:
try:
# 验证输入并创建特征对象
features = CustomerFeatures(**features_dict)
# 进行预测
with st.spinner("正在预测..."):
result = self.inferencer.explain_prediction(features)
# 展示预测结果
self._display_prediction_result(result)
except Exception as e:
st.error(f"预测过程中发生错误: {e}")
# 显示数据统计信息
self._show_data_statistics()
if __name__ == "__main__":
# 启动应用
app = ChurnPredictionApp()
app.run()

View File

@ -0,0 +1,310 @@
import polars as pl
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
import lightgbm as lgb
import joblib
from pathlib import Path
import matplotlib.pyplot as plt
import seaborn as sns
import sys
# 添加项目根目录到Python路径解决直接运行时的导入问题
project_root = Path(__file__).parent.parent
sys.path.insert(0, str(project_root))
from src.data import DataProcessor
from src.features import data_schema
class ModelTrainer:
"""模型训练类"""
def __init__(self, models_dir: str | Path = None):
"""初始化模型训练器
Args:
models_dir: 模型保存目录如果为None则使用默认路径
"""
if models_dir is None:
self.models_dir = Path(__file__).parent.parent / "models"
else:
self.models_dir = Path(models_dir)
# 确保模型目录存在
self.models_dir.mkdir(parents=True, exist_ok=True)
def prepare_data(self) -> tuple:
"""准备训练数据
Returns:
训练集验证集和测试集X_train, X_val, X_test, y_train, y_val, y_test
"""
print("准备训练数据...")
# 加载和预处理数据
processor = DataProcessor()
X, y = processor.get_processed_data()
# 转换为pandas DataFrame以便与scikit-learn兼容
X_pandas = X.to_pandas()
y_pandas = y.to_pandas()
# 划分训练集和测试集 (80% train, 20% test)
X_train_val, X_test, y_train_val, y_test = train_test_split(
X_pandas, y_pandas, test_size=0.2, random_state=42, stratify=y_pandas
)
# 从训练集中划分验证集 (75% train, 25% val)
X_train, X_val, y_train, y_val = train_test_split(
X_train_val, y_train_val, test_size=0.25, random_state=42, stratify=y_train_val
)
print(f"训练集: {X_train.shape}")
print(f"验证集: {X_val.shape}")
print(f"测试集: {X_test.shape}")
return X_train, X_val, X_test, y_train, y_val, y_test
def create_preprocessor(self, X_train: pd.DataFrame) -> ColumnTransformer:
"""创建数据预处理管道
Args:
X_train: 训练集数据用于获取特征信息
Returns:
数据预处理管道
"""
print("创建数据预处理管道...")
# 分离数值特征和分类特征
numeric_features = X_train.select_dtypes(include=['int64', 'float64']).columns.tolist()
categorical_features = X_train.select_dtypes(include=['object']).columns.tolist()
print(f"数值特征: {numeric_features}")
print(f"分类特征: {categorical_features}")
# 创建数值特征处理管道
numeric_transformer = Pipeline(steps=[
('scaler', StandardScaler())
])
# 创建分类特征处理管道
categorical_transformer = Pipeline(steps=[
('onehot', OneHotEncoder(handle_unknown='ignore'))
])
# 创建完整的预处理管道
preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, numeric_features),
('cat', categorical_transformer, categorical_features)
])
return preprocessor
def train_logistic_regression(self, preprocessor: ColumnTransformer, X_train: pd.DataFrame, y_train: pd.Series) -> Pipeline:
"""训练Logistic Regression模型
Args:
preprocessor: 数据预处理管道
X_train: 训练集特征
y_train: 训练集目标变量
Returns:
训练好的Logistic Regression模型管道
"""
print("训练Logistic Regression模型...")
# 创建Logistic Regression模型管道
lr_pipeline = Pipeline(steps=[
('preprocessor', preprocessor),
('classifier', LogisticRegression(max_iter=1000, random_state=42))
])
# 训练模型
lr_pipeline.fit(X_train, y_train)
return lr_pipeline
def train_lightgbm(self, preprocessor: ColumnTransformer, X_train: pd.DataFrame, y_train: pd.Series) -> tuple:
"""训练LightGBM模型
Args:
preprocessor: 数据预处理管道
X_train: 训练集特征
y_train: 训练集目标变量
Returns:
预处理后的特征训练好的LightGBM模型
"""
print("训练LightGBM模型...")
# 预处理训练数据
X_train_preprocessed = preprocessor.fit_transform(X_train)
# 获取特征名称
num_features = preprocessor.transformers_[0][2]
cat_features = preprocessor.named_transformers_['cat'].get_feature_names_out()
feature_names = num_features + list(cat_features)
# 创建LightGBM数据集
lgb_train = lgb.Dataset(X_train_preprocessed, y_train, feature_name=feature_names)
# 设置LightGBM参数
params = {
'objective': 'binary',
'metric': 'binary_logloss',
'learning_rate': 0.05,
'n_estimators': 500,
'random_state': 42,
'verbose': -1
}
# 训练LightGBM模型
lgb_model = lgb.train(
params,
lgb_train,
num_boost_round=500,
valid_sets=[lgb_train],
callbacks=[lgb.log_evaluation(period=100)]
)
return feature_names, lgb_model
def evaluate_model(self, model: Pipeline | lgb.Booster, preprocessor: ColumnTransformer,
X: pd.DataFrame, y: pd.Series, model_name: str) -> dict:
"""评估模型性能
Args:
model: 要评估的模型
preprocessor: 数据预处理管道
X: 测试数据特征
y: 测试数据目标变量
model_name: 模型名称
Returns:
模型性能指标
"""
print(f"评估{model_name}模型...")
# 预测概率
if isinstance(model, Pipeline):
y_pred_proba = model.predict_proba(X)[:, 1]
y_pred = model.predict(X)
else:
# LightGBM模型
X_preprocessed = preprocessor.transform(X)
y_pred_proba = model.predict(X_preprocessed)
y_pred = (y_pred_proba >= 0.5).astype(int)
# 计算性能指标
metrics = {
'accuracy': accuracy_score(y, y_pred),
'precision': precision_score(y, y_pred),
'recall': recall_score(y, y_pred),
'f1': f1_score(y, y_pred),
'roc_auc': roc_auc_score(y, y_pred_proba)
}
print(f"{model_name} 模型性能:")
for metric_name, value in metrics.items():
print(f" {metric_name}: {value:.4f}")
return metrics
def save_model(self, model: Pipeline | lgb.Booster, preprocessor: ColumnTransformer,
feature_names: list = None, model_name: str = "best_model"):
"""保存模型和预处理工具
Args:
model: 要保存的模型
preprocessor: 数据预处理管道
feature_names: 特征名称列表仅LightGBM需要
model_name: 模型名称
"""
print(f"保存{model_name}模型...")
if isinstance(model, Pipeline):
# 保存完整的管道模型
model_path = self.models_dir / f"{model_name}.joblib"
joblib.dump(model, model_path)
else:
# 保存LightGBM模型
model_path = self.models_dir / f"{model_name}.joblib"
joblib.dump(model, model_path)
# 保存预处理管道
preprocessor_path = self.models_dir / "preprocessor.joblib"
joblib.dump(preprocessor, preprocessor_path)
# 保存特征名称
features_path = self.models_dir / "features.joblib"
joblib.dump(feature_names, features_path)
print(f"模型保存成功: {model_path}")
def train_and_evaluate(self):
"""完整的训练和评估流程"""
print("开始模型训练和评估流程...")
# 1. 准备数据
X_train, X_val, X_test, y_train, y_val, y_test = self.prepare_data()
# 2. 创建预处理管道
preprocessor = self.create_preprocessor(X_train)
# 3. 训练Logistic Regression模型
lr_model = self.train_logistic_regression(preprocessor, X_train, y_train)
# 4. 评估Logistic Regression模型
lr_train_metrics = self.evaluate_model(lr_model, preprocessor, X_train, y_train, "Logistic Regression (训练集)")
lr_val_metrics = self.evaluate_model(lr_model, preprocessor, X_val, y_val, "Logistic Regression (验证集)")
# 5. 训练LightGBM模型
feature_names, lgb_model = self.train_lightgbm(preprocessor, X_train, y_train)
# 6. 评估LightGBM模型
lgb_train_metrics = self.evaluate_model(lgb_model, preprocessor, X_train, y_train, "LightGBM (训练集)")
lgb_val_metrics = self.evaluate_model(lgb_model, preprocessor, X_val, y_val, "LightGBM (验证集)")
# 7. 选择最佳模型
print("\n选择最佳模型...")
best_model = None
best_model_name = ""
if lr_val_metrics['roc_auc'] > lgb_val_metrics['roc_auc']:
best_model = lr_model
best_model_name = "Logistic Regression"
else:
best_model = lgb_model
best_model_name = "LightGBM"
print(f"最佳模型: {best_model_name}")
# 8. 在测试集上评估最佳模型
print(f"\n在测试集上评估{best_model_name}模型...")
if isinstance(best_model, Pipeline):
best_test_metrics = self.evaluate_model(best_model, preprocessor, X_test, y_test, "Best Model (测试集)")
else:
best_test_metrics = self.evaluate_model(best_model, preprocessor, X_test, y_test, "Best Model (测试集)")
# 9. 保存最佳模型
if isinstance(best_model, Pipeline):
self.save_model(best_model, preprocessor, model_name="best_model_lr")
else:
self.save_model(best_model, preprocessor, feature_names, model_name="best_model")
print("\n模型训练和评估流程完成!")
return best_model, best_test_metrics
if __name__ == "__main__":
# 运行模型训练和评估
trainer = ModelTrainer()
best_model, test_metrics = trainer.train_and_evaluate()

3918
ml_course_design/uv.lock generated Normal file

File diff suppressed because it is too large Load Diff

30
run_agent.bat Normal file
View File

@ -0,0 +1,30 @@
@echo off
REM 电信客户流失预测Agent启动脚本
set "PROJECT_ROOT=%~dp0ml_course_design"
REM 检查项目根目录是否存在
if not exist "%PROJECT_ROOT%" (
echo 错误: 项目根目录不存在于 "%PROJECT_ROOT%"
echo 请确保该脚本与 ml_course_design 文件夹位于同一目录下
pause
exit /b 1
)
REM 检查uv是否已安装
where uv >nul 2>nul
if %errorlevel% neq 0 (
echo 错误: 未找到uv命令
echo 请先安装uv: pip install uv
pause
exit /b 1
)
REM 切换到项目根目录并启动Agent应用
echo 正在启动客户流失预测Agent...
echo 项目根目录: %PROJECT_ROOT%
cd /d "%PROJECT_ROOT%"
uv run python -m src.agent_app
REM 等待用户按下任意键退出
pause

32
run_agent.ps1 Normal file
View File

@ -0,0 +1,32 @@
# 电信客户流失预测Agent启动脚本 (PowerShell版本)
# 设置项目根目录
$ProjectRoot = "$PSScriptRoot\ml_course_design"
# 检查项目根目录是否存在
if(-not (Test-Path $ProjectRoot)){
Write-Host "错误: 项目根目录不存在于 $ProjectRoot" -ForegroundColor Red
Write-Host "请确保该脚本与 ml_course_design 文件夹位于同一目录下" -ForegroundColor Yellow
Pause
exit 1
}
# 检查uv是否已安装
if(-not (Get-Command "uv" -ErrorAction SilentlyContinue)){
Write-Host "错误: 未找到uv命令" -ForegroundColor Red
Write-Host "请先安装uv: pip install uv" -ForegroundColor Yellow
Pause
exit 1
}
# 切换到项目根目录并启动Agent应用
Write-Host "正在启动客户流失预测Agent..." -ForegroundColor Green
Write-Host "项目根目录: $ProjectRoot" -ForegroundColor Cyan
Set-Location -Path $ProjectRoot
# 启动Agent应用
uv run python -m src.agent_app
# 等待用户按下任意键退出
Write-Host "\n按任意键退出..." -ForegroundColor Gray
$x = $host.ui.RawUI.ReadKey("NoEcho,IncludeKeyDown")

57
start_agent.py Normal file
View File

@ -0,0 +1,57 @@
#!/usr/bin/env python3
"""
电信客户流失预测Agent启动脚本
"""
import os
import sys
import subprocess
from pathlib import Path
def main():
# 获取当前脚本所在目录
script_dir = Path(__file__).parent
# 项目根目录
project_root = script_dir / "ml_course_design"
print(f"当前脚本目录: {script_dir}")
print(f"项目根目录: {project_root}")
# 检查项目根目录是否存在
if not project_root.exists():
print(f"错误: 项目根目录不存在于 {project_root}")
print("请确保该脚本与 ml_course_design 文件夹位于同一目录下")
input("按回车键退出...")
return 1
# 检查uv是否已安装
try:
subprocess.run(["uv", "--version"], check=True, capture_output=True, text=True)
except subprocess.CalledProcessError:
print("错误: 未找到uv命令")
print("请先安装uv: pip install uv")
input("按回车键退出...")
return 1
# 切换到项目根目录并启动Agent应用
print("正在启动客户流失预测Agent...")
print(f"\n使用以下命令启动Agent:")
print(f"cd {project_root} && uv run python -m src.agent_app")
# 执行命令
try:
subprocess.run(
["uv", "run", "python", "-m", "src.agent_app"],
cwd=str(project_root),
check=True
)
except subprocess.CalledProcessError as e:
print(f"启动失败: {e}")
input("按回车键退出...")
return 1
return 0
if __name__ == "__main__":
sys.exit(main())

34
test_agent.py Normal file
View File

@ -0,0 +1,34 @@
import os
import sys
from pathlib import Path
# 获取项目根目录
project_root = Path(r"c:\Users\HUANGYING\Desktop\jqxx\新建文件夹 (4)\ml_course_design")
# 将项目根目录添加到Python路径
sys.path.insert(0, str(project_root))
print(f"项目根目录: {project_root}")
print(f"Python路径中包含项目根目录: {str(project_root) in sys.path}")
# 测试导入
print("\n测试导入...")
try:
from src.agent_app import ChurnPredictionAgent
print("✅ 成功导入ChurnPredictionAgent!")
# 测试创建Agent实例
agent = ChurnPredictionAgent()
print("✅ 成功创建Agent实例!")
print(f"\n🎉 测试成功! 现在可以使用以下命令运行Agent应用:")
print(r" cd 'c:\Users\HUANGYING\Desktop\jqxx\新建文件夹 (4)\ml_course_design' ; uv run python -m src.agent_app")
except ImportError as e:
print(f"❌ 导入失败: {e}")
sys.exit(1)
except Exception as e:
print(f"❌ 其他错误: {e}")
import traceback
traceback.print_exc()
sys.exit(1)

0
uv Normal file
View File