Merge pull request #26463 from stuxf/fix/mcp-routing-auth

fix(mcp): tighten public-route detection and OAuth2 fallback gating
This commit is contained in:
yuneng-jiang 2026-04-29 19:31:00 -07:00 committed by GitHub
commit 2e561bd04e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 401 additions and 25 deletions

View File

@ -117,7 +117,10 @@ class MCPRequestHandler:
return b"{}"
request.body = mock_body # type: ignore
if ".well-known" in str(request.url): # public routes
# Only OAuth metadata routes registered under /.well-known/ are public.
# Match on request.url.path (path-only, exact prefix) so the substring
# cannot be smuggled via query string, hostname, or a deeper URL segment.
if request.url.path.startswith("/.well-known/"):
validated_user_api_key_auth = UserAPIKeyAuth()
elif has_explicit_litellm_key:
# Explicit x-litellm-api-key provided - always validate normally
@ -126,27 +129,37 @@ class MCPRequestHandler:
)
elif oauth2_headers:
# No x-litellm-api-key, but Authorization header present.
# Could be a LiteLLM key (backward compat) OR an OAuth2 token
# from an upstream MCP provider (e.g. Atlassian).
# Try LiteLLM auth first; on auth failure, treat as OAuth2 passthrough.
# Could be a LiteLLM key (backward compat) OR an opaque OAuth2 token
# the operator wants forwarded to an upstream OAuth2-mode MCP server.
# Try LiteLLM auth first; on auth failure, only fall back to anonymous
# passthrough when the request actually targets a server whose operator
# configured ``auth_type=oauth2``. For any other server (api_key,
# bearer_token, basic, etc.), a failed LiteLLM auth is a real failure
# and must propagate — otherwise an attacker can exchange any garbage
# bearer for an anonymous session.
try:
validated_user_api_key_auth = await user_api_key_auth(
api_key=litellm_api_key, request=request
)
except HTTPException as e:
if e.status_code in (401, 403):
except (HTTPException, ProxyException) as e:
# HTTPException.status_code is int; ProxyException.code is
# normalized to str in its __init__ but can be ``"None"`` or any
# non-numeric string when the caller didn't supply a numeric
# code, so we compare against both int and str forms rather
# than coercing (``int("None")`` would raise ValueError and
# rewrite the auth error as a 500).
status = e.status_code if isinstance(e, HTTPException) else e.code
if status in (
401,
403,
"401",
"403",
) and MCPRequestHandler._target_servers_use_oauth2(
path=request.url.path, mcp_servers=mcp_servers
):
verbose_logger.debug(
"MCP OAuth2: Authorization header is not a valid LiteLLM key, "
"treating as OAuth2 token passthrough"
)
validated_user_api_key_auth = UserAPIKeyAuth()
else:
raise
except ProxyException as e:
if str(e.code) in ("401", "403"):
verbose_logger.debug(
"MCP OAuth2: Authorization header is not a valid LiteLLM key, "
"treating as OAuth2 token passthrough"
"MCP OAuth2: target server is OAuth2-mode, treating "
"Authorization as upstream OAuth2 token passthrough"
)
validated_user_api_key_auth = UserAPIKeyAuth()
else:
@ -165,6 +178,62 @@ class MCPRequestHandler:
dict(headers),
)
@staticmethod
def _extract_target_server_names_from_path(path: str) -> List[str]:
"""
Extract the target MCP server name from the standard MCP transport
URL patterns: ``/mcp/{server_name}[/...]`` and
``/{server_name}/mcp[/...]``. Returns ``[]`` for any other path so
callers fail closed when the target cannot be resolved.
REST/admin endpoints, OAuth2 server endpoints
(``/{server_name}/authorize``, ``/token`` etc.), and ``.well-known``
discovery routes intentionally fall through those flows do not need
OAuth2 token passthrough. Clients aggregating multiple servers should
use ``x-mcp-servers``, which takes precedence over path parsing.
"""
segments = [s for s in path.split("/") if s]
if len(segments) >= 2 and segments[0] == "mcp":
return [segments[1]]
if len(segments) >= 2 and segments[1] == "mcp":
return [segments[0]]
return []
@staticmethod
def _target_servers_use_oauth2(path: str, mcp_servers: Optional[List[str]]) -> bool:
"""
True only when EVERY MCP server the request targets is configured for
``auth_type == oauth2``. If any target is non-OAuth2 or if the target
cannot be resolved at all return False so the caller fails closed.
Used to gate the "treat Authorization as opaque OAuth2 token" fallback
in :meth:`process_mcp_request` so a failed LiteLLM-auth cannot be
exchanged for an anonymous session against a non-OAuth2 server.
"""
# Inline imports avoid a circular dependency: mcp_server_manager imports
# from this module.
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
from litellm.types.mcp import MCPAuth
# Use the x-mcp-servers header verbatim when present (including the
# explicitly-empty list, which means "no targets" → fail closed).
# Only fall back to path parsing when the header was absent entirely.
target_names = (
mcp_servers
if mcp_servers is not None
else MCPRequestHandler._extract_target_server_names_from_path(path)
)
if not target_names:
return False
for name in target_names:
server = global_mcp_server_manager.get_mcp_server_by_name(name)
if server is None or server.auth_type != MCPAuth.oauth2:
return False
return True
@staticmethod
def _get_mcp_auth_header_from_headers(headers: Headers) -> Optional[str]:
"""

View File

@ -551,11 +551,14 @@ class TestMCPOAuth2AuthFlow:
async def test_oauth2_token_in_authorization_header_fallback(self):
"""
When only Authorization header is present with a non-LiteLLM OAuth2 token,
When only Authorization header is present with a non-LiteLLM OAuth2 token
AND the target server is operator-configured for ``auth_type=oauth2``,
auth should fall back to permissive mode (OAuth2 passthrough).
"""
from fastapi import HTTPException
from litellm.types.mcp import MCPAuth
scope = {
"type": "http",
"method": "POST",
@ -568,10 +571,19 @@ class TestMCPOAuth2AuthFlow:
async def mock_user_api_key_auth_fails(api_key, request):
raise HTTPException(status_code=401, detail="Invalid API key")
with patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_fails,
oauth2_server = MagicMock()
oauth2_server.auth_type = MCPAuth.oauth2
with (
patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_fails,
),
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
) as mock_mgr,
):
mock_mgr.get_mcp_server_by_name.return_value = oauth2_server
(
auth_result,
mcp_auth_header,
@ -695,9 +707,11 @@ class TestMCPOAuth2AuthFlow:
async def test_proxy_exception_oauth2_fallback(self):
"""
user_api_key_auth raises ProxyException (not HTTPException) in production.
The OAuth2 fallback must catch ProxyException with code 401/403 too.
The OAuth2 fallback must catch ProxyException with code 401/403 too,
but only when the target server is operator-configured for ``auth_type=oauth2``.
"""
from litellm.proxy._types import ProxyException
from litellm.types.mcp import MCPAuth
scope = {
"type": "http",
@ -716,10 +730,19 @@ class TestMCPOAuth2AuthFlow:
code=401,
)
with patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_proxy_exception,
oauth2_server = MagicMock()
oauth2_server.auth_type = MCPAuth.oauth2
with (
patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_proxy_exception,
),
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
) as mock_mgr,
):
mock_mgr.get_mcp_server_by_name.return_value = oauth2_server
(
auth_result,
mcp_auth_header,
@ -768,6 +791,290 @@ class TestMCPOAuth2AuthFlow:
await MCPRequestHandler.process_mcp_request(scope)
@pytest.mark.asyncio
class TestMCPPublicRouteGuard:
"""
Regression tests for GHSA-7cwm-3279-qf3c / HW6xR21d:
the public-route bypass at the top of process_mcp_request must match
the exact `/.well-known/` path prefix, not a substring of the URL.
"""
async def test_well_known_substring_in_query_does_not_bypass_auth(self):
"""
URL with `.well-known` smuggled into the query string must still
require valid LiteLLM auth.
"""
from fastapi import HTTPException
scope = {
"type": "http",
"method": "POST",
"path": "/mcp/private_server",
"query_string": b"redirect=.well-known/oauth-protected-resource",
"headers": [(b"authorization", b"Bearer sk-bogus")],
}
async def mock_user_api_key_auth_fails(api_key, request):
raise HTTPException(status_code=401, detail="Invalid API key")
with (
patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_fails,
),
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
) as mock_mgr,
):
# Explicit unresolvable target — proves auth still fails even
# when the registry has no info to fall back to.
mock_mgr.get_mcp_server_by_name.return_value = None
with pytest.raises(HTTPException) as exc_info:
await MCPRequestHandler.process_mcp_request(scope)
assert exc_info.value.status_code == 401
async def test_well_known_segment_in_middle_of_path_does_not_bypass_auth(self):
"""
Path containing `.well-known` as a non-prefix component (e.g. a server
name or sub-path) must still require auth.
"""
from fastapi import HTTPException
scope = {
"type": "http",
"method": "POST",
"path": "/mcp/.well-known-fake/tools",
"headers": [(b"authorization", b"Bearer sk-bogus")],
}
async def mock_user_api_key_auth_fails(api_key, request):
raise HTTPException(status_code=401, detail="Invalid API key")
with (
patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_fails,
),
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
) as mock_mgr,
):
mock_mgr.get_mcp_server_by_name.return_value = None
with pytest.raises(HTTPException) as exc_info:
await MCPRequestHandler.process_mcp_request(scope)
assert exc_info.value.status_code == 401
async def test_legitimate_well_known_path_still_bypasses_auth(self):
"""
Real OAuth discovery routes registered under /.well-known/ must remain
public so unauthenticated clients can fetch them per RFC 8414/9728.
"""
scope = {
"type": "http",
"method": "GET",
"path": "/.well-known/oauth-protected-resource",
"headers": [],
}
# No mock needed — public path should not call user_api_key_auth at all
with patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
) as mock_auth:
(auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope)
mock_auth.assert_not_called()
assert isinstance(auth_result, UserAPIKeyAuth)
@pytest.mark.asyncio
class TestMCPOAuth2FallbackTargetGating:
"""
Regression tests for GHSA-h8fm-g6wc-j228 / HW6xR21d:
The OAuth2 passthrough fallback must only fire when the target MCP server
is operator-configured for ``auth_type=oauth2``. A failed LiteLLM-auth
against a non-OAuth2 server (api_key, bearer_token, basic, etc.) must
propagate as a real auth error, not be exchanged for an anonymous session.
"""
@staticmethod
def _make_server(auth_type):
server = MagicMock()
server.auth_type = auth_type
return server
async def test_fallback_blocked_when_target_is_not_oauth2(self):
from fastapi import HTTPException
from litellm.types.mcp import MCPAuth
scope = {
"type": "http",
"method": "POST",
"path": "/mcp/api_key_server",
"headers": [(b"authorization", b"Bearer anything-at-all")],
}
async def mock_user_api_key_auth_fails(api_key, request):
raise HTTPException(status_code=401, detail="Invalid API key")
with (
patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_fails,
),
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
) as mock_mgr,
):
mock_mgr.get_mcp_server_by_name.return_value = (
TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.api_key)
)
with pytest.raises(HTTPException) as exc_info:
await MCPRequestHandler.process_mcp_request(scope)
assert exc_info.value.status_code == 401
async def test_fallback_blocked_when_target_unresolvable(self):
"""
If the target server cannot be resolved from path or x-mcp-servers,
we cannot prove it is OAuth2-mode, so we must fail closed.
"""
from fastapi import HTTPException
scope = {
"type": "http",
"method": "POST",
"path": "/mcp/never_registered_server",
"headers": [(b"authorization", b"Bearer anything")],
}
async def mock_user_api_key_auth_fails(api_key, request):
raise HTTPException(status_code=401, detail="Invalid API key")
with (
patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_fails,
),
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
) as mock_mgr,
):
mock_mgr.get_mcp_server_by_name.return_value = None
with pytest.raises(HTTPException) as exc_info:
await MCPRequestHandler.process_mcp_request(scope)
assert exc_info.value.status_code == 401
async def test_fallback_allowed_when_target_is_oauth2_mode(self):
"""
Operator-configured OAuth2 passthrough still works: target server has
``auth_type=oauth2`` failed LiteLLM auth falls back to anonymous so
the bearer can be forwarded to upstream.
"""
from fastapi import HTTPException
from litellm.types.mcp import MCPAuth
scope = {
"type": "http",
"method": "POST",
"path": "/mcp/atlassian_mcp",
"headers": [
(b"authorization", b"Bearer atlassian-oauth2-access-token-xyz"),
],
}
async def mock_user_api_key_auth_fails(api_key, request):
raise HTTPException(status_code=401, detail="Invalid API key")
with (
patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_fails,
),
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
) as mock_mgr,
):
mock_mgr.get_mcp_server_by_name.return_value = (
TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.oauth2)
)
(auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope)
assert isinstance(auth_result, UserAPIKeyAuth)
async def test_fallback_blocked_when_any_target_in_header_is_not_oauth2(self):
"""
x-mcp-servers can list multiple targets. If ANY of them is non-OAuth2,
the fallback must be blocked otherwise an attacker can mix one
OAuth2-mode server in to enable bypass against the others.
"""
from fastapi import HTTPException
from litellm.types.mcp import MCPAuth
scope = {
"type": "http",
"method": "POST",
"path": "/mcp",
"headers": [
(b"authorization", b"Bearer anything"),
(b"x-mcp-servers", b"oauth2_server,api_key_server"),
],
}
async def mock_user_api_key_auth_fails(api_key, request):
raise HTTPException(status_code=401, detail="Invalid API key")
def mock_lookup(name, client_ip=None):
if name == "oauth2_server":
return TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.oauth2)
return TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.api_key)
with (
patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_fails,
),
patch(
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
) as mock_mgr,
):
mock_mgr.get_mcp_server_by_name.side_effect = mock_lookup
with pytest.raises(HTTPException) as exc_info:
await MCPRequestHandler.process_mcp_request(scope)
assert exc_info.value.status_code == 401
async def test_proxy_exception_with_non_numeric_code_propagates(self):
"""
``ProxyException`` normalises ``code`` via ``str()`` in its __init__,
so callers may produce ``"None"`` or any non-numeric string when no
explicit code was supplied. The exception handler must not coerce
with ``int(...)`` (which would raise ``ValueError`` and rewrite the
auth error as an unhandled 500); it must simply re-raise.
"""
from litellm.proxy._types import ProxyException
scope = {
"type": "http",
"method": "POST",
"path": "/mcp/atlassian_mcp",
"headers": [(b"authorization", b"Bearer anything")],
}
async def mock_user_api_key_auth_no_code(api_key, request):
raise ProxyException(
message="Authentication Error",
type="auth_error",
param="api_key",
code=None,
)
with patch(
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
side_effect=mock_user_api_key_auth_no_code,
):
with pytest.raises(ProxyException):
await MCPRequestHandler.process_mcp_request(scope)
class TestMCPCustomHeaderName:
"""Test suite for custom MCP authentication header name functionality"""