diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 276a6e8a3b..685232d049 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -23,6 +23,7 @@ from typing import ( cast, ) +import httpx from fastapi import FastAPI, HTTPException from pydantic import AnyUrl, ConfigDict from starlette.requests import Request as StarletteRequest @@ -51,13 +52,17 @@ from litellm.proxy._experimental.mcp_server.utils import ( get_server_prefix, iter_known_server_prefixes, ) +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy.auth.ip_address_utils import IPAddressUtils from litellm.proxy.litellm_pre_call_utils import ( LiteLLMProxyRequestSetup, get_chain_id_from_headers, ) -from litellm.types.mcp import MCPAuth +from litellm.types.mcp import MCPAuth, MCPSpecVersion from litellm.types.mcp_server.mcp_server_manager import MCPInfo, MCPServer from litellm.types.utils import CallTypes, StandardLoggingMCPToolCall from litellm.utils import Rules, client, function_setup @@ -2754,6 +2759,157 @@ if MCP_AVAILABLE: ) return user_api_key_auth.model_copy(update={"object_permission": updated_op}) + def _get_forwarded_auth_from_scope(scope: Scope) -> Optional[str]: + """Return the upstream-bound ``Authorization`` header value, or None. + + Only returns the ``Authorization`` header when ``x-litellm-api-key`` is + also present. In that case ``Authorization`` is unambiguously the + upstream token the caller wants forwarded to the MCP server. When + ``x-litellm-api-key`` is absent the ``Authorization`` header may itself + be the LiteLLM proxy API key (backward-compat path in + ``MCPRequestHandler.process_mcp_request``), and forwarding it upstream + would leak the proxy key to a third-party MCP server. + """ + authorization = None + has_litellm_key_header = False + for key, value in scope.get("headers", []): + key_lower = key.lower() + if key_lower == b"authorization": + authorization = value.decode("latin-1") + elif key_lower == b"x-litellm-api-key": + has_litellm_key_header = True + if not has_litellm_key_header: + return None + return authorization + + async def _probe_upstream_auth( + url: str, + auth_header: str, + timeout: float = 5.0, + ) -> tuple: + """JSON-RPC initialize-probe the upstream URL to check whether the token is accepted. + + Uses POST so StreamableHTTP MCP servers run the same auth path as a + real client request. Returns (status_code, www_authenticate). + Fails-open with (200, None) on network errors so a transient hiccup + does not block valid requests. + + Uses the public ``AsyncHTTPHandler.post()`` interface and catches + ``httpx.HTTPStatusError`` separately so the 401/403 we want to surface + is not swallowed by the broad fail-open ``except Exception`` below. + """ + client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.MCP, + params={"timeout": timeout}, + ) + probe_payload = { + "jsonrpc": "2.0", + "id": "litellm-mcp-auth-probe", + "method": "initialize", + "params": { + "protocolVersion": MCPSpecVersion.jun_2025.value, + "capabilities": {}, + "clientInfo": { + "name": "litellm-mcp-auth-probe", + "version": "1.0.0", + }, + }, + } + probe_headers = { + "Authorization": auth_header, + "Accept": "application/json, text/event-stream", + } + try: + resp = await client.post( + url=url, + headers=probe_headers, + json=probe_payload, + timeout=timeout, + ) + return resp.status_code, resp.headers.get("www-authenticate") + except httpx.HTTPStatusError as exc: + # AsyncHTTPHandler.post() calls raise_for_status(); a 401/403 from + # upstream lands here. Return its status so the caller can map it + # to the appropriate response. + return exc.response.status_code, exc.response.headers.get( + "www-authenticate" + ) + except Exception as exc: + verbose_logger.debug( + f"_probe_upstream_auth: probe to {url} failed ({exc}), allowing request through" + ) + return 200, None + + async def _check_passthrough_upstream_auth( + scope: Scope, + user_api_key_auth: Optional[UserAPIKeyAuth], + mcp_servers: Optional[List[str]], + client_ip: Optional[str], + ) -> None: + """Probe pass-through upstream servers in parallel before the MCP session starts. + + Only servers the caller's key is already authorized to reach are probed — + the list is derived from _get_allowed_mcp_servers so that a user cannot + trigger an upstream probe against a server their key is not permitted for. + + The MCP SDK commits HTTP 200 headers before invoking handlers, so a 401 + can only be returned before that point. This function raises HTTPException(401) + with a WWW-Authenticate header if any upstream rejects the client token. + Fails-open: network errors are logged and the request is allowed through. + """ + forwarded_auth = _get_forwarded_auth_from_scope(scope) + if not forwarded_auth: + return + + # Use the authorized server set, not the raw user-supplied names, so that + # a caller cannot force a probe to a server their key is not allowed to use. + allowed_servers = await _get_allowed_mcp_servers( + user_api_key_auth=user_api_key_auth, + mcp_servers=mcp_servers, + client_ip=client_ip, + ) + passthrough_servers = [ + srv + for srv in allowed_servers + if srv.extra_headers + and any(h.lower() == "authorization" for h in srv.extra_headers) + # Exclude M2M servers: _prepare_mcp_server_headers skips caller + # Authorization when has_client_credentials is set, so probing + # those with the caller's token would send the wrong credential. + and not srv.has_client_credentials + ] + if not passthrough_servers: + return + + probe_results = await asyncio.gather( + *[ + _probe_upstream_auth(srv.url or "", forwarded_auth) + for srv in passthrough_servers + ] + ) + request = StarletteRequest(scope) + base_url = get_request_base_url(request) + for srv, (probe_status, _) in zip(passthrough_servers, probe_results): + if probe_status == 401: + # Token is missing or expired — direct the client to re-authorize. + authorization_uri = ( + f"Bearer authorization_uri=" + f"{base_url}/.well-known/oauth-authorization-server/{srv.name}" + ) + raise HTTPException( + status_code=401, + detail="Unauthorized", + headers={"WWW-Authenticate": authorization_uri}, + ) + if probe_status == 403: + # Token is valid but the caller lacks permission — do not hint + # at re-authorization (RFC 9110: a fresh token with the same + # scopes would just hit 403 again and loop indefinitely). + raise HTTPException( + status_code=403, + detail="Forbidden", + ) + async def handle_streamable_http_mcp( scope: Scope, receive: Receive, send: Send ) -> None: @@ -2827,6 +2983,13 @@ if MCP_AVAILABLE: user_api_key_auth, active_toolset_id ) + # Pre-flight auth check for pass-through servers. Must run after + # toolset scoping so the probe list is derived from the fully-authorized + # server set, not the raw user-supplied names. + await _check_passthrough_upstream_auth( + scope, user_api_key_auth, mcp_servers, _client_ip + ) + # Inject masked debug headers when client sends x-litellm-mcp-debug: true _debug_headers = MCPDebug.maybe_build_debug_headers( raw_headers=raw_headers, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 5851d2550b..ecef96351a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -15026,8 +15026,17 @@ async def _stream_mcp_asgi_response( # If the handler task dies (exception or cancellation) without sending the EOF # sentinel, body_iter() would block forever on body_queue.get(). The callback # below guarantees the queue gets unblocked regardless of how the task ends. + # When this happens before response headers, propagate the original exception + # instead of waiting for the header timeout. def _ensure_eof(task: asyncio.Task) -> None: - if task.cancelled() or task.exception() is not None: + if task.cancelled(): + body_queue.put_nowait(None) + return + + task_exception = task.exception() + if task_exception is not None: + if not headers_ready.done(): + headers_ready.set_exception(task_exception) body_queue.put_nowait(None) handler_task.add_done_callback(_ensure_eof) diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index a7649502bd..e1eddfc9c7 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -3256,3 +3256,137 @@ async def test_call_tool_empty_extra_headers_returns_none(): ), "P2 API consistency issue: expected None for empty extra_headers, got: " + str( captured_extra_headers ) + + +# --------------------------------------------------------------------------- +# Pre-flight upstream auth check tests +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_probe_upstream_auth_returns_upstream_status(): + """_probe_upstream_auth forwards the status code from the upstream server.""" + from litellm.proxy._experimental.mcp_server.server import _probe_upstream_auth + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.headers = {"www-authenticate": 'Bearer realm="test"'} + + mock_client = MagicMock() + mock_client.post = AsyncMock(return_value=mock_response) + + with patch( + "litellm.proxy._experimental.mcp_server.server.get_async_httpx_client", + return_value=mock_client, + ): + status, www_auth = await _probe_upstream_auth( + "http://upstream/mcp", "Bearer some-token" + ) + + assert status == 401 + assert www_auth == 'Bearer realm="test"' + mock_client.post.assert_awaited_once() + _, kwargs = mock_client.post.call_args + assert kwargs["headers"]["Authorization"] == "Bearer some-token" + assert kwargs["json"]["method"] == "initialize" + + +@pytest.mark.asyncio +async def test_probe_upstream_auth_surfaces_httpx_status_error(): + """Probe extracts status + WWW-Authenticate from httpx.HTTPStatusError. + + AsyncHTTPHandler.post() calls raise_for_status() internally, so when the + upstream returns 401/403 the call raises httpx.HTTPStatusError rather than + returning the response. The probe must catch that specifically (before the + fail-open `except Exception`) so the auth check is not silently defeated. + """ + import httpx + + from litellm.proxy._experimental.mcp_server.server import _probe_upstream_auth + + mock_response = MagicMock() + mock_response.status_code = 401 + mock_response.headers = {"www-authenticate": 'Bearer realm="test"'} + request = httpx.Request("POST", "http://upstream/mcp") + error = httpx.HTTPStatusError( + message="401 Unauthorized", request=request, response=mock_response + ) + + mock_client = MagicMock() + mock_client.post = AsyncMock(side_effect=error) + + with patch( + "litellm.proxy._experimental.mcp_server.server.get_async_httpx_client", + return_value=mock_client, + ): + status, www_auth = await _probe_upstream_auth( + "http://upstream/mcp", "Bearer some-token" + ) + + assert status == 401 + assert www_auth == 'Bearer realm="test"' + + +@pytest.mark.asyncio +async def test_probe_upstream_auth_fails_open_on_network_error(): + """_probe_upstream_auth returns (200, None) when the network call fails.""" + from litellm.proxy._experimental.mcp_server.server import _probe_upstream_auth + + mock_client = MagicMock() + mock_client.post = AsyncMock(side_effect=Exception("connection refused")) + + with patch( + "litellm.proxy._experimental.mcp_server.server.get_async_httpx_client", + return_value=mock_client, + ): + status, www_auth = await _probe_upstream_auth( + "http://upstream/mcp", "Bearer some-token" + ) + + assert status == 200 + assert www_auth is None + + +def test_get_forwarded_auth_from_scope_extracts_header(): + """Returns Authorization value when x-litellm-api-key is also present.""" + from litellm.proxy._experimental.mcp_server.server import ( + _get_forwarded_auth_from_scope, + ) + + scope = { + "headers": [ + (b"content-type", b"application/json"), + (b"x-litellm-api-key", b"sk-litellm-proxy-key"), + (b"authorization", b"Bearer my-token"), + ] + } + assert _get_forwarded_auth_from_scope(scope) == "Bearer my-token" + + +def test_get_forwarded_auth_from_scope_returns_none_when_missing(): + from litellm.proxy._experimental.mcp_server.server import ( + _get_forwarded_auth_from_scope, + ) + + assert _get_forwarded_auth_from_scope({"headers": []}) is None + + +def test_get_forwarded_auth_from_scope_skips_when_no_litellm_key_header(): + """Skip when ``x-litellm-api-key`` is absent. + + Without ``x-litellm-api-key``, the ``Authorization`` header may itself be + the LiteLLM proxy API key (backward-compat). Forwarding it upstream would + leak the proxy key, so the helper must return None and the probe must + not fire. + """ + from litellm.proxy._experimental.mcp_server.server import ( + _get_forwarded_auth_from_scope, + ) + + scope = { + "headers": [ + (b"content-type", b"application/json"), + (b"authorization", b"Bearer ambiguous-token"), + ] + } + assert _get_forwarded_auth_from_scope(scope) is None diff --git a/tests/test_litellm/proxy/test_mcp_asgi_response.py b/tests/test_litellm/proxy/test_mcp_asgi_response.py new file mode 100644 index 0000000000..d030f65af4 --- /dev/null +++ b/tests/test_litellm/proxy/test_mcp_asgi_response.py @@ -0,0 +1,36 @@ +import asyncio + +import pytest +from fastapi import HTTPException + +from litellm.proxy.proxy_server import _stream_mcp_asgi_response + + +@pytest.mark.asyncio +async def test_stream_mcp_asgi_response_propagates_pre_header_http_exception(): + async def handle_fn(_scope, _receive, _send): + raise HTTPException( + status_code=401, + detail="Unauthorized", + headers={ + "WWW-Authenticate": "Bearer authorization_uri=https://example.test/auth" + }, + ) + + async def receive(): + return {"type": "http.request", "body": b"", "more_body": False} + + with pytest.raises(HTTPException) as exc_info: + await asyncio.wait_for( + _stream_mcp_asgi_response( + handle_fn, + {"type": "http", "method": "POST", "path": "/mcp", "headers": []}, + receive, + ), + timeout=1.0, + ) + + assert exc_info.value.status_code == 401 + assert exc_info.value.headers == { + "WWW-Authenticate": "Bearer authorization_uri=https://example.test/auth" + }