Litellm OSS Staging 010626 (#29422)

This commit is contained in:
Sameer Kankute 2026-06-02 10:12:51 +05:30 committed by GitHub
parent b7bbddbd4d
commit 5fd27141cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
49 changed files with 5146 additions and 136 deletions

View File

@ -106,6 +106,7 @@ GATEWAY_PATH_PREFIXES: tuple[str, ...] = (
# Health & ops
"/health",
"/metrics",
"/watsonx"
)
GATEWAY_EXACT_PATHS: frozenset[str] = frozenset(

View File

@ -771,6 +771,7 @@ openai_compatible_endpoints: List = [
"https://api.moonshot.ai/v1",
"https://api.publicai.co/v1",
"https://api.synthetic.new/openai/v1",
"https://serverless.tensormesh.ai/v1",
"https://api.stima.tech/v1",
"https://nano-gpt.com/api/v1",
"https://api.poe.com/v1",
@ -820,6 +821,7 @@ openai_compatible_providers: List = [
"meta_llama",
"publicai", # PublicAI - JSON-configured provider
"synthetic", # Synthetic - JSON-configured provider
"tensormesh", # Tensormesh - JSON-configured provider
"apertis", # Apertis - JSON-configured provider
"nano-gpt", # Nano-GPT - JSON-configured provider
"poe", # Poe - JSON-configured provider
@ -855,6 +857,7 @@ openai_text_completion_compatible_providers: List = (
"moonshot",
"publicai",
"synthetic",
"tensormesh",
"apertis",
"nano-gpt",
"poe",

View File

@ -95,7 +95,9 @@ class FocusTransformer:
pl.lit("Usage-Based").alias("ChargeFrequency"),
fmt(pl.col("ChargePeriodEnd")).alias("ChargePeriodEnd"),
fmt(pl.col("ChargePeriodStart")).alias("ChargePeriodStart"),
dec(pl.lit(1.0)).alias("ConsumedQuantity"),
dec(
pl.col("api_requests").cast(pl.Int64).cast(pl.Float64).fill_null(0.0)
).alias("ConsumedQuantity"),
pl.lit("Requests").alias("ConsumedUnit"),
dec(pl.col("spend").fill_null(0.0)).alias("ContractedCost"),
none_str.alias("ContractedUnitPrice"),
@ -107,7 +109,9 @@ class FocusTransformer:
none_str.alias("AvailabilityZone"),
pl.lit("USD").alias("PricingCurrency"),
none_str.alias("PricingCategory"),
dec(pl.lit(1.0)).alias("PricingQuantity"),
dec(
pl.col("api_requests").cast(pl.Int64).cast(pl.Float64).fill_null(0.0)
).alias("PricingQuantity"),
none_dec.alias("PricingCurrencyContractedUnitPrice"),
dec(pl.col("spend").fill_null(0.0)).alias("PricingCurrencyEffectiveCost"),
none_dec.alias("PricingCurrencyListUnitPrice"),

View File

@ -511,6 +511,23 @@ class PrometheusLogger(CustomLogger):
labelnames=self.get_labels_for_metric("litellm_cached_tokens_metric"),
)
# Provider prompt-caching metrics
self.litellm_provider_cache_read_input_tokens_metric = self._counter_factory(
name="litellm_provider_cache_read_input_tokens_metric",
documentation="Total prompt/input tokens read from provider prompt cache (e.g. OpenAI/Anthropic/Gemini/Bedrock)",
labelnames=self.get_labels_for_metric(
"litellm_provider_cache_read_input_tokens_metric"
),
)
self.litellm_provider_cache_creation_input_tokens_metric = self._counter_factory(
name="litellm_provider_cache_creation_input_tokens_metric",
documentation="Total prompt/input tokens written to provider prompt cache (e.g. Anthropic/Bedrock)",
labelnames=self.get_labels_for_metric(
"litellm_provider_cache_creation_input_tokens_metric"
),
)
# User and Team count metrics
self.litellm_total_users_metric = self._gauge_factory(
"litellm_total_users",
@ -1458,11 +1475,11 @@ class PrometheusLogger(CustomLogger):
"""
cache_hit = standard_logging_payload.get("cache_hit")
# Only track if cache_hit has a definite value (True or False)
if cache_hit is None:
return
if cache_hit is True:
# Historically these metrics only tracked LiteLLM caching.
# Provider prompt-caching metrics are still emitted below.
pass
elif cache_hit is True:
# Increment cache hits counter
PrometheusLogger._inc_labeled_counter(
self,
@ -1493,6 +1510,51 @@ class PrometheusLogger(CustomLogger):
label_context=label_context,
)
# Provider prompt caching metrics are independent of LiteLLM cache_hit.
provider_cache_read_tokens = 0
provider_cache_creation_tokens = 0
usage_obj = (standard_logging_payload.get("metadata", {}) or {}).get(
"usage_object"
)
if isinstance(usage_obj, dict):
# Prefer explicit provider cache fields when available.
_read = usage_obj.get("cache_read_input_tokens")
_write = usage_obj.get("cache_creation_input_tokens")
if isinstance(_read, int):
provider_cache_read_tokens = _read
if isinstance(_write, int):
provider_cache_creation_tokens = _write
# Fallback to prompt_tokens_details.cached_tokens (common normalization point).
# Only fallback when the explicit field is genuinely absent (None).
if _read is None:
prompt_details = usage_obj.get("prompt_tokens_details")
if isinstance(prompt_details, dict):
cached_tokens = prompt_details.get("cached_tokens")
if isinstance(cached_tokens, int):
provider_cache_read_tokens = cached_tokens
if provider_cache_read_tokens > 0:
PrometheusLogger._inc_labeled_counter(
self,
self.litellm_provider_cache_read_input_tokens_metric,
"litellm_provider_cache_read_input_tokens_metric",
enum_values,
label_context=label_context,
amount=float(provider_cache_read_tokens),
)
if provider_cache_creation_tokens > 0:
PrometheusLogger._inc_labeled_counter(
self,
self.litellm_provider_cache_creation_input_tokens_metric,
"litellm_provider_cache_creation_input_tokens_metric",
enum_values,
label_context=label_context,
amount=float(provider_cache_creation_tokens),
)
async def _increment_remaining_budget_metrics(
self,
user_api_team: Optional[str],

View File

@ -47,8 +47,9 @@ async def async_completion_with_fallbacks(**kwargs):
completion_kwargs = safe_deep_copy(base_kwargs)
# Handle dictionary fallback configurations
if isinstance(fallback, dict):
model = fallback.pop("model", original_model)
completion_kwargs.update(fallback)
fallback_config = safe_deep_copy(dict(fallback))
model = fallback_config.pop("model", original_model)
completion_kwargs.update(fallback_config)
else:
model = fallback

View File

@ -1534,10 +1534,9 @@ class BaseAWSLLM:
)
sigv4 = SigV4Auth(credentials, service_name, aws_region_name)
if headers is not None:
headers = headers or {}
if not any(header_name.lower() == "content-type" for header_name in headers):
headers = {"Content-Type": "application/json", **headers}
else:
headers = {"Content-Type": "application/json"}
aws_signature_headers = self._filter_headers_for_aws_signature(headers)
request = AWSRequest(

View File

@ -233,6 +233,259 @@ class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig):
# example; add others here as they adopt the same schema.
CONVERSE_INVOKE_PROVIDERS = ("nova",)
# OpenAI batch URL that signals an embedding request. Per OpenAI Batch API
# spec, every JSONL record carries a `url` field; we use it as the
# authoritative signal to route the line to the embedding code path
# instead of inferring from the presence of `input` vs `messages`.
OPENAI_EMBEDDINGS_URL = "/v1/embeddings"
@staticmethod
def _is_embedding_record(openai_jsonl_record: Dict[str, Any]) -> bool:
"""
Decide whether an OpenAI batch JSONL line is an embedding request.
Precedence (strict - any explicit `url` short-circuits):
1. `url == "/v1/embeddings"` -> embedding. Authoritative per the
OpenAI Batch API spec.
2. Any other non-empty `url` (e.g. `/v1/chat/completions`) -> NOT
embedding. We trust the caller's explicit signal even if the
body would otherwise suggest embedding; misrouting a chat
record into the embedding transformer would corrupt the
modelInput, while a chat-shaped body sent to the chat path
either succeeds or fails cleanly inside that transformer.
3. `url` missing/empty -> fall back to body shape. Requires
`input` present AND `messages` absent so a malformed record
carrying both keys routes to the chat path (safer default:
Anthropic transforms ignore unknown top-level keys, whereas
the embedding transformer would silently drop the messages).
"""
url = openai_jsonl_record.get("url")
if url == BedrockFilesConfig.OPENAI_EMBEDDINGS_URL:
return True
if url:
return False
body = openai_jsonl_record.get("body", {})
if not isinstance(body, dict):
return False
return "input" in body and "messages" not in body
# Identifier for the Bedrock Titan v2 InvokeModel body schema as stored
# in `model_prices_and_context_window.json`. Centralized so future
# embedding-schema variants can add their own value
# (e.g. `cohere_v3`, `titan_g1`, `titan_multimodal`) without touching
# the detection logic.
_TITAN_V2_INVOCATION_SCHEMA = "titan_v2"
# Substring marker used as a fallback when the registry can't resolve
# the model id - notably cross-region inference profile prefixes
# (`us.amazon.titan-embed-text-v2:0`) and Bedrock ARN forms, which
# `get_model_info` doesn't normalize today.
_TITAN_V2_EMBED_MODEL_MARKER = "titan-embed-text-v2"
# Nested field name under `provider_specific_entry` that identifies the
# Bedrock InvokeModel body schema for batch inference.
# `provider_specific_entry` is the registry's escape hatch for fields
# `get_model_info` doesn't promote to top-level - exactly what we need
# here. Documented in the `sample_spec` entry of
# `model_prices_and_context_window.json` and surfaced by
# `get_model_info` (see `ModelInfo.provider_specific_entry`).
_BEDROCK_INVOCATION_SCHEMA_FIELD = "bedrock_invocation_schema"
@staticmethod
def _is_titan_v2_embed_model(model: str) -> bool:
"""
True iff `model` refers to Amazon Titan Text Embeddings V2.
Resolution order:
1. `model_prices_and_context_window.json` via `get_model_info`.
The Titan v2 registry entry carries an explicit
`provider_specific_entry.bedrock_invocation_schema` discriminator
(`"titan_v2"`). When the registry resolves the id we trust that
field as the source of truth - no hardcoded model-id comparison
needed.
2. Substring fallback (`titan-embed-text-v2` followed by `:`, `/`,
or end-of-string) for ids the registry can't normalize. This
catches cross-region inference profile prefixes
(`us.amazon.titan-embed-text-v2:0`) and Bedrock ARN forms; the
marker boundary check rejects lookalikes like
`titan-embed-text-v20` or `titan-embed-text-v2-experimental`.
Tolerant of common id shapes:
- "amazon.titan-embed-text-v2:0"
- "bedrock/amazon.titan-embed-text-v2:0"
- "us.amazon.titan-embed-text-v2:0" (cross-region inference profile)
- ARN forms ending in ".../amazon.titan-embed-text-v2:0"
"""
# Registry-driven path: when get_model_info resolves the id we trust
# the registry's discriminator. A resolved id with a different (or
# absent) schema value here is intentionally not given a substring
# second-chance - the registry is authoritative for ids it knows.
registry_schema = BedrockFilesConfig._lookup_provider_specific_field(
model, BedrockFilesConfig._BEDROCK_INVOCATION_SCHEMA_FIELD
)
if registry_schema is not None:
return registry_schema == BedrockFilesConfig._TITAN_V2_INVOCATION_SCHEMA
# Registry silence -> substring fallback for unmapped ids only.
normalized = model.lower()
if normalized.startswith("bedrock/"):
normalized = normalized[len("bedrock/") :]
marker = BedrockFilesConfig._TITAN_V2_EMBED_MODEL_MARKER
idx = normalized.find(marker)
if idx < 0:
return False
end = idx + len(marker)
return end == len(normalized) or normalized[end] in (":", "/")
@staticmethod
def _lookup_provider_specific_field(model_id: str, field: str) -> Optional[str]:
"""
Read a nested string field from the registry entry's
`provider_specific_entry` dict via `litellm.get_model_info`.
Returns the field's string value when:
- the registry resolves `model_id`,
- the entry exposes `provider_specific_entry` as a dict, and
- that dict has `field` mapped to a non-empty string.
Otherwise returns `None`.
Isolating this means feature detectors (Titan v2 today, future
Cohere Embed / Nova Multimodal branches) share one defensive
try/except shape instead of duplicating it. The `None` return
covers every realistic failure mode: `get_model_info` raises
(cross-region profile prefixes, Bedrock ARN forms, unreleased
models), returns a non-dict, has no `provider_specific_entry`, or
the requested field is missing / non-string / empty.
"""
try:
from litellm import get_model_info
info = get_model_info(model_id)
except Exception:
return None
if not isinstance(info, dict):
return None
provider_specific = info.get("provider_specific_entry")
if not isinstance(provider_specific, dict):
return None
value = provider_specific.get(field)
return value if isinstance(value, str) and value else None
@staticmethod
def _coerce_embedding_input_to_string(raw_input: Any, model: str = "") -> str:
"""
Normalize an OpenAI /v1/embeddings `input` field into the single
string that Bedrock Titan v2 InvokeModel expects in `inputText`.
Accepts: a string, or a single-element list containing one string.
Rejects (with actionable messages):
- None / missing -> ValueError
- Multi-element string lists -> ValueError, prompts caller to
emit one JSONL line per input
- Pre-tokenized inputs (List[int], List[List[int]]) -> NotImplementedError
- Any other type -> ValueError
Extracted so the validation can be exercised in isolation and so
future embedding-provider branches (Titan G1, Cohere) can reuse it
without duplicating the type-shaping logic.
"""
if raw_input is None:
raise ValueError(
"Embedding batch record is missing required `input` field: "
f"model={model}"
)
# Bedrock InvokeModel for Titan v2 takes exactly one string `inputText`
# per call. Pre-tokenized inputs and multi-element string lists are
# explicitly unsupported so callers emit one JSONL line per embedding
# instead of relying on us to silently fan out or concatenate.
if isinstance(raw_input, list):
if len(raw_input) == 1:
candidate = raw_input[0]
else:
raise ValueError(
"Bedrock batch embedding requires one input per JSONL "
"record. Got a list with "
f"{len(raw_input)} items for model={model}; emit one "
"JSONL line per input string instead."
)
else:
candidate = raw_input
# Catches pre-tokenized inputs (List[int] from OpenAI spec, or a
# single int slipping past the list-unwrap above).
# NOTE: bool is a subclass of int but treating True/False as a token
# is meaningless either way, so the broad check is fine.
if isinstance(candidate, (list, int)):
raise NotImplementedError(
"Bedrock Titan v2 batch embedding does not support "
"pre-tokenized integer inputs. Pass `input` as a string "
f"(model={model})."
)
if not isinstance(candidate, str):
raise ValueError(
"Bedrock batch embedding `input` must be a string (or a "
"single-element list of strings). Got type "
f"{type(candidate).__name__} for model={model}."
)
return candidate
def _map_openai_embedding_to_bedrock_params(
self,
openai_request_body: Dict[str, Any],
) -> Dict[str, Any]:
"""
Transform an OpenAI /v1/embeddings request body into the
Bedrock InvokeModel `modelInput` for embedding models that AWS
supports via batch inference (CreateModelInvocationJob).
Currently routes Amazon Titan Text Embeddings V2 only; other
embedding providers (Titan G1, Titan Multimodal, Cohere Embed,
Nova Multimodal Embeddings) raise NotImplementedError until they
get a dedicated branch. Splitting them keeps PR scope tight and
lets each model's request schema be exercised by its own tests.
AWS docs (Titan v2 InvokeModel body):
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
"""
from litellm.llms.bedrock.embed.amazon_titan_v2_transformation import (
AmazonTitanV2Config,
)
_model = openai_request_body.get("model", "")
if not self._is_titan_v2_embed_model(_model):
# Refuse early instead of silently shaping the body for the wrong
# provider. The synchronous /v1/embeddings path supports more
# models, but each has a different InvokeModel schema; mapping
# them here without dedicated tests would risk corrupt batches.
raise NotImplementedError(
"Bedrock batch embedding currently supports only Amazon "
"Titan Text Embeddings V2 (model id contains "
f"'titan-embed-text-v2'). Got model={_model!r}. Track other "
"embedding models in https://github.com/BerriAI/litellm/issues."
)
input_text = self._coerce_embedding_input_to_string(
openai_request_body.get("input"), model=_model
)
# Map OpenAI-style params (dimensions, encoding_format) onto the
# Titan v2 schema (dimensions, embeddingTypes) via the embed config
# so this stays in sync with the synchronous /v1/embeddings path.
non_default_params = {
k: v for k, v in openai_request_body.items() if k not in ("model", "input")
}
titan_config = AmazonTitanV2Config()
inference_params = titan_config.map_openai_params(
non_default_params=non_default_params,
optional_params={},
)
return dict(
titan_config._transform_request(
input=input_text, inference_params=inference_params
)
)
def _map_openai_to_bedrock_params(
self,
openai_request_body: Dict[str, Any],
@ -349,10 +602,19 @@ class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig):
# Determine provider from model name
provider = self.get_bedrock_invoke_provider(model)
# Transform to Bedrock modelInput format
model_input = self._map_openai_to_bedrock_params(
openai_request_body=openai_body, provider=provider
)
# Route to the embedding transformer when the OpenAI batch line
# targets /v1/embeddings; otherwise fall back to the existing
# chat-completion path. We branch here (rather than inside
# `_map_openai_to_bedrock_params`) so the chat helper keeps its
# narrow contract and the embedding helper can evolve independently.
if self._is_embedding_record(_openai_jsonl_content):
model_input = self._map_openai_embedding_to_bedrock_params(
openai_request_body=openai_body
)
else:
model_input = self._map_openai_to_bedrock_params(
openai_request_body=openai_body, provider=provider
)
# Create Bedrock batch record
record_id = _openai_jsonl_content.get(

View File

@ -5,6 +5,7 @@ Common utilities, constants, and error handling for Black Forest Labs API.
"""
from typing import Dict
from urllib.parse import urlparse
from litellm.llms.base_llm.chat.transformation import BaseLLMException
@ -18,6 +19,42 @@ class BlackForestLabsError(BaseLLMException):
# API Constants
DEFAULT_API_BASE = "https://api.bfl.ai"
# BFL uses regional subdomains (e.g. gateway.bfl.ai) for polling URLs that
# differ from the submission host (api.bfl.ai). We validate against the
# registered domain rather than doing a strict same-origin check.
_BFL_REGISTERED_DOMAIN = "bfl.ai"
def assert_bfl_polling_url(polling_url: str) -> None:
"""Validate that a polling URL points to a BFL-controlled host.
BFL returns polling URLs on subdomains like ``gateway.bfl.ai`` that differ
from the submission host ``api.bfl.ai``. A strict same-origin check would
reject these legitimate URLs. Instead we verify the host is ``bfl.ai`` or
any subdomain of it, which keeps the SSRF guarantee (credentials only go
to BFL-controlled infrastructure) without false-positives on regional hosts.
Raises:
BlackForestLabsError: If the polling URL scheme or host is not trusted.
"""
parsed = urlparse(polling_url)
host = (parsed.hostname or "").lower()
if parsed.scheme != "https":
raise BlackForestLabsError(
status_code=502,
message="Rejected polling URL: scheme must be https",
)
if host != _BFL_REGISTERED_DOMAIN and not host.endswith(
"." + _BFL_REGISTERED_DOMAIN
):
raise BlackForestLabsError(
status_code=502,
message="Rejected polling URL: host is not within the bfl.ai domain",
)
# Polling configuration
DEFAULT_POLLING_INTERVAL = 1.5 # seconds
DEFAULT_MAX_POLLING_TIME = 300 # 5 minutes

View File

@ -15,7 +15,6 @@ import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -29,6 +28,7 @@ from ..common_utils import (
DEFAULT_MAX_POLLING_TIME,
DEFAULT_POLLING_INTERVAL,
BlackForestLabsError,
assert_bfl_polling_url,
)
from .transformation import BlackForestLabsImageEditConfig
@ -332,16 +332,11 @@ class BlackForestLabsImageEdit:
message="No polling_url in BFL response",
)
# Reject cross-origin polling URLs — the ``x-key`` auth header
# would otherwise leak to whatever URL the upstream returns.
# VERIA-51.
try:
assert_same_origin(polling_url, str(initial_response.request.url))
except SSRFError as ssrf_err:
raise BlackForestLabsError(
status_code=502,
message=f"Rejected polling URL: {ssrf_err}",
)
# Reject polling URLs that don't belong to BFL-controlled infrastructure.
# BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the
# submission host (api.bfl.ai), so we validate against the registered
# domain rather than doing a strict same-origin check. VERIA-51.
assert_bfl_polling_url(polling_url)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
@ -428,16 +423,11 @@ class BlackForestLabsImageEdit:
message="No polling_url in BFL response",
)
# Reject cross-origin polling URLs — the ``x-key`` auth header
# would otherwise leak to whatever URL the upstream returns.
# VERIA-51.
try:
assert_same_origin(polling_url, str(initial_response.request.url))
except SSRFError as ssrf_err:
raise BlackForestLabsError(
status_code=502,
message=f"Rejected polling URL: {ssrf_err}",
)
# Reject polling URLs that don't belong to BFL-controlled infrastructure.
# BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the
# submission host (api.bfl.ai), so we validate against the registered
# domain rather than doing a strict same-origin check. VERIA-51.
assert_bfl_polling_url(polling_url)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}

View File

@ -15,7 +15,6 @@ import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -29,6 +28,7 @@ from ..common_utils import (
DEFAULT_MAX_POLLING_TIME,
DEFAULT_POLLING_INTERVAL,
BlackForestLabsError,
assert_bfl_polling_url,
)
from .transformation import BlackForestLabsImageGenerationConfig
@ -172,6 +172,10 @@ class BlackForestLabsImageGeneration:
raw_response=final_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=data,
optional_params=optional_params,
litellm_params=litellm_params_dict,
encoding=None,
)
async def async_image_generation(
@ -274,6 +278,10 @@ class BlackForestLabsImageGeneration:
raw_response=final_response,
model_response=model_response,
logging_obj=logging_obj,
request_data=data,
optional_params=optional_params,
litellm_params=litellm_params_dict,
encoding=None,
)
def _poll_for_result_sync(
@ -318,16 +326,11 @@ class BlackForestLabsImageGeneration:
message="No polling_url in BFL response",
)
# Reject cross-origin polling URLs — the ``x-key`` auth header
# would otherwise leak to whatever URL the upstream returns.
# VERIA-51.
try:
assert_same_origin(polling_url, str(initial_response.request.url))
except SSRFError as ssrf_err:
raise BlackForestLabsError(
status_code=502,
message=f"Rejected polling URL: {ssrf_err}",
)
# Reject polling URLs that don't belong to BFL-controlled infrastructure.
# BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the
# submission host (api.bfl.ai), so we validate against the registered
# domain rather than doing a strict same-origin check. VERIA-51.
assert_bfl_polling_url(polling_url)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
@ -414,16 +417,11 @@ class BlackForestLabsImageGeneration:
message="No polling_url in BFL response",
)
# Reject cross-origin polling URLs — the ``x-key`` auth header
# would otherwise leak to whatever URL the upstream returns.
# VERIA-51.
try:
assert_same_origin(polling_url, str(initial_response.request.url))
except SSRFError as ssrf_err:
raise BlackForestLabsError(
status_code=502,
message=f"Rejected polling URL: {ssrf_err}",
)
# Reject polling URLs that don't belong to BFL-controlled infrastructure.
# BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the
# submission host (api.bfl.ai), so we validate against the registered
# domain rather than doing a strict same-origin check. VERIA-51.
assert_bfl_polling_url(polling_url)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}

