import threading import random from typing import Callable, Dict, List, Optional, TypeVar import grpc from overrides import overrides from chromadb.api.types import GetResult, Metadata, QueryResult from chromadb.config import System from chromadb.execution.executor.abstract import Executor from chromadb.execution.expression.operator import Scan from chromadb.execution.expression.plan import CountPlan, GetPlan, KNNPlan from chromadb.proto import convert from chromadb.proto.query_executor_pb2_grpc import QueryExecutorStub from chromadb.segment.impl.manager.distributed import DistributedSegmentManager from chromadb.telemetry.opentelemetry.grpc import OtelInterceptor from tenacity import ( RetryCallState, Retrying, stop_after_attempt, wait_exponential_jitter, retry_if_exception, ) from opentelemetry.trace import Span def _clean_metadata(metadata: Optional[Metadata]) -> Optional[Metadata]: """Remove any chroma-specific metadata keys that the client shouldn't see from a metadata map.""" if not metadata: return None result = {} for k, v in metadata.items(): if not k.startswith("chroma:"): result[k] = v if len(result) == 0: return None return result def _uri(metadata: Optional[Metadata]) -> Optional[str]: """Retrieve the uri (if any) from a Metadata map""" if metadata and "chroma:uri" in metadata: return str(metadata["chroma:uri"]) return None # Type variables for input and output types of the round-robin retry function I = TypeVar("I") # noqa: E741 O = TypeVar("O") # noqa: E741 class DistributedExecutor(Executor): _mtx: threading.Lock _grpc_stub_pool: Dict[str, QueryExecutorStub] _manager: DistributedSegmentManager _request_timeout_seconds: int _query_replication_factor: int def __init__(self, system: System): super().__init__(system) self._mtx = threading.Lock() self._grpc_stub_pool = {} self._manager = self.require(DistributedSegmentManager) self._request_timeout_seconds = system.settings.require( "chroma_query_request_timeout_seconds" ) self._query_replication_factor = system.settings.require( "chroma_query_replication_factor" ) def _round_robin_retry(self, funcs: List[Callable[[I], O]], args: I) -> O: """ Retry a list of functions in a round-robin fashion until one of them succeeds. funcs: List of functions to retry args: Arguments to pass to each function """ attempt_count = 0 sleep_span: Optional[Span] = None def before_sleep(_: RetryCallState) -> None: # HACK(hammadb) 1/14/2024 - this is a hack to avoid the fact that tracer is not yet available and there are boot order issues # This should really use our component system to get the tracer. Since our grpc utils use this pattern # we are copying it here. This should be removed once we have a better way to get the tracer from chromadb.telemetry.opentelemetry import tracer nonlocal sleep_span if tracer is not None: sleep_span = tracer.start_span("Waiting to retry RPC") for attempt in Retrying( stop=stop_after_attempt(5), wait=wait_exponential_jitter(0.1, jitter=0.1), reraise=True, retry=retry_if_exception( lambda x: isinstance(x, grpc.RpcError) and x.code() in [grpc.StatusCode.UNAVAILABLE, grpc.StatusCode.UNKNOWN] ), before_sleep=before_sleep, ): if sleep_span is not None: sleep_span.end() sleep_span = None with attempt: return funcs[attempt_count % len(funcs)](args) attempt_count += 1 # NOTE(hammadb) because Retrying() will always either return or raise an exception, this line should never be reached raise Exception("Unreachable code error - should never reach here") @overrides def count(self, plan: CountPlan) -> int: endpoints = self._get_grpc_endpoints(plan.scan) count_funcs = [self._get_stub(endpoint).Count for endpoint in endpoints] count_result = self._round_robin_retry( count_funcs, convert.to_proto_count_plan(plan) ) return convert.from_proto_count_result(count_result) @overrides def get(self, plan: GetPlan) -> GetResult: endpoints = self._get_grpc_endpoints(plan.scan) get_funcs = [self._get_stub(endpoint).Get for endpoint in endpoints] get_result = self._round_robin_retry(get_funcs, convert.to_proto_get_plan(plan)) records = convert.from_proto_get_result(get_result) ids = [record["id"] for record in records] embeddings = ( [record["embedding"] for record in records] if plan.projection.embedding else None ) documents = ( [record["document"] for record in records] if plan.projection.document else None ) uris = ( [_uri(record["metadata"]) for record in records] if plan.projection.uri else None ) metadatas = ( [_clean_metadata(record["metadata"]) for record in records] if plan.projection.metadata else None ) # TODO: Fix typing return GetResult( ids=ids, embeddings=embeddings, # type: ignore[typeddict-item] documents=documents, # type: ignore[typeddict-item] uris=uris, # type: ignore[typeddict-item] data=None, metadatas=metadatas, # type: ignore[typeddict-item] included=plan.projection.included, ) @overrides def knn(self, plan: KNNPlan) -> QueryResult: endpoints = self._get_grpc_endpoints(plan.scan) knn_funcs = [self._get_stub(endpoint).KNN for endpoint in endpoints] knn_result = self._round_robin_retry(knn_funcs, convert.to_proto_knn_plan(plan)) results = convert.from_proto_knn_batch_result(knn_result) ids = [[record["record"]["id"] for record in records] for records in results] embeddings = ( [ [record["record"]["embedding"] for record in records] for records in results ] if plan.projection.embedding else None ) documents = ( [ [record["record"]["document"] for record in records] for records in results ] if plan.projection.document else None ) uris = ( [ [_uri(record["record"]["metadata"]) for record in records] for records in results ] if plan.projection.uri else None ) metadatas = ( [ [_clean_metadata(record["record"]["metadata"]) for record in records] for records in results ] if plan.projection.metadata else None ) distances = ( [[record["distance"] for record in records] for records in results] if plan.projection.rank else None ) # TODO: Fix typing return QueryResult( ids=ids, embeddings=embeddings, # type: ignore[typeddict-item] documents=documents, # type: ignore[typeddict-item] uris=uris, # type: ignore[typeddict-item] data=None, metadatas=metadatas, # type: ignore[typeddict-item] distances=distances, # type: ignore[typeddict-item] included=plan.projection.included, ) def _get_grpc_endpoints(self, scan: Scan) -> List[str]: # Since grpc endpoint is endpoint is determined by collection uuid, # the endpoint should be the same for all segments of the same collection grpc_urls = self._manager.get_endpoints( scan.record, self._query_replication_factor ) # Shuffle the grpc urls to distribute the load evenly random.shuffle(grpc_urls) return grpc_urls def _get_stub(self, grpc_url: str) -> QueryExecutorStub: with self._mtx: if grpc_url not in self._grpc_stub_pool: channel = grpc.insecure_channel( grpc_url, options=[ ("grpc.max_concurrent_streams", 1000), ("grpc.max_receive_message_length", 32000000), # 32 MB ], ) interceptors = [OtelInterceptor()] channel = grpc.intercept_channel(channel, *interceptors) self._grpc_stub_pool[grpc_url] = QueryExecutorStub(channel) return self._grpc_stub_pool[grpc_url]