diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index 8d4bdffb2d..b4d0f82d7b 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -315,6 +315,11 @@ model LiteLLM_MCPServerTable { is_byok Boolean @default(false) byok_description String[] @default([]) byok_api_key_help_url String? + approval_status String @default("approved") + submitted_by String? + submitted_at DateTime? + reviewed_at DateTime? + review_notes String? } // Per-user BYOK credentials for MCP servers diff --git a/litellm/proxy/_experimental/mcp_server/db.py b/litellm/proxy/_experimental/mcp_server/db.py index 477229dc70..119e8171a1 100644 --- a/litellm/proxy/_experimental/mcp_server/db.py +++ b/litellm/proxy/_experimental/mcp_server/db.py @@ -628,6 +628,24 @@ async def store_user_oauth_credential( ) +def is_oauth_credential_expired(cred: Dict[str, Any]) -> bool: + """Return True if the OAuth2 credential's access_token has expired. + + Checks the ``expires_at`` ISO-format string stored in the credential payload. + Returns False when ``expires_at`` is absent or unparseable (treat as non-expired). + """ + expires_at = cred.get("expires_at") + if not expires_at: + return False + try: + exp_dt = datetime.fromisoformat(expires_at) + if exp_dt.tzinfo is None: + exp_dt = exp_dt.replace(tzinfo=timezone.utc) + return datetime.now(timezone.utc) > exp_dt + except (ValueError, TypeError): + return False + + async def get_user_oauth_credential( prisma_client: PrismaClient, user_id: str, diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index 0c62f5d5da..f10263ba57 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -1,6 +1,6 @@ import importlib -from datetime import datetime -from typing import Any, Awaitable, Callable, Dict, List, Optional, Union +from datetime import datetime, timezone +from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Union from fastapi import APIRouter, Depends, HTTPException, Query, Request @@ -69,18 +69,36 @@ if MCP_AVAILABLE: return server_auth return mcp_auth_header + def _get_oauth2_server_ids(allowed_server_ids: List[str]) -> Set[str]: + """Return the subset of *allowed_server_ids* whose servers use OAuth2 auth. + + Used as a cheap pre-flight check to skip bulk credential fetching when no + OAuth2 servers are involved in the current request. + """ + return { + sid + for sid in allowed_server_ids + if getattr( + global_mcp_server_manager.get_mcp_server_by_id(sid), "auth_type", None + ) + == MCPAuth.oauth2 + } + async def _get_user_oauth_extra_headers( server, user_api_key_dict: UserAPIKeyAuth, + prefetched_creds: Optional[Dict[str, Dict[str, Any]]] = None, ) -> Optional[Dict[str, str]]: """ For OAuth2 servers, look up the user's stored access token and return it as extra_headers {"Authorization": "Bearer "} so that it reaches the MCP server the same way the admin "Add MCP / Authorize and Fetch" flow does. Returns None for non-OAuth2 servers or when no credential is stored. - """ - from litellm.types.mcp import MCPAuth + Args: + prefetched_creds: Optional dict keyed by server_id with credential payloads. + When provided, avoids a per-server DB round-trip. + """ if getattr(server, "auth_type", None) != MCPAuth.oauth2: return None user_id = getattr(user_api_key_dict, "user_id", None) @@ -90,18 +108,60 @@ if MCP_AVAILABLE: try: from litellm.proxy._experimental.mcp_server.db import ( get_user_oauth_credential, + is_oauth_credential_expired, + ) + + if prefetched_creds is not None: + cred = prefetched_creds.get(server_id) + else: + from litellm.proxy.utils import get_prisma_client_or_throw + + prisma_client = get_prisma_client_or_throw( + "Database not connected. Connect a database to use OAuth2 MCP tools." + ) + cred = await get_user_oauth_credential(prisma_client, user_id, server_id) + if cred and cred.get("access_token"): + if is_oauth_credential_expired(cred): + verbose_logger.debug( + f"_get_user_oauth_extra_headers: token expired for " + f"user={user_id} server={server_id}" + ) + return None + return {"Authorization": f"Bearer {cred['access_token']}"} + except Exception as e: + verbose_logger.warning( + f"_get_user_oauth_extra_headers: failed to retrieve credential for " + f"user={user_id} server={server_id}: {e}" + ) + return None + + async def _prefetch_user_oauth_creds( + user_api_key_dict: UserAPIKeyAuth, + ) -> Dict[str, Dict[str, Any]]: + """Fetch all OAuth2 credentials for the user in a single DB query. + + Returns a dict keyed by server_id. Used to avoid N+1 DB queries when + iterating over multiple OAuth2 MCP servers. + """ + user_id = getattr(user_api_key_dict, "user_id", None) + if not user_id: + return {} + try: + from litellm.proxy._experimental.mcp_server.db import ( + list_user_oauth_credentials, ) from litellm.proxy.utils import get_prisma_client_or_throw prisma_client = get_prisma_client_or_throw( "Database not connected. Connect a database to use OAuth2 MCP tools." ) - cred = await get_user_oauth_credential(prisma_client, user_id, server_id) - if cred and cred.get("access_token"): - return {"Authorization": f"Bearer {cred['access_token']}"} - except Exception: - verbose_logger.debug("Failed to fetch OAuth credential", exc_info=True) - return None + creds = await list_user_oauth_credentials(prisma_client, user_id) + return {c["server_id"]: c for c in creds if "server_id" in c} + except Exception as e: + verbose_logger.warning( + f"_prefetch_user_oauth_creds: failed to prefetch for user={user_id}: {e}" + ) + return {} async def _get_bulk_user_oauth_headers( user_api_key_dict: UserAPIKeyAuth, @@ -364,13 +424,21 @@ if MCP_AVAILABLE: if server_id: # Resolve a server name to its UUID if needed (MCPConnectPicker passes # server_name strings, but allowed_server_ids_set contains UUIDs). + # _name_resolved is kept so the second check can reuse it for accurate + # IP-filter error reporting if the resolved UUID is not in allowed_server_ids. + _name_resolved = None if server_id not in allowed_server_ids: - _resolved = global_mcp_server_manager.get_mcp_server_by_name(server_id) - if _resolved is not None and _resolved.server_id in set(allowed_server_ids): - server_id = _resolved.server_id + _name_resolved = global_mcp_server_manager.get_mcp_server_by_name(server_id) + if _name_resolved is not None and _name_resolved.server_id in set(allowed_server_ids): + server_id = _name_resolved.server_id if server_id not in allowed_server_ids: - _server = global_mcp_server_manager.get_mcp_server_by_id(server_id) + # Try UUID lookup first; fall back to the name-resolved server so that + # IP-filter reporting works correctly even when server_id is a name string. + _server = ( + global_mcp_server_manager.get_mcp_server_by_id(server_id) + or _name_resolved + ) if ( _server is not None and _rest_client_ip is not None @@ -451,10 +519,15 @@ if MCP_AVAILABLE: }, ) - # Query all servers the user has access to. - # Bulk-fetch OAuth creds once so each per-server call below can - # do an O(1) dict lookup instead of N individual DB queries. - bulk_oauth_headers = await _get_bulk_user_oauth_headers(user_api_key_dict) + # Pre-fetch OAuth credentials only when at least one allowed server uses OAuth2, + # to avoid an unnecessary DB round-trip on requests with no OAuth2 MCP servers. + prefetched_oauth_creds = ( + await _prefetch_user_oauth_creds(user_api_key_dict) + if _get_oauth2_server_ids(allowed_server_ids) + else {} + ) + + # Query all servers the user has access to errors = [] for allowed_server_id in allowed_server_ids: server = global_mcp_server_manager.get_mcp_server_by_id( @@ -466,7 +539,9 @@ if MCP_AVAILABLE: server_auth_header = _get_server_auth_header( server, mcp_server_auth_headers, mcp_auth_header ) - user_oauth_extra_headers = bulk_oauth_headers.get(server.server_id) + user_oauth_extra_headers = await _get_user_oauth_extra_headers( + server, user_api_key_dict, prefetched_creds=prefetched_oauth_creds + ) try: tools_result = await _get_tools_for_single_server( diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 99f6a5234a..d6d44042ff 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -8,7 +8,7 @@ import contextlib import time import traceback import uuid -from datetime import datetime +from datetime import datetime, timezone from typing import ( Any, AsyncIterator, @@ -871,6 +871,84 @@ if MCP_AVAILABLE: return allowed_mcp_servers + async def _get_user_oauth_extra_headers_from_db( + server: MCPServer, + user_api_key_auth: Optional[UserAPIKeyAuth], + prefetched_creds: Optional[Dict[str, Dict[str, Any]]] = None, + ) -> Optional[Dict[str, str]]: + """Look up stored OAuth2 token for (user, server) from DB and return as extra_headers dict. + + Args: + prefetched_creds: Optional dict keyed by server_id with credential payloads. + When provided, avoids a per-server DB round-trip. + """ + if server.auth_type != MCPAuth.oauth2: + return None + if user_api_key_auth is None: + return None + user_id = getattr(user_api_key_auth, "user_id", None) + server_id = getattr(server, "server_id", None) + if not user_id or not server_id: + return None + try: + from litellm.proxy._experimental.mcp_server.db import ( # noqa: PLC0415 + get_user_oauth_credential, + is_oauth_credential_expired, + ) + + if prefetched_creds is not None: + cred = prefetched_creds.get(server_id) + else: + from litellm.proxy.utils import ( # noqa: PLC0415 + get_prisma_client_or_throw, + ) + + prisma_client = get_prisma_client_or_throw( + "Database not connected. Connect a database to use OAuth2 MCP tools." + ) + cred = await get_user_oauth_credential(prisma_client, user_id, server_id) + if cred and cred.get("access_token"): + if is_oauth_credential_expired(cred): + verbose_logger.debug( + f"_get_user_oauth_extra_headers_from_db: token expired for " + f"user={user_id} server={server_id}" + ) + return None + return {"Authorization": f"Bearer {cred['access_token']}"} + except Exception as e: + verbose_logger.warning( + f"_get_user_oauth_extra_headers_from_db: failed to retrieve credential for " + f"user={user_id} server={server_id}: {e}" + ) + return None + + async def _prefetch_oauth_creds_for_user( + user_api_key_auth: Optional[UserAPIKeyAuth], + ) -> Dict[str, Dict[str, Any]]: + """Fetch all OAuth2 credentials for the user in one DB query. + + Returns a dict keyed by server_id to avoid N+1 queries in asyncio.gather loops. + """ + user_id = getattr(user_api_key_auth, "user_id", None) if user_api_key_auth else None + if not user_id: + return {} + try: + from litellm.proxy._experimental.mcp_server.db import ( # noqa: PLC0415 + list_user_oauth_credentials, + ) + from litellm.proxy.utils import get_prisma_client_or_throw # noqa: PLC0415 + + prisma_client = get_prisma_client_or_throw( + "Database not connected. Connect a database to use OAuth2 MCP tools." + ) + creds = await list_user_oauth_credentials(prisma_client, user_id) + return {c["server_id"]: c for c in creds if "server_id" in c} + except Exception as e: + verbose_logger.warning( + f"_prefetch_oauth_creds_for_user: failed to prefetch for user={user_id}: {e}" + ) + return {} + def _prepare_mcp_server_headers( server: MCPServer, mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]], @@ -1015,6 +1093,18 @@ if MCP_AVAILABLE: mcp_servers=mcp_servers, ) + # Pre-fetch OAuth credentials only when at least one server uses OAuth2, + # to avoid an unnecessary DB round-trip on requests with no OAuth2 MCP servers. + _has_oauth2_server = any( + getattr(s, "auth_type", None) == MCPAuth.oauth2 + for s in allowed_mcp_servers + ) + _prefetched_oauth_creds = ( + await _prefetch_oauth_creds_for_user(user_api_key_auth) + if _has_oauth2_server + else {} + ) + async def _fetch_and_filter_server_tools( server: MCPServer, ) -> List[MCPTool]: @@ -1030,6 +1120,12 @@ if MCP_AVAILABLE: raw_headers=raw_headers, ) + # If no OAuth2 token came from request headers, fall back to pre-fetched creds + if extra_headers is None and server.auth_type == MCPAuth.oauth2: + extra_headers = await _get_user_oauth_extra_headers_from_db( + server, user_api_key_auth, prefetched_creds=_prefetched_oauth_creds + ) + try: tools = await global_mcp_server_manager._get_tools_from_server( server=server, diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py index 2742f8c8ef..a9ff61dab5 100644 --- a/litellm/proxy/management_endpoints/mcp_management_endpoints.py +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -1306,10 +1306,21 @@ if MCP_AVAILABLE: def _get_cached_temporary_mcp_server_or_404(server_id: str) -> MCPServer: server = get_cached_temporary_mcp_server(server_id) + if server is None: + # Fall back to real DB/config server (e.g. for the user-side OAuth flow + # which calls these endpoints with a real server_id, not a temp session id). + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + global_mcp_server_manager, + ) + + server = ( + global_mcp_server_manager.get_mcp_server_by_id(server_id) + or global_mcp_server_manager.get_mcp_server_by_name(server_id) + ) if server is None: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, - detail={"error": f"Temporary MCP server {server_id} not found"}, + detail={"error": f"MCP server {server_id} not found"}, ) return server @@ -1320,8 +1331,8 @@ if MCP_AVAILABLE: async def mcp_authorize( request: Request, server_id: str, - client_id: str, - redirect_uri: str, + client_id: Optional[str] = None, + redirect_uri: str = Query(...), state: str = "", code_challenge: Optional[str] = None, code_challenge_method: Optional[str] = None, @@ -1329,10 +1340,23 @@ if MCP_AVAILABLE: scope: Optional[str] = None, ): mcp_server = _get_cached_temporary_mcp_server_or_404(server_id) + # Use the server's stored client_id when the caller doesn't supply one + resolved_client_id = mcp_server.client_id or client_id or "" + if not resolved_client_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "missing_client_id", + "message": ( + "No client_id available for this MCP server. " + "Either configure the server with a client_id or supply one in the request." + ), + }, + ) return await authorize_with_server( request=request, mcp_server=mcp_server, - client_id=client_id, + client_id=resolved_client_id, redirect_uri=redirect_uri, state=state, code_challenge=code_challenge, @@ -1351,18 +1375,30 @@ if MCP_AVAILABLE: grant_type: str = Form(...), code: Optional[str] = Form(None), redirect_uri: Optional[str] = Form(None), - client_id: str = Form(...), + client_id: Optional[str] = Form(None), client_secret: Optional[str] = Form(None), code_verifier: Optional[str] = Form(None), ): mcp_server = _get_cached_temporary_mcp_server_or_404(server_id) + resolved_client_id = mcp_server.client_id or client_id or "" + if not resolved_client_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={ + "error": "missing_client_id", + "message": ( + "No client_id available for this MCP server. " + "Either configure the server with a client_id or supply one in the request." + ), + }, + ) return await exchange_token_with_server( request=request, mcp_server=mcp_server, grant_type=grant_type, code=code, redirect_uri=redirect_uri, - client_id=client_id, + client_id=resolved_client_id, client_secret=client_secret, code_verifier=code_verifier, ) diff --git a/litellm/responses/mcp/litellm_proxy_mcp_handler.py b/litellm/responses/mcp/litellm_proxy_mcp_handler.py index 5776ef95ac..7a3934ffda 100644 --- a/litellm/responses/mcp/litellm_proxy_mcp_handler.py +++ b/litellm/responses/mcp/litellm_proxy_mcp_handler.py @@ -1,3 +1,4 @@ +import re import traceback from datetime import datetime from typing import ( @@ -43,6 +44,13 @@ ToolParam = Any LITELLM_PROXY_MCP_SERVER_URL = "litellm_proxy" LITELLM_PROXY_MCP_SERVER_URL_PREFIX = f"{LITELLM_PROXY_MCP_SERVER_URL}/mcp/" +# Matches any URL whose path ends with /mcp/ — covers both root-path +# (http://host:port/mcp/name) and sub-path (http://host/base/mcp/name) proxy deployments. +# A false-positive match (e.g. an external URL that happens to end with /mcp/) results +# in a "server not found" error from the internal gateway, not a silent failure or data leak, +# so this broad pattern is intentional and preferred over anchoring to localhost only. +_PROXY_MCP_PATH_RE = re.compile(r"^https?://.+/mcp/([^/]+)$") + class LiteLLM_Proxy_MCP_Handler: """ @@ -54,7 +62,8 @@ class LiteLLM_Proxy_MCP_Handler: @staticmethod def _should_use_litellm_mcp_gateway(tools: Optional[Iterable[ToolParam]]) -> bool: """ - Returns True if the user passed a MCP tool with server_url="litellm_proxy" + Returns True if any MCP tool should be handled via the litellm proxy MCP gateway. + This includes tools with server_url="litellm_proxy" as well as URLs ending in /mcp/. """ if tools: for tool in tools: @@ -64,6 +73,10 @@ class LiteLLM_Proxy_MCP_Handler: LITELLM_PROXY_MCP_SERVER_URL ): return True + if isinstance(server_url, str) and _PROXY_MCP_PATH_RE.match( + server_url + ): + return True return False @staticmethod @@ -87,6 +100,18 @@ class LiteLLM_Proxy_MCP_Handler: LITELLM_PROXY_MCP_SERVER_URL ): mcp_tools_with_litellm_proxy.append(tool) + elif isinstance(server_url, str): + # Also intercept URLs like http://localhost:4000/mcp/atlassian_test + # by rewriting them to the internal litellm_proxy format. + m = _PROXY_MCP_PATH_RE.match(server_url) + if m: + rewritten = { + **tool, + "server_url": f"{LITELLM_PROXY_MCP_SERVER_URL_PREFIX}{m.group(1)}", + } + mcp_tools_with_litellm_proxy.append(rewritten) + else: + other_tools.append(tool) else: other_tools.append(tool) else: diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index de2ec13b4a..314eed9598 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -2093,3 +2093,83 @@ async def test_get_tools_from_mcp_servers_logs_list_tools_to_spendlogs_when_enab assert spend_meta["tool_count_total"] == 1 assert spend_meta["allowed_server_count"] == 1 assert spend_meta["per_server_tool_counts"]["server_a"] == 1 + + +@pytest.mark.asyncio +async def test_get_tools_from_mcp_servers_injects_stored_oauth2_token(): + """ + When _get_tools_from_mcp_servers is called for an OAuth2 MCP server and no + oauth2_headers are provided in the request (e.g. a /responses API call from a + chat UI), the per-user stored token must be fetched from the DB and passed as + extra_headers to _get_tools_from_server. + + The implementation pre-fetches all user credentials in a single bulk query + (_prefetch_oauth_creds_for_user) to avoid N+1 queries in the gather loop. + + This covers the bug where OAuth2 MCP tools were always empty in the /responses + API because the stored credential was never injected. + """ + try: + from litellm.proxy._experimental.mcp_server.server import ( + _get_tools_from_mcp_servers, + ) + from litellm.proxy._types import UserAPIKeyAuth + from litellm.types.mcp import MCPAuth + except ImportError: + pytest.skip("MCP server not available") + + STORED_TOKEN = "atlassian-oauth-access-token-xyz" + SERVER_ID = "srv-oauth2-id" + USER_ID = "user-123" + + user_auth = UserAPIKeyAuth(api_key="test-key", user_id=USER_ID) + + oauth2_server = MagicMock(name="atlassian_server") + oauth2_server.name = "atlassian_test" + oauth2_server.alias = "atlassian_test" + oauth2_server.server_name = "atlassian_test" + oauth2_server.server_id = SERVER_ID + oauth2_server.auth_type = MCPAuth.oauth2 + oauth2_server.extra_headers = None + + # Simulate the DB returning a valid credential for this user+server + prefetched_creds = {SERVER_ID: {"access_token": STORED_TOKEN, "server_id": SERVER_ID}} + + tool_1 = MagicMock() + tool_1.name = "atlassian_test-search" + + with patch( + "litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers", + new=AsyncMock(return_value=[oauth2_server]), + ), patch( + # Patch the bulk prefetch so no real DB connection is needed + "litellm.proxy._experimental.mcp_server.server._prefetch_oauth_creds_for_user", + new=AsyncMock(return_value=prefetched_creds), + ) as mock_prefetch, patch( + "litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager", + ) as mock_manager, patch( + "litellm.proxy._experimental.mcp_server.server.filter_tools_by_allowed_tools", + side_effect=lambda tools, _server: tools, + ), patch( + "litellm.proxy._experimental.mcp_server.server.filter_tools_by_key_team_permissions", + new=AsyncMock(side_effect=lambda tools, **_: tools), + ): + mock_manager._get_tools_from_server = AsyncMock(return_value=[tool_1]) + + tools = await _get_tools_from_mcp_servers( + user_api_key_auth=user_auth, + mcp_auth_header=None, + mcp_servers=["atlassian_test"], + mcp_server_auth_headers=None, + oauth2_headers=None, # No token from request — must fall back to DB + ) + + # Bulk credential prefetch was called once (not once per server) + mock_prefetch.assert_awaited_once_with(user_auth) + + # The stored token was forwarded to the MCP transport layer as extra_headers + mock_manager._get_tools_from_server.assert_awaited_once() + call_kwargs = mock_manager._get_tools_from_server.await_args.kwargs + assert call_kwargs["extra_headers"] == {"Authorization": f"Bearer {STORED_TOKEN}"} + + assert tools == [tool_1] diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py index 4f93270c16..1d296f0440 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py @@ -484,7 +484,7 @@ class TestListToolsRestAPI: captured = {"called": False} async def fake_get_tools( - server, server_auth_header, raw_headers=None, user_api_key_auth=None + server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None ): captured["called"] = True captured["server"] = server @@ -529,6 +529,175 @@ class TestListToolsRestAPI: assert result["error"] is None assert result["message"] == "Successfully retrieved tools" + async def test_name_resolution_finds_server_by_uuid(self, monkeypatch): + """When server_id is a name string, it should be resolved to its UUID + and used for the tools lookup when the UUID is in allowed_server_ids.""" + from litellm.proxy._experimental.mcp_server.server import MCPServer + from litellm.types.mcp import MCPTransport + + stub_server = MCPServer( + server_id="uuid-abc-123", + name="my-server", + transport=MCPTransport.sse, + ) + stub_server.alias = "my-server" + stub_server.server_name = "my-server" + stub_server.available_on_public_internet = True + stub_server.allowed_tools = None + stub_server.mcp_info = {"server_name": "my-server"} + + async def fake_contexts(user_api_key_auth): + return [user_api_key_auth] + + # Allowed list contains the UUID, not the name + async def fake_get_allowed_mcp_servers(*args, **kwargs): + return ["uuid-abc-123"] + + captured = {"called": False, "server_arg": None} + + async def fake_get_tools(server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None): + captured["called"] = True + captured["server_arg"] = server + return ["tool-x"] + + monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers", + fake_get_allowed_mcp_servers, raising=False, + ) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_name", + lambda name: stub_server if name == "my-server" else None, + raising=False, + ) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id", + lambda sid: stub_server if sid == "uuid-abc-123" else None, + raising=False, + ) + monkeypatch.setattr(rest_endpoints, "_get_tools_for_single_server", fake_get_tools, raising=False) + + request = _build_request(path="/mcp-rest/tools/list", method="GET") + result = await rest_endpoints.list_tool_rest_api( + request, + server_id="my-server", # pass name, not UUID + user_api_key_dict=UserAPIKeyAuth(), + ) + + assert captured["called"] is True + assert captured["server_arg"] is stub_server + assert result["tools"] == ["tool-x"] + assert result["error"] is None + + async def test_name_not_in_allowed_returns_access_denied(self, monkeypatch): + """When name resolves to a server whose UUID is NOT in allowed_server_ids, + the result should be an access_denied error (not a crash or silent pass).""" + from litellm.proxy._experimental.mcp_server.server import MCPServer + from litellm.types.mcp import MCPTransport + + stub_server = MCPServer( + server_id="uuid-xyz-999", + name="restricted-server", + transport=MCPTransport.sse, + ) + stub_server.available_on_public_internet = True + + async def fake_contexts(user_api_key_auth): + return [user_api_key_auth] + + # No allowed servers for this key + async def fake_get_allowed_mcp_servers(*args, **kwargs): + return [] + + monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers", + fake_get_allowed_mcp_servers, raising=False, + ) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_name", + lambda name: stub_server if name == "restricted-server" else None, + raising=False, + ) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id", + lambda sid: stub_server if sid == "uuid-xyz-999" else None, + raising=False, + ) + + request = _build_request(path="/mcp-rest/tools/list", method="GET") + result = await rest_endpoints.list_tool_rest_api( + request, + server_id="restricted-server", + user_api_key_dict=UserAPIKeyAuth(), + ) + + assert result["tools"] == [] + assert result["error"] == "unexpected_error" + assert "access_denied" in result["message"] + + async def test_oauth2_user_token_injected_for_single_server(self, monkeypatch): + """For a single-server OAuth2 request, _get_user_oauth_extra_headers is called + and the returned headers are forwarded to _get_tools_for_single_server.""" + from litellm.proxy._experimental.mcp_server.server import MCPServer + from litellm.types.mcp import MCPTransport + + stub_server = MCPServer( + server_id="oauth-server-id", + name="oauth-server", + transport=MCPTransport.sse, + ) + stub_server.alias = "oauth-server" + stub_server.server_name = "oauth-server" + stub_server.available_on_public_internet = True + stub_server.allowed_tools = None + stub_server.mcp_info = {"server_name": "oauth-server"} + stub_server.auth_type = MCPAuth.oauth2 + + async def fake_contexts(user_api_key_auth): + return [user_api_key_auth] + + async def fake_get_allowed_mcp_servers(*args, **kwargs): + return ["oauth-server-id"] + + oauth_headers = {"Authorization": "Bearer user-oauth-token"} + + async def fake_get_user_oauth_extra_headers(server, user_api_key_dict, prefetched_creds=None): + return oauth_headers + + captured = {} + + async def fake_get_tools(server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None): + captured["server"] = server + captured["auth_header"] = server_auth_header + return ["oauth-tool"] + + monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers", + fake_get_allowed_mcp_servers, raising=False, + ) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id", + lambda sid: stub_server if sid == "oauth-server-id" else None, + raising=False, + ) + monkeypatch.setattr( + rest_endpoints, "_get_user_oauth_extra_headers", + fake_get_user_oauth_extra_headers, raising=False, + ) + monkeypatch.setattr(rest_endpoints, "_get_tools_for_single_server", fake_get_tools, raising=False) + + request = _build_request(path="/mcp-rest/tools/list", method="GET") + result = await rest_endpoints.list_tool_rest_api( + request, + server_id="oauth-server-id", + user_api_key_dict=UserAPIKeyAuth(user_id="user-123"), + ) + + assert result["tools"] == ["oauth-tool"] + assert result["error"] is None + class TestCallToolRestAPI: pytestmark = pytest.mark.asyncio diff --git a/ui/litellm-dashboard/src/components/chat/MCPAppsPanel.tsx b/ui/litellm-dashboard/src/components/chat/MCPAppsPanel.tsx index 32d208261e..b805cab71e 100644 --- a/ui/litellm-dashboard/src/components/chat/MCPAppsPanel.tsx +++ b/ui/litellm-dashboard/src/components/chat/MCPAppsPanel.tsx @@ -319,7 +319,7 @@ const MCPAppsPanel: React.FC = ({ accessToken, selectedServers, onChange // Ignore — credential may already be gone; update UI regardless. } setOauthConnected((prev) => { const n = new Set(prev); n.delete(detailServer.server_id); return n; }); - onChange(selectedServers.filter((s) => s !== name)); + onChangeRef.current(selectedServersRef.current.filter((s) => s !== name)); }} style={{ borderRadius: 8, fontWeight: 600, height: 38, minWidth: 110 }} > @@ -331,7 +331,6 @@ const MCPAppsPanel: React.FC = ({ accessToken, selectedServers, onChange accessToken={accessToken} onConnect={(id) => { setOauthConnected((prev) => new Set(prev).add(id)); - handleToggle(name, true, detailServer.server_id); }} variant="button" /> @@ -559,7 +558,6 @@ const MCPAppsPanel: React.FC = ({ accessToken, selectedServers, onChange accessToken={accessToken} onConnect={(id) => { setOauthConnected((prev) => new Set(prev).add(id)); - handleToggle(nameOf(server), true, server.server_id); }} variant="badge" /> diff --git a/ui/litellm-dashboard/src/components/chat/MCPCredentialsTab.tsx b/ui/litellm-dashboard/src/components/chat/MCPCredentialsTab.tsx index 9714da6d4d..363ec8c0e4 100644 --- a/ui/litellm-dashboard/src/components/chat/MCPCredentialsTab.tsx +++ b/ui/litellm-dashboard/src/components/chat/MCPCredentialsTab.tsx @@ -9,7 +9,8 @@ import React, { useCallback, useEffect, useState } from "react"; import { Spin, message } from "antd"; -import { CheckCircleOutlined, DeleteOutlined, LinkOutlined } from "@ant-design/icons"; +import { DeleteOutlined, LinkOutlined } from "@ant-design/icons"; +import { Badge, Table, TableBody, TableCell, TableHead, TableHeaderCell, TableRow } from "@tremor/react"; import { deleteMCPOAuthUserCredential, listMCPUserCredentials, @@ -86,98 +87,82 @@ const MCPCredentialsTab: React.FC = ({ accessToken }) => { c.alias || c.server_name || c.server_id; return ( -
+
{/* Header */} -
-

- App Credentials -

-

+

+

App Credentials

+

Your stored OAuth connections — used automatically in chat.

{loading ? ( -
+
) : credentials.length === 0 ? ( -
- +
+ No connections yet.
Go to Apps and click Connect to authorize an MCP server.
) : ( -
- {credentials.map((cred) => { - const name = displayName(cred); - const isRevoking = revoking.has(cred.server_id); - const exp = expiryLabel(cred.expires_at); - const connected = relativeTime(cred.connected_at); - const isExpired = exp === "Expired"; +
+ + + + + App + + + Connected + + + Status + + + Actions + + + + + {credentials.map((cred) => { + const name = displayName(cred); + const isRevoking = revoking.has(cred.server_id); + const exp = expiryLabel(cred.expires_at); + const connected = relativeTime(cred.connected_at); + const isExpired = exp === "Expired"; - return ( -
- {/* Status dot */} -
- -
- - {/* Info */} -
-
- {name} -
-
- {connected && ( - - Connected {connected} - - )} - - {exp} - -
-
- - {/* Revoke */} - -
- ); - })} + return ( + + + {name} + + + {connected || "—"} + + + + {exp} + + + + + + + ); + })} +
+
)}