View File

@ -376,6 +376,13 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
)
guardrailed_texts = guardrailed_inputs.get("texts", [])
returned_tool_calls = guardrailed_inputs.get("tool_calls")
guardrailed_tool_calls: List[Dict[str, Any]] = (
cast(List[Dict[str, Any]], returned_tool_calls)
if isinstance(returned_tool_calls, list)
and len(returned_tool_calls) == len(tool_calls_to_check)
else tool_calls_to_check
)
# Step 3: Map guardrail responses back to original response structure
if guardrailed_texts and texts_to_check:
@ -386,10 +393,10 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
)
# Step 4: Apply guardrailed tool calls back to response
if tool_calls_to_check:
if guardrailed_tool_calls:
await self._apply_guardrail_responses_to_output_tool_calls(
response=response,
tool_calls=tool_calls_to_check,
tool_calls=guardrailed_tool_calls,
task_mappings=tool_call_task_mappings,
)
@ -748,10 +755,11 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
task_mappings: List[Tuple[int, int]],
) -> None:
"""
Apply guardrailed tool calls back to output response.
Apply guardrailed tool calls back to the output response.
The guardrail may have modified the tool_calls list in place,
so we apply the modified tool calls back to the original response.
The guardrail may return updated tool calls (either mutated in place or as
a new list), so we apply the provided tool calls back to the original
response.
Override this method to customize how tool call responses are applied.
"""

View File

@ -114,5 +114,14 @@
"param_mappings": {
"max_completion_tokens": "max_tokens"
}
},
"tensormesh": {
"base_url": "https://serverless.tensormesh.ai/v1",
"api_key_env": "TENSORMESH_INFERENCE_API_KEY",
"api_base_env": "TENSORMESH_SERVERLESS_BASE_URL",
"base_class": "openai_gpt",
"param_mappings": {
"max_completion_tokens": "max_tokens"
}
}
}

View File

@ -0,0 +1,69 @@
from typing import TYPE_CHECKING, List, Optional, Tuple
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
if TYPE_CHECKING:
from httpx import URL
class WatsonxPassthroughConfig(IBMWatsonXMixin, BasePassthroughConfig):
"""
Watsonx-specific passthrough configuration.
"""
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
"""Check if request should be streamed"""
return request_data.get("stream", False)
def get_complete_url(
self,
api_base: Optional[str],
api_key: Optional[str],
model: str,
endpoint: str,
request_query_params: Optional[dict],
litellm_params: dict,
) -> Tuple["URL", str]:
"""
Construct complete Watsonx URL with version parameter.
This ensures the version parameter is ALWAYS included in the URL,
solving the query parameter issue.
"""
base_target_url = str(self.get_api_base(api_base))
# Use the format_url helper to construct URL with query params
complete_url = self.format_url(
endpoint=endpoint,
base_target_url=base_target_url,
request_query_params=request_query_params,
)
return (complete_url, base_target_url)
@staticmethod
def get_api_base(
api_base: Optional[str] = None,
) -> Optional[str]:
return api_base or IBMWatsonXMixin()._get_base_url(api_base=api_base)
@staticmethod
def get_api_key(
api_key: Optional[str] = None,
) -> Optional[str]:
return (
api_key
or IBMWatsonXMixin.get_watsonx_credentials(
optional_params=dict(), api_base=None, api_key=api_key
)["api_key"]
)
@staticmethod
def get_base_model(model: str) -> Optional[str]:
return model
def get_models(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> List[str]:
return super().get_models(api_key, api_base)

View File

@ -577,7 +577,10 @@
"max_tokens": 8192,
"mode": "embedding",
"output_cost_per_token": 0.0,
"output_vector_size": 1024
"output_vector_size": 1024,
"provider_specific_entry": {
"bedrock_invocation_schema": "titan_v2"
}
},
"amazon.titan-image-generator-v1": {
"input_cost_per_image": 0.0,
@ -8899,15 +8902,16 @@
"cache_creation_input_token_cost": 3.75e-07
},
"bedrock/us-gov-east-1/anthropic.claude-sonnet-4-5-20250929-v1:0": {
"cache_creation_input_token_cost": 4.125e-06,
"cache_read_input_token_cost": 3.3e-07,
"input_cost_per_token": 3.3e-06,
"cache_creation_input_token_cost": 4.5e-06,
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
"cache_read_input_token_cost": 3.6e-07,
"input_cost_per_token": 3.6e-06,
"litellm_provider": "bedrock",
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_tokens": 8192,
"mode": "chat",
"output_cost_per_token": 1.65e-05,
"output_cost_per_token": 1.8e-05,
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
@ -8920,15 +8924,16 @@
"supports_native_structured_output": true
},
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
"cache_creation_input_token_cost": 4.125e-06,
"cache_read_input_token_cost": 3.3e-07,
"input_cost_per_token": 3.3e-06,
"cache_creation_input_token_cost": 4.5e-06,
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
"cache_read_input_token_cost": 3.6e-07,
"input_cost_per_token": 3.6e-06,
"litellm_provider": "bedrock",
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_tokens": 8192,
"mode": "chat",
"output_cost_per_token": 1.65e-05,
"output_cost_per_token": 1.8e-05,
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
@ -9072,15 +9077,16 @@
"cache_creation_input_token_cost": 3.75e-07
},
"bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0": {
"cache_creation_input_token_cost": 4.125e-06,
"cache_read_input_token_cost": 3.3e-07,
"input_cost_per_token": 3.3e-06,
"cache_creation_input_token_cost": 4.5e-06,
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
"cache_read_input_token_cost": 3.6e-07,
"input_cost_per_token": 3.6e-06,
"litellm_provider": "bedrock",
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_tokens": 8192,
"mode": "chat",
"output_cost_per_token": 1.65e-05,
"output_cost_per_token": 1.8e-05,
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
@ -9093,15 +9099,16 @@
"supports_native_structured_output": true
},
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
"cache_creation_input_token_cost": 4.125e-06,
"cache_read_input_token_cost": 3.3e-07,
"input_cost_per_token": 3.3e-06,
"cache_creation_input_token_cost": 4.5e-06,
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
"cache_read_input_token_cost": 3.6e-07,
"input_cost_per_token": 3.6e-06,
"litellm_provider": "bedrock",
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_tokens": 8192,
"mode": "chat",
"output_cost_per_token": 1.65e-05,
"output_cost_per_token": 1.8e-05,
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
@ -31858,19 +31865,21 @@
"supports_native_structured_output": true
},
"us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0": {
"cache_creation_input_token_cost": 4.125e-06,
"cache_read_input_token_cost": 3.3e-07,
"input_cost_per_token": 3.3e-06,
"input_cost_per_token_above_200k_tokens": 6.6e-06,
"output_cost_per_token_above_200k_tokens": 2.475e-05,
"cache_creation_input_token_cost_above_200k_tokens": 8.25e-06,
"cache_read_input_token_cost_above_200k_tokens": 6.6e-07,
"cache_creation_input_token_cost": 4.5e-06,
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
"cache_read_input_token_cost": 3.6e-07,
"input_cost_per_token": 3.6e-06,
"input_cost_per_token_above_200k_tokens": 7.2e-06,
"output_cost_per_token_above_200k_tokens": 2.7e-05,
"cache_creation_input_token_cost_above_200k_tokens": 9.0e-06,
"cache_creation_input_token_cost_above_1hr_above_200k_tokens": 1.44e-05,
"cache_read_input_token_cost_above_200k_tokens": 7.2e-07,
"litellm_provider": "bedrock_converse",
"max_input_tokens": 200000,
"max_output_tokens": 64000,
"max_tokens": 64000,
"mode": "chat",
"output_cost_per_token": 1.65e-05,
"output_cost_per_token": 1.8e-05,
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
@ -41448,6 +41457,7 @@
},
"bedrock/us-gov-east-1/anthropic.claude-haiku-4-5-20251001-v1:0": {
"cache_creation_input_token_cost": 1.5e-06,
"cache_creation_input_token_cost_above_1hr": 2.4e-06,
"cache_read_input_token_cost": 1.2e-07,
"input_cost_per_token": 1.2e-06,
"litellm_provider": "bedrock",
@ -41470,6 +41480,7 @@
},
"bedrock/us-gov-west-1/anthropic.claude-haiku-4-5-20251001-v1:0": {
"cache_creation_input_token_cost": 1.5e-06,
"cache_creation_input_token_cost_above_1hr": 2.4e-06,
"cache_read_input_token_cost": 1.2e-07,
"input_cost_per_token": 1.2e-06,
"litellm_provider": "bedrock",

View File

@ -419,6 +419,7 @@ class LiteLLMRoutes(enum.Enum):
"/vllm",
"/mistral",
"/milvus",
"/watsonx",
]
#########################################################
@ -3901,7 +3902,9 @@ class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase):
# Union so Pydantic picks Full when data has server-managed fields
# (/team/info) and Base when callers/tests construct with only
# user-settable fields.
litellm_budget_table: Optional[Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable]]
litellm_budget_table: Optional[
Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable]
] = None
def safe_get_team_member_rpm_limit(self) -> Optional[int]:
if self.litellm_budget_table is not None:

View File

@ -328,7 +328,24 @@ class ContentFilterGuardrail(CustomGuardrail):
return result
@staticmethod
def _resolve_category_file_path(file_path: str) -> str:
def _assert_within_categories_dir(path: str, categories_dir: str) -> None:
"""Raise ValueError if path escapes the categories directory."""
resolved = os.path.realpath(path)
allowed = os.path.realpath(categories_dir)
try:
common = os.path.commonpath([resolved, allowed])
except ValueError:
# commonpath() raises ValueError on Windows when paths span different drives
raise ValueError(
f"Category file path '{path}' is outside the allowed categories directory"
)
if common != allowed:
raise ValueError(
f"Category file path '{path}' is outside the allowed "
f"categories directory '{categories_dir}'"
)
def _resolve_category_file_path(self, file_path: str) -> str:
"""
Resolve a category file path that may be relative.
@ -339,12 +356,17 @@ class ContentFilterGuardrail(CustomGuardrail):
file isn't found.
Resolution order:
1. Return as-is if absolute or already exists.
2. Try joining the full path relative to this module's directory.
1. Return as-is if absolute or already exists (jailed to module dir).
2. Try joining the full path relative to this module's directory (jailed).
3. Progressively strip leading path components and try each suffix
relative to this module's directory (handles paths like
"litellm/proxy/.../policy_templates/file.yaml" by finding the
"policy_templates/file.yaml" suffix that exists).
relative to this module's directory (jailed).
The directory jail can be disabled for deployments that legitimately
store category files outside the package (e.g. mounted volumes) by
setting the environment variable
``LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS=true``. Use only in
trusted environments where the proxy configuration cannot be influenced
by untrusted input.
Args:
file_path: The file path to resolve (absolute or relative).
@ -352,15 +374,33 @@ class ContentFilterGuardrail(CustomGuardrail):
Returns:
The resolved absolute-ish path, or the original path if
resolution fails (caller should check existence).
"""
if os.path.isabs(file_path) or os.path.exists(file_path):
return file_path
Raises:
ValueError: If the resolved path escapes the module directory
and ``LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS`` is not set.
"""
module_dir = os.path.dirname(__file__)
allow_external = (
os.environ.get("LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS", "").lower()
== "true"
)
if os.path.isabs(file_path) or os.path.exists(file_path):
if not allow_external:
self._assert_within_categories_dir(file_path, module_dir)
else:
verbose_proxy_logger.warning(
"LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS is set — "
"skipping directory jail for category_file '%s'",
file_path,
)
return file_path
# Try the full relative path joined to the module directory
candidate = os.path.join(module_dir, file_path)
if os.path.exists(candidate):
if not allow_external:
self._assert_within_categories_dir(candidate, module_dir)
return candidate
# Progressively strip leading components to find a matching suffix
@ -369,8 +409,17 @@ class ContentFilterGuardrail(CustomGuardrail):
suffix = os.path.join(*parts[i:])
candidate = os.path.join(module_dir, suffix)
if os.path.exists(candidate):
if not allow_external:
self._assert_within_categories_dir(candidate, module_dir)
return candidate
# File not found via any resolution strategy — jail the module-relative
# path anyway to reject traversal attempts (e.g. "../../../../etc/passwd")
# regardless of CWD or whether the target file exists.
if not allow_external:
self._assert_within_categories_dir(
os.path.join(module_dir, file_path), module_dir
)
return file_path
def _load_categories(self, categories: List[ContentFilterCategoryConfig]) -> None:
@ -395,6 +444,13 @@ class ContentFilterGuardrail(CustomGuardrail):
)
continue
# Prevent path traversal via category_name (e.g. "../../etc/passwd")
if not re.match(r"^[a-zA-Z0-9_\-]+$", category_name):
verbose_proxy_logger.warning(
f"Category name '{category_name}' contains invalid characters, skipping"
)
continue
enabled = cat_config.get("enabled", True)
action = cat_config.get("action")
severity_threshold = (
@ -411,7 +467,13 @@ class ContentFilterGuardrail(CustomGuardrail):
# Load category file (custom or default)
if custom_file:
category_file_path = self._resolve_category_file_path(custom_file)
try:
category_file_path = self._resolve_category_file_path(custom_file)
except ValueError as e:
verbose_proxy_logger.warning(
f"Category {category_name}: invalid category_file path, skipping. {e}"
)
continue
else:
# Try .yaml first, then .json (e.g. harm_toxic_abuse.json)
yaml_path = os.path.join(categories_dir, f"{category_name}.yaml")

View File

@ -140,7 +140,12 @@ class PanwPrismaAirsHandler(CustomGuardrail):
)
self.fallback_on_error = fallback_on_error
self.timeout = timeout
# Coerce defensively. The dashboard UI persists this field as a JSON
# string, and Pydantic extras (the path that splats model_dump into
# this handler) preserve whatever type the user supplied. A string
# value would otherwise reach httpx, which raises TypeError on its
# internal '<=' comparison and surfaces as a misleading api_error.
self.timeout = float(timeout) if timeout is not None else 10.0
# Tri-state: None = not set (default-on for Anthropic), True = explicit on, False = explicit off
self.experimental_use_latest_role_message_only: Optional[bool] = kwargs.get(

View File

@ -0,0 +1,34 @@
from typing import TYPE_CHECKING
from litellm.types.guardrails import SupportedGuardrailIntegrations
from .vigil_guard import VigilGuardGuardrail
if TYPE_CHECKING:
from litellm.types.guardrails import Guardrail, LitellmParams
def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
import litellm
_vigil_guard_callback = VigilGuardGuardrail(
api_base=litellm_params.api_base,
api_key=litellm_params.api_key,
unreachable_fallback=litellm_params.unreachable_fallback,
timeout=litellm_params.timeout,
guardrail_name=guardrail.get("guardrail_name", ""),
event_hook=litellm_params.mode,
default_on=litellm_params.default_on,
)
litellm.logging_callback_manager.add_litellm_callback(_vigil_guard_callback)
return _vigil_guard_callback
guardrail_initializer_registry = {
SupportedGuardrailIntegrations.VIGIL_GUARD.value: initialize_guardrail,
}
guardrail_class_registry = {
SupportedGuardrailIntegrations.VIGIL_GUARD.value: VigilGuardGuardrail,
}

View File

@ -0,0 +1,485 @@
from json import JSONDecodeError
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Dict,
List,
Literal,
Optional,
Protocol,
Tuple,
Type,
cast,
)
import httpx
from litellm._logging import verbose_proxy_logger
from litellm.exceptions import GuardrailRaisedException
from litellm.exceptions import Timeout as LiteLLMTimeout
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.secret_managers.main import get_secret_str
from litellm.types.guardrails import GuardrailEventHooks
from litellm.types.utils import GenericGuardrailAPIInputs
if TYPE_CHECKING:
from litellm.litellm_core_utils.litellm_logging import (
Logging as LiteLLMLoggingObj,
)
from litellm.types.proxy.guardrails.guardrail_hooks.base import (
GuardrailConfigModel,
)
_ANALYZE_ENDPOINT = "/v1/guard/analyze"
_DEFAULT_VIGIL_TIMEOUT = httpx.Timeout(10.0, connect=5.0)
_BLOCK_REASON_MAX_CHARS = 500
_METADATA_STRING_MAX_CHARS = 500
_METADATA_ARRAY_MAX_ITEMS = 10
_VALID_DECISIONS = ("ALLOWED", "SANITIZED", "BLOCKED")
_TRANSIENT_STATUS_CODES = frozenset({429, 502, 503, 504})
_METADATA_ALLOWLIST = (
"model",
"model_group",
"provider",
"region",
"deployment",
"user",
"user_id",
"session_id",
"conversation_id",
"request_id",
"tenant_id",
"org_id",
)
_FallbackMode = Literal["fail_closed", "fail_open"]
class _AsyncPostHandler(Protocol):
def post(
self,
*,
url: str,
headers: Dict[str, str],
json: Dict[str, Any],
timeout: httpx.Timeout,
) -> Awaitable[httpx.Response]: ...
class VigilGuardMissingConfig(ValueError):
pass
class VigilGuardGuardrail(CustomGuardrail):
def __init__(
self,
api_base: Optional[str] = None,
api_key: Optional[str] = None,
unreachable_fallback: Optional[str] = None,
timeout: Optional[float] = None,
async_handler: Optional[_AsyncPostHandler] = None,
**kwargs: Any,
) -> None:
resolved_base = api_base or get_secret_str("VIGIL_GUARD_URL")
if not resolved_base:
raise VigilGuardMissingConfig(
"Vigil Guard api_base is required. Set api_base in the guardrail "
"config or the VIGIL_GUARD_URL environment variable."
)
self.api_base = resolved_base.rstrip("/")
resolved_key = api_key or get_secret_str("VIGIL_GUARD_API_KEY")
if not resolved_key:
raise VigilGuardMissingConfig(
"Vigil Guard api_key is required. Set api_key in the guardrail "
"config or the VIGIL_GUARD_API_KEY environment variable."
)
self.api_key = resolved_key
fallback = (unreachable_fallback or "fail_closed").lower()
self.unreachable_fallback: _FallbackMode = (
"fail_open" if fallback == "fail_open" else "fail_closed"
)
self.timeout: httpx.Timeout = (
_DEFAULT_VIGIL_TIMEOUT
if timeout is None
else httpx.Timeout(timeout, connect=min(timeout, 5.0))
)
self.async_handler: _AsyncPostHandler = async_handler or get_async_httpx_client(
llm_provider=httpxSpecialProvider.GuardrailCallback,
)
if "supported_event_hooks" not in kwargs:
kwargs["supported_event_hooks"] = [
GuardrailEventHooks.pre_call,
GuardrailEventHooks.post_call,
]
super().__init__(**kwargs)
@staticmethod
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
from litellm.types.proxy.guardrails.guardrail_hooks.vigil_guard import (
VigilGuardGuardrailConfigModel,
)
return VigilGuardGuardrailConfigModel
@log_guardrail_information
async def apply_guardrail(
self,
inputs: GenericGuardrailAPIInputs,
request_data: dict,
input_type: Literal["request", "response"],
logging_obj: Optional["LiteLLMLoggingObj"] = None,
) -> GenericGuardrailAPIInputs:
texts = inputs.get("texts") or []
has_text = any(isinstance(text, str) and text.strip() for text in texts)
tool_call_args = (
self._tool_call_arguments(inputs.get("tool_calls"))
if input_type == "response"
else []
)
if not has_text and not tool_call_args:
return inputs
source = "user_input" if input_type == "request" else "model_output"
metadata = self._collect_metadata(request_data, logging_obj)
result_texts: List[str] = []
for index, text in enumerate(texts):
if not isinstance(text, str) or not text.strip():
result_texts.append(text)
continue
try:
analysis = await self._analyze(
text=text, source=source, metadata=metadata
)
except (
httpx.HTTPError,
LiteLLMTimeout,
JSONDecodeError,
OSError,
) as exc:
return self._handle_backend_failure(
exc,
inputs,
source,
result_texts + list(texts[index:]),
inputs.get("tool_calls"),
)
decision = analysis.get("decision") if isinstance(analysis, dict) else None
if decision not in _VALID_DECISIONS:
verbose_proxy_logger.error(
"Vigil Guard unrecognized decision for guardrail_name=%s "
"source=%s: %r",
self.guardrail_name,
source,
decision,
)
if self.unreachable_fallback == "fail_open":
return self._build_output(
inputs,
result_texts + list(texts[index:]),
inputs.get("tool_calls"),
)
raise GuardrailRaisedException(
guardrail_name=self.guardrail_name,
message="Vigil Guard returned an unrecognized decision.",
should_wrap_with_default_message=False,
)
if decision == "BLOCKED":
raise GuardrailRaisedException(
guardrail_name=self.guardrail_name,
message=self._build_block_reason(analysis),
should_wrap_with_default_message=False,
)
if decision == "SANITIZED":
result_texts.append(self._resolve_sanitized_text(text, analysis))
else:
result_texts.append(text)
result_tool_calls = inputs.get("tool_calls")
for tc_index, arguments in tool_call_args:
try:
analysis = await self._analyze(
text=arguments, source=source, metadata=metadata
)
except (
httpx.HTTPError,
LiteLLMTimeout,
JSONDecodeError,
OSError,
) as exc:
return self._handle_backend_failure(
exc, inputs, source, result_texts, result_tool_calls
)
decision = analysis.get("decision") if isinstance(analysis, dict) else None
if decision not in _VALID_DECISIONS:
verbose_proxy_logger.error(
"Vigil Guard unrecognized decision for guardrail_name=%s "
"source=%s: %r",
self.guardrail_name,
source,
decision,
)
if self.unreachable_fallback == "fail_open":
return self._build_output(inputs, result_texts, result_tool_calls)
raise GuardrailRaisedException(
guardrail_name=self.guardrail_name,
message="Vigil Guard returned an unrecognized decision.",
should_wrap_with_default_message=False,
)
if decision == "BLOCKED":
raise GuardrailRaisedException(
guardrail_name=self.guardrail_name,
message=self._build_block_reason(analysis),
should_wrap_with_default_message=False,
)
if decision == "SANITIZED":
result_tool_calls = self._set_tool_call_arguments(
result_tool_calls,
tc_index,
self._resolve_sanitized_text(arguments, analysis),
)
return self._build_output(inputs, result_texts, result_tool_calls)
def _handle_backend_failure(
self,
exc: Exception,
inputs: GenericGuardrailAPIInputs,
source: str,
final_texts: List[Any],
final_tool_calls: Any,
) -> GenericGuardrailAPIInputs:
if self.unreachable_fallback == "fail_open":
verbose_proxy_logger.error(
"Vigil Guard backend failure with fail_open; allowing request "
"unscanned. guardrail_name=%s source=%s error=%s",
self.guardrail_name,
source,
str(exc),
)
return self._build_output(inputs, final_texts, final_tool_calls)
verbose_proxy_logger.error(
"Vigil Guard backend failure with fail_closed; blocking request. "
"guardrail_name=%s source=%s error=%s",
self.guardrail_name,
source,
str(exc),
)
raise GuardrailRaisedException(
guardrail_name=self.guardrail_name,
message="Vigil Guard backend unreachable; request blocked by fail_closed policy.",
should_wrap_with_default_message=False,
) from exc
@staticmethod
def _build_output(
inputs: GenericGuardrailAPIInputs,
final_texts: List[Any],
final_tool_calls: Any,
) -> GenericGuardrailAPIInputs:
# When nothing was changed, return the input shape verbatim so the guardrail
# logs "allow" rather than "mask". When a text or a tool-call argument was
# changed (sanitized), return only the remap-relevant keys and drop
# structured_messages so a stale, unsanitized payload cannot reach the model.
texts_changed = final_texts != (inputs.get("texts") or [])
tool_calls_changed = final_tool_calls != inputs.get("tool_calls")
if not texts_changed and not tool_calls_changed:
return cast(GenericGuardrailAPIInputs, dict(inputs))
guardrailed: GenericGuardrailAPIInputs = {"texts": final_texts}
if "images" in inputs:
guardrailed["images"] = inputs["images"]
if "tools" in inputs:
guardrailed["tools"] = inputs["tools"]
if tool_calls_changed:
guardrailed["tool_calls"] = final_tool_calls
return guardrailed
@staticmethod
def _tool_call_arguments(tool_calls: Any) -> List[Tuple[int, str]]:
pairs: List[Tuple[int, str]] = []
if isinstance(tool_calls, list):
for index, tool_call in enumerate(tool_calls):
function = (
tool_call.get("function") if isinstance(tool_call, dict) else None
)
arguments = (
function.get("arguments") if isinstance(function, dict) else None
)
if isinstance(arguments, str) and arguments.strip():
pairs.append((index, arguments))
return pairs
@staticmethod
def _set_tool_call_arguments(
tool_calls: Any, index: int, arguments: str
) -> List[Any]:
updated = list(tool_calls)
tool_call = dict(updated[index])
function = dict(tool_call.get("function") or {})
function["arguments"] = arguments
tool_call["function"] = function
updated[index] = tool_call
return updated
async def _analyze(
self, text: str, source: str, metadata: Dict[str, Any]
) -> Dict[str, Any]:
payload = {
"text": text,
"source": source,
"mode": "full",
"metadata": metadata,
}
endpoint = f"{self.api_base}{_ANALYZE_ENDPOINT}"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
response = await self._post_with_retry(endpoint, headers, payload)
return response.json()
async def _post_with_retry(
self, endpoint: str, headers: Dict[str, str], payload: Dict[str, Any]
) -> httpx.Response:
for attempt in range(2):
try:
response = await self.async_handler.post(
url=endpoint,
headers=headers,
json=payload,
timeout=self.timeout,
)
response.raise_for_status()
return response
except Exception as exc:
if attempt == 0 and self._is_transient(exc):
verbose_proxy_logger.debug(
"Vigil Guard transient failure; retrying once: %s",
type(exc).__name__,
)
continue
raise
raise AssertionError("unreachable") # pragma: no cover
@staticmethod
def _is_transient(exc: Exception) -> bool:
if isinstance(exc, httpx.HTTPStatusError):
return exc.response.status_code in _TRANSIENT_STATUS_CODES
return isinstance(
exc,
(
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadTimeout,
httpx.RemoteProtocolError,
LiteLLMTimeout,
),
)
@staticmethod
def _build_block_reason(analysis: Dict[str, Any]) -> str:
for key in ("blockMessage", "decisionReason"):
value = analysis.get(key)
if isinstance(value, str) and value.strip():
return value.strip()[:_BLOCK_REASON_MAX_CHARS]
categories = analysis.get("categories")
if isinstance(categories, list):
names = [c for c in categories if isinstance(c, str) and c.strip()]
if names:
return ", ".join(names)[:_BLOCK_REASON_MAX_CHARS]
return "Blocked by policy"
@staticmethod
def _resolve_sanitized_text(original: str, analysis: Dict[str, Any]) -> str:
for key in ("sanitizedText", "outputText"):
value = analysis.get(key)
if isinstance(value, str):
return value
return original
def _collect_metadata(
self, request_data: dict, logging_obj: Optional["LiteLLMLoggingObj"]
) -> Dict[str, Any]:
sources: List[dict] = []
if isinstance(request_data, dict):
sources.append(request_data)
for nested_key in ("metadata", "litellm_metadata"):
nested = request_data.get(nested_key)
if isinstance(nested, dict):
sources.append(nested)
collected: Dict[str, Any] = {}
for field in _METADATA_ALLOWLIST:
for source in sources:
if field in source and source[field] is not None:
clamped = self._clamp_metadata_value(source[field])
if clamped is not None:
collected[field] = clamped
break
call_id = self._extract_call_id(request_data, logging_obj)
if call_id:
collected["litellm_call_id"] = call_id
return collected
@staticmethod
def _clamp_metadata_value(value: Any) -> Any:
if isinstance(value, bool):
return None
if isinstance(value, str):
return value[:_METADATA_STRING_MAX_CHARS]
if isinstance(value, (int, float)):
return value
if isinstance(value, list):
clamped: List[Any] = []
for item in value[:_METADATA_ARRAY_MAX_ITEMS]:
if isinstance(item, bool):
continue
if isinstance(item, str):
clamped.append(item[:_METADATA_STRING_MAX_CHARS])
elif isinstance(item, (int, float)):
clamped.append(item)
return clamped or None
return None
@staticmethod
def _extract_call_id(
request_data: dict, logging_obj: Optional["LiteLLMLoggingObj"]
) -> Optional[str]:
if logging_obj is not None:
call_id = getattr(logging_obj, "litellm_call_id", None)
if isinstance(call_id, str) and call_id:
return call_id
if isinstance(request_data, dict):
call_id = request_data.get("litellm_call_id")
if isinstance(call_id, str) and call_id:
return call_id
metadata = request_data.get("metadata")
if isinstance(metadata, dict):
nested = metadata.get("litellm_call_id")
if isinstance(nested, str) and nested:
return nested
return None

View File

@ -217,7 +217,15 @@ def initialize_panw_prisma_airs(litellm_params, guardrail):
mask_response_content=getattr(litellm_params, "mask_response_content", False),
app_name=getattr(litellm_params, "app_name", None),
fallback_on_error=getattr(litellm_params, "fallback_on_error", "block"),
timeout=float(getattr(litellm_params, "timeout", 10.0)),
# `timeout` is now declared on BaseLitellmParams (Optional[float] = None),
# so the attribute always exists. The Pydantic validator on LitellmParams
# coerces strings to float, but None still means "use handler default" —
# guard against float(None) here.
timeout=(
float(getattr(litellm_params, "timeout", None))
if getattr(litellm_params, "timeout", None) is not None
else 10.0
),
violation_message_template=litellm_params.violation_message_template,
)
litellm.logging_callback_manager.add_litellm_callback(_panw_callback)

View File

@ -2433,3 +2433,89 @@ def create_generic_websocket_passthrough_endpoint(
_forward_headers=forward_headers,
cost_per_request=cost_per_request,
)
@router.api_route(
"/watsonx/{endpoint:path}",
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
tags=["Watsonx Pass-through", "pass-through"],
)
async def watsonx_proxy_route(
endpoint: str,
request: Request,
fastapi_response: Response,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
Watsonx pass-through endpoint.
Allows using Watsonx APIs with automatic IAM token management and version parameter injection.
Example:
POST /watsonx/ml/v1/text/tokenization
POST /watsonx/ml/v1/text/generation
"""
# Direct passthrough with WatsonxPassthroughConfig
from litellm.types.utils import LlmProviders
from litellm.utils import ProviderConfigManager
provider_config = ProviderConfigManager.get_provider_passthrough_config(
provider=LlmProviders.WATSONX,
model="",
)
if provider_config is None:
raise HTTPException(
status_code=404, detail="Watsonx passthrough config not found"
)
# Get complete URL with version parameter
complete_url, _ = provider_config.get_complete_url(
api_base=None,
api_key=None,
model="",
endpoint=endpoint,
request_query_params=None,
litellm_params={},
)
# Get auth headers with IAM token
auth_headers = provider_config.validate_environment(
headers={},
model="",
messages=[],
optional_params={},
litellm_params={},
api_key=None,
api_base=None,
)
# Check for streaming
is_streaming_request = False
if request.method == "POST":
if "multipart/form-data" not in request.headers.get("content-type", ""):
_request_body = await request.json()
else:
_request_body = await get_form_data(request)
if _request_body.get("stream"):
is_streaming_request = True
request_query_params = dict(request.query_params)
if request_query_params.get("version") is None:
request_query_params["version"] = litellm.WATSONX_DEFAULT_API_VERSION
# Create pass-through endpoint
endpoint_func = create_pass_through_route(
endpoint=endpoint,
target=str(complete_url),
custom_headers=auth_headers,
is_streaming_request=is_streaming_request,
custom_llm_provider="watsonx",
query_params=request_query_params,
)
return await endpoint_func(
request,
fastapi_response,
user_api_key_dict,
)

