删除 src/train_tweet_ultimate.py
This commit is contained in:
parent
67fb73e011
commit
9cc826963b
@ -1,287 +0,0 @@
|
|||||||
"""推文情感分析模型训练和加载模块
|
|
||||||
|
|
||||||
实现基于 TF-IDF + LightGBM 的情感分类模型。
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import polars as pl
|
|
||||||
from sklearn.feature_extraction.text import TfidfVectorizer
|
|
||||||
from sklearn.preprocessing import LabelEncoder
|
|
||||||
import lightgbm as lgb
|
|
||||||
import joblib
|
|
||||||
|
|
||||||
|
|
||||||
class TweetSentimentModel:
|
|
||||||
"""推文情感分类模型
|
|
||||||
|
|
||||||
使用 TF-IDF 特征提取和 LightGBM 分类器。
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
tfidf_vectorizer: Optional[TfidfVectorizer] = None,
|
|
||||||
label_encoder: Optional[LabelEncoder] = None,
|
|
||||||
airline_encoder: Optional[LabelEncoder] = None,
|
|
||||||
classifier: Optional[lgb.LGBMClassifier] = None,
|
|
||||||
):
|
|
||||||
"""初始化模型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
tfidf_vectorizer: TF-IDF 向量化器
|
|
||||||
label_encoder: 情感标签编码器
|
|
||||||
airline_encoder: 航空公司编码器
|
|
||||||
classifier: LightGBM 分类器
|
|
||||||
"""
|
|
||||||
self.tfidf_vectorizer = tfidf_vectorizer or TfidfVectorizer(
|
|
||||||
max_features=5000,
|
|
||||||
ngram_range=(1, 2),
|
|
||||||
min_df=2,
|
|
||||||
max_df=0.95,
|
|
||||||
)
|
|
||||||
self.label_encoder = label_encoder or LabelEncoder()
|
|
||||||
self.airline_encoder = airline_encoder or LabelEncoder()
|
|
||||||
self.classifier = classifier or lgb.LGBMClassifier(
|
|
||||||
n_estimators=100,
|
|
||||||
learning_rate=0.1,
|
|
||||||
max_depth=6,
|
|
||||||
random_state=42,
|
|
||||||
verbose=-1,
|
|
||||||
)
|
|
||||||
self._is_fitted = False
|
|
||||||
|
|
||||||
def fit(self, texts: np.ndarray, airlines: np.ndarray, sentiments: np.ndarray) -> "TweetSentimentModel":
|
|
||||||
"""训练模型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: 推文文本数组
|
|
||||||
airlines: 航空公司数组
|
|
||||||
sentiments: 情感标签数组
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
训练好的模型
|
|
||||||
"""
|
|
||||||
# 编码标签
|
|
||||||
self.label_encoder.fit(sentiments)
|
|
||||||
y = self.label_encoder.transform(sentiments)
|
|
||||||
|
|
||||||
# 编码航空公司
|
|
||||||
self.airline_encoder.fit(airlines)
|
|
||||||
airline_encoded = self.airline_encoder.transform(airlines)
|
|
||||||
|
|
||||||
# TF-IDF 特征提取
|
|
||||||
X_text = self.tfidf_vectorizer.fit_transform(texts)
|
|
||||||
|
|
||||||
# 合并特征
|
|
||||||
airline_features = airline_encoded.reshape(-1, 1)
|
|
||||||
X = self._combine_features(X_text, airline_features)
|
|
||||||
|
|
||||||
# 训练分类器
|
|
||||||
self.classifier.fit(X, y)
|
|
||||||
|
|
||||||
self._is_fitted = True
|
|
||||||
return self
|
|
||||||
|
|
||||||
def predict(self, texts: np.ndarray, airlines: np.ndarray) -> np.ndarray:
|
|
||||||
"""预测情感标签
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: 推文文本数组
|
|
||||||
airlines: 航空公司数组
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
预测的情感标签数组
|
|
||||||
"""
|
|
||||||
if not self._is_fitted:
|
|
||||||
raise ValueError("模型尚未训练,请先调用 fit() 方法")
|
|
||||||
|
|
||||||
# TF-IDF 特征提取
|
|
||||||
X_text = self.tfidf_vectorizer.transform(texts)
|
|
||||||
|
|
||||||
# 编码航空公司
|
|
||||||
airline_encoded = self.airline_encoder.transform(airlines)
|
|
||||||
airline_features = airline_encoded.reshape(-1, 1)
|
|
||||||
|
|
||||||
# 合并特征
|
|
||||||
X = self._combine_features(X_text, airline_features)
|
|
||||||
|
|
||||||
# 预测
|
|
||||||
y_pred = self.classifier.predict(X)
|
|
||||||
|
|
||||||
# 解码标签
|
|
||||||
return self.label_encoder.inverse_transform(y_pred)
|
|
||||||
|
|
||||||
def predict_proba(self, texts: np.ndarray, airlines: np.ndarray) -> np.ndarray:
|
|
||||||
"""预测情感概率
|
|
||||||
|
|
||||||
Args:
|
|
||||||
texts: 推文文本数组
|
|
||||||
airlines: 航空公司数组
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
预测的概率数组 (n_samples, n_classes)
|
|
||||||
"""
|
|
||||||
if not self._is_fitted:
|
|
||||||
raise ValueError("模型尚未训练,请先调用 fit() 方法")
|
|
||||||
|
|
||||||
# TF-IDF 特征提取
|
|
||||||
X_text = self.tfidf_vectorizer.transform(texts)
|
|
||||||
|
|
||||||
# 编码航空公司
|
|
||||||
airline_encoded = self.airline_encoder.transform(airlines)
|
|
||||||
airline_features = airline_encoded.reshape(-1, 1)
|
|
||||||
|
|
||||||
# 合并特征
|
|
||||||
X = self._combine_features(X_text, airline_features)
|
|
||||||
|
|
||||||
# 预测概率
|
|
||||||
return self.classifier.predict_proba(X)
|
|
||||||
|
|
||||||
def _combine_features(self, text_features: np.ndarray, airline_features: np.ndarray) -> np.ndarray:
|
|
||||||
"""合并文本特征和航空公司特征
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text_features: TF-IDF 文本特征
|
|
||||||
airline_features: 航空公司特征
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
合并后的特征矩阵
|
|
||||||
"""
|
|
||||||
from scipy.sparse import hstack
|
|
||||||
return hstack([text_features, airline_features])
|
|
||||||
|
|
||||||
def save(self, path: Path) -> None:
|
|
||||||
"""保存模型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: 保存路径
|
|
||||||
"""
|
|
||||||
path = Path(path)
|
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
|
|
||||||
model_data = {
|
|
||||||
"tfidf_vectorizer": self.tfidf_vectorizer,
|
|
||||||
"label_encoder": self.label_encoder,
|
|
||||||
"airline_encoder": self.airline_encoder,
|
|
||||||
"classifier": self.classifier,
|
|
||||||
"is_fitted": self._is_fitted,
|
|
||||||
}
|
|
||||||
|
|
||||||
joblib.dump(model_data, path)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, path: Path) -> "TweetSentimentModel":
|
|
||||||
"""加载模型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
path: 模型路径
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
加载的模型
|
|
||||||
"""
|
|
||||||
model_data = joblib.load(path)
|
|
||||||
|
|
||||||
model = cls(
|
|
||||||
tfidf_vectorizer=model_data["tfidf_vectorizer"],
|
|
||||||
label_encoder=model_data["label_encoder"],
|
|
||||||
airline_encoder=model_data["airline_encoder"],
|
|
||||||
classifier=model_data["classifier"],
|
|
||||||
)
|
|
||||||
model._is_fitted = model_data["is_fitted"]
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def load_model(model_path: Optional[Path] = None) -> TweetSentimentModel:
|
|
||||||
"""加载预训练模型
|
|
||||||
|
|
||||||
Args:
|
|
||||||
model_path: 模型路径(可选,默认使用示例模型)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
加载的模型
|
|
||||||
"""
|
|
||||||
if model_path is not None and model_path.exists():
|
|
||||||
return TweetSentimentModel.load(model_path)
|
|
||||||
|
|
||||||
# 创建并返回一个示例模型(使用示例数据训练)
|
|
||||||
model = _create_example_model()
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def _create_example_model() -> TweetSentimentModel:
|
|
||||||
"""创建示例模型(使用示例数据训练)
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
训练好的示例模型
|
|
||||||
"""
|
|
||||||
# 示例数据
|
|
||||||
texts = np.array([
|
|
||||||
"@United This is the worst airline ever! My flight was delayed for 5 hours and no one helped!",
|
|
||||||
"@Southwest Thank you for the amazing flight! The crew was so helpful and friendly.",
|
|
||||||
"@American What is the baggage policy for international flights?",
|
|
||||||
"@Delta Terrible service! Lost my luggage and no response from customer support.",
|
|
||||||
"@JetBlue Great experience! On time departure and friendly staff.",
|
|
||||||
"@United Why is my flight cancelled again? This is unacceptable!",
|
|
||||||
"@Southwest Love the free snacks and great customer service!",
|
|
||||||
"@American Can you help me with my booking?",
|
|
||||||
"@Delta Worst experience ever! Will never fly again!",
|
|
||||||
"@JetBlue Thank you for the smooth flight and excellent service!",
|
|
||||||
])
|
|
||||||
|
|
||||||
airlines = np.array([
|
|
||||||
"United",
|
|
||||||
"Southwest",
|
|
||||||
"American",
|
|
||||||
"Delta",
|
|
||||||
"JetBlue",
|
|
||||||
"United",
|
|
||||||
"Southwest",
|
|
||||||
"American",
|
|
||||||
"Delta",
|
|
||||||
"JetBlue",
|
|
||||||
])
|
|
||||||
|
|
||||||
sentiments = np.array([
|
|
||||||
"negative",
|
|
||||||
"positive",
|
|
||||||
"neutral",
|
|
||||||
"negative",
|
|
||||||
"positive",
|
|
||||||
"negative",
|
|
||||||
"positive",
|
|
||||||
"neutral",
|
|
||||||
"negative",
|
|
||||||
"positive",
|
|
||||||
])
|
|
||||||
|
|
||||||
# 训练模型
|
|
||||||
model = TweetSentimentModel()
|
|
||||||
model.fit(texts, airlines, sentiments)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
# 示例:加载模型并进行预测
|
|
||||||
print("加载模型...")
|
|
||||||
model = load_model()
|
|
||||||
|
|
||||||
print("\n测试预测...")
|
|
||||||
test_texts = np.array([
|
|
||||||
"@United This is terrible!",
|
|
||||||
"@Southwest Thank you so much!",
|
|
||||||
"@American How do I check in?",
|
|
||||||
])
|
|
||||||
test_airlines = np.array(["United", "Southwest", "American"])
|
|
||||||
|
|
||||||
predictions = model.predict(test_texts, test_airlines)
|
|
||||||
probabilities = model.predict_proba(test_texts, test_airlines)
|
|
||||||
|
|
||||||
for text, airline, pred, prob in zip(test_texts, test_airlines, predictions, probabilities):
|
|
||||||
print(f"\n文本: {text}")
|
|
||||||
print(f"航空公司: {airline}")
|
|
||||||
print(f"预测: {pred}")
|
|
||||||
print(f"概率: {prob}")
|
|
||||||
Loading…
Reference in New Issue
Block a user