[internal copy of #30137] perf(realtime): eliminate redundant per-frame JSON work on OpenAI realtime relay (#30142)
* perf(realtime): eliminate redundant per-frame JSON work on OpenAI realtime relay The GA realtime support added in #27110 made backend_to_client_send_messages parse every backend frame up to three times for beta clients (OpenAI-Beta: realtime=v1), build a discarded Pydantic object per frame for logging, and re-serialize even frames that need no translation. For high-frequency response.output_audio.delta frames carrying multi-KB base64 payloads, that serialized CPU work on the hottest relay path drove the latency regression between v1.83.14 and v1.88.1 for gpt-realtime-1.5 and gpt-realtime-2. This parses each frame once via _parse_backend_event and threads the dict into _handle_raw_backend_message, store_message, and _translate_event_to_beta; short-circuits store_message before the Pydantic build for events not in the logged set; returns the original event unchanged from _translate_event_to_beta when no rename applies so the raw frame is forwarded without re-serialization; and only json.dumps when the type is actually renamed. * fix(realtime): widen store_message type hint to accept plain dict The parse-once refactor passes the dict produced by _parse_backend_event into store_message, but the parameter was typed as str | bytes | OpenAIRealtimeEvents (a union of TypedDicts), which mypy does not consider compatible with a plain dict. Add dict to the accepted union; the body already handles it. --------- Co-authored-by: Miguel Armenta <maarmenta92@gmail.com>
This commit is contained in:
parent
4a3860df1f
commit
7a96b3490d
@ -144,7 +144,7 @@ class RealTimeStreaming:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def store_message(self, message: Union[str, bytes, OpenAIRealtimeEvents]):
|
def store_message(self, message: Union[str, bytes, dict, OpenAIRealtimeEvents]):
|
||||||
"""Store message in list"""
|
"""Store message in list"""
|
||||||
if isinstance(message, bytes):
|
if isinstance(message, bytes):
|
||||||
message = message.decode("utf-8")
|
message = message.decode("utf-8")
|
||||||
@ -154,22 +154,20 @@ class RealTimeStreaming:
|
|||||||
else:
|
else:
|
||||||
message_obj = cast(Dict[str, Any], json.loads(cast(str, message)))
|
message_obj = cast(Dict[str, Any], json.loads(cast(str, message)))
|
||||||
self._collect_tool_calls_from_response_done(cast(dict, message_obj))
|
self._collect_tool_calls_from_response_done(cast(dict, message_obj))
|
||||||
|
if not self._should_store_message(message_obj):
|
||||||
|
return
|
||||||
try:
|
try:
|
||||||
event_type = message_obj.get("type", "")
|
event_type = message_obj.get("type", "")
|
||||||
if event_type in self._SESSION_EVENT_TYPES:
|
if event_type in self._SESSION_EVENT_TYPES:
|
||||||
typed_obj = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore
|
typed_obj: OpenAIRealtimeEvents = OpenAIRealtimeStreamSessionEvents(**message_obj) # type: ignore
|
||||||
else:
|
else:
|
||||||
# Use the base object as a safe catch-all for all other event types
|
# Catch-all base object so unknown/new event names never raise.
|
||||||
# (both beta and GA), so unknown/new event names never raise here.
|
|
||||||
typed_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore
|
typed_obj = OpenAIRealtimeStreamResponseBaseObject(**message_obj) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.debug(f"Error parsing message for logging: {e}")
|
verbose_logger.debug(f"Error parsing message for logging: {e}")
|
||||||
# Don't re-raise — a parse failure must not drop or delay the message
|
self.messages.append(message_obj) # type: ignore[arg-type]
|
||||||
if self._should_store_message(message_obj):
|
|
||||||
self.messages.append(message_obj) # type: ignore[arg-type]
|
|
||||||
return
|
return
|
||||||
if self._should_store_message(typed_obj):
|
self.messages.append(typed_obj)
|
||||||
self.messages.append(typed_obj)
|
|
||||||
|
|
||||||
def _collect_user_input_from_client_event(self, message: Union[str, dict]) -> None:
|
def _collect_user_input_from_client_event(self, message: Union[str, dict]) -> None:
|
||||||
"""Extract user text content from client WebSocket events for spend logging."""
|
"""Extract user text content from client WebSocket events for spend logging."""
|
||||||
@ -358,8 +356,7 @@ class RealTimeStreaming:
|
|||||||
for msg in self._pending_messages_until_setup
|
for msg in self._pending_messages_until_setup
|
||||||
)
|
)
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
"Failed to flush buffered client message after setup: %s "
|
"Failed to flush buffered client message after setup: %s (%d buffered message(s) retained)",
|
||||||
"(%d buffered message(s) retained)",
|
|
||||||
e,
|
e,
|
||||||
len(unsent),
|
len(unsent),
|
||||||
)
|
)
|
||||||
@ -376,8 +373,7 @@ class RealTimeStreaming:
|
|||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_logger.warning(
|
verbose_logger.warning(
|
||||||
"Failed to translate %s to beta protocol, forwarding "
|
"Failed to translate %s to beta protocol, forwarding untranslated event to client: %s",
|
||||||
"untranslated event to client: %s",
|
|
||||||
event.get("type"),
|
event.get("type"),
|
||||||
e,
|
e,
|
||||||
)
|
)
|
||||||
@ -705,48 +701,48 @@ class RealTimeStreaming:
|
|||||||
self.store_message(event_str)
|
self.store_message(event_str)
|
||||||
await self._send_event_to_client(event, event_str)
|
await self._send_event_to_client(event, event_str)
|
||||||
|
|
||||||
async def _handle_raw_backend_message(self, raw_response) -> bool:
|
@staticmethod
|
||||||
|
def _parse_backend_event(raw_response: str) -> Optional[dict]:
|
||||||
|
"""Parse a backend frame once. Returns None for non-JSON or non-object frames."""
|
||||||
|
try:
|
||||||
|
event = json.loads(raw_response)
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
return None
|
||||||
|
return event if isinstance(event, dict) else None
|
||||||
|
|
||||||
|
async def _handle_raw_backend_message(
|
||||||
|
self, event_obj: dict, raw_response: str
|
||||||
|
) -> bool:
|
||||||
"""Process a backend message without provider_config (raw path).
|
"""Process a backend message without provider_config (raw path).
|
||||||
|
|
||||||
Returns True if the caller should skip the default store+forward (i.e. continue the loop).
|
Returns True if the caller should skip the default store+forward (i.e. continue the loop).
|
||||||
"""
|
"""
|
||||||
try:
|
event_type = event_obj.get("type")
|
||||||
event_obj = json.loads(raw_response)
|
|
||||||
|
|
||||||
# For audio/VAD guardrail path: once the session is ready, tell the backend
|
# Send session.created to the client FIRST so it stays in sync, then inject
|
||||||
# not to auto-respond after VAD detects end-of-speech. We send the
|
# the disable-auto-response session.update; otherwise a backend error could
|
||||||
# session.created to the client FIRST so the client is always in sync, then
|
# reach the client before it sees session.created.
|
||||||
# inject the session.update so a potential error from the backend doesn't
|
if (
|
||||||
# arrive before the client sees session.created.
|
event_type == "session.created"
|
||||||
if (
|
and self._has_audio_transcription_guardrails()
|
||||||
event_obj.get("type") == "session.created"
|
):
|
||||||
and self._has_audio_transcription_guardrails()
|
self.store_message(event_obj)
|
||||||
):
|
await self.websocket.send_text(raw_response)
|
||||||
self.store_message(raw_response)
|
await self._send_to_backend(self._make_disable_auto_response_message())
|
||||||
await self.websocket.send_text(raw_response)
|
return True
|
||||||
await self._send_to_backend(self._make_disable_auto_response_message())
|
|
||||||
return True
|
|
||||||
|
|
||||||
if (
|
if event_type == "conversation.item.input_audio_transcription.completed":
|
||||||
event_obj.get("type")
|
transcript = event_obj.get("transcript", "")
|
||||||
== "conversation.item.input_audio_transcription.completed"
|
self._collect_user_input_from_backend_event(event_obj)
|
||||||
):
|
self.store_message(event_obj)
|
||||||
transcript = event_obj.get("transcript", "")
|
await self.websocket.send_text(raw_response)
|
||||||
self._collect_user_input_from_backend_event(event_obj)
|
blocked = await self.run_realtime_guardrails(
|
||||||
## LOGGING — must happen before continue below
|
transcript,
|
||||||
self.store_message(raw_response)
|
item_id=event_obj.get("item_id"),
|
||||||
# Forward transcript to client so user sees what they said
|
)
|
||||||
await self.websocket.send_text(raw_response)
|
if not blocked:
|
||||||
blocked = await self.run_realtime_guardrails(
|
await self._send_to_backend(json.dumps({"type": "response.create"}))
|
||||||
transcript,
|
return True
|
||||||
item_id=event_obj.get("item_id"),
|
|
||||||
)
|
|
||||||
if not blocked:
|
|
||||||
# Clean — trigger LLM response
|
|
||||||
await self._send_to_backend(json.dumps({"type": "response.create"}))
|
|
||||||
return True
|
|
||||||
except (json.JSONDecodeError, AttributeError):
|
|
||||||
pass
|
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def backend_to_client_send_messages(self):
|
async def backend_to_client_send_messages(self):
|
||||||
@ -779,25 +775,25 @@ class RealTimeStreaming:
|
|||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
else:
|
else:
|
||||||
handled = await self._handle_raw_backend_message(raw_response)
|
event = self._parse_backend_event(raw_response)
|
||||||
if handled:
|
if event is None:
|
||||||
continue
|
|
||||||
## LOGGING
|
|
||||||
self.store_message(raw_response)
|
|
||||||
|
|
||||||
# If the client opted into beta protocol, translate GA event
|
|
||||||
# names/shapes back to the beta equivalents before forwarding.
|
|
||||||
if self._client_wants_beta:
|
|
||||||
try:
|
|
||||||
event_dict = json.loads(raw_response)
|
|
||||||
translated = self._translate_event_to_beta(event_dict)
|
|
||||||
if translated is None:
|
|
||||||
continue # drop GA-only events (e.g. conversation.item.done)
|
|
||||||
await self.websocket.send_text(json.dumps(translated))
|
|
||||||
except Exception:
|
|
||||||
await self.websocket.send_text(raw_response)
|
|
||||||
else:
|
|
||||||
await self.websocket.send_text(raw_response)
|
await self.websocket.send_text(raw_response)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if await self._handle_raw_backend_message(event, raw_response):
|
||||||
|
continue
|
||||||
|
self.store_message(event)
|
||||||
|
|
||||||
|
if not self._client_wants_beta:
|
||||||
|
await self.websocket.send_text(raw_response)
|
||||||
|
continue
|
||||||
|
|
||||||
|
translated = self._translate_event_to_beta(event)
|
||||||
|
if translated is None:
|
||||||
|
continue
|
||||||
|
await self.websocket.send_text(
|
||||||
|
raw_response if translated is event else json.dumps(translated)
|
||||||
|
)
|
||||||
|
|
||||||
except websockets.exceptions.ConnectionClosed as e: # type: ignore
|
except websockets.exceptions.ConnectionClosed as e: # type: ignore
|
||||||
verbose_logger.exception(
|
verbose_logger.exception(
|
||||||
@ -927,41 +923,43 @@ class RealTimeStreaming:
|
|||||||
def _translate_event_to_beta(event: dict) -> Optional[dict]:
|
def _translate_event_to_beta(event: dict) -> Optional[dict]:
|
||||||
"""Translate a single GA event dict to its beta equivalent.
|
"""Translate a single GA event dict to its beta equivalent.
|
||||||
|
|
||||||
Returns None if the event should be dropped entirely (e.g. the GA-only
|
Returns None when the event must be dropped (the GA-only
|
||||||
conversation.item.done has no beta counterpart).
|
conversation.item.done has no beta counterpart). Returns the original
|
||||||
Returns the (possibly mutated copy of the) event otherwise.
|
event object unchanged when no translation applies, so the caller can
|
||||||
|
forward the raw frame without re-serializing; otherwise returns a
|
||||||
|
translated copy.
|
||||||
"""
|
"""
|
||||||
event_type = event.get("type", "")
|
event_type = event.get("type", "")
|
||||||
|
|
||||||
# conversation.item.done has no beta equivalent — the client already
|
|
||||||
# received conversation.item.created (translated from .added).
|
|
||||||
if event_type == "conversation.item.done":
|
if event_type == "conversation.item.done":
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Shallow-copy so we don't mutate the stored message
|
renamed_type = RealTimeStreaming._GA_TO_BETA_EVENT_TYPES.get(event_type)
|
||||||
|
has_item = isinstance(event.get("item"), dict)
|
||||||
|
response = event.get("response")
|
||||||
|
has_response_output = isinstance(response, dict) and isinstance(
|
||||||
|
response.get("output"), list
|
||||||
|
)
|
||||||
|
if renamed_type is None and not has_item and not has_response_output:
|
||||||
|
return event
|
||||||
|
|
||||||
translated = dict(event)
|
translated = dict(event)
|
||||||
|
if renamed_type is not None:
|
||||||
# Rename the type field
|
translated["type"] = renamed_type
|
||||||
if event_type in RealTimeStreaming._GA_TO_BETA_EVENT_TYPES:
|
if has_item:
|
||||||
translated["type"] = RealTimeStreaming._GA_TO_BETA_EVENT_TYPES[event_type]
|
|
||||||
|
|
||||||
# Fix content block types inside items (response.done output list,
|
|
||||||
# conversation.item.created item content, etc.)
|
|
||||||
if "item" in translated and isinstance(translated["item"], dict):
|
|
||||||
translated["item"] = RealTimeStreaming._translate_item_content_types(
|
translated["item"] = RealTimeStreaming._translate_item_content_types(
|
||||||
dict(translated["item"])
|
dict(translated["item"])
|
||||||
)
|
)
|
||||||
if "response" in translated and isinstance(translated["response"], dict):
|
if has_response_output:
|
||||||
resp = dict(translated["response"])
|
resp = dict(translated["response"])
|
||||||
if "output" in resp and isinstance(resp["output"], list):
|
resp["output"] = [
|
||||||
resp["output"] = [
|
(
|
||||||
(
|
RealTimeStreaming._translate_item_content_types(dict(o))
|
||||||
RealTimeStreaming._translate_item_content_types(dict(o))
|
if isinstance(o, dict)
|
||||||
if isinstance(o, dict)
|
else o
|
||||||
else o
|
)
|
||||||
)
|
for o in resp["output"]
|
||||||
for o in resp["output"]
|
]
|
||||||
]
|
|
||||||
translated["response"] = resp
|
translated["response"] = resp
|
||||||
|
|
||||||
return translated
|
return translated
|
||||||
|
|||||||
@ -773,17 +773,15 @@ async def test_realtime_guardrail_blocks_prompt_injection():
|
|||||||
guardrail_items = [
|
guardrail_items = [
|
||||||
e for e in sent_to_backend if e.get("type") == "conversation.item.create"
|
e for e in sent_to_backend if e.get("type") == "conversation.item.create"
|
||||||
]
|
]
|
||||||
assert len(guardrail_items) == 1, (
|
assert (
|
||||||
f"Guardrail should inject a conversation.item.create with violation message, "
|
len(guardrail_items) == 1
|
||||||
f"got: {guardrail_items}"
|
), f"Guardrail should inject a conversation.item.create with violation message, got: {guardrail_items}"
|
||||||
)
|
|
||||||
response_creates = [
|
response_creates = [
|
||||||
e for e in sent_to_backend if e.get("type") == "response.create"
|
e for e in sent_to_backend if e.get("type") == "response.create"
|
||||||
]
|
]
|
||||||
assert len(response_creates) == 1, (
|
assert (
|
||||||
f"Guardrail should send exactly one response.create to voice the violation, "
|
len(response_creates) == 1
|
||||||
f"got: {response_creates}"
|
), f"Guardrail should send exactly one response.create to voice the violation, got: {response_creates}"
|
||||||
)
|
|
||||||
|
|
||||||
# ASSERT 2: error event was sent directly to the client WebSocket
|
# ASSERT 2: error event was sent directly to the client WebSocket
|
||||||
sent_to_client = [
|
sent_to_client = [
|
||||||
@ -1050,10 +1048,9 @@ async def test_realtime_function_call_output_guardrail_blocks_and_returns_error(
|
|||||||
# every toolCall with a toolResponse (Gemini/Vertex Live) exit their
|
# every toolCall with a toolResponse (Gemini/Vertex Live) exit their
|
||||||
# pending-tool-call state instead of stalling. The placeholder must NOT
|
# pending-tool-call state instead of stalling. The placeholder must NOT
|
||||||
# contain any of the blocked content.
|
# contain any of the blocked content.
|
||||||
assert len(forwarded_tool_outputs) == 1, (
|
assert (
|
||||||
f"Sanitized function_call_output should be forwarded, got: "
|
len(forwarded_tool_outputs) == 1
|
||||||
f"{forwarded_tool_outputs}"
|
), f"Sanitized function_call_output should be forwarded, got: {forwarded_tool_outputs}"
|
||||||
)
|
|
||||||
sanitized_item = forwarded_tool_outputs[0]["item"]
|
sanitized_item = forwarded_tool_outputs[0]["item"]
|
||||||
assert sanitized_item["call_id"] == "call_123"
|
assert sanitized_item["call_id"] == "call_123"
|
||||||
assert "test@example.com" not in sanitized_item["output"]
|
assert "test@example.com" not in sanitized_item["output"]
|
||||||
@ -2110,3 +2107,202 @@ async def test_deferred_setup_caps_non_audio_buffered_bytes(monkeypatch):
|
|||||||
assert (
|
assert (
|
||||||
streaming._pending_messages_byte_total <= RealTimeStreaming._MAX_BUFFERED_BYTES
|
streaming._pending_messages_byte_total <= RealTimeStreaming._MAX_BUFFERED_BYTES
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _beta_client_ws():
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.scope = {"headers": [(b"openai-beta", b"realtime=v1")]}
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
def _ga_client_ws():
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.scope = {"headers": []}
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
def _streaming_with(client_ws):
|
||||||
|
backend_ws = MagicMock()
|
||||||
|
logging_obj = MagicMock()
|
||||||
|
logging_obj.async_success_handler = AsyncMock()
|
||||||
|
logging_obj.success_handler = MagicMock()
|
||||||
|
return RealTimeStreaming(client_ws, backend_ws, logging_obj)
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_backend_event_returns_none_for_non_json():
|
||||||
|
assert RealTimeStreaming._parse_backend_event("not json") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_backend_event_returns_none_for_non_dict_json():
|
||||||
|
assert RealTimeStreaming._parse_backend_event("[1, 2, 3]") is None
|
||||||
|
assert RealTimeStreaming._parse_backend_event('"a string"') is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_backend_event_returns_dict():
|
||||||
|
parsed = RealTimeStreaming._parse_backend_event('{"type": "x", "v": 1}')
|
||||||
|
assert parsed == {"type": "x", "v": 1}
|
||||||
|
|
||||||
|
|
||||||
|
def test_translate_event_to_beta_returns_identity_when_no_translation():
|
||||||
|
"""An event with no renamed type and no item/response is returned unchanged
|
||||||
|
(same object), so the caller can forward the raw frame without re-serializing."""
|
||||||
|
ev = {"type": "error", "error": {"message": "boom"}}
|
||||||
|
out = RealTimeStreaming._translate_event_to_beta(ev)
|
||||||
|
assert out is ev
|
||||||
|
|
||||||
|
|
||||||
|
def test_translate_event_to_beta_preserves_audio_delta_payload():
|
||||||
|
payload = "QUJDREVG" * 200
|
||||||
|
out = RealTimeStreaming._translate_event_to_beta(
|
||||||
|
{"type": "response.output_audio.delta", "delta": payload, "event_id": "e1"}
|
||||||
|
)
|
||||||
|
assert out is not None
|
||||||
|
assert out["type"] == "response.audio.delta"
|
||||||
|
assert out["delta"] == payload
|
||||||
|
|
||||||
|
|
||||||
|
def test_translate_event_to_beta_remaps_response_done_output_content_types():
|
||||||
|
out = RealTimeStreaming._translate_event_to_beta(
|
||||||
|
{
|
||||||
|
"type": "response.done",
|
||||||
|
"response": {
|
||||||
|
"output": [
|
||||||
|
{
|
||||||
|
"type": "message",
|
||||||
|
"content": [{"type": "output_audio", "transcript": "hi"}],
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
assert out is not None
|
||||||
|
assert out["response"]["output"][0]["content"][0]["type"] == "audio"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_beta_client_receives_translated_audio_delta():
|
||||||
|
client_ws = _beta_client_ws()
|
||||||
|
frame = json.dumps(
|
||||||
|
{"type": "response.output_audio.delta", "delta": "QUJD", "event_id": "e1"}
|
||||||
|
)
|
||||||
|
backend_ws = MagicMock()
|
||||||
|
backend_ws.recv = AsyncMock(
|
||||||
|
side_effect=[frame.encode(), ConnectionClosed(None, None)]
|
||||||
|
)
|
||||||
|
logging_obj = MagicMock()
|
||||||
|
logging_obj.async_success_handler = AsyncMock()
|
||||||
|
logging_obj.success_handler = MagicMock()
|
||||||
|
streaming = RealTimeStreaming(client_ws, backend_ws, logging_obj)
|
||||||
|
|
||||||
|
await streaming.backend_to_client_send_messages()
|
||||||
|
|
||||||
|
assert client_ws.send_text.await_count == 1
|
||||||
|
sent = json.loads(client_ws.send_text.await_args.args[0])
|
||||||
|
assert sent["type"] == "response.audio.delta"
|
||||||
|
assert sent["delta"] == "QUJD"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_ga_client_receives_raw_passthrough():
|
||||||
|
client_ws = _ga_client_ws()
|
||||||
|
frame = json.dumps(
|
||||||
|
{"type": "response.output_audio.delta", "delta": "QUJD", "event_id": "e1"}
|
||||||
|
)
|
||||||
|
backend_ws = MagicMock()
|
||||||
|
backend_ws.recv = AsyncMock(
|
||||||
|
side_effect=[frame.encode(), ConnectionClosed(None, None)]
|
||||||
|
)
|
||||||
|
logging_obj = MagicMock()
|
||||||
|
logging_obj.async_success_handler = AsyncMock()
|
||||||
|
logging_obj.success_handler = MagicMock()
|
||||||
|
streaming = RealTimeStreaming(client_ws, backend_ws, logging_obj)
|
||||||
|
|
||||||
|
await streaming.backend_to_client_send_messages()
|
||||||
|
|
||||||
|
assert client_ws.send_text.await_count == 1
|
||||||
|
# GA client gets the byte-identical frame, no re-serialization.
|
||||||
|
assert client_ws.send_text.await_args.args[0] == frame
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_beta_client_non_translated_event_forwarded_raw():
|
||||||
|
"""For a beta client, an event needing no translation is forwarded as the
|
||||||
|
original raw frame (identity return path), not a re-serialized copy."""
|
||||||
|
client_ws = _beta_client_ws()
|
||||||
|
frame = json.dumps({"type": "error", "error": {"message": "boom"}})
|
||||||
|
backend_ws = MagicMock()
|
||||||
|
backend_ws.recv = AsyncMock(
|
||||||
|
side_effect=[frame.encode(), ConnectionClosed(None, None)]
|
||||||
|
)
|
||||||
|
logging_obj = MagicMock()
|
||||||
|
logging_obj.async_success_handler = AsyncMock()
|
||||||
|
logging_obj.success_handler = MagicMock()
|
||||||
|
streaming = RealTimeStreaming(client_ws, backend_ws, logging_obj)
|
||||||
|
|
||||||
|
await streaming.backend_to_client_send_messages()
|
||||||
|
|
||||||
|
assert client_ws.send_text.await_count == 1
|
||||||
|
assert client_ws.send_text.await_args.args[0] == frame
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_beta_client_drops_conversation_item_done():
|
||||||
|
client_ws = _beta_client_ws()
|
||||||
|
frame = json.dumps({"type": "conversation.item.done", "item": {"id": "i1"}})
|
||||||
|
backend_ws = MagicMock()
|
||||||
|
backend_ws.recv = AsyncMock(
|
||||||
|
side_effect=[frame.encode(), ConnectionClosed(None, None)]
|
||||||
|
)
|
||||||
|
logging_obj = MagicMock()
|
||||||
|
logging_obj.async_success_handler = AsyncMock()
|
||||||
|
logging_obj.success_handler = MagicMock()
|
||||||
|
streaming = RealTimeStreaming(client_ws, backend_ws, logging_obj)
|
||||||
|
|
||||||
|
await streaming.backend_to_client_send_messages()
|
||||||
|
|
||||||
|
assert client_ws.send_text.await_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_message_skips_pydantic_for_unlogged_audio_delta():
|
||||||
|
"""Audio deltas are not in DefaultLoggedRealTimeEventTypes; store_message must
|
||||||
|
skip the Pydantic build entirely (no append, no validation)."""
|
||||||
|
streaming = _streaming_with(_ga_client_ws())
|
||||||
|
with patch(
|
||||||
|
"litellm.litellm_core_utils.realtime_streaming.OpenAIRealtimeStreamResponseBaseObject"
|
||||||
|
) as base_obj:
|
||||||
|
streaming.store_message({"type": "response.output_audio.delta", "delta": "x"})
|
||||||
|
base_obj.assert_not_called()
|
||||||
|
assert streaming.messages == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_audio_delta_frame_parsed_at_most_once():
|
||||||
|
client_ws = _beta_client_ws()
|
||||||
|
frame = json.dumps(
|
||||||
|
{"type": "response.output_audio.delta", "delta": "QUJD", "event_id": "e1"}
|
||||||
|
)
|
||||||
|
backend_ws = MagicMock()
|
||||||
|
backend_ws.recv = AsyncMock(
|
||||||
|
side_effect=[frame.encode(), ConnectionClosed(None, None)]
|
||||||
|
)
|
||||||
|
logging_obj = MagicMock()
|
||||||
|
logging_obj.async_success_handler = AsyncMock()
|
||||||
|
logging_obj.success_handler = MagicMock()
|
||||||
|
streaming = RealTimeStreaming(client_ws, backend_ws, logging_obj)
|
||||||
|
|
||||||
|
real_loads = json.loads
|
||||||
|
calls = {"n": 0}
|
||||||
|
|
||||||
|
def counting_loads(*args, **kwargs):
|
||||||
|
calls["n"] += 1
|
||||||
|
return real_loads(*args, **kwargs)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"litellm.litellm_core_utils.realtime_streaming.json.loads",
|
||||||
|
side_effect=counting_loads,
|
||||||
|
):
|
||||||
|
await streaming.backend_to_client_send_messages()
|
||||||
|
|
||||||
|
assert calls["n"] == 1
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user