wd666/utils/llm_client.py
2026-01-07 11:02:05 +08:00

142 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
LLM 客户端封装 - 统一 Anthropic/OpenAI/AIHubMix 接口
"""
from typing import Generator
import os
class LLMClient:
"""LLM API 统一客户端"""
def __init__(
self,
provider: str = None,
api_key: str = None,
base_url: str = None,
model: str = None
):
"""
初始化 LLM 客户端
Args:
provider: 'anthropic', 'openai', 'aihubmix', 或 'custom'
api_key: API 密钥
base_url: 自定义 API 地址(用于 aihubmix/custom
model: 指定模型名称
"""
self.provider = provider or "aihubmix"
self.model = model or "gpt-4o"
if self.provider == "anthropic":
from anthropic import Anthropic
self.client = Anthropic(api_key=api_key)
elif self.provider == "openai":
from openai import OpenAI
self.client = OpenAI(api_key=api_key)
self.model = model or "gpt-4o"
elif self.provider == "aihubmix":
# AIHubMix 兼容 OpenAI API 格式
from openai import OpenAI
self.client = OpenAI(
api_key=api_key,
base_url=base_url or "https://aihubmix.com/v1"
)
self.model = model or "gpt-4o"
elif self.provider == "custom":
# 自定义 OpenAI 兼容接口vLLM、Ollama、TGI 等)
from openai import OpenAI
self.client = OpenAI(
api_key=api_key or "not-needed",
base_url=base_url or "http://localhost:8000/v1"
)
self.model = model or "local-model"
else:
raise ValueError(f"不支持的 provider: {self.provider}")
def chat_stream(
self,
system_prompt: str,
user_prompt: str,
max_tokens: int = 1024
) -> Generator[str, None, None]:
"""
流式对话
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
max_tokens: 最大输出 token 数
Yields:
str: 流式输出的文本片段
"""
if self.provider == "anthropic":
yield from self._anthropic_stream(system_prompt, user_prompt, max_tokens)
else:
yield from self._openai_stream(system_prompt, user_prompt, max_tokens)
def _anthropic_stream(
self,
system_prompt: str,
user_prompt: str,
max_tokens: int
) -> Generator[str, None, None]:
"""Anthropic 流式调用"""
with self.client.messages.stream(
model=self.model,
max_tokens=max_tokens,
system=system_prompt,
messages=[{"role": "user", "content": user_prompt}]
) as stream:
for text in stream.text_stream:
yield text
def _openai_stream(
self,
system_prompt: str,
user_prompt: str,
max_tokens: int
) -> Generator[str, None, None]:
"""OpenAI 兼容接口流式调用(支持 AIHubMix、vLLM 等)"""
try:
stream = self.client.chat.completions.create(
model=self.model,
max_tokens=max_tokens,
stream=True,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
)
for chunk in stream:
# 安全地获取 content处理各种边界情况
if chunk.choices and len(chunk.choices) > 0:
delta = chunk.choices[0].delta
if delta and hasattr(delta, 'content') and delta.content:
yield delta.content
except Exception as e:
yield f"\n\n[错误: {str(e)}]"
def chat(
self,
system_prompt: str,
user_prompt: str,
max_tokens: int = 1024
) -> str:
"""
非流式对话
Args:
system_prompt: 系统提示词
user_prompt: 用户输入
max_tokens: 最大输出 token 数
Returns:
str: 完整的响应文本
"""
return "".join(self.chat_stream(system_prompt, user_prompt, max_tokens))