2026-01-07 11:02:05 +08:00
|
|
|
|
"""
|
|
|
|
|
|
LLM 客户端封装 - 统一 Anthropic/OpenAI/AIHubMix 接口
|
|
|
|
|
|
"""
|
|
|
|
|
|
from typing import Generator
|
|
|
|
|
|
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-01-09 09:25:02 +08:00
|
|
|
|
import config
|
|
|
|
|
|
|
2026-01-07 11:02:05 +08:00
|
|
|
|
class LLMClient:
|
|
|
|
|
|
"""LLM API 统一客户端"""
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
provider: str = None,
|
|
|
|
|
|
api_key: str = None,
|
|
|
|
|
|
base_url: str = None,
|
|
|
|
|
|
model: str = None
|
|
|
|
|
|
):
|
|
|
|
|
|
"""
|
|
|
|
|
|
初始化 LLM 客户端
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
provider: 'anthropic', 'openai', 'aihubmix', 或 'custom'
|
|
|
|
|
|
api_key: API 密钥
|
|
|
|
|
|
base_url: 自定义 API 地址(用于 aihubmix/custom)
|
|
|
|
|
|
model: 指定模型名称
|
|
|
|
|
|
"""
|
|
|
|
|
|
self.provider = provider or "aihubmix"
|
|
|
|
|
|
self.model = model or "gpt-4o"
|
|
|
|
|
|
|
|
|
|
|
|
if self.provider == "anthropic":
|
|
|
|
|
|
from anthropic import Anthropic
|
|
|
|
|
|
self.client = Anthropic(api_key=api_key)
|
|
|
|
|
|
|
|
|
|
|
|
elif self.provider == "openai":
|
|
|
|
|
|
from openai import OpenAI
|
|
|
|
|
|
self.client = OpenAI(api_key=api_key)
|
|
|
|
|
|
self.model = model or "gpt-4o"
|
|
|
|
|
|
|
2026-01-07 13:44:46 +08:00
|
|
|
|
elif self.provider in ["aihubmix", "deepseek", "siliconflow", "custom"]:
|
|
|
|
|
|
# OpenAI 兼容接口 Providers
|
2026-01-07 11:02:05 +08:00
|
|
|
|
from openai import OpenAI
|
2026-01-07 13:44:46 +08:00
|
|
|
|
|
|
|
|
|
|
default_urls = {
|
|
|
|
|
|
"aihubmix": "https://aihubmix.com/v1",
|
|
|
|
|
|
"deepseek": "https://api.deepseek.com",
|
|
|
|
|
|
"siliconflow": "https://api.siliconflow.cn/v1",
|
|
|
|
|
|
"custom": "http://localhost:8000/v1"
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
final_base_url = base_url or default_urls.get(self.provider)
|
|
|
|
|
|
|
2026-01-07 11:02:05 +08:00
|
|
|
|
self.client = OpenAI(
|
|
|
|
|
|
api_key=api_key,
|
2026-01-07 13:44:46 +08:00
|
|
|
|
base_url=final_base_url
|
2026-01-07 11:02:05 +08:00
|
|
|
|
)
|
|
|
|
|
|
self.model = model or "gpt-4o"
|
|
|
|
|
|
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"不支持的 provider: {self.provider}")
|
|
|
|
|
|
|
|
|
|
|
|
def chat_stream(
|
|
|
|
|
|
self,
|
|
|
|
|
|
system_prompt: str,
|
|
|
|
|
|
user_prompt: str,
|
2026-01-09 09:25:02 +08:00
|
|
|
|
max_tokens: int = config.MAX_OUTPUT_TOKENS
|
2026-01-07 11:02:05 +08:00
|
|
|
|
) -> Generator[str, None, None]:
|
|
|
|
|
|
"""
|
|
|
|
|
|
流式对话
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
system_prompt: 系统提示词
|
|
|
|
|
|
user_prompt: 用户输入
|
|
|
|
|
|
max_tokens: 最大输出 token 数
|
|
|
|
|
|
|
|
|
|
|
|
Yields:
|
|
|
|
|
|
str: 流式输出的文本片段
|
|
|
|
|
|
"""
|
|
|
|
|
|
if self.provider == "anthropic":
|
|
|
|
|
|
yield from self._anthropic_stream(system_prompt, user_prompt, max_tokens)
|
|
|
|
|
|
else:
|
|
|
|
|
|
yield from self._openai_stream(system_prompt, user_prompt, max_tokens)
|
|
|
|
|
|
|
|
|
|
|
|
def _anthropic_stream(
|
|
|
|
|
|
self,
|
|
|
|
|
|
system_prompt: str,
|
|
|
|
|
|
user_prompt: str,
|
|
|
|
|
|
max_tokens: int
|
|
|
|
|
|
) -> Generator[str, None, None]:
|
|
|
|
|
|
"""Anthropic 流式调用"""
|
|
|
|
|
|
with self.client.messages.stream(
|
|
|
|
|
|
model=self.model,
|
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
|
system=system_prompt,
|
|
|
|
|
|
messages=[{"role": "user", "content": user_prompt}]
|
|
|
|
|
|
) as stream:
|
|
|
|
|
|
for text in stream.text_stream:
|
|
|
|
|
|
yield text
|
|
|
|
|
|
|
|
|
|
|
|
def _openai_stream(
|
|
|
|
|
|
self,
|
|
|
|
|
|
system_prompt: str,
|
|
|
|
|
|
user_prompt: str,
|
|
|
|
|
|
max_tokens: int
|
|
|
|
|
|
) -> Generator[str, None, None]:
|
|
|
|
|
|
"""OpenAI 兼容接口流式调用(支持 AIHubMix、vLLM 等)"""
|
|
|
|
|
|
try:
|
|
|
|
|
|
stream = self.client.chat.completions.create(
|
|
|
|
|
|
model=self.model,
|
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
|
stream=True,
|
|
|
|
|
|
messages=[
|
|
|
|
|
|
{"role": "system", "content": system_prompt},
|
|
|
|
|
|
{"role": "user", "content": user_prompt}
|
|
|
|
|
|
]
|
|
|
|
|
|
)
|
|
|
|
|
|
for chunk in stream:
|
|
|
|
|
|
# 安全地获取 content,处理各种边界情况
|
|
|
|
|
|
if chunk.choices and len(chunk.choices) > 0:
|
|
|
|
|
|
delta = chunk.choices[0].delta
|
|
|
|
|
|
if delta and hasattr(delta, 'content') and delta.content:
|
|
|
|
|
|
yield delta.content
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
yield f"\n\n[错误: {str(e)}]"
|
|
|
|
|
|
|
|
|
|
|
|
def chat(
|
|
|
|
|
|
self,
|
|
|
|
|
|
system_prompt: str,
|
|
|
|
|
|
user_prompt: str,
|
|
|
|
|
|
max_tokens: int = 1024
|
|
|
|
|
|
) -> str:
|
|
|
|
|
|
"""
|
|
|
|
|
|
非流式对话
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
system_prompt: 系统提示词
|
|
|
|
|
|
user_prompt: 用户输入
|
|
|
|
|
|
max_tokens: 最大输出 token 数
|
|
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
|
str: 完整的响应文本
|
|
|
|
|
|
"""
|
|
|
|
|
|
return "".join(self.chat_stream(system_prompt, user_prompt, max_tokens))
|