fix(mcp): OAuth2 chat connect - tools fetch, auth, and status fixes (#23406)

* fix(mcp): OAuth2 chat connect - tools fetch, auth flow, and status fixes

- schema.prisma: add missing MCP table fields (approval_status, submitted_by, submitted_at, reviewed_at, review_notes) to prevent destructive migrations
- rest_endpoints.py: inject user OAuth token via extra_headers for OAuth2 servers so tools list is populated; add server name->UUID resolution so MCPConnectPicker name lookups work
- mcp_registry.json: fix Atlassian defaults (transport: http, url: .../v1/mcp)
- ChatPage.tsx: read mcpOauthReturn param to init sidebarView="apps" on OAuth return, clean up param after mount
- MCPAppsPanel.tsx: auto-add OAuth2 servers to selectedServers when credential detected; onConnect also enables server for chat; disconnect removes from selectedServers
- mcp_servers.tsx: sort servers by created_at DESC
- useUserMcpOAuthFlow.tsx: append mcpOauthReturn=apps to return URL so Apps panel is mounted on return

* address greptile review feedback (greploop iteration 1)

* fix(mcp): inject stored OAuth2 token when fetching tools via /responses API

When a user has connected an OAuth2 MCP server (e.g. Atlassian) and then
uses the /responses endpoint with that server, tool listing was failing
because the stored per-user OAuth token was never injected.

Two fixes:
1. server.py: add _get_user_oauth_extra_headers_from_db() helper; call it
   in _get_tools_from_mcp_servers when oauth2_headers is None for an OAuth2
   server, falling back to the user's stored token in LiteLLM_MCPUserCredentials
2. litellm_proxy_mcp_handler.py: also intercept MCP tools whose server_url
   matches */mcp/<server_name> (e.g. http://localhost:4000/mcp/atlassian_test)
   by rewriting them to litellm_proxy/mcp/<server_name> so they go through
   the internal handler (and get the OAuth token injected) instead of being
   forwarded to OpenAI raw where localhost is unreachable

* address greptile review feedback (greploop iteration 2)

* test(mcp): add unit test for OAuth2 token injection in _get_tools_from_mcp_servers

Verifies that when _get_tools_from_mcp_servers is called for an OAuth2 MCP
server without oauth2_headers in the request, the implementation:
- calls _prefetch_oauth_creds_for_user once (not per-server) to avoid N+1 queries
- passes the stored token as extra_headers={"Authorization": "Bearer ..."} to
  _get_tools_from_server so the upstream OAuth2 MCP server authenticates correctly

* address greptile review feedback (greploop iteration 3)

* address greptile review feedback (greploop iteration 4)

* address greptile review feedback (greploop iteration 5)

* redesign credentials table to use Tremor table layout matching Keys page

* fix: /server/oauth authorize 422 - make client_id optional, fall back to real DB server

* fix: mcp_token client_id optional, resolve from server record

* fix: look up real server by UUID (get_mcp_server_by_id) before falling back to name

* Update litellm/responses/mcp/litellm_proxy_mcp_handler.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* fix: address greptile feedback - client_id guards, dict spread, helper refactor, tests

- mcp_management_endpoints: raise 400 when resolved_client_id is empty in
  mcp_authorize and mcp_token instead of forwarding "" to upstream
- litellm_proxy_mcp_handler: use {**tool, "server_url": ...} spread instead
  of dict(tool) + mutation for shallow copy safety
- rest_endpoints: extract _oauth2_server_ids set comprehension to a named
  _get_oauth2_server_ids() helper for clarity; add Set to typing imports
- test_rest_endpoints: add tests for name→UUID resolution path,
  access-denied when resolved UUID not in allowed list, and OAuth2 user
  token injection for single-server requests; fix fake_get_tools signature
  to accept extra_headers kwarg

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
This commit is contained in:
Ishaan Jaff 2026-03-11 22:07:02 -07:00 committed by GitHub
parent 2b7b7d3086
commit 19db79db17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 596 additions and 109 deletions

View File

@ -315,6 +315,11 @@ model LiteLLM_MCPServerTable {
is_byok Boolean @default(false)
byok_description String[] @default([])
byok_api_key_help_url String?
approval_status String @default("approved")
submitted_by String?
submitted_at DateTime?
reviewed_at DateTime?
review_notes String?
}
// Per-user BYOK credentials for MCP servers

View File

@ -628,6 +628,24 @@ async def store_user_oauth_credential(
)
def is_oauth_credential_expired(cred: Dict[str, Any]) -> bool:
"""Return True if the OAuth2 credential's access_token has expired.
Checks the ``expires_at`` ISO-format string stored in the credential payload.
Returns False when ``expires_at`` is absent or unparseable (treat as non-expired).
"""
expires_at = cred.get("expires_at")
if not expires_at:
return False
try:
exp_dt = datetime.fromisoformat(expires_at)
if exp_dt.tzinfo is None:
exp_dt = exp_dt.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > exp_dt
except (ValueError, TypeError):
return False
async def get_user_oauth_credential(
prisma_client: PrismaClient,
user_id: str,

View File

@ -1,6 +1,6 @@
import importlib
from datetime import datetime
from typing import Any, Awaitable, Callable, Dict, List, Optional, Union
from datetime import datetime, timezone
from typing import Any, Awaitable, Callable, Dict, List, Optional, Set, Union
from fastapi import APIRouter, Depends, HTTPException, Query, Request
@ -69,18 +69,36 @@ if MCP_AVAILABLE:
return server_auth
return mcp_auth_header
def _get_oauth2_server_ids(allowed_server_ids: List[str]) -> Set[str]:
"""Return the subset of *allowed_server_ids* whose servers use OAuth2 auth.
Used as a cheap pre-flight check to skip bulk credential fetching when no
OAuth2 servers are involved in the current request.
"""
return {
sid
for sid in allowed_server_ids
if getattr(
global_mcp_server_manager.get_mcp_server_by_id(sid), "auth_type", None
)
== MCPAuth.oauth2
}
async def _get_user_oauth_extra_headers(
server,
user_api_key_dict: UserAPIKeyAuth,
prefetched_creds: Optional[Dict[str, Dict[str, Any]]] = None,
) -> Optional[Dict[str, str]]:
"""
For OAuth2 servers, look up the user's stored access token and return it
as extra_headers {"Authorization": "Bearer <token>"} so that it reaches
the MCP server the same way the admin "Add MCP / Authorize and Fetch" flow does.
Returns None for non-OAuth2 servers or when no credential is stored.
"""
from litellm.types.mcp import MCPAuth
Args:
prefetched_creds: Optional dict keyed by server_id with credential payloads.
When provided, avoids a per-server DB round-trip.
"""
if getattr(server, "auth_type", None) != MCPAuth.oauth2:
return None
user_id = getattr(user_api_key_dict, "user_id", None)
@ -90,18 +108,60 @@ if MCP_AVAILABLE:
try:
from litellm.proxy._experimental.mcp_server.db import (
get_user_oauth_credential,
is_oauth_credential_expired,
)
if prefetched_creds is not None:
cred = prefetched_creds.get(server_id)
else:
from litellm.proxy.utils import get_prisma_client_or_throw
prisma_client = get_prisma_client_or_throw(
"Database not connected. Connect a database to use OAuth2 MCP tools."
)
cred = await get_user_oauth_credential(prisma_client, user_id, server_id)
if cred and cred.get("access_token"):
if is_oauth_credential_expired(cred):
verbose_logger.debug(
f"_get_user_oauth_extra_headers: token expired for "
f"user={user_id} server={server_id}"
)
return None
return {"Authorization": f"Bearer {cred['access_token']}"}
except Exception as e:
verbose_logger.warning(
f"_get_user_oauth_extra_headers: failed to retrieve credential for "
f"user={user_id} server={server_id}: {e}"
)
return None
async def _prefetch_user_oauth_creds(
user_api_key_dict: UserAPIKeyAuth,
) -> Dict[str, Dict[str, Any]]:
"""Fetch all OAuth2 credentials for the user in a single DB query.
Returns a dict keyed by server_id. Used to avoid N+1 DB queries when
iterating over multiple OAuth2 MCP servers.
"""
user_id = getattr(user_api_key_dict, "user_id", None)
if not user_id:
return {}
try:
from litellm.proxy._experimental.mcp_server.db import (
list_user_oauth_credentials,
)
from litellm.proxy.utils import get_prisma_client_or_throw
prisma_client = get_prisma_client_or_throw(
"Database not connected. Connect a database to use OAuth2 MCP tools."
)
cred = await get_user_oauth_credential(prisma_client, user_id, server_id)
if cred and cred.get("access_token"):
return {"Authorization": f"Bearer {cred['access_token']}"}
except Exception:
verbose_logger.debug("Failed to fetch OAuth credential", exc_info=True)
return None
creds = await list_user_oauth_credentials(prisma_client, user_id)
return {c["server_id"]: c for c in creds if "server_id" in c}
except Exception as e:
verbose_logger.warning(
f"_prefetch_user_oauth_creds: failed to prefetch for user={user_id}: {e}"
)
return {}
async def _get_bulk_user_oauth_headers(
user_api_key_dict: UserAPIKeyAuth,
@ -364,13 +424,21 @@ if MCP_AVAILABLE:
if server_id:
# Resolve a server name to its UUID if needed (MCPConnectPicker passes
# server_name strings, but allowed_server_ids_set contains UUIDs).
# _name_resolved is kept so the second check can reuse it for accurate
# IP-filter error reporting if the resolved UUID is not in allowed_server_ids.
_name_resolved = None
if server_id not in allowed_server_ids:
_resolved = global_mcp_server_manager.get_mcp_server_by_name(server_id)
if _resolved is not None and _resolved.server_id in set(allowed_server_ids):
server_id = _resolved.server_id
_name_resolved = global_mcp_server_manager.get_mcp_server_by_name(server_id)
if _name_resolved is not None and _name_resolved.server_id in set(allowed_server_ids):
server_id = _name_resolved.server_id
if server_id not in allowed_server_ids:
_server = global_mcp_server_manager.get_mcp_server_by_id(server_id)
# Try UUID lookup first; fall back to the name-resolved server so that
# IP-filter reporting works correctly even when server_id is a name string.
_server = (
global_mcp_server_manager.get_mcp_server_by_id(server_id)
or _name_resolved
)
if (
_server is not None
and _rest_client_ip is not None
@ -451,10 +519,15 @@ if MCP_AVAILABLE:
},
)
# Query all servers the user has access to.
# Bulk-fetch OAuth creds once so each per-server call below can
# do an O(1) dict lookup instead of N individual DB queries.
bulk_oauth_headers = await _get_bulk_user_oauth_headers(user_api_key_dict)
# Pre-fetch OAuth credentials only when at least one allowed server uses OAuth2,
# to avoid an unnecessary DB round-trip on requests with no OAuth2 MCP servers.
prefetched_oauth_creds = (
await _prefetch_user_oauth_creds(user_api_key_dict)
if _get_oauth2_server_ids(allowed_server_ids)
else {}
)
# Query all servers the user has access to
errors = []
for allowed_server_id in allowed_server_ids:
server = global_mcp_server_manager.get_mcp_server_by_id(
@ -466,7 +539,9 @@ if MCP_AVAILABLE:
server_auth_header = _get_server_auth_header(
server, mcp_server_auth_headers, mcp_auth_header
)
user_oauth_extra_headers = bulk_oauth_headers.get(server.server_id)
user_oauth_extra_headers = await _get_user_oauth_extra_headers(
server, user_api_key_dict, prefetched_creds=prefetched_oauth_creds
)
try:
tools_result = await _get_tools_for_single_server(

View File

@ -8,7 +8,7 @@ import contextlib
import time
import traceback
import uuid
from datetime import datetime
from datetime import datetime, timezone
from typing import (
Any,
AsyncIterator,
@ -871,6 +871,84 @@ if MCP_AVAILABLE:
return allowed_mcp_servers
async def _get_user_oauth_extra_headers_from_db(
server: MCPServer,
user_api_key_auth: Optional[UserAPIKeyAuth],
prefetched_creds: Optional[Dict[str, Dict[str, Any]]] = None,
) -> Optional[Dict[str, str]]:
"""Look up stored OAuth2 token for (user, server) from DB and return as extra_headers dict.
Args:
prefetched_creds: Optional dict keyed by server_id with credential payloads.
When provided, avoids a per-server DB round-trip.
"""
if server.auth_type != MCPAuth.oauth2:
return None
if user_api_key_auth is None:
return None
user_id = getattr(user_api_key_auth, "user_id", None)
server_id = getattr(server, "server_id", None)
if not user_id or not server_id:
return None
try:
from litellm.proxy._experimental.mcp_server.db import ( # noqa: PLC0415
get_user_oauth_credential,
is_oauth_credential_expired,
)
if prefetched_creds is not None:
cred = prefetched_creds.get(server_id)
else:
from litellm.proxy.utils import ( # noqa: PLC0415
get_prisma_client_or_throw,
)
prisma_client = get_prisma_client_or_throw(
"Database not connected. Connect a database to use OAuth2 MCP tools."
)
cred = await get_user_oauth_credential(prisma_client, user_id, server_id)
if cred and cred.get("access_token"):
if is_oauth_credential_expired(cred):
verbose_logger.debug(
f"_get_user_oauth_extra_headers_from_db: token expired for "
f"user={user_id} server={server_id}"
)
return None
return {"Authorization": f"Bearer {cred['access_token']}"}
except Exception as e:
verbose_logger.warning(
f"_get_user_oauth_extra_headers_from_db: failed to retrieve credential for "
f"user={user_id} server={server_id}: {e}"
)
return None
async def _prefetch_oauth_creds_for_user(
user_api_key_auth: Optional[UserAPIKeyAuth],
) -> Dict[str, Dict[str, Any]]:
"""Fetch all OAuth2 credentials for the user in one DB query.
Returns a dict keyed by server_id to avoid N+1 queries in asyncio.gather loops.
"""
user_id = getattr(user_api_key_auth, "user_id", None) if user_api_key_auth else None
if not user_id:
return {}
try:
from litellm.proxy._experimental.mcp_server.db import ( # noqa: PLC0415
list_user_oauth_credentials,
)
from litellm.proxy.utils import get_prisma_client_or_throw # noqa: PLC0415
prisma_client = get_prisma_client_or_throw(
"Database not connected. Connect a database to use OAuth2 MCP tools."
)
creds = await list_user_oauth_credentials(prisma_client, user_id)
return {c["server_id"]: c for c in creds if "server_id" in c}
except Exception as e:
verbose_logger.warning(
f"_prefetch_oauth_creds_for_user: failed to prefetch for user={user_id}: {e}"
)
return {}
def _prepare_mcp_server_headers(
server: MCPServer,
mcp_server_auth_headers: Optional[Dict[str, Dict[str, str]]],
@ -1015,6 +1093,18 @@ if MCP_AVAILABLE:
mcp_servers=mcp_servers,
)
# Pre-fetch OAuth credentials only when at least one server uses OAuth2,
# to avoid an unnecessary DB round-trip on requests with no OAuth2 MCP servers.
_has_oauth2_server = any(
getattr(s, "auth_type", None) == MCPAuth.oauth2
for s in allowed_mcp_servers
)
_prefetched_oauth_creds = (
await _prefetch_oauth_creds_for_user(user_api_key_auth)
if _has_oauth2_server
else {}
)
async def _fetch_and_filter_server_tools(
server: MCPServer,
) -> List[MCPTool]:
@ -1030,6 +1120,12 @@ if MCP_AVAILABLE:
raw_headers=raw_headers,
)
# If no OAuth2 token came from request headers, fall back to pre-fetched creds
if extra_headers is None and server.auth_type == MCPAuth.oauth2:
extra_headers = await _get_user_oauth_extra_headers_from_db(
server, user_api_key_auth, prefetched_creds=_prefetched_oauth_creds
)
try:
tools = await global_mcp_server_manager._get_tools_from_server(
server=server,

View File

@ -1306,10 +1306,21 @@ if MCP_AVAILABLE:
def _get_cached_temporary_mcp_server_or_404(server_id: str) -> MCPServer:
server = get_cached_temporary_mcp_server(server_id)
if server is None:
# Fall back to real DB/config server (e.g. for the user-side OAuth flow
# which calls these endpoints with a real server_id, not a temp session id).
from litellm.proxy._experimental.mcp_server.mcp_server_manager import (
global_mcp_server_manager,
)
server = (
global_mcp_server_manager.get_mcp_server_by_id(server_id)
or global_mcp_server_manager.get_mcp_server_by_name(server_id)
)
if server is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail={"error": f"Temporary MCP server {server_id} not found"},
detail={"error": f"MCP server {server_id} not found"},
)
return server
@ -1320,8 +1331,8 @@ if MCP_AVAILABLE:
async def mcp_authorize(
request: Request,
server_id: str,
client_id: str,
redirect_uri: str,
client_id: Optional[str] = None,
redirect_uri: str = Query(...),
state: str = "",
code_challenge: Optional[str] = None,
code_challenge_method: Optional[str] = None,
@ -1329,10 +1340,23 @@ if MCP_AVAILABLE:
scope: Optional[str] = None,
):
mcp_server = _get_cached_temporary_mcp_server_or_404(server_id)
# Use the server's stored client_id when the caller doesn't supply one
resolved_client_id = mcp_server.client_id or client_id or ""
if not resolved_client_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "missing_client_id",
"message": (
"No client_id available for this MCP server. "
"Either configure the server with a client_id or supply one in the request."
),
},
)
return await authorize_with_server(
request=request,
mcp_server=mcp_server,
client_id=client_id,
client_id=resolved_client_id,
redirect_uri=redirect_uri,
state=state,
code_challenge=code_challenge,
@ -1351,18 +1375,30 @@ if MCP_AVAILABLE:
grant_type: str = Form(...),
code: Optional[str] = Form(None),
redirect_uri: Optional[str] = Form(None),
client_id: str = Form(...),
client_id: Optional[str] = Form(None),
client_secret: Optional[str] = Form(None),
code_verifier: Optional[str] = Form(None),
):
mcp_server = _get_cached_temporary_mcp_server_or_404(server_id)
resolved_client_id = mcp_server.client_id or client_id or ""
if not resolved_client_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={
"error": "missing_client_id",
"message": (
"No client_id available for this MCP server. "
"Either configure the server with a client_id or supply one in the request."
),
},
)
return await exchange_token_with_server(
request=request,
mcp_server=mcp_server,
grant_type=grant_type,
code=code,
redirect_uri=redirect_uri,
client_id=client_id,
client_id=resolved_client_id,
client_secret=client_secret,
code_verifier=code_verifier,
)

View File

@ -1,3 +1,4 @@
import re
import traceback
from datetime import datetime
from typing import (
@ -43,6 +44,13 @@ ToolParam = Any
LITELLM_PROXY_MCP_SERVER_URL = "litellm_proxy"
LITELLM_PROXY_MCP_SERVER_URL_PREFIX = f"{LITELLM_PROXY_MCP_SERVER_URL}/mcp/"
# Matches any URL whose path ends with /mcp/<server_name> — covers both root-path
# (http://host:port/mcp/name) and sub-path (http://host/base/mcp/name) proxy deployments.
# A false-positive match (e.g. an external URL that happens to end with /mcp/<name>) results
# in a "server not found" error from the internal gateway, not a silent failure or data leak,
# so this broad pattern is intentional and preferred over anchoring to localhost only.
_PROXY_MCP_PATH_RE = re.compile(r"^https?://.+/mcp/([^/]+)$")
class LiteLLM_Proxy_MCP_Handler:
"""
@ -54,7 +62,8 @@ class LiteLLM_Proxy_MCP_Handler:
@staticmethod
def _should_use_litellm_mcp_gateway(tools: Optional[Iterable[ToolParam]]) -> bool:
"""
Returns True if the user passed a MCP tool with server_url="litellm_proxy"
Returns True if any MCP tool should be handled via the litellm proxy MCP gateway.
This includes tools with server_url="litellm_proxy" as well as URLs ending in /mcp/<name>.
"""
if tools:
for tool in tools:
@ -64,6 +73,10 @@ class LiteLLM_Proxy_MCP_Handler:
LITELLM_PROXY_MCP_SERVER_URL
):
return True
if isinstance(server_url, str) and _PROXY_MCP_PATH_RE.match(
server_url
):
return True
return False
@staticmethod
@ -87,6 +100,18 @@ class LiteLLM_Proxy_MCP_Handler:
LITELLM_PROXY_MCP_SERVER_URL
):
mcp_tools_with_litellm_proxy.append(tool)
elif isinstance(server_url, str):
# Also intercept URLs like http://localhost:4000/mcp/atlassian_test
# by rewriting them to the internal litellm_proxy format.
m = _PROXY_MCP_PATH_RE.match(server_url)
if m:
rewritten = {
**tool,
"server_url": f"{LITELLM_PROXY_MCP_SERVER_URL_PREFIX}{m.group(1)}",
}
mcp_tools_with_litellm_proxy.append(rewritten)
else:
other_tools.append(tool)
else:
other_tools.append(tool)
else:

View File

@ -2093,3 +2093,83 @@ async def test_get_tools_from_mcp_servers_logs_list_tools_to_spendlogs_when_enab
assert spend_meta["tool_count_total"] == 1
assert spend_meta["allowed_server_count"] == 1
assert spend_meta["per_server_tool_counts"]["server_a"] == 1
@pytest.mark.asyncio
async def test_get_tools_from_mcp_servers_injects_stored_oauth2_token():
"""
When _get_tools_from_mcp_servers is called for an OAuth2 MCP server and no
oauth2_headers are provided in the request (e.g. a /responses API call from a
chat UI), the per-user stored token must be fetched from the DB and passed as
extra_headers to _get_tools_from_server.
The implementation pre-fetches all user credentials in a single bulk query
(_prefetch_oauth_creds_for_user) to avoid N+1 queries in the gather loop.
This covers the bug where OAuth2 MCP tools were always empty in the /responses
API because the stored credential was never injected.
"""
try:
from litellm.proxy._experimental.mcp_server.server import (
_get_tools_from_mcp_servers,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.mcp import MCPAuth
except ImportError:
pytest.skip("MCP server not available")
STORED_TOKEN = "atlassian-oauth-access-token-xyz"
SERVER_ID = "srv-oauth2-id"
USER_ID = "user-123"
user_auth = UserAPIKeyAuth(api_key="test-key", user_id=USER_ID)
oauth2_server = MagicMock(name="atlassian_server")
oauth2_server.name = "atlassian_test"
oauth2_server.alias = "atlassian_test"
oauth2_server.server_name = "atlassian_test"
oauth2_server.server_id = SERVER_ID
oauth2_server.auth_type = MCPAuth.oauth2
oauth2_server.extra_headers = None
# Simulate the DB returning a valid credential for this user+server
prefetched_creds = {SERVER_ID: {"access_token": STORED_TOKEN, "server_id": SERVER_ID}}
tool_1 = MagicMock()
tool_1.name = "atlassian_test-search"
with patch(
"litellm.proxy._experimental.mcp_server.server._get_allowed_mcp_servers",
new=AsyncMock(return_value=[oauth2_server]),
), patch(
# Patch the bulk prefetch so no real DB connection is needed
"litellm.proxy._experimental.mcp_server.server._prefetch_oauth_creds_for_user",
new=AsyncMock(return_value=prefetched_creds),
) as mock_prefetch, patch(
"litellm.proxy._experimental.mcp_server.server.global_mcp_server_manager",
) as mock_manager, patch(
"litellm.proxy._experimental.mcp_server.server.filter_tools_by_allowed_tools",
side_effect=lambda tools, _server: tools,
), patch(
"litellm.proxy._experimental.mcp_server.server.filter_tools_by_key_team_permissions",
new=AsyncMock(side_effect=lambda tools, **_: tools),
):
mock_manager._get_tools_from_server = AsyncMock(return_value=[tool_1])
tools = await _get_tools_from_mcp_servers(
user_api_key_auth=user_auth,
mcp_auth_header=None,
mcp_servers=["atlassian_test"],
mcp_server_auth_headers=None,
oauth2_headers=None, # No token from request — must fall back to DB
)
# Bulk credential prefetch was called once (not once per server)
mock_prefetch.assert_awaited_once_with(user_auth)
# The stored token was forwarded to the MCP transport layer as extra_headers
mock_manager._get_tools_from_server.assert_awaited_once()
call_kwargs = mock_manager._get_tools_from_server.await_args.kwargs
assert call_kwargs["extra_headers"] == {"Authorization": f"Bearer {STORED_TOKEN}"}
assert tools == [tool_1]

View File

@ -484,7 +484,7 @@ class TestListToolsRestAPI:
captured = {"called": False}
async def fake_get_tools(
server, server_auth_header, raw_headers=None, user_api_key_auth=None
server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None
):
captured["called"] = True
captured["server"] = server
@ -529,6 +529,175 @@ class TestListToolsRestAPI:
assert result["error"] is None
assert result["message"] == "Successfully retrieved tools"
async def test_name_resolution_finds_server_by_uuid(self, monkeypatch):
"""When server_id is a name string, it should be resolved to its UUID
and used for the tools lookup when the UUID is in allowed_server_ids."""
from litellm.proxy._experimental.mcp_server.server import MCPServer
from litellm.types.mcp import MCPTransport
stub_server = MCPServer(
server_id="uuid-abc-123",
name="my-server",
transport=MCPTransport.sse,
)
stub_server.alias = "my-server"
stub_server.server_name = "my-server"
stub_server.available_on_public_internet = True
stub_server.allowed_tools = None
stub_server.mcp_info = {"server_name": "my-server"}
async def fake_contexts(user_api_key_auth):
return [user_api_key_auth]
# Allowed list contains the UUID, not the name
async def fake_get_allowed_mcp_servers(*args, **kwargs):
return ["uuid-abc-123"]
captured = {"called": False, "server_arg": None}
async def fake_get_tools(server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None):
captured["called"] = True
captured["server_arg"] = server
return ["tool-x"]
monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers",
fake_get_allowed_mcp_servers, raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_name",
lambda name: stub_server if name == "my-server" else None,
raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id",
lambda sid: stub_server if sid == "uuid-abc-123" else None,
raising=False,
)
monkeypatch.setattr(rest_endpoints, "_get_tools_for_single_server", fake_get_tools, raising=False)
request = _build_request(path="/mcp-rest/tools/list", method="GET")
result = await rest_endpoints.list_tool_rest_api(
request,
server_id="my-server", # pass name, not UUID
user_api_key_dict=UserAPIKeyAuth(),
)
assert captured["called"] is True
assert captured["server_arg"] is stub_server
assert result["tools"] == ["tool-x"]
assert result["error"] is None
async def test_name_not_in_allowed_returns_access_denied(self, monkeypatch):
"""When name resolves to a server whose UUID is NOT in allowed_server_ids,
the result should be an access_denied error (not a crash or silent pass)."""
from litellm.proxy._experimental.mcp_server.server import MCPServer
from litellm.types.mcp import MCPTransport
stub_server = MCPServer(
server_id="uuid-xyz-999",
name="restricted-server",
transport=MCPTransport.sse,
)
stub_server.available_on_public_internet = True
async def fake_contexts(user_api_key_auth):
return [user_api_key_auth]
# No allowed servers for this key
async def fake_get_allowed_mcp_servers(*args, **kwargs):
return []
monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers",
fake_get_allowed_mcp_servers, raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_name",
lambda name: stub_server if name == "restricted-server" else None,
raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id",
lambda sid: stub_server if sid == "uuid-xyz-999" else None,
raising=False,
)
request = _build_request(path="/mcp-rest/tools/list", method="GET")
result = await rest_endpoints.list_tool_rest_api(
request,
server_id="restricted-server",
user_api_key_dict=UserAPIKeyAuth(),
)
assert result["tools"] == []
assert result["error"] == "unexpected_error"
assert "access_denied" in result["message"]
async def test_oauth2_user_token_injected_for_single_server(self, monkeypatch):
"""For a single-server OAuth2 request, _get_user_oauth_extra_headers is called
and the returned headers are forwarded to _get_tools_for_single_server."""
from litellm.proxy._experimental.mcp_server.server import MCPServer
from litellm.types.mcp import MCPTransport
stub_server = MCPServer(
server_id="oauth-server-id",
name="oauth-server",
transport=MCPTransport.sse,
)
stub_server.alias = "oauth-server"
stub_server.server_name = "oauth-server"
stub_server.available_on_public_internet = True
stub_server.allowed_tools = None
stub_server.mcp_info = {"server_name": "oauth-server"}
stub_server.auth_type = MCPAuth.oauth2
async def fake_contexts(user_api_key_auth):
return [user_api_key_auth]
async def fake_get_allowed_mcp_servers(*args, **kwargs):
return ["oauth-server-id"]
oauth_headers = {"Authorization": "Bearer user-oauth-token"}
async def fake_get_user_oauth_extra_headers(server, user_api_key_dict, prefetched_creds=None):
return oauth_headers
captured = {}
async def fake_get_tools(server, server_auth_header, raw_headers=None, user_api_key_auth=None, extra_headers=None):
captured["server"] = server
captured["auth_header"] = server_auth_header
return ["oauth-tool"]
monkeypatch.setattr(rest_endpoints, "build_effective_auth_contexts", fake_contexts, raising=False)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_allowed_mcp_servers",
fake_get_allowed_mcp_servers, raising=False,
)
monkeypatch.setattr(
rest_endpoints.global_mcp_server_manager, "get_mcp_server_by_id",
lambda sid: stub_server if sid == "oauth-server-id" else None,
raising=False,
)
monkeypatch.setattr(
rest_endpoints, "_get_user_oauth_extra_headers",
fake_get_user_oauth_extra_headers, raising=False,
)
monkeypatch.setattr(rest_endpoints, "_get_tools_for_single_server", fake_get_tools, raising=False)
request = _build_request(path="/mcp-rest/tools/list", method="GET")
result = await rest_endpoints.list_tool_rest_api(
request,
server_id="oauth-server-id",
user_api_key_dict=UserAPIKeyAuth(user_id="user-123"),
)
assert result["tools"] == ["oauth-tool"]
assert result["error"] is None
class TestCallToolRestAPI:
pytestmark = pytest.mark.asyncio

