134 lines
5.1 KiB
Python
134 lines
5.1 KiB
Python
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
import asyncio
|
|||
|
|
from typing import Any, List
|
|||
|
|
|
|||
|
|
sys.path.append(os.getcwd())
|
|||
|
|
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
from pydantic_ai import Agent, RunContext
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
from src.infer import predict_pass_prob, explain_prediction
|
|||
|
|
|
|||
|
|
load_dotenv()
|
|||
|
|
|
|||
|
|
# --- 1. 定义结构化输出 (Level 1 Requirement) ---
|
|||
|
|
class ActionItem(BaseModel):
|
|||
|
|
action: str = Field(description="具体的行动建议")
|
|||
|
|
priority: str = Field(description="优先级 (高/中/低)")
|
|||
|
|
|
|||
|
|
class StudyGuidance(BaseModel):
|
|||
|
|
pass_probability: float = Field(description="预测通过率 (0-1)")
|
|||
|
|
risk_assessment: str = Field(description="风险评估 (自然语言描述)")
|
|||
|
|
key_drivers: str = Field(description="导致该预测结果的主要因素 (来自模型解释)")
|
|||
|
|
action_plan: List[ActionItem] = Field(description="3-5条建议清单")
|
|||
|
|
|
|||
|
|
# --- 2. 初始化 Agent ---
|
|||
|
|
# 必须强调:不要编造事实,必须基于工具返回的数据。
|
|||
|
|
agent = Agent(
|
|||
|
|
"deepseek:deepseek-chat",
|
|||
|
|
output_type=StudyGuidance,
|
|||
|
|
system_prompt=(
|
|||
|
|
"你是一个极其严谨的学业数据分析师。"
|
|||
|
|
"你的任务是根据学生的具体情况预测其考试通过率,并给出建议。"
|
|||
|
|
"【重要规则】"
|
|||
|
|
"1. 必须先调用 `predict_student` 获取概率。"
|
|||
|
|
"2. 必须调用 `explain_model` 获取模型认为最重要的特征,并在 `key_drivers` 中引用这些特征。"
|
|||
|
|
"3. 你的建议必须针对那些最重要的特征(例如,如果模型说睡眠很重要,就给睡眠建议)。"
|
|||
|
|
"4. 严禁凭空编造数值。"
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# --- 2.1 定义 Counselor Agent ---
|
|||
|
|
counselor_agent = Agent(
|
|||
|
|
"deepseek:deepseek-chat",
|
|||
|
|
system_prompt=(
|
|||
|
|
"你是一位富有同理心且专业的大学心理咨询师。"
|
|||
|
|
"你的目标是倾听学生的学业压力和生活烦恼,提供情感支持,并根据需要给出建议。"
|
|||
|
|
"【交互风格】"
|
|||
|
|
"1. 同理心:首先通过复述或确认学生的感受来表达理解(例如:“听起来你最近压力真的很大...”)。"
|
|||
|
|
"2. 引导性:不要急于给出解决方案,先通过提问了解更多背景。"
|
|||
|
|
"3. 数据驱动(可选):如果学生询问具体通过率或客观分析,请调用 `predict_student_tool` 或 `explain_model_tool`。"
|
|||
|
|
"4. 语气:温暖、支持、专业,但像朋友一样交谈。"
|
|||
|
|
"【工具使用】"
|
|||
|
|
"如果学生提供了具体的学习时长、睡眠等数据,或者明确询问预测结果,请使用工具。"
|
|||
|
|
"不要在每一句话里都引用数据,只在通过率相关的话题中使用。"
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# --- 3. 注册工具 (Level 1 Requirement: 至少2个工具) ---
|
|||
|
|
|
|||
|
|
@agent.tool
|
|||
|
|
def predict_student(ctx: RunContext[Any],
|
|||
|
|
study_hours: float,
|
|||
|
|
sleep_hours: float,
|
|||
|
|
attendance_rate: float,
|
|||
|
|
stress_level: int,
|
|||
|
|
study_type: str) -> float:
|
|||
|
|
"""
|
|||
|
|
根据学生行为预测通过率。
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
study_hours: 每周学习小时数 (0-20)
|
|||
|
|
sleep_hours: 每天睡眠小时数 (0-12)
|
|||
|
|
attendance_rate: 出勤率 (0.0-1.0)
|
|||
|
|
stress_level: 压力等级 1(低) - 5(高)
|
|||
|
|
study_type: 学习类型 ("Group", "Self", "Online")
|
|||
|
|
"""
|
|||
|
|
return predict_pass_prob(study_hours, sleep_hours, attendance_rate, stress_level, study_type)
|
|||
|
|
|
|||
|
|
@counselor_agent.tool
|
|||
|
|
def predict_student_tool(ctx: RunContext[Any],
|
|||
|
|
study_hours: float,
|
|||
|
|
sleep_hours: float,
|
|||
|
|
attendance_rate: float,
|
|||
|
|
stress_level: int,
|
|||
|
|
study_type: str) -> float:
|
|||
|
|
"""
|
|||
|
|
根据学生行为预测通过率。用于咨询过程中提供客观数据支持。
|
|||
|
|
"""
|
|||
|
|
return predict_pass_prob(study_hours, sleep_hours, attendance_rate, stress_level, study_type)
|
|||
|
|
|
|||
|
|
@agent.tool
|
|||
|
|
def explain_model(ctx: RunContext[Any]) -> str:
|
|||
|
|
"""
|
|||
|
|
获取机器学习模型的全局特征重要性解释。
|
|||
|
|
返回哪些特征对预测结果影响最大。
|
|||
|
|
"""
|
|||
|
|
return explain_prediction()
|
|||
|
|
|
|||
|
|
@counselor_agent.tool
|
|||
|
|
def explain_model_tool(ctx: RunContext[Any]) -> str:
|
|||
|
|
"""
|
|||
|
|
获取机器学习模型的全局特征重要性解释。
|
|||
|
|
"""
|
|||
|
|
return explain_prediction()
|
|||
|
|
|
|||
|
|
async def main():
|
|||
|
|
# 模拟真实的学生查询
|
|||
|
|
query = (
|
|||
|
|
"我最近压力很大 (等级4),每天只睡 4 小时,不过我每周自学(Self) 12 小时,"
|
|||
|
|
"出勤率大概 90%。请帮我分析一下我会挂科吗?基于模型告诉我怎么做最有效。"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print(f"用户: {query}\n")
|
|||
|
|
print("Agent 正在思考并调用模型工具...\n")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
if not os.getenv("DEEPSEEK_API_KEY"):
|
|||
|
|
print("❌ 错误: 未设置 DEEPSEEK_API_KEY,无法运行 Agent。")
|
|||
|
|
print("请在 .env 文件中设置密钥,或 export DEEPSEEK_API_KEY='...'")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
result = await agent.run(query)
|
|||
|
|
|
|||
|
|
print("--- 结构化分析报告 ---")
|
|||
|
|
print(result.output.model_dump_json(indent=2))
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"❌ 运行失败: {e}")
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
asyncio.run(main())
|