142 lines
4.4 KiB
Python
142 lines
4.4 KiB
Python
|
|
"""
|
|||
|
|
LLM 客户端封装 - 统一 Anthropic/OpenAI/AIHubMix 接口
|
|||
|
|
"""
|
|||
|
|
from typing import Generator
|
|||
|
|
import os
|
|||
|
|
|
|||
|
|
|
|||
|
|
class LLMClient:
|
|||
|
|
"""LLM API 统一客户端"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
provider: str = None,
|
|||
|
|
api_key: str = None,
|
|||
|
|
base_url: str = None,
|
|||
|
|
model: str = None
|
|||
|
|
):
|
|||
|
|
"""
|
|||
|
|
初始化 LLM 客户端
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
provider: 'anthropic', 'openai', 'aihubmix', 或 'custom'
|
|||
|
|
api_key: API 密钥
|
|||
|
|
base_url: 自定义 API 地址(用于 aihubmix/custom)
|
|||
|
|
model: 指定模型名称
|
|||
|
|
"""
|
|||
|
|
self.provider = provider or "aihubmix"
|
|||
|
|
self.model = model or "gpt-4o"
|
|||
|
|
|
|||
|
|
if self.provider == "anthropic":
|
|||
|
|
from anthropic import Anthropic
|
|||
|
|
self.client = Anthropic(api_key=api_key)
|
|||
|
|
|
|||
|
|
elif self.provider == "openai":
|
|||
|
|
from openai import OpenAI
|
|||
|
|
self.client = OpenAI(api_key=api_key)
|
|||
|
|
self.model = model or "gpt-4o"
|
|||
|
|
|
|||
|
|
elif self.provider == "aihubmix":
|
|||
|
|
# AIHubMix 兼容 OpenAI API 格式
|
|||
|
|
from openai import OpenAI
|
|||
|
|
self.client = OpenAI(
|
|||
|
|
api_key=api_key,
|
|||
|
|
base_url=base_url or "https://aihubmix.com/v1"
|
|||
|
|
)
|
|||
|
|
self.model = model or "gpt-4o"
|
|||
|
|
|
|||
|
|
elif self.provider == "custom":
|
|||
|
|
# 自定义 OpenAI 兼容接口(vLLM、Ollama、TGI 等)
|
|||
|
|
from openai import OpenAI
|
|||
|
|
self.client = OpenAI(
|
|||
|
|
api_key=api_key or "not-needed",
|
|||
|
|
base_url=base_url or "http://localhost:8000/v1"
|
|||
|
|
)
|
|||
|
|
self.model = model or "local-model"
|
|||
|
|
|
|||
|
|
else:
|
|||
|
|
raise ValueError(f"不支持的 provider: {self.provider}")
|
|||
|
|
|
|||
|
|
def chat_stream(
|
|||
|
|
self,
|
|||
|
|
system_prompt: str,
|
|||
|
|
user_prompt: str,
|
|||
|
|
max_tokens: int = 1024
|
|||
|
|
) -> Generator[str, None, None]:
|
|||
|
|
"""
|
|||
|
|
流式对话
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_prompt: 系统提示词
|
|||
|
|
user_prompt: 用户输入
|
|||
|
|
max_tokens: 最大输出 token 数
|
|||
|
|
|
|||
|
|
Yields:
|
|||
|
|
str: 流式输出的文本片段
|
|||
|
|
"""
|
|||
|
|
if self.provider == "anthropic":
|
|||
|
|
yield from self._anthropic_stream(system_prompt, user_prompt, max_tokens)
|
|||
|
|
else:
|
|||
|
|
yield from self._openai_stream(system_prompt, user_prompt, max_tokens)
|
|||
|
|
|
|||
|
|
def _anthropic_stream(
|
|||
|
|
self,
|
|||
|
|
system_prompt: str,
|
|||
|
|
user_prompt: str,
|
|||
|
|
max_tokens: int
|
|||
|
|
) -> Generator[str, None, None]:
|
|||
|
|
"""Anthropic 流式调用"""
|
|||
|
|
with self.client.messages.stream(
|
|||
|
|
model=self.model,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
system=system_prompt,
|
|||
|
|
messages=[{"role": "user", "content": user_prompt}]
|
|||
|
|
) as stream:
|
|||
|
|
for text in stream.text_stream:
|
|||
|
|
yield text
|
|||
|
|
|
|||
|
|
def _openai_stream(
|
|||
|
|
self,
|
|||
|
|
system_prompt: str,
|
|||
|
|
user_prompt: str,
|
|||
|
|
max_tokens: int
|
|||
|
|
) -> Generator[str, None, None]:
|
|||
|
|
"""OpenAI 兼容接口流式调用(支持 AIHubMix、vLLM 等)"""
|
|||
|
|
try:
|
|||
|
|
stream = self.client.chat.completions.create(
|
|||
|
|
model=self.model,
|
|||
|
|
max_tokens=max_tokens,
|
|||
|
|
stream=True,
|
|||
|
|
messages=[
|
|||
|
|
{"role": "system", "content": system_prompt},
|
|||
|
|
{"role": "user", "content": user_prompt}
|
|||
|
|
]
|
|||
|
|
)
|
|||
|
|
for chunk in stream:
|
|||
|
|
# 安全地获取 content,处理各种边界情况
|
|||
|
|
if chunk.choices and len(chunk.choices) > 0:
|
|||
|
|
delta = chunk.choices[0].delta
|
|||
|
|
if delta and hasattr(delta, 'content') and delta.content:
|
|||
|
|
yield delta.content
|
|||
|
|
except Exception as e:
|
|||
|
|
yield f"\n\n[错误: {str(e)}]"
|
|||
|
|
|
|||
|
|
def chat(
|
|||
|
|
self,
|
|||
|
|
system_prompt: str,
|
|||
|
|
user_prompt: str,
|
|||
|
|
max_tokens: int = 1024
|
|||
|
|
) -> str:
|
|||
|
|
"""
|
|||
|
|
非流式对话
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
system_prompt: 系统提示词
|
|||
|
|
user_prompt: 用户输入
|
|||
|
|
max_tokens: 最大输出 token 数
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
str: 完整的响应文本
|
|||
|
|
"""
|
|||
|
|
return "".join(self.chat_stream(system_prompt, user_prompt, max_tokens))
|