group-wbl/.venv/lib/python3.13/site-packages/langgraph/store/base/embed.py
2026-01-09 09:48:03 +08:00

434 lines
14 KiB
Python

"""Utilities for working with embedding functions and LangChain's Embeddings interface.
This module provides tools to wrap arbitrary embedding functions (both sync and async)
into LangChain's Embeddings interface. This enables using custom embedding functions
with LangChain-compatible tools while maintaining support for both synchronous and
asynchronous operations.
"""
from __future__ import annotations
import asyncio
import functools
import json
from collections.abc import Awaitable, Callable, Sequence
from typing import Any
from langchain_core.embeddings import Embeddings
EmbeddingsFunc = Callable[[Sequence[str]], list[list[float]]]
"""Type for synchronous embedding functions.
The function should take a sequence of strings and return a list of embeddings,
where each embedding is a list of floats. The dimensionality of the embeddings
should be consistent for all inputs.
"""
AEmbeddingsFunc = Callable[[Sequence[str]], Awaitable[list[list[float]]]]
"""Type for asynchronous embedding functions.
Similar to EmbeddingsFunc, but returns an awaitable that resolves to the embeddings.
"""
def ensure_embeddings(
embed: Embeddings | EmbeddingsFunc | AEmbeddingsFunc | str | None,
) -> Embeddings:
"""Ensure that an embedding function conforms to LangChain's Embeddings interface.
This function wraps arbitrary embedding functions to make them compatible with
LangChain's Embeddings interface. It handles both synchronous and asynchronous
functions.
Args:
embed: Either an existing Embeddings instance, or a function that converts
text to embeddings. If the function is async, it will be used for both
sync and async operations.
Returns:
An Embeddings instance that wraps the provided function(s).
??? example "Examples"
Wrap a synchronous embedding function:
```python
def my_embed_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = ensure_embeddings(my_embed_fn)
result = embeddings.embed_query("hello") # Returns [0.1, 0.2]
```
Wrap an asynchronous embedding function:
```python
async def my_async_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = ensure_embeddings(my_async_fn)
result = await embeddings.aembed_query("hello") # Returns [0.1, 0.2]
```
Initialize embeddings using a provider string:
```python
# Requires langchain>=0.3.9 and langgraph-checkpoint>=2.0.11
embeddings = ensure_embeddings("openai:text-embedding-3-small")
result = embeddings.embed_query("hello")
```
"""
if embed is None:
raise ValueError("embed must be provided")
if isinstance(embed, str):
init_embeddings = _get_init_embeddings()
if init_embeddings is None:
from importlib.metadata import PackageNotFoundError, version
try:
lc_version = version("langchain")
version_info = f"Found langchain version {lc_version}, but"
except PackageNotFoundError:
version_info = "langchain is not installed;"
raise ValueError(
f"Could not load embeddings from string '{embed}'. {version_info} "
"loading embeddings by provider:identifier string requires langchain>=0.3.9 "
"as well as the provider-specific package. "
"Install LangChain with: pip install 'langchain>=0.3.9' "
"and the provider-specific package (e.g., 'langchain-openai>=0.3.0'). "
"Alternatively, specify 'embed' as a compatible Embeddings object or python function."
)
return init_embeddings(embed)
if isinstance(embed, Embeddings):
return embed
return EmbeddingsLambda(embed)
class EmbeddingsLambda(Embeddings):
"""Wrapper to convert embedding functions into LangChain's Embeddings interface.
This class allows arbitrary embedding functions to be used with LangChain-compatible
tools. It supports both synchronous and asynchronous operations, and can handle:
1. A synchronous function for sync operations (async operations will use sync function)
2. An async function for both sync/async operations (sync operations will raise an error)
The embedding functions should convert text into fixed-dimensional vectors that
capture the semantic meaning of the text.
Args:
func: Function that converts text to embeddings. Can be sync or async.
If async, it will be used for async operations, but sync operations
will raise an error. If sync, it will be used for both sync and async operations.
??? example "Examples"
With a sync function:
```python
def my_embed_fn(texts):
# Return 2D embeddings for each text
return [[0.1, 0.2] for _ in texts]
embeddings = EmbeddingsLambda(my_embed_fn)
result = embeddings.embed_query("hello") # Returns [0.1, 0.2]
await embeddings.aembed_query("hello") # Also returns [0.1, 0.2]
```
With an async function:
```python
async def my_async_fn(texts):
return [[0.1, 0.2] for _ in texts]
embeddings = EmbeddingsLambda(my_async_fn)
await embeddings.aembed_query("hello") # Returns [0.1, 0.2]
# Note: embed_query() would raise an error
```
"""
def __init__(
self,
func: EmbeddingsFunc | AEmbeddingsFunc,
) -> None:
if func is None:
raise ValueError("func must be provided")
if _is_async_callable(func):
self.afunc = func
else:
self.func = func
def embed_documents(self, texts: list[str]) -> list[list[float]]:
"""Embed a list of texts into vectors.
Args:
texts: list of texts to convert to embeddings.
Returns:
list of embeddings, one per input text. Each embedding is a list of floats.
Raises:
ValueError: If the instance was initialized with only an async function.
"""
func = getattr(self, "func", None)
if func is None:
raise ValueError(
"EmbeddingsLambda was initialized with an async function but no sync function. "
"Use aembed_documents for async operation or provide a sync function."
)
return func(texts)
def embed_query(self, text: str) -> list[float]:
"""Embed a single piece of text.
Args:
text: Text to convert to an embedding.
Returns:
Embedding vector as a list of floats.
Note:
This is equivalent to calling embed_documents with a single text
and taking the first result.
"""
return self.embed_documents([text])[0]
async def aembed_documents(self, texts: list[str]) -> list[list[float]]:
"""Asynchronously embed a list of texts into vectors.
Args:
texts: list of texts to convert to embeddings.
Returns:
list of embeddings, one per input text. Each embedding is a list of floats.
Note:
If no async function was provided, this falls back to the sync implementation.
"""
afunc = getattr(self, "afunc", None)
if afunc is None:
return await super().aembed_documents(texts)
return await afunc(texts)
async def aembed_query(self, text: str) -> list[float]:
"""Asynchronously embed a single piece of text.
Args:
text: Text to convert to an embedding.
Returns:
Embedding vector as a list of floats.
Note:
This is equivalent to calling aembed_documents with a single text
and taking the first result.
"""
afunc = getattr(self, "afunc", None)
if afunc is None:
return await super().aembed_query(text)
return (await afunc([text]))[0]
def get_text_at_path(obj: Any, path: str | list[str]) -> list[str]:
"""Extract text from an object using a path expression or pre-tokenized path.
Args:
obj: The object to extract text from
path: Either a path string or pre-tokenized path list.
!!! info "Path types handled"
- Simple paths: "field1.field2"
- Array indexing: "[0]", "[*]", "[-1]"
- Wildcards: "*"
- Multi-field selection: "{field1,field2}"
- Nested paths in multi-field: "{field1,nested.field2}"
"""
if not path or path == "$":
return [json.dumps(obj, sort_keys=True, ensure_ascii=False)]
tokens = tokenize_path(path) if isinstance(path, str) else path
def _extract_from_obj(obj: Any, tokens: list[str], pos: int) -> list[str]:
if pos >= len(tokens):
if isinstance(obj, (str, int, float, bool)):
return [str(obj)]
elif obj is None:
return []
elif isinstance(obj, (list, dict)):
return [json.dumps(obj, sort_keys=True, ensure_ascii=False)]
return []
token = tokens[pos]
results = []
if token.startswith("[") and token.endswith("]"):
if not isinstance(obj, list):
return []
index = token[1:-1]
if index == "*":
for item in obj:
results.extend(_extract_from_obj(item, tokens, pos + 1))
else:
try:
idx = int(index)
if idx < 0:
idx = len(obj) + idx
if 0 <= idx < len(obj):
results.extend(_extract_from_obj(obj[idx], tokens, pos + 1))
except (ValueError, IndexError):
return []
elif token.startswith("{") and token.endswith("}"):
if not isinstance(obj, dict):
return []
fields = [f.strip() for f in token[1:-1].split(",")]
for field in fields:
nested_tokens = tokenize_path(field)
if nested_tokens:
current_obj: dict | None = obj
for nested_token in nested_tokens:
if (
isinstance(current_obj, dict)
and nested_token in current_obj
):
current_obj = current_obj[nested_token]
else:
current_obj = None
break
if current_obj is not None:
if isinstance(current_obj, (str, int, float, bool)):
results.append(str(current_obj))
elif isinstance(current_obj, (list, dict)):
results.append(
json.dumps(
current_obj, sort_keys=True, ensure_ascii=False
)
)
# Handle wildcard
elif token == "*":
if isinstance(obj, dict):
for value in obj.values():
results.extend(_extract_from_obj(value, tokens, pos + 1))
elif isinstance(obj, list):
for item in obj:
results.extend(_extract_from_obj(item, tokens, pos + 1))
# Handle regular field
else:
if isinstance(obj, dict) and token in obj:
results.extend(_extract_from_obj(obj[token], tokens, pos + 1))
return results
return _extract_from_obj(obj, tokens, 0)
# Private utility functions
def tokenize_path(path: str) -> list[str]:
"""Tokenize a path into components.
!!! info "Types handled"
- Simple paths: "field1.field2"
- Array indexing: "[0]", "[*]", "[-1]"
- Wildcards: "*"
- Multi-field selection: "{field1,field2}"
"""
if not path:
return []
tokens = []
current: list[str] = []
i = 0
while i < len(path):
char = path[i]
if char == "[": # Handle array index
if current:
tokens.append("".join(current))
current = []
bracket_count = 1
index_chars = ["["]
i += 1
while i < len(path) and bracket_count > 0:
if path[i] == "[":
bracket_count += 1
elif path[i] == "]":
bracket_count -= 1
index_chars.append(path[i])
i += 1
tokens.append("".join(index_chars))
continue
elif char == "{": # Handle multi-field selection
if current:
tokens.append("".join(current))
current = []
brace_count = 1
field_chars = ["{"]
i += 1
while i < len(path) and brace_count > 0:
if path[i] == "{":
brace_count += 1
elif path[i] == "}":
brace_count -= 1
field_chars.append(path[i])
i += 1
tokens.append("".join(field_chars))
continue
elif char == ".": # Handle regular field
if current:
tokens.append("".join(current))
current = []
else:
current.append(char)
i += 1
if current:
tokens.append("".join(current))
return tokens
def _is_async_callable(
func: Any,
) -> bool:
"""Check if a function is async.
This includes both async def functions and classes with async __call__ methods.
Args:
func: Function or callable object to check.
Returns:
True if the function is async, False otherwise.
"""
return (
asyncio.iscoroutinefunction(func)
or hasattr(func, "__call__") # noqa: B004
and asyncio.iscoroutinefunction(func.__call__)
)
@functools.lru_cache
def _get_init_embeddings() -> Callable[[str], Embeddings] | None:
try:
from langchain.embeddings import init_embeddings # type: ignore
return init_embeddings
except ImportError:
return None
__all__ = [
"ensure_embeddings",
"EmbeddingsFunc",
"AEmbeddingsFunc",
]