Merge pull request #18763 from BerriAI/litellm_staging_01_07_2026

Staging - 01/07/2026
This commit is contained in:
Sameer Kankute 2026-01-09 17:01:58 +05:30 committed by GitHub
commit 844c766c65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
26 changed files with 1657 additions and 333 deletions

View File

@ -262,6 +262,7 @@ Support for more providers. Missing a provider or LLM Platform, raise a [feature
| Provider | `/chat/completions` | `/messages` | `/responses` | `/embeddings` | `/image/generations` | `/audio/transcriptions` | `/audio/speech` | `/moderations` | `/batches` | `/rerank` |
|-------------------------------------------------------------------------------------|---------------------|-------------|--------------|---------------|----------------------|-------------------------|-----------------|----------------|-----------|-----------|
| [Abliteration (`abliteration`)](https://docs.litellm.ai/docs/providers/abliteration) | ✅ | | | | | | | | | |
| [AI/ML API (`aiml`)](https://docs.litellm.ai/docs/providers/aiml) | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
| [AI21 (`ai21`)](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | | | | | | | |
| [AI21 Chat (`ai21_chat`)](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | | | | | | | |
@ -455,4 +456,3 @@ All these checks must pass before your PR can be merged.
<img src="https://contrib.rocks/image?repo=BerriAI/litellm" />
</a>

View File

@ -0,0 +1,109 @@
# Abliteration
## Overview
| Property | Details |
|-------|-------|
| Description | Abliteration provides an OpenAI-compatible `/chat/completions` endpoint. |
| Provider Route on LiteLLM | `abliteration/` |
| Link to Provider Doc | [Abliteration](https://abliteration.ai) |
| Base URL | `https://api.abliteration.ai/v1` |
| Supported Operations | [`/chat/completions`](#sample-usage) |
<br />
## Required Variables
```python showLineNumbers title="Environment Variables"
os.environ["ABLITERATION_API_KEY"] = "" # your Abliteration API key
```
## Sample Usage
```python showLineNumbers title="Abliteration Completion"
import os
from litellm import completion
os.environ["ABLITERATION_API_KEY"] = ""
response = completion(
model="abliteration/abliterated-model",
messages=[{"role": "user", "content": "Hello from LiteLLM"}],
)
print(response)
```
## Sample Usage - Streaming
```python showLineNumbers title="Abliteration Streaming Completion"
import os
from litellm import completion
os.environ["ABLITERATION_API_KEY"] = ""
response = completion(
model="abliteration/abliterated-model",
messages=[{"role": "user", "content": "Stream a short reply"}],
stream=True,
)
for chunk in response:
print(chunk)
```
## Usage with LiteLLM Proxy Server
1. Add the model to your proxy config:
```yaml showLineNumbers title="config.yaml"
model_list:
- model_name: abliteration-chat
litellm_params:
model: abliteration/abliterated-model
api_key: os.environ/ABLITERATION_API_KEY
```
2. Start the proxy:
```bash
litellm --config /path/to/config.yaml
```
## Direct API Usage (Bearer Token)
Use the environment variable as a Bearer token against the OpenAI-compatible endpoint:
`https://api.abliteration.ai/v1/chat/completions`.
```bash showLineNumbers title="cURL"
export ABLITERATION_API_KEY=""
curl https://api.abliteration.ai/v1/chat/completions \
-H "Authorization: Bearer ${ABLITERATION_API_KEY}" \
-H "Content-Type: application/json" \
-d '{
"model": "abliterated-model",
"messages": [{"role": "user", "content": "Hello from Abliteration"}]
}'
```
```python showLineNumbers title="Python (requests)"
import os
import requests
api_key = os.environ["ABLITERATION_API_KEY"]
response = requests.post(
"https://api.abliteration.ai/v1/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json={
"model": "abliterated-model",
"messages": [{"role": "user", "content": "Hello from Abliteration"}],
},
timeout=60,
)
print(response.json())
```

View File

@ -8,13 +8,7 @@ Use [Qualifire](https://qualifire.ai) to evaluate LLM outputs for quality, safet
## Quick Start
### 1. Install the Qualifire SDK
```bash
pip install qualifire
```
### 2. Define Guardrails on your LiteLLM config.yaml
### 1. Define Guardrails on your LiteLLM config.yaml
Define your guardrails under the `guardrails` section:
@ -61,13 +55,13 @@ guardrails:
- `post_call` Run **after** LLM call, on **input & output**
- `during_call` Run **during** LLM call, on **input**. Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
### 3. Start LiteLLM Gateway
### 2. Start LiteLLM Gateway
```shell
litellm --config config.yaml --detailed_debug
```
### 4. Test request
### 3. Test request
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
@ -142,7 +136,7 @@ guardrails:
evaluation_id: eval_abc123 # Your evaluation ID from Qualifire dashboard
```
When `evaluation_id` is provided, LiteLLM will use `invoke_evaluation()` instead of `evaluate()`, running the pre-configured evaluation from your dashboard.
When `evaluation_id` is provided, LiteLLM will use the invoke evaluation API endpoint instead of the evaluate endpoint, running the pre-configured evaluation from your dashboard.
## Available Checks
@ -213,19 +207,19 @@ guardrails:
### Parameter Reference
| Parameter | Type | Default | Description |
| ------------------------------ | ----------- | --------------------------- | -------------------------------------------------------- |
| `api_key` | `str` | `QUALIFIRE_API_KEY` env var | Your Qualifire API key |
| `api_base` | `str` | `None` | Custom API base URL (optional) |
| `evaluation_id` | `str` | `None` | Pre-configured evaluation ID from Qualifire dashboard |
| `prompt_injections` | `bool` | `true` (if no other checks) | Enable prompt injection detection |
| `hallucinations_check` | `bool` | `None` | Enable hallucination detection |
| `grounding_check` | `bool` | `None` | Enable grounding verification |
| `pii_check` | `bool` | `None` | Enable PII detection |
| `content_moderation_check` | `bool` | `None` | Enable content moderation |
| `tool_selection_quality_check` | `bool` | `None` | Enable tool selection quality check |
| `assertions` | `List[str]` | `None` | Custom assertions to validate |
| `on_flagged` | `str` | `"block"` | Action when content is flagged: `"block"` or `"monitor"` |
| Parameter | Type | Default | Description |
| ------------------------------ | ----------- | ---------------------------- | -------------------------------------------------------- |
| `api_key` | `str` | `QUALIFIRE_API_KEY` env var | Your Qualifire API key |
| `api_base` | `str` | `https://proxy.qualifire.ai` | Custom API base URL (optional) |
| `evaluation_id` | `str` | `None` | Pre-configured evaluation ID from Qualifire dashboard |
| `prompt_injections` | `bool` | `true` (if no other checks) | Enable prompt injection detection |
| `hallucinations_check` | `bool` | `None` | Enable hallucination detection |
| `grounding_check` | `bool` | `None` | Enable grounding verification |
| `pii_check` | `bool` | `None` | Enable PII detection |
| `content_moderation_check` | `bool` | `None` | Enable content moderation |
| `tool_selection_quality_check` | `bool` | `None` | Enable tool selection quality check |
| `assertions` | `List[str]` | `None` | Custom assertions to validate |
| `on_flagged` | `str` | `"block"` | Action when content is flagged: `"block"` or `"monitor"` |
### Default Behavior
@ -261,4 +255,3 @@ This evaluates whether the LLM selected the appropriate tools and provided corre
- [Qualifire Documentation](https://docs.qualifire.ai)
- [Qualifire Dashboard](https://app.qualifire.ai)
- [Qualifire Python SDK](https://github.com/qualifire-dev/qualifire-python-sdk)

View File

@ -55,6 +55,7 @@ const sidebars = {
"proxy/guardrails/test_playground",
"proxy/guardrails/litellm_content_filter",
...[
"proxy/guardrails/qualifire",
"proxy/guardrails/aim_security",
"proxy/guardrails/onyx_security",
"proxy/guardrails/aporia_api",
@ -653,12 +654,13 @@ const sidebars = {
"providers/bedrock_writer",
"providers/bedrock_batches",
"providers/aws_polly",
"providers/bedrock_vector_store",
]
},
"providers/litellm_proxy",
"providers/ai21",
"providers/aiml",
"providers/bedrock_vector_store",
]
},
"providers/litellm_proxy",
"providers/abliteration",
"providers/ai21",
"providers/aiml",
"providers/aleph_alpha",
"providers/amazon_nova",
"providers/anyscale",

View File

@ -225,10 +225,13 @@ class BraintrustLogger(CustomLogger):
"id": litellm_call_id,
"input": prompt["messages"],
"metadata": standard_logging_object,
"tags": tags,
"span_attributes": {"name": span_name, "type": "llm"},
}
# Braintrust cannot specify 'tags' for non-root spans
if dynamic_metadata.get("root_span_id") is None:
request_data["tags"] = tags
# Only add those that are not None (or falsy)
for key, value in span_attributes.items():
if value:
@ -351,14 +354,37 @@ class BraintrustLogger(CustomLogger):
# Allow metadata override for span name
span_name = dynamic_metadata.get("span_name", "Chat Completion")
# Span parents is a special case
span_parents = dynamic_metadata.get("span_parents")
# Convert comma-separated string to list if present
if span_parents:
span_parents = [s.strip() for s in span_parents.split(",") if s.strip()]
# Add optional span attributes only if present
span_attributes = {
"span_id": dynamic_metadata.get("span_id"),
"root_span_id": dynamic_metadata.get("root_span_id"),
"span_parents": span_parents,
}
request_data = {
"id": litellm_call_id,
"input": prompt["messages"],
"output": output,
"metadata": standard_logging_object,
"tags": tags,
"span_attributes": {"name": span_name, "type": "llm"},
}
# Braintrust cannot specify 'tags' for non-root spans
if dynamic_metadata.get("root_span_id") is None:
request_data["tags"] = tags
# Only add those that are not None (or falsy)
for key, value in span_attributes.items():
if value:
request_data[key] = value
if choices is not None:
request_data["output"] = [choice.dict() for choice in choices]
else:
@ -367,9 +393,6 @@ class BraintrustLogger(CustomLogger):
if metrics is not None:
request_data["metrics"] = metrics
if metrics is not None:
request_data["metrics"] = metrics
try:
await self.global_braintrust_http_handler.post(
url=f"{self.api_base}/project_logs/{project_id}/insert",

View File

@ -45,6 +45,7 @@ def _get_cached_end_user_id_for_cost_tracking():
global _get_end_user_id_for_cost_tracking
if _get_end_user_id_for_cost_tracking is None:
from litellm.utils import get_end_user_id_for_cost_tracking
_get_end_user_id_for_cost_tracking = get_end_user_id_for_cost_tracking
return _get_end_user_id_for_cost_tracking
@ -330,6 +331,25 @@ class PrometheusLogger(CustomLogger):
labelnames=self.get_labels_for_metric("litellm_requests_metric"),
)
# Cache metrics
self.litellm_cache_hits_metric = self._counter_factory(
name="litellm_cache_hits_metric",
documentation="Total number of LiteLLM cache hits",
labelnames=self.get_labels_for_metric("litellm_cache_hits_metric"),
)
self.litellm_cache_misses_metric = self._counter_factory(
name="litellm_cache_misses_metric",
documentation="Total number of LiteLLM cache misses",
labelnames=self.get_labels_for_metric("litellm_cache_misses_metric"),
)
self.litellm_cached_tokens_metric = self._counter_factory(
name="litellm_cached_tokens_metric",
documentation="Total tokens served from LiteLLM cache",
labelnames=self.get_labels_for_metric("litellm_cached_tokens_metric"),
)
except Exception as e:
print_verbose(f"Got exception on init prometheus client {str(e)}")
raise e
@ -801,7 +821,7 @@ class PrometheusLogger(CustomLogger):
litellm_params = kwargs.get("litellm_params", {}) or {}
_metadata = litellm_params.get("metadata", {})
get_end_user_id_for_cost_tracking = _get_cached_end_user_id_for_cost_tracking()
end_user_id = get_end_user_id_for_cost_tracking(
litellm_params, service_type="prometheus"
)
@ -821,7 +841,7 @@ class PrometheusLogger(CustomLogger):
user_api_key_auth_metadata: Optional[dict] = standard_logging_payload[
"metadata"
].get("user_api_key_auth_metadata")
# Include top-level metadata fields (excluding nested dictionaries)
# This allows accessing fields like requester_ip_address from top-level metadata
top_level_metadata = standard_logging_payload.get("metadata", {})
@ -832,7 +852,7 @@ class PrometheusLogger(CustomLogger):
for k, v in top_level_metadata.items()
if not isinstance(v, dict) # Exclude nested dicts to avoid conflicts
}
combined_metadata: Dict[str, Any] = {
**top_level_fields, # Include top-level fields first
**(_requester_metadata if _requester_metadata else {}),
@ -951,6 +971,12 @@ class PrometheusLogger(CustomLogger):
kwargs, start_time, end_time, enum_values, output_tokens
)
# cache metrics
self._increment_cache_metrics(
standard_logging_payload=standard_logging_payload, # type: ignore
enum_values=enum_values,
)
if (
standard_logging_payload["stream"] is True
): # log successful streaming requests from logging event hook.
@ -1020,6 +1046,54 @@ class PrometheusLogger(CustomLogger):
standard_logging_payload["completion_tokens"]
)
def _increment_cache_metrics(
self,
standard_logging_payload: StandardLoggingPayload,
enum_values: UserAPIKeyLabelValues,
):
"""
Increment cache-related Prometheus metrics based on cache hit/miss status.
Args:
standard_logging_payload: Contains cache_hit field (True/False/None)
enum_values: Label values for Prometheus metrics
"""
cache_hit = standard_logging_payload.get("cache_hit")
# Only track if cache_hit has a definite value (True or False)
if cache_hit is None:
return
if cache_hit is True:
# Increment cache hits counter
_labels = prometheus_label_factory(
supported_enum_labels=self.get_labels_for_metric(
metric_name="litellm_cache_hits_metric"
),
enum_values=enum_values,
)
self.litellm_cache_hits_metric.labels(**_labels).inc()
# Increment cached tokens counter
total_tokens = standard_logging_payload.get("total_tokens", 0)
if total_tokens > 0:
_labels = prometheus_label_factory(
supported_enum_labels=self.get_labels_for_metric(
metric_name="litellm_cached_tokens_metric"
),
enum_values=enum_values,
)
self.litellm_cached_tokens_metric.labels(**_labels).inc(total_tokens)
else:
# cache_hit is False - increment cache misses counter
_labels = prometheus_label_factory(
supported_enum_labels=self.get_labels_for_metric(
metric_name="litellm_cache_misses_metric"
),
enum_values=enum_values,
)
self.litellm_cache_misses_metric.labels(**_labels).inc()
async def _increment_remaining_budget_metrics(
self,
user_api_team: Optional[str],
@ -1208,7 +1282,7 @@ class PrometheusLogger(CustomLogger):
litellm_params = kwargs.get("litellm_params", {}) or {}
get_end_user_id_for_cost_tracking = _get_cached_end_user_id_for_cost_tracking()
end_user_id = get_end_user_id_for_cost_tracking(
litellm_params, service_type="prometheus"
)
@ -1562,7 +1636,6 @@ class PrometheusLogger(CustomLogger):
api_provider=llm_provider or "",
)
if exception is not None:
_labels = prometheus_label_factory(
supported_enum_labels=self.get_labels_for_metric(
metric_name="litellm_deployment_failure_responses"
@ -1595,12 +1668,11 @@ class PrometheusLogger(CustomLogger):
enum_values: UserAPIKeyLabelValues,
output_tokens: float = 1.0,
):
try:
verbose_logger.debug("setting remaining tokens requests metric")
standard_logging_payload: Optional[StandardLoggingPayload] = (
request_kwargs.get("standard_logging_object")
)
standard_logging_payload: Optional[
StandardLoggingPayload
] = request_kwargs.get("standard_logging_object")
if standard_logging_payload is None:
return
@ -2380,10 +2452,10 @@ class PrometheusLogger(CustomLogger):
from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
from litellm.integrations.custom_logger import CustomLogger
prometheus_loggers: List[CustomLogger] = (
litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=PrometheusLogger
)
prometheus_loggers: List[
CustomLogger
] = litellm.logging_callback_manager.get_custom_loggers_for_type(
callback_type=PrometheusLogger
)
# we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them
verbose_logger.debug("found %s prometheus loggers", len(prometheus_loggers))
@ -2455,7 +2527,7 @@ def prometheus_label_factory(
if UserAPIKeyLabelNames.END_USER.value in filtered_labels:
get_end_user_id_for_cost_tracking = _get_cached_end_user_id_for_cost_tracking()
filtered_labels["end_user"] = get_end_user_id_for_cost_tracking(
litellm_params={"user_api_key_end_user_id": enum_values.end_user},
service_type="prometheus",

View File

@ -4839,9 +4839,9 @@ class StandardLoggingPayloadSetup:
metadata = litellm_params.get("metadata") or {}
litellm_metadata = litellm_params.get("litellm_metadata") or {}
if metadata.get("tags", []):
request_tags = metadata.get("tags", [])
request_tags = metadata.get("tags", []).copy()
elif litellm_metadata.get("tags", []):
request_tags = litellm_metadata.get("tags", [])
request_tags = litellm_metadata.get("tags", []).copy()
else:
request_tags = []
user_agent_tags = StandardLoggingPayloadSetup._get_user_agent_tags(

View File

@ -61,6 +61,10 @@
"max_completion_tokens": "max_tokens"
}
},
"abliteration": {
"base_url": "https://api.abliteration.ai/v1",
"api_key_env": "ABLITERATION_API_KEY"
},
"llamagate": {
"base_url": "https://api.llamagate.dev/v1",
"api_key_env": "LLAMAGATE_API_KEY",

View File

@ -0,0 +1,182 @@
"""
OpenRouter Embedding API Configuration.
This module provides the configuration for OpenRouter's Embedding API.
OpenRouter is OpenAI-compatible and supports embeddings via the /v1/embeddings endpoint.
Docs: https://openrouter.ai/docs
"""
from typing import TYPE_CHECKING, Any, Optional
import httpx
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
from litellm.types.llms.openai import AllEmbeddingInputValues
from litellm.types.utils import EmbeddingResponse
from litellm.utils import convert_to_model_response_object
from ..common_utils import OpenRouterException
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
LiteLLMLoggingObj = _LiteLLMLoggingObj
else:
LiteLLMLoggingObj = Any
class OpenrouterEmbeddingConfig(BaseEmbeddingConfig):
"""
Configuration for OpenRouter's Embedding API.
Reference: https://openrouter.ai/docs
"""
def validate_environment(
self,
headers: dict,
model: str,
messages: list,
optional_params: dict,
litellm_params: dict,
api_key: Optional[str] = None,
api_base: Optional[str] = None,
) -> dict:
"""
Validate environment and set up headers for OpenRouter API.
OpenRouter requires:
- Authorization header with Bearer token
- HTTP-Referer header (site URL)
- X-Title header (app name)
"""
from litellm import get_secret
# Get OpenRouter-specific headers
openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai"
openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM"
openrouter_headers = {
"HTTP-Referer": openrouter_site_url,
"X-Title": openrouter_app_name,
"Content-Type": "application/json",
}
# Add Authorization header if api_key is provided
if api_key:
openrouter_headers["Authorization"] = f"Bearer {api_key}"
# Merge with existing headers (user's extra_headers take priority)
merged_headers = {**openrouter_headers, **headers}
return merged_headers
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
optional_params: dict,
litellm_params: dict,
stream: Optional[bool] = None,
) -> str:
"""
Get the complete URL for OpenRouter Embedding API endpoint.
"""
# api_base is already set to https://openrouter.ai/api/v1 in main.py
# Remove trailing slashes
if api_base:
api_base = api_base.rstrip("/")
else:
api_base = "https://openrouter.ai/api/v1"
# Return the embeddings endpoint
return f"{api_base}/embeddings"
def transform_embedding_request(
self,
model: str,
input: AllEmbeddingInputValues,
optional_params: dict,
headers: dict,
) -> dict:
"""
Transform embedding request to OpenRouter format (OpenAI-compatible).
"""
# Ensure input is a list
if isinstance(input, str):
input = [input]
# OpenRouter expects the full model name (e.g., google/gemini-embedding-001)
# Strip 'openrouter/' prefix if present
if model.startswith("openrouter/"):
model = model.replace("openrouter/", "", 1)
return {
"model": model,
"input": input,
**optional_params,
}
def transform_embedding_response(
self,
model: str,
raw_response: httpx.Response,
model_response: EmbeddingResponse,
logging_obj: LiteLLMLoggingObj,
api_key: Optional[str],
request_data: dict,
optional_params: dict,
litellm_params: dict,
) -> EmbeddingResponse:
"""
Transform embedding response from OpenRouter format (OpenAI-compatible).
"""
logging_obj.post_call(original_response=raw_response.text)
# OpenRouter returns standard OpenAI-compatible embedding response
response_json = raw_response.json()
return convert_to_model_response_object(
response_object=response_json,
model_response_object=model_response,
response_type="embedding",
)
def get_supported_openai_params(self, model: str) -> list:
"""
Get list of supported OpenAI parameters for OpenRouter embeddings.
"""
return [
"timeout",
"dimensions",
"encoding_format",
"user",
]
def map_openai_params(
self,
non_default_params: dict,
optional_params: dict,
model: str,
drop_params: bool,
) -> dict:
"""
Map OpenAI parameters to OpenRouter format.
"""
for param, value in non_default_params.items():
if param in self.get_supported_openai_params(model):
optional_params[param] = value
return optional_params
def get_error_class(
self, error_message: str, status_code: int, headers: Any
) -> Any:
"""
Get the error class for OpenRouter errors.
"""
return OpenRouterException(
message=error_message,
status_code=status_code,
headers=headers,
)

View File

@ -4707,6 +4707,51 @@ def embedding( # noqa: PLR0915
litellm_params=litellm_params_dict,
headers=headers,
)
elif custom_llm_provider == "openrouter":
api_base = (
api_base
or litellm.api_base
or get_secret_str("OPENROUTER_API_BASE")
or "https://openrouter.ai/api/v1"
)
api_key = (
api_key
or litellm.api_key
or litellm.openrouter_key
or get_secret("OPENROUTER_API_KEY")
or get_secret("OR_API_KEY")
)
openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai"
openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM"
openrouter_headers = {
"HTTP-Referer": openrouter_site_url,
"X-Title": openrouter_app_name,
}
_headers = headers or litellm.headers
if _headers:
openrouter_headers.update(_headers)
headers = openrouter_headers
response = base_llm_http_handler.embedding(
model=model,
input=input,
custom_llm_provider=custom_llm_provider,
api_base=api_base,
api_key=api_key,
logging_obj=logging,
timeout=timeout,
model_response=EmbeddingResponse(),
optional_params=optional_params,
client=client,
aembedding=aembedding,
litellm_params=litellm_params_dict,
headers=headers,
)
elif custom_llm_provider == "huggingface":
api_key = (
api_key

View File

@ -1942,7 +1942,7 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
)
database_connection_pool_limit: Optional[int] = Field(
100,
10,
description="default connection pool for prisma client connecting to postgres db",
)
database_connection_timeout: Optional[float] = Field(

View File

@ -11,7 +11,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.integrations.custom_guardrail import ModifyResponseException
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
create_streaming_response,
create_response,
)
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.types.utils import TokenCountResponse
@ -106,7 +106,7 @@ async def anthropic_response( # noqa: PLR0915
)
)
return await create_streaming_response(
return await create_response(
generator=selected_data_generator,
media_type="text/event-stream",
headers={},

View File

@ -17,7 +17,7 @@ from typing import (
import httpx
import orjson
from fastapi import HTTPException, Request, status
from fastapi.responses import Response, StreamingResponse
from fastapi.responses import JSONResponse, Response, StreamingResponse
import litellm
from litellm._logging import verbose_proxy_logger
@ -96,16 +96,55 @@ async def _parse_event_data_for_error(event_line: Union[str, bytes]) -> Optional
return None
async def create_streaming_response(
def _extract_error_from_sse_chunk(event_line: Union[str, bytes]) -> dict:
"""
Extract error dictionary from SSE format chunk.
Args:
event_line: SSE format event line, e.g. "data: {"error": {...}}\n\n"
Returns:
Error dictionary in OpenAI API format
"""
event_line = (
event_line.decode("utf-8") if isinstance(event_line, bytes) else event_line
)
# Default error format
default_error = {
"message": "Unknown error",
"type": "internal_server_error",
"param": None,
"code": "500",
}
if event_line.startswith("data: "):
json_str = event_line[len("data: ") :].strip()
if not json_str or json_str == "[DONE]":
return default_error
try:
data = orjson.loads(json_str)
if isinstance(data, dict) and "error" in data:
error_obj = data["error"]
if isinstance(error_obj, dict):
return error_obj
except (orjson.JSONDecodeError, json.JSONDecodeError):
pass
return default_error
async def create_response(
generator: AsyncGenerator[str, None],
media_type: str,
headers: dict,
default_status_code: int = status.HTTP_200_OK,
) -> StreamingResponse:
) -> Union[StreamingResponse, JSONResponse]:
"""
Creates a StreamingResponse by inspecting the first chunk for an error code.
The entire original generator content is streamed, but the HTTP status code
of the response is set based on the first chunk if it's a recognized error.
Create streaming response, checking if the first chunk is an error.
If the first chunk is an error, return a standard JSON error response.
Otherwise, return StreamingResponse and stream all content.
"""
first_chunk_value: Optional[str] = None
final_status_code = default_status_code
@ -124,9 +163,27 @@ async def create_streaming_response(
first_chunk_value
)
if error_code_from_chunk is not None:
# First chunk is an error, stream hasn't really started yet
# Should return standard JSON error response instead of SSE format
final_status_code = error_code_from_chunk
verbose_proxy_logger.debug(
f"Error detected in first stream chunk. Status code set to: {final_status_code}"
f"Error detected in first stream chunk. Returning JSON error response with status code: {final_status_code}"
)
# Parse error content
error_dict = _extract_error_from_sse_chunk(first_chunk_value)
# Consume and close generator (avoid resource leak)
try:
await generator.aclose()
except Exception:
pass
# Return JSON format error response
return JSONResponse(
status_code=final_status_code,
content={"error": error_dict},
headers=headers,
)
except Exception as e:
verbose_proxy_logger.debug(f"Error parsing first chunk value: {e}")
@ -670,7 +727,7 @@ class ProxyBaseLLMRequestProcessing:
proxy_logging_obj=proxy_logging_obj,
)
)
return await create_streaming_response(
return await create_response(
generator=selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
@ -681,7 +738,7 @@ class ProxyBaseLLMRequestProcessing:
user_api_key_dict=user_api_key_dict,
request_data=self.data,
)
return await create_streaming_response(
return await create_response(
generator=selected_data_generator,
media_type="text/event-stream",
headers=custom_headers,
@ -923,11 +980,11 @@ class ProxyBaseLLMRequestProcessing:
@staticmethod
def _get_pre_call_type(
route_type: Literal["acompletion", "aembedding", "aresponses", "allm_passthrough_route"],
) -> Literal["completion", "embeddings", "responses", "allm_passthrough_route"]:
) -> Literal["completion", "embedding", "responses", "allm_passthrough_route"]:
if route_type == "acompletion":
return "completion"
elif route_type == "aembedding":
return "embeddings"
return "embedding"
elif route_type == "aresponses":
return "responses"
elif route_type == "allm_passthrough_route":

View File

@ -5,6 +5,7 @@
# +-------------------------------------------------------------+
# Qualifire - Evaluate LLM outputs for quality, safety, and reliability
import json
import os
from typing import Any, Dict, List, Literal, Optional, Type
@ -15,12 +16,17 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import AllMessageValues
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
from litellm.types.utils import GenericGuardrailAPIInputs
GUARDRAIL_NAME = "qualifire"
DEFAULT_QUALIFIRE_API_BASE = "https://proxy.qualifire.ai"
class QualifireGuardrail(CustomGuardrail):
@ -44,7 +50,7 @@ class QualifireGuardrail(CustomGuardrail):
Args:
api_key: API key for Qualifire (or use QUALIFIRE_API_KEY env var)
api_base: Optional custom API base URL
api_base: Optional custom API base URL (defaults to https://api.qualifire.ai)
evaluation_id: Pre-configured evaluation ID from Qualifire dashboard
prompt_injections: Enable prompt injection detection (default if no other checks)
hallucinations_check: Enable hallucination detection
@ -64,6 +70,7 @@ class QualifireGuardrail(CustomGuardrail):
api_base
or get_secret_str("QUALIFIRE_BASE_URL")
or os.environ.get("QUALIFIRE_BASE_URL")
or DEFAULT_QUALIFIRE_API_BASE
)
self.evaluation_id = evaluation_id
self.prompt_injections = prompt_injections
@ -79,7 +86,11 @@ class QualifireGuardrail(CustomGuardrail):
if not self._has_any_check_enabled() and not self.evaluation_id:
self.prompt_injections = True
self._client = None
# Initialize async HTTP client for direct API calls
self.async_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback
)
super().__init__(**kwargs)
def _has_any_check_enabled(self) -> bool:
@ -96,43 +107,22 @@ class QualifireGuardrail(CustomGuardrail):
]
)
def _get_client(self):
"""Lazy initialization of Qualifire client."""
if self._client is None:
try:
from qualifire.client import Client
except ImportError:
raise ImportError(
"qualifire package is required for QualifireGuardrail. "
"Install it with: pip install qualifire"
)
client_kwargs: Dict[str, Any] = {}
if self.qualifire_api_key:
client_kwargs["api_key"] = self.qualifire_api_key
if self.qualifire_api_base:
client_kwargs["base_url"] = self.qualifire_api_base
self._client = Client(**client_kwargs)
return self._client
def _convert_messages_to_qualifire_format(
def _convert_messages_to_api_format(
self, messages: List[AllMessageValues]
) -> List[Any]:
) -> List[Dict[str, Any]]:
"""
Convert LiteLLM messages to Qualifire's LLMMessage format.
Convert LiteLLM messages to Qualifire API format.
Supports tool calls for tool_selection_quality_check.
"""
try:
from qualifire.types import LLMMessage, LLMToolCall
except ImportError:
raise ImportError(
"qualifire package is required for QualifireGuardrail. "
"Install it with: pip install qualifire"
)
qualifire_messages = []
Returns a list of dicts matching the API's ModelInvocationCanonicalMessage schema:
{
"role": "user" | "assistant" | "system" | "tool",
"content": "...",
"tool_call_id": "...", # optional
"tool_calls": [{"id": "...", "name": "...", "arguments": {...}}] # optional
}
"""
api_messages = []
for msg in messages:
role = msg.get("role", "user")
content = msg.get("content", "")
@ -147,42 +137,86 @@ class QualifireGuardrail(CustomGuardrail):
text_parts.append(part)
content = "\n".join(text_parts)
llm_message_kwargs: Dict[str, Any] = {
api_message: Dict[str, Any] = {
"role": role,
"content": content if isinstance(content, str) else str(content),
}
# Handle tool_call_id for tool response messages
tool_call_id = msg.get("tool_call_id")
if tool_call_id:
api_message["tool_call_id"] = tool_call_id
# Handle tool calls if present
tool_calls = msg.get("tool_calls")
if tool_calls and isinstance(tool_calls, list):
qualifire_tool_calls = []
api_tool_calls = []
for tc in tool_calls:
if isinstance(tc, dict):
function_info = tc.get("function", {})
# Arguments can be a string (JSON) or dict
args = function_info.get("arguments", {})
if isinstance(args, str):
import json
try:
args = json.loads(args)
except json.JSONDecodeError:
args = {}
qualifire_tool_calls.append(
LLMToolCall(
id=tc.get("id") or "",
name=function_info.get("name") or "",
arguments=args if isinstance(args, dict) else {},
)
api_tool_calls.append(
{
"id": tc.get("id") or "",
"name": function_info.get("name") or "",
"arguments": args if isinstance(args, dict) else {},
}
)
if qualifire_tool_calls:
llm_message_kwargs["tool_calls"] = qualifire_tool_calls
if api_tool_calls:
api_message["tool_calls"] = api_tool_calls
qualifire_messages.append(LLMMessage(**llm_message_kwargs))
api_messages.append(api_message)
return qualifire_messages
return api_messages
def _check_if_flagged(self, result: Any) -> bool:
def _convert_tools_to_api_format(
self, tools: Optional[List[Any]]
) -> Optional[List[Dict[str, Any]]]:
"""
Convert OpenAI-format tools to Qualifire API format.
Returns a list of dicts matching the API's ModelInvocationToolDefinition schema:
{
"name": "...",
"description": "...",
"parameters": {...}
}
"""
if not tools:
return None
api_tools = []
for tool in tools:
if isinstance(tool, dict):
# Handle OpenAI function tool format
if tool.get("type") == "function":
function_def = tool.get("function", {})
api_tools.append(
{
"name": function_def.get("name", ""),
"description": function_def.get("description", ""),
"parameters": function_def.get("parameters", {}),
}
)
# Handle direct tool format
elif "name" in tool:
api_tools.append(
{
"name": tool.get("name", ""),
"description": tool.get("description", ""),
"parameters": tool.get("parameters", {}),
}
)
return api_tools if api_tools else None
def _check_if_flagged(self, result: Dict[str, Any]) -> bool:
"""
Check if the Qualifire evaluation result indicates flagged content.
@ -190,65 +224,53 @@ class QualifireGuardrail(CustomGuardrail):
A high score (close to 100) indicates GOOD content, low score indicates problems.
"""
# Check evaluation results for any flagged items
evaluation_results = getattr(result, "evaluationResults", None) or []
if isinstance(result, dict):
evaluation_results = result.get("evaluationResults", []) or []
evaluation_results = result.get("evaluationResults", []) or []
for eval_result in evaluation_results:
results: List[Any] = []
if isinstance(eval_result, dict):
results = eval_result.get("results", []) or []
else:
results = getattr(eval_result, "results", []) or []
results = eval_result.get("results", []) or []
for r in results:
flagged = (
r.get("flagged")
if isinstance(r, dict)
else getattr(r, "flagged", False)
)
if flagged:
if r.get("flagged"):
return True
return False
def _build_evaluate_kwargs(
def _build_evaluate_payload(
self,
qualifire_messages: List[Any],
api_messages: List[Dict[str, Any]],
output: Optional[str],
assertions: Optional[List[str]],
available_tools: Optional[List[Any]],
available_tools: Optional[List[Dict[str, Any]]],
) -> Dict[str, Any]:
"""Build kwargs dictionary for the evaluate call."""
kwargs: Dict[str, Any] = {"messages": qualifire_messages}
"""Build payload dictionary for the /api/evaluation/evaluate endpoint."""
payload: Dict[str, Any] = {"messages": api_messages}
if output is not None:
kwargs["output"] = output
payload["output"] = output
# Add enabled checks
if self.prompt_injections:
kwargs["prompt_injections"] = True
payload["prompt_injections"] = True
if self.hallucinations_check:
kwargs["hallucinations_check"] = True
payload["hallucinations_check"] = True
if self.grounding_check:
kwargs["grounding_check"] = True
payload["grounding_check"] = True
if self.pii_check:
kwargs["pii_check"] = True
payload["pii_check"] = True
if self.content_moderation_check:
kwargs["content_moderation_check"] = True
payload["content_moderation_check"] = True
if self.tool_selection_quality_check:
# Only enable tool_selection_quality_check if available_tools is provided
if available_tools:
kwargs["tool_selection_quality_check"] = True
kwargs["available_tools"] = available_tools
payload["tool_selection_quality_check"] = True
payload["available_tools"] = available_tools
else:
verbose_proxy_logger.debug(
"Qualifire Guardrail: tool_selection_quality_check enabled but no available_tools provided, skipping this check"
)
if assertions:
kwargs["assertions"] = assertions
payload["assertions"] = assertions
return kwargs
return payload
async def _run_qualifire_check(
self,
@ -274,11 +296,17 @@ class QualifireGuardrail(CustomGuardrail):
assertions = dynamic_params.get("assertions") or self.assertions
on_flagged = dynamic_params.get("on_flagged") or self.on_flagged
try:
client = self._get_client()
qualifire_messages = self._convert_messages_to_qualifire_format(messages)
# Prepare headers
headers = {
"X-Qualifire-API-Key": self.qualifire_api_key or "",
"Content-Type": "application/json",
}
# Use invoke_evaluation if evaluation_id is provided
try:
# Convert messages to API format
api_messages = self._convert_messages_to_api_format(messages)
# Use invoke endpoint if evaluation_id is provided
if evaluation_id:
# For invoke_evaluation, we need to extract input/output
input_text = ""
@ -291,25 +319,47 @@ class QualifireGuardrail(CustomGuardrail):
input_text = content
break
result = client.invoke_evaluation(
evaluation_id=evaluation_id,
input=input_text,
output=output or "",
)
payload = {
"evaluation_id": evaluation_id,
"input": input_text,
"output": output or "",
"messages": api_messages,
}
# Convert tools if provided
api_tools = self._convert_tools_to_api_format(available_tools)
if api_tools:
payload["available_tools"] = api_tools
url = f"{self.qualifire_api_base}/api/evaluation/invoke"
else:
# Use evaluate with individual checks
kwargs = self._build_evaluate_kwargs(
qualifire_messages=qualifire_messages,
# Use evaluate endpoint with individual checks
api_tools = self._convert_tools_to_api_format(available_tools)
payload = self._build_evaluate_payload(
api_messages=api_messages,
output=output,
assertions=assertions,
available_tools=available_tools,
available_tools=api_tools,
)
result = client.evaluate(**kwargs)
url = f"{self.qualifire_api_base}/api/evaluation/evaluate"
# Convert result to dict for logging
verbose_proxy_logger.debug(
f"Qualifire Guardrail: Making request to {url}"
)
# Make the API request
response = await self.async_handler.post(
url=url,
headers=headers,
json=payload,
)
response.raise_for_status()
result = response.json()
# Extract response info for logging
qualifire_response = {
"score": getattr(result, "score", None),
"status": getattr(result, "status", None),
"score": result.get("score"),
"status": result.get("status"),
}
verbose_proxy_logger.debug(

View File

@ -229,7 +229,7 @@ from litellm.proxy.batches_endpoints.endpoints import router as batches_router
from litellm.proxy.caching_routes import router as caching_router
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
create_streaming_response,
create_response,
)
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
@ -6824,7 +6824,7 @@ async def run_thread(
if (
"stream" in data and data["stream"] is True
): # use generate_responses to stream responses
return await create_streaming_response(
return await create_response(
generator=async_assistants_data_generator(
user_api_key_dict=user_api_key_dict,
response=response,

View File

@ -1,7 +1,7 @@
import re
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List, Literal, Optional, Tuple, Union
from typing import Dict, List, Literal, Optional, Tuple
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -185,6 +185,10 @@ DEFINED_PROMETHEUS_METRICS = Literal[
"litellm_redis_daily_spend_update_queue_size",
"litellm_in_memory_spend_update_queue_size",
"litellm_redis_spend_update_queue_size",
# Cache metrics
"litellm_cache_hits_metric",
"litellm_cache_misses_metric",
"litellm_cached_tokens_metric",
]
@ -436,6 +440,21 @@ class PrometheusMetricLabels:
litellm_redis_spend_update_queue_size: List[str] = []
# Cache metrics - track cache hits, misses, and tokens served from cache
_cache_metric_labels = [
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value,
UserAPIKeyLabelNames.API_KEY_HASH.value,
UserAPIKeyLabelNames.API_KEY_ALIAS.value,
UserAPIKeyLabelNames.TEAM.value,
UserAPIKeyLabelNames.TEAM_ALIAS.value,
UserAPIKeyLabelNames.END_USER.value,
UserAPIKeyLabelNames.USER.value,
]
litellm_cache_hits_metric = _cache_metric_labels
litellm_cache_misses_metric = _cache_metric_labels
litellm_cached_tokens_metric = _cache_metric_labels
@staticmethod
def get_labels(label_name: DEFINED_PROMETHEUS_METRICS) -> List[str]:
default_labels = getattr(PrometheusMetricLabels, label_name)

View File

@ -7718,6 +7718,11 @@ class ProviderConfigManager:
return litellm.CometAPIEmbeddingConfig()
elif litellm.LlmProviders.GITHUB_COPILOT == provider:
return litellm.GithubCopilotEmbeddingConfig()
elif litellm.LlmProviders.OPENROUTER == provider:
from litellm.llms.openrouter.embedding.transformation import (
OpenrouterEmbeddingConfig,
)
return OpenrouterEmbeddingConfig()
elif litellm.LlmProviders.GIGACHAT == provider:
return litellm.GigaChatEmbeddingConfig()
elif litellm.LlmProviders.SAGEMAKER == provider:

View File

@ -32,6 +32,23 @@
}
},
"providers": {
"abliteration": {
"display_name": "Abliteration (`abliteration`)",
"url": "https://docs.litellm.ai/docs/providers/abliteration",
"endpoints": {
"chat_completions": true,
"messages": false,
"responses": false,
"embeddings": false,
"image_generations": false,
"audio_transcriptions": false,
"audio_speech": false,
"moderations": false,
"batches": false,
"rerank": false,
"a2a": false
}
},
"aiml": {
"display_name": "AI/ML API (`aiml`)",
"url": "https://docs.litellm.ai/docs/providers/aiml",
@ -1559,7 +1576,7 @@
"chat_completions": true,
"messages": true,
"responses": true,
"embeddings": false,
"embeddings": true,
"image_generations": false,
"audio_transcriptions": false,
"audio_speech": false,

View File

@ -0,0 +1,50 @@
"""
Unit tests for the Abliteration OpenAI-like provider.
"""
import os
import sys
sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../.."))
)
from litellm.llms.openai_like.dynamic_config import create_config_class
from litellm.llms.openai_like.json_loader import JSONProviderRegistry
ABLITERATION_BASE_URL = "https://api.abliteration.ai/v1"
def _get_config():
provider = JSONProviderRegistry.get("abliteration")
assert provider is not None
config_class = create_config_class(provider)
return config_class()
def test_abliteration_provider_registered():
provider = JSONProviderRegistry.get("abliteration")
assert provider is not None
assert provider.base_url == ABLITERATION_BASE_URL
assert provider.api_key_env == "ABLITERATION_API_KEY"
def test_abliteration_resolves_env_api_key(monkeypatch):
config = _get_config()
monkeypatch.setenv("ABLITERATION_API_KEY", "test-key")
api_base, api_key = config._get_openai_compatible_provider_info(None, None)
assert api_base == ABLITERATION_BASE_URL
assert api_key == "test-key"
def test_abliteration_complete_url_appends_endpoint():
config = _get_config()
url = config.get_complete_url(
api_base=ABLITERATION_BASE_URL,
api_key="test-key",
model="abliteration/abliterated-model",
optional_params={},
litellm_params={},
stream=False,
)
assert url == f"{ABLITERATION_BASE_URL}/chat/completions"

View File

@ -32,3 +32,18 @@ def test_completion_openrouter_image_generation():
.message.images[0]["image_url"]["url"]
.startswith("data:image/png;base64,")
)
def test_openrouter_embedding():
"""Test OpenRouter embeddings support."""
litellm._turn_on_debug()
resp = litellm.embedding(
model="openrouter/openai/text-embedding-3-small",
input=["Hello world", "How are you?"],
)
print(resp)
assert resp is not None
assert len(resp.data) == 2
assert resp.data[0]["embedding"] is not None
assert isinstance(resp.data[0]["embedding"], list)
assert len(resp.data[0]["embedding"]) > 0

View File

@ -530,7 +530,7 @@ class TestPassthroughCallTypeHandling:
)
assert (
ProxyBaseLLMRequestProcessing._get_pre_call_type(route_type="aembedding")
== "embeddings"
== "embedding"
)
assert (
ProxyBaseLLMRequestProcessing._get_pre_call_type(route_type="aresponses")

View File

@ -0,0 +1,211 @@
"""
Unit tests for cache Prometheus metrics.
Run with: poetry run pytest tests/test_litellm/integrations/test_prometheus_cache_metrics.py -v
"""
import pytest
from unittest.mock import MagicMock, patch
from litellm.types.integrations.prometheus import UserAPIKeyLabelValues
class TestPrometheusCacheMetrics:
"""Tests for cache-related Prometheus metrics"""
@pytest.fixture
def sample_enum_values(self):
"""Create sample enum values for labels"""
return UserAPIKeyLabelValues(
end_user="test-end-user",
hashed_api_key="test-key-hash",
api_key_alias="test-key-alias",
team="test-team",
team_alias="test-team-alias",
user="test-user",
model="gpt-3.5-turbo",
)
def test_cache_metrics_defined_in_types(self):
"""Test that cache metrics are defined in DEFINED_PROMETHEUS_METRICS"""
from litellm.types.integrations.prometheus import DEFINED_PROMETHEUS_METRICS
from typing import get_args
defined_metrics = get_args(DEFINED_PROMETHEUS_METRICS)
assert "litellm_cache_hits_metric" in defined_metrics
assert "litellm_cache_misses_metric" in defined_metrics
assert "litellm_cached_tokens_metric" in defined_metrics
def test_cache_metric_labels_defined(self):
"""Test that cache metric labels are properly defined"""
from litellm.types.integrations.prometheus import PrometheusMetricLabels
# Verify labels are defined for each cache metric
assert hasattr(PrometheusMetricLabels, "litellm_cache_hits_metric")
assert hasattr(PrometheusMetricLabels, "litellm_cache_misses_metric")
assert hasattr(PrometheusMetricLabels, "litellm_cached_tokens_metric")
# Verify labels include expected keys
expected_labels = [
"model",
"hashed_api_key",
"api_key_alias",
"team",
"team_alias",
"end_user",
"user",
]
for label in expected_labels:
assert label in PrometheusMetricLabels.litellm_cache_hits_metric
assert label in PrometheusMetricLabels.litellm_cache_misses_metric
assert label in PrometheusMetricLabels.litellm_cached_tokens_metric
def test_increment_cache_metrics_on_cache_hit(self, sample_enum_values):
"""Test that cache hit increments the correct metrics"""
# Create mock for PrometheusLogger instance
mock_logger = MagicMock()
# Import the method directly and bind it to our mock
from litellm.integrations.prometheus import PrometheusLogger
# Create a mock standard logging payload with cache_hit=True
standard_logging_payload = {
"cache_hit": True,
"total_tokens": 100,
"prompt_tokens": 50,
"completion_tokens": 50,
"model_group": "openai",
"request_tags": [],
}
# Create mock metrics
mock_logger.litellm_cache_hits_metric = MagicMock()
mock_logger.litellm_cache_misses_metric = MagicMock()
mock_logger.litellm_cached_tokens_metric = MagicMock()
mock_logger.get_labels_for_metric = MagicMock(
return_value=[
"model",
"hashed_api_key",
"api_key_alias",
"team",
"team_alias",
"end_user",
"user",
]
)
# Call the method using unbound method approach
PrometheusLogger._increment_cache_metrics(
mock_logger,
standard_logging_payload=standard_logging_payload,
enum_values=sample_enum_values,
)
# Verify cache hits metric was incremented
mock_logger.litellm_cache_hits_metric.labels.assert_called()
mock_logger.litellm_cache_hits_metric.labels().inc.assert_called_once()
# Verify cached tokens metric was incremented with total_tokens
mock_logger.litellm_cached_tokens_metric.labels.assert_called()
mock_logger.litellm_cached_tokens_metric.labels().inc.assert_called_once_with(
100
)
# Verify cache misses metric was NOT called
mock_logger.litellm_cache_misses_metric.labels.assert_not_called()
def test_increment_cache_metrics_on_cache_miss(self, sample_enum_values):
"""Test that cache miss increments the correct metrics"""
# Create mock for PrometheusLogger instance
mock_logger = MagicMock()
from litellm.integrations.prometheus import PrometheusLogger
# Create a mock standard logging payload with cache_hit=False
standard_logging_payload = {
"cache_hit": False,
"total_tokens": 100,
"prompt_tokens": 50,
"completion_tokens": 50,
"model_group": "openai",
"request_tags": [],
}
# Create mock metrics
mock_logger.litellm_cache_hits_metric = MagicMock()
mock_logger.litellm_cache_misses_metric = MagicMock()
mock_logger.litellm_cached_tokens_metric = MagicMock()
mock_logger.get_labels_for_metric = MagicMock(
return_value=[
"model",
"hashed_api_key",
"api_key_alias",
"team",
"team_alias",
"end_user",
"user",
]
)
# Call the method
PrometheusLogger._increment_cache_metrics(
mock_logger,
standard_logging_payload=standard_logging_payload,
enum_values=sample_enum_values,
)
# Verify cache misses metric was incremented
mock_logger.litellm_cache_misses_metric.labels.assert_called()
mock_logger.litellm_cache_misses_metric.labels().inc.assert_called_once()
# Verify cache hits and cached tokens metrics were NOT called
mock_logger.litellm_cache_hits_metric.labels.assert_not_called()
mock_logger.litellm_cached_tokens_metric.labels.assert_not_called()
def test_increment_cache_metrics_when_cache_hit_is_none(self, sample_enum_values):
"""Test that no metrics are incremented when cache_hit is None"""
# Create mock for PrometheusLogger instance
mock_logger = MagicMock()
from litellm.integrations.prometheus import PrometheusLogger
# Create a mock standard logging payload with cache_hit=None
standard_logging_payload = {
"cache_hit": None,
"total_tokens": 100,
"prompt_tokens": 50,
"completion_tokens": 50,
"model_group": "openai",
"request_tags": [],
}
# Create mock metrics
mock_logger.litellm_cache_hits_metric = MagicMock()
mock_logger.litellm_cache_misses_metric = MagicMock()
mock_logger.litellm_cached_tokens_metric = MagicMock()
mock_logger.get_labels_for_metric = MagicMock(
return_value=[
"model",
"hashed_api_key",
"api_key_alias",
"team",
"team_alias",
"end_user",
"user",
]
)
# Call the method
PrometheusLogger._increment_cache_metrics(
mock_logger,
standard_logging_payload=standard_logging_payload,
enum_values=sample_enum_values,
)
# Verify NO metrics were called
mock_logger.litellm_cache_hits_metric.labels.assert_not_called()
mock_logger.litellm_cache_misses_metric.labels.assert_not_called()
mock_logger.litellm_cached_tokens_metric.labels.assert_not_called()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -393,6 +393,63 @@ def test_get_request_tags_from_metadata_and_litellm_metadata():
assert "User-Agent: litellm/1.0.0" in tags
def test_get_request_tags_does_not_mutate_original_tags():
"""
Test that _get_request_tags does not mutate the original tags list in metadata.
This is a regression test for a bug where calling _get_request_tags multiple times
would cause User-Agent tags to be duplicated because the function was mutating
the original tags list instead of creating a copy.
"""
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
# Create metadata with original tags
original_tags = ["custom-tag-1", "custom-tag-2"]
metadata = {"tags": original_tags}
litellm_params = {"metadata": metadata}
proxy_server_request = {
"headers": {
"user-agent": "AsyncOpenAI/Python 1.99.9",
}
}
# Call _get_request_tags multiple times (simulating multiple callbacks)
tags1 = StandardLoggingPayloadSetup._get_request_tags(
litellm_params=litellm_params,
proxy_server_request=proxy_server_request,
)
tags2 = StandardLoggingPayloadSetup._get_request_tags(
litellm_params=litellm_params,
proxy_server_request=proxy_server_request,
)
tags3 = StandardLoggingPayloadSetup._get_request_tags(
litellm_params=litellm_params,
proxy_server_request=proxy_server_request,
)
# Verify the original tags list was NOT mutated
assert original_tags == ["custom-tag-1", "custom-tag-2"], (
f"Original tags list was mutated: {original_tags}"
)
assert metadata["tags"] == ["custom-tag-1", "custom-tag-2"], (
f"metadata['tags'] was mutated: {metadata['tags']}"
)
# Verify each returned list has exactly 2 User-Agent tags (not duplicated)
user_agent_count_1 = len([t for t in tags1 if t.startswith("User-Agent:")])
user_agent_count_2 = len([t for t in tags2 if t.startswith("User-Agent:")])
user_agent_count_3 = len([t for t in tags3 if t.startswith("User-Agent:")])
assert user_agent_count_1 == 2, f"Expected 2 User-Agent tags, got {user_agent_count_1}"
assert user_agent_count_2 == 2, f"Expected 2 User-Agent tags, got {user_agent_count_2}"
assert user_agent_count_3 == 2, f"Expected 2 User-Agent tags, got {user_agent_count_3}"
# Verify all returned lists are independent (different objects)
assert tags1 is not tags2
assert tags2 is not tags3
assert tags1 is not original_tags
def test_get_extra_header_tags():
"""Test the _get_extra_header_tags method with various scenarios."""
import litellm

View File

@ -0,0 +1,132 @@
"""
Unit tests for OpenRouter embedding transformation logic.
"""
from litellm.llms.openrouter.embedding.transformation import (
OpenrouterEmbeddingConfig,
)
def test_openrouter_embedding_supported_params():
"""Test that supported OpenAI params are correctly defined."""
config = OpenrouterEmbeddingConfig()
supported = config.get_supported_openai_params("test-model")
assert "timeout" in supported
assert "dimensions" in supported
assert "encoding_format" in supported
assert "user" in supported
def test_openrouter_embedding_transform_request():
"""Test request transformation logic."""
config = OpenrouterEmbeddingConfig()
# Test with string input
result = config.transform_embedding_request(
model="openrouter/google/text-embedding-004",
input="Hello world",
optional_params={},
headers={},
)
assert result["model"] == "google/text-embedding-004"
assert result["input"] == ["Hello world"]
# Test with list input
result = config.transform_embedding_request(
model="google/text-embedding-004",
input=["Hello", "World"],
optional_params={"dimensions": 512},
headers={},
)
assert result["model"] == "google/text-embedding-004"
assert result["input"] == ["Hello", "World"]
assert result["dimensions"] == 512
def test_openrouter_embedding_validate_environment():
"""Test environment validation and header setup."""
config = OpenrouterEmbeddingConfig()
# Test with API key
headers = config.validate_environment(
headers={"Custom-Header": "value"},
model="test-model",
messages=[],
optional_params={},
litellm_params={},
api_key="test-api-key",
)
# Should include OpenRouter-specific headers
assert "HTTP-Referer" in headers
assert "X-Title" in headers
# Should include Content-Type header
assert "Content-Type" in headers
assert headers["Content-Type"] == "application/json"
# Should include Authorization header
assert "Authorization" in headers
assert headers["Authorization"] == "Bearer test-api-key"
# Should preserve custom headers
assert headers["Custom-Header"] == "value"
# Test without API key
headers_no_key = config.validate_environment(
headers={},
model="test-model",
messages=[],
optional_params={},
litellm_params={},
api_key=None,
)
# Should still include OpenRouter headers but not Authorization
assert "HTTP-Referer" in headers_no_key
assert "X-Title" in headers_no_key
assert "Content-Type" in headers_no_key
assert "Authorization" not in headers_no_key
def test_openrouter_embedding_get_complete_url():
"""Test URL construction."""
config = OpenrouterEmbeddingConfig()
url = config.get_complete_url(
api_base="https://openrouter.ai/api/v1",
api_key="test-key",
model="test-model",
optional_params={},
litellm_params={},
)
assert url == "https://openrouter.ai/api/v1/embeddings"
# Test with trailing slash
url = config.get_complete_url(
api_base="https://openrouter.ai/api/v1/",
api_key="test-key",
model="test-model",
optional_params={},
litellm_params={},
)
assert url == "https://openrouter.ai/api/v1/embeddings"
def test_openrouter_embedding_map_params():
"""Test parameter mapping."""
config = OpenrouterEmbeddingConfig()
result = config.map_openai_params(
non_default_params={"dimensions": 512, "timeout": 30, "unsupported": "value"},
optional_params={},
model="test-model",
drop_params=False,
)
# Supported params should be included
assert result["dimensions"] == 512
assert result["timeout"] == 30
# Unsupported params should not be included
assert "unsupported" not in result

View File

@ -2,7 +2,6 @@
Unit tests for Qualifire guardrail integration.
"""
import sys
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -75,9 +74,37 @@ class TestQualifireGuardrailInit:
assert guardrail.on_flagged == "monitor"
def test_init_with_default_api_base(self):
"""Test that default API base is set when not provided."""
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
DEFAULT_QUALIFIRE_API_BASE,
QualifireGuardrail,
)
guardrail = QualifireGuardrail(
api_key="test_key",
guardrail_name="test_guardrail",
)
assert guardrail.qualifire_api_base == DEFAULT_QUALIFIRE_API_BASE
def test_init_with_custom_api_base(self):
"""Test initialization with custom API base URL."""
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
QualifireGuardrail,
)
guardrail = QualifireGuardrail(
api_key="test_key",
api_base="https://custom.qualifire.ai",
guardrail_name="test_guardrail",
)
assert guardrail.qualifire_api_base == "https://custom.qualifire.ai"
class TestQualifireGuardrailMessageConversion:
"""Tests for message conversion to Qualifire format."""
"""Tests for message conversion to API format."""
def test_convert_simple_messages(self):
"""Test conversion of simple text messages."""
@ -95,15 +122,13 @@ class TestQualifireGuardrailMessageConversion:
{"role": "assistant", "content": "Hi there!"},
]
# Create mock LLMMessage class
mock_llm_message = MagicMock()
result = guardrail._convert_messages_to_api_format(messages)
with patch(
"litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire.QualifireGuardrail._convert_messages_to_qualifire_format"
) as mock_convert:
mock_convert.return_value = [mock_llm_message, mock_llm_message]
result = guardrail._convert_messages_to_qualifire_format(messages)
assert len(result) == 2
assert len(result) == 2
assert result[0]["role"] == "user"
assert result[0]["content"] == "Hello, world!"
assert result[1]["role"] == "assistant"
assert result[1]["content"] == "Hi there!"
def test_convert_multimodal_messages(self):
"""Test conversion of multimodal messages with text parts."""
@ -126,112 +151,258 @@ class TestQualifireGuardrailMessageConversion:
},
]
with patch(
"litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire.QualifireGuardrail._convert_messages_to_qualifire_format"
) as mock_convert:
mock_convert.return_value = [MagicMock()]
result = guardrail._convert_messages_to_qualifire_format(messages)
assert len(result) == 1
result = guardrail._convert_messages_to_api_format(messages)
assert len(result) == 1
assert result[0]["role"] == "user"
assert result[0]["content"] == "First part\nSecond part"
def test_convert_messages_with_tool_calls(self):
"""Test conversion of messages with tool calls."""
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
QualifireGuardrail,
)
guardrail = QualifireGuardrail(
api_key="test_key",
guardrail_name="test_guardrail",
)
messages = [
{
"role": "assistant",
"content": "",
"tool_calls": [
{
"id": "call_123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": '{"location": "NYC"}',
},
}
],
},
]
result = guardrail._convert_messages_to_api_format(messages)
assert len(result) == 1
assert result[0]["role"] == "assistant"
assert "tool_calls" in result[0]
assert len(result[0]["tool_calls"]) == 1
assert result[0]["tool_calls"][0]["id"] == "call_123"
assert result[0]["tool_calls"][0]["name"] == "get_weather"
assert result[0]["tool_calls"][0]["arguments"] == {"location": "NYC"}
class TestQualifireGuardrailEvaluateKwargs:
"""Tests for evaluate kwargs passed to Qualifire client."""
class TestQualifireGuardrailToolConversion:
"""Tests for tool definition conversion."""
def test_convert_openai_function_tools(self):
"""Test conversion of OpenAI function tool format."""
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
QualifireGuardrail,
)
guardrail = QualifireGuardrail(
api_key="test_key",
guardrail_name="test_guardrail",
)
tools = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a location",
"parameters": {"type": "object", "properties": {}},
},
}
]
result = guardrail._convert_tools_to_api_format(tools)
assert result is not None
assert len(result) == 1
assert result[0]["name"] == "get_weather"
assert result[0]["description"] == "Get weather for a location"
def test_convert_empty_tools(self):
"""Test that empty tools returns None."""
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
QualifireGuardrail,
)
guardrail = QualifireGuardrail(
api_key="test_key",
guardrail_name="test_guardrail",
)
result = guardrail._convert_tools_to_api_format(None)
assert result is None
result = guardrail._convert_tools_to_api_format([])
assert result is None
class TestQualifireGuardrailAPICall:
"""Tests for API call with httpx client."""
@pytest.mark.asyncio
async def test_evaluate_called_with_prompt_injections(self):
"""Test that evaluate is called with prompt_injections enabled."""
# Mock the qualifire module and its types
mock_qualifire_types = MagicMock()
mock_llm_message = MagicMock()
mock_llm_tool_call = MagicMock()
mock_message_instance = MagicMock()
mock_llm_message.return_value = mock_message_instance
mock_qualifire_types.LLMMessage = mock_llm_message
mock_qualifire_types.LLMToolCall = mock_llm_tool_call
with patch.dict('sys.modules', {'qualifire': MagicMock(), 'qualifire.types': mock_qualifire_types}):
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
QualifireGuardrail,
)
"""Test that evaluate endpoint is called with prompt_injections enabled."""
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
QualifireGuardrail,
)
guardrail = QualifireGuardrail(
api_key="test_key",
prompt_injections=True,
guardrail_name="test_guardrail",
)
guardrail = QualifireGuardrail(
api_key="test_key",
prompt_injections=True,
guardrail_name="test_guardrail",
)
# Mock the client
mock_client = MagicMock()
mock_result = MagicMock()
mock_result.score = 100
mock_result.status = "completed"
mock_result.evaluationResults = []
mock_client.evaluate.return_value = mock_result
guardrail._client = mock_client
# Mock the async HTTP handler
mock_response = MagicMock()
mock_response.json.return_value = {
"score": 100,
"status": "completed",
"evaluationResults": [],
}
mock_response.raise_for_status = MagicMock()
guardrail.async_handler.post = AsyncMock(return_value=mock_response)
messages = [{"role": "user", "content": "Hello, world!"}]
messages = [{"role": "user", "content": "Hello, world!"}]
await guardrail._run_qualifire_check(
messages=messages, output=None, dynamic_params={}
)
await guardrail._run_qualifire_check(
messages=messages, output=None, dynamic_params={}
)
# Verify evaluate was called with correct kwargs
mock_client.evaluate.assert_called_once()
call_kwargs = mock_client.evaluate.call_args[1]
assert call_kwargs["prompt_injections"] is True
assert "messages" in call_kwargs
# Verify the API was called
guardrail.async_handler.post.assert_called_once()
call_kwargs = guardrail.async_handler.post.call_args[1]
assert "json" in call_kwargs
payload = call_kwargs["json"]
assert payload["prompt_injections"] is True
assert "messages" in payload
assert call_kwargs["url"].endswith("/api/evaluation/evaluate")
@pytest.mark.asyncio
async def test_evaluate_called_with_multiple_checks(self):
"""Test that evaluate is called with multiple checks enabled."""
# Mock the qualifire module and its types
mock_qualifire_types = MagicMock()
mock_llm_message = MagicMock()
mock_llm_tool_call = MagicMock()
mock_message_instance = MagicMock()
mock_llm_message.return_value = mock_message_instance
mock_qualifire_types.LLMMessage = mock_llm_message
mock_qualifire_types.LLMToolCall = mock_llm_tool_call
with patch.dict('sys.modules', {'qualifire': MagicMock(), 'qualifire.types': mock_qualifire_types}):
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
QualifireGuardrail,
)
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
QualifireGuardrail,
)
guardrail = QualifireGuardrail(
api_key="test_key",
prompt_injections=True,
pii_check=True,
hallucinations_check=True,
assertions=["Output must be valid JSON"],
guardrail_name="test_guardrail",
)
guardrail = QualifireGuardrail(
api_key="test_key",
prompt_injections=True,
pii_check=True,
hallucinations_check=True,
assertions=["Output must be valid JSON"],
guardrail_name="test_guardrail",
)
# Mock the client
mock_client = MagicMock()
mock_result = MagicMock()
mock_result.score = 100
mock_result.status = "completed"
mock_result.evaluationResults = []
mock_client.evaluate.return_value = mock_result
guardrail._client = mock_client
# Mock the async HTTP handler
mock_response = MagicMock()
mock_response.json.return_value = {
"score": 100,
"status": "completed",
"evaluationResults": [],
}
mock_response.raise_for_status = MagicMock()
guardrail.async_handler.post = AsyncMock(return_value=mock_response)
messages = [{"role": "user", "content": "Hello, world!"}]
messages = [{"role": "user", "content": "Hello, world!"}]
await guardrail._run_qualifire_check(
messages=messages, output="Test output", dynamic_params={}
)
await guardrail._run_qualifire_check(
messages=messages, output="Test output", dynamic_params={}
)
# Verify evaluate was called with correct kwargs
mock_client.evaluate.assert_called_once()
call_kwargs = mock_client.evaluate.call_args[1]
assert call_kwargs["prompt_injections"] is True
assert call_kwargs["pii_check"] is True
assert call_kwargs["hallucinations_check"] is True
assert call_kwargs["assertions"] == ["Output must be valid JSON"]
assert call_kwargs["output"] == "Test output"
# Verify the API was called with correct payload
guardrail.async_handler.post.assert_called_once()
call_kwargs = guardrail.async_handler.post.call_args[1]
payload = call_kwargs["json"]
assert payload["prompt_injections"] is True
assert payload["pii_check"] is True
assert payload["hallucinations_check"] is True
assert payload["assertions"] == ["Output must be valid JSON"]
assert payload["output"] == "Test output"
@pytest.mark.asyncio
async def test_invoke_endpoint_used_with_evaluation_id(self):
"""Test that invoke endpoint is used when evaluation_id is provided."""
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
QualifireGuardrail,
)
guardrail = QualifireGuardrail(
api_key="test_key",
evaluation_id="eval_123",
guardrail_name="test_guardrail",
)
# Mock the async HTTP handler
mock_response = MagicMock()
mock_response.json.return_value = {
"score": 100,
"status": "completed",
"evaluationResults": [],
}
mock_response.raise_for_status = MagicMock()
guardrail.async_handler.post = AsyncMock(return_value=mock_response)
messages = [{"role": "user", "content": "Hello, world!"}]
await guardrail._run_qualifire_check(
messages=messages, output="Test output", dynamic_params={}
)
# Verify the invoke endpoint was called
guardrail.async_handler.post.assert_called_once()
call_kwargs = guardrail.async_handler.post.call_args[1]
assert call_kwargs["url"].endswith("/api/evaluation/invoke")
payload = call_kwargs["json"]
assert payload["evaluation_id"] == "eval_123"
assert payload["input"] == "Hello, world!"
assert payload["output"] == "Test output"
@pytest.mark.asyncio
async def test_correct_headers_sent(self):
"""Test that correct headers are sent with the API request."""
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
QualifireGuardrail,
)
guardrail = QualifireGuardrail(
api_key="my_api_key",
guardrail_name="test_guardrail",
)
# Mock the async HTTP handler
mock_response = MagicMock()
mock_response.json.return_value = {
"score": 100,
"status": "completed",
"evaluationResults": [],
}
mock_response.raise_for_status = MagicMock()
guardrail.async_handler.post = AsyncMock(return_value=mock_response)
messages = [{"role": "user", "content": "Hello!"}]
await guardrail._run_qualifire_check(
messages=messages, output=None, dynamic_params={}
)
call_kwargs = guardrail.async_handler.post.call_args[1]
headers = call_kwargs["headers"]
assert headers["X-Qualifire-API-Key"] == "my_api_key"
assert headers["Content-Type"] == "application/json"
class TestQualifireGuardrailCheckIfFlagged:
@ -248,12 +419,14 @@ class TestQualifireGuardrailCheckIfFlagged:
guardrail_name="test_guardrail",
)
# Mock result with completed status and no flagged items
mock_result = MagicMock()
mock_result.status = "completed"
mock_result.evaluationResults = []
# Result with completed status and no flagged items (dict format)
result = {
"status": "completed",
"score": 100,
"evaluationResults": [],
}
assert guardrail._check_if_flagged(mock_result) is False
assert guardrail._check_if_flagged(result) is False
def test_check_if_flagged_returns_true_for_flagged_content(self):
"""Test that _check_if_flagged returns True when content is flagged."""
@ -266,18 +439,25 @@ class TestQualifireGuardrailCheckIfFlagged:
guardrail_name="test_guardrail",
)
# Mock result with flagged item
mock_inner_result = MagicMock()
mock_inner_result.flagged = True
# Result with flagged item (dict format matching API response)
result = {
"status": "completed",
"score": 15,
"evaluationResults": [
{
"type": "prompt_injection",
"results": [
{
"flagged": True,
"score": 0.15,
"reason": "Prompt injection detected",
}
],
}
],
}
mock_eval_result = MagicMock()
mock_eval_result.results = [mock_inner_result]
mock_result = MagicMock()
mock_result.status = "completed"
mock_result.evaluationResults = [mock_eval_result]
assert guardrail._check_if_flagged(mock_result) is True
assert guardrail._check_if_flagged(result) is True
def test_check_if_flagged_returns_false_when_no_flagged_items(self):
"""Test that _check_if_flagged returns False when no items are flagged."""
@ -291,17 +471,24 @@ class TestQualifireGuardrailCheckIfFlagged:
)
# Result with evaluation results but nothing flagged
mock_inner_result = MagicMock()
mock_inner_result.flagged = False
result = {
"status": "completed",
"score": 95,
"evaluationResults": [
{
"type": "prompt_injection",
"results": [
{
"flagged": False,
"score": 0.95,
"reason": "No issues detected",
}
],
}
],
}
mock_eval_result = MagicMock()
mock_eval_result.results = [mock_inner_result]
mock_result = MagicMock()
mock_result.status = "success"
mock_result.evaluationResults = [mock_eval_result]
assert guardrail._check_if_flagged(mock_result) is False
assert guardrail._check_if_flagged(result) is False
class TestQualifireGuardrailShouldRun:

