sms-castle-walls/src/llm/llm_service.py

250 lines
9.7 KiB
Python
Raw Normal View History

import os
from openai import OpenAI
from dotenv import load_dotenv
from typing import Dict, Any, Optional
from pydantic import BaseModel
# 加载环境变量
load_dotenv()
class LLMService:
"""LLM服务类用于调用DeepSeek API"""
def __init__(self):
"""初始化LLM服务"""
# 直接从.env文件读取API密钥
env_path = os.path.join(os.path.dirname(__file__), "..", "..", ".env")
env_vars = {}
if os.path.exists(env_path):
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
env_vars[key.strip()] = value.strip()
# 获取API密钥和基础URL
self.api_key = env_vars.get("DEEPSEEK_API_KEY") or os.getenv("DEEPSEEK_API_KEY")
self.base_url = env_vars.get("DEEPSEEK_BASE_URL") or os.getenv("DEEPSEEK_BASE_URL", "https://api.deepseek.com")
# 清理并修复base_url确保包含协议前缀
if self.base_url:
# 清理URL移除多余的空格、引号和反引号
self.base_url = self.base_url.strip()
# 移除可能的引号
for quote_char in ['"', "'", '`']:
if self.base_url.startswith(quote_char) and self.base_url.endswith(quote_char):
self.base_url = self.base_url[1:-1].strip()
break
# 确保包含协议前缀
if not (self.base_url.startswith("http://") or self.base_url.startswith("https://")):
print(f"警告: base_url '{self.base_url}' 缺少协议前缀,将添加 'https://'")
self.base_url = f"https://{self.base_url}"
# 默认模型
self.default_model = "deepseek-chat"
# 延迟创建客户端,直到实际需要时
self.client = None
def _ensure_client(self):
"""确保客户端已初始化"""
if self.client is None:
# 确保API密钥存在
if not self.api_key:
raise ValueError("DEEPSEEK_API_KEY环境变量未设置")
# 创建客户端
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url
)
def _call_api(
self,
messages: list,
model: str = None,
temperature: float = 0.7,
max_tokens: int = 1000,
response_format: Optional[Dict[str, Any]] = None
) -> str:
"""调用LLM API"""
if model is None:
model = self.default_model
try:
# 确保客户端已初始化
self._ensure_client()
response = self.client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
response_format=response_format
)
return response.choices[0].message.content
except Exception as e:
import json
import traceback
print(f"LLM API调用失败: {e}")
print(f"错误类型: {type(e).__name__}")
# 尝试获取更详细的错误信息
if hasattr(e, 'response') and e.response:
try:
error_details = e.response.json()
print(f"详细错误信息: {json.dumps(error_details, ensure_ascii=False, indent=2)}")
except:
print(f"原始响应内容: {e.response.text}")
traceback.print_exc()
return ""
def translate_text(self, text: str, target_lang: str = "zh-CN") -> str:
"""将文本翻译成目标语言"""
messages = [
{
"role": "system",
"content": f"你是一个专业的翻译助手,请将给定的文本翻译成{target_lang}。保持原文的意思,翻译要准确、自然。"
},
{
"role": "user",
"content": text
}
]
return self._call_api(messages, temperature=0.3)
def explain_prediction(self, text: str, label: str, probability: float) -> Dict[str, Any]:
"""解释模型预测结果"""
messages = [
{
"role": "system",
"content": "你是一个专业的短信分类解释专家。请根据给定的短信内容、分类结果和概率,生成一个清晰、简洁的解释。解释应该包括:\n1. 短信的主要内容\n2. 为什么被分类为垃圾短信或正常短信\n3. 分类的可信度\n\n请使用结构化的JSON格式输出包含以下字段\n- content_summary: 短信内容摘要\n- classification_reason: 分类原因\n- confidence_level: 可信度级别(高、中、低)\n- confidence_explanation: 可信度解释\n- suggestions: 针对该短信的建议"
},
{
"role": "user",
"content": f"短信内容: {text}\n分类结果: {label}\n分类概率: {probability:.2f}"
}
]
response = self._call_api(
messages,
temperature=0.5,
response_format={"type": "json_object"}
)
import json
try:
result = json.loads(response)
# 确保suggestions字段是一个列表
if not isinstance(result.get("suggestions"), list):
suggestions_text = result.get("suggestions", "")
# 如果是字符串,尝试分割成列表
if isinstance(suggestions_text, str):
# 移除可能的前缀
suggestions_text = suggestions_text.replace("建议用户:", "")
suggestions_text = suggestions_text.replace("建议:", "")
# 按序号分割
import re
# 匹配数字. 开头的模式
suggestions = re.split(r'\d+\.\s*', suggestions_text)
# 过滤掉空字符串
suggestions = [s.strip() for s in suggestions if s.strip()]
result["suggestions"] = suggestions
else:
# 如果是其他类型,设置为空列表
result["suggestions"] = []
return result
except json.JSONDecodeError:
return {
"content_summary": "",
"classification_reason": "",
"confidence_level": "",
"confidence_explanation": "",
"suggestions": []
}
def generate_advice(self, text: str, label: str) -> str:
"""生成行动建议"""
messages = [
{
"role": "system",
"content": "你是一个专业的短信管理顾问。请根据给定的短信内容和分类结果,生成具体、实用的行动建议。建议要简洁明了,针对性强。"
},
{
"role": "user",
"content": f"短信内容: {text}\n分类结果: {label}"
}
]
return self._call_api(messages, temperature=0.5)
def analyze_spam_patterns(self, spam_texts: list) -> str:
"""分析垃圾短信的模式"""
if len(spam_texts) == 0:
return "没有提供垃圾短信样本"
# 限制短信数量避免超过API限制
sample_texts = spam_texts[:5]
texts_str = "\n".join([f"{i+1}. {text}" for i, text in enumerate(sample_texts)])
messages = [
{
"role": "system",
"content": "你是一个垃圾短信模式分析专家。请分析给定的垃圾短信样本,总结出常见的模式和特征。分析要全面、准确,包括但不限于:\n1. 内容特征\n2. 语言风格\n3. 发送目的\n4. 常见关键词\n\n请使用简洁明了的语言输出分析结果。"
},
{
"role": "user",
"content": f"垃圾短信样本:\n{texts_str}"
}
]
return self._call_api(messages, temperature=0.5)
# 创建全局LLM服务实例
llm_service = LLMService()
def get_llm_service() -> LLMService:
"""获取LLM服务实例"""
return llm_service
def main():
"""主函数,用于测试"""
# 测试翻译功能
test_text = "Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005. Text FA to 87121 to receive entry question(std txt rate)T&C's apply 08452810075over18's"
print("翻译测试:")
print(f"原文: {test_text}")
translation = llm_service.translate_text(test_text)
print(f"译文: {translation}")
# 测试解释功能
print("\n解释测试:")
explanation = llm_service.explain_prediction(test_text, "spam", 0.95)
print(f"解释结果: {explanation}")
# 测试建议功能
print("\n建议测试:")
advice = llm_service.generate_advice(test_text, "spam")
print(f"建议: {advice}")
# 测试模式分析功能
print("\n模式分析测试:")
spam_samples = [
"Free entry in 2 a wkly comp to win FA Cup final tkts 21st May 2005.",
"WINNER!! As a valued network customer you have been selected to receivea <20>900 prize reward!",
"Had your mobile 11 months or more? U R entitled to Update to the latest colour mobiles with camera for Free!"
]
pattern_analysis = llm_service.analyze_spam_patterns(spam_samples)
print(f"模式分析: {pattern_analysis}")
if __name__ == "__main__":
main()