View File

@ -36,9 +36,18 @@ router = APIRouter()
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
async def spend_key_fn():
async def spend_key_fn(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
View all keys created, ordered by spend
View keys created, ordered by spend.
- Admin callers (PROXY_ADMIN / PROXY_ADMIN_VIEW_ONLY) see every key in
the database.
- All other callers (INTERNAL_USER / INTERNAL_USER_VIEW_ONLY, etc.) are
scoped to keys they own (``user_id == caller``). A caller with no
``user_id`` has no scope and receives an empty list rather than the
full table.
Example Request:
```
@ -55,8 +64,17 @@ async def spend_key_fn():
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
key_info = await prisma_client.get_data(table_name="key", query_type="find_all")
return key_info
if _is_admin_view_safe(user_api_key_dict=user_api_key_dict):
return await prisma_client.get_data(table_name="key", query_type="find_all")
caller_user_id = user_api_key_dict.user_id
if not caller_user_id:
return []
return await prisma_client.get_data(
table_name="key",
query_type="find_all",
user_id=caller_user_id,
)
except Exception as e:
raise HTTPException(
@ -85,9 +103,19 @@ async def spend_user_fn(
default=None,
description="Get User Table row for user_id",
),
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""
View all users created, ordered by spend
View users created, ordered by spend.
- Admin callers (PROXY_ADMIN / PROXY_ADMIN_VIEW_ONLY) see every user, or
a specific user when ``user_id`` is supplied.
- All other callers may only read their own row. If they supply a
``user_id`` query parameter that does not match their authenticated
``user_id`` the request is rejected with HTTP 403; supplying their
own id (or none at all) returns just their row. A caller with no
``user_id`` on their key has no scope and receives an empty list
rather than the full table.
Example Request:
```
@ -109,6 +137,17 @@ async def spend_user_fn(
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
if not _is_admin_view_safe(user_api_key_dict=user_api_key_dict):
caller_user_id = user_api_key_dict.user_id
if not caller_user_id:
return []
if user_id is not None and user_id != caller_user_id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={"error": "Not authorized to view spend for another user."},
)
user_id = caller_user_id
if user_id is not None:
user_info = await prisma_client.get_data(
table_name="user", query_type="find_unique", user_id=user_id
@ -123,6 +162,8 @@ async def spend_user_fn(
_strip_password_from_users(result)
return result
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,

View File

@ -41,6 +41,9 @@ from litellm.types.proxy.guardrails.guardrail_hooks.hiddenlayer import (
from litellm.types.proxy.guardrails.guardrail_hooks.qohash import (
QostodianNexusConfigModel,
)
from litellm.types.proxy.guardrails.guardrail_hooks.vigil_guard import (
VigilGuardGuardrailConfigModel,
)
"""
Pydantic object defining how to set guardrails on litellm proxy
@ -103,6 +106,7 @@ class SupportedGuardrailIntegrations(Enum):
LLM_AS_A_JUDGE = "llm_as_a_judge"
QOSTODIAN_NEXUS = "qostodian_nexus"
RUBRIK = "rubrik"
VIGIL_GUARD = "vigil_guard"
class Role(Enum):
@ -758,6 +762,15 @@ class BaseLitellmParams(
description="Python-like code containing the apply_guardrail function for custom guardrail logic",
)
timeout: Optional[float] = Field(
default=None,
description=(
"Per-request timeout for the guardrail provider API call (seconds). "
"Accepts int, float, or numeric string; coerced to float on load. "
"Each guardrail handler chooses its own default when unset."
),
)
model_config = ConfigDict(extra="allow", protected_namespaces=())
@ -791,6 +804,7 @@ class LitellmParams(
BlockCodeExecutionGuardrailConfigModel,
HiddenlayerGuardrailConfigModel,
QostodianNexusConfigModel,
VigilGuardGuardrailConfigModel,
):
guardrail: str = Field(description="The type of guardrail integration to use")
mode: Union[str, List[str], Mode] = Field(
@ -814,6 +828,18 @@ class LitellmParams(
return [x.lower() if isinstance(x, str) else x for x in v]
return v
@field_validator("timeout", mode="before", check_fields=False)
@classmethod
def coerce_timeout(cls, v):
"""Accept string-valued timeouts (dashboard UI sends JSON strings)
and coerce to float before any handler reads the value."""
if v is None or v == "":
return None
try:
return float(v)
except (TypeError, ValueError) as e:
raise ValueError(f"timeout must be numeric, got {v!r}") from e
def __init__(self, **kwargs):
default_on = kwargs.pop("default_on", None)
if default_on is not None:

View File

@ -238,6 +238,9 @@ DEFINED_PROMETHEUS_METRICS = Literal[
"litellm_cache_hits_metric",
"litellm_cache_misses_metric",
"litellm_cached_tokens_metric",
# Provider prompt-caching metrics (e.g. OpenAI/Anthropic/Bedrock/Gemini)
"litellm_provider_cache_read_input_tokens_metric",
"litellm_provider_cache_creation_input_tokens_metric",
"litellm_deployment_tpm_limit",
"litellm_deployment_rpm_limit",
"litellm_remaining_api_key_requests_for_model",
@ -655,6 +658,10 @@ class PrometheusMetricLabels:
litellm_cache_misses_metric = _cache_metric_labels
litellm_cached_tokens_metric = _cache_metric_labels
# Provider prompt-caching metrics - track tokens read/written to provider caches
litellm_provider_cache_read_input_tokens_metric = _cache_metric_labels
litellm_provider_cache_creation_input_tokens_metric = _cache_metric_labels
# Metrics whose emission paths supply org context (used by get_labels)
_org_label_metrics: ClassVar[frozenset] = frozenset(
{
@ -672,7 +679,6 @@ class PrometheusMetricLabels:
"litellm_output_tokens_metric",
}
)
# Managed batch metrics
_batch_user_labels = [
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value,

View File

@ -0,0 +1,26 @@
from typing import Optional
from pydantic import Field
from .base import GuardrailConfigModel
class VigilGuardGuardrailConfigModel(GuardrailConfigModel):
api_base: Optional[str] = Field(
default=None,
description=(
"Vigil Guard API base URL. "
"Falls back to the VIGIL_GUARD_URL environment variable."
),
)
api_key: Optional[str] = Field(
default=None,
description=(
"Vigil Guard API key. "
"Falls back to the VIGIL_GUARD_API_KEY environment variable."
),
)
@staticmethod
def ui_friendly_name() -> str:
return "Vigil Guard"

View File

@ -3365,6 +3365,7 @@ class LlmProviders(str, Enum):
POE = "poe"
CHUTES = "chutes"
XIAOMI_MIMO = "xiaomi_mimo"
TENSORMESH = "tensormesh"
LITELLM_AGENT = "litellm_agent"
CURSOR = "cursor"
BEDROCK_MANTLE = "bedrock_mantle"

View File

@ -8985,6 +8985,12 @@ class ProviderConfigManager:
)
return AzurePassthroughConfig()
elif LlmProviders.WATSONX == provider:
from litellm.llms.watsonx.passthrough.transformation import (
WatsonxPassthroughConfig,
)
return WatsonxPassthroughConfig()
return None
@staticmethod

View File

@ -577,7 +577,10 @@
"max_tokens": 8192,
"mode": "embedding",
"output_cost_per_token": 0.0,
"output_vector_size": 1024
"output_vector_size": 1024,
"provider_specific_entry": {
"bedrock_invocation_schema": "titan_v2"
}
},
"amazon.titan-image-generator-v1": {
"input_cost_per_image": 0.0,
@ -8899,15 +8902,16 @@
"cache_creation_input_token_cost": 3.75e-07
},
"bedrock/us-gov-east-1/anthropic.claude-sonnet-4-5-20250929-v1:0": {
"cache_creation_input_token_cost": 4.125e-06,
"cache_read_input_token_cost": 3.3e-07,
"input_cost_per_token": 3.3e-06,
"cache_creation_input_token_cost": 4.5e-06,
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
"cache_read_input_token_cost": 3.6e-07,
"input_cost_per_token": 3.6e-06,
"litellm_provider": "bedrock",
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_tokens": 8192,
"mode": "chat",
"output_cost_per_token": 1.65e-05,
"output_cost_per_token": 1.8e-05,
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
@ -8920,15 +8924,16 @@
"supports_native_structured_output": true
},
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
"cache_creation_input_token_cost": 4.125e-06,
"cache_read_input_token_cost": 3.3e-07,
"input_cost_per_token": 3.3e-06,
"cache_creation_input_token_cost": 4.5e-06,
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
"cache_read_input_token_cost": 3.6e-07,
"input_cost_per_token": 3.6e-06,
"litellm_provider": "bedrock",
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_tokens": 8192,
"mode": "chat",
"output_cost_per_token": 1.65e-05,
"output_cost_per_token": 1.8e-05,
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
@ -9072,15 +9077,16 @@
"cache_creation_input_token_cost": 3.75e-07
},
"bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0": {
"cache_creation_input_token_cost": 4.125e-06,
"cache_read_input_token_cost": 3.3e-07,
"input_cost_per_token": 3.3e-06,
"cache_creation_input_token_cost": 4.5e-06,
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
"cache_read_input_token_cost": 3.6e-07,
"input_cost_per_token": 3.6e-06,
"litellm_provider": "bedrock",
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_tokens": 8192,
"mode": "chat",
"output_cost_per_token": 1.65e-05,
"output_cost_per_token": 1.8e-05,
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
@ -9093,15 +9099,16 @@
"supports_native_structured_output": true
},
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
"cache_creation_input_token_cost": 4.125e-06,
"cache_read_input_token_cost": 3.3e-07,
"input_cost_per_token": 3.3e-06,
"cache_creation_input_token_cost": 4.5e-06,
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
"cache_read_input_token_cost": 3.6e-07,
"input_cost_per_token": 3.6e-06,
"litellm_provider": "bedrock",
"max_input_tokens": 200000,
"max_output_tokens": 8192,
"max_tokens": 8192,
"mode": "chat",
"output_cost_per_token": 1.65e-05,
"output_cost_per_token": 1.8e-05,
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
@ -31733,19 +31740,21 @@
"supports_native_structured_output": true
},
"us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0": {
"cache_creation_input_token_cost": 4.125e-06,
"cache_read_input_token_cost": 3.3e-07,
"input_cost_per_token": 3.3e-06,
"input_cost_per_token_above_200k_tokens": 6.6e-06,
"output_cost_per_token_above_200k_tokens": 2.475e-05,
"cache_creation_input_token_cost_above_200k_tokens": 8.25e-06,
"cache_read_input_token_cost_above_200k_tokens": 6.6e-07,
"cache_creation_input_token_cost": 4.5e-06,
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
"cache_read_input_token_cost": 3.6e-07,
"input_cost_per_token": 3.6e-06,
"input_cost_per_token_above_200k_tokens": 7.2e-06,
"output_cost_per_token_above_200k_tokens": 2.7e-05,
"cache_creation_input_token_cost_above_200k_tokens": 9.0e-06,
"cache_creation_input_token_cost_above_1hr_above_200k_tokens": 1.44e-05,
"cache_read_input_token_cost_above_200k_tokens": 7.2e-07,
"litellm_provider": "bedrock_converse",
"max_input_tokens": 200000,
"max_output_tokens": 64000,
"max_tokens": 64000,
"mode": "chat",
"output_cost_per_token": 1.65e-05,
"output_cost_per_token": 1.8e-05,
"supports_assistant_prefill": true,
"supports_computer_use": true,
"supports_function_calling": true,
@ -41332,6 +41341,7 @@
},
"bedrock/us-gov-east-1/anthropic.claude-haiku-4-5-20251001-v1:0": {
"cache_creation_input_token_cost": 1.5e-06,
"cache_creation_input_token_cost_above_1hr": 2.4e-06,
"cache_read_input_token_cost": 1.2e-07,
"input_cost_per_token": 1.2e-06,
"litellm_provider": "bedrock",
@ -41354,6 +41364,7 @@
},
"bedrock/us-gov-west-1/anthropic.claude-haiku-4-5-20251001-v1:0": {
"cache_creation_input_token_cost": 1.5e-06,
"cache_creation_input_token_cost_above_1hr": 2.4e-06,
"cache_read_input_token_cost": 1.2e-07,
"input_cost_per_token": 1.2e-06,
"litellm_provider": "bedrock",

View File

@ -2079,6 +2079,24 @@
"a2a": false
}
},
"tensormesh": {
"display_name": "Tensormesh (`tensormesh`)",
"url": "https://docs.litellm.ai/docs/providers/tensormesh",
"endpoints": {
"chat_completions": true,
"messages": true,
"responses": false,
"embeddings": false,
"image_generations": false,
"audio_transcriptions": false,
"audio_speech": false,
"moderations": false,
"batches": false,
"rerank": false,
"a2a": false,
"text_completion": true
}
},
"text-completion-codestral": {
"display_name": "Text Completion Codestral (`text-completion-codestral`)",
"url": "https://docs.litellm.ai/docs/providers/codestral",

View File

@ -0,0 +1,69 @@
"""Tests for FocusTransformer — ConsumedQuantity / PricingQuantity correctness."""
from __future__ import annotations
from decimal import Decimal
import polars as pl
from litellm.integrations.focus.transformer import FocusTransformer
def _base_row(**overrides) -> dict:
row = {
"date": "2026-05-25",
"user_id": "u1",
"api_key": "sk-test",
"api_key_alias": "my-key",
"model": "gpt-4o",
"model_group": "openai",
"custom_llm_provider": "openai",
"spend": 0.05,
"api_requests": 3,
"team_id": "team1",
"team_alias": "Engineering",
"user_email": "user@example.com",
}
row.update(overrides)
return row
def _transform(rows: list[dict]) -> pl.DataFrame:
frame = pl.DataFrame(rows, infer_schema_length=None)
return FocusTransformer().transform(frame)
def test_consumed_quantity_reflects_api_requests():
result = _transform([_base_row(api_requests=7)])
assert result["ConsumedQuantity"][0] == Decimal("7.000000")
def test_pricing_quantity_reflects_api_requests():
result = _transform([_base_row(api_requests=7)])
assert result["PricingQuantity"][0] == Decimal("7.000000")
def test_null_api_requests_falls_back_to_zero_not_one():
"""Rows with NULL api_requests (old schema rows) must produce 0, not 1."""
result = _transform([_base_row(api_requests=None)])
assert result["ConsumedQuantity"][0] == Decimal("0.000000")
assert result["PricingQuantity"][0] == Decimal("0.000000")
def test_zero_api_requests_stays_zero():
result = _transform([_base_row(api_requests=0)])
assert result["ConsumedQuantity"][0] == Decimal("0.000000")
assert result["PricingQuantity"][0] == Decimal("0.000000")
def test_bigint_api_requests_cast_correctly():
"""api_requests comes from Postgres as BigInt — large values must not overflow."""
result = _transform([_base_row(api_requests=1_000_000)])
assert result["ConsumedQuantity"][0] == Decimal("1000000.000000")
assert result["PricingQuantity"][0] == Decimal("1000000.000000")
def test_consumed_and_pricing_quantity_match():
"""ConsumedQuantity and PricingQuantity must always be equal."""
result = _transform([_base_row(api_requests=42)])
assert result["ConsumedQuantity"][0] == result["PricingQuantity"][0]

View File

@ -35,6 +35,8 @@ class TestPrometheusCacheMetrics:
assert "litellm_cache_hits_metric" in defined_metrics
assert "litellm_cache_misses_metric" in defined_metrics
assert "litellm_cached_tokens_metric" in defined_metrics
assert "litellm_provider_cache_read_input_tokens_metric" in defined_metrics
assert "litellm_provider_cache_creation_input_tokens_metric" in defined_metrics
def test_cache_metric_labels_defined(self):
"""Test that cache metric labels are properly defined"""
@ -44,6 +46,13 @@ class TestPrometheusCacheMetrics:
assert hasattr(PrometheusMetricLabels, "litellm_cache_hits_metric")
assert hasattr(PrometheusMetricLabels, "litellm_cache_misses_metric")
assert hasattr(PrometheusMetricLabels, "litellm_cached_tokens_metric")
assert hasattr(
PrometheusMetricLabels, "litellm_provider_cache_read_input_tokens_metric"
)
assert hasattr(
PrometheusMetricLabels,
"litellm_provider_cache_creation_input_tokens_metric",
)
# Verify labels include expected keys
expected_labels = [
@ -59,6 +68,14 @@ class TestPrometheusCacheMetrics:
assert label in PrometheusMetricLabels.litellm_cache_hits_metric
assert label in PrometheusMetricLabels.litellm_cache_misses_metric
assert label in PrometheusMetricLabels.litellm_cached_tokens_metric
assert (
label
in PrometheusMetricLabels.litellm_provider_cache_read_input_tokens_metric
)
assert (
label
in PrometheusMetricLabels.litellm_provider_cache_creation_input_tokens_metric
)
def test_increment_cache_metrics_on_cache_hit(self, sample_enum_values):
"""Test that cache hit increments the correct metrics"""
@ -76,12 +93,20 @@ class TestPrometheusCacheMetrics:
"completion_tokens": 50,
"model_group": "openai",
"request_tags": [],
"metadata": {
"usage_object": {
"cache_read_input_tokens": 25,
"cache_creation_input_tokens": 10,
}
},
}
# Create mock metrics
mock_logger.litellm_cache_hits_metric = MagicMock()
mock_logger.litellm_cache_misses_metric = MagicMock()
mock_logger.litellm_cached_tokens_metric = MagicMock()
mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock()
mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock()
mock_logger.get_labels_for_metric = MagicMock(
return_value=[
"model",
@ -114,6 +139,14 @@ class TestPrometheusCacheMetrics:
# Verify cache misses metric was NOT called
mock_logger.litellm_cache_misses_metric.labels.assert_not_called()
# Verify provider prompt caching metrics were incremented
mock_logger.litellm_provider_cache_read_input_tokens_metric.labels().inc.assert_called_once_with(
25
)
mock_logger.litellm_provider_cache_creation_input_tokens_metric.labels().inc.assert_called_once_with(
10
)
def test_increment_cache_metrics_on_cache_miss(self, sample_enum_values):
"""Test that cache miss increments the correct metrics"""
# Create mock for PrometheusLogger instance
@ -129,12 +162,20 @@ class TestPrometheusCacheMetrics:
"completion_tokens": 50,
"model_group": "openai",
"request_tags": [],
"metadata": {
"usage_object": {
# Explicit provider field absent -> fallback should use prompt_tokens_details.cached_tokens
"prompt_tokens_details": {"cached_tokens": 20},
}
},
}
# Create mock metrics
mock_logger.litellm_cache_hits_metric = MagicMock()
mock_logger.litellm_cache_misses_metric = MagicMock()
mock_logger.litellm_cached_tokens_metric = MagicMock()
mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock()
mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock()
mock_logger.get_labels_for_metric = MagicMock(
return_value=[
"model",
@ -162,6 +203,61 @@ class TestPrometheusCacheMetrics:
mock_logger.litellm_cache_hits_metric.labels.assert_not_called()
mock_logger.litellm_cached_tokens_metric.labels.assert_not_called()
# Provider prompt caching metrics should still be emitted
mock_logger.litellm_provider_cache_read_input_tokens_metric.labels().inc.assert_called_once_with(
20
)
mock_logger.litellm_provider_cache_creation_input_tokens_metric.labels.assert_not_called()
def test_provider_cache_read_does_not_fallback_on_explicit_zero(
self, sample_enum_values
):
"""Explicit cache_read_input_tokens=0 must not trigger fallback to cached_tokens."""
mock_logger = MagicMock()
from litellm.integrations.prometheus import PrometheusLogger
standard_logging_payload = {
"cache_hit": False,
"total_tokens": 100,
"prompt_tokens": 50,
"completion_tokens": 50,
"model_group": "openai",
"request_tags": [],
"metadata": {
"usage_object": {
"cache_read_input_tokens": 0,
"prompt_tokens_details": {"cached_tokens": 20},
}
},
}
mock_logger.litellm_cache_hits_metric = MagicMock()
mock_logger.litellm_cache_misses_metric = MagicMock()
mock_logger.litellm_cached_tokens_metric = MagicMock()
mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock()
mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock()
mock_logger.get_labels_for_metric = MagicMock(
return_value=[
"model",
"hashed_api_key",
"api_key_alias",
"team",
"team_alias",
"end_user",
"user",
]
)
PrometheusLogger._increment_cache_metrics(
mock_logger,
standard_logging_payload=standard_logging_payload,
enum_values=sample_enum_values,
)
# Should not emit read metric, because explicit provider value is zero.
mock_logger.litellm_provider_cache_read_input_tokens_metric.labels.assert_not_called()
def test_increment_cache_metrics_when_cache_hit_is_none(self, sample_enum_values):
"""Test that no metrics are incremented when cache_hit is None"""
# Create mock for PrometheusLogger instance
@ -177,12 +273,19 @@ class TestPrometheusCacheMetrics:
"completion_tokens": 50,
"model_group": "openai",
"request_tags": [],
"metadata": {
"usage_object": {
"cache_read_input_tokens": 25,
}
},
}
# Create mock metrics
mock_logger.litellm_cache_hits_metric = MagicMock()
mock_logger.litellm_cache_misses_metric = MagicMock()
mock_logger.litellm_cached_tokens_metric = MagicMock()
mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock()
mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock()
mock_logger.get_labels_for_metric = MagicMock(
return_value=[
"model",
@ -207,6 +310,12 @@ class TestPrometheusCacheMetrics:
mock_logger.litellm_cache_misses_metric.labels.assert_not_called()
mock_logger.litellm_cached_tokens_metric.labels.assert_not_called()
# Provider prompt caching metrics should still be emitted
mock_logger.litellm_provider_cache_read_input_tokens_metric.labels().inc.assert_called_once_with(
25
)
mock_logger.litellm_provider_cache_creation_input_tokens_metric.labels.assert_not_called()
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -0,0 +1,43 @@
import pytest
import litellm
from litellm.litellm_core_utils.fallback_utils import async_completion_with_fallbacks
@pytest.mark.asyncio
async def test_fallback_dict_not_mutated(monkeypatch):
fallback_dict = {"model": "fallback-model", "temperature": 0.2}
original_fallback_dict = dict(fallback_dict)
attempted_models: list[str] = []
async def _fake_acompletion(*, model: str, **kwargs):
attempted_models.append(model)
if model == "primary-model":
raise Exception("primary failed")
return {"model": model, "temperature": kwargs.get("temperature")}
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion)
# Call 1: primary fails, fallback dict succeeds
response_1 = await async_completion_with_fallbacks(
model="primary-model",
kwargs={"fallbacks": [fallback_dict]},
)
assert response_1["model"] == "fallback-model"
assert fallback_dict == original_fallback_dict
# Call 2: re-use the same dict object; it should still work and remain unchanged
response_2 = await async_completion_with_fallbacks(
model="primary-model",
kwargs={"fallbacks": [fallback_dict]},
)
assert response_2["model"] == "fallback-model"
assert fallback_dict == original_fallback_dict
assert attempted_models == [
"primary-model",
"fallback-model",
"primary-model",
"fallback-model",
]

View File

@ -0,0 +1,3 @@
{"recordId": "embed-1", "modelInput": {"inputText": "Hello world"}}
{"recordId": "embed-2", "modelInput": {"inputText": "Another document to embed", "dimensions": 512}}
{"recordId": "embed-3", "modelInput": {"inputText": "Single element list", "embeddingTypes": ["binary"]}}

View File

@ -0,0 +1,3 @@
{"custom_id": "embed-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "bedrock/amazon.titan-embed-text-v2:0", "input": "Hello world"}}
{"custom_id": "embed-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "bedrock/amazon.titan-embed-text-v2:0", "input": "Another document to embed", "dimensions": 512}}
{"custom_id": "embed-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "bedrock/amazon.titan-embed-text-v2:0", "input": ["Single element list"], "encoding_format": "base64"}}

View File

@ -426,7 +426,7 @@ class TestBedrockFilesTransformation:
"s3_bucket_name": "litellm-batch-352026",
"s3_region_name": "us-gov-west-1",
}
# aws_region_name set to something different s3_region_name must still win
# aws_region_name set to something different - s3_region_name must still win
optional_params = {"aws_region_name": "us-east-1"}
captured_optional_params: dict = {}
@ -482,3 +482,630 @@ class TestBedrockFilesTransformation:
assert "messages" in model_input
assert "max_tokens" in model_input
assert model_input["max_tokens"] == 10
class TestBedrockFilesEmbeddingTransformation:
"""
Tests for routing OpenAI /v1/embeddings batch JSONL records through the
Titan v2 transformer so AWS Bedrock's CreateModelInvocationJob receives
a valid modelInput body.
Scope is intentionally Titan v2 only - other embedding models will get
their own follow-up PRs/tests so each schema is exercised in isolation.
"""
def test_titan_v2_embedding_jsonl_matches_fixture(self):
"""Round-trip the input fixture against the expected Bedrock output."""
import json
import os
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
here = os.path.dirname(__file__)
with open(os.path.join(here, "input_batch_embeddings.jsonl")) as f:
openai_jsonl = [json.loads(line) for line in f if line.strip()]
with open(os.path.join(here, "expected_bedrock_batch_embeddings.jsonl")) as f:
expected = [json.loads(line) for line in f if line.strip()]
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
openai_jsonl
)
assert result == expected
def test_titan_v2_simple_string_input(self):
"""Single string `input` maps to `{"inputText": <str>}` with no extras."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "e1",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": "bedrock/amazon.titan-embed-text-v2:0",
"input": "Hello",
},
}
]
)
assert result == [{"recordId": "e1", "modelInput": {"inputText": "Hello"}}]
def test_titan_v2_dimensions_and_encoding_format(self):
"""OpenAI `dimensions` / `encoding_format` map to Titan v2 schema."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "e1",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": "bedrock/amazon.titan-embed-text-v2:0",
"input": "Hi",
"dimensions": 256,
"encoding_format": "float",
},
}
]
)
model_input = result[0]["modelInput"]
assert model_input["inputText"] == "Hi"
assert model_input["dimensions"] == 256
assert model_input["embeddingTypes"] == ["float"]
def test_embedding_routing_falls_back_to_body_shape(self):
"""Records without `url` still route via `input` presence."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "e1",
"body": {
"model": "bedrock/amazon.titan-embed-text-v2:0",
"input": "Hello",
},
}
]
)
assert result[0]["modelInput"] == {"inputText": "Hello"}
def test_embedding_single_element_list_input_is_accepted(self):
"""A single-element list maps to the same shape as a bare string."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "e1",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": "bedrock/amazon.titan-embed-text-v2:0",
"input": ["only one"],
},
}
]
)
assert result[0]["modelInput"]["inputText"] == "only one"
def test_embedding_multi_input_list_raises(self):
"""Multi-element `input` lists are rejected with a clear message."""
import pytest
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
with pytest.raises(ValueError, match="one input per JSONL record"):
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "e1",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": "bedrock/amazon.titan-embed-text-v2:0",
"input": ["a", "b"],
},
}
]
)
def test_embedding_missing_input_raises(self):
"""A record routed to /v1/embeddings without `input` is an error."""
import pytest
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
with pytest.raises(ValueError, match="missing required `input`"):
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "e1",
"method": "POST",
"url": "/v1/embeddings",
"body": {"model": "bedrock/amazon.titan-embed-text-v2:0"},
}
]
)
def test_mixed_chat_and_embedding_in_same_batch(self):
"""Chat and embedding records in the same JSONL each take their path."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "chat-1",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0",
"messages": [{"role": "user", "content": "Hi"}],
"max_tokens": 5,
},
},
{
"custom_id": "embed-1",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": "bedrock/amazon.titan-embed-text-v2:0",
"input": "Hi",
},
},
]
)
assert result[0]["recordId"] == "chat-1"
assert "messages" in result[0]["modelInput"]
assert result[0]["modelInput"]["anthropic_version"] == "bedrock-2023-05-31"
assert result[1]["recordId"] == "embed-1"
assert result[1]["modelInput"] == {"inputText": "Hi"}
def test_unsupported_embedding_model_raises_not_implemented(self):
"""Cohere/Nova/Titan-G1 embed get a clear NotImplementedError, not a corrupt body."""
import pytest
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
for unsupported_model in (
"bedrock/cohere.embed-english-v3",
"bedrock/amazon.titan-embed-text-v1",
"bedrock/amazon.titan-embed-image-v1",
"bedrock/amazon.nova-2-multimodal-embeddings-v1:0",
):
with pytest.raises(NotImplementedError, match="titan-embed-text-v2"):
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "e1",
"method": "POST",
"url": "/v1/embeddings",
"body": {"model": unsupported_model, "input": "Hi"},
}
]
)
def test_titan_v2_model_name_variants_route_correctly(self):
"""All common Titan v2 model id shapes route through the embedding path."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
for model_id in (
"amazon.titan-embed-text-v2:0",
"bedrock/amazon.titan-embed-text-v2:0",
"us.amazon.titan-embed-text-v2:0",
"bedrock/us.amazon.titan-embed-text-v2:0",
):
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "e1",
"method": "POST",
"url": "/v1/embeddings",
"body": {"model": model_id, "input": "Hi"},
}
]
)
assert result[0]["modelInput"] == {
"inputText": "Hi"
}, f"model id {model_id} did not route to Titan v2 embedding path"
def test_pretokenized_input_list_of_ints_raises(self):
"""`input: List[int]` (pre-tokenized) is rejected, not silently mis-shaped."""
import pytest
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
with pytest.raises(
(NotImplementedError, ValueError), match=r"pre-tokenized|one input per"
):
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "e1",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": "bedrock/amazon.titan-embed-text-v2:0",
"input": [1, 2, 3],
},
}
]
)
def test_pretokenized_single_wrapped_list_raises(self):
"""`input: List[List[int]]` with one element is rejected as pre-tokenized."""
import pytest
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
with pytest.raises(NotImplementedError, match="pre-tokenized"):
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "e1",
"method": "POST",
"url": "/v1/embeddings",
"body": {
"model": "bedrock/amazon.titan-embed-text-v2:0",
"input": [[1, 2, 3]],
},
}
]
)
def test_record_with_both_input_and_messages_routes_to_chat(self):
"""If a record has both fields, chat wins (safer default - see helper docstring)."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "ambiguous-1",
"body": {
"model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0",
"messages": [{"role": "user", "content": "Hi"}],
"input": "this should be ignored by chat path",
"max_tokens": 5,
},
}
]
)
assert "messages" in result[0]["modelInput"]
assert "inputText" not in result[0]["modelInput"]
def test_url_embeddings_with_missing_input_raises_not_chat_error(self):
"""url says embed, body lacks input → embedding-path error, not chat-path crash."""
import pytest
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
config = BedrockFilesConfig()
with pytest.raises(ValueError, match="missing required `input`"):
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "e1",
"method": "POST",
"url": "/v1/embeddings",
"body": {"model": "bedrock/amazon.titan-embed-text-v2:0"},
}
]
)
def test_titan_v2_marker_boundary_rejects_lookalikes(self):
"""The marker must end at `:`, `/`, or end-of-string to avoid false positives."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
# Look-alikes that must NOT route through the Titan v2 path
for model in (
"bedrock/amazon.titan-embed-text-v20:0",
"bedrock/amazon.titan-embed-text-v2-experimental:0",
"bedrock/amazon.titan-embed-text-v2foo",
):
assert not BedrockFilesConfig._is_titan_v2_embed_model(
model
), f"{model} unexpectedly matched the Titan v2 marker"
# Real Titan v2 ids that MUST match
for model in (
"amazon.titan-embed-text-v2:0",
"bedrock/amazon.titan-embed-text-v2:0",
"us.amazon.titan-embed-text-v2:0",
"arn:aws:bedrock:us-east-1:123:foundation-model/amazon.titan-embed-text-v2:0",
):
assert BedrockFilesConfig._is_titan_v2_embed_model(
model
), f"{model} unexpectedly missed the Titan v2 marker"
def test_titan_v2_accepted_when_registry_schema_field_matches(self, mocker):
"""Registry-driven happy path: nested
`provider_specific_entry.bedrock_invocation_schema == "titan_v2"`
is the authoritative signal."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
mocker.patch(
"litellm.get_model_info",
return_value={
"provider_specific_entry": {"bedrock_invocation_schema": "titan_v2"}
},
)
assert BedrockFilesConfig._is_titan_v2_embed_model(
"amazon.titan-embed-text-v2:0"
)
def test_titan_v2_rejected_when_registry_schema_field_differs(self, mocker):
"""Registry resolves with a different schema value (e.g. a hypothetical
Cohere Embed entry) -> reject. Registry is authoritative; no substring
second-chance for ids the registry knows."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
mocker.patch(
"litellm.get_model_info",
return_value={
"provider_specific_entry": {"bedrock_invocation_schema": "cohere_v3"}
},
)
# Even though the model id looks like Titan v2, the registry says
# otherwise and we trust it.
assert not BedrockFilesConfig._is_titan_v2_embed_model(
"amazon.titan-embed-text-v2:0"
)
def test_titan_v2_falls_back_to_marker_when_registry_lacks_schema_field(
self, mocker
):
"""Registry resolves but the entry has no
`provider_specific_entry.bedrock_invocation_schema` field yet (e.g.
a stale local registry) -> fall through to substring."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
# No provider_specific_entry at all
mocker.patch(
"litellm.get_model_info",
return_value={"mode": "embedding"},
)
assert BedrockFilesConfig._is_titan_v2_embed_model(
"amazon.titan-embed-text-v2:0"
)
# provider_specific_entry present but missing the schema key
mocker.patch(
"litellm.get_model_info",
return_value={
"mode": "embedding",
"provider_specific_entry": {"unrelated": "value"},
},
)
assert BedrockFilesConfig._is_titan_v2_embed_model(
"amazon.titan-embed-text-v2:0"
)
def test_titan_v2_accepted_when_registry_silent(self, mocker):
"""Marker-only match is fine for ids the registry can't resolve
(cross-region profile prefixes, ARN forms)."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
mocker.patch("litellm.get_model_info", side_effect=Exception("not mapped"))
assert BedrockFilesConfig._is_titan_v2_embed_model(
"us.amazon.titan-embed-text-v2:0"
)
assert BedrockFilesConfig._is_titan_v2_embed_model(
"arn:aws:bedrock:us-east-1:123:foundation-model/amazon.titan-embed-text-v2:0"
)
def test_lookup_provider_specific_field_helper(self, mocker):
"""Direct coverage of the nested registry field helper."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
# Happy path: returns the nested field's string value
mocker.patch(
"litellm.get_model_info",
return_value={
"provider_specific_entry": {"bedrock_invocation_schema": "titan_v2"}
},
)
assert (
BedrockFilesConfig._lookup_provider_specific_field(
"anything", "bedrock_invocation_schema"
)
== "titan_v2"
)
# Registry raises -> None
mocker.patch("litellm.get_model_info", side_effect=Exception("not mapped"))
assert (
BedrockFilesConfig._lookup_provider_specific_field("anything", "any")
is None
)
# Registry returns non-dict -> None
mocker.patch("litellm.get_model_info", return_value="not a dict")
assert (
BedrockFilesConfig._lookup_provider_specific_field("anything", "any")
is None
)
# Registry returns dict without provider_specific_entry -> None
mocker.patch("litellm.get_model_info", return_value={"mode": "embedding"})
assert (
BedrockFilesConfig._lookup_provider_specific_field(
"anything", "bedrock_invocation_schema"
)
is None
)
# provider_specific_entry exists but isn't a dict -> None
mocker.patch(
"litellm.get_model_info",
return_value={"provider_specific_entry": "not a dict"},
)
assert (
BedrockFilesConfig._lookup_provider_specific_field(
"anything", "bedrock_invocation_schema"
)
is None
)
# provider_specific_entry dict missing the requested field -> None
mocker.patch(
"litellm.get_model_info",
return_value={"provider_specific_entry": {"unrelated": "x"}},
)
assert (
BedrockFilesConfig._lookup_provider_specific_field(
"anything", "bedrock_invocation_schema"
)
is None
)
# Non-string nested value -> None
mocker.patch(
"litellm.get_model_info",
return_value={"provider_specific_entry": {"bedrock_invocation_schema": 42}},
)
assert (
BedrockFilesConfig._lookup_provider_specific_field(
"anything", "bedrock_invocation_schema"
)
is None
)
# Empty-string nested value -> None
mocker.patch(
"litellm.get_model_info",
return_value={"provider_specific_entry": {"bedrock_invocation_schema": ""}},
)
assert (
BedrockFilesConfig._lookup_provider_specific_field(
"anything", "bedrock_invocation_schema"
)
is None
)
def test_is_embedding_record_helper(self):
"""Helper detects embeddings via `url` first, then by body shape."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
assert BedrockFilesConfig._is_embedding_record(
{"url": "/v1/embeddings", "body": {"input": "x"}}
)
# body-only fallback
assert BedrockFilesConfig._is_embedding_record({"body": {"input": "x"}})
# chat shape
assert not BedrockFilesConfig._is_embedding_record(
{"url": "/v1/chat/completions", "body": {"messages": []}}
)
# ambiguous body without `input` is treated as not-embedding
assert not BedrockFilesConfig._is_embedding_record({"body": {}})
def test_explicit_chat_url_with_input_body_short_circuits_to_chat(self):
"""Explicit url=/v1/chat/completions wins even if body looks like embedding.
Without this short-circuit, a chat record whose body happens to carry
`input` (and no `messages`) would be mis-routed to the embedding
transformer, corrupting the modelInput.
"""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
# Direct helper assertion
assert not BedrockFilesConfig._is_embedding_record(
{
"url": "/v1/chat/completions",
"body": {
"model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0",
"input": "this would mis-route under the old precedence",
},
}
)
# End-to-end: a record like this routes through the chat path. We
# just need to make sure we DON'T silently produce an inputText
# body and call it a chat completion.
config = BedrockFilesConfig()
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
[
{
"custom_id": "explicit-chat-with-input",
"method": "POST",
"url": "/v1/chat/completions",
"body": {
"model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0",
"messages": [{"role": "user", "content": "Hi"}],
"input": "should not become inputText",
"max_tokens": 5,
},
}
]
)
model_input = result[0]["modelInput"]
assert (
"inputText" not in model_input
), "explicit chat URL must not produce an embedding-shaped modelInput"
def test_coerce_embedding_input_helper_isolated(self):
"""Direct coverage of the extracted input-normalization helper."""
import pytest
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
# Happy paths
assert BedrockFilesConfig._coerce_embedding_input_to_string("hello") == "hello"
assert (
BedrockFilesConfig._coerce_embedding_input_to_string(["hello"]) == "hello"
)
# Error paths
with pytest.raises(ValueError, match="missing required `input`"):
BedrockFilesConfig._coerce_embedding_input_to_string(None, model="m")
with pytest.raises(ValueError, match="one input per JSONL record"):
BedrockFilesConfig._coerce_embedding_input_to_string(["a", "b"])
# A multi-element list of ints is rejected as "one input per JSONL
# record" too - we can't tell if it's pre-tokenized or "3 strings"
# without more context, so the most-actionable error wins.
with pytest.raises(ValueError, match="one input per JSONL record"):
BedrockFilesConfig._coerce_embedding_input_to_string([1, 2, 3])
# Single-element list wrapping a token list -> pre-tokenized error.
with pytest.raises(NotImplementedError, match="pre-tokenized"):
BedrockFilesConfig._coerce_embedding_input_to_string([[1, 2, 3]])
# Single-element list wrapping a bare int -> pre-tokenized error.
with pytest.raises(NotImplementedError, match="pre-tokenized"):
BedrockFilesConfig._coerce_embedding_input_to_string([42])
with pytest.raises(ValueError, match="must be a string"):
BedrockFilesConfig._coerce_embedding_input_to_string({"unsupported": True})
def test_other_non_embedding_urls_route_to_chat(self):
"""Any non-/v1/embeddings url short-circuits to chat path."""
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
# /v1/completions (legacy completions endpoint)
assert not BedrockFilesConfig._is_embedding_record(
{"url": "/v1/completions", "body": {"input": "x"}}
)
# Arbitrary unknown url - caller's explicit signal still wins
assert not BedrockFilesConfig._is_embedding_record(
{"url": "/v1/responses", "body": {"input": "x"}}
)

View File

@ -1,8 +1,9 @@
import json
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import httpx
import pytest
from botocore.credentials import Credentials
def _anthropic_response(url: str) -> httpx.Response:
@ -310,3 +311,54 @@ async def test_anthropic_messages_routes_bedrock_claude_platform_to_messages_api
assert requests[0]["body"]["messages"] == [{"role": "user", "content": "hello"}]
assert requests[0]["body"]["max_tokens"] == 10
assert requests[0]["body"]["model"] == "claude-sonnet-4-6"
def test_sigv4_no_duplicate_content_type_when_caller_sets_lowercase():
"""
Regression: get_anthropic_headers() supplies "content-type" (lowercase).
_sign_request() used to prepend "Content-Type" (uppercase), leaving both
keys in the dict. botocore joins them into "application/json, application/json"
in the canonical string, while the wire request sends only one value 401.
Fix: prepend with lowercase "content-type" so **headers overwrites it when
the caller already set it.
"""
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
llm = BaseAWSLLM()
mock_credentials = Credentials("key", "secret", "token")
mock_sigv4 = MagicMock()
captured: list[dict] = []
def fake_aws_request(method, url, data, headers):
captured.append(dict(headers))
req = MagicMock()
req.headers = {"Authorization": "AWS4-HMAC-SHA256 Credential=test"}
req.body = data.encode() if isinstance(data, str) else data
return req
with (
patch("botocore.auth.SigV4Auth", return_value=mock_sigv4),
patch("botocore.awsrequest.AWSRequest", side_effect=fake_aws_request),
patch.object(llm, "get_credentials", return_value=mock_credentials),
patch.object(llm, "_get_aws_region_name", return_value="us-east-1"),
):
llm._sign_request(
service_name="aws-external-anthropic",
headers={"content-type": "application/json"},
optional_params={"aws_region_name": "us-east-1"},
request_data={
"model": "claude-sonnet-4-6",
"messages": [],
"max_tokens": 10,
},
api_base="https://aws-external-anthropic.us-east-1.api.aws/v1/messages",
)
signed = captured[0]
ct_keys = [k for k in signed if k.lower() == "content-type"]
assert ct_keys == ["content-type"], (
f"Expected exactly one 'content-type' key, got {ct_keys}. "
"Duplicate keys produce 'application/json, application/json' in the "
"SigV4 canonical string and cause a 401."
)

View File

@ -0,0 +1,67 @@
"""
Tests for Black Forest Labs common_utils specifically assert_bfl_polling_url.
BFL uses regional subdomains (e.g. gateway.bfl.ai) for polling URLs that
differ from the submission host (api.bfl.ai). These tests verify that the
domain-aware check accepts legitimate BFL subdomains while still rejecting
off-domain and non-HTTPS URLs.
"""
import pytest
from litellm.llms.black_forest_labs.common_utils import (
BlackForestLabsError,
assert_bfl_polling_url,
)
class TestAssertBflPollingUrl:
# --- should pass ---
def test_exact_registered_domain(self):
assert_bfl_polling_url("https://bfl.ai/v1/get_result?id=abc")
def test_api_subdomain(self):
assert_bfl_polling_url("https://api.bfl.ai/v1/get_result?id=abc")
def test_gateway_subdomain(self):
# BFL uses gateway.bfl.ai for polling — this was the original bug trigger
assert_bfl_polling_url("https://gateway.bfl.ai/v1/get_result?id=abc")
def test_regional_subdomain(self):
assert_bfl_polling_url("https://eu.api.bfl.ai/v1/get_result?id=abc")
def test_deep_subdomain(self):
assert_bfl_polling_url("https://region.gateway.bfl.ai/poll?id=xyz")
# --- should raise BlackForestLabsError ---
def test_rejects_http_scheme(self):
# HTTP must be rejected — x-key would be forwarded in plaintext
with pytest.raises(BlackForestLabsError, match="scheme must be https"):
assert_bfl_polling_url("http://api.bfl.ai/v1/get_result?id=abc")
def test_rejects_off_domain(self):
with pytest.raises(BlackForestLabsError, match="host is not within"):
assert_bfl_polling_url("https://evil.com/steal-key")
def test_rejects_lookalike_domain(self):
with pytest.raises(BlackForestLabsError, match="host is not within"):
assert_bfl_polling_url("https://notbfl.ai/v1/get_result?id=abc")
def test_rejects_bfl_ai_as_suffix_only(self):
# "fakebfl.ai" must not match — the check is on registered domain boundary
with pytest.raises(BlackForestLabsError, match="host is not within"):
assert_bfl_polling_url("https://fakebfl.ai/v1/get_result?id=abc")
def test_rejects_bfl_in_path(self):
with pytest.raises(BlackForestLabsError, match="host is not within"):
assert_bfl_polling_url("https://evil.com/bfl.ai/steal")
def test_rejects_ftp_scheme(self):
with pytest.raises(BlackForestLabsError, match="scheme must be https"):
assert_bfl_polling_url("ftp://api.bfl.ai/v1/get_result?id=abc")
def test_rejects_javascript_scheme(self):
with pytest.raises(BlackForestLabsError, match="scheme must be https"):
assert_bfl_polling_url("javascript://api.bfl.ai/alert(1)")

View File

@ -8,8 +8,7 @@ with guardrail transformations, including tool calls.
import json
import os
import sys
from typing import Any, List, Literal, Optional, Tuple
from unittest.mock import AsyncMock, MagicMock
from typing import Any, Literal, Optional
import pytest
@ -84,6 +83,70 @@ class MockGuardrail(CustomGuardrail):
return result
class MockCopiedToolCallGuardrail(CustomGuardrail):
"""Mock guardrail that returns copied tool calls instead of mutating inputs."""
async def apply_guardrail(
self,
inputs: GenericGuardrailAPIInputs,
request_data: dict,
input_type: Literal["request", "response"],
logging_obj: Optional[Any] = None,
) -> GenericGuardrailAPIInputs:
tool_calls = inputs.get("tool_calls", [])
copied_tool_calls = []
for tool_call in tool_calls:
copied = dict(tool_call)
function = dict(copied["function"])
function["arguments"] = json.dumps({"email": "[EMAIL]"})
copied["function"] = function
copied_tool_calls.append(copied)
return GenericGuardrailAPIInputs(
texts=inputs.get("texts", []),
tool_calls=copied_tool_calls,
)
class MockNonListToolCallGuardrail(CustomGuardrail):
"""Mock guardrail that returns tool_calls as a non-list envelope on the response
path, as some released guardrails do when they assign a detection API JSON dict."""
async def apply_guardrail(
self,
inputs: GenericGuardrailAPIInputs,
request_data: dict,
input_type: Literal["request", "response"],
logging_obj: Optional[Any] = None,
) -> GenericGuardrailAPIInputs:
result = GenericGuardrailAPIInputs(texts=inputs.get("texts", []))
result["tool_calls"] = {"verdict": "allow", "detections": []} # type: ignore
return result
class MockMisalignedToolCallGuardrail(CustomGuardrail):
"""Mock guardrail that returns a tool_calls list whose length differs from the
input, so it cannot be applied positionally onto the response."""
async def apply_guardrail(
self,
inputs: GenericGuardrailAPIInputs,
request_data: dict,
input_type: Literal["request", "response"],
logging_obj: Optional[Any] = None,
) -> GenericGuardrailAPIInputs:
tool_calls = inputs.get("tool_calls", [])
shortened = []
if tool_calls:
first = dict(tool_calls[0])
first["function"] = {"name": "x", "arguments": json.dumps({"x": 1})}
shortened.append(first)
return GenericGuardrailAPIInputs(
texts=inputs.get("texts", []),
tool_calls=shortened,
)
class TestOpenAIChatCompletionsHandlerToolsInput:
"""Test input processing with tools (function definitions)"""
@ -740,6 +803,131 @@ class TestOpenAIChatCompletionsHandlerToolCallsOutput:
assert response.model == "gpt-4o-mini"
assert response.choices[0].finish_reason == "tool_calls"
@pytest.mark.asyncio
async def test_output_response_uses_returned_guardrailed_tool_calls(self):
"""Test returned tool_calls are remapped even when guardrail does not mutate inputs."""
handler = OpenAIChatCompletionsHandler()
guardrail = MockCopiedToolCallGuardrail(guardrail_name="test")
response = ModelResponse(
id="chatcmpl-tool-copy",
created=1234567890,
model="gpt-4",
object="chat.completion",
choices=[
Choices(
finish_reason="tool_calls",
index=0,
message=Message(
content=None,
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
id="call_email",
type="function",
function=Function(
name="send_email",
arguments=json.dumps({"email": "john@example.com"}),
),
)
],
),
)
],
)
await handler.process_output_response(response, guardrail)
response_tool_call = response.choices[0].message.tool_calls[0]
assert response_tool_call.function.name == "send_email"
assert json.loads(response_tool_call.function.arguments) == {"email": "[EMAIL]"}
@pytest.mark.asyncio
async def test_output_response_ignores_non_list_returned_tool_calls(self):
"""A guardrail returning tool_calls as a non-list (e.g. a detection-API envelope
dict) must not crash the remap; the original arguments are preserved."""
handler = OpenAIChatCompletionsHandler()
guardrail = MockNonListToolCallGuardrail(guardrail_name="test")
original = json.dumps({"email": "john@example.com"})
response = ModelResponse(
id="chatcmpl-nonlist",
created=1234567890,
model="gpt-4",
object="chat.completion",
choices=[
Choices(
finish_reason="tool_calls",
index=0,
message=Message(
content=None,
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
id="call_email",
type="function",
function=Function(
name="send_email", arguments=original
),
)
],
),
)
],
)
await handler.process_output_response(response, guardrail)
response_tool_call = response.choices[0].message.tool_calls[0]
assert response_tool_call.function.arguments == original
@pytest.mark.asyncio
async def test_output_response_ignores_misaligned_returned_tool_calls(self):
"""A guardrail returning a tool_calls list of a different length than the input
cannot be applied positionally; the handler falls back and preserves the
original arguments instead of writing onto the wrong tool call."""
handler = OpenAIChatCompletionsHandler()
guardrail = MockMisalignedToolCallGuardrail(guardrail_name="test")
first_args = json.dumps({"email": "a@example.com"})
second_args = json.dumps({"email": "b@example.com"})
response = ModelResponse(
id="chatcmpl-misaligned",
created=1234567890,
model="gpt-4",
object="chat.completion",
choices=[
Choices(
finish_reason="tool_calls",
index=0,
message=Message(
content=None,
role="assistant",
tool_calls=[
ChatCompletionMessageToolCall(
id="call_1",
type="function",
function=Function(
name="send_email", arguments=first_args
),
),
ChatCompletionMessageToolCall(
id="call_2",
type="function",
function=Function(
name="send_email", arguments=second_args
),
),
],
),
)
],
)
await handler.process_output_response(response, guardrail)
tool_calls = response.choices[0].message.tool_calls
assert tool_calls[0].function.arguments == first_args
assert tool_calls[1].function.arguments == second_args
class MockPassThroughGuardrail(CustomGuardrail):
"""Mock guardrail that passes through without blocking - for testing streaming fallback behavior"""
@ -765,7 +953,7 @@ class TestOpenAIChatCompletionsHandlerStreamingOutput:
This test verifies the fix for the bug where accessing chunk.choices[0]
would raise IndexError when a streaming chunk has an empty choices list.
"""
from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
from litellm.types.utils import ModelResponseStream
handler = OpenAIChatCompletionsHandler()
guardrail = MockPassThroughGuardrail(guardrail_name="test")

