fix(containers): record ownership for service-account keys + fix Prisma Json serialization (#28990)
* fix(containers): record ownership for service-account keys + fix Prisma Json field serialization - Track containers created implicitly via /v1/responses by extracting container IDs from the response output and calling record_container_owner for each one, so subsequent file-API calls from the same service account pass ownership checks. - Fix DataError: Prisma Python requires Json fields to be JSON strings; serialize file_object with json.dumps() before insert/update in LiteLLM_ManagedObjectTable. - Add collect_container_ids_from_responses_response utility to responses/utils.py that walks all output item shapes (code_interpreter_call, message annotations). - Tests: two new cases covering the responses-tracking path and the end-to-end record-then-assert flow for service accounts with team scope. Co-authored-by: Cursor <cursoragent@cursor.com> * fix(containers): swallow all exceptions in ownership hook; tighten file_object_json type to str Co-authored-by: Cursor <cursoragent@cursor.com> * fix(containers): parse file_object JSON string in existing ownership test Co-authored-by: Cursor <cursoragent@cursor.com> * fix: container ownership recording bugs - Remove unreachable _aresponses_websocket from route_type set in base_process_llm_request; the WebSocket endpoint never flows through base_process_llm_request, so this branch was dead code that gave a false impression of coverage. - Drop the HTTPException re-raise in record_container_owners_from_responses_response so per-container failures (including HTTP 403/500 from conflicting ownership rows) no longer abort the batch and skip recording for the remaining container IDs in the same response. Co-authored-by: Yassin Kortam <yassin@berri.ai> * fix(containers): record ownership for streaming /v1/responses too Streaming /v1/responses returns through the select_data_generator branch in base_process_llm_request and bypasses the non-streaming ownership tail, so code-interpreter containers created mid-stream were never written to LiteLLM_ManagedObjectTable. Follow-up file API calls would then 403. Wrap the SSE generator so container ownership is recorded once the upstream iterator finishes assembling completed_response. Also covers the background-polling path, which loops body_iterator end-to-end. Co-authored-by: Yassin Kortam <yassin@berri.ai> --------- Co-authored-by: Cursor <cursoragent@cursor.com> Co-authored-by: Yassin Kortam <yassin@berri.ai>
This commit is contained in:
parent
9cac0471ae
commit
157e7a0f20
@ -1368,6 +1368,21 @@ class ProxyBaseLLMRequestProcessing:
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=self.data,
|
||||
)
|
||||
if route_type == "aresponses":
|
||||
# Streaming /v1/responses returns here without
|
||||
# reaching the non-streaming ownership tail below.
|
||||
# Wrap the SSE generator so container ownership is
|
||||
# written once the upstream iterator finishes
|
||||
# assembling ``completed_response`` — otherwise
|
||||
# code-interpreter containers created during the
|
||||
# stream stay unregistered and follow-up file API
|
||||
# calls 403. Covers the background-polling path
|
||||
# too, which loops ``body_iterator`` end-to-end.
|
||||
selected_data_generator = ProxyBaseLLMRequestProcessing._wrap_responses_stream_for_container_ownership(
|
||||
original_stream_response=response,
|
||||
wrapped_generator=selected_data_generator,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
return await create_response(
|
||||
generator=selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
@ -1483,8 +1498,93 @@ class ProxyBaseLLMRequestProcessing:
|
||||
|
||||
await check_response_size_is_safe(response=response)
|
||||
|
||||
if route_type in {"aresponses", "aget_responses"}:
|
||||
await ProxyBaseLLMRequestProcessing._record_container_owners_from_responses_if_needed(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@staticmethod
|
||||
async def _record_container_owners_from_responses_if_needed(
|
||||
response: Any,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
) -> None:
|
||||
"""Register code-interpreter containers so follow-up file APIs pass ownership checks."""
|
||||
from litellm.proxy.container_endpoints.ownership import (
|
||||
record_container_owners_from_responses_response,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
return
|
||||
|
||||
try:
|
||||
await record_container_owners_from_responses_response(
|
||||
response=response,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Container ownership recording failed after responses call: %s",
|
||||
e,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _extract_completed_responses_response(stream_response: Any) -> Any:
|
||||
"""Pull the assembled ``ResponsesAPIResponse`` off a streaming iterator.
|
||||
|
||||
``ResponsesAPIStreamingIterator`` stores the terminal stream event
|
||||
(``response.completed`` / ``response.incomplete`` / ``response.failed``)
|
||||
in ``completed_response``; the actual response body hangs off
|
||||
that event's ``.response`` attribute. Some iterators store the
|
||||
``ResponsesAPIResponse`` directly. Handle both shapes so the
|
||||
container-ownership recording path can walk ``.output`` either way.
|
||||
"""
|
||||
completed = getattr(stream_response, "completed_response", None)
|
||||
if completed is None:
|
||||
return None
|
||||
response_obj = getattr(completed, "response", None)
|
||||
if response_obj is not None:
|
||||
return response_obj
|
||||
return completed
|
||||
|
||||
@staticmethod
|
||||
async def _wrap_responses_stream_for_container_ownership(
|
||||
original_stream_response: Any,
|
||||
wrapped_generator: Any,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
):
|
||||
"""Forward SSE chunks, then record container ownership at stream end.
|
||||
|
||||
Streaming ``/v1/responses`` short-circuits out of
|
||||
``base_process_llm_request`` before the non-streaming ownership
|
||||
tail runs, so without this wrap the
|
||||
``LiteLLM_ManagedObjectTable`` row for any container created
|
||||
during the stream is never written and follow-up file API calls
|
||||
return 403.
|
||||
"""
|
||||
try:
|
||||
async for chunk in wrapped_generator:
|
||||
yield chunk
|
||||
finally:
|
||||
try:
|
||||
completed_obj = (
|
||||
ProxyBaseLLMRequestProcessing._extract_completed_responses_response(
|
||||
original_stream_response
|
||||
)
|
||||
)
|
||||
if completed_obj is not None:
|
||||
await ProxyBaseLLMRequestProcessing._record_container_owners_from_responses_if_needed(
|
||||
response=completed_obj,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"Container ownership recording failed after streaming responses call: %s",
|
||||
e,
|
||||
)
|
||||
|
||||
async def base_passthrough_process_llm_request(
|
||||
self,
|
||||
request: Request,
|
||||
|
||||
@ -117,6 +117,58 @@ async def _get_prisma_client():
|
||||
return prisma_client
|
||||
|
||||
|
||||
def _custom_llm_provider_from_responses_response(
|
||||
response: Any,
|
||||
default: str = "openai",
|
||||
) -> str:
|
||||
hidden_params: Dict[str, Any] = {}
|
||||
if isinstance(response, dict):
|
||||
hidden_params = response.get("_hidden_params") or {}
|
||||
else:
|
||||
hidden_params = getattr(response, "_hidden_params", None) or {}
|
||||
|
||||
provider = hidden_params.get("custom_llm_provider")
|
||||
if isinstance(provider, str) and provider:
|
||||
return provider
|
||||
return default
|
||||
|
||||
|
||||
async def record_container_owners_from_responses_response(
|
||||
response: Any,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Track containers created implicitly by code interpreter in /v1/responses."""
|
||||
container_ids = (
|
||||
ResponsesAPIRequestUtils.collect_container_ids_from_responses_response(response)
|
||||
)
|
||||
if not container_ids:
|
||||
return
|
||||
|
||||
resolved_provider = (
|
||||
custom_llm_provider or _custom_llm_provider_from_responses_response(response)
|
||||
)
|
||||
|
||||
for container_id in container_ids:
|
||||
try:
|
||||
await record_container_owner(
|
||||
response={"id": container_id, "object": "container"},
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
custom_llm_provider=resolved_provider,
|
||||
)
|
||||
except Exception as e:
|
||||
# Per-container errors (including ``HTTPException`` from
|
||||
# conflicting/forbidden ownership rows) must not abort the
|
||||
# batch — other containers in the same response should still
|
||||
# get recorded so their follow-up file API calls don't 403.
|
||||
verbose_proxy_logger.exception(
|
||||
"Failed to record container ownership from responses output "
|
||||
"for container_id=%s: %s",
|
||||
container_id,
|
||||
e,
|
||||
)
|
||||
|
||||
|
||||
async def record_container_owner(
|
||||
response: Any,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
@ -151,6 +203,8 @@ async def record_container_owner(
|
||||
file_object = _dump_response(response)
|
||||
file_object["custom_llm_provider"] = resolved_provider
|
||||
file_object["provider_container_id"] = original_container_id
|
||||
# Prisma Python requires Json fields to be serialized as a JSON string.
|
||||
file_object_json: str = json.dumps(file_object)
|
||||
|
||||
prisma_client = await _get_prisma_client()
|
||||
if prisma_client is None:
|
||||
@ -172,7 +226,7 @@ async def record_container_owner(
|
||||
where={"model_object_id": model_object_id},
|
||||
data={
|
||||
"unified_object_id": container_id,
|
||||
"file_object": file_object,
|
||||
"file_object": file_object_json,
|
||||
"updated_by": owner,
|
||||
},
|
||||
)
|
||||
@ -181,7 +235,7 @@ async def record_container_owner(
|
||||
data={
|
||||
"unified_object_id": container_id,
|
||||
"model_object_id": model_object_id,
|
||||
"file_object": file_object,
|
||||
"file_object": file_object_json,
|
||||
"file_purpose": CONTAINER_OBJECT_PURPOSE,
|
||||
"created_by": owner,
|
||||
"updated_by": owner,
|
||||
|
||||
@ -738,6 +738,98 @@ class ResponsesAPIRequestUtils:
|
||||
model_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _collect_container_ids_from_annotations(
|
||||
annotations: Any,
|
||||
collected: set[str],
|
||||
) -> None:
|
||||
if not annotations or not isinstance(annotations, list):
|
||||
return
|
||||
for ann in annotations:
|
||||
ResponsesAPIRequestUtils._collect_container_ids_from_output_item(
|
||||
ann, collected
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _collect_container_ids_from_message_content(
|
||||
content: Any,
|
||||
collected: set[str],
|
||||
) -> None:
|
||||
if not content:
|
||||
return
|
||||
if isinstance(content, list):
|
||||
for part in content:
|
||||
if isinstance(part, dict):
|
||||
ResponsesAPIRequestUtils._collect_container_ids_from_annotations(
|
||||
part.get("annotations"),
|
||||
collected,
|
||||
)
|
||||
else:
|
||||
ResponsesAPIRequestUtils._collect_container_ids_from_annotations(
|
||||
getattr(part, "annotations", None),
|
||||
collected,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _collect_container_ids_from_output_item(
|
||||
item: Any,
|
||||
collected: set[str],
|
||||
) -> None:
|
||||
"""Collect managed or raw ``container_id`` values from one output item."""
|
||||
if item is None:
|
||||
return
|
||||
|
||||
if isinstance(item, dict):
|
||||
cid = item.get("container_id")
|
||||
if isinstance(cid, str) and cid:
|
||||
collected.add(cid)
|
||||
nested = item.get("code_interpreter_call")
|
||||
if isinstance(nested, dict):
|
||||
nc = nested.get("container_id")
|
||||
if isinstance(nc, str) and nc:
|
||||
collected.add(nc)
|
||||
if item.get("type") == "message":
|
||||
ResponsesAPIRequestUtils._collect_container_ids_from_message_content(
|
||||
item.get("content"),
|
||||
collected,
|
||||
)
|
||||
return
|
||||
|
||||
cid_attr = getattr(item, "container_id", None)
|
||||
if isinstance(cid_attr, str) and cid_attr:
|
||||
collected.add(cid_attr)
|
||||
|
||||
nested_obj = getattr(item, "code_interpreter_call", None)
|
||||
if nested_obj is not None:
|
||||
ResponsesAPIRequestUtils._collect_container_ids_from_output_item(
|
||||
nested_obj, collected
|
||||
)
|
||||
|
||||
if getattr(item, "type", None) == "message":
|
||||
ResponsesAPIRequestUtils._collect_container_ids_from_message_content(
|
||||
getattr(item, "content", None),
|
||||
collected,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def collect_container_ids_from_responses_response(response: Any) -> list[str]:
|
||||
"""Return unique container IDs referenced in a Responses API payload."""
|
||||
if response is None:
|
||||
return []
|
||||
|
||||
if isinstance(response, dict):
|
||||
output = response.get("output", [])
|
||||
else:
|
||||
output = getattr(response, "output", []) or []
|
||||
|
||||
collected: set[str] = set()
|
||||
if output:
|
||||
for item in output:
|
||||
ResponsesAPIRequestUtils._collect_container_ids_from_output_item(
|
||||
item, collected
|
||||
)
|
||||
return list(collected)
|
||||
|
||||
@staticmethod
|
||||
def _update_container_ids_in_response(
|
||||
responses_api_response: Union[ResponsesAPIResponse, Dict[str, Any]],
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import json
|
||||
import sys
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock
|
||||
@ -91,8 +92,9 @@ async def test_should_not_mutate_dict_container_response_when_recording_owner(
|
||||
|
||||
assert returned == {"id": "cntr_provider", "object": "container"}
|
||||
data = table.create.await_args.kwargs["data"]
|
||||
assert data["file_object"]["custom_llm_provider"] == "openai"
|
||||
assert data["file_object"]["provider_container_id"] == "cntr_provider"
|
||||
file_obj = json.loads(data["file_object"])
|
||||
assert file_obj["custom_llm_provider"] == "openai"
|
||||
assert file_obj["provider_container_id"] == "cntr_provider"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -913,3 +915,195 @@ async def test_admin_with_identity_records_container_ownership(monkeypatch):
|
||||
table.create.assert_awaited_once()
|
||||
created_data = table.create.await_args.kwargs["data"]
|
||||
assert created_data["created_by"] == "proxy-admin"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_record_containers_from_responses_output_for_service_account(
|
||||
monkeypatch,
|
||||
):
|
||||
table = AsyncMock()
|
||||
table.find_unique.return_value = None
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
auth = UserAPIKeyAuth(team_id="team-1")
|
||||
encoded_container_id = (
|
||||
"cntr_bGl0ZWxsbTpjdXN0b21fbGxtX3Byb3ZpZGVyOmF6dXJlO21vZGVsX2lkOmR"
|
||||
"lZi0xMjM7Y29udGFpbmVyX2lkOmNudHJfbmF0aXZl"
|
||||
)
|
||||
responses_payload = {
|
||||
"output": [
|
||||
{
|
||||
"type": "message",
|
||||
"content": [
|
||||
{
|
||||
"type": "output_text",
|
||||
"annotations": [
|
||||
{
|
||||
"type": "container_file_citation",
|
||||
"container_id": encoded_container_id,
|
||||
"file_id": "cfile_abc",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
"_hidden_params": {"custom_llm_provider": "azure"},
|
||||
}
|
||||
|
||||
await ownership.record_container_owners_from_responses_response(
|
||||
response=responses_payload,
|
||||
user_api_key_dict=auth,
|
||||
)
|
||||
|
||||
table.create.assert_awaited_once()
|
||||
created_data = table.create.await_args.kwargs["data"]
|
||||
assert created_data["created_by"] == "team:team-1"
|
||||
assert created_data["unified_object_id"] == encoded_container_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_service_account_can_access_container_after_responses_tracking(
|
||||
monkeypatch,
|
||||
):
|
||||
encoded_container_id = (
|
||||
"cntr_bGl0ZWxsbTpjdXN0b21fbGxtX3Byb3ZpZGVyOmF6dXJlO21vZGVsX2lkOmR"
|
||||
"lZi0xMjM7Y29udGFpbmVyX2lkOmNudHJfbmF0aXZl"
|
||||
)
|
||||
table = AsyncMock()
|
||||
table.find_unique.return_value = None
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
auth = UserAPIKeyAuth(team_id="team-1")
|
||||
|
||||
await ownership.record_container_owners_from_responses_response(
|
||||
response={
|
||||
"output": [
|
||||
{
|
||||
"type": "code_interpreter_call",
|
||||
"container_id": encoded_container_id,
|
||||
}
|
||||
],
|
||||
"_hidden_params": {"custom_llm_provider": "azure"},
|
||||
},
|
||||
user_api_key_dict=auth,
|
||||
)
|
||||
|
||||
original_id, provider = await ownership.assert_user_can_access_container(
|
||||
container_id=encoded_container_id,
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="azure",
|
||||
)
|
||||
assert original_id == "cntr_native"
|
||||
assert provider == "azure"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_record_container_ownership_after_streaming_responses_finish(
|
||||
monkeypatch,
|
||||
):
|
||||
"""Streaming /v1/responses calls return through the
|
||||
``select_data_generator`` branch and never reach the non-streaming
|
||||
container-ownership tail. The wrapper must read
|
||||
``completed_response`` off the upstream iterator once iteration
|
||||
finishes and write the row, otherwise code-interpreter containers
|
||||
created during the stream stay unregistered and follow-up file API
|
||||
calls 403.
|
||||
"""
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
|
||||
encoded_container_id = (
|
||||
"cntr_bGl0ZWxsbTpjdXN0b21fbGxtX3Byb3ZpZGVyOmF6dXJlO21vZGVsX2lkOmR"
|
||||
"lZi0xMjM7Y29udGFpbmVyX2lkOmNudHJfbmF0aXZl"
|
||||
)
|
||||
response_body = SimpleNamespace(
|
||||
output=[
|
||||
SimpleNamespace(
|
||||
type="code_interpreter_call",
|
||||
container_id=encoded_container_id,
|
||||
code_interpreter_call=None,
|
||||
)
|
||||
]
|
||||
)
|
||||
stream_response = SimpleNamespace(
|
||||
completed_response=SimpleNamespace(response=response_body),
|
||||
_hidden_params={"custom_llm_provider": "azure"},
|
||||
)
|
||||
|
||||
async def fake_sse_generator():
|
||||
yield "data: chunk-1\n\n"
|
||||
yield "data: chunk-2\n\n"
|
||||
|
||||
table = AsyncMock()
|
||||
table.find_unique.return_value = None
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
auth = UserAPIKeyAuth(team_id="team-1")
|
||||
|
||||
wrapped = (
|
||||
ProxyBaseLLMRequestProcessing._wrap_responses_stream_for_container_ownership(
|
||||
original_stream_response=stream_response,
|
||||
wrapped_generator=fake_sse_generator(),
|
||||
user_api_key_dict=auth,
|
||||
)
|
||||
)
|
||||
|
||||
chunks = [chunk async for chunk in wrapped]
|
||||
assert chunks == ["data: chunk-1\n\n", "data: chunk-2\n\n"]
|
||||
|
||||
table.create.assert_awaited_once()
|
||||
created_data = table.create.await_args.kwargs["data"]
|
||||
assert created_data["created_by"] == "team:team-1"
|
||||
assert created_data["unified_object_id"] == encoded_container_id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streaming_ownership_wrap_no_op_when_stream_did_not_complete(
|
||||
monkeypatch,
|
||||
):
|
||||
"""If the stream errored before ``response.completed``,
|
||||
``completed_response`` is ``None`` — we must skip the ownership
|
||||
write rather than crash the response generator."""
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
|
||||
stream_response = SimpleNamespace(completed_response=None)
|
||||
|
||||
async def fake_sse_generator():
|
||||
yield "data: chunk-1\n\n"
|
||||
|
||||
record = AsyncMock()
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"record_container_owners_from_responses_response",
|
||||
record,
|
||||
)
|
||||
|
||||
wrapped = (
|
||||
ProxyBaseLLMRequestProcessing._wrap_responses_stream_for_container_ownership(
|
||||
original_stream_response=stream_response,
|
||||
wrapped_generator=fake_sse_generator(),
|
||||
user_api_key_dict=UserAPIKeyAuth(user_id="user-1"),
|
||||
)
|
||||
)
|
||||
chunks = [chunk async for chunk in wrapped]
|
||||
|
||||
assert chunks == ["data: chunk-1\n\n"]
|
||||
record.assert_not_awaited()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user