diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index feab3b04a5..6c36c813e7 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -4448,6 +4448,25 @@ class JWTRoutingOverride(BaseModel): } +class UnregisteredJWTClientBehavior(str, enum.Enum): + """ + Controls what happens when `virtual_key_claim_field` is configured but the + JWT claim value has no registered mapping in `litellm_jwtkeymapping`. + + - fallback_team_mapping: Fall through to standard team-based JWT auth (default, + backward-compatible). + - reject: Immediately return HTTP 403. Use this when every valid JWT client + must have a pre-registered virtual key — unknown callers are denied. + - auto_register: Automatically create a new virtual key and mapping on first + encounter. The new key has no budget/model restrictions; admins can tighten + it later via /jwt_client/update. + """ + + FALLBACK_TEAM_MAPPING = "fallback_team_mapping" + REJECT = "reject" + AUTO_REGISTER = "auto_register" + + class JWTIssuerConfig(BaseModel): """ Issuer-bound JWT validation configuration. @@ -4614,6 +4633,15 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): default=300, description="TTL (seconds) for caching JWT-to-virtual-key mapping lookups.", ) + unregistered_jwt_client_behavior: UnregisteredJWTClientBehavior = Field( + default=UnregisteredJWTClientBehavior.FALLBACK_TEAM_MAPPING, + description=( + "What to do when virtual_key_claim_field is set but the JWT claim value " + "has no registered mapping. 'fallback_team_mapping' (default): fall through " + "to team-based JWT auth. 'reject': return HTTP 403. " + "'auto_register': auto-create a virtual key and mapping on first encounter." + ), + ) routing_overrides: Optional[List[JWTRoutingOverride]] = Field( default=None, description="Optional claim-based routing overrides for JWT-shaped tokens. Matching rules route requests to oauth2 before default JWT flow.", @@ -4633,6 +4661,13 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase): # ``s3://`` / ``gcs://`` when this is None. config_file_path = kwargs.pop("config_file_path", None) + # Backward-compat: jwt_client_id_field was renamed to virtual_key_claim_field + if "jwt_client_id_field" in kwargs: + if "virtual_key_claim_field" not in kwargs: + kwargs["virtual_key_claim_field"] = kwargs.pop("jwt_client_id_field") + else: + kwargs.pop("jwt_client_id_field") + # get the attribute names for this Pydantic model allowed_keys = LiteLLM_JWTAuth.__annotations__.keys() diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 9d4efbaeee..92765fc8ea 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -12,7 +12,7 @@ import fnmatch import re import secrets from datetime import datetime, timezone -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union, cast +from typing import Any, Dict, Iterator, NamedTuple, List, Optional, Tuple, Union, cast import fastapi from fastapi import HTTPException, Request, WebSocket, status @@ -607,6 +607,169 @@ async def check_api_key_for_custom_headers_or_pass_through_endpoints( return api_key +# Cache sentinel written when a JWT under AUTO_REGISTER resolved to a proxy +# admin via auth_builder. Proxy admins don't need a mapped virtual key (they +# have full access via auth_builder anyway), but without a cache entry every +# subsequent request from the same JWT identity would re-query the DB for a +# non-existent mapping. Sentinel tells _resolve_jwt_to_virtual_key to skip +# the lookup and return None (caller proceeds to auth_builder). +_JWT_PROXY_ADMIN_SENTINEL = "__JWT_PROXY_ADMIN__" + + +class _PendingAutoRegister(NamedTuple): + """ + Signal returned by ``_resolve_jwt_to_virtual_key`` when the JWT's claim is + unmapped and ``unregistered_jwt_client_behavior`` is AUTO_REGISTER. + + The caller MUST run standard ``JWTAuthManager.auth_builder`` to apply RBAC, + scope mappings, ``custom_validate``, and ``user_allowed_email_domain`` + policy BEFORE calling ``_auto_register_jwt_mapping`` with the validated + ``team_id`` / ``user_id`` from the auth_builder result. Auto-registering + purely on a signature-valid JWT (the old behavior) bypassed every JWT + policy beyond signature verification. + """ + + claim_field: str + claim_value: str + cache_key: str + + +async def _auto_register_jwt_mapping( + virtual_key_claim_field: str, + claim_value: str, + jwt_handler: JWTHandler, + prisma_client: PrismaClient, + user_api_key_cache: UserApiKeyCache, + parent_otel_span: Optional[Span], + proxy_logging_obj: ProxyLogging, + cache_key: str, + team_id: Optional[str] = None, + user_id: Optional[str] = None, + org_id: Optional[str] = None, + end_user_id: Optional[str] = None, +) -> Optional[UserAPIKeyAuth]: + """ + Auto-register: create a new virtual key + mapping for an unrecognised JWT + claim value. ``team_id`` and ``user_id`` must come from a successful + ``JWTAuthManager.auth_builder`` run — they encode the JWT identity AFTER + RBAC/scope/custom_validate/email-domain policy has been enforced. The key + is stamped with those values so the cached future-request path inherits + the same team/user/org limits the auth_builder path would have applied. + + Race safety: if two concurrent requests both reach here simultaneously (both + saw no mapping in the DB), one will win the unique-constraint race on + litellm_jwtkeymapping. The loser catches the conflict, deletes its orphaned + key, fetches the winner's mapping, and proceeds — no error surfaced. + """ + # Inline import required: key_management_endpoints imports user_api_key_auth + # (line 51) so a module-level import here would create a circular dependency. + from litellm.proxy.management_endpoints.key_management_endpoints import ( + generate_key_helper_fn, + ) + + # ``table_name="key"`` is required: without it, generate_key_helper_fn + # falls into the user-upsert branch (`table_name is None or "user"`) and + # attempts to insert into LiteLLM_UserTable with user_id=None, which fails + # the NOT NULL @id constraint. Every successful key-creation caller (e.g. + # /key/generate) passes table_name="key" explicitly. + key_data = await generate_key_helper_fn( + request_type="key", + table_name="key", + team_id=team_id, + user_id=user_id, + organization_id=org_id, + metadata={ + "auto_registered": True, + "jwt_claim_field": virtual_key_claim_field, + "jwt_claim_value": claim_value, + }, + ) + # generate_key_helper_fn returns the plaintext key in "token"; the persisted + # row in LiteLLM_VerificationToken uses its hash, so hash here to get the FK + # value referenced by LiteLLM_JWTKeyMapping.token. + token_hash = hash_token(key_data["token"]) + + try: + await prisma_client.db.litellm_jwtkeymapping.create( + data={ + "jwt_claim_name": virtual_key_claim_field, + "jwt_claim_value": claim_value, + "token": token_hash, + "created_by": "auto_register", + "updated_by": "auto_register", + } + ) + except Exception as e: + error_str = str(e).lower() + if "unique" in error_str or "p2002" in error_str: + # A concurrent request won the race. The key generate_key_helper_fn + # just persisted to LiteLLM_VerificationToken is orphaned — nothing + # maps to it, but it's a fully valid unrestricted API key sitting in + # the DB and the cleartext is in memory on this request. Delete it + # so orphans don't accumulate under sustained concurrency. + verbose_proxy_logger.debug( + "JWT Key Mapping (auto_register): unique conflict on create — " + "deleting orphaned virtual key and fetching winner's mapping for %s='%s'.", + virtual_key_claim_field, + claim_value, + ) + try: + await prisma_client.db.litellm_verificationtoken.delete( + where={"token": token_hash} + ) + except Exception as delete_err: + # Don't fail the request if cleanup fails — the orphan is + # unmapped and inert. Log so an operator can prune it later. + verbose_proxy_logger.warning( + "JWT Key Mapping (auto_register): failed to delete orphaned key after race: %s", + delete_err, + ) + token_hash = await get_jwt_key_mapping_object( + jwt_claim_name=virtual_key_claim_field, + jwt_claim_value=claim_value, + prisma_client=prisma_client, + ) + if token_hash is None: + # The winner's mapping vanished between the unique-constraint + # conflict and our re-fetch (concurrent delete). Returning None + # here would silently fall through to team-based JWT auth — + # a less-restrictive path than the operator configured. Raise + # 503 so the caller retries against a stable state instead. + raise HTTPException( + status_code=503, + detail=( + "JWT Key Mapping: AUTO_REGISTER race resolution failed — " + "winner's mapping was concurrently removed. Retry the request." + ), + ) + else: + raise + + await user_api_key_cache.async_set_cache( + key=cache_key, + value=token_hash, + ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl, + ) + + verbose_proxy_logger.info( + "JWT Key Mapping (auto_register): created new virtual key for %s='%s'.", + virtual_key_claim_field, + claim_value, + ) + + auto_registered_key = await get_key_object( + hashed_token=token_hash, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + ) + if auto_registered_key is not None: + auto_registered_key.org_id = org_id + auto_registered_key.end_user_id = end_user_id + return auto_registered_key + + async def _resolve_jwt_to_virtual_key( jwt_claims: dict, jwt_handler: JWTHandler, @@ -614,7 +777,22 @@ async def _resolve_jwt_to_virtual_key( user_api_key_cache: UserApiKeyCache, parent_otel_span: Optional[Span], proxy_logging_obj: ProxyLogging, -) -> Optional[UserAPIKeyAuth]: +) -> Union[Optional[UserAPIKeyAuth], "_PendingAutoRegister"]: + """ + Returns: + - ``UserAPIKeyAuth``: a resolved virtual key (cache hit or DB hit). The + caller may use this directly; JWT policy has been enforced previously + (at key-creation time or, for cached results, before caching). + - ``_PendingAutoRegister``: claim is unmapped and behavior is AUTO_REGISTER. + The caller MUST run ``JWTAuthManager.auth_builder`` to enforce JWT + policy (RBAC, scope, custom_validate, email-domain), then invoke + ``_auto_register_jwt_mapping`` with the validated team_id/user_id. + - ``None``: claim is unmapped and behavior is FALLBACK_TEAM_MAPPING. + The caller falls through to standard team-based JWT auth (which itself + enforces full JWT policy via auth_builder). + - Raises HTTPException: REJECT policy hit, missing claim under + REJECT/AUTO_REGISTER, or other policy violations. + """ virtual_key_claim_field = jwt_handler.litellm_jwtauth.virtual_key_claim_field if virtual_key_claim_field is None: return None @@ -629,12 +807,61 @@ async def _resolve_jwt_to_virtual_key( verbose_proxy_logger.debug( f"JWT Key Mapping: Claim field '{virtual_key_claim_field}' not found in JWT claims." ) + # A missing claim is an unmapped client — apply the no-match policy + # rather than returning early. Otherwise a caller can bypass REJECT + # simply by presenting a JWT that omits the configured field. For + # AUTO_REGISTER there is no stable identity to map without a claim + # value, so we deny rather than create a sentinel-keyed record. + behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior + if behavior in ( + UnregisteredJWTClientBehavior.REJECT, + UnregisteredJWTClientBehavior.AUTO_REGISTER, + ): + raise HTTPException( + status_code=403, + detail=( + f"JWT Key Mapping: Required claim '{virtual_key_claim_field}' " + "is missing from the JWT. Access denied." + ), + ) return None cache_key = f"jwt_key_mapping:{virtual_key_claim_field}:{claim_value}" cached_mapping = await user_api_key_cache.async_get_cache(cache_key) + if cached_mapping == _JWT_PROXY_ADMIN_SENTINEL: + # Previously resolved to a proxy admin via auth_builder; skip the + # mapping lookup and let the caller re-run auth_builder. Avoids a + # repeated DB hit on every proxy-admin request under AUTO_REGISTER. + return None + if cached_mapping == "__NO_MAPPING__": + behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior + if behavior == UnregisteredJWTClientBehavior.REJECT: + raise HTTPException( + status_code=403, + detail=f"JWT Key Mapping: No registered mapping for {virtual_key_claim_field}='{claim_value}'. Access denied.", + ) + if behavior == UnregisteredJWTClientBehavior.AUTO_REGISTER: + # Stale sentinel written under a prior fallback_team_mapping config — + # evict it and defer auto-register to after auth_builder runs. Raise + # the same 500 as the fresh-path AUTO_REGISTER branch when there is + # no DB, so behavior is consistent regardless of whether the cache + # happens to hold the sentinel. + if prisma_client is None: + raise HTTPException( + status_code=500, + detail=( + "JWT Key Mapping: AUTO_REGISTER requires a database connection. " + "Configure a database or change unregistered_jwt_client_behavior." + ), + ) + await user_api_key_cache.async_delete_cache(cache_key) + return _PendingAutoRegister( + claim_field=virtual_key_claim_field, + claim_value=str(claim_value), + cache_key=cache_key, + ) return None elif cached_mapping is not None: return await get_key_object( @@ -645,14 +872,15 @@ async def _resolve_jwt_to_virtual_key( proxy_logging_obj=proxy_logging_obj, ) - if prisma_client is None: - return None - - token_hash = await get_jwt_key_mapping_object( - jwt_claim_name=virtual_key_claim_field, - jwt_claim_value=str(claim_value), - prisma_client=prisma_client, - ) + # Resolve the mapping from DB, or treat prisma_client=None as a definitive + # miss (no DB → no mapping can exist → apply no-match policy below). + token_hash: Optional[str] = None + if prisma_client is not None: + token_hash = await get_jwt_key_mapping_object( + jwt_claim_name=virtual_key_claim_field, + jwt_claim_value=str(claim_value), + prisma_client=prisma_client, + ) if token_hash is not None: await user_api_key_cache.async_set_cache( @@ -667,13 +895,50 @@ async def _resolve_jwt_to_virtual_key( parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) - else: + + # No mapping found (DB miss or no DB) — apply no-match policy. + behavior = jwt_handler.litellm_jwtauth.unregistered_jwt_client_behavior + + if behavior == UnregisteredJWTClientBehavior.REJECT: + # Cache the miss before raising so repeated rejections are served from + # cache and don't re-query the DB on every request. await user_api_key_cache.async_set_cache( key=cache_key, value="__NO_MAPPING__", ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl, ) - return None + raise HTTPException( + status_code=403, + detail=f"JWT Key Mapping: No registered mapping for {virtual_key_claim_field}='{claim_value}'. Access denied.", + ) + + if behavior == UnregisteredJWTClientBehavior.AUTO_REGISTER: + if prisma_client is None: + raise HTTPException( + status_code=500, + detail=( + "JWT Key Mapping: AUTO_REGISTER requires a database connection. " + "Configure a database or change unregistered_jwt_client_behavior." + ), + ) + # Defer: caller runs JWTAuthManager.auth_builder to enforce RBAC, scope, + # custom_validate, and email-domain policy, then auto-registers using + # the validated identity. Auto-registering here on a signature-only + # JWT would bypass every JWT policy beyond signature verification. + return _PendingAutoRegister( + claim_field=virtual_key_claim_field, + claim_value=str(claim_value), + cache_key=cache_key, + ) + + # FALLBACK_TEAM_MAPPING (default): cache the miss and return None so the + # caller falls through to standard team-based JWT auth. + await user_api_key_cache.async_set_cache( + key=cache_key, + value="__NO_MAPPING__", + ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl, + ) + return None def _ensure_parent_otel_span_on_request_state(request: Request) -> None: @@ -893,6 +1158,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 # Try JWT-to-Virtual-Key mapping first to avoid # unnecessary DB queries in auth_builder do_standard_jwt_auth = True + pending_auto_register: Optional[_PendingAutoRegister] = None if jwt_handler.litellm_jwtauth.virtual_key_claim_field is not None: # Decode JWT to get claims without running full auth_builder jwt_claims: Optional[dict] @@ -901,7 +1167,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 else: jwt_claims = await jwt_handler.auth_jwt(token=api_key) - valid_token = await _resolve_jwt_to_virtual_key( + resolve_result = await _resolve_jwt_to_virtual_key( jwt_claims=jwt_claims, jwt_handler=jwt_handler, prisma_client=prisma_client, @@ -909,11 +1175,19 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 parent_otel_span=parent_otel_span, proxy_logging_obj=proxy_logging_obj, ) - if valid_token is not None: + if isinstance(resolve_result, UserAPIKeyAuth): + valid_token = resolve_result api_key = valid_token.token or "" valid_token.jwt_claims = jwt_claims do_standard_jwt_auth = False # Fall through to virtual key checks + elif isinstance(resolve_result, _PendingAutoRegister): + # Run full JWT policy (RBAC, scope, custom_validate, + # email-domain) via auth_builder, then create the key + # from the validated identity below. + pending_auto_register = resolve_result + # else: None → FALLBACK_TEAM_MAPPING, falls through to + # standard JWT auth_builder below if do_standard_jwt_auth: with tracer.trace("litellm.proxy.auth.jwt_auth_builder"): @@ -946,6 +1220,19 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 jwt_claims = result.get("jwt_claims", None) if is_proxy_admin: + # Proxy admins authenticate via auth_builder (full + # access), not via a mapped virtual key. If + # AUTO_REGISTER was pending, cache a sentinel so + # future requests from this JWT identity skip the + # DB mapping lookup in _resolve_jwt_to_virtual_key. + # Without this, every proxy-admin request under + # AUTO_REGISTER re-hits get_jwt_key_mapping_object. + if pending_auto_register is not None: + await user_api_key_cache.async_set_cache( + key=pending_auto_register.cache_key, + value=_JWT_PROXY_ADMIN_SENTINEL, + ttl=jwt_handler.litellm_jwtauth.virtual_key_mapping_cache_ttl, + ) return UserAPIKeyAuth( api_key=None, user_role=LitellmUserRoles.PROXY_ADMIN, @@ -1032,6 +1319,32 @@ async def _user_api_key_auth_builder( # noqa: PLR0915 else None ) + # AUTO_REGISTER deferred from _resolve_jwt_to_virtual_key. + # JWT policy (RBAC, scope, custom_validate, email-domain) + # has now been enforced by auth_builder above. Create the + # mapping + virtual key from the *validated* identity, then + # replace valid_token with the new key so downstream checks + # use the key-scoped path. + if pending_auto_register is not None and prisma_client is not None: + auto_registered = await _auto_register_jwt_mapping( + virtual_key_claim_field=pending_auto_register.claim_field, + claim_value=pending_auto_register.claim_value, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, + cache_key=pending_auto_register.cache_key, + team_id=team_id, + user_id=user_id, + org_id=org_id, + end_user_id=end_user_id, + ) + if auto_registered is not None: + auto_registered.jwt_claims = jwt_claims + valid_token = auto_registered + api_key = valid_token.token or "" + # Check if model has zero cost - if so, skip all budget checks model = _get_model_from_request_context( request_data=request_data, diff --git a/tests/proxy_unit_tests/test_jwt_key_mapping.py b/tests/proxy_unit_tests/test_jwt_key_mapping.py index bf1c4a3f6c..61c2418396 100644 --- a/tests/proxy_unit_tests/test_jwt_key_mapping.py +++ b/tests/proxy_unit_tests/test_jwt_key_mapping.py @@ -27,7 +27,6 @@ from litellm.proxy.management_endpoints.jwt_key_mapping_endpoints import ( from litellm.caching.caching import DualCache from fastapi import HTTPException - # ────────────────────────────────────────────── # Tests: _resolve_jwt_to_virtual_key # ────────────────────────────────────────────── @@ -454,3 +453,856 @@ async def test_create_success_returns_response_without_token(): assert isinstance(result, JWTKeyMappingResponse) assert "token" not in result.model_fields assert result.jwt_claim_name == "email" + + +# ────────────────────────────────────────────── +# Tests: unregistered_jwt_client_behavior +# ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_reject_behavior_raises_403_on_no_mapping(): + """ + When unregistered_jwt_client_behavior='reject' and no mapping exists, + _resolve_jwt_to_virtual_key must raise HTTP 403. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="email", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT, + ) + jwt_claims = {"email": "unknown@example.com"} + + prisma_client = MagicMock() + prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None) + + user_api_key_cache = DualCache() + + with patch( + "litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock + ): + with pytest.raises(HTTPException) as exc_info: + await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + ) + assert exc_info.value.status_code == 403 + assert "unknown@example.com" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_reject_behavior_caches_sentinel_after_db_miss(): + """ + On a fresh DB miss with REJECT, the __NO_MAPPING__ sentinel must be written + to cache so that subsequent rejected requests are served from cache and do + not re-query the DB. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="email", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT, + virtual_key_mapping_cache_ttl=300, + ) + jwt_claims = {"email": "unknown@example.com"} + + prisma_client = MagicMock() + prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None) + + user_api_key_cache = DualCache() + + with patch( + "litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock + ): + # First call — DB miss, should raise 403 and write sentinel + with pytest.raises(HTTPException) as exc_info: + await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + ) + assert exc_info.value.status_code == 403 + + # Sentinel must now be in cache + cached = await user_api_key_cache.async_get_cache( + "jwt_key_mapping:email:unknown@example.com" + ) + assert cached == "__NO_MAPPING__" + + # Second call — must raise 403 from cache, no additional DB hit + prisma_client.db.litellm_jwtkeymapping.find_first.reset_mock() + with pytest.raises(HTTPException) as exc_info2: + await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + ) + assert exc_info2.value.status_code == 403 + prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called() + + +@pytest.mark.asyncio +async def test_reject_behavior_raises_403_on_cached_no_mapping(): + """ + When the negative-cache sentinel __NO_MAPPING__ is present and behavior is + 'reject', the function must also raise HTTP 403 (not return None silently). + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="email", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT, + ) + jwt_claims = {"email": "unknown@example.com"} + + prisma_client = MagicMock() + prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None) + + # Pre-populate the negative cache so the DB is not hit + user_api_key_cache = DualCache() + cache_key = "jwt_key_mapping:email:unknown@example.com" + await user_api_key_cache.async_set_cache(cache_key, "__NO_MAPPING__") + + with patch( + "litellm.proxy.auth.user_api_key_auth.get_key_object", new_callable=AsyncMock + ): + with pytest.raises(HTTPException) as exc_info: + await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + ) + assert exc_info.value.status_code == 403 + # DB must NOT have been hit (sentinel served from cache) + prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called() + + +@pytest.mark.asyncio +async def test_auto_register_returns_pending_signal_without_creating_key(): + """ + Security: when unregistered_jwt_client_behavior='auto_register' and no + mapping exists, _resolve_jwt_to_virtual_key must NOT create the key yet. + It returns a _PendingAutoRegister signal so the caller can run + JWTAuthManager.auth_builder (enforcing RBAC, scope mappings, + custom_validate, user_allowed_email_domain) FIRST. Creating the key here + would bypass every JWT policy beyond signature verification. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + from litellm.proxy.auth.user_api_key_auth import _PendingAutoRegister + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER, + virtual_key_mapping_cache_ttl=300, + ) + jwt_claims = {"sub": "new-user-42"} + + prisma_client = MagicMock() + prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None) + prisma_client.db.litellm_jwtkeymapping.create = AsyncMock() + + user_api_key_cache = DualCache() + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn", + new_callable=AsyncMock, + ) as mock_gen_key: + result = await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + ) + + assert isinstance(result, _PendingAutoRegister) + assert result.claim_field == "sub" + assert result.claim_value == "new-user-42" + assert result.cache_key == "jwt_key_mapping:sub:new-user-42" + # CRITICAL: no key was created — that must wait until after auth_builder + mock_gen_key.assert_not_called() + prisma_client.db.litellm_jwtkeymapping.create.assert_not_called() + + +@pytest.mark.asyncio +async def test_auto_register_creates_key_and_mapping_when_helper_invoked(): + """ + When the caller invokes _auto_register_jwt_mapping directly (after + auth_builder validation), the helper creates the key + mapping row and + returns a UserAPIKeyAuth. The mapping row stores the hashed token (FK to + LiteLLM_VerificationToken), not the plaintext key. + """ + from litellm.proxy._types import hash_token + from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + virtual_key_mapping_cache_ttl=300, + ) + + prisma_client = MagicMock() + prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None) + prisma_client.db.litellm_jwtkeymapping.create = AsyncMock() + + user_api_key_cache = DualCache() + plaintext_key = "sk-auto-key" + expected_hash = hash_token(plaintext_key) + mock_key_obj = UserAPIKeyAuth(token=expected_hash, team_id="validated-team") + + with ( + patch( + "litellm.proxy.auth.user_api_key_auth.get_key_object", + new_callable=AsyncMock, + ) as mock_get_key, + patch( + "litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn", + new_callable=AsyncMock, + ) as mock_gen_key, + ): + mock_gen_key.return_value = {"token": plaintext_key, "key": plaintext_key} + mock_get_key.return_value = mock_key_obj + + result = await _auto_register_jwt_mapping( + virtual_key_claim_field="sub", + claim_value="new-user-42", + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + cache_key="jwt_key_mapping:sub:new-user-42", + team_id="validated-team", + user_id="validated-user", + ) + + assert result == mock_key_obj + # generate_key_helper_fn was passed table_name="key" (not user-upsert path) + # and the validated team_id + user_id from auth_builder + assert mock_gen_key.call_args.kwargs["table_name"] == "key" + assert mock_gen_key.call_args.kwargs["team_id"] == "validated-team" + assert mock_gen_key.call_args.kwargs["user_id"] == "validated-user" + # Mapping row was created with the hashed token (FK target) + call_data = prisma_client.db.litellm_jwtkeymapping.create.call_args[1]["data"] + assert call_data["jwt_claim_name"] == "sub" + assert call_data["jwt_claim_value"] == "new-user-42" + assert call_data["token"] == expected_hash + cached = await user_api_key_cache.async_get_cache("jwt_key_mapping:sub:new-user-42") + assert cached == expected_hash + + +@pytest.mark.asyncio +async def test_auto_register_returns_pending_signal_on_stale_no_mapping_sentinel(): + """ + If the cache holds a stale __NO_MAPPING__ sentinel (written under a prior + fallback_team_mapping config) and behavior is now AUTO_REGISTER, the + resolver must evict the sentinel and return _PendingAutoRegister (so the + caller can run auth_builder before creating the key) — not silently return + None and not create the key on the spot. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + from litellm.proxy.auth.user_api_key_auth import _PendingAutoRegister + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="email", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER, + virtual_key_mapping_cache_ttl=300, + ) + jwt_claims = {"email": "alice@corp.com"} + + prisma_client = MagicMock() + prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None) + prisma_client.db.litellm_jwtkeymapping.create = AsyncMock() + + user_api_key_cache = DualCache() + await user_api_key_cache.async_set_cache( + "jwt_key_mapping:email:alice@corp.com", "__NO_MAPPING__" + ) + + with patch( + "litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn", + new_callable=AsyncMock, + ) as mock_gen_key: + result = await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + ) + + assert isinstance(result, _PendingAutoRegister) + # Stale sentinel must be evicted so the deferred auto-register actually + # runs after auth_builder validates the JWT + cached_after = await user_api_key_cache.async_get_cache( + "jwt_key_mapping:email:alice@corp.com" + ) + assert cached_after is None + mock_gen_key.assert_not_called() + prisma_client.db.litellm_jwtkeymapping.create.assert_not_called() + + +@pytest.mark.asyncio +async def test_auto_register_race_condition_unique_conflict(): + """ + If two concurrent requests both call _auto_register_jwt_mapping and the + second hits a unique-constraint violation on create, it must: + 1) delete the orphaned virtual key it just created (so orphans don't + accumulate in LiteLLM_VerificationToken under sustained concurrency), + 2) fall back to the winner's mapping, + 3) not surface an error. + """ + from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping + from litellm.proxy._types import UnregisteredJWTClientBehavior, hash_token + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER, + virtual_key_mapping_cache_ttl=300, + ) + + prisma_client = MagicMock() + prisma_client.db.litellm_jwtkeymapping.create = AsyncMock( + side_effect=Exception("Unique constraint failed (P2002)") + ) + prisma_client.db.litellm_verificationtoken.delete = AsyncMock() + # Simulate the winner's mapping already in DB after the conflict + winner_mapping = MagicMock() + winner_mapping.token = "winner_token_hash" + winner_mapping.is_active = True + prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock( + return_value=winner_mapping + ) + + user_api_key_cache = DualCache() + loser_plaintext = "sk-loser" + loser_hash = hash_token(loser_plaintext) + mock_key_obj = UserAPIKeyAuth(token="winner_token_hash", team_id=None) + + with ( + patch( + "litellm.proxy.auth.user_api_key_auth.get_key_object", + new_callable=AsyncMock, + ) as mock_get_key, + patch( + "litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn", + new_callable=AsyncMock, + return_value={"token": loser_plaintext, "key": loser_plaintext}, + ), + ): + mock_get_key.return_value = mock_key_obj + + result = await _auto_register_jwt_mapping( + virtual_key_claim_field="sub", + claim_value="user-42", + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + cache_key="jwt_key_mapping:sub:user-42", + ) + + assert result == mock_key_obj + # The orphaned loser key must be deleted from LiteLLM_VerificationToken + prisma_client.db.litellm_verificationtoken.delete.assert_called_once_with( + where={"token": loser_hash} + ) + # Cache should hold the winner's token, not the loser's + cached = await user_api_key_cache.async_get_cache("jwt_key_mapping:sub:user-42") + assert cached == "winner_token_hash" + mock_get_key.assert_called_once_with( + hashed_token="winner_token_hash", + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + ) + + +# ────────────────────────────────────────────── +# Tests: prisma_client=None does not bypass no-match policy +# ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_reject_behavior_enforced_when_prisma_client_is_none(): + """ + When prisma_client is None and behavior is REJECT, a 403 must be raised — + not silently fallen through to team auth. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="email", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT, + ) + jwt_claims = {"email": "unknown@example.com"} + + user_api_key_cache = DualCache() + + with pytest.raises(HTTPException) as exc_info: + await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=None, # no DB + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + ) + assert exc_info.value.status_code == 403 + assert "unknown@example.com" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_reject_raises_403_when_claim_field_missing_from_jwt(): + """ + Security: a JWT that omits the configured virtual_key_claim_field must NOT + bypass the REJECT policy. Previously the early `if claim_value is None: + return None` branch ran before the policy check, letting a caller who knows + the configured claim-field name silently fall through to team-based auth. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.REJECT, + ) + # JWT does NOT contain "sub" + jwt_claims = {"email": "user@example.com"} + + with pytest.raises(HTTPException) as exc_info: + await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=MagicMock(), + user_api_key_cache=DualCache(), + parent_otel_span=None, + proxy_logging_obj=None, + ) + assert exc_info.value.status_code == 403 + assert "'sub'" in exc_info.value.detail + assert "missing from the JWT" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_auto_register_raises_403_when_claim_field_missing_from_jwt(): + """ + AUTO_REGISTER cannot create a mapping without a stable identity. When the + configured claim field is missing from the JWT, return 403 rather than + silently falling through (which would bypass the unregistered-client policy) + or creating a sentinel-keyed record. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER, + ) + jwt_claims = {"email": "user@example.com"} + + with pytest.raises(HTTPException) as exc_info: + await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=MagicMock(), + user_api_key_cache=DualCache(), + parent_otel_span=None, + proxy_logging_obj=None, + ) + assert exc_info.value.status_code == 403 + assert "missing from the JWT" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_fallback_team_mapping_returns_none_when_claim_field_missing_from_jwt(): + """ + Under FALLBACK_TEAM_MAPPING (the default, backward-compatible mode), a JWT + without the configured claim field must still fall through to team-based + JWT auth — not raise. This preserves the pre-existing contract. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.FALLBACK_TEAM_MAPPING, + ) + jwt_claims = {"email": "user@example.com"} + + result = await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=MagicMock(), + user_api_key_cache=DualCache(), + parent_otel_span=None, + proxy_logging_obj=None, + ) + assert result is None + + +@pytest.mark.asyncio +async def test_fallback_team_mapping_returns_none_when_prisma_client_is_none(): + """ + When prisma_client is None and behavior is FALLBACK_TEAM_MAPPING, the + function must return None (fall through to team auth) — not raise. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="email", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.FALLBACK_TEAM_MAPPING, + ) + jwt_claims = {"email": "anyone@example.com"} + + result = await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=None, + user_api_key_cache=DualCache(), + parent_otel_span=None, + proxy_logging_obj=None, + ) + assert result is None + + +@pytest.mark.asyncio +async def test_auto_register_raises_500_when_prisma_client_is_none(): + """ + AUTO_REGISTER without a DB connection must raise HTTP 500 with a clear + message — it cannot create keys without a database. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER, + ) + jwt_claims = {"sub": "new-user-42"} + + with pytest.raises(HTTPException) as exc_info: + await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=None, + user_api_key_cache=DualCache(), + parent_otel_span=None, + proxy_logging_obj=None, + ) + assert exc_info.value.status_code == 500 + assert "AUTO_REGISTER requires a database" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_auto_register_raises_500_when_sentinel_cached_and_no_db(): + """ + AUTO_REGISTER + cached __NO_MAPPING__ sentinel + prisma_client is None must + raise HTTP 500, matching the fresh-path behavior. Previously this path + silently returned None and let the request fall through to team auth, + creating different access-control outcomes under identical configuration. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER, + virtual_key_mapping_cache_ttl=300, + ) + jwt_claims = {"sub": "user-42"} + + user_api_key_cache = DualCache() + # Stale sentinel written under a prior fallback_team_mapping config + await user_api_key_cache.async_set_cache( + "jwt_key_mapping:sub:user-42", "__NO_MAPPING__" + ) + + with pytest.raises(HTTPException) as exc_info: + await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=None, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + ) + assert exc_info.value.status_code == 500 + assert "AUTO_REGISTER requires a database" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_auto_register_race_conflict_tolerates_delete_failure(): + """ + If deleting the orphaned virtual key after a race-condition conflict fails + (e.g. transient DB error), the request must still succeed by returning the + winner's mapping — the orphan is unmapped and inert. + """ + from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER, + virtual_key_mapping_cache_ttl=300, + ) + + prisma_client = MagicMock() + prisma_client.db.litellm_jwtkeymapping.create = AsyncMock( + side_effect=Exception("Unique constraint failed (P2002)") + ) + prisma_client.db.litellm_verificationtoken.delete = AsyncMock( + side_effect=Exception("transient DB error") + ) + winner_mapping = MagicMock() + winner_mapping.token = "winner_token_hash" + winner_mapping.is_active = True + prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock( + return_value=winner_mapping + ) + + user_api_key_cache = DualCache() + mock_key_obj = UserAPIKeyAuth(token="winner_token_hash", team_id=None) + + with ( + patch( + "litellm.proxy.auth.user_api_key_auth.get_key_object", + new_callable=AsyncMock, + ) as mock_get_key, + patch( + "litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn", + new_callable=AsyncMock, + return_value={"token": "sk-loser", "key": "sk-loser"}, + ), + ): + mock_get_key.return_value = mock_key_obj + + result = await _auto_register_jwt_mapping( + virtual_key_claim_field="sub", + claim_value="user-42", + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + cache_key="jwt_key_mapping:sub:user-42", + ) + + # Caller still receives the winner's mapping even when cleanup fails + assert result == mock_key_obj + prisma_client.db.litellm_verificationtoken.delete.assert_called_once() + + +@pytest.mark.asyncio +async def test_auto_register_raises_503_when_winner_mapping_vanishes(): + """ + Race edge case: this request loses the unique-constraint race, deletes its + orphan, then refetches the winner's mapping — but the winner's row was + concurrently deleted. Previously this returned None, silently falling + through to less-restrictive team-based JWT auth (bypassing the configured + AUTO_REGISTER policy). Must now raise HTTP 503 so the caller retries + rather than getting unintended fallback access. + """ + from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER, + virtual_key_mapping_cache_ttl=300, + ) + + prisma_client = MagicMock() + prisma_client.db.litellm_jwtkeymapping.create = AsyncMock( + side_effect=Exception("Unique constraint failed (P2002)") + ) + prisma_client.db.litellm_verificationtoken.delete = AsyncMock() + # Winner row no longer exists by the time we refetch + prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock(return_value=None) + + user_api_key_cache = DualCache() + + with ( + patch( + "litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn", + new_callable=AsyncMock, + return_value={"token": "sk-loser", "key": "sk-loser"}, + ), + pytest.raises(HTTPException) as exc_info, + ): + await _auto_register_jwt_mapping( + virtual_key_claim_field="sub", + claim_value="user-42", + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + cache_key="jwt_key_mapping:sub:user-42", + ) + + assert exc_info.value.status_code == 503 + assert "concurrently removed" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_proxy_admin_sentinel_skips_db_lookup_on_cache_hit(): + """ + When the cache holds the proxy-admin sentinel (written after a prior + request's is_proxy_admin early-return), _resolve_jwt_to_virtual_key must + return None *without* hitting the DB. Caller proceeds to auth_builder. + + Without this, every subsequent proxy-admin request under AUTO_REGISTER + would re-query get_jwt_key_mapping_object — a cache-miss regression + introduced by the deferred-auto-register refactor. + """ + from litellm.proxy._types import UnregisteredJWTClientBehavior + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + unregistered_jwt_client_behavior=UnregisteredJWTClientBehavior.AUTO_REGISTER, + virtual_key_mapping_cache_ttl=300, + ) + jwt_claims = {"sub": "admin-user"} + + prisma_client = MagicMock() + # Will fail the test if accessed — proves the sentinel short-circuits DB + prisma_client.db.litellm_jwtkeymapping.find_first = AsyncMock( + side_effect=AssertionError("DB must not be hit when sentinel is cached") + ) + + user_api_key_cache = DualCache() + await user_api_key_cache.async_set_cache( + "jwt_key_mapping:sub:admin-user", "__JWT_PROXY_ADMIN__" + ) + + result = await _resolve_jwt_to_virtual_key( + jwt_claims=jwt_claims, + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + parent_otel_span=None, + proxy_logging_obj=None, + ) + + assert result is None + prisma_client.db.litellm_jwtkeymapping.find_first.assert_not_called() + + +# ────────────────────────────────────────────── +# Tests: AUTO_REGISTER stamps validated identity from auth_builder +# ────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_auto_register_helper_stamps_validated_identity_context(): + """ + The deferred-auto-register contract: _auto_register_jwt_mapping is called + with identity fields from JWTAuthManager.auth_builder's *validated* + result (after RBAC, scope mappings, custom_validate, email-domain policy). + These must be passed to generate_key_helper_fn so the created key carries + them — the cached future-request path then inherits the same team/user/org + limits the auth_builder path would have applied. + """ + from litellm.proxy.auth.user_api_key_auth import _auto_register_jwt_mapping + + jwt_handler = JWTHandler() + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + virtual_key_mapping_cache_ttl=300, + ) + + prisma_client = MagicMock() + prisma_client.db.litellm_jwtkeymapping.create = AsyncMock() + mock_key_obj = UserAPIKeyAuth( + token="hashed", team_id="validated-team", user_id="validated-user" + ) + + with ( + patch( + "litellm.proxy.auth.user_api_key_auth.get_key_object", + new_callable=AsyncMock, + ) as mock_get_key, + patch( + "litellm.proxy.management_endpoints.key_management_endpoints.generate_key_helper_fn", + new_callable=AsyncMock, + ) as mock_gen_key, + ): + mock_gen_key.return_value = {"token": "sk-newkey", "key": "sk-newkey"} + mock_get_key.return_value = mock_key_obj + + result = await _auto_register_jwt_mapping( + virtual_key_claim_field="sub", + claim_value="new-user", + jwt_handler=jwt_handler, + prisma_client=prisma_client, + user_api_key_cache=DualCache(), + parent_otel_span=None, + proxy_logging_obj=None, + cache_key="jwt_key_mapping:sub:new-user", + team_id="validated-team", + user_id="validated-user", + org_id="validated-org", + end_user_id="validated-end-user", + ) + + assert result == mock_key_obj + assert mock_gen_key.call_args.kwargs["team_id"] == "validated-team" + assert mock_gen_key.call_args.kwargs["user_id"] == "validated-user" + assert mock_gen_key.call_args.kwargs["organization_id"] == "validated-org" + assert result.org_id == "validated-org" + assert result.end_user_id == "validated-end-user" + + +# ────────────────────────────────────────────── +# Tests: backward-compat alias jwt_client_id_field +# ────────────────────────────────────────────── + + +def test_jwt_client_id_field_alias_maps_to_virtual_key_claim_field(): + """ + jwt_client_id_field (old doc name) must silently alias to virtual_key_claim_field. + """ + auth = LiteLLM_JWTAuth(jwt_client_id_field="azp") + assert auth.virtual_key_claim_field == "azp" + + +def test_jwt_client_id_field_does_not_raise_on_duplicate(): + """ + If both jwt_client_id_field and virtual_key_claim_field are supplied, + virtual_key_claim_field takes precedence and no error is raised. + """ + auth = LiteLLM_JWTAuth( + jwt_client_id_field="old_field", + virtual_key_claim_field="new_field", + ) + assert auth.virtual_key_claim_field == "new_field" diff --git a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py index a3452ac802..fa6cc8bed1 100644 --- a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py +++ b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py @@ -31,12 +31,14 @@ from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.auth_checks import get_key_object, _cache_key_object from litellm.proxy.auth.route_checks import RouteChecks from litellm.proxy.auth.user_api_key_auth import ( + _PendingAutoRegister, _matches_routing_override, _reserve_budget_after_common_checks, _route_requires_auth_despite_public, _routing_selector_matches_claim, _run_centralized_common_checks, _run_post_custom_auth_checks, + _user_api_key_auth_builder, get_api_key, user_api_key_auth, ) @@ -1550,6 +1552,93 @@ class TestJWTOAuth2Coexistence: assert mock_jwt_auth.call_args.kwargs["request_method"] == "POST" assert result.user_id == "jwt-human-user" + @pytest.mark.asyncio + async def test_auto_register_passes_validated_org_context_to_generated_key(self): + jwt_token = "eyJhbGciOiJSUzI1NiJ9.eyJzdWIiOiJ1c2VyMSJ9.signature" + general_settings = {"enable_jwt_auth": True} + user_api_key_cache = DualCache() + prisma_client = MagicMock() + jwt_handler = MagicMock() + jwt_handler.is_jwt.return_value = True + jwt_handler.auth_jwt = AsyncMock(return_value={"sub": "user1"}) + jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth( + virtual_key_claim_field="sub", + virtual_key_mapping_cache_ttl=300, + ) + auto_registered_key = UserAPIKeyAuth( + token="hashed-auto-key", + team_id="validated-team", + user_id="validated-user", + org_id="validated-org", + end_user_id="validated-end-user", + ) + mock_jwt_result = { + "is_proxy_admin": False, + "team_object": None, + "user_object": None, + "end_user_object": None, + "org_object": None, + "token": jwt_token, + "team_id": "validated-team", + "user_id": "validated-user", + "end_user_id": "validated-end-user", + "org_id": "validated-org", + "team_membership": None, + "jwt_claims": {"sub": "user1"}, + } + + mock_request = MagicMock() + mock_request.url.path = "/v1/chat/completions" + mock_request.method = "POST" + mock_request.headers = {"authorization": f"Bearer {jwt_token}"} + mock_request.query_params = {} + mock_request.state = SimpleNamespace() + + with ( + patch("litellm.proxy.proxy_server.general_settings", general_settings), + patch("litellm.proxy.proxy_server.premium_user", True), + patch("litellm.proxy.proxy_server.master_key", "sk-master"), + patch("litellm.proxy.proxy_server.prisma_client", prisma_client), + patch("litellm.proxy.proxy_server.user_api_key_cache", user_api_key_cache), + patch("litellm.proxy.proxy_server.proxy_logging_obj", MagicMock()), + patch("litellm.proxy.proxy_server.jwt_handler", jwt_handler), + patch( + "litellm.proxy.auth.user_api_key_auth._resolve_jwt_to_virtual_key", + new_callable=AsyncMock, + return_value=_PendingAutoRegister( + claim_field="sub", + claim_value="user1", + cache_key="jwt_key_mapping:sub:user1", + ), + ), + patch( + "litellm.proxy.auth.user_api_key_auth.JWTAuthManager.auth_builder", + new_callable=AsyncMock, + return_value=mock_jwt_result, + ), + patch( + "litellm.proxy.auth.user_api_key_auth._auto_register_jwt_mapping", + new_callable=AsyncMock, + return_value=auto_registered_key, + ) as mock_auto_register, + ): + result = await _user_api_key_auth_builder( + request=mock_request, + api_key=jwt_token, + azure_api_key_header="", + anthropic_api_key_header=None, + google_ai_studio_api_key_header=None, + azure_apim_header=None, + request_data={"model": "gpt-4o-mini"}, + ) + + mock_auto_register.assert_awaited_once() + assert mock_auto_register.call_args.kwargs["team_id"] == "validated-team" + assert mock_auto_register.call_args.kwargs["user_id"] == "validated-user" + assert mock_auto_register.call_args.kwargs["org_id"] == "validated-org" + assert mock_auto_register.call_args.kwargs["end_user_id"] == "validated-end-user" + assert result.org_id == "validated-org" + @pytest.mark.asyncio async def test_routing_override_routes_matching_jwt_to_oauth2(self): """