85 lines
2.7 KiB
Python
85 lines
2.7 KiB
Python
|
|
"""数据模块测试"""
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
import pandas as pd
|
||
|
|
from pathlib import Path
|
||
|
|
import sys
|
||
|
|
|
||
|
|
# 添加项目根目录到Python路径
|
||
|
|
sys.path.append(str(Path(__file__).parent.parent))
|
||
|
|
|
||
|
|
from src.data import DataProcessor, TweetSchema
|
||
|
|
|
||
|
|
|
||
|
|
class TestDataProcessor:
|
||
|
|
"""数据处理器测试类"""
|
||
|
|
|
||
|
|
def setup_method(self):
|
||
|
|
"""测试准备"""
|
||
|
|
self.processor = DataProcessor("data/Tweets.csv")
|
||
|
|
|
||
|
|
def test_abbreviation_dict(self):
|
||
|
|
"""测试缩写词典"""
|
||
|
|
assert 'pls' in self.processor.abb_dict
|
||
|
|
assert self.processor.abb_dict['pls'] == 'please'
|
||
|
|
|
||
|
|
def test_preprocess_text(self):
|
||
|
|
"""测试文本预处理"""
|
||
|
|
# 测试基础清洗
|
||
|
|
text = "Hello @user This is a #test http://example.com"
|
||
|
|
processed = self.processor.preprocess_text(text)
|
||
|
|
|
||
|
|
# 应该保留@、#、URL
|
||
|
|
assert "@user" in processed
|
||
|
|
assert "#test" in processed
|
||
|
|
assert "http://example.com" in processed
|
||
|
|
|
||
|
|
# 测试缩写替换
|
||
|
|
text_with_abb = "pls help thx"
|
||
|
|
processed = self.processor.preprocess_text(text_with_abb)
|
||
|
|
assert "please" in processed
|
||
|
|
assert "thanks" in processed
|
||
|
|
|
||
|
|
def test_schema_validation(self):
|
||
|
|
"""测试数据验证"""
|
||
|
|
# 创建测试数据
|
||
|
|
test_data = {
|
||
|
|
'tweet_id': [1, 2, 3],
|
||
|
|
'airline_sentiment': ['negative', 'neutral', 'positive'],
|
||
|
|
'airline_sentiment_confidence': [0.9, 0.8, 0.95],
|
||
|
|
'negativereason': ['Late Flight', None, 'Bad Service'],
|
||
|
|
'airline': ['united', 'delta', 'american'],
|
||
|
|
'text': ['test tweet 1', 'test tweet 2', 'test tweet 3'],
|
||
|
|
'tweet_created': ['2023-01-01', '2023-01-02', '2023-01-03']
|
||
|
|
}
|
||
|
|
|
||
|
|
df = pd.DataFrame(test_data)
|
||
|
|
|
||
|
|
# 应该通过验证
|
||
|
|
validated_df = TweetSchema.validate(df)
|
||
|
|
assert len(validated_df) == 3
|
||
|
|
|
||
|
|
def test_feature_extraction(self):
|
||
|
|
"""测试特征提取"""
|
||
|
|
# 创建测试数据
|
||
|
|
test_data = {
|
||
|
|
'text': ['@user test #hashtag http://example.com'],
|
||
|
|
'airline': ['united'],
|
||
|
|
'airline_sentiment_confidence': [0.8],
|
||
|
|
'tweet_created': ['2023-01-01 10:00:00']
|
||
|
|
}
|
||
|
|
|
||
|
|
df = pd.DataFrame(test_data)
|
||
|
|
df_processed = self.processor.extract_features(df)
|
||
|
|
|
||
|
|
# 检查提取的特征
|
||
|
|
assert 'cleaned_text' in df_processed.columns
|
||
|
|
assert 'text_length' in df_processed.columns
|
||
|
|
assert 'has_mention' in df_processed.columns
|
||
|
|
assert df_processed['has_mention'].iloc[0] == 1
|
||
|
|
assert df_processed['has_hashtag'].iloc[0] == 1
|
||
|
|
assert df_processed['has_url'].iloc[0] == 1
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
pytest.main([__file__])
|