fix(v3 limiter): cap no-max_tokens TPM floor at smallest configured limit (#28805)

This commit is contained in:
michelligabriele 2026-05-31 04:36:04 +02:00 committed by GitHub
parent 8b16b61114
commit 80cf50dedb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 262 additions and 7 deletions

View File

@ -207,6 +207,11 @@ REDIS_NODE_HASHTAG_NAME = "all_keys"
# *some* output budget; these define that fallback estimate.
DEFAULT_MAX_TOKENS_ESTIMATE = 4096
DEFAULT_CHARS_PER_TOKEN = 4
# Fraction of the available output budget reserved as the upfront floor when
# the request omits max_tokens. Applied to both DEFAULT_MAX_TOKENS_ESTIMATE
# (baseline floor) and to the smallest configured TPM limit (capped floor for
# small per-tenant TPM caps).
_TPM_FLOOR_FRACTION = 4
# Stash for the reserved-token count on the request data dict so success/
# failure callbacks can reconcile against the upfront reservation.
TPM_RESERVED_TOKENS_KEY = "_litellm_tpm_reserved_tokens"
@ -340,10 +345,26 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
"""Return the current time for rate limiting calculations."""
return self._time_provider()
@staticmethod
def _no_max_tokens_output_floor(
min_configured_tpm_limit: Optional[int],
) -> int:
"""Output-budget floor used when the request omits max_tokens.
Capped at a fraction of the smallest configured TPM limit so a small
per-tenant cap can't be tripped by the floor alone. Returns the
baseline floor when no limit is provided.
"""
baseline = DEFAULT_MAX_TOKENS_ESTIMATE // _TPM_FLOOR_FRACTION
if min_configured_tpm_limit is None:
return baseline
return min(baseline, max(1, min_configured_tpm_limit // _TPM_FLOOR_FRACTION))
def _estimate_tokens_for_request(
self,
data: dict,
model: Optional[str] = None,
min_configured_tpm_limit: Optional[int] = None,
) -> int:
"""
Estimate total tokens this request will consume so we can reserve them
@ -351,6 +372,12 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
estimated = input_tokens + max_tokens.
Supports chat (messages), completions (prompt), and embeddings (input).
``min_configured_tpm_limit`` is the smallest ``tokens_per_unit`` among
the TPM-bearing descriptors this request will be charged against. When
provided, the no-``max_tokens`` output-budget floor is capped at a
fraction of that limit so small TPM caps remain usable. Omit to
preserve the unconstrained floor.
"""
messages = data.get("messages")
prompt = data.get("prompt")
@ -394,11 +421,14 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
case _:
# No max_tokens specified — reserve at least the input size with a
# conservative floor so a stream of small concurrent requests can't
# collectively bypass the limit.
max_tokens_estimate = max(
estimated_input_tokens,
DEFAULT_MAX_TOKENS_ESTIMATE // 4,
# collectively bypass the limit. Cap the floor by a fraction of
# the smallest TPM limit this request will be charged against,
# so a small per-tenant TPM cap can't be tripped by the floor
# alone.
output_floor = self._no_max_tokens_output_floor(
min_configured_tpm_limit
)
max_tokens_estimate = max(estimated_input_tokens, output_floor)
total_estimated = estimated_input_tokens + max_tokens_estimate
@ -2009,12 +2039,39 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
# in-memory check otherwise — single-worker protection still holds
# even without Redis.
# ----------------------------------------------------------------
has_tpm_limits = any(
(d.get("rate_limit") or {}).get("tokens_per_unit") is not None
configured_tpm_limits = [
int(v)
for d in descriptors
)
for v in [(d.get("rate_limit") or {}).get("tokens_per_unit")]
if v is not None
]
has_tpm_limits = bool(configured_tpm_limits)
if has_tpm_limits:
min_configured_tpm_limit = min(configured_tpm_limits)
# When the configured TPM cap is small enough to constrain the
# no-max_tokens floor, also hard-cap the model output via
# data["max_tokens"] so concurrent unbounded generations can't
# spend past the limit before post-call reconciliation runs.
# Skip when the request already sets max_tokens or has no
# generation budget at all (embeddings).
capped_floor = self._no_max_tokens_output_floor(
min_configured_tpm_limit
)
baseline_floor = DEFAULT_MAX_TOKENS_ESTIMATE // _TPM_FLOOR_FRACTION
has_explicit_max_tokens = (
data.get("max_tokens") is not None
or data.get("max_completion_tokens") is not None
)
is_embedding = data.get("input") is not None
if (
capped_floor < baseline_floor
and not has_explicit_max_tokens
and not is_embedding
):
data["max_tokens"] = capped_floor
# Floor at 1 token so contentless requests (/responses,
# tool-call continuations, empty messages) still flow
# through the atomic counter and get backpressure when at
@ -2026,6 +2083,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
self._estimate_tokens_for_request(
data=data,
model=requested_model,
min_configured_tpm_limit=min_configured_tpm_limit,
),
1,
)

View File

@ -995,5 +995,202 @@ async def test_token_rate_limit_headers_present_in_stored_response(rate_limiter)
assert api_key_tokens["limit_remaining"] >= 0
@pytest.mark.asyncio
async def test_estimate_tokens_floor_caps_at_smallest_configured_tpm(rate_limiter):
"""
Regression: with a small configured TPM cap and no max_tokens, the
output-budget floor must be capped at a fraction of that limit so the
reservation alone can't trip the limit.
"""
handler, _cache = rate_limiter
estimate = handler._estimate_tokens_for_request(
data={"messages": [{"role": "user", "content": "hello"}]},
min_configured_tpm_limit=1000,
)
# input ~= 5//4 = 1 token; output floor capped at 1000//4 = 250;
# total ~= 251 (well under 1000).
assert (
estimate <= 1000 // 2
), f"With TPM=1000, reservation must stay well under the limit; got {estimate}"
assert estimate >= 1, "Estimate must be at least the call-site floor of 1"
@pytest.mark.asyncio
async def test_estimate_tokens_floor_unchanged_for_large_tpm(rate_limiter):
"""
Large TPM budgets must keep the 1024-token floor so a stream of small
concurrent requests can't collectively bypass the limit.
"""
handler, _cache = rate_limiter
estimate = handler._estimate_tokens_for_request(
data={"messages": [{"role": "user", "content": "hello"}]},
min_configured_tpm_limit=100_000,
)
# input ~= 1; output floor = min(1024, 100_000//4=25_000) = 1024;
# total ~= 1025.
assert estimate == 1 + 1024
@pytest.mark.asyncio
async def test_estimate_tokens_floor_unchanged_when_kwarg_omitted(rate_limiter):
"""
Callers that don't pass min_configured_tpm_limit (legacy path, tests that
stub the estimator) must observe the pre-fix floor.
"""
handler, _cache = rate_limiter
estimate = handler._estimate_tokens_for_request(
data={"messages": [{"role": "user", "content": "hello"}]},
)
assert estimate == 1 + 1024
@pytest.mark.asyncio
async def test_small_tpm_cap_admits_no_max_tokens_request(rate_limiter):
"""
Regression (end-to-end at the hook level): a project-level model_tpm_limit
of 1000 with a tiny no-max_tokens request must not 429 on the first call.
Pre-fix the 1024-token floor tripped OVER_LIMIT against the 1000-token cap
on every request.
"""
handler, cache = rate_limiter
api_key = hash_token("sk-small-tpm")
user_api_key_dict = UserAPIKeyAuth(
api_key=api_key,
project_id="proj-small-tpm",
project_metadata={
"model_tpm_limit": {"gpt-3.5-turbo": 1000},
"model_rpm_limit": {"gpt-3.5-turbo": 60},
},
)
data = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "hello"}],
}
# Must not raise — pre-fix this was a 429.
await handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type="",
)
reserved = (data.get("metadata") or {}).get(TPM_RESERVED_TOKENS_KEY)
assert reserved is not None, "Reservation should have been stashed"
assert reserved <= 1000 // 2, (
f"Capped floor must keep the reservation well under the 1000 TPM "
f"cap; got {reserved}"
)
@pytest.mark.asyncio
async def test_small_tpm_cap_injects_matching_max_tokens(rate_limiter):
"""
When a small TPM cap forces the no-max_tokens floor below the baseline,
the hook must also write data['max_tokens'] = capped_floor so the actual
model output is bounded by the reservation. Without this cap, concurrent
no-max_tokens generations can spend past the TPM limit before post-call
reconciliation runs.
"""
handler, cache = rate_limiter
user_api_key_dict = UserAPIKeyAuth(
api_key=hash_token("sk-small-tpm-cap"),
project_id="proj-small-tpm-cap",
project_metadata={
"model_tpm_limit": {"gpt-3.5-turbo": 1000},
},
)
data: dict = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "hello"}],
}
await handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type="",
)
assert data.get("max_tokens") == 1000 // 4, (
f"Capped floor must be written to max_tokens to bound the actual "
f"model output; got {data.get('max_tokens')}"
)
@pytest.mark.asyncio
async def test_large_tpm_cap_does_not_inject_max_tokens(rate_limiter):
"""
A TPM cap that doesn't constrain the floor must not silently inject
max_tokens that would change behaviour for tenants who already have
plenty of budget.
"""
handler, cache = rate_limiter
user_api_key_dict = UserAPIKeyAuth(
api_key=hash_token("sk-large-tpm-cap"),
project_id="proj-large-tpm-cap",
project_metadata={
"model_tpm_limit": {"gpt-3.5-turbo": 100_000},
},
)
data: dict = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "hello"}],
}
await handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type="",
)
assert "max_tokens" not in data, (
f"Large TPM caps should leave max_tokens alone; got "
f"{data.get('max_tokens')}"
)
@pytest.mark.asyncio
async def test_small_tpm_cap_preserves_explicit_max_tokens(rate_limiter):
"""
Explicit max_tokens from the caller must never be overwritten by the
bypass mitigation the user already declared their budget.
"""
handler, cache = rate_limiter
user_api_key_dict = UserAPIKeyAuth(
api_key=hash_token("sk-explicit-max-tokens"),
project_id="proj-explicit-max-tokens",
project_metadata={
"model_tpm_limit": {"gpt-3.5-turbo": 1000},
},
)
data: dict = {
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "hello"}],
"max_tokens": 500,
}
await handler.async_pre_call_hook(
user_api_key_dict=user_api_key_dict,
cache=cache,
data=data,
call_type="",
)
assert data["max_tokens"] == 500
if __name__ == "__main__":
pytest.main([__file__, "-v", "-s"])