From 8db338c6fb2f868f374df2c9a781cad9d38ea61a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9E=97=E5=98=89=E7=83=A8?= Date: Fri, 16 Jan 2026 19:28:30 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=8A=E4=BC=A0=E6=96=87=E4=BB=B6=E8=87=B3?= =?UTF-8?q?=20/?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- agent_app.cpython-312.pyc | Bin 0 -> 7843 bytes agent_app.py | 184 ++++++++++++++++++++++++++++++++++++++ data.py | 92 +++++++++++++++++++ streamlit_app.py | 152 +++++++++++++++++++++++++++++++ train.py | 90 +++++++++++++++++++ 5 files changed, 518 insertions(+) create mode 100644 agent_app.cpython-312.pyc create mode 100644 agent_app.py create mode 100644 data.py create mode 100644 streamlit_app.py create mode 100644 train.py diff --git a/agent_app.cpython-312.pyc b/agent_app.cpython-312.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df358bca87cf1423eb39620539c0b80735fa1e2c GIT binary patch literal 7843 zcmc&YT~t%onkUK0e@FrZ34a;|l%Rr$Q>|M6pjz8o+G=}S5u1rQ2O7)|Jtt`kgWhPV z2CIVCs?b73TWzdVX{#3$vF*AK)3xR$Y0@zsyVl&hniImyOnPp#CbKP!7SHz5f0SgD zydoBQA_A{1O~YneT6YC#M5S3TCbkijw1=Q%wCIW?HJXb@8W3;WVAFrFWriWPI0H$TsDK|n8loOi3lWw)s4yvikTv=shZ63RmHnXPXrphV@{TA(TSu=en&9c@ufWmqy26D1iH{-N= zY}rIeQ=JFjV?LmXW)Vd+i(dhLXqLb!NJ>meC~{Pqxos&eqve!rJuJj<`KTgIce8?4 z0!*2Ksj^@y3{z)dnk<+a!*Vh(c!g1Y=G8P9nwx>WI^X; zL3J2fn1LFz@ai$FC<7~|N~qGTTgb)xmSy1cv)1O}wF@$^V#+kSFmsxGymCHiokK0KzTWPDu%h0SBNhN>vLfk(- z-FuV2b1-q|6#w4Ec;Gz$@%uA*0GxRTDz7FpjlX$)=1$LL2`6r&IkAK0)D+Fy7<-4u z?sjn!U@or=CMOOai%ooZS;T0e!%5oRdj%YqKjc4~oId&o{_qV>?zA#?kJW)8@zaBe zLtjk?-sf*!;*>PywOJ9yF4gN*oiy9+a+@o>1%N9)aEr&$C7(aCy zxLMv|^>}E;C2T0C=vJ%O;lZdA*Dt~OVpD@N7lt|cUaQ0EveAzsi%s>#{of?6kH<&v zaPl|YUe@kvn=@tN^VqFhsVStxZFLFA4}LKH^<@0aP5yQtC%3s>9;?kWgH+SZ0;s%5 zJ;aB6eE*4rF!St1PC{8bz36fu0Sx(LlT)132`Jvc?-CyxjgOv>U!UZU22zWmdwR?) zX3%>GCso+2&JL@+&E;JNV~O$M=}VKbiKDPkSi*mL8J3u;me<85KS`Vq#D_i?0A^0w z0d%koQnkWS@lJg2@$AhwMF&H_WrrKV6bq;Ynx-dD#XlZ`32v{)23*x-F60yz3sN^0 z3#YPJoNmhNK>r$xuuTe+g8k!JUK+#N0PpLmbL(o!xgu%jI9MY z+~u)YJ34AQIvEAv!xbWs6~VK=N8BS9-P0Kc!~^V@Hlo|;Yw+iX$&D~HVvOjjg>Ln| zeB;RLk^EJ@ZT_8Ma@Dvlx)cM21B&Yl0a6Z8cGj^LUX zBKjA7jsEg5`QqFqK$B=-Z?I-#M7K$xV$;3cdBO74k=!-Dr~UO|a?L$M(ZEZ=Wm_YL zZN6vxwlKL3&v+@4yF-|KTVrUd*G{Q)5@!E)C9qLEObxd{oP9ASk0suZs{J)tL zFdK+nz#wHAmMoQ-gH@&ut4zH)ms2#-HaiQN$cxx&GkSazQj$k*O@G*fC7_2tc7{Lv z_4F5KuqZLerNQz$bD<}Gx;+GChpEzQwTCu#&%&1`or(*2vJ_~DeGqiUp z3*P+W1$!)5GV@j2xVp<@##K>hPa&nG_X;r)YVPRU@YwmLvdXqiQTf+rf$ath#3`8_6AcGx)@d z5#vj~XZxjL@}+yF3xhSgBc&}t($Ei_-h%nKC1Thk@b#X1Me~D8w?v8>KmuRxFYnzI zCL2IZgG+Wr3Yw9eM#yGiYqaYlypE27%7bU2n&ov$MTvUAfuh7c1USzUT1t_$td#`U zm23v*iB_cga_B2la23E6&{uaV&1z2jyp`EUd%!qhPPC=mWV{>tm!MrYL3Nhs7IhOA z_!A06Ko40ZT)K(&w2PM>g7!-49~@1>ENU`KI2CAc+F=3H`Ua=8G9G&?FagV{E$~1T zQtz|_OfSjOj#g$3tY#4nCKH1m1DOK5%kHsQx(Xkqx+V=kjxUS!wvUJwSNO?ex@b`u zG}@>_b5~&uDU3fVit!!WGNi+7L^?ibHH@AmJkS+Nd-ipt#~yP~I6m@G=9IL*%53u} z)1Y=7E2Vov`@gS-YtG>IAm;~{_H<;G9)OmD2YMVG9}&aiG%e^UvalcvIHyqJU&BI4 zh9%i}%w&@n*-Sx)@#A0dM@JGj284w-$JaSqr~i#dFzbmisW%Aujh3KdE#2R`;&%+t9mV zpyFKZU~NQQ{Zm=_x!r@i1BPo0E-wg|J?U>e*7TDmukX3u=Yqv`V-@4IU)Bcewgin0 zKWZ9d0BGuM3YM)JdwJaQr6suP>EP40U_DDOyu*Gbdwc@Sk0e$RB!kLH>XZ zLH>w#7?@jFDJ4R|(~x*4NN78nKc}RnaKG4-Vco~ENw>I7lz!g*lE?2xk;vAyXV`gm ztMo37#GZ`=(E=3jC-+Oc$!^KxSpF?|lw$=-qFdUYGm9hr*DdK5_DO}p>2cJ0kg38Y zO>4SEtB{pFwwe5uyQ1WX`u(zQ8Bhr>-+uB}=@S8cd6X=J(+D^v-7>sh{wsH*z)(fE zq8&-aoYE~%(@9}06!seSD?MltF5Sw^3>Bqr6H^-Wrsaf)@E~Upmv&@M=M)NV`B4#{ zrF$1~?E7~lyNI_%<{U4&9F-|mUUs`3rn&CPV0>@Jhw*=z-j zoYK*@S5W=f(-C~}g#dxDWILLm0dLxpmK1Z2An_?nLV2AXoVbJHl#N!8bt?m*0w=Y5 zXeWzeCx_SRVmYOUvAS54mEdGn*3;QR3l2a7AZMONJLEt}LfLH|O9$iLYh{obF{`ld zU@(a!1;{xu)yhc#DT7#+%Slm$!xPx;H84&cUJr`5PzHrWs92=!(>I3<>o;`5#z({+juQLe8e=dBocbmVVKQF4w>)(5}{bc(O zI#ZIA=9;4OP3Ihgj=;)*?egkt8!m6ST^?S#Av}L$|CWCC_uG<6!nihBN)(ly(++A! zx( zliGoQ2s#Ck8Ws3u6is~gsf|G0iZ zW0@55(IdJQc~n@pHo_0`Z79(`_+=DAB8rlN{K3y!!2Nd^cD?d))Xq_Qcx!W)Hj z3Jw^A9dQmKapZGw=OI){^k1C$z=tV|O^xsur+}cuHz(sm=Y;)**d9@al-Km3oBWZ< znG1*czBAx2#1CHM{Sz2rZ0dw?a)?)AQ>TR67#^IN97x3&sSJl75<0@M;wJ|Zmkwf| z9drdU_4uJ5aT+ykw8sM8J<>g$eTev(RF_+sg5pK9-Y4-xXBhQ zqkfT-25picbF?5@+Wh{VM|c0J#2j6)0?&9gSoT`h3@uT*Ji4GJI)7!btbTS}a6hvh zOZ4oR7n?xTA8|h)W6dcMSb4;~bIM1MN(V^BRh$I_L^3SB0=Q8TA=~&E``tZj`^DX2 z!OmkZKrO7xuDdgC;C|ABsK>>VvDV!ZN}5hYc9SW4Cx<_5A(5?ue{BC6i9V@K+(y{M zd*on`rIe~$ir^>3$CM0D<&^L}R)5*>@;$?V=zQAi`J*5>ulfd5{1|+<6x?223$a85HXZ@ogrw#X+rE z!89NcQgt{5hErtB9@O27s^w5|QpW9dQD%9T*;6B9WI6fKr+KIj28vn8MyV7s9 z-`j$Q9TCM&Fc*a-cjqk$&07*sgy+@y)lq%Pz_yUS!Y>7w!FV^nB9vb-@=7>=sbBF> zLl{bbDl8p%Wn}ZI-$lzRN9>`p716?qk=FvwQ04ky#Z$q84bl2_6VLrn|4i7hJyicp zz}kOc#5%I`hXQl3{+WDH@Q&>a?3*b1cFp(OLk+(TZh1BMYHP5~9qia2YC8~Y z{aunEHi>tLp()%ceyEliRDMNL4~E#U`>($!iGpWEEE2-+)@rsa(l(hzUB%cTU~6$> z$ppPM9s3++&D_id^3iaRf|Gz8iM$fwrI2-*eJ;k1x@J^WqY4Gc%oFJfr83yDJ5lifHOUrHfmAzEnIxbIG(JF0vZhFC43;iS63~tr zAEHLs1#N)BaHrMovRG!yQCH=(x~y$9$`|1TJY%ORP63&7kKM|2a$=X;9t0-gB#?FI z6gJx7@H(u_jF6>uT061RH}l_6IJsBw6wNw;Ls5+Tf6ml)6{?M4^T-09q#35I5n z4mOK8ncYPJ@#bPkJa+M{aN0`U0l}bMqSQLkOyo*^v&Uhzbe+l1W1=}VElj`sekrgp-M_b zBGJDPmG_C9`-J*Fk$<1ayH9BE6G|8Xxe^(C^1DQSh{zvUa&Fn+vS4lVkHqe%ve;jE zS6LiV77wfmD=U1mhZ?SVKoAT|E3A0$P8 z))a>|rM{ dict: + """ + Tool 1: 调用 ML 模型预测购买概率 + """ + if not self.artifacts: + return {"score": 0.0, "reason": "Model not loaded"} + + # 转换输入为 DataFrame + data = features.model_dump() + df = pd.DataFrame([data]) + + # 预处理 (使用训练时保存的 encoder) + # 注意:这里需要严格复现训练时的预处理逻辑 + # 训练时我们做了 Label Encoding + for col, le in self.artifacts['encoders'].items(): + if col in df.columns: + # 处理未知类别 + try: + df[col] = le.transform(df[col].astype(str)) + except: + # 遇到未知类别,这里简单处理为 0 (或者 mode) + logger.warning(f"Unknown category in {col}") + df[col] = 0 + + # 确保列顺序一致 + # 我们训练时用了 X (df.drop(target)) + # 这里需要筛选出 numeric_cols + categorical_cols + # 简单起见,我们假设 feature names 保存了顺序 + feature_names = self.artifacts['features'] + + # 补齐可能缺失的列 + for col in feature_names: + if col not in df.columns: + df[col] = 0 + + X_input = df[feature_names] + + # 预测 + model = self.artifacts['lgb_model'] # 优先使用 LightGBM + prob = model.predict_proba(X_input)[0][1] + + return { + "score": float(prob), + "top_features": ["balance", "poutcome"] # 这里简化,实际可用 SHAP + } + + def get_strategy(self, score: float) -> dict: + """ + Tool 2: 规则引擎/检索工具 + """ + if score > 0.6: + return { + "segment": "高意向 VIP", + "action_type": "人工介入", + "templates": ["尊贵的客户,鉴于您...", "专属理财经理一对一服务"] + } + elif score > 0.3: + return { + "segment": "潜在客户", + "action_type": "自动化营销", + "templates": ["你好,近期理财活动...", "点击领取加息券"] + } + else: + return { + "segment": "低意向群体", + "action_type": "静默/邮件", + "templates": ["月度财经摘要"] + } + + def run(self, features: CustomerFeatures) -> Decision: + """ + Agent 主流程 + """ + logger.info(f"Agent 正在处理客户: {features.job}, {features.age}岁") + + # 1. 感知 (调用 ML 工具) + pred_result = self.predict_risk(features) + score = pred_result["score"] + + # 2. 规划 (调用 策略工具) + strategy = self.get_strategy(score) + + # 3. 决策 (模拟 LLM 整合) + # 在真实场景中,这里构建 Prompt 发送给 DeepSeek + # 这里我们用 Python 逻辑模拟 LLM 的结构化输出能力 + + decision = Decision( + risk_score=round(score, 4), + customer_segment=strategy["segment"], + decision=f"建议采取 {strategy['action_type']}", + actions=[f"使用话术: {t}" for t in strategy["templates"]], + rationale=f"模型预测概率为 {score:.1%},属于{strategy['segment']}。该群体对{strategy['action_type']}转化率较高。" + ) + + return decision + +if __name__ == "__main__": + # 测试 Agent + agent = MarketingAgent() + + # 构造一个测试用例 + test_customer = CustomerFeatures( + age=35, + job="management", + marital="married", + education="tertiary", + default="no", + balance=2000, + housing="yes", + loan="no", + contact="cellular", + day=15, + month="may", + campaign=1, + pdays=-1, + previous=0, + poutcome="unknown" + ) + + result = agent.run(test_customer) + print("\n=== Agent Decision ===") + print(result.model_dump_json(indent=2)) diff --git a/data.py b/data.py new file mode 100644 index 0000000..067d120 --- /dev/null +++ b/data.py @@ -0,0 +1,92 @@ +import polars as pl +import pandera as pa +from pandera import Column, Check, DataFrameSchema +import logging + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +# ========================================== +# 1. 定义 Pandera Schema (数据契约) +# ========================================== +# 原始数据 Schema +raw_schema = DataFrameSchema({ + "age": Column(int, checks=Check.ge(18)), + "job": Column(str), + "marital": Column(str), + "education": Column(str), + "default": Column(str, checks=Check.isin(["yes", "no"])), + "balance": Column(int), + "housing": Column(str, checks=Check.isin(["yes", "no"])), + "loan": Column(str, checks=Check.isin(["yes", "no"])), + "contact": Column(str), + "day": Column(int, checks=[Check.ge(1), Check.le(31)]), + "month": Column(str), + "duration": Column(int, checks=Check.ge(0)), + "campaign": Column(int, checks=Check.ge(1)), + "pdays": Column(int), + "previous": Column(int, checks=Check.ge(0)), + "poutcome": Column(str), + "deposit": Column(str, checks=Check.isin(["yes", "no"])), +}) + +# 清洗后 Schema +processed_schema = DataFrameSchema({ + "age": Column(int), + "balance": Column(int), + "deposit": Column(int, checks=Check.isin([0, 1])), + # 其他数值化或保留的特征... +}) + +def load_and_clean_data(file_path: str): + """ + 使用 Polars 加载并清洗数据 + """ + logger.info(f"正在加载数据: {file_path}") + + # 1. Lazy Load + lf = pl.scan_csv(file_path) + + # 2. 初步清洗计划 + # - 移除 duration (避免数据泄露) + # - 将 deposit (yes/no) 转换为 (1/0) + # - 简单的分类变量编码 (为了 LightGBM,我们可以保留分类类型或做 Label Encoding) + # LightGBM 原生支持 Category,但 sklearn 需要数值。 + # 为了通用性,这里做 Label Encoding 或者 One-Hot。 + # 但 Polars 的 Label Encoding 比较手动。 + # 我们这里先只做核心转换。 + + processed_lf = ( + lf.drop(["duration"]) # 移除泄露特征 + .with_columns([ + pl.col("deposit").replace({"yes": 1, "no": 0}).cast(pl.Int64).alias("target"), + # 简单的特征工程示例:将 pdays -1 处理为 999 或单独一类 (这里保持原样,树模型能处理) + ]) + .drop("deposit") # 移除原始标签列,保留 target + ) + + # 3. 执行计算 (Collect) + df = processed_lf.collect() + + logger.info(f"数据加载完成,形状: {df.shape}") + + # 4. Pandera 验证 (转换回 Pandas 验证,因为 Pandera 对 Polars 支持尚在实验阶段或部分支持) + # 这里我们验证关键字段 + try: + # 简单验证一下 target 是否只有 0 和 1 + assert df["target"].n_unique() <= 2 + logger.info("基础数据验证通过") + except Exception as e: + logger.error(f"数据验证失败: {e}") + raise e + + return df + +if __name__ == "__main__": + # 测试代码 + try: + df = load_and_clean_data("data/bank.csv") + print(df.head()) + except Exception as e: + print(f"Error: {e}") diff --git a/streamlit_app.py b/streamlit_app.py new file mode 100644 index 0000000..037386f --- /dev/null +++ b/streamlit_app.py @@ -0,0 +1,152 @@ +import streamlit as st +import pandas as pd +import joblib +import os +import sys + +# 添加项目根目录到 Path 以便导入 src +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) + +from src.agent_app import MarketingAgent, CustomerFeatures + +# 配置页面 +st.set_page_config( + page_title="智能银行营销助手", + page_icon="🤖", + layout="wide" +) + +# 侧边栏:项目信息 +with st.sidebar: + st.title("🏦 智能营销系统") + st.markdown("---") + st.info("**Day 5 演示版**") + st.markdown(""" + **核心能力:** + 1. 📊 **LightGBM** 客户购买预测 + 2. 🧠 **Agent** 策略生成 + 3. 📝 **Pydantic** 结构化输出 + """) + st.markdown("---") + st.caption("由第 X 组开发") + +# 主界面 +st.title("🤖 客户意向预测与决策系统") + +# 1. 模拟客户输入 +st.header("1. 录入客户信息") + +col1, col2, col3 = st.columns(3) + +# 映射字典 +job_map = { + "management": "管理人员", "technician": "技术人员", "entrepreneur": "企业家", + "blue-collar": "蓝领", "unknown": "未知", "retired": "退休人员", + "admin.": "行政人员", "services": "服务业", "self-employed": "自雇人士", + "unemployed": "失业", "maid": "家政", "student": "学生" +} +education_map = {"tertiary": "高等教育", "secondary": "中等教育", "primary": "初等教育", "unknown": "未知"} +marital_map = {"married": "已婚", "single": "单身", "divorced": "离异"} +binary_map = {"yes": "是", "no": "否"} +contact_map = {"cellular": "手机", "telephone": "座机", "unknown": "未知"} +month_map = { + "jan": "1月", "feb": "2月", "mar": "3月", "apr": "4月", "may": "5月", "jun": "6月", + "jul": "7月", "aug": "8月", "sep": "9月", "oct": "10月", "nov": "11月", "dec": "12月" +} +poutcome_map = {"unknown": "未知", "failure": "失败", "other": "其他", "success": "成功"} + +# 辅助函数:反向查找 key +def get_key(val, my_dict): + for key, value in my_dict.items(): + if val == value: return key + return val + +with col1: + age = st.number_input("年龄", 18, 100, 30) + job_display = st.selectbox("职业", list(job_map.values())) + job = get_key(job_display, job_map) + + education_display = st.selectbox("教育", list(education_map.values())) + education = get_key(education_display, education_map) + + balance = st.number_input("账户余额 (欧元)", -1000, 100000, 1500) + +with col2: + marital_display = st.selectbox("婚姻", list(marital_map.values())) + marital = get_key(marital_display, marital_map) + + housing_display = st.selectbox("是否有房贷", list(binary_map.values())) + housing = get_key(housing_display, binary_map) + + loan_display = st.selectbox("是否有个人贷", list(binary_map.values())) + loan = get_key(loan_display, binary_map) + + default_display = st.selectbox("是否有违约记录", list(binary_map.values())) + default = get_key(default_display, binary_map) + +with col3: + contact_display = st.selectbox("联系方式", list(contact_map.values())) + contact = get_key(contact_display, contact_map) + + month_display = st.selectbox("最后联系月份", list(month_map.values())) + month = get_key(month_display, month_map) + + day = st.slider("最后联系日", 1, 31, 15) + + poutcome_display = st.selectbox("上次活动结果", list(poutcome_map.values())) + poutcome = get_key(poutcome_display, poutcome_map) + +# 隐藏的高级特征 +with st.expander("高级营销特征 (可选)"): + campaign = st.number_input("本次活动联系次数", 1, 50, 1) + pdays = st.number_input("距离上次联系天数 (-1代表无)", -1, 999, -1) + previous = st.number_input("活动前联系次数", 0, 100, 0) + +# 2. 触发 Agent +if st.button("🚀 开始分析与决策", type="primary"): + try: + # 构造 Pydantic 对象 + customer = CustomerFeatures( + age=age, job=job, marital=marital, education=education, + default=default, balance=balance, housing=housing, loan=loan, + contact=contact, day=day, month=month, + campaign=campaign, pdays=pdays, previous=previous, poutcome=poutcome + ) + + # 初始化 Agent + with st.spinner("Agent 正在加载模型并思考..."): + agent = MarketingAgent() + decision = agent.run(customer) + + # 3. 展示结果 + st.divider() + st.header("2. 智能分析报告") + + # 结果看板 + res_col1, res_col2 = st.columns([1, 2]) + + with res_col1: + st.metric("预测购买概率", f"{decision.risk_score:.1%}") + if decision.risk_score > 0.6: + st.success(f"分群:{decision.customer_segment}") + elif decision.risk_score > 0.3: + st.warning(f"分群:{decision.customer_segment}") + else: + st.error(f"分群:{decision.customer_segment}") + + with res_col2: + st.subheader("💡 决策建议") + st.info(decision.decision) + st.markdown(f"**决策依据:** {decision.rationale}") + + # 行动清单 + st.subheader("📝 执行清单") + for i, action in enumerate(decision.actions, 1): + st.write(f"{i}. {action}") + + # JSON 视图 + with st.expander("查看原始 JSON 输出 (Traceable)"): + st.json(decision.model_dump()) + + except Exception as e: + st.error(f"发生错误: {str(e)}") diff --git a/train.py b/train.py new file mode 100644 index 0000000..862a0d1 --- /dev/null +++ b/train.py @@ -0,0 +1,90 @@ +import polars as pl +import pandas as pd +import lightgbm as lgb +from sklearn.linear_model import LogisticRegression +from sklearn.model_selection import train_test_split +from sklearn.metrics import classification_report, roc_auc_score, f1_score +from sklearn.preprocessing import LabelEncoder, StandardScaler +import joblib +import logging +import os +from src.data import load_and_clean_data + +# 配置日志 +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') +logger = logging.getLogger(__name__) + +def train_models(data_path="data/bank.csv", model_dir="models"): + # 1. 加载数据 + df_pl = load_and_clean_data(data_path) + df = df_pl.to_pandas() # 转换为 Pandas 以兼容 Sklearn + + # 2. 特征预处理 + # 区分分类和数值特征 + target_col = "target" + X = df.drop(columns=[target_col]) + y = df[target_col] + + cat_cols = X.select_dtypes(include=['object', 'category']).columns.tolist() + num_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist() + + # Label Encoding (为了简化,LightGBM 可以直接处理 Category,但 Sklearn 需要编码) + encoders = {} + for col in cat_cols: + le = LabelEncoder() + X[col] = le.fit_transform(X[col].astype(str)) + encoders[col] = le + + # 3. 数据切分 + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) + + # 4. 训练基线模型 (Logistic Regression) + logger.info("训练基线模型 (Logistic Regression)...") + # 逻辑回归需要归一化 + scaler = StandardScaler() + X_train_scaled = X_train.copy() + X_test_scaled = X_test.copy() + X_train_scaled[num_cols] = scaler.fit_transform(X_train[num_cols]) + X_test_scaled[num_cols] = scaler.transform(X_test[num_cols]) + + lr_model = LogisticRegression(max_iter=1000, random_state=42) + lr_model.fit(X_train_scaled, y_train) + + lr_pred = lr_model.predict(X_test_scaled) + lr_prob = lr_model.predict_proba(X_test_scaled)[:, 1] + + logger.info(f"Baseline F1: {f1_score(y_test, lr_pred):.4f}") + logger.info(f"Baseline AUC: {roc_auc_score(y_test, lr_prob):.4f}") + + # 5. 训练进阶模型 (LightGBM) + logger.info("训练进阶模型 (LightGBM)...") + lgb_model = lgb.LGBMClassifier(n_estimators=100, learning_rate=0.05, random_state=42, verbose=-1) + lgb_model.fit(X_train, y_train) + + lgb_pred = lgb_model.predict(X_test) + lgb_prob = lgb_model.predict_proba(X_test)[:, 1] + + logger.info(f"LightGBM F1: {f1_score(y_test, lgb_pred):.4f}") + logger.info(f"LightGBM AUC: {roc_auc_score(y_test, lgb_prob):.4f}") + + # 6. 保存模型与元数据 + if not os.path.exists(model_dir): + os.makedirs(model_dir) + + artifacts = { + "lgb_model": lgb_model, + "lr_model": lr_model, + "scaler": scaler, + "encoders": encoders, + "features": list(X.columns), + "cat_cols": cat_cols, + "num_cols": num_cols + } + + joblib.dump(artifacts, os.path.join(model_dir, "model_artifacts.pkl")) + logger.info(f"模型已保存至 {model_dir}/model_artifacts.pkl") + + return artifacts + +if __name__ == "__main__": + train_models()