Litellm jwt mapping virtualkeys (#28510)
* restore an explicit no-match policy * fix(jwt): fix AUTO_REGISTER sentinel bypass, race condition, and inline import comment - AUTO_REGISTER now evicts stale __NO_MAPPING__ sentinel instead of silently returning None when cached under a prior fallback_team_mapping config - Race condition in _auto_register_jwt_mapping: catch P2002 unique-constraint violation on concurrent creates, fetch the winning mapping, proceed cleanly - Added comment on inline generate_key_helper_fn import explaining the circular dependency (key_management_endpoints imports user_api_key_auth at line 51) - 3 new tests: stale sentinel eviction, race condition winner fallback, and the existing auto_register happy path Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(jwt): cache __NO_MAPPING__ sentinel before raising 403 in REJECT mode REJECT mode was raising HTTPException immediately on a DB miss without writing the __NO_MAPPING__ sentinel, causing every subsequent rejected request to re-query the DB. Write the sentinel first so repeated rejections are served from cache within virtual_key_mapping_cache_ttl. Adds test asserting DB is not hit on the second reject after a cache-warm miss. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(jwt): enforce no-match policy when prisma_client is None The early `if prisma_client is None: return None` guard ran before the no-match policy check, silently bypassing REJECT and AUTO_REGISTER — every JWT client fell through to team auth regardless of configuration. Fix: treat prisma_client=None as a definitive DB miss and fall through to the same policy block as a real miss. REJECT now raises 403, AUTO_REGISTER raises 500 with a clear message (can't create keys without a DB), FALLBACK_TEAM_MAPPING returns None unchanged. Adds three tests: REJECT/403 with no DB, FALLBACK returns None with no DB, AUTO_REGISTER/500 with no DB. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * fix(jwt): consistent AUTO_REGISTER on cached sentinel; clean up race orphans Addresses Greptile review on PR #25570 cherry-pick. 1. Inconsistent AUTO_REGISTER when __NO_MAPPING__ sentinel is cached: The cached-sentinel branch silently returned None when prisma_client was None, while the fresh path raised HTTP 500 under the same config. Same request, different access-control outcome depending on cache state. Both paths now raise the same 500. 2. Orphaned virtual keys from race-condition losers: On unique-constraint conflict, generate_key_helper_fn had already persisted an unrestricted virtual key in LiteLLM_VerificationToken with the cleartext in request memory. Under sustained concurrency these accumulated indefinitely. The loser now deletes its orphan before falling back to the winner's mapping; failure to delete is logged but does not fail the request. Also corrects a latent FK bug surfaced while fixing #2: the mapping row was storing the plaintext key in LiteLLM_JWTKeyMapping.token, but that column FKs to the hashed LiteLLM_VerificationToken.token — now hashed at the call site. Tests: - updated test_auto_register_creates_key_and_mapping to assert the hashed token is stored, not the plaintext - updated test_auto_register_race_condition_unique_conflict to assert the orphan is deleted with the correct hashed token - added test_auto_register_raises_500_when_sentinel_cached_and_no_db - added test_auto_register_race_conflict_tolerates_delete_failure Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(jwt): close REJECT bypass when JWT omits the configured claim field A JWT presented without the configured `virtual_key_claim_field` previously returned None at the `claim_value is None` guard before the `unregistered_jwt_client_behavior` check ran. A caller who knows the configured claim-field name could bypass REJECT by simply omitting that field and falling through to team-based JWT auth. Apply the no-match policy on a missing claim: - REJECT → 403 - AUTO_REGISTER → 403 (no stable identity to map; refuse rather than create a sentinel-keyed record) - FALLBACK_TEAM_MAPPING → return None (unchanged, backward-compatible) Adds three tests covering each branch of the missing-claim path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(jwt): AUTO_REGISTER inherits team_id so keys are bounded by team limits Auto-registered virtual keys were created with no team, model, route, rate, or budget constraints — broader access than the standard team-based JWT auth path the same client would have taken. Under AUTO_REGISTER, resolve the team_id from the JWT (via the operator-configured team_id_jwt_field / team_id_default) and stamp it on the new key. Downstream auth then applies the team's budget/models/tpm/rpm/allowed_routes via the existing virtual-key flow. Policy when team_id_jwt_field is configured: - JWT carries team claim → stamp resolved team_id - JWT lacks claim + team_id_default set → stamp default - JWT lacks claim + no default → 403 (refuse to create an unbounded key) When neither team_id_jwt_field nor team_id_default is configured, the operator has explicitly opted out of team-based limits — the auto-created key has no team_id (matches what team-auth would do in the same config). Adds 4 tests covering each branch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(jwt): make AUTO_REGISTER functional in prod; raise on missing winner Two correctness fixes flagged by Greptile on the AUTO_REGISTER path: 1. generate_key_helper_fn was called without table_name="key". Without that, the helper falls into the user-upsert branch (table_name in (None, "user")) and tries to insert into LiteLLM_UserTable with user_id=None, which hits the NOT NULL @id constraint. AUTO_REGISTER would never have succeeded in production. Now passes table_name="key" explicitly, matching the /key/generate caller. 2. When the race loser refetches the winner's mapping and gets None (winner row concurrently deleted), the previous code returned None — and the caller in _resolve_jwt_to_virtual_key then fell through to less- restrictive team-based JWT auth, silently bypassing the configured AUTO_REGISTER policy. Now raises HTTP 503 so the caller retries against a stable state rather than getting unintended fallback access. Adds one test for the 503 winner-vanishes path. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(jwt): defer AUTO_REGISTER until JWT policy is enforced by auth_builder Closes the JWT policy bypass on the AUTO_REGISTER path flagged by veria-ai. Before: when unregistered_jwt_client_behavior=auto_register and the JWT's claim was unmapped, _resolve_jwt_to_virtual_key validated the JWT signature and then immediately created a virtual key + mapping. JWTAuthManager.auth_builder never ran for the first request (the new key short-circuited the team-auth path), and every subsequent request hit the cached mapping — so custom_validate, RBAC, scope_mappings, and user_allowed_email_domain were never enforced for auto-registered clients. After: _resolve_jwt_to_virtual_key returns a _PendingAutoRegister signal instead of creating the key. The caller in _user_api_key_auth_builder runs JWTAuthManager.auth_builder, then — only on a validated, policy-passing result — calls _auto_register_jwt_mapping with the team_id / user_id from that result. The created key inherits team + user limits from the validated identity, and future cache hits load that already-policy-checked key. Also drops the interim _resolve_inherited_team_id helper that pulled team_id from raw JWT claims — same bypass risk; team_id now comes exclusively from auth_builder. Tests: - Rewrote two existing tests to assert _resolve_jwt_to_virtual_key returns _PendingAutoRegister (no key created yet) for both the fresh-DB-miss and stale-sentinel branches - Added a contract test that _auto_register_jwt_mapping stamps the validated team_id/user_id onto generate_key_helper_fn - Removed four stale team-binding tests that exercised the prior raw-claim helper Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * Update user_api_key_auth.py * fix(jwt): cache proxy-admin AUTO_REGISTER path to avoid repeated DB lookups Cache-miss regression introduced by the deferred-auto-register refactor: when a JWT under AUTO_REGISTER resolved to a proxy admin, the is_proxy_admin early-return in _user_api_key_auth_builder ran *before* the pending auto-register cache-write block. Result: no cache entry, so every subsequent proxy-admin request re-queried get_jwt_key_mapping_object indefinitely. Fix: write a __JWT_PROXY_ADMIN__ sentinel to user_api_key_cache before the early return when a pending auto-register existed. _resolve_jwt_to_virtual_key treats that sentinel as "skip mapping, fall through to auth_builder", so future requests from the same JWT identity hit the cache instead of the DB. auth_builder still runs full JWT policy on every request — only the mapping DB lookup is short-circuited. Adds one test asserting the sentinel cache-hit returns None without hitting prisma_client.db.litellm_jwtkeymapping.find_first. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com> * fix(proxy): stamp org context on JWT auto-registered keys AUTO_REGISTER keys were created with team_id and user_id only, so org budget checks were skipped after switching to the key-scoped path. Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com> Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
41e90a6ada
commit
3bd89f209e
@ -4448,6 +4448,25 @@ class JWTRoutingOverride(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class UnregisteredJWTClientBehavior(str, enum.Enum):
|
||||
"""
|
||||
Controls what happens when `virtual_key_claim_field` is configured but the
|
||||
JWT claim value has no registered mapping in `litellm_jwtkeymapping`.
|
||||
|
||||
- fallback_team_mapping: Fall through to standard team-based JWT auth (default,
|
||||
backward-compatible).
|
||||
- reject: Immediately return HTTP 403. Use this when every valid JWT client
|
||||
must have a pre-registered virtual key — unknown callers are denied.
|
||||
- auto_register: Automatically create a new virtual key and mapping on first
|
||||
encounter. The new key has no budget/model restrictions; admins can tighten
|
||||
it later via /jwt_client/update.
|
||||
"""
|
||||
|
||||
FALLBACK_TEAM_MAPPING = "fallback_team_mapping"
|
||||
REJECT = "reject"
|
||||
AUTO_REGISTER = "auto_register"
|
||||
|
||||
|
||||
class JWTIssuerConfig(BaseModel):
|
||||
"""
|
||||
Issuer-bound JWT validation configuration.
|
||||
@ -4614,6 +4633,15 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||
default=300,
|
||||
description="TTL (seconds) for caching JWT-to-virtual-key mapping lookups.",
|
||||
)
|
||||
unregistered_jwt_client_behavior: UnregisteredJWTClientBehavior = Field(
|
||||
default=UnregisteredJWTClientBehavior.FALLBACK_TEAM_MAPPING,
|
||||
description=(
|
||||
"What to do when virtual_key_claim_field is set but the JWT claim value "
|
||||
"has no registered mapping. 'fallback_team_mapping' (default): fall through "
|
||||
"to team-based JWT auth. 'reject': return HTTP 403. "
|
||||
"'auto_register': auto-create a virtual key and mapping on first encounter."
|
||||
),
|
||||
)
|
||||
routing_overrides: Optional[List[JWTRoutingOverride]] = Field(
|
||||
default=None,
|
||||
description="Optional claim-based routing overrides for JWT-shaped tokens. Matching rules route requests to oauth2 before default JWT flow.",
|
||||
@ -4633,6 +4661,13 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||
# ``s3://`` / ``gcs://`` when this is None.
|
||||
config_file_path = kwargs.pop("config_file_path", None)
|
||||
|
||||
# Backward-compat: jwt_client_id_field was renamed to virtual_key_claim_field
|
||||
if "jwt_client_id_field" in kwargs:
|
||||
if "virtual_key_claim_field" not in kwargs:
|
||||
kwargs["virtual_key_claim_field"] = kwargs.pop("jwt_client_id_field")
|
||||
else:
|
||||
kwargs.pop("jwt_client_id_field")
|
||||
|
||||
# get the attribute names for this Pydantic model
|
||||
allowed_keys = LiteLLM_JWTAuth.__annotations__.keys()
|
||||
|
||||
|
||||
@ -12,7 +12,7 @@ import fnmatch
|
||||
import re
|
||||
import secrets
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast
|
||||
from typing import Any, Dict, Iterator, NamedTuple, List, Optional, Tuple, Union, cast
|
||||
|
||||
import fastapi
|
||||
from fastapi import HTTPException, Request, WebSocket, status
|
||||
@ -607,6 +607,169 @@ async def check_api_key_for_custom_headers_or_pass_through_endpoints(
|
||||
return api_key
|
||||
|
||||
|
||||
# Cache sentinel written when a JWT under AUTO_REGISTER resolved to a proxy
|
||||
# admin via auth_builder. Proxy admins don't need a mapped virtual key (they
|
||||
# have full access via auth_builder anyway), but without a cache entry every
|
||||
# subsequent request from the same JWT identity would re-query the DB for a
|
||||
# non-existent mapping. Sentinel tells _resolve_jwt_to_virtual_key to skip
|
||||
# the lookup and return None (caller proceeds to auth_builder).
|
||||
_JWT_PROXY_ADMIN_SENTINEL = "__JWT_PROXY_ADMIN__"
|
||||
|
||||
|
||||
class _PendingAutoRegister(NamedTuple):
|
||||
"""
|
||||
Signal returned by ``_resolve_jwt_to_virtual_key`` when the JWT's claim is
|
||||
unmapped and ``unregistered_jwt_client_behavior`` is AUTO_REGISTER.
|
||||
|
||||
The caller MUST run standard ``JWTAuthManager.auth_builder`` to apply RBAC,
|
||||
scope mappings, ``custom_validate``, and ``user_allowed_email_domain``
|
||||
policy BEFORE calling ``_auto_register_jwt_mapping`` with the validated
|
||||
``team_id`` / ``user_id`` from the auth_builder result. Auto-registering
|
||||
purely on a signature-valid JWT (the old behavior) bypassed every JWT
|
||||
policy beyond signature verification.
|
||||
"""
|
||||
|
||||
claim_field: str
|
||||
claim_value: str
|
||||
cache_key: str
|
||||
|
||||
|
||||
async def _auto_register_jwt_mapping(
|
||||
virtual_key_claim_field: str,
|
||||
claim_value: str,
|
||||
jwt_handler: JWTHandler,
|
||||
prisma_client: PrismaClient,
|
||||
user_api_key_cache: UserApiKeyCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
cache_key: str,
|
||||
team_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
end_user_id: Optional[str] = None,
|
||||
) -> Optional[UserAPIKeyAuth]:
|
||||
"""
|
||||
Auto-register: create a new virtual key + mapping for an unrecognised JWT
|
||||
claim value. ``team_id`` and ``user_id`` must come from a successful
|
||||
``JWTAuthManager.auth_builder`` run — they encode the JWT identity AFTER
|
||||
RBAC/scope/custom_validate/email-domain policy has been enforced. The key
|
||||
is stamped with those values so the cached future-request path inherits
|
||||
the same team/user/org limits the auth_builder path would have applied.
|
||||
|
||||
Race safety: if two concurrent requests both reach here simultaneously (both
|
||||
saw no mapping in the DB), one will win the unique-constraint race on
|
||||
litellm_jwtkeymapping. The loser catches the conflict, deletes its orphaned
|
||||
key, fetches the winner's mapping, and proceeds — no error surfaced.
|
||||
"""
|
||||
# Inline import required: key_management_endpoints imports user_api_key_auth
|
||||
# (line 51) so a module-level import here would create a circular dependency.
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
generate_key_helper_fn,
|
||||
)
|
||||
|
||||
# ``table_name="key"`` is required: without it, generate_key_helper_fn
|
||||
# falls into the user-upsert branch (`table_name is None or "user"`) and
|
||||
# attempts to insert into LiteLLM_UserTable with user_id=None, which fails
|
||||
# the NOT NULL @id constraint. Every successful key-creation caller (e.g.
|
||||
# /key/generate) passes table_name="key" explicitly.
|
||||
key_data = await generate_key_helper_fn(
|
||||
request_type="key",
|
||||
table_name="key",
|
||||
team_id=team_id,
|
||||
user_id=user_id,
|
||||
organization_id=org_id,
|
||||
metadata={
|
||||
"auto_registered": True,
|
||||
"jwt_claim_field": virtual_key_claim_field,
|
||||
"jwt_claim_value": claim_value,
|
||||
},
|
||||
)
|
||||
# generate_key_helper_fn returns the plaintext key in "token"; the persisted
|
||||
# row in LiteLLM_VerificationToken uses its hash, so hash here to get the FK
|
||||
# value referenced by LiteLLM_JWTKeyMapping.token.
|
||||
token_hash = hash_token(key_data["token"])
|
||||
|
||||
try:
|
||||
await prisma_client.db.litellm_jwtkeymapping.create(
|
||||
data={
|
||||
"jwt_claim_name": virtual_key_claim_field,
|
||||
"jwt_claim_value": claim_value,
|
||||
"token": token_hash,
|
||||
"created_by": "auto_register",
|
||||
"updated_by": "auto_register",
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if "unique" in error_str or "p2002" in error_str:
|
||||
# A concurrent request won the race. The key generate_key_helper_fn
|
||||
# just persisted to LiteLLM_VerificationToken is orphaned — nothing
|
||||
# maps to it, but it's a fully valid unrestricted API key sitting in
|
||||
# the DB and the cleartext is in memory on this request. Delete it
|
||||
# so orphans don't accumulate under sustained concurrency.
|
||||
verbose_proxy_logger.debug(
|
||||
"JWT Key Mapping (auto_register): unique conflict on create — "
|
||||
"deleting orphaned virtual key and fetching winner's mapping for %s='%s'.",
|
||||
virtual_key_claim_field,
|
||||
claim_value,
|
||||
)
|
||||
try:
|
||||
await prisma_client.db.litellm_verificationtoken.delete(
|
||||
where={"token": token_hash}
|
||||
)
|
||||
except Exception as delete_err:
|
||||
# Don't fail the request if cleanup fails — the orphan is
|
||||
# unmapped and inert. Log so an operator can prune it later.
|
||||
verbose_proxy_logger.warning(
|
||||
"JWT Key Mapping (auto_register): failed to delete orphaned key after race: %s",
|
||||
delete_err,
|
||||
)
|
||||
token_hash = await get_jwt_key_mapping_object(
|
||||
jwt_claim_name=virtual_key_claim_field,
|
||||
jwt_claim_value=claim_value,
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
if token_hash is None:
|
||||
# The winner's mapping vanished between the unique-constraint
|
||||
# conflict and our re-fetch (concurrent delete). Returning None
|
||||
# here would silently fall through to team-based JWT auth —
|
||||
# a less-restrictive path than the operator configured. Raise
|
||||
# 503 so the caller retries against a stable state instead.
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=(
|
||||
"JWT Key Mapping: AUTO_REGISTER race resolution failed — "
|
||||
"winner's mapping was concurrently removed. Retry the request."
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=token_hash,
|
||||
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"JWT Key Mapping (auto_register): created new virtual key for %s='%s'.",
|
||||
virtual_key_claim_field,
|
||||
claim_value,
|
||||
)
|
||||
|
||||
auto_registered_key = await get_key_object(
|
||||
hashed_token=token_hash,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if auto_registered_key is not None:
|
||||
auto_registered_key.org_id = org_id
|
||||
auto_registered_key.end_user_id = end_user_id
|
||||
return auto_registered_key
|
||||
|
||||
|
||||
async def _resolve_jwt_to_virtual_key(
|
||||
jwt_claims: dict,
|
||||
jwt_handler: JWTHandler,
|
||||
@ -614,7 +777,22 @@ async def _resolve_jwt_to_virtual_key(
|
||||
user_api_key_cache: UserApiKeyCache,
|
||||
parent_otel_span: Optional[Span],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
) -> Optional[UserAPIKeyAuth]:
|
||||
) -> Union[Optional[UserAPIKeyAuth], "_PendingAutoRegister"]:
|
||||
"""
|
||||
Returns:
|
||||
- ``UserAPIKeyAuth``: a resolved virtual key (cache hit or DB hit). The
|
||||
caller may use this directly; JWT policy has been enforced previously
|
||||
(at key-creation time or, for cached results, before caching).
|
||||
- ``_PendingAutoRegister``: claim is unmapped and behavior is AUTO_REGISTER.
|
||||
The caller MUST run ``JWTAuthManager.auth_builder`` to enforce JWT
|
||||
policy (RBAC, scope, custom_validate, email-domain), then invoke
|
||||
``_auto_register_jwt_mapping`` with the validated team_id/user_id.
|
||||
- ``None``: claim is unmapped and behavior is FALLBACK_TEAM_MAPPING.
|
||||
The caller falls through to standard team-based JWT auth (which itself
|
||||
enforces full JWT policy via auth_builder).
|
||||
- Raises HTTPException: REJECT policy hit, missing claim under
|
||||
REJECT/AUTO_REGISTER, or other policy violations.
|
||||
"""
|
||||
virtual_key_claim_field = jwt_handler.litellm_jwtauth.virtual_key_claim_field
|
||||
if virtual_key_claim_field is None:
|
||||
return None
|
||||
@ -629,12 +807,61 @@ async def _resolve_jwt_to_virtual_key(
|
||||
verbose_proxy_logger.debug(
|
||||
f"JWT Key Mapping: Claim field '{virtual_key_claim_field}' not found in JWT claims."
|
||||
)
|
||||
# A missing claim is an unmapped client — apply the no-match policy
|
||||
# rather than returning early. Otherwise a caller can bypass REJECT
|
||||
# simply by presenting a JWT that omits the configured field. For
|
||||
# AUTO_REGISTER there is no stable identity to map without a claim
|
||||
# value, so we deny rather than create a sentinel-keyed record.
|
||||
behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior
|
||||
if behavior in (
|
||||
UnregisteredJWTClientBehavior.REJECT,
|
||||
UnregisteredJWTClientBehavior.AUTO_REGISTER,
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=(
|
||||
f"JWT Key Mapping: Required claim '{virtual_key_claim_field}' "
|
||||
"is missing from the JWT. Access denied."
|
||||
),
|
||||
)
|
||||
return None
|
||||
|
||||
cache_key = f"jwt_key_mapping:{virtual_key_claim_field}:{claim_value}"
|
||||
cached_mapping = await user_api_key_cache.async_get_cache(cache_key)
|
||||
|
||||
if cached_mapping == _JWT_PROXY_ADMIN_SENTINEL:
|
||||
# Previously resolved to a proxy admin via auth_builder; skip the
|
||||
# mapping lookup and let the caller re-run auth_builder. Avoids a
|
||||
# repeated DB hit on every proxy-admin request under AUTO_REGISTER.
|
||||
return None
|
||||
|
||||
if cached_mapping == "__NO_MAPPING__":
|
||||
behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior
|
||||
if behavior == UnregisteredJWTClientBehavior.REJECT:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"JWT Key Mapping: No registered mapping for {virtual_key_claim_field}='{claim_value}'. Access denied.",
|
||||
)
|
||||
if behavior == UnregisteredJWTClientBehavior.AUTO_REGISTER:
|
||||
# Stale sentinel written under a prior fallback_team_mapping config —
|
||||
# evict it and defer auto-register to after auth_builder runs. Raise
|
||||
# the same 500 as the fresh-path AUTO_REGISTER branch when there is
|
||||
# no DB, so behavior is consistent regardless of whether the cache
|
||||
# happens to hold the sentinel.
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=(
|
||||
"JWT Key Mapping: AUTO_REGISTER requires a database connection. "
|
||||
"Configure a database or change unregistered_jwt_client_behavior."
|
||||
),
|
||||
)
|
||||
await user_api_key_cache.async_delete_cache(cache_key)
|
||||
return _PendingAutoRegister(
|
||||
claim_field=virtual_key_claim_field,
|
||||
claim_value=str(claim_value),
|
||||
cache_key=cache_key,
|
||||
)
|
||||
return None
|
||||
elif cached_mapping is not None:
|
||||
return await get_key_object(
|
||||
@ -645,14 +872,15 @@ async def _resolve_jwt_to_virtual_key(
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
return None
|
||||
|
||||
token_hash = await get_jwt_key_mapping_object(
|
||||
jwt_claim_name=virtual_key_claim_field,
|
||||
jwt_claim_value=str(claim_value),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
# Resolve the mapping from DB, or treat prisma_client=None as a definitive
|
||||
# miss (no DB → no mapping can exist → apply no-match policy below).
|
||||
token_hash: Optional[str] = None
|
||||
if prisma_client is not None:
|
||||
token_hash = await get_jwt_key_mapping_object(
|
||||
jwt_claim_name=virtual_key_claim_field,
|
||||
jwt_claim_value=str(claim_value),
|
||||
prisma_client=prisma_client,
|
||||
)
|
||||
|
||||
if token_hash is not None:
|
||||
await user_api_key_cache.async_set_cache(
|
||||
@ -667,13 +895,50 @@ async def _resolve_jwt_to_virtual_key(
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
else:
|
||||
|
||||
# No mapping found (DB miss or no DB) — apply no-match policy.
|
||||
behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior
|
||||
|
||||
if behavior == UnregisteredJWTClientBehavior.REJECT:
|
||||
# Cache the miss before raising so repeated rejections are served from
|
||||
# cache and don't re-query the DB on every request.
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value="__NO_MAPPING__",
|
||||
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
|
||||
)
|
||||
return None
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"JWT Key Mapping: No registered mapping for {virtual_key_claim_field}='{claim_value}'. Access denied.",
|
||||
)
|
||||
|
||||
if behavior == UnregisteredJWTClientBehavior.AUTO_REGISTER:
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=(
|
||||
"JWT Key Mapping: AUTO_REGISTER requires a database connection. "
|
||||
"Configure a database or change unregistered_jwt_client_behavior."
|
||||
),
|
||||
)
|
||||
# Defer: caller runs JWTAuthManager.auth_builder to enforce RBAC, scope,
|
||||
# custom_validate, and email-domain policy, then auto-registers using
|
||||
# the validated identity. Auto-registering here on a signature-only
|
||||
# JWT would bypass every JWT policy beyond signature verification.
|
||||
return _PendingAutoRegister(
|
||||
claim_field=virtual_key_claim_field,
|
||||
claim_value=str(claim_value),
|
||||
cache_key=cache_key,
|
||||
)
|
||||
|
||||
# FALLBACK_TEAM_MAPPING (default): cache the miss and return None so the
|
||||
# caller falls through to standard team-based JWT auth.
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value="__NO_MAPPING__",
|
||||
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def _ensure_parent_otel_span_on_request_state(request: Request) -> None:
|
||||
@ -893,6 +1158,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
# Try JWT-to-Virtual-Key mapping first to avoid
|
||||
# unnecessary DB queries in auth_builder
|
||||
do_standard_jwt_auth = True
|
||||
pending_auto_register: Optional[_PendingAutoRegister] = None
|
||||
if jwt_handler.litellm_jwtauth.virtual_key_claim_field is not None:
|
||||
# Decode JWT to get claims without running full auth_builder
|
||||
jwt_claims: Optional[dict]
|
||||
@ -901,7 +1167,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
else:
|
||||
jwt_claims = await jwt_handler.auth_jwt(token=api_key)
|
||||
|
||||
valid_token = await _resolve_jwt_to_virtual_key(
|
||||
resolve_result = await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
@ -909,11 +1175,19 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
if valid_token is not None:
|
||||
if isinstance(resolve_result, UserAPIKeyAuth):
|
||||
valid_token = resolve_result
|
||||
api_key = valid_token.token or ""
|
||||
valid_token.jwt_claims = jwt_claims
|
||||
do_standard_jwt_auth = False
|
||||
# Fall through to virtual key checks
|
||||
elif isinstance(resolve_result, _PendingAutoRegister):
|
||||
# Run full JWT policy (RBAC, scope, custom_validate,
|
||||
# email-domain) via auth_builder, then create the key
|
||||
# from the validated identity below.
|
||||
pending_auto_register = resolve_result
|
||||
# else: None → FALLBACK_TEAM_MAPPING, falls through to
|
||||
# standard JWT auth_builder below
|
||||
|
||||
if do_standard_jwt_auth:
|
||||
with tracer.trace("litellm.proxy.auth.jwt_auth_builder"):
|
||||
@ -946,6 +1220,19 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
jwt_claims = result.get("jwt_claims", None)
|
||||
|
||||
if is_proxy_admin:
|
||||
# Proxy admins authenticate via auth_builder (full
|
||||
# access), not via a mapped virtual key. If
|
||||
# AUTO_REGISTER was pending, cache a sentinel so
|
||||
# future requests from this JWT identity skip the
|
||||
# DB mapping lookup in _resolve_jwt_to_virtual_key.
|
||||
# Without this, every proxy-admin request under
|
||||
# AUTO_REGISTER re-hits get_jwt_key_mapping_object.
|
||||
if pending_auto_register is not None:
|
||||
await user_api_key_cache.async_set_cache(
|
||||
key=pending_auto_register.cache_key,
|
||||
value=_JWT_PROXY_ADMIN_SENTINEL,
|
||||
ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl,
|
||||
)
|
||||
return UserAPIKeyAuth(
|
||||
api_key=None,
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
@ -1032,6 +1319,32 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
else None
|
||||
)
|
||||
|
||||
# AUTO_REGISTER deferred from _resolve_jwt_to_virtual_key.
|
||||
# JWT policy (RBAC, scope, custom_validate, email-domain)
|
||||
# has now been enforced by auth_builder above. Create the
|
||||
# mapping + virtual key from the *validated* identity, then
|
||||
# replace valid_token with the new key so downstream checks
|
||||
# use the key-scoped path.
|
||||
if pending_auto_register is not None and prisma_client is not None:
|
||||
auto_registered = await _auto_register_jwt_mapping(
|
||||
virtual_key_claim_field=pending_auto_register.claim_field,
|
||||
claim_value=pending_auto_register.claim_value,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
cache_key=pending_auto_register.cache_key,
|
||||
team_id=team_id,
|
||||
user_id=user_id,
|
||||
org_id=org_id,
|
||||
end_user_id=end_user_id,
|
||||
)
|
||||
if auto_registered is not None:
|
||||
auto_registered.jwt_claims = jwt_claims
|
||||
valid_token = auto_registered
|
||||
api_key = valid_token.token or ""
|
||||
|
||||
# Check if model has zero cost - if so, skip all budget checks
|
||||
model = _get_model_from_request_context(
|
||||
request_data=request_data,
|
||||
|
||||
@ -27,7 +27,6 @@ from litellm.proxy.management_endpoints.jwt_key_mapping_endpoints import (
|
||||
from litellm.caching.caching import DualCache
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Tests: _resolve_jwt_to_virtual_key
|
||||
# ──────────────────────────────────────────────
|
||||
@ -454,3 +453,856 @@ async def test_create_success_returns_response_without_token():
|
||||
assert isinstance(result, JWTKeyMappingResponse)
|
||||
assert "token" not in result.model_fields
|
||||
assert result.jwt_claim_name == "email"
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Tests: unregistered_jwt_client_behavior
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_behavior_raises_403_on_no_mapping():
|
||||
"""
|
||||
When unregistered_jwt_client_behavior='reject' and no mapping exists,
|
||||
_resolve_jwt_to_virtual_key must raise HTTP 403.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="email",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
|
||||
)
|
||||
jwt_claims = {"email": "unknown@example.com"}
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "unknown@example.com" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_behavior_caches_sentinel_after_db_miss():
|
||||
"""
|
||||
On a fresh DB miss with REJECT, the __NO_MAPPING__ sentinel must be written
|
||||
to cache so that subsequent rejected requests are served from cache and do
|
||||
not re-query the DB.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="email",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
|
||||
virtual_key_mapping_cache_ttl=300,
|
||||
)
|
||||
jwt_claims = {"email": "unknown@example.com"}
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
|
||||
):
|
||||
# First call — DB miss, should raise 403 and write sentinel
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
# Sentinel must now be in cache
|
||||
cached = await user_api_key_cache.async_get_cache(
|
||||
"jwt_key_mapping:email:unknown@example.com"
|
||||
)
|
||||
assert cached == "__NO_MAPPING__"
|
||||
|
||||
# Second call — must raise 403 from cache, no additional DB hit
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock()
|
||||
with pytest.raises(HTTPException) as exc_info2:
|
||||
await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert exc_info2.value.status_code == 403
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_behavior_raises_403_on_cached_no_mapping():
|
||||
"""
|
||||
When the negative-cache sentinel __NO_MAPPING__ is present and behavior is
|
||||
'reject', the function must also raise HTTP 403 (not return None silently).
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="email",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
|
||||
)
|
||||
jwt_claims = {"email": "unknown@example.com"}
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
|
||||
|
||||
# Pre-populate the negative cache so the DB is not hit
|
||||
user_api_key_cache = DualCache()
|
||||
cache_key = "jwt_key_mapping:email:unknown@example.com"
|
||||
await user_api_key_cache.async_set_cache(cache_key, "__NO_MAPPING__")
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
# DB must NOT have been hit (sentinel served from cache)
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_register_returns_pending_signal_without_creating_key():
|
||||
"""
|
||||
Security: when unregistered_jwt_client_behavior='auto_register' and no
|
||||
mapping exists, _resolve_jwt_to_virtual_key must NOT create the key yet.
|
||||
It returns a _PendingAutoRegister signal so the caller can run
|
||||
JWTAuthManager.auth_builder (enforcing RBAC, scope mappings,
|
||||
custom_validate, user_allowed_email_domain) FIRST. Creating the key here
|
||||
would bypass every JWT policy beyond signature verification.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
from litellm.proxy.auth.user_api_key_auth import _PendingAutoRegister
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
|
||||
virtual_key_mapping_cache_ttl=300,
|
||||
)
|
||||
jwt_claims = {"sub": "new-user-42"}
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
|
||||
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_gen_key:
|
||||
result = await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, _PendingAutoRegister)
|
||||
assert result.claim_field == "sub"
|
||||
assert result.claim_value == "new-user-42"
|
||||
assert result.cache_key == "jwt_key_mapping:sub:new-user-42"
|
||||
# CRITICAL: no key was created — that must wait until after auth_builder
|
||||
mock_gen_key.assert_not_called()
|
||||
prisma_client.db.litellm_jwtkeymapping.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_register_creates_key_and_mapping_when_helper_invoked():
|
||||
"""
|
||||
When the caller invokes _auto_register_jwt_mapping directly (after
|
||||
auth_builder validation), the helper creates the key + mapping row and
|
||||
returns a UserAPIKeyAuth. The mapping row stores the hashed token (FK to
|
||||
LiteLLM_VerificationToken), not the plaintext key.
|
||||
"""
|
||||
from litellm.proxy._types import hash_token
|
||||
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
virtual_key_mapping_cache_ttl=300,
|
||||
)
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
|
||||
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
plaintext_key = "sk-auto-key"
|
||||
expected_hash = hash_token(plaintext_key)
|
||||
mock_key_obj = UserAPIKeyAuth(token=expected_hash, team_id="validated-team")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.get_key_object",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_key,
|
||||
patch(
|
||||
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_gen_key,
|
||||
):
|
||||
mock_gen_key.return_value = {"token": plaintext_key, "key": plaintext_key}
|
||||
mock_get_key.return_value = mock_key_obj
|
||||
|
||||
result = await _auto_register_jwt_mapping(
|
||||
virtual_key_claim_field="sub",
|
||||
claim_value="new-user-42",
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
cache_key="jwt_key_mapping:sub:new-user-42",
|
||||
team_id="validated-team",
|
||||
user_id="validated-user",
|
||||
)
|
||||
|
||||
assert result == mock_key_obj
|
||||
# generate_key_helper_fn was passed table_name="key" (not user-upsert path)
|
||||
# and the validated team_id + user_id from auth_builder
|
||||
assert mock_gen_key.call_args.kwargs["table_name"] == "key"
|
||||
assert mock_gen_key.call_args.kwargs["team_id"] == "validated-team"
|
||||
assert mock_gen_key.call_args.kwargs["user_id"] == "validated-user"
|
||||
# Mapping row was created with the hashed token (FK target)
|
||||
call_data = prisma_client.db.litellm_jwtkeymapping.create.call_args[1]["data"]
|
||||
assert call_data["jwt_claim_name"] == "sub"
|
||||
assert call_data["jwt_claim_value"] == "new-user-42"
|
||||
assert call_data["token"] == expected_hash
|
||||
cached = await user_api_key_cache.async_get_cache("jwt_key_mapping:sub:new-user-42")
|
||||
assert cached == expected_hash
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_register_returns_pending_signal_on_stale_no_mapping_sentinel():
|
||||
"""
|
||||
If the cache holds a stale __NO_MAPPING__ sentinel (written under a prior
|
||||
fallback_team_mapping config) and behavior is now AUTO_REGISTER, the
|
||||
resolver must evict the sentinel and return _PendingAutoRegister (so the
|
||||
caller can run auth_builder before creating the key) — not silently return
|
||||
None and not create the key on the spot.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
from litellm.proxy.auth.user_api_key_auth import _PendingAutoRegister
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="email",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
|
||||
virtual_key_mapping_cache_ttl=300,
|
||||
)
|
||||
jwt_claims = {"email": "alice@corp.com"}
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
|
||||
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
await user_api_key_cache.async_set_cache(
|
||||
"jwt_key_mapping:email:alice@corp.com", "__NO_MAPPING__"
|
||||
)
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_gen_key:
|
||||
result = await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
assert isinstance(result, _PendingAutoRegister)
|
||||
# Stale sentinel must be evicted so the deferred auto-register actually
|
||||
# runs after auth_builder validates the JWT
|
||||
cached_after = await user_api_key_cache.async_get_cache(
|
||||
"jwt_key_mapping:email:alice@corp.com"
|
||||
)
|
||||
assert cached_after is None
|
||||
mock_gen_key.assert_not_called()
|
||||
prisma_client.db.litellm_jwtkeymapping.create.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_register_race_condition_unique_conflict():
|
||||
"""
|
||||
If two concurrent requests both call _auto_register_jwt_mapping and the
|
||||
second hits a unique-constraint violation on create, it must:
|
||||
1) delete the orphaned virtual key it just created (so orphans don't
|
||||
accumulate in LiteLLM_VerificationToken under sustained concurrency),
|
||||
2) fall back to the winner's mapping,
|
||||
3) not surface an error.
|
||||
"""
|
||||
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior, hash_token
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
|
||||
virtual_key_mapping_cache_ttl=300,
|
||||
)
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock(
|
||||
side_effect=Exception("Unique constraint failed (P2002)")
|
||||
)
|
||||
prisma_client.db.litellm_verificationtoken.delete = AsyncMock()
|
||||
# Simulate the winner's mapping already in DB after the conflict
|
||||
winner_mapping = MagicMock()
|
||||
winner_mapping.token = "winner_token_hash"
|
||||
winner_mapping.is_active = True
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(
|
||||
return_value=winner_mapping
|
||||
)
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
loser_plaintext = "sk-loser"
|
||||
loser_hash = hash_token(loser_plaintext)
|
||||
mock_key_obj = UserAPIKeyAuth(token="winner_token_hash", team_id=None)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.get_key_object",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_key,
|
||||
patch(
|
||||
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"token": loser_plaintext, "key": loser_plaintext},
|
||||
),
|
||||
):
|
||||
mock_get_key.return_value = mock_key_obj
|
||||
|
||||
result = await _auto_register_jwt_mapping(
|
||||
virtual_key_claim_field="sub",
|
||||
claim_value="user-42",
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
cache_key="jwt_key_mapping:sub:user-42",
|
||||
)
|
||||
|
||||
assert result == mock_key_obj
|
||||
# The orphaned loser key must be deleted from LiteLLM_VerificationToken
|
||||
prisma_client.db.litellm_verificationtoken.delete.assert_called_once_with(
|
||||
where={"token": loser_hash}
|
||||
)
|
||||
# Cache should hold the winner's token, not the loser's
|
||||
cached = await user_api_key_cache.async_get_cache("jwt_key_mapping:sub:user-42")
|
||||
assert cached == "winner_token_hash"
|
||||
mock_get_key.assert_called_once_with(
|
||||
hashed_token="winner_token_hash",
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Tests: prisma_client=None does not bypass no-match policy
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_behavior_enforced_when_prisma_client_is_none():
|
||||
"""
|
||||
When prisma_client is None and behavior is REJECT, a 403 must be raised —
|
||||
not silently fallen through to team auth.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="email",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
|
||||
)
|
||||
jwt_claims = {"email": "unknown@example.com"}
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=None, # no DB
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "unknown@example.com" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reject_raises_403_when_claim_field_missing_from_jwt():
|
||||
"""
|
||||
Security: a JWT that omits the configured virtual_key_claim_field must NOT
|
||||
bypass the REJECT policy. Previously the early `if claim_value is None:
|
||||
return None` branch ran before the policy check, letting a caller who knows
|
||||
the configured claim-field name silently fall through to team-based auth.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT,
|
||||
)
|
||||
# JWT does NOT contain "sub"
|
||||
jwt_claims = {"email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=MagicMock(),
|
||||
user_api_key_cache=DualCache(),
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "'sub'" in exc_info.value.detail
|
||||
assert "missing from the JWT" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_register_raises_403_when_claim_field_missing_from_jwt():
|
||||
"""
|
||||
AUTO_REGISTER cannot create a mapping without a stable identity. When the
|
||||
configured claim field is missing from the JWT, return 403 rather than
|
||||
silently falling through (which would bypass the unregistered-client policy)
|
||||
or creating a sentinel-keyed record.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
|
||||
)
|
||||
jwt_claims = {"email": "user@example.com"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=MagicMock(),
|
||||
user_api_key_cache=DualCache(),
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "missing from the JWT" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_team_mapping_returns_none_when_claim_field_missing_from_jwt():
|
||||
"""
|
||||
Under FALLBACK_TEAM_MAPPING (the default, backward-compatible mode), a JWT
|
||||
without the configured claim field must still fall through to team-based
|
||||
JWT auth — not raise. This preserves the pre-existing contract.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.FALLBACK_TEAM_MAPPING,
|
||||
)
|
||||
jwt_claims = {"email": "user@example.com"}
|
||||
|
||||
result = await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=MagicMock(),
|
||||
user_api_key_cache=DualCache(),
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_team_mapping_returns_none_when_prisma_client_is_none():
|
||||
"""
|
||||
When prisma_client is None and behavior is FALLBACK_TEAM_MAPPING, the
|
||||
function must return None (fall through to team auth) — not raise.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="email",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.FALLBACK_TEAM_MAPPING,
|
||||
)
|
||||
jwt_claims = {"email": "anyone@example.com"}
|
||||
|
||||
result = await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=DualCache(),
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_register_raises_500_when_prisma_client_is_none():
|
||||
"""
|
||||
AUTO_REGISTER without a DB connection must raise HTTP 500 with a clear
|
||||
message — it cannot create keys without a database.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
|
||||
)
|
||||
jwt_claims = {"sub": "new-user-42"}
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=DualCache(),
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "AUTO_REGISTER requires a database" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_register_raises_500_when_sentinel_cached_and_no_db():
|
||||
"""
|
||||
AUTO_REGISTER + cached __NO_MAPPING__ sentinel + prisma_client is None must
|
||||
raise HTTP 500, matching the fresh-path behavior. Previously this path
|
||||
silently returned None and let the request fall through to team auth,
|
||||
creating different access-control outcomes under identical configuration.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
|
||||
virtual_key_mapping_cache_ttl=300,
|
||||
)
|
||||
jwt_claims = {"sub": "user-42"}
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
# Stale sentinel written under a prior fallback_team_mapping config
|
||||
await user_api_key_cache.async_set_cache(
|
||||
"jwt_key_mapping:sub:user-42", "__NO_MAPPING__"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert exc_info.value.status_code == 500
|
||||
assert "AUTO_REGISTER requires a database" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_register_race_conflict_tolerates_delete_failure():
|
||||
"""
|
||||
If deleting the orphaned virtual key after a race-condition conflict fails
|
||||
(e.g. transient DB error), the request must still succeed by returning the
|
||||
winner's mapping — the orphan is unmapped and inert.
|
||||
"""
|
||||
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
|
||||
virtual_key_mapping_cache_ttl=300,
|
||||
)
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock(
|
||||
side_effect=Exception("Unique constraint failed (P2002)")
|
||||
)
|
||||
prisma_client.db.litellm_verificationtoken.delete = AsyncMock(
|
||||
side_effect=Exception("transient DB error")
|
||||
)
|
||||
winner_mapping = MagicMock()
|
||||
winner_mapping.token = "winner_token_hash"
|
||||
winner_mapping.is_active = True
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(
|
||||
return_value=winner_mapping
|
||||
)
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
mock_key_obj = UserAPIKeyAuth(token="winner_token_hash", team_id=None)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.get_key_object",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_key,
|
||||
patch(
|
||||
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"token": "sk-loser", "key": "sk-loser"},
|
||||
),
|
||||
):
|
||||
mock_get_key.return_value = mock_key_obj
|
||||
|
||||
result = await _auto_register_jwt_mapping(
|
||||
virtual_key_claim_field="sub",
|
||||
claim_value="user-42",
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
cache_key="jwt_key_mapping:sub:user-42",
|
||||
)
|
||||
|
||||
# Caller still receives the winner's mapping even when cleanup fails
|
||||
assert result == mock_key_obj
|
||||
prisma_client.db.litellm_verificationtoken.delete.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_register_raises_503_when_winner_mapping_vanishes():
|
||||
"""
|
||||
Race edge case: this request loses the unique-constraint race, deletes its
|
||||
orphan, then refetches the winner's mapping — but the winner's row was
|
||||
concurrently deleted. Previously this returned None, silently falling
|
||||
through to less-restrictive team-based JWT auth (bypassing the configured
|
||||
AUTO_REGISTER policy). Must now raise HTTP 503 so the caller retries
|
||||
rather than getting unintended fallback access.
|
||||
"""
|
||||
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
|
||||
virtual_key_mapping_cache_ttl=300,
|
||||
)
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock(
|
||||
side_effect=Exception("Unique constraint failed (P2002)")
|
||||
)
|
||||
prisma_client.db.litellm_verificationtoken.delete = AsyncMock()
|
||||
# Winner row no longer exists by the time we refetch
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None)
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"token": "sk-loser", "key": "sk-loser"},
|
||||
),
|
||||
pytest.raises(HTTPException) as exc_info,
|
||||
):
|
||||
await _auto_register_jwt_mapping(
|
||||
virtual_key_claim_field="sub",
|
||||
claim_value="user-42",
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
cache_key="jwt_key_mapping:sub:user-42",
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 503
|
||||
assert "concurrently removed" in exc_info.value.detail
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_admin_sentinel_skips_db_lookup_on_cache_hit():
|
||||
"""
|
||||
When the cache holds the proxy-admin sentinel (written after a prior
|
||||
request's is_proxy_admin early-return), _resolve_jwt_to_virtual_key must
|
||||
return None *without* hitting the DB. Caller proceeds to auth_builder.
|
||||
|
||||
Without this, every subsequent proxy-admin request under AUTO_REGISTER
|
||||
would re-query get_jwt_key_mapping_object — a cache-miss regression
|
||||
introduced by the deferred-auto-register refactor.
|
||||
"""
|
||||
from litellm.proxy._types import UnregisteredJWTClientBehavior
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER,
|
||||
virtual_key_mapping_cache_ttl=300,
|
||||
)
|
||||
jwt_claims = {"sub": "admin-user"}
|
||||
|
||||
prisma_client = MagicMock()
|
||||
# Will fail the test if accessed — proves the sentinel short-circuits DB
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(
|
||||
side_effect=AssertionError("DB must not be hit when sentinel is cached")
|
||||
)
|
||||
|
||||
user_api_key_cache = DualCache()
|
||||
await user_api_key_cache.async_set_cache(
|
||||
"jwt_key_mapping:sub:admin-user", "__JWT_PROXY_ADMIN__"
|
||||
)
|
||||
|
||||
result = await _resolve_jwt_to_virtual_key(
|
||||
jwt_claims=jwt_claims,
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called()
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Tests: AUTO_REGISTER stamps validated identity from auth_builder
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_register_helper_stamps_validated_identity_context():
|
||||
"""
|
||||
The deferred-auto-register contract: _auto_register_jwt_mapping is called
|
||||
with identity fields from JWTAuthManager.auth_builder's *validated*
|
||||
result (after RBAC, scope mappings, custom_validate, email-domain policy).
|
||||
These must be passed to generate_key_helper_fn so the created key carries
|
||||
them — the cached future-request path then inherits the same team/user/org
|
||||
limits the auth_builder path would have applied.
|
||||
"""
|
||||
from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
virtual_key_mapping_cache_ttl=300,
|
||||
)
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_jwtkeymapping.create = AsyncMock()
|
||||
mock_key_obj = UserAPIKeyAuth(
|
||||
token="hashed", team_id="validated-team", user_id="validated-user"
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.get_key_object",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_key,
|
||||
patch(
|
||||
"litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_gen_key,
|
||||
):
|
||||
mock_gen_key.return_value = {"token": "sk-newkey", "key": "sk-newkey"}
|
||||
mock_get_key.return_value = mock_key_obj
|
||||
|
||||
result = await _auto_register_jwt_mapping(
|
||||
virtual_key_claim_field="sub",
|
||||
claim_value="new-user",
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=DualCache(),
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
cache_key="jwt_key_mapping:sub:new-user",
|
||||
team_id="validated-team",
|
||||
user_id="validated-user",
|
||||
org_id="validated-org",
|
||||
end_user_id="validated-end-user",
|
||||
)
|
||||
|
||||
assert result == mock_key_obj
|
||||
assert mock_gen_key.call_args.kwargs["team_id"] == "validated-team"
|
||||
assert mock_gen_key.call_args.kwargs["user_id"] == "validated-user"
|
||||
assert mock_gen_key.call_args.kwargs["organization_id"] == "validated-org"
|
||||
assert result.org_id == "validated-org"
|
||||
assert result.end_user_id == "validated-end-user"
|
||||
|
||||
|
||||
# ──────────────────────────────────────────────
|
||||
# Tests: backward-compat alias jwt_client_id_field
|
||||
# ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_jwt_client_id_field_alias_maps_to_virtual_key_claim_field():
|
||||
"""
|
||||
jwt_client_id_field (old doc name) must silently alias to virtual_key_claim_field.
|
||||
"""
|
||||
auth = LiteLLM_JWTAuth(jwt_client_id_field="azp")
|
||||
assert auth.virtual_key_claim_field == "azp"
|
||||
|
||||
|
||||
def test_jwt_client_id_field_does_not_raise_on_duplicate():
|
||||
"""
|
||||
If both jwt_client_id_field and virtual_key_claim_field are supplied,
|
||||
virtual_key_claim_field takes precedence and no error is raised.
|
||||
"""
|
||||
auth = LiteLLM_JWTAuth(
|
||||
jwt_client_id_field="old_field",
|
||||
virtual_key_claim_field="new_field",
|
||||
)
|
||||
assert auth.virtual_key_claim_field == "new_field"
|
||||
|
||||
@ -31,12 +31,14 @@ from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.auth_checks import get_key_object, _cache_key_object
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.auth.user_api_key_auth import (
|
||||
_PendingAutoRegister,
|
||||
_matches_routing_override,
|
||||
_reserve_budget_after_common_checks,
|
||||
_route_requires_auth_despite_public,
|
||||
_routing_selector_matches_claim,
|
||||
_run_centralized_common_checks,
|
||||
_run_post_custom_auth_checks,
|
||||
_user_api_key_auth_builder,
|
||||
get_api_key,
|
||||
user_api_key_auth,
|
||||
)
|
||||
@ -1550,6 +1552,93 @@ class TestJWTOAuth2Coexistence:
|
||||
assert mock_jwt_auth.call_args.kwargs["request_method"] == "POST"
|
||||
assert result.user_id == "jwt-human-user"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auto_register_passes_validated_org_context_to_generated_key(self):
|
||||
jwt_token = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ1c2VyMSJ9.signature"
|
||||
general_settings = {"enable_jwt_auth": True}
|
||||
user_api_key_cache = DualCache()
|
||||
prisma_client = MagicMock()
|
||||
jwt_handler = MagicMock()
|
||||
jwt_handler.is_jwt.return_value = True
|
||||
jwt_handler.auth_jwt = AsyncMock(return_value={"sub": "user1"})
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
virtual_key_claim_field="sub",
|
||||
virtual_key_mapping_cache_ttl=300,
|
||||
)
|
||||
auto_registered_key = UserAPIKeyAuth(
|
||||
token="hashed-auto-key",
|
||||
team_id="validated-team",
|
||||
user_id="validated-user",
|
||||
org_id="validated-org",
|
||||
end_user_id="validated-end-user",
|
||||
)
|
||||
mock_jwt_result = {
|
||||
"is_proxy_admin": False,
|
||||
"team_object": None,
|
||||
"user_object": None,
|
||||
"end_user_object": None,
|
||||
"org_object": None,
|
||||
"token": jwt_token,
|
||||
"team_id": "validated-team",
|
||||
"user_id": "validated-user",
|
||||
"end_user_id": "validated-end-user",
|
||||
"org_id": "validated-org",
|
||||
"team_membership": None,
|
||||
"jwt_claims": {"sub": "user1"},
|
||||
}
|
||||
|
||||
mock_request = MagicMock()
|
||||
mock_request.url.path = "/v1/chat/completions"
|
||||
mock_request.method = "POST"
|
||||
mock_request.headers = {"authorization": f"Bearer {jwt_token}"}
|
||||
mock_request.query_params = {}
|
||||
mock_request.state = SimpleNamespace()
|
||||
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.general_settings", general_settings),
|
||||
patch("litellm.proxy.proxy_server.premium_user", True),
|
||||
patch("litellm.proxy.proxy_server.master_key", "sk-master"),
|
||||
patch("litellm.proxy.proxy_server.prisma_client", prisma_client),
|
||||
patch("litellm.proxy.proxy_server.user_api_key_cache", user_api_key_cache),
|
||||
patch("litellm.proxy.proxy_server.proxy_logging_obj", MagicMock()),
|
||||
patch("litellm.proxy.proxy_server.jwt_handler", jwt_handler),
|
||||
patch(
|
||||
"litellm.proxy.auth.user_api_key_auth._resolve_jwt_to_virtual_key",
|
||||
new_callable=AsyncMock,
|
||||
return_value=_PendingAutoRegister(
|
||||
claim_field="sub",
|
||||
claim_value="user1",
|
||||
cache_key="jwt_key_mapping:sub:user1",
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.auth.user_api_key_auth.JWTAuthManager.auth_builder",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_jwt_result,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.auth.user_api_key_auth._auto_register_jwt_mapping",
|
||||
new_callable=AsyncMock,
|
||||
return_value=auto_registered_key,
|
||||
) as mock_auto_register,
|
||||
):
|
||||
result = await _user_api_key_auth_builder(
|
||||
request=mock_request,
|
||||
api_key=jwt_token,
|
||||
azure_api_key_header="",
|
||||
anthropic_api_key_header=None,
|
||||
google_ai_studio_api_key_header=None,
|
||||
azure_apim_header=None,
|
||||
request_data={"model": "gpt-4o-mini"},
|
||||
)
|
||||
|
||||
mock_auto_register.assert_awaited_once()
|
||||
assert mock_auto_register.call_args.kwargs["team_id"] == "validated-team"
|
||||
assert mock_auto_register.call_args.kwargs["user_id"] == "validated-user"
|
||||
assert mock_auto_register.call_args.kwargs["org_id"] == "validated-org"
|
||||
assert mock_auto_register.call_args.kwargs["end_user_id"] == "validated-end-user"
|
||||
assert result.org_id == "validated-org"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routing_override_routes_matching_jwt_to_oauth2(self):
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user