124 lines
3.7 KiB
Python
124 lines
3.7 KiB
Python
"""Shared retry utilities for agent middleware.
|
|
|
|
This module contains common constants, utilities, and logic used by both
|
|
model and tool retry middleware implementations.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import random
|
|
from collections.abc import Callable
|
|
from typing import Literal
|
|
|
|
# Type aliases
|
|
RetryOn = tuple[type[Exception], ...] | Callable[[Exception], bool]
|
|
"""Type for specifying which exceptions to retry on.
|
|
|
|
Can be either:
|
|
- A tuple of exception types to retry on (based on `isinstance` checks)
|
|
- A callable that takes an exception and returns `True` if it should be retried
|
|
"""
|
|
|
|
OnFailure = Literal["error", "continue"] | Callable[[Exception], str]
|
|
"""Type for specifying failure handling behavior.
|
|
|
|
Can be either:
|
|
- A literal action string (`'error'` or `'continue'`)
|
|
- `'error'`: Re-raise the exception, stopping agent execution.
|
|
- `'continue'`: Inject a message with the error details, allowing the agent to continue.
|
|
For tool retries, a `ToolMessage` with the error details will be injected.
|
|
For model retries, an `AIMessage` with the error details will be returned.
|
|
- A callable that takes an exception and returns a string for error message content
|
|
"""
|
|
|
|
|
|
def validate_retry_params(
|
|
max_retries: int,
|
|
initial_delay: float,
|
|
max_delay: float,
|
|
backoff_factor: float,
|
|
) -> None:
|
|
"""Validate retry parameters.
|
|
|
|
Args:
|
|
max_retries: Maximum number of retry attempts.
|
|
initial_delay: Initial delay in seconds before first retry.
|
|
max_delay: Maximum delay in seconds between retries.
|
|
backoff_factor: Multiplier for exponential backoff.
|
|
|
|
Raises:
|
|
ValueError: If any parameter is invalid (negative values).
|
|
"""
|
|
if max_retries < 0:
|
|
msg = "max_retries must be >= 0"
|
|
raise ValueError(msg)
|
|
if initial_delay < 0:
|
|
msg = "initial_delay must be >= 0"
|
|
raise ValueError(msg)
|
|
if max_delay < 0:
|
|
msg = "max_delay must be >= 0"
|
|
raise ValueError(msg)
|
|
if backoff_factor < 0:
|
|
msg = "backoff_factor must be >= 0"
|
|
raise ValueError(msg)
|
|
|
|
|
|
def should_retry_exception(
|
|
exc: Exception,
|
|
retry_on: RetryOn,
|
|
) -> bool:
|
|
"""Check if an exception should trigger a retry.
|
|
|
|
Args:
|
|
exc: The exception that occurred.
|
|
retry_on: Either a tuple of exception types to retry on, or a callable
|
|
that takes an exception and returns `True` if it should be retried.
|
|
|
|
Returns:
|
|
`True` if the exception should be retried, `False` otherwise.
|
|
"""
|
|
if callable(retry_on):
|
|
return retry_on(exc)
|
|
return isinstance(exc, retry_on)
|
|
|
|
|
|
def calculate_delay(
|
|
retry_number: int,
|
|
*,
|
|
backoff_factor: float,
|
|
initial_delay: float,
|
|
max_delay: float,
|
|
jitter: bool,
|
|
) -> float:
|
|
"""Calculate delay for a retry attempt with exponential backoff and optional jitter.
|
|
|
|
Args:
|
|
retry_number: The retry attempt number (0-indexed).
|
|
backoff_factor: Multiplier for exponential backoff.
|
|
|
|
Set to `0.0` for constant delay.
|
|
initial_delay: Initial delay in seconds before first retry.
|
|
max_delay: Maximum delay in seconds between retries.
|
|
|
|
Caps exponential backoff growth.
|
|
jitter: Whether to add random jitter to delay to avoid thundering herd.
|
|
|
|
Returns:
|
|
Delay in seconds before next retry.
|
|
"""
|
|
if backoff_factor == 0.0:
|
|
delay = initial_delay
|
|
else:
|
|
delay = initial_delay * (backoff_factor**retry_number)
|
|
|
|
# Cap at max_delay
|
|
delay = min(delay, max_delay)
|
|
|
|
if jitter and delay > 0:
|
|
jitter_amount = delay * 0.25 # ±25% jitter
|
|
delay = delay + random.uniform(-jitter_amount, jitter_amount) # noqa: S311
|
|
# Ensure delay is not negative after jitter
|
|
delay = max(0, delay)
|
|
|
|
return delay
|