diff --git a/litellm/constants.py b/litellm/constants.py index e36746326c..fb765c0226 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -1443,6 +1443,12 @@ CLI_JWT_EXPIRATION_HOURS = int( or os.getenv("LITELLM_CLI_JWT_EXPIRATION_HOURS") or 24 ) +# Comma-separated allowlisted OIDC claim map for CLI SSO polling, e.g. +# "employment_type->acme_employment_type,org_info.department->department" +CLI_SSO_CLAIM_MAP = ( + os.getenv("CLI_SSO_CLAIM_MAP") or os.getenv("LITELLM_CLI_SSO_CLAIM_MAP") or "" +) +CLI_SSO_CLAIM_MAX_SCALAR_LENGTH = 1024 ########################### UI SESSION DURATION ########################### # Duration for UI login session (username/password, SSO, invitation links). Format: "30s", "30m", "24h", "7d" diff --git a/litellm/proxy/client/README.md b/litellm/proxy/client/README.md index adf562d69c..9fbc6f2197 100644 --- a/litellm/proxy/client/README.md +++ b/litellm/proxy/client/README.md @@ -350,7 +350,7 @@ The CLI provides three authentication commands: 4. **User Authentication**: User completes SSO authentication in browser 5. **Callback Processing**: SSO provider redirects back to proxy with state parameter 6. **User Code Verification**: Browser confirms the verification code shown in the CLI -7. **Polling**: CLI polls `/sso/cli/poll/{login_id}` with the polling secret header until the JWT is ready +7. **Polling**: CLI polls `/sso/cli/poll/{login_id}` with the polling secret header until the JWT is ready. When `CLI_SSO_CLAIM_MAP` is configured on the proxy, the poll response may include `attribution_metadata` (allowlisted scalar OIDC claims for client attribution). 8. **Token Storage**: CLI saves the authentication token to `~/.litellm/token.json` ### Benefits of This Approach diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index ff3bbf4738..d3e1099d96 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -43,6 +43,8 @@ from litellm.caching.dual_cache import DualCache from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid from litellm.constants import ( + CLI_SSO_CLAIM_MAP, + CLI_SSO_CLAIM_MAX_SCALAR_LENGTH, CLI_SSO_SESSION_CACHE_KEY_PREFIX, CLI_SSO_SESSION_TTL_SECONDS, LITELLM_CLI_SOURCE_IDENTIFIER, @@ -140,6 +142,20 @@ _CLI_SSO_START_RATE_LIMIT_WINDOW_SECONDS = 60 _CLI_SSO_START_RATE_LIMIT_MAX_ATTEMPTS = 30 _CLI_SSO_USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789" _CLI_SSO_LOGIN_ID_RE = re.compile(r"^cli-[A-Za-z0-9_-]{12,124}$") +_CLI_SSO_SCALAR_TYPES = (str, int, float, bool) +_CLI_SSO_DEST_KEY_RE = re.compile(r"^[A-Za-z0-9_.-]+$") +_CLI_SSO_SECRET_KEY_FRAGMENTS = frozenset( + { + "access_token", + "api_key", + "client_secret", + "id_token", + "password", + "private_key", + "refresh_token", + "secret", + } +) def _hash_cli_sso_secret(secret: str) -> str: @@ -225,6 +241,239 @@ def _verify_cli_sso_poll_secret(flow: dict, poll_secret: Optional[str]) -> bool: return secrets.compare_digest(supplied_poll_secret_hash, expected_poll_secret_hash) +def _parse_cli_sso_claim_map() -> List[Tuple[str, str]]: + """ + Parse CLI_SSO_CLAIM_MAP / LITELLM_CLI_SSO_CLAIM_MAP. + + Format: comma-separated ``source_claim->metadata_key`` pairs, e.g. + ``employment_type->acme_employment_type,org_info.department->department``. + Destination keys may use an optional ``metadata.`` prefix; values are stored + on the LiteLLM user's ``metadata`` JSON column. + """ + claim_map_raw = CLI_SSO_CLAIM_MAP.strip() + if not claim_map_raw: + return [] + + parsed: List[Tuple[str, str]] = [] + for entry in claim_map_raw.split(","): + entry = entry.strip() + if not entry or "->" not in entry: + continue + source_claim, dest_key = entry.split("->", 1) + source_claim = source_claim.strip() + dest_key = dest_key.strip() + if dest_key.startswith("metadata."): + dest_key = dest_key[len("metadata.") :] + if source_claim and dest_key: + parsed.append((source_claim, dest_key)) + return parsed + + +def _is_safe_cli_sso_metadata_dest_key(dest_key: str) -> bool: + if not dest_key or not _CLI_SSO_DEST_KEY_RE.fullmatch(dest_key): + return False + lowered = dest_key.lower() + return not any(fragment in lowered for fragment in _CLI_SSO_SECRET_KEY_FRAGMENTS) + + +def _is_safe_cli_sso_scalar_claim_value(value: Any) -> bool: + if not isinstance(value, _CLI_SSO_SCALAR_TYPES): + return False + if isinstance(value, str): + if len(value) > CLI_SSO_CLAIM_MAX_SCALAR_LENGTH: + return False + if value.startswith("eyJ") and value.count(".") >= 2: + return False + return True + + +def _sso_result_to_dict(result: Union[CustomOpenID, OpenID, dict]) -> Dict[str, Any]: + if isinstance(result, dict): + return result + if hasattr(result, "model_dump"): + dumped = result.model_dump() + if isinstance(dumped, dict): + return cast(Dict[str, Any], dumped) + return {} + + +def _get_nested_claim_value(data: Dict[str, Any], claim_path: str) -> Any: + """Resolve a dot-notation claim path against an SSO result dict. + + Unlike ``get_nested_value``, this does not strip a leading ``metadata.`` + prefix, since OIDC claims may legitimately use ``metadata`` as a top-level + key. + """ + if not claim_path: + return None + if claim_path in data: + return data[claim_path] + placeholder = "\x00" + parts = claim_path.replace("\\.", placeholder).split(".") + parts = [p.replace(placeholder, ".") for p in parts] + current: Any = data + for part in parts: + if isinstance(current, dict) and part in current: + current = current[part] + else: + return None + return current + + +def _extract_sso_claim_value( + result: Union[CustomOpenID, OpenID, dict], claim_path: str +) -> Any: + extra_fields = getattr(result, "extra_fields", None) + if isinstance(extra_fields, dict): + if claim_path in extra_fields: + return extra_fields[claim_path] + nested = _get_nested_claim_value(extra_fields, claim_path) + if nested is not None: + return nested + + if isinstance(result, dict): + return _get_nested_claim_value(result, claim_path) + + result_dict = _sso_result_to_dict(result) + return _get_nested_claim_value(result_dict, claim_path) + + +def _set_nested_metadata_value( + metadata: Dict[str, Any], key_path: str, value: Any +) -> None: + placeholder = "\x00" + parts = key_path.replace("\\.", placeholder).split(".") + parts = [p.replace(placeholder, ".") for p in parts] + current: Any = metadata + for part in parts[:-1]: + existing = current.get(part) + if not isinstance(existing, dict): + existing = {} + current[part] = existing + current = existing + current[parts[-1]] = value + + +def _flatten_cli_sso_metadata_for_poll( + metadata: Dict[str, Any], +) -> Dict[str, Union[str, int, float, bool]]: + """Expose scalar attribution metadata as a flat dict for CLI poll responses.""" + flattened: Dict[str, Union[str, int, float, bool]] = {} + stack: List[Tuple[str, Any]] = [("", metadata)] + while stack: + prefix, value = stack.pop() + if isinstance(value, dict): + for key, nested in value.items(): + nested_prefix = f"{prefix}.{key}" if prefix else key + stack.append((nested_prefix, nested)) + elif _is_safe_cli_sso_scalar_claim_value(value): + flattened[prefix] = value + return flattened + + +def build_cli_sso_attribution_metadata( + result: Union[CustomOpenID, OpenID, dict], +) -> Dict[str, Any]: + """ + Build allowlisted, non-secret scalar attribution metadata from an SSO result. + + Sources are configured via CLI_SSO_CLAIM_MAP / LITELLM_CLI_SSO_CLAIM_MAP and + may include claims captured by GENERIC_USER_EXTRA_ATTRIBUTES on CustomOpenID. + """ + claim_map = _parse_cli_sso_claim_map() + if not claim_map: + return {} + + metadata: Dict[str, Any] = {} + for source_claim, dest_key in claim_map: + if not _is_safe_cli_sso_metadata_dest_key(dest_key): + verbose_proxy_logger.debug( + f"Skipping unsafe CLI SSO metadata destination key: {dest_key}" + ) + continue + + raw_value = _extract_sso_claim_value(result=result, claim_path=source_claim) + if not _is_safe_cli_sso_scalar_claim_value(raw_value): + continue + + _set_nested_metadata_value( + metadata=metadata, key_path=dest_key, value=raw_value + ) + + return metadata + + +def _merge_cli_sso_attribution_metadata( + existing_metadata: Dict[str, Any], attribution_metadata: Dict[str, Any] +) -> Dict[str, Any]: + """Merge attribution metadata into existing user metadata in-place. + + Preserves original value types (in particular, string claim values that + happen to look numeric are NOT coerced to ``int``/``float``). Nested dicts + are merged iteratively so attribution claims do not clobber unrelated keys + under the same parent. + """ + pending: List[Tuple[Dict[str, Any], Dict[str, Any]]] = [ + (existing_metadata, attribution_metadata) + ] + while pending: + target, source = pending.pop() + for key, value in source.items(): + if value is None: + continue + existing_value = target.get(key) + if isinstance(value, dict) and isinstance(existing_value, dict): + pending.append((existing_value, value)) + else: + target[key] = value + return existing_metadata + + +async def _persist_cli_sso_user_metadata( + prisma_client: PrismaClient, + user_id: str, + attribution_metadata: Dict[str, Any], +) -> None: + if not attribution_metadata: + return + + try: + user_row = await prisma_client.db.litellm_usertable.find_unique( + where={"user_id": user_id} + ) + existing_metadata: Dict[str, Any] = {} + if user_row is not None: + row_metadata = user_row.metadata + if isinstance(row_metadata, dict): + existing_metadata = deepcopy(row_metadata) + + merged_metadata = _merge_cli_sso_attribution_metadata( + existing_metadata=existing_metadata, + attribution_metadata=attribution_metadata, + ) + await prisma_client.db.litellm_usertable.update_many( + where={"user_id": user_id}, + data={"metadata": merged_metadata}, + ) + verbose_proxy_logger.info( + f"Persisted CLI SSO attribution metadata for user {user_id}: " + f"{list(_flatten_cli_sso_metadata_for_poll(attribution_metadata).keys())}" + ) + except Exception as e: + verbose_proxy_logger.error( + f"Failed to persist CLI SSO attribution metadata for user {user_id}: {e}" + ) + + +def _cli_poll_attribution_metadata_from_session( + session_data: Dict[str, Any], +) -> Dict[str, Union[str, int, float, bool]]: + stored = session_data.get("attribution_metadata") + if isinstance(stored, dict): + return _flatten_cli_sso_metadata_for_poll(stored) + return {} + + def _render_cli_sso_verification_page( verify_url: str, browser_complete_token: str ) -> str: @@ -1674,7 +1923,12 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa: key_id = state_parts[1] if len(state_parts) > 1 else None verbose_proxy_logger.info("CLI SSO callback detected") - return await cli_sso_callback(request=request, key=key_id, result=result) + return await cli_sso_callback( + request=request, + key=key_id, + result=result, + received_response=received_response, + ) # Control-plane cross-origin: read return_to from cookie. # Starlette's cookie_parser already handles RFC 2109 unquoting. @@ -1692,15 +1946,144 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa: ) +async def _build_cli_sso_user_defined_values( + result: Union[OpenID, dict], + parsed_openid_result: ParsedOpenIDResult, +) -> Optional[SSOUserDefinedValues]: + from litellm.proxy.proxy_server import user_custom_sso + + user_id = parsed_openid_result.get("user_id") + if user_custom_sso is not None: + if inspect.iscoroutinefunction(user_custom_sso): + return await user_custom_sso(result) # type: ignore + raise ValueError("user_custom_sso must be a coroutine function") + if user_id is None: + return None + return SSOUserDefinedValues( + models=[], + user_id=user_id, + user_email=parsed_openid_result.get("user_email"), + max_budget=litellm.max_internal_user_budget, + user_role=parsed_openid_result.get("user_role"), + budget_duration=litellm.internal_user_budget_duration, + ) + + +async def _fetch_cli_sso_team_details( + prisma_client: PrismaClient, + teams: List[str], +) -> List[Dict[str, Any]]: + team_details: List[Dict[str, Any]] = [] + try: + if teams: + prisma_teams = await prisma_client.db.litellm_teamtable.find_many( + where={"team_id": {"in": teams}} + ) + for team_row in prisma_teams: + team_dict = team_row.model_dump() + team_details.append( + { + "team_id": team_dict.get("team_id"), + "team_alias": team_dict.get("team_alias"), + } + ) + except Exception as e: + verbose_proxy_logger.error( + f"Error fetching team details for CLI SSO session: {e}" + ) + return team_details + + +async def _complete_cli_sso_callback_session( + *, + request: Request, + key: str, + flow: dict, + result: Union[OpenID, dict], + parsed_openid_result: ParsedOpenIDResult, + user_defined_values: Optional[SSOUserDefinedValues], + prisma_client: PrismaClient, + user_api_key_cache: UserApiKeyCache, + proxy_logging_obj: ProxyLogging, +): + from fastapi.responses import HTMLResponse + + user_id = parsed_openid_result.get("user_id") + user_email = parsed_openid_result.get("user_email") + user_info = await get_user_info_from_db( + result=result, + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + proxy_logging_obj=proxy_logging_obj, + user_email=user_email, + user_defined_values=user_defined_values, + alternate_user_id=user_id, + ) + if user_info is None: + raise HTTPException( + status_code=500, detail="Failed to retrieve user information from SSO" + ) + if not user_info.user_id: + raise HTTPException( + status_code=500, detail="Failed to retrieve user information from SSO" + ) + + teams: List[str] = [] + if hasattr(user_info, "teams") and user_info.teams: + teams = user_info.teams if isinstance(user_info.teams, list) else [] + + team_details = await _fetch_cli_sso_team_details( + prisma_client=prisma_client, teams=teams + ) + attribution_metadata = build_cli_sso_attribution_metadata(result=result) + if attribution_metadata: + await _persist_cli_sso_user_metadata( + prisma_client=prisma_client, + user_id=cast(str, user_info.user_id), + attribution_metadata=attribution_metadata, + ) + + flow["session_data"] = { + "user_id": cast(str, user_info.user_id), + "user_role": user_info.user_role, + "models": user_info.models if hasattr(user_info, "models") else [], + "user_email": user_email, + "teams": teams, + "team_details": team_details, + "attribution_metadata": attribution_metadata, + } + flow["sso_complete"] = True + browser_complete_token = secrets.token_urlsafe(32) + flow["browser_complete_token_hash"] = _hash_cli_sso_secret(browser_complete_token) + _set_cli_sso_flow(login_id=key, cache=user_api_key_cache, flow=flow) + + verbose_proxy_logger.info( + f"Stored CLI SSO session for user: {user_info.user_id}, teams: {teams}, num_teams: {len(teams)}" + ) + verify_url = get_custom_url( + request_base_url=str(request.base_url), + route=f"sso/cli/complete/{key}", + ) + return HTMLResponse( + content=_render_cli_sso_verification_page( + verify_url=verify_url, + browser_complete_token=browser_complete_token, + ), + status_code=200, + ) + + async def cli_sso_callback( request: Request, key: Optional[str] = None, result: Optional[Union[OpenID, dict]] = None, + received_response: Optional[dict] = None, ): """CLI SSO callback - stores session info for JWT generation on polling""" verbose_proxy_logger.info("CLI SSO callback") from litellm.proxy.proxy_server import ( + general_settings, prisma_client, proxy_logging_obj, user_api_key_cache, @@ -1722,92 +2105,40 @@ async def cli_sso_callback( # After None check, cast to non-None type for type checker result_non_none: Union[OpenID, dict] = cast(Union[OpenID, dict], result) - parsed_openid_result = SSOAuthenticationHandler._get_user_email_and_id_from_result( - result=result_non_none - ) - verbose_proxy_logger.debug(f"parsed_openid_result: {parsed_openid_result}") - try: - # Get full user info from DB - user_info = await get_user_info_from_db( + parsed_openid_result = ( + SSOAuthenticationHandler._get_user_email_and_id_from_result( + result=result_non_none, + generic_client_id=os.getenv("GENERIC_CLIENT_ID", None), + ) + ) + verbose_proxy_logger.debug(f"parsed_openid_result: {parsed_openid_result}") + user_defined_values = await _build_cli_sso_user_defined_values( result=result_non_none, + parsed_openid_result=parsed_openid_result, + ) + + SSOAuthenticationHandler.verify_user_in_restricted_sso_group( + general_settings=general_settings, + result=result_non_none, + received_response=received_response, + ) + + return await _complete_cli_sso_callback_session( + request=request, + key=cast(str, key), + flow=flow, + result=result_non_none, + parsed_openid_result=parsed_openid_result, + user_defined_values=user_defined_values, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, - user_email=parsed_openid_result.get("user_email"), - user_defined_values=None, - alternate_user_id=parsed_openid_result.get("user_id"), ) - - if user_info is None: - raise HTTPException( - status_code=500, detail="Failed to retrieve user information from SSO" - ) - - # Get all teams from user_info - CLI will let user select which one - teams: List[str] = [] - if hasattr(user_info, "teams") and user_info.teams: - teams = user_info.teams if isinstance(user_info.teams, list) else [] - - # Also fetch team aliases for a better CLI UX. We keep the original - # "teams" list of IDs for backwards compatibility and add an - # optional "team_details" field containing objects with both - # team_id and team_alias. - team_details: List[Dict[str, Any]] = [] - try: - if teams: - prisma_teams = await prisma_client.db.litellm_teamtable.find_many( - where={"team_id": {"in": teams}} - ) - for team_row in prisma_teams: - team_dict = team_row.model_dump() - team_details.append( - { - "team_id": team_dict.get("team_id"), - "team_alias": team_dict.get("team_alias"), - } - ) - except Exception as e: - # If anything goes wrong here, fall back gracefully without - # impacting the SSO flow. - verbose_proxy_logger.error( - f"Error fetching team details for CLI SSO session: {e}" - ) - - session_data = { - "user_id": user_info.user_id, - "user_role": user_info.user_role, - "models": user_info.models if hasattr(user_info, "models") else [], - "user_email": parsed_openid_result.get("user_email"), - "teams": teams, - # Optional rich metadata for clients that want nicer display - "team_details": team_details, - } - - flow["session_data"] = session_data - flow["sso_complete"] = True - browser_complete_token = secrets.token_urlsafe(32) - flow["browser_complete_token_hash"] = _hash_cli_sso_secret( - browser_complete_token - ) - _set_cli_sso_flow(login_id=cast(str, key), cache=user_api_key_cache, flow=flow) - - verbose_proxy_logger.info( - f"Stored CLI SSO session for user: {user_info.user_id}, teams: {teams}, num_teams: {len(teams)}" - ) - - from fastapi.responses import HTMLResponse - - verify_url = get_custom_url( - request_base_url=str(request.base_url), - route=f"sso/cli/complete/{key}", - ) - html_content = _render_cli_sso_verification_page( - verify_url=verify_url, - browser_complete_token=browser_complete_token, - ) - return HTMLResponse(content=html_content, status_code=200) - + except ProxyException: + raise + except HTTPException: + raise except Exception as e: verbose_proxy_logger.error(f"Error with CLI SSO callback: {e}") raise HTTPException( @@ -1873,13 +2204,19 @@ async def cli_poll_key( team_details_response = [ {"team_id": t, "team_alias": None} for t in user_teams ] - return { + poll_response: Dict[str, Any] = { "status": "ready", "user_id": user_id, "teams": user_teams, "team_details": team_details_response, "requires_team_selection": True, } + attribution_metadata = _cli_poll_attribution_metadata_from_session( + session_data + ) + if attribution_metadata: + poll_response["attribution_metadata"] = attribution_metadata + return poll_response # Validate team_id if provided if team_id is not None: @@ -1912,7 +2249,7 @@ async def cli_poll_key( verbose_proxy_logger.info( f"CLI JWT generated for user: {user_id}, team: {team_id}" ) - return { + poll_response = { "status": "ready", "key": jwt_token, "user_id": user_id, @@ -1922,6 +2259,12 @@ async def cli_poll_key( # present nicer information if needed. "team_details": user_team_details, } + attribution_metadata = _cli_poll_attribution_metadata_from_session( + session_data + ) + if attribution_metadata: + poll_response["attribution_metadata"] = attribution_metadata + return poll_response else: return {"status": "pending"} diff --git a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py index 23216542f3..a72633b726 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py +++ b/tests/test_litellm/proxy/management_endpoints/test_ui_sso.py @@ -2438,6 +2438,7 @@ class TestCLIKeyRegenerationFlow: request=mock_request, key="cli-new-session-key-456", result=mock_result, + received_response=None, ) def test_get_redirect_url_does_not_include_existing_key_in_url(self): @@ -5552,6 +5553,289 @@ def test_generic_response_convertor_extra_attributes_missing_field(monkeypatch): assert result.extra_fields["another_missing"] is None +class TestCliSsoAttributionMetadata: + """CLI SSO allowlisted OIDC claim persistence and poll exposure.""" + + def test_parse_cli_sso_claim_map(self, monkeypatch): + from litellm.proxy.management_endpoints import ui_sso + + monkeypatch.setattr( + ui_sso, + "CLI_SSO_CLAIM_MAP", + "employment_type->metadata.acme_employment_type, org_info.department -> department", + ) + assert ui_sso._parse_cli_sso_claim_map() == [ + ("employment_type", "acme_employment_type"), + ("org_info.department", "department"), + ] + + def test_build_cli_sso_attribution_metadata_filters_non_scalars(self, monkeypatch): + from litellm.proxy.management_endpoints import ui_sso + from litellm.proxy.management_endpoints.types import CustomOpenID + + monkeypatch.setattr( + ui_sso, + "CLI_SSO_CLAIM_MAP", + "employment_type->acme_employment_type,access_token->should_drop,group->groups", + ) + + result = CustomOpenID( + id="user-1", + email="user@example.com", + display_name="User", + provider="generic", + team_ids=[], + extra_fields={ + "employment_type": "full_time", + "access_token": "eyJhbGciOiJIUzI1NiJ9.payload.signature", + "group": ["team-a", "team-b"], + }, + ) + + metadata = ui_sso.build_cli_sso_attribution_metadata(result=result) + assert metadata == {"acme_employment_type": "full_time"} + + def test_build_cli_sso_attribution_metadata_from_oidc_dict(self, monkeypatch): + from litellm.proxy.management_endpoints import ui_sso + + monkeypatch.setattr( + ui_sso, + "CLI_SSO_CLAIM_MAP", + "org_info.department->department", + ) + + metadata = ui_sso.build_cli_sso_attribution_metadata( + result={ + "sub": "user-1", + "email": "user@example.com", + "org_info": {"department": "Engineering"}, + } + ) + assert metadata == {"department": "Engineering"} + + @pytest.mark.asyncio + async def test_cli_sso_callback_passes_user_defined_values_for_new_users(self): + """First CLI SSO login must supply SSOUserDefinedValues so upsert can create the user.""" + from litellm.proxy._types import LiteLLM_UserTable + from litellm.proxy.management_endpoints import ui_sso + from litellm.proxy.management_endpoints.types import CustomOpenID + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "http://internal-proxy.local/" + session_key = "cli-session-new-user" + mock_user_info = LiteLLM_UserTable( + user_id="cli-test-user", + user_role="internal_user", + teams=[], + models=[], + ) + mock_sso_result = CustomOpenID( + id="cli-test-user", + email="cli-test@example.com", + display_name="cli-test-user", + provider="generic", + team_ids=[], + ) + mock_cache = MagicMock() + mock_cache.get_cache.return_value = { + "poll_secret_hash": "poll-secret-hash", + "user_code_hash": "user-code-hash", + "sso_complete": False, + "user_code_verified": False, + "session_data": None, + } + get_user_info_mock = AsyncMock(return_value=mock_user_info) + + with ( + patch( + "litellm.proxy.management_endpoints.ui_sso.get_user_info_from_db", + get_user_info_mock, + ), + patch("litellm.proxy.proxy_server.prisma_client", MagicMock()), + patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache), + patch("litellm.proxy.proxy_server.user_custom_sso", None), + ): + await ui_sso.cli_sso_callback( + request=mock_request, + key=session_key, + result=mock_sso_result, + ) + + get_user_info_mock.assert_awaited_once() + assert get_user_info_mock.call_args.kwargs["user_defined_values"] is not None + assert ( + get_user_info_mock.call_args.kwargs["user_defined_values"]["user_id"] + == "cli-test-user" + ) + + @pytest.mark.asyncio + async def test_cli_sso_callback_rejects_restricted_sso_group(self): + """CLI SSO must enforce restricted_sso_group before upserting the user.""" + from litellm.proxy._types import ProxyException + from litellm.proxy.management_endpoints import ui_sso + from litellm.proxy.management_endpoints.types import CustomOpenID + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "http://internal-proxy.local/" + mock_cache = MagicMock() + mock_cache.get_cache.return_value = { + "poll_secret_hash": "poll-secret-hash", + "user_code_hash": "user-code-hash", + "sso_complete": False, + "user_code_verified": False, + "session_data": None, + } + mock_sso_result = CustomOpenID( + id="cli-test-user", + email="cli-test@example.com", + display_name="cli-test-user", + provider="generic", + team_ids=["other-group"], + ) + + with ( + patch( + "litellm.proxy.management_endpoints.ui_sso.get_user_info_from_db", + new=AsyncMock(), + ) as get_user_info_mock, + patch("litellm.proxy.proxy_server.prisma_client", MagicMock()), + patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache), + patch("litellm.proxy.proxy_server.user_custom_sso", None), + patch( + "litellm.proxy.proxy_server.general_settings", + { + "ui_access_mode": { + "type": "restricted_sso_group", + "restricted_sso_group": "required-group", + } + }, + ), + ): + with pytest.raises(ProxyException): + await ui_sso.cli_sso_callback( + request=mock_request, + key="cli-session-restricted", + result=mock_sso_result, + received_response={"groups": ["other-group"]}, + ) + + get_user_info_mock.assert_not_awaited() + + @pytest.mark.asyncio + async def test_cli_sso_callback_persists_attribution_metadata(self, monkeypatch): + from litellm.proxy._types import LiteLLM_UserTable + from litellm.proxy.management_endpoints import ui_sso + + monkeypatch.setattr( + ui_sso, + "CLI_SSO_CLAIM_MAP", + "employment_type->acme_employment_type", + ) + + mock_request = MagicMock(spec=Request) + mock_request.base_url = "http://internal-proxy.local/" + session_key = "cli-session-4567890" + mock_user_info = LiteLLM_UserTable( + user_id="test-user-123", + user_role="internal_user", + teams=["team1"], + models=["gpt-4"], + ) + mock_sso_result = { + "user_email": "test@example.com", + "user_id": "test-user-123", + "employment_type": "contractor", + } + mock_cache = MagicMock() + mock_cache.get_cache.return_value = { + "poll_secret_hash": "poll-secret-hash", + "user_code_hash": "user-code-hash", + "sso_complete": False, + "user_code_verified": False, + "session_data": None, + } + mock_prisma = MagicMock() + mock_prisma.db.litellm_usertable.find_unique = AsyncMock( + return_value=MagicMock(metadata={"auth_provider": "generic"}) + ) + mock_prisma.db.litellm_usertable.update_many = AsyncMock() + + with ( + patch.dict( + os.environ, + { + "PROXY_BASE_URL": "https://test.example.com", + "SERVER_ROOT_PATH": "", + }, + ), + patch( + "litellm.proxy.management_endpoints.ui_sso.get_user_info_from_db", + return_value=mock_user_info, + ), + patch("litellm.proxy.proxy_server.prisma_client", mock_prisma), + patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache), + patch("litellm.proxy.proxy_server.user_custom_sso", None), + patch( + "litellm.proxy.common_utils.html_forms.cli_sso_success.render_cli_sso_success_page", + return_value="Success", + ), + ): + await ui_sso.cli_sso_callback( + request=mock_request, + key=session_key, + result=mock_sso_result, + ) + + flow_data = mock_cache.set_cache.call_args.kwargs["value"] + assert flow_data["session_data"]["attribution_metadata"] == { + "acme_employment_type": "contractor" + } + mock_prisma.db.litellm_usertable.update_many.assert_awaited_once() + update_data = mock_prisma.db.litellm_usertable.update_many.call_args.kwargs[ + "data" + ] + assert update_data["metadata"]["acme_employment_type"] == "contractor" + assert update_data["metadata"]["auth_provider"] == "generic" + + @pytest.mark.asyncio + async def test_cli_poll_key_returns_attribution_metadata(self, monkeypatch): + from litellm.proxy.management_endpoints.ui_sso import ( + _hash_cli_sso_secret, + cli_poll_key, + ) + + session_key = "cli-session-789123" + session_data = { + "user_id": "test-user-456", + "user_role": "internal_user", + "teams": ["team-a", "team-b"], + "models": ["gpt-4"], + "attribution_metadata": { + "acme_employment_type": "full_time", + "org": {"cost_center": "CC-42"}, + }, + } + mock_cache = MagicMock() + mock_cache.get_cache.return_value = { + "poll_secret_hash": _hash_cli_sso_secret("poll-secret"), + "sso_complete": True, + "user_code_verified": True, + "session_data": session_data, + } + + with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache): + result = await cli_poll_key( + key_id=session_key, + team_id=None, + x_litellm_cli_poll_secret="poll-secret", + ) + + assert result["attribution_metadata"] == { + "acme_employment_type": "full_time", + "org.cost_center": "CC-42", + } + + class TestValidateReturnTo: """Tests for SSOAuthenticationHandler._validate_return_to"""