69 lines
1.9 KiB
Python
69 lines
1.9 KiB
Python
|
|
from typing import Any
|
||
|
|
|
||
|
|
from langchain_core.callbacks import (
|
||
|
|
AsyncCallbackManagerForRetrieverRun,
|
||
|
|
CallbackManagerForRetrieverRun,
|
||
|
|
)
|
||
|
|
from langchain_core.documents import BaseDocumentCompressor, Document
|
||
|
|
from langchain_core.retrievers import BaseRetriever, RetrieverLike
|
||
|
|
from pydantic import ConfigDict
|
||
|
|
from typing_extensions import override
|
||
|
|
|
||
|
|
|
||
|
|
class ContextualCompressionRetriever(BaseRetriever):
|
||
|
|
"""Retriever that wraps a base retriever and compresses the results."""
|
||
|
|
|
||
|
|
base_compressor: BaseDocumentCompressor
|
||
|
|
"""Compressor for compressing retrieved documents."""
|
||
|
|
|
||
|
|
base_retriever: RetrieverLike
|
||
|
|
"""Base Retriever to use for getting relevant documents."""
|
||
|
|
|
||
|
|
model_config = ConfigDict(
|
||
|
|
arbitrary_types_allowed=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
@override
|
||
|
|
def _get_relevant_documents(
|
||
|
|
self,
|
||
|
|
query: str,
|
||
|
|
*,
|
||
|
|
run_manager: CallbackManagerForRetrieverRun,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> list[Document]:
|
||
|
|
docs = self.base_retriever.invoke(
|
||
|
|
query,
|
||
|
|
config={"callbacks": run_manager.get_child()},
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
if docs:
|
||
|
|
compressed_docs = self.base_compressor.compress_documents(
|
||
|
|
docs,
|
||
|
|
query,
|
||
|
|
callbacks=run_manager.get_child(),
|
||
|
|
)
|
||
|
|
return list(compressed_docs)
|
||
|
|
return []
|
||
|
|
|
||
|
|
@override
|
||
|
|
async def _aget_relevant_documents(
|
||
|
|
self,
|
||
|
|
query: str,
|
||
|
|
*,
|
||
|
|
run_manager: AsyncCallbackManagerForRetrieverRun,
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> list[Document]:
|
||
|
|
docs = await self.base_retriever.ainvoke(
|
||
|
|
query,
|
||
|
|
config={"callbacks": run_manager.get_child()},
|
||
|
|
**kwargs,
|
||
|
|
)
|
||
|
|
if docs:
|
||
|
|
compressed_docs = await self.base_compressor.acompress_documents(
|
||
|
|
docs,
|
||
|
|
query,
|
||
|
|
callbacks=run_manager.get_child(),
|
||
|
|
)
|
||
|
|
return list(compressed_docs)
|
||
|
|
return []
|