fix(managed_batches): convert raw output_file_id to managed ID in CheckBatchCost poller (#27984)

* fix(managed_batches): convert raw output_file_id to managed ID in CheckBatchCost poller

CheckBatchCost bypasses async_post_call_success_hook, causing raw provider
output_file_ids to be persisted in LiteLLM_ManagedObjectTable. This fix converts
output_file_id and error_file_id to managed base64 IDs before the DB write.

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(check_batch_cost): persist managed file before mutating response and propagate team_id

- Move setattr after store_unified_file_id so the response only receives the
  managed ID once the DB record is successfully written. Avoids serializing
  an orphaned managed ID into file_object when the store call fails.
- Populate team_id on the minimal UserAPIKeyAuth from job.team_id so the
  managed file record is created with the correct team ownership, allowing
  other team members to access the batch output file via /files/{id}/content.

Co-authored-by: Yassin Kortam <yassin@berri.ai>

* test(managed_batches): extend test to cover error_file_id conversion

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix managed file test

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Yassin Kortam <yassin@berri.ai>
This commit is contained in:
Sameer Kankute 2026-05-15 17:11:38 +05:30 committed by GitHub
parent fe755ee02a
commit cbdc70d544
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 182 additions and 1 deletions

View File

@ -300,6 +300,42 @@ class CheckBatchCost:
custom_llm_provider=custom_llm_provider,
)
# CheckBatchCost bypasses async_post_call_success_hook, so convert raw
# output/error file IDs to managed base64 IDs before the DB write here.
managed_files_hook = self.proxy_logging_obj.get_proxy_hook("managed_files")
if managed_files_hook is not None:
from litellm.proxy._types import UserAPIKeyAuth
_minimal_auth = UserAPIKeyAuth(
user_id=job.created_by or "default-user-id",
team_id=getattr(job, "team_id", None),
)
for _file_attr in ["output_file_id", "error_file_id"]:
_raw_file_id = getattr(response, _file_attr, None)
if _raw_file_id and not _is_base64_encoded_unified_file_id(_raw_file_id):
try:
_unified_file_id = managed_files_hook.get_unified_output_file_id(
output_file_id=_raw_file_id,
model_id=model_id,
model_name=str(model_name) if model_name else deployment_info.model_name or None,
)
await managed_files_hook.store_unified_file_id(
file_id=_unified_file_id,
file_object=None,
litellm_parent_otel_span=None,
model_mappings={model_id: _raw_file_id},
user_api_key_dict=_minimal_auth,
)
setattr(response, _file_attr, _unified_file_id)
verbose_proxy_logger.info(
f"CheckBatchCost: converted {_file_attr} "
f"{_raw_file_id!r} -> managed ID for batch {batch_id}"
)
except Exception as _e:
verbose_proxy_logger.warning(
f"CheckBatchCost: failed to create managed file ID for "
f"{_file_attr}={_raw_file_id!r}: {_e}"
)
# Pass deployment model_info so custom batch pricing
# (input_cost_per_token_batches etc.) is used for cost calc
deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {}

View File

