student-score-prediction/tests/test_train.py

46 lines
1.2 KiB
Python
Raw Permalink Normal View History

2026-01-12 11:06:20 +08:00
"""训练模块测试"""
from unittest.mock import patch
import pytest
from sklearn.pipeline import Pipeline
from src.train import get_pipeline, train
def test_get_pipeline_structure():
"""测试 Pipeline 结构"""
pipeline = get_pipeline("rf")
assert isinstance(pipeline, Pipeline)
assert "preprocessor" in pipeline.named_steps
assert "classifier" in pipeline.named_steps
def test_get_pipeline_lr():
"""测试逻辑回归 Pipeline"""
pipeline = get_pipeline("lr")
assert isinstance(pipeline, Pipeline)
def test_train_function_runs(tmp_path):
"""测试训练函数能正常运行"""
models_dir = tmp_path / "models"
model_path = models_dir / "model.pkl"
with (
patch("src.train.MODELS_DIR", models_dir),
patch("src.train.MODEL_PATH", model_path),
patch("src.train.generate_data") as mock_gen,
):
from src.data import generate_data
real_small_df = generate_data(n_samples=20)
mock_gen.return_value = real_small_df
try:
train()
except Exception as e:
pytest.fail(f"Train function failed: {e}")
assert model_path.exists()