View File

@ -0,0 +1,84 @@
"""
Tests for Tensormesh provider configuration and integration.
"""
import litellm
class TestTensormeshProviderConfig:
"""Test Tensormesh provider configuration"""
def test_tensormesh_in_provider_list(self):
"""Test that tensormesh is in the provider list"""
from litellm import LlmProviders
assert hasattr(LlmProviders, "TENSORMESH")
assert LlmProviders.TENSORMESH.value == "tensormesh"
assert "tensormesh" in litellm.provider_list
def test_tensormesh_json_config_exists(self):
"""Test that tensormesh is configured in providers.json"""
from litellm.llms.openai_like.json_loader import JSONProviderRegistry
assert JSONProviderRegistry.exists("tensormesh")
tensormesh = JSONProviderRegistry.get("tensormesh")
assert tensormesh is not None
assert tensormesh.base_url == "https://serverless.tensormesh.ai/v1"
assert tensormesh.api_key_env == "TENSORMESH_INFERENCE_API_KEY"
assert tensormesh.api_base_env == "TENSORMESH_SERVERLESS_BASE_URL"
assert tensormesh.param_mappings.get("max_completion_tokens") == "max_tokens"
def test_tensormesh_provider_resolution(self):
"""Test that provider resolution finds tensormesh and the default base URL"""
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
model, provider, api_key, api_base = get_llm_provider(
model="tensormesh/openai/gpt-oss-120b",
custom_llm_provider=None,
api_base=None,
api_key=None,
)
assert model == "openai/gpt-oss-120b"
assert provider == "tensormesh"
assert api_base == "https://serverless.tensormesh.ai/v1"
def test_tensormesh_api_base_override(self):
"""Test that an explicit api_base / api_key overrides the serverless default"""
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
model, provider, api_key, api_base = get_llm_provider(
model="tensormesh/openai/gpt-oss-120b",
custom_llm_provider=None,
api_base="https://custom.example.com/v1",
api_key="sk-test",
)
assert provider == "tensormesh"
assert api_base == "https://custom.example.com/v1"
assert api_key == "sk-test"
def test_tensormesh_text_completion_enabled(self):
"""Tensormesh is wired for the /completions (text completion) route,
matching the text_completion flag in provider_endpoints_support.json."""
assert "tensormesh" in litellm.openai_text_completion_compatible_providers
def test_tensormesh_router_config(self):
"""Test that tensormesh can be used in Router configuration"""
from litellm import Router
router = Router(
model_list=[
{
"model_name": "tensormesh-chat",
"litellm_params": {
"model": "tensormesh/openai/gpt-oss-120b",
"api_key": "test-key",
},
}
]
)
assert len(router.model_list) == 1
assert router.model_list[0]["model_name"] == "tensormesh-chat"

