From cbdc70d5442e6ccb9f91c011feef9d58be6522dc Mon Sep 17 00:00:00 2001 From: Sameer Kankute Date: Fri, 15 May 2026 17:11:38 +0530 Subject: [PATCH] fix(managed_batches): convert raw output_file_id to managed ID in CheckBatchCost poller (#27984) * fix(managed_batches): convert raw output_file_id to managed ID in CheckBatchCost poller CheckBatchCost bypasses async_post_call_success_hook, causing raw provider output_file_ids to be persisted in LiteLLM_ManagedObjectTable. This fix converts output_file_id and error_file_id to managed base64 IDs before the DB write. Co-authored-by: Cursor * fix(check_batch_cost): persist managed file before mutating response and propagate team_id - Move setattr after store_unified_file_id so the response only receives the managed ID once the DB record is successfully written. Avoids serializing an orphaned managed ID into file_object when the store call fails. - Populate team_id on the minimal UserAPIKeyAuth from job.team_id so the managed file record is created with the correct team ownership, allowing other team members to access the batch output file via /files/{id}/content. Co-authored-by: Yassin Kortam * test(managed_batches): extend test to cover error_file_id conversion Co-authored-by: Cursor * fix managed file test --------- Co-authored-by: Cursor Co-authored-by: Yassin Kortam --- .../proxy/common_utils/check_batch_cost.py | 36 +++++ .../proxy_unit_tests/test_check_batch_cost.py | 142 +++++++++++++++++- .../proxy/test_managed_files_access_check.py | 5 + 3 files changed, 182 insertions(+), 1 deletion(-) diff --git a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py b/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py index 356f6ecd4b..ee7745d0ad 100644 --- a/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py +++ b/enterprise/litellm_enterprise/proxy/common_utils/check_batch_cost.py @@ -300,6 +300,42 @@ class CheckBatchCost: custom_llm_provider=custom_llm_provider, ) + # CheckBatchCost bypasses async_post_call_success_hook, so convert raw + # output/error file IDs to managed base64 IDs before the DB write here. + managed_files_hook = self.proxy_logging_obj.get_proxy_hook("managed_files") + if managed_files_hook is not None: + from litellm.proxy._types import UserAPIKeyAuth + _minimal_auth = UserAPIKeyAuth( + user_id=job.created_by or "default-user-id", + team_id=getattr(job, "team_id", None), + ) + for _file_attr in ["output_file_id", "error_file_id"]: + _raw_file_id = getattr(response, _file_attr, None) + if _raw_file_id and not _is_base64_encoded_unified_file_id(_raw_file_id): + try: + _unified_file_id = managed_files_hook.get_unified_output_file_id( + output_file_id=_raw_file_id, + model_id=model_id, + model_name=str(model_name) if model_name else deployment_info.model_name or None, + ) + await managed_files_hook.store_unified_file_id( + file_id=_unified_file_id, + file_object=None, + litellm_parent_otel_span=None, + model_mappings={model_id: _raw_file_id}, + user_api_key_dict=_minimal_auth, + ) + setattr(response, _file_attr, _unified_file_id) + verbose_proxy_logger.info( + f"CheckBatchCost: converted {_file_attr} " + f"{_raw_file_id!r} -> managed ID for batch {batch_id}" + ) + except Exception as _e: + verbose_proxy_logger.warning( + f"CheckBatchCost: failed to create managed file ID for " + f"{_file_attr}={_raw_file_id!r}: {_e}" + ) + # Pass deployment model_info so custom batch pricing # (input_cost_per_token_batches etc.) is used for cost calc deployment_model_info = deployment_info.model_info.model_dump() if deployment_info.model_info else {} diff --git a/tests/proxy_unit_tests/test_check_batch_cost.py b/tests/proxy_unit_tests/test_check_batch_cost.py index 8b4ce1e382..e8acaf6fea 100644 --- a/tests/proxy_unit_tests/test_check_batch_cost.py +++ b/tests/proxy_unit_tests/test_check_batch_cost.py @@ -22,7 +22,9 @@ class TestCheckBatchCost: @pytest.fixture def mock_proxy_logging_obj(self): - return MagicMock() + mock = MagicMock() + mock.get_proxy_hook.return_value = None + return mock @pytest.fixture def mock_llm_router(self): @@ -372,3 +374,141 @@ class TestCheckBatchCost: update_data["batch_processed"] is True ), "update() must include batch_processed=True when column is present" assert update_data["status"] == "complete" + + @pytest.mark.asyncio + async def test_raw_output_file_id_converted_to_managed_id( + self, check_batch_cost_instance, mock_prisma_client, mock_llm_router + ): + """CheckBatchCost must convert a raw provider output_file_id to a managed base64 ID. + + Without this, GET /batches/{id} returns a raw file ID that cannot be routed + through the proxy, causing API_KEY errors when clients call GET /files/{id}/content. + """ + mock_prisma_client.db.litellm_managedobjecttable.update_many = AsyncMock( + return_value=0 + ) + mock_prisma_client.db.litellm_managedobjecttable.update = AsyncMock() + mock_prisma_client.db.litellm_usertable.find_unique = AsyncMock( + return_value=None + ) + + mock_job = MagicMock() + mock_job.id = "job-raw-file-1" + mock_job.unified_object_id = "dW5pZmllZF9iYXRjaF9pZA==" + mock_job.created_by = "user-1" + mock_job.team_id = None + + check_batch_cost_instance._has_batch_processed_column = True + mock_prisma_client.db.litellm_managedobjecttable.find_many = AsyncMock( + return_value=[mock_job] + ) + + raw_output_file_id = "file-batch-output-abc123" + raw_error_file_id = "file-batch-error-xyz456" + fake_managed_output_id = "bGl0ZWxsbV9wcm94eTo6b3V0cHV0" + fake_managed_error_id = "bGl0ZWxsbV9wcm94eTo6ZXJyb3I=" + + mock_response = MagicMock() + mock_response.status = "completed" + mock_response.output_file_id = raw_output_file_id + mock_response.error_file_id = raw_error_file_id + mock_response.model_dump_json.return_value = ( + '{"id":"batch-1","status":"completed"}' + ) + + mock_llm_router.aretrieve_batch = AsyncMock(return_value=mock_response) + mock_llm_router.get_deployment_credentials_with_provider = MagicMock( + return_value={"api_key": "sk-test"} + ) + + mock_deployment = MagicMock() + mock_deployment.litellm_params.custom_llm_provider = "azure" + mock_deployment.litellm_params.model = "azure/gpt-5-mini" + mock_deployment.model_name = "gpt-5-batch" + mock_deployment.model_info.model_dump.return_value = {} + mock_llm_router.get_deployment = MagicMock(return_value=mock_deployment) + + mock_hook = MagicMock() + mock_hook.get_unified_output_file_id.side_effect = [ + fake_managed_output_id, + fake_managed_error_id, + ] + mock_hook.store_unified_file_id = AsyncMock() + check_batch_cost_instance.proxy_logging_obj.get_proxy_hook.return_value = ( + mock_hook + ) + + mock_file_content = MagicMock() + mock_file_content.content = b'{"id":"req-1"}' + decoded_id = "llm_model_id,model-123;llm_batch_id,batch-456;" + + with ( + patch( + "litellm.proxy.openai_files_endpoints.common_utils._is_base64_encoded_unified_file_id", + # call 1: job unified_object_id decode, call 2: existing raw check for output_file_id, + # call 3: fix guard for output_file_id, call 4: fix guard for error_file_id + side_effect=[decoded_id, None, None, None], + ), + patch( + "litellm.proxy.openai_files_endpoints.common_utils.get_model_id_from_unified_batch_id", + return_value="model-123", + ), + patch( + "litellm.proxy.openai_files_endpoints.common_utils.get_batch_id_from_unified_batch_id", + return_value="batch-456", + ), + patch( + "litellm.files.main.afile_content", + new_callable=AsyncMock, + return_value=mock_file_content, + ), + patch( + "litellm.batches.batch_utils._get_file_content_as_dictionary", + return_value=[{"id": "req-1"}], + ), + patch( + "litellm.batches.batch_utils.calculate_batch_cost_and_usage", + new_callable=AsyncMock, + return_value=( + 0.01, + {"prompt_tokens": 10, "completion_tokens": 5}, + ["gpt-4"], + ), + ), + patch( + "litellm.litellm_core_utils.get_llm_provider_logic.get_llm_provider", + return_value=("gpt-5-mini", "azure", None, None), + ), + patch( + "litellm.litellm_core_utils.litellm_logging.Logging" + ) as mock_logging_cls, + ): + mock_logging_obj = MagicMock() + mock_logging_obj.async_success_handler = AsyncMock() + mock_logging_cls.return_value = mock_logging_obj + + await check_batch_cost_instance.check_batch_cost() + + assert mock_hook.get_unified_output_file_id.call_count == 2 + mock_hook.get_unified_output_file_id.assert_any_call( + output_file_id=raw_output_file_id, + model_id="model-123", + model_name="gpt-5-mini", + ) + mock_hook.get_unified_output_file_id.assert_any_call( + output_file_id=raw_error_file_id, + model_id="model-123", + model_name="gpt-5-mini", + ) + assert mock_hook.store_unified_file_id.await_count == 2 + # {raw_file_id: managed_file_id} for each store call + stored = { + next(iter(c[1]["model_mappings"].values())): c[1]["file_id"] + for c in mock_hook.store_unified_file_id.call_args_list + } + assert stored == { + raw_output_file_id: fake_managed_output_id, + raw_error_file_id: fake_managed_error_id, + } + assert mock_response.output_file_id == fake_managed_output_id + assert mock_response.error_file_id == fake_managed_error_id diff --git a/tests/test_litellm/enterprise/proxy/test_managed_files_access_check.py b/tests/test_litellm/enterprise/proxy/test_managed_files_access_check.py index b87c833531..d9a0b27539 100644 --- a/tests/test_litellm/enterprise/proxy/test_managed_files_access_check.py +++ b/tests/test_litellm/enterprise/proxy/test_managed_files_access_check.py @@ -184,6 +184,7 @@ async def test_check_batch_cost_should_call_afile_content_directly_with_credenti mock_job.unified_object_id = unified_object_id mock_job.created_by = "user-A" mock_job.id = "job-1" + mock_job.team_id = None # Mock prisma mock_prisma = MagicMock() @@ -196,6 +197,10 @@ async def test_check_batch_cost_should_call_afile_content_directly_with_credenti mock_proxy_logging = MagicMock() mock_managed_files_hook = MagicMock() mock_managed_files_hook.afile_content = AsyncMock() + mock_managed_files_hook.store_unified_file_id = AsyncMock() + mock_managed_files_hook.get_unified_output_file_id.return_value = ( + "bGl0ZWxsbV9wcm94eTo6bWFuYWdlZA==" + ) mock_proxy_logging.get_proxy_hook = MagicMock(return_value=mock_managed_files_hook) # Mock the batch response (completed, with output file)