generated from Python-2026Spring/assignment-05-final-project-template
46 lines
1.2 KiB
Python
46 lines
1.2 KiB
Python
|
|
"""训练模块测试"""
|
||
|
|
|
||
|
|
from unittest.mock import patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
from sklearn.pipeline import Pipeline
|
||
|
|
|
||
|
|
from src.train import get_pipeline, train
|
||
|
|
|
||
|
|
|
||
|
|
def test_get_pipeline_structure():
|
||
|
|
"""测试 Pipeline 结构"""
|
||
|
|
pipeline = get_pipeline("rf")
|
||
|
|
assert isinstance(pipeline, Pipeline)
|
||
|
|
assert "preprocessor" in pipeline.named_steps
|
||
|
|
assert "classifier" in pipeline.named_steps
|
||
|
|
|
||
|
|
|
||
|
|
def test_get_pipeline_lr():
|
||
|
|
"""测试逻辑回归 Pipeline"""
|
||
|
|
pipeline = get_pipeline("lr")
|
||
|
|
assert isinstance(pipeline, Pipeline)
|
||
|
|
|
||
|
|
|
||
|
|
def test_train_function_runs(tmp_path):
|
||
|
|
"""测试训练函数能正常运行"""
|
||
|
|
models_dir = tmp_path / "models"
|
||
|
|
model_path = models_dir / "model.pkl"
|
||
|
|
|
||
|
|
with (
|
||
|
|
patch("src.train.MODELS_DIR", models_dir),
|
||
|
|
patch("src.train.MODEL_PATH", model_path),
|
||
|
|
patch("src.train.generate_data") as mock_gen,
|
||
|
|
):
|
||
|
|
from src.data import generate_data
|
||
|
|
|
||
|
|
real_small_df = generate_data(n_samples=20)
|
||
|
|
mock_gen.return_value = real_small_df
|
||
|
|
|
||
|
|
try:
|
||
|
|
train()
|
||
|
|
except Exception as e:
|
||
|
|
pytest.fail(f"Train function failed: {e}")
|
||
|
|
|
||
|
|
assert model_path.exists()
|