Migrate MCP servers to react query

This commit is contained in:
yuneng-jiang 2025-12-22 14:33:38 -08:00
parent e12d14e969
commit dc3bdffaee
10 changed files with 155 additions and 184 deletions

View File

@ -0,0 +1,13 @@
import { useQuery } from "@tanstack/react-query";
import { createQueryKeys } from "../common/queryKeysFactory";
import { fetchMCPAccessGroups } from "@/components/networking";
const mcpAccessGroupsKeys = createQueryKeys("mcpAccessGroups");
export const useMCPAccessGroups = (accessToken: string | null) => {
return useQuery<string[]>({
queryKey: mcpAccessGroupsKeys.list({}),
queryFn: async () => await fetchMCPAccessGroups(accessToken!),
enabled: !!accessToken,
});
};

View File

@ -0,0 +1,14 @@
import { useQuery } from "@tanstack/react-query";
import { createQueryKeys } from "../common/queryKeysFactory";
import { fetchMCPServers } from "@/components/networking";
import { MCPServer } from "@/components/mcp_tools/types";
const mcpServersKeys = createQueryKeys("mcpServers");
export const useMCPServers = (accessToken: string | null) => {
return useQuery<MCPServer[]>({
queryKey: mcpServersKeys.list({}),
queryFn: async () => await fetchMCPServers(accessToken!),
enabled: !!accessToken,
});
};

View File

