import os import joblib from src.infer import predict_pass_prob from src.train import MODEL_PATH, train def test_train_creates_model(): # 确保模型不存在或被覆盖 if os.path.exists(MODEL_PATH): os.remove(MODEL_PATH) train() assert os.path.exists(MODEL_PATH) model = joblib.load(MODEL_PATH) assert model is not None def test_inference(): # 确保模型存在 if not os.path.exists(MODEL_PATH): train() # 高概率情况 (大量学习/睡眠/出勤 + Group学习 + 低压力) prob_high = predict_pass_prob(15, 8, 1.0, 1, "Group") assert prob_high > 0.5 # 低概率情况 (不学习/不睡/缺勤 + 在线 + 高压力) prob_low = predict_pass_prob(0, 3, 0.0, 5, "Online") assert prob_low < 0.5