"""Utilities for batching operations in a background task.""" from __future__ import annotations import asyncio import functools import weakref from collections.abc import Callable, Iterable from typing import Any, Literal, TypeVar from langgraph.store.base import ( NOT_PROVIDED, BaseStore, GetOp, Item, ListNamespacesOp, MatchCondition, NamespacePath, NotProvided, Op, PutOp, Result, SearchItem, SearchOp, _ensure_refresh, _ensure_ttl, _validate_namespace, ) F = TypeVar("F", bound=Callable) def _check_loop(func: F) -> F: @functools.wraps(func) def wrapper(store: AsyncBatchedBaseStore, *args: Any, **kwargs: Any) -> Any: method_name: str = func.__name__ try: current_loop = asyncio.get_running_loop() if current_loop is store._loop: replacement_str = ( f"Specifically, replace `store.{method_name}(...)` with `await store.a{method_name}(...)" if method_name else "For example, replace `store.get(...)` with `await store.aget(...)`" ) raise asyncio.InvalidStateError( f"Synchronous calls to {store.__class__.__name__} detected in the main event loop. " "This can lead to deadlocks or performance issues. " "Please use the asynchronous interface for main thread operations. " f"{replacement_str} " ) except RuntimeError: pass return func(store, *args, **kwargs) return wrapper class AsyncBatchedBaseStore(BaseStore): """Efficiently batch operations in a background task.""" __slots__ = ("_loop", "_aqueue", "_task") def __init__(self) -> None: super().__init__() self._loop = asyncio.get_running_loop() self._aqueue: asyncio.Queue[tuple[asyncio.Future, Op]] = asyncio.Queue() self._task: asyncio.Task | None = None self._ensure_task() def __del__(self) -> None: try: if self._task: self._task.cancel() except RuntimeError: pass def _ensure_task(self) -> None: """Ensure the background processing loop is running.""" if self._task is None or self._task.done(): self._task = self._loop.create_task(_run(self._aqueue, weakref.ref(self))) async def aget( self, namespace: tuple[str, ...], key: str, *, refresh_ttl: bool | None = None, ) -> Item | None: self._ensure_task() fut = self._loop.create_future() self._aqueue.put_nowait( ( fut, GetOp( namespace, key, refresh_ttl=_ensure_refresh(self.ttl_config, refresh_ttl), ), ) ) return await fut async def asearch( self, namespace_prefix: tuple[str, ...], /, *, query: str | None = None, filter: dict[str, Any] | None = None, limit: int = 10, offset: int = 0, refresh_ttl: bool | None = None, ) -> list[SearchItem]: self._ensure_task() fut = self._loop.create_future() self._aqueue.put_nowait( ( fut, SearchOp( namespace_prefix, filter, limit, offset, query, refresh_ttl=_ensure_refresh(self.ttl_config, refresh_ttl), ), ) ) return await fut async def aput( self, namespace: tuple[str, ...], key: str, value: dict[str, Any], index: Literal[False] | list[str] | None = None, *, ttl: float | None | NotProvided = NOT_PROVIDED, ) -> None: self._ensure_task() _validate_namespace(namespace) fut = self._loop.create_future() self._aqueue.put_nowait( ( fut, PutOp( namespace, key, value, index, ttl=_ensure_ttl(self.ttl_config, ttl) ), ) ) return await fut async def adelete( self, namespace: tuple[str, ...], key: str, ) -> None: self._ensure_task() fut = self._loop.create_future() self._aqueue.put_nowait((fut, PutOp(namespace, key, None))) return await fut async def alist_namespaces( self, *, prefix: NamespacePath | None = None, suffix: NamespacePath | None = None, max_depth: int | None = None, limit: int = 100, offset: int = 0, ) -> list[tuple[str, ...]]: self._ensure_task() fut = self._loop.create_future() match_conditions = [] if prefix: match_conditions.append(MatchCondition(match_type="prefix", path=prefix)) if suffix: match_conditions.append(MatchCondition(match_type="suffix", path=suffix)) op = ListNamespacesOp( match_conditions=tuple(match_conditions), max_depth=max_depth, limit=limit, offset=offset, ) self._aqueue.put_nowait((fut, op)) return await fut @_check_loop def batch(self, ops: Iterable[Op]) -> list[Result]: return asyncio.run_coroutine_threadsafe(self.abatch(ops), self._loop).result() @_check_loop def get( self, namespace: tuple[str, ...], key: str, *, refresh_ttl: bool | None = None, ) -> Item | None: return asyncio.run_coroutine_threadsafe( self.aget(namespace, key=key, refresh_ttl=refresh_ttl), self._loop ).result() @_check_loop def search( self, namespace_prefix: tuple[str, ...], /, *, query: str | None = None, filter: dict[str, Any] | None = None, limit: int = 10, offset: int = 0, refresh_ttl: bool | None = None, ) -> list[SearchItem]: return asyncio.run_coroutine_threadsafe( self.asearch( namespace_prefix, query=query, filter=filter, limit=limit, offset=offset, refresh_ttl=refresh_ttl, ), self._loop, ).result() @_check_loop def put( self, namespace: tuple[str, ...], key: str, value: dict[str, Any], index: Literal[False] | list[str] | None = None, *, ttl: float | None | NotProvided = NOT_PROVIDED, ) -> None: _validate_namespace(namespace) asyncio.run_coroutine_threadsafe( self.aput( namespace, key=key, value=value, index=index, ttl=_ensure_ttl(self.ttl_config, ttl), ), self._loop, ).result() @_check_loop def delete( self, namespace: tuple[str, ...], key: str, ) -> None: asyncio.run_coroutine_threadsafe( self.adelete(namespace, key=key), self._loop ).result() @_check_loop def list_namespaces( self, *, prefix: NamespacePath | None = None, suffix: NamespacePath | None = None, max_depth: int | None = None, limit: int = 100, offset: int = 0, ) -> list[tuple[str, ...]]: return asyncio.run_coroutine_threadsafe( self.alist_namespaces( prefix=prefix, suffix=suffix, max_depth=max_depth, limit=limit, offset=offset, ), self._loop, ).result() def _dedupe_ops(values: list[Op]) -> tuple[list[int] | None, list[Op]]: """Dedupe operations while preserving order for results. Args: values: List of operations to dedupe Returns: Tuple of (listen indices, deduped operations) where listen indices map deduped operation results back to original positions """ if len(values) <= 1: return None, list(values) dedupped: list[Op] = [] listen: list[int] = [] puts: dict[tuple[tuple[str, ...], str], int] = {} for op in values: if isinstance(op, (GetOp, SearchOp, ListNamespacesOp)): try: listen.append(dedupped.index(op)) except ValueError: listen.append(len(dedupped)) dedupped.append(op) elif isinstance(op, PutOp): putkey = (op.namespace, op.key) if putkey in puts: # Overwrite previous put ix = puts[putkey] dedupped[ix] = op listen.append(ix) else: puts[putkey] = len(dedupped) listen.append(len(dedupped)) dedupped.append(op) else: # Any new ops will be treated regularly listen.append(len(dedupped)) dedupped.append(op) return listen, dedupped async def _run( aqueue: asyncio.Queue[tuple[asyncio.Future, Op]], store: weakref.ReferenceType[BaseStore], ) -> None: while item := await aqueue.get(): # check if store is still alive if s := store(): try: # accumulate operations scheduled in same tick items = [item] try: while item := aqueue.get_nowait(): items.append(item) except asyncio.QueueEmpty: pass # get the operations to run futs = [item[0] for item in items] values = [item[1] for item in items] # action each operation try: listen, dedupped = _dedupe_ops(values) results = await s.abatch(dedupped) if listen is not None: results = [results[ix] for ix in listen] # set the results of each operation for fut, result in zip(futs, results, strict=False): # guard against future being done (e.g. cancelled) if not fut.done(): fut.set_result(result) except Exception as e: for fut in futs: # guard against future being done (e.g. cancelled) if not fut.done(): fut.set_exception(e) finally: # remove strong ref to store del s else: break