CourseDesign/tests/test_data.py
2026-01-09 14:30:23 +08:00

112 lines
3.0 KiB
Python

"""数据模块测试
测试 Polars 数据生成、Pandera 校验和预处理功能。
"""
import polars as pl
import pytest
from src.data import (
CleanStudentDataSchema,
RawStudentDataSchema,
generate_data,
get_feature_columns,
preprocess_data,
validate_clean_data,
validate_raw_data,
)
def test_generate_data_structure():
"""测试生成数据的结构是否正确"""
df = generate_data(n_samples=50)
assert isinstance(df, pl.DataFrame)
assert len(df) == 50
expected_cols = [
"study_hours",
"sleep_hours",
"attendance_rate",
"study_type",
"stress_level",
"is_pass",
]
for col in expected_cols:
assert col in df.columns
def test_generate_data_content_range():
"""测试生成数据的值范围是否正确"""
df = generate_data(n_samples=50)
assert df["study_hours"].min() >= 0
assert df["study_hours"].max() <= 20
assert df["sleep_hours"].min() >= 0
assert df["stress_level"].min() >= 1
assert df["stress_level"].max() <= 5
assert df["is_pass"].is_in([0, 1]).all()
def test_generate_data_missing_values():
"""测试数据是否包含预期的缺失值"""
df = generate_data(n_samples=500, random_seed=42)
# attendance_rate 有 5% 概率为 null
null_count = df["attendance_rate"].null_count()
assert null_count >= 0
def test_validate_raw_data():
"""测试原始数据 Schema 校验(宽松模式)"""
df = generate_data(n_samples=50)
# 应该能通过校验,即使有缺失值
validated = validate_raw_data(df)
assert isinstance(validated, pl.DataFrame)
def test_validate_clean_data():
"""测试清洗后数据 Schema 校验(严格模式)"""
df = generate_data(n_samples=50)
df_clean = df.drop_nulls()
validated = validate_clean_data(df_clean)
assert isinstance(validated, pl.DataFrame)
def test_preprocess_data_removes_nulls():
"""测试预处理是否删除缺失值"""
df = generate_data(n_samples=500, random_seed=42)
null_before = df["attendance_rate"].null_count()
df_clean = preprocess_data(df, validate=True)
null_after = df_clean["attendance_rate"].null_count()
assert null_after == 0
assert len(df_clean) <= len(df)
def test_preprocess_data_removes_duplicates():
"""测试去重预处理"""
df = pl.DataFrame({
"study_hours": [1.0, 2.0, 2.0, 3.0],
"sleep_hours": [7.0, 7.0, 7.0, 7.0],
"attendance_rate": [0.8, 0.8, 0.8, 0.8],
"stress_level": [1, 2, 2, 3],
"study_type": ["Self", "Self", "Self", "Self"],
"is_pass": [0, 1, 1, 1],
})
clean_df = preprocess_data(df, validate=True)
assert len(clean_df) == 3
def test_get_feature_columns():
"""测试特征列获取"""
num_feats, cat_feats = get_feature_columns()
assert "study_hours" in num_feats
assert "study_type" in cat_feats
def test_schema_classes_exist():
"""测试 Schema 类是否可用"""
assert RawStudentDataSchema is not None
assert CleanStudentDataSchema is not None