From 466f06df6dca01a4a9a8d76db49a90bf4e351b11 Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Thu, 14 May 2026 00:33:36 +0530 Subject: [PATCH] fix(mcp): surface upstream 401 for token-forwarding MCP servers (#27847) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix(mcp): surface upstream 401 for token-forwarding MCP servers For MCP servers configured with extra_headers: [Authorization], the gateway forwards the client token directly to the upstream. When that token is rejected (expired or invalid) the upstream returns 401, but the MCP SDK starts the SSE stream with 200 OK before calling handlers, so the 401 can't be returned mid-stream. Fix: add a pre-flight httpx probe in handle_streamable_http_mcp — before the SDK opens the session — so the gateway can still return HTTP 401 with WWW-Authenticate: Bearer authorization_uri= when the upstream rejects the token. The probe fails-open (returns 200) on network errors so a transient hiccup does not block valid requests. Co-authored-by: Cursor * fix(mcp): parallelize pre-flight auth probes and use HEAD to avoid side effects - Extract forwarded_auth outside the pass-through server loop (was called N times for the same scope value) - Gather all upstream auth probes concurrently with asyncio.gather instead of sequentially; eliminates N×5 s worst-case latency - Switch probe from POST+initialize JSON-RPC body to HEAD request; HEAD carries the Authorization header so the upstream rejects invalid tokens with 401 but never allocates a session or writes an audit entry Co-authored-by: Cursor * fix(mcp): use get_async_httpx_client in _probe_upstream_auth Replaces bare httpx.AsyncClient with the project-standard get_async_httpx_client(httpxSpecialProvider.MCP) to satisfy the ensure_async_clients_test code coverage check and avoid the +500 ms per-request overhead of creating a new client on every probe call. Co-authored-by: Cursor * refactor(mcp): extract pre-flight probe into _check_passthrough_upstream_auth Moves the parallel upstream auth probe logic out of handle_streamable_http_mcp into a dedicated helper to satisfy Ruff PLR0915 (Too many statements > 50). Co-authored-by: Cursor * fix(mcp): gate pre-flight probes on authorized server set to prevent bypass _check_passthrough_upstream_auth was resolving user-supplied server names directly before authorization ran, letting any permitted LiteLLM key trigger an upstream HEAD probe to a server it was not allowed to use. Changes: - Call _get_allowed_mcp_servers inside the helper so only servers the caller's key is authorized for are probed. - Move the call site to after toolset scoping so the auth context is fully resolved before the probe list is built. - Thread user_api_key_auth into the helper signature (replaces the raw mcp_servers name list). Co-authored-by: Cursor * Add async HTTP HEAD support Co-authored-by: Yassin Kortam * fix(mcp): use Scope type annotation in _get_forwarded_auth_from_scope Co-authored-by: Cursor * Fix MCP upstream auth probe method Co-authored-by: Yassin Kortam * Remove unused AsyncHTTPHandler head method Co-authored-by: Yassin Kortam * fix(mcp): exclude has_client_credentials servers from pre-flight auth probe _prepare_mcp_server_headers skips caller Authorization when the server uses OAuth client-credentials (M2M), but the pre-flight probe was still selecting those servers and forwarding the caller's raw token in the HEAD request. Exclude servers with has_client_credentials from the probe list to match the actual downstream header-preparation logic. Co-authored-by: Cursor * fix(mcp): propagate upstream 403 as 403, not 401 with WWW-Authenticate Per RFC 9110, 401 means "go get new credentials." Mapping an upstream 403 to a gateway 401 causes OAuth clients to restart the authorization flow, obtain a fresh token with identical scopes, hit 403 again, and loop indefinitely. 401 from upstream → gateway 401 + WWW-Authenticate (re-authorize) 403 from upstream → gateway 403 (no WWW-Authenticate hint) Co-authored-by: Cursor * fix(mcp): skip auth probe when Authorization may be the LiteLLM proxy key The pre-flight upstream probe must not forward the caller's Authorization header when it could itself be the LiteLLM proxy API key. Restrict the probe to requests that supply x-litellm-api-key explicitly — only then is the Authorization header unambiguously the upstream OAuth token the caller wants forwarded. * Fix MCP ASGI HTTPException propagation Co-authored-by: Yassin Kortam * fix(mcp): use public AsyncHTTPHandler.post() in auth probe Use AsyncHTTPHandler.post() and catch httpx.HTTPStatusError explicitly so the 401/403 we want to surface is not silently swallowed by the broad fail-open except Exception block. Avoids reaching into the handler's private client attribute, which would silently regress to fail-open if AsyncHTTPHandler is ever refactored. * Fix MCP auth probe tests Co-authored-by: Yassin Kortam * test(mcp): add coverage for httpx.HTTPStatusError path in auth probe AsyncHTTPHandler.post() calls raise_for_status() internally, so a real upstream 401/403 lands as httpx.HTTPStatusError. Add a test that exercises that specific exception path so a regression that swallows the error in the broad fail-open except Exception would be caught. --------- Co-authored-by: Cursor Co-authored-by: Yassin Kortam Co-authored-by: claude-bot --- .../proxy/_experimental/mcp_server/server.py | 165 +++++++++++++++++- litellm/proxy/proxy_server.py | 11 +- .../mcp_server/test_mcp_server.py | 134 ++++++++++++++ .../proxy/test_mcp_asgi_response.py | 36 ++++ 4 files changed, 344 insertions(+), 2 deletions(-) create mode 100644 tests/test_litellm/proxy/test_mcp_asgi_response.py diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 276a6e8a3b..685232d049 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -23,6 +23,7 @@ from typing import ( cast, ) +import httpx from fastapi import FastAPI, HTTPException from pydantic import AnyUrl, ConfigDict from starlette.requests import Request as StarletteRequest @@ -51,13 +52,17 @@ from litellm.proxy._experimental.mcp_server.utils import ( get_server_prefix, iter_known_server_prefixes, ) +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.ip_address_utils import IPAddressUtils from litellm.proxy.litellm_pre_call_utils import ( LiteLLMProxyRequestSetup, get_chain_id_from_headers, ) -from litellm.types.mcp import MCPAuth +from litellm.types.mcp import MCPAuth, MCPSpecVersion from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall from litellm.utils import Rules, client, function_setup @@ -2754,6 +2759,157 @@ if MCP_AVAILABLE: ) return user_api_key_auth.model_copy(update={"object_permission": updated_op}) + def _get_forwarded_auth_from_scope(scope: Scope) -> Optional[str]: + """Return the upstream-bound ``Authorization`` header value, or None. + + Only returns the ``Authorization`` header when ``x-litellm-api-key`` is + also present. In that case ``Authorization`` is unambiguously the + upstream token the caller wants forwarded to the MCP server. When + ``x-litellm-api-key`` is absent the ``Authorization`` header may itself + be the LiteLLM proxy API key (backward-compat path in + ``MCPRequestHandler.process_mcp_request``), and forwarding it upstream + would leak the proxy key to a third-party MCP server. + """ + authorization = None + has_litellm_key_header = False + for key, value in scope.get("headers", []): + key_lower = key.lower() + if key_lower == b"authorization": + authorization = value.decode("latin-1") + elif key_lower == b"x-litellm-api-key": + has_litellm_key_header = True + if not has_litellm_key_header: + return None + return authorization + + async def _probe_upstream_auth( + url: str, + auth_header: str, + timeout: float = 5.0, + ) -> tuple: + """JSON-RPC initialize-probe the upstream URL to check whether the token is accepted. + + Uses POST so StreamableHTTP MCP servers run the same auth path as a + real client request. Returns (status_code, www_authenticate). + Fails-open with (200, None) on network errors so a transient hiccup + does not block valid requests. + + Uses the public ``AsyncHTTPHandler.post()`` interface and catches + ``httpx.HTTPStatusError`` separately so the 401/403 we want to surface + is not swallowed by the broad fail-open ``except Exception`` below. + """ + client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.MCP, + params={"timeout": timeout}, + ) + probe_payload = { + "jsonrpc": "2.0", + "id": "litellm-mcp-auth-probe", + "method": "initialize", + "params": { + "protocolVersion": MCPSpecVersion.jun_2025.value, + "capabilities": {}, + "clientInfo": { + "name": "litellm-mcp-auth-probe", + "version": "1.0.0", + }, + }, + } + probe_headers = { + "Authorization": auth_header, + "Accept": "application/json, text/event-stream", + } + try: + resp = await client.post( + url=url, + headers=probe_headers, + json=probe_payload, + timeout=timeout, + ) + return resp.status_code, resp.headers.get("www-authenticate") + except httpx.HTTPStatusError as exc: + # AsyncHTTPHandler.post() calls raise_for_status(); a 401/403 from + # upstream lands here. Return its status so the caller can map it + # to the appropriate response. + return exc.response.status_code, exc.response.headers.get( + "www-authenticate" + ) + except Exception as exc: + verbose_logger.debug( + f"_probe_upstream_auth: probe to {url} failed ({exc}), allowing request through" + ) + return 200, None + + async def _check_passthrough_upstream_auth( + scope: Scope, + user_api_key_auth: Optional[UserAPIKeyAuth], + mcp_servers: Optional[List[str]], + client_ip: Optional[str], + ) -> None: + """Probe pass-through upstream servers in parallel before the MCP session starts. + + Only servers the caller's key is already authorized to reach are probed — + the list is derived from _get_allowed_mcp_servers so that a user cannot + trigger an upstream probe against a server their key is not permitted for. + + The MCP SDK commits HTTP 200 headers before invoking handlers, so a 401 + can only be returned before that point. This function raises HTTPException(401) + with a WWW-Authenticate header if any upstream rejects the client token. + Fails-open: network errors are logged and the request is allowed through. + """ + forwarded_auth = _get_forwarded_auth_from_scope(scope) + if not forwarded_auth: + return + + # Use the authorized server set, not the raw user-supplied names, so that + # a caller cannot force a probe to a server their key is not allowed to use. + allowed_servers = await _get_allowed_mcp_servers( + user_api_key_auth=user_api_key_auth, + mcp_servers=mcp_servers, + client_ip=client_ip, + ) + passthrough_servers = [ + srv + for srv in allowed_servers + if srv.extra_headers + and any(h.lower() == "authorization" for h in srv.extra_headers) + # Exclude M2M servers: _prepare_mcp_server_headers skips caller + # Authorization when has_client_credentials is set, so probing + # those with the caller's token would send the wrong credential. + and not srv.has_client_credentials + ] + if not passthrough_servers: + return + + probe_results = await asyncio.gather( + *[ + _probe_upstream_auth(srv.url or "", forwarded_auth) + for srv in passthrough_servers + ] + ) + request = StarletteRequest(scope) + base_url = get_request_base_url(request) + for srv, (probe_status, _) in zip(passthrough_servers, probe_results): + if probe_status == 401: + # Token is missing or expired — direct the client to re-authorize. + authorization_uri = ( + f"Bearer authorization_uri=" + f"{base_url}/.well-known/oauth-authorization-server/{srv.name}" + ) + raise HTTPException( + status_code=401, + detail="Unauthorized", + headers={"WWW-Authenticate": authorization_uri}, + ) + if probe_status == 403: + # Token is valid but the caller lacks permission — do not hint + # at re-authorization (RFC 9110: a fresh token with the same + # scopes would just hit 403 again and loop indefinitely). + raise HTTPException( + status_code=403, + detail="Forbidden", + ) + async def handle_streamable_http_mcp( scope: Scope, receive: Receive, send: Send ) -> None: @@ -2827,6 +2983,13 @@ if MCP_AVAILABLE: user_api_key_auth, active_toolset_id ) + # Pre-flight auth check for pass-through servers. Must run after + # toolset scoping so the probe list is derived from the fully-authorized + # server set, not the raw user-supplied names. + await _check_passthrough_upstream_auth( + scope, user_api_key_auth, mcp_servers, _client_ip + ) + # Inject masked debug headers when client sends x-litellm-mcp-debug: true _debug_headers = MCPDebug.maybe_build_debug_headers( raw_headers=raw_headers, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5851d2550b..ecef96351a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -15026,8 +15026,17 @@ async def _stream_mcp_asgi_response( # If the handler task dies (exception or cancellation) without sending the EOF # sentinel, body_iter() would block forever on body_queue.get(). The callback # below guarantees the queue gets unblocked regardless of how the task ends. + # When this happens before response headers, propagate the original exception + # instead of waiting for the header timeout. def _ensure_eof(task: asyncio.Task) -> None: - if task.cancelled() or task.exception() is not None: + if task.cancelled(): + body_queue.put_nowait(None) + return + + task_exception = task.exception() + if task_exception is not None: + if not headers_ready.done(): + headers_ready.set_exception(task_exception) body_queue.put_nowait(None) handler_task.add_done_callback(_ensure_eof) diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index a7649502bd..e1eddfc9c7 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -3256,3 +3256,137 @@ async def test_call_tool_empty_extra_headers_returns_none(): ), "P2 API consistency issue: expected None for empty extra_headers, got: " + str( captured_extra_headers ) + + +# --------------------------------------------------------------------------- +# Pre-flight upstream auth check tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_probe_upstream_auth_returns_upstream_status(): + """_probe_upstream_auth forwards the status code from the upstream server.""" + from litellm.proxy._experimental.mcp_server.server import _probe_upstream_auth + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.headers = {"www-authenticate": 'Bearer realm="test"'} + + mock_client = MagicMock() + mock_client.post = AsyncMock(return_value=mock_response) + + with patch( + "litellm.proxy._experimental.mcp_server.server.get_async_httpx_client", + return_value=mock_client, + ): + status, www_auth = await _probe_upstream_auth( + "http://upstream/mcp", "Bearer some-token" + ) + + assert status == 401 + assert www_auth == 'Bearer realm="test"' + mock_client.post.assert_awaited_once() + _, kwargs = mock_client.post.call_args + assert kwargs["headers"]["Authorization"] == "Bearer some-token" + assert kwargs["json"]["method"] == "initialize" + + +@pytest.mark.asyncio +async def test_probe_upstream_auth_surfaces_httpx_status_error(): + """Probe extracts status + WWW-Authenticate from httpx.HTTPStatusError. + + AsyncHTTPHandler.post() calls raise_for_status() internally, so when the + upstream returns 401/403 the call raises httpx.HTTPStatusError rather than + returning the response. The probe must catch that specifically (before the + fail-open `except Exception`) so the auth check is not silently defeated. + """ + import httpx + + from litellm.proxy._experimental.mcp_server.server import _probe_upstream_auth + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.headers = {"www-authenticate": 'Bearer realm="test"'} + request = httpx.Request("POST", "http://upstream/mcp") + error = httpx.HTTPStatusError( + message="401 Unauthorized", request=request, response=mock_response + ) + + mock_client = MagicMock() + mock_client.post = AsyncMock(side_effect=error) + + with patch( + "litellm.proxy._experimental.mcp_server.server.get_async_httpx_client", + return_value=mock_client, + ): + status, www_auth = await _probe_upstream_auth( + "http://upstream/mcp", "Bearer some-token" + ) + + assert status == 401 + assert www_auth == 'Bearer realm="test"' + + +@pytest.mark.asyncio +async def test_probe_upstream_auth_fails_open_on_network_error(): + """_probe_upstream_auth returns (200, None) when the network call fails.""" + from litellm.proxy._experimental.mcp_server.server import _probe_upstream_auth + + mock_client = MagicMock() + mock_client.post = AsyncMock(side_effect=Exception("connection refused")) + + with patch( + "litellm.proxy._experimental.mcp_server.server.get_async_httpx_client", + return_value=mock_client, + ): + status, www_auth = await _probe_upstream_auth( + "http://upstream/mcp", "Bearer some-token" + ) + + assert status == 200 + assert www_auth is None + + +def test_get_forwarded_auth_from_scope_extracts_header(): + """Returns Authorization value when x-litellm-api-key is also present.""" + from litellm.proxy._experimental.mcp_server.server import ( + _get_forwarded_auth_from_scope, + ) + + scope = { + "headers": [ + (b"content-type", b"application/json"), + (b"x-litellm-api-key", b"sk-litellm-proxy-key"), + (b"authorization", b"Bearer my-token"), + ] + } + assert _get_forwarded_auth_from_scope(scope) == "Bearer my-token" + + +def test_get_forwarded_auth_from_scope_returns_none_when_missing(): + from litellm.proxy._experimental.mcp_server.server import ( + _get_forwarded_auth_from_scope, + ) + + assert _get_forwarded_auth_from_scope({"headers": []}) is None + + +def test_get_forwarded_auth_from_scope_skips_when_no_litellm_key_header(): + """Skip when ``x-litellm-api-key`` is absent. + + Without ``x-litellm-api-key``, the ``Authorization`` header may itself be + the LiteLLM proxy API key (backward-compat). Forwarding it upstream would + leak the proxy key, so the helper must return None and the probe must + not fire. + """ + from litellm.proxy._experimental.mcp_server.server import ( + _get_forwarded_auth_from_scope, + ) + + scope = { + "headers": [ + (b"content-type", b"application/json"), + (b"authorization", b"Bearer ambiguous-token"), + ] + } + assert _get_forwarded_auth_from_scope(scope) is None diff --git a/tests/test_litellm/proxy/test_mcp_asgi_response.py b/tests/test_litellm/proxy/test_mcp_asgi_response.py new file mode 100644 index 0000000000..d030f65af4 --- /dev/null +++ b/tests/test_litellm/proxy/test_mcp_asgi_response.py @@ -0,0 +1,36 @@ +import asyncio + +import pytest +from fastapi import HTTPException + +from litellm.proxy.proxy_server import _stream_mcp_asgi_response + + +@pytest.mark.asyncio +async def test_stream_mcp_asgi_response_propagates_pre_header_http_exception(): + async def handle_fn(_scope, _receive, _send): + raise HTTPException( + status_code=401, + detail="Unauthorized", + headers={ + "WWW-Authenticate": "Bearer authorization_uri=https://example.test/auth" + }, + ) + + async def receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + with pytest.raises(HTTPException) as exc_info: + await asyncio.wait_for( + _stream_mcp_asgi_response( + handle_fn, + {"type": "http", "method": "POST", "path": "/mcp", "headers": []}, + receive, + ), + timeout=1.0, + ) + + assert exc_info.value.status_code == 401 + assert exc_info.value.headers == { + "WWW-Authenticate": "Bearer authorization_uri=https://example.test/auth" + }