Litellm jwt mapping virtualkeys (#28510)

* restore an explicit no-match policy

* fix(jwt): fix AUTO_REGISTER sentinel bypass, race condition, and inline import comment

- AUTO_REGISTER now evicts stale __NO_MAPPING__ sentinel instead of silently
  returning None when cached under a prior fallback_team_mapping config
- Race condition in _auto_register_jwt_mapping: catch P2002 unique-constraint
  violation on concurrent creates, fetch the winning mapping, proceed cleanly
- Added comment on inline generate_key_helper_fn import explaining the circular
  dependency (key_management_endpoints imports user_api_key_auth at line 51)
- 3 new tests: stale sentinel eviction, race condition winner fallback, and the
  existing auto_register happy path

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(jwt): cache __NO_MAPPING__ sentinel before raising 403 in REJECT mode

REJECT mode was raising HTTPException immediately on a DB miss without writing
the __NO_MAPPING__ sentinel, causing every subsequent rejected request to
re-query the DB. Write the sentinel first so repeated rejections are served
from cache within virtual_key_mapping_cache_ttl.

Adds test asserting DB is not hit on the second reject after a cache-warm miss.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(jwt): enforce no-match policy when prisma_client is None

The early `if prisma_client is None: return None` guard ran before the
no-match policy check, silently bypassing REJECT and AUTO_REGISTER — every
JWT client fell through to team auth regardless of configuration.

Fix: treat prisma_client=None as a definitive DB miss and fall through to the
same policy block as a real miss. REJECT now raises 403, AUTO_REGISTER raises
500 with a clear message (can't create keys without a DB), FALLBACK_TEAM_MAPPING
returns None unchanged.

Adds three tests: REJECT/403 with no DB, FALLBACK returns None with no DB,
AUTO_REGISTER/500 with no DB.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

* fix(jwt): consistent AUTO_REGISTER on cached sentinel; clean up race orphans

Addresses Greptile review on PR #25570 cherry-pick.

1. Inconsistent AUTO_REGISTER when __NO_MAPPING__ sentinel is cached:
   The cached-sentinel branch silently returned None when prisma_client was
   None, while the fresh path raised HTTP 500 under the same config. Same
   request, different access-control outcome depending on cache state. Both
   paths now raise the same 500.

2. Orphaned virtual keys from race-condition losers:
   On unique-constraint conflict, generate_key_helper_fn had already persisted
   an unrestricted virtual key in LiteLLM_VerificationToken with the cleartext
   in request memory. Under sustained concurrency these accumulated
   indefinitely. The loser now deletes its orphan before falling back to the
   winner's mapping; failure to delete is logged but does not fail the request.

Also corrects a latent FK bug surfaced while fixing #2: the mapping row was
storing the plaintext key in LiteLLM_JWTKeyMapping.token, but that column FKs
to the hashed LiteLLM_VerificationToken.token — now hashed at the call site.

Tests:
- updated test_auto_register_creates_key_and_mapping to assert the hashed
  token is stored, not the plaintext
- updated test_auto_register_race_condition_unique_conflict to assert the
  orphan is deleted with the correct hashed token
- added test_auto_register_raises_500_when_sentinel_cached_and_no_db
- added test_auto_register_race_conflict_tolerates_delete_failure

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(jwt): close REJECT bypass when JWT omits the configured claim field

A JWT presented without the configured `virtual_key_claim_field` previously
returned None at the `claim_value is None` guard before the
`unregistered_jwt_client_behavior` check ran. A caller who knows the configured
claim-field name could bypass REJECT by simply omitting that field and falling
through to team-based JWT auth.

Apply the no-match policy on a missing claim:
  - REJECT          → 403
  - AUTO_REGISTER   → 403 (no stable identity to map; refuse rather than
                     create a sentinel-keyed record)
  - FALLBACK_TEAM_MAPPING → return None (unchanged, backward-compatible)

Adds three tests covering each branch of the missing-claim path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(jwt): AUTO_REGISTER inherits team_id so keys are bounded by team limits

Auto-registered virtual keys were created with no team, model, route, rate, or
budget constraints — broader access than the standard team-based JWT auth path
the same client would have taken. Under AUTO_REGISTER, resolve the team_id
from the JWT (via the operator-configured team_id_jwt_field / team_id_default)
and stamp it on the new key. Downstream auth then applies the team's
budget/models/tpm/rpm/allowed_routes via the existing virtual-key flow.

Policy when team_id_jwt_field is configured:
  - JWT carries team claim → stamp resolved team_id
  - JWT lacks claim + team_id_default set → stamp default
  - JWT lacks claim + no default → 403 (refuse to create an unbounded key)

When neither team_id_jwt_field nor team_id_default is configured, the
operator has explicitly opted out of team-based limits — the auto-created
key has no team_id (matches what team-auth would do in the same config).

Adds 4 tests covering each branch.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(jwt): make AUTO_REGISTER functional in prod; raise on missing winner

Two correctness fixes flagged by Greptile on the AUTO_REGISTER path:

1. generate_key_helper_fn was called without table_name="key". Without that,
   the helper falls into the user-upsert branch (table_name in (None, "user"))
   and tries to insert into LiteLLM_UserTable with user_id=None, which hits
   the NOT NULL @id constraint. AUTO_REGISTER would never have succeeded in
   production. Now passes table_name="key" explicitly, matching the
   /key/generate caller.

2. When the race loser refetches the winner's mapping and gets None (winner
   row concurrently deleted), the previous code returned None — and the
   caller in _resolve_jwt_to_virtual_key then fell through to less-
   restrictive team-based JWT auth, silently bypassing the configured
   AUTO_REGISTER policy. Now raises HTTP 503 so the caller retries against
   a stable state rather than getting unintended fallback access.

Adds one test for the 503 winner-vanishes path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(jwt): defer AUTO_REGISTER until JWT policy is enforced by auth_builder

Closes the JWT policy bypass on the AUTO_REGISTER path flagged by veria-ai.

Before: when unregistered_jwt_client_behavior=auto_register and the JWT's
claim was unmapped, _resolve_jwt_to_virtual_key validated the JWT signature
and then immediately created a virtual key + mapping. JWTAuthManager.auth_builder
never ran for the first request (the new key short-circuited the team-auth
path), and every subsequent request hit the cached mapping — so custom_validate,
RBAC, scope_mappings, and user_allowed_email_domain were never enforced for
auto-registered clients.

After: _resolve_jwt_to_virtual_key returns a _PendingAutoRegister signal
instead of creating the key. The caller in _user_api_key_auth_builder runs
JWTAuthManager.auth_builder, then — only on a validated, policy-passing
result — calls _auto_register_jwt_mapping with the team_id / user_id from
that result. The created key inherits team + user limits from the validated
identity, and future cache hits load that already-policy-checked key.

Also drops the interim _resolve_inherited_team_id helper that pulled team_id
from raw JWT claims — same bypass risk; team_id now comes exclusively from
auth_builder.

Tests:
  - Rewrote two existing tests to assert _resolve_jwt_to_virtual_key returns
    _PendingAutoRegister (no key created yet) for both the fresh-DB-miss
    and stale-sentinel branches
  - Added a contract test that _auto_register_jwt_mapping stamps the
    validated team_id/user_id onto generate_key_helper_fn
  - Removed four stale team-binding tests that exercised the prior
    raw-claim helper

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* Update user_api_key_auth.py

* fix(jwt): cache proxy-admin AUTO_REGISTER path to avoid repeated DB lookups

Cache-miss regression introduced by the deferred-auto-register refactor:
when a JWT under AUTO_REGISTER resolved to a proxy admin, the is_proxy_admin
early-return in _user_api_key_auth_builder ran *before* the pending
auto-register cache-write block. Result: no cache entry, so every
subsequent proxy-admin request re-queried get_jwt_key_mapping_object
indefinitely.

Fix: write a __JWT_PROXY_ADMIN__ sentinel to user_api_key_cache before the
early return when a pending auto-register existed. _resolve_jwt_to_virtual_key
treats that sentinel as "skip mapping, fall through to auth_builder", so
future requests from the same JWT identity hit the cache instead of the DB.
auth_builder still runs full JWT policy on every request — only the
mapping DB lookup is short-circuited.

Adds one test asserting the sentinel cache-hit returns None without
hitting prisma_client.db.litellm_jwtkeymapping.find_first.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>

* fix(proxy): stamp org context on JWT auto-registered keys

AUTO_REGISTER keys were created with team_id and user_id only, so org budget checks were skipped after switching to the key-scoped path.

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Shivam Rawat 2026-06-04 19:00:36 -07:00 committed by GitHub
parent 41e90a6ada
commit 3bd89f209e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1304 additions and 15 deletions

View File

@ -4448,6 +4448,25 @@ class JWTRoutingOverride(BaseModel):
}
class UnregisteredJWTClientBehavior(str, enum.Enum):
"""
Controls what happens when `virtual_key_claim_field` is configured but the
JWT claim value has no registered mapping in `litellm_jwtkeymapping`.
- fallback_team_mapping: Fall through to standard team-based JWT auth (default,
backward-compatible).
- reject: Immediately return HTTP 403. Use this when every valid JWT client
must have a pre-registered virtual key unknown callers are denied.
- auto_register: Automatically create a new virtual key and mapping on first
encounter. The new key has no budget/model restrictions; admins can tighten
it later via /jwt_client/update.
"""
FALLBACK_TEAM_MAPPING = "fallback_team_mapping"
REJECT = "reject"
AUTO_REGISTER = "auto_register"
class JWTIssuerConfig(BaseModel):
"""
Issuer-bound JWT validation configuration.
@ -4614,6 +4633,15 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
default=300,
description="TTL (seconds) for caching JWT-to-virtual-key mapping lookups.",
)
unregistered_jwt_client_behavior: UnregisteredJWTClientBehavior = Field(
default=UnregisteredJWTClientBehavior.FALLBACK_TEAM_MAPPING,
description=(
"What to do when virtual_key_claim_field is set but the JWT claim value "
"has no registered mapping. 'fallback_team_mapping' (default): fall through "
"to team-based JWT auth. 'reject': return HTTP 403. "
"'auto_register': auto-create a virtual key and mapping on first encounter."
),
)
routing_overrides: Optional[List[JWTRoutingOverride]] = Field(
default=None,
description="Optional claim-based routing overrides for JWT-shaped tokens. Matching rules route requests to oauth2 before default JWT flow.",
@ -4633,6 +4661,13 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
# ``s3://`` / ``gcs://`` when this is None.
config_file_path = kwargs.pop("config_file_path", None)
# Backward-compat: jwt_client_id_field was renamed to virtual_key_claim_field
if "jwt_client_id_field" in kwargs:
if "virtual_key_claim_field" not in kwargs:
kwargs["virtual_key_claim_field"] = kwargs.pop("jwt_client_id_field")
else:
kwargs.pop("jwt_client_id_field")
# get the attribute names for this Pydantic model
allowed_keys = LiteLLM_JWTAuth.__annotations__.keys()

View File

@ -12,7 +12,7 @@ import fnmatch
import re
import secrets
from datetime import datetime, timezone
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
from typing import Any, Dict, Iterator, NamedTuple, List, Optional, Tuple, Union, cast
import fastapi
from fastapi import HTTPException, Request, WebSocket, status
@ -607,6 +607,169 @@ async def check_api_key_for_custom_headers_or_pass_through_endpoints(
return api_key
# Cache sentinel written when a JWT under AUTO_REGISTER resolved to a proxy
# admin via auth_builder. Proxy admins don't need a mapped virtual key (they
# have full access via auth_builder anyway), but without a cache entry every
# subsequent request from the same JWT identity would re-query the DB for a
# non-existent mapping. Sentinel tells _resolve_jwt_to_virtual_key to skip
# the lookup and return None (caller proceeds to auth_builder).
_JWT_PROXY_ADMIN_SENTINEL = "__JWT_PROXY_ADMIN__"
class _PendingAutoRegister(NamedTuple):
"""
Signal returned by ``_resolve_jwt_to_virtual_key`` when the JWT's claim is
unmapped and ``unregistered_jwt_client_behavior`` is AUTO_REGISTER.
The caller MUST run standard ``JWTAuthManager.auth_builder`` to apply RBAC,
scope mappings, ``custom_validate``, and ``user_allowed_email_domain``
policy BEFORE calling ``_auto_register_jwt_mapping`` with the validated
``team_id`` / ``user_id`` from the auth_builder result. Auto-registering
purely on a signature-valid JWT (the old behavior) bypassed every JWT
policy beyond signature verification.
"""
claim_field: str
claim_value: str
cache_key: str
async def _auto_register_jwt_mapping(
virtual_key_claim_field: str,
claim_value: str,
jwt_handler: JWTHandler,
prisma_client: PrismaClient,
user_api_key_cache: UserApiKeyCache,
parent_otel_span: Optional[Span],
proxy_logging_obj: ProxyLogging,
cache_key: str,
team_id: Optional[str] = None,
user_id: Optional[str] = None,
org_id: Optional[str] = None,
end_user_id: Optional[str] = None,
) -> Optional[UserAPIKeyAuth]:
"""
Auto-register: create a new virtual key + mapping for an unrecognised JWT
claim value. ``team_id`` and ``user_id`` must come from a successful
``JWTAuthManager.auth_builder`` run they encode the JWT identity AFTER
RBAC/scope/custom_validate/email-domain policy has been enforced. The key
is stamped with those values so the cached future-request path inherits
the same team/user/org limits the auth_builder path would have applied.
Race safety: if two concurrent requests both reach here simultaneously (both
saw no mapping in the DB), one will win the unique-constraint race on
litellm_jwtkeymapping. The loser catches the conflict, deletes its orphaned
key, fetches the winner's mapping, and proceeds — no error surfaced.
"""
# Inline import required: key_management_endpoints imports user_api_key_auth
# (line 51) so a module-level import here would create a circular dependency.
from litellm.proxy.management_endpoints.key_management_endpoints import (
generate_key_helper_fn,
)
# ``table_name="key"`` is required: without it, generate_key_helper_fn
# falls into the user-upsert branch (`table_name is None or "user"`) and
# attempts to insert into LiteLLM_UserTable with user_id=None, which fails
# the NOT NULL @id constraint. Every successful key-creation caller (e.g.
# /key/generate) passes table_name="key" explicitly.
key_data = await generate_key_helper_fn(
request_type="key",
table_name="key",
team_id=team_id,
user_id=user_id,
organization_id=org_id,
metadata={
"auto_registered": True,
"jwt_claim_field": virtual_key_claim_field,
"jwt_claim_value": claim_value,
},
)
# generate_key_helper_fn returns the plaintext key in "token"; the persisted
# row in LiteLLM_VerificationToken uses its hash, so hash here to get the FK
# value referenced by LiteLLM_JWTKeyMapping.token.
token_hash = hash_token(key_data["token"])
try:
await prisma_client.db.litellm_jwtkeymapping.create(
data={
"jwt_claim_name": virtual_key_claim_field,
"jwt_claim_value": claim_value,
"token": token_hash,
"created_by": "auto_register",
"updated_by": "auto_register",
}
)
except Exception as e:
error_str = str(e).lower()
if "unique" in error_str or "p2002" in error_str:
# A concurrent request won the race. The key generate_key_helper_fn
# just persisted to LiteLLM_VerificationToken is orphaned — nothing
# maps to it, but it's a fully valid unrestricted API key sitting in
# the DB and the cleartext is in memory on this request. Delete it
# so orphans don't accumulate under sustained concurrency.
verbose_proxy_logger.debug(
"JWT Key Mapping (auto_register): unique conflict on create — "
"deleting orphaned virtual key and fetching winner's mapping for %s='%s'.",
virtual_key_claim_field,
claim_value,
)
try:
await prisma_client.db.litellm_verificationtoken.delete(
where={"token": token_hash}
)
except Exception as delete_err:
# Don't fail the request if cleanup fails — the orphan is
# unmapped and inert. Log so an operator can prune it later.
verbose_proxy_logger.warning(
"JWT Key Mapping (auto_register): failed to delete orphaned key after race: %s",
delete_err,
)
token_hash = await get_jwt_key_mapping_object(
jwt_claim_name=virtual_key_claim_field,
jwt_claim_value=claim_value,
prisma_client=prisma_client,
)
if token_hash is None:
# The winner's mapping vanished between the unique-constraint
# conflict and our re-fetch (concurrent delete). Returning None
# here would silently fall through to team-based JWT auth —
# a less-restrictive path than the operator configured. Raise
# 503 so the caller retries against a stable state instead.
raise HTTPException(
status_code=503,
detail=(
"JWT Key Mapping: AUTO_REGISTER race resolution failed — "
"winner's mapping was concurrently removed. Retry the request."
),
)
else:
raise
await user_api_key_cache.async_set_cache(
key=cache_key,
value=token_hash,
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
)
verbose_proxy_logger.info(
"JWT Key Mapping (auto_register): created new virtual key for %s='%s'.",
virtual_key_claim_field,
claim_value,
)
auto_registered_key = await get_key_object(
hashed_token=token_hash,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if auto_registered_key is not None:
auto_registered_key.org_id = org_id
auto_registered_key.end_user_id = end_user_id
return auto_registered_key
async def _resolve_jwt_to_virtual_key(
jwt_claims: dict,
jwt_handler: JWTHandler,
@ -614,7 +777,22 @@ async def _resolve_jwt_to_virtual_key(
user_api_key_cache: UserApiKeyCache,
parent_otel_span: Optional[Span],
proxy_logging_obj: ProxyLogging,
) -> Optional[UserAPIKeyAuth]:
) -> Union[Optional[UserAPIKeyAuth], "_PendingAutoRegister"]:
"""
Returns:
- ``UserAPIKeyAuth``: a resolved virtual key (cache hit or DB hit). The
caller may use this directly; JWT policy has been enforced previously
(at key-creation time or, for cached results, before caching).
- ``_PendingAutoRegister``: claim is unmapped and behavior is AUTO_REGISTER.
The caller MUST run ``JWTAuthManager.auth_builder`` to enforce JWT
policy (RBAC, scope, custom_validate, email-domain), then invoke
``_auto_register_jwt_mapping`` with the validated team_id/user_id.
- ``None``: claim is unmapped and behavior is FALLBACK_TEAM_MAPPING.
The caller falls through to standard team-based JWT auth (which itself
enforces full JWT policy via auth_builder).
- Raises HTTPException: REJECT policy hit, missing claim under
REJECT/AUTO_REGISTER, or other policy violations.
"""
virtual_key_claim_field = jwt_handler.litellm_jwtauth.virtual_key_claim_field
if virtual_key_claim_field is None:
return None
@ -629,12 +807,61 @@ async def _resolve_jwt_to_virtual_key(
verbose_proxy_logger.debug(
f"JWT Key Mapping: Claim field '{virtual_key_claim_field}' not found in JWT claims."
)
# A missing claim is an unmapped client — apply the no-match policy
# rather than returning early. Otherwise a caller can bypass REJECT
# simply by presenting a JWT that omits the configured field. For
# AUTO_REGISTER there is no stable identity to map without a claim
# value, so we deny rather than create a sentinel-keyed record.
behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior
if behavior in (
UnregisteredJWTClientBehavior.REJECT,
UnregisteredJWTClientBehavior.AUTO_REGISTER,
):
raise HTTPException(
status_code=403,
detail=(
f"JWT Key Mapping: Required claim '{virtual_key_claim_field}' "
"is missing from the JWT. Access denied."
),
)
return None
cache_key = f"jwt_key_mapping:{virtual_key_claim_field}:{claim_value}"
cached_mapping = await user_api_key_cache.async_get_cache(cache_key)
if cached_mapping == _JWT_PROXY_ADMIN_SENTINEL:
# Previously resolved to a proxy admin via auth_builder; skip the
# mapping lookup and let the caller re-run auth_builder. Avoids a
# repeated DB hit on every proxy-admin request under AUTO_REGISTER.
return None
if cached_mapping == "__NO_MAPPING__":
behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior
if behavior == UnregisteredJWTClientBehavior.REJECT:
raise HTTPException(
status_code=403,
detail=f"JWT Key Mapping: No registered mapping for {virtual_key_claim_field}='{claim_value}'. Access denied.",
)
if behavior == UnregisteredJWTClientBehavior.AUTO_REGISTER:
# Stale sentinel written under a prior fallback_team_mapping config —
# evict it and defer auto-register to after auth_builder runs. Raise
# the same 500 as the fresh-path AUTO_REGISTER branch when there is
# no DB, so behavior is consistent regardless of whether the cache
# happens to hold the sentinel.
if prisma_client is None:
raise HTTPException(
status_code=500,
detail=(
"JWT Key Mapping: AUTO_REGISTER requires a database connection. "
"Configure a database or change unregistered_jwt_client_behavior."
),
)
await user_api_key_cache.async_delete_cache(cache_key)
return _PendingAutoRegister(
claim_field=virtual_key_claim_field,
claim_value=str(claim_value),
cache_key=cache_key,
)
return None
elif cached_mapping is not None:
return await get_key_object(
@ -645,14 +872,15 @@ async def _resolve_jwt_to_virtual_key(
proxy_logging_obj=proxy_logging_obj,
)
if prisma_client is None:
return None
token_hash = await get_jwt_key_mapping_object(
jwt_claim_name=virtual_key_claim_field,
jwt_claim_value=str(claim_value),
prisma_client=prisma_client,
)
# Resolve the mapping from DB, or treat prisma_client=None as a definitive
# miss (no DB → no mapping can exist → apply no-match policy below).
token_hash: Optional[str] = None
if prisma_client is not None:
token_hash = await get_jwt_key_mapping_object(
jwt_claim_name=virtual_key_claim_field,
jwt_claim_value=str(claim_value),
prisma_client=prisma_client,
)
if token_hash is not None:
await user_api_key_cache.async_set_cache(
@ -667,13 +895,50 @@ async def _resolve_jwt_to_virtual_key(
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
else:
# No mapping found (DB miss or no DB) — apply no-match policy.
behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior
if behavior == UnregisteredJWTClientBehavior.REJECT:
# Cache the miss before raising so repeated rejections are served from
# cache and don't re-query the DB on every request.
await user_api_key_cache.async_set_cache(
key=cache_key,
value="__NO_MAPPING__",
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
)
return None
raise HTTPException(
status_code=403,
detail=f"JWT Key Mapping: No registered mapping for {virtual_key_claim_field}='{claim_value}'. Access denied.",
)
if behavior == UnregisteredJWTClientBehavior.AUTO_REGISTER:
if prisma_client is None:
raise HTTPException(
status_code=500,
detail=(
"JWT Key Mapping: AUTO_REGISTER requires a database connection. "
"Configure a database or change unregistered_jwt_client_behavior."
),
)
# Defer: caller runs JWTAuthManager.auth_builder to enforce RBAC, scope,
# custom_validate, and email-domain policy, then auto-registers using
# the validated identity. Auto-registering here on a signature-only
# JWT would bypass every JWT policy beyond signature verification.
return _PendingAutoRegister(
claim_field=virtual_key_claim_field,
claim_value=str(claim_value),
cache_key=cache_key,
)
# FALLBACK_TEAM_MAPPING (default): cache the miss and return None so the
# caller falls through to standard team-based JWT auth.
await user_api_key_cache.async_set_cache(
key=cache_key,
value="__NO_MAPPING__",
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
)
return None
def _ensure_parent_otel_span_on_request_state(request: Request) -> None:
@ -893,6 +1158,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
# Try JWT-to-Virtual-Key mapping first to avoid
# unnecessary DB queries in auth_builder
do_standard_jwt_auth = True
pending_auto_register: Optional[_PendingAutoRegister] = None
if jwt_handler.litellm_jwtauth.virtual_key_claim_field is not None:
# Decode JWT to get claims without running full auth_builder
jwt_claims: Optional[dict]
@ -901,7 +1167,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
else:
jwt_claims = await jwt_handler.auth_jwt(token=api_key)
valid_token = await _resolve_jwt_to_virtual_key(
resolve_result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
@ -909,11 +1175,19 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
if valid_token is not None:
if isinstance(resolve_result, UserAPIKeyAuth):
valid_token = resolve_result
api_key = valid_token.token or ""
valid_token.jwt_claims = jwt_claims
do_standard_jwt_auth = False
# Fall through to virtual key checks
elif isinstance(resolve_result, _PendingAutoRegister):
# Run full JWT policy (RBAC, scope, custom_validate,
# email-domain) via auth_builder, then create the key
# from the validated identity below.
pending_auto_register = resolve_result
# else: None → FALLBACK_TEAM_MAPPING, falls through to
# standard JWT auth_builder below
if do_standard_jwt_auth:
with tracer.trace("litellm.proxy.auth.jwt_auth_builder"):
@ -946,6 +1220,19 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
jwt_claims = result.get("jwt_claims", None)
if is_proxy_admin:
# Proxy admins authenticate via auth_builder (full
# access), not via a mapped virtual key. If
# AUTO_REGISTER was pending, cache a sentinel so
# future requests from this JWT identity skip the
# DB mapping lookup in _resolve_jwt_to_virtual_key.
# Without this, every proxy-admin request under
# AUTO_REGISTER re-hits get_jwt_key_mapping_object.
if pending_auto_register is not None:
await user_api_key_cache.async_set_cache(
key=pending_auto_register.cache_key,
value=_JWT_PROXY_ADMIN_SENTINEL,
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
)
return UserAPIKeyAuth(
api_key=None,
user_role=LitellmUserRoles.PROXY_ADMIN,
@ -1032,6 +1319,32 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
else None
)
# AUTO_REGISTER deferred from _resolve_jwt_to_virtual_key.
# JWT policy (RBAC, scope, custom_validate, email-domain)
# has now been enforced by auth_builder above. Create the
# mapping + virtual key from the *validated* identity, then
# replace valid_token with the new key so downstream checks
# use the key-scoped path.
if pending_auto_register is not None and prisma_client is not None:
auto_registered = await _auto_register_jwt_mapping(
virtual_key_claim_field=pending_auto_register.claim_field,
claim_value=pending_auto_register.claim_value,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
cache_key=pending_auto_register.cache_key,
team_id=team_id,
user_id=user_id,
org_id=org_id,
end_user_id=end_user_id,
)
if auto_registered is not None:
auto_registered.jwt_claims = jwt_claims
valid_token = auto_registered
api_key = valid_token.token or ""
# Check if model has zero cost - if so, skip all budget checks
model = _get_model_from_request_context(
request_data=request_data,

View File

@ -27,7 +27,6 @@ from litellm.proxy.management_endpoints.jwt_key_mapping_endpoints import (
from litellm.caching.caching import DualCache
from fastapi import HTTPException
# ──────────────────────────────────────────────
# Tests: _resolve_jwt_to_virtual_key
# ──────────────────────────────────────────────
@ -454,3 +453,856 @@ async def test_create_success_returns_response_without_token():
assert isinstance(result, JWTKeyMappingResponse)
assert "token" not in result.model_fields
assert result.jwt_claim_name == "email"
# ──────────────────────────────────────────────
# Tests: unregistered_jwt_client_behavior
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_reject_behavior_raises_403_on_no_mapping():
"""
When unregistered_jwt_client_behavior='reject' and no mapping exists,
_resolve_jwt_to_virtual_key must raise HTTP 403.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
)
jwt_claims = {"email": "unknown@example.com"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
user_api_key_cache = DualCache()
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
):
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
assert "unknown@example.com" in exc_info.value.detail
@pytest.mark.asyncio
async def test_reject_behavior_caches_sentinel_after_db_miss():
"""
On a fresh DB miss with REJECT, the __NO_MAPPING__ sentinel must be written
to cache so that subsequent rejected requests are served from cache and do
not re-query the DB.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
virtual_key_mapping_cache_ttl=300,
)
jwt_claims = {"email": "unknown@example.com"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
user_api_key_cache = DualCache()
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
):
# First call — DB miss, should raise 403 and write sentinel
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
# Sentinel must now be in cache
cached = await user_api_key_cache.async_get_cache(
"jwt_key_mapping:email:unknown@example.com"
)
assert cached == "__NO_MAPPING__"
# Second call — must raise 403 from cache, no additional DB hit
prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
with pytest.raises(HTTPException) as exc_info2:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info2.value.status_code == 403
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_reject_behavior_raises_403_on_cached_no_mapping():
"""
When the negative-cache sentinel __NO_MAPPING__ is present and behavior is
'reject', the function must also raise HTTP 403 (not return None silently).
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
)
jwt_claims = {"email": "unknown@example.com"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
# Pre-populate the negative cache so the DB is not hit
user_api_key_cache = DualCache()
cache_key = "jwt_key_mapping:email:unknown@example.com"
await user_api_key_cache.async_set_cache(cache_key, "__NO_MAPPING__")
with patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
):
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
# DB must NOT have been hit (sentinel served from cache)
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_auto_register_returns_pending_signal_without_creating_key():
"""
Security: when unregistered_jwt_client_behavior='auto_register' and no
mapping exists, _resolve_jwt_to_virtual_key must NOT create the key yet.
It returns a _PendingAutoRegister signal so the caller can run
JWTAuthManager.auth_builder (enforcing RBAC, scope mappings,
custom_validate, user_allowed_email_domain) FIRST. Creating the key here
would bypass every JWT policy beyond signature verification.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
from litellm.proxy.auth.user_api_key_auth import _PendingAutoRegister
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
jwt_claims = {"sub": "new-user-42"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
user_api_key_cache = DualCache()
with patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
) as mock_gen_key:
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert isinstance(result, _PendingAutoRegister)
assert result.claim_field == "sub"
assert result.claim_value == "new-user-42"
assert result.cache_key == "jwt_key_mapping:sub:new-user-42"
# CRITICAL: no key was created — that must wait until after auth_builder
mock_gen_key.assert_not_called()
prisma_client.db.litellm_jwtkeymapping.create.assert_not_called()
@pytest.mark.asyncio
async def test_auto_register_creates_key_and_mapping_when_helper_invoked():
"""
When the caller invokes _auto_register_jwt_mapping directly (after
auth_builder validation), the helper creates the key + mapping row and
returns a UserAPIKeyAuth. The mapping row stores the hashed token (FK to
LiteLLM_VerificationToken), not the plaintext key.
"""
from litellm.proxy._types import hash_token
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
virtual_key_mapping_cache_ttl=300,
)
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
user_api_key_cache = DualCache()
plaintext_key = "sk-auto-key"
expected_hash = hash_token(plaintext_key)
mock_key_obj = UserAPIKeyAuth(token=expected_hash, team_id="validated-team")
with (
patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object",
new_callable=AsyncMock,
) as mock_get_key,
patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
) as mock_gen_key,
):
mock_gen_key.return_value = {"token": plaintext_key, "key": plaintext_key}
mock_get_key.return_value = mock_key_obj
result = await _auto_register_jwt_mapping(
virtual_key_claim_field="sub",
claim_value="new-user-42",
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
cache_key="jwt_key_mapping:sub:new-user-42",
team_id="validated-team",
user_id="validated-user",
)
assert result == mock_key_obj
# generate_key_helper_fn was passed table_name="key" (not user-upsert path)
# and the validated team_id + user_id from auth_builder
assert mock_gen_key.call_args.kwargs["table_name"] == "key"
assert mock_gen_key.call_args.kwargs["team_id"] == "validated-team"
assert mock_gen_key.call_args.kwargs["user_id"] == "validated-user"
# Mapping row was created with the hashed token (FK target)
call_data = prisma_client.db.litellm_jwtkeymapping.create.call_args[1]["data"]
assert call_data["jwt_claim_name"] == "sub"
assert call_data["jwt_claim_value"] == "new-user-42"
assert call_data["token"] == expected_hash
cached = await user_api_key_cache.async_get_cache("jwt_key_mapping:sub:new-user-42")
assert cached == expected_hash
@pytest.mark.asyncio
async def test_auto_register_returns_pending_signal_on_stale_no_mapping_sentinel():
"""
If the cache holds a stale __NO_MAPPING__ sentinel (written under a prior
fallback_team_mapping config) and behavior is now AUTO_REGISTER, the
resolver must evict the sentinel and return _PendingAutoRegister (so the
caller can run auth_builder before creating the key) not silently return
None and not create the key on the spot.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
from litellm.proxy.auth.user_api_key_auth import _PendingAutoRegister
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
jwt_claims = {"email": "alice@corp.com"}
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
user_api_key_cache = DualCache()
await user_api_key_cache.async_set_cache(
"jwt_key_mapping:email:alice@corp.com", "__NO_MAPPING__"
)
with patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
) as mock_gen_key:
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert isinstance(result, _PendingAutoRegister)
# Stale sentinel must be evicted so the deferred auto-register actually
# runs after auth_builder validates the JWT
cached_after = await user_api_key_cache.async_get_cache(
"jwt_key_mapping:email:alice@corp.com"
)
assert cached_after is None
mock_gen_key.assert_not_called()
prisma_client.db.litellm_jwtkeymapping.create.assert_not_called()
@pytest.mark.asyncio
async def test_auto_register_race_condition_unique_conflict():
"""
If two concurrent requests both call _auto_register_jwt_mapping and the
second hits a unique-constraint violation on create, it must:
1) delete the orphaned virtual key it just created (so orphans don't
accumulate in LiteLLM_VerificationToken under sustained concurrency),
2) fall back to the winner's mapping,
3) not surface an error.
"""
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
from litellm.proxy._types import UnregisteredJWTClientBehavior, hash_token
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock(
side_effect=Exception("Unique constraint failed (P2002)")
)
prisma_client.db.litellm_verificationtoken.delete = AsyncMock()
# Simulate the winner's mapping already in DB after the conflict
winner_mapping = MagicMock()
winner_mapping.token = "winner_token_hash"
winner_mapping.is_active = True
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(
return_value=winner_mapping
)
user_api_key_cache = DualCache()
loser_plaintext = "sk-loser"
loser_hash = hash_token(loser_plaintext)
mock_key_obj = UserAPIKeyAuth(token="winner_token_hash", team_id=None)
with (
patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object",
new_callable=AsyncMock,
) as mock_get_key,
patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
return_value={"token": loser_plaintext, "key": loser_plaintext},
),
):
mock_get_key.return_value = mock_key_obj
result = await _auto_register_jwt_mapping(
virtual_key_claim_field="sub",
claim_value="user-42",
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
cache_key="jwt_key_mapping:sub:user-42",
)
assert result == mock_key_obj
# The orphaned loser key must be deleted from LiteLLM_VerificationToken
prisma_client.db.litellm_verificationtoken.delete.assert_called_once_with(
where={"token": loser_hash}
)
# Cache should hold the winner's token, not the loser's
cached = await user_api_key_cache.async_get_cache("jwt_key_mapping:sub:user-42")
assert cached == "winner_token_hash"
mock_get_key.assert_called_once_with(
hashed_token="winner_token_hash",
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
# ──────────────────────────────────────────────
# Tests: prisma_client=None does not bypass no-match policy
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_reject_behavior_enforced_when_prisma_client_is_none():
"""
When prisma_client is None and behavior is REJECT, a 403 must be raised
not silently fallen through to team auth.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
)
jwt_claims = {"email": "unknown@example.com"}
user_api_key_cache = DualCache()
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=None, # no DB
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
assert "unknown@example.com" in exc_info.value.detail
@pytest.mark.asyncio
async def test_reject_raises_403_when_claim_field_missing_from_jwt():
"""
Security: a JWT that omits the configured virtual_key_claim_field must NOT
bypass the REJECT policy. Previously the early `if claim_value is None:
return None` branch ran before the policy check, letting a caller who knows
the configured claim-field name silently fall through to team-based auth.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
)
# JWT does NOT contain "sub"
jwt_claims = {"email": "user@example.com"}
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=MagicMock(),
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
assert "'sub'" in exc_info.value.detail
assert "missing from the JWT" in exc_info.value.detail
@pytest.mark.asyncio
async def test_auto_register_raises_403_when_claim_field_missing_from_jwt():
"""
AUTO_REGISTER cannot create a mapping without a stable identity. When the
configured claim field is missing from the JWT, return 403 rather than
silently falling through (which would bypass the unregistered-client policy)
or creating a sentinel-keyed record.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
)
jwt_claims = {"email": "user@example.com"}
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=MagicMock(),
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 403
assert "missing from the JWT" in exc_info.value.detail
@pytest.mark.asyncio
async def test_fallback_team_mapping_returns_none_when_claim_field_missing_from_jwt():
"""
Under FALLBACK_TEAM_MAPPING (the default, backward-compatible mode), a JWT
without the configured claim field must still fall through to team-based
JWT auth not raise. This preserves the pre-existing contract.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.FALLBACK_TEAM_MAPPING,
)
jwt_claims = {"email": "user@example.com"}
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=MagicMock(),
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result is None
@pytest.mark.asyncio
async def test_fallback_team_mapping_returns_none_when_prisma_client_is_none():
"""
When prisma_client is None and behavior is FALLBACK_TEAM_MAPPING, the
function must return None (fall through to team auth) not raise.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="email",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.FALLBACK_TEAM_MAPPING,
)
jwt_claims = {"email": "anyone@example.com"}
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=None,
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result is None
@pytest.mark.asyncio
async def test_auto_register_raises_500_when_prisma_client_is_none():
"""
AUTO_REGISTER without a DB connection must raise HTTP 500 with a clear
message it cannot create keys without a database.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
)
jwt_claims = {"sub": "new-user-42"}
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=None,
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 500
assert "AUTO_REGISTER requires a database" in exc_info.value.detail
@pytest.mark.asyncio
async def test_auto_register_raises_500_when_sentinel_cached_and_no_db():
"""
AUTO_REGISTER + cached __NO_MAPPING__ sentinel + prisma_client is None must
raise HTTP 500, matching the fresh-path behavior. Previously this path
silently returned None and let the request fall through to team auth,
creating different access-control outcomes under identical configuration.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
jwt_claims = {"sub": "user-42"}
user_api_key_cache = DualCache()
# Stale sentinel written under a prior fallback_team_mapping config
await user_api_key_cache.async_set_cache(
"jwt_key_mapping:sub:user-42", "__NO_MAPPING__"
)
with pytest.raises(HTTPException) as exc_info:
await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=None,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert exc_info.value.status_code == 500
assert "AUTO_REGISTER requires a database" in exc_info.value.detail
@pytest.mark.asyncio
async def test_auto_register_race_conflict_tolerates_delete_failure():
"""
If deleting the orphaned virtual key after a race-condition conflict fails
(e.g. transient DB error), the request must still succeed by returning the
winner's mapping — the orphan is unmapped and inert.
"""
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock(
side_effect=Exception("Unique constraint failed (P2002)")
)
prisma_client.db.litellm_verificationtoken.delete = AsyncMock(
side_effect=Exception("transient DB error")
)
winner_mapping = MagicMock()
winner_mapping.token = "winner_token_hash"
winner_mapping.is_active = True
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(
return_value=winner_mapping
)
user_api_key_cache = DualCache()
mock_key_obj = UserAPIKeyAuth(token="winner_token_hash", team_id=None)
with (
patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object",
new_callable=AsyncMock,
) as mock_get_key,
patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
return_value={"token": "sk-loser", "key": "sk-loser"},
),
):
mock_get_key.return_value = mock_key_obj
result = await _auto_register_jwt_mapping(
virtual_key_claim_field="sub",
claim_value="user-42",
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
cache_key="jwt_key_mapping:sub:user-42",
)
# Caller still receives the winner's mapping even when cleanup fails
assert result == mock_key_obj
prisma_client.db.litellm_verificationtoken.delete.assert_called_once()
@pytest.mark.asyncio
async def test_auto_register_raises_503_when_winner_mapping_vanishes():
"""
Race edge case: this request loses the unique-constraint race, deletes its
orphan, then refetches the winner's mapping — but the winner's row was
concurrently deleted. Previously this returned None, silently falling
through to less-restrictive team-based JWT auth (bypassing the configured
AUTO_REGISTER policy). Must now raise HTTP 503 so the caller retries
rather than getting unintended fallback access.
"""
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock(
side_effect=Exception("Unique constraint failed (P2002)")
)
prisma_client.db.litellm_verificationtoken.delete = AsyncMock()
# Winner row no longer exists by the time we refetch
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
user_api_key_cache = DualCache()
with (
patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
return_value={"token": "sk-loser", "key": "sk-loser"},
),
pytest.raises(HTTPException) as exc_info,
):
await _auto_register_jwt_mapping(
virtual_key_claim_field="sub",
claim_value="user-42",
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
cache_key="jwt_key_mapping:sub:user-42",
)
assert exc_info.value.status_code == 503
assert "concurrently removed" in exc_info.value.detail
@pytest.mark.asyncio
async def test_proxy_admin_sentinel_skips_db_lookup_on_cache_hit():
"""
When the cache holds the proxy-admin sentinel (written after a prior
request's is_proxy_admin early-return), _resolve_jwt_to_virtual_key must
return None *without* hitting the DB. Caller proceeds to auth_builder.
Without this, every subsequent proxy-admin request under AUTO_REGISTER
would re-query get_jwt_key_mapping_object a cache-miss regression
introduced by the deferred-auto-register refactor.
"""
from litellm.proxy._types import UnregisteredJWTClientBehavior
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
virtual_key_mapping_cache_ttl=300,
)
jwt_claims = {"sub": "admin-user"}
prisma_client = MagicMock()
# Will fail the test if accessed — proves the sentinel short-circuits DB
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(
side_effect=AssertionError("DB must not be hit when sentinel is cached")
)
user_api_key_cache = DualCache()
await user_api_key_cache.async_set_cache(
"jwt_key_mapping:sub:admin-user", "__JWT_PROXY_ADMIN__"
)
result = await _resolve_jwt_to_virtual_key(
jwt_claims=jwt_claims,
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=None,
proxy_logging_obj=None,
)
assert result is None
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
# ──────────────────────────────────────────────
# Tests: AUTO_REGISTER stamps validated identity from auth_builder
# ──────────────────────────────────────────────
@pytest.mark.asyncio
async def test_auto_register_helper_stamps_validated_identity_context():
"""
The deferred-auto-register contract: _auto_register_jwt_mapping is called
with identity fields from JWTAuthManager.auth_builder's *validated*
result (after RBAC, scope mappings, custom_validate, email-domain policy).
These must be passed to generate_key_helper_fn so the created key carries
them the cached future-request path then inherits the same team/user/org
limits the auth_builder path would have applied.
"""
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
virtual_key_mapping_cache_ttl=300,
)
prisma_client = MagicMock()
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
mock_key_obj = UserAPIKeyAuth(
token="hashed", team_id="validated-team", user_id="validated-user"
)
with (
patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object",
new_callable=AsyncMock,
) as mock_get_key,
patch(
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
new_callable=AsyncMock,
) as mock_gen_key,
):
mock_gen_key.return_value = {"token": "sk-newkey", "key": "sk-newkey"}
mock_get_key.return_value = mock_key_obj
result = await _auto_register_jwt_mapping(
virtual_key_claim_field="sub",
claim_value="new-user",
jwt_handler=jwt_handler,
prisma_client=prisma_client,
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=None,
cache_key="jwt_key_mapping:sub:new-user",
team_id="validated-team",
user_id="validated-user",
org_id="validated-org",
end_user_id="validated-end-user",
)
assert result == mock_key_obj
assert mock_gen_key.call_args.kwargs["team_id"] == "validated-team"
assert mock_gen_key.call_args.kwargs["user_id"] == "validated-user"
assert mock_gen_key.call_args.kwargs["organization_id"] == "validated-org"
assert result.org_id == "validated-org"
assert result.end_user_id == "validated-end-user"
# ──────────────────────────────────────────────
# Tests: backward-compat alias jwt_client_id_field
# ──────────────────────────────────────────────
def test_jwt_client_id_field_alias_maps_to_virtual_key_claim_field():
"""
jwt_client_id_field (old doc name) must silently alias to virtual_key_claim_field.
"""
auth = LiteLLM_JWTAuth(jwt_client_id_field="azp")
assert auth.virtual_key_claim_field == "azp"
def test_jwt_client_id_field_does_not_raise_on_duplicate():
"""
If both jwt_client_id_field and virtual_key_claim_field are supplied,
virtual_key_claim_field takes precedence and no error is raised.
"""
auth = LiteLLM_JWTAuth(
jwt_client_id_field="old_field",
virtual_key_claim_field="new_field",
)
assert auth.virtual_key_claim_field == "new_field"

View File

@ -31,12 +31,14 @@ from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.auth_checks import get_key_object, _cache_key_object
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.auth.user_api_key_auth import (
_PendingAutoRegister,
_matches_routing_override,
_reserve_budget_after_common_checks,
_route_requires_auth_despite_public,
_routing_selector_matches_claim,
_run_centralized_common_checks,
_run_post_custom_auth_checks,
_user_api_key_auth_builder,
get_api_key,
user_api_key_auth,
)
@ -1550,6 +1552,93 @@ class TestJWTOAuth2Coexistence:
assert mock_jwt_auth.call_args.kwargs["request_method"] == "POST"
assert result.user_id == "jwt-human-user"
@pytest.mark.asyncio
async def test_auto_register_passes_validated_org_context_to_generated_key(self):
jwt_token = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ1c2VyMSJ9.signature"
general_settings = {"enable_jwt_auth": True}
user_api_key_cache = DualCache()
prisma_client = MagicMock()
jwt_handler = MagicMock()
jwt_handler.is_jwt.return_value = True
jwt_handler.auth_jwt = AsyncMock(return_value={"sub": "user1"})
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
virtual_key_claim_field="sub",
virtual_key_mapping_cache_ttl=300,
)
auto_registered_key = UserAPIKeyAuth(
token="hashed-auto-key",
team_id="validated-team",
user_id="validated-user",
org_id="validated-org",
end_user_id="validated-end-user",
)
mock_jwt_result = {
"is_proxy_admin": False,
"team_object": None,
"user_object": None,
"end_user_object": None,
"org_object": None,
"token": jwt_token,
"team_id": "validated-team",
"user_id": "validated-user",
"end_user_id": "validated-end-user",
"org_id": "validated-org",
"team_membership": None,
"jwt_claims": {"sub": "user1"},
}
mock_request = MagicMock()
mock_request.url.path = "/v1/chat/completions"
mock_request.method = "POST"
mock_request.headers = {"authorization": f"Bearer {jwt_token}"}
mock_request.query_params = {}
mock_request.state = SimpleNamespace()
with (
patch("litellm.proxy.proxy_server.general_settings", general_settings),
patch("litellm.proxy.proxy_server.premium_user", True),
patch("litellm.proxy.proxy_server.master_key", "sk-master"),
patch("litellm.proxy.proxy_server.prisma_client", prisma_client),
patch("litellm.proxy.proxy_server.user_api_key_cache", user_api_key_cache),
patch("litellm.proxy.proxy_server.proxy_logging_obj", MagicMock()),
patch("litellm.proxy.proxy_server.jwt_handler", jwt_handler),
patch(
"litellm.proxy.auth.user_api_key_auth._resolve_jwt_to_virtual_key",
new_callable=AsyncMock,
return_value=_PendingAutoRegister(
claim_field="sub",
claim_value="user1",
cache_key="jwt_key_mapping:sub:user1",
),
),
patch(
"litellm.proxy.auth.user_api_key_auth.JWTAuthManager.auth_builder",
new_callable=AsyncMock,
return_value=mock_jwt_result,
),
patch(
"litellm.proxy.auth.user_api_key_auth._auto_register_jwt_mapping",
new_callable=AsyncMock,
return_value=auto_registered_key,
) as mock_auto_register,
):
result = await _user_api_key_auth_builder(
request=mock_request,
api_key=jwt_token,
azure_api_key_header="",
anthropic_api_key_header=None,
google_ai_studio_api_key_header=None,
azure_apim_header=None,
request_data={"model": "gpt-4o-mini"},
)
mock_auto_register.assert_awaited_once()
assert mock_auto_register.call_args.kwargs["team_id"] == "validated-team"
assert mock_auto_register.call_args.kwargs["user_id"] == "validated-user"
assert mock_auto_register.call_args.kwargs["org_id"] == "validated-org"
assert mock_auto_register.call_args.kwargs["end_user_id"] == "validated-end-user"
assert result.org_id == "validated-org"
@pytest.mark.asyncio
async def test_routing_override_routes_matching_jwt_to_oauth2(self):
"""