146 lines
5.1 KiB
Python
146 lines
5.1 KiB
Python
from typing import Any, Dict, Mapping, Optional, TypeVar
|
|
from urllib.parse import quote, urlparse, urlunparse
|
|
import logging
|
|
import orjson as json
|
|
import httpx
|
|
|
|
import chromadb.errors as errors
|
|
from chromadb.config import Component, Settings, System
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
# inherits from Component so that it can create an init function to use system
|
|
# this way it can build limits from the settings in System
|
|
class BaseHTTPClient(Component):
|
|
_settings: Settings
|
|
pre_flight_checks: Any = None
|
|
DEFAULT_KEEPALIVE_SECS: float = 40.0
|
|
|
|
def __init__(self, system: System):
|
|
super().__init__(system)
|
|
self._settings = system.settings
|
|
keepalive_setting = self._settings.chroma_http_keepalive_secs
|
|
self.keepalive_secs: Optional[float] = (
|
|
keepalive_setting
|
|
if keepalive_setting is not None
|
|
else BaseHTTPClient.DEFAULT_KEEPALIVE_SECS
|
|
)
|
|
self._http_limits = self._build_limits()
|
|
|
|
def _build_limits(self) -> httpx.Limits:
|
|
limit_kwargs: Dict[str, Any] = {}
|
|
if self.keepalive_secs is not None:
|
|
limit_kwargs["keepalive_expiry"] = self.keepalive_secs
|
|
|
|
max_connections = self._settings.chroma_http_max_connections
|
|
if max_connections is not None:
|
|
limit_kwargs["max_connections"] = max_connections
|
|
|
|
max_keepalive_connections = self._settings.chroma_http_max_keepalive_connections
|
|
if max_keepalive_connections is not None:
|
|
limit_kwargs["max_keepalive_connections"] = max_keepalive_connections
|
|
|
|
return httpx.Limits(**limit_kwargs)
|
|
|
|
@property
|
|
def http_limits(self) -> httpx.Limits:
|
|
return self._http_limits
|
|
|
|
@staticmethod
|
|
def _validate_host(host: str) -> None:
|
|
parsed = urlparse(host)
|
|
if "/" in host and parsed.scheme not in {"http", "https"}:
|
|
raise ValueError(
|
|
"Invalid URL. " f"Unrecognized protocol - {parsed.scheme}."
|
|
)
|
|
if "/" in host and (not host.startswith("http")):
|
|
raise ValueError(
|
|
"Invalid URL. "
|
|
"Seems that you are trying to pass URL as a host but without \
|
|
specifying the protocol. "
|
|
"Please add http:// or https:// to the host."
|
|
)
|
|
|
|
@staticmethod
|
|
def resolve_url(
|
|
chroma_server_host: str,
|
|
chroma_server_ssl_enabled: Optional[bool] = False,
|
|
default_api_path: Optional[str] = "",
|
|
chroma_server_http_port: Optional[int] = 8000,
|
|
) -> str:
|
|
_skip_port = False
|
|
_chroma_server_host = chroma_server_host
|
|
BaseHTTPClient._validate_host(_chroma_server_host)
|
|
if _chroma_server_host.startswith("http"):
|
|
logger.debug("Skipping port as the user is passing a full URL")
|
|
_skip_port = True
|
|
parsed = urlparse(_chroma_server_host)
|
|
|
|
scheme = "https" if chroma_server_ssl_enabled else parsed.scheme or "http"
|
|
net_loc = parsed.netloc or parsed.hostname or chroma_server_host
|
|
port = (
|
|
":" + str(parsed.port or chroma_server_http_port) if not _skip_port else ""
|
|
)
|
|
path = parsed.path or default_api_path
|
|
|
|
if not path or path == net_loc:
|
|
path = default_api_path if default_api_path else ""
|
|
if not path.endswith(default_api_path or ""):
|
|
path = path + default_api_path if default_api_path else ""
|
|
full_url = urlunparse(
|
|
(scheme, f"{net_loc}{port}", quote(path.replace("//", "/")), "", "", "")
|
|
)
|
|
|
|
return full_url
|
|
|
|
# requests removes None values from the built query string, but httpx includes it as an empty value
|
|
T = TypeVar("T", bound=Dict[Any, Any])
|
|
|
|
@staticmethod
|
|
def _clean_params(params: T) -> T:
|
|
"""Remove None values from provided dict."""
|
|
return {k: v for k, v in params.items() if v is not None} # type: ignore
|
|
|
|
@staticmethod
|
|
def _raise_chroma_error(resp: httpx.Response) -> None:
|
|
"""Raises an error if the response is not ok, using a ChromaError if possible."""
|
|
try:
|
|
resp.raise_for_status()
|
|
return
|
|
except httpx.HTTPStatusError:
|
|
pass
|
|
|
|
chroma_error = None
|
|
try:
|
|
body = json.loads(resp.text)
|
|
if "error" in body:
|
|
if body["error"] in errors.error_types:
|
|
chroma_error = errors.error_types[body["error"]](body["message"])
|
|
|
|
trace_id = resp.headers.get("chroma-trace-id")
|
|
if trace_id:
|
|
chroma_error.trace_id = trace_id
|
|
|
|
except BaseException:
|
|
pass
|
|
|
|
if chroma_error:
|
|
raise chroma_error
|
|
|
|
try:
|
|
resp.raise_for_status()
|
|
except httpx.HTTPStatusError:
|
|
trace_id = resp.headers.get("chroma-trace-id")
|
|
if trace_id:
|
|
raise Exception(f"{resp.text} (trace ID: {trace_id})")
|
|
raise (Exception(resp.text))
|
|
|
|
def get_request_headers(self) -> Mapping[str, str]:
|
|
"""Return headers used for HTTP requests."""
|
|
return {}
|
|
|
|
def get_api_url(self) -> str:
|
|
"""Return the API URL for this client."""
|
|
return ""
|