77 lines
2.0 KiB
Python
77 lines
2.0 KiB
Python
|
|
import polars as pl
|
|||
|
|
import pandas as pd
|
|||
|
|
from pathlib import Path
|
|||
|
|
from typing import Tuple
|
|||
|
|
|
|||
|
|
|
|||
|
|
def load_data(file_path: str) -> pl.DataFrame:
|
|||
|
|
"""使用Polars加载数据集"""
|
|||
|
|
# 加载csv文件,处理编码问题
|
|||
|
|
df = pl.read_csv(
|
|||
|
|
file_path,
|
|||
|
|
encoding="latin-1",
|
|||
|
|
ignore_errors=True,
|
|||
|
|
has_header=True
|
|||
|
|
)
|
|||
|
|
return df
|
|||
|
|
|
|||
|
|
|
|||
|
|
def clean_data(df: pl.DataFrame) -> pl.DataFrame:
|
|||
|
|
"""清洗数据集"""
|
|||
|
|
# 查看数据集基本信息
|
|||
|
|
print("原始数据集形状:", df.shape)
|
|||
|
|
print("原始数据集列名:", df.columns)
|
|||
|
|
|
|||
|
|
# 删除不必要的列(最后三列都是空的)
|
|||
|
|
df = df.drop(df.columns[-3:])
|
|||
|
|
|
|||
|
|
# 重命名列名
|
|||
|
|
df = df.rename({
|
|||
|
|
"v1": "label",
|
|||
|
|
"v2": "message"
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 查看清洗后的数据集
|
|||
|
|
print("清洗后数据集形状:", df.shape)
|
|||
|
|
print("清洗后数据集列名:", df.columns)
|
|||
|
|
print("标签分布:", df["label"].value_counts())
|
|||
|
|
|
|||
|
|
return df
|
|||
|
|
|
|||
|
|
|
|||
|
|
def preprocess_data(df: pl.DataFrame) -> Tuple[pl.DataFrame, pl.Series]:
|
|||
|
|
"""预处理数据,准备用于模型训练"""
|
|||
|
|
# 将标签转换为数值(ham=0, spam=1)
|
|||
|
|
df = df.with_columns(
|
|||
|
|
pl.when(pl.col("label") == "spam").then(1).otherwise(0).alias("label")
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 分离特征和标签
|
|||
|
|
X = df.drop("label")
|
|||
|
|
y = df["label"]
|
|||
|
|
|
|||
|
|
return X, y
|
|||
|
|
|
|||
|
|
|
|||
|
|
def save_data(df: pl.DataFrame, file_path: str) -> None:
|
|||
|
|
"""保存处理后的数据集"""
|
|||
|
|
df.write_csv(file_path, index=False)
|
|||
|
|
print(f"数据集已保存到: {file_path}")
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
# 测试数据处理流程
|
|||
|
|
file_path = "../spam.csv"
|
|||
|
|
# 检查文件是否存在
|
|||
|
|
import os
|
|||
|
|
if not os.path.exists(file_path):
|
|||
|
|
file_path = "./spam.csv"
|
|||
|
|
df = load_data(file_path)
|
|||
|
|
df_cleaned = clean_data(df)
|
|||
|
|
X, y = preprocess_data(df_cleaned)
|
|||
|
|
|
|||
|
|
print("特征数据形状:", X.shape)
|
|||
|
|
print("标签数据形状:", y.shape)
|
|||
|
|
print("前5行数据:")
|
|||
|
|
print(df_cleaned.head())
|