diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/mcpServers/useMCPAccessGroups.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/mcpServers/useMCPAccessGroups.ts new file mode 100644 index 0000000000..eeeb76bb74 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/mcpServers/useMCPAccessGroups.ts @@ -0,0 +1,13 @@ +import { useQuery } from "@tanstack/react-query"; +import { createQueryKeys } from "../common/queryKeysFactory"; +import { fetchMCPAccessGroups } from "@/components/networking"; + +const mcpAccessGroupsKeys = createQueryKeys("mcpAccessGroups"); + +export const useMCPAccessGroups = (accessToken: string | null) => { + return useQuery({ + queryKey: mcpAccessGroupsKeys.list({}), + queryFn: async () => await fetchMCPAccessGroups(accessToken!), + enabled: !!accessToken, + }); +}; diff --git a/ui/litellm-dashboard/src/app/(dashboard)/hooks/mcpServers/useMCPServers.ts b/ui/litellm-dashboard/src/app/(dashboard)/hooks/mcpServers/useMCPServers.ts new file mode 100644 index 0000000000..02e471d8e5 --- /dev/null +++ b/ui/litellm-dashboard/src/app/(dashboard)/hooks/mcpServers/useMCPServers.ts @@ -0,0 +1,14 @@ +import { useQuery } from "@tanstack/react-query"; +import { createQueryKeys } from "../common/queryKeysFactory"; +import { fetchMCPServers } from "@/components/networking"; +import { MCPServer } from "@/components/mcp_tools/types"; + +const mcpServersKeys = createQueryKeys("mcpServers"); + +export const useMCPServers = (accessToken: string | null) => { + return useQuery({ + queryKey: mcpServersKeys.list({}), + queryFn: async () => await fetchMCPServers(accessToken!), + enabled: !!accessToken, + }); +}; diff --git a/ui/litellm-dashboard/src/components/mcp_server_management/MCPServerSelector.tsx b/ui/litellm-dashboard/src/components/mcp_server_management/MCPServerSelector.tsx index 7b79f5cc70..f30a76a229 100644 --- a/ui/litellm-dashboard/src/components/mcp_server_management/MCPServerSelector.tsx +++ b/ui/litellm-dashboard/src/components/mcp_server_management/MCPServerSelector.tsx @@ -1,15 +1,13 @@ -import React, { useEffect, useState } from "react"; +import React from "react"; import { Select } from "antd"; -import { fetchMCPServers, fetchMCPAccessGroups } from "../networking"; +import { useMCPServers } from "@/app/(dashboard)/hooks/mcpServers/useMCPServers"; +import { useMCPAccessGroups } from "@/app/(dashboard)/hooks/mcpServers/useMCPAccessGroups"; import { MCPServer } from "../mcp_tools/types"; interface MCPServerSelectorProps { - onChange: (selected: { - servers: string[]; - accessGroups: string[]; - }) => void; - value?: { - servers: string[]; + onChange: (selected: { servers: string[]; accessGroups: string[] }) => void; + value?: { + servers: string[]; accessGroups: string[]; }; className?: string; @@ -26,31 +24,10 @@ const MCPServerSelector: React.FC = ({ placeholder = "Select MCP servers", disabled = false, }) => { - const [mcpServers, setMCPServers] = useState([]); - const [accessGroups, setAccessGroups] = useState([]); - const [loading, setLoading] = useState(false); + const { data: mcpServers = [], isLoading: serversLoading } = useMCPServers(accessToken); + const { data: accessGroups = [], isLoading: groupsLoading } = useMCPAccessGroups(accessToken); - useEffect(() => { - const fetchData = async () => { - if (!accessToken) return; - setLoading(true); - try { - const [serversRes, groupsRes] = await Promise.all([ - fetchMCPServers(accessToken), - fetchMCPAccessGroups(accessToken), - ]); - let servers = Array.isArray(serversRes) ? serversRes : serversRes.data || []; - let groups = Array.isArray(groupsRes) ? groupsRes : groupsRes.data || []; - setMCPServers(servers); - setAccessGroups(groups); - } catch (error) { - console.error("Error fetching MCP servers or access groups:", error); - } finally { - setLoading(false); - } - }; - fetchData(); - }, [accessToken]); + const loading = serversLoading || groupsLoading; // Combine options, access groups first const options = [ diff --git a/ui/litellm-dashboard/src/components/mcp_server_management/MCPToolPermissions.test.tsx b/ui/litellm-dashboard/src/components/mcp_server_management/MCPToolPermissions.test.tsx index 93f96966d0..ec541d7a63 100644 --- a/ui/litellm-dashboard/src/components/mcp_server_management/MCPToolPermissions.test.tsx +++ b/ui/litellm-dashboard/src/components/mcp_server_management/MCPToolPermissions.test.tsx @@ -1,11 +1,22 @@ import { describe, it, expect, vi, beforeEach } from "vitest"; import { render, screen, waitFor } from "@testing-library/react"; import userEvent from "@testing-library/user-event"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; import MCPToolPermissions from "./MCPToolPermissions"; import * as networking from "../networking"; vi.mock("../networking"); +const createQueryClient = () => + new QueryClient({ + defaultOptions: { + queries: { + retry: false, + gcTime: 0, + }, + }, + }); + describe("MCPToolPermissions", () => { const mockAccessToken = "test-token"; const mockServerId = "server-123"; @@ -28,15 +39,13 @@ describe("MCPToolPermissions", () => { ]; // Mock fetchMCPServers to return server details - vi.mocked(networking.fetchMCPServers).mockResolvedValue({ - data: [ - { - server_id: mockServerId, - server_name: mockServerName, - alias: mockServerName, - }, - ], - }); + vi.mocked(networking.fetchMCPServers).mockResolvedValue([ + { + server_id: mockServerId, + server_name: mockServerName, + alias: mockServerName, + }, + ]); // Mock listMCPTools to return tools for the server vi.mocked(networking.listMCPTools).mockResolvedValue({ @@ -44,13 +53,16 @@ describe("MCPToolPermissions", () => { error: false, }); + const queryClient = createQueryClient(); render( - + + + , ); // Wait for server and tools to load @@ -76,4 +88,3 @@ describe("MCPToolPermissions", () => { expect(networking.listMCPTools).toHaveBeenCalledWith(mockAccessToken, mockServerId); }); }); - diff --git a/ui/litellm-dashboard/src/components/mcp_server_management/MCPToolPermissions.tsx b/ui/litellm-dashboard/src/components/mcp_server_management/MCPToolPermissions.tsx index 3f36caf467..a567791e22 100644 --- a/ui/litellm-dashboard/src/components/mcp_server_management/MCPToolPermissions.tsx +++ b/ui/litellm-dashboard/src/components/mcp_server_management/MCPToolPermissions.tsx @@ -1,9 +1,10 @@ -import React, { useEffect, useState } from "react"; -import { listMCPTools, fetchMCPServers } from "../networking"; +import React, { useEffect, useState, useMemo } from "react"; +import { listMCPTools } from "../networking"; import { MCPTool, MCPServer } from "../mcp_tools/types"; import { Text } from "@tremor/react"; import { Spin, Checkbox } from "antd"; import { XIcon } from "lucide-react"; +import { useMCPServers } from "../../app/(dashboard)/hooks/mcpServers/useMCPServers"; interface MCPToolPermissionsProps { accessToken: string; @@ -20,63 +21,43 @@ const MCPToolPermissions: React.FC = ({ onChange, disabled = false, }) => { - const [servers, setServers] = useState([]); + const { data: allServers = [] } = useMCPServers(accessToken); const [serverTools, setServerTools] = useState>({}); const [loadingTools, setLoadingTools] = useState>({}); const [toolErrors, setToolErrors] = useState>({}); - // Fetch server details - useEffect(() => { - const loadServerDetails = async () => { - if (selectedServers.length === 0) { - setServers([]); - return; - } - - try { - const response = await fetchMCPServers(accessToken); - const allServers = Array.isArray(response) ? response : response.data || []; - - const filteredServers = allServers.filter((server: MCPServer) => - selectedServers.includes(server.server_id) - ); - - setServers(filteredServers); - } catch (error) { - console.error("Error fetching MCP servers:", error); - setServers([]); - } - }; - - loadServerDetails(); - }, [selectedServers, accessToken]); + // Filter servers based on selectedServers + const servers = useMemo(() => { + if (selectedServers.length === 0) return []; + return allServers.filter((server: MCPServer) => selectedServers.includes(server.server_id)); + }, [allServers, selectedServers]); // Fetch tools for a specific server const fetchToolsForServer = async (serverId: string) => { - setLoadingTools(prev => ({ ...prev, [serverId]: true })); - setToolErrors(prev => ({ ...prev, [serverId]: "" })); - + setLoadingTools((prev) => ({ ...prev, [serverId]: true })); + setToolErrors((prev) => ({ ...prev, [serverId]: "" })); + try { const response = await listMCPTools(accessToken, serverId); - + if (response.error) { - setToolErrors(prev => ({ ...prev, [serverId]: response.message || "Failed to fetch tools" })); - setServerTools(prev => ({ ...prev, [serverId]: [] })); + setToolErrors((prev) => ({ ...prev, [serverId]: response.message || "Failed to fetch tools" })); + setServerTools((prev) => ({ ...prev, [serverId]: [] })); } else { - setServerTools(prev => ({ ...prev, [serverId]: response.tools || [] })); + setServerTools((prev) => ({ ...prev, [serverId]: response.tools || [] })); } } catch (err) { console.error(`Error fetching tools for server ${serverId}:`, err); - setToolErrors(prev => ({ ...prev, [serverId]: "Failed to fetch tools" })); - setServerTools(prev => ({ ...prev, [serverId]: [] })); + setToolErrors((prev) => ({ ...prev, [serverId]: "Failed to fetch tools" })); + setServerTools((prev) => ({ ...prev, [serverId]: [] })); } finally { - setLoadingTools(prev => ({ ...prev, [serverId]: false })); + setLoadingTools((prev) => ({ ...prev, [serverId]: false })); } }; // Auto-fetch tools when servers change useEffect(() => { - servers.forEach(server => { + servers.forEach((server) => { if (!serverTools[server.server_id] && !loadingTools[server.server_id]) { fetchToolsForServer(server.server_id); } @@ -87,9 +68,9 @@ const MCPToolPermissions: React.FC = ({ const handleToolToggle = (serverId: string, toolName: string) => { const currentTools = toolPermissions[serverId] || []; const newTools = currentTools.includes(toolName) - ? currentTools.filter(name => name !== toolName) + ? currentTools.filter((name) => name !== toolName) : [...currentTools, toolName]; - + const updatedPermissions = { ...toolPermissions, [serverId]: newTools, @@ -101,7 +82,7 @@ const MCPToolPermissions: React.FC = ({ const tools = serverTools[serverId] || []; onChange({ ...toolPermissions, - [serverId]: tools.map(t => t.name), + [serverId]: tools.map((t) => t.name), }); }; @@ -131,9 +112,7 @@ const MCPToolPermissions: React.FC = ({
{serverName} - {server.description && ( - {server.description} - )} + {server.description && {server.description}}
diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_servers.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_servers.tsx index a46738ab36..83393c4a94 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_servers.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_servers.tsx @@ -1,11 +1,11 @@ import { isAdminRole } from "@/utils/roles"; import { QuestionCircleOutlined } from "@ant-design/icons"; -import { useQuery } from "@tanstack/react-query"; import { Button, Tab, TabGroup, TabList, TabPanel, TabPanels, Text, Title } from "@tremor/react"; import { Descriptions, Modal, Select, Tooltip, Typography } from "antd"; import React, { useEffect, useState } from "react"; +import { useMCPServers } from "../../app/(dashboard)/hooks/mcpServers/useMCPServers"; import NotificationsManager from "../molecules/notifications_manager"; -import { deleteMCPServer, fetchMCPServers } from "../networking"; +import { deleteMCPServer } from "../networking"; import { DataTable } from "../view_logs/table"; import CreateMCPServer from "./create_mcp_server"; import MCPConnect from "./mcp_connect"; @@ -19,19 +19,7 @@ const EDIT_OAUTH_UI_STATE_KEY = "litellm-mcp-oauth-edit-state"; const { Option } = Select; const MCPServers: React.FC = ({ accessToken, userRole, userID }) => { - const { - data: mcpServers, - isLoading: isLoadingServers, - refetch, - dataUpdatedAt, - } = useQuery({ - queryKey: ["mcpServers"], - queryFn: () => { - if (!accessToken) throw new Error("Access Token required"); - return fetchMCPServers(accessToken); - }, - enabled: !!accessToken, - }) as { data: MCPServer[]; isLoading: boolean; refetch: () => void; dataUpdatedAt: number }; + const { data: mcpServers, isLoading: isLoadingServers, refetch, dataUpdatedAt } = useMCPServers(accessToken); // Log allowed_tools from fetched servers React.useEffect(() => { diff --git a/ui/litellm-dashboard/src/components/organisms/create_key_button.tsx b/ui/litellm-dashboard/src/components/organisms/create_key_button.tsx index 3c0a0f520d..f3ad803e27 100644 --- a/ui/litellm-dashboard/src/components/organisms/create_key_button.tsx +++ b/ui/litellm-dashboard/src/components/organisms/create_key_button.tsx @@ -1,43 +1,40 @@ "use client"; -import React, { useState, useEffect, useCallback } from "react"; -import { Button, TextInput, Grid, Col } from "@tremor/react"; -import { Text, Title, Accordion, AccordionHeader, AccordionBody } from "@tremor/react"; -import { CopyToClipboard } from "react-copy-to-clipboard"; -import { Button as Button2, Modal, Form, Input, Select, Radio, Switch } from "antd"; -import NumericalInput from "../shared/numerical_input"; -import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key"; -import SchemaFormFields from "../common_components/check_openapi_schema"; -import { - keyCreateCall, - modelAvailableCall, - getGuardrailsList, - proxyBaseUrl, - getPossibleUserRoles, - userFilterUICall, - keyCreateServiceAccountCall, - fetchMCPAccessGroups, - getPromptsList, -} from "../networking"; -import VectorStoreSelector from "../vector_store_management/VectorStoreSelector"; -import PassThroughRoutesSelector from "../common_components/PassThroughRoutesSelector"; -import { Team } from "../key_team_helpers/key_list"; -import TeamDropdown from "../common_components/team_dropdown"; -import { InfoCircleOutlined } from "@ant-design/icons"; -import { Tooltip } from "antd"; -import PremiumLoggingSettings from "../common_components/PremiumLoggingSettings"; -import Createuser from "../create_user_button"; -import debounce from "lodash/debounce"; -import { rolesWithWriteAccess } from "../../utils/roles"; -import BudgetDurationDropdown from "../common_components/budget_duration_dropdown"; import { formatNumberWithCommas } from "@/utils/dataUtils"; +import { InfoCircleOutlined } from "@ant-design/icons"; +import { Accordion, AccordionBody, AccordionHeader, Button, Col, Grid, Text, TextInput, Title } from "@tremor/react"; +import { Button as Button2, Form, Input, Modal, Radio, Select, Switch, Tooltip } from "antd"; +import debounce from "lodash/debounce"; +import React, { useCallback, useEffect, useState } from "react"; +import { CopyToClipboard } from "react-copy-to-clipboard"; +import { rolesWithWriteAccess } from "../../utils/roles"; +import AgentSelector from "../agent_management/AgentSelector"; import { mapDisplayToInternalNames } from "../callback_info_helpers"; +import BudgetDurationDropdown from "../common_components/budget_duration_dropdown"; +import SchemaFormFields from "../common_components/check_openapi_schema"; +import KeyLifecycleSettings from "../common_components/KeyLifecycleSettings"; +import ModelAliasManager from "../common_components/ModelAliasManager"; +import PassThroughRoutesSelector from "../common_components/PassThroughRoutesSelector"; +import PremiumLoggingSettings from "../common_components/PremiumLoggingSettings"; +import RateLimitTypeFormItem from "../common_components/RateLimitTypeFormItem"; +import TeamDropdown from "../common_components/team_dropdown"; +import Createuser from "../create_user_button"; +import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key"; +import { Team } from "../key_team_helpers/key_list"; import MCPServerSelector from "../mcp_server_management/MCPServerSelector"; import MCPToolPermissions from "../mcp_server_management/MCPToolPermissions"; -import AgentSelector from "../agent_management/AgentSelector"; -import ModelAliasManager from "../common_components/ModelAliasManager"; import NotificationsManager from "../molecules/notifications_manager"; -import KeyLifecycleSettings from "../common_components/KeyLifecycleSettings"; -import RateLimitTypeFormItem from "../common_components/RateLimitTypeFormItem"; +import { + getGuardrailsList, + getPossibleUserRoles, + getPromptsList, + keyCreateCall, + keyCreateServiceAccountCall, + modelAvailableCall, + proxyBaseUrl, + userFilterUICall, +} from "../networking"; +import NumericalInput from "../shared/numerical_input"; +import VectorStoreSelector from "../vector_store_management/VectorStoreSelector"; const { Option } = Select; @@ -168,7 +165,6 @@ const CreateKey: React.FC = ({ const [userOptions, setUserOptions] = useState([]); const [userSearchLoading, setUserSearchLoading] = useState(false); const [mcpAccessGroups, setMcpAccessGroups] = useState([]); - const [mcpAccessGroupsLoaded, setMcpAccessGroupsLoaded] = useState(false); const [disabledCallbacks, setDisabledCallbacks] = useState([]); const [keyType, setKeyType] = useState("default"); const [modelAliases, setModelAliases] = useState<{ [key: string]: string }>({}); @@ -205,22 +201,6 @@ const CreateKey: React.FC = ({ } }, [accessToken, userID, userRole]); - const fetchMcpAccessGroups = async () => { - try { - if (accessToken == null) { - return; - } - const groups = await fetchMCPAccessGroups(accessToken); - setMcpAccessGroups(groups); - } catch (error) { - console.error("Failed to fetch MCP access groups:", error); - } - }; - - useEffect(() => { - fetchMcpAccessGroups(); - }, [accessToken]); - useEffect(() => { const fetchGuardrails = async () => { try { @@ -1053,15 +1033,7 @@ const CreateKey: React.FC = ({ options={predefinedTags} /> - { - if (!mcpAccessGroupsLoaded) { - fetchMcpAccessGroups(); - setMcpAccessGroupsLoaded(true); - } - }} - > + MCP Settings diff --git a/ui/litellm-dashboard/src/components/team/team_info.test.tsx b/ui/litellm-dashboard/src/components/team/team_info.test.tsx index 9b19611828..4f7d70ba6a 100644 --- a/ui/litellm-dashboard/src/components/team/team_info.test.tsx +++ b/ui/litellm-dashboard/src/components/team/team_info.test.tsx @@ -1,5 +1,6 @@ import * as networking from "@/components/networking"; -import { act, fireEvent, render, screen, waitFor } from "@testing-library/react"; +import { act, fireEvent, screen, waitFor } from "@testing-library/react"; +import { renderWithProviders } from "../../../tests/test-utils"; import { afterEach, describe, expect, it, vi } from "vitest"; import TeamInfoView from "./team_info"; @@ -62,7 +63,7 @@ describe("TeamInfoView", () => { vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] }); vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); - render( + renderWithProviders( {}} @@ -124,7 +125,7 @@ describe("TeamInfoView", () => { vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] }); vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); - render( + renderWithProviders( {}} @@ -219,7 +220,7 @@ describe("TeamInfoView", () => { vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] }); vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); - render( + renderWithProviders( {}} @@ -310,7 +311,7 @@ describe("TeamInfoView", () => { vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] }); vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); - render( + renderWithProviders( {}} @@ -373,7 +374,7 @@ describe("TeamInfoView", () => { vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); vi.mocked(networking.teamUpdateCall).mockResolvedValue({ data: teamResponse.team_info, team_id: "123" } as any); - render( + renderWithProviders( {}} @@ -450,7 +451,7 @@ describe("TeamInfoView", () => { vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]); vi.mocked(networking.teamUpdateCall).mockResolvedValue({ data: teamResponse.team_info, team_id: "123" } as any); - render( + renderWithProviders( {}} diff --git a/ui/litellm-dashboard/src/components/templates/key_edit_view.test.tsx b/ui/litellm-dashboard/src/components/templates/key_edit_view.test.tsx index e23bc88dda..85c6192693 100644 --- a/ui/litellm-dashboard/src/components/templates/key_edit_view.test.tsx +++ b/ui/litellm-dashboard/src/components/templates/key_edit_view.test.tsx @@ -1,4 +1,5 @@ -import { render, waitFor } from "@testing-library/react"; +import { waitFor } from "@testing-library/react"; +import { renderWithProviders } from "../../../tests/test-utils"; import { describe, expect, it, vi } from "vitest"; import { KeyEditView } from "./key_edit_view"; import { KeyResponse } from "../key_team_helpers/key_list"; @@ -89,7 +90,7 @@ describe("KeyEditView", () => { key_rotation_at: undefined, }; it("should render", async () => { - const { getByText } = render( + const { getByText } = renderWithProviders( {}} @@ -107,7 +108,7 @@ describe("KeyEditView", () => { }); it("should render tags", async () => { - const { getByText } = render( + const { getByText } = renderWithProviders( {}} @@ -125,7 +126,7 @@ describe("KeyEditView", () => { }); it("should not render tags in metadata textarea", async () => { - const { getByLabelText } = render( + const { getByLabelText } = renderWithProviders( {}} diff --git a/ui/litellm-dashboard/tests/test-utils.tsx b/ui/litellm-dashboard/tests/test-utils.tsx index cf7fbaf0d8..ed1f248648 100644 --- a/ui/litellm-dashboard/tests/test-utils.tsx +++ b/ui/litellm-dashboard/tests/test-utils.tsx @@ -1,9 +1,26 @@ import React, { PropsWithChildren } from "react"; import { render, RenderOptions } from "@testing-library/react"; +import { QueryClient, QueryClientProvider } from "@tanstack/react-query"; + +// Create a client for testing +const queryClient = new QueryClient({ + defaultOptions: { + queries: { + retry: false, + gcTime: Infinity, + staleTime: Infinity, + refetchOnWindowFocus: false, + refetchOnReconnect: false, + refetchOnMount: false, + }, + mutations: { + retry: false, + }, + }, +}); const Providers: React.FC = ({ children }) => { - // Add future providers here (Theme/Router/QueryClient/etc.) - return <>{children}; + return {children}; }; export const renderWithProviders = (ui: React.ReactElement, options?: RenderOptions) =>