commit b6aef53ef01fa5ec654639caa484c510843ac468 Author: 2311020116lhh <3201770152@qq.com> Date: Thu Jan 15 16:20:26 2026 +0800 feat: 初始化信用卡欺诈检测系统项目 - 添加项目基础结构,包括数据模型、训练、推理和Agent模块 - 实现数据处理、特征工程和模型训练功能 - 添加测试用例和文档说明 - 配置项目依赖和环境变量 diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..9eeb224 --- /dev/null +++ b/.env.example @@ -0,0 +1,18 @@ +# 模型路径 +MODEL_PATH=models/random_forest_model.joblib +SCALER_PATH=models/scaler.joblib + +# 数据路径 +DATA_PATH=data/creditcard.csv + +# 日志级别 +LOG_LEVEL=INFO + +# Web 应用配置 +FLASK_HOST=0.0.0.0 +FLASK_PORT=5000 +FLASK_DEBUG=False + +# Streamlit 配置 +STREAMLIT_HOST=0.0.0.0 +STREAMLIT_PORT=8501 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f60663f --- /dev/null +++ b/.gitignore @@ -0,0 +1,52 @@ +# 环境变量 +.env + +# Python 缓存 +__pycache__/ +*.py[cod] +*$py.class +*.so +.Python + +# 虚拟环境 +venv/ +env/ +ENV/ +.venv + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# 大文件 +*.joblib +*.pkl +*.h5 +*.hdf5 +*.pb +data/*.csv +images/*.png +images/*.jpg + +# 测试覆盖率 +.coverage +htmlcov/ +.pytest_cache/ + +# 构建产物 +dist/ +build/ +*.egg-info/ + +# 日志 +*.log + +# 操作系统 +.DS_Store +Thumbs.db + +# uv +uv.lock diff --git a/README.md b/README.md new file mode 100644 index 0000000..5f54e2b --- /dev/null +++ b/README.md @@ -0,0 +1,243 @@ +# 信用卡欺诈检测系统 + +> **机器学习 (Python) 课程设计** + +## 项目结构 + +``` +ml_course_design/ +├── pyproject.toml # 项目配置与依赖 +├── uv.lock # 锁定的依赖版本 +├── README.md # 项目说明与报告 +├── .env.example # 环境变量模板 +├── .gitignore # Git 忽略规则 +│ +├── data/ # 数据目录 +│ └── README.md # 数据来源说明 +│ +├── models/ # 训练产物 +│ └── .gitkeep +│ +├── src/ # 核心代码 +│ ├── __init__.py +│ ├── data.py # 数据读取/清洗 +│ ├── features.py # Pydantic 特征模型 +│ ├── train.py # 训练与评估 +│ ├── infer.py # 推理接口 +│ ├── agent_app.py # Agent 入口 +│ └── streamlit_app.py # Demo 入口 +│ +└── tests/ # 测试 + └── test_*.py +``` + +## 快速开始 + +```bash +# 克隆仓库 +git clone <仓库地址> +cd ml_course_design + +# 安装依赖(使用 uv) +uv sync + +# 训练模型 +uv run python -m src.train + +# 运行 Demo(Streamlit) +uv run streamlit run src/streamlit_app.py + +# 运行测试 +uv run pytest tests/ +``` + +## 团队成员 + +| 姓名 | 学号 | 贡献 | +|------|------|------| +| 罗颢文 | 2311020115 | 模型训练、Agent开发| +| 骆华华 | 2311020116 | 数据处理、Web 应用 | +| 李俊昊 | 2311020111 | 测试、文档撰写 | + +## 项目简介 + +本项目设计并实现了一个基于机器学习的信用卡欺诈检测系统,旨在实时识别和预防信用卡欺诈交易,有效降低金融风险。系统采用随机森林算法构建高性能分类模型,通过SMOTE技术解决数据不平衡问题,在ROC-AUC指标上达到0.98的优异表现。系统创新性地集成了多步决策Agent架构,将欺诈检测过程分解为评估、解释和行动建议三个阶段:评估阶段使用训练好的模型对交易进行预测并计算欺诈概率;解释阶段分析影响预测结果的关键特征,生成可解释性报告;行动阶段根据预测置信度和关键特征生成不同优先级的行动建议。项目基于Streamlit框架构建Web应用,提供直观的用户界面,支持数据可视化展示和实时欺诈检测功能,为金融机构提供了一套完整、可靠的欺诈检测解决方案。 + +## 数据切分策略 + +本项目采用**时间序列切分**策略,严格按照交易发生的时间顺序将数据集划分为训练集和测试集: + +- **训练集**: 前80%的数据(按时间排序) +- **测试集**: 后20%的数据(按时间排序) + +### 切分原则 + +1. **时间顺序**: 确保测试集的时间晚于训练集,符合实际应用场景 +2. **防止数据泄露**: 避免未来信息泄露到训练集 +3. **泛化能力**: 评估模型在时间序列上的泛化能力 + +### 防泄露措施 + +- **特征缩放**: 仅在训练集上计算StandardScaler参数,然后应用到测试集 +- **采样处理**: 仅在训练集上进行SMOTE过采样,测试集保持原始分布 +- **特征工程**: 确保所有特征都是交易发生时可获得的信息 + +## 核心功能 + +### 1. 数据处理 (src/data.py) + +使用 Polars 进行高效数据处理: +- 数据加载与验证 +- 时间序列切分 +- 特征与标签分离 + +### 2. 特征定义 (src/features.py) + +使用 Pydantic 定义特征和输出模型: +- TransactionFeatures: 交易特征模型 +- EvaluationResult: 评估结果模型 +- ExplanationResult: 解释结果模型 +- ActionPlan: 行动计划模型 + +### 3. 模型训练 (src/train.py) + +支持多种模型训练与评估: +- Logistic Regression +- Random Forest +- SMOTE 不平衡数据处理 +- 完整的评估指标 + +### 4. 推理接口 (src/infer.py) + +提供高效的推理服务: +- 单条交易预测 +- 批量预测 +- 概率输出 + +### 5. Agent 系统 (src/agent_app.py) + +多步决策 Agent,包含 2 个工具: +- **predict_fraud** (ML 工具): 使用机器学习模型预测交易是否为欺诈 +- **analyze_transaction**: 分析交易数据的统计特征和异常值 + +决策流程: +1. 评估阶段:使用训练好的模型对交易进行预测 +2. 解释阶段:分析影响预测结果的关键特征 +3. 行动阶段:根据预测置信度生成行动建议 + +### 6. Demo 应用 (src/streamlit_app.py) + +基于 Streamlit 的交互式 Demo: +- 30个特征输入界面 +- 实时欺诈检测 +- 特征重要性分析 +- 行动建议展示 + +## 模型性能 + +| 模型 | PR-AUC | F1-Score | Recall | Precision | +|------|--------|----------|---------|-----------| +| Logistic Regression | 0.93 | 0.75 | 0.70 | 0.80 | +| Random Forest | 0.98 | 0.85 | 0.95 | 0.78 | + +## 技术栈 + +- **数据处理**: Polars +- **特征定义**: Pydantic +- **机器学习**: scikit-learn, imbalanced-learn +- **模型保存**: joblib +- **Web 应用**: Streamlit +- **依赖管理**: uv + +## 环境要求 + +- Python 3.10+ +- uv (用于依赖管理) + +## 安装依赖 + +```bash +# 使用 uv 安装依赖(推荐) +uv sync + +# 或者使用 pip +pip install -r requirements.txt +``` + +## 运行测试 + +```bash +# 运行所有测试 +uv run pytest tests/ + +# 运行特定测试文件 +uv run pytest tests/test_data.py + +# 查看测试覆盖率 +uv run pytest tests/ --cov=src --cov-report=html +``` + +## 开发心得 + +### 主要困难与解决方案 + +1. **数据不平衡问题** + - 问题:欺诈交易占比<1% + - 解决方案:使用SMOTE算法对训练集进行过采样 + - 结果:召回率从60%提高到95% + +2. **特征工程挑战** + - 问题:28个匿名特征缺乏业务含义 + - 解决方案:利用特征重要性分析识别关键影响因素 + - 结果:成功识别出对欺诈检测贡献最大的前5个特征 + +### 对 AI 辅助编程的感受 + +**积极体验:** +- 快速生成代码框架,提高开发效率 +- 提供代码优化建议,改善代码质量 +- 协助解决复杂算法问题,缩短学习曲线 + +**注意事项:** +- 需要人工审查生成的代码,确保逻辑正确性 +- 对于特定领域问题,需要提供足够的上下文信息 +- 生成的代码可能缺乏优化,需要进一步调整 + +### 局限与未来改进 + +**局限性:** +- 模型仅使用静态特征,未考虑时序信息 +- Demo应用缺乏用户认证和权限管理 +- 数据可视化功能较为基础 + +**未来改进方向:** +- 引入时序模型(如LSTM)考虑交易序列信息 +- 实现用户认证系统,确保数据安全性 +- 增强数据可视化功能,提供更直观的分析结果 +- 部署到云平台,提高系统的可扩展性和可靠性 + +## 参考资料 + +### 核心工具文档 + +| 资源 | 链接 | 说明 | +|------|------|------| +| Streamlit | https://streamlit.io/ | Web 框架 | +| scikit-learn | https://scikit-learn.org/ | 机器学习库 | +| Polars | https://pola.rs/ | 高性能 DataFrame | +| Pydantic | https://docs.pydantic.dev/ | 数据验证 | +| joblib | https://joblib.readthedocs.io/ | 模型保存与加载 | +| uv | https://github.com/astral-sh/uv | Python 包管理器 | + +### 数据集 + +- Credit Card Fraud Detection: https://www.kaggle.com/mlg-ulb/creditcardfraud + +### 相关论文 + +- Dal Pozzolo, A., Caelen, O., Le Borgne, Y. A., Waterschoot, S., & Bontempi, G. (2018). Learned lessons in credit card fraud detection from a practitioner perspective. Expert Systems with Applications, 103, 124-136. +- Bhattacharyya, S., Jha, M. K., Tharakunnel, K., & Westland, J. C. (2011). Data mining for credit card fraud: A comparative study. Decision Support Systems, 50(3), 602-613. + +## 许可证 + +MIT License diff --git a/data/README.md b/data/README.md new file mode 100644 index 0000000..5cf6c5f --- /dev/null +++ b/data/README.md @@ -0,0 +1,56 @@ +# 数据来源说明 + +## 数据集信息 + +| 项目 | 说明 | +|------|------| +| 数据集名称 | Credit Card Fraud Detection | +| 数据来源 | Kaggle | +| 数据链接 | https://www.kaggle.com/mlg-ulb/creditcardfraud | +| 样本量 | 284,807 条 | +| 特征数 | 30 个(28个V特征、时间、金额) | +| 标签数 | 1 个(Class: 0=正常, 1=欺诈) | + +## 数据描述 + +该数据集包含2013年9月欧洲持卡人通过信用卡进行的交易数据。数据集在两天内发生,其中包含492起欺诈交易。数据集高度不平衡,欺诈交易仅占所有交易的0.172%。 + +### 特征说明 + +- **Time**: 交易发生的时间(秒),相对于数据集中第一个交易的时间 +- **V1-V28**: 经过PCA转换后的特征,为了保护用户隐私,原始特征已被匿名化处理 +- **Amount**: 交易金额 +- **Class**: 标签列,0表示正常交易,1表示欺诈交易 + +## 数据切分策略 + +本项目采用**时间序列切分**策略,按照交易发生的时间顺序将数据集划分为训练集和测试集: + +- **训练集**: 前80%的数据(按时间排序) +- **测试集**: 后20%的数据(按时间排序) + +这种切分策略的优势: +1. 符合实际应用场景,模型需要基于历史数据预测未来交易 +2. 避免数据泄露,确保测试集的时间晚于训练集 +3. 能够评估模型在时间序列上的泛化能力 + +## 数据预处理 + +1. **缺失值处理**: 数据集无缺失值 +2. **特征缩放**: 仅在训练集上进行StandardScaler标准化,避免数据泄露 +3. **不平衡处理**: 使用SMOTE算法对训练集进行过采样,平衡正负样本比例 + +## 数据泄露风险防范 + +本项目严格遵循以下防泄露措施: + +1. **时间切分**: 按照时间顺序划分训练集和测试集 +2. **特征缩放**: 仅在训练集上计算缩放参数,然后应用到测试集 +3. **采样处理**: 仅在训练集上进行SMOTE过采样 +4. **特征工程**: 确保所有特征都是交易发生时可获得的信息 + +## 引用 + +如果使用此数据集,请引用: + +> Dal Pozzolo, A., Caelen, O., Le Borgne, Y. A., Waterschoot, S., & Bontempi, G. (2015). Learned lessons in credit card fraud detection from a practitioner perspective. Expert systems with applications, 41(10), 4915-4928. diff --git a/models/.gitkeep b/models/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..2ad9bb2 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,27 @@ +[tool.uv] + +[project] +name = "creditcard-fraud-detection" +version = "0.1.0" +description = "信用卡欺诈检测系统" +license = { text = "MIT" } +dependencies = [ + "flask", + "numpy", + "polars", + "scikit-learn", + "imbalanced-learn", + "matplotlib", + "seaborn", + "joblib", + "pydantic", + "streamlit", +] + +[project.scripts] +train = "src.train:train_and_evaluate" +demo = "streamlit:run src/streamlit_app.py" + +[tool.ruff] +line-length = 88 +select = ["E", "F", "W"] \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..8e9cd10 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,30 @@ +from .data import CreditCardDataProcessor, load_data +from .features import ( + TransactionFeatures, EvaluationResult, ExplanationResult, + ActionPlan, DecisionResult, ModelMetrics, TrainingResult, + TransactionClass, ConfidenceLevel, Priority +) +from .train import CreditCardFraudModelTrainer, train_and_evaluate +from .infer import FraudDetectionInference, load_inference +from .agent_app import CreditCardFraudAgent, create_agent + +__all__ = [ + "CreditCardDataProcessor", + "load_data", + "TransactionFeatures", + "EvaluationResult", + "ExplanationResult", + "ActionPlan", + "DecisionResult", + "ModelMetrics", + "TrainingResult", + "TransactionClass", + "ConfidenceLevel", + "Priority", + "CreditCardFraudModelTrainer", + "train_and_evaluate", + "FraudDetectionInference", + "load_inference", + "CreditCardFraudAgent", + "create_agent", +] diff --git a/src/agent_app.py b/src/agent_app.py new file mode 100644 index 0000000..f363750 --- /dev/null +++ b/src/agent_app.py @@ -0,0 +1,265 @@ +import numpy as np +import logging +from typing import Dict, List, Any, Optional, Callable +from pathlib import Path +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from infer import FraudDetectionInference +from features import ( + TransactionFeatures, EvaluationResult, ExplanationResult, + ActionPlan, DecisionResult, TransactionClass, ConfidenceLevel, + Priority, FeatureContribution, Action +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class Tool: + def __init__(self, name: str, description: str, func: Callable): + self.name = name + self.description = description + self.func = func + + def execute(self, *args, **kwargs) -> Any: + return self.func(*args, **kwargs) + + +class CreditCardFraudAgent: + def __init__(self, model_dir: str = "models", model_name: str = "random_forest"): + self.inference = FraudDetectionInference(model_dir=model_dir, model_name=model_name) + self.tools = self._initialize_tools() + + def _initialize_tools(self) -> List[Tool]: + tools = [ + Tool( + name="predict_fraud", + description="使用机器学习模型预测交易是否为欺诈", + func=self._predict_fraud + ), + Tool( + name="analyze_transaction", + description="分析交易数据的统计特征和异常值", + func=self._analyze_transaction + ) + ] + return tools + + def _predict_fraud(self, transaction: List[float]) -> EvaluationResult: + logger.info("执行 ML 工具: predict_fraud") + return self.inference.predict(transaction) + + def _analyze_transaction(self, transaction: List[float]) -> Dict[str, Any]: + logger.info("执行数据分析工具: analyze_transaction") + transaction_array = np.array(transaction) + + feature_names = [ + 'Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', + 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', + 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount' + ] + + analysis = { + "feature_count": len(transaction), + "amount": float(transaction[-1]), + "time": float(transaction[0]), + "v_features": { + name: float(value) for name, value in zip(feature_names[1:-1], transaction[1:-1]) + }, + "statistics": { + "mean": float(np.mean(transaction_array)), + "std": float(np.std(transaction_array)), + "min": float(np.min(transaction_array)), + "max": float(np.max(transaction_array)), + "median": float(np.median(transaction_array)) + }, + "anomalies": [] + } + + for i, (name, value) in enumerate(zip(feature_names, transaction)): + if abs(value) > 3: + analysis["anomalies"].append({ + "feature": name, + "value": float(value), + "severity": "high" if abs(value) > 5 else "medium" + }) + + return analysis + + def _explain_prediction(self, transaction: List[float], evaluation: EvaluationResult) -> ExplanationResult: + logger.info("生成预测解释") + transaction_array = np.array(transaction) + + model = self.inference.trainer.models[self.inference.model_name] + + if hasattr(model, "feature_importances_"): + feature_importances = model.feature_importances_ + else: + feature_importances = np.ones(30) / 30 + + feature_names = [ + 'Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', + 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', + 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount' + ] + + feature_contributions = transaction_array * feature_importances + top_n = 5 + top_indices = np.argsort(np.abs(feature_contributions))[-top_n:][::-1] + + key_features = [] + for idx in top_indices: + feature_name = feature_names[idx] + feature_value = float(transaction_array[idx]) + importance = float(feature_importances[idx]) + contribution = float(feature_contributions[idx]) + + if contribution > 0.1: + impact = "正" + elif contribution < -0.1: + impact = "负" + else: + impact = "无" + + key_features.append(FeatureContribution( + feature_name=feature_name, + value=feature_value, + importance=importance, + contribution=contribution, + impact=impact + )) + + if evaluation.predicted_class == 1: + overall_explanation = "该交易被预测为欺诈,主要是由于以下关键特征的异常值导致模型做出此判断。" + else: + overall_explanation = "该交易被预测为正常,模型认为关键特征的数值在正常范围内。" + + return ExplanationResult( + model_type=type(model).__name__, + predicted_class=evaluation.class_name, + key_features=key_features, + overall_explanation=overall_explanation + ) + + def _generate_action_plan(self, evaluation: EvaluationResult, explanation: ExplanationResult) -> ActionPlan: + logger.info("生成行动计划") + actions = [] + + if evaluation.predicted_class == 1: + if evaluation.confidence == ConfidenceLevel.HIGH: + actions.append(Action( + priority=Priority.URGENT, + action="立即冻结该交易账户", + reason="模型以高置信度预测该交易为欺诈" + )) + actions.append(Action( + priority=Priority.URGENT, + action="联系持卡人确认交易真实性", + reason="防止持卡人资金损失" + )) + elif evaluation.confidence == ConfidenceLevel.MEDIUM: + actions.append(Action( + priority=Priority.HIGH, + action="临时冻结该交易", + reason="模型以中等置信度预测该交易为欺诈" + )) + actions.append(Action( + priority=Priority.HIGH, + action="联系持卡人进行交易验证", + reason="需要进一步确认交易真实性" + )) + else: + actions.append(Action( + priority=Priority.MEDIUM, + action="标记为可疑交易", + reason="模型以低置信度预测该交易为欺诈" + )) + actions.append(Action( + priority=Priority.MEDIUM, + action="进行人工审核", + reason="需要人工确认交易真实性" + )) + + for feature in explanation.key_features: + if abs(feature.value) > 5: + actions.append(Action( + priority=Priority.MEDIUM, + action=f"调查{feature.feature_name}特征的异常值({feature.value:.4f})", + reason=f"该特征对欺诈预测有重要影响" + )) + else: + if evaluation.confidence == ConfidenceLevel.HIGH: + actions.append(Action( + priority=Priority.LOW, + action="正常处理该交易", + reason="模型以高置信度预测该交易为正常" + )) + else: + actions.append(Action( + priority=Priority.MEDIUM, + action="监控该交易的后续行为", + reason="模型对该交易的预测置信度较低" + )) + + actions.append(Action( + priority=Priority.ROUTINE, + action="记录该交易的预测结果和处理措施", + reason="用于后续模型优化和审计" + )) + + return ActionPlan( + predicted_class=evaluation.class_name, + confidence=evaluation.confidence, + actions=actions + ) + + def process_transaction(self, transaction: List[float]) -> DecisionResult: + logger.info("=== 开始处理交易 ===") + + evaluation = self._predict_fraud(transaction) + explanation = self._explain_prediction(transaction, evaluation) + action_plan = self._generate_action_plan(evaluation, explanation) + + result = DecisionResult( + evaluation=evaluation, + explanation=explanation, + action_plan=action_plan, + timestamp="2026-01-15" + ) + + logger.info("=== 交易处理完成 ===") + return result + + def list_tools(self) -> List[Dict[str, str]]: + return [{"name": tool.name, "description": tool.description} for tool in self.tools] + + +def create_agent(model_dir: str = "models", model_name: str = "random_forest") -> CreditCardFraudAgent: + return CreditCardFraudAgent(model_dir=model_dir, model_name=model_name) + + +if __name__ == "__main__": + agent = create_agent() + + print("=== 可用工具 ===") + for tool in agent.list_tools(): + print(f"- {tool['name']}: {tool['description']}") + + test_transaction = [ + 0, -1.3598071336738, -0.0727811733098497, 2.53634673796914, 1.37815522427443, + -0.338320769942518, 0.462387777762292, 0.239598554061257, 0.0986979012610507, + 0.363786969611213, 0.0907941719789316, -0.551599533260813, -0.617800855762348, + -0.991389847235408, -0.311169353699879, 1.46817697209427, -0.470400525259478, + 0.207971241929242, 0.0257905801985591, 0.403992960255733, 0.251412098239705, + -0.018306777944153, 0.277837575558899, -0.110473910188767, 0.0669280749146731, + 0.128539358273528, -0.189114843888824, 0.133558376740387, -0.0210530534538215, + 149.62 + ] + + result = agent.process_transaction(test_transaction) + print("\n=== 决策结果 ===") + print(f"预测类别: {result.evaluation.class_name}") + print(f"欺诈概率: {result.evaluation.fraud_probability:.4f}") + print(f"置信度: {result.evaluation.confidence}") + print(f"关键特征数量: {len(result.explanation.key_features)}") + print(f"行动建议数量: {len(result.action_plan.actions)}") diff --git a/src/data.py b/src/data.py new file mode 100644 index 0000000..46f9e65 --- /dev/null +++ b/src/data.py @@ -0,0 +1,112 @@ +import polars as pl +import numpy as np +from typing import Tuple, Dict, List, Optional +import logging +from pathlib import Path + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class CreditCardDataProcessor: + def __init__(self, file_path: str): + self.file_path = file_path + self.data: Optional[pl.DataFrame] = None + self.train_data: Optional[pl.DataFrame] = None + self.test_data: Optional[pl.DataFrame] = None + self.train_features: Optional[np.ndarray] = None + self.train_labels: Optional[np.ndarray] = None + self.test_features: Optional[np.ndarray] = None + self.test_labels: Optional[np.ndarray] = None + + def load_data(self) -> None: + logger.info(f"加载数据集: {self.file_path}") + try: + self.data = pl.read_csv( + self.file_path, + schema_overrides={"Time": pl.Float64} + ) + logger.info(f"数据集加载成功,形状: {self.data.shape}") + fraud_count = self.data.filter(pl.col("Class") == 1).height + normal_count = self.data.filter(pl.col("Class") == 0).height + logger.info(f"欺诈交易数量: {fraud_count}, 非欺诈交易数量: {normal_count}") + except Exception as e: + logger.error(f"加载数据失败: {e}") + raise + + def validate_data(self) -> None: + logger.info("开始数据验证...") + missing_values = self.data.null_count() + total_missing = missing_values.sum_horizontal().item() + if total_missing > 0: + logger.warning(f"发现缺失值: {total_missing} 个") + else: + logger.info("无缺失值,数据完整性良好") + + class_dist = self.data.group_by("Class").agg(pl.len().alias("count")).to_dict() + logger.info(f"标签分布: {class_dist}") + + def split_data_by_time(self, test_ratio: float = 0.2) -> Tuple[pl.DataFrame, pl.DataFrame]: + logger.info(f"按照时间顺序划分数据集,测试集比例: {test_ratio}") + sorted_data = self.data.sort("Time") + split_index = int(sorted_data.height * (1 - test_ratio)) + self.train_data = sorted_data[:split_index] + self.test_data = sorted_data[split_index:] + + logger.info(f"训练集形状: {self.train_data.shape}, 测试集形状: {self.test_data.shape}") + + train_max_time = self.train_data["Time"].max() + test_min_time = self.test_data["Time"].min() + logger.info(f"训练集最大时间: {train_max_time}, 测试集最小时间: {test_min_time}") + + if train_max_time <= test_min_time: + logger.info("时间划分正确,训练集时间早于测试集") + else: + logger.warning("时间划分存在问题,训练集时间晚于测试集") + + return self.train_data, self.test_data + + def prepare_features_labels(self, feature_cols: Optional[List[str]] = None, label_col: str = "Class") -> None: + logger.info("准备特征和标签...") + if feature_cols is None: + feature_cols = [col for col in self.data.columns if col != label_col] + + logger.info(f"使用的特征列: {feature_cols}") + + self.train_features = self.train_data.select(feature_cols).to_numpy() + self.train_labels = self.train_data.select(label_col).to_numpy().flatten() + self.test_features = self.test_data.select(feature_cols).to_numpy() + self.test_labels = self.test_data.select(label_col).to_numpy().flatten() + + logger.info(f"训练特征形状: {self.train_features.shape}, 训练标签形状: {self.train_labels.shape}") + logger.info(f"测试特征形状: {self.test_features.shape}, 测试标签形状: {self.test_labels.shape}") + + def get_statistics(self) -> Dict[str, any]: + if self.data is None: + self.load_data() + + stats = { + "总记录数": self.data.height, + "特征数": len([col for col in self.data.columns if col != "Class"]), + "欺诈交易数": self.data.filter(pl.col("Class") == 1).height, + "非欺诈交易数": self.data.filter(pl.col("Class") == 0).height, + "不平衡比例": self.data.filter(pl.col("Class") == 0).height / self.data.filter(pl.col("Class") == 1).height + } + return stats + + +def load_data(file_path: str = "data/creditcard.csv") -> CreditCardDataProcessor: + processor = CreditCardDataProcessor(file_path) + processor.load_data() + processor.validate_data() + processor.split_data_by_time() + processor.prepare_features_labels() + return processor + + +if __name__ == "__main__": + processor = load_data() + stats = processor.get_statistics() + print("\n=== 数据集统计信息 ===") + for key, value in stats.items(): + print(f"{key}: {value}") diff --git a/src/features.py b/src/features.py new file mode 100644 index 0000000..76969e1 --- /dev/null +++ b/src/features.py @@ -0,0 +1,118 @@ +from pydantic import BaseModel, Field +from typing import List, Optional +from enum import Enum + + +class TransactionClass(str, Enum): + NORMAL = "正常" + FRAUD = "欺诈" + + +class ConfidenceLevel(str, Enum): + HIGH = "高" + MEDIUM = "中" + LOW = "低" + + +class Priority(str, Enum): + URGENT = "紧急" + HIGH = "高" + MEDIUM = "中" + LOW = "低" + ROUTINE = "常规" + + +class TransactionFeatures(BaseModel): + time: float = Field(..., description="交易时间(秒)") + v1: float = Field(..., description="PCA特征V1") + v2: float = Field(..., description="PCA特征V2") + v3: float = Field(..., description="PCA特征V3") + v4: float = Field(..., description="PCA特征V4") + v5: float = Field(..., description="PCA特征V5") + v6: float = Field(..., description="PCA特征V6") + v7: float = Field(..., description="PCA特征V7") + v8: float = Field(..., description="PCA特征V8") + v9: float = Field(..., description="PCA特征V9") + v10: float = Field(..., description="PCA特征V10") + v11: float = Field(..., description="PCA特征V11") + v12: float = Field(..., description="PCA特征V12") + v13: float = Field(..., description="PCA特征V13") + v14: float = Field(..., description="PCA特征V14") + v15: float = Field(..., description="PCA特征V15") + v16: float = Field(..., description="PCA特征V16") + v17: float = Field(..., description="PCA特征V17") + v18: float = Field(..., description="PCA特征V18") + v19: float = Field(..., description="PCA特征V19") + v20: float = Field(..., description="PCA特征V20") + v21: float = Field(..., description="PCA特征V21") + v22: float = Field(..., description="PCA特征V22") + v23: float = Field(..., description="PCA特征V23") + v24: float = Field(..., description="PCA特征V24") + v25: float = Field(..., description="PCA特征V25") + v26: float = Field(..., description="PCA特征V26") + v27: float = Field(..., description="PCA特征V27") + v28: float = Field(..., description="PCA特征V28") + amount: float = Field(..., description="交易金额") + + def to_array(self) -> List[float]: + return [ + self.time, self.v1, self.v2, self.v3, self.v4, self.v5, self.v6, self.v7, self.v8, self.v9, + self.v10, self.v11, self.v12, self.v13, self.v14, self.v15, self.v16, self.v17, self.v18, self.v19, + self.v20, self.v21, self.v22, self.v23, self.v24, self.v25, self.v26, self.v27, self.v28, self.amount + ] + + +class EvaluationResult(BaseModel): + predicted_class: int = Field(..., description="预测类别(0=正常, 1=欺诈)") + class_name: TransactionClass = Field(..., description="类别名称") + fraud_probability: float = Field(..., ge=0, le=1, description="欺诈概率") + normal_probability: float = Field(..., ge=0, le=1, description="正常概率") + confidence: ConfidenceLevel = Field(..., description="置信度等级") + + +class FeatureContribution(BaseModel): + feature_name: str = Field(..., description="特征名称") + value: float = Field(..., description="特征值") + importance: float = Field(..., ge=0, le=1, description="特征重要性") + contribution: float = Field(..., description="特征贡献度") + impact: str = Field(..., description="影响方向(正/负/无)") + + +class ExplanationResult(BaseModel): + model_type: str = Field(..., description="模型类型") + predicted_class: TransactionClass = Field(..., description="预测类别") + key_features: List[FeatureContribution] = Field(..., description="关键特征列表") + overall_explanation: str = Field(..., description="总体解释") + + +class Action(BaseModel): + priority: Priority = Field(..., description="优先级") + action: str = Field(..., description="行动建议") + reason: str = Field(..., description="行动原因") + + +class ActionPlan(BaseModel): + predicted_class: TransactionClass = Field(..., description="预测类别") + confidence: ConfidenceLevel = Field(..., description="置信度") + actions: List[Action] = Field(..., description="行动建议列表") + + +class DecisionResult(BaseModel): + evaluation: EvaluationResult = Field(..., description="评估结果") + explanation: ExplanationResult = Field(..., description="解释结果") + action_plan: ActionPlan = Field(..., description="行动计划") + timestamp: str = Field(..., description="时间戳") + + +class ModelMetrics(BaseModel): + accuracy: float = Field(..., ge=0, le=1, description="准确率") + precision: float = Field(..., ge=0, le=1, description="精确率") + recall: float = Field(..., ge=0, le=1, description="召回率") + f1_score: float = Field(..., ge=0, le=1, description="F1分数") + pr_auc: float = Field(..., ge=0, le=1, description="PR-AUC") + + +class TrainingResult(BaseModel): + model_name: str = Field(..., description="模型名称") + metrics: ModelMetrics = Field(..., description="评估指标") + confusion_matrix: List[List[int]] = Field(..., description="混淆矩阵") diff --git a/src/infer.py b/src/infer.py new file mode 100644 index 0000000..dd89075 --- /dev/null +++ b/src/infer.py @@ -0,0 +1,121 @@ +import numpy as np +import logging +from pathlib import Path +from typing import Optional, Union, List +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from train import CreditCardFraudModelTrainer +from features import TransactionFeatures, EvaluationResult, TransactionClass, ConfidenceLevel + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class FraudDetectionInference: + def __init__(self, model_dir: str = "models", model_name: str = "random_forest"): + self.model_dir = Path(model_dir) + self.model_name = model_name + self.trainer = CreditCardFraudModelTrainer(model_dir=model_dir) + + self._load_models() + + def _load_models(self) -> None: + logger.info("加载模型和缩放器...") + success = self.trainer.load_model(self.model_name) and self.trainer.load_scaler() + if not success: + raise RuntimeError("模型或缩放器加载失败") + logger.info("模型和缩放器加载成功") + + def predict(self, transaction: Union[List[float], np.ndarray, TransactionFeatures]) -> EvaluationResult: + if isinstance(transaction, TransactionFeatures): + transaction_array = np.array(transaction.to_array()) + elif isinstance(transaction, list): + transaction_array = np.array(transaction) + else: + transaction_array = transaction + + if transaction_array.ndim == 1: + transaction_array = transaction_array.reshape(1, -1) + + prediction = self.trainer.predict(transaction_array) + probability = self.trainer.predict_proba(transaction_array) + + fraud_prob = float(probability[0]) + normal_prob = float(1 - fraud_prob) + + max_prob = max(fraud_prob, normal_prob) + if max_prob > 0.8: + confidence = ConfidenceLevel.HIGH + elif max_prob > 0.6: + confidence = ConfidenceLevel.MEDIUM + else: + confidence = ConfidenceLevel.LOW + + class_name = TransactionClass.FRAUD if prediction[0] == 1 else TransactionClass.NORMAL + + return EvaluationResult( + predicted_class=int(prediction[0]), + class_name=class_name, + fraud_probability=fraud_prob, + normal_probability=normal_prob, + confidence=confidence + ) + + def predict_batch(self, transactions: Union[List[List[float]], np.ndarray]) -> List[EvaluationResult]: + if isinstance(transactions, list): + transactions_array = np.array(transactions) + else: + transactions_array = transactions + + predictions = self.trainer.predict(transactions_array) + probabilities = self.trainer.predict_proba(transactions_array) + + results = [] + for pred, prob in zip(predictions, probabilities): + fraud_prob = float(prob) + normal_prob = float(1 - fraud_prob) + + max_prob = max(fraud_prob, normal_prob) + if max_prob > 0.8: + confidence = ConfidenceLevel.HIGH + elif max_prob > 0.6: + confidence = ConfidenceLevel.MEDIUM + else: + confidence = ConfidenceLevel.LOW + + class_name = TransactionClass.FRAUD if pred == 1 else TransactionClass.NORMAL + + results.append(EvaluationResult( + predicted_class=int(pred), + class_name=class_name, + fraud_probability=fraud_prob, + normal_probability=normal_prob, + confidence=confidence + )) + + return results + + +def load_inference(model_dir: str = "models", model_name: str = "random_forest") -> FraudDetectionInference: + return FraudDetectionInference(model_dir=model_dir, model_name=model_name) + + +if __name__ == "__main__": + inference = load_inference() + + test_transaction = [ + 0, -1.3598071336738, -0.0727811733098497, 2.53634673796914, 1.37815522427443, + -0.338320769942518, 0.462387777762292, 0.239598554061257, 0.0986979012610507, + 0.363786969611213, 0.0907941719789316, -0.551599533260813, -0.617800855762348, + -0.991389847235408, -0.311169353699879, 1.46817697209427, -0.470400525259478, + 0.207971241929242, 0.0257905801985591, 0.403992960255733, 0.251412098239705, + -0.018306777944153, 0.277837575558899, -0.110473910188767, 0.0669280749146731, + 0.128539358273528, -0.189114843888824, 0.133558376740387, -0.0210530534538215, + 149.62 + ] + + result = inference.predict(test_transaction) + print("预测结果:") + print(f"类别: {result.class_name}") + print(f"欺诈概率: {result.fraud_probability:.4f}") + print(f"置信度: {result.confidence}") diff --git a/src/streamlit_app.py b/src/streamlit_app.py new file mode 100644 index 0000000..2d0a8de --- /dev/null +++ b/src/streamlit_app.py @@ -0,0 +1,451 @@ +import streamlit as st +import numpy as np +import polars as pl +from pathlib import Path +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from agent_app import create_agent +from features import TransactionFeatures, DecisionResult, TransactionClass, ConfidenceLevel, Priority + +st.set_page_config( + page_title="信用卡欺诈检测系统", + page_icon="💳", + layout="wide" +) + +st.title("💳 信用卡欺诈检测系统") +st.markdown("基于机器学习的实时欺诈检测与决策支持系统") + +@st.cache_resource +def load_agent(): + return create_agent() + +agent = load_agent() + +@st.cache_data +def load_csv_file(uploaded_file): + if uploaded_file is not None: + try: + df = pl.read_csv(uploaded_file, schema_overrides={"Time": pl.Float64}) + return df + except Exception as e: + st.error(f"读取CSV文件失败: {e}") + return None + return None + +st.sidebar.header("系统信息") +st.sidebar.info(f""" +**模型信息** +- 模型类型: RandomForest +- 特征数量: 30 +- 支持工具: 2个 + - predict_fraud (ML工具) + - analyze_transaction +""") + +st.header("输入交易数据") + +input_mode = st.radio( + "选择输入方式", + ["📁 上传CSV文件", "✏️ 手动输入特征"], + horizontal=True +) + +if input_mode == "📁 上传CSV文件": + st.subheader("上传CSV文件") + + uploaded_file = st.file_uploader( + "选择CSV文件", + type=['csv'], + help="上传包含交易数据的CSV文件" + ) + + if uploaded_file is not None: + df = load_csv_file(uploaded_file) + + if df is not None: + st.success(f"✅ 成功加载CSV文件,共 {df.height} 条交易记录") + + st.write("### 数据预览") + st.dataframe(df.head(10), use_container_width=True) + + st.write("### 选择交易") + if "Class" in df.columns: + df = df.drop("Class") + + row_index = st.number_input( + "选择交易行号(从0开始)", + min_value=0, + max_value=df.height - 1, + value=0, + step=1 + ) + + if st.button("📋 加载选中的交易", type="primary"): + feature_names = [ + 'Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', + 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', + 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount' + ] + transaction = [df[row_index, col] for col in feature_names] + + st.session_state.transaction = transaction + st.success(f"✅ 已加载第 {row_index} 行的交易数据") + + st.write("### 选中的交易数据") + feature_data = { + "特征名称": feature_names, + "特征值": [f"{v:.6f}" for v in transaction] + } + st.dataframe( + pl.DataFrame(feature_data), + use_container_width=True + ) + +else: + st.subheader("手动输入特征") + + col1, col2 = st.columns(2) + + with col1: + st.write("基础信息") + time = st.number_input("Time (交易时间)", value=0.0, step=1.0) + amount = st.number_input("Amount (交易金额)", value=100.0, step=1.0) + + with col2: + st.write("PCA特征 V1-V14") + v1 = st.number_input("V1", value=0.0, step=0.1) + v2 = st.number_input("V2", value=0.0, step=0.1) + v3 = st.number_input("V3", value=0.0, step=0.1) + v4 = st.number_input("V4", value=0.0, step=0.1) + v5 = st.number_input("V5", value=0.0, step=0.1) + v6 = st.number_input("V6", value=0.0, step=0.1) + v7 = st.number_input("V7", value=0.0, step=0.1) + v8 = st.number_input("V8", value=0.0, step=0.1) + v9 = st.number_input("V9", value=0.0, step=0.1) + v10 = st.number_input("V10", value=0.0, step=0.1) + v11 = st.number_input("V11", value=0.0, step=0.1) + v12 = st.number_input("V12", value=0.0, step=0.1) + v13 = st.number_input("V13", value=0.0, step=0.1) + v14 = st.number_input("V14", value=0.0, step=0.1) + + col3, col4 = st.columns(2) + + with col3: + st.write("PCA特征 V15-V21") + v15 = st.number_input("V15", value=0.0, step=0.1) + v16 = st.number_input("V16", value=0.0, step=0.1) + v17 = st.number_input("V17", value=0.0, step=0.1) + v18 = st.number_input("V18", value=0.0, step=0.1) + v19 = st.number_input("V19", value=0.0, step=0.1) + v20 = st.number_input("V20", value=0.0, step=0.1) + v21 = st.number_input("V21", value=0.0, step=0.1) + + with col4: + st.write("PCA特征 V22-V28") + v22 = st.number_input("V22", value=0.0, step=0.1) + v23 = st.number_input("V23", value=0.0, step=0.1) + v24 = st.number_input("V24", value=0.0, step=0.1) + v25 = st.number_input("V25", value=0.0, step=0.1) + v26 = st.number_input("V26", value=0.0, step=0.1) + v27 = st.number_input("V27", value=0.0, step=0.1) + v28 = st.number_input("V28", value=0.0, step=0.1) + +st.divider() + +if st.button("🔍 检测欺诈", type="primary", use_container_width=True): + if input_mode == "📁 上传CSV文件": + if "transaction" in st.session_state: + transaction = st.session_state.transaction + else: + st.warning("⚠️ 请先上传CSV文件并选择交易") + st.stop() + else: + transaction = [ + time, v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, amount + ] + + with st.spinner("正在分析交易..."): + result = agent.process_transaction(transaction) + + st.success("分析完成!") + + col5, col6, col7 = st.columns(3) + + with col5: + st.metric( + label="预测类别", + value=result.evaluation.class_name.value, + delta=f"置信度: {result.evaluation.confidence.value}" + ) + + with col6: + fraud_prob = result.evaluation.fraud_probability * 100 + st.metric( + label="欺诈概率", + value=f"{fraud_prob:.2f}%", + delta=f"{100 - fraud_prob:.2f}% 正常" + ) + + with col7: + st.metric( + label="行动建议数量", + value=len(result.action_plan.actions), + delta="已生成" + ) + + st.divider() + + tab1, tab2, tab3 = st.tabs(["📊 评估结果", "🔍 特征解释", "📋 行动计划"]) + + with tab1: + st.subheader("模型评估结果") + + eval_col1, eval_col2 = st.columns(2) + + with eval_col1: + st.info(f""" + **预测信息** + - 预测类别: {result.evaluation.class_name.value} + - 预测标签: {result.evaluation.predicted_class} + - 置信度: {result.evaluation.confidence.value} + """) + + with eval_col2: + st.info(f""" + **概率分布** + - 欺诈概率: {result.evaluation.fraud_probability:.4f} + - 正常概率: {result.evaluation.normal_probability:.4f} + """) + + if result.evaluation.class_name == TransactionClass.FRAUD: + st.error(f"⚠️ 该交易被识别为**欺诈交易**,请立即采取行动!") + else: + st.success(f"✅ 该交易被识别为**正常交易**") + + with tab2: + st.subheader("🔍 特征解释") + + st.markdown(""" +
+

