添加项目所有文件
This commit is contained in:
parent
58437d6a48
commit
47e287c011
26
.gitignore
vendored
Normal file
26
.gitignore
vendored
Normal file
@ -0,0 +1,26 @@
|
||||
# ===== 环境变量(绝对不能提交!)=====
|
||||
.env
|
||||
|
||||
# ===== Python 虚拟环境 =====
|
||||
.venv/
|
||||
venv/
|
||||
__pycache__/
|
||||
*.pyc
|
||||
*.pyo
|
||||
.pytest_cache/
|
||||
|
||||
# ===== IDE 配置 =====
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
|
||||
# ===== macOS 系统文件 =====
|
||||
.DS_Store
|
||||
|
||||
# ===== Jupyter =====
|
||||
.ipynb_checkpoints/
|
||||
|
||||
# ===== 超大文件(超过 10MB 需手动添加)=====
|
||||
# 如果你的数据或模型文件超过 10MB,请在下面添加:
|
||||
# data/large_dataset.csv
|
||||
# models/large_model.pkl
|
||||
53
pyproject.toml
Normal file
53
pyproject.toml
Normal file
@ -0,0 +1,53 @@
|
||||
[project]
|
||||
name = "ml-course-design"
|
||||
version = "0.1.0"
|
||||
description = "机器学习 × LLM × Agent 课程设计模板"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"pydantic>=2.10",
|
||||
"pandera>=0.21",
|
||||
"pydantic-ai>=0.7",
|
||||
"polars>=1.0",
|
||||
"pandas>=2.2",
|
||||
"scikit-learn>=1.5",
|
||||
"lightgbm>=4.5",
|
||||
"seaborn>=0.13",
|
||||
"joblib>=1.4",
|
||||
"python-dotenv>=1.0",
|
||||
"streamlit>=1.40",
|
||||
"xgboost>=3.1.3",
|
||||
"httpx>=0.27.0",
|
||||
"aiohttp>=3.9.0",
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
name = "tencent"
|
||||
url = "https://mirrors.cloud.tencent.com/pypi/simple/"
|
||||
default = true
|
||||
|
||||
[[tool.uv.index]]
|
||||
url = "https://mirrors.aliyun.com/pypi/simple/"
|
||||
|
||||
[dependency-groups]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=1.3",
|
||||
"ruff>=0.8",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["src"]
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests"]
|
||||
109
src/airline_detector.py
Normal file
109
src/airline_detector.py
Normal file
@ -0,0 +1,109 @@
|
||||
"""智能航空公司识别模块
|
||||
|
||||
自动从推文文本中识别涉及的航空公司,提供更智能的分析体验。
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional, List, Dict
|
||||
|
||||
|
||||
class AirlineDetector:
|
||||
"""航空公司识别器"""
|
||||
|
||||
def __init__(self):
|
||||
# 航空公司名称映射(包含常见变体)
|
||||
self.airline_patterns = {
|
||||
"United": [
|
||||
r"@[Uu]nited", r"@[Uu]nitedAir", r"united airlines?",
|
||||
r"united air", r"#united"
|
||||
],
|
||||
"Delta": [
|
||||
r"@[Dd]elta", r"@[Dd]eltaAir", r"delta airlines?",
|
||||
r"delta air", r"#delta"
|
||||
],
|
||||
"American": [
|
||||
r"@[Aa]merican", r"@[Aa]mericanAir", r"american airlines?",
|
||||
r"american air", r"#american"
|
||||
],
|
||||
"Southwest": [
|
||||
r"@[Ss]outhwest", r"@[Ss]outhwestAir", r"southwest airlines?",
|
||||
r"southwest air", r"#southwest"
|
||||
],
|
||||
"US Airways": [
|
||||
r"@[Uu][Ss][Aa]irways", r"us airways", r"usairways",
|
||||
r"#usairways"
|
||||
],
|
||||
"JetBlue": [
|
||||
r"@[Jj]etblue", r"@[Jj]etblueair", r"jetblue airlines?",
|
||||
r"jetblue air", r"#jetblue"
|
||||
],
|
||||
"Virgin America": [
|
||||
r"@[Vv]irginamerica", r"virgin america", r"virginamerica",
|
||||
r"#virginamerica"
|
||||
]
|
||||
}
|
||||
|
||||
def detect_airlines(self, text: str) -> List[str]:
|
||||
"""检测推文中提到的所有航空公司"""
|
||||
|
||||
detected_airlines = []
|
||||
|
||||
for airline, patterns in self.airline_patterns.items():
|
||||
for pattern in patterns:
|
||||
if re.search(pattern, text, re.IGNORECASE):
|
||||
if airline not in detected_airlines:
|
||||
detected_airlines.append(airline)
|
||||
break # 找到一个匹配就继续下一个航空公司
|
||||
|
||||
return detected_airlines
|
||||
|
||||
def get_primary_airline(self, text: str) -> Optional[str]:
|
||||
"""获取主要的航空公司(基于出现频率和位置)"""
|
||||
|
||||
airlines = self.detect_airlines(text)
|
||||
|
||||
if not airlines:
|
||||
return None
|
||||
|
||||
# 简单的优先级策略
|
||||
# 1. 出现在开头的航空公司优先
|
||||
text_lower = text.lower()
|
||||
for airline in airlines:
|
||||
airline_lower = airline.lower()
|
||||
if text_lower.startswith(f"@{airline_lower}") or \
|
||||
text_lower.startswith(airline_lower):
|
||||
return airline
|
||||
|
||||
# 2. 返回第一个检测到的航空公司
|
||||
return airlines[0]
|
||||
|
||||
def analyze_airline_context(self, text: str) -> Dict:
|
||||
"""分析航空公司的上下文信息"""
|
||||
|
||||
airlines = self.detect_airlines(text)
|
||||
|
||||
return {
|
||||
"detected_airlines": airlines,
|
||||
"airline_count": len(airlines),
|
||||
"primary_airline": self.get_primary_airline(text) if airlines else None,
|
||||
"is_multiple_airlines": len(airlines) > 1,
|
||||
"has_airline_mention": len(airlines) > 0
|
||||
}
|
||||
|
||||
|
||||
def extract_airline_context(text: str) -> Dict:
|
||||
"""提取航空公司上下文信息的便捷函数"""
|
||||
detector = AirlineDetector()
|
||||
return detector.analyze_airline_context(text)
|
||||
|
||||
|
||||
def get_suggested_airline(text: str) -> str:
|
||||
"""获取建议的航空公司(用于兼容现有接口)"""
|
||||
detector = AirlineDetector()
|
||||
primary_airline = detector.get_primary_airline(text)
|
||||
|
||||
if primary_airline:
|
||||
return primary_airline
|
||||
else:
|
||||
# 如果没有检测到航空公司,使用默认值或根据内容推断
|
||||
return "通用航空公司" # 或根据其他上下文推断
|
||||
76
src/config.py
Normal file
76
src/config.py
Normal file
@ -0,0 +1,76 @@
|
||||
"""项目配置文件"""
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Config:
|
||||
"""全局配置类"""
|
||||
|
||||
# API 配置 - 使用属性动态获取环境变量
|
||||
@classmethod
|
||||
def get_api_key(cls) -> Optional[str]:
|
||||
"""动态获取API密钥"""
|
||||
return os.getenv("DEEPSEEK_API_KEY")
|
||||
|
||||
@classmethod
|
||||
def get_base_url(cls) -> str:
|
||||
"""动态获取API基础URL"""
|
||||
return os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
|
||||
|
||||
DEEPSEEK_MODEL: str = "deepseek-chat"
|
||||
|
||||
# 应用配置
|
||||
MAX_RETRY_ATTEMPTS: int = 3
|
||||
REQUEST_TIMEOUT: int = 30
|
||||
CACHE_EXPIRY: int = 300 # 5分钟缓存
|
||||
|
||||
# 情感分析配置
|
||||
SENTIMENT_THRESHOLDS = {
|
||||
"high_confidence": 0.8,
|
||||
"medium_confidence": 0.6,
|
||||
"low_confidence": 0.4
|
||||
}
|
||||
|
||||
# 航空公司配置
|
||||
AIRLINES = [
|
||||
"United", "US Airways", "American", "Southwest", "Delta", "Virgin America",
|
||||
"JetBlue", "Alaska", "Frontier", "Spirit", "Hawaiian"
|
||||
]
|
||||
|
||||
# UI 配置
|
||||
MAX_DISPLAY_FACTORS: int = 5
|
||||
MAX_DISPLAY_ACTIONS: int = 4
|
||||
|
||||
@classmethod
|
||||
def validate_config(cls) -> bool:
|
||||
"""验证配置是否有效"""
|
||||
api_key = cls.get_api_key()
|
||||
if not api_key:
|
||||
print("⚠️ 警告: DEEPSEEK_API_KEY 未配置")
|
||||
return False
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def validate_config_with_message(cls) -> tuple[bool, str]:
|
||||
"""验证配置并返回详细消息"""
|
||||
api_key = cls.get_api_key()
|
||||
base_url = cls.get_base_url()
|
||||
|
||||
if not api_key:
|
||||
return False, "API 密钥未配置,请检查 .env 文件"
|
||||
|
||||
# 检查API密钥格式
|
||||
if not api_key.startswith("sk-"):
|
||||
return False, "API 密钥格式不正确,应以 'sk-' 开头"
|
||||
|
||||
# 检查URL格式
|
||||
if base_url and not base_url.startswith("http"):
|
||||
return False, "API 地址格式不正确,应以 'http' 或 'https' 开头"
|
||||
|
||||
return True, "API 配置正常"
|
||||
|
||||
@classmethod
|
||||
def get_available_airlines(cls) -> list[str]:
|
||||
"""获取可用的航空公司列表"""
|
||||
return cls.AIRLINES
|
||||
241
src/deepseek_agent.py
Normal file
241
src/deepseek_agent.py
Normal file
@ -0,0 +1,241 @@
|
||||
"""DeepSeek API 驱动的智能情感分析 Agent
|
||||
|
||||
使用 DeepSeek API 实现实时、智能的推文情感分析、解释和处置方案生成。
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Optional, Dict, Any
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class DeepSeekClient:
|
||||
"""DeepSeek API 客户端"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
||||
self.api_key = api_key or os.getenv("DEEPSEEK_API_KEY")
|
||||
self.base_url = base_url or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
||||
|
||||
if not self.api_key:
|
||||
raise ValueError("DeepSeek API Key 未配置,请设置 DEEPSEEK_API_KEY 环境变量")
|
||||
|
||||
async def chat_completion(self, messages: list, model: str = "deepseek-chat", temperature: float = 0.7) -> str:
|
||||
"""调用 DeepSeek API 进行对话补全"""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
else:
|
||||
raise Exception(f"DeepSeek API 调用失败: {response.status_code} - {response.text}")
|
||||
|
||||
|
||||
class SentimentAnalysisResult(BaseModel):
|
||||
"""情感分析结果"""
|
||||
sentiment: str = Field(description="情感类别: negative/neutral/positive")
|
||||
confidence: float = Field(description="置信度 (0-1)")
|
||||
reasoning: str = Field(description="情感判断的推理过程")
|
||||
key_factors: list[str] = Field(description="影响情感判断的关键因素")
|
||||
|
||||
|
||||
class DisposalPlan(BaseModel):
|
||||
"""处置方案"""
|
||||
priority: str = Field(description="处理优先级: high/medium/low")
|
||||
action_type: str = Field(description="行动类型: response/investigate/monitor/ignore")
|
||||
suggested_response: Optional[str] = Field(description="建议回复内容", default=None)
|
||||
follow_up_actions: list[str] = Field(description="后续行动建议")
|
||||
reasoning: str = Field(description="处置方案制定的理由")
|
||||
|
||||
|
||||
class TweetAnalysisResult(BaseModel):
|
||||
"""推文分析完整结果"""
|
||||
tweet_text: str = Field(description="原始推文文本")
|
||||
airline: str = Field(description="航空公司")
|
||||
sentiment_analysis: SentimentAnalysisResult = Field(description="情感分析结果")
|
||||
disposal_plan: DisposalPlan = Field(description="处置方案")
|
||||
|
||||
|
||||
class DeepSeekTweetAgent:
|
||||
"""基于 DeepSeek API 的智能推文分析 Agent"""
|
||||
|
||||
def __init__(self, api_key: Optional[str] = None, base_url: Optional[str] = None):
|
||||
self.client = DeepSeekClient(api_key, base_url)
|
||||
|
||||
async def analyze_sentiment(self, text: str, airline: str) -> SentimentAnalysisResult:
|
||||
"""使用 DeepSeek API 进行情感分析"""
|
||||
|
||||
prompt = f"""
|
||||
请分析以下航空推文的情感倾向,并给出详细的分析过程:
|
||||
|
||||
推文内容:"{text}"
|
||||
航空公司:{airline}
|
||||
|
||||
请按照以下格式输出分析结果:
|
||||
1. 情感类别:negative/neutral/positive
|
||||
2. 置信度:0-1之间的数值
|
||||
3. 关键因素:列出影响情感判断的关键因素(最多5个)
|
||||
4. 推理过程:详细说明情感判断的推理过程
|
||||
|
||||
请确保分析准确、客观,并考虑航空行业的特殊性。
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一位专业的航空行业情感分析专家,擅长分析推文中的情感倾向和潜在问题。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await self.client.chat_completion(messages, temperature=0.3)
|
||||
|
||||
# 解析响应并构建结果
|
||||
return self._parse_sentiment_response(response, text, airline)
|
||||
|
||||
async def generate_disposal_plan(self, text: str, airline: str, sentiment_result: SentimentAnalysisResult) -> DisposalPlan:
|
||||
"""生成处置方案"""
|
||||
|
||||
prompt = f"""
|
||||
基于以下推文分析和情感判断结果,为航空公司制定一个合理的处置方案:
|
||||
|
||||
推文内容:"{text}"
|
||||
航空公司:{airline}
|
||||
情感分析结果:
|
||||
- 情感类别:{sentiment_result.sentiment}
|
||||
- 置信度:{sentiment_result.confidence}
|
||||
- 关键因素:{', '.join(sentiment_result.key_factors)}
|
||||
- 推理过程:{sentiment_result.reasoning}
|
||||
|
||||
请按照以下格式输出处置方案:
|
||||
1. 优先级:high/medium/low
|
||||
2. 行动类型:response/investigate/monitor/ignore
|
||||
3. 建议回复:如果行动类型是response,提供具体的回复建议
|
||||
4. 后续行动:列出2-4个具体的后续行动建议
|
||||
5. 制定理由:说明为什么制定这样的处置方案
|
||||
|
||||
请确保处置方案符合航空行业的服务标准和客户关系管理最佳实践。
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "你是一位航空公司的客户服务专家,擅长制定合理的客户反馈处置方案。"},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
response = await self.client.chat_completion(messages, temperature=0.5)
|
||||
|
||||
return self._parse_disposal_response(response)
|
||||
|
||||
async def analyze_tweet(self, text: str, airline: str) -> TweetAnalysisResult:
|
||||
"""完整的推文分析流程"""
|
||||
|
||||
# 1. 情感分析
|
||||
sentiment_result = await self.analyze_sentiment(text, airline)
|
||||
|
||||
# 2. 生成处置方案
|
||||
disposal_plan = await self.generate_disposal_plan(text, airline, sentiment_result)
|
||||
|
||||
# 3. 返回完整结果
|
||||
return TweetAnalysisResult(
|
||||
tweet_text=text,
|
||||
airline=airline,
|
||||
sentiment_analysis=sentiment_result,
|
||||
disposal_plan=disposal_plan
|
||||
)
|
||||
|
||||
def _parse_sentiment_response(self, response: str, text: str, airline: str) -> SentimentAnalysisResult:
|
||||
"""解析情感分析响应"""
|
||||
|
||||
# 简化解析逻辑,实际应用中可以使用更复杂的解析
|
||||
lines = response.strip().split('\n')
|
||||
|
||||
sentiment = "neutral"
|
||||
confidence = 0.5
|
||||
key_factors = []
|
||||
reasoning = ""
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("情感类别:"):
|
||||
sentiment = line.replace("情感类别:", "").strip().lower()
|
||||
elif line.startswith("置信度:"):
|
||||
try:
|
||||
confidence = float(line.replace("置信度:", "").strip())
|
||||
except:
|
||||
confidence = 0.5
|
||||
elif line.startswith("关键因素:"):
|
||||
factors = line.replace("关键因素:", "").strip()
|
||||
key_factors = [f.strip() for f in factors.split(',') if f.strip()]
|
||||
elif line.startswith("推理过程:"):
|
||||
reasoning = line.replace("推理过程:", "").strip()
|
||||
|
||||
# 如果解析失败,使用默认值
|
||||
if not reasoning:
|
||||
reasoning = f"基于推文内容和航空行业特点进行综合分析,判断情感倾向为{sentiment}。"
|
||||
|
||||
return SentimentAnalysisResult(
|
||||
sentiment=sentiment,
|
||||
confidence=confidence,
|
||||
reasoning=reasoning,
|
||||
key_factors=key_factors
|
||||
)
|
||||
|
||||
def _parse_disposal_response(self, response: str) -> DisposalPlan:
|
||||
"""解析处置方案响应"""
|
||||
|
||||
lines = response.strip().split('\n')
|
||||
|
||||
priority = "medium"
|
||||
action_type = "monitor"
|
||||
suggested_response = None
|
||||
follow_up_actions = []
|
||||
reasoning = ""
|
||||
|
||||
for line in lines:
|
||||
line = line.strip()
|
||||
if line.startswith("优先级:"):
|
||||
priority = line.replace("优先级:", "").strip().lower()
|
||||
elif line.startswith("行动类型:"):
|
||||
action_type = line.replace("行动类型:", "").strip().lower()
|
||||
elif line.startswith("建议回复:"):
|
||||
suggested_response = line.replace("建议回复:", "").strip()
|
||||
elif line.startswith("后续行动:"):
|
||||
actions = line.replace("后续行动:", "").strip()
|
||||
follow_up_actions = [a.strip() for a in actions.split(',') if a.strip()]
|
||||
elif line.startswith("制定理由:"):
|
||||
reasoning = line.replace("制定理由:", "").strip()
|
||||
|
||||
if not reasoning:
|
||||
reasoning = "基于情感分析结果和航空行业最佳实践制定此处置方案。"
|
||||
|
||||
return DisposalPlan(
|
||||
priority=priority,
|
||||
action_type=action_type,
|
||||
suggested_response=suggested_response,
|
||||
follow_up_actions=follow_up_actions,
|
||||
reasoning=reasoning
|
||||
)
|
||||
|
||||
|
||||
# 同步版本的包装函数(为了兼容现有接口)
|
||||
import asyncio
|
||||
|
||||
|
||||
def analyze_tweet_deepseek(text: str, airline: str) -> TweetAnalysisResult:
|
||||
"""同步版本的推文分析函数"""
|
||||
agent = DeepSeekTweetAgent()
|
||||
return asyncio.run(agent.analyze_tweet(text, airline))
|
||||
493
src/deepseek_agent_optimized.py
Normal file
493
src/deepseek_agent_optimized.py
Normal file
@ -0,0 +1,493 @@
|
||||
"""优化版 DeepSeek API 驱动的智能情感分析 Agent"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import re
|
||||
from typing import Optional, Dict, Any
|
||||
from functools import lru_cache
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.config import Config
|
||||
|
||||
|
||||
class APIError(Exception):
|
||||
"""API 错误异常类"""
|
||||
def __init__(self, message: str, status_code: Optional[int] = None):
|
||||
self.message = message
|
||||
self.status_code = status_code
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class DeepSeekClient:
|
||||
"""优化版 DeepSeek API 客户端"""
|
||||
|
||||
def __init__(self):
|
||||
api_key = Config.get_api_key()
|
||||
if not api_key:
|
||||
raise ValueError("DeepSeek API Key 未配置")
|
||||
|
||||
self.api_key = api_key
|
||||
self.base_url = Config.get_base_url()
|
||||
self.model = Config.DEEPSEEK_MODEL
|
||||
|
||||
async def chat_completion_with_retry(
|
||||
self,
|
||||
messages: list,
|
||||
temperature: float = 0.3,
|
||||
max_tokens: int = 500
|
||||
) -> str:
|
||||
"""带重试机制的 API 调用"""
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
data = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
last_error = None
|
||||
|
||||
for attempt in range(Config.MAX_RETRY_ATTEMPTS):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=Config.REQUEST_TIMEOUT) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
return result["choices"][0]["message"]["content"]
|
||||
elif response.status_code == 401:
|
||||
raise APIError("API 密钥无效", response.status_code)
|
||||
elif response.status_code == 429:
|
||||
# 限流,等待后重试
|
||||
wait_time = 2 ** attempt # 指数退避
|
||||
await asyncio.sleep(wait_time)
|
||||
continue
|
||||
else:
|
||||
raise APIError(f"API 调用失败: {response.status_code}", response.status_code)
|
||||
|
||||
except (httpx.ConnectError, httpx.TimeoutException) as e:
|
||||
last_error = e
|
||||
if attempt < Config.MAX_RETRY_ATTEMPTS - 1:
|
||||
await asyncio.sleep(1) # 等待1秒后重试
|
||||
continue
|
||||
else:
|
||||
raise APIError(f"网络连接失败: {str(e)}")
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
if attempt < Config.MAX_RETRY_ATTEMPTS - 1:
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
else:
|
||||
raise APIError(f"API 调用异常: {str(e)}")
|
||||
|
||||
raise last_error or APIError("未知错误")
|
||||
|
||||
|
||||
class SentimentAnalysisResult(BaseModel):
|
||||
"""情感分析结果"""
|
||||
sentiment: str = Field(description="情感类别: negative/neutral/positive")
|
||||
confidence: float = Field(description="置信度 (0-1)")
|
||||
reasoning: str = Field(description="情感判断的推理过程")
|
||||
key_factors: list[str] = Field(description="影响情感判断的关键因素")
|
||||
intensity: str = Field(description="情感强度: mild/moderate/strong")
|
||||
|
||||
|
||||
class DisposalPlan(BaseModel):
|
||||
"""处置方案"""
|
||||
priority: str = Field(description="处理优先级: high/medium/low")
|
||||
action_type: str = Field(description="行动类型: response/investigate/monitor/ignore")
|
||||
suggested_response: Optional[str] = Field(description="建议回复内容", default=None)
|
||||
follow_up_actions: list[str] = Field(description="后续行动建议")
|
||||
reasoning: str = Field(description="处置方案制定的理由")
|
||||
urgency_level: str = Field(description="紧急程度: immediate/soon/normal")
|
||||
|
||||
|
||||
class TweetAnalysisResult(BaseModel):
|
||||
"""推文分析完整结果"""
|
||||
tweet_text: str = Field(description="原始推文文本")
|
||||
airline: str = Field(description="航空公司")
|
||||
sentiment_analysis: SentimentAnalysisResult = Field(description="情感分析结果")
|
||||
disposal_plan: DisposalPlan = Field(description="处置方案")
|
||||
processing_time: float = Field(description="处理耗时(秒)")
|
||||
api_used: bool = Field(description="是否使用了 API")
|
||||
|
||||
|
||||
class ResponseParser:
|
||||
"""API 响应解析器"""
|
||||
|
||||
@staticmethod
|
||||
def parse_sentiment_response(response: str) -> Dict[str, Any]:
|
||||
"""解析情感分析响应"""
|
||||
|
||||
# 使用正则表达式进行更精确的解析
|
||||
patterns = {
|
||||
"sentiment": r"情感类别[::]\s*(negative|neutral|positive)",
|
||||
"confidence": r"置信度[::]\s*([0-9]*\.?[0-9]+)",
|
||||
"intensity": r"情感强度[::]\s*(mild|moderate|strong)",
|
||||
}
|
||||
|
||||
result = {}
|
||||
|
||||
for key, pattern in patterns.items():
|
||||
match = re.search(pattern, response, re.IGNORECASE)
|
||||
if match:
|
||||
result[key] = match.group(1).lower() if key != "confidence" else float(match.group(1))
|
||||
|
||||
# 解析关键因素
|
||||
factors_match = re.search(r"关键因素[::]([^\n]*)(?:\n|$)", response)
|
||||
if factors_match:
|
||||
factors_text = factors_match.group(1).strip()
|
||||
result["key_factors"] = [f.strip() for f in factors_text.split(",") if f.strip()]
|
||||
else:
|
||||
result["key_factors"] = []
|
||||
|
||||
# 提取推理过程
|
||||
reasoning_match = re.search(r"推理过程[::]([^\n]*)(?:\n|$)", response)
|
||||
if reasoning_match:
|
||||
result["reasoning"] = reasoning_match.group(1).strip()
|
||||
else:
|
||||
# 如果找不到,使用默认推理
|
||||
result["reasoning"] = "基于推文内容和航空行业特点进行综合分析"
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def parse_disposal_response(response: str) -> Dict[str, Any]:
|
||||
"""解析处置方案响应"""
|
||||
|
||||
patterns = {
|
||||
"priority": r"优先级[::]\s*(high|medium|low)",
|
||||
"action_type": r"行动类型[::]\s*(response|investigate|monitor|ignore)",
|
||||
"urgency_level": r"紧急程度[::]\s*(immediate|soon|normal)",
|
||||
}
|
||||
|
||||
result = {}
|
||||
|
||||
for key, pattern in patterns.items():
|
||||
match = re.search(pattern, response, re.IGNORECASE)
|
||||
if match:
|
||||
result[key] = match.group(1).lower()
|
||||
|
||||
# 解析建议回复
|
||||
response_match = re.search(r"建议回复[::]([^\n]*)(?:\n|$)", response)
|
||||
if response_match:
|
||||
result["suggested_response"] = response_match.group(1).strip()
|
||||
|
||||
# 解析后续行动
|
||||
actions_match = re.search(r"后续行动[::]([^\n]*)(?:\n|$)", response)
|
||||
if actions_match:
|
||||
actions_text = actions_match.group(1).strip()
|
||||
result["follow_up_actions"] = [a.strip() for a in actions_text.split(",") if a.strip()]
|
||||
else:
|
||||
result["follow_up_actions"] = []
|
||||
|
||||
# 解析制定理由
|
||||
reasoning_match = re.search(r"制定理由[::]([^\n]*)(?:\n|$)", response)
|
||||
if reasoning_match:
|
||||
result["reasoning"] = reasoning_match.group(1).strip()
|
||||
else:
|
||||
result["reasoning"] = "基于情感分析结果制定"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class OptimizedDeepSeekTweetAgent:
|
||||
"""优化版 DeepSeek 推文分析 Agent"""
|
||||
|
||||
def __init__(self):
|
||||
self.client = DeepSeekClient()
|
||||
self.parser = ResponseParser()
|
||||
|
||||
async def analyze_sentiment(self, text: str, airline: str) -> SentimentAnalysisResult:
|
||||
"""优化版情感分析"""
|
||||
|
||||
prompt = f"""
|
||||
你是一位专业的航空行业情感分析专家。请分析以下推文的情感倾向:
|
||||
|
||||
推文内容:"{text}"
|
||||
航空公司:{airline}
|
||||
|
||||
请严格按照以下JSON格式输出分析结果:
|
||||
{{
|
||||
"sentiment": "negative/neutral/positive",
|
||||
"confidence": 0.0-1.0之间的数值,
|
||||
"intensity": "mild/moderate/strong",
|
||||
"key_factors": ["因素1", "因素2", "因素3"],
|
||||
"reasoning": "详细的情感判断推理过程"
|
||||
}}
|
||||
|
||||
分析要求:
|
||||
1. 情感判断要准确反映推文的真实情感
|
||||
2. 置信度要基于推文的明确程度和情感强度
|
||||
3. 关键因素要具体、相关
|
||||
4. 推理过程要详细、有逻辑
|
||||
|
||||
请只输出JSON格式的结果,不要有其他内容。
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一位专业的航空行业情感分析专家,擅长准确识别推文中的情感倾向。"
|
||||
},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.client.chat_completion_with_retry(messages, temperature=0.1)
|
||||
|
||||
# 清理响应文本,移除可能的标记和空白
|
||||
cleaned_response = response.strip()
|
||||
|
||||
# 尝试解析JSON响应
|
||||
try:
|
||||
# 尝试提取JSON部分(如果响应包含其他文本)
|
||||
json_match = re.search(r'\{[^}]+\}', cleaned_response)
|
||||
if json_match:
|
||||
json_text = json_match.group(0)
|
||||
result_data = json.loads(json_text)
|
||||
else:
|
||||
result_data = json.loads(cleaned_response)
|
||||
|
||||
# 验证必需字段
|
||||
required_fields = ["sentiment", "confidence", "intensity"]
|
||||
for field in required_fields:
|
||||
if field not in result_data:
|
||||
raise ValueError(f"缺少必需字段: {field}")
|
||||
|
||||
return SentimentAnalysisResult(**result_data)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as json_error:
|
||||
# JSON解析失败,使用正则解析
|
||||
print(f"JSON解析失败,使用正则解析: {json_error}")
|
||||
parsed_data = self.parser.parse_sentiment_response(response)
|
||||
|
||||
# 确保必需字段有默认值
|
||||
default_values = {
|
||||
"sentiment": "neutral",
|
||||
"confidence": 0.5,
|
||||
"intensity": "moderate"
|
||||
}
|
||||
|
||||
for field, default_value in default_values.items():
|
||||
if field not in parsed_data or not parsed_data[field]:
|
||||
parsed_data[field] = default_value
|
||||
|
||||
return SentimentAnalysisResult(**parsed_data)
|
||||
|
||||
except APIError as e:
|
||||
# API调用失败,返回默认结果
|
||||
print(f"API调用失败: {e.message}")
|
||||
return SentimentAnalysisResult(
|
||||
sentiment="neutral",
|
||||
confidence=0.5,
|
||||
intensity="moderate",
|
||||
key_factors=["API调用失败,使用默认分析"],
|
||||
reasoning=f"API调用失败: {e.message}"
|
||||
)
|
||||
|
||||
async def generate_disposal_plan(
|
||||
self,
|
||||
text: str,
|
||||
airline: str,
|
||||
sentiment_result: SentimentAnalysisResult
|
||||
) -> DisposalPlan:
|
||||
"""生成优化版处置方案"""
|
||||
|
||||
prompt = f"""
|
||||
基于以下推文分析和情感判断结果,为航空公司制定一个合理的处置方案:
|
||||
|
||||
推文内容:"{text}"
|
||||
航空公司:{airline}
|
||||
情感分析结果:
|
||||
- 情感类别:{sentiment_result.sentiment}
|
||||
- 置信度:{sentiment_result.confidence:.1%}
|
||||
- 情感强度:{sentiment_result.intensity}
|
||||
- 关键因素:{', '.join(sentiment_result.key_factors)}
|
||||
|
||||
请严格按照以下JSON格式输出处置方案:
|
||||
{{
|
||||
"priority": "high/medium/low",
|
||||
"action_type": "response/investigate/monitor/ignore",
|
||||
"suggested_response": "具体的回复建议(如适用)",
|
||||
"follow_up_actions": ["行动1", "行动2"],
|
||||
"reasoning": "制定此方案的理由",
|
||||
"urgency_level": "immediate/soon/normal"
|
||||
}}
|
||||
|
||||
要求:
|
||||
1. 优先级要基于情感强度和置信度
|
||||
2. 行动类型要符合航空行业最佳实践
|
||||
3. 建议回复要专业、有同理心
|
||||
4. 后续行动要具体、可执行
|
||||
|
||||
请只输出JSON格式的结果。
|
||||
"""
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "你是一位航空公司的客户服务专家,擅长制定合理的客户反馈处置方案。"
|
||||
},
|
||||
{"role": "user", "content": prompt}
|
||||
]
|
||||
|
||||
try:
|
||||
response = await self.client.chat_completion_with_retry(messages, temperature=0.3)
|
||||
|
||||
# 清理响应文本
|
||||
cleaned_response = response.strip()
|
||||
|
||||
try:
|
||||
# 尝试提取JSON部分
|
||||
json_match = re.search(r'\{[^}]+\}', cleaned_response)
|
||||
if json_match:
|
||||
json_text = json_match.group(0)
|
||||
result_data = json.loads(json_text)
|
||||
else:
|
||||
result_data = json.loads(cleaned_response)
|
||||
|
||||
# 验证必需字段
|
||||
required_fields = ["priority", "action_type", "reasoning"]
|
||||
for field in required_fields:
|
||||
if field not in result_data:
|
||||
raise ValueError(f"缺少必需字段: {field}")
|
||||
|
||||
return DisposalPlan(**result_data)
|
||||
|
||||
except (json.JSONDecodeError, ValueError) as json_error:
|
||||
# JSON解析失败,使用正则解析
|
||||
print(f"处置方案JSON解析失败,使用正则解析: {json_error}")
|
||||
parsed_data = self.parser.parse_disposal_response(response)
|
||||
|
||||
# 确保必需字段有默认值
|
||||
default_values = {
|
||||
"priority": "medium",
|
||||
"action_type": "monitor",
|
||||
"reasoning": "基于情感分析结果制定",
|
||||
"follow_up_actions": [],
|
||||
"urgency_level": "normal"
|
||||
}
|
||||
|
||||
for field, default_value in default_values.items():
|
||||
if field not in parsed_data or not parsed_data[field]:
|
||||
parsed_data[field] = default_value
|
||||
|
||||
return DisposalPlan(**parsed_data)
|
||||
|
||||
except APIError as e:
|
||||
# API调用失败,返回默认处置方案
|
||||
print(f"处置方案API调用失败: {e.message}")
|
||||
return self._generate_default_disposal_plan(sentiment_result)
|
||||
|
||||
def _generate_default_disposal_plan(self, sentiment_result: SentimentAnalysisResult) -> DisposalPlan:
|
||||
"""生成默认处置方案"""
|
||||
|
||||
if sentiment_result.sentiment == "negative":
|
||||
return DisposalPlan(
|
||||
priority="medium",
|
||||
action_type="investigate",
|
||||
suggested_response=None,
|
||||
follow_up_actions=["进一步核实情况", "根据核实结果决定行动"],
|
||||
reasoning="负面情感需要进一步调查",
|
||||
urgency_level="soon"
|
||||
)
|
||||
elif sentiment_result.sentiment == "positive":
|
||||
return DisposalPlan(
|
||||
priority="low",
|
||||
action_type="monitor",
|
||||
suggested_response=None,
|
||||
follow_up_actions=["持续关注用户动态"],
|
||||
reasoning="正面情感保持关注即可",
|
||||
urgency_level="normal"
|
||||
)
|
||||
else:
|
||||
return DisposalPlan(
|
||||
priority="low",
|
||||
action_type="monitor",
|
||||
suggested_response=None,
|
||||
follow_up_actions=["常规关注"],
|
||||
reasoning="中性情感常规处理",
|
||||
urgency_level="normal"
|
||||
)
|
||||
|
||||
async def analyze_tweet(self, text: str, airline: str) -> TweetAnalysisResult:
|
||||
"""完整的推文分析流程"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 情感分析
|
||||
sentiment_result = await self.analyze_sentiment(text, airline)
|
||||
|
||||
# 2. 生成处置方案
|
||||
disposal_plan = await self.generate_disposal_plan(text, airline, sentiment_result)
|
||||
|
||||
# 3. 计算处理时间
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# 返回完整结果
|
||||
return TweetAnalysisResult(
|
||||
tweet_text=text,
|
||||
airline=airline,
|
||||
sentiment_analysis=sentiment_result,
|
||||
disposal_plan=disposal_plan,
|
||||
processing_time=processing_time,
|
||||
api_used=True
|
||||
)
|
||||
|
||||
|
||||
# 同步版本的包装函数
|
||||
async def analyze_tweet_async(text: str, airline: str) -> TweetAnalysisResult:
|
||||
"""异步版本的推文分析"""
|
||||
agent = OptimizedDeepSeekTweetAgent()
|
||||
return await agent.analyze_tweet(text, airline)
|
||||
|
||||
|
||||
def analyze_tweet_sync(text: str, airline: str) -> TweetAnalysisResult:
|
||||
"""同步版本的推文分析函数"""
|
||||
return asyncio.run(analyze_tweet_async(text, airline))
|
||||
|
||||
|
||||
# 终极版本 - 完全不需要航空公司参数
|
||||
async def analyze_tweet_ultimate_async(text: str) -> TweetAnalysisResult:
|
||||
"""终极版本异步推文分析 - 无需航空公司参数"""
|
||||
agent = OptimizedDeepSeekTweetAgent()
|
||||
|
||||
# 自动检测航空公司或使用通用标识
|
||||
airline = "通用航空公司"
|
||||
|
||||
# 简单的航空公司检测逻辑
|
||||
airline_keywords = {
|
||||
"united": "United Airlines",
|
||||
"delta": "Delta Air Lines",
|
||||
"american": "American Airlines",
|
||||
"southwest": "Southwest Airlines",
|
||||
"jetblue": "JetBlue Airways",
|
||||
"air china": "中国国际航空",
|
||||
"china eastern": "中国东方航空",
|
||||
"china southern": "中国南方航空"
|
||||
}
|
||||
|
||||
text_lower = text.lower()
|
||||
for keyword, airline_name in airline_keywords.items():
|
||||
if keyword in text_lower:
|
||||
airline = airline_name
|
||||
break
|
||||
|
||||
return await agent.analyze_tweet(text, airline)
|
||||
|
||||
|
||||
def analyze_tweet_sync_ultimate(text: str) -> TweetAnalysisResult:
|
||||
"""终极版本同步推文分析 - 完全无需航空公司参数"""
|
||||
return asyncio.run(analyze_tweet_ultimate_async(text))
|
||||
642
src/streamlit_tweet_app_smart.py
Normal file
642
src/streamlit_tweet_app_smart.py
Normal file
@ -0,0 +1,642 @@
|
||||
"""智能版 Streamlit 演示应用 - 推文情感分析
|
||||
|
||||
移除手动选择航空公司功能,实现智能识别和上下文感知分析。
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import asyncio
|
||||
import streamlit as st
|
||||
from datetime import datetime
|
||||
|
||||
# Ensure project root is in path
|
||||
sys.path.append(os.getcwd())
|
||||
|
||||
# 首先加载环境变量
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# 然后导入配置和其他模块
|
||||
from src.config import Config
|
||||
from src.deepseek_agent_optimized import analyze_tweet_sync
|
||||
from src.tweet_agent_enhanced import analyze_tweet as analyze_tweet_enhanced
|
||||
from src.airline_detector import extract_airline_context
|
||||
|
||||
# 页面配置
|
||||
st.set_page_config(
|
||||
page_title="智能航空推文分析",
|
||||
page_icon="✈️",
|
||||
layout="wide",
|
||||
initial_sidebar_state="expanded"
|
||||
)
|
||||
|
||||
# 自定义CSS样式
|
||||
st.markdown("""
|
||||
<style>
|
||||
.main-header {
|
||||
text-align: center;
|
||||
padding: 1rem;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
border-radius: 10px;
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
.sentiment-card {
|
||||
padding: 1rem;
|
||||
border-radius: 10px;
|
||||
margin: 0.5rem 0;
|
||||
box-shadow: 0 2px 4px rgba(0,0,0,0.1);
|
||||
}
|
||||
.negative { background-color: #ff6b6b; color: white; }
|
||||
.neutral { background-color: #ffd93d; color: #333; }
|
||||
.positive { background-color: #6bcb77; color: white; }
|
||||
.priority-high { background-color: #ff4757; color: white; }
|
||||
.priority-medium { background-color: #ffa502; color: white; }
|
||||
.priority-low { background-color: #2ed573; color: white; }
|
||||
.context-info {
|
||||
background-color: #f8f9fa;
|
||||
padding: 0.5rem;
|
||||
border-radius: 5px;
|
||||
border-left: 4px solid #6c757d;
|
||||
margin: 0.5rem 0;
|
||||
}
|
||||
</style>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
# 页面标题
|
||||
st.markdown("""
|
||||
<div class="main-header">
|
||||
<h1>✈️ 智能航空推文分析系统</h1>
|
||||
<p>基于上下文感知的实时情感分析与智能处置方案</p>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
# 侧边栏配置
|
||||
with st.sidebar:
|
||||
st.header("⚙️ 系统配置")
|
||||
|
||||
# API 状态显示
|
||||
api_valid, api_message = Config.validate_config_with_message()
|
||||
api_status = "✅ 可用" if api_valid else "❌ 不可用"
|
||||
|
||||
st.metric("DeepSeek API 状态", api_status)
|
||||
|
||||
if not api_valid:
|
||||
st.error(f"⚠️ {api_message}")
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# 分析模式选择
|
||||
st.markdown("### 🔧 分析模式")
|
||||
|
||||
# 根据API状态动态调整可用模式
|
||||
available_modes = ["🤖 智能模式 (推荐)", "⚙️ 传统模型"]
|
||||
if api_valid:
|
||||
available_modes.insert(1, "🧠 DeepSeek API")
|
||||
|
||||
analysis_mode = st.radio(
|
||||
"选择分析引擎",
|
||||
available_modes,
|
||||
help="智能模式会自动选择最优分析方式"
|
||||
)
|
||||
|
||||
# 模式描述
|
||||
mode_descriptions = {
|
||||
"🤖 智能模式 (推荐)": "自动选择最优分析方式,平衡准确性和速度",
|
||||
"🧠 DeepSeek API": "使用大语言模型进行深度语义分析",
|
||||
"⚙️ 传统模型": "使用预训练的机器学习模型"
|
||||
}
|
||||
|
||||
if analysis_mode == "🧠 DeepSeek API" and not api_valid:
|
||||
st.warning("⚠️ DeepSeek API 当前不可用,将自动切换到传统模型")
|
||||
else:
|
||||
st.info(mode_descriptions[analysis_mode])
|
||||
|
||||
st.markdown("---")
|
||||
|
||||
# 功能选择
|
||||
st.markdown("### 📋 功能菜单")
|
||||
app_mode = st.radio(
|
||||
"选择功能",
|
||||
["📝 智能分析", "📊 批量处理", "📈 分析统计"]
|
||||
)
|
||||
|
||||
# 初始化会话状态
|
||||
if "analysis_history" not in st.session_state:
|
||||
st.session_state.analysis_history = []
|
||||
|
||||
if "performance_stats" not in st.session_state:
|
||||
st.session_state.performance_stats = {
|
||||
"total_analyses": 0,
|
||||
"avg_processing_time": 0,
|
||||
"sentiment_distribution": {"negative": 0, "neutral": 0, "positive": 0}
|
||||
}
|
||||
|
||||
# 工具函数
|
||||
def get_sentiment_emoji(sentiment: str) -> str:
|
||||
"""获取情感对应的表情符号"""
|
||||
emoji_map = {
|
||||
"negative": "😠",
|
||||
"neutral": "😐",
|
||||
"positive": "😊",
|
||||
}
|
||||
return emoji_map.get(sentiment, "❓")
|
||||
|
||||
def get_priority_emoji(priority: str) -> str:
|
||||
"""获取优先级对应的表情符号"""
|
||||
emoji_map = {
|
||||
"high": "🔴",
|
||||
"medium": "🟡",
|
||||
"low": "🟢",
|
||||
}
|
||||
return emoji_map.get(priority, "⚪")
|
||||
|
||||
def get_urgency_emoji(urgency: str) -> str:
|
||||
"""获取紧急程度对应的表情符号"""
|
||||
emoji_map = {
|
||||
"immediate": "🚨",
|
||||
"soon": "⏰",
|
||||
"normal": "📅",
|
||||
}
|
||||
return emoji_map.get(urgency, "📌")
|
||||
|
||||
# 主要功能界面
|
||||
if app_mode == "📝 智能分析":
|
||||
st.markdown("### 📝 智能推文分析")
|
||||
|
||||
# 输入区域
|
||||
with st.container():
|
||||
tweet_text = st.text_area(
|
||||
"✍️ 推文内容",
|
||||
placeholder="例如:@United 航班延误了3小时,服务态度很差,非常失望...",
|
||||
height=120,
|
||||
help="请输入要分析的航空推文内容,系统会自动识别航空公司"
|
||||
)
|
||||
|
||||
# 分析按钮
|
||||
col1, col2, col3 = st.columns([1, 2, 1])
|
||||
with col2:
|
||||
analyze_btn = st.button(
|
||||
"🚀 开始智能分析",
|
||||
type="primary",
|
||||
use_container_width=True,
|
||||
disabled=not tweet_text.strip()
|
||||
)
|
||||
|
||||
if analyze_btn and tweet_text:
|
||||
# 显示分析进度
|
||||
progress_bar = st.progress(0)
|
||||
status_text = st.empty()
|
||||
|
||||
try:
|
||||
# 分析进度模拟
|
||||
for i in range(4):
|
||||
progress = (i + 1) * 25
|
||||
progress_bar.progress(progress)
|
||||
|
||||
if i == 0:
|
||||
status_text.text("🔍 正在识别上下文...")
|
||||
elif i == 1:
|
||||
status_text.text("😊 正在分析情感倾向...")
|
||||
elif i == 2:
|
||||
status_text.text("💭 正在生成解释...")
|
||||
else:
|
||||
status_text.text("🚀 正在制定处置方案...")
|
||||
|
||||
time.sleep(0.3)
|
||||
|
||||
# 自动识别航空公司上下文(不显示给用户)
|
||||
airline_context = extract_airline_context(tweet_text)
|
||||
suggested_airline = airline_context.get("primary_airline", "通用航空公司")
|
||||
|
||||
# 2. 执行分析
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if analysis_mode == "🧠 DeepSeek API":
|
||||
if api_valid:
|
||||
result = analyze_tweet_sync(tweet_text, suggested_airline)
|
||||
else:
|
||||
st.warning("DeepSeek API 不可用,自动切换到传统模型")
|
||||
result = analyze_tweet_enhanced(tweet_text, suggested_airline, "traditional")
|
||||
elif analysis_mode == "⚙️ 传统模型":
|
||||
result = analyze_tweet_enhanced(tweet_text, suggested_airline, "traditional")
|
||||
else: # 智能模式
|
||||
if api_valid:
|
||||
result = analyze_tweet_sync(tweet_text, suggested_airline)
|
||||
else:
|
||||
result = analyze_tweet_enhanced(tweet_text, suggested_airline, "traditional")
|
||||
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
except Exception as analysis_error:
|
||||
# 分析失败时的优雅降级
|
||||
processing_time = time.time() - start_time
|
||||
st.error(f"分析过程中出现错误: {str(analysis_error)}")
|
||||
|
||||
# 尝试使用传统模型作为备选
|
||||
try:
|
||||
status_text.text("🔄 尝试使用传统模型...")
|
||||
result = analyze_tweet_enhanced(tweet_text, suggested_airline, "traditional")
|
||||
st.info("✅ 已成功使用传统模型完成分析")
|
||||
except Exception as fallback_error:
|
||||
# 如果传统模型也失败,抛出错误
|
||||
raise Exception(f"所有分析模式都失败: {str(fallback_error)}")
|
||||
|
||||
# 完成进度
|
||||
progress_bar.progress(100)
|
||||
status_text.text("✅ 分析完成!")
|
||||
|
||||
# 显示结果
|
||||
st.success(f"分析完成!耗时: {processing_time:.2f}秒 | 上下文: {suggested_airline}")
|
||||
|
||||
# 情感分析结果
|
||||
st.markdown("### 😊 情感分析结果")
|
||||
|
||||
# 统一处理不同的结果对象
|
||||
if hasattr(result, 'sentiment_analysis'):
|
||||
sentiment_result = result.sentiment_analysis
|
||||
disposal_plan = result.disposal_plan
|
||||
else:
|
||||
sentiment_result = result.classification
|
||||
disposal_plan = result.disposal_plan
|
||||
|
||||
sentiment = sentiment_result.sentiment
|
||||
confidence = sentiment_result.confidence
|
||||
intensity = getattr(sentiment_result, 'intensity', 'moderate')
|
||||
|
||||
# 情感卡片
|
||||
sentiment_color = {
|
||||
"negative": "negative",
|
||||
"neutral": "neutral",
|
||||
"positive": "positive"
|
||||
}[sentiment]
|
||||
|
||||
st.markdown(f"""
|
||||
<div class="sentiment-card {sentiment_color}">
|
||||
<h2>{get_sentiment_emoji(sentiment)} {sentiment.upper()}</h2>
|
||||
<p><strong>置信度:</strong> {confidence:.1%} | <strong>情感强度:</strong> {intensity}</p>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
# 分析详情
|
||||
col1, col2 = st.columns(2)
|
||||
|
||||
with col1:
|
||||
st.markdown("#### 🔍 关键因素")
|
||||
factors = getattr(sentiment_result, 'key_factors', [])
|
||||
if factors:
|
||||
factors = factors[:Config.MAX_DISPLAY_FACTORS]
|
||||
for factor in factors:
|
||||
st.markdown(f"- {factor}")
|
||||
else:
|
||||
st.info("未识别到明显的关键因素")
|
||||
|
||||
with col2:
|
||||
st.markdown("#### 💭 推理过程")
|
||||
reasoning = getattr(sentiment_result, 'reasoning', '基于模型分析得出结果')
|
||||
st.info(reasoning)
|
||||
|
||||
# 处置方案
|
||||
st.markdown("### 🚀 处置方案")
|
||||
|
||||
priority = disposal_plan.priority
|
||||
urgency = getattr(disposal_plan, 'urgency_level', 'normal')
|
||||
|
||||
st.markdown(f"""
|
||||
<div class="priority-{priority}" style="padding: 1rem; border-radius: 10px; color: white;">
|
||||
<h3>{get_priority_emoji(priority)} 优先级: {priority.upper()}</h3>
|
||||
<p>{get_urgency_emoji(urgency)} 紧急程度: {urgency} | 📋 行动类型: {disposal_plan.action_type}</p>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
if disposal_plan.suggested_response:
|
||||
st.markdown("#### 💬 建议回复")
|
||||
with st.expander("查看建议回复"):
|
||||
st.success(disposal_plan.suggested_response)
|
||||
|
||||
st.markdown("#### 📋 后续行动")
|
||||
actions = disposal_plan.follow_up_actions[:Config.MAX_DISPLAY_ACTIONS]
|
||||
for i, action in enumerate(actions, 1):
|
||||
st.markdown(f"{i}. {action}")
|
||||
|
||||
st.markdown("#### 📝 制定理由")
|
||||
reasoning = getattr(disposal_plan, 'reasoning', '基于情感分析结果制定')
|
||||
st.info(reasoning)
|
||||
|
||||
# 记录分析历史
|
||||
analysis_record = {
|
||||
"timestamp": datetime.now(),
|
||||
"tweet": tweet_text[:100] + "..." if len(tweet_text) > 100 else tweet_text,
|
||||
"airline": suggested_airline,
|
||||
"sentiment": sentiment,
|
||||
"confidence": confidence,
|
||||
"processing_time": processing_time,
|
||||
"mode": analysis_mode,
|
||||
"airline_context": airline_context
|
||||
}
|
||||
st.session_state.analysis_history.insert(0, analysis_record)
|
||||
|
||||
# 更新统计信息
|
||||
st.session_state.performance_stats["total_analyses"] += 1
|
||||
st.session_state.performance_stats["sentiment_distribution"][sentiment] += 1
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"❌ 分析失败: {str(e)}")
|
||||
st.info("💡 建议检查网络连接或尝试切换分析模式")
|
||||
|
||||
elif app_mode == "📊 批量处理":
|
||||
st.markdown("### 📊 批量推文处理")
|
||||
|
||||
# 批量处理模式选择
|
||||
batch_mode = st.radio(
|
||||
"选择批量处理方式",
|
||||
["📁 上传CSV文件", "✍️ 手动输入多条数据"],
|
||||
horizontal=True
|
||||
)
|
||||
|
||||
if batch_mode == "📁 上传CSV文件":
|
||||
st.markdown("#### 📁 CSV文件上传")
|
||||
|
||||
uploaded_file = st.file_uploader(
|
||||
"上传CSV文件",
|
||||
type=['csv'],
|
||||
help="请上传包含推文数据的CSV文件,文件应包含'tweet_text'列"
|
||||
)
|
||||
|
||||
if uploaded_file is not None:
|
||||
try:
|
||||
import pandas as pd
|
||||
|
||||
# 读取CSV文件
|
||||
df = pd.read_csv(uploaded_file)
|
||||
|
||||
# 检查必要的列
|
||||
if 'tweet_text' not in df.columns:
|
||||
st.error("❌ CSV文件必须包含'tweet_text'列")
|
||||
else:
|
||||
st.success(f"✅ 成功读取文件,共 {len(df)} 条推文")
|
||||
|
||||
# 显示数据预览
|
||||
with st.expander("📋 数据预览", expanded=True):
|
||||
st.dataframe(df.head(), use_container_width=True)
|
||||
|
||||
# 批量分析按钮
|
||||
if st.button("🚀 开始批量分析", type="primary"):
|
||||
# 批量分析逻辑
|
||||
batch_results = []
|
||||
progress_bar = st.progress(0)
|
||||
|
||||
for i, row in df.iterrows():
|
||||
tweet_text = str(row['tweet_text'])
|
||||
|
||||
# 更新进度
|
||||
progress = int((i + 1) / len(df) * 100)
|
||||
progress_bar.progress(progress)
|
||||
|
||||
try:
|
||||
# 执行单条分析
|
||||
airline_context = extract_airline_context(tweet_text)
|
||||
suggested_airline = airline_context.get("primary_airline", "通用航空公司")
|
||||
|
||||
# 根据分析模式选择分析函数
|
||||
if analysis_mode == "🧠 DeepSeek API" and api_valid:
|
||||
result = analyze_tweet_sync(tweet_text, suggested_airline)
|
||||
else:
|
||||
result = analyze_tweet_enhanced(tweet_text, suggested_airline, "traditional")
|
||||
|
||||
# 统一处理结果
|
||||
if hasattr(result, 'sentiment_analysis'):
|
||||
sentiment_result = result.sentiment_analysis
|
||||
disposal_plan = result.disposal_plan
|
||||
else:
|
||||
sentiment_result = result.classification
|
||||
disposal_plan = result.disposal_plan
|
||||
|
||||
batch_results.append({
|
||||
'序号': i + 1,
|
||||
'推文': tweet_text[:100] + "..." if len(tweet_text) > 100 else tweet_text,
|
||||
'航空公司': suggested_airline,
|
||||
'情感': sentiment_result.sentiment,
|
||||
'置信度': sentiment_result.confidence,
|
||||
'优先级': disposal_plan.priority,
|
||||
'行动类型': disposal_plan.action_type
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
batch_results.append({
|
||||
'序号': i + 1,
|
||||
'推文': tweet_text[:100] + "..." if len(tweet_text) > 100 else tweet_text,
|
||||
'航空公司': "分析失败",
|
||||
'情感': "错误",
|
||||
'置信度': 0.0,
|
||||
'优先级': "未知",
|
||||
'行动类型': "检查错误"
|
||||
})
|
||||
|
||||
# 显示批量分析结果
|
||||
st.markdown("### 📊 批量分析结果")
|
||||
|
||||
# 转换为DataFrame显示
|
||||
results_df = pd.DataFrame(batch_results)
|
||||
st.dataframe(results_df, use_container_width=True)
|
||||
|
||||
# 统计信息
|
||||
st.markdown("#### 📈 分析统计")
|
||||
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("总推文数", len(df))
|
||||
with col2:
|
||||
success_count = len([r for r in batch_results if r['情感'] != "错误"])
|
||||
st.metric("成功分析", success_count)
|
||||
with col3:
|
||||
error_count = len([r for r in batch_results if r['情感'] == "错误"])
|
||||
st.metric("分析失败", error_count)
|
||||
with col4:
|
||||
success_rate = (success_count / len(df)) * 100 if len(df) > 0 else 0
|
||||
st.metric("成功率", f"{success_rate:.1f}%")
|
||||
|
||||
# 情感分布
|
||||
sentiment_counts = results_df['情感'].value_counts()
|
||||
|
||||
if not sentiment_counts.empty:
|
||||
st.markdown("#### 😊 情感分布")
|
||||
|
||||
sentiment_cols = st.columns(len(sentiment_counts))
|
||||
for idx, (sentiment, count) in enumerate(sentiment_counts.items()):
|
||||
with sentiment_cols[idx % len(sentiment_cols)]:
|
||||
st.metric(f"{get_sentiment_emoji(sentiment)} {sentiment}", count)
|
||||
|
||||
# 导出功能
|
||||
st.markdown("#### 💾 结果导出")
|
||||
|
||||
csv_data = results_df.to_csv(index=False).encode('utf-8')
|
||||
st.download_button(
|
||||
label="<EFBFBD> 下载CSV结果",
|
||||
data=csv_data,
|
||||
file_name=f"批量分析结果_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
||||
mime="text/csv"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
st.error(f"❌ 文件处理错误: {str(e)}")
|
||||
|
||||
else: # 手动输入多条数据
|
||||
st.markdown("#### ✍️ 手动输入多条数据")
|
||||
|
||||
# 多行文本输入
|
||||
multi_tweets = st.text_area(
|
||||
"输入多条推文(每行一条)",
|
||||
placeholder="例如:\n@United 航班延误了3小时,服务态度很差\n@Delta 服务非常专业,体验很好\n机场安检排队时间太长",
|
||||
height=200,
|
||||
help="每行输入一条推文,系统将逐条进行分析"
|
||||
)
|
||||
|
||||
if st.button("🚀 开始批量分析", type="primary") and multi_tweets.strip():
|
||||
# 分割推文
|
||||
tweets = [tweet.strip() for tweet in multi_tweets.split('\n') if tweet.strip()]
|
||||
|
||||
if tweets:
|
||||
st.success(f"✅ 识别到 {len(tweets)} 条推文")
|
||||
|
||||
# 批量分析逻辑
|
||||
batch_results = []
|
||||
progress_bar = st.progress(0)
|
||||
|
||||
for i, tweet_text in enumerate(tweets):
|
||||
# 更新进度
|
||||
progress = int((i + 1) / len(tweets) * 100)
|
||||
progress_bar.progress(progress)
|
||||
|
||||
try:
|
||||
# 执行单条分析
|
||||
airline_context = extract_airline_context(tweet_text)
|
||||
suggested_airline = airline_context.get("primary_airline", "通用航空公司")
|
||||
|
||||
# 根据分析模式选择分析函数
|
||||
if analysis_mode == "🧠 DeepSeek API" and api_valid:
|
||||
result = analyze_tweet_sync(tweet_text, suggested_airline)
|
||||
else:
|
||||
result = analyze_tweet_enhanced(tweet_text, suggested_airline, "traditional")
|
||||
|
||||
# 统一处理结果
|
||||
if hasattr(result, 'sentiment_analysis'):
|
||||
sentiment_result = result.sentiment_analysis
|
||||
disposal_plan = result.disposal_plan
|
||||
else:
|
||||
sentiment_result = result.classification
|
||||
disposal_plan = result.disposal_plan
|
||||
|
||||
batch_results.append({
|
||||
'序号': i + 1,
|
||||
'推文': tweet_text[:100] + "..." if len(tweet_text) > 100 else tweet_text,
|
||||
'航空公司': suggested_airline,
|
||||
'情感': sentiment_result.sentiment,
|
||||
'置信度': sentiment_result.confidence,
|
||||
'优先级': disposal_plan.priority,
|
||||
'行动类型': disposal_plan.action_type
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
batch_results.append({
|
||||
'序号': i + 1,
|
||||
'推文': tweet_text[:100] + "..." if len(tweet_text) > 100 else tweet_text,
|
||||
'航空公司': "分析失败",
|
||||
'情感': "错误",
|
||||
'置信度': 0.0,
|
||||
'优先级': "未知",
|
||||
'行动类型': "检查错误"
|
||||
})
|
||||
|
||||
# 显示批量分析结果
|
||||
st.markdown("### 📊 批量分析结果")
|
||||
|
||||
# 转换为DataFrame显示
|
||||
import pandas as pd
|
||||
results_df = pd.DataFrame(batch_results)
|
||||
st.dataframe(results_df, use_container_width=True)
|
||||
|
||||
# 统计信息
|
||||
st.markdown("#### 📈 分析统计")
|
||||
|
||||
col1, col2, col3, col4 = st.columns(4)
|
||||
|
||||
with col1:
|
||||
st.metric("总推文数", len(tweets))
|
||||
with col2:
|
||||
success_count = len([r for r in batch_results if r['情感'] != "错误"])
|
||||
st.metric("成功分析", success_count)
|
||||
with col3:
|
||||
error_count = len([r for r in batch_results if r['情感'] == "错误"])
|
||||
st.metric("分析失败", error_count)
|
||||
with col4:
|
||||
success_rate = (success_count / len(tweets)) * 100 if len(tweets) > 0 else 0
|
||||
st.metric("成功率", f"{success_rate:.1f}%")
|
||||
|
||||
# 情感分布
|
||||
sentiment_counts = results_df['情感'].value_counts()
|
||||
|
||||
if not sentiment_counts.empty:
|
||||
st.markdown("#### 😊 情感分布")
|
||||
|
||||
sentiment_cols = st.columns(len(sentiment_counts))
|
||||
for idx, (sentiment, count) in enumerate(sentiment_counts.items()):
|
||||
with sentiment_cols[idx % len(sentiment_cols)]:
|
||||
st.metric(f"{get_sentiment_emoji(sentiment)} {sentiment}", count)
|
||||
|
||||
# 导出功能
|
||||
st.markdown("#### 💾 结果导出")
|
||||
|
||||
csv_data = results_df.to_csv(index=False).encode('utf-8')
|
||||
st.download_button(
|
||||
label="📥 下载CSV结果",
|
||||
data=csv_data,
|
||||
file_name=f"批量分析结果_{datetime.now().strftime('%Y%m%d_%H%M%S')}.csv",
|
||||
mime="text/csv"
|
||||
)
|
||||
|
||||
elif app_mode == "📈 分析统计":
|
||||
st.markdown("### 📈 分析统计")
|
||||
|
||||
if not st.session_state.analysis_history:
|
||||
st.info("📊 暂无分析数据,请先进行智能分析")
|
||||
else:
|
||||
# 统计信息
|
||||
total_analyses = st.session_state.performance_stats["total_analyses"]
|
||||
sentiment_dist = st.session_state.performance_stats["sentiment_distribution"]
|
||||
|
||||
col1, col2, col3 = st.columns(3)
|
||||
|
||||
with col1:
|
||||
st.metric("总分析次数", total_analyses)
|
||||
with col2:
|
||||
negative_pct = (sentiment_dist["negative"] / total_analyses * 100) if total_analyses > 0 else 0
|
||||
st.metric("负面情感比例", f"{negative_pct:.1f}%")
|
||||
with col3:
|
||||
avg_time = sum([r["processing_time"] for r in st.session_state.analysis_history]) / total_analyses
|
||||
st.metric("平均处理时间", f"{avg_time:.2f}s")
|
||||
|
||||
# 航空公司统计
|
||||
airline_stats = {}
|
||||
for record in st.session_state.analysis_history:
|
||||
airline = record["airline"]
|
||||
if airline not in airline_stats:
|
||||
airline_stats[airline] = 0
|
||||
airline_stats[airline] += 1
|
||||
|
||||
if airline_stats:
|
||||
st.markdown("#### ✈️ 航空公司分布")
|
||||
for airline, count in sorted(airline_stats.items(), key=lambda x: x[1], reverse=True)[:5]:
|
||||
st.write(f"- **{airline}**: {count} 次分析")
|
||||
|
||||
# 页脚信息
|
||||
st.markdown("---")
|
||||
st.markdown("""
|
||||
<div style="text-align: center; color: #666; font-size: 0.9em;">
|
||||
<p>🚀 智能航空推文分析系统 v3.0 | 基于上下文感知的实时情感分析</p>
|
||||
<p>💡 技术支持: 机器学习课程设计项目</p>
|
||||
</div>
|
||||
""", unsafe_allow_html=True)
|
||||
358
src/tweet_agent_enhanced.py
Normal file
358
src/tweet_agent_enhanced.py
Normal file
@ -0,0 +1,358 @@
|
||||
"""增强版推文情感分析 Agent
|
||||
|
||||
支持 DeepSeek API 和传统 ML 模型的混合模式,提供更智能的分析能力。
|
||||
"""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Literal
|
||||
import asyncio
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from src.deepseek_agent import DeepSeekTweetAgent, TweetAnalysisResult as DeepSeekResult
|
||||
from src.train_tweet_ultimate import load_model as load_ultimate_model
|
||||
|
||||
|
||||
class AnalysisMode:
|
||||
"""分析模式枚举"""
|
||||
DEEPSEEK = "deepseek"
|
||||
TRADITIONAL = "traditional"
|
||||
HYBRID = "hybrid"
|
||||
|
||||
|
||||
class EnhancedSentimentClassification(BaseModel):
|
||||
"""增强版情感分类结果"""
|
||||
sentiment: str = Field(description="情感类别: negative/neutral/positive")
|
||||
confidence: float = Field(description="置信度 (0-1)")
|
||||
reasoning: str = Field(description="情感判断的推理过程")
|
||||
key_factors: list[str] = Field(description="影响情感判断的关键因素")
|
||||
analysis_mode: str = Field(description="使用的分析模式: deepseek/traditional/hybrid")
|
||||
|
||||
|
||||
class EnhancedDisposalPlan(BaseModel):
|
||||
"""增强版处置方案"""
|
||||
priority: str = Field(description="处理优先级: high/medium/low")
|
||||
action_type: str = Field(description="行动类型: response/investigate/monitor/ignore")
|
||||
suggested_response: Optional[str] = Field(description="建议回复内容", default=None)
|
||||
follow_up_actions: list[str] = Field(description="后续行动建议")
|
||||
reasoning: str = Field(description="处置方案制定的理由")
|
||||
|
||||
|
||||
class EnhancedTweetAnalysisResult(BaseModel):
|
||||
"""增强版推文分析结果"""
|
||||
tweet_text: str = Field(description="原始推文文本")
|
||||
airline: str = Field(description="航空公司")
|
||||
classification: EnhancedSentimentClassification = Field(description="情感分类结果")
|
||||
disposal_plan: EnhancedDisposalPlan = Field(description="处置方案")
|
||||
processing_time: float = Field(description="处理耗时(秒)")
|
||||
api_used: bool = Field(description="是否使用了 API")
|
||||
|
||||
|
||||
class EnhancedTweetSentimentAgent:
|
||||
"""增强版推文情感分析 Agent
|
||||
|
||||
支持 DeepSeek API 和传统 ML 模型的混合模式。
|
||||
"""
|
||||
|
||||
def __init__(self, default_mode: str = AnalysisMode.HYBRID):
|
||||
"""初始化 Agent
|
||||
|
||||
Args:
|
||||
default_mode: 默认分析模式 (deepseek/traditional/hybrid)
|
||||
"""
|
||||
self.default_mode = default_mode
|
||||
|
||||
# 初始化传统 ML 模型
|
||||
self.traditional_model = None
|
||||
try:
|
||||
self.traditional_model = load_ultimate_model()
|
||||
print("✅ 传统 ML 模型加载成功")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 传统 ML 模型加载失败: {e}")
|
||||
|
||||
# 初始化 DeepSeek API 客户端
|
||||
self.deepseek_agent = None
|
||||
try:
|
||||
self.deepseek_agent = DeepSeekTweetAgent()
|
||||
print("✅ DeepSeek API 客户端初始化成功")
|
||||
except Exception as e:
|
||||
print(f"⚠️ DeepSeek API 客户端初始化失败: {e}")
|
||||
|
||||
def _traditional_classify(self, text: str, airline: str) -> EnhancedSentimentClassification:
|
||||
"""传统 ML 模型分类"""
|
||||
if not self.traditional_model:
|
||||
raise ValueError("传统 ML 模型未加载")
|
||||
|
||||
import numpy as np
|
||||
|
||||
# 预测
|
||||
sentiment = self.traditional_model.predict(np.array([text]), np.array([airline]))[0]
|
||||
|
||||
# 预测概率
|
||||
proba = self.traditional_model.predict_proba(np.array([text]), np.array([airline]))[0]
|
||||
|
||||
# 获取预测类别的置信度
|
||||
sentiment_idx = self.traditional_model.label_encoder.transform([sentiment])[0]
|
||||
confidence = float(proba[sentiment_idx])
|
||||
|
||||
# 生成简单解释
|
||||
reasoning = f"基于传统机器学习模型预测,情感倾向为{sentiment},置信度{confidence:.1%}。"
|
||||
|
||||
# 提取关键词(简化版)
|
||||
key_factors = [f"模型预测置信度: {confidence:.1%}", f"航空公司: {airline}"]
|
||||
|
||||
return EnhancedSentimentClassification(
|
||||
sentiment=sentiment,
|
||||
confidence=confidence,
|
||||
reasoning=reasoning,
|
||||
key_factors=key_factors,
|
||||
analysis_mode=AnalysisMode.TRADITIONAL
|
||||
)
|
||||
|
||||
async def _deepseek_classify(self, text: str, airline: str) -> EnhancedSentimentClassification:
|
||||
"""DeepSeek API 分类"""
|
||||
if not self.deepseek_agent:
|
||||
raise ValueError("DeepSeek API 客户端未初始化")
|
||||
|
||||
result = await self.deepseek_agent.analyze_sentiment(text, airline)
|
||||
|
||||
return EnhancedSentimentClassification(
|
||||
sentiment=result.sentiment,
|
||||
confidence=result.confidence,
|
||||
reasoning=result.reasoning,
|
||||
key_factors=result.key_factors,
|
||||
analysis_mode=AnalysisMode.DEEPSEEK
|
||||
)
|
||||
|
||||
def _hybrid_classify(self, text: str, airline: str, traditional_result, deepseek_result) -> EnhancedSentimentClassification:
|
||||
"""混合模式分类"""
|
||||
|
||||
# 简单的混合策略:优先使用 DeepSeek,如果失败则使用传统模型
|
||||
if deepseek_result.confidence >= 0.7:
|
||||
# DeepSeek 置信度高,使用 DeepSeek 结果
|
||||
return deepseek_result
|
||||
else:
|
||||
# DeepSeek 置信度低,使用传统模型结果
|
||||
return traditional_result
|
||||
|
||||
def _generate_disposal_plan(self, text: str, airline: str, classification: EnhancedSentimentClassification) -> EnhancedDisposalPlan:
|
||||
"""生成处置方案"""
|
||||
|
||||
sentiment = classification.sentiment
|
||||
confidence = classification.confidence
|
||||
|
||||
# 基于情感和置信度确定优先级和行动类型
|
||||
if sentiment == "negative":
|
||||
if confidence >= 0.8:
|
||||
priority = "high"
|
||||
action_type = "response"
|
||||
suggested_response = self._generate_negative_response(text, airline)
|
||||
follow_up_actions = [
|
||||
"记录客户投诉详情",
|
||||
"转交相关部门处理",
|
||||
"跟进处理进度",
|
||||
"在24小时内给予反馈",
|
||||
]
|
||||
reasoning = "负面情感且置信度高,需要立即响应和处理"
|
||||
else:
|
||||
priority = "medium"
|
||||
action_type = "investigate"
|
||||
suggested_response = None
|
||||
follow_up_actions = [
|
||||
"进一步核实情况",
|
||||
"根据核实结果决定是否需要回复",
|
||||
]
|
||||
reasoning = "负面情感但置信度一般,需要进一步调查"
|
||||
elif sentiment == "positive":
|
||||
if confidence >= 0.8:
|
||||
priority = "low"
|
||||
action_type = "response"
|
||||
suggested_response = self._generate_positive_response(text, airline)
|
||||
follow_up_actions = [
|
||||
"感谢客户反馈",
|
||||
"分享正面评价至内部团队",
|
||||
"考虑在官方渠道展示",
|
||||
]
|
||||
reasoning = "正面情感且置信度高,适合回复感谢"
|
||||
else:
|
||||
priority = "low"
|
||||
action_type = "monitor"
|
||||
suggested_response = None
|
||||
follow_up_actions = [
|
||||
"持续关注该用户后续动态",
|
||||
]
|
||||
reasoning = "正面情感但置信度一般,保持关注即可"
|
||||
else: # neutral
|
||||
if "?" in text or "help" in text.lower():
|
||||
priority = "medium"
|
||||
action_type = "response"
|
||||
suggested_response = self._generate_neutral_response(text, airline)
|
||||
follow_up_actions = [
|
||||
"提供准确信息",
|
||||
"确保客户问题得到解答",
|
||||
]
|
||||
reasoning = "中性情感但包含询问,需要提供帮助"
|
||||
else:
|
||||
priority = "low"
|
||||
action_type = "monitor"
|
||||
suggested_response = None
|
||||
follow_up_actions = [
|
||||
"持续关注",
|
||||
]
|
||||
reasoning = "中性情感,保持常规关注"
|
||||
|
||||
return EnhancedDisposalPlan(
|
||||
priority=priority,
|
||||
action_type=action_type,
|
||||
suggested_response=suggested_response,
|
||||
follow_up_actions=follow_up_actions,
|
||||
reasoning=reasoning
|
||||
)
|
||||
|
||||
def _generate_negative_response(self, text: str, airline: str) -> str:
|
||||
"""生成负面情感回复"""
|
||||
responses = [
|
||||
f"感谢您的反馈。我们非常重视您提到的问题,将立即进行调查并尽快给您答复。",
|
||||
f"对于您的不愉快体验,我们深表歉意。请私信我们详细情况,我们将全力为您解决。",
|
||||
f"收到您的反馈,我们对此感到抱歉。相关部门已介入,将尽快处理并给您满意的答复。",
|
||||
]
|
||||
return responses[hash(text) % len(responses)]
|
||||
|
||||
def _generate_positive_response(self, text: str, airline: str) -> str:
|
||||
"""生成正面情感回复"""
|
||||
responses = [
|
||||
f"感谢您的认可和支持!我们会继续努力为您提供更好的服务。",
|
||||
f"很高兴听到您的正面反馈!您的满意是我们前进的动力。",
|
||||
f"感谢您的分享!我们会将您的反馈传达给团队,激励我们做得更好。",
|
||||
]
|
||||
return responses[hash(text) % len(responses)]
|
||||
|
||||
def _generate_neutral_response(self, text: str, airline: str) -> str:
|
||||
"""生成中性情感回复"""
|
||||
responses = [
|
||||
f"感谢您的询问。请问您需要了解哪方面的信息?我们将竭诚为您解答。",
|
||||
f"收到您的问题。请提供更多细节,以便我们更好地为您提供帮助。",
|
||||
]
|
||||
return responses[hash(text) % len(responses)]
|
||||
|
||||
async def analyze(self, text: str, airline: str, mode: Optional[str] = None) -> EnhancedTweetAnalysisResult:
|
||||
"""完整分析流程
|
||||
|
||||
Args:
|
||||
text: 推文文本
|
||||
airline: 航空公司
|
||||
mode: 分析模式 (deepseek/traditional/hybrid)
|
||||
|
||||
Returns:
|
||||
完整分析结果
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
analysis_mode = mode or self.default_mode
|
||||
api_used = False
|
||||
|
||||
# 1. 情感分类
|
||||
if analysis_mode == AnalysisMode.TRADITIONAL:
|
||||
classification = self._traditional_classify(text, airline)
|
||||
elif analysis_mode == AnalysisMode.DEEPSEEK:
|
||||
classification = await self._deepseek_classify(text, airline)
|
||||
api_used = True
|
||||
else: # hybrid
|
||||
# 并行执行两种分析
|
||||
traditional_task = asyncio.to_thread(self._traditional_classify, text, airline)
|
||||
deepseek_task = self._deepseek_classify(text, airline)
|
||||
|
||||
traditional_result, deepseek_result = await asyncio.gather(
|
||||
traditional_task, deepseek_task, return_exceptions=True
|
||||
)
|
||||
|
||||
# 处理异常情况
|
||||
if isinstance(traditional_result, Exception):
|
||||
print(f"传统模型分析失败: {traditional_result}")
|
||||
traditional_result = None
|
||||
|
||||
if isinstance(deepseek_result, Exception):
|
||||
print(f"DeepSeek API 分析失败: {deepseek_result}")
|
||||
deepseek_result = None
|
||||
|
||||
if deepseek_result:
|
||||
api_used = True
|
||||
|
||||
# 混合策略
|
||||
if traditional_result and deepseek_result:
|
||||
classification = self._hybrid_classify(text, airline, traditional_result, deepseek_result)
|
||||
elif deepseek_result:
|
||||
classification = deepseek_result
|
||||
elif traditional_result:
|
||||
classification = traditional_result
|
||||
else:
|
||||
raise ValueError("所有分析模式都失败了")
|
||||
|
||||
# 2. 生成处置方案
|
||||
disposal_plan = self._generate_disposal_plan(text, airline, classification)
|
||||
|
||||
# 3. 计算处理时间
|
||||
processing_time = time.time() - start_time
|
||||
|
||||
# 返回完整结果
|
||||
return EnhancedTweetAnalysisResult(
|
||||
tweet_text=text,
|
||||
airline=airline,
|
||||
classification=classification,
|
||||
disposal_plan=disposal_plan,
|
||||
processing_time=processing_time,
|
||||
api_used=api_used
|
||||
)
|
||||
|
||||
|
||||
def analyze_tweet(text: str, airline: str, mode: str = "hybrid") -> EnhancedTweetAnalysisResult:
|
||||
"""同步版本的推文分析函数(兼容原有接口)
|
||||
|
||||
Args:
|
||||
text: 推文文本
|
||||
airline: 航空公司
|
||||
mode: 分析模式 (deepseek/traditional/hybrid)
|
||||
|
||||
Returns:
|
||||
分析结果
|
||||
"""
|
||||
agent = EnhancedTweetSentimentAgent()
|
||||
return asyncio.run(agent.analyze(text, airline, mode))
|
||||
|
||||
|
||||
def analyze_tweet_ultimate(text: str, mode: str = "traditional") -> EnhancedTweetAnalysisResult:
|
||||
"""终极版本推文分析函数 - 完全无需航空公司参数
|
||||
|
||||
Args:
|
||||
text: 推文文本
|
||||
mode: 分析模式 (traditional)
|
||||
|
||||
Returns:
|
||||
分析结果
|
||||
"""
|
||||
agent = EnhancedTweetSentimentAgent()
|
||||
|
||||
# 自动检测航空公司或使用通用标识
|
||||
airline = "通用航空公司"
|
||||
|
||||
# 简单的航空公司检测逻辑
|
||||
airline_keywords = {
|
||||
"united": "United Airlines",
|
||||
"delta": "Delta Air Lines",
|
||||
"american": "American Airlines",
|
||||
"southwest": "Southwest Airlines",
|
||||
"jetblue": "JetBlue Airways",
|
||||
"air china": "中国国际航空",
|
||||
"china eastern": "中国东方航空",
|
||||
"china southern": "中国南方航空"
|
||||
}
|
||||
|
||||
text_lower = text.lower()
|
||||
for keyword, airline_name in airline_keywords.items():
|
||||
if keyword in text_lower:
|
||||
airline = airline_name
|
||||
break
|
||||
|
||||
return asyncio.run(agent.analyze(text, airline, mode))
|
||||
146
test_api_connection.py
Normal file
146
test_api_connection.py
Normal file
@ -0,0 +1,146 @@
|
||||
"""测试 DeepSeek API 连接"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import httpx
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 加载环境变量
|
||||
load_dotenv()
|
||||
|
||||
async def test_deepseek_api():
|
||||
"""测试 DeepSeek API 连接"""
|
||||
|
||||
api_key = os.getenv("DEEPSEEK_API_KEY")
|
||||
base_url = os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com/v1")
|
||||
|
||||
print("🔍 开始测试 DeepSeek API 连接...")
|
||||
print(f"API Key: {'已配置' if api_key else '未配置'}")
|
||||
print(f"Base URL: {base_url}")
|
||||
|
||||
if not api_key:
|
||||
print("❌ 错误: 未找到 DEEPSEEK_API_KEY 环境变量")
|
||||
return False
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
# 简单的测试消息
|
||||
data = {
|
||||
"model": "deepseek-chat",
|
||||
"messages": [
|
||||
{"role": "system", "content": "你是一个测试助手,只需要回复'连接成功'即可。"},
|
||||
{"role": "user", "content": "请回复'连接成功'"}
|
||||
],
|
||||
"temperature": 0.1,
|
||||
"max_tokens": 10,
|
||||
"stream": False
|
||||
}
|
||||
|
||||
try:
|
||||
print("🔄 发送测试请求...")
|
||||
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
response = await client.post(
|
||||
f"{base_url}/chat/completions",
|
||||
headers=headers,
|
||||
json=data
|
||||
)
|
||||
|
||||
print(f"📊 响应状态码: {response.status_code}")
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
content = result["choices"][0]["message"]["content"]
|
||||
print(f"✅ API 连接成功!")
|
||||
print(f"📝 响应内容: {content}")
|
||||
print(f"🔢 使用令牌数: {result.get('usage', {}).get('total_tokens', '未知')}")
|
||||
return True
|
||||
else:
|
||||
print(f"❌ API 调用失败: {response.status_code}")
|
||||
print(f"📋 错误信息: {response.text}")
|
||||
return False
|
||||
|
||||
except httpx.ConnectError as e:
|
||||
print(f"❌ 连接错误: {e}")
|
||||
print("💡 请检查网络连接和API地址")
|
||||
return False
|
||||
except httpx.TimeoutException as e:
|
||||
print(f"❌ 请求超时: {e}")
|
||||
print("💡 请检查网络连接或增加超时时间")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ 未知错误: {e}")
|
||||
return False
|
||||
|
||||
async def test_agent_integration():
|
||||
"""测试 Agent 集成"""
|
||||
|
||||
print("\n🧪 测试 Agent 集成...")
|
||||
|
||||
try:
|
||||
from src.deepseek_agent import DeepSeekTweetAgent
|
||||
|
||||
agent = DeepSeekTweetAgent()
|
||||
|
||||
# 测试推文分析
|
||||
test_tweet = "@United This is the worst airline ever! My flight was delayed for 5 hours and the staff was rude."
|
||||
test_airline = "United"
|
||||
|
||||
print(f"📝 测试推文: {test_tweet}")
|
||||
print(f"✈️ 航空公司: {test_airline}")
|
||||
|
||||
result = await agent.analyze_sentiment(test_tweet, test_airline)
|
||||
|
||||
print("✅ Agent 集成测试成功!")
|
||||
print(f"😊 情感分析结果: {result.sentiment}")
|
||||
print(f"📊 置信度: {result.confidence:.1%}")
|
||||
print(f"🔍 关键因素: {result.key_factors}")
|
||||
|
||||
return True
|
||||
|
||||
except ImportError as e:
|
||||
print(f"❌ 导入错误: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Agent 测试失败: {e}")
|
||||
return False
|
||||
|
||||
async def main():
|
||||
"""主测试函数"""
|
||||
|
||||
print("=" * 50)
|
||||
print("🚀 DeepSeek API 连接测试")
|
||||
print("=" * 50)
|
||||
|
||||
# 测试基础 API 连接
|
||||
api_success = await test_deepseek_api()
|
||||
|
||||
# 测试 Agent 集成
|
||||
agent_success = await test_agent_integration()
|
||||
|
||||
print("\n" + "=" * 50)
|
||||
print("📋 测试结果汇总")
|
||||
print("=" * 50)
|
||||
|
||||
if api_success and agent_success:
|
||||
print("🎉 所有测试通过! DeepSeek API 集成成功!")
|
||||
print("💡 您现在可以正常使用增强版应用了")
|
||||
else:
|
||||
print("⚠️ 部分测试失败,请检查配置")
|
||||
|
||||
if not api_success:
|
||||
print("❌ API 连接测试失败")
|
||||
print("💡 请检查:")
|
||||
print(" 1. API Key 是否正确配置")
|
||||
print(" 2. 网络连接是否正常")
|
||||
print(" 3. API 地址是否正确")
|
||||
|
||||
if not agent_success:
|
||||
print("❌ Agent 集成测试失败")
|
||||
print("💡 请检查模块导入和依赖安装")
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Loading…
Reference in New Issue
Block a user