Merge pull request #26463 from stuxf/fix/mcp-routing-auth
fix(mcp): tighten public-route detection and OAuth2 fallback gating
This commit is contained in:
commit
2e561bd04e
@ -117,7 +117,10 @@ class MCPRequestHandler:
|
||||
return b"{}"
|
||||
|
||||
request.body = mock_body # type: ignore
|
||||
if ".well-known" in str(request.url): # public routes
|
||||
# Only OAuth metadata routes registered under /.well-known/ are public.
|
||||
# Match on request.url.path (path-only, exact prefix) so the substring
|
||||
# cannot be smuggled via query string, hostname, or a deeper URL segment.
|
||||
if request.url.path.startswith("/.well-known/"):
|
||||
validated_user_api_key_auth = UserAPIKeyAuth()
|
||||
elif has_explicit_litellm_key:
|
||||
# Explicit x-litellm-api-key provided - always validate normally
|
||||
@ -126,27 +129,37 @@ class MCPRequestHandler:
|
||||
)
|
||||
elif oauth2_headers:
|
||||
# No x-litellm-api-key, but Authorization header present.
|
||||
# Could be a LiteLLM key (backward compat) OR an OAuth2 token
|
||||
# from an upstream MCP provider (e.g. Atlassian).
|
||||
# Try LiteLLM auth first; on auth failure, treat as OAuth2 passthrough.
|
||||
# Could be a LiteLLM key (backward compat) OR an opaque OAuth2 token
|
||||
# the operator wants forwarded to an upstream OAuth2-mode MCP server.
|
||||
# Try LiteLLM auth first; on auth failure, only fall back to anonymous
|
||||
# passthrough when the request actually targets a server whose operator
|
||||
# configured ``auth_type=oauth2``. For any other server (api_key,
|
||||
# bearer_token, basic, etc.), a failed LiteLLM auth is a real failure
|
||||
# and must propagate — otherwise an attacker can exchange any garbage
|
||||
# bearer for an anonymous session.
|
||||
try:
|
||||
validated_user_api_key_auth = await user_api_key_auth(
|
||||
api_key=litellm_api_key, request=request
|
||||
)
|
||||
except HTTPException as e:
|
||||
if e.status_code in (401, 403):
|
||||
except (HTTPException, ProxyException) as e:
|
||||
# HTTPException.status_code is int; ProxyException.code is
|
||||
# normalized to str in its __init__ but can be ``"None"`` or any
|
||||
# non-numeric string when the caller didn't supply a numeric
|
||||
# code, so we compare against both int and str forms rather
|
||||
# than coercing (``int("None")`` would raise ValueError and
|
||||
# rewrite the auth error as a 500).
|
||||
status = e.status_code if isinstance(e, HTTPException) else e.code
|
||||
if status in (
|
||||
401,
|
||||
403,
|
||||
"401",
|
||||
"403",
|
||||
) and MCPRequestHandler._target_servers_use_oauth2(
|
||||
path=request.url.path, mcp_servers=mcp_servers
|
||||
):
|
||||
verbose_logger.debug(
|
||||
"MCP OAuth2: Authorization header is not a valid LiteLLM key, "
|
||||
"treating as OAuth2 token passthrough"
|
||||
)
|
||||
validated_user_api_key_auth = UserAPIKeyAuth()
|
||||
else:
|
||||
raise
|
||||
except ProxyException as e:
|
||||
if str(e.code) in ("401", "403"):
|
||||
verbose_logger.debug(
|
||||
"MCP OAuth2: Authorization header is not a valid LiteLLM key, "
|
||||
"treating as OAuth2 token passthrough"
|
||||
"MCP OAuth2: target server is OAuth2-mode, treating "
|
||||
"Authorization as upstream OAuth2 token passthrough"
|
||||
)
|
||||
validated_user_api_key_auth = UserAPIKeyAuth()
|
||||
else:
|
||||
@ -165,6 +178,62 @@ class MCPRequestHandler:
|
||||
dict(headers),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_target_server_names_from_path(path: str) -> List[str]:
|
||||
"""
|
||||
Extract the target MCP server name from the standard MCP transport
|
||||
URL patterns: ``/mcp/{server_name}[/...]`` and
|
||||
``/{server_name}/mcp[/...]``. Returns ``[]`` for any other path so
|
||||
callers fail closed when the target cannot be resolved.
|
||||
|
||||
REST/admin endpoints, OAuth2 server endpoints
|
||||
(``/{server_name}/authorize``, ``/token`` etc.), and ``.well-known``
|
||||
discovery routes intentionally fall through — those flows do not need
|
||||
OAuth2 token passthrough. Clients aggregating multiple servers should
|
||||
use ``x-mcp-servers``, which takes precedence over path parsing.
|
||||
"""
|
||||
segments = [s for s in path.split("/") if s]
|
||||
if len(segments) >= 2 and segments[0] == "mcp":
|
||||
return [segments[1]]
|
||||
if len(segments) >= 2 and segments[1] == "mcp":
|
||||
return [segments[0]]
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _target_servers_use_oauth2(path: str, mcp_servers: Optional[List[str]]) -> bool:
|
||||
"""
|
||||
True only when EVERY MCP server the request targets is configured for
|
||||
``auth_type == oauth2``. If any target is non-OAuth2 — or if the target
|
||||
cannot be resolved at all — return False so the caller fails closed.
|
||||
|
||||
Used to gate the "treat Authorization as opaque OAuth2 token" fallback
|
||||
in :meth:`process_mcp_request` so a failed LiteLLM-auth cannot be
|
||||
exchanged for an anonymous session against a non-OAuth2 server.
|
||||
"""
|
||||
# Inline imports avoid a circular dependency: mcp_server_manager imports
|
||||
# from this module.
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
global_mcp_server_manager,
|
||||
)
|
||||
from litellm.types.mcp import MCPAuth
|
||||
|
||||
# Use the x-mcp-servers header verbatim when present (including the
|
||||
# explicitly-empty list, which means "no targets" → fail closed).
|
||||
# Only fall back to path parsing when the header was absent entirely.
|
||||
target_names = (
|
||||
mcp_servers
|
||||
if mcp_servers is not None
|
||||
else MCPRequestHandler._extract_target_server_names_from_path(path)
|
||||
)
|
||||
if not target_names:
|
||||
return False
|
||||
|
||||
for name in target_names:
|
||||
server = global_mcp_server_manager.get_mcp_server_by_name(name)
|
||||
if server is None or server.auth_type != MCPAuth.oauth2:
|
||||
return False
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def _get_mcp_auth_header_from_headers(headers: Headers) -> Optional[str]:
|
||||
"""
|
||||
|
||||
@ -551,11 +551,14 @@ class TestMCPOAuth2AuthFlow:
|
||||
|
||||
async def test_oauth2_token_in_authorization_header_fallback(self):
|
||||
"""
|
||||
When only Authorization header is present with a non-LiteLLM OAuth2 token,
|
||||
When only Authorization header is present with a non-LiteLLM OAuth2 token
|
||||
AND the target server is operator-configured for ``auth_type=oauth2``,
|
||||
auth should fall back to permissive mode (OAuth2 passthrough).
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.types.mcp import MCPAuth
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
@ -568,10 +571,19 @@ class TestMCPOAuth2AuthFlow:
|
||||
async def mock_user_api_key_auth_fails(api_key, request):
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
side_effect=mock_user_api_key_auth_fails,
|
||||
oauth2_server = MagicMock()
|
||||
oauth2_server.auth_type = MCPAuth.oauth2
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
side_effect=mock_user_api_key_auth_fails,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
|
||||
) as mock_mgr,
|
||||
):
|
||||
mock_mgr.get_mcp_server_by_name.return_value = oauth2_server
|
||||
(
|
||||
auth_result,
|
||||
mcp_auth_header,
|
||||
@ -695,9 +707,11 @@ class TestMCPOAuth2AuthFlow:
|
||||
async def test_proxy_exception_oauth2_fallback(self):
|
||||
"""
|
||||
user_api_key_auth raises ProxyException (not HTTPException) in production.
|
||||
The OAuth2 fallback must catch ProxyException with code 401/403 too.
|
||||
The OAuth2 fallback must catch ProxyException with code 401/403 too,
|
||||
but only when the target server is operator-configured for ``auth_type=oauth2``.
|
||||
"""
|
||||
from litellm.proxy._types import ProxyException
|
||||
from litellm.types.mcp import MCPAuth
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
@ -716,10 +730,19 @@ class TestMCPOAuth2AuthFlow:
|
||||
code=401,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
side_effect=mock_user_api_key_auth_proxy_exception,
|
||||
oauth2_server = MagicMock()
|
||||
oauth2_server.auth_type = MCPAuth.oauth2
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
side_effect=mock_user_api_key_auth_proxy_exception,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
|
||||
) as mock_mgr,
|
||||
):
|
||||
mock_mgr.get_mcp_server_by_name.return_value = oauth2_server
|
||||
(
|
||||
auth_result,
|
||||
mcp_auth_header,
|
||||
@ -768,6 +791,290 @@ class TestMCPOAuth2AuthFlow:
|
||||
await MCPRequestHandler.process_mcp_request(scope)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestMCPPublicRouteGuard:
|
||||
"""
|
||||
Regression tests for GHSA-7cwm-3279-qf3c / HW6xR21d:
|
||||
the public-route bypass at the top of process_mcp_request must match
|
||||
the exact `/.well-known/` path prefix, not a substring of the URL.
|
||||
"""
|
||||
|
||||
async def test_well_known_substring_in_query_does_not_bypass_auth(self):
|
||||
"""
|
||||
URL with `.well-known` smuggled into the query string must still
|
||||
require valid LiteLLM auth.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": "/mcp/private_server",
|
||||
"query_string": b"redirect=.well-known/oauth-protected-resource",
|
||||
"headers": [(b"authorization", b"Bearer sk-bogus")],
|
||||
}
|
||||
|
||||
async def mock_user_api_key_auth_fails(api_key, request):
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
side_effect=mock_user_api_key_auth_fails,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
|
||||
) as mock_mgr,
|
||||
):
|
||||
# Explicit unresolvable target — proves auth still fails even
|
||||
# when the registry has no info to fall back to.
|
||||
mock_mgr.get_mcp_server_by_name.return_value = None
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await MCPRequestHandler.process_mcp_request(scope)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
async def test_well_known_segment_in_middle_of_path_does_not_bypass_auth(self):
|
||||
"""
|
||||
Path containing `.well-known` as a non-prefix component (e.g. a server
|
||||
name or sub-path) must still require auth.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": "/mcp/.well-known-fake/tools",
|
||||
"headers": [(b"authorization", b"Bearer sk-bogus")],
|
||||
}
|
||||
|
||||
async def mock_user_api_key_auth_fails(api_key, request):
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
side_effect=mock_user_api_key_auth_fails,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
|
||||
) as mock_mgr,
|
||||
):
|
||||
mock_mgr.get_mcp_server_by_name.return_value = None
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await MCPRequestHandler.process_mcp_request(scope)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
async def test_legitimate_well_known_path_still_bypasses_auth(self):
|
||||
"""
|
||||
Real OAuth discovery routes registered under /.well-known/ must remain
|
||||
public so unauthenticated clients can fetch them per RFC 8414/9728.
|
||||
"""
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "GET",
|
||||
"path": "/.well-known/oauth-protected-resource",
|
||||
"headers": [],
|
||||
}
|
||||
|
||||
# No mock needed — public path should not call user_api_key_auth at all
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
) as mock_auth:
|
||||
(auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope)
|
||||
mock_auth.assert_not_called()
|
||||
assert isinstance(auth_result, UserAPIKeyAuth)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestMCPOAuth2FallbackTargetGating:
|
||||
"""
|
||||
Regression tests for GHSA-h8fm-g6wc-j228 / HW6xR21d:
|
||||
The OAuth2 passthrough fallback must only fire when the target MCP server
|
||||
is operator-configured for ``auth_type=oauth2``. A failed LiteLLM-auth
|
||||
against a non-OAuth2 server (api_key, bearer_token, basic, etc.) must
|
||||
propagate as a real auth error, not be exchanged for an anonymous session.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_server(auth_type):
|
||||
server = MagicMock()
|
||||
server.auth_type = auth_type
|
||||
return server
|
||||
|
||||
async def test_fallback_blocked_when_target_is_not_oauth2(self):
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.types.mcp import MCPAuth
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": "/mcp/api_key_server",
|
||||
"headers": [(b"authorization", b"Bearer anything-at-all")],
|
||||
}
|
||||
|
||||
async def mock_user_api_key_auth_fails(api_key, request):
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
side_effect=mock_user_api_key_auth_fails,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
|
||||
) as mock_mgr,
|
||||
):
|
||||
mock_mgr.get_mcp_server_by_name.return_value = (
|
||||
TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.api_key)
|
||||
)
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await MCPRequestHandler.process_mcp_request(scope)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
async def test_fallback_blocked_when_target_unresolvable(self):
|
||||
"""
|
||||
If the target server cannot be resolved from path or x-mcp-servers,
|
||||
we cannot prove it is OAuth2-mode, so we must fail closed.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": "/mcp/never_registered_server",
|
||||
"headers": [(b"authorization", b"Bearer anything")],
|
||||
}
|
||||
|
||||
async def mock_user_api_key_auth_fails(api_key, request):
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
side_effect=mock_user_api_key_auth_fails,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
|
||||
) as mock_mgr,
|
||||
):
|
||||
mock_mgr.get_mcp_server_by_name.return_value = None
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await MCPRequestHandler.process_mcp_request(scope)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
async def test_fallback_allowed_when_target_is_oauth2_mode(self):
|
||||
"""
|
||||
Operator-configured OAuth2 passthrough still works: target server has
|
||||
``auth_type=oauth2`` → failed LiteLLM auth falls back to anonymous so
|
||||
the bearer can be forwarded to upstream.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.types.mcp import MCPAuth
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": "/mcp/atlassian_mcp",
|
||||
"headers": [
|
||||
(b"authorization", b"Bearer atlassian-oauth2-access-token-xyz"),
|
||||
],
|
||||
}
|
||||
|
||||
async def mock_user_api_key_auth_fails(api_key, request):
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
side_effect=mock_user_api_key_auth_fails,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
|
||||
) as mock_mgr,
|
||||
):
|
||||
mock_mgr.get_mcp_server_by_name.return_value = (
|
||||
TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.oauth2)
|
||||
)
|
||||
(auth_result, *_rest) = await MCPRequestHandler.process_mcp_request(scope)
|
||||
assert isinstance(auth_result, UserAPIKeyAuth)
|
||||
|
||||
async def test_fallback_blocked_when_any_target_in_header_is_not_oauth2(self):
|
||||
"""
|
||||
x-mcp-servers can list multiple targets. If ANY of them is non-OAuth2,
|
||||
the fallback must be blocked — otherwise an attacker can mix one
|
||||
OAuth2-mode server in to enable bypass against the others.
|
||||
"""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.types.mcp import MCPAuth
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": "/mcp",
|
||||
"headers": [
|
||||
(b"authorization", b"Bearer anything"),
|
||||
(b"x-mcp-servers", b"oauth2_server,api_key_server"),
|
||||
],
|
||||
}
|
||||
|
||||
async def mock_user_api_key_auth_fails(api_key, request):
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
|
||||
def mock_lookup(name, client_ip=None):
|
||||
if name == "oauth2_server":
|
||||
return TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.oauth2)
|
||||
return TestMCPOAuth2FallbackTargetGating._make_server(MCPAuth.api_key)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
side_effect=mock_user_api_key_auth_fails,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.mcp_server_manager.global_mcp_server_manager"
|
||||
) as mock_mgr,
|
||||
):
|
||||
mock_mgr.get_mcp_server_by_name.side_effect = mock_lookup
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await MCPRequestHandler.process_mcp_request(scope)
|
||||
assert exc_info.value.status_code == 401
|
||||
|
||||
async def test_proxy_exception_with_non_numeric_code_propagates(self):
|
||||
"""
|
||||
``ProxyException`` normalises ``code`` via ``str()`` in its __init__,
|
||||
so callers may produce ``"None"`` or any non-numeric string when no
|
||||
explicit code was supplied. The exception handler must not coerce
|
||||
with ``int(...)`` (which would raise ``ValueError`` and rewrite the
|
||||
auth error as an unhandled 500); it must simply re-raise.
|
||||
"""
|
||||
from litellm.proxy._types import ProxyException
|
||||
|
||||
scope = {
|
||||
"type": "http",
|
||||
"method": "POST",
|
||||
"path": "/mcp/atlassian_mcp",
|
||||
"headers": [(b"authorization", b"Bearer anything")],
|
||||
}
|
||||
|
||||
async def mock_user_api_key_auth_no_code(api_key, request):
|
||||
raise ProxyException(
|
||||
message="Authentication Error",
|
||||
type="auth_error",
|
||||
param="api_key",
|
||||
code=None,
|
||||
)
|
||||
|
||||
with patch(
|
||||
"litellm.proxy._experimental.mcp_server.auth.user_api_key_auth_mcp.user_api_key_auth",
|
||||
side_effect=mock_user_api_key_auth_no_code,
|
||||
):
|
||||
with pytest.raises(ProxyException):
|
||||
await MCPRequestHandler.process_mcp_request(scope)
|
||||
|
||||
|
||||
class TestMCPCustomHeaderName:
|
||||
"""Test suite for custom MCP authentication header name functionality"""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user