From 1fbb78d2a4d8abe12bd67cf9ce3acf8710bab273 Mon Sep 17 00:00:00 2001 From: Shivam Rawat Date: Sat, 6 Jun 2026 12:35:18 -0700 Subject: [PATCH] Title: Fix managed batch cancel credential resolution (#29734) * Fix managed batch cancel credential resolution Decode unified batch IDs before cancel routing and resolve litellm_credential_name to api_key in Router._acancel_batch so JWT team-scoped deployments cancel with the same credentials used at create time Co-authored-by: Cursor * fix batch cancellation credential cleanup Co-authored-by: Cursor --------- Co-authored-by: Cursor --- litellm/proxy/batches_endpoints/endpoints.py | 49 ++++---- .../openai_files_endpoints/common_utils.py | 5 +- litellm/router.py | 7 ++ .../test_batch_x_litellm_model_encoding.py | 107 +++++++++++++++++- .../test_router_acancel_batch.py | 77 +++++++++++++ 5 files changed, 221 insertions(+), 24 deletions(-) diff --git a/litellm/proxy/batches_endpoints/endpoints.py b/litellm/proxy/batches_endpoints/endpoints.py index 8516570995..ea479a5721 100644 --- a/litellm/proxy/batches_endpoints/endpoints.py +++ b/litellm/proxy/batches_endpoints/endpoints.py @@ -13,11 +13,9 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.batches.main import CancelBatchRequest, RetrieveBatchRequest from litellm.proxy._types import * +from litellm.proxy.common_utils.callback_utils import sanitize_openai_provider_metadata from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing -from litellm.proxy.common_utils.callback_utils import ( - sanitize_openai_provider_metadata, -) from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.proxy.common_utils.openai_endpoint_utils import ( get_custom_llm_provider_from_request_headers, @@ -28,6 +26,7 @@ from litellm.proxy.openai_files_endpoints.common_utils import ( decode_model_from_file_id, encode_batch_response_ids, encode_file_id_with_model, + get_batch_id_from_unified_batch_id, get_batch_from_database, get_credentials_for_model, get_model_id_from_unified_batch_id, @@ -109,6 +108,7 @@ async def create_batch( # noqa: PLR0915 proxy_config=proxy_config, route_type="acreate_batch", ) + data["metadata"] = sanitize_openai_provider_metadata(data.get("metadata")) ## check if model is a loadbalanced model router_model: Optional[str] = None @@ -123,9 +123,6 @@ async def create_batch( # noqa: PLR0915 or get_custom_llm_provider_from_request_headers(request=request) or "openai" ) - if isinstance(data.get("metadata"), dict): - data["metadata"] = sanitize_openai_provider_metadata(data["metadata"]) - _create_batch_data = LiteLLMBatchCreateRequest(**data) # Apply team-level batch output expiry enforcement @@ -529,10 +526,6 @@ async def retrieve_batch( # noqa: PLR0915 custom_llm_provider=custom_llm_provider, **data # type: ignore ) - response = await proxy_logging_obj.post_call_success_hook( - data=data, user_api_key_dict=user_api_key_dict, response=response - ) - # FIX: Update the database with the latest state from provider await update_batch_in_database( batch_id=batch_id, @@ -543,9 +536,19 @@ async def retrieve_batch( # noqa: PLR0915 verbose_proxy_logger=verbose_proxy_logger, db_batch_object=db_batch_object, operation="retrieve", - user_api_key_dict=user_api_key_dict, ) + ### CALL HOOKS ### - modify outgoing data + response = await proxy_logging_obj.post_call_success_hook( + data=data, user_api_key_dict=user_api_key_dict, response=response + ) + + # Fix: bug_feb14_batch_retrieve_returns_raw_input_file_id + # Resolve raw provider file IDs (input, output, error) to unified IDs. + if unified_batch_id: + await resolve_input_file_id_to_unified(response, prisma_client) + await resolve_output_file_ids_to_unified(response, prisma_client) + ### ALERTING ### asyncio.create_task( proxy_logging_obj.update_request_status( @@ -891,11 +894,19 @@ async def cancel_batch( }, ) - # Hook has already extracted model and unwrapped batch_id into data dict + model_id_from_batch = get_model_id_from_unified_batch_id(unified_batch_id) + if model_id_from_batch is None: + raise HTTPException( + status_code=400, + detail={ + "error": "Invalid LiteLLM managed batch ID. Missing model_id." + }, + ) + data["model"] = model_id_from_batch + data["batch_id"] = get_batch_id_from_unified_batch_id(unified_batch_id) response = await llm_router.acancel_batch(**data) # type: ignore response._hidden_params["unified_batch_id"] = unified_batch_id - # Ensure model_id is set for the post_call_success_hook to re-encode IDs if not response._hidden_params.get("model_id") and data.get("model"): response._hidden_params["model_id"] = data["model"] @@ -917,14 +928,10 @@ async def cancel_batch( **_cancel_batch_data, ) + # FIX: Update the database with the new cancelled state managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files") from litellm.proxy.proxy_server import prisma_client - response = await proxy_logging_obj.post_call_success_hook( - data=data, user_api_key_dict=user_api_key_dict, response=response - ) - - # FIX: Update the database with the new cancelled state await update_batch_in_database( batch_id=batch_id, unified_batch_id=unified_batch_id, @@ -933,7 +940,11 @@ async def cancel_batch( prisma_client=prisma_client, verbose_proxy_logger=verbose_proxy_logger, operation="cancel", - user_api_key_dict=user_api_key_dict, + ) + + ### CALL HOOKS ### - modify outgoing data + response = await proxy_logging_obj.post_call_success_hook( + data=data, user_api_key_dict=user_api_key_dict, response=response ) ### ALERTING ### diff --git a/litellm/proxy/openai_files_endpoints/common_utils.py b/litellm/proxy/openai_files_endpoints/common_utils.py index 0415bb456e..cc0d06e4f4 100644 --- a/litellm/proxy/openai_files_endpoints/common_utils.py +++ b/litellm/proxy/openai_files_endpoints/common_utils.py @@ -79,9 +79,10 @@ def get_batch_id_from_unified_batch_id(file_id: str) -> str: if not isinstance(file_id, str): return "" if "llm_batch_id" in file_id: - return file_id.split("llm_batch_id:")[1].split(",")[0] + batch_id = file_id.split("llm_batch_id:", 1)[1] else: - return file_id.split("generic_response_id:")[1].split(",")[0] + batch_id = file_id.split("generic_response_id:", 1)[1] + return re.split(r"[;,]", batch_id, maxsplit=1)[0] def encode_file_id_with_model( diff --git a/litellm/router.py b/litellm/router.py index a92590d3db..d0f4e5ff44 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -5561,7 +5561,14 @@ class Router: request_kwargs=kwargs, ) + selected_deployment_id = (deployment.get("model_info") or {}).get("id") data = deployment["litellm_params"].copy() + resolved_credentials = self.get_deployment_credentials_with_provider( + model_id=selected_deployment_id or model + ) + if resolved_credentials is not None: + data.update(resolved_credentials) + data.pop("litellm_credential_name", None) model_name = data["model"] self._update_kwargs_with_deployment( deployment=deployment, kwargs=kwargs, function_name="_acancel_batch" diff --git a/tests/litellm/proxy/test_batch_x_litellm_model_encoding.py b/tests/litellm/proxy/test_batch_x_litellm_model_encoding.py index 1d498b48ca..49e0498f14 100644 --- a/tests/litellm/proxy/test_batch_x_litellm_model_encoding.py +++ b/tests/litellm/proxy/test_batch_x_litellm_model_encoding.py @@ -5,14 +5,15 @@ Verifies that create_batch encodes response IDs with model info so that retrieve_batch can route back to the correct provider/credentials. """ +import base64 from typing import Optional from unittest.mock import AsyncMock, MagicMock, patch import pytest -import litellm from litellm.proxy.openai_files_endpoints.common_utils import ( decode_model_from_file_id, + get_batch_id_from_unified_batch_id, get_original_file_id, ) from litellm.types.utils import LiteLLMBatch @@ -30,6 +31,11 @@ def _make_mock_request(headers: dict) -> MagicMock: return mock_request +def _make_unified_batch_id(model_id: str, batch_id: str) -> str: + decoded_id = f"litellm_proxy;model_id:{model_id};llm_batch_id:{batch_id}" + return base64.urlsafe_b64encode(decoded_id.encode()).decode().rstrip("=") + + def _make_batch_response( batch_id: str = "batch_abc123", input_file_id: str = "file-input456", @@ -51,6 +57,15 @@ def _make_batch_response( ) +def test_get_batch_id_from_unified_batch_id_handles_appended_fields(): + decoded_id = ( + "litellm_proxy;model_id:deployment-123;" + "llm_batch_id:batch_openai_123;llm_output_file_id:file-output" + ) + + assert get_batch_id_from_unified_batch_id(decoded_id) == "batch_openai_123" + + @pytest.mark.asyncio async def test_create_batch_with_x_litellm_model_encodes_batch_id(): """ @@ -68,6 +83,7 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id(): mock_user_api_key_dict = MagicMock() mock_user_api_key_dict.parent_otel_span = None mock_user_api_key_dict.user_id = "test_user" + mock_user_api_key_dict.team_metadata = {} mock_credentials = { "api_key": "sk-test", @@ -83,6 +99,11 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id(): "input_file_id": "file-input456", "endpoint": "/v1/chat/completions", "completion_window": "24h", + "metadata": { + "customer_id": "cust-123", + "applied_guardrails": ["pii"], + "attempt": 1, + }, } ), ), @@ -98,8 +119,8 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id(): ), patch( "litellm.acreate_batch", - new=AsyncMock(return_value=mock_response), - ), + new_callable=AsyncMock, + ) as mock_create_batch, patch( "litellm.proxy.batches_endpoints.endpoints.is_known_model", return_value=False, @@ -116,6 +137,7 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id(): ), ), ): + mock_create_batch.return_value = mock_response # Setup the mock processor to return data and logging obj mock_processor = MagicMock() mock_processor.common_processing_pre_call_logic = AsyncMock( @@ -124,6 +146,11 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id(): "input_file_id": "file-input456", "endpoint": "/v1/chat/completions", "completion_window": "24h", + "metadata": { + "customer_id": "cust-123", + "applied_guardrails": ["pii"], + "attempt": 1, + }, }, MagicMock(), ) @@ -155,6 +182,7 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id(): assert ( original_id == raw_batch_id ), f"Expected original ID '{raw_batch_id}', got: {original_id}" + assert mock_create_batch.call_args.kwargs["metadata"] == {"customer_id": "cust-123"} @pytest.mark.asyncio @@ -180,6 +208,7 @@ async def test_create_batch_with_x_litellm_model_encodes_output_and_error_file_i mock_user_api_key_dict = MagicMock() mock_user_api_key_dict.parent_otel_span = None mock_user_api_key_dict.user_id = "test_user" + mock_user_api_key_dict.team_metadata = {} mock_credentials = { "api_key": "sk-test", @@ -272,6 +301,7 @@ async def test_create_batch_without_x_litellm_model_returns_raw_ids(): mock_user_api_key_dict = MagicMock() mock_user_api_key_dict.parent_otel_span = None mock_user_api_key_dict.user_id = "test_user" + mock_user_api_key_dict.team_metadata = {} with ( patch( @@ -384,3 +414,74 @@ class TestBatchIdRoundTripWithRetrieve: assert encoded.startswith("batch_") assert decode_model_from_file_id(encoded) == model assert get_original_file_id(encoded) == raw_id + + +@pytest.mark.asyncio +async def test_cancel_batch_with_unified_id_routes_with_decoded_model_and_batch_id(): + from litellm.proxy.batches_endpoints.endpoints import cancel_batch + + model_id = "deployment-123" + raw_batch_id = "batch_openai_123" + unified_batch_id = _make_unified_batch_id( + model_id=model_id, batch_id=raw_batch_id + ) + mock_response = _make_batch_response(batch_id=raw_batch_id, status="cancelled") + mock_response._hidden_params = {} + mock_router = MagicMock() + mock_router.acancel_batch = AsyncMock(return_value=mock_response) + mock_request = _make_mock_request(headers={}) + mock_request.url.path = f"/v1/batches/{unified_batch_id}/cancel" + mock_fastapi_response = MagicMock() + mock_fastapi_response.headers = {} + mock_user_api_key_dict = MagicMock() + mock_user_api_key_dict.parent_otel_span = None + mock_user_api_key_dict.user_id = "test_user" + mock_user_api_key_dict.allowed_model_region = None + mock_user_api_key_dict.team_metadata = {} + + with ( + patch( + "litellm.proxy.batches_endpoints.endpoints.ProxyBaseLLMRequestProcessing" + ) as mock_processor_cls, + patch( + "litellm.proxy.batches_endpoints.endpoints.update_batch_in_database", + new=AsyncMock(), + ), + patch( + "litellm.proxy.proxy_server.add_litellm_data_to_request", + new=AsyncMock(side_effect=lambda data, **_: data), + ), + patch("litellm.proxy.proxy_server.general_settings", {}), + patch("litellm.proxy.proxy_server.llm_router", mock_router), + patch("litellm.proxy.proxy_server.proxy_config", MagicMock()), + patch("litellm.proxy.proxy_server.version", "1.0.0"), + patch("litellm.proxy.proxy_server.prisma_client", None), + patch( + "litellm.proxy.proxy_server.proxy_logging_obj", + MagicMock( + get_proxy_hook=MagicMock(return_value=None), + post_call_success_hook=AsyncMock(return_value=mock_response), + post_call_failure_hook=AsyncMock(), + update_request_status=AsyncMock(), + ), + ), + ): + mock_processor = MagicMock() + mock_processor.common_processing_pre_call_logic = AsyncMock( + return_value=({"batch_id": unified_batch_id}, MagicMock()) + ) + mock_processor_cls.return_value = mock_processor + + response = await cancel_batch( + request=mock_request, + batch_id=unified_batch_id, + fastapi_response=mock_fastapi_response, + provider=None, + user_api_key_dict=mock_user_api_key_dict, + ) + + mock_router.acancel_batch.assert_awaited_once() + cancel_kwargs = mock_router.acancel_batch.await_args.kwargs + assert cancel_kwargs["model"] == model_id + assert cancel_kwargs["batch_id"] == raw_batch_id + assert response._hidden_params["model_id"] == model_id diff --git a/tests/router_unit_tests/test_router_acancel_batch.py b/tests/router_unit_tests/test_router_acancel_batch.py index b364a66752..016da592e9 100644 --- a/tests/router_unit_tests/test_router_acancel_batch.py +++ b/tests/router_unit_tests/test_router_acancel_batch.py @@ -13,6 +13,7 @@ import pytest from unittest.mock import patch, AsyncMock, MagicMock from litellm import Router import litellm +from litellm.types.utils import CredentialItem @pytest.fixture @@ -52,3 +53,79 @@ async def test_router_acancel_batch(router): assert mock_cancel.called assert response.id == "batch_123" assert response.status == "cancelled" + + +@pytest.mark.asyncio +async def test_router_acancel_batch_resolves_credential_name(): + litellm.credential_list = [ + CredentialItem( + credential_name="openai-test-credential", + credential_info={"custom_llm_provider": "openai"}, + credential_values={"api_key": "resolved-openai-key"}, + ) + ] + router = Router( + model_list=[ + { + "model_name": "gpt-5.5", + "litellm_params": { + "model": "openai/gpt-5.5", + "litellm_credential_name": "openai-test-credential", + }, + } + ] + ) + mock_response = MagicMock() + mock_response.id = "batch_123" + mock_response.status = "cancelled" + + try: + with patch.object( + litellm, "acancel_batch", new_callable=AsyncMock + ) as mock_cancel: + mock_cancel.return_value = mock_response + + await router.acancel_batch( + model="gpt-5.5", + batch_id="batch_123", + ) + + call_kwargs = mock_cancel.call_args.kwargs + assert call_kwargs["api_key"] == "resolved-openai-key" + assert "litellm_credential_name" not in call_kwargs + finally: + litellm.credential_list = [] + + +@pytest.mark.asyncio +async def test_router_acancel_batch_removes_unresolved_credential_name(): + router = Router( + model_list=[ + { + "model_name": "gpt-5.5", + "litellm_params": { + "model": "openai/gpt-5.5", + "litellm_credential_name": "missing-openai-credential", + }, + } + ] + ) + mock_response = MagicMock() + mock_response.id = "batch_123" + mock_response.status = "cancelled" + + with ( + patch.object( + router, "get_deployment_credentials_with_provider", return_value=None + ), + patch.object(litellm, "acancel_batch", new_callable=AsyncMock) as mock_cancel, + ): + mock_cancel.return_value = mock_response + + await router.acancel_batch( + model="gpt-5.5", + batch_id="batch_123", + ) + + call_kwargs = mock_cancel.call_args.kwargs + assert "litellm_credential_name" not in call_kwargs