@ -1,15 +1,13 @@
import React, { useEffect, useState } from "react";
import React from "react";
import { Select } from "antd";
import { fetchMCPServers, fetchMCPAccessGroups } from "../networking";
import { useMCPServers } from "@/app/(dashboard)/hooks/mcpServers/useMCPServers";
import { useMCPAccessGroups } from "@/app/(dashboard)/hooks/mcpServers/useMCPAccessGroups";
import { MCPServer } from "../mcp_tools/types";
interface MCPServerSelectorProps {
onChange: (selected: {
servers: string[];
accessGroups: string[];
}) => void;
value?: {
servers: string[];
onChange: (selected: { servers: string[]; accessGroups: string[] }) => void;
value?: {
servers: string[];
accessGroups: string[];
};
className?: string;
@ -26,31 +24,10 @@ const MCPServerSelector: React.FC<MCPServerSelectorProps> = ({
placeholder = "Select MCP servers",
disabled = false,
}) => {
const [mcpServers, setMCPServers] = useState<MCPServer[]>([]);
const [accessGroups, setAccessGroups] = useState<string[]>([]);
const [loading, setLoading] = useState(false);
const { data: mcpServers = [], isLoading: serversLoading } = useMCPServers(accessToken);
const { data: accessGroups = [], isLoading: groupsLoading } = useMCPAccessGroups(accessToken);
useEffect(() => {
const fetchData = async () => {
if (!accessToken) return;
setLoading(true);
try {
const [serversRes, groupsRes] = await Promise.all([
fetchMCPServers(accessToken),
fetchMCPAccessGroups(accessToken),
]);
let servers = Array.isArray(serversRes) ? serversRes : serversRes.data || [];
let groups = Array.isArray(groupsRes) ? groupsRes : groupsRes.data || [];
setMCPServers(servers);
setAccessGroups(groups);
} catch (error) {
console.error("Error fetching MCP servers or access groups:", error);
} finally {
setLoading(false);
}
};
fetchData();
}, [accessToken]);
const loading = serversLoading || groupsLoading;
// Combine options, access groups first
const options = [

View File

@ -1,11 +1,22 @@
import { describe, it, expect, vi, beforeEach } from "vitest";
import { render, screen, waitFor } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
import MCPToolPermissions from "./MCPToolPermissions";
import * as networking from "../networking";
vi.mock("../networking");
const createQueryClient = () =>
new QueryClient({
defaultOptions: {
queries: {
retry: false,
gcTime: 0,
},
},
});
describe("MCPToolPermissions", () => {
const mockAccessToken = "test-token";
const mockServerId = "server-123";
@ -28,15 +39,13 @@ describe("MCPToolPermissions", () => {
];
// Mock fetchMCPServers to return server details
vi.mocked(networking.fetchMCPServers).mockResolvedValue({
data: [
{
server_id: mockServerId,
server_name: mockServerName,
alias: mockServerName,
},
],
});
vi.mocked(networking.fetchMCPServers).mockResolvedValue([
{
server_id: mockServerId,
server_name: mockServerName,
alias: mockServerName,
},
]);
// Mock listMCPTools to return tools for the server
vi.mocked(networking.listMCPTools).mockResolvedValue({
@ -44,13 +53,16 @@ describe("MCPToolPermissions", () => {
error: false,
});
const queryClient = createQueryClient();
render(
<MCPToolPermissions
accessToken={mockAccessToken}
selectedServers={[mockServerId]}
toolPermissions={{}}
onChange={mockOnChange}
/>
<QueryClientProvider client={queryClient}>
<MCPToolPermissions
accessToken={mockAccessToken}
selectedServers={[mockServerId]}
toolPermissions={{}}
onChange={mockOnChange}
/>
</QueryClientProvider>,
);
// Wait for server and tools to load
@ -76,4 +88,3 @@ describe("MCPToolPermissions", () => {
expect(networking.listMCPTools).toHaveBeenCalledWith(mockAccessToken, mockServerId);
});
});

View File

@ -1,9 +1,10 @@
import React, { useEffect, useState } from "react";
import { listMCPTools, fetchMCPServers } from "../networking";
import React, { useEffect, useState, useMemo } from "react";
import { listMCPTools } from "../networking";
import { MCPTool, MCPServer } from "../mcp_tools/types";
import { Text } from "@tremor/react";
import { Spin, Checkbox } from "antd";
import { XIcon } from "lucide-react";
import { useMCPServers } from "../../app/(dashboard)/hooks/mcpServers/useMCPServers";
interface MCPToolPermissionsProps {
accessToken: string;
@ -20,63 +21,43 @@ const MCPToolPermissions: React.FC<MCPToolPermissionsProps> = ({
onChange,
disabled = false,
}) => {
const [servers, setServers] = useState<MCPServer[]>([]);
const { data: allServers = [] } = useMCPServers(accessToken);
const [serverTools, setServerTools] = useState<Record<string, MCPTool[]>>({});
const [loadingTools, setLoadingTools] = useState<Record<string, boolean>>({});
const [toolErrors, setToolErrors] = useState<Record<string, string>>({});
// Fetch server details
useEffect(() => {
const loadServerDetails = async () => {
if (selectedServers.length === 0) {
setServers([]);
return;
}
try {
const response = await fetchMCPServers(accessToken);
const allServers = Array.isArray(response) ? response : response.data || [];
const filteredServers = allServers.filter((server: MCPServer) =>
selectedServers.includes(server.server_id)
);
setServers(filteredServers);
} catch (error) {
console.error("Error fetching MCP servers:", error);
setServers([]);
}
};
loadServerDetails();
}, [selectedServers, accessToken]);
// Filter servers based on selectedServers
const servers = useMemo(() => {
if (selectedServers.length === 0) return [];
return allServers.filter((server: MCPServer) => selectedServers.includes(server.server_id));
}, [allServers, selectedServers]);
// Fetch tools for a specific server
const fetchToolsForServer = async (serverId: string) => {
setLoadingTools(prev => ({ ...prev, [serverId]: true }));
setToolErrors(prev => ({ ...prev, [serverId]: "" }));
setLoadingTools((prev) => ({ ...prev, [serverId]: true }));
setToolErrors((prev) => ({ ...prev, [serverId]: "" }));
try {
const response = await listMCPTools(accessToken, serverId);
if (response.error) {
setToolErrors(prev => ({ ...prev, [serverId]: response.message || "Failed to fetch tools" }));
setServerTools(prev => ({ ...prev, [serverId]: [] }));
setToolErrors((prev) => ({ ...prev, [serverId]: response.message || "Failed to fetch tools" }));
setServerTools((prev) => ({ ...prev, [serverId]: [] }));
} else {
setServerTools(prev => ({ ...prev, [serverId]: response.tools || [] }));
setServerTools((prev) => ({ ...prev, [serverId]: response.tools || [] }));
}
} catch (err) {
console.error(`Error fetching tools for server ${serverId}:`, err);
setToolErrors(prev => ({ ...prev, [serverId]: "Failed to fetch tools" }));
setServerTools(prev => ({ ...prev, [serverId]: [] }));
setToolErrors((prev) => ({ ...prev, [serverId]: "Failed to fetch tools" }));
setServerTools((prev) => ({ ...prev, [serverId]: [] }));
} finally {
setLoadingTools(prev => ({ ...prev, [serverId]: false }));
setLoadingTools((prev) => ({ ...prev, [serverId]: false }));
}
};
// Auto-fetch tools when servers change
useEffect(() => {
servers.forEach(server => {
servers.forEach((server) => {
if (!serverTools[server.server_id] && !loadingTools[server.server_id]) {
fetchToolsForServer(server.server_id);
}
@ -87,9 +68,9 @@ const MCPToolPermissions: React.FC<MCPToolPermissionsProps> = ({
const handleToolToggle = (serverId: string, toolName: string) => {
const currentTools = toolPermissions[serverId] || [];
const newTools = currentTools.includes(toolName)
? currentTools.filter(name => name !== toolName)
? currentTools.filter((name) => name !== toolName)
: [...currentTools, toolName];
const updatedPermissions = {
...toolPermissions,
[serverId]: newTools,
@ -101,7 +82,7 @@ const MCPToolPermissions: React.FC<MCPToolPermissionsProps> = ({
const tools = serverTools[serverId] || [];
onChange({
...toolPermissions,
[serverId]: tools.map(t => t.name),
[serverId]: tools.map((t) => t.name),
});
};
@ -131,9 +112,7 @@ const MCPToolPermissions: React.FC<MCPToolPermissionsProps> = ({
<div className="flex items-center justify-between p-4 border-b bg-white rounded-t-lg">
<div>
<Text className="font-semibold text-gray-900">{serverName}</Text>
{server.description && (
<Text className="text-sm text-gray-500">{server.description}</Text>
)}
{server.description && <Text className="text-sm text-gray-500">{server.description}</Text>}
</div>
<div className="flex items-center gap-3">
<button
@ -197,9 +176,7 @@ const MCPToolPermissions: React.FC<MCPToolPermissionsProps> = ({
<div className="flex-1 min-w-0">
<div className="flex items-center gap-2">
<Text className="font-medium text-gray-900">{tool.name}</Text>
<Text className="text-sm text-gray-500">
- {tool.description || "No description"}
</Text>
<Text className="text-sm text-gray-500">- {tool.description || "No description"}</Text>
</div>
</div>
</div>

View File

@ -1,11 +1,11 @@
import { isAdminRole } from "@/utils/roles";
import { QuestionCircleOutlined } from "@ant-design/icons";
import { useQuery } from "@tanstack/react-query";
import { Button, Tab, TabGroup, TabList, TabPanel, TabPanels, Text, Title } from "@tremor/react";
import { Descriptions, Modal, Select, Tooltip, Typography } from "antd";
import React, { useEffect, useState } from "react";
import { useMCPServers } from "../../app/(dashboard)/hooks/mcpServers/useMCPServers";
import NotificationsManager from "../molecules/notifications_manager";
import { deleteMCPServer, fetchMCPServers } from "../networking";
import { deleteMCPServer } from "../networking";
import { DataTable } from "../view_logs/table";
import CreateMCPServer from "./create_mcp_server";
import MCPConnect from "./mcp_connect";
@ -19,19 +19,7 @@ const EDIT_OAUTH_UI_STATE_KEY = "litellm-mcp-oauth-edit-state";
const { Option } = Select;
const MCPServers: React.FC<MCPServerProps> = ({ accessToken, userRole, userID }) => {
const {
data: mcpServers,
isLoading: isLoadingServers,
refetch,
dataUpdatedAt,
} = useQuery({
queryKey: ["mcpServers"],
queryFn: () => {
if (!accessToken) throw new Error("Access Token required");
return fetchMCPServers(accessToken);
},
enabled: !!accessToken,
}) as { data: MCPServer[]; isLoading: boolean; refetch: () => void; dataUpdatedAt: number };
const { data: mcpServers, isLoading: isLoadingServers, refetch, dataUpdatedAt } = useMCPServers(accessToken);
// Log allowed_tools from fetched servers
React.useEffect(() => {

View File

@ -1,43 +1,40 @@
"use client";
import React, { useState, useEffect, useCallback } from "react";
import { Button, TextInput, Grid, Col } from "@tremor/react";
import { Text, Title, Accordion, AccordionHeader, AccordionBody } from "@tremor/react";
import { CopyToClipboard } from "react-copy-to-clipboard";
import { Button as Button2, Modal, Form, Input, Select, Radio, Switch } from "antd";
import NumericalInput from "../shared/numerical_input";
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
import SchemaFormFields from "../common_components/check_openapi_schema";
import {
keyCreateCall,
modelAvailableCall,
getGuardrailsList,
proxyBaseUrl,
getPossibleUserRoles,
userFilterUICall,
keyCreateServiceAccountCall,
fetchMCPAccessGroups,
getPromptsList,
} from "../networking";
import VectorStoreSelector from "../vector_store_management/VectorStoreSelector";
import PassThroughRoutesSelector from "../common_components/PassThroughRoutesSelector";
import { Team } from "../key_team_helpers/key_list";
import TeamDropdown from "../common_components/team_dropdown";
import { InfoCircleOutlined } from "@ant-design/icons";
import { Tooltip } from "antd";
import PremiumLoggingSettings from "../common_components/PremiumLoggingSettings";
import Createuser from "../create_user_button";
import debounce from "lodash/debounce";
import { rolesWithWriteAccess } from "../../utils/roles";
import BudgetDurationDropdown from "../common_components/budget_duration_dropdown";
import { formatNumberWithCommas } from "@/utils/dataUtils";
import { InfoCircleOutlined } from "@ant-design/icons";
import { Accordion, AccordionBody, AccordionHeader, Button, Col, Grid, Text, TextInput, Title } from "@tremor/react";
import { Button as Button2, Form, Input, Modal, Radio, Select, Switch, Tooltip } from "antd";
import debounce from "lodash/debounce";
import React, { useCallback, useEffect, useState } from "react";
import { CopyToClipboard } from "react-copy-to-clipboard";
import { rolesWithWriteAccess } from "../../utils/roles";
import AgentSelector from "../agent_management/AgentSelector";
import { mapDisplayToInternalNames } from "../callback_info_helpers";
import BudgetDurationDropdown from "../common_components/budget_duration_dropdown";
import SchemaFormFields from "../common_components/check_openapi_schema";
import KeyLifecycleSettings from "../common_components/KeyLifecycleSettings";
import ModelAliasManager from "../common_components/ModelAliasManager";
import PassThroughRoutesSelector from "../common_components/PassThroughRoutesSelector";
import PremiumLoggingSettings from "../common_components/PremiumLoggingSettings";
import RateLimitTypeFormItem from "../common_components/RateLimitTypeFormItem";
import TeamDropdown from "../common_components/team_dropdown";
import Createuser from "../create_user_button";
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
import { Team } from "../key_team_helpers/key_list";
import MCPServerSelector from "../mcp_server_management/MCPServerSelector";
import MCPToolPermissions from "../mcp_server_management/MCPToolPermissions";
import AgentSelector from "../agent_management/AgentSelector";
import ModelAliasManager from "../common_components/ModelAliasManager";
import NotificationsManager from "../molecules/notifications_manager";
import KeyLifecycleSettings from "../common_components/KeyLifecycleSettings";
import RateLimitTypeFormItem from "../common_components/RateLimitTypeFormItem";
import {
getGuardrailsList,
getPossibleUserRoles,
getPromptsList,
keyCreateCall,
keyCreateServiceAccountCall,
modelAvailableCall,
proxyBaseUrl,
userFilterUICall,
} from "../networking";
import NumericalInput from "../shared/numerical_input";
import VectorStoreSelector from "../vector_store_management/VectorStoreSelector";
const { Option } = Select;
@ -168,7 +165,6 @@ const CreateKey: React.FC<CreateKeyProps> = ({
const [userOptions, setUserOptions] = useState<UserOption[]>([]);
const [userSearchLoading, setUserSearchLoading] = useState<boolean>(false);
const [mcpAccessGroups, setMcpAccessGroups] = useState<string[]>([]);
const [mcpAccessGroupsLoaded, setMcpAccessGroupsLoaded] = useState(false);
const [disabledCallbacks, setDisabledCallbacks] = useState<string[]>([]);
const [keyType, setKeyType] = useState<string>("default");
const [modelAliases, setModelAliases] = useState<{ [key: string]: string }>({});
@ -205,22 +201,6 @@ const CreateKey: React.FC<CreateKeyProps> = ({
}
}, [accessToken, userID, userRole]);
const fetchMcpAccessGroups = async () => {
try {
if (accessToken == null) {
return;
}
const groups = await fetchMCPAccessGroups(accessToken);
setMcpAccessGroups(groups);
} catch (error) {
console.error("Failed to fetch MCP access groups:", error);
}
};
useEffect(() => {
fetchMcpAccessGroups();
}, [accessToken]);
useEffect(() => {
const fetchGuardrails = async () => {
try {
@ -1053,15 +1033,7 @@ const CreateKey: React.FC<CreateKeyProps> = ({
options={predefinedTags}
/>
</Form.Item>
<Accordion
className="mt-4 mb-4"
onClick={() => {
if (!mcpAccessGroupsLoaded) {
fetchMcpAccessGroups();
setMcpAccessGroupsLoaded(true);
}
}}
>
<Accordion className="mt-4 mb-4">
<AccordionHeader>
<b>MCP Settings</b>
</AccordionHeader>

View File

@ -1,5 +1,6 @@
import * as networking from "@/components/networking";
import { act, fireEvent, render, screen, waitFor } from "@testing-library/react";
import { act, fireEvent, screen, waitFor } from "@testing-library/react";
import { renderWithProviders } from "../../../tests/test-utils";
import { afterEach, describe, expect, it, vi } from "vitest";
import TeamInfoView from "./team_info";
@ -62,7 +63,7 @@ describe("TeamInfoView", () => {
vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] });
vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]);
render(
renderWithProviders(
<TeamInfoView
teamId="123"
onUpdate={() => {}}
@ -124,7 +125,7 @@ describe("TeamInfoView", () => {
vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] });
vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]);
render(
renderWithProviders(
<TeamInfoView
teamId="123"
onUpdate={() => {}}
@ -219,7 +220,7 @@ describe("TeamInfoView", () => {
vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] });
vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]);
render(
renderWithProviders(
<TeamInfoView
teamId="123"
onUpdate={() => {}}
@ -310,7 +311,7 @@ describe("TeamInfoView", () => {
vi.mocked(networking.getGuardrailsList).mockResolvedValue({ guardrails: [] });
vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]);
render(
renderWithProviders(
<TeamInfoView
teamId="123"
onUpdate={() => {}}
@ -373,7 +374,7 @@ describe("TeamInfoView", () => {
vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]);
vi.mocked(networking.teamUpdateCall).mockResolvedValue({ data: teamResponse.team_info, team_id: "123" } as any);
render(
renderWithProviders(
<TeamInfoView
teamId="123"
onUpdate={() => {}}
@ -450,7 +451,7 @@ describe("TeamInfoView", () => {
vi.mocked(networking.fetchMCPAccessGroups).mockResolvedValue([]);
vi.mocked(networking.teamUpdateCall).mockResolvedValue({ data: teamResponse.team_info, team_id: "123" } as any);
render(
renderWithProviders(
<TeamInfoView
teamId="123"
onUpdate={() => {}}

View File

@ -1,4 +1,5 @@
import { render, waitFor } from "@testing-library/react";
import { waitFor } from "@testing-library/react";
import { renderWithProviders } from "../../../tests/test-utils";
import { describe, expect, it, vi } from "vitest";
import { KeyEditView } from "./key_edit_view";
import { KeyResponse } from "../key_team_helpers/key_list";
@ -89,7 +90,7 @@ describe("KeyEditView", () => {
key_rotation_at: undefined,
};
it("should render", async () => {
const { getByText } = render(
const { getByText } = renderWithProviders(
<KeyEditView
keyData={MOCK_KEY_DATA}
onCancel={() => {}}
@ -107,7 +108,7 @@ describe("KeyEditView", () => {
});
it("should render tags", async () => {
const { getByText } = render(
const { getByText } = renderWithProviders(
<KeyEditView
keyData={MOCK_KEY_DATA}
onCancel={() => {}}
@ -125,7 +126,7 @@ describe("KeyEditView", () => {
});
it("should not render tags in metadata textarea", async () => {
const { getByLabelText } = render(
const { getByLabelText } = renderWithProviders(
<KeyEditView
keyData={MOCK_KEY_DATA}
onCancel={() => {}}

View File

@ -1,9 +1,26 @@
import React, { PropsWithChildren } from "react";
import { render, RenderOptions } from "@testing-library/react";
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
// Create a client for testing
const queryClient = new QueryClient({
defaultOptions: {
queries: {
retry: false,
gcTime: Infinity,
staleTime: Infinity,
refetchOnWindowFocus: false,
refetchOnReconnect: false,
refetchOnMount: false,
},
mutations: {
retry: false,
},
},
});
const Providers: React.FC<PropsWithChildren> = ({ children }) => {
// Add future providers here (Theme/Router/QueryClient/etc.)
return <>{children}</>;
return <QueryClientProvider client={queryClient}>{children}</QueryClientProvider>;
};
export const renderWithProviders = (ui: React.ReactElement, options?: RenderOptions) =>