fix(containers): record ownership for service-account keys + fix Prisma Json serialization (#28990)

* fix(containers): record ownership for service-account keys + fix Prisma Json field serialization

- Track containers created implicitly via /v1/responses by extracting container IDs
  from the response output and calling record_container_owner for each one, so
  subsequent file-API calls from the same service account pass ownership checks.
- Fix DataError: Prisma Python requires Json fields to be JSON strings; serialize
  file_object with json.dumps() before insert/update in LiteLLM_ManagedObjectTable.
- Add collect_container_ids_from_responses_response utility to responses/utils.py
  that walks all output item shapes (code_interpreter_call, message annotations).
- Tests: two new cases covering the responses-tracking path and the end-to-end
  record-then-assert flow for service accounts with team scope.

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(containers): swallow all exceptions in ownership hook; tighten file_object_json type to str

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix(containers): parse file_object JSON string in existing ownership test

Co-authored-by: Cursor <cursoragent@cursor.com>

* fix: container ownership recording bugs

- Remove unreachable _aresponses_websocket from route_type set in
  base_process_llm_request; the WebSocket endpoint never flows through
  base_process_llm_request, so this branch was dead code that gave a
  false impression of coverage.
- Drop the HTTPException re-raise in record_container_owners_from_responses_response
  so per-container failures (including HTTP 403/500 from conflicting
  ownership rows) no longer abort the batch and skip recording for the
  remaining container IDs in the same response.

Co-authored-by: Yassin Kortam <yassin@berri.ai>

* fix(containers): record ownership for streaming /v1/responses too

Streaming /v1/responses returns through the select_data_generator
branch in base_process_llm_request and bypasses the non-streaming
ownership tail, so code-interpreter containers created mid-stream
were never written to LiteLLM_ManagedObjectTable. Follow-up file API
calls would then 403.

Wrap the SSE generator so container ownership is recorded once the
upstream iterator finishes assembling completed_response. Also covers
the background-polling path, which loops body_iterator end-to-end.

Co-authored-by: Yassin Kortam <yassin@berri.ai>

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: Yassin Kortam <yassin@berri.ai>
This commit is contained in:
Sameer Kankute 2026-05-28 09:30:07 +05:30 committed by GitHub
parent 9cac0471ae
commit 157e7a0f20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 444 additions and 4 deletions

View File

@ -1368,6 +1368,21 @@ class ProxyBaseLLMRequestProcessing:
user_api_key_dict=user_api_key_dict,
request_data=self.data,
)
if route_type == "aresponses":
# Streaming /v1/responses returns here without
# reaching the non-streaming ownership tail below.
# Wrap the SSE generator so container ownership is
# written once the upstream iterator finishes
# assembling ``completed_response`` — otherwise
# code-interpreter containers created during the
# stream stay unregistered and follow-up file API
# calls 403. Covers the background-polling path
# too, which loops ``body_iterator`` end-to-end.
selected_data_generator = ProxyBaseLLMRequestProcessing._wrap_responses_stream_for_container_ownership(
original_stream_response=response,
wrapped_generator=selected_data_generator,
user_api_key_dict=user_api_key_dict,
)
return await create_response(
generator=selected_data_generator,
media_type="text/event-stream",
@ -1483,8 +1498,93 @@ class ProxyBaseLLMRequestProcessing:
await check_response_size_is_safe(response=response)
if route_type in {"aresponses", "aget_responses"}:
await ProxyBaseLLMRequestProcessing._record_container_owners_from_responses_if_needed(
response=response,
user_api_key_dict=user_api_key_dict,
)
return response
@staticmethod
async def _record_container_owners_from_responses_if_needed(
response: Any,
user_api_key_dict: UserAPIKeyAuth,
) -> None:
"""Register code-interpreter containers so follow-up file APIs pass ownership checks."""
from litellm.proxy.container_endpoints.ownership import (
record_container_owners_from_responses_response,
)
if response is None:
return
try:
await record_container_owners_from_responses_response(
response=response,
user_api_key_dict=user_api_key_dict,
)
except Exception as e:
verbose_proxy_logger.exception(
"Container ownership recording failed after responses call: %s",
e,
)
@staticmethod
def _extract_completed_responses_response(stream_response: Any) -> Any:
"""Pull the assembled ``ResponsesAPIResponse`` off a streaming iterator.
``ResponsesAPIStreamingIterator`` stores the terminal stream event
(``response.completed`` / ``response.incomplete`` / ``response.failed``)
in ``completed_response``; the actual response body hangs off
that event's ``.response`` attribute. Some iterators store the
``ResponsesAPIResponse`` directly. Handle both shapes so the
container-ownership recording path can walk ``.output`` either way.
"""
completed = getattr(stream_response, "completed_response", None)
if completed is None:
return None
response_obj = getattr(completed, "response", None)
if response_obj is not None:
return response_obj
return completed
@staticmethod
async def _wrap_responses_stream_for_container_ownership(
original_stream_response: Any,
wrapped_generator: Any,
user_api_key_dict: UserAPIKeyAuth,
):
"""Forward SSE chunks, then record container ownership at stream end.
Streaming ``/v1/responses`` short-circuits out of
``base_process_llm_request`` before the non-streaming ownership
tail runs, so without this wrap the
``LiteLLM_ManagedObjectTable`` row for any container created
during the stream is never written and follow-up file API calls
return 403.
"""
try:
async for chunk in wrapped_generator:
yield chunk
finally:
try:
completed_obj = (
ProxyBaseLLMRequestProcessing._extract_completed_responses_response(
original_stream_response
)
)
if completed_obj is not None:
await ProxyBaseLLMRequestProcessing._record_container_owners_from_responses_if_needed(
response=completed_obj,
user_api_key_dict=user_api_key_dict,
)
except Exception as e:
verbose_proxy_logger.exception(
"Container ownership recording failed after streaming responses call: %s",
e,
)
async def base_passthrough_process_llm_request(
self,
request: Request,

View File

@ -117,6 +117,58 @@ async def _get_prisma_client():
return prisma_client
def _custom_llm_provider_from_responses_response(
response: Any,
default: str = "openai",
) -> str:
hidden_params: Dict[str, Any] = {}
if isinstance(response, dict):
hidden_params = response.get("_hidden_params") or {}
else:
hidden_params = getattr(response, "_hidden_params", None) or {}
provider = hidden_params.get("custom_llm_provider")
if isinstance(provider, str) and provider:
return provider
return default
async def record_container_owners_from_responses_response(
response: Any,
user_api_key_dict: UserAPIKeyAuth,
custom_llm_provider: Optional[str] = None,
) -> None:
"""Track containers created implicitly by code interpreter in /v1/responses."""
container_ids = (
ResponsesAPIRequestUtils.collect_container_ids_from_responses_response(response)
)
if not container_ids:
return
resolved_provider = (
custom_llm_provider or _custom_llm_provider_from_responses_response(response)
)
for container_id in container_ids:
try:
await record_container_owner(
response={"id": container_id, "object": "container"},
user_api_key_dict=user_api_key_dict,
custom_llm_provider=resolved_provider,
)
except Exception as e:
# Per-container errors (including ``HTTPException`` from
# conflicting/forbidden ownership rows) must not abort the
# batch — other containers in the same response should still
# get recorded so their follow-up file API calls don't 403.
verbose_proxy_logger.exception(
"Failed to record container ownership from responses output "
"for container_id=%s: %s",
container_id,
e,
)
async def record_container_owner(
response: Any,
user_api_key_dict: UserAPIKeyAuth,
@ -151,6 +203,8 @@ async def record_container_owner(
file_object = _dump_response(response)
file_object["custom_llm_provider"] = resolved_provider
file_object["provider_container_id"] = original_container_id
# Prisma Python requires Json fields to be serialized as a JSON string.
file_object_json: str = json.dumps(file_object)
prisma_client = await _get_prisma_client()
if prisma_client is None:
@ -172,7 +226,7 @@ async def record_container_owner(
where={"model_object_id": model_object_id},
data={
"unified_object_id": container_id,
"file_object": file_object,
"file_object": file_object_json,
"updated_by": owner,
},
)
@ -181,7 +235,7 @@ async def record_container_owner(
data={
"unified_object_id": container_id,
"model_object_id": model_object_id,
"file_object": file_object,
"file_object": file_object_json,
"file_purpose": CONTAINER_OBJECT_PURPOSE,
"created_by": owner,
"updated_by": owner,

View File

@ -738,6 +738,98 @@ class ResponsesAPIRequestUtils:
model_id,
)
@staticmethod
def _collect_container_ids_from_annotations(
annotations: Any,
collected: set[str],
) -> None:
if not annotations or not isinstance(annotations, list):
return
for ann in annotations:
ResponsesAPIRequestUtils._collect_container_ids_from_output_item(
ann, collected
)
@staticmethod
def _collect_container_ids_from_message_content(
content: Any,
collected: set[str],
) -> None:
if not content:
return
if isinstance(content, list):
for part in content:
if isinstance(part, dict):
ResponsesAPIRequestUtils._collect_container_ids_from_annotations(
part.get("annotations"),
collected,
)
else:
ResponsesAPIRequestUtils._collect_container_ids_from_annotations(
getattr(part, "annotations", None),
collected,
)
@staticmethod
def _collect_container_ids_from_output_item(
item: Any,
collected: set[str],
) -> None:
"""Collect managed or raw ``container_id`` values from one output item."""
if item is None:
return
if isinstance(item, dict):
cid = item.get("container_id")
if isinstance(cid, str) and cid:
collected.add(cid)
nested = item.get("code_interpreter_call")
if isinstance(nested, dict):
nc = nested.get("container_id")
if isinstance(nc, str) and nc:
collected.add(nc)
if item.get("type") == "message":
ResponsesAPIRequestUtils._collect_container_ids_from_message_content(
item.get("content"),
collected,
)
return
cid_attr = getattr(item, "container_id", None)
if isinstance(cid_attr, str) and cid_attr:
collected.add(cid_attr)
nested_obj = getattr(item, "code_interpreter_call", None)
if nested_obj is not None:
ResponsesAPIRequestUtils._collect_container_ids_from_output_item(
nested_obj, collected
)
if getattr(item, "type", None) == "message":
ResponsesAPIRequestUtils._collect_container_ids_from_message_content(
getattr(item, "content", None),
collected,
)
@staticmethod
def collect_container_ids_from_responses_response(response: Any) -> list[str]:
"""Return unique container IDs referenced in a Responses API payload."""
if response is None:
return []
if isinstance(response, dict):
output = response.get("output", [])
else:
output = getattr(response, "output", []) or []
collected: set[str] = set()
if output:
for item in output:
ResponsesAPIRequestUtils._collect_container_ids_from_output_item(
item, collected
)
return list(collected)
@staticmethod
def _update_container_ids_in_response(
responses_api_response: Union[ResponsesAPIResponse, Dict[str, Any]],

View File

@ -1,3 +1,4 @@
import json
import sys
from types import SimpleNamespace
from unittest.mock import AsyncMock
@ -91,8 +92,9 @@ async def test_should_not_mutate_dict_container_response_when_recording_owner(
assert returned == {"id": "cntr_provider", "object": "container"}
data = table.create.await_args.kwargs["data"]
assert data["file_object"]["custom_llm_provider"] == "openai"
assert data["file_object"]["provider_container_id"] == "cntr_provider"
file_obj = json.loads(data["file_object"])
assert file_obj["custom_llm_provider"] == "openai"
assert file_obj["provider_container_id"] == "cntr_provider"
@pytest.mark.asyncio
@ -913,3 +915,195 @@ async def test_admin_with_identity_records_container_ownership(monkeypatch):
table.create.assert_awaited_once()
created_data = table.create.await_args.kwargs["data"]
assert created_data["created_by"] == "proxy-admin"
@pytest.mark.asyncio
async def test_should_record_containers_from_responses_output_for_service_account(
monkeypatch,
):
table = AsyncMock()
table.find_unique.return_value = None
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
auth = UserAPIKeyAuth(team_id="team-1")
encoded_container_id = (
"cntr_bGl0ZWxsbTpjdXN0b21fbGxtX3Byb3ZpZGVyOmF6dXJlO21vZGVsX2lkOmR"
"lZi0xMjM7Y29udGFpbmVyX2lkOmNudHJfbmF0aXZl"
)
responses_payload = {
"output": [
{
"type": "message",
"content": [
{
"type": "output_text",
"annotations": [
{
"type": "container_file_citation",
"container_id": encoded_container_id,
"file_id": "cfile_abc",
}
],
}
],
}
],
"_hidden_params": {"custom_llm_provider": "azure"},
}
await ownership.record_container_owners_from_responses_response(
response=responses_payload,
user_api_key_dict=auth,
)
table.create.assert_awaited_once()
created_data = table.create.await_args.kwargs["data"]
assert created_data["created_by"] == "team:team-1"
assert created_data["unified_object_id"] == encoded_container_id
@pytest.mark.asyncio
async def test_service_account_can_access_container_after_responses_tracking(
monkeypatch,
):
encoded_container_id = (
"cntr_bGl0ZWxsbTpjdXN0b21fbGxtX3Byb3ZpZGVyOmF6dXJlO21vZGVsX2lkOmR"
"lZi0xMjM7Y29udGFpbmVyX2lkOmNudHJfbmF0aXZl"
)
table = AsyncMock()
table.find_unique.return_value = None
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
auth = UserAPIKeyAuth(team_id="team-1")
await ownership.record_container_owners_from_responses_response(
response={
"output": [
{
"type": "code_interpreter_call",
"container_id": encoded_container_id,
}
],
"_hidden_params": {"custom_llm_provider": "azure"},
},
user_api_key_dict=auth,
)
original_id, provider = await ownership.assert_user_can_access_container(
container_id=encoded_container_id,
user_api_key_dict=auth,
custom_llm_provider="azure",
)
assert original_id == "cntr_native"
assert provider == "azure"
@pytest.mark.asyncio
async def test_should_record_container_ownership_after_streaming_responses_finish(
monkeypatch,
):
"""Streaming /v1/responses calls return through the
``select_data_generator`` branch and never reach the non-streaming
container-ownership tail. The wrapper must read
``completed_response`` off the upstream iterator once iteration
finishes and write the row, otherwise code-interpreter containers
created during the stream stay unregistered and follow-up file API
calls 403.
"""
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
encoded_container_id = (
"cntr_bGl0ZWxsbTpjdXN0b21fbGxtX3Byb3ZpZGVyOmF6dXJlO21vZGVsX2lkOmR"
"lZi0xMjM7Y29udGFpbmVyX2lkOmNudHJfbmF0aXZl"
)
response_body = SimpleNamespace(
output=[
SimpleNamespace(
type="code_interpreter_call",
container_id=encoded_container_id,
code_interpreter_call=None,
)
]
)
stream_response = SimpleNamespace(
completed_response=SimpleNamespace(response=response_body),
_hidden_params={"custom_llm_provider": "azure"},
)
async def fake_sse_generator():
yield "data: chunk-1\n\n"
yield "data: chunk-2\n\n"
table = AsyncMock()
table.find_unique.return_value = None
prisma_client = SimpleNamespace(
db=SimpleNamespace(litellm_managedobjecttable=table)
)
monkeypatch.setattr(
ownership,
"_get_prisma_client",
AsyncMock(return_value=prisma_client),
)
auth = UserAPIKeyAuth(team_id="team-1")
wrapped = (
ProxyBaseLLMRequestProcessing._wrap_responses_stream_for_container_ownership(
original_stream_response=stream_response,
wrapped_generator=fake_sse_generator(),
user_api_key_dict=auth,
)
)
chunks = [chunk async for chunk in wrapped]
assert chunks == ["data: chunk-1\n\n", "data: chunk-2\n\n"]
table.create.assert_awaited_once()
created_data = table.create.await_args.kwargs["data"]
assert created_data["created_by"] == "team:team-1"
assert created_data["unified_object_id"] == encoded_container_id
@pytest.mark.asyncio
async def test_streaming_ownership_wrap_no_op_when_stream_did_not_complete(
monkeypatch,
):
"""If the stream errored before ``response.completed``,
``completed_response`` is ``None`` we must skip the ownership
write rather than crash the response generator."""
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
stream_response = SimpleNamespace(completed_response=None)
async def fake_sse_generator():
yield "data: chunk-1\n\n"
record = AsyncMock()
monkeypatch.setattr(
ownership,
"record_container_owners_from_responses_response",
record,
)
wrapped = (
ProxyBaseLLMRequestProcessing._wrap_responses_stream_for_container_ownership(
original_stream_response=stream_response,
wrapped_generator=fake_sse_generator(),
user_api_key_dict=UserAPIKeyAuth(user_id="user-1"),
)
)
chunks = [chunk async for chunk in wrapped]
assert chunks == ["data: chunk-1\n\n"]
record.assert_not_awaited()