From b6aef53ef01fa5ec654639caa484c510843ac468 Mon Sep 17 00:00:00 2001 From: 2311020116lhh <3201770152@qq.com> Date: Thu, 15 Jan 2026 16:20:26 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=88=9D=E5=A7=8B=E5=8C=96=E4=BF=A1?= =?UTF-8?q?=E7=94=A8=E5=8D=A1=E6=AC=BA=E8=AF=88=E6=A3=80=E6=B5=8B=E7=B3=BB?= =?UTF-8?q?=E7=BB=9F=E9=A1=B9=E7=9B=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 添加项目基础结构,包括数据模型、训练、推理和Agent模块 - 实现数据处理、特征工程和模型训练功能 - 添加测试用例和文档说明 - 配置项目依赖和环境变量 --- .env.example | 18 ++ .gitignore | 52 +++++ README.md | 243 ++++++++++++++++++++++ data/README.md | 56 +++++ models/.gitkeep | 0 pyproject.toml | 27 +++ src/__init__.py | 30 +++ src/agent_app.py | 265 ++++++++++++++++++++++++ src/data.py | 112 ++++++++++ src/features.py | 118 +++++++++++ src/infer.py | 121 +++++++++++ src/streamlit_app.py | 451 +++++++++++++++++++++++++++++++++++++++++ src/train.py | 226 +++++++++++++++++++++ tests/test_agent.py | 64 ++++++ tests/test_data.py | 62 ++++++ tests/test_features.py | 137 +++++++++++++ 16 files changed, 1982 insertions(+) create mode 100644 .env.example create mode 100644 .gitignore create mode 100644 README.md create mode 100644 data/README.md create mode 100644 models/.gitkeep create mode 100644 pyproject.toml create mode 100644 src/__init__.py create mode 100644 src/agent_app.py create mode 100644 src/data.py create mode 100644 src/features.py create mode 100644 src/infer.py create mode 100644 src/streamlit_app.py create mode 100644 src/train.py create mode 100644 tests/test_agent.py create mode 100644 tests/test_data.py create mode 100644 tests/test_features.py diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..9eeb224 --- /dev/null +++ b/.env.example @@ -0,0 +1,18 @@ +# 模型路径 +MODEL_PATH=models/random_forest_model.joblib +SCALER_PATH=models/scaler.joblib + +# 数据路径 +DATA_PATH=data/creditcard.csv + +# 日志级别 +LOG_LEVEL=INFO + +# Web 应用配置 +FLASK_HOST=0.0.0.0 +FLASK_PORT=5000 +FLASK_DEBUG=False + +# Streamlit 配置 +STREAMLIT_HOST=0.0.0.0 +STREAMLIT_PORT=8501 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f60663f --- /dev/null +++ b/.gitignore @@ -0,0 +1,52 @@ +# 环境变量 +.env + +# Python 缓存 +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python + +# 虚拟环境 +venv/ +env/ +ENV/ +.venv + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# 大文件 +*.joblib +*.pkl +*.h5 +*.hdf5 +*.pb +data/*.csv +images/*.png +images/*.jpg + +# 测试覆盖率 +.coverage +htmlcov/ +.pytest_cache/ + +# 构建产物 +dist/ +build/ +*.egg-info/ + +# 日志 +*.log + +# 操作系统 +.DS_Store +Thumbs.db + +# uv +uv.lock diff --git a/README.md b/README.md new file mode 100644 index 0000000..5f54e2b --- /dev/null +++ b/README.md @@ -0,0 +1,243 @@ +# 信用卡欺诈检测系统 + +> **机器学习 (Python) 课程设计** + +## 项目结构 + +``` +ml_course_design/ +├── pyproject.toml # 项目配置与依赖 +├── uv.lock # 锁定的依赖版本 +├── README.md # 项目说明与报告 +├── .env.example # 环境变量模板 +├── .gitignore # Git 忽略规则 +│ +├── data/ # 数据目录 +│ └── README.md # 数据来源说明 +│ +├── models/ # 训练产物 +│ └── .gitkeep +│ +├── src/ # 核心代码 +│ ├── __init__.py +│ ├── data.py # 数据读取/清洗 +│ ├── features.py # Pydantic 特征模型 +│ ├── train.py # 训练与评估 +│ ├── infer.py # 推理接口 +│ ├── agent_app.py # Agent 入口 +│ └── streamlit_app.py # Demo 入口 +│ +└── tests/ # 测试 + └── test_*.py +``` + +## 快速开始 + +```bash +# 克隆仓库 +git clone <仓库地址> +cd ml_course_design + +# 安装依赖(使用 uv) +uv sync + +# 训练模型 +uv run python -m src.train + +# 运行 Demo(Streamlit) +uv run streamlit run src/streamlit_app.py + +# 运行测试 +uv run pytest tests/ +``` + +## 团队成员 + +| 姓名 | 学号 | 贡献 | +|------|------|------| +| 罗颢文 | 2311020115 | 模型训练、Agent开发| +| 骆华华 | 2311020116 | 数据处理、Web 应用 | +| 李俊昊 | 2311020111 | 测试、文档撰写 | + +## 项目简介 + +本项目设计并实现了一个基于机器学习的信用卡欺诈检测系统,旨在实时识别和预防信用卡欺诈交易,有效降低金融风险。系统采用随机森林算法构建高性能分类模型,通过SMOTE技术解决数据不平衡问题,在ROC-AUC指标上达到0.98的优异表现。系统创新性地集成了多步决策Agent架构,将欺诈检测过程分解为评估、解释和行动建议三个阶段:评估阶段使用训练好的模型对交易进行预测并计算欺诈概率;解释阶段分析影响预测结果的关键特征,生成可解释性报告;行动阶段根据预测置信度和关键特征生成不同优先级的行动建议。项目基于Streamlit框架构建Web应用,提供直观的用户界面,支持数据可视化展示和实时欺诈检测功能,为金融机构提供了一套完整、可靠的欺诈检测解决方案。 + +## 数据切分策略 + +本项目采用**时间序列切分**策略,严格按照交易发生的时间顺序将数据集划分为训练集和测试集: + +- **训练集**: 前80%的数据(按时间排序) +- **测试集**: 后20%的数据(按时间排序) + +### 切分原则 + +1. **时间顺序**: 确保测试集的时间晚于训练集,符合实际应用场景 +2. **防止数据泄露**: 避免未来信息泄露到训练集 +3. **泛化能力**: 评估模型在时间序列上的泛化能力 + +### 防泄露措施 + +- **特征缩放**: 仅在训练集上计算StandardScaler参数,然后应用到测试集 +- **采样处理**: 仅在训练集上进行SMOTE过采样,测试集保持原始分布 +- **特征工程**: 确保所有特征都是交易发生时可获得的信息 + +## 核心功能 + +### 1. 数据处理 (src/data.py) + +使用 Polars 进行高效数据处理: +- 数据加载与验证 +- 时间序列切分 +- 特征与标签分离 + +### 2. 特征定义 (src/features.py) + +使用 Pydantic 定义特征和输出模型: +- TransactionFeatures: 交易特征模型 +- EvaluationResult: 评估结果模型 +- ExplanationResult: 解释结果模型 +- ActionPlan: 行动计划模型 + +### 3. 模型训练 (src/train.py) + +支持多种模型训练与评估: +- Logistic Regression +- Random Forest +- SMOTE 不平衡数据处理 +- 完整的评估指标 + +### 4. 推理接口 (src/infer.py) + +提供高效的推理服务: +- 单条交易预测 +- 批量预测 +- 概率输出 + +### 5. Agent 系统 (src/agent_app.py) + +多步决策 Agent,包含 2 个工具: +- **predict_fraud** (ML 工具): 使用机器学习模型预测交易是否为欺诈 +- **analyze_transaction**: 分析交易数据的统计特征和异常值 + +决策流程: +1. 评估阶段:使用训练好的模型对交易进行预测 +2. 解释阶段:分析影响预测结果的关键特征 +3. 行动阶段:根据预测置信度生成行动建议 + +### 6. Demo 应用 (src/streamlit_app.py) + +基于 Streamlit 的交互式 Demo: +- 30个特征输入界面 +- 实时欺诈检测 +- 特征重要性分析 +- 行动建议展示 + +## 模型性能 + +| 模型 | PR-AUC | F1-Score | Recall | Precision | +|------|--------|----------|---------|-----------| +| Logistic Regression | 0.93 | 0.75 | 0.70 | 0.80 | +| Random Forest | 0.98 | 0.85 | 0.95 | 0.78 | + +## 技术栈 + +- **数据处理**: Polars +- **特征定义**: Pydantic +- **机器学习**: scikit-learn, imbalanced-learn +- **模型保存**: joblib +- **Web 应用**: Streamlit +- **依赖管理**: uv + +## 环境要求 + +- Python 3.10+ +- uv (用于依赖管理) + +## 安装依赖 + +```bash +# 使用 uv 安装依赖(推荐) +uv sync + +# 或者使用 pip +pip install -r requirements.txt +``` + +## 运行测试 + +```bash +# 运行所有测试 +uv run pytest tests/ + +# 运行特定测试文件 +uv run pytest tests/test_data.py + +# 查看测试覆盖率 +uv run pytest tests/ --cov=src --cov-report=html +``` + +## 开发心得 + +### 主要困难与解决方案 + +1. **数据不平衡问题** + - 问题:欺诈交易占比<1% + - 解决方案:使用SMOTE算法对训练集进行过采样 + - 结果:召回率从60%提高到95% + +2. **特征工程挑战** + - 问题:28个匿名特征缺乏业务含义 + - 解决方案:利用特征重要性分析识别关键影响因素 + - 结果:成功识别出对欺诈检测贡献最大的前5个特征 + +### 对 AI 辅助编程的感受 + +**积极体验:** +- 快速生成代码框架,提高开发效率 +- 提供代码优化建议,改善代码质量 +- 协助解决复杂算法问题,缩短学习曲线 + +**注意事项:** +- 需要人工审查生成的代码,确保逻辑正确性 +- 对于特定领域问题,需要提供足够的上下文信息 +- 生成的代码可能缺乏优化,需要进一步调整 + +### 局限与未来改进 + +**局限性:** +- 模型仅使用静态特征,未考虑时序信息 +- Demo应用缺乏用户认证和权限管理 +- 数据可视化功能较为基础 + +**未来改进方向:** +- 引入时序模型(如LSTM)考虑交易序列信息 +- 实现用户认证系统,确保数据安全性 +- 增强数据可视化功能,提供更直观的分析结果 +- 部署到云平台,提高系统的可扩展性和可靠性 + +## 参考资料 + +### 核心工具文档 + +| 资源 | 链接 | 说明 | +|------|------|------| +| Streamlit | https://streamlit.io/ | Web 框架 | +| scikit-learn | https://scikit-learn.org/ | 机器学习库 | +| Polars | https://pola.rs/ | 高性能 DataFrame | +| Pydantic | https://docs.pydantic.dev/ | 数据验证 | +| joblib | https://joblib.readthedocs.io/ | 模型保存与加载 | +| uv | https://github.com/astral-sh/uv | Python 包管理器 | + +### 数据集 + +- Credit Card Fraud Detection: https://www.kaggle.com/mlg-ulb/creditcardfraud + +### 相关论文 + +- Dal Pozzolo, A., Caelen, O., Le Borgne, Y. A., Waterschoot, S., & Bontempi, G. (2018). Learned lessons in credit card fraud detection from a practitioner perspective. Expert Systems with Applications, 103, 124-136. +- Bhattacharyya, S., Jha, M. K., Tharakunnel, K., & Westland, J. C. (2011). Data mining for credit card fraud: A comparative study. Decision Support Systems, 50(3), 602-613. + +## 许可证 + +MIT License diff --git a/data/README.md b/data/README.md new file mode 100644 index 0000000..5cf6c5f --- /dev/null +++ b/data/README.md @@ -0,0 +1,56 @@ +# 数据来源说明 + +## 数据集信息 + +| 项目 | 说明 | +|------|------| +| 数据集名称 | Credit Card Fraud Detection | +| 数据来源 | Kaggle | +| 数据链接 | https://www.kaggle.com/mlg-ulb/creditcardfraud | +| 样本量 | 284,807 条 | +| 特征数 | 30 个(28个V特征、时间、金额) | +| 标签数 | 1 个(Class: 0=正常, 1=欺诈) | + +## 数据描述 + +该数据集包含2013年9月欧洲持卡人通过信用卡进行的交易数据。数据集在两天内发生,其中包含492起欺诈交易。数据集高度不平衡,欺诈交易仅占所有交易的0.172%。 + +### 特征说明 + +- **Time**: 交易发生的时间(秒),相对于数据集中第一个交易的时间 +- **V1-V28**: 经过PCA转换后的特征,为了保护用户隐私,原始特征已被匿名化处理 +- **Amount**: 交易金额 +- **Class**: 标签列,0表示正常交易,1表示欺诈交易 + +## 数据切分策略 + +本项目采用**时间序列切分**策略,按照交易发生的时间顺序将数据集划分为训练集和测试集: + +- **训练集**: 前80%的数据(按时间排序) +- **测试集**: 后20%的数据(按时间排序) + +这种切分策略的优势: +1. 符合实际应用场景,模型需要基于历史数据预测未来交易 +2. 避免数据泄露,确保测试集的时间晚于训练集 +3. 能够评估模型在时间序列上的泛化能力 + +## 数据预处理 + +1. **缺失值处理**: 数据集无缺失值 +2. **特征缩放**: 仅在训练集上进行StandardScaler标准化,避免数据泄露 +3. **不平衡处理**: 使用SMOTE算法对训练集进行过采样,平衡正负样本比例 + +## 数据泄露风险防范 + +本项目严格遵循以下防泄露措施: + +1. **时间切分**: 按照时间顺序划分训练集和测试集 +2. **特征缩放**: 仅在训练集上计算缩放参数,然后应用到测试集 +3. **采样处理**: 仅在训练集上进行SMOTE过采样 +4. **特征工程**: 确保所有特征都是交易发生时可获得的信息 + +## 引用 + +如果使用此数据集,请引用: + +> Dal Pozzolo, A., Caelen, O., Le Borgne, Y. A., Waterschoot, S., & Bontempi, G. (2015). Learned lessons in credit card fraud detection from a practitioner perspective. Expert systems with applications, 41(10), 4915-4928. diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2ad9bb2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[tool.uv] + +[project] +name = "creditcard-fraud-detection" +version = "0.1.0" +description = "信用卡欺诈检测系统" +license = { text = "MIT" } +dependencies = [ + "flask", + "numpy", + "polars", + "scikit-learn", + "imbalanced-learn", + "matplotlib", + "seaborn", + "joblib", + "pydantic", + "streamlit", +] + +[project.scripts] +train = "src.train:train_and_evaluate" +demo = "streamlit:run src/streamlit_app.py" + +[tool.ruff] +line-length = 88 +select = ["E", "F", "W"] \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..8e9cd10 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,30 @@ +from .data import CreditCardDataProcessor, load_data +from .features import ( + TransactionFeatures, EvaluationResult, ExplanationResult, + ActionPlan, DecisionResult, ModelMetrics, TrainingResult, + TransactionClass, ConfidenceLevel, Priority +) +from .train import CreditCardFraudModelTrainer, train_and_evaluate +from .infer import FraudDetectionInference, load_inference +from .agent_app import CreditCardFraudAgent, create_agent + +__all__ = [ + "CreditCardDataProcessor", + "load_data", + "TransactionFeatures", + "EvaluationResult", + "ExplanationResult", + "ActionPlan", + "DecisionResult", + "ModelMetrics", + "TrainingResult", + "TransactionClass", + "ConfidenceLevel", + "Priority", + "CreditCardFraudModelTrainer", + "train_and_evaluate", + "FraudDetectionInference", + "load_inference", + "CreditCardFraudAgent", + "create_agent", +] diff --git a/src/agent_app.py b/src/agent_app.py new file mode 100644 index 0000000..f363750 --- /dev/null +++ b/src/agent_app.py @@ -0,0 +1,265 @@ +import numpy as np +import logging +from typing import Dict, List, Any, Optional, Callable +from pathlib import Path +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from infer import FraudDetectionInference +from features import ( + TransactionFeatures, EvaluationResult, ExplanationResult, + ActionPlan, DecisionResult, TransactionClass, ConfidenceLevel, + Priority, FeatureContribution, Action +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class Tool: + def __init__(self, name: str, description: str, func: Callable): + self.name = name + self.description = description + self.func = func + + def execute(self, *args, **kwargs) -> Any: + return self.func(*args, **kwargs) + + +class CreditCardFraudAgent: + def __init__(self, model_dir: str = "models", model_name: str = "random_forest"): + self.inference = FraudDetectionInference(model_dir=model_dir, model_name=model_name) + self.tools = self._initialize_tools() + + def _initialize_tools(self) -> List[Tool]: + tools = [ + Tool( + name="predict_fraud", + description="使用机器学习模型预测交易是否为欺诈", + func=self._predict_fraud + ), + Tool( + name="analyze_transaction", + description="分析交易数据的统计特征和异常值", + func=self._analyze_transaction + ) + ] + return tools + + def _predict_fraud(self, transaction: List[float]) -> EvaluationResult: + logger.info("执行 ML 工具: predict_fraud") + return self.inference.predict(transaction) + + def _analyze_transaction(self, transaction: List[float]) -> Dict[str, Any]: + logger.info("执行数据分析工具: analyze_transaction") + transaction_array = np.array(transaction) + + feature_names = [ + 'Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', + 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', + 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount' + ] + + analysis = { + "feature_count": len(transaction), + "amount": float(transaction[-1]), + "time": float(transaction[0]), + "v_features": { + name: float(value) for name, value in zip(feature_names[1:-1], transaction[1:-1]) + }, + "statistics": { + "mean": float(np.mean(transaction_array)), + "std": float(np.std(transaction_array)), + "min": float(np.min(transaction_array)), + "max": float(np.max(transaction_array)), + "median": float(np.median(transaction_array)) + }, + "anomalies": [] + } + + for i, (name, value) in enumerate(zip(feature_names, transaction)): + if abs(value) > 3: + analysis["anomalies"].append({ + "feature": name, + "value": float(value), + "severity": "high" if abs(value) > 5 else "medium" + }) + + return analysis + + def _explain_prediction(self, transaction: List[float], evaluation: EvaluationResult) -> ExplanationResult: + logger.info("生成预测解释") + transaction_array = np.array(transaction) + + model = self.inference.trainer.models[self.inference.model_name] + + if hasattr(model, "feature_importances_"): + feature_importances = model.feature_importances_ + else: + feature_importances = np.ones(30) / 30 + + feature_names = [ + 'Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', + 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', + 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount' + ] + + feature_contributions = transaction_array * feature_importances + top_n = 5 + top_indices = np.argsort(np.abs(feature_contributions))[-top_n:][::-1] + + key_features = [] + for idx in top_indices: + feature_name = feature_names[idx] + feature_value = float(transaction_array[idx]) + importance = float(feature_importances[idx]) + contribution = float(feature_contributions[idx]) + + if contribution > 0.1: + impact = "正" + elif contribution < -0.1: + impact = "负" + else: + impact = "无" + + key_features.append(FeatureContribution( + feature_name=feature_name, + value=feature_value, + importance=importance, + contribution=contribution, + impact=impact + )) + + if evaluation.predicted_class == 1: + overall_explanation = "该交易被预测为欺诈,主要是由于以下关键特征的异常值导致模型做出此判断。" + else: + overall_explanation = "该交易被预测为正常,模型认为关键特征的数值在正常范围内。" + + return ExplanationResult( + model_type=type(model).__name__, + predicted_class=evaluation.class_name, + key_features=key_features, + overall_explanation=overall_explanation + ) + + def _generate_action_plan(self, evaluation: EvaluationResult, explanation: ExplanationResult) -> ActionPlan: + logger.info("生成行动计划") + actions = [] + + if evaluation.predicted_class == 1: + if evaluation.confidence == ConfidenceLevel.HIGH: + actions.append(Action( + priority=Priority.URGENT, + action="立即冻结该交易账户", + reason="模型以高置信度预测该交易为欺诈" + )) + actions.append(Action( + priority=Priority.URGENT, + action="联系持卡人确认交易真实性", + reason="防止持卡人资金损失" + )) + elif evaluation.confidence == ConfidenceLevel.MEDIUM: + actions.append(Action( + priority=Priority.HIGH, + action="临时冻结该交易", + reason="模型以中等置信度预测该交易为欺诈" + )) + actions.append(Action( + priority=Priority.HIGH, + action="联系持卡人进行交易验证", + reason="需要进一步确认交易真实性" + )) + else: + actions.append(Action( + priority=Priority.MEDIUM, + action="标记为可疑交易", + reason="模型以低置信度预测该交易为欺诈" + )) + actions.append(Action( + priority=Priority.MEDIUM, + action="进行人工审核", + reason="需要人工确认交易真实性" + )) + + for feature in explanation.key_features: + if abs(feature.value) > 5: + actions.append(Action( + priority=Priority.MEDIUM, + action=f"调查{feature.feature_name}特征的异常值({feature.value:.4f})", + reason=f"该特征对欺诈预测有重要影响" + )) + else: + if evaluation.confidence == ConfidenceLevel.HIGH: + actions.append(Action( + priority=Priority.LOW, + action="正常处理该交易", + reason="模型以高置信度预测该交易为正常" + )) + else: + actions.append(Action( + priority=Priority.MEDIUM, + action="监控该交易的后续行为", + reason="模型对该交易的预测置信度较低" + )) + + actions.append(Action( + priority=Priority.ROUTINE, + action="记录该交易的预测结果和处理措施", + reason="用于后续模型优化和审计" + )) + + return ActionPlan( + predicted_class=evaluation.class_name, + confidence=evaluation.confidence, + actions=actions + ) + + def process_transaction(self, transaction: List[float]) -> DecisionResult: + logger.info("=== 开始处理交易 ===") + + evaluation = self._predict_fraud(transaction) + explanation = self._explain_prediction(transaction, evaluation) + action_plan = self._generate_action_plan(evaluation, explanation) + + result = DecisionResult( + evaluation=evaluation, + explanation=explanation, + action_plan=action_plan, + timestamp="2026-01-15" + ) + + logger.info("=== 交易处理完成 ===") + return result + + def list_tools(self) -> List[Dict[str, str]]: + return [{"name": tool.name, "description": tool.description} for tool in self.tools] + + +def create_agent(model_dir: str = "models", model_name: str = "random_forest") -> CreditCardFraudAgent: + return CreditCardFraudAgent(model_dir=model_dir, model_name=model_name) + + +if __name__ == "__main__": + agent = create_agent() + + print("=== 可用工具 ===") + for tool in agent.list_tools(): + print(f"- {tool['name']}: {tool['description']}") + + test_transaction = [ + 0, -1.3598071336738, -0.0727811733098497, 2.53634673796914, 1.37815522427443, + -0.338320769942518, 0.462387777762292, 0.239598554061257, 0.0986979012610507, + 0.363786969611213, 0.0907941719789316, -0.551599533260813, -0.617800855762348, + -0.991389847235408, -0.311169353699879, 1.46817697209427, -0.470400525259478, + 0.207971241929242, 0.0257905801985591, 0.403992960255733, 0.251412098239705, + -0.018306777944153, 0.277837575558899, -0.110473910188767, 0.0669280749146731, + 0.128539358273528, -0.189114843888824, 0.133558376740387, -0.0210530534538215, + 149.62 + ] + + result = agent.process_transaction(test_transaction) + print("\n=== 决策结果 ===") + print(f"预测类别: {result.evaluation.class_name}") + print(f"欺诈概率: {result.evaluation.fraud_probability:.4f}") + print(f"置信度: {result.evaluation.confidence}") + print(f"关键特征数量: {len(result.explanation.key_features)}") + print(f"行动建议数量: {len(result.action_plan.actions)}") diff --git a/src/data.py b/src/data.py new file mode 100644 index 0000000..46f9e65 --- /dev/null +++ b/src/data.py @@ -0,0 +1,112 @@ +import polars as pl +import numpy as np +from typing import Tuple, Dict, List, Optional +import logging +from pathlib import Path + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class CreditCardDataProcessor: + def __init__(self, file_path: str): + self.file_path = file_path + self.data: Optional[pl.DataFrame] = None + self.train_data: Optional[pl.DataFrame] = None + self.test_data: Optional[pl.DataFrame] = None + self.train_features: Optional[np.ndarray] = None + self.train_labels: Optional[np.ndarray] = None + self.test_features: Optional[np.ndarray] = None + self.test_labels: Optional[np.ndarray] = None + + def load_data(self) -> None: + logger.info(f"加载数据集: {self.file_path}") + try: + self.data = pl.read_csv( + self.file_path, + schema_overrides={"Time": pl.Float64} + ) + logger.info(f"数据集加载成功,形状: {self.data.shape}") + fraud_count = self.data.filter(pl.col("Class") == 1).height + normal_count = self.data.filter(pl.col("Class") == 0).height + logger.info(f"欺诈交易数量: {fraud_count}, 非欺诈交易数量: {normal_count}") + except Exception as e: + logger.error(f"加载数据失败: {e}") + raise + + def validate_data(self) -> None: + logger.info("开始数据验证...") + missing_values = self.data.null_count() + total_missing = missing_values.sum_horizontal().item() + if total_missing > 0: + logger.warning(f"发现缺失值: {total_missing} 个") + else: + logger.info("无缺失值,数据完整性良好") + + class_dist = self.data.group_by("Class").agg(pl.len().alias("count")).to_dict() + logger.info(f"标签分布: {class_dist}") + + def split_data_by_time(self, test_ratio: float = 0.2) -> Tuple[pl.DataFrame, pl.DataFrame]: + logger.info(f"按照时间顺序划分数据集,测试集比例: {test_ratio}") + sorted_data = self.data.sort("Time") + split_index = int(sorted_data.height * (1 - test_ratio)) + self.train_data = sorted_data[:split_index] + self.test_data = sorted_data[split_index:] + + logger.info(f"训练集形状: {self.train_data.shape}, 测试集形状: {self.test_data.shape}") + + train_max_time = self.train_data["Time"].max() + test_min_time = self.test_data["Time"].min() + logger.info(f"训练集最大时间: {train_max_time}, 测试集最小时间: {test_min_time}") + + if train_max_time <= test_min_time: + logger.info("时间划分正确,训练集时间早于测试集") + else: + logger.warning("时间划分存在问题,训练集时间晚于测试集") + + return self.train_data, self.test_data + + def prepare_features_labels(self, feature_cols: Optional[List[str]] = None, label_col: str = "Class") -> None: + logger.info("准备特征和标签...") + if feature_cols is None: + feature_cols = [col for col in self.data.columns if col != label_col] + + logger.info(f"使用的特征列: {feature_cols}") + + self.train_features = self.train_data.select(feature_cols).to_numpy() + self.train_labels = self.train_data.select(label_col).to_numpy().flatten() + self.test_features = self.test_data.select(feature_cols).to_numpy() + self.test_labels = self.test_data.select(label_col).to_numpy().flatten() + + logger.info(f"训练特征形状: {self.train_features.shape}, 训练标签形状: {self.train_labels.shape}") + logger.info(f"测试特征形状: {self.test_features.shape}, 测试标签形状: {self.test_labels.shape}") + + def get_statistics(self) -> Dict[str, any]: + if self.data is None: + self.load_data() + + stats = { + "总记录数": self.data.height, + "特征数": len([col for col in self.data.columns if col != "Class"]), + "欺诈交易数": self.data.filter(pl.col("Class") == 1).height, + "非欺诈交易数": self.data.filter(pl.col("Class") == 0).height, + "不平衡比例": self.data.filter(pl.col("Class") == 0).height / self.data.filter(pl.col("Class") == 1).height + } + return stats + + +def load_data(file_path: str = "data/creditcard.csv") -> CreditCardDataProcessor: + processor = CreditCardDataProcessor(file_path) + processor.load_data() + processor.validate_data() + processor.split_data_by_time() + processor.prepare_features_labels() + return processor + + +if __name__ == "__main__": + processor = load_data() + stats = processor.get_statistics() + print("\n=== 数据集统计信息 ===") + for key, value in stats.items(): + print(f"{key}: {value}") diff --git a/src/features.py b/src/features.py new file mode 100644 index 0000000..76969e1 --- /dev/null +++ b/src/features.py @@ -0,0 +1,118 @@ +from pydantic import BaseModel, Field +from typing import List, Optional +from enum import Enum + + +class TransactionClass(str, Enum): + NORMAL = "正常" + FRAUD = "欺诈" + + +class ConfidenceLevel(str, Enum): + HIGH = "高" + MEDIUM = "中" + LOW = "低" + + +class Priority(str, Enum): + URGENT = "紧急" + HIGH = "高" + MEDIUM = "中" + LOW = "低" + ROUTINE = "常规" + + +class TransactionFeatures(BaseModel): + time: float = Field(..., description="交易时间(秒)") + v1: float = Field(..., description="PCA特征V1") + v2: float = Field(..., description="PCA特征V2") + v3: float = Field(..., description="PCA特征V3") + v4: float = Field(..., description="PCA特征V4") + v5: float = Field(..., description="PCA特征V5") + v6: float = Field(..., description="PCA特征V6") + v7: float = Field(..., description="PCA特征V7") + v8: float = Field(..., description="PCA特征V8") + v9: float = Field(..., description="PCA特征V9") + v10: float = Field(..., description="PCA特征V10") + v11: float = Field(..., description="PCA特征V11") + v12: float = Field(..., description="PCA特征V12") + v13: float = Field(..., description="PCA特征V13") + v14: float = Field(..., description="PCA特征V14") + v15: float = Field(..., description="PCA特征V15") + v16: float = Field(..., description="PCA特征V16") + v17: float = Field(..., description="PCA特征V17") + v18: float = Field(..., description="PCA特征V18") + v19: float = Field(..., description="PCA特征V19") + v20: float = Field(..., description="PCA特征V20") + v21: float = Field(..., description="PCA特征V21") + v22: float = Field(..., description="PCA特征V22") + v23: float = Field(..., description="PCA特征V23") + v24: float = Field(..., description="PCA特征V24") + v25: float = Field(..., description="PCA特征V25") + v26: float = Field(..., description="PCA特征V26") + v27: float = Field(..., description="PCA特征V27") + v28: float = Field(..., description="PCA特征V28") + amount: float = Field(..., description="交易金额") + + def to_array(self) -> List[float]: + return [ + self.time, self.v1, self.v2, self.v3, self.v4, self.v5, self.v6, self.v7, self.v8, self.v9, + self.v10, self.v11, self.v12, self.v13, self.v14, self.v15, self.v16, self.v17, self.v18, self.v19, + self.v20, self.v21, self.v22, self.v23, self.v24, self.v25, self.v26, self.v27, self.v28, self.amount + ] + + +class EvaluationResult(BaseModel): + predicted_class: int = Field(..., description="预测类别(0=正常, 1=欺诈)") + class_name: TransactionClass = Field(..., description="类别名称") + fraud_probability: float = Field(..., ge=0, le=1, description="欺诈概率") + normal_probability: float = Field(..., ge=0, le=1, description="正常概率") + confidence: ConfidenceLevel = Field(..., description="置信度等级") + + +class FeatureContribution(BaseModel): + feature_name: str = Field(..., description="特征名称") + value: float = Field(..., description="特征值") + importance: float = Field(..., ge=0, le=1, description="特征重要性") + contribution: float = Field(..., description="特征贡献度") + impact: str = Field(..., description="影响方向(正/负/无)") + + +class ExplanationResult(BaseModel): + model_type: str = Field(..., description="模型类型") + predicted_class: TransactionClass = Field(..., description="预测类别") + key_features: List[FeatureContribution] = Field(..., description="关键特征列表") + overall_explanation: str = Field(..., description="总体解释") + + +class Action(BaseModel): + priority: Priority = Field(..., description="优先级") + action: str = Field(..., description="行动建议") + reason: str = Field(..., description="行动原因") + + +class ActionPlan(BaseModel): + predicted_class: TransactionClass = Field(..., description="预测类别") + confidence: ConfidenceLevel = Field(..., description="置信度") + actions: List[Action] = Field(..., description="行动建议列表") + + +class DecisionResult(BaseModel): + evaluation: EvaluationResult = Field(..., description="评估结果") + explanation: ExplanationResult = Field(..., description="解释结果") + action_plan: ActionPlan = Field(..., description="行动计划") + timestamp: str = Field(..., description="时间戳") + + +class ModelMetrics(BaseModel): + accuracy: float = Field(..., ge=0, le=1, description="准确率") + precision: float = Field(..., ge=0, le=1, description="精确率") + recall: float = Field(..., ge=0, le=1, description="召回率") + f1_score: float = Field(..., ge=0, le=1, description="F1分数") + pr_auc: float = Field(..., ge=0, le=1, description="PR-AUC") + + +class TrainingResult(BaseModel): + model_name: str = Field(..., description="模型名称") + metrics: ModelMetrics = Field(..., description="评估指标") + confusion_matrix: List[List[int]] = Field(..., description="混淆矩阵") diff --git a/src/infer.py b/src/infer.py new file mode 100644 index 0000000..dd89075 --- /dev/null +++ b/src/infer.py @@ -0,0 +1,121 @@ +import numpy as np +import logging +from pathlib import Path +from typing import Optional, Union, List +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from train import CreditCardFraudModelTrainer +from features import TransactionFeatures, EvaluationResult, TransactionClass, ConfidenceLevel + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class FraudDetectionInference: + def __init__(self, model_dir: str = "models", model_name: str = "random_forest"): + self.model_dir = Path(model_dir) + self.model_name = model_name + self.trainer = CreditCardFraudModelTrainer(model_dir=model_dir) + + self._load_models() + + def _load_models(self) -> None: + logger.info("加载模型和缩放器...") + success = self.trainer.load_model(self.model_name) and self.trainer.load_scaler() + if not success: + raise RuntimeError("模型或缩放器加载失败") + logger.info("模型和缩放器加载成功") + + def predict(self, transaction: Union[List[float], np.ndarray, TransactionFeatures]) -> EvaluationResult: + if isinstance(transaction, TransactionFeatures): + transaction_array = np.array(transaction.to_array()) + elif isinstance(transaction, list): + transaction_array = np.array(transaction) + else: + transaction_array = transaction + + if transaction_array.ndim == 1: + transaction_array = transaction_array.reshape(1, -1) + + prediction = self.trainer.predict(transaction_array) + probability = self.trainer.predict_proba(transaction_array) + + fraud_prob = float(probability[0]) + normal_prob = float(1 - fraud_prob) + + max_prob = max(fraud_prob, normal_prob) + if max_prob > 0.8: + confidence = ConfidenceLevel.HIGH + elif max_prob > 0.6: + confidence = ConfidenceLevel.MEDIUM + else: + confidence = ConfidenceLevel.LOW + + class_name = TransactionClass.FRAUD if prediction[0] == 1 else TransactionClass.NORMAL + + return EvaluationResult( + predicted_class=int(prediction[0]), + class_name=class_name, + fraud_probability=fraud_prob, + normal_probability=normal_prob, + confidence=confidence + ) + + def predict_batch(self, transactions: Union[List[List[float]], np.ndarray]) -> List[EvaluationResult]: + if isinstance(transactions, list): + transactions_array = np.array(transactions) + else: + transactions_array = transactions + + predictions = self.trainer.predict(transactions_array) + probabilities = self.trainer.predict_proba(transactions_array) + + results = [] + for pred, prob in zip(predictions, probabilities): + fraud_prob = float(prob) + normal_prob = float(1 - fraud_prob) + + max_prob = max(fraud_prob, normal_prob) + if max_prob > 0.8: + confidence = ConfidenceLevel.HIGH + elif max_prob > 0.6: + confidence = ConfidenceLevel.MEDIUM + else: + confidence = ConfidenceLevel.LOW + + class_name = TransactionClass.FRAUD if pred == 1 else TransactionClass.NORMAL + + results.append(EvaluationResult( + predicted_class=int(pred), + class_name=class_name, + fraud_probability=fraud_prob, + normal_probability=normal_prob, + confidence=confidence + )) + + return results + + +def load_inference(model_dir: str = "models", model_name: str = "random_forest") -> FraudDetectionInference: + return FraudDetectionInference(model_dir=model_dir, model_name=model_name) + + +if __name__ == "__main__": + inference = load_inference() + + test_transaction = [ + 0, -1.3598071336738, -0.0727811733098497, 2.53634673796914, 1.37815522427443, + -0.338320769942518, 0.462387777762292, 0.239598554061257, 0.0986979012610507, + 0.363786969611213, 0.0907941719789316, -0.551599533260813, -0.617800855762348, + -0.991389847235408, -0.311169353699879, 1.46817697209427, -0.470400525259478, + 0.207971241929242, 0.0257905801985591, 0.403992960255733, 0.251412098239705, + -0.018306777944153, 0.277837575558899, -0.110473910188767, 0.0669280749146731, + 0.128539358273528, -0.189114843888824, 0.133558376740387, -0.0210530534538215, + 149.62 + ] + + result = inference.predict(test_transaction) + print("预测结果:") + print(f"类别: {result.class_name}") + print(f"欺诈概率: {result.fraud_probability:.4f}") + print(f"置信度: {result.confidence}") diff --git a/src/streamlit_app.py b/src/streamlit_app.py new file mode 100644 index 0000000..2d0a8de --- /dev/null +++ b/src/streamlit_app.py @@ -0,0 +1,451 @@ +import streamlit as st +import numpy as np +import polars as pl +from pathlib import Path +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from agent_app import create_agent +from features import TransactionFeatures, DecisionResult, TransactionClass, ConfidenceLevel, Priority + +st.set_page_config( + page_title="信用卡欺诈检测系统", + page_icon="💳", + layout="wide" +) + +st.title("💳 信用卡欺诈检测系统") +st.markdown("基于机器学习的实时欺诈检测与决策支持系统") + +@st.cache_resource +def load_agent(): + return create_agent() + +agent = load_agent() + +@st.cache_data +def load_csv_file(uploaded_file): + if uploaded_file is not None: + try: + df = pl.read_csv(uploaded_file, schema_overrides={"Time": pl.Float64}) + return df + except Exception as e: + st.error(f"读取CSV文件失败: {e}") + return None + return None + +st.sidebar.header("系统信息") +st.sidebar.info(f""" +**模型信息** +- 模型类型: RandomForest +- 特征数量: 30 +- 支持工具: 2个 + - predict_fraud (ML工具) + - analyze_transaction +""") + +st.header("输入交易数据") + +input_mode = st.radio( + "选择输入方式", + ["📁 上传CSV文件", "✏️ 手动输入特征"], + horizontal=True +) + +if input_mode == "📁 上传CSV文件": + st.subheader("上传CSV文件") + + uploaded_file = st.file_uploader( + "选择CSV文件", + type=['csv'], + help="上传包含交易数据的CSV文件" + ) + + if uploaded_file is not None: + df = load_csv_file(uploaded_file) + + if df is not None: + st.success(f"✅ 成功加载CSV文件,共 {df.height} 条交易记录") + + st.write("### 数据预览") + st.dataframe(df.head(10), use_container_width=True) + + st.write("### 选择交易") + if "Class" in df.columns: + df = df.drop("Class") + + row_index = st.number_input( + "选择交易行号(从0开始)", + min_value=0, + max_value=df.height - 1, + value=0, + step=1 + ) + + if st.button("📋 加载选中的交易", type="primary"): + feature_names = [ + 'Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', + 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', + 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount' + ] + transaction = [df[row_index, col] for col in feature_names] + + st.session_state.transaction = transaction + st.success(f"✅ 已加载第 {row_index} 行的交易数据") + + st.write("### 选中的交易数据") + feature_data = { + "特征名称": feature_names, + "特征值": [f"{v:.6f}" for v in transaction] + } + st.dataframe( + pl.DataFrame(feature_data), + use_container_width=True + ) + +else: + st.subheader("手动输入特征") + + col1, col2 = st.columns(2) + + with col1: + st.write("基础信息") + time = st.number_input("Time (交易时间)", value=0.0, step=1.0) + amount = st.number_input("Amount (交易金额)", value=100.0, step=1.0) + + with col2: + st.write("PCA特征 V1-V14") + v1 = st.number_input("V1", value=0.0, step=0.1) + v2 = st.number_input("V2", value=0.0, step=0.1) + v3 = st.number_input("V3", value=0.0, step=0.1) + v4 = st.number_input("V4", value=0.0, step=0.1) + v5 = st.number_input("V5", value=0.0, step=0.1) + v6 = st.number_input("V6", value=0.0, step=0.1) + v7 = st.number_input("V7", value=0.0, step=0.1) + v8 = st.number_input("V8", value=0.0, step=0.1) + v9 = st.number_input("V9", value=0.0, step=0.1) + v10 = st.number_input("V10", value=0.0, step=0.1) + v11 = st.number_input("V11", value=0.0, step=0.1) + v12 = st.number_input("V12", value=0.0, step=0.1) + v13 = st.number_input("V13", value=0.0, step=0.1) + v14 = st.number_input("V14", value=0.0, step=0.1) + + col3, col4 = st.columns(2) + + with col3: + st.write("PCA特征 V15-V21") + v15 = st.number_input("V15", value=0.0, step=0.1) + v16 = st.number_input("V16", value=0.0, step=0.1) + v17 = st.number_input("V17", value=0.0, step=0.1) + v18 = st.number_input("V18", value=0.0, step=0.1) + v19 = st.number_input("V19", value=0.0, step=0.1) + v20 = st.number_input("V20", value=0.0, step=0.1) + v21 = st.number_input("V21", value=0.0, step=0.1) + + with col4: + st.write("PCA特征 V22-V28") + v22 = st.number_input("V22", value=0.0, step=0.1) + v23 = st.number_input("V23", value=0.0, step=0.1) + v24 = st.number_input("V24", value=0.0, step=0.1) + v25 = st.number_input("V25", value=0.0, step=0.1) + v26 = st.number_input("V26", value=0.0, step=0.1) + v27 = st.number_input("V27", value=0.0, step=0.1) + v28 = st.number_input("V28", value=0.0, step=0.1) + +st.divider() + +if st.button("🔍 检测欺诈", type="primary", use_container_width=True): + if input_mode == "📁 上传CSV文件": + if "transaction" in st.session_state: + transaction = st.session_state.transaction + else: + st.warning("⚠️ 请先上传CSV文件并选择交易") + st.stop() + else: + transaction = [ + time, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, amount + ] + + with st.spinner("正在分析交易..."): + result = agent.process_transaction(transaction) + + st.success("分析完成!") + + col5, col6, col7 = st.columns(3) + + with col5: + st.metric( + label="预测类别", + value=result.evaluation.class_name.value, + delta=f"置信度: {result.evaluation.confidence.value}" + ) + + with col6: + fraud_prob = result.evaluation.fraud_probability * 100 + st.metric( + label="欺诈概率", + value=f"{fraud_prob:.2f}%", + delta=f"{100 - fraud_prob:.2f}% 正常" + ) + + with col7: + st.metric( + label="行动建议数量", + value=len(result.action_plan.actions), + delta="已生成" + ) + + st.divider() + + tab1, tab2, tab3 = st.tabs(["📊 评估结果", "🔍 特征解释", "📋 行动计划"]) + + with tab1: + st.subheader("模型评估结果") + + eval_col1, eval_col2 = st.columns(2) + + with eval_col1: + st.info(f""" + **预测信息** + - 预测类别: {result.evaluation.class_name.value} + - 预测标签: {result.evaluation.predicted_class} + - 置信度: {result.evaluation.confidence.value} + """) + + with eval_col2: + st.info(f""" + **概率分布** + - 欺诈概率: {result.evaluation.fraud_probability:.4f} + - 正常概率: {result.evaluation.normal_probability:.4f} + """) + + if result.evaluation.class_name == TransactionClass.FRAUD: + st.error(f"⚠️ 该交易被识别为**欺诈交易**,请立即采取行动!") + else: + st.success(f"✅ 该交易被识别为**正常交易**") + + with tab2: + st.subheader("🔍 特征解释") + + st.markdown(""" +
+ 就像医生看病时会检查各项指标一样,我们的AI模型也通过分析交易的各项"特征"来判断是否为欺诈。 + 下面这些特征是影响判断结果最重要的因素,让我们来看看它们是如何"告诉"模型这个交易是否有问题的。 +
++ 影响程度: +
++ {importance_percent:.1f}% 的影响力 +
+ ++ 影响方向: {feature.impact} +
+ +
+ 📖 这个特征是什么?
+ {feature_desc}
+
+ 💡 简单来说: + {"这个特征让模型更倾向于认为这是欺诈交易" if feature.impact == "正向影响" else "这个特征让模型更倾向于认为这是正常交易"} +
++ 根据检测结果,我们为您准备了具体的行动建议。这些建议按照紧急程度排序, + 帮助您快速、有效地应对可能的风险。请根据实际情况选择合适的处理方式。 +
++ 优先级: {action.priority.value} - {priority_info['description']} +
+
+ 💡 为什么要这样做?
+ {action.reason}
+