fix(mcp): block arbitrary command execution via stdio transport

Add command allowlist for MCP stdio transport to prevent RCE via
/mcp-rest/test/* endpoints. Restrict test endpoints to PROXY_ADMIN
role. Fix docker/README.md MASTER_KEY -> LITELLM_MASTER_KEY.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Sameer Kankute 2026-03-26 15:56:54 +05:30
parent 62757ff48f
commit 7b7f304675
No known key found for this signature in database
6 changed files with 337 additions and 39 deletions

View File

@ -13,19 +13,19 @@ To build and run the application, you will use the `docker-compose.yml` file loc
### 1. Set the Master Key
The application requires a `MASTER_KEY` for signing and validating tokens. You must set this key as an environment variable before running the application.
The application requires a `LITELLM_MASTER_KEY` for signing and validating tokens. You must set this key as an environment variable before running the application.
Create a `.env` file in the root of the project and add the following line:
```
MASTER_KEY=your-secret-key
LITELLM_MASTER_KEY=your-secret-key
```
Replace `your-secret-key` with a strong, randomly generated secret.
### 2. Build and Run the Containers
Once you have set the `MASTER_KEY`, you can build and run the containers using the following command:
Once you have set the `LITELLM_MASTER_KEY`, you can build and run the containers using the following command:
```bash
docker compose up -d --build
@ -89,4 +89,4 @@ This command should succeed (showing engine versions) even with `--network none`
## Troubleshooting
- **`build_admin_ui.sh: not found`**: This error can occur if the Docker build context is not set correctly. Ensure that you are running the `docker-compose` command from the root of the project.
- **`Master key is not initialized`**: This error means the `MASTER_key` environment variable is not set. Make sure you have created a `.env` file in the project root with the `MASTER_KEY` defined.
- **`Master key is not initialized`**: This error means the `LITELLM_MASTER_KEY` environment variable is not set. Make sure you have created a `.env` file in the project root with the `LITELLM_MASTER_KEY` defined.

View File

@ -141,6 +141,15 @@ MCP_TOOL_LISTING_TIMEOUT = float(os.getenv("LITELLM_MCP_TOOL_LISTING_TIMEOUT", "
MCP_METADATA_TIMEOUT = float(os.getenv("LITELLM_MCP_METADATA_TIMEOUT", "10.0"))
MCP_HEALTH_CHECK_TIMEOUT = float(os.getenv("LITELLM_MCP_HEALTH_CHECK_TIMEOUT", "10.0"))
# Allowlist of commands permitted for MCP stdio transport.
# Prevents arbitrary command execution via /mcp-rest/test/* endpoints or server creation.
# Extend via LITELLM_MCP_STDIO_EXTRA_COMMANDS env var (comma-separated).
_MCP_STDIO_EXTRA_COMMANDS = os.getenv("LITELLM_MCP_STDIO_EXTRA_COMMANDS", "")
MCP_STDIO_ALLOWED_COMMANDS: frozenset = frozenset(
{"npx", "uvx", "python", "python3", "node", "docker", "deno"}
| (set(_MCP_STDIO_EXTRA_COMMANDS.split(",")) - {""})
)
LITELLM_UI_ALLOW_HEADERS = [
"x-litellm-semantic-filter",
"x-litellm-semantic-filter-tools",

View File

@ -1122,6 +1122,20 @@ class MCPServerManager:
from litellm.constants import MCP_NPM_CACHE_DIR
resolved_env["NPM_CONFIG_CACHE"] = MCP_NPM_CACHE_DIR
# Defense-in-depth: validate command even if Pydantic validation was bypassed
# (e.g. MCPServer built from config/DB records predating the allowlist)
if server.command:
import os as _os
from litellm.constants import MCP_STDIO_ALLOWED_COMMANDS
base_command = _os.path.basename(server.command)
if base_command not in MCP_STDIO_ALLOWED_COMMANDS:
raise ValueError(
f"Command '{server.command}' is not in the allowed commands list "
f"for stdio transport. Allowed commands: {sorted(MCP_STDIO_ALLOWED_COMMANDS)}"
)
stdio_config: Optional[MCPStdioConfig] = None
if server.command and server.args is not None:
stdio_config = MCPStdioConfig(

View File

@ -2,14 +2,14 @@ import importlib
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Set, Union
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
from litellm._logging import verbose_logger
from litellm.proxy._experimental.mcp_server.ui_session_utils import (
build_effective_auth_contexts,
)
from litellm.proxy._experimental.mcp_server.utils import merge_mcp_headers
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers
@ -1027,6 +1027,13 @@ if MCP_AVAILABLE:
"""
Test if we can connect to the provided MCP server before adding it
"""
if LitellmUserRoles.PROXY_ADMIN != user_api_key_dict.user_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "User does not have permission to test MCP server connections. Only PROXY_ADMIN users can perform this action."
},
)
async def _test_connection_operation(client):
async def _noop(session):
@ -1041,7 +1048,7 @@ if MCP_AVAILABLE:
raw_headers=_safe_get_request_headers(request),
)
@router.post("/test/tools/list")
@router.post("/test/tools/list", dependencies=[Depends(user_api_key_auth)])
async def test_tools_list(
request: Request,
new_mcp_server_request: NewMCPServerRequest,
@ -1050,6 +1057,14 @@ if MCP_AVAILABLE:
"""
Preview tools available from MCP server before adding it
"""
if LitellmUserRoles.PROXY_ADMIN != user_api_key_dict.user_role:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"error": "User does not have permission to test MCP server tools. Only PROXY_ADMIN users can perform this action."
},
)
# For OpenAPI spec servers, generate tools from the spec directly
if new_mcp_server_request.spec_path:
return await _preview_openapi_tools(new_mcp_server_request.spec_path)

View File

@ -1162,6 +1162,17 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
raise ValueError("command is required for stdio transport")
if not values.get("args"):
raise ValueError("args is required for stdio transport")
# Validate command against allowlist to prevent arbitrary execution
import os as _os
from litellm.constants import MCP_STDIO_ALLOWED_COMMANDS
base_command = _os.path.basename(values["command"])
if base_command not in MCP_STDIO_ALLOWED_COMMANDS:
raise ValueError(
f"Command '{values['command']}' is not in the allowed commands list "
f"for stdio transport. Allowed commands: {sorted(MCP_STDIO_ALLOWED_COMMANDS)}"
)
elif transport in [MCPTransport.http, MCPTransport.sse]:
if not values.get("url") and not values.get("spec_path"):
raise ValueError(
@ -1222,6 +1233,17 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
raise ValueError("command is required for stdio transport")
if not values.get("args"):
raise ValueError("args is required for stdio transport")
# Validate command against allowlist to prevent arbitrary execution
import os as _os
from litellm.constants import MCP_STDIO_ALLOWED_COMMANDS
base_command = _os.path.basename(values["command"])
if base_command not in MCP_STDIO_ALLOWED_COMMANDS:
raise ValueError(
f"Command '{values['command']}' is not in the allowed commands list "
f"for stdio transport. Allowed commands: {sorted(MCP_STDIO_ALLOWED_COMMANDS)}"
)
elif transport in [MCPTransport.http, MCPTransport.sse]:
if not values.get("url") and not values.get("spec_path"):
raise ValueError(

View File

@ -156,7 +156,6 @@ class TestExecuteWithMcpClient:
"Authorization": "STATIC token",
}
@pytest.mark.asyncio
async def test_m2m_credentials_forwarded_to_server_model(self, monkeypatch):
"""M2M OAuth credentials (client_id, client_secret) from the nested
@ -199,9 +198,7 @@ class TestExecuteWithMcpClient:
},
)
result = await rest_endpoints._execute_with_mcp_client(
payload, ok_operation
)
result = await rest_endpoints._execute_with_mcp_client(payload, ok_operation)
assert result["status"] == "ok"
server = captured["server"]
@ -262,7 +259,10 @@ class TestExecuteWithMcpClient:
assert result["status"] == "ok"
# The incoming Authorization must be dropped — extra_headers should
# contain no oauth2 headers (only static_headers, which are None here).
assert captured["extra_headers"] is None or "Authorization" not in captured["extra_headers"]
assert (
captured["extra_headers"] is None
or "Authorization" not in captured["extra_headers"]
)
@pytest.mark.asyncio
async def test_catches_exception_group(self, monkeypatch):
@ -300,9 +300,7 @@ class TestExecuteWithMcpClient:
auth_type=MCPAuth.none,
)
result = await rest_endpoints._execute_with_mcp_client(
payload, ok_operation
)
result = await rest_endpoints._execute_with_mcp_client(payload, ok_operation)
assert result["status"] == "error"
assert result["error"] is True
@ -365,8 +363,12 @@ class TestTestToolsList:
credentials={"auth_value": "secret-key"},
)
from litellm.proxy._types import LitellmUserRoles
result = await rest_endpoints.test_tools_list(
request, payload, user_api_key_dict=UserAPIKeyAuth()
request,
payload,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
assert result["message"] == "Successfully retrieved tools"
@ -419,8 +421,12 @@ class TestTestToolsList:
auth_type=MCPAuth.oauth2,
)
from litellm.proxy._types import LitellmUserRoles
result = await rest_endpoints.test_tools_list(
request, payload, user_api_key_dict=UserAPIKeyAuth()
request,
payload,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
assert result["message"] == "Successfully retrieved tools"
@ -484,7 +490,11 @@ class TestListToolsRestAPI:
captured = {"called": False}
async def fake_get_tools(
server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None
server,
server_auth_header,
raw_headers=None,
user_api_key_auth=None,
extra_headers=None,
):
captured["called"] = True
captured["server"] = server
@ -555,27 +565,47 @@ class TestListToolsRestAPI:
captured = {"called": False, "server_arg": None}
async def fake_get_tools(server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None):
async def fake_get_tools(
server,
server_auth_header,
raw_headers=None,
user_api_key_auth=None,
extra_headers=None,
):
captured["called"] = True
captured["server_arg"] = server
return ["tool-x"]
monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers",
fake_get_allowed_mcp_servers, raising=False,
rest_endpoints,
"build_effective_auth_contexts",
fake_contexts,
raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_name",
rest_endpoints.global_mcp_server_manager,
"get_allowed_mcp_servers",
fake_get_allowed_mcp_servers,
raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager,
"get_mcp_server_by_name",
lambda name: stub_server if name == "my-server" else None,
raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id",
rest_endpoints.global_mcp_server_manager,
"get_mcp_server_by_id",
lambda sid: stub_server if sid == "uuid-abc-123" else None,
raising=False,
)
monkeypatch.setattr(rest_endpoints, "_get_tools_for_single_server", fake_get_tools, raising=False)
monkeypatch.setattr(
rest_endpoints,
"_get_tools_for_single_server",
fake_get_tools,
raising=False,
)
request = _build_request(path="/mcp-rest/tools/list", method="GET")
result = await rest_endpoints.list_tool_rest_api(
@ -609,18 +639,27 @@ class TestListToolsRestAPI:
async def fake_get_allowed_mcp_servers(*args, **kwargs):
return []
monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers",
fake_get_allowed_mcp_servers, raising=False,
rest_endpoints,
"build_effective_auth_contexts",
fake_contexts,
raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_name",
rest_endpoints.global_mcp_server_manager,
"get_allowed_mcp_servers",
fake_get_allowed_mcp_servers,
raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager,
"get_mcp_server_by_name",
lambda name: stub_server if name == "restricted-server" else None,
raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id",
rest_endpoints.global_mcp_server_manager,
"get_mcp_server_by_id",
lambda sid: stub_server if sid == "uuid-xyz-999" else None,
raising=False,
)
@ -662,31 +701,54 @@ class TestListToolsRestAPI:
oauth_headers = {"Authorization": "Bearer user-oauth-token"}
async def fake_get_user_oauth_extra_headers(server, user_api_key_dict, prefetched_creds=None):
async def fake_get_user_oauth_extra_headers(
server, user_api_key_dict, prefetched_creds=None
):
return oauth_headers
captured = {}
async def fake_get_tools(server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None):
async def fake_get_tools(
server,
server_auth_header,
raw_headers=None,
user_api_key_auth=None,
extra_headers=None,
):
captured["server"] = server
captured["auth_header"] = server_auth_header
return ["oauth-tool"]
monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers",
fake_get_allowed_mcp_servers, raising=False,
rest_endpoints,
"build_effective_auth_contexts",
fake_contexts,
raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id",
rest_endpoints.global_mcp_server_manager,
"get_allowed_mcp_servers",
fake_get_allowed_mcp_servers,
raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager,
"get_mcp_server_by_id",
lambda sid: stub_server if sid == "oauth-server-id" else None,
raising=False,
)
monkeypatch.setattr(
rest_endpoints, "_get_user_oauth_extra_headers",
fake_get_user_oauth_extra_headers, raising=False,
rest_endpoints,
"_get_user_oauth_extra_headers",
fake_get_user_oauth_extra_headers,
raising=False,
)
monkeypatch.setattr(
rest_endpoints,
"_get_tools_for_single_server",
fake_get_tools,
raising=False,
)
monkeypatch.setattr(rest_endpoints, "_get_tools_for_single_server", fake_get_tools, raising=False)
request = _build_request(path="/mcp-rest/tools/list", method="GET")
result = await rest_endpoints.list_tool_rest_api(
@ -1124,3 +1186,179 @@ class TestGetToolsForSingleServer:
assert "tool3" in tool_names
assert "tool1" not in tool_names
assert "tool4" not in tool_names
class TestStdioCommandAllowlist:
"""Tests for MCP stdio command allowlist validation."""
def test_allowed_command_passes_validation(self):
"""npx, uvx, python, etc. should be accepted."""
req = NewMCPServerRequest(
server_name="test",
transport="stdio",
command="npx",
args=["-y", "@modelcontextprotocol/server-filesystem"],
)
assert req.command == "npx"
def test_disallowed_command_raises(self):
"""Arbitrary commands like bash should be rejected."""
with pytest.raises(ValueError, match="not in the allowed commands list"):
NewMCPServerRequest(
server_name="test",
transport="stdio",
command="bash",
args=["-c", "echo pwned"],
)
def test_sh_command_raises(self):
"""sh should be rejected."""
with pytest.raises(ValueError, match="not in the allowed commands list"):
NewMCPServerRequest(
server_name="test",
transport="stdio",
command="sh",
args=["-c", "id > /tmp/output.txt"],
)
def test_absolute_path_bypass_blocked(self):
"""/bin/bash should be blocked (basename is 'bash')."""
with pytest.raises(ValueError, match="not in the allowed commands list"):
NewMCPServerRequest(
server_name="test",
transport="stdio",
command="/bin/bash",
args=["-c", "echo pwned"],
)
def test_absolute_path_to_allowed_command_works(self):
"""/usr/bin/python3 should pass (basename is 'python3')."""
req = NewMCPServerRequest(
server_name="test",
transport="stdio",
command="/usr/bin/python3",
args=["-m", "some_module"],
)
assert req.command == "/usr/bin/python3"
def test_http_transport_ignores_allowlist(self):
"""HTTP/SSE transport should not trigger command validation."""
req = NewMCPServerRequest(
server_name="test",
transport="sse",
url="https://example.com/mcp",
)
assert req.transport == "sse"
def test_uvx_command_passes(self):
req = NewMCPServerRequest(
server_name="test",
transport="stdio",
command="uvx",
args=["mcp-server-sqlite"],
)
assert req.command == "uvx"
def test_node_command_passes(self):
req = NewMCPServerRequest(
server_name="test",
transport="stdio",
command="node",
args=["server.js"],
)
assert req.command == "node"
class TestEndpointRoleChecks:
"""Tests for PROXY_ADMIN role checks on MCP test endpoints."""
def test_test_connection_has_auth_dependency(self):
route = _get_route("/mcp-rest/test/connection", "POST")
assert _route_has_dependency(route, user_api_key_auth)
def test_test_tools_list_has_auth_dependency(self):
route = _get_route("/mcp-rest/test/tools/list", "POST")
assert _route_has_dependency(route, user_api_key_auth)
@pytest.mark.asyncio
async def test_test_connection_rejects_non_admin(self):
"""Non-admin users should get 403 from test_connection."""
from litellm.proxy._types import LitellmUserRoles
payload = NewMCPServerRequest(
server_name="test",
url="https://example.com/mcp",
auth_type=MCPAuth.none,
)
user_key = UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER,
user_id="non_admin",
api_key="sk-test",
)
request = _build_request()
with pytest.raises(HTTPException) as exc_info:
await rest_endpoints.test_connection(
request=request,
new_mcp_server_request=payload,
user_api_key_dict=user_key,
)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_test_tools_list_rejects_non_admin(self):
"""Non-admin users should get 403 from test_tools_list."""
from litellm.proxy._types import LitellmUserRoles
payload = NewMCPServerRequest(
server_name="test",
url="https://example.com/mcp",
auth_type=MCPAuth.none,
)
user_key = UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER,
user_id="non_admin",
api_key="sk-test",
)
request = _build_request()
with pytest.raises(HTTPException) as exc_info:
await rest_endpoints.test_tools_list(
request=request,
new_mcp_server_request=payload,
user_api_key_dict=user_key,
)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_test_connection_allows_admin(self, monkeypatch):
"""PROXY_ADMIN should pass the role check."""
from litellm.proxy._types import LitellmUserRoles
async def fake_execute(*args, **kwargs):
return {"status": "ok"}
monkeypatch.setattr(
rest_endpoints,
"_execute_with_mcp_client",
fake_execute,
)
payload = NewMCPServerRequest(
server_name="test",
url="https://example.com/mcp",
auth_type=MCPAuth.none,
)
user_key = UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
user_id="admin",
api_key="sk-admin",
)
request = _build_request()
result = await rest_endpoints.test_connection(
request=request,
new_mcp_server_request=payload,
user_api_key_dict=user_key,
)
assert result["status"] == "ok"