View File

@ -0,0 +1,282 @@
"""
Unit tests for WatsonxPassthroughConfig transformation.
Tests the Watsonx-specific passthrough configuration including URL construction,
streaming detection, and authentication handling.
"""
import os
import sys
from unittest.mock import MagicMock, patch
import httpx
import pytest
sys.path.insert(0, os.path.abspath("../../../../.."))
import litellm
from litellm.llms.watsonx.passthrough.transformation import WatsonxPassthroughConfig
class TestWatsonxPassthroughConfig:
"""Tests for WatsonxPassthroughConfig class."""
def test_is_streaming_request_true(self):
"""Test that streaming is detected when stream=True in request data."""
config = WatsonxPassthroughConfig()
request_data = {"stream": True, "input": "test"}
result = config.is_streaming_request(
endpoint="ml/v1/text/generation", request_data=request_data
)
assert result is True
def test_is_streaming_request_false(self):
"""Test that streaming is not detected when stream=False in request data."""
config = WatsonxPassthroughConfig()
request_data = {"stream": False, "input": "test"}
result = config.is_streaming_request(
endpoint="ml/v1/text/generation", request_data=request_data
)
assert result is False
def test_is_streaming_request_missing_stream_key(self):
"""Test that streaming defaults to False when stream key is missing."""
config = WatsonxPassthroughConfig()
request_data = {"input": "test"}
result = config.is_streaming_request(
endpoint="ml/v1/text/generation", request_data=request_data
)
assert result is False
def test_get_complete_url_with_api_base(self):
"""Test URL construction with explicit api_base."""
config = WatsonxPassthroughConfig()
api_base = "https://us-south.ml.cloud.ibm.com"
endpoint = "ml/v1/text/generation"
request_query_params = {"version": "2024-03-19"}
complete_url, base_target_url = config.get_complete_url(
api_base=api_base,
api_key=None,
model="ibm/granite-13b-chat-v2",
endpoint=endpoint,
request_query_params=request_query_params,
litellm_params={},
)
assert isinstance(complete_url, httpx.URL)
assert str(complete_url).startswith(api_base)
assert endpoint in str(complete_url)
assert "version=2024-03-19" in str(complete_url)
assert base_target_url == api_base
@patch("litellm.llms.watsonx.common_utils.get_secret_str")
def test_get_complete_url_with_env_api_base(self, mock_get_secret):
"""Test URL construction with api_base from environment."""
config = WatsonxPassthroughConfig()
env_api_base = "https://eu-de.ml.cloud.ibm.com"
mock_get_secret.return_value = env_api_base
endpoint = "ml/v1/text/tokenization"
request_query_params = {"version": "2024-03-19"}
complete_url, base_target_url = config.get_complete_url(
api_base=None,
api_key=None,
model="ibm/granite-13b-chat-v2",
endpoint=endpoint,
request_query_params=request_query_params,
litellm_params={},
)
assert isinstance(complete_url, httpx.URL)
assert str(complete_url).startswith(env_api_base)
assert endpoint in str(complete_url)
assert base_target_url == env_api_base
def test_get_complete_url_with_query_params(self):
"""Test that query parameters are correctly added to URL."""
config = WatsonxPassthroughConfig()
api_base = "https://us-south.ml.cloud.ibm.com"
endpoint = "ml/v1/text/generation"
request_query_params = {
"version": "2024-03-19",
}
complete_url, _ = config.get_complete_url(
api_base=api_base,
api_key=None,
model="ibm/granite-13b-chat-v2",
endpoint=endpoint,
request_query_params=request_query_params,
litellm_params={},
)
url_str = str(complete_url)
assert "version=2024-03-19" in url_str
def test_get_complete_url_without_query_params(self):
"""Test URL construction without query parameters."""
config = WatsonxPassthroughConfig()
api_base = "https://us-south.ml.cloud.ibm.com"
endpoint = "ml/v1/models"
complete_url, base_target_url = config.get_complete_url(
api_base=api_base,
api_key=None,
model="",
endpoint=endpoint,
request_query_params=None,
litellm_params={},
)
assert isinstance(complete_url, httpx.URL)
assert str(complete_url) == f"{api_base}/{endpoint}"
assert base_target_url == api_base
assert "version=2024-03-19" not in str(complete_url)
@patch("litellm.llms.watsonx.common_utils.get_secret_str")
def test_get_api_base_with_explicit_value(self, mock_get_secret):
"""Test get_api_base returns explicit value when provided."""
explicit_base = "https://custom.watsonx.com"
result = WatsonxPassthroughConfig.get_api_base(api_base=explicit_base)
assert result == explicit_base
mock_get_secret.assert_not_called()
@patch("litellm.llms.watsonx.common_utils.get_secret_str")
def test_get_api_base_from_environment(self, mock_get_secret):
"""Test get_api_base retrieves from environment when not provided."""
env_base = "https://env.watsonx.com"
mock_get_secret.return_value = env_base
result = WatsonxPassthroughConfig.get_api_base(api_base=None)
assert result == env_base
mock_get_secret.assert_called_once_with("WATSONX_API_BASE")
@patch("litellm.llms.watsonx.common_utils.get_secret_str")
def test_get_api_key_with_explicit_value(self, mock_get_secret):
"""Test get_api_key returns explicit value when provided."""
explicit_key = "test-api-key-123"
result = WatsonxPassthroughConfig.get_api_key(api_key=explicit_key)
assert result == explicit_key
mock_get_secret.assert_not_called()
@patch("litellm.llms.watsonx.common_utils.get_secret_str")
def test_get_api_key_from_environment(self, mock_get_secret):
"""Test get_api_key retrieves from environment when not provided."""
env_key = "env-api-key-456"
mock_get_secret.return_value = env_key
result = WatsonxPassthroughConfig.get_api_key(api_key=None)
assert result == env_key
mock_get_secret.assert_any_call("WATSONX_APIKEY")
def test_get_base_model_returns_model(self):
"""Test get_base_model returns the model as-is."""
model = "ibm/granite-13b-chat-v2"
result = WatsonxPassthroughConfig.get_base_model(model)
assert result == model
def test_get_base_model_with_deployment(self):
"""Test get_base_model with deployment model."""
model = "deployment/test-deployment-id"
result = WatsonxPassthroughConfig.get_base_model(model)
assert result == model
def test_get_complete_url_with_different_endpoints(self):
"""Test URL construction with various endpoint paths."""
config = WatsonxPassthroughConfig()
api_base = "https://us-south.ml.cloud.ibm.com"
endpoints = [
"ml/v1/text/generation",
"ml/v1/text/tokenization",
"ml/v1/deployments/test-id/text/generation",
"ml/v1/models",
"ml/v1/foundation_model_specs",
]
for endpoint in endpoints:
complete_url, base_target_url = config.get_complete_url(
api_base=api_base,
api_key=None,
model="",
endpoint=endpoint,
request_query_params={"version": "2024-03-19"},
litellm_params={},
)
assert isinstance(complete_url, httpx.URL)
assert endpoint in str(complete_url)
assert base_target_url == api_base
def test_get_complete_url_preserves_query_param_order(self):
"""Test that query parameters maintain their values correctly."""
config = WatsonxPassthroughConfig()
api_base = "https://us-south.ml.cloud.ibm.com"
endpoint = "ml/v1/text/generation"
request_query_params = {
"version": "2024-03-19",
"project_id": "abc-123",
"space_id": "xyz-789",
}
complete_url, _ = config.get_complete_url(
api_base=api_base,
api_key=None,
model="",
endpoint=endpoint,
request_query_params=request_query_params,
litellm_params={},
)
url_str = str(complete_url)
# Verify all params are present
assert "version=2024-03-19" in url_str
assert "project_id=abc-123" in url_str
assert "space_id=xyz-789" in url_str
def test_is_streaming_request_with_various_stream_values(self):
"""Test streaming detection with different stream value types."""
config = WatsonxPassthroughConfig()
# Test with boolean True
assert config.is_streaming_request("endpoint", {"stream": True}) is True
# Test with boolean False
assert config.is_streaming_request("endpoint", {"stream": False}) is False
# Test with string "true" (truthy string)
result = config.is_streaming_request("endpoint", {"stream": "true"})
assert result == "true" # Returns the value as-is from .get()
# Test with integer 1 (truthy)
result = config.is_streaming_request("endpoint", {"stream": 1})
assert result == 1
# Test with integer 0 (falsy)
result = config.is_streaming_request("endpoint", {"stream": 0})
assert result == 0
# Test with None
result = config.is_streaming_request("endpoint", {"stream": None})
assert result is None
# Test with empty dict (defaults to False)
assert config.is_streaming_request("endpoint", {}) is False

