Title: Fix managed batch cancel credential resolution (#29734)

* Fix managed batch cancel credential resolution

Decode unified batch IDs before cancel routing and resolve litellm_credential_name to api_key in Router._acancel_batch so JWT team-scoped deployments cancel with the same credentials used at create time

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix batch cancellation credential cleanup

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
Shivam Rawat 2026-06-06 12:35:18 -07:00 committed by GitHub
parent 51769a8ede
commit 1fbb78d2a4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 221 additions and 24 deletions

View File

@ -13,11 +13,9 @@ import litellm
from litellm._logging import verbose_proxy_logger
from litellm.batches.main import CancelBatchRequest, RetrieveBatchRequest
from litellm.proxy._types import *
from litellm.proxy.common_utils.callback_utils import sanitize_openai_provider_metadata
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.common_utils.callback_utils import (
sanitize_openai_provider_metadata,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.common_utils.openai_endpoint_utils import (
get_custom_llm_provider_from_request_headers,
@ -28,6 +26,7 @@ from litellm.proxy.openai_files_endpoints.common_utils import (
decode_model_from_file_id,
encode_batch_response_ids,
encode_file_id_with_model,
get_batch_id_from_unified_batch_id,
get_batch_from_database,
get_credentials_for_model,
get_model_id_from_unified_batch_id,
@ -109,6 +108,7 @@ async def create_batch( # noqa: PLR0915
proxy_config=proxy_config,
route_type="acreate_batch",
)
data["metadata"] = sanitize_openai_provider_metadata(data.get("metadata"))
## check if model is a loadbalanced model
router_model: Optional[str] = None
@ -123,9 +123,6 @@ async def create_batch( # noqa: PLR0915
or get_custom_llm_provider_from_request_headers(request=request)
or "openai"
)
if isinstance(data.get("metadata"), dict):
data["metadata"] = sanitize_openai_provider_metadata(data["metadata"])
_create_batch_data = LiteLLMBatchCreateRequest(**data)
# Apply team-level batch output expiry enforcement
@ -529,10 +526,6 @@ async def retrieve_batch( # noqa: PLR0915
custom_llm_provider=custom_llm_provider, **data # type: ignore
)
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
# FIX: Update the database with the latest state from provider
await update_batch_in_database(
batch_id=batch_id,
@ -543,9 +536,19 @@ async def retrieve_batch( # noqa: PLR0915
verbose_proxy_logger=verbose_proxy_logger,
db_batch_object=db_batch_object,
operation="retrieve",
user_api_key_dict=user_api_key_dict,
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
# Fix: bug_feb14_batch_retrieve_returns_raw_input_file_id
# Resolve raw provider file IDs (input, output, error) to unified IDs.
if unified_batch_id:
await resolve_input_file_id_to_unified(response, prisma_client)
await resolve_output_file_ids_to_unified(response, prisma_client)
### ALERTING ###
asyncio.create_task(
proxy_logging_obj.update_request_status(
@ -891,11 +894,19 @@ async def cancel_batch(
},
)
# Hook has already extracted model and unwrapped batch_id into data dict
model_id_from_batch = get_model_id_from_unified_batch_id(unified_batch_id)
if model_id_from_batch is None:
raise HTTPException(
status_code=400,
detail={
"error": "Invalid LiteLLM managed batch ID. Missing model_id."
},
)
data["model"] = model_id_from_batch
data["batch_id"] = get_batch_id_from_unified_batch_id(unified_batch_id)
response = await llm_router.acancel_batch(**data) # type: ignore
response._hidden_params["unified_batch_id"] = unified_batch_id
# Ensure model_id is set for the post_call_success_hook to re-encode IDs
if not response._hidden_params.get("model_id") and data.get("model"):
response._hidden_params["model_id"] = data["model"]
@ -917,14 +928,10 @@ async def cancel_batch(
**_cancel_batch_data,
)
# FIX: Update the database with the new cancelled state
managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files")
from litellm.proxy.proxy_server import prisma_client
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
# FIX: Update the database with the new cancelled state
await update_batch_in_database(
batch_id=batch_id,
unified_batch_id=unified_batch_id,
@ -933,7 +940,11 @@ async def cancel_batch(
prisma_client=prisma_client,
verbose_proxy_logger=verbose_proxy_logger,
operation="cancel",
user_api_key_dict=user_api_key_dict,
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
)
### ALERTING ###

View File

@ -79,9 +79,10 @@ def get_batch_id_from_unified_batch_id(file_id: str) -> str:
if not isinstance(file_id, str):
return ""
if "llm_batch_id" in file_id:
return file_id.split("llm_batch_id:")[1].split(",")[0]
batch_id = file_id.split("llm_batch_id:", 1)[1]
else:
return file_id.split("generic_response_id:")[1].split(",")[0]
batch_id = file_id.split("generic_response_id:", 1)[1]
return re.split(r"[;,]", batch_id, maxsplit=1)[0]
def encode_file_id_with_model(

View File

@ -5561,7 +5561,14 @@ class Router:
request_kwargs=kwargs,
)
selected_deployment_id = (deployment.get("model_info") or {}).get("id")
data = deployment["litellm_params"].copy()
resolved_credentials = self.get_deployment_credentials_with_provider(
model_id=selected_deployment_id or model
)
if resolved_credentials is not None:
data.update(resolved_credentials)
data.pop("litellm_credential_name", None)
model_name = data["model"]
self._update_kwargs_with_deployment(
deployment=deployment, kwargs=kwargs, function_name="_acancel_batch"

View File

@ -5,14 +5,15 @@ Verifies that create_batch encodes response IDs with model info so that
retrieve_batch can route back to the correct provider/credentials.
"""
import base64
from typing import Optional
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import litellm
from litellm.proxy.openai_files_endpoints.common_utils import (
decode_model_from_file_id,
get_batch_id_from_unified_batch_id,
get_original_file_id,
)
from litellm.types.utils import LiteLLMBatch
@ -30,6 +31,11 @@ def _make_mock_request(headers: dict) -> MagicMock:
return mock_request
def _make_unified_batch_id(model_id: str, batch_id: str) -> str:
decoded_id = f"litellm_proxy;model_id:{model_id};llm_batch_id:{batch_id}"
return base64.urlsafe_b64encode(decoded_id.encode()).decode().rstrip("=")
def _make_batch_response(
batch_id: str = "batch_abc123",
input_file_id: str = "file-input456",
@ -51,6 +57,15 @@ def _make_batch_response(
)
def test_get_batch_id_from_unified_batch_id_handles_appended_fields():
decoded_id = (
"litellm_proxy;model_id:deployment-123;"
"llm_batch_id:batch_openai_123;llm_output_file_id:file-output"
)
assert get_batch_id_from_unified_batch_id(decoded_id) == "batch_openai_123"
@pytest.mark.asyncio
async def test_create_batch_with_x_litellm_model_encodes_batch_id():
"""
@ -68,6 +83,7 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
mock_user_api_key_dict = MagicMock()
mock_user_api_key_dict.parent_otel_span = None
mock_user_api_key_dict.user_id = "test_user"
mock_user_api_key_dict.team_metadata = {}
mock_credentials = {
"api_key": "sk-test",
@ -83,6 +99,11 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
"input_file_id": "file-input456",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {
"customer_id": "cust-123",
"applied_guardrails": ["pii"],
"attempt": 1,
},
}
),
),
@ -98,8 +119,8 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
),
patch(
"litellm.acreate_batch",
new=AsyncMock(return_value=mock_response),
),
new_callable=AsyncMock,
) as mock_create_batch,
patch(
"litellm.proxy.batches_endpoints.endpoints.is_known_model",
return_value=False,
@ -116,6 +137,7 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
),
),
):
mock_create_batch.return_value = mock_response
# Setup the mock processor to return data and logging obj
mock_processor = MagicMock()
mock_processor.common_processing_pre_call_logic = AsyncMock(
@ -124,6 +146,11 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
"input_file_id": "file-input456",
"endpoint": "/v1/chat/completions",
"completion_window": "24h",
"metadata": {
"customer_id": "cust-123",
"applied_guardrails": ["pii"],
"attempt": 1,
},
},
MagicMock(),
)
@ -155,6 +182,7 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
assert (
original_id == raw_batch_id
), f"Expected original ID '{raw_batch_id}', got: {original_id}"
assert mock_create_batch.call_args.kwargs["metadata"] == {"customer_id": "cust-123"}
@pytest.mark.asyncio
@ -180,6 +208,7 @@ async def test_create_batch_with_x_litellm_model_encodes_output_and_error_file_i
mock_user_api_key_dict = MagicMock()
mock_user_api_key_dict.parent_otel_span = None
mock_user_api_key_dict.user_id = "test_user"
mock_user_api_key_dict.team_metadata = {}
mock_credentials = {
"api_key": "sk-test",
@ -272,6 +301,7 @@ async def test_create_batch_without_x_litellm_model_returns_raw_ids():
mock_user_api_key_dict = MagicMock()
mock_user_api_key_dict.parent_otel_span = None
mock_user_api_key_dict.user_id = "test_user"
mock_user_api_key_dict.team_metadata = {}
with (
patch(
@ -384,3 +414,74 @@ class TestBatchIdRoundTripWithRetrieve:
assert encoded.startswith("batch_")
assert decode_model_from_file_id(encoded) == model
assert get_original_file_id(encoded) == raw_id
@pytest.mark.asyncio
async def test_cancel_batch_with_unified_id_routes_with_decoded_model_and_batch_id():
from litellm.proxy.batches_endpoints.endpoints import cancel_batch
model_id = "deployment-123"
raw_batch_id = "batch_openai_123"
unified_batch_id = _make_unified_batch_id(
model_id=model_id, batch_id=raw_batch_id
)
mock_response = _make_batch_response(batch_id=raw_batch_id, status="cancelled")
mock_response._hidden_params = {}
mock_router = MagicMock()
mock_router.acancel_batch = AsyncMock(return_value=mock_response)
mock_request = _make_mock_request(headers={})
mock_request.url.path = f"/v1/batches/{unified_batch_id}/cancel"
mock_fastapi_response = MagicMock()
mock_fastapi_response.headers = {}
mock_user_api_key_dict = MagicMock()
mock_user_api_key_dict.parent_otel_span = None
mock_user_api_key_dict.user_id = "test_user"
mock_user_api_key_dict.allowed_model_region = None
mock_user_api_key_dict.team_metadata = {}
with (
patch(
"litellm.proxy.batches_endpoints.endpoints.ProxyBaseLLMRequestProcessing"
) as mock_processor_cls,
patch(
"litellm.proxy.batches_endpoints.endpoints.update_batch_in_database",
new=AsyncMock(),
),
patch(
"litellm.proxy.proxy_server.add_litellm_data_to_request",
new=AsyncMock(side_effect=lambda data, **_: data),
),
patch("litellm.proxy.proxy_server.general_settings", {}),
patch("litellm.proxy.proxy_server.llm_router", mock_router),
patch("litellm.proxy.proxy_server.proxy_config", MagicMock()),
patch("litellm.proxy.proxy_server.version", "1.0.0"),
patch("litellm.proxy.proxy_server.prisma_client", None),
patch(
"litellm.proxy.proxy_server.proxy_logging_obj",
MagicMock(
get_proxy_hook=MagicMock(return_value=None),
post_call_success_hook=AsyncMock(return_value=mock_response),
post_call_failure_hook=AsyncMock(),
update_request_status=AsyncMock(),
),
),
):
mock_processor = MagicMock()
mock_processor.common_processing_pre_call_logic = AsyncMock(
return_value=({"batch_id": unified_batch_id}, MagicMock())
)
mock_processor_cls.return_value = mock_processor
response = await cancel_batch(
request=mock_request,
batch_id=unified_batch_id,
fastapi_response=mock_fastapi_response,
provider=None,
user_api_key_dict=mock_user_api_key_dict,
)
mock_router.acancel_batch.assert_awaited_once()
cancel_kwargs = mock_router.acancel_batch.await_args.kwargs
assert cancel_kwargs["model"] == model_id
assert cancel_kwargs["batch_id"] == raw_batch_id
assert response._hidden_params["model_id"] == model_id

View File

@ -13,6 +13,7 @@ import pytest
from unittest.mock import patch, AsyncMock, MagicMock
from litellm import Router
import litellm
from litellm.types.utils import CredentialItem
@pytest.fixture
@ -52,3 +53,79 @@ async def test_router_acancel_batch(router):
assert mock_cancel.called
assert response.id == "batch_123"
assert response.status == "cancelled"
@pytest.mark.asyncio
async def test_router_acancel_batch_resolves_credential_name():
litellm.credential_list = [
CredentialItem(
credential_name="openai-test-credential",
credential_info={"custom_llm_provider": "openai"},
credential_values={"api_key": "resolved-openai-key"},
)
]
router = Router(
model_list=[
{
"model_name": "gpt-5.5",
"litellm_params": {
"model": "openai/gpt-5.5",
"litellm_credential_name": "openai-test-credential",
},
}
]
)
mock_response = MagicMock()
mock_response.id = "batch_123"
mock_response.status = "cancelled"
try:
with patch.object(
litellm, "acancel_batch", new_callable=AsyncMock
) as mock_cancel:
mock_cancel.return_value = mock_response
await router.acancel_batch(
model="gpt-5.5",
batch_id="batch_123",
)
call_kwargs = mock_cancel.call_args.kwargs
assert call_kwargs["api_key"] == "resolved-openai-key"
assert "litellm_credential_name" not in call_kwargs
finally:
litellm.credential_list = []
@pytest.mark.asyncio
async def test_router_acancel_batch_removes_unresolved_credential_name():
router = Router(
model_list=[
{
"model_name": "gpt-5.5",
"litellm_params": {
"model": "openai/gpt-5.5",
"litellm_credential_name": "missing-openai-credential",
},
}
]
)
mock_response = MagicMock()
mock_response.id = "batch_123"
mock_response.status = "cancelled"
with (
patch.object(
router, "get_deployment_credentials_with_provider", return_value=None
),
patch.object(litellm, "acancel_batch", new_callable=AsyncMock) as mock_cancel,
):
mock_cancel.return_value = mock_response
await router.acancel_batch(
model="gpt-5.5",
batch_id="batch_123",
)
call_kwargs = mock_cancel.call_args.kwargs
assert "litellm_credential_name" not in call_kwargs