@ -22,7 +22,9 @@ class TestCheckBatchCost:
@pytest.fixture
def mock_proxy_logging_obj(self):
return MagicMock()
mock = MagicMock()
mock.get_proxy_hook.return_value = None
return mock
@pytest.fixture
def mock_llm_router(self):
@ -372,3 +374,141 @@ class TestCheckBatchCost:
update_data["batch_processed"] is True
), "update() must include batch_processed=True when column is present"
assert update_data["status"] == "complete"
@pytest.mark.asyncio
async def test_raw_output_file_id_converted_to_managed_id(
self, check_batch_cost_instance, mock_prisma_client, mock_llm_router
):
"""CheckBatchCost must convert a raw provider output_file_id to a managed base64 ID.
Without this, GET /batches/{id} returns a raw file ID that cannot be routed
through the proxy, causing API_KEY errors when clients call GET /files/{id}/content.
"""
mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock(
return_value=0
)
mock_prisma_client.db.litellm_managedobjecttable.update = AsyncMock()
mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock(
return_value=None
)
mock_job = MagicMock()
mock_job.id = "job-raw-file-1"
mock_job.unified_object_id = "dW5pZmllZF9iYXRjaF9pZA=="
mock_job.created_by = "user-1"
mock_job.team_id = None
check_batch_cost_instance._has_batch_processed_column = True
mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock(
return_value=[mock_job]
)
raw_output_file_id = "file-batch-output-abc123"
raw_error_file_id = "file-batch-error-xyz456"
fake_managed_output_id = "bGl0ZWxsbV9wcm94eTo6b3V0cHV0"
fake_managed_error_id = "bGl0ZWxsbV9wcm94eTo6ZXJyb3I="
mock_response = MagicMock()
mock_response.status = "completed"
mock_response.output_file_id = raw_output_file_id
mock_response.error_file_id = raw_error_file_id
mock_response.model_dump_json.return_value = (
'{"id":"batch-1","status":"completed"}'
)
mock_llm_router.aretrieve_batch = AsyncMock(return_value=mock_response)
mock_llm_router.get_deployment_credentials_with_provider = MagicMock(
return_value={"api_key": "sk-test"}
)
mock_deployment = MagicMock()
mock_deployment.litellm_params.custom_llm_provider = "azure"
mock_deployment.litellm_params.model = "azure/gpt-5-mini"
mock_deployment.model_name = "gpt-5-batch"
mock_deployment.model_info.model_dump.return_value = {}
mock_llm_router.get_deployment = MagicMock(return_value=mock_deployment)
mock_hook = MagicMock()
mock_hook.get_unified_output_file_id.side_effect = [
fake_managed_output_id,
fake_managed_error_id,
]
mock_hook.store_unified_file_id = AsyncMock()
check_batch_cost_instance.proxy_logging_obj.get_proxy_hook.return_value = (
mock_hook
)
mock_file_content = MagicMock()
mock_file_content.content = b'{"id":"req-1"}'
decoded_id = "llm_model_id,model-123;llm_batch_id,batch-456;"
with (
patch(
"litellm.proxy.openai_files_endpoints.common_utils._is_base64_encoded_unified_file_id",
# call 1: job unified_object_id decode, call 2: existing raw check for output_file_id,
# call 3: fix guard for output_file_id, call 4: fix guard for error_file_id
side_effect=[decoded_id, None, None, None],
),
patch(
"litellm.proxy.openai_files_endpoints.common_utils.get_model_id_from_unified_batch_id",
return_value="model-123",
),
patch(
"litellm.proxy.openai_files_endpoints.common_utils.get_batch_id_from_unified_batch_id",
return_value="batch-456",
),
patch(
"litellm.files.main.afile_content",
new_callable=AsyncMock,
return_value=mock_file_content,
),
patch(
"litellm.batches.batch_utils._get_file_content_as_dictionary",
return_value=[{"id": "req-1"}],
),
patch(
"litellm.batches.batch_utils.calculate_batch_cost_and_usage",
new_callable=AsyncMock,
return_value=(
0.01,
{"prompt_tokens": 10, "completion_tokens": 5},
["gpt-4"],
),
),
patch(
"litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider",
return_value=("gpt-5-mini", "azure", None, None),
),
patch(
"litellm.litellm_core_utils.litellm_logging.Logging"
) as mock_logging_cls,
):
mock_logging_obj = MagicMock()
mock_logging_obj.async_success_handler = AsyncMock()
mock_logging_cls.return_value = mock_logging_obj
await check_batch_cost_instance.check_batch_cost()
assert mock_hook.get_unified_output_file_id.call_count == 2
mock_hook.get_unified_output_file_id.assert_any_call(
output_file_id=raw_output_file_id,
model_id="model-123",
model_name="gpt-5-mini",
)
mock_hook.get_unified_output_file_id.assert_any_call(
output_file_id=raw_error_file_id,
model_id="model-123",
model_name="gpt-5-mini",
)
assert mock_hook.store_unified_file_id.await_count == 2
# {raw_file_id: managed_file_id} for each store call
stored = {
next(iter(c[1]["model_mappings"].values())): c[1]["file_id"]
for c in mock_hook.store_unified_file_id.call_args_list
}
assert stored == {
raw_output_file_id: fake_managed_output_id,
raw_error_file_id: fake_managed_error_id,
}
assert mock_response.output_file_id == fake_managed_output_id
assert mock_response.error_file_id == fake_managed_error_id

View File

@ -184,6 +184,7 @@ async def test_check_batch_cost_should_call_afile_content_directly_with_credenti
mock_job.unified_object_id = unified_object_id
mock_job.created_by = "user-A"
mock_job.id = "job-1"
mock_job.team_id = None
# Mock prisma
mock_prisma = MagicMock()
@ -196,6 +197,10 @@ async def test_check_batch_cost_should_call_afile_content_directly_with_credenti
mock_proxy_logging = MagicMock()
mock_managed_files_hook = MagicMock()
mock_managed_files_hook.afile_content = AsyncMock()
mock_managed_files_hook.store_unified_file_id = AsyncMock()
mock_managed_files_hook.get_unified_output_file_id.return_value = (
"bGl0ZWxsbV9wcm94eTo6bWFuYWdlZA=="
)
mock_proxy_logging.get_proxy_hook = MagicMock(return_value=mock_managed_files_hook)
# Mock the batch response (completed, with output file)