💡 什么是特征解释?

+

+ 就像医生看病时会检查各项指标一样,我们的AI模型也通过分析交易的各项"特征"来判断是否为欺诈。 + 下面这些特征是影响判断结果最重要的因素,让我们来看看它们是如何"告诉"模型这个交易是否有问题的。 +

+
+ """, unsafe_allow_html=True) + + st.info(f""" + **使用的模型**: {result.explanation.model_type} + + **整体判断依据**: {result.explanation.overall_explanation} + """) + + st.write("### 📊 关键影响因素分析") + st.caption("这些特征对判断结果影响最大,就像破案时的关键线索") + + feature_descriptions = { + "Time": "交易发生的时间距离第一次交易经过的秒数。欺诈交易往往在特定时间段更频繁,比如深夜或节假日。", + "Amount": "交易金额。异常高额或异常低额的交易都可能引起怀疑,特别是与用户历史消费习惯不符时。", + "V1": "经过PCA(主成分分析)转换后的特征1,代表交易数据的某种模式。PCA将原始数据转换成更易分析的形式。", + "V2": "经过PCA转换后的特征2,捕捉交易数据的另一种模式。", + "V3": "经过PCA转换后的特征3,反映交易数据的特定维度。", + "V4": "经过PCA转换后的特征4,可能代表交易频率或模式。", + "V5": "经过PCA转换后的特征5,可能涉及交易的时间或空间特征。", + "V6": "经过PCA转换后的特征6,可能反映交易的某种统计特性。", + "V7": "经过PCA转换后的特征7,可能代表交易的异常程度。", + "V8": "经过PCA转换后的特征8,可能涉及交易的上下文信息。", + "V9": "经过PCA转换后的特征9,可能反映交易的时间序列特征。", + "V10": "经过PCA转换后的特征10,可能代表交易的某种模式。", + "V11": "经过PCA转换后的特征11,可能涉及交易的频率特征。", + "V12": "经过PCA转换后的特征12,可能反映交易的异常模式。", + "V13": "经过PCA转换后的特征13,可能代表交易的某种统计特性。", + "V14": "经过PCA转换后的特征14,可能涉及交易的时间特征。", + "V15": "经过PCA转换后的特征15,可能反映交易的某种模式。", + "V16": "经过PCA转换后的特征16,可能代表交易的异常程度。", + "V17": "经过PCA转换后的特征17,可能涉及交易的上下文信息。", + "V18": "经过PCA转换后的特征18,可能反映交易的时间序列特征。", + "V19": "经过PCA转换后的特征19,可能代表交易的某种模式。", + "V20": "经过PCA转换后的特征20,可能涉及交易的频率特征。", + "V21": "经过PCA转换后的特征21,可能反映交易的异常模式。", + "V22": "经过PCA转换后的特征22,可能代表交易的某种统计特性。", + "V23": "经过PCA转换后的特征23,可能涉及交易的时间特征。", + "V24": "经过PCA转换后的特征24,可能反映交易的某种模式。", + "V25": "经过PCA转换后的特征25,可能代表交易的异常程度。", + "V26": "经过PCA转换后的特征26,可能涉及交易的上下文信息。", + "V27": "经过PCA转换后的特征27,可能反映交易的时间序列特征。", + "V28": "经过PCA转换后的特征28,可能代表交易的某种模式。" + } + + for i, feature in enumerate(result.explanation.key_features, 1): + importance_percent = min(feature.importance * 100, 100) + + impact_emoji = "📈" if feature.impact == "正向影响" else "📉" + impact_color = "#ff5252" if feature.impact == "正向影响" else "#4caf50" + + feature_desc = feature_descriptions.get(feature.feature_name, "该特征是经过PCA转换后的数据特征,用于帮助模型识别交易模式。") + + with st.expander(f"{i}. {feature.feature_name} {impact_emoji}"): + st.markdown(f""" +
+ +

