233 lines
8.9 KiB
Python
233 lines
8.9 KiB
Python
|
|
import streamlit as st
|
|||
|
|
import pandas as pd
|
|||
|
|
import sys
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
# 添加项目根目录到Python路径
|
|||
|
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
|
|||
|
|
|
|||
|
|
from src.agent import agent, get_agent
|
|||
|
|
from src.data import load_data, preprocess_data, split_data
|
|||
|
|
from src.models import train_model, save_model, load_model, compare_models
|
|||
|
|
|
|||
|
|
# 设置页面配置
|
|||
|
|
st.set_page_config(
|
|||
|
|
page_title="垃圾短信分类系统",
|
|||
|
|
page_icon="📱",
|
|||
|
|
layout="wide",
|
|||
|
|
initial_sidebar_state="expanded"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 应用标题
|
|||
|
|
st.title("📱 垃圾短信分类系统")
|
|||
|
|
st.markdown("---")
|
|||
|
|
|
|||
|
|
# 侧边栏
|
|||
|
|
with st.sidebar:
|
|||
|
|
st.header("系统配置")
|
|||
|
|
|
|||
|
|
# 模型选择
|
|||
|
|
model_option = st.selectbox(
|
|||
|
|
"选择模型",
|
|||
|
|
options=["lightgbm", "logistic_regression"],
|
|||
|
|
index=0,
|
|||
|
|
help="选择用于分类的机器学习模型"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 语言选择
|
|||
|
|
lang_option = st.selectbox(
|
|||
|
|
"输出语言",
|
|||
|
|
options=["中文", "英文"],
|
|||
|
|
index=0,
|
|||
|
|
help="选择分类结果和解释的输出语言"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 系统说明
|
|||
|
|
st.markdown("---")
|
|||
|
|
st.header("关于系统")
|
|||
|
|
st.info(
|
|||
|
|
"这是一个基于传统机器学习 + LLM + Agent的垃圾短信分类系统。\n"\
|
|||
|
|
"- 使用LightGBM和Logistic Regression进行分类\n"\
|
|||
|
|
"- 利用DeepSeek LLM解释分类结果\n"\
|
|||
|
|
"- 通过Agent实现工具调用和结果整合"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 主内容区域
|
|||
|
|
col1, col2 = st.columns([1, 1], gap="large")
|
|||
|
|
|
|||
|
|
with col1:
|
|||
|
|
# 短信输入
|
|||
|
|
st.header("输入短信")
|
|||
|
|
|
|||
|
|
# 单条短信输入
|
|||
|
|
sms_input = st.text_area(
|
|||
|
|
"请输入要分类的短信",
|
|||
|
|
height=200,
|
|||
|
|
placeholder="例如:WINNER!! As a valued network customer you have been selected to receivea £900 prize reward!"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 分类按钮
|
|||
|
|
classify_button = st.button(
|
|||
|
|
"📊 开始分类",
|
|||
|
|
type="primary",
|
|||
|
|
use_container_width=True,
|
|||
|
|
disabled=sms_input.strip() == ""
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 批量上传功能
|
|||
|
|
st.markdown("---")
|
|||
|
|
st.header("批量分类")
|
|||
|
|
uploaded_file = st.file_uploader(
|
|||
|
|
"上传CSV文件(包含text列)",
|
|||
|
|
type=["csv"],
|
|||
|
|
help="上传包含短信文本的CSV文件,系统将自动分类"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 模型训练功能(可选)
|
|||
|
|
with st.expander("🔧 模型训练", expanded=False):
|
|||
|
|
if st.button("重新训练模型"):
|
|||
|
|
with st.spinner("正在训练模型..."):
|
|||
|
|
try:
|
|||
|
|
# 加载和预处理数据
|
|||
|
|
df = load_data("../data/spam.csv")
|
|||
|
|
processed_df = preprocess_data(df)
|
|||
|
|
train_df, test_df = split_data(processed_df)
|
|||
|
|
|
|||
|
|
# 训练模型
|
|||
|
|
model, params = train_model(train_df, model_type=model_option)
|
|||
|
|
save_model(model, model_option)
|
|||
|
|
|
|||
|
|
st.success(f"{model_option} 模型训练完成!")
|
|||
|
|
except Exception as e:
|
|||
|
|
st.error(f"模型训练失败:{e}")
|
|||
|
|
|
|||
|
|
with col2:
|
|||
|
|
# 分类结果显示
|
|||
|
|
st.header("分类结果")
|
|||
|
|
|
|||
|
|
# 单条短信分类结果
|
|||
|
|
if classify_button and sms_input.strip():
|
|||
|
|
with st.spinner("正在分类..."):
|
|||
|
|
try:
|
|||
|
|
# 使用Agent进行分类和解释
|
|||
|
|
result = agent.classify_and_explain(sms_input)
|
|||
|
|
|
|||
|
|
# 显示分类结果
|
|||
|
|
st.subheader("📋 分类标签")
|
|||
|
|
|
|||
|
|
# 根据标签显示不同的样式
|
|||
|
|
if result['classification']['label'] == "spam":
|
|||
|
|
st.error(f"⚠️ 这是一条**垃圾短信**")
|
|||
|
|
else:
|
|||
|
|
st.success(f"✅ 这是一条**正常短信**")
|
|||
|
|
|
|||
|
|
# 显示概率
|
|||
|
|
st.subheader("📊 分类概率")
|
|||
|
|
prob_df = pd.DataFrame.from_dict(
|
|||
|
|
result['classification']['probability'],
|
|||
|
|
orient='index',
|
|||
|
|
columns=['概率']
|
|||
|
|
)
|
|||
|
|
st.bar_chart(prob_df)
|
|||
|
|
|
|||
|
|
# 显示详细结果
|
|||
|
|
st.subheader("📝 详细结果")
|
|||
|
|
with st.expander("查看详细分类结果", expanded=True):
|
|||
|
|
st.json(result['classification'], expanded=False)
|
|||
|
|
|
|||
|
|
# 显示解释和建议
|
|||
|
|
st.subheader("🤔 结果解释")
|
|||
|
|
with st.expander("查看分类解释", expanded=True):
|
|||
|
|
st.write(f"**内容摘要**:{result['explanation']['content_summary']}")
|
|||
|
|
st.write(f"**分类原因**:{result['explanation']['classification_reason']}")
|
|||
|
|
st.write(f"**可信度**:{result['explanation']['confidence_level']} - {result['explanation']['confidence_explanation']}")
|
|||
|
|
|
|||
|
|
st.subheader("💡 行动建议")
|
|||
|
|
for i, suggestion in enumerate(result['explanation']['suggestions']):
|
|||
|
|
st.write(f"{i+1}. {suggestion}")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
st.error(f"分类失败:{e}")
|
|||
|
|
|
|||
|
|
# 批量分类结果
|
|||
|
|
if uploaded_file is not None:
|
|||
|
|
with st.spinner("正在批量分类..."):
|
|||
|
|
try:
|
|||
|
|
# 读取上传的文件
|
|||
|
|
df = pd.read_csv(uploaded_file)
|
|||
|
|
|
|||
|
|
if "text" not in df.columns:
|
|||
|
|
st.error("CSV文件必须包含'text'列")
|
|||
|
|
else:
|
|||
|
|
# 限制处理数量
|
|||
|
|
max_rows = 100
|
|||
|
|
if len(df) > max_rows:
|
|||
|
|
st.warning(f"文件包含 {len(df)} 条记录,仅处理前 {max_rows} 条")
|
|||
|
|
df = df.head(max_rows)
|
|||
|
|
|
|||
|
|
# 批量分类
|
|||
|
|
results = []
|
|||
|
|
for text in df["text"].tolist():
|
|||
|
|
result = agent.classify_and_explain(text)
|
|||
|
|
results.append({
|
|||
|
|
"text": text,
|
|||
|
|
"label": result['classification']['label'],
|
|||
|
|
"spam_probability": result['classification']['probability']['spam'],
|
|||
|
|
"ham_probability": result['classification']['probability']['ham'],
|
|||
|
|
"content_summary": result['explanation']['content_summary'],
|
|||
|
|
"classification_reason": result['explanation']['classification_reason']
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
# 转换为DataFrame
|
|||
|
|
results_df = pd.DataFrame(results)
|
|||
|
|
|
|||
|
|
# 显示结果统计
|
|||
|
|
st.subheader("📊 分类统计")
|
|||
|
|
label_counts = results_df["label"].value_counts()
|
|||
|
|
st.bar_chart(label_counts)
|
|||
|
|
|
|||
|
|
# 显示结果表格
|
|||
|
|
st.subheader("📋 分类结果")
|
|||
|
|
st.dataframe(
|
|||
|
|
results_df,
|
|||
|
|
use_container_width=True,
|
|||
|
|
column_config={
|
|||
|
|
"text": st.column_config.TextColumn("短信内容", width="medium"),
|
|||
|
|
"label": st.column_config.TextColumn("分类标签"),
|
|||
|
|
"spam_probability": st.column_config.ProgressColumn(
|
|||
|
|
"垃圾短信概率",
|
|||
|
|
format="%.2f",
|
|||
|
|
min_value=0.0,
|
|||
|
|
max_value=1.0
|
|||
|
|
),
|
|||
|
|
"ham_probability": st.column_config.ProgressColumn(
|
|||
|
|
"正常短信概率",
|
|||
|
|
format="%.2f",
|
|||
|
|
min_value=0.0,
|
|||
|
|
max_value=1.0
|
|||
|
|
),
|
|||
|
|
"content_summary": st.column_config.TextColumn("内容摘要", width="medium"),
|
|||
|
|
"classification_reason": st.column_config.TextColumn("分类原因", width="medium")
|
|||
|
|
}
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 下载结果
|
|||
|
|
st.subheader("💾 下载结果")
|
|||
|
|
csv = results_df.to_csv(index=False).encode('utf-8')
|
|||
|
|
st.download_button(
|
|||
|
|
label="下载分类结果 (CSV)",
|
|||
|
|
data=csv,
|
|||
|
|
file_name="spam_classification_results.csv",
|
|||
|
|
mime="text/csv",
|
|||
|
|
use_container_width=True
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
st.error(f"批量分类失败:{e}")
|
|||
|
|
|
|||
|
|
# 页脚
|
|||
|
|
st.markdown("---")
|
|||
|
|
st.markdown(
|
|||
|
|
"<center>© 2026 垃圾短信分类系统 | 基于传统机器学习 + LLM + Agent</center>",
|
|||
|
|
unsafe_allow_html=True
|
|||
|
|
)
|