View File

@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
import pytest
from fastapi import Request, status
from fastapi.responses import StreamingResponse
from fastapi.responses import JSONResponse, StreamingResponse
import litellm
from litellm._uuid import uuid
@ -11,9 +11,10 @@ from litellm.integrations.opentelemetry import UserAPIKeyAuth
from litellm.proxy.common_request_processing import (
ProxyBaseLLMRequestProcessing,
ProxyConfig,
_extract_error_from_sse_chunk,
_get_cost_breakdown_from_logging_obj,
_parse_event_data_for_error,
create_streaming_response,
create_response,
)
from litellm.proxy.utils import ProxyLogging
@ -680,21 +681,27 @@ class TestCommonRequestProcessingHelpers:
assert await _parse_event_data_for_error(event_line) == expected_code
async def test_create_streaming_response_first_chunk_is_error(self):
"""
Test that when the first chunk is an error, a JSON error response is returned
instead of an SSE streaming response
"""
async def mock_generator():
yield 'data: {"error": {"code": 403, "message": "forbidden"}}\n\n'
yield 'data: {"content": "more data"}\n\n'
yield "data: [DONE]\n\n"
response = await create_streaming_response(
response = await create_response(
mock_generator(), "text/event-stream", {}
)
# Should return JSONResponse instead of StreamingResponse
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_403_FORBIDDEN
content = await self.consume_stream(response)
assert content == [
'data: {"error": {"code": 403, "message": "forbidden"}}\n\n',
'data: {"content": "more data"}\n\n',
"data: [DONE]\n\n",
]
# Verify the response is in standard JSON error format
import json
body = json.loads(response.body.decode())
assert "error" in body
assert body["error"]["code"] == 403
assert body["error"]["message"] == "forbidden"
async def test_create_streaming_response_first_chunk_not_error(self):
async def mock_generator():
@ -702,7 +709,7 @@ class TestCommonRequestProcessingHelpers:
yield 'data: {"content": "second part"}\n\n'
yield "data: [DONE]\n\n"
response = await create_streaming_response(
response = await create_response(
mock_generator(), "text/event-stream", {}
)
assert response.status_code == status.HTTP_200_OK
@ -719,7 +726,7 @@ class TestCommonRequestProcessingHelpers:
yield
# Implicitly raises StopAsyncIteration
response = await create_streaming_response(
response = await create_response(
mock_generator(), "text/event-stream", {}
)
assert response.status_code == status.HTTP_200_OK
@ -732,7 +739,7 @@ class TestCommonRequestProcessingHelpers:
mock_gen = AsyncMock()
mock_gen.__anext__.side_effect = StopAsyncIteration
response = await create_streaming_response(mock_gen, "text/event-stream", {})
response = await create_response(mock_gen, "text/event-stream", {})
assert response.status_code == status.HTTP_200_OK
content = await self.consume_stream(response)
assert content == []
@ -743,7 +750,7 @@ class TestCommonRequestProcessingHelpers:
mock_gen = AsyncMock()
mock_gen.__anext__.side_effect = ValueError("Test error from generator")
response = await create_streaming_response(mock_gen, "text/event-stream", {})
response = await create_response(mock_gen, "text/event-stream", {})
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
content = await self.consume_stream(response)
expected_error_data = {
@ -760,19 +767,24 @@ class TestCommonRequestProcessingHelpers:
assert content[1] == "data: [DONE]\n\n"
async def test_create_streaming_response_first_chunk_error_string_code(self):
"""
Test that when the first chunk contains a string error code, a JSON error response is returned
"""
async def mock_generator():
yield 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n'
yield "data: [DONE]\n\n"
response = await create_streaming_response(
response = await create_response(
mock_generator(), "text/event-stream", {}
)
assert isinstance(response, JSONResponse)
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
content = await self.consume_stream(response)
assert content == [
'data: {"error": {"code": "429", "message": "too many requests"}}\n\n',
"data: [DONE]\n\n",
]
# Verify the response is in standard JSON error format
import json
body = json.loads(response.body.decode())
assert "error" in body
assert body["error"]["code"] == "429"
assert body["error"]["message"] == "too many requests"
async def test_create_streaming_response_custom_headers(self):
async def mock_generator():
@ -780,7 +792,7 @@ class TestCommonRequestProcessingHelpers:
yield "data: [DONE]\n\n"
custom_headers = {"X-Custom-Header": "TestValue"}
response = await create_streaming_response(
response = await create_response(
mock_generator(), "text/event-stream", custom_headers
)
assert response.headers["x-custom-header"] == "TestValue"
@ -790,7 +802,7 @@ class TestCommonRequestProcessingHelpers:
yield 'data: {"content": "data"}\n\n'
yield "data: [DONE]\n\n"
response = await create_streaming_response(
response = await create_response(
mock_generator(),
"text/event-stream",
{},
@ -807,7 +819,7 @@ class TestCommonRequestProcessingHelpers:
async def mock_generator():
yield "data: [DONE]\n\n"
response = await create_streaming_response(
response = await create_response(
mock_generator(), "text/event-stream", {}
)
assert response.status_code == status.HTTP_200_OK # Default status
@ -820,7 +832,7 @@ class TestCommonRequestProcessingHelpers:
yield 'data: {"content": "actual data"}\n\n'
yield "data: [DONE]\n\n"
response = await create_streaming_response(
response = await create_response(
mock_generator(), "text/event-stream", {}
)
assert response.status_code == status.HTTP_200_OK # Default status
@ -851,7 +863,7 @@ class TestCommonRequestProcessingHelpers:
# Patch the tracer in the common_request_processing module
with patch("litellm.proxy.common_request_processing.tracer", mock_tracer):
response = await create_streaming_response(
response = await create_response(
mock_generator(), "text/event-stream", {}
)
@ -888,7 +900,10 @@ class TestCommonRequestProcessingHelpers:
), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}"
async def test_create_streaming_response_dd_trace_with_error_chunk(self):
"""Test that dd trace is applied even when the first chunk contains an error"""
"""
Test that when the first chunk contains an error, JSONResponse is returned
and tracing is not triggered (since it's not a streaming response)
"""
from unittest.mock import patch
# Create a mock tracer
@ -905,28 +920,107 @@ class TestCommonRequestProcessingHelpers:
# Patch the tracer in the common_request_processing module
with patch("litellm.proxy.common_request_processing.tracer", mock_tracer):
response = await create_streaming_response(
response = await create_response(
mock_generator(), "text/event-stream", {}
)
# Even with error, status should be set to error code but tracing should still work
# Should return JSONResponse instead of StreamingResponse
assert isinstance(response, JSONResponse)
assert response.status_code == 400
# Consume the stream to trigger the tracer calls
content = await self.consume_stream(response)
# Verify the response is in standard JSON error format
import json
body = json.loads(response.body.decode())
assert "error" in body
assert body["error"]["code"] == 400
assert body["error"]["message"] == "bad request"
# Verify all chunks are present
assert len(content) == 3
# Since JSONResponse is returned instead of StreamingResponse, streaming tracing should not be triggered
# tracer.trace should not be called
assert mock_tracer.trace.call_count == 0
# Verify that tracer.trace was called for each chunk
assert mock_tracer.trace.call_count == 3
# Verify that each call was made with the correct operation name
actual_calls = mock_tracer.trace.call_args_list
assert len(actual_calls) == 3
class TestExtractErrorFromSSEChunk:
"""Tests for _extract_error_from_sse_chunk function"""
def test_extract_error_from_sse_chunk_with_valid_error(self):
"""Test extracting error information from a standard SSE chunk"""
chunk = 'data: {"error": {"code": 403, "message": "forbidden", "type": "auth_error", "param": "api_key"}}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["code"] == 403
assert error["message"] == "forbidden"
assert error["type"] == "auth_error"
assert error["param"] == "api_key"
def test_extract_error_from_sse_chunk_with_string_code(self):
"""Test error code as string type"""
chunk = 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["code"] == "429"
assert error["message"] == "too many requests"
def test_extract_error_from_sse_chunk_with_bytes(self):
"""Test input as bytes type"""
chunk = b'data: {"error": {"code": 500, "message": "internal error"}}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["code"] == 500
assert error["message"] == "internal error"
def test_extract_error_from_sse_chunk_with_done(self):
"""Test [DONE] marker should return default error"""
chunk = "data: [DONE]\n\n"
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "Unknown error"
assert error["type"] == "internal_server_error"
assert error["code"] == "500"
assert error["param"] is None
def test_extract_error_from_sse_chunk_without_error_field(self):
"""Test missing error field should return default error"""
chunk = 'data: {"content": "some content"}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "Unknown error"
assert error["type"] == "internal_server_error"
assert error["code"] == "500"
def test_extract_error_from_sse_chunk_with_invalid_json(self):
"""Test invalid JSON should return default error"""
chunk = 'data: {invalid json}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "Unknown error"
assert error["type"] == "internal_server_error"
assert error["code"] == "500"
def test_extract_error_from_sse_chunk_without_data_prefix(self):
"""Test missing 'data:' prefix should return default error"""
chunk = '{"error": {"code": 400, "message": "bad request"}}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "Unknown error"
assert error["type"] == "internal_server_error"
assert error["code"] == "500"
def test_extract_error_from_sse_chunk_with_empty_string(self):
"""Test empty string should return default error"""
chunk = ""
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "Unknown error"
assert error["type"] == "internal_server_error"
assert error["code"] == "500"
def test_extract_error_from_sse_chunk_with_minimal_error(self):
"""Test minimal error object"""
chunk = 'data: {"error": {"message": "error occurred"}}\n\n'
error = _extract_error_from_sse_chunk(chunk)
assert error["message"] == "error occurred"
# Other fields should be obtained from the original error object (if exists)
for i, call in enumerate(actual_calls):
args, kwargs = call
assert (
args[0] == "streaming.chunk.yield"
), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}"