From 22186f457af26f83143d277e6c4305a12ba311f4 Mon Sep 17 00:00:00 2001 From: tin-berri Date: Fri, 5 Jun 2026 22:29:56 -0700 Subject: [PATCH] fix(ui): persist Tools-tab MCP OAuth token to DB (#29809) --- litellm/proxy/_experimental/mcp_server/db.py | 52 +++++- .../mcp_server/rest_endpoints.py | 15 +- .../proxy/_experimental/mcp_server/server.py | 52 ++---- .../mcp_server/test_db_credentials.py | 174 ++++++++++++++++++ .../mcp_tools/create_mcp_server.test.tsx | 73 +++++++- .../mcp_tools/create_mcp_server.tsx | 50 +++-- 6 files changed, 354 insertions(+), 62 deletions(-) diff --git a/litellm/proxy/_experimental/mcp_server/db.py b/litellm/proxy/_experimental/mcp_server/db.py index 4d41afcb6f..0ba0181200 100644 --- a/litellm/proxy/_experimental/mcp_server/db.py +++ b/litellm/proxy/_experimental/mcp_server/db.py @@ -7,6 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast from litellm._logging import verbose_proxy_logger from litellm._uuid import uuid +from litellm.constants import MCP_PER_USER_TOKEN_EXPIRY_BUFFER_SECONDS from litellm.proxy._types import ( LiteLLM_MCPServerTable, LiteLLM_ObjectPermissionTable, @@ -996,11 +997,14 @@ async def store_user_oauth_credential( ) -def is_oauth_credential_expired(cred: Dict[str, Any]) -> bool: +def is_oauth_credential_expired(cred: Dict[str, Any], buffer_seconds: int = 0) -> bool: """Return True if the OAuth2 credential's access_token has expired. Checks the ``expires_at`` ISO-format string stored in the credential payload. Returns False when ``expires_at`` is absent or unparseable (treat as non-expired). + With ``buffer_seconds`` > 0, a token that is still valid but expires within the + buffer is also treated as expired, so callers can refresh proactively instead of + handing back a token that may lapse mid-request. """ expires_at = cred.get("expires_at") if not expires_at: @@ -1009,7 +1013,7 @@ def is_oauth_credential_expired(cred: Dict[str, Any]) -> bool: exp_dt = datetime.fromisoformat(expires_at) if exp_dt.tzinfo is None: exp_dt = exp_dt.replace(tzinfo=timezone.utc) - return datetime.now(timezone.utc) > exp_dt + return datetime.now(timezone.utc) + timedelta(seconds=buffer_seconds) > exp_dt except (ValueError, TypeError): return False @@ -1157,6 +1161,50 @@ async def refresh_user_oauth_token( return await get_user_oauth_credential(prisma_client, user_id, server_id) +async def resolve_valid_user_oauth_token( + user_id: str, + server: Any, + cred: Optional[Dict[str, Any]], + prisma_client: Optional[PrismaClient] = None, +) -> Optional[Dict[str, Any]]: + """Return an OAuth2 credential whose access_token is good for the next request. + + Returns the credential unchanged while its token is valid for at least + ``MCP_PER_USER_TOKEN_EXPIRY_BUFFER_SECONDS``. Only when the token is expired (or + expiring within that buffer) and a refresh_token is stored does it mint a new one + via ``refresh_user_oauth_token``. Returns None when there is no usable token + (missing token, expired with no refresh_token, or a failed refresh). + + The refresh_token is only ever sent to the server's token_url inside + ``refresh_user_oauth_token``; it is never exposed to the caller beyond the cred + dict it already holds. ``prisma_client`` is fetched lazily and only when a refresh + actually happens, so the valid-token path never requires a DB handle. + """ + if not cred or not cred.get("access_token"): + return None + if not is_oauth_credential_expired( + cred, buffer_seconds=MCP_PER_USER_TOKEN_EXPIRY_BUFFER_SECONDS + ): + return cred + if not cred.get("refresh_token"): + return None + if prisma_client is None: + from litellm.proxy.utils import get_prisma_client_or_throw + + prisma_client = get_prisma_client_or_throw( + "Database not connected. Cannot refresh OAuth token." + ) + refreshed = await refresh_user_oauth_token( + prisma_client=prisma_client, + user_id=user_id, + server=server, + cred=cred, + ) + if not refreshed or not refreshed.get("access_token"): + return None + return refreshed + + async def approve_mcp_server( prisma_client: PrismaClient, server_id: str, diff --git a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py index 78cecfed0b..725f7a335b 100644 --- a/litellm/proxy/_experimental/mcp_server/rest_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/rest_endpoints.py @@ -146,9 +146,10 @@ if MCP_AVAILABLE: try: from litellm.proxy._experimental.mcp_server.db import ( get_user_oauth_credential, - is_oauth_credential_expired, + resolve_valid_user_oauth_token, ) + prisma_client = None if prefetched_creds is not None: cred = prefetched_creds.get(server_id) else: @@ -160,13 +161,13 @@ if MCP_AVAILABLE: cred = await get_user_oauth_credential( prisma_client, user_id, server_id ) + cred = await resolve_valid_user_oauth_token( + user_id=user_id, + server=server, + cred=cred, + prisma_client=prisma_client, + ) if cred and cred.get("access_token"): - if is_oauth_credential_expired(cred): - verbose_logger.debug( - f"_get_user_oauth_extra_headers: token expired for " - f"user={user_id} server={server_id}" - ) - return None return {"Authorization": f"Bearer {cred['access_token']}"} except Exception as e: verbose_logger.warning( diff --git a/litellm/proxy/_experimental/mcp_server/server.py b/litellm/proxy/_experimental/mcp_server/server.py index 51e77de43c..0477a5d324 100644 --- a/litellm/proxy/_experimental/mcp_server/server.py +++ b/litellm/proxy/_experimental/mcp_server/server.py @@ -1322,8 +1322,7 @@ if MCP_AVAILABLE: try: from litellm.proxy._experimental.mcp_server.db import ( # noqa: PLC0415 get_user_oauth_credential, - is_oauth_credential_expired, - refresh_user_oauth_token, + resolve_valid_user_oauth_token, ) from litellm.proxy._experimental.mcp_server.oauth2_token_cache import ( # noqa: PLC0415 _compute_per_user_token_ttl, @@ -1343,6 +1342,7 @@ if MCP_AVAILABLE: return {"Authorization": f"Bearer {cached_token}"} # ── Slow path: DB lookup ────────────────────────────────────────── + prisma_client = None if prefetched_creds is not None: cred = prefetched_creds.get(server_id) else: @@ -1360,43 +1360,17 @@ if MCP_AVAILABLE: if not cred or not cred.get("access_token"): return None - if is_oauth_credential_expired(cred): - verbose_logger.debug( - "_get_user_oauth_extra_headers_from_db: token expired for user=%s server=%s — attempting refresh", - user_id, - server_id, - ) - # Attempt token refresh; requires a DB client (not available from prefetch) - if cred.get("refresh_token"): - try: - from litellm.proxy.utils import ( # noqa: PLC0415 - get_prisma_client_or_throw, - ) - - prisma_client = get_prisma_client_or_throw( - "Database not connected. Cannot refresh OAuth token." - ) - cred = await refresh_user_oauth_token( - prisma_client=prisma_client, - user_id=user_id, - server=server, - cred=cred, - ) - except Exception as refresh_exc: - verbose_logger.warning( - "_get_user_oauth_extra_headers_from_db: refresh failed for user=%s server=%s: %s", - user_id, - server_id, - refresh_exc, - ) - cred = None - - if not cred or not cred.get("access_token"): - # Clear stale Redis/cache entry so we don't serve it again. - # Do this for both the individual and prefetch paths so the - # next request doesn't get a stale cache hit. - await mcp_per_user_token_cache.delete(user_id, server_id) - return None + cred = await resolve_valid_user_oauth_token( + user_id=user_id, + server=server, + cred=cred, + prisma_client=prisma_client, + ) + if cred is None: + # Refresh failed or token expired with no usable refresh_token — + # clear the stale Redis entry so the next request doesn't reuse it. + await mcp_per_user_token_cache.delete(user_id, server_id) + return None access_token: str = cred["access_token"] diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py index 8aa9f107c8..c230cfd6cd 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_db_credentials.py @@ -10,6 +10,7 @@ keeps a plain-base64 fallback on read so existing rows continue to work. import base64 import json +from datetime import datetime, timedelta, timezone from unittest.mock import AsyncMock, MagicMock import pytest @@ -18,7 +19,9 @@ from litellm.proxy._experimental.mcp_server.db import ( _decode_user_credential, get_user_credential, get_user_oauth_credential, + is_oauth_credential_expired, list_user_oauth_credentials, + resolve_valid_user_oauth_token, rotate_mcp_user_credentials_master_key, rotate_mcp_user_env_vars_master_key, store_user_credential, @@ -405,6 +408,177 @@ async def test_rotate_skips_undecodable_rows(): assert where["user_id_server_id"]["server_id"] == "srv-ok" +# ── Expiry buffer + refresh-on-expiry (OBO list-refresh regression) ─────────── + + +def _oauth_cred(access_token="at-live", refresh_token=None, expires_in_seconds=None): + cred = {"type": "oauth2", "access_token": access_token} + if refresh_token is not None: + cred["refresh_token"] = refresh_token + if expires_in_seconds is not None: + cred["expires_at"] = ( + datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds) + ).isoformat() + return cred + + +def test_expiry_no_buffer_treats_soon_to_expire_as_valid(): + # Without a buffer, a token with 30s of life left is still valid. + cred = _oauth_cred(expires_in_seconds=30) + assert is_oauth_credential_expired(cred) is False + assert is_oauth_credential_expired(cred, buffer_seconds=0) is False + + +def test_expiry_buffer_treats_soon_to_expire_as_expired(): + # With a 60s buffer, the same 30s-of-life token must be treated as expired + # so callers refresh before it lapses mid-request. + cred = _oauth_cred(expires_in_seconds=30) + assert is_oauth_credential_expired(cred, buffer_seconds=60) is True + # A token comfortably beyond the buffer stays valid. + assert ( + is_oauth_credential_expired( + _oauth_cred(expires_in_seconds=600), buffer_seconds=60 + ) + is False + ) + + +def test_expiry_past_is_expired_regardless_of_buffer(): + cred = _oauth_cred(expires_in_seconds=-10) + assert is_oauth_credential_expired(cred) is True + assert is_oauth_credential_expired(cred, buffer_seconds=60) is True + + +def test_expiry_missing_expires_at_is_never_expired(): + assert is_oauth_credential_expired(_oauth_cred()) is False + assert is_oauth_credential_expired(_oauth_cred(), buffer_seconds=60) is False + + +@pytest.mark.asyncio +async def test_resolve_returns_valid_token_without_refreshing(monkeypatch): + # A token good for 10 minutes must be returned as-is, with no refresh call. + import litellm.proxy._experimental.mcp_server.db as db_mod + + refresh = AsyncMock() + monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh) + + cred = _oauth_cred( + access_token="at-live", refresh_token="rt-1", expires_in_seconds=600 + ) + result = await resolve_valid_user_oauth_token( + user_id="alice", server=MagicMock(), cred=cred, prisma_client=MagicMock() + ) + + assert result is cred + assert result["access_token"] == "at-live" + refresh.assert_not_called() + + +@pytest.mark.asyncio +async def test_resolve_refreshes_expired_token_with_refresh_token(monkeypatch): + # The core regression: an expired OBO cred with a refresh_token must mint a + # new token rather than returning None (which left the UI tool list empty). + import litellm.proxy._experimental.mcp_server.db as db_mod + + refreshed = _oauth_cred( + access_token="at-fresh", refresh_token="rt-2", expires_in_seconds=3600 + ) + refresh = AsyncMock(return_value=refreshed) + monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh) + + expired = _oauth_cred( + access_token="at-dead", refresh_token="rt-1", expires_in_seconds=-5 + ) + result = await resolve_valid_user_oauth_token( + user_id="alice", server=MagicMock(), cred=expired, prisma_client=MagicMock() + ) + + refresh.assert_awaited_once() + assert result["access_token"] == "at-fresh" + + +@pytest.mark.asyncio +async def test_resolve_refreshes_token_expiring_within_buffer(monkeypatch): + # A token still technically valid (30s left) but inside the 60s buffer must + # be proactively refreshed, not handed back. + import litellm.proxy._experimental.mcp_server.db as db_mod + + refreshed = _oauth_cred(access_token="at-fresh", expires_in_seconds=3600) + refresh = AsyncMock(return_value=refreshed) + monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh) + + soon = _oauth_cred( + access_token="at-soon", refresh_token="rt-1", expires_in_seconds=30 + ) + result = await resolve_valid_user_oauth_token( + user_id="alice", server=MagicMock(), cred=soon, prisma_client=MagicMock() + ) + + refresh.assert_awaited_once() + assert result["access_token"] == "at-fresh" + + +@pytest.mark.asyncio +async def test_resolve_returns_none_when_expired_without_refresh_token(monkeypatch): + # No refresh_token means nothing to refresh with — return None, never call refresh. + import litellm.proxy._experimental.mcp_server.db as db_mod + + refresh = AsyncMock() + monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh) + + expired = _oauth_cred(access_token="at-dead", expires_in_seconds=-5) + result = await resolve_valid_user_oauth_token( + user_id="alice", server=MagicMock(), cred=expired, prisma_client=MagicMock() + ) + + assert result is None + refresh.assert_not_called() + + +@pytest.mark.asyncio +async def test_resolve_returns_none_when_refresh_fails(monkeypatch): + # A failed refresh (provider returns nothing usable) must surface as None. + import litellm.proxy._experimental.mcp_server.db as db_mod + + refresh = AsyncMock(return_value=None) + monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh) + + expired = _oauth_cred( + access_token="at-dead", refresh_token="rt-1", expires_in_seconds=-5 + ) + result = await resolve_valid_user_oauth_token( + user_id="alice", server=MagicMock(), cred=expired, prisma_client=MagicMock() + ) + + refresh.assert_awaited_once() + assert result is None + + +@pytest.mark.asyncio +async def test_resolve_returns_none_for_missing_credential(monkeypatch): + import litellm.proxy._experimental.mcp_server.db as db_mod + + refresh = AsyncMock() + monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh) + + assert ( + await resolve_valid_user_oauth_token( + user_id="alice", server=MagicMock(), cred=None, prisma_client=MagicMock() + ) + is None + ) + assert ( + await resolve_valid_user_oauth_token( + user_id="alice", + server=MagicMock(), + cred={"type": "oauth2"}, + prisma_client=MagicMock(), + ) + is None + ) + refresh.assert_not_called() + + # ── per-user env-var rotation ───────────────────────────────────────────────── diff --git a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.test.tsx b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.test.tsx index d635d7bb6b..042f02251b 100644 --- a/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.test.tsx +++ b/ui/litellm-dashboard/src/components/mcp_tools/create_mcp_server.test.tsx @@ -2,19 +2,29 @@ import { act, fireEvent, render, screen, waitFor } from "@testing-library/react" import userEvent from "@testing-library/user-event"; import { beforeEach, describe, expect, it, vi } from "vitest"; import * as networking from "../networking"; +import { setToken } from "@/utils/mcpTokenStore"; import CreateMCPServer from "./create_mcp_server"; vi.mock("../networking", () => ({ createMCPServer: vi.fn(), + registerMCPServer: vi.fn(), + storeMCPOAuthUserCredential: vi.fn().mockResolvedValue({}), testMCPToolsListRequest: vi.fn().mockResolvedValue({ tools: [], error: null }), })); +vi.mock("@/utils/mcpTokenStore", () => ({ + setToken: vi.fn(), +})); + +// Mutable holder so individual tests can simulate "Authorize & Fetch" having +// produced a token before submit. +const oauthHook = vi.hoisted(() => ({ tokenResponse: null as Record | null })); vi.mock("@/hooks/useMcpOAuthFlow", () => ({ useMcpOAuthFlow: () => ({ startOAuthFlow: vi.fn(), status: "idle", error: null, - tokenResponse: null, + tokenResponse: oauthHook.tokenResponse, }), })); @@ -27,7 +37,13 @@ vi.mock("./MCPPermissionManagement", () => ({ })); vi.mock("./mcp_tool_configuration", () => ({ - default: ({ onAllowedToolsChange, onToolAllowlistInteraction }: any) => ( + default: ({ + onAllowedToolsChange, + onToolAllowlistInteraction, + }: { + onAllowedToolsChange?: (tools: string[]) => void; + onToolAllowlistInteraction?: () => void; + }) => (