text/src/streamlit_app.py

550 lines
23 KiB
Python
Raw Normal View History

import streamlit as st
import pandas as pd
import sys
import os
# 添加项目根目录到Python路径
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
from src.agent import agent, get_agent
from src.data import load_data, preprocess_data, split_data
from src.models import train_model, save_model, load_model, compare_models
# 设置页面配置
st.set_page_config(
page_title="垃圾短信分类器",
page_icon="<EFBFBD>",
layout="wide",
initial_sidebar_state="expanded"
)
# 自定义CSS - 欧洲中世纪风格
st.markdown("""<style>
/* 基础样式 */
body {
background-color: #1a1a2e;
color: #e0e0e0;
font-family: 'Georgia', serif;
}
/* 标题样式 */
.stTitle {
color: #d4af37;
font-family: 'Garamond', serif;
text-shadow: 2px 2px 4px rgba(0, 0, 0, 0.8);
border-bottom: 2px solid #d4af37;
padding-bottom: 10px;
}
/* 侧边栏样式 */
.stSidebar {
background-color: #16213e;
border-right: 2px solid #d4af37;
}
/* 卡片和容器 */
.stExpander, .stContainer {
background-color: #0f3460;
border: 1px solid #d4af37;
border-radius: 8px;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);
}
/* 按钮样式 */
.stButton > button {
background-color: #c8102e;
color: #ffffff;
border: 2px solid #d4af37;
border-radius: 8px;
font-family: 'Georgia', serif;
font-weight: bold;
padding: 10px 20px;
transition: all 0.3s ease;
}
.stButton > button:hover {
background-color: #8b0000;
color: #d4af37;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);
transform: translateY(-2px);
}
/* 输入框样式 */
.stTextArea, .stSelectbox > div {
background-color: #16213e;
border: 2px solid #d4af37;
border-radius: 8px;
color: #e0e0e0;
font-family: 'Georgia', serif;
}
/* 标题和文本样式 */
h1, h2, h3, h4, h5, h6 {
color: #d4af37;
font-family: 'Garamond', serif;
}
/* 分隔线 */
hr {
border: 1px solid #d4af37;
}
/* 信息卡片 */
.stAlert {
background-color: #0f3460;
border: 2px solid #d4af37;
border-radius: 8px;
color: #e0e0e0;
}
/* 页脚 */
footer {
color: #d4af37;
font-family: 'Georgia', serif;
text-align: center;
padding: 20px;
border-top: 2px solid #d4af37;
margin-top: 40px;
}
</style>""", unsafe_allow_html=True)
# 应用标题 - 中世纪风格
st.markdown("""
<div style="text-align: center; padding: 20px; border: 3px solid #d4af37; border-radius: 10px; background-color: #16213e; box-shadow: 0 8px 16px rgba(0, 0, 0, 0.8);">
<h1 style="color: #d4af37; font-family: 'Garamond', serif; text-shadow: 3px 3px 6px rgba(0, 0, 0, 0.8); margin: 0;">
中世纪垃圾短信分类器
</h1>
<p style="color: #e0e0e0; font-style: italic; margin-top: 10px;">
保护您的通信抵御垃圾信息的入侵
</p>
</div>
""", unsafe_allow_html=True)
st.markdown("---")
# 侧边栏 - 中世纪风格
with st.sidebar:
st.markdown("""
<div style="text-align: center; padding: 10px; border-bottom: 2px solid #d4af37; margin-bottom: 20px;">
<h2 style="color: #d4af37; font-family: 'Garamond', serif; margin: 0;">
🛡 骑士工坊
</h2>
<p style="color: #e0e0e0; font-size: 14px; margin: 5px 0 0;">
系统配置
</p>
</div>
""", unsafe_allow_html=True)
# 模型选择 - 中世纪风格
st.markdown("""
<div style="margin-bottom: 20px;">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; font-size: 18px;">
选择武器
</h3>
<p style="color: #e0e0e0; font-size: 14px; margin: 5px 0 10px;">
选择用于抵御垃圾信息的武器
</p>
</div>
""", unsafe_allow_html=True)
model_option = st.selectbox(
label="",
options=["lightgbm", "logistic_regression"],
index=0,
format_func=lambda x: "圣光使者 (LightGBM)" if x == "lightgbm" else "智慧之剑 (Logistic Regression)"
)
# 语言选择 - 中世纪风格
st.markdown("""
<div style="margin-bottom: 20px;">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; font-size: 18px;">
📜 选择语言
</h3>
<p style="color: #e0e0e0; font-size: 14px; margin: 5px 0 10px;">
选择预言师的语言
</p>
</div>
""", unsafe_allow_html=True)
lang_option = st.selectbox(
label="",
options=["中文", "英文"],
index=0
)
# 系统说明 - 中世纪风格
st.markdown("---")
st.markdown("""
<div style="text-align: center; padding: 10px; border-bottom: 2px solid #d4af37; margin-bottom: 20px;">
<h2 style="color: #d4af37; font-family: 'Garamond', serif; margin: 0;">
🏰 关于城堡
</h2>
</div>
""", unsafe_allow_html=True)
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #d4af37; border-radius: 8px; padding: 15px; color: #e0e0e0; font-family: 'Georgia', serif;">
<p><strong>🛡 城堡防御系统</strong></p>
<p>这是一座由现代魔法和古老智慧构建的防御城堡</p>
<ul>
<li>💫 使用圣光使者 (LightGBM) 和智慧之剑 (Logistic Regression) 守护</li>
<li>🧙 由DeepSeek预言师提供智慧解释</li>
<li>🤖 通过魔法使者 (Agent) 整合所有力量</li>
</ul>
<p style="margin-top: 15px; font-size: 14px; font-style: italic;">
保护您的通信不受垃圾信息的侵袭
</p>
</div>
""", unsafe_allow_html=True)
# 主内容区域 - 中世纪风格
col1, col2 = st.columns([1, 1], gap="large")
with col1:
# 短信输入 - 中世纪风格
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h2 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
📜 信件输入
</h2>
<p style="color: #e0e0e0; text-align: center; font-style: italic;">
输入需要检查的信件内容
</p>
</div>
""", unsafe_allow_html=True)
# 单条短信输入
sms_input = st.text_area(
label="",
height=200,
placeholder="例如WINNER!! As a valued network customer you have been selected to receivea £900 prize reward!",
help="输入需要分类的短信内容"
)
# 分类按钮 - 中世纪风格
classify_button = st.button(
"⚔️ 开始检查",
type="primary",
use_container_width=True,
disabled=sms_input.strip() == ""
)
# 批量上传功能 - 中世纪风格
st.markdown("---")
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h2 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
📦 批量检查
</h2>
<p style="color: #e0e0e0; text-align: center; font-style: italic;">
上传多封信件进行批量检查
</p>
</div>
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader(
label="",
type=["csv"],
help="上传包含短信文本的CSV文件需要包含text列"
)
# 模型训练功能(可选) - 中世纪风格
with st.expander("🔧 锻造武器", expanded=False):
st.markdown("""
<div style="background-color: #0f3460; border: 1px solid #d4af37; border-radius: 8px; padding: 15px;">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; margin-top: 0;">
铁匠工坊
</h3>
<p style="color: #e0e0e0; font-size: 14px;">
重新锻造您的武器提升防御能力
</p>
</div>
""", unsafe_allow_html=True)
if st.button("⚒️ 重新锻造武器"):
with st.spinner("🔨 铁匠正在锻造武器..."):
try:
# 加载和预处理数据
df = load_data("../data/spam.csv")
processed_df = preprocess_data(df)
train_df, test_df = split_data(processed_df)
# 训练模型
model, params = train_model(train_df, model_type=model_option)
save_model(model, model_option)
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #d4af37; border-radius: 8px; padding: 15px; color: #d4af37; font-weight: bold;">
武器锻造完成
<p style="color: #e0e0e0; font-weight: normal; margin-top: 10px;">
您的 {} 已准备好进行战斗
</p>
</div>
"""
.format("圣光使者 (LightGBM)" if model_option == "lightgbm" else "智慧之剑 (Logistic Regression)"), unsafe_allow_html=True)
except Exception as e:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #c8102e; border-radius: 8px; padding: 15px; color: #c8102e; font-weight: bold;">
锻造失败
<p style="color: #e0e0e0; font-weight: normal; margin-top: 10px;">
铁匠遇到了问题{}
</p>
</div>
"""
.format(e), unsafe_allow_html=True)
with col2:
# 分类结果显示
st.header("分类结果")
# 单条短信分类结果
if classify_button and sms_input.strip():
with st.spinner("正在分类..."):
try:
# 使用Agent进行分类和解释
result = agent.classify_and_explain(sms_input)
# 显示分类结果
st.subheader("📋 分类标签")
# 根据标签显示不同的样式 - 中世纪风格
if result['classification']['label'] == "spam":
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #c8102e; border-radius: 8px; padding: 15px; text-align: center; font-size: 18px; font-weight: bold;">
<span style="color: #c8102e;">这是一封**垃圾信件**</span>
<p style="font-size: 14px; font-weight: normal; margin-top: 10px; color: #e0e0e0;">
建议您谨慎对待此信件
</p>
</div>
""", unsafe_allow_html=True)
else:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #228B22; border-radius: 8px; padding: 15px; text-align: center; font-size: 18px; font-weight: bold;">
<span style="color: #228B22;">这是一封**正常信件**</span>
<p style="font-size: 14px; font-weight: normal; margin-top: 10px; color: #e0e0e0;">
此信件安全可以放心阅读
</p>
</div>
""", unsafe_allow_html=True)
# 显示概率 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
📊 预言概率
</h3>
</div>
""", unsafe_allow_html=True)
prob_df = pd.DataFrame.from_dict(
result['classification']['probability'],
orient='index',
columns=['概率']
)
prob_df.index = ['垃圾信件', '正常信件'] if lang_option == '中文' else ['Spam', 'Ham']
st.bar_chart(prob_df)
# 显示详细结果 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
<EFBFBD> 详细预言
</h3>
</div>
""", unsafe_allow_html=True)
with st.expander("查看详细分类结果", expanded=True):
st.json(result['classification'], expanded=False)
# 显示解释和建议 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
<EFBFBD> 预言师的解释
</h3>
</div>
""", unsafe_allow_html=True)
with st.expander("查看预言解释", expanded=True):
st.markdown("""
<div style="background-color: #16213e; border: 1px solid #d4af37; border-radius: 8px; padding: 15px; color: #e0e0e0;">
<p><strong style="color: #d4af37;">📝 内容摘要</strong>{}</p>
<p><strong style="color: #d4af37;"> 预言原因</strong>{}</p>
<p><strong style="color: #d4af37;">🔮 可信度</strong>{} - {}</p>
</div>
"""
.format(
result['explanation']['content_summary'],
result['explanation']['classification_reason'],
result['explanation']['confidence_level'],
result['explanation']['confidence_explanation']
), unsafe_allow_html=True)
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
💡 行动建议
</h3>
</div>
""", unsafe_allow_html=True)
suggestion_html = """
<div style="background-color: #16213e; border: 1px solid #d4af37; border-radius: 8px; padding: 15px;">
<ol style="color: #e0e0e0; list-style-type: decimal; padding-left: 20px;">
"""
for i, suggestion in enumerate(result['explanation']['suggestions']):
suggestion_html += f"<li style='margin-bottom: 10px;'>{suggestion}</li>"
suggestion_html += """
</ol>
</div>
"""
st.markdown(suggestion_html, unsafe_allow_html=True)
except Exception as e:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #c8102e; border-radius: 8px; padding: 15px; color: #c8102e;">
预言失败
<p style="color: #e0e0e0; margin-top: 10px;">
预言师遇到了问题{}
</p>
</div>
"""
.format(e), unsafe_allow_html=True)
# 批量分类结果 - 中世纪风格
if uploaded_file is not None:
with st.spinner("🧙‍♂️ 预言师正在批量解析信件..."):
try:
# 读取上传的文件
df = pd.read_csv(uploaded_file)
if "text" not in df.columns:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #c8102e; border-radius: 8px; padding: 15px; color: #c8102e;">
预言失败
<p style="color: #e0e0e0; margin-top: 10px;">
信件文件必须包含'text'
</p>
</div>
""", unsafe_allow_html=True)
else:
# 限制处理数量
max_rows = 100
if len(df) > max_rows:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #d4af37; border-radius: 8px; padding: 15px; color: #d4af37;">
警告
<p style="color: #e0e0e0; margin-top: 10px;">
信件文件包含 {len(df)} 封信件预言师将只解析前 {max_rows}
</p>
</div>
""", unsafe_allow_html=True)
df = df.head(max_rows)
# 批量分类
results = []
for text in df["text"].tolist():
result = agent.classify_and_explain(text)
results.append({
"text": text,
"label": result['classification']['label'],
"spam_probability": result['classification']['probability']['spam'],
"ham_probability": result['classification']['probability']['ham'],
"content_summary": result['explanation']['content_summary'],
"classification_reason": result['explanation']['classification_reason']
})
# 转换为DataFrame
results_df = pd.DataFrame(results)
# 显示结果统计 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
📊 预言统计
</h3>
</div>
""", unsafe_allow_html=True)
label_counts = results_df["label"].value_counts()
label_counts.index = label_counts.index.map({"spam": "垃圾信件", "ham": "正常信件"})
st.bar_chart(label_counts)
# 显示结果表格 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
<EFBFBD> 预言结果
</h3>
</div>
""", unsafe_allow_html=True)
st.dataframe(
results_df,
use_container_width=True,
column_config={
"text": st.column_config.TextColumn("信件内容", width="medium"),
"label": st.column_config.TextColumn("预言标签"),
"spam_probability": st.column_config.ProgressColumn(
"垃圾信件概率",
format="%.2f",
min_value=0.0,
max_value=1.0
),
"ham_probability": st.column_config.ProgressColumn(
"正常信件概率",
format="%.2f",
min_value=0.0,
max_value=1.0
),
"content_summary": st.column_config.TextColumn("内容摘要", width="medium"),
"classification_reason": st.column_config.TextColumn("预言原因", width="medium")
}
)
# 下载结果 - 中世纪风格
st.markdown("""
<div style="margin-top: 20px; background-color: #0f3460; border: 2px solid #d4af37; border-radius: 10px; padding: 20px; box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5);">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; text-align: center; border-bottom: 1px solid #d4af37; padding-bottom: 10px;">
💾 保存预言
</h3>
</div>
""", unsafe_allow_html=True)
csv = results_df.to_csv(index=False).encode('utf-8')
st.download_button(
label="📄 下载预言结果",
data=csv,
file_name="spam_classification_results.csv",
mime="text/csv",
use_container_width=True
)
except Exception as e:
st.markdown("""
<div style="background-color: #0f3460; border: 2px solid #c8102e; border-radius: 8px; padding: 15px; color: #c8102e;">
预言失败
<p style="color: #e0e0e0; margin-top: 10px;">
预言师遇到了问题{}
</p>
</div>
"""
.format(e), unsafe_allow_html=True)
# 页脚 - 中世纪风格
st.markdown("---")
st.markdown("""
<div style="text-align: center; padding: 20px; border-top: 2px solid #d4af37; margin-top: 40px;">
<h3 style="color: #d4af37; font-family: 'Garamond', serif; margin-bottom: 10px;">
🏰 中世纪垃圾短信防御城堡
</h3>
<p style="color: #e0e0e0; font-family: 'Georgia', serif; font-size: 14px;">
© 2026 由骑士团建造 | 基于传统魔法 + LLM 预言 + Agent 使者
</p>
<p style="color: #d4af37; font-size: 12px; margin-top: 10px; font-style: italic;">
保护您的通信不受垃圾信息的侵袭
</p>
</div>
""", unsafe_allow_html=True)