sms-castle-walls/src/data_processing.py

77 lines
2.0 KiB
Python
Raw Normal View History

import polars as pl
import pandas as pd
from pathlib import Path
from typing import Tuple
def load_data(file_path: str) -> pl.DataFrame:
"""使用Polars加载数据集"""
# 加载csv文件处理编码问题
df = pl.read_csv(
file_path,
encoding="latin-1",
ignore_errors=True,
has_header=True
)
return df
def clean_data(df: pl.DataFrame) -> pl.DataFrame:
"""清洗数据集"""
# 查看数据集基本信息
print("原始数据集形状:", df.shape)
print("原始数据集列名:", df.columns)
# 删除不必要的列(最后三列都是空的)
df = df.drop(df.columns[-3:])
# 重命名列名
df = df.rename({
"v1": "label",
"v2": "message"
})
# 查看清洗后的数据集
print("清洗后数据集形状:", df.shape)
print("清洗后数据集列名:", df.columns)
print("标签分布:", df["label"].value_counts())
return df
def preprocess_data(df: pl.DataFrame) -> Tuple[pl.DataFrame, pl.Series]:
"""预处理数据,准备用于模型训练"""
# 将标签转换为数值ham=0, spam=1
df = df.with_columns(
pl.when(pl.col("label") == "spam").then(1).otherwise(0).alias("label")
)
# 分离特征和标签
X = df.drop("label")
y = df["label"]
return X, y
def save_data(df: pl.DataFrame, file_path: str) -> None:
"""保存处理后的数据集"""
df.write_csv(file_path, index=False)
print(f"数据集已保存到: {file_path}")
if __name__ == "__main__":
# 测试数据处理流程
file_path = "../spam.csv"
# 检查文件是否存在
import os
if not os.path.exists(file_path):
file_path = "./spam.csv"
df = load_data(file_path)
df_cleaned = clean_data(df)
X, y = preprocess_data(df_cleaned)
print("特征数据形状:", X.shape)
print("标签数据形状:", y.shape)
print("前5行数据:")
print(df_cleaned.head())