CourseDesign/tests/test_agent.py
2026-01-09 14:30:23 +08:00

84 lines
2.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Agent 模块测试
测试 Agent 工具函数和依赖注入。
"""
import os
from unittest.mock import patch
import pytest
# 设置虚拟 key 避免 pydantic-ai 初始化错误
os.environ["DEEPSEEK_API_KEY"] = "dummy_key_for_testing"
from src.agent_app import AgentDeps, study_advisor
from src.features import StudentFeatures
@pytest.fixture
def sample_student() -> StudentFeatures:
"""创建测试用学生特征"""
return StudentFeatures(
study_hours=12,
sleep_hours=7,
attendance_rate=0.9,
stress_level=2,
study_type="Self",
)
@pytest.fixture
def sample_deps(sample_student: StudentFeatures) -> AgentDeps:
"""创建测试用依赖"""
return AgentDeps(student=sample_student)
def test_agent_deps_creation(sample_deps: AgentDeps):
"""测试 AgentDeps 创建"""
assert sample_deps.student.study_hours == 12
assert sample_deps.model_path == "models/model.pkl"
def test_student_features_validation():
"""测试 StudentFeatures 验证"""
# 有效数据
student = StudentFeatures(
study_hours=10,
sleep_hours=7,
attendance_rate=0.85,
stress_level=3,
study_type="Group",
)
assert student.study_type == "Group"
# 无效 study_type
with pytest.raises(ValueError):
StudentFeatures(
study_hours=10,
sleep_hours=7,
attendance_rate=0.85,
stress_level=3,
study_type="Invalid",
)
def test_tool_function_mock(sample_deps: AgentDeps):
"""测试工具函数mock 底层推理)"""
with patch("src.agent_app.predict_pass_prob") as mock_predict:
mock_predict.return_value = 0.85
# 由于工具是 async我们直接测试底层函数
with patch("src.infer.load_model"):
with patch("src.infer._MODEL") as mock_model:
mock_model.predict_proba.return_value = [[0.15, 0.85]]
# 这里只验证 mock 设置正确
assert mock_predict.return_value == 0.85
def test_agent_structure():
"""测试 Agent 结构"""
assert study_advisor is not None
assert hasattr(study_advisor, "run")
assert hasattr(study_advisor, "run_sync")