feat(proxy): persist allowlisted OIDC claims in CLI SSO poll (#28463)

* feat(proxy): persist allowlisted OIDC claims in CLI SSO poll

Map CLI_SSO_CLAIM_MAP sources into user metadata and return scalar
attribution_metadata from /sso/cli/poll. Build SSOUserDefinedValues in
cli_sso_callback so first-time CLI logins can upsert users. Add mock OIDC
scripts and tests for claim extraction and poll exposure.

Co-authored-by: Cursor <cursoragent@cursor.com>

* docs(proxy): document CLI SSO attribution_metadata in client README

Co-authored-by: Cursor <cursoragent@cursor.com>

* Delete scripts/mock_oidc_server_for_cli_sso.py

* Delete scripts/test_cli_sso_claims_e2e.py

* fix(ui_sso): preserve claim types and avoid metadata. prefix stripping

- Replace _update_dictionary with a local recursive merge so string
  OIDC claim values that happen to look numeric are not silently coerced
  to int/float when persisting CLI SSO attribution metadata.
- Use a local dot-path resolver in _extract_sso_claim_value so that
  source claim paths beginning with 'metadata.' are not silently stripped
  by get_nested_value (which is designed for LiteLLM JWT metadata, not
  arbitrary OIDC claims).

Co-authored-by: Yassin Kortam <yassin@berri.ai>

* Remove redundant metadata. prefix strip in _set_nested_metadata_value

The _parse_cli_sso_claim_map already strips the metadata. prefix from
dest keys before reaching the setter. The duplicate strip in
_set_nested_metadata_value was a no-op in normal flow but could
mis-place values for dest keys like metadata.metadata.foo.

Co-authored-by: Yassin Kortam <yassin@berri.ai>

* Fix greptile

* Fix ruff

* Move CLI SSO user defined values build inside try/except for consistent error handling

Co-authored-by: Yassin Kortam <yassin@berri.ai>

* fix(proxy): enforce restricted SSO group on CLI SSO callback

Apply verify_user_in_restricted_sso_group before CLI session completion
and user upsert, matching the UI SSO path. Re-raise ProxyException so
restricted-group denials return 403 instead of 500.

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(proxy): replace recursive CLI SSO metadata helpers with iterative merge

Use stack-based flatten/merge to satisfy recursive_detector CI. Fix mypy
types for UserApiKeyCache and user_id on CLI SSO session completion.

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix: resolve nested CustomOpenID extra_fields in CLI SSO claim extraction

When GENERIC_USER_EXTRA_ATTRIBUTES captures a parent object (e.g. org_info),
extra_fields stores it as {"org_info": {"department": "..."}}. A CLI claim
map entry using a dotted path like org_info.department would silently fail
because the lookup only checked the exact flat key. Fall back to dotted-path
resolution on extra_fields before model_dump().

Co-authored-by: Yassin Kortam <yassin@berri.ai>

* fix(sso): update CLI SSO test for new received_response kwarg and remove redundant 'token' secret fragment

Co-authored-by: Yassin Kortam <yassin@berri.ai>

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Yassin Kortam <yassin@berri.ai>
This commit is contained in:
Sameer Kankute 2026-05-22 22:28:50 +05:30 committed by GitHub
parent ef36e89638
commit 50a3f10a92
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 717 additions and 84 deletions

View File

@ -1443,6 +1443,12 @@ CLI_JWT_EXPIRATION_HOURS = int(
or os.getenv("LITELLM_CLI_JWT_EXPIRATION_HOURS")
or 24
)
# Comma-separated allowlisted OIDC claim map for CLI SSO polling, e.g.
# "employment_type->acme_employment_type,org_info.department->department"
CLI_SSO_CLAIM_MAP = (
os.getenv("CLI_SSO_CLAIM_MAP") or os.getenv("LITELLM_CLI_SSO_CLAIM_MAP") or ""
)
CLI_SSO_CLAIM_MAX_SCALAR_LENGTH = 1024
########################### UI SESSION DURATION ###########################
# Duration for UI login session (username/password, SSO, invitation links). Format: "30s", "30m", "24h", "7d"

View File

@ -350,7 +350,7 @@ The CLI provides three authentication commands:
4. **User Authentication**: User completes SSO authentication in browser
5. **Callback Processing**: SSO provider redirects back to proxy with state parameter
6. **User Code Verification**: Browser confirms the verification code shown in the CLI
7. **Polling**: CLI polls `/sso/cli/poll/{login_id}` with the polling secret header until the JWT is ready
7. **Polling**: CLI polls `/sso/cli/poll/{login_id}` with the polling secret header until the JWT is ready. When `CLI_SSO_CLAIM_MAP` is configured on the proxy, the poll response may include `attribution_metadata` (allowlisted scalar OIDC claims for client attribution).
8. **Token Storage**: CLI saves the authentication token to `~/.litellm/token.json`
### Benefits of This Approach

View File

@ -43,6 +43,8 @@ from litellm.caching.dual_cache import DualCache
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.constants import (
CLI_SSO_CLAIM_MAP,
CLI_SSO_CLAIM_MAX_SCALAR_LENGTH,
CLI_SSO_SESSION_CACHE_KEY_PREFIX,
CLI_SSO_SESSION_TTL_SECONDS,
LITELLM_CLI_SOURCE_IDENTIFIER,
@ -140,6 +142,20 @@ _CLI_SSO_START_RATE_LIMIT_WINDOW_SECONDS = 60
_CLI_SSO_START_RATE_LIMIT_MAX_ATTEMPTS = 30
_CLI_SSO_USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
_CLI_SSO_LOGIN_ID_RE = re.compile(r"^cli-[A-Za-z0-9_-]{12,124}$")
_CLI_SSO_SCALAR_TYPES = (str, int, float, bool)
_CLI_SSO_DEST_KEY_RE = re.compile(r"^[A-Za-z0-9_.-]+$")
_CLI_SSO_SECRET_KEY_FRAGMENTS = frozenset(
{
"access_token",
"api_key",
"client_secret",
"id_token",
"password",
"private_key",
"refresh_token",
"secret",
}
)
def _hash_cli_sso_secret(secret: str) -> str:
@ -225,6 +241,239 @@ def _verify_cli_sso_poll_secret(flow: dict, poll_secret: Optional[str]) -> bool:
return secrets.compare_digest(supplied_poll_secret_hash, expected_poll_secret_hash)
def _parse_cli_sso_claim_map() -> List[Tuple[str, str]]:
"""
Parse CLI_SSO_CLAIM_MAP / LITELLM_CLI_SSO_CLAIM_MAP.
Format: comma-separated ``source_claim->metadata_key`` pairs, e.g.
``employment_type->acme_employment_type,org_info.department->department``.
Destination keys may use an optional ``metadata.`` prefix; values are stored
on the LiteLLM user's ``metadata`` JSON column.
"""
claim_map_raw = CLI_SSO_CLAIM_MAP.strip()
if not claim_map_raw:
return []
parsed: List[Tuple[str, str]] = []
for entry in claim_map_raw.split(","):
entry = entry.strip()
if not entry or "->" not in entry:
continue
source_claim, dest_key = entry.split("->", 1)
source_claim = source_claim.strip()
dest_key = dest_key.strip()
if dest_key.startswith("metadata."):
dest_key = dest_key[len("metadata.") :]
if source_claim and dest_key:
parsed.append((source_claim, dest_key))
return parsed
def _is_safe_cli_sso_metadata_dest_key(dest_key: str) -> bool:
if not dest_key or not _CLI_SSO_DEST_KEY_RE.fullmatch(dest_key):
return False
lowered = dest_key.lower()
return not any(fragment in lowered for fragment in _CLI_SSO_SECRET_KEY_FRAGMENTS)
def _is_safe_cli_sso_scalar_claim_value(value: Any) -> bool:
if not isinstance(value, _CLI_SSO_SCALAR_TYPES):
return False
if isinstance(value, str):
if len(value) > CLI_SSO_CLAIM_MAX_SCALAR_LENGTH:
return False
if value.startswith("eyJ") and value.count(".") >= 2:
return False
return True
def _sso_result_to_dict(result: Union[CustomOpenID, OpenID, dict]) -> Dict[str, Any]:
if isinstance(result, dict):
return result
if hasattr(result, "model_dump"):
dumped = result.model_dump()
if isinstance(dumped, dict):
return cast(Dict[str, Any], dumped)
return {}
def _get_nested_claim_value(data: Dict[str, Any], claim_path: str) -> Any:
"""Resolve a dot-notation claim path against an SSO result dict.
Unlike ``get_nested_value``, this does not strip a leading ``metadata.``
prefix, since OIDC claims may legitimately use ``metadata`` as a top-level
key.
"""
if not claim_path:
return None
if claim_path in data:
return data[claim_path]
placeholder = "\x00"
parts = claim_path.replace("\\.", placeholder).split(".")
parts = [p.replace(placeholder, ".") for p in parts]
current: Any = data
for part in parts:
if isinstance(current, dict) and part in current:
current = current[part]
else:
return None
return current
def _extract_sso_claim_value(
result: Union[CustomOpenID, OpenID, dict], claim_path: str
) -> Any:
extra_fields = getattr(result, "extra_fields", None)
if isinstance(extra_fields, dict):
if claim_path in extra_fields:
return extra_fields[claim_path]
nested = _get_nested_claim_value(extra_fields, claim_path)
if nested is not None:
return nested
if isinstance(result, dict):
return _get_nested_claim_value(result, claim_path)
result_dict = _sso_result_to_dict(result)
return _get_nested_claim_value(result_dict, claim_path)
def _set_nested_metadata_value(
metadata: Dict[str, Any], key_path: str, value: Any
) -> None:
placeholder = "\x00"
parts = key_path.replace("\\.", placeholder).split(".")
parts = [p.replace(placeholder, ".") for p in parts]
current: Any = metadata
for part in parts[:-1]:
existing = current.get(part)
if not isinstance(existing, dict):
existing = {}
current[part] = existing
current = existing
current[parts[-1]] = value
def _flatten_cli_sso_metadata_for_poll(
metadata: Dict[str, Any],
) -> Dict[str, Union[str, int, float, bool]]:
"""Expose scalar attribution metadata as a flat dict for CLI poll responses."""
flattened: Dict[str, Union[str, int, float, bool]] = {}
stack: List[Tuple[str, Any]] = [("", metadata)]
while stack:
prefix, value = stack.pop()
if isinstance(value, dict):
for key, nested in value.items():
nested_prefix = f"{prefix}.{key}" if prefix else key
stack.append((nested_prefix, nested))
elif _is_safe_cli_sso_scalar_claim_value(value):
flattened[prefix] = value
return flattened
def build_cli_sso_attribution_metadata(
result: Union[CustomOpenID, OpenID, dict],
) -> Dict[str, Any]:
"""
Build allowlisted, non-secret scalar attribution metadata from an SSO result.
Sources are configured via CLI_SSO_CLAIM_MAP / LITELLM_CLI_SSO_CLAIM_MAP and
may include claims captured by GENERIC_USER_EXTRA_ATTRIBUTES on CustomOpenID.
"""
claim_map = _parse_cli_sso_claim_map()
if not claim_map:
return {}
metadata: Dict[str, Any] = {}
for source_claim, dest_key in claim_map:
if not _is_safe_cli_sso_metadata_dest_key(dest_key):
verbose_proxy_logger.debug(
f"Skipping unsafe CLI SSO metadata destination key: {dest_key}"
)
continue
raw_value = _extract_sso_claim_value(result=result, claim_path=source_claim)
if not _is_safe_cli_sso_scalar_claim_value(raw_value):
continue
_set_nested_metadata_value(
metadata=metadata, key_path=dest_key, value=raw_value
)
return metadata
def _merge_cli_sso_attribution_metadata(
existing_metadata: Dict[str, Any], attribution_metadata: Dict[str, Any]
) -> Dict[str, Any]:
"""Merge attribution metadata into existing user metadata in-place.
Preserves original value types (in particular, string claim values that
happen to look numeric are NOT coerced to ``int``/``float``). Nested dicts
are merged iteratively so attribution claims do not clobber unrelated keys
under the same parent.
"""
pending: List[Tuple[Dict[str, Any], Dict[str, Any]]] = [
(existing_metadata, attribution_metadata)
]
while pending:
target, source = pending.pop()
for key, value in source.items():
if value is None:
continue
existing_value = target.get(key)
if isinstance(value, dict) and isinstance(existing_value, dict):
pending.append((existing_value, value))
else:
target[key] = value
return existing_metadata
async def _persist_cli_sso_user_metadata(
prisma_client: PrismaClient,
user_id: str,
attribution_metadata: Dict[str, Any],
) -> None:
if not attribution_metadata:
return
try:
user_row = await prisma_client.db.litellm_usertable.find_unique(
where={"user_id": user_id}
)
existing_metadata: Dict[str, Any] = {}
if user_row is not None:
row_metadata = user_row.metadata
if isinstance(row_metadata, dict):
existing_metadata = deepcopy(row_metadata)
merged_metadata = _merge_cli_sso_attribution_metadata(
existing_metadata=existing_metadata,
attribution_metadata=attribution_metadata,
)
await prisma_client.db.litellm_usertable.update_many(
where={"user_id": user_id},
data={"metadata": merged_metadata},
)
verbose_proxy_logger.info(
f"Persisted CLI SSO attribution metadata for user {user_id}: "
f"{list(_flatten_cli_sso_metadata_for_poll(attribution_metadata).keys())}"
)
except Exception as e:
verbose_proxy_logger.error(
f"Failed to persist CLI SSO attribution metadata for user {user_id}: {e}"
)
def _cli_poll_attribution_metadata_from_session(
session_data: Dict[str, Any],
) -> Dict[str, Union[str, int, float, bool]]:
stored = session_data.get("attribution_metadata")
if isinstance(stored, dict):
return _flatten_cli_sso_metadata_for_poll(stored)
return {}
def _render_cli_sso_verification_page(
verify_url: str, browser_complete_token: str
) -> str:
@ -1674,7 +1923,12 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
key_id = state_parts[1] if len(state_parts) > 1 else None
verbose_proxy_logger.info("CLI SSO callback detected")
return await cli_sso_callback(request=request, key=key_id, result=result)
return await cli_sso_callback(
request=request,
key=key_id,
result=result,
received_response=received_response,
)
# Control-plane cross-origin: read return_to from cookie.
# Starlette's cookie_parser already handles RFC 2109 unquoting.
@ -1692,15 +1946,144 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
)
async def _build_cli_sso_user_defined_values(
result: Union[OpenID, dict],
parsed_openid_result: ParsedOpenIDResult,
) -> Optional[SSOUserDefinedValues]:
from litellm.proxy.proxy_server import user_custom_sso
user_id = parsed_openid_result.get("user_id")
if user_custom_sso is not None:
if inspect.iscoroutinefunction(user_custom_sso):
return await user_custom_sso(result) # type: ignore
raise ValueError("user_custom_sso must be a coroutine function")
if user_id is None:
return None
return SSOUserDefinedValues(
models=[],
user_id=user_id,
user_email=parsed_openid_result.get("user_email"),
max_budget=litellm.max_internal_user_budget,
user_role=parsed_openid_result.get("user_role"),
budget_duration=litellm.internal_user_budget_duration,
)
async def _fetch_cli_sso_team_details(
prisma_client: PrismaClient,
teams: List[str],
) -> List[Dict[str, Any]]:
team_details: List[Dict[str, Any]] = []
try:
if teams:
prisma_teams = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": teams}}
)
for team_row in prisma_teams:
team_dict = team_row.model_dump()
team_details.append(
{
"team_id": team_dict.get("team_id"),
"team_alias": team_dict.get("team_alias"),
}
)
except Exception as e:
verbose_proxy_logger.error(
f"Error fetching team details for CLI SSO session: {e}"
)
return team_details
async def _complete_cli_sso_callback_session(
*,
request: Request,
key: str,
flow: dict,
result: Union[OpenID, dict],
parsed_openid_result: ParsedOpenIDResult,
user_defined_values: Optional[SSOUserDefinedValues],
prisma_client: PrismaClient,
user_api_key_cache: UserApiKeyCache,
proxy_logging_obj: ProxyLogging,
):
from fastapi.responses import HTMLResponse
user_id = parsed_openid_result.get("user_id")
user_email = parsed_openid_result.get("user_email")
user_info = await get_user_info_from_db(
result=result,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
user_email=user_email,
user_defined_values=user_defined_values,
alternate_user_id=user_id,
)
if user_info is None:
raise HTTPException(
status_code=500, detail="Failed to retrieve user information from SSO"
)
if not user_info.user_id:
raise HTTPException(
status_code=500, detail="Failed to retrieve user information from SSO"
)
teams: List[str] = []
if hasattr(user_info, "teams") and user_info.teams:
teams = user_info.teams if isinstance(user_info.teams, list) else []
team_details = await _fetch_cli_sso_team_details(
prisma_client=prisma_client, teams=teams
)
attribution_metadata = build_cli_sso_attribution_metadata(result=result)
if attribution_metadata:
await _persist_cli_sso_user_metadata(
prisma_client=prisma_client,
user_id=cast(str, user_info.user_id),
attribution_metadata=attribution_metadata,
)
flow["session_data"] = {
"user_id": cast(str, user_info.user_id),
"user_role": user_info.user_role,
"models": user_info.models if hasattr(user_info, "models") else [],
"user_email": user_email,
"teams": teams,
"team_details": team_details,
"attribution_metadata": attribution_metadata,
}
flow["sso_complete"] = True
browser_complete_token = secrets.token_urlsafe(32)
flow["browser_complete_token_hash"] = _hash_cli_sso_secret(browser_complete_token)
_set_cli_sso_flow(login_id=key, cache=user_api_key_cache, flow=flow)
verbose_proxy_logger.info(
f"Stored CLI SSO session for user: {user_info.user_id}, teams: {teams}, num_teams: {len(teams)}"
)
verify_url = get_custom_url(
request_base_url=str(request.base_url),
route=f"sso/cli/complete/{key}",
)
return HTMLResponse(
content=_render_cli_sso_verification_page(
verify_url=verify_url,
browser_complete_token=browser_complete_token,
),
status_code=200,
)
async def cli_sso_callback(
request: Request,
key: Optional[str] = None,
result: Optional[Union[OpenID, dict]] = None,
received_response: Optional[dict] = None,
):
"""CLI SSO callback - stores session info for JWT generation on polling"""
verbose_proxy_logger.info("CLI SSO callback")
from litellm.proxy.proxy_server import (
general_settings,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
@ -1722,92 +2105,40 @@ async def cli_sso_callback(
# After None check, cast to non-None type for type checker
result_non_none: Union[OpenID, dict] = cast(Union[OpenID, dict], result)
parsed_openid_result = SSOAuthenticationHandler._get_user_email_and_id_from_result(
result=result_non_none
)
verbose_proxy_logger.debug(f"parsed_openid_result: {parsed_openid_result}")
try:
# Get full user info from DB
user_info = await get_user_info_from_db(
parsed_openid_result = (
SSOAuthenticationHandler._get_user_email_and_id_from_result(
result=result_non_none,
generic_client_id=os.getenv("GENERIC_CLIENT_ID", None),
)
)
verbose_proxy_logger.debug(f"parsed_openid_result: {parsed_openid_result}")
user_defined_values = await _build_cli_sso_user_defined_values(
result=result_non_none,
parsed_openid_result=parsed_openid_result,
)
SSOAuthenticationHandler.verify_user_in_restricted_sso_group(
general_settings=general_settings,
result=result_non_none,
received_response=received_response,
)
return await _complete_cli_sso_callback_session(
request=request,
key=cast(str, key),
flow=flow,
result=result_non_none,
parsed_openid_result=parsed_openid_result,
user_defined_values=user_defined_values,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
user_email=parsed_openid_result.get("user_email"),
user_defined_values=None,
alternate_user_id=parsed_openid_result.get("user_id"),
)
if user_info is None:
raise HTTPException(
status_code=500, detail="Failed to retrieve user information from SSO"
)
# Get all teams from user_info - CLI will let user select which one
teams: List[str] = []
if hasattr(user_info, "teams") and user_info.teams:
teams = user_info.teams if isinstance(user_info.teams, list) else []
# Also fetch team aliases for a better CLI UX. We keep the original
# "teams" list of IDs for backwards compatibility and add an
# optional "team_details" field containing objects with both
# team_id and team_alias.
team_details: List[Dict[str, Any]] = []
try:
if teams:
prisma_teams = await prisma_client.db.litellm_teamtable.find_many(
where={"team_id": {"in": teams}}
)
for team_row in prisma_teams:
team_dict = team_row.model_dump()
team_details.append(
{
"team_id": team_dict.get("team_id"),
"team_alias": team_dict.get("team_alias"),
}
)
except Exception as e:
# If anything goes wrong here, fall back gracefully without
# impacting the SSO flow.
verbose_proxy_logger.error(
f"Error fetching team details for CLI SSO session: {e}"
)
session_data = {
"user_id": user_info.user_id,
"user_role": user_info.user_role,
"models": user_info.models if hasattr(user_info, "models") else [],
"user_email": parsed_openid_result.get("user_email"),
"teams": teams,
# Optional rich metadata for clients that want nicer display
"team_details": team_details,
}
flow["session_data"] = session_data
flow["sso_complete"] = True
browser_complete_token = secrets.token_urlsafe(32)
flow["browser_complete_token_hash"] = _hash_cli_sso_secret(
browser_complete_token
)
_set_cli_sso_flow(login_id=cast(str, key), cache=user_api_key_cache, flow=flow)
verbose_proxy_logger.info(
f"Stored CLI SSO session for user: {user_info.user_id}, teams: {teams}, num_teams: {len(teams)}"
)
from fastapi.responses import HTMLResponse
verify_url = get_custom_url(
request_base_url=str(request.base_url),
route=f"sso/cli/complete/{key}",
)
html_content = _render_cli_sso_verification_page(
verify_url=verify_url,
browser_complete_token=browser_complete_token,
)
return HTMLResponse(content=html_content, status_code=200)
except ProxyException:
raise
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.error(f"Error with CLI SSO callback: {e}")
raise HTTPException(
@ -1873,13 +2204,19 @@ async def cli_poll_key(
team_details_response = [
{"team_id": t, "team_alias": None} for t in user_teams
]
return {
poll_response: Dict[str, Any] = {
"status": "ready",
"user_id": user_id,
"teams": user_teams,
"team_details": team_details_response,
"requires_team_selection": True,
}
attribution_metadata = _cli_poll_attribution_metadata_from_session(
session_data
)
if attribution_metadata:
poll_response["attribution_metadata"] = attribution_metadata
return poll_response
# Validate team_id if provided
if team_id is not None:
@ -1912,7 +2249,7 @@ async def cli_poll_key(
verbose_proxy_logger.info(
f"CLI JWT generated for user: {user_id}, team: {team_id}"
)
return {
poll_response = {
"status": "ready",
"key": jwt_token,
"user_id": user_id,
@ -1922,6 +2259,12 @@ async def cli_poll_key(
# present nicer information if needed.
"team_details": user_team_details,
}
attribution_metadata = _cli_poll_attribution_metadata_from_session(
session_data
)
if attribution_metadata:
poll_response["attribution_metadata"] = attribution_metadata
return poll_response
else:
return {"status": "pending"}

View File

@ -2438,6 +2438,7 @@ class TestCLIKeyRegenerationFlow:
request=mock_request,
key="cli-new-session-key-456",
result=mock_result,
received_response=None,
)
def test_get_redirect_url_does_not_include_existing_key_in_url(self):
@ -5552,6 +5553,289 @@ def test_generic_response_convertor_extra_attributes_missing_field(monkeypatch):
assert result.extra_fields["another_missing"] is None
class TestCliSsoAttributionMetadata:
"""CLI SSO allowlisted OIDC claim persistence and poll exposure."""
def test_parse_cli_sso_claim_map(self, monkeypatch):
from litellm.proxy.management_endpoints import ui_sso
monkeypatch.setattr(
ui_sso,
"CLI_SSO_CLAIM_MAP",
"employment_type->metadata.acme_employment_type, org_info.department -> department",
)
assert ui_sso._parse_cli_sso_claim_map() == [
("employment_type", "acme_employment_type"),
("org_info.department", "department"),
]
def test_build_cli_sso_attribution_metadata_filters_non_scalars(self, monkeypatch):
from litellm.proxy.management_endpoints import ui_sso
from litellm.proxy.management_endpoints.types import CustomOpenID
monkeypatch.setattr(
ui_sso,
"CLI_SSO_CLAIM_MAP",
"employment_type->acme_employment_type,access_token->should_drop,group->groups",
)
result = CustomOpenID(
id="user-1",
email="user@example.com",
display_name="User",
provider="generic",
team_ids=[],
extra_fields={
"employment_type": "full_time",
"access_token": "eyJhbGciOiJIUzI1NiJ9.payload.signature",
"group": ["team-a", "team-b"],
},
)
metadata = ui_sso.build_cli_sso_attribution_metadata(result=result)
assert metadata == {"acme_employment_type": "full_time"}
def test_build_cli_sso_attribution_metadata_from_oidc_dict(self, monkeypatch):
from litellm.proxy.management_endpoints import ui_sso
monkeypatch.setattr(
ui_sso,
"CLI_SSO_CLAIM_MAP",
"org_info.department->department",
)
metadata = ui_sso.build_cli_sso_attribution_metadata(
result={
"sub": "user-1",
"email": "user@example.com",
"org_info": {"department": "Engineering"},
}
)
assert metadata == {"department": "Engineering"}
@pytest.mark.asyncio
async def test_cli_sso_callback_passes_user_defined_values_for_new_users(self):
"""First CLI SSO login must supply SSOUserDefinedValues so upsert can create the user."""
from litellm.proxy._types import LiteLLM_UserTable
from litellm.proxy.management_endpoints import ui_sso
from litellm.proxy.management_endpoints.types import CustomOpenID
mock_request = MagicMock(spec=Request)
mock_request.base_url = "http://internal-proxy.local/"
session_key = "cli-session-new-user"
mock_user_info = LiteLLM_UserTable(
user_id="cli-test-user",
user_role="internal_user",
teams=[],
models=[],
)
mock_sso_result = CustomOpenID(
id="cli-test-user",
email="cli-test@example.com",
display_name="cli-test-user",
provider="generic",
team_ids=[],
)
mock_cache = MagicMock()
mock_cache.get_cache.return_value = {
"poll_secret_hash": "poll-secret-hash",
"user_code_hash": "user-code-hash",
"sso_complete": False,
"user_code_verified": False,
"session_data": None,
}
get_user_info_mock = AsyncMock(return_value=mock_user_info)
with (
patch(
"litellm.proxy.management_endpoints.ui_sso.get_user_info_from_db",
get_user_info_mock,
),
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
patch("litellm.proxy.proxy_server.user_custom_sso", None),
):
await ui_sso.cli_sso_callback(
request=mock_request,
key=session_key,
result=mock_sso_result,
)
get_user_info_mock.assert_awaited_once()
assert get_user_info_mock.call_args.kwargs["user_defined_values"] is not None
assert (
get_user_info_mock.call_args.kwargs["user_defined_values"]["user_id"]
== "cli-test-user"
)
@pytest.mark.asyncio
async def test_cli_sso_callback_rejects_restricted_sso_group(self):
"""CLI SSO must enforce restricted_sso_group before upserting the user."""
from litellm.proxy._types import ProxyException
from litellm.proxy.management_endpoints import ui_sso
from litellm.proxy.management_endpoints.types import CustomOpenID
mock_request = MagicMock(spec=Request)
mock_request.base_url = "http://internal-proxy.local/"
mock_cache = MagicMock()
mock_cache.get_cache.return_value = {
"poll_secret_hash": "poll-secret-hash",
"user_code_hash": "user-code-hash",
"sso_complete": False,
"user_code_verified": False,
"session_data": None,
}
mock_sso_result = CustomOpenID(
id="cli-test-user",
email="cli-test@example.com",
display_name="cli-test-user",
provider="generic",
team_ids=["other-group"],
)
with (
patch(
"litellm.proxy.management_endpoints.ui_sso.get_user_info_from_db",
new=AsyncMock(),
) as get_user_info_mock,
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
patch("litellm.proxy.proxy_server.user_custom_sso", None),
patch(
"litellm.proxy.proxy_server.general_settings",
{
"ui_access_mode": {
"type": "restricted_sso_group",
"restricted_sso_group": "required-group",
}
},
),
):
with pytest.raises(ProxyException):
await ui_sso.cli_sso_callback(
request=mock_request,
key="cli-session-restricted",
result=mock_sso_result,
received_response={"groups": ["other-group"]},
)
get_user_info_mock.assert_not_awaited()
@pytest.mark.asyncio
async def test_cli_sso_callback_persists_attribution_metadata(self, monkeypatch):
from litellm.proxy._types import LiteLLM_UserTable
from litellm.proxy.management_endpoints import ui_sso
monkeypatch.setattr(
ui_sso,
"CLI_SSO_CLAIM_MAP",
"employment_type->acme_employment_type",
)
mock_request = MagicMock(spec=Request)
mock_request.base_url = "http://internal-proxy.local/"
session_key = "cli-session-4567890"
mock_user_info = LiteLLM_UserTable(
user_id="test-user-123",
user_role="internal_user",
teams=["team1"],
models=["gpt-4"],
)
mock_sso_result = {
"user_email": "test@example.com",
"user_id": "test-user-123",
"employment_type": "contractor",
}
mock_cache = MagicMock()
mock_cache.get_cache.return_value = {
"poll_secret_hash": "poll-secret-hash",
"user_code_hash": "user-code-hash",
"sso_complete": False,
"user_code_verified": False,
"session_data": None,
}
mock_prisma = MagicMock()
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(
return_value=MagicMock(metadata={"auth_provider": "generic"})
)
mock_prisma.db.litellm_usertable.update_many = AsyncMock()
with (
patch.dict(
os.environ,
{
"PROXY_BASE_URL": "https://test.example.com",
"SERVER_ROOT_PATH": "",
},
),
patch(
"litellm.proxy.management_endpoints.ui_sso.get_user_info_from_db",
return_value=mock_user_info,
),
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
patch("litellm.proxy.proxy_server.user_custom_sso", None),
patch(
"litellm.proxy.common_utils.html_forms.cli_sso_success.render_cli_sso_success_page",
return_value="<html>Success</html>",
),
):
await ui_sso.cli_sso_callback(
request=mock_request,
key=session_key,
result=mock_sso_result,
)
flow_data = mock_cache.set_cache.call_args.kwargs["value"]
assert flow_data["session_data"]["attribution_metadata"] == {
"acme_employment_type": "contractor"
}
mock_prisma.db.litellm_usertable.update_many.assert_awaited_once()
update_data = mock_prisma.db.litellm_usertable.update_many.call_args.kwargs[
"data"
]
assert update_data["metadata"]["acme_employment_type"] == "contractor"
assert update_data["metadata"]["auth_provider"] == "generic"
@pytest.mark.asyncio
async def test_cli_poll_key_returns_attribution_metadata(self, monkeypatch):
from litellm.proxy.management_endpoints.ui_sso import (
_hash_cli_sso_secret,
cli_poll_key,
)
session_key = "cli-session-789123"
session_data = {
"user_id": "test-user-456",
"user_role": "internal_user",
"teams": ["team-a", "team-b"],
"models": ["gpt-4"],
"attribution_metadata": {
"acme_employment_type": "full_time",
"org": {"cost_center": "CC-42"},
},
}
mock_cache = MagicMock()
mock_cache.get_cache.return_value = {
"poll_secret_hash": _hash_cli_sso_secret("poll-secret"),
"sso_complete": True,
"user_code_verified": True,
"session_data": session_data,
}
with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache):
result = await cli_poll_key(
key_id=session_key,
team_id=None,
x_litellm_cli_poll_secret="poll-secret",
)
assert result["attribution_metadata"] == {
"acme_employment_type": "full_time",
"org.cost_center": "CC-42",
}
class TestValidateReturnTo:
"""Tests for SSOAuthenticationHandler._validate_return_to"""