fix(jwt): use resolved DB user_id for spend on legacy email match (#29217)

* fix(jwt): attribute spend to resolved DB user_id on email/sso fuzzy match

When user_id_upsert is enabled with JWT auth and a pre-migration user row
exists whose user_email matches the JWT email but whose user_id is a UUID,
get_user_object resolves the legacy row via fuzzy lookup, but the JWT-claim
user_id (the email) still flowed into team-membership lookup,
JWTAuthBuilderResult.user_id, UserAPIKeyAuth and the spend tables. Spend was
orphaned under a phantom email id; /user/info and the Usage page showed $0
for the legacy user (GH #26789).

Treat the resolved user_object as the source of truth: add
_canonical_user_id_from_db, rebind inside get_objects, and return
effective_user_id so auth_builder unpacks it without adding statements.

Fixes #26789

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(jwt): log user_id rebind at DEBUG to avoid email PII in INFO streams

Greptile review on #29217: rebinding often logs JWT email claims at INFO.

Co-authored-by: Cursor <cursoragent@cursor.com>

* test(jwt): update passthrough allowlist mock for 5-tuple get_objects

Staging #29256 added a test that still mocked get_objects with a
4-tuple; our PR expanded the return to 5 values (effective_user_id).

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
milan-berri 2026-06-06 01:59:41 +03:00 committed by GitHub
parent 95e3d136e1
commit b7f47a3b52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 177 additions and 13 deletions

View File

@ -1506,6 +1506,21 @@ class JWTAuthManager:
)
return user_id, user_email, valid_user_email
@staticmethod
def _canonical_user_id_from_db(
user_id: Optional[str],
user_object: Optional[LiteLLM_UserTable],
) -> Optional[str]:
"""Id used for spend / team-membership attribution.
JWT claim (often email) is only a lookup key. If fuzzy match in
``get_user_object`` resolved a legacy row with a different ``user_id``,
use that row's id; otherwise keep the claim. GH #26789.
"""
if user_object is not None and user_object.user_id:
return user_object.user_id
return user_id
@staticmethod
async def get_objects(
user_id: Optional[str],
@ -1526,8 +1541,13 @@ class JWTAuthManager:
Optional[LiteLLM_OrganizationTable],
Optional[LiteLLM_EndUserTable],
Optional[LiteLLM_TeamMembership],
Optional[str],
]:
"""Get user, org, and end user objects. Also resolves org aliases to IDs if configured."""
"""Get user, org, end-user, and team-membership objects.
Returns ``(..., effective_user_id)``: JWT claim unless fuzzy lookup
matched a legacy row (GH #26789).
"""
# Get org object - first try by ID, then by alias
org_object: Optional[LiteLLM_OrganizationTable] = None
@ -1602,6 +1622,18 @@ class JWTAuthManager:
else None
)
# Rebind to resolved DB user_id for team_membership + auth_builder (GH #26789).
effective_user_id = JWTAuthManager._canonical_user_id_from_db(
user_id=user_id, user_object=user_object
)
if effective_user_id != user_id:
verbose_proxy_logger.debug(
"JWT Auth: rebinding user_id %r -> DB user_id %r (email/sso match)",
user_id,
effective_user_id,
)
user_id = effective_user_id
team_membership_object: Optional[LiteLLM_TeamMembership] = None
if user_id and team_id:
team_membership_object = (
@ -1617,7 +1649,13 @@ class JWTAuthManager:
else None
)
return user_object, org_object, end_user_object, team_membership_object
return (
user_object,
org_object,
end_user_object,
team_membership_object,
user_id,
)
@staticmethod
def validate_object_id(
@ -2075,12 +2113,13 @@ class JWTAuthManager:
# Extract alias fields for resolution (if configured)
org_alias = jwt_handler.get_org_alias(token=jwt_valid_token, default_value=None)
# Get other objects
# get_objects returns effective_user_id for downstream spend attribution (GH #26789).
(
user_object,
org_object,
end_user_object,
team_membership_object,
user_id,
) = await JWTAuthManager.get_objects(
user_id=user_id,
user_email=user_email,

View File

@ -332,7 +332,7 @@ async def test_auth_builder_proxy_admin_user_role():
JWTAuthManager,
"get_objects",
new_callable=AsyncMock,
return_value=(user_object, None, None, None),
return_value=(user_object, None, None, None, user_object.user_id),
) as mock_get_objects,
patch.object(
JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock
@ -427,7 +427,7 @@ async def test_auth_builder_non_proxy_admin_user_role():
JWTAuthManager,
"get_objects",
new_callable=AsyncMock,
return_value=(user_object, None, None, None),
return_value=(user_object, None, None, None, user_object.user_id),
) as mock_get_objects,
patch.object(
JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock
@ -1206,7 +1206,7 @@ async def test_auth_builder_returns_team_membership_object():
JWTAuthManager,
"get_objects",
new_callable=AsyncMock,
return_value=(user_object, None, None, mock_team_membership),
return_value=(user_object, None, None, mock_team_membership, user_object.user_id),
) as mock_get_objects,
patch.object(
JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock
@ -1345,7 +1345,7 @@ async def test_auth_builder_with_oidc_userinfo_enabled():
JWTAuthManager,
"get_objects",
new_callable=AsyncMock,
return_value=(user_object, None, None, None),
return_value=(user_object, None, None, None, user_object.user_id),
) as mock_get_objects,
patch.object(
JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock
@ -1469,7 +1469,7 @@ async def test_auth_builder_with_oidc_userinfo_disabled():
JWTAuthManager,
"get_objects",
new_callable=AsyncMock,
return_value=(user_object, None, None, None),
return_value=(user_object, None, None, None, user_object.user_id),
) as mock_get_objects,
patch.object(
JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock
@ -1585,7 +1585,7 @@ async def test_auth_builder_oidc_enabled_falls_back_to_jwt_auth_for_jwt_tokens()
JWTAuthManager,
"get_objects",
new_callable=AsyncMock,
return_value=(user_object, None, None, None),
return_value=(user_object, None, None, None, user_object.user_id),
),
patch.object(JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock),
patch.object(JWTAuthManager, "validate_object_id", return_value=True),
@ -1681,7 +1681,7 @@ async def test_auth_builder_uses_team_from_header_e2e():
JWTAuthManager,
"get_objects",
new_callable=AsyncMock,
return_value=(user_object, None, None, None),
return_value=(user_object, None, None, None, user_object.user_id),
),
patch.object(JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock),
patch.object(
@ -1900,7 +1900,7 @@ async def test_auth_builder_rbac_team_loads_team_for_passthrough_allowlist():
JWTAuthManager,
"get_objects",
new_callable=AsyncMock,
return_value=(None, None, None, None),
return_value=(None, None, None, None, None),
),
patch.object(JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock),
patch.object(
@ -2462,6 +2462,7 @@ async def test_get_objects_resolves_org_by_name():
result_org_obj,
result_end_user_obj,
result_team_membership,
_result_user_id,
) = await JWTAuthManager.get_objects(
user_id=None,
user_email=None,
@ -2909,7 +2910,7 @@ async def test_auth_builder_single_team_db_fallback_when_jwt_has_no_team(
JWTAuthManager,
"get_objects",
new_callable=AsyncMock,
return_value=(user_object, None, None, None),
return_value=(user_object, None, None, None, user_object.user_id),
),
patch.object(JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock),
patch.object(JWTAuthManager, "validate_object_id", return_value=True),
@ -3027,7 +3028,7 @@ async def test_auth_builder_single_team_fallback_membership_error_skips_no_raise
JWTAuthManager,
"get_objects",
new_callable=AsyncMock,
return_value=(user_object, None, None, None),
return_value=(user_object, None, None, None, user_object.user_id),
),
patch.object(JWTAuthManager, "map_user_to_teams", new_callable=AsyncMock),
patch.object(JWTAuthManager, "validate_object_id", return_value=True),
@ -3181,6 +3182,68 @@ def test_build_decode_kwargs_no_warning_when_scoped(
assert matching == []
# GH #26789: JWT claim user_id must rebind to legacy DB row after fuzzy match.
def test_canonical_user_id_rebinds_to_legacy_uuid():
"""JWT email resolves to a legacy UUID row -> use the UUID for attribution."""
legacy_uuid = "bb8ab11f-09aa-47ae-b063-6e80506ac3bc"
jwt_email = "matt@example.com"
user_object = LiteLLM_UserTable(user_id=legacy_uuid, user_email=jwt_email)
assert (
JWTAuthManager._canonical_user_id_from_db(
user_id=jwt_email, user_object=user_object
)
== legacy_uuid
)
def test_canonical_user_id_no_change_when_ids_match():
"""Fresh upserted user (row.user_id == claim) -> claim returned unchanged."""
same = "alice@example.com"
user_object = LiteLLM_UserTable(user_id=same, user_email=same)
assert (
JWTAuthManager._canonical_user_id_from_db(
user_id=same, user_object=user_object
)
== same
)
def test_canonical_user_id_returns_claim_when_no_user_object():
"""No resolved row (e.g. upsert disabled / brand new) -> keep the claim."""
assert (
JWTAuthManager._canonical_user_id_from_db(
user_id="newcomer@example.com", user_object=None
)
== "newcomer@example.com"
)
def test_canonical_user_id_returns_none_when_claim_none_and_no_object():
"""Defensive: no claim and no row -> stays None, never invents an id."""
assert (
JWTAuthManager._canonical_user_id_from_db(user_id=None, user_object=None)
is None
)
def test_canonical_user_id_no_change_when_db_user_id_falsy():
"""Defensive: an empty user_object.user_id must not clobber the claim."""
class _Stub:
user_id = ""
assert (
JWTAuthManager._canonical_user_id_from_db(
user_id="jwt@example.com", user_object=_Stub()
)
== "jwt@example.com"
)
@pytest.mark.asyncio
async def test_auth_jwt_expired_token_raises_401_jwk_path():
"""An expired JWT (access token) decoded via the JWK/dict public-key path
@ -3410,6 +3473,68 @@ def test_get_jwks_url_for_issuer_falls_back_to_discovery_document():
)
@pytest.mark.asyncio
async def test_get_objects_team_membership_uses_rebound_user_id():
"""team_membership lookup uses resolved DB user_id, not JWT email claim."""
from litellm.caching.caching import DualCache
legacy_uuid = "bb8ab11f-09aa-47ae-b063-6e80506ac3bc"
jwt_email = "matt@example.com"
team_id = "team-1"
resolved_user = LiteLLM_UserTable(user_id=legacy_uuid, user_email=jwt_email)
captured = {}
async def fake_get_user_object(*args, **kwargs):
return resolved_user
async def fake_get_team_membership(user_id, team_id, *args, **kwargs):
captured["user_id"] = user_id
captured["team_id"] = team_id
return None
jwt_handler = JWTHandler()
jwt_handler.litellm_jwtauth = LiteLLM_JWTAuth(
user_id_jwt_field="email", user_id_upsert=True
)
with patch(
"litellm.proxy.auth.handle_jwt.get_user_object",
side_effect=fake_get_user_object,
), patch(
"litellm.proxy.auth.handle_jwt.get_team_membership",
side_effect=fake_get_team_membership,
):
(
user_object,
_org_object,
_end_user_object,
_team_membership_object,
effective_user_id,
) = await JWTAuthManager.get_objects(
user_id=jwt_email,
user_email=jwt_email,
org_id=None,
end_user_id=None,
team_id=team_id,
valid_user_email=None,
jwt_handler=jwt_handler,
prisma_client=MagicMock(),
user_api_key_cache=DualCache(),
parent_otel_span=None,
proxy_logging_obj=MagicMock(),
route="/chat/completions",
)
assert user_object is not None and user_object.user_id == legacy_uuid
assert effective_user_id == legacy_uuid
assert captured["user_id"] == legacy_uuid, (
"team_membership lookup must use the resolved DB user_id, not the JWT "
f"email claim (got {captured['user_id']!r})"
)
assert captured["team_id"] == team_id
@pytest.mark.asyncio
async def test_multi_issuer_jwt_validates_selected_issuer_and_maps_claims(
monkeypatch,