From 4ec4ab99d0df118e702432ce7adfc99f7d33e6ec Mon Sep 17 00:00:00 2001 From: Mateo Wang <277851410+mateo-berri@users.noreply.github.com> Date: Fri, 5 Jun 2026 20:15:11 -0700 Subject: [PATCH] feat(mcp): per-server env vars with global + per-user scopes (#28917) --- .../migration.sql | 23 + .../litellm_proxy_extras/schema.prisma | 20 + litellm/experimental_mcp_client/client.py | 38 +- litellm/proxy/_experimental/mcp_server/db.py | 429 ++++- .../mcp_server/mcp_server_manager.py | 341 +++- .../mcp_server/rest_endpoints.py | 50 +- .../proxy/_experimental/mcp_server/server.py | 11 + .../proxy/_experimental/mcp_server/utils.py | 138 +- litellm/proxy/_types.py | 60 + .../key_management_endpoints.py | 10 + .../mcp_management_endpoints.py | 314 +++- litellm/proxy/management_helpers/utils.py | 54 + litellm/proxy/schema.prisma | 20 + .../types/mcp_server/mcp_server_manager.py | 4 + schema.prisma | 20 + .../test_mcp_client.py | 157 +- .../litellm_core_utils/test_token_counter.py | 7 +- .../mcp_server/test_db_credentials.py | 73 +- .../mcp_server/test_mcp_env_vars.py | 1652 +++++++++++++++++ .../mcp_server/test_mcp_server.py | 6 +- .../mcp_server/test_mcp_server_manager.py | 227 +++ .../mcp_server/test_mcp_sigv4_auth.py | 53 + .../mcp_server/test_rest_endpoints.py | 46 + .../test_mcp_management_endpoints.py | 1108 ++++++++++- .../test_management_helpers_utils.py | 148 ++ .../e2e_tests/tests/mcp/mcpServers.spec.ts | 6 +- .../components/mcp_tools/EnvVarsSection.tsx | 142 ++ .../mcp_tools/MCPPermissionManagement.tsx | 11 + .../components/mcp_tools/MCPServerCard.tsx | 383 ++++ .../components/mcp_tools/UserEnvVarsModal.tsx | 141 ++ .../mcp_tools/create_mcp_server.tsx | 14 +- .../components/mcp_tools/mcp_server_edit.tsx | 31 +- .../components/mcp_tools/mcp_servers.test.tsx | 1 + .../src/components/mcp_tools/mcp_servers.tsx | 236 ++- .../src/components/mcp_tools/types.tsx | 35 + .../src/components/mcp_tools/utils.tsx | 26 + .../src/components/networking.tsx | 30 + 37 files changed, 5953 insertions(+), 112 deletions(-) create mode 100644 litellm-proxy-extras/litellm_proxy_extras/migrations/20260520120000_add_mcp_env_vars/migration.sql create mode 100644 tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_env_vars.py create mode 100644 ui/litellm-dashboard/src/components/mcp_tools/EnvVarsSection.tsx create mode 100644 ui/litellm-dashboard/src/components/mcp_tools/MCPServerCard.tsx create mode 100644 ui/litellm-dashboard/src/components/mcp_tools/UserEnvVarsModal.tsx diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260520120000_add_mcp_env_vars/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260520120000_add_mcp_env_vars/migration.sql new file mode 100644 index 0000000000..08d35cd74a --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260520120000_add_mcp_env_vars/migration.sql @@ -0,0 +1,23 @@ +-- AlterTable: add admin-configured env_vars to MCP server table +ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN IF NOT EXISTS "env_vars" JSONB DEFAULT '[]'; + +-- CreateTable: per-user env var values for MCP servers +CREATE TABLE IF NOT EXISTS "LiteLLM_MCPUserEnvVars" ( + "id" TEXT NOT NULL, + "user_id" TEXT NOT NULL, + "server_id" TEXT NOT NULL, + "values_b64" TEXT NOT NULL, + "created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + "updated_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + + CONSTRAINT "LiteLLM_MCPUserEnvVars_pkey" PRIMARY KEY ("id") +); + +-- CreateIndex +CREATE UNIQUE INDEX IF NOT EXISTS "LiteLLM_MCPUserEnvVars_user_id_server_id_key" ON "LiteLLM_MCPUserEnvVars"("user_id", "server_id"); + +-- CreateIndex +CREATE INDEX IF NOT EXISTS "LiteLLM_MCPUserEnvVars_user_id_idx" ON "LiteLLM_MCPUserEnvVars"("user_id"); + +-- CreateIndex +CREATE INDEX IF NOT EXISTS "LiteLLM_MCPUserEnvVars_server_id_idx" ON "LiteLLM_MCPUserEnvVars"("server_id"); diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index b51da83a63..e21c001649 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -311,6 +311,11 @@ model LiteLLM_MCPServerTable { tool_name_to_description Json? @default("{}") extra_headers String[] @default([]) static_headers Json? @default("{}") + // Admin-configured environment variables interpolated into static_headers + // via ${NAME} syntax. Stored as an array of + // {name, value, scope, description}. scope is "global" (value used as-is) + // or "user" (value supplied per-user via LiteLLM_MCPUserEnvVars). + env_vars Json? @default("[]") // Health check status status String? @default("unknown") last_health_check DateTime? @@ -366,6 +371,21 @@ model LiteLLM_MCPUserCredentials { @@unique([user_id, server_id]) } +// Per-user environment variable values for MCP servers. +// values_b64 is an encrypted JSON object: {VAR_NAME: "value", ...}. +model LiteLLM_MCPUserEnvVars { + id String @id @default(uuid()) + user_id String + server_id String + values_b64 String + created_at DateTime @default(now()) + updated_at DateTime @default(now()) @updatedAt + + @@unique([user_id, server_id]) + @@index([user_id]) + @@index([server_id]) +} + // Generate Tokens for Proxy model LiteLLM_VerificationToken { token String @id diff --git a/litellm/experimental_mcp_client/client.py b/litellm/experimental_mcp_client/client.py index aed00c060c..c6d427e7f0 100644 --- a/litellm/experimental_mcp_client/client.py +++ b/litellm/experimental_mcp_client/client.py @@ -60,6 +60,27 @@ def to_basic_auth(auth_value: str) -> str: return base64.b64encode(auth_value.encode("utf-8")).decode() +def _strip_header_whitespace(headers: Dict[str, str]) -> Dict[str, str]: + return { + (key.strip() if isinstance(key, str) else key): ( + value.strip() if isinstance(value, str) else value + ) + for key, value in headers.items() + } + + +def _first_non_cancelled_cause(exc: BaseException) -> Optional[BaseException]: + queue: List[BaseException] = [exc] + while queue: + current = queue.pop(0) + nested = getattr(current, "exceptions", None) + if nested: + queue.extend(nested) + elif not isinstance(current, asyncio.CancelledError): + return current + return None + + TSessionResult = TypeVar("TSessionResult") @@ -335,6 +356,7 @@ class MCPClient: user input (elicitation), or send log messages. """ transport = await transport_ctx.__aenter__() + in_flight_error: Optional[BaseException] = None try: read_stream, write_stream = transport[0], transport[1] # Build session kwargs with optional callbacks @@ -360,11 +382,21 @@ class MCPClient: await session_ctx.__aexit__(None, None, None) except BaseException as e: verbose_logger.debug(f"Error during session context exit: {e}") + except BaseException as e: + in_flight_error = e + raise finally: try: await transport_ctx.__aexit__(None, None, None) - except BaseException as e: - verbose_logger.debug(f"Error during transport context exit: {e}") + except BaseException as exit_error: + verbose_logger.debug( + f"Error during transport context exit: {exit_error}" + ) + root_cause = _first_non_cancelled_cause(exit_error) + if root_cause is not None and isinstance( + in_flight_error, asyncio.CancelledError + ): + raise root_cause from in_flight_error async def run_with_session( self, operation: Callable[[ClientSession], Awaitable[TSessionResult]] @@ -426,7 +458,7 @@ class MCPClient: # update the headers with the extra headers if self.extra_headers: headers.update(self.extra_headers) - return headers + return _strip_header_whitespace(headers) def _create_httpx_client_factory(self) -> Callable[..., httpx.AsyncClient]: """ diff --git a/litellm/proxy/_experimental/mcp_server/db.py b/litellm/proxy/_experimental/mcp_server/db.py index d7b2224eb6..4d41afcb6f 100644 --- a/litellm/proxy/_experimental/mcp_server/db.py +++ b/litellm/proxy/_experimental/mcp_server/db.py @@ -1,5 +1,6 @@ import base64 import binascii +import hashlib import json from datetime import datetime, timedelta, timezone from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast @@ -11,6 +12,7 @@ from litellm.proxy._types import ( LiteLLM_ObjectPermissionTable, LiteLLM_TeamTable, MCPApprovalStatus, + MCPEnvVarScope, MCPSubmissionsSummary, NewMCPServerRequest, SpecialMCPServerName, @@ -28,6 +30,144 @@ from litellm.types.llms.custom_http import httpxSpecialProvider from litellm.types.mcp import MCPCredentials +def _is_global_env_var_scope(scope: Any) -> bool: + """``scope="user"`` entries are placeholders the user fills in; everything + else (including a missing scope) is an admin-supplied global value.""" + return scope != MCPEnvVarScope.user and scope != "user" + + +def _encrypt_global_env_var_values(env_vars: Iterable[Dict[str, Any]]) -> None: + """Encrypt ``scope="global"`` env var values in place before persisting. + + Global values hold admin-supplied secrets (API keys, passwords) that get + interpolated into headers, so they are encrypted at rest like credentials + and the per-user ``values_b64`` column. Per-user placeholders are not + secrets and are stored verbatim. + """ + for entry in env_vars: + if not _is_global_env_var_scope(entry.get("scope")): + continue + value = entry.get("value") + if value: + entry["value"] = encrypt_value_helper(value) + + +def decrypt_global_env_var_values(env_vars: Optional[Iterable[Any]]) -> None: + """Decrypt ``scope="global"`` env var values in place after reading the DB. + + Accepts ``MCPEnvVar`` models (``LiteLLM_MCPServerTable``) or plain dicts + (raw rows / deserialized JSON). Global values are always stored encrypted, + so a value that no longer decrypts (e.g. after a salt-key change) is dropped + and a warning is logged rather than forwarding the ciphertext into upstream + ``${NAME}`` headers, where it would silently fail. + """ + if not env_vars: + return + for entry in env_vars: + is_dict = isinstance(entry, dict) + scope = entry.get("scope") if is_dict else getattr(entry, "scope", None) + if not _is_global_env_var_scope(scope): + continue + value = entry.get("value") if is_dict else getattr(entry, "value", None) + if not value: + continue + decrypted = decrypt_value_helper( + value=value, + key="mcp_global_env_var", + exception_type="debug", + return_original_value=False, + ) + if decrypted is None: + name = entry.get("name") if is_dict else getattr(entry, "name", None) + verbose_proxy_logger.warning( + "MCP global env var %s failed to decrypt (LITELLM_SALT_KEY " + "changed?); dropping it so ciphertext is not sent upstream", + name, + ) + decrypted = "" + if is_dict: + entry["value"] = decrypted + else: + entry.value = decrypted + + +def _decrypt_env_vars_on_returned_row(row: Any) -> None: + """Decrypt ``scope="global"`` env var values on a row returned by Prisma create/update. + + Prisma may hand back ``env_vars`` either as a parsed list (the common case for + JSONB columns) or as a raw JSON string (observed for some write paths). The + in-place decrypt helper only mutates iterables of dicts/models, so a string + payload would silently skip decryption and ciphertext would leak into the + registry via ``add_server``/``update_server`` (which trust the caller). + Parse the string back to a list so the in-place decrypt actually runs, and + write the decrypted list back onto the row so downstream consumers see plain + values. + """ + env_vars = getattr(row, "env_vars", None) + if env_vars is None: + return + if isinstance(env_vars, str): + try: + env_vars = json.loads(env_vars) + except (json.JSONDecodeError, TypeError): + return + if not isinstance(env_vars, list): + return + try: + setattr(row, "env_vars", env_vars) + except (AttributeError, TypeError): + pass + decrypt_global_env_var_values(env_vars) + + +def _reencrypt_global_env_var_values( + env_vars: Optional[Iterable[Any]], new_encryption_key: str +) -> Optional[List[Dict[str, Any]]]: + """Re-encrypt ``scope="global"`` env var values for master-key rotation. + + Each global value is decrypted with the current salt key and re-encrypted + under ``new_encryption_key``. Returns the rebuilt list when at least one + value was rotated, else ``None`` so the caller can skip the DB write. A + value that fails to decrypt is left untouched (and logged) so a corrupt + entry is preserved for recovery rather than overwritten. + """ + if not env_vars: + return None + if isinstance(env_vars, str): + try: + env_vars = json.loads(env_vars) + except (json.JSONDecodeError, TypeError): + return None + if not env_vars: + return None + rebuilt = [dict(v) for v in env_vars] + rotated = False + for entry in rebuilt: + if not _is_global_env_var_scope(entry.get("scope")): + continue + value = entry.get("value") + if not value: + continue + decrypted = decrypt_value_helper( + value=value, + key="mcp_global_env_var", + exception_type="debug", + return_original_value=False, + ) + if decrypted is None: + verbose_proxy_logger.warning( + "rotate_mcp_server_credentials_master_key: could not decrypt " + "global env var %s, skipping", + entry.get("name"), + ) + continue + entry["value"] = encrypt_value_helper( + decrypted, new_encryption_key=new_encryption_key + ) + rotated = True + return rebuilt if rotated else None + + def _prepare_mcp_server_data( data: Union[NewMCPServerRequest, UpdateMCPServerRequest], exclude_unset: bool = False, @@ -98,6 +238,16 @@ def _prepare_mcp_server_data( if data_dict.get("static_headers") is not None: data_dict["static_headers"] = safe_dumps(data_dict["static_headers"]) + # env_vars is read from ``data_dict`` (not ``data``) like every other JSON + # column so the exclude_unset filter is respected: a partial update that + # omits env_vars never overwrites the stored value. Global values are + # encrypted at rest before serialization. + env_vars = data_dict.get("env_vars") + if env_vars is not None: + serialized_env_vars = [dict(v) for v in env_vars] + _encrypt_global_env_var_values(serialized_env_vars) + data_dict["env_vars"] = safe_dumps(serialized_env_vars) + if data_dict.get("mcp_info") is not None: data_dict["mcp_info"] = safe_dumps(data_dict["mcp_info"]) @@ -207,10 +357,13 @@ async def get_all_mcp_servers( where=where if where else {} ) - return [ + tables = [ LiteLLM_MCPServerTable(**mcp_server.model_dump()) for mcp_server in mcp_servers ] + for table in tables: + decrypt_global_env_var_values(table.env_vars) + return tables except Exception as e: verbose_proxy_logger.debug( "litellm.proxy._experimental.mcp_server.db.py::get_all_mcp_servers - {}".format( @@ -226,14 +379,16 @@ async def get_mcp_server( """ Returns the matching mcp server from the db iff exists """ - mcp_server: Optional[LiteLLM_MCPServerTable] = ( - await prisma_client.db.litellm_mcpservertable.find_unique( - where={ - "server_id": server_id, - } - ) + mcp_server = await prisma_client.db.litellm_mcpservertable.find_unique( + where={ + "server_id": server_id, + } ) - return mcp_server + if mcp_server is None: + return None + table = LiteLLM_MCPServerTable(**mcp_server.model_dump()) + decrypt_global_env_var_values(table.env_vars) + return table async def get_mcp_servers( @@ -251,7 +406,9 @@ async def get_mcp_servers( ) final_mcp_servers: List[LiteLLM_MCPServerTable] = [] for _mcp_server in _mcp_servers: - final_mcp_servers.append(LiteLLM_MCPServerTable(**_mcp_server.model_dump())) + table = LiteLLM_MCPServerTable(**_mcp_server.model_dump()) + decrypt_global_env_var_values(table.env_vars) + final_mcp_servers.append(table) return final_mcp_servers @@ -399,6 +556,11 @@ async def delete_mcp_server( """ Delete the mcp server from the db by server_id + The server-row delete is the commit point. Per-user env var rows have no FK + cascade, so they are cleaned up afterwards on a best-effort basis: a transient + failure there leaves only orphaned rows pointing at a now-missing server and + must not turn a successful delete into a caller-visible error. + Returns the deleted mcp server record if it exists, otherwise None """ deleted_server = await prisma_client.db.litellm_mcpservertable.delete( @@ -406,6 +568,18 @@ async def delete_mcp_server( "server_id": server_id, }, ) + if deleted_server is not None: + try: + await prisma_client.db.litellm_mcpuserenvvars.delete_many( + where={"server_id": server_id} + ) + except Exception as e: + verbose_proxy_logger.warning( + "MCP server %s deleted but per-user env var cleanup failed; " + "orphaned rows can be removed on a later delete: %s", + server_id, + e, + ) return deleted_server @@ -429,6 +603,7 @@ async def create_mcp_server( data=data_dict # type: ignore ) + _decrypt_env_vars_on_returned_row(new_mcp_server) return new_mcp_server @@ -506,40 +681,52 @@ async def update_mcp_server( where={"server_id": data.server_id}, data=data_dict # type: ignore ) + _decrypt_env_vars_on_returned_row(updated_mcp_server) return updated_mcp_server async def rotate_mcp_server_credentials_master_key( prisma_client: PrismaClient, touched_by: str, new_master_key: str ): + from litellm.litellm_core_utils.safe_json_dumps import safe_dumps + mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many() + updated = 0 for mcp_server in mcp_servers: + update_data: Dict[str, Any] = {} + credentials = mcp_server.credentials - if not credentials: + if credentials: + # Decrypt with current key first, then re-encrypt with new key + decrypted_credentials = decrypt_credentials( + credentials=cast(MCPCredentials, dict(credentials)), + ) + encrypted_credentials = encrypt_credentials( + credentials=decrypted_credentials, + encryption_key=new_master_key, + ) + update_data["credentials"] = safe_dumps(encrypted_credentials) + + rotated_env_vars = _reencrypt_global_env_var_values( + mcp_server.env_vars, new_master_key + ) + if rotated_env_vars is not None: + update_data["env_vars"] = safe_dumps(rotated_env_vars) + + if not update_data: continue - credentials_copy = dict(credentials) - # Decrypt with current key first, then re-encrypt with new key - decrypted_credentials = decrypt_credentials( - credentials=cast(MCPCredentials, credentials_copy), - ) - encrypted_credentials = encrypt_credentials( - credentials=decrypted_credentials, - encryption_key=new_master_key, - ) - - from litellm.litellm_core_utils.safe_json_dumps import safe_dumps - - serialized_credentials = safe_dumps(encrypted_credentials) - + update_data["updated_by"] = touched_by await prisma_client.db.litellm_mcpservertable.update( where={"server_id": mcp_server.server_id}, - data={ - "credentials": serialized_credentials, - "updated_by": touched_by, - }, + data=update_data, ) + updated += 1 + verbose_proxy_logger.info( + "rotate_mcp_server_credentials_master_key: rotated %d MCP server row(s)", + updated, + ) def _decode_user_credential(stored: str) -> Optional[str]: @@ -594,6 +781,8 @@ async def rotate_mcp_user_credentials_master_key( are logged and skipped so one corrupt row does not abort the rotation. """ rows = await prisma_client.db.litellm_mcpusercredentials.find_many() + rotated = 0 + skipped = 0 for row in rows: plaintext = _decode_user_credential(row.credential_b64) if plaintext is None: @@ -603,6 +792,7 @@ async def rotate_mcp_user_credentials_master_key( row.user_id, row.server_id, ) + skipped += 1 continue re_encrypted = encrypt_value_helper( plaintext, new_encryption_key=new_master_key @@ -616,6 +806,61 @@ async def rotate_mcp_user_credentials_master_key( }, data={"credential_b64": re_encrypted}, ) + rotated += 1 + verbose_proxy_logger.info( + "rotate_mcp_user_credentials_master_key: rotated %d row(s), skipped %d", + rotated, + skipped, + ) + + +async def rotate_mcp_user_env_vars_master_key( + prisma_client: PrismaClient, new_master_key: str +): + """Re-encrypt every ``LiteLLM_MCPUserEnvVars`` row with ``new_master_key``. + + Reads each ``values_b64`` blob with the current salt key and writes it back + encrypted under the new master key. Rows that fail to decrypt are logged and + skipped so one corrupt row does not abort the rotation nor overwrite values + that may still be recoverable. + """ + rows = await prisma_client.db.litellm_mcpuserenvvars.find_many() + rotated = 0 + skipped = 0 + for row in rows: + plaintext = decrypt_value_helper( + value=row.values_b64, + key="mcp_user_env_vars", + exception_type="debug", + return_original_value=False, + ) + if plaintext is None: + verbose_proxy_logger.warning( + "rotate_mcp_user_env_vars_master_key: could not decrypt env vars " + "for user_id=%s server_id=%s, skipping", + row.user_id, + row.server_id, + ) + skipped += 1 + continue + re_encrypted = encrypt_value_helper( + plaintext, new_encryption_key=new_master_key + ) + await prisma_client.db.litellm_mcpuserenvvars.update( + where={ + "user_id_server_id": { + "user_id": row.user_id, + "server_id": row.server_id, + } + }, + data={"values_b64": re_encrypted}, + ) + rotated += 1 + verbose_proxy_logger.info( + "rotate_mcp_user_env_vars_master_key: rotated %d row(s), skipped %d", + rotated, + skipped, + ) async def store_user_credential( @@ -927,7 +1172,9 @@ async def approve_mcp_server( "updated_by": touched_by, }, ) - return LiteLLM_MCPServerTable(**updated.model_dump()) + table = LiteLLM_MCPServerTable(**updated.model_dump()) + decrypt_global_env_var_values(table.env_vars) + return table async def reject_mcp_server( @@ -949,7 +1196,9 @@ async def reject_mcp_server( where={"server_id": server_id}, data=data, ) - return LiteLLM_MCPServerTable(**updated.model_dump()) + table = LiteLLM_MCPServerTable(**updated.model_dump()) + decrypt_global_env_var_values(table.env_vars) + return table async def get_mcp_submissions( @@ -966,6 +1215,8 @@ async def get_mcp_submissions( take=500, # safety cap; paginate if needed in a future iteration ) items = [LiteLLM_MCPServerTable(**r.model_dump()) for r in rows] + for item in items: + decrypt_global_env_var_values(item.env_vars) pending = sum( 1 for i in items if i.approval_status == MCPApprovalStatus.pending_review @@ -980,3 +1231,121 @@ async def get_mcp_submissions( rejected=rejected, items=items, ) + + +# ── Per-user MCP environment variables ──────────────────────────────────── + + +def _decode_user_env_vars(stored: str) -> Dict[str, str]: + """Decrypt a ``values_b64`` blob and parse it as a flat ``{name: value}`` dict.""" + decrypted = decrypt_value_helper( + value=stored, + key="mcp_user_env_vars", + exception_type="debug", + return_original_value=False, + ) + if decrypted is None: + if stored: + verbose_proxy_logger.warning( + "MCP per-user env vars failed to decrypt (LITELLM_SALT_KEY " + "changed?); treating as unset so the user is prompted to " + "re-enter them rather than silently forwarding ciphertext" + ) + return {} + try: + parsed = json.loads(decrypted) + except (ValueError, TypeError): + return {} + if not isinstance(parsed, dict): + return {} + return {str(k): str(v) for k, v in parsed.items()} + + +async def get_user_env_vars( + prisma_client: PrismaClient, + user_id: str, + server_id: str, +) -> Dict[str, str]: + """Return the calling user's env var dict for ``server_id`` (empty if none).""" + row = await prisma_client.db.litellm_mcpuserenvvars.find_unique( + where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}} + ) + if row is None: + return {} + return _decode_user_env_vars(row.values_b64) + + +async def get_user_env_vars_bulk( + prisma_client: PrismaClient, + user_id: str, + server_ids: Iterable[str], +) -> Dict[str, Dict[str, str]]: + """Return ``{server_id: {var_name: value}}`` for one user across many servers. + + Servers with no stored row are simply absent from the result. + """ + ids = list(server_ids) + if not ids: + return {} + rows = await prisma_client.db.litellm_mcpuserenvvars.find_many( + where={"user_id": user_id, "server_id": {"in": ids}} + ) + return {row.server_id: _decode_user_env_vars(row.values_b64) for row in rows} + + +async def merge_user_env_vars( + prisma_client: PrismaClient, + user_id: str, + server_id: str, + updates: Dict[str, str], + allowed_names: Iterable[str], +) -> Dict[str, str]: + """Merge ``updates`` into the user's stored env vars for ``server_id`` and + return the resulting set. + + The read-modify-write runs inside a transaction guarded by a + ``(user_id, server_id)`` advisory lock so two concurrent writes from the + same user can't drop one update. Names outside ``allowed_names`` are pruned, + so an admin retiring a user-scoped variable also clears its stored value. + """ + allowed = set(allowed_names) + lock_key = int.from_bytes( + hashlib.blake2b(f"{user_id}:{server_id}".encode(), digest_size=8).digest(), + "big", + signed=True, + ) + async with prisma_client.db.tx() as tx: + await tx.execute_raw("SELECT pg_advisory_xact_lock($1::bigint)", lock_key) + row = await tx.litellm_mcpuserenvvars.find_unique( + where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}} + ) + existing = _decode_user_env_vars(row.values_b64) if row is not None else {} + merged = {k: v for k, v in {**existing, **updates}.items() if k in allowed} + encoded = encrypt_value_helper(json.dumps(merged)) + await tx.litellm_mcpuserenvvars.upsert( + where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}, + data={ + "create": { + "user_id": user_id, + "server_id": server_id, + "values_b64": encoded, + }, + "update": {"values_b64": encoded}, + }, + ) + return merged + + +async def delete_user_env_vars( + prisma_client: PrismaClient, + user_id: str, + server_id: str, +) -> None: + """Remove the calling user's env var values for ``server_id``. + + Uses ``delete_many`` so a missing row is a no-op; real DB errors still + propagate to the caller instead of being silently swallowed. + """ + await prisma_client.db.litellm_mcpuserenvvars.delete_many( + where={"user_id": user_id, "server_id": server_id} + ) diff --git a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py index 5fd028b343..7048f5bf7c 100644 --- a/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py +++ b/litellm/proxy/_experimental/mcp_server/mcp_server_manager.py @@ -58,20 +58,26 @@ from litellm.proxy._experimental.mcp_server.sampling_handler import ( from litellm.proxy._experimental.mcp_server.oauth2_token_cache import resolve_mcp_auth from litellm.proxy._experimental.mcp_server.utils import ( MCP_TOOL_PREFIX_SEPARATOR, + MCPMissingUserEnvVarsError, add_server_prefix_to_name, + build_env_var_setup_url, + collect_env_var_references, compute_short_server_prefix, get_server_prefix, + interpolate_headers, is_short_mcp_tool_prefix_enabled, is_tool_name_prefixed, iter_known_server_prefixes, merge_mcp_headers, normalize_server_name, + parse_admin_env_vars, split_server_prefix_from_name, validate_mcp_server_name, ) from litellm.proxy._types import ( LiteLLM_MCPServerTable, MCPAuthType, + MCPEnvVar, MCPTransport, MCPTransportType, UserAPIKeyAuth, @@ -124,6 +130,33 @@ _AZURE_ENTRA_HOSTS = { "login.chinacloudapi.cn", # China } +# Short-lived in-memory cache for per-user MCP env var values, mirroring the +# BYOK credential cache. Keyed by (user_id, server_id); value is +# (values_dict, monotonic_timestamp). Keeps the tool-call and tool-listing +# paths off the DB on every request within the TTL window. +_user_env_vars_cache: Dict[Tuple[str, str], Tuple[Dict[str, str], float]] = {} +_USER_ENV_VARS_CACHE_TTL = 60 # seconds +_USER_ENV_VARS_CACHE_MAX_SIZE = 4096 # cap to prevent unbounded growth + + +def invalidate_user_env_vars_cache(user_id: str, server_id: str) -> None: + """Drop a cached entry after the user stores or clears their env var values + so the next request reads the fresh value instead of a stale one.""" + _user_env_vars_cache.pop((user_id, server_id), None) + + +def _write_user_env_vars_cache( + user_id: str, server_id: str, values: Dict[str, str] +) -> None: + cache_key = (user_id, server_id) + # Re-insert at the tail so eviction drops the oldest-written entry, not a + # freshly refreshed one, and only sheds a single entry instead of wiping the + # whole cache (which would stampede the DB). + _user_env_vars_cache.pop(cache_key, None) + if len(_user_env_vars_cache) >= _USER_ENV_VARS_CACHE_MAX_SIZE: + _user_env_vars_cache.pop(next(iter(_user_env_vars_cache)), None) + _user_env_vars_cache[cache_key] = (values, time.monotonic()) + def _should_strip_caller_authorization( mcp_server: MCPServer, @@ -295,6 +328,31 @@ def _deserialize_json_dict(data: Any) -> Optional[Dict[str, str]]: return data +def _deserialize_json_list(data: Any) -> Optional[List[Dict[str, Any]]]: + """Deserialize a JSON array stored in the DB (``env_vars`` and friends). + + Returns ``None`` for empty / null / unparseable input. Accepts strings + (raw JSON), already-materialized lists of dicts, and lists of Pydantic + models (Prisma may hydrate a JSON column such as ``env_vars`` into + ``MCPEnvVar`` objects); model entries are normalized to plain dicts so + downstream consumers expecting ``List[Dict[str, Any]]`` validate. + """ + if data is None or data == "" or data == []: + return None + if isinstance(data, str): + try: + parsed = json.loads(data) + except (json.JSONDecodeError, TypeError): + return None + data = parsed + if not isinstance(data, list): + return None + return [ + item.model_dump(mode="json") if hasattr(item, "model_dump") else item + for item in data + ] + + def _create_sampling_callback(user_api_key_auth: Optional[Any] = None): """ Create a sampling callback for MCP ClientSession. @@ -456,7 +514,8 @@ class MCPServerManager: - server is OpenAPI (spec_path), - non-empty upstream instructions are already cached, - auth preconditions match health_check_server's skip rules - (per-user auth / missing static auth token), + (per-user auth / missing static auth token / static headers that + reference a per-user env var), - a prior probe attempt for this server is within MCP_HEALTH_CHECK_TIMEOUT seconds (the probe is a health-check-shaped op and already uses this knob for its inner call timeout; reusing it @@ -471,6 +530,8 @@ class MCPServerManager: return if server.requires_per_user_auth: return + if self._references_per_user_env_var(server): + return if ( server.auth_type and server.auth_type != MCPAuth.none @@ -495,8 +556,13 @@ class MCPServerManager: ) try: + resolved_static_headers = await self._resolve_static_headers_with_env_vars( + server=server, + user_api_key_auth=None, + raise_on_missing=False, + ) extra_headers: Optional[Dict[str, str]] = ( - dict(server.static_headers) if server.static_headers else None + dict(resolved_static_headers) if resolved_static_headers else None ) client = await self._create_mcp_client( server=server, @@ -656,6 +722,7 @@ class MCPServerManager: allowed_params=server_config.get("allowed_params", None), access_groups=server_config.get("access_groups", None), static_headers=server_config.get("static_headers", None), + env_vars=server_config.get("env_vars", None), allow_all_keys=bool(server_config.get("allow_all_keys", False)), available_on_public_internet=bool( server_config.get("available_on_public_internet", True) @@ -920,17 +987,41 @@ class MCPServerManager: f"Server ID {mcp_server.server_id} not found in registry" ) + def _resolve_env_vars_list( + self, + mcp_server: LiteLLM_MCPServerTable, + *, + env_vars_are_encrypted: bool, + ) -> Optional[List[Dict[str, Any]]]: + env_vars_list = _deserialize_json_list(getattr(mcp_server, "env_vars", None)) + if env_vars_are_encrypted: + from litellm.proxy._experimental.mcp_server.db import ( # noqa: PLC0415 + decrypt_global_env_var_values, + ) + + decrypt_global_env_var_values(env_vars_list) + return env_vars_list + async def build_mcp_server_from_table( self, mcp_server: LiteLLM_MCPServerTable, *, credentials_are_encrypted: bool = True, + env_vars_are_encrypted: Optional[bool] = None, ) -> MCPServer: _mcp_info: MCPInfo = mcp_server.mcp_info or {} env_dict = _deserialize_json_dict(getattr(mcp_server, "env", None)) static_headers_dict = _deserialize_json_dict( getattr(mcp_server, "static_headers", None) ) + env_vars_list = self._resolve_env_vars_list( + mcp_server, + env_vars_are_encrypted=( + credentials_are_encrypted + if env_vars_are_encrypted is None + else env_vars_are_encrypted + ), + ) credentials_dict = _deserialize_json_dict( getattr(mcp_server, "credentials", None) ) @@ -1030,6 +1121,7 @@ class MCPServerManager: mcp_info=mcp_info, extra_headers=getattr(mcp_server, "extra_headers", None), static_headers=static_headers_dict, + env_vars=env_vars_list, client_id=client_id_value or getattr(mcp_server, "client_id", None), client_secret=client_secret_value or getattr(mcp_server, "client_secret", None), @@ -1128,7 +1220,14 @@ class MCPServerManager: return try: if mcp_server.server_id not in self.registry: - new_server = await self.build_mcp_server_from_table(mcp_server) + # Callers hand us a record returned by the db.py read/write + # helpers, which already decrypt global env var values (the + # `credentials` field is the only one still encrypted here). + # Re-decrypting plaintext would zero the values, so build with + # env_vars_are_encrypted=False. + new_server = await self.build_mcp_server_from_table( + mcp_server, env_vars_are_encrypted=False + ) self._assign_unique_short_prefix(new_server) self.registry[mcp_server.server_id] = new_server await self._maybe_register_openapi_tools(new_server) @@ -1151,7 +1250,11 @@ class MCPServerManager: return try: if mcp_server.server_id in self.registry: - new_server = await self.build_mcp_server_from_table(mcp_server) + # See add_server: db.py helpers already decrypted env var + # values, so don't decrypt them a second time here. + new_server = await self.build_mcp_server_from_table( + mcp_server, env_vars_are_encrypted=False + ) # Carry the previously-resolved short prefix across so the # tool names stay stable for clients holding cached lists. existing_prefix = self.registry[mcp_server.server_id].short_prefix @@ -1572,6 +1675,180 @@ class MCPServerManager: return resolved_env + def _references_per_user_env_var(self, server: MCPServer) -> bool: + """True when ``server.static_headers`` reference a per-user ``${NAME}`` env var. + + Such placeholders can only be filled from a calling user's stored values, + so a userless probe (health check / instructions prefetch) would forward + the literal ``${NAME}`` upstream and get rejected. Callers skip the probe + and report ``unknown`` instead of a misleading ``unhealthy``. + """ + static_headers = server.static_headers + env_vars = getattr(server, "env_vars", None) + if not static_headers or not env_vars: + return False + _global_values, user_specs = parse_admin_env_vars(env_vars) + user_var_names = {spec["name"] for spec in user_specs} + if not user_var_names: + return False + referenced = collect_env_var_references(strings=static_headers.values()) + return bool(referenced & user_var_names) + + async def _resolve_static_headers_with_env_vars( + self, + server: MCPServer, + user_api_key_auth: Optional[UserAPIKeyAuth], + *, + raise_on_missing: bool = True, + ) -> Optional[Dict[str, str]]: + """Return server.static_headers with ``${NAME}`` interpolated. + + Globals come from ``server.env_vars`` entries with ``scope=="global"``. + Per-user values come from the ``LiteLLM_MCPUserEnvVars`` row for the + calling user. + + When ``raise_on_missing`` is ``True`` (the tool-*call* path), raises + ``MCPMissingUserEnvVarsError`` if ``static_headers`` reference a per-user + variable the calling user has not yet supplied — converted into a + user-facing 412 by the REST layer. + + When ``raise_on_missing`` is ``False`` (the tool-*list* path), missing + per-user vars are non-blocking: we interpolate whatever is available and + leave unfilled ``${NAME}`` references untouched, so the server's tools + still appear in the listing. The user only hits the friendly error when + they actually invoke a tool that needs the missing value. + """ + static_headers = server.static_headers + env_vars = getattr(server, "env_vars", None) + if not static_headers and not env_vars: + return static_headers + + global_values, user_specs = parse_admin_env_vars(env_vars) + # An empty-valued global is treated as unset: it must not mask a per-user + # var the user still has to supply, nor override a value the user did + # supply. The unresolved ${NAME} is then left untouched, like any other + # undefined reference. + global_values = {name: value for name, value in global_values.items() if value} + user_var_names = {spec["name"] for spec in user_specs} + + # If no env vars are configured, return static_headers as-is. + if not global_values and not user_specs: + return static_headers + + # Figure out which user-scoped vars are actually referenced. A var that + # also carries a global value is always covered by that global (globals + # win in the merge below), so it can never be genuinely "missing" even if + # the user hasn't filled it in -- only vars without a global fallback do. + referenced = collect_env_var_references(strings=(static_headers or {}).values()) + referenced_user_vars = referenced & user_var_names + required_user_vars = { + name for name in referenced_user_vars if name not in global_values + } + + user_values: Dict[str, str] = {} + if required_user_vars: + try: + user_values = await self._load_user_env_vars(server, user_api_key_auth) + except Exception as exc: + # On the tool-call path a DB failure must surface as a real + # server error, not a misleading "set up your credentials" 412. + # On the listing path we stay best-effort and leave the + # unfilled ${NAME} references untouched so tools still appear. + if raise_on_missing: + raise + verbose_logger.warning( + "MCPServerManager: best-effort user env var load failed for " + "server=%s: %s", + server.server_id, + exc, + ) + + if raise_on_missing: + missing = sorted( + name for name in required_user_vars if not user_values.get(name) + ) + if missing: + # A cached negative must never produce a 412: cache + # invalidation is process-local, so a user who just stored + # values on another worker would otherwise be told their + # credentials are missing until the entry expires. Confirm + # against the DB before raising. + user_values = await self._load_user_env_vars( + server, user_api_key_auth, force_refresh=True + ) + missing = sorted( + name for name in required_user_vars if not user_values.get(name) + ) + if missing: + raise MCPMissingUserEnvVarsError( + server_id=server.server_id, + server_name=server.server_name or server.name, + missing=missing, + setup_url=build_env_var_setup_url(server.server_id), + ) + + # Only honor stored user values for currently user-scoped vars, and let + # admin globals win, so a stale row from when a var was user-scoped can + # never override the global value the admin set after switching it. + scoped_user_values = { + name: value for name, value in user_values.items() if name in user_var_names + } + merged_vars: Dict[str, str] = {**scoped_user_values, **global_values} + if not static_headers: + return static_headers + return interpolate_headers(static_headers, merged_vars) + + async def _load_user_env_vars( + self, + server: MCPServer, + user_api_key_auth: Optional[UserAPIKeyAuth], + *, + force_refresh: bool = False, + ) -> Dict[str, str]: + """Look up the calling user's env var values for ``server``. + + Returns an empty dict when no user is available. Results are cached in a + short-lived in-memory map keyed by (user_id, server_id) so the tool-call + and tool-listing paths avoid a DB round-trip per request within the TTL + window; the cache is invalidated when the user stores or clears values. + Pass ``force_refresh`` to bypass the cache read and re-fetch from the DB + (used before raising a "missing credentials" error so a process-local + stale entry cannot mask values stored on another worker). A missing DB + connection and any other DB error propagate so the caller can decide + between failing the request (tool-call path) and staying best-effort + (listing path); they must never be mistaken for "user has no values", + which would send the user a misleading "set up your credentials" 412. + """ + if user_api_key_auth is None: + return {} + user_id = getattr(user_api_key_auth, "user_id", None) + if not user_id: + return {} + + cache_key = (user_id, server.server_id) + if not force_refresh: + cached = _user_env_vars_cache.get(cache_key) + if cached is not None: + values, ts = cached + if time.monotonic() - ts < _USER_ENV_VARS_CACHE_TTL: + return values + + from litellm.proxy.proxy_server import prisma_client # noqa: PLC0415 + + if prisma_client is None: + raise RuntimeError( + "MCP per-user env vars require a database connection, but none " + "is configured. Connect a database to your proxy to use per-user " + "MCP env vars." + ) + from litellm.proxy._experimental.mcp_server.db import ( # noqa: PLC0415 + get_user_env_vars, + ) + + values = await get_user_env_vars(prisma_client, user_id, server.server_id) + _write_user_env_vars_cache(user_id, server.server_id, values) + return values + async def _create_mcp_client( self, server: MCPServer, @@ -1732,10 +2009,17 @@ class MCPServerManager: client = None try: - if server.static_headers: + # Tool *listing* must not be blocked by missing per-user env vars — + # the server's tools should still appear so the client connects. The + # friendly "missing vars" error is raised only on the tool-*call* + # path (see _call_regular_mcp_tool). + resolved_static_headers = await self._resolve_static_headers_with_env_vars( + server, user_api_key_auth, raise_on_missing=False + ) + if resolved_static_headers: if extra_headers is None: extra_headers = {} - extra_headers.update(server.static_headers) + extra_headers.update(resolved_static_headers) # MCPJWTSigner: inject signed JWT for tools/list (list path skips pre_call_hook). # Skip entirely when the signer is not configured (avoid an unnecessary @@ -3105,10 +3389,17 @@ class MCPServerManager: continue extra_headers[header] = header_value - if mcp_server.static_headers: + # Interpolate env vars into static_headers. Raises + # MCPMissingUserEnvVarsError when the calling user has not filled in + # a required per-user variable — the REST layer converts that into + # a friendly 412 with a setup URL. + resolved_static_headers = await self._resolve_static_headers_with_env_vars( + mcp_server, user_api_key_auth + ) + if resolved_static_headers: if extra_headers is None: extra_headers = {} - extra_headers.update(mcp_server.static_headers) + extra_headers.update(resolved_static_headers) if hook_extra_headers: if extra_headers is None: @@ -3566,7 +3857,13 @@ class MCPServerManager: verbose_logger.debug( f"Building server from DB: {server.server_id} ({server.server_name})" ) - new_server = await self.build_mcp_server_from_table(server) + # raw_rows come straight from the DB, so their global env var + # values (like credentials) are still encrypted here, unlike the + # already-decrypted records add_server/update_server are handed. + # Decrypt them while building the registry entry. + new_server = await self.build_mcp_server_from_table( + server, env_vars_are_encrypted=True + ) # Carry the cached short_prefix from the previous registry entry # (if any) so the prefix is stable across reloads. if existing_server is not None and existing_server.short_prefix: @@ -3906,11 +4203,21 @@ class MCPServerManager: and not server.authentication_token ): should_skip_health_check = True + # Skip if static_headers reference a per-user env var: a userless probe + # can't fill ${NAME} and would forward the literal placeholder upstream, + # flipping the server to unhealthy even though real user calls succeed. + elif self._references_per_user_env_var(server): + should_skip_health_check = True if not should_skip_health_check: - extra_headers = {} - if server.static_headers: - extra_headers.update(server.static_headers) + resolved_static_headers = await self._resolve_static_headers_with_env_vars( + server=server, + user_api_key_auth=None, + raise_on_missing=False, + ) + extra_headers = ( + dict(resolved_static_headers) if resolved_static_headers else {} + ) client = await self._create_mcp_client( server=server, @@ -3960,6 +4267,7 @@ class MCPServerManager: extra_headers=server.extra_headers or [], mcp_info=server.mcp_info, static_headers=server.static_headers, + env_vars=self._env_vars_to_models(server.env_vars), status=status, last_health_check=datetime.now(), health_check_error=health_check_error, @@ -4033,6 +4341,14 @@ class MCPServerManager: return list_mcp_servers + @staticmethod + def _env_vars_to_models( + env_vars: Optional[List[Dict[str, Any]]], + ) -> Optional[List[MCPEnvVar]]: + if env_vars is None: + return None + return [MCPEnvVar.model_validate(env_var) for env_var in env_vars] + def _build_mcp_server_table(self, server: MCPServer) -> LiteLLM_MCPServerTable: return LiteLLM_MCPServerTable( server_id=server.server_id, @@ -4053,6 +4369,7 @@ class MCPServerManager: extra_headers=server.extra_headers or [], mcp_info=server.mcp_info, static_headers=server.static_headers, + env_vars=self._env_vars_to_models(server.env_vars), status=None, # No health check performed last_health_check=None, # No health check performed health_check_error=None, diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index e20c9f3a08..78cecfed0b 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -1,3 +1,4 @@ +import asyncio import importlib from datetime import datetime from typing import ( @@ -13,6 +14,7 @@ from typing import ( Union, ) +import httpx from fastapi import APIRouter, Depends, HTTPException, Query, Request, status from litellm._logging import verbose_logger @@ -20,7 +22,10 @@ from litellm.proxy._experimental.mcp_server.exceptions import MCPUpstreamAuthErr from litellm.proxy._experimental.mcp_server.ui_session_utils import ( build_effective_auth_contexts, ) -from litellm.proxy._experimental.mcp_server.utils import merge_mcp_headers +from litellm.proxy._experimental.mcp_server.utils import ( + MCPMissingUserEnvVarsError, + merge_mcp_headers, +) from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth from litellm.proxy.auth.ip_address_utils import IPAddressUtils from litellm.proxy.auth.user_api_key_auth import user_api_key_auth @@ -41,6 +46,28 @@ router = APIRouter( tags=["mcp"], ) + +def _connection_error_message(exc: BaseException) -> str: + if isinstance(exc, httpx.LocalProtocolError): + return ( + "Failed to connect to MCP server: a request header is malformed. " + "Check static headers for leading/trailing spaces or illegal characters." + ) + if isinstance(exc, (httpx.ConnectError, httpx.ConnectTimeout)): + return ( + "Failed to connect to MCP server: the server is unreachable. " + "Check the URL and that the server is running." + ) + if isinstance(exc, httpx.TimeoutException): + return "Failed to connect to MCP server: the connection timed out." + if isinstance(exc, httpx.HTTPStatusError): + return ( + f"Failed to connect to MCP server: it returned HTTP " + f"{exc.response.status_code}." + ) + return "Failed to connect to MCP server. Check proxy logs for details." + + if MCP_AVAILABLE: from mcp.types import Tool as MCPTool @@ -812,6 +839,23 @@ if MCP_AVAILABLE: requested_server_id=canonical_server_id, ) return result + except MCPMissingUserEnvVarsError as e: + verbose_logger.info( + "MCP tool call missing per-user env vars: server_id=%s missing=%s", + e.server_id, + e.missing, + ) + raise HTTPException( + status_code=412, + detail={ + "error": "missing_user_env_vars", + "message": str(e), + "server_id": e.server_id, + "server_name": e.server_name, + "missing": e.missing, + "setup_url": e.setup_url, + }, + ) except BlockedPiiEntityError as e: verbose_logger.error(f"BlockedPiiEntityError in MCP tool call: {str(e)}") raise HTTPException( @@ -961,14 +1005,14 @@ if MCP_AVAILABLE: return await operation(client) - except (KeyboardInterrupt, SystemExit): + except (KeyboardInterrupt, SystemExit, asyncio.CancelledError): raise except BaseException as e: verbose_logger.error("Error in MCP operation: %s", e, exc_info=True) return { "status": "error", "error": True, - "message": "Failed to connect to MCP server. Check proxy logs for details.", + "message": _connection_error_message(e), } async def _preview_openapi_tools(spec_path: str) -> dict: diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index df6cb22fda..51e77de43c 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -53,6 +53,7 @@ from litellm.proxy._experimental.mcp_server.utils import ( LITELLM_MCP_SERVER_DESCRIPTION, LITELLM_MCP_SERVER_NAME, LITELLM_MCP_SERVER_VERSION, + MCPMissingUserEnvVarsError, add_server_prefix_to_name, get_server_prefix, iter_known_server_prefixes, @@ -720,6 +721,16 @@ if MCP_AVAILABLE: host_progress_callback=host_progress_callback, **data, # for logging ) + except MCPMissingUserEnvVarsError as e: + verbose_logger.info( + "MCP mcp_server_tool_call missing per-user env vars: server_id=%s missing=%s", + e.server_id, + e.missing, + ) + return CallToolResult( + content=[TextContent(text=str(e), type="text")], + isError=True, + ) except BlockedPiiEntityError as e: verbose_logger.error( f"BlockedPiiEntityError in MCP tool call: {str(e)}" diff --git a/litellm/proxy/_experimental/mcp_server/utils.py b/litellm/proxy/_experimental/mcp_server/utils.py index b66dfa85b9..97cfa74ea4 100644 --- a/litellm/proxy/_experimental/mcp_server/utils.py +++ b/litellm/proxy/_experimental/mcp_server/utils.py @@ -4,11 +4,23 @@ MCP Server Utilities import json import re -from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Union +from typing import ( + Any, + Dict, + Iterable, + Iterator, + List, + Mapping, + Optional, + Set, + Tuple, + Union, +) import hashlib import importlib import os +from urllib.parse import quote # Constants LITELLM_MCP_SERVER_NAME = "litellm-mcp-server" @@ -370,6 +382,130 @@ def validate_mcp_server_name( raise Exception(error_message) +class MCPMissingUserEnvVarsError(Exception): + """Raised when an MCP request can't be built because the calling user has + not supplied one or more required per-user environment variables. + + The error message is user-facing and includes a URL the user can visit + to fill them in. + """ + + def __init__( + self, + *, + server_id: str, + server_name: Optional[str], + missing: List[str], + setup_url: str, + ) -> None: + self.server_id = server_id + self.server_name = server_name + self.missing = missing + self.setup_url = setup_url + label = server_name or server_id + bullet_list = "\n".join(f"- {name}" for name in missing) + message = ( + f'Cannot connect to MCP server "{label}".\n\n' + f"Your administrator configured this server to require per-user " + f"variables, but you haven't set the following yet:\n" + f"{bullet_list}\n\n" + f"Set your credentials here:\n" + f"{setup_url}" + ) + super().__init__(message) + + +# Pattern for ``${NAME}`` substitution. Matches the standard env-var +# identifier rules — letters, digits, underscores, can't start with a digit. +_ENV_VAR_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}") + + +def parse_admin_env_vars( + env_vars: Optional[Iterable[Any]], +) -> Tuple[Dict[str, str], List[Dict[str, Any]]]: + """Split admin-configured env var entries into globals and per-user specs. + + Accepts the raw value of ``MCPServer.env_vars`` (list of dicts or Pydantic + models). Returns: + + - ``global_values``: ``{name: value}`` for entries with ``scope=="global"``. + - ``user_specs``: list of ``{name, description}`` for entries with + ``scope=="user"`` — these are the names the user must fill in. + + Unknown / malformed entries are skipped silently. + """ + global_values: Dict[str, str] = {} + user_specs: List[Dict[str, Any]] = [] + if not env_vars: + return global_values, user_specs + for raw in env_vars: + if raw is None: + continue + if hasattr(raw, "model_dump"): + entry = raw.model_dump() + elif isinstance(raw, dict): + entry = raw + else: + continue + name = entry.get("name") + if not isinstance(name, str) or not name: + continue + scope = entry.get("scope") or "global" + if scope == "user": + user_specs.append({"name": name, "description": entry.get("description")}) + else: + value = entry.get("value") + global_values[name] = "" if value is None else str(value) + return global_values, user_specs + + +def find_env_var_references(value: str) -> Set[str]: + """Return the set of ``${NAME}`` identifiers referenced inside ``value``.""" + if not value: + return set() + return set(_ENV_VAR_PATTERN.findall(value)) + + +def collect_env_var_references(*, strings: Iterable[str]) -> Set[str]: + """Union of every ``${NAME}`` reference across a collection of strings.""" + refs: Set[str] = set() + for s in strings: + if isinstance(s, str): + refs |= find_env_var_references(s) + return refs + + +def interpolate_env_vars(value: str, variables: Mapping[str, str]) -> str: + """Replace ``${NAME}`` references in ``value`` with the matching mapping + entry. Unknown names are left untouched so callers can detect them via + ``find_env_var_references`` on the result if needed. + """ + if not value: + return value + + def _sub(match: "re.Match[str]") -> str: + name = match.group(1) + if name in variables: + return variables[name] + return match.group(0) + + return _ENV_VAR_PATTERN.sub(_sub, value) + + +def interpolate_headers( + headers: Mapping[str, str], variables: Mapping[str, str] +) -> Dict[str, str]: + """Return a copy of ``headers`` with every value passed through ``interpolate_env_vars``.""" + return {k: interpolate_env_vars(v, variables) for k, v in headers.items()} + + +def build_env_var_setup_url(server_id: str) -> str: + """The frontend URL where a user can fill in their per-user env vars.""" + base = os.environ.get("PROXY_BASE_URL", "").rstrip("/") + path = f"/ui/?page=mcp-servers&fill_env_vars={quote(server_id, safe='')}" + return f"{base}{path}" if base else path + + def merge_mcp_headers( *, extra_headers: Optional[Mapping[str, str]] = None, diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 3415eb435e..e5d7093306 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1265,6 +1265,34 @@ class MCPApprovalStatus(str, enum.Enum): rejected = "rejected" +class MCPEnvVarScope(str, enum.Enum): + """Scope for an MCP server environment variable. + + - ``global``: value is provided by the admin and used for all users. + - ``user``: each user must provide their own value via the per-user + env-var endpoint. The admin-supplied ``value`` is treated as a + placeholder/hint and is not used at request time. + """ + + global_ = "global" + user = "user" + + +class MCPEnvVar(LiteLLMPydanticObjectBase): + """One environment variable for an MCP server. + + Variables can be interpolated into ``static_headers`` using ``${NAME}`` + syntax. ``scope=global`` values are stored on the server. ``scope=user`` + values are stored per-user in ``LiteLLM_MCPUserEnvVars`` and supplied by + each user. + """ + + name: str + value: str = "" + scope: MCPEnvVarScope = MCPEnvVarScope.global_ + description: Optional[str] = None + + # MCP Proxy Request Types class NewMCPServerRequest(LiteLLMPydanticObjectBase): server_id: Optional[str] = None @@ -1283,6 +1311,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase): tool_name_to_description: Optional[Dict[str, str]] = None extra_headers: Optional[List[str]] = None static_headers: Optional[Dict[str, str]] = None + env_vars: Optional[List[MCPEnvVar]] = None instructions: Optional[str] = None # Stdio-specific fields command: Optional[str] = None @@ -1369,6 +1398,7 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase): tool_name_to_description: Optional[Dict[str, str]] = None extra_headers: Optional[List[str]] = None static_headers: Optional[Dict[str, str]] = None + env_vars: Optional[List[MCPEnvVar]] = None instructions: Optional[str] = None # Stdio-specific fields command: Optional[str] = None @@ -1438,6 +1468,7 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase): extra_headers: List[str] = Field(default_factory=list) mcp_info: Optional[MCPInfo] = None static_headers: Optional[Dict[str, str]] = None + env_vars: Optional[List[MCPEnvVar]] = None # Health check status status: Optional[Literal["healthy", "unhealthy", "unknown"]] = Field( default="unknown", @@ -1519,6 +1550,35 @@ class MCPUserCredentialListItem(LiteLLMPydanticObjectBase): connected_at: Optional[str] = None # ISO-8601 +class MCPUserEnvVarsRequest(LiteLLMPydanticObjectBase): + """Payload for storing the calling user's per-user env var values.""" + + values: Dict[str, str] + + +class MCPUserEnvVarSpec(LiteLLMPydanticObjectBase): + """Describes one per-user env var slot for the calling user. + + Stored values are write-only: the status only reports whether a value + ``is_set`` and never echoes the decrypted secret back to the client. + """ + + name: str + description: Optional[str] = None + is_set: bool = False + + +class MCPUserEnvVarsStatus(LiteLLMPydanticObjectBase): + """Per-user env var status for a single MCP server.""" + + server_id: str + server_name: Optional[str] = None + alias: Optional[str] = None + required: List[MCPUserEnvVarSpec] = Field(default_factory=list) + missing_count: int = 0 + setup_url: Optional[str] = None # frontend URL where the user can fill these in + + class RejectMCPServerRequest(LiteLLMPydanticObjectBase): review_notes: Optional[str] = None diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py index c8c590af97..99d0bac88a 100644 --- a/litellm/proxy/management_endpoints/key_management_endpoints.py +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -39,6 +39,7 @@ from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._experimental.mcp_server.db import ( rotate_mcp_server_credentials_master_key, rotate_mcp_user_credentials_master_key, + rotate_mcp_user_env_vars_master_key, ) from litellm.proxy._types import * from litellm.proxy._types import LiteLLM_VerificationToken @@ -4136,6 +4137,15 @@ async def _rotate_master_key( # noqa: PLR0915 "Failed to rotate MCP user credentials: %s", str(e) ) + # 4c. process MCP per-user environment variables table + try: + await rotate_mcp_user_env_vars_master_key( + prisma_client=prisma_client, + new_master_key=new_master_key, + ) + except Exception as e: + verbose_proxy_logger.warning("Failed to rotate MCP user env vars: %s", str(e)) + # 5. process credentials table try: credentials = await prisma_client.db.litellm_credentialstable.find_many() diff --git a/litellm/proxy/management_endpoints/mcp_management_endpoints.py b/litellm/proxy/management_endpoints/mcp_management_endpoints.py index 05cfc67449..f1edcc9c7b 100644 --- a/litellm/proxy/management_endpoints/mcp_management_endpoints.py +++ b/litellm/proxy/management_endpoints/mcp_management_endpoints.py @@ -47,7 +47,10 @@ from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._uuid import uuid from litellm.constants import LITELLM_PROXY_ADMIN_NAME from litellm.proxy._experimental.mcp_server.utils import ( + build_env_var_setup_url, + collect_env_var_references, get_server_prefix, + parse_admin_env_vars, ) from litellm.proxy._experimental.mcp_server.utils import ( validate_and_normalize_mcp_server_payload as _base_validate_and_normalize_mcp_server_payload, @@ -111,12 +114,16 @@ if MCP_AVAILABLE: create_mcp_server, delete_mcp_server, delete_user_credential, + delete_user_env_vars, get_all_mcp_servers_for_user, get_mcp_server, get_mcp_servers, get_mcp_submissions, + get_user_env_vars, + get_user_env_vars_bulk, get_user_oauth_credential, list_user_oauth_credentials, + merge_user_env_vars, reject_mcp_server, store_user_credential, store_user_oauth_credential, @@ -139,6 +146,7 @@ if MCP_AVAILABLE: LitellmUserRoles, MakeMCPServersPublicRequest, MCPApprovalStatus, + MCPEnvVarScope, MCPOAuthUserCredentialRequest, MCPOAuthUserCredentialStatus, MCPSubmissionsSummary, @@ -146,6 +154,9 @@ if MCP_AVAILABLE: MCPUserCredentialListItem, MCPUserCredentialRequest, MCPUserCredentialResponse, + MCPUserEnvVarSpec, + MCPUserEnvVarsRequest, + MCPUserEnvVarsStatus, NewMCPServerRequest, RejectMCPServerRequest, SpecialMCPServerName, @@ -473,6 +484,27 @@ if MCP_AVAILABLE: ) -> List[LiteLLM_MCPServerTable]: return [_redact_mcp_credentials(server) for server in mcp_servers] + def _redact_global_env_var_values(mcp_server: LiteLLM_MCPServerTable) -> None: + """Blank admin-supplied ``scope="global"`` env var secrets in place. + + Global entries hold the admin's plaintext credential (API key, + password, ...) and must never reach non-admin callers. Per-user + entries only carry a placeholder the user fills in themselves, so + their value is left intact. + """ + for env_var in mcp_server.env_vars or []: + if env_var.scope == MCPEnvVarScope.global_: + env_var.value = "" + + def _user_is_full_admin(user_api_key_dict: UserAPIKeyAuth) -> bool: + """True only for ``PROXY_ADMIN``; ``PROXY_ADMIN_VIEW_ONLY`` returns False. + + Global env var secrets pre-fill the admin edit form, so a full admin + must see them, but a read-only admin gets the same redacted view as + any other non-managing caller. + """ + return user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN + def _is_restricted_virtual_key_request(user_api_key_dict: UserAPIKeyAuth) -> bool: """Best-effort detection for route-restricted virtual keys. @@ -513,6 +545,11 @@ if MCP_AVAILABLE: sanitized.authorization_url = None sanitized.token_url = None sanitized.registration_url = None + # Drop env vars entirely rather than only blanking global values: the + # names alone (DB_PASSWORD, GITHUB_API_KEY, ...) leak what secrets the + # admin configured. Non-admins get the per-user vars they must fill in + # from the dedicated /user-env-vars/status endpoint instead. + sanitized.env_vars = None return sanitized def _sanitize_mcp_server_list_for_non_admin( @@ -544,6 +581,7 @@ if MCP_AVAILABLE: sanitized.allowed_tools = [] sanitized.mcp_access_groups = [] sanitized.teams = [] + sanitized.env_vars = None sanitized.authorization_url = None sanitized.token_url = None @@ -976,6 +1014,10 @@ if MCP_AVAILABLE: if not _user_has_admin_view(user_api_key_dict): return _sanitize_mcp_server_list_for_non_admin(redacted_mcp_servers) + if not _user_is_full_admin(user_api_key_dict): + for server in redacted_mcp_servers: + _redact_global_env_var_values(server) + return redacted_mcp_servers @router.get( @@ -1144,7 +1186,11 @@ if MCP_AVAILABLE: "Database not connected. Connect a database to your proxy" ) - return await get_mcp_submissions(prisma_client) + submissions = await get_mcp_submissions(prisma_client) + if not _user_is_full_admin(user_api_key_dict): + for item in submissions.items: + _redact_global_env_var_values(item) + return submissions @router.put( "/server/{server_id}/approve", @@ -1363,6 +1409,8 @@ if MCP_AVAILABLE: return _sanitize_mcp_server_for_virtual_key(redacted) if not _user_has_admin_view(user_api_key_dict): return _sanitize_mcp_server_for_non_admin(redacted) + if not _user_is_full_admin(user_api_key_dict): + _redact_global_env_var_values(redacted) return redacted @router.post( @@ -1432,23 +1480,34 @@ if MCP_AVAILABLE: payload.submitted_by = None payload.submitted_at = None - # Attempt to create the mcp server + # The database write is the commit point: if it fails nothing was + # persisted and the request is a genuine failure. try: new_mcp_server = await create_mcp_server( prisma_client, payload, touched_by=user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME, ) - await global_mcp_server_manager.add_server(new_mcp_server) - - # Ensure registry is up to date by reloading from database - await global_mcp_server_manager.reload_servers_from_database() except Exception as e: verbose_proxy_logger.exception(f"Error creating mcp server: {str(e)}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail={"error": f"Error creating mcp server: {str(e)}"}, ) + + # Registry refresh is best-effort: the row is already committed, so a + # failure here (e.g. an unrelated malformed row in the table) must not + # surface as a 500 and orphan the created server, which would push the + # caller to retry and create duplicates. + try: + await global_mcp_server_manager.add_server(new_mcp_server) + await global_mcp_server_manager.reload_servers_from_database() + except Exception as e: + verbose_proxy_logger.exception( + f"MCP server {new_mcp_server.server_id} created but in-memory " + f"registry refresh failed: {str(e)}" + ) + return _redact_mcp_credentials(new_mcp_server) @router.post( @@ -2106,6 +2165,249 @@ if MCP_AVAILABLE: ) return items + # ── Per-user MCP env var endpoints ──────────────────────────────────────── + + async def _authorize_and_fetch_mcp_server( + prisma_client, + user_api_key_dict: UserAPIKeyAuth, + server_id: str, + ) -> LiteLLM_MCPServerTable: + """Return the MCP server the caller may manage env vars for. + + Admins look the server up directly. Non-admins reuse the access-scoped + listing that already loads every server they can see, so we don't issue + a second per-server query just to re-fetch a record the authorization + check produced. A non-admin who can't see the server gets 403 (never + 404) so server ids can't be enumerated. + """ + if _user_has_admin_view(user_api_key_dict): + server = await get_mcp_server(prisma_client, server_id) + if server is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": f"MCP Server {server_id} not found"}, + ) + return server + accessible = await get_all_mcp_servers_for_user( + prisma_client, user_api_key_dict + ) + for server in accessible: + if server.server_id == server_id: + return server + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={ + "error": ( + f"User does not have permission to access mcp server with id {server_id}. " + "You can only manage env vars for mcp servers that you have access to." + ) + }, + ) + + def _compute_user_env_var_status( + *, + server: LiteLLM_MCPServerTable, + stored_values: Dict[str, str], + ) -> MCPUserEnvVarsStatus: + """Build a status object for one server given the user's stored values. + + Stored credentials are write-only: the response reports only whether + each value ``is_set`` and never echoes the decrypted secret back, so a + leaked token can't be used to exfiltrate the raw upstream credential. + """ + global_values, user_specs = parse_admin_env_vars( + getattr(server, "env_vars", None) + ) + # An empty-valued global is not a usable fallback, so it must not mark a + # referenced per-user var as covered, matching the empty-global filter in + # _resolve_static_headers_with_env_vars. Otherwise this endpoint reports no + # credential needed for a var every tool call still 412s on. + global_values = {name: value for name, value in global_values.items() if value} + + # A var only blocks when it's referenced by static_headers and has no + # admin global fallback, mirroring _resolve_static_headers_with_env_vars + # (globals win the merge) so the status endpoint never asks the user for + # credentials a tool call wouldn't actually require. + static_headers = getattr(server, "static_headers", None) or {} + if isinstance(static_headers, str): + try: + static_headers = json.loads(static_headers) or {} + except (ValueError, TypeError): + static_headers = {} + referenced = collect_env_var_references(strings=static_headers.values()) + user_var_names = {spec["name"] for spec in user_specs} + blocking = { + name for name in (referenced & user_var_names) if name not in global_values + } + + required: List[MCPUserEnvVarSpec] = [] + missing_count = 0 + for spec in user_specs: + name = spec["name"] + if name not in blocking: + continue + value = stored_values.get(name) + is_set = bool(value) + if not is_set: + missing_count += 1 + required.append( + MCPUserEnvVarSpec( + name=name, + description=spec.get("description"), + is_set=is_set, + ) + ) + + return MCPUserEnvVarsStatus( + server_id=server.server_id, + server_name=getattr(server, "server_name", None), + alias=getattr(server, "alias", None), + required=required, + missing_count=missing_count, + setup_url=build_env_var_setup_url(server.server_id) if required else None, + ) + + @router.get( + "/server/{server_id}/user-env-vars", + description="Return the calling user's per-user MCP env var status for this server.", + dependencies=[Depends(user_api_key_auth)], + response_model=MCPUserEnvVarsStatus, + ) + @management_endpoint_wrapper + async def get_mcp_user_env_vars( + server_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ) -> MCPUserEnvVarsStatus: + prisma_client = get_prisma_client_or_throw( + "Database not connected. Connect a database to your proxy" + ) + user_id = user_api_key_dict.user_id or "" + if not user_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "User ID not found in token"}, + ) + server = await _authorize_and_fetch_mcp_server( + prisma_client, user_api_key_dict, server_id + ) + stored = await get_user_env_vars(prisma_client, user_id, server_id) + return _compute_user_env_var_status(server=server, stored_values=stored) + + @router.post( + "/server/{server_id}/user-env-vars", + description=( + "Store the calling user's per-user MCP env var values for this " + "server. Submitted values are merged over any previously stored " + "values, so you only send the fields you want to set or change; a " + "variable omitted (or sent empty) keeps its stored value. Use " + "DELETE to clear all stored values." + ), + dependencies=[Depends(user_api_key_auth)], + response_model=MCPUserEnvVarsStatus, + ) + @management_endpoint_wrapper + async def store_mcp_user_env_vars( + server_id: str, + payload: MCPUserEnvVarsRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ) -> MCPUserEnvVarsStatus: + prisma_client = get_prisma_client_or_throw( + "Database not connected. Connect a database to your proxy" + ) + user_id = user_api_key_dict.user_id or "" + if not user_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "User ID not found in token"}, + ) + server = await _authorize_and_fetch_mcp_server( + prisma_client, user_api_key_dict, server_id + ) + # Only known per-user var names declared by the admin are accepted — + # never persist arbitrary keys the user invents. Submitted values are + # merged over the existing set so a user updating one credential does + # not have to re-enter the others (which are write-only and never shown + # back); an omitted/empty field keeps its stored value. + _, user_specs = parse_admin_env_vars(getattr(server, "env_vars", None)) + allowed_names = {spec["name"] for spec in user_specs} + updates = { + k: v for k, v in payload.values.items() if k in allowed_names and v != "" + } + merged = await merge_user_env_vars( + prisma_client, user_id, server_id, updates, allowed_names + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + invalidate_user_env_vars_cache, + ) + + invalidate_user_env_vars_cache(user_id, server_id) + return _compute_user_env_var_status(server=server, stored_values=merged) + + @router.delete( + "/server/{server_id}/user-env-vars", + description="Clear the calling user's per-user MCP env var values for this server.", + dependencies=[Depends(user_api_key_auth)], + response_model=MCPUserEnvVarsStatus, + ) + @management_endpoint_wrapper + async def clear_mcp_user_env_vars( + server_id: str, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ) -> MCPUserEnvVarsStatus: + prisma_client = get_prisma_client_or_throw( + "Database not connected. Connect a database to your proxy" + ) + user_id = user_api_key_dict.user_id or "" + if not user_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "User ID not found in token"}, + ) + server = await _authorize_and_fetch_mcp_server( + prisma_client, user_api_key_dict, server_id + ) + await delete_user_env_vars(prisma_client, user_id, server_id) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + invalidate_user_env_vars_cache, + ) + + invalidate_user_env_vars_cache(user_id, server_id) + return _compute_user_env_var_status(server=server, stored_values={}) + + @router.get( + "/user-env-vars/status", + description="Per-user MCP env var status across every server the user can access. " + "Used by the dashboard to highlight servers with missing per-user vars.", + dependencies=[Depends(user_api_key_auth)], + response_model=List[MCPUserEnvVarsStatus], + ) + @management_endpoint_wrapper + async def list_mcp_user_env_var_status( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + ) -> List[MCPUserEnvVarsStatus]: + prisma_client = get_prisma_client_or_throw( + "Database not connected. Connect a database to your proxy" + ) + user_id = user_api_key_dict.user_id or "" + if not user_id: + return [] + accessible = await get_all_mcp_servers_for_user( + prisma_client, user_api_key_dict + ) + if not accessible: + return [] + server_ids = [s.server_id for s in accessible] + stored_bulk = await get_user_env_vars_bulk(prisma_client, user_id, server_ids) + statuses: List[MCPUserEnvVarsStatus] = [] + for server in accessible: + stored = stored_bulk.get(server.server_id, {}) + status_obj = _compute_user_env_var_status( + server=server, stored_values=stored + ) + if status_obj.required: + statuses.append(status_obj) + return statuses + @router.put( "/server", description="Allows deleting mcp serves in the db", diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index 287dc38a1f..0b175db3c8 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -5,6 +5,7 @@ from functools import wraps from typing import Any, Callable, List, Optional, Tuple from fastapi import HTTPException, Request +from pydantic import BaseModel import litellm from litellm._logging import verbose_logger @@ -436,6 +437,58 @@ async def send_management_endpoint_alert( ) +def _redacted_env_var(entry: Any) -> dict: + get = entry.get if isinstance(entry, dict) else lambda k: getattr(entry, k, None) + return { + "name": get("name"), + "scope": get("scope"), + "description": get("description"), + "value": "", + } + + +def _redact_record_env_vars(record: Any) -> Any: + """Return ``record`` with its ``env_vars[].value`` blanked. + + Copies rather than mutating, because the record aliases the live response + object that is also returned to the caller. Records without an ``env_vars`` + list are returned unchanged. + """ + env_vars = ( + record.get("env_vars") + if isinstance(record, dict) + else getattr(record, "env_vars", None) + ) + if not isinstance(env_vars, list): + return record + redacted = [_redacted_env_var(entry) for entry in env_vars] + if isinstance(record, dict): + return {**record, "env_vars": redacted} + if isinstance(record, BaseModel): + return record.model_copy(update={"env_vars": redacted}) + return record + + +def _redact_env_var_values(response: dict) -> None: + """Blank ``env_vars[].value`` in a management response before telemetry. + + MCP endpoints return decrypted ``scope="global"`` env var values so the admin + UI can pre-fill the edit form; those values are upstream credentials and must + not be serialized verbatim into OTEL spans, where an observability user could + read them. The values surface both at the top level (single-server + create/update) and nested under ``items`` (the submissions queue), so both are + scrubbed. Names, scopes, and descriptions are kept so traces stay useful. + """ + if isinstance(response.get("env_vars"), list): + response["env_vars"] = [ + _redacted_env_var(entry) for entry in response["env_vars"] + ] + + items = response.get("items") + if isinstance(items, list): + response["items"] = [_redact_record_env_vars(item) for item in items] + + async def _emit_management_endpoint_otel_span( func: Callable, kwargs: dict, @@ -497,6 +550,7 @@ async def _emit_management_endpoint_otel_span( try: raw = dict(result) _response = {k: v for k, v in raw.items() if k not in _CREDENTIAL_FIELDS} + _redact_env_var_values(_response) except Exception: _response = None diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index b51da83a63..e21c001649 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -311,6 +311,11 @@ model LiteLLM_MCPServerTable { tool_name_to_description Json? @default("{}") extra_headers String[] @default([]) static_headers Json? @default("{}") + // Admin-configured environment variables interpolated into static_headers + // via ${NAME} syntax. Stored as an array of + // {name, value, scope, description}. scope is "global" (value used as-is) + // or "user" (value supplied per-user via LiteLLM_MCPUserEnvVars). + env_vars Json? @default("[]") // Health check status status String? @default("unknown") last_health_check DateTime? @@ -366,6 +371,21 @@ model LiteLLM_MCPUserCredentials { @@unique([user_id, server_id]) } +// Per-user environment variable values for MCP servers. +// values_b64 is an encrypted JSON object: {VAR_NAME: "value", ...}. +model LiteLLM_MCPUserEnvVars { + id String @id @default(uuid()) + user_id String + server_id String + values_b64 String + created_at DateTime @default(now()) + updated_at DateTime @default(now()) @updatedAt + + @@unique([user_id, server_id]) + @@index([user_id]) + @@index([server_id]) +} + // Generate Tokens for Proxy model LiteLLM_VerificationToken { token String @id diff --git a/litellm/types/mcp_server/mcp_server_manager.py b/litellm/types/mcp_server/mcp_server_manager.py index 4f8c9a0aa4..92ca027c5b 100644 --- a/litellm/types/mcp_server/mcp_server_manager.py +++ b/litellm/types/mcp_server/mcp_server_manager.py @@ -42,6 +42,10 @@ class MCPServer(BaseModel): static_headers: Optional[Dict[str, str]] = ( None # static headers to forward to the MCP server ) + # Admin-configured env vars. Each entry is {name, value, scope, description}. + # scope=="global" values are interpolated into static_headers using ${NAME}. + # scope=="user" values must be supplied per-user. + env_vars: Optional[List[Dict[str, Any]]] = None # OAuth-specific fields client_id: Optional[str] = None client_secret: Optional[str] = None diff --git a/schema.prisma b/schema.prisma index b51da83a63..e21c001649 100644 --- a/schema.prisma +++ b/schema.prisma @@ -311,6 +311,11 @@ model LiteLLM_MCPServerTable { tool_name_to_description Json? @default("{}") extra_headers String[] @default([]) static_headers Json? @default("{}") + // Admin-configured environment variables interpolated into static_headers + // via ${NAME} syntax. Stored as an array of + // {name, value, scope, description}. scope is "global" (value used as-is) + // or "user" (value supplied per-user via LiteLLM_MCPUserEnvVars). + env_vars Json? @default("[]") // Health check status status String? @default("unknown") last_health_check DateTime? @@ -366,6 +371,21 @@ model LiteLLM_MCPUserCredentials { @@unique([user_id, server_id]) } +// Per-user environment variable values for MCP servers. +// values_b64 is an encrypted JSON object: {VAR_NAME: "value", ...}. +model LiteLLM_MCPUserEnvVars { + id String @id @default(uuid()) + user_id String + server_id String + values_b64 String + created_at DateTime @default(now()) + updated_at DateTime @default(now()) @updatedAt + + @@unique([user_id, server_id]) + @@index([user_id]) + @@index([server_id]) +} + // Generate Tokens for Proxy model LiteLLM_VerificationToken { token String @id diff --git a/tests/test_litellm/experimental_mcp_client/test_mcp_client.py b/tests/test_litellm/experimental_mcp_client/test_mcp_client.py index dee689708c..c9e500b4a5 100644 --- a/tests/test_litellm/experimental_mcp_client/test_mcp_client.py +++ b/tests/test_litellm/experimental_mcp_client/test_mcp_client.py @@ -1,3 +1,4 @@ +import asyncio import os import ssl import sys @@ -10,10 +11,26 @@ import pytest sys.path.insert(0, "../../../") import litellm.experimental_mcp_client.client as mcp_client_module -from litellm.experimental_mcp_client.client import MCPClient +from litellm.experimental_mcp_client.client import ( + MCPClient, + _first_non_cancelled_cause, +) from litellm.types.mcp import MCPAuth, MCPStdioConfig, MCPTransport +class _FakeExceptionGroup(Exception): + """Duck-typed stand-in for an anyio/builtin ExceptionGroup. + + The production unwrapper reads ``.exceptions`` rather than depending on the + builtin ``ExceptionGroup`` type, so this exercises the same code path on + every Python version. + """ + + def __init__(self, message, exceptions): + super().__init__(message) + self.exceptions = tuple(exceptions) + + class TestMCPClient: """Test MCP Client stdio functionality""" @@ -307,6 +324,26 @@ class TestMCPClient: assert headers["Authorization"] == "token my-token" assert headers["X-Custom-Header"] == "custom-value" + def test_get_auth_headers_strips_static_header_whitespace(self): + """ + Static header names/values must be stripped of surrounding whitespace. + + h11 rejects header values with leading/trailing whitespace as an + "Illegal header value", which silently aborts the MCP connection. A + stray space in a configured static header value would otherwise make + every request to that server fail with an opaque error. + """ + client = MCPClient( + server_url="http://example.com/mcp", + transport_type="http", + extra_headers={"X-Db-Url": " mew://host ", " X-Pad ": "v"}, + ) + + headers = client._get_auth_headers() + + assert headers["X-Db-Url"] == "mew://host" + assert headers["X-Pad"] == "v" + def test_token_auth_enum_value(self): """Test that MCPAuth.token enum exists and has correct value""" assert hasattr(MCPAuth, "token") @@ -388,5 +425,123 @@ class TestMCPClientInstructionsCapture: assert client._last_initialize_instructions is None +# --------------------------------------------------------------------------- +# Transport error surfacing +# --------------------------------------------------------------------------- + + +class TestFirstNonCancelledCause: + """Unwrapping the real cause out of a (possibly nested) exception group.""" + + def test_returns_plain_non_cancelled(self): + err = ValueError("boom") + assert _first_non_cancelled_cause(err) is err + + def test_returns_none_for_plain_cancelled(self): + assert _first_non_cancelled_cause(asyncio.CancelledError()) is None + + def test_unwraps_group_to_non_cancelled_leaf(self): + target = httpx.ConnectError("refused") + group = _FakeExceptionGroup("g", [asyncio.CancelledError(), target]) + assert _first_non_cancelled_cause(group) is target + + def test_unwraps_nested_group(self): + target = httpx.LocalProtocolError("Illegal header value") + inner = _FakeExceptionGroup("inner", [asyncio.CancelledError(), target]) + outer = _FakeExceptionGroup("outer", [asyncio.CancelledError(), inner]) + assert _first_non_cancelled_cause(outer) is target + + def test_all_cancelled_returns_none(self): + group = _FakeExceptionGroup( + "g", [asyncio.CancelledError(), asyncio.CancelledError()] + ) + assert _first_non_cancelled_cause(group) is None + + @pytest.mark.skipif( + sys.version_info < (3, 11), reason="builtin ExceptionGroup requires 3.11+" + ) + def test_unwraps_builtin_exception_group(self): + target = httpx.ConnectError("refused") + group = ExceptionGroup("transport failed", [target]) # noqa: F821 + assert _first_non_cancelled_cause(group) is target + + +class TestExecuteSessionOperationSurfacesTransportError: + """_execute_session_operation should surface the real transport failure. + + When the upstream transport's task group fails (illegal header, connection + refused, ...), the in-flight ``session.initialize()`` is cancelled and the + real error only appears when the transport context exits. The opaque + ``CancelledError`` must be replaced with that real cause. + """ + + def _make_session(self, mock_session_cls, initialize): + mock_session = AsyncMock() + mock_session.initialize = initialize + session_ctx = MagicMock() + session_ctx.__aenter__ = AsyncMock(return_value=mock_session) + session_ctx.__aexit__ = AsyncMock(return_value=False) + mock_session_cls.return_value = session_ctx + + def _make_transport(self, aexit_side_effect): + transport_ctx = MagicMock() + transport_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock())) + transport_ctx.__aexit__ = AsyncMock(side_effect=aexit_side_effect) + return transport_ctx + + @pytest.mark.asyncio + @patch("litellm.experimental_mcp_client.client.ClientSession") + async def test_surfaces_connect_error_over_cancelled(self, mock_session_cls): + client = MCPClient(server_url="http://example.com/mcp", transport_type="http") + self._make_session( + mock_session_cls, + AsyncMock(side_effect=asyncio.CancelledError("cancelled by group")), + ) + connect_error = httpx.ConnectError("All connection attempts failed") + transport_ctx = self._make_transport( + _FakeExceptionGroup("transport", [connect_error]) + ) + + async def _op(session): + return "done" + + with pytest.raises(httpx.ConnectError): + await client._execute_session_operation(transport_ctx, _op) + + @pytest.mark.asyncio + @patch("litellm.experimental_mcp_client.client.ClientSession") + async def test_genuine_cancellation_is_not_replaced(self, mock_session_cls): + client = MCPClient(server_url="http://example.com/mcp", transport_type="http") + self._make_session( + mock_session_cls, AsyncMock(side_effect=asyncio.CancelledError()) + ) + transport_ctx = self._make_transport( + _FakeExceptionGroup("teardown", [asyncio.CancelledError()]) + ) + + async def _op(session): + return "done" + + with pytest.raises(asyncio.CancelledError): + await client._execute_session_operation(transport_ctx, _op) + + @pytest.mark.asyncio + @patch("litellm.experimental_mcp_client.client.ClientSession") + async def test_cleanup_error_after_success_is_swallowed(self, mock_session_cls): + client = MCPClient(server_url="http://example.com/mcp", transport_type="http") + init_result = MagicMock() + init_result.instructions = None + self._make_session(mock_session_cls, AsyncMock(return_value=init_result)) + transport_ctx = self._make_transport( + _FakeExceptionGroup("late", [httpx.ConnectError("late cleanup error")]) + ) + + async def _op(session): + return "done" + + result = await client._execute_session_operation(transport_ctx, _op) + assert result == "done" + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_litellm/litellm_core_utils/test_token_counter.py b/tests/test_litellm/litellm_core_utils/test_token_counter.py index 324bace0e9..92c070501b 100644 --- a/tests/test_litellm/litellm_core_utils/test_token_counter.py +++ b/tests/test_litellm/litellm_core_utils/test_token_counter.py @@ -200,7 +200,12 @@ def test_tokenizers(): model="meta-llama/llama-3-70b-instruct", text=sample_text ) - llama3_tokenizer = create_pretrained_tokenizer("Xenova/llama-3-tokenizer") + try: + llama3_tokenizer = create_pretrained_tokenizer("Xenova/llama-3-tokenizer") + except Exception as e: + pytest.skip( + f"custom tokenizer download failed (HF hub unreachable): {e}" + ) llama3_tokens_2 = token_counter( custom_tokenizer=llama3_tokenizer, text=sample_text ) diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py index 078adf72d4..8aa9f107c8 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py @@ -20,11 +20,14 @@ from litellm.proxy._experimental.mcp_server.db import ( get_user_oauth_credential, list_user_oauth_credentials, rotate_mcp_user_credentials_master_key, + rotate_mcp_user_env_vars_master_key, store_user_credential, store_user_oauth_credential, ) -from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper - +from litellm.proxy.common_utils.encrypt_decrypt_utils import ( + decrypt_value_helper, + encrypt_value_helper, +) SALT_KEY = "test-salt-key-for-byok-credential-tests-1234" @@ -400,3 +403,69 @@ async def test_rotate_skips_undecodable_rows(): assert prisma.db.litellm_mcpusercredentials.update.call_count == 1 where = prisma.db.litellm_mcpusercredentials.update.call_args.kwargs["where"] assert where["user_id_server_id"]["server_id"] == "srv-ok" + + +# ── per-user env-var rotation ───────────────────────────────────────────────── + + +def _env_var_row(values_b64: str, user_id="alice", server_id="srv-1"): + row = MagicMock() + row.values_b64 = values_b64 + row.user_id = user_id + row.server_id = server_id + return row + + +@pytest.mark.asyncio +async def test_rotate_user_env_vars_re_encrypts_with_new_key(monkeypatch): + # Encrypt env vars under the current salt, rotate to a new key, then confirm + # the stored ciphertext round-trips under the NEW key. + values = {"API_KEY": "sk-secret", "REGION": "us-east-1"} + encrypted_old = encrypt_value_helper(json.dumps(values)) + + prisma = MagicMock() + prisma.db.litellm_mcpuserenvvars.find_many = AsyncMock( + return_value=[_env_var_row(encrypted_old)] + ) + prisma.db.litellm_mcpuserenvvars.update = AsyncMock() + + new_master_key = "rotated-env-key-1111-2222-3333-4444" + await rotate_mcp_user_env_vars_master_key( + prisma_client=prisma, new_master_key=new_master_key + ) + + new_stored = prisma.db.litellm_mcpuserenvvars.update.call_args.kwargs["data"][ + "values_b64" + ] + assert new_stored != encrypted_old, "rotation must produce different ciphertext" + + monkeypatch.setenv("LITELLM_SALT_KEY", new_master_key) + decrypted = decrypt_value_helper( + value=new_stored, + key="mcp_user_env_vars", + exception_type="debug", + return_original_value=False, + ) + assert json.loads(decrypted) == values + + +@pytest.mark.asyncio +async def test_rotate_user_env_vars_skips_undecryptable_rows(): + # A corrupt row must be skipped (not overwritten) so recoverable data is + # preserved and one bad row does not abort the rest of the rotation. + good = _env_var_row( + encrypt_value_helper(json.dumps({"A": "1"})), server_id="srv-ok" + ) + bad = _env_var_row("!!! not encrypted !!!", server_id="srv-corrupt") + + prisma = MagicMock() + prisma.db.litellm_mcpuserenvvars.find_many = AsyncMock(return_value=[bad, good]) + prisma.db.litellm_mcpuserenvvars.update = AsyncMock() + + await rotate_mcp_user_env_vars_master_key( + prisma_client=prisma, new_master_key="new-key-xxxx" + ) + + assert prisma.db.litellm_mcpuserenvvars.update.call_count == 1 + where = prisma.db.litellm_mcpuserenvvars.update.call_args.kwargs["where"] + assert where["user_id_server_id"]["server_id"] == "srv-ok" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_env_vars.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_env_vars.py new file mode 100644 index 0000000000..19065ff816 --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_env_vars.py @@ -0,0 +1,1652 @@ +"""Tests for MCP env-var interpolation utilities. + +These cover the pure helpers in +``litellm.proxy._experimental.mcp_server.utils`` and do not require a DB +connection. The DB-backed per-user flow is exercised in higher-level +tests in tests/mcp_tests. +""" + +import pytest + +# Look up these names lazily on every access. Tests in this directory call +# ``importlib.reload`` on the utils module to exercise registration logic, +# which replaces ``MCPMissingUserEnvVarsError`` with a freshly-constructed +# class. A direct ``from ... import`` at module load time would freeze the +# old class object and ``pytest.raises(_u("MCPMissingUserEnvVarsError"))`` would +# stop matching the new class. Accessing the attribute through the module +# always picks up the current version. +import litellm.proxy._experimental.mcp_server.utils as _mcp_utils + + +def _u(name: str): + return getattr(_mcp_utils, name) + + +def test_parse_admin_env_vars_splits_global_and_user(): + g, u = _u("parse_admin_env_vars")( + [ + {"name": "DB_PROTOCOL", "value": "postgres", "scope": "global"}, + {"name": "DB_HOST", "value": "localhost", "scope": "global"}, + { + "name": "CORP_USERNAME", + "value": "", + "scope": "user", + "description": "Your DB username", + }, + {"name": "CORP_PASSWORD", "value": "", "scope": "user"}, + ] + ) + assert g == {"DB_PROTOCOL": "postgres", "DB_HOST": "localhost"} + assert u == [ + {"name": "CORP_USERNAME", "description": "Your DB username"}, + {"name": "CORP_PASSWORD", "description": None}, + ] + + +def test_parse_admin_env_vars_handles_none_and_empty(): + assert _u("parse_admin_env_vars")(None) == ({}, []) + assert _u("parse_admin_env_vars")([]) == ({}, []) + + +def test_parse_admin_env_vars_skips_malformed_entries(): + g, u = _u("parse_admin_env_vars")( + [ + None, + {"name": "", "value": "x"}, + {"value": "no_name"}, + {"name": "OK", "value": "v"}, + ] + ) + assert g == {"OK": "v"} + assert u == [] + + +def test_find_env_var_references(): + assert _u("find_env_var_references")("") == set() + assert _u("find_env_var_references")("plain") == set() + assert _u("find_env_var_references")("${A}") == {"A"} + assert _u("find_env_var_references")("${A}/${B}/${A}") == {"A", "B"} + # Invalid identifier patterns should not match + assert _u("find_env_var_references")("${1abc}") == set() + assert _u("find_env_var_references")("${a-b}") == set() + + +def test_collect_env_var_references(): + refs = _u("collect_env_var_references")( + strings=["${A}", "static", "${B}-${C}", None] + ) + assert refs == {"A", "B", "C"} + + +def test_interpolate_env_vars_replaces_known_and_leaves_unknown(): + assert _u("interpolate_env_vars")( + "${A}://${B}/${C}", {"A": "https", "B": "host"} + ) == ("https://host/${C}") + + +def test_interpolate_headers_returns_independent_copy(): + headers = {"X-Url": "${A}://x"} + out = _u("interpolate_headers")(headers, {"A": "https"}) + assert out == {"X-Url": "https://x"} + # original untouched + assert headers == {"X-Url": "${A}://x"} + + +def test_build_env_var_setup_url_includes_server_id(monkeypatch): + monkeypatch.delenv("PROXY_BASE_URL", raising=False) + url = _u("build_env_var_setup_url")("abc-123") + assert url.startswith("/ui/?page=mcp-servers") + assert "fill_env_vars=abc-123" in url + + +def test_build_env_var_setup_url_prepends_proxy_base_url(monkeypatch): + monkeypatch.setenv("PROXY_BASE_URL", "https://proxy.example.com/") + url = _u("build_env_var_setup_url")("abc-123") + assert url.startswith("https://proxy.example.com/ui/") + assert "fill_env_vars=abc-123" in url + + +def test_build_env_var_setup_url_encodes_unsafe_server_id(monkeypatch): + from urllib.parse import parse_qs, urlsplit + + monkeypatch.delenv("PROXY_BASE_URL", raising=False) + server_id = "a&b=c #d/e" + url = _u("build_env_var_setup_url")(server_id) + assert "a&b=c #d/e" not in url + parsed = parse_qs(urlsplit(url).query) + assert parsed["fill_env_vars"] == [server_id] + + +def test_missing_user_env_vars_error_message_is_friendly(): + with pytest.raises(_u("MCPMissingUserEnvVarsError")) as exc_info: + raise _u("MCPMissingUserEnvVarsError")( + server_id="abc-123", + server_name="CorporateDB", + missing=["CORP_USERNAME", "CORP_PASSWORD"], + setup_url="https://proxy.example.com/ui/?page=mcp-servers&fill_env_vars=abc-123", + ) + err = exc_info.value + text = str(err) + assert 'Cannot connect to MCP server "CorporateDB".' in text + assert "- CORP_USERNAME" in text + assert "- CORP_PASSWORD" in text + assert "fill_env_vars=abc-123" in text + assert "Set your credentials here:" in text + assert err.server_id == "abc-123" + assert err.missing == ["CORP_USERNAME", "CORP_PASSWORD"] + + +def test_missing_user_env_vars_error_falls_back_to_server_id(): + err = _u("MCPMissingUserEnvVarsError")( + server_id="abc", + server_name=None, + missing=["X"], + setup_url="/ui/", + ) + text = str(err) + # Falls back to server_id when server_name is missing + assert 'Cannot connect to MCP server "abc".' in text + assert "- X" in text + + +# ── _resolve_static_headers_with_env_vars ──────────────────────────────── + + +@pytest.fixture +def mock_server(): + """A minimal MCPServer-like object for the static-headers resolver.""" + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + return MCPServer( + server_id="srv-1", + name="srv", + server_name="srv", + transport="http", + url="https://example.com", + static_headers={ + "X-DB-URL": "${DB_PROTOCOL}://${CORP_USERNAME}:${CORP_PASSWORD}@${DB_HOST}/db", + "X-Other": "literal", + }, + env_vars=[ + {"name": "DB_PROTOCOL", "value": "postgres", "scope": "global"}, + {"name": "DB_HOST", "value": "db.local", "scope": "global"}, + { + "name": "CORP_USERNAME", + "value": "", + "scope": "user", + "description": "Your DB username", + }, + {"name": "CORP_PASSWORD", "value": "", "scope": "user"}, + ], + ) + + +@pytest.mark.asyncio +async def test_resolve_static_headers_interpolates_globals_and_user( + mock_server, monkeypatch +): + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + + manager = MCPServerManager() + + # Stub the per-user lookup so we don't need a real DB. + async def fake_load_user_env_vars(server, user_api_key_auth): + return {"CORP_USERNAME": "alice", "CORP_PASSWORD": "s3cret"} + + monkeypatch.setattr(manager, "_load_user_env_vars", fake_load_user_env_vars) + + headers = await manager._resolve_static_headers_with_env_vars( + mock_server, user_api_key_auth=object() + ) + assert headers == { + "X-DB-URL": "postgres://alice:s3cret@db.local/db", + "X-Other": "literal", + } + + +@pytest.mark.asyncio +async def test_resolve_static_headers_raises_when_user_vars_missing( + mock_server, monkeypatch +): + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + + manager = MCPServerManager() + + async def fake_load_user_env_vars( + server, user_api_key_auth, *, force_refresh=False + ): + # User has only filled in one of the two required vars + return {"CORP_USERNAME": "alice"} + + monkeypatch.setattr(manager, "_load_user_env_vars", fake_load_user_env_vars) + + with pytest.raises(_u("MCPMissingUserEnvVarsError")) as exc: + await manager._resolve_static_headers_with_env_vars( + mock_server, user_api_key_auth=object() + ) + assert exc.value.missing == ["CORP_PASSWORD"] + assert exc.value.server_id == "srv-1" + assert "fill_env_vars=srv-1" in exc.value.setup_url + + +@pytest.mark.asyncio +async def test_resolve_static_headers_rechecks_db_before_raising_412( + mock_server, monkeypatch +): + """A stale cached negative must not produce a 412 on the tool-call path. + + Cache invalidation is process-local, so a user who stored values on another + worker can have a stale (incomplete) entry on this one. Before raising + MCPMissingUserEnvVarsError the resolver must re-read with force_refresh and + honor the fresh DB values. + """ + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + + manager = MCPServerManager() + + calls = [] + + async def fake_load_user_env_vars( + server, user_api_key_auth, *, force_refresh=False + ): + calls.append(force_refresh) + if force_refresh: + # Fresh DB read sees the values the user stored on another worker. + return {"CORP_USERNAME": "alice", "CORP_PASSWORD": "s3cret"} + # Stale, process-local cached entry is still missing CORP_PASSWORD. + return {"CORP_USERNAME": "alice"} + + monkeypatch.setattr(manager, "_load_user_env_vars", fake_load_user_env_vars) + + headers = await manager._resolve_static_headers_with_env_vars( + mock_server, user_api_key_auth=object() + ) + assert headers == { + "X-DB-URL": "postgres://alice:s3cret@db.local/db", + "X-Other": "literal", + } + # The cached read happened first, then exactly one forced DB re-read. + assert calls == [False, True] + + +@pytest.mark.asyncio +async def test_resolve_static_headers_missing_is_non_blocking_for_listing( + mock_server, monkeypatch +): + """With raise_on_missing=False (the tool-list path), missing per-user vars + must NOT raise. Available vars interpolate; unfilled ${NAME} refs are left + untouched so the server's tools still appear in the listing.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + + manager = MCPServerManager() + + async def fake_load_user_env_vars(server, user_api_key_auth): + # User has only filled in one of the two required vars. + return {"CORP_USERNAME": "alice"} + + monkeypatch.setattr(manager, "_load_user_env_vars", fake_load_user_env_vars) + + headers = await manager._resolve_static_headers_with_env_vars( + mock_server, user_api_key_auth=object(), raise_on_missing=False + ) + # Globals + the supplied user var are interpolated; the still-missing + # CORP_PASSWORD reference is left as a literal rather than blocking listing. + assert headers == { + "X-DB-URL": "postgres://alice:${CORP_PASSWORD}@db.local/db", + "X-Other": "literal", + } + + +@pytest.mark.asyncio +async def test_resolve_static_headers_propagates_db_error_on_tool_call( + mock_server, monkeypatch +): + """A DB failure on the tool-call path must surface as a real error, not be + masked as a "missing credentials" MCPMissingUserEnvVarsError (412).""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + + manager = MCPServerManager() + + async def boom(server, user_api_key_auth): + raise RuntimeError("db down") + + monkeypatch.setattr(manager, "_load_user_env_vars", boom) + + with pytest.raises(RuntimeError, match="db down"): + await manager._resolve_static_headers_with_env_vars( + mock_server, user_api_key_auth=object() + ) + + +@pytest.mark.asyncio +async def test_resolve_static_headers_swallows_db_error_on_listing( + mock_server, monkeypatch +): + """On the listing path a DB failure is non-blocking: globals interpolate + and unfilled per-user ${NAME} refs are left untouched.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + + manager = MCPServerManager() + + async def boom(server, user_api_key_auth): + raise RuntimeError("db down") + + monkeypatch.setattr(manager, "_load_user_env_vars", boom) + + headers = await manager._resolve_static_headers_with_env_vars( + mock_server, user_api_key_auth=object(), raise_on_missing=False + ) + assert headers == { + "X-DB-URL": "postgres://${CORP_USERNAME}:${CORP_PASSWORD}@db.local/db", + "X-Other": "literal", + } + + +@pytest.mark.asyncio +async def test_resolve_static_headers_passthrough_when_no_env_vars(): + """Servers without env_vars should keep static_headers untouched.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + server = MCPServer( + server_id="srv-2", + name="srv2", + transport="http", + url="https://example.com", + static_headers={"Authorization": "Bearer admin-static"}, + env_vars=None, + ) + headers = await manager._resolve_static_headers_with_env_vars(server, None) + assert headers == {"Authorization": "Bearer admin-static"} + + +@pytest.mark.asyncio +async def test_resolve_static_headers_unreferenced_user_var_is_not_blocking( + monkeypatch, +): + """A per-user var declared by the admin but never referenced in + static_headers must not block the request — only blocking-by-use is + enforced.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + server = MCPServer( + server_id="srv-3", + name="srv3", + transport="http", + url="https://example.com", + static_headers={"X-Static": "${GLOBAL_VAR}"}, + env_vars=[ + {"name": "GLOBAL_VAR", "value": "ok", "scope": "global"}, + # User var declared but not referenced anywhere — should be ignored. + {"name": "UNUSED_USER_VAR", "value": "", "scope": "user"}, + ], + ) + + async def fake_load_user_env_vars(server, user_api_key_auth): + return {} + + monkeypatch.setattr(manager, "_load_user_env_vars", fake_load_user_env_vars) + + headers = await manager._resolve_static_headers_with_env_vars(server, object()) + assert headers == {"X-Static": "ok"} + + +@pytest.mark.asyncio +async def test_resolve_static_headers_stale_user_value_cannot_override_global( + monkeypatch, +): + """A var that used to be user-scoped (so the user has a stored value) but is + now global must resolve to the admin's global value, not the stale per-user + row. Otherwise a user could override admin-configured headers indefinitely.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + server = MCPServer( + server_id="srv-4", + name="srv4", + transport="http", + url="https://example.com", + static_headers={"X-DB-URL": "${DB_HOST}/${CORP_USERNAME}"}, + env_vars=[ + # DB_HOST is now global; it used to be user-scoped. + {"name": "DB_HOST", "value": "admin-db", "scope": "global"}, + {"name": "CORP_USERNAME", "value": "", "scope": "user"}, + ], + ) + + async def fake_load_user_env_vars(server, user_api_key_auth): + # Stale DB_HOST row left over from when it was user-scoped. + return {"DB_HOST": "evil-db", "CORP_USERNAME": "alice"} + + monkeypatch.setattr(manager, "_load_user_env_vars", fake_load_user_env_vars) + + headers = await manager._resolve_static_headers_with_env_vars(server, object()) + assert headers == {"X-DB-URL": "admin-db/alice"} + + +@pytest.mark.asyncio +async def test_resolve_static_headers_dual_scope_var_uses_global_without_412( + monkeypatch, +): + """A var declared with both ``global`` and ``user`` scope is covered by the + global value (globals win in the merge), so the tool-call path must resolve + it from the global instead of raising a 412 when the user hasn't filled it + in. This happens during a global-to-user (or user-to-global) migration.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + server = MCPServer( + server_id="srv-5", + name="srv5", + transport="http", + url="https://example.com", + static_headers={"Authorization": "Bearer ${SHARED_TOKEN}"}, + env_vars=[ + {"name": "SHARED_TOKEN", "value": "global-secret", "scope": "global"}, + {"name": "SHARED_TOKEN", "value": "", "scope": "user"}, + ], + ) + + load_calls = [] + + async def fake_load_user_env_vars( + server, user_api_key_auth, *, force_refresh=False + ): + load_calls.append(force_refresh) + return {} + + monkeypatch.setattr(manager, "_load_user_env_vars", fake_load_user_env_vars) + + headers = await manager._resolve_static_headers_with_env_vars( + server, user_api_key_auth=object() + ) + assert headers == {"Authorization": "Bearer global-secret"} + # The global fully covers the reference, so no per-user lookup is needed. + assert load_calls == [] + + +@pytest.mark.asyncio +async def test_resolve_static_headers_empty_global_does_not_cover_user_var( + monkeypatch, +): + """An empty-valued global must not cover a referenced per-user var. The + global carries no usable value, so the tool-call path still raises a 412 + when the user hasn't supplied one, instead of silently interpolating an + empty string into the header.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + server = MCPServer( + server_id="srv-6", + name="srv6", + transport="http", + url="https://example.com", + static_headers={"Authorization": "Bearer ${SHARED_TOKEN}"}, + env_vars=[ + {"name": "SHARED_TOKEN", "value": "", "scope": "global"}, + {"name": "SHARED_TOKEN", "value": "", "scope": "user"}, + ], + ) + + async def fake_load_user_env_vars( + server, user_api_key_auth, *, force_refresh=False + ): + return {} + + monkeypatch.setattr(manager, "_load_user_env_vars", fake_load_user_env_vars) + + with pytest.raises(_u("MCPMissingUserEnvVarsError")) as exc: + await manager._resolve_static_headers_with_env_vars( + server, user_api_key_auth=object() + ) + assert exc.value.missing == ["SHARED_TOKEN"] + + +@pytest.mark.asyncio +async def test_resolve_static_headers_user_value_wins_over_empty_global( + monkeypatch, +): + """When a global is empty, a value the user did supply must win the merge + rather than being clobbered by the empty global. The header resolves to the + user's value, not an empty string.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + server = MCPServer( + server_id="srv-7", + name="srv7", + transport="http", + url="https://example.com", + static_headers={"Authorization": "Bearer ${SHARED_TOKEN}"}, + env_vars=[ + {"name": "SHARED_TOKEN", "value": "", "scope": "global"}, + {"name": "SHARED_TOKEN", "value": "", "scope": "user"}, + ], + ) + + async def fake_load_user_env_vars( + server, user_api_key_auth, *, force_refresh=False + ): + return {"SHARED_TOKEN": "user-secret"} + + monkeypatch.setattr(manager, "_load_user_env_vars", fake_load_user_env_vars) + + headers = await manager._resolve_static_headers_with_env_vars( + server, user_api_key_auth=object() + ) + assert headers == {"Authorization": "Bearer user-secret"} + + +# ── health-check skip for per-user-env-var-backed headers ────────────────── + + +@pytest.mark.parametrize( + "static_headers, env_vars, expected", + [ + ( + {"Authorization": "Bearer ${GITHUB_TOKEN}"}, + [{"name": "GITHUB_TOKEN", "value": "", "scope": "user"}], + True, + ), + ( + {"Authorization": "Bearer ${SHARED_TOKEN}"}, + [{"name": "SHARED_TOKEN", "value": "abc", "scope": "global"}], + False, + ), + ( + {"X-Static": "literal"}, + [{"name": "GITHUB_TOKEN", "value": "", "scope": "user"}], + False, + ), + (None, [{"name": "GITHUB_TOKEN", "value": "", "scope": "user"}], False), + ({"Authorization": "Bearer ${GITHUB_TOKEN}"}, None, False), + ], +) +def test_references_per_user_env_var(static_headers, env_vars, expected): + """Only headers that actually reference a *per-user* var count: globals and + declared-but-unreferenced user vars do not, since the userless probe can + still resolve (or simply not need) them.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + server = MCPServer( + server_id="srv-x", + name="srv", + transport="http", + url="https://example.com", + static_headers=static_headers, + env_vars=env_vars, + ) + assert manager._references_per_user_env_var(server) is expected + + +@pytest.mark.asyncio +async def test_health_check_skips_servers_referencing_per_user_env_var( + mock_server, monkeypatch +): + """A userless health probe cannot fill per-user ${NAME} placeholders, so a + server whose static_headers reference one must report 'unknown' without + connecting. Otherwise it forwards the literal placeholder upstream, gets a + 401, and flips to 'unhealthy' even though real user calls succeed.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + + manager = MCPServerManager() + manager.registry[mock_server.server_id] = mock_server + + created = [] + + async def fake_create_client(*args, **kwargs): + created.append((args, kwargs)) + raise RuntimeError("upstream rejected literal ${NAME}") + + monkeypatch.setattr(manager, "_create_mcp_client", fake_create_client) + + result = await manager.health_check_server(mock_server.server_id) + + assert created == [] + assert result.status == "unknown" + assert result.health_check_error is None + + +# ── _load_user_env_vars guard paths ──────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_load_user_env_vars_returns_empty_without_user(): + """No user auth → no per-user lookup is attempted.""" + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + server = MCPServer( + server_id="s", name="s", transport="http", url="https://example.com" + ) + assert await manager._load_user_env_vars(server, None) == {} + + +@pytest.mark.asyncio +async def test_load_user_env_vars_returns_empty_without_user_id(): + """User auth without a user_id (e.g. anonymous virtual key) → empty dict.""" + from unittest.mock import MagicMock + + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + server = MCPServer( + server_id="s", name="s", transport="http", url="https://example.com" + ) + fake_auth = MagicMock() + fake_auth.user_id = None + assert await manager._load_user_env_vars(server, fake_auth) == {} + + +@pytest.mark.asyncio +async def test_load_user_env_vars_raises_when_db_unavailable(monkeypatch): + """A missing DB connection must raise, not return ``{}``. Returning ``{}`` + would be indistinguishable from "user has no values" and would mislead the + tool-call path into a "set up your credentials" 412 the user can never + satisfy (per-user env vars are unusable without a DB).""" + from unittest.mock import MagicMock + + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + manager = MCPServerManager() + server = MCPServer( + server_id="s", name="s", transport="http", url="https://example.com" + ) + fake_auth = MagicMock() + fake_auth.user_id = "alice" + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + with pytest.raises(RuntimeError, match="database connection"): + await manager._load_user_env_vars(server, fake_auth) + + +@pytest.mark.asyncio +async def test_resolve_static_headers_db_unavailable_is_not_missing_412( + mock_server, monkeypatch +): + """On the tool-call path, an unavailable DB must surface as a real error + rather than a misleading MCPMissingUserEnvVarsError (412). This guards the + regression where ``_load_user_env_vars`` returned ``{}`` when prisma_client + was None, making a DB outage look like "user has no credentials".""" + from unittest.mock import MagicMock + + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + + manager = MCPServerManager() + fake_auth = MagicMock() + fake_auth.user_id = "alice" + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", None) + + with pytest.raises(RuntimeError, match="database connection"): + await manager._resolve_static_headers_with_env_vars( + mock_server, user_api_key_auth=fake_auth + ) + + +@pytest.mark.asyncio +async def test_load_user_env_vars_caches_within_ttl(env_vars_salt_key, monkeypatch): + """A second load within the TTL window is served from the in-memory cache, + keeping the hot tool-call/tool-listing path off the DB.""" + from unittest.mock import MagicMock + + from litellm.proxy._experimental.mcp_server import mcp_server_manager as mgr_mod + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + mgr_mod._user_env_vars_cache.clear() + + row = MagicMock() + row.values_b64 = _encrypted_user_env_blob({"TOKEN": "t0p"}) + + prisma = _mock_env_vars_prisma(row=row) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", prisma) + + manager = MCPServerManager() + server = MCPServer( + server_id="srv-1", name="s", transport="http", url="https://example.com" + ) + fake_auth = MagicMock() + fake_auth.user_id = "alice" + + first = await manager._load_user_env_vars(server, fake_auth) + second = await manager._load_user_env_vars(server, fake_auth) + assert first == {"TOKEN": "t0p"} == second + assert prisma.db.litellm_mcpuserenvvars.find_unique.await_count == 1 + + mgr_mod._user_env_vars_cache.clear() + + +@pytest.mark.asyncio +async def test_load_user_env_vars_force_refresh_bypasses_cache( + env_vars_salt_key, monkeypatch +): + """force_refresh re-reads from the DB even with a fresh cached entry, so a + process-local stale value cannot mask credentials stored on another worker.""" + from unittest.mock import AsyncMock, MagicMock + + from litellm.proxy._experimental.mcp_server import mcp_server_manager as mgr_mod + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + mgr_mod._user_env_vars_cache.clear() + + old_row = MagicMock() + old_row.values_b64 = _encrypted_user_env_blob({"TOKEN": "old"}) + new_row = MagicMock() + new_row.values_b64 = _encrypted_user_env_blob({"TOKEN": "new"}) + + prisma = _mock_env_vars_prisma() + prisma.db.litellm_mcpuserenvvars.find_unique = AsyncMock( + side_effect=[old_row, new_row] + ) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", prisma) + + manager = MCPServerManager() + server = MCPServer( + server_id="srv-1", name="s", transport="http", url="https://example.com" + ) + fake_auth = MagicMock() + fake_auth.user_id = "alice" + + assert await manager._load_user_env_vars(server, fake_auth) == {"TOKEN": "old"} + # A normal load is served from cache (still "old"); force_refresh re-reads. + assert await manager._load_user_env_vars(server, fake_auth) == {"TOKEN": "old"} + assert await manager._load_user_env_vars(server, fake_auth, force_refresh=True) == { + "TOKEN": "new" + } + assert prisma.db.litellm_mcpuserenvvars.find_unique.await_count == 2 + + mgr_mod._user_env_vars_cache.clear() + + +@pytest.mark.asyncio +async def test_load_user_env_vars_invalidation_forces_refetch( + env_vars_salt_key, monkeypatch +): + """After invalidation (store/clear) the next load reads fresh from the DB + instead of serving the stale cached value.""" + from unittest.mock import AsyncMock, MagicMock + + from litellm.proxy._experimental.mcp_server import mcp_server_manager as mgr_mod + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + invalidate_user_env_vars_cache, + ) + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + mgr_mod._user_env_vars_cache.clear() + + old_row = MagicMock() + old_row.values_b64 = _encrypted_user_env_blob({"TOKEN": "old"}) + new_row = MagicMock() + new_row.values_b64 = _encrypted_user_env_blob({"TOKEN": "new"}) + + prisma = _mock_env_vars_prisma() + prisma.db.litellm_mcpuserenvvars.find_unique = AsyncMock( + side_effect=[old_row, new_row] + ) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", prisma) + + manager = MCPServerManager() + server = MCPServer( + server_id="srv-1", name="s", transport="http", url="https://example.com" + ) + fake_auth = MagicMock() + fake_auth.user_id = "alice" + + assert await manager._load_user_env_vars(server, fake_auth) == {"TOKEN": "old"} + invalidate_user_env_vars_cache("alice", "srv-1") + assert await manager._load_user_env_vars(server, fake_auth) == {"TOKEN": "new"} + assert prisma.db.litellm_mcpuserenvvars.find_unique.await_count == 2 + + mgr_mod._user_env_vars_cache.clear() + + +# ── DB helpers: per-user env vars ───────────────────────────────────────── + +_SALT_KEY = "test-salt-key-for-env-vars-tests-1234" + + +@pytest.fixture +def env_vars_salt_key(monkeypatch): + monkeypatch.setenv("LITELLM_SALT_KEY", _SALT_KEY) + + +def _mock_env_vars_prisma(row=None): + """Build a MagicMock prisma_client whose env-vars table returns ``row``.""" + from unittest.mock import AsyncMock, MagicMock + + prisma = MagicMock() + prisma.db.litellm_mcpuserenvvars.find_unique = AsyncMock(return_value=row) + prisma.db.litellm_mcpuserenvvars.find_many = AsyncMock(return_value=[]) + prisma.db.litellm_mcpuserenvvars.upsert = AsyncMock() + prisma.db.litellm_mcpuserenvvars.delete_many = AsyncMock() + return prisma + + +def _encrypted_user_env_blob(values: dict) -> str: + """Encrypt ``values`` the way the production per-user write does, so tests can + seed a correctly-encrypted ``values_b64`` blob without a live DB.""" + import json + + from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper + + return encrypt_value_helper(json.dumps(values)) + + +def _transactional_env_vars_prisma(read_delay: float = 0.0): + """A prisma stand-in backed by an in-memory store that honours + ``db.tx()`` and the ``pg_advisory_xact_lock`` advisory lock. + + ``read_delay`` inserts an ``await`` point inside ``find_unique`` so two + concurrent merges interleave between their read and write; the advisory lock + is what keeps them from clobbering each other. Drop the lock and the second + write wins, losing the first update. + """ + import asyncio + from unittest.mock import MagicMock + + class _Store: + def __init__(self): + self.rows = {} + self.locks = {} + + class _Table: + def __init__(self, store, delay=0.0): + self._store = store + self._delay = delay + + async def find_unique(self, where): + ident = where["user_id_server_id"] + key = (ident["user_id"], ident["server_id"]) + blob = self._store.rows.get(key) + # Yield after capturing the read so an unserialised concurrent merge + # would race on this stale snapshot. + if self._delay: + await asyncio.sleep(self._delay) + if blob is None: + return None + row = MagicMock() + row.values_b64 = blob + return row + + async def upsert(self, where, data): + ident = where["user_id_server_id"] + key = (ident["user_id"], ident["server_id"]) + self._store.rows[key] = data["update"]["values_b64"] + + async def delete_many(self, where): + self._store.rows.pop((where["user_id"], where["server_id"]), None) + + class _Tx: + def __init__(self, store, delay): + self._store = store + self._held = None + self.litellm_mcpuserenvvars = _Table(store, delay=delay) + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + if self._held is not None: + self._held.release() + self._held = None + return False + + async def execute_raw(self, query, *args): + lock_key = args[0] + lock = self._store.locks.setdefault(lock_key, asyncio.Lock()) + await lock.acquire() + self._held = lock + return 1 + + class _DB: + def __init__(self, store, delay): + self._store = store + self._delay = delay + self.litellm_mcpuserenvvars = _Table(store) + + def tx(self): + return _Tx(self._store, self._delay) + + class _Prisma: + def __init__(self, delay): + self.db = _DB(_Store(), delay) + + return _Prisma(read_delay) + + +@pytest.mark.asyncio +async def test_merge_user_env_vars_does_not_persist_plaintext(env_vars_salt_key): + """The per-user write path must encrypt values at rest; ``values_b64`` must + never hold plaintext personal credentials, but must still round-trip.""" + from litellm.proxy._experimental.mcp_server.db import ( + _decode_user_env_vars, + merge_user_env_vars, + ) + + prisma = _transactional_env_vars_prisma() + values = {"CORP_USERNAME": "alice", "CORP_PASSWORD": "s3cret"} + await merge_user_env_vars( + prisma, "alice", "srv-1", values, allowed_names=values.keys() + ) + + row = await prisma.db.litellm_mcpuserenvvars.find_unique( + where={"user_id_server_id": {"user_id": "alice", "server_id": "srv-1"}} + ) + stored = row.values_b64 + assert "s3cret" not in stored + assert "alice" not in stored + assert _decode_user_env_vars(stored) == values + + +@pytest.mark.asyncio +async def test_get_user_env_vars_round_trip(env_vars_salt_key): + from unittest.mock import MagicMock + + from litellm.proxy._experimental.mcp_server.db import get_user_env_vars + + payload = {"CORP_USERNAME": "alice", "CORP_PASSWORD": "s3cret"} + row = MagicMock() + row.values_b64 = _encrypted_user_env_blob(payload) + prisma = _mock_env_vars_prisma(row=row) + + result = await get_user_env_vars(prisma, "alice", "srv-1") + assert result == payload + + +@pytest.mark.asyncio +async def test_get_user_env_vars_returns_empty_for_missing_row(): + from litellm.proxy._experimental.mcp_server.db import get_user_env_vars + + prisma = _mock_env_vars_prisma(row=None) + assert await get_user_env_vars(prisma, "alice", "srv-1") == {} + + +@pytest.mark.asyncio +async def test_decode_user_env_vars_warns_when_undecryptable( + env_vars_salt_key, monkeypatch +): + """A stored blob encrypted under a previous salt key must surface a warning + (not just a debug line) and decode to ``{}`` so a rotated ``LITELLM_SALT_KEY`` + is diagnosable instead of silently sending the user a misleading "set up your + credentials" 412 for values they already stored.""" + from unittest.mock import MagicMock + + import litellm.proxy._experimental.mcp_server.db as mcp_db + from litellm.proxy._experimental.mcp_server.db import _decode_user_env_vars + + blob = _encrypted_user_env_blob({"CORP_PASSWORD": "s3cret"}) + + monkeypatch.setenv("LITELLM_SALT_KEY", "a-totally-different-salt-key-0000") + logger = MagicMock() + monkeypatch.setattr(mcp_db, "verbose_proxy_logger", logger) + + assert _decode_user_env_vars(blob) == {} + logger.warning.assert_called_once() + + +@pytest.mark.asyncio +async def test_get_user_env_vars_bulk_distributes_results(env_vars_salt_key): + from unittest.mock import AsyncMock, MagicMock + + from litellm.proxy._experimental.mcp_server.db import get_user_env_vars_bulk + + blob1 = _encrypted_user_env_blob({"A": "1"}) + blob2 = _encrypted_user_env_blob({"B": "2"}) + + row1 = MagicMock() + row1.server_id = "srv-1" + row1.values_b64 = blob1 + row2 = MagicMock() + row2.server_id = "srv-2" + row2.values_b64 = blob2 + + prisma = _mock_env_vars_prisma() + prisma.db.litellm_mcpuserenvvars.find_many = AsyncMock(return_value=[row1, row2]) + result = await get_user_env_vars_bulk(prisma, "alice", ["srv-1", "srv-2", "srv-3"]) + assert result == {"srv-1": {"A": "1"}, "srv-2": {"B": "2"}} + + +@pytest.mark.asyncio +async def test_get_user_env_vars_bulk_empty_ids_short_circuits(): + from litellm.proxy._experimental.mcp_server.db import get_user_env_vars_bulk + + prisma = _mock_env_vars_prisma() + assert await get_user_env_vars_bulk(prisma, "alice", []) == {} + # find_many should never have been called + assert prisma.db.litellm_mcpuserenvvars.find_many.await_count == 0 + + +@pytest.mark.asyncio +async def test_delete_user_env_vars_is_idempotent_delete_many(): + """Delete must use ``delete_many`` so a missing row is a no-op rather than + raising RecordNotFound; real DB errors are left to propagate.""" + from litellm.proxy._experimental.mcp_server.db import delete_user_env_vars + + prisma = _mock_env_vars_prisma() + await delete_user_env_vars(prisma, "alice", "srv-1") + prisma.db.litellm_mcpuserenvvars.delete_many.assert_awaited_once() + call = prisma.db.litellm_mcpuserenvvars.delete_many.call_args + assert call.kwargs["where"] == {"user_id": "alice", "server_id": "srv-1"} + + +@pytest.mark.asyncio +async def test_merge_user_env_vars_preserves_existing_and_prunes_disallowed( + env_vars_salt_key, +): + """Merging one update keeps the user's other stored values and drops any + name the admin no longer declares as user-scoped.""" + from litellm.proxy._experimental.mcp_server.db import merge_user_env_vars + + prisma = _transactional_env_vars_prisma() + await merge_user_env_vars( + prisma, + "alice", + "srv-1", + {"CORP_USERNAME": "alice", "CORP_PASSWORD": "old", "RETIRED": "x"}, + {"CORP_USERNAME", "CORP_PASSWORD", "RETIRED"}, + ) + + merged = await merge_user_env_vars( + prisma, + "alice", + "srv-1", + {"CORP_PASSWORD": "new"}, + {"CORP_USERNAME", "CORP_PASSWORD"}, + ) + + # CORP_USERNAME survives, CORP_PASSWORD updates, RETIRED (no longer declared) + # is pruned. + assert merged == {"CORP_USERNAME": "alice", "CORP_PASSWORD": "new"} + + +@pytest.mark.asyncio +async def test_merge_user_env_vars_serializes_concurrent_writes(env_vars_salt_key): + """Two simultaneous merges for the same (user, server) must not lose an + update: the advisory-locked transaction serialises the read-modify-write so + both distinct values survive.""" + import asyncio + + from litellm.proxy._experimental.mcp_server.db import ( + get_user_env_vars, + merge_user_env_vars, + ) + + allowed = {"TOKEN_A", "TOKEN_B"} + prisma = _transactional_env_vars_prisma(read_delay=0.02) + + await asyncio.gather( + merge_user_env_vars(prisma, "alice", "srv-1", {"TOKEN_A": "a"}, allowed), + merge_user_env_vars(prisma, "alice", "srv-1", {"TOKEN_B": "b"}, allowed), + ) + + stored = await get_user_env_vars(prisma, "alice", "srv-1") + assert stored == {"TOKEN_A": "a", "TOKEN_B": "b"} + + +@pytest.mark.asyncio +async def test_merge_user_env_vars_acquires_lock_without_deserializing_void( + env_vars_salt_key, +): + """``pg_advisory_xact_lock`` returns ``void``; running it through ``query_raw`` + makes Prisma try to deserialize that column and raises ``RawQueryError``. The + lock must be taken via ``execute_raw`` (no result-set deserialization) so the + merge still completes.""" + from unittest.mock import MagicMock + + from prisma.errors import RawQueryError + + from litellm.proxy._experimental.mcp_server.db import merge_user_env_vars + + class _Tx: + def __init__(self): + self.stored = None + self.litellm_mcpuserenvvars = self + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + return False + + async def query_raw(self, query, *args): + raise RawQueryError( + { + "user_facing_error": { + "error_code": "P2010", + "meta": { + "message": "Failed to deserialize column of type 'void'." + }, + } + } + ) + + async def execute_raw(self, query, *args): + return 1 + + async def find_unique(self, where): + return None + + async def upsert(self, where, data): + self.stored = data["create"]["values_b64"] + + tx = _Tx() + prisma = MagicMock() + prisma.db.tx = MagicMock(return_value=tx) + + values = {"CORP_TOKEN": "t0ken"} + merged = await merge_user_env_vars( + prisma, "alice", "srv-1", values, allowed_names=values.keys() + ) + + assert merged == values + assert tx.stored is not None + + +@pytest.mark.asyncio +async def test_delete_mcp_server_removes_orphaned_user_env_vars(): + """Deleting a server must also drop every user's per-user env var rows for + it; there is no FK cascade, so skipping this leaves orphaned credentials.""" + from unittest.mock import AsyncMock + + from litellm.proxy._experimental.mcp_server.db import delete_mcp_server + + prisma = _mock_env_vars_prisma() + prisma.db.litellm_mcpservertable.delete = AsyncMock(return_value=object()) + + await delete_mcp_server(prisma, "srv-1") + + prisma.db.litellm_mcpuserenvvars.delete_many.assert_awaited_once() + call = prisma.db.litellm_mcpuserenvvars.delete_many.call_args + assert call.kwargs["where"] == {"server_id": "srv-1"} + + +@pytest.mark.asyncio +async def test_delete_mcp_server_skips_env_var_cleanup_when_server_missing(): + """A no-op delete (server not found) must not touch the env var table.""" + from unittest.mock import AsyncMock + + from litellm.proxy._experimental.mcp_server.db import delete_mcp_server + + prisma = _mock_env_vars_prisma() + prisma.db.litellm_mcpservertable.delete = AsyncMock(return_value=None) + + result = await delete_mcp_server(prisma, "srv-1") + + assert result is None + prisma.db.litellm_mcpuserenvvars.delete_many.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_delete_mcp_server_succeeds_when_orphan_cleanup_fails(): + """The server-row delete is the commit point: a transient failure cleaning + the FK-less per-user env var rows must not turn a successful delete into a + caller error, otherwise the caller retries and hits a 404 for a server that + is already gone.""" + from unittest.mock import AsyncMock + + from litellm.proxy._experimental.mcp_server.db import delete_mcp_server + + deleted = object() + prisma = _mock_env_vars_prisma() + prisma.db.litellm_mcpservertable.delete = AsyncMock(return_value=deleted) + prisma.db.litellm_mcpuserenvvars.delete_many = AsyncMock( + side_effect=Exception("connection pool exhausted") + ) + + result = await delete_mcp_server(prisma, "srv-1") + + assert result is deleted + prisma.db.litellm_mcpuserenvvars.delete_many.assert_awaited_once() + + +# ── DB helpers: global env vars encrypted at rest ───────────────────────── + + +def _global_env_var_server_request(env_vars): + from litellm.proxy._types import NewMCPServerRequest + + return NewMCPServerRequest( + alias="echo", + url="https://upstream.example.com/mcp", + transport="http", + auth_type="none", + static_headers={"X-Db": "${DB_PASSWORD}"}, + env_vars=env_vars, + ) + + +def test_prepare_mcp_server_data_encrypts_global_env_var_values(env_vars_salt_key): + """``scope="global"`` secrets must be encrypted before they reach the JSON + column, while ``scope="user"`` placeholders (not secrets) stay verbatim.""" + import json + + from litellm.proxy._experimental.mcp_server.db import ( + _prepare_mcp_server_data, + decrypt_global_env_var_values, + ) + from litellm.proxy._types import MCPEnvVar + + req = _global_env_var_server_request( + [ + MCPEnvVar(name="DB_PASSWORD", value="s3cr3t-p@ss", scope="global"), + MCPEnvVar( + name="CORP_USER", + value="placeholder-hint", + scope="user", + description="your db user", + ), + ] + ) + + stored = _prepare_mcp_server_data(req)["env_vars"] + entries = {e["name"]: e for e in json.loads(stored)} + + # The global secret is unrecoverable from the stored JSON ... + assert "s3cr3t-p@ss" not in stored + assert entries["DB_PASSWORD"]["value"] != "s3cr3t-p@ss" + # ... but the per-user placeholder is stored as-is. + assert entries["CORP_USER"]["value"] == "placeholder-hint" + + # And the encrypted global decrypts back to the original secret. + decrypt_global_env_var_values(list(entries.values())) + assert entries["DB_PASSWORD"]["value"] == "s3cr3t-p@ss" + assert entries["CORP_USER"]["value"] == "placeholder-hint" + + +def test_prepare_mcp_server_data_skips_unset_env_vars_on_partial_update(): + """On a partial update, env_vars must follow the same exclude_unset filter as + every other JSON column: if the caller never set env_vars, the field must not + be written, even when the request object carries a non-None env_vars that was + never marked as set. Otherwise a partial update could silently overwrite the + stored values.""" + from litellm.proxy._experimental.mcp_server.db import _prepare_mcp_server_data + from litellm.proxy._types import MCPEnvVar, UpdateMCPServerRequest + + data = UpdateMCPServerRequest.model_construct( + _fields_set={"server_id"}, + server_id="srv-1", + env_vars=[MCPEnvVar(name="DB_PASSWORD", value="s3cr3t", scope="global")], + ) + + prepared = _prepare_mcp_server_data(data, exclude_unset=True) + + assert "env_vars" not in prepared + + +def test_prepare_mcp_server_data_writes_env_vars_when_set_on_partial_update( + env_vars_salt_key, +): + """A partial update that does set env_vars must serialize and encrypt them.""" + import json + + from litellm.proxy._experimental.mcp_server.db import _prepare_mcp_server_data + from litellm.proxy._types import MCPEnvVar, UpdateMCPServerRequest + + data = UpdateMCPServerRequest( + server_id="srv-1", + env_vars=[MCPEnvVar(name="DB_PASSWORD", value="s3cr3t", scope="global")], + ) + + prepared = _prepare_mcp_server_data(data, exclude_unset=True) + + assert "env_vars" in prepared + entries = json.loads(prepared["env_vars"]) + assert entries[0]["name"] == "DB_PASSWORD" + assert entries[0]["value"] != "s3cr3t" + + +@pytest.mark.asyncio +async def test_build_mcp_server_from_table_decrypts_global_env_vars(env_vars_salt_key): + """End-to-end: an encrypted global value persisted in the DB must be + decrypted when the server is built into the runtime registry, so ``${NAME}`` + headers interpolate to the real secret instead of forwarding ciphertext.""" + import json + + from litellm.proxy._experimental.mcp_server.db import _prepare_mcp_server_data + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.proxy._types import LiteLLM_MCPServerTable, MCPEnvVar + + req = _global_env_var_server_request( + [MCPEnvVar(name="DB_PASSWORD", value="s3cr3t-p@ss", scope="global")] + ) + prepared = _prepare_mcp_server_data(req) + + table = LiteLLM_MCPServerTable( + server_id="srv-global", + alias="echo", + url="https://upstream.example.com/mcp", + transport="http", + auth_type="none", + static_headers={"X-Db": "${DB_PASSWORD}"}, + env_vars=json.loads(prepared["env_vars"]), + ) + + manager = MCPServerManager() + server = await manager.build_mcp_server_from_table(table) + + headers = await manager._resolve_static_headers_with_env_vars(server, None) + assert headers == {"X-Db": "s3cr3t-p@ss"} + + +@pytest.mark.asyncio +async def test_add_server_does_not_double_decrypt_global_env_vars(env_vars_salt_key): + """The create/fetch endpoints hand ``add_server`` a record whose global env + var values were already decrypted by the db.py helpers (only ``credentials`` + stays encrypted). Building the registry entry must not decrypt them a second + time: a second decrypt of an already-plaintext value (e.g. ``postgresql``) + fails and zeroes it, which would forward the raw ``${NAME}`` placeholder + upstream instead of the interpolated secret.""" + import json + + from litellm.proxy._experimental.mcp_server.db import ( + _prepare_mcp_server_data, + decrypt_global_env_var_values, + ) + from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( + MCPServerManager, + ) + from litellm.proxy._types import LiteLLM_MCPServerTable, MCPEnvVar + + req = _global_env_var_server_request( + [MCPEnvVar(name="DB_PASSWORD", value="s3cr3t-p@ss", scope="global")] + ) + env_vars = json.loads(_prepare_mcp_server_data(req)["env_vars"]) + # Mirror what create_mcp_server / get_mcp_server return to add_server. + decrypt_global_env_var_values(env_vars) + assert env_vars[0]["value"] == "s3cr3t-p@ss" + + table = LiteLLM_MCPServerTable( + server_id="srv-add", + alias="echo", + url="https://upstream.example.com/mcp", + transport="http", + auth_type="none", + static_headers={"X-Db": "${DB_PASSWORD}"}, + env_vars=env_vars, + approval_status="active", + ) + + manager = MCPServerManager() + await manager.add_server(table) + + server = manager.registry["srv-add"] + headers = await manager._resolve_static_headers_with_env_vars(server, None) + assert headers == {"X-Db": "s3cr3t-p@ss"} + + +@pytest.mark.asyncio +async def test_create_mcp_server_decrypts_env_vars_when_prisma_returns_json_string( + env_vars_salt_key, +): + """Regression for the reload-reuse path: Prisma can hand back ``env_vars`` on + a write as the raw JSON string that was persisted, not a parsed list. The + create/update wrappers must still decrypt globals on the returned row, else + ``add_server`` (which trusts the caller) seeds the registry with ciphertext + and the subsequent ``reload_servers_from_database`` reuses that broken entry + (timestamps match), so headers forward ciphertext upstream.""" + from unittest.mock import AsyncMock, MagicMock + + from litellm.proxy._experimental.mcp_server.db import ( + _prepare_mcp_server_data, + create_mcp_server, + update_mcp_server, + ) + from litellm.proxy._types import ( + MCPEnvVar, + NewMCPServerRequest, + UpdateMCPServerRequest, + ) + + req = _global_env_var_server_request( + [MCPEnvVar(name="DB_PASSWORD", value="s3cr3t-p@ss", scope="global")] + ) + encrypted_env_vars_str = _prepare_mcp_server_data(req)["env_vars"] + assert "s3cr3t-p@ss" not in encrypted_env_vars_str + + def _prisma_row_with_json_string_env_vars(): + row = MagicMock() + row.env_vars = encrypted_env_vars_str + return row + + mock_prisma = MagicMock() + mock_prisma.db.litellm_mcpservertable.create = AsyncMock( + return_value=_prisma_row_with_json_string_env_vars() + ) + + created = await create_mcp_server( + mock_prisma, + NewMCPServerRequest( + server_id="srv-create", + url="https://upstream.example.com/mcp", + transport="http", + ), + touched_by="test-user", + ) + assert isinstance(created.env_vars, list) + assert created.env_vars[0]["value"] == "s3cr3t-p@ss" + + mock_prisma_upd = MagicMock() + mock_prisma_upd.db.litellm_mcpservertable.update = AsyncMock( + return_value=_prisma_row_with_json_string_env_vars() + ) + updated = await update_mcp_server( + mock_prisma_upd, + UpdateMCPServerRequest(server_id="srv-update"), + touched_by="test-user", + ) + assert isinstance(updated.env_vars, list) + assert updated.env_vars[0]["value"] == "s3cr3t-p@ss" + + +def test_reencrypt_global_env_var_values_handles_json_string(env_vars_salt_key): + """``rotate_mcp_server_credentials_master_key`` reads ``mcp_server.env_vars`` + straight off the Prisma row, which can be a JSON string. The re-encrypt + helper must parse it instead of failing on ``dict(v)`` over a string.""" + import json + + from litellm.proxy._experimental.mcp_server.db import ( + _prepare_mcp_server_data, + _reencrypt_global_env_var_values, + ) + from litellm.proxy._types import MCPEnvVar + + req = _global_env_var_server_request( + [MCPEnvVar(name="DB_PASSWORD", value="s3cr3t-p@ss", scope="global")] + ) + encrypted_env_vars_str = _prepare_mcp_server_data(req)["env_vars"] + original_ciphertext = json.loads(encrypted_env_vars_str)[0]["value"] + + rebuilt = _reencrypt_global_env_var_values( + encrypted_env_vars_str, new_encryption_key="rotated-master-key-0000" + ) + + assert rebuilt is not None + assert rebuilt[0]["name"] == "DB_PASSWORD" + assert rebuilt[0]["value"] != original_ciphertext + assert rebuilt[0]["value"] != "s3cr3t-p@ss" + + +@pytest.mark.asyncio +async def test_rotate_mcp_user_env_vars_logs_rotated_and_skipped_counts( + env_vars_salt_key, monkeypatch +): + """Master-key rotation is a rare, high-stakes batch op, so it emits one + summary line. The counts must track real work: a decryptable row is + re-encrypted and counted as rotated, while a row that no longer decrypts is + left untouched and counted as skipped.""" + from unittest.mock import AsyncMock, MagicMock + + import litellm.proxy._experimental.mcp_server.db as mcp_db + from litellm.proxy._experimental.mcp_server.db import ( + rotate_mcp_user_env_vars_master_key, + ) + from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper + + def _row(user_id, server_id, blob): + row = MagicMock() + row.user_id = user_id + row.server_id = server_id + row.values_b64 = blob + return row + + import json + + # Encrypted under an unrelated key, so it won't decrypt under the active salt + # key and must be skipped rather than re-encrypted. + undecryptable = encrypt_value_helper( + json.dumps({"X": "y"}), new_encryption_key="unrelated-key-9999" + ) + good_one = _row("alice", "srv-1", _encrypted_user_env_blob({"GH_TOKEN": "tok-1"})) + good_two = _row("bob", "srv-2", _encrypted_user_env_blob({"GH_TOKEN": "tok-2"})) + bad = _row("carol", "srv-3", undecryptable) + + prisma = MagicMock() + prisma.db.litellm_mcpuserenvvars.find_many = AsyncMock( + return_value=[good_one, good_two, bad] + ) + prisma.db.litellm_mcpuserenvvars.update = AsyncMock() + + logger = MagicMock() + monkeypatch.setattr(mcp_db, "verbose_proxy_logger", logger) + + await rotate_mcp_user_env_vars_master_key(prisma, new_master_key="rotated-key-0000") + + update = prisma.db.litellm_mcpuserenvvars.update + assert update.await_count == 2 + updated_servers = { + call.kwargs["where"]["user_id_server_id"]["server_id"] + for call in update.call_args_list + } + assert updated_servers == {"srv-1", "srv-2"} # srv-3 was skipped, not rotated + for call in update.call_args_list: + assert call.kwargs["data"]["values_b64"] not in ( + good_one.values_b64, + good_two.values_b64, + ) + + logger.info.assert_called_once() + info_args = logger.info.call_args.args + assert info_args[1] == 2 # rotated + assert info_args[2] == 1 # skipped + + +def test_decrypt_global_env_var_drops_undecryptable_value( + env_vars_salt_key, monkeypatch +): + """A global value encrypted under a previous salt key must be dropped (not + forwarded as ciphertext) and surfaced as a warning, so a rotated + ``LITELLM_SALT_KEY`` can't silently leak ciphertext into ``${NAME}`` headers.""" + import json + from unittest.mock import MagicMock + + import litellm.proxy._experimental.mcp_server.db as mcp_db + from litellm.proxy._experimental.mcp_server.db import ( + _prepare_mcp_server_data, + decrypt_global_env_var_values, + ) + from litellm.proxy._types import MCPEnvVar + + req = _global_env_var_server_request( + [MCPEnvVar(name="DB_PASSWORD", value="s3cr3t-p@ss", scope="global")] + ) + entries = json.loads(_prepare_mcp_server_data(req)["env_vars"]) + ciphertext = entries[0]["value"] + assert ciphertext != "s3cr3t-p@ss" # encrypted under the original salt key + + # Rotate the salt key so the stored ciphertext no longer decrypts. + monkeypatch.setenv("LITELLM_SALT_KEY", "a-totally-different-salt-key-0000") + logger = MagicMock() + monkeypatch.setattr(mcp_db, "verbose_proxy_logger", logger) + + decrypt_global_env_var_values(entries) + + assert entries[0]["value"] == "" + assert ciphertext not in json.dumps(entries) + logger.warning.assert_called_once() + assert "DB_PASSWORD" in logger.warning.call_args.args + + +# ── REST exception handling ─────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_missing_user_env_vars_error_renders_in_mcp_call_tool(): + """The MCP ``call_tool`` handler must turn ``MCPMissingUserEnvVarsError`` + into a friendly ``CallToolResult`` with ``isError=True`` so Claude Code + surfaces the setup URL instead of an opaque internal error.""" + from mcp.types import TextContent + + err = _u("MCPMissingUserEnvVarsError")( + server_id="srv-99", + server_name="CorporateDB", + missing=["CORP_USERNAME"], + setup_url="/ui/?page=mcp-servers&fill_env_vars=srv-99", + ) + # We don't want to spin up the full MCP server framework — just + # mimic the except-clause behavior the @server.call_tool handler uses. + from mcp.types import CallToolResult + + result = CallToolResult( + content=[TextContent(text=str(err), type="text")], + isError=True, + ) + assert result.isError is True + text = result.content[0].text # type: ignore[union-attr] + assert "CorporateDB" in text + assert "CORP_USERNAME" in text + assert "fill_env_vars=srv-99" in text diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py index 227cf3f4bc..0b1240f8ba 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server.py @@ -3934,7 +3934,7 @@ class TestMCPServerManagerReload: ): await manager.reload_servers_from_database() - mock_build.assert_awaited_once_with(db_row) + mock_build.assert_awaited_once_with(db_row, env_vars_are_encrypted=True) assert manager.registry["server-1"] is rebuilt_server @pytest.mark.asyncio @@ -3965,7 +3965,7 @@ class TestMCPServerManagerReload: updated_at=timestamp, ) - async def build_server(db_row): + async def build_server(db_row, **kwargs): if db_row.server_id == "bad-server": raise RuntimeError("transient build failure") if db_row.server_id == "healthy-server": @@ -4031,7 +4031,7 @@ class TestMCPServerManagerReload: updated_at=timestamp, ) - async def build_server(db_row): + async def build_server(db_row, **kwargs): if db_row.server_id == "healthy-server": return healthy_server return bad_openapi_server diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py index 93d85773ce..48c09f6e45 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_server_manager.py @@ -29,10 +29,13 @@ from mcp.types import Tool as MCPTool from litellm.proxy._experimental.mcp_server.mcp_server_manager import ( MCPServerManager, _deserialize_json_dict, + _deserialize_json_list, ) from litellm.proxy._types import ( LiteLLM_MCPServerTable, MCPApprovalStatus, + MCPEnvVar, + MCPEnvVarScope, MCPTransport, ) from litellm.types.mcp import MCPAuth @@ -3057,6 +3060,57 @@ class TestMCPServerTimestamps: assert rebuilt_table.created_at == created assert rebuilt_table.updated_at == updated + def test_deserialize_json_list_normalizes_pydantic_models(self): + """Prisma hydrates the ``env_vars`` JSON column into ``MCPEnvVar`` models; + ``_deserialize_json_list`` must hand back plain dicts so ``MCPServer`` + (typed ``List[Dict[str, Any]]``) validates.""" + env_vars = [ + MCPEnvVar( + name="GITHUB_TOKEN", scope=MCPEnvVarScope.user, description="PAT" + ), + MCPEnvVar(name="REGION", value="us-east-1", scope=MCPEnvVarScope.global_), + ] + result = _deserialize_json_list(env_vars) + assert result is not None + assert all(isinstance(item, dict) for item in result) + assert result[0]["name"] == "GITHUB_TOKEN" + assert result[0]["scope"] == "user" + assert result[1]["value"] == "us-east-1" + + @pytest.mark.asyncio + async def test_build_mcp_server_from_table_with_model_env_vars(self): + """Regression: a DB row whose ``env_vars`` is a list of ``MCPEnvVar`` + models (as Prisma returns) must build into an ``MCPServer`` instead of + raising a Pydantic ``dict_type`` validation error that silently drops + the server from the registry.""" + manager = MCPServerManager() + + table_record = LiteLLM_MCPServerTable( + server_id="env-var-server-1", + server_name="github_peruser", + url="https://api.githubcopilot.com/mcp/", + transport=MCPTransport.http, + static_headers={"Authorization": "Bearer ${GITHUB_TOKEN}"}, + env_vars=[ + MCPEnvVar( + name="GITHUB_TOKEN", + scope=MCPEnvVarScope.user, + description="Your personal GitHub PAT", + ) + ], + ) + + mcp_server = await manager.build_mcp_server_from_table(table_record) + + assert mcp_server.env_vars == [ + { + "name": "GITHUB_TOKEN", + "value": "", + "scope": "user", + "description": "Your personal GitHub PAT", + } + ] + @pytest.mark.asyncio async def test_round_trip_source_url_preserved(self): """source_url survives the full round-trip: LiteLLM_MCPServerTable -> MCPServer -> LiteLLM_MCPServerTable. @@ -4107,6 +4161,179 @@ class TestApprovalStatusGate: assert "never-seen" not in manager.registry +class TestRegistryTableConversionPreservesEnvVars: + """The registry ``MCPServer`` -> ``LiteLLM_MCPServerTable`` conversions back + the GET /v1/mcp/server list and health responses, which populate the admin + edit form. When they dropped ``env_vars`` the form loaded an empty list and + saving any edit silently wiped the stored vars, so ``${VAR}`` static headers + were forwarded upstream un-interpolated. + """ + + @staticmethod + def _server_with_env_vars() -> MCPServer: + return MCPServer( + server_id="env-vars-server", + name="env_vars_server", + url="https://example.com/mcp", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + static_headers={"X-Db-Url": "${DB_PROTOCOL}://${CORP_USER}@${DB_HOST}"}, + env_vars=[ + { + "name": "DB_PROTOCOL", + "value": "postgresql", + "scope": "global", + "description": None, + }, + { + "name": "CORP_USER", + "value": "", + "scope": "user", + "description": "Your DB username", + }, + ], + ) + + @staticmethod + def _assert_env_vars_round_tripped(table: LiteLLM_MCPServerTable) -> None: + assert table.env_vars is not None + by_name = {entry.name: entry for entry in table.env_vars} + assert set(by_name) == {"DB_PROTOCOL", "CORP_USER"} + assert by_name["DB_PROTOCOL"].scope == MCPEnvVarScope.global_ + assert by_name["DB_PROTOCOL"].value == "postgresql" + assert by_name["CORP_USER"].scope == MCPEnvVarScope.user + assert by_name["CORP_USER"].description == "Your DB username" + + def test_build_mcp_server_table_preserves_env_vars(self): + manager = MCPServerManager() + table = manager._build_mcp_server_table(self._server_with_env_vars()) + self._assert_env_vars_round_tripped(table) + + @pytest.mark.asyncio + async def test_health_check_server_preserves_env_vars(self): + # OAuth2 without client credentials needs a per-user token, so the + # health check is skipped (no network) and we exercise the table + # construction path directly. + manager = MCPServerManager() + server = self._server_with_env_vars() + assert server.requires_per_user_auth is True + manager.registry[server.server_id] = server + table = await manager.health_check_server(server.server_id) + self._assert_env_vars_round_tripped(table) + + +class TestHealthCheckInterpolatesGlobalEnvVars: + """The upstream probes (health check and the initialize-instructions + prefetch) must substitute global ``${NAME}`` env vars into static headers + before opening the connection. Forwarding the raw placeholder makes any + server whose auth header is backed by a global env var fail authentication + and flip to 'unhealthy', even though real tool calls (which do interpolate) + keep working. + """ + + @staticmethod + def _server() -> MCPServer: + return MCPServer( + server_id="global-env-server", + name="global_env_server", + url="https://example.com/mcp", + transport=MCPTransport.http, + auth_type=MCPAuth.none, + static_headers={"Authorization": "Bearer ${API_TOKEN}"}, + env_vars=[ + { + "name": "API_TOKEN", + "value": "secret-token", + "scope": "global", + "description": None, + } + ], + ) + + @staticmethod + def _capture_headers(manager: MCPServerManager) -> Dict[str, Any]: + captured: Dict[str, Any] = {} + mock_client = AsyncMock() + mock_client.run_with_session = AsyncMock(return_value="ok") + + async def _create(server, mcp_auth_header, extra_headers, stdio_env): + captured["extra_headers"] = extra_headers + return mock_client + + manager._create_mcp_client = AsyncMock(side_effect=_create) + return captured + + @pytest.mark.asyncio + async def test_health_check_interpolates_global_env_vars(self): + manager = MCPServerManager() + server = self._server() + assert server.requires_per_user_auth is False + manager.get_mcp_server_by_id = MagicMock(return_value=server) + manager._remember_upstream_initialize_instructions = MagicMock() + captured = self._capture_headers(manager) + + result = await manager.health_check_server(server.server_id) + + assert captured["extra_headers"] == {"Authorization": "Bearer secret-token"} + assert result.status == "healthy" + + @pytest.mark.asyncio + async def test_initialize_instructions_prefetch_interpolates_global_env_vars(self): + manager = MCPServerManager() + server = self._server() + captured = self._capture_headers(manager) + manager._remember_upstream_initialize_instructions = MagicMock() + + await manager._ensure_upstream_initialize_instructions_cached(server) + + assert captured["extra_headers"] == {"Authorization": "Bearer secret-token"} + + +class TestUserEnvVarsCacheEviction: + """At capacity the per-user env var cache must shed a single oldest entry + rather than wiping every entry, so a steady stream of distinct callers does + not periodically stampede the DB by invalidating every still-valid value. + """ + + @staticmethod + def _patch_cache(monkeypatch, max_size): + from litellm.proxy._experimental.mcp_server import mcp_server_manager as m + + cache: Dict[Any, Any] = {} + monkeypatch.setattr(m, "_user_env_vars_cache", cache) + monkeypatch.setattr(m, "_USER_ENV_VARS_CACHE_MAX_SIZE", max_size) + return m, cache + + def test_eviction_drops_single_oldest_entry_not_whole_cache(self, monkeypatch): + m, cache = self._patch_cache(monkeypatch, max_size=3) + + for i in range(3): + m._write_user_env_vars_cache(f"user{i}", "srv", {"V": str(i)}) + assert set(cache) == {("user0", "srv"), ("user1", "srv"), ("user2", "srv")} + + m._write_user_env_vars_cache("user3", "srv", {"V": "3"}) + + assert len(cache) == 3 + assert ("user0", "srv") not in cache + assert ("user3", "srv") in cache + assert cache[("user1", "srv")][0] == {"V": "1"} + + def test_refreshing_existing_key_does_not_evict(self, monkeypatch): + m, cache = self._patch_cache(monkeypatch, max_size=2) + + m._write_user_env_vars_cache("a", "srv", {"V": "1"}) + m._write_user_env_vars_cache("b", "srv", {"V": "2"}) + m._write_user_env_vars_cache("a", "srv", {"V": "1-new"}) + + assert set(cache) == {("a", "srv"), ("b", "srv")} + assert cache[("a", "srv")][0] == {"V": "1-new"} + # The just-refreshed key must now sit at the tail so the next insert + # evicts the genuinely older entry instead. + m._write_user_env_vars_cache("c", "srv", {"V": "3"}) + assert ("b", "srv") not in cache + assert ("a", "srv") in cache + + class TestGetPublicMCPServers: """ /public/mcp_hub strict-whitelist semantics — mirrors /public/model_hub diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_sigv4_auth.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_sigv4_auth.py index 0c2a8bb808..c2164a9f19 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_sigv4_auth.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_mcp_sigv4_auth.py @@ -1007,6 +1007,7 @@ class TestRotateCredentials: "aws_secret_access_key": "enc_old:SAK", "aws_region_name": "us-east-1", } + server.env_vars = None mock_prisma = MagicMock() mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock( @@ -1043,6 +1044,58 @@ class TestRotateCredentials: # Non-secret fields should pass through unchanged assert stored_creds["aws_region_name"] == "us-east-1" + @pytest.mark.asyncio + async def test_rotation_reencrypts_global_env_vars(self): + """Global env var values are re-encrypted under the new key; user-scope + placeholders are left untouched.""" + from litellm.proxy._experimental.mcp_server.db import ( + rotate_mcp_server_credentials_master_key, + ) + + server = MagicMock() + server.server_id = "srv-env" + server.credentials = None + server.env_vars = [ + {"name": "API_KEY", "value": "enc_old:secret", "scope": "global"}, + {"name": "USER_TOKEN", "value": "", "scope": "user"}, + ] + + mock_prisma = MagicMock() + mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock( + return_value=[server] + ) + mock_prisma.db.litellm_mcpservertable.update = AsyncMock() + + with ( + patch( + "litellm.proxy._experimental.mcp_server.db._get_salt_key", + return_value="old-key", + ), + patch( + "litellm.proxy._experimental.mcp_server.db.decrypt_value_helper", + side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace( + "enc_old:", "" + ), + ), + patch( + "litellm.proxy._experimental.mcp_server.db.encrypt_value_helper", + side_effect=lambda value, new_encryption_key: f"enc_new:{value}", + ), + ): + await rotate_mcp_server_credentials_master_key( + mock_prisma, "admin", "new-key" + ) + + update_call = mock_prisma.db.litellm_mcpservertable.update + assert update_call.called + stored_env = json.loads(update_call.call_args[1]["data"]["env_vars"]) + # Global value decrypted from old, then re-encrypted with new key + assert stored_env[0]["value"] == "enc_new:secret" + # User-scope placeholder untouched + assert stored_env[1]["value"] == "" + # Credentials column not written when the server has none + assert "credentials" not in update_call.call_args[1]["data"] + class TestAuthTypeSwitchClearsCredentials: """Test that switching auth_type without credentials clears stale secrets.""" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py index 6433e0f636..caff9ea2d2 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_rest_endpoints.py @@ -2,6 +2,7 @@ import json from typing import Any, Dict, Optional from unittest.mock import MagicMock +import httpx import pytest from fastapi import HTTPException from starlette.requests import Request @@ -1629,3 +1630,48 @@ class TestPreviewOpenAPITools: "order is out of sync, so collision suffixes (_2, _3, ...) " "land on different operations" ) + + +class TestConnectionErrorMessage: + """The test-connection endpoints turn raw transport errors into messages. + + The message is returned to an admin in an API response, so it must explain + the failure without echoing the raw header value, which can carry a secret + (e.g. ``Authorization: Bearer ``). + """ + + def test_local_protocol_error_is_actionable_and_redacted(self): + secret = "Bearer sk-super-secret-token" + exc = httpx.LocalProtocolError(f"Illegal header value b' {secret}'") + + message = rest_endpoints._connection_error_message(exc) + + assert "header" in message.lower() + assert secret not in message + + def test_connect_error_points_at_reachability(self): + message = rest_endpoints._connection_error_message( + httpx.ConnectError("All connection attempts failed") + ) + assert "unreachable" in message.lower() + + def test_timeout_error_message(self): + message = rest_endpoints._connection_error_message( + httpx.ConnectTimeout("timed out") + ) + assert "unreachable" in message.lower() + + def test_http_status_error_includes_status_code(self): + response = httpx.Response(status_code=503) + exc = httpx.HTTPStatusError( + "server error", + request=httpx.Request("POST", "http://x/"), + response=response, + ) + message = rest_endpoints._connection_error_message(exc) + assert "503" in message + + def test_unknown_error_falls_back_to_generic(self): + message = rest_endpoints._connection_error_message(RuntimeError("weird")) + assert "weird" not in message + assert "proxy logs" in message.lower() diff --git a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py index 7e044442e2..b5eb091bb8 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_mcp_management_endpoints.py @@ -1105,6 +1105,98 @@ class TestListMCPServers: assert result.status == "healthy" mock_manager.get_allowed_mcp_servers.assert_called_once_with(mock_user_auth) + @pytest.mark.asyncio + async def test_fetch_single_mcp_server_drops_env_vars_for_non_admin(self): + """A non-admin GET /v1/mcp/server/{id} for a server with env_vars must + not 500 and must not leak env var config. ``db.get_mcp_server`` returns + the raw Prisma model whose JSONB ``env_vars`` deserialize to plain + dicts; it is wrapped in ``LiteLLM_MCPServerTable`` (parsing the dicts + into ``MCPEnvVar``) before sanitization. The non-admin sanitizer then + drops ``env_vars`` entirely, since even the names (e.g. GLOBAL_KEY) + reveal which secrets the admin configured. + """ + + # Mirror what Prisma returns: a model whose JSONB ``env_vars`` are + # plain dicts, not parsed ``MCPEnvVar`` objects. ``model_construct`` + # skips validation so the dicts survive verbatim. + raw_prisma_model = LiteLLM_MCPServerTable.model_construct( + server_id="env-server", + server_name="Env Server", + alias="Env Server", + transport=MCPTransport.http, + url="https://env.example.com/mcp", + static_headers={ + "Authorization": "Bearer ${GLOBAL_KEY}", + "X-User": "${USER_KEY}", + }, + env_vars=[ + {"name": "GLOBAL_KEY", "value": "super-secret", "scope": "global"}, + { + "name": "USER_KEY", + "value": "", + "scope": "user", + "description": "your key", + }, + ], + ) + assert isinstance(raw_prisma_model.env_vars[0], dict) + + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_mcpservertable.find_unique = AsyncMock( + return_value=raw_prisma_model + ) + + mock_health_result = generate_mock_mcp_server_db_record( + server_id="env-server", alias="Env Server" + ) + mock_health_result.status = "healthy" + mock_health_result.last_health_check = datetime.now() + mock_health_result.health_check_error = None + + mock_manager = MagicMock() + mock_manager.add_server = AsyncMock() + mock_manager.health_check_server = AsyncMock(return_value=mock_health_result) + + mock_user_auth = generate_mock_user_api_key_auth( + user_role=LitellmUserRoles.INTERNAL_USER + ) + + with ( + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", + return_value=mock_prisma_client, + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager", + mock_manager, + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_all_mcp_servers_for_user", + AsyncMock( + return_value=[ + generate_mock_mcp_server_db_record(server_id="env-server") + ] + ), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints._user_has_admin_view", + return_value=False, + ), + ): + from litellm.proxy.management_endpoints.mcp_management_endpoints import ( + fetch_mcp_server, + ) + + result = await fetch_mcp_server( + request=_make_mock_request(), + server_id="env-server", + user_api_key_dict=mock_user_auth, + ) + + assert result.server_id == "env-server" + # Non-admin viewers get no env var config at all (not even names). + assert result.env_vars is None + class TestTeamScopedMCPServerAccess: """Tests for cross-team information disclosure and restricted key bypass fixes.""" @@ -2308,6 +2400,109 @@ class TestUpdateMCPServer: assert result.alias == "Updated Test Server" +class TestAddMCPServerAtomicity: + """A committed MCP server must survive a post-write registry refresh failure. + + Regression: add_mcp_server inserted the row and then reloaded the whole + registry from the database inside the same try block. One unrelated malformed + row made the reload raise, so the endpoint returned 500 even though the new + row was already persisted. Callers assumed failure and retried, creating + duplicate servers. + """ + + @pytest.mark.asyncio + async def test_create_succeeds_when_registry_refresh_fails(self): + from litellm.proxy.management_endpoints.mcp_management_endpoints import ( + add_mcp_server, + ) + + payload = NewMCPServerRequest( + alias="echo", + url="https://echo.example.com/mcp", + transport=MCPTransport.http, + ) + admin = generate_mock_user_api_key_auth( + user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin-user" + ) + created_server = generate_mock_mcp_server_db_record( + server_id="created-1", alias="echo" + ) + + mock_manager = MagicMock() + mock_manager.add_server = AsyncMock() + mock_manager.reload_servers_from_database = AsyncMock( + side_effect=Exception("malformed pre-existing row") + ) + + with ( + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", + return_value=MagicMock(), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.validate_and_normalize_mcp_server_payload", + MagicMock(), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.create_mcp_server", + AsyncMock(return_value=created_server), + ) as create_mock, + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager", + mock_manager, + ), + ): + result = await add_mcp_server(payload=payload, user_api_key_dict=admin) + + create_mock.assert_awaited_once() + mock_manager.reload_servers_from_database.assert_awaited_once() + assert result.server_id == "created-1" + + @pytest.mark.asyncio + async def test_create_500s_and_skips_registry_when_db_write_fails(self): + from litellm.proxy.management_endpoints.mcp_management_endpoints import ( + add_mcp_server, + ) + + payload = NewMCPServerRequest( + alias="echo", + url="https://echo.example.com/mcp", + transport=MCPTransport.http, + ) + admin = generate_mock_user_api_key_auth( + user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin-user" + ) + + mock_manager = MagicMock() + mock_manager.add_server = AsyncMock() + mock_manager.reload_servers_from_database = AsyncMock() + + with ( + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", + return_value=MagicMock(), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.validate_and_normalize_mcp_server_payload", + MagicMock(), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.create_mcp_server", + AsyncMock(side_effect=Exception("db down")), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager", + mock_manager, + ), + ): + with pytest.raises(HTTPException) as exc_info: + await add_mcp_server(payload=payload, user_api_key_dict=admin) + + assert exc_info.value.status_code == 500 + mock_manager.add_server.assert_not_awaited() + mock_manager.reload_servers_from_database.assert_not_awaited() + + class TestHealthCheckServers: """Test suite for health check servers endpoint""" @@ -2740,6 +2935,65 @@ class TestMCPApprovalWorkflow: assert result.total == 1 assert result.pending_review == 1 + @pytest.mark.asyncio + @pytest.mark.parametrize( + "user_role, expected_global_value", + [ + (LitellmUserRoles.PROXY_ADMIN, "super-secret"), + (LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, ""), + ], + ) + async def test_get_submissions_redacts_global_env_for_view_only_admin( + self, user_role, expected_global_value + ): + """Read-only admins reviewing the submission queue must not receive the + submitter's global env var secrets; full admins still see them.""" + from litellm.proxy._types import MCPSubmissionsSummary + from litellm.proxy.management_endpoints.mcp_management_endpoints import ( + get_mcp_server_submissions, + ) + + base = generate_mock_mcp_server_db_record(alias="Pending") + item = LiteLLM_MCPServerTable( + **{ + **base.model_dump(), + "env_vars": [ + { + "name": "ADMIN_API_KEY", + "value": "super-secret", + "scope": "global", + }, + { + "name": "USER_TOKEN", + "value": "placeholder-hint", + "scope": "user", + }, + ], + } + ) + item.approval_status = "pending_review" + summary = MCPSubmissionsSummary( + total=1, pending_review=1, active=0, rejected=0, items=[item] + ) + + with ( + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", + return_value=MagicMock(), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_mcp_submissions", + AsyncMock(return_value=summary), + ), + ): + result = await get_mcp_server_submissions( + user_api_key_dict=generate_mock_user_api_key_auth(user_role=user_role), + ) + + by_name = {ev.name: ev for ev in result.items[0].env_vars} + assert by_name["ADMIN_API_KEY"].value == expected_global_value + assert by_name["USER_TOKEN"].value == "placeholder-hint" + @pytest.mark.asyncio async def test_approve_non_pending_server_raises_400(self): from litellm.proxy._types import MCPApprovalStatus @@ -3319,6 +3573,853 @@ def test_sanitize_mcp_server_for_non_admin_clears_credential_fields(): assert sanitized.alias == server.alias +def _server_with_global_and_user_env_vars(): + base = generate_mock_mcp_server_db_record() + return LiteLLM_MCPServerTable( + **{ + **base.model_dump(), + "env_vars": [ + {"name": "ADMIN_API_KEY", "value": "super-secret", "scope": "global"}, + {"name": "USER_TOKEN", "value": "placeholder-hint", "scope": "user"}, + ], + } + ) + + +def test_sanitize_non_admin_drops_all_env_vars(): + """The non-admin view drops env vars entirely; even the names are admin + config metadata (e.g. DB_PASSWORD) that must not leak. Non-admins get the + per-user vars they need from the /user-env-vars/status endpoint.""" + import litellm.proxy.management_endpoints.mcp_management_endpoints as mgmt + + server = _server_with_global_and_user_env_vars() + + sanitized = mgmt._sanitize_mcp_server_for_non_admin(server) + + assert sanitized.env_vars is None + + # The original object must not be mutated. + original_by_name = {ev.name: ev for ev in server.env_vars} + assert original_by_name["ADMIN_API_KEY"].value == "super-secret" + + +def test_sanitize_virtual_key_drops_all_env_vars(): + """Virtual-key callers get a discovery-only view; env var entries (even the + names, which are admin config metadata) must be dropped entirely, not just + have their global values blanked.""" + import litellm.proxy.management_endpoints.mcp_management_endpoints as mgmt + + server = _server_with_global_and_user_env_vars() + + sanitized = mgmt._sanitize_mcp_server_for_virtual_key(server) + + assert sanitized.env_vars is None + + # The original object must not be mutated. + assert server.env_vars[0].value == "super-secret" + + +def _server_with_env_vars(server_id: str = "srv-env"): + base = generate_mock_mcp_server_db_record(server_id=server_id) + return LiteLLM_MCPServerTable( + **{ + **base.model_dump(), + "env_vars": [ + {"name": "ADMIN_API_KEY", "value": "super-secret", "scope": "global"}, + {"name": "USER_TOKEN", "value": "placeholder-hint", "scope": "user"}, + ], + } + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "user_role, expected_global_value", + [ + (LitellmUserRoles.PROXY_ADMIN, "super-secret"), + (LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, ""), + ], +) +async def test_fetch_single_mcp_server_redacts_global_env_for_view_only_admin( + user_role, expected_global_value +): + """Read-only admins must not receive admin-supplied global env var secrets; + full admins still see them so the edit form can pre-fill.""" + server = _server_with_env_vars() + + health_result = generate_mock_mcp_server_db_record(server_id=server.server_id) + health_result.status = "healthy" + health_result.last_health_check = datetime.now() + health_result.health_check_error = None + + with ( + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_prisma_client_or_throw", + return_value=MagicMock(), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.get_mcp_server", + AsyncMock(return_value=server), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager.add_server", + AsyncMock(return_value=None), + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager.health_check_server", + AsyncMock(return_value=health_result), + ), + ): + result = await mgmt_endpoints.fetch_mcp_server( + request=_make_mock_request(), + server_id=server.server_id, + user_api_key_dict=generate_mock_user_api_key_auth(user_role=user_role), + ) + + by_name = {ev.name: ev for ev in result.env_vars} + assert by_name["ADMIN_API_KEY"].value == expected_global_value + # Per-user placeholders are always preserved. + assert by_name["USER_TOKEN"].value == "placeholder-hint" + # The source record must never be mutated. + assert {ev.name: ev.value for ev in server.env_vars}[ + "ADMIN_API_KEY" + ] == "super-secret" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "user_role, expected_global_value", + [ + (LitellmUserRoles.PROXY_ADMIN, "super-secret"), + (LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, ""), + ], +) +async def test_fetch_all_mcp_servers_redacts_global_env_for_view_only_admin( + user_role, expected_global_value +): + server = _server_with_env_vars() + + with ( + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints._get_user_mcp_management_mode", + return_value="view_all", + ), + patch( + "litellm.proxy.management_endpoints.mcp_management_endpoints.global_mcp_server_manager.get_all_mcp_servers_unfiltered", + AsyncMock(return_value=[server]), + ), + patch( + "litellm.proxy.proxy_server.prisma_client", + None, + ), + ): + result = await mgmt_endpoints.fetch_all_mcp_servers( + user_api_key_dict=generate_mock_user_api_key_auth(user_role=user_role), + ) + + by_name = {ev.name: ev for ev in result[0].env_vars} + assert by_name["ADMIN_API_KEY"].value == expected_global_value + assert by_name["USER_TOKEN"].value == "placeholder-hint" + assert {ev.name: ev.value for ev in server.env_vars}[ + "ADMIN_API_KEY" + ] == "super-secret" + + +def _make_env_var_server( + *, + server_id: str = "srv-1", + server_name: str = "DB Server", + alias: str = "db_server", + env_vars=None, + static_headers=None, +): + """Lightweight server stand-in for the per-user env-var endpoints. + + The handlers only read ``server_id``/``server_name``/``alias``/``env_vars``/ + ``static_headers`` via ``getattr``, so a SimpleNamespace is enough and keeps + the test decoupled from the full Prisma model. + """ + return SimpleNamespace( + server_id=server_id, + server_name=server_name, + alias=alias, + env_vars=env_vars, + static_headers=static_headers, + ) + + +# env_vars with two referenced per-user fields, one unreferenced per-user field +# (must NOT be blocking), and a global value. +_ENV_VARS_MIXED = [ + {"name": "DB_PROTOCOL", "value": "postgres", "scope": "global"}, + { + "name": "CORP_USERNAME", + "value": "", + "scope": "user", + "description": "Your username", + }, + {"name": "CORP_PASSWORD", "value": "", "scope": "user"}, + {"name": "UNUSED_USER_VAR", "value": "", "scope": "user"}, +] +_STATIC_HEADERS_MIXED = { + "Authorization": "${DB_PROTOCOL}://${CORP_USERNAME}:${CORP_PASSWORD}@host/db", +} + + +class TestComputeUserEnvVarStatus: + """Unit tests for the _compute_user_env_var_status helper.""" + + def test_only_referenced_per_user_vars_are_required(self): + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, static_headers=_STATIC_HEADERS_MIXED + ) + status = mgmt_endpoints._compute_user_env_var_status( + server=server, stored_values={"CORP_USERNAME": "alice"} + ) + names = {spec.name for spec in status.required} + # UNUSED_USER_VAR is declared per-user but never referenced -> not blocking. + assert names == {"CORP_USERNAME", "CORP_PASSWORD"} + by_name = {spec.name: spec for spec in status.required} + assert by_name["CORP_USERNAME"].is_set is True + assert by_name["CORP_USERNAME"].description == "Your username" + assert by_name["CORP_PASSWORD"].is_set is False + # Stored credentials are write-only: the secret is never echoed back. + assert "alice" not in status.model_dump_json() + assert status.missing_count == 1 + assert status.server_id == "srv-1" + assert status.server_name == "DB Server" + assert status.alias == "db_server" + # required is non-empty -> a setup URL is provided. + assert status.setup_url and "srv-1" in status.setup_url + + def test_all_filled_has_zero_missing(self): + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, static_headers=_STATIC_HEADERS_MIXED + ) + status = mgmt_endpoints._compute_user_env_var_status( + server=server, + stored_values={"CORP_USERNAME": "alice", "CORP_PASSWORD": "s3cret"}, + ) + assert status.missing_count == 0 + assert all(spec.is_set for spec in status.required) + + def test_static_headers_as_json_string_is_parsed(self): + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, + static_headers='{"Authorization": "${CORP_USERNAME}"}', + ) + status = mgmt_endpoints._compute_user_env_var_status( + server=server, stored_values={} + ) + # Only CORP_USERNAME is referenced via the JSON-string headers. + assert {spec.name for spec in status.required} == {"CORP_USERNAME"} + assert status.missing_count == 1 + + def test_static_headers_invalid_json_string_yields_no_required(self): + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, static_headers="not-json{" + ) + status = mgmt_endpoints._compute_user_env_var_status( + server=server, stored_values={} + ) + assert status.required == [] + assert status.missing_count == 0 + # No required fields -> no setup URL. + assert status.setup_url is None + + def test_no_per_user_vars_referenced_yields_no_required(self): + server = _make_env_var_server( + env_vars=[{"name": "DB_PROTOCOL", "value": "postgres", "scope": "global"}], + static_headers={"Authorization": "${DB_PROTOCOL}://host"}, + ) + status = mgmt_endpoints._compute_user_env_var_status( + server=server, stored_values={} + ) + assert status.required == [] + assert status.setup_url is None + + def test_dual_scope_var_with_global_fallback_is_not_required(self): + # SHARED_TOKEN is declared both global and user. The global value covers + # the reference (globals win in _resolve_static_headers_with_env_vars), + # so the tool-call path never raises a 412 for it. The status endpoint + # must agree and not report it as required/missing, otherwise it asks the + # user for a credential the request would never actually need. + server = _make_env_var_server( + env_vars=[ + {"name": "SHARED_TOKEN", "value": "global-secret", "scope": "global"}, + {"name": "SHARED_TOKEN", "value": "", "scope": "user"}, + ], + static_headers={"Authorization": "Bearer ${SHARED_TOKEN}"}, + ) + status = mgmt_endpoints._compute_user_env_var_status( + server=server, stored_values={} + ) + assert status.required == [] + assert status.missing_count == 0 + assert status.setup_url is None + + def test_dual_scope_var_with_empty_global_is_required(self): + # SHARED_TOKEN is declared both global (empty value) and user. An empty + # global is not a usable fallback, so _resolve_static_headers_with_env_vars + # still requires the user value and the tool-call path 412s without it. The + # status endpoint must agree and report it required, or it would tell the + # user no credential is needed for a var every call rejects. + server = _make_env_var_server( + env_vars=[ + {"name": "SHARED_TOKEN", "value": "", "scope": "global"}, + {"name": "SHARED_TOKEN", "value": "", "scope": "user"}, + ], + static_headers={"Authorization": "Bearer ${SHARED_TOKEN}"}, + ) + status = mgmt_endpoints._compute_user_env_var_status( + server=server, stored_values={} + ) + assert {spec.name for spec in status.required} == {"SHARED_TOKEN"} + assert status.missing_count == 1 + assert status.setup_url and "srv-1" in status.setup_url + + +class TestGetMCPUserEnvVars: + @pytest.mark.asyncio + async def test_returns_status_for_server(self): + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, static_headers=_STATIC_HEADERS_MIXED + ) + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=server) + ), + patch.object( + mgmt_endpoints, + "get_user_env_vars", + AsyncMock(return_value={"CORP_USERNAME": "alice"}), + ), + ): + result = await mgmt_endpoints.get_mcp_user_env_vars( + server_id="srv-1", + user_api_key_dict=generate_mock_user_api_key_auth(user_id="alice"), + ) + assert result.server_id == "srv-1" + assert result.missing_count == 1 + assert {s.name for s in result.required} == {"CORP_USERNAME", "CORP_PASSWORD"} + # The single-server endpoint reports which credentials are set without + # ever echoing the decrypted secret back to the caller. + by_name = {s.name: s for s in result.required} + assert by_name["CORP_USERNAME"].is_set is True + assert by_name["CORP_PASSWORD"].is_set is False + assert "alice" not in result.model_dump_json() + + @pytest.mark.asyncio + async def test_missing_user_id_raises_400(self): + with patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ): + with pytest.raises(HTTPException) as exc: + await mgmt_endpoints.get_mcp_user_env_vars( + server_id="srv-1", + user_api_key_dict=generate_mock_user_api_key_auth(user_id=""), + ) + assert exc.value.status_code == 400 + + @pytest.mark.asyncio + async def test_unknown_server_raises_404(self): + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=None) + ), + ): + with pytest.raises(HTTPException) as exc: + await mgmt_endpoints.get_mcp_user_env_vars( + server_id="missing", + user_api_key_dict=generate_mock_user_api_key_auth(user_id="alice"), + ) + assert exc.value.status_code == 404 + + +class TestStoreMCPUserEnvVars: + @pytest.mark.asyncio + async def test_persists_only_allowed_non_empty_values(self): + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, static_headers=_STATIC_HEADERS_MIXED + ) + merge_mock = AsyncMock(return_value={"CORP_USERNAME": "alice"}) + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=server) + ), + patch.object(mgmt_endpoints, "merge_user_env_vars", merge_mock), + ): + result = await mgmt_endpoints.store_mcp_user_env_vars( + server_id="srv-1", + payload=mgmt_endpoints.MCPUserEnvVarsRequest( + values={ + "CORP_USERNAME": "alice", + "CORP_PASSWORD": "", # empty -> dropped + "NOT_A_DECLARED_VAR": "x", # unknown -> dropped + } + ), + user_api_key_dict=generate_mock_user_api_key_auth(user_id="alice"), + ) + # Only the declared, non-empty value reaches the atomic merge, scoped to + # the admin-declared user vars. + merge_mock.assert_awaited_once() + _, _, _, updates, allowed_names = merge_mock.await_args.args + assert updates == {"CORP_USERNAME": "alice"} + assert set(allowed_names) == { + "CORP_USERNAME", + "CORP_PASSWORD", + "UNUSED_USER_VAR", + } + # CORP_PASSWORD remains unset in the returned status. + assert result.missing_count == 1 + + @pytest.mark.asyncio + async def test_forwards_only_submitted_updates_and_returns_merged_status(self): + """The endpoint forwards only the user's submitted (allowed, non-empty) + update to the atomic merge and reports status from the merged result, so + a one-field edit never sends the other stored values back through.""" + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, static_headers=_STATIC_HEADERS_MIXED + ) + merge_mock = AsyncMock( + return_value={"CORP_USERNAME": "alice", "CORP_PASSWORD": "new"} + ) + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=server) + ), + patch.object(mgmt_endpoints, "merge_user_env_vars", merge_mock), + ): + result = await mgmt_endpoints.store_mcp_user_env_vars( + server_id="srv-1", + payload=mgmt_endpoints.MCPUserEnvVarsRequest( + values={"CORP_PASSWORD": "new"} + ), + user_api_key_dict=generate_mock_user_api_key_auth(user_id="alice"), + ) + merge_mock.assert_awaited_once() + _, _, _, updates, _ = merge_mock.await_args.args + assert updates == {"CORP_PASSWORD": "new"} + # Status reflects the merged set returned by the atomic merge. + assert result.missing_count == 0 + + @pytest.mark.asyncio + async def test_missing_user_id_raises_400(self): + with patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ): + with pytest.raises(HTTPException) as exc: + await mgmt_endpoints.store_mcp_user_env_vars( + server_id="srv-1", + payload=mgmt_endpoints.MCPUserEnvVarsRequest(values={}), + user_api_key_dict=generate_mock_user_api_key_auth(user_id=""), + ) + assert exc.value.status_code == 400 + + @pytest.mark.asyncio + async def test_unknown_server_raises_404(self): + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=None) + ), + ): + with pytest.raises(HTTPException) as exc: + await mgmt_endpoints.store_mcp_user_env_vars( + server_id="missing", + payload=mgmt_endpoints.MCPUserEnvVarsRequest(values={}), + user_api_key_dict=generate_mock_user_api_key_auth(user_id="alice"), + ) + assert exc.value.status_code == 404 + + +class TestClearMCPUserEnvVars: + @pytest.mark.asyncio + async def test_clears_and_returns_empty_status(self): + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, static_headers=_STATIC_HEADERS_MIXED + ) + delete_mock = AsyncMock() + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=server) + ), + patch.object(mgmt_endpoints, "delete_user_env_vars", delete_mock), + ): + result = await mgmt_endpoints.clear_mcp_user_env_vars( + server_id="srv-1", + user_api_key_dict=generate_mock_user_api_key_auth(user_id="alice"), + ) + delete_mock.assert_awaited_once() + # Everything is now unset. + assert result.missing_count == 2 + assert all(not spec.is_set for spec in result.required) + + @pytest.mark.asyncio + async def test_delete_db_error_propagates(self): + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, static_headers=_STATIC_HEADERS_MIXED + ) + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=server) + ), + patch.object( + mgmt_endpoints, + "delete_user_env_vars", + AsyncMock(side_effect=Exception("db down")), + ), + ): + # A real DB failure must surface, not be masked as a successful clear. + with pytest.raises(Exception, match="db down"): + await mgmt_endpoints.clear_mcp_user_env_vars( + server_id="srv-1", + user_api_key_dict=generate_mock_user_api_key_auth(user_id="alice"), + ) + + @pytest.mark.asyncio + async def test_missing_user_id_raises_400(self): + with patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ): + with pytest.raises(HTTPException) as exc: + await mgmt_endpoints.clear_mcp_user_env_vars( + server_id="srv-1", + user_api_key_dict=generate_mock_user_api_key_auth(user_id=""), + ) + assert exc.value.status_code == 400 + + @pytest.mark.asyncio + async def test_unknown_server_raises_404(self): + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=None) + ), + ): + with pytest.raises(HTTPException) as exc: + await mgmt_endpoints.clear_mcp_user_env_vars( + server_id="missing", + user_api_key_dict=generate_mock_user_api_key_auth(user_id="alice"), + ) + assert exc.value.status_code == 404 + + +class TestListMCPUserEnvVarStatus: + @pytest.mark.asyncio + async def test_no_user_id_returns_empty(self): + with patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ): + result = await mgmt_endpoints.list_mcp_user_env_var_status( + user_api_key_dict=generate_mock_user_api_key_auth(user_id="") + ) + assert result == [] + + @pytest.mark.asyncio + async def test_no_accessible_servers_returns_empty(self): + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, + "get_all_mcp_servers_for_user", + AsyncMock(return_value=[]), + ), + ): + result = await mgmt_endpoints.list_mcp_user_env_var_status( + user_api_key_dict=generate_mock_user_api_key_auth(user_id="alice") + ) + assert result == [] + + @pytest.mark.asyncio + async def test_only_servers_with_required_fields_are_returned(self): + server_with = _make_env_var_server( + server_id="srv-with", + env_vars=_ENV_VARS_MIXED, + static_headers=_STATIC_HEADERS_MIXED, + ) + # No per-user var is referenced -> contributes no status entry. + server_without = _make_env_var_server( + server_id="srv-without", + env_vars=[{"name": "DB_PROTOCOL", "value": "postgres", "scope": "global"}], + static_headers={"Authorization": "${DB_PROTOCOL}://host"}, + ) + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, + "get_all_mcp_servers_for_user", + AsyncMock(return_value=[server_with, server_without]), + ), + patch.object( + mgmt_endpoints, + "get_user_env_vars_bulk", + AsyncMock(return_value={"srv-with": {"CORP_USERNAME": "alice"}}), + ), + ): + result = await mgmt_endpoints.list_mcp_user_env_var_status( + user_api_key_dict=generate_mock_user_api_key_auth(user_id="alice") + ) + assert [s.server_id for s in result] == ["srv-with"] + assert result[0].missing_count == 1 + + @pytest.mark.asyncio + async def test_bulk_status_omits_stored_credential_values(self): + """The bulk feed only drives the "fields missing" badge, so it must not + echo stored credential values back; is_set still reflects presence.""" + server = _make_env_var_server( + server_id="srv-with", + env_vars=_ENV_VARS_MIXED, + static_headers=_STATIC_HEADERS_MIXED, + ) + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, + "get_all_mcp_servers_for_user", + AsyncMock(return_value=[server]), + ), + patch.object( + mgmt_endpoints, + "get_user_env_vars_bulk", + AsyncMock(return_value={"srv-with": {"CORP_USERNAME": "alice"}}), + ), + ): + result = await mgmt_endpoints.list_mcp_user_env_var_status( + user_api_key_dict=generate_mock_user_api_key_auth(user_id="alice") + ) + by_name = {s.name: s for s in result[0].required} + assert by_name["CORP_USERNAME"].is_set is True + assert by_name["CORP_PASSWORD"].is_set is False + assert "alice" not in result[0].model_dump_json() + + +class TestMCPUserEnvVarsAccessControl: + """Per-server env-var endpoints must enforce the same access gate as + fetch_mcp_server: a non-admin caller can only touch servers in their + allowed set.""" + + @pytest.mark.asyncio + async def test_get_forbidden_for_non_admin_without_access(self): + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, static_headers=_STATIC_HEADERS_MIXED + ) + get_user_env_vars = AsyncMock(return_value={}) + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=server) + ), + patch.object( + mgmt_endpoints, + "get_all_mcp_servers_for_user", + AsyncMock(return_value=[_make_env_var_server(server_id="other")]), + ), + patch.object(mgmt_endpoints, "get_user_env_vars", get_user_env_vars), + ): + with pytest.raises(HTTPException) as exc: + await mgmt_endpoints.get_mcp_user_env_vars( + server_id="srv-1", + user_api_key_dict=generate_mock_user_api_key_auth( + user_id="alice", + user_role=LitellmUserRoles.INTERNAL_USER, + ), + ) + assert exc.value.status_code == 403 + get_user_env_vars.assert_not_awaited() + + @pytest.mark.asyncio + async def test_store_forbidden_for_non_admin_without_access(self): + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, static_headers=_STATIC_HEADERS_MIXED + ) + merge_mock = AsyncMock() + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=server) + ), + patch.object( + mgmt_endpoints, + "get_all_mcp_servers_for_user", + AsyncMock(return_value=[]), + ), + patch.object(mgmt_endpoints, "merge_user_env_vars", merge_mock), + ): + with pytest.raises(HTTPException) as exc: + await mgmt_endpoints.store_mcp_user_env_vars( + server_id="srv-1", + payload=mgmt_endpoints.MCPUserEnvVarsRequest( + values={"CORP_USERNAME": "alice"} + ), + user_api_key_dict=generate_mock_user_api_key_auth( + user_id="alice", + user_role=LitellmUserRoles.INTERNAL_USER, + ), + ) + assert exc.value.status_code == 403 + merge_mock.assert_not_awaited() + + @pytest.mark.asyncio + async def test_clear_forbidden_for_non_admin_without_access(self): + server = _make_env_var_server( + env_vars=_ENV_VARS_MIXED, static_headers=_STATIC_HEADERS_MIXED + ) + delete_mock = AsyncMock() + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=server) + ), + patch.object( + mgmt_endpoints, + "get_all_mcp_servers_for_user", + AsyncMock(return_value=[]), + ), + patch.object(mgmt_endpoints, "delete_user_env_vars", delete_mock), + ): + with pytest.raises(HTTPException) as exc: + await mgmt_endpoints.clear_mcp_user_env_vars( + server_id="srv-1", + user_api_key_dict=generate_mock_user_api_key_auth( + user_id="alice", + user_role=LitellmUserRoles.INTERNAL_USER, + ), + ) + assert exc.value.status_code == 403 + delete_mock.assert_not_awaited() + + @pytest.mark.asyncio + async def test_get_allowed_for_non_admin_with_access(self): + server = _make_env_var_server( + server_id="srv-1", + env_vars=_ENV_VARS_MIXED, + static_headers=_STATIC_HEADERS_MIXED, + ) + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=server) + ), + patch.object( + mgmt_endpoints, + "get_all_mcp_servers_for_user", + AsyncMock(return_value=[server]), + ), + patch.object( + mgmt_endpoints, + "get_user_env_vars", + AsyncMock(return_value={"CORP_USERNAME": "alice"}), + ), + ): + result = await mgmt_endpoints.get_mcp_user_env_vars( + server_id="srv-1", + user_api_key_dict=generate_mock_user_api_key_auth( + user_id="alice", + user_role=LitellmUserRoles.INTERNAL_USER, + ), + ) + assert result.server_id == "srv-1" + assert result.missing_count == 1 + + @pytest.mark.asyncio + async def test_admin_bypasses_access_check(self): + """Proxy admins must not be filtered by get_all_mcp_servers_for_user.""" + server = _make_env_var_server( + server_id="srv-1", + env_vars=_ENV_VARS_MIXED, + static_headers=_STATIC_HEADERS_MIXED, + ) + access_list_mock = AsyncMock(return_value=[]) + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object( + mgmt_endpoints, "get_mcp_server", AsyncMock(return_value=server) + ), + patch.object( + mgmt_endpoints, "get_all_mcp_servers_for_user", access_list_mock + ), + patch.object( + mgmt_endpoints, "get_user_env_vars", AsyncMock(return_value={}) + ), + ): + result = await mgmt_endpoints.get_mcp_user_env_vars( + server_id="srv-1", + user_api_key_dict=generate_mock_user_api_key_auth( + user_id="admin", + user_role=LitellmUserRoles.PROXY_ADMIN, + ), + ) + assert result.server_id == "srv-1" + access_list_mock.assert_not_awaited() + + @pytest.mark.asyncio + async def test_non_admin_gets_403_not_404_for_inaccessible_server(self): + """Authorization must run before the existence check so a non-admin + cannot distinguish "server does not exist" (404) from "server exists but + you lack access" (403) and enumerate server IDs.""" + get_mcp_server_mock = AsyncMock(return_value=None) + with ( + patch.object( + mgmt_endpoints, "get_prisma_client_or_throw", return_value=MagicMock() + ), + patch.object(mgmt_endpoints, "get_mcp_server", get_mcp_server_mock), + patch.object( + mgmt_endpoints, + "get_all_mcp_servers_for_user", + AsyncMock(return_value=[]), + ), + ): + with pytest.raises(HTTPException) as exc: + await mgmt_endpoints.get_mcp_user_env_vars( + server_id="srv-1", + user_api_key_dict=generate_mock_user_api_key_auth( + user_id="alice", + user_role=LitellmUserRoles.INTERNAL_USER, + ), + ) + assert exc.value.status_code == 403 + get_mcp_server_mock.assert_not_awaited() + + def test_oauth2_flow_accepted_on_create_request(): """NewMCPServerRequest carries oauth2_flow through to the persisted dict.""" from litellm.proxy._experimental.mcp_server.db import _prepare_mcp_server_data @@ -3344,9 +4445,7 @@ def test_oauth2_flow_round_trips_on_update_and_response_models(): UpdateMCPServerRequest, ) - update = UpdateMCPServerRequest( - server_id="srv-1", oauth2_flow="client_credentials" - ) + update = UpdateMCPServerRequest(server_id="srv-1", oauth2_flow="client_credentials") assert update.oauth2_flow == "client_credentials" row = LiteLLM_MCPServerTable( @@ -3366,6 +4465,5 @@ def test_oauth2_flow_defaults_to_none_when_omitted(): assert UpdateMCPServerRequest(server_id="srv-1").oauth2_flow is None assert ( - LiteLLM_MCPServerTable(server_id="srv-1", transport="http").oauth2_flow - is None + LiteLLM_MCPServerTable(server_id="srv-1", transport="http").oauth2_flow is None ) diff --git a/tests/test_litellm/proxy/management_helpers/test_management_helpers_utils.py b/tests/test_litellm/proxy/management_helpers/test_management_helpers_utils.py index 459072cf9d..463aba6f74 100644 --- a/tests/test_litellm/proxy/management_helpers/test_management_helpers_utils.py +++ b/tests/test_litellm/proxy/management_helpers/test_management_helpers_utils.py @@ -19,6 +19,154 @@ from litellm.proxy._types import ( from litellm.proxy.management_helpers.utils import add_new_member +@pytest.mark.asyncio +async def test_management_otel_span_redacts_mcp_global_env_var_secrets(monkeypatch): + """A decrypted MCP global env var secret must never reach telemetry. + + MCP create/update endpoints return the server with decrypted + ``scope="global"`` env var values so the admin UI can pre-fill the edit + form. ``management_endpoint_wrapper`` serializes the response into an OTEL + span, and that span is readable by observability users, so the secret value + must be blanked there while names/scopes stay for usefulness. The endpoint's + own return value must keep the decrypted value for the admin. + """ + import datetime + + from litellm.proxy._types import ( + LiteLLM_MCPServerTable, + MCPEnvVar, + MCPEnvVarScope, + ) + from litellm.proxy.management_helpers import utils as mgmt_utils + + captured = {} + + class _FakeOtelLogger: + async def async_management_endpoint_success_hook( + self, logging_payload, parent_otel_span + ): + captured["response"] = logging_payload.response + + import litellm.proxy.proxy_server as proxy_server + + monkeypatch.setattr(proxy_server, "open_telemetry_logger", _FakeOtelLogger()) + monkeypatch.setattr(mgmt_utils, "is_otel_v2_enabled", lambda: False) + + secret = "s3cr3t-p@ss" + result = LiteLLM_MCPServerTable( + server_id="srv-1", + alias="echo", + url="http://localhost:8765/mcp", + transport="http", + env_vars=[ + MCPEnvVar(name="DB_PASSWORD", value=secret, scope=MCPEnvVarScope.global_), + MCPEnvVar( + name="CORP_USER", + value="", + scope=MCPEnvVarScope.user, + description="Your DB username", + ), + ], + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + ) + + await mgmt_utils._emit_management_endpoint_otel_span( + func=lambda: None, + kwargs={}, + parent_otel_span=object(), + start_time=datetime.datetime.now(), + end_time=datetime.datetime.now(), + result=result, + ) + + serialized = captured["response"]["env_vars"] + # The secret must not appear anywhere the span serializer would stringify. + assert secret not in str(captured["response"]) + assert all(entry["value"] == "" for entry in serialized) + # Names and scopes survive so the trace stays useful. + assert {entry["name"] for entry in serialized} == {"DB_PASSWORD", "CORP_USER"} + assert any(entry["scope"] == MCPEnvVarScope.global_ for entry in serialized) + # The endpoint's own return value is untouched: the admin still gets the + # decrypted value to pre-fill the edit form. + assert result.env_vars[0].value == secret + + +@pytest.mark.asyncio +async def test_management_otel_span_redacts_nested_submission_env_var_secrets( + monkeypatch, +): + """Decrypted global env var secrets nested under ``items`` must also be blanked. + + ``GET /v1/mcp/server/submissions`` returns ``MCPSubmissionsSummary`` whose + ``items[].env_vars`` carry decrypted ``scope="global"`` values for full admins. + ``management_endpoint_wrapper`` stringifies that nested ``items`` value into the + OTEL span, so redaction has to walk into ``items`` and not just the top level, + while the endpoint's own return value keeps the value for the admin UI. + """ + import datetime + + from litellm.proxy._types import ( + LiteLLM_MCPServerTable, + MCPEnvVar, + MCPEnvVarScope, + MCPSubmissionsSummary, + ) + from litellm.proxy.management_helpers import utils as mgmt_utils + + captured = {} + + class _FakeOtelLogger: + async def async_management_endpoint_success_hook( + self, logging_payload, parent_otel_span + ): + captured["response"] = logging_payload.response + + import litellm.proxy.proxy_server as proxy_server + + monkeypatch.setattr(proxy_server, "open_telemetry_logger", _FakeOtelLogger()) + monkeypatch.setattr(mgmt_utils, "is_otel_v2_enabled", lambda: False) + + secret = "s3cr3t-submission" + server = LiteLLM_MCPServerTable( + server_id="srv-sub", + alias="echo", + url="http://localhost:8765/mcp", + transport="http", + env_vars=[ + MCPEnvVar(name="DB_PASSWORD", value=secret, scope=MCPEnvVarScope.global_), + ], + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + ) + result = MCPSubmissionsSummary( + total=1, pending_review=1, active=0, rejected=0, items=[server] + ) + + await mgmt_utils._emit_management_endpoint_otel_span( + func=lambda: None, + kwargs={}, + parent_otel_span=object(), + start_time=datetime.datetime.now(), + end_time=datetime.datetime.now(), + result=result, + ) + + # The nested secret must not appear anywhere the span serializer stringifies. + assert secret not in str(captured["response"]) + + redacted_item = captured["response"]["items"][0] + redacted_env_vars = ( + redacted_item["env_vars"] + if isinstance(redacted_item, dict) + else redacted_item.env_vars + ) + assert [entry["value"] for entry in redacted_env_vars] == [""] + assert redacted_env_vars[0]["name"] == "DB_PASSWORD" + # The endpoint's own return value is untouched for the admin UI. + assert result.items[0].env_vars[0].value == secret + + @pytest.mark.asyncio async def test_add_new_member_clones_default_team_budget_id(): """ diff --git a/ui/litellm-dashboard/e2e_tests/tests/mcp/mcpServers.spec.ts b/ui/litellm-dashboard/e2e_tests/tests/mcp/mcpServers.spec.ts index 22ba85956d..7c4a7cb056 100644 --- a/ui/litellm-dashboard/e2e_tests/tests/mcp/mcpServers.spec.ts +++ b/ui/litellm-dashboard/e2e_tests/tests/mcp/mcpServers.spec.ts @@ -50,11 +50,11 @@ test.describe("MCP Servers", () => { // No teardown needed — the e2e runner spins up a fresh DB per invocation. - // Success toast and the new row in the table. Scope the row lookup to - // the MCP servers table so the form modal's `server_name` input — which + // Success toast and the new card in the server grid. Scope the lookup to + // the MCP servers grid so the form modal's `server_name` input — which // still holds the timestamped value during its close animation — can't // satisfy the assertion before the server actually lands in the list. await expect(page.getByText("MCP Server created successfully").first()).toBeVisible({ timeout: 15_000 }); - await expect(page.locator("table tbody").getByText(uniqueName).first()).toBeVisible({ timeout: 10_000 }); + await expect(page.getByTestId("mcp-servers-grid").getByText(uniqueName).first()).toBeVisible({ timeout: 10_000 }); }); }); diff --git a/ui/litellm-dashboard/src/components/mcp_tools/EnvVarsSection.tsx b/ui/litellm-dashboard/src/components/mcp_tools/EnvVarsSection.tsx new file mode 100644 index 0000000000..4a694d7158 --- /dev/null +++ b/ui/litellm-dashboard/src/components/mcp_tools/EnvVarsSection.tsx @@ -0,0 +1,142 @@ +import React from "react"; +import { Form, Input, Select, Button, Tooltip, Typography } from "antd"; +import { InfoCircleOutlined, MinusCircleOutlined, PlusOutlined } from "@ant-design/icons"; + +const { Text } = Typography; + +const SCOPE_OPTIONS = [ + { value: "global", label: "Instance" }, + { value: "user", label: "Per-user" }, +]; + +/** + * Form section for admin-configured MCP environment variables. + * + * Each row has: name | value | scope. Variables can be interpolated into + * Static Headers via ${NAME}. ``scope=global`` (shown as "Instance") values + * are used as-is. ``scope=user`` (shown as "Per-user") values are filled in + * by each user via the MCP Gateway dashboard. + * + * The parent form reads the ``env_vars`` field from the form values. + */ +const EnvVarsSection: React.FC = () => { + return ( +
+
+ + Variables + + + Define variables you can interpolate in Static Headers or Authentication using{" "} + {"${VAR_NAME}"}.
+ Instance: admin-defined value used for every user. +
+ Per-user: each user supplies their own value (e.g. personal credentials) via the MCP Gateway + dashboard. + + } + > + +
+
+ + Reference these in Static Headers or Authentication as {"${VAR_NAME}"}. For example:{" "} + + {"${DB_PROTOCOL}://${CORP_USERNAME}:${CORP_PASSWORD}@${DB_HOSTNAME}"} + + + + + {(fields, { add, remove }) => ( +
+ {fields.length > 0 && ( +
+
Variable Name
+
Value / Description
+
Scope
+
+
+ )} + {fields.map(({ key, name, ...restField }) => ( +
+ + + +
+ +
+ + + + + Hint + + + } + placeholder="e.g. Your DB username" + styles={{ input: { color: "#9ca3af" } }} + /> + + ); + } + return ( + + + + ); +}; + +export default EnvVarsSection; diff --git a/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx b/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx index 0bd3921871..27cbdf2ea3 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/MCPPermissionManagement.tsx @@ -53,6 +53,17 @@ const MCPPermissionManagement: React.FC = ({ })); form.setFieldValue("static_headers", staticHeaders); } + if (Array.isArray(mcpServer.env_vars) && mcpServer.env_vars.length > 0) { + form.setFieldValue( + "env_vars", + mcpServer.env_vars.map((entry) => ({ + name: entry.name, + value: entry.value ?? "", + scope: entry.scope ?? "global", + description: entry.description ?? "", + })), + ); + } if (typeof mcpServer.allow_all_keys === "boolean") { form.setFieldValue("allow_all_keys", mcpServer.allow_all_keys); } diff --git a/ui/litellm-dashboard/src/components/mcp_tools/MCPServerCard.tsx b/ui/litellm-dashboard/src/components/mcp_tools/MCPServerCard.tsx new file mode 100644 index 0000000000..064a948132 --- /dev/null +++ b/ui/litellm-dashboard/src/components/mcp_tools/MCPServerCard.tsx @@ -0,0 +1,383 @@ +import { useState, type FC, type KeyboardEvent, type MouseEvent } from "react"; +import { Dropdown, Tooltip, Typography, Tag } from "antd"; +import type { MenuProps } from "antd"; +import { + CheckOutlined, + DeleteOutlined, + ExclamationCircleFilled, + MoreOutlined, + ThunderboltOutlined, +} from "@ant-design/icons"; +import type { MCPServer } from "./types"; +import { getMaskedAndFullUrl } from "./utils"; + +const { Text } = Typography; + +interface MCPServerCardProps { + server: MCPServer; + // Per-user env-var fields this user still needs to fill in for this server. + // Computed by the parent from the bulk /user-env-vars/status response, so + // the card never issues a per-row request (no N+1). + missingUserFields?: string[]; + isLoadingHealth?: boolean; + isRechecking?: boolean; + onClick: () => void; + onRecheckHealth?: () => void; + onByokConnect?: () => void; + onOpenFillFields?: () => void; + onDelete?: () => void; +} + +const HEALTH_TONE: Record = { + healthy: { dot: "bg-green-500" }, + unhealthy: { dot: "bg-red-500" }, + unknown: { dot: "bg-gray-300" }, +}; + +// Stop card-level click handler from firing when an interactive child is used. +const stop = (e: MouseEvent | KeyboardEvent) => e.stopPropagation(); + +const MCPServerCard: FC = ({ + server, + missingUserFields, + isLoadingHealth, + isRechecking, + onClick, + onRecheckHealth, + onByokConnect, + onOpenFillFields, + onDelete, +}) => { + const alias = server.alias || server.server_name || ""; + const name = server.server_name || alias || server.server_id; + // Logo is sourced exclusively from the admin-set `mcp_info.logo_url`. + const candidateLogo = server.mcp_info?.logo_url ?? undefined; + const [failedLogoUrl, setFailedLogoUrl] = useState(null); + const logoUrl = candidateLogo && failedLogoUrl !== candidateLogo ? candidateLogo : undefined; + const transport = server.transport || "http"; + const displayTransport = server.spec_path && transport !== "stdio" ? "openapi" : transport; + const authType = server.auth_type || "none"; + const status = server.status || "unknown"; + const healthTone = HEALTH_TONE[status] ?? HEALTH_TONE.unknown; + const isPublic = server.available_on_public_internet; + const accessGroups = (server.mcp_access_groups ?? []).filter((g): g is string => typeof g === "string"); + + const missing = missingUserFields ?? []; + const needsAttention = missing.length > 0; + + const cardClass = needsAttention + ? "border-2 border-red-300 bg-red-50/40 hover:border-red-400 hover:shadow-md" + : "border border-gray-200 bg-white hover:border-gray-300 hover:shadow-md"; + + const url = server.url || ""; + const { maskedUrl } = url ? getMaskedAndFullUrl(url) : { maskedUrl: "" }; + + // Transport-adapted identifier shown under the title. Every transport has + // something useful here, which keeps the tag row vertically aligned across + // cards in the grid (stdio cards no longer "snap up" because they lack a URL). + let subtitle = ""; + let subtitleTooltip = ""; + if (transport === "stdio") { + const parts = [server.command, ...(server.args ?? [])].filter( + (p): p is string => typeof p === "string" && p.length > 0, + ); + subtitle = parts.join(" "); + subtitleTooltip = subtitle; + } else if (server.spec_path) { + subtitle = server.spec_path; + subtitleTooltip = server.spec_path; + } else if (url) { + subtitle = maskedUrl; + subtitleTooltip = url; + } + + const handleKeyDown = (e: KeyboardEvent) => { + if (e.key === "Enter" || e.key === " ") { + e.preventDefault(); + onClick(); + } + }; + + const menuItems: MenuProps["items"] = []; + if (onRecheckHealth) { + menuItems.push({ + key: "test-connection", + label: "Test Connection", + icon: , + disabled: isRechecking, + onClick: ({ domEvent }) => { + domEvent.stopPropagation(); + onRecheckHealth(); + }, + }); + } + if (onDelete) { + if (menuItems.length > 0) { + menuItems.push({ key: "divider", type: "divider" }); + } + menuItems.push({ + key: "delete", + label: "Delete", + icon: , + danger: true, + onClick: ({ domEvent }) => { + domEvent.stopPropagation(); + onDelete(); + }, + }); + } + + // Card uses role="button" + nested + + )} +
+ + {subtitle ? ( + + + {subtitle} + + + ) : ( + // Defensive placeholder: keep the row even when no identifier is + // available so the tag row stays vertically aligned across the grid. +
+ )} + +
+ + {displayTransport.toUpperCase()} + {authType} + + + + {isPublic ? "Public" : "Internal"} + + + {accessGroups.slice(0, 2).map((g) => ( + + {g} + + ))} + {accessGroups.length > 2 && ( + + +{accessGroups.length - 2} + + )} +
+ + {(server.is_byok || needsAttention) && ( +
+ {server.is_byok && } + {needsAttention && ( +
+ +
Missing user fields:
+
    + {missing.map((m) => ( +
  • • {m}
  • + ))} +
+
+ } + > + + + {missing.length} user field + {missing.length === 1 ? "" : "s"} missing + + + {onOpenFillFields && ( + + )} +
+ )} +
+ )} +
+ ); +}; + +interface HealthChipProps { + status: string; + isLoadingHealth?: boolean; + isRechecking?: boolean; + onRecheck?: () => void; + lastCheck?: string | null; + error?: string | null; + dotClass: string; +} + +const HealthChip: FC = ({ + status, + isLoadingHealth, + isRechecking, + onRecheck, + lastCheck, + error, + dotClass, +}) => { + if (isLoadingHealth || isRechecking) { + return ( + + + + Checking + + + ); + } + const tooltip = ( +
+
Health: {status}
+ {lastCheck &&
Last check: {new Date(lastCheck).toLocaleString()}
} + {error && ( +
+
Error
+
{error}
+
+ )} + {!lastCheck && !error &&
No health data
} + {onRecheck &&
Click to recheck
} +
+ ); + return ( + + { + e.stopPropagation(); + onRecheck(); + } + : undefined + } + > + + + {status.charAt(0).toUpperCase() + status.slice(1)} + + + + ); +}; + +interface ByokRowProps { + connected: boolean; + onConnect?: () => void; +} + +const ByokRow: FC = ({ connected, onConnect }) => { + if (connected) { + return ( +
+ BYOK credential +
+ + Connected + + {onConnect && ( + + )} +
+
+ ); + } + return ( +
+ BYOK credential + {onConnect ? ( + + ) : ( + + )} +
+ ); +}; + +export default MCPServerCard; diff --git a/ui/litellm-dashboard/src/components/mcp_tools/UserEnvVarsModal.tsx b/ui/litellm-dashboard/src/components/mcp_tools/UserEnvVarsModal.tsx new file mode 100644 index 0000000000..08a285cd56 --- /dev/null +++ b/ui/litellm-dashboard/src/components/mcp_tools/UserEnvVarsModal.tsx @@ -0,0 +1,141 @@ +import React from "react"; +import { Modal, Form, Input, Button, Alert, Spin, Tag, Typography } from "antd"; +import { useMutation, useQuery } from "@tanstack/react-query"; +import { MCPServer, MCPUserEnvVarsStatus } from "./types"; +import { getMCPUserEnvVars, storeMCPUserEnvVars } from "../networking"; +import NotificationsManager from "../molecules/notifications_manager"; + +const { Text, Title } = Typography; + +interface UserEnvVarsModalProps { + server: MCPServer | null; + open: boolean; + accessToken: string | null; + onClose: () => void; + onSaved?: (status: MCPUserEnvVarsStatus) => void; +} + +/** + * User-facing modal for filling in per-user MCP environment variables. + * + * Backed by GET / POST ``/v1/mcp/server/{id}/user-env-vars``. Each field + * the admin marked as ``scope=user`` shows up with the admin-supplied + * description as the placeholder. + */ +const UserEnvVarsModal: React.FC = ({ server, open, accessToken, onClose, onSaved }) => { + const [form] = Form.useForm(); + + const { + data: status, + isLoading, + isError, + } = useQuery({ + queryKey: ["mcpUserEnvVars", server?.server_id], + queryFn: () => getMCPUserEnvVars(accessToken!, server!.server_id), + enabled: open && !!server && !!accessToken, + }); + + const saveMutation = useMutation({ + mutationFn: (values: Record) => storeMCPUserEnvVars(accessToken!, server!.server_id, values), + onSuccess: (saved) => { + NotificationsManager.success("Credentials saved"); + onSaved?.(saved); + onClose(); + }, + onError: (err) => { + NotificationsManager.fromBackend(`Failed to save env vars: ${err instanceof Error ? err.message : String(err)}`); + }, + }); + + const handleSave = (values: Record) => { + if (!server || !accessToken) return; + const trimmed: Record = {}; + for (const [k, v] of Object.entries(values)) { + trimmed[k] = (v ?? "").trim(); + } + saveMutation.mutate(trimmed); + }; + + const displayName = server?.server_name || server?.alias || server?.server_id || "MCP Server"; + const required = status?.required ?? []; + const isSaving = saveMutation.isPending; + + return ( + { + if (opened) form.resetFields(); + }} + title={ +
+
+ + Set your credentials + + Per-user +
+ + {displayName} + +
+ } + > +
+ {isLoading ? ( +
+ +
+ ) : isError ? ( + + ) : required.length === 0 ? ( + + ) : ( + <> + + These values are private to you. Your admin configured this MCP server to require these per-user + credentials. Saved values are never shown back; leave an already-set field blank to keep it, or enter a + value to set or change it. + +
+ {required.map((spec) => ( + + {spec.name} + {spec.is_set && Set} + + } + extra={spec.description || undefined} + rules={spec.is_set ? undefined : [{ required: true, message: `${spec.name} is required` }]} + > + + + ))} +
+ + +
+
+ + )} +
+
+ ); +}; + +export default UserEnvVarsModal; diff --git a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx index 7af8e9bf00..018fd29290 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.tsx @@ -13,8 +13,9 @@ import StdioConfiguration from "./StdioConfiguration"; import MCPPermissionManagement from "./MCPPermissionManagement"; import OpenAPIFormSection, { OpenAPIKeyTool } from "./OpenAPIFormSection"; import MCPLogoSelector from "./MCPLogoSelector"; +import EnvVarsSection from "./EnvVarsSection"; import { isAdminRole } from "@/utils/roles"; -import { validateMCPServerUrl, validateMCPServerName } from "./utils"; +import { validateMCPServerUrl, validateMCPServerName, normalizeEnvVars } from "./utils"; import NotificationsManager from "../molecules/notifications_manager"; import { useMcpOAuthFlow } from "@/hooks/useMcpOAuthFlow"; import { useTestMCPConnection } from "@/hooks/useTestMCPConnection"; @@ -43,7 +44,7 @@ const reduceStaticHeaders = (list: unknown): Record => { if (!Array.isArray(list)) return {}; return list.reduce((acc: Record, entry: Record) => { const header = entry?.header?.trim(); - if (header) acc[header] = entry?.value ?? ""; + if (header) acc[header] = (entry?.value ?? "").trim(); return acc; }, {}); }; @@ -289,6 +290,7 @@ const CreateMCPServer: React.FC = ({ try { const { static_headers: staticHeadersList, + env_vars: envVarsList, stdio_config: rawStdioConfig, credentials: credentialValues, allow_all_keys: allowAllKeysRaw, @@ -303,6 +305,7 @@ const CreateMCPServer: React.FC = ({ const accessGroups = restValues.mcp_access_groups; const staticHeaders = reduceStaticHeaders(staticHeadersList); + const envVars = normalizeEnvVars(envVarsList); const credentialsPayload = credentialValues && typeof credentialValues === "object" @@ -403,10 +406,10 @@ const CreateMCPServer: React.FC = ({ delegate_auth_to_upstream: Boolean(delegateAuthToUpstreamRaw), oauth_passthrough: Boolean(oauthPassthroughRaw), static_headers: staticHeaders, + env_vars: envVars, ...(tokenValidation !== null && { token_validation: tokenValidation }), }; - payload.static_headers = staticHeaders; const includeCredentials = restValues.auth_type && AUTH_TYPES_REQUIRING_CREDENTIALS.includes(restValues.auth_type); @@ -1026,6 +1029,11 @@ const CreateMCPServer: React.FC = ({
+ {/* Environment Variables Section */} +
+ +
+ {/* Permission Management / Access Control Section */}
= ({ if (!header) { return acc; } - acc[header] = entry?.value ?? ""; + acc[header] = (entry?.value ?? "").trim(); return acc; }, {}) : ({} as Record); @@ -169,6 +170,18 @@ const MCPServerEdit: React.FC = ({ })); }, [mcpServer.static_headers]); + const initialEnvVars = React.useMemo(() => { + if (!Array.isArray(mcpServer.env_vars)) { + return []; + } + return mcpServer.env_vars.map((entry) => ({ + name: entry.name, + value: entry.value ?? "", + scope: entry.scope === "user" ? "user" : "global", + description: entry.description ?? "", + })); + }, [mcpServer.env_vars]); + const initialEnvJson = React.useMemo(() => { const env = mcpServer.env ?? undefined; if (!env || Object.keys(env).length === 0) { @@ -194,13 +207,14 @@ const MCPServerEdit: React.FC = ({ ...mcpServer, transport: effectiveTransport, static_headers: initialStaticHeaders, + env_vars: initialEnvVars, extra_headers: mcpServer.extra_headers || [], oauth_flow_type: mcpServer.token_url ? OAUTH_FLOW.M2M : OAUTH_FLOW.INTERACTIVE, token_validation_json: mcpServer.token_validation ? JSON.stringify(mcpServer.token_validation, null, 2) : undefined, }), - [mcpServer, effectiveTransport, initialStaticHeaders, initialEnvJson], + [mcpServer, effectiveTransport, initialStaticHeaders, initialEnvVars, initialEnvJson], ); // Initialize cost config from existing server data @@ -388,6 +402,7 @@ const MCPServerEdit: React.FC = ({ // Ensure access groups is always a string array const { static_headers: staticHeadersList, + env_vars: envVarsList, credentials: credentialValues, stdio_config: rawStdioConfig, env_json: rawEnvJson, @@ -411,11 +426,13 @@ const MCPServerEdit: React.FC = ({ if (!header) { return acc; } - acc[header] = entry?.value ?? ""; + acc[header] = (entry?.value ?? "").trim(); return acc; }, {}) : ({} as Record); + const envVars = normalizeEnvVars(envVarsList); + const credentialsPayload = credentialValues && typeof credentialValues === "object" ? Object.entries(credentialValues).reduce((acc: Record, [key, value]) => { @@ -571,6 +588,7 @@ const MCPServerEdit: React.FC = ({ tool_name_to_description: Object.keys(toolNameToDescription).length > 0 ? toolNameToDescription : null, disallowed_tools: restValues.disallowed_tools || [], static_headers: staticHeaders, + env_vars: envVars, allow_all_keys: Boolean(allowAllKeysRaw ?? mcpServer.allow_all_keys), available_on_public_internet: Boolean(availableOnPublicInternetRaw ?? mcpServer.available_on_public_internet), // ``delegate_auth_to_upstream`` is only honored server-side for @@ -1108,6 +1126,11 @@ const MCPServerEdit: React.FC = ({ )} + {/* Environment Variables Section */} +
+ +
+ {/* Permission Management / Access Control Section */}
({ getGeneralSettingsCall: vi.fn().mockResolvedValue([]), updateConfigFieldSetting: vi.fn().mockResolvedValue(undefined), deleteConfigFieldSetting: vi.fn().mockResolvedValue(undefined), + listMCPUserEnvVarStatus: vi.fn().mockResolvedValue([]), })); // Mock NotificationsManager diff --git a/ui/litellm-dashboard/src/components/mcp_tools/mcp_servers.tsx b/ui/litellm-dashboard/src/components/mcp_tools/mcp_servers.tsx index 3ba9b2470f..4b383f701c 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/mcp_servers.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/mcp_servers.tsx @@ -1,26 +1,72 @@ import { isAdminRole } from "@/utils/roles"; -import { QuestionCircleOutlined } from "@ant-design/icons"; +import { QuestionCircleOutlined, SearchOutlined } from "@ant-design/icons"; import { Button, Tab, TabGroup, TabList, TabPanel, TabPanels, Text, Title } from "@tremor/react"; import NewBadge from "../common_components/NewBadge"; -import { Descriptions, Modal, Select, Tooltip, Typography } from "antd"; +import { Descriptions, Empty, Input, Modal, Select, Spin, Tooltip, Typography } from "antd"; import React, { useEffect, useState, useMemo, useCallback } from "react"; +import { useQuery } from "@tanstack/react-query"; import { useMCPServers } from "../../app/(dashboard)/hooks/mcpServers/useMCPServers"; import { useMCPServerHealth } from "../../app/(dashboard)/hooks/mcpServers/useMCPServerHealth"; import NotificationsManager from "../molecules/notifications_manager"; import { deleteMCPServer } from "../networking"; import { MCPSubmissionsTab } from "./MCPSubmissionsTab"; import { MCPToolsetsTab } from "./MCPToolsetsTab"; -import { DataTable } from "../view_logs/table"; import CreateMCPServer from "./create_mcp_server"; import MCPConnect from "./mcp_connect"; -import { mcpServerColumns } from "./mcp_server_columns"; +import MCPServerCard from "./MCPServerCard"; import { MCPServerView } from "./mcp_server_view"; -import { DiscoverableMCPServer, MCPServer, MCPServerProps, Team } from "./types"; +import type { DiscoverableMCPServer, MCPServer, MCPServerProps, MCPUserEnvVarsStatus, Team } from "./types"; import MCPSemanticFilterSettings from "../Settings/AdminSettings/MCPSemanticFilterSettings/MCPSemanticFilterSettings"; import MCPNetworkSettings from "./MCPNetworkSettings"; import MCPDiscovery from "./mcp_discovery"; import { ByokCredentialModal } from "./ByokCredentialModal"; import { getSecureItem } from "@/utils/secureStorage"; +import UserEnvVarsModal from "./UserEnvVarsModal"; +import { listMCPUserEnvVarStatus } from "../networking"; + +type SortKey = "created_desc" | "updated_desc" | "name_asc" | "health"; + +const SORT_OPTIONS: { value: SortKey; label: string }[] = [ + { value: "created_desc", label: "Recently created" }, + { value: "updated_desc", label: "Recently updated" }, + { value: "name_asc", label: "Name (A→Z)" }, + { value: "health", label: "Health (unhealthy first)" }, +]; + +const HEALTH_RANK: Record = { + unhealthy: 0, + unknown: 1, + healthy: 2, +}; + +const compareServers = (a: MCPServer, b: MCPServer, sort: SortKey): number => { + switch (sort) { + case "name_asc": { + const nameA = (a.server_name || a.alias || a.server_id).toLowerCase(); + const nameB = (b.server_name || b.alias || b.server_id).toLowerCase(); + return nameA.localeCompare(nameB); + } + case "updated_desc": { + const ta = a.updated_at ? new Date(a.updated_at).getTime() : 0; + const tb = b.updated_at ? new Date(b.updated_at).getTime() : 0; + return tb - ta; + } + case "health": { + const ra = HEALTH_RANK[a.status ?? "unknown"] ?? 1; + const rb = HEALTH_RANK[b.status ?? "unknown"] ?? 1; + if (ra !== rb) return ra - rb; + const ta = a.created_at ? new Date(a.created_at).getTime() : 0; + const tb = b.created_at ? new Date(b.created_at).getTime() : 0; + return tb - ta; + } + case "created_desc": + default: { + const ta = a.created_at ? new Date(a.created_at).getTime() : 0; + const tb = b.created_at ? new Date(b.created_at).getTime() : 0; + return tb - ta; + } + } +}; const { Text: AntdText, Title: AntdTitle } = Typography; const EDIT_OAUTH_UI_STATE_KEY = "litellm-mcp-oauth-edit-state"; @@ -67,8 +113,52 @@ const MCPServers: React.FC = ({ accessToken, userRole, userID }) const [prefillData, setPrefillData] = useState(null); const [isDeletingServer, setIsDeletingServer] = useState(false); const [byokModalServer, setByokModalServer] = useState(null); + // Per-user env-var fill modal target + deep-link source captured once from the URL. + const [envVarsModalServer, setEnvVarsModalServer] = useState(null); + const [deepLinkServerId, setDeepLinkServerId] = useState(() => + typeof window === "undefined" ? null : new URLSearchParams(window.location.search).get("fill_env_vars"), + ); + const [searchQuery, setSearchQuery] = useState(""); + const [sortKey, setSortKey] = useState("created_desc"); const isInternalUser = userRole === "Internal User"; + // Single bulk fetch of this user's per-server env-var status. Drives the + // red "N user fields missing" footer on each card with no per-row request. + const { data: envVarStatuses, refetch: refetchEnvVarStatus } = useQuery({ + queryKey: ["mcpUserEnvVarStatus"], + queryFn: () => listMCPUserEnvVarStatus(accessToken!), + enabled: !!accessToken, + }); + + // Per-server list of per-user fields this user still needs to fill in. + const missingFieldsByServer = useMemo(() => { + const map: Record = {}; + for (const status of envVarStatuses ?? []) { + map[status.server_id] = (status.required ?? []).filter((spec) => !spec.is_set).map((spec) => spec.name); + } + return map; + }, [envVarStatuses]); + + // Deep-link via ?fill_env_vars= — the link users follow from the + // friendly error the proxy returns when a per-user var is missing. The id is + // captured into state above and resolved to a server below; here we only strip + // the param so a refresh doesn't reopen the modal. + useEffect(() => { + if (!deepLinkServerId || typeof window === "undefined") return; + const params = new URLSearchParams(window.location.search); + if (!params.has("fill_env_vars")) return; + params.delete("fill_env_vars"); + const newSearch = params.toString(); + const newUrl = window.location.pathname + (newSearch ? `?${newSearch}` : "") + window.location.hash; + window.history.replaceState({}, "", newUrl); + }, [deepLinkServerId]); + + const deepLinkServer = useMemo( + () => (deepLinkServerId ? serversWithHealth.find((s) => s.server_id === deepLinkServerId) ?? null : null), + [deepLinkServerId, serversWithHealth], + ); + const activeEnvVarsServer = envVarsModalServer ?? deepLinkServer; + useEffect(() => { if (typeof window === "undefined") { return; @@ -164,26 +254,20 @@ const MCPServers: React.FC = ({ accessToken, userRole, userID }) filterServers(selectedTeam, selectedMcpAccessGroup); }, [serversWithHealth, selectedTeam, selectedMcpAccessGroup, filterServers]); - const columns = React.useMemo( - () => - mcpServerColumns( - userRole ?? "", - (serverId: string) => { - setSelectedServerId(serverId); - setEditServer(false); - }, - (serverId: string) => { - setSelectedServerId(serverId); - setEditServer(true); - }, - handleDelete, - isLoadingHealth, - (server: MCPServer) => setByokModalServer(server), - recheckServerHealth, - recheckingServerIds, - ), - [userRole, isLoadingHealth, recheckServerHealth, recheckingServerIds], - ); + // Search + sort layer applied on top of the team/access-group filters. + const displayedServers = useMemo(() => { + const q = searchQuery.trim().toLowerCase(); + const matches = q + ? filteredServers.filter((s) => { + const name = (s.server_name || "").toLowerCase(); + const alias = (s.alias || "").toLowerCase(); + const url = (s.url || "").toLowerCase(); + const id = s.server_id.toLowerCase(); + return name.includes(q) || alias.includes(q) || url.includes(q) || id.includes(q); + }) + : filteredServers; + return [...matches].sort((a, b) => compareServers(a, b, sortKey)); + }, [filteredServers, searchQuery, sortKey]); function handleDelete(server_id: string) { setServerToDelete(server_id); @@ -198,6 +282,14 @@ const MCPServers: React.FC = ({ accessToken, userRole, userID }) setIsDeletingServer(true); await deleteMCPServer(accessToken, serverIdToDelete); NotificationsManager.success("Deleted MCP Server successfully"); + // If the user is currently viewing the detail page of the server they + // just deleted, return them to the All Servers list. Otherwise the + // detail view would stay mounted, fall back to an empty stub server, + // and show a phantom "Unnamed Server" page. + if (selectedServerId === serverIdToDelete) { + setEditServer(false); + setSelectedServerId(null); + } refetch(); } catch (error) { console.error("Error deleting the mcp server:", error); @@ -442,17 +534,75 @@ const MCPServers: React.FC = ({ accessToken, userRole, userID })
-
-
} - getRowCanExpand={() => false} - isLoading={isLoadingServers} - noDataMessage="No MCP servers configured. Click '+ Add New MCP Server' to get started." - loadingMessage="Loading MCP servers..." - enableSorting={true} +
+ } + placeholder="Search by name, alias, URL, or ID" + value={searchQuery} + onChange={(e) => setSearchQuery(e.target.value)} + style={{ maxWidth: 320 }} /> +
+ Sort + +
+
+ {displayedServers.length} of {filteredServers.length} servers +
+
+
+ {isLoadingServers ? ( +
+ +
+ ) : displayedServers.length === 0 ? ( +
+ +
+ ) : ( +
+ {displayedServers.map((server) => ( + { + setSelectedServerId(server.server_id); + setEditServer(true); + }} + onRecheckHealth={ + recheckServerHealth ? () => recheckServerHealth(server.server_id) : undefined + } + onByokConnect={server.is_byok ? () => setByokModalServer(server) : undefined} + onOpenFillFields={() => setEnvVarsModalServer(server)} + onDelete={isAdminRole(userRole) ? () => handleDelete(server.server_id) : undefined} + /> + ))} +
+ )}
)} @@ -489,6 +639,22 @@ const MCPServers: React.FC = ({ accessToken, userRole, userID }) accessToken={accessToken || ""} /> )} + + {/* Per-user env-var fill modal — backed by /v1/mcp/server/{id}/user-env-vars */} + { + setEnvVarsModalServer(null); + setDeepLinkServerId(null); + }} + onSaved={() => { + // Refresh the bulk status so the red "N user fields missing" footer + // on each card clears once the user has filled in their values. + refetchEnvVarStatus(); + }} + /> ); }; diff --git a/ui/litellm-dashboard/src/components/mcp_tools/types.tsx b/ui/litellm-dashboard/src/components/mcp_tools/types.tsx index d4a3e6ff18..cd3ffcab5e 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/types.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/types.tsx @@ -263,6 +263,41 @@ export interface MCPServer { /** Per-user OAuth token storage settings (interactive OAuth only) */ token_validation?: Record | null; token_storage_ttl_seconds?: number | null; + + /** + * Admin-configured env vars interpolated into static_headers via ${NAME}. + * Stored as a list so the UI can preserve admin-entered ordering. + */ + env_vars?: MCPEnvVar[] | null; +} + +/** One environment variable entry on an MCP server. */ +export type MCPEnvVarScope = "global" | "user"; + +export interface MCPEnvVar { + name: string; + /** For scope="global": the value used in interpolation. + * For scope="user": optional placeholder/description shown to users. */ + value: string; + scope: MCPEnvVarScope; + description?: string | null; +} + +/** One required per-user env var slot returned by the user-env-vars endpoint. */ +export interface MCPUserEnvVarSpec { + name: string; + description?: string | null; + is_set: boolean; +} + +/** Per-server per-user env var status returned by the API. */ +export interface MCPUserEnvVarsStatus { + server_id: string; + server_name?: string | null; + alias?: string | null; + required: MCPUserEnvVarSpec[]; + missing_count: number; + setup_url?: string | null; } export interface MCPServerProps { diff --git a/ui/litellm-dashboard/src/components/mcp_tools/utils.tsx b/ui/litellm-dashboard/src/components/mcp_tools/utils.tsx index 44a0640561..6d9479a13c 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/utils.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/utils.tsx @@ -1,3 +1,5 @@ +import { MCPEnvVar, MCPEnvVarScope } from "./types"; + export const extractMCPToken = (url: string): { token: string | null; baseUrl: string } => { try { const mcpIndex = url.indexOf("/mcp/"); @@ -51,3 +53,27 @@ export const validateMCPServerName = (value: string) => { ? Promise.reject("Cannot contain '-' (hyphen) or spaces. Please use '_' (underscore) instead.") : Promise.resolve(); }; + +// Normalize the env_vars form list into the payload shape the backend expects. +// Drops empty rows, invalid identifiers, and duplicate names; user-scoped entries never carry a value. +export const normalizeEnvVars = (list: unknown): MCPEnvVar[] => { + if (!Array.isArray(list)) return []; + const seen = new Set(); + const out: MCPEnvVar[] = []; + for (const entry of list) { + if (!entry || typeof entry !== "object") continue; + const record = entry as Record; + const name = String(record.name ?? "").trim(); + if (!name || seen.has(name)) continue; + if (!/^[A-Za-z_][A-Za-z0-9_]*$/.test(name)) continue; + const scope: MCPEnvVarScope = record.scope === "user" ? "user" : "global"; + out.push({ + name, + value: scope === "user" ? "" : String(record.value ?? ""), + scope, + description: (record.description as string | undefined) || undefined, + }); + seen.add(name); + } + return out; +}; diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 0503490823..221a1c789f 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -76,6 +76,7 @@ import { UserInfo } from "./view_users/types"; import { EmailEventSettingsResponse, EmailEventSettingsUpdateRequest } from "./email_events/types"; import { jsonFields } from "./common_components/check_openapi_schema"; import NotificationsManager from "./molecules/notifications_manager"; +import type { MCPUserEnvVarsStatus } from "./mcp_tools/types"; import { createApiClient, deriveErrorMessage } from "@/lib/http/client"; import { resolveApiBase } from "@/lib/http/resolveApiBase"; @@ -9620,6 +9621,35 @@ export const listMCPUserCredentials = async (accessToken: string): Promise => { + return apiClient.get(`/v1/mcp/server/${serverId}/user-env-vars`, { accessToken }); +}; + +export const storeMCPUserEnvVars = async ( + accessToken: string, + serverId: string, + values: Record, +): Promise => { + return apiClient.post(`/v1/mcp/server/${serverId}/user-env-vars`, { + accessToken, + body: { values }, + }); +}; + +export const listMCPUserEnvVarStatus = async (accessToken: string): Promise => { + // Best-effort status badges: a failure here must not break the page, so fall + // back to an empty list rather than surfacing the error to the caller. + try { + return await apiClient.get("/v1/mcp/user-env-vars/status", { accessToken }); + } catch { + return []; + } +}; + // ============================================================ // Memory management (/v1/memory) // ============================================================