删除 src/train_tweet_ultimate.py

This commit is contained in:
张则文 2026-01-15 23:25:12 +08:00
parent 67fb73e011
commit 9cc826963b

View File

@ -1,287 +0,0 @@
"""推文情感分析模型训练和加载模块
实现基于 TF-IDF + LightGBM 的情感分类模型
"""
from pathlib import Path
from typing import Optional
import numpy as np
import polars as pl
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import LabelEncoder
import lightgbm as lgb
import joblib
class TweetSentimentModel:
"""推文情感分类模型
使用 TF-IDF 特征提取和 LightGBM 分类器
"""
def __init__(
self,
tfidf_vectorizer: Optional[TfidfVectorizer] = None,
label_encoder: Optional[LabelEncoder] = None,
airline_encoder: Optional[LabelEncoder] = None,
classifier: Optional[lgb.LGBMClassifier] = None,
):
"""初始化模型
Args:
tfidf_vectorizer: TF-IDF 向量化器
label_encoder: 情感标签编码器
airline_encoder: 航空公司编码器
classifier: LightGBM 分类器
"""
self.tfidf_vectorizer = tfidf_vectorizer or TfidfVectorizer(
max_features=5000,
ngram_range=(1, 2),
min_df=2,
max_df=0.95,
)
self.label_encoder = label_encoder or LabelEncoder()
self.airline_encoder = airline_encoder or LabelEncoder()
self.classifier = classifier or lgb.LGBMClassifier(
n_estimators=100,
learning_rate=0.1,
max_depth=6,
random_state=42,
verbose=-1,
)
self._is_fitted = False
def fit(self, texts: np.ndarray, airlines: np.ndarray, sentiments: np.ndarray) -> "TweetSentimentModel":
"""训练模型
Args:
texts: 推文文本数组
airlines: 航空公司数组
sentiments: 情感标签数组
Returns:
训练好的模型
"""
# 编码标签
self.label_encoder.fit(sentiments)
y = self.label_encoder.transform(sentiments)
# 编码航空公司
self.airline_encoder.fit(airlines)
airline_encoded = self.airline_encoder.transform(airlines)
# TF-IDF 特征提取
X_text = self.tfidf_vectorizer.fit_transform(texts)
# 合并特征
airline_features = airline_encoded.reshape(-1, 1)
X = self._combine_features(X_text, airline_features)
# 训练分类器
self.classifier.fit(X, y)
self._is_fitted = True
return self
def predict(self, texts: np.ndarray, airlines: np.ndarray) -> np.ndarray:
"""预测情感标签
Args:
texts: 推文文本数组
airlines: 航空公司数组
Returns:
预测的情感标签数组
"""
if not self._is_fitted:
raise ValueError("模型尚未训练,请先调用 fit() 方法")
# TF-IDF 特征提取
X_text = self.tfidf_vectorizer.transform(texts)
# 编码航空公司
airline_encoded = self.airline_encoder.transform(airlines)
airline_features = airline_encoded.reshape(-1, 1)
# 合并特征
X = self._combine_features(X_text, airline_features)
# 预测
y_pred = self.classifier.predict(X)
# 解码标签
return self.label_encoder.inverse_transform(y_pred)
def predict_proba(self, texts: np.ndarray, airlines: np.ndarray) -> np.ndarray:
"""预测情感概率
Args:
texts: 推文文本数组
airlines: 航空公司数组
Returns:
预测的概率数组 (n_samples, n_classes)
"""
if not self._is_fitted:
raise ValueError("模型尚未训练,请先调用 fit() 方法")
# TF-IDF 特征提取
X_text = self.tfidf_vectorizer.transform(texts)
# 编码航空公司
airline_encoded = self.airline_encoder.transform(airlines)
airline_features = airline_encoded.reshape(-1, 1)
# 合并特征
X = self._combine_features(X_text, airline_features)
# 预测概率
return self.classifier.predict_proba(X)
def _combine_features(self, text_features: np.ndarray, airline_features: np.ndarray) -> np.ndarray:
"""合并文本特征和航空公司特征
Args:
text_features: TF-IDF 文本特征
airline_features: 航空公司特征
Returns:
合并后的特征矩阵
"""
from scipy.sparse import hstack
return hstack([text_features, airline_features])
def save(self, path: Path) -> None:
"""保存模型
Args:
path: 保存路径
"""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
model_data = {
"tfidf_vectorizer": self.tfidf_vectorizer,
"label_encoder": self.label_encoder,
"airline_encoder": self.airline_encoder,
"classifier": self.classifier,
"is_fitted": self._is_fitted,
}
joblib.dump(model_data, path)
@classmethod
def load(cls, path: Path) -> "TweetSentimentModel":
"""加载模型
Args:
path: 模型路径
Returns:
加载的模型
"""
model_data = joblib.load(path)
model = cls(
tfidf_vectorizer=model_data["tfidf_vectorizer"],
label_encoder=model_data["label_encoder"],
airline_encoder=model_data["airline_encoder"],
classifier=model_data["classifier"],
)
model._is_fitted = model_data["is_fitted"]
return model
def load_model(model_path: Optional[Path] = None) -> TweetSentimentModel:
"""加载预训练模型
Args:
model_path: 模型路径可选默认使用示例模型
Returns:
加载的模型
"""
if model_path is not None and model_path.exists():
return TweetSentimentModel.load(model_path)
# 创建并返回一个示例模型(使用示例数据训练)
model = _create_example_model()
return model
def _create_example_model() -> TweetSentimentModel:
"""创建示例模型(使用示例数据训练)
Returns:
训练好的示例模型
"""
# 示例数据
texts = np.array([
"@United This is the worst airline ever! My flight was delayed for 5 hours and no one helped!",
"@Southwest Thank you for the amazing flight! The crew was so helpful and friendly.",
"@American What is the baggage policy for international flights?",
"@Delta Terrible service! Lost my luggage and no response from customer support.",
"@JetBlue Great experience! On time departure and friendly staff.",
"@United Why is my flight cancelled again? This is unacceptable!",
"@Southwest Love the free snacks and great customer service!",
"@American Can you help me with my booking?",
"@Delta Worst experience ever! Will never fly again!",
"@JetBlue Thank you for the smooth flight and excellent service!",
])
airlines = np.array([
"United",
"Southwest",
"American",
"Delta",
"JetBlue",
"United",
"Southwest",
"American",
"Delta",
"JetBlue",
])
sentiments = np.array([
"negative",
"positive",
"neutral",
"negative",
"positive",
"negative",
"positive",
"neutral",
"negative",
"positive",
])
# 训练模型
model = TweetSentimentModel()
model.fit(texts, airlines, sentiments)
return model
if __name__ == "__main__":
# 示例:加载模型并进行预测
print("加载模型...")
model = load_model()
print("\n测试预测...")
test_texts = np.array([
"@United This is terrible!",
"@Southwest Thank you so much!",
"@American How do I check in?",
])
test_airlines = np.array(["United", "Southwest", "American"])
predictions = model.predict(test_texts, test_airlines)
probabilities = model.predict_proba(test_texts, test_airlines)
for text, airline, pred, prob in zip(test_texts, test_airlines, predictions, probabilities):
print(f"\n文本: {text}")
print(f"航空公司: {airline}")
print(f"预测: {pred}")
print(f"概率: {prob}")