46 lines
1.3 KiB
Python
46 lines
1.3 KiB
Python
from enum import Enum
|
|
|
|
from langchain_core.exceptions import OutputParserException
|
|
from langchain_core.output_parsers import BaseOutputParser
|
|
from langchain_core.utils import pre_init
|
|
from typing_extensions import override
|
|
|
|
|
|
class EnumOutputParser(BaseOutputParser[Enum]):
|
|
"""Parse an output that is one of a set of values."""
|
|
|
|
enum: type[Enum]
|
|
"""The enum to parse. Its values must be strings."""
|
|
|
|
@pre_init
|
|
def _raise_deprecation(cls, values: dict) -> dict:
|
|
enum = values["enum"]
|
|
if not all(isinstance(e.value, str) for e in enum):
|
|
msg = "Enum values must be strings"
|
|
raise ValueError(msg)
|
|
return values
|
|
|
|
@property
|
|
def _valid_values(self) -> list[str]:
|
|
return [e.value for e in self.enum]
|
|
|
|
@override
|
|
def parse(self, response: str) -> Enum:
|
|
try:
|
|
return self.enum(response.strip())
|
|
except ValueError as e:
|
|
msg = (
|
|
f"Response '{response}' is not one of the "
|
|
f"expected values: {self._valid_values}"
|
|
)
|
|
raise OutputParserException(msg) from e
|
|
|
|
@override
|
|
def get_format_instructions(self) -> str:
|
|
return f"Select one of the following options: {', '.join(self._valid_values)}"
|
|
|
|
@property
|
|
@override
|
|
def OutputType(self) -> type[Enum]:
|
|
return self.enum
|