fix(v3 limiter): cap no-max_tokens TPM floor at smallest configured limit (#28805)
This commit is contained in:
parent
8b16b61114
commit
80cf50dedb
@ -207,6 +207,11 @@ REDIS_NODE_HASHTAG_NAME = "all_keys"
|
||||
# *some* output budget; these define that fallback estimate.
|
||||
DEFAULT_MAX_TOKENS_ESTIMATE = 4096
|
||||
DEFAULT_CHARS_PER_TOKEN = 4
|
||||
# Fraction of the available output budget reserved as the upfront floor when
|
||||
# the request omits max_tokens. Applied to both DEFAULT_MAX_TOKENS_ESTIMATE
|
||||
# (baseline floor) and to the smallest configured TPM limit (capped floor for
|
||||
# small per-tenant TPM caps).
|
||||
_TPM_FLOOR_FRACTION = 4
|
||||
# Stash for the reserved-token count on the request data dict so success/
|
||||
# failure callbacks can reconcile against the upfront reservation.
|
||||
TPM_RESERVED_TOKENS_KEY = "_litellm_tpm_reserved_tokens"
|
||||
@ -340,10 +345,26 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
"""Return the current time for rate limiting calculations."""
|
||||
return self._time_provider()
|
||||
|
||||
@staticmethod
|
||||
def _no_max_tokens_output_floor(
|
||||
min_configured_tpm_limit: Optional[int],
|
||||
) -> int:
|
||||
"""Output-budget floor used when the request omits max_tokens.
|
||||
|
||||
Capped at a fraction of the smallest configured TPM limit so a small
|
||||
per-tenant cap can't be tripped by the floor alone. Returns the
|
||||
baseline floor when no limit is provided.
|
||||
"""
|
||||
baseline = DEFAULT_MAX_TOKENS_ESTIMATE // _TPM_FLOOR_FRACTION
|
||||
if min_configured_tpm_limit is None:
|
||||
return baseline
|
||||
return min(baseline, max(1, min_configured_tpm_limit // _TPM_FLOOR_FRACTION))
|
||||
|
||||
def _estimate_tokens_for_request(
|
||||
self,
|
||||
data: dict,
|
||||
model: Optional[str] = None,
|
||||
min_configured_tpm_limit: Optional[int] = None,
|
||||
) -> int:
|
||||
"""
|
||||
Estimate total tokens this request will consume so we can reserve them
|
||||
@ -351,6 +372,12 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
estimated = input_tokens + max_tokens.
|
||||
|
||||
Supports chat (messages), completions (prompt), and embeddings (input).
|
||||
|
||||
``min_configured_tpm_limit`` is the smallest ``tokens_per_unit`` among
|
||||
the TPM-bearing descriptors this request will be charged against. When
|
||||
provided, the no-``max_tokens`` output-budget floor is capped at a
|
||||
fraction of that limit so small TPM caps remain usable. Omit to
|
||||
preserve the unconstrained floor.
|
||||
"""
|
||||
messages = data.get("messages")
|
||||
prompt = data.get("prompt")
|
||||
@ -394,11 +421,14 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
case _:
|
||||
# No max_tokens specified — reserve at least the input size with a
|
||||
# conservative floor so a stream of small concurrent requests can't
|
||||
# collectively bypass the limit.
|
||||
max_tokens_estimate = max(
|
||||
estimated_input_tokens,
|
||||
DEFAULT_MAX_TOKENS_ESTIMATE // 4,
|
||||
# collectively bypass the limit. Cap the floor by a fraction of
|
||||
# the smallest TPM limit this request will be charged against,
|
||||
# so a small per-tenant TPM cap can't be tripped by the floor
|
||||
# alone.
|
||||
output_floor = self._no_max_tokens_output_floor(
|
||||
min_configured_tpm_limit
|
||||
)
|
||||
max_tokens_estimate = max(estimated_input_tokens, output_floor)
|
||||
|
||||
total_estimated = estimated_input_tokens + max_tokens_estimate
|
||||
|
||||
@ -2009,12 +2039,39 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
# in-memory check otherwise — single-worker protection still holds
|
||||
# even without Redis.
|
||||
# ----------------------------------------------------------------
|
||||
has_tpm_limits = any(
|
||||
(d.get("rate_limit") or {}).get("tokens_per_unit") is not None
|
||||
configured_tpm_limits = [
|
||||
int(v)
|
||||
for d in descriptors
|
||||
)
|
||||
for v in [(d.get("rate_limit") or {}).get("tokens_per_unit")]
|
||||
if v is not None
|
||||
]
|
||||
has_tpm_limits = bool(configured_tpm_limits)
|
||||
|
||||
if has_tpm_limits:
|
||||
min_configured_tpm_limit = min(configured_tpm_limits)
|
||||
|
||||
# When the configured TPM cap is small enough to constrain the
|
||||
# no-max_tokens floor, also hard-cap the model output via
|
||||
# data["max_tokens"] so concurrent unbounded generations can't
|
||||
# spend past the limit before post-call reconciliation runs.
|
||||
# Skip when the request already sets max_tokens or has no
|
||||
# generation budget at all (embeddings).
|
||||
capped_floor = self._no_max_tokens_output_floor(
|
||||
min_configured_tpm_limit
|
||||
)
|
||||
baseline_floor = DEFAULT_MAX_TOKENS_ESTIMATE // _TPM_FLOOR_FRACTION
|
||||
has_explicit_max_tokens = (
|
||||
data.get("max_tokens") is not None
|
||||
or data.get("max_completion_tokens") is not None
|
||||
)
|
||||
is_embedding = data.get("input") is not None
|
||||
if (
|
||||
capped_floor < baseline_floor
|
||||
and not has_explicit_max_tokens
|
||||
and not is_embedding
|
||||
):
|
||||
data["max_tokens"] = capped_floor
|
||||
|
||||
# Floor at 1 token so contentless requests (/responses,
|
||||
# tool-call continuations, empty messages) still flow
|
||||
# through the atomic counter and get backpressure when at
|
||||
@ -2026,6 +2083,7 @@ class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
||||
self._estimate_tokens_for_request(
|
||||
data=data,
|
||||
model=requested_model,
|
||||
min_configured_tpm_limit=min_configured_tpm_limit,
|
||||
),
|
||||
1,
|
||||
)
|
||||
|
||||
@ -995,5 +995,202 @@ async def test_token_rate_limit_headers_present_in_stored_response(rate_limiter)
|
||||
assert api_key_tokens["limit_remaining"] >= 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_estimate_tokens_floor_caps_at_smallest_configured_tpm(rate_limiter):
|
||||
"""
|
||||
Regression: with a small configured TPM cap and no max_tokens, the
|
||||
output-budget floor must be capped at a fraction of that limit so the
|
||||
reservation alone can't trip the limit.
|
||||
"""
|
||||
handler, _cache = rate_limiter
|
||||
|
||||
estimate = handler._estimate_tokens_for_request(
|
||||
data={"messages": [{"role": "user", "content": "hello"}]},
|
||||
min_configured_tpm_limit=1000,
|
||||
)
|
||||
# input ~= 5//4 = 1 token; output floor capped at 1000//4 = 250;
|
||||
# total ~= 251 (well under 1000).
|
||||
assert (
|
||||
estimate <= 1000 // 2
|
||||
), f"With TPM=1000, reservation must stay well under the limit; got {estimate}"
|
||||
assert estimate >= 1, "Estimate must be at least the call-site floor of 1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_estimate_tokens_floor_unchanged_for_large_tpm(rate_limiter):
|
||||
"""
|
||||
Large TPM budgets must keep the 1024-token floor so a stream of small
|
||||
concurrent requests can't collectively bypass the limit.
|
||||
"""
|
||||
handler, _cache = rate_limiter
|
||||
|
||||
estimate = handler._estimate_tokens_for_request(
|
||||
data={"messages": [{"role": "user", "content": "hello"}]},
|
||||
min_configured_tpm_limit=100_000,
|
||||
)
|
||||
# input ~= 1; output floor = min(1024, 100_000//4=25_000) = 1024;
|
||||
# total ~= 1025.
|
||||
assert estimate == 1 + 1024
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_estimate_tokens_floor_unchanged_when_kwarg_omitted(rate_limiter):
|
||||
"""
|
||||
Callers that don't pass min_configured_tpm_limit (legacy path, tests that
|
||||
stub the estimator) must observe the pre-fix floor.
|
||||
"""
|
||||
handler, _cache = rate_limiter
|
||||
|
||||
estimate = handler._estimate_tokens_for_request(
|
||||
data={"messages": [{"role": "user", "content": "hello"}]},
|
||||
)
|
||||
assert estimate == 1 + 1024
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_tpm_cap_admits_no_max_tokens_request(rate_limiter):
|
||||
"""
|
||||
Regression (end-to-end at the hook level): a project-level model_tpm_limit
|
||||
of 1000 with a tiny no-max_tokens request must not 429 on the first call.
|
||||
Pre-fix the 1024-token floor tripped OVER_LIMIT against the 1000-token cap
|
||||
on every request.
|
||||
"""
|
||||
handler, cache = rate_limiter
|
||||
|
||||
api_key = hash_token("sk-small-tpm")
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=api_key,
|
||||
project_id="proj-small-tpm",
|
||||
project_metadata={
|
||||
"model_tpm_limit": {"gpt-3.5-turbo": 1000},
|
||||
"model_rpm_limit": {"gpt-3.5-turbo": 60},
|
||||
},
|
||||
)
|
||||
|
||||
data = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
}
|
||||
|
||||
# Must not raise — pre-fix this was a 429.
|
||||
await handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type="",
|
||||
)
|
||||
|
||||
reserved = (data.get("metadata") or {}).get(TPM_RESERVED_TOKENS_KEY)
|
||||
assert reserved is not None, "Reservation should have been stashed"
|
||||
assert reserved <= 1000 // 2, (
|
||||
f"Capped floor must keep the reservation well under the 1000 TPM "
|
||||
f"cap; got {reserved}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_tpm_cap_injects_matching_max_tokens(rate_limiter):
|
||||
"""
|
||||
When a small TPM cap forces the no-max_tokens floor below the baseline,
|
||||
the hook must also write data['max_tokens'] = capped_floor so the actual
|
||||
model output is bounded by the reservation. Without this cap, concurrent
|
||||
no-max_tokens generations can spend past the TPM limit before post-call
|
||||
reconciliation runs.
|
||||
"""
|
||||
handler, cache = rate_limiter
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=hash_token("sk-small-tpm-cap"),
|
||||
project_id="proj-small-tpm-cap",
|
||||
project_metadata={
|
||||
"model_tpm_limit": {"gpt-3.5-turbo": 1000},
|
||||
},
|
||||
)
|
||||
|
||||
data: dict = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
}
|
||||
|
||||
await handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type="",
|
||||
)
|
||||
|
||||
assert data.get("max_tokens") == 1000 // 4, (
|
||||
f"Capped floor must be written to max_tokens to bound the actual "
|
||||
f"model output; got {data.get('max_tokens')}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_tpm_cap_does_not_inject_max_tokens(rate_limiter):
|
||||
"""
|
||||
A TPM cap that doesn't constrain the floor must not silently inject
|
||||
max_tokens — that would change behaviour for tenants who already have
|
||||
plenty of budget.
|
||||
"""
|
||||
handler, cache = rate_limiter
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=hash_token("sk-large-tpm-cap"),
|
||||
project_id="proj-large-tpm-cap",
|
||||
project_metadata={
|
||||
"model_tpm_limit": {"gpt-3.5-turbo": 100_000},
|
||||
},
|
||||
)
|
||||
|
||||
data: dict = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
}
|
||||
|
||||
await handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type="",
|
||||
)
|
||||
|
||||
assert "max_tokens" not in data, (
|
||||
f"Large TPM caps should leave max_tokens alone; got "
|
||||
f"{data.get('max_tokens')}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_tpm_cap_preserves_explicit_max_tokens(rate_limiter):
|
||||
"""
|
||||
Explicit max_tokens from the caller must never be overwritten by the
|
||||
bypass mitigation — the user already declared their budget.
|
||||
"""
|
||||
handler, cache = rate_limiter
|
||||
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key=hash_token("sk-explicit-max-tokens"),
|
||||
project_id="proj-explicit-max-tokens",
|
||||
project_metadata={
|
||||
"model_tpm_limit": {"gpt-3.5-turbo": 1000},
|
||||
},
|
||||
)
|
||||
|
||||
data: dict = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "hello"}],
|
||||
"max_tokens": 500,
|
||||
}
|
||||
|
||||
await handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=cache,
|
||||
data=data,
|
||||
call_type="",
|
||||
)
|
||||
|
||||
assert data["max_tokens"] == 500
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "-s"])
|
||||
|
||||
Loading…
Reference in New Issue
Block a user