CourseDesign/src/streamlit_app.py
2026-01-09 14:30:23 +08:00

251 lines
8.5 KiB
Python

"""Streamlit 演示应用
学生成绩预测 AI 助手 - 支持成绩预测分析和心理咨询对话。
"""
import asyncio
import os
import sys
import streamlit as st
# Ensure project root is in path
sys.path.append(os.getcwd())
from dotenv import load_dotenv
from pydantic_ai import FunctionToolCallEvent, FunctionToolResultEvent, PartDeltaEvent
from pydantic_ai.messages import (
ModelRequest,
ModelResponse,
TextPart,
TextPartDelta,
UserPromptPart,
)
from src.agent_app import AgentDeps, counselor_agent, study_advisor
from src.features import StudentFeatures
# Load env variables
load_dotenv()
st.set_page_config(page_title="学生成绩预测 AI 助手", page_icon="🎓", layout="wide")
# Sidebar Configuration
st.sidebar.header("🔧 配置")
api_key = st.sidebar.text_input(
"DeepSeek API Key", type="password", value=os.getenv("DEEPSEEK_API_KEY", "")
)
if api_key:
os.environ["DEEPSEEK_API_KEY"] = api_key
st.sidebar.markdown("---")
# Mode Selection
mode = st.sidebar.radio("功能选择", ["📊 成绩预测", "💬 心理咨询"])
# --- Helper Functions ---
async def run_analysis(
study_hours: float,
sleep_hours: float,
attendance_rate: float,
stress_level: int,
study_type: str,
):
"""运行成绩预测分析"""
try:
if not os.getenv("DEEPSEEK_API_KEY"):
st.error("请在侧边栏提供 DeepSeek API Key。")
return None
# 创建学生特征
student = StudentFeatures(
study_hours=study_hours,
sleep_hours=sleep_hours,
attendance_rate=attendance_rate,
stress_level=stress_level,
study_type=study_type,
)
# 创建依赖
deps = AgentDeps(student=student)
# 构建查询
query = (
f"请分析这位学生的通过率并给出建议。"
f"学生信息已通过工具获取,请调用 predict_pass_probability 和 get_model_explanation 工具。"
)
with st.spinner("🤖 Agent 正在思考... (调用 DeepSeek + 随机森林模型)"):
result = await study_advisor.run(query, deps=deps)
return result.output
except Exception as e:
st.error(f"分析失败: {e!s}")
return None
async def run_counselor_stream(
query: str,
history: list,
placeholder,
student: StudentFeatures,
):
"""
运行咨询师对话流,手动处理流式响应和工具调用事件。
"""
try:
if not os.getenv("DEEPSEEK_API_KEY"):
placeholder.error("❌ 错误: 请在侧边栏提供 DeepSeek API Key。")
return None
# 创建依赖
deps = AgentDeps(student=student)
full_response = ""
# Status container for tool calls
status_placeholder = st.empty()
# Call Counselor Agent with streaming
async for event in counselor_agent.run_stream_events(query, deps=deps, message_history=history):
# Handle Text Delta (Wrapped in PartDeltaEvent)
if isinstance(event, PartDeltaEvent) and isinstance(event.delta, TextPartDelta):
full_response += event.delta.content_delta
placeholder.markdown(full_response + "")
# Handle Tool Call Start
elif isinstance(event, FunctionToolCallEvent):
status_placeholder.info(f"🛠️ 咨询师正在使用工具: `{event.part.tool_name}` ...")
# Handle Tool Result
elif isinstance(event, FunctionToolResultEvent):
status_placeholder.empty()
placeholder.markdown(full_response)
status_placeholder.empty() # Ensure clear
return full_response
except Exception as e:
placeholder.error(f"❌ 咨询失败: {e!s}")
return None
# --- Main Views ---
if mode == "📊 成绩预测":
st.title("🎓 学生成绩预测助手")
st.markdown("在下方输入学生详细信息,获取 AI 驱动的成绩分析。")
with st.form("student_data_form"):
col1, col2 = st.columns(2)
with col1:
study_hours = st.slider("每周学习时长 (小时)", 0.0, 20.0, 10.0, 0.5)
sleep_hours = st.slider("日均睡眠时长 (小时)", 0.0, 12.0, 7.0, 0.5)
with col2:
attendance_rate = st.slider("出勤率", 0.0, 1.0, 0.9, 0.05)
stress_level = st.select_slider(
"压力等级 (1=低, 5=高)", options=[1, 2, 3, 4, 5], value=3
)
study_type = st.radio("主要学习方式", ["Self", "Group", "Online"], horizontal=True)
submitted = st.form_submit_button("🚀 分析通过率")
if submitted:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
guidance = loop.run_until_complete(
run_analysis(study_hours, sleep_hours, attendance_rate, stress_level, study_type)
)
if guidance:
st.divider()
st.subheader("📊 分析结果")
m1, m2, m3 = st.columns(3)
m1.metric("预测通过率", f"{guidance.pass_probability:.1%}")
m2.metric(
"风险评估",
"高风险" if guidance.pass_probability < 0.6 else "低风险",
delta="-高风险" if guidance.pass_probability < 0.6 else "+安全",
)
st.info(f"**风险评估:** {guidance.risk_assessment}")
# 显示关键因素
st.subheader("🔍 关键因素")
for factor in guidance.key_factors:
st.write(f"- {factor}")
st.subheader("✅ 行动计划")
actions = [
{"优先级": item.priority, "建议行动": item.action} for item in guidance.action_plan
]
st.table(actions)
st.subheader("💡 分析依据")
st.write(guidance.rationale)
elif mode == "💬 心理咨询":
st.title("🧩 AI 心理咨询室")
st.markdown("这里是安全且私密的空间。有些压力如果你愿意说,我愿意听。")
# Sidebar for student info (optional for counselor context)
with st.sidebar.expander("📝 学生信息 (可选)", expanded=False):
c_study_hours = st.slider("每周学习时长", 0.0, 20.0, 10.0, 0.5, key="c_study")
c_sleep_hours = st.slider("日均睡眠时长", 0.0, 12.0, 7.0, 0.5, key="c_sleep")
c_attendance = st.slider("出勤率", 0.0, 1.0, 0.9, 0.05, key="c_att")
c_stress = st.select_slider("压力等级", options=[1, 2, 3, 4, 5], value=3, key="c_stress")
c_study_type = st.radio("学习方式", ["Self", "Group", "Online"], key="c_type")
# Initialize chat history
if "messages" not in st.session_state:
st.session_state.messages = []
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if prompt := st.chat_input("想聊聊什么?"):
# Display user message
with st.chat_message("user"):
st.markdown(prompt)
# Add user message to history
st.session_state.messages.append({"role": "user", "content": prompt})
# Prepare history for pydantic-ai
api_history = []
for msg in st.session_state.messages[:-1]:
if msg["role"] == "user":
api_history.append(ModelRequest(parts=[UserPromptPart(content=msg["content"])]))
elif msg["role"] == "assistant":
api_history.append(ModelResponse(parts=[TextPart(content=msg["content"])]))
# Create student features for counselor context
student = StudentFeatures(
study_hours=c_study_hours,
sleep_hours=c_sleep_hours,
attendance_rate=c_attendance,
stress_level=c_stress,
study_type=c_study_type,
)
# Generate response
with st.chat_message("assistant"):
placeholder = st.empty()
with st.spinner("咨询师正在倾听..."):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Run the manual streaming function
response_text = loop.run_until_complete(
run_counselor_stream(prompt, api_history, placeholder, student)
)
if response_text:
st.session_state.messages.append(
{"role": "assistant", "content": response_text}
)