30 lines
824 B
Python
30 lines
824 B
Python
|
|
import os
|
||
|
|
import joblib
|
||
|
|
import pytest
|
||
|
|
from src.train import train, MODEL_PATH
|
||
|
|
from src.infer import load_model, predict_pass_prob
|
||
|
|
|
||
|
|
def test_train_creates_model():
|
||
|
|
# 确保模型不存在或被覆盖
|
||
|
|
if os.path.exists(MODEL_PATH):
|
||
|
|
os.remove(MODEL_PATH)
|
||
|
|
|
||
|
|
train()
|
||
|
|
assert os.path.exists(MODEL_PATH)
|
||
|
|
|
||
|
|
model = joblib.load(MODEL_PATH)
|
||
|
|
assert model is not None
|
||
|
|
|
||
|
|
def test_inference():
|
||
|
|
# 确保模型存在
|
||
|
|
if not os.path.exists(MODEL_PATH):
|
||
|
|
train()
|
||
|
|
|
||
|
|
# 高概率情况 (大量学习/睡眠/出勤 + Group学习 + 低压力)
|
||
|
|
prob_high = predict_pass_prob(15, 8, 1.0, 1, "Group")
|
||
|
|
assert prob_high > 0.5
|
||
|
|
|
||
|
|
# 低概率情况 (不学习/不睡/缺勤 + 在线 + 高压力)
|
||
|
|
prob_low = predict_pass_prob(0, 3, 0.0, 5, "Online")
|
||
|
|
assert prob_low < 0.5
|