74 lines
1.8 KiB
Python
74 lines
1.8 KiB
Python
"""推理模块测试"""
|
|
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
|
|
from src.infer import (
|
|
explain_prediction,
|
|
predict_pass_prob,
|
|
reset_model_cache,
|
|
)
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def train_dummy_model(tmp_path_factory):
|
|
"""训练临时模型用于测试"""
|
|
models_dir = tmp_path_factory.mktemp("models")
|
|
model_path = models_dir / "model.pkl"
|
|
|
|
import joblib
|
|
|
|
from src.data import generate_data, preprocess_data
|
|
from src.train import get_pipeline
|
|
|
|
df = generate_data(n_samples=20)
|
|
df = preprocess_data(df)
|
|
|
|
# 转换为 pandas
|
|
df_pandas = df.to_pandas()
|
|
X = df_pandas.drop(columns=["is_pass"])
|
|
y = df_pandas["is_pass"]
|
|
|
|
pipeline = get_pipeline("rf")
|
|
pipeline.fit(X, y)
|
|
|
|
joblib.dump(pipeline, model_path)
|
|
|
|
return model_path
|
|
|
|
|
|
def test_predict_pass_prob(train_dummy_model):
|
|
"""测试预测函数"""
|
|
reset_model_cache()
|
|
|
|
with patch("src.infer.MODEL_PATH", train_dummy_model):
|
|
proba = predict_pass_prob(
|
|
study_hours=5.0,
|
|
sleep_hours=7.0,
|
|
attendance_rate=0.9,
|
|
stress_level=3,
|
|
study_type="Self",
|
|
)
|
|
assert 0.0 <= proba <= 1.0
|
|
|
|
|
|
def test_explain_prediction(train_dummy_model):
|
|
"""测试解释函数"""
|
|
reset_model_cache()
|
|
|
|
with patch("src.infer.MODEL_PATH", train_dummy_model):
|
|
explanation = explain_prediction()
|
|
assert isinstance(explanation, str)
|
|
assert "模型特征重要性排名" in explanation
|
|
|
|
|
|
def test_load_model_missing():
|
|
"""测试模型文件不存在时的错误处理"""
|
|
reset_model_cache()
|
|
|
|
with patch("src.infer.MODEL_PATH", Path("non_existent_path/model.pkl")):
|
|
with pytest.raises(FileNotFoundError):
|
|
predict_pass_prob(1, 1, 1, 1, "Self")
|