diff --git a/litellm/proxy/_experimental/mcp_server/db.py b/litellm/proxy/_experimental/mcp_server/db.py index e30667776c..d7b2224eb6 100644 --- a/litellm/proxy/_experimental/mcp_server/db.py +++ b/litellm/proxy/_experimental/mcp_server/db.py @@ -67,6 +67,17 @@ def _prepare_mcp_server_data( # ``alias=None`` is a valid request to clear the stored alias. if data_dict.get("alias") is None and "alias" not in fields_set: data_dict.pop("alias", None) + # Prisma ``allowed_tools`` is a required String[]; ``null`` is invalid. + # The UI sends null to clear a whitelist — treat that as ``[]``. + if "allowed_tools" in data_dict and data_dict["allowed_tools"] is None: + data_dict["allowed_tools"] = [] + # Json map fields use ``@default("{}")``; explicit null means clear overrides. + for json_map_field in ( + "tool_name_to_display_name", + "tool_name_to_description", + ): + if json_map_field in data_dict and data_dict[json_map_field] is None: + data_dict[json_map_field] = {} else: data_dict = data.model_dump(exclude_none=True) # Ensure alias is always present in the dict (even if None) @@ -93,13 +104,13 @@ def _prepare_mcp_server_data( if data_dict.get("env") is not None: data_dict["env"] = safe_dumps(data_dict["env"]) - if data_dict.get("tool_name_to_display_name") is not None: + if "tool_name_to_display_name" in data_dict: data_dict["tool_name_to_display_name"] = safe_dumps( - data_dict["tool_name_to_display_name"] + data_dict["tool_name_to_display_name"] or {} ) - if data_dict.get("tool_name_to_description") is not None: + if "tool_name_to_description" in data_dict: data_dict["tool_name_to_description"] = safe_dumps( - data_dict["tool_name_to_description"] + data_dict["tool_name_to_description"] or {} ) # mcp_access_groups is already List[str], no serialization needed diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index f35aa30a7c..b4678a50b2 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -2429,7 +2429,13 @@ class MCPServerManager: """ Check if the tool is allowed or banned for the given server """ - if server.allowed_tools: + from litellm.proxy._experimental.mcp_server.utils import ( + server_applies_tool_allowlist, + ) + + if server_applies_tool_allowlist(server): + if not server.allowed_tools: + return False return ( tool_name in server.allowed_tools or f"{server.name}-{tool_name}" in server.allowed_tools diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index cec5224e18..693ca5a764 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -365,10 +365,9 @@ if MCP_AVAILABLE: user_api_key_auth=user_api_key_auth, ) - # Filter tools based on allowed_tools configuration - # Only filter if allowed_tools is explicitly configured (not None and not empty) - if server.allowed_tools is not None and len(server.allowed_tools) > 0: - tools = filter_tools_by_allowed_tools(tools, server) + # Always apply allowed_tools/disallowed_tools so the blacklist is + # enforced even when no allowlist is set (matches the SSE/HTTP path). + tools = filter_tools_by_allowed_tools(tools, server) # Filter tools based on user_api_key_auth.object_permission.mcp_tool_permissions # This provides per-key/team/org control over which tools can be accessed diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index a05ce3f741..c17ec13d3e 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -945,10 +945,16 @@ if MCP_AVAILABLE: Returns: Filtered list of tools """ + from litellm.proxy._experimental.mcp_server.utils import ( + server_applies_tool_allowlist, + ) + tools_to_return = tools # Filter by allowed_tools (whitelist) - if mcp_server.allowed_tools: + if server_applies_tool_allowlist(mcp_server): + if not mcp_server.allowed_tools: + return [] tools_to_return = [ tool for tool in tools diff --git a/litellm/proxy/_experimental/mcp_server/utils.py b/litellm/proxy/_experimental/mcp_server/utils.py index b8b9207555..b66dfa85b9 100644 --- a/litellm/proxy/_experimental/mcp_server/utils.py +++ b/litellm/proxy/_experimental/mcp_server/utils.py @@ -2,6 +2,7 @@ MCP Server Utilities """ +import json import re from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Union @@ -162,6 +163,36 @@ def lookup_mcp_server_auth_in_headers( return None +MCP_TOOL_ALLOWLIST_ENFORCED_KEY = "tool_allowlist_enforced" + + +def _parse_mcp_info_dict(mcp_info: Any) -> Optional[Dict[str, Any]]: + if mcp_info is None: + return None + if isinstance(mcp_info, dict): + return mcp_info + if isinstance(mcp_info, str): + try: + parsed = json.loads(mcp_info) + except (ValueError, TypeError): + return None + return parsed if isinstance(parsed, dict) else None + return None + + +def is_server_tool_allowlist_enforced(mcp_server: Any) -> bool: + mcp_info = _parse_mcp_info_dict(getattr(mcp_server, "mcp_info", None)) + if not mcp_info: + return False + return bool(mcp_info.get(MCP_TOOL_ALLOWLIST_ENFORCED_KEY)) + + +def server_applies_tool_allowlist(mcp_server: Any) -> bool: + """Whether server-level allowed_tools whitelist filtering is active.""" + allowed_tools = getattr(mcp_server, "allowed_tools", None) or [] + return is_server_tool_allowlist_enforced(mcp_server) or bool(allowed_tools) + + def validate_and_normalize_mcp_server_payload(payload: Any) -> None: """ Validate and normalize MCP server payload fields (server_name and alias). diff --git a/tests/mcp_tests/test_mcp_server.py b/tests/mcp_tests/test_mcp_server.py index d76ebb0072..e65b45fb38 100644 --- a/tests/mcp_tests/test_mcp_server.py +++ b/tests/mcp_tests/test_mcp_server.py @@ -1862,9 +1862,11 @@ async def test_get_tools_for_single_server(): ) from mcp.types import Tool as MCPTool - # Create a mock server + # Create a mock server (pin allowlist fields; MagicMock auto-attrs are truthy) mock_server = MagicMock() mock_server.mcp_info = {"server_name": "zapier"} + mock_server.allowed_tools = None + mock_server.disallowed_tools = None # Create mock tools mock_tools = [ @@ -1899,6 +1901,44 @@ async def test_get_tools_for_single_server(): assert result[0].mcp_info == {"server_name": "zapier"} +@pytest.mark.asyncio +async def test_get_tools_for_single_server_applies_disallowed_tools_without_allowlist(): + """REST listing must honor disallowed_tools even when no allowlist is set.""" + from litellm.proxy._experimental.mcp_server.rest_endpoints import ( + _get_tools_for_single_server, + ) + from mcp.types import Tool as MCPTool + + mock_server = MagicMock() + mock_server.mcp_info = {"server_name": "zapier"} + mock_server.name = "zapier" + mock_server.server_id = "zapier" + mock_server.allowed_tools = None + mock_server.disallowed_tools = ["send_email"] + + mock_tools = [ + MCPTool( + name="send_email", + description="Send an email", + inputSchema={"type": "object"}, + ), + MCPTool( + name="read_email", + description="Read an email", + inputSchema={"type": "object"}, + ), + ] + + with patch( + "litellm.proxy._experimental.mcp_server.rest_endpoints.global_mcp_server_manager" + ) as mock_manager: + mock_manager._get_tools_from_server = AsyncMock(return_value=mock_tools) + + result = await _get_tools_for_single_server(mock_server, "Bearer test_token") + + assert [tool.name for tool in result] == ["read_email"] + + @pytest.mark.asyncio async def test_list_tool_rest_api_with_server_specific_auth(): """Test list_tool_rest_api with server-specific auth headers.""" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_partial_update.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_partial_update.py index b5e0f20f66..49facdbaea 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_partial_update.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_partial_update.py @@ -68,6 +68,34 @@ async def test_partial_update_omits_unset_defaultful_fields(): ) +@pytest.mark.asyncio +async def test_partial_update_null_tool_name_maps_clear_to_empty_json(): + """Explicit null on Json map fields must clear overrides (UI legacy).""" + data = UpdateMCPServerRequest( + server_id="my-test-server", + tool_name_to_display_name=None, + tool_name_to_description=None, + ) + + data_dict = await _run_update(data) + + assert data_dict["tool_name_to_display_name"] == "{}" + assert data_dict["tool_name_to_description"] == "{}" + + +@pytest.mark.asyncio +async def test_partial_update_null_allowed_tools_clears_whitelist(): + """Explicit null must clear the whitelist (UI legacy); Prisma requires [].""" + data = UpdateMCPServerRequest( + server_id="my-test-server", + allowed_tools=None, + ) + + data_dict = await _run_update(data) + + assert data_dict["allowed_tools"] == [] + + @pytest.mark.asyncio async def test_partial_update_preserves_http_transport(): """The reported prod incident: a PUT without transport must not flip http->sse.""" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index fb21e4ee11..bb0cc86037 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -4184,6 +4184,85 @@ def test_filter_tools_by_allowed_tools_no_filter(): assert len(filtered_tools) == 2 +def test_filter_tools_enforced_empty_allowlist_blocks_all(): + from mcp.types import Tool + + from litellm.proxy._experimental.mcp_server.server import ( + filter_tools_by_allowed_tools, + ) + from litellm.types.mcp import MCPTransport + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + tools = [ + Tool( + name="read_wiki_structure", + title=None, + description="", + inputSchema={"type": "object"}, + outputSchema=None, + annotations=None, + ), + ] + server = MCPServer( + server_id="deepwiki", + name="deepwiki", + transport=MCPTransport.http, + allowed_tools=[], + mcp_info={"tool_allowlist_enforced": True}, + ) + + assert filter_tools_by_allowed_tools(tools, server) == [] + + +def test_filter_tools_legacy_empty_allowlist_allows_all(): + from mcp.types import Tool + + from litellm.proxy._experimental.mcp_server.server import ( + filter_tools_by_allowed_tools, + ) + from litellm.types.mcp import MCPTransport + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + tools = [ + Tool( + name="read_wiki_structure", + title=None, + description="", + inputSchema={"type": "object"}, + outputSchema=None, + annotations=None, + ), + ] + server = MCPServer( + server_id="legacy", + name="legacy", + transport=MCPTransport.http, + allowed_tools=[], + mcp_info=None, + ) + + assert len(filter_tools_by_allowed_tools(tools, server)) == 1 + + +def test_check_allowed_or_banned_tools_enforced_empty_denies_calls(): + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp import MCPTransport + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager.__new__(MCPServerManager) + server = MCPServer( + server_id="deepwiki", + name="deepwiki", + transport=MCPTransport.http, + allowed_tools=[], + mcp_info={"tool_allowlist_enforced": True}, + ) + + assert manager.check_allowed_or_banned_tools("read_wiki_structure", server) is False + + @pytest.mark.asyncio async def test_get_tools_from_mcp_servers_injects_stored_oauth2_token(): """ @@ -4540,9 +4619,9 @@ class TestEnsureUpstreamInitializeInstructionsCached: await global_mcp_server_manager._ensure_upstream_initialize_instructions_cached( server ) - assert create.await_count == 1, ( - "Second probe within cooldown must not reconnect to upstream" - ) + assert ( + create.await_count == 1 + ), "Second probe within cooldown must not reconnect to upstream" assert ( "empty-server" not in global_mcp_server_manager._upstream_initialize_instructions_by_server_id @@ -4567,7 +4646,9 @@ class TestEnsureUpstreamInitializeInstructionsCached: server = _make_instruction_server(server_id="boom-server", instructions=None) fake_client = MagicMock() - fake_client.run_with_session = AsyncMock(side_effect=RuntimeError("upstream down")) + fake_client.run_with_session = AsyncMock( + side_effect=RuntimeError("upstream down") + ) fake_client._last_initialize_instructions = None create = AsyncMock(return_value=fake_client) @@ -4579,9 +4660,9 @@ class TestEnsureUpstreamInitializeInstructionsCached: await global_mcp_server_manager._ensure_upstream_initialize_instructions_cached( server ) - assert create.await_count == 1, ( - "Second probe within cooldown must not reconnect after failure" - ) + assert ( + create.await_count == 1 + ), "Second probe within cooldown must not reconnect after failure" assert ( "boom-server" not in global_mcp_server_manager._upstream_initialize_instructions_by_server_id diff --git a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.test.tsx b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.test.tsx index b425126713..d635d7bb6b 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.test.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.test.tsx @@ -27,7 +27,19 @@ vi.mock("./MCPPermissionManagement", () => ({ })); vi.mock("./mcp_tool_configuration", () => ({ - default: () =>
, + default: ({ onAllowedToolsChange, onToolAllowlistInteraction }: any) => ( +
+ +
+ ), })); vi.mock("./mcp_connection_status", () => ({ @@ -335,6 +347,50 @@ describe("CreateMCPServer", () => { // No credentials should be sent for "none" auth expect(payload.credentials).toBeUndefined(); }); + + it("enforces the allowlist when the user explicitly deselects every tool", async () => { + await selectHttpTransport(); + + const user = userEvent.setup({ delay: null }); + + const nameInput = getServerNameInput(); + await user.type(nameInput, "Locked_Down_Server"); + + const urlInput = screen.getByPlaceholderText("https://your-mcp-server.com"); + await user.type(urlInput, "https://example.com/mcp"); + + await selectAntOption("Authentication", "None"); + + await act(async () => { + fireEvent.click(screen.getByRole("button", { name: "Disable all tools" })); + }); + + vi.mocked(networking.createMCPServer).mockResolvedValue({ + server_id: "new-server-1", + server_name: "Locked_Down_Server", + alias: "Locked_Down_Server", + url: "https://example.com/mcp", + transport: "http", + auth_type: "none", + created_at: "2024-01-01T00:00:00Z", + created_by: "user-1", + updated_at: "2024-01-01T00:00:00Z", + updated_by: "user-1", + }); + + const submitButton = screen.getByRole("button", { name: "Add MCP Server" }); + await act(async () => { + fireEvent.click(submitButton); + }); + + await waitFor(() => { + expect(networking.createMCPServer).toHaveBeenCalledTimes(1); + }); + + const [, payload] = vi.mocked(networking.createMCPServer).mock.calls[0]; + expect(payload.mcp_info.tool_allowlist_enforced).toBe(true); + expect(payload.allowed_tools).toEqual([]); + }); }); describe("when OAuth interactive auth is selected", () => { diff --git a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx index 108911bdbf..784de6e03c 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx @@ -69,6 +69,7 @@ const CreateMCPServer: React.FC = ({ } | null>(null); const [aliasManuallyEdited, setAliasManuallyEdited] = useState(false); const [allowedTools, setAllowedTools] = useState([]); + const [hasToolAllowlistInteraction, setHasToolAllowlistInteraction] = useState(false); const [toolNameToDisplayName, setToolNameToDisplayName] = useState>({}); const [toolNameToDescription, setToolNameToDescription] = useState>({}); const [transportType, setTransportType] = useState(""); @@ -106,6 +107,7 @@ const CreateMCPServer: React.FC = ({ transportType, costConfig, allowedTools, + hasToolAllowlistInteraction, searchValue, aliasManuallyEdited, logoUrl, @@ -204,6 +206,9 @@ const CreateMCPServer: React.FC = ({ if (parsed.allowedTools) { setAllowedTools(parsed.allowedTools); } + if (typeof parsed.hasToolAllowlistInteraction === "boolean") { + setHasToolAllowlistInteraction(parsed.hasToolAllowlistInteraction); + } if (parsed.searchValue) { setSearchValue(parsed.searchValue); } @@ -384,12 +389,13 @@ const CreateMCPServer: React.FC = ({ description: restValues.description, logo_url: logoUrl || undefined, mcp_server_cost_info: Object.keys(costConfig).length > 0 ? costConfig : null, + tool_allowlist_enforced: hasToolAllowlistInteraction || allowedTools.length > 0, }, mcp_access_groups: accessGroups, alias: restValues.alias, - allowed_tools: allowedTools.length > 0 ? allowedTools : null, - tool_name_to_display_name: Object.keys(toolNameToDisplayName).length > 0 ? toolNameToDisplayName : null, - tool_name_to_description: Object.keys(toolNameToDescription).length > 0 ? toolNameToDescription : null, + allowed_tools: allowedTools, + tool_name_to_display_name: toolNameToDisplayName, + tool_name_to_description: toolNameToDescription, allow_all_keys: Boolean(allowAllKeysRaw), available_on_public_internet: Boolean(availableOnPublicInternetRaw), delegate_auth_to_upstream: Boolean(delegateAuthToUpstreamRaw), @@ -436,6 +442,7 @@ const CreateMCPServer: React.FC = ({ setCostConfig({}); clearTools(); setAllowedTools([]); + setHasToolAllowlistInteraction(false); setAliasManuallyEdited(false); setLogoUrl(undefined); setModalVisible(false); @@ -457,6 +464,7 @@ const CreateMCPServer: React.FC = ({ setCostConfig({}); clearTools(); setAllowedTools([]); + setHasToolAllowlistInteraction(false); setAliasManuallyEdited(false); setLogoUrl(undefined); setModalVisible(false); @@ -1040,6 +1048,8 @@ const CreateMCPServer: React.FC = ({ allowedTools={allowedTools} existingAllowedTools={null} onAllowedToolsChange={setAllowedTools} + hasToolAllowlistInteraction={hasToolAllowlistInteraction} + onToolAllowlistInteraction={() => setHasToolAllowlistInteraction(true)} toolNameToDisplayName={toolNameToDisplayName} toolNameToDescription={toolNameToDescription} onToolNameToDisplayNameChange={setToolNameToDisplayName} diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.test.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.test.tsx index 1f2864f675..ed3b22a569 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.test.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.test.tsx @@ -35,7 +35,34 @@ vi.mock("./MCPPermissionManagement", () => ({ })); vi.mock("./mcp_tool_configuration", () => ({ - default: () =>
, + default: ({ + existingAllowedTools, + onAllowedToolsChange, + onToolAllowlistInteraction, + onToolNameToDisplayNameChange, + onToolNameToDescriptionChange, + }: any) => ( +
+ + +
+ ), })); // ── fixtures ────────────────────────────────────────────────────────────────── @@ -43,7 +70,7 @@ vi.mock("./mcp_tool_configuration", () => ({ const interactiveOAuthServer = { server_id: "oauth_server_1", server_name: "OAuthServer", - alias: "oauth_server", // underscores: hyphens fail validateMCPServerName + alias: "oauth_server", // underscores: hyphens fail validateMCPServerName description: "Interactive OAuth MCP server", transport: "http", url: "https://example.com/mcp", @@ -218,6 +245,128 @@ describe("MCPServerEdit (delegate auth)", () => { }); }); +describe("MCPServerEdit (tool allowlist)", () => { + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("treats legacy empty allowed_tools as unrestricted", () => { + render( + , + ); + + expect(screen.getByTestId("mcp-tool-config")).toHaveAttribute("data-existing-allowed-tools", "null"); + }); + + it("honors enforced empty allowed_tools", () => { + render( + , + ); + + expect(screen.getByTestId("mcp-tool-config")).toHaveAttribute("data-existing-allowed-tools", "[]"); + }); + + it("saves an explicit empty allowlist after legacy unrestricted tools are disabled", async () => { + vi.mocked(networking.updateMCPServer).mockResolvedValue({ + ...interactiveOAuthServer, + allowed_tools: [], + mcp_info: { server_name: "OAuthServer", tool_allowlist_enforced: true }, + }); + + render( + , + ); + + await act(async () => { + fireEvent.click(screen.getByRole("button", { name: "Disable all tools" })); + }); + + const saveButtons = screen.getAllByRole("button", { name: "Save Changes" }); + await act(async () => { + fireEvent.click(saveButtons[0]); + }); + + await waitFor(() => { + expect(networking.updateMCPServer).toHaveBeenCalledTimes(1); + }); + + const [, payload] = vi.mocked(networking.updateMCPServer).mock.calls[0]; + expect(payload.mcp_info.tool_allowlist_enforced).toBe(true); + expect(payload.allowed_tools).toEqual([]); + }); + + it("saves tool overrides for legacy unrestricted servers", async () => { + vi.mocked(networking.updateMCPServer).mockResolvedValue({ + ...interactiveOAuthServer, + tool_name_to_display_name: { read_user: "Read User" }, + tool_name_to_description: { read_user: "Reads users" }, + }); + + render( + , + ); + + await act(async () => { + fireEvent.click(screen.getByRole("button", { name: "Set tool overrides" })); + }); + + const saveButtons = screen.getAllByRole("button", { name: "Save Changes" }); + await act(async () => { + fireEvent.click(saveButtons[0]); + }); + + await waitFor(() => { + expect(networking.updateMCPServer).toHaveBeenCalledTimes(1); + }); + + const [, payload] = vi.mocked(networking.updateMCPServer).mock.calls[0]; + expect(payload.mcp_info.tool_allowlist_enforced).toBe(false); + expect(payload.allowed_tools).toBeUndefined(); + expect(payload.tool_name_to_display_name).toEqual({ read_user: "Read User" }); + expect(payload.tool_name_to_description).toEqual({ read_user: "Reads users" }); + }); +}); + describe("MCPServerEdit (interactive OAuth)", () => { beforeEach(() => { vi.clearAllMocks(); diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx index 9278d41c3e..ab9c9ed668 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_server_edit.tsx @@ -41,6 +41,7 @@ const MCPServerEdit: React.FC = ({ const [searchValue, setSearchValue] = useState(""); const [aliasManuallyEdited, setAliasManuallyEdited] = useState(false); const [allowedTools, setAllowedTools] = useState([]); + const [hasToolAllowlistInteraction, setHasToolAllowlistInteraction] = useState(false); const [toolNameToDisplayName, setToolNameToDisplayName] = useState>({}); const [toolNameToDescription, setToolNameToDescription] = useState>({}); const [pendingRestoredValues, setPendingRestoredValues] = useState | null>(null); @@ -68,6 +69,9 @@ const MCPServerEdit: React.FC = ({ const currentAuthorizationUrl = Form.useWatch("authorization_url", form); const currentTokenUrl = Form.useWatch("token_url", form); const currentRegistrationUrl = Form.useWatch("registration_url", form); + const hasExistingToolAllowlist = + Boolean(mcpServer.mcp_info?.tool_allowlist_enforced) || (mcpServer.allowed_tools?.length ?? 0) > 0; + const existingAllowedTools = hasExistingToolAllowlist ? mcpServer.allowed_tools ?? [] : null; const persistEditUiState = () => { if (typeof window === "undefined") { @@ -82,6 +86,7 @@ const MCPServerEdit: React.FC = ({ formValues: values, costConfig, allowedTools, + hasToolAllowlistInteraction, searchValue, aliasManuallyEdited, }), @@ -135,7 +140,7 @@ const MCPServerEdit: React.FC = ({ }, onTokenReceived: (token) => { setOauthAccessToken(token?.access_token ?? null); - + if (token?.access_token) { const credentials = { access_token: token.access_token, @@ -143,11 +148,11 @@ const MCPServerEdit: React.FC = ({ ...(token.expires_in && { expires_in: token.expires_in }), ...(token.scope && { scope: token.scope }), }; - + form.setFieldsValue({ credentials }); - + NotificationsManager.success( - "OAuth authorization successful! Please click 'Update MCP Server' to save the credentials." + "OAuth authorization successful! Please click 'Update MCP Server' to save the credentials.", ); } }, @@ -176,7 +181,6 @@ const MCPServerEdit: React.FC = ({ } }, [mcpServer.env]); - // If server has spec_path, show it as "openapi" transport in the UI const effectiveTransport = React.useMemo(() => { if (mcpServer.spec_path && mcpServer.transport !== "stdio") { @@ -208,12 +212,16 @@ const MCPServerEdit: React.FC = ({ // Initialize allowed tools and tool overrides from existing server data useEffect(() => { - if (mcpServer.allowed_tools) { - setAllowedTools(mcpServer.allowed_tools); + setHasToolAllowlistInteraction(false); + }, [mcpServer.server_id]); + + useEffect(() => { + if (hasExistingToolAllowlist) { + setAllowedTools(mcpServer.allowed_tools ?? []); } setToolNameToDisplayName(mcpServer.tool_name_to_display_name ?? {}); setToolNameToDescription(mcpServer.tool_name_to_description ?? {}); - }, [mcpServer]); + }, [mcpServer, hasExistingToolAllowlist]); useEffect(() => { if (typeof window === "undefined") { @@ -238,6 +246,9 @@ const MCPServerEdit: React.FC = ({ if (parsed.allowedTools) { setAllowedTools(parsed.allowedTools); } + if (typeof parsed.hasToolAllowlistInteraction === "boolean") { + setHasToolAllowlistInteraction(parsed.hasToolAllowlistInteraction); + } if (parsed.searchValue) { setSearchValue(parsed.searchValue); } @@ -529,6 +540,8 @@ const MCPServerEdit: React.FC = ({ mcpServer.alias || "unknown"; + const toolAllowlistEnforced = hasExistingToolAllowlist || hasToolAllowlistInteraction || allowedTools.length > 0; + const payload: Record = { ...restValues, ...stdioFields, @@ -537,16 +550,22 @@ const MCPServerEdit: React.FC = ({ env_json: undefined, server_id: mcpServer.server_id, mcp_info: { + ...(mcpServer.mcp_info ?? {}), server_name: mcpInfoServerName, description: restValues.description, logo_url: logoUrl || undefined, mcp_server_cost_info: Object.keys(costConfig).length > 0 ? costConfig : null, + tool_allowlist_enforced: toolAllowlistEnforced, }, mcp_access_groups: accessGroups, alias: restValues.alias, // Include permission management fields extra_headers: restValues.extra_headers || [], - allowed_tools: allowedTools.length > 0 ? allowedTools : null, + ...(toolAllowlistEnforced + ? { + allowed_tools: allowedTools, + } + : {}), tool_name_to_display_name: Object.keys(toolNameToDisplayName).length > 0 ? toolNameToDisplayName : null, tool_name_to_description: Object.keys(toolNameToDescription).length > 0 ? toolNameToDescription : null, disallowed_tools: restValues.disallowed_tools || [], @@ -563,12 +582,11 @@ const MCPServerEdit: React.FC = ({ ? Boolean(delegateAuthToUpstreamRaw ?? mcpServer.delegate_auth_to_upstream) : false, // Include token_validation when it is set (non-null) or when clearing an existing value - ...(tokenValidation !== null || mcpServer.token_validation - ? { token_validation: tokenValidation } - : {}), + ...(tokenValidation !== null || mcpServer.token_validation ? { token_validation: tokenValidation } : {}), }; - const includeCredentials = restValues.auth_type && AUTH_TYPES_REQUIRING_CREDENTIALS.includes(restValues.auth_type); + const includeCredentials = + restValues.auth_type && AUTH_TYPES_REQUIRING_CREDENTIALS.includes(restValues.auth_type); if (includeCredentials && credentialsPayload && Object.keys(credentialsPayload).length > 0) { payload.credentials = credentialsPayload; @@ -700,10 +718,7 @@ const MCPServerEdit: React.FC = ({ /> - +