group-wbl/.venv/lib/python3.13/site-packages/langgraph/_internal/_runnable.py

915 lines
31 KiB
Python
Raw Permalink Normal View History

2026-01-09 09:48:03 +08:00
from __future__ import annotations
import asyncio
import enum
import inspect
import sys
import warnings
from collections.abc import (
AsyncIterator,
Awaitable,
Callable,
Coroutine,
Generator,
Iterator,
Sequence,
)
from contextlib import AsyncExitStack, contextmanager
from contextvars import Context, Token, copy_context
from functools import partial, wraps
from typing import (
Any,
Optional,
Protocol,
TypeGuard,
cast,
)
from langchain_core.runnables.base import (
Runnable,
RunnableConfig,
RunnableLambda,
RunnableParallel,
RunnableSequence,
)
from langchain_core.runnables.base import (
RunnableLike as LCRunnableLike,
)
from langchain_core.runnables.config import (
run_in_executor,
var_child_runnable_config,
)
from langchain_core.runnables.utils import Input, Output
from langchain_core.tracers.langchain import LangChainTracer
from langgraph.store.base import BaseStore
from langgraph._internal._config import (
ensure_config,
get_async_callback_manager_for_config,
get_callback_manager_for_config,
patch_config,
)
from langgraph._internal._constants import (
CONF,
CONFIG_KEY_RUNTIME,
)
from langgraph._internal._typing import MISSING
from langgraph.types import StreamWriter
try:
from langchain_core.tracers._streaming import _StreamingCallbackHandler
except ImportError:
_StreamingCallbackHandler = None # type: ignore
def _set_config_context(
config: RunnableConfig, run: Any = None
) -> Token[RunnableConfig | None]:
"""Set the child Runnable config + tracing context.
Args:
config: The config to set.
"""
config_token = var_child_runnable_config.set(config)
if run is not None:
from langsmith.run_helpers import _set_tracing_context
_set_tracing_context({"parent": run})
return config_token
def _unset_config_context(token: Token[RunnableConfig | None], run: Any = None) -> None:
"""Set the child Runnable config + tracing context.
Args:
token: The config token to reset.
"""
var_child_runnable_config.reset(token)
if run is not None:
from langsmith.run_helpers import _set_tracing_context
_set_tracing_context(
{
"parent": None,
"project_name": None,
"tags": None,
"metadata": None,
"enabled": None,
"client": None,
}
)
@contextmanager
def set_config_context(
config: RunnableConfig, run: Any = None
) -> Generator[Context, None, None]:
"""Set the child Runnable config + tracing context.
Args:
config: The config to set.
"""
ctx = copy_context()
config_token = ctx.run(_set_config_context, config, run)
try:
yield ctx
finally:
ctx.run(_unset_config_context, config_token, run)
# Before Python 3.11 native StrEnum is not available
class StrEnum(str, enum.Enum):
"""A string enum."""
# Special type to denote any type is accepted
ANY_TYPE = object()
ASYNCIO_ACCEPTS_CONTEXT = sys.version_info >= (3, 11)
# List of keyword arguments that can be injected into nodes / tasks / tools at runtime.
# A named argument may appear multiple times if it appears with distinct types.
KWARGS_CONFIG_KEYS: tuple[tuple[str, tuple[Any, ...], str, Any], ...] = (
(
"config",
(
RunnableConfig,
"RunnableConfig",
Optional[RunnableConfig], # noqa: UP045
"Optional[RunnableConfig]",
inspect.Parameter.empty,
),
# for now, use config directly, eventually, will pop off of Runtime
"N/A",
inspect.Parameter.empty,
),
(
"writer",
(StreamWriter, "StreamWriter", inspect.Parameter.empty),
"stream_writer",
lambda _: None,
),
(
"store",
(
BaseStore,
"BaseStore",
inspect.Parameter.empty,
),
"store",
inspect.Parameter.empty,
),
(
"store",
(
Optional[BaseStore], # noqa: UP045
"Optional[BaseStore]",
),
"store",
None,
),
(
"previous",
(ANY_TYPE,),
"previous",
inspect.Parameter.empty,
),
(
"runtime",
(ANY_TYPE,),
# we never hit this block, we just inject runtime directly
"N/A",
inspect.Parameter.empty,
),
)
"""List of kwargs that can be passed to functions, and their corresponding
config keys, default values and type annotations.
Used to configure keyword arguments that can be injected at runtime
from the `Runtime` object as kwargs to `invoke`, `ainvoke`, `stream` and `astream`.
For a keyword to be injected from the config object, the function signature
must contain a kwarg with the same name and a matching type annotation.
Each tuple contains:
- the name of the kwarg in the function signature
- the type annotation(s) for the kwarg
- the `Runtime` attribute for fetching the value (N/A if not applicable)
This is fully internal and should be further refactored to use `get_type_hints`
to resolve forward references and optional types formatted like BaseStore | None.
"""
VALID_KINDS = (inspect.Parameter.POSITIONAL_OR_KEYWORD, inspect.Parameter.KEYWORD_ONLY)
class _RunnableWithWriter(Protocol[Input, Output]):
def __call__(self, state: Input, *, writer: StreamWriter) -> Output: ...
class _RunnableWithStore(Protocol[Input, Output]):
def __call__(self, state: Input, *, store: BaseStore) -> Output: ...
class _RunnableWithWriterStore(Protocol[Input, Output]):
def __call__(
self, state: Input, *, writer: StreamWriter, store: BaseStore
) -> Output: ...
class _RunnableWithConfigWriter(Protocol[Input, Output]):
def __call__(
self, state: Input, *, config: RunnableConfig, writer: StreamWriter
) -> Output: ...
class _RunnableWithConfigStore(Protocol[Input, Output]):
def __call__(
self, state: Input, *, config: RunnableConfig, store: BaseStore
) -> Output: ...
class _RunnableWithConfigWriterStore(Protocol[Input, Output]):
def __call__(
self,
state: Input,
*,
config: RunnableConfig,
writer: StreamWriter,
store: BaseStore,
) -> Output: ...
RunnableLike = (
LCRunnableLike
| _RunnableWithWriter[Input, Output]
| _RunnableWithStore[Input, Output]
| _RunnableWithWriterStore[Input, Output]
| _RunnableWithConfigWriter[Input, Output]
| _RunnableWithConfigStore[Input, Output]
| _RunnableWithConfigWriterStore[Input, Output]
)
class RunnableCallable(Runnable):
"""A much simpler version of RunnableLambda that requires sync and async functions."""
def __init__(
self,
func: Callable[..., Any | Runnable] | None,
afunc: Callable[..., Awaitable[Any | Runnable]] | None = None,
*,
name: str | None = None,
tags: Sequence[str] | None = None,
trace: bool = True,
recurse: bool = True,
explode_args: bool = False,
**kwargs: Any,
) -> None:
self.name = name
if self.name is None:
if func:
try:
if func.__name__ != "<lambda>":
self.name = func.__name__
except AttributeError:
pass
elif afunc:
try:
self.name = afunc.__name__
except AttributeError:
pass
self.func = func
self.afunc = afunc
self.tags = tags
self.kwargs = kwargs
self.trace = trace
self.recurse = recurse
self.explode_args = explode_args
# check signature
if func is None and afunc is None:
raise ValueError("At least one of func or afunc must be provided.")
self.func_accepts: dict[str, tuple[str, Any]] = {}
params = inspect.signature(cast(Callable, func or afunc)).parameters
for kw, typ, runtime_key, default in KWARGS_CONFIG_KEYS:
p = params.get(kw)
if p is None or p.kind not in VALID_KINDS:
# If parameter is not found or is not a valid kind, skip
continue
if typ != (ANY_TYPE,) and p.annotation not in typ:
# A specific type is required, but the function annotation does
# not match the expected type.
# If this is a config parameter with incorrect typing, emit a warning
# because we used to support any type but are moving towards more correct typing
if kw == "config" and p.annotation != inspect.Parameter.empty:
warnings.warn(
f"The 'config' parameter should be typed as 'RunnableConfig' or "
f"'RunnableConfig | None', not '{p.annotation}'. ",
UserWarning,
stacklevel=4,
)
continue
# If the kwarg is accepted by the function, store the key / runtime attribute to inject
self.func_accepts[kw] = (runtime_key, default)
def __repr__(self) -> str:
repr_args = {
k: v
for k, v in self.__dict__.items()
if k not in {"name", "func", "afunc", "config", "kwargs", "trace"}
}
return f"{self.get_name()}({', '.join(f'{k}={v!r}' for k, v in repr_args.items())})"
def invoke(
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
) -> Any:
if self.func is None:
raise TypeError(
f'No synchronous function provided to "{self.name}".'
"\nEither initialize with a synchronous function or invoke"
" via the async API (ainvoke, astream, etc.)"
)
if config is None:
config = ensure_config()
if self.explode_args:
args, _kwargs = input
kwargs = {**self.kwargs, **_kwargs, **kwargs}
else:
args = (input,)
kwargs = {**self.kwargs, **kwargs}
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
for kw, (runtime_key, default) in self.func_accepts.items():
# If the kwarg is already set, use the set value
if kw in kwargs:
continue
kw_value: Any = MISSING
if kw == "config":
kw_value = config
elif runtime:
if kw == "runtime":
kw_value = runtime
else:
try:
kw_value = getattr(runtime, runtime_key)
except AttributeError:
pass
if kw_value is MISSING:
if default is inspect.Parameter.empty:
raise ValueError(
f"Missing required config key '{runtime_key}' for '{self.name}'."
)
kw_value = default
kwargs[kw] = kw_value
if self.trace:
callback_manager = get_callback_manager_for_config(config, self.tags)
run_manager = callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
# get the run
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# run in context
with set_config_context(child_config, run) as context:
ret = context.run(self.func, *args, **kwargs)
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(ret)
else:
ret = self.func(*args, **kwargs)
if self.recurse and isinstance(ret, Runnable):
return ret.invoke(input, config)
return ret
async def ainvoke(
self, input: Any, config: RunnableConfig | None = None, **kwargs: Any
) -> Any:
if not self.afunc:
return self.invoke(input, config)
if config is None:
config = ensure_config()
if self.explode_args:
args, _kwargs = input
kwargs = {**self.kwargs, **_kwargs, **kwargs}
else:
args = (input,)
kwargs = {**self.kwargs, **kwargs}
runtime = config.get(CONF, {}).get(CONFIG_KEY_RUNTIME)
for kw, (runtime_key, default) in self.func_accepts.items():
# If the kwarg has already been set, use the set value
if kw in kwargs:
continue
kw_value: Any = MISSING
if kw == "config":
kw_value = config
elif runtime:
if kw == "runtime":
kw_value = runtime
else:
try:
kw_value = getattr(runtime, runtime_key)
except AttributeError:
pass
if kw_value is MISSING:
if default is inspect.Parameter.empty:
raise ValueError(
f"Missing required config key '{runtime_key}' for '{self.name}'."
)
kw_value = default
kwargs[kw] = kw_value
if self.trace:
callback_manager = get_async_callback_manager_for_config(config, self.tags)
run_manager = await callback_manager.on_chain_start(
None,
input,
name=config.get("run_name") or self.name,
run_id=config.pop("run_id", None),
)
try:
child_config = patch_config(config, callbacks=run_manager.get_child())
coro = cast(Coroutine[None, None, Any], self.afunc(*args, **kwargs))
if ASYNCIO_ACCEPTS_CONTEXT:
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
with set_config_context(child_config, run) as context:
ret = await asyncio.create_task(coro, context=context)
else:
ret = await coro
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(ret)
else:
ret = await self.afunc(*args, **kwargs)
if self.recurse and isinstance(ret, Runnable):
return await ret.ainvoke(input, config)
return ret
def is_async_callable(
func: Any,
) -> TypeGuard[Callable[..., Awaitable]]:
"""Check if a function is async."""
return (
inspect.iscoroutinefunction(func)
or hasattr(func, "__call__")
and inspect.iscoroutinefunction(func.__call__)
)
def is_async_generator(
func: Any,
) -> TypeGuard[Callable[..., AsyncIterator]]:
"""Check if a function is an async generator."""
return (
inspect.isasyncgenfunction(func)
or hasattr(func, "__call__")
and inspect.isasyncgenfunction(func.__call__)
)
def coerce_to_runnable(
thing: RunnableLike, *, name: str | None, trace: bool
) -> Runnable:
"""Coerce a runnable-like object into a Runnable.
Args:
thing: A runnable-like object.
Returns:
A Runnable.
"""
if isinstance(thing, Runnable):
return thing
elif is_async_generator(thing) or inspect.isgeneratorfunction(thing):
return RunnableLambda(thing, name=name)
elif callable(thing):
if is_async_callable(thing):
return RunnableCallable(None, thing, name=name, trace=trace)
else:
return RunnableCallable(
thing,
wraps(thing)(partial(run_in_executor, None, thing)), # type: ignore[arg-type]
name=name,
trace=trace,
)
elif isinstance(thing, dict):
return RunnableParallel(thing)
else:
raise TypeError(
f"Expected a Runnable, callable or dict."
f"Instead got an unsupported type: {type(thing)}"
)
class RunnableSeq(Runnable):
"""Sequence of `Runnable`, where the output of each is the input of the next.
`RunnableSeq` is a simpler version of `RunnableSequence` that is internal to
LangGraph.
"""
def __init__(
self,
*steps: RunnableLike,
name: str | None = None,
trace_inputs: Callable[[Any], Any] | None = None,
) -> None:
"""Create a new RunnableSeq.
Args:
steps: The steps to include in the sequence.
name: The name of the `Runnable`.
Raises:
ValueError: If the sequence has less than 2 steps.
"""
steps_flat: list[Runnable] = []
for step in steps:
if isinstance(step, RunnableSequence):
steps_flat.extend(step.steps)
elif isinstance(step, RunnableSeq):
steps_flat.extend(step.steps)
else:
steps_flat.append(coerce_to_runnable(step, name=None, trace=True))
if len(steps_flat) < 2:
raise ValueError(
f"RunnableSeq must have at least 2 steps, got {len(steps_flat)}"
)
self.steps = steps_flat
self.name = name
self.trace_inputs = trace_inputs
def __or__(
self,
other: Any,
) -> Runnable:
if isinstance(other, RunnableSequence):
return RunnableSeq(
*self.steps,
other.first,
*other.middle,
other.last,
name=self.name or other.name,
)
elif isinstance(other, RunnableSeq):
return RunnableSeq(
*self.steps,
*other.steps,
name=self.name or other.name,
)
else:
return RunnableSeq(
*self.steps,
coerce_to_runnable(other, name=None, trace=True),
name=self.name,
)
def __ror__(
self,
other: Any,
) -> Runnable:
if isinstance(other, RunnableSequence):
return RunnableSequence(
other.first,
*other.middle,
other.last,
*self.steps,
name=other.name or self.name,
)
elif isinstance(other, RunnableSeq):
return RunnableSeq(
*other.steps,
*self.steps,
name=other.name or self.name,
)
else:
return RunnableSequence(
coerce_to_runnable(other, name=None, trace=True),
*self.steps,
name=self.name,
)
def invoke(
self, input: Input, config: RunnableConfig | None = None, **kwargs: Any
) -> Any:
if config is None:
config = ensure_config()
# setup callbacks and context
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
)
# 1st step is the actual node,
# others are writers which don't need to be run in context
if i == 0:
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# run in context
with set_config_context(config, run) as context:
input = context.run(step.invoke, input, config, **kwargs)
else:
input = step.invoke(input, config)
# finish the root run
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(input)
return input
async def ainvoke(
self,
input: Input,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Any:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# invoke all steps in sequence
try:
for i, step in enumerate(self.steps):
# mark each step as a child run
config = patch_config(
config, callbacks=run_manager.get_child(f"seq:step:{i + 1}")
)
# 1st step is the actual node,
# others are writers which don't need to be run in context
if i == 0:
if ASYNCIO_ACCEPTS_CONTEXT:
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# run in context
with set_config_context(config, run) as context:
input = await asyncio.create_task(
step.ainvoke(input, config, **kwargs), context=context
)
else:
input = await step.ainvoke(input, config, **kwargs)
else:
input = await step.ainvoke(input, config)
# finish the root run
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(input)
return input
def stream(
self,
input: Input,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_callback_manager_for_config(config)
# start the root run
run_manager = callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# create first step config
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{1}"),
)
# run all in context
with set_config_context(config, run) as context:
try:
# stream the last steps
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
for idx, step in enumerate(self.steps):
if idx == 0:
iterator = step.stream(input, config, **kwargs)
else:
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
)
iterator = step.transform(iterator, config)
# populates streamed_output in astream_log() output if needed
if _StreamingCallbackHandler is not None:
for h in run_manager.handlers:
if isinstance(h, _StreamingCallbackHandler):
iterator = h.tap_output_iter(run_manager.run_id, iterator)
# consume into final output
output = context.run(_consume_iter, iterator)
# sequence doesn't emit output, yield to mark as generator
yield
except BaseException as e:
run_manager.on_chain_error(e)
raise
else:
run_manager.on_chain_end(output)
async def astream(
self,
input: Input,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
if config is None:
config = ensure_config()
# setup callbacks
callback_manager = get_async_callback_manager_for_config(config)
# start the root run
run_manager = await callback_manager.on_chain_start(
None,
self.trace_inputs(input) if self.trace_inputs is not None else input,
name=config.get("run_name") or self.get_name(),
run_id=config.pop("run_id", None),
)
# stream the last steps
# transform the input stream of each step with the next
# steps that don't natively support transforming an input stream will
# buffer input in memory until all available, and then start emitting output
if ASYNCIO_ACCEPTS_CONTEXT:
# get the run object
for h in run_manager.handlers:
if isinstance(h, LangChainTracer):
run = h.run_map.get(str(run_manager.run_id))
break
else:
run = None
# create first step config
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{1}"),
)
# run all in context
with set_config_context(config, run) as context:
try:
async with AsyncExitStack() as stack:
for idx, step in enumerate(self.steps):
if idx == 0:
aiterator = step.astream(input, config, **kwargs)
else:
config = patch_config(
config,
callbacks=run_manager.get_child(
f"seq:step:{idx + 1}"
),
)
aiterator = step.atransform(aiterator, config)
if hasattr(aiterator, "aclose"):
stack.push_async_callback(aiterator.aclose)
# populates streamed_output in astream_log() output if needed
if _StreamingCallbackHandler is not None:
for h in run_manager.handlers:
if isinstance(h, _StreamingCallbackHandler):
aiterator = h.tap_output_aiter(
run_manager.run_id, aiterator
)
# consume into final output
output = await asyncio.create_task(
_consume_aiter(aiterator), context=context
)
# sequence doesn't emit output, yield to mark as generator
yield
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(output)
else:
try:
async with AsyncExitStack() as stack:
for idx, step in enumerate(self.steps):
config = patch_config(
config,
callbacks=run_manager.get_child(f"seq:step:{idx + 1}"),
)
if idx == 0:
aiterator = step.astream(input, config, **kwargs)
else:
aiterator = step.atransform(aiterator, config)
if hasattr(aiterator, "aclose"):
stack.push_async_callback(aiterator.aclose)
# populates streamed_output in astream_log() output if needed
if _StreamingCallbackHandler is not None:
for h in run_manager.handlers:
if isinstance(h, _StreamingCallbackHandler):
aiterator = h.tap_output_aiter(
run_manager.run_id, aiterator
)
# consume into final output
output = await _consume_aiter(aiterator)
# sequence doesn't emit output, yield to mark as generator
yield
except BaseException as e:
await run_manager.on_chain_error(e)
raise
else:
await run_manager.on_chain_end(output)
def _consume_iter(it: Iterator[Any]) -> Any:
"""Consume an iterator."""
output: Any = None
add_supported = False
for chunk in it:
# collect final output
if output is None:
output = chunk
elif add_supported:
try:
output = output + chunk
except TypeError:
output = chunk
add_supported = False
else:
output = chunk
return output
async def _consume_aiter(it: AsyncIterator[Any]) -> Any:
"""Consume an async iterator."""
output: Any = None
add_supported = False
async for chunk in it:
# collect final output
if add_supported:
try:
output = output + chunk
except TypeError:
output = chunk
add_supported = False
else:
output = chunk
return output