fix(ui): persist Tools-tab MCP OAuth token to DB (#29809)

This commit is contained in:
tin-berri 2026-06-05 22:29:56 -07:00 committed by GitHub
parent 6955e6f2c2
commit 22186f457a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 354 additions and 62 deletions

View File

@ -7,6 +7,7 @@ from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
from litellm._logging import verbose_proxy_logger
from litellm._uuid import uuid
from litellm.constants import MCP_PER_USER_TOKEN_EXPIRY_BUFFER_SECONDS
from litellm.proxy._types import (
LiteLLM_MCPServerTable,
LiteLLM_ObjectPermissionTable,
@ -996,11 +997,14 @@ async def store_user_oauth_credential(
)
def is_oauth_credential_expired(cred: Dict[str, Any]) -> bool:
def is_oauth_credential_expired(cred: Dict[str, Any], buffer_seconds: int = 0) -> bool:
"""Return True if the OAuth2 credential's access_token has expired.
Checks the ``expires_at`` ISO-format string stored in the credential payload.
Returns False when ``expires_at`` is absent or unparseable (treat as non-expired).
With ``buffer_seconds`` > 0, a token that is still valid but expires within the
buffer is also treated as expired, so callers can refresh proactively instead of
handing back a token that may lapse mid-request.
"""
expires_at = cred.get("expires_at")
if not expires_at:
@ -1009,7 +1013,7 @@ def is_oauth_credential_expired(cred: Dict[str, Any]) -> bool:
exp_dt = datetime.fromisoformat(expires_at)
if exp_dt.tzinfo is None:
exp_dt = exp_dt.replace(tzinfo=timezone.utc)
return datetime.now(timezone.utc) > exp_dt
return datetime.now(timezone.utc) + timedelta(seconds=buffer_seconds) > exp_dt
except (ValueError, TypeError):
return False
@ -1157,6 +1161,50 @@ async def refresh_user_oauth_token(
return await get_user_oauth_credential(prisma_client, user_id, server_id)
async def resolve_valid_user_oauth_token(
user_id: str,
server: Any,
cred: Optional[Dict[str, Any]],
prisma_client: Optional[PrismaClient] = None,
) -> Optional[Dict[str, Any]]:
"""Return an OAuth2 credential whose access_token is good for the next request.
Returns the credential unchanged while its token is valid for at least
``MCP_PER_USER_TOKEN_EXPIRY_BUFFER_SECONDS``. Only when the token is expired (or
expiring within that buffer) and a refresh_token is stored does it mint a new one
via ``refresh_user_oauth_token``. Returns None when there is no usable token
(missing token, expired with no refresh_token, or a failed refresh).
The refresh_token is only ever sent to the server's token_url inside
``refresh_user_oauth_token``; it is never exposed to the caller beyond the cred
dict it already holds. ``prisma_client`` is fetched lazily and only when a refresh
actually happens, so the valid-token path never requires a DB handle.
"""
if not cred or not cred.get("access_token"):
return None
if not is_oauth_credential_expired(
cred, buffer_seconds=MCP_PER_USER_TOKEN_EXPIRY_BUFFER_SECONDS
):
return cred
if not cred.get("refresh_token"):
return None
if prisma_client is None:
from litellm.proxy.utils import get_prisma_client_or_throw
prisma_client = get_prisma_client_or_throw(
"Database not connected. Cannot refresh OAuth token."
)
refreshed = await refresh_user_oauth_token(
prisma_client=prisma_client,
user_id=user_id,
server=server,
cred=cred,
)
if not refreshed or not refreshed.get("access_token"):
return None
return refreshed
async def approve_mcp_server(
prisma_client: PrismaClient,
server_id: str,

View File

@ -146,9 +146,10 @@ if MCP_AVAILABLE:
try:
from litellm.proxy._experimental.mcp_server.db import (
get_user_oauth_credential,
is_oauth_credential_expired,
resolve_valid_user_oauth_token,
)
prisma_client = None
if prefetched_creds is not None:
cred = prefetched_creds.get(server_id)
else:
@ -160,13 +161,13 @@ if MCP_AVAILABLE:
cred = await get_user_oauth_credential(
prisma_client, user_id, server_id
)
cred = await resolve_valid_user_oauth_token(
user_id=user_id,
server=server,
cred=cred,
prisma_client=prisma_client,
)
if cred and cred.get("access_token"):
if is_oauth_credential_expired(cred):
verbose_logger.debug(
f"_get_user_oauth_extra_headers: token expired for "
f"user={user_id} server={server_id}"
)
return None
return {"Authorization": f"Bearer {cred['access_token']}"}
except Exception as e:
verbose_logger.warning(

View File

@ -1322,8 +1322,7 @@ if MCP_AVAILABLE:
try:
from litellm.proxy._experimental.mcp_server.db import ( # noqa: PLC0415
get_user_oauth_credential,
is_oauth_credential_expired,
refresh_user_oauth_token,
resolve_valid_user_oauth_token,
)
from litellm.proxy._experimental.mcp_server.oauth2_token_cache import ( # noqa: PLC0415
_compute_per_user_token_ttl,
@ -1343,6 +1342,7 @@ if MCP_AVAILABLE:
return {"Authorization": f"Bearer {cached_token}"}
# ── Slow path: DB lookup ──────────────────────────────────────────
prisma_client = None
if prefetched_creds is not None:
cred = prefetched_creds.get(server_id)
else:
@ -1360,43 +1360,17 @@ if MCP_AVAILABLE:
if not cred or not cred.get("access_token"):
return None
if is_oauth_credential_expired(cred):
verbose_logger.debug(
"_get_user_oauth_extra_headers_from_db: token expired for user=%s server=%s — attempting refresh",
user_id,
server_id,
)
# Attempt token refresh; requires a DB client (not available from prefetch)
if cred.get("refresh_token"):
try:
from litellm.proxy.utils import ( # noqa: PLC0415
get_prisma_client_or_throw,
)
prisma_client = get_prisma_client_or_throw(
"Database not connected. Cannot refresh OAuth token."
)
cred = await refresh_user_oauth_token(
prisma_client=prisma_client,
user_id=user_id,
server=server,
cred=cred,
)
except Exception as refresh_exc:
verbose_logger.warning(
"_get_user_oauth_extra_headers_from_db: refresh failed for user=%s server=%s: %s",
user_id,
server_id,
refresh_exc,
)
cred = None
if not cred or not cred.get("access_token"):
# Clear stale Redis/cache entry so we don't serve it again.
# Do this for both the individual and prefetch paths so the
# next request doesn't get a stale cache hit.
await mcp_per_user_token_cache.delete(user_id, server_id)
return None
cred = await resolve_valid_user_oauth_token(
user_id=user_id,
server=server,
cred=cred,
prisma_client=prisma_client,
)
if cred is None:
# Refresh failed or token expired with no usable refresh_token —
# clear the stale Redis entry so the next request doesn't reuse it.
await mcp_per_user_token_cache.delete(user_id, server_id)
return None
access_token: str = cred["access_token"]

View File

@ -10,6 +10,7 @@ keeps a plain-base64 fallback on read so existing rows continue to work.
import base64
import json
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock
import pytest
@ -18,7 +19,9 @@ from litellm.proxy._experimental.mcp_server.db import (
_decode_user_credential,
get_user_credential,
get_user_oauth_credential,
is_oauth_credential_expired,
list_user_oauth_credentials,
resolve_valid_user_oauth_token,
rotate_mcp_user_credentials_master_key,
rotate_mcp_user_env_vars_master_key,
store_user_credential,
@ -405,6 +408,177 @@ async def test_rotate_skips_undecodable_rows():
assert where["user_id_server_id"]["server_id"] == "srv-ok"
# ── Expiry buffer + refresh-on-expiry (OBO list-refresh regression) ───────────
def _oauth_cred(access_token="at-live", refresh_token=None, expires_in_seconds=None):
cred = {"type": "oauth2", "access_token": access_token}
if refresh_token is not None:
cred["refresh_token"] = refresh_token
if expires_in_seconds is not None:
cred["expires_at"] = (
datetime.now(timezone.utc) + timedelta(seconds=expires_in_seconds)
).isoformat()
return cred
def test_expiry_no_buffer_treats_soon_to_expire_as_valid():
# Without a buffer, a token with 30s of life left is still valid.
cred = _oauth_cred(expires_in_seconds=30)
assert is_oauth_credential_expired(cred) is False
assert is_oauth_credential_expired(cred, buffer_seconds=0) is False
def test_expiry_buffer_treats_soon_to_expire_as_expired():
# With a 60s buffer, the same 30s-of-life token must be treated as expired
# so callers refresh before it lapses mid-request.
cred = _oauth_cred(expires_in_seconds=30)
assert is_oauth_credential_expired(cred, buffer_seconds=60) is True
# A token comfortably beyond the buffer stays valid.
assert (
is_oauth_credential_expired(
_oauth_cred(expires_in_seconds=600), buffer_seconds=60
)
is False
)
def test_expiry_past_is_expired_regardless_of_buffer():
cred = _oauth_cred(expires_in_seconds=-10)
assert is_oauth_credential_expired(cred) is True
assert is_oauth_credential_expired(cred, buffer_seconds=60) is True
def test_expiry_missing_expires_at_is_never_expired():
assert is_oauth_credential_expired(_oauth_cred()) is False
assert is_oauth_credential_expired(_oauth_cred(), buffer_seconds=60) is False
@pytest.mark.asyncio
async def test_resolve_returns_valid_token_without_refreshing(monkeypatch):
# A token good for 10 minutes must be returned as-is, with no refresh call.
import litellm.proxy._experimental.mcp_server.db as db_mod
refresh = AsyncMock()
monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh)
cred = _oauth_cred(
access_token="at-live", refresh_token="rt-1", expires_in_seconds=600
)
result = await resolve_valid_user_oauth_token(
user_id="alice", server=MagicMock(), cred=cred, prisma_client=MagicMock()
)
assert result is cred
assert result["access_token"] == "at-live"
refresh.assert_not_called()
@pytest.mark.asyncio
async def test_resolve_refreshes_expired_token_with_refresh_token(monkeypatch):
# The core regression: an expired OBO cred with a refresh_token must mint a
# new token rather than returning None (which left the UI tool list empty).
import litellm.proxy._experimental.mcp_server.db as db_mod
refreshed = _oauth_cred(
access_token="at-fresh", refresh_token="rt-2", expires_in_seconds=3600
)
refresh = AsyncMock(return_value=refreshed)
monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh)
expired = _oauth_cred(
access_token="at-dead", refresh_token="rt-1", expires_in_seconds=-5
)
result = await resolve_valid_user_oauth_token(
user_id="alice", server=MagicMock(), cred=expired, prisma_client=MagicMock()
)
refresh.assert_awaited_once()
assert result["access_token"] == "at-fresh"
@pytest.mark.asyncio
async def test_resolve_refreshes_token_expiring_within_buffer(monkeypatch):
# A token still technically valid (30s left) but inside the 60s buffer must
# be proactively refreshed, not handed back.
import litellm.proxy._experimental.mcp_server.db as db_mod
refreshed = _oauth_cred(access_token="at-fresh", expires_in_seconds=3600)
refresh = AsyncMock(return_value=refreshed)
monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh)
soon = _oauth_cred(
access_token="at-soon", refresh_token="rt-1", expires_in_seconds=30
)
result = await resolve_valid_user_oauth_token(
user_id="alice", server=MagicMock(), cred=soon, prisma_client=MagicMock()
)
refresh.assert_awaited_once()
assert result["access_token"] == "at-fresh"
@pytest.mark.asyncio
async def test_resolve_returns_none_when_expired_without_refresh_token(monkeypatch):
# No refresh_token means nothing to refresh with — return None, never call refresh.
import litellm.proxy._experimental.mcp_server.db as db_mod
refresh = AsyncMock()
monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh)
expired = _oauth_cred(access_token="at-dead", expires_in_seconds=-5)
result = await resolve_valid_user_oauth_token(
user_id="alice", server=MagicMock(), cred=expired, prisma_client=MagicMock()
)
assert result is None
refresh.assert_not_called()
@pytest.mark.asyncio
async def test_resolve_returns_none_when_refresh_fails(monkeypatch):
# A failed refresh (provider returns nothing usable) must surface as None.
import litellm.proxy._experimental.mcp_server.db as db_mod
refresh = AsyncMock(return_value=None)
monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh)
expired = _oauth_cred(
access_token="at-dead", refresh_token="rt-1", expires_in_seconds=-5
)
result = await resolve_valid_user_oauth_token(
user_id="alice", server=MagicMock(), cred=expired, prisma_client=MagicMock()
)
refresh.assert_awaited_once()
assert result is None
@pytest.mark.asyncio
async def test_resolve_returns_none_for_missing_credential(monkeypatch):
import litellm.proxy._experimental.mcp_server.db as db_mod
refresh = AsyncMock()
monkeypatch.setattr(db_mod, "refresh_user_oauth_token", refresh)
assert (
await resolve_valid_user_oauth_token(
user_id="alice", server=MagicMock(), cred=None, prisma_client=MagicMock()
)
is None
)
assert (
await resolve_valid_user_oauth_token(
user_id="alice",
server=MagicMock(),
cred={"type": "oauth2"},
prisma_client=MagicMock(),
)
is None
)
refresh.assert_not_called()
# ── per-user env-var rotation ─────────────────────────────────────────────────

View File

@ -2,19 +2,29 @@ import { act, fireEvent, render, screen, waitFor } from "@testing-library/react"
import userEvent from "@testing-library/user-event";
import { beforeEach, describe, expect, it, vi } from "vitest";
import * as networking from "../networking";
import { setToken } from "@/utils/mcpTokenStore";
import CreateMCPServer from "./create_mcp_server";
vi.mock("../networking", () => ({
createMCPServer: vi.fn(),
registerMCPServer: vi.fn(),
storeMCPOAuthUserCredential: vi.fn().mockResolvedValue({}),
testMCPToolsListRequest: vi.fn().mockResolvedValue({ tools: [], error: null }),
}));
vi.mock("@/utils/mcpTokenStore", () => ({
setToken: vi.fn(),
}));
// Mutable holder so individual tests can simulate "Authorize & Fetch" having
// produced a token before submit.
const oauthHook = vi.hoisted(() => ({ tokenResponse: null as Record<string, unknown> | null }));
vi.mock("@/hooks/useMcpOAuthFlow", () => ({
useMcpOAuthFlow: () => ({
startOAuthFlow: vi.fn(),
status: "idle",
error: null,
tokenResponse: null,
tokenResponse: oauthHook.tokenResponse,
}),
}));
@ -27,7 +37,13 @@ vi.mock("./MCPPermissionManagement", () => ({
}));
vi.mock("./mcp_tool_configuration", () => ({
default: ({ onAllowedToolsChange, onToolAllowlistInteraction }: any) => (
default: ({
onAllowedToolsChange,
onToolAllowlistInteraction,
}: {
onAllowedToolsChange?: (tools: string[]) => void;
onToolAllowlistInteraction?: () => void;
}) => (
<div data-testid="mcp-tool-config">
<button
type="button"
@ -104,6 +120,7 @@ async function selectAntOption(labelText: string, optionText: string) {
describe("CreateMCPServer", () => {
beforeEach(() => {
vi.clearAllMocks();
oauthHook.tokenResponse = null;
});
it("should render the modal with title when visible", () => {
@ -506,6 +523,58 @@ describe("CreateMCPServer", () => {
expect(payload.token_validation).toBeUndefined();
});
it("persists access + refresh token to the DB on submit for OBO mode", async () => {
// "Authorize & Fetch" produced a token before submit.
oauthHook.tokenResponse = {
access_token: "obo-access-token",
refresh_token: "obo-refresh-token",
expires_in: 3600,
token_type: "bearer",
scope: "channels:read chat:write",
};
vi.mocked(networking.createMCPServer).mockResolvedValue({
server_id: "obo-server-1",
server_name: "OBO_Server",
alias: "OBO_Server",
url: "https://example.com/mcp",
transport: "http",
auth_type: "oauth2",
created_at: "2024-01-01T00:00:00Z",
created_by: "user-1",
updated_at: "2024-01-01T00:00:00Z",
updated_by: "user-1",
});
// Interactive OAuth + delegate_auth_to_upstream off (the default) => OBO mode.
await setupOAuthInteractive();
const nameInput = document.getElementById("server_name") as HTMLInputElement;
await act(async () => {
fireEvent.change(nameInput, { target: { value: "OBO_Server" } });
});
const urlInput = screen.getByPlaceholderText("https://your-mcp-server.com");
await act(async () => {
fireEvent.change(urlInput, { target: { value: "https://example.com/mcp" } });
});
const submitButton = screen.getByRole("button", { name: "Add MCP Server" });
await act(async () => {
fireEvent.click(submitButton);
});
await waitFor(() => {
expect(networking.storeMCPOAuthUserCredential).toHaveBeenCalledTimes(1);
});
expect(networking.storeMCPOAuthUserCredential).toHaveBeenCalledWith("test-token", "obo-server-1", {
access_token: "obo-access-token",
refresh_token: "obo-refresh-token",
expires_in: 3600,
scopes: ["channels:read", "chat:write"],
});
// OBO persists server-side; it must not fall back to the browser-only cache.
expect(setToken).not.toHaveBeenCalled();
});
it("does not submit and shows validation error for invalid JSON in token_validation_json", async () => {
await setupOAuthInteractive();

View File

@ -2,9 +2,18 @@ import React, { useState } from "react";
import { Modal, Tooltip, Form, Select, Input, Switch, Collapse } from "antd";
import { InfoCircleOutlined } from "@ant-design/icons";
import { Button, TextInput } from "@tremor/react";
import { createMCPServer, registerMCPServer } from "../networking";
import { createMCPServer, registerMCPServer, storeMCPOAuthUserCredential } from "../networking";
import { setToken } from "@/utils/mcpTokenStore";
import { AUTH_TYPE, DiscoverableMCPServer, OAUTH_FLOW, MCPServer, MCPServerCostInfo, TRANSPORT } from "./types";
import {
AUTH_TYPE,
DiscoverableMCPServer,
OAUTH_FLOW,
MCPServer,
MCPServerCostInfo,
TRANSPORT,
getMcpOAuthMode,
MCP_OAUTH2_FLOW_M2M,
} from "./types";
import OAuthFormFields from "./OAuthFormFields";
import MCPServerCostConfig from "./mcp_server_cost_config";
import MCPConnectionStatus from "./mcp_connection_status";
@ -424,19 +433,36 @@ const CreateMCPServer: React.FC<CreateMCPServerProps> = ({
? await createMCPServer(accessToken, payload)
: await registerMCPServer(accessToken, payload);
// Cache the OAuth token in sessionStorage so the Tools tab can use it
// immediately without re-authenticating. No backend DB write.
// Persist the token obtained via "Authorize & Fetch" once the server
// exists (so we have its server_id). OBO holds the per-user token in the
// backend, so write it to the DB (has_credentials=True). Passthrough
// forwards a browser-held token, so it stays in sessionStorage only.
if (oauthTokenResponse?.access_token && response?.server_id) {
setToken(
response.server_id,
{
const oauthMode = getMcpOAuthMode({
auth_type: restValues.auth_type,
oauth2_flow: values.oauth_flow_type === OAUTH_FLOW.M2M ? MCP_OAUTH2_FLOW_M2M : null,
delegate_auth_to_upstream: Boolean(delegateAuthToUpstreamRaw),
});
if (oauthMode === "obo") {
const scope = oauthTokenResponse.scope;
await storeMCPOAuthUserCredential(accessToken, response.server_id, {
access_token: oauthTokenResponse.access_token,
expires_in: oauthTokenResponse.expires_in,
refresh_token: oauthTokenResponse.refresh_token,
token_type: oauthTokenResponse.token_type,
},
userID,
);
expires_in: oauthTokenResponse.expires_in,
scopes: typeof scope === "string" && scope ? scope.split(" ") : undefined,
});
} else {
setToken(
response.server_id,
{
access_token: oauthTokenResponse.access_token,
expires_in: oauthTokenResponse.expires_in,
refresh_token: oauthTokenResponse.refresh_token,
token_type: oauthTokenResponse.token_type,
},
userID,
);
}
}
NotificationsManager.success(