group-wbl/.venv/lib/python3.13/site-packages/langchain/chat_models/base.py
2026-01-09 09:48:03 +08:00

945 lines
36 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Factory functions for chat models."""
from __future__ import annotations
import warnings
from importlib import util
from typing import TYPE_CHECKING, Any, Literal, TypeAlias, cast, overload
from langchain_core.language_models import BaseChatModel, LanguageModelInput
from langchain_core.messages import AIMessage, AnyMessage
from langchain_core.runnables import Runnable, RunnableConfig, ensure_config
from typing_extensions import override
if TYPE_CHECKING:
from collections.abc import AsyncIterator, Callable, Iterator, Sequence
from langchain_core.runnables.schema import StreamEvent
from langchain_core.tools import BaseTool
from langchain_core.tracers import RunLog, RunLogPatch
from pydantic import BaseModel
@overload
def init_chat_model(
model: str,
*,
model_provider: str | None = None,
configurable_fields: None = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> BaseChatModel: ...
@overload
def init_chat_model(
model: None = None,
*,
model_provider: str | None = None,
configurable_fields: None = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
@overload
def init_chat_model(
model: str | None = None,
*,
model_provider: str | None = None,
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = ...,
config_prefix: str | None = None,
**kwargs: Any,
) -> _ConfigurableModel: ...
# FOR CONTRIBUTORS: If adding support for a new provider, please append the provider
# name to the supported list in the docstring below. Do *not* change the order of the
# existing providers.
def init_chat_model(
model: str | None = None,
*,
model_provider: str | None = None,
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] | None = None,
config_prefix: str | None = None,
**kwargs: Any,
) -> BaseChatModel | _ConfigurableModel:
"""Initialize a chat model from any supported provider using a unified interface.
**Two main use cases:**
1. **Fixed model** specify the model upfront and get a ready-to-use chat model.
2. **Configurable model** choose to specify parameters (including model name) at
runtime via `config`. Makes it easy to switch between models/providers without
changing your code
!!! note
Requires the integration package for the chosen model provider to be installed.
See the `model_provider` parameter below for specific package names
(e.g., `pip install langchain-openai`).
Refer to the [provider integration's API reference](https://docs.langchain.com/oss/python/integrations/providers)
for supported model parameters to use as `**kwargs`.
Args:
model: The name or ID of the model, e.g. `'o3-mini'`, `'claude-sonnet-4-5-20250929'`.
You can also specify model and model provider in a single argument using
`'{model_provider}:{model}'` format, e.g. `'openai:o1'`.
Will attempt to infer `model_provider` from model if not specified.
The following providers will be inferred based on these model prefixes:
- `gpt-...` | `o1...` | `o3...` -> `openai`
- `claude...` -> `anthropic`
- `amazon...` -> `bedrock`
- `gemini...` -> `google_vertexai`
- `command...` -> `cohere`
- `accounts/fireworks...` -> `fireworks`
- `mistral...` -> `mistralai`
- `deepseek...` -> `deepseek`
- `grok...` -> `xai`
- `sonar...` -> `perplexity`
- `solar...` -> `upstage`
model_provider: The model provider if not specified as part of the model arg
(see above).
Supported `model_provider` values and the corresponding integration package
are:
- `openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- `anthropic` -> [`langchain-anthropic`](https://docs.langchain.com/oss/python/integrations/providers/anthropic)
- `azure_openai` -> [`langchain-openai`](https://docs.langchain.com/oss/python/integrations/providers/openai)
- `azure_ai` -> [`langchain-azure-ai`](https://docs.langchain.com/oss/python/integrations/providers/microsoft)
- `google_vertexai` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `google_genai` -> [`langchain-google-genai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `bedrock` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
- `bedrock_converse` -> [`langchain-aws`](https://docs.langchain.com/oss/python/integrations/providers/aws)
- `cohere` -> [`langchain-cohere`](https://docs.langchain.com/oss/python/integrations/providers/cohere)
- `fireworks` -> [`langchain-fireworks`](https://docs.langchain.com/oss/python/integrations/providers/fireworks)
- `together` -> [`langchain-together`](https://docs.langchain.com/oss/python/integrations/providers/together)
- `mistralai` -> [`langchain-mistralai`](https://docs.langchain.com/oss/python/integrations/providers/mistralai)
- `huggingface` -> [`langchain-huggingface`](https://docs.langchain.com/oss/python/integrations/providers/huggingface)
- `groq` -> [`langchain-groq`](https://docs.langchain.com/oss/python/integrations/providers/groq)
- `ollama` -> [`langchain-ollama`](https://docs.langchain.com/oss/python/integrations/providers/ollama)
- `google_anthropic_vertex` -> [`langchain-google-vertexai`](https://docs.langchain.com/oss/python/integrations/providers/google)
- `deepseek` -> [`langchain-deepseek`](https://docs.langchain.com/oss/python/integrations/providers/deepseek)
- `ibm` -> [`langchain-ibm`](https://docs.langchain.com/oss/python/integrations/providers/ibm)
- `nvidia` -> [`langchain-nvidia-ai-endpoints`](https://docs.langchain.com/oss/python/integrations/providers/nvidia)
- `xai` -> [`langchain-xai`](https://docs.langchain.com/oss/python/integrations/providers/xai)
- `perplexity` -> [`langchain-perplexity`](https://docs.langchain.com/oss/python/integrations/providers/perplexity)
- `upstage` -> [`langchain-upstage`](https://docs.langchain.com/oss/python/integrations/providers/upstage)
configurable_fields: Which model parameters are configurable at runtime:
- `None`: No configurable fields (i.e., a fixed model).
- `'any'`: All fields are configurable. **See security note below.**
- `list[str] | Tuple[str, ...]`: Specified fields are configurable.
Fields are assumed to have `config_prefix` stripped if a `config_prefix` is
specified.
If `model` is specified, then defaults to `None`.
If `model` is not specified, then defaults to `("model", "model_provider")`.
!!! warning "Security note"
Setting `configurable_fields="any"` means fields like `api_key`,
`base_url`, etc., can be altered at runtime, potentially redirecting
model requests to a different service/user.
Make sure that if you're accepting untrusted configurations that you
enumerate the `configurable_fields=(...)` explicitly.
config_prefix: Optional prefix for configuration keys.
Useful when you have multiple configurable models in the same application.
If `'config_prefix'` is a non-empty string then `model` will be configurable
at runtime via the `config["configurable"]["{config_prefix}_{param}"]` keys.
See examples below.
If `'config_prefix'` is an empty string then model will be configurable via
`config["configurable"]["{param}"]`.
**kwargs: Additional model-specific keyword args to pass to the underlying
chat model's `__init__` method. Common parameters include:
- `temperature`: Model temperature for controlling randomness.
- `max_tokens`: Maximum number of output tokens.
- `timeout`: Maximum time (in seconds) to wait for a response.
- `max_retries`: Maximum number of retry attempts for failed requests.
- `base_url`: Custom API endpoint URL.
- `rate_limiter`: A
[`BaseRateLimiter`][langchain_core.rate_limiters.BaseRateLimiter]
instance to control request rate.
Refer to the specific model provider's
[integration reference](https://reference.langchain.com/python/integrations/)
for all available parameters.
Returns:
A `BaseChatModel` corresponding to the `model_name` and `model_provider`
specified if configurability is inferred to be `False`. If configurable, a
chat model emulator that initializes the underlying model at runtime once a
config is passed in.
Raises:
ValueError: If `model_provider` cannot be inferred or isn't supported.
ImportError: If the model provider integration package is not installed.
???+ example "Initialize a non-configurable model"
```python
# pip install langchain langchain-openai langchain-anthropic langchain-google-vertexai
from langchain.chat_models import init_chat_model
o3_mini = init_chat_model("openai:o3-mini", temperature=0)
claude_sonnet = init_chat_model("anthropic:claude-sonnet-4-5-20250929", temperature=0)
gemini_2-5_flash = init_chat_model("google_vertexai:gemini-2.5-flash", temperature=0)
o3_mini.invoke("what's your name")
claude_sonnet.invoke("what's your name")
gemini_2-5_flash.invoke("what's your name")
```
??? example "Partially configurable model with no default"
```python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
# (We don't need to specify configurable=True if a model isn't specified.)
configurable_model = init_chat_model(temperature=0)
configurable_model.invoke("what's your name", config={"configurable": {"model": "gpt-4o"}})
# Use GPT-4o to generate the response
configurable_model.invoke(
"what's your name",
config={"configurable": {"model": "claude-sonnet-4-5-20250929"}},
)
```
??? example "Fully configurable model with a default"
```python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
configurable_model_with_default = init_chat_model(
"openai:gpt-4o",
configurable_fields="any", # This allows us to configure other params like temperature, max_tokens, etc at runtime.
config_prefix="foo",
temperature=0,
)
configurable_model_with_default.invoke("what's your name")
# GPT-4o response with temperature 0 (as set in default)
configurable_model_with_default.invoke(
"what's your name",
config={
"configurable": {
"foo_model": "anthropic:claude-sonnet-4-5-20250929",
"foo_temperature": 0.6,
}
},
)
# Override default to use Sonnet 4.5 with temperature 0.6 to generate response
```
??? example "Bind tools to a configurable model"
You can call any chat model declarative methods on a configurable model in the
same way that you would with a normal model:
```python
# pip install langchain langchain-openai langchain-anthropic
from langchain.chat_models import init_chat_model
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
configurable_model = init_chat_model(
"gpt-4o", configurable_fields=("model", "model_provider"), temperature=0
)
configurable_model_with_tools = configurable_model.bind_tools(
[
GetWeather,
GetPopulation,
]
)
configurable_model_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
# Use GPT-4o
configurable_model_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?",
config={"configurable": {"model": "claude-sonnet-4-5-20250929"}},
)
# Use Sonnet 4.5
```
""" # noqa: E501
if not model and not configurable_fields:
configurable_fields = ("model", "model_provider")
config_prefix = config_prefix or ""
if config_prefix and not configurable_fields:
warnings.warn(
f"{config_prefix=} has been set but no fields are configurable. Set "
f"`configurable_fields=(...)` to specify the model params that are "
f"configurable.",
stacklevel=2,
)
if not configurable_fields:
return _init_chat_model_helper(
cast("str", model),
model_provider=model_provider,
**kwargs,
)
if model:
kwargs["model"] = model
if model_provider:
kwargs["model_provider"] = model_provider
return _ConfigurableModel(
default_config=kwargs,
config_prefix=config_prefix,
configurable_fields=configurable_fields,
)
def _init_chat_model_helper(
model: str,
*,
model_provider: str | None = None,
**kwargs: Any,
) -> BaseChatModel:
model, model_provider = _parse_model(model, model_provider)
if model_provider == "openai":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI
return ChatOpenAI(model=model, **kwargs)
if model_provider == "anthropic":
_check_pkg("langchain_anthropic")
from langchain_anthropic import ChatAnthropic
return ChatAnthropic(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "azure_openai":
_check_pkg("langchain_openai")
from langchain_openai import AzureChatOpenAI
return AzureChatOpenAI(model=model, **kwargs)
if model_provider == "azure_ai":
_check_pkg("langchain_azure_ai")
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel
return AzureAIChatCompletionsModel(model=model, **kwargs)
if model_provider == "cohere":
_check_pkg("langchain_cohere")
from langchain_cohere import ChatCohere
return ChatCohere(model=model, **kwargs)
if model_provider == "google_vertexai":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai import ChatVertexAI
return ChatVertexAI(model=model, **kwargs)
if model_provider == "google_genai":
_check_pkg("langchain_google_genai")
from langchain_google_genai import ChatGoogleGenerativeAI
return ChatGoogleGenerativeAI(model=model, **kwargs)
if model_provider == "fireworks":
_check_pkg("langchain_fireworks")
from langchain_fireworks import ChatFireworks
return ChatFireworks(model=model, **kwargs)
if model_provider == "ollama":
try:
_check_pkg("langchain_ollama")
from langchain_ollama import ChatOllama
except ImportError:
# For backwards compatibility
try:
_check_pkg("langchain_community")
from langchain_community.chat_models import ChatOllama
except ImportError:
# If both langchain-ollama and langchain-community aren't available,
# raise an error related to langchain-ollama
_check_pkg("langchain_ollama")
return ChatOllama(model=model, **kwargs)
if model_provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether
return ChatTogether(model=model, **kwargs)
if model_provider == "mistralai":
_check_pkg("langchain_mistralai")
from langchain_mistralai import ChatMistralAI
return ChatMistralAI(model=model, **kwargs) # type: ignore[call-arg,unused-ignore]
if model_provider == "huggingface":
_check_pkg("langchain_huggingface")
from langchain_huggingface import ChatHuggingFace
return ChatHuggingFace.from_model_id(model_id=model, **kwargs)
if model_provider == "groq":
_check_pkg("langchain_groq")
from langchain_groq import ChatGroq
return ChatGroq(model=model, **kwargs)
if model_provider == "bedrock":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrock
return ChatBedrock(model_id=model, **kwargs)
if model_provider == "bedrock_converse":
_check_pkg("langchain_aws")
from langchain_aws import ChatBedrockConverse
return ChatBedrockConverse(model=model, **kwargs)
if model_provider == "google_anthropic_vertex":
_check_pkg("langchain_google_vertexai")
from langchain_google_vertexai.model_garden import ChatAnthropicVertex
return ChatAnthropicVertex(model=model, **kwargs)
if model_provider == "deepseek":
_check_pkg("langchain_deepseek", pkg_kebab="langchain-deepseek")
from langchain_deepseek import ChatDeepSeek
return ChatDeepSeek(model=model, **kwargs)
if model_provider == "nvidia":
_check_pkg("langchain_nvidia_ai_endpoints")
from langchain_nvidia_ai_endpoints import ChatNVIDIA
return ChatNVIDIA(model=model, **kwargs)
if model_provider == "ibm":
_check_pkg("langchain_ibm")
from langchain_ibm import ChatWatsonx
return ChatWatsonx(model_id=model, **kwargs)
if model_provider == "xai":
_check_pkg("langchain_xai")
from langchain_xai import ChatXAI
return ChatXAI(model=model, **kwargs)
if model_provider == "perplexity":
_check_pkg("langchain_perplexity")
from langchain_perplexity import ChatPerplexity
return ChatPerplexity(model=model, **kwargs)
if model_provider == "upstage":
_check_pkg("langchain_upstage")
from langchain_upstage import ChatUpstage
return ChatUpstage(model=model, **kwargs)
supported = ", ".join(_SUPPORTED_PROVIDERS)
msg = f"Unsupported {model_provider=}.\n\nSupported model providers are: {supported}"
raise ValueError(msg)
_SUPPORTED_PROVIDERS = {
"openai",
"anthropic",
"azure_openai",
"azure_ai",
"cohere",
"google_vertexai",
"google_genai",
"fireworks",
"ollama",
"together",
"mistralai",
"huggingface",
"groq",
"bedrock",
"bedrock_converse",
"google_anthropic_vertex",
"deepseek",
"ibm",
"xai",
"perplexity",
"upstage",
}
def _attempt_infer_model_provider(model_name: str) -> str | None:
if any(model_name.startswith(pre) for pre in ("gpt-", "o1", "o3")):
return "openai"
if model_name.startswith("claude"):
return "anthropic"
if model_name.startswith("command"):
return "cohere"
if model_name.startswith("accounts/fireworks"):
return "fireworks"
if model_name.startswith("gemini"):
return "google_vertexai"
if model_name.startswith("amazon."):
return "bedrock"
if model_name.startswith("mistral"):
return "mistralai"
if model_name.startswith("deepseek"):
return "deepseek"
if model_name.startswith("grok"):
return "xai"
if model_name.startswith("sonar"):
return "perplexity"
if model_name.startswith("solar"):
return "upstage"
return None
def _parse_model(model: str, model_provider: str | None) -> tuple[str, str]:
if (
not model_provider
and ":" in model
and model.split(":", maxsplit=1)[0] in _SUPPORTED_PROVIDERS
):
model_provider = model.split(":", maxsplit=1)[0]
model = ":".join(model.split(":")[1:])
model_provider = model_provider or _attempt_infer_model_provider(model)
if not model_provider:
msg = (
f"Unable to infer model provider for {model=}, please specify model_provider directly."
)
raise ValueError(msg)
model_provider = model_provider.replace("-", "_").lower()
return model, model_provider
def _check_pkg(pkg: str, *, pkg_kebab: str | None = None) -> None:
if not util.find_spec(pkg):
pkg_kebab = pkg_kebab if pkg_kebab is not None else pkg.replace("_", "-")
msg = f"Unable to import {pkg}. Please install with `pip install -U {pkg_kebab}`"
raise ImportError(msg)
def _remove_prefix(s: str, prefix: str) -> str:
return s.removeprefix(prefix)
_DECLARATIVE_METHODS = ("bind_tools", "with_structured_output")
class _ConfigurableModel(Runnable[LanguageModelInput, Any]):
def __init__(
self,
*,
default_config: dict | None = None,
configurable_fields: Literal["any"] | list[str] | tuple[str, ...] = "any",
config_prefix: str = "",
queued_declarative_operations: Sequence[tuple[str, tuple, dict]] = (),
) -> None:
self._default_config: dict = default_config or {}
self._configurable_fields: Literal["any"] | list[str] = (
configurable_fields if configurable_fields == "any" else list(configurable_fields)
)
self._config_prefix = (
config_prefix + "_"
if config_prefix and not config_prefix.endswith("_")
else config_prefix
)
self._queued_declarative_operations: list[tuple[str, tuple, dict]] = list(
queued_declarative_operations,
)
def __getattr__(self, name: str) -> Any:
if name in _DECLARATIVE_METHODS:
# Declarative operations that cannot be applied until after an actual model
# object is instantiated. So instead of returning the actual operation,
# we record the operation and its arguments in a queue. This queue is
# then applied in order whenever we actually instantiate the model (in
# self._model()).
def queue(*args: Any, **kwargs: Any) -> _ConfigurableModel:
queued_declarative_operations = list(
self._queued_declarative_operations,
)
queued_declarative_operations.append((name, args, kwargs))
return _ConfigurableModel(
default_config=dict(self._default_config),
configurable_fields=list(self._configurable_fields)
if isinstance(self._configurable_fields, list)
else self._configurable_fields,
config_prefix=self._config_prefix,
queued_declarative_operations=queued_declarative_operations,
)
return queue
if self._default_config and (model := self._model()) and hasattr(model, name):
return getattr(model, name)
msg = f"{name} is not a BaseChatModel attribute"
if self._default_config:
msg += " and is not implemented on the default model"
msg += "."
raise AttributeError(msg)
def _model(self, config: RunnableConfig | None = None) -> Runnable:
params = {**self._default_config, **self._model_params(config)}
model = _init_chat_model_helper(**params)
for name, args, kwargs in self._queued_declarative_operations:
model = getattr(model, name)(*args, **kwargs)
return model
def _model_params(self, config: RunnableConfig | None) -> dict:
config = ensure_config(config)
model_params = {
_remove_prefix(k, self._config_prefix): v
for k, v in config.get("configurable", {}).items()
if k.startswith(self._config_prefix)
}
if self._configurable_fields != "any":
model_params = {k: v for k, v in model_params.items() if k in self._configurable_fields}
return model_params
def with_config(
self,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> _ConfigurableModel:
"""Bind config to a `Runnable`, returning a new `Runnable`."""
config = RunnableConfig(**(config or {}), **cast("RunnableConfig", kwargs))
model_params = self._model_params(config)
remaining_config = {k: v for k, v in config.items() if k != "configurable"}
remaining_config["configurable"] = {
k: v
for k, v in config.get("configurable", {}).items()
if _remove_prefix(k, self._config_prefix) not in model_params
}
queued_declarative_operations = list(self._queued_declarative_operations)
if remaining_config:
queued_declarative_operations.append(
(
"with_config",
(),
{"config": remaining_config},
),
)
return _ConfigurableModel(
default_config={**self._default_config, **model_params},
configurable_fields=list(self._configurable_fields)
if isinstance(self._configurable_fields, list)
else self._configurable_fields,
config_prefix=self._config_prefix,
queued_declarative_operations=queued_declarative_operations,
)
@property
@override
def InputType(self) -> TypeAlias:
"""Get the input type for this `Runnable`."""
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes
# for a much better schema.
return str | StringPromptValue | ChatPromptValueConcrete | list[AnyMessage]
@override
def invoke(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
return self._model(config).invoke(input, config=config, **kwargs)
@override
async def ainvoke(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any,
) -> Any:
return await self._model(config).ainvoke(input, config=config, **kwargs)
@override
def stream(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
yield from self._model(config).stream(input, config=config, **kwargs)
@override
async def astream(
self,
input: LanguageModelInput,
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
async for x in self._model(config).astream(input, config=config, **kwargs):
yield x
def batch(
self,
inputs: list[LanguageModelInput],
config: RunnableConfig | list[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any | None,
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
return self._model(config).batch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
return super().batch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
async def abatch(
self,
inputs: list[LanguageModelInput],
config: RunnableConfig | list[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any | None,
) -> list[Any]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
return await self._model(config).abatch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
return await super().abatch(
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
def batch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> Iterator[tuple[int, Any | Exception]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
yield from self._model(cast("RunnableConfig", config)).batch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
yield from super().batch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
)
async def abatch_as_completed(
self,
inputs: Sequence[LanguageModelInput],
config: RunnableConfig | Sequence[RunnableConfig] | None = None,
*,
return_exceptions: bool = False,
**kwargs: Any,
) -> AsyncIterator[tuple[int, Any]]:
config = config or None
# If <= 1 config use the underlying models batch implementation.
if config is None or isinstance(config, dict) or len(config) <= 1:
if isinstance(config, list):
config = config[0]
async for x in self._model(
cast("RunnableConfig", config),
).abatch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
):
yield x
# If multiple configs default to Runnable.batch which uses executor to invoke
# in parallel.
else:
async for x in super().abatch_as_completed( # type: ignore[call-overload]
inputs,
config=config,
return_exceptions=return_exceptions,
**kwargs,
):
yield x
@override
def transform(
self,
input: Iterator[LanguageModelInput],
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> Iterator[Any]:
yield from self._model(config).transform(input, config=config, **kwargs)
@override
async def atransform(
self,
input: AsyncIterator[LanguageModelInput],
config: RunnableConfig | None = None,
**kwargs: Any | None,
) -> AsyncIterator[Any]:
async for x in self._model(config).atransform(input, config=config, **kwargs):
yield x
@overload
@override
def astream_log(
self,
input: Any,
config: RunnableConfig | None = None,
*,
diff: Literal[True] = True,
with_streamed_output_list: bool = True,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[RunLogPatch]: ...
@overload
@override
def astream_log(
self,
input: Any,
config: RunnableConfig | None = None,
*,
diff: Literal[False],
with_streamed_output_list: bool = True,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[RunLog]: ...
@override
async def astream_log(
self,
input: Any,
config: RunnableConfig | None = None,
*,
diff: bool = True,
with_streamed_output_list: bool = True,
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[RunLogPatch] | AsyncIterator[RunLog]:
async for x in self._model(config).astream_log( # type: ignore[call-overload, misc]
input,
config=config,
diff=diff,
with_streamed_output_list=with_streamed_output_list,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_tags=exclude_tags,
exclude_types=exclude_types,
exclude_names=exclude_names,
**kwargs,
):
yield x
@override
async def astream_events(
self,
input: Any,
config: RunnableConfig | None = None,
*,
version: Literal["v1", "v2"] = "v2",
include_names: Sequence[str] | None = None,
include_types: Sequence[str] | None = None,
include_tags: Sequence[str] | None = None,
exclude_names: Sequence[str] | None = None,
exclude_types: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
**kwargs: Any,
) -> AsyncIterator[StreamEvent]:
async for x in self._model(config).astream_events(
input,
config=config,
version=version,
include_names=include_names,
include_types=include_types,
include_tags=include_tags,
exclude_tags=exclude_tags,
exclude_types=exclude_types,
exclude_names=exclude_names,
**kwargs,
):
yield x
# Explicitly added to satisfy downstream linters.
def bind_tools(
self,
tools: Sequence[dict[str, Any] | type[BaseModel] | Callable | BaseTool],
**kwargs: Any,
) -> Runnable[LanguageModelInput, AIMessage]:
return self.__getattr__("bind_tools")(tools, **kwargs)
# Explicitly added to satisfy downstream linters.
def with_structured_output(
self,
schema: dict | type[BaseModel],
**kwargs: Any,
) -> Runnable[LanguageModelInput, dict | BaseModel]:
return self.__getattr__("with_structured_output")(schema, **kwargs)