feat(proxy): persist allowlisted OIDC claims in CLI SSO poll (#28463)
* feat(proxy): persist allowlisted OIDC claims in CLI SSO poll
Map CLI_SSO_CLAIM_MAP sources into user metadata and return scalar
attribution_metadata from /sso/cli/poll. Build SSOUserDefinedValues in
cli_sso_callback so first-time CLI logins can upsert users. Add mock OIDC
scripts and tests for claim extraction and poll exposure.
Co-authored-by: Cursor <cursoragent@cursor.com>
* docs(proxy): document CLI SSO attribution_metadata in client README
Co-authored-by: Cursor <cursoragent@cursor.com>
* Delete scripts/mock_oidc_server_for_cli_sso.py
* Delete scripts/test_cli_sso_claims_e2e.py
* fix(ui_sso): preserve claim types and avoid metadata. prefix stripping
- Replace _update_dictionary with a local recursive merge so string
OIDC claim values that happen to look numeric are not silently coerced
to int/float when persisting CLI SSO attribution metadata.
- Use a local dot-path resolver in _extract_sso_claim_value so that
source claim paths beginning with 'metadata.' are not silently stripped
by get_nested_value (which is designed for LiteLLM JWT metadata, not
arbitrary OIDC claims).
Co-authored-by: Yassin Kortam <yassin@berri.ai>
* Remove redundant metadata. prefix strip in _set_nested_metadata_value
The _parse_cli_sso_claim_map already strips the metadata. prefix from
dest keys before reaching the setter. The duplicate strip in
_set_nested_metadata_value was a no-op in normal flow but could
mis-place values for dest keys like metadata.metadata.foo.
Co-authored-by: Yassin Kortam <yassin@berri.ai>
* Fix greptile
* Fix ruff
* Move CLI SSO user defined values build inside try/except for consistent error handling
Co-authored-by: Yassin Kortam <yassin@berri.ai>
* fix(proxy): enforce restricted SSO group on CLI SSO callback
Apply verify_user_in_restricted_sso_group before CLI session completion
and user upsert, matching the UI SSO path. Re-raise ProxyException so
restricted-group denials return 403 instead of 500.
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix(proxy): replace recursive CLI SSO metadata helpers with iterative merge
Use stack-based flatten/merge to satisfy recursive_detector CI. Fix mypy
types for UserApiKeyCache and user_id on CLI SSO session completion.
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix: resolve nested CustomOpenID extra_fields in CLI SSO claim extraction
When GENERIC_USER_EXTRA_ATTRIBUTES captures a parent object (e.g. org_info),
extra_fields stores it as {"org_info": {"department": "..."}}. A CLI claim
map entry using a dotted path like org_info.department would silently fail
because the lookup only checked the exact flat key. Fall back to dotted-path
resolution on extra_fields before model_dump().
Co-authored-by: Yassin Kortam <yassin@berri.ai>
* fix(sso): update CLI SSO test for new received_response kwarg and remove redundant 'token' secret fragment
Co-authored-by: Yassin Kortam <yassin@berri.ai>
---------
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Yassin Kortam <yassin@berri.ai>
This commit is contained in:
parent
ef36e89638
commit
50a3f10a92
@ -1443,6 +1443,12 @@ CLI_JWT_EXPIRATION_HOURS = int(
|
||||
or os.getenv("LITELLM_CLI_JWT_EXPIRATION_HOURS")
|
||||
or 24
|
||||
)
|
||||
# Comma-separated allowlisted OIDC claim map for CLI SSO polling, e.g.
|
||||
# "employment_type->acme_employment_type,org_info.department->department"
|
||||
CLI_SSO_CLAIM_MAP = (
|
||||
os.getenv("CLI_SSO_CLAIM_MAP") or os.getenv("LITELLM_CLI_SSO_CLAIM_MAP") or ""
|
||||
)
|
||||
CLI_SSO_CLAIM_MAX_SCALAR_LENGTH = 1024
|
||||
|
||||
########################### UI SESSION DURATION ###########################
|
||||
# Duration for UI login session (username/password, SSO, invitation links). Format: "30s", "30m", "24h", "7d"
|
||||
|
||||
@ -350,7 +350,7 @@ The CLI provides three authentication commands:
|
||||
4. **User Authentication**: User completes SSO authentication in browser
|
||||
5. **Callback Processing**: SSO provider redirects back to proxy with state parameter
|
||||
6. **User Code Verification**: Browser confirms the verification code shown in the CLI
|
||||
7. **Polling**: CLI polls `/sso/cli/poll/{login_id}` with the polling secret header until the JWT is ready
|
||||
7. **Polling**: CLI polls `/sso/cli/poll/{login_id}` with the polling secret header until the JWT is ready. When `CLI_SSO_CLAIM_MAP` is configured on the proxy, the poll response may include `attribution_metadata` (allowlisted scalar OIDC claims for client attribution).
|
||||
8. **Token Storage**: CLI saves the authentication token to `~/.litellm/token.json`
|
||||
|
||||
### Benefits of This Approach
|
||||
|
||||
@ -43,6 +43,8 @@ from litellm.caching.dual_cache import DualCache
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.constants import (
|
||||
CLI_SSO_CLAIM_MAP,
|
||||
CLI_SSO_CLAIM_MAX_SCALAR_LENGTH,
|
||||
CLI_SSO_SESSION_CACHE_KEY_PREFIX,
|
||||
CLI_SSO_SESSION_TTL_SECONDS,
|
||||
LITELLM_CLI_SOURCE_IDENTIFIER,
|
||||
@ -140,6 +142,20 @@ _CLI_SSO_START_RATE_LIMIT_WINDOW_SECONDS = 60
|
||||
_CLI_SSO_START_RATE_LIMIT_MAX_ATTEMPTS = 30
|
||||
_CLI_SSO_USER_CODE_ALPHABET = "ABCDEFGHJKLMNPQRSTUVWXYZ23456789"
|
||||
_CLI_SSO_LOGIN_ID_RE = re.compile(r"^cli-[A-Za-z0-9_-]{12,124}$")
|
||||
_CLI_SSO_SCALAR_TYPES = (str, int, float, bool)
|
||||
_CLI_SSO_DEST_KEY_RE = re.compile(r"^[A-Za-z0-9_.-]+$")
|
||||
_CLI_SSO_SECRET_KEY_FRAGMENTS = frozenset(
|
||||
{
|
||||
"access_token",
|
||||
"api_key",
|
||||
"client_secret",
|
||||
"id_token",
|
||||
"password",
|
||||
"private_key",
|
||||
"refresh_token",
|
||||
"secret",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _hash_cli_sso_secret(secret: str) -> str:
|
||||
@ -225,6 +241,239 @@ def _verify_cli_sso_poll_secret(flow: dict, poll_secret: Optional[str]) -> bool:
|
||||
return secrets.compare_digest(supplied_poll_secret_hash, expected_poll_secret_hash)
|
||||
|
||||
|
||||
def _parse_cli_sso_claim_map() -> List[Tuple[str, str]]:
|
||||
"""
|
||||
Parse CLI_SSO_CLAIM_MAP / LITELLM_CLI_SSO_CLAIM_MAP.
|
||||
|
||||
Format: comma-separated ``source_claim->metadata_key`` pairs, e.g.
|
||||
``employment_type->acme_employment_type,org_info.department->department``.
|
||||
Destination keys may use an optional ``metadata.`` prefix; values are stored
|
||||
on the LiteLLM user's ``metadata`` JSON column.
|
||||
"""
|
||||
claim_map_raw = CLI_SSO_CLAIM_MAP.strip()
|
||||
if not claim_map_raw:
|
||||
return []
|
||||
|
||||
parsed: List[Tuple[str, str]] = []
|
||||
for entry in claim_map_raw.split(","):
|
||||
entry = entry.strip()
|
||||
if not entry or "->" not in entry:
|
||||
continue
|
||||
source_claim, dest_key = entry.split("->", 1)
|
||||
source_claim = source_claim.strip()
|
||||
dest_key = dest_key.strip()
|
||||
if dest_key.startswith("metadata."):
|
||||
dest_key = dest_key[len("metadata.") :]
|
||||
if source_claim and dest_key:
|
||||
parsed.append((source_claim, dest_key))
|
||||
return parsed
|
||||
|
||||
|
||||
def _is_safe_cli_sso_metadata_dest_key(dest_key: str) -> bool:
|
||||
if not dest_key or not _CLI_SSO_DEST_KEY_RE.fullmatch(dest_key):
|
||||
return False
|
||||
lowered = dest_key.lower()
|
||||
return not any(fragment in lowered for fragment in _CLI_SSO_SECRET_KEY_FRAGMENTS)
|
||||
|
||||
|
||||
def _is_safe_cli_sso_scalar_claim_value(value: Any) -> bool:
|
||||
if not isinstance(value, _CLI_SSO_SCALAR_TYPES):
|
||||
return False
|
||||
if isinstance(value, str):
|
||||
if len(value) > CLI_SSO_CLAIM_MAX_SCALAR_LENGTH:
|
||||
return False
|
||||
if value.startswith("eyJ") and value.count(".") >= 2:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _sso_result_to_dict(result: Union[CustomOpenID, OpenID, dict]) -> Dict[str, Any]:
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
if hasattr(result, "model_dump"):
|
||||
dumped = result.model_dump()
|
||||
if isinstance(dumped, dict):
|
||||
return cast(Dict[str, Any], dumped)
|
||||
return {}
|
||||
|
||||
|
||||
def _get_nested_claim_value(data: Dict[str, Any], claim_path: str) -> Any:
|
||||
"""Resolve a dot-notation claim path against an SSO result dict.
|
||||
|
||||
Unlike ``get_nested_value``, this does not strip a leading ``metadata.``
|
||||
prefix, since OIDC claims may legitimately use ``metadata`` as a top-level
|
||||
key.
|
||||
"""
|
||||
if not claim_path:
|
||||
return None
|
||||
if claim_path in data:
|
||||
return data[claim_path]
|
||||
placeholder = "\x00"
|
||||
parts = claim_path.replace("\\.", placeholder).split(".")
|
||||
parts = [p.replace(placeholder, ".") for p in parts]
|
||||
current: Any = data
|
||||
for part in parts:
|
||||
if isinstance(current, dict) and part in current:
|
||||
current = current[part]
|
||||
else:
|
||||
return None
|
||||
return current
|
||||
|
||||
|
||||
def _extract_sso_claim_value(
|
||||
result: Union[CustomOpenID, OpenID, dict], claim_path: str
|
||||
) -> Any:
|
||||
extra_fields = getattr(result, "extra_fields", None)
|
||||
if isinstance(extra_fields, dict):
|
||||
if claim_path in extra_fields:
|
||||
return extra_fields[claim_path]
|
||||
nested = _get_nested_claim_value(extra_fields, claim_path)
|
||||
if nested is not None:
|
||||
return nested
|
||||
|
||||
if isinstance(result, dict):
|
||||
return _get_nested_claim_value(result, claim_path)
|
||||
|
||||
result_dict = _sso_result_to_dict(result)
|
||||
return _get_nested_claim_value(result_dict, claim_path)
|
||||
|
||||
|
||||
def _set_nested_metadata_value(
|
||||
metadata: Dict[str, Any], key_path: str, value: Any
|
||||
) -> None:
|
||||
placeholder = "\x00"
|
||||
parts = key_path.replace("\\.", placeholder).split(".")
|
||||
parts = [p.replace(placeholder, ".") for p in parts]
|
||||
current: Any = metadata
|
||||
for part in parts[:-1]:
|
||||
existing = current.get(part)
|
||||
if not isinstance(existing, dict):
|
||||
existing = {}
|
||||
current[part] = existing
|
||||
current = existing
|
||||
current[parts[-1]] = value
|
||||
|
||||
|
||||
def _flatten_cli_sso_metadata_for_poll(
|
||||
metadata: Dict[str, Any],
|
||||
) -> Dict[str, Union[str, int, float, bool]]:
|
||||
"""Expose scalar attribution metadata as a flat dict for CLI poll responses."""
|
||||
flattened: Dict[str, Union[str, int, float, bool]] = {}
|
||||
stack: List[Tuple[str, Any]] = [("", metadata)]
|
||||
while stack:
|
||||
prefix, value = stack.pop()
|
||||
if isinstance(value, dict):
|
||||
for key, nested in value.items():
|
||||
nested_prefix = f"{prefix}.{key}" if prefix else key
|
||||
stack.append((nested_prefix, nested))
|
||||
elif _is_safe_cli_sso_scalar_claim_value(value):
|
||||
flattened[prefix] = value
|
||||
return flattened
|
||||
|
||||
|
||||
def build_cli_sso_attribution_metadata(
|
||||
result: Union[CustomOpenID, OpenID, dict],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build allowlisted, non-secret scalar attribution metadata from an SSO result.
|
||||
|
||||
Sources are configured via CLI_SSO_CLAIM_MAP / LITELLM_CLI_SSO_CLAIM_MAP and
|
||||
may include claims captured by GENERIC_USER_EXTRA_ATTRIBUTES on CustomOpenID.
|
||||
"""
|
||||
claim_map = _parse_cli_sso_claim_map()
|
||||
if not claim_map:
|
||||
return {}
|
||||
|
||||
metadata: Dict[str, Any] = {}
|
||||
for source_claim, dest_key in claim_map:
|
||||
if not _is_safe_cli_sso_metadata_dest_key(dest_key):
|
||||
verbose_proxy_logger.debug(
|
||||
f"Skipping unsafe CLI SSO metadata destination key: {dest_key}"
|
||||
)
|
||||
continue
|
||||
|
||||
raw_value = _extract_sso_claim_value(result=result, claim_path=source_claim)
|
||||
if not _is_safe_cli_sso_scalar_claim_value(raw_value):
|
||||
continue
|
||||
|
||||
_set_nested_metadata_value(
|
||||
metadata=metadata, key_path=dest_key, value=raw_value
|
||||
)
|
||||
|
||||
return metadata
|
||||
|
||||
|
||||
def _merge_cli_sso_attribution_metadata(
|
||||
existing_metadata: Dict[str, Any], attribution_metadata: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
"""Merge attribution metadata into existing user metadata in-place.
|
||||
|
||||
Preserves original value types (in particular, string claim values that
|
||||
happen to look numeric are NOT coerced to ``int``/``float``). Nested dicts
|
||||
are merged iteratively so attribution claims do not clobber unrelated keys
|
||||
under the same parent.
|
||||
"""
|
||||
pending: List[Tuple[Dict[str, Any], Dict[str, Any]]] = [
|
||||
(existing_metadata, attribution_metadata)
|
||||
]
|
||||
while pending:
|
||||
target, source = pending.pop()
|
||||
for key, value in source.items():
|
||||
if value is None:
|
||||
continue
|
||||
existing_value = target.get(key)
|
||||
if isinstance(value, dict) and isinstance(existing_value, dict):
|
||||
pending.append((existing_value, value))
|
||||
else:
|
||||
target[key] = value
|
||||
return existing_metadata
|
||||
|
||||
|
||||
async def _persist_cli_sso_user_metadata(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
attribution_metadata: Dict[str, Any],
|
||||
) -> None:
|
||||
if not attribution_metadata:
|
||||
return
|
||||
|
||||
try:
|
||||
user_row = await prisma_client.db.litellm_usertable.find_unique(
|
||||
where={"user_id": user_id}
|
||||
)
|
||||
existing_metadata: Dict[str, Any] = {}
|
||||
if user_row is not None:
|
||||
row_metadata = user_row.metadata
|
||||
if isinstance(row_metadata, dict):
|
||||
existing_metadata = deepcopy(row_metadata)
|
||||
|
||||
merged_metadata = _merge_cli_sso_attribution_metadata(
|
||||
existing_metadata=existing_metadata,
|
||||
attribution_metadata=attribution_metadata,
|
||||
)
|
||||
await prisma_client.db.litellm_usertable.update_many(
|
||||
where={"user_id": user_id},
|
||||
data={"metadata": merged_metadata},
|
||||
)
|
||||
verbose_proxy_logger.info(
|
||||
f"Persisted CLI SSO attribution metadata for user {user_id}: "
|
||||
f"{list(_flatten_cli_sso_metadata_for_poll(attribution_metadata).keys())}"
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Failed to persist CLI SSO attribution metadata for user {user_id}: {e}"
|
||||
)
|
||||
|
||||
|
||||
def _cli_poll_attribution_metadata_from_session(
|
||||
session_data: Dict[str, Any],
|
||||
) -> Dict[str, Union[str, int, float, bool]]:
|
||||
stored = session_data.get("attribution_metadata")
|
||||
if isinstance(stored, dict):
|
||||
return _flatten_cli_sso_metadata_for_poll(stored)
|
||||
return {}
|
||||
|
||||
|
||||
def _render_cli_sso_verification_page(
|
||||
verify_url: str, browser_complete_token: str
|
||||
) -> str:
|
||||
@ -1674,7 +1923,12 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
|
||||
key_id = state_parts[1] if len(state_parts) > 1 else None
|
||||
|
||||
verbose_proxy_logger.info("CLI SSO callback detected")
|
||||
return await cli_sso_callback(request=request, key=key_id, result=result)
|
||||
return await cli_sso_callback(
|
||||
request=request,
|
||||
key=key_id,
|
||||
result=result,
|
||||
received_response=received_response,
|
||||
)
|
||||
|
||||
# Control-plane cross-origin: read return_to from cookie.
|
||||
# Starlette's cookie_parser already handles RFC 2109 unquoting.
|
||||
@ -1692,15 +1946,144 @@ async def auth_callback(request: Request, state: Optional[str] = None): # noqa:
|
||||
)
|
||||
|
||||
|
||||
async def _build_cli_sso_user_defined_values(
|
||||
result: Union[OpenID, dict],
|
||||
parsed_openid_result: ParsedOpenIDResult,
|
||||
) -> Optional[SSOUserDefinedValues]:
|
||||
from litellm.proxy.proxy_server import user_custom_sso
|
||||
|
||||
user_id = parsed_openid_result.get("user_id")
|
||||
if user_custom_sso is not None:
|
||||
if inspect.iscoroutinefunction(user_custom_sso):
|
||||
return await user_custom_sso(result) # type: ignore
|
||||
raise ValueError("user_custom_sso must be a coroutine function")
|
||||
if user_id is None:
|
||||
return None
|
||||
return SSOUserDefinedValues(
|
||||
models=[],
|
||||
user_id=user_id,
|
||||
user_email=parsed_openid_result.get("user_email"),
|
||||
max_budget=litellm.max_internal_user_budget,
|
||||
user_role=parsed_openid_result.get("user_role"),
|
||||
budget_duration=litellm.internal_user_budget_duration,
|
||||
)
|
||||
|
||||
|
||||
async def _fetch_cli_sso_team_details(
|
||||
prisma_client: PrismaClient,
|
||||
teams: List[str],
|
||||
) -> List[Dict[str, Any]]:
|
||||
team_details: List[Dict[str, Any]] = []
|
||||
try:
|
||||
if teams:
|
||||
prisma_teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
where={"team_id": {"in": teams}}
|
||||
)
|
||||
for team_row in prisma_teams:
|
||||
team_dict = team_row.model_dump()
|
||||
team_details.append(
|
||||
{
|
||||
"team_id": team_dict.get("team_id"),
|
||||
"team_alias": team_dict.get("team_alias"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error fetching team details for CLI SSO session: {e}"
|
||||
)
|
||||
return team_details
|
||||
|
||||
|
||||
async def _complete_cli_sso_callback_session(
|
||||
*,
|
||||
request: Request,
|
||||
key: str,
|
||||
flow: dict,
|
||||
result: Union[OpenID, dict],
|
||||
parsed_openid_result: ParsedOpenIDResult,
|
||||
user_defined_values: Optional[SSOUserDefinedValues],
|
||||
prisma_client: PrismaClient,
|
||||
user_api_key_cache: UserApiKeyCache,
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
):
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
user_id = parsed_openid_result.get("user_id")
|
||||
user_email = parsed_openid_result.get("user_email")
|
||||
user_info = await get_user_info_from_db(
|
||||
result=result,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_email=user_email,
|
||||
user_defined_values=user_defined_values,
|
||||
alternate_user_id=user_id,
|
||||
)
|
||||
if user_info is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to retrieve user information from SSO"
|
||||
)
|
||||
if not user_info.user_id:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to retrieve user information from SSO"
|
||||
)
|
||||
|
||||
teams: List[str] = []
|
||||
if hasattr(user_info, "teams") and user_info.teams:
|
||||
teams = user_info.teams if isinstance(user_info.teams, list) else []
|
||||
|
||||
team_details = await _fetch_cli_sso_team_details(
|
||||
prisma_client=prisma_client, teams=teams
|
||||
)
|
||||
attribution_metadata = build_cli_sso_attribution_metadata(result=result)
|
||||
if attribution_metadata:
|
||||
await _persist_cli_sso_user_metadata(
|
||||
prisma_client=prisma_client,
|
||||
user_id=cast(str, user_info.user_id),
|
||||
attribution_metadata=attribution_metadata,
|
||||
)
|
||||
|
||||
flow["session_data"] = {
|
||||
"user_id": cast(str, user_info.user_id),
|
||||
"user_role": user_info.user_role,
|
||||
"models": user_info.models if hasattr(user_info, "models") else [],
|
||||
"user_email": user_email,
|
||||
"teams": teams,
|
||||
"team_details": team_details,
|
||||
"attribution_metadata": attribution_metadata,
|
||||
}
|
||||
flow["sso_complete"] = True
|
||||
browser_complete_token = secrets.token_urlsafe(32)
|
||||
flow["browser_complete_token_hash"] = _hash_cli_sso_secret(browser_complete_token)
|
||||
_set_cli_sso_flow(login_id=key, cache=user_api_key_cache, flow=flow)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Stored CLI SSO session for user: {user_info.user_id}, teams: {teams}, num_teams: {len(teams)}"
|
||||
)
|
||||
verify_url = get_custom_url(
|
||||
request_base_url=str(request.base_url),
|
||||
route=f"sso/cli/complete/{key}",
|
||||
)
|
||||
return HTMLResponse(
|
||||
content=_render_cli_sso_verification_page(
|
||||
verify_url=verify_url,
|
||||
browser_complete_token=browser_complete_token,
|
||||
),
|
||||
status_code=200,
|
||||
)
|
||||
|
||||
|
||||
async def cli_sso_callback(
|
||||
request: Request,
|
||||
key: Optional[str] = None,
|
||||
result: Optional[Union[OpenID, dict]] = None,
|
||||
received_response: Optional[dict] = None,
|
||||
):
|
||||
"""CLI SSO callback - stores session info for JWT generation on polling"""
|
||||
verbose_proxy_logger.info("CLI SSO callback")
|
||||
|
||||
from litellm.proxy.proxy_server import (
|
||||
general_settings,
|
||||
prisma_client,
|
||||
proxy_logging_obj,
|
||||
user_api_key_cache,
|
||||
@ -1722,92 +2105,40 @@ async def cli_sso_callback(
|
||||
# After None check, cast to non-None type for type checker
|
||||
result_non_none: Union[OpenID, dict] = cast(Union[OpenID, dict], result)
|
||||
|
||||
parsed_openid_result = SSOAuthenticationHandler._get_user_email_and_id_from_result(
|
||||
result=result_non_none
|
||||
)
|
||||
verbose_proxy_logger.debug(f"parsed_openid_result: {parsed_openid_result}")
|
||||
|
||||
try:
|
||||
# Get full user info from DB
|
||||
user_info = await get_user_info_from_db(
|
||||
parsed_openid_result = (
|
||||
SSOAuthenticationHandler._get_user_email_and_id_from_result(
|
||||
result=result_non_none,
|
||||
generic_client_id=os.getenv("GENERIC_CLIENT_ID", None),
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(f"parsed_openid_result: {parsed_openid_result}")
|
||||
user_defined_values = await _build_cli_sso_user_defined_values(
|
||||
result=result_non_none,
|
||||
parsed_openid_result=parsed_openid_result,
|
||||
)
|
||||
|
||||
SSOAuthenticationHandler.verify_user_in_restricted_sso_group(
|
||||
general_settings=general_settings,
|
||||
result=result_non_none,
|
||||
received_response=received_response,
|
||||
)
|
||||
|
||||
return await _complete_cli_sso_callback_session(
|
||||
request=request,
|
||||
key=cast(str, key),
|
||||
flow=flow,
|
||||
result=result_non_none,
|
||||
parsed_openid_result=parsed_openid_result,
|
||||
user_defined_values=user_defined_values,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
user_email=parsed_openid_result.get("user_email"),
|
||||
user_defined_values=None,
|
||||
alternate_user_id=parsed_openid_result.get("user_id"),
|
||||
)
|
||||
|
||||
if user_info is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to retrieve user information from SSO"
|
||||
)
|
||||
|
||||
# Get all teams from user_info - CLI will let user select which one
|
||||
teams: List[str] = []
|
||||
if hasattr(user_info, "teams") and user_info.teams:
|
||||
teams = user_info.teams if isinstance(user_info.teams, list) else []
|
||||
|
||||
# Also fetch team aliases for a better CLI UX. We keep the original
|
||||
# "teams" list of IDs for backwards compatibility and add an
|
||||
# optional "team_details" field containing objects with both
|
||||
# team_id and team_alias.
|
||||
team_details: List[Dict[str, Any]] = []
|
||||
try:
|
||||
if teams:
|
||||
prisma_teams = await prisma_client.db.litellm_teamtable.find_many(
|
||||
where={"team_id": {"in": teams}}
|
||||
)
|
||||
for team_row in prisma_teams:
|
||||
team_dict = team_row.model_dump()
|
||||
team_details.append(
|
||||
{
|
||||
"team_id": team_dict.get("team_id"),
|
||||
"team_alias": team_dict.get("team_alias"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
# If anything goes wrong here, fall back gracefully without
|
||||
# impacting the SSO flow.
|
||||
verbose_proxy_logger.error(
|
||||
f"Error fetching team details for CLI SSO session: {e}"
|
||||
)
|
||||
|
||||
session_data = {
|
||||
"user_id": user_info.user_id,
|
||||
"user_role": user_info.user_role,
|
||||
"models": user_info.models if hasattr(user_info, "models") else [],
|
||||
"user_email": parsed_openid_result.get("user_email"),
|
||||
"teams": teams,
|
||||
# Optional rich metadata for clients that want nicer display
|
||||
"team_details": team_details,
|
||||
}
|
||||
|
||||
flow["session_data"] = session_data
|
||||
flow["sso_complete"] = True
|
||||
browser_complete_token = secrets.token_urlsafe(32)
|
||||
flow["browser_complete_token_hash"] = _hash_cli_sso_secret(
|
||||
browser_complete_token
|
||||
)
|
||||
_set_cli_sso_flow(login_id=cast(str, key), cache=user_api_key_cache, flow=flow)
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
f"Stored CLI SSO session for user: {user_info.user_id}, teams: {teams}, num_teams: {len(teams)}"
|
||||
)
|
||||
|
||||
from fastapi.responses import HTMLResponse
|
||||
|
||||
verify_url = get_custom_url(
|
||||
request_base_url=str(request.base_url),
|
||||
route=f"sso/cli/complete/{key}",
|
||||
)
|
||||
html_content = _render_cli_sso_verification_page(
|
||||
verify_url=verify_url,
|
||||
browser_complete_token=browser_complete_token,
|
||||
)
|
||||
return HTMLResponse(content=html_content, status_code=200)
|
||||
|
||||
except ProxyException:
|
||||
raise
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error with CLI SSO callback: {e}")
|
||||
raise HTTPException(
|
||||
@ -1873,13 +2204,19 @@ async def cli_poll_key(
|
||||
team_details_response = [
|
||||
{"team_id": t, "team_alias": None} for t in user_teams
|
||||
]
|
||||
return {
|
||||
poll_response: Dict[str, Any] = {
|
||||
"status": "ready",
|
||||
"user_id": user_id,
|
||||
"teams": user_teams,
|
||||
"team_details": team_details_response,
|
||||
"requires_team_selection": True,
|
||||
}
|
||||
attribution_metadata = _cli_poll_attribution_metadata_from_session(
|
||||
session_data
|
||||
)
|
||||
if attribution_metadata:
|
||||
poll_response["attribution_metadata"] = attribution_metadata
|
||||
return poll_response
|
||||
|
||||
# Validate team_id if provided
|
||||
if team_id is not None:
|
||||
@ -1912,7 +2249,7 @@ async def cli_poll_key(
|
||||
verbose_proxy_logger.info(
|
||||
f"CLI JWT generated for user: {user_id}, team: {team_id}"
|
||||
)
|
||||
return {
|
||||
poll_response = {
|
||||
"status": "ready",
|
||||
"key": jwt_token,
|
||||
"user_id": user_id,
|
||||
@ -1922,6 +2259,12 @@ async def cli_poll_key(
|
||||
# present nicer information if needed.
|
||||
"team_details": user_team_details,
|
||||
}
|
||||
attribution_metadata = _cli_poll_attribution_metadata_from_session(
|
||||
session_data
|
||||
)
|
||||
if attribution_metadata:
|
||||
poll_response["attribution_metadata"] = attribution_metadata
|
||||
return poll_response
|
||||
else:
|
||||
return {"status": "pending"}
|
||||
|
||||
|
||||
@ -2438,6 +2438,7 @@ class TestCLIKeyRegenerationFlow:
|
||||
request=mock_request,
|
||||
key="cli-new-session-key-456",
|
||||
result=mock_result,
|
||||
received_response=None,
|
||||
)
|
||||
|
||||
def test_get_redirect_url_does_not_include_existing_key_in_url(self):
|
||||
@ -5552,6 +5553,289 @@ def test_generic_response_convertor_extra_attributes_missing_field(monkeypatch):
|
||||
assert result.extra_fields["another_missing"] is None
|
||||
|
||||
|
||||
class TestCliSsoAttributionMetadata:
|
||||
"""CLI SSO allowlisted OIDC claim persistence and poll exposure."""
|
||||
|
||||
def test_parse_cli_sso_claim_map(self, monkeypatch):
|
||||
from litellm.proxy.management_endpoints import ui_sso
|
||||
|
||||
monkeypatch.setattr(
|
||||
ui_sso,
|
||||
"CLI_SSO_CLAIM_MAP",
|
||||
"employment_type->metadata.acme_employment_type, org_info.department -> department",
|
||||
)
|
||||
assert ui_sso._parse_cli_sso_claim_map() == [
|
||||
("employment_type", "acme_employment_type"),
|
||||
("org_info.department", "department"),
|
||||
]
|
||||
|
||||
def test_build_cli_sso_attribution_metadata_filters_non_scalars(self, monkeypatch):
|
||||
from litellm.proxy.management_endpoints import ui_sso
|
||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||
|
||||
monkeypatch.setattr(
|
||||
ui_sso,
|
||||
"CLI_SSO_CLAIM_MAP",
|
||||
"employment_type->acme_employment_type,access_token->should_drop,group->groups",
|
||||
)
|
||||
|
||||
result = CustomOpenID(
|
||||
id="user-1",
|
||||
email="user@example.com",
|
||||
display_name="User",
|
||||
provider="generic",
|
||||
team_ids=[],
|
||||
extra_fields={
|
||||
"employment_type": "full_time",
|
||||
"access_token": "eyJhbGciOiJIUzI1NiJ9.payload.signature",
|
||||
"group": ["team-a", "team-b"],
|
||||
},
|
||||
)
|
||||
|
||||
metadata = ui_sso.build_cli_sso_attribution_metadata(result=result)
|
||||
assert metadata == {"acme_employment_type": "full_time"}
|
||||
|
||||
def test_build_cli_sso_attribution_metadata_from_oidc_dict(self, monkeypatch):
|
||||
from litellm.proxy.management_endpoints import ui_sso
|
||||
|
||||
monkeypatch.setattr(
|
||||
ui_sso,
|
||||
"CLI_SSO_CLAIM_MAP",
|
||||
"org_info.department->department",
|
||||
)
|
||||
|
||||
metadata = ui_sso.build_cli_sso_attribution_metadata(
|
||||
result={
|
||||
"sub": "user-1",
|
||||
"email": "user@example.com",
|
||||
"org_info": {"department": "Engineering"},
|
||||
}
|
||||
)
|
||||
assert metadata == {"department": "Engineering"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cli_sso_callback_passes_user_defined_values_for_new_users(self):
|
||||
"""First CLI SSO login must supply SSOUserDefinedValues so upsert can create the user."""
|
||||
from litellm.proxy._types import LiteLLM_UserTable
|
||||
from litellm.proxy.management_endpoints import ui_sso
|
||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.base_url = "http://internal-proxy.local/"
|
||||
session_key = "cli-session-new-user"
|
||||
mock_user_info = LiteLLM_UserTable(
|
||||
user_id="cli-test-user",
|
||||
user_role="internal_user",
|
||||
teams=[],
|
||||
models=[],
|
||||
)
|
||||
mock_sso_result = CustomOpenID(
|
||||
id="cli-test-user",
|
||||
email="cli-test@example.com",
|
||||
display_name="cli-test-user",
|
||||
provider="generic",
|
||||
team_ids=[],
|
||||
)
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get_cache.return_value = {
|
||||
"poll_secret_hash": "poll-secret-hash",
|
||||
"user_code_hash": "user-code-hash",
|
||||
"sso_complete": False,
|
||||
"user_code_verified": False,
|
||||
"session_data": None,
|
||||
}
|
||||
get_user_info_mock = AsyncMock(return_value=mock_user_info)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.management_endpoints.ui_sso.get_user_info_from_db",
|
||||
get_user_info_mock,
|
||||
),
|
||||
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
|
||||
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
|
||||
patch("litellm.proxy.proxy_server.user_custom_sso", None),
|
||||
):
|
||||
await ui_sso.cli_sso_callback(
|
||||
request=mock_request,
|
||||
key=session_key,
|
||||
result=mock_sso_result,
|
||||
)
|
||||
|
||||
get_user_info_mock.assert_awaited_once()
|
||||
assert get_user_info_mock.call_args.kwargs["user_defined_values"] is not None
|
||||
assert (
|
||||
get_user_info_mock.call_args.kwargs["user_defined_values"]["user_id"]
|
||||
== "cli-test-user"
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cli_sso_callback_rejects_restricted_sso_group(self):
|
||||
"""CLI SSO must enforce restricted_sso_group before upserting the user."""
|
||||
from litellm.proxy._types import ProxyException
|
||||
from litellm.proxy.management_endpoints import ui_sso
|
||||
from litellm.proxy.management_endpoints.types import CustomOpenID
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.base_url = "http://internal-proxy.local/"
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get_cache.return_value = {
|
||||
"poll_secret_hash": "poll-secret-hash",
|
||||
"user_code_hash": "user-code-hash",
|
||||
"sso_complete": False,
|
||||
"user_code_verified": False,
|
||||
"session_data": None,
|
||||
}
|
||||
mock_sso_result = CustomOpenID(
|
||||
id="cli-test-user",
|
||||
email="cli-test@example.com",
|
||||
display_name="cli-test-user",
|
||||
provider="generic",
|
||||
team_ids=["other-group"],
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.management_endpoints.ui_sso.get_user_info_from_db",
|
||||
new=AsyncMock(),
|
||||
) as get_user_info_mock,
|
||||
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
|
||||
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
|
||||
patch("litellm.proxy.proxy_server.user_custom_sso", None),
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.general_settings",
|
||||
{
|
||||
"ui_access_mode": {
|
||||
"type": "restricted_sso_group",
|
||||
"restricted_sso_group": "required-group",
|
||||
}
|
||||
},
|
||||
),
|
||||
):
|
||||
with pytest.raises(ProxyException):
|
||||
await ui_sso.cli_sso_callback(
|
||||
request=mock_request,
|
||||
key="cli-session-restricted",
|
||||
result=mock_sso_result,
|
||||
received_response={"groups": ["other-group"]},
|
||||
)
|
||||
|
||||
get_user_info_mock.assert_not_awaited()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cli_sso_callback_persists_attribution_metadata(self, monkeypatch):
|
||||
from litellm.proxy._types import LiteLLM_UserTable
|
||||
from litellm.proxy.management_endpoints import ui_sso
|
||||
|
||||
monkeypatch.setattr(
|
||||
ui_sso,
|
||||
"CLI_SSO_CLAIM_MAP",
|
||||
"employment_type->acme_employment_type",
|
||||
)
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.base_url = "http://internal-proxy.local/"
|
||||
session_key = "cli-session-4567890"
|
||||
mock_user_info = LiteLLM_UserTable(
|
||||
user_id="test-user-123",
|
||||
user_role="internal_user",
|
||||
teams=["team1"],
|
||||
models=["gpt-4"],
|
||||
)
|
||||
mock_sso_result = {
|
||||
"user_email": "test@example.com",
|
||||
"user_id": "test-user-123",
|
||||
"employment_type": "contractor",
|
||||
}
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get_cache.return_value = {
|
||||
"poll_secret_hash": "poll-secret-hash",
|
||||
"user_code_hash": "user-code-hash",
|
||||
"sso_complete": False,
|
||||
"user_code_verified": False,
|
||||
"session_data": None,
|
||||
}
|
||||
mock_prisma = MagicMock()
|
||||
mock_prisma.db.litellm_usertable.find_unique = AsyncMock(
|
||||
return_value=MagicMock(metadata={"auth_provider": "generic"})
|
||||
)
|
||||
mock_prisma.db.litellm_usertable.update_many = AsyncMock()
|
||||
|
||||
with (
|
||||
patch.dict(
|
||||
os.environ,
|
||||
{
|
||||
"PROXY_BASE_URL": "https://test.example.com",
|
||||
"SERVER_ROOT_PATH": "",
|
||||
},
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.management_endpoints.ui_sso.get_user_info_from_db",
|
||||
return_value=mock_user_info,
|
||||
),
|
||||
patch("litellm.proxy.proxy_server.prisma_client", mock_prisma),
|
||||
patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache),
|
||||
patch("litellm.proxy.proxy_server.user_custom_sso", None),
|
||||
patch(
|
||||
"litellm.proxy.common_utils.html_forms.cli_sso_success.render_cli_sso_success_page",
|
||||
return_value="<html>Success</html>",
|
||||
),
|
||||
):
|
||||
await ui_sso.cli_sso_callback(
|
||||
request=mock_request,
|
||||
key=session_key,
|
||||
result=mock_sso_result,
|
||||
)
|
||||
|
||||
flow_data = mock_cache.set_cache.call_args.kwargs["value"]
|
||||
assert flow_data["session_data"]["attribution_metadata"] == {
|
||||
"acme_employment_type": "contractor"
|
||||
}
|
||||
mock_prisma.db.litellm_usertable.update_many.assert_awaited_once()
|
||||
update_data = mock_prisma.db.litellm_usertable.update_many.call_args.kwargs[
|
||||
"data"
|
||||
]
|
||||
assert update_data["metadata"]["acme_employment_type"] == "contractor"
|
||||
assert update_data["metadata"]["auth_provider"] == "generic"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cli_poll_key_returns_attribution_metadata(self, monkeypatch):
|
||||
from litellm.proxy.management_endpoints.ui_sso import (
|
||||
_hash_cli_sso_secret,
|
||||
cli_poll_key,
|
||||
)
|
||||
|
||||
session_key = "cli-session-789123"
|
||||
session_data = {
|
||||
"user_id": "test-user-456",
|
||||
"user_role": "internal_user",
|
||||
"teams": ["team-a", "team-b"],
|
||||
"models": ["gpt-4"],
|
||||
"attribution_metadata": {
|
||||
"acme_employment_type": "full_time",
|
||||
"org": {"cost_center": "CC-42"},
|
||||
},
|
||||
}
|
||||
mock_cache = MagicMock()
|
||||
mock_cache.get_cache.return_value = {
|
||||
"poll_secret_hash": _hash_cli_sso_secret("poll-secret"),
|
||||
"sso_complete": True,
|
||||
"user_code_verified": True,
|
||||
"session_data": session_data,
|
||||
}
|
||||
|
||||
with patch("litellm.proxy.proxy_server.user_api_key_cache", mock_cache):
|
||||
result = await cli_poll_key(
|
||||
key_id=session_key,
|
||||
team_id=None,
|
||||
x_litellm_cli_poll_secret="poll-secret",
|
||||
)
|
||||
|
||||
assert result["attribution_metadata"] == {
|
||||
"acme_employment_type": "full_time",
|
||||
"org.cost_center": "CC-42",
|
||||
}
|
||||
|
||||
|
||||
class TestValidateReturnTo:
|
||||
"""Tests for SSOAuthenticationHandler._validate_return_to"""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user