63 lines
1.9 KiB
Python
63 lines
1.9 KiB
Python
|
|
import pytest
|
||
|
|
import numpy as np
|
||
|
|
from src.data import CreditCardDataProcessor, load_data
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
|
||
|
|
def test_data_processor_initialization():
|
||
|
|
processor = CreditCardDataProcessor("data/creditcard.csv")
|
||
|
|
assert processor.file_path == "data/creditcard.csv"
|
||
|
|
assert processor.data is None
|
||
|
|
|
||
|
|
|
||
|
|
def test_load_data():
|
||
|
|
processor = CreditCardDataProcessor("data/creditcard.csv")
|
||
|
|
processor.load_data()
|
||
|
|
assert processor.data is not None
|
||
|
|
assert processor.data.height > 0
|
||
|
|
assert "Class" in processor.data.columns
|
||
|
|
|
||
|
|
|
||
|
|
def test_validate_data():
|
||
|
|
processor = CreditCardDataProcessor("data/creditcard.csv")
|
||
|
|
processor.load_data()
|
||
|
|
processor.validate_data()
|
||
|
|
assert processor.data is not None
|
||
|
|
|
||
|
|
|
||
|
|
def test_split_data_by_time():
|
||
|
|
processor = CreditCardDataProcessor("data/creditcard.csv")
|
||
|
|
processor.load_data()
|
||
|
|
train_data, test_data = processor.split_data_by_time(test_ratio=0.2)
|
||
|
|
assert train_data is not None
|
||
|
|
assert test_data is not None
|
||
|
|
assert train_data.height > test_data.height
|
||
|
|
|
||
|
|
|
||
|
|
def test_prepare_features_labels():
|
||
|
|
processor = CreditCardDataProcessor("data/creditcard.csv")
|
||
|
|
processor.load_data()
|
||
|
|
processor.split_data_by_time()
|
||
|
|
processor.prepare_features_labels()
|
||
|
|
assert processor.train_features is not None
|
||
|
|
assert processor.train_labels is not None
|
||
|
|
assert processor.test_features is not None
|
||
|
|
assert processor.test_labels is not None
|
||
|
|
|
||
|
|
|
||
|
|
def test_load_data_function():
|
||
|
|
processor = load_data("data/creditcard.csv")
|
||
|
|
assert processor.data is not None
|
||
|
|
assert processor.train_features is not None
|
||
|
|
assert processor.test_features is not None
|
||
|
|
|
||
|
|
|
||
|
|
def test_get_statistics():
|
||
|
|
processor = load_data("data/creditcard.csv")
|
||
|
|
stats = processor.get_statistics()
|
||
|
|
assert "总记录数" in stats
|
||
|
|
assert "特征数" in stats
|
||
|
|
assert "欺诈交易数" in stats
|
||
|
|
assert "非欺诈交易数" in stats
|
||
|
|
assert stats["总记录数"] > 0
|