feat(rate-limiter): allow opting out of v3 TPM reservation and Redis circuit breaker (#30211)

This commit is contained in:
Yassin Kortam 2026-06-11 10:34:26 -07:00 committed by GitHub
parent 0d120de785
commit 012d9f6c0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 208 additions and 12 deletions

View File

@ -22,6 +22,7 @@ import litellm
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
from litellm.constants import ( from litellm.constants import (
DEFAULT_REDIS_MAJOR_VERSION, DEFAULT_REDIS_MAJOR_VERSION,
REDIS_CIRCUIT_BREAKER_ENABLED,
REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD, REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD,
REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT, REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT,
) )
@ -114,15 +115,23 @@ class RedisCircuitBreaker:
OPEN = "open" OPEN = "open"
HALF_OPEN = "half_open" HALF_OPEN = "half_open"
def __init__(self, failure_threshold: int, recovery_timeout: int) -> None: def __init__(
self,
failure_threshold: int,
recovery_timeout: int,
enabled: bool = True,
) -> None:
self.failure_threshold = failure_threshold self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout self.recovery_timeout = recovery_timeout
self.enabled = enabled
self._failure_count = 0 self._failure_count = 0
self._opened_at: Optional[float] = None self._opened_at: Optional[float] = None
self._state = self.CLOSED self._state = self.CLOSED
def is_open(self) -> bool: def is_open(self) -> bool:
"""Returns True if Redis calls should be skipped.""" """Returns True if Redis calls should be skipped."""
if not self.enabled:
return False
if self._state == self.HALF_OPEN: if self._state == self.HALF_OPEN:
# Probe already in flight — fast-fail all concurrent requests. # Probe already in flight — fast-fail all concurrent requests.
# Only the one call that caused the OPEN→HALF_OPEN transition # Only the one call that caused the OPEN→HALF_OPEN transition
@ -136,6 +145,8 @@ class RedisCircuitBreaker:
return False return False
def record_failure(self) -> None: def record_failure(self) -> None:
if not self.enabled:
return
self._failure_count += 1 self._failure_count += 1
self._opened_at = time.time() self._opened_at = time.time()
if self._failure_count >= self.failure_threshold: if self._failure_count >= self.failure_threshold:
@ -149,6 +160,8 @@ class RedisCircuitBreaker:
self._state = self.OPEN self._state = self.OPEN
def record_success(self) -> None: def record_success(self) -> None:
if not self.enabled:
return
if self._state == self.HALF_OPEN: if self._state == self.HALF_OPEN:
verbose_logger.info("Redis circuit breaker CLOSED — Redis recovered") verbose_logger.info("Redis circuit breaker CLOSED — Redis recovered")
self._failure_count = 0 self._failure_count = 0
@ -243,6 +256,7 @@ class RedisCache(BaseCache):
self._circuit_breaker = RedisCircuitBreaker( self._circuit_breaker = RedisCircuitBreaker(
failure_threshold=REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD, failure_threshold=REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD,
recovery_timeout=REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT, recovery_timeout=REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT,
enabled=REDIS_CIRCUIT_BREAKER_ENABLED,
) )
self._setup_health_pings() self._setup_health_pings()

View File

@ -398,6 +398,9 @@ REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD = int(
REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT = int( REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT = int(
os.getenv("REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT", 60) os.getenv("REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT", 60)
) )
REDIS_CIRCUIT_BREAKER_ENABLED = (
os.getenv("REDIS_CIRCUIT_BREAKER_ENABLED", "true").lower() == "true"
)
# Default Redis major version to assume when version cannot be determined # Default Redis major version to assume when version cannot be determined
# Using 7 as it's the modern version that supports LPOP with count parameter # Using 7 as it's the modern version that supports LPOP with count parameter
DEFAULT_REDIS_MAJOR_VERSION = int(os.getenv("DEFAULT_REDIS_MAJOR_VERSION", 7)) DEFAULT_REDIS_MAJOR_VERSION = int(os.getenv("DEFAULT_REDIS_MAJOR_VERSION", 7))

View File

@ -313,6 +313,14 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
self.window_size = int(os.getenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", 60)) self.window_size = int(os.getenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", 60))
# When disabled, TPM is enforced post-call from actual usage (pre-v1.82
# behavior) instead of reserving an estimated budget upfront, shedding
# the extra per-request Redis Lua round-trip and the global-lock
# in-memory fallback that the reservation path incurs.
self.tpm_reservation_enabled = (
os.getenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", "true").lower() == "true"
)
# Batch rate limiter (lazy loaded) # Batch rate limiter (lazy loaded)
self._batch_rate_limiter: Optional[Any] = None self._batch_rate_limiter: Optional[Any] = None
@ -2113,17 +2121,19 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
# Only check rate limits if we have descriptors with actual limits # Only check rate limits if we have descriptors with actual limits
if descriptors: if descriptors:
# First pass: RPM and max_parallel_requests sliding-window check. # First pass: RPM and max_parallel_requests sliding-window check.
# `skip_tpm_check=True` tells should_rate_limit to ignore each # When reservation is enabled, `skip_tpm_check=True` tells
# descriptor's tokens_per_unit so its +1-per-key Lua / in-memory # should_rate_limit to ignore each descriptor's tokens_per_unit so
# increment never touches the :tokens counters — those are owned # its +1-per-key Lua / in-memory increment never touches the
# exclusively by the atomic reserve_tpm_tokens path below. Without # :tokens counters — those are owned exclusively by the atomic
# this, every concurrent in-flight request would pre-inflate the # reserve_tpm_tokens path below. Without this, every concurrent
# :tokens counter by 1, shrinking the effective TPM budget by N # in-flight request would pre-inflate the :tokens counter by 1,
# and causing false-positive 429s under bursts. # shrinking the effective TPM budget by N and causing
# false-positive 429s under bursts. When reservation is disabled,
# this pass enforces TPM directly from the post-call counters.
response = await self.should_rate_limit( response = await self.should_rate_limit(
descriptors=descriptors, descriptors=descriptors,
parent_otel_span=user_api_key_dict.parent_otel_span, parent_otel_span=user_api_key_dict.parent_otel_span,
skip_tpm_check=True, skip_tpm_check=self.tpm_reservation_enabled,
) )
if response["overall_code"] == "OVER_LIMIT": if response["overall_code"] == "OVER_LIMIT":
@ -2153,7 +2163,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
] ]
has_tpm_limits = bool(configured_tpm_limits) has_tpm_limits = bool(configured_tpm_limits)
if has_tpm_limits: if has_tpm_limits and self.tpm_reservation_enabled:
min_configured_tpm_limit = min(configured_tpm_limits) min_configured_tpm_limit = min(configured_tpm_limits)
# When the configured TPM cap is small enough to constrain the # When the configured TPM cap is small enough to constrain the

