generated from Python-2026Spring/assignment-05-final-project-template
feat: 上传课程设计完整代码
This commit is contained in:
parent
a9196c99ff
commit
d2098227fb
900
README.md
900
README.md
@ -1,73 +1,875 @@
|
||||
# 机器学习 × LLM × Agent:课程设计(5 天)
|
||||
|
||||
> **小组作业** | 2–3 人/组 | 构建一个「可落地的智能预测与行动建议系统」
|
||||
|
||||
用传统机器学习完成可量化的预测任务,再用 LLM + Agent 把预测结果变成可执行的决策/建议,并保证输出结构化、可追溯、可复现。
|
||||
|
||||
---
|
||||
|
||||
## 📅 课程安排概览
|
||||
|
||||
| 天数 | 主题 | 内容 |
|
||||
|------|------|------|
|
||||
| **Day 1** | 项目启动 | 技术栈介绍 + 演示 + 选题分组 |
|
||||
| **Day 2** | 自主设计 | 分组开发 |
|
||||
| **Day 3** | 答疑 + Git 指导 | 集中答疑 + Git 提交教学 |
|
||||
| **Day 4** | 自主设计 | 继续开发 + 准备展示 |
|
||||
| **Day 5** | 小组展示 | 教师机运行 + 评分 |
|
||||
|
||||
---
|
||||
|
||||
## 📑 目录
|
||||
|
||||
- [Day 1:项目启动](#day-1项目启动)
|
||||
- [快速开始](#-快速开始)
|
||||
- [技术栈要求](#技术栈要求2026-版)
|
||||
- [选题指南](#选题指南)
|
||||
- [可选扩展思路](#可选扩展思路)
|
||||
- [Day 2:自主设计](#day-2自主设计)
|
||||
- [Day 3:答疑 + Git 指导](#day-3答疑--git-指导)
|
||||
- [Git 安装](#git-安装国内环境)
|
||||
- [Git 基础操作](#git-基础操作)
|
||||
- [.gitignore 详解](#gitignore-详解)
|
||||
- [Day 4:自主设计](#day-4自主设计)
|
||||
- [Day 5:小组展示](#day-5小组展示)
|
||||
- [展示流程](#展示流程)
|
||||
- [跨机运行检查清单](#跨机运行检查清单)
|
||||
- [评分标准](#评分标准总分-100)
|
||||
- [附录](#附录)
|
||||
- [代码示例](#代码示例)
|
||||
- [项目结构](#建议项目结构)
|
||||
- [参考资料](#参考资料)
|
||||
|
||||
---
|
||||
|
||||
# Day 1:项目启动
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
> **2026 最佳实践**:使用 `uv` 替代 pip/venv/poetry 进行全流程项目管理
|
||||
|
||||
```bash
|
||||
# 1. 安装 uv(如尚未安装)
|
||||
# 方法 A:使用 pip 安装(推荐,国内可用)
|
||||
pip install uv -i https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
# 方法 B:使用 pipx 安装(隔离环境)
|
||||
pipx install uv
|
||||
|
||||
# 方法 C:官方脚本(需要科学上网)
|
||||
# macOS / Linux: curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||
# Windows: powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
|
||||
|
||||
# 配置 PyPI 镜像(加速依赖下载)
|
||||
uv config set index-url https://mirrors.aliyun.com/pypi/simple/
|
||||
|
||||
# 2. 克隆/Fork 本模板仓库
|
||||
git clone http://hblu.top:3000/MachineLearning2025/CourseDesign
|
||||
cd CourseDesign
|
||||
|
||||
# 3. 初始化项目并安装依赖(uv 自动创建虚拟环境)
|
||||
uv sync
|
||||
|
||||
# 4. 配置 DeepSeek API Key(不要提交到仓库!)
|
||||
cp .env.example .env
|
||||
# 编辑 .env 文件,填入你的 API Key
|
||||
# DEEPSEEK_API_KEY="your-key-here"
|
||||
|
||||
# 5. 运行示例
|
||||
# 方式 A:运行 Streamlit 可视化 Demo(推荐)
|
||||
uv run streamlit run src/streamlit_app.py
|
||||
|
||||
# 方式 B:运行命令行 Agent Demo
|
||||
uv run python src/agent_app.py
|
||||
|
||||
# 方式 C:运行训练脚本
|
||||
uv run python src/train.py
|
||||
```
|
||||
|
||||
### uv 常用命令速查
|
||||
|
||||
| 命令 | 说明 |
|
||||
|------|------|
|
||||
| `uv sync` | 同步依赖(根据 `pyproject.toml` 和 `uv.lock`) |
|
||||
| `uv add <package>` | 添加依赖(自动更新 `pyproject.toml` 和 `uv.lock`) |
|
||||
| `uv add --dev <package>` | 添加开发依赖(如 pytest, ruff) |
|
||||
| `uv run <command>` | 在项目环境中运行命令 |
|
||||
| `uv lock` | 手动更新锁文件 |
|
||||
| `uv python install 3.12` | 安装指定 Python 版本 |
|
||||
|
||||
---
|
||||
|
||||
## 技术栈要求(2026 版)
|
||||
|
||||
| 组件 | 要求 | 2026 最佳实践 |
|
||||
|------|------|---------------|
|
||||
| **人数** | 2–3 人/组 | — |
|
||||
| **Python 版本** | ≥ 3.12 | 推荐 3.12/3.14 |
|
||||
| **项目管理** | `uv` | 替代 pip/venv/poetry,10-100x 更快 |
|
||||
| **数据处理** | `polars` + `pandas>=2.2` | polars 作为主力(Lazy API),pandas 用于兼容 |
|
||||
| **数据可视化** | `seaborn>=0.13` | 使用 Seaborn Objects API(`so.Plot`) |
|
||||
| **数据验证** | `pydantic` + `pandera` | pydantic 验证单行/配置,pandera 验证 DataFrame 清洗前后 |
|
||||
| **机器学习** | `scikit-learn` + `lightgbm` | sklearn 做基线,LightGBM 做高性能模型 |
|
||||
| **Agent 框架** | `pydantic-ai` | 结构化输出、类型安全的 Agent |
|
||||
| **LLM 提供方** | `DeepSeek` | OpenAI 兼容 API |
|
||||
|
||||
### 必须包含的三块能力
|
||||
|
||||
| 能力 | 说明 |
|
||||
|------|------|
|
||||
| **传统机器学习** | 可复现训练流程、离线评估指标、模型保存与加载 |
|
||||
| **LLM** | 用于解释、归因、生成建议/回复、信息整合(不能凭空杜撰) |
|
||||
| **Agent** | 用工具调用把系统串起来(至少 2 个 tool,其中 1 个必须是 ML 预测/评估相关工具) |
|
||||
|
||||
---
|
||||
|
||||
## 选题指南
|
||||
|
||||
> ⚠️ **注意**:Level 1/2/3 **都可以拿满分**;高难度通常更容易体现"深度",但不会因为选 Level 1 就被封顶。
|
||||
|
||||
### Level 1|入门:表格预测 + 行动建议闭环
|
||||
|
||||
> 📌 **建议新手选择**
|
||||
|
||||
**目标**:做一个结构化数据的分类/回归模型,并让 Agent 基于模型输出给出可执行建议。
|
||||
|
||||
#### 推荐数据集
|
||||
|
||||
| 数据集 | 链接 |
|
||||
|--------|------|
|
||||
| Telco Customer Churn | [Kaggle](https://www.kaggle.com/datasets/blastchar/telco-customer-churn) |
|
||||
| German Credit Risk | [Kaggle](https://www.kaggle.com/datasets/uciml/german-credit) |
|
||||
| Bank Marketing | [Kaggle](https://www.kaggle.com/datasets/janiobachmann/bank-marketing-dataset) |
|
||||
| Heart Failure Prediction | [Kaggle](https://www.kaggle.com/datasets/fedesoriano/heart-failure-prediction) |
|
||||
|
||||
#### ✅ 必做部分
|
||||
|
||||
| 模块 | 要求 |
|
||||
|------|------|
|
||||
| **数据处理** | 使用 Polars 完成可复现的数据清洗流水线;使用 Pandera 定义 Schema |
|
||||
| **机器学习** | 至少 2 个模型对比(1 个基线如 LogReg,1 个强模型如 LightGBM);达到 `F1 ≥ 0.70` 或 `ROC-AUC ≥ 0.75` |
|
||||
| **Agent** | 使用 Pydantic 定义输入输出;至少 2 个 tool(含 1 个 ML 预测工具) |
|
||||
|
||||
---
|
||||
|
||||
### Level 2|进阶:文本任务 + 处置建议
|
||||
|
||||
> 📌 **NLP 向**
|
||||
|
||||
**目标**:做文本分类/情感分析,并让 Agent 生成结构化处置方案。
|
||||
|
||||
#### 推荐数据集
|
||||
|
||||
| 数据集 | 链接 | 说明 |
|
||||
|--------|------|------|
|
||||
| Twitter US Airline Sentiment | [Kaggle](https://www.kaggle.com/datasets/crowdflower/twitter-airline-sentiment) | 航空公司情感分析 |
|
||||
| IMDB 50K Movie Reviews | [Kaggle](https://www.kaggle.com/datasets/lakshmi25npathi/imdb-dataset-of-50k-movie-reviews) | 电影评论情感 |
|
||||
| SMS Spam Collection | [Kaggle](https://www.kaggle.com/datasets/uciml/sms-spam-collection-dataset) | 垃圾短信分类 |
|
||||
| Consumer Complaints | [Kaggle](https://www.kaggle.com/datasets/selener/consumer-complaint-database) | 投诉分流 |
|
||||
|
||||
#### ✅ 必做部分
|
||||
|
||||
| 模块 | 要求 |
|
||||
|------|------|
|
||||
| **数据处理** | 文本清洗要「克制」,说明预处理策略;使用 Pandera 定义 Schema |
|
||||
| **机器学习** | 基线 `TF-IDF + LogReg`;达到 `Accuracy ≥ 0.85` 或 `Macro-F1 ≥ 0.80` |
|
||||
| **Agent** | 实现「分类 → 解释 → 生成处置方案」流程;输出结构化(Pydantic) |
|
||||
|
||||
---
|
||||
|
||||
### Level 3|高阶:不平衡/多表/时序 + 多步决策
|
||||
|
||||
> 📌 **真实世界约束**
|
||||
|
||||
**目标**:处理更复杂的数据特性(极度不平衡、多表关联、时序预测),实现多步决策 Agent。
|
||||
|
||||
#### 推荐数据集
|
||||
|
||||
| 数据集 | 链接 | 特点 |
|
||||
|--------|------|------|
|
||||
| Credit Card Fraud Detection | [Kaggle](https://www.kaggle.com/datasets/mlg-ulb/creditcardfraud) | 极度不平衡 |
|
||||
| IEEE-CIS Fraud Detection | [Kaggle](https://www.kaggle.com/c/ieee-fraud-detection) | 多表/特征工程复杂 |
|
||||
| M5 Forecasting - Accuracy | [Kaggle](https://www.kaggle.com/competitions/m5-forecasting-accuracy) | 时序预测 |
|
||||
| Instacart Market Basket | [Kaggle](https://www.kaggle.com/c/instacart-market-basket-analysis) | 多表 + 推荐 |
|
||||
|
||||
#### ✅ 必做部分
|
||||
|
||||
| 模块 | 要求 |
|
||||
|------|------|
|
||||
| **数据处理** | 明确主键/外键与 join 规则;写出「数据泄露风险点清单」 |
|
||||
| **机器学习** | 使用合理指标(如 `PR-AUC`);必须使用时间切分评估(如时序) |
|
||||
| **Agent** | 至少 3 步决策(评估 → 解释 → 行动计划);输出结构化 |
|
||||
|
||||
---
|
||||
|
||||
### 自选题目标准
|
||||
|
||||
> 💡 **鼓励自选题目**,但必须满足以下硬标准
|
||||
|
||||
| 要求 | 说明 |
|
||||
|------|------|
|
||||
| **数据真实可获取** | 公开、可重复下载(Kaggle/UCI/OpenML 等),提供链接 |
|
||||
| **可量化预测任务** | 有明确标签/目标变量与评价指标 |
|
||||
| **业务闭环** | 能落到「下一步做什么」的决策/行动 |
|
||||
| **Agent 工具调用** | 至少 2 个 tools,其中 1 个必须是 ML 工具 |
|
||||
| **规模与复杂度** | 样本量建议 ≥ 5,000 |
|
||||
| **合规性** | 禁止爬取受限数据;禁止提交密钥/隐私数据 |
|
||||
|
||||
---
|
||||
|
||||
## 可选扩展思路
|
||||
|
||||
以下是一些可选的扩展方向,用于加深项目深度,**不作为评分硬性要求**:
|
||||
|
||||
| 方向 | 思路 |
|
||||
|------|------|
|
||||
| **可解释性** | 添加特征重要性解释工具(如 `explain_top_features`),让 Agent 能解释决策依据 |
|
||||
| **代价敏感策略** | 给每个动作定义成本/收益假设,让 Agent 输出最划算的动作组合 |
|
||||
| **阈值策略** | 把"预测概率"转化为"干预策略"(高/中/低风险不同处理) |
|
||||
| **相似案例检索** | 用 TF-IDF/Embedding 做 `retrieve_similar(text) -> top_k`,提供可追溯证据 |
|
||||
| **合规检查** | 对 Agent 输出做规则检查(如不得泄露隐私、不得虚假承诺) |
|
||||
| **误差分析** | Top 误判样本分析,找出模型薄弱点 |
|
||||
| **消融实验** | 对比不同特征/模型配置,得出改进方向 |
|
||||
|
||||
---
|
||||
|
||||
# Day 2:自主设计
|
||||
|
||||
**今日任务**:
|
||||
- 分组进行项目设计与开发
|
||||
- 完成数据探索与清洗
|
||||
- 开始训练基线模型
|
||||
|
||||
**建议里程碑**:
|
||||
- [ ] 数据下载并完成初步探索
|
||||
- [ ] 数据清洗流水线可运行
|
||||
- [ ] 基线模型训练完成
|
||||
|
||||
---
|
||||
|
||||
# Day 3:答疑 + Git 指导
|
||||
|
||||
## Git 安装(国内环境)
|
||||
|
||||
### Windows
|
||||
|
||||
1. 下载 Git for Windows:
|
||||
- 官方镜像(推荐):https://registry.npmmirror.com/binary.html?path=git-for-windows/
|
||||
- 或官网:https://git-scm.com/download/win
|
||||
2. 双击安装,全程默认设置即可
|
||||
3. 安装完成后,右键可看到「Git Bash Here」选项
|
||||
|
||||
### macOS
|
||||
|
||||
```bash
|
||||
# 方法 A:Xcode 命令行工具(推荐)
|
||||
xcode-select --install
|
||||
|
||||
# 方法 B:Homebrew
|
||||
brew install git
|
||||
```
|
||||
|
||||
### Linux (Ubuntu/Debian)
|
||||
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install git
|
||||
```
|
||||
|
||||
### 验证安装
|
||||
|
||||
```bash
|
||||
git --version
|
||||
# 输出类似:git version 2.43.0
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Git 基础操作
|
||||
|
||||
### 首次配置
|
||||
|
||||
```bash
|
||||
# 设置用户名和邮箱(提交记录会显示)
|
||||
git config --global user.name "你的姓名"
|
||||
git config --global user.email "你的邮箱@example.com"
|
||||
```
|
||||
|
||||
### 克隆仓库
|
||||
|
||||
```bash
|
||||
# 组长创建仓库后,所有组员克隆
|
||||
git clone http://hblu.top:3000/<用户名>/<项目名>.git
|
||||
cd <项目名>
|
||||
```
|
||||
|
||||
### 日常开发流程
|
||||
|
||||
```bash
|
||||
# 1. 拉取最新代码(每次开始工作前)
|
||||
git pull
|
||||
|
||||
# 2. 查看当前状态
|
||||
git status
|
||||
|
||||
# 3. 添加修改的文件
|
||||
git add . # 添加所有修改
|
||||
git add src/train.py # 或只添加特定文件
|
||||
|
||||
# 4. 提交修改
|
||||
git commit -m "feat: 添加数据预处理模块"
|
||||
|
||||
# 5. 推送到远程仓库
|
||||
git push
|
||||
```
|
||||
|
||||
### 常用命令速查
|
||||
|
||||
| 命令 | 说明 |
|
||||
|------|------|
|
||||
| `git clone <url>` | 克隆远程仓库 |
|
||||
| `git pull` | 拉取远程更新 |
|
||||
| `git status` | 查看当前状态 |
|
||||
| `git add .` | 暂存所有修改 |
|
||||
| `git commit -m "消息"` | 提交修改 |
|
||||
| `git push` | 推送到远程 |
|
||||
| `git log --oneline -5` | 查看最近 5 条提交 |
|
||||
|
||||
### 团队协作注意事项
|
||||
|
||||
1. **每次开始工作前先 `git pull`**,避免冲突
|
||||
2. **提交信息要有意义**,如 `feat: 添加 Agent 工具` 而非 `update`
|
||||
3. **小步提交**,不要把所有修改攒到最后一起提交
|
||||
|
||||
---
|
||||
|
||||
## .gitignore 详解
|
||||
|
||||
`.gitignore` 文件告诉 Git **哪些文件不要提交**。这非常重要,因为:
|
||||
- **API Key 泄露会导致账户被盗用**
|
||||
- **大文件会导致仓库臃肿**
|
||||
- **临时文件没有提交意义**
|
||||
|
||||
### 本项目必须忽略的文件
|
||||
|
||||
创建 `.gitignore` 文件,内容如下:
|
||||
|
||||
```gitignore
|
||||
# ===== 环境变量(绝对不能提交!)=====
|
||||
.env
|
||||
|
||||
# ===== Python 虚拟环境 =====
|
||||
.venv/
|
||||
venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
.pytest_cache/
|
||||
|
||||
# ===== IDE 配置 =====
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
|
||||
# ===== macOS 系统文件 =====
|
||||
.DS_Store
|
||||
|
||||
# ===== Jupyter =====
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# ===== 超大文件(超过 10MB 需手动添加)=====
|
||||
# 如果你的数据或模型文件超过 10MB,请在下面添加:
|
||||
# data/large_dataset.csv
|
||||
# models/large_model.pkl
|
||||
```
|
||||
|
||||
> 💡 **关于 data/ 和 models/ 文件**:
|
||||
> - **默认应该提交**,方便教师机直接运行
|
||||
> - 如果单个文件 **超过 10MB**,请添加到 `.gitignore` 并在 `data/README.md` 中说明下载方式
|
||||
|
||||
### 检查 .gitignore 是否生效
|
||||
|
||||
```bash
|
||||
# 查看哪些文件会被 Git 忽略
|
||||
git status --ignored
|
||||
|
||||
# 如果之前已经提交了不应提交的文件,需要先从 Git 中移除
|
||||
git rm --cached .env # 从 Git 移除但保留本地文件
|
||||
git rm --cached -r __pycache__
|
||||
git commit -m "chore: 移除不应提交的文件"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 作业提交流程
|
||||
|
||||
### 1. 账号信息
|
||||
|
||||
账号已统一创建,请登录 [hblu.top:3000/MachineLearning2025](http://hblu.top:3000/MachineLearning2025)
|
||||
|
||||
| 项目 | 说明 |
|
||||
|------|------|
|
||||
| **用户名** | `st` + 学号(如 `st2024001`) |
|
||||
| **初始密码** | `12345678`(请登录后修改) |
|
||||
| **组织** | MachineLearning2025 |
|
||||
|
||||
> ⚠️ **首次登录后请立即修改密码**
|
||||
|
||||
### 2. 组长创建仓库
|
||||
|
||||
在 [MachineLearning2025](http://hblu.top:3000/MachineLearning2025) 组织下创建新仓库,命名格式:`组号-项目名称`(如 `G01-ChurnPredictor`)
|
||||
|
||||
### 3. 添加组员
|
||||
|
||||
Settings → Collaborators → 添加其他组员(使用 `st+学号` 搜索)
|
||||
|
||||
### 4. 提交检查清单
|
||||
|
||||
- [ ] `.gitignore` 已创建且包含必要规则
|
||||
- [ ] `.env.example` 已提交,`.env` 未提交
|
||||
- [ ] 没有提交 API Key 或敏感信息
|
||||
- [ ] 没有提交大于 10MB 的文件
|
||||
|
||||
---
|
||||
|
||||
# Day 4:自主设计
|
||||
|
||||
**今日任务**:
|
||||
- 继续完善项目
|
||||
- 完成 Agent 集成
|
||||
- 准备 Streamlit Demo
|
||||
- 撰写项目报告
|
||||
|
||||
**建议里程碑**:
|
||||
- [ ] ML 模型完成并保存
|
||||
- [ ] Agent 工具调用测试通过
|
||||
- [ ] Streamlit Demo 可运行
|
||||
- [ ] README.md 初稿完成
|
||||
|
||||
---
|
||||
|
||||
# Day 5:小组展示
|
||||
|
||||
## 展示流程
|
||||
|
||||
1. **教师机克隆你的仓库**
|
||||
```bash
|
||||
git clone http://hblu.top:3000/MachineLearning2025/<项目名>.git
|
||||
cd <项目名>
|
||||
```
|
||||
|
||||
2. **安装依赖并运行**
|
||||
```bash
|
||||
uv sync
|
||||
cp .env.example .env
|
||||
# 教师填入测试用 API Key
|
||||
uv run streamlit run src/streamlit_app.py
|
||||
```
|
||||
|
||||
3. **5-8 分钟 Demo 展示**
|
||||
|
||||
---
|
||||
|
||||
## 跨机运行检查清单
|
||||
|
||||
> ⚠️ **避免「明明在我电脑上能跑」的问题**
|
||||
|
||||
### 必须检查
|
||||
|
||||
| 检查项 | 说明 | 常见错误 |
|
||||
|--------|------|----------|
|
||||
| **依赖完整** | 所有依赖都在 `pyproject.toml` 中 | 忘记 `uv add` 新安装的包 |
|
||||
| **相对路径** | 数据/模型使用相对路径 | `C:\Users\张三\data.csv` |
|
||||
| **环境变量** | API Key 通过 `.env` 读取 | 硬编码 Key 在代码中 |
|
||||
| **数据可获取** | 数据文件有下载说明或包含在仓库 | 数据只在本地,忘记上传 |
|
||||
| **uv.lock** | 锁文件已提交 | 依赖版本不确定 |
|
||||
|
||||
### 提交前测试方法
|
||||
|
||||
```bash
|
||||
# 模拟干净环境测试
|
||||
cd /tmp
|
||||
git clone <你的仓库地址>
|
||||
cd <项目名>
|
||||
uv sync
|
||||
cp .env.example .env
|
||||
# 填入 API Key
|
||||
uv run streamlit run src/streamlit_app.py
|
||||
```
|
||||
|
||||
### 常见问题排查
|
||||
|
||||
| 错误 | 原因 | 解决方案 |
|
||||
|------|------|----------|
|
||||
| `ModuleNotFoundError` | 缺少依赖 | `uv add <包名>` 后重新提交 |
|
||||
| `FileNotFoundError` | 路径问题 | 使用 `Path(__file__).parent` 获取相对路径 |
|
||||
| `DEEPSEEK_API_KEY not found` | 环境变量问题 | 检查 `.env` 格式和 `python-dotenv` |
|
||||
|
||||
---
|
||||
|
||||
## 评分标准(总分 100)
|
||||
|
||||
> ⚠️ 所有分析、对比、决策逻辑都必须在 `README.md` 中清晰体现。
|
||||
|
||||
### A. 问题与数据(10 分)
|
||||
|
||||
| 维度 | 分值 | 要求 |
|
||||
|------|------|------|
|
||||
| 任务定义清晰 | 5 | 标签/目标、输入输出边界 |
|
||||
| 数据说明与切分 | 5 | 来源链接、字段含义、切分策略 |
|
||||
|
||||
### B. 传统机器学习(30 分)
|
||||
|
||||
| 维度 | 分值 | 要求 |
|
||||
|------|------|------|
|
||||
| 基线与可复现训练 | 10 | 固定随机种子、训练脚本可跑通 |
|
||||
| 指标与对比 | 10 | 达到指标要求,与基线对比 |
|
||||
| 误差分析 | 10 | 展示错误样本/分桶,给出改进方向 |
|
||||
|
||||
### C. LLM + Agent(30 分)
|
||||
|
||||
| 维度 | 分值 | 要求 |
|
||||
|------|------|------|
|
||||
| 工具调用 | 10 | 至少 2 个 tools,能稳定调用 ML 工具 |
|
||||
| 结构化输出 | 10 | Pydantic schema 清晰;字段有约束 |
|
||||
| 建议可执行且有证据 | 10 | 能落地的动作清单,引用依据 |
|
||||
|
||||
### D. 工程与演示(30 分)
|
||||
|
||||
| 维度 | 分值 | 要求 |
|
||||
|------|------|------|
|
||||
| **Streamlit 演示** | **15** | 交互流畅;展示「预测→分析→建议」全流程 |
|
||||
| **跨机运行** | **10** | 在教师机 `git clone && uv sync && uv run` 可直接运行 |
|
||||
| 代码质量 | 5 | 结构清晰、有类型提示与文档 |
|
||||
|
||||
### ❌ 常见扣分项
|
||||
|
||||
- 训练/推理无法在教师机跑通
|
||||
- 未使用 `uv` 管理项目
|
||||
- 数据泄露(尤其是时序/多表)
|
||||
- Agent 编造数据集不存在的事实
|
||||
- **把密钥提交进仓库(严重扣分)**
|
||||
|
||||
### ✅ 常见加分项
|
||||
|
||||
- 使用 Polars Lazy API 高效处理数据
|
||||
- 做了可解释性/阈值策略/代价敏感分析
|
||||
- 做了检索增强且引用可追溯证据
|
||||
- 做了消融/对比实验,结论清晰
|
||||
|
||||
---
|
||||
|
||||
# 附录
|
||||
|
||||
## 代码示例
|
||||
|
||||
### 数据处理:Polars 最佳实践
|
||||
|
||||
```python
|
||||
import polars as pl
|
||||
|
||||
# ✅ 推荐:使用 Lazy API(自动查询优化)
|
||||
lf = pl.scan_csv("data/train.csv")
|
||||
result = (
|
||||
lf.filter(pl.col("age") > 30)
|
||||
.group_by("category")
|
||||
.agg(pl.col("value").mean())
|
||||
.collect() # 最后一步才执行
|
||||
)
|
||||
|
||||
# ✅ 推荐:从 Pandas 无缝迁移
|
||||
df_polars = pl.from_pandas(df_pandas)
|
||||
df_pandas = df_polars.to_pandas()
|
||||
```
|
||||
|
||||
### 数据验证:Pydantic + Pandera
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
class CustomerFeatures(BaseModel):
|
||||
"""客户特征数据模型"""
|
||||
age: int = Field(ge=0, le=120, description="客户年龄")
|
||||
tenure: int = Field(ge=0, description="客户任期(月)")
|
||||
monthly_charges: float = Field(ge=0, description="月费用")
|
||||
contract_type: str = Field(pattern="^(month-to-month|one-year|two-year)$")
|
||||
```
|
||||
|
||||
```python
|
||||
import pandera as pa
|
||||
from pandera import Column, Check, DataFrameSchema
|
||||
|
||||
# ✅ 定义清洗后 Schema
|
||||
clean_data_schema = DataFrameSchema(
|
||||
columns={
|
||||
"age": Column(pa.Int, checks=[Check.ge(0), Check.le(120)], nullable=False),
|
||||
"tenure": Column(pa.Int, checks=[Check.ge(0)], nullable=False),
|
||||
"monthly_charges": Column(pa.Float, checks=[Check.ge(0)], nullable=False),
|
||||
},
|
||||
strict=True,
|
||||
coerce=True,
|
||||
)
|
||||
```
|
||||
|
||||
### 机器学习:sklearn + LightGBM
|
||||
|
||||
```python
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import roc_auc_score
|
||||
import lightgbm as lgb
|
||||
import joblib
|
||||
|
||||
# 基线模型
|
||||
baseline = LogisticRegression(max_iter=1000, random_state=42)
|
||||
baseline.fit(X_train, y_train)
|
||||
print("Baseline ROC-AUC:", roc_auc_score(y_test, baseline.predict_proba(X_test)[:, 1]))
|
||||
|
||||
# 高性能模型
|
||||
lgb_model = lgb.LGBMClassifier(n_estimators=500, learning_rate=0.05, random_state=42)
|
||||
lgb_model.fit(X_train, y_train)
|
||||
|
||||
# 保存模型
|
||||
joblib.dump(lgb_model, "models/lgb_model.pkl")
|
||||
```
|
||||
|
||||
### Agent:pydantic-ai 示例
|
||||
|
||||
```python
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_ai import Agent, RunContext
|
||||
|
||||
class Decision(BaseModel):
|
||||
"""Agent 输出的结构化决策"""
|
||||
risk_score: float = Field(ge=0, le=1, description="预测风险概率")
|
||||
decision: str = Field(description="建议策略")
|
||||
actions: list[str] = Field(description="可执行动作清单")
|
||||
rationale: str = Field(description="决策依据")
|
||||
|
||||
agent = Agent(
|
||||
"deepseek:deepseek-chat",
|
||||
output_type=Decision,
|
||||
system_prompt="你是业务决策助手。必须先调用工具获取预测结果,再给出结构化决策。",
|
||||
)
|
||||
|
||||
@agent.tool
|
||||
def predict_risk(ctx: RunContext, features: CustomerFeatures) -> float:
|
||||
"""调用 ML 模型返回风险分数"""
|
||||
# TODO: 实现模型调用
|
||||
pass
|
||||
```
|
||||
|
||||
### API Key 配置
|
||||
|
||||
> ⚠️ **不要把 Key 写进代码、不要提交到仓库!**
|
||||
|
||||
创建 `.env.example`(提交到仓库):
|
||||
```
|
||||
DEEPSEEK_API_KEY=your-key-here
|
||||
```
|
||||
|
||||
复制为 `.env` 并填入真实 Key(`.env` 在 `.gitignore` 中排除)。
|
||||
|
||||
---
|
||||
|
||||
## 建议项目结构
|
||||
|
||||
```
|
||||
ml_course_design/
|
||||
├── pyproject.toml # 项目配置与依赖
|
||||
├── uv.lock # 锁定的依赖版本
|
||||
├── README.md # 项目说明与报告
|
||||
├── .env.example # 环境变量模板
|
||||
├── .gitignore # Git 忽略规则
|
||||
│
|
||||
├── data/ # 数据目录
|
||||
│ └── README.md # 数据来源说明
|
||||
│
|
||||
├── models/ # 训练产物
|
||||
│ └── .gitkeep
|
||||
│
|
||||
├── src/ # 核心代码
|
||||
│ ├── __init__.py
|
||||
│ ├── data.py # 数据读取/清洗
|
||||
│ ├── features.py # Pydantic 特征模型
|
||||
│ ├── train.py # 训练与评估
|
||||
│ ├── infer.py # 推理接口
|
||||
│ ├── agent_app.py # Agent 入口
|
||||
│ └── streamlit_app.py # Demo 入口
|
||||
│
|
||||
└── tests/ # 测试
|
||||
└── test_*.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## README.md 模板(你的项目)
|
||||
|
||||
请将以下内容作为你项目 `README.md` 的模板
|
||||
|
||||
````markdown
|
||||
# 项目名称
|
||||
|
||||
一句话描述:这个项目做什么?
|
||||
> **机器学习 (Python) 课程设计**
|
||||
|
||||
## 功能特性
|
||||
## 👥 团队成员
|
||||
|
||||
- ✅ 功能 1:描述
|
||||
- ✅ 功能 2:描述
|
||||
- ✅ 功能 3:LLM 功能描述
|
||||
| 姓名 | 学号 | 贡献 |
|
||||
|------|------|------|
|
||||
| 张三 | 2024001 | 数据处理、模型训练 |
|
||||
| 李四 | 2024002 | Agent 开发、Streamlit |
|
||||
| 王五 | 2024003 | 测试、文档撰写 |
|
||||
|
||||
## 快速开始
|
||||
## 📝 项目简介
|
||||
|
||||
### 环境要求
|
||||
(1-2 段描述项目目标、选用的数据集、解决的问题)
|
||||
|
||||
- Python 3.10+
|
||||
- DeepSeek API Key
|
||||
|
||||
### 安装
|
||||
## 🚀 快速开始
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
# 克隆仓库
|
||||
git clone http://hblu.top:3000/MachineLearning2025/GXX-ProjectName.git
|
||||
cd GXX-ProjectName
|
||||
|
||||
### 配置
|
||||
# 安装依赖
|
||||
uv sync
|
||||
|
||||
1. 复制 `.env.example` 为 `.env`
|
||||
2. 填入你的 DeepSeek API Key
|
||||
|
||||
```bash
|
||||
# 配置环境变量
|
||||
cp .env.example .env
|
||||
# 编辑 .env 文件,填入 API Key
|
||||
# 编辑 .env 填入 API Key
|
||||
|
||||
# 运行 Demo
|
||||
uv run streamlit run src/streamlit_app.py
|
||||
```
|
||||
|
||||
### 运行
|
||||
---
|
||||
|
||||
```bash
|
||||
# CLI 模式
|
||||
python src/main.py --help
|
||||
python src/main.py [命令] [参数]
|
||||
## 1️⃣ 问题定义与数据
|
||||
|
||||
# 或 Web 模式(如有)
|
||||
# streamlit run app.py
|
||||
### 1.1 任务描述
|
||||
|
||||
(描述预测任务类型:分类/回归/时序,以及业务目标)
|
||||
|
||||
### 1.2 数据来源
|
||||
|
||||
| 项目 | 说明 |
|
||||
|------|------|
|
||||
| 数据集名称 | XXX |
|
||||
| 数据链接 | [Kaggle](https://...) |
|
||||
| 样本量 | X,XXX 条 |
|
||||
| 特征数 | XX 个 |
|
||||
|
||||
### 1.3 数据切分与防泄漏
|
||||
|
||||
(如何切分训练/验证/测试集?如何确保没有数据泄漏?)
|
||||
|
||||
---
|
||||
|
||||
## 2️⃣ 机器学习流水线
|
||||
|
||||
### 2.1 基线模型
|
||||
|
||||
| 模型 | 指标 | 结果 |
|
||||
|------|------|------|
|
||||
| Logistic Regression | ROC-AUC | 0.XX |
|
||||
|
||||
### 2.2 进阶模型
|
||||
|
||||
| 模型 | 指标 | 结果 |
|
||||
|------|------|------|
|
||||
| LightGBM | ROC-AUC | 0.XX |
|
||||
|
||||
### 2.3 误差分析
|
||||
|
||||
(模型在哪些样本上表现不佳?为什么?)
|
||||
|
||||
---
|
||||
|
||||
## 3️⃣ Agent 实现
|
||||
|
||||
### 3.1 工具定义
|
||||
|
||||
| 工具名 | 功能 | 输入 | 输出 |
|
||||
|--------|------|------|------|
|
||||
| `predict_risk` | 调用 ML 模型预测 | CustomerFeatures | float |
|
||||
| `explain_features` | 解释特征影响 | CustomerFeatures | list[str] |
|
||||
|
||||
### 3.2 决策流程
|
||||
|
||||
(Agent 如何使用工具?如:预测 → 解释 → 建议)
|
||||
|
||||
### 3.3 案例展示
|
||||
|
||||
**输入**:
|
||||
```
|
||||
请分析这位客户的流失风险:年龄 35,任期 2 个月,月费 89.99
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
```bash
|
||||
# 示例命令 1
|
||||
python src/main.py example1
|
||||
|
||||
# 示例命令 2
|
||||
python src/main.py example2
|
||||
**输出**:
|
||||
```json
|
||||
{
|
||||
"risk_score": 0.72,
|
||||
"decision": "高风险,建议主动挥留",
|
||||
"actions": ["发送优惠短信", "客服回访"],
|
||||
"rationale": "新客户 + 月付合同是流失高危特征"
|
||||
}
|
||||
```
|
||||
|
||||
## 项目结构
|
||||
---
|
||||
|
||||
```
|
||||
project/
|
||||
├── src/
|
||||
│ ├── __init__.py
|
||||
│ ├── main.py # 主入口
|
||||
│ └── ... # 其他模块
|
||||
├── data/ # 数据文件
|
||||
├── output/ # 输出文件
|
||||
├── manifest.yaml # 项目运行声明
|
||||
├── requirements.txt # 依赖
|
||||
└── README.md # 本文件
|
||||
```
|
||||
## 4️⃣ 开发心得
|
||||
|
||||
## 作者
|
||||
### 4.1 主要困难与解决方案
|
||||
|
||||
[姓名] - [学号]
|
||||
(遇到的最大困难是什么?如何解决?)
|
||||
|
||||
### 4.2 对 AI 辅助编程的感受
|
||||
|
||||
(使用 AI 工具的体验如何?哪些场景有帮助?哪些地方需要注意?)
|
||||
|
||||
### 4.3 局限与未来改进
|
||||
|
||||
(如果有更多时间,还有哪些可以改进的地方?)
|
||||
````
|
||||
|
||||
---
|
||||
|
||||
## 参考资料
|
||||
|
||||
### 核心工具文档
|
||||
|
||||
| 资源 | 链接 | 说明 |
|
||||
|------|------|------|
|
||||
| uv 官方文档 | https://docs.astral.sh/uv/ | Python 项目管理器 |
|
||||
| Polars 用户指南 | https://pola.rs/ | 高性能 DataFrame |
|
||||
| Pydantic 文档 | https://docs.pydantic.dev/ | 数据验证与设置 |
|
||||
| Pandera 文档 | https://pandera.readthedocs.io/ | DataFrame Schema 验证 |
|
||||
| pydantic-ai 文档 | https://ai.pydantic.dev/ | Agent 框架 |
|
||||
| DeepSeek API | https://api.deepseek.com | OpenAI 兼容 |
|
||||
|
||||
### 推荐学习资源
|
||||
|
||||
| 资源 | 链接 |
|
||||
|------|------|
|
||||
| Polars vs Pandas | https://pola.rs/user-guide/migration/pandas/ |
|
||||
| Pydantic AI 快速入门 | https://ai.pydantic.dev/quick-start/ |
|
||||
| Pandera 快速入门 | https://pandera.readthedocs.io/en/stable/try_pandera.html |
|
||||
| uv 项目工作流 | https://docs.astral.sh/uv/concepts/projects/ |
|
||||
|
||||
---
|
||||
|
||||
## 📋 Checklist(提交前自检)
|
||||
|
||||
- [ ] 使用 `uv sync` 安装依赖,无需手动创建虚拟环境
|
||||
- [ ] `.gitignore` 包含 `.env`、`__pycache__`、大文件
|
||||
- [ ] 在干净环境下可以复现(`git clone && uv sync && uv run`)
|
||||
- [ ] 没有提交 API Key 或敏感信息
|
||||
- [ ] 使用 Polars 进行数据处理
|
||||
- [ ] 使用 Pydantic 定义特征和输出模型
|
||||
- [ ] Agent 至少有 2 个 tool(含 1 个 ML 工具)
|
||||
- [ ] README.md 说明了数据切分策略
|
||||
- [ ] Demo 可以正常运行
|
||||
|
||||
---
|
||||
|
||||
> 💬 **有问题?** 请在课程群/Issue 中提问,我们会尽快回复。
|
||||
|
||||
5
data/README.md
Normal file
5
data/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
# Data Directory
|
||||
|
||||
Place your raw data files here.
|
||||
|
||||
For this example project, the data is generated synthetically in `src/data.py`, so no external files are needed.
|
||||
0
models/.gitkeep
Normal file
0
models/.gitkeep
Normal file
BIN
models/model.pkl
Normal file
BIN
models/model.pkl
Normal file
Binary file not shown.
56
pyproject.toml
Normal file
56
pyproject.toml
Normal file
@ -0,0 +1,56 @@
|
||||
[project]
|
||||
name = "ml-course-design"
|
||||
version = "0.1.0"
|
||||
description = "机器学习 × LLM × Agent 课程设计模板"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"joblib>=1.5.3",
|
||||
"pandera>=0.28.1",
|
||||
"polars>=1.37.0",
|
||||
"pydantic-ai>=1.41.0",
|
||||
"python-dotenv>=1.2.1",
|
||||
"scikit-learn>=1.8.0",
|
||||
"streamlit>=1.52.2",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
default = true
|
||||
|
||||
dependencies = [
|
||||
"pydantic>=2.10",
|
||||
"pandera>=0.21",
|
||||
"pydantic-ai>=0.7",
|
||||
"polars>=1.0",
|
||||
"pandas>=2.2",
|
||||
"scikit-learn>=1.5",
|
||||
"lightgbm>=4.5",
|
||||
"seaborn>=0.13",
|
||||
"joblib>=1.4",
|
||||
"python-dotenv>=1.0",
|
||||
"streamlit>=1.40",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=1.3",
|
||||
"ruff>=0.8",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
BIN
src/__pycache__/agent_app.cpython-312.pyc
Normal file
BIN
src/__pycache__/agent_app.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/data.cpython-312.pyc
Normal file
BIN
src/__pycache__/data.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/features.cpython-312.pyc
Normal file
BIN
src/__pycache__/features.cpython-312.pyc
Normal file
Binary file not shown.
BIN
src/__pycache__/infer.cpython-312.pyc
Normal file
BIN
src/__pycache__/infer.cpython-312.pyc
Normal file
Binary file not shown.
191
src/agent_app.py
Normal file
191
src/agent_app.py
Normal file
@ -0,0 +1,191 @@
|
||||
"""pydantic-ai Agent 应用模块
|
||||
|
||||
使用 2026 pydantic-ai 最佳实践:
|
||||
- deps_type 依赖注入
|
||||
- @agent.instructions 动态指令
|
||||
- 结构化输出 (Pydantic models)
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Protocol
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic_ai import Agent, RunContext
|
||||
|
||||
from src.features import StudentFeatures, StudyGuidance
|
||||
from src.infer import explain_prediction, predict_pass_prob
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
# --- 1. 定义依赖协议和数据类 ---
|
||||
|
||||
|
||||
class MLModelProtocol(Protocol):
|
||||
"""ML 模型接口协议"""
|
||||
|
||||
def predict(self, features: StudentFeatures) -> float:
|
||||
"""预测通过概率"""
|
||||
...
|
||||
|
||||
def explain(self) -> str:
|
||||
"""获取模型解释"""
|
||||
...
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentDeps:
|
||||
"""Agent 依赖项
|
||||
|
||||
封装 ML 模型和学生特征,通过依赖注入传递给 Agent。
|
||||
"""
|
||||
|
||||
student: StudentFeatures
|
||||
model_path: str = "models/model.pkl"
|
||||
|
||||
|
||||
# --- 2. 定义 Agent ---
|
||||
|
||||
|
||||
study_advisor = Agent(
|
||||
"deepseek:deepseek-chat",
|
||||
deps_type=AgentDeps,
|
||||
output_type=StudyGuidance,
|
||||
instructions=(
|
||||
"你是一个严谨的学业数据分析师。你的任务是根据学生的具体情况预测其考试通过率,并给出建议。\n"
|
||||
"【重要规则】\n"
|
||||
"1. 必须先调用 `predict_pass_probability` 获取概率。\n"
|
||||
"2. 必须调用 `get_model_explanation` 获取模型认为最重要的特征,并在 `key_factors` 中引用这些特征。\n"
|
||||
"3. 你的建议必须针对那些最重要的特征(例如,如果模型说睡眠很重要,就给睡眠建议)。\n"
|
||||
"4. 严禁凭空编造数值。所有数据必须来自工具返回。\n"
|
||||
"5. `rationale` 必须引用 `key_factors` 中的具体因素。"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@study_advisor.instructions
|
||||
async def add_student_context(ctx: RunContext[AgentDeps]) -> str:
|
||||
"""动态添加学生信息到系统提示"""
|
||||
s = ctx.deps.student
|
||||
return (
|
||||
f"当前学生信息:\n"
|
||||
f"- 每周学习时长: {s.study_hours} 小时\n"
|
||||
f"- 每晚睡眠时长: {s.sleep_hours} 小时\n"
|
||||
f"- 出勤率: {s.attendance_rate:.0%}\n"
|
||||
f"- 压力等级: {s.stress_level}/5\n"
|
||||
f"- 学习方式: {s.study_type}"
|
||||
)
|
||||
|
||||
|
||||
# --- 3. 注册工具 ---
|
||||
|
||||
|
||||
@study_advisor.tool
|
||||
async def predict_pass_probability(ctx: RunContext[AgentDeps]) -> float:
|
||||
"""调用 ML 模型预测学生通过概率
|
||||
|
||||
Returns:
|
||||
float: 预测通过率 (0-1)
|
||||
"""
|
||||
s = ctx.deps.student
|
||||
return predict_pass_prob(
|
||||
study_hours=s.study_hours,
|
||||
sleep_hours=s.sleep_hours,
|
||||
attendance_rate=s.attendance_rate,
|
||||
stress_level=s.stress_level,
|
||||
study_type=s.study_type,
|
||||
)
|
||||
|
||||
|
||||
@study_advisor.tool
|
||||
async def get_model_explanation(ctx: RunContext[AgentDeps]) -> str:
|
||||
"""获取 ML 模型的特征重要性解释
|
||||
|
||||
Returns:
|
||||
str: 特征重要性排名说明
|
||||
"""
|
||||
return explain_prediction()
|
||||
|
||||
|
||||
# --- 4. 咨询师 Agent (多轮对话) ---
|
||||
|
||||
|
||||
counselor_agent = Agent(
|
||||
"deepseek:deepseek-chat",
|
||||
deps_type=AgentDeps,
|
||||
instructions=(
|
||||
"你是一位富有同理心且专业的大学心理咨询师。\n"
|
||||
"你的目标是倾听学生的学业压力和生活烦恼,提供情感支持。\n"
|
||||
"【交互风格】\n"
|
||||
"1. 同理心:首先通过复述或确认学生的感受来表达理解。\n"
|
||||
"2. 引导性:不要急于给出解决方案,先通过提问了解更多背景。\n"
|
||||
"3. 数据驱动(可选):如果学生询问具体通过率,请调用工具。\n"
|
||||
"4. 语气:温暖、支持、专业,像朋友一样交谈。"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@counselor_agent.tool
|
||||
async def predict_student_pass(ctx: RunContext[AgentDeps]) -> float:
|
||||
"""获取学生通过率预测(用于咨询过程提供客观数据)"""
|
||||
s = ctx.deps.student
|
||||
return predict_pass_prob(
|
||||
study_hours=s.study_hours,
|
||||
sleep_hours=s.sleep_hours,
|
||||
attendance_rate=s.attendance_rate,
|
||||
stress_level=s.stress_level,
|
||||
study_type=s.study_type,
|
||||
)
|
||||
|
||||
|
||||
@counselor_agent.tool
|
||||
async def explain_factors(ctx: RunContext[AgentDeps]) -> str:
|
||||
"""获取模型特征重要性解释"""
|
||||
return explain_prediction()
|
||||
|
||||
|
||||
# --- 5. 运行示例 ---
|
||||
|
||||
|
||||
async def main():
|
||||
"""运行 Agent 示例"""
|
||||
if not os.getenv("DEEPSEEK_API_KEY"):
|
||||
print("❌ 错误: 未设置 DEEPSEEK_API_KEY")
|
||||
print("请在 .env 文件中设置密钥,或 export DEEPSEEK_API_KEY='...'")
|
||||
return
|
||||
|
||||
# 构建学生特征
|
||||
student = StudentFeatures(
|
||||
study_hours=12,
|
||||
sleep_hours=4,
|
||||
attendance_rate=0.9,
|
||||
stress_level=4,
|
||||
study_type="Self",
|
||||
)
|
||||
|
||||
# 创建依赖
|
||||
deps = AgentDeps(student=student)
|
||||
|
||||
# 用户查询
|
||||
query = (
|
||||
"我最近压力很大 (等级4),每天只睡 4 小时,不过我每周自学(Self) 12 小时,"
|
||||
"出勤率大概 90%。请帮我分析一下我会挂科吗?基于模型告诉我怎么做最有效。"
|
||||
)
|
||||
|
||||
print(f"用户: {query}\n")
|
||||
print("Agent 正在思考并调用模型工具...\n")
|
||||
|
||||
try:
|
||||
result = await study_advisor.run(query, deps=deps)
|
||||
|
||||
print("--- 结构化分析报告 ---")
|
||||
print(result.output.model_dump_json(indent=2))
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 运行失败: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
223
src/data.py
Normal file
223
src/data.py
Normal file
@ -0,0 +1,223 @@
|
||||
"""数据生成、验证与处理模块
|
||||
|
||||
使用 Polars 进行高性能数据处理,Pandera 进行 DataFrame 校验。
|
||||
符合 2026 最佳实践。
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandera.polars as pa
|
||||
import polars as pl
|
||||
|
||||
|
||||
# --- Pandera Schema 定义 ---
|
||||
|
||||
|
||||
class RawStudentDataSchema(pa.DataFrameModel):
|
||||
"""原始数据 Schema(清洗前校验,宽松模式)
|
||||
|
||||
允许缺失值存在,用于验证数据读取后的基本结构。
|
||||
"""
|
||||
study_hours: float = pa.Field(nullable=True)
|
||||
sleep_hours: float = pa.Field(nullable=True)
|
||||
attendance_rate: float = pa.Field(nullable=True)
|
||||
stress_level: int = pa.Field(nullable=True)
|
||||
study_type: str = pa.Field(nullable=True)
|
||||
is_pass: int = pa.Field(nullable=True)
|
||||
|
||||
class Config:
|
||||
strict = False # 允许额外列
|
||||
coerce = True
|
||||
|
||||
|
||||
class CleanStudentDataSchema(pa.DataFrameModel):
|
||||
"""清洗后数据 Schema(严格模式)
|
||||
|
||||
不允许缺失值,强制约束检查。
|
||||
"""
|
||||
study_hours: float = pa.Field(ge=0, le=20, nullable=False)
|
||||
sleep_hours: float = pa.Field(ge=0, le=12, nullable=False)
|
||||
attendance_rate: float = pa.Field(ge=0, le=1, nullable=False)
|
||||
stress_level: int = pa.Field(ge=1, le=5, nullable=False)
|
||||
study_type: str = pa.Field(isin=["Group", "Self", "Online"], nullable=False)
|
||||
is_pass: int = pa.Field(isin=[0, 1], nullable=False)
|
||||
|
||||
class Config:
|
||||
strict = True # 不允许额外列
|
||||
coerce = True
|
||||
|
||||
|
||||
# --- 数据生成函数 ---
|
||||
|
||||
|
||||
def generate_data(n_samples: int = 2000, random_seed: int = 42) -> pl.DataFrame:
|
||||
"""生成学生行为模拟数据
|
||||
|
||||
包含:数值特征、类别特征、噪声、以及非线性关系。
|
||||
|
||||
特征:
|
||||
- study_hours (float): 每周学习时长 (0-20)
|
||||
- sleep_hours (float): 每晚睡眠时长 (3-10)
|
||||
- attendance_rate (float): 出勤率 (0.0-1.0)
|
||||
- study_type (str): 学习方式 ("Group", "Self", "Online")
|
||||
- stress_level (int): 压力等级 (1-5)
|
||||
|
||||
目标:
|
||||
- is_pass (int): 0 或 1
|
||||
|
||||
Args:
|
||||
n_samples: 生成样本数量
|
||||
random_seed: 随机种子,确保可复现
|
||||
|
||||
Returns:
|
||||
pl.DataFrame: Polars DataFrame 包含所有特征和标签
|
||||
"""
|
||||
np.random.seed(random_seed)
|
||||
|
||||
# 1. 生成基础特征
|
||||
study_hours = np.random.uniform(0, 15, n_samples)
|
||||
sleep_hours = np.random.normal(7, 1.5, n_samples).clip(3, 10)
|
||||
attendance_rate = np.random.beta(5, 2, n_samples) # 偏向于高出勤
|
||||
study_type = np.random.choice(
|
||||
["Group", "Self", "Online"],
|
||||
n_samples,
|
||||
p=[0.3, 0.5, 0.2]
|
||||
)
|
||||
stress_level = np.random.randint(1, 6, n_samples)
|
||||
|
||||
# 2. 模拟真实世界逻辑 (分数计算)
|
||||
score = np.full(n_samples, 40.0)
|
||||
|
||||
# 线性影响
|
||||
score += study_hours * 3.0
|
||||
score += (attendance_rate - 0.5) * 30
|
||||
|
||||
# 非线性/交互影响:睡眠不足严重扣分
|
||||
score -= np.maximum(0, 6 - sleep_hours) * 8
|
||||
|
||||
# 类别特征影响
|
||||
mask_group = study_type == "Group"
|
||||
mask_self = study_type == "Self"
|
||||
score[mask_group] += 5
|
||||
score[mask_self] += study_hours[mask_self] * 0.5 # 额外加成
|
||||
|
||||
# 压力影响
|
||||
score -= (stress_level - 1) * 2
|
||||
|
||||
# 3. 添加随机噪声
|
||||
noise = np.random.normal(0, 8, n_samples)
|
||||
final_score = score + noise
|
||||
|
||||
# 4. 生成标签 (及格线 60)
|
||||
is_pass = (final_score >= 60).astype(np.int32)
|
||||
|
||||
# 5. 使用 Polars 构建 DataFrame
|
||||
df = pl.DataFrame({
|
||||
"study_hours": study_hours,
|
||||
"sleep_hours": sleep_hours,
|
||||
"attendance_rate": attendance_rate,
|
||||
"study_type": study_type,
|
||||
"stress_level": stress_level,
|
||||
"is_pass": is_pass,
|
||||
})
|
||||
|
||||
# 6. 人为制造缺失值 (模拟真实数据清洗需求)
|
||||
# 随机丢弃 5% 的 attendance_rate
|
||||
mask_na = np.random.random(n_samples) < 0.05
|
||||
df = df.with_columns(
|
||||
pl.when(pl.Series(mask_na))
|
||||
.then(None)
|
||||
.otherwise(pl.col("attendance_rate"))
|
||||
.alias("attendance_rate")
|
||||
)
|
||||
|
||||
return df
|
||||
|
||||
|
||||
def validate_raw_data(df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""验证原始数据结构(清洗前)
|
||||
|
||||
使用宽松模式校验,允许缺失值。
|
||||
|
||||
Args:
|
||||
df: 原始 Polars DataFrame
|
||||
|
||||
Returns:
|
||||
pl.DataFrame: 验证通过的 DataFrame
|
||||
|
||||
Raises:
|
||||
pa.errors.SchemaError: 验证失败
|
||||
"""
|
||||
return RawStudentDataSchema.validate(df)
|
||||
|
||||
|
||||
def validate_clean_data(df: pl.DataFrame) -> pl.DataFrame:
|
||||
"""验证清洗后数据(严格模式)
|
||||
|
||||
不允许缺失值,强制约束检查。
|
||||
|
||||
Args:
|
||||
df: 清洗后的 Polars DataFrame
|
||||
|
||||
Returns:
|
||||
pl.DataFrame: 验证通过的 DataFrame
|
||||
|
||||
Raises:
|
||||
pa.errors.SchemaError: 验证失败
|
||||
"""
|
||||
return CleanStudentDataSchema.validate(df)
|
||||
|
||||
|
||||
def preprocess_data(df: pl.DataFrame, validate: bool = True) -> pl.DataFrame:
|
||||
"""数据预处理流水线
|
||||
|
||||
1. 删除缺失值
|
||||
2. 删除重复行
|
||||
3. 可选:进行 Schema 校验
|
||||
|
||||
Args:
|
||||
df: 原始 Polars DataFrame
|
||||
validate: 是否进行清洗后 Schema 校验
|
||||
|
||||
Returns:
|
||||
pl.DataFrame: 清洗后的 DataFrame
|
||||
"""
|
||||
# 删除缺失值
|
||||
df_clean = df.drop_nulls()
|
||||
|
||||
# 删除重复行
|
||||
df_clean = df_clean.unique()
|
||||
|
||||
# 可选校验
|
||||
if validate:
|
||||
df_clean = validate_clean_data(df_clean)
|
||||
|
||||
return df_clean
|
||||
|
||||
|
||||
def get_feature_columns() -> tuple[list[str], list[str]]:
|
||||
"""获取特征列名
|
||||
|
||||
Returns:
|
||||
tuple: (数值特征列表, 类别特征列表)
|
||||
"""
|
||||
numeric_features = ["study_hours", "sleep_hours", "attendance_rate", "stress_level"]
|
||||
categorical_features = ["study_type"]
|
||||
return numeric_features, categorical_features
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(">>> 1. 生成数据")
|
||||
df = generate_data()
|
||||
print(df.head())
|
||||
print(f"\n缺失值统计:\n{df.null_count()}")
|
||||
|
||||
print("\n>>> 2. 验证原始数据 (宽松模式)")
|
||||
df_validated = validate_raw_data(df)
|
||||
print("✅ 原始数据验证通过")
|
||||
|
||||
print("\n>>> 3. 清洗数据")
|
||||
df_clean = preprocess_data(df, validate=True)
|
||||
print(f"清洗后样本数: {len(df_clean)} (原始: {len(df)})")
|
||||
print("✅ 清洗后数据验证通过")
|
||||
|
||||
print(f"\n及格率: {df_clean['is_pass'].mean():.2f}")
|
||||
47
src/features.py
Normal file
47
src/features.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""Pydantic 模型定义模块
|
||||
|
||||
定义学生特征输入和 Agent 结构化输出。
|
||||
符合 2026 pydantic-ai 最佳实践。
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class StudentFeatures(BaseModel):
|
||||
"""学生行为特征输入
|
||||
|
||||
用于预测学生考试通过率的核心特征。
|
||||
使用 Pydantic 进行类型验证和约束检查。
|
||||
"""
|
||||
|
||||
study_hours: float = Field(ge=0, le=20, description="每周学习小时数 (0-20)")
|
||||
sleep_hours: float = Field(ge=0, le=12, description="每天睡眠小时数 (0-12)")
|
||||
attendance_rate: float = Field(ge=0, le=1, description="出勤率 (0.0-1.0)")
|
||||
stress_level: int = Field(ge=1, le=5, description="压力等级 1(低) - 5(高)")
|
||||
study_type: str = Field(
|
||||
pattern="^(Group|Self|Online)$", description="学习类型 (Group/Self/Online)"
|
||||
)
|
||||
|
||||
|
||||
class ActionItem(BaseModel):
|
||||
"""可执行行动项"""
|
||||
|
||||
action: str = Field(description="具体的行动建议")
|
||||
priority: str = Field(pattern="^(高|中|低)$", description="优先级 (高/中/低)")
|
||||
|
||||
|
||||
class StudyGuidance(BaseModel):
|
||||
"""Agent 输出的结构化学业指导
|
||||
|
||||
包含预测概率、风险评估和可执行建议。
|
||||
"""
|
||||
|
||||
pass_probability: float = Field(ge=0, le=1, description="预测通过率 (0-1)")
|
||||
risk_assessment: str = Field(
|
||||
pattern="^(低风险|中等风险|高风险)$", description="风险等级评估 (低风险/中等风险/高风险)"
|
||||
)
|
||||
key_factors: list[str] = Field(description="影响预测结果的关键因素(来自模型解释)")
|
||||
action_plan: list[ActionItem] = Field(
|
||||
min_length=3, max_length=8, description="3-8条可执行建议清单"
|
||||
)
|
||||
rationale: str = Field(description="建议依据说明(必须引用模型给出的关键因素)")
|
||||
123
src/infer.py
Normal file
123
src/infer.py
Normal file
@ -0,0 +1,123 @@
|
||||
"""推理模块
|
||||
|
||||
提供 ML 模型加载和预测功能,供 Agent 工具调用。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
MODEL_PATH = Path("models") / "model.pkl"
|
||||
_MODEL: Pipeline | None = None
|
||||
|
||||
|
||||
def load_model() -> Pipeline:
|
||||
"""加载训练好的 ML 模型
|
||||
|
||||
Returns:
|
||||
Pipeline: scikit-learn Pipeline 对象
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: 模型文件不存在
|
||||
"""
|
||||
global _MODEL
|
||||
if _MODEL is None:
|
||||
if not MODEL_PATH.exists():
|
||||
raise FileNotFoundError(
|
||||
f"未找到模型文件 {MODEL_PATH}。请先运行 uv run python src/train.py"
|
||||
)
|
||||
_MODEL = joblib.load(MODEL_PATH)
|
||||
return _MODEL
|
||||
|
||||
|
||||
def predict_pass_prob(
|
||||
study_hours: float,
|
||||
sleep_hours: float,
|
||||
attendance_rate: float,
|
||||
stress_level: int,
|
||||
study_type: str,
|
||||
) -> float:
|
||||
"""预测学生通过概率
|
||||
|
||||
Args:
|
||||
study_hours: 每周学习小时数 (0-20)
|
||||
sleep_hours: 每天睡眠小时数 (0-12)
|
||||
attendance_rate: 出勤率 (0.0-1.0)
|
||||
stress_level: 压力等级 1-5
|
||||
study_type: 学习类型 (Group/Self/Online)
|
||||
|
||||
Returns:
|
||||
float: 通过概率 (0.0 - 1.0)
|
||||
"""
|
||||
model = load_model()
|
||||
|
||||
# 构建 DataFrame (与训练时的输入格式一致)
|
||||
features = pd.DataFrame(
|
||||
[
|
||||
{
|
||||
"study_hours": study_hours,
|
||||
"sleep_hours": sleep_hours,
|
||||
"attendance_rate": attendance_rate,
|
||||
"stress_level": stress_level,
|
||||
"study_type": study_type,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
try:
|
||||
# predict_proba 返回 [proba_false, proba_true]
|
||||
proba = model.predict_proba(features)[0, 1]
|
||||
return float(proba)
|
||||
except Exception as e:
|
||||
print(f"Prediction Error: {e}")
|
||||
return 0.0
|
||||
|
||||
|
||||
def explain_prediction() -> str:
|
||||
"""解释模型的全局特征重要性
|
||||
|
||||
从保存的 Pipeline 中提取特征重要性。
|
||||
|
||||
Returns:
|
||||
str: 特征重要性排名说明
|
||||
"""
|
||||
model = load_model()
|
||||
|
||||
try:
|
||||
# Pipeline 结构: [('preprocessor', ColumnTransformer), ('classifier', RandomForest)]
|
||||
preprocessor = model.named_steps["preprocessor"]
|
||||
clf = model.named_steps["classifier"]
|
||||
|
||||
# 获取特征名
|
||||
num_feats = ["study_hours", "sleep_hours", "attendance_rate", "stress_level"]
|
||||
|
||||
# 获取 OneHot 后的类别特征名
|
||||
cat_encoder = preprocessor.named_transformers_["cat"].named_steps["onehot"]
|
||||
cat_feats = cat_encoder.get_feature_names_out(["study_type"])
|
||||
|
||||
all_feats = np.concatenate([num_feats, cat_feats])
|
||||
|
||||
# 获取重要性数值
|
||||
importances = clf.feature_importances_
|
||||
|
||||
# 排序并输出
|
||||
indices = np.argsort(importances)[::-1]
|
||||
|
||||
lines = ["### 模型特征重要性排名 (Top 5):"]
|
||||
for i in range(min(5, len(importances))):
|
||||
idx = indices[i]
|
||||
lines.append(f"{i + 1}. {all_feats[idx]}: {importances[idx]:.4f}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
except Exception as e:
|
||||
return f"无法解释模型特征 (可能模型结构不同): {e!s}"
|
||||
|
||||
|
||||
def reset_model_cache() -> None:
|
||||
"""重置模型缓存(用于测试)"""
|
||||
global _MODEL
|
||||
_MODEL = None
|
||||
250
src/streamlit_app.py
Normal file
250
src/streamlit_app.py
Normal file
@ -0,0 +1,250 @@
|
||||
"""Streamlit 演示应用
|
||||
|
||||
学生成绩预测 AI 助手 - 支持成绩预测分析和心理咨询对话。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
|
||||
import streamlit as st
|
||||
|
||||
# Ensure project root is in path
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from pydantic_ai import FunctionToolCallEvent, FunctionToolResultEvent, PartDeltaEvent
|
||||
from pydantic_ai.messages import (
|
||||
ModelRequest,
|
||||
ModelResponse,
|
||||
TextPart,
|
||||
TextPartDelta,
|
||||
UserPromptPart,
|
||||
)
|
||||
|
||||
from src.agent_app import AgentDeps, counselor_agent, study_advisor
|
||||
from src.features import StudentFeatures
|
||||
|
||||
# Load env variables
|
||||
load_dotenv()
|
||||
|
||||
st.set_page_config(page_title="学生成绩预测 AI 助手", page_icon="🎓", layout="wide")
|
||||
|
||||
# Sidebar Configuration
|
||||
st.sidebar.header("🔧 配置")
|
||||
api_key = st.sidebar.text_input(
|
||||
"DeepSeek API Key", type="password", value=os.getenv("DEEPSEEK_API_KEY", "")
|
||||
)
|
||||
|
||||
if api_key:
|
||||
os.environ["DEEPSEEK_API_KEY"] = api_key
|
||||
|
||||
st.sidebar.markdown("---")
|
||||
# Mode Selection
|
||||
mode = st.sidebar.radio("功能选择", ["📊 成绩预测", "💬 心理咨询"])
|
||||
|
||||
# --- Helper Functions ---
|
||||
|
||||
|
||||
async def run_analysis(
|
||||
study_hours: float,
|
||||
sleep_hours: float,
|
||||
attendance_rate: float,
|
||||
stress_level: int,
|
||||
study_type: str,
|
||||
):
|
||||
"""运行成绩预测分析"""
|
||||
try:
|
||||
if not os.getenv("DEEPSEEK_API_KEY"):
|
||||
st.error("请在侧边栏提供 DeepSeek API Key。")
|
||||
return None
|
||||
|
||||
# 创建学生特征
|
||||
student = StudentFeatures(
|
||||
study_hours=study_hours,
|
||||
sleep_hours=sleep_hours,
|
||||
attendance_rate=attendance_rate,
|
||||
stress_level=stress_level,
|
||||
study_type=study_type,
|
||||
)
|
||||
|
||||
# 创建依赖
|
||||
deps = AgentDeps(student=student)
|
||||
|
||||
# 构建查询
|
||||
query = (
|
||||
f"请分析这位学生的通过率并给出建议。"
|
||||
f"学生信息已通过工具获取,请调用 predict_pass_probability 和 get_model_explanation 工具。"
|
||||
)
|
||||
|
||||
with st.spinner("🤖 Agent 正在思考... (调用 DeepSeek + 随机森林模型)"):
|
||||
result = await study_advisor.run(query, deps=deps)
|
||||
return result.output
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"分析失败: {e!s}")
|
||||
return None
|
||||
|
||||
|
||||
async def run_counselor_stream(
|
||||
query: str,
|
||||
history: list,
|
||||
placeholder,
|
||||
student: StudentFeatures,
|
||||
):
|
||||
"""
|
||||
运行咨询师对话流,手动处理流式响应和工具调用事件。
|
||||
"""
|
||||
try:
|
||||
if not os.getenv("DEEPSEEK_API_KEY"):
|
||||
placeholder.error("❌ 错误: 请在侧边栏提供 DeepSeek API Key。")
|
||||
return None
|
||||
|
||||
# 创建依赖
|
||||
deps = AgentDeps(student=student)
|
||||
|
||||
full_response = ""
|
||||
# Status container for tool calls
|
||||
status_placeholder = st.empty()
|
||||
|
||||
# Call Counselor Agent with streaming
|
||||
async for event in counselor_agent.run_stream_events(query, deps=deps, message_history=history):
|
||||
# Handle Text Delta (Wrapped in PartDeltaEvent)
|
||||
if isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta):
|
||||
full_response += event.delta.content_delta
|
||||
placeholder.markdown(full_response + "▌")
|
||||
|
||||
# Handle Tool Call Start
|
||||
elif isinstance(event, FunctionToolCallEvent):
|
||||
status_placeholder.info(f"🛠️ 咨询师正在使用工具: `{event.part.tool_name}` ...")
|
||||
|
||||
# Handle Tool Result
|
||||
elif isinstance(event, FunctionToolResultEvent):
|
||||
status_placeholder.empty()
|
||||
|
||||
placeholder.markdown(full_response)
|
||||
status_placeholder.empty() # Ensure clear
|
||||
return full_response
|
||||
|
||||
except Exception as e:
|
||||
placeholder.error(f"❌ 咨询失败: {e!s}")
|
||||
return None
|
||||
|
||||
|
||||
# --- Main Views ---
|
||||
|
||||
if mode == "📊 成绩预测":
|
||||
st.title("🎓 学生成绩预测助手")
|
||||
st.markdown("在下方输入学生详细信息,获取 AI 驱动的成绩分析。")
|
||||
|
||||
with st.form("student_data_form"):
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
study_hours = st.slider("每周学习时长 (小时)", 0.0, 20.0, 10.0, 0.5)
|
||||
sleep_hours = st.slider("日均睡眠时长 (小时)", 0.0, 12.0, 7.0, 0.5)
|
||||
|
||||
with col2:
|
||||
attendance_rate = st.slider("出勤率", 0.0, 1.0, 0.9, 0.05)
|
||||
stress_level = st.select_slider(
|
||||
"压力等级 (1=低, 5=高)", options=[1, 2, 3, 4, 5], value=3
|
||||
)
|
||||
study_type = st.radio("主要学习方式", ["Self", "Group", "Online"], horizontal=True)
|
||||
|
||||
submitted = st.form_submit_button("🚀 分析通过率")
|
||||
|
||||
if submitted:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
guidance = loop.run_until_complete(
|
||||
run_analysis(study_hours, sleep_hours, attendance_rate, stress_level, study_type)
|
||||
)
|
||||
|
||||
if guidance:
|
||||
st.divider()
|
||||
st.subheader("📊 分析结果")
|
||||
m1, m2, m3 = st.columns(3)
|
||||
m1.metric("预测通过率", f"{guidance.pass_probability:.1%}")
|
||||
m2.metric(
|
||||
"风险评估",
|
||||
"高风险" if guidance.pass_probability < 0.6 else "低风险",
|
||||
delta="-高风险" if guidance.pass_probability < 0.6 else "+安全",
|
||||
)
|
||||
|
||||
st.info(f"**风险评估:** {guidance.risk_assessment}")
|
||||
|
||||
# 显示关键因素
|
||||
st.subheader("🔍 关键因素")
|
||||
for factor in guidance.key_factors:
|
||||
st.write(f"- {factor}")
|
||||
|
||||
st.subheader("✅ 行动计划")
|
||||
actions = [
|
||||
{"优先级": item.priority, "建议行动": item.action} for item in guidance.action_plan
|
||||
]
|
||||
st.table(actions)
|
||||
|
||||
st.subheader("💡 分析依据")
|
||||
st.write(guidance.rationale)
|
||||
|
||||
elif mode == "💬 心理咨询":
|
||||
st.title("🧩 AI 心理咨询室")
|
||||
st.markdown("这里是安全且私密的空间。有些压力如果你愿意说,我愿意听。")
|
||||
|
||||
# Sidebar for student info (optional for counselor context)
|
||||
with st.sidebar.expander("📝 学生信息 (可选)", expanded=False):
|
||||
c_study_hours = st.slider("每周学习时长", 0.0, 20.0, 10.0, 0.5, key="c_study")
|
||||
c_sleep_hours = st.slider("日均睡眠时长", 0.0, 12.0, 7.0, 0.5, key="c_sleep")
|
||||
c_attendance = st.slider("出勤率", 0.0, 1.0, 0.9, 0.05, key="c_att")
|
||||
c_stress = st.select_slider("压力等级", options=[1, 2, 3, 4, 5], value=3, key="c_stress")
|
||||
c_study_type = st.radio("学习方式", ["Self", "Group", "Online"], key="c_type")
|
||||
|
||||
# Initialize chat history
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state.messages = []
|
||||
|
||||
# Display chat messages from history on app rerun
|
||||
for message in st.session_state.messages:
|
||||
with st.chat_message(message["role"]):
|
||||
st.markdown(message["content"])
|
||||
|
||||
# React to user input
|
||||
if prompt := st.chat_input("想聊聊什么?"):
|
||||
# Display user message
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
# Add user message to history
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
# Prepare history for pydantic-ai
|
||||
api_history = []
|
||||
for msg in st.session_state.messages[:-1]:
|
||||
if msg["role"] == "user":
|
||||
api_history.append(ModelRequest(parts=[UserPromptPart(content=msg["content"])]))
|
||||
elif msg["role"] == "assistant":
|
||||
api_history.append(ModelResponse(parts=[TextPart(content=msg["content"])]))
|
||||
|
||||
# Create student features for counselor context
|
||||
student = StudentFeatures(
|
||||
study_hours=c_study_hours,
|
||||
sleep_hours=c_sleep_hours,
|
||||
attendance_rate=c_attendance,
|
||||
stress_level=c_stress,
|
||||
study_type=c_study_type,
|
||||
)
|
||||
|
||||
# Generate response
|
||||
with st.chat_message("assistant"):
|
||||
placeholder = st.empty()
|
||||
with st.spinner("咨询师正在倾听..."):
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
# Run the manual streaming function
|
||||
response_text = loop.run_until_complete(
|
||||
run_counselor_stream(prompt, api_history, placeholder, student)
|
||||
)
|
||||
|
||||
if response_text:
|
||||
st.session_state.messages.append(
|
||||
{"role": "assistant", "content": response_text}
|
||||
)
|
||||
124
src/train.py
Normal file
124
src/train.py
Normal file
@ -0,0 +1,124 @@
|
||||
"""训练模块
|
||||
|
||||
使用 Polars 进行数据处理,scikit-learn 进行模型训练。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import joblib
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.ensemble import RandomForestClassifier
|
||||
from sklearn.impute import SimpleImputer
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import classification_report, f1_score
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.preprocessing import OneHotEncoder, StandardScaler
|
||||
|
||||
from src.data import generate_data, get_feature_columns, preprocess_data
|
||||
|
||||
MODELS_DIR = Path("models")
|
||||
MODEL_PATH = MODELS_DIR / "model.pkl"
|
||||
|
||||
|
||||
def get_pipeline(model_type: str = "rf") -> Pipeline:
|
||||
"""构建 sklearn 处理流水线
|
||||
|
||||
1. 数值特征 -> 缺失填充 (均值) -> 标准化
|
||||
2. 类别特征 -> 缺失填充 (众数) -> OneHot编码
|
||||
3. 模型 -> LR 或 RF
|
||||
|
||||
Args:
|
||||
model_type: 模型类型 ("lr" 或 "rf")
|
||||
|
||||
Returns:
|
||||
Pipeline: 完整的 sklearn Pipeline
|
||||
"""
|
||||
numeric_features, categorical_features = get_feature_columns()
|
||||
|
||||
# 数值处理管道
|
||||
numeric_transformer = Pipeline(
|
||||
steps=[
|
||||
("imputer", SimpleImputer(strategy="mean")),
|
||||
("scaler", StandardScaler()),
|
||||
]
|
||||
)
|
||||
|
||||
# 类别处理管道
|
||||
categorical_transformer = Pipeline(
|
||||
steps=[
|
||||
("imputer", SimpleImputer(strategy="most_frequent")),
|
||||
("onehot", OneHotEncoder(handle_unknown="ignore")),
|
||||
]
|
||||
)
|
||||
|
||||
# 组合预处理
|
||||
preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
("num", numeric_transformer, numeric_features),
|
||||
("cat", categorical_transformer, categorical_features),
|
||||
]
|
||||
)
|
||||
|
||||
# 选择模型
|
||||
if model_type == "lr":
|
||||
clf = LogisticRegression(random_state=42, max_iter=1000)
|
||||
else:
|
||||
clf = RandomForestClassifier(n_estimators=500, max_depth=5, random_state=42)
|
||||
|
||||
return Pipeline(steps=[("preprocessor", preprocessor), ("classifier", clf)])
|
||||
|
||||
|
||||
def train() -> None:
|
||||
"""执行完整训练流程"""
|
||||
print(">>> 1. 数据准备 (使用 Polars)")
|
||||
df_polars = generate_data(n_samples=2000)
|
||||
df_polars = preprocess_data(df_polars)
|
||||
|
||||
# 转换为 pandas 用于 sklearn
|
||||
df = df_polars.to_pandas()
|
||||
|
||||
X = df.drop(columns=["is_pass"])
|
||||
y = df["is_pass"]
|
||||
|
||||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||||
print(f"训练集大小: {X_train.shape}, 测试集大小: {X_test.shape}")
|
||||
|
||||
print("\n>>> 2. 模型训练与对比")
|
||||
# 模型 A: 逻辑回归 (Baseline)
|
||||
pipe_lr = get_pipeline("lr")
|
||||
pipe_lr.fit(X_train, y_train)
|
||||
y_pred_lr = pipe_lr.predict(X_test)
|
||||
f1_lr = f1_score(y_test, y_pred_lr)
|
||||
print(f"[Baseline - LogisticRegression] F1: {f1_lr:.4f}")
|
||||
|
||||
# 模型 B: 随机森林 (Target)
|
||||
pipe_rf = get_pipeline("rf")
|
||||
pipe_rf.fit(X_train, y_train)
|
||||
y_pred_rf = pipe_rf.predict(X_test)
|
||||
f1_rf = f1_score(y_test, y_pred_rf)
|
||||
print(f"[Target - RandomForest] F1: {f1_rf:.4f}")
|
||||
|
||||
print("\n>>> 3. 详细评估")
|
||||
best_model = pipe_rf
|
||||
print(classification_report(y_test, y_pred_rf))
|
||||
|
||||
print("\n>>> 4. 误差分析 (Error Analysis)")
|
||||
test_df = X_test.copy()
|
||||
test_df["True Label"] = y_test
|
||||
test_df["Pred Label"] = y_pred_rf
|
||||
|
||||
errors = test_df[test_df["True Label"] != test_df["Pred Label"]]
|
||||
print(f"总计错误样本数: {len(errors)}")
|
||||
if len(errors) > 0:
|
||||
print("典型错误样本预览:")
|
||||
print(errors.head(3))
|
||||
|
||||
print("\n>>> 5. 保存最佳模型")
|
||||
MODELS_DIR.mkdir(exist_ok=True)
|
||||
joblib.dump(best_model, MODEL_PATH)
|
||||
print(f"模型 Pipeline 已完整保存至 {MODEL_PATH}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
83
tests/test_agent.py
Normal file
83
tests/test_agent.py
Normal file
@ -0,0 +1,83 @@
|
||||
"""Agent 模块测试
|
||||
|
||||
测试 Agent 工具函数和依赖注入。
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
# 设置虚拟 key 避免 pydantic-ai 初始化错误
|
||||
os.environ["DEEPSEEK_API_KEY"] = "dummy_key_for_testing"
|
||||
|
||||
from src.agent_app import AgentDeps, study_advisor
|
||||
from src.features import StudentFeatures
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_student() -> StudentFeatures:
|
||||
"""创建测试用学生特征"""
|
||||
return StudentFeatures(
|
||||
study_hours=12,
|
||||
sleep_hours=7,
|
||||
attendance_rate=0.9,
|
||||
stress_level=2,
|
||||
study_type="Self",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_deps(sample_student: StudentFeatures) -> AgentDeps:
|
||||
"""创建测试用依赖"""
|
||||
return AgentDeps(student=sample_student)
|
||||
|
||||
|
||||
def test_agent_deps_creation(sample_deps: AgentDeps):
|
||||
"""测试 AgentDeps 创建"""
|
||||
assert sample_deps.student.study_hours == 12
|
||||
assert sample_deps.model_path == "models/model.pkl"
|
||||
|
||||
|
||||
def test_student_features_validation():
|
||||
"""测试 StudentFeatures 验证"""
|
||||
# 有效数据
|
||||
student = StudentFeatures(
|
||||
study_hours=10,
|
||||
sleep_hours=7,
|
||||
attendance_rate=0.85,
|
||||
stress_level=3,
|
||||
study_type="Group",
|
||||
)
|
||||
assert student.study_type == "Group"
|
||||
|
||||
# 无效 study_type
|
||||
with pytest.raises(ValueError):
|
||||
StudentFeatures(
|
||||
study_hours=10,
|
||||
sleep_hours=7,
|
||||
attendance_rate=0.85,
|
||||
stress_level=3,
|
||||
study_type="Invalid",
|
||||
)
|
||||
|
||||
|
||||
def test_tool_function_mock(sample_deps: AgentDeps):
|
||||
"""测试工具函数(mock 底层推理)"""
|
||||
with patch("src.agent_app.predict_pass_prob") as mock_predict:
|
||||
mock_predict.return_value = 0.85
|
||||
|
||||
# 由于工具是 async,我们直接测试底层函数
|
||||
|
||||
with patch("src.infer.load_model"):
|
||||
with patch("src.infer._MODEL") as mock_model:
|
||||
mock_model.predict_proba.return_value = [[0.15, 0.85]]
|
||||
# 这里只验证 mock 设置正确
|
||||
assert mock_predict.return_value == 0.85
|
||||
|
||||
|
||||
def test_agent_structure():
|
||||
"""测试 Agent 结构"""
|
||||
assert study_advisor is not None
|
||||
assert hasattr(study_advisor, "run")
|
||||
assert hasattr(study_advisor, "run_sync")
|
||||
39
tests/test_counselor_agent.py
Normal file
39
tests/test_counselor_agent.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""咨询师 Agent 测试"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
os.environ["DEEPSEEK_API_KEY"] = "dummy_key_for_testing"
|
||||
|
||||
from src.agent_app import AgentDeps, counselor_agent
|
||||
from src.features import StudentFeatures
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_deps() -> AgentDeps:
|
||||
"""创建测试用依赖"""
|
||||
return AgentDeps(
|
||||
student=StudentFeatures(
|
||||
study_hours=10,
|
||||
sleep_hours=6,
|
||||
attendance_rate=0.8,
|
||||
stress_level=4,
|
||||
study_type="Group",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_counselor_agent_structure():
|
||||
"""测试咨询师 Agent 结构"""
|
||||
assert counselor_agent is not None
|
||||
assert hasattr(counselor_agent, "run")
|
||||
assert hasattr(counselor_agent, "run_stream")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_counselor_agent_deps_type():
|
||||
"""测试 Agent 依赖类型"""
|
||||
# 验证 deps_type 设置正确
|
||||
assert counselor_agent._deps_type == AgentDeps
|
||||
111
tests/test_data.py
Normal file
111
tests/test_data.py
Normal file
@ -0,0 +1,111 @@
|
||||
"""数据模块测试
|
||||
|
||||
测试 Polars 数据生成、Pandera 校验和预处理功能。
|
||||
"""
|
||||
|
||||
import polars as pl
|
||||
import pytest
|
||||
|
||||
from src.data import (
|
||||
CleanStudentDataSchema,
|
||||
RawStudentDataSchema,
|
||||
generate_data,
|
||||
get_feature_columns,
|
||||
preprocess_data,
|
||||
validate_clean_data,
|
||||
validate_raw_data,
|
||||
)
|
||||
|
||||
|
||||
def test_generate_data_structure():
|
||||
"""测试生成数据的结构是否正确"""
|
||||
df = generate_data(n_samples=50)
|
||||
|
||||
assert isinstance(df, pl.DataFrame)
|
||||
assert len(df) == 50
|
||||
|
||||
expected_cols = [
|
||||
"study_hours",
|
||||
"sleep_hours",
|
||||
"attendance_rate",
|
||||
"study_type",
|
||||
"stress_level",
|
||||
"is_pass",
|
||||
]
|
||||
for col in expected_cols:
|
||||
assert col in df.columns
|
||||
|
||||
|
||||
def test_generate_data_content_range():
|
||||
"""测试生成数据的值范围是否正确"""
|
||||
df = generate_data(n_samples=50)
|
||||
|
||||
assert df["study_hours"].min() >= 0
|
||||
assert df["study_hours"].max() <= 20
|
||||
assert df["sleep_hours"].min() >= 0
|
||||
assert df["stress_level"].min() >= 1
|
||||
assert df["stress_level"].max() <= 5
|
||||
assert df["is_pass"].is_in([0, 1]).all()
|
||||
|
||||
|
||||
def test_generate_data_missing_values():
|
||||
"""测试数据是否包含预期的缺失值"""
|
||||
df = generate_data(n_samples=500, random_seed=42)
|
||||
# attendance_rate 有 5% 概率为 null
|
||||
null_count = df["attendance_rate"].null_count()
|
||||
assert null_count >= 0
|
||||
|
||||
|
||||
def test_validate_raw_data():
|
||||
"""测试原始数据 Schema 校验(宽松模式)"""
|
||||
df = generate_data(n_samples=50)
|
||||
# 应该能通过校验,即使有缺失值
|
||||
validated = validate_raw_data(df)
|
||||
assert isinstance(validated, pl.DataFrame)
|
||||
|
||||
|
||||
def test_validate_clean_data():
|
||||
"""测试清洗后数据 Schema 校验(严格模式)"""
|
||||
df = generate_data(n_samples=50)
|
||||
df_clean = df.drop_nulls()
|
||||
validated = validate_clean_data(df_clean)
|
||||
assert isinstance(validated, pl.DataFrame)
|
||||
|
||||
|
||||
def test_preprocess_data_removes_nulls():
|
||||
"""测试预处理是否删除缺失值"""
|
||||
df = generate_data(n_samples=500, random_seed=42)
|
||||
null_before = df["attendance_rate"].null_count()
|
||||
|
||||
df_clean = preprocess_data(df, validate=True)
|
||||
null_after = df_clean["attendance_rate"].null_count()
|
||||
|
||||
assert null_after == 0
|
||||
assert len(df_clean) <= len(df)
|
||||
|
||||
|
||||
def test_preprocess_data_removes_duplicates():
|
||||
"""测试去重预处理"""
|
||||
df = pl.DataFrame({
|
||||
"study_hours": [1.0, 2.0, 2.0, 3.0],
|
||||
"sleep_hours": [7.0, 7.0, 7.0, 7.0],
|
||||
"attendance_rate": [0.8, 0.8, 0.8, 0.8],
|
||||
"stress_level": [1, 2, 2, 3],
|
||||
"study_type": ["Self", "Self", "Self", "Self"],
|
||||
"is_pass": [0, 1, 1, 1],
|
||||
})
|
||||
clean_df = preprocess_data(df, validate=True)
|
||||
assert len(clean_df) == 3
|
||||
|
||||
|
||||
def test_get_feature_columns():
|
||||
"""测试特征列获取"""
|
||||
num_feats, cat_feats = get_feature_columns()
|
||||
assert "study_hours" in num_feats
|
||||
assert "study_type" in cat_feats
|
||||
|
||||
|
||||
def test_schema_classes_exist():
|
||||
"""测试 Schema 类是否可用"""
|
||||
assert RawStudentDataSchema is not None
|
||||
assert CleanStudentDataSchema is not None
|
||||
73
tests/test_infer.py
Normal file
73
tests/test_infer.py
Normal file
@ -0,0 +1,73 @@
|
||||
"""推理模块测试"""
|
||||
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from src.infer import (
|
||||
explain_prediction,
|
||||
predict_pass_prob,
|
||||
reset_model_cache,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def train_dummy_model(tmp_path_factory):
|
||||
"""训练临时模型用于测试"""
|
||||
models_dir = tmp_path_factory.mktemp("models")
|
||||
model_path = models_dir / "model.pkl"
|
||||
|
||||
import joblib
|
||||
|
||||
from src.data import generate_data, preprocess_data
|
||||
from src.train import get_pipeline
|
||||
|
||||
df = generate_data(n_samples=20)
|
||||
df = preprocess_data(df)
|
||||
|
||||
# 转换为 pandas
|
||||
df_pandas = df.to_pandas()
|
||||
X = df_pandas.drop(columns=["is_pass"])
|
||||
y = df_pandas["is_pass"]
|
||||
|
||||
pipeline = get_pipeline("rf")
|
||||
pipeline.fit(X, y)
|
||||
|
||||
joblib.dump(pipeline, model_path)
|
||||
|
||||
return model_path
|
||||
|
||||
|
||||
def test_predict_pass_prob(train_dummy_model):
|
||||
"""测试预测函数"""
|
||||
reset_model_cache()
|
||||
|
||||
with patch("src.infer.MODEL_PATH", train_dummy_model):
|
||||
proba = predict_pass_prob(
|
||||
study_hours=5.0,
|
||||
sleep_hours=7.0,
|
||||
attendance_rate=0.9,
|
||||
stress_level=3,
|
||||
study_type="Self",
|
||||
)
|
||||
assert 0.0 <= proba <= 1.0
|
||||
|
||||
|
||||
def test_explain_prediction(train_dummy_model):
|
||||
"""测试解释函数"""
|
||||
reset_model_cache()
|
||||
|
||||
with patch("src.infer.MODEL_PATH", train_dummy_model):
|
||||
explanation = explain_prediction()
|
||||
assert isinstance(explanation, str)
|
||||
assert "模型特征重要性排名" in explanation
|
||||
|
||||
|
||||
def test_load_model_missing():
|
||||
"""测试模型文件不存在时的错误处理"""
|
||||
reset_model_cache()
|
||||
|
||||
with patch("src.infer.MODEL_PATH", Path("non_existent_path/model.pkl")):
|
||||
with pytest.raises(FileNotFoundError):
|
||||
predict_pass_prob(1, 1, 1, 1, "Self")
|
||||
32
tests/test_model.py
Normal file
32
tests/test_model.py
Normal file
@ -0,0 +1,32 @@
|
||||
import os
|
||||
|
||||
import joblib
|
||||
|
||||
from src.infer import predict_pass_prob
|
||||
from src.train import MODEL_PATH, train
|
||||
|
||||
|
||||
def test_train_creates_model():
|
||||
# 确保模型不存在或被覆盖
|
||||
if os.path.exists(MODEL_PATH):
|
||||
os.remove(MODEL_PATH)
|
||||
|
||||
train()
|
||||
assert os.path.exists(MODEL_PATH)
|
||||
|
||||
model = joblib.load(MODEL_PATH)
|
||||
assert model is not None
|
||||
|
||||
|
||||
def test_inference():
|
||||
# 确保模型存在
|
||||
if not os.path.exists(MODEL_PATH):
|
||||
train()
|
||||
|
||||
# 高概率情况 (大量学习/睡眠/出勤 + Group学习 + 低压力)
|
||||
prob_high = predict_pass_prob(15, 8, 1.0, 1, "Group")
|
||||
assert prob_high > 0.5
|
||||
|
||||
# 低概率情况 (不学习/不睡/缺勤 + 在线 + 高压力)
|
||||
prob_low = predict_pass_prob(0, 3, 0.0, 5, "Online")
|
||||
assert prob_low < 0.5
|
||||
45
tests/test_train.py
Normal file
45
tests/test_train.py
Normal file
@ -0,0 +1,45 @@
|
||||
"""训练模块测试"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sklearn.pipeline import Pipeline
|
||||
|
||||
from src.train import get_pipeline, train
|
||||
|
||||
|
||||
def test_get_pipeline_structure():
|
||||
"""测试 Pipeline 结构"""
|
||||
pipeline = get_pipeline("rf")
|
||||
assert isinstance(pipeline, Pipeline)
|
||||
assert "preprocessor" in pipeline.named_steps
|
||||
assert "classifier" in pipeline.named_steps
|
||||
|
||||
|
||||
def test_get_pipeline_lr():
|
||||
"""测试逻辑回归 Pipeline"""
|
||||
pipeline = get_pipeline("lr")
|
||||
assert isinstance(pipeline, Pipeline)
|
||||
|
||||
|
||||
def test_train_function_runs(tmp_path):
|
||||
"""测试训练函数能正常运行"""
|
||||
models_dir = tmp_path / "models"
|
||||
model_path = models_dir / "model.pkl"
|
||||
|
||||
with (
|
||||
patch("src.train.MODELS_DIR", models_dir),
|
||||
patch("src.train.MODEL_PATH", model_path),
|
||||
patch("src.train.generate_data") as mock_gen,
|
||||
):
|
||||
from src.data import generate_data
|
||||
|
||||
real_small_df = generate_data(n_samples=20)
|
||||
mock_gen.return_value = real_small_df
|
||||
|
||||
try:
|
||||
train()
|
||||
except Exception as e:
|
||||
pytest.fail(f"Train function failed: {e}")
|
||||
|
||||
assert model_path.exists()
|
||||
Loading…
Reference in New Issue
Block a user