From 80cf50dedbcd55b3ddd7809b47f6df63b5a90d57 Mon Sep 17 00:00:00 2001 From: michelligabriele Date: Sun, 31 May 2026 04:36:04 +0200 Subject: [PATCH] fix(v3 limiter): cap no-max_tokens TPM floor at smallest configured limit (#28805) --- .../hooks/parallel_request_limiter_v3.py | 72 ++++++- .../proxy/hooks/test_tpm_concurrent.py | 197 ++++++++++++++++++ 2 files changed, 262 insertions(+), 7 deletions(-) diff --git a/litellm/proxy/hooks/parallel_request_limiter_v3.py b/litellm/proxy/hooks/parallel_request_limiter_v3.py index 283a3d8d10..d03ad70562 100644 --- a/litellm/proxy/hooks/parallel_request_limiter_v3.py +++ b/litellm/proxy/hooks/parallel_request_limiter_v3.py @@ -207,6 +207,11 @@ REDIS_NODE_HASHTAG_NAME = "all_keys" # *some* output budget; these define that fallback estimate. DEFAULT_MAX_TOKENS_ESTIMATE = 4096 DEFAULT_CHARS_PER_TOKEN = 4 +# Fraction of the available output budget reserved as the upfront floor when +# the request omits max_tokens. Applied to both DEFAULT_MAX_TOKENS_ESTIMATE +# (baseline floor) and to the smallest configured TPM limit (capped floor for +# small per-tenant TPM caps). +_TPM_FLOOR_FRACTION = 4 # Stash for the reserved-token count on the request data dict so success/ # failure callbacks can reconcile against the upfront reservation. TPM_RESERVED_TOKENS_KEY = "_litellm_tpm_reserved_tokens" @@ -340,10 +345,26 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger): """Return the current time for rate limiting calculations.""" return self._time_provider() + @staticmethod + def _no_max_tokens_output_floor( + min_configured_tpm_limit: Optional[int], + ) -> int: + """Output-budget floor used when the request omits max_tokens. + + Capped at a fraction of the smallest configured TPM limit so a small + per-tenant cap can't be tripped by the floor alone. Returns the + baseline floor when no limit is provided. + """ + baseline = DEFAULT_MAX_TOKENS_ESTIMATE // _TPM_FLOOR_FRACTION + if min_configured_tpm_limit is None: + return baseline + return min(baseline, max(1, min_configured_tpm_limit // _TPM_FLOOR_FRACTION)) + def _estimate_tokens_for_request( self, data: dict, model: Optional[str] = None, + min_configured_tpm_limit: Optional[int] = None, ) -> int: """ Estimate total tokens this request will consume so we can reserve them @@ -351,6 +372,12 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger): estimated = input_tokens + max_tokens. Supports chat (messages), completions (prompt), and embeddings (input). + + ``min_configured_tpm_limit`` is the smallest ``tokens_per_unit`` among + the TPM-bearing descriptors this request will be charged against. When + provided, the no-``max_tokens`` output-budget floor is capped at a + fraction of that limit so small TPM caps remain usable. Omit to + preserve the unconstrained floor. """ messages = data.get("messages") prompt = data.get("prompt") @@ -394,11 +421,14 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger): case _: # No max_tokens specified — reserve at least the input size with a # conservative floor so a stream of small concurrent requests can't - # collectively bypass the limit. - max_tokens_estimate = max( - estimated_input_tokens, - DEFAULT_MAX_TOKENS_ESTIMATE // 4, + # collectively bypass the limit. Cap the floor by a fraction of + # the smallest TPM limit this request will be charged against, + # so a small per-tenant TPM cap can't be tripped by the floor + # alone. + output_floor = self._no_max_tokens_output_floor( + min_configured_tpm_limit ) + max_tokens_estimate = max(estimated_input_tokens, output_floor) total_estimated = estimated_input_tokens + max_tokens_estimate @@ -2009,12 +2039,39 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger): # in-memory check otherwise — single-worker protection still holds # even without Redis. # ---------------------------------------------------------------- - has_tpm_limits = any( - (d.get("rate_limit") or {}).get("tokens_per_unit") is not None + configured_tpm_limits = [ + int(v) for d in descriptors - ) + for v in [(d.get("rate_limit") or {}).get("tokens_per_unit")] + if v is not None + ] + has_tpm_limits = bool(configured_tpm_limits) if has_tpm_limits: + min_configured_tpm_limit = min(configured_tpm_limits) + + # When the configured TPM cap is small enough to constrain the + # no-max_tokens floor, also hard-cap the model output via + # data["max_tokens"] so concurrent unbounded generations can't + # spend past the limit before post-call reconciliation runs. + # Skip when the request already sets max_tokens or has no + # generation budget at all (embeddings). + capped_floor = self._no_max_tokens_output_floor( + min_configured_tpm_limit + ) + baseline_floor = DEFAULT_MAX_TOKENS_ESTIMATE // _TPM_FLOOR_FRACTION + has_explicit_max_tokens = ( + data.get("max_tokens") is not None + or data.get("max_completion_tokens") is not None + ) + is_embedding = data.get("input") is not None + if ( + capped_floor < baseline_floor + and not has_explicit_max_tokens + and not is_embedding + ): + data["max_tokens"] = capped_floor + # Floor at 1 token so contentless requests (/responses, # tool-call continuations, empty messages) still flow # through the atomic counter and get backpressure when at @@ -2026,6 +2083,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger): self._estimate_tokens_for_request( data=data, model=requested_model, + min_configured_tpm_limit=min_configured_tpm_limit, ), 1, ) diff --git a/tests/test_litellm/proxy/hooks/test_tpm_concurrent.py b/tests/test_litellm/proxy/hooks/test_tpm_concurrent.py index e294d1471d..b02f6c1516 100644 --- a/tests/test_litellm/proxy/hooks/test_tpm_concurrent.py +++ b/tests/test_litellm/proxy/hooks/test_tpm_concurrent.py @@ -995,5 +995,202 @@ async def test_token_rate_limit_headers_present_in_stored_response(rate_limiter) assert api_key_tokens["limit_remaining"] >= 0 +@pytest.mark.asyncio +async def test_estimate_tokens_floor_caps_at_smallest_configured_tpm(rate_limiter): + """ + Regression: with a small configured TPM cap and no max_tokens, the + output-budget floor must be capped at a fraction of that limit so the + reservation alone can't trip the limit. + """ + handler, _cache = rate_limiter + + estimate = handler._estimate_tokens_for_request( + data={"messages": [{"role": "user", "content": "hello"}]}, + min_configured_tpm_limit=1000, + ) + # input ~= 5//4 = 1 token; output floor capped at 1000//4 = 250; + # total ~= 251 (well under 1000). + assert ( + estimate <= 1000 // 2 + ), f"With TPM=1000, reservation must stay well under the limit; got {estimate}" + assert estimate >= 1, "Estimate must be at least the call-site floor of 1" + + +@pytest.mark.asyncio +async def test_estimate_tokens_floor_unchanged_for_large_tpm(rate_limiter): + """ + Large TPM budgets must keep the 1024-token floor so a stream of small + concurrent requests can't collectively bypass the limit. + """ + handler, _cache = rate_limiter + + estimate = handler._estimate_tokens_for_request( + data={"messages": [{"role": "user", "content": "hello"}]}, + min_configured_tpm_limit=100_000, + ) + # input ~= 1; output floor = min(1024, 100_000//4=25_000) = 1024; + # total ~= 1025. + assert estimate == 1 + 1024 + + +@pytest.mark.asyncio +async def test_estimate_tokens_floor_unchanged_when_kwarg_omitted(rate_limiter): + """ + Callers that don't pass min_configured_tpm_limit (legacy path, tests that + stub the estimator) must observe the pre-fix floor. + """ + handler, _cache = rate_limiter + + estimate = handler._estimate_tokens_for_request( + data={"messages": [{"role": "user", "content": "hello"}]}, + ) + assert estimate == 1 + 1024 + + +@pytest.mark.asyncio +async def test_small_tpm_cap_admits_no_max_tokens_request(rate_limiter): + """ + Regression (end-to-end at the hook level): a project-level model_tpm_limit + of 1000 with a tiny no-max_tokens request must not 429 on the first call. + Pre-fix the 1024-token floor tripped OVER_LIMIT against the 1000-token cap + on every request. + """ + handler, cache = rate_limiter + + api_key = hash_token("sk-small-tpm") + user_api_key_dict = UserAPIKeyAuth( + api_key=api_key, + project_id="proj-small-tpm", + project_metadata={ + "model_tpm_limit": {"gpt-3.5-turbo": 1000}, + "model_rpm_limit": {"gpt-3.5-turbo": 60}, + }, + ) + + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}], + } + + # Must not raise — pre-fix this was a 429. + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type="", + ) + + reserved = (data.get("metadata") or {}).get(TPM_RESERVED_TOKENS_KEY) + assert reserved is not None, "Reservation should have been stashed" + assert reserved <= 1000 // 2, ( + f"Capped floor must keep the reservation well under the 1000 TPM " + f"cap; got {reserved}" + ) + + +@pytest.mark.asyncio +async def test_small_tpm_cap_injects_matching_max_tokens(rate_limiter): + """ + When a small TPM cap forces the no-max_tokens floor below the baseline, + the hook must also write data['max_tokens'] = capped_floor so the actual + model output is bounded by the reservation. Without this cap, concurrent + no-max_tokens generations can spend past the TPM limit before post-call + reconciliation runs. + """ + handler, cache = rate_limiter + + user_api_key_dict = UserAPIKeyAuth( + api_key=hash_token("sk-small-tpm-cap"), + project_id="proj-small-tpm-cap", + project_metadata={ + "model_tpm_limit": {"gpt-3.5-turbo": 1000}, + }, + ) + + data: dict = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}], + } + + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type="", + ) + + assert data.get("max_tokens") == 1000 // 4, ( + f"Capped floor must be written to max_tokens to bound the actual " + f"model output; got {data.get('max_tokens')}" + ) + + +@pytest.mark.asyncio +async def test_large_tpm_cap_does_not_inject_max_tokens(rate_limiter): + """ + A TPM cap that doesn't constrain the floor must not silently inject + max_tokens — that would change behaviour for tenants who already have + plenty of budget. + """ + handler, cache = rate_limiter + + user_api_key_dict = UserAPIKeyAuth( + api_key=hash_token("sk-large-tpm-cap"), + project_id="proj-large-tpm-cap", + project_metadata={ + "model_tpm_limit": {"gpt-3.5-turbo": 100_000}, + }, + ) + + data: dict = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}], + } + + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type="", + ) + + assert "max_tokens" not in data, ( + f"Large TPM caps should leave max_tokens alone; got " + f"{data.get('max_tokens')}" + ) + + +@pytest.mark.asyncio +async def test_small_tpm_cap_preserves_explicit_max_tokens(rate_limiter): + """ + Explicit max_tokens from the caller must never be overwritten by the + bypass mitigation — the user already declared their budget. + """ + handler, cache = rate_limiter + + user_api_key_dict = UserAPIKeyAuth( + api_key=hash_token("sk-explicit-max-tokens"), + project_id="proj-explicit-max-tokens", + project_metadata={ + "model_tpm_limit": {"gpt-3.5-turbo": 1000}, + }, + ) + + data: dict = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hello"}], + "max_tokens": 500, + } + + await handler.async_pre_call_hook( + user_api_key_dict=user_api_key_dict, + cache=cache, + data=data, + call_type="", + ) + + assert data["max_tokens"] == 500 + + if __name__ == "__main__": pytest.main([__file__, "-v", "-s"])