diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 9923c3ce4b..12fd1ef89a 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -768,7 +768,9 @@ class MCPServerManager: ) return new_server - async def _maybe_register_openapi_tools(self, server: MCPServer): + async def _maybe_register_openapi_tools( + self, server: MCPServer, *, initialize_mapping: bool = True + ): """Register OpenAPI tools if the server has a spec_path configured.""" if server.spec_path: verbose_logger.info( @@ -779,7 +781,8 @@ class MCPServerManager: server=server, base_url=server.url or "", ) - self.initialize_tool_name_to_mcp_server_name_mapping() + if initialize_mapping: + self.initialize_tool_name_to_mcp_server_name_mapping() async def add_server(self, mcp_server: LiteLLM_MCPServerTable): try: @@ -1978,7 +1981,11 @@ class MCPServerManager: _SHORT_PREFIX_MAX_REHASH_ATTEMPTS = 1024 - def _assign_unique_short_prefix(self, server: MCPServer) -> None: + def _assign_unique_short_prefix( + self, + server: MCPServer, + registry: Optional[Dict[str, MCPServer]] = None, + ) -> None: """Resolve and cache a collision-free short tool prefix on ``server``. Called at registration time for every MCP server entering the @@ -2002,7 +2009,8 @@ class MCPServerManager: return used: Dict[str, str] = {} - for other in self.get_registry().values(): + registry_for_collision_check = registry or self.get_registry() + for other in registry_for_collision_check.values(): if other.server_id == server.server_id: continue if other.short_prefix: @@ -2916,46 +2924,72 @@ class MCPServerManager: # against the *full* set so dedup is deterministic regardless of # iteration order. for server in db_mcp_servers: - existing_server = previous_registry.get(server.server_id) + try: + existing_server = previous_registry.get(server.server_id) - if ( - existing_server is not None - and existing_server.updated_at is not None - and server.updated_at is not None - and existing_server.updated_at == server.updated_at - ): - # Re-use existing server instance to avoid re-running build_mcp_server_from_table() - # which can perform network discovery for OAuth2 servers. - new_registry[server.server_id] = existing_server - continue + if ( + existing_server is not None + and existing_server.updated_at is not None + and server.updated_at is not None + and existing_server.updated_at == server.updated_at + ): + # Re-use existing server instance to avoid re-running build_mcp_server_from_table() + # which can perform network discovery for OAuth2 servers. + new_registry[server.server_id] = existing_server + continue - _warn_on_server_name_fields( - server_id=server.server_id, - alias=getattr(server, "alias", None), - server_name=getattr(server, "server_name", None), - ) - verbose_logger.debug( - f"Building server from DB: {server.server_id} ({server.server_name})" - ) - new_server = await self.build_mcp_server_from_table(server) - # Carry the cached short_prefix from the previous registry entry - # (if any) so the prefix is stable across reloads. - if existing_server is not None and existing_server.short_prefix: - new_server.short_prefix = existing_server.short_prefix - new_registry[server.server_id] = new_server + _warn_on_server_name_fields( + server_id=server.server_id, + alias=getattr(server, "alias", None), + server_name=getattr(server, "server_name", None), + ) + verbose_logger.debug( + f"Building server from DB: {server.server_id} ({server.server_name})" + ) + new_server = await self.build_mcp_server_from_table(server) + # Carry the cached short_prefix from the previous registry entry + # (if any) so the prefix is stable across reloads. + if existing_server is not None and existing_server.short_prefix: + new_server.short_prefix = existing_server.short_prefix + new_registry[server.server_id] = new_server + except Exception as e: + verbose_logger.exception( + "Skipping MCP server %s (%s) during DB reload: %s", + server.server_id, + getattr(server, "alias", None), + e, + ) - # Swap in the new registry first so _assign_unique_short_prefix - # sees the complete set when checking for collisions. - self.registry = new_registry - for new_server in new_registry.values(): - self._assign_unique_short_prefix(new_server) - # Register OpenAPI tools *after* the final short prefix is assigned - # so the tools are stored in the global registry under the same - # prefix that lookups will use. - await self._maybe_register_openapi_tools(new_server) + # Assign short prefixes against the full candidate set without + # publishing the staged registry to concurrent callers. + registered_registry: Dict[str, MCPServer] = {} + registered_openapi_tools = False + for server_id, new_server in new_registry.items(): + try: + self._assign_unique_short_prefix(new_server, registry=new_registry) + # Register OpenAPI tools *after* the final short prefix is assigned + # so the tools are stored in the global registry under the same + # prefix that lookups will use. + await self._maybe_register_openapi_tools( + new_server, initialize_mapping=False + ) + registered_registry[server_id] = new_server + if new_server.spec_path: + registered_openapi_tools = True + except Exception as e: + verbose_logger.exception( + "Skipping MCP server %s (%s) during DB reload: %s", + new_server.server_id, + getattr(new_server, "alias", None), + e, + ) + + self.registry = registered_registry + if registered_openapi_tools: + self.initialize_tool_name_to_mcp_server_name_mapping() verbose_logger.debug( - "MCP registry refreshed (%s servers in registry)", len(new_registry) + "MCP registry refreshed (%s servers in registry)", len(registered_registry) ) def get_mcp_servers_from_ids(self, server_ids: List[str]) -> List[MCPServer]: diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index 06f95159c0..d3e90246c3 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -2282,6 +2282,144 @@ class TestMCPServerManagerReload: mock_build.assert_awaited_once_with(db_row) assert manager.registry["server-1"] is rebuilt_server + @pytest.mark.asyncio + async def test_skips_server_when_build_from_database_fails(self, caplog): + try: + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + except ImportError: + pytest.skip("MCP server not available") + + manager = MCPServerManager() + timestamp = datetime.utcnow() + healthy_row = _make_db_mcp_server("healthy-server", timestamp) + bad_row = _make_db_mcp_server("bad-server", timestamp) + another_healthy_row = _make_db_mcp_server("another-healthy-server", timestamp) + + healthy_server = MCPServer( + server_id="healthy-server", + name="healthy", + transport=MCPTransport.http, + updated_at=timestamp, + ) + another_healthy_server = MCPServer( + server_id="another-healthy-server", + name="another-healthy", + transport=MCPTransport.http, + updated_at=timestamp, + ) + + async def build_server(db_row): + if db_row.server_id == "bad-server": + raise RuntimeError("transient build failure") + if db_row.server_id == "healthy-server": + return healthy_server + return another_healthy_server + + mock_prisma = MagicMock() + mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock( + return_value=[healthy_row, bad_row, another_healthy_row] + ) + with ( + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", + return_value=mock_prisma, + ), + patch.object( + manager, + "build_mcp_server_from_table", + AsyncMock(side_effect=build_server), + ), + patch.object(manager, "_maybe_register_openapi_tools", AsyncMock()), + caplog.at_level("ERROR", logger="LiteLLM"), + ): + await manager.reload_servers_from_database() + + assert set(manager.registry) == {"healthy-server", "another-healthy-server"} + assert manager.registry["healthy-server"] is healthy_server + assert manager.registry["another-healthy-server"] is another_healthy_server + assert "Skipping MCP server bad-server" in caplog.text + + @pytest.mark.asyncio + async def test_skips_server_when_openapi_registration_fails(self, caplog): + try: + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + except ImportError: + pytest.skip("MCP server not available") + + manager = MCPServerManager() + timestamp = datetime.utcnow() + healthy_row = _make_db_mcp_server("healthy-server", timestamp) + bad_openapi_row = _make_db_mcp_server("bad-openapi-server", timestamp) + existing_server = MCPServer( + server_id="existing-server", + name="existing", + transport=MCPTransport.http, + updated_at=timestamp, + ) + manager.registry = {existing_server.server_id: existing_server} + + healthy_server = MCPServer( + server_id="healthy-server", + name="healthy", + transport=MCPTransport.http, + updated_at=timestamp, + ) + bad_openapi_server = MCPServer( + server_id="bad-openapi-server", + name="bad-openapi", + transport=MCPTransport.http, + spec_path="https://example.invalid/openapi.json", + updated_at=timestamp, + ) + + async def build_server(db_row): + if db_row.server_id == "healthy-server": + return healthy_server + return bad_openapi_server + + observed_registries = [] + + async def register_openapi_tools(server, **kwargs): + observed_registries.append(set(manager.registry)) + assert kwargs == {"initialize_mapping": False} + if server.server_id == "bad-openapi-server": + raise RuntimeError("blocked address") + + mock_prisma = MagicMock() + mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock( + return_value=[healthy_row, bad_openapi_row] + ) + with ( + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", + return_value=mock_prisma, + ), + patch.object( + manager, + "build_mcp_server_from_table", + AsyncMock(side_effect=build_server), + ), + patch.object( + manager, + "_maybe_register_openapi_tools", + AsyncMock(side_effect=register_openapi_tools), + ), + caplog.at_level("ERROR", logger="LiteLLM"), + ): + await manager.reload_servers_from_database() + + assert set(manager.registry) == {"healthy-server"} + assert manager.registry["healthy-server"] is healthy_server + assert observed_registries == [ + {"existing-server"}, + {"existing-server"}, + ] + assert "Skipping MCP server bad-openapi-server" in caplog.text + @pytest.mark.asyncio async def test_call_mcp_tool_logs_failure_via_post_call_failure_hook(): @@ -2946,7 +3084,7 @@ async def test_list_tools_with_legacy_db_m2m_server_resolves_oauth2_flow(): """ P1 Regression: list_tools path must apply _resolve_oauth2_flow to legacy DB rows where oauth2_flow is NULL but M2M credentials are present. - + Without this fix, has_client_credentials returns False and the caller's Authorization header is forwarded upstream instead of being blocked. """ @@ -3044,7 +3182,7 @@ async def test_call_tool_empty_extra_headers_returns_none(): """ P2 Regression: When all configured extra_headers are filtered out (e.g. Authorization for M2M), the resulting extra_headers should be None, not {}. - + Downstream code that checks `if extra_headers is None` will behave differently if an empty dict is passed instead. """ @@ -3071,7 +3209,10 @@ async def test_call_tool_empty_extra_headers_returns_none(): extra_headers=["Authorization"], # Will be filtered out for M2M ) - raw_headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + raw_headers = { + "Authorization": "Bearer sk-1234", + "Content-Type": "application/json", + } captured_extra_headers = None @@ -3108,8 +3249,8 @@ async def test_call_tool_empty_extra_headers_returns_none(): pass # We only care about the captured headers # With P2 fix: extra_headers should be None (not {}) when all headers filtered - assert captured_extra_headers is None, ( - "P2 API consistency issue: expected None for empty extra_headers, got: " - + str(captured_extra_headers) + assert ( + captured_extra_headers is None + ), "P2 API consistency issue: expected None for empty extra_headers, got: " + str( + captured_extra_headers ) -