diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index ef1d64335b..a5865e71c2 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -1368,6 +1368,21 @@ class ProxyBaseLLMRequestProcessing: user_api_key_dict=user_api_key_dict, request_data=self.data, ) + if route_type == "aresponses": + # Streaming /v1/responses returns here without + # reaching the non-streaming ownership tail below. + # Wrap the SSE generator so container ownership is + # written once the upstream iterator finishes + # assembling ``completed_response`` — otherwise + # code-interpreter containers created during the + # stream stay unregistered and follow-up file API + # calls 403. Covers the background-polling path + # too, which loops ``body_iterator`` end-to-end. + selected_data_generator = ProxyBaseLLMRequestProcessing._wrap_responses_stream_for_container_ownership( + original_stream_response=response, + wrapped_generator=selected_data_generator, + user_api_key_dict=user_api_key_dict, + ) return await create_response( generator=selected_data_generator, media_type="text/event-stream", @@ -1483,8 +1498,93 @@ class ProxyBaseLLMRequestProcessing: await check_response_size_is_safe(response=response) + if route_type in {"aresponses", "aget_responses"}: + await ProxyBaseLLMRequestProcessing._record_container_owners_from_responses_if_needed( + response=response, + user_api_key_dict=user_api_key_dict, + ) + return response + @staticmethod + async def _record_container_owners_from_responses_if_needed( + response: Any, + user_api_key_dict: UserAPIKeyAuth, + ) -> None: + """Register code-interpreter containers so follow-up file APIs pass ownership checks.""" + from litellm.proxy.container_endpoints.ownership import ( + record_container_owners_from_responses_response, + ) + + if response is None: + return + + try: + await record_container_owners_from_responses_response( + response=response, + user_api_key_dict=user_api_key_dict, + ) + except Exception as e: + verbose_proxy_logger.exception( + "Container ownership recording failed after responses call: %s", + e, + ) + + @staticmethod + def _extract_completed_responses_response(stream_response: Any) -> Any: + """Pull the assembled ``ResponsesAPIResponse`` off a streaming iterator. + + ``ResponsesAPIStreamingIterator`` stores the terminal stream event + (``response.completed`` / ``response.incomplete`` / ``response.failed``) + in ``completed_response``; the actual response body hangs off + that event's ``.response`` attribute. Some iterators store the + ``ResponsesAPIResponse`` directly. Handle both shapes so the + container-ownership recording path can walk ``.output`` either way. + """ + completed = getattr(stream_response, "completed_response", None) + if completed is None: + return None + response_obj = getattr(completed, "response", None) + if response_obj is not None: + return response_obj + return completed + + @staticmethod + async def _wrap_responses_stream_for_container_ownership( + original_stream_response: Any, + wrapped_generator: Any, + user_api_key_dict: UserAPIKeyAuth, + ): + """Forward SSE chunks, then record container ownership at stream end. + + Streaming ``/v1/responses`` short-circuits out of + ``base_process_llm_request`` before the non-streaming ownership + tail runs, so without this wrap the + ``LiteLLM_ManagedObjectTable`` row for any container created + during the stream is never written and follow-up file API calls + return 403. + """ + try: + async for chunk in wrapped_generator: + yield chunk + finally: + try: + completed_obj = ( + ProxyBaseLLMRequestProcessing._extract_completed_responses_response( + original_stream_response + ) + ) + if completed_obj is not None: + await ProxyBaseLLMRequestProcessing._record_container_owners_from_responses_if_needed( + response=completed_obj, + user_api_key_dict=user_api_key_dict, + ) + except Exception as e: + verbose_proxy_logger.exception( + "Container ownership recording failed after streaming responses call: %s", + e, + ) + async def base_passthrough_process_llm_request( self, request: Request, diff --git a/litellm/proxy/container_endpoints/ownership.py b/litellm/proxy/container_endpoints/ownership.py index 57de6c4a63..e0015e112e 100644 --- a/litellm/proxy/container_endpoints/ownership.py +++ b/litellm/proxy/container_endpoints/ownership.py @@ -117,6 +117,58 @@ async def _get_prisma_client(): return prisma_client +def _custom_llm_provider_from_responses_response( + response: Any, + default: str = "openai", +) -> str: + hidden_params: Dict[str, Any] = {} + if isinstance(response, dict): + hidden_params = response.get("_hidden_params") or {} + else: + hidden_params = getattr(response, "_hidden_params", None) or {} + + provider = hidden_params.get("custom_llm_provider") + if isinstance(provider, str) and provider: + return provider + return default + + +async def record_container_owners_from_responses_response( + response: Any, + user_api_key_dict: UserAPIKeyAuth, + custom_llm_provider: Optional[str] = None, +) -> None: + """Track containers created implicitly by code interpreter in /v1/responses.""" + container_ids = ( + ResponsesAPIRequestUtils.collect_container_ids_from_responses_response(response) + ) + if not container_ids: + return + + resolved_provider = ( + custom_llm_provider or _custom_llm_provider_from_responses_response(response) + ) + + for container_id in container_ids: + try: + await record_container_owner( + response={"id": container_id, "object": "container"}, + user_api_key_dict=user_api_key_dict, + custom_llm_provider=resolved_provider, + ) + except Exception as e: + # Per-container errors (including ``HTTPException`` from + # conflicting/forbidden ownership rows) must not abort the + # batch — other containers in the same response should still + # get recorded so their follow-up file API calls don't 403. + verbose_proxy_logger.exception( + "Failed to record container ownership from responses output " + "for container_id=%s: %s", + container_id, + e, + ) + + async def record_container_owner( response: Any, user_api_key_dict: UserAPIKeyAuth, @@ -151,6 +203,8 @@ async def record_container_owner( file_object = _dump_response(response) file_object["custom_llm_provider"] = resolved_provider file_object["provider_container_id"] = original_container_id + # Prisma Python requires Json fields to be serialized as a JSON string. + file_object_json: str = json.dumps(file_object) prisma_client = await _get_prisma_client() if prisma_client is None: @@ -172,7 +226,7 @@ async def record_container_owner( where={"model_object_id": model_object_id}, data={ "unified_object_id": container_id, - "file_object": file_object, + "file_object": file_object_json, "updated_by": owner, }, ) @@ -181,7 +235,7 @@ async def record_container_owner( data={ "unified_object_id": container_id, "model_object_id": model_object_id, - "file_object": file_object, + "file_object": file_object_json, "file_purpose": CONTAINER_OBJECT_PURPOSE, "created_by": owner, "updated_by": owner, diff --git a/litellm/responses/utils.py b/litellm/responses/utils.py index 74e4d7a533..46a2894bd1 100644 --- a/litellm/responses/utils.py +++ b/litellm/responses/utils.py @@ -738,6 +738,98 @@ class ResponsesAPIRequestUtils: model_id, ) + @staticmethod + def _collect_container_ids_from_annotations( + annotations: Any, + collected: set[str], + ) -> None: + if not annotations or not isinstance(annotations, list): + return + for ann in annotations: + ResponsesAPIRequestUtils._collect_container_ids_from_output_item( + ann, collected + ) + + @staticmethod + def _collect_container_ids_from_message_content( + content: Any, + collected: set[str], + ) -> None: + if not content: + return + if isinstance(content, list): + for part in content: + if isinstance(part, dict): + ResponsesAPIRequestUtils._collect_container_ids_from_annotations( + part.get("annotations"), + collected, + ) + else: + ResponsesAPIRequestUtils._collect_container_ids_from_annotations( + getattr(part, "annotations", None), + collected, + ) + + @staticmethod + def _collect_container_ids_from_output_item( + item: Any, + collected: set[str], + ) -> None: + """Collect managed or raw ``container_id`` values from one output item.""" + if item is None: + return + + if isinstance(item, dict): + cid = item.get("container_id") + if isinstance(cid, str) and cid: + collected.add(cid) + nested = item.get("code_interpreter_call") + if isinstance(nested, dict): + nc = nested.get("container_id") + if isinstance(nc, str) and nc: + collected.add(nc) + if item.get("type") == "message": + ResponsesAPIRequestUtils._collect_container_ids_from_message_content( + item.get("content"), + collected, + ) + return + + cid_attr = getattr(item, "container_id", None) + if isinstance(cid_attr, str) and cid_attr: + collected.add(cid_attr) + + nested_obj = getattr(item, "code_interpreter_call", None) + if nested_obj is not None: + ResponsesAPIRequestUtils._collect_container_ids_from_output_item( + nested_obj, collected + ) + + if getattr(item, "type", None) == "message": + ResponsesAPIRequestUtils._collect_container_ids_from_message_content( + getattr(item, "content", None), + collected, + ) + + @staticmethod + def collect_container_ids_from_responses_response(response: Any) -> list[str]: + """Return unique container IDs referenced in a Responses API payload.""" + if response is None: + return [] + + if isinstance(response, dict): + output = response.get("output", []) + else: + output = getattr(response, "output", []) or [] + + collected: set[str] = set() + if output: + for item in output: + ResponsesAPIRequestUtils._collect_container_ids_from_output_item( + item, collected + ) + return list(collected) + @staticmethod def _update_container_ids_in_response( responses_api_response: Union[ResponsesAPIResponse, Dict[str, Any]], diff --git a/tests/test_litellm/containers/test_container_proxy_ownership.py b/tests/test_litellm/containers/test_container_proxy_ownership.py index c295805bdb..176405bb9c 100644 --- a/tests/test_litellm/containers/test_container_proxy_ownership.py +++ b/tests/test_litellm/containers/test_container_proxy_ownership.py @@ -1,3 +1,4 @@ +import json import sys from types import SimpleNamespace from unittest.mock import AsyncMock @@ -91,8 +92,9 @@ async def test_should_not_mutate_dict_container_response_when_recording_owner( assert returned == {"id": "cntr_provider", "object": "container"} data = table.create.await_args.kwargs["data"] - assert data["file_object"]["custom_llm_provider"] == "openai" - assert data["file_object"]["provider_container_id"] == "cntr_provider" + file_obj = json.loads(data["file_object"]) + assert file_obj["custom_llm_provider"] == "openai" + assert file_obj["provider_container_id"] == "cntr_provider" @pytest.mark.asyncio @@ -913,3 +915,195 @@ async def test_admin_with_identity_records_container_ownership(monkeypatch): table.create.assert_awaited_once() created_data = table.create.await_args.kwargs["data"] assert created_data["created_by"] == "proxy-admin" + + +@pytest.mark.asyncio +async def test_should_record_containers_from_responses_output_for_service_account( + monkeypatch, +): + table = AsyncMock() + table.find_unique.return_value = None + prisma_client = SimpleNamespace( + db=SimpleNamespace(litellm_managedobjecttable=table) + ) + monkeypatch.setattr( + ownership, + "_get_prisma_client", + AsyncMock(return_value=prisma_client), + ) + auth = UserAPIKeyAuth(team_id="team-1") + encoded_container_id = ( + "cntr_bGl0ZWxsbTpjdXN0b21fbGxtX3Byb3ZpZGVyOmF6dXJlO21vZGVsX2lkOmR" + "lZi0xMjM7Y29udGFpbmVyX2lkOmNudHJfbmF0aXZl" + ) + responses_payload = { + "output": [ + { + "type": "message", + "content": [ + { + "type": "output_text", + "annotations": [ + { + "type": "container_file_citation", + "container_id": encoded_container_id, + "file_id": "cfile_abc", + } + ], + } + ], + } + ], + "_hidden_params": {"custom_llm_provider": "azure"}, + } + + await ownership.record_container_owners_from_responses_response( + response=responses_payload, + user_api_key_dict=auth, + ) + + table.create.assert_awaited_once() + created_data = table.create.await_args.kwargs["data"] + assert created_data["created_by"] == "team:team-1" + assert created_data["unified_object_id"] == encoded_container_id + + +@pytest.mark.asyncio +async def test_service_account_can_access_container_after_responses_tracking( + monkeypatch, +): + encoded_container_id = ( + "cntr_bGl0ZWxsbTpjdXN0b21fbGxtX3Byb3ZpZGVyOmF6dXJlO21vZGVsX2lkOmR" + "lZi0xMjM7Y29udGFpbmVyX2lkOmNudHJfbmF0aXZl" + ) + table = AsyncMock() + table.find_unique.return_value = None + prisma_client = SimpleNamespace( + db=SimpleNamespace(litellm_managedobjecttable=table) + ) + monkeypatch.setattr( + ownership, + "_get_prisma_client", + AsyncMock(return_value=prisma_client), + ) + auth = UserAPIKeyAuth(team_id="team-1") + + await ownership.record_container_owners_from_responses_response( + response={ + "output": [ + { + "type": "code_interpreter_call", + "container_id": encoded_container_id, + } + ], + "_hidden_params": {"custom_llm_provider": "azure"}, + }, + user_api_key_dict=auth, + ) + + original_id, provider = await ownership.assert_user_can_access_container( + container_id=encoded_container_id, + user_api_key_dict=auth, + custom_llm_provider="azure", + ) + assert original_id == "cntr_native" + assert provider == "azure" + + +@pytest.mark.asyncio +async def test_should_record_container_ownership_after_streaming_responses_finish( + monkeypatch, +): + """Streaming /v1/responses calls return through the + ``select_data_generator`` branch and never reach the non-streaming + container-ownership tail. The wrapper must read + ``completed_response`` off the upstream iterator once iteration + finishes and write the row, otherwise code-interpreter containers + created during the stream stay unregistered and follow-up file API + calls 403. + """ + from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing + + encoded_container_id = ( + "cntr_bGl0ZWxsbTpjdXN0b21fbGxtX3Byb3ZpZGVyOmF6dXJlO21vZGVsX2lkOmR" + "lZi0xMjM7Y29udGFpbmVyX2lkOmNudHJfbmF0aXZl" + ) + response_body = SimpleNamespace( + output=[ + SimpleNamespace( + type="code_interpreter_call", + container_id=encoded_container_id, + code_interpreter_call=None, + ) + ] + ) + stream_response = SimpleNamespace( + completed_response=SimpleNamespace(response=response_body), + _hidden_params={"custom_llm_provider": "azure"}, + ) + + async def fake_sse_generator(): + yield "data: chunk-1\n\n" + yield "data: chunk-2\n\n" + + table = AsyncMock() + table.find_unique.return_value = None + prisma_client = SimpleNamespace( + db=SimpleNamespace(litellm_managedobjecttable=table) + ) + monkeypatch.setattr( + ownership, + "_get_prisma_client", + AsyncMock(return_value=prisma_client), + ) + auth = UserAPIKeyAuth(team_id="team-1") + + wrapped = ( + ProxyBaseLLMRequestProcessing._wrap_responses_stream_for_container_ownership( + original_stream_response=stream_response, + wrapped_generator=fake_sse_generator(), + user_api_key_dict=auth, + ) + ) + + chunks = [chunk async for chunk in wrapped] + assert chunks == ["data: chunk-1\n\n", "data: chunk-2\n\n"] + + table.create.assert_awaited_once() + created_data = table.create.await_args.kwargs["data"] + assert created_data["created_by"] == "team:team-1" + assert created_data["unified_object_id"] == encoded_container_id + + +@pytest.mark.asyncio +async def test_streaming_ownership_wrap_no_op_when_stream_did_not_complete( + monkeypatch, +): + """If the stream errored before ``response.completed``, + ``completed_response`` is ``None`` — we must skip the ownership + write rather than crash the response generator.""" + from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing + + stream_response = SimpleNamespace(completed_response=None) + + async def fake_sse_generator(): + yield "data: chunk-1\n\n" + + record = AsyncMock() + monkeypatch.setattr( + ownership, + "record_container_owners_from_responses_response", + record, + ) + + wrapped = ( + ProxyBaseLLMRequestProcessing._wrap_responses_stream_for_container_ownership( + original_stream_response=stream_response, + wrapped_generator=fake_sse_generator(), + user_api_key_dict=UserAPIKeyAuth(user_id="user-1"), + ) + ) + chunks = [chunk async for chunk in wrapped] + + assert chunks == ["data: chunk-1\n\n"] + record.assert_not_awaited()