View File

@ -5375,5 +5375,93 @@ class TestPanwAirsDualScanIndependence:
assert mcp_call.get("content") is None
class TestPanwAirsTimeoutCoercion:
"""Regression tests for string-valued timeout handling.
Before the fix, a string `timeout` (which is what the dashboard UI persists
and what raw YAML preserves if quoted) survived into httpx, which raised
`TypeError: '<=' not supported between instances of 'str' and 'int'`. The
broad except in apply_guardrail swallowed it and the proxy returned a
misleading 500 'Security scan failed - request blocked for safety'.
"""
def test_handler_coerces_string_timeout_to_float(self):
handler = make_handler(timeout="30")
assert handler.timeout == 30.0
assert isinstance(handler.timeout, float)
def test_handler_accepts_int_timeout(self):
handler = make_handler(timeout=15)
assert handler.timeout == 15.0
def test_handler_accepts_float_timeout(self):
handler = make_handler(timeout=7.5)
assert handler.timeout == 7.5
def test_handler_none_timeout_falls_back_to_default(self):
handler = make_handler(timeout=None)
assert handler.timeout == 10.0
def test_handler_omitted_timeout_uses_default(self):
handler = make_handler()
assert handler.timeout == 10.0
def test_litellm_params_coerces_string_timeout(self):
"""Boundary validation: the Pydantic model itself should normalize
string timeouts before any handler reads the value via model_dump()."""
params = LitellmParams(
guardrail="panw_prisma_airs",
mode="pre_call",
api_key="test_key",
profile_name="test_profile",
timeout="30",
)
assert params.timeout == 30.0
assert isinstance(params.timeout, float)
def test_litellm_params_rejects_garbage_timeout(self):
with pytest.raises(ValueError):
LitellmParams(
guardrail="panw_prisma_airs",
mode="pre_call",
api_key="test_key",
profile_name="test_profile",
timeout="not-a-number",
)
def test_litellm_params_empty_string_timeout_becomes_none(self):
"""Empty-string timeout (which the dashboard form can send) should
be coerced to None, not crash, and not produce float('')."""
params = LitellmParams(
guardrail="panw_prisma_airs",
mode="pre_call",
api_key="test_key",
profile_name="test_profile",
timeout="",
)
assert params.timeout is None
def test_legacy_initializer_handles_unset_timeout(self):
"""Regression guard: with timeout now a declared Optional[float] = None
on BaseLitellmParams, the legacy panw initializer at
guardrail_initializers.py:220 must not crash on float(None) when the
caller omits timeout entirely."""
from litellm.proxy.guardrails.guardrail_initializers import (
initialize_panw_prisma_airs,
)
params = LitellmParams(
guardrail="panw_prisma_airs",
mode="pre_call",
api_key="test_key",
profile_name="test_profile",
# timeout intentionally omitted - field defaults to None
)
guardrail_config = {"guardrail_name": "test_legacy"}
handler = initialize_panw_prisma_airs(params, guardrail_config)
# Default fallback applied, not crashed on float(None)
assert handler.timeout == 10.0
if __name__ == "__main__":
pytest.main([__file__, "-v"])

View File

@ -0,0 +1,900 @@
import json
import logging
import ssl
from types import SimpleNamespace
from typing import Any, List
import httpx
import pytest
from litellm.exceptions import GuardrailRaisedException
from litellm.exceptions import Timeout as LiteLLMTimeout
from litellm.proxy.guardrails.guardrail_hooks.vigil_guard import (
VigilGuardGuardrail,
guardrail_class_registry,
guardrail_initializer_registry,
initialize_guardrail,
)
from litellm.proxy.guardrails.guardrail_hooks.vigil_guard.vigil_guard import (
_DEFAULT_VIGIL_TIMEOUT,
VigilGuardMissingConfig,
)
from litellm.types.guardrails import LitellmParams, SupportedGuardrailIntegrations
from litellm.types.proxy.guardrails.guardrail_hooks.vigil_guard import (
VigilGuardGuardrailConfigModel,
)
_ENDPOINT = "https://vigil.test/v1/guard/analyze"
def _resp(body: dict, status_code: int = 200) -> httpx.Response:
return httpx.Response(
status_code=status_code,
json=body,
request=httpx.Request("POST", _ENDPOINT),
)
class FakeHandler:
def __init__(self, items: List[Any]):
self._items = list(items)
self.calls: List[SimpleNamespace] = []
async def post(self, *, url, headers, json, timeout=None): # noqa: A002
self.calls.append(
SimpleNamespace(url=url, headers=headers, json=json, timeout=timeout)
)
if not self._items:
raise AssertionError("FakeHandler ran out of programmed responses")
item = self._items.pop(0)
if isinstance(item, BaseException):
raise item
return item
def _make_guardrail(
handler: FakeHandler,
*,
unreachable_fallback="fail_closed",
api_base="https://vigil.test",
api_key="vg_secret_key_123",
guardrail_name="vigil-guard",
timeout=None,
) -> VigilGuardGuardrail:
return VigilGuardGuardrail(
api_base=api_base,
api_key=api_key,
unreachable_fallback=unreachable_fallback,
timeout=timeout,
async_handler=handler,
guardrail_name=guardrail_name,
event_hook="pre_call",
default_on=True,
)
def _transient_exceptions() -> List[BaseException]:
req = httpx.Request("POST", _ENDPOINT)
return [
httpx.ConnectError("boom", request=req),
httpx.ConnectTimeout("boom", request=req),
httpx.ReadTimeout("boom", request=req),
httpx.RemoteProtocolError("boom", request=req),
LiteLLMTimeout(message="t", model="m", llm_provider="vigil_guard"),
]
def test_requires_api_base(monkeypatch):
monkeypatch.delenv("VIGIL_GUARD_URL", raising=False)
monkeypatch.delenv("VIGIL_GUARD_API_KEY", raising=False)
with pytest.raises(VigilGuardMissingConfig):
VigilGuardGuardrail(api_key="k", async_handler=FakeHandler([]))
def test_requires_api_key(monkeypatch):
monkeypatch.delenv("VIGIL_GUARD_API_KEY", raising=False)
with pytest.raises(VigilGuardMissingConfig):
VigilGuardGuardrail(
api_base="https://vigil.test", async_handler=FakeHandler([])
)
def test_trailing_slash_stripped():
g = _make_guardrail(FakeHandler([]), api_base="https://vigil.test/")
assert g.api_base == "https://vigil.test"
def test_env_fallback(monkeypatch):
monkeypatch.setenv("VIGIL_GUARD_URL", "https://env.vigil.test")
monkeypatch.setenv("VIGIL_GUARD_API_KEY", "env_key")
g = VigilGuardGuardrail(
async_handler=FakeHandler([]),
guardrail_name="vg",
event_hook="pre_call",
default_on=True,
)
assert g.api_base == "https://env.vigil.test"
assert g.api_key == "env_key"
def test_default_unreachable_fallback_is_fail_closed():
g = _make_guardrail(FakeHandler([]), unreachable_fallback=None)
assert g.unreachable_fallback == "fail_closed"
def test_explicit_fail_open_is_stored():
g = _make_guardrail(FakeHandler([]), unreachable_fallback="fail_open")
assert g.unreachable_fallback == "fail_open"
def test_unknown_fallback_defaults_to_fail_closed():
g = _make_guardrail(FakeHandler([]), unreachable_fallback="weird")
assert g.unreachable_fallback == "fail_closed"
async def test_allowed_preserves_full_input_shape_and_logs_allow():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
structured = [{"role": "user", "content": "hello"}]
inputs = {"texts": ["hello"], "structured_messages": structured, "model": "gpt-4o"}
request_data = {"metadata": {}}
out = await g.apply_guardrail(
inputs=inputs, request_data=request_data, input_type="request", logging_obj=None
)
assert out["texts"] == ["hello"]
assert out["structured_messages"] is structured
assert out["model"] == "gpt-4o"
assert out is not inputs
assert inputs["structured_messages"] is structured
assert len(handler.calls) == 1
entries = request_data["metadata"]["standard_logging_guardrail_information"]
assert entries[0]["guardrail_response"] == "allow"
async def test_sanitized_replaces_text():
handler = FakeHandler(
[_resp({"decision": "SANITIZED", "sanitizedText": "[REDACTED]"})]
)
g = _make_guardrail(handler)
out = await g.apply_guardrail(
inputs={"texts": ["my ssn is 123"]}, request_data={}, input_type="request"
)
assert out["texts"] == ["[REDACTED]"]
@pytest.mark.parametrize(
"body,expected",
[
(
{
"decision": "SANITIZED",
"sanitizedText": "S",
"outputText": "O",
},
"S",
),
({"decision": "SANITIZED", "outputText": "O"}, "O"),
({"decision": "SANITIZED", "sanitizedText": 123, "outputText": "O"}, "O"),
({"decision": "SANITIZED", "sanitizedText": ""}, ""),
({"decision": "SANITIZED"}, "orig"),
],
)
async def test_sanitized_precedence(body, expected):
handler = FakeHandler([_resp(body)])
g = _make_guardrail(handler)
out = await g.apply_guardrail(
inputs={"texts": ["orig"]}, request_data={}, input_type="request"
)
assert out["texts"] == [expected]
async def test_blocked_raises_guardrail_exception_with_400():
handler = FakeHandler([_resp({"decision": "BLOCKED", "blockMessage": "nope"})])
g = _make_guardrail(handler)
with pytest.raises(GuardrailRaisedException) as exc_info:
await g.apply_guardrail(
inputs={"texts": ["bad"]}, request_data={}, input_type="request"
)
assert exc_info.value.status_code == 400
assert exc_info.value.guardrail_name == "vigil-guard"
assert exc_info.value.message == "nope"
@pytest.mark.parametrize(
"body,expected",
[
(
{
"decision": "BLOCKED",
"blockMessage": "bm",
"decisionReason": "dr",
"categories": ["c1"],
},
"bm",
),
({"decision": "BLOCKED", "blockMessage": " ", "decisionReason": "dr"}, "dr"),
(
{"decision": "BLOCKED", "decisionReason": "dr", "categories": ["c1", "c2"]},
"dr",
),
({"decision": "BLOCKED", "categories": ["c1", "c2"]}, "c1, c2"),
({"decision": "BLOCKED"}, "Blocked by policy"),
],
)
async def test_block_reason_precedence(body, expected):
handler = FakeHandler([_resp(body)])
g = _make_guardrail(handler)
with pytest.raises(GuardrailRaisedException) as exc_info:
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert exc_info.value.message == expected
async def test_block_reason_is_clamped_to_500_chars():
handler = FakeHandler([_resp({"decision": "BLOCKED", "blockMessage": "x" * 600})])
g = _make_guardrail(handler)
with pytest.raises(GuardrailRaisedException) as exc_info:
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert "x" * 500 in exc_info.value.message
assert "x" * 501 not in exc_info.value.message
async def test_empty_and_whitespace_texts_skip_analyze():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
out = await g.apply_guardrail(
inputs={"texts": ["", " ", "real"]}, request_data={}, input_type="request"
)
assert out["texts"] == ["", " ", "real"]
assert len(handler.calls) == 1
assert handler.calls[0].json["text"] == "real"
async def test_no_scannable_text_returns_inputs_unchanged():
handler = FakeHandler([])
g = _make_guardrail(handler)
inputs = {"texts": ["", " "], "structured_messages": [{"role": "user"}]}
out = await g.apply_guardrail(inputs=inputs, request_data={}, input_type="request")
assert out is inputs
assert len(handler.calls) == 0
async def test_multi_text_preserves_length_and_order():
handler = FakeHandler(
[
_resp({"decision": "ALLOWED"}),
_resp({"decision": "SANITIZED", "sanitizedText": "B-clean"}),
_resp({"decision": "ALLOWED"}),
]
)
g = _make_guardrail(handler)
out = await g.apply_guardrail(
inputs={"texts": ["A", "B", "C"]}, request_data={}, input_type="request"
)
assert out["texts"] == ["A", "B-clean", "C"]
assert len(handler.calls) == 3
async def test_one_blocked_text_blocks_the_whole_call():
handler = FakeHandler(
[
_resp({"decision": "ALLOWED"}),
_resp({"decision": "BLOCKED", "blockMessage": "bad second"}),
]
)
g = _make_guardrail(handler)
with pytest.raises(GuardrailRaisedException):
await g.apply_guardrail(
inputs={"texts": ["ok", "bad"]}, request_data={}, input_type="request"
)
async def test_request_source_is_user_input():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert handler.calls[0].json["source"] == "user_input"
async def test_response_source_is_model_output():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="response"
)
assert handler.calls[0].json["source"] == "model_output"
async def test_sanitized_returns_canonical_shape_and_logs_mask():
handler = FakeHandler(
[_resp({"decision": "SANITIZED", "sanitizedText": "[REDACTED]"})]
)
g = _make_guardrail(handler)
tools = [{"type": "function", "function": {"name": "f"}}]
inputs = {
"texts": ["my ssn is 123"],
"images": ["img1"],
"tools": tools,
"tool_calls": [{"id": "1"}],
"structured_messages": [{"role": "user", "content": "my ssn is 123"}],
"model": "gpt-4o",
}
request_data = {"metadata": {}}
out = await g.apply_guardrail(
inputs=inputs, request_data=request_data, input_type="request"
)
assert out["texts"] == ["[REDACTED]"]
assert out["images"] == ["img1"]
assert out["tools"] == tools
assert set(out.keys()) == {"texts", "images", "tools"}
entries = request_data["metadata"]["standard_logging_guardrail_information"]
assert entries[0]["guardrail_response"] == "mask"
async def test_empty_images_and_tools_are_preserved_when_present():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
out = await g.apply_guardrail(
inputs={"texts": ["x"], "images": [], "tools": []},
request_data={},
input_type="request",
)
assert set(out.keys()) == {"texts", "images", "tools"}
assert out["images"] == []
assert out["tools"] == []
async def test_logging_obj_none_supported():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
out = await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request", logging_obj=None
)
assert out["texts"] == ["x"]
async def test_standard_guardrail_logging_remains_active():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
request_data = {"metadata": {}}
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data=request_data, input_type="request"
)
entries = request_data["metadata"]["standard_logging_guardrail_information"]
assert len(entries) == 1
assert entries[0]["guardrail_name"] == "vigil-guard"
assert entries[0]["guardrail_status"] == "success"
async def test_request_url_headers_and_body():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler, api_base="https://vigil.test", api_key="vg_secret")
await g.apply_guardrail(
inputs={"texts": ["hello"]}, request_data={}, input_type="request"
)
call = handler.calls[0]
assert call.url == "https://vigil.test/v1/guard/analyze"
assert call.headers["Authorization"] == "Bearer vg_secret"
assert call.headers["Content-Type"] == "application/json"
assert call.json["text"] == "hello"
assert call.json["mode"] == "full"
assert set(call.json.keys()) == {"text", "source", "mode", "metadata"}
assert "metadata" in call.json
async def test_default_timeout_forwarded_when_unset():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
assert g.timeout == _DEFAULT_VIGIL_TIMEOUT
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert handler.calls[0].timeout == _DEFAULT_VIGIL_TIMEOUT
async def test_configured_timeout_forwarded_to_handler():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler, timeout=30)
expected = httpx.Timeout(30, connect=5.0)
assert g.timeout == expected
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert handler.calls[0].timeout == expected
def test_short_timeout_caps_connect():
g = _make_guardrail(FakeHandler([]), timeout=2)
assert g.timeout == httpx.Timeout(2, connect=2.0)
def test_initialize_guardrail_forwards_timeout():
lp = LitellmParams(
guardrail="vigil_guard",
mode="pre_call",
api_base="https://vigil.test",
api_key="k",
timeout="30",
)
cb = initialize_guardrail(lp, {"guardrail_name": "vg"})
assert cb.timeout == httpx.Timeout(30, connect=5.0)
async def test_api_key_only_in_header_never_in_payload():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler, api_key="super_secret_key")
await g.apply_guardrail(
inputs={"texts": ["hello"]},
request_data={"metadata": {"user_id": "u"}},
input_type="request",
)
call = handler.calls[0]
assert "super_secret_key" not in json.dumps(call.json)
assert call.headers["Authorization"] == "Bearer super_secret_key"
@pytest.mark.parametrize("code", [429, 502, 503, 504])
async def test_retry_once_on_transient_status(code):
handler = FakeHandler([_resp({}, status_code=code), _resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
out = await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert out["texts"] == ["x"]
assert len(handler.calls) == 2
@pytest.mark.parametrize("exc", _transient_exceptions())
async def test_retry_once_on_transient_exception(exc):
handler = FakeHandler([exc, _resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
out = await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert out["texts"] == ["x"]
assert len(handler.calls) == 2
@pytest.mark.parametrize(
"exc, expected",
[
(RuntimeError("boom"), RuntimeError),
(
httpx.WriteError("boom", request=httpx.Request("POST", _ENDPOINT)),
GuardrailRaisedException,
),
],
)
async def test_no_retry_on_non_transient_exception(exc, expected):
handler = FakeHandler([exc])
g = _make_guardrail(handler)
with pytest.raises(expected):
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert len(handler.calls) == 1
@pytest.mark.parametrize("code", [400, 401, 403, 404, 422])
async def test_no_retry_on_non_429_4xx(code):
handler = FakeHandler([_resp({}, status_code=code)])
g = _make_guardrail(handler)
with pytest.raises(GuardrailRaisedException) as exc_info:
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert exc_info.value.status_code == 400
assert len(handler.calls) == 1
async def test_fail_closed_raises_after_exhausted_retry(caplog):
handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)])
g = _make_guardrail(handler)
with (
caplog.at_level(logging.ERROR),
pytest.raises(GuardrailRaisedException) as exc_info,
):
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert exc_info.value.status_code == 400
assert len(handler.calls) == 2
assert any("fail_closed" in record.message for record in caplog.records)
assert any("vigil-guard" in record.message for record in caplog.records)
@pytest.mark.parametrize("exc", _transient_exceptions())
async def test_fail_closed_raises_controlled_block_on_transport_error(exc, caplog):
handler = FakeHandler([exc, exc])
g = _make_guardrail(handler)
with (
caplog.at_level(logging.ERROR),
pytest.raises(GuardrailRaisedException) as exc_info,
):
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert exc_info.value.status_code == 400
assert exc_info.value.guardrail_name == "vigil-guard"
assert exc_info.value.__cause__ is exc
assert any("fail_closed" in record.message for record in caplog.records)
async def test_fail_open_returns_inputs_unchanged_on_backend_error(caplog):
handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)])
g = _make_guardrail(handler, unreachable_fallback="fail_open")
structured = [{"role": "user", "content": "x"}]
inputs = {"texts": ["x"], "structured_messages": structured}
request_data = {"metadata": {}}
with caplog.at_level(logging.ERROR):
out = await g.apply_guardrail(
inputs=inputs, request_data=request_data, input_type="request"
)
assert out is not inputs
assert out["texts"] == ["x"]
assert out["structured_messages"] == structured
assert len(handler.calls) == 2
assert any("fail_open" in record.message for record in caplog.records)
assert any("vigil-guard" in record.message for record in caplog.records)
entries = request_data["metadata"]["standard_logging_guardrail_information"]
assert entries[0]["guardrail_response"] == "allow"
@pytest.mark.parametrize("exc", [ssl.SSLError("tls failed"), OSError("network down")])
async def test_fail_open_returns_inputs_unchanged_on_transport_error(exc):
handler = FakeHandler([exc])
g = _make_guardrail(handler, unreachable_fallback="fail_open")
inputs = {"texts": ["x"]}
out = await g.apply_guardrail(inputs=inputs, request_data={}, input_type="request")
assert out is not inputs
assert out["texts"] == ["x"]
assert len(handler.calls) == 1
@pytest.mark.parametrize(
"exc",
[
TypeError("bug"),
KeyError("bug"),
AttributeError("bug"),
],
)
async def test_fail_open_does_not_swallow_programming_errors(exc):
handler = FakeHandler([exc])
g = _make_guardrail(handler, unreachable_fallback="fail_open")
with pytest.raises(type(exc)):
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert len(handler.calls) == 1
async def test_invalid_decision_fail_closed_raises(caplog):
handler = FakeHandler([_resp({"decision": "MAYBE"})])
g = _make_guardrail(handler)
with (
caplog.at_level(logging.ERROR),
pytest.raises(GuardrailRaisedException) as exc_info,
):
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert exc_info.value.status_code == 400
assert "MAYBE" not in exc_info.value.message
assert any("MAYBE" in record.message for record in caplog.records)
async def test_invalid_decision_fail_open_returns_inputs():
handler = FakeHandler([_resp({"decision": "MAYBE"})])
g = _make_guardrail(handler, unreachable_fallback="fail_open")
out = await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data={}, input_type="request"
)
assert out["texts"] == ["x"]
async def test_fail_open_multi_text_preserves_earlier_sanitization():
handler = FakeHandler(
[
_resp({"decision": "SANITIZED", "sanitizedText": "[REDACTED]"}),
_resp({}, status_code=503),
_resp({}, status_code=503),
]
)
g = _make_guardrail(handler, unreachable_fallback="fail_open")
request_data = {"metadata": {}}
out = await g.apply_guardrail(
inputs={"texts": ["my ssn is 123", "second"]},
request_data=request_data,
input_type="request",
)
assert out["texts"] == ["[REDACTED]", "second"]
assert len(handler.calls) == 3
entries = request_data["metadata"]["standard_logging_guardrail_information"]
assert entries[0]["guardrail_response"] == "mask"
def _tool_call(arguments, name="f", tc_id="1"):
return {
"id": tc_id,
"type": "function",
"function": {"name": name, "arguments": arguments},
}
async def test_response_tool_call_arguments_allowed_unchanged():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
tcs = [_tool_call('{"q": "weather"}')]
out = await g.apply_guardrail(
inputs={"texts": [], "tool_calls": tcs}, request_data={}, input_type="response"
)
assert handler.calls[0].json["text"] == '{"q": "weather"}'
assert handler.calls[0].json["source"] == "model_output"
assert out["tool_calls"] == tcs
async def test_response_tool_call_arguments_sanitized_in_place():
handler = FakeHandler(
[_resp({"decision": "SANITIZED", "sanitizedText": '{"email": "[EMAIL]"}'})]
)
g = _make_guardrail(handler)
tcs = [_tool_call('{"email": "john@example.com"}', name="send_mail")]
inputs = {"texts": [], "tool_calls": tcs}
out = await g.apply_guardrail(inputs=inputs, request_data={}, input_type="response")
assert out["tool_calls"][0]["function"]["arguments"] == '{"email": "[EMAIL]"}'
assert out["tool_calls"][0]["function"]["name"] == "send_mail"
# original inputs are not mutated in place
assert inputs["tool_calls"][0]["function"]["arguments"] == (
'{"email": "john@example.com"}'
)
async def test_response_tool_call_arguments_blocked_raises():
handler = FakeHandler(
[_resp({"decision": "BLOCKED", "blockMessage": "tool blocked"})]
)
g = _make_guardrail(handler)
tcs = [_tool_call('{"x": "bad"}')]
with pytest.raises(GuardrailRaisedException) as exc_info:
await g.apply_guardrail(
inputs={"texts": [], "tool_calls": tcs},
request_data={},
input_type="response",
)
assert exc_info.value.status_code == 400
assert exc_info.value.message == "tool blocked"
async def test_request_tool_calls_are_not_scanned():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
tcs = [_tool_call('{"x": "y"}')]
await g.apply_guardrail(
inputs={"texts": ["hello"], "tool_calls": tcs},
request_data={},
input_type="request",
)
assert len(handler.calls) == 1
assert handler.calls[0].json["text"] == "hello"
async def test_tool_call_scan_backend_failure_fail_closed_raises():
handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)])
g = _make_guardrail(handler)
tcs = [_tool_call('{"x": "y"}')]
with pytest.raises(GuardrailRaisedException) as exc_info:
await g.apply_guardrail(
inputs={"texts": [], "tool_calls": tcs},
request_data={},
input_type="response",
)
assert exc_info.value.status_code == 400
assert len(handler.calls) == 2
async def test_tool_call_scan_backend_failure_fail_open_passes_through():
handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)])
g = _make_guardrail(handler, unreachable_fallback="fail_open")
tcs = [_tool_call('{"x": "y"}')]
out = await g.apply_guardrail(
inputs={"texts": [], "tool_calls": tcs}, request_data={}, input_type="response"
)
assert out["tool_calls"] == tcs
async def test_response_tool_call_unrecognized_decision_fail_closed_raises():
handler = FakeHandler([_resp({"decision": "MAYBE"})])
g = _make_guardrail(handler)
tcs = [_tool_call('{"x": "y"}')]
with pytest.raises(GuardrailRaisedException) as exc_info:
await g.apply_guardrail(
inputs={"texts": [], "tool_calls": tcs},
request_data={},
input_type="response",
)
assert exc_info.value.status_code == 400
async def test_response_tool_call_unrecognized_decision_fail_open_passes_through():
handler = FakeHandler([_resp({"decision": "MAYBE"})])
g = _make_guardrail(handler, unreachable_fallback="fail_open")
tcs = [_tool_call('{"x": "y"}')]
out = await g.apply_guardrail(
inputs={"texts": [], "tool_calls": tcs}, request_data={}, input_type="response"
)
assert out["tool_calls"] == tcs
async def test_metadata_allowlist_and_clamping():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
request_data = {
"model": "gpt-4o",
"metadata": {
"user_id": "u1",
"tenant_id": "t1",
"secret_unlisted": "should_not_forward",
"session_id": "s" * 600,
"org_id": ["a"] * 20,
"request_id": True,
"conversation_id": 7,
},
}
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data=request_data, input_type="request"
)
md = handler.calls[0].json["metadata"]
assert md["model"] == "gpt-4o"
assert md["user_id"] == "u1"
assert md["tenant_id"] == "t1"
assert "secret_unlisted" not in md
assert len(md["session_id"]) == 500
assert len(md["org_id"]) == 10
assert "request_id" not in md
assert md["conversation_id"] == 7
async def test_metadata_source_precedence_and_litellm_metadata_fallback():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
request_data = {
"user_id": "top",
"metadata": {"user_id": "nested"},
"litellm_metadata": {"tenant_id": "lm-tenant"},
}
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data=request_data, input_type="request"
)
md = handler.calls[0].json["metadata"]
assert md["user_id"] == "top"
assert md["tenant_id"] == "lm-tenant"
async def test_metadata_uses_later_source_when_earlier_value_is_unclampable():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
request_data = {
"user_id": {"drop": "dicts are not forwarded"},
"metadata": {"user_id": "nested"},
}
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data=request_data, input_type="request"
)
assert handler.calls[0].json["metadata"]["user_id"] == "nested"
async def test_metadata_array_items_are_clamped_and_filtered():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
request_data = {
"metadata": {
"org_id": ["z" * 600, 123, True, {"drop": 1}, None],
},
}
await g.apply_guardrail(
inputs={"texts": ["x"]}, request_data=request_data, input_type="request"
)
assert handler.calls[0].json["metadata"]["org_id"] == ["z" * 500, 123]
async def test_metadata_array_with_no_supported_items_is_dropped():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
await g.apply_guardrail(
inputs={"texts": ["x"]},
request_data={"metadata": {"org_id": [{"drop": 1}, None]}},
input_type="request",
)
assert "org_id" not in handler.calls[0].json["metadata"]
async def test_call_id_forwarded_from_logging_obj():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
logging_obj = SimpleNamespace(litellm_call_id="call-123")
await g.apply_guardrail(
inputs={"texts": ["x"]},
request_data={},
input_type="request",
logging_obj=logging_obj,
)
assert handler.calls[0].json["metadata"]["litellm_call_id"] == "call-123"
async def test_call_id_forwarded_from_request_data():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
await g.apply_guardrail(
inputs={"texts": ["x"]},
request_data={"litellm_call_id": "rd-1"},
input_type="request",
logging_obj=None,
)
assert handler.calls[0].json["metadata"]["litellm_call_id"] == "rd-1"
async def test_call_id_forwarded_from_request_metadata():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
await g.apply_guardrail(
inputs={"texts": ["x"]},
request_data={"metadata": {"litellm_call_id": "md-1"}},
input_type="request",
logging_obj=None,
)
assert handler.calls[0].json["metadata"]["litellm_call_id"] == "md-1"
async def test_call_id_logging_obj_takes_precedence():
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
g = _make_guardrail(handler)
logging_obj = SimpleNamespace(litellm_call_id="log-1")
await g.apply_guardrail(
inputs={"texts": ["x"]},
request_data={"litellm_call_id": "rd-1"},
input_type="request",
logging_obj=logging_obj,
)
assert handler.calls[0].json["metadata"]["litellm_call_id"] == "log-1"
def test_enum_value():
assert SupportedGuardrailIntegrations.VIGIL_GUARD.value == "vigil_guard"
def test_config_model_ui_name_and_instantiation():
assert VigilGuardGuardrailConfigModel.ui_friendly_name() == "Vigil Guard"
model = VigilGuardGuardrailConfigModel(api_base="https://x", api_key="k")
assert model.api_base == "https://x"
def test_get_config_model_returns_config_model():
g = _make_guardrail(FakeHandler([]))
assert g.get_config_model() is VigilGuardGuardrailConfigModel
def test_registries_expose_initializer_and_class():
assert "vigil_guard" in guardrail_initializer_registry
assert guardrail_class_registry["vigil_guard"] is VigilGuardGuardrail
def test_litellm_params_includes_config_model():
assert VigilGuardGuardrailConfigModel in LitellmParams.__mro__
def test_config_driven_initialization_creates_callback():
lp = LitellmParams(
guardrail="vigil_guard",
mode="pre_call",
api_base="https://vigil.test",
api_key="k",
)
cb = initialize_guardrail(lp, {"guardrail_name": "vg"})
assert isinstance(cb, VigilGuardGuardrail)
assert cb.unreachable_fallback == "fail_closed"

