337 lines
11 KiB
Python
337 lines
11 KiB
Python
|
|
"""Ensemble Retriever.
|
||
|
|
|
||
|
|
Ensemble retriever that ensemble the results of
|
||
|
|
multiple retrievers by using weighted Reciprocal Rank Fusion.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
from collections import defaultdict
|
||
|
|
from collections.abc import Callable, Hashable, Iterable, Iterator
|
||
|
|
from itertools import chain
|
||
|
|
from typing import (
|
||
|
|
Any,
|
||
|
|
TypeVar,
|
||
|
|
cast,
|
||
|
|
)
|
||
|
|
|
||
|
|
from langchain_core.callbacks import (
|
||
|
|
AsyncCallbackManagerForRetrieverRun,
|
||
|
|
CallbackManagerForRetrieverRun,
|
||
|
|
)
|
||
|
|
from langchain_core.documents import Document
|
||
|
|
from langchain_core.retrievers import BaseRetriever, RetrieverLike
|
||
|
|
from langchain_core.runnables import RunnableConfig
|
||
|
|
from langchain_core.runnables.config import ensure_config, patch_config
|
||
|
|
from langchain_core.runnables.utils import (
|
||
|
|
ConfigurableFieldSpec,
|
||
|
|
get_unique_config_specs,
|
||
|
|
)
|
||
|
|
from pydantic import model_validator
|
||
|
|
from typing_extensions import override
|
||
|
|
|
||
|
|
T = TypeVar("T")
|
||
|
|
H = TypeVar("H", bound=Hashable)
|
||
|
|
|
||
|
|
|
||
|
|
def unique_by_key(iterable: Iterable[T], key: Callable[[T], H]) -> Iterator[T]:
|
||
|
|
"""Yield unique elements of an iterable based on a key function.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
iterable: The iterable to filter.
|
||
|
|
key: A function that returns a hashable key for each element.
|
||
|
|
|
||
|
|
Yields:
|
||
|
|
Unique elements of the iterable based on the key function.
|
||
|
|
"""
|
||
|
|
seen = set()
|
||
|
|
for e in iterable:
|
||
|
|
if (k := key(e)) not in seen:
|
||
|
|
seen.add(k)
|
||
|
|
yield e
|
||
|
|
|
||
|
|
|
||
|
|
class EnsembleRetriever(BaseRetriever):
|
||
|
|
"""Retriever that ensembles the multiple retrievers.
|
||
|
|
|
||
|
|
It uses a rank fusion.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
retrievers: A list of retrievers to ensemble.
|
||
|
|
weights: A list of weights corresponding to the retrievers. Defaults to equal
|
||
|
|
weighting for all retrievers.
|
||
|
|
c: A constant added to the rank, controlling the balance between the importance
|
||
|
|
of high-ranked items and the consideration given to lower-ranked items.
|
||
|
|
id_key: The key in the document's metadata used to determine unique documents.
|
||
|
|
If not specified, page_content is used.
|
||
|
|
"""
|
||
|
|
|
||
|
|
retrievers: list[RetrieverLike]
|
||
|
|
weights: list[float]
|
||
|
|
c: int = 60
|
||
|
|
id_key: str | None = None
|
||
|
|
|
||
|
|
@property
|
||
|
|
def config_specs(self) -> list[ConfigurableFieldSpec]:
|
||
|
|
"""List configurable fields for this runnable."""
|
||
|
|
return get_unique_config_specs(
|
||
|
|
spec for retriever in self.retrievers for spec in retriever.config_specs
|
||
|
|
)
|
||
|
|
|
||
|
|
@model_validator(mode="before")
|
||
|
|
@classmethod
|
||
|
|
def _set_weights(cls, values: dict[str, Any]) -> Any:
|
||
|
|
if not values.get("weights"):
|
||
|
|
n_retrievers = len(values["retrievers"])
|
||
|
|
values["weights"] = [1 / n_retrievers] * n_retrievers
|
||
|
|
return values
|
||
|
|
|
||
|
|
@override
|
||
|
|
def invoke(
|
||
|
|
self,
|
||
|
|
input: str,
|
||
|
|
config: RunnableConfig | None = None,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> list[Document]:
|
||
|
|
from langchain_core.callbacks import CallbackManager
|
||
|
|
|
||
|
|
config = ensure_config(config)
|
||
|
|
callback_manager = CallbackManager.configure(
|
||
|
|
config.get("callbacks"),
|
||
|
|
None,
|
||
|
|
verbose=kwargs.get("verbose", False),
|
||
|
|
inheritable_tags=config.get("tags", []),
|
||
|
|
local_tags=self.tags,
|
||
|
|
inheritable_metadata=config.get("metadata", {}),
|
||
|
|
local_metadata=self.metadata,
|
||
|
|
)
|
||
|
|
run_manager = callback_manager.on_retriever_start(
|
||
|
|
None,
|
||
|
|
input,
|
||
|
|
name=config.get("run_name") or self.get_name(),
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
result = self.rank_fusion(input, run_manager=run_manager, config=config)
|
||
|
|
except Exception as e:
|
||
|
|
run_manager.on_retriever_error(e)
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
run_manager.on_retriever_end(
|
||
|
|
result,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
return result
|
||
|
|
|
||
|
|
@override
|
||
|
|
async def ainvoke(
|
||
|
|
self,
|
||
|
|
input: str,
|
||
|
|
config: RunnableConfig | None = None,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> list[Document]:
|
||
|
|
from langchain_core.callbacks import AsyncCallbackManager
|
||
|
|
|
||
|
|
config = ensure_config(config)
|
||
|
|
callback_manager = AsyncCallbackManager.configure(
|
||
|
|
config.get("callbacks"),
|
||
|
|
None,
|
||
|
|
verbose=kwargs.get("verbose", False),
|
||
|
|
inheritable_tags=config.get("tags", []),
|
||
|
|
local_tags=self.tags,
|
||
|
|
inheritable_metadata=config.get("metadata", {}),
|
||
|
|
local_metadata=self.metadata,
|
||
|
|
)
|
||
|
|
run_manager = await callback_manager.on_retriever_start(
|
||
|
|
None,
|
||
|
|
input,
|
||
|
|
name=config.get("run_name") or self.get_name(),
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
try:
|
||
|
|
result = await self.arank_fusion(
|
||
|
|
input,
|
||
|
|
run_manager=run_manager,
|
||
|
|
config=config,
|
||
|
|
)
|
||
|
|
except Exception as e:
|
||
|
|
await run_manager.on_retriever_error(e)
|
||
|
|
raise
|
||
|
|
else:
|
||
|
|
await run_manager.on_retriever_end(
|
||
|
|
result,
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
return result
|
||
|
|
|
||
|
|
def _get_relevant_documents(
|
||
|
|
self,
|
||
|
|
query: str,
|
||
|
|
*,
|
||
|
|
run_manager: CallbackManagerForRetrieverRun,
|
||
|
|
) -> list[Document]:
|
||
|
|
"""Get the relevant documents for a given query.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
query: The query to search for.
|
||
|
|
run_manager: The callback handler to use.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A list of reranked documents.
|
||
|
|
"""
|
||
|
|
# Get fused result of the retrievers.
|
||
|
|
return self.rank_fusion(query, run_manager)
|
||
|
|
|
||
|
|
async def _aget_relevant_documents(
|
||
|
|
self,
|
||
|
|
query: str,
|
||
|
|
*,
|
||
|
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||
|
|
) -> list[Document]:
|
||
|
|
"""Asynchronously get the relevant documents for a given query.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
query: The query to search for.
|
||
|
|
run_manager: The callback handler to use.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A list of reranked documents.
|
||
|
|
"""
|
||
|
|
# Get fused result of the retrievers.
|
||
|
|
return await self.arank_fusion(query, run_manager)
|
||
|
|
|
||
|
|
def rank_fusion(
|
||
|
|
self,
|
||
|
|
query: str,
|
||
|
|
run_manager: CallbackManagerForRetrieverRun,
|
||
|
|
*,
|
||
|
|
config: RunnableConfig | None = None,
|
||
|
|
) -> list[Document]:
|
||
|
|
"""Rank fusion.
|
||
|
|
|
||
|
|
Retrieve the results of the retrievers and use rank_fusion_func to get
|
||
|
|
the final result.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
query: The query to search for.
|
||
|
|
run_manager: The callback handler to use.
|
||
|
|
config: Optional configuration for the retrievers.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A list of reranked documents.
|
||
|
|
"""
|
||
|
|
# Get the results of all retrievers.
|
||
|
|
retriever_docs = [
|
||
|
|
retriever.invoke(
|
||
|
|
query,
|
||
|
|
patch_config(
|
||
|
|
config,
|
||
|
|
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
|
||
|
|
),
|
||
|
|
)
|
||
|
|
for i, retriever in enumerate(self.retrievers)
|
||
|
|
]
|
||
|
|
|
||
|
|
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
||
|
|
for i in range(len(retriever_docs)):
|
||
|
|
retriever_docs[i] = [
|
||
|
|
Document(page_content=cast("str", doc)) if isinstance(doc, str) else doc # type: ignore[unreachable]
|
||
|
|
for doc in retriever_docs[i]
|
||
|
|
]
|
||
|
|
|
||
|
|
# apply rank fusion
|
||
|
|
return self.weighted_reciprocal_rank(retriever_docs)
|
||
|
|
|
||
|
|
async def arank_fusion(
|
||
|
|
self,
|
||
|
|
query: str,
|
||
|
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||
|
|
*,
|
||
|
|
config: RunnableConfig | None = None,
|
||
|
|
) -> list[Document]:
|
||
|
|
"""Rank fusion.
|
||
|
|
|
||
|
|
Asynchronously retrieve the results of the retrievers
|
||
|
|
and use rank_fusion_func to get the final result.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
query: The query to search for.
|
||
|
|
run_manager: The callback handler to use.
|
||
|
|
config: Optional configuration for the retrievers.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
A list of reranked documents.
|
||
|
|
"""
|
||
|
|
# Get the results of all retrievers.
|
||
|
|
retriever_docs = await asyncio.gather(
|
||
|
|
*[
|
||
|
|
retriever.ainvoke(
|
||
|
|
query,
|
||
|
|
patch_config(
|
||
|
|
config,
|
||
|
|
callbacks=run_manager.get_child(tag=f"retriever_{i + 1}"),
|
||
|
|
),
|
||
|
|
)
|
||
|
|
for i, retriever in enumerate(self.retrievers)
|
||
|
|
],
|
||
|
|
)
|
||
|
|
|
||
|
|
# Enforce that retrieved docs are Documents for each list in retriever_docs
|
||
|
|
for i in range(len(retriever_docs)):
|
||
|
|
retriever_docs[i] = [
|
||
|
|
Document(page_content=doc) if not isinstance(doc, Document) else doc
|
||
|
|
for doc in retriever_docs[i]
|
||
|
|
]
|
||
|
|
|
||
|
|
# apply rank fusion
|
||
|
|
return self.weighted_reciprocal_rank(retriever_docs)
|
||
|
|
|
||
|
|
def weighted_reciprocal_rank(
|
||
|
|
self,
|
||
|
|
doc_lists: list[list[Document]],
|
||
|
|
) -> list[Document]:
|
||
|
|
"""Perform weighted Reciprocal Rank Fusion on multiple rank lists.
|
||
|
|
|
||
|
|
You can find more details about RRF here:
|
||
|
|
https://plg.uwaterloo.ca/~gvcormac/cormacksigir09-rrf.pdf.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
doc_lists: A list of rank lists, where each rank list contains unique items.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
The final aggregated list of items sorted by their weighted RRF
|
||
|
|
scores in descending order.
|
||
|
|
"""
|
||
|
|
if len(doc_lists) != len(self.weights):
|
||
|
|
msg = "Number of rank lists must be equal to the number of weights."
|
||
|
|
raise ValueError(msg)
|
||
|
|
|
||
|
|
# Associate each doc's content with its RRF score for later sorting by it
|
||
|
|
# Duplicated contents across retrievers are collapsed & scored cumulatively
|
||
|
|
rrf_score: dict[str, float] = defaultdict(float)
|
||
|
|
for doc_list, weight in zip(doc_lists, self.weights, strict=False):
|
||
|
|
for rank, doc in enumerate(doc_list, start=1):
|
||
|
|
rrf_score[
|
||
|
|
(
|
||
|
|
doc.page_content
|
||
|
|
if self.id_key is None
|
||
|
|
else doc.metadata[self.id_key]
|
||
|
|
)
|
||
|
|
] += weight / (rank + self.c)
|
||
|
|
|
||
|
|
# Docs are deduplicated by their contents then sorted by their scores
|
||
|
|
all_docs = chain.from_iterable(doc_lists)
|
||
|
|
return sorted(
|
||
|
|
unique_by_key(
|
||
|
|
all_docs,
|
||
|
|
lambda doc: (
|
||
|
|
doc.page_content
|
||
|
|
if self.id_key is None
|
||
|
|
else doc.metadata[self.id_key]
|
||
|
|
),
|
||
|
|
),
|
||
|
|
reverse=True,
|
||
|
|
key=lambda doc: rrf_score[
|
||
|
|
doc.page_content if self.id_key is None else doc.metadata[self.id_key]
|
||
|
|
],
|
||
|
|
)
|