diff --git a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py index d4799e9f20..756b2ed91d 100644 --- a/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py +++ b/litellm/proxy/_experimental/mcp_server/auth/user_api_key_auth_mcp.py @@ -117,7 +117,10 @@ class MCPRequestHandler: return b"{}" request.body = mock_body # type: ignore - if ".well-known" in str(request.url): # public routes + # Only OAuth metadata routes registered under /.well-known/ are public. + # Match on request.url.path (path-only, exact prefix) so the substring + # cannot be smuggled via query string, hostname, or a deeper URL segment. + if request.url.path.startswith("/.well-known/"): validated_user_api_key_auth = UserAPIKeyAuth() elif has_explicit_litellm_key: # Explicit x-litellm-api-key provided - always validate normally @@ -126,27 +129,37 @@ class MCPRequestHandler: ) elif oauth2_headers: # No x-litellm-api-key, but Authorization header present. - # Could be a LiteLLM key (backward compat) OR an OAuth2 token - # from an upstream MCP provider (e.g. Atlassian). - # Try LiteLLM auth first; on auth failure, treat as OAuth2 passthrough. + # Could be a LiteLLM key (backward compat) OR an opaque OAuth2 token + # the operator wants forwarded to an upstream OAuth2-mode MCP server. + # Try LiteLLM auth first; on auth failure, only fall back to anonymous + # passthrough when the request actually targets a server whose operator + # configured ``auth_type=oauth2``. For any other server (api_key, + # bearer_token, basic, etc.), a failed LiteLLM auth is a real failure + # and must propagate — otherwise an attacker can exchange any garbage + # bearer for an anonymous session. try: validated_user_api_key_auth = await user_api_key_auth( api_key=litellm_api_key, request=request ) - except HTTPException as e: - if e.status_code in (401, 403): + except (HTTPException, ProxyException) as e: + # HTTPException.status_code is int; ProxyException.code is + # normalized to str in its __init__ but can be ``"None"`` or any + # non-numeric string when the caller didn't supply a numeric + # code, so we compare against both int and str forms rather + # than coercing (``int("None")`` would raise ValueError and + # rewrite the auth error as a 500). + status = e.status_code if isinstance(e, HTTPException) else e.code + if status in ( + 401, + 403, + "401", + "403", + ) and MCPRequestHandler._target_servers_use_oauth2( + path=request.url.path, mcp_servers=mcp_servers + ): verbose_logger.debug( - "MCP OAuth2: Authorization header is not a valid LiteLLM key, " - "treating as OAuth2 token passthrough" - ) - validated_user_api_key_auth = UserAPIKeyAuth() - else: - raise - except ProxyException as e: - if str(e.code) in ("401", "403"): - verbose_logger.debug( - "MCP OAuth2: Authorization header is not a valid LiteLLM key, " - "treating as OAuth2 token passthrough" + "MCP OAuth2: target server is OAuth2-mode, treating " + "Authorization as upstream OAuth2 token passthrough" ) validated_user_api_key_auth = UserAPIKeyAuth() else: @@ -165,6 +178,62 @@ class MCPRequestHandler: dict(headers), ) + @staticmethod + def _extract_target_server_names_from_path(path: str) -> List[str]: + """ + Extract the target MCP server name from the standard MCP transport + URL patterns: ``/mcp/{server_name}[/...]`` and + ``/{server_name}/mcp[/...]``. Returns ``[]`` for any other path so + callers fail closed when the target cannot be resolved. + + REST/admin endpoints, OAuth2 server endpoints + (``/{server_name}/authorize``, ``/token`` etc.), and ``.well-known`` + discovery routes intentionally fall through — those flows do not need + OAuth2 token passthrough. Clients aggregating multiple servers should + use ``x-mcp-servers``, which takes precedence over path parsing. + """ + segments = [s for s in path.split("/") if s] + if len(segments) >= 2 and segments[0] == "mcp": + return [segments[1]] + if len(segments) >= 2 and segments[1] == "mcp": + return [segments[0]] + return [] + + @staticmethod + def _target_servers_use_oauth2(path: str, mcp_servers: Optional[List[str]]) -> bool: + """ + True only when EVERY MCP server the request targets is configured for + ``auth_type == oauth2``. If any target is non-OAuth2 — or if the target + cannot be resolved at all — return False so the caller fails closed. + + Used to gate the "treat Authorization as opaque OAuth2 token" fallback + in :meth:`process_mcp_request` so a failed LiteLLM-auth cannot be + exchanged for an anonymous session against a non-OAuth2 server. + """ + # Inline imports avoid a circular dependency: mcp_server_manager imports + # from this module. + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + from litellm.types.mcp import MCPAuth + + # Use the x-mcp-servers header verbatim when present (including the + # explicitly-empty list, which means "no targets" → fail closed). + # Only fall back to path parsing when the header was absent entirely. + target_names = ( + mcp_servers + if mcp_servers is not None + else MCPRequestHandler._extract_target_server_names_from_path(path) + ) + if not target_names: + return False + + for name in target_names: + server = global_mcp_server_manager.get_mcp_server_by_name(name) + if server is None or server.auth_type != MCPAuth.oauth2: + return False + return True + @staticmethod def _get_mcp_auth_header_from_headers(headers: Headers) -> Optional[str]: """ diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py index fd1a2b1236..dd352d0999 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/auth/test_user_api_key_auth_mcp.py @@ -551,11 +551,14 @@ class TestMCPOAuth2AuthFlow: async def test_oauth2_token_in_authorization_header_fallback(self): """ - When only Authorization header is present with a non-LiteLLM OAuth2 token, + When only Authorization header is present with a non-LiteLLM OAuth2 token + AND the target server is operator-configured for ``auth_type=oauth2``, auth should fall back to permissive mode (OAuth2 passthrough). """ from fastapi import HTTPException + from litellm.types.mcp import MCPAuth + scope = { "type": "http", "method": "POST", @@ -568,10 +571,19 @@ class TestMCPOAuth2AuthFlow: async def mock_user_api_key_auth_fails(api_key, request): raise HTTPException(status_code=401, detail="Invalid API key") - with patch( - "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", - side_effect=mock_user_api_key_auth_fails, + oauth2_server = MagicMock() + oauth2_server.auth_type = MCPAuth.oauth2 + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, ): + mock_mgr.get_mcp_server_by_name.return_value = oauth2_server ( auth_result, mcp_auth_header, @@ -695,9 +707,11 @@ class TestMCPOAuth2AuthFlow: async def test_proxy_exception_oauth2_fallback(self): """ user_api_key_auth raises ProxyException (not HTTPException) in production. - The OAuth2 fallback must catch ProxyException with code 401/403 too. + The OAuth2 fallback must catch ProxyException with code 401/403 too, + but only when the target server is operator-configured for ``auth_type=oauth2``. """ from litellm.proxy._types import ProxyException + from litellm.types.mcp import MCPAuth scope = { "type": "http", @@ -716,10 +730,19 @@ class TestMCPOAuth2AuthFlow: code=401, ) - with patch( - "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", - side_effect=mock_user_api_key_auth_proxy_exception, + oauth2_server = MagicMock() + oauth2_server.auth_type = MCPAuth.oauth2 + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_proxy_exception, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, ): + mock_mgr.get_mcp_server_by_name.return_value = oauth2_server ( auth_result, mcp_auth_header, @@ -768,6 +791,290 @@ class TestMCPOAuth2AuthFlow: await MCPRequestHandler.process_mcp_request(scope) +@pytest.mark.asyncio +class TestMCPPublicRouteGuard: + """ + Regression tests for GHSA-7cwm-3279-qf3c / HW6xR21d: + the public-route bypass at the top of process_mcp_request must match + the exact `/.well-known/` path prefix, not a substring of the URL. + """ + + async def test_well_known_substring_in_query_does_not_bypass_auth(self): + """ + URL with `.well-known` smuggled into the query string must still + require valid LiteLLM auth. + """ + from fastapi import HTTPException + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/private_server", + "query_string": b"redirect=.well-known/oauth-protected-resource", + "headers": [(b"authorization", b"Bearer sk-bogus")], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + # Explicit unresolvable target — proves auth still fails even + # when the registry has no info to fall back to. + mock_mgr.get_mcp_server_by_name.return_value = None + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + + async def test_well_known_segment_in_middle_of_path_does_not_bypass_auth(self): + """ + Path containing `.well-known` as a non-prefix component (e.g. a server + name or sub-path) must still require auth. + """ + from fastapi import HTTPException + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/.well-known-fake/tools", + "headers": [(b"authorization", b"Bearer sk-bogus")], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = None + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + + async def test_legitimate_well_known_path_still_bypasses_auth(self): + """ + Real OAuth discovery routes registered under /.well-known/ must remain + public so unauthenticated clients can fetch them per RFC 8414/9728. + """ + scope = { + "type": "http", + "method": "GET", + "path": "/.well-known/oauth-protected-resource", + "headers": [], + } + + # No mock needed — public path should not call user_api_key_auth at all + with patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + ) as mock_auth: + (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) + mock_auth.assert_not_called() + assert isinstance(auth_result, UserAPIKeyAuth) + + +@pytest.mark.asyncio +class TestMCPOAuth2FallbackTargetGating: + """ + Regression tests for GHSA-h8fm-g6wc-j228 / HW6xR21d: + The OAuth2 passthrough fallback must only fire when the target MCP server + is operator-configured for ``auth_type=oauth2``. A failed LiteLLM-auth + against a non-OAuth2 server (api_key, bearer_token, basic, etc.) must + propagate as a real auth error, not be exchanged for an anonymous session. + """ + + @staticmethod + def _make_server(auth_type): + server = MagicMock() + server.auth_type = auth_type + return server + + async def test_fallback_blocked_when_target_is_not_oauth2(self): + from fastapi import HTTPException + + from litellm.types.mcp import MCPAuth + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/api_key_server", + "headers": [(b"authorization", b"Bearer anything-at-all")], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = ( + TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.api_key) + ) + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + + async def test_fallback_blocked_when_target_unresolvable(self): + """ + If the target server cannot be resolved from path or x-mcp-servers, + we cannot prove it is OAuth2-mode, so we must fail closed. + """ + from fastapi import HTTPException + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/never_registered_server", + "headers": [(b"authorization", b"Bearer anything")], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = None + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + + async def test_fallback_allowed_when_target_is_oauth2_mode(self): + """ + Operator-configured OAuth2 passthrough still works: target server has + ``auth_type=oauth2`` → failed LiteLLM auth falls back to anonymous so + the bearer can be forwarded to upstream. + """ + from fastapi import HTTPException + + from litellm.types.mcp import MCPAuth + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/atlassian_mcp", + "headers": [ + (b"authorization", b"Bearer atlassian-oauth2-access-token-xyz"), + ], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.return_value = ( + TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.oauth2) + ) + (auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope) + assert isinstance(auth_result, UserAPIKeyAuth) + + async def test_fallback_blocked_when_any_target_in_header_is_not_oauth2(self): + """ + x-mcp-servers can list multiple targets. If ANY of them is non-OAuth2, + the fallback must be blocked — otherwise an attacker can mix one + OAuth2-mode server in to enable bypass against the others. + """ + from fastapi import HTTPException + + from litellm.types.mcp import MCPAuth + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp", + "headers": [ + (b"authorization", b"Bearer anything"), + (b"x-mcp-servers", b"oauth2_server,api_key_server"), + ], + } + + async def mock_user_api_key_auth_fails(api_key, request): + raise HTTPException(status_code=401, detail="Invalid API key") + + def mock_lookup(name, client_ip=None): + if name == "oauth2_server": + return TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.oauth2) + return TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.api_key) + + with ( + patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_fails, + ), + patch( + "litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager" + ) as mock_mgr, + ): + mock_mgr.get_mcp_server_by_name.side_effect = mock_lookup + with pytest.raises(HTTPException) as exc_info: + await MCPRequestHandler.process_mcp_request(scope) + assert exc_info.value.status_code == 401 + + async def test_proxy_exception_with_non_numeric_code_propagates(self): + """ + ``ProxyException`` normalises ``code`` via ``str()`` in its __init__, + so callers may produce ``"None"`` or any non-numeric string when no + explicit code was supplied. The exception handler must not coerce + with ``int(...)`` (which would raise ``ValueError`` and rewrite the + auth error as an unhandled 500); it must simply re-raise. + """ + from litellm.proxy._types import ProxyException + + scope = { + "type": "http", + "method": "POST", + "path": "/mcp/atlassian_mcp", + "headers": [(b"authorization", b"Bearer anything")], + } + + async def mock_user_api_key_auth_no_code(api_key, request): + raise ProxyException( + message="Authentication Error", + type="auth_error", + param="api_key", + code=None, + ) + + with patch( + "litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth", + side_effect=mock_user_api_key_auth_no_code, + ): + with pytest.raises(ProxyException): + await MCPRequestHandler.process_mcp_request(scope) + + class TestMCPCustomHeaderName: """Test suite for custom MCP authentication header name functionality"""