Fix MCP DB reload partial failures (#27314)
* Fix MCP database reload partial failures Co-authored-by: ishaan-berri <ishaan-berri@users.noreply.github.com> * Avoid staged MCP registry exposure Co-authored-by: ishaan-berri <ishaan-berri@users.noreply.github.com> --------- Co-authored-by: oss-agent-shin <279349115+oss-agent-shin@users.noreply.github.com> Co-authored-by: ishaan-berri <ishaan-berri@users.noreply.github.com>
This commit is contained in:
parent
924c141843
commit
bd1a05aed9
@ -768,7 +768,9 @@ class MCPServerManager:
|
||||
)
|
||||
return new_server
|
||||
|
||||
async def _maybe_register_openapi_tools(self, server: MCPServer):
|
||||
async def _maybe_register_openapi_tools(
|
||||
self, server: MCPServer, *, initialize_mapping: bool = True
|
||||
):
|
||||
"""Register OpenAPI tools if the server has a spec_path configured."""
|
||||
if server.spec_path:
|
||||
verbose_logger.info(
|
||||
@ -779,7 +781,8 @@ class MCPServerManager:
|
||||
server=server,
|
||||
base_url=server.url or "",
|
||||
)
|
||||
self.initialize_tool_name_to_mcp_server_name_mapping()
|
||||
if initialize_mapping:
|
||||
self.initialize_tool_name_to_mcp_server_name_mapping()
|
||||
|
||||
async def add_server(self, mcp_server: LiteLLM_MCPServerTable):
|
||||
try:
|
||||
@ -1978,7 +1981,11 @@ class MCPServerManager:
|
||||
|
||||
_SHORT_PREFIX_MAX_REHASH_ATTEMPTS = 1024
|
||||
|
||||
def _assign_unique_short_prefix(self, server: MCPServer) -> None:
|
||||
def _assign_unique_short_prefix(
|
||||
self,
|
||||
server: MCPServer,
|
||||
registry: Optional[Dict[str, MCPServer]] = None,
|
||||
) -> None:
|
||||
"""Resolve and cache a collision-free short tool prefix on ``server``.
|
||||
|
||||
Called at registration time for every MCP server entering the
|
||||
@ -2002,7 +2009,8 @@ class MCPServerManager:
|
||||
return
|
||||
|
||||
used: Dict[str, str] = {}
|
||||
for other in self.get_registry().values():
|
||||
registry_for_collision_check = registry or self.get_registry()
|
||||
for other in registry_for_collision_check.values():
|
||||
if other.server_id == server.server_id:
|
||||
continue
|
||||
if other.short_prefix:
|
||||
@ -2916,46 +2924,72 @@ class MCPServerManager:
|
||||
# against the *full* set so dedup is deterministic regardless of
|
||||
# iteration order.
|
||||
for server in db_mcp_servers:
|
||||
existing_server = previous_registry.get(server.server_id)
|
||||
try:
|
||||
existing_server = previous_registry.get(server.server_id)
|
||||
|
||||
if (
|
||||
existing_server is not None
|
||||
and existing_server.updated_at is not None
|
||||
and server.updated_at is not None
|
||||
and existing_server.updated_at == server.updated_at
|
||||
):
|
||||
# Re-use existing server instance to avoid re-running build_mcp_server_from_table()
|
||||
# which can perform network discovery for OAuth2 servers.
|
||||
new_registry[server.server_id] = existing_server
|
||||
continue
|
||||
if (
|
||||
existing_server is not None
|
||||
and existing_server.updated_at is not None
|
||||
and server.updated_at is not None
|
||||
and existing_server.updated_at == server.updated_at
|
||||
):
|
||||
# Re-use existing server instance to avoid re-running build_mcp_server_from_table()
|
||||
# which can perform network discovery for OAuth2 servers.
|
||||
new_registry[server.server_id] = existing_server
|
||||
continue
|
||||
|
||||
_warn_on_server_name_fields(
|
||||
server_id=server.server_id,
|
||||
alias=getattr(server, "alias", None),
|
||||
server_name=getattr(server, "server_name", None),
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Building server from DB: {server.server_id} ({server.server_name})"
|
||||
)
|
||||
new_server = await self.build_mcp_server_from_table(server)
|
||||
# Carry the cached short_prefix from the previous registry entry
|
||||
# (if any) so the prefix is stable across reloads.
|
||||
if existing_server is not None and existing_server.short_prefix:
|
||||
new_server.short_prefix = existing_server.short_prefix
|
||||
new_registry[server.server_id] = new_server
|
||||
_warn_on_server_name_fields(
|
||||
server_id=server.server_id,
|
||||
alias=getattr(server, "alias", None),
|
||||
server_name=getattr(server, "server_name", None),
|
||||
)
|
||||
verbose_logger.debug(
|
||||
f"Building server from DB: {server.server_id} ({server.server_name})"
|
||||
)
|
||||
new_server = await self.build_mcp_server_from_table(server)
|
||||
# Carry the cached short_prefix from the previous registry entry
|
||||
# (if any) so the prefix is stable across reloads.
|
||||
if existing_server is not None and existing_server.short_prefix:
|
||||
new_server.short_prefix = existing_server.short_prefix
|
||||
new_registry[server.server_id] = new_server
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Skipping MCP server %s (%s) during DB reload: %s",
|
||||
server.server_id,
|
||||
getattr(server, "alias", None),
|
||||
e,
|
||||
)
|
||||
|
||||
# Swap in the new registry first so _assign_unique_short_prefix
|
||||
# sees the complete set when checking for collisions.
|
||||
self.registry = new_registry
|
||||
for new_server in new_registry.values():
|
||||
self._assign_unique_short_prefix(new_server)
|
||||
# Register OpenAPI tools *after* the final short prefix is assigned
|
||||
# so the tools are stored in the global registry under the same
|
||||
# prefix that lookups will use.
|
||||
await self._maybe_register_openapi_tools(new_server)
|
||||
# Assign short prefixes against the full candidate set without
|
||||
# publishing the staged registry to concurrent callers.
|
||||
registered_registry: Dict[str, MCPServer] = {}
|
||||
registered_openapi_tools = False
|
||||
for server_id, new_server in new_registry.items():
|
||||
try:
|
||||
self._assign_unique_short_prefix(new_server, registry=new_registry)
|
||||
# Register OpenAPI tools *after* the final short prefix is assigned
|
||||
# so the tools are stored in the global registry under the same
|
||||
# prefix that lookups will use.
|
||||
await self._maybe_register_openapi_tools(
|
||||
new_server, initialize_mapping=False
|
||||
)
|
||||
registered_registry[server_id] = new_server
|
||||
if new_server.spec_path:
|
||||
registered_openapi_tools = True
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
"Skipping MCP server %s (%s) during DB reload: %s",
|
||||
new_server.server_id,
|
||||
getattr(new_server, "alias", None),
|
||||
e,
|
||||
)
|
||||
|
||||
self.registry = registered_registry
|
||||
if registered_openapi_tools:
|
||||
self.initialize_tool_name_to_mcp_server_name_mapping()
|
||||
|
||||
verbose_logger.debug(
|
||||
"MCP registry refreshed (%s servers in registry)", len(new_registry)
|
||||
"MCP registry refreshed (%s servers in registry)", len(registered_registry)
|
||||
)
|
||||
|
||||
def get_mcp_servers_from_ids(self, server_ids: List[str]) -> List[MCPServer]:
|
||||
|
||||
@ -2282,6 +2282,144 @@ class TestMCPServerManagerReload:
|
||||
mock_build.assert_awaited_once_with(db_row)
|
||||
assert manager.registry["server-1"] is rebuilt_server
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_server_when_build_from_database_fails(self, caplog):
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
MCPServerManager,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.skip("MCP server not available")
|
||||
|
||||
manager = MCPServerManager()
|
||||
timestamp = datetime.utcnow()
|
||||
healthy_row = _make_db_mcp_server("healthy-server", timestamp)
|
||||
bad_row = _make_db_mcp_server("bad-server", timestamp)
|
||||
another_healthy_row = _make_db_mcp_server("another-healthy-server", timestamp)
|
||||
|
||||
healthy_server = MCPServer(
|
||||
server_id="healthy-server",
|
||||
name="healthy",
|
||||
transport=MCPTransport.http,
|
||||
updated_at=timestamp,
|
||||
)
|
||||
another_healthy_server = MCPServer(
|
||||
server_id="another-healthy-server",
|
||||
name="another-healthy",
|
||||
transport=MCPTransport.http,
|
||||
updated_at=timestamp,
|
||||
)
|
||||
|
||||
async def build_server(db_row):
|
||||
if db_row.server_id == "bad-server":
|
||||
raise RuntimeError("transient build failure")
|
||||
if db_row.server_id == "healthy-server":
|
||||
return healthy_server
|
||||
return another_healthy_server
|
||||
|
||||
mock_prisma = MagicMock()
|
||||
mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock(
|
||||
return_value=[healthy_row, bad_row, another_healthy_row]
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw",
|
||||
return_value=mock_prisma,
|
||||
),
|
||||
patch.object(
|
||||
manager,
|
||||
"build_mcp_server_from_table",
|
||||
AsyncMock(side_effect=build_server),
|
||||
),
|
||||
patch.object(manager, "_maybe_register_openapi_tools", AsyncMock()),
|
||||
caplog.at_level("ERROR", logger="LiteLLM"),
|
||||
):
|
||||
await manager.reload_servers_from_database()
|
||||
|
||||
assert set(manager.registry) == {"healthy-server", "another-healthy-server"}
|
||||
assert manager.registry["healthy-server"] is healthy_server
|
||||
assert manager.registry["another-healthy-server"] is another_healthy_server
|
||||
assert "Skipping MCP server bad-server" in caplog.text
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_server_when_openapi_registration_fails(self, caplog):
|
||||
try:
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
MCPServerManager,
|
||||
)
|
||||
except ImportError:
|
||||
pytest.skip("MCP server not available")
|
||||
|
||||
manager = MCPServerManager()
|
||||
timestamp = datetime.utcnow()
|
||||
healthy_row = _make_db_mcp_server("healthy-server", timestamp)
|
||||
bad_openapi_row = _make_db_mcp_server("bad-openapi-server", timestamp)
|
||||
existing_server = MCPServer(
|
||||
server_id="existing-server",
|
||||
name="existing",
|
||||
transport=MCPTransport.http,
|
||||
updated_at=timestamp,
|
||||
)
|
||||
manager.registry = {existing_server.server_id: existing_server}
|
||||
|
||||
healthy_server = MCPServer(
|
||||
server_id="healthy-server",
|
||||
name="healthy",
|
||||
transport=MCPTransport.http,
|
||||
updated_at=timestamp,
|
||||
)
|
||||
bad_openapi_server = MCPServer(
|
||||
server_id="bad-openapi-server",
|
||||
name="bad-openapi",
|
||||
transport=MCPTransport.http,
|
||||
spec_path="https://example.invalid/openapi.json",
|
||||
updated_at=timestamp,
|
||||
)
|
||||
|
||||
async def build_server(db_row):
|
||||
if db_row.server_id == "healthy-server":
|
||||
return healthy_server
|
||||
return bad_openapi_server
|
||||
|
||||
observed_registries = []
|
||||
|
||||
async def register_openapi_tools(server, **kwargs):
|
||||
observed_registries.append(set(manager.registry))
|
||||
assert kwargs == {"initialize_mapping": False}
|
||||
if server.server_id == "bad-openapi-server":
|
||||
raise RuntimeError("blocked address")
|
||||
|
||||
mock_prisma = MagicMock()
|
||||
mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock(
|
||||
return_value=[healthy_row, bad_openapi_row]
|
||||
)
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw",
|
||||
return_value=mock_prisma,
|
||||
),
|
||||
patch.object(
|
||||
manager,
|
||||
"build_mcp_server_from_table",
|
||||
AsyncMock(side_effect=build_server),
|
||||
),
|
||||
patch.object(
|
||||
manager,
|
||||
"_maybe_register_openapi_tools",
|
||||
AsyncMock(side_effect=register_openapi_tools),
|
||||
),
|
||||
caplog.at_level("ERROR", logger="LiteLLM"),
|
||||
):
|
||||
await manager.reload_servers_from_database()
|
||||
|
||||
assert set(manager.registry) == {"healthy-server"}
|
||||
assert manager.registry["healthy-server"] is healthy_server
|
||||
assert observed_registries == [
|
||||
{"existing-server"},
|
||||
{"existing-server"},
|
||||
]
|
||||
assert "Skipping MCP server bad-openapi-server" in caplog.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_mcp_tool_logs_failure_via_post_call_failure_hook():
|
||||
@ -2946,7 +3084,7 @@ async def test_list_tools_with_legacy_db_m2m_server_resolves_oauth2_flow():
|
||||
"""
|
||||
P1 Regression: list_tools path must apply _resolve_oauth2_flow to legacy DB
|
||||
rows where oauth2_flow is NULL but M2M credentials are present.
|
||||
|
||||
|
||||
Without this fix, has_client_credentials returns False and the caller's
|
||||
Authorization header is forwarded upstream instead of being blocked.
|
||||
"""
|
||||
@ -3044,7 +3182,7 @@ async def test_call_tool_empty_extra_headers_returns_none():
|
||||
"""
|
||||
P2 Regression: When all configured extra_headers are filtered out (e.g.
|
||||
Authorization for M2M), the resulting extra_headers should be None, not {}.
|
||||
|
||||
|
||||
Downstream code that checks `if extra_headers is None` will behave
|
||||
differently if an empty dict is passed instead.
|
||||
"""
|
||||
@ -3071,7 +3209,10 @@ async def test_call_tool_empty_extra_headers_returns_none():
|
||||
extra_headers=["Authorization"], # Will be filtered out for M2M
|
||||
)
|
||||
|
||||
raw_headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
raw_headers = {
|
||||
"Authorization": "Bearer sk-1234",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
captured_extra_headers = None
|
||||
|
||||
@ -3108,8 +3249,8 @@ async def test_call_tool_empty_extra_headers_returns_none():
|
||||
pass # We only care about the captured headers
|
||||
|
||||
# With P2 fix: extra_headers should be None (not {}) when all headers filtered
|
||||
assert captured_extra_headers is None, (
|
||||
"P2 API consistency issue: expected None for empty extra_headers, got: "
|
||||
+ str(captured_extra_headers)
|
||||
assert (
|
||||
captured_extra_headers is None
|
||||
), "P2 API consistency issue: expected None for empty extra_headers, got: " + str(
|
||||
captured_extra_headers
|
||||
)
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user