import polars as pl import numpy as np from typing import Tuple, Dict, List, Optional import logging from pathlib import Path logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class CreditCardDataProcessor: def __init__(self, file_path: str): self.file_path = file_path self.data: Optional[pl.DataFrame] = None self.train_data: Optional[pl.DataFrame] = None self.test_data: Optional[pl.DataFrame] = None self.train_features: Optional[np.ndarray] = None self.train_labels: Optional[np.ndarray] = None self.test_features: Optional[np.ndarray] = None self.test_labels: Optional[np.ndarray] = None def load_data(self) -> None: logger.info(f"加载数据集: {self.file_path}") try: self.data = pl.read_csv( self.file_path, schema_overrides={"Time": pl.Float64} ) logger.info(f"数据集加载成功,形状: {self.data.shape}") fraud_count = self.data.filter(pl.col("Class") == 1).height normal_count = self.data.filter(pl.col("Class") == 0).height logger.info(f"欺诈交易数量: {fraud_count}, 非欺诈交易数量: {normal_count}") except Exception as e: logger.error(f"加载数据失败: {e}") raise def validate_data(self) -> None: logger.info("开始数据验证...") missing_values = self.data.null_count() total_missing = missing_values.sum_horizontal().item() if total_missing > 0: logger.warning(f"发现缺失值: {total_missing} 个") else: logger.info("无缺失值,数据完整性良好") class_dist = self.data.group_by("Class").agg(pl.len().alias("count")).to_dict() logger.info(f"标签分布: {class_dist}") def split_data_by_time(self, test_ratio: float = 0.2) -> Tuple[pl.DataFrame, pl.DataFrame]: logger.info(f"按照时间顺序划分数据集,测试集比例: {test_ratio}") sorted_data = self.data.sort("Time") split_index = int(sorted_data.height * (1 - test_ratio)) self.train_data = sorted_data[:split_index] self.test_data = sorted_data[split_index:] logger.info(f"训练集形状: {self.train_data.shape}, 测试集形状: {self.test_data.shape}") train_max_time = self.train_data["Time"].max() test_min_time = self.test_data["Time"].min() logger.info(f"训练集最大时间: {train_max_time}, 测试集最小时间: {test_min_time}") if train_max_time <= test_min_time: logger.info("时间划分正确,训练集时间早于测试集") else: logger.warning("时间划分存在问题,训练集时间晚于测试集") return self.train_data, self.test_data def prepare_features_labels(self, feature_cols: Optional[List[str]] = None, label_col: str = "Class") -> None: logger.info("准备特征和标签...") if feature_cols is None: feature_cols = [col for col in self.data.columns if col != label_col] logger.info(f"使用的特征列: {feature_cols}") self.train_features = self.train_data.select(feature_cols).to_numpy() self.train_labels = self.train_data.select(label_col).to_numpy().flatten() self.test_features = self.test_data.select(feature_cols).to_numpy() self.test_labels = self.test_data.select(label_col).to_numpy().flatten() logger.info(f"训练特征形状: {self.train_features.shape}, 训练标签形状: {self.train_labels.shape}") logger.info(f"测试特征形状: {self.test_features.shape}, 测试标签形状: {self.test_labels.shape}") def get_statistics(self) -> Dict[str, any]: if self.data is None: self.load_data() stats = { "总记录数": self.data.height, "特征数": len([col for col in self.data.columns if col != "Class"]), "欺诈交易数": self.data.filter(pl.col("Class") == 1).height, "非欺诈交易数": self.data.filter(pl.col("Class") == 0).height, "不平衡比例": self.data.filter(pl.col("Class") == 0).height / self.data.filter(pl.col("Class") == 1).height } return stats def load_data(file_path: str = "data/creditcard.csv") -> CreditCardDataProcessor: processor = CreditCardDataProcessor(file_path) processor.load_data() processor.validate_data() processor.split_data_by_time() processor.prepare_features_labels() return processor if __name__ == "__main__": processor = load_data() stats = processor.get_statistics() print("\n=== 数据集统计信息 ===") for key, value in stats.items(): print(f"{key}: {value}")