49 lines
1.7 KiB
Python
49 lines
1.7 KiB
Python
import sys
|
|
import os
|
|
import pytest
|
|
from sklearn.pipeline import Pipeline
|
|
from unittest.mock import patch, MagicMock
|
|
|
|
# Ensure src is in path
|
|
sys.path.append(os.getcwd())
|
|
|
|
from src.train import get_pipeline, train
|
|
|
|
def test_get_pipeline_structure():
|
|
"""Test if get_pipeline returns a valid Scikit-learn pipeline."""
|
|
pipeline = get_pipeline("rf")
|
|
assert isinstance(pipeline, Pipeline)
|
|
assert "preprocessor" in pipeline.named_steps
|
|
assert "classifier" in pipeline.named_steps
|
|
|
|
def test_train_function_runs(tmp_path):
|
|
"""
|
|
Test if the train function runs without errors.
|
|
We mock generate_models to use a temp dir and run with small data.
|
|
"""
|
|
# Create a temporary directory for models
|
|
models_dir = tmp_path / "models"
|
|
model_path = models_dir / "model.pkl"
|
|
|
|
# Needs to be string for some os.path usages if they are strict, but pathlib usually works.
|
|
# However, src/train.py uses os.path.join(MODELS_DIR, ...), so we need to patch constants.
|
|
|
|
with patch("src.train.MODELS_DIR", str(models_dir)), \
|
|
patch("src.train.MODEL_PATH", str(model_path)), \
|
|
patch("src.train.generate_data") as mock_gen:
|
|
|
|
# Mock data generation to return a very small dataframe to speed up test
|
|
# We need to use real data structure though bc pipeline expects specific columns
|
|
from src.data import generate_data
|
|
real_small_df = generate_data(n_samples=10)
|
|
mock_gen.return_value = real_small_df
|
|
|
|
# Run training
|
|
try:
|
|
train()
|
|
except Exception as e:
|
|
pytest.fail(f"Train function failed with error: {e}")
|
|
|
|
# Check if model file was created
|
|
assert model_path.exists()
|