278 lines
9.0 KiB
Python
278 lines
9.0 KiB
Python
from __future__ import annotations
|
|
|
|
from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence
|
|
from functools import cached_property
|
|
from typing import (
|
|
Any,
|
|
)
|
|
|
|
from langchain_core.runnables import Runnable, RunnableConfig
|
|
|
|
from langgraph._internal._config import merge_configs
|
|
from langgraph._internal._constants import CONF, CONFIG_KEY_READ
|
|
from langgraph._internal._runnable import RunnableCallable, RunnableSeq
|
|
from langgraph.pregel._utils import find_subgraph_pregel
|
|
from langgraph.pregel._write import ChannelWrite
|
|
from langgraph.pregel.protocol import PregelProtocol
|
|
from langgraph.types import CachePolicy, RetryPolicy
|
|
|
|
READ_TYPE = Callable[[str | Sequence[str], bool], Any | dict[str, Any]]
|
|
INPUT_CACHE_KEY_TYPE = tuple[Callable[..., Any], tuple[str, ...]]
|
|
|
|
|
|
class ChannelRead(RunnableCallable):
|
|
"""Implements the logic for reading state from CONFIG_KEY_READ.
|
|
Usable both as a runnable as well as a static method to call imperatively."""
|
|
|
|
channel: str | list[str]
|
|
|
|
fresh: bool = False
|
|
|
|
mapper: Callable[[Any], Any] | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
channel: str | list[str],
|
|
*,
|
|
fresh: bool = False,
|
|
mapper: Callable[[Any], Any] | None = None,
|
|
tags: list[str] | None = None,
|
|
) -> None:
|
|
super().__init__(
|
|
func=self._read,
|
|
afunc=self._aread,
|
|
tags=tags,
|
|
name=None,
|
|
trace=False,
|
|
)
|
|
self.fresh = fresh
|
|
self.mapper = mapper
|
|
self.channel = channel
|
|
|
|
def get_name(self, suffix: str | None = None, *, name: str | None = None) -> str:
|
|
if name:
|
|
pass
|
|
elif isinstance(self.channel, str):
|
|
name = f"ChannelRead<{self.channel}>"
|
|
else:
|
|
name = f"ChannelRead<{','.join(self.channel)}>"
|
|
return super().get_name(suffix, name=name)
|
|
|
|
def _read(self, _: Any, config: RunnableConfig) -> Any:
|
|
return self.do_read(
|
|
config, select=self.channel, fresh=self.fresh, mapper=self.mapper
|
|
)
|
|
|
|
async def _aread(self, _: Any, config: RunnableConfig) -> Any:
|
|
return self.do_read(
|
|
config, select=self.channel, fresh=self.fresh, mapper=self.mapper
|
|
)
|
|
|
|
@staticmethod
|
|
def do_read(
|
|
config: RunnableConfig,
|
|
*,
|
|
select: str | list[str],
|
|
fresh: bool = False,
|
|
mapper: Callable[[Any], Any] | None = None,
|
|
) -> Any:
|
|
try:
|
|
read: READ_TYPE = config[CONF][CONFIG_KEY_READ]
|
|
except KeyError:
|
|
raise RuntimeError(
|
|
"Not configured with a read function"
|
|
"Make sure to call in the context of a Pregel process"
|
|
)
|
|
if mapper:
|
|
return mapper(read(select, fresh))
|
|
else:
|
|
return read(select, fresh)
|
|
|
|
|
|
DEFAULT_BOUND = RunnableCallable(lambda input: input)
|
|
|
|
|
|
class PregelNode:
|
|
"""A node in a Pregel graph. This won't be invoked as a runnable by the graph
|
|
itself, but instead acts as a container for the components necessary to make
|
|
a PregelExecutableTask for a node."""
|
|
|
|
channels: str | list[str]
|
|
"""The channels that will be passed as input to `bound`.
|
|
If a str, the node will be invoked with its value if it isn't empty.
|
|
If a list, the node will be invoked with a dict of those channels' values."""
|
|
|
|
triggers: list[str]
|
|
"""If any of these channels is written to, this node will be triggered in
|
|
the next step."""
|
|
|
|
mapper: Callable[[Any], Any] | None
|
|
"""A function to transform the input before passing it to `bound`."""
|
|
|
|
writers: list[Runnable]
|
|
"""A list of writers that will be executed after `bound`, responsible for
|
|
taking the output of `bound` and writing it to the appropriate channels."""
|
|
|
|
bound: Runnable[Any, Any]
|
|
"""The main logic of the node. This will be invoked with the input from
|
|
`channels`."""
|
|
|
|
retry_policy: Sequence[RetryPolicy] | None
|
|
"""The retry policies to use when invoking the node."""
|
|
|
|
cache_policy: CachePolicy | None
|
|
"""The cache policy to use when invoking the node."""
|
|
|
|
tags: Sequence[str] | None
|
|
"""Tags to attach to the node for tracing."""
|
|
|
|
metadata: Mapping[str, Any] | None
|
|
"""Metadata to attach to the node for tracing."""
|
|
|
|
subgraphs: Sequence[PregelProtocol]
|
|
"""Subgraphs used by the node."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
channels: str | list[str],
|
|
triggers: Sequence[str],
|
|
mapper: Callable[[Any], Any] | None = None,
|
|
writers: list[Runnable] | None = None,
|
|
tags: list[str] | None = None,
|
|
metadata: Mapping[str, Any] | None = None,
|
|
bound: Runnable[Any, Any] | None = None,
|
|
retry_policy: RetryPolicy | Sequence[RetryPolicy] | None = None,
|
|
cache_policy: CachePolicy | None = None,
|
|
subgraphs: Sequence[PregelProtocol] | None = None,
|
|
) -> None:
|
|
self.channels = channels
|
|
self.triggers = list(triggers)
|
|
self.mapper = mapper
|
|
self.writers = writers or []
|
|
self.bound = bound if bound is not None else DEFAULT_BOUND
|
|
self.cache_policy = cache_policy
|
|
if isinstance(retry_policy, RetryPolicy):
|
|
self.retry_policy = (retry_policy,)
|
|
else:
|
|
self.retry_policy = retry_policy
|
|
self.tags = tags
|
|
self.metadata = metadata
|
|
if subgraphs is not None:
|
|
self.subgraphs = subgraphs
|
|
elif self.bound is not DEFAULT_BOUND:
|
|
try:
|
|
subgraph = find_subgraph_pregel(self.bound)
|
|
except Exception:
|
|
subgraph = None
|
|
if subgraph:
|
|
self.subgraphs = [subgraph]
|
|
else:
|
|
self.subgraphs = []
|
|
else:
|
|
self.subgraphs = []
|
|
|
|
def copy(self, update: dict[str, Any]) -> PregelNode:
|
|
attrs = {**self.__dict__, **update}
|
|
# Drop the cached properties
|
|
attrs.pop("flat_writers", None)
|
|
attrs.pop("node", None)
|
|
attrs.pop("input_cache_key", None)
|
|
return PregelNode(**attrs)
|
|
|
|
@cached_property
|
|
def flat_writers(self) -> list[Runnable]:
|
|
"""Get writers with optimizations applied. Dedupes consecutive ChannelWrites."""
|
|
writers = self.writers.copy()
|
|
while (
|
|
len(writers) > 1
|
|
and isinstance(writers[-1], ChannelWrite)
|
|
and isinstance(writers[-2], ChannelWrite)
|
|
):
|
|
# we can combine writes if they are consecutive
|
|
# careful to not modify the original writers list or ChannelWrite
|
|
writers[-2] = ChannelWrite(
|
|
writes=writers[-2].writes + writers[-1].writes,
|
|
)
|
|
writers.pop()
|
|
return writers
|
|
|
|
@cached_property
|
|
def node(self) -> Runnable[Any, Any] | None:
|
|
"""Get a runnable that combines `bound` and `writers`."""
|
|
writers = self.flat_writers
|
|
if self.bound is DEFAULT_BOUND and not writers:
|
|
return None
|
|
elif self.bound is DEFAULT_BOUND and len(writers) == 1:
|
|
return writers[0]
|
|
elif self.bound is DEFAULT_BOUND:
|
|
return RunnableSeq(*writers)
|
|
elif writers:
|
|
return RunnableSeq(self.bound, *writers)
|
|
else:
|
|
return self.bound
|
|
|
|
@cached_property
|
|
def input_cache_key(self) -> INPUT_CACHE_KEY_TYPE:
|
|
"""Get a cache key for the input to the node.
|
|
This is used to avoid calculating the same input multiple times."""
|
|
return (
|
|
self.mapper,
|
|
tuple(self.channels)
|
|
if isinstance(self.channels, list)
|
|
else (self.channels,),
|
|
)
|
|
|
|
def invoke(
|
|
self,
|
|
input: Any,
|
|
config: RunnableConfig | None = None,
|
|
**kwargs: Any | None,
|
|
) -> Any:
|
|
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
|
return self.bound.invoke(
|
|
input,
|
|
merge_configs(self_config, config),
|
|
**kwargs,
|
|
)
|
|
|
|
async def ainvoke(
|
|
self,
|
|
input: Any,
|
|
config: RunnableConfig | None = None,
|
|
**kwargs: Any | None,
|
|
) -> Any:
|
|
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
|
return await self.bound.ainvoke(
|
|
input,
|
|
merge_configs(self_config, config),
|
|
**kwargs,
|
|
)
|
|
|
|
def stream(
|
|
self,
|
|
input: Any,
|
|
config: RunnableConfig | None = None,
|
|
**kwargs: Any | None,
|
|
) -> Iterator[Any]:
|
|
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
|
yield from self.bound.stream(
|
|
input,
|
|
merge_configs(self_config, config),
|
|
**kwargs,
|
|
)
|
|
|
|
async def astream(
|
|
self,
|
|
input: Any,
|
|
config: RunnableConfig | None = None,
|
|
**kwargs: Any | None,
|
|
) -> AsyncIterator[Any]:
|
|
self_config: RunnableConfig = {"metadata": self.metadata, "tags": self.tags}
|
|
async for item in self.bound.astream(
|
|
input,
|
|
merge_configs(self_config, config),
|
|
**kwargs,
|
|
):
|
|
yield item
|