143 lines
4.4 KiB
Python
143 lines
4.4 KiB
Python
"""
|
||
LLM 客户端封装 - 统一 Anthropic/OpenAI/AIHubMix 接口
|
||
"""
|
||
from typing import Generator
|
||
import os
|
||
|
||
|
||
class LLMClient:
|
||
"""LLM API 统一客户端"""
|
||
|
||
def __init__(
|
||
self,
|
||
provider: str = None,
|
||
api_key: str = None,
|
||
base_url: str = None,
|
||
model: str = None
|
||
):
|
||
"""
|
||
初始化 LLM 客户端
|
||
|
||
Args:
|
||
provider: 'anthropic', 'openai', 'aihubmix', 或 'custom'
|
||
api_key: API 密钥
|
||
base_url: 自定义 API 地址(用于 aihubmix/custom)
|
||
model: 指定模型名称
|
||
"""
|
||
self.provider = provider or "aihubmix"
|
||
self.model = model or "gpt-4o"
|
||
|
||
if self.provider == "anthropic":
|
||
from anthropic import Anthropic
|
||
self.client = Anthropic(api_key=api_key)
|
||
|
||
elif self.provider == "openai":
|
||
from openai import OpenAI
|
||
self.client = OpenAI(api_key=api_key)
|
||
self.model = model or "gpt-4o"
|
||
|
||
elif self.provider in ["aihubmix", "deepseek", "siliconflow", "custom"]:
|
||
# OpenAI 兼容接口 Providers
|
||
from openai import OpenAI
|
||
|
||
default_urls = {
|
||
"aihubmix": "https://aihubmix.com/v1",
|
||
"deepseek": "https://api.deepseek.com",
|
||
"siliconflow": "https://api.siliconflow.cn/v1",
|
||
"custom": "http://localhost:8000/v1"
|
||
}
|
||
|
||
final_base_url = base_url or default_urls.get(self.provider)
|
||
|
||
self.client = OpenAI(
|
||
api_key=api_key,
|
||
base_url=final_base_url
|
||
)
|
||
self.model = model or "gpt-4o"
|
||
|
||
else:
|
||
raise ValueError(f"不支持的 provider: {self.provider}")
|
||
|
||
def chat_stream(
|
||
self,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
max_tokens: int = 1024
|
||
) -> Generator[str, None, None]:
|
||
"""
|
||
流式对话
|
||
|
||
Args:
|
||
system_prompt: 系统提示词
|
||
user_prompt: 用户输入
|
||
max_tokens: 最大输出 token 数
|
||
|
||
Yields:
|
||
str: 流式输出的文本片段
|
||
"""
|
||
if self.provider == "anthropic":
|
||
yield from self._anthropic_stream(system_prompt, user_prompt, max_tokens)
|
||
else:
|
||
yield from self._openai_stream(system_prompt, user_prompt, max_tokens)
|
||
|
||
def _anthropic_stream(
|
||
self,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
max_tokens: int
|
||
) -> Generator[str, None, None]:
|
||
"""Anthropic 流式调用"""
|
||
with self.client.messages.stream(
|
||
model=self.model,
|
||
max_tokens=max_tokens,
|
||
system=system_prompt,
|
||
messages=[{"role": "user", "content": user_prompt}]
|
||
) as stream:
|
||
for text in stream.text_stream:
|
||
yield text
|
||
|
||
def _openai_stream(
|
||
self,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
max_tokens: int
|
||
) -> Generator[str, None, None]:
|
||
"""OpenAI 兼容接口流式调用(支持 AIHubMix、vLLM 等)"""
|
||
try:
|
||
stream = self.client.chat.completions.create(
|
||
model=self.model,
|
||
max_tokens=max_tokens,
|
||
stream=True,
|
||
messages=[
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt}
|
||
]
|
||
)
|
||
for chunk in stream:
|
||
# 安全地获取 content,处理各种边界情况
|
||
if chunk.choices and len(chunk.choices) > 0:
|
||
delta = chunk.choices[0].delta
|
||
if delta and hasattr(delta, 'content') and delta.content:
|
||
yield delta.content
|
||
except Exception as e:
|
||
yield f"\n\n[错误: {str(e)}]"
|
||
|
||
def chat(
|
||
self,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
max_tokens: int = 1024
|
||
) -> str:
|
||
"""
|
||
非流式对话
|
||
|
||
Args:
|
||
system_prompt: 系统提示词
|
||
user_prompt: 用户输入
|
||
max_tokens: 最大输出 token 数
|
||
|
||
Returns:
|
||
str: 完整的响应文本
|
||
"""
|
||
return "".join(self.chat_stream(system_prompt, user_prompt, max_tokens))
|