feat(jwt-auth): opt-in fallback to DB team on unresolved JWT claim (#28913)
* fix(jwt-auth): defer to single-team DB fallback on claim mismatch Extends the single-team DB fallback introduced in #26418 to two more cases where it previously could not run: * `find_and_validate_specific_team_id`: when `team_id_jwt_field` is configured and a claim value is present in the token but the team does not exist in the LiteLLM DB (HTTPException 404 from `get_team_object`), return `(None, None)` instead of raising — the auth_builder fallback then attributes the request to the user's single DB team. Only HTTPException is caught; other errors (e.g. "No DB Connected") still propagate. * `find_team_with_model_access`: when none of the `team_ids_jwt_field` groups resolve to a real LiteLLM team, return `(None, None)` instead of raising 403 so the same fallback path runs. If at least one group DID resolve to a team but none granted the requested model, the original 403 is preserved (legitimate access denial — not a claim mismatch). Tracked via the new `any_claim_team_resolved` flag. The strict `is_required_team_id` raise and `enforce_team_based_model_access` raise remain unchanged. Unit tests cover both new soft-fail paths and guard each preserved path (strict required, enforce_team_based, the preserved 403, and the non-HTTPException propagation). Co-authored-by: Cursor <cursoragent@cursor.com> * fix(jwt-auth): narrow HTTPException catch to 404 (greptile review) Address Greptile review comments on #28913: * `find_and_validate_specific_team_id`: re-raise HTTPException when `status_code != 404`, pinning the catch to the "team doesn't exist in db" path documented for `get_team_object`. A future change that introduces a different status code (e.g. 403 for a blocked team) will now propagate instead of silently falling through to the single-team DB fallback. * Add `test_find_and_validate_specific_team_id_non_404_http_exception_propagates` parametrised over 400 / 403 / 500 to lock in the contract. Co-authored-by: Cursor <cursoragent@cursor.com> * fix(jwt-auth): gate claim-mismatch fallback behind opt-in flag The unresolved-team-claim fallback added in the previous commit weakened the strict claim-based authorization contract by default — an authenticated user whose JWT carries a stale or invalid team claim could still consume their single DB team's models/quota via the fallback. Gate both soft-fail paths in `find_and_validate_specific_team_id` and `find_team_with_model_access` behind a new opt-in flag `team_claim_fallback` on `LiteLLM_JWTAuth` (default False). Default-off preserves the pre-existing strict behavior. Operators who intentionally treat IdP team claims as advisory (e.g. machine tokens whose group claims live in a separate namespace from LiteLLM team_ids) opt in via config. Adds two regression tests guarding the default-off behavior. Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
69a7bdb247
commit
a7ecf6b5b1
@ -4195,6 +4195,16 @@ class LiteLLM_JWTAuth(LiteLLMPydanticObjectBase):
|
||||
default=None,
|
||||
description="Optional claim-based routing overrides for JWT-shaped tokens. Matching rules route requests to oauth2 before default JWT flow.",
|
||||
)
|
||||
team_claim_fallback: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"If True, when a configured team_id_jwt_field / team_ids_jwt_field "
|
||||
"claim is present but does not resolve to any known team, defer to "
|
||||
"the single-team DB fallback (caller's only team membership) "
|
||||
"instead of raising. Default False preserves strict claim-based "
|
||||
"authorization."
|
||||
),
|
||||
)
|
||||
issuers: Optional[List[JWTIssuerConfig]] = Field(
|
||||
default=None,
|
||||
description="Optional issuer-bound JWT validation rules. When a token's `iss` matches a configured issuer, validation uses that issuer's JWKS, audience, and claim mappings. Tokens with an unlisted `iss` fall back to the global JWT_AUDIENCE/JWT_ISSUER validation path — this is additive routing, not an allow-list.",
|
||||
|
||||
@ -1299,15 +1299,29 @@ class JWTAuthManager:
|
||||
|
||||
# First try to get team by team_id
|
||||
if individual_team_id:
|
||||
team_object = await get_team_object(
|
||||
team_id=individual_team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert,
|
||||
)
|
||||
return individual_team_id, team_object
|
||||
try:
|
||||
team_object = await get_team_object(
|
||||
team_id=individual_team_id,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
parent_otel_span=parent_otel_span,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
team_id_upsert=jwt_handler.litellm_jwtauth.team_id_upsert,
|
||||
)
|
||||
return individual_team_id, team_object
|
||||
except HTTPException as e:
|
||||
if (
|
||||
e.status_code != 404
|
||||
or not jwt_handler.litellm_jwtauth.team_claim_fallback
|
||||
):
|
||||
raise
|
||||
# Claim doesn't map to a known team — defer to fallback.
|
||||
verbose_proxy_logger.debug(
|
||||
"JWT team_id claim '%s' did not resolve to a team: %s",
|
||||
individual_team_id,
|
||||
e.detail,
|
||||
)
|
||||
return None, None
|
||||
|
||||
# If no team_id found, try to resolve via team_alias_jwt_field
|
||||
team_alias = jwt_handler.get_team_alias(
|
||||
@ -1431,6 +1445,7 @@ class JWTAuthManager:
|
||||
)
|
||||
return None, None
|
||||
|
||||
any_claim_team_resolved = False
|
||||
for team_id in team_ids:
|
||||
try:
|
||||
team_object = await get_team_object(
|
||||
@ -1441,6 +1456,9 @@ class JWTAuthManager:
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
if team_object is not None:
|
||||
any_claim_team_resolved = True
|
||||
|
||||
if team_object and team_object.models is not None:
|
||||
team_models = team_object.models
|
||||
if isinstance(team_models, list) and (
|
||||
@ -1478,12 +1496,17 @@ class JWTAuthManager:
|
||||
if denied_auth_enforced_pass_through_route:
|
||||
JWTAuthManager._raise_team_passthrough_route_denial(route=route)
|
||||
|
||||
if requested_model:
|
||||
if requested_model and (
|
||||
any_claim_team_resolved
|
||||
or not jwt_handler.litellm_jwtauth.team_claim_fallback
|
||||
):
|
||||
# Claim resolved but no model access, or fallback disabled — deny.
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"No team has access to the requested model: {requested_model}. Checked teams={team_ids}. Check `/models` to see all available models.",
|
||||
)
|
||||
|
||||
# No claim team resolved and fallback enabled — defer to fallback.
|
||||
return None, None
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -3182,6 +3182,310 @@ def test_build_decode_kwargs_no_warning_when_scoped(
|
||||
assert matching == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Defer to single-team DB fallback (PR #26418) when JWT claims are present
|
||||
# but do not resolve to a LiteLLM team.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_and_validate_specific_team_id_unresolved_claim_returns_none():
|
||||
"""With `team_claim_fallback=True`: team_id claim is present in the JWT
|
||||
but the team is missing in the DB — return (None, None) so the
|
||||
auth_builder single-team fallback can run, instead of raising and
|
||||
failing auth."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
|
||||
team_id_jwt_field="team_id",
|
||||
team_claim_fallback=True,
|
||||
)
|
||||
token = {"sub": "user-1", "team_id": "claim-team-not-in-db"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.auth.handle_jwt.get_team_object",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_team:
|
||||
mock_get_team.side_effect = HTTPException(status_code=404, detail="missing")
|
||||
|
||||
team_id, team_object = await JWTAuthManager.find_and_validate_specific_team_id(
|
||||
jwt_handler=jwt_handler,
|
||||
jwt_valid_token=token,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=None,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
assert team_id is None
|
||||
assert team_object is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_team_with_model_access_unresolved_group_claim_returns_none(
|
||||
monkeypatch,
|
||||
):
|
||||
"""With `team_claim_fallback=True`: group claim resolves to team_ids that
|
||||
don't exist in the DB — return (None, None) instead of raising 403, so
|
||||
the single-team fallback can run."""
|
||||
import sys
|
||||
import types
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.router import Router
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{"model_name": "gpt-4o-mini", "litellm_params": {"model": "gpt-4o-mini"}}
|
||||
]
|
||||
)
|
||||
proxy_server_module = types.ModuleType("proxy_server")
|
||||
proxy_server_module.llm_router = router
|
||||
monkeypatch.setitem(sys.modules, "litellm.proxy.proxy_server", proxy_server_module)
|
||||
|
||||
async def raise_404(*_args, **_kwargs):
|
||||
raise HTTPException(status_code=404, detail="missing")
|
||||
|
||||
monkeypatch.setattr("litellm.proxy.auth.handle_jwt.get_team_object", raise_404)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_claim_fallback=True)
|
||||
|
||||
team_id, team_object = await JWTAuthManager.find_team_with_model_access(
|
||||
team_ids={"idp-group-a", "idp-group-b"},
|
||||
requested_model="gpt-4o-mini",
|
||||
route="/chat/completions",
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=None,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
assert team_id is None
|
||||
assert team_object is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_and_validate_specific_team_id_non_http_exception_still_propagates():
|
||||
"""Regression guard: only the 404 HTTPException raised by
|
||||
`get_team_object` ("team doesn't exist in db") is softened. Other
|
||||
errors — e.g. "No DB Connected" — must still propagate so operator-side
|
||||
problems are loud."""
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="team_id")
|
||||
token = {"sub": "user-1", "team_id": "some-claim-team"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.auth.handle_jwt.get_team_object",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_team:
|
||||
mock_get_team.side_effect = RuntimeError("simulated infrastructure error")
|
||||
|
||||
with pytest.raises(RuntimeError, match="simulated infrastructure error"):
|
||||
await JWTAuthManager.find_and_validate_specific_team_id(
|
||||
jwt_handler=jwt_handler,
|
||||
jwt_valid_token=token,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=None,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_and_validate_specific_team_id_non_404_http_exception_propagates():
|
||||
"""Regression guard: only 404 HTTPException is softened. If
|
||||
`get_team_object` is ever updated to raise a different HTTP status code
|
||||
(e.g. 403 for a blocked team), that error must still propagate rather
|
||||
than silently fall through to the single-team DB fallback."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="team_id")
|
||||
token = {"sub": "user-1", "team_id": "some-claim-team"}
|
||||
|
||||
for status_code in (400, 403, 500):
|
||||
with patch(
|
||||
"litellm.proxy.auth.handle_jwt.get_team_object",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_team:
|
||||
mock_get_team.side_effect = HTTPException(
|
||||
status_code=status_code, detail="non-404 failure"
|
||||
)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await JWTAuthManager.find_and_validate_specific_team_id(
|
||||
jwt_handler=jwt_handler,
|
||||
jwt_valid_token=token,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=None,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
assert exc_info.value.status_code == status_code
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_team_with_model_access_enforce_team_based_access_still_raises():
|
||||
"""Regression guard: when no group claims are present and
|
||||
`enforce_team_based_model_access` is on, the original 403 still fires —
|
||||
the new soft-fail only applies to the unresolved-claim path inside the
|
||||
loop, not to the no-team-claims-at-all path at the top."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(enforce_team_based_model_access=True)
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await JWTAuthManager.find_team_with_model_access(
|
||||
team_ids=set(),
|
||||
requested_model="gpt-4o-mini",
|
||||
route="/chat/completions",
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=None,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "enforce_team_based_model_access" in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_team_with_model_access_resolved_team_without_model_still_raises_403(
|
||||
monkeypatch,
|
||||
):
|
||||
"""Regression guard: when the JWT group claim DOES resolve to a real
|
||||
LiteLLM team but that team does not grant the requested model, keep the
|
||||
original 403. Only the unresolved-claim case is softened."""
|
||||
import sys
|
||||
import types
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.router import Router
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{"model_name": "gpt-4o-mini", "litellm_params": {"model": "gpt-4o-mini"}},
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo",
|
||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||
},
|
||||
]
|
||||
)
|
||||
proxy_server_module = types.ModuleType("proxy_server")
|
||||
proxy_server_module.llm_router = router
|
||||
monkeypatch.setitem(sys.modules, "litellm.proxy.proxy_server", proxy_server_module)
|
||||
|
||||
team = LiteLLM_TeamTable(team_id="real-team", models=["gpt-3.5-turbo"])
|
||||
|
||||
async def mock_get_team_object(*_args, **_kwargs):
|
||||
return team
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.auth.handle_jwt.get_team_object", mock_get_team_object
|
||||
)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await JWTAuthManager.find_team_with_model_access(
|
||||
team_ids={"real-team"},
|
||||
requested_model="gpt-4o-mini",
|
||||
route="/chat/completions",
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=None,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
assert "No team has access to the requested model" in str(exc_info.value.detail)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_and_validate_specific_team_id_unresolved_claim_default_raises():
|
||||
"""Default `team_claim_fallback=False`: unresolved team_id claim must
|
||||
still raise — preserves the strict claim-based authorization boundary
|
||||
when the operator has not opted in to the fallback."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(team_id_jwt_field="team_id")
|
||||
token = {"sub": "user-1", "team_id": "claim-team-not-in-db"}
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.auth.handle_jwt.get_team_object",
|
||||
new_callable=AsyncMock,
|
||||
) as mock_get_team:
|
||||
mock_get_team.side_effect = HTTPException(status_code=404, detail="missing")
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await JWTAuthManager.find_and_validate_specific_team_id(
|
||||
jwt_handler=jwt_handler,
|
||||
jwt_valid_token=token,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=None,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_team_with_model_access_unresolved_group_claim_default_raises(
|
||||
monkeypatch,
|
||||
):
|
||||
"""Default `team_claim_fallback=False`: group claims that don't resolve
|
||||
to any LiteLLM team must still raise 403 — preserves the strict
|
||||
claim-based authorization boundary."""
|
||||
import sys
|
||||
import types
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.router import Router
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{"model_name": "gpt-4o-mini", "litellm_params": {"model": "gpt-4o-mini"}}
|
||||
]
|
||||
)
|
||||
proxy_server_module = types.ModuleType("proxy_server")
|
||||
proxy_server_module.llm_router = router
|
||||
monkeypatch.setitem(sys.modules, "litellm.proxy.proxy_server", proxy_server_module)
|
||||
|
||||
async def raise_404(*_args, **_kwargs):
|
||||
raise HTTPException(status_code=404, detail="missing")
|
||||
|
||||
monkeypatch.setattr("litellm.proxy.auth.handle_jwt.get_team_object", raise_404)
|
||||
|
||||
jwt_handler = JWTHandler()
|
||||
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await JWTAuthManager.find_team_with_model_access(
|
||||
team_ids={"idp-group-a", "idp-group-b"},
|
||||
requested_model="gpt-4o-mini",
|
||||
route="/chat/completions",
|
||||
jwt_handler=jwt_handler,
|
||||
prisma_client=None,
|
||||
user_api_key_cache=None,
|
||||
parent_otel_span=None,
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
# GH #26789: JWT claim user_id must rebind to legacy DB row after fuzzy match.
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user