feat(rate-limiter): allow opting out of v3 TPM reservation and Redis circuit breaker (#30211)
This commit is contained in:
parent
0d120de785
commit
012d9f6c0a
@ -22,6 +22,7 @@ import litellm
|
|||||||
from litellm._logging import print_verbose, verbose_logger
|
from litellm._logging import print_verbose, verbose_logger
|
||||||
from litellm.constants import (
|
from litellm.constants import (
|
||||||
DEFAULT_REDIS_MAJOR_VERSION,
|
DEFAULT_REDIS_MAJOR_VERSION,
|
||||||
|
REDIS_CIRCUIT_BREAKER_ENABLED,
|
||||||
REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD,
|
REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD,
|
||||||
REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT,
|
REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT,
|
||||||
)
|
)
|
||||||
@ -114,15 +115,23 @@ class RedisCircuitBreaker:
|
|||||||
OPEN = "open"
|
OPEN = "open"
|
||||||
HALF_OPEN = "half_open"
|
HALF_OPEN = "half_open"
|
||||||
|
|
||||||
def __init__(self, failure_threshold: int, recovery_timeout: int) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
failure_threshold: int,
|
||||||
|
recovery_timeout: int,
|
||||||
|
enabled: bool = True,
|
||||||
|
) -> None:
|
||||||
self.failure_threshold = failure_threshold
|
self.failure_threshold = failure_threshold
|
||||||
self.recovery_timeout = recovery_timeout
|
self.recovery_timeout = recovery_timeout
|
||||||
|
self.enabled = enabled
|
||||||
self._failure_count = 0
|
self._failure_count = 0
|
||||||
self._opened_at: Optional[float] = None
|
self._opened_at: Optional[float] = None
|
||||||
self._state = self.CLOSED
|
self._state = self.CLOSED
|
||||||
|
|
||||||
def is_open(self) -> bool:
|
def is_open(self) -> bool:
|
||||||
"""Returns True if Redis calls should be skipped."""
|
"""Returns True if Redis calls should be skipped."""
|
||||||
|
if not self.enabled:
|
||||||
|
return False
|
||||||
if self._state == self.HALF_OPEN:
|
if self._state == self.HALF_OPEN:
|
||||||
# Probe already in flight — fast-fail all concurrent requests.
|
# Probe already in flight — fast-fail all concurrent requests.
|
||||||
# Only the one call that caused the OPEN→HALF_OPEN transition
|
# Only the one call that caused the OPEN→HALF_OPEN transition
|
||||||
@ -136,6 +145,8 @@ class RedisCircuitBreaker:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def record_failure(self) -> None:
|
def record_failure(self) -> None:
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
self._failure_count += 1
|
self._failure_count += 1
|
||||||
self._opened_at = time.time()
|
self._opened_at = time.time()
|
||||||
if self._failure_count >= self.failure_threshold:
|
if self._failure_count >= self.failure_threshold:
|
||||||
@ -149,6 +160,8 @@ class RedisCircuitBreaker:
|
|||||||
self._state = self.OPEN
|
self._state = self.OPEN
|
||||||
|
|
||||||
def record_success(self) -> None:
|
def record_success(self) -> None:
|
||||||
|
if not self.enabled:
|
||||||
|
return
|
||||||
if self._state == self.HALF_OPEN:
|
if self._state == self.HALF_OPEN:
|
||||||
verbose_logger.info("Redis circuit breaker CLOSED — Redis recovered")
|
verbose_logger.info("Redis circuit breaker CLOSED — Redis recovered")
|
||||||
self._failure_count = 0
|
self._failure_count = 0
|
||||||
@ -243,6 +256,7 @@ class RedisCache(BaseCache):
|
|||||||
self._circuit_breaker = RedisCircuitBreaker(
|
self._circuit_breaker = RedisCircuitBreaker(
|
||||||
failure_threshold=REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD,
|
failure_threshold=REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD,
|
||||||
recovery_timeout=REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT,
|
recovery_timeout=REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT,
|
||||||
|
enabled=REDIS_CIRCUIT_BREAKER_ENABLED,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._setup_health_pings()
|
self._setup_health_pings()
|
||||||
|
|||||||
@ -398,6 +398,9 @@ REDIS_CIRCUIT_BREAKER_FAILURE_THRESHOLD = int(
|
|||||||
REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT = int(
|
REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT = int(
|
||||||
os.getenv("REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT", 60)
|
os.getenv("REDIS_CIRCUIT_BREAKER_RECOVERY_TIMEOUT", 60)
|
||||||
)
|
)
|
||||||
|
REDIS_CIRCUIT_BREAKER_ENABLED = (
|
||||||
|
os.getenv("REDIS_CIRCUIT_BREAKER_ENABLED", "true").lower() == "true"
|
||||||
|
)
|
||||||
# Default Redis major version to assume when version cannot be determined
|
# Default Redis major version to assume when version cannot be determined
|
||||||
# Using 7 as it's the modern version that supports LPOP with count parameter
|
# Using 7 as it's the modern version that supports LPOP with count parameter
|
||||||
DEFAULT_REDIS_MAJOR_VERSION = int(os.getenv("DEFAULT_REDIS_MAJOR_VERSION", 7))
|
DEFAULT_REDIS_MAJOR_VERSION = int(os.getenv("DEFAULT_REDIS_MAJOR_VERSION", 7))
|
||||||
|
|||||||
@ -313,6 +313,14 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
|||||||
|
|
||||||
self.window_size = int(os.getenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", 60))
|
self.window_size = int(os.getenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", 60))
|
||||||
|
|
||||||
|
# When disabled, TPM is enforced post-call from actual usage (pre-v1.82
|
||||||
|
# behavior) instead of reserving an estimated budget upfront, shedding
|
||||||
|
# the extra per-request Redis Lua round-trip and the global-lock
|
||||||
|
# in-memory fallback that the reservation path incurs.
|
||||||
|
self.tpm_reservation_enabled = (
|
||||||
|
os.getenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", "true").lower() == "true"
|
||||||
|
)
|
||||||
|
|
||||||
# Batch rate limiter (lazy loaded)
|
# Batch rate limiter (lazy loaded)
|
||||||
self._batch_rate_limiter: Optional[Any] = None
|
self._batch_rate_limiter: Optional[Any] = None
|
||||||
|
|
||||||
@ -2113,17 +2121,19 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
|||||||
# Only check rate limits if we have descriptors with actual limits
|
# Only check rate limits if we have descriptors with actual limits
|
||||||
if descriptors:
|
if descriptors:
|
||||||
# First pass: RPM and max_parallel_requests sliding-window check.
|
# First pass: RPM and max_parallel_requests sliding-window check.
|
||||||
# `skip_tpm_check=True` tells should_rate_limit to ignore each
|
# When reservation is enabled, `skip_tpm_check=True` tells
|
||||||
# descriptor's tokens_per_unit so its +1-per-key Lua / in-memory
|
# should_rate_limit to ignore each descriptor's tokens_per_unit so
|
||||||
# increment never touches the :tokens counters — those are owned
|
# its +1-per-key Lua / in-memory increment never touches the
|
||||||
# exclusively by the atomic reserve_tpm_tokens path below. Without
|
# :tokens counters — those are owned exclusively by the atomic
|
||||||
# this, every concurrent in-flight request would pre-inflate the
|
# reserve_tpm_tokens path below. Without this, every concurrent
|
||||||
# :tokens counter by 1, shrinking the effective TPM budget by N
|
# in-flight request would pre-inflate the :tokens counter by 1,
|
||||||
# and causing false-positive 429s under bursts.
|
# shrinking the effective TPM budget by N and causing
|
||||||
|
# false-positive 429s under bursts. When reservation is disabled,
|
||||||
|
# this pass enforces TPM directly from the post-call counters.
|
||||||
response = await self.should_rate_limit(
|
response = await self.should_rate_limit(
|
||||||
descriptors=descriptors,
|
descriptors=descriptors,
|
||||||
parent_otel_span=user_api_key_dict.parent_otel_span,
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
||||||
skip_tpm_check=True,
|
skip_tpm_check=self.tpm_reservation_enabled,
|
||||||
)
|
)
|
||||||
|
|
||||||
if response["overall_code"] == "OVER_LIMIT":
|
if response["overall_code"] == "OVER_LIMIT":
|
||||||
@ -2153,7 +2163,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
|||||||
]
|
]
|
||||||
has_tpm_limits = bool(configured_tpm_limits)
|
has_tpm_limits = bool(configured_tpm_limits)
|
||||||
|
|
||||||
if has_tpm_limits:
|
if has_tpm_limits and self.tpm_reservation_enabled:
|
||||||
min_configured_tpm_limit = min(configured_tpm_limits)
|
min_configured_tpm_limit = min(configured_tpm_limits)
|
||||||
|
|
||||||
# When the configured TPM cap is small enough to constrain the
|
# When the configured TPM cap is small enough to constrain the
|
||||||
|
|||||||
@ -54,7 +54,7 @@ class LiteLLMHealthCheckClient:
|
|||||||
timeout: Request timeout in seconds (default: 120, matching Go implementation)
|
timeout: Request timeout in seconds (default: 120, matching Go implementation)
|
||||||
completion_prompt: Test prompt for chat/completion models
|
completion_prompt: Test prompt for chat/completion models
|
||||||
embedding_text: Test text for embedding models
|
embedding_text: Test text for embedding models
|
||||||
custom_auth_header: Optional custom header name for authentication (e.g., "x-ifood-requester-service").
|
custom_auth_header: Optional custom header name for authentication (e.g., "x-requester-service").
|
||||||
If provided, uses this header instead of standard "Authorization" header.
|
If provided, uses this header instead of standard "Authorization" header.
|
||||||
"""
|
"""
|
||||||
self.base_url = base_url.rstrip("/")
|
self.base_url = base_url.rstrip("/")
|
||||||
@ -404,7 +404,7 @@ async def main():
|
|||||||
yaml_path = os.environ.get("LITELLM_MODELS_YAML")
|
yaml_path = os.environ.get("LITELLM_MODELS_YAML")
|
||||||
custom_auth_header = os.environ.get(
|
custom_auth_header = os.environ.get(
|
||||||
"LITELLM_CUSTOM_AUTH_HEADER"
|
"LITELLM_CUSTOM_AUTH_HEADER"
|
||||||
) # e.g., "x-ifood-requester-service"
|
) # e.g., "x-requester-service"
|
||||||
|
|
||||||
# Debug: Print custom auth header value if set
|
# Debug: Print custom auth header value if set
|
||||||
if custom_auth_header:
|
if custom_auth_header:
|
||||||
|
|||||||
@ -243,6 +243,67 @@ def test_circuit_breaker_half_open_concurrent_calls_are_fast_failed():
|
|||||||
), "concurrent callers should be fast-failed in HALF_OPEN"
|
), "concurrent callers should be fast-failed in HALF_OPEN"
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_breaker_disabled_never_opens():
|
||||||
|
"""When disabled, failures never open the circuit and is_open() stays False."""
|
||||||
|
from litellm.caching.redis_cache import RedisCircuitBreaker
|
||||||
|
|
||||||
|
cb = RedisCircuitBreaker(failure_threshold=3, recovery_timeout=60, enabled=False)
|
||||||
|
|
||||||
|
for _ in range(100):
|
||||||
|
cb.record_failure()
|
||||||
|
|
||||||
|
assert cb._state == "closed"
|
||||||
|
assert cb.is_open() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_circuit_breaker_disabled_record_success_leaves_state_untouched():
|
||||||
|
"""
|
||||||
|
A disabled breaker must not mutate state in any state-machine method. Force
|
||||||
|
a non-default (OPEN) state and assert record_success() returns without
|
||||||
|
resetting it — the same enabled-guard contract as is_open/record_failure.
|
||||||
|
"""
|
||||||
|
from litellm.caching.redis_cache import RedisCircuitBreaker
|
||||||
|
|
||||||
|
cb = RedisCircuitBreaker(failure_threshold=3, recovery_timeout=60, enabled=False)
|
||||||
|
cb._state = "open"
|
||||||
|
cb._failure_count = 3
|
||||||
|
|
||||||
|
cb.record_success()
|
||||||
|
|
||||||
|
assert cb._state == "open"
|
||||||
|
assert cb._failure_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_circuit_breaker_disabled_guard_always_calls_method():
|
||||||
|
"""A disabled breaker lets every guarded call through, even after failures."""
|
||||||
|
from litellm.caching.redis_cache import (
|
||||||
|
RedisCircuitBreaker,
|
||||||
|
_redis_circuit_breaker_guard,
|
||||||
|
)
|
||||||
|
|
||||||
|
class FakeRedis:
|
||||||
|
def __init__(self):
|
||||||
|
self._circuit_breaker = RedisCircuitBreaker(
|
||||||
|
failure_threshold=1, recovery_timeout=60, enabled=False
|
||||||
|
)
|
||||||
|
self.call_count = 0
|
||||||
|
|
||||||
|
@_redis_circuit_breaker_guard
|
||||||
|
async def boom(self):
|
||||||
|
self.call_count += 1
|
||||||
|
raise RuntimeError("redis down")
|
||||||
|
|
||||||
|
fr = FakeRedis()
|
||||||
|
for _ in range(5):
|
||||||
|
with pytest.raises(RuntimeError, match="redis down"):
|
||||||
|
await fr.boom()
|
||||||
|
|
||||||
|
# Every call reached the method body; the breaker never short-circuited.
|
||||||
|
assert fr.call_count == 5
|
||||||
|
assert fr._circuit_breaker.is_open() is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_increment_cache_returns_none_when_no_in_memory_cache_and_redis_fails():
|
async def test_async_increment_cache_returns_none_when_no_in_memory_cache_and_redis_fails():
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -3187,3 +3187,111 @@ def test_get_key_mcp_rpm_limit_precedence():
|
|||||||
none_set = UserAPIKeyAuth(api_key=hash_token("sk-mcp-key"))
|
none_set = UserAPIKeyAuth(api_key=hash_token("sk-mcp-key"))
|
||||||
assert get_key_mcp_rpm_limit(none_set) is None
|
assert get_key_mcp_rpm_limit(none_set) is None
|
||||||
assert get_team_mcp_rpm_limit(none_set) is None
|
assert get_team_mcp_rpm_limit(none_set) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_tpm_reservation_enabled_by_default(monkeypatch):
|
||||||
|
"""Upfront TPM reservation is on unless explicitly disabled via env."""
|
||||||
|
monkeypatch.delenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", raising=False)
|
||||||
|
handler = _PROXY_MaxParallelRequestsHandler(
|
||||||
|
internal_usage_cache=InternalUsageCache(DualCache())
|
||||||
|
)
|
||||||
|
assert handler.tpm_reservation_enabled is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("value", ["false", "False", "FALSE"])
|
||||||
|
def test_tpm_reservation_disabled_via_env(monkeypatch, value):
|
||||||
|
monkeypatch.setenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", value)
|
||||||
|
handler = _PROXY_MaxParallelRequestsHandler(
|
||||||
|
internal_usage_cache=InternalUsageCache(DualCache())
|
||||||
|
)
|
||||||
|
assert handler.tpm_reservation_enabled is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_call_hook_reserves_tpm_when_enabled(monkeypatch):
|
||||||
|
"""
|
||||||
|
With reservation enabled, the pre-call hook reserves the estimated token
|
||||||
|
budget upfront and tells should_rate_limit to skip the :tokens counter so
|
||||||
|
only the reservation path owns it.
|
||||||
|
"""
|
||||||
|
monkeypatch.delenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", raising=False)
|
||||||
|
handler = _PROXY_MaxParallelRequestsHandler(
|
||||||
|
internal_usage_cache=InternalUsageCache(DualCache())
|
||||||
|
)
|
||||||
|
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(api_key=hash_token("sk-tpm"), tpm_limit=10_000)
|
||||||
|
|
||||||
|
should_rate_limit_calls: List[Dict[str, Any]] = []
|
||||||
|
original_should_rate_limit = handler.should_rate_limit
|
||||||
|
|
||||||
|
async def spy_should_rate_limit(*args, **kwargs):
|
||||||
|
should_rate_limit_calls.append(kwargs)
|
||||||
|
return await original_should_rate_limit(*args, **kwargs)
|
||||||
|
|
||||||
|
reserve_calls: List[int] = []
|
||||||
|
original_reserve = handler.reserve_tpm_tokens
|
||||||
|
|
||||||
|
async def spy_reserve(*args, **kwargs):
|
||||||
|
reserve_calls.append(kwargs.get("estimated_tokens"))
|
||||||
|
return await original_reserve(*args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(handler, "should_rate_limit", spy_should_rate_limit)
|
||||||
|
monkeypatch.setattr(handler, "reserve_tpm_tokens", spy_reserve)
|
||||||
|
|
||||||
|
await handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=handler.internal_usage_cache.dual_cache,
|
||||||
|
data={"model": "gpt-4", "messages": [{"role": "user", "content": "hi"}]},
|
||||||
|
call_type="completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(reserve_calls) == 1, "reservation must run when enabled"
|
||||||
|
assert should_rate_limit_calls[0]["skip_tpm_check"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_call_hook_skips_reservation_when_disabled(monkeypatch):
|
||||||
|
"""
|
||||||
|
With reservation disabled, the pre-call hook never calls reserve_tpm_tokens
|
||||||
|
and enforces TPM directly in should_rate_limit (skip_tpm_check=False), the
|
||||||
|
pre-v1.82 post-call accounting behavior.
|
||||||
|
"""
|
||||||
|
monkeypatch.setenv("LITELLM_TPM_TOKEN_RESERVATION_ENABLED", "false")
|
||||||
|
handler = _PROXY_MaxParallelRequestsHandler(
|
||||||
|
internal_usage_cache=InternalUsageCache(DualCache())
|
||||||
|
)
|
||||||
|
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(api_key=hash_token("sk-tpm"), tpm_limit=10_000)
|
||||||
|
|
||||||
|
should_rate_limit_calls: List[Dict[str, Any]] = []
|
||||||
|
original_should_rate_limit = handler.should_rate_limit
|
||||||
|
|
||||||
|
async def spy_should_rate_limit(*args, **kwargs):
|
||||||
|
should_rate_limit_calls.append(kwargs)
|
||||||
|
return await original_should_rate_limit(*args, **kwargs)
|
||||||
|
|
||||||
|
reserve_calls: List[Any] = []
|
||||||
|
|
||||||
|
async def spy_reserve(*args, **kwargs):
|
||||||
|
reserve_calls.append(kwargs)
|
||||||
|
raise AssertionError("reserve_tpm_tokens must not run when disabled")
|
||||||
|
|
||||||
|
monkeypatch.setattr(handler, "should_rate_limit", spy_should_rate_limit)
|
||||||
|
monkeypatch.setattr(handler, "reserve_tpm_tokens", spy_reserve)
|
||||||
|
|
||||||
|
data = {"model": "gpt-4", "messages": [{"role": "user", "content": "hi"}]}
|
||||||
|
await handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=handler.internal_usage_cache.dual_cache,
|
||||||
|
data=data,
|
||||||
|
call_type="completion",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert reserve_calls == [], "reservation must be skipped when disabled"
|
||||||
|
assert should_rate_limit_calls[0]["skip_tpm_check"] is False
|
||||||
|
# No reservation stash leaks into the request metadata.
|
||||||
|
from litellm.proxy.hooks.parallel_request_limiter_v3 import (
|
||||||
|
TPM_RESERVED_TOKENS_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert TPM_RESERVED_TOKENS_KEY not in (data.get("metadata") or {})
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user