From 012d9f6c0a3f6bbe8d284d2f74fe9b57bdd26835 Mon Sep 17 00:00:00 2001 From: Yassin Kortam Date: Thu, 11 Jun 2026 10:34:26 -0700 Subject: [PATCH] feat(rate-limiter): allow opting out of v3 TPM reservation and Redis circuit breaker (#30211) --- litellm/caching/redis_cache.py | 16 ++- litellm/constants.py | 3 + .../hooks/parallel_request_limiter_v3.py | 28 +++-- scripts/health_check/health_check_client.py | 4 +- tests/test_litellm/caching/test_dual_cache.py | 61 ++++++++++ .../hooks/test_parallel_request_limiter_v3.py | 108 ++++++++++++++++++ 6 files changed, 208 insertions(+), 12 deletions(-) diff --git a/litellm/caching/redis_cache.py b/litellm/caching/redis_cache.py index cb9ce475d3..7239bea785 100644 --- a/litellm/caching/redis_cache.py +++ b/litellm/caching/redis_cache.py @@ -22,6 +22,7 @@ import litellm from litellm._logging import print_verbose, verbose_logger from litellm.constants import ( DEFAULT_REDIS_MAJOR_VERSION, + REDIS_CIRCUIT_BREAKER_ENABLED, REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD, REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT, ) @@ -114,15 +115,23 @@ class RedisCircuitBreaker: OPEN = "open" HALF_OPEN = "half_open" - def __init__(self, failure_threshold: int, recovery_timeout: int) -> None: + def __init__( + self, + failure_threshold: int, + recovery_timeout: int, + enabled: bool = True, + ) -> None: self.failure_threshold = failure_threshold self.recovery_timeout = recovery_timeout + self.enabled = enabled self._failure_count = 0 self._opened_at: Optional[float] = None self._state = self.CLOSED def is_open(self) -> bool: """Returns True if Redis calls should be skipped.""" + if not self.enabled: + return False if self._state == self.HALF_OPEN: # Probe already in flight — fast-fail all concurrent requests. # Only the one call that caused the OPEN→HALF_OPEN transition @@ -136,6 +145,8 @@ class RedisCircuitBreaker: return False def record_failure(self) -> None: + if not self.enabled: + return self._failure_count += 1 self._opened_at = time.time() if self._failure_count >= self.failure_threshold: @@ -149,6 +160,8 @@ class RedisCircuitBreaker: self._state = self.OPEN def record_success(self) -> None: + if not self.enabled: + return if self._state == self.HALF_OPEN: verbose_logger.info("Redis circuit breaker CLOSED — Redis recovered") self._failure_count = 0 @@ -243,6 +256,7 @@ class RedisCache(BaseCache): self._circuit_breaker = RedisCircuitBreaker( failure_threshold=REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD, recovery_timeout=REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT, + enabled=REDIS_CIRCUIT_BREAKER_ENABLED, ) self._setup_health_pings() diff --git a/litellm/constants.py b/litellm/constants.py index 57f55e6c17..ab8e57d735 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -398,6 +398,9 @@ REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD = int( REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT = int( os.getenv("REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT", 60) ) +REDIS_CIRCUIT_BREAKER_ENABLED = ( + os.getenv("REDIS_CIRCUIT_BREAKER_ENABLED", "true").lower() == "true" +) # Default Redis major version to assume when version cannot be determined # Using 7 as it's the modern version that supports LPOP with count parameter DEFAULT_REDIS_MAJOR_VERSION = int(os.getenv("DEFAULT_REDIS_MAJOR_VERSION", 7)) diff --git a/litellm/proxy/hooks/parallel_request_limiter_v3.py b/litellm/proxy/hooks/parallel_request_limiter_v3.py index 6b70cea65a..f45c63d138 100644 --- a/litellm/proxy/hooks/parallel_request_limiter_v3.py +++ b/litellm/proxy/hooks/parallel_request_limiter_v3.py @@ -313,6 +313,14 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger): self.window_size = int(os.getenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", 60)) + # When disabled, TPM is enforced post-call from actual usage (pre-v1.82 + # behavior) instead of reserving an estimated budget upfront, shedding + # the extra per-request Redis Lua round-trip and the global-lock + # in-memory fallback that the reservation path incurs. + self.tpm_reservation_enabled = ( + os.getenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", "true").lower() == "true" + ) + # Batch rate limiter (lazy loaded) self._batch_rate_limiter: Optional[Any] = None @@ -2113,17 +2121,19 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger): # Only check rate limits if we have descriptors with actual limits if descriptors: # First pass: RPM and max_parallel_requests sliding-window check. - # `skip_tpm_check=True` tells should_rate_limit to ignore each - # descriptor's tokens_per_unit so its +1-per-key Lua / in-memory - # increment never touches the :tokens counters — those are owned - # exclusively by the atomic reserve_tpm_tokens path below. Without - # this, every concurrent in-flight request would pre-inflate the - # :tokens counter by 1, shrinking the effective TPM budget by N - # and causing false-positive 429s under bursts. + # When reservation is enabled, `skip_tpm_check=True` tells + # should_rate_limit to ignore each descriptor's tokens_per_unit so + # its +1-per-key Lua / in-memory increment never touches the + # :tokens counters — those are owned exclusively by the atomic + # reserve_tpm_tokens path below. Without this, every concurrent + # in-flight request would pre-inflate the :tokens counter by 1, + # shrinking the effective TPM budget by N and causing + # false-positive 429s under bursts. When reservation is disabled, + # this pass enforces TPM directly from the post-call counters. response = await self.should_rate_limit( descriptors=descriptors, parent_otel_span=user_api_key_dict.parent_otel_span, - skip_tpm_check=True, + skip_tpm_check=self.tpm_reservation_enabled, ) if response["overall_code"] == "OVER_LIMIT": @@ -2153,7 +2163,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger): ] has_tpm_limits = bool(configured_tpm_limits) - if has_tpm_limits: + if has_tpm_limits and self.tpm_reservation_enabled: min_configured_tpm_limit = min(configured_tpm_limits) # When the configured TPM cap is small enough to constrain the diff --git a/scripts/health_check/health_check_client.py b/scripts/health_check/health_check_client.py index 497fd6271b..9ef8b93496 100644 --- a/scripts/health_check/health_check_client.py +++ b/scripts/health_check/health_check_client.py @@ -54,7 +54,7 @@ class LiteLLMHealthCheckClient: timeout: Request timeout in seconds (default: 120, matching Go implementation) completion_prompt: Test prompt for chat/completion models embedding_text: Test text for embedding models - custom_auth_header: Optional custom header name for authentication (e.g., "x-ifood-requester-service"). + custom_auth_header: Optional custom header name for authentication (e.g., "x-requester-service"). If provided, uses this header instead of standard "Authorization" header. """ self.base_url = base_url.rstrip("/") @@ -404,7 +404,7 @@ async def main(): yaml_path = os.environ.get("LITELLM_MODELS_YAML") custom_auth_header = os.environ.get( "LITELLM_CUSTOM_AUTH_HEADER" - ) # e.g., "x-ifood-requester-service" + ) # e.g., "x-requester-service" # Debug: Print custom auth header value if set if custom_auth_header: diff --git a/tests/test_litellm/caching/test_dual_cache.py b/tests/test_litellm/caching/test_dual_cache.py index 6477472620..f4f88def78 100644 --- a/tests/test_litellm/caching/test_dual_cache.py +++ b/tests/test_litellm/caching/test_dual_cache.py @@ -243,6 +243,67 @@ def test_circuit_breaker_half_open_concurrent_calls_are_fast_failed(): ), "concurrent callers should be fast-failed in HALF_OPEN" +def test_circuit_breaker_disabled_never_opens(): + """When disabled, failures never open the circuit and is_open() stays False.""" + from litellm.caching.redis_cache import RedisCircuitBreaker + + cb = RedisCircuitBreaker(failure_threshold=3, recovery_timeout=60, enabled=False) + + for _ in range(100): + cb.record_failure() + + assert cb._state == "closed" + assert cb.is_open() is False + + +def test_circuit_breaker_disabled_record_success_leaves_state_untouched(): + """ + A disabled breaker must not mutate state in any state-machine method. Force + a non-default (OPEN) state and assert record_success() returns without + resetting it — the same enabled-guard contract as is_open/record_failure. + """ + from litellm.caching.redis_cache import RedisCircuitBreaker + + cb = RedisCircuitBreaker(failure_threshold=3, recovery_timeout=60, enabled=False) + cb._state = "open" + cb._failure_count = 3 + + cb.record_success() + + assert cb._state == "open" + assert cb._failure_count == 3 + + +@pytest.mark.asyncio +async def test_circuit_breaker_disabled_guard_always_calls_method(): + """A disabled breaker lets every guarded call through, even after failures.""" + from litellm.caching.redis_cache import ( + RedisCircuitBreaker, + _redis_circuit_breaker_guard, + ) + + class FakeRedis: + def __init__(self): + self._circuit_breaker = RedisCircuitBreaker( + failure_threshold=1, recovery_timeout=60, enabled=False + ) + self.call_count = 0 + + @_redis_circuit_breaker_guard + async def boom(self): + self.call_count += 1 + raise RuntimeError("redis down") + + fr = FakeRedis() + for _ in range(5): + with pytest.raises(RuntimeError, match="redis down"): + await fr.boom() + + # Every call reached the method body; the breaker never short-circuited. + assert fr.call_count == 5 + assert fr._circuit_breaker.is_open() is False + + @pytest.mark.asyncio async def test_async_increment_cache_returns_none_when_no_in_memory_cache_and_redis_fails(): """ diff --git a/tests/test_litellm/proxy/hooks/test_parallel_request_limiter_v3.py b/tests/test_litellm/proxy/hooks/test_parallel_request_limiter_v3.py index d10311b9f4..ae699ff8e1 100644 --- a/tests/test_litellm/proxy/hooks/test_parallel_request_limiter_v3.py +++ b/tests/test_litellm/proxy/hooks/test_parallel_request_limiter_v3.py @@ -3187,3 +3187,111 @@ def test_get_key_mcp_rpm_limit_precedence(): none_set = UserAPIKeyAuth(api_key=hash_token("sk-mcp-key")) assert get_key_mcp_rpm_limit(none_set) is None assert get_team_mcp_rpm_limit(none_set) is None + + +def test_tpm_reservation_enabled_by_default(monkeypatch): + """Upfront TPM reservation is on unless explicitly disabled via env.""" + monkeypatch.delenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", raising=False) + handler = _PROXY_MaxParallelRequestsHandler( + internal_usage_cache=InternalUsageCache(DualCache()) + ) + assert handler.tpm_reservation_enabled is True + + +@pytest.mark.parametrize("value", ["false", "False", "FALSE"]) +def test_tpm_reservation_disabled_via_env(monkeypatch, value): + monkeypatch.setenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", value) + handler = _PROXY_MaxParallelRequestsHandler( + internal_usage_cache=InternalUsageCache(DualCache()) + ) + assert handler.tpm_reservation_enabled is False + + +@pytest.mark.asyncio +async def test_pre_call_hook_reserves_tpm_when_enabled(monkeypatch): + """ + With reservation enabled, the pre-call hook reserves the estimated token + budget upfront and tells should_rate_limit to skip the :tokens counter so + only the reservation path owns it. + """ + monkeypatch.delenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", raising=False) + handler = _PROXY_MaxParallelRequestsHandler( + internal_usage_cache=InternalUsageCache(DualCache()) + ) + + user_api_key_dict = UserAPIKeyAuth(api_key=hash_token("sk-tpm"), tpm_limit=10_000) + + should_rate_limit_calls: List[Dict[str, Any]] = [] + original_should_rate_limit = handler.should_rate_limit + + async def spy_should_rate_limit(*args, **kwargs): + should_rate_limit_calls.append(kwargs) + return await original_should_rate_limit(*args, **kwargs) + + reserve_calls: List[int] = [] + original_reserve = handler.reserve_tpm_tokens + + async def spy_reserve(*args, **kwargs): + reserve_calls.append(kwargs.get("estimated_tokens")) + return await original_reserve(*args, **kwargs) + + monkeypatch.setattr(handler, "should_rate_limit", spy_should_rate_limit) + monkeypatch.setattr(handler, "reserve_tpm_tokens", spy_reserve) + + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=handler.internal_usage_cache.dual_cache, + data={"model": "gpt-4", "messages": [{"role": "user", "content": "hi"}]}, + call_type="completion", + ) + + assert len(reserve_calls) == 1, "reservation must run when enabled" + assert should_rate_limit_calls[0]["skip_tpm_check"] is True + + +@pytest.mark.asyncio +async def test_pre_call_hook_skips_reservation_when_disabled(monkeypatch): + """ + With reservation disabled, the pre-call hook never calls reserve_tpm_tokens + and enforces TPM directly in should_rate_limit (skip_tpm_check=False), the + pre-v1.82 post-call accounting behavior. + """ + monkeypatch.setenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", "false") + handler = _PROXY_MaxParallelRequestsHandler( + internal_usage_cache=InternalUsageCache(DualCache()) + ) + + user_api_key_dict = UserAPIKeyAuth(api_key=hash_token("sk-tpm"), tpm_limit=10_000) + + should_rate_limit_calls: List[Dict[str, Any]] = [] + original_should_rate_limit = handler.should_rate_limit + + async def spy_should_rate_limit(*args, **kwargs): + should_rate_limit_calls.append(kwargs) + return await original_should_rate_limit(*args, **kwargs) + + reserve_calls: List[Any] = [] + + async def spy_reserve(*args, **kwargs): + reserve_calls.append(kwargs) + raise AssertionError("reserve_tpm_tokens must not run when disabled") + + monkeypatch.setattr(handler, "should_rate_limit", spy_should_rate_limit) + monkeypatch.setattr(handler, "reserve_tpm_tokens", spy_reserve) + + data = {"model": "gpt-4", "messages": [{"role": "user", "content": "hi"}]} + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=handler.internal_usage_cache.dual_cache, + data=data, + call_type="completion", + ) + + assert reserve_calls == [], "reservation must be skipped when disabled" + assert should_rate_limit_calls[0]["skip_tpm_check"] is False + # No reservation stash leaks into the request metadata. + from litellm.proxy.hooks.parallel_request_limiter_v3 import ( + TPM_RESERVED_TOKENS_KEY, + ) + + assert TPM_RESERVED_TOKENS_KEY not in (data.get("metadata") or {})