G09-BankMarketing/smart_agent.py
2026-01-16 19:22:13 +08:00

237 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pandas as pd
import joblib
import json
import re
# ==========================================
# 0. 基础工具类定义
# ==========================================
class BaseTool:
def __init__(self, name, description):
self.name = name
self.description = description
def run(self, *args, **kwargs):
raise NotImplementedError
# ==========================================
# 1. 工具实现
# ==========================================
class MLPredictionTool(BaseTool):
"""
工具 1: 机器学习预测工具
功能: 加载预训练模型,预测客户转化概率,并提供特征归因。
"""
def __init__(self, artifact_path='model_artifacts.pkl'):
super().__init__("ML_Predictor", "输入客户特征,输出购买概率和关键影响因素")
print(f"[{self.name}] 正在加载模型资产...")
self.artifacts = joblib.load(artifact_path)
self.model = self.artifacts['model']
self.encoders = self.artifacts['encoders']
self.feature_meta = self.artifacts['feature_meta']
# 获取特征重要性,用于简单的归因解释
self.feature_importances = pd.Series(
self.model.feature_importances_,
index=self.feature_meta['all_cols']
).sort_values(ascending=False)
def preprocess(self, customer_data):
df = pd.DataFrame([customer_data])
if 'duration' in df.columns: df = df.drop('duration', axis=1)
for col, le in self.encoders.items():
if col in df.columns:
try:
df[col] = le.transform(df[col])
except:
df[col] = 0 # 简单处理未知值
# 补齐可能缺失的列全0填充并保持顺序
for col in self.feature_meta['all_cols']:
if col not in df.columns:
df[col] = 0
return df[self.feature_meta['all_cols']]
def run(self, customer_data):
# 1. 预处理
X = self.preprocess(customer_data)
# 2. 预测
prob = float(self.model.predict_proba(X)[0][1]) # 强制转换为 python float
# 3. 归因 (Attribution)
# 简单逻辑:找出该客户数据中,属于 Top 5 重要特征的字段及其值
top_features = self.feature_importances.head(5).index.tolist()
attribution = {feat: customer_data.get(feat, 'N/A') for feat in top_features}
return {
"probability": round(prob, 4),
"risk_level": "High" if prob < 0.3 else ("Medium" if prob < 0.7 else "Low"),
"key_factors": attribution
}
class StrategyRetrievalTool(BaseTool):
"""
工具 2: 策略检索工具
功能: 根据客户分群或意向等级,检索对应的营销话术和产品包。
"""
def __init__(self):
super().__init__("Strategy_Retriever", "根据意向分检索营销策略")
# 模拟向量数据库或规则库
self.knowledge_base = {
"High_Intent": {
"segment": "VIP_Growth",
"channel": "Personal_Call",
"product": "大额存单/结构性存款",
"script_template": "尊贵的{name},鉴于您良好的{key_factor},我们要为您推荐专属..."
},
"Medium_Intent": {
"segment": "Potential_Saver",
"channel": "SMS_Web",
"product": "灵活理财/定投",
"script_template": "你好,发现您对{key_factor}感兴趣,这里有一份理财攻略..."
},
"Low_Intent": {
"segment": "General_Mass",
"channel": "Email",
"product": "货币基金/新人礼包",
"script_template": "本月财经快讯:如何打理您的零钱..."
}
}
def run(self, probability):
if probability > 0.7:
key = "High_Intent"
elif probability > 0.4:
key = "Medium_Intent"
else:
key = "Low_Intent"
return self.knowledge_base[key]
# ==========================================
# 2. Agent 定义 (Orchestrator)
# ==========================================
class SalesAgent:
def __init__(self):
self.tools = {
"predictor": MLPredictionTool(),
"retriever": StrategyRetrievalTool()
}
def mock_llm_inference(self, prompt):
"""
模拟 LLM 的生成能力。
在真实场景中,这里调用 openai.ChatCompletion.create(model="gpt-4", messages=...)
"""
# 从 Prompt 中解析 Context
# 这是一个 Mock所以我们用正则或简单的逻辑把 Prompt 里的信息“反刍”出来
# 实际上 LLM 会进行语义理解和润色
# 提取关键信息用于 Mock 输出
try:
context_str = re.search(r"【Context】(.*?)【Instruction】", prompt, re.S).group(1)
context = json.loads(context_str)
pred_result = context['prediction']
strategy_result = context['strategy']
customer_info = context['customer_raw']
# 模拟 LLM 生成话术
script = strategy_result['script_template'].format(
name="客户",
key_factor=list(pred_result['key_factors'].keys())[0]
)
response = {
"thought_process": f"模型预测概率为 {pred_result['probability']},属于 {strategy_result['segment']} 客群。已检索到对应策略,建议通过 {strategy_result['channel']} 触达。",
"final_decision": {
"action": strategy_result['channel'],
"product_recommendation": strategy_result['product'],
"personalized_script": script,
"attribution_explanation": f"预测模型显示该客户成交概率为 {pred_result['probability']*100}%,主要受 {json.dumps(pred_result['key_factors'], ensure_ascii=False)} 等因素影响。"
}
}
return json.dumps(response, ensure_ascii=False, indent=2)
except Exception as e:
return json.dumps({"error": f"LLM Mock Failed: {str(e)}"})
def process_request(self, customer_data):
print(f"\n[Agent] 收到新请求: {customer_data.get('job', 'Unknown')} | {customer_data.get('age')}")
# --- Step 1: 调用 ML 工具进行预测 ---
print(f"[Agent] 调用工具: {self.tools['predictor'].name} ...")
pred_result = self.tools['predictor'].run(customer_data)
print(f" >>> 预测结果: 概率={pred_result['probability']}, 关键因素={list(pred_result['key_factors'].keys())}")
# --- Step 2: 调用 检索工具获取策略 ---
print(f"[Agent] 调用工具: {self.tools['retriever'].name} ...")
strategy_result = self.tools['retriever'].run(pred_result['probability'])
print(f" >>> 检索结果: 渠道={strategy_result['channel']}, 产品={strategy_result['product']}")
# --- Step 3: LLM 整合信息 ---
print(f"[Agent] 请求 LLM 进行最终决策与生成...")
# 构建 Context
context = {
"customer_raw": customer_data,
"prediction": pred_result,
"strategy": strategy_result
}
prompt = f"""
你是一个智能营销助手。请根据以下上下文信息,生成结构化的营销建议。
【Context】
{json.dumps(context, ensure_ascii=False)}
【Instruction】
1. 解释模型预测结果。
2. 结合策略库,生成具体的话术。
3. 输出 JSON 格式。
"""
final_output = self.mock_llm_inference(prompt)
return final_output
# ==========================================
# 3. 主程序入口
# ==========================================
if __name__ == "__main__":
# 1. 准备数据
df_raw = pd.read_csv('bank.csv')
# 2. 初始化 Agent
agent = SalesAgent()
# 3. 模拟场景
print("\n" + "="*60)
print("场景演示: Agent 协调多个工具完成决策")
print("="*60)
# 场景 A: 低意向客户
customer_a = df_raw.iloc[1].to_dict() # 假设这是低概率
if 'deposit' in customer_a: del customer_a['deposit']
result_a = agent.process_request(customer_a)
print("\n[Agent Final Output]")
print(result_a)
print("-" * 60)
# 场景 B: 高意向客户 (人工构造)
customer_b = customer_a.copy()
customer_b.update({
'poutcome': 'success',
'duration': 1000, # 注意工具内部会移除 duration这里只是模拟输入
'contact': 'cellular',
'month': 'oct'
})
result_b = agent.process_request(customer_b)
print("\n[Agent Final Output]")
print(result_b)