View File

@ -319,7 +319,7 @@ const MCPAppsPanel: React.FC<Props> = ({ accessToken, selectedServers, onChange
// Ignore — credential may already be gone; update UI regardless.
}
setOauthConnected((prev) => { const n = new Set(prev); n.delete(detailServer.server_id); return n; });
onChange(selectedServers.filter((s) => s !== name));
onChangeRef.current(selectedServersRef.current.filter((s) => s !== name));
}}
style={{ borderRadius: 8, fontWeight: 600, height: 38, minWidth: 110 }}
>
@ -331,7 +331,6 @@ const MCPAppsPanel: React.FC<Props> = ({ accessToken, selectedServers, onChange
accessToken={accessToken}
onConnect={(id) => {
setOauthConnected((prev) => new Set(prev).add(id));
handleToggle(name, true, detailServer.server_id);
}}
variant="button"
/>
@ -559,7 +558,6 @@ const MCPAppsPanel: React.FC<Props> = ({ accessToken, selectedServers, onChange
accessToken={accessToken}
onConnect={(id) => {
setOauthConnected((prev) => new Set(prev).add(id));
handleToggle(nameOf(server), true, server.server_id);
}}
variant="badge"
/>

View File

@ -9,7 +9,8 @@
import React, { useCallback, useEffect, useState } from "react";
import { Spin, message } from "antd";
import { CheckCircleOutlined, DeleteOutlined, LinkOutlined } from "@ant-design/icons";
import { DeleteOutlined, LinkOutlined } from "@ant-design/icons";
import { Badge, Table, TableBody, TableCell, TableHead, TableHeaderCell, TableRow } from "@tremor/react";
import {
deleteMCPOAuthUserCredential,
listMCPUserCredentials,
@ -86,98 +87,82 @@ const MCPCredentialsTab: React.FC<Props> = ({ accessToken }) => {
c.alias || c.server_name || c.server_id;
return (
<div style={{ width: "100%" }}>
<div className="w-full">
{/* Header */}
<div style={{ marginBottom: 20 }}>
<h2 style={{ margin: "0 0 4px", fontSize: 18, fontWeight: 600, color: "#111827" }}>
App Credentials
</h2>
<p style={{ margin: 0, fontSize: 13, color: "#6b7280" }}>
<div className="mb-4">
<h2 className="text-base font-semibold text-gray-900 mb-0.5">App Credentials</h2>
<p className="text-sm text-gray-500 m-0">
Your stored OAuth connections used automatically in chat.
</p>
</div>
{loading ? (
<div style={{ display: "flex", justifyContent: "center", padding: "48px 0" }}>
<div className="flex justify-center py-12">
<Spin />
</div>
) : credentials.length === 0 ? (
<div style={{
textAlign: "center", color: "#9ca3af", fontSize: 13,
padding: "48px 12px", border: "1px dashed #e5e7eb", borderRadius: 10,
}}>
<LinkOutlined style={{ fontSize: 28, marginBottom: 12, display: "block", color: "#d1d5db" }} />
<div className="text-center text-gray-400 text-sm py-12 border border-dashed border-gray-200 rounded-lg">
<LinkOutlined className="text-2xl mb-3 block text-gray-300" />
No connections yet.
<br />
Go to <strong>Apps</strong> and click <strong>Connect</strong> to authorize an MCP server.
</div>
) : (
<div style={{ display: "flex", flexDirection: "column", gap: 8 }}>
{credentials.map((cred) => {
const name = displayName(cred);
const isRevoking = revoking.has(cred.server_id);
const exp = expiryLabel(cred.expires_at);
const connected = relativeTime(cred.connected_at);
const isExpired = exp === "Expired";
<div className="rounded-lg border border-gray-200 overflow-hidden">
<Table>
<TableHead>
<TableRow>
<TableHeaderCell className="text-xs font-medium text-gray-500 py-2 px-4">
App
</TableHeaderCell>
<TableHeaderCell className="text-xs font-medium text-gray-500 py-2 px-4">
Connected
</TableHeaderCell>
<TableHeaderCell className="text-xs font-medium text-gray-500 py-2 px-4">
Status
</TableHeaderCell>
<TableHeaderCell className="text-xs font-medium text-gray-500 py-2 px-4 text-right">
Actions
</TableHeaderCell>
</TableRow>
</TableHead>
<TableBody>
{credentials.map((cred) => {
const name = displayName(cred);
const isRevoking = revoking.has(cred.server_id);
const exp = expiryLabel(cred.expires_at);
const connected = relativeTime(cred.connected_at);
const isExpired = exp === "Expired";
return (
<div
key={cred.server_id}
style={{
display: "flex", alignItems: "center", gap: 12,
padding: "14px 16px", border: "1px solid #e5e7eb",
borderRadius: 10, background: "#fff",
}}
>
{/* Status dot */}
<div style={{ flexShrink: 0 }}>
<CheckCircleOutlined style={{ fontSize: 20, color: isExpired ? "#d1d5db" : "#52c41a" }} />
</div>
{/* Info */}
<div style={{ flex: 1, minWidth: 0 }}>
<div style={{ fontSize: 14, fontWeight: 600, color: "#111827", overflow: "hidden", textOverflow: "ellipsis", whiteSpace: "nowrap" }}>
{name}
</div>
<div style={{ display: "flex", alignItems: "center", gap: 10, marginTop: 2, flexWrap: "wrap" }}>
{connected && (
<span style={{ fontSize: 12, color: "#9ca3af" }}>
Connected {connected}
</span>
)}
<span style={{
fontSize: 11, fontWeight: 600,
color: isExpired ? "#ef4444" : "#16a34a",
background: isExpired ? "#fef2f2" : "#f0fdf4",
borderRadius: 4, padding: "1px 6px",
}}>
{exp}
</span>
</div>
</div>
{/* Revoke */}
<button
onClick={() => handleRevoke(cred.server_id)}
disabled={isRevoking}
title="Revoke connection"
style={{
background: "none", border: "1px solid #e5e7eb",
borderRadius: 6, padding: "4px 8px",
cursor: isRevoking ? "not-allowed" : "pointer",
color: "#9ca3af", display: "flex", alignItems: "center",
opacity: isRevoking ? 0.5 : 1, flexShrink: 0,
}}
>
{isRevoking ? (
<Spin size="small" />
) : (
<DeleteOutlined style={{ fontSize: 14 }} />
)}
</button>
</div>
);
})}
return (
<TableRow key={cred.server_id} className="h-10 hover:bg-gray-50">
<TableCell className="py-2 px-4">
<span className="text-sm font-medium text-gray-900">{name}</span>
</TableCell>
<TableCell className="py-2 px-4">
<span className="text-sm text-gray-500">{connected || "—"}</span>
</TableCell>
<TableCell className="py-2 px-4">
<Badge color={isExpired ? "red" : "green"} size="xs">
{exp}
</Badge>
</TableCell>
<TableCell className="py-2 px-4 text-right">
<button
onClick={() => handleRevoke(cred.server_id)}
disabled={isRevoking}
title="Revoke connection"
className={`inline-flex items-center justify-center rounded-md border border-gray-200 px-2 py-1 text-gray-400 hover:text-red-500 hover:border-red-200 transition-colors ${isRevoking ? "opacity-50 cursor-not-allowed" : "cursor-pointer"}`}
style={{ background: "none" }}
>
{isRevoking ? <Spin size="small" /> : <DeleteOutlined className="text-sm" />}
</button>
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
)}
</div>