Litellm OSS Staging 010626 (#29422)
This commit is contained in:
parent
b7bbddbd4d
commit
5fd27141cf
@ -106,6 +106,7 @@ GATEWAY_PATH_PREFIXES: tuple[str, ...] = (
|
||||
# Health & ops
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/watsonx"
|
||||
)
|
||||
|
||||
GATEWAY_EXACT_PATHS: frozenset[str] = frozenset(
|
||||
|
||||
@ -771,6 +771,7 @@ openai_compatible_endpoints: List = [
|
||||
"https://api.moonshot.ai/v1",
|
||||
"https://api.publicai.co/v1",
|
||||
"https://api.synthetic.new/openai/v1",
|
||||
"https://serverless.tensormesh.ai/v1",
|
||||
"https://api.stima.tech/v1",
|
||||
"https://nano-gpt.com/api/v1",
|
||||
"https://api.poe.com/v1",
|
||||
@ -820,6 +821,7 @@ openai_compatible_providers: List = [
|
||||
"meta_llama",
|
||||
"publicai", # PublicAI - JSON-configured provider
|
||||
"synthetic", # Synthetic - JSON-configured provider
|
||||
"tensormesh", # Tensormesh - JSON-configured provider
|
||||
"apertis", # Apertis - JSON-configured provider
|
||||
"nano-gpt", # Nano-GPT - JSON-configured provider
|
||||
"poe", # Poe - JSON-configured provider
|
||||
@ -855,6 +857,7 @@ openai_text_completion_compatible_providers: List = (
|
||||
"moonshot",
|
||||
"publicai",
|
||||
"synthetic",
|
||||
"tensormesh",
|
||||
"apertis",
|
||||
"nano-gpt",
|
||||
"poe",
|
||||
|
||||
@ -95,7 +95,9 @@ class FocusTransformer:
|
||||
pl.lit("Usage-Based").alias("ChargeFrequency"),
|
||||
fmt(pl.col("ChargePeriodEnd")).alias("ChargePeriodEnd"),
|
||||
fmt(pl.col("ChargePeriodStart")).alias("ChargePeriodStart"),
|
||||
dec(pl.lit(1.0)).alias("ConsumedQuantity"),
|
||||
dec(
|
||||
pl.col("api_requests").cast(pl.Int64).cast(pl.Float64).fill_null(0.0)
|
||||
).alias("ConsumedQuantity"),
|
||||
pl.lit("Requests").alias("ConsumedUnit"),
|
||||
dec(pl.col("spend").fill_null(0.0)).alias("ContractedCost"),
|
||||
none_str.alias("ContractedUnitPrice"),
|
||||
@ -107,7 +109,9 @@ class FocusTransformer:
|
||||
none_str.alias("AvailabilityZone"),
|
||||
pl.lit("USD").alias("PricingCurrency"),
|
||||
none_str.alias("PricingCategory"),
|
||||
dec(pl.lit(1.0)).alias("PricingQuantity"),
|
||||
dec(
|
||||
pl.col("api_requests").cast(pl.Int64).cast(pl.Float64).fill_null(0.0)
|
||||
).alias("PricingQuantity"),
|
||||
none_dec.alias("PricingCurrencyContractedUnitPrice"),
|
||||
dec(pl.col("spend").fill_null(0.0)).alias("PricingCurrencyEffectiveCost"),
|
||||
none_dec.alias("PricingCurrencyListUnitPrice"),
|
||||
|
||||
@ -511,6 +511,23 @@ class PrometheusLogger(CustomLogger):
|
||||
labelnames=self.get_labels_for_metric("litellm_cached_tokens_metric"),
|
||||
)
|
||||
|
||||
# Provider prompt-caching metrics
|
||||
self.litellm_provider_cache_read_input_tokens_metric = self._counter_factory(
|
||||
name="litellm_provider_cache_read_input_tokens_metric",
|
||||
documentation="Total prompt/input tokens read from provider prompt cache (e.g. OpenAI/Anthropic/Gemini/Bedrock)",
|
||||
labelnames=self.get_labels_for_metric(
|
||||
"litellm_provider_cache_read_input_tokens_metric"
|
||||
),
|
||||
)
|
||||
|
||||
self.litellm_provider_cache_creation_input_tokens_metric = self._counter_factory(
|
||||
name="litellm_provider_cache_creation_input_tokens_metric",
|
||||
documentation="Total prompt/input tokens written to provider prompt cache (e.g. Anthropic/Bedrock)",
|
||||
labelnames=self.get_labels_for_metric(
|
||||
"litellm_provider_cache_creation_input_tokens_metric"
|
||||
),
|
||||
)
|
||||
|
||||
# User and Team count metrics
|
||||
self.litellm_total_users_metric = self._gauge_factory(
|
||||
"litellm_total_users",
|
||||
@ -1458,11 +1475,11 @@ class PrometheusLogger(CustomLogger):
|
||||
"""
|
||||
cache_hit = standard_logging_payload.get("cache_hit")
|
||||
|
||||
# Only track if cache_hit has a definite value (True or False)
|
||||
if cache_hit is None:
|
||||
return
|
||||
|
||||
if cache_hit is True:
|
||||
# Historically these metrics only tracked LiteLLM caching.
|
||||
# Provider prompt-caching metrics are still emitted below.
|
||||
pass
|
||||
elif cache_hit is True:
|
||||
# Increment cache hits counter
|
||||
PrometheusLogger._inc_labeled_counter(
|
||||
self,
|
||||
@ -1493,6 +1510,51 @@ class PrometheusLogger(CustomLogger):
|
||||
label_context=label_context,
|
||||
)
|
||||
|
||||
# Provider prompt caching metrics are independent of LiteLLM cache_hit.
|
||||
provider_cache_read_tokens = 0
|
||||
provider_cache_creation_tokens = 0
|
||||
usage_obj = (standard_logging_payload.get("metadata", {}) or {}).get(
|
||||
"usage_object"
|
||||
)
|
||||
if isinstance(usage_obj, dict):
|
||||
# Prefer explicit provider cache fields when available.
|
||||
_read = usage_obj.get("cache_read_input_tokens")
|
||||
_write = usage_obj.get("cache_creation_input_tokens")
|
||||
|
||||
if isinstance(_read, int):
|
||||
provider_cache_read_tokens = _read
|
||||
if isinstance(_write, int):
|
||||
provider_cache_creation_tokens = _write
|
||||
|
||||
# Fallback to prompt_tokens_details.cached_tokens (common normalization point).
|
||||
# Only fallback when the explicit field is genuinely absent (None).
|
||||
if _read is None:
|
||||
prompt_details = usage_obj.get("prompt_tokens_details")
|
||||
if isinstance(prompt_details, dict):
|
||||
cached_tokens = prompt_details.get("cached_tokens")
|
||||
if isinstance(cached_tokens, int):
|
||||
provider_cache_read_tokens = cached_tokens
|
||||
|
||||
if provider_cache_read_tokens > 0:
|
||||
PrometheusLogger._inc_labeled_counter(
|
||||
self,
|
||||
self.litellm_provider_cache_read_input_tokens_metric,
|
||||
"litellm_provider_cache_read_input_tokens_metric",
|
||||
enum_values,
|
||||
label_context=label_context,
|
||||
amount=float(provider_cache_read_tokens),
|
||||
)
|
||||
|
||||
if provider_cache_creation_tokens > 0:
|
||||
PrometheusLogger._inc_labeled_counter(
|
||||
self,
|
||||
self.litellm_provider_cache_creation_input_tokens_metric,
|
||||
"litellm_provider_cache_creation_input_tokens_metric",
|
||||
enum_values,
|
||||
label_context=label_context,
|
||||
amount=float(provider_cache_creation_tokens),
|
||||
)
|
||||
|
||||
async def _increment_remaining_budget_metrics(
|
||||
self,
|
||||
user_api_team: Optional[str],
|
||||
|
||||
@ -47,8 +47,9 @@ async def async_completion_with_fallbacks(**kwargs):
|
||||
completion_kwargs = safe_deep_copy(base_kwargs)
|
||||
# Handle dictionary fallback configurations
|
||||
if isinstance(fallback, dict):
|
||||
model = fallback.pop("model", original_model)
|
||||
completion_kwargs.update(fallback)
|
||||
fallback_config = safe_deep_copy(dict(fallback))
|
||||
model = fallback_config.pop("model", original_model)
|
||||
completion_kwargs.update(fallback_config)
|
||||
else:
|
||||
model = fallback
|
||||
|
||||
|
||||
@ -1534,10 +1534,9 @@ class BaseAWSLLM:
|
||||
)
|
||||
|
||||
sigv4 = SigV4Auth(credentials, service_name, aws_region_name)
|
||||
if headers is not None:
|
||||
headers = headers or {}
|
||||
if not any(header_name.lower() == "content-type" for header_name in headers):
|
||||
headers = {"Content-Type": "application/json", **headers}
|
||||
else:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
aws_signature_headers = self._filter_headers_for_aws_signature(headers)
|
||||
request = AWSRequest(
|
||||
|
||||
@ -233,6 +233,259 @@ class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig):
|
||||
# example; add others here as they adopt the same schema.
|
||||
CONVERSE_INVOKE_PROVIDERS = ("nova",)
|
||||
|
||||
# OpenAI batch URL that signals an embedding request. Per OpenAI Batch API
|
||||
# spec, every JSONL record carries a `url` field; we use it as the
|
||||
# authoritative signal to route the line to the embedding code path
|
||||
# instead of inferring from the presence of `input` vs `messages`.
|
||||
OPENAI_EMBEDDINGS_URL = "/v1/embeddings"
|
||||
|
||||
@staticmethod
|
||||
def _is_embedding_record(openai_jsonl_record: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Decide whether an OpenAI batch JSONL line is an embedding request.
|
||||
|
||||
Precedence (strict - any explicit `url` short-circuits):
|
||||
1. `url == "/v1/embeddings"` -> embedding. Authoritative per the
|
||||
OpenAI Batch API spec.
|
||||
2. Any other non-empty `url` (e.g. `/v1/chat/completions`) -> NOT
|
||||
embedding. We trust the caller's explicit signal even if the
|
||||
body would otherwise suggest embedding; misrouting a chat
|
||||
record into the embedding transformer would corrupt the
|
||||
modelInput, while a chat-shaped body sent to the chat path
|
||||
either succeeds or fails cleanly inside that transformer.
|
||||
3. `url` missing/empty -> fall back to body shape. Requires
|
||||
`input` present AND `messages` absent so a malformed record
|
||||
carrying both keys routes to the chat path (safer default:
|
||||
Anthropic transforms ignore unknown top-level keys, whereas
|
||||
the embedding transformer would silently drop the messages).
|
||||
"""
|
||||
url = openai_jsonl_record.get("url")
|
||||
if url == BedrockFilesConfig.OPENAI_EMBEDDINGS_URL:
|
||||
return True
|
||||
if url:
|
||||
return False
|
||||
body = openai_jsonl_record.get("body", {})
|
||||
if not isinstance(body, dict):
|
||||
return False
|
||||
return "input" in body and "messages" not in body
|
||||
|
||||
# Identifier for the Bedrock Titan v2 InvokeModel body schema as stored
|
||||
# in `model_prices_and_context_window.json`. Centralized so future
|
||||
# embedding-schema variants can add their own value
|
||||
# (e.g. `cohere_v3`, `titan_g1`, `titan_multimodal`) without touching
|
||||
# the detection logic.
|
||||
_TITAN_V2_INVOCATION_SCHEMA = "titan_v2"
|
||||
|
||||
# Substring marker used as a fallback when the registry can't resolve
|
||||
# the model id - notably cross-region inference profile prefixes
|
||||
# (`us.amazon.titan-embed-text-v2:0`) and Bedrock ARN forms, which
|
||||
# `get_model_info` doesn't normalize today.
|
||||
_TITAN_V2_EMBED_MODEL_MARKER = "titan-embed-text-v2"
|
||||
|
||||
# Nested field name under `provider_specific_entry` that identifies the
|
||||
# Bedrock InvokeModel body schema for batch inference.
|
||||
# `provider_specific_entry` is the registry's escape hatch for fields
|
||||
# `get_model_info` doesn't promote to top-level - exactly what we need
|
||||
# here. Documented in the `sample_spec` entry of
|
||||
# `model_prices_and_context_window.json` and surfaced by
|
||||
# `get_model_info` (see `ModelInfo.provider_specific_entry`).
|
||||
_BEDROCK_INVOCATION_SCHEMA_FIELD = "bedrock_invocation_schema"
|
||||
|
||||
@staticmethod
|
||||
def _is_titan_v2_embed_model(model: str) -> bool:
|
||||
"""
|
||||
True iff `model` refers to Amazon Titan Text Embeddings V2.
|
||||
|
||||
Resolution order:
|
||||
1. `model_prices_and_context_window.json` via `get_model_info`.
|
||||
The Titan v2 registry entry carries an explicit
|
||||
`provider_specific_entry.bedrock_invocation_schema` discriminator
|
||||
(`"titan_v2"`). When the registry resolves the id we trust that
|
||||
field as the source of truth - no hardcoded model-id comparison
|
||||
needed.
|
||||
2. Substring fallback (`titan-embed-text-v2` followed by `:`, `/`,
|
||||
or end-of-string) for ids the registry can't normalize. This
|
||||
catches cross-region inference profile prefixes
|
||||
(`us.amazon.titan-embed-text-v2:0`) and Bedrock ARN forms; the
|
||||
marker boundary check rejects lookalikes like
|
||||
`titan-embed-text-v20` or `titan-embed-text-v2-experimental`.
|
||||
|
||||
Tolerant of common id shapes:
|
||||
- "amazon.titan-embed-text-v2:0"
|
||||
- "bedrock/amazon.titan-embed-text-v2:0"
|
||||
- "us.amazon.titan-embed-text-v2:0" (cross-region inference profile)
|
||||
- ARN forms ending in ".../amazon.titan-embed-text-v2:0"
|
||||
"""
|
||||
# Registry-driven path: when get_model_info resolves the id we trust
|
||||
# the registry's discriminator. A resolved id with a different (or
|
||||
# absent) schema value here is intentionally not given a substring
|
||||
# second-chance - the registry is authoritative for ids it knows.
|
||||
registry_schema = BedrockFilesConfig._lookup_provider_specific_field(
|
||||
model, BedrockFilesConfig._BEDROCK_INVOCATION_SCHEMA_FIELD
|
||||
)
|
||||
if registry_schema is not None:
|
||||
return registry_schema == BedrockFilesConfig._TITAN_V2_INVOCATION_SCHEMA
|
||||
|
||||
# Registry silence -> substring fallback for unmapped ids only.
|
||||
normalized = model.lower()
|
||||
if normalized.startswith("bedrock/"):
|
||||
normalized = normalized[len("bedrock/") :]
|
||||
marker = BedrockFilesConfig._TITAN_V2_EMBED_MODEL_MARKER
|
||||
idx = normalized.find(marker)
|
||||
if idx < 0:
|
||||
return False
|
||||
end = idx + len(marker)
|
||||
return end == len(normalized) or normalized[end] in (":", "/")
|
||||
|
||||
@staticmethod
|
||||
def _lookup_provider_specific_field(model_id: str, field: str) -> Optional[str]:
|
||||
"""
|
||||
Read a nested string field from the registry entry's
|
||||
`provider_specific_entry` dict via `litellm.get_model_info`.
|
||||
|
||||
Returns the field's string value when:
|
||||
- the registry resolves `model_id`,
|
||||
- the entry exposes `provider_specific_entry` as a dict, and
|
||||
- that dict has `field` mapped to a non-empty string.
|
||||
Otherwise returns `None`.
|
||||
|
||||
Isolating this means feature detectors (Titan v2 today, future
|
||||
Cohere Embed / Nova Multimodal branches) share one defensive
|
||||
try/except shape instead of duplicating it. The `None` return
|
||||
covers every realistic failure mode: `get_model_info` raises
|
||||
(cross-region profile prefixes, Bedrock ARN forms, unreleased
|
||||
models), returns a non-dict, has no `provider_specific_entry`, or
|
||||
the requested field is missing / non-string / empty.
|
||||
"""
|
||||
try:
|
||||
from litellm import get_model_info
|
||||
|
||||
info = get_model_info(model_id)
|
||||
except Exception:
|
||||
return None
|
||||
if not isinstance(info, dict):
|
||||
return None
|
||||
provider_specific = info.get("provider_specific_entry")
|
||||
if not isinstance(provider_specific, dict):
|
||||
return None
|
||||
value = provider_specific.get(field)
|
||||
return value if isinstance(value, str) and value else None
|
||||
|
||||
@staticmethod
|
||||
def _coerce_embedding_input_to_string(raw_input: Any, model: str = "") -> str:
|
||||
"""
|
||||
Normalize an OpenAI /v1/embeddings `input` field into the single
|
||||
string that Bedrock Titan v2 InvokeModel expects in `inputText`.
|
||||
|
||||
Accepts: a string, or a single-element list containing one string.
|
||||
Rejects (with actionable messages):
|
||||
- None / missing -> ValueError
|
||||
- Multi-element string lists -> ValueError, prompts caller to
|
||||
emit one JSONL line per input
|
||||
- Pre-tokenized inputs (List[int], List[List[int]]) -> NotImplementedError
|
||||
- Any other type -> ValueError
|
||||
|
||||
Extracted so the validation can be exercised in isolation and so
|
||||
future embedding-provider branches (Titan G1, Cohere) can reuse it
|
||||
without duplicating the type-shaping logic.
|
||||
"""
|
||||
if raw_input is None:
|
||||
raise ValueError(
|
||||
"Embedding batch record is missing required `input` field: "
|
||||
f"model={model}"
|
||||
)
|
||||
|
||||
# Bedrock InvokeModel for Titan v2 takes exactly one string `inputText`
|
||||
# per call. Pre-tokenized inputs and multi-element string lists are
|
||||
# explicitly unsupported so callers emit one JSONL line per embedding
|
||||
# instead of relying on us to silently fan out or concatenate.
|
||||
if isinstance(raw_input, list):
|
||||
if len(raw_input) == 1:
|
||||
candidate = raw_input[0]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Bedrock batch embedding requires one input per JSONL "
|
||||
"record. Got a list with "
|
||||
f"{len(raw_input)} items for model={model}; emit one "
|
||||
"JSONL line per input string instead."
|
||||
)
|
||||
else:
|
||||
candidate = raw_input
|
||||
|
||||
# Catches pre-tokenized inputs (List[int] from OpenAI spec, or a
|
||||
# single int slipping past the list-unwrap above).
|
||||
# NOTE: bool is a subclass of int but treating True/False as a token
|
||||
# is meaningless either way, so the broad check is fine.
|
||||
if isinstance(candidate, (list, int)):
|
||||
raise NotImplementedError(
|
||||
"Bedrock Titan v2 batch embedding does not support "
|
||||
"pre-tokenized integer inputs. Pass `input` as a string "
|
||||
f"(model={model})."
|
||||
)
|
||||
if not isinstance(candidate, str):
|
||||
raise ValueError(
|
||||
"Bedrock batch embedding `input` must be a string (or a "
|
||||
"single-element list of strings). Got type "
|
||||
f"{type(candidate).__name__} for model={model}."
|
||||
)
|
||||
return candidate
|
||||
|
||||
def _map_openai_embedding_to_bedrock_params(
|
||||
self,
|
||||
openai_request_body: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Transform an OpenAI /v1/embeddings request body into the
|
||||
Bedrock InvokeModel `modelInput` for embedding models that AWS
|
||||
supports via batch inference (CreateModelInvocationJob).
|
||||
|
||||
Currently routes Amazon Titan Text Embeddings V2 only; other
|
||||
embedding providers (Titan G1, Titan Multimodal, Cohere Embed,
|
||||
Nova Multimodal Embeddings) raise NotImplementedError until they
|
||||
get a dedicated branch. Splitting them keeps PR scope tight and
|
||||
lets each model's request schema be exercised by its own tests.
|
||||
|
||||
AWS docs (Titan v2 InvokeModel body):
|
||||
https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html
|
||||
"""
|
||||
from litellm.llms.bedrock.embed.amazon_titan_v2_transformation import (
|
||||
AmazonTitanV2Config,
|
||||
)
|
||||
|
||||
_model = openai_request_body.get("model", "")
|
||||
if not self._is_titan_v2_embed_model(_model):
|
||||
# Refuse early instead of silently shaping the body for the wrong
|
||||
# provider. The synchronous /v1/embeddings path supports more
|
||||
# models, but each has a different InvokeModel schema; mapping
|
||||
# them here without dedicated tests would risk corrupt batches.
|
||||
raise NotImplementedError(
|
||||
"Bedrock batch embedding currently supports only Amazon "
|
||||
"Titan Text Embeddings V2 (model id contains "
|
||||
f"'titan-embed-text-v2'). Got model={_model!r}. Track other "
|
||||
"embedding models in https://github.com/BerriAI/litellm/issues."
|
||||
)
|
||||
|
||||
input_text = self._coerce_embedding_input_to_string(
|
||||
openai_request_body.get("input"), model=_model
|
||||
)
|
||||
|
||||
# Map OpenAI-style params (dimensions, encoding_format) onto the
|
||||
# Titan v2 schema (dimensions, embeddingTypes) via the embed config
|
||||
# so this stays in sync with the synchronous /v1/embeddings path.
|
||||
non_default_params = {
|
||||
k: v for k, v in openai_request_body.items() if k not in ("model", "input")
|
||||
}
|
||||
titan_config = AmazonTitanV2Config()
|
||||
inference_params = titan_config.map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params={},
|
||||
)
|
||||
return dict(
|
||||
titan_config._transform_request(
|
||||
input=input_text, inference_params=inference_params
|
||||
)
|
||||
)
|
||||
|
||||
def _map_openai_to_bedrock_params(
|
||||
self,
|
||||
openai_request_body: Dict[str, Any],
|
||||
@ -349,10 +602,19 @@ class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig):
|
||||
# Determine provider from model name
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
|
||||
# Transform to Bedrock modelInput format
|
||||
model_input = self._map_openai_to_bedrock_params(
|
||||
openai_request_body=openai_body, provider=provider
|
||||
)
|
||||
# Route to the embedding transformer when the OpenAI batch line
|
||||
# targets /v1/embeddings; otherwise fall back to the existing
|
||||
# chat-completion path. We branch here (rather than inside
|
||||
# `_map_openai_to_bedrock_params`) so the chat helper keeps its
|
||||
# narrow contract and the embedding helper can evolve independently.
|
||||
if self._is_embedding_record(_openai_jsonl_content):
|
||||
model_input = self._map_openai_embedding_to_bedrock_params(
|
||||
openai_request_body=openai_body
|
||||
)
|
||||
else:
|
||||
model_input = self._map_openai_to_bedrock_params(
|
||||
openai_request_body=openai_body, provider=provider
|
||||
)
|
||||
|
||||
# Create Bedrock batch record
|
||||
record_id = _openai_jsonl_content.get(
|
||||
|
||||
@ -5,6 +5,7 @@ Common utilities, constants, and error handling for Black Forest Labs API.
|
||||
"""
|
||||
|
||||
from typing import Dict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from litellm.llms.base_llm.chat.transformation import BaseLLMException
|
||||
|
||||
@ -18,6 +19,42 @@ class BlackForestLabsError(BaseLLMException):
|
||||
# API Constants
|
||||
DEFAULT_API_BASE = "https://api.bfl.ai"
|
||||
|
||||
# BFL uses regional subdomains (e.g. gateway.bfl.ai) for polling URLs that
|
||||
# differ from the submission host (api.bfl.ai). We validate against the
|
||||
# registered domain rather than doing a strict same-origin check.
|
||||
_BFL_REGISTERED_DOMAIN = "bfl.ai"
|
||||
|
||||
|
||||
def assert_bfl_polling_url(polling_url: str) -> None:
|
||||
"""Validate that a polling URL points to a BFL-controlled host.
|
||||
|
||||
BFL returns polling URLs on subdomains like ``gateway.bfl.ai`` that differ
|
||||
from the submission host ``api.bfl.ai``. A strict same-origin check would
|
||||
reject these legitimate URLs. Instead we verify the host is ``bfl.ai`` or
|
||||
any subdomain of it, which keeps the SSRF guarantee (credentials only go
|
||||
to BFL-controlled infrastructure) without false-positives on regional hosts.
|
||||
|
||||
Raises:
|
||||
BlackForestLabsError: If the polling URL scheme or host is not trusted.
|
||||
"""
|
||||
parsed = urlparse(polling_url)
|
||||
host = (parsed.hostname or "").lower()
|
||||
|
||||
if parsed.scheme != "https":
|
||||
raise BlackForestLabsError(
|
||||
status_code=502,
|
||||
message="Rejected polling URL: scheme must be https",
|
||||
)
|
||||
|
||||
if host != _BFL_REGISTERED_DOMAIN and not host.endswith(
|
||||
"." + _BFL_REGISTERED_DOMAIN
|
||||
):
|
||||
raise BlackForestLabsError(
|
||||
status_code=502,
|
||||
message="Rejected polling URL: host is not within the bfl.ai domain",
|
||||
)
|
||||
|
||||
|
||||
# Polling configuration
|
||||
DEFAULT_POLLING_INTERVAL = 1.5 # seconds
|
||||
DEFAULT_MAX_POLLING_TIME = 300 # 5 minutes
|
||||
|
||||
@ -15,7 +15,6 @@ import httpx
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
@ -29,6 +28,7 @@ from ..common_utils import (
|
||||
DEFAULT_MAX_POLLING_TIME,
|
||||
DEFAULT_POLLING_INTERVAL,
|
||||
BlackForestLabsError,
|
||||
assert_bfl_polling_url,
|
||||
)
|
||||
from .transformation import BlackForestLabsImageEditConfig
|
||||
|
||||
@ -332,16 +332,11 @@ class BlackForestLabsImageEdit:
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Reject cross-origin polling URLs — the ``x-key`` auth header
|
||||
# would otherwise leak to whatever URL the upstream returns.
|
||||
# VERIA-51.
|
||||
try:
|
||||
assert_same_origin(polling_url, str(initial_response.request.url))
|
||||
except SSRFError as ssrf_err:
|
||||
raise BlackForestLabsError(
|
||||
status_code=502,
|
||||
message=f"Rejected polling URL: {ssrf_err}",
|
||||
)
|
||||
# Reject polling URLs that don't belong to BFL-controlled infrastructure.
|
||||
# BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the
|
||||
# submission host (api.bfl.ai), so we validate against the registered
|
||||
# domain rather than doing a strict same-origin check. VERIA-51.
|
||||
assert_bfl_polling_url(polling_url)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
@ -428,16 +423,11 @@ class BlackForestLabsImageEdit:
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Reject cross-origin polling URLs — the ``x-key`` auth header
|
||||
# would otherwise leak to whatever URL the upstream returns.
|
||||
# VERIA-51.
|
||||
try:
|
||||
assert_same_origin(polling_url, str(initial_response.request.url))
|
||||
except SSRFError as ssrf_err:
|
||||
raise BlackForestLabsError(
|
||||
status_code=502,
|
||||
message=f"Rejected polling URL: {ssrf_err}",
|
||||
)
|
||||
# Reject polling URLs that don't belong to BFL-controlled infrastructure.
|
||||
# BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the
|
||||
# submission host (api.bfl.ai), so we validate against the registered
|
||||
# domain rather than doing a strict same-origin check. VERIA-51.
|
||||
assert_bfl_polling_url(polling_url)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
@ -15,7 +15,6 @@ import httpx
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
@ -29,6 +28,7 @@ from ..common_utils import (
|
||||
DEFAULT_MAX_POLLING_TIME,
|
||||
DEFAULT_POLLING_INTERVAL,
|
||||
BlackForestLabsError,
|
||||
assert_bfl_polling_url,
|
||||
)
|
||||
from .transformation import BlackForestLabsImageGenerationConfig
|
||||
|
||||
@ -172,6 +172,10 @@ class BlackForestLabsImageGeneration:
|
||||
raw_response=final_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
encoding=None,
|
||||
)
|
||||
|
||||
async def async_image_generation(
|
||||
@ -274,6 +278,10 @@ class BlackForestLabsImageGeneration:
|
||||
raw_response=final_response,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
request_data=data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params_dict,
|
||||
encoding=None,
|
||||
)
|
||||
|
||||
def _poll_for_result_sync(
|
||||
@ -318,16 +326,11 @@ class BlackForestLabsImageGeneration:
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Reject cross-origin polling URLs — the ``x-key`` auth header
|
||||
# would otherwise leak to whatever URL the upstream returns.
|
||||
# VERIA-51.
|
||||
try:
|
||||
assert_same_origin(polling_url, str(initial_response.request.url))
|
||||
except SSRFError as ssrf_err:
|
||||
raise BlackForestLabsError(
|
||||
status_code=502,
|
||||
message=f"Rejected polling URL: {ssrf_err}",
|
||||
)
|
||||
# Reject polling URLs that don't belong to BFL-controlled infrastructure.
|
||||
# BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the
|
||||
# submission host (api.bfl.ai), so we validate against the registered
|
||||
# domain rather than doing a strict same-origin check. VERIA-51.
|
||||
assert_bfl_polling_url(polling_url)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
@ -414,16 +417,11 @@ class BlackForestLabsImageGeneration:
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Reject cross-origin polling URLs — the ``x-key`` auth header
|
||||
# would otherwise leak to whatever URL the upstream returns.
|
||||
# VERIA-51.
|
||||
try:
|
||||
assert_same_origin(polling_url, str(initial_response.request.url))
|
||||
except SSRFError as ssrf_err:
|
||||
raise BlackForestLabsError(
|
||||
status_code=502,
|
||||
message=f"Rejected polling URL: {ssrf_err}",
|
||||
)
|
||||
# Reject polling URLs that don't belong to BFL-controlled infrastructure.
|
||||
# BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the
|
||||
# submission host (api.bfl.ai), so we validate against the registered
|
||||
# domain rather than doing a strict same-origin check. VERIA-51.
|
||||
assert_bfl_polling_url(polling_url)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
@ -376,6 +376,13 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
)
|
||||
|
||||
guardrailed_texts = guardrailed_inputs.get("texts", [])
|
||||
returned_tool_calls = guardrailed_inputs.get("tool_calls")
|
||||
guardrailed_tool_calls: List[Dict[str, Any]] = (
|
||||
cast(List[Dict[str, Any]], returned_tool_calls)
|
||||
if isinstance(returned_tool_calls, list)
|
||||
and len(returned_tool_calls) == len(tool_calls_to_check)
|
||||
else tool_calls_to_check
|
||||
)
|
||||
|
||||
# Step 3: Map guardrail responses back to original response structure
|
||||
if guardrailed_texts and texts_to_check:
|
||||
@ -386,10 +393,10 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
)
|
||||
|
||||
# Step 4: Apply guardrailed tool calls back to response
|
||||
if tool_calls_to_check:
|
||||
if guardrailed_tool_calls:
|
||||
await self._apply_guardrail_responses_to_output_tool_calls(
|
||||
response=response,
|
||||
tool_calls=tool_calls_to_check,
|
||||
tool_calls=guardrailed_tool_calls,
|
||||
task_mappings=tool_call_task_mappings,
|
||||
)
|
||||
|
||||
@ -748,10 +755,11 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
task_mappings: List[Tuple[int, int]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply guardrailed tool calls back to output response.
|
||||
Apply guardrailed tool calls back to the output response.
|
||||
|
||||
The guardrail may have modified the tool_calls list in place,
|
||||
so we apply the modified tool calls back to the original response.
|
||||
The guardrail may return updated tool calls (either mutated in place or as
|
||||
a new list), so we apply the provided tool calls back to the original
|
||||
response.
|
||||
|
||||
Override this method to customize how tool call responses are applied.
|
||||
"""
|
||||
|
||||
@ -114,5 +114,14 @@
|
||||
"param_mappings": {
|
||||
"max_completion_tokens": "max_tokens"
|
||||
}
|
||||
},
|
||||
"tensormesh": {
|
||||
"base_url": "https://serverless.tensormesh.ai/v1",
|
||||
"api_key_env": "TENSORMESH_INFERENCE_API_KEY",
|
||||
"api_base_env": "TENSORMESH_SERVERLESS_BASE_URL",
|
||||
"base_class": "openai_gpt",
|
||||
"param_mappings": {
|
||||
"max_completion_tokens": "max_tokens"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
0
litellm/llms/watsonx/passthrough/__init__.py
Normal file
0
litellm/llms/watsonx/passthrough/__init__.py
Normal file
69
litellm/llms/watsonx/passthrough/transformation.py
Normal file
69
litellm/llms/watsonx/passthrough/transformation.py
Normal file
@ -0,0 +1,69 @@
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple
|
||||
|
||||
from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig
|
||||
from litellm.llms.watsonx.common_utils import IBMWatsonXMixin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from httpx import URL
|
||||
|
||||
|
||||
class WatsonxPassthroughConfig(IBMWatsonXMixin, BasePassthroughConfig):
|
||||
"""
|
||||
Watsonx-specific passthrough configuration.
|
||||
"""
|
||||
|
||||
def is_streaming_request(self, endpoint: str, request_data: dict) -> bool:
|
||||
"""Check if request should be streamed"""
|
||||
return request_data.get("stream", False)
|
||||
|
||||
def get_complete_url(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
model: str,
|
||||
endpoint: str,
|
||||
request_query_params: Optional[dict],
|
||||
litellm_params: dict,
|
||||
) -> Tuple["URL", str]:
|
||||
"""
|
||||
Construct complete Watsonx URL with version parameter.
|
||||
|
||||
This ensures the version parameter is ALWAYS included in the URL,
|
||||
solving the query parameter issue.
|
||||
"""
|
||||
base_target_url = str(self.get_api_base(api_base))
|
||||
|
||||
# Use the format_url helper to construct URL with query params
|
||||
complete_url = self.format_url(
|
||||
endpoint=endpoint,
|
||||
base_target_url=base_target_url,
|
||||
request_query_params=request_query_params,
|
||||
)
|
||||
|
||||
return (complete_url, base_target_url)
|
||||
|
||||
@staticmethod
|
||||
def get_api_base(
|
||||
api_base: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
return api_base or IBMWatsonXMixin()._get_base_url(api_base=api_base)
|
||||
|
||||
@staticmethod
|
||||
def get_api_key(
|
||||
api_key: Optional[str] = None,
|
||||
) -> Optional[str]:
|
||||
return (
|
||||
api_key
|
||||
or IBMWatsonXMixin.get_watsonx_credentials(
|
||||
optional_params=dict(), api_base=None, api_key=api_key
|
||||
)["api_key"]
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_base_model(model: str) -> Optional[str]:
|
||||
return model
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
return super().get_models(api_key, api_base)
|
||||
@ -577,7 +577,10 @@
|
||||
"max_tokens": 8192,
|
||||
"mode": "embedding",
|
||||
"output_cost_per_token": 0.0,
|
||||
"output_vector_size": 1024
|
||||
"output_vector_size": 1024,
|
||||
"provider_specific_entry": {
|
||||
"bedrock_invocation_schema": "titan_v2"
|
||||
}
|
||||
},
|
||||
"amazon.titan-image-generator-v1": {
|
||||
"input_cost_per_image": 0.0,
|
||||
@ -8899,15 +8902,16 @@
|
||||
"cache_creation_input_token_cost": 3.75e-07
|
||||
},
|
||||
"bedrock/us-gov-east-1/anthropic.claude-sonnet-4-5-20250929-v1:0": {
|
||||
"cache_creation_input_token_cost": 4.125e-06,
|
||||
"cache_read_input_token_cost": 3.3e-07,
|
||||
"input_cost_per_token": 3.3e-06,
|
||||
"cache_creation_input_token_cost": 4.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
|
||||
"cache_read_input_token_cost": 3.6e-07,
|
||||
"input_cost_per_token": 3.6e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_tokens": 8192,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.65e-05,
|
||||
"output_cost_per_token": 1.8e-05,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
@ -8920,15 +8924,16 @@
|
||||
"supports_native_structured_output": true
|
||||
},
|
||||
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
|
||||
"cache_creation_input_token_cost": 4.125e-06,
|
||||
"cache_read_input_token_cost": 3.3e-07,
|
||||
"input_cost_per_token": 3.3e-06,
|
||||
"cache_creation_input_token_cost": 4.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
|
||||
"cache_read_input_token_cost": 3.6e-07,
|
||||
"input_cost_per_token": 3.6e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_tokens": 8192,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.65e-05,
|
||||
"output_cost_per_token": 1.8e-05,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
@ -9072,15 +9077,16 @@
|
||||
"cache_creation_input_token_cost": 3.75e-07
|
||||
},
|
||||
"bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0": {
|
||||
"cache_creation_input_token_cost": 4.125e-06,
|
||||
"cache_read_input_token_cost": 3.3e-07,
|
||||
"input_cost_per_token": 3.3e-06,
|
||||
"cache_creation_input_token_cost": 4.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
|
||||
"cache_read_input_token_cost": 3.6e-07,
|
||||
"input_cost_per_token": 3.6e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_tokens": 8192,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.65e-05,
|
||||
"output_cost_per_token": 1.8e-05,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
@ -9093,15 +9099,16 @@
|
||||
"supports_native_structured_output": true
|
||||
},
|
||||
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
|
||||
"cache_creation_input_token_cost": 4.125e-06,
|
||||
"cache_read_input_token_cost": 3.3e-07,
|
||||
"input_cost_per_token": 3.3e-06,
|
||||
"cache_creation_input_token_cost": 4.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
|
||||
"cache_read_input_token_cost": 3.6e-07,
|
||||
"input_cost_per_token": 3.6e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_tokens": 8192,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.65e-05,
|
||||
"output_cost_per_token": 1.8e-05,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
@ -31858,19 +31865,21 @@
|
||||
"supports_native_structured_output": true
|
||||
},
|
||||
"us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0": {
|
||||
"cache_creation_input_token_cost": 4.125e-06,
|
||||
"cache_read_input_token_cost": 3.3e-07,
|
||||
"input_cost_per_token": 3.3e-06,
|
||||
"input_cost_per_token_above_200k_tokens": 6.6e-06,
|
||||
"output_cost_per_token_above_200k_tokens": 2.475e-05,
|
||||
"cache_creation_input_token_cost_above_200k_tokens": 8.25e-06,
|
||||
"cache_read_input_token_cost_above_200k_tokens": 6.6e-07,
|
||||
"cache_creation_input_token_cost": 4.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
|
||||
"cache_read_input_token_cost": 3.6e-07,
|
||||
"input_cost_per_token": 3.6e-06,
|
||||
"input_cost_per_token_above_200k_tokens": 7.2e-06,
|
||||
"output_cost_per_token_above_200k_tokens": 2.7e-05,
|
||||
"cache_creation_input_token_cost_above_200k_tokens": 9.0e-06,
|
||||
"cache_creation_input_token_cost_above_1hr_above_200k_tokens": 1.44e-05,
|
||||
"cache_read_input_token_cost_above_200k_tokens": 7.2e-07,
|
||||
"litellm_provider": "bedrock_converse",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 64000,
|
||||
"max_tokens": 64000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.65e-05,
|
||||
"output_cost_per_token": 1.8e-05,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
@ -41448,6 +41457,7 @@
|
||||
},
|
||||
"bedrock/us-gov-east-1/anthropic.claude-haiku-4-5-20251001-v1:0": {
|
||||
"cache_creation_input_token_cost": 1.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 2.4e-06,
|
||||
"cache_read_input_token_cost": 1.2e-07,
|
||||
"input_cost_per_token": 1.2e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
@ -41470,6 +41480,7 @@
|
||||
},
|
||||
"bedrock/us-gov-west-1/anthropic.claude-haiku-4-5-20251001-v1:0": {
|
||||
"cache_creation_input_token_cost": 1.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 2.4e-06,
|
||||
"cache_read_input_token_cost": 1.2e-07,
|
||||
"input_cost_per_token": 1.2e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
|
||||
@ -419,6 +419,7 @@ class LiteLLMRoutes(enum.Enum):
|
||||
"/vllm",
|
||||
"/mistral",
|
||||
"/milvus",
|
||||
"/watsonx",
|
||||
]
|
||||
|
||||
#########################################################
|
||||
@ -3901,7 +3902,9 @@ class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase):
|
||||
# Union so Pydantic picks Full when data has server-managed fields
|
||||
# (/team/info) and Base when callers/tests construct with only
|
||||
# user-settable fields.
|
||||
litellm_budget_table: Optional[Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable]]
|
||||
litellm_budget_table: Optional[
|
||||
Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable]
|
||||
] = None
|
||||
|
||||
def safe_get_team_member_rpm_limit(self) -> Optional[int]:
|
||||
if self.litellm_budget_table is not None:
|
||||
|
||||
@ -328,7 +328,24 @@ class ContentFilterGuardrail(CustomGuardrail):
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _resolve_category_file_path(file_path: str) -> str:
|
||||
def _assert_within_categories_dir(path: str, categories_dir: str) -> None:
|
||||
"""Raise ValueError if path escapes the categories directory."""
|
||||
resolved = os.path.realpath(path)
|
||||
allowed = os.path.realpath(categories_dir)
|
||||
try:
|
||||
common = os.path.commonpath([resolved, allowed])
|
||||
except ValueError:
|
||||
# commonpath() raises ValueError on Windows when paths span different drives
|
||||
raise ValueError(
|
||||
f"Category file path '{path}' is outside the allowed categories directory"
|
||||
)
|
||||
if common != allowed:
|
||||
raise ValueError(
|
||||
f"Category file path '{path}' is outside the allowed "
|
||||
f"categories directory '{categories_dir}'"
|
||||
)
|
||||
|
||||
def _resolve_category_file_path(self, file_path: str) -> str:
|
||||
"""
|
||||
Resolve a category file path that may be relative.
|
||||
|
||||
@ -339,12 +356,17 @@ class ContentFilterGuardrail(CustomGuardrail):
|
||||
file isn't found.
|
||||
|
||||
Resolution order:
|
||||
1. Return as-is if absolute or already exists.
|
||||
2. Try joining the full path relative to this module's directory.
|
||||
1. Return as-is if absolute or already exists (jailed to module dir).
|
||||
2. Try joining the full path relative to this module's directory (jailed).
|
||||
3. Progressively strip leading path components and try each suffix
|
||||
relative to this module's directory (handles paths like
|
||||
"litellm/proxy/.../policy_templates/file.yaml" by finding the
|
||||
"policy_templates/file.yaml" suffix that exists).
|
||||
relative to this module's directory (jailed).
|
||||
|
||||
The directory jail can be disabled for deployments that legitimately
|
||||
store category files outside the package (e.g. mounted volumes) by
|
||||
setting the environment variable
|
||||
``LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS=true``. Use only in
|
||||
trusted environments where the proxy configuration cannot be influenced
|
||||
by untrusted input.
|
||||
|
||||
Args:
|
||||
file_path: The file path to resolve (absolute or relative).
|
||||
@ -352,15 +374,33 @@ class ContentFilterGuardrail(CustomGuardrail):
|
||||
Returns:
|
||||
The resolved absolute-ish path, or the original path if
|
||||
resolution fails (caller should check existence).
|
||||
"""
|
||||
if os.path.isabs(file_path) or os.path.exists(file_path):
|
||||
return file_path
|
||||
|
||||
Raises:
|
||||
ValueError: If the resolved path escapes the module directory
|
||||
and ``LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS`` is not set.
|
||||
"""
|
||||
module_dir = os.path.dirname(__file__)
|
||||
allow_external = (
|
||||
os.environ.get("LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS", "").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
if os.path.isabs(file_path) or os.path.exists(file_path):
|
||||
if not allow_external:
|
||||
self._assert_within_categories_dir(file_path, module_dir)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS is set — "
|
||||
"skipping directory jail for category_file '%s'",
|
||||
file_path,
|
||||
)
|
||||
return file_path
|
||||
|
||||
# Try the full relative path joined to the module directory
|
||||
candidate = os.path.join(module_dir, file_path)
|
||||
if os.path.exists(candidate):
|
||||
if not allow_external:
|
||||
self._assert_within_categories_dir(candidate, module_dir)
|
||||
return candidate
|
||||
|
||||
# Progressively strip leading components to find a matching suffix
|
||||
@ -369,8 +409,17 @@ class ContentFilterGuardrail(CustomGuardrail):
|
||||
suffix = os.path.join(*parts[i:])
|
||||
candidate = os.path.join(module_dir, suffix)
|
||||
if os.path.exists(candidate):
|
||||
if not allow_external:
|
||||
self._assert_within_categories_dir(candidate, module_dir)
|
||||
return candidate
|
||||
|
||||
# File not found via any resolution strategy — jail the module-relative
|
||||
# path anyway to reject traversal attempts (e.g. "../../../../etc/passwd")
|
||||
# regardless of CWD or whether the target file exists.
|
||||
if not allow_external:
|
||||
self._assert_within_categories_dir(
|
||||
os.path.join(module_dir, file_path), module_dir
|
||||
)
|
||||
return file_path
|
||||
|
||||
def _load_categories(self, categories: List[ContentFilterCategoryConfig]) -> None:
|
||||
@ -395,6 +444,13 @@ class ContentFilterGuardrail(CustomGuardrail):
|
||||
)
|
||||
continue
|
||||
|
||||
# Prevent path traversal via category_name (e.g. "../../etc/passwd")
|
||||
if not re.match(r"^[a-zA-Z0-9_\-]+$", category_name):
|
||||
verbose_proxy_logger.warning(
|
||||
f"Category name '{category_name}' contains invalid characters, skipping"
|
||||
)
|
||||
continue
|
||||
|
||||
enabled = cat_config.get("enabled", True)
|
||||
action = cat_config.get("action")
|
||||
severity_threshold = (
|
||||
@ -411,7 +467,13 @@ class ContentFilterGuardrail(CustomGuardrail):
|
||||
|
||||
# Load category file (custom or default)
|
||||
if custom_file:
|
||||
category_file_path = self._resolve_category_file_path(custom_file)
|
||||
try:
|
||||
category_file_path = self._resolve_category_file_path(custom_file)
|
||||
except ValueError as e:
|
||||
verbose_proxy_logger.warning(
|
||||
f"Category {category_name}: invalid category_file path, skipping. {e}"
|
||||
)
|
||||
continue
|
||||
else:
|
||||
# Try .yaml first, then .json (e.g. harm_toxic_abuse.json)
|
||||
yaml_path = os.path.join(categories_dir, f"{category_name}.yaml")
|
||||
|
||||
@ -140,7 +140,12 @@ class PanwPrismaAirsHandler(CustomGuardrail):
|
||||
)
|
||||
|
||||
self.fallback_on_error = fallback_on_error
|
||||
self.timeout = timeout
|
||||
# Coerce defensively. The dashboard UI persists this field as a JSON
|
||||
# string, and Pydantic extras (the path that splats model_dump into
|
||||
# this handler) preserve whatever type the user supplied. A string
|
||||
# value would otherwise reach httpx, which raises TypeError on its
|
||||
# internal '<=' comparison and surfaces as a misleading api_error.
|
||||
self.timeout = float(timeout) if timeout is not None else 10.0
|
||||
|
||||
# Tri-state: None = not set (default-on for Anthropic), True = explicit on, False = explicit off
|
||||
self.experimental_use_latest_role_message_only: Optional[bool] = kwargs.get(
|
||||
|
||||
@ -0,0 +1,34 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from litellm.types.guardrails import SupportedGuardrailIntegrations
|
||||
|
||||
from .vigil_guard import VigilGuardGuardrail
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.guardrails import Guardrail, LitellmParams
|
||||
|
||||
|
||||
def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"):
|
||||
import litellm
|
||||
|
||||
_vigil_guard_callback = VigilGuardGuardrail(
|
||||
api_base=litellm_params.api_base,
|
||||
api_key=litellm_params.api_key,
|
||||
unreachable_fallback=litellm_params.unreachable_fallback,
|
||||
timeout=litellm_params.timeout,
|
||||
guardrail_name=guardrail.get("guardrail_name", ""),
|
||||
event_hook=litellm_params.mode,
|
||||
default_on=litellm_params.default_on,
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_vigil_guard_callback)
|
||||
return _vigil_guard_callback
|
||||
|
||||
|
||||
guardrail_initializer_registry = {
|
||||
SupportedGuardrailIntegrations.VIGIL_GUARD.value: initialize_guardrail,
|
||||
}
|
||||
|
||||
|
||||
guardrail_class_registry = {
|
||||
SupportedGuardrailIntegrations.VIGIL_GUARD.value: VigilGuardGuardrail,
|
||||
}
|
||||
@ -0,0 +1,485 @@
|
||||
from json import JSONDecodeError
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Awaitable,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
Tuple,
|
||||
Type,
|
||||
cast,
|
||||
)
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.exceptions import GuardrailRaisedException
|
||||
from litellm.exceptions import Timeout as LiteLLMTimeout
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
Logging as LiteLLMLoggingObj,
|
||||
)
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.base import (
|
||||
GuardrailConfigModel,
|
||||
)
|
||||
|
||||
|
||||
_ANALYZE_ENDPOINT = "/v1/guard/analyze"
|
||||
_DEFAULT_VIGIL_TIMEOUT = httpx.Timeout(10.0, connect=5.0)
|
||||
_BLOCK_REASON_MAX_CHARS = 500
|
||||
_METADATA_STRING_MAX_CHARS = 500
|
||||
_METADATA_ARRAY_MAX_ITEMS = 10
|
||||
_VALID_DECISIONS = ("ALLOWED", "SANITIZED", "BLOCKED")
|
||||
_TRANSIENT_STATUS_CODES = frozenset({429, 502, 503, 504})
|
||||
_METADATA_ALLOWLIST = (
|
||||
"model",
|
||||
"model_group",
|
||||
"provider",
|
||||
"region",
|
||||
"deployment",
|
||||
"user",
|
||||
"user_id",
|
||||
"session_id",
|
||||
"conversation_id",
|
||||
"request_id",
|
||||
"tenant_id",
|
||||
"org_id",
|
||||
)
|
||||
|
||||
_FallbackMode = Literal["fail_closed", "fail_open"]
|
||||
|
||||
|
||||
class _AsyncPostHandler(Protocol):
|
||||
def post(
|
||||
self,
|
||||
*,
|
||||
url: str,
|
||||
headers: Dict[str, str],
|
||||
json: Dict[str, Any],
|
||||
timeout: httpx.Timeout,
|
||||
) -> Awaitable[httpx.Response]: ...
|
||||
|
||||
|
||||
class VigilGuardMissingConfig(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class VigilGuardGuardrail(CustomGuardrail):
|
||||
def __init__(
|
||||
self,
|
||||
api_base: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
unreachable_fallback: Optional[str] = None,
|
||||
timeout: Optional[float] = None,
|
||||
async_handler: Optional[_AsyncPostHandler] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
resolved_base = api_base or get_secret_str("VIGIL_GUARD_URL")
|
||||
if not resolved_base:
|
||||
raise VigilGuardMissingConfig(
|
||||
"Vigil Guard api_base is required. Set api_base in the guardrail "
|
||||
"config or the VIGIL_GUARD_URL environment variable."
|
||||
)
|
||||
self.api_base = resolved_base.rstrip("/")
|
||||
|
||||
resolved_key = api_key or get_secret_str("VIGIL_GUARD_API_KEY")
|
||||
if not resolved_key:
|
||||
raise VigilGuardMissingConfig(
|
||||
"Vigil Guard api_key is required. Set api_key in the guardrail "
|
||||
"config or the VIGIL_GUARD_API_KEY environment variable."
|
||||
)
|
||||
self.api_key = resolved_key
|
||||
|
||||
fallback = (unreachable_fallback or "fail_closed").lower()
|
||||
self.unreachable_fallback: _FallbackMode = (
|
||||
"fail_open" if fallback == "fail_open" else "fail_closed"
|
||||
)
|
||||
|
||||
self.timeout: httpx.Timeout = (
|
||||
_DEFAULT_VIGIL_TIMEOUT
|
||||
if timeout is None
|
||||
else httpx.Timeout(timeout, connect=min(timeout, 5.0))
|
||||
)
|
||||
|
||||
self.async_handler: _AsyncPostHandler = async_handler or get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.GuardrailCallback,
|
||||
)
|
||||
|
||||
if "supported_event_hooks" not in kwargs:
|
||||
kwargs["supported_event_hooks"] = [
|
||||
GuardrailEventHooks.pre_call,
|
||||
GuardrailEventHooks.post_call,
|
||||
]
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@staticmethod
|
||||
def get_config_model() -> Optional[Type["GuardrailConfigModel"]]:
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.vigil_guard import (
|
||||
VigilGuardGuardrailConfigModel,
|
||||
)
|
||||
|
||||
return VigilGuardGuardrailConfigModel
|
||||
|
||||
@log_guardrail_information
|
||||
async def apply_guardrail(
|
||||
self,
|
||||
inputs: GenericGuardrailAPIInputs,
|
||||
request_data: dict,
|
||||
input_type: Literal["request", "response"],
|
||||
logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
) -> GenericGuardrailAPIInputs:
|
||||
texts = inputs.get("texts") or []
|
||||
has_text = any(isinstance(text, str) and text.strip() for text in texts)
|
||||
tool_call_args = (
|
||||
self._tool_call_arguments(inputs.get("tool_calls"))
|
||||
if input_type == "response"
|
||||
else []
|
||||
)
|
||||
if not has_text and not tool_call_args:
|
||||
return inputs
|
||||
|
||||
source = "user_input" if input_type == "request" else "model_output"
|
||||
metadata = self._collect_metadata(request_data, logging_obj)
|
||||
|
||||
result_texts: List[str] = []
|
||||
for index, text in enumerate(texts):
|
||||
if not isinstance(text, str) or not text.strip():
|
||||
result_texts.append(text)
|
||||
continue
|
||||
|
||||
try:
|
||||
analysis = await self._analyze(
|
||||
text=text, source=source, metadata=metadata
|
||||
)
|
||||
except (
|
||||
httpx.HTTPError,
|
||||
LiteLLMTimeout,
|
||||
JSONDecodeError,
|
||||
OSError,
|
||||
) as exc:
|
||||
return self._handle_backend_failure(
|
||||
exc,
|
||||
inputs,
|
||||
source,
|
||||
result_texts + list(texts[index:]),
|
||||
inputs.get("tool_calls"),
|
||||
)
|
||||
|
||||
decision = analysis.get("decision") if isinstance(analysis, dict) else None
|
||||
if decision not in _VALID_DECISIONS:
|
||||
verbose_proxy_logger.error(
|
||||
"Vigil Guard unrecognized decision for guardrail_name=%s "
|
||||
"source=%s: %r",
|
||||
self.guardrail_name,
|
||||
source,
|
||||
decision,
|
||||
)
|
||||
if self.unreachable_fallback == "fail_open":
|
||||
return self._build_output(
|
||||
inputs,
|
||||
result_texts + list(texts[index:]),
|
||||
inputs.get("tool_calls"),
|
||||
)
|
||||
raise GuardrailRaisedException(
|
||||
guardrail_name=self.guardrail_name,
|
||||
message="Vigil Guard returned an unrecognized decision.",
|
||||
should_wrap_with_default_message=False,
|
||||
)
|
||||
|
||||
if decision == "BLOCKED":
|
||||
raise GuardrailRaisedException(
|
||||
guardrail_name=self.guardrail_name,
|
||||
message=self._build_block_reason(analysis),
|
||||
should_wrap_with_default_message=False,
|
||||
)
|
||||
|
||||
if decision == "SANITIZED":
|
||||
result_texts.append(self._resolve_sanitized_text(text, analysis))
|
||||
else:
|
||||
result_texts.append(text)
|
||||
|
||||
result_tool_calls = inputs.get("tool_calls")
|
||||
for tc_index, arguments in tool_call_args:
|
||||
try:
|
||||
analysis = await self._analyze(
|
||||
text=arguments, source=source, metadata=metadata
|
||||
)
|
||||
except (
|
||||
httpx.HTTPError,
|
||||
LiteLLMTimeout,
|
||||
JSONDecodeError,
|
||||
OSError,
|
||||
) as exc:
|
||||
return self._handle_backend_failure(
|
||||
exc, inputs, source, result_texts, result_tool_calls
|
||||
)
|
||||
|
||||
decision = analysis.get("decision") if isinstance(analysis, dict) else None
|
||||
if decision not in _VALID_DECISIONS:
|
||||
verbose_proxy_logger.error(
|
||||
"Vigil Guard unrecognized decision for guardrail_name=%s "
|
||||
"source=%s: %r",
|
||||
self.guardrail_name,
|
||||
source,
|
||||
decision,
|
||||
)
|
||||
if self.unreachable_fallback == "fail_open":
|
||||
return self._build_output(inputs, result_texts, result_tool_calls)
|
||||
raise GuardrailRaisedException(
|
||||
guardrail_name=self.guardrail_name,
|
||||
message="Vigil Guard returned an unrecognized decision.",
|
||||
should_wrap_with_default_message=False,
|
||||
)
|
||||
|
||||
if decision == "BLOCKED":
|
||||
raise GuardrailRaisedException(
|
||||
guardrail_name=self.guardrail_name,
|
||||
message=self._build_block_reason(analysis),
|
||||
should_wrap_with_default_message=False,
|
||||
)
|
||||
|
||||
if decision == "SANITIZED":
|
||||
result_tool_calls = self._set_tool_call_arguments(
|
||||
result_tool_calls,
|
||||
tc_index,
|
||||
self._resolve_sanitized_text(arguments, analysis),
|
||||
)
|
||||
|
||||
return self._build_output(inputs, result_texts, result_tool_calls)
|
||||
|
||||
def _handle_backend_failure(
|
||||
self,
|
||||
exc: Exception,
|
||||
inputs: GenericGuardrailAPIInputs,
|
||||
source: str,
|
||||
final_texts: List[Any],
|
||||
final_tool_calls: Any,
|
||||
) -> GenericGuardrailAPIInputs:
|
||||
if self.unreachable_fallback == "fail_open":
|
||||
verbose_proxy_logger.error(
|
||||
"Vigil Guard backend failure with fail_open; allowing request "
|
||||
"unscanned. guardrail_name=%s source=%s error=%s",
|
||||
self.guardrail_name,
|
||||
source,
|
||||
str(exc),
|
||||
)
|
||||
return self._build_output(inputs, final_texts, final_tool_calls)
|
||||
verbose_proxy_logger.error(
|
||||
"Vigil Guard backend failure with fail_closed; blocking request. "
|
||||
"guardrail_name=%s source=%s error=%s",
|
||||
self.guardrail_name,
|
||||
source,
|
||||
str(exc),
|
||||
)
|
||||
raise GuardrailRaisedException(
|
||||
guardrail_name=self.guardrail_name,
|
||||
message="Vigil Guard backend unreachable; request blocked by fail_closed policy.",
|
||||
should_wrap_with_default_message=False,
|
||||
) from exc
|
||||
|
||||
@staticmethod
|
||||
def _build_output(
|
||||
inputs: GenericGuardrailAPIInputs,
|
||||
final_texts: List[Any],
|
||||
final_tool_calls: Any,
|
||||
) -> GenericGuardrailAPIInputs:
|
||||
# When nothing was changed, return the input shape verbatim so the guardrail
|
||||
# logs "allow" rather than "mask". When a text or a tool-call argument was
|
||||
# changed (sanitized), return only the remap-relevant keys and drop
|
||||
# structured_messages so a stale, unsanitized payload cannot reach the model.
|
||||
texts_changed = final_texts != (inputs.get("texts") or [])
|
||||
tool_calls_changed = final_tool_calls != inputs.get("tool_calls")
|
||||
if not texts_changed and not tool_calls_changed:
|
||||
return cast(GenericGuardrailAPIInputs, dict(inputs))
|
||||
guardrailed: GenericGuardrailAPIInputs = {"texts": final_texts}
|
||||
if "images" in inputs:
|
||||
guardrailed["images"] = inputs["images"]
|
||||
if "tools" in inputs:
|
||||
guardrailed["tools"] = inputs["tools"]
|
||||
if tool_calls_changed:
|
||||
guardrailed["tool_calls"] = final_tool_calls
|
||||
return guardrailed
|
||||
|
||||
@staticmethod
|
||||
def _tool_call_arguments(tool_calls: Any) -> List[Tuple[int, str]]:
|
||||
pairs: List[Tuple[int, str]] = []
|
||||
if isinstance(tool_calls, list):
|
||||
for index, tool_call in enumerate(tool_calls):
|
||||
function = (
|
||||
tool_call.get("function") if isinstance(tool_call, dict) else None
|
||||
)
|
||||
arguments = (
|
||||
function.get("arguments") if isinstance(function, dict) else None
|
||||
)
|
||||
if isinstance(arguments, str) and arguments.strip():
|
||||
pairs.append((index, arguments))
|
||||
return pairs
|
||||
|
||||
@staticmethod
|
||||
def _set_tool_call_arguments(
|
||||
tool_calls: Any, index: int, arguments: str
|
||||
) -> List[Any]:
|
||||
updated = list(tool_calls)
|
||||
tool_call = dict(updated[index])
|
||||
function = dict(tool_call.get("function") or {})
|
||||
function["arguments"] = arguments
|
||||
tool_call["function"] = function
|
||||
updated[index] = tool_call
|
||||
return updated
|
||||
|
||||
async def _analyze(
|
||||
self, text: str, source: str, metadata: Dict[str, Any]
|
||||
) -> Dict[str, Any]:
|
||||
payload = {
|
||||
"text": text,
|
||||
"source": source,
|
||||
"mode": "full",
|
||||
"metadata": metadata,
|
||||
}
|
||||
endpoint = f"{self.api_base}{_ANALYZE_ENDPOINT}"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
response = await self._post_with_retry(endpoint, headers, payload)
|
||||
return response.json()
|
||||
|
||||
async def _post_with_retry(
|
||||
self, endpoint: str, headers: Dict[str, str], payload: Dict[str, Any]
|
||||
) -> httpx.Response:
|
||||
for attempt in range(2):
|
||||
try:
|
||||
response = await self.async_handler.post(
|
||||
url=endpoint,
|
||||
headers=headers,
|
||||
json=payload,
|
||||
timeout=self.timeout,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response
|
||||
except Exception as exc:
|
||||
if attempt == 0 and self._is_transient(exc):
|
||||
verbose_proxy_logger.debug(
|
||||
"Vigil Guard transient failure; retrying once: %s",
|
||||
type(exc).__name__,
|
||||
)
|
||||
continue
|
||||
raise
|
||||
raise AssertionError("unreachable") # pragma: no cover
|
||||
|
||||
@staticmethod
|
||||
def _is_transient(exc: Exception) -> bool:
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
return exc.response.status_code in _TRANSIENT_STATUS_CODES
|
||||
return isinstance(
|
||||
exc,
|
||||
(
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
LiteLLMTimeout,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_block_reason(analysis: Dict[str, Any]) -> str:
|
||||
for key in ("blockMessage", "decisionReason"):
|
||||
value = analysis.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()[:_BLOCK_REASON_MAX_CHARS]
|
||||
categories = analysis.get("categories")
|
||||
if isinstance(categories, list):
|
||||
names = [c for c in categories if isinstance(c, str) and c.strip()]
|
||||
if names:
|
||||
return ", ".join(names)[:_BLOCK_REASON_MAX_CHARS]
|
||||
return "Blocked by policy"
|
||||
|
||||
@staticmethod
|
||||
def _resolve_sanitized_text(original: str, analysis: Dict[str, Any]) -> str:
|
||||
for key in ("sanitizedText", "outputText"):
|
||||
value = analysis.get(key)
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return original
|
||||
|
||||
def _collect_metadata(
|
||||
self, request_data: dict, logging_obj: Optional["LiteLLMLoggingObj"]
|
||||
) -> Dict[str, Any]:
|
||||
sources: List[dict] = []
|
||||
if isinstance(request_data, dict):
|
||||
sources.append(request_data)
|
||||
for nested_key in ("metadata", "litellm_metadata"):
|
||||
nested = request_data.get(nested_key)
|
||||
if isinstance(nested, dict):
|
||||
sources.append(nested)
|
||||
|
||||
collected: Dict[str, Any] = {}
|
||||
for field in _METADATA_ALLOWLIST:
|
||||
for source in sources:
|
||||
if field in source and source[field] is not None:
|
||||
clamped = self._clamp_metadata_value(source[field])
|
||||
if clamped is not None:
|
||||
collected[field] = clamped
|
||||
break
|
||||
|
||||
call_id = self._extract_call_id(request_data, logging_obj)
|
||||
if call_id:
|
||||
collected["litellm_call_id"] = call_id
|
||||
|
||||
return collected
|
||||
|
||||
@staticmethod
|
||||
def _clamp_metadata_value(value: Any) -> Any:
|
||||
if isinstance(value, bool):
|
||||
return None
|
||||
if isinstance(value, str):
|
||||
return value[:_METADATA_STRING_MAX_CHARS]
|
||||
if isinstance(value, (int, float)):
|
||||
return value
|
||||
if isinstance(value, list):
|
||||
clamped: List[Any] = []
|
||||
for item in value[:_METADATA_ARRAY_MAX_ITEMS]:
|
||||
if isinstance(item, bool):
|
||||
continue
|
||||
if isinstance(item, str):
|
||||
clamped.append(item[:_METADATA_STRING_MAX_CHARS])
|
||||
elif isinstance(item, (int, float)):
|
||||
clamped.append(item)
|
||||
return clamped or None
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _extract_call_id(
|
||||
request_data: dict, logging_obj: Optional["LiteLLMLoggingObj"]
|
||||
) -> Optional[str]:
|
||||
if logging_obj is not None:
|
||||
call_id = getattr(logging_obj, "litellm_call_id", None)
|
||||
if isinstance(call_id, str) and call_id:
|
||||
return call_id
|
||||
if isinstance(request_data, dict):
|
||||
call_id = request_data.get("litellm_call_id")
|
||||
if isinstance(call_id, str) and call_id:
|
||||
return call_id
|
||||
metadata = request_data.get("metadata")
|
||||
if isinstance(metadata, dict):
|
||||
nested = metadata.get("litellm_call_id")
|
||||
if isinstance(nested, str) and nested:
|
||||
return nested
|
||||
return None
|
||||
@ -217,7 +217,15 @@ def initialize_panw_prisma_airs(litellm_params, guardrail):
|
||||
mask_response_content=getattr(litellm_params, "mask_response_content", False),
|
||||
app_name=getattr(litellm_params, "app_name", None),
|
||||
fallback_on_error=getattr(litellm_params, "fallback_on_error", "block"),
|
||||
timeout=float(getattr(litellm_params, "timeout", 10.0)),
|
||||
# `timeout` is now declared on BaseLitellmParams (Optional[float] = None),
|
||||
# so the attribute always exists. The Pydantic validator on LitellmParams
|
||||
# coerces strings to float, but None still means "use handler default" —
|
||||
# guard against float(None) here.
|
||||
timeout=(
|
||||
float(getattr(litellm_params, "timeout", None))
|
||||
if getattr(litellm_params, "timeout", None) is not None
|
||||
else 10.0
|
||||
),
|
||||
violation_message_template=litellm_params.violation_message_template,
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_callback(_panw_callback)
|
||||
|
||||
@ -2433,3 +2433,89 @@ def create_generic_websocket_passthrough_endpoint(
|
||||
_forward_headers=forward_headers,
|
||||
cost_per_request=cost_per_request,
|
||||
)
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/watsonx/{endpoint:path}",
|
||||
methods=["GET", "POST", "PUT", "DELETE", "PATCH"],
|
||||
tags=["Watsonx Pass-through", "pass-through"],
|
||||
)
|
||||
async def watsonx_proxy_route(
|
||||
endpoint: str,
|
||||
request: Request,
|
||||
fastapi_response: Response,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Watsonx pass-through endpoint.
|
||||
Allows using Watsonx APIs with automatic IAM token management and version parameter injection.
|
||||
|
||||
Example:
|
||||
POST /watsonx/ml/v1/text/tokenization
|
||||
POST /watsonx/ml/v1/text/generation
|
||||
"""
|
||||
# Direct passthrough with WatsonxPassthroughConfig
|
||||
from litellm.types.utils import LlmProviders
|
||||
from litellm.utils import ProviderConfigManager
|
||||
|
||||
provider_config = ProviderConfigManager.get_provider_passthrough_config(
|
||||
provider=LlmProviders.WATSONX,
|
||||
model="",
|
||||
)
|
||||
|
||||
if provider_config is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="Watsonx passthrough config not found"
|
||||
)
|
||||
|
||||
# Get complete URL with version parameter
|
||||
complete_url, _ = provider_config.get_complete_url(
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
model="",
|
||||
endpoint=endpoint,
|
||||
request_query_params=None,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
# Get auth headers with IAM token
|
||||
auth_headers = provider_config.validate_environment(
|
||||
headers={},
|
||||
model="",
|
||||
messages=[],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
api_key=None,
|
||||
api_base=None,
|
||||
)
|
||||
|
||||
# Check for streaming
|
||||
is_streaming_request = False
|
||||
if request.method == "POST":
|
||||
if "multipart/form-data" not in request.headers.get("content-type", ""):
|
||||
_request_body = await request.json()
|
||||
else:
|
||||
_request_body = await get_form_data(request)
|
||||
|
||||
if _request_body.get("stream"):
|
||||
is_streaming_request = True
|
||||
|
||||
request_query_params = dict(request.query_params)
|
||||
if request_query_params.get("version") is None:
|
||||
request_query_params["version"] = litellm.WATSONX_DEFAULT_API_VERSION
|
||||
|
||||
# Create pass-through endpoint
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=endpoint,
|
||||
target=str(complete_url),
|
||||
custom_headers=auth_headers,
|
||||
is_streaming_request=is_streaming_request,
|
||||
custom_llm_provider="watsonx",
|
||||
query_params=request_query_params,
|
||||
)
|
||||
|
||||
return await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
)
|
||||
|
||||
@ -36,9 +36,18 @@ router = APIRouter()
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def spend_key_fn():
|
||||
async def spend_key_fn(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
View all keys created, ordered by spend
|
||||
View keys created, ordered by spend.
|
||||
|
||||
- Admin callers (PROXY_ADMIN / PROXY_ADMIN_VIEW_ONLY) see every key in
|
||||
the database.
|
||||
- All other callers (INTERNAL_USER / INTERNAL_USER_VIEW_ONLY, etc.) are
|
||||
scoped to keys they own (``user_id == caller``). A caller with no
|
||||
``user_id`` has no scope and receives an empty list rather than the
|
||||
full table.
|
||||
|
||||
Example Request:
|
||||
```
|
||||
@ -55,8 +64,17 @@ async def spend_key_fn():
|
||||
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
|
||||
key_info = await prisma_client.get_data(table_name="key", query_type="find_all")
|
||||
return key_info
|
||||
if _is_admin_view_safe(user_api_key_dict=user_api_key_dict):
|
||||
return await prisma_client.get_data(table_name="key", query_type="find_all")
|
||||
|
||||
caller_user_id = user_api_key_dict.user_id
|
||||
if not caller_user_id:
|
||||
return []
|
||||
return await prisma_client.get_data(
|
||||
table_name="key",
|
||||
query_type="find_all",
|
||||
user_id=caller_user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
@ -85,9 +103,19 @@ async def spend_user_fn(
|
||||
default=None,
|
||||
description="Get User Table row for user_id",
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
View all users created, ordered by spend
|
||||
View users created, ordered by spend.
|
||||
|
||||
- Admin callers (PROXY_ADMIN / PROXY_ADMIN_VIEW_ONLY) see every user, or
|
||||
a specific user when ``user_id`` is supplied.
|
||||
- All other callers may only read their own row. If they supply a
|
||||
``user_id`` query parameter that does not match their authenticated
|
||||
``user_id`` the request is rejected with HTTP 403; supplying their
|
||||
own id (or none at all) returns just their row. A caller with no
|
||||
``user_id`` on their key has no scope and receives an empty list
|
||||
rather than the full table.
|
||||
|
||||
Example Request:
|
||||
```
|
||||
@ -109,6 +137,17 @@ async def spend_user_fn(
|
||||
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
|
||||
if not _is_admin_view_safe(user_api_key_dict=user_api_key_dict):
|
||||
caller_user_id = user_api_key_dict.user_id
|
||||
if not caller_user_id:
|
||||
return []
|
||||
if user_id is not None and user_id != caller_user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={"error": "Not authorized to view spend for another user."},
|
||||
)
|
||||
user_id = caller_user_id
|
||||
|
||||
if user_id is not None:
|
||||
user_info = await prisma_client.get_data(
|
||||
table_name="user", query_type="find_unique", user_id=user_id
|
||||
@ -123,6 +162,8 @@ async def spend_user_fn(
|
||||
_strip_password_from_users(result)
|
||||
return result
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
|
||||
@ -41,6 +41,9 @@ from litellm.types.proxy.guardrails.guardrail_hooks.hiddenlayer import (
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.qohash import (
|
||||
QostodianNexusConfigModel,
|
||||
)
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.vigil_guard import (
|
||||
VigilGuardGuardrailConfigModel,
|
||||
)
|
||||
|
||||
"""
|
||||
Pydantic object defining how to set guardrails on litellm proxy
|
||||
@ -103,6 +106,7 @@ class SupportedGuardrailIntegrations(Enum):
|
||||
LLM_AS_A_JUDGE = "llm_as_a_judge"
|
||||
QOSTODIAN_NEXUS = "qostodian_nexus"
|
||||
RUBRIK = "rubrik"
|
||||
VIGIL_GUARD = "vigil_guard"
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
@ -758,6 +762,15 @@ class BaseLitellmParams(
|
||||
description="Python-like code containing the apply_guardrail function for custom guardrail logic",
|
||||
)
|
||||
|
||||
timeout: Optional[float] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Per-request timeout for the guardrail provider API call (seconds). "
|
||||
"Accepts int, float, or numeric string; coerced to float on load. "
|
||||
"Each guardrail handler chooses its own default when unset."
|
||||
),
|
||||
)
|
||||
|
||||
model_config = ConfigDict(extra="allow", protected_namespaces=())
|
||||
|
||||
|
||||
@ -791,6 +804,7 @@ class LitellmParams(
|
||||
BlockCodeExecutionGuardrailConfigModel,
|
||||
HiddenlayerGuardrailConfigModel,
|
||||
QostodianNexusConfigModel,
|
||||
VigilGuardGuardrailConfigModel,
|
||||
):
|
||||
guardrail: str = Field(description="The type of guardrail integration to use")
|
||||
mode: Union[str, List[str], Mode] = Field(
|
||||
@ -814,6 +828,18 @@ class LitellmParams(
|
||||
return [x.lower() if isinstance(x, str) else x for x in v]
|
||||
return v
|
||||
|
||||
@field_validator("timeout", mode="before", check_fields=False)
|
||||
@classmethod
|
||||
def coerce_timeout(cls, v):
|
||||
"""Accept string-valued timeouts (dashboard UI sends JSON strings)
|
||||
and coerce to float before any handler reads the value."""
|
||||
if v is None or v == "":
|
||||
return None
|
||||
try:
|
||||
return float(v)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise ValueError(f"timeout must be numeric, got {v!r}") from e
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
default_on = kwargs.pop("default_on", None)
|
||||
if default_on is not None:
|
||||
|
||||
@ -238,6 +238,9 @@ DEFINED_PROMETHEUS_METRICS = Literal[
|
||||
"litellm_cache_hits_metric",
|
||||
"litellm_cache_misses_metric",
|
||||
"litellm_cached_tokens_metric",
|
||||
# Provider prompt-caching metrics (e.g. OpenAI/Anthropic/Bedrock/Gemini)
|
||||
"litellm_provider_cache_read_input_tokens_metric",
|
||||
"litellm_provider_cache_creation_input_tokens_metric",
|
||||
"litellm_deployment_tpm_limit",
|
||||
"litellm_deployment_rpm_limit",
|
||||
"litellm_remaining_api_key_requests_for_model",
|
||||
@ -655,6 +658,10 @@ class PrometheusMetricLabels:
|
||||
litellm_cache_misses_metric = _cache_metric_labels
|
||||
litellm_cached_tokens_metric = _cache_metric_labels
|
||||
|
||||
# Provider prompt-caching metrics - track tokens read/written to provider caches
|
||||
litellm_provider_cache_read_input_tokens_metric = _cache_metric_labels
|
||||
litellm_provider_cache_creation_input_tokens_metric = _cache_metric_labels
|
||||
|
||||
# Metrics whose emission paths supply org context (used by get_labels)
|
||||
_org_label_metrics: ClassVar[frozenset] = frozenset(
|
||||
{
|
||||
@ -672,7 +679,6 @@ class PrometheusMetricLabels:
|
||||
"litellm_output_tokens_metric",
|
||||
}
|
||||
)
|
||||
|
||||
# Managed batch metrics
|
||||
_batch_user_labels = [
|
||||
UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value,
|
||||
|
||||
@ -0,0 +1,26 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from .base import GuardrailConfigModel
|
||||
|
||||
|
||||
class VigilGuardGuardrailConfigModel(GuardrailConfigModel):
|
||||
api_base: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Vigil Guard API base URL. "
|
||||
"Falls back to the VIGIL_GUARD_URL environment variable."
|
||||
),
|
||||
)
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Vigil Guard API key. "
|
||||
"Falls back to the VIGIL_GUARD_API_KEY environment variable."
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def ui_friendly_name() -> str:
|
||||
return "Vigil Guard"
|
||||
@ -3365,6 +3365,7 @@ class LlmProviders(str, Enum):
|
||||
POE = "poe"
|
||||
CHUTES = "chutes"
|
||||
XIAOMI_MIMO = "xiaomi_mimo"
|
||||
TENSORMESH = "tensormesh"
|
||||
LITELLM_AGENT = "litellm_agent"
|
||||
CURSOR = "cursor"
|
||||
BEDROCK_MANTLE = "bedrock_mantle"
|
||||
|
||||
@ -8985,6 +8985,12 @@ class ProviderConfigManager:
|
||||
)
|
||||
|
||||
return AzurePassthroughConfig()
|
||||
elif LlmProviders.WATSONX == provider:
|
||||
from litellm.llms.watsonx.passthrough.transformation import (
|
||||
WatsonxPassthroughConfig,
|
||||
)
|
||||
|
||||
return WatsonxPassthroughConfig()
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
||||
@ -577,7 +577,10 @@
|
||||
"max_tokens": 8192,
|
||||
"mode": "embedding",
|
||||
"output_cost_per_token": 0.0,
|
||||
"output_vector_size": 1024
|
||||
"output_vector_size": 1024,
|
||||
"provider_specific_entry": {
|
||||
"bedrock_invocation_schema": "titan_v2"
|
||||
}
|
||||
},
|
||||
"amazon.titan-image-generator-v1": {
|
||||
"input_cost_per_image": 0.0,
|
||||
@ -8899,15 +8902,16 @@
|
||||
"cache_creation_input_token_cost": 3.75e-07
|
||||
},
|
||||
"bedrock/us-gov-east-1/anthropic.claude-sonnet-4-5-20250929-v1:0": {
|
||||
"cache_creation_input_token_cost": 4.125e-06,
|
||||
"cache_read_input_token_cost": 3.3e-07,
|
||||
"input_cost_per_token": 3.3e-06,
|
||||
"cache_creation_input_token_cost": 4.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
|
||||
"cache_read_input_token_cost": 3.6e-07,
|
||||
"input_cost_per_token": 3.6e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_tokens": 8192,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.65e-05,
|
||||
"output_cost_per_token": 1.8e-05,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
@ -8920,15 +8924,16 @@
|
||||
"supports_native_structured_output": true
|
||||
},
|
||||
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
|
||||
"cache_creation_input_token_cost": 4.125e-06,
|
||||
"cache_read_input_token_cost": 3.3e-07,
|
||||
"input_cost_per_token": 3.3e-06,
|
||||
"cache_creation_input_token_cost": 4.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
|
||||
"cache_read_input_token_cost": 3.6e-07,
|
||||
"input_cost_per_token": 3.6e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_tokens": 8192,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.65e-05,
|
||||
"output_cost_per_token": 1.8e-05,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
@ -9072,15 +9077,16 @@
|
||||
"cache_creation_input_token_cost": 3.75e-07
|
||||
},
|
||||
"bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0": {
|
||||
"cache_creation_input_token_cost": 4.125e-06,
|
||||
"cache_read_input_token_cost": 3.3e-07,
|
||||
"input_cost_per_token": 3.3e-06,
|
||||
"cache_creation_input_token_cost": 4.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
|
||||
"cache_read_input_token_cost": 3.6e-07,
|
||||
"input_cost_per_token": 3.6e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_tokens": 8192,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.65e-05,
|
||||
"output_cost_per_token": 1.8e-05,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
@ -9093,15 +9099,16 @@
|
||||
"supports_native_structured_output": true
|
||||
},
|
||||
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
|
||||
"cache_creation_input_token_cost": 4.125e-06,
|
||||
"cache_read_input_token_cost": 3.3e-07,
|
||||
"input_cost_per_token": 3.3e-06,
|
||||
"cache_creation_input_token_cost": 4.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
|
||||
"cache_read_input_token_cost": 3.6e-07,
|
||||
"input_cost_per_token": 3.6e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 8192,
|
||||
"max_tokens": 8192,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.65e-05,
|
||||
"output_cost_per_token": 1.8e-05,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
@ -31733,19 +31740,21 @@
|
||||
"supports_native_structured_output": true
|
||||
},
|
||||
"us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0": {
|
||||
"cache_creation_input_token_cost": 4.125e-06,
|
||||
"cache_read_input_token_cost": 3.3e-07,
|
||||
"input_cost_per_token": 3.3e-06,
|
||||
"input_cost_per_token_above_200k_tokens": 6.6e-06,
|
||||
"output_cost_per_token_above_200k_tokens": 2.475e-05,
|
||||
"cache_creation_input_token_cost_above_200k_tokens": 8.25e-06,
|
||||
"cache_read_input_token_cost_above_200k_tokens": 6.6e-07,
|
||||
"cache_creation_input_token_cost": 4.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 7.2e-06,
|
||||
"cache_read_input_token_cost": 3.6e-07,
|
||||
"input_cost_per_token": 3.6e-06,
|
||||
"input_cost_per_token_above_200k_tokens": 7.2e-06,
|
||||
"output_cost_per_token_above_200k_tokens": 2.7e-05,
|
||||
"cache_creation_input_token_cost_above_200k_tokens": 9.0e-06,
|
||||
"cache_creation_input_token_cost_above_1hr_above_200k_tokens": 1.44e-05,
|
||||
"cache_read_input_token_cost_above_200k_tokens": 7.2e-07,
|
||||
"litellm_provider": "bedrock_converse",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 64000,
|
||||
"max_tokens": 64000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.65e-05,
|
||||
"output_cost_per_token": 1.8e-05,
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
@ -41332,6 +41341,7 @@
|
||||
},
|
||||
"bedrock/us-gov-east-1/anthropic.claude-haiku-4-5-20251001-v1:0": {
|
||||
"cache_creation_input_token_cost": 1.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 2.4e-06,
|
||||
"cache_read_input_token_cost": 1.2e-07,
|
||||
"input_cost_per_token": 1.2e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
@ -41354,6 +41364,7 @@
|
||||
},
|
||||
"bedrock/us-gov-west-1/anthropic.claude-haiku-4-5-20251001-v1:0": {
|
||||
"cache_creation_input_token_cost": 1.5e-06,
|
||||
"cache_creation_input_token_cost_above_1hr": 2.4e-06,
|
||||
"cache_read_input_token_cost": 1.2e-07,
|
||||
"input_cost_per_token": 1.2e-06,
|
||||
"litellm_provider": "bedrock",
|
||||
|
||||
@ -2079,6 +2079,24 @@
|
||||
"a2a": false
|
||||
}
|
||||
},
|
||||
"tensormesh": {
|
||||
"display_name": "Tensormesh (`tensormesh`)",
|
||||
"url": "https://docs.litellm.ai/docs/providers/tensormesh",
|
||||
"endpoints": {
|
||||
"chat_completions": true,
|
||||
"messages": true,
|
||||
"responses": false,
|
||||
"embeddings": false,
|
||||
"image_generations": false,
|
||||
"audio_transcriptions": false,
|
||||
"audio_speech": false,
|
||||
"moderations": false,
|
||||
"batches": false,
|
||||
"rerank": false,
|
||||
"a2a": false,
|
||||
"text_completion": true
|
||||
}
|
||||
},
|
||||
"text-completion-codestral": {
|
||||
"display_name": "Text Completion Codestral (`text-completion-codestral`)",
|
||||
"url": "https://docs.litellm.ai/docs/providers/codestral",
|
||||
|
||||
@ -0,0 +1,69 @@
|
||||
"""Tests for FocusTransformer — ConsumedQuantity / PricingQuantity correctness."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
import polars as pl
|
||||
|
||||
from litellm.integrations.focus.transformer import FocusTransformer
|
||||
|
||||
|
||||
def _base_row(**overrides) -> dict:
|
||||
row = {
|
||||
"date": "2026-05-25",
|
||||
"user_id": "u1",
|
||||
"api_key": "sk-test",
|
||||
"api_key_alias": "my-key",
|
||||
"model": "gpt-4o",
|
||||
"model_group": "openai",
|
||||
"custom_llm_provider": "openai",
|
||||
"spend": 0.05,
|
||||
"api_requests": 3,
|
||||
"team_id": "team1",
|
||||
"team_alias": "Engineering",
|
||||
"user_email": "user@example.com",
|
||||
}
|
||||
row.update(overrides)
|
||||
return row
|
||||
|
||||
|
||||
def _transform(rows: list[dict]) -> pl.DataFrame:
|
||||
frame = pl.DataFrame(rows, infer_schema_length=None)
|
||||
return FocusTransformer().transform(frame)
|
||||
|
||||
|
||||
def test_consumed_quantity_reflects_api_requests():
|
||||
result = _transform([_base_row(api_requests=7)])
|
||||
assert result["ConsumedQuantity"][0] == Decimal("7.000000")
|
||||
|
||||
|
||||
def test_pricing_quantity_reflects_api_requests():
|
||||
result = _transform([_base_row(api_requests=7)])
|
||||
assert result["PricingQuantity"][0] == Decimal("7.000000")
|
||||
|
||||
|
||||
def test_null_api_requests_falls_back_to_zero_not_one():
|
||||
"""Rows with NULL api_requests (old schema rows) must produce 0, not 1."""
|
||||
result = _transform([_base_row(api_requests=None)])
|
||||
assert result["ConsumedQuantity"][0] == Decimal("0.000000")
|
||||
assert result["PricingQuantity"][0] == Decimal("0.000000")
|
||||
|
||||
|
||||
def test_zero_api_requests_stays_zero():
|
||||
result = _transform([_base_row(api_requests=0)])
|
||||
assert result["ConsumedQuantity"][0] == Decimal("0.000000")
|
||||
assert result["PricingQuantity"][0] == Decimal("0.000000")
|
||||
|
||||
|
||||
def test_bigint_api_requests_cast_correctly():
|
||||
"""api_requests comes from Postgres as BigInt — large values must not overflow."""
|
||||
result = _transform([_base_row(api_requests=1_000_000)])
|
||||
assert result["ConsumedQuantity"][0] == Decimal("1000000.000000")
|
||||
assert result["PricingQuantity"][0] == Decimal("1000000.000000")
|
||||
|
||||
|
||||
def test_consumed_and_pricing_quantity_match():
|
||||
"""ConsumedQuantity and PricingQuantity must always be equal."""
|
||||
result = _transform([_base_row(api_requests=42)])
|
||||
assert result["ConsumedQuantity"][0] == result["PricingQuantity"][0]
|
||||
@ -35,6 +35,8 @@ class TestPrometheusCacheMetrics:
|
||||
assert "litellm_cache_hits_metric" in defined_metrics
|
||||
assert "litellm_cache_misses_metric" in defined_metrics
|
||||
assert "litellm_cached_tokens_metric" in defined_metrics
|
||||
assert "litellm_provider_cache_read_input_tokens_metric" in defined_metrics
|
||||
assert "litellm_provider_cache_creation_input_tokens_metric" in defined_metrics
|
||||
|
||||
def test_cache_metric_labels_defined(self):
|
||||
"""Test that cache metric labels are properly defined"""
|
||||
@ -44,6 +46,13 @@ class TestPrometheusCacheMetrics:
|
||||
assert hasattr(PrometheusMetricLabels, "litellm_cache_hits_metric")
|
||||
assert hasattr(PrometheusMetricLabels, "litellm_cache_misses_metric")
|
||||
assert hasattr(PrometheusMetricLabels, "litellm_cached_tokens_metric")
|
||||
assert hasattr(
|
||||
PrometheusMetricLabels, "litellm_provider_cache_read_input_tokens_metric"
|
||||
)
|
||||
assert hasattr(
|
||||
PrometheusMetricLabels,
|
||||
"litellm_provider_cache_creation_input_tokens_metric",
|
||||
)
|
||||
|
||||
# Verify labels include expected keys
|
||||
expected_labels = [
|
||||
@ -59,6 +68,14 @@ class TestPrometheusCacheMetrics:
|
||||
assert label in PrometheusMetricLabels.litellm_cache_hits_metric
|
||||
assert label in PrometheusMetricLabels.litellm_cache_misses_metric
|
||||
assert label in PrometheusMetricLabels.litellm_cached_tokens_metric
|
||||
assert (
|
||||
label
|
||||
in PrometheusMetricLabels.litellm_provider_cache_read_input_tokens_metric
|
||||
)
|
||||
assert (
|
||||
label
|
||||
in PrometheusMetricLabels.litellm_provider_cache_creation_input_tokens_metric
|
||||
)
|
||||
|
||||
def test_increment_cache_metrics_on_cache_hit(self, sample_enum_values):
|
||||
"""Test that cache hit increments the correct metrics"""
|
||||
@ -76,12 +93,20 @@ class TestPrometheusCacheMetrics:
|
||||
"completion_tokens": 50,
|
||||
"model_group": "openai",
|
||||
"request_tags": [],
|
||||
"metadata": {
|
||||
"usage_object": {
|
||||
"cache_read_input_tokens": 25,
|
||||
"cache_creation_input_tokens": 10,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Create mock metrics
|
||||
mock_logger.litellm_cache_hits_metric = MagicMock()
|
||||
mock_logger.litellm_cache_misses_metric = MagicMock()
|
||||
mock_logger.litellm_cached_tokens_metric = MagicMock()
|
||||
mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock()
|
||||
mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock()
|
||||
mock_logger.get_labels_for_metric = MagicMock(
|
||||
return_value=[
|
||||
"model",
|
||||
@ -114,6 +139,14 @@ class TestPrometheusCacheMetrics:
|
||||
# Verify cache misses metric was NOT called
|
||||
mock_logger.litellm_cache_misses_metric.labels.assert_not_called()
|
||||
|
||||
# Verify provider prompt caching metrics were incremented
|
||||
mock_logger.litellm_provider_cache_read_input_tokens_metric.labels().inc.assert_called_once_with(
|
||||
25
|
||||
)
|
||||
mock_logger.litellm_provider_cache_creation_input_tokens_metric.labels().inc.assert_called_once_with(
|
||||
10
|
||||
)
|
||||
|
||||
def test_increment_cache_metrics_on_cache_miss(self, sample_enum_values):
|
||||
"""Test that cache miss increments the correct metrics"""
|
||||
# Create mock for PrometheusLogger instance
|
||||
@ -129,12 +162,20 @@ class TestPrometheusCacheMetrics:
|
||||
"completion_tokens": 50,
|
||||
"model_group": "openai",
|
||||
"request_tags": [],
|
||||
"metadata": {
|
||||
"usage_object": {
|
||||
# Explicit provider field absent -> fallback should use prompt_tokens_details.cached_tokens
|
||||
"prompt_tokens_details": {"cached_tokens": 20},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Create mock metrics
|
||||
mock_logger.litellm_cache_hits_metric = MagicMock()
|
||||
mock_logger.litellm_cache_misses_metric = MagicMock()
|
||||
mock_logger.litellm_cached_tokens_metric = MagicMock()
|
||||
mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock()
|
||||
mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock()
|
||||
mock_logger.get_labels_for_metric = MagicMock(
|
||||
return_value=[
|
||||
"model",
|
||||
@ -162,6 +203,61 @@ class TestPrometheusCacheMetrics:
|
||||
mock_logger.litellm_cache_hits_metric.labels.assert_not_called()
|
||||
mock_logger.litellm_cached_tokens_metric.labels.assert_not_called()
|
||||
|
||||
# Provider prompt caching metrics should still be emitted
|
||||
mock_logger.litellm_provider_cache_read_input_tokens_metric.labels().inc.assert_called_once_with(
|
||||
20
|
||||
)
|
||||
mock_logger.litellm_provider_cache_creation_input_tokens_metric.labels.assert_not_called()
|
||||
|
||||
def test_provider_cache_read_does_not_fallback_on_explicit_zero(
|
||||
self, sample_enum_values
|
||||
):
|
||||
"""Explicit cache_read_input_tokens=0 must not trigger fallback to cached_tokens."""
|
||||
mock_logger = MagicMock()
|
||||
|
||||
from litellm.integrations.prometheus import PrometheusLogger
|
||||
|
||||
standard_logging_payload = {
|
||||
"cache_hit": False,
|
||||
"total_tokens": 100,
|
||||
"prompt_tokens": 50,
|
||||
"completion_tokens": 50,
|
||||
"model_group": "openai",
|
||||
"request_tags": [],
|
||||
"metadata": {
|
||||
"usage_object": {
|
||||
"cache_read_input_tokens": 0,
|
||||
"prompt_tokens_details": {"cached_tokens": 20},
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
mock_logger.litellm_cache_hits_metric = MagicMock()
|
||||
mock_logger.litellm_cache_misses_metric = MagicMock()
|
||||
mock_logger.litellm_cached_tokens_metric = MagicMock()
|
||||
mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock()
|
||||
mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock()
|
||||
mock_logger.get_labels_for_metric = MagicMock(
|
||||
return_value=[
|
||||
"model",
|
||||
"hashed_api_key",
|
||||
"api_key_alias",
|
||||
"team",
|
||||
"team_alias",
|
||||
"end_user",
|
||||
"user",
|
||||
]
|
||||
)
|
||||
|
||||
PrometheusLogger._increment_cache_metrics(
|
||||
mock_logger,
|
||||
standard_logging_payload=standard_logging_payload,
|
||||
enum_values=sample_enum_values,
|
||||
)
|
||||
|
||||
# Should not emit read metric, because explicit provider value is zero.
|
||||
mock_logger.litellm_provider_cache_read_input_tokens_metric.labels.assert_not_called()
|
||||
|
||||
def test_increment_cache_metrics_when_cache_hit_is_none(self, sample_enum_values):
|
||||
"""Test that no metrics are incremented when cache_hit is None"""
|
||||
# Create mock for PrometheusLogger instance
|
||||
@ -177,12 +273,19 @@ class TestPrometheusCacheMetrics:
|
||||
"completion_tokens": 50,
|
||||
"model_group": "openai",
|
||||
"request_tags": [],
|
||||
"metadata": {
|
||||
"usage_object": {
|
||||
"cache_read_input_tokens": 25,
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
# Create mock metrics
|
||||
mock_logger.litellm_cache_hits_metric = MagicMock()
|
||||
mock_logger.litellm_cache_misses_metric = MagicMock()
|
||||
mock_logger.litellm_cached_tokens_metric = MagicMock()
|
||||
mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock()
|
||||
mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock()
|
||||
mock_logger.get_labels_for_metric = MagicMock(
|
||||
return_value=[
|
||||
"model",
|
||||
@ -207,6 +310,12 @@ class TestPrometheusCacheMetrics:
|
||||
mock_logger.litellm_cache_misses_metric.labels.assert_not_called()
|
||||
mock_logger.litellm_cached_tokens_metric.labels.assert_not_called()
|
||||
|
||||
# Provider prompt caching metrics should still be emitted
|
||||
mock_logger.litellm_provider_cache_read_input_tokens_metric.labels().inc.assert_called_once_with(
|
||||
25
|
||||
)
|
||||
mock_logger.litellm_provider_cache_creation_input_tokens_metric.labels.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
43
tests/test_litellm/litellm_core_utils/test_fallback_utils.py
Normal file
43
tests/test_litellm/litellm_core_utils/test_fallback_utils.py
Normal file
@ -0,0 +1,43 @@
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.fallback_utils import async_completion_with_fallbacks
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_dict_not_mutated(monkeypatch):
|
||||
fallback_dict = {"model": "fallback-model", "temperature": 0.2}
|
||||
original_fallback_dict = dict(fallback_dict)
|
||||
|
||||
attempted_models: list[str] = []
|
||||
|
||||
async def _fake_acompletion(*, model: str, **kwargs):
|
||||
attempted_models.append(model)
|
||||
if model == "primary-model":
|
||||
raise Exception("primary failed")
|
||||
return {"model": model, "temperature": kwargs.get("temperature")}
|
||||
|
||||
monkeypatch.setattr(litellm, "acompletion", _fake_acompletion)
|
||||
|
||||
# Call 1: primary fails, fallback dict succeeds
|
||||
response_1 = await async_completion_with_fallbacks(
|
||||
model="primary-model",
|
||||
kwargs={"fallbacks": [fallback_dict]},
|
||||
)
|
||||
assert response_1["model"] == "fallback-model"
|
||||
assert fallback_dict == original_fallback_dict
|
||||
|
||||
# Call 2: re-use the same dict object; it should still work and remain unchanged
|
||||
response_2 = await async_completion_with_fallbacks(
|
||||
model="primary-model",
|
||||
kwargs={"fallbacks": [fallback_dict]},
|
||||
)
|
||||
assert response_2["model"] == "fallback-model"
|
||||
assert fallback_dict == original_fallback_dict
|
||||
|
||||
assert attempted_models == [
|
||||
"primary-model",
|
||||
"fallback-model",
|
||||
"primary-model",
|
||||
"fallback-model",
|
||||
]
|
||||
@ -0,0 +1,3 @@
|
||||
{"recordId": "embed-1", "modelInput": {"inputText": "Hello world"}}
|
||||
{"recordId": "embed-2", "modelInput": {"inputText": "Another document to embed", "dimensions": 512}}
|
||||
{"recordId": "embed-3", "modelInput": {"inputText": "Single element list", "embeddingTypes": ["binary"]}}
|
||||
@ -0,0 +1,3 @@
|
||||
{"custom_id": "embed-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "bedrock/amazon.titan-embed-text-v2:0", "input": "Hello world"}}
|
||||
{"custom_id": "embed-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "bedrock/amazon.titan-embed-text-v2:0", "input": "Another document to embed", "dimensions": 512}}
|
||||
{"custom_id": "embed-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "bedrock/amazon.titan-embed-text-v2:0", "input": ["Single element list"], "encoding_format": "base64"}}
|
||||
@ -426,7 +426,7 @@ class TestBedrockFilesTransformation:
|
||||
"s3_bucket_name": "litellm-batch-352026",
|
||||
"s3_region_name": "us-gov-west-1",
|
||||
}
|
||||
# aws_region_name set to something different — s3_region_name must still win
|
||||
# aws_region_name set to something different - s3_region_name must still win
|
||||
optional_params = {"aws_region_name": "us-east-1"}
|
||||
|
||||
captured_optional_params: dict = {}
|
||||
@ -482,3 +482,630 @@ class TestBedrockFilesTransformation:
|
||||
assert "messages" in model_input
|
||||
assert "max_tokens" in model_input
|
||||
assert model_input["max_tokens"] == 10
|
||||
|
||||
|
||||
class TestBedrockFilesEmbeddingTransformation:
|
||||
"""
|
||||
Tests for routing OpenAI /v1/embeddings batch JSONL records through the
|
||||
Titan v2 transformer so AWS Bedrock's CreateModelInvocationJob receives
|
||||
a valid modelInput body.
|
||||
|
||||
Scope is intentionally Titan v2 only - other embedding models will get
|
||||
their own follow-up PRs/tests so each schema is exercised in isolation.
|
||||
"""
|
||||
|
||||
def test_titan_v2_embedding_jsonl_matches_fixture(self):
|
||||
"""Round-trip the input fixture against the expected Bedrock output."""
|
||||
import json
|
||||
import os
|
||||
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
here = os.path.dirname(__file__)
|
||||
with open(os.path.join(here, "input_batch_embeddings.jsonl")) as f:
|
||||
openai_jsonl = [json.loads(line) for line in f if line.strip()]
|
||||
with open(os.path.join(here, "expected_bedrock_batch_embeddings.jsonl")) as f:
|
||||
expected = [json.loads(line) for line in f if line.strip()]
|
||||
|
||||
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
openai_jsonl
|
||||
)
|
||||
|
||||
assert result == expected
|
||||
|
||||
def test_titan_v2_simple_string_input(self):
|
||||
"""Single string `input` maps to `{"inputText": <str>}` with no extras."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "e1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {
|
||||
"model": "bedrock/amazon.titan-embed-text-v2:0",
|
||||
"input": "Hello",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert result == [{"recordId": "e1", "modelInput": {"inputText": "Hello"}}]
|
||||
|
||||
def test_titan_v2_dimensions_and_encoding_format(self):
|
||||
"""OpenAI `dimensions` / `encoding_format` map to Titan v2 schema."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "e1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {
|
||||
"model": "bedrock/amazon.titan-embed-text-v2:0",
|
||||
"input": "Hi",
|
||||
"dimensions": 256,
|
||||
"encoding_format": "float",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
model_input = result[0]["modelInput"]
|
||||
assert model_input["inputText"] == "Hi"
|
||||
assert model_input["dimensions"] == 256
|
||||
assert model_input["embeddingTypes"] == ["float"]
|
||||
|
||||
def test_embedding_routing_falls_back_to_body_shape(self):
|
||||
"""Records without `url` still route via `input` presence."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "e1",
|
||||
"body": {
|
||||
"model": "bedrock/amazon.titan-embed-text-v2:0",
|
||||
"input": "Hello",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert result[0]["modelInput"] == {"inputText": "Hello"}
|
||||
|
||||
def test_embedding_single_element_list_input_is_accepted(self):
|
||||
"""A single-element list maps to the same shape as a bare string."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "e1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {
|
||||
"model": "bedrock/amazon.titan-embed-text-v2:0",
|
||||
"input": ["only one"],
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert result[0]["modelInput"]["inputText"] == "only one"
|
||||
|
||||
def test_embedding_multi_input_list_raises(self):
|
||||
"""Multi-element `input` lists are rejected with a clear message."""
|
||||
import pytest
|
||||
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
with pytest.raises(ValueError, match="one input per JSONL record"):
|
||||
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "e1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {
|
||||
"model": "bedrock/amazon.titan-embed-text-v2:0",
|
||||
"input": ["a", "b"],
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def test_embedding_missing_input_raises(self):
|
||||
"""A record routed to /v1/embeddings without `input` is an error."""
|
||||
import pytest
|
||||
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
with pytest.raises(ValueError, match="missing required `input`"):
|
||||
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "e1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {"model": "bedrock/amazon.titan-embed-text-v2:0"},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def test_mixed_chat_and_embedding_in_same_batch(self):
|
||||
"""Chat and embedding records in the same JSONL each take their path."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "chat-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"max_tokens": 5,
|
||||
},
|
||||
},
|
||||
{
|
||||
"custom_id": "embed-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {
|
||||
"model": "bedrock/amazon.titan-embed-text-v2:0",
|
||||
"input": "Hi",
|
||||
},
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert result[0]["recordId"] == "chat-1"
|
||||
assert "messages" in result[0]["modelInput"]
|
||||
assert result[0]["modelInput"]["anthropic_version"] == "bedrock-2023-05-31"
|
||||
|
||||
assert result[1]["recordId"] == "embed-1"
|
||||
assert result[1]["modelInput"] == {"inputText": "Hi"}
|
||||
|
||||
def test_unsupported_embedding_model_raises_not_implemented(self):
|
||||
"""Cohere/Nova/Titan-G1 embed get a clear NotImplementedError, not a corrupt body."""
|
||||
import pytest
|
||||
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
for unsupported_model in (
|
||||
"bedrock/cohere.embed-english-v3",
|
||||
"bedrock/amazon.titan-embed-text-v1",
|
||||
"bedrock/amazon.titan-embed-image-v1",
|
||||
"bedrock/amazon.nova-2-multimodal-embeddings-v1:0",
|
||||
):
|
||||
with pytest.raises(NotImplementedError, match="titan-embed-text-v2"):
|
||||
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "e1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {"model": unsupported_model, "input": "Hi"},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def test_titan_v2_model_name_variants_route_correctly(self):
|
||||
"""All common Titan v2 model id shapes route through the embedding path."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
for model_id in (
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
"bedrock/amazon.titan-embed-text-v2:0",
|
||||
"us.amazon.titan-embed-text-v2:0",
|
||||
"bedrock/us.amazon.titan-embed-text-v2:0",
|
||||
):
|
||||
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "e1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {"model": model_id, "input": "Hi"},
|
||||
}
|
||||
]
|
||||
)
|
||||
assert result[0]["modelInput"] == {
|
||||
"inputText": "Hi"
|
||||
}, f"model id {model_id} did not route to Titan v2 embedding path"
|
||||
|
||||
def test_pretokenized_input_list_of_ints_raises(self):
|
||||
"""`input: List[int]` (pre-tokenized) is rejected, not silently mis-shaped."""
|
||||
import pytest
|
||||
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
with pytest.raises(
|
||||
(NotImplementedError, ValueError), match=r"pre-tokenized|one input per"
|
||||
):
|
||||
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "e1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {
|
||||
"model": "bedrock/amazon.titan-embed-text-v2:0",
|
||||
"input": [1, 2, 3],
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def test_pretokenized_single_wrapped_list_raises(self):
|
||||
"""`input: List[List[int]]` with one element is rejected as pre-tokenized."""
|
||||
import pytest
|
||||
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
with pytest.raises(NotImplementedError, match="pre-tokenized"):
|
||||
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "e1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {
|
||||
"model": "bedrock/amazon.titan-embed-text-v2:0",
|
||||
"input": [[1, 2, 3]],
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def test_record_with_both_input_and_messages_routes_to_chat(self):
|
||||
"""If a record has both fields, chat wins (safer default - see helper docstring)."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "ambiguous-1",
|
||||
"body": {
|
||||
"model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"input": "this should be ignored by chat path",
|
||||
"max_tokens": 5,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert "messages" in result[0]["modelInput"]
|
||||
assert "inputText" not in result[0]["modelInput"]
|
||||
|
||||
def test_url_embeddings_with_missing_input_raises_not_chat_error(self):
|
||||
"""url says embed, body lacks input → embedding-path error, not chat-path crash."""
|
||||
import pytest
|
||||
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
with pytest.raises(ValueError, match="missing required `input`"):
|
||||
config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "e1",
|
||||
"method": "POST",
|
||||
"url": "/v1/embeddings",
|
||||
"body": {"model": "bedrock/amazon.titan-embed-text-v2:0"},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def test_titan_v2_marker_boundary_rejects_lookalikes(self):
|
||||
"""The marker must end at `:`, `/`, or end-of-string to avoid false positives."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
# Look-alikes that must NOT route through the Titan v2 path
|
||||
for model in (
|
||||
"bedrock/amazon.titan-embed-text-v20:0",
|
||||
"bedrock/amazon.titan-embed-text-v2-experimental:0",
|
||||
"bedrock/amazon.titan-embed-text-v2foo",
|
||||
):
|
||||
assert not BedrockFilesConfig._is_titan_v2_embed_model(
|
||||
model
|
||||
), f"{model} unexpectedly matched the Titan v2 marker"
|
||||
|
||||
# Real Titan v2 ids that MUST match
|
||||
for model in (
|
||||
"amazon.titan-embed-text-v2:0",
|
||||
"bedrock/amazon.titan-embed-text-v2:0",
|
||||
"us.amazon.titan-embed-text-v2:0",
|
||||
"arn:aws:bedrock:us-east-1:123:foundation-model/amazon.titan-embed-text-v2:0",
|
||||
):
|
||||
assert BedrockFilesConfig._is_titan_v2_embed_model(
|
||||
model
|
||||
), f"{model} unexpectedly missed the Titan v2 marker"
|
||||
|
||||
def test_titan_v2_accepted_when_registry_schema_field_matches(self, mocker):
|
||||
"""Registry-driven happy path: nested
|
||||
`provider_specific_entry.bedrock_invocation_schema == "titan_v2"`
|
||||
is the authoritative signal."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
mocker.patch(
|
||||
"litellm.get_model_info",
|
||||
return_value={
|
||||
"provider_specific_entry": {"bedrock_invocation_schema": "titan_v2"}
|
||||
},
|
||||
)
|
||||
assert BedrockFilesConfig._is_titan_v2_embed_model(
|
||||
"amazon.titan-embed-text-v2:0"
|
||||
)
|
||||
|
||||
def test_titan_v2_rejected_when_registry_schema_field_differs(self, mocker):
|
||||
"""Registry resolves with a different schema value (e.g. a hypothetical
|
||||
Cohere Embed entry) -> reject. Registry is authoritative; no substring
|
||||
second-chance for ids the registry knows."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
mocker.patch(
|
||||
"litellm.get_model_info",
|
||||
return_value={
|
||||
"provider_specific_entry": {"bedrock_invocation_schema": "cohere_v3"}
|
||||
},
|
||||
)
|
||||
# Even though the model id looks like Titan v2, the registry says
|
||||
# otherwise and we trust it.
|
||||
assert not BedrockFilesConfig._is_titan_v2_embed_model(
|
||||
"amazon.titan-embed-text-v2:0"
|
||||
)
|
||||
|
||||
def test_titan_v2_falls_back_to_marker_when_registry_lacks_schema_field(
|
||||
self, mocker
|
||||
):
|
||||
"""Registry resolves but the entry has no
|
||||
`provider_specific_entry.bedrock_invocation_schema` field yet (e.g.
|
||||
a stale local registry) -> fall through to substring."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
# No provider_specific_entry at all
|
||||
mocker.patch(
|
||||
"litellm.get_model_info",
|
||||
return_value={"mode": "embedding"},
|
||||
)
|
||||
assert BedrockFilesConfig._is_titan_v2_embed_model(
|
||||
"amazon.titan-embed-text-v2:0"
|
||||
)
|
||||
|
||||
# provider_specific_entry present but missing the schema key
|
||||
mocker.patch(
|
||||
"litellm.get_model_info",
|
||||
return_value={
|
||||
"mode": "embedding",
|
||||
"provider_specific_entry": {"unrelated": "value"},
|
||||
},
|
||||
)
|
||||
assert BedrockFilesConfig._is_titan_v2_embed_model(
|
||||
"amazon.titan-embed-text-v2:0"
|
||||
)
|
||||
|
||||
def test_titan_v2_accepted_when_registry_silent(self, mocker):
|
||||
"""Marker-only match is fine for ids the registry can't resolve
|
||||
(cross-region profile prefixes, ARN forms)."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
mocker.patch("litellm.get_model_info", side_effect=Exception("not mapped"))
|
||||
assert BedrockFilesConfig._is_titan_v2_embed_model(
|
||||
"us.amazon.titan-embed-text-v2:0"
|
||||
)
|
||||
assert BedrockFilesConfig._is_titan_v2_embed_model(
|
||||
"arn:aws:bedrock:us-east-1:123:foundation-model/amazon.titan-embed-text-v2:0"
|
||||
)
|
||||
|
||||
def test_lookup_provider_specific_field_helper(self, mocker):
|
||||
"""Direct coverage of the nested registry field helper."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
# Happy path: returns the nested field's string value
|
||||
mocker.patch(
|
||||
"litellm.get_model_info",
|
||||
return_value={
|
||||
"provider_specific_entry": {"bedrock_invocation_schema": "titan_v2"}
|
||||
},
|
||||
)
|
||||
assert (
|
||||
BedrockFilesConfig._lookup_provider_specific_field(
|
||||
"anything", "bedrock_invocation_schema"
|
||||
)
|
||||
== "titan_v2"
|
||||
)
|
||||
|
||||
# Registry raises -> None
|
||||
mocker.patch("litellm.get_model_info", side_effect=Exception("not mapped"))
|
||||
assert (
|
||||
BedrockFilesConfig._lookup_provider_specific_field("anything", "any")
|
||||
is None
|
||||
)
|
||||
|
||||
# Registry returns non-dict -> None
|
||||
mocker.patch("litellm.get_model_info", return_value="not a dict")
|
||||
assert (
|
||||
BedrockFilesConfig._lookup_provider_specific_field("anything", "any")
|
||||
is None
|
||||
)
|
||||
|
||||
# Registry returns dict without provider_specific_entry -> None
|
||||
mocker.patch("litellm.get_model_info", return_value={"mode": "embedding"})
|
||||
assert (
|
||||
BedrockFilesConfig._lookup_provider_specific_field(
|
||||
"anything", "bedrock_invocation_schema"
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
# provider_specific_entry exists but isn't a dict -> None
|
||||
mocker.patch(
|
||||
"litellm.get_model_info",
|
||||
return_value={"provider_specific_entry": "not a dict"},
|
||||
)
|
||||
assert (
|
||||
BedrockFilesConfig._lookup_provider_specific_field(
|
||||
"anything", "bedrock_invocation_schema"
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
# provider_specific_entry dict missing the requested field -> None
|
||||
mocker.patch(
|
||||
"litellm.get_model_info",
|
||||
return_value={"provider_specific_entry": {"unrelated": "x"}},
|
||||
)
|
||||
assert (
|
||||
BedrockFilesConfig._lookup_provider_specific_field(
|
||||
"anything", "bedrock_invocation_schema"
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
# Non-string nested value -> None
|
||||
mocker.patch(
|
||||
"litellm.get_model_info",
|
||||
return_value={"provider_specific_entry": {"bedrock_invocation_schema": 42}},
|
||||
)
|
||||
assert (
|
||||
BedrockFilesConfig._lookup_provider_specific_field(
|
||||
"anything", "bedrock_invocation_schema"
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
# Empty-string nested value -> None
|
||||
mocker.patch(
|
||||
"litellm.get_model_info",
|
||||
return_value={"provider_specific_entry": {"bedrock_invocation_schema": ""}},
|
||||
)
|
||||
assert (
|
||||
BedrockFilesConfig._lookup_provider_specific_field(
|
||||
"anything", "bedrock_invocation_schema"
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_is_embedding_record_helper(self):
|
||||
"""Helper detects embeddings via `url` first, then by body shape."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
assert BedrockFilesConfig._is_embedding_record(
|
||||
{"url": "/v1/embeddings", "body": {"input": "x"}}
|
||||
)
|
||||
# body-only fallback
|
||||
assert BedrockFilesConfig._is_embedding_record({"body": {"input": "x"}})
|
||||
# chat shape
|
||||
assert not BedrockFilesConfig._is_embedding_record(
|
||||
{"url": "/v1/chat/completions", "body": {"messages": []}}
|
||||
)
|
||||
# ambiguous body without `input` is treated as not-embedding
|
||||
assert not BedrockFilesConfig._is_embedding_record({"body": {}})
|
||||
|
||||
def test_explicit_chat_url_with_input_body_short_circuits_to_chat(self):
|
||||
"""Explicit url=/v1/chat/completions wins even if body looks like embedding.
|
||||
|
||||
Without this short-circuit, a chat record whose body happens to carry
|
||||
`input` (and no `messages`) would be mis-routed to the embedding
|
||||
transformer, corrupting the modelInput.
|
||||
"""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
# Direct helper assertion
|
||||
assert not BedrockFilesConfig._is_embedding_record(
|
||||
{
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"input": "this would mis-route under the old precedence",
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
# End-to-end: a record like this routes through the chat path. We
|
||||
# just need to make sure we DON'T silently produce an inputText
|
||||
# body and call it a chat completion.
|
||||
config = BedrockFilesConfig()
|
||||
result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content(
|
||||
[
|
||||
{
|
||||
"custom_id": "explicit-chat-with-input",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"messages": [{"role": "user", "content": "Hi"}],
|
||||
"input": "should not become inputText",
|
||||
"max_tokens": 5,
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
model_input = result[0]["modelInput"]
|
||||
assert (
|
||||
"inputText" not in model_input
|
||||
), "explicit chat URL must not produce an embedding-shaped modelInput"
|
||||
|
||||
def test_coerce_embedding_input_helper_isolated(self):
|
||||
"""Direct coverage of the extracted input-normalization helper."""
|
||||
import pytest
|
||||
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
# Happy paths
|
||||
assert BedrockFilesConfig._coerce_embedding_input_to_string("hello") == "hello"
|
||||
assert (
|
||||
BedrockFilesConfig._coerce_embedding_input_to_string(["hello"]) == "hello"
|
||||
)
|
||||
|
||||
# Error paths
|
||||
with pytest.raises(ValueError, match="missing required `input`"):
|
||||
BedrockFilesConfig._coerce_embedding_input_to_string(None, model="m")
|
||||
with pytest.raises(ValueError, match="one input per JSONL record"):
|
||||
BedrockFilesConfig._coerce_embedding_input_to_string(["a", "b"])
|
||||
# A multi-element list of ints is rejected as "one input per JSONL
|
||||
# record" too - we can't tell if it's pre-tokenized or "3 strings"
|
||||
# without more context, so the most-actionable error wins.
|
||||
with pytest.raises(ValueError, match="one input per JSONL record"):
|
||||
BedrockFilesConfig._coerce_embedding_input_to_string([1, 2, 3])
|
||||
# Single-element list wrapping a token list -> pre-tokenized error.
|
||||
with pytest.raises(NotImplementedError, match="pre-tokenized"):
|
||||
BedrockFilesConfig._coerce_embedding_input_to_string([[1, 2, 3]])
|
||||
# Single-element list wrapping a bare int -> pre-tokenized error.
|
||||
with pytest.raises(NotImplementedError, match="pre-tokenized"):
|
||||
BedrockFilesConfig._coerce_embedding_input_to_string([42])
|
||||
with pytest.raises(ValueError, match="must be a string"):
|
||||
BedrockFilesConfig._coerce_embedding_input_to_string({"unsupported": True})
|
||||
|
||||
def test_other_non_embedding_urls_route_to_chat(self):
|
||||
"""Any non-/v1/embeddings url short-circuits to chat path."""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
# /v1/completions (legacy completions endpoint)
|
||||
assert not BedrockFilesConfig._is_embedding_record(
|
||||
{"url": "/v1/completions", "body": {"input": "x"}}
|
||||
)
|
||||
# Arbitrary unknown url - caller's explicit signal still wins
|
||||
assert not BedrockFilesConfig._is_embedding_record(
|
||||
{"url": "/v1/responses", "body": {"input": "x"}}
|
||||
)
|
||||
|
||||
@ -1,8 +1,9 @@
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from botocore.credentials import Credentials
|
||||
|
||||
|
||||
def _anthropic_response(url: str) -> httpx.Response:
|
||||
@ -310,3 +311,54 @@ async def test_anthropic_messages_routes_bedrock_claude_platform_to_messages_api
|
||||
assert requests[0]["body"]["messages"] == [{"role": "user", "content": "hello"}]
|
||||
assert requests[0]["body"]["max_tokens"] == 10
|
||||
assert requests[0]["body"]["model"] == "claude-sonnet-4-6"
|
||||
|
||||
|
||||
def test_sigv4_no_duplicate_content_type_when_caller_sets_lowercase():
|
||||
"""
|
||||
Regression: get_anthropic_headers() supplies "content-type" (lowercase).
|
||||
_sign_request() used to prepend "Content-Type" (uppercase), leaving both
|
||||
keys in the dict. botocore joins them into "application/json, application/json"
|
||||
in the canonical string, while the wire request sends only one value → 401.
|
||||
|
||||
Fix: prepend with lowercase "content-type" so **headers overwrites it when
|
||||
the caller already set it.
|
||||
"""
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
|
||||
llm = BaseAWSLLM()
|
||||
mock_credentials = Credentials("key", "secret", "token")
|
||||
mock_sigv4 = MagicMock()
|
||||
captured: list[dict] = []
|
||||
|
||||
def fake_aws_request(method, url, data, headers):
|
||||
captured.append(dict(headers))
|
||||
req = MagicMock()
|
||||
req.headers = {"Authorization": "AWS4-HMAC-SHA256 Credential=test"}
|
||||
req.body = data.encode() if isinstance(data, str) else data
|
||||
return req
|
||||
|
||||
with (
|
||||
patch("botocore.auth.SigV4Auth", return_value=mock_sigv4),
|
||||
patch("botocore.awsrequest.AWSRequest", side_effect=fake_aws_request),
|
||||
patch.object(llm, "get_credentials", return_value=mock_credentials),
|
||||
patch.object(llm, "_get_aws_region_name", return_value="us-east-1"),
|
||||
):
|
||||
llm._sign_request(
|
||||
service_name="aws-external-anthropic",
|
||||
headers={"content-type": "application/json"},
|
||||
optional_params={"aws_region_name": "us-east-1"},
|
||||
request_data={
|
||||
"model": "claude-sonnet-4-6",
|
||||
"messages": [],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
api_base="https://aws-external-anthropic.us-east-1.api.aws/v1/messages",
|
||||
)
|
||||
|
||||
signed = captured[0]
|
||||
ct_keys = [k for k in signed if k.lower() == "content-type"]
|
||||
assert ct_keys == ["content-type"], (
|
||||
f"Expected exactly one 'content-type' key, got {ct_keys}. "
|
||||
"Duplicate keys produce 'application/json, application/json' in the "
|
||||
"SigV4 canonical string and cause a 401."
|
||||
)
|
||||
|
||||
@ -0,0 +1,67 @@
|
||||
"""
|
||||
Tests for Black Forest Labs common_utils — specifically assert_bfl_polling_url.
|
||||
|
||||
BFL uses regional subdomains (e.g. gateway.bfl.ai) for polling URLs that
|
||||
differ from the submission host (api.bfl.ai). These tests verify that the
|
||||
domain-aware check accepts legitimate BFL subdomains while still rejecting
|
||||
off-domain and non-HTTPS URLs.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.llms.black_forest_labs.common_utils import (
|
||||
BlackForestLabsError,
|
||||
assert_bfl_polling_url,
|
||||
)
|
||||
|
||||
|
||||
class TestAssertBflPollingUrl:
|
||||
# --- should pass ---
|
||||
|
||||
def test_exact_registered_domain(self):
|
||||
assert_bfl_polling_url("https://bfl.ai/v1/get_result?id=abc")
|
||||
|
||||
def test_api_subdomain(self):
|
||||
assert_bfl_polling_url("https://api.bfl.ai/v1/get_result?id=abc")
|
||||
|
||||
def test_gateway_subdomain(self):
|
||||
# BFL uses gateway.bfl.ai for polling — this was the original bug trigger
|
||||
assert_bfl_polling_url("https://gateway.bfl.ai/v1/get_result?id=abc")
|
||||
|
||||
def test_regional_subdomain(self):
|
||||
assert_bfl_polling_url("https://eu.api.bfl.ai/v1/get_result?id=abc")
|
||||
|
||||
def test_deep_subdomain(self):
|
||||
assert_bfl_polling_url("https://region.gateway.bfl.ai/poll?id=xyz")
|
||||
|
||||
# --- should raise BlackForestLabsError ---
|
||||
|
||||
def test_rejects_http_scheme(self):
|
||||
# HTTP must be rejected — x-key would be forwarded in plaintext
|
||||
with pytest.raises(BlackForestLabsError, match="scheme must be https"):
|
||||
assert_bfl_polling_url("http://api.bfl.ai/v1/get_result?id=abc")
|
||||
|
||||
def test_rejects_off_domain(self):
|
||||
with pytest.raises(BlackForestLabsError, match="host is not within"):
|
||||
assert_bfl_polling_url("https://evil.com/steal-key")
|
||||
|
||||
def test_rejects_lookalike_domain(self):
|
||||
with pytest.raises(BlackForestLabsError, match="host is not within"):
|
||||
assert_bfl_polling_url("https://notbfl.ai/v1/get_result?id=abc")
|
||||
|
||||
def test_rejects_bfl_ai_as_suffix_only(self):
|
||||
# "fakebfl.ai" must not match — the check is on registered domain boundary
|
||||
with pytest.raises(BlackForestLabsError, match="host is not within"):
|
||||
assert_bfl_polling_url("https://fakebfl.ai/v1/get_result?id=abc")
|
||||
|
||||
def test_rejects_bfl_in_path(self):
|
||||
with pytest.raises(BlackForestLabsError, match="host is not within"):
|
||||
assert_bfl_polling_url("https://evil.com/bfl.ai/steal")
|
||||
|
||||
def test_rejects_ftp_scheme(self):
|
||||
with pytest.raises(BlackForestLabsError, match="scheme must be https"):
|
||||
assert_bfl_polling_url("ftp://api.bfl.ai/v1/get_result?id=abc")
|
||||
|
||||
def test_rejects_javascript_scheme(self):
|
||||
with pytest.raises(BlackForestLabsError, match="scheme must be https"):
|
||||
assert_bfl_polling_url("javascript://api.bfl.ai/alert(1)")
|
||||
@ -8,8 +8,7 @@ with guardrail transformations, including tool calls.
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Any, List, Literal, Optional, Tuple
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
import pytest
|
||||
|
||||
@ -84,6 +83,70 @@ class MockGuardrail(CustomGuardrail):
|
||||
return result
|
||||
|
||||
|
||||
class MockCopiedToolCallGuardrail(CustomGuardrail):
|
||||
"""Mock guardrail that returns copied tool calls instead of mutating inputs."""
|
||||
|
||||
async def apply_guardrail(
|
||||
self,
|
||||
inputs: GenericGuardrailAPIInputs,
|
||||
request_data: dict,
|
||||
input_type: Literal["request", "response"],
|
||||
logging_obj: Optional[Any] = None,
|
||||
) -> GenericGuardrailAPIInputs:
|
||||
tool_calls = inputs.get("tool_calls", [])
|
||||
copied_tool_calls = []
|
||||
for tool_call in tool_calls:
|
||||
copied = dict(tool_call)
|
||||
function = dict(copied["function"])
|
||||
function["arguments"] = json.dumps({"email": "[EMAIL]"})
|
||||
copied["function"] = function
|
||||
copied_tool_calls.append(copied)
|
||||
|
||||
return GenericGuardrailAPIInputs(
|
||||
texts=inputs.get("texts", []),
|
||||
tool_calls=copied_tool_calls,
|
||||
)
|
||||
|
||||
|
||||
class MockNonListToolCallGuardrail(CustomGuardrail):
|
||||
"""Mock guardrail that returns tool_calls as a non-list envelope on the response
|
||||
path, as some released guardrails do when they assign a detection API JSON dict."""
|
||||
|
||||
async def apply_guardrail(
|
||||
self,
|
||||
inputs: GenericGuardrailAPIInputs,
|
||||
request_data: dict,
|
||||
input_type: Literal["request", "response"],
|
||||
logging_obj: Optional[Any] = None,
|
||||
) -> GenericGuardrailAPIInputs:
|
||||
result = GenericGuardrailAPIInputs(texts=inputs.get("texts", []))
|
||||
result["tool_calls"] = {"verdict": "allow", "detections": []} # type: ignore
|
||||
return result
|
||||
|
||||
|
||||
class MockMisalignedToolCallGuardrail(CustomGuardrail):
|
||||
"""Mock guardrail that returns a tool_calls list whose length differs from the
|
||||
input, so it cannot be applied positionally onto the response."""
|
||||
|
||||
async def apply_guardrail(
|
||||
self,
|
||||
inputs: GenericGuardrailAPIInputs,
|
||||
request_data: dict,
|
||||
input_type: Literal["request", "response"],
|
||||
logging_obj: Optional[Any] = None,
|
||||
) -> GenericGuardrailAPIInputs:
|
||||
tool_calls = inputs.get("tool_calls", [])
|
||||
shortened = []
|
||||
if tool_calls:
|
||||
first = dict(tool_calls[0])
|
||||
first["function"] = {"name": "x", "arguments": json.dumps({"x": 1})}
|
||||
shortened.append(first)
|
||||
return GenericGuardrailAPIInputs(
|
||||
texts=inputs.get("texts", []),
|
||||
tool_calls=shortened,
|
||||
)
|
||||
|
||||
|
||||
class TestOpenAIChatCompletionsHandlerToolsInput:
|
||||
"""Test input processing with tools (function definitions)"""
|
||||
|
||||
@ -740,6 +803,131 @@ class TestOpenAIChatCompletionsHandlerToolCallsOutput:
|
||||
assert response.model == "gpt-4o-mini"
|
||||
assert response.choices[0].finish_reason == "tool_calls"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_response_uses_returned_guardrailed_tool_calls(self):
|
||||
"""Test returned tool_calls are remapped even when guardrail does not mutate inputs."""
|
||||
handler = OpenAIChatCompletionsHandler()
|
||||
guardrail = MockCopiedToolCallGuardrail(guardrail_name="test")
|
||||
|
||||
response = ModelResponse(
|
||||
id="chatcmpl-tool-copy",
|
||||
created=1234567890,
|
||||
model="gpt-4",
|
||||
object="chat.completion",
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=Message(
|
||||
content=None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="call_email",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="send_email",
|
||||
arguments=json.dumps({"email": "john@example.com"}),
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
await handler.process_output_response(response, guardrail)
|
||||
|
||||
response_tool_call = response.choices[0].message.tool_calls[0]
|
||||
assert response_tool_call.function.name == "send_email"
|
||||
assert json.loads(response_tool_call.function.arguments) == {"email": "[EMAIL]"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_response_ignores_non_list_returned_tool_calls(self):
|
||||
"""A guardrail returning tool_calls as a non-list (e.g. a detection-API envelope
|
||||
dict) must not crash the remap; the original arguments are preserved."""
|
||||
handler = OpenAIChatCompletionsHandler()
|
||||
guardrail = MockNonListToolCallGuardrail(guardrail_name="test")
|
||||
original = json.dumps({"email": "john@example.com"})
|
||||
response = ModelResponse(
|
||||
id="chatcmpl-nonlist",
|
||||
created=1234567890,
|
||||
model="gpt-4",
|
||||
object="chat.completion",
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=Message(
|
||||
content=None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="call_email",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="send_email", arguments=original
|
||||
),
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
await handler.process_output_response(response, guardrail)
|
||||
|
||||
response_tool_call = response.choices[0].message.tool_calls[0]
|
||||
assert response_tool_call.function.arguments == original
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_output_response_ignores_misaligned_returned_tool_calls(self):
|
||||
"""A guardrail returning a tool_calls list of a different length than the input
|
||||
cannot be applied positionally; the handler falls back and preserves the
|
||||
original arguments instead of writing onto the wrong tool call."""
|
||||
handler = OpenAIChatCompletionsHandler()
|
||||
guardrail = MockMisalignedToolCallGuardrail(guardrail_name="test")
|
||||
first_args = json.dumps({"email": "a@example.com"})
|
||||
second_args = json.dumps({"email": "b@example.com"})
|
||||
response = ModelResponse(
|
||||
id="chatcmpl-misaligned",
|
||||
created=1234567890,
|
||||
model="gpt-4",
|
||||
object="chat.completion",
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=Message(
|
||||
content=None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="send_email", arguments=first_args
|
||||
),
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
id="call_2",
|
||||
type="function",
|
||||
function=Function(
|
||||
name="send_email", arguments=second_args
|
||||
),
|
||||
),
|
||||
],
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
await handler.process_output_response(response, guardrail)
|
||||
|
||||
tool_calls = response.choices[0].message.tool_calls
|
||||
assert tool_calls[0].function.arguments == first_args
|
||||
assert tool_calls[1].function.arguments == second_args
|
||||
|
||||
|
||||
class MockPassThroughGuardrail(CustomGuardrail):
|
||||
"""Mock guardrail that passes through without blocking - for testing streaming fallback behavior"""
|
||||
@ -765,7 +953,7 @@ class TestOpenAIChatCompletionsHandlerStreamingOutput:
|
||||
This test verifies the fix for the bug where accessing chunk.choices[0]
|
||||
would raise IndexError when a streaming chunk has an empty choices list.
|
||||
"""
|
||||
from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices
|
||||
from litellm.types.utils import ModelResponseStream
|
||||
|
||||
handler = OpenAIChatCompletionsHandler()
|
||||
guardrail = MockPassThroughGuardrail(guardrail_name="test")
|
||||
|
||||
@ -0,0 +1,84 @@
|
||||
"""
|
||||
Tests for Tensormesh provider configuration and integration.
|
||||
"""
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class TestTensormeshProviderConfig:
|
||||
"""Test Tensormesh provider configuration"""
|
||||
|
||||
def test_tensormesh_in_provider_list(self):
|
||||
"""Test that tensormesh is in the provider list"""
|
||||
from litellm import LlmProviders
|
||||
|
||||
assert hasattr(LlmProviders, "TENSORMESH")
|
||||
assert LlmProviders.TENSORMESH.value == "tensormesh"
|
||||
assert "tensormesh" in litellm.provider_list
|
||||
|
||||
def test_tensormesh_json_config_exists(self):
|
||||
"""Test that tensormesh is configured in providers.json"""
|
||||
from litellm.llms.openai_like.json_loader import JSONProviderRegistry
|
||||
|
||||
assert JSONProviderRegistry.exists("tensormesh")
|
||||
|
||||
tensormesh = JSONProviderRegistry.get("tensormesh")
|
||||
assert tensormesh is not None
|
||||
assert tensormesh.base_url == "https://serverless.tensormesh.ai/v1"
|
||||
assert tensormesh.api_key_env == "TENSORMESH_INFERENCE_API_KEY"
|
||||
assert tensormesh.api_base_env == "TENSORMESH_SERVERLESS_BASE_URL"
|
||||
assert tensormesh.param_mappings.get("max_completion_tokens") == "max_tokens"
|
||||
|
||||
def test_tensormesh_provider_resolution(self):
|
||||
"""Test that provider resolution finds tensormesh and the default base URL"""
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
|
||||
model, provider, api_key, api_base = get_llm_provider(
|
||||
model="tensormesh/openai/gpt-oss-120b",
|
||||
custom_llm_provider=None,
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
)
|
||||
|
||||
assert model == "openai/gpt-oss-120b"
|
||||
assert provider == "tensormesh"
|
||||
assert api_base == "https://serverless.tensormesh.ai/v1"
|
||||
|
||||
def test_tensormesh_api_base_override(self):
|
||||
"""Test that an explicit api_base / api_key overrides the serverless default"""
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
|
||||
model, provider, api_key, api_base = get_llm_provider(
|
||||
model="tensormesh/openai/gpt-oss-120b",
|
||||
custom_llm_provider=None,
|
||||
api_base="https://custom.example.com/v1",
|
||||
api_key="sk-test",
|
||||
)
|
||||
|
||||
assert provider == "tensormesh"
|
||||
assert api_base == "https://custom.example.com/v1"
|
||||
assert api_key == "sk-test"
|
||||
|
||||
def test_tensormesh_text_completion_enabled(self):
|
||||
"""Tensormesh is wired for the /completions (text completion) route,
|
||||
matching the text_completion flag in provider_endpoints_support.json."""
|
||||
assert "tensormesh" in litellm.openai_text_completion_compatible_providers
|
||||
|
||||
def test_tensormesh_router_config(self):
|
||||
"""Test that tensormesh can be used in Router configuration"""
|
||||
from litellm import Router
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "tensormesh-chat",
|
||||
"litellm_params": {
|
||||
"model": "tensormesh/openai/gpt-oss-120b",
|
||||
"api_key": "test-key",
|
||||
},
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
assert len(router.model_list) == 1
|
||||
assert router.model_list[0]["model_name"] == "tensormesh-chat"
|
||||
@ -0,0 +1,282 @@
|
||||
"""
|
||||
Unit tests for WatsonxPassthroughConfig transformation.
|
||||
|
||||
Tests the Watsonx-specific passthrough configuration including URL construction,
|
||||
streaming detection, and authentication handling.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../../../.."))
|
||||
|
||||
import litellm
|
||||
from litellm.llms.watsonx.passthrough.transformation import WatsonxPassthroughConfig
|
||||
|
||||
|
||||
class TestWatsonxPassthroughConfig:
|
||||
"""Tests for WatsonxPassthroughConfig class."""
|
||||
|
||||
def test_is_streaming_request_true(self):
|
||||
"""Test that streaming is detected when stream=True in request data."""
|
||||
config = WatsonxPassthroughConfig()
|
||||
request_data = {"stream": True, "input": "test"}
|
||||
|
||||
result = config.is_streaming_request(
|
||||
endpoint="ml/v1/text/generation", request_data=request_data
|
||||
)
|
||||
|
||||
assert result is True
|
||||
|
||||
def test_is_streaming_request_false(self):
|
||||
"""Test that streaming is not detected when stream=False in request data."""
|
||||
config = WatsonxPassthroughConfig()
|
||||
request_data = {"stream": False, "input": "test"}
|
||||
|
||||
result = config.is_streaming_request(
|
||||
endpoint="ml/v1/text/generation", request_data=request_data
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_is_streaming_request_missing_stream_key(self):
|
||||
"""Test that streaming defaults to False when stream key is missing."""
|
||||
config = WatsonxPassthroughConfig()
|
||||
request_data = {"input": "test"}
|
||||
|
||||
result = config.is_streaming_request(
|
||||
endpoint="ml/v1/text/generation", request_data=request_data
|
||||
)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_get_complete_url_with_api_base(self):
|
||||
"""Test URL construction with explicit api_base."""
|
||||
config = WatsonxPassthroughConfig()
|
||||
api_base = "https://us-south.ml.cloud.ibm.com"
|
||||
endpoint = "ml/v1/text/generation"
|
||||
request_query_params = {"version": "2024-03-19"}
|
||||
|
||||
complete_url, base_target_url = config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=None,
|
||||
model="ibm/granite-13b-chat-v2",
|
||||
endpoint=endpoint,
|
||||
request_query_params=request_query_params,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert isinstance(complete_url, httpx.URL)
|
||||
assert str(complete_url).startswith(api_base)
|
||||
assert endpoint in str(complete_url)
|
||||
assert "version=2024-03-19" in str(complete_url)
|
||||
assert base_target_url == api_base
|
||||
|
||||
@patch("litellm.llms.watsonx.common_utils.get_secret_str")
|
||||
def test_get_complete_url_with_env_api_base(self, mock_get_secret):
|
||||
"""Test URL construction with api_base from environment."""
|
||||
config = WatsonxPassthroughConfig()
|
||||
env_api_base = "https://eu-de.ml.cloud.ibm.com"
|
||||
mock_get_secret.return_value = env_api_base
|
||||
|
||||
endpoint = "ml/v1/text/tokenization"
|
||||
request_query_params = {"version": "2024-03-19"}
|
||||
|
||||
complete_url, base_target_url = config.get_complete_url(
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
model="ibm/granite-13b-chat-v2",
|
||||
endpoint=endpoint,
|
||||
request_query_params=request_query_params,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert isinstance(complete_url, httpx.URL)
|
||||
assert str(complete_url).startswith(env_api_base)
|
||||
assert endpoint in str(complete_url)
|
||||
assert base_target_url == env_api_base
|
||||
|
||||
def test_get_complete_url_with_query_params(self):
|
||||
"""Test that query parameters are correctly added to URL."""
|
||||
config = WatsonxPassthroughConfig()
|
||||
api_base = "https://us-south.ml.cloud.ibm.com"
|
||||
endpoint = "ml/v1/text/generation"
|
||||
request_query_params = {
|
||||
"version": "2024-03-19",
|
||||
}
|
||||
|
||||
complete_url, _ = config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=None,
|
||||
model="ibm/granite-13b-chat-v2",
|
||||
endpoint=endpoint,
|
||||
request_query_params=request_query_params,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
url_str = str(complete_url)
|
||||
assert "version=2024-03-19" in url_str
|
||||
|
||||
def test_get_complete_url_without_query_params(self):
|
||||
"""Test URL construction without query parameters."""
|
||||
config = WatsonxPassthroughConfig()
|
||||
api_base = "https://us-south.ml.cloud.ibm.com"
|
||||
endpoint = "ml/v1/models"
|
||||
|
||||
complete_url, base_target_url = config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=None,
|
||||
model="",
|
||||
endpoint=endpoint,
|
||||
request_query_params=None,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert isinstance(complete_url, httpx.URL)
|
||||
assert str(complete_url) == f"{api_base}/{endpoint}"
|
||||
assert base_target_url == api_base
|
||||
assert "version=2024-03-19" not in str(complete_url)
|
||||
|
||||
@patch("litellm.llms.watsonx.common_utils.get_secret_str")
|
||||
def test_get_api_base_with_explicit_value(self, mock_get_secret):
|
||||
"""Test get_api_base returns explicit value when provided."""
|
||||
explicit_base = "https://custom.watsonx.com"
|
||||
|
||||
result = WatsonxPassthroughConfig.get_api_base(api_base=explicit_base)
|
||||
|
||||
assert result == explicit_base
|
||||
mock_get_secret.assert_not_called()
|
||||
|
||||
@patch("litellm.llms.watsonx.common_utils.get_secret_str")
|
||||
def test_get_api_base_from_environment(self, mock_get_secret):
|
||||
"""Test get_api_base retrieves from environment when not provided."""
|
||||
env_base = "https://env.watsonx.com"
|
||||
mock_get_secret.return_value = env_base
|
||||
|
||||
result = WatsonxPassthroughConfig.get_api_base(api_base=None)
|
||||
|
||||
assert result == env_base
|
||||
mock_get_secret.assert_called_once_with("WATSONX_API_BASE")
|
||||
|
||||
@patch("litellm.llms.watsonx.common_utils.get_secret_str")
|
||||
def test_get_api_key_with_explicit_value(self, mock_get_secret):
|
||||
"""Test get_api_key returns explicit value when provided."""
|
||||
explicit_key = "test-api-key-123"
|
||||
|
||||
result = WatsonxPassthroughConfig.get_api_key(api_key=explicit_key)
|
||||
|
||||
assert result == explicit_key
|
||||
mock_get_secret.assert_not_called()
|
||||
|
||||
@patch("litellm.llms.watsonx.common_utils.get_secret_str")
|
||||
def test_get_api_key_from_environment(self, mock_get_secret):
|
||||
"""Test get_api_key retrieves from environment when not provided."""
|
||||
env_key = "env-api-key-456"
|
||||
mock_get_secret.return_value = env_key
|
||||
|
||||
result = WatsonxPassthroughConfig.get_api_key(api_key=None)
|
||||
|
||||
assert result == env_key
|
||||
mock_get_secret.assert_any_call("WATSONX_APIKEY")
|
||||
|
||||
def test_get_base_model_returns_model(self):
|
||||
"""Test get_base_model returns the model as-is."""
|
||||
model = "ibm/granite-13b-chat-v2"
|
||||
|
||||
result = WatsonxPassthroughConfig.get_base_model(model)
|
||||
|
||||
assert result == model
|
||||
|
||||
def test_get_base_model_with_deployment(self):
|
||||
"""Test get_base_model with deployment model."""
|
||||
model = "deployment/test-deployment-id"
|
||||
|
||||
result = WatsonxPassthroughConfig.get_base_model(model)
|
||||
|
||||
assert result == model
|
||||
|
||||
def test_get_complete_url_with_different_endpoints(self):
|
||||
"""Test URL construction with various endpoint paths."""
|
||||
config = WatsonxPassthroughConfig()
|
||||
api_base = "https://us-south.ml.cloud.ibm.com"
|
||||
|
||||
endpoints = [
|
||||
"ml/v1/text/generation",
|
||||
"ml/v1/text/tokenization",
|
||||
"ml/v1/deployments/test-id/text/generation",
|
||||
"ml/v1/models",
|
||||
"ml/v1/foundation_model_specs",
|
||||
]
|
||||
|
||||
for endpoint in endpoints:
|
||||
complete_url, base_target_url = config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=None,
|
||||
model="",
|
||||
endpoint=endpoint,
|
||||
request_query_params={"version": "2024-03-19"},
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
assert isinstance(complete_url, httpx.URL)
|
||||
assert endpoint in str(complete_url)
|
||||
assert base_target_url == api_base
|
||||
|
||||
def test_get_complete_url_preserves_query_param_order(self):
|
||||
"""Test that query parameters maintain their values correctly."""
|
||||
config = WatsonxPassthroughConfig()
|
||||
api_base = "https://us-south.ml.cloud.ibm.com"
|
||||
endpoint = "ml/v1/text/generation"
|
||||
request_query_params = {
|
||||
"version": "2024-03-19",
|
||||
"project_id": "abc-123",
|
||||
"space_id": "xyz-789",
|
||||
}
|
||||
|
||||
complete_url, _ = config.get_complete_url(
|
||||
api_base=api_base,
|
||||
api_key=None,
|
||||
model="",
|
||||
endpoint=endpoint,
|
||||
request_query_params=request_query_params,
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
url_str = str(complete_url)
|
||||
# Verify all params are present
|
||||
assert "version=2024-03-19" in url_str
|
||||
assert "project_id=abc-123" in url_str
|
||||
assert "space_id=xyz-789" in url_str
|
||||
|
||||
def test_is_streaming_request_with_various_stream_values(self):
|
||||
"""Test streaming detection with different stream value types."""
|
||||
config = WatsonxPassthroughConfig()
|
||||
|
||||
# Test with boolean True
|
||||
assert config.is_streaming_request("endpoint", {"stream": True}) is True
|
||||
|
||||
# Test with boolean False
|
||||
assert config.is_streaming_request("endpoint", {"stream": False}) is False
|
||||
|
||||
# Test with string "true" (truthy string)
|
||||
result = config.is_streaming_request("endpoint", {"stream": "true"})
|
||||
assert result == "true" # Returns the value as-is from .get()
|
||||
|
||||
# Test with integer 1 (truthy)
|
||||
result = config.is_streaming_request("endpoint", {"stream": 1})
|
||||
assert result == 1
|
||||
|
||||
# Test with integer 0 (falsy)
|
||||
result = config.is_streaming_request("endpoint", {"stream": 0})
|
||||
assert result == 0
|
||||
|
||||
# Test with None
|
||||
result = config.is_streaming_request("endpoint", {"stream": None})
|
||||
assert result is None
|
||||
|
||||
# Test with empty dict (defaults to False)
|
||||
assert config.is_streaming_request("endpoint", {}) is False
|
||||
@ -5375,5 +5375,93 @@ class TestPanwAirsDualScanIndependence:
|
||||
assert mcp_call.get("content") is None
|
||||
|
||||
|
||||
class TestPanwAirsTimeoutCoercion:
|
||||
"""Regression tests for string-valued timeout handling.
|
||||
|
||||
Before the fix, a string `timeout` (which is what the dashboard UI persists
|
||||
and what raw YAML preserves if quoted) survived into httpx, which raised
|
||||
`TypeError: '<=' not supported between instances of 'str' and 'int'`. The
|
||||
broad except in apply_guardrail swallowed it and the proxy returned a
|
||||
misleading 500 'Security scan failed - request blocked for safety'.
|
||||
"""
|
||||
|
||||
def test_handler_coerces_string_timeout_to_float(self):
|
||||
handler = make_handler(timeout="30")
|
||||
assert handler.timeout == 30.0
|
||||
assert isinstance(handler.timeout, float)
|
||||
|
||||
def test_handler_accepts_int_timeout(self):
|
||||
handler = make_handler(timeout=15)
|
||||
assert handler.timeout == 15.0
|
||||
|
||||
def test_handler_accepts_float_timeout(self):
|
||||
handler = make_handler(timeout=7.5)
|
||||
assert handler.timeout == 7.5
|
||||
|
||||
def test_handler_none_timeout_falls_back_to_default(self):
|
||||
handler = make_handler(timeout=None)
|
||||
assert handler.timeout == 10.0
|
||||
|
||||
def test_handler_omitted_timeout_uses_default(self):
|
||||
handler = make_handler()
|
||||
assert handler.timeout == 10.0
|
||||
|
||||
def test_litellm_params_coerces_string_timeout(self):
|
||||
"""Boundary validation: the Pydantic model itself should normalize
|
||||
string timeouts before any handler reads the value via model_dump()."""
|
||||
params = LitellmParams(
|
||||
guardrail="panw_prisma_airs",
|
||||
mode="pre_call",
|
||||
api_key="test_key",
|
||||
profile_name="test_profile",
|
||||
timeout="30",
|
||||
)
|
||||
assert params.timeout == 30.0
|
||||
assert isinstance(params.timeout, float)
|
||||
|
||||
def test_litellm_params_rejects_garbage_timeout(self):
|
||||
with pytest.raises(ValueError):
|
||||
LitellmParams(
|
||||
guardrail="panw_prisma_airs",
|
||||
mode="pre_call",
|
||||
api_key="test_key",
|
||||
profile_name="test_profile",
|
||||
timeout="not-a-number",
|
||||
)
|
||||
|
||||
def test_litellm_params_empty_string_timeout_becomes_none(self):
|
||||
"""Empty-string timeout (which the dashboard form can send) should
|
||||
be coerced to None, not crash, and not produce float('')."""
|
||||
params = LitellmParams(
|
||||
guardrail="panw_prisma_airs",
|
||||
mode="pre_call",
|
||||
api_key="test_key",
|
||||
profile_name="test_profile",
|
||||
timeout="",
|
||||
)
|
||||
assert params.timeout is None
|
||||
|
||||
def test_legacy_initializer_handles_unset_timeout(self):
|
||||
"""Regression guard: with timeout now a declared Optional[float] = None
|
||||
on BaseLitellmParams, the legacy panw initializer at
|
||||
guardrail_initializers.py:220 must not crash on float(None) when the
|
||||
caller omits timeout entirely."""
|
||||
from litellm.proxy.guardrails.guardrail_initializers import (
|
||||
initialize_panw_prisma_airs,
|
||||
)
|
||||
|
||||
params = LitellmParams(
|
||||
guardrail="panw_prisma_airs",
|
||||
mode="pre_call",
|
||||
api_key="test_key",
|
||||
profile_name="test_profile",
|
||||
# timeout intentionally omitted - field defaults to None
|
||||
)
|
||||
guardrail_config = {"guardrail_name": "test_legacy"}
|
||||
handler = initialize_panw_prisma_airs(params, guardrail_config)
|
||||
# Default fallback applied, not crashed on float(None)
|
||||
assert handler.timeout == 10.0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v"])
|
||||
|
||||
@ -0,0 +1,900 @@
|
||||
import json
|
||||
import logging
|
||||
import ssl
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, List
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from litellm.exceptions import GuardrailRaisedException
|
||||
from litellm.exceptions import Timeout as LiteLLMTimeout
|
||||
from litellm.proxy.guardrails.guardrail_hooks.vigil_guard import (
|
||||
VigilGuardGuardrail,
|
||||
guardrail_class_registry,
|
||||
guardrail_initializer_registry,
|
||||
initialize_guardrail,
|
||||
)
|
||||
from litellm.proxy.guardrails.guardrail_hooks.vigil_guard.vigil_guard import (
|
||||
_DEFAULT_VIGIL_TIMEOUT,
|
||||
VigilGuardMissingConfig,
|
||||
)
|
||||
from litellm.types.guardrails import LitellmParams, SupportedGuardrailIntegrations
|
||||
from litellm.types.proxy.guardrails.guardrail_hooks.vigil_guard import (
|
||||
VigilGuardGuardrailConfigModel,
|
||||
)
|
||||
|
||||
_ENDPOINT = "https://vigil.test/v1/guard/analyze"
|
||||
|
||||
|
||||
def _resp(body: dict, status_code: int = 200) -> httpx.Response:
|
||||
return httpx.Response(
|
||||
status_code=status_code,
|
||||
json=body,
|
||||
request=httpx.Request("POST", _ENDPOINT),
|
||||
)
|
||||
|
||||
|
||||
class FakeHandler:
|
||||
def __init__(self, items: List[Any]):
|
||||
self._items = list(items)
|
||||
self.calls: List[SimpleNamespace] = []
|
||||
|
||||
async def post(self, *, url, headers, json, timeout=None): # noqa: A002
|
||||
self.calls.append(
|
||||
SimpleNamespace(url=url, headers=headers, json=json, timeout=timeout)
|
||||
)
|
||||
if not self._items:
|
||||
raise AssertionError("FakeHandler ran out of programmed responses")
|
||||
item = self._items.pop(0)
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
return item
|
||||
|
||||
|
||||
def _make_guardrail(
|
||||
handler: FakeHandler,
|
||||
*,
|
||||
unreachable_fallback="fail_closed",
|
||||
api_base="https://vigil.test",
|
||||
api_key="vg_secret_key_123",
|
||||
guardrail_name="vigil-guard",
|
||||
timeout=None,
|
||||
) -> VigilGuardGuardrail:
|
||||
return VigilGuardGuardrail(
|
||||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
unreachable_fallback=unreachable_fallback,
|
||||
timeout=timeout,
|
||||
async_handler=handler,
|
||||
guardrail_name=guardrail_name,
|
||||
event_hook="pre_call",
|
||||
default_on=True,
|
||||
)
|
||||
|
||||
|
||||
def _transient_exceptions() -> List[BaseException]:
|
||||
req = httpx.Request("POST", _ENDPOINT)
|
||||
return [
|
||||
httpx.ConnectError("boom", request=req),
|
||||
httpx.ConnectTimeout("boom", request=req),
|
||||
httpx.ReadTimeout("boom", request=req),
|
||||
httpx.RemoteProtocolError("boom", request=req),
|
||||
LiteLLMTimeout(message="t", model="m", llm_provider="vigil_guard"),
|
||||
]
|
||||
|
||||
|
||||
def test_requires_api_base(monkeypatch):
|
||||
monkeypatch.delenv("VIGIL_GUARD_URL", raising=False)
|
||||
monkeypatch.delenv("VIGIL_GUARD_API_KEY", raising=False)
|
||||
with pytest.raises(VigilGuardMissingConfig):
|
||||
VigilGuardGuardrail(api_key="k", async_handler=FakeHandler([]))
|
||||
|
||||
|
||||
def test_requires_api_key(monkeypatch):
|
||||
monkeypatch.delenv("VIGIL_GUARD_API_KEY", raising=False)
|
||||
with pytest.raises(VigilGuardMissingConfig):
|
||||
VigilGuardGuardrail(
|
||||
api_base="https://vigil.test", async_handler=FakeHandler([])
|
||||
)
|
||||
|
||||
|
||||
def test_trailing_slash_stripped():
|
||||
g = _make_guardrail(FakeHandler([]), api_base="https://vigil.test/")
|
||||
assert g.api_base == "https://vigil.test"
|
||||
|
||||
|
||||
def test_env_fallback(monkeypatch):
|
||||
monkeypatch.setenv("VIGIL_GUARD_URL", "https://env.vigil.test")
|
||||
monkeypatch.setenv("VIGIL_GUARD_API_KEY", "env_key")
|
||||
g = VigilGuardGuardrail(
|
||||
async_handler=FakeHandler([]),
|
||||
guardrail_name="vg",
|
||||
event_hook="pre_call",
|
||||
default_on=True,
|
||||
)
|
||||
assert g.api_base == "https://env.vigil.test"
|
||||
assert g.api_key == "env_key"
|
||||
|
||||
|
||||
def test_default_unreachable_fallback_is_fail_closed():
|
||||
g = _make_guardrail(FakeHandler([]), unreachable_fallback=None)
|
||||
assert g.unreachable_fallback == "fail_closed"
|
||||
|
||||
|
||||
def test_explicit_fail_open_is_stored():
|
||||
g = _make_guardrail(FakeHandler([]), unreachable_fallback="fail_open")
|
||||
assert g.unreachable_fallback == "fail_open"
|
||||
|
||||
|
||||
def test_unknown_fallback_defaults_to_fail_closed():
|
||||
g = _make_guardrail(FakeHandler([]), unreachable_fallback="weird")
|
||||
assert g.unreachable_fallback == "fail_closed"
|
||||
|
||||
|
||||
async def test_allowed_preserves_full_input_shape_and_logs_allow():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
structured = [{"role": "user", "content": "hello"}]
|
||||
inputs = {"texts": ["hello"], "structured_messages": structured, "model": "gpt-4o"}
|
||||
request_data = {"metadata": {}}
|
||||
out = await g.apply_guardrail(
|
||||
inputs=inputs, request_data=request_data, input_type="request", logging_obj=None
|
||||
)
|
||||
assert out["texts"] == ["hello"]
|
||||
assert out["structured_messages"] is structured
|
||||
assert out["model"] == "gpt-4o"
|
||||
assert out is not inputs
|
||||
assert inputs["structured_messages"] is structured
|
||||
assert len(handler.calls) == 1
|
||||
entries = request_data["metadata"]["standard_logging_guardrail_information"]
|
||||
assert entries[0]["guardrail_response"] == "allow"
|
||||
|
||||
|
||||
async def test_sanitized_replaces_text():
|
||||
handler = FakeHandler(
|
||||
[_resp({"decision": "SANITIZED", "sanitizedText": "[REDACTED]"})]
|
||||
)
|
||||
g = _make_guardrail(handler)
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": ["my ssn is 123"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert out["texts"] == ["[REDACTED]"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"body,expected",
|
||||
[
|
||||
(
|
||||
{
|
||||
"decision": "SANITIZED",
|
||||
"sanitizedText": "S",
|
||||
"outputText": "O",
|
||||
},
|
||||
"S",
|
||||
),
|
||||
({"decision": "SANITIZED", "outputText": "O"}, "O"),
|
||||
({"decision": "SANITIZED", "sanitizedText": 123, "outputText": "O"}, "O"),
|
||||
({"decision": "SANITIZED", "sanitizedText": ""}, ""),
|
||||
({"decision": "SANITIZED"}, "orig"),
|
||||
],
|
||||
)
|
||||
async def test_sanitized_precedence(body, expected):
|
||||
handler = FakeHandler([_resp(body)])
|
||||
g = _make_guardrail(handler)
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": ["orig"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert out["texts"] == [expected]
|
||||
|
||||
|
||||
async def test_blocked_raises_guardrail_exception_with_400():
|
||||
handler = FakeHandler([_resp({"decision": "BLOCKED", "blockMessage": "nope"})])
|
||||
g = _make_guardrail(handler)
|
||||
with pytest.raises(GuardrailRaisedException) as exc_info:
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["bad"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.guardrail_name == "vigil-guard"
|
||||
assert exc_info.value.message == "nope"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"body,expected",
|
||||
[
|
||||
(
|
||||
{
|
||||
"decision": "BLOCKED",
|
||||
"blockMessage": "bm",
|
||||
"decisionReason": "dr",
|
||||
"categories": ["c1"],
|
||||
},
|
||||
"bm",
|
||||
),
|
||||
({"decision": "BLOCKED", "blockMessage": " ", "decisionReason": "dr"}, "dr"),
|
||||
(
|
||||
{"decision": "BLOCKED", "decisionReason": "dr", "categories": ["c1", "c2"]},
|
||||
"dr",
|
||||
),
|
||||
({"decision": "BLOCKED", "categories": ["c1", "c2"]}, "c1, c2"),
|
||||
({"decision": "BLOCKED"}, "Blocked by policy"),
|
||||
],
|
||||
)
|
||||
async def test_block_reason_precedence(body, expected):
|
||||
handler = FakeHandler([_resp(body)])
|
||||
g = _make_guardrail(handler)
|
||||
with pytest.raises(GuardrailRaisedException) as exc_info:
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert exc_info.value.message == expected
|
||||
|
||||
|
||||
async def test_block_reason_is_clamped_to_500_chars():
|
||||
handler = FakeHandler([_resp({"decision": "BLOCKED", "blockMessage": "x" * 600})])
|
||||
g = _make_guardrail(handler)
|
||||
with pytest.raises(GuardrailRaisedException) as exc_info:
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert "x" * 500 in exc_info.value.message
|
||||
assert "x" * 501 not in exc_info.value.message
|
||||
|
||||
|
||||
async def test_empty_and_whitespace_texts_skip_analyze():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": ["", " ", "real"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert out["texts"] == ["", " ", "real"]
|
||||
assert len(handler.calls) == 1
|
||||
assert handler.calls[0].json["text"] == "real"
|
||||
|
||||
|
||||
async def test_no_scannable_text_returns_inputs_unchanged():
|
||||
handler = FakeHandler([])
|
||||
g = _make_guardrail(handler)
|
||||
inputs = {"texts": ["", " "], "structured_messages": [{"role": "user"}]}
|
||||
out = await g.apply_guardrail(inputs=inputs, request_data={}, input_type="request")
|
||||
assert out is inputs
|
||||
assert len(handler.calls) == 0
|
||||
|
||||
|
||||
async def test_multi_text_preserves_length_and_order():
|
||||
handler = FakeHandler(
|
||||
[
|
||||
_resp({"decision": "ALLOWED"}),
|
||||
_resp({"decision": "SANITIZED", "sanitizedText": "B-clean"}),
|
||||
_resp({"decision": "ALLOWED"}),
|
||||
]
|
||||
)
|
||||
g = _make_guardrail(handler)
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": ["A", "B", "C"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert out["texts"] == ["A", "B-clean", "C"]
|
||||
assert len(handler.calls) == 3
|
||||
|
||||
|
||||
async def test_one_blocked_text_blocks_the_whole_call():
|
||||
handler = FakeHandler(
|
||||
[
|
||||
_resp({"decision": "ALLOWED"}),
|
||||
_resp({"decision": "BLOCKED", "blockMessage": "bad second"}),
|
||||
]
|
||||
)
|
||||
g = _make_guardrail(handler)
|
||||
with pytest.raises(GuardrailRaisedException):
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["ok", "bad"]}, request_data={}, input_type="request"
|
||||
)
|
||||
|
||||
|
||||
async def test_request_source_is_user_input():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert handler.calls[0].json["source"] == "user_input"
|
||||
|
||||
|
||||
async def test_response_source_is_model_output():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="response"
|
||||
)
|
||||
assert handler.calls[0].json["source"] == "model_output"
|
||||
|
||||
|
||||
async def test_sanitized_returns_canonical_shape_and_logs_mask():
|
||||
handler = FakeHandler(
|
||||
[_resp({"decision": "SANITIZED", "sanitizedText": "[REDACTED]"})]
|
||||
)
|
||||
g = _make_guardrail(handler)
|
||||
tools = [{"type": "function", "function": {"name": "f"}}]
|
||||
inputs = {
|
||||
"texts": ["my ssn is 123"],
|
||||
"images": ["img1"],
|
||||
"tools": tools,
|
||||
"tool_calls": [{"id": "1"}],
|
||||
"structured_messages": [{"role": "user", "content": "my ssn is 123"}],
|
||||
"model": "gpt-4o",
|
||||
}
|
||||
request_data = {"metadata": {}}
|
||||
out = await g.apply_guardrail(
|
||||
inputs=inputs, request_data=request_data, input_type="request"
|
||||
)
|
||||
assert out["texts"] == ["[REDACTED]"]
|
||||
assert out["images"] == ["img1"]
|
||||
assert out["tools"] == tools
|
||||
assert set(out.keys()) == {"texts", "images", "tools"}
|
||||
entries = request_data["metadata"]["standard_logging_guardrail_information"]
|
||||
assert entries[0]["guardrail_response"] == "mask"
|
||||
|
||||
|
||||
async def test_empty_images_and_tools_are_preserved_when_present():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": ["x"], "images": [], "tools": []},
|
||||
request_data={},
|
||||
input_type="request",
|
||||
)
|
||||
assert set(out.keys()) == {"texts", "images", "tools"}
|
||||
assert out["images"] == []
|
||||
assert out["tools"] == []
|
||||
|
||||
|
||||
async def test_logging_obj_none_supported():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request", logging_obj=None
|
||||
)
|
||||
assert out["texts"] == ["x"]
|
||||
|
||||
|
||||
async def test_standard_guardrail_logging_remains_active():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
request_data = {"metadata": {}}
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data=request_data, input_type="request"
|
||||
)
|
||||
entries = request_data["metadata"]["standard_logging_guardrail_information"]
|
||||
assert len(entries) == 1
|
||||
assert entries[0]["guardrail_name"] == "vigil-guard"
|
||||
assert entries[0]["guardrail_status"] == "success"
|
||||
|
||||
|
||||
async def test_request_url_headers_and_body():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler, api_base="https://vigil.test", api_key="vg_secret")
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["hello"]}, request_data={}, input_type="request"
|
||||
)
|
||||
call = handler.calls[0]
|
||||
assert call.url == "https://vigil.test/v1/guard/analyze"
|
||||
assert call.headers["Authorization"] == "Bearer vg_secret"
|
||||
assert call.headers["Content-Type"] == "application/json"
|
||||
assert call.json["text"] == "hello"
|
||||
assert call.json["mode"] == "full"
|
||||
assert set(call.json.keys()) == {"text", "source", "mode", "metadata"}
|
||||
assert "metadata" in call.json
|
||||
|
||||
|
||||
async def test_default_timeout_forwarded_when_unset():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
assert g.timeout == _DEFAULT_VIGIL_TIMEOUT
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert handler.calls[0].timeout == _DEFAULT_VIGIL_TIMEOUT
|
||||
|
||||
|
||||
async def test_configured_timeout_forwarded_to_handler():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler, timeout=30)
|
||||
expected = httpx.Timeout(30, connect=5.0)
|
||||
assert g.timeout == expected
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert handler.calls[0].timeout == expected
|
||||
|
||||
|
||||
def test_short_timeout_caps_connect():
|
||||
g = _make_guardrail(FakeHandler([]), timeout=2)
|
||||
assert g.timeout == httpx.Timeout(2, connect=2.0)
|
||||
|
||||
|
||||
def test_initialize_guardrail_forwards_timeout():
|
||||
lp = LitellmParams(
|
||||
guardrail="vigil_guard",
|
||||
mode="pre_call",
|
||||
api_base="https://vigil.test",
|
||||
api_key="k",
|
||||
timeout="30",
|
||||
)
|
||||
cb = initialize_guardrail(lp, {"guardrail_name": "vg"})
|
||||
assert cb.timeout == httpx.Timeout(30, connect=5.0)
|
||||
|
||||
|
||||
async def test_api_key_only_in_header_never_in_payload():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler, api_key="super_secret_key")
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["hello"]},
|
||||
request_data={"metadata": {"user_id": "u"}},
|
||||
input_type="request",
|
||||
)
|
||||
call = handler.calls[0]
|
||||
assert "super_secret_key" not in json.dumps(call.json)
|
||||
assert call.headers["Authorization"] == "Bearer super_secret_key"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("code", [429, 502, 503, 504])
|
||||
async def test_retry_once_on_transient_status(code):
|
||||
handler = FakeHandler([_resp({}, status_code=code), _resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert out["texts"] == ["x"]
|
||||
assert len(handler.calls) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exc", _transient_exceptions())
|
||||
async def test_retry_once_on_transient_exception(exc):
|
||||
handler = FakeHandler([exc, _resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert out["texts"] == ["x"]
|
||||
assert len(handler.calls) == 2
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exc, expected",
|
||||
[
|
||||
(RuntimeError("boom"), RuntimeError),
|
||||
(
|
||||
httpx.WriteError("boom", request=httpx.Request("POST", _ENDPOINT)),
|
||||
GuardrailRaisedException,
|
||||
),
|
||||
],
|
||||
)
|
||||
async def test_no_retry_on_non_transient_exception(exc, expected):
|
||||
handler = FakeHandler([exc])
|
||||
g = _make_guardrail(handler)
|
||||
with pytest.raises(expected):
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert len(handler.calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("code", [400, 401, 403, 404, 422])
|
||||
async def test_no_retry_on_non_429_4xx(code):
|
||||
handler = FakeHandler([_resp({}, status_code=code)])
|
||||
g = _make_guardrail(handler)
|
||||
with pytest.raises(GuardrailRaisedException) as exc_info:
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert len(handler.calls) == 1
|
||||
|
||||
|
||||
async def test_fail_closed_raises_after_exhausted_retry(caplog):
|
||||
handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)])
|
||||
g = _make_guardrail(handler)
|
||||
with (
|
||||
caplog.at_level(logging.ERROR),
|
||||
pytest.raises(GuardrailRaisedException) as exc_info,
|
||||
):
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert len(handler.calls) == 2
|
||||
assert any("fail_closed" in record.message for record in caplog.records)
|
||||
assert any("vigil-guard" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exc", _transient_exceptions())
|
||||
async def test_fail_closed_raises_controlled_block_on_transport_error(exc, caplog):
|
||||
handler = FakeHandler([exc, exc])
|
||||
g = _make_guardrail(handler)
|
||||
with (
|
||||
caplog.at_level(logging.ERROR),
|
||||
pytest.raises(GuardrailRaisedException) as exc_info,
|
||||
):
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.guardrail_name == "vigil-guard"
|
||||
assert exc_info.value.__cause__ is exc
|
||||
assert any("fail_closed" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
async def test_fail_open_returns_inputs_unchanged_on_backend_error(caplog):
|
||||
handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)])
|
||||
g = _make_guardrail(handler, unreachable_fallback="fail_open")
|
||||
structured = [{"role": "user", "content": "x"}]
|
||||
inputs = {"texts": ["x"], "structured_messages": structured}
|
||||
request_data = {"metadata": {}}
|
||||
with caplog.at_level(logging.ERROR):
|
||||
out = await g.apply_guardrail(
|
||||
inputs=inputs, request_data=request_data, input_type="request"
|
||||
)
|
||||
assert out is not inputs
|
||||
assert out["texts"] == ["x"]
|
||||
assert out["structured_messages"] == structured
|
||||
assert len(handler.calls) == 2
|
||||
assert any("fail_open" in record.message for record in caplog.records)
|
||||
assert any("vigil-guard" in record.message for record in caplog.records)
|
||||
entries = request_data["metadata"]["standard_logging_guardrail_information"]
|
||||
assert entries[0]["guardrail_response"] == "allow"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("exc", [ssl.SSLError("tls failed"), OSError("network down")])
|
||||
async def test_fail_open_returns_inputs_unchanged_on_transport_error(exc):
|
||||
handler = FakeHandler([exc])
|
||||
g = _make_guardrail(handler, unreachable_fallback="fail_open")
|
||||
inputs = {"texts": ["x"]}
|
||||
out = await g.apply_guardrail(inputs=inputs, request_data={}, input_type="request")
|
||||
assert out is not inputs
|
||||
assert out["texts"] == ["x"]
|
||||
assert len(handler.calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exc",
|
||||
[
|
||||
TypeError("bug"),
|
||||
KeyError("bug"),
|
||||
AttributeError("bug"),
|
||||
],
|
||||
)
|
||||
async def test_fail_open_does_not_swallow_programming_errors(exc):
|
||||
handler = FakeHandler([exc])
|
||||
g = _make_guardrail(handler, unreachable_fallback="fail_open")
|
||||
with pytest.raises(type(exc)):
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert len(handler.calls) == 1
|
||||
|
||||
|
||||
async def test_invalid_decision_fail_closed_raises(caplog):
|
||||
handler = FakeHandler([_resp({"decision": "MAYBE"})])
|
||||
g = _make_guardrail(handler)
|
||||
with (
|
||||
caplog.at_level(logging.ERROR),
|
||||
pytest.raises(GuardrailRaisedException) as exc_info,
|
||||
):
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "MAYBE" not in exc_info.value.message
|
||||
assert any("MAYBE" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
async def test_invalid_decision_fail_open_returns_inputs():
|
||||
handler = FakeHandler([_resp({"decision": "MAYBE"})])
|
||||
g = _make_guardrail(handler, unreachable_fallback="fail_open")
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data={}, input_type="request"
|
||||
)
|
||||
assert out["texts"] == ["x"]
|
||||
|
||||
|
||||
async def test_fail_open_multi_text_preserves_earlier_sanitization():
|
||||
handler = FakeHandler(
|
||||
[
|
||||
_resp({"decision": "SANITIZED", "sanitizedText": "[REDACTED]"}),
|
||||
_resp({}, status_code=503),
|
||||
_resp({}, status_code=503),
|
||||
]
|
||||
)
|
||||
g = _make_guardrail(handler, unreachable_fallback="fail_open")
|
||||
request_data = {"metadata": {}}
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": ["my ssn is 123", "second"]},
|
||||
request_data=request_data,
|
||||
input_type="request",
|
||||
)
|
||||
assert out["texts"] == ["[REDACTED]", "second"]
|
||||
assert len(handler.calls) == 3
|
||||
entries = request_data["metadata"]["standard_logging_guardrail_information"]
|
||||
assert entries[0]["guardrail_response"] == "mask"
|
||||
|
||||
|
||||
def _tool_call(arguments, name="f", tc_id="1"):
|
||||
return {
|
||||
"id": tc_id,
|
||||
"type": "function",
|
||||
"function": {"name": name, "arguments": arguments},
|
||||
}
|
||||
|
||||
|
||||
async def test_response_tool_call_arguments_allowed_unchanged():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
tcs = [_tool_call('{"q": "weather"}')]
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": [], "tool_calls": tcs}, request_data={}, input_type="response"
|
||||
)
|
||||
assert handler.calls[0].json["text"] == '{"q": "weather"}'
|
||||
assert handler.calls[0].json["source"] == "model_output"
|
||||
assert out["tool_calls"] == tcs
|
||||
|
||||
|
||||
async def test_response_tool_call_arguments_sanitized_in_place():
|
||||
handler = FakeHandler(
|
||||
[_resp({"decision": "SANITIZED", "sanitizedText": '{"email": "[EMAIL]"}'})]
|
||||
)
|
||||
g = _make_guardrail(handler)
|
||||
tcs = [_tool_call('{"email": "john@example.com"}', name="send_mail")]
|
||||
inputs = {"texts": [], "tool_calls": tcs}
|
||||
out = await g.apply_guardrail(inputs=inputs, request_data={}, input_type="response")
|
||||
assert out["tool_calls"][0]["function"]["arguments"] == '{"email": "[EMAIL]"}'
|
||||
assert out["tool_calls"][0]["function"]["name"] == "send_mail"
|
||||
# original inputs are not mutated in place
|
||||
assert inputs["tool_calls"][0]["function"]["arguments"] == (
|
||||
'{"email": "john@example.com"}'
|
||||
)
|
||||
|
||||
|
||||
async def test_response_tool_call_arguments_blocked_raises():
|
||||
handler = FakeHandler(
|
||||
[_resp({"decision": "BLOCKED", "blockMessage": "tool blocked"})]
|
||||
)
|
||||
g = _make_guardrail(handler)
|
||||
tcs = [_tool_call('{"x": "bad"}')]
|
||||
with pytest.raises(GuardrailRaisedException) as exc_info:
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": [], "tool_calls": tcs},
|
||||
request_data={},
|
||||
input_type="response",
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert exc_info.value.message == "tool blocked"
|
||||
|
||||
|
||||
async def test_request_tool_calls_are_not_scanned():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
tcs = [_tool_call('{"x": "y"}')]
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["hello"], "tool_calls": tcs},
|
||||
request_data={},
|
||||
input_type="request",
|
||||
)
|
||||
assert len(handler.calls) == 1
|
||||
assert handler.calls[0].json["text"] == "hello"
|
||||
|
||||
|
||||
async def test_tool_call_scan_backend_failure_fail_closed_raises():
|
||||
handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)])
|
||||
g = _make_guardrail(handler)
|
||||
tcs = [_tool_call('{"x": "y"}')]
|
||||
with pytest.raises(GuardrailRaisedException) as exc_info:
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": [], "tool_calls": tcs},
|
||||
request_data={},
|
||||
input_type="response",
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert len(handler.calls) == 2
|
||||
|
||||
|
||||
async def test_tool_call_scan_backend_failure_fail_open_passes_through():
|
||||
handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)])
|
||||
g = _make_guardrail(handler, unreachable_fallback="fail_open")
|
||||
tcs = [_tool_call('{"x": "y"}')]
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": [], "tool_calls": tcs}, request_data={}, input_type="response"
|
||||
)
|
||||
assert out["tool_calls"] == tcs
|
||||
|
||||
|
||||
async def test_response_tool_call_unrecognized_decision_fail_closed_raises():
|
||||
handler = FakeHandler([_resp({"decision": "MAYBE"})])
|
||||
g = _make_guardrail(handler)
|
||||
tcs = [_tool_call('{"x": "y"}')]
|
||||
with pytest.raises(GuardrailRaisedException) as exc_info:
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": [], "tool_calls": tcs},
|
||||
request_data={},
|
||||
input_type="response",
|
||||
)
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
|
||||
async def test_response_tool_call_unrecognized_decision_fail_open_passes_through():
|
||||
handler = FakeHandler([_resp({"decision": "MAYBE"})])
|
||||
g = _make_guardrail(handler, unreachable_fallback="fail_open")
|
||||
tcs = [_tool_call('{"x": "y"}')]
|
||||
out = await g.apply_guardrail(
|
||||
inputs={"texts": [], "tool_calls": tcs}, request_data={}, input_type="response"
|
||||
)
|
||||
assert out["tool_calls"] == tcs
|
||||
|
||||
|
||||
async def test_metadata_allowlist_and_clamping():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
request_data = {
|
||||
"model": "gpt-4o",
|
||||
"metadata": {
|
||||
"user_id": "u1",
|
||||
"tenant_id": "t1",
|
||||
"secret_unlisted": "should_not_forward",
|
||||
"session_id": "s" * 600,
|
||||
"org_id": ["a"] * 20,
|
||||
"request_id": True,
|
||||
"conversation_id": 7,
|
||||
},
|
||||
}
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data=request_data, input_type="request"
|
||||
)
|
||||
md = handler.calls[0].json["metadata"]
|
||||
assert md["model"] == "gpt-4o"
|
||||
assert md["user_id"] == "u1"
|
||||
assert md["tenant_id"] == "t1"
|
||||
assert "secret_unlisted" not in md
|
||||
assert len(md["session_id"]) == 500
|
||||
assert len(md["org_id"]) == 10
|
||||
assert "request_id" not in md
|
||||
assert md["conversation_id"] == 7
|
||||
|
||||
|
||||
async def test_metadata_source_precedence_and_litellm_metadata_fallback():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
request_data = {
|
||||
"user_id": "top",
|
||||
"metadata": {"user_id": "nested"},
|
||||
"litellm_metadata": {"tenant_id": "lm-tenant"},
|
||||
}
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data=request_data, input_type="request"
|
||||
)
|
||||
md = handler.calls[0].json["metadata"]
|
||||
assert md["user_id"] == "top"
|
||||
assert md["tenant_id"] == "lm-tenant"
|
||||
|
||||
|
||||
async def test_metadata_uses_later_source_when_earlier_value_is_unclampable():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
request_data = {
|
||||
"user_id": {"drop": "dicts are not forwarded"},
|
||||
"metadata": {"user_id": "nested"},
|
||||
}
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data=request_data, input_type="request"
|
||||
)
|
||||
assert handler.calls[0].json["metadata"]["user_id"] == "nested"
|
||||
|
||||
|
||||
async def test_metadata_array_items_are_clamped_and_filtered():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
request_data = {
|
||||
"metadata": {
|
||||
"org_id": ["z" * 600, 123, True, {"drop": 1}, None],
|
||||
},
|
||||
}
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]}, request_data=request_data, input_type="request"
|
||||
)
|
||||
assert handler.calls[0].json["metadata"]["org_id"] == ["z" * 500, 123]
|
||||
|
||||
|
||||
async def test_metadata_array_with_no_supported_items_is_dropped():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]},
|
||||
request_data={"metadata": {"org_id": [{"drop": 1}, None]}},
|
||||
input_type="request",
|
||||
)
|
||||
assert "org_id" not in handler.calls[0].json["metadata"]
|
||||
|
||||
|
||||
async def test_call_id_forwarded_from_logging_obj():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
logging_obj = SimpleNamespace(litellm_call_id="call-123")
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]},
|
||||
request_data={},
|
||||
input_type="request",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
assert handler.calls[0].json["metadata"]["litellm_call_id"] == "call-123"
|
||||
|
||||
|
||||
async def test_call_id_forwarded_from_request_data():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]},
|
||||
request_data={"litellm_call_id": "rd-1"},
|
||||
input_type="request",
|
||||
logging_obj=None,
|
||||
)
|
||||
assert handler.calls[0].json["metadata"]["litellm_call_id"] == "rd-1"
|
||||
|
||||
|
||||
async def test_call_id_forwarded_from_request_metadata():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]},
|
||||
request_data={"metadata": {"litellm_call_id": "md-1"}},
|
||||
input_type="request",
|
||||
logging_obj=None,
|
||||
)
|
||||
assert handler.calls[0].json["metadata"]["litellm_call_id"] == "md-1"
|
||||
|
||||
|
||||
async def test_call_id_logging_obj_takes_precedence():
|
||||
handler = FakeHandler([_resp({"decision": "ALLOWED"})])
|
||||
g = _make_guardrail(handler)
|
||||
logging_obj = SimpleNamespace(litellm_call_id="log-1")
|
||||
await g.apply_guardrail(
|
||||
inputs={"texts": ["x"]},
|
||||
request_data={"litellm_call_id": "rd-1"},
|
||||
input_type="request",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
assert handler.calls[0].json["metadata"]["litellm_call_id"] == "log-1"
|
||||
|
||||
|
||||
def test_enum_value():
|
||||
assert SupportedGuardrailIntegrations.VIGIL_GUARD.value == "vigil_guard"
|
||||
|
||||
|
||||
def test_config_model_ui_name_and_instantiation():
|
||||
assert VigilGuardGuardrailConfigModel.ui_friendly_name() == "Vigil Guard"
|
||||
model = VigilGuardGuardrailConfigModel(api_base="https://x", api_key="k")
|
||||
assert model.api_base == "https://x"
|
||||
|
||||
|
||||
def test_get_config_model_returns_config_model():
|
||||
g = _make_guardrail(FakeHandler([]))
|
||||
assert g.get_config_model() is VigilGuardGuardrailConfigModel
|
||||
|
||||
|
||||
def test_registries_expose_initializer_and_class():
|
||||
assert "vigil_guard" in guardrail_initializer_registry
|
||||
assert guardrail_class_registry["vigil_guard"] is VigilGuardGuardrail
|
||||
|
||||
|
||||
def test_litellm_params_includes_config_model():
|
||||
assert VigilGuardGuardrailConfigModel in LitellmParams.__mro__
|
||||
|
||||
|
||||
def test_config_driven_initialization_creates_callback():
|
||||
lp = LitellmParams(
|
||||
guardrail="vigil_guard",
|
||||
mode="pre_call",
|
||||
api_base="https://vigil.test",
|
||||
api_key="k",
|
||||
)
|
||||
cb = initialize_guardrail(lp, {"guardrail_name": "vg"})
|
||||
assert isinstance(cb, VigilGuardGuardrail)
|
||||
assert cb.unreachable_fallback == "fail_closed"
|
||||
@ -0,0 +1,213 @@
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
|
||||
class TestContentFilterPathTraversal:
|
||||
"""Tests that _resolve_category_file_path rejects path traversal."""
|
||||
|
||||
def _get_guardrail(self):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
|
||||
ContentFilterGuardrail,
|
||||
)
|
||||
|
||||
return ContentFilterGuardrail.__new__(ContentFilterGuardrail)
|
||||
|
||||
def test_traversal_via_relative_dotdot_raises(self):
|
||||
guardrail = self._get_guardrail()
|
||||
with pytest.raises(ValueError, match="outside the allowed categories"):
|
||||
guardrail._resolve_category_file_path("../../../../etc/passwd")
|
||||
|
||||
def test_traversal_via_absolute_path_raises(self):
|
||||
guardrail = self._get_guardrail()
|
||||
with pytest.raises(ValueError, match="outside the allowed categories"):
|
||||
guardrail._resolve_category_file_path("/etc/passwd")
|
||||
|
||||
def test_valid_category_file_inside_categories_dir_allowed(self):
|
||||
guardrail = self._get_guardrail()
|
||||
categories_dir = os.path.join(
|
||||
os.path.dirname(
|
||||
__import__(
|
||||
"litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter",
|
||||
fromlist=["content_filter"],
|
||||
).__file__
|
||||
),
|
||||
"categories",
|
||||
)
|
||||
valid_file = os.path.join(categories_dir, "harmful_self_harm.yaml")
|
||||
if not os.path.exists(valid_file):
|
||||
pytest.skip("harmful_self_harm.yaml not present in this environment")
|
||||
result = guardrail._resolve_category_file_path(valid_file)
|
||||
assert result == valid_file
|
||||
|
||||
def test_invalid_category_name_skipped(self):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
|
||||
ContentFilterGuardrail,
|
||||
)
|
||||
|
||||
guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail)
|
||||
guardrail.loaded_categories = {}
|
||||
guardrail.severity_threshold = "medium"
|
||||
guardrail.category_keywords = {}
|
||||
guardrail.always_block_category_keywords = {}
|
||||
guardrail.conditional_categories = {}
|
||||
# category name with path traversal chars must be skipped, not crash
|
||||
guardrail._load_categories([{"category": "../../etc/passwd", "enabled": True}])
|
||||
assert "../../etc/passwd" not in guardrail.loaded_categories
|
||||
|
||||
def test_category_name_with_slash_skipped(self):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
|
||||
ContentFilterGuardrail,
|
||||
)
|
||||
|
||||
guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail)
|
||||
guardrail.loaded_categories = {}
|
||||
guardrail.severity_threshold = "medium"
|
||||
guardrail.category_keywords = {}
|
||||
guardrail.always_block_category_keywords = {}
|
||||
guardrail.conditional_categories = {}
|
||||
guardrail._load_categories(
|
||||
[{"category": "foo/../../etc/passwd", "enabled": True}]
|
||||
)
|
||||
assert "foo/../../etc/passwd" not in guardrail.loaded_categories
|
||||
|
||||
def test_assert_within_categories_dir_blocks_parent_traversal(self):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
|
||||
ContentFilterGuardrail,
|
||||
)
|
||||
|
||||
categories_dir = os.path.join(
|
||||
os.path.dirname(
|
||||
__import__(
|
||||
"litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter",
|
||||
fromlist=["content_filter"],
|
||||
).__file__
|
||||
),
|
||||
"categories",
|
||||
)
|
||||
with pytest.raises(ValueError, match="outside the allowed categories"):
|
||||
ContentFilterGuardrail._assert_within_categories_dir(
|
||||
"/etc/passwd", categories_dir
|
||||
)
|
||||
|
||||
def test_assert_within_categories_dir_allows_valid_file(self, tmp_path):
|
||||
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
|
||||
ContentFilterGuardrail,
|
||||
)
|
||||
|
||||
categories_dir = str(tmp_path)
|
||||
valid_file = str(tmp_path / "test.yaml")
|
||||
# Should not raise
|
||||
ContentFilterGuardrail._assert_within_categories_dir(valid_file, categories_dir)
|
||||
|
||||
def test_assert_within_categories_dir_commonpath_raises_valueerror(self, tmp_path):
|
||||
"""Cover the except-ValueError branch (Windows cross-drive paths)."""
|
||||
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
|
||||
ContentFilterGuardrail,
|
||||
)
|
||||
|
||||
categories_dir = str(tmp_path)
|
||||
valid_file = str(tmp_path / "test.yaml")
|
||||
with patch(
|
||||
"os.path.commonpath", side_effect=ValueError("Paths on different drives")
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError, match="outside the allowed categories directory"
|
||||
):
|
||||
ContentFilterGuardrail._assert_within_categories_dir(
|
||||
valid_file, categories_dir
|
||||
)
|
||||
|
||||
def test_resolve_category_file_path_direct_join_hit(self):
|
||||
"""Cover the first-join-attempt success branch (lines 383-384)."""
|
||||
guardrail = self._get_guardrail()
|
||||
# "categories/<file>" joined directly to module_dir resolves to an existing file.
|
||||
categories_dir = os.path.join(
|
||||
os.path.dirname(
|
||||
__import__(
|
||||
"litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter",
|
||||
fromlist=["content_filter"],
|
||||
).__file__
|
||||
),
|
||||
"categories",
|
||||
)
|
||||
yaml_files = [f for f in os.listdir(categories_dir) if f.endswith(".yaml")]
|
||||
if not yaml_files:
|
||||
pytest.skip("No category YAML files present in this environment")
|
||||
relative_path = os.path.join("categories", yaml_files[0])
|
||||
result = guardrail._resolve_category_file_path(relative_path)
|
||||
assert os.path.isabs(result) or os.path.exists(result)
|
||||
|
||||
def test_resolve_category_file_path_component_strip_hit(self):
|
||||
"""Cover the component-stripping loop success branch (lines 392-393)."""
|
||||
guardrail = self._get_guardrail()
|
||||
categories_dir = os.path.join(
|
||||
os.path.dirname(
|
||||
__import__(
|
||||
"litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter",
|
||||
fromlist=["content_filter"],
|
||||
).__file__
|
||||
),
|
||||
"categories",
|
||||
)
|
||||
yaml_files = [f for f in os.listdir(categories_dir) if f.endswith(".yaml")]
|
||||
if not yaml_files:
|
||||
pytest.skip("No category YAML files present in this environment")
|
||||
# Prefix with a fake leading component so the first-join attempt misses,
|
||||
# but stripping that component reveals categories/<file> which exists.
|
||||
prefixed_path = "some_prefix/categories/" + yaml_files[0]
|
||||
result = guardrail._resolve_category_file_path(prefixed_path)
|
||||
assert os.path.isabs(result) or os.path.exists(result)
|
||||
|
||||
def test_load_categories_traversal_category_file_skipped(self):
|
||||
"""Cover the except-ValueError branch in _load_categories (lines 451-454)."""
|
||||
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
|
||||
ContentFilterGuardrail,
|
||||
)
|
||||
|
||||
guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail)
|
||||
guardrail.loaded_categories = {}
|
||||
guardrail.severity_threshold = "medium"
|
||||
guardrail.category_keywords = {}
|
||||
guardrail.always_block_category_keywords = {}
|
||||
guardrail.conditional_categories = {}
|
||||
# A traversal path in category_file must be skipped (not crash) via ValueError.
|
||||
guardrail._load_categories(
|
||||
[
|
||||
{
|
||||
"category": "valid_name",
|
||||
"enabled": True,
|
||||
"category_file": "../../../../etc/passwd",
|
||||
}
|
||||
]
|
||||
)
|
||||
assert "valid_name" not in guardrail.loaded_categories
|
||||
|
||||
def test_allow_external_paths_env_var_bypasses_jail(self, tmp_path):
|
||||
"""LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS=true skips the directory jail."""
|
||||
import os as _os
|
||||
from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import (
|
||||
ContentFilterGuardrail,
|
||||
)
|
||||
|
||||
guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail)
|
||||
# Create a real file outside the module directory (simulates mounted volume).
|
||||
external_file = tmp_path / "external_categories.yaml"
|
||||
external_file.write_text("category_name: test\n")
|
||||
|
||||
with patch.dict(
|
||||
_os.environ, {"LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS": "true"}
|
||||
):
|
||||
# Should return the path without raising ValueError.
|
||||
result = guardrail._resolve_category_file_path(str(external_file))
|
||||
assert result == str(external_file)
|
||||
|
||||
def test_traversal_blocked_when_allow_external_not_set(self):
|
||||
"""Without the env var the jail still blocks traversal paths."""
|
||||
import os as _os
|
||||
|
||||
guardrail = self._get_guardrail()
|
||||
with patch.dict(_os.environ, {}, clear=False):
|
||||
_os.environ.pop("LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS", None)
|
||||
with pytest.raises(ValueError, match="outside the allowed categories"):
|
||||
guardrail._resolve_category_file_path("/etc/passwd")
|
||||
@ -0,0 +1,444 @@
|
||||
"""
|
||||
Unit tests for watsonx_proxy_route endpoint.
|
||||
|
||||
Tests the Watsonx pass-through endpoint that handles automatic IAM token management
|
||||
and version parameter injection.
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException, Request, Response
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
import litellm
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import (
|
||||
watsonx_proxy_route,
|
||||
)
|
||||
|
||||
|
||||
class TestWatsonxProxyRoute:
|
||||
"""Tests for the Watsonx pass-through route."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watsonx_proxy_route_success_non_streaming(self):
|
||||
"""Test successful non-streaming request through Watsonx proxy route."""
|
||||
# Setup mocks
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.method = "POST"
|
||||
mock_request.query_params = {}
|
||||
mock_request.headers = {"content-type": "application/json"}
|
||||
mock_request.json = AsyncMock(return_value={"stream": False, "input": "test"})
|
||||
mock_response = MagicMock(spec=Response)
|
||||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||
|
||||
# Mock provider config
|
||||
mock_provider_config = MagicMock()
|
||||
mock_provider_config.get_complete_url.return_value = (
|
||||
"https://us-south.ml.cloud.ibm.com/ml/v1/text/generation",
|
||||
{},
|
||||
)
|
||||
mock_provider_config.validate_environment.return_value = {
|
||||
"Authorization": "Bearer test-iam-token"
|
||||
}
|
||||
|
||||
# Mock endpoint function
|
||||
mock_endpoint_func = AsyncMock(
|
||||
return_value={"model_id": "ibm/granite-13b-chat-v2", "results": []}
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
|
||||
return_value=mock_provider_config,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
|
||||
return_value=mock_endpoint_func,
|
||||
) as mock_create_route,
|
||||
):
|
||||
result = await watsonx_proxy_route(
|
||||
endpoint="ml/v1/text/generation",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
# Verify provider config was called correctly
|
||||
mock_provider_config.get_complete_url.assert_called_once()
|
||||
mock_provider_config.validate_environment.assert_called_once()
|
||||
|
||||
# Verify create_pass_through_route was called with correct parameters
|
||||
mock_create_route.assert_called_once()
|
||||
call_args = mock_create_route.call_args[1]
|
||||
assert call_args["endpoint"] == "ml/v1/text/generation"
|
||||
assert (
|
||||
call_args["target"]
|
||||
== "https://us-south.ml.cloud.ibm.com/ml/v1/text/generation"
|
||||
)
|
||||
assert (
|
||||
call_args["custom_headers"]["Authorization"] == "Bearer test-iam-token"
|
||||
)
|
||||
assert call_args["is_streaming_request"] is False
|
||||
assert call_args["custom_llm_provider"] == "watsonx"
|
||||
assert (
|
||||
call_args["query_params"]["version"]
|
||||
== litellm.WATSONX_DEFAULT_API_VERSION
|
||||
)
|
||||
|
||||
# Verify endpoint function was called
|
||||
mock_endpoint_func.assert_called_once_with(
|
||||
mock_request, mock_response, mock_user_api_key_dict
|
||||
)
|
||||
|
||||
assert result == {"model_id": "ibm/granite-13b-chat-v2", "results": []}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watsonx_proxy_route_success_streaming(self):
|
||||
"""Test successful streaming request through Watsonx proxy route."""
|
||||
# Setup mocks
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.method = "POST"
|
||||
mock_request.query_params = {}
|
||||
mock_request.headers = {"content-type": "application/json"}
|
||||
mock_request.json = AsyncMock(return_value={"stream": True, "input": "test"})
|
||||
mock_response = MagicMock(spec=Response)
|
||||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||
|
||||
# Mock provider config
|
||||
mock_provider_config = MagicMock()
|
||||
mock_provider_config.get_complete_url.return_value = (
|
||||
"https://us-south.ml.cloud.ibm.com/ml/v1/text/generation_stream",
|
||||
{},
|
||||
)
|
||||
mock_provider_config.validate_environment.return_value = {
|
||||
"Authorization": "Bearer test-iam-token"
|
||||
}
|
||||
|
||||
# Mock endpoint function
|
||||
mock_endpoint_func = AsyncMock(return_value="streaming_response")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
|
||||
return_value=mock_provider_config,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
|
||||
return_value=mock_endpoint_func,
|
||||
) as mock_create_route,
|
||||
):
|
||||
result = await watsonx_proxy_route(
|
||||
endpoint="ml/v1/text/generation_stream",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
# Verify create_pass_through_route was called with streaming enabled
|
||||
mock_create_route.assert_called_once()
|
||||
call_args = mock_create_route.call_args[1]
|
||||
assert call_args["is_streaming_request"] is True
|
||||
|
||||
assert result == "streaming_response"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watsonx_proxy_route_get_request(self):
|
||||
"""Test GET request through Watsonx proxy route."""
|
||||
# Setup mocks
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.method = "GET"
|
||||
mock_request.query_params = {"project_id": "test-project"}
|
||||
mock_request.headers = {}
|
||||
mock_response = MagicMock(spec=Response)
|
||||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||
|
||||
# Mock provider config
|
||||
mock_provider_config = MagicMock()
|
||||
mock_provider_config.get_complete_url.return_value = (
|
||||
"https://us-south.ml.cloud.ibm.com/ml/v1/models",
|
||||
{},
|
||||
)
|
||||
mock_provider_config.validate_environment.return_value = {
|
||||
"Authorization": "Bearer test-iam-token"
|
||||
}
|
||||
|
||||
# Mock endpoint function
|
||||
mock_endpoint_func = AsyncMock(return_value={"resources": []})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
|
||||
return_value=mock_provider_config,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
|
||||
return_value=mock_endpoint_func,
|
||||
) as mock_create_route,
|
||||
):
|
||||
result = await watsonx_proxy_route(
|
||||
endpoint="ml/v1/models",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
# Verify is_streaming_request is False for GET requests
|
||||
mock_create_route.assert_called_once()
|
||||
call_args = mock_create_route.call_args[1]
|
||||
assert call_args["is_streaming_request"] is False
|
||||
|
||||
assert result == {"resources": []}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watsonx_proxy_route_multipart_form_data(self):
|
||||
"""Test multipart/form-data request through Watsonx proxy route."""
|
||||
# Setup mocks
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.method = "POST"
|
||||
mock_request.query_params = {}
|
||||
mock_request.headers = {"content-type": "multipart/form-data; boundary=----"}
|
||||
mock_response = MagicMock(spec=Response)
|
||||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||
|
||||
# Mock form data
|
||||
mock_form_data = {"file": "test_file", "stream": False}
|
||||
|
||||
# Mock provider config
|
||||
mock_provider_config = MagicMock()
|
||||
mock_provider_config.get_complete_url.return_value = (
|
||||
"https://us-south.ml.cloud.ibm.com/ml/v1/text/tokenization",
|
||||
{},
|
||||
)
|
||||
mock_provider_config.validate_environment.return_value = {
|
||||
"Authorization": "Bearer test-iam-token"
|
||||
}
|
||||
|
||||
# Mock endpoint function
|
||||
mock_endpoint_func = AsyncMock(return_value={"token_count": 10})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
|
||||
return_value=mock_provider_config,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.get_form_data",
|
||||
return_value=mock_form_data,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
|
||||
return_value=mock_endpoint_func,
|
||||
) as mock_create_route,
|
||||
):
|
||||
result = await watsonx_proxy_route(
|
||||
endpoint="ml/v1/text/tokenization",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
# Verify is_streaming_request is False for non-streaming form data
|
||||
mock_create_route.assert_called_once()
|
||||
call_args = mock_create_route.call_args[1]
|
||||
assert call_args["is_streaming_request"] is False
|
||||
|
||||
assert result == {"token_count": 10}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watsonx_proxy_route_no_provider_config(self):
|
||||
"""Test that HTTPException is raised when provider config is not found."""
|
||||
# Setup mocks
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.method = "POST"
|
||||
mock_request.query_params = {}
|
||||
mock_request.headers = {"content-type": "application/json"}
|
||||
mock_response = MagicMock(spec=Response)
|
||||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await watsonx_proxy_route(
|
||||
endpoint="ml/v1/text/generation",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 404
|
||||
assert exc_info.value.detail == "Watsonx passthrough config not found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watsonx_proxy_route_version_parameter_injection(self):
|
||||
"""Test that version parameter is correctly injected into query params."""
|
||||
# Setup mocks
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.method = "POST"
|
||||
mock_request.query_params = {}
|
||||
mock_request.headers = {"content-type": "application/json"}
|
||||
mock_request.json = AsyncMock(return_value={"input": "test"})
|
||||
mock_response = MagicMock(spec=Response)
|
||||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||
|
||||
# Mock provider config
|
||||
mock_provider_config = MagicMock()
|
||||
mock_provider_config.get_complete_url.return_value = (
|
||||
"https://us-south.ml.cloud.ibm.com/ml/v1/text/generation",
|
||||
{},
|
||||
)
|
||||
mock_provider_config.validate_environment.return_value = {
|
||||
"Authorization": "Bearer test-iam-token"
|
||||
}
|
||||
|
||||
# Mock endpoint function
|
||||
mock_endpoint_func = AsyncMock(return_value={})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
|
||||
return_value=mock_provider_config,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
|
||||
return_value=mock_endpoint_func,
|
||||
) as mock_create_route,
|
||||
):
|
||||
await watsonx_proxy_route(
|
||||
endpoint="ml/v1/text/generation",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
# Verify version parameter is injected
|
||||
mock_create_route.assert_called_once()
|
||||
call_args = mock_create_route.call_args[1]
|
||||
assert "query_params" in call_args
|
||||
assert "version" in call_args["query_params"]
|
||||
assert (
|
||||
call_args["query_params"]["version"]
|
||||
== litellm.WATSONX_DEFAULT_API_VERSION
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watsonx_proxy_route_custom_headers_from_validate_environment(self):
|
||||
"""Test that custom headers from validate_environment are passed through."""
|
||||
# Setup mocks
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.method = "POST"
|
||||
mock_request.query_params = {}
|
||||
mock_request.headers = {"content-type": "application/json"}
|
||||
mock_request.json = AsyncMock(return_value={"input": "test"})
|
||||
mock_response = MagicMock(spec=Response)
|
||||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||
|
||||
# Mock provider config with custom headers
|
||||
mock_provider_config = MagicMock()
|
||||
mock_provider_config.get_complete_url.return_value = (
|
||||
"https://us-south.ml.cloud.ibm.com/ml/v1/text/generation",
|
||||
{},
|
||||
)
|
||||
mock_provider_config.validate_environment.return_value = {
|
||||
"Authorization": "Bearer test-iam-token",
|
||||
"X-Custom-Header": "custom-value",
|
||||
}
|
||||
|
||||
# Mock endpoint function
|
||||
mock_endpoint_func = AsyncMock(return_value={})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
|
||||
return_value=mock_provider_config,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
|
||||
return_value=mock_endpoint_func,
|
||||
) as mock_create_route,
|
||||
):
|
||||
await watsonx_proxy_route(
|
||||
endpoint="ml/v1/text/generation",
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
# Verify custom headers are passed through
|
||||
mock_create_route.assert_called_once()
|
||||
call_args = mock_create_route.call_args[1]
|
||||
assert "custom_headers" in call_args
|
||||
assert (
|
||||
call_args["custom_headers"]["Authorization"] == "Bearer test-iam-token"
|
||||
)
|
||||
assert call_args["custom_headers"]["X-Custom-Header"] == "custom-value"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_watsonx_proxy_route_different_endpoints(self):
|
||||
"""Test various Watsonx endpoint paths."""
|
||||
endpoints = [
|
||||
"ml/v1/text/generation",
|
||||
"ml/v1/text/tokenization",
|
||||
"ml/v1/deployments/test-deployment/text/generation",
|
||||
"ml/v1/models",
|
||||
]
|
||||
|
||||
for endpoint_path in endpoints:
|
||||
# Setup mocks
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.method = "POST"
|
||||
mock_request.query_params = {}
|
||||
mock_request.headers = {"content-type": "application/json"}
|
||||
mock_request.json = AsyncMock(return_value={"input": "test"})
|
||||
mock_response = MagicMock(spec=Response)
|
||||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||
|
||||
# Mock provider config
|
||||
mock_provider_config = MagicMock()
|
||||
mock_provider_config.get_complete_url.return_value = (
|
||||
f"https://us-south.ml.cloud.ibm.com/{endpoint_path}",
|
||||
{},
|
||||
)
|
||||
mock_provider_config.validate_environment.return_value = {
|
||||
"Authorization": "Bearer test-iam-token"
|
||||
}
|
||||
|
||||
# Mock endpoint function
|
||||
mock_endpoint_func = AsyncMock(return_value={})
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config",
|
||||
return_value=mock_provider_config,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route",
|
||||
return_value=mock_endpoint_func,
|
||||
) as mock_create_route,
|
||||
):
|
||||
await watsonx_proxy_route(
|
||||
endpoint=endpoint_path,
|
||||
request=mock_request,
|
||||
fastapi_response=mock_response,
|
||||
user_api_key_dict=mock_user_api_key_dict,
|
||||
)
|
||||
|
||||
# Verify endpoint is passed correctly
|
||||
mock_create_route.assert_called_once()
|
||||
call_args = mock_create_route.call_args[1]
|
||||
assert call_args["endpoint"] == endpoint_path
|
||||
assert (
|
||||
call_args["target"]
|
||||
== f"https://us-south.ml.cloud.ibm.com/{endpoint_path}"
|
||||
)
|
||||
@ -3185,3 +3185,358 @@ async def test_view_spend_logs_date_range_hashes_sk_api_key(client, monkeypatch)
|
||||
assert where["api_key"] == "hashed::sk-raw-admin-token"
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
|
||||
class _SpendScopeMockPrismaClient:
|
||||
|
||||
def __init__(self, get_data_returns=None, find_many_returns=None):
|
||||
self._get_data_returns = (
|
||||
get_data_returns if get_data_returns is not None else []
|
||||
)
|
||||
self._find_many_returns = (
|
||||
find_many_returns if find_many_returns is not None else []
|
||||
)
|
||||
self.get_data_calls = []
|
||||
self.find_many_calls = []
|
||||
|
||||
client = self
|
||||
|
||||
class _VerificationTokenTable:
|
||||
async def find_many(self, where=None, order=None, include=None):
|
||||
client.find_many_calls.append(
|
||||
{"where": where, "order": order, "include": include}
|
||||
)
|
||||
return client._find_many_returns
|
||||
|
||||
class _DB:
|
||||
def __init__(self):
|
||||
self.litellm_verificationtoken = _VerificationTokenTable()
|
||||
|
||||
self.db = _DB()
|
||||
|
||||
async def get_data(self, table_name=None, query_type=None, **kwargs):
|
||||
self.get_data_calls.append(
|
||||
{"table_name": table_name, "query_type": query_type, **kwargs}
|
||||
)
|
||||
if query_type == "find_unique":
|
||||
return self._get_data_returns[0] if self._get_data_returns else None
|
||||
return self._get_data_returns
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_key_fn_proxy_admin_returns_all_keys(client, monkeypatch):
|
||||
"""Admins keep their existing full-table view of /spend/keys."""
|
||||
mock_keys = [
|
||||
{"token": "hashed-a", "user_id": "alice", "spend": 10.0},
|
||||
{"token": "hashed-b", "user_id": "bob", "spend": 5.0},
|
||||
]
|
||||
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=mock_keys)
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
|
||||
|
||||
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin"
|
||||
)
|
||||
try:
|
||||
response = client.get(
|
||||
"/spend/keys", headers={"Authorization": "Bearer sk-test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
# Admin path: goes through get_data (full table), never the scoped find_many
|
||||
assert len(mock_prisma.get_data_calls) == 1
|
||||
assert mock_prisma.get_data_calls[0]["table_name"] == "key"
|
||||
assert mock_prisma.get_data_calls[0]["query_type"] == "find_all"
|
||||
assert mock_prisma.find_many_calls == []
|
||||
assert response.json() == mock_keys
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_key_fn_proxy_admin_view_only_returns_all_keys(client, monkeypatch):
|
||||
"""View-only admins are still admins for this endpoint."""
|
||||
mock_keys = [{"token": "hashed-a", "user_id": "alice"}]
|
||||
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=mock_keys)
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
|
||||
|
||||
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, user_id="admin_viewer"
|
||||
)
|
||||
try:
|
||||
response = client.get(
|
||||
"/spend/keys", headers={"Authorization": "Bearer sk-test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert mock_prisma.find_many_calls == []
|
||||
assert len(mock_prisma.get_data_calls) == 1
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"role",
|
||||
[LitellmUserRoles.INTERNAL_USER, LitellmUserRoles.INTERNAL_USER_VIEW_ONLY],
|
||||
)
|
||||
async def test_spend_key_fn_internal_user_scoped_to_own_keys(client, monkeypatch, role):
|
||||
"""Both internal-user roles must only see keys they own."""
|
||||
caller_owned_keys = [
|
||||
{"token": "hashed-mine-1", "user_id": "alice", "spend": 2.0},
|
||||
{"token": "hashed-mine-2", "user_id": "alice", "spend": 1.0},
|
||||
]
|
||||
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=caller_owned_keys)
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
|
||||
|
||||
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
|
||||
user_role=role, user_id="alice"
|
||||
)
|
||||
try:
|
||||
response = client.get(
|
||||
"/spend/keys", headers={"Authorization": "Bearer sk-test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
# Non-admin path goes through the same get_data helper as admin,
|
||||
# but with a user_id scope so only the caller's rows come back.
|
||||
assert mock_prisma.find_many_calls == []
|
||||
assert len(mock_prisma.get_data_calls) == 1
|
||||
call = mock_prisma.get_data_calls[0]
|
||||
assert call["table_name"] == "key"
|
||||
assert call["query_type"] == "find_all"
|
||||
assert call["user_id"] == "alice"
|
||||
assert response.json() == caller_owned_keys
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_key_fn_internal_user_without_user_id_returns_empty(
|
||||
client, monkeypatch
|
||||
):
|
||||
"""
|
||||
A non-admin key with no user_id has no tenant scope. Returning the full
|
||||
table would re-introduce the leak; return an empty list instead.
|
||||
"""
|
||||
mock_prisma = _SpendScopeMockPrismaClient(
|
||||
get_data_returns=[{"token": "do-not-leak"}],
|
||||
find_many_returns=[{"token": "do-not-leak"}],
|
||||
)
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
|
||||
|
||||
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER, user_id=None
|
||||
)
|
||||
try:
|
||||
response = client.get(
|
||||
"/spend/keys", headers={"Authorization": "Bearer sk-test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
assert mock_prisma.get_data_calls == []
|
||||
assert mock_prisma.find_many_calls == []
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_user_fn_proxy_admin_returns_all_users_without_user_id(
|
||||
client, monkeypatch
|
||||
):
|
||||
"""Admins keep their existing full-table view of /spend/users."""
|
||||
mock_users = [
|
||||
{"user_id": "alice", "user_email": "alice@example.com", "spend": 1.0},
|
||||
{"user_id": "bob", "user_email": "bob@example.com", "spend": 2.0},
|
||||
]
|
||||
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=mock_users)
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
|
||||
|
||||
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin"
|
||||
)
|
||||
try:
|
||||
response = client.get(
|
||||
"/spend/users", headers={"Authorization": "Bearer sk-test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(mock_prisma.get_data_calls) == 1
|
||||
assert mock_prisma.get_data_calls[0]["table_name"] == "user"
|
||||
assert mock_prisma.get_data_calls[0]["query_type"] == "find_all"
|
||||
assert response.json() == mock_users
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_user_fn_proxy_admin_can_query_specific_user_id(
|
||||
client, monkeypatch
|
||||
):
|
||||
"""Admins can still target a specific user_id."""
|
||||
mock_user = {
|
||||
"user_id": "carol",
|
||||
"user_email": "carol@example.com",
|
||||
"spend": 7.0,
|
||||
}
|
||||
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[mock_user])
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
|
||||
|
||||
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin"
|
||||
)
|
||||
try:
|
||||
response = client.get(
|
||||
"/spend/users",
|
||||
params={"user_id": "carol"},
|
||||
headers={"Authorization": "Bearer sk-test"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(mock_prisma.get_data_calls) == 1
|
||||
assert mock_prisma.get_data_calls[0]["query_type"] == "find_unique"
|
||||
assert mock_prisma.get_data_calls[0]["user_id"] == "carol"
|
||||
assert response.json() == [mock_user]
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"role",
|
||||
[LitellmUserRoles.INTERNAL_USER, LitellmUserRoles.INTERNAL_USER_VIEW_ONLY],
|
||||
)
|
||||
async def test_spend_user_fn_internal_user_scoped_without_user_id(
|
||||
client, monkeypatch, role
|
||||
):
|
||||
"""No user_id supplied -> must query the caller's own row, not the table."""
|
||||
own_row = {"user_id": "alice", "user_email": "alice@example.com", "spend": 3.0}
|
||||
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[own_row])
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
|
||||
|
||||
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
|
||||
user_role=role, user_id="alice"
|
||||
)
|
||||
try:
|
||||
response = client.get(
|
||||
"/spend/users", headers={"Authorization": "Bearer sk-test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(mock_prisma.get_data_calls) == 1
|
||||
assert mock_prisma.get_data_calls[0]["query_type"] == "find_unique"
|
||||
assert mock_prisma.get_data_calls[0]["user_id"] == "alice"
|
||||
assert response.json() == [own_row]
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_user_fn_internal_user_supplying_other_user_id_returns_403(
|
||||
client, monkeypatch
|
||||
):
|
||||
"""
|
||||
An internal user passing user_id=victim must be rejected outright, not
|
||||
silently rewritten. A 403 makes the attempt observable in logs.
|
||||
"""
|
||||
leaked_victim_row = {
|
||||
"user_id": "victim",
|
||||
"user_email": "victim@example.com",
|
||||
"spend": 999.0,
|
||||
}
|
||||
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[leaked_victim_row])
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
|
||||
|
||||
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER, user_id="alice"
|
||||
)
|
||||
try:
|
||||
response = client.get(
|
||||
"/spend/users",
|
||||
params={"user_id": "victim"},
|
||||
headers={"Authorization": "Bearer sk-test"},
|
||||
)
|
||||
assert response.status_code == 403
|
||||
assert mock_prisma.get_data_calls == []
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_user_fn_internal_user_supplying_own_user_id_is_allowed(
|
||||
client, monkeypatch
|
||||
):
|
||||
"""
|
||||
Passing your own user_id explicitly is fine — the 403 only fires when
|
||||
the supplied id differs from the caller's.
|
||||
"""
|
||||
own_row = {"user_id": "alice", "user_email": "alice@example.com", "spend": 3.0}
|
||||
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[own_row])
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
|
||||
|
||||
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER, user_id="alice"
|
||||
)
|
||||
try:
|
||||
response = client.get(
|
||||
"/spend/users",
|
||||
params={"user_id": "alice"},
|
||||
headers={"Authorization": "Bearer sk-test"},
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert len(mock_prisma.get_data_calls) == 1
|
||||
assert mock_prisma.get_data_calls[0]["query_type"] == "find_unique"
|
||||
assert mock_prisma.get_data_calls[0]["user_id"] == "alice"
|
||||
assert response.json() == [own_row]
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_user_fn_internal_user_without_user_id_returns_empty(
|
||||
client, monkeypatch
|
||||
):
|
||||
"""
|
||||
A non-admin key with no user_id has no tenant scope -> return empty,
|
||||
never the full table. Same defensive contract as /spend/keys.
|
||||
"""
|
||||
mock_prisma = _SpendScopeMockPrismaClient(
|
||||
get_data_returns=[{"user_id": "do-not-leak"}]
|
||||
)
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
|
||||
|
||||
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, user_id=None
|
||||
)
|
||||
try:
|
||||
response = client.get(
|
||||
"/spend/users", headers={"Authorization": "Bearer sk-test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
assert response.json() == []
|
||||
assert mock_prisma.get_data_calls == []
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spend_user_fn_strips_password_field(client, monkeypatch):
|
||||
"""
|
||||
Existing password-redaction behavior must be preserved on the scoped
|
||||
path so we don't regress a separate disclosure when adding the fix.
|
||||
"""
|
||||
own_row = {
|
||||
"user_id": "alice",
|
||||
"user_email": "alice@example.com",
|
||||
"password": "hashed-password-must-not-leak",
|
||||
"spend": 1.0,
|
||||
}
|
||||
mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[own_row])
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma)
|
||||
|
||||
app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.INTERNAL_USER, user_id="alice"
|
||||
)
|
||||
try:
|
||||
response = client.get(
|
||||
"/spend/users", headers={"Authorization": "Bearer sk-test"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
body = response.json()
|
||||
assert len(body) == 1
|
||||
assert "password" not in body[0]
|
||||
finally:
|
||||
app.dependency_overrides.pop(ps.user_api_key_auth, None)
|
||||
|
||||
32
tests/test_litellm/test__types.py
Normal file
32
tests/test_litellm/test__types.py
Normal file
@ -0,0 +1,32 @@
|
||||
# tests/test_litellm/proxy/test__types.py
|
||||
|
||||
from litellm.proxy._types import LiteLLM_TeamMembership
|
||||
|
||||
|
||||
def test_team_membership_budget_table_optional_no_crash():
|
||||
"""
|
||||
Regression test for #28689
|
||||
Pydantic v2: Optional[T] without default = required field.
|
||||
When budget_id is null, DB join returns no litellm_budget_table key.
|
||||
model_validate must NOT raise 'Field required'.
|
||||
"""
|
||||
data = {
|
||||
"user_id": "test-user",
|
||||
"team_id": "test-team",
|
||||
"budget_id": None,
|
||||
# litellm_budget_table intentionally absent (as DB join returns when budget_id is null)
|
||||
}
|
||||
result = LiteLLM_TeamMembership.model_validate(data)
|
||||
assert result.litellm_budget_table is None
|
||||
|
||||
|
||||
def test_team_membership_budget_table_present_still_works():
|
||||
"""When budget_id exists, litellm_budget_table should still be populated."""
|
||||
data = {
|
||||
"user_id": "test-user",
|
||||
"team_id": "test-team",
|
||||
"budget_id": "some-budget-id",
|
||||
"litellm_budget_table": None,
|
||||
}
|
||||
result = LiteLLM_TeamMembership.model_validate(data)
|
||||
assert result.litellm_budget_table is None
|
||||
47
tests/test_litellm/test_bedrock_usgov_haiku_1hr_cache.py
Normal file
47
tests/test_litellm/test_bedrock_usgov_haiku_1hr_cache.py
Normal file
@ -0,0 +1,47 @@
|
||||
"""
|
||||
Validate that AWS GovCloud (Bedrock us-gov-*) Haiku 4.5 entries carry
|
||||
the 1-hour cache write tier.
|
||||
|
||||
AWS Bedrock GovCloud pricing applies a +20% premium over global
|
||||
Anthropic rates. Global Haiku 4.5 1h cache write is $2.00/MTok; us-gov
|
||||
is therefore $2.40/MTok — exactly 1.6x the 5-minute rate of $1.50/MTok.
|
||||
|
||||
Source: https://aws.amazon.com/bedrock/pricing/
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_data():
|
||||
json_path = os.path.join(
|
||||
os.path.dirname(__file__), "../../model_prices_and_context_window.json"
|
||||
)
|
||||
with open(json_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
HAIKU_USGOV_KEYS = [
|
||||
"bedrock/us-gov-east-1/anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"bedrock/us-gov-west-1/anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_key", HAIKU_USGOV_KEYS)
|
||||
def test_usgov_haiku_4_5_1hr_cache_write(model_data, model_key):
|
||||
assert model_key in model_data, f"Missing model entry: {model_key}"
|
||||
info = model_data[model_key]
|
||||
assert (
|
||||
info["cache_creation_input_token_cost"] == 1.5e-06
|
||||
), f"{model_key}: 5m cache write should be $1.50/MTok"
|
||||
assert (
|
||||
info["cache_creation_input_token_cost_above_1hr"] == 2.4e-06
|
||||
), f"{model_key}: 1h cache write should be $2.40/MTok"
|
||||
ratio = (
|
||||
info["cache_creation_input_token_cost_above_1hr"]
|
||||
/ info["cache_creation_input_token_cost"]
|
||||
)
|
||||
assert abs(ratio - 1.6) < 1e-9, f"{model_key}: 1h/5m ratio is {ratio}, expected 1.6"
|
||||
132
tests/test_litellm/test_bedrock_usgov_pricing.py
Normal file
132
tests/test_litellm/test_bedrock_usgov_pricing.py
Normal file
@ -0,0 +1,132 @@
|
||||
"""
|
||||
Validate AWS GovCloud (Bedrock us-gov-*) Anthropic pricing entries.
|
||||
|
||||
AWS Bedrock pricing in GovCloud carries a +20% premium over the global
|
||||
Anthropic prices (not the +10% commercial-US premium). Until 2026-05-22
|
||||
these entries silently mirrored commercial US, undercharging customers
|
||||
by ~9%.
|
||||
|
||||
Source: https://aws.amazon.com/bedrock/pricing/
|
||||
|
||||
Sonnet 4.5 in us-gov-* (per million tokens):
|
||||
input = $3.60
|
||||
output = $18.00
|
||||
cache write 5m = $4.50
|
||||
cache write 1h = $7.20
|
||||
cache read = $0.36
|
||||
|
||||
Reference: https://github.com/BerriAI/litellm/issues/27120
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def model_data():
|
||||
json_path = os.path.join(
|
||||
os.path.dirname(__file__), "../../model_prices_and_context_window.json"
|
||||
)
|
||||
with open(json_path) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
SONNET_4_5_USGOV_KEYS = [
|
||||
"bedrock/us-gov-east-1/anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0",
|
||||
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0",
|
||||
"us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model_key", SONNET_4_5_USGOV_KEYS)
|
||||
def test_usgov_sonnet_4_5_pricing(model_data, model_key):
|
||||
"""Each us-gov sonnet-4-5 entry must carry the +20%-over-global rates
|
||||
that AWS publishes on the GovCloud pricing page.
|
||||
"""
|
||||
assert model_key in model_data, f"Missing model entry: {model_key}"
|
||||
info = model_data[model_key]
|
||||
|
||||
assert info["input_cost_per_token"] == 3.6e-06, (
|
||||
f"{model_key}: input_cost_per_token should be $3.60/MTok "
|
||||
f"(got {info['input_cost_per_token']})"
|
||||
)
|
||||
assert (
|
||||
info["output_cost_per_token"] == 1.8e-05
|
||||
), f"{model_key}: output_cost_per_token should be $18.00/MTok"
|
||||
assert (
|
||||
info["cache_creation_input_token_cost"] == 4.5e-06
|
||||
), f"{model_key}: 5m cache write should be $4.50/MTok"
|
||||
assert (
|
||||
info["cache_creation_input_token_cost_above_1hr"] == 7.2e-06
|
||||
), f"{model_key}: 1h cache write should be $7.20/MTok"
|
||||
assert (
|
||||
info["cache_read_input_token_cost"] == 3.6e-07
|
||||
), f"{model_key}: cache read should be $0.36/MTok"
|
||||
|
||||
|
||||
def test_usgov_carries_20_percent_premium_over_global(model_data):
|
||||
"""The us-gov rates must equal 1.2x the global anthropic.* rates,
|
||||
matching AWS's documented GovCloud uplift.
|
||||
"""
|
||||
global_key = "anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
usgov_key = "bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
global_info = model_data[global_key]
|
||||
usgov_info = model_data[usgov_key]
|
||||
for field in (
|
||||
"input_cost_per_token",
|
||||
"output_cost_per_token",
|
||||
"cache_creation_input_token_cost",
|
||||
"cache_creation_input_token_cost_above_1hr",
|
||||
"cache_read_input_token_cost",
|
||||
):
|
||||
ratio = usgov_info[field] / global_info[field]
|
||||
assert (
|
||||
abs(ratio - 1.2) < 1e-9
|
||||
), f"{field}: us-gov / global ratio is {ratio}, expected 1.2"
|
||||
|
||||
|
||||
# The us-gov.anthropic.* cross-region inference profile is the only us-gov
|
||||
# entry that carries the 1M-context `_above_200k_tokens` pricing tier — the
|
||||
# bedrock/us-gov-{east,west}-1/ entries are capped at 200k tokens.
|
||||
USGOV_CROSS_REGION_KEY = "us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
|
||||
EXPECTED_USGOV_ABOVE_200K = {
|
||||
"input_cost_per_token_above_200k_tokens": 7.2e-06,
|
||||
"output_cost_per_token_above_200k_tokens": 2.7e-05,
|
||||
"cache_creation_input_token_cost_above_200k_tokens": 9.0e-06,
|
||||
"cache_creation_input_token_cost_above_1hr_above_200k_tokens": 1.44e-05,
|
||||
"cache_read_input_token_cost_above_200k_tokens": 7.2e-07,
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize("field,expected", EXPECTED_USGOV_ABOVE_200K.items())
|
||||
def test_usgov_cross_region_above_200k_carries_gov_premium(model_data, field, expected):
|
||||
"""The `_above_200k_tokens` tier on the us-gov cross-region inference
|
||||
profile must also carry the +20% GovCloud uplift. The original PR
|
||||
corrected the base rates but left the 200k-tier fields at the +10%
|
||||
commercial-US rates, undercharging long-context requests.
|
||||
"""
|
||||
info = model_data[USGOV_CROSS_REGION_KEY]
|
||||
assert field in info, f"{USGOV_CROSS_REGION_KEY}: missing field {field}"
|
||||
assert (
|
||||
info[field] == expected
|
||||
), f"{USGOV_CROSS_REGION_KEY}: {field} should be {expected} (got {info[field]})"
|
||||
|
||||
|
||||
def test_usgov_cross_region_above_200k_ratio_to_global(model_data):
|
||||
"""Cross-check via the property-based invariant: every `_above_200k_tokens`
|
||||
field on the us-gov cross-region profile must equal 1.2x the global
|
||||
anthropic.* rate, the same GovCloud uplift the base tier carries.
|
||||
"""
|
||||
global_key = "anthropic.claude-sonnet-4-5-20250929-v1:0"
|
||||
global_info = model_data[global_key]
|
||||
usgov_info = model_data[USGOV_CROSS_REGION_KEY]
|
||||
for field in EXPECTED_USGOV_ABOVE_200K:
|
||||
ratio = usgov_info[field] / global_info[field]
|
||||
assert (
|
||||
abs(ratio - 1.2) < 1e-9
|
||||
), f"{field}: us-gov / global ratio is {ratio}, expected 1.2"
|
||||
Loading…
Reference in New Issue
Block a user