186 lines
7.3 KiB
Python
186 lines
7.3 KiB
Python
|
|
import streamlit as st
|
|||
|
|
import asyncio
|
|||
|
|
import os
|
|||
|
|
import sys
|
|||
|
|
|
|||
|
|
# Ensure project root is in path
|
|||
|
|
sys.path.append(os.getcwd())
|
|||
|
|
|
|||
|
|
from src.agent_app import agent, counselor_agent, StudyGuidance
|
|||
|
|
from pydantic_ai.messages import (
|
|||
|
|
ModelMessage, ModelRequest, ModelResponse, UserPromptPart, TextPart,
|
|||
|
|
TextPartDelta, ToolCallPart, ToolReturnPart
|
|||
|
|
)
|
|||
|
|
from pydantic_ai import (
|
|||
|
|
AgentStreamEvent, PartDeltaEvent, FunctionToolCallEvent, FunctionToolResultEvent
|
|||
|
|
)
|
|||
|
|
from dotenv import load_dotenv
|
|||
|
|
|
|||
|
|
# Load env variables
|
|||
|
|
load_dotenv()
|
|||
|
|
|
|||
|
|
st.set_page_config(
|
|||
|
|
page_title="学生成绩预测 AI 助手",
|
|||
|
|
page_icon="🎓",
|
|||
|
|
layout="wide"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# Sidebar Configuration
|
|||
|
|
st.sidebar.header("🔧 配置")
|
|||
|
|
api_key = st.sidebar.text_input("DeepSeek API Key", type="password", value=os.getenv("DEEPSEEK_API_KEY", ""))
|
|||
|
|
|
|||
|
|
if api_key:
|
|||
|
|
os.environ["DEEPSEEK_API_KEY"] = api_key
|
|||
|
|
|
|||
|
|
st.sidebar.markdown("---")
|
|||
|
|
# Mode Selection
|
|||
|
|
mode = st.sidebar.radio("功能选择", ["📊 成绩预测", "💬 心理咨询"])
|
|||
|
|
|
|||
|
|
# --- Helper Functions ---
|
|||
|
|
|
|||
|
|
async def run_analysis(query):
|
|||
|
|
try:
|
|||
|
|
if not os.getenv("DEEPSEEK_API_KEY"):
|
|||
|
|
st.error("请在侧边栏提供 DeepSeek API Key。")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
with st.spinner("🤖 Agent 正在思考... (调用 DeepSeek + 随机森林模型)"):
|
|||
|
|
result = await agent.run(query)
|
|||
|
|
return result.output
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
st.error(f"分析失败: {str(e)}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
async def run_counselor_stream(query, history, placeholder):
|
|||
|
|
"""
|
|||
|
|
Manually stream the response to a placeholder, handling tool events for visibility.
|
|||
|
|
"""
|
|||
|
|
try:
|
|||
|
|
if not os.getenv("DEEPSEEK_API_KEY"):
|
|||
|
|
placeholder.error("❌ 错误: 请在侧边栏提供 DeepSeek API Key。")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
full_response = ""
|
|||
|
|
# Status container for tool calls
|
|||
|
|
status_placeholder = st.empty()
|
|||
|
|
|
|||
|
|
# Call Counselor Agent with streaming
|
|||
|
|
# Call Counselor Agent with streaming using run_stream_events which is the modern way to get events
|
|||
|
|
async for event in counselor_agent.run_stream_events(query, message_history=history):
|
|||
|
|
# Handle Text Delta (Wrapped in PartDeltaEvent)
|
|||
|
|
if isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta):
|
|||
|
|
full_response += event.delta.content_delta
|
|||
|
|
placeholder.markdown(full_response + "▌")
|
|||
|
|
|
|||
|
|
# Handle Tool Call Start
|
|||
|
|
elif isinstance(event, FunctionToolCallEvent):
|
|||
|
|
# event.part is ToolCallPart usually, or event.tool_call
|
|||
|
|
# Check pydantic-ai docs structure: FunctionToolCallEvent has 'part' which is ToolCallPart
|
|||
|
|
status_placeholder.info(f"🛠️ 咨询师正在使用工具: `{event.part.tool_name}` ...")
|
|||
|
|
|
|||
|
|
# Handle Tool Result
|
|||
|
|
elif isinstance(event, FunctionToolResultEvent):
|
|||
|
|
status_placeholder.empty()
|
|||
|
|
|
|||
|
|
placeholder.markdown(full_response)
|
|||
|
|
status_placeholder.empty() # Ensure clear
|
|||
|
|
return full_response
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
placeholder.error(f"❌ 咨询失败: {str(e)}")
|
|||
|
|
return None
|
|||
|
|
|
|||
|
|
# --- Main Views ---
|
|||
|
|
|
|||
|
|
if mode == "📊 成绩预测":
|
|||
|
|
st.title("🎓 学生成绩预测助手")
|
|||
|
|
st.markdown("在下方输入学生详细信息,获取 AI 驱动的成绩分析。")
|
|||
|
|
|
|||
|
|
with st.form("student_data_form"):
|
|||
|
|
col1, col2 = st.columns(2)
|
|||
|
|
|
|||
|
|
with col1:
|
|||
|
|
study_hours = st.slider("每周学习时长 (小时)", 0.0, 20.0, 10.0, 0.5)
|
|||
|
|
sleep_hours = st.slider("日均睡眠时长 (小时)", 0.0, 12.0, 7.0, 0.5)
|
|||
|
|
|
|||
|
|
with col2:
|
|||
|
|
attendance_rate = st.slider("出勤率", 0.0, 1.0, 0.9, 0.05)
|
|||
|
|
stress_level = st.select_slider("压力等级 (1=低, 5=高)", options=[1, 2, 3, 4, 5], value=3)
|
|||
|
|
study_type = st.radio("主要学习方式", ["Self", "Group", "Online"], horizontal=True)
|
|||
|
|
|
|||
|
|
submitted = st.form_submit_button("🚀 分析通过率")
|
|||
|
|
|
|||
|
|
if submitted:
|
|||
|
|
user_query = (
|
|||
|
|
f"我是一名学生,情况如下:"
|
|||
|
|
f"每周学习时间: {study_hours} 小时;"
|
|||
|
|
f"平均睡眠时间: {sleep_hours} 小时;"
|
|||
|
|
f"出勤率: {attendance_rate:.2f};"
|
|||
|
|
f"压力等级: {stress_level} (1-5);"
|
|||
|
|
f"主要学习方式: {study_type}。"
|
|||
|
|
f"请调用 `predict_student` 预测我的通过率,并调用 `explain_model` 分析关键因素,最后给出针对性的建议。"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
loop = asyncio.new_event_loop()
|
|||
|
|
asyncio.set_event_loop(loop)
|
|||
|
|
guidance = loop.run_until_complete(run_analysis(user_query))
|
|||
|
|
|
|||
|
|
if guidance:
|
|||
|
|
st.divider()
|
|||
|
|
st.subheader("📊 分析结果")
|
|||
|
|
m1, m2, m3 = st.columns(3)
|
|||
|
|
m1.metric("预测通过率", f"{guidance.pass_probability:.1%}")
|
|||
|
|
m2.metric("风险评估", "高风险" if guidance.pass_probability < 0.6 else "低风险",
|
|||
|
|
delta="-高风险" if guidance.pass_probability < 0.6 else "+安全")
|
|||
|
|
|
|||
|
|
st.info(f"**风险评估:** {guidance.risk_assessment}")
|
|||
|
|
st.write(f"**关键因素:** {guidance.key_drivers}")
|
|||
|
|
|
|||
|
|
st.subheader("✅ 行动计划")
|
|||
|
|
actions = [{"优先级": item.priority, "建议行动": item.action} for item in guidance.action_plan]
|
|||
|
|
st.table(actions)
|
|||
|
|
|
|||
|
|
elif mode == "💬 心理咨询":
|
|||
|
|
st.title("🧩 AI 心理咨询室")
|
|||
|
|
st.markdown("这里是安全且私密的空间。有些压力如果你愿意说,我愿意听。")
|
|||
|
|
|
|||
|
|
# Initialize chat history
|
|||
|
|
if "messages" not in st.session_state:
|
|||
|
|
st.session_state.messages = []
|
|||
|
|
|
|||
|
|
# Display chat messages from history on app rerun
|
|||
|
|
for message in st.session_state.messages:
|
|||
|
|
with st.chat_message(message["role"]):
|
|||
|
|
st.markdown(message["content"])
|
|||
|
|
|
|||
|
|
# React to user input
|
|||
|
|
if prompt := st.chat_input("想聊聊什么?"):
|
|||
|
|
# Display user message
|
|||
|
|
with st.chat_message("user"):
|
|||
|
|
st.markdown(prompt)
|
|||
|
|
# Add user message to history
|
|||
|
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
|||
|
|
|
|||
|
|
# Prepare history for pydantic-ai
|
|||
|
|
# Convert Streamlit history to pydantic-ai ModelMessages
|
|||
|
|
# Note: We exclude the last message because `agent.run` takes the new prompt as argument
|
|||
|
|
api_history = []
|
|||
|
|
for msg in st.session_state.messages[:-1]:
|
|||
|
|
if msg["role"] == "user":
|
|||
|
|
api_history.append(ModelRequest(parts=[UserPromptPart(content=msg["content"])]))
|
|||
|
|
elif msg["role"] == "assistant":
|
|||
|
|
api_history.append(ModelResponse(parts=[TextPart(content=msg["content"])]))
|
|||
|
|
|
|||
|
|
# Generate response
|
|||
|
|
with st.chat_message("assistant"):
|
|||
|
|
placeholder = st.empty()
|
|||
|
|
with st.spinner("咨询师正在倾听..."):
|
|||
|
|
loop = asyncio.new_event_loop()
|
|||
|
|
asyncio.set_event_loop(loop)
|
|||
|
|
# Run the manual streaming function
|
|||
|
|
response_text = loop.run_until_complete(run_counselor_stream(prompt, api_history, placeholder))
|
|||
|
|
|
|||
|
|
if response_text:
|
|||
|
|
st.session_state.messages.append({"role": "assistant", "content": response_text})
|