Fix MCP DB reload partial failures (#27314)

* Fix MCP database reload partial failures

Co-authored-by: ishaan-berri <ishaan-berri@users.noreply.github.com>

* Avoid staged MCP registry exposure

Co-authored-by: ishaan-berri <ishaan-berri@users.noreply.github.com>

---------

Co-authored-by: oss-agent-shin <279349115+oss-agent-shin@users.noreply.github.com>
Co-authored-by: ishaan-berri <ishaan-berri@users.noreply.github.com>
This commit is contained in:
ishaan-berri 2026-05-06 15:18:18 -07:00 committed by GitHub
parent 924c141843
commit bd1a05aed9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 221 additions and 46 deletions

View File

@ -768,7 +768,9 @@ class MCPServerManager:
)
return new_server
async def _maybe_register_openapi_tools(self, server: MCPServer):
async def _maybe_register_openapi_tools(
self, server: MCPServer, *, initialize_mapping: bool = True
):
"""Register OpenAPI tools if the server has a spec_path configured."""
if server.spec_path:
verbose_logger.info(
@ -779,7 +781,8 @@ class MCPServerManager:
server=server,
base_url=server.url or "",
)
self.initialize_tool_name_to_mcp_server_name_mapping()
if initialize_mapping:
self.initialize_tool_name_to_mcp_server_name_mapping()
async def add_server(self, mcp_server: LiteLLM_MCPServerTable):
try:
@ -1978,7 +1981,11 @@ class MCPServerManager:
_SHORT_PREFIX_MAX_REHASH_ATTEMPTS = 1024
def _assign_unique_short_prefix(self, server: MCPServer) -> None:
def _assign_unique_short_prefix(
self,
server: MCPServer,
registry: Optional[Dict[str, MCPServer]] = None,
) -> None:
"""Resolve and cache a collision-free short tool prefix on ``server``.
Called at registration time for every MCP server entering the
@ -2002,7 +2009,8 @@ class MCPServerManager:
return
used: Dict[str, str] = {}
for other in self.get_registry().values():
registry_for_collision_check = registry or self.get_registry()
for other in registry_for_collision_check.values():
if other.server_id == server.server_id:
continue
if other.short_prefix:
@ -2916,46 +2924,72 @@ class MCPServerManager:
# against the *full* set so dedup is deterministic regardless of
# iteration order.
for server in db_mcp_servers:
existing_server = previous_registry.get(server.server_id)
try:
existing_server = previous_registry.get(server.server_id)
if (
existing_server is not None
and existing_server.updated_at is not None
and server.updated_at is not None
and existing_server.updated_at == server.updated_at
):
# Re-use existing server instance to avoid re-running build_mcp_server_from_table()
# which can perform network discovery for OAuth2 servers.
new_registry[server.server_id] = existing_server
continue
if (
existing_server is not None
and existing_server.updated_at is not None
and server.updated_at is not None
and existing_server.updated_at == server.updated_at
):
# Re-use existing server instance to avoid re-running build_mcp_server_from_table()
# which can perform network discovery for OAuth2 servers.
new_registry[server.server_id] = existing_server
continue
_warn_on_server_name_fields(
server_id=server.server_id,
alias=getattr(server, "alias", None),
server_name=getattr(server, "server_name", None),
)
verbose_logger.debug(
f"Building server from DB: {server.server_id} ({server.server_name})"
)
new_server = await self.build_mcp_server_from_table(server)
# Carry the cached short_prefix from the previous registry entry
# (if any) so the prefix is stable across reloads.
if existing_server is not None and existing_server.short_prefix:
new_server.short_prefix = existing_server.short_prefix
new_registry[server.server_id] = new_server
_warn_on_server_name_fields(
server_id=server.server_id,
alias=getattr(server, "alias", None),
server_name=getattr(server, "server_name", None),
)
verbose_logger.debug(
f"Building server from DB: {server.server_id} ({server.server_name})"
)
new_server = await self.build_mcp_server_from_table(server)
# Carry the cached short_prefix from the previous registry entry
# (if any) so the prefix is stable across reloads.
if existing_server is not None and existing_server.short_prefix:
new_server.short_prefix = existing_server.short_prefix
new_registry[server.server_id] = new_server
except Exception as e:
verbose_logger.exception(
"Skipping MCP server %s (%s) during DB reload: %s",
server.server_id,
getattr(server, "alias", None),
e,
)
# Swap in the new registry first so _assign_unique_short_prefix
# sees the complete set when checking for collisions.
self.registry = new_registry
for new_server in new_registry.values():
self._assign_unique_short_prefix(new_server)
# Register OpenAPI tools *after* the final short prefix is assigned
# so the tools are stored in the global registry under the same
# prefix that lookups will use.
await self._maybe_register_openapi_tools(new_server)
# Assign short prefixes against the full candidate set without
# publishing the staged registry to concurrent callers.
registered_registry: Dict[str, MCPServer] = {}
registered_openapi_tools = False
for server_id, new_server in new_registry.items():
try:
self._assign_unique_short_prefix(new_server, registry=new_registry)
# Register OpenAPI tools *after* the final short prefix is assigned
# so the tools are stored in the global registry under the same
# prefix that lookups will use.
await self._maybe_register_openapi_tools(
new_server, initialize_mapping=False
)
registered_registry[server_id] = new_server
if new_server.spec_path:
registered_openapi_tools = True
except Exception as e:
verbose_logger.exception(
"Skipping MCP server %s (%s) during DB reload: %s",
new_server.server_id,
getattr(new_server, "alias", None),
e,
)
self.registry = registered_registry
if registered_openapi_tools:
self.initialize_tool_name_to_mcp_server_name_mapping()
verbose_logger.debug(
"MCP registry refreshed (%s servers in registry)", len(new_registry)
"MCP registry refreshed (%s servers in registry)", len(registered_registry)
)
def get_mcp_servers_from_ids(self, server_ids: List[str]) -> List[MCPServer]:

View File

@ -2282,6 +2282,144 @@ class TestMCPServerManagerReload:
mock_build.assert_awaited_once_with(db_row)
assert manager.registry["server-1"] is rebuilt_server
@pytest.mark.asyncio
async def test_skips_server_when_build_from_database_fails(self, caplog):
try:
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
MCPServerManager,
)
except ImportError:
pytest.skip("MCP server not available")
manager = MCPServerManager()
timestamp = datetime.utcnow()
healthy_row = _make_db_mcp_server("healthy-server", timestamp)
bad_row = _make_db_mcp_server("bad-server", timestamp)
another_healthy_row = _make_db_mcp_server("another-healthy-server", timestamp)
healthy_server = MCPServer(
server_id="healthy-server",
name="healthy",
transport=MCPTransport.http,
updated_at=timestamp,
)
another_healthy_server = MCPServer(
server_id="another-healthy-server",
name="another-healthy",
transport=MCPTransport.http,
updated_at=timestamp,
)
async def build_server(db_row):
if db_row.server_id == "bad-server":
raise RuntimeError("transient build failure")
if db_row.server_id == "healthy-server":
return healthy_server
return another_healthy_server
mock_prisma = MagicMock()
mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock(
return_value=[healthy_row, bad_row, another_healthy_row]
)
with (
patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw",
return_value=mock_prisma,
),
patch.object(
manager,
"build_mcp_server_from_table",
AsyncMock(side_effect=build_server),
),
patch.object(manager, "_maybe_register_openapi_tools", AsyncMock()),
caplog.at_level("ERROR", logger="LiteLLM"),
):
await manager.reload_servers_from_database()
assert set(manager.registry) == {"healthy-server", "another-healthy-server"}
assert manager.registry["healthy-server"] is healthy_server
assert manager.registry["another-healthy-server"] is another_healthy_server
assert "Skipping MCP server bad-server" in caplog.text
@pytest.mark.asyncio
async def test_skips_server_when_openapi_registration_fails(self, caplog):
try:
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
MCPServerManager,
)
except ImportError:
pytest.skip("MCP server not available")
manager = MCPServerManager()
timestamp = datetime.utcnow()
healthy_row = _make_db_mcp_server("healthy-server", timestamp)
bad_openapi_row = _make_db_mcp_server("bad-openapi-server", timestamp)
existing_server = MCPServer(
server_id="existing-server",
name="existing",
transport=MCPTransport.http,
updated_at=timestamp,
)
manager.registry = {existing_server.server_id: existing_server}
healthy_server = MCPServer(
server_id="healthy-server",
name="healthy",
transport=MCPTransport.http,
updated_at=timestamp,
)
bad_openapi_server = MCPServer(
server_id="bad-openapi-server",
name="bad-openapi",
transport=MCPTransport.http,
spec_path="https://example.invalid/openapi.json",
updated_at=timestamp,
)
async def build_server(db_row):
if db_row.server_id == "healthy-server":
return healthy_server
return bad_openapi_server
observed_registries = []
async def register_openapi_tools(server, **kwargs):
observed_registries.append(set(manager.registry))
assert kwargs == {"initialize_mapping": False}
if server.server_id == "bad-openapi-server":
raise RuntimeError("blocked address")
mock_prisma = MagicMock()
mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock(
return_value=[healthy_row, bad_openapi_row]
)
with (
patch(
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw",
return_value=mock_prisma,
),
patch.object(
manager,
"build_mcp_server_from_table",
AsyncMock(side_effect=build_server),
),
patch.object(
manager,
"_maybe_register_openapi_tools",
AsyncMock(side_effect=register_openapi_tools),
),
caplog.at_level("ERROR", logger="LiteLLM"),
):
await manager.reload_servers_from_database()
assert set(manager.registry) == {"healthy-server"}
assert manager.registry["healthy-server"] is healthy_server
assert observed_registries == [
{"existing-server"},
{"existing-server"},
]
assert "Skipping MCP server bad-openapi-server" in caplog.text
@pytest.mark.asyncio
async def test_call_mcp_tool_logs_failure_via_post_call_failure_hook():
@ -2946,7 +3084,7 @@ async def test_list_tools_with_legacy_db_m2m_server_resolves_oauth2_flow():
"""
P1 Regression: list_tools path must apply _resolve_oauth2_flow to legacy DB
rows where oauth2_flow is NULL but M2M credentials are present.
Without this fix, has_client_credentials returns False and the caller's
Authorization header is forwarded upstream instead of being blocked.
"""
@ -3044,7 +3182,7 @@ async def test_call_tool_empty_extra_headers_returns_none():
"""
P2 Regression: When all configured extra_headers are filtered out (e.g.
Authorization for M2M), the resulting extra_headers should be None, not {}.
Downstream code that checks `if extra_headers is None` will behave
differently if an empty dict is passed instead.
"""
@ -3071,7 +3209,10 @@ async def test_call_tool_empty_extra_headers_returns_none():
extra_headers=["Authorization"], # Will be filtered out for M2M
)
raw_headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
raw_headers = {
"Authorization": "Bearer sk-1234",
"Content-Type": "application/json",
}
captured_extra_headers = None
@ -3108,8 +3249,8 @@ async def test_call_tool_empty_extra_headers_returns_none():
pass # We only care about the captured headers
# With P2 fix: extra_headers should be None (not {}) when all headers filtered
assert captured_extra_headers is None, (
"P2 API consistency issue: expected None for empty extra_headers, got: "
+ str(captured_extra_headers)
assert (
captured_extra_headers is None
), "P2 API consistency issue: expected None for empty extra_headers, got: " + str(
captured_extra_headers
)