wd666/utils/llm_client.py

145 lines
4.4 KiB
Python
Raw Normal View History

2026-01-07 11:02:05 +08:00
"""
LLM 客户端封装 - 统一 Anthropic/OpenAI/AIHubMix 接口
"""
from typing import Generator
import os
import config
2026-01-07 11:02:05 +08:00
class LLMClient:
"""LLM API 统一客户端"""
def __init__(
self,
provider: str = None,
api_key: str = None,
base_url: str = None,
model: str = None
):
"""
初始化 LLM 客户端
Args:
provider: 'anthropic', 'openai', 'aihubmix', 'custom'
api_key: API 密钥
base_url: 自定义 API 地址用于 aihubmix/custom
model: 指定模型名称
"""
self.provider = provider or "aihubmix"
self.model = model or "gpt-4o"
if self.provider == "anthropic":
from anthropic import Anthropic
self.client = Anthropic(api_key=api_key)
elif self.provider == "openai":
from openai import OpenAI
self.client = OpenAI(api_key=api_key)
self.model = model or "gpt-4o"
elif self.provider in ["aihubmix", "deepseek", "siliconflow", "custom"]:
# OpenAI 兼容接口 Providers
2026-01-07 11:02:05 +08:00
from openai import OpenAI
default_urls = {
"aihubmix": "https://aihubmix.com/v1",
"deepseek": "https://api.deepseek.com",
"siliconflow": "https://api.siliconflow.cn/v1",
"custom": "http://localhost:8000/v1"
}
final_base_url = base_url or default_urls.get(self.provider)
2026-01-07 11:02:05 +08:00
self.client = OpenAI(
api_key=api_key,
base_url=final_base_url
2026-01-07 11:02:05 +08:00
)
self.model = model or "gpt-4o"
else:
raise ValueError(f"不支持的 provider: {self.provider}")
def chat_stream(
self,
system_prompt: str,
user_prompt: str,
max_tokens: int = config.MAX_OUTPUT_TOKENS
2026-01-07 11:02:05 +08:00
) -> Generator[str, None, None]:
"""
流式对话
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
max_tokens: 最大输出 token
Yields:
str: 流式输出的文本片段
"""
if self.provider == "anthropic":
yield from self._anthropic_stream(system_prompt, user_prompt, max_tokens)
else:
yield from self._openai_stream(system_prompt, user_prompt, max_tokens)
def _anthropic_stream(
self,
system_prompt: str,
user_prompt: str,
max_tokens: int
) -> Generator[str, None, None]:
"""Anthropic 流式调用"""
with self.client.messages.stream(
model=self.model,
max_tokens=max_tokens,
system=system_prompt,
messages=[{"role": "user", "content": user_prompt}]
) as stream:
for text in stream.text_stream:
yield text
def _openai_stream(
self,
system_prompt: str,
user_prompt: str,
max_tokens: int
) -> Generator[str, None, None]:
"""OpenAI 兼容接口流式调用(支持 AIHubMix、vLLM 等)"""
try:
stream = self.client.chat.completions.create(
model=self.model,
max_tokens=max_tokens,
stream=True,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
)
for chunk in stream:
# 安全地获取 content处理各种边界情况
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if delta and hasattr(delta, 'content') and delta.content:
yield delta.content
except Exception as e:
yield f"\n\n[错误: {str(e)}]"
def chat(
self,
system_prompt: str,
user_prompt: str,
max_tokens: int = 1024
) -> str:
"""
非流式对话
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
max_tokens: 最大输出 token
Returns:
str: 完整的响应文本
"""
return "".join(self.chat_stream(system_prompt, user_prompt, max_tokens))