删除 src/train_tweet_ultimate.py
This commit is contained in:
parent
67fb73e011
commit
9cc826963b
@ -1,287 +0,0 @@
|
||||
"""推文情感分析模型训练和加载模块
|
||||
|
||||
实现基于 TF-IDF + LightGBM 的情感分类模型。
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import polars as pl
|
||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
import lightgbm as lgb
|
||||
import joblib
|
||||
|
||||
|
||||
class TweetSentimentModel:
|
||||
"""推文情感分类模型
|
||||
|
||||
使用 TF-IDF 特征提取和 LightGBM 分类器。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tfidf_vectorizer: Optional[TfidfVectorizer] = None,
|
||||
label_encoder: Optional[LabelEncoder] = None,
|
||||
airline_encoder: Optional[LabelEncoder] = None,
|
||||
classifier: Optional[lgb.LGBMClassifier] = None,
|
||||
):
|
||||
"""初始化模型
|
||||
|
||||
Args:
|
||||
tfidf_vectorizer: TF-IDF 向量化器
|
||||
label_encoder: 情感标签编码器
|
||||
airline_encoder: 航空公司编码器
|
||||
classifier: LightGBM 分类器
|
||||
"""
|
||||
self.tfidf_vectorizer = tfidf_vectorizer or TfidfVectorizer(
|
||||
max_features=5000,
|
||||
ngram_range=(1, 2),
|
||||
min_df=2,
|
||||
max_df=0.95,
|
||||
)
|
||||
self.label_encoder = label_encoder or LabelEncoder()
|
||||
self.airline_encoder = airline_encoder or LabelEncoder()
|
||||
self.classifier = classifier or lgb.LGBMClassifier(
|
||||
n_estimators=100,
|
||||
learning_rate=0.1,
|
||||
max_depth=6,
|
||||
random_state=42,
|
||||
verbose=-1,
|
||||
)
|
||||
self._is_fitted = False
|
||||
|
||||
def fit(self, texts: np.ndarray, airlines: np.ndarray, sentiments: np.ndarray) -> "TweetSentimentModel":
|
||||
"""训练模型
|
||||
|
||||
Args:
|
||||
texts: 推文文本数组
|
||||
airlines: 航空公司数组
|
||||
sentiments: 情感标签数组
|
||||
|
||||
Returns:
|
||||
训练好的模型
|
||||
"""
|
||||
# 编码标签
|
||||
self.label_encoder.fit(sentiments)
|
||||
y = self.label_encoder.transform(sentiments)
|
||||
|
||||
# 编码航空公司
|
||||
self.airline_encoder.fit(airlines)
|
||||
airline_encoded = self.airline_encoder.transform(airlines)
|
||||
|
||||
# TF-IDF 特征提取
|
||||
X_text = self.tfidf_vectorizer.fit_transform(texts)
|
||||
|
||||
# 合并特征
|
||||
airline_features = airline_encoded.reshape(-1, 1)
|
||||
X = self._combine_features(X_text, airline_features)
|
||||
|
||||
# 训练分类器
|
||||
self.classifier.fit(X, y)
|
||||
|
||||
self._is_fitted = True
|
||||
return self
|
||||
|
||||
def predict(self, texts: np.ndarray, airlines: np.ndarray) -> np.ndarray:
|
||||
"""预测情感标签
|
||||
|
||||
Args:
|
||||
texts: 推文文本数组
|
||||
airlines: 航空公司数组
|
||||
|
||||
Returns:
|
||||
预测的情感标签数组
|
||||
"""
|
||||
if not self._is_fitted:
|
||||
raise ValueError("模型尚未训练,请先调用 fit() 方法")
|
||||
|
||||
# TF-IDF 特征提取
|
||||
X_text = self.tfidf_vectorizer.transform(texts)
|
||||
|
||||
# 编码航空公司
|
||||
airline_encoded = self.airline_encoder.transform(airlines)
|
||||
airline_features = airline_encoded.reshape(-1, 1)
|
||||
|
||||
# 合并特征
|
||||
X = self._combine_features(X_text, airline_features)
|
||||
|
||||
# 预测
|
||||
y_pred = self.classifier.predict(X)
|
||||
|
||||
# 解码标签
|
||||
return self.label_encoder.inverse_transform(y_pred)
|
||||
|
||||
def predict_proba(self, texts: np.ndarray, airlines: np.ndarray) -> np.ndarray:
|
||||
"""预测情感概率
|
||||
|
||||
Args:
|
||||
texts: 推文文本数组
|
||||
airlines: 航空公司数组
|
||||
|
||||
Returns:
|
||||
预测的概率数组 (n_samples, n_classes)
|
||||
"""
|
||||
if not self._is_fitted:
|
||||
raise ValueError("模型尚未训练,请先调用 fit() 方法")
|
||||
|
||||
# TF-IDF 特征提取
|
||||
X_text = self.tfidf_vectorizer.transform(texts)
|
||||
|
||||
# 编码航空公司
|
||||
airline_encoded = self.airline_encoder.transform(airlines)
|
||||
airline_features = airline_encoded.reshape(-1, 1)
|
||||
|
||||
# 合并特征
|
||||
X = self._combine_features(X_text, airline_features)
|
||||
|
||||
# 预测概率
|
||||
return self.classifier.predict_proba(X)
|
||||
|
||||
def _combine_features(self, text_features: np.ndarray, airline_features: np.ndarray) -> np.ndarray:
|
||||
"""合并文本特征和航空公司特征
|
||||
|
||||
Args:
|
||||
text_features: TF-IDF 文本特征
|
||||
airline_features: 航空公司特征
|
||||
|
||||
Returns:
|
||||
合并后的特征矩阵
|
||||
"""
|
||||
from scipy.sparse import hstack
|
||||
return hstack([text_features, airline_features])
|
||||
|
||||
def save(self, path: Path) -> None:
|
||||
"""保存模型
|
||||
|
||||
Args:
|
||||
path: 保存路径
|
||||
"""
|
||||
path = Path(path)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model_data = {
|
||||
"tfidf_vectorizer": self.tfidf_vectorizer,
|
||||
"label_encoder": self.label_encoder,
|
||||
"airline_encoder": self.airline_encoder,
|
||||
"classifier": self.classifier,
|
||||
"is_fitted": self._is_fitted,
|
||||
}
|
||||
|
||||
joblib.dump(model_data, path)
|
||||
|
||||
@classmethod
|
||||
def load(cls, path: Path) -> "TweetSentimentModel":
|
||||
"""加载模型
|
||||
|
||||
Args:
|
||||
path: 模型路径
|
||||
|
||||
Returns:
|
||||
加载的模型
|
||||
"""
|
||||
model_data = joblib.load(path)
|
||||
|
||||
model = cls(
|
||||
tfidf_vectorizer=model_data["tfidf_vectorizer"],
|
||||
label_encoder=model_data["label_encoder"],
|
||||
airline_encoder=model_data["airline_encoder"],
|
||||
classifier=model_data["classifier"],
|
||||
)
|
||||
model._is_fitted = model_data["is_fitted"]
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_model(model_path: Optional[Path] = None) -> TweetSentimentModel:
|
||||
"""加载预训练模型
|
||||
|
||||
Args:
|
||||
model_path: 模型路径(可选,默认使用示例模型)
|
||||
|
||||
Returns:
|
||||
加载的模型
|
||||
"""
|
||||
if model_path is not None and model_path.exists():
|
||||
return TweetSentimentModel.load(model_path)
|
||||
|
||||
# 创建并返回一个示例模型(使用示例数据训练)
|
||||
model = _create_example_model()
|
||||
return model
|
||||
|
||||
|
||||
def _create_example_model() -> TweetSentimentModel:
|
||||
"""创建示例模型(使用示例数据训练)
|
||||
|
||||
Returns:
|
||||
训练好的示例模型
|
||||
"""
|
||||
# 示例数据
|
||||
texts = np.array([
|
||||
"@United This is the worst airline ever! My flight was delayed for 5 hours and no one helped!",
|
||||
"@Southwest Thank you for the amazing flight! The crew was so helpful and friendly.",
|
||||
"@American What is the baggage policy for international flights?",
|
||||
"@Delta Terrible service! Lost my luggage and no response from customer support.",
|
||||
"@JetBlue Great experience! On time departure and friendly staff.",
|
||||
"@United Why is my flight cancelled again? This is unacceptable!",
|
||||
"@Southwest Love the free snacks and great customer service!",
|
||||
"@American Can you help me with my booking?",
|
||||
"@Delta Worst experience ever! Will never fly again!",
|
||||
"@JetBlue Thank you for the smooth flight and excellent service!",
|
||||
])
|
||||
|
||||
airlines = np.array([
|
||||
"United",
|
||||
"Southwest",
|
||||
"American",
|
||||
"Delta",
|
||||
"JetBlue",
|
||||
"United",
|
||||
"Southwest",
|
||||
"American",
|
||||
"Delta",
|
||||
"JetBlue",
|
||||
])
|
||||
|
||||
sentiments = np.array([
|
||||
"negative",
|
||||
"positive",
|
||||
"neutral",
|
||||
"negative",
|
||||
"positive",
|
||||
"negative",
|
||||
"positive",
|
||||
"neutral",
|
||||
"negative",
|
||||
"positive",
|
||||
])
|
||||
|
||||
# 训练模型
|
||||
model = TweetSentimentModel()
|
||||
model.fit(texts, airlines, sentiments)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 示例:加载模型并进行预测
|
||||
print("加载模型...")
|
||||
model = load_model()
|
||||
|
||||
print("\n测试预测...")
|
||||
test_texts = np.array([
|
||||
"@United This is terrible!",
|
||||
"@Southwest Thank you so much!",
|
||||
"@American How do I check in?",
|
||||
])
|
||||
test_airlines = np.array(["United", "Southwest", "American"])
|
||||
|
||||
predictions = model.predict(test_texts, test_airlines)
|
||||
probabilities = model.predict_proba(test_texts, test_airlines)
|
||||
|
||||
for text, airline, pred, prob in zip(test_texts, test_airlines, predictions, probabilities):
|
||||
print(f"\n文本: {text}")
|
||||
print(f"航空公司: {airline}")
|
||||
print(f"预测: {pred}")
|
||||
print(f"概率: {prob}")
|
||||
Loading…
Reference in New Issue
Block a user