84 lines
2.6 KiB
Python
84 lines
2.6 KiB
Python
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
from collections.abc import AsyncIterator
|
||
|
|
from typing import Any, Literal, cast
|
||
|
|
|
||
|
|
from langchain_core.callbacks import AsyncCallbackHandler
|
||
|
|
from langchain_core.outputs import LLMResult
|
||
|
|
from typing_extensions import override
|
||
|
|
|
||
|
|
# TODO: If used by two LLM runs in parallel this won't work as expected
|
||
|
|
|
||
|
|
|
||
|
|
class AsyncIteratorCallbackHandler(AsyncCallbackHandler):
|
||
|
|
"""Callback handler that returns an async iterator."""
|
||
|
|
|
||
|
|
queue: asyncio.Queue[str]
|
||
|
|
|
||
|
|
done: asyncio.Event
|
||
|
|
|
||
|
|
@property
|
||
|
|
def always_verbose(self) -> bool:
|
||
|
|
"""Always verbose."""
|
||
|
|
return True
|
||
|
|
|
||
|
|
def __init__(self) -> None:
|
||
|
|
"""Instantiate AsyncIteratorCallbackHandler."""
|
||
|
|
self.queue = asyncio.Queue()
|
||
|
|
self.done = asyncio.Event()
|
||
|
|
|
||
|
|
@override
|
||
|
|
async def on_llm_start(
|
||
|
|
self,
|
||
|
|
serialized: dict[str, Any],
|
||
|
|
prompts: list[str],
|
||
|
|
**kwargs: Any,
|
||
|
|
) -> None:
|
||
|
|
# If two calls are made in a row, this resets the state
|
||
|
|
self.done.clear()
|
||
|
|
|
||
|
|
@override
|
||
|
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||
|
|
if token is not None and token != "":
|
||
|
|
self.queue.put_nowait(token)
|
||
|
|
|
||
|
|
@override
|
||
|
|
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
|
||
|
|
self.done.set()
|
||
|
|
|
||
|
|
@override
|
||
|
|
async def on_llm_error(self, error: BaseException, **kwargs: Any) -> None:
|
||
|
|
self.done.set()
|
||
|
|
|
||
|
|
# TODO: implement the other methods
|
||
|
|
|
||
|
|
async def aiter(self) -> AsyncIterator[str]:
|
||
|
|
"""Asynchronous iterator that yields tokens."""
|
||
|
|
while not self.queue.empty() or not self.done.is_set():
|
||
|
|
# Wait for the next token in the queue,
|
||
|
|
# but stop waiting if the done event is set
|
||
|
|
done, other = await asyncio.wait(
|
||
|
|
[
|
||
|
|
# NOTE: If you add other tasks here, update the code below,
|
||
|
|
# which assumes each set has exactly one task each
|
||
|
|
asyncio.ensure_future(self.queue.get()),
|
||
|
|
asyncio.ensure_future(self.done.wait()),
|
||
|
|
],
|
||
|
|
return_when=asyncio.FIRST_COMPLETED,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Cancel the other task
|
||
|
|
if other:
|
||
|
|
other.pop().cancel()
|
||
|
|
|
||
|
|
# Extract the value of the first completed task
|
||
|
|
token_or_done = cast("str | Literal[True]", done.pop().result())
|
||
|
|
|
||
|
|
# If the extracted value is the boolean True, the done event was set
|
||
|
|
if token_or_done is True:
|
||
|
|
break
|
||
|
|
|
||
|
|
# Otherwise, the extracted value is a token, which we yield
|
||
|
|
yield token_or_done
|