feat(mcp): per-server env vars with global + per-user scopes (#28917)
This commit is contained in:
parent
53cf3d8416
commit
4ec4ab99d0
@ -0,0 +1,23 @@
|
||||
-- AlterTable: add admin-configured env_vars to MCP server table
|
||||
ALTER TABLE "LiteLLM_MCPServerTable" ADD COLUMN IF NOT EXISTS "env_vars" JSONB DEFAULT '[]';
|
||||
|
||||
-- CreateTable: per-user env var values for MCP servers
|
||||
CREATE TABLE IF NOT EXISTS "LiteLLM_MCPUserEnvVars" (
|
||||
"id" TEXT NOT NULL,
|
||||
"user_id" TEXT NOT NULL,
|
||||
"server_id" TEXT NOT NULL,
|
||||
"values_b64" TEXT NOT NULL,
|
||||
"created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
|
||||
CONSTRAINT "LiteLLM_MCPUserEnvVars_pkey" PRIMARY KEY ("id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX IF NOT EXISTS "LiteLLM_MCPUserEnvVars_user_id_server_id_key" ON "LiteLLM_MCPUserEnvVars"("user_id", "server_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX IF NOT EXISTS "LiteLLM_MCPUserEnvVars_user_id_idx" ON "LiteLLM_MCPUserEnvVars"("user_id");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX IF NOT EXISTS "LiteLLM_MCPUserEnvVars_server_id_idx" ON "LiteLLM_MCPUserEnvVars"("server_id");
|
||||
@ -311,6 +311,11 @@ model LiteLLM_MCPServerTable {
|
||||
tool_name_to_description Json? @default("{}")
|
||||
extra_headers String[] @default([])
|
||||
static_headers Json? @default("{}")
|
||||
// Admin-configured environment variables interpolated into static_headers
|
||||
// via ${NAME} syntax. Stored as an array of
|
||||
// {name, value, scope, description}. scope is "global" (value used as-is)
|
||||
// or "user" (value supplied per-user via LiteLLM_MCPUserEnvVars).
|
||||
env_vars Json? @default("[]")
|
||||
// Health check status
|
||||
status String? @default("unknown")
|
||||
last_health_check DateTime?
|
||||
@ -366,6 +371,21 @@ model LiteLLM_MCPUserCredentials {
|
||||
@@unique([user_id, server_id])
|
||||
}
|
||||
|
||||
// Per-user environment variable values for MCP servers.
|
||||
// values_b64 is an encrypted JSON object: {VAR_NAME: "value", ...}.
|
||||
model LiteLLM_MCPUserEnvVars {
|
||||
id String @id @default(uuid())
|
||||
user_id String
|
||||
server_id String
|
||||
values_b64 String
|
||||
created_at DateTime @default(now())
|
||||
updated_at DateTime @default(now()) @updatedAt
|
||||
|
||||
@@unique([user_id, server_id])
|
||||
@@index([user_id])
|
||||
@@index([server_id])
|
||||
}
|
||||
|
||||
// Generate Tokens for Proxy
|
||||
model LiteLLM_VerificationToken {
|
||||
token String @id
|
||||
|
||||
@ -60,6 +60,27 @@ def to_basic_auth(auth_value: str) -> str:
|
||||
return base64.b64encode(auth_value.encode("utf-8")).decode()
|
||||
|
||||
|
||||
def _strip_header_whitespace(headers: Dict[str, str]) -> Dict[str, str]:
|
||||
return {
|
||||
(key.strip() if isinstance(key, str) else key): (
|
||||
value.strip() if isinstance(value, str) else value
|
||||
)
|
||||
for key, value in headers.items()
|
||||
}
|
||||
|
||||
|
||||
def _first_non_cancelled_cause(exc: BaseException) -> Optional[BaseException]:
|
||||
queue: List[BaseException] = [exc]
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
nested = getattr(current, "exceptions", None)
|
||||
if nested:
|
||||
queue.extend(nested)
|
||||
elif not isinstance(current, asyncio.CancelledError):
|
||||
return current
|
||||
return None
|
||||
|
||||
|
||||
TSessionResult = TypeVar("TSessionResult")
|
||||
|
||||
|
||||
@ -335,6 +356,7 @@ class MCPClient:
|
||||
user input (elicitation), or send log messages.
|
||||
"""
|
||||
transport = await transport_ctx.__aenter__()
|
||||
in_flight_error: Optional[BaseException] = None
|
||||
try:
|
||||
read_stream, write_stream = transport[0], transport[1]
|
||||
# Build session kwargs with optional callbacks
|
||||
@ -360,11 +382,21 @@ class MCPClient:
|
||||
await session_ctx.__aexit__(None, None, None)
|
||||
except BaseException as e:
|
||||
verbose_logger.debug(f"Error during session context exit: {e}")
|
||||
except BaseException as e:
|
||||
in_flight_error = e
|
||||
raise
|
||||
finally:
|
||||
try:
|
||||
await transport_ctx.__aexit__(None, None, None)
|
||||
except BaseException as e:
|
||||
verbose_logger.debug(f"Error during transport context exit: {e}")
|
||||
except BaseException as exit_error:
|
||||
verbose_logger.debug(
|
||||
f"Error during transport context exit: {exit_error}"
|
||||
)
|
||||
root_cause = _first_non_cancelled_cause(exit_error)
|
||||
if root_cause is not None and isinstance(
|
||||
in_flight_error, asyncio.CancelledError
|
||||
):
|
||||
raise root_cause from in_flight_error
|
||||
|
||||
async def run_with_session(
|
||||
self, operation: Callable[[ClientSession], Awaitable[TSessionResult]]
|
||||
@ -426,7 +458,7 @@ class MCPClient:
|
||||
# update the headers with the extra headers
|
||||
if self.extra_headers:
|
||||
headers.update(self.extra_headers)
|
||||
return headers
|
||||
return _strip_header_whitespace(headers)
|
||||
|
||||
def _create_httpx_client_factory(self) -> Callable[..., httpx.AsyncClient]:
|
||||
"""
|
||||
|
||||
@ -1,5 +1,6 @@
|
||||
import base64
|
||||
import binascii
|
||||
import hashlib
|
||||
import json
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
|
||||
@ -11,6 +12,7 @@ from litellm.proxy._types import (
|
||||
LiteLLM_ObjectPermissionTable,
|
||||
LiteLLM_TeamTable,
|
||||
MCPApprovalStatus,
|
||||
MCPEnvVarScope,
|
||||
MCPSubmissionsSummary,
|
||||
NewMCPServerRequest,
|
||||
SpecialMCPServerName,
|
||||
@ -28,6 +30,144 @@ from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.mcp import MCPCredentials
|
||||
|
||||
|
||||
def _is_global_env_var_scope(scope: Any) -> bool:
|
||||
"""``scope="user"`` entries are placeholders the user fills in; everything
|
||||
else (including a missing scope) is an admin-supplied global value."""
|
||||
return scope != MCPEnvVarScope.user and scope != "user"
|
||||
|
||||
|
||||
def _encrypt_global_env_var_values(env_vars: Iterable[Dict[str, Any]]) -> None:
|
||||
"""Encrypt ``scope="global"`` env var values in place before persisting.
|
||||
|
||||
Global values hold admin-supplied secrets (API keys, passwords) that get
|
||||
interpolated into headers, so they are encrypted at rest like credentials
|
||||
and the per-user ``values_b64`` column. Per-user placeholders are not
|
||||
secrets and are stored verbatim.
|
||||
"""
|
||||
for entry in env_vars:
|
||||
if not _is_global_env_var_scope(entry.get("scope")):
|
||||
continue
|
||||
value = entry.get("value")
|
||||
if value:
|
||||
entry["value"] = encrypt_value_helper(value)
|
||||
|
||||
|
||||
def decrypt_global_env_var_values(env_vars: Optional[Iterable[Any]]) -> None:
|
||||
"""Decrypt ``scope="global"`` env var values in place after reading the DB.
|
||||
|
||||
Accepts ``MCPEnvVar`` models (``LiteLLM_MCPServerTable``) or plain dicts
|
||||
(raw rows / deserialized JSON). Global values are always stored encrypted,
|
||||
so a value that no longer decrypts (e.g. after a salt-key change) is dropped
|
||||
and a warning is logged rather than forwarding the ciphertext into upstream
|
||||
``${NAME}`` headers, where it would silently fail.
|
||||
"""
|
||||
if not env_vars:
|
||||
return
|
||||
for entry in env_vars:
|
||||
is_dict = isinstance(entry, dict)
|
||||
scope = entry.get("scope") if is_dict else getattr(entry, "scope", None)
|
||||
if not _is_global_env_var_scope(scope):
|
||||
continue
|
||||
value = entry.get("value") if is_dict else getattr(entry, "value", None)
|
||||
if not value:
|
||||
continue
|
||||
decrypted = decrypt_value_helper(
|
||||
value=value,
|
||||
key="mcp_global_env_var",
|
||||
exception_type="debug",
|
||||
return_original_value=False,
|
||||
)
|
||||
if decrypted is None:
|
||||
name = entry.get("name") if is_dict else getattr(entry, "name", None)
|
||||
verbose_proxy_logger.warning(
|
||||
"MCP global env var %s failed to decrypt (LITELLM_SALT_KEY "
|
||||
"changed?); dropping it so ciphertext is not sent upstream",
|
||||
name,
|
||||
)
|
||||
decrypted = ""
|
||||
if is_dict:
|
||||
entry["value"] = decrypted
|
||||
else:
|
||||
entry.value = decrypted
|
||||
|
||||
|
||||
def _decrypt_env_vars_on_returned_row(row: Any) -> None:
|
||||
"""Decrypt ``scope="global"`` env var values on a row returned by Prisma create/update.
|
||||
|
||||
Prisma may hand back ``env_vars`` either as a parsed list (the common case for
|
||||
JSONB columns) or as a raw JSON string (observed for some write paths). The
|
||||
in-place decrypt helper only mutates iterables of dicts/models, so a string
|
||||
payload would silently skip decryption and ciphertext would leak into the
|
||||
registry via ``add_server``/``update_server`` (which trust the caller).
|
||||
Parse the string back to a list so the in-place decrypt actually runs, and
|
||||
write the decrypted list back onto the row so downstream consumers see plain
|
||||
values.
|
||||
"""
|
||||
env_vars = getattr(row, "env_vars", None)
|
||||
if env_vars is None:
|
||||
return
|
||||
if isinstance(env_vars, str):
|
||||
try:
|
||||
env_vars = json.loads(env_vars)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return
|
||||
if not isinstance(env_vars, list):
|
||||
return
|
||||
try:
|
||||
setattr(row, "env_vars", env_vars)
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
decrypt_global_env_var_values(env_vars)
|
||||
|
||||
|
||||
def _reencrypt_global_env_var_values(
|
||||
env_vars: Optional[Iterable[Any]], new_encryption_key: str
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Re-encrypt ``scope="global"`` env var values for master-key rotation.
|
||||
|
||||
Each global value is decrypted with the current salt key and re-encrypted
|
||||
under ``new_encryption_key``. Returns the rebuilt list when at least one
|
||||
value was rotated, else ``None`` so the caller can skip the DB write. A
|
||||
value that fails to decrypt is left untouched (and logged) so a corrupt
|
||||
entry is preserved for recovery rather than overwritten.
|
||||
"""
|
||||
if not env_vars:
|
||||
return None
|
||||
if isinstance(env_vars, str):
|
||||
try:
|
||||
env_vars = json.loads(env_vars)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
if not env_vars:
|
||||
return None
|
||||
rebuilt = [dict(v) for v in env_vars]
|
||||
rotated = False
|
||||
for entry in rebuilt:
|
||||
if not _is_global_env_var_scope(entry.get("scope")):
|
||||
continue
|
||||
value = entry.get("value")
|
||||
if not value:
|
||||
continue
|
||||
decrypted = decrypt_value_helper(
|
||||
value=value,
|
||||
key="mcp_global_env_var",
|
||||
exception_type="debug",
|
||||
return_original_value=False,
|
||||
)
|
||||
if decrypted is None:
|
||||
verbose_proxy_logger.warning(
|
||||
"rotate_mcp_server_credentials_master_key: could not decrypt "
|
||||
"global env var %s, skipping",
|
||||
entry.get("name"),
|
||||
)
|
||||
continue
|
||||
entry["value"] = encrypt_value_helper(
|
||||
decrypted, new_encryption_key=new_encryption_key
|
||||
)
|
||||
rotated = True
|
||||
return rebuilt if rotated else None
|
||||
|
||||
|
||||
def _prepare_mcp_server_data(
|
||||
data: Union[NewMCPServerRequest, UpdateMCPServerRequest],
|
||||
exclude_unset: bool = False,
|
||||
@ -98,6 +238,16 @@ def _prepare_mcp_server_data(
|
||||
if data_dict.get("static_headers") is not None:
|
||||
data_dict["static_headers"] = safe_dumps(data_dict["static_headers"])
|
||||
|
||||
# env_vars is read from ``data_dict`` (not ``data``) like every other JSON
|
||||
# column so the exclude_unset filter is respected: a partial update that
|
||||
# omits env_vars never overwrites the stored value. Global values are
|
||||
# encrypted at rest before serialization.
|
||||
env_vars = data_dict.get("env_vars")
|
||||
if env_vars is not None:
|
||||
serialized_env_vars = [dict(v) for v in env_vars]
|
||||
_encrypt_global_env_var_values(serialized_env_vars)
|
||||
data_dict["env_vars"] = safe_dumps(serialized_env_vars)
|
||||
|
||||
if data_dict.get("mcp_info") is not None:
|
||||
data_dict["mcp_info"] = safe_dumps(data_dict["mcp_info"])
|
||||
|
||||
@ -207,10 +357,13 @@ async def get_all_mcp_servers(
|
||||
where=where if where else {}
|
||||
)
|
||||
|
||||
return [
|
||||
tables = [
|
||||
LiteLLM_MCPServerTable(**mcp_server.model_dump())
|
||||
for mcp_server in mcp_servers
|
||||
]
|
||||
for table in tables:
|
||||
decrypt_global_env_var_values(table.env_vars)
|
||||
return tables
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(
|
||||
"litellm.proxy._experimental.mcp_server.db.py::get_all_mcp_servers - {}".format(
|
||||
@ -226,14 +379,16 @@ async def get_mcp_server(
|
||||
"""
|
||||
Returns the matching mcp server from the db iff exists
|
||||
"""
|
||||
mcp_server: Optional[LiteLLM_MCPServerTable] = (
|
||||
await prisma_client.db.litellm_mcpservertable.find_unique(
|
||||
where={
|
||||
"server_id": server_id,
|
||||
}
|
||||
)
|
||||
mcp_server = await prisma_client.db.litellm_mcpservertable.find_unique(
|
||||
where={
|
||||
"server_id": server_id,
|
||||
}
|
||||
)
|
||||
return mcp_server
|
||||
if mcp_server is None:
|
||||
return None
|
||||
table = LiteLLM_MCPServerTable(**mcp_server.model_dump())
|
||||
decrypt_global_env_var_values(table.env_vars)
|
||||
return table
|
||||
|
||||
|
||||
async def get_mcp_servers(
|
||||
@ -251,7 +406,9 @@ async def get_mcp_servers(
|
||||
)
|
||||
final_mcp_servers: List[LiteLLM_MCPServerTable] = []
|
||||
for _mcp_server in _mcp_servers:
|
||||
final_mcp_servers.append(LiteLLM_MCPServerTable(**_mcp_server.model_dump()))
|
||||
table = LiteLLM_MCPServerTable(**_mcp_server.model_dump())
|
||||
decrypt_global_env_var_values(table.env_vars)
|
||||
final_mcp_servers.append(table)
|
||||
|
||||
return final_mcp_servers
|
||||
|
||||
@ -399,6 +556,11 @@ async def delete_mcp_server(
|
||||
"""
|
||||
Delete the mcp server from the db by server_id
|
||||
|
||||
The server-row delete is the commit point. Per-user env var rows have no FK
|
||||
cascade, so they are cleaned up afterwards on a best-effort basis: a transient
|
||||
failure there leaves only orphaned rows pointing at a now-missing server and
|
||||
must not turn a successful delete into a caller-visible error.
|
||||
|
||||
Returns the deleted mcp server record if it exists, otherwise None
|
||||
"""
|
||||
deleted_server = await prisma_client.db.litellm_mcpservertable.delete(
|
||||
@ -406,6 +568,18 @@ async def delete_mcp_server(
|
||||
"server_id": server_id,
|
||||
},
|
||||
)
|
||||
if deleted_server is not None:
|
||||
try:
|
||||
await prisma_client.db.litellm_mcpuserenvvars.delete_many(
|
||||
where={"server_id": server_id}
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"MCP server %s deleted but per-user env var cleanup failed; "
|
||||
"orphaned rows can be removed on a later delete: %s",
|
||||
server_id,
|
||||
e,
|
||||
)
|
||||
return deleted_server
|
||||
|
||||
|
||||
@ -429,6 +603,7 @@ async def create_mcp_server(
|
||||
data=data_dict # type: ignore
|
||||
)
|
||||
|
||||
_decrypt_env_vars_on_returned_row(new_mcp_server)
|
||||
return new_mcp_server
|
||||
|
||||
|
||||
@ -506,40 +681,52 @@ async def update_mcp_server(
|
||||
where={"server_id": data.server_id}, data=data_dict # type: ignore
|
||||
)
|
||||
|
||||
_decrypt_env_vars_on_returned_row(updated_mcp_server)
|
||||
return updated_mcp_server
|
||||
|
||||
|
||||
async def rotate_mcp_server_credentials_master_key(
|
||||
prisma_client: PrismaClient, touched_by: str, new_master_key: str
|
||||
):
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
mcp_servers = await prisma_client.db.litellm_mcpservertable.find_many()
|
||||
|
||||
updated = 0
|
||||
for mcp_server in mcp_servers:
|
||||
update_data: Dict[str, Any] = {}
|
||||
|
||||
credentials = mcp_server.credentials
|
||||
if not credentials:
|
||||
if credentials:
|
||||
# Decrypt with current key first, then re-encrypt with new key
|
||||
decrypted_credentials = decrypt_credentials(
|
||||
credentials=cast(MCPCredentials, dict(credentials)),
|
||||
)
|
||||
encrypted_credentials = encrypt_credentials(
|
||||
credentials=decrypted_credentials,
|
||||
encryption_key=new_master_key,
|
||||
)
|
||||
update_data["credentials"] = safe_dumps(encrypted_credentials)
|
||||
|
||||
rotated_env_vars = _reencrypt_global_env_var_values(
|
||||
mcp_server.env_vars, new_master_key
|
||||
)
|
||||
if rotated_env_vars is not None:
|
||||
update_data["env_vars"] = safe_dumps(rotated_env_vars)
|
||||
|
||||
if not update_data:
|
||||
continue
|
||||
|
||||
credentials_copy = dict(credentials)
|
||||
# Decrypt with current key first, then re-encrypt with new key
|
||||
decrypted_credentials = decrypt_credentials(
|
||||
credentials=cast(MCPCredentials, credentials_copy),
|
||||
)
|
||||
encrypted_credentials = encrypt_credentials(
|
||||
credentials=decrypted_credentials,
|
||||
encryption_key=new_master_key,
|
||||
)
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
serialized_credentials = safe_dumps(encrypted_credentials)
|
||||
|
||||
update_data["updated_by"] = touched_by
|
||||
await prisma_client.db.litellm_mcpservertable.update(
|
||||
where={"server_id": mcp_server.server_id},
|
||||
data={
|
||||
"credentials": serialized_credentials,
|
||||
"updated_by": touched_by,
|
||||
},
|
||||
data=update_data,
|
||||
)
|
||||
updated += 1
|
||||
verbose_proxy_logger.info(
|
||||
"rotate_mcp_server_credentials_master_key: rotated %d MCP server row(s)",
|
||||
updated,
|
||||
)
|
||||
|
||||
|
||||
def _decode_user_credential(stored: str) -> Optional[str]:
|
||||
@ -594,6 +781,8 @@ async def rotate_mcp_user_credentials_master_key(
|
||||
are logged and skipped so one corrupt row does not abort the rotation.
|
||||
"""
|
||||
rows = await prisma_client.db.litellm_mcpusercredentials.find_many()
|
||||
rotated = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
plaintext = _decode_user_credential(row.credential_b64)
|
||||
if plaintext is None:
|
||||
@ -603,6 +792,7 @@ async def rotate_mcp_user_credentials_master_key(
|
||||
row.user_id,
|
||||
row.server_id,
|
||||
)
|
||||
skipped += 1
|
||||
continue
|
||||
re_encrypted = encrypt_value_helper(
|
||||
plaintext, new_encryption_key=new_master_key
|
||||
@ -616,6 +806,61 @@ async def rotate_mcp_user_credentials_master_key(
|
||||
},
|
||||
data={"credential_b64": re_encrypted},
|
||||
)
|
||||
rotated += 1
|
||||
verbose_proxy_logger.info(
|
||||
"rotate_mcp_user_credentials_master_key: rotated %d row(s), skipped %d",
|
||||
rotated,
|
||||
skipped,
|
||||
)
|
||||
|
||||
|
||||
async def rotate_mcp_user_env_vars_master_key(
|
||||
prisma_client: PrismaClient, new_master_key: str
|
||||
):
|
||||
"""Re-encrypt every ``LiteLLM_MCPUserEnvVars`` row with ``new_master_key``.
|
||||
|
||||
Reads each ``values_b64`` blob with the current salt key and writes it back
|
||||
encrypted under the new master key. Rows that fail to decrypt are logged and
|
||||
skipped so one corrupt row does not abort the rotation nor overwrite values
|
||||
that may still be recoverable.
|
||||
"""
|
||||
rows = await prisma_client.db.litellm_mcpuserenvvars.find_many()
|
||||
rotated = 0
|
||||
skipped = 0
|
||||
for row in rows:
|
||||
plaintext = decrypt_value_helper(
|
||||
value=row.values_b64,
|
||||
key="mcp_user_env_vars",
|
||||
exception_type="debug",
|
||||
return_original_value=False,
|
||||
)
|
||||
if plaintext is None:
|
||||
verbose_proxy_logger.warning(
|
||||
"rotate_mcp_user_env_vars_master_key: could not decrypt env vars "
|
||||
"for user_id=%s server_id=%s, skipping",
|
||||
row.user_id,
|
||||
row.server_id,
|
||||
)
|
||||
skipped += 1
|
||||
continue
|
||||
re_encrypted = encrypt_value_helper(
|
||||
plaintext, new_encryption_key=new_master_key
|
||||
)
|
||||
await prisma_client.db.litellm_mcpuserenvvars.update(
|
||||
where={
|
||||
"user_id_server_id": {
|
||||
"user_id": row.user_id,
|
||||
"server_id": row.server_id,
|
||||
}
|
||||
},
|
||||
data={"values_b64": re_encrypted},
|
||||
)
|
||||
rotated += 1
|
||||
verbose_proxy_logger.info(
|
||||
"rotate_mcp_user_env_vars_master_key: rotated %d row(s), skipped %d",
|
||||
rotated,
|
||||
skipped,
|
||||
)
|
||||
|
||||
|
||||
async def store_user_credential(
|
||||
@ -927,7 +1172,9 @@ async def approve_mcp_server(
|
||||
"updated_by": touched_by,
|
||||
},
|
||||
)
|
||||
return LiteLLM_MCPServerTable(**updated.model_dump())
|
||||
table = LiteLLM_MCPServerTable(**updated.model_dump())
|
||||
decrypt_global_env_var_values(table.env_vars)
|
||||
return table
|
||||
|
||||
|
||||
async def reject_mcp_server(
|
||||
@ -949,7 +1196,9 @@ async def reject_mcp_server(
|
||||
where={"server_id": server_id},
|
||||
data=data,
|
||||
)
|
||||
return LiteLLM_MCPServerTable(**updated.model_dump())
|
||||
table = LiteLLM_MCPServerTable(**updated.model_dump())
|
||||
decrypt_global_env_var_values(table.env_vars)
|
||||
return table
|
||||
|
||||
|
||||
async def get_mcp_submissions(
|
||||
@ -966,6 +1215,8 @@ async def get_mcp_submissions(
|
||||
take=500, # safety cap; paginate if needed in a future iteration
|
||||
)
|
||||
items = [LiteLLM_MCPServerTable(**r.model_dump()) for r in rows]
|
||||
for item in items:
|
||||
decrypt_global_env_var_values(item.env_vars)
|
||||
|
||||
pending = sum(
|
||||
1 for i in items if i.approval_status == MCPApprovalStatus.pending_review
|
||||
@ -980,3 +1231,121 @@ async def get_mcp_submissions(
|
||||
rejected=rejected,
|
||||
items=items,
|
||||
)
|
||||
|
||||
|
||||
# ── Per-user MCP environment variables ────────────────────────────────────
|
||||
|
||||
|
||||
def _decode_user_env_vars(stored: str) -> Dict[str, str]:
|
||||
"""Decrypt a ``values_b64`` blob and parse it as a flat ``{name: value}`` dict."""
|
||||
decrypted = decrypt_value_helper(
|
||||
value=stored,
|
||||
key="mcp_user_env_vars",
|
||||
exception_type="debug",
|
||||
return_original_value=False,
|
||||
)
|
||||
if decrypted is None:
|
||||
if stored:
|
||||
verbose_proxy_logger.warning(
|
||||
"MCP per-user env vars failed to decrypt (LITELLM_SALT_KEY "
|
||||
"changed?); treating as unset so the user is prompted to "
|
||||
"re-enter them rather than silently forwarding ciphertext"
|
||||
)
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(decrypted)
|
||||
except (ValueError, TypeError):
|
||||
return {}
|
||||
if not isinstance(parsed, dict):
|
||||
return {}
|
||||
return {str(k): str(v) for k, v in parsed.items()}
|
||||
|
||||
|
||||
async def get_user_env_vars(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
) -> Dict[str, str]:
|
||||
"""Return the calling user's env var dict for ``server_id`` (empty if none)."""
|
||||
row = await prisma_client.db.litellm_mcpuserenvvars.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
if row is None:
|
||||
return {}
|
||||
return _decode_user_env_vars(row.values_b64)
|
||||
|
||||
|
||||
async def get_user_env_vars_bulk(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_ids: Iterable[str],
|
||||
) -> Dict[str, Dict[str, str]]:
|
||||
"""Return ``{server_id: {var_name: value}}`` for one user across many servers.
|
||||
|
||||
Servers with no stored row are simply absent from the result.
|
||||
"""
|
||||
ids = list(server_ids)
|
||||
if not ids:
|
||||
return {}
|
||||
rows = await prisma_client.db.litellm_mcpuserenvvars.find_many(
|
||||
where={"user_id": user_id, "server_id": {"in": ids}}
|
||||
)
|
||||
return {row.server_id: _decode_user_env_vars(row.values_b64) for row in rows}
|
||||
|
||||
|
||||
async def merge_user_env_vars(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
updates: Dict[str, str],
|
||||
allowed_names: Iterable[str],
|
||||
) -> Dict[str, str]:
|
||||
"""Merge ``updates`` into the user's stored env vars for ``server_id`` and
|
||||
return the resulting set.
|
||||
|
||||
The read-modify-write runs inside a transaction guarded by a
|
||||
``(user_id, server_id)`` advisory lock so two concurrent writes from the
|
||||
same user can't drop one update. Names outside ``allowed_names`` are pruned,
|
||||
so an admin retiring a user-scoped variable also clears its stored value.
|
||||
"""
|
||||
allowed = set(allowed_names)
|
||||
lock_key = int.from_bytes(
|
||||
hashlib.blake2b(f"{user_id}:{server_id}".encode(), digest_size=8).digest(),
|
||||
"big",
|
||||
signed=True,
|
||||
)
|
||||
async with prisma_client.db.tx() as tx:
|
||||
await tx.execute_raw("SELECT pg_advisory_xact_lock($1::bigint)", lock_key)
|
||||
row = await tx.litellm_mcpuserenvvars.find_unique(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}}
|
||||
)
|
||||
existing = _decode_user_env_vars(row.values_b64) if row is not None else {}
|
||||
merged = {k: v for k, v in {**existing, **updates}.items() if k in allowed}
|
||||
encoded = encrypt_value_helper(json.dumps(merged))
|
||||
await tx.litellm_mcpuserenvvars.upsert(
|
||||
where={"user_id_server_id": {"user_id": user_id, "server_id": server_id}},
|
||||
data={
|
||||
"create": {
|
||||
"user_id": user_id,
|
||||
"server_id": server_id,
|
||||
"values_b64": encoded,
|
||||
},
|
||||
"update": {"values_b64": encoded},
|
||||
},
|
||||
)
|
||||
return merged
|
||||
|
||||
|
||||
async def delete_user_env_vars(
|
||||
prisma_client: PrismaClient,
|
||||
user_id: str,
|
||||
server_id: str,
|
||||
) -> None:
|
||||
"""Remove the calling user's env var values for ``server_id``.
|
||||
|
||||
Uses ``delete_many`` so a missing row is a no-op; real DB errors still
|
||||
propagate to the caller instead of being silently swallowed.
|
||||
"""
|
||||
await prisma_client.db.litellm_mcpuserenvvars.delete_many(
|
||||
where={"user_id": user_id, "server_id": server_id}
|
||||
)
|
||||
|
||||
@ -58,20 +58,26 @@ from litellm.proxy._experimental.mcp_server.sampling_handler import (
|
||||
from litellm.proxy._experimental.mcp_server.oauth2_token_cache import resolve_mcp_auth
|
||||
from litellm.proxy._experimental.mcp_server.utils import (
|
||||
MCP_TOOL_PREFIX_SEPARATOR,
|
||||
MCPMissingUserEnvVarsError,
|
||||
add_server_prefix_to_name,
|
||||
build_env_var_setup_url,
|
||||
collect_env_var_references,
|
||||
compute_short_server_prefix,
|
||||
get_server_prefix,
|
||||
interpolate_headers,
|
||||
is_short_mcp_tool_prefix_enabled,
|
||||
is_tool_name_prefixed,
|
||||
iter_known_server_prefixes,
|
||||
merge_mcp_headers,
|
||||
normalize_server_name,
|
||||
parse_admin_env_vars,
|
||||
split_server_prefix_from_name,
|
||||
validate_mcp_server_name,
|
||||
)
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_MCPServerTable,
|
||||
MCPAuthType,
|
||||
MCPEnvVar,
|
||||
MCPTransport,
|
||||
MCPTransportType,
|
||||
UserAPIKeyAuth,
|
||||
@ -124,6 +130,33 @@ _AZURE_ENTRA_HOSTS = {
|
||||
"login.chinacloudapi.cn", # China
|
||||
}
|
||||
|
||||
# Short-lived in-memory cache for per-user MCP env var values, mirroring the
|
||||
# BYOK credential cache. Keyed by (user_id, server_id); value is
|
||||
# (values_dict, monotonic_timestamp). Keeps the tool-call and tool-listing
|
||||
# paths off the DB on every request within the TTL window.
|
||||
_user_env_vars_cache: Dict[Tuple[str, str], Tuple[Dict[str, str], float]] = {}
|
||||
_USER_ENV_VARS_CACHE_TTL = 60 # seconds
|
||||
_USER_ENV_VARS_CACHE_MAX_SIZE = 4096 # cap to prevent unbounded growth
|
||||
|
||||
|
||||
def invalidate_user_env_vars_cache(user_id: str, server_id: str) -> None:
|
||||
"""Drop a cached entry after the user stores or clears their env var values
|
||||
so the next request reads the fresh value instead of a stale one."""
|
||||
_user_env_vars_cache.pop((user_id, server_id), None)
|
||||
|
||||
|
||||
def _write_user_env_vars_cache(
|
||||
user_id: str, server_id: str, values: Dict[str, str]
|
||||
) -> None:
|
||||
cache_key = (user_id, server_id)
|
||||
# Re-insert at the tail so eviction drops the oldest-written entry, not a
|
||||
# freshly refreshed one, and only sheds a single entry instead of wiping the
|
||||
# whole cache (which would stampede the DB).
|
||||
_user_env_vars_cache.pop(cache_key, None)
|
||||
if len(_user_env_vars_cache) >= _USER_ENV_VARS_CACHE_MAX_SIZE:
|
||||
_user_env_vars_cache.pop(next(iter(_user_env_vars_cache)), None)
|
||||
_user_env_vars_cache[cache_key] = (values, time.monotonic())
|
||||
|
||||
|
||||
def _should_strip_caller_authorization(
|
||||
mcp_server: MCPServer,
|
||||
@ -295,6 +328,31 @@ def _deserialize_json_dict(data: Any) -> Optional[Dict[str, str]]:
|
||||
return data
|
||||
|
||||
|
||||
def _deserialize_json_list(data: Any) -> Optional[List[Dict[str, Any]]]:
|
||||
"""Deserialize a JSON array stored in the DB (``env_vars`` and friends).
|
||||
|
||||
Returns ``None`` for empty / null / unparseable input. Accepts strings
|
||||
(raw JSON), already-materialized lists of dicts, and lists of Pydantic
|
||||
models (Prisma may hydrate a JSON column such as ``env_vars`` into
|
||||
``MCPEnvVar`` objects); model entries are normalized to plain dicts so
|
||||
downstream consumers expecting ``List[Dict[str, Any]]`` validate.
|
||||
"""
|
||||
if data is None or data == "" or data == []:
|
||||
return None
|
||||
if isinstance(data, str):
|
||||
try:
|
||||
parsed = json.loads(data)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return None
|
||||
data = parsed
|
||||
if not isinstance(data, list):
|
||||
return None
|
||||
return [
|
||||
item.model_dump(mode="json") if hasattr(item, "model_dump") else item
|
||||
for item in data
|
||||
]
|
||||
|
||||
|
||||
def _create_sampling_callback(user_api_key_auth: Optional[Any] = None):
|
||||
"""
|
||||
Create a sampling callback for MCP ClientSession.
|
||||
@ -456,7 +514,8 @@ class MCPServerManager:
|
||||
- server is OpenAPI (spec_path),
|
||||
- non-empty upstream instructions are already cached,
|
||||
- auth preconditions match health_check_server's skip rules
|
||||
(per-user auth / missing static auth token),
|
||||
(per-user auth / missing static auth token / static headers that
|
||||
reference a per-user env var),
|
||||
- a prior probe attempt for this server is within
|
||||
MCP_HEALTH_CHECK_TIMEOUT seconds (the probe is a health-check-shaped
|
||||
op and already uses this knob for its inner call timeout; reusing it
|
||||
@ -471,6 +530,8 @@ class MCPServerManager:
|
||||
return
|
||||
if server.requires_per_user_auth:
|
||||
return
|
||||
if self._references_per_user_env_var(server):
|
||||
return
|
||||
if (
|
||||
server.auth_type
|
||||
and server.auth_type != MCPAuth.none
|
||||
@ -495,8 +556,13 @@ class MCPServerManager:
|
||||
)
|
||||
|
||||
try:
|
||||
resolved_static_headers = await self._resolve_static_headers_with_env_vars(
|
||||
server=server,
|
||||
user_api_key_auth=None,
|
||||
raise_on_missing=False,
|
||||
)
|
||||
extra_headers: Optional[Dict[str, str]] = (
|
||||
dict(server.static_headers) if server.static_headers else None
|
||||
dict(resolved_static_headers) if resolved_static_headers else None
|
||||
)
|
||||
client = await self._create_mcp_client(
|
||||
server=server,
|
||||
@ -656,6 +722,7 @@ class MCPServerManager:
|
||||
allowed_params=server_config.get("allowed_params", None),
|
||||
access_groups=server_config.get("access_groups", None),
|
||||
static_headers=server_config.get("static_headers", None),
|
||||
env_vars=server_config.get("env_vars", None),
|
||||
allow_all_keys=bool(server_config.get("allow_all_keys", False)),
|
||||
available_on_public_internet=bool(
|
||||
server_config.get("available_on_public_internet", True)
|
||||
@ -920,17 +987,41 @@ class MCPServerManager:
|
||||
f"Server ID {mcp_server.server_id} not found in registry"
|
||||
)
|
||||
|
||||
def _resolve_env_vars_list(
|
||||
self,
|
||||
mcp_server: LiteLLM_MCPServerTable,
|
||||
*,
|
||||
env_vars_are_encrypted: bool,
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
env_vars_list = _deserialize_json_list(getattr(mcp_server, "env_vars", None))
|
||||
if env_vars_are_encrypted:
|
||||
from litellm.proxy._experimental.mcp_server.db import ( # noqa: PLC0415
|
||||
decrypt_global_env_var_values,
|
||||
)
|
||||
|
||||
decrypt_global_env_var_values(env_vars_list)
|
||||
return env_vars_list
|
||||
|
||||
async def build_mcp_server_from_table(
|
||||
self,
|
||||
mcp_server: LiteLLM_MCPServerTable,
|
||||
*,
|
||||
credentials_are_encrypted: bool = True,
|
||||
env_vars_are_encrypted: Optional[bool] = None,
|
||||
) -> MCPServer:
|
||||
_mcp_info: MCPInfo = mcp_server.mcp_info or {}
|
||||
env_dict = _deserialize_json_dict(getattr(mcp_server, "env", None))
|
||||
static_headers_dict = _deserialize_json_dict(
|
||||
getattr(mcp_server, "static_headers", None)
|
||||
)
|
||||
env_vars_list = self._resolve_env_vars_list(
|
||||
mcp_server,
|
||||
env_vars_are_encrypted=(
|
||||
credentials_are_encrypted
|
||||
if env_vars_are_encrypted is None
|
||||
else env_vars_are_encrypted
|
||||
),
|
||||
)
|
||||
credentials_dict = _deserialize_json_dict(
|
||||
getattr(mcp_server, "credentials", None)
|
||||
)
|
||||
@ -1030,6 +1121,7 @@ class MCPServerManager:
|
||||
mcp_info=mcp_info,
|
||||
extra_headers=getattr(mcp_server, "extra_headers", None),
|
||||
static_headers=static_headers_dict,
|
||||
env_vars=env_vars_list,
|
||||
client_id=client_id_value or getattr(mcp_server, "client_id", None),
|
||||
client_secret=client_secret_value
|
||||
or getattr(mcp_server, "client_secret", None),
|
||||
@ -1128,7 +1220,14 @@ class MCPServerManager:
|
||||
return
|
||||
try:
|
||||
if mcp_server.server_id not in self.registry:
|
||||
new_server = await self.build_mcp_server_from_table(mcp_server)
|
||||
# Callers hand us a record returned by the db.py read/write
|
||||
# helpers, which already decrypt global env var values (the
|
||||
# `credentials` field is the only one still encrypted here).
|
||||
# Re-decrypting plaintext would zero the values, so build with
|
||||
# env_vars_are_encrypted=False.
|
||||
new_server = await self.build_mcp_server_from_table(
|
||||
mcp_server, env_vars_are_encrypted=False
|
||||
)
|
||||
self._assign_unique_short_prefix(new_server)
|
||||
self.registry[mcp_server.server_id] = new_server
|
||||
await self._maybe_register_openapi_tools(new_server)
|
||||
@ -1151,7 +1250,11 @@ class MCPServerManager:
|
||||
return
|
||||
try:
|
||||
if mcp_server.server_id in self.registry:
|
||||
new_server = await self.build_mcp_server_from_table(mcp_server)
|
||||
# See add_server: db.py helpers already decrypted env var
|
||||
# values, so don't decrypt them a second time here.
|
||||
new_server = await self.build_mcp_server_from_table(
|
||||
mcp_server, env_vars_are_encrypted=False
|
||||
)
|
||||
# Carry the previously-resolved short prefix across so the
|
||||
# tool names stay stable for clients holding cached lists.
|
||||
existing_prefix = self.registry[mcp_server.server_id].short_prefix
|
||||
@ -1572,6 +1675,180 @@ class MCPServerManager:
|
||||
|
||||
return resolved_env
|
||||
|
||||
def _references_per_user_env_var(self, server: MCPServer) -> bool:
|
||||
"""True when ``server.static_headers`` reference a per-user ``${NAME}`` env var.
|
||||
|
||||
Such placeholders can only be filled from a calling user's stored values,
|
||||
so a userless probe (health check / instructions prefetch) would forward
|
||||
the literal ``${NAME}`` upstream and get rejected. Callers skip the probe
|
||||
and report ``unknown`` instead of a misleading ``unhealthy``.
|
||||
"""
|
||||
static_headers = server.static_headers
|
||||
env_vars = getattr(server, "env_vars", None)
|
||||
if not static_headers or not env_vars:
|
||||
return False
|
||||
_global_values, user_specs = parse_admin_env_vars(env_vars)
|
||||
user_var_names = {spec["name"] for spec in user_specs}
|
||||
if not user_var_names:
|
||||
return False
|
||||
referenced = collect_env_var_references(strings=static_headers.values())
|
||||
return bool(referenced & user_var_names)
|
||||
|
||||
async def _resolve_static_headers_with_env_vars(
|
||||
self,
|
||||
server: MCPServer,
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth],
|
||||
*,
|
||||
raise_on_missing: bool = True,
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""Return server.static_headers with ``${NAME}`` interpolated.
|
||||
|
||||
Globals come from ``server.env_vars`` entries with ``scope=="global"``.
|
||||
Per-user values come from the ``LiteLLM_MCPUserEnvVars`` row for the
|
||||
calling user.
|
||||
|
||||
When ``raise_on_missing`` is ``True`` (the tool-*call* path), raises
|
||||
``MCPMissingUserEnvVarsError`` if ``static_headers`` reference a per-user
|
||||
variable the calling user has not yet supplied — converted into a
|
||||
user-facing 412 by the REST layer.
|
||||
|
||||
When ``raise_on_missing`` is ``False`` (the tool-*list* path), missing
|
||||
per-user vars are non-blocking: we interpolate whatever is available and
|
||||
leave unfilled ``${NAME}`` references untouched, so the server's tools
|
||||
still appear in the listing. The user only hits the friendly error when
|
||||
they actually invoke a tool that needs the missing value.
|
||||
"""
|
||||
static_headers = server.static_headers
|
||||
env_vars = getattr(server, "env_vars", None)
|
||||
if not static_headers and not env_vars:
|
||||
return static_headers
|
||||
|
||||
global_values, user_specs = parse_admin_env_vars(env_vars)
|
||||
# An empty-valued global is treated as unset: it must not mask a per-user
|
||||
# var the user still has to supply, nor override a value the user did
|
||||
# supply. The unresolved ${NAME} is then left untouched, like any other
|
||||
# undefined reference.
|
||||
global_values = {name: value for name, value in global_values.items() if value}
|
||||
user_var_names = {spec["name"] for spec in user_specs}
|
||||
|
||||
# If no env vars are configured, return static_headers as-is.
|
||||
if not global_values and not user_specs:
|
||||
return static_headers
|
||||
|
||||
# Figure out which user-scoped vars are actually referenced. A var that
|
||||
# also carries a global value is always covered by that global (globals
|
||||
# win in the merge below), so it can never be genuinely "missing" even if
|
||||
# the user hasn't filled it in -- only vars without a global fallback do.
|
||||
referenced = collect_env_var_references(strings=(static_headers or {}).values())
|
||||
referenced_user_vars = referenced & user_var_names
|
||||
required_user_vars = {
|
||||
name for name in referenced_user_vars if name not in global_values
|
||||
}
|
||||
|
||||
user_values: Dict[str, str] = {}
|
||||
if required_user_vars:
|
||||
try:
|
||||
user_values = await self._load_user_env_vars(server, user_api_key_auth)
|
||||
except Exception as exc:
|
||||
# On the tool-call path a DB failure must surface as a real
|
||||
# server error, not a misleading "set up your credentials" 412.
|
||||
# On the listing path we stay best-effort and leave the
|
||||
# unfilled ${NAME} references untouched so tools still appear.
|
||||
if raise_on_missing:
|
||||
raise
|
||||
verbose_logger.warning(
|
||||
"MCPServerManager: best-effort user env var load failed for "
|
||||
"server=%s: %s",
|
||||
server.server_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
if raise_on_missing:
|
||||
missing = sorted(
|
||||
name for name in required_user_vars if not user_values.get(name)
|
||||
)
|
||||
if missing:
|
||||
# A cached negative must never produce a 412: cache
|
||||
# invalidation is process-local, so a user who just stored
|
||||
# values on another worker would otherwise be told their
|
||||
# credentials are missing until the entry expires. Confirm
|
||||
# against the DB before raising.
|
||||
user_values = await self._load_user_env_vars(
|
||||
server, user_api_key_auth, force_refresh=True
|
||||
)
|
||||
missing = sorted(
|
||||
name for name in required_user_vars if not user_values.get(name)
|
||||
)
|
||||
if missing:
|
||||
raise MCPMissingUserEnvVarsError(
|
||||
server_id=server.server_id,
|
||||
server_name=server.server_name or server.name,
|
||||
missing=missing,
|
||||
setup_url=build_env_var_setup_url(server.server_id),
|
||||
)
|
||||
|
||||
# Only honor stored user values for currently user-scoped vars, and let
|
||||
# admin globals win, so a stale row from when a var was user-scoped can
|
||||
# never override the global value the admin set after switching it.
|
||||
scoped_user_values = {
|
||||
name: value for name, value in user_values.items() if name in user_var_names
|
||||
}
|
||||
merged_vars: Dict[str, str] = {**scoped_user_values, **global_values}
|
||||
if not static_headers:
|
||||
return static_headers
|
||||
return interpolate_headers(static_headers, merged_vars)
|
||||
|
||||
async def _load_user_env_vars(
|
||||
self,
|
||||
server: MCPServer,
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth],
|
||||
*,
|
||||
force_refresh: bool = False,
|
||||
) -> Dict[str, str]:
|
||||
"""Look up the calling user's env var values for ``server``.
|
||||
|
||||
Returns an empty dict when no user is available. Results are cached in a
|
||||
short-lived in-memory map keyed by (user_id, server_id) so the tool-call
|
||||
and tool-listing paths avoid a DB round-trip per request within the TTL
|
||||
window; the cache is invalidated when the user stores or clears values.
|
||||
Pass ``force_refresh`` to bypass the cache read and re-fetch from the DB
|
||||
(used before raising a "missing credentials" error so a process-local
|
||||
stale entry cannot mask values stored on another worker). A missing DB
|
||||
connection and any other DB error propagate so the caller can decide
|
||||
between failing the request (tool-call path) and staying best-effort
|
||||
(listing path); they must never be mistaken for "user has no values",
|
||||
which would send the user a misleading "set up your credentials" 412.
|
||||
"""
|
||||
if user_api_key_auth is None:
|
||||
return {}
|
||||
user_id = getattr(user_api_key_auth, "user_id", None)
|
||||
if not user_id:
|
||||
return {}
|
||||
|
||||
cache_key = (user_id, server.server_id)
|
||||
if not force_refresh:
|
||||
cached = _user_env_vars_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
values, ts = cached
|
||||
if time.monotonic() - ts < _USER_ENV_VARS_CACHE_TTL:
|
||||
return values
|
||||
|
||||
from litellm.proxy.proxy_server import prisma_client # noqa: PLC0415
|
||||
|
||||
if prisma_client is None:
|
||||
raise RuntimeError(
|
||||
"MCP per-user env vars require a database connection, but none "
|
||||
"is configured. Connect a database to your proxy to use per-user "
|
||||
"MCP env vars."
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.db import ( # noqa: PLC0415
|
||||
get_user_env_vars,
|
||||
)
|
||||
|
||||
values = await get_user_env_vars(prisma_client, user_id, server.server_id)
|
||||
_write_user_env_vars_cache(user_id, server.server_id, values)
|
||||
return values
|
||||
|
||||
async def _create_mcp_client(
|
||||
self,
|
||||
server: MCPServer,
|
||||
@ -1732,10 +2009,17 @@ class MCPServerManager:
|
||||
client = None
|
||||
|
||||
try:
|
||||
if server.static_headers:
|
||||
# Tool *listing* must not be blocked by missing per-user env vars —
|
||||
# the server's tools should still appear so the client connects. The
|
||||
# friendly "missing vars" error is raised only on the tool-*call*
|
||||
# path (see _call_regular_mcp_tool).
|
||||
resolved_static_headers = await self._resolve_static_headers_with_env_vars(
|
||||
server, user_api_key_auth, raise_on_missing=False
|
||||
)
|
||||
if resolved_static_headers:
|
||||
if extra_headers is None:
|
||||
extra_headers = {}
|
||||
extra_headers.update(server.static_headers)
|
||||
extra_headers.update(resolved_static_headers)
|
||||
|
||||
# MCPJWTSigner: inject signed JWT for tools/list (list path skips pre_call_hook).
|
||||
# Skip entirely when the signer is not configured (avoid an unnecessary
|
||||
@ -3105,10 +3389,17 @@ class MCPServerManager:
|
||||
continue
|
||||
extra_headers[header] = header_value
|
||||
|
||||
if mcp_server.static_headers:
|
||||
# Interpolate env vars into static_headers. Raises
|
||||
# MCPMissingUserEnvVarsError when the calling user has not filled in
|
||||
# a required per-user variable — the REST layer converts that into
|
||||
# a friendly 412 with a setup URL.
|
||||
resolved_static_headers = await self._resolve_static_headers_with_env_vars(
|
||||
mcp_server, user_api_key_auth
|
||||
)
|
||||
if resolved_static_headers:
|
||||
if extra_headers is None:
|
||||
extra_headers = {}
|
||||
extra_headers.update(mcp_server.static_headers)
|
||||
extra_headers.update(resolved_static_headers)
|
||||
|
||||
if hook_extra_headers:
|
||||
if extra_headers is None:
|
||||
@ -3566,7 +3857,13 @@ class MCPServerManager:
|
||||
verbose_logger.debug(
|
||||
f"Building server from DB: {server.server_id} ({server.server_name})"
|
||||
)
|
||||
new_server = await self.build_mcp_server_from_table(server)
|
||||
# raw_rows come straight from the DB, so their global env var
|
||||
# values (like credentials) are still encrypted here, unlike the
|
||||
# already-decrypted records add_server/update_server are handed.
|
||||
# Decrypt them while building the registry entry.
|
||||
new_server = await self.build_mcp_server_from_table(
|
||||
server, env_vars_are_encrypted=True
|
||||
)
|
||||
# Carry the cached short_prefix from the previous registry entry
|
||||
# (if any) so the prefix is stable across reloads.
|
||||
if existing_server is not None and existing_server.short_prefix:
|
||||
@ -3906,11 +4203,21 @@ class MCPServerManager:
|
||||
and not server.authentication_token
|
||||
):
|
||||
should_skip_health_check = True
|
||||
# Skip if static_headers reference a per-user env var: a userless probe
|
||||
# can't fill ${NAME} and would forward the literal placeholder upstream,
|
||||
# flipping the server to unhealthy even though real user calls succeed.
|
||||
elif self._references_per_user_env_var(server):
|
||||
should_skip_health_check = True
|
||||
|
||||
if not should_skip_health_check:
|
||||
extra_headers = {}
|
||||
if server.static_headers:
|
||||
extra_headers.update(server.static_headers)
|
||||
resolved_static_headers = await self._resolve_static_headers_with_env_vars(
|
||||
server=server,
|
||||
user_api_key_auth=None,
|
||||
raise_on_missing=False,
|
||||
)
|
||||
extra_headers = (
|
||||
dict(resolved_static_headers) if resolved_static_headers else {}
|
||||
)
|
||||
|
||||
client = await self._create_mcp_client(
|
||||
server=server,
|
||||
@ -3960,6 +4267,7 @@ class MCPServerManager:
|
||||
extra_headers=server.extra_headers or [],
|
||||
mcp_info=server.mcp_info,
|
||||
static_headers=server.static_headers,
|
||||
env_vars=self._env_vars_to_models(server.env_vars),
|
||||
status=status,
|
||||
last_health_check=datetime.now(),
|
||||
health_check_error=health_check_error,
|
||||
@ -4033,6 +4341,14 @@ class MCPServerManager:
|
||||
|
||||
return list_mcp_servers
|
||||
|
||||
@staticmethod
|
||||
def _env_vars_to_models(
|
||||
env_vars: Optional[List[Dict[str, Any]]],
|
||||
) -> Optional[List[MCPEnvVar]]:
|
||||
if env_vars is None:
|
||||
return None
|
||||
return [MCPEnvVar.model_validate(env_var) for env_var in env_vars]
|
||||
|
||||
def _build_mcp_server_table(self, server: MCPServer) -> LiteLLM_MCPServerTable:
|
||||
return LiteLLM_MCPServerTable(
|
||||
server_id=server.server_id,
|
||||
@ -4053,6 +4369,7 @@ class MCPServerManager:
|
||||
extra_headers=server.extra_headers or [],
|
||||
mcp_info=server.mcp_info,
|
||||
static_headers=server.static_headers,
|
||||
env_vars=self._env_vars_to_models(server.env_vars),
|
||||
status=None, # No health check performed
|
||||
last_health_check=None, # No health check performed
|
||||
health_check_error=None,
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
@ -13,6 +14,7 @@ from typing import (
|
||||
Union,
|
||||
)
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Request, status
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
@ -20,7 +22,10 @@ from litellm.proxy._experimental.mcp_server.exceptions import MCPUpstreamAuthErr
|
||||
from litellm.proxy._experimental.mcp_server.ui_session_utils import (
|
||||
build_effective_auth_contexts,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.utils import merge_mcp_headers
|
||||
from litellm.proxy._experimental.mcp_server.utils import (
|
||||
MCPMissingUserEnvVarsError,
|
||||
merge_mcp_headers,
|
||||
)
|
||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.ip_address_utils import IPAddressUtils
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
@ -41,6 +46,28 @@ router = APIRouter(
|
||||
tags=["mcp"],
|
||||
)
|
||||
|
||||
|
||||
def _connection_error_message(exc: BaseException) -> str:
|
||||
if isinstance(exc, httpx.LocalProtocolError):
|
||||
return (
|
||||
"Failed to connect to MCP server: a request header is malformed. "
|
||||
"Check static headers for leading/trailing spaces or illegal characters."
|
||||
)
|
||||
if isinstance(exc, (httpx.ConnectError, httpx.ConnectTimeout)):
|
||||
return (
|
||||
"Failed to connect to MCP server: the server is unreachable. "
|
||||
"Check the URL and that the server is running."
|
||||
)
|
||||
if isinstance(exc, httpx.TimeoutException):
|
||||
return "Failed to connect to MCP server: the connection timed out."
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
return (
|
||||
f"Failed to connect to MCP server: it returned HTTP "
|
||||
f"{exc.response.status_code}."
|
||||
)
|
||||
return "Failed to connect to MCP server. Check proxy logs for details."
|
||||
|
||||
|
||||
if MCP_AVAILABLE:
|
||||
from mcp.types import Tool as MCPTool
|
||||
|
||||
@ -812,6 +839,23 @@ if MCP_AVAILABLE:
|
||||
requested_server_id=canonical_server_id,
|
||||
)
|
||||
return result
|
||||
except MCPMissingUserEnvVarsError as e:
|
||||
verbose_logger.info(
|
||||
"MCP tool call missing per-user env vars: server_id=%s missing=%s",
|
||||
e.server_id,
|
||||
e.missing,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=412,
|
||||
detail={
|
||||
"error": "missing_user_env_vars",
|
||||
"message": str(e),
|
||||
"server_id": e.server_id,
|
||||
"server_name": e.server_name,
|
||||
"missing": e.missing,
|
||||
"setup_url": e.setup_url,
|
||||
},
|
||||
)
|
||||
except BlockedPiiEntityError as e:
|
||||
verbose_logger.error(f"BlockedPiiEntityError in MCP tool call: {str(e)}")
|
||||
raise HTTPException(
|
||||
@ -961,14 +1005,14 @@ if MCP_AVAILABLE:
|
||||
|
||||
return await operation(client)
|
||||
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
except (KeyboardInterrupt, SystemExit, asyncio.CancelledError):
|
||||
raise
|
||||
except BaseException as e:
|
||||
verbose_logger.error("Error in MCP operation: %s", e, exc_info=True)
|
||||
return {
|
||||
"status": "error",
|
||||
"error": True,
|
||||
"message": "Failed to connect to MCP server. Check proxy logs for details.",
|
||||
"message": _connection_error_message(e),
|
||||
}
|
||||
|
||||
async def _preview_openapi_tools(spec_path: str) -> dict:
|
||||
|
||||
@ -53,6 +53,7 @@ from litellm.proxy._experimental.mcp_server.utils import (
|
||||
LITELLM_MCP_SERVER_DESCRIPTION,
|
||||
LITELLM_MCP_SERVER_NAME,
|
||||
LITELLM_MCP_SERVER_VERSION,
|
||||
MCPMissingUserEnvVarsError,
|
||||
add_server_prefix_to_name,
|
||||
get_server_prefix,
|
||||
iter_known_server_prefixes,
|
||||
@ -720,6 +721,16 @@ if MCP_AVAILABLE:
|
||||
host_progress_callback=host_progress_callback,
|
||||
**data, # for logging
|
||||
)
|
||||
except MCPMissingUserEnvVarsError as e:
|
||||
verbose_logger.info(
|
||||
"MCP mcp_server_tool_call missing per-user env vars: server_id=%s missing=%s",
|
||||
e.server_id,
|
||||
e.missing,
|
||||
)
|
||||
return CallToolResult(
|
||||
content=[TextContent(text=str(e), type="text")],
|
||||
isError=True,
|
||||
)
|
||||
except BlockedPiiEntityError as e:
|
||||
verbose_logger.error(
|
||||
f"BlockedPiiEntityError in MCP tool call: {str(e)}"
|
||||
|
||||
@ -4,11 +4,23 @@ MCP Server Utilities
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import Any, Dict, Iterator, Mapping, Optional, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Iterable,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import hashlib
|
||||
import importlib
|
||||
import os
|
||||
from urllib.parse import quote
|
||||
|
||||
# Constants
|
||||
LITELLM_MCP_SERVER_NAME = "litellm-mcp-server"
|
||||
@ -370,6 +382,130 @@ def validate_mcp_server_name(
|
||||
raise Exception(error_message)
|
||||
|
||||
|
||||
class MCPMissingUserEnvVarsError(Exception):
|
||||
"""Raised when an MCP request can't be built because the calling user has
|
||||
not supplied one or more required per-user environment variables.
|
||||
|
||||
The error message is user-facing and includes a URL the user can visit
|
||||
to fill them in.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server_id: str,
|
||||
server_name: Optional[str],
|
||||
missing: List[str],
|
||||
setup_url: str,
|
||||
) -> None:
|
||||
self.server_id = server_id
|
||||
self.server_name = server_name
|
||||
self.missing = missing
|
||||
self.setup_url = setup_url
|
||||
label = server_name or server_id
|
||||
bullet_list = "\n".join(f"- {name}" for name in missing)
|
||||
message = (
|
||||
f'Cannot connect to MCP server "{label}".\n\n'
|
||||
f"Your administrator configured this server to require per-user "
|
||||
f"variables, but you haven't set the following yet:\n"
|
||||
f"{bullet_list}\n\n"
|
||||
f"Set your credentials here:\n"
|
||||
f"{setup_url}"
|
||||
)
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
# Pattern for ``${NAME}`` substitution. Matches the standard env-var
|
||||
# identifier rules — letters, digits, underscores, can't start with a digit.
|
||||
_ENV_VAR_PATTERN = re.compile(r"\$\{([A-Za-z_][A-Za-z0-9_]*)\}")
|
||||
|
||||
|
||||
def parse_admin_env_vars(
|
||||
env_vars: Optional[Iterable[Any]],
|
||||
) -> Tuple[Dict[str, str], List[Dict[str, Any]]]:
|
||||
"""Split admin-configured env var entries into globals and per-user specs.
|
||||
|
||||
Accepts the raw value of ``MCPServer.env_vars`` (list of dicts or Pydantic
|
||||
models). Returns:
|
||||
|
||||
- ``global_values``: ``{name: value}`` for entries with ``scope=="global"``.
|
||||
- ``user_specs``: list of ``{name, description}`` for entries with
|
||||
``scope=="user"`` — these are the names the user must fill in.
|
||||
|
||||
Unknown / malformed entries are skipped silently.
|
||||
"""
|
||||
global_values: Dict[str, str] = {}
|
||||
user_specs: List[Dict[str, Any]] = []
|
||||
if not env_vars:
|
||||
return global_values, user_specs
|
||||
for raw in env_vars:
|
||||
if raw is None:
|
||||
continue
|
||||
if hasattr(raw, "model_dump"):
|
||||
entry = raw.model_dump()
|
||||
elif isinstance(raw, dict):
|
||||
entry = raw
|
||||
else:
|
||||
continue
|
||||
name = entry.get("name")
|
||||
if not isinstance(name, str) or not name:
|
||||
continue
|
||||
scope = entry.get("scope") or "global"
|
||||
if scope == "user":
|
||||
user_specs.append({"name": name, "description": entry.get("description")})
|
||||
else:
|
||||
value = entry.get("value")
|
||||
global_values[name] = "" if value is None else str(value)
|
||||
return global_values, user_specs
|
||||
|
||||
|
||||
def find_env_var_references(value: str) -> Set[str]:
|
||||
"""Return the set of ``${NAME}`` identifiers referenced inside ``value``."""
|
||||
if not value:
|
||||
return set()
|
||||
return set(_ENV_VAR_PATTERN.findall(value))
|
||||
|
||||
|
||||
def collect_env_var_references(*, strings: Iterable[str]) -> Set[str]:
|
||||
"""Union of every ``${NAME}`` reference across a collection of strings."""
|
||||
refs: Set[str] = set()
|
||||
for s in strings:
|
||||
if isinstance(s, str):
|
||||
refs |= find_env_var_references(s)
|
||||
return refs
|
||||
|
||||
|
||||
def interpolate_env_vars(value: str, variables: Mapping[str, str]) -> str:
|
||||
"""Replace ``${NAME}`` references in ``value`` with the matching mapping
|
||||
entry. Unknown names are left untouched so callers can detect them via
|
||||
``find_env_var_references`` on the result if needed.
|
||||
"""
|
||||
if not value:
|
||||
return value
|
||||
|
||||
def _sub(match: "re.Match[str]") -> str:
|
||||
name = match.group(1)
|
||||
if name in variables:
|
||||
return variables[name]
|
||||
return match.group(0)
|
||||
|
||||
return _ENV_VAR_PATTERN.sub(_sub, value)
|
||||
|
||||
|
||||
def interpolate_headers(
|
||||
headers: Mapping[str, str], variables: Mapping[str, str]
|
||||
) -> Dict[str, str]:
|
||||
"""Return a copy of ``headers`` with every value passed through ``interpolate_env_vars``."""
|
||||
return {k: interpolate_env_vars(v, variables) for k, v in headers.items()}
|
||||
|
||||
|
||||
def build_env_var_setup_url(server_id: str) -> str:
|
||||
"""The frontend URL where a user can fill in their per-user env vars."""
|
||||
base = os.environ.get("PROXY_BASE_URL", "").rstrip("/")
|
||||
path = f"/ui/?page=mcp-servers&fill_env_vars={quote(server_id, safe='')}"
|
||||
return f"{base}{path}" if base else path
|
||||
|
||||
|
||||
def merge_mcp_headers(
|
||||
*,
|
||||
extra_headers: Optional[Mapping[str, str]] = None,
|
||||
|
||||
@ -1265,6 +1265,34 @@ class MCPApprovalStatus(str, enum.Enum):
|
||||
rejected = "rejected"
|
||||
|
||||
|
||||
class MCPEnvVarScope(str, enum.Enum):
|
||||
"""Scope for an MCP server environment variable.
|
||||
|
||||
- ``global``: value is provided by the admin and used for all users.
|
||||
- ``user``: each user must provide their own value via the per-user
|
||||
env-var endpoint. The admin-supplied ``value`` is treated as a
|
||||
placeholder/hint and is not used at request time.
|
||||
"""
|
||||
|
||||
global_ = "global"
|
||||
user = "user"
|
||||
|
||||
|
||||
class MCPEnvVar(LiteLLMPydanticObjectBase):
|
||||
"""One environment variable for an MCP server.
|
||||
|
||||
Variables can be interpolated into ``static_headers`` using ``${NAME}``
|
||||
syntax. ``scope=global`` values are stored on the server. ``scope=user``
|
||||
values are stored per-user in ``LiteLLM_MCPUserEnvVars`` and supplied by
|
||||
each user.
|
||||
"""
|
||||
|
||||
name: str
|
||||
value: str = ""
|
||||
scope: MCPEnvVarScope = MCPEnvVarScope.global_
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
# MCP Proxy Request Types
|
||||
class NewMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
server_id: Optional[str] = None
|
||||
@ -1283,6 +1311,7 @@ class NewMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
tool_name_to_description: Optional[Dict[str, str]] = None
|
||||
extra_headers: Optional[List[str]] = None
|
||||
static_headers: Optional[Dict[str, str]] = None
|
||||
env_vars: Optional[List[MCPEnvVar]] = None
|
||||
instructions: Optional[str] = None
|
||||
# Stdio-specific fields
|
||||
command: Optional[str] = None
|
||||
@ -1369,6 +1398,7 @@ class UpdateMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
tool_name_to_description: Optional[Dict[str, str]] = None
|
||||
extra_headers: Optional[List[str]] = None
|
||||
static_headers: Optional[Dict[str, str]] = None
|
||||
env_vars: Optional[List[MCPEnvVar]] = None
|
||||
instructions: Optional[str] = None
|
||||
# Stdio-specific fields
|
||||
command: Optional[str] = None
|
||||
@ -1438,6 +1468,7 @@ class LiteLLM_MCPServerTable(LiteLLMPydanticObjectBase):
|
||||
extra_headers: List[str] = Field(default_factory=list)
|
||||
mcp_info: Optional[MCPInfo] = None
|
||||
static_headers: Optional[Dict[str, str]] = None
|
||||
env_vars: Optional[List[MCPEnvVar]] = None
|
||||
# Health check status
|
||||
status: Optional[Literal["healthy", "unhealthy", "unknown"]] = Field(
|
||||
default="unknown",
|
||||
@ -1519,6 +1550,35 @@ class MCPUserCredentialListItem(LiteLLMPydanticObjectBase):
|
||||
connected_at: Optional[str] = None # ISO-8601
|
||||
|
||||
|
||||
class MCPUserEnvVarsRequest(LiteLLMPydanticObjectBase):
|
||||
"""Payload for storing the calling user's per-user env var values."""
|
||||
|
||||
values: Dict[str, str]
|
||||
|
||||
|
||||
class MCPUserEnvVarSpec(LiteLLMPydanticObjectBase):
|
||||
"""Describes one per-user env var slot for the calling user.
|
||||
|
||||
Stored values are write-only: the status only reports whether a value
|
||||
``is_set`` and never echoes the decrypted secret back to the client.
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
is_set: bool = False
|
||||
|
||||
|
||||
class MCPUserEnvVarsStatus(LiteLLMPydanticObjectBase):
|
||||
"""Per-user env var status for a single MCP server."""
|
||||
|
||||
server_id: str
|
||||
server_name: Optional[str] = None
|
||||
alias: Optional[str] = None
|
||||
required: List[MCPUserEnvVarSpec] = Field(default_factory=list)
|
||||
missing_count: int = 0
|
||||
setup_url: Optional[str] = None # frontend URL where the user can fill these in
|
||||
|
||||
|
||||
class RejectMCPServerRequest(LiteLLMPydanticObjectBase):
|
||||
review_notes: Optional[str] = None
|
||||
|
||||
|
||||
@ -39,6 +39,7 @@ from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._experimental.mcp_server.db import (
|
||||
rotate_mcp_server_credentials_master_key,
|
||||
rotate_mcp_user_credentials_master_key,
|
||||
rotate_mcp_user_env_vars_master_key,
|
||||
)
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy._types import LiteLLM_VerificationToken
|
||||
@ -4136,6 +4137,15 @@ async def _rotate_master_key( # noqa: PLR0915
|
||||
"Failed to rotate MCP user credentials: %s", str(e)
|
||||
)
|
||||
|
||||
# 4c. process MCP per-user environment variables table
|
||||
try:
|
||||
await rotate_mcp_user_env_vars_master_key(
|
||||
prisma_client=prisma_client,
|
||||
new_master_key=new_master_key,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning("Failed to rotate MCP user env vars: %s", str(e))
|
||||
|
||||
# 5. process credentials table
|
||||
try:
|
||||
credentials = await prisma_client.db.litellm_credentialstable.find_many()
|
||||
|
||||
@ -47,7 +47,10 @@ from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.constants import LITELLM_PROXY_ADMIN_NAME
|
||||
from litellm.proxy._experimental.mcp_server.utils import (
|
||||
build_env_var_setup_url,
|
||||
collect_env_var_references,
|
||||
get_server_prefix,
|
||||
parse_admin_env_vars,
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.utils import (
|
||||
validate_and_normalize_mcp_server_payload as _base_validate_and_normalize_mcp_server_payload,
|
||||
@ -111,12 +114,16 @@ if MCP_AVAILABLE:
|
||||
create_mcp_server,
|
||||
delete_mcp_server,
|
||||
delete_user_credential,
|
||||
delete_user_env_vars,
|
||||
get_all_mcp_servers_for_user,
|
||||
get_mcp_server,
|
||||
get_mcp_servers,
|
||||
get_mcp_submissions,
|
||||
get_user_env_vars,
|
||||
get_user_env_vars_bulk,
|
||||
get_user_oauth_credential,
|
||||
list_user_oauth_credentials,
|
||||
merge_user_env_vars,
|
||||
reject_mcp_server,
|
||||
store_user_credential,
|
||||
store_user_oauth_credential,
|
||||
@ -139,6 +146,7 @@ if MCP_AVAILABLE:
|
||||
LitellmUserRoles,
|
||||
MakeMCPServersPublicRequest,
|
||||
MCPApprovalStatus,
|
||||
MCPEnvVarScope,
|
||||
MCPOAuthUserCredentialRequest,
|
||||
MCPOAuthUserCredentialStatus,
|
||||
MCPSubmissionsSummary,
|
||||
@ -146,6 +154,9 @@ if MCP_AVAILABLE:
|
||||
MCPUserCredentialListItem,
|
||||
MCPUserCredentialRequest,
|
||||
MCPUserCredentialResponse,
|
||||
MCPUserEnvVarSpec,
|
||||
MCPUserEnvVarsRequest,
|
||||
MCPUserEnvVarsStatus,
|
||||
NewMCPServerRequest,
|
||||
RejectMCPServerRequest,
|
||||
SpecialMCPServerName,
|
||||
@ -473,6 +484,27 @@ if MCP_AVAILABLE:
|
||||
) -> List[LiteLLM_MCPServerTable]:
|
||||
return [_redact_mcp_credentials(server) for server in mcp_servers]
|
||||
|
||||
def _redact_global_env_var_values(mcp_server: LiteLLM_MCPServerTable) -> None:
|
||||
"""Blank admin-supplied ``scope="global"`` env var secrets in place.
|
||||
|
||||
Global entries hold the admin's plaintext credential (API key,
|
||||
password, ...) and must never reach non-admin callers. Per-user
|
||||
entries only carry a placeholder the user fills in themselves, so
|
||||
their value is left intact.
|
||||
"""
|
||||
for env_var in mcp_server.env_vars or []:
|
||||
if env_var.scope == MCPEnvVarScope.global_:
|
||||
env_var.value = ""
|
||||
|
||||
def _user_is_full_admin(user_api_key_dict: UserAPIKeyAuth) -> bool:
|
||||
"""True only for ``PROXY_ADMIN``; ``PROXY_ADMIN_VIEW_ONLY`` returns False.
|
||||
|
||||
Global env var secrets pre-fill the admin edit form, so a full admin
|
||||
must see them, but a read-only admin gets the same redacted view as
|
||||
any other non-managing caller.
|
||||
"""
|
||||
return user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
|
||||
def _is_restricted_virtual_key_request(user_api_key_dict: UserAPIKeyAuth) -> bool:
|
||||
"""Best-effort detection for route-restricted virtual keys.
|
||||
|
||||
@ -513,6 +545,11 @@ if MCP_AVAILABLE:
|
||||
sanitized.authorization_url = None
|
||||
sanitized.token_url = None
|
||||
sanitized.registration_url = None
|
||||
# Drop env vars entirely rather than only blanking global values: the
|
||||
# names alone (DB_PASSWORD, GITHUB_API_KEY, ...) leak what secrets the
|
||||
# admin configured. Non-admins get the per-user vars they must fill in
|
||||
# from the dedicated /user-env-vars/status endpoint instead.
|
||||
sanitized.env_vars = None
|
||||
return sanitized
|
||||
|
||||
def _sanitize_mcp_server_list_for_non_admin(
|
||||
@ -544,6 +581,7 @@ if MCP_AVAILABLE:
|
||||
sanitized.allowed_tools = []
|
||||
sanitized.mcp_access_groups = []
|
||||
sanitized.teams = []
|
||||
sanitized.env_vars = None
|
||||
|
||||
sanitized.authorization_url = None
|
||||
sanitized.token_url = None
|
||||
@ -976,6 +1014,10 @@ if MCP_AVAILABLE:
|
||||
if not _user_has_admin_view(user_api_key_dict):
|
||||
return _sanitize_mcp_server_list_for_non_admin(redacted_mcp_servers)
|
||||
|
||||
if not _user_is_full_admin(user_api_key_dict):
|
||||
for server in redacted_mcp_servers:
|
||||
_redact_global_env_var_values(server)
|
||||
|
||||
return redacted_mcp_servers
|
||||
|
||||
@router.get(
|
||||
@ -1144,7 +1186,11 @@ if MCP_AVAILABLE:
|
||||
"Database not connected. Connect a database to your proxy"
|
||||
)
|
||||
|
||||
return await get_mcp_submissions(prisma_client)
|
||||
submissions = await get_mcp_submissions(prisma_client)
|
||||
if not _user_is_full_admin(user_api_key_dict):
|
||||
for item in submissions.items:
|
||||
_redact_global_env_var_values(item)
|
||||
return submissions
|
||||
|
||||
@router.put(
|
||||
"/server/{server_id}/approve",
|
||||
@ -1363,6 +1409,8 @@ if MCP_AVAILABLE:
|
||||
return _sanitize_mcp_server_for_virtual_key(redacted)
|
||||
if not _user_has_admin_view(user_api_key_dict):
|
||||
return _sanitize_mcp_server_for_non_admin(redacted)
|
||||
if not _user_is_full_admin(user_api_key_dict):
|
||||
_redact_global_env_var_values(redacted)
|
||||
return redacted
|
||||
|
||||
@router.post(
|
||||
@ -1432,23 +1480,34 @@ if MCP_AVAILABLE:
|
||||
payload.submitted_by = None
|
||||
payload.submitted_at = None
|
||||
|
||||
# Attempt to create the mcp server
|
||||
# The database write is the commit point: if it fails nothing was
|
||||
# persisted and the request is a genuine failure.
|
||||
try:
|
||||
new_mcp_server = await create_mcp_server(
|
||||
prisma_client,
|
||||
payload,
|
||||
touched_by=user_api_key_dict.user_id or LITELLM_PROXY_ADMIN_NAME,
|
||||
)
|
||||
await global_mcp_server_manager.add_server(new_mcp_server)
|
||||
|
||||
# Ensure registry is up to date by reloading from database
|
||||
await global_mcp_server_manager.reload_servers_from_database()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating mcp server: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": f"Error creating mcp server: {str(e)}"},
|
||||
)
|
||||
|
||||
# Registry refresh is best-effort: the row is already committed, so a
|
||||
# failure here (e.g. an unrelated malformed row in the table) must not
|
||||
# surface as a 500 and orphan the created server, which would push the
|
||||
# caller to retry and create duplicates.
|
||||
try:
|
||||
await global_mcp_server_manager.add_server(new_mcp_server)
|
||||
await global_mcp_server_manager.reload_servers_from_database()
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"MCP server {new_mcp_server.server_id} created but in-memory "
|
||||
f"registry refresh failed: {str(e)}"
|
||||
)
|
||||
|
||||
return _redact_mcp_credentials(new_mcp_server)
|
||||
|
||||
@router.post(
|
||||
@ -2106,6 +2165,249 @@ if MCP_AVAILABLE:
|
||||
)
|
||||
return items
|
||||
|
||||
# ── Per-user MCP env var endpoints ────────────────────────────────────────
|
||||
|
||||
async def _authorize_and_fetch_mcp_server(
|
||||
prisma_client,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
server_id: str,
|
||||
) -> LiteLLM_MCPServerTable:
|
||||
"""Return the MCP server the caller may manage env vars for.
|
||||
|
||||
Admins look the server up directly. Non-admins reuse the access-scoped
|
||||
listing that already loads every server they can see, so we don't issue
|
||||
a second per-server query just to re-fetch a record the authorization
|
||||
check produced. A non-admin who can't see the server gets 403 (never
|
||||
404) so server ids can't be enumerated.
|
||||
"""
|
||||
if _user_has_admin_view(user_api_key_dict):
|
||||
server = await get_mcp_server(prisma_client, server_id)
|
||||
if server is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"error": f"MCP Server {server_id} not found"},
|
||||
)
|
||||
return server
|
||||
accessible = await get_all_mcp_servers_for_user(
|
||||
prisma_client, user_api_key_dict
|
||||
)
|
||||
for server in accessible:
|
||||
if server.server_id == server_id:
|
||||
return server
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"error": (
|
||||
f"User does not have permission to access mcp server with id {server_id}. "
|
||||
"You can only manage env vars for mcp servers that you have access to."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
def _compute_user_env_var_status(
|
||||
*,
|
||||
server: LiteLLM_MCPServerTable,
|
||||
stored_values: Dict[str, str],
|
||||
) -> MCPUserEnvVarsStatus:
|
||||
"""Build a status object for one server given the user's stored values.
|
||||
|
||||
Stored credentials are write-only: the response reports only whether
|
||||
each value ``is_set`` and never echoes the decrypted secret back, so a
|
||||
leaked token can't be used to exfiltrate the raw upstream credential.
|
||||
"""
|
||||
global_values, user_specs = parse_admin_env_vars(
|
||||
getattr(server, "env_vars", None)
|
||||
)
|
||||
# An empty-valued global is not a usable fallback, so it must not mark a
|
||||
# referenced per-user var as covered, matching the empty-global filter in
|
||||
# _resolve_static_headers_with_env_vars. Otherwise this endpoint reports no
|
||||
# credential needed for a var every tool call still 412s on.
|
||||
global_values = {name: value for name, value in global_values.items() if value}
|
||||
|
||||
# A var only blocks when it's referenced by static_headers and has no
|
||||
# admin global fallback, mirroring _resolve_static_headers_with_env_vars
|
||||
# (globals win the merge) so the status endpoint never asks the user for
|
||||
# credentials a tool call wouldn't actually require.
|
||||
static_headers = getattr(server, "static_headers", None) or {}
|
||||
if isinstance(static_headers, str):
|
||||
try:
|
||||
static_headers = json.loads(static_headers) or {}
|
||||
except (ValueError, TypeError):
|
||||
static_headers = {}
|
||||
referenced = collect_env_var_references(strings=static_headers.values())
|
||||
user_var_names = {spec["name"] for spec in user_specs}
|
||||
blocking = {
|
||||
name for name in (referenced & user_var_names) if name not in global_values
|
||||
}
|
||||
|
||||
required: List[MCPUserEnvVarSpec] = []
|
||||
missing_count = 0
|
||||
for spec in user_specs:
|
||||
name = spec["name"]
|
||||
if name not in blocking:
|
||||
continue
|
||||
value = stored_values.get(name)
|
||||
is_set = bool(value)
|
||||
if not is_set:
|
||||
missing_count += 1
|
||||
required.append(
|
||||
MCPUserEnvVarSpec(
|
||||
name=name,
|
||||
description=spec.get("description"),
|
||||
is_set=is_set,
|
||||
)
|
||||
)
|
||||
|
||||
return MCPUserEnvVarsStatus(
|
||||
server_id=server.server_id,
|
||||
server_name=getattr(server, "server_name", None),
|
||||
alias=getattr(server, "alias", None),
|
||||
required=required,
|
||||
missing_count=missing_count,
|
||||
setup_url=build_env_var_setup_url(server.server_id) if required else None,
|
||||
)
|
||||
|
||||
@router.get(
|
||||
"/server/{server_id}/user-env-vars",
|
||||
description="Return the calling user's per-user MCP env var status for this server.",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=MCPUserEnvVarsStatus,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def get_mcp_user_env_vars(
|
||||
server_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> MCPUserEnvVarsStatus:
|
||||
prisma_client = get_prisma_client_or_throw(
|
||||
"Database not connected. Connect a database to your proxy"
|
||||
)
|
||||
user_id = user_api_key_dict.user_id or ""
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "User ID not found in token"},
|
||||
)
|
||||
server = await _authorize_and_fetch_mcp_server(
|
||||
prisma_client, user_api_key_dict, server_id
|
||||
)
|
||||
stored = await get_user_env_vars(prisma_client, user_id, server_id)
|
||||
return _compute_user_env_var_status(server=server, stored_values=stored)
|
||||
|
||||
@router.post(
|
||||
"/server/{server_id}/user-env-vars",
|
||||
description=(
|
||||
"Store the calling user's per-user MCP env var values for this "
|
||||
"server. Submitted values are merged over any previously stored "
|
||||
"values, so you only send the fields you want to set or change; a "
|
||||
"variable omitted (or sent empty) keeps its stored value. Use "
|
||||
"DELETE to clear all stored values."
|
||||
),
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=MCPUserEnvVarsStatus,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def store_mcp_user_env_vars(
|
||||
server_id: str,
|
||||
payload: MCPUserEnvVarsRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> MCPUserEnvVarsStatus:
|
||||
prisma_client = get_prisma_client_or_throw(
|
||||
"Database not connected. Connect a database to your proxy"
|
||||
)
|
||||
user_id = user_api_key_dict.user_id or ""
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "User ID not found in token"},
|
||||
)
|
||||
server = await _authorize_and_fetch_mcp_server(
|
||||
prisma_client, user_api_key_dict, server_id
|
||||
)
|
||||
# Only known per-user var names declared by the admin are accepted —
|
||||
# never persist arbitrary keys the user invents. Submitted values are
|
||||
# merged over the existing set so a user updating one credential does
|
||||
# not have to re-enter the others (which are write-only and never shown
|
||||
# back); an omitted/empty field keeps its stored value.
|
||||
_, user_specs = parse_admin_env_vars(getattr(server, "env_vars", None))
|
||||
allowed_names = {spec["name"] for spec in user_specs}
|
||||
updates = {
|
||||
k: v for k, v in payload.values.items() if k in allowed_names and v != ""
|
||||
}
|
||||
merged = await merge_user_env_vars(
|
||||
prisma_client, user_id, server_id, updates, allowed_names
|
||||
)
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
invalidate_user_env_vars_cache,
|
||||
)
|
||||
|
||||
invalidate_user_env_vars_cache(user_id, server_id)
|
||||
return _compute_user_env_var_status(server=server, stored_values=merged)
|
||||
|
||||
@router.delete(
|
||||
"/server/{server_id}/user-env-vars",
|
||||
description="Clear the calling user's per-user MCP env var values for this server.",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=MCPUserEnvVarsStatus,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def clear_mcp_user_env_vars(
|
||||
server_id: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> MCPUserEnvVarsStatus:
|
||||
prisma_client = get_prisma_client_or_throw(
|
||||
"Database not connected. Connect a database to your proxy"
|
||||
)
|
||||
user_id = user_api_key_dict.user_id or ""
|
||||
if not user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"error": "User ID not found in token"},
|
||||
)
|
||||
server = await _authorize_and_fetch_mcp_server(
|
||||
prisma_client, user_api_key_dict, server_id
|
||||
)
|
||||
await delete_user_env_vars(prisma_client, user_id, server_id)
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
invalidate_user_env_vars_cache,
|
||||
)
|
||||
|
||||
invalidate_user_env_vars_cache(user_id, server_id)
|
||||
return _compute_user_env_var_status(server=server, stored_values={})
|
||||
|
||||
@router.get(
|
||||
"/user-env-vars/status",
|
||||
description="Per-user MCP env var status across every server the user can access. "
|
||||
"Used by the dashboard to highlight servers with missing per-user vars.",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=List[MCPUserEnvVarsStatus],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def list_mcp_user_env_var_status(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
) -> List[MCPUserEnvVarsStatus]:
|
||||
prisma_client = get_prisma_client_or_throw(
|
||||
"Database not connected. Connect a database to your proxy"
|
||||
)
|
||||
user_id = user_api_key_dict.user_id or ""
|
||||
if not user_id:
|
||||
return []
|
||||
accessible = await get_all_mcp_servers_for_user(
|
||||
prisma_client, user_api_key_dict
|
||||
)
|
||||
if not accessible:
|
||||
return []
|
||||
server_ids = [s.server_id for s in accessible]
|
||||
stored_bulk = await get_user_env_vars_bulk(prisma_client, user_id, server_ids)
|
||||
statuses: List[MCPUserEnvVarsStatus] = []
|
||||
for server in accessible:
|
||||
stored = stored_bulk.get(server.server_id, {})
|
||||
status_obj = _compute_user_env_var_status(
|
||||
server=server, stored_values=stored
|
||||
)
|
||||
if status_obj.required:
|
||||
statuses.append(status_obj)
|
||||
return statuses
|
||||
|
||||
@router.put(
|
||||
"/server",
|
||||
description="Allows deleting mcp serves in the db",
|
||||
|
||||
@ -5,6 +5,7 @@ from functools import wraps
|
||||
from typing import Any, Callable, List, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException, Request
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
@ -436,6 +437,58 @@ async def send_management_endpoint_alert(
|
||||
)
|
||||
|
||||
|
||||
def _redacted_env_var(entry: Any) -> dict:
|
||||
get = entry.get if isinstance(entry, dict) else lambda k: getattr(entry, k, None)
|
||||
return {
|
||||
"name": get("name"),
|
||||
"scope": get("scope"),
|
||||
"description": get("description"),
|
||||
"value": "",
|
||||
}
|
||||
|
||||
|
||||
def _redact_record_env_vars(record: Any) -> Any:
|
||||
"""Return ``record`` with its ``env_vars[].value`` blanked.
|
||||
|
||||
Copies rather than mutating, because the record aliases the live response
|
||||
object that is also returned to the caller. Records without an ``env_vars``
|
||||
list are returned unchanged.
|
||||
"""
|
||||
env_vars = (
|
||||
record.get("env_vars")
|
||||
if isinstance(record, dict)
|
||||
else getattr(record, "env_vars", None)
|
||||
)
|
||||
if not isinstance(env_vars, list):
|
||||
return record
|
||||
redacted = [_redacted_env_var(entry) for entry in env_vars]
|
||||
if isinstance(record, dict):
|
||||
return {**record, "env_vars": redacted}
|
||||
if isinstance(record, BaseModel):
|
||||
return record.model_copy(update={"env_vars": redacted})
|
||||
return record
|
||||
|
||||
|
||||
def _redact_env_var_values(response: dict) -> None:
|
||||
"""Blank ``env_vars[].value`` in a management response before telemetry.
|
||||
|
||||
MCP endpoints return decrypted ``scope="global"`` env var values so the admin
|
||||
UI can pre-fill the edit form; those values are upstream credentials and must
|
||||
not be serialized verbatim into OTEL spans, where an observability user could
|
||||
read them. The values surface both at the top level (single-server
|
||||
create/update) and nested under ``items`` (the submissions queue), so both are
|
||||
scrubbed. Names, scopes, and descriptions are kept so traces stay useful.
|
||||
"""
|
||||
if isinstance(response.get("env_vars"), list):
|
||||
response["env_vars"] = [
|
||||
_redacted_env_var(entry) for entry in response["env_vars"]
|
||||
]
|
||||
|
||||
items = response.get("items")
|
||||
if isinstance(items, list):
|
||||
response["items"] = [_redact_record_env_vars(item) for item in items]
|
||||
|
||||
|
||||
async def _emit_management_endpoint_otel_span(
|
||||
func: Callable,
|
||||
kwargs: dict,
|
||||
@ -497,6 +550,7 @@ async def _emit_management_endpoint_otel_span(
|
||||
try:
|
||||
raw = dict(result)
|
||||
_response = {k: v for k, v in raw.items() if k not in _CREDENTIAL_FIELDS}
|
||||
_redact_env_var_values(_response)
|
||||
except Exception:
|
||||
_response = None
|
||||
|
||||
|
||||
@ -311,6 +311,11 @@ model LiteLLM_MCPServerTable {
|
||||
tool_name_to_description Json? @default("{}")
|
||||
extra_headers String[] @default([])
|
||||
static_headers Json? @default("{}")
|
||||
// Admin-configured environment variables interpolated into static_headers
|
||||
// via ${NAME} syntax. Stored as an array of
|
||||
// {name, value, scope, description}. scope is "global" (value used as-is)
|
||||
// or "user" (value supplied per-user via LiteLLM_MCPUserEnvVars).
|
||||
env_vars Json? @default("[]")
|
||||
// Health check status
|
||||
status String? @default("unknown")
|
||||
last_health_check DateTime?
|
||||
@ -366,6 +371,21 @@ model LiteLLM_MCPUserCredentials {
|
||||
@@unique([user_id, server_id])
|
||||
}
|
||||
|
||||
// Per-user environment variable values for MCP servers.
|
||||
// values_b64 is an encrypted JSON object: {VAR_NAME: "value", ...}.
|
||||
model LiteLLM_MCPUserEnvVars {
|
||||
id String @id @default(uuid())
|
||||
user_id String
|
||||
server_id String
|
||||
values_b64 String
|
||||
created_at DateTime @default(now())
|
||||
updated_at DateTime @default(now()) @updatedAt
|
||||
|
||||
@@unique([user_id, server_id])
|
||||
@@index([user_id])
|
||||
@@index([server_id])
|
||||
}
|
||||
|
||||
// Generate Tokens for Proxy
|
||||
model LiteLLM_VerificationToken {
|
||||
token String @id
|
||||
|
||||
@ -42,6 +42,10 @@ class MCPServer(BaseModel):
|
||||
static_headers: Optional[Dict[str, str]] = (
|
||||
None # static headers to forward to the MCP server
|
||||
)
|
||||
# Admin-configured env vars. Each entry is {name, value, scope, description}.
|
||||
# scope=="global" values are interpolated into static_headers using ${NAME}.
|
||||
# scope=="user" values must be supplied per-user.
|
||||
env_vars: Optional[List[Dict[str, Any]]] = None
|
||||
# OAuth-specific fields
|
||||
client_id: Optional[str] = None
|
||||
client_secret: Optional[str] = None
|
||||
|
||||
@ -311,6 +311,11 @@ model LiteLLM_MCPServerTable {
|
||||
tool_name_to_description Json? @default("{}")
|
||||
extra_headers String[] @default([])
|
||||
static_headers Json? @default("{}")
|
||||
// Admin-configured environment variables interpolated into static_headers
|
||||
// via ${NAME} syntax. Stored as an array of
|
||||
// {name, value, scope, description}. scope is "global" (value used as-is)
|
||||
// or "user" (value supplied per-user via LiteLLM_MCPUserEnvVars).
|
||||
env_vars Json? @default("[]")
|
||||
// Health check status
|
||||
status String? @default("unknown")
|
||||
last_health_check DateTime?
|
||||
@ -366,6 +371,21 @@ model LiteLLM_MCPUserCredentials {
|
||||
@@unique([user_id, server_id])
|
||||
}
|
||||
|
||||
// Per-user environment variable values for MCP servers.
|
||||
// values_b64 is an encrypted JSON object: {VAR_NAME: "value", ...}.
|
||||
model LiteLLM_MCPUserEnvVars {
|
||||
id String @id @default(uuid())
|
||||
user_id String
|
||||
server_id String
|
||||
values_b64 String
|
||||
created_at DateTime @default(now())
|
||||
updated_at DateTime @default(now()) @updatedAt
|
||||
|
||||
@@unique([user_id, server_id])
|
||||
@@index([user_id])
|
||||
@@index([server_id])
|
||||
}
|
||||
|
||||
// Generate Tokens for Proxy
|
||||
model LiteLLM_VerificationToken {
|
||||
token String @id
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
@ -10,10 +11,26 @@ import pytest
|
||||
sys.path.insert(0, "../../../")
|
||||
|
||||
import litellm.experimental_mcp_client.client as mcp_client_module
|
||||
from litellm.experimental_mcp_client.client import MCPClient
|
||||
from litellm.experimental_mcp_client.client import (
|
||||
MCPClient,
|
||||
_first_non_cancelled_cause,
|
||||
)
|
||||
from litellm.types.mcp import MCPAuth, MCPStdioConfig, MCPTransport
|
||||
|
||||
|
||||
class _FakeExceptionGroup(Exception):
|
||||
"""Duck-typed stand-in for an anyio/builtin ExceptionGroup.
|
||||
|
||||
The production unwrapper reads ``.exceptions`` rather than depending on the
|
||||
builtin ``ExceptionGroup`` type, so this exercises the same code path on
|
||||
every Python version.
|
||||
"""
|
||||
|
||||
def __init__(self, message, exceptions):
|
||||
super().__init__(message)
|
||||
self.exceptions = tuple(exceptions)
|
||||
|
||||
|
||||
class TestMCPClient:
|
||||
"""Test MCP Client stdio functionality"""
|
||||
|
||||
@ -307,6 +324,26 @@ class TestMCPClient:
|
||||
assert headers["Authorization"] == "token my-token"
|
||||
assert headers["X-Custom-Header"] == "custom-value"
|
||||
|
||||
def test_get_auth_headers_strips_static_header_whitespace(self):
|
||||
"""
|
||||
Static header names/values must be stripped of surrounding whitespace.
|
||||
|
||||
h11 rejects header values with leading/trailing whitespace as an
|
||||
"Illegal header value", which silently aborts the MCP connection. A
|
||||
stray space in a configured static header value would otherwise make
|
||||
every request to that server fail with an opaque error.
|
||||
"""
|
||||
client = MCPClient(
|
||||
server_url="http://example.com/mcp",
|
||||
transport_type="http",
|
||||
extra_headers={"X-Db-Url": " mew://host ", " X-Pad ": "v"},
|
||||
)
|
||||
|
||||
headers = client._get_auth_headers()
|
||||
|
||||
assert headers["X-Db-Url"] == "mew://host"
|
||||
assert headers["X-Pad"] == "v"
|
||||
|
||||
def test_token_auth_enum_value(self):
|
||||
"""Test that MCPAuth.token enum exists and has correct value"""
|
||||
assert hasattr(MCPAuth, "token")
|
||||
@ -388,5 +425,123 @@ class TestMCPClientInstructionsCapture:
|
||||
assert client._last_initialize_instructions is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Transport error surfacing
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFirstNonCancelledCause:
|
||||
"""Unwrapping the real cause out of a (possibly nested) exception group."""
|
||||
|
||||
def test_returns_plain_non_cancelled(self):
|
||||
err = ValueError("boom")
|
||||
assert _first_non_cancelled_cause(err) is err
|
||||
|
||||
def test_returns_none_for_plain_cancelled(self):
|
||||
assert _first_non_cancelled_cause(asyncio.CancelledError()) is None
|
||||
|
||||
def test_unwraps_group_to_non_cancelled_leaf(self):
|
||||
target = httpx.ConnectError("refused")
|
||||
group = _FakeExceptionGroup("g", [asyncio.CancelledError(), target])
|
||||
assert _first_non_cancelled_cause(group) is target
|
||||
|
||||
def test_unwraps_nested_group(self):
|
||||
target = httpx.LocalProtocolError("Illegal header value")
|
||||
inner = _FakeExceptionGroup("inner", [asyncio.CancelledError(), target])
|
||||
outer = _FakeExceptionGroup("outer", [asyncio.CancelledError(), inner])
|
||||
assert _first_non_cancelled_cause(outer) is target
|
||||
|
||||
def test_all_cancelled_returns_none(self):
|
||||
group = _FakeExceptionGroup(
|
||||
"g", [asyncio.CancelledError(), asyncio.CancelledError()]
|
||||
)
|
||||
assert _first_non_cancelled_cause(group) is None
|
||||
|
||||
@pytest.mark.skipif(
|
||||
sys.version_info < (3, 11), reason="builtin ExceptionGroup requires 3.11+"
|
||||
)
|
||||
def test_unwraps_builtin_exception_group(self):
|
||||
target = httpx.ConnectError("refused")
|
||||
group = ExceptionGroup("transport failed", [target]) # noqa: F821
|
||||
assert _first_non_cancelled_cause(group) is target
|
||||
|
||||
|
||||
class TestExecuteSessionOperationSurfacesTransportError:
|
||||
"""_execute_session_operation should surface the real transport failure.
|
||||
|
||||
When the upstream transport's task group fails (illegal header, connection
|
||||
refused, ...), the in-flight ``session.initialize()`` is cancelled and the
|
||||
real error only appears when the transport context exits. The opaque
|
||||
``CancelledError`` must be replaced with that real cause.
|
||||
"""
|
||||
|
||||
def _make_session(self, mock_session_cls, initialize):
|
||||
mock_session = AsyncMock()
|
||||
mock_session.initialize = initialize
|
||||
session_ctx = MagicMock()
|
||||
session_ctx.__aenter__ = AsyncMock(return_value=mock_session)
|
||||
session_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_session_cls.return_value = session_ctx
|
||||
|
||||
def _make_transport(self, aexit_side_effect):
|
||||
transport_ctx = MagicMock()
|
||||
transport_ctx.__aenter__ = AsyncMock(return_value=(MagicMock(), MagicMock()))
|
||||
transport_ctx.__aexit__ = AsyncMock(side_effect=aexit_side_effect)
|
||||
return transport_ctx
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.experimental_mcp_client.client.ClientSession")
|
||||
async def test_surfaces_connect_error_over_cancelled(self, mock_session_cls):
|
||||
client = MCPClient(server_url="http://example.com/mcp", transport_type="http")
|
||||
self._make_session(
|
||||
mock_session_cls,
|
||||
AsyncMock(side_effect=asyncio.CancelledError("cancelled by group")),
|
||||
)
|
||||
connect_error = httpx.ConnectError("All connection attempts failed")
|
||||
transport_ctx = self._make_transport(
|
||||
_FakeExceptionGroup("transport", [connect_error])
|
||||
)
|
||||
|
||||
async def _op(session):
|
||||
return "done"
|
||||
|
||||
with pytest.raises(httpx.ConnectError):
|
||||
await client._execute_session_operation(transport_ctx, _op)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.experimental_mcp_client.client.ClientSession")
|
||||
async def test_genuine_cancellation_is_not_replaced(self, mock_session_cls):
|
||||
client = MCPClient(server_url="http://example.com/mcp", transport_type="http")
|
||||
self._make_session(
|
||||
mock_session_cls, AsyncMock(side_effect=asyncio.CancelledError())
|
||||
)
|
||||
transport_ctx = self._make_transport(
|
||||
_FakeExceptionGroup("teardown", [asyncio.CancelledError()])
|
||||
)
|
||||
|
||||
async def _op(session):
|
||||
return "done"
|
||||
|
||||
with pytest.raises(asyncio.CancelledError):
|
||||
await client._execute_session_operation(transport_ctx, _op)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("litellm.experimental_mcp_client.client.ClientSession")
|
||||
async def test_cleanup_error_after_success_is_swallowed(self, mock_session_cls):
|
||||
client = MCPClient(server_url="http://example.com/mcp", transport_type="http")
|
||||
init_result = MagicMock()
|
||||
init_result.instructions = None
|
||||
self._make_session(mock_session_cls, AsyncMock(return_value=init_result))
|
||||
transport_ctx = self._make_transport(
|
||||
_FakeExceptionGroup("late", [httpx.ConnectError("late cleanup error")])
|
||||
)
|
||||
|
||||
async def _op(session):
|
||||
return "done"
|
||||
|
||||
result = await client._execute_session_operation(transport_ctx, _op)
|
||||
assert result == "done"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
|
||||
@ -200,7 +200,12 @@ def test_tokenizers():
|
||||
model="meta-llama/llama-3-70b-instruct", text=sample_text
|
||||
)
|
||||
|
||||
llama3_tokenizer = create_pretrained_tokenizer("Xenova/llama-3-tokenizer")
|
||||
try:
|
||||
llama3_tokenizer = create_pretrained_tokenizer("Xenova/llama-3-tokenizer")
|
||||
except Exception as e:
|
||||
pytest.skip(
|
||||
f"custom tokenizer download failed (HF hub unreachable): {e}"
|
||||
)
|
||||
llama3_tokens_2 = token_counter(
|
||||
custom_tokenizer=llama3_tokenizer, text=sample_text
|
||||
)
|
||||
|
||||
@ -20,11 +20,14 @@ from litellm.proxy._experimental.mcp_server.db import (
|
||||
get_user_oauth_credential,
|
||||
list_user_oauth_credentials,
|
||||
rotate_mcp_user_credentials_master_key,
|
||||
rotate_mcp_user_env_vars_master_key,
|
||||
store_user_credential,
|
||||
store_user_oauth_credential,
|
||||
)
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import decrypt_value_helper
|
||||
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import (
|
||||
decrypt_value_helper,
|
||||
encrypt_value_helper,
|
||||
)
|
||||
|
||||
SALT_KEY = "test-salt-key-for-byok-credential-tests-1234"
|
||||
|
||||
@ -400,3 +403,69 @@ async def test_rotate_skips_undecodable_rows():
|
||||
assert prisma.db.litellm_mcpusercredentials.update.call_count == 1
|
||||
where = prisma.db.litellm_mcpusercredentials.update.call_args.kwargs["where"]
|
||||
assert where["user_id_server_id"]["server_id"] == "srv-ok"
|
||||
|
||||
|
||||
# ── per-user env-var rotation ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _env_var_row(values_b64: str, user_id="alice", server_id="srv-1"):
|
||||
row = MagicMock()
|
||||
row.values_b64 = values_b64
|
||||
row.user_id = user_id
|
||||
row.server_id = server_id
|
||||
return row
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rotate_user_env_vars_re_encrypts_with_new_key(monkeypatch):
|
||||
# Encrypt env vars under the current salt, rotate to a new key, then confirm
|
||||
# the stored ciphertext round-trips under the NEW key.
|
||||
values = {"API_KEY": "sk-secret", "REGION": "us-east-1"}
|
||||
encrypted_old = encrypt_value_helper(json.dumps(values))
|
||||
|
||||
prisma = MagicMock()
|
||||
prisma.db.litellm_mcpuserenvvars.find_many = AsyncMock(
|
||||
return_value=[_env_var_row(encrypted_old)]
|
||||
)
|
||||
prisma.db.litellm_mcpuserenvvars.update = AsyncMock()
|
||||
|
||||
new_master_key = "rotated-env-key-1111-2222-3333-4444"
|
||||
await rotate_mcp_user_env_vars_master_key(
|
||||
prisma_client=prisma, new_master_key=new_master_key
|
||||
)
|
||||
|
||||
new_stored = prisma.db.litellm_mcpuserenvvars.update.call_args.kwargs["data"][
|
||||
"values_b64"
|
||||
]
|
||||
assert new_stored != encrypted_old, "rotation must produce different ciphertext"
|
||||
|
||||
monkeypatch.setenv("LITELLM_SALT_KEY", new_master_key)
|
||||
decrypted = decrypt_value_helper(
|
||||
value=new_stored,
|
||||
key="mcp_user_env_vars",
|
||||
exception_type="debug",
|
||||
return_original_value=False,
|
||||
)
|
||||
assert json.loads(decrypted) == values
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rotate_user_env_vars_skips_undecryptable_rows():
|
||||
# A corrupt row must be skipped (not overwritten) so recoverable data is
|
||||
# preserved and one bad row does not abort the rest of the rotation.
|
||||
good = _env_var_row(
|
||||
encrypt_value_helper(json.dumps({"A": "1"})), server_id="srv-ok"
|
||||
)
|
||||
bad = _env_var_row("!!! not encrypted !!!", server_id="srv-corrupt")
|
||||
|
||||
prisma = MagicMock()
|
||||
prisma.db.litellm_mcpuserenvvars.find_many = AsyncMock(return_value=[bad, good])
|
||||
prisma.db.litellm_mcpuserenvvars.update = AsyncMock()
|
||||
|
||||
await rotate_mcp_user_env_vars_master_key(
|
||||
prisma_client=prisma, new_master_key="new-key-xxxx"
|
||||
)
|
||||
|
||||
assert prisma.db.litellm_mcpuserenvvars.update.call_count == 1
|
||||
where = prisma.db.litellm_mcpuserenvvars.update.call_args.kwargs["where"]
|
||||
assert where["user_id_server_id"]["server_id"] == "srv-ok"
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -3934,7 +3934,7 @@ class TestMCPServerManagerReload:
|
||||
):
|
||||
await manager.reload_servers_from_database()
|
||||
|
||||
mock_build.assert_awaited_once_with(db_row)
|
||||
mock_build.assert_awaited_once_with(db_row, env_vars_are_encrypted=True)
|
||||
assert manager.registry["server-1"] is rebuilt_server
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -3965,7 +3965,7 @@ class TestMCPServerManagerReload:
|
||||
updated_at=timestamp,
|
||||
)
|
||||
|
||||
async def build_server(db_row):
|
||||
async def build_server(db_row, **kwargs):
|
||||
if db_row.server_id == "bad-server":
|
||||
raise RuntimeError("transient build failure")
|
||||
if db_row.server_id == "healthy-server":
|
||||
@ -4031,7 +4031,7 @@ class TestMCPServerManagerReload:
|
||||
updated_at=timestamp,
|
||||
)
|
||||
|
||||
async def build_server(db_row):
|
||||
async def build_server(db_row, **kwargs):
|
||||
if db_row.server_id == "healthy-server":
|
||||
return healthy_server
|
||||
return bad_openapi_server
|
||||
|
||||
@ -29,10 +29,13 @@ from mcp.types import Tool as MCPTool
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
|
||||
MCPServerManager,
|
||||
_deserialize_json_dict,
|
||||
_deserialize_json_list,
|
||||
)
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_MCPServerTable,
|
||||
MCPApprovalStatus,
|
||||
MCPEnvVar,
|
||||
MCPEnvVarScope,
|
||||
MCPTransport,
|
||||
)
|
||||
from litellm.types.mcp import MCPAuth
|
||||
@ -3057,6 +3060,57 @@ class TestMCPServerTimestamps:
|
||||
assert rebuilt_table.created_at == created
|
||||
assert rebuilt_table.updated_at == updated
|
||||
|
||||
def test_deserialize_json_list_normalizes_pydantic_models(self):
|
||||
"""Prisma hydrates the ``env_vars`` JSON column into ``MCPEnvVar`` models;
|
||||
``_deserialize_json_list`` must hand back plain dicts so ``MCPServer``
|
||||
(typed ``List[Dict[str, Any]]``) validates."""
|
||||
env_vars = [
|
||||
MCPEnvVar(
|
||||
name="GITHUB_TOKEN", scope=MCPEnvVarScope.user, description="PAT"
|
||||
),
|
||||
MCPEnvVar(name="REGION", value="us-east-1", scope=MCPEnvVarScope.global_),
|
||||
]
|
||||
result = _deserialize_json_list(env_vars)
|
||||
assert result is not None
|
||||
assert all(isinstance(item, dict) for item in result)
|
||||
assert result[0]["name"] == "GITHUB_TOKEN"
|
||||
assert result[0]["scope"] == "user"
|
||||
assert result[1]["value"] == "us-east-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_mcp_server_from_table_with_model_env_vars(self):
|
||||
"""Regression: a DB row whose ``env_vars`` is a list of ``MCPEnvVar``
|
||||
models (as Prisma returns) must build into an ``MCPServer`` instead of
|
||||
raising a Pydantic ``dict_type`` validation error that silently drops
|
||||
the server from the registry."""
|
||||
manager = MCPServerManager()
|
||||
|
||||
table_record = LiteLLM_MCPServerTable(
|
||||
server_id="env-var-server-1",
|
||||
server_name="github_peruser",
|
||||
url="https://api.githubcopilot.com/mcp/",
|
||||
transport=MCPTransport.http,
|
||||
static_headers={"Authorization": "Bearer ${GITHUB_TOKEN}"},
|
||||
env_vars=[
|
||||
MCPEnvVar(
|
||||
name="GITHUB_TOKEN",
|
||||
scope=MCPEnvVarScope.user,
|
||||
description="Your personal GitHub PAT",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
mcp_server = await manager.build_mcp_server_from_table(table_record)
|
||||
|
||||
assert mcp_server.env_vars == [
|
||||
{
|
||||
"name": "GITHUB_TOKEN",
|
||||
"value": "",
|
||||
"scope": "user",
|
||||
"description": "Your personal GitHub PAT",
|
||||
}
|
||||
]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_round_trip_source_url_preserved(self):
|
||||
"""source_url survives the full round-trip: LiteLLM_MCPServerTable -> MCPServer -> LiteLLM_MCPServerTable.
|
||||
@ -4107,6 +4161,179 @@ class TestApprovalStatusGate:
|
||||
assert "never-seen" not in manager.registry
|
||||
|
||||
|
||||
class TestRegistryTableConversionPreservesEnvVars:
|
||||
"""The registry ``MCPServer`` -> ``LiteLLM_MCPServerTable`` conversions back
|
||||
the GET /v1/mcp/server list and health responses, which populate the admin
|
||||
edit form. When they dropped ``env_vars`` the form loaded an empty list and
|
||||
saving any edit silently wiped the stored vars, so ``${VAR}`` static headers
|
||||
were forwarded upstream un-interpolated.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _server_with_env_vars() -> MCPServer:
|
||||
return MCPServer(
|
||||
server_id="env-vars-server",
|
||||
name="env_vars_server",
|
||||
url="https://example.com/mcp",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.oauth2,
|
||||
static_headers={"X-Db-Url": "${DB_PROTOCOL}://${CORP_USER}@${DB_HOST}"},
|
||||
env_vars=[
|
||||
{
|
||||
"name": "DB_PROTOCOL",
|
||||
"value": "postgresql",
|
||||
"scope": "global",
|
||||
"description": None,
|
||||
},
|
||||
{
|
||||
"name": "CORP_USER",
|
||||
"value": "",
|
||||
"scope": "user",
|
||||
"description": "Your DB username",
|
||||
},
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _assert_env_vars_round_tripped(table: LiteLLM_MCPServerTable) -> None:
|
||||
assert table.env_vars is not None
|
||||
by_name = {entry.name: entry for entry in table.env_vars}
|
||||
assert set(by_name) == {"DB_PROTOCOL", "CORP_USER"}
|
||||
assert by_name["DB_PROTOCOL"].scope == MCPEnvVarScope.global_
|
||||
assert by_name["DB_PROTOCOL"].value == "postgresql"
|
||||
assert by_name["CORP_USER"].scope == MCPEnvVarScope.user
|
||||
assert by_name["CORP_USER"].description == "Your DB username"
|
||||
|
||||
def test_build_mcp_server_table_preserves_env_vars(self):
|
||||
manager = MCPServerManager()
|
||||
table = manager._build_mcp_server_table(self._server_with_env_vars())
|
||||
self._assert_env_vars_round_tripped(table)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_server_preserves_env_vars(self):
|
||||
# OAuth2 without client credentials needs a per-user token, so the
|
||||
# health check is skipped (no network) and we exercise the table
|
||||
# construction path directly.
|
||||
manager = MCPServerManager()
|
||||
server = self._server_with_env_vars()
|
||||
assert server.requires_per_user_auth is True
|
||||
manager.registry[server.server_id] = server
|
||||
table = await manager.health_check_server(server.server_id)
|
||||
self._assert_env_vars_round_tripped(table)
|
||||
|
||||
|
||||
class TestHealthCheckInterpolatesGlobalEnvVars:
|
||||
"""The upstream probes (health check and the initialize-instructions
|
||||
prefetch) must substitute global ``${NAME}`` env vars into static headers
|
||||
before opening the connection. Forwarding the raw placeholder makes any
|
||||
server whose auth header is backed by a global env var fail authentication
|
||||
and flip to 'unhealthy', even though real tool calls (which do interpolate)
|
||||
keep working.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _server() -> MCPServer:
|
||||
return MCPServer(
|
||||
server_id="global-env-server",
|
||||
name="global_env_server",
|
||||
url="https://example.com/mcp",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.none,
|
||||
static_headers={"Authorization": "Bearer ${API_TOKEN}"},
|
||||
env_vars=[
|
||||
{
|
||||
"name": "API_TOKEN",
|
||||
"value": "secret-token",
|
||||
"scope": "global",
|
||||
"description": None,
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _capture_headers(manager: MCPServerManager) -> Dict[str, Any]:
|
||||
captured: Dict[str, Any] = {}
|
||||
mock_client = AsyncMock()
|
||||
mock_client.run_with_session = AsyncMock(return_value="ok")
|
||||
|
||||
async def _create(server, mcp_auth_header, extra_headers, stdio_env):
|
||||
captured["extra_headers"] = extra_headers
|
||||
return mock_client
|
||||
|
||||
manager._create_mcp_client = AsyncMock(side_effect=_create)
|
||||
return captured
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_interpolates_global_env_vars(self):
|
||||
manager = MCPServerManager()
|
||||
server = self._server()
|
||||
assert server.requires_per_user_auth is False
|
||||
manager.get_mcp_server_by_id = MagicMock(return_value=server)
|
||||
manager._remember_upstream_initialize_instructions = MagicMock()
|
||||
captured = self._capture_headers(manager)
|
||||
|
||||
result = await manager.health_check_server(server.server_id)
|
||||
|
||||
assert captured["extra_headers"] == {"Authorization": "Bearer secret-token"}
|
||||
assert result.status == "healthy"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_initialize_instructions_prefetch_interpolates_global_env_vars(self):
|
||||
manager = MCPServerManager()
|
||||
server = self._server()
|
||||
captured = self._capture_headers(manager)
|
||||
manager._remember_upstream_initialize_instructions = MagicMock()
|
||||
|
||||
await manager._ensure_upstream_initialize_instructions_cached(server)
|
||||
|
||||
assert captured["extra_headers"] == {"Authorization": "Bearer secret-token"}
|
||||
|
||||
|
||||
class TestUserEnvVarsCacheEviction:
|
||||
"""At capacity the per-user env var cache must shed a single oldest entry
|
||||
rather than wiping every entry, so a steady stream of distinct callers does
|
||||
not periodically stampede the DB by invalidating every still-valid value.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _patch_cache(monkeypatch, max_size):
|
||||
from litellm.proxy._experimental.mcp_server import mcp_server_manager as m
|
||||
|
||||
cache: Dict[Any, Any] = {}
|
||||
monkeypatch.setattr(m, "_user_env_vars_cache", cache)
|
||||
monkeypatch.setattr(m, "_USER_ENV_VARS_CACHE_MAX_SIZE", max_size)
|
||||
return m, cache
|
||||
|
||||
def test_eviction_drops_single_oldest_entry_not_whole_cache(self, monkeypatch):
|
||||
m, cache = self._patch_cache(monkeypatch, max_size=3)
|
||||
|
||||
for i in range(3):
|
||||
m._write_user_env_vars_cache(f"user{i}", "srv", {"V": str(i)})
|
||||
assert set(cache) == {("user0", "srv"), ("user1", "srv"), ("user2", "srv")}
|
||||
|
||||
m._write_user_env_vars_cache("user3", "srv", {"V": "3"})
|
||||
|
||||
assert len(cache) == 3
|
||||
assert ("user0", "srv") not in cache
|
||||
assert ("user3", "srv") in cache
|
||||
assert cache[("user1", "srv")][0] == {"V": "1"}
|
||||
|
||||
def test_refreshing_existing_key_does_not_evict(self, monkeypatch):
|
||||
m, cache = self._patch_cache(monkeypatch, max_size=2)
|
||||
|
||||
m._write_user_env_vars_cache("a", "srv", {"V": "1"})
|
||||
m._write_user_env_vars_cache("b", "srv", {"V": "2"})
|
||||
m._write_user_env_vars_cache("a", "srv", {"V": "1-new"})
|
||||
|
||||
assert set(cache) == {("a", "srv"), ("b", "srv")}
|
||||
assert cache[("a", "srv")][0] == {"V": "1-new"}
|
||||
# The just-refreshed key must now sit at the tail so the next insert
|
||||
# evicts the genuinely older entry instead.
|
||||
m._write_user_env_vars_cache("c", "srv", {"V": "3"})
|
||||
assert ("b", "srv") not in cache
|
||||
assert ("a", "srv") in cache
|
||||
|
||||
|
||||
class TestGetPublicMCPServers:
|
||||
"""
|
||||
/public/mcp_hub strict-whitelist semantics — mirrors /public/model_hub
|
||||
|
||||
@ -1007,6 +1007,7 @@ class TestRotateCredentials:
|
||||
"aws_secret_access_key": "enc_old:SAK",
|
||||
"aws_region_name": "us-east-1",
|
||||
}
|
||||
server.env_vars = None
|
||||
|
||||
mock_prisma = MagicMock()
|
||||
mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock(
|
||||
@ -1043,6 +1044,58 @@ class TestRotateCredentials:
|
||||
# Non-secret fields should pass through unchanged
|
||||
assert stored_creds["aws_region_name"] == "us-east-1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rotation_reencrypts_global_env_vars(self):
|
||||
"""Global env var values are re-encrypted under the new key; user-scope
|
||||
placeholders are left untouched."""
|
||||
from litellm.proxy._experimental.mcp_server.db import (
|
||||
rotate_mcp_server_credentials_master_key,
|
||||
)
|
||||
|
||||
server = MagicMock()
|
||||
server.server_id = "srv-env"
|
||||
server.credentials = None
|
||||
server.env_vars = [
|
||||
{"name": "API_KEY", "value": "enc_old:secret", "scope": "global"},
|
||||
{"name": "USER_TOKEN", "value": "", "scope": "user"},
|
||||
]
|
||||
|
||||
mock_prisma = MagicMock()
|
||||
mock_prisma.db.litellm_mcpservertable.find_many = AsyncMock(
|
||||
return_value=[server]
|
||||
)
|
||||
mock_prisma.db.litellm_mcpservertable.update = AsyncMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.db._get_salt_key",
|
||||
return_value="old-key",
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.db.decrypt_value_helper",
|
||||
side_effect=lambda value, key, exception_type="error", return_original_value=False: value.replace(
|
||||
"enc_old:", ""
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy._experimental.mcp_server.db.encrypt_value_helper",
|
||||
side_effect=lambda value, new_encryption_key: f"enc_new:{value}",
|
||||
),
|
||||
):
|
||||
await rotate_mcp_server_credentials_master_key(
|
||||
mock_prisma, "admin", "new-key"
|
||||
)
|
||||
|
||||
update_call = mock_prisma.db.litellm_mcpservertable.update
|
||||
assert update_call.called
|
||||
stored_env = json.loads(update_call.call_args[1]["data"]["env_vars"])
|
||||
# Global value decrypted from old, then re-encrypted with new key
|
||||
assert stored_env[0]["value"] == "enc_new:secret"
|
||||
# User-scope placeholder untouched
|
||||
assert stored_env[1]["value"] == ""
|
||||
# Credentials column not written when the server has none
|
||||
assert "credentials" not in update_call.call_args[1]["data"]
|
||||
|
||||
|
||||
class TestAuthTypeSwitchClearsCredentials:
|
||||
"""Test that switching auth_type without credentials clears stale secrets."""
|
||||
|
||||
@ -2,6 +2,7 @@ import json
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
from starlette.requests import Request
|
||||
@ -1629,3 +1630,48 @@ class TestPreviewOpenAPITools:
|
||||
"order is out of sync, so collision suffixes (_2, _3, ...) "
|
||||
"land on different operations"
|
||||
)
|
||||
|
||||
|
||||
class TestConnectionErrorMessage:
|
||||
"""The test-connection endpoints turn raw transport errors into messages.
|
||||
|
||||
The message is returned to an admin in an API response, so it must explain
|
||||
the failure without echoing the raw header value, which can carry a secret
|
||||
(e.g. ``Authorization: Bearer <token>``).
|
||||
"""
|
||||
|
||||
def test_local_protocol_error_is_actionable_and_redacted(self):
|
||||
secret = "Bearer sk-super-secret-token"
|
||||
exc = httpx.LocalProtocolError(f"Illegal header value b' {secret}'")
|
||||
|
||||
message = rest_endpoints._connection_error_message(exc)
|
||||
|
||||
assert "header" in message.lower()
|
||||
assert secret not in message
|
||||
|
||||
def test_connect_error_points_at_reachability(self):
|
||||
message = rest_endpoints._connection_error_message(
|
||||
httpx.ConnectError("All connection attempts failed")
|
||||
)
|
||||
assert "unreachable" in message.lower()
|
||||
|
||||
def test_timeout_error_message(self):
|
||||
message = rest_endpoints._connection_error_message(
|
||||
httpx.ConnectTimeout("timed out")
|
||||
)
|
||||
assert "unreachable" in message.lower()
|
||||
|
||||
def test_http_status_error_includes_status_code(self):
|
||||
response = httpx.Response(status_code=503)
|
||||
exc = httpx.HTTPStatusError(
|
||||
"server error",
|
||||
request=httpx.Request("POST", "http://x/"),
|
||||
response=response,
|
||||
)
|
||||
message = rest_endpoints._connection_error_message(exc)
|
||||
assert "503" in message
|
||||
|
||||
def test_unknown_error_falls_back_to_generic(self):
|
||||
message = rest_endpoints._connection_error_message(RuntimeError("weird"))
|
||||
assert "weird" not in message
|
||||
assert "proxy logs" in message.lower()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@ -19,6 +19,154 @@ from litellm.proxy._types import (
|
||||
from litellm.proxy.management_helpers.utils import add_new_member
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_management_otel_span_redacts_mcp_global_env_var_secrets(monkeypatch):
|
||||
"""A decrypted MCP global env var secret must never reach telemetry.
|
||||
|
||||
MCP create/update endpoints return the server with decrypted
|
||||
``scope="global"`` env var values so the admin UI can pre-fill the edit
|
||||
form. ``management_endpoint_wrapper`` serializes the response into an OTEL
|
||||
span, and that span is readable by observability users, so the secret value
|
||||
must be blanked there while names/scopes stay for usefulness. The endpoint's
|
||||
own return value must keep the decrypted value for the admin.
|
||||
"""
|
||||
import datetime
|
||||
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_MCPServerTable,
|
||||
MCPEnvVar,
|
||||
MCPEnvVarScope,
|
||||
)
|
||||
from litellm.proxy.management_helpers import utils as mgmt_utils
|
||||
|
||||
captured = {}
|
||||
|
||||
class _FakeOtelLogger:
|
||||
async def async_management_endpoint_success_hook(
|
||||
self, logging_payload, parent_otel_span
|
||||
):
|
||||
captured["response"] = logging_payload.response
|
||||
|
||||
import litellm.proxy.proxy_server as proxy_server
|
||||
|
||||
monkeypatch.setattr(proxy_server, "open_telemetry_logger", _FakeOtelLogger())
|
||||
monkeypatch.setattr(mgmt_utils, "is_otel_v2_enabled", lambda: False)
|
||||
|
||||
secret = "s3cr3t-p@ss"
|
||||
result = LiteLLM_MCPServerTable(
|
||||
server_id="srv-1",
|
||||
alias="echo",
|
||||
url="http://localhost:8765/mcp",
|
||||
transport="http",
|
||||
env_vars=[
|
||||
MCPEnvVar(name="DB_PASSWORD", value=secret, scope=MCPEnvVarScope.global_),
|
||||
MCPEnvVar(
|
||||
name="CORP_USER",
|
||||
value="",
|
||||
scope=MCPEnvVarScope.user,
|
||||
description="Your DB username",
|
||||
),
|
||||
],
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
)
|
||||
|
||||
await mgmt_utils._emit_management_endpoint_otel_span(
|
||||
func=lambda: None,
|
||||
kwargs={},
|
||||
parent_otel_span=object(),
|
||||
start_time=datetime.datetime.now(),
|
||||
end_time=datetime.datetime.now(),
|
||||
result=result,
|
||||
)
|
||||
|
||||
serialized = captured["response"]["env_vars"]
|
||||
# The secret must not appear anywhere the span serializer would stringify.
|
||||
assert secret not in str(captured["response"])
|
||||
assert all(entry["value"] == "" for entry in serialized)
|
||||
# Names and scopes survive so the trace stays useful.
|
||||
assert {entry["name"] for entry in serialized} == {"DB_PASSWORD", "CORP_USER"}
|
||||
assert any(entry["scope"] == MCPEnvVarScope.global_ for entry in serialized)
|
||||
# The endpoint's own return value is untouched: the admin still gets the
|
||||
# decrypted value to pre-fill the edit form.
|
||||
assert result.env_vars[0].value == secret
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_management_otel_span_redacts_nested_submission_env_var_secrets(
|
||||
monkeypatch,
|
||||
):
|
||||
"""Decrypted global env var secrets nested under ``items`` must also be blanked.
|
||||
|
||||
``GET /v1/mcp/server/submissions`` returns ``MCPSubmissionsSummary`` whose
|
||||
``items[].env_vars`` carry decrypted ``scope="global"`` values for full admins.
|
||||
``management_endpoint_wrapper`` stringifies that nested ``items`` value into the
|
||||
OTEL span, so redaction has to walk into ``items`` and not just the top level,
|
||||
while the endpoint's own return value keeps the value for the admin UI.
|
||||
"""
|
||||
import datetime
|
||||
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_MCPServerTable,
|
||||
MCPEnvVar,
|
||||
MCPEnvVarScope,
|
||||
MCPSubmissionsSummary,
|
||||
)
|
||||
from litellm.proxy.management_helpers import utils as mgmt_utils
|
||||
|
||||
captured = {}
|
||||
|
||||
class _FakeOtelLogger:
|
||||
async def async_management_endpoint_success_hook(
|
||||
self, logging_payload, parent_otel_span
|
||||
):
|
||||
captured["response"] = logging_payload.response
|
||||
|
||||
import litellm.proxy.proxy_server as proxy_server
|
||||
|
||||
monkeypatch.setattr(proxy_server, "open_telemetry_logger", _FakeOtelLogger())
|
||||
monkeypatch.setattr(mgmt_utils, "is_otel_v2_enabled", lambda: False)
|
||||
|
||||
secret = "s3cr3t-submission"
|
||||
server = LiteLLM_MCPServerTable(
|
||||
server_id="srv-sub",
|
||||
alias="echo",
|
||||
url="http://localhost:8765/mcp",
|
||||
transport="http",
|
||||
env_vars=[
|
||||
MCPEnvVar(name="DB_PASSWORD", value=secret, scope=MCPEnvVarScope.global_),
|
||||
],
|
||||
created_at=datetime.datetime.now(),
|
||||
updated_at=datetime.datetime.now(),
|
||||
)
|
||||
result = MCPSubmissionsSummary(
|
||||
total=1, pending_review=1, active=0, rejected=0, items=[server]
|
||||
)
|
||||
|
||||
await mgmt_utils._emit_management_endpoint_otel_span(
|
||||
func=lambda: None,
|
||||
kwargs={},
|
||||
parent_otel_span=object(),
|
||||
start_time=datetime.datetime.now(),
|
||||
end_time=datetime.datetime.now(),
|
||||
result=result,
|
||||
)
|
||||
|
||||
# The nested secret must not appear anywhere the span serializer stringifies.
|
||||
assert secret not in str(captured["response"])
|
||||
|
||||
redacted_item = captured["response"]["items"][0]
|
||||
redacted_env_vars = (
|
||||
redacted_item["env_vars"]
|
||||
if isinstance(redacted_item, dict)
|
||||
else redacted_item.env_vars
|
||||
)
|
||||
assert [entry["value"] for entry in redacted_env_vars] == [""]
|
||||
assert redacted_env_vars[0]["name"] == "DB_PASSWORD"
|
||||
# The endpoint's own return value is untouched for the admin UI.
|
||||
assert result.items[0].env_vars[0].value == secret
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_new_member_clones_default_team_budget_id():
|
||||
"""
|
||||
|
||||
@ -50,11 +50,11 @@ test.describe("MCP Servers", () => {
|
||||
|
||||
// No teardown needed — the e2e runner spins up a fresh DB per invocation.
|
||||
|
||||
// Success toast and the new row in the table. Scope the row lookup to
|
||||
// the MCP servers table so the form modal's `server_name` input — which
|
||||
// Success toast and the new card in the server grid. Scope the lookup to
|
||||
// the MCP servers grid so the form modal's `server_name` input — which
|
||||
// still holds the timestamped value during its close animation — can't
|
||||
// satisfy the assertion before the server actually lands in the list.
|
||||
await expect(page.getByText("MCP Server created successfully").first()).toBeVisible({ timeout: 15_000 });
|
||||
await expect(page.locator("table tbody").getByText(uniqueName).first()).toBeVisible({ timeout: 10_000 });
|
||||
await expect(page.getByTestId("mcp-servers-grid").getByText(uniqueName).first()).toBeVisible({ timeout: 10_000 });
|
||||
});
|
||||
});
|
||||
|
||||
142
ui/litellm-dashboard/src/components/mcp_tools/EnvVarsSection.tsx
Normal file
142
ui/litellm-dashboard/src/components/mcp_tools/EnvVarsSection.tsx
Normal file
@ -0,0 +1,142 @@
|
||||
import React from "react";
|
||||
import { Form, Input, Select, Button, Tooltip, Typography } from "antd";
|
||||
import { InfoCircleOutlined, MinusCircleOutlined, PlusOutlined } from "@ant-design/icons";
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
const SCOPE_OPTIONS = [
|
||||
{ value: "global", label: "Instance" },
|
||||
{ value: "user", label: "Per-user" },
|
||||
];
|
||||
|
||||
/**
|
||||
* Form section for admin-configured MCP environment variables.
|
||||
*
|
||||
* Each row has: name | value | scope. Variables can be interpolated into
|
||||
* Static Headers via ${NAME}. ``scope=global`` (shown as "Instance") values
|
||||
* are used as-is. ``scope=user`` (shown as "Per-user") values are filled in
|
||||
* by each user via the MCP Gateway dashboard.
|
||||
*
|
||||
* The parent form reads the ``env_vars`` field from the form values.
|
||||
*/
|
||||
const EnvVarsSection: React.FC = () => {
|
||||
return (
|
||||
<div className="rounded-lg border border-gray-200 bg-gray-50 p-4">
|
||||
<div className="flex items-center gap-2 mb-1">
|
||||
<Text strong className="text-sm">
|
||||
Variables
|
||||
</Text>
|
||||
<Tooltip
|
||||
title={
|
||||
<>
|
||||
Define variables you can interpolate in Static Headers or Authentication using{" "}
|
||||
<code>{"${VAR_NAME}"}</code>. <br />
|
||||
<b>Instance</b>: admin-defined value used for every user.
|
||||
<br />
|
||||
<b>Per-user</b>: each user supplies their own value (e.g. personal credentials) via the MCP Gateway
|
||||
dashboard.
|
||||
</>
|
||||
}
|
||||
>
|
||||
<InfoCircleOutlined className="text-blue-400 hover:text-blue-600 cursor-help" />
|
||||
</Tooltip>
|
||||
</div>
|
||||
<Text className="text-xs text-gray-600 block mb-3">
|
||||
Reference these in Static Headers or Authentication as <code>{"${VAR_NAME}"}</code>. For example:{" "}
|
||||
<code className="bg-white px-1 rounded border border-gray-200">
|
||||
{"${DB_PROTOCOL}://${CORP_USERNAME}:${CORP_PASSWORD}@${DB_HOSTNAME}"}
|
||||
</code>
|
||||
</Text>
|
||||
|
||||
<Form.List name="env_vars">
|
||||
{(fields, { add, remove }) => (
|
||||
<div className="space-y-2">
|
||||
{fields.length > 0 && (
|
||||
<div className="flex gap-3 px-1 text-xs font-medium text-gray-500 uppercase tracking-wide">
|
||||
<div style={{ flex: 1 }}>Variable Name</div>
|
||||
<div style={{ flex: 1 }}>Value / Description</div>
|
||||
<div style={{ width: 160 }}>Scope</div>
|
||||
<div style={{ width: 24 }} />
|
||||
</div>
|
||||
)}
|
||||
{fields.map(({ key, name, ...restField }) => (
|
||||
<div key={key} className="flex gap-3 items-start">
|
||||
<Form.Item
|
||||
{...restField}
|
||||
name={[name, "name"]}
|
||||
className="mb-0"
|
||||
style={{ flex: 1 }}
|
||||
rules={[
|
||||
{ required: true, message: "Variable name is required" },
|
||||
{
|
||||
pattern: /^[A-Za-z_][A-Za-z0-9_]*$/,
|
||||
message: "Use letters, digits, underscores; cannot start with a digit.",
|
||||
},
|
||||
]}
|
||||
>
|
||||
<Input placeholder="e.g. DB_PROTOCOL" className="rounded-md font-mono" />
|
||||
</Form.Item>
|
||||
<div style={{ flex: 1 }}>
|
||||
<ScopedValueOrDescription name={name} restField={restField} />
|
||||
</div>
|
||||
<Form.Item
|
||||
{...restField}
|
||||
name={[name, "scope"]}
|
||||
className="mb-0"
|
||||
initialValue="global"
|
||||
style={{ width: 160 }}
|
||||
>
|
||||
<Select options={SCOPE_OPTIONS} />
|
||||
</Form.Item>
|
||||
<div style={{ width: 24, height: 32 }} className="flex items-center justify-center">
|
||||
<MinusCircleOutlined
|
||||
onClick={() => remove(name)}
|
||||
className="text-gray-500 hover:text-red-500 cursor-pointer"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<Button type="dashed" onClick={() => add({ scope: "global" })} icon={<PlusOutlined />} block>
|
||||
Add Variable
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</Form.List>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
// For instance-scoped vars this column holds the admin value. For per-user
|
||||
// vars the value comes from each user later, so the column instead captures an
|
||||
// optional description that the per-user fill-in modal shows as a hint.
|
||||
const ScopedValueOrDescription: React.FC<{
|
||||
name: number;
|
||||
restField: object;
|
||||
}> = ({ name, restField }) => {
|
||||
const isPerUser = Form.useWatch(["env_vars", name, "scope"]) === "user";
|
||||
if (isPerUser) {
|
||||
return (
|
||||
<Form.Item {...restField} name={[name, "description"]} className="mb-0">
|
||||
<Input
|
||||
addonBefore={
|
||||
<Tooltip title="Per-user variables have no shared value. This text is only a hint shown to each user when they fill in their own value.">
|
||||
<span className="text-xs text-gray-500 cursor-help whitespace-nowrap">
|
||||
<InfoCircleOutlined className="mr-1" />
|
||||
Hint
|
||||
</span>
|
||||
</Tooltip>
|
||||
}
|
||||
placeholder="e.g. Your DB username"
|
||||
styles={{ input: { color: "#9ca3af" } }}
|
||||
/>
|
||||
</Form.Item>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<Form.Item {...restField} name={[name, "value"]} className="mb-0">
|
||||
<Input placeholder="e.g. postgresql" className="rounded-md font-mono" />
|
||||
</Form.Item>
|
||||
);
|
||||
};
|
||||
|
||||
export default EnvVarsSection;
|
||||
@ -53,6 +53,17 @@ const MCPPermissionManagement: React.FC<MCPPermissionManagementProps> = ({
|
||||
}));
|
||||
form.setFieldValue("static_headers", staticHeaders);
|
||||
}
|
||||
if (Array.isArray(mcpServer.env_vars) && mcpServer.env_vars.length > 0) {
|
||||
form.setFieldValue(
|
||||
"env_vars",
|
||||
mcpServer.env_vars.map((entry) => ({
|
||||
name: entry.name,
|
||||
value: entry.value ?? "",
|
||||
scope: entry.scope ?? "global",
|
||||
description: entry.description ?? "",
|
||||
})),
|
||||
);
|
||||
}
|
||||
if (typeof mcpServer.allow_all_keys === "boolean") {
|
||||
form.setFieldValue("allow_all_keys", mcpServer.allow_all_keys);
|
||||
}
|
||||
|
||||
383
ui/litellm-dashboard/src/components/mcp_tools/MCPServerCard.tsx
Normal file
383
ui/litellm-dashboard/src/components/mcp_tools/MCPServerCard.tsx
Normal file
@ -0,0 +1,383 @@
|
||||
import { useState, type FC, type KeyboardEvent, type MouseEvent } from "react";
|
||||
import { Dropdown, Tooltip, Typography, Tag } from "antd";
|
||||
import type { MenuProps } from "antd";
|
||||
import {
|
||||
CheckOutlined,
|
||||
DeleteOutlined,
|
||||
ExclamationCircleFilled,
|
||||
MoreOutlined,
|
||||
ThunderboltOutlined,
|
||||
} from "@ant-design/icons";
|
||||
import type { MCPServer } from "./types";
|
||||
import { getMaskedAndFullUrl } from "./utils";
|
||||
|
||||
const { Text } = Typography;
|
||||
|
||||
interface MCPServerCardProps {
|
||||
server: MCPServer;
|
||||
// Per-user env-var fields this user still needs to fill in for this server.
|
||||
// Computed by the parent from the bulk /user-env-vars/status response, so
|
||||
// the card never issues a per-row request (no N+1).
|
||||
missingUserFields?: string[];
|
||||
isLoadingHealth?: boolean;
|
||||
isRechecking?: boolean;
|
||||
onClick: () => void;
|
||||
onRecheckHealth?: () => void;
|
||||
onByokConnect?: () => void;
|
||||
onOpenFillFields?: () => void;
|
||||
onDelete?: () => void;
|
||||
}
|
||||
|
||||
const HEALTH_TONE: Record<string, { dot: string }> = {
|
||||
healthy: { dot: "bg-green-500" },
|
||||
unhealthy: { dot: "bg-red-500" },
|
||||
unknown: { dot: "bg-gray-300" },
|
||||
};
|
||||
|
||||
// Stop card-level click handler from firing when an interactive child is used.
|
||||
const stop = (e: MouseEvent | KeyboardEvent) => e.stopPropagation();
|
||||
|
||||
const MCPServerCard: FC<MCPServerCardProps> = ({
|
||||
server,
|
||||
missingUserFields,
|
||||
isLoadingHealth,
|
||||
isRechecking,
|
||||
onClick,
|
||||
onRecheckHealth,
|
||||
onByokConnect,
|
||||
onOpenFillFields,
|
||||
onDelete,
|
||||
}) => {
|
||||
const alias = server.alias || server.server_name || "";
|
||||
const name = server.server_name || alias || server.server_id;
|
||||
// Logo is sourced exclusively from the admin-set `mcp_info.logo_url`.
|
||||
const candidateLogo = server.mcp_info?.logo_url ?? undefined;
|
||||
const [failedLogoUrl, setFailedLogoUrl] = useState<string | null>(null);
|
||||
const logoUrl = candidateLogo && failedLogoUrl !== candidateLogo ? candidateLogo : undefined;
|
||||
const transport = server.transport || "http";
|
||||
const displayTransport = server.spec_path && transport !== "stdio" ? "openapi" : transport;
|
||||
const authType = server.auth_type || "none";
|
||||
const status = server.status || "unknown";
|
||||
const healthTone = HEALTH_TONE[status] ?? HEALTH_TONE.unknown;
|
||||
const isPublic = server.available_on_public_internet;
|
||||
const accessGroups = (server.mcp_access_groups ?? []).filter((g): g is string => typeof g === "string");
|
||||
|
||||
const missing = missingUserFields ?? [];
|
||||
const needsAttention = missing.length > 0;
|
||||
|
||||
const cardClass = needsAttention
|
||||
? "border-2 border-red-300 bg-red-50/40 hover:border-red-400 hover:shadow-md"
|
||||
: "border border-gray-200 bg-white hover:border-gray-300 hover:shadow-md";
|
||||
|
||||
const url = server.url || "";
|
||||
const { maskedUrl } = url ? getMaskedAndFullUrl(url) : { maskedUrl: "" };
|
||||
|
||||
// Transport-adapted identifier shown under the title. Every transport has
|
||||
// something useful here, which keeps the tag row vertically aligned across
|
||||
// cards in the grid (stdio cards no longer "snap up" because they lack a URL).
|
||||
let subtitle = "";
|
||||
let subtitleTooltip = "";
|
||||
if (transport === "stdio") {
|
||||
const parts = [server.command, ...(server.args ?? [])].filter(
|
||||
(p): p is string => typeof p === "string" && p.length > 0,
|
||||
);
|
||||
subtitle = parts.join(" ");
|
||||
subtitleTooltip = subtitle;
|
||||
} else if (server.spec_path) {
|
||||
subtitle = server.spec_path;
|
||||
subtitleTooltip = server.spec_path;
|
||||
} else if (url) {
|
||||
subtitle = maskedUrl;
|
||||
subtitleTooltip = url;
|
||||
}
|
||||
|
||||
const handleKeyDown = (e: KeyboardEvent<HTMLDivElement>) => {
|
||||
if (e.key === "Enter" || e.key === " ") {
|
||||
e.preventDefault();
|
||||
onClick();
|
||||
}
|
||||
};
|
||||
|
||||
const menuItems: MenuProps["items"] = [];
|
||||
if (onRecheckHealth) {
|
||||
menuItems.push({
|
||||
key: "test-connection",
|
||||
label: "Test Connection",
|
||||
icon: <ThunderboltOutlined />,
|
||||
disabled: isRechecking,
|
||||
onClick: ({ domEvent }) => {
|
||||
domEvent.stopPropagation();
|
||||
onRecheckHealth();
|
||||
},
|
||||
});
|
||||
}
|
||||
if (onDelete) {
|
||||
if (menuItems.length > 0) {
|
||||
menuItems.push({ key: "divider", type: "divider" });
|
||||
}
|
||||
menuItems.push({
|
||||
key: "delete",
|
||||
label: "Delete",
|
||||
icon: <DeleteOutlined />,
|
||||
danger: true,
|
||||
onClick: ({ domEvent }) => {
|
||||
domEvent.stopPropagation();
|
||||
onDelete();
|
||||
},
|
||||
});
|
||||
}
|
||||
|
||||
// Card uses role="button" + nested <button> children (Set, BYOK Connect, the
|
||||
// recheck-health Tag), so a real <button> wrapper would produce invalid
|
||||
// nested-interactive HTML. The role + tabIndex + Enter/Space handler keeps
|
||||
// the whole card clickable and keyboard-accessible.
|
||||
return (
|
||||
<div
|
||||
role="button"
|
||||
tabIndex={0}
|
||||
onClick={onClick}
|
||||
onKeyDown={handleKeyDown}
|
||||
className={`group relative flex h-full cursor-pointer flex-col gap-3 rounded-lg p-4 transition-all duration-150 focus:outline-none focus-visible:ring-2 focus-visible:ring-blue-400 ${cardClass}`}
|
||||
>
|
||||
<div className="flex items-start gap-3">
|
||||
{logoUrl ? (
|
||||
<img
|
||||
src={logoUrl}
|
||||
alt={`${name} logo`}
|
||||
className="h-10 w-10 flex-shrink-0 rounded object-contain"
|
||||
onError={() => setFailedLogoUrl(logoUrl)}
|
||||
/>
|
||||
) : (
|
||||
<div className="flex h-10 w-10 flex-shrink-0 items-center justify-center rounded bg-gray-100 font-semibold text-gray-500">
|
||||
{(name || "?").slice(0, 2).toUpperCase()}
|
||||
</div>
|
||||
)}
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="block w-full truncate text-left font-semibold text-gray-900" title={name}>
|
||||
{name}
|
||||
</div>
|
||||
<div className="mt-0.5 flex items-center gap-2 text-xs text-gray-500">
|
||||
{alias && <span className="truncate">{alias}</span>}
|
||||
{alias && <span className="text-gray-300">·</span>}
|
||||
<Tooltip title={server.server_id}>
|
||||
<span className="font-mono text-blue-600">{server.server_id.slice(0, 7)}</span>
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
{menuItems.length > 0 && (
|
||||
<Dropdown menu={{ items: menuItems }} trigger={["click"]} placement="bottomRight">
|
||||
<button
|
||||
type="button"
|
||||
onClick={stop}
|
||||
onKeyDown={stop}
|
||||
aria-label="Server actions"
|
||||
className="-mr-1 -mt-1 inline-flex h-8 w-8 items-center justify-center rounded-md text-gray-500 transition-colors hover:bg-gray-100 hover:text-blue-600"
|
||||
>
|
||||
<MoreOutlined style={{ fontSize: 20 }} />
|
||||
</button>
|
||||
</Dropdown>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{subtitle ? (
|
||||
<Tooltip title={subtitleTooltip}>
|
||||
<Text className="truncate font-mono text-xs text-gray-500" ellipsis>
|
||||
{subtitle}
|
||||
</Text>
|
||||
</Tooltip>
|
||||
) : (
|
||||
// Defensive placeholder: keep the row even when no identifier is
|
||||
// available so the tag row stays vertically aligned across the grid.
|
||||
<div className="h-[18px]" aria-hidden />
|
||||
)}
|
||||
|
||||
<div className="flex flex-wrap items-center gap-1.5">
|
||||
<HealthChip
|
||||
status={status}
|
||||
isLoadingHealth={isLoadingHealth}
|
||||
isRechecking={isRechecking}
|
||||
onRecheck={onRecheckHealth}
|
||||
lastCheck={server.last_health_check}
|
||||
error={server.health_check_error}
|
||||
dotClass={healthTone.dot}
|
||||
/>
|
||||
<Tag className="m-0">{displayTransport.toUpperCase()}</Tag>
|
||||
<Tag className="m-0">{authType}</Tag>
|
||||
<Tag color={isPublic ? "green" : "orange"} className="m-0">
|
||||
<span className="inline-flex items-center gap-1">
|
||||
<span className={`h-1.5 w-1.5 rounded-full ${isPublic ? "bg-green-500" : "bg-orange-500"}`} />
|
||||
{isPublic ? "Public" : "Internal"}
|
||||
</span>
|
||||
</Tag>
|
||||
{accessGroups.slice(0, 2).map((g) => (
|
||||
<Tooltip key={g} title={g}>
|
||||
<Tag className="m-0 max-w-[120px] truncate">{g}</Tag>
|
||||
</Tooltip>
|
||||
))}
|
||||
{accessGroups.length > 2 && (
|
||||
<Tooltip title={accessGroups.slice(2).join(", ")}>
|
||||
<Tag className="m-0">+{accessGroups.length - 2}</Tag>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{(server.is_byok || needsAttention) && (
|
||||
<div className="mt-auto flex flex-col gap-2">
|
||||
{server.is_byok && <ByokRow connected={!!server.has_user_credential} onConnect={onByokConnect} />}
|
||||
{needsAttention && (
|
||||
<div className="flex items-center justify-between gap-2 text-xs">
|
||||
<Tooltip
|
||||
title={
|
||||
<div>
|
||||
<div className="font-semibold mb-1">Missing user fields:</div>
|
||||
<ul className="ml-3">
|
||||
{missing.map((m) => (
|
||||
<li key={m}>• {m}</li>
|
||||
))}
|
||||
</ul>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<span className="inline-flex items-center gap-1 font-semibold text-red-700">
|
||||
<ExclamationCircleFilled />
|
||||
{missing.length} user field
|
||||
{missing.length === 1 ? "" : "s"} missing
|
||||
</span>
|
||||
</Tooltip>
|
||||
{onOpenFillFields && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
stop(e);
|
||||
onOpenFillFields();
|
||||
}}
|
||||
className="rounded-md bg-red-600 px-3 py-1 text-xs font-medium text-white shadow-sm transition-colors hover:bg-red-700"
|
||||
>
|
||||
Set
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface HealthChipProps {
|
||||
status: string;
|
||||
isLoadingHealth?: boolean;
|
||||
isRechecking?: boolean;
|
||||
onRecheck?: () => void;
|
||||
lastCheck?: string | null;
|
||||
error?: string | null;
|
||||
dotClass: string;
|
||||
}
|
||||
|
||||
const HealthChip: FC<HealthChipProps> = ({
|
||||
status,
|
||||
isLoadingHealth,
|
||||
isRechecking,
|
||||
onRecheck,
|
||||
lastCheck,
|
||||
error,
|
||||
dotClass,
|
||||
}) => {
|
||||
if (isLoadingHealth || isRechecking) {
|
||||
return (
|
||||
<Tag className="m-0">
|
||||
<span className="inline-flex items-center gap-1.5 text-xs text-gray-500">
|
||||
<span className="h-1.5 w-1.5 animate-pulse rounded-full bg-gray-300" />
|
||||
Checking
|
||||
</span>
|
||||
</Tag>
|
||||
);
|
||||
}
|
||||
const tooltip = (
|
||||
<div className="max-w-xs">
|
||||
<div className="font-semibold mb-1">Health: {status}</div>
|
||||
{lastCheck && <div className="text-xs mb-1">Last check: {new Date(lastCheck).toLocaleString()}</div>}
|
||||
{error && (
|
||||
<div className="text-xs">
|
||||
<div className="font-medium text-red-300 mb-1">Error</div>
|
||||
<div className="break-words">{error}</div>
|
||||
</div>
|
||||
)}
|
||||
{!lastCheck && !error && <div className="text-xs text-gray-400">No health data</div>}
|
||||
{onRecheck && <div className="mt-1 text-xs text-gray-300">Click to recheck</div>}
|
||||
</div>
|
||||
);
|
||||
return (
|
||||
<Tooltip title={tooltip} placement="top">
|
||||
<Tag
|
||||
className={`m-0 ${onRecheck ? "cursor-pointer hover:opacity-80" : "cursor-default"}`}
|
||||
onClick={
|
||||
onRecheck
|
||||
? (e) => {
|
||||
e.stopPropagation();
|
||||
onRecheck();
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
<span className="inline-flex items-center gap-1.5">
|
||||
<span className={`h-1.5 w-1.5 rounded-full ${dotClass}`} />
|
||||
{status.charAt(0).toUpperCase() + status.slice(1)}
|
||||
</span>
|
||||
</Tag>
|
||||
</Tooltip>
|
||||
);
|
||||
};
|
||||
|
||||
interface ByokRowProps {
|
||||
connected: boolean;
|
||||
onConnect?: () => void;
|
||||
}
|
||||
|
||||
const ByokRow: FC<ByokRowProps> = ({ connected, onConnect }) => {
|
||||
if (connected) {
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 text-xs">
|
||||
<span className="text-gray-500">BYOK credential</span>
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="inline-flex items-center gap-1 rounded-full border border-green-200 bg-green-50 px-2 py-0.5 font-medium text-green-700">
|
||||
<CheckOutlined style={{ fontSize: 10 }} /> Connected
|
||||
</span>
|
||||
{onConnect && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
stop(e);
|
||||
onConnect();
|
||||
}}
|
||||
className="text-xs text-gray-400 transition-colors hover:text-blue-600"
|
||||
>
|
||||
Update
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return (
|
||||
<div className="flex items-center justify-between gap-2 text-xs">
|
||||
<span className="text-gray-500">BYOK credential</span>
|
||||
{onConnect ? (
|
||||
<button
|
||||
type="button"
|
||||
onClick={(e) => {
|
||||
stop(e);
|
||||
onConnect();
|
||||
}}
|
||||
className="rounded-md bg-blue-600 px-3 py-1 text-xs font-medium text-white shadow-sm transition-colors hover:bg-blue-700"
|
||||
>
|
||||
Connect
|
||||
</button>
|
||||
) : (
|
||||
<span className="text-gray-400">—</span>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default MCPServerCard;
|
||||
@ -0,0 +1,141 @@
|
||||
import React from "react";
|
||||
import { Modal, Form, Input, Button, Alert, Spin, Tag, Typography } from "antd";
|
||||
import { useMutation, useQuery } from "@tanstack/react-query";
|
||||
import { MCPServer, MCPUserEnvVarsStatus } from "./types";
|
||||
import { getMCPUserEnvVars, storeMCPUserEnvVars } from "../networking";
|
||||
import NotificationsManager from "../molecules/notifications_manager";
|
||||
|
||||
const { Text, Title } = Typography;
|
||||
|
||||
interface UserEnvVarsModalProps {
|
||||
server: MCPServer | null;
|
||||
open: boolean;
|
||||
accessToken: string | null;
|
||||
onClose: () => void;
|
||||
onSaved?: (status: MCPUserEnvVarsStatus) => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* User-facing modal for filling in per-user MCP environment variables.
|
||||
*
|
||||
* Backed by GET / POST ``/v1/mcp/server/{id}/user-env-vars``. Each field
|
||||
* the admin marked as ``scope=user`` shows up with the admin-supplied
|
||||
* description as the placeholder.
|
||||
*/
|
||||
const UserEnvVarsModal: React.FC<UserEnvVarsModalProps> = ({ server, open, accessToken, onClose, onSaved }) => {
|
||||
const [form] = Form.useForm();
|
||||
|
||||
const {
|
||||
data: status,
|
||||
isLoading,
|
||||
isError,
|
||||
} = useQuery<MCPUserEnvVarsStatus>({
|
||||
queryKey: ["mcpUserEnvVars", server?.server_id],
|
||||
queryFn: () => getMCPUserEnvVars(accessToken!, server!.server_id),
|
||||
enabled: open && !!server && !!accessToken,
|
||||
});
|
||||
|
||||
const saveMutation = useMutation({
|
||||
mutationFn: (values: Record<string, string>) => storeMCPUserEnvVars(accessToken!, server!.server_id, values),
|
||||
onSuccess: (saved) => {
|
||||
NotificationsManager.success("Credentials saved");
|
||||
onSaved?.(saved);
|
||||
onClose();
|
||||
},
|
||||
onError: (err) => {
|
||||
NotificationsManager.fromBackend(`Failed to save env vars: ${err instanceof Error ? err.message : String(err)}`);
|
||||
},
|
||||
});
|
||||
|
||||
const handleSave = (values: Record<string, string>) => {
|
||||
if (!server || !accessToken) return;
|
||||
const trimmed: Record<string, string> = {};
|
||||
for (const [k, v] of Object.entries(values)) {
|
||||
trimmed[k] = (v ?? "").trim();
|
||||
}
|
||||
saveMutation.mutate(trimmed);
|
||||
};
|
||||
|
||||
const displayName = server?.server_name || server?.alias || server?.server_id || "MCP Server";
|
||||
const required = status?.required ?? [];
|
||||
const isSaving = saveMutation.isPending;
|
||||
|
||||
return (
|
||||
<Modal
|
||||
open={open}
|
||||
onCancel={onClose}
|
||||
footer={null}
|
||||
width={520}
|
||||
destroyOnHidden
|
||||
afterOpenChange={(opened) => {
|
||||
if (opened) form.resetFields();
|
||||
}}
|
||||
title={
|
||||
<div>
|
||||
<div className="flex items-center gap-2">
|
||||
<Title level={5} style={{ margin: 0 }}>
|
||||
Set your credentials
|
||||
</Title>
|
||||
<Tag color="blue">Per-user</Tag>
|
||||
</div>
|
||||
<Text type="secondary" className="text-xs">
|
||||
{displayName}
|
||||
</Text>
|
||||
</div>
|
||||
}
|
||||
>
|
||||
<div className="space-y-4 mt-2">
|
||||
{isLoading ? (
|
||||
<div className="flex items-center justify-center py-8">
|
||||
<Spin />
|
||||
</div>
|
||||
) : isError ? (
|
||||
<Alert type="error" showIcon message="Failed to load env vars" />
|
||||
) : required.length === 0 ? (
|
||||
<Alert type="info" showIcon message="No per-user fields configured for this server." />
|
||||
) : (
|
||||
<>
|
||||
<Text className="text-sm text-gray-600 block">
|
||||
These values are private to you. Your admin configured this MCP server to require these per-user
|
||||
credentials. Saved values are never shown back; leave an already-set field blank to keep it, or enter a
|
||||
value to set or change it.
|
||||
</Text>
|
||||
<Form form={form} layout="vertical" onFinish={handleSave} disabled={isSaving}>
|
||||
{required.map((spec) => (
|
||||
<Form.Item
|
||||
key={spec.name}
|
||||
name={spec.name}
|
||||
label={
|
||||
<span className="flex items-center gap-2">
|
||||
<span className="font-mono text-sm font-semibold">{spec.name}</span>
|
||||
{spec.is_set && <Tag color="green">Set</Tag>}
|
||||
</span>
|
||||
}
|
||||
extra={spec.description || undefined}
|
||||
rules={spec.is_set ? undefined : [{ required: true, message: `${spec.name} is required` }]}
|
||||
>
|
||||
<Input.Password
|
||||
placeholder={
|
||||
spec.is_set ? "Enter a new value to overwrite" : spec.description || `Enter your ${spec.name}`
|
||||
}
|
||||
visibilityToggle
|
||||
/>
|
||||
</Form.Item>
|
||||
))}
|
||||
<div className="flex items-center justify-end gap-2 pt-2 border-t border-gray-100">
|
||||
<Button onClick={onClose} disabled={isSaving}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button type="primary" htmlType="submit" loading={isSaving}>
|
||||
Save Credentials
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default UserEnvVarsModal;
|
||||
@ -13,8 +13,9 @@ import StdioConfiguration from "./StdioConfiguration";
|
||||
import MCPPermissionManagement from "./MCPPermissionManagement";
|
||||
import OpenAPIFormSection, { OpenAPIKeyTool } from "./OpenAPIFormSection";
|
||||
import MCPLogoSelector from "./MCPLogoSelector";
|
||||
import EnvVarsSection from "./EnvVarsSection";
|
||||
import { isAdminRole } from "@/utils/roles";
|
||||
import { validateMCPServerUrl, validateMCPServerName } from "./utils";
|
||||
import { validateMCPServerUrl, validateMCPServerName, normalizeEnvVars } from "./utils";
|
||||
import NotificationsManager from "../molecules/notifications_manager";
|
||||
import { useMcpOAuthFlow } from "@/hooks/useMcpOAuthFlow";
|
||||
import { useTestMCPConnection } from "@/hooks/useTestMCPConnection";
|
||||
@ -43,7 +44,7 @@ const reduceStaticHeaders = (list: unknown): Record<string, string> => {
|
||||
if (!Array.isArray(list)) return {};
|
||||
return list.reduce((acc: Record<string, string>, entry: Record<string, string>) => {
|
||||
const header = entry?.header?.trim();
|
||||
if (header) acc[header] = entry?.value ?? "";
|
||||
if (header) acc[header] = (entry?.value ?? "").trim();
|
||||
return acc;
|
||||
}, {});
|
||||
};
|
||||
@ -289,6 +290,7 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
|
||||
try {
|
||||
const {
|
||||
static_headers: staticHeadersList,
|
||||
env_vars: envVarsList,
|
||||
stdio_config: rawStdioConfig,
|
||||
credentials: credentialValues,
|
||||
allow_all_keys: allowAllKeysRaw,
|
||||
@ -303,6 +305,7 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
|
||||
const accessGroups = restValues.mcp_access_groups;
|
||||
|
||||
const staticHeaders = reduceStaticHeaders(staticHeadersList);
|
||||
const envVars = normalizeEnvVars(envVarsList);
|
||||
|
||||
const credentialsPayload =
|
||||
credentialValues && typeof credentialValues === "object"
|
||||
@ -403,10 +406,10 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
|
||||
delegate_auth_to_upstream: Boolean(delegateAuthToUpstreamRaw),
|
||||
oauth_passthrough: Boolean(oauthPassthroughRaw),
|
||||
static_headers: staticHeaders,
|
||||
env_vars: envVars,
|
||||
...(tokenValidation !== null && { token_validation: tokenValidation }),
|
||||
};
|
||||
|
||||
payload.static_headers = staticHeaders;
|
||||
const includeCredentials =
|
||||
restValues.auth_type && AUTH_TYPES_REQUIRING_CREDENTIALS.includes(restValues.auth_type);
|
||||
|
||||
@ -1026,6 +1029,11 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
|
||||
<StdioConfiguration isVisible={transportType === "stdio"} />
|
||||
</div>
|
||||
|
||||
{/* Environment Variables Section */}
|
||||
<div className="mt-8">
|
||||
<EnvVarsSection />
|
||||
</div>
|
||||
|
||||
{/* Permission Management / Access Control Section */}
|
||||
<div className="mt-8">
|
||||
<MCPPermissionManagement
|
||||
|
||||
@ -9,7 +9,8 @@ import MCPPermissionManagement from "./MCPPermissionManagement";
|
||||
import MCPToolConfiguration from "./mcp_tool_configuration";
|
||||
import StdioConfiguration from "./StdioConfiguration";
|
||||
import MCPLogoSelector from "./MCPLogoSelector";
|
||||
import { validateMCPServerUrl, validateMCPServerName } from "./utils";
|
||||
import EnvVarsSection from "./EnvVarsSection";
|
||||
import { validateMCPServerUrl, validateMCPServerName, normalizeEnvVars } from "./utils";
|
||||
import NotificationsManager from "../molecules/notifications_manager";
|
||||
import { useMcpOAuthFlow } from "@/hooks/useMcpOAuthFlow";
|
||||
import { getSecureItem, setSecureItem } from "@/utils/secureStorage";
|
||||
@ -117,7 +118,7 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
if (!header) {
|
||||
return acc;
|
||||
}
|
||||
acc[header] = entry?.value ?? "";
|
||||
acc[header] = (entry?.value ?? "").trim();
|
||||
return acc;
|
||||
}, {})
|
||||
: ({} as Record<string, string>);
|
||||
@ -169,6 +170,18 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
}));
|
||||
}, [mcpServer.static_headers]);
|
||||
|
||||
const initialEnvVars = React.useMemo(() => {
|
||||
if (!Array.isArray(mcpServer.env_vars)) {
|
||||
return [];
|
||||
}
|
||||
return mcpServer.env_vars.map((entry) => ({
|
||||
name: entry.name,
|
||||
value: entry.value ?? "",
|
||||
scope: entry.scope === "user" ? "user" : "global",
|
||||
description: entry.description ?? "",
|
||||
}));
|
||||
}, [mcpServer.env_vars]);
|
||||
|
||||
const initialEnvJson = React.useMemo(() => {
|
||||
const env = mcpServer.env ?? undefined;
|
||||
if (!env || Object.keys(env).length === 0) {
|
||||
@ -194,13 +207,14 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
...mcpServer,
|
||||
transport: effectiveTransport,
|
||||
static_headers: initialStaticHeaders,
|
||||
env_vars: initialEnvVars,
|
||||
extra_headers: mcpServer.extra_headers || [],
|
||||
oauth_flow_type: mcpServer.token_url ? OAUTH_FLOW.M2M : OAUTH_FLOW.INTERACTIVE,
|
||||
token_validation_json: mcpServer.token_validation
|
||||
? JSON.stringify(mcpServer.token_validation, null, 2)
|
||||
: undefined,
|
||||
}),
|
||||
[mcpServer, effectiveTransport, initialStaticHeaders, initialEnvJson],
|
||||
[mcpServer, effectiveTransport, initialStaticHeaders, initialEnvVars, initialEnvJson],
|
||||
);
|
||||
|
||||
// Initialize cost config from existing server data
|
||||
@ -388,6 +402,7 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
// Ensure access groups is always a string array
|
||||
const {
|
||||
static_headers: staticHeadersList,
|
||||
env_vars: envVarsList,
|
||||
credentials: credentialValues,
|
||||
stdio_config: rawStdioConfig,
|
||||
env_json: rawEnvJson,
|
||||
@ -411,11 +426,13 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
if (!header) {
|
||||
return acc;
|
||||
}
|
||||
acc[header] = entry?.value ?? "";
|
||||
acc[header] = (entry?.value ?? "").trim();
|
||||
return acc;
|
||||
}, {})
|
||||
: ({} as Record<string, string>);
|
||||
|
||||
const envVars = normalizeEnvVars(envVarsList);
|
||||
|
||||
const credentialsPayload =
|
||||
credentialValues && typeof credentialValues === "object"
|
||||
? Object.entries(credentialValues).reduce((acc: Record<string, any>, [key, value]) => {
|
||||
@ -571,6 +588,7 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
tool_name_to_description: Object.keys(toolNameToDescription).length > 0 ? toolNameToDescription : null,
|
||||
disallowed_tools: restValues.disallowed_tools || [],
|
||||
static_headers: staticHeaders,
|
||||
env_vars: envVars,
|
||||
allow_all_keys: Boolean(allowAllKeysRaw ?? mcpServer.allow_all_keys),
|
||||
available_on_public_internet: Boolean(availableOnPublicInternetRaw ?? mcpServer.available_on_public_internet),
|
||||
// ``delegate_auth_to_upstream`` is only honored server-side for
|
||||
@ -1108,6 +1126,11 @@ const MCPServerEdit: React.FC<MCPServerEditProps> = ({
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Environment Variables Section */}
|
||||
<div className="mt-6">
|
||||
<EnvVarsSection />
|
||||
</div>
|
||||
|
||||
{/* Permission Management / Access Control Section */}
|
||||
<div className="mt-6">
|
||||
<MCPPermissionManagement
|
||||
|
||||
@ -15,6 +15,7 @@ vi.mock("../networking", () => ({
|
||||
getGeneralSettingsCall: vi.fn().mockResolvedValue([]),
|
||||
updateConfigFieldSetting: vi.fn().mockResolvedValue(undefined),
|
||||
deleteConfigFieldSetting: vi.fn().mockResolvedValue(undefined),
|
||||
listMCPUserEnvVarStatus: vi.fn().mockResolvedValue([]),
|
||||
}));
|
||||
|
||||
// Mock NotificationsManager
|
||||
|
||||
@ -1,26 +1,72 @@
|
||||
import { isAdminRole } from "@/utils/roles";
|
||||
import { QuestionCircleOutlined } from "@ant-design/icons";
|
||||
import { QuestionCircleOutlined, SearchOutlined } from "@ant-design/icons";
|
||||
import { Button, Tab, TabGroup, TabList, TabPanel, TabPanels, Text, Title } from "@tremor/react";
|
||||
import NewBadge from "../common_components/NewBadge";
|
||||
import { Descriptions, Modal, Select, Tooltip, Typography } from "antd";
|
||||
import { Descriptions, Empty, Input, Modal, Select, Spin, Tooltip, Typography } from "antd";
|
||||
import React, { useEffect, useState, useMemo, useCallback } from "react";
|
||||
import { useQuery } from "@tanstack/react-query";
|
||||
import { useMCPServers } from "../../app/(dashboard)/hooks/mcpServers/useMCPServers";
|
||||
import { useMCPServerHealth } from "../../app/(dashboard)/hooks/mcpServers/useMCPServerHealth";
|
||||
import NotificationsManager from "../molecules/notifications_manager";
|
||||
import { deleteMCPServer } from "../networking";
|
||||
import { MCPSubmissionsTab } from "./MCPSubmissionsTab";
|
||||
import { MCPToolsetsTab } from "./MCPToolsetsTab";
|
||||
import { DataTable } from "../view_logs/table";
|
||||
import CreateMCPServer from "./create_mcp_server";
|
||||
import MCPConnect from "./mcp_connect";
|
||||
import { mcpServerColumns } from "./mcp_server_columns";
|
||||
import MCPServerCard from "./MCPServerCard";
|
||||
import { MCPServerView } from "./mcp_server_view";
|
||||
import { DiscoverableMCPServer, MCPServer, MCPServerProps, Team } from "./types";
|
||||
import type { DiscoverableMCPServer, MCPServer, MCPServerProps, MCPUserEnvVarsStatus, Team } from "./types";
|
||||
import MCPSemanticFilterSettings from "../Settings/AdminSettings/MCPSemanticFilterSettings/MCPSemanticFilterSettings";
|
||||
import MCPNetworkSettings from "./MCPNetworkSettings";
|
||||
import MCPDiscovery from "./mcp_discovery";
|
||||
import { ByokCredentialModal } from "./ByokCredentialModal";
|
||||
import { getSecureItem } from "@/utils/secureStorage";
|
||||
import UserEnvVarsModal from "./UserEnvVarsModal";
|
||||
import { listMCPUserEnvVarStatus } from "../networking";
|
||||
|
||||
type SortKey = "created_desc" | "updated_desc" | "name_asc" | "health";
|
||||
|
||||
const SORT_OPTIONS: { value: SortKey; label: string }[] = [
|
||||
{ value: "created_desc", label: "Recently created" },
|
||||
{ value: "updated_desc", label: "Recently updated" },
|
||||
{ value: "name_asc", label: "Name (A→Z)" },
|
||||
{ value: "health", label: "Health (unhealthy first)" },
|
||||
];
|
||||
|
||||
const HEALTH_RANK: Record<string, number> = {
|
||||
unhealthy: 0,
|
||||
unknown: 1,
|
||||
healthy: 2,
|
||||
};
|
||||
|
||||
const compareServers = (a: MCPServer, b: MCPServer, sort: SortKey): number => {
|
||||
switch (sort) {
|
||||
case "name_asc": {
|
||||
const nameA = (a.server_name || a.alias || a.server_id).toLowerCase();
|
||||
const nameB = (b.server_name || b.alias || b.server_id).toLowerCase();
|
||||
return nameA.localeCompare(nameB);
|
||||
}
|
||||
case "updated_desc": {
|
||||
const ta = a.updated_at ? new Date(a.updated_at).getTime() : 0;
|
||||
const tb = b.updated_at ? new Date(b.updated_at).getTime() : 0;
|
||||
return tb - ta;
|
||||
}
|
||||
case "health": {
|
||||
const ra = HEALTH_RANK[a.status ?? "unknown"] ?? 1;
|
||||
const rb = HEALTH_RANK[b.status ?? "unknown"] ?? 1;
|
||||
if (ra !== rb) return ra - rb;
|
||||
const ta = a.created_at ? new Date(a.created_at).getTime() : 0;
|
||||
const tb = b.created_at ? new Date(b.created_at).getTime() : 0;
|
||||
return tb - ta;
|
||||
}
|
||||
case "created_desc":
|
||||
default: {
|
||||
const ta = a.created_at ? new Date(a.created_at).getTime() : 0;
|
||||
const tb = b.created_at ? new Date(b.created_at).getTime() : 0;
|
||||
return tb - ta;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const { Text: AntdText, Title: AntdTitle } = Typography;
|
||||
const EDIT_OAUTH_UI_STATE_KEY = "litellm-mcp-oauth-edit-state";
|
||||
@ -67,8 +113,52 @@ const MCPServers: React.FC<MCPServerProps> = ({ accessToken, userRole, userID })
|
||||
const [prefillData, setPrefillData] = useState<DiscoverableMCPServer | null>(null);
|
||||
const [isDeletingServer, setIsDeletingServer] = useState(false);
|
||||
const [byokModalServer, setByokModalServer] = useState<MCPServer | null>(null);
|
||||
// Per-user env-var fill modal target + deep-link source captured once from the URL.
|
||||
const [envVarsModalServer, setEnvVarsModalServer] = useState<MCPServer | null>(null);
|
||||
const [deepLinkServerId, setDeepLinkServerId] = useState<string | null>(() =>
|
||||
typeof window === "undefined" ? null : new URLSearchParams(window.location.search).get("fill_env_vars"),
|
||||
);
|
||||
const [searchQuery, setSearchQuery] = useState<string>("");
|
||||
const [sortKey, setSortKey] = useState<SortKey>("created_desc");
|
||||
const isInternalUser = userRole === "Internal User";
|
||||
|
||||
// Single bulk fetch of this user's per-server env-var status. Drives the
|
||||
// red "N user fields missing" footer on each card with no per-row request.
|
||||
const { data: envVarStatuses, refetch: refetchEnvVarStatus } = useQuery<MCPUserEnvVarsStatus[]>({
|
||||
queryKey: ["mcpUserEnvVarStatus"],
|
||||
queryFn: () => listMCPUserEnvVarStatus(accessToken!),
|
||||
enabled: !!accessToken,
|
||||
});
|
||||
|
||||
// Per-server list of per-user fields this user still needs to fill in.
|
||||
const missingFieldsByServer = useMemo(() => {
|
||||
const map: Record<string, string[]> = {};
|
||||
for (const status of envVarStatuses ?? []) {
|
||||
map[status.server_id] = (status.required ?? []).filter((spec) => !spec.is_set).map((spec) => spec.name);
|
||||
}
|
||||
return map;
|
||||
}, [envVarStatuses]);
|
||||
|
||||
// Deep-link via ?fill_env_vars=<server_id> — the link users follow from the
|
||||
// friendly error the proxy returns when a per-user var is missing. The id is
|
||||
// captured into state above and resolved to a server below; here we only strip
|
||||
// the param so a refresh doesn't reopen the modal.
|
||||
useEffect(() => {
|
||||
if (!deepLinkServerId || typeof window === "undefined") return;
|
||||
const params = new URLSearchParams(window.location.search);
|
||||
if (!params.has("fill_env_vars")) return;
|
||||
params.delete("fill_env_vars");
|
||||
const newSearch = params.toString();
|
||||
const newUrl = window.location.pathname + (newSearch ? `?${newSearch}` : "") + window.location.hash;
|
||||
window.history.replaceState({}, "", newUrl);
|
||||
}, [deepLinkServerId]);
|
||||
|
||||
const deepLinkServer = useMemo(
|
||||
() => (deepLinkServerId ? serversWithHealth.find((s) => s.server_id === deepLinkServerId) ?? null : null),
|
||||
[deepLinkServerId, serversWithHealth],
|
||||
);
|
||||
const activeEnvVarsServer = envVarsModalServer ?? deepLinkServer;
|
||||
|
||||
useEffect(() => {
|
||||
if (typeof window === "undefined") {
|
||||
return;
|
||||
@ -164,26 +254,20 @@ const MCPServers: React.FC<MCPServerProps> = ({ accessToken, userRole, userID })
|
||||
filterServers(selectedTeam, selectedMcpAccessGroup);
|
||||
}, [serversWithHealth, selectedTeam, selectedMcpAccessGroup, filterServers]);
|
||||
|
||||
const columns = React.useMemo(
|
||||
() =>
|
||||
mcpServerColumns(
|
||||
userRole ?? "",
|
||||
(serverId: string) => {
|
||||
setSelectedServerId(serverId);
|
||||
setEditServer(false);
|
||||
},
|
||||
(serverId: string) => {
|
||||
setSelectedServerId(serverId);
|
||||
setEditServer(true);
|
||||
},
|
||||
handleDelete,
|
||||
isLoadingHealth,
|
||||
(server: MCPServer) => setByokModalServer(server),
|
||||
recheckServerHealth,
|
||||
recheckingServerIds,
|
||||
),
|
||||
[userRole, isLoadingHealth, recheckServerHealth, recheckingServerIds],
|
||||
);
|
||||
// Search + sort layer applied on top of the team/access-group filters.
|
||||
const displayedServers = useMemo(() => {
|
||||
const q = searchQuery.trim().toLowerCase();
|
||||
const matches = q
|
||||
? filteredServers.filter((s) => {
|
||||
const name = (s.server_name || "").toLowerCase();
|
||||
const alias = (s.alias || "").toLowerCase();
|
||||
const url = (s.url || "").toLowerCase();
|
||||
const id = s.server_id.toLowerCase();
|
||||
return name.includes(q) || alias.includes(q) || url.includes(q) || id.includes(q);
|
||||
})
|
||||
: filteredServers;
|
||||
return [...matches].sort((a, b) => compareServers(a, b, sortKey));
|
||||
}, [filteredServers, searchQuery, sortKey]);
|
||||
|
||||
function handleDelete(server_id: string) {
|
||||
setServerToDelete(server_id);
|
||||
@ -198,6 +282,14 @@ const MCPServers: React.FC<MCPServerProps> = ({ accessToken, userRole, userID })
|
||||
setIsDeletingServer(true);
|
||||
await deleteMCPServer(accessToken, serverIdToDelete);
|
||||
NotificationsManager.success("Deleted MCP Server successfully");
|
||||
// If the user is currently viewing the detail page of the server they
|
||||
// just deleted, return them to the All Servers list. Otherwise the
|
||||
// detail view would stay mounted, fall back to an empty stub server,
|
||||
// and show a phantom "Unnamed Server" page.
|
||||
if (selectedServerId === serverIdToDelete) {
|
||||
setEditServer(false);
|
||||
setSelectedServerId(null);
|
||||
}
|
||||
refetch();
|
||||
} catch (error) {
|
||||
console.error("Error deleting the mcp server:", error);
|
||||
@ -442,17 +534,75 @@ const MCPServers: React.FC<MCPServerProps> = ({ accessToken, userRole, userID })
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="w-full mt-6">
|
||||
<DataTable
|
||||
data={filteredServers}
|
||||
columns={columns}
|
||||
renderSubComponent={() => <div></div>}
|
||||
getRowCanExpand={() => false}
|
||||
isLoading={isLoadingServers}
|
||||
noDataMessage="No MCP servers configured. Click '+ Add New MCP Server' to get started."
|
||||
loadingMessage="Loading MCP servers..."
|
||||
enableSorting={true}
|
||||
<div className="mt-4 flex flex-wrap items-center gap-3">
|
||||
<Input
|
||||
allowClear
|
||||
prefix={<SearchOutlined className="text-gray-400" />}
|
||||
placeholder="Search by name, alias, URL, or ID"
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
style={{ maxWidth: 320 }}
|
||||
/>
|
||||
<div className="flex items-center gap-2">
|
||||
<Text className="whitespace-nowrap text-sm font-medium text-gray-600">Sort</Text>
|
||||
<Select
|
||||
value={sortKey}
|
||||
onChange={(v: SortKey) => setSortKey(v)}
|
||||
style={{ width: 220 }}
|
||||
size="middle"
|
||||
>
|
||||
{SORT_OPTIONS.map((opt) => (
|
||||
<Option key={opt.value} value={opt.value}>
|
||||
{opt.label}
|
||||
</Option>
|
||||
))}
|
||||
</Select>
|
||||
</div>
|
||||
<div className="ml-auto text-xs text-gray-500">
|
||||
{displayedServers.length} of {filteredServers.length} servers
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-4 w-full">
|
||||
{isLoadingServers ? (
|
||||
<div className="flex items-center justify-center rounded-lg border border-dashed border-gray-200 bg-white p-12">
|
||||
<Spin tip="Loading MCP servers..." />
|
||||
</div>
|
||||
) : displayedServers.length === 0 ? (
|
||||
<div className="rounded-lg border border-dashed border-gray-200 bg-white p-12">
|
||||
<Empty
|
||||
description={
|
||||
filteredServers.length === 0
|
||||
? "No MCP servers configured. Click '+ Add New MCP Server' to get started."
|
||||
: "No servers match the current filters or search."
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div
|
||||
data-testid="mcp-servers-grid"
|
||||
className="grid auto-rows-fr grid-cols-1 gap-4 md:grid-cols-2 xl:grid-cols-3"
|
||||
>
|
||||
{displayedServers.map((server) => (
|
||||
<MCPServerCard
|
||||
key={server.server_id}
|
||||
server={server}
|
||||
missingUserFields={missingFieldsByServer[server.server_id]}
|
||||
isLoadingHealth={isLoadingHealth}
|
||||
isRechecking={recheckingServerIds?.has(server.server_id)}
|
||||
onClick={() => {
|
||||
setSelectedServerId(server.server_id);
|
||||
setEditServer(true);
|
||||
}}
|
||||
onRecheckHealth={
|
||||
recheckServerHealth ? () => recheckServerHealth(server.server_id) : undefined
|
||||
}
|
||||
onByokConnect={server.is_byok ? () => setByokModalServer(server) : undefined}
|
||||
onOpenFillFields={() => setEnvVarsModalServer(server)}
|
||||
onDelete={isAdminRole(userRole) ? () => handleDelete(server.server_id) : undefined}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
@ -489,6 +639,22 @@ const MCPServers: React.FC<MCPServerProps> = ({ accessToken, userRole, userID })
|
||||
accessToken={accessToken || ""}
|
||||
/>
|
||||
)}
|
||||
|
||||
{/* Per-user env-var fill modal — backed by /v1/mcp/server/{id}/user-env-vars */}
|
||||
<UserEnvVarsModal
|
||||
server={activeEnvVarsServer}
|
||||
open={!!activeEnvVarsServer}
|
||||
accessToken={accessToken}
|
||||
onClose={() => {
|
||||
setEnvVarsModalServer(null);
|
||||
setDeepLinkServerId(null);
|
||||
}}
|
||||
onSaved={() => {
|
||||
// Refresh the bulk status so the red "N user fields missing" footer
|
||||
// on each card clears once the user has filled in their values.
|
||||
refetchEnvVarStatus();
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@ -263,6 +263,41 @@ export interface MCPServer {
|
||||
/** Per-user OAuth token storage settings (interactive OAuth only) */
|
||||
token_validation?: Record<string, any> | null;
|
||||
token_storage_ttl_seconds?: number | null;
|
||||
|
||||
/**
|
||||
* Admin-configured env vars interpolated into static_headers via ${NAME}.
|
||||
* Stored as a list so the UI can preserve admin-entered ordering.
|
||||
*/
|
||||
env_vars?: MCPEnvVar[] | null;
|
||||
}
|
||||
|
||||
/** One environment variable entry on an MCP server. */
|
||||
export type MCPEnvVarScope = "global" | "user";
|
||||
|
||||
export interface MCPEnvVar {
|
||||
name: string;
|
||||
/** For scope="global": the value used in interpolation.
|
||||
* For scope="user": optional placeholder/description shown to users. */
|
||||
value: string;
|
||||
scope: MCPEnvVarScope;
|
||||
description?: string | null;
|
||||
}
|
||||
|
||||
/** One required per-user env var slot returned by the user-env-vars endpoint. */
|
||||
export interface MCPUserEnvVarSpec {
|
||||
name: string;
|
||||
description?: string | null;
|
||||
is_set: boolean;
|
||||
}
|
||||
|
||||
/** Per-server per-user env var status returned by the API. */
|
||||
export interface MCPUserEnvVarsStatus {
|
||||
server_id: string;
|
||||
server_name?: string | null;
|
||||
alias?: string | null;
|
||||
required: MCPUserEnvVarSpec[];
|
||||
missing_count: number;
|
||||
setup_url?: string | null;
|
||||
}
|
||||
|
||||
export interface MCPServerProps {
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
import { MCPEnvVar, MCPEnvVarScope } from "./types";
|
||||
|
||||
export const extractMCPToken = (url: string): { token: string | null; baseUrl: string } => {
|
||||
try {
|
||||
const mcpIndex = url.indexOf("/mcp/");
|
||||
@ -51,3 +53,27 @@ export const validateMCPServerName = (value: string) => {
|
||||
? Promise.reject("Cannot contain '-' (hyphen) or spaces. Please use '_' (underscore) instead.")
|
||||
: Promise.resolve();
|
||||
};
|
||||
|
||||
// Normalize the env_vars form list into the payload shape the backend expects.
|
||||
// Drops empty rows, invalid identifiers, and duplicate names; user-scoped entries never carry a value.
|
||||
export const normalizeEnvVars = (list: unknown): MCPEnvVar[] => {
|
||||
if (!Array.isArray(list)) return [];
|
||||
const seen = new Set<string>();
|
||||
const out: MCPEnvVar[] = [];
|
||||
for (const entry of list) {
|
||||
if (!entry || typeof entry !== "object") continue;
|
||||
const record = entry as Record<string, unknown>;
|
||||
const name = String(record.name ?? "").trim();
|
||||
if (!name || seen.has(name)) continue;
|
||||
if (!/^[A-Za-z_][A-Za-z0-9_]*$/.test(name)) continue;
|
||||
const scope: MCPEnvVarScope = record.scope === "user" ? "user" : "global";
|
||||
out.push({
|
||||
name,
|
||||
value: scope === "user" ? "" : String(record.value ?? ""),
|
||||
scope,
|
||||
description: (record.description as string | undefined) || undefined,
|
||||
});
|
||||
seen.add(name);
|
||||
}
|
||||
return out;
|
||||
};
|
||||
|
||||
@ -76,6 +76,7 @@ import { UserInfo } from "./view_users/types";
|
||||
import { EmailEventSettingsResponse, EmailEventSettingsUpdateRequest } from "./email_events/types";
|
||||
import { jsonFields } from "./common_components/check_openapi_schema";
|
||||
import NotificationsManager from "./molecules/notifications_manager";
|
||||
import type { MCPUserEnvVarsStatus } from "./mcp_tools/types";
|
||||
import { createApiClient, deriveErrorMessage } from "@/lib/http/client";
|
||||
import { resolveApiBase } from "@/lib/http/resolveApiBase";
|
||||
|
||||
@ -9620,6 +9621,35 @@ export const listMCPUserCredentials = async (accessToken: string): Promise<MCPUs
|
||||
return response.json();
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// MCP per-user env vars (/v1/mcp/server/{id}/user-env-vars)
|
||||
// ============================================================
|
||||
|
||||
export const getMCPUserEnvVars = async (accessToken: string, serverId: string): Promise<MCPUserEnvVarsStatus> => {
|
||||
return apiClient.get<MCPUserEnvVarsStatus>(`/v1/mcp/server/${serverId}/user-env-vars`, { accessToken });
|
||||
};
|
||||
|
||||
export const storeMCPUserEnvVars = async (
|
||||
accessToken: string,
|
||||
serverId: string,
|
||||
values: Record<string, string>,
|
||||
): Promise<MCPUserEnvVarsStatus> => {
|
||||
return apiClient.post<MCPUserEnvVarsStatus>(`/v1/mcp/server/${serverId}/user-env-vars`, {
|
||||
accessToken,
|
||||
body: { values },
|
||||
});
|
||||
};
|
||||
|
||||
export const listMCPUserEnvVarStatus = async (accessToken: string): Promise<MCPUserEnvVarsStatus[]> => {
|
||||
// Best-effort status badges: a failure here must not break the page, so fall
|
||||
// back to an empty list rather than surfacing the error to the caller.
|
||||
try {
|
||||
return await apiClient.get<MCPUserEnvVarsStatus[]>("/v1/mcp/user-env-vars/status", { accessToken });
|
||||
} catch {
|
||||
return [];
|
||||
}
|
||||
};
|
||||
|
||||
// ============================================================
|
||||
// Memory management (/v1/memory)
|
||||
// ============================================================
|
||||
|
||||
Loading…
Reference in New Issue
Block a user