View File

@ -54,7 +54,7 @@ class LiteLLMHealthCheckClient:
timeout: Request timeout in seconds (default: 120, matching Go implementation) timeout: Request timeout in seconds (default: 120, matching Go implementation)
completion_prompt: Test prompt for chat/completion models completion_prompt: Test prompt for chat/completion models
embedding_text: Test text for embedding models embedding_text: Test text for embedding models
custom_auth_header: Optional custom header name for authentication (e.g., "x-ifood-requester-service"). custom_auth_header: Optional custom header name for authentication (e.g., "x-requester-service").
If provided, uses this header instead of standard "Authorization" header. If provided, uses this header instead of standard "Authorization" header.
""" """
self.base_url = base_url.rstrip("/") self.base_url = base_url.rstrip("/")
@ -404,7 +404,7 @@ async def main():
yaml_path = os.environ.get("LITELLM_MODELS_YAML") yaml_path = os.environ.get("LITELLM_MODELS_YAML")
custom_auth_header = os.environ.get( custom_auth_header = os.environ.get(
"LITELLM_CUSTOM_AUTH_HEADER" "LITELLM_CUSTOM_AUTH_HEADER"
) # e.g., "x-ifood-requester-service" ) # e.g., "x-requester-service"
# Debug: Print custom auth header value if set # Debug: Print custom auth header value if set
if custom_auth_header: if custom_auth_header:

View File

@ -243,6 +243,67 @@ def test_circuit_breaker_half_open_concurrent_calls_are_fast_failed():
), "concurrent callers should be fast-failed in HALF_OPEN" ), "concurrent callers should be fast-failed in HALF_OPEN"
def test_circuit_breaker_disabled_never_opens():
"""When disabled, failures never open the circuit and is_open() stays False."""
from litellm.caching.redis_cache import RedisCircuitBreaker
cb = RedisCircuitBreaker(failure_threshold=3, recovery_timeout=60, enabled=False)
for _ in range(100):
cb.record_failure()
assert cb._state == "closed"
assert cb.is_open() is False
def test_circuit_breaker_disabled_record_success_leaves_state_untouched():
"""
A disabled breaker must not mutate state in any state-machine method. Force
a non-default (OPEN) state and assert record_success() returns without
resetting it the same enabled-guard contract as is_open/record_failure.
"""
from litellm.caching.redis_cache import RedisCircuitBreaker
cb = RedisCircuitBreaker(failure_threshold=3, recovery_timeout=60, enabled=False)
cb._state = "open"
cb._failure_count = 3
cb.record_success()
assert cb._state == "open"
assert cb._failure_count == 3
@pytest.mark.asyncio
async def test_circuit_breaker_disabled_guard_always_calls_method():
"""A disabled breaker lets every guarded call through, even after failures."""
from litellm.caching.redis_cache import (
RedisCircuitBreaker,
_redis_circuit_breaker_guard,
)
class FakeRedis:
def __init__(self):
self._circuit_breaker = RedisCircuitBreaker(
failure_threshold=1, recovery_timeout=60, enabled=False
)
self.call_count = 0
@_redis_circuit_breaker_guard
async def boom(self):
self.call_count += 1
raise RuntimeError("redis down")
fr = FakeRedis()
for _ in range(5):
with pytest.raises(RuntimeError, match="redis down"):
await fr.boom()
# Every call reached the method body; the breaker never short-circuited.
assert fr.call_count == 5
assert fr._circuit_breaker.is_open() is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_increment_cache_returns_none_when_no_in_memory_cache_and_redis_fails(): async def test_async_increment_cache_returns_none_when_no_in_memory_cache_and_redis_fails():
""" """

