CourseDesign/tests/test_infer.py

74 lines
1.8 KiB
Python
Raw Normal View History

2026-01-09 14:30:23 +08:00
"""推理模块测试"""
from pathlib import Path
from unittest.mock import patch
2026-01-09 14:30:23 +08:00
import pytest
from src.infer import (
explain_prediction,
predict_pass_prob,
reset_model_cache,
)
@pytest.fixture(scope="module")
def train_dummy_model(tmp_path_factory):
2026-01-09 14:30:23 +08:00
"""训练临时模型用于测试"""
models_dir = tmp_path_factory.mktemp("models")
model_path = models_dir / "model.pkl"
2026-01-09 14:30:23 +08:00
import joblib
2026-01-09 14:30:23 +08:00
from src.data import generate_data, preprocess_data
from src.train import get_pipeline
df = generate_data(n_samples=20)
df = preprocess_data(df)
2026-01-09 14:30:23 +08:00
# 转换为 pandas
df_pandas = df.to_pandas()
X = df_pandas.drop(columns=["is_pass"])
y = df_pandas["is_pass"]
pipeline = get_pipeline("rf")
pipeline.fit(X, y)
2026-01-09 14:30:23 +08:00
joblib.dump(pipeline, model_path)
2026-01-09 14:30:23 +08:00
return model_path
def test_predict_pass_prob(train_dummy_model):
2026-01-09 14:30:23 +08:00
"""测试预测函数"""
reset_model_cache()
with patch("src.infer.MODEL_PATH", train_dummy_model):
proba = predict_pass_prob(
study_hours=5.0,
sleep_hours=7.0,
attendance_rate=0.9,
stress_level=3,
2026-01-09 14:30:23 +08:00
study_type="Self",
)
assert 0.0 <= proba <= 1.0
2026-01-09 14:30:23 +08:00
def test_explain_prediction(train_dummy_model):
2026-01-09 14:30:23 +08:00
"""测试解释函数"""
reset_model_cache()
with patch("src.infer.MODEL_PATH", train_dummy_model):
explanation = explain_prediction()
assert isinstance(explanation, str)
assert "模型特征重要性排名" in explanation
2026-01-09 14:30:23 +08:00
def test_load_model_missing():
2026-01-09 14:30:23 +08:00
"""测试模型文件不存在时的错误处理"""
reset_model_cache()
with patch("src.infer.MODEL_PATH", Path("non_existent_path/model.pkl")):
with pytest.raises(FileNotFoundError):
2026-01-09 14:30:23 +08:00
predict_pass_prob(1, 1, 1, 1, "Self")