250 lines
9.7 KiB
Python
250 lines
9.7 KiB
Python
import os
|
||
from openai import OpenAI
|
||
from dotenv import load_dotenv
|
||
from typing import Dict, Any, Optional
|
||
from pydantic import BaseModel
|
||
|
||
# 加载环境变量
|
||
load_dotenv()
|
||
|
||
|
||
class LLMService:
|
||
"""LLM服务类,用于调用DeepSeek API"""
|
||
|
||
def __init__(self):
|
||
"""初始化LLM服务"""
|
||
# 直接从.env文件读取API密钥
|
||
env_path = os.path.join(os.path.dirname(__file__), "..", "..", ".env")
|
||
env_vars = {}
|
||
|
||
if os.path.exists(env_path):
|
||
with open(env_path, "r") as f:
|
||
for line in f:
|
||
line = line.strip()
|
||
if line and not line.startswith("#"):
|
||
key, value = line.split("=", 1)
|
||
env_vars[key.strip()] = value.strip()
|
||
|
||
# 获取API密钥和基础URL
|
||
self.api_key = env_vars.get("DEEPSEEK_API_KEY") or os.getenv("DEEPSEEK_API_KEY")
|
||
self.base_url = env_vars.get("DEEPSEEK_BASE_URL") or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
|
||
|
||
# 清理并修复base_url,确保包含协议前缀
|
||
if self.base_url:
|
||
# 清理URL:移除多余的空格、引号和反引号
|
||
self.base_url = self.base_url.strip()
|
||
# 移除可能的引号
|
||
for quote_char in ['"', "'", '`']:
|
||
if self.base_url.startswith(quote_char) and self.base_url.endswith(quote_char):
|
||
self.base_url = self.base_url[1:-1].strip()
|
||
break
|
||
|
||
# 确保包含协议前缀
|
||
if not (self.base_url.startswith("http://") or self.base_url.startswith("https://")):
|
||
print(f"警告: base_url '{self.base_url}' 缺少协议前缀,将添加 'https://'")
|
||
self.base_url = f"https://{self.base_url}"
|
||
|
||
# 默认模型
|
||
self.default_model = "deepseek-chat"
|
||
|
||
# 延迟创建客户端,直到实际需要时
|
||
self.client = None
|
||
|
||
def _ensure_client(self):
|
||
"""确保客户端已初始化"""
|
||
if self.client is None:
|
||
# 确保API密钥存在
|
||
if not self.api_key:
|
||
raise ValueError("DEEPSEEK_API_KEY环境变量未设置")
|
||
|
||
# 创建客户端
|
||
self.client = OpenAI(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url
|
||
)
|
||
|
||
def _call_api(
|
||
self,
|
||
messages: list,
|
||
model: str = None,
|
||
temperature: float = 0.7,
|
||
max_tokens: int = 1000,
|
||
response_format: Optional[Dict[str, Any]] = None
|
||
) -> str:
|
||
"""调用LLM API"""
|
||
if model is None:
|
||
model = self.default_model
|
||
|
||
try:
|
||
# 确保客户端已初始化
|
||
self._ensure_client()
|
||
|
||
response = self.client.chat.completions.create(
|
||
model=model,
|
||
messages=messages,
|
||
temperature=temperature,
|
||
max_tokens=max_tokens,
|
||
response_format=response_format
|
||
)
|
||
return response.choices[0].message.content
|
||
except Exception as e:
|
||
import json
|
||
import traceback
|
||
print(f"LLM API调用失败: {e}")
|
||
print(f"错误类型: {type(e).__name__}")
|
||
# 尝试获取更详细的错误信息
|
||
if hasattr(e, 'response') and e.response:
|
||
try:
|
||
error_details = e.response.json()
|
||
print(f"详细错误信息: {json.dumps(error_details, ensure_ascii=False, indent=2)}")
|
||
except:
|
||
print(f"原始响应内容: {e.response.text}")
|
||
traceback.print_exc()
|
||
return ""
|
||
|
||
def translate_text(self, text: str, target_lang: str = "zh-CN") -> str:
|
||
"""将文本翻译成目标语言"""
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": f"你是一个专业的翻译助手,请将给定的文本翻译成{target_lang}。保持原文的意思,翻译要准确、自然。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": text
|
||
}
|
||
]
|
||
|
||
return self._call_api(messages, temperature=0.3)
|
||
|
||
def explain_prediction(self, text: str, label: str, probability: float) -> Dict[str, Any]:
|
||
"""解释模型预测结果"""
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的短信分类解释专家。请根据给定的短信内容、分类结果和概率,生成一个清晰、简洁的解释。解释应该包括:\n1. 短信的主要内容\n2. 为什么被分类为垃圾短信或正常短信\n3. 分类的可信度\n\n请使用结构化的JSON格式输出,包含以下字段:\n- content_summary: 短信内容摘要\n- classification_reason: 分类原因\n- confidence_level: 可信度级别(高、中、低)\n- confidence_explanation: 可信度解释\n- suggestions: 针对该短信的建议"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"短信内容: {text}\n分类结果: {label}\n分类概率: {probability:.2f}"
|
||
}
|
||
]
|
||
|
||
response = self._call_api(
|
||
messages,
|
||
temperature=0.5,
|
||
response_format={"type": "json_object"}
|
||
)
|
||
|
||
import json
|
||
try:
|
||
result = json.loads(response)
|
||
|
||
# 确保suggestions字段是一个列表
|
||
if not isinstance(result.get("suggestions"), list):
|
||
suggestions_text = result.get("suggestions", "")
|
||
# 如果是字符串,尝试分割成列表
|
||
if isinstance(suggestions_text, str):
|
||
# 移除可能的前缀
|
||
suggestions_text = suggestions_text.replace("建议用户:", "")
|
||
suggestions_text = suggestions_text.replace("建议:", "")
|
||
# 按序号分割
|
||
import re
|
||
# 匹配数字. 开头的模式
|
||
suggestions = re.split(r'\d+\.\s*', suggestions_text)
|
||
# 过滤掉空字符串
|
||
suggestions = [s.strip() for s in suggestions if s.strip()]
|
||
result["suggestions"] = suggestions
|
||
else:
|
||
# 如果是其他类型,设置为空列表
|
||
result["suggestions"] = []
|
||
|
||
return result
|
||
except json.JSONDecodeError:
|
||
return {
|
||
"content_summary": "",
|
||
"classification_reason": "",
|
||
"confidence_level": "中",
|
||
"confidence_explanation": "",
|
||
"suggestions": []
|
||
}
|
||
|
||
def generate_advice(self, text: str, label: str) -> str:
|
||
"""生成行动建议"""
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个专业的短信管理顾问。请根据给定的短信内容和分类结果,生成具体、实用的行动建议。建议要简洁明了,针对性强。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"短信内容: {text}\n分类结果: {label}"
|
||
}
|
||
]
|
||
|
||
return self._call_api(messages, temperature=0.5)
|
||
|
||
def analyze_spam_patterns(self, spam_texts: list) -> str:
|
||
"""分析垃圾短信的模式"""
|
||
if len(spam_texts) == 0:
|
||
return "没有提供垃圾短信样本"
|
||
|
||
# 限制短信数量,避免超过API限制
|
||
sample_texts = spam_texts[:5]
|
||
texts_str = "\n".join([f"{i+1}. {text}" for i, text in enumerate(sample_texts)])
|
||
|
||
messages = [
|
||
{
|
||
"role": "system",
|
||
"content": "你是一个垃圾短信模式分析专家。请分析给定的垃圾短信样本,总结出常见的模式和特征。分析要全面、准确,包括但不限于:\n1. 内容特征\n2. 语言风格\n3. 发送目的\n4. 常见关键词\n\n请使用简洁明了的语言输出分析结果。"
|
||
},
|
||
{
|
||
"role": "user",
|
||
"content": f"垃圾短信样本:\n{texts_str}"
|
||
}
|
||
]
|
||
|
||
return self._call_api(messages, temperature=0.5)
|
||
|
||
|
||
# 创建全局LLM服务实例
|
||
llm_service = LLMService()
|
||
|
||
|
||
def get_llm_service() -> LLMService:
|
||
"""获取LLM服务实例"""
|
||
return llm_service
|
||
|
||
|
||
def main():
|
||
"""主函数,用于测试"""
|
||
# 测试翻译功能
|
||
test_text = "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's"
|
||
print("翻译测试:")
|
||
print(f"原文: {test_text}")
|
||
translation = llm_service.translate_text(test_text)
|
||
print(f"译文: {translation}")
|
||
|
||
# 测试解释功能
|
||
print("\n解释测试:")
|
||
explanation = llm_service.explain_prediction(test_text, "spam", 0.95)
|
||
print(f"解释结果: {explanation}")
|
||
|
||
# 测试建议功能
|
||
print("\n建议测试:")
|
||
advice = llm_service.generate_advice(test_text, "spam")
|
||
print(f"建议: {advice}")
|
||
|
||
# 测试模式分析功能
|
||
print("\n模式分析测试:")
|
||
spam_samples = [
|
||
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005.",
|
||
"WINNER!! As a valued network customer you have been selected to receivea <20>900 prize reward!",
|
||
"Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free!"
|
||
]
|
||
pattern_analysis = llm_service.analyze_spam_patterns(spam_samples)
|
||
print(f"模式分析: {pattern_analysis}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |