"""数据模块测试 测试 Polars 数据生成、Pandera 校验和预处理功能。 """ import polars as pl import pytest from src.data import ( CleanStudentDataSchema, RawStudentDataSchema, generate_data, get_feature_columns, preprocess_data, validate_clean_data, validate_raw_data, ) def test_generate_data_structure(): """测试生成数据的结构是否正确""" df = generate_data(n_samples=50) assert isinstance(df, pl.DataFrame) assert len(df) == 50 expected_cols = [ "study_hours", "sleep_hours", "attendance_rate", "study_type", "stress_level", "is_pass", ] for col in expected_cols: assert col in df.columns def test_generate_data_content_range(): """测试生成数据的值范围是否正确""" df = generate_data(n_samples=50) assert df["study_hours"].min() >= 0 assert df["study_hours"].max() <= 20 assert df["sleep_hours"].min() >= 0 assert df["stress_level"].min() >= 1 assert df["stress_level"].max() <= 5 assert df["is_pass"].is_in([0, 1]).all() def test_generate_data_missing_values(): """测试数据是否包含预期的缺失值""" df = generate_data(n_samples=500, random_seed=42) # attendance_rate 有 5% 概率为 null null_count = df["attendance_rate"].null_count() assert null_count >= 0 def test_validate_raw_data(): """测试原始数据 Schema 校验(宽松模式)""" df = generate_data(n_samples=50) # 应该能通过校验,即使有缺失值 validated = validate_raw_data(df) assert isinstance(validated, pl.DataFrame) def test_validate_clean_data(): """测试清洗后数据 Schema 校验(严格模式)""" df = generate_data(n_samples=50) df_clean = df.drop_nulls() validated = validate_clean_data(df_clean) assert isinstance(validated, pl.DataFrame) def test_preprocess_data_removes_nulls(): """测试预处理是否删除缺失值""" df = generate_data(n_samples=500, random_seed=42) null_before = df["attendance_rate"].null_count() df_clean = preprocess_data(df, validate=True) null_after = df_clean["attendance_rate"].null_count() assert null_after == 0 assert len(df_clean) <= len(df) def test_preprocess_data_removes_duplicates(): """测试去重预处理""" df = pl.DataFrame({ "study_hours": [1.0, 2.0, 2.0, 3.0], "sleep_hours": [7.0, 7.0, 7.0, 7.0], "attendance_rate": [0.8, 0.8, 0.8, 0.8], "stress_level": [1, 2, 2, 3], "study_type": ["Self", "Self", "Self", "Self"], "is_pass": [0, 1, 1, 1], }) clean_df = preprocess_data(df, validate=True) assert len(clean_df) == 3 def test_get_feature_columns(): """测试特征列获取""" num_feats, cat_feats = get_feature_columns() assert "study_hours" in num_feats assert "study_type" in cat_feats def test_schema_classes_exist(): """测试 Schema 类是否可用""" assert RawStudentDataSchema is not None assert CleanStudentDataSchema is not None