View File

@ -0,0 +1,213 @@
import os
from unittest.mock import patch
import pytest
class TestContentFilterPathTraversal:
"""Tests that _resolve_category_file_path rejects path traversal."""
def _get_guardrail(self):
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
ContentFilterGuardrail,
)
return ContentFilterGuardrail.__new__(ContentFilterGuardrail)
def test_traversal_via_relative_dotdot_raises(self):
guardrail = self._get_guardrail()
with pytest.raises(ValueError, match="outside the allowed categories"):
guardrail._resolve_category_file_path("../../../../etc/passwd")
def test_traversal_via_absolute_path_raises(self):
guardrail = self._get_guardrail()
with pytest.raises(ValueError, match="outside the allowed categories"):
guardrail._resolve_category_file_path("/etc/passwd")
def test_valid_category_file_inside_categories_dir_allowed(self):
guardrail = self._get_guardrail()
categories_dir = os.path.join(
os.path.dirname(
__import__(
"litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter",
fromlist=["content_filter"],
).__file__
),
"categories",
)
valid_file = os.path.join(categories_dir, "harmful_self_harm.yaml")
if not os.path.exists(valid_file):
pytest.skip("harmful_self_harm.yaml not present in this environment")
result = guardrail._resolve_category_file_path(valid_file)
assert result == valid_file
def test_invalid_category_name_skipped(self):
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
ContentFilterGuardrail,
)
guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail)
guardrail.loaded_categories = {}
guardrail.severity_threshold = "medium"
guardrail.category_keywords = {}
guardrail.always_block_category_keywords = {}
guardrail.conditional_categories = {}
# category name with path traversal chars must be skipped, not crash
guardrail._load_categories([{"category": "../../etc/passwd", "enabled": True}])
assert "../../etc/passwd" not in guardrail.loaded_categories
def test_category_name_with_slash_skipped(self):
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
ContentFilterGuardrail,
)
guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail)
guardrail.loaded_categories = {}
guardrail.severity_threshold = "medium"
guardrail.category_keywords = {}
guardrail.always_block_category_keywords = {}
guardrail.conditional_categories = {}
guardrail._load_categories(
[{"category": "foo/../../etc/passwd", "enabled": True}]
)
assert "foo/../../etc/passwd" not in guardrail.loaded_categories
def test_assert_within_categories_dir_blocks_parent_traversal(self):
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
ContentFilterGuardrail,
)
categories_dir = os.path.join(
os.path.dirname(
__import__(
"litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter",
fromlist=["content_filter"],
).__file__
),
"categories",
)
with pytest.raises(ValueError, match="outside the allowed categories"):
ContentFilterGuardrail._assert_within_categories_dir(
"/etc/passwd", categories_dir
)
def test_assert_within_categories_dir_allows_valid_file(self, tmp_path):
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
ContentFilterGuardrail,
)
categories_dir = str(tmp_path)
valid_file = str(tmp_path / "test.yaml")
# Should not raise
ContentFilterGuardrail._assert_within_categories_dir(valid_file, categories_dir)
def test_assert_within_categories_dir_commonpath_raises_valueerror(self, tmp_path):
"""Cover the except-ValueError branch (Windows cross-drive paths)."""
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
ContentFilterGuardrail,
)
categories_dir = str(tmp_path)
valid_file = str(tmp_path / "test.yaml")
with patch(
"os.path.commonpath", side_effect=ValueError("Paths on different drives")
):
with pytest.raises(
ValueError, match="outside the allowed categories directory"
):
ContentFilterGuardrail._assert_within_categories_dir(
valid_file, categories_dir
)
def test_resolve_category_file_path_direct_join_hit(self):
"""Cover the first-join-attempt success branch (lines 383-384)."""
guardrail = self._get_guardrail()
# "categories/<file>" joined directly to module_dir resolves to an existing file.
categories_dir = os.path.join(
os.path.dirname(
__import__(
"litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter",
fromlist=["content_filter"],
).__file__
),
"categories",
)
yaml_files = [f for f in os.listdir(categories_dir) if f.endswith(".yaml")]
if not yaml_files:
pytest.skip("No category YAML files present in this environment")
relative_path = os.path.join("categories", yaml_files[0])
result = guardrail._resolve_category_file_path(relative_path)
assert os.path.isabs(result) or os.path.exists(result)
def test_resolve_category_file_path_component_strip_hit(self):
"""Cover the component-stripping loop success branch (lines 392-393)."""
guardrail = self._get_guardrail()
categories_dir = os.path.join(
os.path.dirname(
__import__(
"litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter",
fromlist=["content_filter"],
).__file__
),
"categories",
)
yaml_files = [f for f in os.listdir(categories_dir) if f.endswith(".yaml")]
if not yaml_files:
pytest.skip("No category YAML files present in this environment")
# Prefix with a fake leading component so the first-join attempt misses,
# but stripping that component reveals categories/<file> which exists.
prefixed_path = "some_prefix/categories/" + yaml_files[0]
result = guardrail._resolve_category_file_path(prefixed_path)
assert os.path.isabs(result) or os.path.exists(result)
def test_load_categories_traversal_category_file_skipped(self):
"""Cover the except-ValueError branch in _load_categories (lines 451-454)."""
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
ContentFilterGuardrail,
)
guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail)
guardrail.loaded_categories = {}
guardrail.severity_threshold = "medium"
guardrail.category_keywords = {}
guardrail.always_block_category_keywords = {}
guardrail.conditional_categories = {}
# A traversal path in category_file must be skipped (not crash) via ValueError.
guardrail._load_categories(
[
{
"category": "valid_name",
"enabled": True,
"category_file": "../../../../etc/passwd",
}
]
)
assert "valid_name" not in guardrail.loaded_categories
def test_allow_external_paths_env_var_bypasses_jail(self, tmp_path):
"""LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS=true skips the directory jail."""
import os as _os
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
ContentFilterGuardrail,
)
guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail)
# Create a real file outside the module directory (simulates mounted volume).
external_file = tmp_path / "external_categories.yaml"
external_file.write_text("category_name: test\n")
with patch.dict(
_os.environ, {"LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS": "true"}
):
# Should return the path without raising ValueError.
result = guardrail._resolve_category_file_path(str(external_file))
assert result == str(external_file)
def test_traversal_blocked_when_allow_external_not_set(self):
"""Without the env var the jail still blocks traversal paths."""
import os as _os
guardrail = self._get_guardrail()
with patch.dict(_os.environ, {}, clear=False):
_os.environ.pop("LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS", None)
with pytest.raises(ValueError, match="outside the allowed categories"):
guardrail._resolve_category_file_path("/etc/passwd")

View File

@ -0,0 +1,444 @@
"""
Unit tests for watsonx_proxy_route endpoint.
Tests the Watsonx pass-through endpoint that handles automatic IAM token management
and version parameter injection.
"""
import json
import os
import sys
from unittest.mock import AsyncMock, MagicMock, Mock, patch
import pytest
from fastapi import HTTPException, Request, Response
sys.path.insert(
0, os.path.abspath("../../../..")
) # Adds the parent directory to the system path
import litellm
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
watsonx_proxy_route,
)
class TestWatsonxProxyRoute:
"""Tests for the Watsonx pass-through route."""
@pytest.mark.asyncio
async def test_watsonx_proxy_route_success_non_streaming(self):
"""Test successful non-streaming request through Watsonx proxy route."""
# Setup mocks
mock_request = MagicMock(spec=Request)
mock_request.method = "POST"
mock_request.query_params = {}
mock_request.headers = {"content-type": "application/json"}
mock_request.json = AsyncMock(return_value={"stream": False, "input": "test"})
mock_response = MagicMock(spec=Response)
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
# Mock provider config
mock_provider_config = MagicMock()
mock_provider_config.get_complete_url.return_value = (
"https://us-south.ml.cloud.ibm.com/ml/v1/text/generation",
{},
)
mock_provider_config.validate_environment.return_value = {
"Authorization": "Bearer test-iam-token"
}
# Mock endpoint function
mock_endpoint_func = AsyncMock(
return_value={"model_id": "ibm/granite-13b-chat-v2", "results": []}
)
with (
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
return_value=mock_provider_config,
),
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
return_value=mock_endpoint_func,
) as mock_create_route,
):
result = await watsonx_proxy_route(
endpoint="ml/v1/text/generation",
request=mock_request,
fastapi_response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
)
# Verify provider config was called correctly
mock_provider_config.get_complete_url.assert_called_once()
mock_provider_config.validate_environment.assert_called_once()
# Verify create_pass_through_route was called with correct parameters
mock_create_route.assert_called_once()
call_args = mock_create_route.call_args[1]
assert call_args["endpoint"] == "ml/v1/text/generation"
assert (
call_args["target"]
== "https://us-south.ml.cloud.ibm.com/ml/v1/text/generation"
)
assert (
call_args["custom_headers"]["Authorization"] == "Bearer test-iam-token"
)
assert call_args["is_streaming_request"] is False
assert call_args["custom_llm_provider"] == "watsonx"
assert (
call_args["query_params"]["version"]
== litellm.WATSONX_DEFAULT_API_VERSION
)
# Verify endpoint function was called
mock_endpoint_func.assert_called_once_with(
mock_request, mock_response, mock_user_api_key_dict
)
assert result == {"model_id": "ibm/granite-13b-chat-v2", "results": []}
@pytest.mark.asyncio
async def test_watsonx_proxy_route_success_streaming(self):
"""Test successful streaming request through Watsonx proxy route."""
# Setup mocks
mock_request = MagicMock(spec=Request)
mock_request.method = "POST"
mock_request.query_params = {}
mock_request.headers = {"content-type": "application/json"}
mock_request.json = AsyncMock(return_value={"stream": True, "input": "test"})
mock_response = MagicMock(spec=Response)
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
# Mock provider config
mock_provider_config = MagicMock()
mock_provider_config.get_complete_url.return_value = (
"https://us-south.ml.cloud.ibm.com/ml/v1/text/generation_stream",
{},
)
mock_provider_config.validate_environment.return_value = {
"Authorization": "Bearer test-iam-token"
}
# Mock endpoint function
mock_endpoint_func = AsyncMock(return_value="streaming_response")
with (
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
return_value=mock_provider_config,
),
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
return_value=mock_endpoint_func,
) as mock_create_route,
):
result = await watsonx_proxy_route(
endpoint="ml/v1/text/generation_stream",
request=mock_request,
fastapi_response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
)
# Verify create_pass_through_route was called with streaming enabled
mock_create_route.assert_called_once()
call_args = mock_create_route.call_args[1]
assert call_args["is_streaming_request"] is True
assert result == "streaming_response"
@pytest.mark.asyncio
async def test_watsonx_proxy_route_get_request(self):
"""Test GET request through Watsonx proxy route."""
# Setup mocks
mock_request = MagicMock(spec=Request)
mock_request.method = "GET"
mock_request.query_params = {"project_id": "test-project"}
mock_request.headers = {}
mock_response = MagicMock(spec=Response)
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
# Mock provider config
mock_provider_config = MagicMock()
mock_provider_config.get_complete_url.return_value = (
"https://us-south.ml.cloud.ibm.com/ml/v1/models",
{},
)
mock_provider_config.validate_environment.return_value = {
"Authorization": "Bearer test-iam-token"
}
# Mock endpoint function
mock_endpoint_func = AsyncMock(return_value={"resources": []})
with (
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
return_value=mock_provider_config,
),
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
return_value=mock_endpoint_func,
) as mock_create_route,
):
result = await watsonx_proxy_route(
endpoint="ml/v1/models",
request=mock_request,
fastapi_response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
)
# Verify is_streaming_request is False for GET requests
mock_create_route.assert_called_once()
call_args = mock_create_route.call_args[1]
assert call_args["is_streaming_request"] is False
assert result == {"resources": []}
@pytest.mark.asyncio
async def test_watsonx_proxy_route_multipart_form_data(self):
"""Test multipart/form-data request through Watsonx proxy route."""
# Setup mocks
mock_request = MagicMock(spec=Request)
mock_request.method = "POST"
mock_request.query_params = {}
mock_request.headers = {"content-type": "multipart/form-data; boundary=----"}
mock_response = MagicMock(spec=Response)
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
# Mock form data
mock_form_data = {"file": "test_file", "stream": False}
# Mock provider config
mock_provider_config = MagicMock()
mock_provider_config.get_complete_url.return_value = (
"https://us-south.ml.cloud.ibm.com/ml/v1/text/tokenization",
{},
)
mock_provider_config.validate_environment.return_value = {
"Authorization": "Bearer test-iam-token"
}
# Mock endpoint function
mock_endpoint_func = AsyncMock(return_value={"token_count": 10})
with (
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
return_value=mock_provider_config,
),
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.get_form_data",
return_value=mock_form_data,
),
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
return_value=mock_endpoint_func,
) as mock_create_route,
):
result = await watsonx_proxy_route(
endpoint="ml/v1/text/tokenization",
request=mock_request,
fastapi_response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
)
# Verify is_streaming_request is False for non-streaming form data
mock_create_route.assert_called_once()
call_args = mock_create_route.call_args[1]
assert call_args["is_streaming_request"] is False
assert result == {"token_count": 10}
@pytest.mark.asyncio
async def test_watsonx_proxy_route_no_provider_config(self):
"""Test that HTTPException is raised when provider config is not found."""
# Setup mocks
mock_request = MagicMock(spec=Request)
mock_request.method = "POST"
mock_request.query_params = {}
mock_request.headers = {"content-type": "application/json"}
mock_response = MagicMock(spec=Response)
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
with (
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
return_value=None,
),
):
with pytest.raises(HTTPException) as exc_info:
await watsonx_proxy_route(
endpoint="ml/v1/text/generation",
request=mock_request,
fastapi_response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
)
assert exc_info.value.status_code == 404
assert exc_info.value.detail == "Watsonx passthrough config not found"
@pytest.mark.asyncio
async def test_watsonx_proxy_route_version_parameter_injection(self):
"""Test that version parameter is correctly injected into query params."""
# Setup mocks
mock_request = MagicMock(spec=Request)
mock_request.method = "POST"
mock_request.query_params = {}
mock_request.headers = {"content-type": "application/json"}
mock_request.json = AsyncMock(return_value={"input": "test"})
mock_response = MagicMock(spec=Response)
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
# Mock provider config
mock_provider_config = MagicMock()
mock_provider_config.get_complete_url.return_value = (
"https://us-south.ml.cloud.ibm.com/ml/v1/text/generation",
{},
)
mock_provider_config.validate_environment.return_value = {
"Authorization": "Bearer test-iam-token"
}
# Mock endpoint function
mock_endpoint_func = AsyncMock(return_value={})
with (
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
return_value=mock_provider_config,
),
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
return_value=mock_endpoint_func,
) as mock_create_route,
):
await watsonx_proxy_route(
endpoint="ml/v1/text/generation",
request=mock_request,
fastapi_response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
)
# Verify version parameter is injected
mock_create_route.assert_called_once()
call_args = mock_create_route.call_args[1]
assert "query_params" in call_args
assert "version" in call_args["query_params"]
assert (
call_args["query_params"]["version"]
== litellm.WATSONX_DEFAULT_API_VERSION
)
@pytest.mark.asyncio
async def test_watsonx_proxy_route_custom_headers_from_validate_environment(self):
"""Test that custom headers from validate_environment are passed through."""
# Setup mocks
mock_request = MagicMock(spec=Request)
mock_request.method = "POST"
mock_request.query_params = {}
mock_request.headers = {"content-type": "application/json"}
mock_request.json = AsyncMock(return_value={"input": "test"})
mock_response = MagicMock(spec=Response)
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
# Mock provider config with custom headers
mock_provider_config = MagicMock()
mock_provider_config.get_complete_url.return_value = (
"https://us-south.ml.cloud.ibm.com/ml/v1/text/generation",
{},
)
mock_provider_config.validate_environment.return_value = {
"Authorization": "Bearer test-iam-token",
"X-Custom-Header": "custom-value",
}
# Mock endpoint function
mock_endpoint_func = AsyncMock(return_value={})
with (
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
return_value=mock_provider_config,
),
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
return_value=mock_endpoint_func,
) as mock_create_route,
):
await watsonx_proxy_route(
endpoint="ml/v1/text/generation",
request=mock_request,
fastapi_response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
)
# Verify custom headers are passed through
mock_create_route.assert_called_once()
call_args = mock_create_route.call_args[1]
assert "custom_headers" in call_args
assert (
call_args["custom_headers"]["Authorization"] == "Bearer test-iam-token"
)
assert call_args["custom_headers"]["X-Custom-Header"] == "custom-value"
@pytest.mark.asyncio
async def test_watsonx_proxy_route_different_endpoints(self):
"""Test various Watsonx endpoint paths."""
endpoints = [
"ml/v1/text/generation",
"ml/v1/text/tokenization",
"ml/v1/deployments/test-deployment/text/generation",
"ml/v1/models",
]
for endpoint_path in endpoints:
# Setup mocks
mock_request = MagicMock(spec=Request)
mock_request.method = "POST"
mock_request.query_params = {}
mock_request.headers = {"content-type": "application/json"}
mock_request.json = AsyncMock(return_value={"input": "test"})
mock_response = MagicMock(spec=Response)
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
# Mock provider config
mock_provider_config = MagicMock()
mock_provider_config.get_complete_url.return_value = (
f"https://us-south.ml.cloud.ibm.com/{endpoint_path}",
{},
)
mock_provider_config.validate_environment.return_value = {
"Authorization": "Bearer test-iam-token"
}
# Mock endpoint function
mock_endpoint_func = AsyncMock(return_value={})
with (
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
return_value=mock_provider_config,
),
patch(
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
return_value=mock_endpoint_func,
) as mock_create_route,
):
await watsonx_proxy_route(
endpoint=endpoint_path,
request=mock_request,
fastapi_response=mock_response,
user_api_key_dict=mock_user_api_key_dict,
)
# Verify endpoint is passed correctly
mock_create_route.assert_called_once()
call_args = mock_create_route.call_args[1]
assert call_args["endpoint"] == endpoint_path
assert (
call_args["target"]
== f"https://us-south.ml.cloud.ibm.com/{endpoint_path}"
)

