feat(模型): 添加模型训练、评估和预测功能
- 实现模型训练模块,支持逻辑回归和LightGBM - 添加模型评估功能,包括分类报告和混淆矩阵 - 实现预测功能,支持单条、批量及DataFrame输入 - 更新.gitignore,移除models目录忽略 - 添加模型评估结果文件
This commit is contained in:
parent
e081ed9329
commit
48013e93bd
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,8 +3,6 @@
|
|||||||
env/
|
env/
|
||||||
venv/
|
venv/
|
||||||
|
|
||||||
# Model files
|
|
||||||
models/
|
|
||||||
|
|
||||||
# Environment variables
|
# Environment variables
|
||||||
.env
|
.env
|
||||||
|
|||||||
BIN
models/lightgbm.joblib
Normal file
BIN
models/lightgbm.joblib
Normal file
Binary file not shown.
8
models/lightgbm_classification_report.txt
Normal file
8
models/lightgbm_classification_report.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
precision recall f1-score support
|
||||||
|
|
||||||
|
ham 0.98 0.99 0.99 969
|
||||||
|
spam 0.93 0.89 0.91 145
|
||||||
|
|
||||||
|
accuracy 0.98 1114
|
||||||
|
macro avg 0.96 0.94 0.95 1114
|
||||||
|
weighted avg 0.98 0.98 0.98 1114
|
||||||
BIN
models/lightgbm_confusion_matrix.png
Normal file
BIN
models/lightgbm_confusion_matrix.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 65 KiB |
BIN
models/logistic_regression.joblib
Normal file
BIN
models/logistic_regression.joblib
Normal file
Binary file not shown.
8
models/logistic_regression_classification_report.txt
Normal file
8
models/logistic_regression_classification_report.txt
Normal file
@ -0,0 +1,8 @@
|
|||||||
|
precision recall f1-score support
|
||||||
|
|
||||||
|
ham 0.99 0.99 0.99 969
|
||||||
|
spam 0.92 0.94 0.93 145
|
||||||
|
|
||||||
|
accuracy 0.98 1114
|
||||||
|
macro avg 0.95 0.96 0.96 1114
|
||||||
|
weighted avg 0.98 0.98 0.98 1114
|
||||||
BIN
models/logistic_regression_confusion_matrix.png
Normal file
BIN
models/logistic_regression_confusion_matrix.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 64 KiB |
42
src/models/__init__.py
Normal file
42
src/models/__init__.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from .train import (
|
||||||
|
build_tfidf_vectorizer,
|
||||||
|
build_logistic_regression,
|
||||||
|
build_lightgbm,
|
||||||
|
train_model,
|
||||||
|
save_model,
|
||||||
|
load_model
|
||||||
|
)
|
||||||
|
from .evaluate import (
|
||||||
|
evaluate_model,
|
||||||
|
print_evaluation_results,
|
||||||
|
plot_confusion_matrix,
|
||||||
|
save_evaluation_results,
|
||||||
|
compare_models
|
||||||
|
)
|
||||||
|
from .predict import (
|
||||||
|
predict_spam,
|
||||||
|
predict_batch_spam,
|
||||||
|
predict_from_df
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
# train.py
|
||||||
|
"build_tfidf_vectorizer",
|
||||||
|
"build_logistic_regression",
|
||||||
|
"build_lightgbm",
|
||||||
|
"train_model",
|
||||||
|
"save_model",
|
||||||
|
"load_model",
|
||||||
|
|
||||||
|
# evaluate.py
|
||||||
|
"evaluate_model",
|
||||||
|
"print_evaluation_results",
|
||||||
|
"plot_confusion_matrix",
|
||||||
|
"save_evaluation_results",
|
||||||
|
"compare_models",
|
||||||
|
|
||||||
|
# predict.py
|
||||||
|
"predict_spam",
|
||||||
|
"predict_batch_spam",
|
||||||
|
"predict_from_df"
|
||||||
|
]
|
||||||
159
src/models/evaluate.py
Normal file
159
src/models/evaluate.py
Normal file
@ -0,0 +1,159 @@
|
|||||||
|
from sklearn.metrics import (
|
||||||
|
accuracy_score,
|
||||||
|
precision_score,
|
||||||
|
recall_score,
|
||||||
|
f1_score,
|
||||||
|
classification_report,
|
||||||
|
confusion_matrix
|
||||||
|
)
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
import seaborn as sns
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
import polars as pl
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_model(
|
||||||
|
model,
|
||||||
|
test_df: pl.DataFrame
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""评估模型"""
|
||||||
|
# 准备测试数据
|
||||||
|
X_test = test_df["clean_text"].to_list()
|
||||||
|
y_test = test_df["label_num"].to_list()
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
y_pred = model.predict(X_test)
|
||||||
|
y_pred_proba = model.predict_proba(X_test)[:, 1]
|
||||||
|
|
||||||
|
# 计算指标
|
||||||
|
accuracy = accuracy_score(y_test, y_pred)
|
||||||
|
precision = precision_score(y_test, y_pred, average="macro")
|
||||||
|
recall = recall_score(y_test, y_pred, average="macro")
|
||||||
|
f1 = f1_score(y_test, y_pred, average="macro")
|
||||||
|
|
||||||
|
# 生成分类报告
|
||||||
|
class_report = classification_report(y_test, y_pred, target_names=["ham", "spam"])
|
||||||
|
|
||||||
|
# 生成混淆矩阵
|
||||||
|
conf_matrix = confusion_matrix(y_test, y_pred)
|
||||||
|
|
||||||
|
# 收集结果
|
||||||
|
results = {
|
||||||
|
"accuracy": accuracy,
|
||||||
|
"precision": precision,
|
||||||
|
"recall": recall,
|
||||||
|
"f1": f1,
|
||||||
|
"class_report": class_report,
|
||||||
|
"conf_matrix": conf_matrix,
|
||||||
|
"y_test": y_test,
|
||||||
|
"y_pred": y_pred,
|
||||||
|
"y_pred_proba": y_pred_proba
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def print_evaluation_results(results: Dict[str, Any]) -> None:
|
||||||
|
"""打印评估结果"""
|
||||||
|
print(f"Accuracy: {results['accuracy']:.4f}")
|
||||||
|
print(f"Precision (macro): {results['precision']:.4f}")
|
||||||
|
print(f"Recall (macro): {results['recall']:.4f}")
|
||||||
|
print(f"F1 Score (macro): {results['f1']:.4f}")
|
||||||
|
print("\nClassification Report:")
|
||||||
|
print(results['class_report'])
|
||||||
|
|
||||||
|
|
||||||
|
def plot_confusion_matrix(
|
||||||
|
conf_matrix: List[List[int]],
|
||||||
|
save_path: str = None
|
||||||
|
) -> None:
|
||||||
|
"""绘制混淆矩阵"""
|
||||||
|
plt.figure(figsize=(8, 6))
|
||||||
|
sns.heatmap(
|
||||||
|
conf_matrix,
|
||||||
|
annot=True,
|
||||||
|
fmt="d",
|
||||||
|
cmap="Blues",
|
||||||
|
xticklabels=["ham", "spam"],
|
||||||
|
yticklabels=["ham", "spam"]
|
||||||
|
)
|
||||||
|
plt.xlabel("Predicted Label")
|
||||||
|
plt.ylabel("True Label")
|
||||||
|
plt.title("Confusion Matrix")
|
||||||
|
|
||||||
|
if save_path:
|
||||||
|
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||||||
|
plt.savefig(save_path, dpi=300, bbox_inches="tight")
|
||||||
|
print(f"混淆矩阵已保存到: {save_path}")
|
||||||
|
|
||||||
|
plt.close()
|
||||||
|
|
||||||
|
|
||||||
|
def save_evaluation_results(
|
||||||
|
results: Dict[str, Any],
|
||||||
|
model_name: str,
|
||||||
|
save_dir: str = "./models"
|
||||||
|
) -> None:
|
||||||
|
"""保存评估结果"""
|
||||||
|
# 确保保存目录存在
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
|
||||||
|
# 保存分类报告
|
||||||
|
report_path = os.path.join(save_dir, f"{model_name}_classification_report.txt")
|
||||||
|
with open(report_path, "w") as f:
|
||||||
|
f.write(results["class_report"])
|
||||||
|
|
||||||
|
# 保存混淆矩阵
|
||||||
|
conf_matrix_path = os.path.join(save_dir, f"{model_name}_confusion_matrix.png")
|
||||||
|
plot_confusion_matrix(results["conf_matrix"], conf_matrix_path)
|
||||||
|
|
||||||
|
print(f"评估结果已保存到: {save_dir}")
|
||||||
|
|
||||||
|
|
||||||
|
def compare_models(
|
||||||
|
models: Dict[str, Any],
|
||||||
|
test_df: pl.DataFrame
|
||||||
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""比较多个模型"""
|
||||||
|
comparison_results = {}
|
||||||
|
|
||||||
|
for model_name, model in models.items():
|
||||||
|
print(f"\n=== 评估 {model_name} 模型 ===")
|
||||||
|
results = evaluate_model(model, test_df)
|
||||||
|
print_evaluation_results(results)
|
||||||
|
save_evaluation_results(results, model_name)
|
||||||
|
comparison_results[model_name] = results
|
||||||
|
|
||||||
|
return comparison_results
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数,用于测试"""
|
||||||
|
from .train import load_model
|
||||||
|
from ..data import load_data, preprocess_data, split_data
|
||||||
|
|
||||||
|
# 加载和预处理数据
|
||||||
|
df = load_data("../data/spam.csv")
|
||||||
|
processed_df = preprocess_data(df)
|
||||||
|
_, test_df = split_data(processed_df)
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
try:
|
||||||
|
lr_model = load_model("logistic_regression")
|
||||||
|
lgbm_model = load_model("lightgbm")
|
||||||
|
|
||||||
|
# 比较模型
|
||||||
|
models = {
|
||||||
|
"logistic_regression": lr_model,
|
||||||
|
"lightgbm": lgbm_model
|
||||||
|
}
|
||||||
|
|
||||||
|
compare_models(models, test_df)
|
||||||
|
except FileNotFoundError:
|
||||||
|
print("模型文件不存在,请先运行train.py训练模型")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
135
src/models/predict.py
Normal file
135
src/models/predict.py
Normal file
@ -0,0 +1,135 @@
|
|||||||
|
from typing import List, Dict, Any
|
||||||
|
import polars as pl
|
||||||
|
import warnings
|
||||||
|
from sklearn.exceptions import DataConversionWarning
|
||||||
|
from .train import load_model
|
||||||
|
from ..data import clean_text
|
||||||
|
|
||||||
|
# 忽略特征名称不匹配的警告
|
||||||
|
warnings.filterwarnings("ignore", message="X does not have valid feature names")
|
||||||
|
|
||||||
|
|
||||||
|
def predict_spam(
|
||||||
|
text: str,
|
||||||
|
model_name: str = "lightgbm"
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""预测单条短信是否为垃圾短信"""
|
||||||
|
# 加载模型
|
||||||
|
model = load_model(model_name)
|
||||||
|
|
||||||
|
# 清洗文本
|
||||||
|
cleaned_text = clean_text(text)
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
prediction = model.predict([cleaned_text])[0]
|
||||||
|
probability = model.predict_proba([cleaned_text])[0]
|
||||||
|
|
||||||
|
# 转换结果
|
||||||
|
label = "spam" if prediction == 1 else "ham"
|
||||||
|
|
||||||
|
return {
|
||||||
|
"original_text": text,
|
||||||
|
"cleaned_text": cleaned_text,
|
||||||
|
"label": label,
|
||||||
|
"label_num": int(prediction),
|
||||||
|
"probability": {
|
||||||
|
"ham": float(probability[0]),
|
||||||
|
"spam": float(probability[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def predict_batch_spam(
|
||||||
|
texts: List[str],
|
||||||
|
model_name: str = "lightgbm"
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""批量预测短信是否为垃圾短信"""
|
||||||
|
# 加载模型
|
||||||
|
model = load_model(model_name)
|
||||||
|
|
||||||
|
# 清洗文本
|
||||||
|
cleaned_texts = [clean_text(text) for text in texts]
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
predictions = model.predict(cleaned_texts)
|
||||||
|
probabilities = model.predict_proba(cleaned_texts)
|
||||||
|
|
||||||
|
# 转换结果
|
||||||
|
results = []
|
||||||
|
for i, (text, cleaned_text, prediction, probability) in enumerate(zip(texts, cleaned_texts, predictions, probabilities)):
|
||||||
|
label = "spam" if prediction == 1 else "ham"
|
||||||
|
results.append({
|
||||||
|
"id": i,
|
||||||
|
"original_text": text,
|
||||||
|
"cleaned_text": cleaned_text,
|
||||||
|
"label": label,
|
||||||
|
"label_num": int(prediction),
|
||||||
|
"probability": {
|
||||||
|
"ham": float(probability[0]),
|
||||||
|
"spam": float(probability[1])
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def predict_from_df(
|
||||||
|
df: pl.DataFrame,
|
||||||
|
text_column: str = "text",
|
||||||
|
model_name: str = "lightgbm"
|
||||||
|
) -> pl.DataFrame:
|
||||||
|
"""从DataFrame中预测短信是否为垃圾短信"""
|
||||||
|
# 清洗文本
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.col(text_column).map_elements(clean_text, return_dtype=pl.String).alias("clean_text")
|
||||||
|
)
|
||||||
|
|
||||||
|
# 准备预测数据
|
||||||
|
texts = df["clean_text"].to_list()
|
||||||
|
|
||||||
|
# 预测
|
||||||
|
model = load_model(model_name)
|
||||||
|
predictions = model.predict(texts)
|
||||||
|
probabilities = model.predict_proba(texts)
|
||||||
|
|
||||||
|
# 添加预测结果
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.Series("label_num", predictions).cast(pl.Int64),
|
||||||
|
pl.Series("ham_prob", [p[0] for p in probabilities]).cast(pl.Float64),
|
||||||
|
pl.Series("spam_prob", [p[1] for p in probabilities]).cast(pl.Float64)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 添加标签
|
||||||
|
df = df.with_columns(
|
||||||
|
pl.when(pl.col("label_num") == 1)
|
||||||
|
.then("spam")
|
||||||
|
.otherwise("ham")
|
||||||
|
.alias("label")
|
||||||
|
)
|
||||||
|
|
||||||
|
return df
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数,用于测试"""
|
||||||
|
# 测试单条预测
|
||||||
|
test_text = "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's"
|
||||||
|
result = predict_spam(test_text)
|
||||||
|
print("单条预测结果:")
|
||||||
|
print(result)
|
||||||
|
|
||||||
|
# 测试批量预测
|
||||||
|
test_texts = [
|
||||||
|
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005.",
|
||||||
|
"Ok lar... Joking wif u oni...",
|
||||||
|
"WINNER!! As a valued network customer you have been selected to receivea <20>900 prize reward!",
|
||||||
|
"I'm gonna be home soon and i don't want to talk about this stuff anymore tonight, k?"
|
||||||
|
]
|
||||||
|
results = predict_batch_spam(test_texts)
|
||||||
|
print("\n批量预测结果:")
|
||||||
|
for res in results:
|
||||||
|
print(res)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
162
src/models/train.py
Normal file
162
src/models/train.py
Normal file
@ -0,0 +1,162 @@
|
|||||||
|
import os
|
||||||
|
import joblib
|
||||||
|
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||||
|
from sklearn.linear_model import LogisticRegression
|
||||||
|
from lightgbm import LGBMClassifier
|
||||||
|
from sklearn.pipeline import Pipeline
|
||||||
|
from sklearn.model_selection import GridSearchCV
|
||||||
|
from typing import Tuple, Dict, Any
|
||||||
|
import polars as pl
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
# 加载环境变量
|
||||||
|
load_dotenv()
|
||||||
|
|
||||||
|
# 获取模型保存目录
|
||||||
|
MODEL_SAVE_DIR = os.getenv("MODEL_SAVE_DIR", "./models")
|
||||||
|
|
||||||
|
|
||||||
|
def build_tfidf_vectorizer() -> TfidfVectorizer:
|
||||||
|
"""构建TF-IDF向量化器"""
|
||||||
|
return TfidfVectorizer(
|
||||||
|
stop_words="english",
|
||||||
|
max_features=10000,
|
||||||
|
ngram_range=(1, 2),
|
||||||
|
lowercase=True,
|
||||||
|
token_pattern=r'\b[a-zA-Z0-9_]{1,}\b'
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_logistic_regression() -> Pipeline:
|
||||||
|
"""构建Logistic Regression模型流水线"""
|
||||||
|
pipeline = Pipeline([
|
||||||
|
("tfidf", build_tfidf_vectorizer()),
|
||||||
|
("lr", LogisticRegression(
|
||||||
|
random_state=42,
|
||||||
|
max_iter=1000,
|
||||||
|
class_weight="balanced"
|
||||||
|
))
|
||||||
|
])
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def build_lightgbm() -> Pipeline:
|
||||||
|
"""构建LightGBM模型流水线"""
|
||||||
|
pipeline = Pipeline([
|
||||||
|
("tfidf", build_tfidf_vectorizer()),
|
||||||
|
("lgbm", LGBMClassifier(
|
||||||
|
random_state=42,
|
||||||
|
class_weight="balanced",
|
||||||
|
verbose=-1,
|
||||||
|
feature_name="auto"
|
||||||
|
))
|
||||||
|
])
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
|
||||||
|
def train_model(
|
||||||
|
train_df: pl.DataFrame,
|
||||||
|
model_type: str = "lightgbm",
|
||||||
|
hyperparam_tune: bool = False
|
||||||
|
) -> Tuple[Pipeline, Dict[str, Any]]:
|
||||||
|
"""训练模型"""
|
||||||
|
# 准备训练数据
|
||||||
|
X_train = train_df["clean_text"].to_list()
|
||||||
|
y_train = train_df["label_num"].to_list()
|
||||||
|
|
||||||
|
# 选择模型
|
||||||
|
if model_type == "logistic_regression":
|
||||||
|
pipeline = build_logistic_regression()
|
||||||
|
param_grid = {
|
||||||
|
"lr__C": [0.1, 1.0, 10.0],
|
||||||
|
"lr__solver": ["liblinear", "lbfgs"]
|
||||||
|
}
|
||||||
|
elif model_type == "lightgbm":
|
||||||
|
pipeline = build_lightgbm()
|
||||||
|
param_grid = {
|
||||||
|
"lgbm__n_estimators": [100, 200],
|
||||||
|
"lgbm__learning_rate": [0.1, 0.2],
|
||||||
|
"lgbm__max_depth": [3, 5, 7]
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
raise ValueError(f"不支持的模型类型: {model_type}")
|
||||||
|
|
||||||
|
# 超参数调优
|
||||||
|
if hyperparam_tune:
|
||||||
|
grid_search = GridSearchCV(
|
||||||
|
pipeline,
|
||||||
|
param_grid,
|
||||||
|
cv=5,
|
||||||
|
scoring="f1_macro",
|
||||||
|
n_jobs=-1
|
||||||
|
)
|
||||||
|
grid_search.fit(X_train, y_train)
|
||||||
|
best_model = grid_search.best_estimator_
|
||||||
|
best_params = grid_search.best_params_
|
||||||
|
print(f"最佳参数: {best_params}")
|
||||||
|
else:
|
||||||
|
best_model = pipeline
|
||||||
|
best_model.fit(X_train, y_train)
|
||||||
|
best_params = {}
|
||||||
|
|
||||||
|
return best_model, best_params
|
||||||
|
|
||||||
|
|
||||||
|
def save_model(model: Pipeline, model_name: str) -> str:
|
||||||
|
"""保存模型"""
|
||||||
|
# 确保模型保存目录存在
|
||||||
|
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
|
||||||
|
|
||||||
|
# 构建模型文件路径
|
||||||
|
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.joblib")
|
||||||
|
|
||||||
|
# 保存模型
|
||||||
|
joblib.dump(model, model_path)
|
||||||
|
print(f"模型已保存到: {model_path}")
|
||||||
|
|
||||||
|
return model_path
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(model_name: str) -> Pipeline:
|
||||||
|
"""加载模型"""
|
||||||
|
# 构建模型文件路径
|
||||||
|
model_path = os.path.join(MODEL_SAVE_DIR, f"{model_name}.joblib")
|
||||||
|
|
||||||
|
# 加载模型
|
||||||
|
if not os.path.exists(model_path):
|
||||||
|
raise FileNotFoundError(f"模型文件不存在: {model_path}")
|
||||||
|
|
||||||
|
model = joblib.load(model_path)
|
||||||
|
print(f"模型已从: {model_path} 加载")
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""主函数,用于测试"""
|
||||||
|
import sys
|
||||||
|
import os
|
||||||
|
|
||||||
|
# 添加项目根目录到Python路径
|
||||||
|
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
|
||||||
|
|
||||||
|
from src.data import load_data, preprocess_data, split_data
|
||||||
|
|
||||||
|
# 加载和预处理数据
|
||||||
|
df = load_data("./data/spam.csv")
|
||||||
|
processed_df = preprocess_data(df)
|
||||||
|
train_df, test_df = split_data(processed_df)
|
||||||
|
|
||||||
|
print("训练Logistic Regression模型...")
|
||||||
|
lr_model, lr_params = train_model(train_df, model_type="logistic_regression")
|
||||||
|
save_model(lr_model, "logistic_regression")
|
||||||
|
|
||||||
|
print("\n训练LightGBM模型...")
|
||||||
|
lgbm_model, lgbm_params = train_model(train_df, model_type="lightgbm")
|
||||||
|
save_model(lgbm_model, "lightgbm")
|
||||||
|
|
||||||
|
print("\n模型训练完成!")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue
Block a user