View File

@ -3187,3 +3187,111 @@ def test_get_key_mcp_rpm_limit_precedence():
none_set = UserAPIKeyAuth(api_key=hash_token("sk-mcp-key")) none_set = UserAPIKeyAuth(api_key=hash_token("sk-mcp-key"))
assert get_key_mcp_rpm_limit(none_set) is None assert get_key_mcp_rpm_limit(none_set) is None
assert get_team_mcp_rpm_limit(none_set) is None assert get_team_mcp_rpm_limit(none_set) is None
def test_tpm_reservation_enabled_by_default(monkeypatch):
"""Upfront TPM reservation is on unless explicitly disabled via env."""
monkeypatch.delenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", raising=False)
handler = _PROXY_MaxParallelRequestsHandler(
internal_usage_cache=InternalUsageCache(DualCache())
)
assert handler.tpm_reservation_enabled is True
@pytest.mark.parametrize("value", ["false", "False", "FALSE"])
def test_tpm_reservation_disabled_via_env(monkeypatch, value):
monkeypatch.setenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", value)
handler = _PROXY_MaxParallelRequestsHandler(
internal_usage_cache=InternalUsageCache(DualCache())
)
assert handler.tpm_reservation_enabled is False
@pytest.mark.asyncio
async def test_pre_call_hook_reserves_tpm_when_enabled(monkeypatch):
"""
With reservation enabled, the pre-call hook reserves the estimated token
budget upfront and tells should_rate_limit to skip the :tokens counter so
only the reservation path owns it.
"""
monkeypatch.delenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", raising=False)
handler = _PROXY_MaxParallelRequestsHandler(
internal_usage_cache=InternalUsageCache(DualCache())
)
user_api_key_dict = UserAPIKeyAuth(api_key=hash_token("sk-tpm"), tpm_limit=10_000)
should_rate_limit_calls: List[Dict[str, Any]] = []
original_should_rate_limit = handler.should_rate_limit
async def spy_should_rate_limit(*args, **kwargs):
should_rate_limit_calls.append(kwargs)
return await original_should_rate_limit(*args, **kwargs)
reserve_calls: List[int] = []
original_reserve = handler.reserve_tpm_tokens
async def spy_reserve(*args, **kwargs):
reserve_calls.append(kwargs.get("estimated_tokens"))
return await original_reserve(*args, **kwargs)
monkeypatch.setattr(handler, "should_rate_limit", spy_should_rate_limit)
monkeypatch.setattr(handler, "reserve_tpm_tokens", spy_reserve)
await handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=handler.internal_usage_cache.dual_cache,
data={"model": "gpt-4", "messages": [{"role": "user", "content": "hi"}]},
call_type="completion",
)
assert len(reserve_calls) == 1, "reservation must run when enabled"
assert should_rate_limit_calls[0]["skip_tpm_check"] is True
@pytest.mark.asyncio
async def test_pre_call_hook_skips_reservation_when_disabled(monkeypatch):
"""
With reservation disabled, the pre-call hook never calls reserve_tpm_tokens
and enforces TPM directly in should_rate_limit (skip_tpm_check=False), the
pre-v1.82 post-call accounting behavior.
"""
monkeypatch.setenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", "false")
handler = _PROXY_MaxParallelRequestsHandler(
internal_usage_cache=InternalUsageCache(DualCache())
)
user_api_key_dict = UserAPIKeyAuth(api_key=hash_token("sk-tpm"), tpm_limit=10_000)
should_rate_limit_calls: List[Dict[str, Any]] = []
original_should_rate_limit = handler.should_rate_limit
async def spy_should_rate_limit(*args, **kwargs):
should_rate_limit_calls.append(kwargs)
return await original_should_rate_limit(*args, **kwargs)
reserve_calls: List[Any] = []
async def spy_reserve(*args, **kwargs):
reserve_calls.append(kwargs)
raise AssertionError("reserve_tpm_tokens must not run when disabled")
monkeypatch.setattr(handler, "should_rate_limit", spy_should_rate_limit)
monkeypatch.setattr(handler, "reserve_tpm_tokens", spy_reserve)
data = {"model": "gpt-4", "messages": [{"role": "user", "content": "hi"}]}
await handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=handler.internal_usage_cache.dual_cache,
data=data,
call_type="completion",
)
assert reserve_calls == [], "reservation must be skipped when disabled"
assert should_rate_limit_calls[0]["skip_tpm_check"] is False
# No reservation stash leaks into the request metadata.
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
TPM_RESERVED_TOKENS_KEY,
)
assert TPM_RESERVED_TOKENS_KEY not in (data.get("metadata") or {})