77 lines
2.0 KiB
Python
77 lines
2.0 KiB
Python
import polars as pl
|
||
import pandas as pd
|
||
from pathlib import Path
|
||
from typing import Tuple
|
||
|
||
|
||
def load_data(file_path: str) -> pl.DataFrame:
|
||
"""使用Polars加载数据集"""
|
||
# 加载csv文件,处理编码问题
|
||
df = pl.read_csv(
|
||
file_path,
|
||
encoding="latin-1",
|
||
ignore_errors=True,
|
||
has_header=True
|
||
)
|
||
return df
|
||
|
||
|
||
def clean_data(df: pl.DataFrame) -> pl.DataFrame:
|
||
"""清洗数据集"""
|
||
# 查看数据集基本信息
|
||
print("原始数据集形状:", df.shape)
|
||
print("原始数据集列名:", df.columns)
|
||
|
||
# 删除不必要的列(最后三列都是空的)
|
||
df = df.drop(df.columns[-3:])
|
||
|
||
# 重命名列名
|
||
df = df.rename({
|
||
"v1": "label",
|
||
"v2": "message"
|
||
})
|
||
|
||
# 查看清洗后的数据集
|
||
print("清洗后数据集形状:", df.shape)
|
||
print("清洗后数据集列名:", df.columns)
|
||
print("标签分布:", df["label"].value_counts())
|
||
|
||
return df
|
||
|
||
|
||
def preprocess_data(df: pl.DataFrame) -> Tuple[pl.DataFrame, pl.Series]:
|
||
"""预处理数据,准备用于模型训练"""
|
||
# 将标签转换为数值(ham=0, spam=1)
|
||
df = df.with_columns(
|
||
pl.when(pl.col("label") == "spam").then(1).otherwise(0).alias("label")
|
||
)
|
||
|
||
# 分离特征和标签
|
||
X = df.drop("label")
|
||
y = df["label"]
|
||
|
||
return X, y
|
||
|
||
|
||
def save_data(df: pl.DataFrame, file_path: str) -> None:
|
||
"""保存处理后的数据集"""
|
||
df.write_csv(file_path, index=False)
|
||
print(f"数据集已保存到: {file_path}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试数据处理流程
|
||
file_path = "../spam.csv"
|
||
# 检查文件是否存在
|
||
import os
|
||
if not os.path.exists(file_path):
|
||
file_path = "./spam.csv"
|
||
df = load_data(file_path)
|
||
df_cleaned = clean_data(df)
|
||
X, y = preprocess_data(df_cleaned)
|
||
|
||
print("特征数据形状:", X.shape)
|
||
print("标签数据形状:", y.shape)
|
||
print("前5行数据:")
|
||
print(df_cleaned.head())
|