fix(managed_batches): convert raw output_file_id to managed ID in CheckBatchCost poller (#27984)
* fix(managed_batches): convert raw output_file_id to managed ID in CheckBatchCost poller
CheckBatchCost bypasses async_post_call_success_hook, causing raw provider
output_file_ids to be persisted in LiteLLM_ManagedObjectTable. This fix converts
output_file_id and error_file_id to managed base64 IDs before the DB write.
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix(check_batch_cost): persist managed file before mutating response and propagate team_id
- Move setattr after store_unified_file_id so the response only receives the
managed ID once the DB record is successfully written. Avoids serializing
an orphaned managed ID into file_object when the store call fails.
- Populate team_id on the minimal UserAPIKeyAuth from job.team_id so the
managed file record is created with the correct team ownership, allowing
other team members to access the batch output file via /files/{id}/content.
Co-authored-by: Yassin Kortam <yassin@berri.ai>
* test(managed_batches): extend test to cover error_file_id conversion
Co-authored-by: Cursor <cursoragent@cursor.com>
* fix managed file test
---------
Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Yassin Kortam <yassin@berri.ai>
This commit is contained in:
parent
fe755ee02a
commit
cbdc70d544
@ -300,6 +300,42 @@ class CheckBatchCost:
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# CheckBatchCost bypasses async_post_call_success_hook, so convert raw
|
||||
# output/error file IDs to managed base64 IDs before the DB write here.
|
||||
managed_files_hook = self.proxy_logging_obj.get_proxy_hook("managed_files")
|
||||
if managed_files_hook is not None:
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
_minimal_auth = UserAPIKeyAuth(
|
||||
user_id=job.created_by or "default-user-id",
|
||||
team_id=getattr(job, "team_id", None),
|
||||
)
|
||||
for _file_attr in ["output_file_id", "error_file_id"]:
|
||||
_raw_file_id = getattr(response, _file_attr, None)
|
||||
if _raw_file_id and not _is_base64_encoded_unified_file_id(_raw_file_id):
|
||||
try:
|
||||
_unified_file_id = managed_files_hook.get_unified_output_file_id(
|
||||
output_file_id=_raw_file_id,
|
||||
model_id=model_id,
|
||||
model_name=str(model_name) if model_name else deployment_info.model_name or None,
|
||||
)
|
||||
await managed_files_hook.store_unified_file_id(
|
||||
file_id=_unified_file_id,
|
||||
file_object=None,
|
||||
litellm_parent_otel_span=None,
|
||||
model_mappings={model_id: _raw_file_id},
|
||||
user_api_key_dict=_minimal_auth,
|
||||
)
|
||||
setattr(response, _file_attr, _unified_file_id)
|
||||
verbose_proxy_logger.info(
|
||||
f"CheckBatchCost: converted {_file_attr} "
|
||||
f"{_raw_file_id!r} -> managed ID for batch {batch_id}"
|
||||
)
|
||||
except Exception as _e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"CheckBatchCost: failed to create managed file ID for "
|
||||
f"{_file_attr}={_raw_file_id!r}: {_e}"
|
||||
)
|
||||
|
||||
# Pass deployment model_info so custom batch pricing
|
||||
# (input_cost_per_token_batches etc.) is used for cost calc
|
||||
deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {}
|
||||
|
||||
@ -22,7 +22,9 @@ class TestCheckBatchCost:
|
||||
|
||||
@pytest.fixture
|
||||
def mock_proxy_logging_obj(self):
|
||||
return MagicMock()
|
||||
mock = MagicMock()
|
||||
mock.get_proxy_hook.return_value = None
|
||||
return mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm_router(self):
|
||||
@ -372,3 +374,141 @@ class TestCheckBatchCost:
|
||||
update_data["batch_processed"] is True
|
||||
), "update() must include batch_processed=True when column is present"
|
||||
assert update_data["status"] == "complete"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raw_output_file_id_converted_to_managed_id(
|
||||
self, check_batch_cost_instance, mock_prisma_client, mock_llm_router
|
||||
):
|
||||
"""CheckBatchCost must convert a raw provider output_file_id to a managed base64 ID.
|
||||
|
||||
Without this, GET /batches/{id} returns a raw file ID that cannot be routed
|
||||
through the proxy, causing API_KEY errors when clients call GET /files/{id}/content.
|
||||
"""
|
||||
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
|
||||
return_value=0
|
||||
)
|
||||
mock_prisma_client.db.litellm_managedobjecttable.update = AsyncMock()
|
||||
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(
|
||||
return_value=None
|
||||
)
|
||||
|
||||
mock_job = MagicMock()
|
||||
mock_job.id = "job-raw-file-1"
|
||||
mock_job.unified_object_id = "dW5pZmllZF9iYXRjaF9pZA=="
|
||||
mock_job.created_by = "user-1"
|
||||
mock_job.team_id = None
|
||||
|
||||
check_batch_cost_instance._has_batch_processed_column = True
|
||||
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
|
||||
return_value=[mock_job]
|
||||
)
|
||||
|
||||
raw_output_file_id = "file-batch-output-abc123"
|
||||
raw_error_file_id = "file-batch-error-xyz456"
|
||||
fake_managed_output_id = "bGl0ZWxsbV9wcm94eTo6b3V0cHV0"
|
||||
fake_managed_error_id = "bGl0ZWxsbV9wcm94eTo6ZXJyb3I="
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status = "completed"
|
||||
mock_response.output_file_id = raw_output_file_id
|
||||
mock_response.error_file_id = raw_error_file_id
|
||||
mock_response.model_dump_json.return_value = (
|
||||
'{"id":"batch-1","status":"completed"}'
|
||||
)
|
||||
|
||||
mock_llm_router.aretrieve_batch = AsyncMock(return_value=mock_response)
|
||||
mock_llm_router.get_deployment_credentials_with_provider = MagicMock(
|
||||
return_value={"api_key": "sk-test"}
|
||||
)
|
||||
|
||||
mock_deployment = MagicMock()
|
||||
mock_deployment.litellm_params.custom_llm_provider = "azure"
|
||||
mock_deployment.litellm_params.model = "azure/gpt-5-mini"
|
||||
mock_deployment.model_name = "gpt-5-batch"
|
||||
mock_deployment.model_info.model_dump.return_value = {}
|
||||
mock_llm_router.get_deployment = MagicMock(return_value=mock_deployment)
|
||||
|
||||
mock_hook = MagicMock()
|
||||
mock_hook.get_unified_output_file_id.side_effect = [
|
||||
fake_managed_output_id,
|
||||
fake_managed_error_id,
|
||||
]
|
||||
mock_hook.store_unified_file_id = AsyncMock()
|
||||
check_batch_cost_instance.proxy_logging_obj.get_proxy_hook.return_value = (
|
||||
mock_hook
|
||||
)
|
||||
|
||||
mock_file_content = MagicMock()
|
||||
mock_file_content.content = b'{"id":"req-1"}'
|
||||
decoded_id = "llm_model_id,model-123;llm_batch_id,batch-456;"
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.openai_files_endpoints.common_utils._is_base64_encoded_unified_file_id",
|
||||
# call 1: job unified_object_id decode, call 2: existing raw check for output_file_id,
|
||||
# call 3: fix guard for output_file_id, call 4: fix guard for error_file_id
|
||||
side_effect=[decoded_id, None, None, None],
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.openai_files_endpoints.common_utils.get_model_id_from_unified_batch_id",
|
||||
return_value="model-123",
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.openai_files_endpoints.common_utils.get_batch_id_from_unified_batch_id",
|
||||
return_value="batch-456",
|
||||
),
|
||||
patch(
|
||||
"litellm.files.main.afile_content",
|
||||
new_callable=AsyncMock,
|
||||
return_value=mock_file_content,
|
||||
),
|
||||
patch(
|
||||
"litellm.batches.batch_utils._get_file_content_as_dictionary",
|
||||
return_value=[{"id": "req-1"}],
|
||||
),
|
||||
patch(
|
||||
"litellm.batches.batch_utils.calculate_batch_cost_and_usage",
|
||||
new_callable=AsyncMock,
|
||||
return_value=(
|
||||
0.01,
|
||||
{"prompt_tokens": 10, "completion_tokens": 5},
|
||||
["gpt-4"],
|
||||
),
|
||||
),
|
||||
patch(
|
||||
"litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider",
|
||||
return_value=("gpt-5-mini", "azure", None, None),
|
||||
),
|
||||
patch(
|
||||
"litellm.litellm_core_utils.litellm_logging.Logging"
|
||||
) as mock_logging_cls,
|
||||
):
|
||||
mock_logging_obj = MagicMock()
|
||||
mock_logging_obj.async_success_handler = AsyncMock()
|
||||
mock_logging_cls.return_value = mock_logging_obj
|
||||
|
||||
await check_batch_cost_instance.check_batch_cost()
|
||||
|
||||
assert mock_hook.get_unified_output_file_id.call_count == 2
|
||||
mock_hook.get_unified_output_file_id.assert_any_call(
|
||||
output_file_id=raw_output_file_id,
|
||||
model_id="model-123",
|
||||
model_name="gpt-5-mini",
|
||||
)
|
||||
mock_hook.get_unified_output_file_id.assert_any_call(
|
||||
output_file_id=raw_error_file_id,
|
||||
model_id="model-123",
|
||||
model_name="gpt-5-mini",
|
||||
)
|
||||
assert mock_hook.store_unified_file_id.await_count == 2
|
||||
# {raw_file_id: managed_file_id} for each store call
|
||||
stored = {
|
||||
next(iter(c[1]["model_mappings"].values())): c[1]["file_id"]
|
||||
for c in mock_hook.store_unified_file_id.call_args_list
|
||||
}
|
||||
assert stored == {
|
||||
raw_output_file_id: fake_managed_output_id,
|
||||
raw_error_file_id: fake_managed_error_id,
|
||||
}
|
||||
assert mock_response.output_file_id == fake_managed_output_id
|
||||
assert mock_response.error_file_id == fake_managed_error_id
|
||||
|
||||
@ -184,6 +184,7 @@ async def test_check_batch_cost_should_call_afile_content_directly_with_credenti
|
||||
mock_job.unified_object_id = unified_object_id
|
||||
mock_job.created_by = "user-A"
|
||||
mock_job.id = "job-1"
|
||||
mock_job.team_id = None
|
||||
|
||||
# Mock prisma
|
||||
mock_prisma = MagicMock()
|
||||
@ -196,6 +197,10 @@ async def test_check_batch_cost_should_call_afile_content_directly_with_credenti
|
||||
mock_proxy_logging = MagicMock()
|
||||
mock_managed_files_hook = MagicMock()
|
||||
mock_managed_files_hook.afile_content = AsyncMock()
|
||||
mock_managed_files_hook.store_unified_file_id = AsyncMock()
|
||||
mock_managed_files_hook.get_unified_output_file_id.return_value = (
|
||||
"bGl0ZWxsbV9wcm94eTo6bWFuYWdlZA=="
|
||||
)
|
||||
mock_proxy_logging.get_proxy_hook = MagicMock(return_value=mock_managed_files_hook)
|
||||
|
||||
# Mock the batch response (completed, with output file)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user