diff --git a/AGENTS.md b/AGENTS.md index bfd44304d5..96776a3fae 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -189,4 +189,39 @@ When opening issues or pull requests, follow these templates: - Check similar provider implementations - Ensure comprehensive test coverage - Update documentation appropriately -- Consider backward compatibility impact \ No newline at end of file +- Consider backward compatibility impact + +## Cursor Cloud specific instructions + +### Environment + +- Poetry is installed in `~/.local/bin`; the update script ensures it is on `PATH`. +- Python 3.12, Node 22 are pre-installed. +- The virtual environment lives under `~/.cache/pypoetry/virtualenvs/`. + +### Running the proxy server + +Start the proxy with a config file: + +```bash +poetry run litellm --config dev_config.yaml --port 4000 +``` + +The proxy takes ~15-20 seconds to fully start (it runs Prisma migrations on boot). Wait for `/health` to return before sending requests. Without a PostgreSQL `DATABASE_URL`, the proxy connects to a default Neon dev database embedded in the `litellm-proxy-extras` package. + +### Running tests + +See `CLAUDE.md` and the `Makefile` for standard commands. Key notes: + +- `psycopg-binary` must be installed (`poetry run pip install psycopg-binary`) because the pytest-postgresql plugin requires it and the lock file only includes `psycopg` (no binary). +- The `--timeout` pytest flag is NOT available; don't pass it. +- Unit tests: `poetry run pytest tests/test_litellm/ -x -vv -n 4` +- Black `--check` may report pre-existing formatting issues; this does not block test runs. + +### Lint + +```bash +cd litellm && poetry run ruff check . +``` + +Ruff is the primary fast linter. For the full lint suite (including mypy, black, circular imports), run `make lint` per `CLAUDE.md`. \ No newline at end of file diff --git a/litellm/litellm_core_utils/realtime_streaming.py b/litellm/litellm_core_utils/realtime_streaming.py index 759eaf6003..d15d23f8ee 100644 --- a/litellm/litellm_core_utils/realtime_streaming.py +++ b/litellm/litellm_core_utils/realtime_streaming.py @@ -49,6 +49,9 @@ class RealTimeStreaming: self.logging_obj = logging_obj self.messages: List[OpenAIRealtimeEvents] = [] self.input_message: Dict = {} + self.input_messages: List[Dict[str, str]] = [] + self.session_tools: List[Dict] = [] + self.tool_calls: List[Dict] = [] _logged_real_time_event_types = litellm.logged_real_time_event_types @@ -85,6 +88,7 @@ class RealTimeStreaming: message_obj = message else: message_obj = json.loads(message) + self._collect_tool_calls_from_response_done(message_obj) try: if ( not isinstance(message, dict) @@ -100,15 +104,107 @@ class RealTimeStreaming: if self._should_store_message(message_obj): self.messages.append(message_obj) - def store_input(self, message: dict): + def _collect_user_input_from_client_event( + self, message: Union[str, dict] + ) -> None: + """Extract user text content from client WebSocket events for spend logging.""" + try: + if isinstance(message, str): + msg_obj = json.loads(message) + elif isinstance(message, dict): + msg_obj = message + else: + return + + msg_type = msg_obj.get("type", "") + + if msg_type == "conversation.item.create": + item = msg_obj.get("item", {}) + if item.get("role") == "user": + content_list = item.get("content", []) + for content in content_list: + if ( + isinstance(content, dict) + and content.get("type") == "input_text" + ): + text = content.get("text", "") + if text: + self.input_messages.append( + {"role": "user", "content": text} + ) + elif msg_type == "session.update": + session = msg_obj.get("session", {}) + instructions = session.get("instructions", "") + if instructions: + self.input_messages.append( + {"role": "system", "content": instructions} + ) + tools = session.get("tools") + if tools and isinstance(tools, list): + self.session_tools = tools + except (json.JSONDecodeError, AttributeError, TypeError): + pass + + def _collect_user_input_from_backend_event(self, event_obj: dict) -> None: + """Extract user voice transcription from backend events for spend logging.""" + try: + event_type = event_obj.get("type", "") + if ( + event_type + == "conversation.item.input_audio_transcription.completed" + ): + transcript = event_obj.get("transcript", "") + if transcript: + self.input_messages.append( + {"role": "user", "content": transcript} + ) + except (AttributeError, TypeError): + pass + + def _collect_tool_calls_from_response_done( + self, event_obj: dict + ) -> None: + """Extract function_call items from response.done events for spend logging.""" + try: + if event_obj.get("type") != "response.done": + return + response = event_obj.get("response", {}) + for item in response.get("output", []): + if item.get("type") == "function_call": + self.tool_calls.append( + { + "id": item.get("call_id", ""), + "type": "function", + "function": { + "name": item.get("name", ""), + "arguments": item.get("arguments", "{}"), + }, + } + ) + except (AttributeError, TypeError): + pass + + def store_input(self, message: Union[str, dict]): """Store input message""" - self.input_message = message + self.input_message = message if isinstance(message, dict) else {} + self._collect_user_input_from_client_event(message) if self.logging_obj: self.logging_obj.pre_call(input=message, api_key="") async def log_messages(self): """Log messages in list""" if self.logging_obj: + if self.input_messages: + self.logging_obj.model_call_details["messages"] = ( + self.input_messages + ) + if self.session_tools or self.tool_calls: + self.logging_obj.model_call_details[ + "realtime_tools" + ] = self.session_tools + self.logging_obj.model_call_details[ + "realtime_tool_calls" + ] = self.tool_calls ## ASYNC LOGGING # Create an event loop for the new thread asyncio.create_task(self.logging_obj.async_success_handler(self.messages)) @@ -259,6 +355,7 @@ class RealTimeStreaming: == "conversation.item.input_audio_transcription.completed" ): transcript = event.get("transcript", "") + self._collect_user_input_from_backend_event(event) self.store_message(event_str) await self.websocket.send_text(event_str) blocked = await self.run_realtime_guardrails( @@ -308,6 +405,7 @@ class RealTimeStreaming: == "conversation.item.input_audio_transcription.completed" ): transcript = event_obj.get("transcript", "") + self._collect_user_input_from_backend_event(event_obj) ## LOGGING — must happen before continue below self.store_message(raw_response) # Forward transcript to client so user sees what they said diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index 6106f69ea6..133f8ec136 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -207,6 +207,13 @@ async def user_api_key_auth_websocket(websocket: WebSocket): # If no Authorization header, try the api-key header if not authorization: api_key = websocket.headers.get("api-key") + if not api_key: + # Try extracting from WebSocket subprotocol (browser clients) + for protocol in websocket.headers.get("sec-websocket-protocol", "").split(","): + protocol = protocol.strip() + if protocol.startswith("openai-insecure-api-key."): + api_key = protocol[len("openai-insecure-api-key."):] + break if not api_key: await websocket.close(code=status.WS_1008_POLICY_VIOLATION) raise HTTPException(status_code=403, detail="No API key provided") diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 82cfd455be..f0b1e66818 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -7360,7 +7360,15 @@ async def realtime_websocket_endpoint( ), user_api_key_dict=Depends(user_api_key_auth_websocket), ): - await websocket.accept() + requested_protocols = [ + p.strip() + for p in (websocket.headers.get("sec-websocket-protocol") or "").split(",") + if p.strip() + ] + accept_kwargs: dict = {} + if requested_protocols: + accept_kwargs["subprotocol"] = requested_protocols[0] + await websocket.accept(**accept_kwargs) # Only use explicit parameters, not all query params query_params = cast( diff --git a/litellm/proxy/spend_tracking/spend_tracking_utils.py b/litellm/proxy/spend_tracking/spend_tracking_utils.py index 05dbb8ed71..1ef8556233 100644 --- a/litellm/proxy/spend_tracking/spend_tracking_utils.py +++ b/litellm/proxy/spend_tracking/spend_tracking_utils.py @@ -593,6 +593,16 @@ def _get_messages_for_spend_logs_payload( standard_logging_payload: Optional[StandardLoggingPayload], metadata: Optional[dict] = None, ) -> str: + if _should_store_prompts_and_responses_in_spend_logs(): + if standard_logging_payload is not None: + call_type = standard_logging_payload.get("call_type", "") + if call_type == "_arealtime": + messages = standard_logging_payload.get("messages") + if messages is not None: + try: + return json.dumps(messages, default=str) + except Exception: + return "{}" return "{}" @@ -734,7 +744,13 @@ def _get_proxy_server_request_for_spend_logs_payload( ) if _proxy_server_request is not None: _request_body = _proxy_server_request.get("body", {}) or {} - + + if kwargs is not None: + realtime_tools = kwargs.get("realtime_tools") + if realtime_tools: + _request_body = dict(_request_body) + _request_body["tools"] = realtime_tools + # Apply message redaction if turn_off_message_logging is enabled if kwargs is not None: from litellm.litellm_core_utils.redact_messages import ( @@ -796,7 +812,13 @@ def _get_response_for_spend_logs_payload( response_obj: Any = payload.get("response") if response_obj is None: return "{}" - + + if kwargs is not None: + realtime_tool_calls = kwargs.get("realtime_tool_calls") + if realtime_tool_calls and isinstance(response_obj, dict): + response_obj = dict(response_obj) + response_obj["tool_calls"] = realtime_tool_calls + # Apply message redaction if turn_off_message_logging is enabled if kwargs is not None: from litellm.litellm_core_utils.redact_messages import ( diff --git a/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py b/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py index 6b10369604..aaaab95ce6 100644 --- a/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py +++ b/tests/test_litellm/litellm_core_utils/test_realtime_streaming.py @@ -66,6 +66,300 @@ def test_realtime_streaming_store_message(): assert len(streaming.messages) == 2 # Should not store the new message +def test_collect_user_input_from_text_conversation_item(): + """ + Test that conversation.item.create with input_text content is collected as user input. + """ + websocket = MagicMock() + backend_ws = MagicMock() + logging_obj = MagicMock() + streaming = RealTimeStreaming(websocket, backend_ws, logging_obj) + + msg = json.dumps({ + "type": "conversation.item.create", + "item": { + "role": "user", + "content": [ + {"type": "input_text", "text": "Hello, how are you?"} + ] + } + }) + streaming.store_input(msg) + + assert len(streaming.input_messages) == 1 + assert streaming.input_messages[0]["role"] == "user" + assert streaming.input_messages[0]["content"] == "Hello, how are you?" + + +def test_collect_user_input_from_session_update_instructions(): + """ + Test that session.update with instructions is collected as system input. + """ + websocket = MagicMock() + backend_ws = MagicMock() + logging_obj = MagicMock() + streaming = RealTimeStreaming(websocket, backend_ws, logging_obj) + + msg = json.dumps({ + "type": "session.update", + "session": { + "instructions": "You are a helpful assistant." + } + }) + streaming.store_input(msg) + + assert len(streaming.input_messages) == 1 + assert streaming.input_messages[0]["role"] == "system" + assert streaming.input_messages[0]["content"] == "You are a helpful assistant." + + +def test_collect_user_input_from_transcription_event(): + """ + Test that conversation.item.input_audio_transcription.completed events + are collected as user input from backend events. + """ + websocket = MagicMock() + backend_ws = MagicMock() + logging_obj = MagicMock() + streaming = RealTimeStreaming(websocket, backend_ws, logging_obj) + + event_obj = { + "type": "conversation.item.input_audio_transcription.completed", + "transcript": "What is the weather today?", + "item_id": "item_123", + } + streaming._collect_user_input_from_backend_event(event_obj) + + assert len(streaming.input_messages) == 1 + assert streaming.input_messages[0]["role"] == "user" + assert streaming.input_messages[0]["content"] == "What is the weather today?" + + +def test_collect_user_input_ignores_irrelevant_events(): + """ + Test that irrelevant client events don't get collected as user input. + """ + websocket = MagicMock() + backend_ws = MagicMock() + logging_obj = MagicMock() + streaming = RealTimeStreaming(websocket, backend_ws, logging_obj) + + # input_audio_buffer.append should not be collected + msg = json.dumps({"type": "input_audio_buffer.append", "audio": "base64data"}) + streaming.store_input(msg) + assert len(streaming.input_messages) == 0 + + # response.create should not be collected + msg = json.dumps({"type": "response.create"}) + streaming.store_input(msg) + assert len(streaming.input_messages) == 0 + + +def test_collect_user_input_empty_transcript_not_collected(): + """ + Test that transcription events with empty transcripts are not collected. + """ + websocket = MagicMock() + backend_ws = MagicMock() + logging_obj = MagicMock() + streaming = RealTimeStreaming(websocket, backend_ws, logging_obj) + + event_obj = { + "type": "conversation.item.input_audio_transcription.completed", + "transcript": "", + "item_id": "item_123", + } + streaming._collect_user_input_from_backend_event(event_obj) + assert len(streaming.input_messages) == 0 + + +@pytest.mark.asyncio +async def test_log_messages_sets_input_messages_on_logging_obj(): + """ + Test that log_messages() sets input_messages on the logging object's model_call_details. + """ + websocket = MagicMock() + backend_ws = MagicMock() + logging_obj = MagicMock() + logging_obj.model_call_details = {"messages": "default-message-value"} + logging_obj.async_success_handler = AsyncMock() + logging_obj.success_handler = MagicMock() + streaming = RealTimeStreaming(websocket, backend_ws, logging_obj) + + streaming.input_messages = [ + {"role": "user", "content": "Hello from voice"}, + {"role": "user", "content": "Tell me a joke"}, + ] + + await streaming.log_messages() + + assert logging_obj.model_call_details["messages"] == [ + {"role": "user", "content": "Hello from voice"}, + {"role": "user", "content": "Tell me a joke"}, + ] + + +@pytest.mark.asyncio +async def test_transcription_captured_in_backend_to_client(): + """ + Test that conversation.item.input_audio_transcription.completed events + from the backend are captured as user input during the WebSocket session. + """ + import litellm + + client_ws = MagicMock() + client_ws.send_text = AsyncMock() + + transcript_event = json.dumps({ + "type": "conversation.item.input_audio_transcription.completed", + "transcript": "What are the opening hours?", + "item_id": "item_789", + }).encode() + + backend_ws = MagicMock() + backend_ws.recv = AsyncMock( + side_effect=[ + transcript_event, + ConnectionClosed(None, None), + ] + ) + backend_ws.send = AsyncMock() + + logging_obj = MagicMock() + logging_obj.model_call_details = {"messages": "default-message-value"} + logging_obj.async_success_handler = AsyncMock() + logging_obj.success_handler = MagicMock() + streaming = RealTimeStreaming(client_ws, backend_ws, logging_obj) + await streaming.backend_to_client_send_messages() + + assert len(streaming.input_messages) == 1 + assert streaming.input_messages[0]["role"] == "user" + assert streaming.input_messages[0]["content"] == "What are the opening hours?" + assert logging_obj.model_call_details["messages"] == streaming.input_messages + + +def test_collect_session_tools_from_session_update(): + """ + Test that tools from session.update events are collected. + """ + websocket = MagicMock() + backend_ws = MagicMock() + logging_obj = MagicMock() + streaming = RealTimeStreaming(websocket, backend_ws, logging_obj) + + msg = json.dumps({ + "type": "session.update", + "session": { + "tools": [ + { + "type": "function", + "name": "get_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": {"location": {"type": "string"}}, + "required": ["location"], + }, + } + ], + "instructions": "You are a weather assistant." + } + }) + streaming.store_input(msg) + + assert len(streaming.session_tools) == 1 + assert streaming.session_tools[0]["name"] == "get_weather" + assert len(streaming.input_messages) == 1 + assert streaming.input_messages[0]["role"] == "system" + + +def test_collect_tool_calls_from_response_done(): + """ + Test that function_call items are extracted from response.done events. + """ + websocket = MagicMock() + backend_ws = MagicMock() + logging_obj = MagicMock() + streaming = RealTimeStreaming(websocket, backend_ws, logging_obj) + streaming.logged_real_time_event_types = "*" + + response_done = json.dumps({ + "type": "response.done", + "event_id": "evt_123", + "response": { + "output": [ + { + "type": "function_call", + "call_id": "call_abc123", + "name": "get_weather", + "arguments": '{"location": "Paris"}', + } + ] + } + }) + streaming.store_message(response_done) + + assert len(streaming.tool_calls) == 1 + assert streaming.tool_calls[0]["id"] == "call_abc123" + assert streaming.tool_calls[0]["type"] == "function" + assert streaming.tool_calls[0]["function"]["name"] == "get_weather" + assert streaming.tool_calls[0]["function"]["arguments"] == '{"location": "Paris"}' + + +def test_tool_calls_not_collected_from_non_function_call_output(): + """ + Test that non-function_call output items in response.done are not collected. + """ + websocket = MagicMock() + backend_ws = MagicMock() + logging_obj = MagicMock() + streaming = RealTimeStreaming(websocket, backend_ws, logging_obj) + streaming.logged_real_time_event_types = "*" + + response_done = json.dumps({ + "type": "response.done", + "event_id": "evt_456", + "response": { + "output": [ + { + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "Hello!"}] + } + ] + } + }) + streaming.store_message(response_done) + + assert len(streaming.tool_calls) == 0 + + +@pytest.mark.asyncio +async def test_log_messages_includes_tools_in_model_call_details(): + """ + Test that log_messages() sets session_tools and tool_calls on the logging object. + """ + websocket = MagicMock() + backend_ws = MagicMock() + logging_obj = MagicMock() + logging_obj.model_call_details = {"messages": "default-message-value"} + logging_obj.async_success_handler = AsyncMock() + logging_obj.success_handler = MagicMock() + streaming = RealTimeStreaming(websocket, backend_ws, logging_obj) + + streaming.session_tools = [ + {"type": "function", "name": "get_weather", "description": "Get weather"} + ] + streaming.tool_calls = [ + {"id": "call_1", "type": "function", "function": {"name": "get_weather", "arguments": '{"location": "Paris"}'}} + ] + + await streaming.log_messages() + + assert logging_obj.model_call_details["realtime_tools"] == streaming.session_tools + assert logging_obj.model_call_details["realtime_tool_calls"] == streaming.tool_calls + + @pytest.mark.asyncio async def test_realtime_guardrail_blocks_prompt_injection(): """ diff --git a/tests/test_litellm/proxy/spend_tracking/test_spend_tracking_utils.py b/tests/test_litellm/proxy/spend_tracking/test_spend_tracking_utils.py index 06a544e7f4..30257cd2da 100644 --- a/tests/test_litellm/proxy/spend_tracking/test_spend_tracking_utils.py +++ b/tests/test_litellm/proxy/spend_tracking/test_spend_tracking_utils.py @@ -19,6 +19,7 @@ import litellm from litellm.constants import LITELLM_TRUNCATED_PAYLOAD_FIELD, REDACTED_BY_LITELM_STRING from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy.spend_tracking.spend_tracking_utils import ( + _get_messages_for_spend_logs_payload, _get_proxy_server_request_for_spend_logs_payload, _get_request_duration_ms, _get_response_for_spend_logs_payload, @@ -246,6 +247,75 @@ def test_get_vector_store_request_for_spend_logs_payload_null_input(mock_should_ assert result is None +@patch( + "litellm.proxy.spend_tracking.spend_tracking_utils._should_store_prompts_and_responses_in_spend_logs" +) +def test_get_messages_for_spend_logs_realtime_returns_messages(mock_should_store): + """ + Test that _get_messages_for_spend_logs_payload returns messages + for realtime calls when store_prompts_in_spend_logs is True. + """ + mock_should_store.return_value = True + realtime_messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the weather today?"}, + ] + payload = cast( + StandardLoggingPayload, + { + "call_type": "_arealtime", + "messages": realtime_messages, + }, + ) + result = _get_messages_for_spend_logs_payload(payload) + parsed = json.loads(result) + assert len(parsed) == 2 + assert parsed[0]["role"] == "system" + assert parsed[0]["content"] == "You are a helpful assistant." + assert parsed[1]["role"] == "user" + assert parsed[1]["content"] == "What is the weather today?" + + +@patch( + "litellm.proxy.spend_tracking.spend_tracking_utils._should_store_prompts_and_responses_in_spend_logs" +) +def test_get_messages_for_spend_logs_realtime_empty_when_disabled(mock_should_store): + """ + Test that _get_messages_for_spend_logs_payload returns '{}' for realtime calls + when store_prompts_in_spend_logs is False. + """ + mock_should_store.return_value = False + payload = cast( + StandardLoggingPayload, + { + "call_type": "_arealtime", + "messages": [{"role": "user", "content": "Hello"}], + }, + ) + result = _get_messages_for_spend_logs_payload(payload) + assert result == "{}" + + +@patch( + "litellm.proxy.spend_tracking.spend_tracking_utils._should_store_prompts_and_responses_in_spend_logs" +) +def test_get_messages_for_spend_logs_non_realtime_returns_empty(mock_should_store): + """ + Test that _get_messages_for_spend_logs_payload returns '{}' for non-realtime + calls even when store_prompts_in_spend_logs is True. + """ + mock_should_store.return_value = True + payload = cast( + StandardLoggingPayload, + { + "call_type": "acompletion", + "messages": [{"role": "user", "content": "Hello"}], + }, + ) + result = _get_messages_for_spend_logs_payload(payload) + assert result == "{}" + + @patch( "litellm.proxy.spend_tracking.spend_tracking_utils._should_store_prompts_and_responses_in_spend_logs" ) diff --git a/ui/litellm-dashboard/src/components/ToolPolicies.tsx b/ui/litellm-dashboard/src/components/ToolPolicies.tsx index 860093cead..82785496ea 100644 --- a/ui/litellm-dashboard/src/components/ToolPolicies.tsx +++ b/ui/litellm-dashboard/src/components/ToolPolicies.tsx @@ -2,7 +2,19 @@ import React, { useCallback, useDeferredValue, useEffect, useState } from "react"; import { Select, Switch, Tooltip } from "antd"; +<<<<<<< cursor/development-environment-setup-13a7 +// @ts-ignore - duplicate import removed +import { + Table, + TableHead, + TableHeaderCell, + TableBody, + TableRow, + TableCell, +} from "@tremor/react"; +======= import { Table, TableHead, TableHeaderCell, TableBody, TableRow, TableCell } from "@tremor/react"; +>>>>>>> main import { TimeCell } from "./view_logs/time_cell"; import { TableHeaderSortDropdown } from "./common_components/TableHeaderSortDropdown/TableHeaderSortDropdown"; import type { SortState } from "./common_components/TableHeaderSortDropdown/TableHeaderSortDropdown"; @@ -48,6 +60,21 @@ const PolicySelect: React.FC<{ minWidth: 110, fontWeight: 500, }} +<<<<<<< cursor/development-environment-setup-13a7 + {...{styles: { + selector: { + backgroundColor: style.bg, + borderColor: style.border, + color: style.color, + borderRadius: 999, + fontSize: 11, + fontWeight: 600, + paddingLeft: 8, + paddingRight: 4, + }, + }} as any} +======= +>>>>>>> main popupMatchSelectWidth={false} options={POLICY_OPTIONS.map((o) => ({ value: o.value, diff --git a/ui/litellm-dashboard/src/components/organisms/create_key_button.tsx b/ui/litellm-dashboard/src/components/organisms/create_key_button.tsx index 378026e460..961a5d4d46 100644 --- a/ui/litellm-dashboard/src/components/organisms/create_key_button.tsx +++ b/ui/litellm-dashboard/src/components/organisms/create_key_button.tsx @@ -298,7 +298,7 @@ const CreateKey: React.FC = ({ team, teams, data, addKey }) => { formValues.user_id = userID; } else if (keyOwner === "agent") { if (!selectedAgentId) { - NotificationsManager.error("Please select an agent"); + NotificationsManager.fromBackend("Please select an agent"); return; } formValues.agent_id = selectedAgentId; diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx index 4a47cc00b2..3e165c7065 100644 --- a/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/ChatUI.tsx @@ -68,6 +68,7 @@ import ResponsesImageUpload from "./ResponsesImageUpload"; import { createDisplayMessage, createMultimodalMessage } from "./ResponsesImageUtils"; import { SearchResultsDisplay } from "./SearchResultsDisplay"; import SessionManagement from "./SessionManagement"; +import RealtimePlayground from "./RealtimePlayground"; import { A2ATaskMetadata, MessageType } from "./types"; import { useCodeInterpreter } from "./useCodeInterpreter"; @@ -1826,6 +1827,14 @@ const ChatUI: React.FC = ({ {/* Main Chat Area */}
+ {endpointType === EndpointType.REALTIME ? ( + + ) : ( + <>
{simplified ? "Chat" : "Test Key"}
@@ -2422,6 +2431,8 @@ const ChatUI: React.FC = ({ )}
+ + )}
diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/RealtimePlayground.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/RealtimePlayground.tsx new file mode 100644 index 0000000000..87305385bc --- /dev/null +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/RealtimePlayground.tsx @@ -0,0 +1,459 @@ +"use client"; + +import { + AudioMutedOutlined, + AudioOutlined, + CloseCircleOutlined, + SendOutlined, + SoundOutlined, +} from "@ant-design/icons"; +import { Button, Input, Select, Typography } from "antd"; +import React, { useCallback, useEffect, useRef, useState } from "react"; +import { getProxyBaseUrl } from "../../networking"; +import { OPEN_AI_VOICE_SELECT_OPTIONS } from "./chatConstants"; + +const { Text } = Typography; + +interface RealtimeMessage { + role: "user" | "assistant" | "system" | "status"; + content: string; + timestamp: Date; +} + +interface RealtimePlaygroundProps { + accessToken: string; + selectedModel: string; + customProxyBaseUrl?: string; +} + +const RealtimePlayground: React.FC = ({ + accessToken, + selectedModel, + customProxyBaseUrl, +}) => { + const [messages, setMessages] = useState([]); + const [inputText, setInputText] = useState(""); + const [isConnected, setIsConnected] = useState(false); + const [isConnecting, setIsConnecting] = useState(false); + const [isRecording, setIsRecording] = useState(false); + const [selectedVoice, setSelectedVoice] = useState("alloy"); + const wsRef = useRef(null); + const audioContextRef = useRef(null); + const mediaStreamRef = useRef(null); + const processorRef = useRef(null); + const playbackQueueRef = useRef([]); + const isPlayingRef = useRef(false); + const messagesEndRef = useRef(null); + const nextPlayTimeRef = useRef(0); + + const scrollToBottom = useCallback(() => { + messagesEndRef.current?.scrollIntoView({ behavior: "smooth" }); + }, []); + + useEffect(() => { + scrollToBottom(); + }, [messages, scrollToBottom]); + + const addMessage = useCallback( + (role: RealtimeMessage["role"], content: string) => { + setMessages((prev) => [...prev, { role, content, timestamp: new Date() }]); + }, + [] + ); + + const appendAssistantText = useCallback((text: string) => { + setMessages((prev) => { + const last = prev[prev.length - 1]; + if (last && last.role === "assistant") { + return [...prev.slice(0, -1), { ...last, content: last.content + text }]; + } + return [...prev, { role: "assistant", content: text, timestamp: new Date() }]; + }); + }, []); + + const playAudioChunk = useCallback((base64Audio: string) => { + const raw = atob(base64Audio); + const bytes = new Uint8Array(raw.length); + for (let i = 0; i < raw.length; i++) bytes[i] = raw.charCodeAt(i); + const pcm16 = new Int16Array(bytes.buffer); + const float32 = new Float32Array(pcm16.length); + for (let i = 0; i < pcm16.length; i++) float32[i] = pcm16[i] / 32768; + + const ctx = audioContextRef.current; + if (!ctx) return; + + const buffer = ctx.createBuffer(1, float32.length, 24000); + buffer.getChannelData(0).set(float32); + const source = ctx.createBufferSource(); + source.buffer = buffer; + source.connect(ctx.destination); + + const now = ctx.currentTime; + const startTime = Math.max(now, nextPlayTimeRef.current); + source.start(startTime); + nextPlayTimeRef.current = startTime + buffer.duration; + }, []); + + const connect = useCallback(async () => { + if (wsRef.current) return; + if (!selectedModel) { + addMessage("status", "Please select a model first"); + return; + } + setIsConnecting(true); + + try { + audioContextRef.current = new AudioContext({ sampleRate: 24000 }); + + const baseUrl = customProxyBaseUrl || getProxyBaseUrl(); + const wsBase = baseUrl.replace(/^http/, "ws"); + const url = `${wsBase}/v1/realtime?model=${encodeURIComponent(selectedModel)}`; + + const ws = new WebSocket(url, ["realtime", `openai-insecure-api-key.${accessToken}`]); + + ws.onopen = () => { + setIsConnected(true); + setIsConnecting(false); + addMessage("status", "Connected to realtime API"); + }; + + ws.onmessage = async (event) => { + try { + let raw = event.data; + if (raw instanceof Blob) { + raw = await raw.text(); + } else if (raw instanceof ArrayBuffer) { + raw = new TextDecoder().decode(raw); + } + const data = JSON.parse(raw); + const type = data.type; + + if (type === "session.created") { + ws.send( + JSON.stringify({ + type: "session.update", + session: { + modalities: ["text", "audio"], + voice: selectedVoice, + input_audio_format: "pcm16", + output_audio_format: "pcm16", + input_audio_transcription: { model: "gpt-4o-mini-transcribe" }, + turn_detection: null, + }, + }) + ); + } else if (type === "session.updated") { + // session configured + } else if (type === "response.audio.delta") { + if (data.delta) playAudioChunk(data.delta); + } else if (type === "response.audio_transcript.delta" || type === "response.text.delta") { + if (data.delta) appendAssistantText(data.delta); + } else if ( + type === "conversation.item.input_audio_transcription.completed" + ) { + if (data.transcript) addMessage("user", data.transcript); + } else if (type === "response.done") { + // Ensure we have the full text if deltas were missed + setMessages((prev) => { + const last = prev[prev.length - 1]; + if (last && last.role === "assistant" && last.content) return prev; + // No assistant message yet — extract from response.done + const output = data.response?.output || []; + const texts: string[] = []; + for (const item of output) { + for (const c of item.content || []) { + const t = c.text || c.transcript; + if (t) texts.push(t); + } + } + if (texts.length > 0) { + return [...prev, { role: "assistant" as const, content: texts.join(""), timestamp: new Date() }]; + } + return prev; + }); + } else if (type === "error") { + addMessage("status", `Error: ${data.error?.message || JSON.stringify(data.error)}`); + } + } catch { + // ignore parse errors + } + }; + + ws.onerror = () => { + addMessage("status", "WebSocket error"); + setIsConnected(false); + setIsConnecting(false); + }; + + ws.onclose = () => { + addMessage("status", "Disconnected"); + setIsConnected(false); + setIsConnecting(false); + wsRef.current = null; + }; + + wsRef.current = ws; + } catch (err: any) { + addMessage("status", `Connection failed: ${err.message}`); + setIsConnecting(false); + } + }, [accessToken, selectedModel, selectedVoice, customProxyBaseUrl, addMessage, appendAssistantText, playAudioChunk]); + + const disconnect = useCallback(() => { + stopRecording(); + wsRef.current?.close(); + wsRef.current = null; + audioContextRef.current?.close(); + audioContextRef.current = null; + nextPlayTimeRef.current = 0; + configureSessionRef.current = false; + setIsConnected(false); + }, []); + + const startRecording = useCallback(async () => { + if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) return; + + // Switch to server VAD mode for voice input + wsRef.current.send( + JSON.stringify({ + type: "session.update", + session: { + modalities: ["text", "audio"], + voice: selectedVoice, + input_audio_format: "pcm16", + output_audio_format: "pcm16", + input_audio_transcription: { model: "gpt-4o-mini-transcribe" }, + turn_detection: { type: "server_vad" }, + }, + }) + ); + + try { + const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + mediaStreamRef.current = stream; + + const ctx = audioContextRef.current || new AudioContext({ sampleRate: 24000 }); + audioContextRef.current = ctx; + + const source = ctx.createMediaStreamSource(stream); + const processor = ctx.createScriptProcessor(4096, 1, 1); + processorRef.current = processor; + + processor.onaudioprocess = (e) => { + if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) return; + const input = e.inputBuffer.getChannelData(0); + + // Resample to 24kHz if needed + const sampleRate = ctx.sampleRate; + const targetRate = 24000; + let samples: Float32Array; + if (sampleRate !== targetRate) { + const ratio = sampleRate / targetRate; + const newLength = Math.round(input.length / ratio); + samples = new Float32Array(newLength); + for (let i = 0; i < newLength; i++) { + samples[i] = input[Math.round(i * ratio)] || 0; + } + } else { + samples = input; + } + + // Convert to PCM16 + const pcm16 = new Int16Array(samples.length); + for (let i = 0; i < samples.length; i++) { + const s = Math.max(-1, Math.min(1, samples[i])); + pcm16[i] = s < 0 ? s * 0x8000 : s * 0x7fff; + } + + // Base64 encode and send + const bytes = new Uint8Array(pcm16.buffer); + let binary = ""; + for (let i = 0; i < bytes.length; i++) binary += String.fromCharCode(bytes[i]); + const b64 = btoa(binary); + + wsRef.current!.send( + JSON.stringify({ type: "input_audio_buffer.append", audio: b64 }) + ); + }; + + source.connect(processor); + processor.connect(ctx.destination); + setIsRecording(true); + addMessage("status", "🎙️ Listening..."); + } catch (err: any) { + addMessage("status", `Microphone error: ${err.message}`); + } + }, [addMessage]); + + const stopRecording = useCallback(() => { + processorRef.current?.disconnect(); + processorRef.current = null; + mediaStreamRef.current?.getTracks().forEach((t) => t.stop()); + mediaStreamRef.current = null; + setIsRecording(false); + }, []); + + const configureSessionRef = useRef(false); + + const ensureTextSession = useCallback(() => { + if (!wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) return; + if (configureSessionRef.current) return; + configureSessionRef.current = true; + wsRef.current.send( + JSON.stringify({ + type: "session.update", + session: { + modalities: ["text", "audio"], + voice: selectedVoice, + input_audio_format: "pcm16", + output_audio_format: "pcm16", + input_audio_transcription: { model: "gpt-4o-mini-transcribe" }, + turn_detection: null, + }, + }) + ); + }, [selectedVoice]); + + const sendTextMessage = useCallback(() => { + if (!inputText.trim() || !wsRef.current || wsRef.current.readyState !== WebSocket.OPEN) return; + const text = inputText.trim(); + addMessage("user", text); + setInputText(""); + + wsRef.current.send( + JSON.stringify({ + type: "conversation.item.create", + item: { + type: "message", + role: "user", + content: [{ type: "input_text", text }], + }, + }) + ); + wsRef.current.send(JSON.stringify({ type: "response.create" })); + }, [inputText, addMessage, ensureTextSession]); + + useEffect(() => { + return () => { + wsRef.current?.close(); + audioContextRef.current?.close(); + mediaStreamRef.current?.getTracks().forEach((t) => t.stop()); + }; + }, []); + + return ( +
+ {/* Header */} +
+
+ + Realtime Voice Chat + + + {isConnected ? "Connected" : isConnecting ? "Connecting..." : "Disconnected"} + +
+
+ setInputText(e.target.value)} + onPressEnter={sendTextMessage} + className="flex-1" + size="large" + /> +
+ {isRecording && ( +
+ + Listening — speak into your microphone. Server VAD will detect when you stop. +
+ )} +
+ )} +
+ ); +}; + +export default RealtimePlayground; diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/chatConstants.ts b/ui/litellm-dashboard/src/components/playground/chat_ui/chatConstants.ts index 4fd478711a..919e3bc1c6 100644 --- a/ui/litellm-dashboard/src/components/playground/chat_ui/chatConstants.ts +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/chatConstants.ts @@ -44,4 +44,5 @@ export const ENDPOINT_OPTIONS = [ { value: EndpointType.TRANSCRIPTION, label: "/v1/audio/transcriptions" }, { value: EndpointType.A2A_AGENTS, label: "/v1/a2a/message/send" }, { value: EndpointType.MCP, label: "/mcp-rest/tools/call" }, + { value: EndpointType.REALTIME, label: "/v1/realtime" }, ]; diff --git a/ui/litellm-dashboard/src/components/playground/chat_ui/mode_endpoint_mapping.tsx b/ui/litellm-dashboard/src/components/playground/chat_ui/mode_endpoint_mapping.tsx index 6bd765b1bf..f354efe641 100644 --- a/ui/litellm-dashboard/src/components/playground/chat_ui/mode_endpoint_mapping.tsx +++ b/ui/litellm-dashboard/src/components/playground/chat_ui/mode_endpoint_mapping.tsx @@ -27,7 +27,7 @@ export enum EndpointType { TRANSCRIPTION = "transcription", A2A_AGENTS = "a2a_agents", MCP = "mcp", - // add additional endpoint types if required + REALTIME = "realtime", } // Create a mapping between the model mode and the corresponding endpoint type diff --git a/ui/litellm-dashboard/src/components/view_logs/ToolsSection/utils.test.ts b/ui/litellm-dashboard/src/components/view_logs/ToolsSection/utils.test.ts index 75f975e9a1..b3c118ef36 100644 --- a/ui/litellm-dashboard/src/components/view_logs/ToolsSection/utils.test.ts +++ b/ui/litellm-dashboard/src/components/view_logs/ToolsSection/utils.test.ts @@ -258,6 +258,83 @@ describe("ToolsSection utils", () => { expect(result[1].index).toBe(2); expect(result[2].index).toBe(3); }); + + it("should parse tool calls from realtime API response with tool_calls field", () => { + const log: Partial = { + request_id: "test-realtime-1", + proxy_server_request: { + tools: [ + { + type: "function", + name: "get_weather", + description: "Get current weather", + }, + ], + }, + response: { + results: [], + usage: {}, + tool_calls: [ + { + id: "call_abc", + type: "function", + function: { + name: "get_weather", + arguments: '{"location": "Paris"}', + }, + }, + ], + }, + } as any; + + const result = parseToolsFromLog(log as LogEntry); + + expect(result).toHaveLength(1); + expect(result[0].name).toBe("get_weather"); + expect(result[0].called).toBe(true); + expect(result[0].callData?.arguments).toEqual({ location: "Paris" }); + }); + + it("should parse tool calls from realtime response.done events", () => { + const log: Partial = { + request_id: "test-realtime-2", + proxy_server_request: { + tools: [ + { + type: "function", + name: "get_weather", + description: "Get current weather", + }, + ], + }, + response: { + results: [ + { type: "session.created", session: {} }, + { + type: "response.done", + response: { + output: [ + { + type: "function_call", + call_id: "call_xyz", + name: "get_weather", + arguments: '{"location": "Tokyo"}', + }, + ], + }, + }, + ], + usage: {}, + }, + } as any; + + const result = parseToolsFromLog(log as LogEntry); + + expect(result).toHaveLength(1); + expect(result[0].name).toBe("get_weather"); + expect(result[0].called).toBe(true); + expect(result[0].callData?.arguments).toEqual({ location: "Tokyo" }); + }); }); describe("hasTools", () => { diff --git a/ui/litellm-dashboard/src/components/view_logs/ToolsSection/utils.ts b/ui/litellm-dashboard/src/components/view_logs/ToolsSection/utils.ts index 351bbf5116..d0fd662efe 100644 --- a/ui/litellm-dashboard/src/components/view_logs/ToolsSection/utils.ts +++ b/ui/litellm-dashboard/src/components/view_logs/ToolsSection/utils.ts @@ -77,6 +77,33 @@ function extractToolCallsFromResponse(log: LogEntry): ToolCall[] { } } + // Realtime API format: response.tool_calls (added by spend tracking for realtime calls) + if (Array.isArray(responseData.tool_calls)) { + return responseData.tool_calls; + } + + // Realtime API format: response.results[].response.output[].type === "function_call" + if (Array.isArray(responseData.results)) { + const toolCalls: ToolCall[] = []; + for (const result of responseData.results) { + if (result.type === "response.done" && result.response?.output) { + for (const item of result.response.output) { + if (item.type === "function_call") { + toolCalls.push({ + id: item.call_id || "", + type: "function", + function: { + name: item.name || "", + arguments: item.arguments || "{}", + }, + }); + } + } + } + } + if (toolCalls.length > 0) return toolCalls; + } + return []; }