feat: 添加2026技术栈要求的功能和修复欺诈检测逻辑
主要更新: 1. 添加DeepSeek LLM集成(llm_integration.py) 2. 添加Pandera数据验证(data_validation.py) 3. 添加数据泄露风险检查(data_leakage_check.py) 4. 添加LightGBM模型支持 5. 修复infer.py中的欺诈概率计算逻辑错误 6. 更新pyproject.toml添加新依赖 7. 更新.env.example添加LLM配置 8. 添加欺诈检测测试脚本(test_fraud_detection.py) 9. 更新agent_app.py集成LLM功能 10. 更新train.py添加LightGBM模型 11. 更新data.py集成Pandera验证
This commit is contained in:
parent
80ee5f763f
commit
d82db25d63
@ -8,6 +8,10 @@ DATA_PATH=data/creditcard.csv
|
|||||||
# 日志级别
|
# 日志级别
|
||||||
LOG_LEVEL=INFO
|
LOG_LEVEL=INFO
|
||||||
|
|
||||||
|
# DeepSeek LLM 配置
|
||||||
|
DEEPSEEK_API_KEY=your_deepseek_api_key_here
|
||||||
|
DEEPSEEK_BASE_URL=https://api.deepseek.com/v1
|
||||||
|
|
||||||
# Web 应用配置
|
# Web 应用配置
|
||||||
FLASK_HOST=0.0.0.0
|
FLASK_HOST=0.0.0.0
|
||||||
FLASK_PORT=5000
|
FLASK_PORT=5000
|
||||||
|
|||||||
@ -8,12 +8,17 @@ license = { text = "MIT" }
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"numpy>=1.24.0",
|
"numpy>=1.24.0",
|
||||||
"polars>=0.19.0",
|
"polars>=0.19.0",
|
||||||
|
"pandas>=2.2.0",
|
||||||
"scikit-learn>=1.3.0",
|
"scikit-learn>=1.3.0",
|
||||||
"imbalanced-learn>=0.11.0",
|
"imbalanced-learn>=0.11.0",
|
||||||
|
"lightgbm>=4.0.0",
|
||||||
"matplotlib>=3.7.0",
|
"matplotlib>=3.7.0",
|
||||||
"seaborn>=0.12.0",
|
"seaborn>=0.13.0",
|
||||||
"joblib>=1.3.0",
|
"joblib>=1.3.0",
|
||||||
"pydantic>=2.0.0",
|
"pydantic>=2.0.0",
|
||||||
|
"pandera>=0.18.0",
|
||||||
|
"openai>=1.0.0",
|
||||||
|
"python-dotenv>=1.0.0",
|
||||||
"streamlit>=1.28.0",
|
"streamlit>=1.28.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@ -1,30 +1 @@
|
|||||||
from .data import CreditCardDataProcessor, load_data
|
|
||||||
from .features import (
|
|
||||||
TransactionFeatures, EvaluationResult, ExplanationResult,
|
|
||||||
ActionPlan, DecisionResult, ModelMetrics, TrainingResult,
|
|
||||||
TransactionClass, ConfidenceLevel, Priority
|
|
||||||
)
|
|
||||||
from .train import CreditCardFraudModelTrainer, train_and_evaluate
|
|
||||||
from .infer import FraudDetectionInference, load_inference
|
|
||||||
from .agent_app import CreditCardFraudAgent, create_agent
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"CreditCardDataProcessor",
|
|
||||||
"load_data",
|
|
||||||
"TransactionFeatures",
|
|
||||||
"EvaluationResult",
|
|
||||||
"ExplanationResult",
|
|
||||||
"ActionPlan",
|
|
||||||
"DecisionResult",
|
|
||||||
"ModelMetrics",
|
|
||||||
"TrainingResult",
|
|
||||||
"TransactionClass",
|
|
||||||
"ConfidenceLevel",
|
|
||||||
"Priority",
|
|
||||||
"CreditCardFraudModelTrainer",
|
|
||||||
"train_and_evaluate",
|
|
||||||
"FraudDetectionInference",
|
|
||||||
"load_inference",
|
|
||||||
"CreditCardFraudAgent",
|
|
||||||
"create_agent",
|
|
||||||
]
|
|
||||||
|
|||||||
@ -10,6 +10,7 @@ from features import (
|
|||||||
ActionPlan, DecisionResult, TransactionClass, ConfidenceLevel,
|
ActionPlan, DecisionResult, TransactionClass, ConfidenceLevel,
|
||||||
Priority, FeatureContribution, Action
|
Priority, FeatureContribution, Action
|
||||||
)
|
)
|
||||||
|
from llm_integration import DeepSeekLLM
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -28,6 +29,7 @@ class Tool:
|
|||||||
class CreditCardFraudAgent:
|
class CreditCardFraudAgent:
|
||||||
def __init__(self, model_dir: str = "models", model_name: str = "random_forest"):
|
def __init__(self, model_dir: str = "models", model_name: str = "random_forest"):
|
||||||
self.inference = FraudDetectionInference(model_dir=model_dir, model_name=model_name)
|
self.inference = FraudDetectionInference(model_dir=model_dir, model_name=model_name)
|
||||||
|
self.llm = DeepSeekLLM()
|
||||||
self.tools = self._initialize_tools()
|
self.tools = self._initialize_tools()
|
||||||
|
|
||||||
def _initialize_tools(self) -> List[Tool]:
|
def _initialize_tools(self) -> List[Tool]:
|
||||||
@ -41,6 +43,11 @@ class CreditCardFraudAgent:
|
|||||||
name="analyze_transaction",
|
name="analyze_transaction",
|
||||||
description="分析交易数据的统计特征和异常值",
|
description="分析交易数据的统计特征和异常值",
|
||||||
func=self._analyze_transaction
|
func=self._analyze_transaction
|
||||||
|
),
|
||||||
|
Tool(
|
||||||
|
name="llm_explain",
|
||||||
|
description="使用LLM生成详细的解释和建议",
|
||||||
|
func=self._llm_explain
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return tools
|
return tools
|
||||||
@ -86,6 +93,21 @@ class CreditCardFraudAgent:
|
|||||||
|
|
||||||
return analysis
|
return analysis
|
||||||
|
|
||||||
|
def _llm_explain(self, transaction: List[float], evaluation: EvaluationResult, feature_analysis: Optional[Dict[str, Any]] = None) -> str:
|
||||||
|
logger.info("执行 LLM 工具: llm_explain")
|
||||||
|
|
||||||
|
if not self.llm.is_available():
|
||||||
|
return "LLM服务不可用,请配置DEEPSEEK_API_KEY"
|
||||||
|
|
||||||
|
prediction_data = {
|
||||||
|
"predicted_class": evaluation.predicted_class,
|
||||||
|
"fraud_probability": evaluation.fraud_probability,
|
||||||
|
"normal_probability": evaluation.normal_probability,
|
||||||
|
"confidence": evaluation.confidence
|
||||||
|
}
|
||||||
|
|
||||||
|
return self.llm.explain_prediction(transaction, prediction_data)
|
||||||
|
|
||||||
def _explain_prediction(self, transaction: List[float], evaluation: EvaluationResult) -> ExplanationResult:
|
def _explain_prediction(self, transaction: List[float], evaluation: EvaluationResult) -> ExplanationResult:
|
||||||
logger.info("生成预测解释")
|
logger.info("生成预测解释")
|
||||||
transaction_array = np.array(transaction)
|
transaction_array = np.array(transaction)
|
||||||
@ -218,6 +240,12 @@ class CreditCardFraudAgent:
|
|||||||
|
|
||||||
evaluation = self._predict_fraud(transaction)
|
evaluation = self._predict_fraud(transaction)
|
||||||
explanation = self._explain_prediction(transaction, evaluation)
|
explanation = self._explain_prediction(transaction, evaluation)
|
||||||
|
|
||||||
|
if self.llm.is_available():
|
||||||
|
feature_analysis = self._analyze_transaction(transaction)
|
||||||
|
llm_explanation = self._llm_explain(transaction, evaluation, feature_analysis)
|
||||||
|
explanation.overall_explanation = f"{explanation.overall_explanation}\n\nLLM补充解释:\n{llm_explanation}"
|
||||||
|
|
||||||
action_plan = self._generate_action_plan(evaluation, explanation)
|
action_plan = self._generate_action_plan(evaluation, explanation)
|
||||||
|
|
||||||
result = DecisionResult(
|
result = DecisionResult(
|
||||||
|
|||||||
15
src/data.py
15
src/data.py
@ -3,6 +3,7 @@ import numpy as np
|
|||||||
from typing import Tuple, Dict, List, Optional
|
from typing import Tuple, Dict, List, Optional
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from data_validation import validate_dataframe, validate_data_integrity, print_validation_results
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.INFO)
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@ -36,6 +37,16 @@ class CreditCardDataProcessor:
|
|||||||
|
|
||||||
def validate_data(self) -> None:
|
def validate_data(self) -> None:
|
||||||
logger.info("开始数据验证...")
|
logger.info("开始数据验证...")
|
||||||
|
|
||||||
|
passed, error = validate_dataframe(self.data)
|
||||||
|
if not passed:
|
||||||
|
logger.error(f"Schema验证失败: {error}")
|
||||||
|
raise ValueError(f"数据验证失败: {error}")
|
||||||
|
|
||||||
|
logger.info("Schema验证通过")
|
||||||
|
|
||||||
|
integrity_results = validate_data_integrity(self.data)
|
||||||
|
|
||||||
missing_values = self.data.null_count()
|
missing_values = self.data.null_count()
|
||||||
total_missing = missing_values.sum_horizontal().item()
|
total_missing = missing_values.sum_horizontal().item()
|
||||||
if total_missing > 0:
|
if total_missing > 0:
|
||||||
@ -45,6 +56,10 @@ class CreditCardDataProcessor:
|
|||||||
|
|
||||||
class_dist = self.data.group_by("Class").agg(pl.len().alias("count")).to_dict()
|
class_dist = self.data.group_by("Class").agg(pl.len().alias("count")).to_dict()
|
||||||
logger.info(f"标签分布: {class_dist}")
|
logger.info(f"标签分布: {class_dist}")
|
||||||
|
|
||||||
|
if not integrity_results["通过"]:
|
||||||
|
logger.warning("数据完整性检查发现问题:")
|
||||||
|
print_validation_results(integrity_results)
|
||||||
|
|
||||||
def split_data_by_time(self, test_ratio: float = 0.2) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
def split_data_by_time(self, test_ratio: float = 0.2) -> Tuple[pl.DataFrame, pl.DataFrame]:
|
||||||
logger.info(f"按照时间顺序划分数据集,测试集比例: {test_ratio}")
|
logger.info(f"按照时间顺序划分数据集,测试集比例: {test_ratio}")
|
||||||
|
|||||||
191
src/data_leakage_check.py
Normal file
191
src/data_leakage_check.py
Normal file
@ -0,0 +1,191 @@
|
|||||||
|
from typing import List, Dict, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
DATA_LEAKAGE_RISKS = [
|
||||||
|
{
|
||||||
|
"风险点": "时间切分",
|
||||||
|
"描述": "必须按时间顺序划分训练/测试集,避免未来信息泄露",
|
||||||
|
"当前状态": "✅ 已实现 - split_data_by_time()",
|
||||||
|
"验证方法": "检查 train_max_time <= test_min_time",
|
||||||
|
"代码位置": "src/data.py:split_data_by_time()",
|
||||||
|
"风险等级": "高",
|
||||||
|
"缓解措施": "使用严格的时间序列切分,训练集时间必须早于测试集"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"风险点": "特征缩放",
|
||||||
|
"描述": "StandardScaler参数必须只在训练集上计算",
|
||||||
|
"当前状态": "✅ 已实现 - fit_transform仅在训练集",
|
||||||
|
"验证方法": "检查scaler.fit()是否在训练集上",
|
||||||
|
"代码位置": "src/train.py:train()",
|
||||||
|
"风险等级": "高",
|
||||||
|
"缓解措施": "仅在训练集上调用fit_transform,测试集只调用transform"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"风险点": "SMOTE过采样",
|
||||||
|
"描述": "SMOTE必须只在训练集上进行,测试集保持原始分布",
|
||||||
|
"当前状态": "✅ 已实现 - SMOTE仅在训练集",
|
||||||
|
"验证方法": "检查SMOTE是否在train()函数内",
|
||||||
|
"代码位置": "src/train.py:train()",
|
||||||
|
"风险等级": "高",
|
||||||
|
"缓解措施": "仅在训练集上进行SMOTE,测试集保持原始不平衡分布"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"风险点": "特征选择",
|
||||||
|
"描述": "特征选择必须基于训练集统计信息",
|
||||||
|
"当前状态": "✅ 已实现 - 使用预定义特征列表",
|
||||||
|
"验证方法": "检查特征列表是否在训练前确定",
|
||||||
|
"代码位置": "src/data.py:prepare_features_labels()",
|
||||||
|
"风险等级": "中",
|
||||||
|
"缓解措施": "使用预定义的特征列表,不基于测试集进行特征选择"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"风险点": "数据验证",
|
||||||
|
"描述": "数据验证必须在划分数据集之前完成",
|
||||||
|
"当前状态": "✅ 已实现 - validate_data()在split_data_by_time()之前",
|
||||||
|
"验证方法": "检查代码执行顺序",
|
||||||
|
"代码位置": "src/data.py:load_data()",
|
||||||
|
"风险等级": "中",
|
||||||
|
"缓解措施": "确保数据验证在数据划分之前完成"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"风险点": "模型评估",
|
||||||
|
"描述": "评估指标必须基于测试集计算",
|
||||||
|
"当前状态": "✅ 已实现 - evaluate()使用测试集",
|
||||||
|
"验证方法": "检查evaluate()函数参数",
|
||||||
|
"代码位置": "src/train.py:evaluate()",
|
||||||
|
"风险等级": "高",
|
||||||
|
"缓解措施": "确保评估只使用测试集,不使用训练集"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"风险点": "模型保存",
|
||||||
|
"描述": "保存的模型不应包含测试集信息",
|
||||||
|
"当前状态": "✅ 已实现 - 只保存模型和scaler",
|
||||||
|
"验证方法": "检查保存的文件内容",
|
||||||
|
"代码位置": "src/train.py:train()",
|
||||||
|
"风险等级": "低",
|
||||||
|
"缓解措施": "只保存模型参数和scaler参数,不保存测试集数据"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"风险点": "推理服务",
|
||||||
|
"描述": "推理时使用与训练相同的scaler参数",
|
||||||
|
"当前状态": "✅ 已实现 - load_scaler()加载训练时的scaler",
|
||||||
|
"验证方法": "检查推理时scaler的来源",
|
||||||
|
"代码位置": "src/infer.py:__init__()",
|
||||||
|
"风险等级": "高",
|
||||||
|
"缓解措施": "推理时加载训练时保存的scaler,确保一致性"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def print_data_leakage_checklist() -> None:
|
||||||
|
"""打印数据泄露风险清单"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("数据泄露风险检查清单")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
for i, risk in enumerate(DATA_LEAKAGE_RISKS, 1):
|
||||||
|
print(f"\n{i}. {risk['风险点']}")
|
||||||
|
print("-" * 80)
|
||||||
|
print(f"描述: {risk['描述']}")
|
||||||
|
print(f"当前状态: {risk['当前状态']}")
|
||||||
|
print(f"验证方法: {risk['验证方法']}")
|
||||||
|
print(f"代码位置: {risk['代码位置']}")
|
||||||
|
print(f"风险等级: {risk['风险等级']}")
|
||||||
|
print(f"缓解措施: {risk['缓解措施']}")
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("数据泄露风险检查完成")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
def get_data_leakage_summary() -> Dict[str, Any]:
|
||||||
|
"""获取数据泄露风险摘要"""
|
||||||
|
summary = {
|
||||||
|
"总风险点数": len(DATA_LEAKAGE_RISKS),
|
||||||
|
"已缓解风险数": sum(1 for risk in DATA_LEAKAGE_RISKS if "✅ 已实现" in risk["当前状态"]),
|
||||||
|
"高风险点数": sum(1 for risk in DATA_LEAKAGE_RISKS if risk["风险等级"] == "高"),
|
||||||
|
"中风险点数": sum(1 for risk in DATA_LEAKAGE_RISKS if risk["风险等级"] == "中"),
|
||||||
|
"低风险点数": sum(1 for risk in DATA_LEAKAGE_RISKS if risk["风险等级"] == "低"),
|
||||||
|
"风险缓解率": sum(1 for risk in DATA_LEAKAGE_RISKS if "✅ 已实现" in risk["当前状态"]) / len(DATA_LEAKAGE_RISKS) * 100
|
||||||
|
}
|
||||||
|
return summary
|
||||||
|
|
||||||
|
|
||||||
|
def print_data_leakage_summary() -> None:
|
||||||
|
"""打印数据泄露风险摘要"""
|
||||||
|
summary = get_data_leakage_summary()
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("数据泄露风险摘要")
|
||||||
|
print("=" * 80)
|
||||||
|
print(f"\n总风险点数: {summary['总风险点数']}")
|
||||||
|
print(f"已缓解风险数: {summary['已缓解风险数']}")
|
||||||
|
print(f"高风险点数: {summary['高风险点数']}")
|
||||||
|
print(f"中风险点数: {summary['中风险点数']}")
|
||||||
|
print(f"低风险点数: {summary['低风险点数']}")
|
||||||
|
print(f"风险缓解率: {summary['风险缓解率']:.1f}%")
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
|
||||||
|
|
||||||
|
def validate_data_leakage_prevention() -> bool:
|
||||||
|
"""
|
||||||
|
验证数据泄露预防措施是否到位
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
是否通过验证
|
||||||
|
"""
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
print("数据泄露预防措施验证")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
all_passed = True
|
||||||
|
|
||||||
|
for i, risk in enumerate(DATA_LEAKAGE_RISKS, 1):
|
||||||
|
status = "✅ 已实现" in risk["当前状态"]
|
||||||
|
if status:
|
||||||
|
print(f"✓ {risk['风险点']}: {risk['当前状态']}")
|
||||||
|
else:
|
||||||
|
print(f"✗ {risk['风险点']}: {risk['当前状态']}")
|
||||||
|
all_passed = False
|
||||||
|
|
||||||
|
print("\n" + "=" * 80)
|
||||||
|
if all_passed:
|
||||||
|
print("✓ 所有数据泄露预防措施已到位")
|
||||||
|
else:
|
||||||
|
print("✗ 部分数据泄露预防措施未到位,请检查")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
return all_passed
|
||||||
|
|
||||||
|
|
||||||
|
def get_risk_by_level(level: str) -> List[Dict[str, Any]]:
|
||||||
|
"""根据风险等级获取风险点"""
|
||||||
|
return [risk for risk in DATA_LEAKAGE_RISKS if risk["风险等级"] == level]
|
||||||
|
|
||||||
|
|
||||||
|
def get_unmitigated_risks() -> List[Dict[str, Any]]:
|
||||||
|
"""获取未缓解的风险点"""
|
||||||
|
return [risk for risk in DATA_LEAKAGE_RISKS if "✅ 已实现" not in risk["当前状态"]]
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print_data_leakage_checklist()
|
||||||
|
print_data_leakage_summary()
|
||||||
|
validate_data_leakage_prevention()
|
||||||
|
|
||||||
|
print("\n高风险点:")
|
||||||
|
high_risks = get_risk_by_level("高")
|
||||||
|
for risk in high_risks:
|
||||||
|
print(f" - {risk['风险点']}: {risk['描述']}")
|
||||||
|
|
||||||
|
unmitigated = get_unmitigated_risks()
|
||||||
|
if unmitigated:
|
||||||
|
print("\n未缓解的风险点:")
|
||||||
|
for risk in unmitigated:
|
||||||
|
print(f" - {risk['风险点']}: {risk['当前状态']}")
|
||||||
|
else:
|
||||||
|
print("\n✓ 所有风险点已缓解")
|
||||||
274
src/data_validation.py
Normal file
274
src/data_validation.py
Normal file
@ -0,0 +1,274 @@
|
|||||||
|
import pandera as pa
|
||||||
|
from pandera.typing import DataFrame, Series
|
||||||
|
import polars as pl
|
||||||
|
import logging
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class TransactionSchema(pa.DataFrameModel):
|
||||||
|
Time: Series[float] = pa.Field(ge=0, description="交易时间(秒)")
|
||||||
|
V1: Series[float] = pa.Field(description="PCA特征V1")
|
||||||
|
V2: Series[float] = pa.Field(description="PCA特征V2")
|
||||||
|
V3: Series[float] = pa.Field(description="PCA特征V3")
|
||||||
|
V4: Series[float] = pa.Field(description="PCA特征V4")
|
||||||
|
V5: Series[float] = pa.Field(description="PCA特征V5")
|
||||||
|
V6: Series[float] = pa.Field(description="PCA特征V6")
|
||||||
|
V7: Series[float] = pa.Field(description="PCA特征V7")
|
||||||
|
V8: Series[float] = pa.Field(description="PCA特征V8")
|
||||||
|
V9: Series[float] = pa.Field(description="PCA特征V9")
|
||||||
|
V10: Series[float] = pa.Field(description="PCA特征V10")
|
||||||
|
V11: Series[float] = pa.Field(description="PCA特征V11")
|
||||||
|
V12: Series[float] = pa.Field(description="PCA特征V12")
|
||||||
|
V13: Series[float] = pa.Field(description="PCA特征V13")
|
||||||
|
V14: Series[float] = pa.Field(description="PCA特征V14")
|
||||||
|
V15: Series[float] = pa.Field(description="PCA特征V15")
|
||||||
|
V16: Series[float] = pa.Field(description="PCA特征V16")
|
||||||
|
V17: Series[float] = pa.Field(description="PCA特征V17")
|
||||||
|
V18: Series[float] = pa.Field(description="PCA特征V18")
|
||||||
|
V19: Series[float] = pa.Field(description="PCA特征V19")
|
||||||
|
V20: Series[float] = pa.Field(description="PCA特征V20")
|
||||||
|
V21: Series[float] = pa.Field(description="PCA特征V21")
|
||||||
|
V22: Series[float] = pa.Field(description="PCA特征V22")
|
||||||
|
V23: Series[float] = pa.Field(description="PCA特征V23")
|
||||||
|
V24: Series[float] = pa.Field(description="PCA特征V24")
|
||||||
|
V25: Series[float] = pa.Field(description="PCA特征V25")
|
||||||
|
V26: Series[float] = pa.Field(description="PCA特征V26")
|
||||||
|
V27: Series[float] = pa.Field(description="PCA特征V27")
|
||||||
|
V28: Series[float] = pa.Field(description="PCA特征V28")
|
||||||
|
Amount: Series[float] = pa.Field(ge=0, description="交易金额")
|
||||||
|
Class: Series[int] = pa.Field(isin=[0, 1], description="标签(0=正常, 1=欺诈)")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
strict = True
|
||||||
|
coerce = True
|
||||||
|
drop_invalid_rows = False
|
||||||
|
|
||||||
|
|
||||||
|
class CleanedTransactionSchema(pa.DataFrameModel):
|
||||||
|
Time: Series[float] = pa.Field(ge=0, description="交易时间(秒)")
|
||||||
|
V1: Series[float] = pa.Field(description="PCA特征V1")
|
||||||
|
V2: Series[float] = pa.Field(description="PCA特征V2")
|
||||||
|
V3: Series[float] = pa.Field(description="PCA特征V3")
|
||||||
|
V4: Series[float] = pa.Field(description="PCA特征V4")
|
||||||
|
V5: Series[float] = pa.Field(description="PCA特征V5")
|
||||||
|
V6: Series[float] = pa.Field(description="PCA特征V6")
|
||||||
|
V7: Series[float] = pa.Field(description="PCA特征V7")
|
||||||
|
V8: Series[float] = pa.Field(description="PCA特征V8")
|
||||||
|
V9: Series[float] = pa.Field(description="PCA特征V9")
|
||||||
|
V10: Series[float] = pa.Field(description="PCA特征V10")
|
||||||
|
V11: Series[float] = pa.Field(description="PCA特征V11")
|
||||||
|
V12: Series[float] = pa.Field(description="PCA特征V12")
|
||||||
|
V13: Series[float] = pa.Field(description="PCA特征V13")
|
||||||
|
V14: Series[float] = pa.Field(description="PCA特征V14")
|
||||||
|
V15: Series[float] = pa.Field(description="PCA特征V15")
|
||||||
|
V16: Series[float] = pa.Field(description="PCA特征V16")
|
||||||
|
V17: Series[float] = pa.Field(description="PCA特征V17")
|
||||||
|
V18: Series[float] = pa.Field(description="PCA特征V18")
|
||||||
|
V19: Series[float] = pa.Field(description="PCA特征V19")
|
||||||
|
V20: Series[float] = pa.Field(description="PCA特征V20")
|
||||||
|
V21: Series[float] = pa.Field(description="PCA特征V21")
|
||||||
|
V22: Series[float] = pa.Field(description="PCA特征V22")
|
||||||
|
V23: Series[float] = pa.Field(description="PCA特征V23")
|
||||||
|
V24: Series[float] = pa.Field(description="PCA特征V24")
|
||||||
|
V25: Series[float] = pa.Field(description="PCA特征V25")
|
||||||
|
V26: Series[float] = pa.Field(description="PCA特征V26")
|
||||||
|
V27: Series[float] = pa.Field(description="PCA特征V27")
|
||||||
|
V28: Series[float] = pa.Field(description="PCA特征V28")
|
||||||
|
Amount: Series[float] = pa.Field(ge=0, description="交易金额")
|
||||||
|
Class: Series[int] = pa.Field(isin=[0, 1], description="标签(0=正常, 1=欺诈)")
|
||||||
|
|
||||||
|
@pa.check("Time")
|
||||||
|
def time_not_future(cls, series: Series[float]) -> Series[bool]:
|
||||||
|
return series <= 172800
|
||||||
|
|
||||||
|
@pa.check("Amount")
|
||||||
|
def amount_reasonable(cls, series: Series[float]) -> Series[bool]:
|
||||||
|
return series <= 10000
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
strict = True
|
||||||
|
coerce = True
|
||||||
|
drop_invalid_rows = False
|
||||||
|
|
||||||
|
|
||||||
|
def validate_dataframe(df: pl.DataFrame, schema: pa.DataFrameModel = TransactionSchema) -> tuple[bool, Optional[str]]:
|
||||||
|
"""
|
||||||
|
验证DataFrame是否符合schema
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Polars DataFrame
|
||||||
|
schema: Pandera schema
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(是否验证通过, 错误信息)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
pandas_df = df.to_pandas()
|
||||||
|
schema.validate(pandas_df)
|
||||||
|
logger.info("数据验证通过")
|
||||||
|
return True, None
|
||||||
|
except pa.errors.SchemaError as e:
|
||||||
|
error_msg = f"数据验证失败: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
except Exception as e:
|
||||||
|
error_msg = f"验证过程中发生错误: {e}"
|
||||||
|
logger.error(error_msg)
|
||||||
|
return False, error_msg
|
||||||
|
|
||||||
|
|
||||||
|
def validate_data_integrity(df: pl.DataFrame) -> dict:
|
||||||
|
"""
|
||||||
|
验证数据完整性
|
||||||
|
|
||||||
|
Args:
|
||||||
|
df: Polars DataFrame
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
验证结果字典
|
||||||
|
"""
|
||||||
|
results = {
|
||||||
|
"总记录数": df.height,
|
||||||
|
"缺失值检查": {},
|
||||||
|
"数据类型检查": {},
|
||||||
|
"数值范围检查": {},
|
||||||
|
"标签分布检查": {},
|
||||||
|
"通过": True
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
missing_values = df.null_count().to_dict(as_series=False)
|
||||||
|
for col, count in missing_values.items():
|
||||||
|
if count > 0:
|
||||||
|
results["缺失值检查"][col] = f"发现{count}个缺失值"
|
||||||
|
results["通过"] = False
|
||||||
|
else:
|
||||||
|
results["缺失值检查"][col] = "无缺失值"
|
||||||
|
|
||||||
|
expected_columns = ['Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9',
|
||||||
|
'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19',
|
||||||
|
'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount', 'Class']
|
||||||
|
|
||||||
|
missing_columns = set(expected_columns) - set(df.columns)
|
||||||
|
extra_columns = set(df.columns) - set(expected_columns)
|
||||||
|
|
||||||
|
if missing_columns:
|
||||||
|
results["数据类型检查"]["缺失列"] = list(missing_columns)
|
||||||
|
results["通过"] = False
|
||||||
|
if extra_columns:
|
||||||
|
results["数据类型检查"]["多余列"] = list(extra_columns)
|
||||||
|
|
||||||
|
results["数值范围检查"]["Time"] = {
|
||||||
|
"最小值": float(df["Time"].min()),
|
||||||
|
"最大值": float(df["Time"].max()),
|
||||||
|
"状态": "正常" if df["Time"].min() >= 0 and df["Time"].max() <= 172800 else "异常"
|
||||||
|
}
|
||||||
|
|
||||||
|
results["数值范围检查"]["Amount"] = {
|
||||||
|
"最小值": float(df["Amount"].min()),
|
||||||
|
"最大值": float(df["Amount"].max()),
|
||||||
|
"状态": "正常" if df["Amount"].min() >= 0 and df["Amount"].max() <= 10000 else "异常"
|
||||||
|
}
|
||||||
|
|
||||||
|
if df["Amount"].max() > 10000:
|
||||||
|
results["通过"] = False
|
||||||
|
|
||||||
|
class_counts = df.group_by("Class").agg(pl.len().alias("count")).to_dict(as_series=False)
|
||||||
|
class_dist = {row["Class"]: row["count"] for row in class_counts}
|
||||||
|
|
||||||
|
results["标签分布检查"] = {
|
||||||
|
"正常交易数": class_dist.get(0, 0),
|
||||||
|
"欺诈交易数": class_dist.get(1, 0),
|
||||||
|
"不平衡比例": class_dist.get(0, 0) / class_dist.get(1, 1) if class_dist.get(1, 0) > 0 else float('inf')
|
||||||
|
}
|
||||||
|
|
||||||
|
if class_dist.get(1, 0) == 0:
|
||||||
|
results["标签分布检查"]["警告"] = "未发现欺诈交易样本"
|
||||||
|
results["通过"] = False
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
results["错误"] = str(e)
|
||||||
|
results["通过"] = False
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def print_validation_results(results: dict) -> None:
|
||||||
|
"""打印验证结果"""
|
||||||
|
print("\n" + "=" * 60)
|
||||||
|
print("数据验证结果")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
print(f"\n总记录数: {results['总记录数']}")
|
||||||
|
|
||||||
|
print("\n缺失值检查:")
|
||||||
|
for col, status in results["缺失值检查"].items():
|
||||||
|
print(f" {col}: {status}")
|
||||||
|
|
||||||
|
if "数据类型检查" in results:
|
||||||
|
print("\n数据类型检查:")
|
||||||
|
for key, value in results["数据类型检查"].items():
|
||||||
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
|
print("\n数值范围检查:")
|
||||||
|
for col, info in results["数值范围检查"].items():
|
||||||
|
print(f" {col}:")
|
||||||
|
print(f" 最小值: {info['最小值']}")
|
||||||
|
print(f" 最大值: {info['最大值']}")
|
||||||
|
print(f" 状态: {info['状态']}")
|
||||||
|
|
||||||
|
print("\n标签分布检查:")
|
||||||
|
for key, value in results["标签分布检查"].items():
|
||||||
|
print(f" {key}: {value}")
|
||||||
|
|
||||||
|
print(f"\n验证结果: {'✓ 通过' if results['通过'] else '✗ 未通过'}")
|
||||||
|
print("=" * 60)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
test_data = pl.DataFrame({
|
||||||
|
'Time': [0.0, 1.0, 2.0],
|
||||||
|
'V1': [-1.36, 0.96, 1.89],
|
||||||
|
'V2': [0.96, -1.19, -1.94],
|
||||||
|
'V3': [1.89, -1.94, 1.60],
|
||||||
|
'V4': [-1.19, 1.60, 1.37],
|
||||||
|
'V5': [-1.94, 1.37, -0.34],
|
||||||
|
'V6': [1.60, -0.34, -0.47],
|
||||||
|
'V7': [1.37, -0.47, 1.42],
|
||||||
|
'V8': [-0.34, 1.42, 3.00],
|
||||||
|
'V9': [-0.47, 3.00, -0.58],
|
||||||
|
'V10': [1.42, -0.58, 1.18],
|
||||||
|
'V11': [3.00, 1.18, 1.67],
|
||||||
|
'V12': [-0.58, 1.67, -2.89],
|
||||||
|
'V13': [1.18, -2.89, -0.60],
|
||||||
|
'V14': [1.67, -0.60, -1.14],
|
||||||
|
'V15': [-2.89, -1.14, -0.21],
|
||||||
|
'V16': [-0.60, -0.21, 0.16],
|
||||||
|
'V17': [-1.14, 0.16, 0.30],
|
||||||
|
'V18': [-0.21, 0.30, -0.64],
|
||||||
|
'V19': [0.16, -0.64, -0.21],
|
||||||
|
'V20': [0.30, -0.21, 0.46],
|
||||||
|
'V21': [-0.64, 0.46, 0.10],
|
||||||
|
'V22': [-0.21, 0.10, -0.33],
|
||||||
|
'V23': [0.46, -0.33, 0.13],
|
||||||
|
'V24': [0.10, 0.13, -0.19],
|
||||||
|
'V25': [-0.33, -0.19, -0.26],
|
||||||
|
'V26': [0.13, -0.26, 100.0],
|
||||||
|
'V27': [-0.19, 100.0, 50.0],
|
||||||
|
'V28': [-0.26, 50.0, 25.0],
|
||||||
|
'Amount': [100.0, 50.0, 25.0],
|
||||||
|
'Class': [0, 1, 0]
|
||||||
|
})
|
||||||
|
|
||||||
|
passed, error = validate_dataframe(test_data)
|
||||||
|
print(f"Schema验证: {'✓ 通过' if passed else '✗ 未通过'}")
|
||||||
|
if error:
|
||||||
|
print(f"错误: {error}")
|
||||||
|
|
||||||
|
integrity_results = validate_data_integrity(test_data)
|
||||||
|
print_validation_results(integrity_results)
|
||||||
@ -38,9 +38,7 @@ class FraudDetectionInference:
|
|||||||
transaction_array = transaction_array.reshape(1, -1)
|
transaction_array = transaction_array.reshape(1, -1)
|
||||||
|
|
||||||
prediction = self.trainer.predict(transaction_array)
|
prediction = self.trainer.predict(transaction_array)
|
||||||
probability = self.trainer.predict_proba(transaction_array)
|
fraud_prob = float(self.trainer.predict_proba(transaction_array))
|
||||||
|
|
||||||
fraud_prob = float(probability[0])
|
|
||||||
normal_prob = float(1 - fraud_prob)
|
normal_prob = float(1 - fraud_prob)
|
||||||
|
|
||||||
max_prob = max(fraud_prob, normal_prob)
|
max_prob = max(fraud_prob, normal_prob)
|
||||||
|
|||||||
178
src/llm_integration.py
Normal file
178
src/llm_integration.py
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
from openai import OpenAI
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
from typing import Optional, Dict, Any, List
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DeepSeekLLM:
|
||||||
|
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
||||||
|
self.api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
|
||||||
|
self.base_url = base_url or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
|
||||||
|
|
||||||
|
if not self.api_key:
|
||||||
|
logger.warning("未设置 DEEPSEEK_API_KEY,LLM功能将不可用")
|
||||||
|
self.client = None
|
||||||
|
else:
|
||||||
|
self.client = OpenAI(
|
||||||
|
api_key=self.api_key,
|
||||||
|
base_url=self.base_url
|
||||||
|
)
|
||||||
|
logger.info("DeepSeek LLM 初始化成功")
|
||||||
|
|
||||||
|
def is_available(self) -> bool:
|
||||||
|
return self.client is not None
|
||||||
|
|
||||||
|
def explain_prediction(self, transaction: List[float], prediction: Dict[str, Any]) -> str:
|
||||||
|
if not self.is_available():
|
||||||
|
return "LLM服务不可用,请配置DEEPSEEK_API_KEY"
|
||||||
|
|
||||||
|
feature_names = [
|
||||||
|
'Time', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6', 'V7', 'V8', 'V9',
|
||||||
|
'V10', 'V11', 'V12', 'V13', 'V14', 'V15', 'V16', 'V17', 'V18', 'V19',
|
||||||
|
'V20', 'V21', 'V22', 'V23', 'V24', 'V25', 'V26', 'V27', 'V28', 'Amount'
|
||||||
|
]
|
||||||
|
|
||||||
|
transaction_dict = dict(zip(feature_names, transaction))
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
你是一个专业的信用卡欺诈检测专家。请分析以下交易数据并给出解释。
|
||||||
|
|
||||||
|
交易数据:
|
||||||
|
{transaction_dict}
|
||||||
|
|
||||||
|
预测结果:
|
||||||
|
- 预测类别: {'欺诈' if prediction['predicted_class'] == 1 else '正常'}
|
||||||
|
- 欺诈概率: {prediction['fraud_probability']:.4f}
|
||||||
|
- 正常概率: {prediction['normal_probability']:.4f}
|
||||||
|
- 置信度: {prediction['confidence']}
|
||||||
|
|
||||||
|
请用中文回答以下问题:
|
||||||
|
1. 为什么这个交易被预测为{'欺诈' if prediction['predicted_class'] == 1 else '正常'}?
|
||||||
|
2. 哪些特征对预测结果影响最大?
|
||||||
|
3. 请提供3条具体的行动建议。
|
||||||
|
|
||||||
|
请保持回答简洁、专业,控制在200字以内。
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model="deepseek-chat",
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=500
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM调用失败: {e}")
|
||||||
|
return f"LLM调用失败: {str(e)}"
|
||||||
|
|
||||||
|
def generate_action_suggestions(self, transaction: List[float], prediction: Dict[str, Any], feature_analysis: Dict[str, Any]) -> List[str]:
|
||||||
|
if not self.is_available():
|
||||||
|
return ["LLM服务不可用,请配置DEEPSEEK_API_KEY"]
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
你是一个专业的信用卡欺诈检测专家。基于以下信息,请生成3条具体的行动建议。
|
||||||
|
|
||||||
|
交易数据:
|
||||||
|
- 交易金额: {transaction[-1]}
|
||||||
|
- 交易时间: {transaction[0]}
|
||||||
|
|
||||||
|
预测结果:
|
||||||
|
- 预测类别: {'欺诈' if prediction['predicted_class'] == 1 else '正常'}
|
||||||
|
- 欺诈概率: {prediction['fraud_probability']:.4f}
|
||||||
|
- 置信度: {prediction['confidence']}
|
||||||
|
|
||||||
|
特征分析:
|
||||||
|
{feature_analysis}
|
||||||
|
|
||||||
|
请用中文生成3条具体的行动建议,每条建议应该包含:
|
||||||
|
1. 行动内容
|
||||||
|
2. 优先级(紧急/高/中/低)
|
||||||
|
3. 执行原因
|
||||||
|
|
||||||
|
请保持回答简洁、专业。
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model="deepseek-chat",
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=400
|
||||||
|
)
|
||||||
|
content = response.choices[0].message.content
|
||||||
|
|
||||||
|
suggestions = []
|
||||||
|
lines = content.split('\n')
|
||||||
|
for line in lines:
|
||||||
|
line = line.strip()
|
||||||
|
if line and not line.startswith(('1.', '2.', '3.', '-')):
|
||||||
|
suggestions.append(line)
|
||||||
|
|
||||||
|
return suggestions[:3] if suggestions else [content]
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM调用失败: {e}")
|
||||||
|
return [f"LLM调用失败: {str(e)}"]
|
||||||
|
|
||||||
|
def analyze_transaction_context(self, transaction: List[float], historical_data: Optional[List[List[float]]] = None) -> str:
|
||||||
|
if not self.is_available():
|
||||||
|
return "LLM服务不可用,请配置DEEPSEEK_API_KEY"
|
||||||
|
|
||||||
|
prompt = f"""
|
||||||
|
你是一个专业的信用卡欺诈检测专家。请分析以下交易数据。
|
||||||
|
|
||||||
|
交易数据:
|
||||||
|
- 交易金额: {transaction[-1]}
|
||||||
|
- 交易时间: {transaction[0]}
|
||||||
|
|
||||||
|
请用中文分析:
|
||||||
|
1. 这个交易金额是否异常?
|
||||||
|
2. 这个交易时间是否异常?
|
||||||
|
3. 基于这些信息,这个交易的风险等级如何?
|
||||||
|
|
||||||
|
请保持回答简洁、专业,控制在100字以内。
|
||||||
|
"""
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.client.chat.completions.create(
|
||||||
|
model="deepseek-chat",
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
temperature=0.7,
|
||||||
|
max_tokens=300
|
||||||
|
)
|
||||||
|
return response.choices[0].message.content
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"LLM调用失败: {e}")
|
||||||
|
return f"LLM调用失败: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
|
def create_llm(api_key: Optional[str] = None) -> DeepSeekLLM:
|
||||||
|
return DeepSeekLLM(api_key=api_key)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
llm = create_llm()
|
||||||
|
|
||||||
|
if llm.is_available():
|
||||||
|
test_transaction = [0.0, -1.36, 0.96, 1.89, -1.19, -1.94, 1.60, 1.37, -0.34, -0.47,
|
||||||
|
1.42, 3.00, -0.58, 1.18, 1.67, -2.89, -0.60, -1.14, -0.21, 0.16,
|
||||||
|
0.30, -0.64, -0.21, 0.46, 0.10, -0.33, 0.13, -0.19, -0.26, 100.0]
|
||||||
|
|
||||||
|
test_prediction = {
|
||||||
|
"predicted_class": 0,
|
||||||
|
"fraud_probability": 0.05,
|
||||||
|
"normal_probability": 0.95,
|
||||||
|
"confidence": "高"
|
||||||
|
}
|
||||||
|
|
||||||
|
explanation = llm.explain_prediction(test_transaction, test_prediction)
|
||||||
|
print("=== LLM 解释 ===")
|
||||||
|
print(explanation)
|
||||||
|
else:
|
||||||
|
print("LLM服务不可用,请配置DEEPSEEK_API_KEY")
|
||||||
105
src/test_fraud_detection.py
Normal file
105
src/test_fraud_detection.py
Normal file
@ -0,0 +1,105 @@
|
|||||||
|
import polars as pl
|
||||||
|
import numpy as np
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent))
|
||||||
|
from agent_app import create_agent
|
||||||
|
|
||||||
|
def test_fraud_detection():
|
||||||
|
"""测试欺诈检测功能"""
|
||||||
|
|
||||||
|
print("=" * 80)
|
||||||
|
print("欺诈检测系统测试")
|
||||||
|
print("=" * 80)
|
||||||
|
|
||||||
|
agent = create_agent(model_dir="models", model_name="random_forest")
|
||||||
|
|
||||||
|
test_data = [
|
||||||
|
{
|
||||||
|
"description": "正常交易(小额)",
|
||||||
|
"transaction": [0.0, -1.36, 0.96, 1.89, -1.19, -1.94, 1.60, 1.37, -0.34, -0.47,
|
||||||
|
1.42, 3.00, -0.58, 1.18, 1.67, -2.89, -0.60, -1.14, -0.21, 0.16,
|
||||||
|
0.30, -0.64, -0.21, 0.46, 0.10, -0.33, 0.13, -0.19, -0.26, 100.0],
|
||||||
|
"expected": "正常"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "欺诈交易(真实样本)",
|
||||||
|
"transaction": [406.0, -2.312227, 1.951992, -1.609851, 3.997906, -0.522188, -1.426545, -2.537387, 1.391657, -2.770089,
|
||||||
|
-2.772272, 3.202033, -2.899907, -0.595221, -4.289254, 0.389724, -1.140747, -2.830056, -0.016822, 0.416956,
|
||||||
|
0.126911, 0.517232, -0.035049, -0.465211, 0.320198, 0.044519, 0.177840, 0.261145, -0.143276, 0.0],
|
||||||
|
"expected": "欺诈"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"description": "正常交易(中等金额)",
|
||||||
|
"transaction": [240.0, 0.45, -0.23, 1.20, -0.56, 0.89, -1.23, 0.34, -0.67, 1.12,
|
||||||
|
-0.45, 0.78, -1.34, 0.45, -0.78, 1.23, -0.45, 0.67, -0.89, 0.34,
|
||||||
|
-0.56, 0.78, -0.45, 0.67, -0.34, 0.45, -0.23, 0.56, 0.25, 50.0],
|
||||||
|
"expected": "正常"
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
for i, test_case in enumerate(test_data, 1):
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"测试用例 {i}: {test_case['description']}")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
print(f"\n交易金额: ${test_case['transaction'][-1]:,.2f}")
|
||||||
|
print(f"交易时间: {test_case['transaction'][0]:.1f} 秒")
|
||||||
|
print(f"预期结果: {test_case['expected']}")
|
||||||
|
|
||||||
|
print(f"\n{'-' * 80}")
|
||||||
|
print("开始处理...")
|
||||||
|
print(f"{'-' * 80}")
|
||||||
|
|
||||||
|
result = agent.process_transaction(test_case['transaction'])
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("预测结果")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
print(f"\n预测类别: {result.evaluation.class_name}")
|
||||||
|
print(f"欺诈概率: {result.evaluation.fraud_probability:.4f}")
|
||||||
|
print(f"正常概率: {result.evaluation.normal_probability:.4f}")
|
||||||
|
print(f"置信度: {result.evaluation.confidence}")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("关键特征")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
for feature in result.explanation.key_features[:5]:
|
||||||
|
print(f"\n{feature.feature_name}:")
|
||||||
|
print(f" 特征值: {feature.value:.4f}")
|
||||||
|
print(f" 重要性: {feature.importance:.4f}")
|
||||||
|
print(f" 贡献度: {feature.contribution:.4f}")
|
||||||
|
print(f" 影响方向: {feature.impact}")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("总体解释")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
print(f"\n{result.explanation.overall_explanation}")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("行动计划")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
for action in result.action_plan.actions:
|
||||||
|
print(f"\n[{action.priority}] {action.action}")
|
||||||
|
print(f" 原因: {action.reason}")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print(f"预测结果: {result.evaluation.class_name} | 预期结果: {test_case['expected']}")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
if result.evaluation.class_name == test_case['expected']:
|
||||||
|
print("✓ 预测正确!")
|
||||||
|
else:
|
||||||
|
print("✗ 预测错误!")
|
||||||
|
|
||||||
|
print(f"\n{'=' * 80}")
|
||||||
|
print("测试完成")
|
||||||
|
print(f"{'=' * 80}")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
test_fraud_detection()
|
||||||
@ -6,6 +6,7 @@ from sklearn.metrics import (
|
|||||||
precision_recall_curve, auc, confusion_matrix
|
precision_recall_curve, auc, confusion_matrix
|
||||||
)
|
)
|
||||||
from imblearn.over_sampling import SMOTE
|
from imblearn.over_sampling import SMOTE
|
||||||
|
from lightgbm import LGBMClassifier
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import logging
|
import logging
|
||||||
import joblib
|
import joblib
|
||||||
@ -34,6 +35,12 @@ class CreditCardFraudModelTrainer:
|
|||||||
random_state=42,
|
random_state=42,
|
||||||
class_weight="balanced",
|
class_weight="balanced",
|
||||||
n_estimators=100
|
n_estimators=100
|
||||||
|
),
|
||||||
|
"lightgbm": LGBMClassifier(
|
||||||
|
random_state=42,
|
||||||
|
class_weight="balanced",
|
||||||
|
n_estimators=100,
|
||||||
|
verbose=-1
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user