View File

@ -3185,3 +3185,358 @@ async def test_view_spend_logs_date_range_hashes_sk_api_key(client, monkeypatch)
assert where["api_key"] == "hashed::sk-raw-admin-token"
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)
class _SpendScopeMockPrismaClient:
def __init__(self, get_data_returns=None, find_many_returns=None):
self._get_data_returns = (
get_data_returns if get_data_returns is not None else []
)
self._find_many_returns = (
find_many_returns if find_many_returns is not None else []
)
self.get_data_calls = []
self.find_many_calls = []
client = self
class _VerificationTokenTable:
async def find_many(self, where=None, order=None, include=None):
client.find_many_calls.append(
{"where": where, "order": order, "include": include}
)
return client._find_many_returns
class _DB:
def __init__(self):
self.litellm_verificationtoken = _VerificationTokenTable()
self.db = _DB()
async def get_data(self, table_name=None, query_type=None, **kwargs):
self.get_data_calls.append(
{"table_name": table_name, "query_type": query_type, **kwargs}
)
if query_type == "find_unique":
return self._get_data_returns[0] if self._get_data_returns else None
return self._get_data_returns
@pytest.mark.asyncio
async def test_spend_key_fn_proxy_admin_returns_all_keys(client, monkeypatch):
"""Admins keep their existing full-table view of /spend/keys."""
mock_keys = [
{"token": "hashed-a", "user_id": "alice", "spend": 10.0},
{"token": "hashed-b", "user_id": "bob", "spend": 5.0},
]
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=mock_keys)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin"
)
try:
response = client.get(
"/spend/keys", headers={"Authorization": "Bearer sk-test"}
)
assert response.status_code == 200
# Admin path: goes through get_data (full table), never the scoped find_many
assert len(mock_prisma.get_data_calls) == 1
assert mock_prisma.get_data_calls[0]["table_name"] == "key"
assert mock_prisma.get_data_calls[0]["query_type"] == "find_all"
assert mock_prisma.find_many_calls == []
assert response.json() == mock_keys
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)
@pytest.mark.asyncio
async def test_spend_key_fn_proxy_admin_view_only_returns_all_keys(client, monkeypatch):
"""View-only admins are still admins for this endpoint."""
mock_keys = [{"token": "hashed-a", "user_id": "alice"}]
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=mock_keys)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, user_id="admin_viewer"
)
try:
response = client.get(
"/spend/keys", headers={"Authorization": "Bearer sk-test"}
)
assert response.status_code == 200
assert mock_prisma.find_many_calls == []
assert len(mock_prisma.get_data_calls) == 1
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"role",
[LitellmUserRoles.INTERNAL_USER, LitellmUserRoles.INTERNAL_USER_VIEW_ONLY],
)
async def test_spend_key_fn_internal_user_scoped_to_own_keys(client, monkeypatch, role):
"""Both internal-user roles must only see keys they own."""
caller_owned_keys = [
{"token": "hashed-mine-1", "user_id": "alice", "spend": 2.0},
{"token": "hashed-mine-2", "user_id": "alice", "spend": 1.0},
]
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=caller_owned_keys)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
user_role=role, user_id="alice"
)
try:
response = client.get(
"/spend/keys", headers={"Authorization": "Bearer sk-test"}
)
assert response.status_code == 200
# Non-admin path goes through the same get_data helper as admin,
# but with a user_id scope so only the caller's rows come back.
assert mock_prisma.find_many_calls == []
assert len(mock_prisma.get_data_calls) == 1
call = mock_prisma.get_data_calls[0]
assert call["table_name"] == "key"
assert call["query_type"] == "find_all"
assert call["user_id"] == "alice"
assert response.json() == caller_owned_keys
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)
@pytest.mark.asyncio
async def test_spend_key_fn_internal_user_without_user_id_returns_empty(
client, monkeypatch
):
"""
A non-admin key with no user_id has no tenant scope. Returning the full
table would re-introduce the leak; return an empty list instead.
"""
mock_prisma = _SpendScopeMockPrismaClient(
get_data_returns=[{"token": "do-not-leak"}],
find_many_returns=[{"token": "do-not-leak"}],
)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER, user_id=None
)
try:
response = client.get(
"/spend/keys", headers={"Authorization": "Bearer sk-test"}
)
assert response.status_code == 200
assert response.json() == []
assert mock_prisma.get_data_calls == []
assert mock_prisma.find_many_calls == []
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)
@pytest.mark.asyncio
async def test_spend_user_fn_proxy_admin_returns_all_users_without_user_id(
client, monkeypatch
):
"""Admins keep their existing full-table view of /spend/users."""
mock_users = [
{"user_id": "alice", "user_email": "alice@example.com", "spend": 1.0},
{"user_id": "bob", "user_email": "bob@example.com", "spend": 2.0},
]
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=mock_users)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin"
)
try:
response = client.get(
"/spend/users", headers={"Authorization": "Bearer sk-test"}
)
assert response.status_code == 200
assert len(mock_prisma.get_data_calls) == 1
assert mock_prisma.get_data_calls[0]["table_name"] == "user"
assert mock_prisma.get_data_calls[0]["query_type"] == "find_all"
assert response.json() == mock_users
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)
@pytest.mark.asyncio
async def test_spend_user_fn_proxy_admin_can_query_specific_user_id(
client, monkeypatch
):
"""Admins can still target a specific user_id."""
mock_user = {
"user_id": "carol",
"user_email": "carol@example.com",
"spend": 7.0,
}
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[mock_user])
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin"
)
try:
response = client.get(
"/spend/users",
params={"user_id": "carol"},
headers={"Authorization": "Bearer sk-test"},
)
assert response.status_code == 200
assert len(mock_prisma.get_data_calls) == 1
assert mock_prisma.get_data_calls[0]["query_type"] == "find_unique"
assert mock_prisma.get_data_calls[0]["user_id"] == "carol"
assert response.json() == [mock_user]
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"role",
[LitellmUserRoles.INTERNAL_USER, LitellmUserRoles.INTERNAL_USER_VIEW_ONLY],
)
async def test_spend_user_fn_internal_user_scoped_without_user_id(
client, monkeypatch, role
):
"""No user_id supplied -> must query the caller's own row, not the table."""
own_row = {"user_id": "alice", "user_email": "alice@example.com", "spend": 3.0}
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[own_row])
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
user_role=role, user_id="alice"
)
try:
response = client.get(
"/spend/users", headers={"Authorization": "Bearer sk-test"}
)
assert response.status_code == 200
assert len(mock_prisma.get_data_calls) == 1
assert mock_prisma.get_data_calls[0]["query_type"] == "find_unique"
assert mock_prisma.get_data_calls[0]["user_id"] == "alice"
assert response.json() == [own_row]
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)
@pytest.mark.asyncio
async def test_spend_user_fn_internal_user_supplying_other_user_id_returns_403(
client, monkeypatch
):
"""
An internal user passing user_id=victim must be rejected outright, not
silently rewritten. A 403 makes the attempt observable in logs.
"""
leaked_victim_row = {
"user_id": "victim",
"user_email": "victim@example.com",
"spend": 999.0,
}
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[leaked_victim_row])
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER, user_id="alice"
)
try:
response = client.get(
"/spend/users",
params={"user_id": "victim"},
headers={"Authorization": "Bearer sk-test"},
)
assert response.status_code == 403
assert mock_prisma.get_data_calls == []
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)
@pytest.mark.asyncio
async def test_spend_user_fn_internal_user_supplying_own_user_id_is_allowed(
client, monkeypatch
):
"""
Passing your own user_id explicitly is fine the 403 only fires when
the supplied id differs from the caller's.
"""
own_row = {"user_id": "alice", "user_email": "alice@example.com", "spend": 3.0}
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[own_row])
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER, user_id="alice"
)
try:
response = client.get(
"/spend/users",
params={"user_id": "alice"},
headers={"Authorization": "Bearer sk-test"},
)
assert response.status_code == 200
assert len(mock_prisma.get_data_calls) == 1
assert mock_prisma.get_data_calls[0]["query_type"] == "find_unique"
assert mock_prisma.get_data_calls[0]["user_id"] == "alice"
assert response.json() == [own_row]
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)
@pytest.mark.asyncio
async def test_spend_user_fn_internal_user_without_user_id_returns_empty(
client, monkeypatch
):
"""
A non-admin key with no user_id has no tenant scope -> return empty,
never the full table. Same defensive contract as /spend/keys.
"""
mock_prisma = _SpendScopeMockPrismaClient(
get_data_returns=[{"user_id": "do-not-leak"}]
)
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, user_id=None
)
try:
response = client.get(
"/spend/users", headers={"Authorization": "Bearer sk-test"}
)
assert response.status_code == 200
assert response.json() == []
assert mock_prisma.get_data_calls == []
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)
@pytest.mark.asyncio
async def test_spend_user_fn_strips_password_field(client, monkeypatch):
"""
Existing password-redaction behavior must be preserved on the scoped
path so we don't regress a separate disclosure when adding the fix.
"""
own_row = {
"user_id": "alice",
"user_email": "alice@example.com",
"password": "hashed-password-must-not-leak",
"spend": 1.0,
}
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[own_row])
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
user_role=LitellmUserRoles.INTERNAL_USER, user_id="alice"
)
try:
response = client.get(
"/spend/users", headers={"Authorization": "Bearer sk-test"}
)
assert response.status_code == 200
body = response.json()
assert len(body) == 1
assert "password" not in body[0]
finally:
app.dependency_overrides.pop(ps.user_api_key_auth, None)

View File

@ -0,0 +1,32 @@
# tests/test_litellm/proxy/test__types.py
from litellm.proxy._types import LiteLLM_TeamMembership
def test_team_membership_budget_table_optional_no_crash():
"""
Regression test for #28689
Pydantic v2: Optional[T] without default = required field.
When budget_id is null, DB join returns no litellm_budget_table key.
model_validate must NOT raise 'Field required'.
"""
data = {
"user_id": "test-user",
"team_id": "test-team",
"budget_id": None,
# litellm_budget_table intentionally absent (as DB join returns when budget_id is null)
}
result = LiteLLM_TeamMembership.model_validate(data)
assert result.litellm_budget_table is None
def test_team_membership_budget_table_present_still_works():
"""When budget_id exists, litellm_budget_table should still be populated."""
data = {
"user_id": "test-user",
"team_id": "test-team",
"budget_id": "some-budget-id",
"litellm_budget_table": None,
}
result = LiteLLM_TeamMembership.model_validate(data)
assert result.litellm_budget_table is None

View File

@ -0,0 +1,47 @@
"""
Validate that AWS GovCloud (Bedrock us-gov-*) Haiku 4.5 entries carry
the 1-hour cache write tier.
AWS Bedrock GovCloud pricing applies a +20% premium over global
Anthropic rates. Global Haiku 4.5 1h cache write is $2.00/MTok; us-gov
is therefore $2.40/MTok exactly 1.6x the 5-minute rate of $1.50/MTok.
Source: https://aws.amazon.com/bedrock/pricing/
"""
import json
import os
import pytest
@pytest.fixture(scope="module")
def model_data():
json_path = os.path.join(
os.path.dirname(__file__), "../../model_prices_and_context_window.json"
)
with open(json_path) as f:
return json.load(f)
HAIKU_USGOV_KEYS = [
"bedrock/us-gov-east-1/anthropic.claude-haiku-4-5-20251001-v1:0",
"bedrock/us-gov-west-1/anthropic.claude-haiku-4-5-20251001-v1:0",
]
@pytest.mark.parametrize("model_key", HAIKU_USGOV_KEYS)
def test_usgov_haiku_4_5_1hr_cache_write(model_data, model_key):
assert model_key in model_data, f"Missing model entry: {model_key}"
info = model_data[model_key]
assert (
info["cache_creation_input_token_cost"] == 1.5e-06
), f"{model_key}: 5m cache write should be $1.50/MTok"
assert (
info["cache_creation_input_token_cost_above_1hr"] == 2.4e-06
), f"{model_key}: 1h cache write should be $2.40/MTok"
ratio = (
info["cache_creation_input_token_cost_above_1hr"]
/ info["cache_creation_input_token_cost"]
)
assert abs(ratio - 1.6) < 1e-9, f"{model_key}: 1h/5m ratio is {ratio}, expected 1.6"

View File

@ -0,0 +1,132 @@
"""
Validate AWS GovCloud (Bedrock us-gov-*) Anthropic pricing entries.
AWS Bedrock pricing in GovCloud carries a +20% premium over the global
Anthropic prices (not the +10% commercial-US premium). Until 2026-05-22
these entries silently mirrored commercial US, undercharging customers
by ~9%.
Source: https://aws.amazon.com/bedrock/pricing/
Sonnet 4.5 in us-gov-* (per million tokens):
input = $3.60
output = $18.00
cache write 5m = $4.50
cache write 1h = $7.20
cache read = $0.36
Reference: https://github.com/BerriAI/litellm/issues/27120
"""
import json
import os
import pytest
@pytest.fixture(scope="module")
def model_data():
json_path = os.path.join(
os.path.dirname(__file__), "../../model_prices_and_context_window.json"
)
with open(json_path) as f:
return json.load(f)
SONNET_4_5_USGOV_KEYS = [
"bedrock/us-gov-east-1/anthropic.claude-sonnet-4-5-20250929-v1:0",
"bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0",
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0",
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0",
"us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0",
]
@pytest.mark.parametrize("model_key", SONNET_4_5_USGOV_KEYS)
def test_usgov_sonnet_4_5_pricing(model_data, model_key):
"""Each us-gov sonnet-4-5 entry must carry the +20%-over-global rates
that AWS publishes on the GovCloud pricing page.
"""
assert model_key in model_data, f"Missing model entry: {model_key}"
info = model_data[model_key]
assert info["input_cost_per_token"] == 3.6e-06, (
f"{model_key}: input_cost_per_token should be $3.60/MTok "
f"(got {info['input_cost_per_token']})"
)
assert (
info["output_cost_per_token"] == 1.8e-05
), f"{model_key}: output_cost_per_token should be $18.00/MTok"
assert (
info["cache_creation_input_token_cost"] == 4.5e-06
), f"{model_key}: 5m cache write should be $4.50/MTok"
assert (
info["cache_creation_input_token_cost_above_1hr"] == 7.2e-06
), f"{model_key}: 1h cache write should be $7.20/MTok"
assert (
info["cache_read_input_token_cost"] == 3.6e-07
), f"{model_key}: cache read should be $0.36/MTok"
def test_usgov_carries_20_percent_premium_over_global(model_data):
"""The us-gov rates must equal 1.2x the global anthropic.* rates,
matching AWS's documented GovCloud uplift.
"""
global_key = "anthropic.claude-sonnet-4-5-20250929-v1:0"
usgov_key = "bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0"
global_info = model_data[global_key]
usgov_info = model_data[usgov_key]
for field in (
"input_cost_per_token",
"output_cost_per_token",
"cache_creation_input_token_cost",
"cache_creation_input_token_cost_above_1hr",
"cache_read_input_token_cost",
):
ratio = usgov_info[field] / global_info[field]
assert (
abs(ratio - 1.2) < 1e-9
), f"{field}: us-gov / global ratio is {ratio}, expected 1.2"
# The us-gov.anthropic.* cross-region inference profile is the only us-gov
# entry that carries the 1M-context `_above_200k_tokens` pricing tier — the
# bedrock/us-gov-{east,west}-1/ entries are capped at 200k tokens.
USGOV_CROSS_REGION_KEY = "us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0"
EXPECTED_USGOV_ABOVE_200K = {
"input_cost_per_token_above_200k_tokens": 7.2e-06,
"output_cost_per_token_above_200k_tokens": 2.7e-05,
"cache_creation_input_token_cost_above_200k_tokens": 9.0e-06,
"cache_creation_input_token_cost_above_1hr_above_200k_tokens": 1.44e-05,
"cache_read_input_token_cost_above_200k_tokens": 7.2e-07,
}
@pytest.mark.parametrize("field,expected", EXPECTED_USGOV_ABOVE_200K.items())
def test_usgov_cross_region_above_200k_carries_gov_premium(model_data, field, expected):
"""The `_above_200k_tokens` tier on the us-gov cross-region inference
profile must also carry the +20% GovCloud uplift. The original PR
corrected the base rates but left the 200k-tier fields at the +10%
commercial-US rates, undercharging long-context requests.
"""
info = model_data[USGOV_CROSS_REGION_KEY]
assert field in info, f"{USGOV_CROSS_REGION_KEY}: missing field {field}"
assert (
info[field] == expected
), f"{USGOV_CROSS_REGION_KEY}: {field} should be {expected} (got {info[field]})"
def test_usgov_cross_region_above_200k_ratio_to_global(model_data):
"""Cross-check via the property-based invariant: every `_above_200k_tokens`
field on the us-gov cross-region profile must equal 1.2x the global
anthropic.* rate, the same GovCloud uplift the base tier carries.
"""
global_key = "anthropic.claude-sonnet-4-5-20250929-v1:0"
global_info = model_data[global_key]
usgov_info = model_data[USGOV_CROSS_REGION_KEY]
for field in EXPECTED_USGOV_ABOVE_200K:
ratio = usgov_info[field] / global_info[field]
assert (
abs(ratio - 1.2) < 1e-9
), f"{field}: us-gov / global ratio is {ratio}, expected 1.2"