Initial commit: Customer Sentiment Analysis project
This commit is contained in:
commit
990323d5d6
2
.env.example
Normal file
2
.env.example
Normal file
@ -0,0 +1,2 @@
|
||||
# API Key for the application
|
||||
API_KEY=deepseek
|
||||
22
.gitignore
vendored
Normal file
22
.gitignore
vendored
Normal file
@ -0,0 +1,22 @@
|
||||
# Python-generated files
|
||||
__pycache__/
|
||||
*.py[oc]
|
||||
build/
|
||||
dist/
|
||||
wheels/
|
||||
*.egg-info
|
||||
|
||||
# Virtual environments
|
||||
.venv
|
||||
|
||||
# Environment and secrets
|
||||
.env
|
||||
.env.*
|
||||
!.env.example
|
||||
|
||||
# Artifacts and models
|
||||
artifacts/
|
||||
*.joblib
|
||||
|
||||
# Large data files
|
||||
data/*.csv
|
||||
1
.python-version
Normal file
1
.python-version
Normal file
@ -0,0 +1 @@
|
||||
3.12
|
||||
207
README.md
Normal file
207
README.md
Normal file
@ -0,0 +1,207 @@
|
||||
# 客户情感预测与风险分析系统
|
||||
|
||||
> **机器学习 (Python) 课程设计**
|
||||
|
||||
## 👥 团队成员
|
||||
|
||||
| 姓名 | 学号 | 贡献 |
|
||||
|------|------|------|
|
||||
| 于洋 | 2311020129 | 数据处理、模型训练 |
|
||||
| 张洁 | 2311020131 | Agent 开发、提交 |
|
||||
| 杨艺瑶 | 2311020127 | streamlit,汇报 |
|
||||
|
||||
## 📝 项目简介
|
||||
|
||||
本项目旨在构建一个端到端的客户情感分析与风险预警系统。通过对包含2.5万条记录的客户评论数据集(Customer Sentiment Dataset)进行挖掘,结合交易属性(如响应时间、解决状态、购买渠道等)与评论文本,利用机器学习算法自动识别客户的情感倾向(Positive/Negative/Neutral)。
|
||||
|
||||
系统不仅实现了高精度的情感分类模型(Logistic Regression 与 LightGBM),还进一步集成了“预测 → 分析 → 建议”的智能 Agent 流程。最终通过 Streamlit 搭建了交互式仪表盘,支持实时输入客户特征,输出情感风险评分、关键影响因子解释以及针对性的运营建议,帮助企业及时挽留高风险客户,提升服务质量。
|
||||
|
||||
## 1️⃣ 问题定义与数据
|
||||
|
||||
### 1.1 任务描述
|
||||
|
||||
**任务类型**:多分类任务 (Multi-class Classification)
|
||||
**业务目标**:基于客户的评论文本 (`review_text`) 及相关交易属性(如评分、产品类别、平台等),预测客户的情感倾向 (`sentiment`)。该模型旨在帮助企业自动化监控客户反馈,及时识别负面评价并采取行动,从而提升客户满意度和留存率。
|
||||
|
||||
### 1.2 数据来源
|
||||
|
||||
| 项目 | 说明 |
|
||||
|------|------|
|
||||
| 数据集名称 | Customer Sentiment Dataset |
|
||||
| 数据链接 | 本地文件 (data/Customer_Sentiment.csv) |
|
||||
| 样本量 | 25,000 条 |
|
||||
| 特征数 | 11 个 (不含 ID 和标签) |
|
||||
|
||||
### 1.3 数据切分与防泄漏
|
||||
|
||||
**数据切分**:
|
||||
- 采用 **随机切分** (Random Split) 策略。
|
||||
- **训练集 : 验证集 : 测试集 = 8 : 1 : 1** (或 70% : 15% : 15%)。
|
||||
- 设定固定的 `random_state` 以确保实验可复现。
|
||||
|
||||
**防泄漏措施**:
|
||||
1. **ID 剔除**:移除 `customer_id`,防止模型记忆特定用户。
|
||||
2. **特征筛选**:剔除 `customer_rating`(客户评分),因为评分与情感倾向高度相关,直接使用会导致数据泄漏,使任务失去预测意义。
|
||||
3. **时间穿越**:虽然数据集包含 `response_time_hours` 和 `issue_resolved`,但在预测“评论发布时”的情感时,这些可能是未来信息。若业务场景为“收到评论即时预测”,应排除这些特征;若为“事后归因分析”,则可保留。本项目暂作为特征处理,但需注意其业务含义。
|
||||
4. **预处理隔离**:所有统计特征(如文本向量化的词汇表、数值归一化的均值/方差)仅在 **训练集** 上计算,严禁利用验证集或测试集信息。
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
```bash
|
||||
# 克隆仓库
|
||||
git clone http://hblu.top:3000/MachineLearning2025/GXX-Customer_Sentiment_Analysis.git
|
||||
cd GXX-Customer_Sentiment_Analysis
|
||||
|
||||
# 安装依赖
|
||||
uv sync
|
||||
|
||||
# 配置环境变量
|
||||
cp .env.example .env
|
||||
# 编辑 .env 填入 API Key
|
||||
|
||||
# 运行 Demo
|
||||
uv run streamlit run src/streamlit_app.py
|
||||
```
|
||||
|
||||
## 2️⃣ 机器学习流水线
|
||||
|
||||
### 2.1 基线模型
|
||||
|
||||
| 模型 | 指标 | 结果 |
|
||||
|------|------|------|
|
||||
| Logistic Regression | ROC-AUC | 1.0000 |
|
||||
|
||||
### 2.2 进阶模型
|
||||
|
||||
| 模型 | 指标 | 结果 |
|
||||
|------|------|------|
|
||||
| LightGBM | ROC-AUC | 1.0000 |
|
||||
|
||||
### 2.3 误差分析
|
||||
|
||||
(模型在哪些样本上表现不佳?为什么?)
|
||||
|
||||
**结果分析**:
|
||||
- 模型在测试集上取得了完美的分类效果 (ROC-AUC = 1.0, Accuracy = 100%)。
|
||||
- **原因推测**:数据集可能是合成数据,且不同情感类别(Positive, Neutral, Negative)对应的评论文本模式非常固定且区分度极高(例如 "very disappointed" 总是对应 Negative,"excellent product" 总是对应 Positive)。此外,剔除了强相关特征 `customer_rating` 后,文本特征依然提供了足够的信息进行完美分类。
|
||||
- **潜在问题**:虽然模型在当前数据集上表现完美,但在真实世界的复杂评论数据上可能无法泛化。建议引入更多样化、含噪声的真实评论数据进行进一步测试。
|
||||
- **错误样本**:无(0/5000 错误)。
|
||||
|
||||
### 2.4 阈值策略与代价敏感分析(加分项)
|
||||
- 业务代价设定:假阴性(负面未拦截)成本 `C_FN = 10`,假阳性(误拦截)成本 `C_FP = 1`。
|
||||
- 风险阈值:为降低期望代价,采用分级策略:
|
||||
- 高风险:`risk ≥ 0.7` → 立即升级处理、客服介入、补偿策略评估
|
||||
- 中风险:`0.4 ≤ risk < 0.7` → 标记与复核,优先排队
|
||||
- 低风险:`risk < 0.4` → 常规监控
|
||||
- 证据:结合 `explain_features` 的线性贡献因子,定位行动点(如高频负面词、长响应时间、未解决状态)。
|
||||
|
||||
## 数据处理(必做)
|
||||
|
||||
- 使用 **Polars** 完成可复现的数据清洗流水线,脚本位置:`src/data_processing.py`
|
||||
- 定义 Schema:`define_schema`(`pandera.polars`)在 `src/data_processing.py:5`。
|
||||
- 清洗流程:`clean_data` 在 `src/data_processing.py:23`(字段标准化、布尔化、类别转换)。
|
||||
- 探索与保存:`load_and_inspect` 在 `src/data_processing.py:52`(校验→清洗→概览→保存至 `data/Cleaned_Customer_Sentiment.csv`)。
|
||||
- 运行命令:`py src/data_processing.py`
|
||||
|
||||
### 清洗操作(Polars + Pandera)
|
||||
- Schema 校验:
|
||||
- `customer_id > 0`;`gender ∈ {male,female,other}`;`age_group ∈ {18-25,26-35,36-45,46-60,60+}`;
|
||||
- `region ∈ {north,south,east,west,central}`;`purchase_channel ∈ {online,offline}`;
|
||||
- `customer_rating ∈ [1,5]`;`sentiment ∈ {positive,negative,neutral}`;`response_time_hours ≥ 0`;
|
||||
- `issue_resolved, complaint_registered ∈ {yes,no}`。
|
||||
- 值标准化与类型转换:
|
||||
- 将 `issue_resolved`, `complaint_registered` 从 `yes/no` 转为 `bool`。
|
||||
- 将 `product_category`, `platform`, `review_text` 统一为小写。
|
||||
- 将类别列(`gender, age_group, region, product_category, purchase_channel, platform, sentiment`)转换为分类类型。
|
||||
- 去除泄漏相关特征:在训练阶段剔除 `customer_rating`。
|
||||
- 数据探索输出:
|
||||
- 类别列的频次统计(略过长文本列)。
|
||||
- 清洗后数据示例与 Schema 打印。
|
||||
- 落盘:
|
||||
- 结果保存到 `data/Cleaned_Customer_Sentiment.csv`,供训练与 Demo 使用。
|
||||
|
||||
### 清洗结果(核心指标)
|
||||
- 总行数:`25,000`
|
||||
- 空值计数:所有列空值为 `0`
|
||||
- 情感分布:
|
||||
- positive: `9,978`
|
||||
- negative: `9,937`
|
||||
- neutral: `5,085`
|
||||
- 性别分布:
|
||||
- male: `8,385`,female: `8,356`,other: `8,259`
|
||||
- 渠道分布:
|
||||
- `purchase_channel = online`(`25,000`)
|
||||
- 业务状态:
|
||||
- `issue_resolved=True` 比例:`66.372%`
|
||||
- `complaint_registered=True` 比例:`39.748%`
|
||||
- 响应时长(按情感均值,小时):
|
||||
- neutral: `36.0869`,negative: `36.0222`,positive: `35.9924`
|
||||
- Top 平台(前 5):
|
||||
- nykaa (`1,301`), snapdeal (`1,289`), others (`1,286`), reliance digital (`1,279`), zepto (`1,278`)
|
||||
- Top 产品品类(前 5):
|
||||
- groceries (`2,858`), automobile (`2,833`), books (`2,812`), travel (`2,811`), fashion (`2,782`)
|
||||
|
||||
## 机器学习(必做)
|
||||
|
||||
- 至少 2 个模型对比:已实现 **Logistic Regression** 与 **LightGBM**(见 `src/train_models.py`)。
|
||||
- 指标达标:`ROC-AUC = 1.0000 ≥ 0.75`(分类报告亦为满分,满足 `F1 ≥ 0.70` 要求)。
|
||||
- 工件持久化:流水线与标签编码器保存到 `artifacts/`,便于 Agent 复用(`src/train_models.py:137-146` 与持久化 `src/train_models.py:145-149`)。
|
||||
|
||||
## 3️⃣ Agent 实现
|
||||
|
||||
### 3.1 工具定义
|
||||
|
||||
| 工具名 | 功能 | 输入 | 输出 |
|
||||
|--------|------|------|------|
|
||||
| `predict_risk` | 调用 ML 模型预测 | CustomerFeatures | float |
|
||||
| `explain_features` | 解释特征影响 | CustomerFeatures | list[str] |
|
||||
|
||||
CustomerFeatures 字段:
|
||||
`gender`, `age_group`, `region`, `product_category`, `purchase_channel`, `platform`, `response_time_hours`, `issue_resolved`, `complaint_registered`, `review_text`
|
||||
(不包含 `customer_rating` 与 `customer_id`)
|
||||
|
||||
实现位置:`src/agent.py`
|
||||
模型与预处理加载自:`artifacts/`(含 `lgb_pipeline.joblib`, `lr_pipeline.joblib`, `label_encoder.joblib`)
|
||||
|
||||
### 3.2 决策流程
|
||||
|
||||
- 预测:`predict_risk(features)` 使用 LightGBM 流水线输出负面情感概率(risk,0.0~1.0)。
|
||||
- 解释:`explain_features(features)` 通过 Logistic 回归的线性贡献,返回若干条影响说明(包含方向与权重)。
|
||||
- 建议:基于 `risk` 与解释结果,产品可进一步生成运营建议(如:优先处理高风险投诉、优化长响应时间、针对高频负面词优化客服话术)。
|
||||
|
||||
### Agent(必做)符合项
|
||||
|
||||
- 使用 **Pydantic** 定义输入与输出模型:
|
||||
- `CustomerFeatures` 输入模型(`src/agent.py:11-21`,包含枚举与边界约束)
|
||||
- `RiskOutput` 与 `ExplanationOutput` 输出模型(`src/agent.py:23-27`)
|
||||
- 至少 2 个工具:
|
||||
- `predict_risk`(ML 预测工具,`src/agent.py:47-52`)
|
||||
- `explain_features`(特征影响解释,`src/agent.py:57-73`)
|
||||
|
||||
## 4️⃣ 开发心得
|
||||
|
||||
### 4.1 主要困难与解决方案
|
||||
|
||||
- 依赖安装与网络:`polars`, `pandera`, `scikit-learn`, `lightgbm`, `pyarrow` 等包在国内网络下安装不稳定。
|
||||
解决:配置清华镜像(如 `pip -i https://pypi.tuna.tsinghua.edu.cn/simple`、设置 `UV_INDEX_URL`),分步安装缺失依赖(如 `pyarrow` 用于 `polars.to_pandas()`)。
|
||||
- 项目初始化命名问题:中文目录名导致 `uv init` 的包名无效。
|
||||
解决:使用 `uv init --name customer-sentiment-analysis` 指定合法英文包名。
|
||||
- 模型参数兼容性:`LogisticRegression` 版本差异导致 `multi_class` 参数报错。
|
||||
解决:移除不兼容参数,使用默认自动模式。
|
||||
- 数据泄漏与指标异常:出现近乎完美的指标,初步判断评分与文本标签高度相关。
|
||||
解决:剔除 `customer_rating` 特征,保留文本与业务属性;同时在文档明确风险与解释原因。
|
||||
- 工件持久化与推理:训练后需在推理端复用预处理与模型。
|
||||
解决:将流水线与标签编码器持久化至 `artifacts/`,在 `src/agent.py` 统一加载与推理。
|
||||
|
||||
### 4.2 对 AI 辅助编程的感受
|
||||
|
||||
- 明显提升效率:快速生成数据流水线、模型训练与解释模块、结构化文档,排查错误栈更高效。
|
||||
- 需要严格验证:对自动生成的代码进行运行验证与指标审查,关注版本兼容与数据泄漏。
|
||||
- 最佳实践:固定 `random_state`、将预处理封装进 `Pipeline`、仅在训练集拟合、保存可复用工件、在 README 清晰记录假设与风险。
|
||||
|
||||
### 4.3 局限与未来改进
|
||||
|
||||
- 数据与评估:引入更真实、含噪声的评论数据;采用交叉验证与时间切分;增加宏平均 F1、精细化错误聚类。
|
||||
- 解释与可用性:引入 SHAP/LIME 等更稳健的解释方法;将解释结果转为面向运营的建议话术。
|
||||
- 系统与工程:将 `agent.py` 封装为服务/API,集成 Streamlit 交互;完善单元测试、CI、类型检查与 lint;健全 `.env` 配置校验。
|
||||
- 部署与监控:容器化部署、资源优化(并行/向量化);加上指标监控与告警、模型版本管理与回滚策略。
|
||||
18
pyproject.toml
Normal file
18
pyproject.toml
Normal file
@ -0,0 +1,18 @@
|
||||
[project]
|
||||
name = "customer-sentiment-analysis"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"lightgbm>=4.6.0",
|
||||
"matplotlib>=3.10.8",
|
||||
"pandas>=2.3.3",
|
||||
"pandera>=0.28.1",
|
||||
"polars>=1.37.1",
|
||||
"scikit-learn>=1.8.0",
|
||||
"seaborn>=0.13.2",
|
||||
"streamlit>=1.52.2",
|
||||
"pydantic>=2.9.2",
|
||||
"python-dotenv>=1.0.0",
|
||||
]
|
||||
77
src/agent.py
Normal file
77
src/agent.py
Normal file
@ -0,0 +1,77 @@
|
||||
import os
|
||||
import joblib
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from typing import Literal, Annotated
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
_model_lgb = None
|
||||
_model_lr = None
|
||||
_le = None
|
||||
|
||||
class CustomerFeatures(BaseModel):
|
||||
gender: Literal["male", "female", "other"]
|
||||
age_group: Literal["18-25", "26-35", "36-45", "46-60", "60+"]
|
||||
region: Literal["north", "south", "east", "west", "central"]
|
||||
product_category: str
|
||||
purchase_channel: Literal["online", "offline"]
|
||||
platform: str
|
||||
response_time_hours: Annotated[float, Field(ge=0)]
|
||||
issue_resolved: bool
|
||||
complaint_registered: bool
|
||||
review_text: Annotated[str, Field(min_length=3)]
|
||||
|
||||
class RiskOutput(BaseModel):
|
||||
risk: float
|
||||
|
||||
class ExplanationOutput(BaseModel):
|
||||
factors: list[str]
|
||||
|
||||
def _ensure_loaded():
|
||||
global _model_lgb, _model_lr, _le
|
||||
if _model_lgb is None:
|
||||
_model_lgb = joblib.load(os.path.join("artifacts", "lgb_pipeline.joblib"))
|
||||
if _model_lr is None:
|
||||
_model_lr = joblib.load(os.path.join("artifacts", "lr_pipeline.joblib"))
|
||||
if _le is None:
|
||||
_le = joblib.load(os.path.join("artifacts", "label_encoder.joblib"))
|
||||
|
||||
def _to_dataframe(features) -> pd.DataFrame:
|
||||
if isinstance(features, CustomerFeatures):
|
||||
payload = features.model_dump()
|
||||
elif isinstance(features, dict):
|
||||
payload = features
|
||||
else:
|
||||
raise TypeError("features must be CustomerFeatures or dict")
|
||||
return pd.DataFrame([payload])
|
||||
|
||||
def predict_risk(features: CustomerFeatures | dict) -> float:
|
||||
_ensure_loaded()
|
||||
df = _to_dataframe(features)
|
||||
probs = _model_lgb.predict_proba(df)[0]
|
||||
idx_neg = int(_le.transform(["negative"])[0])
|
||||
return float(probs[idx_neg])
|
||||
|
||||
def predict_risk_model(features: CustomerFeatures | dict) -> RiskOutput:
|
||||
return RiskOutput(risk=predict_risk(features))
|
||||
|
||||
def explain_features(features: CustomerFeatures | dict) -> list[str]:
|
||||
_ensure_loaded()
|
||||
df = _to_dataframe(features)
|
||||
pre = _model_lr.named_steps["preprocessor"]
|
||||
Xv = pre.transform(df)
|
||||
clf = _model_lr.named_steps["classifier"]
|
||||
idx_neg = int(_le.transform(["negative"])[0])
|
||||
coefs = clf.coef_[idx_neg]
|
||||
vec = Xv.toarray().ravel()
|
||||
contrib = vec * coefs
|
||||
names = pre.get_feature_names_out()
|
||||
order = np.argsort(-np.abs(contrib))[:8]
|
||||
out = []
|
||||
for i in order:
|
||||
direction = "increase" if contrib[i] > 0 else "decrease"
|
||||
out.append(f"{names[i]} {direction} negative risk (weight={contrib[i]:.3f})")
|
||||
return out
|
||||
|
||||
def explain_features_model(features: CustomerFeatures | dict) -> ExplanationOutput:
|
||||
return ExplanationOutput(factors=explain_features(features))
|
||||
94
src/data_processing.py
Normal file
94
src/data_processing.py
Normal file
@ -0,0 +1,94 @@
|
||||
import polars as pl
|
||||
import pandera as pa
|
||||
from pandera.polars import DataFrameSchema, Column, Check
|
||||
|
||||
def define_schema():
|
||||
schema = DataFrameSchema({
|
||||
"customer_id": Column(pl.Int64, checks=Check.gt(0)),
|
||||
"gender": Column(pl.String, checks=Check.isin(["male", "female", "other"])),
|
||||
"age_group": Column(pl.String, checks=Check.isin(["18-25", "26-35", "36-45", "46-60", "60+"])),
|
||||
"region": Column(pl.String, checks=Check.isin(["north", "south", "east", "west", "central"])),
|
||||
"product_category": Column(pl.String),
|
||||
"purchase_channel": Column(pl.String, checks=Check.isin(["online", "offline"])), # Assuming online/offline
|
||||
"platform": Column(pl.String),
|
||||
"customer_rating": Column(pl.Int64, checks=Check.in_range(1, 5)),
|
||||
"review_text": Column(pl.String),
|
||||
"sentiment": Column(pl.String, checks=Check.isin(["positive", "negative", "neutral"])),
|
||||
"response_time_hours": Column(pl.Int64, checks=Check.ge(0)),
|
||||
"issue_resolved": Column(pl.String, checks=Check.isin(["yes", "no"])),
|
||||
"complaint_registered": Column(pl.String, checks=Check.isin(["yes", "no"])),
|
||||
})
|
||||
return schema
|
||||
|
||||
def clean_data(df: pl.DataFrame) -> pl.DataFrame:
|
||||
print("Cleaning data...")
|
||||
cleaned_df = df.with_columns([
|
||||
# Convert yes/no to boolean
|
||||
(pl.col("issue_resolved") == "yes").alias("issue_resolved"),
|
||||
(pl.col("complaint_registered") == "yes").alias("complaint_registered"),
|
||||
|
||||
# Standardize text to lowercase (optional but good for analysis)
|
||||
pl.col("product_category").str.to_lowercase(),
|
||||
pl.col("platform").str.to_lowercase(),
|
||||
pl.col("review_text").str.to_lowercase(),
|
||||
])
|
||||
|
||||
# Cast to categorical for efficiency where appropriate
|
||||
categorical_cols = ["gender", "age_group", "region", "product_category", "purchase_channel", "platform", "sentiment"]
|
||||
cleaned_df = cleaned_df.with_columns([
|
||||
pl.col(col).cast(pl.Categorical) for col in categorical_cols
|
||||
])
|
||||
|
||||
return cleaned_df
|
||||
|
||||
def summarize_data(df: pl.DataFrame):
|
||||
print("\n--- Value Counts for Categorical Columns ---")
|
||||
categorical_cols = [col for col, dtype in df.schema.items() if isinstance(dtype, (pl.Categorical, pl.String))]
|
||||
for col in categorical_cols:
|
||||
if col != "review_text": # Skip long text
|
||||
print(f"\nValue counts for {col}:")
|
||||
print(df[col].value_counts().sort("count", descending=True))
|
||||
|
||||
def load_and_inspect(file_path):
|
||||
print(f"Loading data from {file_path}...")
|
||||
try:
|
||||
df = pl.read_csv(file_path)
|
||||
print("Data loaded successfully.")
|
||||
except Exception as e:
|
||||
print(f"Error loading data: {e}")
|
||||
return None
|
||||
|
||||
# Define Schema
|
||||
schema = define_schema()
|
||||
|
||||
print("\n--- Validating Raw Data ---")
|
||||
try:
|
||||
schema.validate(df)
|
||||
print("Raw data validation passed!")
|
||||
except pa.errors.SchemaError as e:
|
||||
print("Raw data validation failed:")
|
||||
print(e)
|
||||
|
||||
# Clean Data
|
||||
cleaned_df = clean_data(df)
|
||||
|
||||
print("\n--- Cleaned Data Head ---")
|
||||
print(cleaned_df.head())
|
||||
|
||||
print("\n--- Cleaned Schema ---")
|
||||
print(cleaned_df.schema)
|
||||
|
||||
# Summarize Data
|
||||
summarize_data(cleaned_df)
|
||||
|
||||
# Save cleaned data
|
||||
output_path = "data/Cleaned_Customer_Sentiment.csv"
|
||||
print(f"\nSaving cleaned data to {output_path}...")
|
||||
cleaned_df.write_csv(output_path)
|
||||
print("Data saved successfully.")
|
||||
|
||||
return cleaned_df
|
||||
|
||||
if __name__ == "__main__":
|
||||
file_path = "data/Customer_Sentiment.csv"
|
||||
cleaned_df = load_and_inspect(file_path)
|
||||
327
src/streamlit_app.py
Normal file
327
src/streamlit_app.py
Normal file
@ -0,0 +1,327 @@
|
||||
import streamlit as st
|
||||
import polars as pl
|
||||
import os
|
||||
from dotenv import load_dotenv
|
||||
from agent import CustomerFeatures, predict_risk, explain_features
|
||||
import altair as alt
|
||||
|
||||
st.set_page_config(page_title="Customer Sentiment Analysis", layout="wide")
|
||||
load_dotenv()
|
||||
|
||||
st.markdown(
|
||||
"""
|
||||
<style>
|
||||
.main {
|
||||
background: radial-gradient(circle at top left, #fffde7 0%, #e8f5e9 40%, #e3f2fd 100%);
|
||||
cursor: none;
|
||||
}
|
||||
.stMetric {
|
||||
background-color: rgba(255, 255, 255, 0.78);
|
||||
border-radius: 14px;
|
||||
padding: 10px 16px;
|
||||
box-shadow: 0 4px 10px rgba(0, 0, 0, 0.06);
|
||||
}
|
||||
.block-container {
|
||||
padding-top: 4rem;
|
||||
}
|
||||
.star {
|
||||
position: fixed;
|
||||
width: 8px;
|
||||
height: 8px;
|
||||
background: radial-gradient(circle, #ffc107 0%, rgba(255, 193, 7, 0) 70%);
|
||||
pointer-events: none;
|
||||
transform: translate(-50%, -50%);
|
||||
animation: fade-out 0.6s linear forwards;
|
||||
border-radius: 50%;
|
||||
z-index: 9999;
|
||||
}
|
||||
@keyframes fade-out {
|
||||
from { opacity: 1; transform: translate(-50%, -50%) scale(1); }
|
||||
to { opacity: 0; transform: translate(-50%, -50%) scale(0.4); }
|
||||
}
|
||||
.top-nav {
|
||||
position: fixed;
|
||||
top: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
height: 3rem;
|
||||
background: rgba(255, 255, 255, 0.9);
|
||||
backdrop-filter: blur(8px);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 2rem;
|
||||
z-index: 9000;
|
||||
border-bottom: 1px solid rgba(0, 0, 0, 0.06);
|
||||
}
|
||||
.top-nav a {
|
||||
text-decoration: none;
|
||||
color: #1565c0;
|
||||
font-weight: 600;
|
||||
padding: 4px 10px;
|
||||
border-radius: 999px;
|
||||
transition: background-color 0.2s ease, color 0.2s ease;
|
||||
font-size: 0.9rem;
|
||||
}
|
||||
.top-nav a:hover {
|
||||
background-color: #ffe082;
|
||||
color: #1b5e20;
|
||||
}
|
||||
</style>
|
||||
<div class="top-nav">
|
||||
<a href="#overview">概览</a>
|
||||
<a href="#stats">数据统计</a>
|
||||
<a href="#prediction-system">预测系统</a>
|
||||
</div>
|
||||
<script>
|
||||
document.addEventListener("mousemove", function(e) {
|
||||
const star = document.createElement("div");
|
||||
star.className = "star";
|
||||
star.style.left = e.clientX + "px";
|
||||
star.style.top = e.clientY + "px";
|
||||
document.body.appendChild(star);
|
||||
setTimeout(() => star.remove(), 600);
|
||||
});
|
||||
</script>
|
||||
""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
|
||||
|
||||
def map_risk_to_advice(risk: float):
|
||||
if risk >= 0.7:
|
||||
level = "高风险"
|
||||
decision = "高风险客户,建议重点挽留"
|
||||
actions = [
|
||||
"立即安排客服电话回访,了解具体不满原因",
|
||||
"提供一次性优惠券或折扣,降低流失意愿",
|
||||
"在后续一周内跟踪满意度变化",
|
||||
]
|
||||
elif risk >= 0.4:
|
||||
level = "中风险"
|
||||
decision = "中风险客户,建议适度关怀"
|
||||
actions = [
|
||||
"通过短信或站内信收集更多反馈",
|
||||
"在下次消费时提供小额优惠",
|
||||
"关注后续服务体验,避免问题累积",
|
||||
]
|
||||
else:
|
||||
level = "低风险"
|
||||
decision = "低风险客户,建议保持现有服务水平"
|
||||
actions = [
|
||||
"定期推送增值服务或会员活动",
|
||||
"保持响应速度与服务质量",
|
||||
"鼓励其在社交平台分享正向评价",
|
||||
]
|
||||
return level, decision, actions
|
||||
|
||||
|
||||
st.title("📊 Customer Sentiment Analysis Dashboard")
|
||||
|
||||
st.image(
|
||||
"https://via.placeholder.com/1200x200/fffde7/e65100?text=Customer+Sentiment+Prediction+System",
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
|
||||
@st.cache_data
|
||||
def load_data():
|
||||
file_path = os.path.join("data", "Cleaned_Customer_Sentiment.csv")
|
||||
if not os.path.exists(file_path):
|
||||
st.error(f"Data file not found at {file_path}. Please run the data processing script first.")
|
||||
return None
|
||||
return pl.read_csv(file_path)
|
||||
|
||||
|
||||
df = load_data()
|
||||
|
||||
if df is not None:
|
||||
st.sidebar.header("Filters")
|
||||
sentiments = df["sentiment"].unique().to_list()
|
||||
selected_sentiment = st.sidebar.multiselect("Select Sentiment", sentiments, default=sentiments)
|
||||
regions = df["region"].unique().to_list()
|
||||
selected_region = st.sidebar.multiselect("Select Region", regions, default=regions)
|
||||
|
||||
filtered_df = df.filter(
|
||||
(pl.col("sentiment").is_in(selected_sentiment)) &
|
||||
(pl.col("region").is_in(selected_region))
|
||||
)
|
||||
|
||||
st.markdown('<div id="overview"></div>', unsafe_allow_html=True)
|
||||
col1, col2, col3 = st.columns(3)
|
||||
col1.metric("Total Reviews", filtered_df.height)
|
||||
col2.metric("Positive Reviews", filtered_df.filter(pl.col("sentiment") == "positive").height)
|
||||
col3.metric("Avg Rating", f"{filtered_df['customer_rating'].mean():.2f}")
|
||||
|
||||
st.markdown('<div id="stats"></div>', unsafe_allow_html=True)
|
||||
st.subheader("Sentiment Distribution")
|
||||
sentiment_counts = filtered_df["sentiment"].value_counts()
|
||||
sentiment_chart = (
|
||||
alt.Chart(sentiment_counts.to_pandas())
|
||||
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
||||
.encode(
|
||||
x=alt.X("sentiment:N", title="Sentiment"),
|
||||
y=alt.Y("count:Q", title="Count"),
|
||||
color=alt.Color(
|
||||
"sentiment:N",
|
||||
scale=alt.Scale(
|
||||
domain=["negative", "neutral", "positive"],
|
||||
range=["#ef9a9a", "#fff59d", "#a5d6a7"],
|
||||
),
|
||||
legend=None,
|
||||
),
|
||||
tooltip=["sentiment", "count"],
|
||||
)
|
||||
)
|
||||
st.altair_chart(sentiment_chart, use_container_width=True)
|
||||
|
||||
st.subheader("Data Preview")
|
||||
st.dataframe(filtered_df.to_pandas())
|
||||
|
||||
banner_col1, banner_col2 = st.columns([2, 1])
|
||||
with banner_col1:
|
||||
st.subheader("Configuration Check")
|
||||
api_key = os.getenv("API_KEY")
|
||||
if api_key:
|
||||
st.success(f"API Key loaded: {api_key[:4]}****")
|
||||
else:
|
||||
st.warning("API Key not found in environment variables.")
|
||||
with banner_col2:
|
||||
st.image(
|
||||
"https://via.placeholder.com/400x200/e8f5e9/1b5e20?text=Customer+Care",
|
||||
use_container_width=True,
|
||||
)
|
||||
|
||||
st.subheader("Cleaning Stats")
|
||||
nulls_df = filtered_df.select([pl.col(c).is_null().sum().alias(c) for c in filtered_df.columns]).melt()
|
||||
nulls_pd = nulls_df.rename({"variable": "feature", "value": "nulls"}).to_pandas()
|
||||
c1, c2 = st.columns(2)
|
||||
c1.metric("Issue Resolved Rate", f"{float(filtered_df['issue_resolved'].cast(pl.Boolean).mean())*100:.2f}%")
|
||||
c2.metric("Complaint Registered Rate", f"{float(filtered_df['complaint_registered'].cast(pl.Boolean).mean())*100:.2f}%")
|
||||
st.subheader("Nulls Per Column")
|
||||
nulls_chart = (
|
||||
alt.Chart(nulls_pd)
|
||||
.mark_bar(cornerRadiusTopLeft=4, cornerRadiusTopRight=4)
|
||||
.encode(
|
||||
x=alt.X("feature:N", sort="-y", title="Feature"),
|
||||
y=alt.Y("nulls:Q", title="Null Count"),
|
||||
color=alt.value("#80cbc4"),
|
||||
tooltip=["feature", "nulls"],
|
||||
)
|
||||
)
|
||||
st.altair_chart(nulls_chart, use_container_width=True)
|
||||
|
||||
st.subheader("Gender Distribution")
|
||||
gender_counts = filtered_df["gender"].value_counts()
|
||||
gender_chart = (
|
||||
alt.Chart(gender_counts.to_pandas())
|
||||
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
||||
.encode(
|
||||
x=alt.X("gender:N", title="Gender"),
|
||||
y=alt.Y("count:Q", title="Count"),
|
||||
color=alt.value("#90caf9"),
|
||||
tooltip=["gender", "count"],
|
||||
)
|
||||
)
|
||||
st.altair_chart(gender_chart, use_container_width=True)
|
||||
|
||||
st.subheader("Avg Response Time by Sentiment")
|
||||
avg_resp = filtered_df.group_by("sentiment").agg(pl.col("response_time_hours").mean()).sort("response_time_hours", descending=True)
|
||||
avg_resp_chart = (
|
||||
alt.Chart(avg_resp.to_pandas())
|
||||
.mark_line(point=True)
|
||||
.encode(
|
||||
x=alt.X("sentiment:N", title="Sentiment"),
|
||||
y=alt.Y("response_time_hours:Q", title="Avg Response Time (hours)"),
|
||||
color=alt.value("#ffb74d"),
|
||||
tooltip=["sentiment", "response_time_hours"],
|
||||
)
|
||||
)
|
||||
st.altair_chart(avg_resp_chart, use_container_width=True)
|
||||
|
||||
st.subheader("Top Platforms")
|
||||
top_platforms = filtered_df["platform"].value_counts().sort("count", descending=True).head(10)
|
||||
platform_chart = (
|
||||
alt.Chart(top_platforms.to_pandas())
|
||||
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
||||
.encode(
|
||||
x=alt.X("platform:N", sort="-y", title="Platform"),
|
||||
y=alt.Y("count:Q", title="Count"),
|
||||
color=alt.value("#ce93d8"),
|
||||
tooltip=["platform", "count"],
|
||||
)
|
||||
)
|
||||
st.altair_chart(platform_chart, use_container_width=True)
|
||||
|
||||
st.subheader("Top Product Categories")
|
||||
top_products = filtered_df["product_category"].value_counts().sort("count", descending=True).head(10)
|
||||
product_chart = (
|
||||
alt.Chart(top_products.to_pandas())
|
||||
.mark_bar(cornerRadiusTopLeft=6, cornerRadiusTopRight=6)
|
||||
.encode(
|
||||
x=alt.X("product_category:N", sort="-y", title="Product Category"),
|
||||
y=alt.Y("count:Q", title="Count"),
|
||||
color=alt.value("#ffcc80"),
|
||||
tooltip=["product_category", "count"],
|
||||
)
|
||||
)
|
||||
st.altair_chart(product_chart, use_container_width=True)
|
||||
|
||||
st.markdown('<div id="prediction-system"></div>', unsafe_allow_html=True)
|
||||
st.subheader("预测系统:预测 → 分析 → 建议")
|
||||
with st.form("recommendation_form"):
|
||||
left, right = st.columns(2)
|
||||
with left:
|
||||
gender_label = st.selectbox("性别", ["男性", "女性", "其他"])
|
||||
gender_map = {"男性": "male", "女性": "female", "其他": "other"}
|
||||
gender = gender_map[gender_label]
|
||||
age_group = st.selectbox("年龄段", ["18-25", "26-35", "36-45", "46-60", "60+"])
|
||||
region = st.selectbox("地区", ["north", "south", "east", "west", "central"])
|
||||
purchase_channel = st.selectbox("购买渠道", ["online", "offline"])
|
||||
issue_resolved_label = st.selectbox("问题是否已解决", ["是", "否"])
|
||||
with right:
|
||||
platforms = df["platform"].unique().to_list()
|
||||
product_categories = df["product_category"].unique().to_list()
|
||||
platform = st.selectbox("平台", platforms)
|
||||
product_category = st.selectbox("产品类别", product_categories)
|
||||
default_resp = float(filtered_df["response_time_hours"].mean())
|
||||
response_time_hours = st.number_input("响应时间(小时)", min_value=0.0, value=default_resp, step=0.5)
|
||||
complaint_registered_label = st.selectbox("是否已投诉", ["是", "否"])
|
||||
review_text = st.text_area("评论文本", value="服务态度很好,处理问题很及时。")
|
||||
submitted = st.form_submit_button("运行预测系统")
|
||||
|
||||
if submitted:
|
||||
features = {
|
||||
"gender": gender,
|
||||
"age_group": age_group,
|
||||
"region": region,
|
||||
"product_category": product_category,
|
||||
"purchase_channel": purchase_channel,
|
||||
"platform": platform,
|
||||
"response_time_hours": response_time_hours,
|
||||
"issue_resolved": issue_resolved_label == "是",
|
||||
"complaint_registered": complaint_registered_label == "是",
|
||||
"review_text": review_text,
|
||||
}
|
||||
try:
|
||||
customer = CustomerFeatures(**features)
|
||||
risk = predict_risk(customer)
|
||||
explanations = explain_features(customer)
|
||||
level, decision, actions = map_risk_to_advice(risk)
|
||||
|
||||
st.markdown("#### 预测")
|
||||
pred_col1, pred_col2 = st.columns(2)
|
||||
pred_col1.metric("负向情绪风险分数", f"{risk:.2f}")
|
||||
pred_col2.metric("风险等级", level)
|
||||
st.progress(risk)
|
||||
|
||||
st.markdown("#### 分析")
|
||||
for item in explanations:
|
||||
st.write(f"- {item}")
|
||||
|
||||
st.markdown("#### 建议")
|
||||
st.write(decision)
|
||||
for act in actions:
|
||||
st.write(f"- {act}")
|
||||
except Exception as e:
|
||||
st.error(f"运行预测系统时出错: {e}")
|
||||
154
src/train_models.py
Normal file
154
src/train_models.py
Normal file
@ -0,0 +1,154 @@
|
||||
import polars as pl
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.linear_model import LogisticRegression
|
||||
from sklearn.metrics import roc_auc_score, classification_report, confusion_matrix
|
||||
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
|
||||
from sklearn.compose import ColumnTransformer
|
||||
from sklearn.pipeline import Pipeline
|
||||
from sklearn.impute import SimpleImputer
|
||||
import lightgbm as lgb
|
||||
import joblib
|
||||
import os
|
||||
|
||||
# Set random seed for reproducibility
|
||||
RANDOM_STATE = 42
|
||||
|
||||
def load_data():
|
||||
file_path = os.path.join("data", "Cleaned_Customer_Sentiment.csv")
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
# Read with Polars but convert to Pandas for Scikit-Learn compatibility
|
||||
df = pl.read_csv(file_path).to_pandas()
|
||||
return df
|
||||
|
||||
def preprocess_data(df):
|
||||
# Drop irrelevant columns for prediction
|
||||
# Note: 'response_time_hours' and 'issue_resolved' might be potential leakages or post-event features,
|
||||
# but we will keep them if the goal is to predict sentiment given the full context including resolution.
|
||||
# If the goal is strictly from text, we should drop them.
|
||||
# For this exercise, I'll assume we want to use all available info except ID.
|
||||
|
||||
X = df.drop(columns=['sentiment', 'customer_id', 'customer_rating'])
|
||||
y = df['sentiment']
|
||||
|
||||
# Encode target
|
||||
le = LabelEncoder()
|
||||
y_encoded = le.fit_transform(y)
|
||||
|
||||
# Split data
|
||||
X_train, X_test, y_train, y_test = train_test_split(
|
||||
X, y_encoded, test_size=0.2, random_state=RANDOM_STATE, stratify=y_encoded
|
||||
)
|
||||
|
||||
return X_train, X_test, y_train, y_test, le
|
||||
|
||||
def build_preprocessor(X_train):
|
||||
numeric_features = ['response_time_hours']
|
||||
categorical_features = ['gender', 'age_group', 'region', 'product_category', 'purchase_channel', 'platform', 'issue_resolved', 'complaint_registered']
|
||||
text_features = 'review_text'
|
||||
|
||||
# Transformers
|
||||
numeric_transformer = Pipeline(steps=[
|
||||
('imputer', SimpleImputer(strategy='median')),
|
||||
])
|
||||
|
||||
categorical_transformer = Pipeline(steps=[
|
||||
('imputer', SimpleImputer(strategy='constant', fill_value='missing')),
|
||||
('onehot', OneHotEncoder(handle_unknown='ignore'))
|
||||
])
|
||||
|
||||
text_transformer = TfidfVectorizer(max_features=5000, stop_words='english')
|
||||
|
||||
preprocessor = ColumnTransformer(
|
||||
transformers=[
|
||||
('num', numeric_transformer, numeric_features),
|
||||
('cat', categorical_transformer, categorical_features),
|
||||
('text', text_transformer, text_features)
|
||||
])
|
||||
|
||||
return preprocessor
|
||||
|
||||
def train_logistic_regression(X_train, y_train, preprocessor):
|
||||
print("\n--- Training Logistic Regression ---")
|
||||
clf = Pipeline(steps=[('preprocessor', preprocessor),
|
||||
('classifier', LogisticRegression(max_iter=1000, random_state=RANDOM_STATE))])
|
||||
|
||||
clf.fit(X_train, y_train)
|
||||
return clf
|
||||
|
||||
def train_lightgbm(X_train, y_train, preprocessor):
|
||||
print("\n--- Training LightGBM ---")
|
||||
|
||||
# LightGBM handles categories natively ideally, but for simplicity in pipeline with text,
|
||||
# we'll use the preprocessed sparse matrix from the ColumnTransformer.
|
||||
# Note: ColumnTransformer returns sparse matrix because of Tfidf and OneHot.
|
||||
|
||||
clf = Pipeline(steps=[('preprocessor', preprocessor),
|
||||
('classifier', lgb.LGBMClassifier(random_state=RANDOM_STATE, verbose=-1))])
|
||||
|
||||
clf.fit(X_train, y_train)
|
||||
return clf
|
||||
|
||||
def evaluate_model(model, X_test, y_test, le, model_name):
|
||||
y_pred = model.predict(X_test)
|
||||
y_prob = model.predict_proba(X_test)
|
||||
|
||||
# ROC-AUC (One-vs-Rest)
|
||||
auc = roc_auc_score(y_test, y_prob, multi_class='ovr')
|
||||
print(f"\n{model_name} Results:")
|
||||
print(f"ROC-AUC: {auc:.4f}")
|
||||
|
||||
print("\nClassification Report:")
|
||||
print(classification_report(y_test, y_pred, target_names=le.classes_))
|
||||
|
||||
return auc, y_pred, y_prob
|
||||
|
||||
def analyze_errors(model, X_test, y_test, y_pred, le):
|
||||
print("\n--- Error Analysis ---")
|
||||
# Convert back to dataframe for easier inspection
|
||||
test_df = X_test.copy()
|
||||
test_df['true_label'] = le.inverse_transform(y_test)
|
||||
test_df['predicted_label'] = le.inverse_transform(y_pred)
|
||||
|
||||
# Identify errors
|
||||
errors = test_df[test_df['true_label'] != test_df['predicted_label']]
|
||||
print(f"Total Errors: {len(errors)} out of {len(test_df)} ({len(errors)/len(test_df):.2%})")
|
||||
|
||||
print("\nConfusion Matrix:")
|
||||
print(confusion_matrix(y_test, y_pred))
|
||||
|
||||
print("\nSample Errors:")
|
||||
print(errors[['review_text', 'true_label', 'predicted_label']].head(5))
|
||||
|
||||
# Common error patterns
|
||||
print("\nError Counts by True Label:")
|
||||
print(errors['true_label'].value_counts())
|
||||
|
||||
def main():
|
||||
df = load_data()
|
||||
X_train, X_test, y_train, y_test, le = preprocess_data(df)
|
||||
|
||||
preprocessor = build_preprocessor(X_train)
|
||||
|
||||
# 1. Logistic Regression
|
||||
lr_model = train_logistic_regression(X_train, y_train, preprocessor)
|
||||
lr_auc, lr_pred, _ = evaluate_model(lr_model, X_test, y_test, le, "Logistic Regression")
|
||||
|
||||
# 2. LightGBM
|
||||
lgb_model = train_lightgbm(X_train, y_train, preprocessor)
|
||||
lgb_auc, lgb_pred, _ = evaluate_model(lgb_model, X_test, y_test, le, "LightGBM")
|
||||
|
||||
os.makedirs("artifacts", exist_ok=True)
|
||||
joblib.dump(lr_model, os.path.join("artifacts", "lr_pipeline.joblib"))
|
||||
joblib.dump(lgb_model, os.path.join("artifacts", "lgb_pipeline.joblib"))
|
||||
joblib.dump(le, os.path.join("artifacts", "label_encoder.joblib"))
|
||||
|
||||
# Error Analysis (using LightGBM as it's likely better)
|
||||
analyze_errors(lgb_model, X_test, y_test, lgb_pred, le)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user