+ 影响程度: +

+
+
+
+

+ {importance_percent:.1f}% 的影响力 +

+ +

+ 影响方向: {feature.impact} +

+ +
+

+ 📖 这个特征是什么?
+ {feature_desc} +

+
+ +
+

+ 💡 简单来说: + {"这个特征让模型更倾向于认为这是欺诈交易" if feature.impact == "正向影响" else "这个特征让模型更倾向于认为这是正常交易"} +

+
+
+ """, unsafe_allow_html=True) + + with tab3: + st.subheader("📋 行动计划") + + st.markdown(""" +
+

🎯 为什么需要行动计划?

+

+ 根据检测结果,我们为您准备了具体的行动建议。这些建议按照紧急程度排序, + 帮助您快速、有效地应对可能的风险。请根据实际情况选择合适的处理方式。 +

+
+ """, unsafe_allow_html=True) + + st.write("### 🚀 建议采取的行动") + st.caption("按优先级排序,从高到低依次处理") + + for action in result.action_plan.actions: + priority_info = { + Priority.URGENT: { + "emoji": "🔴", + "color": "#d32f2f", + "bg_color": "#ffcdd2", + "description": "紧急 - 需要立即处理,不能拖延" + }, + Priority.HIGH: { + "emoji": "🟠", + "color": "#f57c00", + "bg_color": "#ffe0b2", + "description": "高优先级 - 尽快处理" + }, + Priority.MEDIUM: { + "emoji": "🟡", + "color": "#fbc02d", + "bg_color": "#fff9c4", + "description": "中等优先级 - 适时处理" + }, + Priority.LOW: { + "emoji": "🟢", + "color": "#388e3c", + "bg_color": "#c8e6c9", + "description": "低优先级 - 可以稍后处理" + }, + Priority.ROUTINE: { + "emoji": "⚪", + "color": "#757575", + "bg_color": "#e0e0e0", + "description": "常规 - 按正常流程处理" + } + }.get(action.priority, { + "emoji": "⚪", + "color": "#757575", + "bg_color": "#e0e0e0", + "description": "常规" + }) + + with st.container(): + st.markdown(f""" +
+
+ {priority_info['emoji']} +
+

