diff --git a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py index ed374635fe..3beddd2c43 100644 --- a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py @@ -512,12 +512,13 @@ async def exchange_token_with_server( result = { "access_token": access_token, "token_type": token_response.get("token_type", "Bearer"), - "expires_in": token_response.get("expires_in", 3600), } - if "refresh_token" in token_response and token_response["refresh_token"]: + if token_response.get("expires_in") is not None: + result["expires_in"] = token_response["expires_in"] + if token_response.get("refresh_token"): result["refresh_token"] = token_response["refresh_token"] - if "scope" in token_response and token_response["scope"]: + if token_response.get("scope"): result["scope"] = token_response["scope"] # RFC 6749 ยง5.1: token responses must not be cached. diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py index da66d60aed..6fd935e336 100644 --- a/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_discoverable_endpoints.py @@ -1,5 +1,6 @@ """Tests for MCP OAuth discoverable endpoints""" +import json from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -2661,3 +2662,74 @@ async def test_token_endpoint_sets_no_store_cache_control(): assert response.headers["cache-control"] == "no-store" assert response.headers["pragma"] == "no-cache" + + +async def _exchange_with_upstream_token_response(upstream_body): + from fastapi import Request + + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + exchange_token_with_server, + ) + from litellm.proxy._types import MCPTransport + from litellm.types.mcp import MCPAuth + from litellm.types.mcp_server.mcp_server_manager import MCPServer + + server = MCPServer( + server_id="t", + name="t", + server_name="t", + alias="t", + transport=MCPTransport.http, + auth_type=MCPAuth.oauth2, + client_id="cid", + client_secret="cs", + authorization_url="https://provider.com/oauth/authorize", + token_url="https://provider.com/oauth/token", + ) + mock_request = MagicMock(spec=Request) + mock_request.base_url = "https://litellm.example.com/" + mock_request.headers = {} + + fake_http_response = MagicMock() + fake_http_response.json.return_value = upstream_body + fake_http_response.raise_for_status = MagicMock() + fake_http_client = MagicMock() + fake_http_client.post = AsyncMock(return_value=fake_http_response) + + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints.get_async_httpx_client", + return_value=fake_http_client, + ): + response = await exchange_token_with_server( + request=mock_request, + mcp_server=server, + grant_type="authorization_code", + code="c", + redirect_uri="http://127.0.0.1:3000/cb", + client_id="cid", + client_secret=None, + code_verifier=None, + ) + return json.loads(response.body) + + +@pytest.mark.asyncio +async def test_token_exchange_omits_expires_in_when_upstream_omits_it(): + """A provider that issues a non-expiring token (e.g. Slack without token + rotation) returns no ``expires_in``. The exchange must mirror that and omit + ``expires_in`` rather than fabricate a 1-hour TTL, so the stored credential + is treated as non-expiring instead of dying after an hour.""" + body = await _exchange_with_upstream_token_response( + {"access_token": "tok", "token_type": "Bearer"} + ) + assert "expires_in" not in body + + +@pytest.mark.asyncio +async def test_token_exchange_passes_through_upstream_expires_in(): + """When the provider does send ``expires_in`` (e.g. Slack with token + rotation), the exchange forwards the real value unchanged.""" + body = await _exchange_with_upstream_token_response( + {"access_token": "tok", "token_type": "Bearer", "expires_in": 43200} + ) + assert body["expires_in"] == 43200