diff --git a/.env.example b/.env.example index 9eeb224..ae33f3b 100644 --- a/.env.example +++ b/.env.example @@ -8,6 +8,10 @@ DATA_PATH=data/creditcard.csv # 日志级别 LOG_LEVEL=INFO +# DeepSeek LLM 配置 +DEEPSEEK_API_KEY=your_deepseek_api_key_here +DEEPSEEK_BASE_URL=https://api.deepseek.com/v1 + # Web 应用配置 FLASK_HOST=0.0.0.0 FLASK_PORT=5000 diff --git a/pyproject.toml b/pyproject.toml index 3b889fb..ffd06c9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,12 +8,17 @@ license = { text = "MIT" } dependencies = [ "numpy>=1.24.0", "polars>=0.19.0", + "pandas>=2.2.0", "scikit-learn>=1.3.0", "imbalanced-learn>=0.11.0", + "lightgbm>=4.0.0", "matplotlib>=3.7.0", - "seaborn>=0.12.0", + "seaborn>=0.13.0", "joblib>=1.3.0", "pydantic>=2.0.0", + "pandera>=0.18.0", + "openai>=1.0.0", + "python-dotenv>=1.0.0", "streamlit>=1.28.0", ] diff --git a/src/__init__.py b/src/__init__.py index 8e9cd10..8b13789 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,30 +1 @@ -from .data import CreditCardDataProcessor, load_data -from .features import ( - TransactionFeatures, EvaluationResult, ExplanationResult, - ActionPlan, DecisionResult, ModelMetrics, TrainingResult, - TransactionClass, ConfidenceLevel, Priority -) -from .train import CreditCardFraudModelTrainer, train_and_evaluate -from .infer import FraudDetectionInference, load_inference -from .agent_app import CreditCardFraudAgent, create_agent -__all__ = [ - "CreditCardDataProcessor", - "load_data", - "TransactionFeatures", - "EvaluationResult", - "ExplanationResult", - "ActionPlan", - "DecisionResult", - "ModelMetrics", - "TrainingResult", - "TransactionClass", - "ConfidenceLevel", - "Priority", - "CreditCardFraudModelTrainer", - "train_and_evaluate", - "FraudDetectionInference", - "load_inference", - "CreditCardFraudAgent", - "create_agent", -] diff --git a/src/agent_app.py b/src/agent_app.py index d89463e..7332ef7 100644 --- a/src/agent_app.py +++ b/src/agent_app.py @@ -10,6 +10,7 @@ from features import ( ActionPlan, DecisionResult, TransactionClass, ConfidenceLevel, Priority, FeatureContribution, Action ) +from llm_integration import DeepSeekLLM logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -28,6 +29,7 @@ class Tool: class CreditCardFraudAgent: def __init__(self, model_dir: str = "models", model_name: str = "random_forest"): self.inference = FraudDetectionInference(model_dir=model_dir, model_name=model_name) + self.llm = DeepSeekLLM() self.tools = self._initialize_tools() def _initialize_tools(self) -> List[Tool]: @@ -41,6 +43,11 @@ class CreditCardFraudAgent: name="analyze_transaction", description="分析交易数据的统计特征和异常值", func=self._analyze_transaction + ), + Tool( + name="llm_explain", + description="使用LLM生成详细的解释和建议", + func=self._llm_explain ) ] return tools @@ -86,6 +93,21 @@ class CreditCardFraudAgent: return analysis + def _llm_explain(self, transaction: List[float], evaluation: EvaluationResult, feature_analysis: Optional[Dict[str, Any]] = None) -> str: + logger.info("执行 LLM 工具: llm_explain") + + if not self.llm.is_available(): + return "LLM服务不可用,请配置DEEPSEEK_API_KEY" + + prediction_data = { + "predicted_class": evaluation.predicted_class, + "fraud_probability": evaluation.fraud_probability, + "normal_probability": evaluation.normal_probability, + "confidence": evaluation.confidence + } + + return self.llm.explain_prediction(transaction, prediction_data) + def _explain_prediction(self, transaction: List[float], evaluation: EvaluationResult) -> ExplanationResult: logger.info("生成预测解释") transaction_array = np.array(transaction) @@ -218,6 +240,12 @@ class CreditCardFraudAgent: evaluation = self._predict_fraud(transaction) explanation = self._explain_prediction(transaction, evaluation) + + if self.llm.is_available(): + feature_analysis = self._analyze_transaction(transaction) + llm_explanation = self._llm_explain(transaction, evaluation, feature_analysis) + explanation.overall_explanation = f"{explanation.overall_explanation}\n\nLLM补充解释:\n{llm_explanation}" + action_plan = self._generate_action_plan(evaluation, explanation) result = DecisionResult( diff --git a/src/data.py b/src/data.py index 46f9e65..fd491ba 100644 --- a/src/data.py +++ b/src/data.py @@ -3,6 +3,7 @@ import numpy as np from typing import Tuple, Dict, List, Optional import logging from pathlib import Path +from data_validation import validate_dataframe, validate_data_integrity, print_validation_results logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @@ -36,6 +37,16 @@ class CreditCardDataProcessor: def validate_data(self) -> None: logger.info("开始数据验证...") + + passed, error = validate_dataframe(self.data) + if not passed: + logger.error(f"Schema验证失败: {error}") + raise ValueError(f"数据验证失败: {error}") + + logger.info("Schema验证通过") + + integrity_results = validate_data_integrity(self.data) + missing_values = self.data.null_count() total_missing = missing_values.sum_horizontal().item() if total_missing > 0: @@ -45,6 +56,10 @@ class CreditCardDataProcessor: class_dist = self.data.group_by("Class").agg(pl.len().alias("count")).to_dict() logger.info(f"标签分布: {class_dist}") + + if not integrity_results["通过"]: + logger.warning("数据完整性检查发现问题:") + print_validation_results(integrity_results) def split_data_by_time(self, test_ratio: float = 0.2) -> Tuple[pl.DataFrame, pl.DataFrame]: logger.info(f"按照时间顺序划分数据集,测试集比例: {test_ratio}") diff --git a/src/data_leakage_check.py b/src/data_leakage_check.py new file mode 100644 index 0000000..d76c53f --- /dev/null +++ b/src/data_leakage_check.py @@ -0,0 +1,191 @@ +from typing import List, Dict, Any +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +DATA_LEAKAGE_RISKS = [ + { + "风险点": "时间切分", + "描述": "必须按时间顺序划分训练/测试集,避免未来信息泄露", + "当前状态": "✅ 已实现 - split_data_by_time()", + "验证方法": "检查 train_max_time <= test_min_time", + "代码位置": "src/data.py:split_data_by_time()", + "风险等级": "高", + "缓解措施": "使用严格的时间序列切分,训练集时间必须早于测试集" + }, + { + "风险点": "特征缩放", + "描述": "StandardScaler参数必须只在训练集上计算", + "当前状态": "✅ 已实现 - fit_transform仅在训练集", + "验证方法": "检查scaler.fit()是否在训练集上", + "代码位置": "src/train.py:train()", + "风险等级": "高", + "缓解措施": "仅在训练集上调用fit_transform,测试集只调用transform" + }, + { + "风险点": "SMOTE过采样", + "描述": "SMOTE必须只在训练集上进行,测试集保持原始分布", + "当前状态": "✅ 已实现 - SMOTE仅在训练集", + "验证方法": "检查SMOTE是否在train()函数内", + "代码位置": "src/train.py:train()", + "风险等级": "高", + "缓解措施": "仅在训练集上进行SMOTE,测试集保持原始不平衡分布" + }, + { + "风险点": "特征选择", + "描述": "特征选择必须基于训练集统计信息", + "当前状态": "✅ 已实现 - 使用预定义特征列表", + "验证方法": "检查特征列表是否在训练前确定", + "代码位置": "src/data.py:prepare_features_labels()", + "风险等级": "中", + "缓解措施": "使用预定义的特征列表,不基于测试集进行特征选择" + }, + { + "风险点": "数据验证", + "描述": "数据验证必须在划分数据集之前完成", + "当前状态": "✅ 已实现 - validate_data()在split_data_by_time()之前", + "验证方法": "检查代码执行顺序", + "代码位置": "src/data.py:load_data()", + "风险等级": "中", + "缓解措施": "确保数据验证在数据划分之前完成" + }, + { + "风险点": "模型评估", + "描述": "评估指标必须基于测试集计算", + "当前状态": "✅ 已实现 - evaluate()使用测试集", + "验证方法": "检查evaluate()函数参数", + "代码位置": "src/train.py:evaluate()", + "风险等级": "高", + "缓解措施": "确保评估只使用测试集,不使用训练集" + }, + { + "风险点": "模型保存", + "描述": "保存的模型不应包含测试集信息", + "当前状态": "✅ 已实现 - 只保存模型和scaler", + "验证方法": "检查保存的文件内容", + "代码位置": "src/train.py:train()", + "风险等级": "低", + "缓解措施": "只保存模型参数和scaler参数,不保存测试集数据" + }, + { + "风险点": "推理服务", + "描述": "推理时使用与训练相同的scaler参数", + "当前状态": "✅ 已实现 - load_scaler()加载训练时的scaler", + "验证方法": "检查推理时scaler的来源", + "代码位置": "src/infer.py:__init__()", + "风险等级": "高", + "缓解措施": "推理时加载训练时保存的scaler,确保一致性" + } +] + + +def print_data_leakage_checklist() -> None: + """打印数据泄露风险清单""" + print("\n" + "=" * 80) + print("数据泄露风险检查清单") + print("=" * 80) + + for i, risk in enumerate(DATA_LEAKAGE_RISKS, 1): + print(f"\n{i}. {risk['风险点']}") + print("-" * 80) + print(f"描述: {risk['描述']}") + print(f"当前状态: {risk['当前状态']}") + print(f"验证方法: {risk['验证方法']}") + print(f"代码位置: {risk['代码位置']}") + print(f"风险等级: {risk['风险等级']}") + print(f"缓解措施: {risk['缓解措施']}") + + print("\n" + "=" * 80) + print("数据泄露风险检查完成") + print("=" * 80) + + +def get_data_leakage_summary() -> Dict[str, Any]: + """获取数据泄露风险摘要""" + summary = { + "总风险点数": len(DATA_LEAKAGE_RISKS), + "已缓解风险数": sum(1 for risk in DATA_LEAKAGE_RISKS if "✅ 已实现" in risk["当前状态"]), + "高风险点数": sum(1 for risk in DATA_LEAKAGE_RISKS if risk["风险等级"] == "高"), + "中风险点数": sum(1 for risk in DATA_LEAKAGE_RISKS if risk["风险等级"] == "中"), + "低风险点数": sum(1 for risk in DATA_LEAKAGE_RISKS if risk["风险等级"] == "低"), + "风险缓解率": sum(1 for risk in DATA_LEAKAGE_RISKS if "✅ 已实现" in risk["当前状态"]) / len(DATA_LEAKAGE_RISKS) * 100 + } + return summary + + +def print_data_leakage_summary() -> None: + """打印数据泄露风险摘要""" + summary = get_data_leakage_summary() + + print("\n" + "=" * 80) + print("数据泄露风险摘要") + print("=" * 80) + print(f"\n总风险点数: {summary['总风险点数']}") + print(f"已缓解风险数: {summary['已缓解风险数']}") + print(f"高风险点数: {summary['高风险点数']}") + print(f"中风险点数: {summary['中风险点数']}") + print(f"低风险点数: {summary['低风险点数']}") + print(f"风险缓解率: {summary['风险缓解率']:.1f}%") + print("\n" + "=" * 80) + + +def validate_data_leakage_prevention() -> bool: + """ + 验证数据泄露预防措施是否到位 + + Returns: + 是否通过验证 + """ + print("\n" + "=" * 80) + print("数据泄露预防措施验证") + print("=" * 80) + + all_passed = True + + for i, risk in enumerate(DATA_LEAKAGE_RISKS, 1): + status = "✅ 已实现" in risk["当前状态"] + if status: + print(f"✓ {risk['风险点']}: {risk['当前状态']}") + else: + print(f"✗ {risk['风险点']}: {risk['当前状态']}") + all_passed = False + + print("\n" + "=" * 80) + if all_passed: + print("✓ 所有数据泄露预防措施已到位") + else: + print("✗ 部分数据泄露预防措施未到位,请检查") + print("=" * 80) + + return all_passed + + +def get_risk_by_level(level: str) -> List[Dict[str, Any]]: + """根据风险等级获取风险点""" + return [risk for risk in DATA_LEAKAGE_RISKS if risk["风险等级"] == level] + + +def get_unmitigated_risks() -> List[Dict[str, Any]]: + """获取未缓解的风险点""" + return [risk for risk in DATA_LEAKAGE_RISKS if "✅ 已实现" not in risk["当前状态"]] + + +if __name__ == "__main__": + print_data_leakage_checklist() + print_data_leakage_summary() + validate_data_leakage_prevention() + + print("\n高风险点:") + high_risks = get_risk_by_level("高") + for risk in high_risks: + print(f" - {risk['风险点']}: {risk['描述']}") + + unmitigated = get_unmitigated_risks() + if unmitigated: + print("\n未缓解的风险点:") + for risk in unmitigated: + print(f" - {risk['风险点']}: {risk['当前状态']}") + else: + print("\n✓ 所有风险点已缓解") diff --git a/src/data_validation.py b/src/data_validation.py new file mode 100644 index 0000000..c75adb9 --- /dev/null +++ b/src/data_validation.py @@ -0,0 +1,274 @@ +import pandera as pa +from pandera.typing import DataFrame, Series +import polars as pl +import logging +from typing import Optional + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class TransactionSchema(pa.DataFrameModel): + Time: Series[float] = pa.Field(ge=0, description="交易时间(秒)") + V1: Series[float] = pa.Field(description="PCA特征V1") + V2: Series[float] = pa.Field(description="PCA特征V2") + V3: Series[float] = pa.Field(description="PCA特征V3") + V4: Series[float] = pa.Field(description="PCA特征V4") + V5: Series[float] = pa.Field(description="PCA特征V5") + V6: Series[float] = pa.Field(description="PCA特征V6") + V7: Series[float] = pa.Field(description="PCA特征V7") + V8: Series[float] = pa.Field(description="PCA特征V8") + V9: Series[float] = pa.Field(description="PCA特征V9") + V10: Series[float] = pa.Field(description="PCA特征V10") + V11: Series[float] = pa.Field(description="PCA特征V11") + V12: Series[float] = pa.Field(description="PCA特征V12") + V13: Series[float] = pa.Field(description="PCA特征V13") + V14: Series[float] = pa.Field(description="PCA特征V14") + V15: Series[float] = pa.Field(description="PCA特征V15") + V16: Series[float] = pa.Field(description="PCA特征V16") + V17: Series[float] = pa.Field(description="PCA特征V17") + V18: Series[float] = pa.Field(description="PCA特征V18") + V19: Series[float] = pa.Field(description="PCA特征V19") + V20: Series[float] = pa.Field(description="PCA特征V20") + V21: Series[float] = pa.Field(description="PCA特征V21") + V22: Series[float] = pa.Field(description="PCA特征V22") + V23: Series[float] = pa.Field(description="PCA特征V23") + V24: Series[float] = pa.Field(description="PCA特征V24") + V25: Series[float] = pa.Field(description="PCA特征V25") + V26: Series[float] = pa.Field(description="PCA特征V26") + V27: Series[float] = pa.Field(description="PCA特征V27") + V28: Series[float] = pa.Field(description="PCA特征V28") + Amount: Series[float] = pa.Field(ge=0, description="交易金额") + Class: Series[int] = pa.Field(isin=[0, 1], description="标签(0=正常, 1=欺诈)") + + class Config: + strict = True + coerce = True + drop_invalid_rows = False + + +class CleanedTransactionSchema(pa.DataFrameModel): + Time: Series[float] = pa.Field(ge=0, description="交易时间(秒)") + V1: Series[float] = pa.Field(description="PCA特征V1") + V2: Series[float] = pa.Field(description="PCA特征V2") + V3: Series[float] = pa.Field(description="PCA特征V3") + V4: Series[float] = pa.Field(description="PCA特征V4") + V5: Series[float] = pa.Field(description="PCA特征V5") + V6: Series[float] = pa.Field(description="PCA特征V6") + V7: Series[float] = pa.Field(description="PCA特征V7") + V8: Series[float] = pa.Field(description="PCA特征V8") + V9: Series[float] = pa.Field(description="PCA特征V9") + V10: Series[float] = pa.Field(description="PCA特征V10") + V11: Series[float] = pa.Field(description="PCA特征V11") + V12: Series[float] = pa.Field(description="PCA特征V12") + V13: Series[float] = pa.Field(description="PCA特征V13") + V14: Series[float] = pa.Field(description="PCA特征V14") + V15: Series[float] = pa.Field(description="PCA特征V15") + V16: Series[float] = pa.Field(description="PCA特征V16") + V17: Series[float] = pa.Field(description="PCA特征V17") + V18: Series[float] = pa.Field(description="PCA特征V18") + V19: Series[float] = pa.Field(description="PCA特征V19") + V20: Series[float] = pa.Field(description="PCA特征V20") + V21: Series[float] = pa.Field(description="PCA特征V21") + V22: Series[float] = pa.Field(description="PCA特征V22") + V23: Series[float] = pa.Field(description="PCA特征V23") + V24: Series[float] = pa.Field(description="PCA特征V24") + V25: Series[float] = pa.Field(description="PCA特征V25") + V26: Series[float] = pa.Field(description="PCA特征V26") + V27: Series[float] = pa.Field(description="PCA特征V27") + V28: Series[float] = pa.Field(description="PCA特征V28") + Amount: Series[float] = pa.Field(ge=0, description="交易金额") + Class: Series[int] = pa.Field(isin=[0, 1], description="标签(0=正常, 1=欺诈)") + + @pa.check("Time") + def time_not_future(cls, series: Series[float]) -> Series[bool]: + return series <= 172800 + + @pa.check("Amount") + def amount_reasonable(cls, series: Series[float]) -> Series[bool]: + return series <= 10000 + + class Config: + strict = True + coerce = True + drop_invalid_rows = False + + +def validate_dataframe(df: pl.DataFrame, schema: pa.DataFrameModel = TransactionSchema) -> tuple[bool, Optional[str]]: + """ + 验证DataFrame是否符合schema + + Args: + df: Polars DataFrame + schema: Pandera schema + + Returns: + (是否验证通过, 错误信息) + """ + try: + pandas_df = df.to_pandas() + schema.validate(pandas_df) + logger.info("数据验证通过") + return True, None + except pa.errors.SchemaError as e: + error_msg = f"数据验证失败: {e}" + logger.error(error_msg) + return False, error_msg + except Exception as e: + error_msg = f"验证过程中发生错误: {e}" + logger.error(error_msg) + return False, error_msg + + +def validate_data_integrity(df: pl.DataFrame) -> dict: + """ + 验证数据完整性 + + Args: + df: Polars DataFrame + + Returns: + 验证结果字典 + """ + results = { + "总记录数": df.height, + "缺失值检查": {}, + "数据类型检查": {}, + "数值范围检查": {}, + "标签分布检查": {}, + "通过": True + } + + try: + missing_values = df.null_count().to_dict(as_series=False) + for col, count in missing_values.items(): + if count > 0: + results["缺失值检查"][col] = f"发现{count}个缺失值" + results["通过"] = False + else: + results["缺失值检查"][col] = "无缺失值" + + expected_columns = ['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', + 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', + 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount', 'Class'] + + missing_columns = set(expected_columns) - set(df.columns) + extra_columns = set(df.columns) - set(expected_columns) + + if missing_columns: + results["数据类型检查"]["缺失列"] = list(missing_columns) + results["通过"] = False + if extra_columns: + results["数据类型检查"]["多余列"] = list(extra_columns) + + results["数值范围检查"]["Time"] = { + "最小值": float(df["Time"].min()), + "最大值": float(df["Time"].max()), + "状态": "正常" if df["Time"].min() >= 0 and df["Time"].max() <= 172800 else "异常" + } + + results["数值范围检查"]["Amount"] = { + "最小值": float(df["Amount"].min()), + "最大值": float(df["Amount"].max()), + "状态": "正常" if df["Amount"].min() >= 0 and df["Amount"].max() <= 10000 else "异常" + } + + if df["Amount"].max() > 10000: + results["通过"] = False + + class_counts = df.group_by("Class").agg(pl.len().alias("count")).to_dict(as_series=False) + class_dist = {row["Class"]: row["count"] for row in class_counts} + + results["标签分布检查"] = { + "正常交易数": class_dist.get(0, 0), + "欺诈交易数": class_dist.get(1, 0), + "不平衡比例": class_dist.get(0, 0) / class_dist.get(1, 1) if class_dist.get(1, 0) > 0 else float('inf') + } + + if class_dist.get(1, 0) == 0: + results["标签分布检查"]["警告"] = "未发现欺诈交易样本" + results["通过"] = False + + except Exception as e: + results["错误"] = str(e) + results["通过"] = False + + return results + + +def print_validation_results(results: dict) -> None: + """打印验证结果""" + print("\n" + "=" * 60) + print("数据验证结果") + print("=" * 60) + + print(f"\n总记录数: {results['总记录数']}") + + print("\n缺失值检查:") + for col, status in results["缺失值检查"].items(): + print(f" {col}: {status}") + + if "数据类型检查" in results: + print("\n数据类型检查:") + for key, value in results["数据类型检查"].items(): + print(f" {key}: {value}") + + print("\n数值范围检查:") + for col, info in results["数值范围检查"].items(): + print(f" {col}:") + print(f" 最小值: {info['最小值']}") + print(f" 最大值: {info['最大值']}") + print(f" 状态: {info['状态']}") + + print("\n标签分布检查:") + for key, value in results["标签分布检查"].items(): + print(f" {key}: {value}") + + print(f"\n验证结果: {'✓ 通过' if results['通过'] else '✗ 未通过'}") + print("=" * 60) + + +if __name__ == "__main__": + import polars as pl + + test_data = pl.DataFrame({ + 'Time': [0.0, 1.0, 2.0], + 'V1': [-1.36, 0.96, 1.89], + 'V2': [0.96, -1.19, -1.94], + 'V3': [1.89, -1.94, 1.60], + 'V4': [-1.19, 1.60, 1.37], + 'V5': [-1.94, 1.37, -0.34], + 'V6': [1.60, -0.34, -0.47], + 'V7': [1.37, -0.47, 1.42], + 'V8': [-0.34, 1.42, 3.00], + 'V9': [-0.47, 3.00, -0.58], + 'V10': [1.42, -0.58, 1.18], + 'V11': [3.00, 1.18, 1.67], + 'V12': [-0.58, 1.67, -2.89], + 'V13': [1.18, -2.89, -0.60], + 'V14': [1.67, -0.60, -1.14], + 'V15': [-2.89, -1.14, -0.21], + 'V16': [-0.60, -0.21, 0.16], + 'V17': [-1.14, 0.16, 0.30], + 'V18': [-0.21, 0.30, -0.64], + 'V19': [0.16, -0.64, -0.21], + 'V20': [0.30, -0.21, 0.46], + 'V21': [-0.64, 0.46, 0.10], + 'V22': [-0.21, 0.10, -0.33], + 'V23': [0.46, -0.33, 0.13], + 'V24': [0.10, 0.13, -0.19], + 'V25': [-0.33, -0.19, -0.26], + 'V26': [0.13, -0.26, 100.0], + 'V27': [-0.19, 100.0, 50.0], + 'V28': [-0.26, 50.0, 25.0], + 'Amount': [100.0, 50.0, 25.0], + 'Class': [0, 1, 0] + }) + + passed, error = validate_dataframe(test_data) + print(f"Schema验证: {'✓ 通过' if passed else '✗ 未通过'}") + if error: + print(f"错误: {error}") + + integrity_results = validate_data_integrity(test_data) + print_validation_results(integrity_results) diff --git a/src/infer.py b/src/infer.py index dd89075..d5fac56 100644 --- a/src/infer.py +++ b/src/infer.py @@ -38,9 +38,7 @@ class FraudDetectionInference: transaction_array = transaction_array.reshape(1, -1) prediction = self.trainer.predict(transaction_array) - probability = self.trainer.predict_proba(transaction_array) - - fraud_prob = float(probability[0]) + fraud_prob = float(self.trainer.predict_proba(transaction_array)) normal_prob = float(1 - fraud_prob) max_prob = max(fraud_prob, normal_prob) diff --git a/src/llm_integration.py b/src/llm_integration.py new file mode 100644 index 0000000..ad70e16 --- /dev/null +++ b/src/llm_integration.py @@ -0,0 +1,178 @@ +from openai import OpenAI +import os +import logging +from typing import Optional, Dict, Any, List +from dotenv import load_dotenv + +load_dotenv() + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class DeepSeekLLM: + def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None): + self.api_key = api_key or os.getenv("DEEPSEEK_API_KEY") + self.base_url = base_url or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1") + + if not self.api_key: + logger.warning("未设置 DEEPSEEK_API_KEY,LLM功能将不可用") + self.client = None + else: + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + logger.info("DeepSeek LLM 初始化成功") + + def is_available(self) -> bool: + return self.client is not None + + def explain_prediction(self, transaction: List[float], prediction: Dict[str, Any]) -> str: + if not self.is_available(): + return "LLM服务不可用,请配置DEEPSEEK_API_KEY" + + feature_names = [ + 'Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9', + 'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19', + 'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount' + ] + + transaction_dict = dict(zip(feature_names, transaction)) + + prompt = f""" +你是一个专业的信用卡欺诈检测专家。请分析以下交易数据并给出解释。 + +交易数据: +{transaction_dict} + +预测结果: +- 预测类别: {'欺诈' if prediction['predicted_class'] == 1 else '正常'} +- 欺诈概率: {prediction['fraud_probability']:.4f} +- 正常概率: {prediction['normal_probability']:.4f} +- 置信度: {prediction['confidence']} + +请用中文回答以下问题: +1. 为什么这个交易被预测为{'欺诈' if prediction['predicted_class'] == 1 else '正常'}? +2. 哪些特征对预测结果影响最大? +3. 请提供3条具体的行动建议。 + +请保持回答简洁、专业,控制在200字以内。 +""" + + try: + response = self.client.chat.completions.create( + model="deepseek-chat", + messages=[{"role": "user", "content": prompt}], + temperature=0.7, + max_tokens=500 + ) + return response.choices[0].message.content + except Exception as e: + logger.error(f"LLM调用失败: {e}") + return f"LLM调用失败: {str(e)}" + + def generate_action_suggestions(self, transaction: List[float], prediction: Dict[str, Any], feature_analysis: Dict[str, Any]) -> List[str]: + if not self.is_available(): + return ["LLM服务不可用,请配置DEEPSEEK_API_KEY"] + + prompt = f""" +你是一个专业的信用卡欺诈检测专家。基于以下信息,请生成3条具体的行动建议。 + +交易数据: +- 交易金额: {transaction[-1]} +- 交易时间: {transaction[0]} + +预测结果: +- 预测类别: {'欺诈' if prediction['predicted_class'] == 1 else '正常'} +- 欺诈概率: {prediction['fraud_probability']:.4f} +- 置信度: {prediction['confidence']} + +特征分析: +{feature_analysis} + +请用中文生成3条具体的行动建议,每条建议应该包含: +1. 行动内容 +2. 优先级(紧急/高/中/低) +3. 执行原因 + +请保持回答简洁、专业。 +""" + + try: + response = self.client.chat.completions.create( + model="deepseek-chat", + messages=[{"role": "user", "content": prompt}], + temperature=0.7, + max_tokens=400 + ) + content = response.choices[0].message.content + + suggestions = [] + lines = content.split('\n') + for line in lines: + line = line.strip() + if line and not line.startswith(('1.', '2.', '3.', '-')): + suggestions.append(line) + + return suggestions[:3] if suggestions else [content] + except Exception as e: + logger.error(f"LLM调用失败: {e}") + return [f"LLM调用失败: {str(e)}"] + + def analyze_transaction_context(self, transaction: List[float], historical_data: Optional[List[List[float]]] = None) -> str: + if not self.is_available(): + return "LLM服务不可用,请配置DEEPSEEK_API_KEY" + + prompt = f""" +你是一个专业的信用卡欺诈检测专家。请分析以下交易数据。 + +交易数据: +- 交易金额: {transaction[-1]} +- 交易时间: {transaction[0]} + +请用中文分析: +1. 这个交易金额是否异常? +2. 这个交易时间是否异常? +3. 基于这些信息,这个交易的风险等级如何? + +请保持回答简洁、专业,控制在100字以内。 +""" + + try: + response = self.client.chat.completions.create( + model="deepseek-chat", + messages=[{"role": "user", "content": prompt}], + temperature=0.7, + max_tokens=300 + ) + return response.choices[0].message.content + except Exception as e: + logger.error(f"LLM调用失败: {e}") + return f"LLM调用失败: {str(e)}" + + +def create_llm(api_key: Optional[str] = None) -> DeepSeekLLM: + return DeepSeekLLM(api_key=api_key) + + +if __name__ == "__main__": + llm = create_llm() + + if llm.is_available(): + test_transaction = [0.0, -1.36, 0.96, 1.89, -1.19, -1.94, 1.60, 1.37, -0.34, -0.47, + 1.42, 3.00, -0.58, 1.18, 1.67, -2.89, -0.60, -1.14, -0.21, 0.16, + 0.30, -0.64, -0.21, 0.46, 0.10, -0.33, 0.13, -0.19, -0.26, 100.0] + + test_prediction = { + "predicted_class": 0, + "fraud_probability": 0.05, + "normal_probability": 0.95, + "confidence": "高" + } + + explanation = llm.explain_prediction(test_transaction, test_prediction) + print("=== LLM 解释 ===") + print(explanation) + else: + print("LLM服务不可用,请配置DEEPSEEK_API_KEY") diff --git a/src/test_fraud_detection.py b/src/test_fraud_detection.py new file mode 100644 index 0000000..3822b2e --- /dev/null +++ b/src/test_fraud_detection.py @@ -0,0 +1,105 @@ +import polars as pl +import numpy as np +import sys +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).parent)) +from agent_app import create_agent + +def test_fraud_detection(): + """测试欺诈检测功能""" + + print("=" * 80) + print("欺诈检测系统测试") + print("=" * 80) + + agent = create_agent(model_dir="models", model_name="random_forest") + + test_data = [ + { + "description": "正常交易(小额)", + "transaction": [0.0, -1.36, 0.96, 1.89, -1.19, -1.94, 1.60, 1.37, -0.34, -0.47, + 1.42, 3.00, -0.58, 1.18, 1.67, -2.89, -0.60, -1.14, -0.21, 0.16, + 0.30, -0.64, -0.21, 0.46, 0.10, -0.33, 0.13, -0.19, -0.26, 100.0], + "expected": "正常" + }, + { + "description": "欺诈交易(真实样本)", + "transaction": [406.0, -2.312227, 1.951992, -1.609851, 3.997906, -0.522188, -1.426545, -2.537387, 1.391657, -2.770089, + -2.772272, 3.202033, -2.899907, -0.595221, -4.289254, 0.389724, -1.140747, -2.830056, -0.016822, 0.416956, + 0.126911, 0.517232, -0.035049, -0.465211, 0.320198, 0.044519, 0.177840, 0.261145, -0.143276, 0.0], + "expected": "欺诈" + }, + { + "description": "正常交易(中等金额)", + "transaction": [240.0, 0.45, -0.23, 1.20, -0.56, 0.89, -1.23, 0.34, -0.67, 1.12, + -0.45, 0.78, -1.34, 0.45, -0.78, 1.23, -0.45, 0.67, -0.89, 0.34, + -0.56, 0.78, -0.45, 0.67, -0.34, 0.45, -0.23, 0.56, 0.25, 50.0], + "expected": "正常" + } + ] + + for i, test_case in enumerate(test_data, 1): + print(f"\n{'=' * 80}") + print(f"测试用例 {i}: {test_case['description']}") + print(f"{'=' * 80}") + + print(f"\n交易金额: ${test_case['transaction'][-1]:,.2f}") + print(f"交易时间: {test_case['transaction'][0]:.1f} 秒") + print(f"预期结果: {test_case['expected']}") + + print(f"\n{'-' * 80}") + print("开始处理...") + print(f"{'-' * 80}") + + result = agent.process_transaction(test_case['transaction']) + + print(f"\n{'=' * 80}") + print("预测结果") + print(f"{'=' * 80}") + + print(f"\n预测类别: {result.evaluation.class_name}") + print(f"欺诈概率: {result.evaluation.fraud_probability:.4f}") + print(f"正常概率: {result.evaluation.normal_probability:.4f}") + print(f"置信度: {result.evaluation.confidence}") + + print(f"\n{'=' * 80}") + print("关键特征") + print(f"{'=' * 80}") + + for feature in result.explanation.key_features[:5]: + print(f"\n{feature.feature_name}:") + print(f" 特征值: {feature.value:.4f}") + print(f" 重要性: {feature.importance:.4f}") + print(f" 贡献度: {feature.contribution:.4f}") + print(f" 影响方向: {feature.impact}") + + print(f"\n{'=' * 80}") + print("总体解释") + print(f"{'=' * 80}") + print(f"\n{result.explanation.overall_explanation}") + + print(f"\n{'=' * 80}") + print("行动计划") + print(f"{'=' * 80}") + + for action in result.action_plan.actions: + print(f"\n[{action.priority}] {action.action}") + print(f" 原因: {action.reason}") + + print(f"\n{'=' * 80}") + print(f"预测结果: {result.evaluation.class_name} | 预期结果: {test_case['expected']}") + print(f"{'=' * 80}") + + if result.evaluation.class_name == test_case['expected']: + print("✓ 预测正确!") + else: + print("✗ 预测错误!") + + print(f"\n{'=' * 80}") + print("测试完成") + print(f"{'=' * 80}") + + +if __name__ == "__main__": + test_fraud_detection() diff --git a/src/train.py b/src/train.py index 67a20fd..b209894 100644 --- a/src/train.py +++ b/src/train.py @@ -6,6 +6,7 @@ from sklearn.metrics import ( precision_recall_curve, auc, confusion_matrix ) from imblearn.over_sampling import SMOTE +from lightgbm import LGBMClassifier import numpy as np import logging import joblib @@ -34,6 +35,12 @@ class CreditCardFraudModelTrainer: random_state=42, class_weight="balanced", n_estimators=100 + ), + "lightgbm": LGBMClassifier( + random_state=42, + class_weight="balanced", + n_estimators=100, + verbose=-1 ) }