2955 lines
117 KiB
Python
2955 lines
117 KiB
Python
"""
|
|
This is a rate limiter implementation based on a similar one by Envoy proxy.
|
|
|
|
This is currently in development and not yet ready for production.
|
|
"""
|
|
|
|
import asyncio
|
|
import binascii
|
|
import os
|
|
from datetime import datetime
|
|
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
List,
|
|
Literal,
|
|
Optional,
|
|
Set,
|
|
Tuple,
|
|
TypedDict,
|
|
Union,
|
|
cast,
|
|
)
|
|
|
|
from fastapi import HTTPException
|
|
|
|
from litellm import DualCache
|
|
from litellm._logging import verbose_proxy_logger
|
|
from litellm.constants import DYNAMIC_RATE_LIMIT_ERROR_THRESHOLD_PER_MINUTE
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
|
get_str_from_messages,
|
|
)
|
|
from litellm.proxy._types import UserAPIKeyAuth
|
|
from litellm.proxy.auth.auth_utils import get_model_rate_limit_from_metadata
|
|
from litellm.types.caching import RedisPipelineIncrementOperation
|
|
from litellm.types.llms.openai import BaseLiteLLMOpenAIResponseObject
|
|
from litellm.types.utils import ModelResponse, Usage
|
|
|
|
if TYPE_CHECKING:
|
|
from opentelemetry.trace import Span as _Span
|
|
|
|
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
|
from litellm.types.caching import RedisPipelineIncrementOperation
|
|
|
|
Span = Union[_Span, Any]
|
|
InternalUsageCache = _InternalUsageCache
|
|
else:
|
|
Span = Any
|
|
InternalUsageCache = Any
|
|
|
|
BATCH_RATE_LIMITER_SCRIPT = """
|
|
local results = {}
|
|
local now = tonumber(ARGV[1])
|
|
local window_size = tonumber(ARGV[2])
|
|
|
|
-- Process each window/counter pair
|
|
for i = 1, #KEYS, 2 do
|
|
local window_key = KEYS[i]
|
|
local counter_key = KEYS[i + 1]
|
|
local increment_value = 1
|
|
|
|
-- Check if window exists and is valid
|
|
local window_start = redis.call('GET', window_key)
|
|
if not window_start or (now - tonumber(window_start)) >= window_size then
|
|
-- Reset window and counter
|
|
redis.call('SET', window_key, tostring(now))
|
|
redis.call('SET', counter_key, increment_value)
|
|
redis.call('EXPIRE', window_key, window_size)
|
|
redis.call('EXPIRE', counter_key, window_size)
|
|
table.insert(results, tostring(now)) -- window_start
|
|
table.insert(results, increment_value) -- counter
|
|
else
|
|
local counter = redis.call('INCR', counter_key)
|
|
-- This happens when window_key exists but counter_key doesn't (e.g., tokens key
|
|
-- created after requests key when both share the same window_key)
|
|
local current_ttl = redis.call('TTL', counter_key)
|
|
if current_ttl == -1 then
|
|
redis.call('EXPIRE', counter_key, window_size)
|
|
end
|
|
table.insert(results, window_start) -- window_start
|
|
table.insert(results, counter) -- counter
|
|
end
|
|
end
|
|
|
|
return results
|
|
"""
|
|
|
|
CHECK_AND_INCREMENT_BY_N_SCRIPT = """
|
|
-- Atomic check-and-increment-by-N across one or more descriptors.
|
|
-- All-or-nothing: if any descriptor would exceed its limit, no counter is
|
|
-- modified.
|
|
--
|
|
-- Uses Redis server time (`redis.call('TIME')`) instead of a client-supplied
|
|
-- timestamp so that window resets are deterministic across replicas with
|
|
-- skewed wall-clocks. This prevents a clock-skew-induced reopening of the
|
|
-- TOCTOU window across multi-replica deployments.
|
|
--
|
|
-- KEYS layout: pairs of (window_key, counter_key), one pair per descriptor.
|
|
-- ARGV layout: per-descriptor 4-tuple, starting at ARGV[1]:
|
|
-- ARGV[(i-1)*4 + 1] = limit
|
|
-- ARGV[(i-1)*4 + 2] = increment
|
|
-- ARGV[(i-1)*4 + 3] = ttl_seconds (counter TTL when window resets)
|
|
-- ARGV[(i-1)*4 + 4] = window_size_seconds (sliding-window length)
|
|
--
|
|
-- Return on success: { 0, new_counter_1, new_counter_2, ... }
|
|
-- Return on over-limit: { 1, descriptor_index, current_counter, limit }
|
|
local time_reply = redis.call('TIME')
|
|
local now = tonumber(time_reply[1])
|
|
local descriptor_count = #KEYS / 2
|
|
|
|
-- Pass 1: read state, validate. Abort without writing if any over limit.
|
|
local descriptor_state = {}
|
|
for i = 1, descriptor_count do
|
|
local window_key = KEYS[(i - 1) * 2 + 1]
|
|
local counter_key = KEYS[(i - 1) * 2 + 2]
|
|
local arg_base = (i - 1) * 4 + 1
|
|
local limit = tonumber(ARGV[arg_base])
|
|
local increment = tonumber(ARGV[arg_base + 1])
|
|
local window_size = tonumber(ARGV[arg_base + 3])
|
|
|
|
local window_start = redis.call('GET', window_key)
|
|
local window_expired = (not window_start) or
|
|
((now - tonumber(window_start)) >= window_size)
|
|
|
|
local current_counter
|
|
if window_expired then
|
|
current_counter = 0
|
|
else
|
|
current_counter = tonumber(redis.call('GET', counter_key) or 0)
|
|
end
|
|
|
|
if current_counter + increment > limit then
|
|
return { 1, i, current_counter, limit }
|
|
end
|
|
|
|
descriptor_state[i] = { window_expired, current_counter }
|
|
end
|
|
|
|
-- Pass 2: all checks passed. Apply increments.
|
|
local results = { 0 }
|
|
for i = 1, descriptor_count do
|
|
local window_key = KEYS[(i - 1) * 2 + 1]
|
|
local counter_key = KEYS[(i - 1) * 2 + 2]
|
|
local arg_base = (i - 1) * 4 + 1
|
|
local increment = tonumber(ARGV[arg_base + 1])
|
|
local ttl = tonumber(ARGV[arg_base + 2])
|
|
local window_size = tonumber(ARGV[arg_base + 3])
|
|
|
|
local window_expired = descriptor_state[i][1]
|
|
|
|
if window_expired then
|
|
redis.call('SET', window_key, tostring(now))
|
|
redis.call('SET', counter_key, increment)
|
|
redis.call('EXPIRE', window_key, window_size)
|
|
if ttl > 0 then
|
|
redis.call('EXPIRE', counter_key, ttl)
|
|
end
|
|
table.insert(results, increment)
|
|
else
|
|
local new_counter = redis.call('INCRBY', counter_key, increment)
|
|
local current_ttl = redis.call('TTL', counter_key)
|
|
if current_ttl == -1 and ttl > 0 then
|
|
redis.call('EXPIRE', counter_key, ttl)
|
|
end
|
|
table.insert(results, new_counter)
|
|
end
|
|
end
|
|
|
|
return results
|
|
"""
|
|
|
|
TOKEN_INCREMENT_SCRIPT = """
|
|
local results = {}
|
|
|
|
-- Process each key/increment_value/ttl triplet
|
|
for i = 1, #KEYS do
|
|
local key = KEYS[i]
|
|
local increment_value = tonumber(ARGV[i * 2 - 1])
|
|
local ttl_seconds = tonumber(ARGV[i * 2])
|
|
|
|
-- Increment the value
|
|
local new_value = redis.call('INCRBYFLOAT', key, increment_value)
|
|
|
|
-- Handle TTL: only set expire if ttl_seconds > 0 and key has no current TTL
|
|
-- ttl_seconds can be 0 (no TTL) or positive (set TTL)
|
|
if ttl_seconds and ttl_seconds > 0 then
|
|
local current_ttl = redis.call('TTL', key)
|
|
if current_ttl == -1 then
|
|
redis.call('EXPIRE', key, ttl_seconds)
|
|
end
|
|
end
|
|
|
|
table.insert(results, new_value)
|
|
end
|
|
|
|
return results
|
|
"""
|
|
|
|
# Redis cluster slot count
|
|
REDIS_CLUSTER_SLOTS = 16384
|
|
REDIS_NODE_HASHTAG_NAME = "all_keys"
|
|
|
|
# TPM token reservation tuning constants.
|
|
# When max_tokens is not specified in the request we still need to reserve
|
|
# *some* output budget; these define that fallback estimate.
|
|
DEFAULT_MAX_TOKENS_ESTIMATE = 4096
|
|
DEFAULT_CHARS_PER_TOKEN = 4
|
|
# Fraction of the available output budget reserved as the upfront floor when
|
|
# the request omits max_tokens. Applied to both DEFAULT_MAX_TOKENS_ESTIMATE
|
|
# (baseline floor) and to the smallest configured TPM limit (capped floor for
|
|
# small per-tenant TPM caps).
|
|
_TPM_FLOOR_FRACTION = 4
|
|
# Stash for the reserved-token count on the request data dict so success/
|
|
# failure callbacks can reconcile against the upfront reservation.
|
|
TPM_RESERVED_TOKENS_KEY = "_litellm_tpm_reserved_tokens"
|
|
# Stash for the model identifier the reservation was charged against.
|
|
# Reconciliation must target the same key that was incremented at reservation
|
|
TPM_RESERVED_MODEL_KEY = "_litellm_tpm_reserved_model"
|
|
# Stash for the (scope_key, scope_value) pairs whose :tokens counter the
|
|
# upfront reservation incremented. Reconciliation applies the delta to these
|
|
# scopes only; scopes without a configured TPM limit were never charged at
|
|
# pre-call and must receive the full actual usage instead of the delta —
|
|
# otherwise their counters drift negative whenever actual < reserved.
|
|
TPM_RESERVED_SCOPES_KEY = "_litellm_tpm_reserved_scopes"
|
|
# Idempotency marker for the reservation refund path. Set when any failure
|
|
# callback releases the reservation so the next callback in the same flow
|
|
# (e.g. async_log_failure_event firing after async_post_call_failure_hook)
|
|
# does not double-refund.
|
|
TPM_RESERVATION_RELEASED_KEY = "_litellm_tpm_reservation_released"
|
|
RATE_LIMIT_DESCRIPTORS_KEY = "_litellm_rate_limit_descriptors"
|
|
# Stash keys live ONLY in metadata channels — never at the top level of the
|
|
# request body. Top-level keys are forwarded as body params to upstream
|
|
# providers, which reject unknown fields with 400/429 errors.
|
|
_LITELLM_STASH_KEYS: Tuple[str, ...] = (
|
|
TPM_RESERVED_TOKENS_KEY,
|
|
TPM_RESERVED_MODEL_KEY,
|
|
TPM_RESERVED_SCOPES_KEY,
|
|
TPM_RESERVATION_RELEASED_KEY,
|
|
RATE_LIMIT_DESCRIPTORS_KEY,
|
|
)
|
|
|
|
|
|
class RateLimitDescriptorRateLimitObject(TypedDict, total=False):
|
|
requests_per_unit: Optional[int]
|
|
tokens_per_unit: Optional[int]
|
|
max_parallel_requests: Optional[int]
|
|
window_size: Optional[int]
|
|
|
|
|
|
class RateLimitDescriptor(TypedDict):
|
|
key: str
|
|
value: str
|
|
rate_limit: Optional[RateLimitDescriptorRateLimitObject]
|
|
|
|
|
|
class RateLimitStatus(TypedDict):
|
|
code: str
|
|
current_limit: int
|
|
limit_remaining: int
|
|
rate_limit_type: Literal["requests", "tokens", "max_parallel_requests"]
|
|
descriptor_key: str
|
|
|
|
|
|
class RateLimitResponse(TypedDict):
|
|
overall_code: str
|
|
statuses: List[RateLimitStatus]
|
|
|
|
|
|
class RateLimitResponseWithDescriptors(TypedDict):
|
|
descriptors: List[RateLimitDescriptor]
|
|
response: RateLimitResponse
|
|
|
|
|
|
class _PROXY_MaxParallelRequestsHandler_v3(CustomLogger):
|
|
def __init__(
|
|
self,
|
|
internal_usage_cache: InternalUsageCache,
|
|
time_provider: Optional[Callable[[], datetime]] = None,
|
|
):
|
|
self.internal_usage_cache = internal_usage_cache
|
|
self._time_provider = time_provider or datetime.now
|
|
if self.internal_usage_cache.dual_cache.redis_cache is not None:
|
|
self.batch_rate_limiter_script = (
|
|
self.internal_usage_cache.dual_cache.redis_cache.async_register_script(
|
|
BATCH_RATE_LIMITER_SCRIPT
|
|
)
|
|
)
|
|
self.token_increment_script = (
|
|
self.internal_usage_cache.dual_cache.redis_cache.async_register_script(
|
|
TOKEN_INCREMENT_SCRIPT
|
|
)
|
|
)
|
|
self.check_and_increment_by_n_script = (
|
|
self.internal_usage_cache.dual_cache.redis_cache.async_register_script(
|
|
CHECK_AND_INCREMENT_BY_N_SCRIPT
|
|
)
|
|
)
|
|
else:
|
|
self.batch_rate_limiter_script = None
|
|
self.token_increment_script = None
|
|
self.check_and_increment_by_n_script = None
|
|
|
|
self.window_size = int(os.getenv("LITELLM_RATE_LIMIT_WINDOW_SIZE", 60))
|
|
|
|
# Batch rate limiter (lazy loaded)
|
|
self._batch_rate_limiter: Optional[Any] = None
|
|
|
|
# Serializes multi-phase check+increment sequences (batch + dynamic
|
|
# limiters) within this process to close the TOCTOU window between
|
|
# read-only check and counter increment. Multi-replica deployments
|
|
# additionally rely on Redis Lua atomicity for cross-process safety.
|
|
#
|
|
# Coarse granularity: this single lock serializes ALL atomic check+
|
|
# increment operations across batch and dynamic limiters on this
|
|
# instance. A slow batch input-file fetch (which happens upstream of
|
|
# the lock) does not block here, but Redis Lua latency does. If
|
|
# contention shows up under load (visible as p99 latency spikes
|
|
# correlated with batch traffic), shard to a per-descriptor-key lock
|
|
# via a `weakref.WeakValueDictionary[str, asyncio.Lock]`. Punted as a
|
|
# follow-up because Lua dominates wall-time and the lock is held for
|
|
# one round-trip.
|
|
self._check_and_increment_lock = asyncio.Lock()
|
|
|
|
def _get_batch_rate_limiter(self) -> Optional[Any]:
|
|
"""Get or lazy-load the batch rate limiter."""
|
|
if self._batch_rate_limiter is None:
|
|
try:
|
|
from litellm.proxy.hooks.batch_rate_limiter import (
|
|
_PROXY_BatchRateLimiter,
|
|
)
|
|
|
|
self._batch_rate_limiter = _PROXY_BatchRateLimiter(
|
|
internal_usage_cache=self.internal_usage_cache,
|
|
parallel_request_limiter=self,
|
|
)
|
|
except Exception as e:
|
|
verbose_proxy_logger.debug(
|
|
f"Could not load batch rate limiter: {str(e)}"
|
|
)
|
|
return self._batch_rate_limiter
|
|
|
|
def _get_current_time(self) -> datetime:
|
|
"""Return the current time for rate limiting calculations."""
|
|
return self._time_provider()
|
|
|
|
@staticmethod
|
|
def _no_max_tokens_output_floor(
|
|
min_configured_tpm_limit: Optional[int],
|
|
) -> int:
|
|
"""Output-budget floor used when the request omits max_tokens.
|
|
|
|
Capped at a fraction of the smallest configured TPM limit so a small
|
|
per-tenant cap can't be tripped by the floor alone. Returns the
|
|
baseline floor when no limit is provided.
|
|
"""
|
|
baseline = DEFAULT_MAX_TOKENS_ESTIMATE // _TPM_FLOOR_FRACTION
|
|
if min_configured_tpm_limit is None:
|
|
return baseline
|
|
return min(baseline, max(1, min_configured_tpm_limit // _TPM_FLOOR_FRACTION))
|
|
|
|
def _estimate_tokens_for_request(
|
|
self,
|
|
data: dict,
|
|
model: Optional[str] = None,
|
|
min_configured_tpm_limit: Optional[int] = None,
|
|
) -> int:
|
|
"""
|
|
Estimate total tokens this request will consume so we can reserve them
|
|
upfront (input + output budget):
|
|
estimated = input_tokens + max_tokens.
|
|
|
|
Supports chat (messages), completions (prompt), and embeddings (input).
|
|
|
|
``min_configured_tpm_limit`` is the smallest ``tokens_per_unit`` among
|
|
the TPM-bearing descriptors this request will be charged against. When
|
|
provided, the no-``max_tokens`` output-budget floor is capped at a
|
|
fraction of that limit so small TPM caps remain usable. Omit to
|
|
preserve the unconstrained floor.
|
|
"""
|
|
messages = data.get("messages")
|
|
prompt = data.get("prompt")
|
|
input_text = data.get("input") # embeddings
|
|
|
|
match (messages, prompt, input_text):
|
|
case (messages, _, _) if messages:
|
|
total_chars = len(get_str_from_messages(messages))
|
|
case (_, str() as p, _):
|
|
total_chars = len(p)
|
|
case (_, list() as p, _):
|
|
total_chars = sum(len(str(item)) for item in p)
|
|
case (_, _, str() as t):
|
|
total_chars = len(t)
|
|
case (_, _, list() as t):
|
|
total_chars = sum(len(str(item)) for item in t)
|
|
case _:
|
|
total_chars = 0
|
|
|
|
estimated_input_tokens = (
|
|
max(1, total_chars // DEFAULT_CHARS_PER_TOKEN) if total_chars > 0 else 0
|
|
)
|
|
|
|
explicit_max_tokens = data.get("max_tokens") or data.get(
|
|
"max_completion_tokens"
|
|
)
|
|
|
|
match (explicit_max_tokens, input_text):
|
|
case (mt, _) if mt is not None:
|
|
max_tokens_estimate = int(mt)
|
|
case (_, embeddings_input) if embeddings_input:
|
|
# Embeddings have no output tokens
|
|
max_tokens_estimate = 0
|
|
case _ if total_chars == 0:
|
|
# Fully contentless request (no messages, prompt, or input).
|
|
# Don't apply the conservative output-budget floor here — it
|
|
# would over-reserve and could push small TPM limits into a
|
|
# false 429. The caller floors at 1 so backpressure still
|
|
# applies once the counter is at limit.
|
|
max_tokens_estimate = 0
|
|
case _:
|
|
# No max_tokens specified — reserve at least the input size with a
|
|
# conservative floor so a stream of small concurrent requests can't
|
|
# collectively bypass the limit. Cap the floor by a fraction of
|
|
# the smallest TPM limit this request will be charged against,
|
|
# so a small per-tenant TPM cap can't be tripped by the floor
|
|
# alone.
|
|
output_floor = self._no_max_tokens_output_floor(
|
|
min_configured_tpm_limit
|
|
)
|
|
max_tokens_estimate = max(estimated_input_tokens, output_floor)
|
|
|
|
total_estimated = estimated_input_tokens + max_tokens_estimate
|
|
|
|
verbose_proxy_logger.debug(
|
|
f"TPM reservation estimate: input={estimated_input_tokens}, "
|
|
f"max_tokens={max_tokens_estimate} (explicit={explicit_max_tokens is not None}), "
|
|
f"total={total_estimated}"
|
|
)
|
|
|
|
return total_estimated
|
|
|
|
def _is_redis_cluster(self) -> bool:
|
|
"""
|
|
Check if the dual cache is using Redis cluster.
|
|
|
|
Returns:
|
|
bool: True if using Redis cluster, False otherwise.
|
|
"""
|
|
from litellm.caching.redis_cluster_cache import RedisClusterCache
|
|
|
|
return (
|
|
self.internal_usage_cache.dual_cache.redis_cache is not None
|
|
and isinstance(
|
|
self.internal_usage_cache.dual_cache.redis_cache, RedisClusterCache
|
|
)
|
|
)
|
|
|
|
async def in_memory_cache_sliding_window(
|
|
self,
|
|
keys: List[str],
|
|
now_int: int,
|
|
window_size: int,
|
|
) -> List[Any]:
|
|
"""
|
|
Implement sliding window rate limiting logic using in-memory cache operations.
|
|
This follows the same logic as the Redis Lua script but uses async cache operations.
|
|
"""
|
|
results: List[Any] = []
|
|
|
|
# Process each window/counter pair
|
|
for i in range(0, len(keys), 2):
|
|
window_key = keys[i]
|
|
counter_key = keys[i + 1]
|
|
increment_value = 1
|
|
|
|
# Get the window start time
|
|
window_start = await self.internal_usage_cache.async_get_cache(
|
|
key=window_key,
|
|
litellm_parent_otel_span=None,
|
|
local_only=True,
|
|
)
|
|
|
|
# Check if window exists and is valid
|
|
if window_start is None or (now_int - int(window_start)) >= window_size:
|
|
# Reset window and counter
|
|
await self.internal_usage_cache.async_set_cache(
|
|
key=window_key,
|
|
value=str(now_int),
|
|
ttl=window_size,
|
|
litellm_parent_otel_span=None,
|
|
local_only=True,
|
|
)
|
|
await self.internal_usage_cache.async_set_cache(
|
|
key=counter_key,
|
|
value=increment_value,
|
|
ttl=window_size,
|
|
litellm_parent_otel_span=None,
|
|
local_only=True,
|
|
)
|
|
results.append(str(now_int)) # window_start
|
|
results.append(increment_value) # counter
|
|
else:
|
|
# Increment the counter
|
|
current_counter = await self.internal_usage_cache.async_get_cache(
|
|
key=counter_key,
|
|
litellm_parent_otel_span=None,
|
|
local_only=True,
|
|
)
|
|
new_counter_value = (
|
|
int(current_counter) if current_counter is not None else 0
|
|
) + increment_value
|
|
await self.internal_usage_cache.async_set_cache(
|
|
key=counter_key,
|
|
value=new_counter_value,
|
|
ttl=window_size,
|
|
litellm_parent_otel_span=None,
|
|
local_only=True,
|
|
)
|
|
results.append(window_start) # window_start
|
|
results.append(new_counter_value) # counter
|
|
|
|
return results
|
|
|
|
def create_rate_limit_keys(
|
|
self,
|
|
key: str,
|
|
value: str,
|
|
rate_limit_type: Literal["requests", "tokens", "max_parallel_requests"],
|
|
) -> str:
|
|
"""
|
|
Create the rate limit keys for the given key and value.
|
|
"""
|
|
counter_key = f"{{{key}:{value}}}:{rate_limit_type}"
|
|
|
|
return counter_key
|
|
|
|
def is_cache_list_over_limit(
|
|
self,
|
|
keys_to_fetch: List[str],
|
|
cache_values: List[Any],
|
|
key_metadata: Dict[str, Any],
|
|
) -> RateLimitResponse:
|
|
"""
|
|
Check if the cache values are over the limit.
|
|
"""
|
|
statuses: List[RateLimitStatus] = []
|
|
overall_code = "OK"
|
|
|
|
for i in range(0, len(cache_values), 2):
|
|
item_code = "OK"
|
|
window_key = keys_to_fetch[i]
|
|
counter_key = keys_to_fetch[i + 1]
|
|
counter_value = cache_values[i + 1]
|
|
requests_limit = key_metadata[window_key]["requests_limit"]
|
|
max_parallel_requests_limit = key_metadata[window_key][
|
|
"max_parallel_requests_limit"
|
|
]
|
|
tokens_limit = key_metadata[window_key]["tokens_limit"]
|
|
|
|
# Determine which limit to use for current_limit and limit_remaining
|
|
current_limit: Optional[int] = None
|
|
rate_limit_type: Optional[
|
|
Literal["requests", "tokens", "max_parallel_requests"]
|
|
] = None
|
|
if counter_key.endswith(":requests"):
|
|
current_limit = requests_limit
|
|
rate_limit_type = "requests"
|
|
elif counter_key.endswith(":max_parallel_requests"):
|
|
current_limit = max_parallel_requests_limit
|
|
rate_limit_type = "max_parallel_requests"
|
|
elif counter_key.endswith(":tokens"):
|
|
current_limit = tokens_limit
|
|
rate_limit_type = "tokens"
|
|
|
|
if current_limit is None or rate_limit_type is None:
|
|
continue
|
|
|
|
if counter_value is not None and int(counter_value) > current_limit:
|
|
overall_code = "OVER_LIMIT"
|
|
item_code = "OVER_LIMIT"
|
|
|
|
# Only compute limit_remaining if current_limit is not None
|
|
limit_remaining = (
|
|
current_limit - int(counter_value)
|
|
if counter_value is not None
|
|
else current_limit
|
|
)
|
|
|
|
statuses.append(
|
|
{
|
|
"code": item_code,
|
|
"current_limit": current_limit,
|
|
"limit_remaining": limit_remaining,
|
|
"rate_limit_type": rate_limit_type,
|
|
"descriptor_key": key_metadata[window_key]["descriptor_key"],
|
|
}
|
|
)
|
|
|
|
return RateLimitResponse(overall_code=overall_code, statuses=statuses)
|
|
|
|
def keyslot_for_redis_cluster(self, key: str) -> int:
|
|
"""
|
|
Compute the Redis Cluster slot for a given key.
|
|
|
|
Simple implementation of `HASH_SLOT = CRC16(key) mod 16384`
|
|
|
|
Read more about hash slots here: https://medium.com/@linz07m/how-hash-slots-power-data-distribution-in-redis-cluster-bc5b7e74ca7d
|
|
|
|
Args:
|
|
key (str): The Redis key.
|
|
|
|
Returns:
|
|
int: The slot number (0-16383).
|
|
|
|
|
|
"""
|
|
# Handle hash tags: use substring between { and }
|
|
start = key.find("{")
|
|
if start != -1:
|
|
end = key.find("}", start + 1)
|
|
if end != -1 and end != start + 1:
|
|
key = key[start + 1 : end]
|
|
|
|
# Compute CRC16 and mod 16384
|
|
crc = binascii.crc_hqx(key.encode("utf-8"), 0)
|
|
return crc % REDIS_CLUSTER_SLOTS
|
|
|
|
def _group_keys_by_hash_tag(self, keys: List[str]) -> Dict[str, List[str]]:
|
|
"""
|
|
Group keys by their Redis hash tag to ensure cluster compatibility.
|
|
|
|
For Redis clusters, uses slot calculation to group keys that belong to the same slot.
|
|
For regular Redis, no grouping is needed - all keys can be processed together.
|
|
"""
|
|
groups: Dict[str, List[str]] = {}
|
|
|
|
# Use slot calculation for Redis clusters only
|
|
if self._is_redis_cluster():
|
|
for key in keys:
|
|
slot = self.keyslot_for_redis_cluster(key)
|
|
slot_key = f"slot_{slot}"
|
|
|
|
if slot_key not in groups:
|
|
groups[slot_key] = []
|
|
groups[slot_key].append(key)
|
|
else:
|
|
# For regular Redis, no grouping needed - process all keys together
|
|
groups[REDIS_NODE_HASHTAG_NAME] = keys
|
|
|
|
return groups
|
|
|
|
async def _execute_redis_batch_rate_limiter_script(
|
|
self,
|
|
keys_to_fetch: List[str],
|
|
now_int: int,
|
|
) -> List[Any]:
|
|
"""
|
|
Execute Redis operations grouped by hash tag for cluster compatibility.
|
|
|
|
Args:
|
|
keys_to_fetch: List[str] - List of keys to fetch
|
|
now_int: int - Current timestamp
|
|
|
|
Returns:
|
|
List[Any] - List of cache values
|
|
"""
|
|
if self.batch_rate_limiter_script is None:
|
|
return []
|
|
|
|
key_groups = self._group_keys_by_hash_tag(keys_to_fetch)
|
|
all_cache_values = []
|
|
|
|
for hash_tag, group_keys in key_groups.items():
|
|
try:
|
|
group_cache_values = await self.batch_rate_limiter_script(
|
|
keys=group_keys,
|
|
args=[now_int, self.window_size], # Use integer timestamp
|
|
)
|
|
all_cache_values.extend(group_cache_values)
|
|
except Exception as e:
|
|
verbose_proxy_logger.warning(
|
|
f"Redis Lua script failed for hash tag {hash_tag}: {str(e)}"
|
|
)
|
|
# Fallback to in-memory cache for this group
|
|
group_cache_values = await self.in_memory_cache_sliding_window(
|
|
keys=group_keys,
|
|
now_int=now_int,
|
|
window_size=self.window_size,
|
|
)
|
|
all_cache_values.extend(group_cache_values)
|
|
|
|
return all_cache_values
|
|
|
|
async def should_rate_limit(
|
|
self,
|
|
descriptors: List[RateLimitDescriptor],
|
|
parent_otel_span: Optional[Span] = None,
|
|
read_only: bool = False,
|
|
skip_tpm_check: bool = False,
|
|
) -> RateLimitResponse:
|
|
"""
|
|
Check if any of the rate limit descriptors should be rate limited.
|
|
Returns a RateLimitResponse with the overall code and status for each descriptor.
|
|
Uses batch operations for Redis to improve performance.
|
|
|
|
Args:
|
|
descriptors: List of rate limit descriptors to check
|
|
parent_otel_span: Optional OpenTelemetry span for tracing
|
|
read_only: If True, only check limits without incrementing counters
|
|
skip_tpm_check: If True, ignore each descriptor's ``tokens_per_unit``
|
|
— the :tokens counter is neither read nor incremented by this
|
|
pass. Callers that handle TPM via the atomic
|
|
``reserve_tpm_tokens`` reservation path should set this to
|
|
avoid the +1-per-key Lua / in-memory increment double-charging
|
|
the tokens counter.
|
|
"""
|
|
|
|
current_time = self._get_current_time()
|
|
now = current_time.timestamp()
|
|
now_int = int(now) # Convert to integer for Redis Lua script
|
|
|
|
# Collect all keys and their metadata upfront
|
|
keys_to_fetch: List[str] = []
|
|
key_metadata = {} # Store metadata for each key
|
|
for descriptor in descriptors:
|
|
descriptor_key = descriptor["key"]
|
|
descriptor_value = descriptor["value"]
|
|
rate_limit: RateLimitDescriptorRateLimitObject = (
|
|
descriptor.get("rate_limit") or RateLimitDescriptorRateLimitObject()
|
|
)
|
|
requests_limit = rate_limit.get("requests_per_unit")
|
|
tokens_limit = None if skip_tpm_check else rate_limit.get("tokens_per_unit")
|
|
max_parallel_requests_limit = rate_limit.get("max_parallel_requests")
|
|
window_size = rate_limit.get("window_size") or self.window_size
|
|
|
|
window_key = f"{{{descriptor_key}:{descriptor_value}}}:window"
|
|
|
|
rate_limit_set = False
|
|
if requests_limit is not None:
|
|
rpm_key = self.create_rate_limit_keys(
|
|
descriptor_key, descriptor_value, "requests"
|
|
)
|
|
keys_to_fetch.extend([window_key, rpm_key])
|
|
rate_limit_set = True
|
|
if tokens_limit is not None:
|
|
tpm_key = self.create_rate_limit_keys(
|
|
descriptor_key, descriptor_value, "tokens"
|
|
)
|
|
keys_to_fetch.extend([window_key, tpm_key])
|
|
rate_limit_set = True
|
|
if max_parallel_requests_limit is not None:
|
|
max_parallel_requests_key = self.create_rate_limit_keys(
|
|
descriptor_key, descriptor_value, "max_parallel_requests"
|
|
)
|
|
keys_to_fetch.extend([window_key, max_parallel_requests_key])
|
|
rate_limit_set = True
|
|
|
|
if not rate_limit_set:
|
|
continue
|
|
|
|
key_metadata[window_key] = {
|
|
"requests_limit": (
|
|
int(requests_limit) if requests_limit is not None else None
|
|
),
|
|
"tokens_limit": int(tokens_limit) if tokens_limit is not None else None,
|
|
"max_parallel_requests_limit": (
|
|
int(max_parallel_requests_limit)
|
|
if max_parallel_requests_limit is not None
|
|
else None
|
|
),
|
|
"window_size": int(window_size),
|
|
"descriptor_key": descriptor_key,
|
|
}
|
|
|
|
## CHECK IN-MEMORY CACHE
|
|
cache_values = await self.internal_usage_cache.async_batch_get_cache(
|
|
keys=keys_to_fetch,
|
|
parent_otel_span=parent_otel_span,
|
|
local_only=True,
|
|
)
|
|
|
|
if cache_values is not None:
|
|
rate_limit_response = self.is_cache_list_over_limit(
|
|
keys_to_fetch, cache_values, key_metadata
|
|
)
|
|
if rate_limit_response["overall_code"] == "OVER_LIMIT":
|
|
return rate_limit_response
|
|
|
|
## IF under limit in-memory, check Redis
|
|
if read_only:
|
|
# READ-ONLY MODE: Just read current values without incrementing
|
|
cache_values = await self.internal_usage_cache.async_batch_get_cache(
|
|
keys=keys_to_fetch,
|
|
parent_otel_span=parent_otel_span,
|
|
local_only=False, # Check Redis too
|
|
)
|
|
|
|
# For keys that don't exist yet, set them to 0
|
|
if cache_values is None:
|
|
cache_values = []
|
|
for _ in keys_to_fetch:
|
|
cache_values.append(str(now_int) if _.endswith(":window") else 0)
|
|
elif self.batch_rate_limiter_script is not None:
|
|
# NORMAL MODE: Increment counters in Redis
|
|
# Group keys by hash tag for Redis cluster compatibility
|
|
cache_values = await self._execute_redis_batch_rate_limiter_script(
|
|
keys_to_fetch=keys_to_fetch,
|
|
now_int=now_int,
|
|
)
|
|
|
|
# update in-memory cache with new values
|
|
for i in range(0, len(cache_values), 2):
|
|
window_key = keys_to_fetch[i]
|
|
counter_key = keys_to_fetch[i + 1]
|
|
window_value = cache_values[i]
|
|
counter_value = cache_values[i + 1]
|
|
await self.internal_usage_cache.async_set_cache(
|
|
key=counter_key,
|
|
value=counter_value,
|
|
ttl=self.window_size,
|
|
litellm_parent_otel_span=parent_otel_span,
|
|
local_only=True,
|
|
)
|
|
await self.internal_usage_cache.async_set_cache(
|
|
key=window_key,
|
|
value=window_value,
|
|
ttl=self.window_size,
|
|
litellm_parent_otel_span=parent_otel_span,
|
|
local_only=True,
|
|
)
|
|
else:
|
|
# NORMAL MODE: In-memory sliding window (no Redis)
|
|
cache_values = await self.in_memory_cache_sliding_window(
|
|
keys=keys_to_fetch,
|
|
now_int=now_int,
|
|
window_size=self.window_size,
|
|
)
|
|
|
|
rate_limit_response = self.is_cache_list_over_limit(
|
|
keys_to_fetch, cache_values, key_metadata
|
|
)
|
|
return rate_limit_response
|
|
|
|
async def atomic_check_and_increment_by_n(
|
|
self,
|
|
descriptors: List[RateLimitDescriptor],
|
|
increments: List[Dict[Literal["requests", "tokens"], int]],
|
|
parent_otel_span: Optional[Span] = None,
|
|
) -> RateLimitResponse:
|
|
"""
|
|
Atomic check-and-increment-by-N across one or more descriptors.
|
|
|
|
All-or-nothing: if any descriptor would exceed its limit, no counter is
|
|
modified and the response carries `overall_code = "OVER_LIMIT"` with
|
|
the offending descriptor's status. Closes the TOCTOU window between
|
|
read and increment in both single-process and multi-process (Redis)
|
|
deployments.
|
|
|
|
Cluster-safety: each descriptor's keys all share a `{key:value}` hash
|
|
tag, so the Redis Lua path issues one Lua call per descriptor — every
|
|
call's keys co-locate on a single Redis Cluster slot, avoiding
|
|
CROSSSLOT errors. Cross-descriptor atomicity is preserved via
|
|
refund-on-rollback: if descriptor i is OVER_LIMIT, descriptors 0..i-1
|
|
get a direct INCRBY refund (refunds need no atomicity guarantee).
|
|
|
|
Args:
|
|
descriptors: rate-limit descriptors to check
|
|
increments: per-descriptor increment amounts, indexed parallel to
|
|
`descriptors`. Each entry is `{"requests": int, "tokens": int}`
|
|
— values default to 0 when a descriptor has no matching limit.
|
|
|
|
Returns:
|
|
RateLimitResponse with one status per (descriptor, rate_limit_type)
|
|
counter, mirroring `should_rate_limit`'s shape.
|
|
"""
|
|
if len(descriptors) != len(increments):
|
|
raise ValueError(
|
|
"atomic_check_and_increment_by_n: descriptors and increments "
|
|
"must have the same length"
|
|
)
|
|
|
|
# Build per-descriptor (keys, args, meta) groups. All keys within a
|
|
# group share the descriptor's {key:value} hash tag, so a single Lua
|
|
# call per group never triggers CROSSSLOT on Redis Cluster.
|
|
descriptor_groups: List[Tuple[List[str], List[Any], List[Dict[str, Any]]]] = []
|
|
for descriptor, increment_amounts in zip(descriptors, increments):
|
|
keys, args, meta = self._build_descriptor_atomic_payload(
|
|
descriptor=descriptor,
|
|
increment_amounts=increment_amounts,
|
|
)
|
|
if keys:
|
|
descriptor_groups.append((keys, args, meta))
|
|
|
|
if not descriptor_groups:
|
|
return RateLimitResponse(overall_code="OK", statuses=[])
|
|
|
|
# Multi-process atomicity via Redis Lua, per descriptor for slot
|
|
# co-location. Single-process atomicity falls back to the
|
|
# asyncio.Lock + in-memory sliding window below — there are no
|
|
# cluster slot concerns locally, so we keep the batched 2-phase
|
|
# critical section for true cross-descriptor atomicity.
|
|
if self.check_and_increment_by_n_script is not None:
|
|
return await self._atomic_lua_per_descriptor(
|
|
descriptor_groups=descriptor_groups,
|
|
parent_otel_span=parent_otel_span,
|
|
)
|
|
|
|
flat_meta: List[Dict[str, Any]] = [
|
|
m for _keys, _args, group_meta in descriptor_groups for m in group_meta
|
|
]
|
|
async with self._check_and_increment_lock:
|
|
return await self._atomic_check_and_increment_in_memory(
|
|
per_counter_meta=flat_meta,
|
|
parent_otel_span=parent_otel_span,
|
|
)
|
|
|
|
def _build_descriptor_atomic_payload(
|
|
self,
|
|
descriptor: RateLimitDescriptor,
|
|
increment_amounts: Dict[Literal["requests", "tokens"], int],
|
|
) -> Tuple[List[str], List[Any], List[Dict[str, Any]]]:
|
|
"""
|
|
Build (KEYS, ARGV, per-counter meta) for a single descriptor's Lua
|
|
call. All keys returned share the descriptor's {key:value} hash tag.
|
|
"""
|
|
descriptor_key = descriptor["key"]
|
|
descriptor_value = descriptor["value"]
|
|
rate_limit: RateLimitDescriptorRateLimitObject = (
|
|
descriptor.get("rate_limit") or RateLimitDescriptorRateLimitObject()
|
|
)
|
|
window_size = rate_limit.get("window_size") or self.window_size
|
|
window_key = f"{{{descriptor_key}:{descriptor_value}}}:window"
|
|
|
|
keys: List[str] = []
|
|
args: List[Any] = []
|
|
meta: List[Dict[str, Any]] = []
|
|
|
|
for rate_limit_type in ("requests", "tokens"):
|
|
rlt: Literal["requests", "tokens"] = cast(
|
|
Literal["requests", "tokens"], rate_limit_type
|
|
)
|
|
if rlt == "requests":
|
|
limit_value = rate_limit.get("requests_per_unit")
|
|
inc_amount = int(increment_amounts.get("requests", 0) or 0)
|
|
else:
|
|
limit_value = rate_limit.get("tokens_per_unit")
|
|
inc_amount = int(increment_amounts.get("tokens", 0) or 0)
|
|
if limit_value is None or inc_amount <= 0:
|
|
continue
|
|
counter_key = self.create_rate_limit_keys(
|
|
descriptor_key, descriptor_value, rlt
|
|
)
|
|
# Counter-key TTL and window_size are conceptually distinct
|
|
# ("how long the counter Redis key lives" vs "how long the
|
|
# sliding window is"). Kept as separate values so a future
|
|
# custom-TTL descriptor doesn't reintroduce a silent expiry bug.
|
|
ttl_seconds = int(window_size)
|
|
window_size_seconds = int(window_size)
|
|
keys.extend([window_key, counter_key])
|
|
# 4-tuple matches the Lua ARGV layout:
|
|
# [limit, increment, ttl_seconds, window_size_seconds].
|
|
args.extend(
|
|
[int(limit_value), inc_amount, ttl_seconds, window_size_seconds]
|
|
)
|
|
meta.append(
|
|
{
|
|
"descriptor_key": descriptor_key,
|
|
"current_limit": int(limit_value),
|
|
"rate_limit_type": rlt,
|
|
"window_key": window_key,
|
|
"counter_key": counter_key,
|
|
"increment": inc_amount,
|
|
"ttl": ttl_seconds,
|
|
"window_size": window_size_seconds,
|
|
}
|
|
)
|
|
return keys, args, meta
|
|
|
|
async def _atomic_lua_per_descriptor(
|
|
self,
|
|
descriptor_groups: List[Tuple[List[str], List[Any], List[Dict[str, Any]]]],
|
|
parent_otel_span: Optional[Span] = None,
|
|
) -> RateLimitResponse:
|
|
"""
|
|
Run Lua check-and-increment one descriptor at a time so each call's
|
|
keys co-locate on a single Redis Cluster slot. On OVER_LIMIT for
|
|
descriptor i, refund descriptors 0..i-1's increments. On Lua failure
|
|
mid-loop, refund applied increments and fall back to in-memory.
|
|
"""
|
|
applied: List[List[Dict[str, Any]]] = []
|
|
statuses: List[RateLimitStatus] = []
|
|
|
|
for _idx, (keys, args, meta) in enumerate(descriptor_groups):
|
|
try:
|
|
raw = await self.check_and_increment_by_n_script(
|
|
keys=keys,
|
|
args=args,
|
|
)
|
|
except Exception as e:
|
|
# Lua failure (timeout, OOM, network partition) leaves Redis
|
|
# state ambiguous. Refund any prior groups so Redis returns
|
|
# to its pre-call state, then fall back to in-memory for the
|
|
# whole call (counters there are independent of Redis).
|
|
verbose_proxy_logger.error(
|
|
f"atomic_check_and_increment_by_n: Redis Lua execution "
|
|
f"failed ({type(e).__name__}: {e}). Refunding "
|
|
f"{len(applied)} prior descriptors and falling back to "
|
|
f"in-memory enforcement — counters will diverge from "
|
|
f"Redis until window expires (window_size="
|
|
f"{self.window_size}s)."
|
|
)
|
|
await self._refund_applied_descriptor_groups(applied)
|
|
flat_meta: List[Dict[str, Any]] = [
|
|
m for _k, _a, group_meta in descriptor_groups for m in group_meta
|
|
]
|
|
async with self._check_and_increment_lock:
|
|
return await self._atomic_check_and_increment_in_memory(
|
|
per_counter_meta=flat_meta,
|
|
parent_otel_span=parent_otel_span,
|
|
)
|
|
|
|
response = self._build_atomic_response(raw, meta)
|
|
if response["overall_code"] == "OVER_LIMIT":
|
|
await self._refund_applied_descriptor_groups(applied)
|
|
return response
|
|
applied.append(meta)
|
|
statuses.extend(response["statuses"])
|
|
|
|
return RateLimitResponse(overall_code="OK", statuses=statuses)
|
|
|
|
async def _refund_applied_descriptor_groups(
|
|
self,
|
|
applied: List[List[Dict[str, Any]]],
|
|
) -> None:
|
|
"""
|
|
Decrement counters for descriptor groups already applied via Lua.
|
|
Best-effort: refund failures are logged but not raised — the original
|
|
OVER_LIMIT / fallback decision is what matters to the caller.
|
|
"""
|
|
if not applied:
|
|
return
|
|
redis_cache = self.internal_usage_cache.dual_cache.redis_cache
|
|
if redis_cache is None:
|
|
return
|
|
for group_meta in applied:
|
|
for entry in group_meta:
|
|
try:
|
|
await redis_cache.async_increment(
|
|
key=entry["counter_key"],
|
|
value=-entry["increment"],
|
|
)
|
|
except Exception as e:
|
|
verbose_proxy_logger.warning(
|
|
f"Failed to refund {entry['counter_key']} on "
|
|
f"cross-descriptor rollback: {e}"
|
|
)
|
|
|
|
def _build_atomic_response(
|
|
self,
|
|
raw: List[Any],
|
|
per_counter_meta: List[Dict[str, Any]],
|
|
) -> RateLimitResponse:
|
|
"""Convert Lua script return value to RateLimitResponse.
|
|
|
|
Indexing invariant: `per_counter_meta` and `KEYS` are parallel-indexed
|
|
at the COUNTER level, not the descriptor level. A descriptor with both
|
|
RPM and TPM limits emits two `(window_key, counter_key)` pairs and
|
|
two meta entries — one per counter. The Lua script's loop variable
|
|
`i` therefore enumerates counters, and the over-limit return tuple
|
|
`{1, i, ...}` carries a counter index that maps directly to
|
|
`per_counter_meta[i - 1]`. Keep these arrays parallel at the counter
|
|
level when modifying this code.
|
|
"""
|
|
if not raw:
|
|
return RateLimitResponse(overall_code="OK", statuses=[])
|
|
|
|
status_code = int(raw[0])
|
|
if status_code == 1:
|
|
# Over limit: { 1, counter_index (1-based), current_counter, limit }
|
|
descriptor_index = int(raw[1]) - 1
|
|
current_counter = int(raw[2])
|
|
limit = int(raw[3])
|
|
meta = per_counter_meta[descriptor_index]
|
|
return RateLimitResponse(
|
|
overall_code="OVER_LIMIT",
|
|
statuses=[
|
|
RateLimitStatus(
|
|
code="OVER_LIMIT",
|
|
current_limit=limit,
|
|
limit_remaining=max(0, limit - current_counter),
|
|
rate_limit_type=meta["rate_limit_type"],
|
|
descriptor_key=meta["descriptor_key"],
|
|
)
|
|
],
|
|
)
|
|
|
|
statuses: List[RateLimitStatus] = []
|
|
for meta, new_counter in zip(per_counter_meta, raw[1:]):
|
|
statuses.append(
|
|
RateLimitStatus(
|
|
code="OK",
|
|
current_limit=meta["current_limit"],
|
|
limit_remaining=max(0, meta["current_limit"] - int(new_counter)),
|
|
rate_limit_type=meta["rate_limit_type"],
|
|
descriptor_key=meta["descriptor_key"],
|
|
)
|
|
)
|
|
return RateLimitResponse(overall_code="OK", statuses=statuses)
|
|
|
|
async def _atomic_check_and_increment_in_memory(
|
|
self,
|
|
per_counter_meta: List[Dict[str, Any]],
|
|
parent_otel_span: Optional[Span] = None,
|
|
) -> RateLimitResponse:
|
|
"""In-memory all-or-nothing check-and-increment. Caller holds lock.
|
|
|
|
Reads/writes the LOCAL DualCache (`local_only=True`) — note this is
|
|
a different store from Redis. When this fallback fires after a Lua
|
|
failure, in-memory counters will diverge from Redis until each key's
|
|
window expires (TTL bounds divergence).
|
|
"""
|
|
# Use a single 'now' for the duration of this critical section so all
|
|
# descriptors evaluate window expiry consistently.
|
|
now_int = int(self._get_current_time().timestamp())
|
|
|
|
# Pass 1: read state, validate.
|
|
descriptor_state: List[Dict[str, Any]] = []
|
|
for meta in per_counter_meta:
|
|
window_size = meta["window_size"]
|
|
window_start = await self.internal_usage_cache.async_get_cache(
|
|
key=meta["window_key"],
|
|
litellm_parent_otel_span=parent_otel_span,
|
|
local_only=True,
|
|
)
|
|
window_expired = (
|
|
window_start is None or (now_int - int(window_start)) >= window_size
|
|
)
|
|
current_counter = (
|
|
0
|
|
if window_expired
|
|
else int(
|
|
await self.internal_usage_cache.async_get_cache(
|
|
key=meta["counter_key"],
|
|
litellm_parent_otel_span=parent_otel_span,
|
|
local_only=True,
|
|
)
|
|
or 0
|
|
)
|
|
)
|
|
if current_counter + meta["increment"] > meta["current_limit"]:
|
|
return RateLimitResponse(
|
|
overall_code="OVER_LIMIT",
|
|
statuses=[
|
|
RateLimitStatus(
|
|
code="OVER_LIMIT",
|
|
current_limit=meta["current_limit"],
|
|
limit_remaining=max(
|
|
0, meta["current_limit"] - current_counter
|
|
),
|
|
rate_limit_type=meta["rate_limit_type"],
|
|
descriptor_key=meta["descriptor_key"],
|
|
)
|
|
],
|
|
)
|
|
descriptor_state.append(
|
|
{"window_expired": window_expired, "current": current_counter}
|
|
)
|
|
|
|
# Pass 2: apply increments.
|
|
statuses: List[RateLimitStatus] = []
|
|
for meta, state in zip(per_counter_meta, descriptor_state):
|
|
new_counter = (
|
|
meta["increment"]
|
|
if state["window_expired"]
|
|
else state["current"] + meta["increment"]
|
|
)
|
|
if state["window_expired"]:
|
|
await self.internal_usage_cache.async_set_cache(
|
|
key=meta["window_key"],
|
|
value=str(now_int),
|
|
ttl=meta["window_size"],
|
|
litellm_parent_otel_span=parent_otel_span,
|
|
local_only=True,
|
|
)
|
|
await self.internal_usage_cache.async_set_cache(
|
|
key=meta["counter_key"],
|
|
value=new_counter,
|
|
ttl=meta["ttl"],
|
|
litellm_parent_otel_span=parent_otel_span,
|
|
local_only=True,
|
|
)
|
|
statuses.append(
|
|
RateLimitStatus(
|
|
code="OK",
|
|
current_limit=meta["current_limit"],
|
|
limit_remaining=max(0, meta["current_limit"] - new_counter),
|
|
rate_limit_type=meta["rate_limit_type"],
|
|
descriptor_key=meta["descriptor_key"],
|
|
)
|
|
)
|
|
return RateLimitResponse(overall_code="OK", statuses=statuses)
|
|
|
|
async def reserve_tpm_tokens(
|
|
self,
|
|
descriptors: List[RateLimitDescriptor],
|
|
estimated_tokens: int,
|
|
parent_otel_span: Optional[Span] = None,
|
|
) -> RateLimitResponse:
|
|
"""
|
|
Reserve ``estimated_tokens`` against every TPM-bearing descriptor
|
|
BEFORE the upstream call, so concurrent requests cannot all observe
|
|
"under limit" before any of them increments the counter.
|
|
|
|
Thin wrapper around ``atomic_check_and_increment_by_n``: builds a
|
|
TPM-only descriptor/increment list and delegates the all-or-nothing
|
|
atomicity (Lua on Redis, asyncio-locked DualCache otherwise) to the
|
|
shared primitive.
|
|
"""
|
|
tpm_descriptors: List[RateLimitDescriptor] = [
|
|
d
|
|
for d in descriptors
|
|
if (d.get("rate_limit") or {}).get("tokens_per_unit") is not None
|
|
]
|
|
if not tpm_descriptors:
|
|
return RateLimitResponse(overall_code="OK", statuses=[])
|
|
|
|
increments: List[Dict[Literal["requests", "tokens"], int]] = [
|
|
{"tokens": estimated_tokens} for _ in tpm_descriptors
|
|
]
|
|
return await self.atomic_check_and_increment_by_n(
|
|
descriptors=tpm_descriptors,
|
|
increments=increments,
|
|
parent_otel_span=parent_otel_span,
|
|
)
|
|
|
|
def create_organization_rate_limit_descriptor(
|
|
self, user_api_key_dict: UserAPIKeyAuth, requested_model: Optional[str] = None
|
|
) -> List[RateLimitDescriptor]:
|
|
descriptors: List[RateLimitDescriptor] = []
|
|
|
|
# Global org rate limits
|
|
if user_api_key_dict.org_id is not None and (
|
|
user_api_key_dict.organization_rpm_limit is not None
|
|
or user_api_key_dict.organization_tpm_limit is not None
|
|
):
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="organization",
|
|
value=user_api_key_dict.org_id,
|
|
rate_limit={
|
|
"requests_per_unit": user_api_key_dict.organization_rpm_limit,
|
|
"tokens_per_unit": user_api_key_dict.organization_tpm_limit,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
# Model specific org rate limits
|
|
if (
|
|
get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "organization_metadata", "model_rpm_limit"
|
|
)
|
|
is not None
|
|
or get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "organization_metadata", "model_tpm_limit"
|
|
)
|
|
is not None
|
|
):
|
|
_tpm_limit_for_team_model = (
|
|
get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "organization_metadata", "model_tpm_limit"
|
|
)
|
|
or {}
|
|
)
|
|
_rpm_limit_for_team_model = (
|
|
get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "organization_metadata", "model_rpm_limit"
|
|
)
|
|
or {}
|
|
)
|
|
|
|
should_check_rate_limit = False
|
|
if requested_model in _tpm_limit_for_team_model:
|
|
should_check_rate_limit = True
|
|
elif requested_model in _rpm_limit_for_team_model:
|
|
should_check_rate_limit = True
|
|
|
|
if should_check_rate_limit:
|
|
model_specific_tpm_limit = None
|
|
model_specific_rpm_limit = None
|
|
if requested_model in _tpm_limit_for_team_model:
|
|
model_specific_tpm_limit = _tpm_limit_for_team_model[
|
|
requested_model
|
|
]
|
|
if requested_model in _rpm_limit_for_team_model:
|
|
model_specific_rpm_limit = _rpm_limit_for_team_model[
|
|
requested_model
|
|
]
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="model_per_organization",
|
|
value=f"{user_api_key_dict.org_id}:{requested_model}",
|
|
rate_limit={
|
|
"requests_per_unit": model_specific_rpm_limit,
|
|
"tokens_per_unit": model_specific_tpm_limit,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
return descriptors
|
|
|
|
def _add_model_per_key_rate_limit_descriptor(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
requested_model: Optional[str],
|
|
descriptors: List[RateLimitDescriptor],
|
|
) -> None:
|
|
"""
|
|
Add model-specific rate limit descriptor for API key if applicable.
|
|
|
|
Args:
|
|
user_api_key_dict: User API key authentication dictionary
|
|
requested_model: The model being requested
|
|
descriptors: List of rate limit descriptors to append to
|
|
"""
|
|
from litellm.proxy.auth.auth_utils import (
|
|
get_key_model_rpm_limit,
|
|
get_key_model_tpm_limit,
|
|
)
|
|
|
|
if not requested_model:
|
|
return
|
|
|
|
_tpm_limit_for_key_model = get_key_model_tpm_limit(
|
|
user_api_key_dict, model_name=requested_model
|
|
)
|
|
_rpm_limit_for_key_model = get_key_model_rpm_limit(
|
|
user_api_key_dict, model_name=requested_model
|
|
)
|
|
|
|
if _tpm_limit_for_key_model is None and _rpm_limit_for_key_model is None:
|
|
return
|
|
|
|
_tpm_limit_for_key_model = _tpm_limit_for_key_model or {}
|
|
_rpm_limit_for_key_model = _rpm_limit_for_key_model or {}
|
|
|
|
# Check if model has any rate limits configured
|
|
should_check_rate_limit = (
|
|
requested_model in _tpm_limit_for_key_model
|
|
or requested_model in _rpm_limit_for_key_model
|
|
)
|
|
|
|
if not should_check_rate_limit:
|
|
return
|
|
|
|
# Get model-specific limits
|
|
model_specific_tpm_limit: Optional[int] = _tpm_limit_for_key_model.get(
|
|
requested_model
|
|
)
|
|
model_specific_rpm_limit: Optional[int] = _rpm_limit_for_key_model.get(
|
|
requested_model
|
|
)
|
|
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="model_per_key",
|
|
value=f"{user_api_key_dict.api_key}:{requested_model}",
|
|
rate_limit={
|
|
"requests_per_unit": model_specific_rpm_limit,
|
|
"tokens_per_unit": model_specific_tpm_limit,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
def _should_enforce_rate_limit(
|
|
self,
|
|
limit_type: Optional[str],
|
|
model_has_failures: bool,
|
|
) -> bool:
|
|
"""
|
|
Determine if rate limit should be enforced based on limit type and model health.
|
|
|
|
Args:
|
|
limit_type: Type of rate limit ("dynamic", "guaranteed_throughput", "best_effort_throughput", or None)
|
|
model_has_failures: Whether the model has recent failures
|
|
|
|
Returns:
|
|
True if rate limit should be enforced, False otherwise
|
|
"""
|
|
if limit_type == "dynamic":
|
|
# Dynamic mode: only enforce if model has failures
|
|
return model_has_failures
|
|
# All other modes (including None): always enforce
|
|
return True
|
|
|
|
def _get_enforced_limit(
|
|
self,
|
|
limit_value: Optional[int],
|
|
limit_type: Optional[str],
|
|
model_has_failures: bool,
|
|
) -> Optional[int]:
|
|
"""
|
|
Get the rate limit value to enforce based on limit type and model health.
|
|
|
|
Args:
|
|
limit_value: The configured limit value
|
|
limit_type: Type of rate limit ("dynamic", "guaranteed_throughput", "best_effort_throughput", or None)
|
|
model_has_failures: Whether the model has recent failures
|
|
|
|
Returns:
|
|
The limit value if it should be enforced, None otherwise
|
|
"""
|
|
if limit_value is None:
|
|
return None
|
|
|
|
if self._should_enforce_rate_limit(
|
|
limit_type=limit_type,
|
|
model_has_failures=model_has_failures,
|
|
):
|
|
return limit_value
|
|
|
|
return None
|
|
|
|
def _is_dynamic_rate_limiting_enabled(
|
|
self,
|
|
rpm_limit_type: Optional[str],
|
|
tpm_limit_type: Optional[str],
|
|
) -> bool:
|
|
"""
|
|
Check if dynamic rate limiting is enabled for either RPM or TPM.
|
|
|
|
Args:
|
|
rpm_limit_type: RPM rate limit type
|
|
tpm_limit_type: TPM rate limit type
|
|
|
|
Returns:
|
|
True if dynamic mode is enabled for either limit type
|
|
"""
|
|
return rpm_limit_type == "dynamic" or tpm_limit_type == "dynamic"
|
|
|
|
def _get_agent_from_registry(self, agent_id: str) -> Optional[Any]:
|
|
"""Look up an agent from the in-memory registry by ID."""
|
|
from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry
|
|
|
|
return global_agent_registry.get_agent_by_id(agent_id=agent_id)
|
|
|
|
def _get_resolved_agent_id(
|
|
self, user_api_key_dict: UserAPIKeyAuth, data: dict
|
|
) -> Optional[str]:
|
|
"""
|
|
Resolve the agent_id from either the API key or request metadata.
|
|
Key-level agent_id takes precedence over metadata/header-supplied agent_id.
|
|
"""
|
|
key_agent_id = getattr(user_api_key_dict, "agent_id", None)
|
|
if key_agent_id:
|
|
return key_agent_id
|
|
metadata = data.get("metadata") or {}
|
|
return metadata.get("agent_id")
|
|
|
|
def _get_session_id_from_data(self, data: dict) -> Optional[str]:
|
|
"""Extract session_id from request metadata or litellm_session_id."""
|
|
session_id = data.get("litellm_session_id")
|
|
if session_id:
|
|
return str(session_id)
|
|
metadata = data.get("metadata") or {}
|
|
session_id = metadata.get("session_id")
|
|
if session_id:
|
|
return str(session_id)
|
|
litellm_metadata = data.get("litellm_metadata") or {}
|
|
session_id = litellm_metadata.get("session_id")
|
|
if session_id:
|
|
return str(session_id)
|
|
return None
|
|
|
|
def _create_agent_rate_limit_descriptors(
|
|
self,
|
|
agent_id: str,
|
|
data: dict,
|
|
) -> List[RateLimitDescriptor]:
|
|
"""
|
|
Create rate limit descriptors for agent-level and session-level limits.
|
|
|
|
Agent-level: caps total RPM/TPM across all sessions for a given agent.
|
|
Session-level: caps RPM/TPM within a single session (identified by session_id).
|
|
"""
|
|
descriptors: List[RateLimitDescriptor] = []
|
|
|
|
agent = self._get_agent_from_registry(agent_id)
|
|
if agent is None:
|
|
return descriptors
|
|
|
|
agent_rpm = getattr(agent, "rpm_limit", None)
|
|
agent_tpm = getattr(agent, "tpm_limit", None)
|
|
if agent_rpm is not None or agent_tpm is not None:
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="agent",
|
|
value=agent_id,
|
|
rate_limit={
|
|
"requests_per_unit": agent_rpm,
|
|
"tokens_per_unit": agent_tpm,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
session_rpm = getattr(agent, "session_rpm_limit", None)
|
|
session_tpm = getattr(agent, "session_tpm_limit", None)
|
|
if session_rpm is not None or session_tpm is not None:
|
|
session_id = self._get_session_id_from_data(data)
|
|
if session_id is not None:
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="agent_session",
|
|
value=f"{agent_id}:{session_id}",
|
|
rate_limit={
|
|
"requests_per_unit": session_rpm,
|
|
"tokens_per_unit": session_tpm,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
return descriptors
|
|
|
|
def _create_rate_limit_descriptors(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
data: dict,
|
|
rpm_limit_type: Optional[str],
|
|
tpm_limit_type: Optional[str],
|
|
model_has_failures: bool,
|
|
) -> List[RateLimitDescriptor]:
|
|
"""
|
|
Create all rate limit descriptors for the request.
|
|
|
|
Returns list of descriptors for API key, user, team, team member, end user,
|
|
model-specific, agent, and agent-session limits.
|
|
"""
|
|
from litellm.proxy.auth.auth_utils import (
|
|
get_team_model_rpm_limit,
|
|
get_team_model_tpm_limit,
|
|
)
|
|
|
|
descriptors = []
|
|
|
|
# API Key rate limits
|
|
if user_api_key_dict.api_key and (
|
|
user_api_key_dict.rpm_limit is not None
|
|
or user_api_key_dict.tpm_limit is not None
|
|
or user_api_key_dict.max_parallel_requests is not None
|
|
):
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="api_key",
|
|
value=user_api_key_dict.api_key,
|
|
rate_limit={
|
|
"requests_per_unit": self._get_enforced_limit(
|
|
limit_value=user_api_key_dict.rpm_limit,
|
|
limit_type=rpm_limit_type,
|
|
model_has_failures=model_has_failures,
|
|
),
|
|
"tokens_per_unit": self._get_enforced_limit(
|
|
limit_value=user_api_key_dict.tpm_limit,
|
|
limit_type=tpm_limit_type,
|
|
model_has_failures=model_has_failures,
|
|
),
|
|
"max_parallel_requests": user_api_key_dict.max_parallel_requests,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
# User rate limits
|
|
if user_api_key_dict.user_id and (
|
|
user_api_key_dict.user_rpm_limit is not None
|
|
or user_api_key_dict.user_tpm_limit is not None
|
|
):
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="user",
|
|
value=user_api_key_dict.user_id,
|
|
rate_limit={
|
|
"requests_per_unit": user_api_key_dict.user_rpm_limit,
|
|
"tokens_per_unit": user_api_key_dict.user_tpm_limit,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
# Team rate limits
|
|
if user_api_key_dict.team_id and (
|
|
user_api_key_dict.team_rpm_limit is not None
|
|
or user_api_key_dict.team_tpm_limit is not None
|
|
):
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="team",
|
|
value=user_api_key_dict.team_id,
|
|
rate_limit={
|
|
"requests_per_unit": user_api_key_dict.team_rpm_limit,
|
|
"tokens_per_unit": user_api_key_dict.team_tpm_limit,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
# Team Member rate limits
|
|
if user_api_key_dict.user_id and (
|
|
user_api_key_dict.team_member_rpm_limit is not None
|
|
or user_api_key_dict.team_member_tpm_limit is not None
|
|
):
|
|
team_member_value = (
|
|
f"{user_api_key_dict.team_id}:{user_api_key_dict.user_id}"
|
|
)
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="team_member",
|
|
value=team_member_value,
|
|
rate_limit={
|
|
"requests_per_unit": user_api_key_dict.team_member_rpm_limit,
|
|
"tokens_per_unit": user_api_key_dict.team_member_tpm_limit,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
# End user rate limits
|
|
if user_api_key_dict.end_user_id and (
|
|
user_api_key_dict.end_user_rpm_limit is not None
|
|
or user_api_key_dict.end_user_tpm_limit is not None
|
|
):
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="end_user",
|
|
value=user_api_key_dict.end_user_id,
|
|
rate_limit={
|
|
"requests_per_unit": user_api_key_dict.end_user_rpm_limit,
|
|
"tokens_per_unit": user_api_key_dict.end_user_tpm_limit,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
# Model rate limits
|
|
requested_model = data.get("model", None)
|
|
self._add_model_per_key_rate_limit_descriptor(
|
|
user_api_key_dict=user_api_key_dict,
|
|
requested_model=requested_model,
|
|
descriptors=descriptors,
|
|
)
|
|
|
|
if (
|
|
get_team_model_rpm_limit(user_api_key_dict) is not None
|
|
or get_team_model_tpm_limit(user_api_key_dict) is not None
|
|
):
|
|
_tpm_limit_for_team_model = (
|
|
get_team_model_tpm_limit(user_api_key_dict) or {}
|
|
)
|
|
_rpm_limit_for_team_model = (
|
|
get_team_model_rpm_limit(user_api_key_dict) or {}
|
|
)
|
|
should_check_rate_limit = False
|
|
if requested_model in _tpm_limit_for_team_model:
|
|
should_check_rate_limit = True
|
|
elif requested_model in _rpm_limit_for_team_model:
|
|
should_check_rate_limit = True
|
|
|
|
if should_check_rate_limit:
|
|
model_specific_tpm_limit = None
|
|
model_specific_rpm_limit = None
|
|
if requested_model in _tpm_limit_for_team_model:
|
|
model_specific_tpm_limit = _tpm_limit_for_team_model[
|
|
requested_model
|
|
]
|
|
if requested_model in _rpm_limit_for_team_model:
|
|
model_specific_rpm_limit = _rpm_limit_for_team_model[
|
|
requested_model
|
|
]
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="model_per_team",
|
|
value=f"{user_api_key_dict.team_id}:{requested_model}",
|
|
rate_limit={
|
|
"requests_per_unit": model_specific_rpm_limit,
|
|
"tokens_per_unit": model_specific_tpm_limit,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
# Agent-level and session-level rate limits
|
|
resolved_agent_id = self._get_resolved_agent_id(user_api_key_dict, data)
|
|
|
|
if resolved_agent_id:
|
|
descriptors.extend(
|
|
self._create_agent_rate_limit_descriptors(
|
|
agent_id=resolved_agent_id,
|
|
data=data,
|
|
)
|
|
)
|
|
|
|
return descriptors
|
|
|
|
async def _check_model_has_recent_failures(
|
|
self,
|
|
model: str,
|
|
parent_otel_span: Optional[Span] = None,
|
|
) -> bool:
|
|
"""
|
|
Check if any deployment for this model has recent failures by using
|
|
the router's existing failure tracking.
|
|
|
|
Returns True if any deployment has failures in the current minute.
|
|
"""
|
|
from litellm.proxy.proxy_server import llm_router
|
|
from litellm.router_utils.router_callbacks.track_deployment_metrics import (
|
|
get_deployment_failures_for_current_minute,
|
|
)
|
|
|
|
if llm_router is None:
|
|
return False
|
|
|
|
try:
|
|
# Get all deployments for this model
|
|
model_list = llm_router.get_model_list(model_name=model)
|
|
if not model_list:
|
|
return False
|
|
|
|
# Check each deployment's failure count
|
|
for deployment in model_list:
|
|
deployment_id = deployment.get("model_info", {}).get("id")
|
|
if not deployment_id:
|
|
continue
|
|
|
|
# Use router's existing failure tracking
|
|
failure_count = get_deployment_failures_for_current_minute(
|
|
litellm_router_instance=llm_router,
|
|
deployment_id=deployment_id,
|
|
)
|
|
|
|
if failure_count > DYNAMIC_RATE_LIMIT_ERROR_THRESHOLD_PER_MINUTE:
|
|
verbose_proxy_logger.debug(
|
|
f"[Dynamic Rate Limit] Deployment {deployment_id} has {failure_count} failures "
|
|
f"in current minute - enforcing rate limits for model {model}"
|
|
)
|
|
return True
|
|
|
|
verbose_proxy_logger.debug(
|
|
f"[Dynamic Rate Limit] No failures detected for model {model} - allowing dynamic exceeding"
|
|
)
|
|
return False
|
|
|
|
except Exception as e:
|
|
verbose_proxy_logger.debug(
|
|
f"Error checking model failure status: {str(e)}, defaulting to enforce limits"
|
|
)
|
|
# Fail safe: enforce limits if we can't check
|
|
return True
|
|
|
|
def get_rate_limiter_for_call_type(self, call_type: str) -> Optional[Any]:
|
|
"""Get the rate limiter for the call type."""
|
|
if call_type == "acreate_batch":
|
|
batch_limiter = self._get_batch_rate_limiter()
|
|
return batch_limiter
|
|
return None
|
|
|
|
def _add_team_model_rate_limit_descriptor_from_metadata(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
requested_model: Optional[str],
|
|
descriptors: List[RateLimitDescriptor],
|
|
) -> None:
|
|
"""Add team model rate limit descriptor from team_metadata if applicable."""
|
|
if (
|
|
get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "team_metadata", "model_rpm_limit"
|
|
)
|
|
is not None
|
|
or get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "team_metadata", "model_tpm_limit"
|
|
)
|
|
is not None
|
|
):
|
|
_tpm_limit_for_team_model = (
|
|
get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "team_metadata", "model_tpm_limit"
|
|
)
|
|
or {}
|
|
)
|
|
_rpm_limit_for_team_model = (
|
|
get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "team_metadata", "model_rpm_limit"
|
|
)
|
|
or {}
|
|
)
|
|
should_check_rate_limit = (
|
|
requested_model in _tpm_limit_for_team_model
|
|
or requested_model in _rpm_limit_for_team_model
|
|
)
|
|
|
|
if should_check_rate_limit and requested_model is not None:
|
|
model_specific_tpm_limit = _tpm_limit_for_team_model.get(
|
|
requested_model
|
|
)
|
|
model_specific_rpm_limit = _rpm_limit_for_team_model.get(
|
|
requested_model
|
|
)
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="model_per_team",
|
|
value=f"{user_api_key_dict.team_id}:{requested_model}",
|
|
rate_limit={
|
|
"requests_per_unit": model_specific_rpm_limit,
|
|
"tokens_per_unit": model_specific_tpm_limit,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
def _add_project_model_rate_limit_descriptor_from_metadata(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
requested_model: Optional[str],
|
|
descriptors: List[RateLimitDescriptor],
|
|
) -> None:
|
|
"""Add project model rate limit descriptor from project_metadata if applicable."""
|
|
if (
|
|
get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "project_metadata", "model_rpm_limit"
|
|
)
|
|
is not None
|
|
or get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "project_metadata", "model_tpm_limit"
|
|
)
|
|
is not None
|
|
):
|
|
_tpm_limit_for_project_model = (
|
|
get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "project_metadata", "model_tpm_limit"
|
|
)
|
|
or {}
|
|
)
|
|
_rpm_limit_for_project_model = (
|
|
get_model_rate_limit_from_metadata(
|
|
user_api_key_dict, "project_metadata", "model_rpm_limit"
|
|
)
|
|
or {}
|
|
)
|
|
should_check_rate_limit = (
|
|
requested_model in _tpm_limit_for_project_model
|
|
or requested_model in _rpm_limit_for_project_model
|
|
)
|
|
|
|
if should_check_rate_limit and requested_model is not None:
|
|
model_specific_tpm_limit = _tpm_limit_for_project_model.get(
|
|
requested_model
|
|
)
|
|
model_specific_rpm_limit = _rpm_limit_for_project_model.get(
|
|
requested_model
|
|
)
|
|
descriptors.append(
|
|
RateLimitDescriptor(
|
|
key="model_per_project",
|
|
value=f"{user_api_key_dict.project_id}:{requested_model}",
|
|
rate_limit={
|
|
"requests_per_unit": model_specific_rpm_limit,
|
|
"tokens_per_unit": model_specific_tpm_limit,
|
|
"window_size": self.window_size,
|
|
},
|
|
)
|
|
)
|
|
|
|
def _handle_rate_limit_error(
|
|
self,
|
|
response: RateLimitResponse,
|
|
descriptors: List[RateLimitDescriptor],
|
|
) -> None:
|
|
"""Handle rate limit exceeded error by raising HTTPException."""
|
|
for status in response["statuses"]:
|
|
if status["code"] == "OVER_LIMIT":
|
|
descriptor_key = status["descriptor_key"]
|
|
matching_descriptor = next(
|
|
(desc for desc in descriptors if desc["key"] == descriptor_key),
|
|
None,
|
|
)
|
|
descriptor_value = (
|
|
matching_descriptor["value"]
|
|
if matching_descriptor is not None
|
|
else "unknown"
|
|
)
|
|
|
|
now = self._get_current_time().timestamp()
|
|
reset_time = now + self.window_size
|
|
reset_time_formatted = datetime.fromtimestamp(reset_time).strftime(
|
|
"%Y-%m-%d %H:%M:%S UTC"
|
|
)
|
|
|
|
remaining_display = max(0, status["limit_remaining"])
|
|
rate_limit_type = status["rate_limit_type"]
|
|
current_limit = status["current_limit"]
|
|
|
|
detail = (
|
|
f"Rate limit exceeded for {descriptor_key}: {descriptor_value}. "
|
|
f"Limit type: {rate_limit_type}. "
|
|
f"Current limit: {current_limit}, Remaining: {remaining_display}. "
|
|
f"Limit resets at: {reset_time_formatted}"
|
|
)
|
|
|
|
raise HTTPException(
|
|
status_code=429,
|
|
detail=detail,
|
|
headers={
|
|
"retry-after": str(self.window_size),
|
|
"rate_limit_type": str(status["rate_limit_type"]),
|
|
"reset_at": reset_time_formatted,
|
|
},
|
|
)
|
|
|
|
async def async_pre_call_hook(
|
|
self,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
cache: DualCache,
|
|
data: dict,
|
|
call_type: str,
|
|
):
|
|
"""
|
|
Pre-call hook to check rate limits before making the API call.
|
|
Supports dynamic rate limiting based on deployment health.
|
|
"""
|
|
verbose_proxy_logger.debug("Inside Rate Limit Pre-Call Hook")
|
|
|
|
# Reject caller-supplied stash values before any read/write. Otherwise
|
|
# a client can inject ``_litellm_rate_limit_descriptors`` /
|
|
# ``_litellm_tpm_reserved_tokens`` in body ``metadata`` and have
|
|
# ``async_post_call_failure_hook`` refund TPM counters against scopes
|
|
# they name (e.g. another tenant's api_key).
|
|
self._strip_stash_keys_from_all_channels(data)
|
|
|
|
#########################################################
|
|
# Check if the call type has a specific rate limiter
|
|
# eg. for Batch APIs we need to use the batch rate limiter to read the input file and count the tokens and requests
|
|
#########################################################
|
|
call_type_specific_rate_limiter = self.get_rate_limiter_for_call_type(
|
|
call_type=call_type
|
|
)
|
|
if call_type_specific_rate_limiter:
|
|
return await call_type_specific_rate_limiter.async_pre_call_hook(
|
|
user_api_key_dict=user_api_key_dict,
|
|
cache=cache,
|
|
data=data,
|
|
call_type=call_type,
|
|
)
|
|
|
|
# Get rate limit types from metadata
|
|
metadata = user_api_key_dict.metadata or {}
|
|
rpm_limit_type = metadata.get("rpm_limit_type")
|
|
tpm_limit_type = metadata.get("tpm_limit_type")
|
|
|
|
# For dynamic mode, check if the model has recent failures
|
|
model_has_failures = False
|
|
requested_model = data.get("model", None)
|
|
|
|
if (
|
|
self._is_dynamic_rate_limiting_enabled(
|
|
rpm_limit_type=rpm_limit_type,
|
|
tpm_limit_type=tpm_limit_type,
|
|
)
|
|
and requested_model
|
|
):
|
|
model_has_failures = await self._check_model_has_recent_failures(
|
|
model=requested_model,
|
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
|
)
|
|
|
|
# Create rate limit descriptors
|
|
descriptors = self._create_rate_limit_descriptors(
|
|
user_api_key_dict=user_api_key_dict,
|
|
data=data,
|
|
rpm_limit_type=rpm_limit_type,
|
|
tpm_limit_type=tpm_limit_type,
|
|
model_has_failures=model_has_failures,
|
|
)
|
|
|
|
# Add team model rate limits from team_metadata
|
|
self._add_team_model_rate_limit_descriptor_from_metadata(
|
|
user_api_key_dict=user_api_key_dict,
|
|
requested_model=requested_model,
|
|
descriptors=descriptors,
|
|
)
|
|
|
|
# Project Level Rate Limits
|
|
self._add_project_model_rate_limit_descriptor_from_metadata(
|
|
user_api_key_dict=user_api_key_dict,
|
|
requested_model=requested_model,
|
|
descriptors=descriptors,
|
|
)
|
|
|
|
# Org Level Rate Limits
|
|
descriptors.extend(
|
|
self.create_organization_rate_limit_descriptor(
|
|
user_api_key_dict, requested_model
|
|
)
|
|
)
|
|
# Only check rate limits if we have descriptors with actual limits
|
|
if descriptors:
|
|
# First pass: RPM and max_parallel_requests sliding-window check.
|
|
# `skip_tpm_check=True` tells should_rate_limit to ignore each
|
|
# descriptor's tokens_per_unit so its +1-per-key Lua / in-memory
|
|
# increment never touches the :tokens counters — those are owned
|
|
# exclusively by the atomic reserve_tpm_tokens path below. Without
|
|
# this, every concurrent in-flight request would pre-inflate the
|
|
# :tokens counter by 1, shrinking the effective TPM budget by N
|
|
# and causing false-positive 429s under bursts.
|
|
response = await self.should_rate_limit(
|
|
descriptors=descriptors,
|
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
|
skip_tpm_check=True,
|
|
)
|
|
|
|
if response["overall_code"] == "OVER_LIMIT":
|
|
self._handle_rate_limit_error(
|
|
response=response,
|
|
descriptors=descriptors,
|
|
)
|
|
else:
|
|
# add descriptors to request headers
|
|
data["litellm_proxy_rate_limit_response"] = response
|
|
|
|
# ----------------------------------------------------------------
|
|
# TPM token reservation
|
|
# Atomically reserve estimated tokens upfront so concurrent
|
|
# requests cannot all observe "under limit" before any of them
|
|
# has incremented the counter. atomic_check_and_increment_by_n
|
|
# uses Redis Lua when available and falls back to an asyncio-locked
|
|
# in-memory check otherwise — single-worker protection still holds
|
|
# even without Redis.
|
|
# ----------------------------------------------------------------
|
|
configured_tpm_limits = [
|
|
int(v)
|
|
for d in descriptors
|
|
for v in [(d.get("rate_limit") or {}).get("tokens_per_unit")]
|
|
if v is not None
|
|
]
|
|
has_tpm_limits = bool(configured_tpm_limits)
|
|
|
|
if has_tpm_limits:
|
|
min_configured_tpm_limit = min(configured_tpm_limits)
|
|
|
|
# When the configured TPM cap is small enough to constrain the
|
|
# no-max_tokens floor, also hard-cap the model output via
|
|
# data["max_tokens"] so concurrent unbounded generations can't
|
|
# spend past the limit before post-call reconciliation runs.
|
|
# Skip when the request already sets max_tokens or has no
|
|
# generation budget at all (embeddings).
|
|
capped_floor = self._no_max_tokens_output_floor(
|
|
min_configured_tpm_limit
|
|
)
|
|
baseline_floor = DEFAULT_MAX_TOKENS_ESTIMATE // _TPM_FLOOR_FRACTION
|
|
has_explicit_max_tokens = (
|
|
data.get("max_tokens") is not None
|
|
or data.get("max_completion_tokens") is not None
|
|
)
|
|
is_embedding = data.get("input") is not None
|
|
if (
|
|
capped_floor < baseline_floor
|
|
and not has_explicit_max_tokens
|
|
and not is_embedding
|
|
):
|
|
data["max_tokens"] = capped_floor
|
|
|
|
# Floor at 1 token so contentless requests (/responses,
|
|
# tool-call continuations, empty messages) still flow
|
|
# through the atomic counter and get backpressure when at
|
|
# limit. Without this floor, N concurrent contentless
|
|
# requests would all pass pre-call with no enforcement.
|
|
# Post-call reconciliation refunds the over-reservation
|
|
# delta when actual usage comes in below the floor.
|
|
estimated_tokens = max(
|
|
self._estimate_tokens_for_request(
|
|
data=data,
|
|
model=requested_model,
|
|
min_configured_tpm_limit=min_configured_tpm_limit,
|
|
),
|
|
1,
|
|
)
|
|
|
|
tpm_response = await self.reserve_tpm_tokens(
|
|
descriptors=descriptors,
|
|
estimated_tokens=estimated_tokens,
|
|
parent_otel_span=user_api_key_dict.parent_otel_span,
|
|
)
|
|
|
|
if tpm_response["overall_code"] == "OVER_LIMIT":
|
|
self._handle_rate_limit_error(
|
|
response=tpm_response,
|
|
descriptors=descriptors,
|
|
)
|
|
else:
|
|
self._stash_value_in_metadata_channels(
|
|
data=data,
|
|
key=RATE_LIMIT_DESCRIPTORS_KEY,
|
|
value=descriptors,
|
|
)
|
|
# Capture the exact (key, value) scopes the reservation
|
|
# incremented so post-call reconciliation only applies
|
|
# the (actual - reserved) delta to those — unreserved
|
|
# scopes get charged the full actual usage instead.
|
|
reserved_scopes: List[Tuple[str, str]] = [
|
|
(d["key"], d["value"])
|
|
for d in descriptors
|
|
if (d.get("rate_limit") or {}).get("tokens_per_unit")
|
|
is not None
|
|
]
|
|
self._stash_reservation_in_data(
|
|
data=data,
|
|
estimated_tokens=estimated_tokens,
|
|
reserved_model=requested_model,
|
|
reserved_scopes=reserved_scopes,
|
|
)
|
|
|
|
# Merge TPM statuses into the stored rate-limit response
|
|
# so x-ratelimit-{key}-remaining-tokens / -limit-tokens
|
|
# headers reach the client. Without this, the RPM-only
|
|
# response from should_rate_limit (skip_tpm_check=True)
|
|
# silently drops all token headers.
|
|
stored_response = data.get("litellm_proxy_rate_limit_response")
|
|
if isinstance(stored_response, dict):
|
|
stored_response.setdefault("statuses", []).extend(
|
|
tpm_response["statuses"]
|
|
)
|
|
elif tpm_response["statuses"]:
|
|
data["litellm_proxy_rate_limit_response"] = tpm_response
|
|
|
|
verbose_proxy_logger.debug(
|
|
f"TPM tokens reserved: {estimated_tokens} for model {requested_model}"
|
|
)
|
|
|
|
# Defense-in-depth: scrub any stash key that escaped onto data
|
|
# top-level (stale cache hit, router pass, test fixture) before the
|
|
# body is forwarded to the provider.
|
|
self._strip_stash_keys_from_top_level(data)
|
|
|
|
@staticmethod
|
|
def _strip_stash_keys_from_top_level(data: Any) -> None:
|
|
if not isinstance(data, dict):
|
|
return
|
|
for stash_key in _LITELLM_STASH_KEYS:
|
|
data.pop(stash_key, None)
|
|
|
|
@classmethod
|
|
def _strip_stash_keys_from_all_channels(cls, data: Any) -> None:
|
|
if not isinstance(data, dict):
|
|
return
|
|
cls._strip_stash_keys_from_top_level(data)
|
|
for channel in ("metadata", "litellm_metadata"):
|
|
channel_dict = data.get(channel)
|
|
if isinstance(channel_dict, dict):
|
|
for stash_key in _LITELLM_STASH_KEYS:
|
|
channel_dict.pop(stash_key, None)
|
|
|
|
def _create_pipeline_operations(
|
|
self,
|
|
key: str,
|
|
value: str,
|
|
rate_limit_type: Literal["requests", "tokens", "max_parallel_requests"],
|
|
total_tokens: int,
|
|
) -> List["RedisPipelineIncrementOperation"]:
|
|
"""
|
|
Create pipeline operations for TPM increments
|
|
"""
|
|
pipeline_operations: List[RedisPipelineIncrementOperation] = []
|
|
counter_key = self.create_rate_limit_keys(
|
|
key=key,
|
|
value=value,
|
|
rate_limit_type="tokens",
|
|
)
|
|
pipeline_operations.append(
|
|
RedisPipelineIncrementOperation(
|
|
key=counter_key,
|
|
increment_value=total_tokens,
|
|
ttl=self.window_size,
|
|
)
|
|
)
|
|
|
|
return pipeline_operations
|
|
|
|
def _get_total_tokens_from_usage(
|
|
self, usage: Optional[Any], rate_limit_type: Literal["output", "input", "total"]
|
|
) -> int:
|
|
"""
|
|
Get total tokens from response usage for rate limiting.
|
|
|
|
For 'input' and 'total' rate limit types, cached tokens are excluded
|
|
because providers like AWS Bedrock don't count cached tokens toward
|
|
rate limits. This aligns LiteLLM's TPM calculation with provider behavior.
|
|
"""
|
|
total_tokens = 0
|
|
cached_tokens = 0
|
|
|
|
if usage:
|
|
if isinstance(usage, Usage):
|
|
if rate_limit_type == "output":
|
|
total_tokens = usage.completion_tokens or 0
|
|
elif rate_limit_type == "input":
|
|
total_tokens = usage.prompt_tokens or 0
|
|
elif rate_limit_type == "total":
|
|
total_tokens = usage.total_tokens or 0
|
|
|
|
# Get cached tokens to exclude from input/total
|
|
if rate_limit_type in ("input", "total"):
|
|
if (
|
|
hasattr(usage, "prompt_tokens_details")
|
|
and usage.prompt_tokens_details is not None
|
|
):
|
|
cached_tokens = (
|
|
getattr(usage.prompt_tokens_details, "cached_tokens", 0)
|
|
or 0
|
|
)
|
|
|
|
elif isinstance(usage, dict):
|
|
# Responses API usage comes as a dict
|
|
if rate_limit_type == "output":
|
|
total_tokens = usage.get("completion_tokens", 0) or 0
|
|
elif rate_limit_type == "input":
|
|
total_tokens = usage.get("prompt_tokens", 0) or 0
|
|
elif rate_limit_type == "total":
|
|
total_tokens = usage.get("total_tokens", 0) or 0
|
|
|
|
# Get cached tokens from dict
|
|
if rate_limit_type in ("input", "total"):
|
|
prompt_details = usage.get("prompt_tokens_details") or {}
|
|
if isinstance(prompt_details, dict):
|
|
cached_tokens = prompt_details.get("cached_tokens", 0) or 0
|
|
|
|
# Subtract cached tokens for input/total (providers don't count them)
|
|
if cached_tokens > 0:
|
|
total_tokens = max(0, total_tokens - cached_tokens)
|
|
|
|
return total_tokens
|
|
|
|
async def _execute_token_increment_script(
|
|
self,
|
|
pipeline_operations: List["RedisPipelineIncrementOperation"],
|
|
) -> None:
|
|
"""
|
|
Execute token increment script grouped by hash tag for cluster compatibility.
|
|
"""
|
|
if self.token_increment_script is None:
|
|
return
|
|
|
|
# Group operations by hash tag for Redis cluster compatibility
|
|
operation_keys = [op["key"] for op in pipeline_operations]
|
|
key_groups = self._group_keys_by_hash_tag(operation_keys)
|
|
|
|
for _hash_tag, group_keys in key_groups.items():
|
|
# Get operations for this hash tag group
|
|
group_operations = [
|
|
op for op in pipeline_operations if op["key"] in group_keys
|
|
]
|
|
|
|
keys = []
|
|
args = []
|
|
|
|
for op in group_operations:
|
|
# Convert None TTL to 0 for Lua script
|
|
ttl_value = op["ttl"] if op["ttl"] is not None else 0
|
|
|
|
verbose_proxy_logger.debug(
|
|
f"Executing TTL-preserving increment for key={op['key']}, "
|
|
f"increment={op['increment_value']}, ttl={ttl_value}"
|
|
)
|
|
keys.append(op["key"])
|
|
args.extend([op["increment_value"], ttl_value])
|
|
|
|
await self.token_increment_script(
|
|
keys=keys,
|
|
args=args,
|
|
)
|
|
|
|
async def async_increment_tokens_with_ttl_preservation(
|
|
self,
|
|
pipeline_operations: List["RedisPipelineIncrementOperation"],
|
|
parent_otel_span: Optional[Span] = None,
|
|
) -> None:
|
|
"""
|
|
Increment token counters using Lua script to preserve existing TTL.
|
|
This prevents TTL reset on every token increment.
|
|
"""
|
|
if not pipeline_operations:
|
|
return
|
|
|
|
# Check if script is available
|
|
if self.token_increment_script is None:
|
|
verbose_proxy_logger.debug(
|
|
"TTL preservation script not available, using regular pipeline"
|
|
)
|
|
await self.internal_usage_cache.dual_cache.async_increment_cache_pipeline(
|
|
increment_list=pipeline_operations,
|
|
litellm_parent_otel_span=parent_otel_span,
|
|
)
|
|
return
|
|
|
|
try:
|
|
await self._execute_token_increment_script(pipeline_operations)
|
|
|
|
verbose_proxy_logger.debug(
|
|
f"Successfully executed TTL-preserving increment for {len(pipeline_operations)} keys"
|
|
)
|
|
|
|
except Exception as e:
|
|
verbose_proxy_logger.warning(
|
|
f"TTL preservation failed, falling back to regular pipeline: {str(e)}"
|
|
)
|
|
# Fallback to regular pipeline on error
|
|
await self.internal_usage_cache.dual_cache.async_increment_cache_pipeline(
|
|
increment_list=pipeline_operations,
|
|
litellm_parent_otel_span=parent_otel_span,
|
|
)
|
|
|
|
def get_rate_limit_type(self) -> Literal["output", "input", "total"]:
|
|
from litellm.proxy.proxy_server import general_settings
|
|
|
|
specified_rate_limit_type = general_settings.get(
|
|
"token_rate_limit_type", "total"
|
|
)
|
|
if specified_rate_limit_type not in [
|
|
"output",
|
|
"input",
|
|
"total",
|
|
]:
|
|
return "total" # default to total
|
|
return specified_rate_limit_type
|
|
|
|
@staticmethod
|
|
def _stash_value_in_metadata_channels(
|
|
data: Dict[str, Any],
|
|
key: str,
|
|
value: Any,
|
|
) -> None:
|
|
for channel in ("metadata", "litellm_metadata"):
|
|
existing = data.get(channel)
|
|
if isinstance(existing, dict):
|
|
existing[key] = value
|
|
elif channel == "metadata":
|
|
# ``litellm_metadata`` is owned by the router; don't conjure
|
|
# it here.
|
|
data[channel] = {key: value}
|
|
|
|
@classmethod
|
|
def _stash_reservation_in_data(
|
|
cls,
|
|
data: Dict[str, Any],
|
|
estimated_tokens: int,
|
|
reserved_model: Optional[str],
|
|
reserved_scopes: Optional[List[Tuple[str, str]]] = None,
|
|
) -> None:
|
|
"""
|
|
``reserved_scopes`` is serialized as a list of [key, value] pairs so
|
|
it round-trips through JSON-based metadata transports.
|
|
"""
|
|
scopes_payload: Optional[List[List[str]]] = (
|
|
[[k, v] for k, v in reserved_scopes] if reserved_scopes else None
|
|
)
|
|
|
|
cls._stash_value_in_metadata_channels(
|
|
data=data, key=TPM_RESERVED_TOKENS_KEY, value=estimated_tokens
|
|
)
|
|
if reserved_model:
|
|
cls._stash_value_in_metadata_channels(
|
|
data=data, key=TPM_RESERVED_MODEL_KEY, value=reserved_model
|
|
)
|
|
if scopes_payload is not None:
|
|
cls._stash_value_in_metadata_channels(
|
|
data=data, key=TPM_RESERVED_SCOPES_KEY, value=scopes_payload
|
|
)
|
|
|
|
@staticmethod
|
|
def _lookup_stashed_value(
|
|
kwargs: Any,
|
|
standard_logging_metadata: Optional[Dict[str, Any]],
|
|
key: str,
|
|
) -> Any:
|
|
"""
|
|
Resolve a stashed value from any metadata channel the request data
|
|
can flow through to a callback. Top-level ``kwargs`` is not checked
|
|
because stash keys must never live there.
|
|
"""
|
|
candidate: Any = None
|
|
if isinstance(kwargs, dict):
|
|
for channel in ("metadata", "litellm_metadata"):
|
|
channel_dict = kwargs.get(channel)
|
|
if isinstance(channel_dict, dict) and key in channel_dict:
|
|
candidate = channel_dict.get(key)
|
|
if candidate is not None:
|
|
return candidate
|
|
litellm_params = kwargs.get("litellm_params")
|
|
if isinstance(litellm_params, dict):
|
|
lp_metadata = litellm_params.get("metadata")
|
|
if isinstance(lp_metadata, dict):
|
|
candidate = lp_metadata.get(key)
|
|
if candidate is None and isinstance(standard_logging_metadata, dict):
|
|
candidate = standard_logging_metadata.get(key)
|
|
return candidate
|
|
|
|
@classmethod
|
|
def _get_reserved_tokens_from_kwargs(
|
|
cls,
|
|
kwargs: Any,
|
|
standard_logging_metadata: Optional[Dict[str, Any]] = None,
|
|
) -> int:
|
|
candidate = cls._lookup_stashed_value(
|
|
kwargs, standard_logging_metadata, TPM_RESERVED_TOKENS_KEY
|
|
)
|
|
try:
|
|
return int(candidate or 0)
|
|
except (TypeError, ValueError):
|
|
return 0
|
|
|
|
@classmethod
|
|
def _get_reserved_model_from_kwargs(
|
|
cls,
|
|
kwargs: Any,
|
|
standard_logging_metadata: Optional[Dict[str, Any]] = None,
|
|
) -> Optional[str]:
|
|
"""
|
|
Resolve the model the upfront reservation was charged against. Used to
|
|
target reconciliation at the same key that was incremented, regardless
|
|
of whether the router later set a different ``model_group`` in
|
|
``litellm_params.metadata``.
|
|
"""
|
|
candidate = cls._lookup_stashed_value(
|
|
kwargs, standard_logging_metadata, TPM_RESERVED_MODEL_KEY
|
|
)
|
|
return candidate if isinstance(candidate, str) and candidate else None
|
|
|
|
@classmethod
|
|
def _get_reserved_scopes_from_kwargs(
|
|
cls,
|
|
kwargs: Any,
|
|
standard_logging_metadata: Optional[Dict[str, Any]] = None,
|
|
) -> Set[Tuple[str, str]]:
|
|
"""
|
|
Resolve the (scope_key, scope_value) pairs the upfront reservation
|
|
actually charged. Reconciliation distinguishes these from
|
|
unreserved scopes — applying the delta to reserved scopes (which
|
|
already carry +reserved on the counter) and the full actual to
|
|
unreserved ones (which were never charged).
|
|
"""
|
|
candidate = cls._lookup_stashed_value(
|
|
kwargs, standard_logging_metadata, TPM_RESERVED_SCOPES_KEY
|
|
)
|
|
if not isinstance(candidate, list):
|
|
return set()
|
|
scopes: Set[Tuple[str, str]] = set()
|
|
for entry in candidate:
|
|
if (
|
|
isinstance(entry, (list, tuple))
|
|
and len(entry) == 2
|
|
and isinstance(entry[0], str)
|
|
and isinstance(entry[1], str)
|
|
):
|
|
scopes.add((entry[0], entry[1]))
|
|
return scopes
|
|
|
|
@classmethod
|
|
def _is_reservation_released(
|
|
cls,
|
|
kwargs: Any,
|
|
standard_logging_metadata: Optional[Dict[str, Any]] = None,
|
|
) -> bool:
|
|
"""True if a prior callback already refunded this request's reservation."""
|
|
return bool(
|
|
cls._lookup_stashed_value(
|
|
kwargs, standard_logging_metadata, TPM_RESERVATION_RELEASED_KEY
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def _mark_reservation_released(data: Any) -> None:
|
|
"""
|
|
Stamp the released flag into every metadata channel a sibling
|
|
callback might read from. async_post_call_failure_hook receives the
|
|
request data dict; async_log_failure_event reads kwargs +
|
|
standard_logging_object.metadata. Same dict identity across
|
|
``request_data["metadata"]`` and ``kwargs["litellm_params"]["metadata"]``
|
|
means writes here propagate to the other hook.
|
|
"""
|
|
if not isinstance(data, dict):
|
|
return
|
|
for channel in ("metadata", "litellm_metadata"):
|
|
existing = data.get(channel)
|
|
if isinstance(existing, dict):
|
|
existing[TPM_RESERVATION_RELEASED_KEY] = True
|
|
litellm_params = data.get("litellm_params")
|
|
if isinstance(litellm_params, dict):
|
|
lp_metadata = litellm_params.get("metadata")
|
|
if isinstance(lp_metadata, dict):
|
|
lp_metadata[TPM_RESERVATION_RELEASED_KEY] = True
|
|
slo = data.get("standard_logging_object")
|
|
if isinstance(slo, dict):
|
|
slo_meta = slo.get("metadata")
|
|
if isinstance(slo_meta, dict):
|
|
slo_meta[TPM_RESERVATION_RELEASED_KEY] = True
|
|
|
|
def _collect_tpm_scope_targets(
|
|
self,
|
|
standard_logging_metadata: Dict[str, Any],
|
|
kwargs: Any,
|
|
model_group: Optional[str],
|
|
) -> List[Tuple[str, str]]:
|
|
"""
|
|
Enumerate every (scope_key, scope_value) pair that *might* carry a
|
|
TPM counter for this request — independent of whether each scope had
|
|
a configured TPM limit at pre-call. Reservation awareness happens at
|
|
the emitter; this helper just lists the candidate scopes so callers
|
|
can split reserved-vs-unreserved.
|
|
"""
|
|
user_api_key = standard_logging_metadata.get("user_api_key_hash")
|
|
user_api_key_user_id = standard_logging_metadata.get("user_api_key_user_id")
|
|
user_api_key_team_id = standard_logging_metadata.get("user_api_key_team_id")
|
|
user_api_key_organization_id = standard_logging_metadata.get(
|
|
"user_api_key_org_id"
|
|
)
|
|
user_api_key_project_id = standard_logging_metadata.get(
|
|
"user_api_key_project_id"
|
|
)
|
|
user_api_key_end_user_id = (
|
|
kwargs.get("user") if isinstance(kwargs, dict) else None
|
|
) or standard_logging_metadata.get("user_api_key_end_user_id")
|
|
agent_id = standard_logging_metadata.get("agent_id")
|
|
session_id = standard_logging_metadata.get(
|
|
"session_id"
|
|
) or standard_logging_metadata.get("trace_id")
|
|
|
|
targets: List[Tuple[str, str]] = []
|
|
if user_api_key:
|
|
targets.append(("api_key", user_api_key))
|
|
if user_api_key_user_id:
|
|
targets.append(("user", user_api_key_user_id))
|
|
if user_api_key_team_id:
|
|
targets.append(("team", user_api_key_team_id))
|
|
if user_api_key_team_id and user_api_key_user_id:
|
|
targets.append(
|
|
("team_member", f"{user_api_key_team_id}:{user_api_key_user_id}")
|
|
)
|
|
if user_api_key_end_user_id:
|
|
targets.append(("end_user", user_api_key_end_user_id))
|
|
if user_api_key_organization_id:
|
|
targets.append(("organization", user_api_key_organization_id))
|
|
if model_group:
|
|
if user_api_key:
|
|
targets.append(("model_per_key", f"{user_api_key}:{model_group}"))
|
|
if user_api_key_team_id:
|
|
targets.append(
|
|
("model_per_team", f"{user_api_key_team_id}:{model_group}")
|
|
)
|
|
if user_api_key_organization_id:
|
|
targets.append(
|
|
(
|
|
"model_per_organization",
|
|
f"{user_api_key_organization_id}:{model_group}",
|
|
)
|
|
)
|
|
if user_api_key_project_id:
|
|
targets.append(
|
|
(
|
|
"model_per_project",
|
|
f"{user_api_key_project_id}:{model_group}",
|
|
)
|
|
)
|
|
if agent_id:
|
|
targets.append(("agent", agent_id))
|
|
if session_id:
|
|
targets.append(("agent_session", f"{agent_id}:{session_id}"))
|
|
return targets
|
|
|
|
def _build_reservation_aware_tpm_ops(
|
|
self,
|
|
targets: List[Tuple[str, str]],
|
|
reserved_scopes: Set[Tuple[str, str]],
|
|
actual_tokens: int,
|
|
reserved_tokens: int,
|
|
) -> List[RedisPipelineIncrementOperation]:
|
|
"""
|
|
Emit per-scope TPM increment ops with reservation awareness.
|
|
|
|
- Reserved scope (counter already at +reserved from pre-call):
|
|
reconcile to actual via ``actual - reserved``.
|
|
- Unreserved scope (counter never touched at pre-call):
|
|
charge the full ``actual``.
|
|
|
|
Same primitive serves success reconciliation, over-reservation
|
|
release, and failure refund — pass ``actual_tokens=0`` for the pure
|
|
refund case (reserved scopes get -reserved, unreserved get 0/skip).
|
|
"""
|
|
ops: List[RedisPipelineIncrementOperation] = []
|
|
for scope_key, scope_value in targets:
|
|
if (scope_key, scope_value) in reserved_scopes:
|
|
increment = actual_tokens - reserved_tokens
|
|
else:
|
|
increment = actual_tokens
|
|
if increment == 0:
|
|
continue
|
|
ops.append(
|
|
RedisPipelineIncrementOperation(
|
|
key=self.create_rate_limit_keys(scope_key, scope_value, "tokens"),
|
|
increment_value=increment,
|
|
ttl=self.window_size,
|
|
)
|
|
)
|
|
return ops
|
|
|
|
def _build_success_event_pipeline_operations(
|
|
self,
|
|
kwargs: Any,
|
|
response_obj: Any,
|
|
rate_limit_type: Literal["output", "input", "total"],
|
|
) -> List[RedisPipelineIncrementOperation]:
|
|
"""Build Redis pipeline increment ops for TPM / parallel-request counters."""
|
|
from litellm.proxy.common_utils.callback_utils import (
|
|
get_model_group_from_litellm_kwargs,
|
|
)
|
|
|
|
# Get metadata from standard_logging_object - this correctly handles both
|
|
# 'metadata' and 'litellm_metadata' fields from litellm_params
|
|
standard_logging_object = kwargs.get("standard_logging_object") or {}
|
|
standard_logging_metadata = standard_logging_object.get("metadata") or {}
|
|
|
|
user_api_key = standard_logging_metadata.get("user_api_key_hash")
|
|
model_group = get_model_group_from_litellm_kwargs(kwargs)
|
|
|
|
# Get total tokens from response
|
|
total_tokens = 0
|
|
# spot fix for /responses api
|
|
if isinstance(response_obj, ModelResponse) or isinstance(
|
|
response_obj, BaseLiteLLMOpenAIResponseObject
|
|
):
|
|
_usage = getattr(response_obj, "usage", None)
|
|
total_tokens = self._get_total_tokens_from_usage(
|
|
usage=_usage, rate_limit_type=rate_limit_type
|
|
)
|
|
|
|
reserved_tokens = self._get_reserved_tokens_from_kwargs(
|
|
kwargs=kwargs,
|
|
standard_logging_metadata=standard_logging_metadata,
|
|
)
|
|
reserved_model = self._get_reserved_model_from_kwargs(
|
|
kwargs=kwargs,
|
|
standard_logging_metadata=standard_logging_metadata,
|
|
)
|
|
reserved_scopes = self._get_reserved_scopes_from_kwargs(
|
|
kwargs=kwargs,
|
|
standard_logging_metadata=standard_logging_metadata,
|
|
)
|
|
# Reconciliation must target the same model-scoped counter that the
|
|
# pre-call reservation incremented. If a reservation was made,
|
|
# ``reserved_model`` is authoritative; otherwise fall back to the
|
|
# router's ``model_group`` (covers the no-reservation charge path).
|
|
reconcile_model = reserved_model or model_group
|
|
|
|
pipeline_operations: List[RedisPipelineIncrementOperation] = []
|
|
|
|
# max_parallel_requests is its own counter (api-key only) — always decrement.
|
|
if user_api_key:
|
|
pipeline_operations.append(
|
|
RedisPipelineIncrementOperation(
|
|
key=self.create_rate_limit_keys(
|
|
key="api_key",
|
|
value=user_api_key,
|
|
rate_limit_type="max_parallel_requests",
|
|
),
|
|
increment_value=-1,
|
|
ttl=self.window_size,
|
|
)
|
|
)
|
|
|
|
# ----------------------------------------------------------------
|
|
# TPM reconciliation
|
|
# Per-scope behavior:
|
|
# reserved scope -> apply (actual - reserved) delta to settle
|
|
# the counter at +actual.
|
|
# unreserved scope -> charge the full actual usage (the
|
|
# reservation never incremented this scope).
|
|
# When no reservation was made, reserved_tokens=0 and reserved_scopes
|
|
# is empty, so every scope falls through the unreserved branch and
|
|
# gets the full actual charge — matching pre-PR behavior.
|
|
# ----------------------------------------------------------------
|
|
targets = self._collect_tpm_scope_targets(
|
|
standard_logging_metadata=standard_logging_metadata,
|
|
kwargs=kwargs,
|
|
model_group=reconcile_model,
|
|
)
|
|
if reserved_tokens > 0 and total_tokens < reserved_tokens:
|
|
verbose_proxy_logger.debug(
|
|
f"Releasing unused TPM budget on success: "
|
|
f"reserved={reserved_tokens}, actual={total_tokens}, "
|
|
f"release={reserved_tokens - total_tokens}"
|
|
)
|
|
pipeline_operations.extend(
|
|
self._build_reservation_aware_tpm_ops(
|
|
targets=targets,
|
|
reserved_scopes=reserved_scopes,
|
|
actual_tokens=total_tokens,
|
|
reserved_tokens=reserved_tokens,
|
|
)
|
|
)
|
|
|
|
return pipeline_operations
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
"""
|
|
Update TPM usage on successful API calls by incrementing counters using pipeline
|
|
"""
|
|
from litellm.litellm_core_utils.core_helpers import (
|
|
_get_parent_otel_span_from_kwargs,
|
|
)
|
|
|
|
rate_limit_type = self.get_rate_limit_type()
|
|
|
|
litellm_parent_otel_span: Union[Span, None] = _get_parent_otel_span_from_kwargs(
|
|
kwargs
|
|
)
|
|
try:
|
|
verbose_proxy_logger.debug(
|
|
"INSIDE parallel request limiter ASYNC SUCCESS LOGGING"
|
|
)
|
|
|
|
pipeline_operations = self._build_success_event_pipeline_operations(
|
|
kwargs=kwargs,
|
|
response_obj=response_obj,
|
|
rate_limit_type=rate_limit_type,
|
|
)
|
|
|
|
if pipeline_operations:
|
|
await self.async_increment_tokens_with_ttl_preservation(
|
|
pipeline_operations=pipeline_operations,
|
|
parent_otel_span=litellm_parent_otel_span,
|
|
)
|
|
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(
|
|
f"Error in rate limit success event: {str(e)}"
|
|
)
|
|
|
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
|
"""
|
|
On failure: decrement max_parallel_requests and refund the upfront
|
|
TPM reservation only against the scopes the reservation actually
|
|
charged. Unreserved scopes were never incremented at pre-call, so
|
|
refunding them would drive their counter negative.
|
|
"""
|
|
from litellm.litellm_core_utils.core_helpers import (
|
|
_get_parent_otel_span_from_kwargs,
|
|
)
|
|
|
|
try:
|
|
litellm_parent_otel_span: Union[Span, None] = (
|
|
_get_parent_otel_span_from_kwargs(kwargs)
|
|
)
|
|
standard_logging_object = kwargs.get("standard_logging_object") or {}
|
|
standard_logging_metadata = standard_logging_object.get("metadata") or {}
|
|
user_api_key = standard_logging_metadata.get("user_api_key_hash")
|
|
|
|
pipeline_operations: List[RedisPipelineIncrementOperation] = []
|
|
|
|
if user_api_key:
|
|
pipeline_operations.append(
|
|
RedisPipelineIncrementOperation(
|
|
key=self.create_rate_limit_keys(
|
|
key="api_key",
|
|
value=user_api_key,
|
|
rate_limit_type="max_parallel_requests",
|
|
),
|
|
increment_value=-1,
|
|
ttl=self.window_size,
|
|
)
|
|
)
|
|
|
|
# Skip the reservation refund if async_post_call_failure_hook
|
|
# already released it (proxy-level rejection that also bubbles up
|
|
# here as an LLM-error callback). max_parallel_requests is its
|
|
# own counter and is always decremented per call.
|
|
already_released = self._is_reservation_released(
|
|
kwargs=kwargs,
|
|
standard_logging_metadata=standard_logging_metadata,
|
|
)
|
|
reserved_tokens = (
|
|
0
|
|
if already_released
|
|
else self._get_reserved_tokens_from_kwargs(
|
|
kwargs=kwargs,
|
|
standard_logging_metadata=standard_logging_metadata,
|
|
)
|
|
)
|
|
if reserved_tokens > 0:
|
|
verbose_proxy_logger.debug(
|
|
f"Releasing reserved TPM tokens on failure: {reserved_tokens}"
|
|
)
|
|
# Refund only against the scopes the reservation actually
|
|
# charged. _build_reservation_aware_tpm_ops with
|
|
# actual_tokens=0 emits -reserved on reserved scopes and 0
|
|
# on unreserved (skipped), so unreserved scopes can't drift
|
|
# negative. Targets are derived purely from the reserved
|
|
# set so we don't even need to re-collect them from
|
|
# metadata.
|
|
reserved_scopes = self._get_reserved_scopes_from_kwargs(
|
|
kwargs=kwargs,
|
|
standard_logging_metadata=standard_logging_metadata,
|
|
)
|
|
pipeline_operations.extend(
|
|
self._build_reservation_aware_tpm_ops(
|
|
targets=list(reserved_scopes),
|
|
reserved_scopes=reserved_scopes,
|
|
actual_tokens=0,
|
|
reserved_tokens=reserved_tokens,
|
|
)
|
|
)
|
|
|
|
if pipeline_operations:
|
|
await self.internal_usage_cache.dual_cache.async_increment_cache_pipeline(
|
|
increment_list=pipeline_operations,
|
|
litellm_parent_otel_span=litellm_parent_otel_span,
|
|
)
|
|
if reserved_tokens > 0:
|
|
self._mark_reservation_released(kwargs)
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(
|
|
f"Error in rate limit failure event: {str(e)}"
|
|
)
|
|
|
|
async def async_post_call_success_hook(
|
|
self, data: dict, user_api_key_dict: UserAPIKeyAuth, response
|
|
):
|
|
"""
|
|
Post-call hook to update rate limit headers in the response.
|
|
"""
|
|
try:
|
|
from pydantic import BaseModel
|
|
|
|
litellm_proxy_rate_limit_response = cast(
|
|
Optional[RateLimitResponse],
|
|
data.get("litellm_proxy_rate_limit_response", None),
|
|
)
|
|
|
|
if litellm_proxy_rate_limit_response is not None:
|
|
# Update response headers
|
|
if hasattr(response, "_hidden_params"):
|
|
_hidden_params = getattr(response, "_hidden_params")
|
|
else:
|
|
_hidden_params = None
|
|
|
|
if _hidden_params is not None and (
|
|
isinstance(_hidden_params, BaseModel)
|
|
or isinstance(_hidden_params, dict)
|
|
):
|
|
if isinstance(_hidden_params, BaseModel):
|
|
_hidden_params = _hidden_params.model_dump()
|
|
|
|
_additional_headers = (
|
|
_hidden_params.get("additional_headers", {}) or {}
|
|
)
|
|
|
|
# Add rate limit headers
|
|
for status in litellm_proxy_rate_limit_response["statuses"]:
|
|
prefix = f"x-ratelimit-{status['descriptor_key']}"
|
|
_additional_headers[
|
|
f"{prefix}-remaining-{status['rate_limit_type']}"
|
|
] = status["limit_remaining"]
|
|
_additional_headers[
|
|
f"{prefix}-limit-{status['rate_limit_type']}"
|
|
] = status["current_limit"]
|
|
|
|
setattr(
|
|
response,
|
|
"_hidden_params",
|
|
{**_hidden_params, "additional_headers": _additional_headers},
|
|
)
|
|
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(
|
|
f"Error in rate limit post-call hook: {str(e)}"
|
|
)
|
|
|
|
async def async_post_call_failure_hook(
|
|
self,
|
|
request_data: dict,
|
|
original_exception: Exception,
|
|
user_api_key_dict: UserAPIKeyAuth,
|
|
traceback_str: Optional[str] = None,
|
|
) -> None:
|
|
"""
|
|
Release any TPM reservation when the request is rejected after the
|
|
pre-call hook reserved tokens but before the LLM call ran (e.g. a
|
|
downstream guardrail/auth hook raised). Without this, those
|
|
reservations are stranded — async_log_failure_event is a litellm
|
|
completion-level callback and never fires for proxy-side rejections.
|
|
|
|
Idempotent via TPM_RESERVATION_RELEASED_KEY: if both this hook and
|
|
async_log_failure_event end up running in the same flow, only the
|
|
first refund applies.
|
|
"""
|
|
try:
|
|
if self._is_reservation_released(kwargs=request_data):
|
|
return
|
|
reserved_tokens = self._get_reserved_tokens_from_kwargs(kwargs=request_data)
|
|
if reserved_tokens <= 0:
|
|
return
|
|
|
|
# Refund directly against the descriptors we reserved against —
|
|
# the pre-call hook stashes them in the request-data metadata
|
|
# channels before success/failure callbacks run.
|
|
stashed = self._lookup_stashed_value(
|
|
kwargs=request_data,
|
|
standard_logging_metadata=None,
|
|
key=RATE_LIMIT_DESCRIPTORS_KEY,
|
|
)
|
|
descriptors: List[RateLimitDescriptor] = (
|
|
stashed if isinstance(stashed, list) else []
|
|
)
|
|
ops: List[RedisPipelineIncrementOperation] = []
|
|
for descriptor in descriptors:
|
|
rate_limit = descriptor.get("rate_limit") or {}
|
|
if rate_limit.get("tokens_per_unit") is None:
|
|
continue
|
|
ops.append(
|
|
RedisPipelineIncrementOperation(
|
|
key=self.create_rate_limit_keys(
|
|
descriptor["key"],
|
|
descriptor["value"],
|
|
"tokens",
|
|
),
|
|
increment_value=-reserved_tokens,
|
|
ttl=self.window_size,
|
|
)
|
|
)
|
|
if ops:
|
|
verbose_proxy_logger.debug(
|
|
f"Releasing reserved TPM tokens on proxy-level "
|
|
f"rejection: {reserved_tokens}"
|
|
)
|
|
await self.internal_usage_cache.dual_cache.async_increment_cache_pipeline(
|
|
increment_list=ops,
|
|
litellm_parent_otel_span=user_api_key_dict.parent_otel_span,
|
|
)
|
|
self._mark_reservation_released(request_data)
|
|
except Exception as e:
|
|
verbose_proxy_logger.exception(
|
|
f"Error releasing TPM reservation on post-call failure: {e}"
|
|
)
|
|
return None
|