Title: Fix managed batch cancel credential resolution (#29734)
* Fix managed batch cancel credential resolution Decode unified batch IDs before cancel routing and resolve litellm_credential_name to api_key in Router._acancel_batch so JWT team-scoped deployments cancel with the same credentials used at create time Co-authored-by: Cursor <cursoragent@cursor.com> * fix batch cancellation credential cleanup Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
51769a8ede
commit
1fbb78d2a4
@ -13,11 +13,9 @@ import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.batches.main import CancelBatchRequest, RetrieveBatchRequest
|
||||
from litellm.proxy._types import *
|
||||
from litellm.proxy.common_utils.callback_utils import sanitize_openai_provider_metadata
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.common_utils.callback_utils import (
|
||||
sanitize_openai_provider_metadata,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.proxy.common_utils.openai_endpoint_utils import (
|
||||
get_custom_llm_provider_from_request_headers,
|
||||
@ -28,6 +26,7 @@ from litellm.proxy.openai_files_endpoints.common_utils import (
|
||||
decode_model_from_file_id,
|
||||
encode_batch_response_ids,
|
||||
encode_file_id_with_model,
|
||||
get_batch_id_from_unified_batch_id,
|
||||
get_batch_from_database,
|
||||
get_credentials_for_model,
|
||||
get_model_id_from_unified_batch_id,
|
||||
@ -109,6 +108,7 @@ async def create_batch( # noqa: PLR0915
|
||||
proxy_config=proxy_config,
|
||||
route_type="acreate_batch",
|
||||
)
|
||||
data["metadata"] = sanitize_openai_provider_metadata(data.get("metadata"))
|
||||
|
||||
## check if model is a loadbalanced model
|
||||
router_model: Optional[str] = None
|
||||
@ -123,9 +123,6 @@ async def create_batch( # noqa: PLR0915
|
||||
or get_custom_llm_provider_from_request_headers(request=request)
|
||||
or "openai"
|
||||
)
|
||||
if isinstance(data.get("metadata"), dict):
|
||||
data["metadata"] = sanitize_openai_provider_metadata(data["metadata"])
|
||||
|
||||
_create_batch_data = LiteLLMBatchCreateRequest(**data)
|
||||
|
||||
# Apply team-level batch output expiry enforcement
|
||||
@ -529,10 +526,6 @@ async def retrieve_batch( # noqa: PLR0915
|
||||
custom_llm_provider=custom_llm_provider, **data # type: ignore
|
||||
)
|
||||
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
# FIX: Update the database with the latest state from provider
|
||||
await update_batch_in_database(
|
||||
batch_id=batch_id,
|
||||
@ -543,9 +536,19 @@ async def retrieve_batch( # noqa: PLR0915
|
||||
verbose_proxy_logger=verbose_proxy_logger,
|
||||
db_batch_object=db_batch_object,
|
||||
operation="retrieve",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
# Fix: bug_feb14_batch_retrieve_returns_raw_input_file_id
|
||||
# Resolve raw provider file IDs (input, output, error) to unified IDs.
|
||||
if unified_batch_id:
|
||||
await resolve_input_file_id_to_unified(response, prisma_client)
|
||||
await resolve_output_file_ids_to_unified(response, prisma_client)
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.update_request_status(
|
||||
@ -891,11 +894,19 @@ async def cancel_batch(
|
||||
},
|
||||
)
|
||||
|
||||
# Hook has already extracted model and unwrapped batch_id into data dict
|
||||
model_id_from_batch = get_model_id_from_unified_batch_id(unified_batch_id)
|
||||
if model_id_from_batch is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Invalid LiteLLM managed batch ID. Missing model_id."
|
||||
},
|
||||
)
|
||||
data["model"] = model_id_from_batch
|
||||
data["batch_id"] = get_batch_id_from_unified_batch_id(unified_batch_id)
|
||||
response = await llm_router.acancel_batch(**data) # type: ignore
|
||||
response._hidden_params["unified_batch_id"] = unified_batch_id
|
||||
|
||||
# Ensure model_id is set for the post_call_success_hook to re-encode IDs
|
||||
if not response._hidden_params.get("model_id") and data.get("model"):
|
||||
response._hidden_params["model_id"] = data["model"]
|
||||
|
||||
@ -917,14 +928,10 @@ async def cancel_batch(
|
||||
**_cancel_batch_data,
|
||||
)
|
||||
|
||||
# FIX: Update the database with the new cancelled state
|
||||
managed_files_obj = proxy_logging_obj.get_proxy_hook("managed_files")
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
# FIX: Update the database with the new cancelled state
|
||||
await update_batch_in_database(
|
||||
batch_id=batch_id,
|
||||
unified_batch_id=unified_batch_id,
|
||||
@ -933,7 +940,11 @@ async def cancel_batch(
|
||||
prisma_client=prisma_client,
|
||||
verbose_proxy_logger=verbose_proxy_logger,
|
||||
operation="cancel",
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
### CALL HOOKS ### - modify outgoing data
|
||||
response = await proxy_logging_obj.post_call_success_hook(
|
||||
data=data, user_api_key_dict=user_api_key_dict, response=response
|
||||
)
|
||||
|
||||
### ALERTING ###
|
||||
|
||||
@ -79,9 +79,10 @@ def get_batch_id_from_unified_batch_id(file_id: str) -> str:
|
||||
if not isinstance(file_id, str):
|
||||
return ""
|
||||
if "llm_batch_id" in file_id:
|
||||
return file_id.split("llm_batch_id:")[1].split(",")[0]
|
||||
batch_id = file_id.split("llm_batch_id:", 1)[1]
|
||||
else:
|
||||
return file_id.split("generic_response_id:")[1].split(",")[0]
|
||||
batch_id = file_id.split("generic_response_id:", 1)[1]
|
||||
return re.split(r"[;,]", batch_id, maxsplit=1)[0]
|
||||
|
||||
|
||||
def encode_file_id_with_model(
|
||||
|
||||
@ -5561,7 +5561,14 @@ class Router:
|
||||
request_kwargs=kwargs,
|
||||
)
|
||||
|
||||
selected_deployment_id = (deployment.get("model_info") or {}).get("id")
|
||||
data = deployment["litellm_params"].copy()
|
||||
resolved_credentials = self.get_deployment_credentials_with_provider(
|
||||
model_id=selected_deployment_id or model
|
||||
)
|
||||
if resolved_credentials is not None:
|
||||
data.update(resolved_credentials)
|
||||
data.pop("litellm_credential_name", None)
|
||||
model_name = data["model"]
|
||||
self._update_kwargs_with_deployment(
|
||||
deployment=deployment, kwargs=kwargs, function_name="_acancel_batch"
|
||||
|
||||
@ -5,14 +5,15 @@ Verifies that create_batch encodes response IDs with model info so that
|
||||
retrieve_batch can route back to the correct provider/credentials.
|
||||
"""
|
||||
|
||||
import base64
|
||||
from typing import Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.proxy.openai_files_endpoints.common_utils import (
|
||||
decode_model_from_file_id,
|
||||
get_batch_id_from_unified_batch_id,
|
||||
get_original_file_id,
|
||||
)
|
||||
from litellm.types.utils import LiteLLMBatch
|
||||
@ -30,6 +31,11 @@ def _make_mock_request(headers: dict) -> MagicMock:
|
||||
return mock_request
|
||||
|
||||
|
||||
def _make_unified_batch_id(model_id: str, batch_id: str) -> str:
|
||||
decoded_id = f"litellm_proxy;model_id:{model_id};llm_batch_id:{batch_id}"
|
||||
return base64.urlsafe_b64encode(decoded_id.encode()).decode().rstrip("=")
|
||||
|
||||
|
||||
def _make_batch_response(
|
||||
batch_id: str = "batch_abc123",
|
||||
input_file_id: str = "file-input456",
|
||||
@ -51,6 +57,15 @@ def _make_batch_response(
|
||||
)
|
||||
|
||||
|
||||
def test_get_batch_id_from_unified_batch_id_handles_appended_fields():
|
||||
decoded_id = (
|
||||
"litellm_proxy;model_id:deployment-123;"
|
||||
"llm_batch_id:batch_openai_123;llm_output_file_id:file-output"
|
||||
)
|
||||
|
||||
assert get_batch_id_from_unified_batch_id(decoded_id) == "batch_openai_123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_batch_with_x_litellm_model_encodes_batch_id():
|
||||
"""
|
||||
@ -68,6 +83,7 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
|
||||
mock_user_api_key_dict = MagicMock()
|
||||
mock_user_api_key_dict.parent_otel_span = None
|
||||
mock_user_api_key_dict.user_id = "test_user"
|
||||
mock_user_api_key_dict.team_metadata = {}
|
||||
|
||||
mock_credentials = {
|
||||
"api_key": "sk-test",
|
||||
@ -83,6 +99,11 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
|
||||
"input_file_id": "file-input456",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"metadata": {
|
||||
"customer_id": "cust-123",
|
||||
"applied_guardrails": ["pii"],
|
||||
"attempt": 1,
|
||||
},
|
||||
}
|
||||
),
|
||||
),
|
||||
@ -98,8 +119,8 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
|
||||
),
|
||||
patch(
|
||||
"litellm.acreate_batch",
|
||||
new=AsyncMock(return_value=mock_response),
|
||||
),
|
||||
new_callable=AsyncMock,
|
||||
) as mock_create_batch,
|
||||
patch(
|
||||
"litellm.proxy.batches_endpoints.endpoints.is_known_model",
|
||||
return_value=False,
|
||||
@ -116,6 +137,7 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
|
||||
),
|
||||
),
|
||||
):
|
||||
mock_create_batch.return_value = mock_response
|
||||
# Setup the mock processor to return data and logging obj
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.common_processing_pre_call_logic = AsyncMock(
|
||||
@ -124,6 +146,11 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
|
||||
"input_file_id": "file-input456",
|
||||
"endpoint": "/v1/chat/completions",
|
||||
"completion_window": "24h",
|
||||
"metadata": {
|
||||
"customer_id": "cust-123",
|
||||
"applied_guardrails": ["pii"],
|
||||
"attempt": 1,
|
||||
},
|
||||
},
|
||||
MagicMock(),
|
||||
)
|
||||
@ -155,6 +182,7 @@ async def test_create_batch_with_x_litellm_model_encodes_batch_id():
|
||||
assert (
|
||||
original_id == raw_batch_id
|
||||
), f"Expected original ID '{raw_batch_id}', got: {original_id}"
|
||||
assert mock_create_batch.call_args.kwargs["metadata"] == {"customer_id": "cust-123"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -180,6 +208,7 @@ async def test_create_batch_with_x_litellm_model_encodes_output_and_error_file_i
|
||||
mock_user_api_key_dict = MagicMock()
|
||||
mock_user_api_key_dict.parent_otel_span = None
|
||||
mock_user_api_key_dict.user_id = "test_user"
|
||||
mock_user_api_key_dict.team_metadata = {}
|
||||
|
||||
mock_credentials = {
|
||||
"api_key": "sk-test",
|
||||
@ -272,6 +301,7 @@ async def test_create_batch_without_x_litellm_model_returns_raw_ids():
|
||||
mock_user_api_key_dict = MagicMock()
|
||||
mock_user_api_key_dict.parent_otel_span = None
|
||||
mock_user_api_key_dict.user_id = "test_user"
|
||||
mock_user_api_key_dict.team_metadata = {}
|
||||
|
||||
with (
|
||||
patch(
|
||||
@ -384,3 +414,74 @@ class TestBatchIdRoundTripWithRetrieve:
|
||||
assert encoded.startswith("batch_")
|
||||
assert decode_model_from_file_id(encoded) == model
|
||||
assert get_original_file_id(encoded) == raw_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cancel_batch_with_unified_id_routes_with_decoded_model_and_batch_id():
|
||||
from litellm.proxy.batches_endpoints.endpoints import cancel_batch
|
||||
|
||||
model_id = "deployment-123"
|
||||
raw_batch_id = "batch_openai_123"
|
||||
unified_batch_id = _make_unified_batch_id(
|
||||
model_id=model_id, batch_id=raw_batch_id
|
||||
)
|
||||
mock_response = _make_batch_response(batch_id=raw_batch_id, status="cancelled")
|
||||
mock_response._hidden_params = {}
|
||||
mock_router = MagicMock()
|
||||
mock_router.acancel_batch = AsyncMock(return_value=mock_response)
|
||||
mock_request = _make_mock_request(headers={})
|
||||
mock_request.url.path = f"/v1/batches/{unified_batch_id}/cancel"
|
||||
mock_fastapi_response = MagicMock()
|
||||
mock_fastapi_response.headers = {}
|
||||
mock_user_api_key_dict = MagicMock()
|
||||
mock_user_api_key_dict.parent_otel_span = None
|
||||
mock_user_api_key_dict.user_id = "test_user"
|
||||
mock_user_api_key_dict.allowed_model_region = None
|
||||
mock_user_api_key_dict.team_metadata = {}
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.batches_endpoints.endpoints.ProxyBaseLLMRequestProcessing"
|
||||
) as mock_processor_cls,
|
||||
patch(
|
||||
"litellm.proxy.batches_endpoints.endpoints.update_batch_in_database",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.add_litellm_data_to_request",
|
||||
new=AsyncMock(side_effect=lambda data, **_: data),
|
||||
),
|
||||
patch("litellm.proxy.proxy_server.general_settings", {}),
|
||||
patch("litellm.proxy.proxy_server.llm_router", mock_router),
|
||||
patch("litellm.proxy.proxy_server.proxy_config", MagicMock()),
|
||||
patch("litellm.proxy.proxy_server.version", "1.0.0"),
|
||||
patch("litellm.proxy.proxy_server.prisma_client", None),
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.proxy_logging_obj",
|
||||
MagicMock(
|
||||
get_proxy_hook=MagicMock(return_value=None),
|
||||
post_call_success_hook=AsyncMock(return_value=mock_response),
|
||||
post_call_failure_hook=AsyncMock(),
|
||||
update_request_status=AsyncMock(),
|
||||
),
|
||||
),
|
||||
):
|
||||
mock_processor = MagicMock()
|
||||
mock_processor.common_processing_pre_call_logic = AsyncMock(
|
||||
return_value=({"batch_id": unified_batch_id}, MagicMock())
|
||||
)
|
||||
mock_processor_cls.return_value = mock_processor
|
||||
|
||||
response = await cancel_batch(
|
||||
request=mock_request,
|
||||
batch_id=unified_batch_id,
|
||||
fastapi_response=mock_fastapi_response,
|
||||
provider=None,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
mock_router.acancel_batch.assert_awaited_once()
|
||||
cancel_kwargs = mock_router.acancel_batch.await_args.kwargs
|
||||
assert cancel_kwargs["model"] == model_id
|
||||
assert cancel_kwargs["batch_id"] == raw_batch_id
|
||||
assert response._hidden_params["model_id"] == model_id
|
||||
|
||||
@ -13,6 +13,7 @@ import pytest
|
||||
from unittest.mock import patch, AsyncMock, MagicMock
|
||||
from litellm import Router
|
||||
import litellm
|
||||
from litellm.types.utils import CredentialItem
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -52,3 +53,79 @@ async def test_router_acancel_batch(router):
|
||||
assert mock_cancel.called
|
||||
assert response.id == "batch_123"
|
||||
assert response.status == "cancelled"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_acancel_batch_resolves_credential_name():
|
||||
litellm.credential_list = [
|
||||
CredentialItem(
|
||||
credential_name="openai-test-credential",
|
||||
credential_info={"custom_llm_provider": "openai"},
|
||||
credential_values={"api_key": "resolved-openai-key"},
|
||||
)
|
||||
]
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-5.5",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-5.5",
|
||||
"litellm_credential_name": "openai-test-credential",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.id = "batch_123"
|
||||
mock_response.status = "cancelled"
|
||||
|
||||
try:
|
||||
with patch.object(
|
||||
litellm, "acancel_batch", new_callable=AsyncMock
|
||||
) as mock_cancel:
|
||||
mock_cancel.return_value = mock_response
|
||||
|
||||
await router.acancel_batch(
|
||||
model="gpt-5.5",
|
||||
batch_id="batch_123",
|
||||
)
|
||||
|
||||
call_kwargs = mock_cancel.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "resolved-openai-key"
|
||||
assert "litellm_credential_name" not in call_kwargs
|
||||
finally:
|
||||
litellm.credential_list = []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_acancel_batch_removes_unresolved_credential_name():
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "gpt-5.5",
|
||||
"litellm_params": {
|
||||
"model": "openai/gpt-5.5",
|
||||
"litellm_credential_name": "missing-openai-credential",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
mock_response = MagicMock()
|
||||
mock_response.id = "batch_123"
|
||||
mock_response.status = "cancelled"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
router, "get_deployment_credentials_with_provider", return_value=None
|
||||
),
|
||||
patch.object(litellm, "acancel_batch", new_callable=AsyncMock) as mock_cancel,
|
||||
):
|
||||
mock_cancel.return_value = mock_response
|
||||
|
||||
await router.acancel_batch(
|
||||
model="gpt-5.5",
|
||||
batch_id="batch_123",
|
||||
)
|
||||
|
||||
call_kwargs = mock_cancel.call_args.kwargs
|
||||
assert "litellm_credential_name" not in call_kwargs
|
||||
|
||||
Loading…
Reference in New Issue
Block a user