diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260513120000_add_delegate_auth_to_upstream_to_mcp_servers/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260513120000_add_delegate_auth_to_upstream_to_mcp_servers/migration.sql new file mode 100644 index 0000000000..50a4874390 --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260513120000_add_delegate_auth_to_upstream_to_mcp_servers/migration.sql @@ -0,0 +1,2 @@ +-- AlterTable +ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN IF NOT EXISTS "delegate_auth_to_upstream" BOOLEAN NOT NULL DEFAULT false; diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 84ce99557e..b53507abe6 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -323,6 +323,7 @@ model LiteLLM_MCPServerTable { registration_url String? allow_all_keys Boolean @default(false) available_on_public_internet Boolean @default(true) + delegate_auth_to_upstream Boolean @default(false) is_byok Boolean @default(false) byok_description String[] @default([]) byok_api_key_help_url String? diff --git a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py index a05af66118..c87e8c414c 100644 --- a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py +++ b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py @@ -1,3 +1,4 @@ +import re from typing import Dict, List, Optional, Set, Tuple, cast from fastapi import HTTPException @@ -122,6 +123,24 @@ class MCPRequestHandler: # cannot be smuggled via query string, hostname, or a deeper URL segment. if request.url.path.startswith("/.well-known/"): validated_user_api_key_auth = UserAPIKeyAuth() + elif ( + not litellm_api_key + and MCPRequestHandler._target_servers_delegate_auth_to_upstream( # noqa: E501 + path=request.url.path, mcp_servers=mcp_servers + ) + ): + # Operator opted this oauth2 server into upstream-delegated auth + # (PKCE passthrough): skip LiteLLM API-key/SSO entirely so the + # client authenticates directly with the upstream MCP server. + # Fires ONLY when neither x-litellm-api-key nor Authorization is + # present. If any LiteLLM key is supplied (primary or secondary + # header), we fall through so user_id is resolved, spend/rate + # limiting apply, and any stored OAuth token can be retrieved + # and forwarded upstream. Gated by + # _target_servers_delegate_auth_to_upstream, which only returns + # True when EVERY target is auth_type=oauth2 AND has the + # delegate_auth_to_upstream flag set — fails closed otherwise. + validated_user_api_key_auth = UserAPIKeyAuth() elif has_explicit_litellm_key: # Explicit x-litellm-api-key provided - always validate normally validated_user_api_key_auth = await user_api_key_auth( @@ -181,23 +200,62 @@ class MCPRequestHandler: @staticmethod def _extract_target_server_names_from_path(path: str) -> List[str]: """ - Extract the target MCP server name from the standard MCP transport - URL patterns: ``/mcp/{server_name}[/...]`` and + Extract the target MCP server name(s) from the standard MCP transport + URL patterns: ``/mcp/{server_name_or_csv}[/...]`` and ``/{server_name}/mcp[/...]``. Returns ``[]`` for any other path so callers fail closed when the target cannot be resolved. + Mirrors the regex-based parser in ``server.py::_get_mcp_servers_in_path`` + so the names used for auth gating match the names used for downstream + filtering. Without this alignment, an attacker could craft + ``/mcp//`` so that auth treats the request + as targeting the delegate server (bypassing LiteLLM auth) while + downstream filtering sees a different (non-existent) target and falls + back to the caller's full allowed-server set. + REST/admin endpoints, OAuth2 server endpoints (``/{server_name}/authorize``, ``/token`` etc.), and ``.well-known`` discovery routes intentionally fall through — those flows do not need OAuth2 token passthrough. Clients aggregating multiple servers should - use ``x-mcp-servers``, which takes precedence over path parsing. + use ``x-mcp-servers`` on a path that does not encode a target. """ + # ``/{server_name}/mcp[/...]`` form — single server. The literal + # ``mcp`` must be the second segment (not the first, which would be + # the ``/mcp/...`` form handled below). This branch must stay in sync + # with ``server.py::_get_mcp_servers_in_path``, which also accepts the + # un-rewritten form (some entry points may skip the + # ``dynamic_mcp_route`` rewrite). segments = [s for s in path.split("/") if s] - if len(segments) >= 2 and segments[0] == "mcp": - return [segments[1]] - if len(segments) >= 2 and segments[1] == "mcp": + if len(segments) >= 2 and segments[1] == "mcp" and segments[0] != "mcp": return [segments[0]] - return [] + + # ``/mcp/...`` form — server name(s) may contain a slash (e.g. + # ``custom_solutions/user_123``) and may be a comma-separated list. + # Use the same parsing logic as ``_get_mcp_servers_in_path`` so the + # parsed names match downstream routing. + mcp_path_match = re.match(r"^/mcp/([^?#]+)(?:\?.*)?(?:#.*)?$", path) + if not mcp_path_match: + return [] + servers_and_path = mcp_path_match.group(1) + if not servers_and_path: + return [] + + if "," in servers_and_path: + # Comma-separated servers, possibly followed by a trailing path. + path_match = re.search(r"/([^/,]+(?:/[^/,]+)*)$", servers_and_path) + if path_match: + servers_part = servers_and_path[: -(len(path_match.group(1)) + 1)] + else: + servers_part = servers_and_path + return [s.strip() for s in servers_part.split(",") if s.strip()] + + # Single-server case — server name may contain at most one slash. + single_server_match = re.match( + r"^([^/]+(?:/[^/]+)?)(?:/.*)?$", servers_and_path + ) + if single_server_match: + return [single_server_match.group(1)] + return [servers_and_path] @staticmethod def _target_servers_use_oauth2(path: str, mcp_servers: Optional[List[str]]) -> bool: @@ -217,13 +275,13 @@ class MCPRequestHandler: ) from litellm.types.mcp import MCPAuth - # Use the x-mcp-servers header verbatim when present (including the - # explicitly-empty list, which means "no targets" → fail closed). - # Only fall back to path parsing when the header was absent entirely. - target_names = ( - mcp_servers - if mcp_servers is not None - else MCPRequestHandler._extract_target_server_names_from_path(path) + # Resolve the same target list downstream routing will use. For + # ``/mcp/...`` routes, ``extract_mcp_auth_context`` overrides the + # ``x-mcp-servers`` header with path-derived names, so we must mirror + # that here — otherwise a caller could set the header to a permissive + # server while the path targets a stricter one (header/path TOCTOU). + target_names = MCPRequestHandler._resolve_target_server_names( + path=path, mcp_servers_header=mcp_servers ) if not target_names: return False @@ -234,6 +292,78 @@ class MCPRequestHandler: return False return True + @staticmethod + def _target_servers_delegate_auth_to_upstream( + path: str, mcp_servers: Optional[List[str]] + ) -> bool: + """ + True only when EVERY MCP server the request targets is configured for + ``auth_type == oauth2`` AND has ``delegate_auth_to_upstream=True``. + Fails closed when any target does not opt in or cannot be resolved. + + Used by :meth:`process_mcp_request` to skip LiteLLM API-key/SSO auth + entirely (PKCE passthrough) so the client authenticates directly with + the upstream MCP server. Mixed-target requests (e.g. one delegated + + one non-delegated server) fall back to normal LiteLLM auth. + """ + # Inline imports avoid a circular dependency: mcp_server_manager imports + # from this module. + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from litellm.types.mcp import MCPAuth + + # See _target_servers_use_oauth2: must mirror the downstream + # header-vs-path override or an attacker could set + # ``x-mcp-servers`` to a delegate-enabled server while the URL path + # targets a non-delegate server, skipping LiteLLM auth for it. + target_names = MCPRequestHandler._resolve_target_server_names( + path=path, mcp_servers_header=mcp_servers + ) + if not target_names: + return False + + for name in target_names: + server = global_mcp_server_manager.get_mcp_server_by_name(name) + if server is None or server.auth_type != MCPAuth.oauth2: + return False + # `is True` is intentional: opt-in must be an explicit boolean + # True. A MagicMock attribute (in tests) or any other truthy + # non-bool must not silently enable the bypass. + if getattr(server, "delegate_auth_to_upstream", False) is not True: + return False + if not getattr(server, "available_on_public_internet", True): + return False + # Never delegate for M2M (client_credentials) servers: LiteLLM + # fetches the upstream token automatically using stored credentials, + # so allowing anonymous bypass would let any external caller invoke + # tools authenticated as LiteLLM's service account. + if server.has_client_credentials: + return False + return True + + @staticmethod + def _resolve_target_server_names( + path: str, mcp_servers_header: Optional[List[str]] + ) -> List[str]: + """ + Resolve the target MCP server names exactly as downstream routing + does (``server.py::extract_mcp_auth_context``). + + For ``/mcp/...`` paths, downstream routing **overrides** any + ``x-mcp-servers`` header value with the path-derived names. Mirror + that here so an attacker cannot use a permissive header value to + flip an auth gate while the path targets a stricter server + (header/path TOCTOU). For non-``/mcp/...`` paths (where the path + does not encode targets), fall back to the header. + """ + path_targets = MCPRequestHandler._extract_target_server_names_from_path(path) + if path_targets: + return path_targets + # Path did not resolve to /mcp/... targets — trust the header + # (including an explicitly empty list, which means "no targets"). + return mcp_servers_header if mcp_servers_header is not None else [] + @staticmethod def _get_mcp_auth_header_from_headers(headers: Headers) -> Optional[str]: """ diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 4901bc76d2..31ed0918f3 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -402,6 +402,9 @@ class MCPServerManager: available_on_public_internet=bool( server_config.get("available_on_public_internet", True) ), + delegate_auth_to_upstream=bool( + server_config.get("delegate_auth_to_upstream", False) + ), # AWS SigV4 fields aws_access_key_id=server_config.get("aws_access_key_id", None), aws_secret_access_key=server_config.get("aws_secret_access_key", None), @@ -796,6 +799,9 @@ class MCPServerManager: available_on_public_internet=bool( getattr(mcp_server, "available_on_public_internet", True) ), + delegate_auth_to_upstream=bool( + getattr(mcp_server, "delegate_auth_to_upstream", False) + ), created_at=getattr(mcp_server, "created_at", None), updated_at=getattr(mcp_server, "updated_at", None), tool_name_to_display_name=_deserialize_json_dict( @@ -967,6 +973,34 @@ class MCPServerManager: if not in_toolset_scope: combined_servers.update(allow_all_server_ids) + # For anonymous callers (no user_id, no role), also surface any + # servers the operator has opted into upstream-delegated auth. + # These servers handle their own auth at the upstream level, so + # LiteLLM granting access here does not bypass any security gate. + is_anonymous = not ( + user_api_key_auth + and ( + getattr(user_api_key_auth, "user_id", None) + or getattr(user_api_key_auth, "user_role", None) + or getattr(user_api_key_auth, "api_key", None) + ) + ) + if is_anonymous: + delegate_server_ids = [ + server.server_id + for server in self.get_registry().values() + if getattr(server, "auth_type", None) == MCPAuth.oauth2 + and getattr(server, "delegate_auth_to_upstream", False) is True + # M2M servers must not be exposed anonymously: an + # unauthenticated caller would get LiteLLM to proxy tool + # calls using its stored client_credentials. + and not server.has_client_credentials + # Internal-only servers must not be reachable from public + # internet callers who happen to carry an upstream token. + and getattr(server, "available_on_public_internet", True) + ] + combined_servers.update(delegate_server_ids) + if len(combined_servers) == 0: verbose_logger.debug( "No allowed MCP Servers found for user api key auth." diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 685232d049..0a74a92f9c 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -1337,8 +1337,24 @@ if MCP_AVAILABLE: raw_headers=raw_headers, ) - # If no OAuth2 token came from request headers, fall back to pre-fetched creds - if extra_headers is None and server.auth_type == MCPAuth.oauth2: + # Prefer server-stored per-user OAuth when configured, so a stale + # Authorization header from the MCP client cannot override Redis/DB + # (same issue as call_tool in mcp_server_manager: VS Code caches tokens). + if ( + server.auth_type == MCPAuth.oauth2 + and getattr(server, "needs_user_oauth_token", False) + and user_api_key_auth is not None + ): + db_headers = await _get_user_oauth_extra_headers_from_db( + server, + user_api_key_auth, + prefetched_creds=_prefetched_oauth_creds, + ) + if db_headers: + extra_headers = db_headers + + # If still no OAuth2 token, fall back to pre-fetched creds (non-stale-client path) + elif extra_headers is None and server.auth_type == MCPAuth.oauth2: extra_headers = await _get_user_oauth_extra_headers_from_db( server, user_api_key_auth, @@ -2541,6 +2557,10 @@ if MCP_AVAILABLE: import re mcp_servers_from_path: Optional[List[str]] = None + segments = [s for s in path.split("/") if s] + if len(segments) >= 2 and segments[1] == "mcp" and segments[0] != "mcp": + return [segments[0]] + # Match /mcp/ # Where servers can be comma-separated list of server names # Server names can contain slashes (e.g., "custom_solutions/user_123") diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index f513381c86..9a9d27b9b8 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1274,6 +1274,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase): oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = None allow_all_keys: bool = False available_on_public_internet: bool = True + delegate_auth_to_upstream: bool = False is_byok: bool = False byok_description: List[str] = Field(default_factory=list) byok_api_key_help_url: Optional[str] = None @@ -1356,6 +1357,7 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase): registration_url: Optional[str] = None allow_all_keys: bool = False available_on_public_internet: bool = True + delegate_auth_to_upstream: bool = False is_byok: bool = False byok_description: List[str] = Field(default_factory=list) byok_api_key_help_url: Optional[str] = None @@ -1427,6 +1429,7 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase): registration_url: Optional[str] = None allow_all_keys: bool = False available_on_public_internet: bool = True + delegate_auth_to_upstream: bool = False is_byok: bool = False byok_description: List[str] = Field(default_factory=list) byok_api_key_help_url: Optional[str] = None diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py index f2e64fdde8..587b80d472 100644 --- a/litellm/proxy/management_endpoints/mcp_management_endpoints.py +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -163,7 +163,7 @@ if MCP_AVAILABLE: ) from litellm.proxy.management_endpoints.common_utils import _user_has_admin_view from litellm.proxy.management_helpers.utils import management_endpoint_wrapper - from litellm.types.mcp import MCPCredentials + from litellm.types.mcp import MCPAuth, MCPCredentials from litellm.types.mcp_server.mcp_server_manager import MCPServer @dataclass @@ -1551,6 +1551,57 @@ if MCP_AVAILABLE: except _jwt.InvalidTokenError: pass + # For delegate_auth_to_upstream servers the entire PKCE handshake + # (both /authorize browser redirect and /token authorization_code + # exchange) must work without a LiteLLM session. /authorize is opened + # in a VS Code webview that may have no cookie; /token is a programmatic + # POST from VS Code. PKCE security (code_verifier) guarantees the + # authorization_code exchange cannot be replayed, so anonymous access + # is safe for that grant only. + # + # Importantly, NOT safe for refresh_token grants: ``mcp_token`` will + # forward the request to the upstream issuer with LiteLLM's stored + # ``client_secret`` attached, so any caller holding a refresh token + # issued to this client could mint fresh upstream access tokens through + # us. Require normal LiteLLM auth for those. + if not api_key: + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( # noqa: PLC0415 + global_mcp_server_manager, + ) + + server_id = request.path_params.get("server_id", "") + if server_id: + _s = global_mcp_server_manager.get_mcp_server_by_id(server_id) + if not _s: + _s = global_mcp_server_manager.get_mcp_server_by_name(server_id) + if ( + _s + and getattr(_s, "auth_type", None) == MCPAuth.oauth2 + and getattr(_s, "delegate_auth_to_upstream", False) is True + and getattr(_s, "available_on_public_internet", True) + # M2M servers fetch tokens with stored credentials; never + # expose their /authorize or /token endpoints anonymously. + and not _s.has_client_credentials + ): + # For /token, require PKCE authorization_code; refresh_token + # grants must NOT bypass auth (see comment above). + path_lower = (request.url.path or "").rstrip("/").lower() + if path_lower.endswith("/token"): + body_data = await _read_request_body(request=request) + grant_type = (body_data or {}).get("grant_type", "") + if grant_type != "authorization_code": + # Fall through to normal LiteLLM auth (will 401 if + # no key supplied). + pass + else: + return UserAPIKeyAuth() + else: + # /authorize and other PKCE-flow GETs are safe to + # bypass: PKCE binds the upstream issuer's ``code`` + # to the original ``code_challenge`` so no anonymous + # token can be minted via the redirect alone. + return UserAPIKeyAuth() + request_data = await _read_request_body(request=request) request_data = populate_request_with_path_params( request_data=request_data, request=request diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 84ce99557e..b53507abe6 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -323,6 +323,7 @@ model LiteLLM_MCPServerTable { registration_url String? allow_all_keys Boolean @default(false) available_on_public_internet Boolean @default(true) + delegate_auth_to_upstream Boolean @default(false) is_byok Boolean @default(false) byok_description String[] @default([]) byok_api_key_help_url String? diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index 268d064eac..776c7fa67a 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -68,6 +68,12 @@ class MCPServer(BaseModel): access_groups: Optional[List[str]] = None allow_all_keys: bool = False available_on_public_internet: bool = True + # When True AND auth_type == oauth2, MCP requests targeting this server + # bypass LiteLLM API-key/SSO auth (and the pre-emptive 401) so the client + # completes PKCE directly with the upstream MCP server. Honored only for + # auth_type=oauth2; ignored for any other auth_type. See + # MCPRequestHandler._target_servers_delegate_auth_to_upstream. + delegate_auth_to_upstream: bool = False is_byok: bool = False byok_description: List[str] = [] byok_api_key_help_url: Optional[str] = None diff --git a/schema.prisma b/schema.prisma index 84ce99557e..b53507abe6 100644 --- a/schema.prisma +++ b/schema.prisma @@ -323,6 +323,7 @@ model LiteLLM_MCPServerTable { registration_url String? allow_all_keys Boolean @default(false) available_on_public_internet Boolean @default(true) + delegate_auth_to_upstream Boolean @default(false) is_byok Boolean @default(false) byok_description String[] @default([]) byok_api_key_help_url String? diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py index 6e0dadcd4d..5bb16a4cd4 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py @@ -1075,6 +1075,727 @@ class TestMCPOAuth2FallbackTargetGating: await MCPRequestHandler.process_mcp_request(scope) +@pytest.mark.asyncio +class TestMCPDelegateAuthToUpstream: + """ + Tests for the ``delegate_auth_to_upstream`` per-server flag. + + When set on an ``auth_type=oauth2`` MCP server, LiteLLM must skip its own + API-key/SSO check entirely so the client completes PKCE directly with the + upstream MCP server. The gate must fail closed for any non-oauth2 server, + any mixed-target request, and any request where the target cannot be + resolved. + """ + + @staticmethod + def _make_server(auth_type, delegate_auth_to_upstream=False): + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + return MCPServer( + server_id="test-server-id", + name="test-server", + transport="http", + auth_type=auth_type, + delegate_auth_to_upstream=delegate_auth_to_upstream, + ) + + async def test_delegate_skips_litellm_auth_with_no_authorization(self): + """ + oauth2 + delegate_auth_to_upstream=True, no Authorization header at + all → anonymous UserAPIKeyAuth and ``user_api_key_auth`` is never + called. + """ + from litellm.types.mcp import MCPAuth + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/delegated_oauth_server", + "headers": [], + } + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + ) as mock_auth, + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = ( + TestMCPDelegateAuthToUpstream._make_server( + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + ) + ) + (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) + assert isinstance(auth_result, UserAPIKeyAuth) + mock_auth.assert_not_called() + + async def test_delegate_with_upstream_token_in_authorization_falls_back_to_anonymous( + self, + ): + """ + oauth2 + delegate_auth_to_upstream=True with an upstream OAuth token in + ``Authorization`` (not a LiteLLM key): LiteLLM auth is attempted first + (and fails), then the existing oauth2 fallback returns anonymous so the + bearer is forwarded upstream untouched. The delegate branch itself does + not fire when Authorization is present — that is what protects spend + tracking for callers using Authorization-style LiteLLM keys. + """ + from fastapi import HTTPException + + from litellm.types.mcp import MCPAuth + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/delegated_oauth_server", + "headers": [(b"authorization", b"Bearer upstream-pkce-token")], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = ( + TestMCPDelegateAuthToUpstream._make_server( + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + ) + ) + ( + auth_result, + _, + _, + _, + oauth2_headers, + _, + ) = await MCPRequestHandler.process_mcp_request(scope) + assert isinstance(auth_result, UserAPIKeyAuth) + assert oauth2_headers.get("Authorization") == "Bearer upstream-pkce-token" + + async def test_delegate_off_still_requires_litellm_auth(self): + """ + oauth2 server but delegate flag is OFF → existing behaviour: a missing + / invalid LiteLLM key still 401s (no anonymous fast-path). + """ + from fastapi import HTTPException + + from litellm.types.mcp import MCPAuth + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/non_delegated_oauth_server", + "headers": [], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = ( + TestMCPDelegateAuthToUpstream._make_server( + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=False, + ) + ) + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + + async def test_delegate_ignored_for_non_oauth2_server(self): + """ + Defense in depth: even if an operator turns on delegate_auth_to_upstream + for a non-oauth2 server (api_key, bearer_token, etc.), the gate must + not fire — only oauth2 servers may delegate. + """ + from fastapi import HTTPException + + from litellm.types.mcp import MCPAuth + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/api_key_server", + "headers": [], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = ( + TestMCPDelegateAuthToUpstream._make_server( + auth_type=MCPAuth.api_key, + delegate_auth_to_upstream=True, + ) + ) + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + + async def test_delegate_mixed_targets_fail_closed(self): + """ + x-mcp-servers can list multiple targets. If ANY of them does not opt in + to delegate_auth_to_upstream, the bypass must NOT fire — otherwise an + attacker could mix one delegated server in to skip auth on the others. + """ + from fastapi import HTTPException + + from litellm.types.mcp import MCPAuth + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [ + (b"x-mcp-servers", b"delegated_oauth,plain_oauth"), + ], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + def mock_lookup(name, client_ip=None): + if name == "delegated_oauth": + return TestMCPDelegateAuthToUpstream._make_server( + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + ) + return TestMCPDelegateAuthToUpstream._make_server( + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=False, + ) + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.side_effect = mock_lookup + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + + async def test_delegate_no_resolvable_target_fail_closed(self): + """ + If the target server cannot be resolved at all (e.g. admin/REST path + that isn't ``/mcp/{name}`` or ``/{name}/mcp``), we cannot prove the + gate's preconditions, so we must fail closed and run normal auth. + """ + from fastapi import HTTPException + + scope = { + "type": "http", + "method": "GET", + "path": "/admin/whatever", + "headers": [], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = None + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + + async def test_explicit_litellm_key_takes_precedence_over_delegate(self): + """ + When ``x-litellm-api-key`` is present, normal auth runs even for a + delegate server, so ``user_id`` is resolved and any stored upstream + OAuth credentials can be looked up and forwarded. The bypass only + fires when no LiteLLM key is supplied. + """ + from litellm.types.mcp import MCPAuth + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/delegated_oauth_server", + "headers": [(b"x-litellm-api-key", b"Bearer sk-1234")], + } + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + new_callable=AsyncMock, + return_value=UserAPIKeyAuth(user_id="real-user"), + ) as mock_auth, + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = ( + TestMCPDelegateAuthToUpstream._make_server( + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + ) + ) + (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) + assert isinstance(auth_result, UserAPIKeyAuth) + assert auth_result.user_id == "real-user" + mock_auth.assert_called_once() + + async def test_litellm_key_via_authorization_header_not_bypassed(self): + """ + Regression: a LiteLLM key sent via the secondary ``Authorization`` header + (e.g. ``Authorization: Bearer sk-...``) must still trigger normal auth + and not be silently swallowed by the delegate bypass — otherwise spend + tracking and rate limiting are skipped for those callers. + """ + from litellm.types.mcp import MCPAuth + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/delegated_oauth_server", + "headers": [(b"authorization", b"Bearer sk-1234")], + } + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + new_callable=AsyncMock, + return_value=UserAPIKeyAuth(user_id="real-user"), + ) as mock_auth, + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = ( + TestMCPDelegateAuthToUpstream._make_server( + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + ) + ) + (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) + assert isinstance(auth_result, UserAPIKeyAuth) + assert auth_result.user_id == "real-user" + mock_auth.assert_called_once() + + async def test_delegate_ignored_for_client_credentials_server(self): + """ + oauth2 + delegate_auth_to_upstream=True but oauth2_flow=client_credentials + → bypass must NOT fire; normal LiteLLM auth must be attempted. + + M2M servers fetch the upstream token automatically using stored + credentials, so allowing anonymous bypass would let any external + caller invoke tools as LiteLLM's service account. + """ + from fastapi import HTTPException + + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/m2m_server", + "headers": [], + } + + m2m_server = MCPServer( + server_id="m2m-server-id", + name="m2m_server", + transport="http", + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + oauth2_flow="client_credentials", + ) + + async def mock_auth_raises(*_args, **_kwargs): + raise HTTPException(status_code=401, detail="No key provided") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_auth_raises, + ) as mock_auth, + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = m2m_server + # No delegate bypass → normal auth is attempted → 401 raised + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + mock_auth.assert_called_once() + + async def test_delegate_ignored_for_non_public_server(self): + """ + Internal-only delegate servers must not bypass LiteLLM auth for + anonymous public callers. + """ + from fastapi import HTTPException + + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/internal_server", + "headers": [], + } + + internal_server = MCPServer( + server_id="internal-server-id", + name="internal_server", + transport="http", + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + available_on_public_internet=False, + ) + + async def mock_auth_raises(*_args, **_kwargs): + raise HTTPException(status_code=401, detail="No key provided") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_auth_raises, + ) as mock_auth, + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = internal_server + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + mock_auth.assert_called_once() + + async def test_get_allowed_servers_excludes_client_credentials_delegate(self): + """ + get_allowed_mcp_servers must not surface M2M (client_credentials) delegate + servers to anonymous callers even if delegate_auth_to_upstream=True. + """ + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + pkce_server = MCPServer( + server_id="pkce-server", + name="pkce_server", + transport="http", + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + available_on_public_internet=True, + ) + m2m_server = MCPServer( + server_id="m2m-server", + name="m2m_server", + transport="http", + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + oauth2_flow="client_credentials", + available_on_public_internet=True, + ) + manager.registry = { + pkce_server.server_id: pkce_server, + m2m_server.server_id: m2m_server, + } + + with patch.object( + MCPRequestHandler, + "get_allowed_mcp_servers", + new_callable=AsyncMock, + return_value=[], + ): + result = await manager.get_allowed_mcp_servers(None) + + assert "pkce-server" in result + assert "m2m-server" not in result + + async def test_get_allowed_servers_excludes_non_public_delegate(self): + """ + Internal-only (available_on_public_internet=False) delegate servers + must not appear in the anonymous allow-list. + """ + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + public_server = MCPServer( + server_id="public-server", + name="public_server", + transport="http", + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + available_on_public_internet=True, + ) + internal_server = MCPServer( + server_id="internal-server", + name="internal_server", + transport="http", + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + available_on_public_internet=False, + ) + manager.registry = { + public_server.server_id: public_server, + internal_server.server_id: internal_server, + } + + with patch.object( + MCPRequestHandler, + "get_allowed_mcp_servers", + new_callable=AsyncMock, + return_value=[], + ): + result = await manager.get_allowed_mcp_servers(None) + + assert "public-server" in result + assert "internal-server" not in result + + def test_extract_target_server_names_matches_routing_parser(self): + """ + Regression: _extract_target_server_names_from_path must match the + downstream regex parser in server.py::_get_mcp_servers_in_path. + + Previously, a request to ``/mcp//garbage`` was parsed as + targeting ```` by the auth gate (bypassing LiteLLM auth) + while the routing layer parsed it as ``/garbage`` — when + that name did not resolve, the request fell back to the anonymous + allow-list which can include ``allow_all_keys`` servers that normally + require a LiteLLM key. + """ + from litellm.proxy._experimental.mcp_server.server import ( + _get_mcp_servers_in_path, + ) + + cases = [ + # Single server, single segment. + ("/mcp/foo", ["foo"]), + # Server name with one embedded slash (two segments). + ("/mcp/foo/bar", ["foo/bar"]), + # Server name with embedded slash + extra path → name stays at two segments. + ("/mcp/foo/bar/tools", ["foo/bar"]), + # Comma-separated servers, no trailing path. + ("/mcp/foo,bar", ["foo", "bar"]), + # Comma-separated servers with trailing path. + ("/mcp/foo,bar/tools", ["foo", "bar"]), + # Alternative form ``//mcp`` is also parsed (both auth + # parser and routing parser handle it for defense-in-depth — some + # entry points may not be rewritten by ``dynamic_mcp_route``). + ("/foo/mcp", ["foo"]), + ("/foo/mcp/tools", ["foo"]), + # Non-MCP paths → empty (fail closed). + ("/.well-known/oauth-authorization-server", []), + ("/v1/keys", []), + ("/", []), + ] + for path_input, expected in cases: + assert ( + MCPRequestHandler._extract_target_server_names_from_path(path_input) + == expected + ), f"path={path_input!r} → expected {expected!r}" + assert ( + _get_mcp_servers_in_path(path_input) or [] + ) == expected, f"path={path_input!r} → routing expected {expected!r}" + + async def test_delegate_does_not_bypass_on_extra_path_segment(self): + """ + Regression: ``/mcp//`` must NOT bypass auth. + + The bypass key check is now performed against the same parsed target + as downstream routing — ``/`` — which will not + resolve to a delegate-enabled server, so normal LiteLLM auth runs. + """ + from fastapi import HTTPException + + from litellm.types.mcp import MCPAuth + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/delegated_server/extra", + "headers": [], + } + + delegate_server = TestMCPDelegateAuthToUpstream._make_server( + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + ) + + def lookup_by_name(name): + # Only the *exact* delegated name resolves. Anything else (e.g. + # ``delegated_server/extra``) returns None so the bypass fails. + if name == "delegated_server": + return delegate_server + return None + + async def mock_auth_raises(*_args, **_kwargs): + raise HTTPException(status_code=401, detail="No key provided") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_auth_raises, + ) as mock_auth, + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.side_effect = lookup_by_name + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + # Auth was attempted (not bypassed) because the parsed target + # name does not match any registered delegate server. + mock_auth.assert_called_once() + + async def test_delegate_ignores_x_mcp_servers_header_for_mcp_paths(self): + """ + Regression (header/path TOCTOU): For ``/mcp/...`` routes, downstream + routing overrides ``x-mcp-servers`` with the path-derived names. + The auth bypass must do the same — otherwise an attacker could send + ``x-mcp-servers: `` while the URL path targets a + non-delegate server, flipping the auth gate on a server that should + require a LiteLLM key. + """ + from fastapi import HTTPException + + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/non_delegate_server", + "headers": [(b"x-mcp-servers", b"delegated_server")], + } + + delegate_server = MCPServer( + server_id="delegate-id", + name="delegated_server", + transport="http", + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + available_on_public_internet=True, + ) + non_delegate = MCPServer( + server_id="non-delegate-id", + name="non_delegate_server", + transport="http", + auth_type=MCPAuth.api_key, + ) + + def lookup_by_name(name): + return { + "delegated_server": delegate_server, + "non_delegate_server": non_delegate, + }.get(name) + + async def mock_auth_raises(*_args, **_kwargs): + raise HTTPException(status_code=401, detail="No key provided") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_auth_raises, + ) as mock_auth, + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.side_effect = lookup_by_name + # Bypass MUST NOT fire — path-derived target is the non-delegate + # server. Normal auth runs and 401s. + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + mock_auth.assert_called_once() + + async def test_resolve_target_server_names_prefers_path_over_header(self): + """ + ``_resolve_target_server_names`` must: + + - For ``/mcp/`` paths, return the path-derived list and ignore + the header (mirrors downstream routing). + - For non-MCP paths, fall back to the header (including the explicit + empty-list case, which fails closed). + """ + # Path matches /mcp/... — header is ignored. + assert MCPRequestHandler._resolve_target_server_names( + path="/mcp/foo", mcp_servers_header=["evil"] + ) == ["foo"] + assert MCPRequestHandler._resolve_target_server_names( + path="/mcp/foo,bar", mcp_servers_header=["evil"] + ) == ["foo", "bar"] + assert MCPRequestHandler._resolve_target_server_names( + path="/foo/mcp", mcp_servers_header=["evil"] + ) == ["foo"] + # Path does not match — header is trusted. + assert MCPRequestHandler._resolve_target_server_names( + path="/.well-known/oauth-authorization-server", + mcp_servers_header=["foo"], + ) == ["foo"] + # Explicit empty list on a non-MCP path → empty (fail closed). + assert ( + MCPRequestHandler._resolve_target_server_names( + path="/.well-known/oauth-authorization-server", + mcp_servers_header=[], + ) + == [] + ) + # No header on a non-MCP path → empty. + assert ( + MCPRequestHandler._resolve_target_server_names( + path="/.well-known/oauth-authorization-server", + mcp_servers_header=None, + ) + == [] + ) + + class TestMCPCustomHeaderName: """Test suite for custom MCP authentication header name functionality""" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py index b53420f000..794864f658 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py @@ -2456,6 +2456,51 @@ class TestMCPServerManager: assert "test_server_1" in result assert "test_server_2" in result + @pytest.mark.asyncio + async def test_get_allowed_mcp_servers_anonymous_delegate_requires_oauth2(self): + """Anonymous delegated auth listing should only include oauth2 servers.""" + from litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp import ( + MCPRequestHandler, + ) + + manager = MCPServerManager() + oauth_delegate_server = MCPServer( + server_id="oauth-delegate", + name="oauth_delegate", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=True, + ) + api_key_delegate_server = MCPServer( + server_id="api-key-delegate", + name="api_key_delegate", + transport=MCPTransport.http, + auth_type=MCPAuth.api_key, + delegate_auth_to_upstream=True, + ) + oauth_non_delegate_server = MCPServer( + server_id="oauth-non-delegate", + name="oauth_non_delegate", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + delegate_auth_to_upstream=False, + ) + manager.registry = { + oauth_delegate_server.server_id: oauth_delegate_server, + api_key_delegate_server.server_id: api_key_delegate_server, + oauth_non_delegate_server.server_id: oauth_non_delegate_server, + } + + with patch.object( + MCPRequestHandler, + "get_allowed_mcp_servers", + new_callable=AsyncMock, + return_value=[], + ): + result = await manager.get_allowed_mcp_servers(None) + + assert set(result) == {"oauth-delegate"} + def test_get_mcp_server_from_tool_name_uses_server_name_not_name(self): """ Test that _get_mcp_server_from_tool_name uses server.server_name instead of server.name diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_stale_session.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_stale_session.py index a0bfbff422..7b3bb81e04 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_stale_session.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_stale_session.py @@ -497,31 +497,39 @@ async def test_per_user_oauth_missing_stored_token_returns_preemptive_401(): oauth_server.auth_type = MCPAuth.oauth2 oauth_server.needs_user_oauth_token = True - with patch( - "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", - new_callable=AsyncMock, - return_value=(user_auth, None, ["repro_oauth_server"], None, None, None), - ), patch( - "litellm.proxy._experimental.mcp_server.server.set_auth_context", - ), patch( - "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", - True, - ), patch( - "litellm.proxy._experimental.mcp_server.server._handle_stale_mcp_session", - new_callable=AsyncMock, - return_value=False, - ), patch( - "litellm.proxy._experimental.mcp_server.server._get_user_oauth_extra_headers_from_db", - new_callable=AsyncMock, - return_value=None, - ) as mock_get_stored_token, patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_name", - return_value=oauth_server, - ), patch.object( - session_manager, - "handle_request", - new_callable=AsyncMock, - ) as mock_handle_request: + with ( + patch( + "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", + new_callable=AsyncMock, + return_value=(user_auth, None, ["repro_oauth_server"], None, None, None), + ), + patch( + "litellm.proxy._experimental.mcp_server.server.set_auth_context", + ), + patch( + "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", + True, + ), + patch( + "litellm.proxy._experimental.mcp_server.server._handle_stale_mcp_session", + new_callable=AsyncMock, + return_value=False, + ), + patch( + "litellm.proxy._experimental.mcp_server.server._get_user_oauth_extra_headers_from_db", + new_callable=AsyncMock, + return_value=None, + ) as mock_get_stored_token, + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_name", + return_value=oauth_server, + ), + patch.object( + session_manager, + "handle_request", + new_callable=AsyncMock, + ) as mock_handle_request, + ): with pytest.raises(HTTPException) as exc_info: await handle_streamable_http_mcp(scope, receive, send) @@ -562,32 +570,119 @@ async def test_per_user_oauth_with_stored_token_skips_preemptive_401(): oauth_server.auth_type = MCPAuth.oauth2 oauth_server.needs_user_oauth_token = True - with patch( - "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", - new_callable=AsyncMock, - return_value=(user_auth, None, ["repro_oauth_server"], None, None, None), - ), patch( - "litellm.proxy._experimental.mcp_server.server.set_auth_context", - ), patch( - "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", - True, - ), patch( - "litellm.proxy._experimental.mcp_server.server._handle_stale_mcp_session", - new_callable=AsyncMock, - return_value=False, - ), patch( - "litellm.proxy._experimental.mcp_server.server._get_user_oauth_extra_headers_from_db", - new_callable=AsyncMock, - return_value={"Authorization": "Bearer cached-token"}, - ) as mock_get_stored_token, patch( - "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_name", - return_value=oauth_server, - ), patch.object( - session_manager, - "handle_request", - new_callable=AsyncMock, - ) as mock_handle_request: + with ( + patch( + "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", + new_callable=AsyncMock, + return_value=(user_auth, None, ["repro_oauth_server"], None, None, None), + ), + patch( + "litellm.proxy._experimental.mcp_server.server.set_auth_context", + ), + patch( + "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", + True, + ), + patch( + "litellm.proxy._experimental.mcp_server.server._handle_stale_mcp_session", + new_callable=AsyncMock, + return_value=False, + ), + patch( + "litellm.proxy._experimental.mcp_server.server._get_user_oauth_extra_headers_from_db", + new_callable=AsyncMock, + return_value={"Authorization": "Bearer cached-token"}, + ) as mock_get_stored_token, + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_name", + return_value=oauth_server, + ), + patch.object( + session_manager, + "handle_request", + new_callable=AsyncMock, + ) as mock_handle_request, + ): await handle_streamable_http_mcp(scope, receive, send) assert mock_get_stored_token.await_count == 1 assert mock_handle_request.await_count == 1 + + +@pytest.mark.asyncio +async def test_handle_streamable_http_mcp_emits_401_for_delegated_server_without_token(): + """ + OAuth2 server with ``delegate_auth_to_upstream=True`` and no Authorization + header must still emit a pre-emptive 401 with WWW-Authenticate so the + client kicks off PKCE. The 401 points at LiteLLM's discovery shim, which + in turn delegates to the upstream OAuth issuer. + """ + from fastapi import HTTPException + + try: + from litellm.proxy._experimental.mcp_server.server import ( + handle_streamable_http_mcp, + session_manager, + ) + except ImportError: + pytest.skip("MCP server not available") + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [ + (b"content-type", b"application/json"), + (b"host", b"litellm.example.com"), + ], + } + receive = AsyncMock() + send = AsyncMock() + user_auth = MagicMock() + user_auth.user_id = None + delegated_server = MagicMock() + delegated_server.auth_type = MCPAuth.oauth2 + delegated_server.delegate_auth_to_upstream = True + delegated_server.needs_user_oauth_token = True + + with ( + patch( + "litellm.proxy._experimental.mcp_server.server.extract_mcp_auth_context", + new_callable=AsyncMock, + return_value=( + user_auth, + None, + ["delegated_oauth_server"], + None, + None, + None, + ), + ), + patch( + "litellm.proxy._experimental.mcp_server.server.set_auth_context", + ), + patch( + "litellm.proxy._experimental.mcp_server.server._SESSION_MANAGERS_INITIALIZED", + True, + ), + patch( + "litellm.proxy._experimental.mcp_server.server._handle_stale_mcp_session", + new_callable=AsyncMock, + return_value=False, + ), + patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager.get_mcp_server_by_name", + return_value=delegated_server, + ), + patch.object( + session_manager, + "handle_request", + new_callable=AsyncMock, + ) as mock_handle_request, + ): + with pytest.raises(HTTPException) as exc_info: + await handle_streamable_http_mcp(scope, receive, send) + + assert exc_info.value.status_code == 401 + assert "www-authenticate" in exc_info.value.headers + assert mock_handle_request.await_count == 0 diff --git a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py index 30ad84e18b..47e058786a 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py @@ -1645,6 +1645,108 @@ class TestTemporaryMCPSessionEndpoints: _, call_kwargs = auth_builder_mock.call_args assert call_kwargs["api_key"] == "Bearer sk-header-key" + @pytest.mark.asyncio + async def test_mcp_oauth_user_api_key_auth_requires_oauth2_for_delegate_bypass( + self, + ): + """Non-oauth2 servers must not get anonymous access from the delegate flag.""" + from litellm.proxy.management_endpoints.mcp_management_endpoints import ( + _mcp_oauth_user_api_key_auth, + ) + + expected_auth = generate_mock_user_api_key_auth( + user_role=LitellmUserRoles.PROXY_ADMIN + ) + mock_request = MagicMock() + mock_request.headers = {} + mock_request.cookies = {} + mock_request.path_params = {"server_id": "server-1"} + non_oauth_server = MagicMock() + non_oauth_server.auth_type = MCPAuth.api_key + non_oauth_server.delegate_auth_to_upstream = True + mock_manager = MagicMock() + mock_manager.get_mcp_server_by_id.return_value = non_oauth_server + mock_manager.get_mcp_server_by_name.return_value = None + fake_proxy_server = types.SimpleNamespace(master_key=None) + + with ( + patch.dict(sys.modules, {"litellm.proxy.proxy_server": fake_proxy_server}), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager", + mock_manager, + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints._user_api_key_auth_builder", + AsyncMock(return_value=expected_auth), + ) as auth_builder_mock, + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints._read_request_body", + AsyncMock(return_value={}), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.populate_request_with_path_params", + side_effect=lambda request_data, request: request_data, + ), + ): + result = await _mcp_oauth_user_api_key_auth(mock_request) + + assert result is expected_auth + auth_builder_mock.assert_awaited_once() + _, call_kwargs = auth_builder_mock.call_args + assert call_kwargs["api_key"] == "" + + @pytest.mark.asyncio + async def test_mcp_oauth_user_api_key_auth_requires_public_server_for_delegate_bypass( + self, + ): + """Internal-only delegate servers must still require LiteLLM auth.""" + from litellm.proxy.management_endpoints.mcp_management_endpoints import ( + _mcp_oauth_user_api_key_auth, + ) + + expected_auth = generate_mock_user_api_key_auth( + user_role=LitellmUserRoles.PROXY_ADMIN + ) + mock_request = MagicMock() + mock_request.headers = {} + mock_request.cookies = {} + mock_request.path_params = {"server_id": "server-1"} + internal_server = MagicMock() + internal_server.auth_type = MCPAuth.oauth2 + internal_server.delegate_auth_to_upstream = True + internal_server.available_on_public_internet = False + internal_server.has_client_credentials = False + mock_manager = MagicMock() + mock_manager.get_mcp_server_by_id.return_value = internal_server + mock_manager.get_mcp_server_by_name.return_value = None + fake_proxy_server = types.SimpleNamespace(master_key=None) + + with ( + patch.dict(sys.modules, {"litellm.proxy.proxy_server": fake_proxy_server}), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager", + mock_manager, + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints._user_api_key_auth_builder", + AsyncMock(return_value=expected_auth), + ) as auth_builder_mock, + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints._read_request_body", + AsyncMock(return_value={}), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.populate_request_with_path_params", + side_effect=lambda request_data, request: request_data, + ), + ): + result = await _mcp_oauth_user_api_key_auth(mock_request) + + assert result is expected_auth + auth_builder_mock.assert_awaited_once() + _, call_kwargs = auth_builder_mock.call_args + assert call_kwargs["api_key"] == "" + @pytest.mark.asyncio async def test_mcp_authorize_proxies_to_discoverable_endpoint(self): from litellm.proxy.management_endpoints.mcp_management_endpoints import ( diff --git a/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx b/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx index dcb7298c83..8f60e50d2b 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx @@ -1,7 +1,7 @@ import React, { useEffect } from "react"; import { Form, Select, Tooltip, Collapse, Input, Space, Button, Switch } from "antd"; import { InfoCircleOutlined, MinusCircleOutlined, PlusOutlined } from "@ant-design/icons"; -import { MCPServer } from "./types"; +import { MCPServer, AUTH_TYPE } from "./types"; const { Panel } = Collapse; interface MCPPermissionManagementProps { @@ -23,6 +23,8 @@ const MCPPermissionManagement: React.FC = ({ getAccessGroupOptions, }) => { const form = Form.useFormInstance(); + const watchedAuthType = Form.useWatch("auth_type", form); + const isOAuth2 = watchedAuthType === AUTH_TYPE.OAUTH2; // Set initial values when mcpServer changes useEffect(() => { @@ -40,12 +42,25 @@ const MCPPermissionManagement: React.FC = ({ if (typeof mcpServer.available_on_public_internet === "boolean") { form.setFieldValue("available_on_public_internet", mcpServer.available_on_public_internet); } + if (typeof mcpServer.delegate_auth_to_upstream === "boolean") { + form.setFieldValue("delegate_auth_to_upstream", mcpServer.delegate_auth_to_upstream); + } } else { form.setFieldValue("allow_all_keys", false); form.setFieldValue("available_on_public_internet", true); + form.setFieldValue("delegate_auth_to_upstream", false); } }, [mcpServer, form]); + // delegate_auth_to_upstream is only honored server-side when auth_type=oauth2. + // Force it back to false whenever the user switches away from oauth2 so a + // stale toggle value doesn't get persisted with another auth type. + useEffect(() => { + if (!isOAuth2) { + form.setFieldValue("delegate_auth_to_upstream", false); + } + }, [isOAuth2, form]); + return ( = ({ + {isOAuth2 && ( +
+
+ + Delegate auth to upstream (PKCE passthrough) + + + + +

+ Bypass LiteLLM auth so clients authenticate directly with the upstream OAuth MCP server. +

+
+ + + +
+ )} + diff --git a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx index 17bcd59c43..f8b0141b25 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx @@ -284,6 +284,7 @@ const CreateMCPServer: React.FC = ({ credentials: credentialValues, allow_all_keys: allowAllKeysRaw, available_on_public_internet: availableOnPublicInternetRaw, + delegate_auth_to_upstream: delegateAuthToUpstreamRaw, token_validation_json: rawTokenValidationJson, ...restValues } = values; @@ -388,6 +389,7 @@ const CreateMCPServer: React.FC = ({ tool_name_to_description: Object.keys(toolNameToDescription).length > 0 ? toolNameToDescription : null, allow_all_keys: Boolean(allowAllKeysRaw), available_on_public_internet: Boolean(availableOnPublicInternetRaw), + delegate_auth_to_upstream: Boolean(delegateAuthToUpstreamRaw), static_headers: staticHeaders, ...(tokenValidation !== null && { token_validation: tokenValidation }), }; diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.test.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.test.tsx index 760504d579..1f2864f675 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.test.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.test.tsx @@ -177,6 +177,47 @@ describe("MCPServerEdit (stdio)", () => { }); }); +describe("MCPServerEdit (delegate auth)", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("should clear delegate auth flag when saving a non-oauth2 server", async () => { + vi.mocked(networking.updateMCPServer).mockResolvedValue({ + ...interactiveOAuthServer, + auth_type: "none", + delegate_auth_to_upstream: false, + }); + + render( + , + ); + + const saveButtons = screen.getAllByRole("button", { name: "Save Changes" }); + await act(async () => { + fireEvent.click(saveButtons[0]); + }); + + await waitFor(() => { + expect(networking.updateMCPServer).toHaveBeenCalledTimes(1); + }); + + const [, payload] = vi.mocked(networking.updateMCPServer).mock.calls[0]; + expect(payload.auth_type).toBe("none"); + expect(payload.delegate_auth_to_upstream).toBe(false); + }); +}); + describe("MCPServerEdit (interactive OAuth)", () => { beforeEach(() => { vi.clearAllMocks(); diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx index 997e76982d..9278d41c3e 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx @@ -384,6 +384,7 @@ const MCPServerEdit: React.FC = ({ args: rawArgs, allow_all_keys: allowAllKeysRaw, available_on_public_internet: availableOnPublicInternetRaw, + delegate_auth_to_upstream: delegateAuthToUpstreamRaw, token_validation_json: rawTokenValidationJson, ...restValues } = values; @@ -552,6 +553,15 @@ const MCPServerEdit: React.FC = ({ static_headers: staticHeaders, allow_all_keys: Boolean(allowAllKeysRaw ?? mcpServer.allow_all_keys), available_on_public_internet: Boolean(availableOnPublicInternetRaw ?? mcpServer.available_on_public_internet), + // ``delegate_auth_to_upstream`` is only honored server-side for + // ``auth_type=oauth2``. The Form.Item is conditionally rendered so the + // value drops out of the form on auth_type change; force false for any + // non-oauth2 server to avoid persisting a stale ``true`` that would + // silently re-activate if auth_type is later switched back to oauth2. + delegate_auth_to_upstream: + restValues.auth_type === AUTH_TYPE.OAUTH2 + ? Boolean(delegateAuthToUpstreamRaw ?? mcpServer.delegate_auth_to_upstream) + : false, // Include token_validation when it is set (non-null) or when clearing an existing value ...(tokenValidation !== null || mcpServer.token_validation ? { token_validation: tokenValidation } diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_view.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_view.tsx index f9a3d57e95..1f8f7f68d3 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_view.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_view.tsx @@ -272,6 +272,23 @@ export const MCPServerView: React.FC = ({ )} + {handleAuth(mcpServer.auth_type) === "oauth2" && ( +
+ Delegate Auth to Upstream +
+ {mcpServer.delegate_auth_to_upstream ? ( + + + Enabled (PKCE passthrough) + + ) : ( + + Disabled + + )} +
+
+ )}
Access Groups
diff --git a/ui/litellm-dashboard/src/components/mcp_tools/types.tsx b/ui/litellm-dashboard/src/components/mcp_tools/types.tsx index 814c5d74f4..7cfe08d9ee 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/types.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/types.tsx @@ -202,6 +202,7 @@ export interface MCPServer { tool_name_to_description?: Record; allow_all_keys?: boolean; available_on_public_internet?: boolean; + delegate_auth_to_upstream?: boolean; /** Stdio-only fields (present when transport === 'stdio') */ command?: string | null; diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 7f28dbe6da..756348f493 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -8800,6 +8800,7 @@ interface ExchangeMcpOAuthTokenParams { clientSecret?: string; codeVerifier: string; redirectUri: string; + accessToken?: string | null; } export const exchangeMcpOAuthToken = async ({ @@ -8809,6 +8810,7 @@ export const exchangeMcpOAuthToken = async ({ clientSecret, codeVerifier, redirectUri, + accessToken, }: ExchangeMcpOAuthTokenParams) => { const base = getProxyBaseUrl(); const normalizedServerId = encodeURIComponent(serverId.trim()); @@ -8826,11 +8828,16 @@ export const exchangeMcpOAuthToken = async ({ body.set("code_verifier", codeVerifier); body.set("redirect_uri", redirectUri); + const headers: Record = { + "Content-Type": "application/x-www-form-urlencoded", + }; + if (accessToken) { + headers["Authorization"] = `Bearer ${accessToken}`; + } + const response = await fetch(url, { method: "POST", - headers: { - "Content-Type": "application/x-www-form-urlencoded", - }, + headers, body: body.toString(), }); diff --git a/ui/litellm-dashboard/src/hooks/useMcpOAuthFlow.tsx b/ui/litellm-dashboard/src/hooks/useMcpOAuthFlow.tsx index 9157fedbe2..7edeade4cb 100644 --- a/ui/litellm-dashboard/src/hooks/useMcpOAuthFlow.tsx +++ b/ui/litellm-dashboard/src/hooks/useMcpOAuthFlow.tsx @@ -279,6 +279,7 @@ export const useMcpOAuthFlow = ({ clientSecret: flowState.clientSecret, codeVerifier: flowState.codeVerifier, redirectUri: flowState.redirectUri, + accessToken, }); onTokenReceived(token); diff --git a/ui/litellm-dashboard/src/hooks/useUserMcpOAuthFlow.tsx b/ui/litellm-dashboard/src/hooks/useUserMcpOAuthFlow.tsx index aa7ce84de5..cf0a81dcad 100644 --- a/ui/litellm-dashboard/src/hooks/useUserMcpOAuthFlow.tsx +++ b/ui/litellm-dashboard/src/hooks/useUserMcpOAuthFlow.tsx @@ -224,6 +224,7 @@ export const useUserMcpOAuthFlow = ({ clientSecret: flowState.clientSecret, codeVerifier: flowState.codeVerifier, redirectUri: flowState.redirectUri, + accessToken, }); // Persist the token for this user via the backend.