diff --git a/docker/README.md b/docker/README.md index 7027a30fdd..26d8c9a37b 100644 --- a/docker/README.md +++ b/docker/README.md @@ -13,19 +13,19 @@ To build and run the application, you will use the `docker-compose.yml` file loc ### 1. Set the Master Key -The application requires a `MASTER_KEY` for signing and validating tokens. You must set this key as an environment variable before running the application. +The application requires a `LITELLM_MASTER_KEY` for signing and validating tokens. You must set this key as an environment variable before running the application. Create a `.env` file in the root of the project and add the following line: ``` -MASTER_KEY=your-secret-key +LITELLM_MASTER_KEY=your-secret-key ``` Replace `your-secret-key` with a strong, randomly generated secret. ### 2. Build and Run the Containers -Once you have set the `MASTER_KEY`, you can build and run the containers using the following command: +Once you have set the `LITELLM_MASTER_KEY`, you can build and run the containers using the following command: ```bash docker compose up -d --build @@ -89,4 +89,4 @@ This command should succeed (showing engine versions) even with `--network none` ## Troubleshooting - **`build_admin_ui.sh: not found`**: This error can occur if the Docker build context is not set correctly. Ensure that you are running the `docker-compose` command from the root of the project. -- **`Master key is not initialized`**: This error means the `MASTER_key` environment variable is not set. Make sure you have created a `.env` file in the project root with the `MASTER_KEY` defined. +- **`Master key is not initialized`**: This error means the `LITELLM_MASTER_KEY` environment variable is not set. Make sure you have created a `.env` file in the project root with the `LITELLM_MASTER_KEY` defined. diff --git a/litellm/constants.py b/litellm/constants.py index 28c6c0cc0e..49ec47e251 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -141,6 +141,15 @@ MCP_TOOL_LISTING_TIMEOUT = float(os.getenv("LITELLM_MCP_TOOL_LISTING_TIMEOUT", " MCP_METADATA_TIMEOUT = float(os.getenv("LITELLM_MCP_METADATA_TIMEOUT", "10.0")) MCP_HEALTH_CHECK_TIMEOUT = float(os.getenv("LITELLM_MCP_HEALTH_CHECK_TIMEOUT", "10.0")) +# Allowlist of commands permitted for MCP stdio transport. +# Prevents arbitrary command execution via /mcp-rest/test/* endpoints or server creation. +# Extend via LITELLM_MCP_STDIO_EXTRA_COMMANDS env var (comma-separated). +_MCP_STDIO_EXTRA_COMMANDS = os.getenv("LITELLM_MCP_STDIO_EXTRA_COMMANDS", "") +MCP_STDIO_ALLOWED_COMMANDS: frozenset = frozenset( + {"npx", "uvx", "python", "python3", "node", "docker", "deno"} + | (set(_MCP_STDIO_EXTRA_COMMANDS.split(",")) - {""}) +) + LITELLM_UI_ALLOW_HEADERS = [ "x-litellm-semantic-filter", "x-litellm-semantic-filter-tools", diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 7b87e7e7e6..e8ad0bf748 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -1122,6 +1122,20 @@ class MCPServerManager: from litellm.constants import MCP_NPM_CACHE_DIR resolved_env["NPM_CONFIG_CACHE"] = MCP_NPM_CACHE_DIR + # Defense-in-depth: validate command even if Pydantic validation was bypassed + # (e.g. MCPServer built from config/DB records predating the allowlist) + if server.command: + import os as _os + + from litellm.constants import MCP_STDIO_ALLOWED_COMMANDS + + base_command = _os.path.basename(server.command) + if base_command not in MCP_STDIO_ALLOWED_COMMANDS: + raise ValueError( + f"Command '{server.command}' is not in the allowed commands list " + f"for stdio transport. Allowed commands: {sorted(MCP_STDIO_ALLOWED_COMMANDS)}" + ) + stdio_config: Optional[MCPStdioConfig] = None if server.command and server.args is not None: stdio_config = MCPStdioConfig( diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index c0151d47e0..32560a2211 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -2,14 +2,14 @@ import importlib from datetime import datetime from typing import Any, Awaitable, Callable, Dict, List, Literal, Optional, Set, Union -from fastapi import APIRouter, Depends, HTTPException, Query, Request +from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from litellm._logging import verbose_logger from litellm.proxy._experimental.mcp_server.ui_session_utils import ( build_effective_auth_contexts, ) from litellm.proxy._experimental.mcp_server.utils import merge_mcp_headers -from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.auth.ip_address_utils import IPAddressUtils from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.http_parsing_utils import _safe_get_request_headers @@ -1027,6 +1027,13 @@ if MCP_AVAILABLE: """ Test if we can connect to the provided MCP server before adding it """ + if LitellmUserRoles.PROXY_ADMIN != user_api_key_dict.user_role: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": "User does not have permission to test MCP server connections. Only PROXY_ADMIN users can perform this action." + }, + ) async def _test_connection_operation(client): async def _noop(session): @@ -1041,7 +1048,7 @@ if MCP_AVAILABLE: raw_headers=_safe_get_request_headers(request), ) - @router.post("/test/tools/list") + @router.post("/test/tools/list", dependencies=[Depends(user_api_key_auth)]) async def test_tools_list( request: Request, new_mcp_server_request: NewMCPServerRequest, @@ -1050,6 +1057,14 @@ if MCP_AVAILABLE: """ Preview tools available from MCP server before adding it """ + if LitellmUserRoles.PROXY_ADMIN != user_api_key_dict.user_role: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": "User does not have permission to test MCP server tools. Only PROXY_ADMIN users can perform this action." + }, + ) + # For OpenAPI spec servers, generate tools from the spec directly if new_mcp_server_request.spec_path: return await _preview_openapi_tools(new_mcp_server_request.spec_path) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 441b3b836a..0f537433a3 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1162,6 +1162,17 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase): raise ValueError("command is required for stdio transport") if not values.get("args"): raise ValueError("args is required for stdio transport") + # Validate command against allowlist to prevent arbitrary execution + import os as _os + + from litellm.constants import MCP_STDIO_ALLOWED_COMMANDS + + base_command = _os.path.basename(values["command"]) + if base_command not in MCP_STDIO_ALLOWED_COMMANDS: + raise ValueError( + f"Command '{values['command']}' is not in the allowed commands list " + f"for stdio transport. Allowed commands: {sorted(MCP_STDIO_ALLOWED_COMMANDS)}" + ) elif transport in [MCPTransport.http, MCPTransport.sse]: if not values.get("url") and not values.get("spec_path"): raise ValueError( @@ -1222,6 +1233,17 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase): raise ValueError("command is required for stdio transport") if not values.get("args"): raise ValueError("args is required for stdio transport") + # Validate command against allowlist to prevent arbitrary execution + import os as _os + + from litellm.constants import MCP_STDIO_ALLOWED_COMMANDS + + base_command = _os.path.basename(values["command"]) + if base_command not in MCP_STDIO_ALLOWED_COMMANDS: + raise ValueError( + f"Command '{values['command']}' is not in the allowed commands list " + f"for stdio transport. Allowed commands: {sorted(MCP_STDIO_ALLOWED_COMMANDS)}" + ) elif transport in [MCPTransport.http, MCPTransport.sse]: if not values.get("url") and not values.get("spec_path"): raise ValueError( diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py index 3acbe5465f..25786d982c 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py @@ -156,7 +156,6 @@ class TestExecuteWithMcpClient: "Authorization": "STATIC token", } - @pytest.mark.asyncio async def test_m2m_credentials_forwarded_to_server_model(self, monkeypatch): """M2M OAuth credentials (client_id, client_secret) from the nested @@ -199,9 +198,7 @@ class TestExecuteWithMcpClient: }, ) - result = await rest_endpoints._execute_with_mcp_client( - payload, ok_operation - ) + result = await rest_endpoints._execute_with_mcp_client(payload, ok_operation) assert result["status"] == "ok" server = captured["server"] @@ -262,7 +259,10 @@ class TestExecuteWithMcpClient: assert result["status"] == "ok" # The incoming Authorization must be dropped — extra_headers should # contain no oauth2 headers (only static_headers, which are None here). - assert captured["extra_headers"] is None or "Authorization" not in captured["extra_headers"] + assert ( + captured["extra_headers"] is None + or "Authorization" not in captured["extra_headers"] + ) @pytest.mark.asyncio async def test_catches_exception_group(self, monkeypatch): @@ -300,9 +300,7 @@ class TestExecuteWithMcpClient: auth_type=MCPAuth.none, ) - result = await rest_endpoints._execute_with_mcp_client( - payload, ok_operation - ) + result = await rest_endpoints._execute_with_mcp_client(payload, ok_operation) assert result["status"] == "error" assert result["error"] is True @@ -365,8 +363,12 @@ class TestTestToolsList: credentials={"auth_value": "secret-key"}, ) + from litellm.proxy._types import LitellmUserRoles + result = await rest_endpoints.test_tools_list( - request, payload, user_api_key_dict=UserAPIKeyAuth() + request, + payload, + user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), ) assert result["message"] == "Successfully retrieved tools" @@ -419,8 +421,12 @@ class TestTestToolsList: auth_type=MCPAuth.oauth2, ) + from litellm.proxy._types import LitellmUserRoles + result = await rest_endpoints.test_tools_list( - request, payload, user_api_key_dict=UserAPIKeyAuth() + request, + payload, + user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), ) assert result["message"] == "Successfully retrieved tools" @@ -484,7 +490,11 @@ class TestListToolsRestAPI: captured = {"called": False} async def fake_get_tools( - server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None + server, + server_auth_header, + raw_headers=None, + user_api_key_auth=None, + extra_headers=None, ): captured["called"] = True captured["server"] = server @@ -555,27 +565,47 @@ class TestListToolsRestAPI: captured = {"called": False, "server_arg": None} - async def fake_get_tools(server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None): + async def fake_get_tools( + server, + server_auth_header, + raw_headers=None, + user_api_key_auth=None, + extra_headers=None, + ): captured["called"] = True captured["server_arg"] = server return ["tool-x"] - monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False) monkeypatch.setattr( - rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers", - fake_get_allowed_mcp_servers, raising=False, + rest_endpoints, + "build_effective_auth_contexts", + fake_contexts, + raising=False, ) monkeypatch.setattr( - rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_name", + rest_endpoints.global_mcp_server_manager, + "get_allowed_mcp_servers", + fake_get_allowed_mcp_servers, + raising=False, + ) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, + "get_mcp_server_by_name", lambda name: stub_server if name == "my-server" else None, raising=False, ) monkeypatch.setattr( - rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id", + rest_endpoints.global_mcp_server_manager, + "get_mcp_server_by_id", lambda sid: stub_server if sid == "uuid-abc-123" else None, raising=False, ) - monkeypatch.setattr(rest_endpoints, "_get_tools_for_single_server", fake_get_tools, raising=False) + monkeypatch.setattr( + rest_endpoints, + "_get_tools_for_single_server", + fake_get_tools, + raising=False, + ) request = _build_request(path="/mcp-rest/tools/list", method="GET") result = await rest_endpoints.list_tool_rest_api( @@ -609,18 +639,27 @@ class TestListToolsRestAPI: async def fake_get_allowed_mcp_servers(*args, **kwargs): return [] - monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False) monkeypatch.setattr( - rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers", - fake_get_allowed_mcp_servers, raising=False, + rest_endpoints, + "build_effective_auth_contexts", + fake_contexts, + raising=False, ) monkeypatch.setattr( - rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_name", + rest_endpoints.global_mcp_server_manager, + "get_allowed_mcp_servers", + fake_get_allowed_mcp_servers, + raising=False, + ) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, + "get_mcp_server_by_name", lambda name: stub_server if name == "restricted-server" else None, raising=False, ) monkeypatch.setattr( - rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id", + rest_endpoints.global_mcp_server_manager, + "get_mcp_server_by_id", lambda sid: stub_server if sid == "uuid-xyz-999" else None, raising=False, ) @@ -662,31 +701,54 @@ class TestListToolsRestAPI: oauth_headers = {"Authorization": "Bearer user-oauth-token"} - async def fake_get_user_oauth_extra_headers(server, user_api_key_dict, prefetched_creds=None): + async def fake_get_user_oauth_extra_headers( + server, user_api_key_dict, prefetched_creds=None + ): return oauth_headers captured = {} - async def fake_get_tools(server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None): + async def fake_get_tools( + server, + server_auth_header, + raw_headers=None, + user_api_key_auth=None, + extra_headers=None, + ): captured["server"] = server captured["auth_header"] = server_auth_header return ["oauth-tool"] - monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False) monkeypatch.setattr( - rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers", - fake_get_allowed_mcp_servers, raising=False, + rest_endpoints, + "build_effective_auth_contexts", + fake_contexts, + raising=False, ) monkeypatch.setattr( - rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id", + rest_endpoints.global_mcp_server_manager, + "get_allowed_mcp_servers", + fake_get_allowed_mcp_servers, + raising=False, + ) + monkeypatch.setattr( + rest_endpoints.global_mcp_server_manager, + "get_mcp_server_by_id", lambda sid: stub_server if sid == "oauth-server-id" else None, raising=False, ) monkeypatch.setattr( - rest_endpoints, "_get_user_oauth_extra_headers", - fake_get_user_oauth_extra_headers, raising=False, + rest_endpoints, + "_get_user_oauth_extra_headers", + fake_get_user_oauth_extra_headers, + raising=False, + ) + monkeypatch.setattr( + rest_endpoints, + "_get_tools_for_single_server", + fake_get_tools, + raising=False, ) - monkeypatch.setattr(rest_endpoints, "_get_tools_for_single_server", fake_get_tools, raising=False) request = _build_request(path="/mcp-rest/tools/list", method="GET") result = await rest_endpoints.list_tool_rest_api( @@ -1124,3 +1186,179 @@ class TestGetToolsForSingleServer: assert "tool3" in tool_names assert "tool1" not in tool_names assert "tool4" not in tool_names + + +class TestStdioCommandAllowlist: + """Tests for MCP stdio command allowlist validation.""" + + def test_allowed_command_passes_validation(self): + """npx, uvx, python, etc. should be accepted.""" + req = NewMCPServerRequest( + server_name="test", + transport="stdio", + command="npx", + args=["-y", "@modelcontextprotocol/server-filesystem"], + ) + assert req.command == "npx" + + def test_disallowed_command_raises(self): + """Arbitrary commands like bash should be rejected.""" + with pytest.raises(ValueError, match="not in the allowed commands list"): + NewMCPServerRequest( + server_name="test", + transport="stdio", + command="bash", + args=["-c", "echo pwned"], + ) + + def test_sh_command_raises(self): + """sh should be rejected.""" + with pytest.raises(ValueError, match="not in the allowed commands list"): + NewMCPServerRequest( + server_name="test", + transport="stdio", + command="sh", + args=["-c", "id > /tmp/output.txt"], + ) + + def test_absolute_path_bypass_blocked(self): + """/bin/bash should be blocked (basename is 'bash').""" + with pytest.raises(ValueError, match="not in the allowed commands list"): + NewMCPServerRequest( + server_name="test", + transport="stdio", + command="/bin/bash", + args=["-c", "echo pwned"], + ) + + def test_absolute_path_to_allowed_command_works(self): + """/usr/bin/python3 should pass (basename is 'python3').""" + req = NewMCPServerRequest( + server_name="test", + transport="stdio", + command="/usr/bin/python3", + args=["-m", "some_module"], + ) + assert req.command == "/usr/bin/python3" + + def test_http_transport_ignores_allowlist(self): + """HTTP/SSE transport should not trigger command validation.""" + req = NewMCPServerRequest( + server_name="test", + transport="sse", + url="https://example.com/mcp", + ) + assert req.transport == "sse" + + def test_uvx_command_passes(self): + req = NewMCPServerRequest( + server_name="test", + transport="stdio", + command="uvx", + args=["mcp-server-sqlite"], + ) + assert req.command == "uvx" + + def test_node_command_passes(self): + req = NewMCPServerRequest( + server_name="test", + transport="stdio", + command="node", + args=["server.js"], + ) + assert req.command == "node" + + +class TestEndpointRoleChecks: + """Tests for PROXY_ADMIN role checks on MCP test endpoints.""" + + def test_test_connection_has_auth_dependency(self): + route = _get_route("/mcp-rest/test/connection", "POST") + assert _route_has_dependency(route, user_api_key_auth) + + def test_test_tools_list_has_auth_dependency(self): + route = _get_route("/mcp-rest/test/tools/list", "POST") + assert _route_has_dependency(route, user_api_key_auth) + + @pytest.mark.asyncio + async def test_test_connection_rejects_non_admin(self): + """Non-admin users should get 403 from test_connection.""" + from litellm.proxy._types import LitellmUserRoles + + payload = NewMCPServerRequest( + server_name="test", + url="https://example.com/mcp", + auth_type=MCPAuth.none, + ) + user_key = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="non_admin", + api_key="sk-test", + ) + request = _build_request() + + with pytest.raises(HTTPException) as exc_info: + await rest_endpoints.test_connection( + request=request, + new_mcp_server_request=payload, + user_api_key_dict=user_key, + ) + assert exc_info.value.status_code == 403 + + @pytest.mark.asyncio + async def test_test_tools_list_rejects_non_admin(self): + """Non-admin users should get 403 from test_tools_list.""" + from litellm.proxy._types import LitellmUserRoles + + payload = NewMCPServerRequest( + server_name="test", + url="https://example.com/mcp", + auth_type=MCPAuth.none, + ) + user_key = UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, + user_id="non_admin", + api_key="sk-test", + ) + request = _build_request() + + with pytest.raises(HTTPException) as exc_info: + await rest_endpoints.test_tools_list( + request=request, + new_mcp_server_request=payload, + user_api_key_dict=user_key, + ) + assert exc_info.value.status_code == 403 + + @pytest.mark.asyncio + async def test_test_connection_allows_admin(self, monkeypatch): + """PROXY_ADMIN should pass the role check.""" + from litellm.proxy._types import LitellmUserRoles + + async def fake_execute(*args, **kwargs): + return {"status": "ok"} + + monkeypatch.setattr( + rest_endpoints, + "_execute_with_mcp_client", + fake_execute, + ) + + payload = NewMCPServerRequest( + server_name="test", + url="https://example.com/mcp", + auth_type=MCPAuth.none, + ) + user_key = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + user_id="admin", + api_key="sk-admin", + ) + request = _build_request() + + result = await rest_endpoints.test_connection( + request=request, + new_mcp_server_request=payload, + user_api_key_dict=user_key, + ) + assert result["status"] == "ok"