fix(mcp): block arbitrary command execution via stdio transport
Add command allowlist for MCP stdio transport to prevent RCE via /mcp-rest/test/* endpoints. Restrict test endpoints to PROXY_ADMIN role. Fix docker/README.md MASTER_KEY -> LITELLM_MASTER_KEY. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
62757ff48f
commit
7b7f304675
@ -13,19 +13,19 @@ To build and run the application, you will use the `docker-compose.yml` file loc
|
||||
|
||||
### 1. Set the Master Key
|
||||
|
||||
The application requires a `MASTER_KEY` for signing and validating tokens. You must set this key as an environment variable before running the application.
|
||||
The application requires a `LITELLM_MASTER_KEY` for signing and validating tokens. You must set this key as an environment variable before running the application.
|
||||
|
||||
Create a `.env` file in the root of the project and add the following line:
|
||||
|
||||
```
|
||||
MASTER_KEY=your-secret-key
|
||||
LITELLM_MASTER_KEY=your-secret-key
|
||||
```
|
||||
|
||||
Replace `your-secret-key` with a strong, randomly generated secret.
|
||||
|
||||
### 2. Build and Run the Containers
|
||||
|
||||
Once you have set the `MASTER_KEY`, you can build and run the containers using the following command:
|
||||
Once you have set the `LITELLM_MASTER_KEY`, you can build and run the containers using the following command:
|
||||
|
||||
```bash
|
||||
docker compose up -d --build
|
||||
@ -89,4 +89,4 @@ This command should succeed (showing engine versions) even with `--network none`
|
||||
## Troubleshooting
|
||||
|
||||
- **`build_admin_ui.sh: not found`**: This error can occur if the Docker build context is not set correctly. Ensure that you are running the `docker-compose` command from the root of the project.
|
||||
- **`Master key is not initialized`**: This error means the `MASTER_key` environment variable is not set. Make sure you have created a `.env` file in the project root with the `MASTER_KEY` defined.
|
||||
- **`Master key is not initialized`**: This error means the `LITELLM_MASTER_KEY` environment variable is not set. Make sure you have created a `.env` file in the project root with the `LITELLM_MASTER_KEY` defined.
|
||||
|
||||
@ -141,6 +141,15 @@ MCP_TOOL_LISTING_TIMEOUT = float(os.getenv("LITELLM_MCP_TOOL_LISTING_TIMEOUT", "
|
||||
MCP_METADATA_TIMEOUT = float(os.getenv("LITELLM_MCP_METADATA_TIMEOUT", "10.0"))
|
||||
MCP_HEALTH_CHECK_TIMEOUT = float(os.getenv("LITELLM_MCP_HEALTH_CHECK_TIMEOUT", "10.0"))
|
||||
|
||||
# Allowlist of commands permitted for MCP stdio transport.
|
||||
# Prevents arbitrary command execution via /mcp-rest/test/* endpoints or server creation.
|
||||
# Extend via LITELLM_MCP_STDIO_EXTRA_COMMANDS env var (comma-separated).
|
||||
_MCP_STDIO_EXTRA_COMMANDS = os.getenv("LITELLM_MCP_STDIO_EXTRA_COMMANDS", "")
|
||||
MCP_STDIO_ALLOWED_COMMANDS: frozenset = frozenset(
|
||||
{"npx", "uvx", "python", "python3", "node", "docker", "deno"}
|
||||
| (set(_MCP_STDIO_EXTRA_COMMANDS.split(",")) - {""})
|
||||
)
|
||||
|
||||
LITELLM_UI_ALLOW_HEADERS = [
|
||||
"x-litellm-semantic-filter",
|
||||
"x-litellm-semantic-filter-tools",
|
||||
|
||||
@ -1122,6 +1122,20 @@ class MCPServerManager:
|
||||
from litellm.constants import MCP_NPM_CACHE_DIR
|
||||
|
||||
resolved_env["NPM_CONFIG_CACHE"] = MCP_NPM_CACHE_DIR
|
||||
# Defense-in-depth: validate command even if Pydantic validation was bypassed
|
||||
# (e.g. MCPServer built from config/DB records predating the allowlist)
|
||||
if server.command:
|
||||
import os as _os
|
||||
|
||||
from litellm.constants import MCP_STDIO_ALLOWED_COMMANDS
|
||||
|
||||
base_command = _os.path.basename(server.command)
|
||||
if base_command not in MCP_STDIO_ALLOWED_COMMANDS:
|
||||
raise ValueError(
|
||||
f"Command '{server.command}' is not in the allowed commands list "
|
||||
f"for stdio transport. Allowed commands: {sorted(MCP_STDIO_ALLOWED_COMMANDS)}"
|
||||
)
|
||||
|
||||
stdio_config: Optional[MCPStdioConfig] = None
|
||||
if server.command and server.args is not None:
|
||||
stdio_config = MCPStdioConfig(
|
||||
|
||||
@ -2,14 +2,14 @@ import importlib
|
||||
from datetime import datetime
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.proxy._experimental.mcp_server.ui_session_utils import (
|
||||
build_effective_auth_contexts,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.utils import merge_mcp_headers
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers
|
||||
@ -1027,6 +1027,13 @@ if MCP_AVAILABLE:
|
||||
"""
|
||||
Test if we can connect to the provided MCP server before adding it
|
||||
"""
|
||||
if LitellmUserRoles.PROXY_ADMIN != user_api_key_dict.user_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": "User does not have permission to test MCP server connections. Only PROXY_ADMIN users can perform this action."
|
||||
},
|
||||
)
|
||||
|
||||
async def _test_connection_operation(client):
|
||||
async def _noop(session):
|
||||
@ -1041,7 +1048,7 @@ if MCP_AVAILABLE:
|
||||
raw_headers=_safe_get_request_headers(request),
|
||||
)
|
||||
|
||||
@router.post("/test/tools/list")
|
||||
@router.post("/test/tools/list", dependencies=[Depends(user_api_key_auth)])
|
||||
async def test_tools_list(
|
||||
request: Request,
|
||||
new_mcp_server_request: NewMCPServerRequest,
|
||||
@ -1050,6 +1057,14 @@ if MCP_AVAILABLE:
|
||||
"""
|
||||
Preview tools available from MCP server before adding it
|
||||
"""
|
||||
if LitellmUserRoles.PROXY_ADMIN != user_api_key_dict.user_role:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": "User does not have permission to test MCP server tools. Only PROXY_ADMIN users can perform this action."
|
||||
},
|
||||
)
|
||||
|
||||
# For OpenAPI spec servers, generate tools from the spec directly
|
||||
if new_mcp_server_request.spec_path:
|
||||
return await _preview_openapi_tools(new_mcp_server_request.spec_path)
|
||||
|
||||
@ -1162,6 +1162,17 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
raise ValueError("command is required for stdio transport")
|
||||
if not values.get("args"):
|
||||
raise ValueError("args is required for stdio transport")
|
||||
# Validate command against allowlist to prevent arbitrary execution
|
||||
import os as _os
|
||||
|
||||
from litellm.constants import MCP_STDIO_ALLOWED_COMMANDS
|
||||
|
||||
base_command = _os.path.basename(values["command"])
|
||||
if base_command not in MCP_STDIO_ALLOWED_COMMANDS:
|
||||
raise ValueError(
|
||||
f"Command '{values['command']}' is not in the allowed commands list "
|
||||
f"for stdio transport. Allowed commands: {sorted(MCP_STDIO_ALLOWED_COMMANDS)}"
|
||||
)
|
||||
elif transport in [MCPTransport.http, MCPTransport.sse]:
|
||||
if not values.get("url") and not values.get("spec_path"):
|
||||
raise ValueError(
|
||||
@ -1222,6 +1233,17 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
raise ValueError("command is required for stdio transport")
|
||||
if not values.get("args"):
|
||||
raise ValueError("args is required for stdio transport")
|
||||
# Validate command against allowlist to prevent arbitrary execution
|
||||
import os as _os
|
||||
|
||||
from litellm.constants import MCP_STDIO_ALLOWED_COMMANDS
|
||||
|
||||
base_command = _os.path.basename(values["command"])
|
||||
if base_command not in MCP_STDIO_ALLOWED_COMMANDS:
|
||||
raise ValueError(
|
||||
f"Command '{values['command']}' is not in the allowed commands list "
|
||||
f"for stdio transport. Allowed commands: {sorted(MCP_STDIO_ALLOWED_COMMANDS)}"
|
||||
)
|
||||
elif transport in [MCPTransport.http, MCPTransport.sse]:
|
||||
if not values.get("url") and not values.get("spec_path"):
|
||||
raise ValueError(
|
||||
|
||||
@ -156,7 +156,6 @@ class TestExecuteWithMcpClient:
|
||||
"Authorization": "STATIC token",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_m2m_credentials_forwarded_to_server_model(self, monkeypatch):
|
||||
"""M2M OAuth credentials (client_id, client_secret) from the nested
|
||||
@ -199,9 +198,7 @@ class TestExecuteWithMcpClient:
|
||||
},
|
||||
)
|
||||
|
||||
result = await rest_endpoints._execute_with_mcp_client(
|
||||
payload, ok_operation
|
||||
)
|
||||
result = await rest_endpoints._execute_with_mcp_client(payload, ok_operation)
|
||||
|
||||
assert result["status"] == "ok"
|
||||
server = captured["server"]
|
||||
@ -262,7 +259,10 @@ class TestExecuteWithMcpClient:
|
||||
assert result["status"] == "ok"
|
||||
# The incoming Authorization must be dropped — extra_headers should
|
||||
# contain no oauth2 headers (only static_headers, which are None here).
|
||||
assert captured["extra_headers"] is None or "Authorization" not in captured["extra_headers"]
|
||||
assert (
|
||||
captured["extra_headers"] is None
|
||||
or "Authorization" not in captured["extra_headers"]
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_catches_exception_group(self, monkeypatch):
|
||||
@ -300,9 +300,7 @@ class TestExecuteWithMcpClient:
|
||||
auth_type=MCPAuth.none,
|
||||
)
|
||||
|
||||
result = await rest_endpoints._execute_with_mcp_client(
|
||||
payload, ok_operation
|
||||
)
|
||||
result = await rest_endpoints._execute_with_mcp_client(payload, ok_operation)
|
||||
|
||||
assert result["status"] == "error"
|
||||
assert result["error"] is True
|
||||
@ -365,8 +363,12 @@ class TestTestToolsList:
|
||||
credentials={"auth_value": "secret-key"},
|
||||
)
|
||||
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
result = await rest_endpoints.test_tools_list(
|
||||
request, payload, user_api_key_dict=UserAPIKeyAuth()
|
||||
request,
|
||||
payload,
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
|
||||
assert result["message"] == "Successfully retrieved tools"
|
||||
@ -419,8 +421,12 @@ class TestTestToolsList:
|
||||
auth_type=MCPAuth.oauth2,
|
||||
)
|
||||
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
result = await rest_endpoints.test_tools_list(
|
||||
request, payload, user_api_key_dict=UserAPIKeyAuth()
|
||||
request,
|
||||
payload,
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
|
||||
assert result["message"] == "Successfully retrieved tools"
|
||||
@ -484,7 +490,11 @@ class TestListToolsRestAPI:
|
||||
captured = {"called": False}
|
||||
|
||||
async def fake_get_tools(
|
||||
server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None
|
||||
server,
|
||||
server_auth_header,
|
||||
raw_headers=None,
|
||||
user_api_key_auth=None,
|
||||
extra_headers=None,
|
||||
):
|
||||
captured["called"] = True
|
||||
captured["server"] = server
|
||||
@ -555,27 +565,47 @@ class TestListToolsRestAPI:
|
||||
|
||||
captured = {"called": False, "server_arg": None}
|
||||
|
||||
async def fake_get_tools(server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None):
|
||||
async def fake_get_tools(
|
||||
server,
|
||||
server_auth_header,
|
||||
raw_headers=None,
|
||||
user_api_key_auth=None,
|
||||
extra_headers=None,
|
||||
):
|
||||
captured["called"] = True
|
||||
captured["server_arg"] = server
|
||||
return ["tool-x"]
|
||||
|
||||
monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers",
|
||||
fake_get_allowed_mcp_servers, raising=False,
|
||||
rest_endpoints,
|
||||
"build_effective_auth_contexts",
|
||||
fake_contexts,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_name",
|
||||
rest_endpoints.global_mcp_server_manager,
|
||||
"get_allowed_mcp_servers",
|
||||
fake_get_allowed_mcp_servers,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints.global_mcp_server_manager,
|
||||
"get_mcp_server_by_name",
|
||||
lambda name: stub_server if name == "my-server" else None,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id",
|
||||
rest_endpoints.global_mcp_server_manager,
|
||||
"get_mcp_server_by_id",
|
||||
lambda sid: stub_server if sid == "uuid-abc-123" else None,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(rest_endpoints, "_get_tools_for_single_server", fake_get_tools, raising=False)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints,
|
||||
"_get_tools_for_single_server",
|
||||
fake_get_tools,
|
||||
raising=False,
|
||||
)
|
||||
|
||||
request = _build_request(path="/mcp-rest/tools/list", method="GET")
|
||||
result = await rest_endpoints.list_tool_rest_api(
|
||||
@ -609,18 +639,27 @@ class TestListToolsRestAPI:
|
||||
async def fake_get_allowed_mcp_servers(*args, **kwargs):
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers",
|
||||
fake_get_allowed_mcp_servers, raising=False,
|
||||
rest_endpoints,
|
||||
"build_effective_auth_contexts",
|
||||
fake_contexts,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_name",
|
||||
rest_endpoints.global_mcp_server_manager,
|
||||
"get_allowed_mcp_servers",
|
||||
fake_get_allowed_mcp_servers,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints.global_mcp_server_manager,
|
||||
"get_mcp_server_by_name",
|
||||
lambda name: stub_server if name == "restricted-server" else None,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id",
|
||||
rest_endpoints.global_mcp_server_manager,
|
||||
"get_mcp_server_by_id",
|
||||
lambda sid: stub_server if sid == "uuid-xyz-999" else None,
|
||||
raising=False,
|
||||
)
|
||||
@ -662,31 +701,54 @@ class TestListToolsRestAPI:
|
||||
|
||||
oauth_headers = {"Authorization": "Bearer user-oauth-token"}
|
||||
|
||||
async def fake_get_user_oauth_extra_headers(server, user_api_key_dict, prefetched_creds=None):
|
||||
async def fake_get_user_oauth_extra_headers(
|
||||
server, user_api_key_dict, prefetched_creds=None
|
||||
):
|
||||
return oauth_headers
|
||||
|
||||
captured = {}
|
||||
|
||||
async def fake_get_tools(server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None):
|
||||
async def fake_get_tools(
|
||||
server,
|
||||
server_auth_header,
|
||||
raw_headers=None,
|
||||
user_api_key_auth=None,
|
||||
extra_headers=None,
|
||||
):
|
||||
captured["server"] = server
|
||||
captured["auth_header"] = server_auth_header
|
||||
return ["oauth-tool"]
|
||||
|
||||
monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers",
|
||||
fake_get_allowed_mcp_servers, raising=False,
|
||||
rest_endpoints,
|
||||
"build_effective_auth_contexts",
|
||||
fake_contexts,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id",
|
||||
rest_endpoints.global_mcp_server_manager,
|
||||
"get_allowed_mcp_servers",
|
||||
fake_get_allowed_mcp_servers,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints.global_mcp_server_manager,
|
||||
"get_mcp_server_by_id",
|
||||
lambda sid: stub_server if sid == "oauth-server-id" else None,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints, "_get_user_oauth_extra_headers",
|
||||
fake_get_user_oauth_extra_headers, raising=False,
|
||||
rest_endpoints,
|
||||
"_get_user_oauth_extra_headers",
|
||||
fake_get_user_oauth_extra_headers,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints,
|
||||
"_get_tools_for_single_server",
|
||||
fake_get_tools,
|
||||
raising=False,
|
||||
)
|
||||
monkeypatch.setattr(rest_endpoints, "_get_tools_for_single_server", fake_get_tools, raising=False)
|
||||
|
||||
request = _build_request(path="/mcp-rest/tools/list", method="GET")
|
||||
result = await rest_endpoints.list_tool_rest_api(
|
||||
@ -1124,3 +1186,179 @@ class TestGetToolsForSingleServer:
|
||||
assert "tool3" in tool_names
|
||||
assert "tool1" not in tool_names
|
||||
assert "tool4" not in tool_names
|
||||
|
||||
|
||||
class TestStdioCommandAllowlist:
|
||||
"""Tests for MCP stdio command allowlist validation."""
|
||||
|
||||
def test_allowed_command_passes_validation(self):
|
||||
"""npx, uvx, python, etc. should be accepted."""
|
||||
req = NewMCPServerRequest(
|
||||
server_name="test",
|
||||
transport="stdio",
|
||||
command="npx",
|
||||
args=["-y", "@modelcontextprotocol/server-filesystem"],
|
||||
)
|
||||
assert req.command == "npx"
|
||||
|
||||
def test_disallowed_command_raises(self):
|
||||
"""Arbitrary commands like bash should be rejected."""
|
||||
with pytest.raises(ValueError, match="not in the allowed commands list"):
|
||||
NewMCPServerRequest(
|
||||
server_name="test",
|
||||
transport="stdio",
|
||||
command="bash",
|
||||
args=["-c", "echo pwned"],
|
||||
)
|
||||
|
||||
def test_sh_command_raises(self):
|
||||
"""sh should be rejected."""
|
||||
with pytest.raises(ValueError, match="not in the allowed commands list"):
|
||||
NewMCPServerRequest(
|
||||
server_name="test",
|
||||
transport="stdio",
|
||||
command="sh",
|
||||
args=["-c", "id > /tmp/output.txt"],
|
||||
)
|
||||
|
||||
def test_absolute_path_bypass_blocked(self):
|
||||
"""/bin/bash should be blocked (basename is 'bash')."""
|
||||
with pytest.raises(ValueError, match="not in the allowed commands list"):
|
||||
NewMCPServerRequest(
|
||||
server_name="test",
|
||||
transport="stdio",
|
||||
command="/bin/bash",
|
||||
args=["-c", "echo pwned"],
|
||||
)
|
||||
|
||||
def test_absolute_path_to_allowed_command_works(self):
|
||||
"""/usr/bin/python3 should pass (basename is 'python3')."""
|
||||
req = NewMCPServerRequest(
|
||||
server_name="test",
|
||||
transport="stdio",
|
||||
command="/usr/bin/python3",
|
||||
args=["-m", "some_module"],
|
||||
)
|
||||
assert req.command == "/usr/bin/python3"
|
||||
|
||||
def test_http_transport_ignores_allowlist(self):
|
||||
"""HTTP/SSE transport should not trigger command validation."""
|
||||
req = NewMCPServerRequest(
|
||||
server_name="test",
|
||||
transport="sse",
|
||||
url="https://example.com/mcp",
|
||||
)
|
||||
assert req.transport == "sse"
|
||||
|
||||
def test_uvx_command_passes(self):
|
||||
req = NewMCPServerRequest(
|
||||
server_name="test",
|
||||
transport="stdio",
|
||||
command="uvx",
|
||||
args=["mcp-server-sqlite"],
|
||||
)
|
||||
assert req.command == "uvx"
|
||||
|
||||
def test_node_command_passes(self):
|
||||
req = NewMCPServerRequest(
|
||||
server_name="test",
|
||||
transport="stdio",
|
||||
command="node",
|
||||
args=["server.js"],
|
||||
)
|
||||
assert req.command == "node"
|
||||
|
||||
|
||||
class TestEndpointRoleChecks:
|
||||
"""Tests for PROXY_ADMIN role checks on MCP test endpoints."""
|
||||
|
||||
def test_test_connection_has_auth_dependency(self):
|
||||
route = _get_route("/mcp-rest/test/connection", "POST")
|
||||
assert _route_has_dependency(route, user_api_key_auth)
|
||||
|
||||
def test_test_tools_list_has_auth_dependency(self):
|
||||
route = _get_route("/mcp-rest/test/tools/list", "POST")
|
||||
assert _route_has_dependency(route, user_api_key_auth)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_connection_rejects_non_admin(self):
|
||||
"""Non-admin users should get 403 from test_connection."""
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
payload = NewMCPServerRequest(
|
||||
server_name="test",
|
||||
url="https://example.com/mcp",
|
||||
auth_type=MCPAuth.none,
|
||||
)
|
||||
user_key = UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
user_id="non_admin",
|
||||
api_key="sk-test",
|
||||
)
|
||||
request = _build_request()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await rest_endpoints.test_connection(
|
||||
request=request,
|
||||
new_mcp_server_request=payload,
|
||||
user_api_key_dict=user_key,
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_tools_list_rejects_non_admin(self):
|
||||
"""Non-admin users should get 403 from test_tools_list."""
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
payload = NewMCPServerRequest(
|
||||
server_name="test",
|
||||
url="https://example.com/mcp",
|
||||
auth_type=MCPAuth.none,
|
||||
)
|
||||
user_key = UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER,
|
||||
user_id="non_admin",
|
||||
api_key="sk-test",
|
||||
)
|
||||
request = _build_request()
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await rest_endpoints.test_tools_list(
|
||||
request=request,
|
||||
new_mcp_server_request=payload,
|
||||
user_api_key_dict=user_key,
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_test_connection_allows_admin(self, monkeypatch):
|
||||
"""PROXY_ADMIN should pass the role check."""
|
||||
from litellm.proxy._types import LitellmUserRoles
|
||||
|
||||
async def fake_execute(*args, **kwargs):
|
||||
return {"status": "ok"}
|
||||
|
||||
monkeypatch.setattr(
|
||||
rest_endpoints,
|
||||
"_execute_with_mcp_client",
|
||||
fake_execute,
|
||||
)
|
||||
|
||||
payload = NewMCPServerRequest(
|
||||
server_name="test",
|
||||
url="https://example.com/mcp",
|
||||
auth_type=MCPAuth.none,
|
||||
)
|
||||
user_key = UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
user_id="admin",
|
||||
api_key="sk-admin",
|
||||
)
|
||||
request = _build_request()
|
||||
|
||||
result = await rest_endpoints.test_connection(
|
||||
request=request,
|
||||
new_mcp_server_request=payload,
|
||||
user_api_key_dict=user_key,
|
||||
)
|
||||
assert result["status"] == "ok"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user