Merge pull request #18763 from BerriAI/litellm_staging_01_07_2026
Staging - 01/07/2026
This commit is contained in:
commit
844c766c65
@ -262,6 +262,7 @@ Support for more providers. Missing a provider or LLM Platform, raise a [feature
|
||||
|
||||
| Provider | `/chat/completions` | `/messages` | `/responses` | `/embeddings` | `/image/generations` | `/audio/transcriptions` | `/audio/speech` | `/moderations` | `/batches` | `/rerank` |
|
||||
|-------------------------------------------------------------------------------------|---------------------|-------------|--------------|---------------|----------------------|-------------------------|-----------------|----------------|-----------|-----------|
|
||||
| [Abliteration (`abliteration`)](https://docs.litellm.ai/docs/providers/abliteration) | ✅ | | | | | | | | | |
|
||||
| [AI/ML API (`aiml`)](https://docs.litellm.ai/docs/providers/aiml) | ✅ | ✅ | ✅ | ✅ | ✅ | | | | | |
|
||||
| [AI21 (`ai21`)](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
| [AI21 Chat (`ai21_chat`)](https://docs.litellm.ai/docs/providers/ai21) | ✅ | ✅ | ✅ | | | | | | | |
|
||||
@ -455,4 +456,3 @@ All these checks must pass before your PR can be merged.
|
||||
<img src="https://contrib.rocks/image?repo=BerriAI/litellm" />
|
||||
</a>
|
||||
|
||||
|
||||
|
||||
109
docs/my-website/docs/providers/abliteration.md
Normal file
109
docs/my-website/docs/providers/abliteration.md
Normal file
@ -0,0 +1,109 @@
|
||||
# Abliteration
|
||||
|
||||
## Overview
|
||||
|
||||
| Property | Details |
|
||||
|-------|-------|
|
||||
| Description | Abliteration provides an OpenAI-compatible `/chat/completions` endpoint. |
|
||||
| Provider Route on LiteLLM | `abliteration/` |
|
||||
| Link to Provider Doc | [Abliteration](https://abliteration.ai) |
|
||||
| Base URL | `https://api.abliteration.ai/v1` |
|
||||
| Supported Operations | [`/chat/completions`](#sample-usage) |
|
||||
|
||||
<br />
|
||||
|
||||
## Required Variables
|
||||
|
||||
```python showLineNumbers title="Environment Variables"
|
||||
os.environ["ABLITERATION_API_KEY"] = "" # your Abliteration API key
|
||||
```
|
||||
|
||||
## Sample Usage
|
||||
|
||||
```python showLineNumbers title="Abliteration Completion"
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
os.environ["ABLITERATION_API_KEY"] = ""
|
||||
|
||||
response = completion(
|
||||
model="abliteration/abliterated-model",
|
||||
messages=[{"role": "user", "content": "Hello from LiteLLM"}],
|
||||
)
|
||||
|
||||
print(response)
|
||||
```
|
||||
|
||||
## Sample Usage - Streaming
|
||||
|
||||
```python showLineNumbers title="Abliteration Streaming Completion"
|
||||
import os
|
||||
from litellm import completion
|
||||
|
||||
os.environ["ABLITERATION_API_KEY"] = ""
|
||||
|
||||
response = completion(
|
||||
model="abliteration/abliterated-model",
|
||||
messages=[{"role": "user", "content": "Stream a short reply"}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
print(chunk)
|
||||
```
|
||||
|
||||
## Usage with LiteLLM Proxy Server
|
||||
|
||||
1. Add the model to your proxy config:
|
||||
|
||||
```yaml showLineNumbers title="config.yaml"
|
||||
model_list:
|
||||
- model_name: abliteration-chat
|
||||
litellm_params:
|
||||
model: abliteration/abliterated-model
|
||||
api_key: os.environ/ABLITERATION_API_KEY
|
||||
```
|
||||
|
||||
2. Start the proxy:
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
```
|
||||
|
||||
## Direct API Usage (Bearer Token)
|
||||
|
||||
Use the environment variable as a Bearer token against the OpenAI-compatible endpoint:
|
||||
`https://api.abliteration.ai/v1/chat/completions`.
|
||||
|
||||
```bash showLineNumbers title="cURL"
|
||||
export ABLITERATION_API_KEY=""
|
||||
curl https://api.abliteration.ai/v1/chat/completions \
|
||||
-H "Authorization: Bearer ${ABLITERATION_API_KEY}" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "abliterated-model",
|
||||
"messages": [{"role": "user", "content": "Hello from Abliteration"}]
|
||||
}'
|
||||
```
|
||||
|
||||
```python showLineNumbers title="Python (requests)"
|
||||
import os
|
||||
import requests
|
||||
|
||||
api_key = os.environ["ABLITERATION_API_KEY"]
|
||||
|
||||
response = requests.post(
|
||||
"https://api.abliteration.ai/v1/chat/completions",
|
||||
headers={
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json={
|
||||
"model": "abliterated-model",
|
||||
"messages": [{"role": "user", "content": "Hello from Abliteration"}],
|
||||
},
|
||||
timeout=60,
|
||||
)
|
||||
|
||||
print(response.json())
|
||||
```
|
||||
@ -8,13 +8,7 @@ Use [Qualifire](https://qualifire.ai) to evaluate LLM outputs for quality, safet
|
||||
|
||||
## Quick Start
|
||||
|
||||
### 1. Install the Qualifire SDK
|
||||
|
||||
```bash
|
||||
pip install qualifire
|
||||
```
|
||||
|
||||
### 2. Define Guardrails on your LiteLLM config.yaml
|
||||
### 1. Define Guardrails on your LiteLLM config.yaml
|
||||
|
||||
Define your guardrails under the `guardrails` section:
|
||||
|
||||
@ -61,13 +55,13 @@ guardrails:
|
||||
- `post_call` Run **after** LLM call, on **input & output**
|
||||
- `during_call` Run **during** LLM call, on **input**. Same as `pre_call` but runs in parallel as LLM call. Response not returned until guardrail check completes
|
||||
|
||||
### 3. Start LiteLLM Gateway
|
||||
### 2. Start LiteLLM Gateway
|
||||
|
||||
```shell
|
||||
litellm --config config.yaml --detailed_debug
|
||||
```
|
||||
|
||||
### 4. Test request
|
||||
### 3. Test request
|
||||
|
||||
**[Langchain, OpenAI SDK Usage Examples](../proxy/user_keys#request-format)**
|
||||
|
||||
@ -142,7 +136,7 @@ guardrails:
|
||||
evaluation_id: eval_abc123 # Your evaluation ID from Qualifire dashboard
|
||||
```
|
||||
|
||||
When `evaluation_id` is provided, LiteLLM will use `invoke_evaluation()` instead of `evaluate()`, running the pre-configured evaluation from your dashboard.
|
||||
When `evaluation_id` is provided, LiteLLM will use the invoke evaluation API endpoint instead of the evaluate endpoint, running the pre-configured evaluation from your dashboard.
|
||||
|
||||
## Available Checks
|
||||
|
||||
@ -213,19 +207,19 @@ guardrails:
|
||||
|
||||
### Parameter Reference
|
||||
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------------------------ | ----------- | --------------------------- | -------------------------------------------------------- |
|
||||
| `api_key` | `str` | `QUALIFIRE_API_KEY` env var | Your Qualifire API key |
|
||||
| `api_base` | `str` | `None` | Custom API base URL (optional) |
|
||||
| `evaluation_id` | `str` | `None` | Pre-configured evaluation ID from Qualifire dashboard |
|
||||
| `prompt_injections` | `bool` | `true` (if no other checks) | Enable prompt injection detection |
|
||||
| `hallucinations_check` | `bool` | `None` | Enable hallucination detection |
|
||||
| `grounding_check` | `bool` | `None` | Enable grounding verification |
|
||||
| `pii_check` | `bool` | `None` | Enable PII detection |
|
||||
| `content_moderation_check` | `bool` | `None` | Enable content moderation |
|
||||
| `tool_selection_quality_check` | `bool` | `None` | Enable tool selection quality check |
|
||||
| `assertions` | `List[str]` | `None` | Custom assertions to validate |
|
||||
| `on_flagged` | `str` | `"block"` | Action when content is flagged: `"block"` or `"monitor"` |
|
||||
| Parameter | Type | Default | Description |
|
||||
| ------------------------------ | ----------- | ---------------------------- | -------------------------------------------------------- |
|
||||
| `api_key` | `str` | `QUALIFIRE_API_KEY` env var | Your Qualifire API key |
|
||||
| `api_base` | `str` | `https://proxy.qualifire.ai` | Custom API base URL (optional) |
|
||||
| `evaluation_id` | `str` | `None` | Pre-configured evaluation ID from Qualifire dashboard |
|
||||
| `prompt_injections` | `bool` | `true` (if no other checks) | Enable prompt injection detection |
|
||||
| `hallucinations_check` | `bool` | `None` | Enable hallucination detection |
|
||||
| `grounding_check` | `bool` | `None` | Enable grounding verification |
|
||||
| `pii_check` | `bool` | `None` | Enable PII detection |
|
||||
| `content_moderation_check` | `bool` | `None` | Enable content moderation |
|
||||
| `tool_selection_quality_check` | `bool` | `None` | Enable tool selection quality check |
|
||||
| `assertions` | `List[str]` | `None` | Custom assertions to validate |
|
||||
| `on_flagged` | `str` | `"block"` | Action when content is flagged: `"block"` or `"monitor"` |
|
||||
|
||||
### Default Behavior
|
||||
|
||||
@ -261,4 +255,3 @@ This evaluates whether the LLM selected the appropriate tools and provided corre
|
||||
|
||||
- [Qualifire Documentation](https://docs.qualifire.ai)
|
||||
- [Qualifire Dashboard](https://app.qualifire.ai)
|
||||
- [Qualifire Python SDK](https://github.com/qualifire-dev/qualifire-python-sdk)
|
||||
|
||||
@ -55,6 +55,7 @@ const sidebars = {
|
||||
"proxy/guardrails/test_playground",
|
||||
"proxy/guardrails/litellm_content_filter",
|
||||
...[
|
||||
"proxy/guardrails/qualifire",
|
||||
"proxy/guardrails/aim_security",
|
||||
"proxy/guardrails/onyx_security",
|
||||
"proxy/guardrails/aporia_api",
|
||||
@ -653,12 +654,13 @@ const sidebars = {
|
||||
"providers/bedrock_writer",
|
||||
"providers/bedrock_batches",
|
||||
"providers/aws_polly",
|
||||
"providers/bedrock_vector_store",
|
||||
]
|
||||
},
|
||||
"providers/litellm_proxy",
|
||||
"providers/ai21",
|
||||
"providers/aiml",
|
||||
"providers/bedrock_vector_store",
|
||||
]
|
||||
},
|
||||
"providers/litellm_proxy",
|
||||
"providers/abliteration",
|
||||
"providers/ai21",
|
||||
"providers/aiml",
|
||||
"providers/aleph_alpha",
|
||||
"providers/amazon_nova",
|
||||
"providers/anyscale",
|
||||
|
||||
@ -225,10 +225,13 @@ class BraintrustLogger(CustomLogger):
|
||||
"id": litellm_call_id,
|
||||
"input": prompt["messages"],
|
||||
"metadata": standard_logging_object,
|
||||
"tags": tags,
|
||||
"span_attributes": {"name": span_name, "type": "llm"},
|
||||
}
|
||||
|
||||
|
||||
# Braintrust cannot specify 'tags' for non-root spans
|
||||
if dynamic_metadata.get("root_span_id") is None:
|
||||
request_data["tags"] = tags
|
||||
|
||||
# Only add those that are not None (or falsy)
|
||||
for key, value in span_attributes.items():
|
||||
if value:
|
||||
@ -351,14 +354,37 @@ class BraintrustLogger(CustomLogger):
|
||||
# Allow metadata override for span name
|
||||
span_name = dynamic_metadata.get("span_name", "Chat Completion")
|
||||
|
||||
# Span parents is a special case
|
||||
span_parents = dynamic_metadata.get("span_parents")
|
||||
|
||||
# Convert comma-separated string to list if present
|
||||
if span_parents:
|
||||
span_parents = [s.strip() for s in span_parents.split(",") if s.strip()]
|
||||
|
||||
# Add optional span attributes only if present
|
||||
span_attributes = {
|
||||
"span_id": dynamic_metadata.get("span_id"),
|
||||
"root_span_id": dynamic_metadata.get("root_span_id"),
|
||||
"span_parents": span_parents,
|
||||
}
|
||||
|
||||
request_data = {
|
||||
"id": litellm_call_id,
|
||||
"input": prompt["messages"],
|
||||
"output": output,
|
||||
"metadata": standard_logging_object,
|
||||
"tags": tags,
|
||||
"span_attributes": {"name": span_name, "type": "llm"},
|
||||
}
|
||||
|
||||
# Braintrust cannot specify 'tags' for non-root spans
|
||||
if dynamic_metadata.get("root_span_id") is None:
|
||||
request_data["tags"] = tags
|
||||
|
||||
# Only add those that are not None (or falsy)
|
||||
for key, value in span_attributes.items():
|
||||
if value:
|
||||
request_data[key] = value
|
||||
|
||||
if choices is not None:
|
||||
request_data["output"] = [choice.dict() for choice in choices]
|
||||
else:
|
||||
@ -367,9 +393,6 @@ class BraintrustLogger(CustomLogger):
|
||||
if metrics is not None:
|
||||
request_data["metrics"] = metrics
|
||||
|
||||
if metrics is not None:
|
||||
request_data["metrics"] = metrics
|
||||
|
||||
try:
|
||||
await self.global_braintrust_http_handler.post(
|
||||
url=f"{self.api_base}/project_logs/{project_id}/insert",
|
||||
|
||||
@ -45,6 +45,7 @@ def _get_cached_end_user_id_for_cost_tracking():
|
||||
global _get_end_user_id_for_cost_tracking
|
||||
if _get_end_user_id_for_cost_tracking is None:
|
||||
from litellm.utils import get_end_user_id_for_cost_tracking
|
||||
|
||||
_get_end_user_id_for_cost_tracking = get_end_user_id_for_cost_tracking
|
||||
return _get_end_user_id_for_cost_tracking
|
||||
|
||||
@ -330,6 +331,25 @@ class PrometheusLogger(CustomLogger):
|
||||
labelnames=self.get_labels_for_metric("litellm_requests_metric"),
|
||||
)
|
||||
|
||||
# Cache metrics
|
||||
self.litellm_cache_hits_metric = self._counter_factory(
|
||||
name="litellm_cache_hits_metric",
|
||||
documentation="Total number of LiteLLM cache hits",
|
||||
labelnames=self.get_labels_for_metric("litellm_cache_hits_metric"),
|
||||
)
|
||||
|
||||
self.litellm_cache_misses_metric = self._counter_factory(
|
||||
name="litellm_cache_misses_metric",
|
||||
documentation="Total number of LiteLLM cache misses",
|
||||
labelnames=self.get_labels_for_metric("litellm_cache_misses_metric"),
|
||||
)
|
||||
|
||||
self.litellm_cached_tokens_metric = self._counter_factory(
|
||||
name="litellm_cached_tokens_metric",
|
||||
documentation="Total tokens served from LiteLLM cache",
|
||||
labelnames=self.get_labels_for_metric("litellm_cached_tokens_metric"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print_verbose(f"Got exception on init prometheus client {str(e)}")
|
||||
raise e
|
||||
@ -801,7 +821,7 @@ class PrometheusLogger(CustomLogger):
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
_metadata = litellm_params.get("metadata", {})
|
||||
get_end_user_id_for_cost_tracking = _get_cached_end_user_id_for_cost_tracking()
|
||||
|
||||
|
||||
end_user_id = get_end_user_id_for_cost_tracking(
|
||||
litellm_params, service_type="prometheus"
|
||||
)
|
||||
@ -821,7 +841,7 @@ class PrometheusLogger(CustomLogger):
|
||||
user_api_key_auth_metadata: Optional[dict] = standard_logging_payload[
|
||||
"metadata"
|
||||
].get("user_api_key_auth_metadata")
|
||||
|
||||
|
||||
# Include top-level metadata fields (excluding nested dictionaries)
|
||||
# This allows accessing fields like requester_ip_address from top-level metadata
|
||||
top_level_metadata = standard_logging_payload.get("metadata", {})
|
||||
@ -832,7 +852,7 @@ class PrometheusLogger(CustomLogger):
|
||||
for k, v in top_level_metadata.items()
|
||||
if not isinstance(v, dict) # Exclude nested dicts to avoid conflicts
|
||||
}
|
||||
|
||||
|
||||
combined_metadata: Dict[str, Any] = {
|
||||
**top_level_fields, # Include top-level fields first
|
||||
**(_requester_metadata if _requester_metadata else {}),
|
||||
@ -951,6 +971,12 @@ class PrometheusLogger(CustomLogger):
|
||||
kwargs, start_time, end_time, enum_values, output_tokens
|
||||
)
|
||||
|
||||
# cache metrics
|
||||
self._increment_cache_metrics(
|
||||
standard_logging_payload=standard_logging_payload, # type: ignore
|
||||
enum_values=enum_values,
|
||||
)
|
||||
|
||||
if (
|
||||
standard_logging_payload["stream"] is True
|
||||
): # log successful streaming requests from logging event hook.
|
||||
@ -1020,6 +1046,54 @@ class PrometheusLogger(CustomLogger):
|
||||
standard_logging_payload["completion_tokens"]
|
||||
)
|
||||
|
||||
def _increment_cache_metrics(
|
||||
self,
|
||||
standard_logging_payload: StandardLoggingPayload,
|
||||
enum_values: UserAPIKeyLabelValues,
|
||||
):
|
||||
"""
|
||||
Increment cache-related Prometheus metrics based on cache hit/miss status.
|
||||
|
||||
Args:
|
||||
standard_logging_payload: Contains cache_hit field (True/False/None)
|
||||
enum_values: Label values for Prometheus metrics
|
||||
"""
|
||||
cache_hit = standard_logging_payload.get("cache_hit")
|
||||
|
||||
# Only track if cache_hit has a definite value (True or False)
|
||||
if cache_hit is None:
|
||||
return
|
||||
|
||||
if cache_hit is True:
|
||||
# Increment cache hits counter
|
||||
_labels = prometheus_label_factory(
|
||||
supported_enum_labels=self.get_labels_for_metric(
|
||||
metric_name="litellm_cache_hits_metric"
|
||||
),
|
||||
enum_values=enum_values,
|
||||
)
|
||||
self.litellm_cache_hits_metric.labels(**_labels).inc()
|
||||
|
||||
# Increment cached tokens counter
|
||||
total_tokens = standard_logging_payload.get("total_tokens", 0)
|
||||
if total_tokens > 0:
|
||||
_labels = prometheus_label_factory(
|
||||
supported_enum_labels=self.get_labels_for_metric(
|
||||
metric_name="litellm_cached_tokens_metric"
|
||||
),
|
||||
enum_values=enum_values,
|
||||
)
|
||||
self.litellm_cached_tokens_metric.labels(**_labels).inc(total_tokens)
|
||||
else:
|
||||
# cache_hit is False - increment cache misses counter
|
||||
_labels = prometheus_label_factory(
|
||||
supported_enum_labels=self.get_labels_for_metric(
|
||||
metric_name="litellm_cache_misses_metric"
|
||||
),
|
||||
enum_values=enum_values,
|
||||
)
|
||||
self.litellm_cache_misses_metric.labels(**_labels).inc()
|
||||
|
||||
async def _increment_remaining_budget_metrics(
|
||||
self,
|
||||
user_api_team: Optional[str],
|
||||
@ -1208,7 +1282,7 @@ class PrometheusLogger(CustomLogger):
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
get_end_user_id_for_cost_tracking = _get_cached_end_user_id_for_cost_tracking()
|
||||
|
||||
|
||||
end_user_id = get_end_user_id_for_cost_tracking(
|
||||
litellm_params, service_type="prometheus"
|
||||
)
|
||||
@ -1562,7 +1636,6 @@ class PrometheusLogger(CustomLogger):
|
||||
api_provider=llm_provider or "",
|
||||
)
|
||||
if exception is not None:
|
||||
|
||||
_labels = prometheus_label_factory(
|
||||
supported_enum_labels=self.get_labels_for_metric(
|
||||
metric_name="litellm_deployment_failure_responses"
|
||||
@ -1595,12 +1668,11 @@ class PrometheusLogger(CustomLogger):
|
||||
enum_values: UserAPIKeyLabelValues,
|
||||
output_tokens: float = 1.0,
|
||||
):
|
||||
|
||||
try:
|
||||
verbose_logger.debug("setting remaining tokens requests metric")
|
||||
standard_logging_payload: Optional[StandardLoggingPayload] = (
|
||||
request_kwargs.get("standard_logging_object")
|
||||
)
|
||||
standard_logging_payload: Optional[
|
||||
StandardLoggingPayload
|
||||
] = request_kwargs.get("standard_logging_object")
|
||||
|
||||
if standard_logging_payload is None:
|
||||
return
|
||||
@ -2380,10 +2452,10 @@ class PrometheusLogger(CustomLogger):
|
||||
from litellm.constants import PROMETHEUS_BUDGET_METRICS_REFRESH_INTERVAL_MINUTES
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
prometheus_loggers: List[CustomLogger] = (
|
||||
litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
callback_type=PrometheusLogger
|
||||
)
|
||||
prometheus_loggers: List[
|
||||
CustomLogger
|
||||
] = litellm.logging_callback_manager.get_custom_loggers_for_type(
|
||||
callback_type=PrometheusLogger
|
||||
)
|
||||
# we need to get the initialized prometheus logger instance(s) and call logger.initialize_remaining_budget_metrics() on them
|
||||
verbose_logger.debug("found %s prometheus loggers", len(prometheus_loggers))
|
||||
@ -2455,7 +2527,7 @@ def prometheus_label_factory(
|
||||
|
||||
if UserAPIKeyLabelNames.END_USER.value in filtered_labels:
|
||||
get_end_user_id_for_cost_tracking = _get_cached_end_user_id_for_cost_tracking()
|
||||
|
||||
|
||||
filtered_labels["end_user"] = get_end_user_id_for_cost_tracking(
|
||||
litellm_params={"user_api_key_end_user_id": enum_values.end_user},
|
||||
service_type="prometheus",
|
||||
|
||||
@ -4839,9 +4839,9 @@ class StandardLoggingPayloadSetup:
|
||||
metadata = litellm_params.get("metadata") or {}
|
||||
litellm_metadata = litellm_params.get("litellm_metadata") or {}
|
||||
if metadata.get("tags", []):
|
||||
request_tags = metadata.get("tags", [])
|
||||
request_tags = metadata.get("tags", []).copy()
|
||||
elif litellm_metadata.get("tags", []):
|
||||
request_tags = litellm_metadata.get("tags", [])
|
||||
request_tags = litellm_metadata.get("tags", []).copy()
|
||||
else:
|
||||
request_tags = []
|
||||
user_agent_tags = StandardLoggingPayloadSetup._get_user_agent_tags(
|
||||
|
||||
@ -61,6 +61,10 @@
|
||||
"max_completion_tokens": "max_tokens"
|
||||
}
|
||||
},
|
||||
"abliteration": {
|
||||
"base_url": "https://api.abliteration.ai/v1",
|
||||
"api_key_env": "ABLITERATION_API_KEY"
|
||||
},
|
||||
"llamagate": {
|
||||
"base_url": "https://api.llamagate.dev/v1",
|
||||
"api_key_env": "LLAMAGATE_API_KEY",
|
||||
|
||||
182
litellm/llms/openrouter/embedding/transformation.py
Normal file
182
litellm/llms/openrouter/embedding/transformation.py
Normal file
@ -0,0 +1,182 @@
|
||||
"""
|
||||
OpenRouter Embedding API Configuration.
|
||||
|
||||
This module provides the configuration for OpenRouter's Embedding API.
|
||||
OpenRouter is OpenAI-compatible and supports embeddings via the /v1/embeddings endpoint.
|
||||
|
||||
Docs: https://openrouter.ai/docs
|
||||
"""
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.llms.base_llm.embedding.transformation import BaseEmbeddingConfig
|
||||
from litellm.types.llms.openai import AllEmbeddingInputValues
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
from litellm.utils import convert_to_model_response_object
|
||||
|
||||
from ..common_utils import OpenRouterException
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class OpenrouterEmbeddingConfig(BaseEmbeddingConfig):
|
||||
"""
|
||||
Configuration for OpenRouter's Embedding API.
|
||||
|
||||
Reference: https://openrouter.ai/docs
|
||||
"""
|
||||
|
||||
def validate_environment(
|
||||
self,
|
||||
headers: dict,
|
||||
model: str,
|
||||
messages: list,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Validate environment and set up headers for OpenRouter API.
|
||||
|
||||
OpenRouter requires:
|
||||
- Authorization header with Bearer token
|
||||
- HTTP-Referer header (site URL)
|
||||
- X-Title header (app name)
|
||||
"""
|
||||
from litellm import get_secret
|
||||
|
||||
# Get OpenRouter-specific headers
|
||||
openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai"
|
||||
openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM"
|
||||
|
||||
openrouter_headers = {
|
||||
"HTTP-Referer": openrouter_site_url,
|
||||
"X-Title": openrouter_app_name,
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add Authorization header if api_key is provided
|
||||
if api_key:
|
||||
openrouter_headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Merge with existing headers (user's extra_headers take priority)
|
||||
merged_headers = {**openrouter_headers, **headers}
|
||||
|
||||
return merged_headers
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get the complete URL for OpenRouter Embedding API endpoint.
|
||||
"""
|
||||
# api_base is already set to https://openrouter.ai/api/v1 in main.py
|
||||
# Remove trailing slashes
|
||||
if api_base:
|
||||
api_base = api_base.rstrip("/")
|
||||
else:
|
||||
api_base = "https://openrouter.ai/api/v1"
|
||||
|
||||
# Return the embeddings endpoint
|
||||
return f"{api_base}/embeddings"
|
||||
|
||||
def transform_embedding_request(
|
||||
self,
|
||||
model: str,
|
||||
input: AllEmbeddingInputValues,
|
||||
optional_params: dict,
|
||||
headers: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Transform embedding request to OpenRouter format (OpenAI-compatible).
|
||||
"""
|
||||
# Ensure input is a list
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
|
||||
# OpenRouter expects the full model name (e.g., google/gemini-embedding-001)
|
||||
# Strip 'openrouter/' prefix if present
|
||||
if model.startswith("openrouter/"):
|
||||
model = model.replace("openrouter/", "", 1)
|
||||
|
||||
return {
|
||||
"model": model,
|
||||
"input": input,
|
||||
**optional_params,
|
||||
}
|
||||
|
||||
def transform_embedding_response(
|
||||
self,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
model_response: EmbeddingResponse,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
request_data: dict,
|
||||
optional_params: dict,
|
||||
litellm_params: dict,
|
||||
) -> EmbeddingResponse:
|
||||
"""
|
||||
Transform embedding response from OpenRouter format (OpenAI-compatible).
|
||||
"""
|
||||
logging_obj.post_call(original_response=raw_response.text)
|
||||
|
||||
# OpenRouter returns standard OpenAI-compatible embedding response
|
||||
response_json = raw_response.json()
|
||||
|
||||
return convert_to_model_response_object(
|
||||
response_object=response_json,
|
||||
model_response_object=model_response,
|
||||
response_type="embedding",
|
||||
)
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> list:
|
||||
"""
|
||||
Get list of supported OpenAI parameters for OpenRouter embeddings.
|
||||
"""
|
||||
return [
|
||||
"timeout",
|
||||
"dimensions",
|
||||
"encoding_format",
|
||||
"user",
|
||||
]
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
drop_params: bool,
|
||||
) -> dict:
|
||||
"""
|
||||
Map OpenAI parameters to OpenRouter format.
|
||||
"""
|
||||
for param, value in non_default_params.items():
|
||||
if param in self.get_supported_openai_params(model):
|
||||
optional_params[param] = value
|
||||
return optional_params
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Any
|
||||
) -> Any:
|
||||
"""
|
||||
Get the error class for OpenRouter errors.
|
||||
"""
|
||||
return OpenRouterException(
|
||||
message=error_message,
|
||||
status_code=status_code,
|
||||
headers=headers,
|
||||
)
|
||||
@ -4707,6 +4707,51 @@ def embedding( # noqa: PLR0915
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=headers,
|
||||
)
|
||||
elif custom_llm_provider == "openrouter":
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("OPENROUTER_API_BASE")
|
||||
or "https://openrouter.ai/api/v1"
|
||||
)
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openrouter_key
|
||||
or get_secret("OPENROUTER_API_KEY")
|
||||
or get_secret("OR_API_KEY")
|
||||
)
|
||||
|
||||
openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai"
|
||||
openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM"
|
||||
|
||||
openrouter_headers = {
|
||||
"HTTP-Referer": openrouter_site_url,
|
||||
"X-Title": openrouter_app_name,
|
||||
}
|
||||
|
||||
_headers = headers or litellm.headers
|
||||
if _headers:
|
||||
openrouter_headers.update(_headers)
|
||||
|
||||
headers = openrouter_headers
|
||||
|
||||
response = base_llm_http_handler.embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
timeout=timeout,
|
||||
model_response=EmbeddingResponse(),
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
aembedding=aembedding,
|
||||
litellm_params=litellm_params_dict,
|
||||
headers=headers,
|
||||
)
|
||||
elif custom_llm_provider == "huggingface":
|
||||
api_key = (
|
||||
api_key
|
||||
|
||||
@ -1942,7 +1942,7 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
|
||||
description="connect to a postgres db - needed for generating temporary keys + tracking spend / key",
|
||||
)
|
||||
database_connection_pool_limit: Optional[int] = Field(
|
||||
100,
|
||||
10,
|
||||
description="default connection pool for prisma client connecting to postgres db",
|
||||
)
|
||||
database_connection_timeout: Optional[float] = Field(
|
||||
|
||||
@ -11,7 +11,7 @@ from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.integrations.custom_guardrail import ModifyResponseException
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
create_streaming_response,
|
||||
create_response,
|
||||
)
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm.types.utils import TokenCountResponse
|
||||
@ -106,7 +106,7 @@ async def anthropic_response( # noqa: PLR0915
|
||||
)
|
||||
)
|
||||
|
||||
return await create_streaming_response(
|
||||
return await create_response(
|
||||
generator=selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers={},
|
||||
|
||||
@ -17,7 +17,7 @@ from typing import (
|
||||
import httpx
|
||||
import orjson
|
||||
from fastapi import HTTPException, Request, status
|
||||
from fastapi.responses import Response, StreamingResponse
|
||||
from fastapi.responses import JSONResponse, Response, StreamingResponse
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
@ -96,16 +96,55 @@ async def _parse_event_data_for_error(event_line: Union[str, bytes]) -> Optional
|
||||
return None
|
||||
|
||||
|
||||
async def create_streaming_response(
|
||||
def _extract_error_from_sse_chunk(event_line: Union[str, bytes]) -> dict:
|
||||
"""
|
||||
Extract error dictionary from SSE format chunk.
|
||||
|
||||
Args:
|
||||
event_line: SSE format event line, e.g. "data: {"error": {...}}\n\n"
|
||||
|
||||
Returns:
|
||||
Error dictionary in OpenAI API format
|
||||
"""
|
||||
event_line = (
|
||||
event_line.decode("utf-8") if isinstance(event_line, bytes) else event_line
|
||||
)
|
||||
|
||||
# Default error format
|
||||
default_error = {
|
||||
"message": "Unknown error",
|
||||
"type": "internal_server_error",
|
||||
"param": None,
|
||||
"code": "500",
|
||||
}
|
||||
|
||||
if event_line.startswith("data: "):
|
||||
json_str = event_line[len("data: ") :].strip()
|
||||
if not json_str or json_str == "[DONE]":
|
||||
return default_error
|
||||
|
||||
try:
|
||||
data = orjson.loads(json_str)
|
||||
if isinstance(data, dict) and "error" in data:
|
||||
error_obj = data["error"]
|
||||
if isinstance(error_obj, dict):
|
||||
return error_obj
|
||||
except (orjson.JSONDecodeError, json.JSONDecodeError):
|
||||
pass
|
||||
|
||||
return default_error
|
||||
|
||||
|
||||
async def create_response(
|
||||
generator: AsyncGenerator[str, None],
|
||||
media_type: str,
|
||||
headers: dict,
|
||||
default_status_code: int = status.HTTP_200_OK,
|
||||
) -> StreamingResponse:
|
||||
) -> Union[StreamingResponse, JSONResponse]:
|
||||
"""
|
||||
Creates a StreamingResponse by inspecting the first chunk for an error code.
|
||||
The entire original generator content is streamed, but the HTTP status code
|
||||
of the response is set based on the first chunk if it's a recognized error.
|
||||
Create streaming response, checking if the first chunk is an error.
|
||||
If the first chunk is an error, return a standard JSON error response.
|
||||
Otherwise, return StreamingResponse and stream all content.
|
||||
"""
|
||||
first_chunk_value: Optional[str] = None
|
||||
final_status_code = default_status_code
|
||||
@ -124,9 +163,27 @@ async def create_streaming_response(
|
||||
first_chunk_value
|
||||
)
|
||||
if error_code_from_chunk is not None:
|
||||
# First chunk is an error, stream hasn't really started yet
|
||||
# Should return standard JSON error response instead of SSE format
|
||||
final_status_code = error_code_from_chunk
|
||||
verbose_proxy_logger.debug(
|
||||
f"Error detected in first stream chunk. Status code set to: {final_status_code}"
|
||||
f"Error detected in first stream chunk. Returning JSON error response with status code: {final_status_code}"
|
||||
)
|
||||
|
||||
# Parse error content
|
||||
error_dict = _extract_error_from_sse_chunk(first_chunk_value)
|
||||
|
||||
# Consume and close generator (avoid resource leak)
|
||||
try:
|
||||
await generator.aclose()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Return JSON format error response
|
||||
return JSONResponse(
|
||||
status_code=final_status_code,
|
||||
content={"error": error_dict},
|
||||
headers=headers,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error parsing first chunk value: {e}")
|
||||
@ -670,7 +727,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
)
|
||||
return await create_streaming_response(
|
||||
return await create_response(
|
||||
generator=selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers=custom_headers,
|
||||
@ -681,7 +738,7 @@ class ProxyBaseLLMRequestProcessing:
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
request_data=self.data,
|
||||
)
|
||||
return await create_streaming_response(
|
||||
return await create_response(
|
||||
generator=selected_data_generator,
|
||||
media_type="text/event-stream",
|
||||
headers=custom_headers,
|
||||
@ -923,11 +980,11 @@ class ProxyBaseLLMRequestProcessing:
|
||||
@staticmethod
|
||||
def _get_pre_call_type(
|
||||
route_type: Literal["acompletion", "aembedding", "aresponses", "allm_passthrough_route"],
|
||||
) -> Literal["completion", "embeddings", "responses", "allm_passthrough_route"]:
|
||||
) -> Literal["completion", "embedding", "responses", "allm_passthrough_route"]:
|
||||
if route_type == "acompletion":
|
||||
return "completion"
|
||||
elif route_type == "aembedding":
|
||||
return "embeddings"
|
||||
return "embedding"
|
||||
elif route_type == "aresponses":
|
||||
return "responses"
|
||||
elif route_type == "allm_passthrough_route":
|
||||
|
||||
@ -5,6 +5,7 @@
|
||||
# +-------------------------------------------------------------+
|
||||
# Qualifire - Evaluate LLM outputs for quality, safety, and reliability
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Literal, Optional, Type
|
||||
|
||||
@ -15,12 +16,17 @@ from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.base import GuardrailConfigModel
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
GUARDRAIL_NAME = "qualifire"
|
||||
DEFAULT_QUALIFIRE_API_BASE = "https://proxy.qualifire.ai"
|
||||
|
||||
|
||||
class QualifireGuardrail(CustomGuardrail):
|
||||
@ -44,7 +50,7 @@ class QualifireGuardrail(CustomGuardrail):
|
||||
|
||||
Args:
|
||||
api_key: API key for Qualifire (or use QUALIFIRE_API_KEY env var)
|
||||
api_base: Optional custom API base URL
|
||||
api_base: Optional custom API base URL (defaults to https://api.qualifire.ai)
|
||||
evaluation_id: Pre-configured evaluation ID from Qualifire dashboard
|
||||
prompt_injections: Enable prompt injection detection (default if no other checks)
|
||||
hallucinations_check: Enable hallucination detection
|
||||
@ -64,6 +70,7 @@ class QualifireGuardrail(CustomGuardrail):
|
||||
api_base
|
||||
or get_secret_str("QUALIFIRE_BASE_URL")
|
||||
or os.environ.get("QUALIFIRE_BASE_URL")
|
||||
or DEFAULT_QUALIFIRE_API_BASE
|
||||
)
|
||||
self.evaluation_id = evaluation_id
|
||||
self.prompt_injections = prompt_injections
|
||||
@ -79,7 +86,11 @@ class QualifireGuardrail(CustomGuardrail):
|
||||
if not self._has_any_check_enabled() and not self.evaluation_id:
|
||||
self.prompt_injections = True
|
||||
|
||||
self._client = None
|
||||
# Initialize async HTTP client for direct API calls
|
||||
self.async_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.GuardrailCallback
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _has_any_check_enabled(self) -> bool:
|
||||
@ -96,43 +107,22 @@ class QualifireGuardrail(CustomGuardrail):
|
||||
]
|
||||
)
|
||||
|
||||
def _get_client(self):
|
||||
"""Lazy initialization of Qualifire client."""
|
||||
if self._client is None:
|
||||
try:
|
||||
from qualifire.client import Client
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"qualifire package is required for QualifireGuardrail. "
|
||||
"Install it with: pip install qualifire"
|
||||
)
|
||||
|
||||
client_kwargs: Dict[str, Any] = {}
|
||||
if self.qualifire_api_key:
|
||||
client_kwargs["api_key"] = self.qualifire_api_key
|
||||
if self.qualifire_api_base:
|
||||
client_kwargs["base_url"] = self.qualifire_api_base
|
||||
|
||||
self._client = Client(**client_kwargs)
|
||||
|
||||
return self._client
|
||||
|
||||
def _convert_messages_to_qualifire_format(
|
||||
def _convert_messages_to_api_format(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[Any]:
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Convert LiteLLM messages to Qualifire's LLMMessage format.
|
||||
Convert LiteLLM messages to Qualifire API format.
|
||||
Supports tool calls for tool_selection_quality_check.
|
||||
"""
|
||||
try:
|
||||
from qualifire.types import LLMMessage, LLMToolCall
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"qualifire package is required for QualifireGuardrail. "
|
||||
"Install it with: pip install qualifire"
|
||||
)
|
||||
|
||||
qualifire_messages = []
|
||||
Returns a list of dicts matching the API's ModelInvocationCanonicalMessage schema:
|
||||
{
|
||||
"role": "user" | "assistant" | "system" | "tool",
|
||||
"content": "...",
|
||||
"tool_call_id": "...", # optional
|
||||
"tool_calls": [{"id": "...", "name": "...", "arguments": {...}}] # optional
|
||||
}
|
||||
"""
|
||||
api_messages = []
|
||||
for msg in messages:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
@ -147,42 +137,86 @@ class QualifireGuardrail(CustomGuardrail):
|
||||
text_parts.append(part)
|
||||
content = "\n".join(text_parts)
|
||||
|
||||
llm_message_kwargs: Dict[str, Any] = {
|
||||
api_message: Dict[str, Any] = {
|
||||
"role": role,
|
||||
"content": content if isinstance(content, str) else str(content),
|
||||
}
|
||||
|
||||
# Handle tool_call_id for tool response messages
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
if tool_call_id:
|
||||
api_message["tool_call_id"] = tool_call_id
|
||||
|
||||
# Handle tool calls if present
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if tool_calls and isinstance(tool_calls, list):
|
||||
qualifire_tool_calls = []
|
||||
api_tool_calls = []
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
function_info = tc.get("function", {})
|
||||
# Arguments can be a string (JSON) or dict
|
||||
args = function_info.get("arguments", {})
|
||||
if isinstance(args, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
args = json.loads(args)
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
qualifire_tool_calls.append(
|
||||
LLMToolCall(
|
||||
id=tc.get("id") or "",
|
||||
name=function_info.get("name") or "",
|
||||
arguments=args if isinstance(args, dict) else {},
|
||||
)
|
||||
api_tool_calls.append(
|
||||
{
|
||||
"id": tc.get("id") or "",
|
||||
"name": function_info.get("name") or "",
|
||||
"arguments": args if isinstance(args, dict) else {},
|
||||
}
|
||||
)
|
||||
if qualifire_tool_calls:
|
||||
llm_message_kwargs["tool_calls"] = qualifire_tool_calls
|
||||
if api_tool_calls:
|
||||
api_message["tool_calls"] = api_tool_calls
|
||||
|
||||
qualifire_messages.append(LLMMessage(**llm_message_kwargs))
|
||||
api_messages.append(api_message)
|
||||
|
||||
return qualifire_messages
|
||||
return api_messages
|
||||
|
||||
def _check_if_flagged(self, result: Any) -> bool:
|
||||
def _convert_tools_to_api_format(
|
||||
self, tools: Optional[List[Any]]
|
||||
) -> Optional[List[Dict[str, Any]]]:
|
||||
"""
|
||||
Convert OpenAI-format tools to Qualifire API format.
|
||||
|
||||
Returns a list of dicts matching the API's ModelInvocationToolDefinition schema:
|
||||
{
|
||||
"name": "...",
|
||||
"description": "...",
|
||||
"parameters": {...}
|
||||
}
|
||||
"""
|
||||
if not tools:
|
||||
return None
|
||||
|
||||
api_tools = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, dict):
|
||||
# Handle OpenAI function tool format
|
||||
if tool.get("type") == "function":
|
||||
function_def = tool.get("function", {})
|
||||
api_tools.append(
|
||||
{
|
||||
"name": function_def.get("name", ""),
|
||||
"description": function_def.get("description", ""),
|
||||
"parameters": function_def.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
# Handle direct tool format
|
||||
elif "name" in tool:
|
||||
api_tools.append(
|
||||
{
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": tool.get("parameters", {}),
|
||||
}
|
||||
)
|
||||
|
||||
return api_tools if api_tools else None
|
||||
|
||||
def _check_if_flagged(self, result: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Check if the Qualifire evaluation result indicates flagged content.
|
||||
|
||||
@ -190,65 +224,53 @@ class QualifireGuardrail(CustomGuardrail):
|
||||
A high score (close to 100) indicates GOOD content, low score indicates problems.
|
||||
"""
|
||||
# Check evaluation results for any flagged items
|
||||
evaluation_results = getattr(result, "evaluationResults", None) or []
|
||||
if isinstance(result, dict):
|
||||
evaluation_results = result.get("evaluationResults", []) or []
|
||||
evaluation_results = result.get("evaluationResults", []) or []
|
||||
|
||||
for eval_result in evaluation_results:
|
||||
results: List[Any] = []
|
||||
if isinstance(eval_result, dict):
|
||||
results = eval_result.get("results", []) or []
|
||||
else:
|
||||
results = getattr(eval_result, "results", []) or []
|
||||
|
||||
results = eval_result.get("results", []) or []
|
||||
for r in results:
|
||||
flagged = (
|
||||
r.get("flagged")
|
||||
if isinstance(r, dict)
|
||||
else getattr(r, "flagged", False)
|
||||
)
|
||||
if flagged:
|
||||
if r.get("flagged"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def _build_evaluate_kwargs(
|
||||
def _build_evaluate_payload(
|
||||
self,
|
||||
qualifire_messages: List[Any],
|
||||
api_messages: List[Dict[str, Any]],
|
||||
output: Optional[str],
|
||||
assertions: Optional[List[str]],
|
||||
available_tools: Optional[List[Any]],
|
||||
available_tools: Optional[List[Dict[str, Any]]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Build kwargs dictionary for the evaluate call."""
|
||||
kwargs: Dict[str, Any] = {"messages": qualifire_messages}
|
||||
"""Build payload dictionary for the /api/evaluation/evaluate endpoint."""
|
||||
payload: Dict[str, Any] = {"messages": api_messages}
|
||||
|
||||
if output is not None:
|
||||
kwargs["output"] = output
|
||||
payload["output"] = output
|
||||
|
||||
# Add enabled checks
|
||||
if self.prompt_injections:
|
||||
kwargs["prompt_injections"] = True
|
||||
payload["prompt_injections"] = True
|
||||
if self.hallucinations_check:
|
||||
kwargs["hallucinations_check"] = True
|
||||
payload["hallucinations_check"] = True
|
||||
if self.grounding_check:
|
||||
kwargs["grounding_check"] = True
|
||||
payload["grounding_check"] = True
|
||||
if self.pii_check:
|
||||
kwargs["pii_check"] = True
|
||||
payload["pii_check"] = True
|
||||
if self.content_moderation_check:
|
||||
kwargs["content_moderation_check"] = True
|
||||
payload["content_moderation_check"] = True
|
||||
if self.tool_selection_quality_check:
|
||||
# Only enable tool_selection_quality_check if available_tools is provided
|
||||
if available_tools:
|
||||
kwargs["tool_selection_quality_check"] = True
|
||||
kwargs["available_tools"] = available_tools
|
||||
payload["tool_selection_quality_check"] = True
|
||||
payload["available_tools"] = available_tools
|
||||
else:
|
||||
verbose_proxy_logger.debug(
|
||||
"Qualifire Guardrail: tool_selection_quality_check enabled but no available_tools provided, skipping this check"
|
||||
)
|
||||
if assertions:
|
||||
kwargs["assertions"] = assertions
|
||||
payload["assertions"] = assertions
|
||||
|
||||
return kwargs
|
||||
return payload
|
||||
|
||||
async def _run_qualifire_check(
|
||||
self,
|
||||
@ -274,11 +296,17 @@ class QualifireGuardrail(CustomGuardrail):
|
||||
assertions = dynamic_params.get("assertions") or self.assertions
|
||||
on_flagged = dynamic_params.get("on_flagged") or self.on_flagged
|
||||
|
||||
try:
|
||||
client = self._get_client()
|
||||
qualifire_messages = self._convert_messages_to_qualifire_format(messages)
|
||||
# Prepare headers
|
||||
headers = {
|
||||
"X-Qualifire-API-Key": self.qualifire_api_key or "",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Use invoke_evaluation if evaluation_id is provided
|
||||
try:
|
||||
# Convert messages to API format
|
||||
api_messages = self._convert_messages_to_api_format(messages)
|
||||
|
||||
# Use invoke endpoint if evaluation_id is provided
|
||||
if evaluation_id:
|
||||
# For invoke_evaluation, we need to extract input/output
|
||||
input_text = ""
|
||||
@ -291,25 +319,47 @@ class QualifireGuardrail(CustomGuardrail):
|
||||
input_text = content
|
||||
break
|
||||
|
||||
result = client.invoke_evaluation(
|
||||
evaluation_id=evaluation_id,
|
||||
input=input_text,
|
||||
output=output or "",
|
||||
)
|
||||
payload = {
|
||||
"evaluation_id": evaluation_id,
|
||||
"input": input_text,
|
||||
"output": output or "",
|
||||
"messages": api_messages,
|
||||
}
|
||||
|
||||
# Convert tools if provided
|
||||
api_tools = self._convert_tools_to_api_format(available_tools)
|
||||
if api_tools:
|
||||
payload["available_tools"] = api_tools
|
||||
|
||||
url = f"{self.qualifire_api_base}/api/evaluation/invoke"
|
||||
else:
|
||||
# Use evaluate with individual checks
|
||||
kwargs = self._build_evaluate_kwargs(
|
||||
qualifire_messages=qualifire_messages,
|
||||
# Use evaluate endpoint with individual checks
|
||||
api_tools = self._convert_tools_to_api_format(available_tools)
|
||||
payload = self._build_evaluate_payload(
|
||||
api_messages=api_messages,
|
||||
output=output,
|
||||
assertions=assertions,
|
||||
available_tools=available_tools,
|
||||
available_tools=api_tools,
|
||||
)
|
||||
result = client.evaluate(**kwargs)
|
||||
url = f"{self.qualifire_api_base}/api/evaluation/evaluate"
|
||||
|
||||
# Convert result to dict for logging
|
||||
verbose_proxy_logger.debug(
|
||||
f"Qualifire Guardrail: Making request to {url}"
|
||||
)
|
||||
|
||||
# Make the API request
|
||||
response = await self.async_handler.post(
|
||||
url=url,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Extract response info for logging
|
||||
qualifire_response = {
|
||||
"score": getattr(result, "score", None),
|
||||
"status": getattr(result, "status", None),
|
||||
"score": result.get("score"),
|
||||
"status": result.get("status"),
|
||||
}
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
|
||||
@ -229,7 +229,7 @@ from litellm.proxy.batches_endpoints.endpoints import router as batches_router
|
||||
from litellm.proxy.caching_routes import router as caching_router
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
create_streaming_response,
|
||||
create_response,
|
||||
)
|
||||
from litellm.proxy.common_utils.callback_utils import initialize_callbacks_on_proxy
|
||||
from litellm.proxy.common_utils.debug_utils import init_verbose_loggers
|
||||
@ -6824,7 +6824,7 @@ async def run_thread(
|
||||
if (
|
||||
"stream" in data and data["stream"] is True
|
||||
): # use generate_responses to stream responses
|
||||
return await create_streaming_response(
|
||||
return await create_response(
|
||||
generator=async_assistants_data_generator(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
response=response,
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Literal, Optional, Tuple, Union
|
||||
from typing import Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
@ -185,6 +185,10 @@ DEFINED_PROMETHEUS_METRICS = Literal[
|
||||
"litellm_redis_daily_spend_update_queue_size",
|
||||
"litellm_in_memory_spend_update_queue_size",
|
||||
"litellm_redis_spend_update_queue_size",
|
||||
# Cache metrics
|
||||
"litellm_cache_hits_metric",
|
||||
"litellm_cache_misses_metric",
|
||||
"litellm_cached_tokens_metric",
|
||||
]
|
||||
|
||||
|
||||
@ -436,6 +440,21 @@ class PrometheusMetricLabels:
|
||||
|
||||
litellm_redis_spend_update_queue_size: List[str] = []
|
||||
|
||||
# Cache metrics - track cache hits, misses, and tokens served from cache
|
||||
_cache_metric_labels = [
|
||||
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value,
|
||||
UserAPIKeyLabelNames.API_KEY_HASH.value,
|
||||
UserAPIKeyLabelNames.API_KEY_ALIAS.value,
|
||||
UserAPIKeyLabelNames.TEAM.value,
|
||||
UserAPIKeyLabelNames.TEAM_ALIAS.value,
|
||||
UserAPIKeyLabelNames.END_USER.value,
|
||||
UserAPIKeyLabelNames.USER.value,
|
||||
]
|
||||
|
||||
litellm_cache_hits_metric = _cache_metric_labels
|
||||
litellm_cache_misses_metric = _cache_metric_labels
|
||||
litellm_cached_tokens_metric = _cache_metric_labels
|
||||
|
||||
@staticmethod
|
||||
def get_labels(label_name: DEFINED_PROMETHEUS_METRICS) -> List[str]:
|
||||
default_labels = getattr(PrometheusMetricLabels, label_name)
|
||||
|
||||
@ -7718,6 +7718,11 @@ class ProviderConfigManager:
|
||||
return litellm.CometAPIEmbeddingConfig()
|
||||
elif litellm.LlmProviders.GITHUB_COPILOT == provider:
|
||||
return litellm.GithubCopilotEmbeddingConfig()
|
||||
elif litellm.LlmProviders.OPENROUTER == provider:
|
||||
from litellm.llms.openrouter.embedding.transformation import (
|
||||
OpenrouterEmbeddingConfig,
|
||||
)
|
||||
return OpenrouterEmbeddingConfig()
|
||||
elif litellm.LlmProviders.GIGACHAT == provider:
|
||||
return litellm.GigaChatEmbeddingConfig()
|
||||
elif litellm.LlmProviders.SAGEMAKER == provider:
|
||||
|
||||
@ -32,6 +32,23 @@
|
||||
}
|
||||
},
|
||||
"providers": {
|
||||
"abliteration": {
|
||||
"display_name": "Abliteration (`abliteration`)",
|
||||
"url": "https://docs.litellm.ai/docs/providers/abliteration",
|
||||
"endpoints": {
|
||||
"chat_completions": true,
|
||||
"messages": false,
|
||||
"responses": false,
|
||||
"embeddings": false,
|
||||
"image_generations": false,
|
||||
"audio_transcriptions": false,
|
||||
"audio_speech": false,
|
||||
"moderations": false,
|
||||
"batches": false,
|
||||
"rerank": false,
|
||||
"a2a": false
|
||||
}
|
||||
},
|
||||
"aiml": {
|
||||
"display_name": "AI/ML API (`aiml`)",
|
||||
"url": "https://docs.litellm.ai/docs/providers/aiml",
|
||||
@ -1559,7 +1576,7 @@
|
||||
"chat_completions": true,
|
||||
"messages": true,
|
||||
"responses": true,
|
||||
"embeddings": false,
|
||||
"embeddings": true,
|
||||
"image_generations": false,
|
||||
"audio_transcriptions": false,
|
||||
"audio_speech": false,
|
||||
|
||||
50
tests/litellm/llms/openai_like/test_abliteration_provider.py
Normal file
50
tests/litellm/llms/openai_like/test_abliteration_provider.py
Normal file
@ -0,0 +1,50 @@
|
||||
"""
|
||||
Unit tests for the Abliteration OpenAI-like provider.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../.."))
|
||||
)
|
||||
|
||||
from litellm.llms.openai_like.dynamic_config import create_config_class
|
||||
from litellm.llms.openai_like.json_loader import JSONProviderRegistry
|
||||
|
||||
ABLITERATION_BASE_URL = "https://api.abliteration.ai/v1"
|
||||
|
||||
|
||||
def _get_config():
|
||||
provider = JSONProviderRegistry.get("abliteration")
|
||||
assert provider is not None
|
||||
config_class = create_config_class(provider)
|
||||
return config_class()
|
||||
|
||||
|
||||
def test_abliteration_provider_registered():
|
||||
provider = JSONProviderRegistry.get("abliteration")
|
||||
assert provider is not None
|
||||
assert provider.base_url == ABLITERATION_BASE_URL
|
||||
assert provider.api_key_env == "ABLITERATION_API_KEY"
|
||||
|
||||
|
||||
def test_abliteration_resolves_env_api_key(monkeypatch):
|
||||
config = _get_config()
|
||||
monkeypatch.setenv("ABLITERATION_API_KEY", "test-key")
|
||||
api_base, api_key = config._get_openai_compatible_provider_info(None, None)
|
||||
assert api_base == ABLITERATION_BASE_URL
|
||||
assert api_key == "test-key"
|
||||
|
||||
|
||||
def test_abliteration_complete_url_appends_endpoint():
|
||||
config = _get_config()
|
||||
url = config.get_complete_url(
|
||||
api_base=ABLITERATION_BASE_URL,
|
||||
api_key="test-key",
|
||||
model="abliteration/abliterated-model",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
stream=False,
|
||||
)
|
||||
assert url == f"{ABLITERATION_BASE_URL}/chat/completions"
|
||||
@ -32,3 +32,18 @@ def test_completion_openrouter_image_generation():
|
||||
.message.images[0]["image_url"]["url"]
|
||||
.startswith("data:image/png;base64,")
|
||||
)
|
||||
|
||||
|
||||
def test_openrouter_embedding():
|
||||
"""Test OpenRouter embeddings support."""
|
||||
litellm._turn_on_debug()
|
||||
resp = litellm.embedding(
|
||||
model="openrouter/openai/text-embedding-3-small",
|
||||
input=["Hello world", "How are you?"],
|
||||
)
|
||||
print(resp)
|
||||
assert resp is not None
|
||||
assert len(resp.data) == 2
|
||||
assert resp.data[0]["embedding"] is not None
|
||||
assert isinstance(resp.data[0]["embedding"], list)
|
||||
assert len(resp.data[0]["embedding"]) > 0
|
||||
|
||||
@ -530,7 +530,7 @@ class TestPassthroughCallTypeHandling:
|
||||
)
|
||||
assert (
|
||||
ProxyBaseLLMRequestProcessing._get_pre_call_type(route_type="aembedding")
|
||||
== "embeddings"
|
||||
== "embedding"
|
||||
)
|
||||
assert (
|
||||
ProxyBaseLLMRequestProcessing._get_pre_call_type(route_type="aresponses")
|
||||
|
||||
211
tests/test_litellm/integrations/test_prometheus_cache_metrics.py
Normal file
211
tests/test_litellm/integrations/test_prometheus_cache_metrics.py
Normal file
@ -0,0 +1,211 @@
|
||||
"""
|
||||
Unit tests for cache Prometheus metrics.
|
||||
|
||||
Run with: poetry run pytest tests/test_litellm/integrations/test_prometheus_cache_metrics.py -v
|
||||
"""
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from litellm.types.integrations.prometheus import UserAPIKeyLabelValues
|
||||
|
||||
|
||||
class TestPrometheusCacheMetrics:
|
||||
"""Tests for cache-related Prometheus metrics"""
|
||||
|
||||
@pytest.fixture
|
||||
def sample_enum_values(self):
|
||||
"""Create sample enum values for labels"""
|
||||
return UserAPIKeyLabelValues(
|
||||
end_user="test-end-user",
|
||||
hashed_api_key="test-key-hash",
|
||||
api_key_alias="test-key-alias",
|
||||
team="test-team",
|
||||
team_alias="test-team-alias",
|
||||
user="test-user",
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
|
||||
def test_cache_metrics_defined_in_types(self):
|
||||
"""Test that cache metrics are defined in DEFINED_PROMETHEUS_METRICS"""
|
||||
from litellm.types.integrations.prometheus import DEFINED_PROMETHEUS_METRICS
|
||||
from typing import get_args
|
||||
|
||||
defined_metrics = get_args(DEFINED_PROMETHEUS_METRICS)
|
||||
|
||||
assert "litellm_cache_hits_metric" in defined_metrics
|
||||
assert "litellm_cache_misses_metric" in defined_metrics
|
||||
assert "litellm_cached_tokens_metric" in defined_metrics
|
||||
|
||||
def test_cache_metric_labels_defined(self):
|
||||
"""Test that cache metric labels are properly defined"""
|
||||
from litellm.types.integrations.prometheus import PrometheusMetricLabels
|
||||
|
||||
# Verify labels are defined for each cache metric
|
||||
assert hasattr(PrometheusMetricLabels, "litellm_cache_hits_metric")
|
||||
assert hasattr(PrometheusMetricLabels, "litellm_cache_misses_metric")
|
||||
assert hasattr(PrometheusMetricLabels, "litellm_cached_tokens_metric")
|
||||
|
||||
# Verify labels include expected keys
|
||||
expected_labels = [
|
||||
"model",
|
||||
"hashed_api_key",
|
||||
"api_key_alias",
|
||||
"team",
|
||||
"team_alias",
|
||||
"end_user",
|
||||
"user",
|
||||
]
|
||||
for label in expected_labels:
|
||||
assert label in PrometheusMetricLabels.litellm_cache_hits_metric
|
||||
assert label in PrometheusMetricLabels.litellm_cache_misses_metric
|
||||
assert label in PrometheusMetricLabels.litellm_cached_tokens_metric
|
||||
|
||||
def test_increment_cache_metrics_on_cache_hit(self, sample_enum_values):
|
||||
"""Test that cache hit increments the correct metrics"""
|
||||
# Create mock for PrometheusLogger instance
|
||||
mock_logger = MagicMock()
|
||||
|
||||
# Import the method directly and bind it to our mock
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
|
||||
# Create a mock standard logging payload with cache_hit=True
|
||||
standard_logging_payload = {
|
||||
"cache_hit": True,
|
||||
"total_tokens": 100,
|
||||
"prompt_tokens": 50,
|
||||
"completion_tokens": 50,
|
||||
"model_group": "openai",
|
||||
"request_tags": [],
|
||||
}
|
||||
|
||||
# Create mock metrics
|
||||
mock_logger.litellm_cache_hits_metric = MagicMock()
|
||||
mock_logger.litellm_cache_misses_metric = MagicMock()
|
||||
mock_logger.litellm_cached_tokens_metric = MagicMock()
|
||||
mock_logger.get_labels_for_metric = MagicMock(
|
||||
return_value=[
|
||||
"model",
|
||||
"hashed_api_key",
|
||||
"api_key_alias",
|
||||
"team",
|
||||
"team_alias",
|
||||
"end_user",
|
||||
"user",
|
||||
]
|
||||
)
|
||||
|
||||
# Call the method using unbound method approach
|
||||
PrometheusLogger._increment_cache_metrics(
|
||||
mock_logger,
|
||||
standard_logging_payload=standard_logging_payload,
|
||||
enum_values=sample_enum_values,
|
||||
)
|
||||
|
||||
# Verify cache hits metric was incremented
|
||||
mock_logger.litellm_cache_hits_metric.labels.assert_called()
|
||||
mock_logger.litellm_cache_hits_metric.labels().inc.assert_called_once()
|
||||
|
||||
# Verify cached tokens metric was incremented with total_tokens
|
||||
mock_logger.litellm_cached_tokens_metric.labels.assert_called()
|
||||
mock_logger.litellm_cached_tokens_metric.labels().inc.assert_called_once_with(
|
||||
100
|
||||
)
|
||||
|
||||
# Verify cache misses metric was NOT called
|
||||
mock_logger.litellm_cache_misses_metric.labels.assert_not_called()
|
||||
|
||||
def test_increment_cache_metrics_on_cache_miss(self, sample_enum_values):
|
||||
"""Test that cache miss increments the correct metrics"""
|
||||
# Create mock for PrometheusLogger instance
|
||||
mock_logger = MagicMock()
|
||||
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
|
||||
# Create a mock standard logging payload with cache_hit=False
|
||||
standard_logging_payload = {
|
||||
"cache_hit": False,
|
||||
"total_tokens": 100,
|
||||
"prompt_tokens": 50,
|
||||
"completion_tokens": 50,
|
||||
"model_group": "openai",
|
||||
"request_tags": [],
|
||||
}
|
||||
|
||||
# Create mock metrics
|
||||
mock_logger.litellm_cache_hits_metric = MagicMock()
|
||||
mock_logger.litellm_cache_misses_metric = MagicMock()
|
||||
mock_logger.litellm_cached_tokens_metric = MagicMock()
|
||||
mock_logger.get_labels_for_metric = MagicMock(
|
||||
return_value=[
|
||||
"model",
|
||||
"hashed_api_key",
|
||||
"api_key_alias",
|
||||
"team",
|
||||
"team_alias",
|
||||
"end_user",
|
||||
"user",
|
||||
]
|
||||
)
|
||||
|
||||
# Call the method
|
||||
PrometheusLogger._increment_cache_metrics(
|
||||
mock_logger,
|
||||
standard_logging_payload=standard_logging_payload,
|
||||
enum_values=sample_enum_values,
|
||||
)
|
||||
|
||||
# Verify cache misses metric was incremented
|
||||
mock_logger.litellm_cache_misses_metric.labels.assert_called()
|
||||
mock_logger.litellm_cache_misses_metric.labels().inc.assert_called_once()
|
||||
|
||||
# Verify cache hits and cached tokens metrics were NOT called
|
||||
mock_logger.litellm_cache_hits_metric.labels.assert_not_called()
|
||||
mock_logger.litellm_cached_tokens_metric.labels.assert_not_called()
|
||||
|
||||
def test_increment_cache_metrics_when_cache_hit_is_none(self, sample_enum_values):
|
||||
"""Test that no metrics are incremented when cache_hit is None"""
|
||||
# Create mock for PrometheusLogger instance
|
||||
mock_logger = MagicMock()
|
||||
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
|
||||
# Create a mock standard logging payload with cache_hit=None
|
||||
standard_logging_payload = {
|
||||
"cache_hit": None,
|
||||
"total_tokens": 100,
|
||||
"prompt_tokens": 50,
|
||||
"completion_tokens": 50,
|
||||
"model_group": "openai",
|
||||
"request_tags": [],
|
||||
}
|
||||
|
||||
# Create mock metrics
|
||||
mock_logger.litellm_cache_hits_metric = MagicMock()
|
||||
mock_logger.litellm_cache_misses_metric = MagicMock()
|
||||
mock_logger.litellm_cached_tokens_metric = MagicMock()
|
||||
mock_logger.get_labels_for_metric = MagicMock(
|
||||
return_value=[
|
||||
"model",
|
||||
"hashed_api_key",
|
||||
"api_key_alias",
|
||||
"team",
|
||||
"team_alias",
|
||||
"end_user",
|
||||
"user",
|
||||
]
|
||||
)
|
||||
|
||||
# Call the method
|
||||
PrometheusLogger._increment_cache_metrics(
|
||||
mock_logger,
|
||||
standard_logging_payload=standard_logging_payload,
|
||||
enum_values=sample_enum_values,
|
||||
)
|
||||
|
||||
# Verify NO metrics were called
|
||||
mock_logger.litellm_cache_hits_metric.labels.assert_not_called()
|
||||
mock_logger.litellm_cache_misses_metric.labels.assert_not_called()
|
||||
mock_logger.litellm_cached_tokens_metric.labels.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
@ -393,6 +393,63 @@ def test_get_request_tags_from_metadata_and_litellm_metadata():
|
||||
assert "User-Agent: litellm/1.0.0" in tags
|
||||
|
||||
|
||||
def test_get_request_tags_does_not_mutate_original_tags():
|
||||
"""
|
||||
Test that _get_request_tags does not mutate the original tags list in metadata.
|
||||
|
||||
This is a regression test for a bug where calling _get_request_tags multiple times
|
||||
would cause User-Agent tags to be duplicated because the function was mutating
|
||||
the original tags list instead of creating a copy.
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import StandardLoggingPayloadSetup
|
||||
|
||||
# Create metadata with original tags
|
||||
original_tags = ["custom-tag-1", "custom-tag-2"]
|
||||
metadata = {"tags": original_tags}
|
||||
litellm_params = {"metadata": metadata}
|
||||
proxy_server_request = {
|
||||
"headers": {
|
||||
"user-agent": "AsyncOpenAI/Python 1.99.9",
|
||||
}
|
||||
}
|
||||
|
||||
# Call _get_request_tags multiple times (simulating multiple callbacks)
|
||||
tags1 = StandardLoggingPayloadSetup._get_request_tags(
|
||||
litellm_params=litellm_params,
|
||||
proxy_server_request=proxy_server_request,
|
||||
)
|
||||
tags2 = StandardLoggingPayloadSetup._get_request_tags(
|
||||
litellm_params=litellm_params,
|
||||
proxy_server_request=proxy_server_request,
|
||||
)
|
||||
tags3 = StandardLoggingPayloadSetup._get_request_tags(
|
||||
litellm_params=litellm_params,
|
||||
proxy_server_request=proxy_server_request,
|
||||
)
|
||||
|
||||
# Verify the original tags list was NOT mutated
|
||||
assert original_tags == ["custom-tag-1", "custom-tag-2"], (
|
||||
f"Original tags list was mutated: {original_tags}"
|
||||
)
|
||||
assert metadata["tags"] == ["custom-tag-1", "custom-tag-2"], (
|
||||
f"metadata['tags'] was mutated: {metadata['tags']}"
|
||||
)
|
||||
|
||||
# Verify each returned list has exactly 2 User-Agent tags (not duplicated)
|
||||
user_agent_count_1 = len([t for t in tags1 if t.startswith("User-Agent:")])
|
||||
user_agent_count_2 = len([t for t in tags2 if t.startswith("User-Agent:")])
|
||||
user_agent_count_3 = len([t for t in tags3 if t.startswith("User-Agent:")])
|
||||
|
||||
assert user_agent_count_1 == 2, f"Expected 2 User-Agent tags, got {user_agent_count_1}"
|
||||
assert user_agent_count_2 == 2, f"Expected 2 User-Agent tags, got {user_agent_count_2}"
|
||||
assert user_agent_count_3 == 2, f"Expected 2 User-Agent tags, got {user_agent_count_3}"
|
||||
|
||||
# Verify all returned lists are independent (different objects)
|
||||
assert tags1 is not tags2
|
||||
assert tags2 is not tags3
|
||||
assert tags1 is not original_tags
|
||||
|
||||
|
||||
def test_get_extra_header_tags():
|
||||
"""Test the _get_extra_header_tags method with various scenarios."""
|
||||
import litellm
|
||||
|
||||
@ -0,0 +1,132 @@
|
||||
"""
|
||||
Unit tests for OpenRouter embedding transformation logic.
|
||||
"""
|
||||
from litellm.llms.openrouter.embedding.transformation import (
|
||||
OpenrouterEmbeddingConfig,
|
||||
)
|
||||
|
||||
|
||||
def test_openrouter_embedding_supported_params():
|
||||
"""Test that supported OpenAI params are correctly defined."""
|
||||
config = OpenrouterEmbeddingConfig()
|
||||
supported = config.get_supported_openai_params("test-model")
|
||||
|
||||
assert "timeout" in supported
|
||||
assert "dimensions" in supported
|
||||
assert "encoding_format" in supported
|
||||
assert "user" in supported
|
||||
|
||||
|
||||
def test_openrouter_embedding_transform_request():
|
||||
"""Test request transformation logic."""
|
||||
config = OpenrouterEmbeddingConfig()
|
||||
|
||||
# Test with string input
|
||||
result = config.transform_embedding_request(
|
||||
model="openrouter/google/text-embedding-004",
|
||||
input="Hello world",
|
||||
optional_params={},
|
||||
headers={},
|
||||
)
|
||||
|
||||
assert result["model"] == "google/text-embedding-004"
|
||||
assert result["input"] == ["Hello world"]
|
||||
|
||||
# Test with list input
|
||||
result = config.transform_embedding_request(
|
||||
model="google/text-embedding-004",
|
||||
input=["Hello", "World"],
|
||||
optional_params={"dimensions": 512},
|
||||
headers={},
|
||||
)
|
||||
|
||||
assert result["model"] == "google/text-embedding-004"
|
||||
assert result["input"] == ["Hello", "World"]
|
||||
assert result["dimensions"] == 512
|
||||
|
||||
|
||||
def test_openrouter_embedding_validate_environment():
|
||||
"""Test environment validation and header setup."""
|
||||
config = OpenrouterEmbeddingConfig()
|
||||
|
||||
# Test with API key
|
||||
headers = config.validate_environment(
|
||||
headers={"Custom-Header": "value"},
|
||||
model="test-model",
|
||||
messages=[],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key="test-api-key",
|
||||
)
|
||||
|
||||
# Should include OpenRouter-specific headers
|
||||
assert "HTTP-Referer" in headers
|
||||
assert "X-Title" in headers
|
||||
# Should include Content-Type header
|
||||
assert "Content-Type" in headers
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
# Should include Authorization header
|
||||
assert "Authorization" in headers
|
||||
assert headers["Authorization"] == "Bearer test-api-key"
|
||||
# Should preserve custom headers
|
||||
assert headers["Custom-Header"] == "value"
|
||||
|
||||
# Test without API key
|
||||
headers_no_key = config.validate_environment(
|
||||
headers={},
|
||||
model="test-model",
|
||||
messages=[],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
# Should still include OpenRouter headers but not Authorization
|
||||
assert "HTTP-Referer" in headers_no_key
|
||||
assert "X-Title" in headers_no_key
|
||||
assert "Content-Type" in headers_no_key
|
||||
assert "Authorization" not in headers_no_key
|
||||
|
||||
|
||||
def test_openrouter_embedding_get_complete_url():
|
||||
"""Test URL construction."""
|
||||
config = OpenrouterEmbeddingConfig()
|
||||
|
||||
url = config.get_complete_url(
|
||||
api_base="https://openrouter.ai/api/v1",
|
||||
api_key="test-key",
|
||||
model="test-model",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert url == "https://openrouter.ai/api/v1/embeddings"
|
||||
|
||||
# Test with trailing slash
|
||||
url = config.get_complete_url(
|
||||
api_base="https://openrouter.ai/api/v1/",
|
||||
api_key="test-key",
|
||||
model="test-model",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert url == "https://openrouter.ai/api/v1/embeddings"
|
||||
|
||||
|
||||
def test_openrouter_embedding_map_params():
|
||||
"""Test parameter mapping."""
|
||||
config = OpenrouterEmbeddingConfig()
|
||||
|
||||
result = config.map_openai_params(
|
||||
non_default_params={"dimensions": 512, "timeout": 30, "unsupported": "value"},
|
||||
optional_params={},
|
||||
model="test-model",
|
||||
drop_params=False,
|
||||
)
|
||||
|
||||
# Supported params should be included
|
||||
assert result["dimensions"] == 512
|
||||
assert result["timeout"] == 30
|
||||
# Unsupported params should not be included
|
||||
assert "unsupported" not in result
|
||||
@ -2,7 +2,6 @@
|
||||
Unit tests for Qualifire guardrail integration.
|
||||
"""
|
||||
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@ -75,9 +74,37 @@ class TestQualifireGuardrailInit:
|
||||
|
||||
assert guardrail.on_flagged == "monitor"
|
||||
|
||||
def test_init_with_default_api_base(self):
|
||||
"""Test that default API base is set when not provided."""
|
||||
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
|
||||
DEFAULT_QUALIFIRE_API_BASE,
|
||||
QualifireGuardrail,
|
||||
)
|
||||
|
||||
guardrail = QualifireGuardrail(
|
||||
api_key="test_key",
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
|
||||
assert guardrail.qualifire_api_base == DEFAULT_QUALIFIRE_API_BASE
|
||||
|
||||
def test_init_with_custom_api_base(self):
|
||||
"""Test initialization with custom API base URL."""
|
||||
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
|
||||
QualifireGuardrail,
|
||||
)
|
||||
|
||||
guardrail = QualifireGuardrail(
|
||||
api_key="test_key",
|
||||
api_base="https://custom.qualifire.ai",
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
|
||||
assert guardrail.qualifire_api_base == "https://custom.qualifire.ai"
|
||||
|
||||
|
||||
class TestQualifireGuardrailMessageConversion:
|
||||
"""Tests for message conversion to Qualifire format."""
|
||||
"""Tests for message conversion to API format."""
|
||||
|
||||
def test_convert_simple_messages(self):
|
||||
"""Test conversion of simple text messages."""
|
||||
@ -95,15 +122,13 @@ class TestQualifireGuardrailMessageConversion:
|
||||
{"role": "assistant", "content": "Hi there!"},
|
||||
]
|
||||
|
||||
# Create mock LLMMessage class
|
||||
mock_llm_message = MagicMock()
|
||||
result = guardrail._convert_messages_to_api_format(messages)
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire.QualifireGuardrail._convert_messages_to_qualifire_format"
|
||||
) as mock_convert:
|
||||
mock_convert.return_value = [mock_llm_message, mock_llm_message]
|
||||
result = guardrail._convert_messages_to_qualifire_format(messages)
|
||||
assert len(result) == 2
|
||||
assert len(result) == 2
|
||||
assert result[0]["role"] == "user"
|
||||
assert result[0]["content"] == "Hello, world!"
|
||||
assert result[1]["role"] == "assistant"
|
||||
assert result[1]["content"] == "Hi there!"
|
||||
|
||||
def test_convert_multimodal_messages(self):
|
||||
"""Test conversion of multimodal messages with text parts."""
|
||||
@ -126,112 +151,258 @@ class TestQualifireGuardrailMessageConversion:
|
||||
},
|
||||
]
|
||||
|
||||
with patch(
|
||||
"litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire.QualifireGuardrail._convert_messages_to_qualifire_format"
|
||||
) as mock_convert:
|
||||
mock_convert.return_value = [MagicMock()]
|
||||
result = guardrail._convert_messages_to_qualifire_format(messages)
|
||||
assert len(result) == 1
|
||||
result = guardrail._convert_messages_to_api_format(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "user"
|
||||
assert result[0]["content"] == "First part\nSecond part"
|
||||
|
||||
def test_convert_messages_with_tool_calls(self):
|
||||
"""Test conversion of messages with tool calls."""
|
||||
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
|
||||
QualifireGuardrail,
|
||||
)
|
||||
|
||||
guardrail = QualifireGuardrail(
|
||||
api_key="test_key",
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "NYC"}',
|
||||
},
|
||||
}
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
result = guardrail._convert_messages_to_api_format(messages)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["role"] == "assistant"
|
||||
assert "tool_calls" in result[0]
|
||||
assert len(result[0]["tool_calls"]) == 1
|
||||
assert result[0]["tool_calls"][0]["id"] == "call_123"
|
||||
assert result[0]["tool_calls"][0]["name"] == "get_weather"
|
||||
assert result[0]["tool_calls"][0]["arguments"] == {"location": "NYC"}
|
||||
|
||||
|
||||
class TestQualifireGuardrailEvaluateKwargs:
|
||||
"""Tests for evaluate kwargs passed to Qualifire client."""
|
||||
class TestQualifireGuardrailToolConversion:
|
||||
"""Tests for tool definition conversion."""
|
||||
|
||||
def test_convert_openai_function_tools(self):
|
||||
"""Test conversion of OpenAI function tool format."""
|
||||
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
|
||||
QualifireGuardrail,
|
||||
)
|
||||
|
||||
guardrail = QualifireGuardrail(
|
||||
api_key="test_key",
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a location",
|
||||
"parameters": {"type": "object", "properties": {}},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
result = guardrail._convert_tools_to_api_format(tools)
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "get_weather"
|
||||
assert result[0]["description"] == "Get weather for a location"
|
||||
|
||||
def test_convert_empty_tools(self):
|
||||
"""Test that empty tools returns None."""
|
||||
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
|
||||
QualifireGuardrail,
|
||||
)
|
||||
|
||||
guardrail = QualifireGuardrail(
|
||||
api_key="test_key",
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
|
||||
result = guardrail._convert_tools_to_api_format(None)
|
||||
assert result is None
|
||||
|
||||
result = guardrail._convert_tools_to_api_format([])
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestQualifireGuardrailAPICall:
|
||||
"""Tests for API call with httpx client."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_called_with_prompt_injections(self):
|
||||
"""Test that evaluate is called with prompt_injections enabled."""
|
||||
# Mock the qualifire module and its types
|
||||
mock_qualifire_types = MagicMock()
|
||||
mock_llm_message = MagicMock()
|
||||
mock_llm_tool_call = MagicMock()
|
||||
mock_message_instance = MagicMock()
|
||||
mock_llm_message.return_value = mock_message_instance
|
||||
|
||||
mock_qualifire_types.LLMMessage = mock_llm_message
|
||||
mock_qualifire_types.LLMToolCall = mock_llm_tool_call
|
||||
|
||||
with patch.dict('sys.modules', {'qualifire': MagicMock(), 'qualifire.types': mock_qualifire_types}):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
|
||||
QualifireGuardrail,
|
||||
)
|
||||
"""Test that evaluate endpoint is called with prompt_injections enabled."""
|
||||
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
|
||||
QualifireGuardrail,
|
||||
)
|
||||
|
||||
guardrail = QualifireGuardrail(
|
||||
api_key="test_key",
|
||||
prompt_injections=True,
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
guardrail = QualifireGuardrail(
|
||||
api_key="test_key",
|
||||
prompt_injections=True,
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
|
||||
# Mock the client
|
||||
mock_client = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.score = 100
|
||||
mock_result.status = "completed"
|
||||
mock_result.evaluationResults = []
|
||||
mock_client.evaluate.return_value = mock_result
|
||||
guardrail._client = mock_client
|
||||
# Mock the async HTTP handler
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"score": 100,
|
||||
"status": "completed",
|
||||
"evaluationResults": [],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
guardrail.async_handler.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
await guardrail._run_qualifire_check(
|
||||
messages=messages, output=None, dynamic_params={}
|
||||
)
|
||||
await guardrail._run_qualifire_check(
|
||||
messages=messages, output=None, dynamic_params={}
|
||||
)
|
||||
|
||||
# Verify evaluate was called with correct kwargs
|
||||
mock_client.evaluate.assert_called_once()
|
||||
call_kwargs = mock_client.evaluate.call_args[1]
|
||||
assert call_kwargs["prompt_injections"] is True
|
||||
assert "messages" in call_kwargs
|
||||
# Verify the API was called
|
||||
guardrail.async_handler.post.assert_called_once()
|
||||
call_kwargs = guardrail.async_handler.post.call_args[1]
|
||||
|
||||
assert "json" in call_kwargs
|
||||
payload = call_kwargs["json"]
|
||||
assert payload["prompt_injections"] is True
|
||||
assert "messages" in payload
|
||||
assert call_kwargs["url"].endswith("/api/evaluation/evaluate")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_called_with_multiple_checks(self):
|
||||
"""Test that evaluate is called with multiple checks enabled."""
|
||||
# Mock the qualifire module and its types
|
||||
mock_qualifire_types = MagicMock()
|
||||
mock_llm_message = MagicMock()
|
||||
mock_llm_tool_call = MagicMock()
|
||||
mock_message_instance = MagicMock()
|
||||
mock_llm_message.return_value = mock_message_instance
|
||||
|
||||
mock_qualifire_types.LLMMessage = mock_llm_message
|
||||
mock_qualifire_types.LLMToolCall = mock_llm_tool_call
|
||||
|
||||
with patch.dict('sys.modules', {'qualifire': MagicMock(), 'qualifire.types': mock_qualifire_types}):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
|
||||
QualifireGuardrail,
|
||||
)
|
||||
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
|
||||
QualifireGuardrail,
|
||||
)
|
||||
|
||||
guardrail = QualifireGuardrail(
|
||||
api_key="test_key",
|
||||
prompt_injections=True,
|
||||
pii_check=True,
|
||||
hallucinations_check=True,
|
||||
assertions=["Output must be valid JSON"],
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
guardrail = QualifireGuardrail(
|
||||
api_key="test_key",
|
||||
prompt_injections=True,
|
||||
pii_check=True,
|
||||
hallucinations_check=True,
|
||||
assertions=["Output must be valid JSON"],
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
|
||||
# Mock the client
|
||||
mock_client = MagicMock()
|
||||
mock_result = MagicMock()
|
||||
mock_result.score = 100
|
||||
mock_result.status = "completed"
|
||||
mock_result.evaluationResults = []
|
||||
mock_client.evaluate.return_value = mock_result
|
||||
guardrail._client = mock_client
|
||||
# Mock the async HTTP handler
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"score": 100,
|
||||
"status": "completed",
|
||||
"evaluationResults": [],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
guardrail.async_handler.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
await guardrail._run_qualifire_check(
|
||||
messages=messages, output="Test output", dynamic_params={}
|
||||
)
|
||||
await guardrail._run_qualifire_check(
|
||||
messages=messages, output="Test output", dynamic_params={}
|
||||
)
|
||||
|
||||
# Verify evaluate was called with correct kwargs
|
||||
mock_client.evaluate.assert_called_once()
|
||||
call_kwargs = mock_client.evaluate.call_args[1]
|
||||
assert call_kwargs["prompt_injections"] is True
|
||||
assert call_kwargs["pii_check"] is True
|
||||
assert call_kwargs["hallucinations_check"] is True
|
||||
assert call_kwargs["assertions"] == ["Output must be valid JSON"]
|
||||
assert call_kwargs["output"] == "Test output"
|
||||
# Verify the API was called with correct payload
|
||||
guardrail.async_handler.post.assert_called_once()
|
||||
call_kwargs = guardrail.async_handler.post.call_args[1]
|
||||
|
||||
payload = call_kwargs["json"]
|
||||
assert payload["prompt_injections"] is True
|
||||
assert payload["pii_check"] is True
|
||||
assert payload["hallucinations_check"] is True
|
||||
assert payload["assertions"] == ["Output must be valid JSON"]
|
||||
assert payload["output"] == "Test output"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invoke_endpoint_used_with_evaluation_id(self):
|
||||
"""Test that invoke endpoint is used when evaluation_id is provided."""
|
||||
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
|
||||
QualifireGuardrail,
|
||||
)
|
||||
|
||||
guardrail = QualifireGuardrail(
|
||||
api_key="test_key",
|
||||
evaluation_id="eval_123",
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
|
||||
# Mock the async HTTP handler
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"score": 100,
|
||||
"status": "completed",
|
||||
"evaluationResults": [],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
guardrail.async_handler.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello, world!"}]
|
||||
|
||||
await guardrail._run_qualifire_check(
|
||||
messages=messages, output="Test output", dynamic_params={}
|
||||
)
|
||||
|
||||
# Verify the invoke endpoint was called
|
||||
guardrail.async_handler.post.assert_called_once()
|
||||
call_kwargs = guardrail.async_handler.post.call_args[1]
|
||||
|
||||
assert call_kwargs["url"].endswith("/api/evaluation/invoke")
|
||||
payload = call_kwargs["json"]
|
||||
assert payload["evaluation_id"] == "eval_123"
|
||||
assert payload["input"] == "Hello, world!"
|
||||
assert payload["output"] == "Test output"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_correct_headers_sent(self):
|
||||
"""Test that correct headers are sent with the API request."""
|
||||
from litellm.proxy.guardrails.guardrail_hooks.qualifire.qualifire import (
|
||||
QualifireGuardrail,
|
||||
)
|
||||
|
||||
guardrail = QualifireGuardrail(
|
||||
api_key="my_api_key",
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
|
||||
# Mock the async HTTP handler
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"score": 100,
|
||||
"status": "completed",
|
||||
"evaluationResults": [],
|
||||
}
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
guardrail.async_handler.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
messages = [{"role": "user", "content": "Hello!"}]
|
||||
|
||||
await guardrail._run_qualifire_check(
|
||||
messages=messages, output=None, dynamic_params={}
|
||||
)
|
||||
|
||||
call_kwargs = guardrail.async_handler.post.call_args[1]
|
||||
headers = call_kwargs["headers"]
|
||||
|
||||
assert headers["X-Qualifire-API-Key"] == "my_api_key"
|
||||
assert headers["Content-Type"] == "application/json"
|
||||
|
||||
|
||||
class TestQualifireGuardrailCheckIfFlagged:
|
||||
@ -248,12 +419,14 @@ class TestQualifireGuardrailCheckIfFlagged:
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
|
||||
# Mock result with completed status and no flagged items
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "completed"
|
||||
mock_result.evaluationResults = []
|
||||
# Result with completed status and no flagged items (dict format)
|
||||
result = {
|
||||
"status": "completed",
|
||||
"score": 100,
|
||||
"evaluationResults": [],
|
||||
}
|
||||
|
||||
assert guardrail._check_if_flagged(mock_result) is False
|
||||
assert guardrail._check_if_flagged(result) is False
|
||||
|
||||
def test_check_if_flagged_returns_true_for_flagged_content(self):
|
||||
"""Test that _check_if_flagged returns True when content is flagged."""
|
||||
@ -266,18 +439,25 @@ class TestQualifireGuardrailCheckIfFlagged:
|
||||
guardrail_name="test_guardrail",
|
||||
)
|
||||
|
||||
# Mock result with flagged item
|
||||
mock_inner_result = MagicMock()
|
||||
mock_inner_result.flagged = True
|
||||
# Result with flagged item (dict format matching API response)
|
||||
result = {
|
||||
"status": "completed",
|
||||
"score": 15,
|
||||
"evaluationResults": [
|
||||
{
|
||||
"type": "prompt_injection",
|
||||
"results": [
|
||||
{
|
||||
"flagged": True,
|
||||
"score": 0.15,
|
||||
"reason": "Prompt injection detected",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mock_eval_result = MagicMock()
|
||||
mock_eval_result.results = [mock_inner_result]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "completed"
|
||||
mock_result.evaluationResults = [mock_eval_result]
|
||||
|
||||
assert guardrail._check_if_flagged(mock_result) is True
|
||||
assert guardrail._check_if_flagged(result) is True
|
||||
|
||||
def test_check_if_flagged_returns_false_when_no_flagged_items(self):
|
||||
"""Test that _check_if_flagged returns False when no items are flagged."""
|
||||
@ -291,17 +471,24 @@ class TestQualifireGuardrailCheckIfFlagged:
|
||||
)
|
||||
|
||||
# Result with evaluation results but nothing flagged
|
||||
mock_inner_result = MagicMock()
|
||||
mock_inner_result.flagged = False
|
||||
result = {
|
||||
"status": "completed",
|
||||
"score": 95,
|
||||
"evaluationResults": [
|
||||
{
|
||||
"type": "prompt_injection",
|
||||
"results": [
|
||||
{
|
||||
"flagged": False,
|
||||
"score": 0.95,
|
||||
"reason": "No issues detected",
|
||||
}
|
||||
],
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
mock_eval_result = MagicMock()
|
||||
mock_eval_result.results = [mock_inner_result]
|
||||
|
||||
mock_result = MagicMock()
|
||||
mock_result.status = "success"
|
||||
mock_result.evaluationResults = [mock_eval_result]
|
||||
|
||||
assert guardrail._check_if_flagged(mock_result) is False
|
||||
assert guardrail._check_if_flagged(result) is False
|
||||
|
||||
|
||||
class TestQualifireGuardrailShouldRun:
|
||||
|
||||
@ -3,7 +3,7 @@ from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import Request, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
|
||||
import litellm
|
||||
from litellm._uuid import uuid
|
||||
@ -11,9 +11,10 @@ from litellm.integrations.opentelemetry import UserAPIKeyAuth
|
||||
from litellm.proxy.common_request_processing import (
|
||||
ProxyBaseLLMRequestProcessing,
|
||||
ProxyConfig,
|
||||
_extract_error_from_sse_chunk,
|
||||
_get_cost_breakdown_from_logging_obj,
|
||||
_parse_event_data_for_error,
|
||||
create_streaming_response,
|
||||
create_response,
|
||||
)
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
@ -680,21 +681,27 @@ class TestCommonRequestProcessingHelpers:
|
||||
assert await _parse_event_data_for_error(event_line) == expected_code
|
||||
|
||||
async def test_create_streaming_response_first_chunk_is_error(self):
|
||||
"""
|
||||
Test that when the first chunk is an error, a JSON error response is returned
|
||||
instead of an SSE streaming response
|
||||
"""
|
||||
async def mock_generator():
|
||||
yield 'data: {"error": {"code": 403, "message": "forbidden"}}\n\n'
|
||||
yield 'data: {"content": "more data"}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
response = await create_streaming_response(
|
||||
response = await create_response(
|
||||
mock_generator(), "text/event-stream", {}
|
||||
)
|
||||
# Should return JSONResponse instead of StreamingResponse
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_403_FORBIDDEN
|
||||
content = await self.consume_stream(response)
|
||||
assert content == [
|
||||
'data: {"error": {"code": 403, "message": "forbidden"}}\n\n',
|
||||
'data: {"content": "more data"}\n\n',
|
||||
"data: [DONE]\n\n",
|
||||
]
|
||||
# Verify the response is in standard JSON error format
|
||||
import json
|
||||
body = json.loads(response.body.decode())
|
||||
assert "error" in body
|
||||
assert body["error"]["code"] == 403
|
||||
assert body["error"]["message"] == "forbidden"
|
||||
|
||||
async def test_create_streaming_response_first_chunk_not_error(self):
|
||||
async def mock_generator():
|
||||
@ -702,7 +709,7 @@ class TestCommonRequestProcessingHelpers:
|
||||
yield 'data: {"content": "second part"}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
response = await create_streaming_response(
|
||||
response = await create_response(
|
||||
mock_generator(), "text/event-stream", {}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@ -719,7 +726,7 @@ class TestCommonRequestProcessingHelpers:
|
||||
yield
|
||||
# Implicitly raises StopAsyncIteration
|
||||
|
||||
response = await create_streaming_response(
|
||||
response = await create_response(
|
||||
mock_generator(), "text/event-stream", {}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
@ -732,7 +739,7 @@ class TestCommonRequestProcessingHelpers:
|
||||
mock_gen = AsyncMock()
|
||||
mock_gen.__anext__.side_effect = StopAsyncIteration
|
||||
|
||||
response = await create_streaming_response(mock_gen, "text/event-stream", {})
|
||||
response = await create_response(mock_gen, "text/event-stream", {})
|
||||
assert response.status_code == status.HTTP_200_OK
|
||||
content = await self.consume_stream(response)
|
||||
assert content == []
|
||||
@ -743,7 +750,7 @@ class TestCommonRequestProcessingHelpers:
|
||||
mock_gen = AsyncMock()
|
||||
mock_gen.__anext__.side_effect = ValueError("Test error from generator")
|
||||
|
||||
response = await create_streaming_response(mock_gen, "text/event-stream", {})
|
||||
response = await create_response(mock_gen, "text/event-stream", {})
|
||||
assert response.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR
|
||||
content = await self.consume_stream(response)
|
||||
expected_error_data = {
|
||||
@ -760,19 +767,24 @@ class TestCommonRequestProcessingHelpers:
|
||||
assert content[1] == "data: [DONE]\n\n"
|
||||
|
||||
async def test_create_streaming_response_first_chunk_error_string_code(self):
|
||||
"""
|
||||
Test that when the first chunk contains a string error code, a JSON error response is returned
|
||||
"""
|
||||
async def mock_generator():
|
||||
yield 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
response = await create_streaming_response(
|
||||
response = await create_response(
|
||||
mock_generator(), "text/event-stream", {}
|
||||
)
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS
|
||||
content = await self.consume_stream(response)
|
||||
assert content == [
|
||||
'data: {"error": {"code": "429", "message": "too many requests"}}\n\n',
|
||||
"data: [DONE]\n\n",
|
||||
]
|
||||
# Verify the response is in standard JSON error format
|
||||
import json
|
||||
body = json.loads(response.body.decode())
|
||||
assert "error" in body
|
||||
assert body["error"]["code"] == "429"
|
||||
assert body["error"]["message"] == "too many requests"
|
||||
|
||||
async def test_create_streaming_response_custom_headers(self):
|
||||
async def mock_generator():
|
||||
@ -780,7 +792,7 @@ class TestCommonRequestProcessingHelpers:
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
custom_headers = {"X-Custom-Header": "TestValue"}
|
||||
response = await create_streaming_response(
|
||||
response = await create_response(
|
||||
mock_generator(), "text/event-stream", custom_headers
|
||||
)
|
||||
assert response.headers["x-custom-header"] == "TestValue"
|
||||
@ -790,7 +802,7 @@ class TestCommonRequestProcessingHelpers:
|
||||
yield 'data: {"content": "data"}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
response = await create_streaming_response(
|
||||
response = await create_response(
|
||||
mock_generator(),
|
||||
"text/event-stream",
|
||||
{},
|
||||
@ -807,7 +819,7 @@ class TestCommonRequestProcessingHelpers:
|
||||
async def mock_generator():
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
response = await create_streaming_response(
|
||||
response = await create_response(
|
||||
mock_generator(), "text/event-stream", {}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK # Default status
|
||||
@ -820,7 +832,7 @@ class TestCommonRequestProcessingHelpers:
|
||||
yield 'data: {"content": "actual data"}\n\n'
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
response = await create_streaming_response(
|
||||
response = await create_response(
|
||||
mock_generator(), "text/event-stream", {}
|
||||
)
|
||||
assert response.status_code == status.HTTP_200_OK # Default status
|
||||
@ -851,7 +863,7 @@ class TestCommonRequestProcessingHelpers:
|
||||
|
||||
# Patch the tracer in the common_request_processing module
|
||||
with patch("litellm.proxy.common_request_processing.tracer", mock_tracer):
|
||||
response = await create_streaming_response(
|
||||
response = await create_response(
|
||||
mock_generator(), "text/event-stream", {}
|
||||
)
|
||||
|
||||
@ -888,7 +900,10 @@ class TestCommonRequestProcessingHelpers:
|
||||
), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}"
|
||||
|
||||
async def test_create_streaming_response_dd_trace_with_error_chunk(self):
|
||||
"""Test that dd trace is applied even when the first chunk contains an error"""
|
||||
"""
|
||||
Test that when the first chunk contains an error, JSONResponse is returned
|
||||
and tracing is not triggered (since it's not a streaming response)
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
# Create a mock tracer
|
||||
@ -905,28 +920,107 @@ class TestCommonRequestProcessingHelpers:
|
||||
|
||||
# Patch the tracer in the common_request_processing module
|
||||
with patch("litellm.proxy.common_request_processing.tracer", mock_tracer):
|
||||
response = await create_streaming_response(
|
||||
response = await create_response(
|
||||
mock_generator(), "text/event-stream", {}
|
||||
)
|
||||
|
||||
# Even with error, status should be set to error code but tracing should still work
|
||||
# Should return JSONResponse instead of StreamingResponse
|
||||
assert isinstance(response, JSONResponse)
|
||||
assert response.status_code == 400
|
||||
|
||||
# Consume the stream to trigger the tracer calls
|
||||
content = await self.consume_stream(response)
|
||||
# Verify the response is in standard JSON error format
|
||||
import json
|
||||
body = json.loads(response.body.decode())
|
||||
assert "error" in body
|
||||
assert body["error"]["code"] == 400
|
||||
assert body["error"]["message"] == "bad request"
|
||||
|
||||
# Verify all chunks are present
|
||||
assert len(content) == 3
|
||||
# Since JSONResponse is returned instead of StreamingResponse, streaming tracing should not be triggered
|
||||
# tracer.trace should not be called
|
||||
assert mock_tracer.trace.call_count == 0
|
||||
|
||||
# Verify that tracer.trace was called for each chunk
|
||||
assert mock_tracer.trace.call_count == 3
|
||||
|
||||
# Verify that each call was made with the correct operation name
|
||||
actual_calls = mock_tracer.trace.call_args_list
|
||||
assert len(actual_calls) == 3
|
||||
class TestExtractErrorFromSSEChunk:
|
||||
"""Tests for _extract_error_from_sse_chunk function"""
|
||||
|
||||
def test_extract_error_from_sse_chunk_with_valid_error(self):
|
||||
"""Test extracting error information from a standard SSE chunk"""
|
||||
chunk = 'data: {"error": {"code": 403, "message": "forbidden", "type": "auth_error", "param": "api_key"}}\n\n'
|
||||
error = _extract_error_from_sse_chunk(chunk)
|
||||
|
||||
assert error["code"] == 403
|
||||
assert error["message"] == "forbidden"
|
||||
assert error["type"] == "auth_error"
|
||||
assert error["param"] == "api_key"
|
||||
|
||||
def test_extract_error_from_sse_chunk_with_string_code(self):
|
||||
"""Test error code as string type"""
|
||||
chunk = 'data: {"error": {"code": "429", "message": "too many requests"}}\n\n'
|
||||
error = _extract_error_from_sse_chunk(chunk)
|
||||
|
||||
assert error["code"] == "429"
|
||||
assert error["message"] == "too many requests"
|
||||
|
||||
def test_extract_error_from_sse_chunk_with_bytes(self):
|
||||
"""Test input as bytes type"""
|
||||
chunk = b'data: {"error": {"code": 500, "message": "internal error"}}\n\n'
|
||||
error = _extract_error_from_sse_chunk(chunk)
|
||||
|
||||
assert error["code"] == 500
|
||||
assert error["message"] == "internal error"
|
||||
|
||||
def test_extract_error_from_sse_chunk_with_done(self):
|
||||
"""Test [DONE] marker should return default error"""
|
||||
chunk = "data: [DONE]\n\n"
|
||||
error = _extract_error_from_sse_chunk(chunk)
|
||||
|
||||
assert error["message"] == "Unknown error"
|
||||
assert error["type"] == "internal_server_error"
|
||||
assert error["code"] == "500"
|
||||
assert error["param"] is None
|
||||
|
||||
def test_extract_error_from_sse_chunk_without_error_field(self):
|
||||
"""Test missing error field should return default error"""
|
||||
chunk = 'data: {"content": "some content"}\n\n'
|
||||
error = _extract_error_from_sse_chunk(chunk)
|
||||
|
||||
assert error["message"] == "Unknown error"
|
||||
assert error["type"] == "internal_server_error"
|
||||
assert error["code"] == "500"
|
||||
|
||||
def test_extract_error_from_sse_chunk_with_invalid_json(self):
|
||||
"""Test invalid JSON should return default error"""
|
||||
chunk = 'data: {invalid json}\n\n'
|
||||
error = _extract_error_from_sse_chunk(chunk)
|
||||
|
||||
assert error["message"] == "Unknown error"
|
||||
assert error["type"] == "internal_server_error"
|
||||
assert error["code"] == "500"
|
||||
|
||||
def test_extract_error_from_sse_chunk_without_data_prefix(self):
|
||||
"""Test missing 'data:' prefix should return default error"""
|
||||
chunk = '{"error": {"code": 400, "message": "bad request"}}\n\n'
|
||||
error = _extract_error_from_sse_chunk(chunk)
|
||||
|
||||
assert error["message"] == "Unknown error"
|
||||
assert error["type"] == "internal_server_error"
|
||||
assert error["code"] == "500"
|
||||
|
||||
def test_extract_error_from_sse_chunk_with_empty_string(self):
|
||||
"""Test empty string should return default error"""
|
||||
chunk = ""
|
||||
error = _extract_error_from_sse_chunk(chunk)
|
||||
|
||||
assert error["message"] == "Unknown error"
|
||||
assert error["type"] == "internal_server_error"
|
||||
assert error["code"] == "500"
|
||||
|
||||
def test_extract_error_from_sse_chunk_with_minimal_error(self):
|
||||
"""Test minimal error object"""
|
||||
chunk = 'data: {"error": {"message": "error occurred"}}\n\n'
|
||||
error = _extract_error_from_sse_chunk(chunk)
|
||||
|
||||
assert error["message"] == "error occurred"
|
||||
# Other fields should be obtained from the original error object (if exists)
|
||||
|
||||
|
||||
for i, call in enumerate(actual_calls):
|
||||
args, kwargs = call
|
||||
assert (
|
||||
args[0] == "streaming.chunk.yield"
|
||||
), f"Call {i} should have operation name 'streaming.chunk.yield', got {args[0]}"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user