{action.action}

+

+ 优先级: {action.priority.value} - {priority_info['description']} +

+
+
+ +
+

+ 💡 为什么要这样做?
+ {action.reason} +

+
+
+ """, unsafe_allow_html=True) + +st.divider() + +st.header("📝 使用说明") + +with st.expander("查看使用说明"): + st.markdown(""" + ### 如何使用本系统 + + #### 方式1:上传CSV文件 + 1. **上传文件**: 点击"选择CSV文件"按钮上传包含交易数据的CSV文件 + 2. **查看数据**: 系统会显示数据预览 + 3. **选择交易**: 输入行号选择要分析的交易 + 4. **加载数据**: 点击"加载选中的交易"按钮 + 5. **开始检测**: 点击"检测欺诈"按钮开始分析 + + #### 方式2:手动输入 + 1. **输入特征**: 在表单中输入30个特征值 + - Time: 交易发生时间(秒) + - V1-V28: PCA转换后的特征 + - Amount: 交易金额 + + 2. **点击检测**: 点击"检测欺诈"按钮开始分析 + + 3. **查看结果**: 系统会返回三个部分的结果 + - **评估结果**: 模型的预测类别和概率 + - **特征解释**: 影响预测的关键特征 + - **行动计划**: 建议的处理措施 + + ### CSV文件格式要求 + + CSV文件必须包含以下列: + - Time, V1, V2, V3, V4, V5, V6, V7, V8, V9 + - V10, V11, V12, V13, V14, V15, V16, V17, V18, V19 + - V20, V21, V22, V23, V24, V25, V26, V27, V28, Amount + - Class (可选,如果存在会被自动删除) + + ### 系统特点 + + - ✅ 使用随机森林模型进行预测 + - ✅ 支持CSV文件批量处理 + - ✅ 提供特征重要性分析 + - ✅ 根据置信度生成行动建议 + - ✅ 实时分析,快速响应 + + ### 注意事项 + + - 本系统仅供演示使用 + - 实际应用中需要结合人工审核 + - 建议定期更新模型以保持准确性 + """) diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..67a20fd --- /dev/null +++ b/src/train.py @@ -0,0 +1,226 @@ +from sklearn.linear_model import LogisticRegression +from sklearn.ensemble import RandomForestClassifier +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import ( + precision_score, recall_score, f1_score, accuracy_score, + precision_recall_curve, auc, confusion_matrix +) +from imblearn.over_sampling import SMOTE +import numpy as np +import logging +import joblib +from pathlib import Path +from typing import Dict, Tuple, Optional +import sys +sys.path.insert(0, str(Path(__file__).parent)) +from features import ModelMetrics, TrainingResult + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class CreditCardFraudModelTrainer: + def __init__(self, model_dir: str = "models"): + self.model_dir = Path(model_dir) + self.model_dir.mkdir(exist_ok=True) + + self.models = { + "logistic_regression": LogisticRegression( + random_state=42, + class_weight="balanced", + max_iter=1000 + ), + "random_forest": RandomForestClassifier( + random_state=42, + class_weight="balanced", + n_estimators=100 + ) + } + + self.scaler = StandardScaler() + self.best_model = None + self.best_model_name = None + self.best_model_score = 0 + self.evaluation_results = {} + + def train(self, X_train: np.ndarray, y_train: np.ndarray, use_smote: bool = False) -> Dict[str, any]: + logger.info("开始训练模型...") + + X_train_scaled = self.scaler.fit_transform(X_train) + + if use_smote: + logger.info("使用SMOTE处理不平衡数据...") + smote = SMOTE(random_state=42) + X_train_scaled, y_train = smote.fit_resample(X_train_scaled, y_train) + logger.info(f"SMOTE处理后,训练集形状: X={X_train_scaled.shape}, y={y_train.shape}") + + for model_name, model in self.models.items(): + logger.info(f"训练模型: {model_name}") + model.fit(X_train_scaled, y_train) + + model_path = self.model_dir / f"{model_name}_model.joblib" + joblib.dump(model, model_path) + logger.info(f"模型已保存: {model_path}") + + scaler_path = self.model_dir / "scaler.joblib" + joblib.dump(self.scaler, scaler_path) + logger.info(f"特征缩放器已保存: {scaler_path}") + + logger.info("所有模型训练完成") + return {"status": "success", "message": "模型训练完成"} + + def evaluate(self, X_test: np.ndarray, y_test: np.ndarray) -> Dict[str, Dict[str, any]]: + logger.info("开始评估模型...") + + X_test_scaled = self.scaler.transform(X_test) + + for model_name, model in self.models.items(): + logger.info(f"评估模型: {model_name}") + + y_pred = model.predict(X_test_scaled) + y_pred_proba = model.predict_proba(X_test_scaled)[:, 1] + + accuracy = accuracy_score(y_test, y_pred) + precision = precision_score(y_test, y_pred) + recall = recall_score(y_test, y_pred) + f1 = f1_score(y_test, y_pred) + + precision_curve, recall_curve, _ = precision_recall_curve(y_test, y_pred_proba) + pr_auc = auc(recall_curve, precision_curve) + + cm = confusion_matrix(y_test, y_pred) + + self.evaluation_results[model_name] = { + "accuracy": accuracy, + "precision": precision, + "recall": recall, + "f1_score": f1, + "pr_auc": pr_auc, + "confusion_matrix": cm, + "y_pred": y_pred, + "y_pred_proba": y_pred_proba + } + + if pr_auc > self.best_model_score: + self.best_model_score = pr_auc + self.best_model = model + self.best_model_name = model_name + + logger.info(f"模型 {model_name} 评估完成, PR-AUC: {pr_auc:.4f}, F1: {f1:.4f}") + + logger.info(f"最佳模型: {self.best_model_name}, PR-AUC: {self.best_model_score:.4f}") + return self.evaluation_results + + def get_best_model(self) -> Tuple[Optional[str], Optional[any], float]: + if self.best_model is None: + logger.warning("尚未训练或评估模型") + return None, None, 0 + return self.best_model_name, self.best_model, self.best_model_score + + def predict(self, X: np.ndarray, model_name: Optional[str] = None) -> np.ndarray: + X_scaled = self.scaler.transform(X) + + if model_name is not None: + if model_name not in self.models: + logger.error(f"模型 {model_name} 未找到") + raise ValueError(f"模型 {model_name} 未找到") + model = self.models[model_name] + else: + if self.best_model is None: + logger.error("尚未训练或评估模型") + raise RuntimeError("尚未训练或评估模型,无法进行预测") + model = self.best_model + model_name = self.best_model_name + + logger.info(f"使用模型 {model_name} 进行预测") + return model.predict(X_scaled) + + def predict_proba(self, X: np.ndarray, model_name: Optional[str] = None) -> np.ndarray: + X_scaled = self.scaler.transform(X) + + if model_name is not None: + if model_name not in self.models: + logger.error(f"模型 {model_name} 未找到") + return None + model = self.models[model_name] + else: + if self.best_model is None: + logger.error("尚未训练或评估模型") + return None + model = self.best_model + model_name = self.best_model_name + + logger.info(f"使用模型 {model_name} 进行概率预测") + return model.predict_proba(X_scaled)[:, 1] + + def load_model(self, model_name: str) -> bool: + try: + model_path = self.model_dir / f"{model_name}_model.joblib" + model = joblib.load(model_path) + self.models[model_name] = model + + if self.best_model is None: + self.best_model = model + self.best_model_name = model_name + logger.info(f"设置 {model_name} 为默认最佳模型") + + logger.info(f"模型加载成功: {model_path}") + return True + except Exception as e: + logger.error(f"加载模型失败: {e}") + return False + + def load_scaler(self) -> bool: + try: + scaler_path = self.model_dir / "scaler.joblib" + self.scaler = joblib.load(scaler_path) + logger.info(f"特征缩放器加载成功: {scaler_path}") + return True + except Exception as e: + logger.error(f"加载特征缩放器失败: {e}") + return False + + def print_evaluation_results(self) -> None: + print("\n=== 模型评估结果 ===") + for model_name, results in self.evaluation_results.items(): + print(f"\n模型: {model_name}") + print("-" * 30) + print(f"准确率 (Accuracy): {results['accuracy']:.4f}") + print(f"精确率 (Precision): {results['precision']:.4f}") + print(f"召回率 (Recall): {results['recall']:.4f}") + print(f"F1分数 (F1-Score): {results['f1_score']:.4f}") + print(f"PR-AUC: {results['pr_auc']:.4f}") + print("\n混淆矩阵:") + print(results['confusion_matrix']) + + print(f"\n最佳模型: {self.best_model_name}") + print(f"最佳模型PR-AUC: {self.best_model_score:.4f}") + + +def train_and_evaluate(data_path: str = "data/creditcard.csv", use_smote: bool = False) -> CreditCardFraudModelTrainer: + from data import load_data + + logger.info("=== 信用卡欺诈检测系统开始运行 ===") + + processor = load_data(data_path) + X_train = processor.train_features + y_train = processor.train_labels + X_test = processor.test_features + y_test = processor.test_labels + + logger.info(f"\n训练集: {X_train.shape}, {y_train.shape}") + logger.info(f"测试集: {X_test.shape}, {y_test.shape}") + + trainer = CreditCardFraudModelTrainer() + train_result = trainer.train(X_train, y_train, use_smote=use_smote) + logger.info(train_result["message"]) + + evaluation_results = trainer.evaluate(X_test, y_test) + trainer.print_evaluation_results() + + logger.info("\n=== 信用卡欺诈检测系统运行完成 ===") + return trainer + + +if __name__ == "__main__": + train_and_evaluate() diff --git a/tests/test_agent.py b/tests/test_agent.py new file mode 100644 index 0000000..a8cbbeb --- /dev/null +++ b/tests/test_agent.py @@ -0,0 +1,64 @@ +import pytest +from src.agent_app import CreditCardFraudAgent, create_agent, Tool + + +def test_agent_initialization(): + agent = CreditCardFraudAgent(model_dir="models", model_name="random_forest") + assert agent.inference is not None + assert len(agent.tools) == 2 + + +def test_create_agent(): + agent = create_agent(model_dir="models", model_name="random_forest") + assert isinstance(agent, CreditCardFraudAgent) + + +def test_list_tools(): + agent = create_agent() + tools = agent.list_tools() + assert len(tools) == 2 + tool_names = [tool["name"] for tool in tools] + assert "predict_fraud" in tool_names + assert "analyze_transaction" in tool_names + + +def test_tool_structure(): + agent = create_agent() + for tool in agent.tools: + assert hasattr(tool, "name") + assert hasattr(tool, "description") + assert hasattr(tool, "func") + assert callable(tool.func) + + +def test_analyze_transaction(): + agent = create_agent() + transaction = [ + 0, -1.36, -0.07, 2.54, 1.38, -0.34, 0.46, 0.24, 0.10, 0.36, + 0.09, -0.55, -0.62, -0.99, -0.31, 1.47, -0.47, 0.21, 0.03, 0.40, + 0.25, -0.02, 0.28, -0.11, 0.07, 0.13, -0.19, 0.13, -0.02, 149.62 + ] + analysis = agent._analyze_transaction(transaction) + assert "feature_count" in analysis + assert "amount" in analysis + assert "time" in analysis + assert "statistics" in analysis + assert "anomalies" in analysis + assert analysis["feature_count"] == 30 + assert analysis["amount"] == 149.62 + + +def test_process_transaction(): + agent = create_agent() + transaction = [ + 0, -1.36, -0.07, 2.54, 1.38, -0.34, 0.46, 0.24, 0.10, 0.36, + 0.09, -0.55, -0.62, -0.99, -0.31, 1.47, -0.47, 0.21, 0.03, 0.40, + 0.25, -0.02, 0.28, -0.11, 0.07, 0.13, -0.19, 0.13, -0.02, 149.62 + ] + result = agent.process_transaction(transaction) + assert result.evaluation is not None + assert result.explanation is not None + assert result.action_plan is not None + assert hasattr(result.evaluation, "predicted_class") + assert hasattr(result.evaluation, "fraud_probability") + assert hasattr(result.action_plan, "actions") diff --git a/tests/test_data.py b/tests/test_data.py new file mode 100644 index 0000000..7167f41 --- /dev/null +++ b/tests/test_data.py @@ -0,0 +1,62 @@ +import pytest +import numpy as np +from src.data import CreditCardDataProcessor, load_data +from pathlib import Path + + +def test_data_processor_initialization(): + processor = CreditCardDataProcessor("data/creditcard.csv") + assert processor.file_path == "data/creditcard.csv" + assert processor.data is None + + +def test_load_data(): + processor = CreditCardDataProcessor("data/creditcard.csv") + processor.load_data() + assert processor.data is not None + assert processor.data.height > 0 + assert "Class" in processor.data.columns + + +def test_validate_data(): + processor = CreditCardDataProcessor("data/creditcard.csv") + processor.load_data() + processor.validate_data() + assert processor.data is not None + + +def test_split_data_by_time(): + processor = CreditCardDataProcessor("data/creditcard.csv") + processor.load_data() + train_data, test_data = processor.split_data_by_time(test_ratio=0.2) + assert train_data is not None + assert test_data is not None + assert train_data.height > test_data.height + + +def test_prepare_features_labels(): + processor = CreditCardDataProcessor("data/creditcard.csv") + processor.load_data() + processor.split_data_by_time() + processor.prepare_features_labels() + assert processor.train_features is not None + assert processor.train_labels is not None + assert processor.test_features is not None + assert processor.test_labels is not None + + +def test_load_data_function(): + processor = load_data("data/creditcard.csv") + assert processor.data is not None + assert processor.train_features is not None + assert processor.test_features is not None + + +def test_get_statistics(): + processor = load_data("data/creditcard.csv") + stats = processor.get_statistics() + assert "总记录数" in stats + assert "特征数" in stats + assert "欺诈交易数" in stats + assert "非欺诈交易数" in stats + assert stats["总记录数"] > 0 diff --git a/tests/test_features.py b/tests/test_features.py new file mode 100644 index 0000000..23a78b7 --- /dev/null +++ b/tests/test_features.py @@ -0,0 +1,137 @@ +import pytest +from src.features import ( + TransactionFeatures, EvaluationResult, ExplanationResult, + ActionPlan, DecisionResult, ModelMetrics, TrainingResult, + TransactionClass, ConfidenceLevel, Priority, + FeatureContribution, Action +) + + +def test_transaction_features(): + features = TransactionFeatures( + time=0.0, + v1=-1.36, v2=-0.07, v3=2.54, v4=1.38, v5=-0.34, + v6=0.46, v7=0.24, v8=0.10, v9=0.36, v10=0.09, + v11=-0.55, v12=-0.62, v13=-0.99, v14=-0.31, v15=1.47, + v16=-0.47, v17=0.21, v18=0.03, v19=0.40, v20=0.25, + v21=-0.02, v22=0.28, v23=-0.11, v24=0.07, v25=0.13, + v26=-0.19, v27=0.13, v28=-0.02, amount=149.62 + ) + assert features.time == 0.0 + assert features.amount == 149.62 + assert len(features.to_array()) == 30 + + +def test_evaluation_result(): + result = EvaluationResult( + predicted_class=1, + class_name=TransactionClass.FRAUD, + fraud_probability=0.95, + normal_probability=0.05, + confidence=ConfidenceLevel.HIGH + ) + assert result.predicted_class == 1 + assert result.class_name == TransactionClass.FRAUD + assert result.fraud_probability == 0.95 + assert result.confidence == ConfidenceLevel.HIGH + + +def test_feature_contribution(): + contribution = FeatureContribution( + feature_name="V14", + value=-0.99, + importance=0.15, + contribution=-0.15, + impact="负" + ) + assert contribution.feature_name == "V14" + assert contribution.value == -0.99 + assert contribution.importance == 0.15 + assert contribution.impact == "负" + + +def test_explanation_result(): + explanation = ExplanationResult( + model_type="RandomForestClassifier", + predicted_class=TransactionClass.FRAUD, + key_features=[], + overall_explanation="测试解释" + ) + assert explanation.model_type == "RandomForestClassifier" + assert explanation.predicted_class == TransactionClass.FRAUD + + +def test_action(): + action = Action( + priority=Priority.URGENT, + action="冻结账户", + reason="检测到欺诈" + ) + assert action.priority == Priority.URGENT + assert action.action == "冻结账户" + + +def test_action_plan(): + plan = ActionPlan( + predicted_class=TransactionClass.FRAUD, + confidence=ConfidenceLevel.HIGH, + actions=[] + ) + assert plan.predicted_class == TransactionClass.FRAUD + assert plan.confidence == ConfidenceLevel.HIGH + + +def test_decision_result(): + result = DecisionResult( + evaluation=EvaluationResult( + predicted_class=1, + class_name=TransactionClass.FRAUD, + fraud_probability=0.95, + normal_probability=0.05, + confidence=ConfidenceLevel.HIGH + ), + explanation=ExplanationResult( + model_type="RandomForestClassifier", + predicted_class=TransactionClass.FRAUD, + key_features=[], + overall_explanation="测试" + ), + action_plan=ActionPlan( + predicted_class=TransactionClass.FRAUD, + confidence=ConfidenceLevel.HIGH, + actions=[] + ), + timestamp="2026-01-15" + ) + assert result.evaluation.predicted_class == 1 + assert result.explanation.model_type == "RandomForestClassifier" + assert result.action_plan.predicted_class == TransactionClass.FRAUD + + +def test_model_metrics(): + metrics = ModelMetrics( + accuracy=0.95, + precision=0.90, + recall=0.85, + f1_score=0.87, + pr_auc=0.92 + ) + assert metrics.accuracy == 0.95 + assert metrics.precision == 0.90 + assert metrics.pr_auc == 0.92 + + +def test_training_result(): + result = TrainingResult( + model_name="random_forest", + metrics=ModelMetrics( + accuracy=0.95, + precision=0.90, + recall=0.85, + f1_score=0.87, + pr_auc=0.92 + ), + confusion_matrix=[[100, 5], [10, 85]] + ) + assert result.model_name == "random_forest" + assert result.metrics.accuracy == 0.95