From 000913fa12ec41916c04ab1107cd31d196c3fb0f Mon Sep 17 00:00:00 2001 From: drorIvry Date: Wed, 7 Jan 2026 13:53:12 +0200 Subject: [PATCH 1/9] Hotfix - docs qualifire (#18724) * Hotfix - docs qualifire * Hotfix - docs qualifire * Hotfix - docs qualifire * Hotfix - docs qualifire * Hotfix - docs qualifire * Hotfix - docs qualifire * Hotfix - docs qualifire --- .../docs/proxy/guardrails/qualifire.md | 41 +- docs/my-website/sidebars.js | 1 + .../guardrail_hooks/qualifire/qualifire.py | 244 ++++++---- .../guardrail_hooks/test_qualifire.py | 433 +++++++++++++----- 4 files changed, 475 insertions(+), 244 deletions(-) diff --git a/docs/my-website/docs/proxy/guardrails/qualifire.md b/docs/my-website/docs/proxy/guardrails/qualifire.md index 66961c92d9..850af37e47 100644 --- a/docs/my-website/docs/proxy/guardrails/qualifire.md +++ b/docs/my-website/docs/proxy/guardrails/qualifire.md @@ -8,13 +8,7 @@ Use [Qualifire](https://qualifire.ai) to evaluate LLM outputs for quality, safet ## Quick Start -### 1. Install the Qualifire SDK - -```bash -pip install qualifire -``` - -### 2. Define Guardrails on your LiteLLM config.yaml +### 1. Define Guardrails on your LiteLLM config.yaml Define your guardrails under the `guardrails` section: @@ -61,13 +55,13 @@ guardrails: - `post_call` Run **after** LLM call, on **input & output** - `during_call` Run **during** LLM call, on **input**. Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes -### 3. Start LiteLLM Gateway +### 2. Start LiteLLM Gateway ```shell litellm --config config.yaml --detailed_debug ``` -### 4. Test request +### 3. Test request **[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)** @@ -142,7 +136,7 @@ guardrails: evaluation_id: eval_abc123 # Your evaluation ID from Qualifire dashboard ``` -When `evaluation_id` is provided, LiteLLM will use `invoke_evaluation()` instead of `evaluate()`, running the pre-configured evaluation from your dashboard. +When `evaluation_id` is provided, LiteLLM will use the invoke evaluation API endpoint instead of the evaluate endpoint, running the pre-configured evaluation from your dashboard. ## Available Checks @@ -213,19 +207,19 @@ guardrails: ### Parameter Reference -| Parameter | Type | Default | Description | -| ------------------------------ | ----------- | --------------------------- | -------------------------------------------------------- | -| `api_key` | `str` | `QUALIFIRE_API_KEY` env var | Your Qualifire API key | -| `api_base` | `str` | `None` | Custom API base URL (optional) | -| `evaluation_id` | `str` | `None` | Pre-configured evaluation ID from Qualifire dashboard | -| `prompt_injections` | `bool` | `true` (if no other checks) | Enable prompt injection detection | -| `hallucinations_check` | `bool` | `None` | Enable hallucination detection | -| `grounding_check` | `bool` | `None` | Enable grounding verification | -| `pii_check` | `bool` | `None` | Enable PII detection | -| `content_moderation_check` | `bool` | `None` | Enable content moderation | -| `tool_selection_quality_check` | `bool` | `None` | Enable tool selection quality check | -| `assertions` | `List[str]` | `None` | Custom assertions to validate | -| `on_flagged` | `str` | `"block"` | Action when content is flagged: `"block"` or `"monitor"` | +| Parameter | Type | Default | Description | +| ------------------------------ | ----------- | ---------------------------- | -------------------------------------------------------- | +| `api_key` | `str` | `QUALIFIRE_API_KEY` env var | Your Qualifire API key | +| `api_base` | `str` | `https://proxy.qualifire.ai` | Custom API base URL (optional) | +| `evaluation_id` | `str` | `None` | Pre-configured evaluation ID from Qualifire dashboard | +| `prompt_injections` | `bool` | `true` (if no other checks) | Enable prompt injection detection | +| `hallucinations_check` | `bool` | `None` | Enable hallucination detection | +| `grounding_check` | `bool` | `None` | Enable grounding verification | +| `pii_check` | `bool` | `None` | Enable PII detection | +| `content_moderation_check` | `bool` | `None` | Enable content moderation | +| `tool_selection_quality_check` | `bool` | `None` | Enable tool selection quality check | +| `assertions` | `List[str]` | `None` | Custom assertions to validate | +| `on_flagged` | `str` | `"block"` | Action when content is flagged: `"block"` or `"monitor"` | ### Default Behavior @@ -261,4 +255,3 @@ This evaluates whether the LLM selected the appropriate tools and provided corre - [Qualifire Documentation](https://docs.qualifire.ai) - [Qualifire Dashboard](https://app.qualifire.ai) -- [Qualifire Python SDK](https://github.com/qualifire-dev/qualifire-python-sdk) diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 5d2f096156..004132ca05 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -55,6 +55,7 @@ const sidebars = { "proxy/guardrails/test_playground", "proxy/guardrails/litellm_content_filter", ...[ + "proxy/guardrails/qualifire", "proxy/guardrails/aim_security", "proxy/guardrails/onyx_security", "proxy/guardrails/aporia_api", diff --git a/litellm/proxy/guardrails/guardrail_hooks/qualifire/qualifire.py b/litellm/proxy/guardrails/guardrail_hooks/qualifire/qualifire.py index a6971b49f3..87da11efad 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/qualifire/qualifire.py +++ b/litellm/proxy/guardrails/guardrail_hooks/qualifire/qualifire.py @@ -5,6 +5,7 @@ # +-------------------------------------------------------------+ # Qualifire - Evaluate LLM outputs for quality, safety, and reliability +import json import os from typing import Any, Dict, List, Literal, Optional, Type @@ -15,12 +16,17 @@ from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.litellm_core_utils.litellm_logging import ( Logging as LiteLLMLoggingObj, ) +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import AllMessageValues from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel from litellm.types.utils import GenericGuardrailAPIInputs GUARDRAIL_NAME = "qualifire" +DEFAULT_QUALIFIRE_API_BASE = "https://proxy.qualifire.ai" class QualifireGuardrail(CustomGuardrail): @@ -44,7 +50,7 @@ class QualifireGuardrail(CustomGuardrail): Args: api_key: API key for Qualifire (or use QUALIFIRE_API_KEY env var) - api_base: Optional custom API base URL + api_base: Optional custom API base URL (defaults to https://api.qualifire.ai) evaluation_id: Pre-configured evaluation ID from Qualifire dashboard prompt_injections: Enable prompt injection detection (default if no other checks) hallucinations_check: Enable hallucination detection @@ -64,6 +70,7 @@ class QualifireGuardrail(CustomGuardrail): api_base or get_secret_str("QUALIFIRE_BASE_URL") or os.environ.get("QUALIFIRE_BASE_URL") + or DEFAULT_QUALIFIRE_API_BASE ) self.evaluation_id = evaluation_id self.prompt_injections = prompt_injections @@ -79,7 +86,11 @@ class QualifireGuardrail(CustomGuardrail): if not self._has_any_check_enabled() and not self.evaluation_id: self.prompt_injections = True - self._client = None + # Initialize async HTTP client for direct API calls + self.async_handler = get_async_httpx_client( + llm_provider=httpxSpecialProvider.GuardrailCallback + ) + super().__init__(**kwargs) def _has_any_check_enabled(self) -> bool: @@ -96,43 +107,22 @@ class QualifireGuardrail(CustomGuardrail): ] ) - def _get_client(self): - """Lazy initialization of Qualifire client.""" - if self._client is None: - try: - from qualifire.client import Client - except ImportError: - raise ImportError( - "qualifire package is required for QualifireGuardrail. " - "Install it with: pip install qualifire" - ) - - client_kwargs: Dict[str, Any] = {} - if self.qualifire_api_key: - client_kwargs["api_key"] = self.qualifire_api_key - if self.qualifire_api_base: - client_kwargs["base_url"] = self.qualifire_api_base - - self._client = Client(**client_kwargs) - - return self._client - - def _convert_messages_to_qualifire_format( + def _convert_messages_to_api_format( self, messages: List[AllMessageValues] - ) -> List[Any]: + ) -> List[Dict[str, Any]]: """ - Convert LiteLLM messages to Qualifire's LLMMessage format. + Convert LiteLLM messages to Qualifire API format. Supports tool calls for tool_selection_quality_check. - """ - try: - from qualifire.types import LLMMessage, LLMToolCall - except ImportError: - raise ImportError( - "qualifire package is required for QualifireGuardrail. " - "Install it with: pip install qualifire" - ) - qualifire_messages = [] + Returns a list of dicts matching the API's ModelInvocationCanonicalMessage schema: + { + "role": "user" | "assistant" | "system" | "tool", + "content": "...", + "tool_call_id": "...", # optional + "tool_calls": [{"id": "...", "name": "...", "arguments": {...}}] # optional + } + """ + api_messages = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") @@ -147,42 +137,86 @@ class QualifireGuardrail(CustomGuardrail): text_parts.append(part) content = "\n".join(text_parts) - llm_message_kwargs: Dict[str, Any] = { + api_message: Dict[str, Any] = { "role": role, "content": content if isinstance(content, str) else str(content), } + # Handle tool_call_id for tool response messages + tool_call_id = msg.get("tool_call_id") + if tool_call_id: + api_message["tool_call_id"] = tool_call_id + # Handle tool calls if present tool_calls = msg.get("tool_calls") if tool_calls and isinstance(tool_calls, list): - qualifire_tool_calls = [] + api_tool_calls = [] for tc in tool_calls: if isinstance(tc, dict): function_info = tc.get("function", {}) # Arguments can be a string (JSON) or dict args = function_info.get("arguments", {}) if isinstance(args, str): - import json - try: args = json.loads(args) except json.JSONDecodeError: args = {} - qualifire_tool_calls.append( - LLMToolCall( - id=tc.get("id") or "", - name=function_info.get("name") or "", - arguments=args if isinstance(args, dict) else {}, - ) + api_tool_calls.append( + { + "id": tc.get("id") or "", + "name": function_info.get("name") or "", + "arguments": args if isinstance(args, dict) else {}, + } ) - if qualifire_tool_calls: - llm_message_kwargs["tool_calls"] = qualifire_tool_calls + if api_tool_calls: + api_message["tool_calls"] = api_tool_calls - qualifire_messages.append(LLMMessage(**llm_message_kwargs)) + api_messages.append(api_message) - return qualifire_messages + return api_messages - def _check_if_flagged(self, result: Any) -> bool: + def _convert_tools_to_api_format( + self, tools: Optional[List[Any]] + ) -> Optional[List[Dict[str, Any]]]: + """ + Convert OpenAI-format tools to Qualifire API format. + + Returns a list of dicts matching the API's ModelInvocationToolDefinition schema: + { + "name": "...", + "description": "...", + "parameters": {...} + } + """ + if not tools: + return None + + api_tools = [] + for tool in tools: + if isinstance(tool, dict): + # Handle OpenAI function tool format + if tool.get("type") == "function": + function_def = tool.get("function", {}) + api_tools.append( + { + "name": function_def.get("name", ""), + "description": function_def.get("description", ""), + "parameters": function_def.get("parameters", {}), + } + ) + # Handle direct tool format + elif "name" in tool: + api_tools.append( + { + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {}), + } + ) + + return api_tools if api_tools else None + + def _check_if_flagged(self, result: Dict[str, Any]) -> bool: """ Check if the Qualifire evaluation result indicates flagged content. @@ -190,65 +224,53 @@ class QualifireGuardrail(CustomGuardrail): A high score (close to 100) indicates GOOD content, low score indicates problems. """ # Check evaluation results for any flagged items - evaluation_results = getattr(result, "evaluationResults", None) or [] - if isinstance(result, dict): - evaluation_results = result.get("evaluationResults", []) or [] + evaluation_results = result.get("evaluationResults", []) or [] for eval_result in evaluation_results: - results: List[Any] = [] - if isinstance(eval_result, dict): - results = eval_result.get("results", []) or [] - else: - results = getattr(eval_result, "results", []) or [] - + results = eval_result.get("results", []) or [] for r in results: - flagged = ( - r.get("flagged") - if isinstance(r, dict) - else getattr(r, "flagged", False) - ) - if flagged: + if r.get("flagged"): return True return False - def _build_evaluate_kwargs( + def _build_evaluate_payload( self, - qualifire_messages: List[Any], + api_messages: List[Dict[str, Any]], output: Optional[str], assertions: Optional[List[str]], - available_tools: Optional[List[Any]], + available_tools: Optional[List[Dict[str, Any]]], ) -> Dict[str, Any]: - """Build kwargs dictionary for the evaluate call.""" - kwargs: Dict[str, Any] = {"messages": qualifire_messages} + """Build payload dictionary for the /api/evaluation/evaluate endpoint.""" + payload: Dict[str, Any] = {"messages": api_messages} if output is not None: - kwargs["output"] = output + payload["output"] = output # Add enabled checks if self.prompt_injections: - kwargs["prompt_injections"] = True + payload["prompt_injections"] = True if self.hallucinations_check: - kwargs["hallucinations_check"] = True + payload["hallucinations_check"] = True if self.grounding_check: - kwargs["grounding_check"] = True + payload["grounding_check"] = True if self.pii_check: - kwargs["pii_check"] = True + payload["pii_check"] = True if self.content_moderation_check: - kwargs["content_moderation_check"] = True + payload["content_moderation_check"] = True if self.tool_selection_quality_check: # Only enable tool_selection_quality_check if available_tools is provided if available_tools: - kwargs["tool_selection_quality_check"] = True - kwargs["available_tools"] = available_tools + payload["tool_selection_quality_check"] = True + payload["available_tools"] = available_tools else: verbose_proxy_logger.debug( "Qualifire Guardrail: tool_selection_quality_check enabled but no available_tools provided, skipping this check" ) if assertions: - kwargs["assertions"] = assertions + payload["assertions"] = assertions - return kwargs + return payload async def _run_qualifire_check( self, @@ -274,11 +296,17 @@ class QualifireGuardrail(CustomGuardrail): assertions = dynamic_params.get("assertions") or self.assertions on_flagged = dynamic_params.get("on_flagged") or self.on_flagged - try: - client = self._get_client() - qualifire_messages = self._convert_messages_to_qualifire_format(messages) + # Prepare headers + headers = { + "X-Qualifire-API-Key": self.qualifire_api_key or "", + "Content-Type": "application/json", + } - # Use invoke_evaluation if evaluation_id is provided + try: + # Convert messages to API format + api_messages = self._convert_messages_to_api_format(messages) + + # Use invoke endpoint if evaluation_id is provided if evaluation_id: # For invoke_evaluation, we need to extract input/output input_text = "" @@ -291,25 +319,47 @@ class QualifireGuardrail(CustomGuardrail): input_text = content break - result = client.invoke_evaluation( - evaluation_id=evaluation_id, - input=input_text, - output=output or "", - ) + payload = { + "evaluation_id": evaluation_id, + "input": input_text, + "output": output or "", + "messages": api_messages, + } + + # Convert tools if provided + api_tools = self._convert_tools_to_api_format(available_tools) + if api_tools: + payload["available_tools"] = api_tools + + url = f"{self.qualifire_api_base}/api/evaluation/invoke" else: - # Use evaluate with individual checks - kwargs = self._build_evaluate_kwargs( - qualifire_messages=qualifire_messages, + # Use evaluate endpoint with individual checks + api_tools = self._convert_tools_to_api_format(available_tools) + payload = self._build_evaluate_payload( + api_messages=api_messages, output=output, assertions=assertions, - available_tools=available_tools, + available_tools=api_tools, ) - result = client.evaluate(**kwargs) + url = f"{self.qualifire_api_base}/api/evaluation/evaluate" - # Convert result to dict for logging + verbose_proxy_logger.debug( + f"Qualifire Guardrail: Making request to {url}" + ) + + # Make the API request + response = await self.async_handler.post( + url=url, + headers=headers, + json=payload, + ) + response.raise_for_status() + result = response.json() + + # Extract response info for logging qualifire_response = { - "score": getattr(result, "score", None), - "status": getattr(result, "status", None), + "score": result.get("score"), + "status": result.get("status"), } verbose_proxy_logger.debug( diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_qualifire.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_qualifire.py index 35ed49a84e..fd72185d1e 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_qualifire.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_qualifire.py @@ -2,7 +2,6 @@ Unit tests for Qualifire guardrail integration. """ -import sys from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -75,9 +74,37 @@ class TestQualifireGuardrailInit: assert guardrail.on_flagged == "monitor" + def test_init_with_default_api_base(self): + """Test that default API base is set when not provided.""" + from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import ( + DEFAULT_QUALIFIRE_API_BASE, + QualifireGuardrail, + ) + + guardrail = QualifireGuardrail( + api_key="test_key", + guardrail_name="test_guardrail", + ) + + assert guardrail.qualifire_api_base == DEFAULT_QUALIFIRE_API_BASE + + def test_init_with_custom_api_base(self): + """Test initialization with custom API base URL.""" + from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import ( + QualifireGuardrail, + ) + + guardrail = QualifireGuardrail( + api_key="test_key", + api_base="https://custom.qualifire.ai", + guardrail_name="test_guardrail", + ) + + assert guardrail.qualifire_api_base == "https://custom.qualifire.ai" + class TestQualifireGuardrailMessageConversion: - """Tests for message conversion to Qualifire format.""" + """Tests for message conversion to API format.""" def test_convert_simple_messages(self): """Test conversion of simple text messages.""" @@ -95,15 +122,13 @@ class TestQualifireGuardrailMessageConversion: {"role": "assistant", "content": "Hi there!"}, ] - # Create mock LLMMessage class - mock_llm_message = MagicMock() + result = guardrail._convert_messages_to_api_format(messages) - with patch( - "litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire.QualifireGuardrail._convert_messages_to_qualifire_format" - ) as mock_convert: - mock_convert.return_value = [mock_llm_message, mock_llm_message] - result = guardrail._convert_messages_to_qualifire_format(messages) - assert len(result) == 2 + assert len(result) == 2 + assert result[0]["role"] == "user" + assert result[0]["content"] == "Hello, world!" + assert result[1]["role"] == "assistant" + assert result[1]["content"] == "Hi there!" def test_convert_multimodal_messages(self): """Test conversion of multimodal messages with text parts.""" @@ -126,112 +151,258 @@ class TestQualifireGuardrailMessageConversion: }, ] - with patch( - "litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire.QualifireGuardrail._convert_messages_to_qualifire_format" - ) as mock_convert: - mock_convert.return_value = [MagicMock()] - result = guardrail._convert_messages_to_qualifire_format(messages) - assert len(result) == 1 + result = guardrail._convert_messages_to_api_format(messages) + + assert len(result) == 1 + assert result[0]["role"] == "user" + assert result[0]["content"] == "First part\nSecond part" + + def test_convert_messages_with_tool_calls(self): + """Test conversion of messages with tool calls.""" + from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import ( + QualifireGuardrail, + ) + + guardrail = QualifireGuardrail( + api_key="test_key", + guardrail_name="test_guardrail", + ) + + messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "NYC"}', + }, + } + ], + }, + ] + + result = guardrail._convert_messages_to_api_format(messages) + + assert len(result) == 1 + assert result[0]["role"] == "assistant" + assert "tool_calls" in result[0] + assert len(result[0]["tool_calls"]) == 1 + assert result[0]["tool_calls"][0]["id"] == "call_123" + assert result[0]["tool_calls"][0]["name"] == "get_weather" + assert result[0]["tool_calls"][0]["arguments"] == {"location": "NYC"} -class TestQualifireGuardrailEvaluateKwargs: - """Tests for evaluate kwargs passed to Qualifire client.""" +class TestQualifireGuardrailToolConversion: + """Tests for tool definition conversion.""" + + def test_convert_openai_function_tools(self): + """Test conversion of OpenAI function tool format.""" + from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import ( + QualifireGuardrail, + ) + + guardrail = QualifireGuardrail( + api_key="test_key", + guardrail_name="test_guardrail", + ) + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a location", + "parameters": {"type": "object", "properties": {}}, + }, + } + ] + + result = guardrail._convert_tools_to_api_format(tools) + + assert result is not None + assert len(result) == 1 + assert result[0]["name"] == "get_weather" + assert result[0]["description"] == "Get weather for a location" + + def test_convert_empty_tools(self): + """Test that empty tools returns None.""" + from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import ( + QualifireGuardrail, + ) + + guardrail = QualifireGuardrail( + api_key="test_key", + guardrail_name="test_guardrail", + ) + + result = guardrail._convert_tools_to_api_format(None) + assert result is None + + result = guardrail._convert_tools_to_api_format([]) + assert result is None + + +class TestQualifireGuardrailAPICall: + """Tests for API call with httpx client.""" @pytest.mark.asyncio async def test_evaluate_called_with_prompt_injections(self): - """Test that evaluate is called with prompt_injections enabled.""" - # Mock the qualifire module and its types - mock_qualifire_types = MagicMock() - mock_llm_message = MagicMock() - mock_llm_tool_call = MagicMock() - mock_message_instance = MagicMock() - mock_llm_message.return_value = mock_message_instance - - mock_qualifire_types.LLMMessage = mock_llm_message - mock_qualifire_types.LLMToolCall = mock_llm_tool_call - - with patch.dict('sys.modules', {'qualifire': MagicMock(), 'qualifire.types': mock_qualifire_types}): - from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import ( - QualifireGuardrail, - ) + """Test that evaluate endpoint is called with prompt_injections enabled.""" + from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import ( + QualifireGuardrail, + ) - guardrail = QualifireGuardrail( - api_key="test_key", - prompt_injections=True, - guardrail_name="test_guardrail", - ) + guardrail = QualifireGuardrail( + api_key="test_key", + prompt_injections=True, + guardrail_name="test_guardrail", + ) - # Mock the client - mock_client = MagicMock() - mock_result = MagicMock() - mock_result.score = 100 - mock_result.status = "completed" - mock_result.evaluationResults = [] - mock_client.evaluate.return_value = mock_result - guardrail._client = mock_client + # Mock the async HTTP handler + mock_response = MagicMock() + mock_response.json.return_value = { + "score": 100, + "status": "completed", + "evaluationResults": [], + } + mock_response.raise_for_status = MagicMock() + guardrail.async_handler.post = AsyncMock(return_value=mock_response) - messages = [{"role": "user", "content": "Hello, world!"}] + messages = [{"role": "user", "content": "Hello, world!"}] - await guardrail._run_qualifire_check( - messages=messages, output=None, dynamic_params={} - ) + await guardrail._run_qualifire_check( + messages=messages, output=None, dynamic_params={} + ) - # Verify evaluate was called with correct kwargs - mock_client.evaluate.assert_called_once() - call_kwargs = mock_client.evaluate.call_args[1] - assert call_kwargs["prompt_injections"] is True - assert "messages" in call_kwargs + # Verify the API was called + guardrail.async_handler.post.assert_called_once() + call_kwargs = guardrail.async_handler.post.call_args[1] + + assert "json" in call_kwargs + payload = call_kwargs["json"] + assert payload["prompt_injections"] is True + assert "messages" in payload + assert call_kwargs["url"].endswith("/api/evaluation/evaluate") @pytest.mark.asyncio async def test_evaluate_called_with_multiple_checks(self): """Test that evaluate is called with multiple checks enabled.""" - # Mock the qualifire module and its types - mock_qualifire_types = MagicMock() - mock_llm_message = MagicMock() - mock_llm_tool_call = MagicMock() - mock_message_instance = MagicMock() - mock_llm_message.return_value = mock_message_instance - - mock_qualifire_types.LLMMessage = mock_llm_message - mock_qualifire_types.LLMToolCall = mock_llm_tool_call - - with patch.dict('sys.modules', {'qualifire': MagicMock(), 'qualifire.types': mock_qualifire_types}): - from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import ( - QualifireGuardrail, - ) + from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import ( + QualifireGuardrail, + ) - guardrail = QualifireGuardrail( - api_key="test_key", - prompt_injections=True, - pii_check=True, - hallucinations_check=True, - assertions=["Output must be valid JSON"], - guardrail_name="test_guardrail", - ) + guardrail = QualifireGuardrail( + api_key="test_key", + prompt_injections=True, + pii_check=True, + hallucinations_check=True, + assertions=["Output must be valid JSON"], + guardrail_name="test_guardrail", + ) - # Mock the client - mock_client = MagicMock() - mock_result = MagicMock() - mock_result.score = 100 - mock_result.status = "completed" - mock_result.evaluationResults = [] - mock_client.evaluate.return_value = mock_result - guardrail._client = mock_client + # Mock the async HTTP handler + mock_response = MagicMock() + mock_response.json.return_value = { + "score": 100, + "status": "completed", + "evaluationResults": [], + } + mock_response.raise_for_status = MagicMock() + guardrail.async_handler.post = AsyncMock(return_value=mock_response) - messages = [{"role": "user", "content": "Hello, world!"}] + messages = [{"role": "user", "content": "Hello, world!"}] - await guardrail._run_qualifire_check( - messages=messages, output="Test output", dynamic_params={} - ) + await guardrail._run_qualifire_check( + messages=messages, output="Test output", dynamic_params={} + ) - # Verify evaluate was called with correct kwargs - mock_client.evaluate.assert_called_once() - call_kwargs = mock_client.evaluate.call_args[1] - assert call_kwargs["prompt_injections"] is True - assert call_kwargs["pii_check"] is True - assert call_kwargs["hallucinations_check"] is True - assert call_kwargs["assertions"] == ["Output must be valid JSON"] - assert call_kwargs["output"] == "Test output" + # Verify the API was called with correct payload + guardrail.async_handler.post.assert_called_once() + call_kwargs = guardrail.async_handler.post.call_args[1] + + payload = call_kwargs["json"] + assert payload["prompt_injections"] is True + assert payload["pii_check"] is True + assert payload["hallucinations_check"] is True + assert payload["assertions"] == ["Output must be valid JSON"] + assert payload["output"] == "Test output" + + @pytest.mark.asyncio + async def test_invoke_endpoint_used_with_evaluation_id(self): + """Test that invoke endpoint is used when evaluation_id is provided.""" + from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import ( + QualifireGuardrail, + ) + + guardrail = QualifireGuardrail( + api_key="test_key", + evaluation_id="eval_123", + guardrail_name="test_guardrail", + ) + + # Mock the async HTTP handler + mock_response = MagicMock() + mock_response.json.return_value = { + "score": 100, + "status": "completed", + "evaluationResults": [], + } + mock_response.raise_for_status = MagicMock() + guardrail.async_handler.post = AsyncMock(return_value=mock_response) + + messages = [{"role": "user", "content": "Hello, world!"}] + + await guardrail._run_qualifire_check( + messages=messages, output="Test output", dynamic_params={} + ) + + # Verify the invoke endpoint was called + guardrail.async_handler.post.assert_called_once() + call_kwargs = guardrail.async_handler.post.call_args[1] + + assert call_kwargs["url"].endswith("/api/evaluation/invoke") + payload = call_kwargs["json"] + assert payload["evaluation_id"] == "eval_123" + assert payload["input"] == "Hello, world!" + assert payload["output"] == "Test output" + + @pytest.mark.asyncio + async def test_correct_headers_sent(self): + """Test that correct headers are sent with the API request.""" + from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import ( + QualifireGuardrail, + ) + + guardrail = QualifireGuardrail( + api_key="my_api_key", + guardrail_name="test_guardrail", + ) + + # Mock the async HTTP handler + mock_response = MagicMock() + mock_response.json.return_value = { + "score": 100, + "status": "completed", + "evaluationResults": [], + } + mock_response.raise_for_status = MagicMock() + guardrail.async_handler.post = AsyncMock(return_value=mock_response) + + messages = [{"role": "user", "content": "Hello!"}] + + await guardrail._run_qualifire_check( + messages=messages, output=None, dynamic_params={} + ) + + call_kwargs = guardrail.async_handler.post.call_args[1] + headers = call_kwargs["headers"] + + assert headers["X-Qualifire-API-Key"] == "my_api_key" + assert headers["Content-Type"] == "application/json" class TestQualifireGuardrailCheckIfFlagged: @@ -248,12 +419,14 @@ class TestQualifireGuardrailCheckIfFlagged: guardrail_name="test_guardrail", ) - # Mock result with completed status and no flagged items - mock_result = MagicMock() - mock_result.status = "completed" - mock_result.evaluationResults = [] + # Result with completed status and no flagged items (dict format) + result = { + "status": "completed", + "score": 100, + "evaluationResults": [], + } - assert guardrail._check_if_flagged(mock_result) is False + assert guardrail._check_if_flagged(result) is False def test_check_if_flagged_returns_true_for_flagged_content(self): """Test that _check_if_flagged returns True when content is flagged.""" @@ -266,18 +439,25 @@ class TestQualifireGuardrailCheckIfFlagged: guardrail_name="test_guardrail", ) - # Mock result with flagged item - mock_inner_result = MagicMock() - mock_inner_result.flagged = True + # Result with flagged item (dict format matching API response) + result = { + "status": "completed", + "score": 15, + "evaluationResults": [ + { + "type": "prompt_injection", + "results": [ + { + "flagged": True, + "score": 0.15, + "reason": "Prompt injection detected", + } + ], + } + ], + } - mock_eval_result = MagicMock() - mock_eval_result.results = [mock_inner_result] - - mock_result = MagicMock() - mock_result.status = "completed" - mock_result.evaluationResults = [mock_eval_result] - - assert guardrail._check_if_flagged(mock_result) is True + assert guardrail._check_if_flagged(result) is True def test_check_if_flagged_returns_false_when_no_flagged_items(self): """Test that _check_if_flagged returns False when no items are flagged.""" @@ -291,17 +471,24 @@ class TestQualifireGuardrailCheckIfFlagged: ) # Result with evaluation results but nothing flagged - mock_inner_result = MagicMock() - mock_inner_result.flagged = False + result = { + "status": "completed", + "score": 95, + "evaluationResults": [ + { + "type": "prompt_injection", + "results": [ + { + "flagged": False, + "score": 0.95, + "reason": "No issues detected", + } + ], + } + ], + } - mock_eval_result = MagicMock() - mock_eval_result.results = [mock_inner_result] - - mock_result = MagicMock() - mock_result.status = "success" - mock_result.evaluationResults = [mock_eval_result] - - assert guardrail._check_if_flagged(mock_result) is False + assert guardrail._check_if_flagged(result) is False class TestQualifireGuardrailShouldRun: From 91b5c66cf2851cc9e3dd76b0299e8bbf819d1ef7 Mon Sep 17 00:00:00 2001 From: Kris Xia Date: Wed, 7 Jan 2026 23:56:47 +0800 Subject: [PATCH 2/9] fix(proxy): return json error response instead of sse format for initial streaming errors (#18757) * adding signoz integration to observability docs * Fixing build * Adding timeout for flaky test * Fixing e2e * fix(proxy): return json error response instead of sse format for initial streaming errors when the first chunk of a streaming response contains an error, return a standard json error response instead of sse format. this ensures clients receive properly formatted error responses before the stream actually begins. - rename create_streaming_response to create_response - add logic to detect error in first chunk and return JSONResponse - add _extract_error_from_sse_chunk helper function - update all call sites to use the new function name - update tests to reflect the function rename * test(proxy): add comprehensive tests for error extraction from sse chunks - Add new test class TestExtractErrorFromSSEChunk with 10 test cases - Update existing tests to verify JSONResponse returned for initial streaming errors - Add tests for error code as string, bytes input, invalid JSON, and edge cases - Verify correct error format extraction from SSE chunks --------- Co-authored-by: Goutham Karthi Co-authored-by: yuneng-jiang Co-authored-by: YutaSaito <36355491+uc4w6c@users.noreply.github.com> --- docs/my-website/docs/observability/signoz.md | 394 ++++++++++++++++++ .../proxy/anthropic_endpoints/endpoints.py | 4 +- litellm/proxy/common_request_processing.py | 75 +++- litellm/proxy/proxy_server.py | 4 +- .../proxy/test_common_request_processing.py | 176 ++++++-- .../tests/users/viewInternalUsers.spec.ts | 1 + .../ModelsAndEndpointsView.tsx | 20 +- 7 files changed, 616 insertions(+), 58 deletions(-) create mode 100644 docs/my-website/docs/observability/signoz.md diff --git a/docs/my-website/docs/observability/signoz.md b/docs/my-website/docs/observability/signoz.md new file mode 100644 index 0000000000..4b65916fdf --- /dev/null +++ b/docs/my-website/docs/observability/signoz.md @@ -0,0 +1,394 @@ +import Tabs from '@theme/Tabs'; +import TabItem from '@theme/TabItem'; + +# SigNoz LiteLLM Integration + +For more details on setting up observability for LiteLLM, check out the [SigNoz LiteLLM observability docs](https://signoz.io/docs/litellm-observability/). + + +## Overview + +This guide walks you through setting up observability and monitoring for LiteLLM SDK and Proxy Server using [OpenTelemetry](https://opentelemetry.io/) and exporting logs, traces, and metrics to SigNoz. With this integration, you can observe various models performance, capture request/response details, and track system-level metrics in SigNoz, giving you real-time visibility into latency, error rates, and usage trends for your LiteLLM applications. + +Instrumenting LiteLLM in your AI applications with telemetry ensures full observability across your AI workflows, making it easier to debug issues, optimize performance, and understand user interactions. By leveraging SigNoz, you can analyze correlated traces, logs, and metrics in unified dashboards, configure alerts, and gain actionable insights to continuously improve reliability, responsiveness, and user experience. + +## Prerequisites + +- A [SigNoz Cloud account](https://signoz.io/teams/) with an active ingestion key +- Internet access to send telemetry data to SigNoz Cloud +- [LiteLLM](https://www.litellm.ai/) SDK or Proxy integration +- For Python: `pip` installed for managing Python packages and _(optional but recommended)_ a Python virtual environment to isolate dependencies + +## Monitoring LiteLLM + +LiteLLM can be monitored in two ways: using the **LiteLLM SDK** (directly embedded in your Python application code for programmatic LLM calls) or the **LiteLLM Proxy Server** (a standalone server that acts as a centralized gateway for managing and routing LLM requests across your infrastructure). + + + + +For more detailed info on instrumenting your LiteLLM SDK applications click [here](https://docs.litellm.ai/docs/observability/opentelemetry_integration). + + + + + +No-code auto-instrumentation is recommended for quick setup with minimal code changes. It's ideal when you want to get observability up and running without modifying your application code and are leveraging standard instrumentor libraries. + +**Step 1:** Install the necessary packages in your Python environment. + +```bash +pip install \ + opentelemetry-api \ + opentelemetry-distro \ + opentelemetry-exporter-otlp \ + httpx \ + opentelemetry-instrumentation-httpx \ + litellm +``` + +**Step 2:** Add Automatic Instrumentation + +```bash +opentelemetry-bootstrap --action=install +``` + +**Step 3:** Instrument your LiteLLM SDK application + +Initialize LiteLLM SDK instrumentation by calling `litellm.callbacks = ["otel"]`: + +```python +from litellm import litellm + +litellm.callbacks = ["otel"] +``` + +This call enables automatic tracing, logs, and metrics collection for all LiteLLM SDK calls in your application. + +> πŸ“Œ Note: Ensure this is called before any LiteLLM related calls to properly configure instrumentation of your application + +**Step 4:** Run an example + +```python +from litellm import completion, litellm + +litellm.callbacks = ["otel"] + +response = completion( + model="openai/gpt-4o", + messages=[{ "content": "What is SigNoz","role": "user"}] +) + +print(response) +``` + +> πŸ“Œ Note: LiteLLM supports a [variety of model providers](https://docs.litellm.ai/docs/providers) for LLMs. In this example, we're using OpenAI. Before running this code, ensure that you have set the environment variable `OPENAI_API_KEY` with your generated API key. + +**Step 5:** Run your application with auto-instrumentation + +```bash +OTEL_RESOURCE_ATTRIBUTES="service.name=" \ +OTEL_EXPORTER_OTLP_ENDPOINT="https://ingest..signoz.cloud:443" \ +OTEL_EXPORTER_OTLP_HEADERS="signoz-ingestion-key=" \ +OTEL_EXPORTER_OTLP_PROTOCOL=grpc \ +OTEL_TRACES_EXPORTER=otlp \ +OTEL_METRICS_EXPORTER=otlp \ +OTEL_LOGS_EXPORTER=otlp \ +OTEL_PYTHON_LOG_CORRELATION=true \ +OTEL_PYTHON_LOGGING_AUTO_INSTRUMENTATION_ENABLED=true \ +OTEL_PYTHON_DISABLED_INSTRUMENTATIONS=openai \ +opentelemetry-instrument +``` + +> πŸ“Œ Note: We're using `OTEL_PYTHON_DISABLED_INSTRUMENTATIONS=openai` in the run command to disable the OpenAI instrumentor for tracing. This avoids conflicts with LiteLLM's native telemetry/instrumentation, ensuring that telemetry is captured exclusively through LiteLLM's built-in instrumentation. + +- **``**Β is the name of your service +- Set the `` to match your SigNoz Cloud [region](https://signoz.io/docs/ingestion/signoz-cloud/overview/#endpoint) +- Replace `` with your SigNoz [ingestion key](https://signoz.io/docs/ingestion/signoz-cloud/keys/) +- Replace `` with the actual command you would use to run your application. For example: `python main.py` + +> πŸ“Œ Note: Using self-hosted SigNoz? Most steps are identical. To adapt this guide, update the endpoint and remove the ingestion key header as shown in [Cloud β†’ Self-Hosted](https://signoz.io/docs/ingestion/cloud-vs-self-hosted/#cloud-to-self-hosted). + + + + + + +Code-based instrumentation gives you fine-grained control over your telemetry configuration. Use this approach when you need to customize resource attributes, sampling strategies, or integrate with existing observability infrastructure. + +**Step 1:** Install the necessary packages in your Python environment. + +```bash +pip install \ + opentelemetry-api \ + opentelemetry-sdk \ + opentelemetry-exporter-otlp \ + opentelemetry-instrumentation-httpx \ + opentelemetry-instrumentation-system-metrics \ + litellm +``` + +**Step 2:** Import the necessary modules in your Python application + +**Traces:** + +```python +from opentelemetry import trace +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +``` + +**Logs:** + +```python +from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler +from opentelemetry.sdk._logs.export import BatchLogRecordProcessor +from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter +from opentelemetry._logs import set_logger_provider +import logging +``` + +**Metrics:** + +```python +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry import metrics +from opentelemetry.instrumentation.system_metrics import SystemMetricsInstrumentor +from opentelemetry.instrumentation.httpx import HTTPXClientInstrumentor +``` + +**Step 3:** Set up the OpenTelemetry Tracer Provider to send traces directly to SigNoz Cloud + +```python +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter +from opentelemetry import trace +import os + +resource = Resource.create({"service.name": ""}) +provider = TracerProvider(resource=resource) +span_exporter = OTLPSpanExporter( + endpoint= os.getenv("OTEL_EXPORTER_TRACES_ENDPOINT"), + headers={"signoz-ingestion-key": os.getenv("SIGNOZ_INGESTION_KEY")}, +) +processor = BatchSpanProcessor(span_exporter) +provider.add_span_processor(processor) +trace.set_tracer_provider(provider) +``` + +- **``**Β is the name of your service +- **`OTEL_EXPORTER_TRACES_ENDPOINT`** β†’ SigNoz Cloud trace endpoint with appropriate [region](https://signoz.io/docs/ingestion/signoz-cloud/overview/#endpoint):`https://ingest..signoz.cloud:443/v1/traces` +- **`SIGNOZ_INGESTION_KEY`** β†’ Your SigNoz [ingestion key](https://signoz.io/docs/ingestion/signoz-cloud/keys/) + + +> πŸ“Œ Note: Using self-hosted SigNoz? Most steps are identical. To adapt this guide, update the endpoint and remove the ingestion key header as shown in [Cloud β†’ Self-Hosted](https://signoz.io/docs/ingestion/cloud-vs-self-hosted/#cloud-to-self-hosted). + + +**Step 4**: Setup Logs + +```python +import logging +from opentelemetry.sdk.resources import Resource +from opentelemetry._logs import set_logger_provider +from opentelemetry.sdk._logs import LoggerProvider, LoggingHandler +from opentelemetry.sdk._logs.export import BatchLogRecordProcessor +from opentelemetry.exporter.otlp.proto.http._log_exporter import OTLPLogExporter +import os + +resource = Resource.create({"service.name": ""}) +logger_provider = LoggerProvider(resource=resource) +set_logger_provider(logger_provider) + +otlp_log_exporter = OTLPLogExporter( + endpoint= os.getenv("OTEL_EXPORTER_LOGS_ENDPOINT"), + headers={"signoz-ingestion-key": os.getenv("SIGNOZ_INGESTION_KEY")}, +) +logger_provider.add_log_record_processor( + BatchLogRecordProcessor(otlp_log_exporter) +) +# Attach OTel logging handler to root logger +handler = LoggingHandler(level=logging.INFO, logger_provider=logger_provider) +logging.basicConfig(level=logging.INFO, handlers=[handler]) + +logger = logging.getLogger(__name__) +``` + +- **``**Β is the name of your service +- **`OTEL_EXPORTER_LOGS_ENDPOINT`** β†’ SigNoz Cloud endpoint with appropriate [region](https://signoz.io/docs/ingestion/signoz-cloud/overview/#endpoint):`https://ingest..signoz.cloud:443/v1/logs` +- **`SIGNOZ_INGESTION_KEY`** β†’ Your SigNoz [ingestion key](https://signoz.io/docs/ingestion/signoz-cloud/keys/) + +> πŸ“Œ Note: Using self-hosted SigNoz? Most steps are identical. To adapt this guide, update the endpoint and remove the ingestion key header as shown in [Cloud β†’ Self-Hosted](https://signoz.io/docs/ingestion/cloud-vs-self-hosted/#cloud-to-self-hosted). + + +**Step 5**: Setup Metrics + +```python +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry import metrics +from opentelemetry.instrumentation.system_metrics import SystemMetricsInstrumentor +import os + +resource = Resource.create({"service.name": ""}) +metric_exporter = OTLPMetricExporter( + endpoint= os.getenv("OTEL_EXPORTER_METRICS_ENDPOINT"), + headers={"signoz-ingestion-key": os.getenv("SIGNOZ_INGESTION_KEY")}, +) +reader = PeriodicExportingMetricReader(metric_exporter) +metric_provider = MeterProvider(metric_readers=[reader], resource=resource) +metrics.set_meter_provider(metric_provider) + +meter = metrics.get_meter(__name__) + +# turn on out-of-the-box metrics +SystemMetricsInstrumentor().instrument() +HTTPXClientInstrumentor().instrument() +``` + +- **``**Β is the name of your service +- **`OTEL_EXPORTER_METRICS_ENDPOINT`** β†’ SigNoz Cloud endpoint with appropriate [region](https://signoz.io/docs/ingestion/signoz-cloud/overview/#endpoint):`https://ingest..signoz.cloud:443/v1/metrics` +- **`SIGNOZ_INGESTION_KEY`** β†’ Your SigNoz [ingestion key](https://signoz.io/docs/ingestion/signoz-cloud/keys/) + +> πŸ“Œ Note: Using self-hosted SigNoz? Most steps are identical. To adapt this guide, update the endpoint and remove the ingestion key header as shown in [Cloud β†’ Self-Hosted](https://signoz.io/docs/ingestion/cloud-vs-self-hosted/#cloud-to-self-hosted). + + +> πŸ“Œ Note: SystemMetricsInstrumentor provides system metrics (CPU, memory, etc.), and HTTPXClientInstrumentor provides outbound HTTP request metrics such as request duration. If you want to add custom metrics to your LiteLLM application, see [Python Custom Metrics](https://signoz.io/opentelemetry/python-custom-metrics/). + +**Step 6:** Instrument your LiteLLM application + +Initialize LiteLLM SDK instrumentation by calling `litellm.callbacks = ["otel"]`: + +```python +from litellm import litellm + +litellm.callbacks = ["otel"] +``` + +This call enables automatic tracing, logs, and metrics collection for all LiteLLM SDK calls in your application. + +> πŸ“Œ Note: Ensure this is called before any LiteLLM related calls to properly configure instrumentation of your application + +**Step 7:** Run an example + +```python +from litellm import completion, litellm + +litellm.callbacks = ["otel"] + +response = completion( + model="openai/gpt-4o", + messages=[{ "content": "What is SigNoz","role": "user"}] +) + +print(response) +``` + +> πŸ“Œ Note: LiteLLM supports a [variety of model providers](https://docs.litellm.ai/docs/providers) for LLMs. In this example, we're using OpenAI. Before running this code, ensure that you have set the environment variable `OPENAI_API_KEY` with your generated API key. + + + + +## View Traces, Logs, and Metrics in SigNoz + +Your LiteLLM commands should now automatically emit traces, logs, and metrics. + +You should be able to view traces in Signoz Cloud under the traces tab: + +![LiteLLM SDK Trace View](https://signoz.io/img/docs/llm/litellm/litellmsdk-traces.webp) + +When you click on a trace in SigNoz, you'll see a detailed view of the trace, including all associated spans, along with their events and attributes. + +![LiteLLM SDK Detailed Trace View](https://signoz.io/img/docs/llm/litellm/litellmsdk-detailed-traces.webp) + +You should be able to view logs in Signoz Cloud under the logs tab. You can also view logs by clicking on the β€œRelated Logs” button in the trace view to see correlated logs: + +![LiteLLM SDK Logs View](https://signoz.io/img/docs/llm/litellm/litellmsdk-logs.webp) + +When you click on any of these logs in SigNoz, you'll see a detailed view of the log, including attributes: + +![LiteLLM SDK Detailed Logs View](https://signoz.io/img/docs/llm/litellm/litellmsdk-detailed-logs.webp) + +You should be able to see LiteLLM related metrics in Signoz Cloud under the metrics tab: + +![LiteLLM SDK Metrics View](https://signoz.io/img/docs/llm/litellm/litellmsdk-metrics.webp) + +When you click on any of these metrics in SigNoz, you'll see a detailed view of the metric, including attributes: + +![LiteLLM Detailed Metrics View](https://signoz.io/img/docs/llm/litellm/litellmsdk-detailed-metrics.webp) + +## Dashboard + +You can also check out our custom LiteLLM SDK dashboardΒ [here](https://signoz.io/docs/dashboards/dashboard-templates/litellm-sdk-dashboard/) which provides specialized visualizations for monitoring your LiteLLM usage in applications. The dashboard includes pre-built charts specifically tailored for LLM usage, along with import instructions to get started quickly. + +![LiteLLM SDK Dashboard Template](https://signoz.io/img/docs/llm/litellm/litellm-sdk-dashboard.webp) + + + + + +**Step 1:** Install the necessary packages in your Python environment. + +```bash +pip install opentelemetry-api \ + opentelemetry-sdk \ + opentelemetry-exporter-otlp \ + 'litellm[proxy]' +``` + +**Step 2:** Configure otel for the LiteLLM Proxy Server + +Add the following to `config.yaml`: + +```yaml +litellm_settings: + callbacks: ['otel'] +``` + +**Step 3:** Set the following environment variables: + +```bash +export OTEL_EXPORTER_OTLP_ENDPOINT="https://ingest..signoz.cloud:443" +export OTEL_EXPORTER_OTLP_HEADERS="signoz-ingestion-key=" +export OTEL_EXPORTER_OTLP_PROTOCOL="grpc" +export OTEL_TRACES_EXPORTER="otlp" +export OTEL_METRICS_EXPORTER="otlp" +export OTEL_LOGS_EXPORTER="otlp" +``` + +- Set the `` to match your SigNoz Cloud [region](https://signoz.io/docs/ingestion/signoz-cloud/overview/#endpoint) +- Replace `` with your SigNoz [ingestion key](https://signoz.io/docs/ingestion/signoz-cloud/keys/) + +> πŸ“Œ Note: Using self-hosted SigNoz? Most steps are identical. To adapt this guide, update the endpoint and remove the ingestion key header as shown in [Cloud β†’ Self-Hosted](https://signoz.io/docs/ingestion/cloud-vs-self-hosted/#cloud-to-self-hosted). + + +**Step 4:** Run the proxy server using the config file: + +```bash +litellm --config config.yaml +``` + +Now any calls made through your LiteLLM proxy server will be traced and sent to SigNoz. + +You should be able to view traces in Signoz Cloud under the traces tab: + +![LiteLLM Proxy Trace View](https://signoz.io/img/docs/llm/litellm/litellmproxy-traces.webp) + +When you click on a trace in SigNoz, you'll see a detailed view of the trace, including all associated spans, along with their events and attributes. + +![LiteLLM Proxy Detailed Trace View](https://signoz.io/img/docs/llm/litellm/litellmproxy-detailed-traces.webp) + +## Dashboard + +You can also check out our custom LiteLLM Proxy dashboardΒ [here](https://signoz.io/docs/dashboards/dashboard-templates/litellm-proxy-dashboard/) which provides specialized visualizations for monitoring your LiteLLM Proxy usage in applications. The dashboard includes pre-built charts specifically tailored for LLM usage, along with import instructions to get started quickly. + +![LiteLLM Proxy Dashboard Template](https://signoz.io/img/docs/llm/litellm/litellm-proxy-dashboard.webp) + + + diff --git a/litellm/proxy/anthropic_endpoints/endpoints.py b/litellm/proxy/anthropic_endpoints/endpoints.py index 334362a027..7de6b7fccf 100644 --- a/litellm/proxy/anthropic_endpoints/endpoints.py +++ b/litellm/proxy/anthropic_endpoints/endpoints.py @@ -11,7 +11,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.integrations.custom_guardrail import ModifyResponseException from litellm.proxy.common_request_processing import ( ProxyBaseLLMRequestProcessing, - create_streaming_response, + create_response, ) from litellm.proxy.common_utils.http_parsing_utils import _read_request_body from litellm.types.utils import TokenCountResponse @@ -106,7 +106,7 @@ async def anthropic_response( # noqa: PLR0915 ) ) - return await create_streaming_response( + return await create_response( generator=selected_data_generator, media_type="text/event-stream", headers={}, diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 537b48f06e..95c23be5b8 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -17,7 +17,7 @@ from typing import ( import httpx import orjson from fastapi import HTTPException, Request, status -from fastapi.responses import Response, StreamingResponse +from fastapi.responses import JSONResponse, Response, StreamingResponse import litellm from litellm._logging import verbose_proxy_logger @@ -96,16 +96,55 @@ async def _parse_event_data_for_error(event_line: Union[str, bytes]) -> Optional return None -async def create_streaming_response( +def _extract_error_from_sse_chunk(event_line: Union[str, bytes]) -> dict: + """ + Extract error dictionary from SSE format chunk. + + Args: + event_line: SSE format event line, e.g. "data: {"error": {...}}\n\n" + + Returns: + Error dictionary in OpenAI API format + """ + event_line = ( + event_line.decode("utf-8") if isinstance(event_line, bytes) else event_line + ) + + # Default error format + default_error = { + "message": "Unknown error", + "type": "internal_server_error", + "param": None, + "code": "500", + } + + if event_line.startswith("data: "): + json_str = event_line[len("data: ") :].strip() + if not json_str or json_str == "[DONE]": + return default_error + + try: + data = orjson.loads(json_str) + if isinstance(data, dict) and "error" in data: + error_obj = data["error"] + if isinstance(error_obj, dict): + return error_obj + except (orjson.JSONDecodeError, json.JSONDecodeError): + pass + + return default_error + + +async def create_response( generator: AsyncGenerator[str, None], media_type: str, headers: dict, default_status_code: int = status.HTTP_200_OK, -) -> StreamingResponse: +) -> Union[StreamingResponse, JSONResponse]: """ - Creates a StreamingResponse by inspecting the first chunk for an error code. - The entire original generator content is streamed, but the HTTP status code - of the response is set based on the first chunk if it's a recognized error. + Create streaming response, checking if the first chunk is an error. + If the first chunk is an error, return a standard JSON error response. + Otherwise, return StreamingResponse and stream all content. """ first_chunk_value: Optional[str] = None final_status_code = default_status_code @@ -124,9 +163,27 @@ async def create_streaming_response( first_chunk_value ) if error_code_from_chunk is not None: + # First chunk is an error, stream hasn't really started yet + # Should return standard JSON error response instead of SSE format final_status_code = error_code_from_chunk verbose_proxy_logger.debug( - f"Error detected in first stream chunk. Status code set to: {final_status_code}" + f"Error detected in first stream chunk. Returning JSON error response with status code: {final_status_code}" + ) + + # Parse error content + error_dict = _extract_error_from_sse_chunk(first_chunk_value) + + # Consume and close generator (avoid resource leak) + try: + await generator.aclose() + except Exception: + pass + + # Return JSON format error response + return JSONResponse( + status_code=final_status_code, + content={"error": error_dict}, + headers=headers, ) except Exception as e: verbose_proxy_logger.debug(f"Error parsing first chunk value: {e}") @@ -647,7 +704,7 @@ class ProxyBaseLLMRequestProcessing: proxy_logging_obj=proxy_logging_obj, ) ) - return await create_streaming_response( + return await create_response( generator=selected_data_generator, media_type="text/event-stream", headers=custom_headers, @@ -658,7 +715,7 @@ class ProxyBaseLLMRequestProcessing: user_api_key_dict=user_api_key_dict, request_data=self.data, ) - return await create_streaming_response( + return await create_response( generator=selected_data_generator, media_type="text/event-stream", headers=custom_headers, diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 06525e3913..9e314a77c3 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -229,7 +229,7 @@ from litellm.proxy.batches_endpoints.endpoints import router as batches_router from litellm.proxy.caching_routes import router as caching_router from litellm.proxy.common_request_processing import ( ProxyBaseLLMRequestProcessing, - create_streaming_response, + create_response, ) from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy from litellm.proxy.common_utils.debug_utils import init_verbose_loggers @@ -6736,7 +6736,7 @@ async def run_thread( if ( "stream" in data and data["stream"] is True ): # use generate_responses to stream responses - return await create_streaming_response( + return await create_response( generator=async_assistants_data_generator( user_api_key_dict=user_api_key_dict, response=response, diff --git a/tests/test_litellm/proxy/test_common_request_processing.py b/tests/test_litellm/proxy/test_common_request_processing.py index b5d4438569..ed1c29f5dc 100644 --- a/tests/test_litellm/proxy/test_common_request_processing.py +++ b/tests/test_litellm/proxy/test_common_request_processing.py @@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock import pytest from fastapi import Request, status -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse import litellm from litellm._uuid import uuid @@ -11,9 +11,10 @@ from litellm.integrations.opentelemetry import UserAPIKeyAuth from litellm.proxy.common_request_processing import ( ProxyBaseLLMRequestProcessing, ProxyConfig, + _extract_error_from_sse_chunk, _get_cost_breakdown_from_logging_obj, _parse_event_data_for_error, - create_streaming_response, + create_response, ) from litellm.proxy.utils import ProxyLogging @@ -602,21 +603,27 @@ class TestCommonRequestProcessingHelpers: assert await _parse_event_data_for_error(event_line) == expected_code async def test_create_streaming_response_first_chunk_is_error(self): + """ + Test that when the first chunk is an error, a JSON error response is returned + instead of an SSE streaming response + """ async def mock_generator(): yield 'data: {"error": {"code": 403, "message": "forbidden"}}\n\n' yield 'data: {"content": "more data"}\n\n' yield "data: [DONE]\n\n" - response = await create_streaming_response( + response = await create_response( mock_generator(), "text/event-stream", {} ) + # Should return JSONResponse instead of StreamingResponse + assert isinstance(response, JSONResponse) assert response.status_code == status.HTTP_403_FORBIDDEN - content = await self.consume_stream(response) - assert content == [ - 'data: {"error": {"code": 403, "message": "forbidden"}}\n\n', - 'data: {"content": "more data"}\n\n', - "data: [DONE]\n\n", - ] + # Verify the response is in standard JSON error format + import json + body = json.loads(response.body.decode()) + assert "error" in body + assert body["error"]["code"] == 403 + assert body["error"]["message"] == "forbidden" async def test_create_streaming_response_first_chunk_not_error(self): async def mock_generator(): @@ -624,7 +631,7 @@ class TestCommonRequestProcessingHelpers: yield 'data: {"content": "second part"}\n\n' yield "data: [DONE]\n\n" - response = await create_streaming_response( + response = await create_response( mock_generator(), "text/event-stream", {} ) assert response.status_code == status.HTTP_200_OK @@ -641,7 +648,7 @@ class TestCommonRequestProcessingHelpers: yield # Implicitly raises StopAsyncIteration - response = await create_streaming_response( + response = await create_response( mock_generator(), "text/event-stream", {} ) assert response.status_code == status.HTTP_200_OK @@ -654,7 +661,7 @@ class TestCommonRequestProcessingHelpers: mock_gen = AsyncMock() mock_gen.__anext__.side_effect = StopAsyncIteration - response = await create_streaming_response(mock_gen, "text/event-stream", {}) + response = await create_response(mock_gen, "text/event-stream", {}) assert response.status_code == status.HTTP_200_OK content = await self.consume_stream(response) assert content == [] @@ -665,7 +672,7 @@ class TestCommonRequestProcessingHelpers: mock_gen = AsyncMock() mock_gen.__anext__.side_effect = ValueError("Test error from generator") - response = await create_streaming_response(mock_gen, "text/event-stream", {}) + response = await create_response(mock_gen, "text/event-stream", {}) assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR content = await self.consume_stream(response) expected_error_data = { @@ -682,19 +689,24 @@ class TestCommonRequestProcessingHelpers: assert content[1] == "data: [DONE]\n\n" async def test_create_streaming_response_first_chunk_error_string_code(self): + """ + Test that when the first chunk contains a string error code, a JSON error response is returned + """ async def mock_generator(): yield 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n' yield "data: [DONE]\n\n" - response = await create_streaming_response( + response = await create_response( mock_generator(), "text/event-stream", {} ) + assert isinstance(response, JSONResponse) assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS - content = await self.consume_stream(response) - assert content == [ - 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n', - "data: [DONE]\n\n", - ] + # Verify the response is in standard JSON error format + import json + body = json.loads(response.body.decode()) + assert "error" in body + assert body["error"]["code"] == "429" + assert body["error"]["message"] == "too many requests" async def test_create_streaming_response_custom_headers(self): async def mock_generator(): @@ -702,7 +714,7 @@ class TestCommonRequestProcessingHelpers: yield "data: [DONE]\n\n" custom_headers = {"X-Custom-Header": "TestValue"} - response = await create_streaming_response( + response = await create_response( mock_generator(), "text/event-stream", custom_headers ) assert response.headers["x-custom-header"] == "TestValue" @@ -712,7 +724,7 @@ class TestCommonRequestProcessingHelpers: yield 'data: {"content": "data"}\n\n' yield "data: [DONE]\n\n" - response = await create_streaming_response( + response = await create_response( mock_generator(), "text/event-stream", {}, @@ -729,7 +741,7 @@ class TestCommonRequestProcessingHelpers: async def mock_generator(): yield "data: [DONE]\n\n" - response = await create_streaming_response( + response = await create_response( mock_generator(), "text/event-stream", {} ) assert response.status_code == status.HTTP_200_OK # Default status @@ -742,7 +754,7 @@ class TestCommonRequestProcessingHelpers: yield 'data: {"content": "actual data"}\n\n' yield "data: [DONE]\n\n" - response = await create_streaming_response( + response = await create_response( mock_generator(), "text/event-stream", {} ) assert response.status_code == status.HTTP_200_OK # Default status @@ -773,7 +785,7 @@ class TestCommonRequestProcessingHelpers: # Patch the tracer in the common_request_processing module with patch("litellm.proxy.common_request_processing.tracer", mock_tracer): - response = await create_streaming_response( + response = await create_response( mock_generator(), "text/event-stream", {} ) @@ -810,7 +822,10 @@ class TestCommonRequestProcessingHelpers: ), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}" async def test_create_streaming_response_dd_trace_with_error_chunk(self): - """Test that dd trace is applied even when the first chunk contains an error""" + """ + Test that when the first chunk contains an error, JSONResponse is returned + and tracing is not triggered (since it's not a streaming response) + """ from unittest.mock import patch # Create a mock tracer @@ -827,28 +842,107 @@ class TestCommonRequestProcessingHelpers: # Patch the tracer in the common_request_processing module with patch("litellm.proxy.common_request_processing.tracer", mock_tracer): - response = await create_streaming_response( + response = await create_response( mock_generator(), "text/event-stream", {} ) - # Even with error, status should be set to error code but tracing should still work + # Should return JSONResponse instead of StreamingResponse + assert isinstance(response, JSONResponse) assert response.status_code == 400 - # Consume the stream to trigger the tracer calls - content = await self.consume_stream(response) + # Verify the response is in standard JSON error format + import json + body = json.loads(response.body.decode()) + assert "error" in body + assert body["error"]["code"] == 400 + assert body["error"]["message"] == "bad request" - # Verify all chunks are present - assert len(content) == 3 + # Since JSONResponse is returned instead of StreamingResponse, streaming tracing should not be triggered + # tracer.trace should not be called + assert mock_tracer.trace.call_count == 0 - # Verify that tracer.trace was called for each chunk - assert mock_tracer.trace.call_count == 3 - # Verify that each call was made with the correct operation name - actual_calls = mock_tracer.trace.call_args_list - assert len(actual_calls) == 3 +class TestExtractErrorFromSSEChunk: + """Tests for _extract_error_from_sse_chunk function""" + + def test_extract_error_from_sse_chunk_with_valid_error(self): + """Test extracting error information from a standard SSE chunk""" + chunk = 'data: {"error": {"code": 403, "message": "forbidden", "type": "auth_error", "param": "api_key"}}\n\n' + error = _extract_error_from_sse_chunk(chunk) + + assert error["code"] == 403 + assert error["message"] == "forbidden" + assert error["type"] == "auth_error" + assert error["param"] == "api_key" + + def test_extract_error_from_sse_chunk_with_string_code(self): + """Test error code as string type""" + chunk = 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n' + error = _extract_error_from_sse_chunk(chunk) + + assert error["code"] == "429" + assert error["message"] == "too many requests" + + def test_extract_error_from_sse_chunk_with_bytes(self): + """Test input as bytes type""" + chunk = b'data: {"error": {"code": 500, "message": "internal error"}}\n\n' + error = _extract_error_from_sse_chunk(chunk) + + assert error["code"] == 500 + assert error["message"] == "internal error" + + def test_extract_error_from_sse_chunk_with_done(self): + """Test [DONE] marker should return default error""" + chunk = "data: [DONE]\n\n" + error = _extract_error_from_sse_chunk(chunk) + + assert error["message"] == "Unknown error" + assert error["type"] == "internal_server_error" + assert error["code"] == "500" + assert error["param"] is None + + def test_extract_error_from_sse_chunk_without_error_field(self): + """Test missing error field should return default error""" + chunk = 'data: {"content": "some content"}\n\n' + error = _extract_error_from_sse_chunk(chunk) + + assert error["message"] == "Unknown error" + assert error["type"] == "internal_server_error" + assert error["code"] == "500" + + def test_extract_error_from_sse_chunk_with_invalid_json(self): + """Test invalid JSON should return default error""" + chunk = 'data: {invalid json}\n\n' + error = _extract_error_from_sse_chunk(chunk) + + assert error["message"] == "Unknown error" + assert error["type"] == "internal_server_error" + assert error["code"] == "500" + + def test_extract_error_from_sse_chunk_without_data_prefix(self): + """Test missing 'data:' prefix should return default error""" + chunk = '{"error": {"code": 400, "message": "bad request"}}\n\n' + error = _extract_error_from_sse_chunk(chunk) + + assert error["message"] == "Unknown error" + assert error["type"] == "internal_server_error" + assert error["code"] == "500" + + def test_extract_error_from_sse_chunk_with_empty_string(self): + """Test empty string should return default error""" + chunk = "" + error = _extract_error_from_sse_chunk(chunk) + + assert error["message"] == "Unknown error" + assert error["type"] == "internal_server_error" + assert error["code"] == "500" + + def test_extract_error_from_sse_chunk_with_minimal_error(self): + """Test minimal error object""" + chunk = 'data: {"error": {"message": "error occurred"}}\n\n' + error = _extract_error_from_sse_chunk(chunk) + + assert error["message"] == "error occurred" + # Other fields should be obtained from the original error object (if exists) + - for i, call in enumerate(actual_calls): - args, kwargs = call - assert ( - args[0] == "streaming.chunk.yield" - ), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}" diff --git a/ui/litellm-dashboard/e2e_tests/tests/users/viewInternalUsers.spec.ts b/ui/litellm-dashboard/e2e_tests/tests/users/viewInternalUsers.spec.ts index 980c7233e4..65797028ed 100644 --- a/ui/litellm-dashboard/e2e_tests/tests/users/viewInternalUsers.spec.ts +++ b/ui/litellm-dashboard/e2e_tests/tests/users/viewInternalUsers.spec.ts @@ -43,6 +43,7 @@ test.describe("Internal Users Page", () => { await expect(prevButton).toBeDisabled(); } + await page.waitForTimeout(1000); // Check if there are more pages const hasMorePages = infoText.includes("of") && !infoText.endsWith("25 of 25"); if (hasMorePages) { diff --git a/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/ModelsAndEndpointsView.tsx b/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/ModelsAndEndpointsView.tsx index a8b1d2cddc..1cce704467 100644 --- a/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/ModelsAndEndpointsView.tsx +++ b/ui/litellm-dashboard/src/app/(dashboard)/models-and-endpoints/ModelsAndEndpointsView.tsx @@ -279,12 +279,13 @@ const ModelsAndEndpointsView: React.FC = ({ premiumUser, te {/* Missing Provider Banner */}
- +

Missing a provider?

- The LiteLLM engineering team is constantly adding support for new LLM models, providers, endpoints. If you don't see the one you need, let us know and we'll prioritize it. + The LiteLLM engineering team is constantly adding support for new LLM models, providers, endpoints. If + you don't see the one you need, let us know and we'll prioritize it.

= ({ premiumUser, te className="flex-shrink-0 inline-flex items-center gap-2 px-4 py-2 bg-[#6366f1] hover:bg-[#5558e3] text-white text-sm font-medium rounded-lg transition-colors" > Request Provider - - + +
From 92f7789f1000c29519f3270442cc4ccd16cb989e Mon Sep 17 00:00:00 2001 From: Harshit Jain <48647625+Harshit28j@users.noreply.github.com> Date: Wed, 7 Jan 2026 21:29:04 +0530 Subject: [PATCH 3/9] feat(prometheus): add caching metrics (#18755) --- litellm/integrations/prometheus.py | 101 +++++++-- litellm/types/integrations/prometheus.py | 21 +- .../test_prometheus_cache_metrics.py | 211 ++++++++++++++++++ 3 files changed, 318 insertions(+), 15 deletions(-) create mode 100644 tests/test_litellm/integrations/test_prometheus_cache_metrics.py diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index c01f748127..2ec2f41b3b 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -14,6 +14,7 @@ from typing import ( Literal, Optional, Tuple, + Union, cast, ) @@ -44,6 +45,7 @@ def _get_cached_end_user_id_for_cost_tracking(): global _get_end_user_id_for_cost_tracking if _get_end_user_id_for_cost_tracking is None: from litellm.utils import get_end_user_id_for_cost_tracking + _get_end_user_id_for_cost_tracking = get_end_user_id_for_cost_tracking return _get_end_user_id_for_cost_tracking @@ -329,6 +331,25 @@ class PrometheusLogger(CustomLogger): labelnames=self.get_labels_for_metric("litellm_requests_metric"), ) + # Cache metrics + self.litellm_cache_hits_metric = self._counter_factory( + name="litellm_cache_hits_metric", + documentation="Total number of LiteLLM cache hits", + labelnames=self.get_labels_for_metric("litellm_cache_hits_metric"), + ) + + self.litellm_cache_misses_metric = self._counter_factory( + name="litellm_cache_misses_metric", + documentation="Total number of LiteLLM cache misses", + labelnames=self.get_labels_for_metric("litellm_cache_misses_metric"), + ) + + self.litellm_cached_tokens_metric = self._counter_factory( + name="litellm_cached_tokens_metric", + documentation="Total tokens served from LiteLLM cache", + labelnames=self.get_labels_for_metric("litellm_cached_tokens_metric"), + ) + except Exception as e: print_verbose(f"Got exception on init prometheus client {str(e)}") raise e @@ -795,7 +816,7 @@ class PrometheusLogger(CustomLogger): litellm_params = kwargs.get("litellm_params", {}) or {} _metadata = litellm_params.get("metadata", {}) get_end_user_id_for_cost_tracking = _get_cached_end_user_id_for_cost_tracking() - + end_user_id = get_end_user_id_for_cost_tracking( litellm_params, service_type="prometheus" ) @@ -815,7 +836,7 @@ class PrometheusLogger(CustomLogger): user_api_key_auth_metadata: Optional[dict] = standard_logging_payload[ "metadata" ].get("user_api_key_auth_metadata") - + # Include top-level metadata fields (excluding nested dictionaries) # This allows accessing fields like requester_ip_address from top-level metadata top_level_metadata = standard_logging_payload.get("metadata", {}) @@ -826,7 +847,7 @@ class PrometheusLogger(CustomLogger): for k, v in top_level_metadata.items() if not isinstance(v, dict) # Exclude nested dicts to avoid conflicts } - + combined_metadata: Dict[str, Any] = { **top_level_fields, # Include top-level fields first **(_requester_metadata if _requester_metadata else {}), @@ -945,6 +966,12 @@ class PrometheusLogger(CustomLogger): kwargs, start_time, end_time, enum_values, output_tokens ) + # cache metrics + self._increment_cache_metrics( + standard_logging_payload=standard_logging_payload, # type: ignore + enum_values=enum_values, + ) + if ( standard_logging_payload["stream"] is True ): # log successful streaming requests from logging event hook. @@ -1014,6 +1041,54 @@ class PrometheusLogger(CustomLogger): standard_logging_payload["completion_tokens"] ) + def _increment_cache_metrics( + self, + standard_logging_payload: StandardLoggingPayload, + enum_values: UserAPIKeyLabelValues, + ): + """ + Increment cache-related Prometheus metrics based on cache hit/miss status. + + Args: + standard_logging_payload: Contains cache_hit field (True/False/None) + enum_values: Label values for Prometheus metrics + """ + cache_hit = standard_logging_payload.get("cache_hit") + + # Only track if cache_hit has a definite value (True or False) + if cache_hit is None: + return + + if cache_hit is True: + # Increment cache hits counter + _labels = prometheus_label_factory( + supported_enum_labels=self.get_labels_for_metric( + metric_name="litellm_cache_hits_metric" + ), + enum_values=enum_values, + ) + self.litellm_cache_hits_metric.labels(**_labels).inc() + + # Increment cached tokens counter + total_tokens = standard_logging_payload.get("total_tokens", 0) + if total_tokens > 0: + _labels = prometheus_label_factory( + supported_enum_labels=self.get_labels_for_metric( + metric_name="litellm_cached_tokens_metric" + ), + enum_values=enum_values, + ) + self.litellm_cached_tokens_metric.labels(**_labels).inc(total_tokens) + else: + # cache_hit is False - increment cache misses counter + _labels = prometheus_label_factory( + supported_enum_labels=self.get_labels_for_metric( + metric_name="litellm_cache_misses_metric" + ), + enum_values=enum_values, + ) + self.litellm_cache_misses_metric.labels(**_labels).inc() + async def _increment_remaining_budget_metrics( self, user_api_team: Optional[str], @@ -1196,7 +1271,7 @@ class PrometheusLogger(CustomLogger): ) litellm_params = kwargs.get("litellm_params", {}) or {} get_end_user_id_for_cost_tracking = _get_cached_end_user_id_for_cost_tracking() - + end_user_id = get_end_user_id_for_cost_tracking( litellm_params, service_type="prometheus" ) @@ -1398,7 +1473,6 @@ class PrometheusLogger(CustomLogger): api_provider=llm_provider or "", ) if exception is not None: - _labels = prometheus_label_factory( supported_enum_labels=self.get_labels_for_metric( metric_name="litellm_deployment_failure_responses" @@ -1431,12 +1505,11 @@ class PrometheusLogger(CustomLogger): enum_values: UserAPIKeyLabelValues, output_tokens: float = 1.0, ): - try: verbose_logger.debug("setting remaining tokens requests metric") - standard_logging_payload: Optional[StandardLoggingPayload] = ( - request_kwargs.get("standard_logging_object") - ) + standard_logging_payload: Optional[ + StandardLoggingPayload + ] = request_kwargs.get("standard_logging_object") if standard_logging_payload is None: return @@ -2208,10 +2281,10 @@ class PrometheusLogger(CustomLogger): from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES from litellm.integrations.custom_logger import CustomLogger - prometheus_loggers: List[CustomLogger] = ( - litellm.logging_callback_manager.get_custom_loggers_for_type( - callback_type=PrometheusLogger - ) + prometheus_loggers: List[ + CustomLogger + ] = litellm.logging_callback_manager.get_custom_loggers_for_type( + callback_type=PrometheusLogger ) # we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them verbose_logger.debug("found %s prometheus loggers", len(prometheus_loggers)) @@ -2283,7 +2356,7 @@ def prometheus_label_factory( if UserAPIKeyLabelNames.END_USER.value in filtered_labels: get_end_user_id_for_cost_tracking = _get_cached_end_user_id_for_cost_tracking() - + filtered_labels["end_user"] = get_end_user_id_for_cost_tracking( litellm_params={"user_api_key_end_user_id": enum_values.end_user}, service_type="prometheus", diff --git a/litellm/types/integrations/prometheus.py b/litellm/types/integrations/prometheus.py index 6a254fc825..fb439a9541 100644 --- a/litellm/types/integrations/prometheus.py +++ b/litellm/types/integrations/prometheus.py @@ -1,7 +1,7 @@ import re from dataclasses import dataclass from enum import Enum -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Dict, List, Literal, Optional, Tuple from pydantic import BaseModel, Field from typing_extensions import Annotated @@ -185,6 +185,10 @@ DEFINED_PROMETHEUS_METRICS = Literal[ "litellm_redis_daily_spend_update_queue_size", "litellm_in_memory_spend_update_queue_size", "litellm_redis_spend_update_queue_size", + # Cache metrics + "litellm_cache_hits_metric", + "litellm_cache_misses_metric", + "litellm_cached_tokens_metric", ] @@ -436,6 +440,21 @@ class PrometheusMetricLabels: litellm_redis_spend_update_queue_size: List[str] = [] + # Cache metrics - track cache hits, misses, and tokens served from cache + _cache_metric_labels = [ + UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value, + UserAPIKeyLabelNames.API_KEY_HASH.value, + UserAPIKeyLabelNames.API_KEY_ALIAS.value, + UserAPIKeyLabelNames.TEAM.value, + UserAPIKeyLabelNames.TEAM_ALIAS.value, + UserAPIKeyLabelNames.END_USER.value, + UserAPIKeyLabelNames.USER.value, + ] + + litellm_cache_hits_metric = _cache_metric_labels + litellm_cache_misses_metric = _cache_metric_labels + litellm_cached_tokens_metric = _cache_metric_labels + @staticmethod def get_labels(label_name: DEFINED_PROMETHEUS_METRICS) -> List[str]: default_labels = getattr(PrometheusMetricLabels, label_name) diff --git a/tests/test_litellm/integrations/test_prometheus_cache_metrics.py b/tests/test_litellm/integrations/test_prometheus_cache_metrics.py new file mode 100644 index 0000000000..660757673f --- /dev/null +++ b/tests/test_litellm/integrations/test_prometheus_cache_metrics.py @@ -0,0 +1,211 @@ +""" +Unit tests for cache Prometheus metrics. + +Run with: poetry run pytest tests/test_litellm/integrations/test_prometheus_cache_metrics.py -v +""" +import pytest +from unittest.mock import MagicMock, patch +from litellm.types.integrations.prometheus import UserAPIKeyLabelValues + + +class TestPrometheusCacheMetrics: + """Tests for cache-related Prometheus metrics""" + + @pytest.fixture + def sample_enum_values(self): + """Create sample enum values for labels""" + return UserAPIKeyLabelValues( + end_user="test-end-user", + hashed_api_key="test-key-hash", + api_key_alias="test-key-alias", + team="test-team", + team_alias="test-team-alias", + user="test-user", + model="gpt-3.5-turbo", + ) + + def test_cache_metrics_defined_in_types(self): + """Test that cache metrics are defined in DEFINED_PROMETHEUS_METRICS""" + from litellm.types.integrations.prometheus import DEFINED_PROMETHEUS_METRICS + from typing import get_args + + defined_metrics = get_args(DEFINED_PROMETHEUS_METRICS) + + assert "litellm_cache_hits_metric" in defined_metrics + assert "litellm_cache_misses_metric" in defined_metrics + assert "litellm_cached_tokens_metric" in defined_metrics + + def test_cache_metric_labels_defined(self): + """Test that cache metric labels are properly defined""" + from litellm.types.integrations.prometheus import PrometheusMetricLabels + + # Verify labels are defined for each cache metric + assert hasattr(PrometheusMetricLabels, "litellm_cache_hits_metric") + assert hasattr(PrometheusMetricLabels, "litellm_cache_misses_metric") + assert hasattr(PrometheusMetricLabels, "litellm_cached_tokens_metric") + + # Verify labels include expected keys + expected_labels = [ + "model", + "hashed_api_key", + "api_key_alias", + "team", + "team_alias", + "end_user", + "user", + ] + for label in expected_labels: + assert label in PrometheusMetricLabels.litellm_cache_hits_metric + assert label in PrometheusMetricLabels.litellm_cache_misses_metric + assert label in PrometheusMetricLabels.litellm_cached_tokens_metric + + def test_increment_cache_metrics_on_cache_hit(self, sample_enum_values): + """Test that cache hit increments the correct metrics""" + # Create mock for PrometheusLogger instance + mock_logger = MagicMock() + + # Import the method directly and bind it to our mock + from litellm.integrations.prometheus import PrometheusLogger + + # Create a mock standard logging payload with cache_hit=True + standard_logging_payload = { + "cache_hit": True, + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + "model_group": "openai", + "request_tags": [], + } + + # Create mock metrics + mock_logger.litellm_cache_hits_metric = MagicMock() + mock_logger.litellm_cache_misses_metric = MagicMock() + mock_logger.litellm_cached_tokens_metric = MagicMock() + mock_logger.get_labels_for_metric = MagicMock( + return_value=[ + "model", + "hashed_api_key", + "api_key_alias", + "team", + "team_alias", + "end_user", + "user", + ] + ) + + # Call the method using unbound method approach + PrometheusLogger._increment_cache_metrics( + mock_logger, + standard_logging_payload=standard_logging_payload, + enum_values=sample_enum_values, + ) + + # Verify cache hits metric was incremented + mock_logger.litellm_cache_hits_metric.labels.assert_called() + mock_logger.litellm_cache_hits_metric.labels().inc.assert_called_once() + + # Verify cached tokens metric was incremented with total_tokens + mock_logger.litellm_cached_tokens_metric.labels.assert_called() + mock_logger.litellm_cached_tokens_metric.labels().inc.assert_called_once_with( + 100 + ) + + # Verify cache misses metric was NOT called + mock_logger.litellm_cache_misses_metric.labels.assert_not_called() + + def test_increment_cache_metrics_on_cache_miss(self, sample_enum_values): + """Test that cache miss increments the correct metrics""" + # Create mock for PrometheusLogger instance + mock_logger = MagicMock() + + from litellm.integrations.prometheus import PrometheusLogger + + # Create a mock standard logging payload with cache_hit=False + standard_logging_payload = { + "cache_hit": False, + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + "model_group": "openai", + "request_tags": [], + } + + # Create mock metrics + mock_logger.litellm_cache_hits_metric = MagicMock() + mock_logger.litellm_cache_misses_metric = MagicMock() + mock_logger.litellm_cached_tokens_metric = MagicMock() + mock_logger.get_labels_for_metric = MagicMock( + return_value=[ + "model", + "hashed_api_key", + "api_key_alias", + "team", + "team_alias", + "end_user", + "user", + ] + ) + + # Call the method + PrometheusLogger._increment_cache_metrics( + mock_logger, + standard_logging_payload=standard_logging_payload, + enum_values=sample_enum_values, + ) + + # Verify cache misses metric was incremented + mock_logger.litellm_cache_misses_metric.labels.assert_called() + mock_logger.litellm_cache_misses_metric.labels().inc.assert_called_once() + + # Verify cache hits and cached tokens metrics were NOT called + mock_logger.litellm_cache_hits_metric.labels.assert_not_called() + mock_logger.litellm_cached_tokens_metric.labels.assert_not_called() + + def test_increment_cache_metrics_when_cache_hit_is_none(self, sample_enum_values): + """Test that no metrics are incremented when cache_hit is None""" + # Create mock for PrometheusLogger instance + mock_logger = MagicMock() + + from litellm.integrations.prometheus import PrometheusLogger + + # Create a mock standard logging payload with cache_hit=None + standard_logging_payload = { + "cache_hit": None, + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + "model_group": "openai", + "request_tags": [], + } + + # Create mock metrics + mock_logger.litellm_cache_hits_metric = MagicMock() + mock_logger.litellm_cache_misses_metric = MagicMock() + mock_logger.litellm_cached_tokens_metric = MagicMock() + mock_logger.get_labels_for_metric = MagicMock( + return_value=[ + "model", + "hashed_api_key", + "api_key_alias", + "team", + "team_alias", + "end_user", + "user", + ] + ) + + # Call the method + PrometheusLogger._increment_cache_metrics( + mock_logger, + standard_logging_payload=standard_logging_payload, + enum_values=sample_enum_values, + ) + + # Verify NO metrics were called + mock_logger.litellm_cache_hits_metric.labels.assert_not_called() + mock_logger.litellm_cache_misses_metric.labels.assert_not_called() + mock_logger.litellm_cached_tokens_metric.labels.assert_not_called() + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) From 1b8708fccc30dc5337c92709ba146bb96aa7ac8e Mon Sep 17 00:00:00 2001 From: kothamah <104782493+kothamah@users.noreply.github.com> Date: Wed, 7 Jan 2026 11:10:36 -0500 Subject: [PATCH 4/9] Litellm embeddings calltype fix for guardrail precallhook (#18740) * adding signoz integration to observability docs * Fixing build * Adding timeout for flaky test * Fixing e2e * add team member budget duration in team/update * Reusable Duration Select and update team member budget UI * feat: allow configuring project name for OpenTelemetry service name * docs: sets ARIZE_PROJECT_NAME * added valid callType for bedrock guardrail pre hook This is to resolve the error when bedrock guardrails are enabled and invoke the embedding models. {"error":{"message":"'embeddings' is not a valid CallTypes","type":"None","param":"None","code":"500"}}* * updated the test case to reflect valid callType --------- Co-authored-by: Goutham Karthi Co-authored-by: yuneng-jiang Co-authored-by: YutaSaito <36355491+uc4w6c@users.noreply.github.com> Co-authored-by: Yuta Saito --- .../docs/observability/arize_integration.md | 1 + litellm/integrations/arize/arize.py | 2 + litellm/integrations/opentelemetry.py | 88 +++-- litellm/litellm_core_utils/litellm_logging.py | 1 + litellm/proxy/_types.py | 1 + litellm/proxy/common_request_processing.py | 4 +- .../management_endpoints/team_endpoints.py | 15 +- litellm/types/integrations/arize.py | 1 + tests/local_testing/test_arize_ai.py | 3 + .../arize/test_arize_health_check.py | 10 +- .../integrations/test_custom_guardrail.py | 2 +- .../integrations/test_opentelemetry.py | 52 +-- .../test_team_endpoints.py | 366 +++++++++++++++++- .../common_components/DurationSelect.test.tsx | 49 +++ .../common_components/DurationSelect.tsx | 17 + .../src/components/team/team_info.tsx | 12 + 16 files changed, 552 insertions(+), 72 deletions(-) create mode 100644 ui/litellm-dashboard/src/components/common_components/DurationSelect.test.tsx create mode 100644 ui/litellm-dashboard/src/components/common_components/DurationSelect.tsx diff --git a/docs/my-website/docs/observability/arize_integration.md b/docs/my-website/docs/observability/arize_integration.md index 0b457f0868..b3ccf98ea3 100644 --- a/docs/my-website/docs/observability/arize_integration.md +++ b/docs/my-website/docs/observability/arize_integration.md @@ -68,6 +68,7 @@ environment_variables: ARIZE_API_KEY: "141a****" ARIZE_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize GRPC api endpoint ARIZE_HTTP_ENDPOINT: "https://otlp.arize.com/v1" # OPTIONAL - your custom arize HTTP api endpoint. Set either this or ARIZE_ENDPOINT or Neither (defaults to https://otlp.arize.com/v1 on grpc) + ARIZE_PROJECT_NAME: "my-litellm-project" # OPTIONAL - sets the arize project name ``` 2. Start the proxy diff --git a/litellm/integrations/arize/arize.py b/litellm/integrations/arize/arize.py index 4d1aa80dcc..9c2f0d95d4 100644 --- a/litellm/integrations/arize/arize.py +++ b/litellm/integrations/arize/arize.py @@ -51,6 +51,7 @@ class ArizeLogger(OpenTelemetry): space_id = os.environ.get("ARIZE_SPACE_ID") space_key = os.environ.get("ARIZE_SPACE_KEY") api_key = os.environ.get("ARIZE_API_KEY") + project_name = os.environ.get("ARIZE_PROJECT_NAME") grpc_endpoint = os.environ.get("ARIZE_ENDPOINT") http_endpoint = os.environ.get("ARIZE_HTTP_ENDPOINT") @@ -74,6 +75,7 @@ class ArizeLogger(OpenTelemetry): api_key=api_key, protocol=protocol, endpoint=endpoint, + project_name=project_name, ) async def async_service_success_hook( diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index a7d2326d93..7e0cfab617 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -54,38 +54,6 @@ RAW_REQUEST_SPAN_NAME = "raw_gen_ai_request" LITELLM_REQUEST_SPAN_NAME = "litellm_request" -def _get_litellm_resource(): - """ - Create a proper OpenTelemetry Resource that respects OTEL_RESOURCE_ATTRIBUTES - while maintaining backward compatibility with LiteLLM-specific environment variables. - """ - from opentelemetry.sdk.resources import OTELResourceDetector, Resource - - # Create base resource attributes with LiteLLM-specific defaults - # These will be overridden by OTEL_RESOURCE_ATTRIBUTES if present - base_attributes: Dict[str, Optional[str]] = { - "service.name": os.getenv("OTEL_SERVICE_NAME", "litellm"), - "deployment.environment": os.getenv("OTEL_ENVIRONMENT_NAME", "production"), - # Fix the model_id to use proper environment variable or default to service name - "model_id": os.getenv( - "OTEL_MODEL_ID", os.getenv("OTEL_SERVICE_NAME", "litellm") - ), - } - - # Create base resource with LiteLLM-specific defaults - base_resource = Resource.create(base_attributes) # type: ignore - - # Create resource from OTEL_RESOURCE_ATTRIBUTES using the detector - otel_resource_detector = OTELResourceDetector() - env_resource = otel_resource_detector.detect() - - # Merge the resources: env_resource takes precedence over base_resource - # This ensures OTEL_RESOURCE_ATTRIBUTES overrides LiteLLM defaults - merged_resource = base_resource.merge(env_resource) - - return merged_resource - - @dataclass class OpenTelemetryConfig: exporter: Union[str, SpanExporter] = "console" @@ -93,6 +61,19 @@ class OpenTelemetryConfig: headers: Optional[str] = None enable_metrics: bool = False enable_events: bool = False + service_name: Optional[str] = None + deployment_environment: Optional[str] = None + model_id: Optional[str] = None + + def __post_init__(self) -> None: + if not self.service_name: + self.service_name = os.getenv("OTEL_SERVICE_NAME", "litellm") + if not self.deployment_environment: + self.deployment_environment = os.getenv( + "OTEL_ENVIRONMENT_NAME", "production" + ) + if not self.model_id: + self.model_id = os.getenv("OTEL_MODEL_ID", self.service_name) @classmethod def from_env(cls): @@ -122,6 +103,9 @@ class OpenTelemetryConfig: os.getenv("LITELLM_OTEL_INTEGRATION_ENABLE_EVENTS", "false").lower() == "true" ) + service_name = os.getenv("OTEL_SERVICE_NAME", "litellm") + deployment_environment = os.getenv("OTEL_ENVIRONMENT_NAME", "production") + model_id = os.getenv("OTEL_MODEL_ID", service_name) if exporter == "in_memory": return cls(exporter=InMemorySpanExporter()) @@ -131,6 +115,9 @@ class OpenTelemetryConfig: headers=headers, # example: OTEL_HEADERS=x-honeycomb-team=B85YgLm96***" enable_metrics=enable_metrics, enable_events=enable_events, + service_name=service_name, + deployment_environment=deployment_environment, + model_id=model_id, ) @@ -174,6 +161,22 @@ class OpenTelemetry(CustomLogger): self._init_logs(logger_provider) self._init_otel_logger_on_litellm_proxy() + @staticmethod + def _get_litellm_resource(config: OpenTelemetryConfig): + """Create an OpenTelemetry Resource using config-driven defaults.""" + from opentelemetry.sdk.resources import OTELResourceDetector, Resource + + base_attributes: Dict[str, Optional[str]] = { + "service.name": config.service_name, + "deployment.environment": config.deployment_environment, + "model_id": config.model_id or config.service_name, + } + + base_resource = Resource.create(base_attributes) # type: ignore[arg-type] + otel_resource_detector = OTELResourceDetector() + env_resource = otel_resource_detector.detect() + return base_resource.merge(env_resource) + def _init_otel_logger_on_litellm_proxy(self): """ Initializes OpenTelemetry for litellm proxy server @@ -266,7 +269,7 @@ class OpenTelemetry(CustomLogger): from opentelemetry.trace import SpanKind def create_tracer_provider(): - provider = TracerProvider(resource=_get_litellm_resource()) + provider = TracerProvider(resource=self._get_litellm_resource(self.config)) provider.add_span_processor(self._get_span_processor()) return provider @@ -300,7 +303,8 @@ class OpenTelemetry(CustomLogger): def create_meter_provider(): metric_reader = self._get_metric_reader() return MeterProvider( - metric_readers=[metric_reader], resource=_get_litellm_resource() + metric_readers=[metric_reader], + resource=self._get_litellm_resource(self.config), ) meter_provider = self._get_or_create_provider( @@ -355,7 +359,9 @@ class OpenTelemetry(CustomLogger): from opentelemetry.sdk._logs.export import BatchLogRecordProcessor def create_logger_provider(): - provider = OTLoggerProvider(resource=_get_litellm_resource()) + provider = OTLoggerProvider( + resource=self._get_litellm_resource(self.config) + ) log_exporter = self._get_log_exporter() provider.add_log_record_processor( BatchLogRecordProcessor(log_exporter) # type: ignore[arg-type] @@ -606,7 +612,7 @@ class OpenTelemetry(CustomLogger): from opentelemetry.sdk.trace import TracerProvider # Create a temporary tracer provider with dynamic headers - temp_provider = TracerProvider(resource=_get_litellm_resource()) + temp_provider = TracerProvider(resource=self._get_litellm_resource(self.config)) temp_provider.add_span_processor( self._get_span_processor(dynamic_headers=dynamic_headers) ) @@ -987,9 +993,9 @@ class OpenTelemetry(CustomLogger): # Get the resource from the logger provider logger_provider = get_logger_provider() - resource = ( - getattr(logger_provider, "_resource", None) or _get_litellm_resource() - ) + resource = getattr( + logger_provider, "_resource", None + ) or self._get_litellm_resource(self.config) parent_ctx = span.get_span_context() provider = (kwargs.get("litellm_params") or {}).get( @@ -1910,7 +1916,9 @@ class OpenTelemetry(CustomLogger): ) _split_otel_headers = OpenTelemetry._get_headers_dictionary(self.OTEL_HEADERS) - normalized_endpoint = self._normalize_otel_endpoint(self.OTEL_ENDPOINT, "metrics") + normalized_endpoint = self._normalize_otel_endpoint( + self.OTEL_ENDPOINT, "metrics" + ) if self.OTEL_EXPORTER == "console": exporter = ConsoleMetricExporter() diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index cd32493556..5448fe7c77 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -3630,6 +3630,7 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 otel_config = OpenTelemetryConfig( exporter=arize_config.protocol, endpoint=arize_config.endpoint, + service_name=arize_config.project_name, ) os.environ[ diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 954c26e2cb..cff3bee1ca 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1532,6 +1532,7 @@ class UpdateTeamRequest(LiteLLMPydanticObjectBase): guardrails: Optional[List[str]] = None object_permission: Optional[LiteLLM_ObjectPermissionBase] = None team_member_budget: Optional[float] = None + team_member_budget_duration: Optional[str] = None team_member_rpm_limit: Optional[int] = None team_member_tpm_limit: Optional[int] = None team_member_key_duration: Optional[str] = None diff --git a/litellm/proxy/common_request_processing.py b/litellm/proxy/common_request_processing.py index 95c23be5b8..4e0c4f2381 100644 --- a/litellm/proxy/common_request_processing.py +++ b/litellm/proxy/common_request_processing.py @@ -957,11 +957,11 @@ class ProxyBaseLLMRequestProcessing: @staticmethod def _get_pre_call_type( route_type: Literal["acompletion", "aembedding", "aresponses", "allm_passthrough_route"], - ) -> Literal["completion", "embeddings", "responses", "allm_passthrough_route"]: + ) -> Literal["completion", "embedding", "responses", "allm_passthrough_route"]: if route_type == "acompletion": return "completion" elif route_type == "aembedding": - return "embeddings" + return "embedding" elif route_type == "aresponses": return "responses" elif route_type == "allm_passthrough_route": diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 76c607f5c4..920105edc1 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -112,6 +112,7 @@ class TeamMemberBudgetHandler: team_member_budget: Optional[float] = None, team_member_rpm_limit: Optional[int] = None, team_member_tpm_limit: Optional[int] = None, + team_member_budget_duration: Optional[str] = None, ) -> bool: """Check if any team member limits are provided""" return any( @@ -119,6 +120,7 @@ class TeamMemberBudgetHandler: team_member_budget is not None, team_member_rpm_limit is not None, team_member_tpm_limit is not None, + team_member_budget_duration is not None, ] ) @@ -130,6 +132,7 @@ class TeamMemberBudgetHandler: team_member_budget: Optional[float] = None, team_member_rpm_limit: Optional[int] = None, team_member_tpm_limit: Optional[int] = None, + team_member_budget_duration: Optional[str] = None, ) -> dict: """Create team member budget table with provided limits""" from litellm.proxy._types import BudgetNewRequest @@ -147,7 +150,7 @@ class TeamMemberBudgetHandler: # Create budget request with all provided limits budget_request = BudgetNewRequest( budget_id=budget_id, - budget_duration=data.budget_duration, + budget_duration=data.budget_duration or team_member_budget_duration, ) if team_member_budget is not None: @@ -156,6 +159,8 @@ class TeamMemberBudgetHandler: budget_request.rpm_limit = team_member_rpm_limit if team_member_tpm_limit is not None: budget_request.tpm_limit = team_member_tpm_limit + if team_member_budget_duration is not None: + budget_request.budget_duration = team_member_budget_duration team_member_budget_table = await new_budget( budget_obj=budget_request, @@ -182,6 +187,7 @@ class TeamMemberBudgetHandler: team_member_budget: Optional[float] = None, team_member_rpm_limit: Optional[int] = None, team_member_tpm_limit: Optional[int] = None, + team_member_budget_duration: Optional[str] = None, ) -> dict: """Upsert team member budget table with provided limits""" from litellm.proxy._types import BudgetNewRequest @@ -203,6 +209,8 @@ class TeamMemberBudgetHandler: budget_request.rpm_limit = team_member_rpm_limit if team_member_tpm_limit is not None: budget_request.tpm_limit = team_member_tpm_limit + if team_member_budget_duration is not None: + budget_request.budget_duration = team_member_budget_duration budget_row = await update_budget( budget_obj=budget_request, @@ -223,6 +231,7 @@ class TeamMemberBudgetHandler: team_member_budget=team_member_budget, team_member_rpm_limit=team_member_rpm_limit, team_member_tpm_limit=team_member_tpm_limit, + team_member_budget_duration=team_member_budget_duration, ) # Remove team member fields from updated_kv @@ -233,6 +242,7 @@ class TeamMemberBudgetHandler: def _clean_team_member_fields(data_dict: dict) -> None: """Remove team member fields from data dictionary""" data_dict.pop("team_member_budget", None) + data_dict.pop("team_member_budget_duration", None) data_dict.pop("team_member_rpm_limit", None) data_dict.pop("team_member_tpm_limit", None) @@ -1214,6 +1224,7 @@ async def update_team( # noqa: PLR0915 - disable_global_guardrails: Optional[bool] - Whether to disable global guardrails for the key. - object_permission: Optional[LiteLLM_ObjectPermissionBase] - team-specific object permission. Example - {"vector_stores": ["vector_store_1", "vector_store_2"], "agents": ["agent_1", "agent_2"], "agent_access_groups": ["dev_group"]}. IF null or {} then no object permission. - team_member_budget: Optional[float] - The maximum budget allocated to an individual team member. + - team_member_budget_duration: Optional[str] - The duration of the budget for the team member. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets) - team_member_rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for individual team members. - team_member_tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for individual team members. - team_member_key_duration: Optional[str] - The duration for a team member's key. e.g. "1d", "1w", "1mo" @@ -1349,6 +1360,7 @@ async def update_team( # noqa: PLR0915 team_member_budget=data.team_member_budget, team_member_rpm_limit=data.team_member_rpm_limit, team_member_tpm_limit=data.team_member_tpm_limit, + team_member_budget_duration=data.team_member_budget_duration, ): updated_kv = await TeamMemberBudgetHandler.upsert_team_member_budget_table( team_table=existing_team_row, @@ -1357,6 +1369,7 @@ async def update_team( # noqa: PLR0915 team_member_budget=data.team_member_budget, team_member_rpm_limit=data.team_member_rpm_limit, team_member_tpm_limit=data.team_member_tpm_limit, + team_member_budget_duration=data.team_member_budget_duration, ) else: TeamMemberBudgetHandler._clean_team_member_fields(updated_kv) diff --git a/litellm/types/integrations/arize.py b/litellm/types/integrations/arize.py index be4df30e79..248fdac3b3 100644 --- a/litellm/types/integrations/arize.py +++ b/litellm/types/integrations/arize.py @@ -14,3 +14,4 @@ class ArizeConfig(BaseModel): api_key: Optional[str] = None protocol: Protocol endpoint: str + project_name: Optional[str] = None diff --git a/tests/local_testing/test_arize_ai.py b/tests/local_testing/test_arize_ai.py index 6a77352143..3b497d638a 100644 --- a/tests/local_testing/test_arize_ai.py +++ b/tests/local_testing/test_arize_ai.py @@ -71,6 +71,7 @@ def test_get_arize_config(mock_env_vars): assert config.api_key == "test_api_key" assert config.endpoint == "https://otlp.arize.com/v1" assert config.protocol == "otlp_grpc" + assert config.project_name is None def test_get_arize_config_with_endpoints(mock_env_vars, monkeypatch): @@ -79,10 +80,12 @@ def test_get_arize_config_with_endpoints(mock_env_vars, monkeypatch): """ monkeypatch.setenv("ARIZE_ENDPOINT", "grpc://test.endpoint") monkeypatch.setenv("ARIZE_HTTP_ENDPOINT", "http://test.endpoint") + monkeypatch.setenv("ARIZE_PROJECT_NAME", "custom-project") config = ArizeLogger.get_arize_config() assert config.endpoint == "grpc://test.endpoint" assert config.protocol == "otlp_grpc" + assert config.project_name == "custom-project" @pytest.mark.skip( diff --git a/tests/test_litellm/integrations/arize/test_arize_health_check.py b/tests/test_litellm/integrations/arize/test_arize_health_check.py index 91d0b42d48..8d86b7dc09 100644 --- a/tests/test_litellm/integrations/arize/test_arize_health_check.py +++ b/tests/test_litellm/integrations/arize/test_arize_health_check.py @@ -123,7 +123,8 @@ class TestArizeIntegrationWithProxy: with patch.dict(os.environ, { "ARIZE_SPACE_KEY": "test-space-123", "ARIZE_API_KEY": "test-api-456", - "ARIZE_ENDPOINT": "https://custom.arize.com/v1" + "ARIZE_ENDPOINT": "https://custom.arize.com/v1", + "ARIZE_PROJECT_NAME": "custom-project", }): config = ArizeLogger.get_arize_config() @@ -131,13 +132,15 @@ class TestArizeIntegrationWithProxy: assert config.api_key == "test-api-456" assert config.endpoint == "https://custom.arize.com/v1" assert config.protocol == "otlp_grpc" + assert config.project_name == "custom-project" def test_arize_get_config_defaults(self): """Test ArizeLogger.get_arize_config() with default endpoint.""" with patch.dict(os.environ, { "ARIZE_SPACE_KEY": "test-space-default", - "ARIZE_API_KEY": "test-api-default" + "ARIZE_API_KEY": "test-api-default", + "ARIZE_PROJECT_NAME": "default-project", }, clear=True): config = ArizeLogger.get_arize_config() @@ -145,6 +148,7 @@ class TestArizeIntegrationWithProxy: assert config.api_key == "test-api-default" assert config.endpoint == "https://otlp.arize.com/v1" # Default endpoint assert config.protocol == "otlp_grpc" # Default protocol + assert config.project_name == "default-project" def test_arize_construct_dynamic_headers(self): """Test dynamic OTEL headers construction for team/key logging.""" @@ -180,4 +184,4 @@ class TestArizeIntegrationWithProxy: if __name__ == "__main__": - pytest.main([__file__, "-v"]) \ No newline at end of file + pytest.main([__file__, "-v"]) diff --git a/tests/test_litellm/integrations/test_custom_guardrail.py b/tests/test_litellm/integrations/test_custom_guardrail.py index a322dfe9a2..d7d7720ff4 100644 --- a/tests/test_litellm/integrations/test_custom_guardrail.py +++ b/tests/test_litellm/integrations/test_custom_guardrail.py @@ -530,7 +530,7 @@ class TestPassthroughCallTypeHandling: ) assert ( ProxyBaseLLMRequestProcessing._get_pre_call_type(route_type="aembedding") - == "embeddings" + == "embedding" ) assert ( ProxyBaseLLMRequestProcessing._get_pre_call_type(route_type="aresponses") diff --git a/tests/test_litellm/integrations/test_opentelemetry.py b/tests/test_litellm/integrations/test_opentelemetry.py index 6c17570e13..55b65fbb92 100644 --- a/tests/test_litellm/integrations/test_opentelemetry.py +++ b/tests/test_litellm/integrations/test_opentelemetry.py @@ -258,6 +258,22 @@ class TestOpenTelemetry(unittest.TestCase): MODEL = "arn:aws:bedrock:us-west-2:1234567890123:inference-profile/us.anthropic.claude-3-7-sonnet-20250219-v1:0" HERE = os.path.dirname(__file__) + @patch.dict(os.environ, {}, clear=True) + def test_open_telemetry_config_manual_defaults(self): + """Manual OpenTelemetryConfig creation should populate default identifiers.""" + config = OpenTelemetryConfig(exporter="console", endpoint="http://collector") + self.assertEqual(config.service_name, "litellm") + self.assertEqual(config.deployment_environment, "production") + self.assertEqual(config.model_id, "litellm") + + @patch.dict(os.environ, {}, clear=True) + def test_open_telemetry_config_custom_service_name(self): + """Model ID should inherit provided service name when not explicitly set.""" + config = OpenTelemetryConfig(service_name="custom-service", exporter="console") + self.assertEqual(config.service_name, "custom-service") + self.assertEqual(config.deployment_environment, "production") + self.assertEqual(config.model_id, "custom-service") + def wait_for_spans(self, exporter: InMemorySpanExporter, prefix: str): """Poll until we see at least one span with an attribute key starting with `prefix`.""" deadline = time.time() + self.POLL_TIMEOUT @@ -504,8 +520,6 @@ class TestOpenTelemetry(unittest.TestCase): self, mock_detector_cls, mock_resource_create ): """Test _get_litellm_resource with default values when no environment variables are set.""" - from litellm.integrations.opentelemetry import _get_litellm_resource - # Mock the Resource.create method mock_base_resource = MagicMock() mock_resource_create.return_value = mock_base_resource @@ -520,8 +534,8 @@ class TestOpenTelemetry(unittest.TestCase): mock_merged_resource = MagicMock() mock_base_resource.merge.return_value = mock_merged_resource - # Call the function - result = _get_litellm_resource() + config = OpenTelemetryConfig() + result = OpenTelemetry._get_litellm_resource(config) # Verify Resource.create was called with correct default attributes expected_attributes = { @@ -549,8 +563,6 @@ class TestOpenTelemetry(unittest.TestCase): self, mock_detector_cls, mock_resource_create ): """Test _get_litellm_resource with LiteLLM-specific environment variables.""" - from litellm.integrations.opentelemetry import _get_litellm_resource - # Mock the Resource.create method mock_base_resource = MagicMock() mock_resource_create.return_value = mock_base_resource @@ -565,8 +577,8 @@ class TestOpenTelemetry(unittest.TestCase): mock_merged_resource = MagicMock() mock_base_resource.merge.return_value = mock_merged_resource - # Call the function - result = _get_litellm_resource() + config = OpenTelemetryConfig.from_env() + result = OpenTelemetry._get_litellm_resource(config) # Verify Resource.create was called with environment variable values expected_attributes = { @@ -593,8 +605,6 @@ class TestOpenTelemetry(unittest.TestCase): self, mock_detector_cls, mock_resource_create ): """Test _get_litellm_resource with OTEL_RESOURCE_ATTRIBUTES environment variable.""" - from litellm.integrations.opentelemetry import _get_litellm_resource - # Mock the Resource.create method to simulate the actual behavior # In reality, Resource.create() would parse OTEL_RESOURCE_ATTRIBUTES and merge it mock_base_resource = MagicMock() @@ -610,8 +620,8 @@ class TestOpenTelemetry(unittest.TestCase): mock_merged_resource = MagicMock() mock_base_resource.merge.return_value = mock_merged_resource - # Call the function - result = _get_litellm_resource() + config = OpenTelemetryConfig.from_env() + result = OpenTelemetry._get_litellm_resource(config) # Verify Resource.create was called with the base attributes # The actual OTEL_RESOURCE_ATTRIBUTES parsing is handled by OpenTelemetry SDK @@ -628,10 +638,8 @@ class TestOpenTelemetry(unittest.TestCase): @patch.dict(os.environ, {}, clear=True) def test_get_litellm_resource_integration_with_real_resource(self): """Integration test to verify _get_litellm_resource works with actual OpenTelemetry Resource.""" - from litellm.integrations.opentelemetry import _get_litellm_resource - - # This test uses the real OpenTelemetry Resource.create() method - result = _get_litellm_resource() + config = OpenTelemetryConfig() + result = OpenTelemetry._get_litellm_resource(config) # Verify the result is a Resource instance from opentelemetry.sdk.resources import Resource @@ -653,10 +661,8 @@ class TestOpenTelemetry(unittest.TestCase): ) def test_get_litellm_resource_real_otel_resource_attributes(self): """Integration test to verify OTEL_RESOURCE_ATTRIBUTES is properly handled.""" - from litellm.integrations.opentelemetry import _get_litellm_resource - - # This test uses the real OpenTelemetry Resource.create() method - result = _get_litellm_resource() + config = OpenTelemetryConfig.from_env() + result = OpenTelemetry._get_litellm_resource(config) print("RESULT", result) @@ -683,10 +689,8 @@ class TestOpenTelemetry(unittest.TestCase): ) def test_get_litellm_resource_precedence(self): """Test that OTEL_SERVICE_NAME takes precedence over OTEL_RESOURCE_ATTRIBUTES according to OpenTelemetry spec.""" - from litellm.integrations.opentelemetry import _get_litellm_resource - - # This test verifies the OpenTelemetry standard behavior - result = _get_litellm_resource() + config = OpenTelemetryConfig.from_env() + result = OpenTelemetry._get_litellm_resource(config) # Verify the result is a Resource instance from opentelemetry.sdk.resources import Resource diff --git a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py index 6cf8f745e0..57064586af 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_team_endpoints.py @@ -1279,7 +1279,7 @@ async def test_update_team_team_member_budget_not_passed_to_db(): # Mock budget upsert to return updated_kv without team_member_budget def mock_upsert_side_effect( - team_table, user_api_key_dict, updated_kv, team_member_budget=None, team_member_rpm_limit=None, team_member_tpm_limit=None + team_table, user_api_key_dict, updated_kv, team_member_budget=None, team_member_rpm_limit=None, team_member_tpm_limit=None, team_member_budget_duration=None ): # Remove team_member_budget from updated_kv as the real function does result_kv = updated_kv.copy() @@ -1376,6 +1376,370 @@ async def test_update_team_team_member_budget_not_passed_to_db(): ) +def test_clean_team_member_fields(): + """ + Test that _clean_team_member_fields removes all team member fields from a dictionary. + """ + from litellm.proxy.management_endpoints.team_endpoints import ( + TeamMemberBudgetHandler, + ) + + data_dict = { + "team_id": "test_team", + "team_alias": "Test Team", + "team_member_budget": 100.0, + "team_member_budget_duration": "30d", + "team_member_rpm_limit": 50, + "team_member_tpm_limit": 1000, + "other_field": "should_remain", + } + + TeamMemberBudgetHandler._clean_team_member_fields(data_dict) + + assert "team_member_budget" not in data_dict + assert "team_member_budget_duration" not in data_dict + assert "team_member_rpm_limit" not in data_dict + assert "team_member_tpm_limit" not in data_dict + assert data_dict["team_id"] == "test_team" + assert data_dict["team_alias"] == "Test Team" + assert data_dict["other_field"] == "should_remain" + + +def test_clean_team_member_fields_with_missing_fields(): + """ + Test that _clean_team_member_fields handles dictionaries without team member fields gracefully. + """ + from litellm.proxy.management_endpoints.team_endpoints import ( + TeamMemberBudgetHandler, + ) + + data_dict = { + "team_id": "test_team", + "team_alias": "Test Team", + } + + TeamMemberBudgetHandler._clean_team_member_fields(data_dict) + + assert data_dict["team_id"] == "test_team" + assert data_dict["team_alias"] == "Test Team" + + +@pytest.mark.asyncio +async def test_create_team_member_budget_table(): + """ + Test that create_team_member_budget_table creates a budget and adds it to metadata. + """ + from unittest.mock import AsyncMock, MagicMock, patch + + from litellm.proxy._types import LitellmUserRoles, NewTeamRequest, UserAPIKeyAuth + from litellm.proxy.management_endpoints.team_endpoints import ( + TeamMemberBudgetHandler, + ) + + mock_user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, user_id="test_user_id" + ) + + data = NewTeamRequest( + team_id="test_team_id", + team_alias="Test Team", + budget_duration="1mo", + ) + new_team_data_json = { + "team_id": "test_team_id", + "team_alias": "Test Team", + "team_member_budget": 100.0, + "team_member_budget_duration": "30d", + "team_member_rpm_limit": 50, + "team_member_tpm_limit": 1000, + } + + mock_budget_response = MagicMock() + mock_budget_response.budget_id = "budget_123" + + with patch( + "litellm.proxy.management_endpoints.budget_management_endpoints.new_budget", + new_callable=AsyncMock + ) as mock_new_budget: + mock_new_budget.return_value = mock_budget_response + + result = await TeamMemberBudgetHandler.create_team_member_budget_table( + data=data, + new_team_data_json=new_team_data_json, + user_api_key_dict=mock_user_api_key_dict, + team_member_budget=100.0, + team_member_rpm_limit=50, + team_member_tpm_limit=1000, + team_member_budget_duration="30d", + ) + + assert mock_new_budget.called + call_args = mock_new_budget.call_args + budget_request = call_args[1]["budget_obj"] + + assert budget_request.max_budget == 100.0 + assert budget_request.rpm_limit == 50 + assert budget_request.tpm_limit == 1000 + assert budget_request.budget_duration == "30d" + assert budget_request.budget_id is not None + assert "team-" in budget_request.budget_id + + assert "team_member_budget_id" in result["metadata"] + assert result["metadata"]["team_member_budget_id"] == "budget_123" + + assert "team_member_budget" not in result + assert "team_member_budget_duration" not in result + assert "team_member_rpm_limit" not in result + assert "team_member_tpm_limit" not in result + + +@pytest.mark.asyncio +async def test_create_team_member_budget_table_without_team_alias(): + """ + Test that create_team_member_budget_table generates budget_id correctly when team_alias is None. + """ + from unittest.mock import AsyncMock, MagicMock, patch + + from litellm.proxy._types import LitellmUserRoles, NewTeamRequest, UserAPIKeyAuth + from litellm.proxy.management_endpoints.team_endpoints import ( + TeamMemberBudgetHandler, + ) + + mock_user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, user_id="test_user_id" + ) + + data = NewTeamRequest(team_id="test_team_id") + new_team_data_json = { + "team_id": "test_team_id", + "team_member_budget": 100.0, + } + + mock_budget_response = MagicMock() + mock_budget_response.budget_id = "budget_123" + + with patch( + "litellm.proxy.management_endpoints.budget_management_endpoints.new_budget", + new_callable=AsyncMock + ) as mock_new_budget: + mock_new_budget.return_value = mock_budget_response + + result = await TeamMemberBudgetHandler.create_team_member_budget_table( + data=data, + new_team_data_json=new_team_data_json, + user_api_key_dict=mock_user_api_key_dict, + team_member_budget=100.0, + ) + + assert mock_new_budget.called + call_args = mock_new_budget.call_args + budget_request = call_args[1]["budget_obj"] + + assert budget_request.budget_id is not None + assert budget_request.budget_id.startswith("team-budget-") + + +@pytest.mark.asyncio +async def test_upsert_team_member_budget_table_existing_budget(): + """ + Test that upsert_team_member_budget_table updates an existing budget when team_member_budget_id exists. + """ + from unittest.mock import AsyncMock, MagicMock, patch + + from litellm.proxy._types import LitellmUserRoles, LiteLLM_TeamTable, UserAPIKeyAuth + from litellm.proxy.management_endpoints.team_endpoints import ( + TeamMemberBudgetHandler, + ) + + mock_user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, user_id="test_user_id" + ) + + team_table = MagicMock(spec=LiteLLM_TeamTable) + team_table.metadata = {"team_member_budget_id": "existing_budget_123"} + + updated_kv = { + "team_id": "test_team_id", + "team_member_budget": 200.0, + "team_member_budget_duration": "60d", + "team_member_rpm_limit": 100, + } + + mock_budget_response = MagicMock() + mock_budget_response.budget_id = "existing_budget_123" + + with patch( + "litellm.proxy.management_endpoints.budget_management_endpoints.update_budget", + new_callable=AsyncMock + ) as mock_update_budget: + mock_update_budget.return_value = mock_budget_response + + result = await TeamMemberBudgetHandler.upsert_team_member_budget_table( + team_table=team_table, + user_api_key_dict=mock_user_api_key_dict, + updated_kv=updated_kv, + team_member_budget=200.0, + team_member_budget_duration="60d", + team_member_rpm_limit=100, + ) + + assert mock_update_budget.called + call_args = mock_update_budget.call_args + budget_request = call_args[1]["budget_obj"] + + assert budget_request.budget_id == "existing_budget_123" + assert budget_request.max_budget == 200.0 + assert budget_request.budget_duration == "60d" + assert budget_request.rpm_limit == 100 + + assert "team_member_budget_id" in result["metadata"] + assert result["metadata"]["team_member_budget_id"] == "existing_budget_123" + + assert "team_member_budget" not in result + assert "team_member_budget_duration" not in result + assert "team_member_rpm_limit" not in result + + +@pytest.mark.asyncio +async def test_upsert_team_member_budget_table_no_existing_budget(): + """ + Test that upsert_team_member_budget_table creates a new budget when team_member_budget_id does not exist. + """ + from unittest.mock import AsyncMock, MagicMock, patch + + from litellm.proxy._types import LitellmUserRoles, LiteLLM_TeamTable, UserAPIKeyAuth + from litellm.proxy.management_endpoints.team_endpoints import ( + TeamMemberBudgetHandler, + ) + + mock_user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, user_id="test_user_id" + ) + + team_table = MagicMock(spec=LiteLLM_TeamTable) + team_table.metadata = {} + team_table.team_alias = "Test Team" + team_table.budget_duration = None + + updated_kv = { + "team_id": "test_team_id", + "team_member_budget": 150.0, + "team_member_budget_duration": "45d", + } + + mock_budget_response = MagicMock() + mock_budget_response.budget_id = "new_budget_456" + + with patch( + "litellm.proxy.management_endpoints.budget_management_endpoints.new_budget", + new_callable=AsyncMock + ) as mock_new_budget: + mock_new_budget.return_value = mock_budget_response + + result = await TeamMemberBudgetHandler.upsert_team_member_budget_table( + team_table=team_table, + user_api_key_dict=mock_user_api_key_dict, + updated_kv=updated_kv, + team_member_budget=150.0, + team_member_budget_duration="45d", + ) + + assert mock_new_budget.called + assert "team_member_budget_id" in result["metadata"] + assert result["metadata"]["team_member_budget_id"] == "new_budget_456" + + assert "team_member_budget" not in result + assert "team_member_budget_duration" not in result + + +@pytest.mark.asyncio +async def test_update_team_with_team_member_budget_duration(): + """ + Test that team/update endpoint properly handles team_member_budget_duration. + """ + from unittest.mock import AsyncMock, MagicMock, Mock, patch + + from fastapi import Request + + from litellm.proxy._types import LitellmUserRoles, UpdateTeamRequest, UserAPIKeyAuth + from litellm.proxy.management_endpoints.team_endpoints import update_team + + mock_request = Mock(spec=Request) + mock_user_api_key_dict = UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, user_id="test_user_id" + ) + + with patch("litellm.proxy.proxy_server.prisma_client") as mock_prisma_client, patch( + "litellm.proxy.proxy_server.llm_router" + ) as mock_llm_router, patch( + "litellm.proxy.proxy_server.user_api_key_cache" + ) as mock_cache, patch( + "litellm.proxy.proxy_server.proxy_logging_obj" + ) as mock_logging, patch( + "litellm.proxy.proxy_server.litellm_proxy_admin_name", "admin" + ), patch( + "litellm.proxy.auth.auth_checks._cache_team_object" + ) as mock_cache_team, patch( + "litellm.proxy.management_endpoints.team_endpoints.TeamMemberBudgetHandler.upsert_team_member_budget_table" + ) as mock_upsert_budget: + + mock_existing_team = MagicMock() + mock_existing_team.model_dump.return_value = { + "team_id": "test_team_id", + "team_alias": "test_team", + "metadata": {"team_member_budget_id": "budget_123"}, + } + mock_existing_team.metadata = {"team_member_budget_id": "budget_123"} + mock_prisma_client.db.litellm_teamtable.find_unique = AsyncMock( + return_value=mock_existing_team + ) + + mock_updated_team = MagicMock() + mock_updated_team.team_id = "test_team_id" + mock_updated_team.model_dump.return_value = {"team_id": "test_team_id"} + mock_prisma_client.db.litellm_teamtable.update = AsyncMock( + return_value=mock_updated_team + ) + mock_prisma_client.jsonify_team_object = MagicMock( + side_effect=lambda db_data: db_data + ) + + def mock_upsert_side_effect( + team_table, user_api_key_dict, updated_kv, team_member_budget=None, team_member_rpm_limit=None, team_member_tpm_limit=None, team_member_budget_duration=None + ): + result_kv = updated_kv.copy() + result_kv.pop("team_member_budget", None) + result_kv.pop("team_member_budget_duration", None) + return result_kv + + mock_upsert_budget.side_effect = mock_upsert_side_effect + + update_request = UpdateTeamRequest( + team_id="test_team_id", + team_alias="updated_alias", + team_member_budget=100.0, + team_member_budget_duration="30d", + ) + + result = await update_team( + data=update_request, + http_request=mock_request, + user_api_key_dict=mock_user_api_key_dict, + ) + + assert mock_upsert_budget.called + call_args = mock_upsert_budget.call_args + assert call_args[1]["team_member_budget"] == 100.0 + assert call_args[1]["team_member_budget_duration"] == "30d" + + assert mock_prisma_client.db.litellm_teamtable.update.called + update_call_args = mock_prisma_client.db.litellm_teamtable.update.call_args + update_data = update_call_args[1]["data"] + + assert "team_member_budget" not in update_data + assert "team_member_budget_duration" not in update_data + + @pytest.mark.asyncio async def test_bulk_team_member_add_success(): """ diff --git a/ui/litellm-dashboard/src/components/common_components/DurationSelect.test.tsx b/ui/litellm-dashboard/src/components/common_components/DurationSelect.test.tsx new file mode 100644 index 0000000000..296ef1ae63 --- /dev/null +++ b/ui/litellm-dashboard/src/components/common_components/DurationSelect.test.tsx @@ -0,0 +1,49 @@ +import { render, screen } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { describe, it, expect, vi } from "vitest"; +import DurationSelect from "./DurationSelect"; + +describe("DurationSelect", () => { + it("should render", () => { + render(); + expect(screen.getByRole("combobox")).toBeInTheDocument(); + }); + + it("should render all three duration options", async () => { + const user = userEvent.setup(); + render(); + + const select = screen.getByRole("combobox"); + await user.click(select); + + expect(screen.getByText("Daily")).toBeInTheDocument(); + expect(screen.getByText("Weekly")).toBeInTheDocument(); + expect(screen.getByText("Monthly")).toBeInTheDocument(); + }); + + it("should apply className prop", () => { + render(); + const select = screen.getByRole("combobox"); + expect(select.closest(".test-class")).toBeInTheDocument(); + }); + + it("should call onChange when an option is selected", async () => { + const user = userEvent.setup(); + const onChange = vi.fn(); + render(); + + const select = screen.getByRole("combobox"); + await user.click(select); + + const dailyOption = screen.getByText("Daily"); + await user.click(dailyOption); + + expect(onChange).toHaveBeenCalledWith("24h", expect.any(Object)); + }); + + it("should accept and pass value prop to Select", () => { + render(); + const select = screen.getByRole("combobox"); + expect(select).toBeInTheDocument(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/common_components/DurationSelect.tsx b/ui/litellm-dashboard/src/components/common_components/DurationSelect.tsx new file mode 100644 index 0000000000..a84e8aeb11 --- /dev/null +++ b/ui/litellm-dashboard/src/components/common_components/DurationSelect.tsx @@ -0,0 +1,17 @@ +import { Select } from "antd"; + +interface DurationSelectProps { + className?: string; + value?: string; + onChange?: (value: string) => void; +} + +export default function DurationSelect({ className, value, onChange }: DurationSelectProps) { + return ( + + ); +} diff --git a/ui/litellm-dashboard/src/components/team/team_info.tsx b/ui/litellm-dashboard/src/components/team/team_info.tsx index d2d1c88593..fd7994aa1d 100644 --- a/ui/litellm-dashboard/src/components/team/team_info.tsx +++ b/ui/litellm-dashboard/src/components/team/team_info.tsx @@ -48,6 +48,7 @@ import EditLoggingSettings from "./EditLoggingSettings"; import MemberModal from "./EditMembership"; import MemberPermissions from "./member_permissions"; import TeamMembersComponent from "./team_member_view"; +import DurationSelect from "../common_components/DurationSelect"; export interface TeamMembership { user_id: string; @@ -413,6 +414,7 @@ const TeamInfoView: React.FC = ({ }; updateData.max_budget = mapEmptyStringToNull(updateData.max_budget); + updateData.team_member_budget_duration = values.team_member_budget_duration; if (values.team_member_budget !== undefined) { updateData.team_member_budget = Number(values.team_member_budget); @@ -650,6 +652,8 @@ const TeamInfoView: React.FC = ({ budget_duration: info.budget_duration, team_member_tpm_limit: info.team_member_budget_table?.tpm_limit, team_member_rpm_limit: info.team_member_budget_table?.rpm_limit, + team_member_budget: info.team_member_budget_table?.max_budget, + team_member_budget_duration: info.team_member_budget_table?.budget_duration, guardrails: info.metadata?.guardrails || [], disable_global_guardrails: info.metadata?.disable_global_guardrails || false, metadata: info.metadata @@ -747,6 +751,13 @@ const TeamInfoView: React.FC = ({ + + form.setFieldValue("team_member_budget_duration", value)} + value={form.getFieldValue("team_member_budget_duration")} + /> + + = ({
Max Budget: {info.team_member_budget_table?.max_budget || "No Limit"}
+
Budget Duration: {info.team_member_budget_table?.budget_duration || "No Limit"}
Key Duration: {info.metadata?.team_member_key_duration || "No Limit"}
TPM Limit: {info.team_member_budget_table?.tpm_limit || "No Limit"}
RPM Limit: {info.team_member_budget_table?.rpm_limit || "No Limit"}
From f7212d84d5ae7ac2f7bf469709bea03080acb6b7 Mon Sep 17 00:00:00 2001 From: tianduo-fh <153234738+Tianduo16@users.noreply.github.com> Date: Wed, 7 Jan 2026 11:12:27 -0500 Subject: [PATCH 5/9] fix: prevent duplicate User-Agent tags in request_tags (#18723) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The `_get_request_tags` function was returning a reference to the original tags list from metadata, then mutating it with `.extend()`. This caused duplicate User-Agent tags when the function was called multiple times during a single request lifecycle (e.g., by logging, prometheus, and guardrails). The fix uses `.copy()` to create a new list before extending, ensuring the original metadata tags are not mutated. πŸ€– Generated with [Claude Code](https://claude.com/claude-code) Co-authored-by: Tianduo Zhai Co-authored-by: Claude Opus 4.5 --- litellm/litellm_core_utils/litellm_logging.py | 4 +- .../test_litellm_logging.py | 57 +++++++++++++++++++ 2 files changed, 59 insertions(+), 2 deletions(-) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 5448fe7c77..d9161f1db4 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -4824,9 +4824,9 @@ class StandardLoggingPayloadSetup: metadata = litellm_params.get("metadata") or {} litellm_metadata = litellm_params.get("litellm_metadata") or {} if metadata.get("tags", []): - request_tags = metadata.get("tags", []) + request_tags = metadata.get("tags", []).copy() elif litellm_metadata.get("tags", []): - request_tags = litellm_metadata.get("tags", []) + request_tags = litellm_metadata.get("tags", []).copy() else: request_tags = [] user_agent_tags = StandardLoggingPayloadSetup._get_user_agent_tags( diff --git a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py index 9b150fd89f..bae0e5bbb4 100644 --- a/tests/test_litellm/litellm_core_utils/test_litellm_logging.py +++ b/tests/test_litellm/litellm_core_utils/test_litellm_logging.py @@ -393,6 +393,63 @@ def test_get_request_tags_from_metadata_and_litellm_metadata(): assert "User-Agent: litellm/1.0.0" in tags +def test_get_request_tags_does_not_mutate_original_tags(): + """ + Test that _get_request_tags does not mutate the original tags list in metadata. + + This is a regression test for a bug where calling _get_request_tags multiple times + would cause User-Agent tags to be duplicated because the function was mutating + the original tags list instead of creating a copy. + """ + from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup + + # Create metadata with original tags + original_tags = ["custom-tag-1", "custom-tag-2"] + metadata = {"tags": original_tags} + litellm_params = {"metadata": metadata} + proxy_server_request = { + "headers": { + "user-agent": "AsyncOpenAI/Python 1.99.9", + } + } + + # Call _get_request_tags multiple times (simulating multiple callbacks) + tags1 = StandardLoggingPayloadSetup._get_request_tags( + litellm_params=litellm_params, + proxy_server_request=proxy_server_request, + ) + tags2 = StandardLoggingPayloadSetup._get_request_tags( + litellm_params=litellm_params, + proxy_server_request=proxy_server_request, + ) + tags3 = StandardLoggingPayloadSetup._get_request_tags( + litellm_params=litellm_params, + proxy_server_request=proxy_server_request, + ) + + # Verify the original tags list was NOT mutated + assert original_tags == ["custom-tag-1", "custom-tag-2"], ( + f"Original tags list was mutated: {original_tags}" + ) + assert metadata["tags"] == ["custom-tag-1", "custom-tag-2"], ( + f"metadata['tags'] was mutated: {metadata['tags']}" + ) + + # Verify each returned list has exactly 2 User-Agent tags (not duplicated) + user_agent_count_1 = len([t for t in tags1 if t.startswith("User-Agent:")]) + user_agent_count_2 = len([t for t in tags2 if t.startswith("User-Agent:")]) + user_agent_count_3 = len([t for t in tags3 if t.startswith("User-Agent:")]) + + assert user_agent_count_1 == 2, f"Expected 2 User-Agent tags, got {user_agent_count_1}" + assert user_agent_count_2 == 2, f"Expected 2 User-Agent tags, got {user_agent_count_2}" + assert user_agent_count_3 == 2, f"Expected 2 User-Agent tags, got {user_agent_count_3}" + + # Verify all returned lists are independent (different objects) + assert tags1 is not tags2 + assert tags2 is not tags3 + assert tags1 is not original_tags + + def test_get_extra_header_tags(): """Test the _get_extra_header_tags method with various scenarios.""" import litellm From 8ce4eea88f977a5c0e31358babed40c468885f67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Wilson=20Marc=C3=ADlio=20J=C3=BAnior?= Date: Wed, 7 Jan 2026 13:13:55 -0300 Subject: [PATCH 6/9] make base_connection_pool_limit default value the same (#18721) --- litellm/proxy/_types.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index cff3bee1ca..bb6b56d509 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1937,7 +1937,7 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase): description="connect to a postgres db - needed for generating temporary keys + tracking spend / key", ) database_connection_pool_limit: Optional[int] = Field( - 100, + 10, description="default connection pool for prisma client connecting to postgres db", ) database_connection_timeout: Optional[float] = Field( From dc4ce7c5a21615504b739d445d9c3ffd35aaeb38 Mon Sep 17 00:00:00 2001 From: Abliteration AI Date: Wed, 7 Jan 2026 08:16:54 -0800 Subject: [PATCH 7/9] feat: Add abliteration.ai provider (#18678) * feat: Add abliteration.ai provider * adding signoz integration to observability docs * Fixing build * Adding timeout for flaky test * Fixing e2e * add team member budget duration in team/update * Reusable Duration Select and update team member budget UI --------- Co-authored-by: Goutham Karthi Co-authored-by: yuneng-jiang Co-authored-by: YutaSaito <36355491+uc4w6c@users.noreply.github.com> --- README.md | 2 +- .../my-website/docs/providers/abliteration.md | 109 ++++++++++++++++++ docs/my-website/sidebars.js | 13 ++- litellm/llms/openai_like/providers.json | 4 + provider_endpoints_support.json | 17 +++ .../openai_like/test_abliteration_provider.py | 50 ++++++++ 6 files changed, 188 insertions(+), 7 deletions(-) create mode 100644 docs/my-website/docs/providers/abliteration.md create mode 100644 tests/litellm/llms/openai_like/test_abliteration_provider.py diff --git a/README.md b/README.md index a020bd8089..75a23faa5c 100644 --- a/README.md +++ b/README.md @@ -262,6 +262,7 @@ Support for more providers. Missing a provider or LLM Platform, raise a [feature | Provider | `/chat/completions` | `/messages` | `/responses` | `/embeddings` | `/image/generations` | `/audio/transcriptions` | `/audio/speech` | `/moderations` | `/batches` | `/rerank` | |-------------------------------------------------------------------------------------|---------------------|-------------|--------------|---------------|----------------------|-------------------------|-----------------|----------------|-----------|-----------| +| [Abliteration (`abliteration`)](https://docs.litellm.ai/docs/providers/abliteration) | βœ… | | | | | | | | | | | [AI/ML API (`aiml`)](https://docs.litellm.ai/docs/providers/aiml) | βœ… | βœ… | βœ… | βœ… | βœ… | | | | | | | [AI21 (`ai21`)](https://docs.litellm.ai/docs/providers/ai21) | βœ… | βœ… | βœ… | | | | | | | | | [AI21 Chat (`ai21_chat`)](https://docs.litellm.ai/docs/providers/ai21) | βœ… | βœ… | βœ… | | | | | | | | @@ -455,4 +456,3 @@ All these checks must pass before your PR can be merged. - diff --git a/docs/my-website/docs/providers/abliteration.md b/docs/my-website/docs/providers/abliteration.md new file mode 100644 index 0000000000..a0fc7f3931 --- /dev/null +++ b/docs/my-website/docs/providers/abliteration.md @@ -0,0 +1,109 @@ +# Abliteration + +## Overview + +| Property | Details | +|-------|-------| +| Description | Abliteration provides an OpenAI-compatible `/chat/completions` endpoint. | +| Provider Route on LiteLLM | `abliteration/` | +| Link to Provider Doc | [Abliteration](https://abliteration.ai) | +| Base URL | `https://api.abliteration.ai/v1` | +| Supported Operations | [`/chat/completions`](#sample-usage) | + +
+ +## Required Variables + +```python showLineNumbers title="Environment Variables" +os.environ["ABLITERATION_API_KEY"] = "" # your Abliteration API key +``` + +## Sample Usage + +```python showLineNumbers title="Abliteration Completion" +import os +from litellm import completion + +os.environ["ABLITERATION_API_KEY"] = "" + +response = completion( + model="abliteration/abliterated-model", + messages=[{"role": "user", "content": "Hello from LiteLLM"}], +) + +print(response) +``` + +## Sample Usage - Streaming + +```python showLineNumbers title="Abliteration Streaming Completion" +import os +from litellm import completion + +os.environ["ABLITERATION_API_KEY"] = "" + +response = completion( + model="abliteration/abliterated-model", + messages=[{"role": "user", "content": "Stream a short reply"}], + stream=True, +) + +for chunk in response: + print(chunk) +``` + +## Usage with LiteLLM Proxy Server + +1. Add the model to your proxy config: + +```yaml showLineNumbers title="config.yaml" +model_list: + - model_name: abliteration-chat + litellm_params: + model: abliteration/abliterated-model + api_key: os.environ/ABLITERATION_API_KEY +``` + +2. Start the proxy: + +```bash +litellm --config /path/to/config.yaml +``` + +## Direct API Usage (Bearer Token) + +Use the environment variable as a Bearer token against the OpenAI-compatible endpoint: +`https://api.abliteration.ai/v1/chat/completions`. + +```bash showLineNumbers title="cURL" +export ABLITERATION_API_KEY="" +curl https://api.abliteration.ai/v1/chat/completions \ + -H "Authorization: Bearer ${ABLITERATION_API_KEY}" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "abliterated-model", + "messages": [{"role": "user", "content": "Hello from Abliteration"}] + }' +``` + +```python showLineNumbers title="Python (requests)" +import os +import requests + +api_key = os.environ["ABLITERATION_API_KEY"] + +response = requests.post( + "https://api.abliteration.ai/v1/chat/completions", + headers={ + "Authorization": f"Bearer {api_key}", + "Content-Type": "application/json", + }, + json={ + "model": "abliterated-model", + "messages": [{"role": "user", "content": "Hello from Abliteration"}], + }, + timeout=60, +) + +print(response.json()) +``` diff --git a/docs/my-website/sidebars.js b/docs/my-website/sidebars.js index 004132ca05..4c50449afc 100644 --- a/docs/my-website/sidebars.js +++ b/docs/my-website/sidebars.js @@ -683,12 +683,13 @@ const sidebars = { "providers/bedrock_writer", "providers/bedrock_batches", "providers/aws_polly", - "providers/bedrock_vector_store", - ] - }, - "providers/litellm_proxy", - "providers/ai21", - "providers/aiml", + "providers/bedrock_vector_store", + ] + }, + "providers/litellm_proxy", + "providers/abliteration", + "providers/ai21", + "providers/aiml", "providers/aleph_alpha", "providers/amazon_nova", "providers/anyscale", diff --git a/litellm/llms/openai_like/providers.json b/litellm/llms/openai_like/providers.json index 206aee1359..bda3684a8a 100644 --- a/litellm/llms/openai_like/providers.json +++ b/litellm/llms/openai_like/providers.json @@ -61,6 +61,10 @@ "max_completion_tokens": "max_tokens" } }, + "abliteration": { + "base_url": "https://api.abliteration.ai/v1", + "api_key_env": "ABLITERATION_API_KEY" + }, "llamagate": { "base_url": "https://api.llamagate.dev/v1", "api_key_env": "LLAMAGATE_API_KEY", diff --git a/provider_endpoints_support.json b/provider_endpoints_support.json index bc5dea7b97..2521d71b5c 100644 --- a/provider_endpoints_support.json +++ b/provider_endpoints_support.json @@ -34,6 +34,23 @@ } }, "providers": { + "abliteration": { + "display_name": "Abliteration (`abliteration`)", + "url": "https://docs.litellm.ai/docs/providers/abliteration", + "endpoints": { + "chat_completions": true, + "messages": false, + "responses": false, + "embeddings": false, + "image_generations": false, + "audio_transcriptions": false, + "audio_speech": false, + "moderations": false, + "batches": false, + "rerank": false, + "a2a": false + } + }, "aiml": { "display_name": "AI/ML API (`aiml`)", "url": "https://docs.litellm.ai/docs/providers/aiml", diff --git a/tests/litellm/llms/openai_like/test_abliteration_provider.py b/tests/litellm/llms/openai_like/test_abliteration_provider.py new file mode 100644 index 0000000000..8b8d443fc4 --- /dev/null +++ b/tests/litellm/llms/openai_like/test_abliteration_provider.py @@ -0,0 +1,50 @@ +""" +Unit tests for the Abliteration OpenAI-like provider. +""" + +import os +import sys + +sys.path.insert( + 0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../..")) +) + +from litellm.llms.openai_like.dynamic_config import create_config_class +from litellm.llms.openai_like.json_loader import JSONProviderRegistry + +ABLITERATION_BASE_URL = "https://api.abliteration.ai/v1" + + +def _get_config(): + provider = JSONProviderRegistry.get("abliteration") + assert provider is not None + config_class = create_config_class(provider) + return config_class() + + +def test_abliteration_provider_registered(): + provider = JSONProviderRegistry.get("abliteration") + assert provider is not None + assert provider.base_url == ABLITERATION_BASE_URL + assert provider.api_key_env == "ABLITERATION_API_KEY" + + +def test_abliteration_resolves_env_api_key(monkeypatch): + config = _get_config() + monkeypatch.setenv("ABLITERATION_API_KEY", "test-key") + api_base, api_key = config._get_openai_compatible_provider_info(None, None) + assert api_base == ABLITERATION_BASE_URL + assert api_key == "test-key" + + +def test_abliteration_complete_url_appends_endpoint(): + config = _get_config() + url = config.get_complete_url( + api_base=ABLITERATION_BASE_URL, + api_key="test-key", + model="abliteration/abliterated-model", + optional_params={}, + litellm_params={}, + stream=False, + ) + assert url == f"{ABLITERATION_BASE_URL}/chat/completions" From 5a4242e1882517cf850849721897bf39f46e2cf9 Mon Sep 17 00:00:00 2001 From: Wen-Tien Chang Date: Thu, 8 Jan 2026 03:25:59 +0800 Subject: [PATCH 8/9] fix(braintrust): pass span_attributes in async logging and skip tags on non-root spans (#18409) * fix(braintrust): handle tags and span attributes for non-root spans * fix(braintrust): refactor span attributes handling and remove duplicate metrics assignment --- litellm/integrations/braintrust_logging.py | 35 ++++++++++++++++++---- 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/litellm/integrations/braintrust_logging.py b/litellm/integrations/braintrust_logging.py index 364fa3f5de..585de510e8 100644 --- a/litellm/integrations/braintrust_logging.py +++ b/litellm/integrations/braintrust_logging.py @@ -225,10 +225,13 @@ class BraintrustLogger(CustomLogger): "id": litellm_call_id, "input": prompt["messages"], "metadata": standard_logging_object, - "tags": tags, "span_attributes": {"name": span_name, "type": "llm"}, } - + + # Braintrust cannot specify 'tags' for non-root spans + if dynamic_metadata.get("root_span_id") is None: + request_data["tags"] = tags + # Only add those that are not None (or falsy) for key, value in span_attributes.items(): if value: @@ -351,14 +354,37 @@ class BraintrustLogger(CustomLogger): # Allow metadata override for span name span_name = dynamic_metadata.get("span_name", "Chat Completion") + # Span parents is a special case + span_parents = dynamic_metadata.get("span_parents") + + # Convert comma-separated string to list if present + if span_parents: + span_parents = [s.strip() for s in span_parents.split(",") if s.strip()] + + # Add optional span attributes only if present + span_attributes = { + "span_id": dynamic_metadata.get("span_id"), + "root_span_id": dynamic_metadata.get("root_span_id"), + "span_parents": span_parents, + } + request_data = { "id": litellm_call_id, "input": prompt["messages"], "output": output, "metadata": standard_logging_object, - "tags": tags, "span_attributes": {"name": span_name, "type": "llm"}, } + + # Braintrust cannot specify 'tags' for non-root spans + if dynamic_metadata.get("root_span_id") is None: + request_data["tags"] = tags + + # Only add those that are not None (or falsy) + for key, value in span_attributes.items(): + if value: + request_data[key] = value + if choices is not None: request_data["output"] = [choice.dict() for choice in choices] else: @@ -367,9 +393,6 @@ class BraintrustLogger(CustomLogger): if metrics is not None: request_data["metrics"] = metrics - if metrics is not None: - request_data["metrics"] = metrics - try: await self.global_braintrust_http_handler.post( url=f"{self.api_base}/project_logs/{project_id}/insert", From bae625bdc6e9f9608c48bc5251134ff2515cb84e Mon Sep 17 00:00:00 2001 From: Elkhan Eminov Date: Wed, 7 Jan 2026 19:27:31 +0000 Subject: [PATCH 9/9] OpenRouter embeddings API support (#18391) * support for OpenRouter embeddings * add bearer * add content header --- .../openrouter/embedding/transformation.py | 182 ++++++++++++++++++ litellm/main.py | 45 +++++ litellm/utils.py | 5 + provider_endpoints_support.json | 2 +- tests/llm_translation/test_openrouter.py | 15 ++ ...est_openrouter_embedding_transformation.py | 132 +++++++++++++ 6 files changed, 380 insertions(+), 1 deletion(-) create mode 100644 litellm/llms/openrouter/embedding/transformation.py create mode 100644 tests/test_litellm/llms/openrouter/test_openrouter_embedding_transformation.py diff --git a/litellm/llms/openrouter/embedding/transformation.py b/litellm/llms/openrouter/embedding/transformation.py new file mode 100644 index 0000000000..d1d0e911d1 --- /dev/null +++ b/litellm/llms/openrouter/embedding/transformation.py @@ -0,0 +1,182 @@ +""" +OpenRouter Embedding API Configuration. + +This module provides the configuration for OpenRouter's Embedding API. +OpenRouter is OpenAI-compatible and supports embeddings via the /v1/embeddings endpoint. + +Docs: https://openrouter.ai/docs +""" +from typing import TYPE_CHECKING, Any, Optional + +import httpx + +from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig +from litellm.types.llms.openai import AllEmbeddingInputValues +from litellm.types.utils import EmbeddingResponse +from litellm.utils import convert_to_model_response_object + +from ..common_utils import OpenRouterException + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj + + LiteLLMLoggingObj = _LiteLLMLoggingObj +else: + LiteLLMLoggingObj = Any + + +class OpenrouterEmbeddingConfig(BaseEmbeddingConfig): + """ + Configuration for OpenRouter's Embedding API. + + Reference: https://openrouter.ai/docs + """ + + def validate_environment( + self, + headers: dict, + model: str, + messages: list, + optional_params: dict, + litellm_params: dict, + api_key: Optional[str] = None, + api_base: Optional[str] = None, + ) -> dict: + """ + Validate environment and set up headers for OpenRouter API. + + OpenRouter requires: + - Authorization header with Bearer token + - HTTP-Referer header (site URL) + - X-Title header (app name) + """ + from litellm import get_secret + + # Get OpenRouter-specific headers + openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai" + openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" + + openrouter_headers = { + "HTTP-Referer": openrouter_site_url, + "X-Title": openrouter_app_name, + "Content-Type": "application/json", + } + + # Add Authorization header if api_key is provided + if api_key: + openrouter_headers["Authorization"] = f"Bearer {api_key}" + + # Merge with existing headers (user's extra_headers take priority) + merged_headers = {**openrouter_headers, **headers} + + return merged_headers + + def get_complete_url( + self, + api_base: Optional[str], + api_key: Optional[str], + model: str, + optional_params: dict, + litellm_params: dict, + stream: Optional[bool] = None, + ) -> str: + """ + Get the complete URL for OpenRouter Embedding API endpoint. + """ + # api_base is already set to https://openrouter.ai/api/v1 in main.py + # Remove trailing slashes + if api_base: + api_base = api_base.rstrip("/") + else: + api_base = "https://openrouter.ai/api/v1" + + # Return the embeddings endpoint + return f"{api_base}/embeddings" + + def transform_embedding_request( + self, + model: str, + input: AllEmbeddingInputValues, + optional_params: dict, + headers: dict, + ) -> dict: + """ + Transform embedding request to OpenRouter format (OpenAI-compatible). + """ + # Ensure input is a list + if isinstance(input, str): + input = [input] + + # OpenRouter expects the full model name (e.g., google/gemini-embedding-001) + # Strip 'openrouter/' prefix if present + if model.startswith("openrouter/"): + model = model.replace("openrouter/", "", 1) + + return { + "model": model, + "input": input, + **optional_params, + } + + def transform_embedding_response( + self, + model: str, + raw_response: httpx.Response, + model_response: EmbeddingResponse, + logging_obj: LiteLLMLoggingObj, + api_key: Optional[str], + request_data: dict, + optional_params: dict, + litellm_params: dict, + ) -> EmbeddingResponse: + """ + Transform embedding response from OpenRouter format (OpenAI-compatible). + """ + logging_obj.post_call(original_response=raw_response.text) + + # OpenRouter returns standard OpenAI-compatible embedding response + response_json = raw_response.json() + + return convert_to_model_response_object( + response_object=response_json, + model_response_object=model_response, + response_type="embedding", + ) + + def get_supported_openai_params(self, model: str) -> list: + """ + Get list of supported OpenAI parameters for OpenRouter embeddings. + """ + return [ + "timeout", + "dimensions", + "encoding_format", + "user", + ] + + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ) -> dict: + """ + Map OpenAI parameters to OpenRouter format. + """ + for param, value in non_default_params.items(): + if param in self.get_supported_openai_params(model): + optional_params[param] = value + return optional_params + + def get_error_class( + self, error_message: str, status_code: int, headers: Any + ) -> Any: + """ + Get the error class for OpenRouter errors. + """ + return OpenRouterException( + message=error_message, + status_code=status_code, + headers=headers, + ) diff --git a/litellm/main.py b/litellm/main.py index e8a8b504d9..6905d8f6a8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -4701,6 +4701,51 @@ def embedding( # noqa: PLR0915 litellm_params=litellm_params_dict, headers=headers, ) + elif custom_llm_provider == "openrouter": + api_base = ( + api_base + or litellm.api_base + or get_secret_str("OPENROUTER_API_BASE") + or "https://openrouter.ai/api/v1" + ) + + api_key = ( + api_key + or litellm.api_key + or litellm.openrouter_key + or get_secret("OPENROUTER_API_KEY") + or get_secret("OR_API_KEY") + ) + + openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai" + openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" + + openrouter_headers = { + "HTTP-Referer": openrouter_site_url, + "X-Title": openrouter_app_name, + } + + _headers = headers or litellm.headers + if _headers: + openrouter_headers.update(_headers) + + headers = openrouter_headers + + response = base_llm_http_handler.embedding( + model=model, + input=input, + custom_llm_provider=custom_llm_provider, + api_base=api_base, + api_key=api_key, + logging_obj=logging, + timeout=timeout, + model_response=EmbeddingResponse(), + optional_params=optional_params, + client=client, + aembedding=aembedding, + litellm_params=litellm_params_dict, + headers=headers, + ) elif custom_llm_provider == "huggingface": api_key = ( api_key diff --git a/litellm/utils.py b/litellm/utils.py index fbbaa94f7a..fb01337bc5 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7718,6 +7718,11 @@ class ProviderConfigManager: return litellm.CometAPIEmbeddingConfig() elif litellm.LlmProviders.GITHUB_COPILOT == provider: return litellm.GithubCopilotEmbeddingConfig() + elif litellm.LlmProviders.OPENROUTER == provider: + from litellm.llms.openrouter.embedding.transformation import ( + OpenrouterEmbeddingConfig, + ) + return OpenrouterEmbeddingConfig() elif litellm.LlmProviders.GIGACHAT == provider: return litellm.GigaChatEmbeddingConfig() elif litellm.LlmProviders.SAGEMAKER == provider: diff --git a/provider_endpoints_support.json b/provider_endpoints_support.json index 2521d71b5c..a29b66f6c5 100644 --- a/provider_endpoints_support.json +++ b/provider_endpoints_support.json @@ -1565,7 +1565,7 @@ "chat_completions": true, "messages": true, "responses": true, - "embeddings": false, + "embeddings": true, "image_generations": false, "audio_transcriptions": false, "audio_speech": false, diff --git a/tests/llm_translation/test_openrouter.py b/tests/llm_translation/test_openrouter.py index 839d08e12b..105b05d344 100644 --- a/tests/llm_translation/test_openrouter.py +++ b/tests/llm_translation/test_openrouter.py @@ -32,3 +32,18 @@ def test_completion_openrouter_image_generation(): .message.images[0]["image_url"]["url"] .startswith("data:image/png;base64,") ) + + +def test_openrouter_embedding(): + """Test OpenRouter embeddings support.""" + litellm._turn_on_debug() + resp = litellm.embedding( + model="openrouter/openai/text-embedding-3-small", + input=["Hello world", "How are you?"], + ) + print(resp) + assert resp is not None + assert len(resp.data) == 2 + assert resp.data[0]["embedding"] is not None + assert isinstance(resp.data[0]["embedding"], list) + assert len(resp.data[0]["embedding"]) > 0 diff --git a/tests/test_litellm/llms/openrouter/test_openrouter_embedding_transformation.py b/tests/test_litellm/llms/openrouter/test_openrouter_embedding_transformation.py new file mode 100644 index 0000000000..714adc346d --- /dev/null +++ b/tests/test_litellm/llms/openrouter/test_openrouter_embedding_transformation.py @@ -0,0 +1,132 @@ +""" +Unit tests for OpenRouter embedding transformation logic. +""" +from litellm.llms.openrouter.embedding.transformation import ( + OpenrouterEmbeddingConfig, +) + + +def test_openrouter_embedding_supported_params(): + """Test that supported OpenAI params are correctly defined.""" + config = OpenrouterEmbeddingConfig() + supported = config.get_supported_openai_params("test-model") + + assert "timeout" in supported + assert "dimensions" in supported + assert "encoding_format" in supported + assert "user" in supported + + +def test_openrouter_embedding_transform_request(): + """Test request transformation logic.""" + config = OpenrouterEmbeddingConfig() + + # Test with string input + result = config.transform_embedding_request( + model="openrouter/google/text-embedding-004", + input="Hello world", + optional_params={}, + headers={}, + ) + + assert result["model"] == "google/text-embedding-004" + assert result["input"] == ["Hello world"] + + # Test with list input + result = config.transform_embedding_request( + model="google/text-embedding-004", + input=["Hello", "World"], + optional_params={"dimensions": 512}, + headers={}, + ) + + assert result["model"] == "google/text-embedding-004" + assert result["input"] == ["Hello", "World"] + assert result["dimensions"] == 512 + + +def test_openrouter_embedding_validate_environment(): + """Test environment validation and header setup.""" + config = OpenrouterEmbeddingConfig() + + # Test with API key + headers = config.validate_environment( + headers={"Custom-Header": "value"}, + model="test-model", + messages=[], + optional_params={}, + litellm_params={}, + api_key="test-api-key", + ) + + # Should include OpenRouter-specific headers + assert "HTTP-Referer" in headers + assert "X-Title" in headers + # Should include Content-Type header + assert "Content-Type" in headers + assert headers["Content-Type"] == "application/json" + # Should include Authorization header + assert "Authorization" in headers + assert headers["Authorization"] == "Bearer test-api-key" + # Should preserve custom headers + assert headers["Custom-Header"] == "value" + + # Test without API key + headers_no_key = config.validate_environment( + headers={}, + model="test-model", + messages=[], + optional_params={}, + litellm_params={}, + api_key=None, + ) + + # Should still include OpenRouter headers but not Authorization + assert "HTTP-Referer" in headers_no_key + assert "X-Title" in headers_no_key + assert "Content-Type" in headers_no_key + assert "Authorization" not in headers_no_key + + +def test_openrouter_embedding_get_complete_url(): + """Test URL construction.""" + config = OpenrouterEmbeddingConfig() + + url = config.get_complete_url( + api_base="https://openrouter.ai/api/v1", + api_key="test-key", + model="test-model", + optional_params={}, + litellm_params={}, + ) + + assert url == "https://openrouter.ai/api/v1/embeddings" + + # Test with trailing slash + url = config.get_complete_url( + api_base="https://openrouter.ai/api/v1/", + api_key="test-key", + model="test-model", + optional_params={}, + litellm_params={}, + ) + + assert url == "https://openrouter.ai/api/v1/embeddings" + + +def test_openrouter_embedding_map_params(): + """Test parameter mapping.""" + config = OpenrouterEmbeddingConfig() + + result = config.map_openai_params( + non_default_params={"dimensions": 512, "timeout": 30, "unsupported": "value"}, + optional_params={}, + model="test-model", + drop_params=False, + ) + + # Supported params should be included + assert result["dimensions"] == 512 + assert result["timeout"] == 30 + # Unsupported params should not be included + assert "unsupported" not in result