diff --git a/gateway/routes/allowlist.py b/gateway/routes/allowlist.py index cbbf55c987..144bb4c473 100644 --- a/gateway/routes/allowlist.py +++ b/gateway/routes/allowlist.py @@ -106,6 +106,7 @@ GATEWAY_PATH_PREFIXES: tuple[str, ...] = ( # Health & ops "/health", "/metrics", + "/watsonx" ) GATEWAY_EXACT_PATHS: frozenset[str] = frozenset( diff --git a/litellm/constants.py b/litellm/constants.py index 20625a80bf..df15050e65 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -771,6 +771,7 @@ openai_compatible_endpoints: List = [ "https://api.moonshot.ai/v1", "https://api.publicai.co/v1", "https://api.synthetic.new/openai/v1", + "https://serverless.tensormesh.ai/v1", "https://api.stima.tech/v1", "https://nano-gpt.com/api/v1", "https://api.poe.com/v1", @@ -820,6 +821,7 @@ openai_compatible_providers: List = [ "meta_llama", "publicai", # PublicAI - JSON-configured provider "synthetic", # Synthetic - JSON-configured provider + "tensormesh", # Tensormesh - JSON-configured provider "apertis", # Apertis - JSON-configured provider "nano-gpt", # Nano-GPT - JSON-configured provider "poe", # Poe - JSON-configured provider @@ -855,6 +857,7 @@ openai_text_completion_compatible_providers: List = ( "moonshot", "publicai", "synthetic", + "tensormesh", "apertis", "nano-gpt", "poe", diff --git a/litellm/integrations/focus/transformer.py b/litellm/integrations/focus/transformer.py index 6f4433b4a0..8496b7ec15 100644 --- a/litellm/integrations/focus/transformer.py +++ b/litellm/integrations/focus/transformer.py @@ -95,7 +95,9 @@ class FocusTransformer: pl.lit("Usage-Based").alias("ChargeFrequency"), fmt(pl.col("ChargePeriodEnd")).alias("ChargePeriodEnd"), fmt(pl.col("ChargePeriodStart")).alias("ChargePeriodStart"), - dec(pl.lit(1.0)).alias("ConsumedQuantity"), + dec( + pl.col("api_requests").cast(pl.Int64).cast(pl.Float64).fill_null(0.0) + ).alias("ConsumedQuantity"), pl.lit("Requests").alias("ConsumedUnit"), dec(pl.col("spend").fill_null(0.0)).alias("ContractedCost"), none_str.alias("ContractedUnitPrice"), @@ -107,7 +109,9 @@ class FocusTransformer: none_str.alias("AvailabilityZone"), pl.lit("USD").alias("PricingCurrency"), none_str.alias("PricingCategory"), - dec(pl.lit(1.0)).alias("PricingQuantity"), + dec( + pl.col("api_requests").cast(pl.Int64).cast(pl.Float64).fill_null(0.0) + ).alias("PricingQuantity"), none_dec.alias("PricingCurrencyContractedUnitPrice"), dec(pl.col("spend").fill_null(0.0)).alias("PricingCurrencyEffectiveCost"), none_dec.alias("PricingCurrencyListUnitPrice"), diff --git a/litellm/integrations/prometheus.py b/litellm/integrations/prometheus.py index 5f05284212..9fc0980736 100644 --- a/litellm/integrations/prometheus.py +++ b/litellm/integrations/prometheus.py @@ -511,6 +511,23 @@ class PrometheusLogger(CustomLogger): labelnames=self.get_labels_for_metric("litellm_cached_tokens_metric"), ) + # Provider prompt-caching metrics + self.litellm_provider_cache_read_input_tokens_metric = self._counter_factory( + name="litellm_provider_cache_read_input_tokens_metric", + documentation="Total prompt/input tokens read from provider prompt cache (e.g. OpenAI/Anthropic/Gemini/Bedrock)", + labelnames=self.get_labels_for_metric( + "litellm_provider_cache_read_input_tokens_metric" + ), + ) + + self.litellm_provider_cache_creation_input_tokens_metric = self._counter_factory( + name="litellm_provider_cache_creation_input_tokens_metric", + documentation="Total prompt/input tokens written to provider prompt cache (e.g. Anthropic/Bedrock)", + labelnames=self.get_labels_for_metric( + "litellm_provider_cache_creation_input_tokens_metric" + ), + ) + # User and Team count metrics self.litellm_total_users_metric = self._gauge_factory( "litellm_total_users", @@ -1458,11 +1475,11 @@ class PrometheusLogger(CustomLogger): """ cache_hit = standard_logging_payload.get("cache_hit") - # Only track if cache_hit has a definite value (True or False) if cache_hit is None: - return - - if cache_hit is True: + # Historically these metrics only tracked LiteLLM caching. + # Provider prompt-caching metrics are still emitted below. + pass + elif cache_hit is True: # Increment cache hits counter PrometheusLogger._inc_labeled_counter( self, @@ -1493,6 +1510,51 @@ class PrometheusLogger(CustomLogger): label_context=label_context, ) + # Provider prompt caching metrics are independent of LiteLLM cache_hit. + provider_cache_read_tokens = 0 + provider_cache_creation_tokens = 0 + usage_obj = (standard_logging_payload.get("metadata", {}) or {}).get( + "usage_object" + ) + if isinstance(usage_obj, dict): + # Prefer explicit provider cache fields when available. + _read = usage_obj.get("cache_read_input_tokens") + _write = usage_obj.get("cache_creation_input_tokens") + + if isinstance(_read, int): + provider_cache_read_tokens = _read + if isinstance(_write, int): + provider_cache_creation_tokens = _write + + # Fallback to prompt_tokens_details.cached_tokens (common normalization point). + # Only fallback when the explicit field is genuinely absent (None). + if _read is None: + prompt_details = usage_obj.get("prompt_tokens_details") + if isinstance(prompt_details, dict): + cached_tokens = prompt_details.get("cached_tokens") + if isinstance(cached_tokens, int): + provider_cache_read_tokens = cached_tokens + + if provider_cache_read_tokens > 0: + PrometheusLogger._inc_labeled_counter( + self, + self.litellm_provider_cache_read_input_tokens_metric, + "litellm_provider_cache_read_input_tokens_metric", + enum_values, + label_context=label_context, + amount=float(provider_cache_read_tokens), + ) + + if provider_cache_creation_tokens > 0: + PrometheusLogger._inc_labeled_counter( + self, + self.litellm_provider_cache_creation_input_tokens_metric, + "litellm_provider_cache_creation_input_tokens_metric", + enum_values, + label_context=label_context, + amount=float(provider_cache_creation_tokens), + ) + async def _increment_remaining_budget_metrics( self, user_api_team: Optional[str], diff --git a/litellm/litellm_core_utils/fallback_utils.py b/litellm/litellm_core_utils/fallback_utils.py index 52eb35663b..daacca85c8 100644 --- a/litellm/litellm_core_utils/fallback_utils.py +++ b/litellm/litellm_core_utils/fallback_utils.py @@ -47,8 +47,9 @@ async def async_completion_with_fallbacks(**kwargs): completion_kwargs = safe_deep_copy(base_kwargs) # Handle dictionary fallback configurations if isinstance(fallback, dict): - model = fallback.pop("model", original_model) - completion_kwargs.update(fallback) + fallback_config = safe_deep_copy(dict(fallback)) + model = fallback_config.pop("model", original_model) + completion_kwargs.update(fallback_config) else: model = fallback diff --git a/litellm/llms/bedrock/base_aws_llm.py b/litellm/llms/bedrock/base_aws_llm.py index b659c1b0a0..b1b0682938 100644 --- a/litellm/llms/bedrock/base_aws_llm.py +++ b/litellm/llms/bedrock/base_aws_llm.py @@ -1534,10 +1534,9 @@ class BaseAWSLLM: ) sigv4 = SigV4Auth(credentials, service_name, aws_region_name) - if headers is not None: + headers = headers or {} + if not any(header_name.lower() == "content-type" for header_name in headers): headers = {"Content-Type": "application/json", **headers} - else: - headers = {"Content-Type": "application/json"} aws_signature_headers = self._filter_headers_for_aws_signature(headers) request = AWSRequest( diff --git a/litellm/llms/bedrock/files/transformation.py b/litellm/llms/bedrock/files/transformation.py index 6669363093..cec2e934af 100644 --- a/litellm/llms/bedrock/files/transformation.py +++ b/litellm/llms/bedrock/files/transformation.py @@ -233,6 +233,259 @@ class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig): # example; add others here as they adopt the same schema. CONVERSE_INVOKE_PROVIDERS = ("nova",) + # OpenAI batch URL that signals an embedding request. Per OpenAI Batch API + # spec, every JSONL record carries a `url` field; we use it as the + # authoritative signal to route the line to the embedding code path + # instead of inferring from the presence of `input` vs `messages`. + OPENAI_EMBEDDINGS_URL = "/v1/embeddings" + + @staticmethod + def _is_embedding_record(openai_jsonl_record: Dict[str, Any]) -> bool: + """ + Decide whether an OpenAI batch JSONL line is an embedding request. + + Precedence (strict - any explicit `url` short-circuits): + 1. `url == "/v1/embeddings"` -> embedding. Authoritative per the + OpenAI Batch API spec. + 2. Any other non-empty `url` (e.g. `/v1/chat/completions`) -> NOT + embedding. We trust the caller's explicit signal even if the + body would otherwise suggest embedding; misrouting a chat + record into the embedding transformer would corrupt the + modelInput, while a chat-shaped body sent to the chat path + either succeeds or fails cleanly inside that transformer. + 3. `url` missing/empty -> fall back to body shape. Requires + `input` present AND `messages` absent so a malformed record + carrying both keys routes to the chat path (safer default: + Anthropic transforms ignore unknown top-level keys, whereas + the embedding transformer would silently drop the messages). + """ + url = openai_jsonl_record.get("url") + if url == BedrockFilesConfig.OPENAI_EMBEDDINGS_URL: + return True + if url: + return False + body = openai_jsonl_record.get("body", {}) + if not isinstance(body, dict): + return False + return "input" in body and "messages" not in body + + # Identifier for the Bedrock Titan v2 InvokeModel body schema as stored + # in `model_prices_and_context_window.json`. Centralized so future + # embedding-schema variants can add their own value + # (e.g. `cohere_v3`, `titan_g1`, `titan_multimodal`) without touching + # the detection logic. + _TITAN_V2_INVOCATION_SCHEMA = "titan_v2" + + # Substring marker used as a fallback when the registry can't resolve + # the model id - notably cross-region inference profile prefixes + # (`us.amazon.titan-embed-text-v2:0`) and Bedrock ARN forms, which + # `get_model_info` doesn't normalize today. + _TITAN_V2_EMBED_MODEL_MARKER = "titan-embed-text-v2" + + # Nested field name under `provider_specific_entry` that identifies the + # Bedrock InvokeModel body schema for batch inference. + # `provider_specific_entry` is the registry's escape hatch for fields + # `get_model_info` doesn't promote to top-level - exactly what we need + # here. Documented in the `sample_spec` entry of + # `model_prices_and_context_window.json` and surfaced by + # `get_model_info` (see `ModelInfo.provider_specific_entry`). + _BEDROCK_INVOCATION_SCHEMA_FIELD = "bedrock_invocation_schema" + + @staticmethod + def _is_titan_v2_embed_model(model: str) -> bool: + """ + True iff `model` refers to Amazon Titan Text Embeddings V2. + + Resolution order: + 1. `model_prices_and_context_window.json` via `get_model_info`. + The Titan v2 registry entry carries an explicit + `provider_specific_entry.bedrock_invocation_schema` discriminator + (`"titan_v2"`). When the registry resolves the id we trust that + field as the source of truth - no hardcoded model-id comparison + needed. + 2. Substring fallback (`titan-embed-text-v2` followed by `:`, `/`, + or end-of-string) for ids the registry can't normalize. This + catches cross-region inference profile prefixes + (`us.amazon.titan-embed-text-v2:0`) and Bedrock ARN forms; the + marker boundary check rejects lookalikes like + `titan-embed-text-v20` or `titan-embed-text-v2-experimental`. + + Tolerant of common id shapes: + - "amazon.titan-embed-text-v2:0" + - "bedrock/amazon.titan-embed-text-v2:0" + - "us.amazon.titan-embed-text-v2:0" (cross-region inference profile) + - ARN forms ending in ".../amazon.titan-embed-text-v2:0" + """ + # Registry-driven path: when get_model_info resolves the id we trust + # the registry's discriminator. A resolved id with a different (or + # absent) schema value here is intentionally not given a substring + # second-chance - the registry is authoritative for ids it knows. + registry_schema = BedrockFilesConfig._lookup_provider_specific_field( + model, BedrockFilesConfig._BEDROCK_INVOCATION_SCHEMA_FIELD + ) + if registry_schema is not None: + return registry_schema == BedrockFilesConfig._TITAN_V2_INVOCATION_SCHEMA + + # Registry silence -> substring fallback for unmapped ids only. + normalized = model.lower() + if normalized.startswith("bedrock/"): + normalized = normalized[len("bedrock/") :] + marker = BedrockFilesConfig._TITAN_V2_EMBED_MODEL_MARKER + idx = normalized.find(marker) + if idx < 0: + return False + end = idx + len(marker) + return end == len(normalized) or normalized[end] in (":", "/") + + @staticmethod + def _lookup_provider_specific_field(model_id: str, field: str) -> Optional[str]: + """ + Read a nested string field from the registry entry's + `provider_specific_entry` dict via `litellm.get_model_info`. + + Returns the field's string value when: + - the registry resolves `model_id`, + - the entry exposes `provider_specific_entry` as a dict, and + - that dict has `field` mapped to a non-empty string. + Otherwise returns `None`. + + Isolating this means feature detectors (Titan v2 today, future + Cohere Embed / Nova Multimodal branches) share one defensive + try/except shape instead of duplicating it. The `None` return + covers every realistic failure mode: `get_model_info` raises + (cross-region profile prefixes, Bedrock ARN forms, unreleased + models), returns a non-dict, has no `provider_specific_entry`, or + the requested field is missing / non-string / empty. + """ + try: + from litellm import get_model_info + + info = get_model_info(model_id) + except Exception: + return None + if not isinstance(info, dict): + return None + provider_specific = info.get("provider_specific_entry") + if not isinstance(provider_specific, dict): + return None + value = provider_specific.get(field) + return value if isinstance(value, str) and value else None + + @staticmethod + def _coerce_embedding_input_to_string(raw_input: Any, model: str = "") -> str: + """ + Normalize an OpenAI /v1/embeddings `input` field into the single + string that Bedrock Titan v2 InvokeModel expects in `inputText`. + + Accepts: a string, or a single-element list containing one string. + Rejects (with actionable messages): + - None / missing -> ValueError + - Multi-element string lists -> ValueError, prompts caller to + emit one JSONL line per input + - Pre-tokenized inputs (List[int], List[List[int]]) -> NotImplementedError + - Any other type -> ValueError + + Extracted so the validation can be exercised in isolation and so + future embedding-provider branches (Titan G1, Cohere) can reuse it + without duplicating the type-shaping logic. + """ + if raw_input is None: + raise ValueError( + "Embedding batch record is missing required `input` field: " + f"model={model}" + ) + + # Bedrock InvokeModel for Titan v2 takes exactly one string `inputText` + # per call. Pre-tokenized inputs and multi-element string lists are + # explicitly unsupported so callers emit one JSONL line per embedding + # instead of relying on us to silently fan out or concatenate. + if isinstance(raw_input, list): + if len(raw_input) == 1: + candidate = raw_input[0] + else: + raise ValueError( + "Bedrock batch embedding requires one input per JSONL " + "record. Got a list with " + f"{len(raw_input)} items for model={model}; emit one " + "JSONL line per input string instead." + ) + else: + candidate = raw_input + + # Catches pre-tokenized inputs (List[int] from OpenAI spec, or a + # single int slipping past the list-unwrap above). + # NOTE: bool is a subclass of int but treating True/False as a token + # is meaningless either way, so the broad check is fine. + if isinstance(candidate, (list, int)): + raise NotImplementedError( + "Bedrock Titan v2 batch embedding does not support " + "pre-tokenized integer inputs. Pass `input` as a string " + f"(model={model})." + ) + if not isinstance(candidate, str): + raise ValueError( + "Bedrock batch embedding `input` must be a string (or a " + "single-element list of strings). Got type " + f"{type(candidate).__name__} for model={model}." + ) + return candidate + + def _map_openai_embedding_to_bedrock_params( + self, + openai_request_body: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Transform an OpenAI /v1/embeddings request body into the + Bedrock InvokeModel `modelInput` for embedding models that AWS + supports via batch inference (CreateModelInvocationJob). + + Currently routes Amazon Titan Text Embeddings V2 only; other + embedding providers (Titan G1, Titan Multimodal, Cohere Embed, + Nova Multimodal Embeddings) raise NotImplementedError until they + get a dedicated branch. Splitting them keeps PR scope tight and + lets each model's request schema be exercised by its own tests. + + AWS docs (Titan v2 InvokeModel body): + https://docs.aws.amazon.com/bedrock/latest/userguide/model-parameters-titan-embed-text.html + """ + from litellm.llms.bedrock.embed.amazon_titan_v2_transformation import ( + AmazonTitanV2Config, + ) + + _model = openai_request_body.get("model", "") + if not self._is_titan_v2_embed_model(_model): + # Refuse early instead of silently shaping the body for the wrong + # provider. The synchronous /v1/embeddings path supports more + # models, but each has a different InvokeModel schema; mapping + # them here without dedicated tests would risk corrupt batches. + raise NotImplementedError( + "Bedrock batch embedding currently supports only Amazon " + "Titan Text Embeddings V2 (model id contains " + f"'titan-embed-text-v2'). Got model={_model!r}. Track other " + "embedding models in https://github.com/BerriAI/litellm/issues." + ) + + input_text = self._coerce_embedding_input_to_string( + openai_request_body.get("input"), model=_model + ) + + # Map OpenAI-style params (dimensions, encoding_format) onto the + # Titan v2 schema (dimensions, embeddingTypes) via the embed config + # so this stays in sync with the synchronous /v1/embeddings path. + non_default_params = { + k: v for k, v in openai_request_body.items() if k not in ("model", "input") + } + titan_config = AmazonTitanV2Config() + inference_params = titan_config.map_openai_params( + non_default_params=non_default_params, + optional_params={}, + ) + return dict( + titan_config._transform_request( + input=input_text, inference_params=inference_params + ) + ) + def _map_openai_to_bedrock_params( self, openai_request_body: Dict[str, Any], @@ -349,10 +602,19 @@ class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig): # Determine provider from model name provider = self.get_bedrock_invoke_provider(model) - # Transform to Bedrock modelInput format - model_input = self._map_openai_to_bedrock_params( - openai_request_body=openai_body, provider=provider - ) + # Route to the embedding transformer when the OpenAI batch line + # targets /v1/embeddings; otherwise fall back to the existing + # chat-completion path. We branch here (rather than inside + # `_map_openai_to_bedrock_params`) so the chat helper keeps its + # narrow contract and the embedding helper can evolve independently. + if self._is_embedding_record(_openai_jsonl_content): + model_input = self._map_openai_embedding_to_bedrock_params( + openai_request_body=openai_body + ) + else: + model_input = self._map_openai_to_bedrock_params( + openai_request_body=openai_body, provider=provider + ) # Create Bedrock batch record record_id = _openai_jsonl_content.get( diff --git a/litellm/llms/black_forest_labs/common_utils.py b/litellm/llms/black_forest_labs/common_utils.py index 507ef17c50..237208693f 100644 --- a/litellm/llms/black_forest_labs/common_utils.py +++ b/litellm/llms/black_forest_labs/common_utils.py @@ -5,6 +5,7 @@ Common utilities, constants, and error handling for Black Forest Labs API. """ from typing import Dict +from urllib.parse import urlparse from litellm.llms.base_llm.chat.transformation import BaseLLMException @@ -18,6 +19,42 @@ class BlackForestLabsError(BaseLLMException): # API Constants DEFAULT_API_BASE = "https://api.bfl.ai" +# BFL uses regional subdomains (e.g. gateway.bfl.ai) for polling URLs that +# differ from the submission host (api.bfl.ai). We validate against the +# registered domain rather than doing a strict same-origin check. +_BFL_REGISTERED_DOMAIN = "bfl.ai" + + +def assert_bfl_polling_url(polling_url: str) -> None: + """Validate that a polling URL points to a BFL-controlled host. + + BFL returns polling URLs on subdomains like ``gateway.bfl.ai`` that differ + from the submission host ``api.bfl.ai``. A strict same-origin check would + reject these legitimate URLs. Instead we verify the host is ``bfl.ai`` or + any subdomain of it, which keeps the SSRF guarantee (credentials only go + to BFL-controlled infrastructure) without false-positives on regional hosts. + + Raises: + BlackForestLabsError: If the polling URL scheme or host is not trusted. + """ + parsed = urlparse(polling_url) + host = (parsed.hostname or "").lower() + + if parsed.scheme != "https": + raise BlackForestLabsError( + status_code=502, + message="Rejected polling URL: scheme must be https", + ) + + if host != _BFL_REGISTERED_DOMAIN and not host.endswith( + "." + _BFL_REGISTERED_DOMAIN + ): + raise BlackForestLabsError( + status_code=502, + message="Rejected polling URL: host is not within the bfl.ai domain", + ) + + # Polling configuration DEFAULT_POLLING_INTERVAL = 1.5 # seconds DEFAULT_MAX_POLLING_TIME = 300 # 5 minutes diff --git a/litellm/llms/black_forest_labs/image_edit/handler.py b/litellm/llms/black_forest_labs/image_edit/handler.py index f5784e0836..ab191c165f 100644 --- a/litellm/llms/black_forest_labs/image_edit/handler.py +++ b/litellm/llms/black_forest_labs/image_edit/handler.py @@ -15,7 +15,6 @@ import httpx import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, @@ -29,6 +28,7 @@ from ..common_utils import ( DEFAULT_MAX_POLLING_TIME, DEFAULT_POLLING_INTERVAL, BlackForestLabsError, + assert_bfl_polling_url, ) from .transformation import BlackForestLabsImageEditConfig @@ -332,16 +332,11 @@ class BlackForestLabsImageEdit: message="No polling_url in BFL response", ) - # Reject cross-origin polling URLs — the ``x-key`` auth header - # would otherwise leak to whatever URL the upstream returns. - # VERIA-51. - try: - assert_same_origin(polling_url, str(initial_response.request.url)) - except SSRFError as ssrf_err: - raise BlackForestLabsError( - status_code=502, - message=f"Rejected polling URL: {ssrf_err}", - ) + # Reject polling URLs that don't belong to BFL-controlled infrastructure. + # BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the + # submission host (api.bfl.ai), so we validate against the registered + # domain rather than doing a strict same-origin check. VERIA-51. + assert_bfl_polling_url(polling_url) # Get just the auth header for polling polling_headers = {"x-key": headers.get("x-key", "")} @@ -428,16 +423,11 @@ class BlackForestLabsImageEdit: message="No polling_url in BFL response", ) - # Reject cross-origin polling URLs — the ``x-key`` auth header - # would otherwise leak to whatever URL the upstream returns. - # VERIA-51. - try: - assert_same_origin(polling_url, str(initial_response.request.url)) - except SSRFError as ssrf_err: - raise BlackForestLabsError( - status_code=502, - message=f"Rejected polling URL: {ssrf_err}", - ) + # Reject polling URLs that don't belong to BFL-controlled infrastructure. + # BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the + # submission host (api.bfl.ai), so we validate against the registered + # domain rather than doing a strict same-origin check. VERIA-51. + assert_bfl_polling_url(polling_url) # Get just the auth header for polling polling_headers = {"x-key": headers.get("x-key", "")} diff --git a/litellm/llms/black_forest_labs/image_generation/handler.py b/litellm/llms/black_forest_labs/image_generation/handler.py index 8af4a236fd..f797fac419 100644 --- a/litellm/llms/black_forest_labs/image_generation/handler.py +++ b/litellm/llms/black_forest_labs/image_generation/handler.py @@ -15,7 +15,6 @@ import httpx import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj -from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, @@ -29,6 +28,7 @@ from ..common_utils import ( DEFAULT_MAX_POLLING_TIME, DEFAULT_POLLING_INTERVAL, BlackForestLabsError, + assert_bfl_polling_url, ) from .transformation import BlackForestLabsImageGenerationConfig @@ -172,6 +172,10 @@ class BlackForestLabsImageGeneration: raw_response=final_response, model_response=model_response, logging_obj=logging_obj, + request_data=data, + optional_params=optional_params, + litellm_params=litellm_params_dict, + encoding=None, ) async def async_image_generation( @@ -274,6 +278,10 @@ class BlackForestLabsImageGeneration: raw_response=final_response, model_response=model_response, logging_obj=logging_obj, + request_data=data, + optional_params=optional_params, + litellm_params=litellm_params_dict, + encoding=None, ) def _poll_for_result_sync( @@ -318,16 +326,11 @@ class BlackForestLabsImageGeneration: message="No polling_url in BFL response", ) - # Reject cross-origin polling URLs — the ``x-key`` auth header - # would otherwise leak to whatever URL the upstream returns. - # VERIA-51. - try: - assert_same_origin(polling_url, str(initial_response.request.url)) - except SSRFError as ssrf_err: - raise BlackForestLabsError( - status_code=502, - message=f"Rejected polling URL: {ssrf_err}", - ) + # Reject polling URLs that don't belong to BFL-controlled infrastructure. + # BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the + # submission host (api.bfl.ai), so we validate against the registered + # domain rather than doing a strict same-origin check. VERIA-51. + assert_bfl_polling_url(polling_url) # Get just the auth header for polling polling_headers = {"x-key": headers.get("x-key", "")} @@ -414,16 +417,11 @@ class BlackForestLabsImageGeneration: message="No polling_url in BFL response", ) - # Reject cross-origin polling URLs — the ``x-key`` auth header - # would otherwise leak to whatever URL the upstream returns. - # VERIA-51. - try: - assert_same_origin(polling_url, str(initial_response.request.url)) - except SSRFError as ssrf_err: - raise BlackForestLabsError( - status_code=502, - message=f"Rejected polling URL: {ssrf_err}", - ) + # Reject polling URLs that don't belong to BFL-controlled infrastructure. + # BFL uses regional subdomains (e.g. gateway.bfl.ai) that differ from the + # submission host (api.bfl.ai), so we validate against the registered + # domain rather than doing a strict same-origin check. VERIA-51. + assert_bfl_polling_url(polling_url) # Get just the auth header for polling polling_headers = {"x-key": headers.get("x-key", "")} diff --git a/litellm/llms/openai/chat/guardrail_translation/handler.py b/litellm/llms/openai/chat/guardrail_translation/handler.py index d413a24453..8c9a8228da 100644 --- a/litellm/llms/openai/chat/guardrail_translation/handler.py +++ b/litellm/llms/openai/chat/guardrail_translation/handler.py @@ -376,6 +376,13 @@ class OpenAIChatCompletionsHandler(BaseTranslation): ) guardrailed_texts = guardrailed_inputs.get("texts", []) + returned_tool_calls = guardrailed_inputs.get("tool_calls") + guardrailed_tool_calls: List[Dict[str, Any]] = ( + cast(List[Dict[str, Any]], returned_tool_calls) + if isinstance(returned_tool_calls, list) + and len(returned_tool_calls) == len(tool_calls_to_check) + else tool_calls_to_check + ) # Step 3: Map guardrail responses back to original response structure if guardrailed_texts and texts_to_check: @@ -386,10 +393,10 @@ class OpenAIChatCompletionsHandler(BaseTranslation): ) # Step 4: Apply guardrailed tool calls back to response - if tool_calls_to_check: + if guardrailed_tool_calls: await self._apply_guardrail_responses_to_output_tool_calls( response=response, - tool_calls=tool_calls_to_check, + tool_calls=guardrailed_tool_calls, task_mappings=tool_call_task_mappings, ) @@ -748,10 +755,11 @@ class OpenAIChatCompletionsHandler(BaseTranslation): task_mappings: List[Tuple[int, int]], ) -> None: """ - Apply guardrailed tool calls back to output response. + Apply guardrailed tool calls back to the output response. - The guardrail may have modified the tool_calls list in place, - so we apply the modified tool calls back to the original response. + The guardrail may return updated tool calls (either mutated in place or as + a new list), so we apply the provided tool calls back to the original + response. Override this method to customize how tool call responses are applied. """ diff --git a/litellm/llms/openai_like/providers.json b/litellm/llms/openai_like/providers.json index b5e5aa4ea2..c9257677fd 100644 --- a/litellm/llms/openai_like/providers.json +++ b/litellm/llms/openai_like/providers.json @@ -114,5 +114,14 @@ "param_mappings": { "max_completion_tokens": "max_tokens" } + }, + "tensormesh": { + "base_url": "https://serverless.tensormesh.ai/v1", + "api_key_env": "TENSORMESH_INFERENCE_API_KEY", + "api_base_env": "TENSORMESH_SERVERLESS_BASE_URL", + "base_class": "openai_gpt", + "param_mappings": { + "max_completion_tokens": "max_tokens" + } } } diff --git a/litellm/llms/watsonx/passthrough/__init__.py b/litellm/llms/watsonx/passthrough/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/litellm/llms/watsonx/passthrough/transformation.py b/litellm/llms/watsonx/passthrough/transformation.py new file mode 100644 index 0000000000..9162eef0e0 --- /dev/null +++ b/litellm/llms/watsonx/passthrough/transformation.py @@ -0,0 +1,69 @@ +from typing import TYPE_CHECKING, List, Optional, Tuple + +from litellm.llms.base_llm.passthrough.transformation import BasePassthroughConfig +from litellm.llms.watsonx.common_utils import IBMWatsonXMixin + +if TYPE_CHECKING: + from httpx import URL + + +class WatsonxPassthroughConfig(IBMWatsonXMixin, BasePassthroughConfig): + """ + Watsonx-specific passthrough configuration. + """ + + def is_streaming_request(self, endpoint: str, request_data: dict) -> bool: + """Check if request should be streamed""" + return request_data.get("stream", False) + + def get_complete_url( + self, + api_base: Optional[str], + api_key: Optional[str], + model: str, + endpoint: str, + request_query_params: Optional[dict], + litellm_params: dict, + ) -> Tuple["URL", str]: + """ + Construct complete Watsonx URL with version parameter. + + This ensures the version parameter is ALWAYS included in the URL, + solving the query parameter issue. + """ + base_target_url = str(self.get_api_base(api_base)) + + # Use the format_url helper to construct URL with query params + complete_url = self.format_url( + endpoint=endpoint, + base_target_url=base_target_url, + request_query_params=request_query_params, + ) + + return (complete_url, base_target_url) + + @staticmethod + def get_api_base( + api_base: Optional[str] = None, + ) -> Optional[str]: + return api_base or IBMWatsonXMixin()._get_base_url(api_base=api_base) + + @staticmethod + def get_api_key( + api_key: Optional[str] = None, + ) -> Optional[str]: + return ( + api_key + or IBMWatsonXMixin.get_watsonx_credentials( + optional_params=dict(), api_base=None, api_key=api_key + )["api_key"] + ) + + @staticmethod + def get_base_model(model: str) -> Optional[str]: + return model + + def get_models( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> List[str]: + return super().get_models(api_key, api_base) diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index a66b72fc9f..0ddfec5f63 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -577,7 +577,10 @@ "max_tokens": 8192, "mode": "embedding", "output_cost_per_token": 0.0, - "output_vector_size": 1024 + "output_vector_size": 1024, + "provider_specific_entry": { + "bedrock_invocation_schema": "titan_v2" + } }, "amazon.titan-image-generator-v1": { "input_cost_per_image": 0.0, @@ -8899,15 +8902,16 @@ "cache_creation_input_token_cost": 3.75e-07 }, "bedrock/us-gov-east-1/anthropic.claude-sonnet-4-5-20250929-v1:0": { - "cache_creation_input_token_cost": 4.125e-06, - "cache_read_input_token_cost": 3.3e-07, - "input_cost_per_token": 3.3e-06, + "cache_creation_input_token_cost": 4.5e-06, + "cache_creation_input_token_cost_above_1hr": 7.2e-06, + "cache_read_input_token_cost": 3.6e-07, + "input_cost_per_token": 3.6e-06, "litellm_provider": "bedrock", "max_input_tokens": 200000, "max_output_tokens": 8192, "max_tokens": 8192, "mode": "chat", - "output_cost_per_token": 1.65e-05, + "output_cost_per_token": 1.8e-05, "supports_assistant_prefill": true, "supports_computer_use": true, "supports_function_calling": true, @@ -8920,15 +8924,16 @@ "supports_native_structured_output": true }, "bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": { - "cache_creation_input_token_cost": 4.125e-06, - "cache_read_input_token_cost": 3.3e-07, - "input_cost_per_token": 3.3e-06, + "cache_creation_input_token_cost": 4.5e-06, + "cache_creation_input_token_cost_above_1hr": 7.2e-06, + "cache_read_input_token_cost": 3.6e-07, + "input_cost_per_token": 3.6e-06, "litellm_provider": "bedrock", "max_input_tokens": 200000, "max_output_tokens": 8192, "max_tokens": 8192, "mode": "chat", - "output_cost_per_token": 1.65e-05, + "output_cost_per_token": 1.8e-05, "supports_assistant_prefill": true, "supports_computer_use": true, "supports_function_calling": true, @@ -9072,15 +9077,16 @@ "cache_creation_input_token_cost": 3.75e-07 }, "bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0": { - "cache_creation_input_token_cost": 4.125e-06, - "cache_read_input_token_cost": 3.3e-07, - "input_cost_per_token": 3.3e-06, + "cache_creation_input_token_cost": 4.5e-06, + "cache_creation_input_token_cost_above_1hr": 7.2e-06, + "cache_read_input_token_cost": 3.6e-07, + "input_cost_per_token": 3.6e-06, "litellm_provider": "bedrock", "max_input_tokens": 200000, "max_output_tokens": 8192, "max_tokens": 8192, "mode": "chat", - "output_cost_per_token": 1.65e-05, + "output_cost_per_token": 1.8e-05, "supports_assistant_prefill": true, "supports_computer_use": true, "supports_function_calling": true, @@ -9093,15 +9099,16 @@ "supports_native_structured_output": true }, "bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": { - "cache_creation_input_token_cost": 4.125e-06, - "cache_read_input_token_cost": 3.3e-07, - "input_cost_per_token": 3.3e-06, + "cache_creation_input_token_cost": 4.5e-06, + "cache_creation_input_token_cost_above_1hr": 7.2e-06, + "cache_read_input_token_cost": 3.6e-07, + "input_cost_per_token": 3.6e-06, "litellm_provider": "bedrock", "max_input_tokens": 200000, "max_output_tokens": 8192, "max_tokens": 8192, "mode": "chat", - "output_cost_per_token": 1.65e-05, + "output_cost_per_token": 1.8e-05, "supports_assistant_prefill": true, "supports_computer_use": true, "supports_function_calling": true, @@ -31858,19 +31865,21 @@ "supports_native_structured_output": true }, "us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0": { - "cache_creation_input_token_cost": 4.125e-06, - "cache_read_input_token_cost": 3.3e-07, - "input_cost_per_token": 3.3e-06, - "input_cost_per_token_above_200k_tokens": 6.6e-06, - "output_cost_per_token_above_200k_tokens": 2.475e-05, - "cache_creation_input_token_cost_above_200k_tokens": 8.25e-06, - "cache_read_input_token_cost_above_200k_tokens": 6.6e-07, + "cache_creation_input_token_cost": 4.5e-06, + "cache_creation_input_token_cost_above_1hr": 7.2e-06, + "cache_read_input_token_cost": 3.6e-07, + "input_cost_per_token": 3.6e-06, + "input_cost_per_token_above_200k_tokens": 7.2e-06, + "output_cost_per_token_above_200k_tokens": 2.7e-05, + "cache_creation_input_token_cost_above_200k_tokens": 9.0e-06, + "cache_creation_input_token_cost_above_1hr_above_200k_tokens": 1.44e-05, + "cache_read_input_token_cost_above_200k_tokens": 7.2e-07, "litellm_provider": "bedrock_converse", "max_input_tokens": 200000, "max_output_tokens": 64000, "max_tokens": 64000, "mode": "chat", - "output_cost_per_token": 1.65e-05, + "output_cost_per_token": 1.8e-05, "supports_assistant_prefill": true, "supports_computer_use": true, "supports_function_calling": true, @@ -41448,6 +41457,7 @@ }, "bedrock/us-gov-east-1/anthropic.claude-haiku-4-5-20251001-v1:0": { "cache_creation_input_token_cost": 1.5e-06, + "cache_creation_input_token_cost_above_1hr": 2.4e-06, "cache_read_input_token_cost": 1.2e-07, "input_cost_per_token": 1.2e-06, "litellm_provider": "bedrock", @@ -41470,6 +41480,7 @@ }, "bedrock/us-gov-west-1/anthropic.claude-haiku-4-5-20251001-v1:0": { "cache_creation_input_token_cost": 1.5e-06, + "cache_creation_input_token_cost_above_1hr": 2.4e-06, "cache_read_input_token_cost": 1.2e-07, "input_cost_per_token": 1.2e-06, "litellm_provider": "bedrock", diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 751f855ea3..98a17e4be9 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -419,6 +419,7 @@ class LiteLLMRoutes(enum.Enum): "/vllm", "/mistral", "/milvus", + "/watsonx", ] ######################################################### @@ -3901,7 +3902,9 @@ class LiteLLM_TeamMembership(LiteLLMPydanticObjectBase): # Union so Pydantic picks Full when data has server-managed fields # (/team/info) and Base when callers/tests construct with only # user-settable fields. - litellm_budget_table: Optional[Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable]] + litellm_budget_table: Optional[ + Union[LiteLLM_BudgetTableFull, LiteLLM_BudgetTable] + ] = None def safe_get_team_member_rpm_limit(self) -> Optional[int]: if self.litellm_budget_table is not None: diff --git a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py index d6065ef73f..c6dfe141ab 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py +++ b/litellm/proxy/guardrails/guardrail_hooks/litellm_content_filter/content_filter.py @@ -328,7 +328,24 @@ class ContentFilterGuardrail(CustomGuardrail): return result @staticmethod - def _resolve_category_file_path(file_path: str) -> str: + def _assert_within_categories_dir(path: str, categories_dir: str) -> None: + """Raise ValueError if path escapes the categories directory.""" + resolved = os.path.realpath(path) + allowed = os.path.realpath(categories_dir) + try: + common = os.path.commonpath([resolved, allowed]) + except ValueError: + # commonpath() raises ValueError on Windows when paths span different drives + raise ValueError( + f"Category file path '{path}' is outside the allowed categories directory" + ) + if common != allowed: + raise ValueError( + f"Category file path '{path}' is outside the allowed " + f"categories directory '{categories_dir}'" + ) + + def _resolve_category_file_path(self, file_path: str) -> str: """ Resolve a category file path that may be relative. @@ -339,12 +356,17 @@ class ContentFilterGuardrail(CustomGuardrail): file isn't found. Resolution order: - 1. Return as-is if absolute or already exists. - 2. Try joining the full path relative to this module's directory. + 1. Return as-is if absolute or already exists (jailed to module dir). + 2. Try joining the full path relative to this module's directory (jailed). 3. Progressively strip leading path components and try each suffix - relative to this module's directory (handles paths like - "litellm/proxy/.../policy_templates/file.yaml" by finding the - "policy_templates/file.yaml" suffix that exists). + relative to this module's directory (jailed). + + The directory jail can be disabled for deployments that legitimately + store category files outside the package (e.g. mounted volumes) by + setting the environment variable + ``LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS=true``. Use only in + trusted environments where the proxy configuration cannot be influenced + by untrusted input. Args: file_path: The file path to resolve (absolute or relative). @@ -352,15 +374,33 @@ class ContentFilterGuardrail(CustomGuardrail): Returns: The resolved absolute-ish path, or the original path if resolution fails (caller should check existence). - """ - if os.path.isabs(file_path) or os.path.exists(file_path): - return file_path + Raises: + ValueError: If the resolved path escapes the module directory + and ``LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS`` is not set. + """ module_dir = os.path.dirname(__file__) + allow_external = ( + os.environ.get("LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS", "").lower() + == "true" + ) + + if os.path.isabs(file_path) or os.path.exists(file_path): + if not allow_external: + self._assert_within_categories_dir(file_path, module_dir) + else: + verbose_proxy_logger.warning( + "LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS is set — " + "skipping directory jail for category_file '%s'", + file_path, + ) + return file_path # Try the full relative path joined to the module directory candidate = os.path.join(module_dir, file_path) if os.path.exists(candidate): + if not allow_external: + self._assert_within_categories_dir(candidate, module_dir) return candidate # Progressively strip leading components to find a matching suffix @@ -369,8 +409,17 @@ class ContentFilterGuardrail(CustomGuardrail): suffix = os.path.join(*parts[i:]) candidate = os.path.join(module_dir, suffix) if os.path.exists(candidate): + if not allow_external: + self._assert_within_categories_dir(candidate, module_dir) return candidate + # File not found via any resolution strategy — jail the module-relative + # path anyway to reject traversal attempts (e.g. "../../../../etc/passwd") + # regardless of CWD or whether the target file exists. + if not allow_external: + self._assert_within_categories_dir( + os.path.join(module_dir, file_path), module_dir + ) return file_path def _load_categories(self, categories: List[ContentFilterCategoryConfig]) -> None: @@ -395,6 +444,13 @@ class ContentFilterGuardrail(CustomGuardrail): ) continue + # Prevent path traversal via category_name (e.g. "../../etc/passwd") + if not re.match(r"^[a-zA-Z0-9_\-]+$", category_name): + verbose_proxy_logger.warning( + f"Category name '{category_name}' contains invalid characters, skipping" + ) + continue + enabled = cat_config.get("enabled", True) action = cat_config.get("action") severity_threshold = ( @@ -411,7 +467,13 @@ class ContentFilterGuardrail(CustomGuardrail): # Load category file (custom or default) if custom_file: - category_file_path = self._resolve_category_file_path(custom_file) + try: + category_file_path = self._resolve_category_file_path(custom_file) + except ValueError as e: + verbose_proxy_logger.warning( + f"Category {category_name}: invalid category_file path, skipping. {e}" + ) + continue else: # Try .yaml first, then .json (e.g. harm_toxic_abuse.json) yaml_path = os.path.join(categories_dir, f"{category_name}.yaml") diff --git a/litellm/proxy/guardrails/guardrail_hooks/panw_prisma_airs/panw_prisma_airs.py b/litellm/proxy/guardrails/guardrail_hooks/panw_prisma_airs/panw_prisma_airs.py index bbffc70ddb..e5200394b5 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/panw_prisma_airs/panw_prisma_airs.py +++ b/litellm/proxy/guardrails/guardrail_hooks/panw_prisma_airs/panw_prisma_airs.py @@ -140,7 +140,12 @@ class PanwPrismaAirsHandler(CustomGuardrail): ) self.fallback_on_error = fallback_on_error - self.timeout = timeout + # Coerce defensively. The dashboard UI persists this field as a JSON + # string, and Pydantic extras (the path that splats model_dump into + # this handler) preserve whatever type the user supplied. A string + # value would otherwise reach httpx, which raises TypeError on its + # internal '<=' comparison and surfaces as a misleading api_error. + self.timeout = float(timeout) if timeout is not None else 10.0 # Tri-state: None = not set (default-on for Anthropic), True = explicit on, False = explicit off self.experimental_use_latest_role_message_only: Optional[bool] = kwargs.get( diff --git a/litellm/proxy/guardrails/guardrail_hooks/vigil_guard/__init__.py b/litellm/proxy/guardrails/guardrail_hooks/vigil_guard/__init__.py new file mode 100644 index 0000000000..4263b798f0 --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/vigil_guard/__init__.py @@ -0,0 +1,34 @@ +from typing import TYPE_CHECKING + +from litellm.types.guardrails import SupportedGuardrailIntegrations + +from .vigil_guard import VigilGuardGuardrail + +if TYPE_CHECKING: + from litellm.types.guardrails import Guardrail, LitellmParams + + +def initialize_guardrail(litellm_params: "LitellmParams", guardrail: "Guardrail"): + import litellm + + _vigil_guard_callback = VigilGuardGuardrail( + api_base=litellm_params.api_base, + api_key=litellm_params.api_key, + unreachable_fallback=litellm_params.unreachable_fallback, + timeout=litellm_params.timeout, + guardrail_name=guardrail.get("guardrail_name", ""), + event_hook=litellm_params.mode, + default_on=litellm_params.default_on, + ) + litellm.logging_callback_manager.add_litellm_callback(_vigil_guard_callback) + return _vigil_guard_callback + + +guardrail_initializer_registry = { + SupportedGuardrailIntegrations.VIGIL_GUARD.value: initialize_guardrail, +} + + +guardrail_class_registry = { + SupportedGuardrailIntegrations.VIGIL_GUARD.value: VigilGuardGuardrail, +} diff --git a/litellm/proxy/guardrails/guardrail_hooks/vigil_guard/vigil_guard.py b/litellm/proxy/guardrails/guardrail_hooks/vigil_guard/vigil_guard.py new file mode 100644 index 0000000000..337cb9a9f2 --- /dev/null +++ b/litellm/proxy/guardrails/guardrail_hooks/vigil_guard/vigil_guard.py @@ -0,0 +1,485 @@ +from json import JSONDecodeError +from typing import ( + TYPE_CHECKING, + Any, + Awaitable, + Dict, + List, + Literal, + Optional, + Protocol, + Tuple, + Type, + cast, +) + +import httpx + +from litellm._logging import verbose_proxy_logger +from litellm.exceptions import GuardrailRaisedException +from litellm.exceptions import Timeout as LiteLLMTimeout +from litellm.integrations.custom_guardrail import ( + CustomGuardrail, + log_guardrail_information, +) +from litellm.llms.custom_httpx.http_handler import ( + get_async_httpx_client, + httpxSpecialProvider, +) +from litellm.secret_managers.main import get_secret_str +from litellm.types.guardrails import GuardrailEventHooks +from litellm.types.utils import GenericGuardrailAPIInputs + +if TYPE_CHECKING: + from litellm.litellm_core_utils.litellm_logging import ( + Logging as LiteLLMLoggingObj, + ) + from litellm.types.proxy.guardrails.guardrail_hooks.base import ( + GuardrailConfigModel, + ) + + +_ANALYZE_ENDPOINT = "/v1/guard/analyze" +_DEFAULT_VIGIL_TIMEOUT = httpx.Timeout(10.0, connect=5.0) +_BLOCK_REASON_MAX_CHARS = 500 +_METADATA_STRING_MAX_CHARS = 500 +_METADATA_ARRAY_MAX_ITEMS = 10 +_VALID_DECISIONS = ("ALLOWED", "SANITIZED", "BLOCKED") +_TRANSIENT_STATUS_CODES = frozenset({429, 502, 503, 504}) +_METADATA_ALLOWLIST = ( + "model", + "model_group", + "provider", + "region", + "deployment", + "user", + "user_id", + "session_id", + "conversation_id", + "request_id", + "tenant_id", + "org_id", +) + +_FallbackMode = Literal["fail_closed", "fail_open"] + + +class _AsyncPostHandler(Protocol): + def post( + self, + *, + url: str, + headers: Dict[str, str], + json: Dict[str, Any], + timeout: httpx.Timeout, + ) -> Awaitable[httpx.Response]: ... + + +class VigilGuardMissingConfig(ValueError): + pass + + +class VigilGuardGuardrail(CustomGuardrail): + def __init__( + self, + api_base: Optional[str] = None, + api_key: Optional[str] = None, + unreachable_fallback: Optional[str] = None, + timeout: Optional[float] = None, + async_handler: Optional[_AsyncPostHandler] = None, + **kwargs: Any, + ) -> None: + resolved_base = api_base or get_secret_str("VIGIL_GUARD_URL") + if not resolved_base: + raise VigilGuardMissingConfig( + "Vigil Guard api_base is required. Set api_base in the guardrail " + "config or the VIGIL_GUARD_URL environment variable." + ) + self.api_base = resolved_base.rstrip("/") + + resolved_key = api_key or get_secret_str("VIGIL_GUARD_API_KEY") + if not resolved_key: + raise VigilGuardMissingConfig( + "Vigil Guard api_key is required. Set api_key in the guardrail " + "config or the VIGIL_GUARD_API_KEY environment variable." + ) + self.api_key = resolved_key + + fallback = (unreachable_fallback or "fail_closed").lower() + self.unreachable_fallback: _FallbackMode = ( + "fail_open" if fallback == "fail_open" else "fail_closed" + ) + + self.timeout: httpx.Timeout = ( + _DEFAULT_VIGIL_TIMEOUT + if timeout is None + else httpx.Timeout(timeout, connect=min(timeout, 5.0)) + ) + + self.async_handler: _AsyncPostHandler = async_handler or get_async_httpx_client( + llm_provider=httpxSpecialProvider.GuardrailCallback, + ) + + if "supported_event_hooks" not in kwargs: + kwargs["supported_event_hooks"] = [ + GuardrailEventHooks.pre_call, + GuardrailEventHooks.post_call, + ] + + super().__init__(**kwargs) + + @staticmethod + def get_config_model() -> Optional[Type["GuardrailConfigModel"]]: + from litellm.types.proxy.guardrails.guardrail_hooks.vigil_guard import ( + VigilGuardGuardrailConfigModel, + ) + + return VigilGuardGuardrailConfigModel + + @log_guardrail_information + async def apply_guardrail( + self, + inputs: GenericGuardrailAPIInputs, + request_data: dict, + input_type: Literal["request", "response"], + logging_obj: Optional["LiteLLMLoggingObj"] = None, + ) -> GenericGuardrailAPIInputs: + texts = inputs.get("texts") or [] + has_text = any(isinstance(text, str) and text.strip() for text in texts) + tool_call_args = ( + self._tool_call_arguments(inputs.get("tool_calls")) + if input_type == "response" + else [] + ) + if not has_text and not tool_call_args: + return inputs + + source = "user_input" if input_type == "request" else "model_output" + metadata = self._collect_metadata(request_data, logging_obj) + + result_texts: List[str] = [] + for index, text in enumerate(texts): + if not isinstance(text, str) or not text.strip(): + result_texts.append(text) + continue + + try: + analysis = await self._analyze( + text=text, source=source, metadata=metadata + ) + except ( + httpx.HTTPError, + LiteLLMTimeout, + JSONDecodeError, + OSError, + ) as exc: + return self._handle_backend_failure( + exc, + inputs, + source, + result_texts + list(texts[index:]), + inputs.get("tool_calls"), + ) + + decision = analysis.get("decision") if isinstance(analysis, dict) else None + if decision not in _VALID_DECISIONS: + verbose_proxy_logger.error( + "Vigil Guard unrecognized decision for guardrail_name=%s " + "source=%s: %r", + self.guardrail_name, + source, + decision, + ) + if self.unreachable_fallback == "fail_open": + return self._build_output( + inputs, + result_texts + list(texts[index:]), + inputs.get("tool_calls"), + ) + raise GuardrailRaisedException( + guardrail_name=self.guardrail_name, + message="Vigil Guard returned an unrecognized decision.", + should_wrap_with_default_message=False, + ) + + if decision == "BLOCKED": + raise GuardrailRaisedException( + guardrail_name=self.guardrail_name, + message=self._build_block_reason(analysis), + should_wrap_with_default_message=False, + ) + + if decision == "SANITIZED": + result_texts.append(self._resolve_sanitized_text(text, analysis)) + else: + result_texts.append(text) + + result_tool_calls = inputs.get("tool_calls") + for tc_index, arguments in tool_call_args: + try: + analysis = await self._analyze( + text=arguments, source=source, metadata=metadata + ) + except ( + httpx.HTTPError, + LiteLLMTimeout, + JSONDecodeError, + OSError, + ) as exc: + return self._handle_backend_failure( + exc, inputs, source, result_texts, result_tool_calls + ) + + decision = analysis.get("decision") if isinstance(analysis, dict) else None + if decision not in _VALID_DECISIONS: + verbose_proxy_logger.error( + "Vigil Guard unrecognized decision for guardrail_name=%s " + "source=%s: %r", + self.guardrail_name, + source, + decision, + ) + if self.unreachable_fallback == "fail_open": + return self._build_output(inputs, result_texts, result_tool_calls) + raise GuardrailRaisedException( + guardrail_name=self.guardrail_name, + message="Vigil Guard returned an unrecognized decision.", + should_wrap_with_default_message=False, + ) + + if decision == "BLOCKED": + raise GuardrailRaisedException( + guardrail_name=self.guardrail_name, + message=self._build_block_reason(analysis), + should_wrap_with_default_message=False, + ) + + if decision == "SANITIZED": + result_tool_calls = self._set_tool_call_arguments( + result_tool_calls, + tc_index, + self._resolve_sanitized_text(arguments, analysis), + ) + + return self._build_output(inputs, result_texts, result_tool_calls) + + def _handle_backend_failure( + self, + exc: Exception, + inputs: GenericGuardrailAPIInputs, + source: str, + final_texts: List[Any], + final_tool_calls: Any, + ) -> GenericGuardrailAPIInputs: + if self.unreachable_fallback == "fail_open": + verbose_proxy_logger.error( + "Vigil Guard backend failure with fail_open; allowing request " + "unscanned. guardrail_name=%s source=%s error=%s", + self.guardrail_name, + source, + str(exc), + ) + return self._build_output(inputs, final_texts, final_tool_calls) + verbose_proxy_logger.error( + "Vigil Guard backend failure with fail_closed; blocking request. " + "guardrail_name=%s source=%s error=%s", + self.guardrail_name, + source, + str(exc), + ) + raise GuardrailRaisedException( + guardrail_name=self.guardrail_name, + message="Vigil Guard backend unreachable; request blocked by fail_closed policy.", + should_wrap_with_default_message=False, + ) from exc + + @staticmethod + def _build_output( + inputs: GenericGuardrailAPIInputs, + final_texts: List[Any], + final_tool_calls: Any, + ) -> GenericGuardrailAPIInputs: + # When nothing was changed, return the input shape verbatim so the guardrail + # logs "allow" rather than "mask". When a text or a tool-call argument was + # changed (sanitized), return only the remap-relevant keys and drop + # structured_messages so a stale, unsanitized payload cannot reach the model. + texts_changed = final_texts != (inputs.get("texts") or []) + tool_calls_changed = final_tool_calls != inputs.get("tool_calls") + if not texts_changed and not tool_calls_changed: + return cast(GenericGuardrailAPIInputs, dict(inputs)) + guardrailed: GenericGuardrailAPIInputs = {"texts": final_texts} + if "images" in inputs: + guardrailed["images"] = inputs["images"] + if "tools" in inputs: + guardrailed["tools"] = inputs["tools"] + if tool_calls_changed: + guardrailed["tool_calls"] = final_tool_calls + return guardrailed + + @staticmethod + def _tool_call_arguments(tool_calls: Any) -> List[Tuple[int, str]]: + pairs: List[Tuple[int, str]] = [] + if isinstance(tool_calls, list): + for index, tool_call in enumerate(tool_calls): + function = ( + tool_call.get("function") if isinstance(tool_call, dict) else None + ) + arguments = ( + function.get("arguments") if isinstance(function, dict) else None + ) + if isinstance(arguments, str) and arguments.strip(): + pairs.append((index, arguments)) + return pairs + + @staticmethod + def _set_tool_call_arguments( + tool_calls: Any, index: int, arguments: str + ) -> List[Any]: + updated = list(tool_calls) + tool_call = dict(updated[index]) + function = dict(tool_call.get("function") or {}) + function["arguments"] = arguments + tool_call["function"] = function + updated[index] = tool_call + return updated + + async def _analyze( + self, text: str, source: str, metadata: Dict[str, Any] + ) -> Dict[str, Any]: + payload = { + "text": text, + "source": source, + "mode": "full", + "metadata": metadata, + } + endpoint = f"{self.api_base}{_ANALYZE_ENDPOINT}" + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + } + response = await self._post_with_retry(endpoint, headers, payload) + return response.json() + + async def _post_with_retry( + self, endpoint: str, headers: Dict[str, str], payload: Dict[str, Any] + ) -> httpx.Response: + for attempt in range(2): + try: + response = await self.async_handler.post( + url=endpoint, + headers=headers, + json=payload, + timeout=self.timeout, + ) + response.raise_for_status() + return response + except Exception as exc: + if attempt == 0 and self._is_transient(exc): + verbose_proxy_logger.debug( + "Vigil Guard transient failure; retrying once: %s", + type(exc).__name__, + ) + continue + raise + raise AssertionError("unreachable") # pragma: no cover + + @staticmethod + def _is_transient(exc: Exception) -> bool: + if isinstance(exc, httpx.HTTPStatusError): + return exc.response.status_code in _TRANSIENT_STATUS_CODES + return isinstance( + exc, + ( + httpx.ConnectError, + httpx.ConnectTimeout, + httpx.ReadTimeout, + httpx.RemoteProtocolError, + LiteLLMTimeout, + ), + ) + + @staticmethod + def _build_block_reason(analysis: Dict[str, Any]) -> str: + for key in ("blockMessage", "decisionReason"): + value = analysis.get(key) + if isinstance(value, str) and value.strip(): + return value.strip()[:_BLOCK_REASON_MAX_CHARS] + categories = analysis.get("categories") + if isinstance(categories, list): + names = [c for c in categories if isinstance(c, str) and c.strip()] + if names: + return ", ".join(names)[:_BLOCK_REASON_MAX_CHARS] + return "Blocked by policy" + + @staticmethod + def _resolve_sanitized_text(original: str, analysis: Dict[str, Any]) -> str: + for key in ("sanitizedText", "outputText"): + value = analysis.get(key) + if isinstance(value, str): + return value + return original + + def _collect_metadata( + self, request_data: dict, logging_obj: Optional["LiteLLMLoggingObj"] + ) -> Dict[str, Any]: + sources: List[dict] = [] + if isinstance(request_data, dict): + sources.append(request_data) + for nested_key in ("metadata", "litellm_metadata"): + nested = request_data.get(nested_key) + if isinstance(nested, dict): + sources.append(nested) + + collected: Dict[str, Any] = {} + for field in _METADATA_ALLOWLIST: + for source in sources: + if field in source and source[field] is not None: + clamped = self._clamp_metadata_value(source[field]) + if clamped is not None: + collected[field] = clamped + break + + call_id = self._extract_call_id(request_data, logging_obj) + if call_id: + collected["litellm_call_id"] = call_id + + return collected + + @staticmethod + def _clamp_metadata_value(value: Any) -> Any: + if isinstance(value, bool): + return None + if isinstance(value, str): + return value[:_METADATA_STRING_MAX_CHARS] + if isinstance(value, (int, float)): + return value + if isinstance(value, list): + clamped: List[Any] = [] + for item in value[:_METADATA_ARRAY_MAX_ITEMS]: + if isinstance(item, bool): + continue + if isinstance(item, str): + clamped.append(item[:_METADATA_STRING_MAX_CHARS]) + elif isinstance(item, (int, float)): + clamped.append(item) + return clamped or None + return None + + @staticmethod + def _extract_call_id( + request_data: dict, logging_obj: Optional["LiteLLMLoggingObj"] + ) -> Optional[str]: + if logging_obj is not None: + call_id = getattr(logging_obj, "litellm_call_id", None) + if isinstance(call_id, str) and call_id: + return call_id + if isinstance(request_data, dict): + call_id = request_data.get("litellm_call_id") + if isinstance(call_id, str) and call_id: + return call_id + metadata = request_data.get("metadata") + if isinstance(metadata, dict): + nested = metadata.get("litellm_call_id") + if isinstance(nested, str) and nested: + return nested + return None diff --git a/litellm/proxy/guardrails/guardrail_initializers.py b/litellm/proxy/guardrails/guardrail_initializers.py index 109f223716..9af4395083 100644 --- a/litellm/proxy/guardrails/guardrail_initializers.py +++ b/litellm/proxy/guardrails/guardrail_initializers.py @@ -217,7 +217,15 @@ def initialize_panw_prisma_airs(litellm_params, guardrail): mask_response_content=getattr(litellm_params, "mask_response_content", False), app_name=getattr(litellm_params, "app_name", None), fallback_on_error=getattr(litellm_params, "fallback_on_error", "block"), - timeout=float(getattr(litellm_params, "timeout", 10.0)), + # `timeout` is now declared on BaseLitellmParams (Optional[float] = None), + # so the attribute always exists. The Pydantic validator on LitellmParams + # coerces strings to float, but None still means "use handler default" — + # guard against float(None) here. + timeout=( + float(getattr(litellm_params, "timeout", None)) + if getattr(litellm_params, "timeout", None) is not None + else 10.0 + ), violation_message_template=litellm_params.violation_message_template, ) litellm.logging_callback_manager.add_litellm_callback(_panw_callback) diff --git a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py index e94f56302a..7c3a6f1901 100644 --- a/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/llm_passthrough_endpoints.py @@ -2433,3 +2433,89 @@ def create_generic_websocket_passthrough_endpoint( _forward_headers=forward_headers, cost_per_request=cost_per_request, ) + + +@router.api_route( + "/watsonx/{endpoint:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + tags=["Watsonx Pass-through", "pass-through"], +) +async def watsonx_proxy_route( + endpoint: str, + request: Request, + fastapi_response: Response, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Watsonx pass-through endpoint. + Allows using Watsonx APIs with automatic IAM token management and version parameter injection. + + Example: + POST /watsonx/ml/v1/text/tokenization + POST /watsonx/ml/v1/text/generation + """ + # Direct passthrough with WatsonxPassthroughConfig + from litellm.types.utils import LlmProviders + from litellm.utils import ProviderConfigManager + + provider_config = ProviderConfigManager.get_provider_passthrough_config( + provider=LlmProviders.WATSONX, + model="", + ) + + if provider_config is None: + raise HTTPException( + status_code=404, detail="Watsonx passthrough config not found" + ) + + # Get complete URL with version parameter + complete_url, _ = provider_config.get_complete_url( + api_base=None, + api_key=None, + model="", + endpoint=endpoint, + request_query_params=None, + litellm_params={}, + ) + + # Get auth headers with IAM token + auth_headers = provider_config.validate_environment( + headers={}, + model="", + messages=[], + optional_params={}, + litellm_params={}, + api_key=None, + api_base=None, + ) + + # Check for streaming + is_streaming_request = False + if request.method == "POST": + if "multipart/form-data" not in request.headers.get("content-type", ""): + _request_body = await request.json() + else: + _request_body = await get_form_data(request) + + if _request_body.get("stream"): + is_streaming_request = True + + request_query_params = dict(request.query_params) + if request_query_params.get("version") is None: + request_query_params["version"] = litellm.WATSONX_DEFAULT_API_VERSION + + # Create pass-through endpoint + endpoint_func = create_pass_through_route( + endpoint=endpoint, + target=str(complete_url), + custom_headers=auth_headers, + is_streaming_request=is_streaming_request, + custom_llm_provider="watsonx", + query_params=request_query_params, + ) + + return await endpoint_func( + request, + fastapi_response, + user_api_key_dict, + ) diff --git a/litellm/proxy/spend_tracking/spend_management_endpoints.py b/litellm/proxy/spend_tracking/spend_management_endpoints.py index c7ecd64c0f..6c57fd95b3 100644 --- a/litellm/proxy/spend_tracking/spend_management_endpoints.py +++ b/litellm/proxy/spend_tracking/spend_management_endpoints.py @@ -36,9 +36,18 @@ router = APIRouter() dependencies=[Depends(user_api_key_auth)], include_in_schema=False, ) -async def spend_key_fn(): +async def spend_key_fn( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): """ - View all keys created, ordered by spend + View keys created, ordered by spend. + + - Admin callers (PROXY_ADMIN / PROXY_ADMIN_VIEW_ONLY) see every key in + the database. + - All other callers (INTERNAL_USER / INTERNAL_USER_VIEW_ONLY, etc.) are + scoped to keys they own (``user_id == caller``). A caller with no + ``user_id`` has no scope and receives an empty list rather than the + full table. Example Request: ``` @@ -55,8 +64,17 @@ async def spend_key_fn(): "Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" ) - key_info = await prisma_client.get_data(table_name="key", query_type="find_all") - return key_info + if _is_admin_view_safe(user_api_key_dict=user_api_key_dict): + return await prisma_client.get_data(table_name="key", query_type="find_all") + + caller_user_id = user_api_key_dict.user_id + if not caller_user_id: + return [] + return await prisma_client.get_data( + table_name="key", + query_type="find_all", + user_id=caller_user_id, + ) except Exception as e: raise HTTPException( @@ -85,9 +103,19 @@ async def spend_user_fn( default=None, description="Get User Table row for user_id", ), + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ - View all users created, ordered by spend + View users created, ordered by spend. + + - Admin callers (PROXY_ADMIN / PROXY_ADMIN_VIEW_ONLY) see every user, or + a specific user when ``user_id`` is supplied. + - All other callers may only read their own row. If they supply a + ``user_id`` query parameter that does not match their authenticated + ``user_id`` the request is rejected with HTTP 403; supplying their + own id (or none at all) returns just their row. A caller with no + ``user_id`` on their key has no scope and receives an empty list + rather than the full table. Example Request: ``` @@ -109,6 +137,17 @@ async def spend_user_fn( "Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" ) + if not _is_admin_view_safe(user_api_key_dict=user_api_key_dict): + caller_user_id = user_api_key_dict.user_id + if not caller_user_id: + return [] + if user_id is not None and user_id != caller_user_id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail={"error": "Not authorized to view spend for another user."}, + ) + user_id = caller_user_id + if user_id is not None: user_info = await prisma_client.get_data( table_name="user", query_type="find_unique", user_id=user_id @@ -123,6 +162,8 @@ async def spend_user_fn( _strip_password_from_users(result) return result + except HTTPException: + raise except Exception as e: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, diff --git a/litellm/types/guardrails.py b/litellm/types/guardrails.py index 744d467f87..25c0bcabb4 100644 --- a/litellm/types/guardrails.py +++ b/litellm/types/guardrails.py @@ -41,6 +41,9 @@ from litellm.types.proxy.guardrails.guardrail_hooks.hiddenlayer import ( from litellm.types.proxy.guardrails.guardrail_hooks.qohash import ( QostodianNexusConfigModel, ) +from litellm.types.proxy.guardrails.guardrail_hooks.vigil_guard import ( + VigilGuardGuardrailConfigModel, +) """ Pydantic object defining how to set guardrails on litellm proxy @@ -103,6 +106,7 @@ class SupportedGuardrailIntegrations(Enum): LLM_AS_A_JUDGE = "llm_as_a_judge" QOSTODIAN_NEXUS = "qostodian_nexus" RUBRIK = "rubrik" + VIGIL_GUARD = "vigil_guard" class Role(Enum): @@ -758,6 +762,15 @@ class BaseLitellmParams( description="Python-like code containing the apply_guardrail function for custom guardrail logic", ) + timeout: Optional[float] = Field( + default=None, + description=( + "Per-request timeout for the guardrail provider API call (seconds). " + "Accepts int, float, or numeric string; coerced to float on load. " + "Each guardrail handler chooses its own default when unset." + ), + ) + model_config = ConfigDict(extra="allow", protected_namespaces=()) @@ -791,6 +804,7 @@ class LitellmParams( BlockCodeExecutionGuardrailConfigModel, HiddenlayerGuardrailConfigModel, QostodianNexusConfigModel, + VigilGuardGuardrailConfigModel, ): guardrail: str = Field(description="The type of guardrail integration to use") mode: Union[str, List[str], Mode] = Field( @@ -814,6 +828,18 @@ class LitellmParams( return [x.lower() if isinstance(x, str) else x for x in v] return v + @field_validator("timeout", mode="before", check_fields=False) + @classmethod + def coerce_timeout(cls, v): + """Accept string-valued timeouts (dashboard UI sends JSON strings) + and coerce to float before any handler reads the value.""" + if v is None or v == "": + return None + try: + return float(v) + except (TypeError, ValueError) as e: + raise ValueError(f"timeout must be numeric, got {v!r}") from e + def __init__(self, **kwargs): default_on = kwargs.pop("default_on", None) if default_on is not None: diff --git a/litellm/types/integrations/prometheus.py b/litellm/types/integrations/prometheus.py index 827d10985c..55f4fc9650 100644 --- a/litellm/types/integrations/prometheus.py +++ b/litellm/types/integrations/prometheus.py @@ -238,6 +238,9 @@ DEFINED_PROMETHEUS_METRICS = Literal[ "litellm_cache_hits_metric", "litellm_cache_misses_metric", "litellm_cached_tokens_metric", + # Provider prompt-caching metrics (e.g. OpenAI/Anthropic/Bedrock/Gemini) + "litellm_provider_cache_read_input_tokens_metric", + "litellm_provider_cache_creation_input_tokens_metric", "litellm_deployment_tpm_limit", "litellm_deployment_rpm_limit", "litellm_remaining_api_key_requests_for_model", @@ -655,6 +658,10 @@ class PrometheusMetricLabels: litellm_cache_misses_metric = _cache_metric_labels litellm_cached_tokens_metric = _cache_metric_labels + # Provider prompt-caching metrics - track tokens read/written to provider caches + litellm_provider_cache_read_input_tokens_metric = _cache_metric_labels + litellm_provider_cache_creation_input_tokens_metric = _cache_metric_labels + # Metrics whose emission paths supply org context (used by get_labels) _org_label_metrics: ClassVar[frozenset] = frozenset( { @@ -672,7 +679,6 @@ class PrometheusMetricLabels: "litellm_output_tokens_metric", } ) - # Managed batch metrics _batch_user_labels = [ UserAPIKeyLabelNames.v1_LITELLM_MODEL_NAME.value, diff --git a/litellm/types/proxy/guardrails/guardrail_hooks/vigil_guard.py b/litellm/types/proxy/guardrails/guardrail_hooks/vigil_guard.py new file mode 100644 index 0000000000..6d41c24ecc --- /dev/null +++ b/litellm/types/proxy/guardrails/guardrail_hooks/vigil_guard.py @@ -0,0 +1,26 @@ +from typing import Optional + +from pydantic import Field + +from .base import GuardrailConfigModel + + +class VigilGuardGuardrailConfigModel(GuardrailConfigModel): + api_base: Optional[str] = Field( + default=None, + description=( + "Vigil Guard API base URL. " + "Falls back to the VIGIL_GUARD_URL environment variable." + ), + ) + api_key: Optional[str] = Field( + default=None, + description=( + "Vigil Guard API key. " + "Falls back to the VIGIL_GUARD_API_KEY environment variable." + ), + ) + + @staticmethod + def ui_friendly_name() -> str: + return "Vigil Guard" diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 7f22a7cc21..a0d8f78b3b 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -3365,6 +3365,7 @@ class LlmProviders(str, Enum): POE = "poe" CHUTES = "chutes" XIAOMI_MIMO = "xiaomi_mimo" + TENSORMESH = "tensormesh" LITELLM_AGENT = "litellm_agent" CURSOR = "cursor" BEDROCK_MANTLE = "bedrock_mantle" diff --git a/litellm/utils.py b/litellm/utils.py index a3a26c338b..0a2bf53228 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8985,6 +8985,12 @@ class ProviderConfigManager: ) return AzurePassthroughConfig() + elif LlmProviders.WATSONX == provider: + from litellm.llms.watsonx.passthrough.transformation import ( + WatsonxPassthroughConfig, + ) + + return WatsonxPassthroughConfig() return None @staticmethod diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 112096f9b5..b2698ab3a8 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -577,7 +577,10 @@ "max_tokens": 8192, "mode": "embedding", "output_cost_per_token": 0.0, - "output_vector_size": 1024 + "output_vector_size": 1024, + "provider_specific_entry": { + "bedrock_invocation_schema": "titan_v2" + } }, "amazon.titan-image-generator-v1": { "input_cost_per_image": 0.0, @@ -8899,15 +8902,16 @@ "cache_creation_input_token_cost": 3.75e-07 }, "bedrock/us-gov-east-1/anthropic.claude-sonnet-4-5-20250929-v1:0": { - "cache_creation_input_token_cost": 4.125e-06, - "cache_read_input_token_cost": 3.3e-07, - "input_cost_per_token": 3.3e-06, + "cache_creation_input_token_cost": 4.5e-06, + "cache_creation_input_token_cost_above_1hr": 7.2e-06, + "cache_read_input_token_cost": 3.6e-07, + "input_cost_per_token": 3.6e-06, "litellm_provider": "bedrock", "max_input_tokens": 200000, "max_output_tokens": 8192, "max_tokens": 8192, "mode": "chat", - "output_cost_per_token": 1.65e-05, + "output_cost_per_token": 1.8e-05, "supports_assistant_prefill": true, "supports_computer_use": true, "supports_function_calling": true, @@ -8920,15 +8924,16 @@ "supports_native_structured_output": true }, "bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": { - "cache_creation_input_token_cost": 4.125e-06, - "cache_read_input_token_cost": 3.3e-07, - "input_cost_per_token": 3.3e-06, + "cache_creation_input_token_cost": 4.5e-06, + "cache_creation_input_token_cost_above_1hr": 7.2e-06, + "cache_read_input_token_cost": 3.6e-07, + "input_cost_per_token": 3.6e-06, "litellm_provider": "bedrock", "max_input_tokens": 200000, "max_output_tokens": 8192, "max_tokens": 8192, "mode": "chat", - "output_cost_per_token": 1.65e-05, + "output_cost_per_token": 1.8e-05, "supports_assistant_prefill": true, "supports_computer_use": true, "supports_function_calling": true, @@ -9072,15 +9077,16 @@ "cache_creation_input_token_cost": 3.75e-07 }, "bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0": { - "cache_creation_input_token_cost": 4.125e-06, - "cache_read_input_token_cost": 3.3e-07, - "input_cost_per_token": 3.3e-06, + "cache_creation_input_token_cost": 4.5e-06, + "cache_creation_input_token_cost_above_1hr": 7.2e-06, + "cache_read_input_token_cost": 3.6e-07, + "input_cost_per_token": 3.6e-06, "litellm_provider": "bedrock", "max_input_tokens": 200000, "max_output_tokens": 8192, "max_tokens": 8192, "mode": "chat", - "output_cost_per_token": 1.65e-05, + "output_cost_per_token": 1.8e-05, "supports_assistant_prefill": true, "supports_computer_use": true, "supports_function_calling": true, @@ -9093,15 +9099,16 @@ "supports_native_structured_output": true }, "bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": { - "cache_creation_input_token_cost": 4.125e-06, - "cache_read_input_token_cost": 3.3e-07, - "input_cost_per_token": 3.3e-06, + "cache_creation_input_token_cost": 4.5e-06, + "cache_creation_input_token_cost_above_1hr": 7.2e-06, + "cache_read_input_token_cost": 3.6e-07, + "input_cost_per_token": 3.6e-06, "litellm_provider": "bedrock", "max_input_tokens": 200000, "max_output_tokens": 8192, "max_tokens": 8192, "mode": "chat", - "output_cost_per_token": 1.65e-05, + "output_cost_per_token": 1.8e-05, "supports_assistant_prefill": true, "supports_computer_use": true, "supports_function_calling": true, @@ -31733,19 +31740,21 @@ "supports_native_structured_output": true }, "us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0": { - "cache_creation_input_token_cost": 4.125e-06, - "cache_read_input_token_cost": 3.3e-07, - "input_cost_per_token": 3.3e-06, - "input_cost_per_token_above_200k_tokens": 6.6e-06, - "output_cost_per_token_above_200k_tokens": 2.475e-05, - "cache_creation_input_token_cost_above_200k_tokens": 8.25e-06, - "cache_read_input_token_cost_above_200k_tokens": 6.6e-07, + "cache_creation_input_token_cost": 4.5e-06, + "cache_creation_input_token_cost_above_1hr": 7.2e-06, + "cache_read_input_token_cost": 3.6e-07, + "input_cost_per_token": 3.6e-06, + "input_cost_per_token_above_200k_tokens": 7.2e-06, + "output_cost_per_token_above_200k_tokens": 2.7e-05, + "cache_creation_input_token_cost_above_200k_tokens": 9.0e-06, + "cache_creation_input_token_cost_above_1hr_above_200k_tokens": 1.44e-05, + "cache_read_input_token_cost_above_200k_tokens": 7.2e-07, "litellm_provider": "bedrock_converse", "max_input_tokens": 200000, "max_output_tokens": 64000, "max_tokens": 64000, "mode": "chat", - "output_cost_per_token": 1.65e-05, + "output_cost_per_token": 1.8e-05, "supports_assistant_prefill": true, "supports_computer_use": true, "supports_function_calling": true, @@ -41332,6 +41341,7 @@ }, "bedrock/us-gov-east-1/anthropic.claude-haiku-4-5-20251001-v1:0": { "cache_creation_input_token_cost": 1.5e-06, + "cache_creation_input_token_cost_above_1hr": 2.4e-06, "cache_read_input_token_cost": 1.2e-07, "input_cost_per_token": 1.2e-06, "litellm_provider": "bedrock", @@ -41354,6 +41364,7 @@ }, "bedrock/us-gov-west-1/anthropic.claude-haiku-4-5-20251001-v1:0": { "cache_creation_input_token_cost": 1.5e-06, + "cache_creation_input_token_cost_above_1hr": 2.4e-06, "cache_read_input_token_cost": 1.2e-07, "input_cost_per_token": 1.2e-06, "litellm_provider": "bedrock", diff --git a/provider_endpoints_support.json b/provider_endpoints_support.json index 388752b032..abd03e1c95 100644 --- a/provider_endpoints_support.json +++ b/provider_endpoints_support.json @@ -2079,6 +2079,24 @@ "a2a": false } }, + "tensormesh": { + "display_name": "Tensormesh (`tensormesh`)", + "url": "https://docs.litellm.ai/docs/providers/tensormesh", + "endpoints": { + "chat_completions": true, + "messages": true, + "responses": false, + "embeddings": false, + "image_generations": false, + "audio_transcriptions": false, + "audio_speech": false, + "moderations": false, + "batches": false, + "rerank": false, + "a2a": false, + "text_completion": true + } + }, "text-completion-codestral": { "display_name": "Text Completion Codestral (`text-completion-codestral`)", "url": "https://docs.litellm.ai/docs/providers/codestral", diff --git a/tests/test_litellm/integrations/focus/test_focus_transformer.py b/tests/test_litellm/integrations/focus/test_focus_transformer.py new file mode 100644 index 0000000000..7e90f7d0a2 --- /dev/null +++ b/tests/test_litellm/integrations/focus/test_focus_transformer.py @@ -0,0 +1,69 @@ +"""Tests for FocusTransformer — ConsumedQuantity / PricingQuantity correctness.""" + +from __future__ import annotations + +from decimal import Decimal + +import polars as pl + +from litellm.integrations.focus.transformer import FocusTransformer + + +def _base_row(**overrides) -> dict: + row = { + "date": "2026-05-25", + "user_id": "u1", + "api_key": "sk-test", + "api_key_alias": "my-key", + "model": "gpt-4o", + "model_group": "openai", + "custom_llm_provider": "openai", + "spend": 0.05, + "api_requests": 3, + "team_id": "team1", + "team_alias": "Engineering", + "user_email": "user@example.com", + } + row.update(overrides) + return row + + +def _transform(rows: list[dict]) -> pl.DataFrame: + frame = pl.DataFrame(rows, infer_schema_length=None) + return FocusTransformer().transform(frame) + + +def test_consumed_quantity_reflects_api_requests(): + result = _transform([_base_row(api_requests=7)]) + assert result["ConsumedQuantity"][0] == Decimal("7.000000") + + +def test_pricing_quantity_reflects_api_requests(): + result = _transform([_base_row(api_requests=7)]) + assert result["PricingQuantity"][0] == Decimal("7.000000") + + +def test_null_api_requests_falls_back_to_zero_not_one(): + """Rows with NULL api_requests (old schema rows) must produce 0, not 1.""" + result = _transform([_base_row(api_requests=None)]) + assert result["ConsumedQuantity"][0] == Decimal("0.000000") + assert result["PricingQuantity"][0] == Decimal("0.000000") + + +def test_zero_api_requests_stays_zero(): + result = _transform([_base_row(api_requests=0)]) + assert result["ConsumedQuantity"][0] == Decimal("0.000000") + assert result["PricingQuantity"][0] == Decimal("0.000000") + + +def test_bigint_api_requests_cast_correctly(): + """api_requests comes from Postgres as BigInt — large values must not overflow.""" + result = _transform([_base_row(api_requests=1_000_000)]) + assert result["ConsumedQuantity"][0] == Decimal("1000000.000000") + assert result["PricingQuantity"][0] == Decimal("1000000.000000") + + +def test_consumed_and_pricing_quantity_match(): + """ConsumedQuantity and PricingQuantity must always be equal.""" + result = _transform([_base_row(api_requests=42)]) + assert result["ConsumedQuantity"][0] == result["PricingQuantity"][0] diff --git a/tests/test_litellm/integrations/test_prometheus_cache_metrics.py b/tests/test_litellm/integrations/test_prometheus_cache_metrics.py index 88148ce137..6c9923322f 100644 --- a/tests/test_litellm/integrations/test_prometheus_cache_metrics.py +++ b/tests/test_litellm/integrations/test_prometheus_cache_metrics.py @@ -35,6 +35,8 @@ class TestPrometheusCacheMetrics: assert "litellm_cache_hits_metric" in defined_metrics assert "litellm_cache_misses_metric" in defined_metrics assert "litellm_cached_tokens_metric" in defined_metrics + assert "litellm_provider_cache_read_input_tokens_metric" in defined_metrics + assert "litellm_provider_cache_creation_input_tokens_metric" in defined_metrics def test_cache_metric_labels_defined(self): """Test that cache metric labels are properly defined""" @@ -44,6 +46,13 @@ class TestPrometheusCacheMetrics: assert hasattr(PrometheusMetricLabels, "litellm_cache_hits_metric") assert hasattr(PrometheusMetricLabels, "litellm_cache_misses_metric") assert hasattr(PrometheusMetricLabels, "litellm_cached_tokens_metric") + assert hasattr( + PrometheusMetricLabels, "litellm_provider_cache_read_input_tokens_metric" + ) + assert hasattr( + PrometheusMetricLabels, + "litellm_provider_cache_creation_input_tokens_metric", + ) # Verify labels include expected keys expected_labels = [ @@ -59,6 +68,14 @@ class TestPrometheusCacheMetrics: assert label in PrometheusMetricLabels.litellm_cache_hits_metric assert label in PrometheusMetricLabels.litellm_cache_misses_metric assert label in PrometheusMetricLabels.litellm_cached_tokens_metric + assert ( + label + in PrometheusMetricLabels.litellm_provider_cache_read_input_tokens_metric + ) + assert ( + label + in PrometheusMetricLabels.litellm_provider_cache_creation_input_tokens_metric + ) def test_increment_cache_metrics_on_cache_hit(self, sample_enum_values): """Test that cache hit increments the correct metrics""" @@ -76,12 +93,20 @@ class TestPrometheusCacheMetrics: "completion_tokens": 50, "model_group": "openai", "request_tags": [], + "metadata": { + "usage_object": { + "cache_read_input_tokens": 25, + "cache_creation_input_tokens": 10, + } + }, } # Create mock metrics mock_logger.litellm_cache_hits_metric = MagicMock() mock_logger.litellm_cache_misses_metric = MagicMock() mock_logger.litellm_cached_tokens_metric = MagicMock() + mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock() + mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock() mock_logger.get_labels_for_metric = MagicMock( return_value=[ "model", @@ -114,6 +139,14 @@ class TestPrometheusCacheMetrics: # Verify cache misses metric was NOT called mock_logger.litellm_cache_misses_metric.labels.assert_not_called() + # Verify provider prompt caching metrics were incremented + mock_logger.litellm_provider_cache_read_input_tokens_metric.labels().inc.assert_called_once_with( + 25 + ) + mock_logger.litellm_provider_cache_creation_input_tokens_metric.labels().inc.assert_called_once_with( + 10 + ) + def test_increment_cache_metrics_on_cache_miss(self, sample_enum_values): """Test that cache miss increments the correct metrics""" # Create mock for PrometheusLogger instance @@ -129,12 +162,20 @@ class TestPrometheusCacheMetrics: "completion_tokens": 50, "model_group": "openai", "request_tags": [], + "metadata": { + "usage_object": { + # Explicit provider field absent -> fallback should use prompt_tokens_details.cached_tokens + "prompt_tokens_details": {"cached_tokens": 20}, + } + }, } # Create mock metrics mock_logger.litellm_cache_hits_metric = MagicMock() mock_logger.litellm_cache_misses_metric = MagicMock() mock_logger.litellm_cached_tokens_metric = MagicMock() + mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock() + mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock() mock_logger.get_labels_for_metric = MagicMock( return_value=[ "model", @@ -162,6 +203,61 @@ class TestPrometheusCacheMetrics: mock_logger.litellm_cache_hits_metric.labels.assert_not_called() mock_logger.litellm_cached_tokens_metric.labels.assert_not_called() + # Provider prompt caching metrics should still be emitted + mock_logger.litellm_provider_cache_read_input_tokens_metric.labels().inc.assert_called_once_with( + 20 + ) + mock_logger.litellm_provider_cache_creation_input_tokens_metric.labels.assert_not_called() + + def test_provider_cache_read_does_not_fallback_on_explicit_zero( + self, sample_enum_values + ): + """Explicit cache_read_input_tokens=0 must not trigger fallback to cached_tokens.""" + mock_logger = MagicMock() + + from litellm.integrations.prometheus import PrometheusLogger + + standard_logging_payload = { + "cache_hit": False, + "total_tokens": 100, + "prompt_tokens": 50, + "completion_tokens": 50, + "model_group": "openai", + "request_tags": [], + "metadata": { + "usage_object": { + "cache_read_input_tokens": 0, + "prompt_tokens_details": {"cached_tokens": 20}, + } + }, + } + + mock_logger.litellm_cache_hits_metric = MagicMock() + mock_logger.litellm_cache_misses_metric = MagicMock() + mock_logger.litellm_cached_tokens_metric = MagicMock() + mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock() + mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock() + mock_logger.get_labels_for_metric = MagicMock( + return_value=[ + "model", + "hashed_api_key", + "api_key_alias", + "team", + "team_alias", + "end_user", + "user", + ] + ) + + PrometheusLogger._increment_cache_metrics( + mock_logger, + standard_logging_payload=standard_logging_payload, + enum_values=sample_enum_values, + ) + + # Should not emit read metric, because explicit provider value is zero. + mock_logger.litellm_provider_cache_read_input_tokens_metric.labels.assert_not_called() + def test_increment_cache_metrics_when_cache_hit_is_none(self, sample_enum_values): """Test that no metrics are incremented when cache_hit is None""" # Create mock for PrometheusLogger instance @@ -177,12 +273,19 @@ class TestPrometheusCacheMetrics: "completion_tokens": 50, "model_group": "openai", "request_tags": [], + "metadata": { + "usage_object": { + "cache_read_input_tokens": 25, + } + }, } # Create mock metrics mock_logger.litellm_cache_hits_metric = MagicMock() mock_logger.litellm_cache_misses_metric = MagicMock() mock_logger.litellm_cached_tokens_metric = MagicMock() + mock_logger.litellm_provider_cache_read_input_tokens_metric = MagicMock() + mock_logger.litellm_provider_cache_creation_input_tokens_metric = MagicMock() mock_logger.get_labels_for_metric = MagicMock( return_value=[ "model", @@ -207,6 +310,12 @@ class TestPrometheusCacheMetrics: mock_logger.litellm_cache_misses_metric.labels.assert_not_called() mock_logger.litellm_cached_tokens_metric.labels.assert_not_called() + # Provider prompt caching metrics should still be emitted + mock_logger.litellm_provider_cache_read_input_tokens_metric.labels().inc.assert_called_once_with( + 25 + ) + mock_logger.litellm_provider_cache_creation_input_tokens_metric.labels.assert_not_called() + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_litellm/litellm_core_utils/test_fallback_utils.py b/tests/test_litellm/litellm_core_utils/test_fallback_utils.py new file mode 100644 index 0000000000..0c542ff6a1 --- /dev/null +++ b/tests/test_litellm/litellm_core_utils/test_fallback_utils.py @@ -0,0 +1,43 @@ +import pytest + +import litellm +from litellm.litellm_core_utils.fallback_utils import async_completion_with_fallbacks + + +@pytest.mark.asyncio +async def test_fallback_dict_not_mutated(monkeypatch): + fallback_dict = {"model": "fallback-model", "temperature": 0.2} + original_fallback_dict = dict(fallback_dict) + + attempted_models: list[str] = [] + + async def _fake_acompletion(*, model: str, **kwargs): + attempted_models.append(model) + if model == "primary-model": + raise Exception("primary failed") + return {"model": model, "temperature": kwargs.get("temperature")} + + monkeypatch.setattr(litellm, "acompletion", _fake_acompletion) + + # Call 1: primary fails, fallback dict succeeds + response_1 = await async_completion_with_fallbacks( + model="primary-model", + kwargs={"fallbacks": [fallback_dict]}, + ) + assert response_1["model"] == "fallback-model" + assert fallback_dict == original_fallback_dict + + # Call 2: re-use the same dict object; it should still work and remain unchanged + response_2 = await async_completion_with_fallbacks( + model="primary-model", + kwargs={"fallbacks": [fallback_dict]}, + ) + assert response_2["model"] == "fallback-model" + assert fallback_dict == original_fallback_dict + + assert attempted_models == [ + "primary-model", + "fallback-model", + "primary-model", + "fallback-model", + ] diff --git a/tests/test_litellm/llms/bedrock/files/expected_bedrock_batch_embeddings.jsonl b/tests/test_litellm/llms/bedrock/files/expected_bedrock_batch_embeddings.jsonl new file mode 100644 index 0000000000..e798c39b79 --- /dev/null +++ b/tests/test_litellm/llms/bedrock/files/expected_bedrock_batch_embeddings.jsonl @@ -0,0 +1,3 @@ +{"recordId": "embed-1", "modelInput": {"inputText": "Hello world"}} +{"recordId": "embed-2", "modelInput": {"inputText": "Another document to embed", "dimensions": 512}} +{"recordId": "embed-3", "modelInput": {"inputText": "Single element list", "embeddingTypes": ["binary"]}} diff --git a/tests/test_litellm/llms/bedrock/files/input_batch_embeddings.jsonl b/tests/test_litellm/llms/bedrock/files/input_batch_embeddings.jsonl new file mode 100644 index 0000000000..f87b4eba7e --- /dev/null +++ b/tests/test_litellm/llms/bedrock/files/input_batch_embeddings.jsonl @@ -0,0 +1,3 @@ +{"custom_id": "embed-1", "method": "POST", "url": "/v1/embeddings", "body": {"model": "bedrock/amazon.titan-embed-text-v2:0", "input": "Hello world"}} +{"custom_id": "embed-2", "method": "POST", "url": "/v1/embeddings", "body": {"model": "bedrock/amazon.titan-embed-text-v2:0", "input": "Another document to embed", "dimensions": 512}} +{"custom_id": "embed-3", "method": "POST", "url": "/v1/embeddings", "body": {"model": "bedrock/amazon.titan-embed-text-v2:0", "input": ["Single element list"], "encoding_format": "base64"}} diff --git a/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py b/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py index 5245612e9d..ba41fc47e8 100644 --- a/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py +++ b/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py @@ -426,7 +426,7 @@ class TestBedrockFilesTransformation: "s3_bucket_name": "litellm-batch-352026", "s3_region_name": "us-gov-west-1", } - # aws_region_name set to something different — s3_region_name must still win + # aws_region_name set to something different - s3_region_name must still win optional_params = {"aws_region_name": "us-east-1"} captured_optional_params: dict = {} @@ -482,3 +482,630 @@ class TestBedrockFilesTransformation: assert "messages" in model_input assert "max_tokens" in model_input assert model_input["max_tokens"] == 10 + + +class TestBedrockFilesEmbeddingTransformation: + """ + Tests for routing OpenAI /v1/embeddings batch JSONL records through the + Titan v2 transformer so AWS Bedrock's CreateModelInvocationJob receives + a valid modelInput body. + + Scope is intentionally Titan v2 only - other embedding models will get + their own follow-up PRs/tests so each schema is exercised in isolation. + """ + + def test_titan_v2_embedding_jsonl_matches_fixture(self): + """Round-trip the input fixture against the expected Bedrock output.""" + import json + import os + + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + here = os.path.dirname(__file__) + with open(os.path.join(here, "input_batch_embeddings.jsonl")) as f: + openai_jsonl = [json.loads(line) for line in f if line.strip()] + with open(os.path.join(here, "expected_bedrock_batch_embeddings.jsonl")) as f: + expected = [json.loads(line) for line in f if line.strip()] + + result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + openai_jsonl + ) + + assert result == expected + + def test_titan_v2_simple_string_input(self): + """Single string `input` maps to `{"inputText": }` with no extras.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "e1", + "method": "POST", + "url": "/v1/embeddings", + "body": { + "model": "bedrock/amazon.titan-embed-text-v2:0", + "input": "Hello", + }, + } + ] + ) + + assert result == [{"recordId": "e1", "modelInput": {"inputText": "Hello"}}] + + def test_titan_v2_dimensions_and_encoding_format(self): + """OpenAI `dimensions` / `encoding_format` map to Titan v2 schema.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "e1", + "method": "POST", + "url": "/v1/embeddings", + "body": { + "model": "bedrock/amazon.titan-embed-text-v2:0", + "input": "Hi", + "dimensions": 256, + "encoding_format": "float", + }, + } + ] + ) + + model_input = result[0]["modelInput"] + assert model_input["inputText"] == "Hi" + assert model_input["dimensions"] == 256 + assert model_input["embeddingTypes"] == ["float"] + + def test_embedding_routing_falls_back_to_body_shape(self): + """Records without `url` still route via `input` presence.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "e1", + "body": { + "model": "bedrock/amazon.titan-embed-text-v2:0", + "input": "Hello", + }, + } + ] + ) + + assert result[0]["modelInput"] == {"inputText": "Hello"} + + def test_embedding_single_element_list_input_is_accepted(self): + """A single-element list maps to the same shape as a bare string.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "e1", + "method": "POST", + "url": "/v1/embeddings", + "body": { + "model": "bedrock/amazon.titan-embed-text-v2:0", + "input": ["only one"], + }, + } + ] + ) + + assert result[0]["modelInput"]["inputText"] == "only one" + + def test_embedding_multi_input_list_raises(self): + """Multi-element `input` lists are rejected with a clear message.""" + import pytest + + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + with pytest.raises(ValueError, match="one input per JSONL record"): + config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "e1", + "method": "POST", + "url": "/v1/embeddings", + "body": { + "model": "bedrock/amazon.titan-embed-text-v2:0", + "input": ["a", "b"], + }, + } + ] + ) + + def test_embedding_missing_input_raises(self): + """A record routed to /v1/embeddings without `input` is an error.""" + import pytest + + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + with pytest.raises(ValueError, match="missing required `input`"): + config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "e1", + "method": "POST", + "url": "/v1/embeddings", + "body": {"model": "bedrock/amazon.titan-embed-text-v2:0"}, + } + ] + ) + + def test_mixed_chat_and_embedding_in_same_batch(self): + """Chat and embedding records in the same JSONL each take their path.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "chat-1", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0", + "messages": [{"role": "user", "content": "Hi"}], + "max_tokens": 5, + }, + }, + { + "custom_id": "embed-1", + "method": "POST", + "url": "/v1/embeddings", + "body": { + "model": "bedrock/amazon.titan-embed-text-v2:0", + "input": "Hi", + }, + }, + ] + ) + + assert result[0]["recordId"] == "chat-1" + assert "messages" in result[0]["modelInput"] + assert result[0]["modelInput"]["anthropic_version"] == "bedrock-2023-05-31" + + assert result[1]["recordId"] == "embed-1" + assert result[1]["modelInput"] == {"inputText": "Hi"} + + def test_unsupported_embedding_model_raises_not_implemented(self): + """Cohere/Nova/Titan-G1 embed get a clear NotImplementedError, not a corrupt body.""" + import pytest + + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + for unsupported_model in ( + "bedrock/cohere.embed-english-v3", + "bedrock/amazon.titan-embed-text-v1", + "bedrock/amazon.titan-embed-image-v1", + "bedrock/amazon.nova-2-multimodal-embeddings-v1:0", + ): + with pytest.raises(NotImplementedError, match="titan-embed-text-v2"): + config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "e1", + "method": "POST", + "url": "/v1/embeddings", + "body": {"model": unsupported_model, "input": "Hi"}, + } + ] + ) + + def test_titan_v2_model_name_variants_route_correctly(self): + """All common Titan v2 model id shapes route through the embedding path.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + for model_id in ( + "amazon.titan-embed-text-v2:0", + "bedrock/amazon.titan-embed-text-v2:0", + "us.amazon.titan-embed-text-v2:0", + "bedrock/us.amazon.titan-embed-text-v2:0", + ): + result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "e1", + "method": "POST", + "url": "/v1/embeddings", + "body": {"model": model_id, "input": "Hi"}, + } + ] + ) + assert result[0]["modelInput"] == { + "inputText": "Hi" + }, f"model id {model_id} did not route to Titan v2 embedding path" + + def test_pretokenized_input_list_of_ints_raises(self): + """`input: List[int]` (pre-tokenized) is rejected, not silently mis-shaped.""" + import pytest + + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + with pytest.raises( + (NotImplementedError, ValueError), match=r"pre-tokenized|one input per" + ): + config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "e1", + "method": "POST", + "url": "/v1/embeddings", + "body": { + "model": "bedrock/amazon.titan-embed-text-v2:0", + "input": [1, 2, 3], + }, + } + ] + ) + + def test_pretokenized_single_wrapped_list_raises(self): + """`input: List[List[int]]` with one element is rejected as pre-tokenized.""" + import pytest + + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + with pytest.raises(NotImplementedError, match="pre-tokenized"): + config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "e1", + "method": "POST", + "url": "/v1/embeddings", + "body": { + "model": "bedrock/amazon.titan-embed-text-v2:0", + "input": [[1, 2, 3]], + }, + } + ] + ) + + def test_record_with_both_input_and_messages_routes_to_chat(self): + """If a record has both fields, chat wins (safer default - see helper docstring).""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "ambiguous-1", + "body": { + "model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0", + "messages": [{"role": "user", "content": "Hi"}], + "input": "this should be ignored by chat path", + "max_tokens": 5, + }, + } + ] + ) + + assert "messages" in result[0]["modelInput"] + assert "inputText" not in result[0]["modelInput"] + + def test_url_embeddings_with_missing_input_raises_not_chat_error(self): + """url says embed, body lacks input → embedding-path error, not chat-path crash.""" + import pytest + + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + with pytest.raises(ValueError, match="missing required `input`"): + config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "e1", + "method": "POST", + "url": "/v1/embeddings", + "body": {"model": "bedrock/amazon.titan-embed-text-v2:0"}, + } + ] + ) + + def test_titan_v2_marker_boundary_rejects_lookalikes(self): + """The marker must end at `:`, `/`, or end-of-string to avoid false positives.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + # Look-alikes that must NOT route through the Titan v2 path + for model in ( + "bedrock/amazon.titan-embed-text-v20:0", + "bedrock/amazon.titan-embed-text-v2-experimental:0", + "bedrock/amazon.titan-embed-text-v2foo", + ): + assert not BedrockFilesConfig._is_titan_v2_embed_model( + model + ), f"{model} unexpectedly matched the Titan v2 marker" + + # Real Titan v2 ids that MUST match + for model in ( + "amazon.titan-embed-text-v2:0", + "bedrock/amazon.titan-embed-text-v2:0", + "us.amazon.titan-embed-text-v2:0", + "arn:aws:bedrock:us-east-1:123:foundation-model/amazon.titan-embed-text-v2:0", + ): + assert BedrockFilesConfig._is_titan_v2_embed_model( + model + ), f"{model} unexpectedly missed the Titan v2 marker" + + def test_titan_v2_accepted_when_registry_schema_field_matches(self, mocker): + """Registry-driven happy path: nested + `provider_specific_entry.bedrock_invocation_schema == "titan_v2"` + is the authoritative signal.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + mocker.patch( + "litellm.get_model_info", + return_value={ + "provider_specific_entry": {"bedrock_invocation_schema": "titan_v2"} + }, + ) + assert BedrockFilesConfig._is_titan_v2_embed_model( + "amazon.titan-embed-text-v2:0" + ) + + def test_titan_v2_rejected_when_registry_schema_field_differs(self, mocker): + """Registry resolves with a different schema value (e.g. a hypothetical + Cohere Embed entry) -> reject. Registry is authoritative; no substring + second-chance for ids the registry knows.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + mocker.patch( + "litellm.get_model_info", + return_value={ + "provider_specific_entry": {"bedrock_invocation_schema": "cohere_v3"} + }, + ) + # Even though the model id looks like Titan v2, the registry says + # otherwise and we trust it. + assert not BedrockFilesConfig._is_titan_v2_embed_model( + "amazon.titan-embed-text-v2:0" + ) + + def test_titan_v2_falls_back_to_marker_when_registry_lacks_schema_field( + self, mocker + ): + """Registry resolves but the entry has no + `provider_specific_entry.bedrock_invocation_schema` field yet (e.g. + a stale local registry) -> fall through to substring.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + # No provider_specific_entry at all + mocker.patch( + "litellm.get_model_info", + return_value={"mode": "embedding"}, + ) + assert BedrockFilesConfig._is_titan_v2_embed_model( + "amazon.titan-embed-text-v2:0" + ) + + # provider_specific_entry present but missing the schema key + mocker.patch( + "litellm.get_model_info", + return_value={ + "mode": "embedding", + "provider_specific_entry": {"unrelated": "value"}, + }, + ) + assert BedrockFilesConfig._is_titan_v2_embed_model( + "amazon.titan-embed-text-v2:0" + ) + + def test_titan_v2_accepted_when_registry_silent(self, mocker): + """Marker-only match is fine for ids the registry can't resolve + (cross-region profile prefixes, ARN forms).""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + mocker.patch("litellm.get_model_info", side_effect=Exception("not mapped")) + assert BedrockFilesConfig._is_titan_v2_embed_model( + "us.amazon.titan-embed-text-v2:0" + ) + assert BedrockFilesConfig._is_titan_v2_embed_model( + "arn:aws:bedrock:us-east-1:123:foundation-model/amazon.titan-embed-text-v2:0" + ) + + def test_lookup_provider_specific_field_helper(self, mocker): + """Direct coverage of the nested registry field helper.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + # Happy path: returns the nested field's string value + mocker.patch( + "litellm.get_model_info", + return_value={ + "provider_specific_entry": {"bedrock_invocation_schema": "titan_v2"} + }, + ) + assert ( + BedrockFilesConfig._lookup_provider_specific_field( + "anything", "bedrock_invocation_schema" + ) + == "titan_v2" + ) + + # Registry raises -> None + mocker.patch("litellm.get_model_info", side_effect=Exception("not mapped")) + assert ( + BedrockFilesConfig._lookup_provider_specific_field("anything", "any") + is None + ) + + # Registry returns non-dict -> None + mocker.patch("litellm.get_model_info", return_value="not a dict") + assert ( + BedrockFilesConfig._lookup_provider_specific_field("anything", "any") + is None + ) + + # Registry returns dict without provider_specific_entry -> None + mocker.patch("litellm.get_model_info", return_value={"mode": "embedding"}) + assert ( + BedrockFilesConfig._lookup_provider_specific_field( + "anything", "bedrock_invocation_schema" + ) + is None + ) + + # provider_specific_entry exists but isn't a dict -> None + mocker.patch( + "litellm.get_model_info", + return_value={"provider_specific_entry": "not a dict"}, + ) + assert ( + BedrockFilesConfig._lookup_provider_specific_field( + "anything", "bedrock_invocation_schema" + ) + is None + ) + + # provider_specific_entry dict missing the requested field -> None + mocker.patch( + "litellm.get_model_info", + return_value={"provider_specific_entry": {"unrelated": "x"}}, + ) + assert ( + BedrockFilesConfig._lookup_provider_specific_field( + "anything", "bedrock_invocation_schema" + ) + is None + ) + + # Non-string nested value -> None + mocker.patch( + "litellm.get_model_info", + return_value={"provider_specific_entry": {"bedrock_invocation_schema": 42}}, + ) + assert ( + BedrockFilesConfig._lookup_provider_specific_field( + "anything", "bedrock_invocation_schema" + ) + is None + ) + + # Empty-string nested value -> None + mocker.patch( + "litellm.get_model_info", + return_value={"provider_specific_entry": {"bedrock_invocation_schema": ""}}, + ) + assert ( + BedrockFilesConfig._lookup_provider_specific_field( + "anything", "bedrock_invocation_schema" + ) + is None + ) + + def test_is_embedding_record_helper(self): + """Helper detects embeddings via `url` first, then by body shape.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + assert BedrockFilesConfig._is_embedding_record( + {"url": "/v1/embeddings", "body": {"input": "x"}} + ) + # body-only fallback + assert BedrockFilesConfig._is_embedding_record({"body": {"input": "x"}}) + # chat shape + assert not BedrockFilesConfig._is_embedding_record( + {"url": "/v1/chat/completions", "body": {"messages": []}} + ) + # ambiguous body without `input` is treated as not-embedding + assert not BedrockFilesConfig._is_embedding_record({"body": {}}) + + def test_explicit_chat_url_with_input_body_short_circuits_to_chat(self): + """Explicit url=/v1/chat/completions wins even if body looks like embedding. + + Without this short-circuit, a chat record whose body happens to carry + `input` (and no `messages`) would be mis-routed to the embedding + transformer, corrupting the modelInput. + """ + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + # Direct helper assertion + assert not BedrockFilesConfig._is_embedding_record( + { + "url": "/v1/chat/completions", + "body": { + "model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0", + "input": "this would mis-route under the old precedence", + }, + } + ) + + # End-to-end: a record like this routes through the chat path. We + # just need to make sure we DON'T silently produce an inputText + # body and call it a chat completion. + config = BedrockFilesConfig() + result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + [ + { + "custom_id": "explicit-chat-with-input", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0", + "messages": [{"role": "user", "content": "Hi"}], + "input": "should not become inputText", + "max_tokens": 5, + }, + } + ] + ) + + model_input = result[0]["modelInput"] + assert ( + "inputText" not in model_input + ), "explicit chat URL must not produce an embedding-shaped modelInput" + + def test_coerce_embedding_input_helper_isolated(self): + """Direct coverage of the extracted input-normalization helper.""" + import pytest + + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + # Happy paths + assert BedrockFilesConfig._coerce_embedding_input_to_string("hello") == "hello" + assert ( + BedrockFilesConfig._coerce_embedding_input_to_string(["hello"]) == "hello" + ) + + # Error paths + with pytest.raises(ValueError, match="missing required `input`"): + BedrockFilesConfig._coerce_embedding_input_to_string(None, model="m") + with pytest.raises(ValueError, match="one input per JSONL record"): + BedrockFilesConfig._coerce_embedding_input_to_string(["a", "b"]) + # A multi-element list of ints is rejected as "one input per JSONL + # record" too - we can't tell if it's pre-tokenized or "3 strings" + # without more context, so the most-actionable error wins. + with pytest.raises(ValueError, match="one input per JSONL record"): + BedrockFilesConfig._coerce_embedding_input_to_string([1, 2, 3]) + # Single-element list wrapping a token list -> pre-tokenized error. + with pytest.raises(NotImplementedError, match="pre-tokenized"): + BedrockFilesConfig._coerce_embedding_input_to_string([[1, 2, 3]]) + # Single-element list wrapping a bare int -> pre-tokenized error. + with pytest.raises(NotImplementedError, match="pre-tokenized"): + BedrockFilesConfig._coerce_embedding_input_to_string([42]) + with pytest.raises(ValueError, match="must be a string"): + BedrockFilesConfig._coerce_embedding_input_to_string({"unsupported": True}) + + def test_other_non_embedding_urls_route_to_chat(self): + """Any non-/v1/embeddings url short-circuits to chat path.""" + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + # /v1/completions (legacy completions endpoint) + assert not BedrockFilesConfig._is_embedding_record( + {"url": "/v1/completions", "body": {"input": "x"}} + ) + # Arbitrary unknown url - caller's explicit signal still wins + assert not BedrockFilesConfig._is_embedding_record( + {"url": "/v1/responses", "body": {"input": "x"}} + ) diff --git a/tests/test_litellm/llms/bedrock/test_claude_platform_provider.py b/tests/test_litellm/llms/bedrock/test_claude_platform_provider.py index 74cfbb265c..dbded8e0a2 100644 --- a/tests/test_litellm/llms/bedrock/test_claude_platform_provider.py +++ b/tests/test_litellm/llms/bedrock/test_claude_platform_provider.py @@ -1,8 +1,9 @@ import json -from unittest.mock import patch +from unittest.mock import MagicMock, patch import httpx import pytest +from botocore.credentials import Credentials def _anthropic_response(url: str) -> httpx.Response: @@ -310,3 +311,54 @@ async def test_anthropic_messages_routes_bedrock_claude_platform_to_messages_api assert requests[0]["body"]["messages"] == [{"role": "user", "content": "hello"}] assert requests[0]["body"]["max_tokens"] == 10 assert requests[0]["body"]["model"] == "claude-sonnet-4-6" + + +def test_sigv4_no_duplicate_content_type_when_caller_sets_lowercase(): + """ + Regression: get_anthropic_headers() supplies "content-type" (lowercase). + _sign_request() used to prepend "Content-Type" (uppercase), leaving both + keys in the dict. botocore joins them into "application/json, application/json" + in the canonical string, while the wire request sends only one value → 401. + + Fix: prepend with lowercase "content-type" so **headers overwrites it when + the caller already set it. + """ + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + llm = BaseAWSLLM() + mock_credentials = Credentials("key", "secret", "token") + mock_sigv4 = MagicMock() + captured: list[dict] = [] + + def fake_aws_request(method, url, data, headers): + captured.append(dict(headers)) + req = MagicMock() + req.headers = {"Authorization": "AWS4-HMAC-SHA256 Credential=test"} + req.body = data.encode() if isinstance(data, str) else data + return req + + with ( + patch("botocore.auth.SigV4Auth", return_value=mock_sigv4), + patch("botocore.awsrequest.AWSRequest", side_effect=fake_aws_request), + patch.object(llm, "get_credentials", return_value=mock_credentials), + patch.object(llm, "_get_aws_region_name", return_value="us-east-1"), + ): + llm._sign_request( + service_name="aws-external-anthropic", + headers={"content-type": "application/json"}, + optional_params={"aws_region_name": "us-east-1"}, + request_data={ + "model": "claude-sonnet-4-6", + "messages": [], + "max_tokens": 10, + }, + api_base="https://aws-external-anthropic.us-east-1.api.aws/v1/messages", + ) + + signed = captured[0] + ct_keys = [k for k in signed if k.lower() == "content-type"] + assert ct_keys == ["content-type"], ( + f"Expected exactly one 'content-type' key, got {ct_keys}. " + "Duplicate keys produce 'application/json, application/json' in the " + "SigV4 canonical string and cause a 401." + ) diff --git a/tests/test_litellm/llms/black_forest_labs/test_bfl_common_utils.py b/tests/test_litellm/llms/black_forest_labs/test_bfl_common_utils.py new file mode 100644 index 0000000000..dc1d21bd03 --- /dev/null +++ b/tests/test_litellm/llms/black_forest_labs/test_bfl_common_utils.py @@ -0,0 +1,67 @@ +""" +Tests for Black Forest Labs common_utils — specifically assert_bfl_polling_url. + +BFL uses regional subdomains (e.g. gateway.bfl.ai) for polling URLs that +differ from the submission host (api.bfl.ai). These tests verify that the +domain-aware check accepts legitimate BFL subdomains while still rejecting +off-domain and non-HTTPS URLs. +""" + +import pytest + +from litellm.llms.black_forest_labs.common_utils import ( + BlackForestLabsError, + assert_bfl_polling_url, +) + + +class TestAssertBflPollingUrl: + # --- should pass --- + + def test_exact_registered_domain(self): + assert_bfl_polling_url("https://bfl.ai/v1/get_result?id=abc") + + def test_api_subdomain(self): + assert_bfl_polling_url("https://api.bfl.ai/v1/get_result?id=abc") + + def test_gateway_subdomain(self): + # BFL uses gateway.bfl.ai for polling — this was the original bug trigger + assert_bfl_polling_url("https://gateway.bfl.ai/v1/get_result?id=abc") + + def test_regional_subdomain(self): + assert_bfl_polling_url("https://eu.api.bfl.ai/v1/get_result?id=abc") + + def test_deep_subdomain(self): + assert_bfl_polling_url("https://region.gateway.bfl.ai/poll?id=xyz") + + # --- should raise BlackForestLabsError --- + + def test_rejects_http_scheme(self): + # HTTP must be rejected — x-key would be forwarded in plaintext + with pytest.raises(BlackForestLabsError, match="scheme must be https"): + assert_bfl_polling_url("http://api.bfl.ai/v1/get_result?id=abc") + + def test_rejects_off_domain(self): + with pytest.raises(BlackForestLabsError, match="host is not within"): + assert_bfl_polling_url("https://evil.com/steal-key") + + def test_rejects_lookalike_domain(self): + with pytest.raises(BlackForestLabsError, match="host is not within"): + assert_bfl_polling_url("https://notbfl.ai/v1/get_result?id=abc") + + def test_rejects_bfl_ai_as_suffix_only(self): + # "fakebfl.ai" must not match — the check is on registered domain boundary + with pytest.raises(BlackForestLabsError, match="host is not within"): + assert_bfl_polling_url("https://fakebfl.ai/v1/get_result?id=abc") + + def test_rejects_bfl_in_path(self): + with pytest.raises(BlackForestLabsError, match="host is not within"): + assert_bfl_polling_url("https://evil.com/bfl.ai/steal") + + def test_rejects_ftp_scheme(self): + with pytest.raises(BlackForestLabsError, match="scheme must be https"): + assert_bfl_polling_url("ftp://api.bfl.ai/v1/get_result?id=abc") + + def test_rejects_javascript_scheme(self): + with pytest.raises(BlackForestLabsError, match="scheme must be https"): + assert_bfl_polling_url("javascript://api.bfl.ai/alert(1)") diff --git a/tests/test_litellm/llms/openai/chat/guardrail_translation/test_openai_guardrail_handler.py b/tests/test_litellm/llms/openai/chat/guardrail_translation/test_openai_guardrail_handler.py index a2c3700294..4c268d9dfc 100644 --- a/tests/test_litellm/llms/openai/chat/guardrail_translation/test_openai_guardrail_handler.py +++ b/tests/test_litellm/llms/openai/chat/guardrail_translation/test_openai_guardrail_handler.py @@ -8,8 +8,7 @@ with guardrail transformations, including tool calls. import json import os import sys -from typing import Any, List, Literal, Optional, Tuple -from unittest.mock import AsyncMock, MagicMock +from typing import Any, Literal, Optional import pytest @@ -84,6 +83,70 @@ class MockGuardrail(CustomGuardrail): return result +class MockCopiedToolCallGuardrail(CustomGuardrail): + """Mock guardrail that returns copied tool calls instead of mutating inputs.""" + + async def apply_guardrail( + self, + inputs: GenericGuardrailAPIInputs, + request_data: dict, + input_type: Literal["request", "response"], + logging_obj: Optional[Any] = None, + ) -> GenericGuardrailAPIInputs: + tool_calls = inputs.get("tool_calls", []) + copied_tool_calls = [] + for tool_call in tool_calls: + copied = dict(tool_call) + function = dict(copied["function"]) + function["arguments"] = json.dumps({"email": "[EMAIL]"}) + copied["function"] = function + copied_tool_calls.append(copied) + + return GenericGuardrailAPIInputs( + texts=inputs.get("texts", []), + tool_calls=copied_tool_calls, + ) + + +class MockNonListToolCallGuardrail(CustomGuardrail): + """Mock guardrail that returns tool_calls as a non-list envelope on the response + path, as some released guardrails do when they assign a detection API JSON dict.""" + + async def apply_guardrail( + self, + inputs: GenericGuardrailAPIInputs, + request_data: dict, + input_type: Literal["request", "response"], + logging_obj: Optional[Any] = None, + ) -> GenericGuardrailAPIInputs: + result = GenericGuardrailAPIInputs(texts=inputs.get("texts", [])) + result["tool_calls"] = {"verdict": "allow", "detections": []} # type: ignore + return result + + +class MockMisalignedToolCallGuardrail(CustomGuardrail): + """Mock guardrail that returns a tool_calls list whose length differs from the + input, so it cannot be applied positionally onto the response.""" + + async def apply_guardrail( + self, + inputs: GenericGuardrailAPIInputs, + request_data: dict, + input_type: Literal["request", "response"], + logging_obj: Optional[Any] = None, + ) -> GenericGuardrailAPIInputs: + tool_calls = inputs.get("tool_calls", []) + shortened = [] + if tool_calls: + first = dict(tool_calls[0]) + first["function"] = {"name": "x", "arguments": json.dumps({"x": 1})} + shortened.append(first) + return GenericGuardrailAPIInputs( + texts=inputs.get("texts", []), + tool_calls=shortened, + ) + + class TestOpenAIChatCompletionsHandlerToolsInput: """Test input processing with tools (function definitions)""" @@ -740,6 +803,131 @@ class TestOpenAIChatCompletionsHandlerToolCallsOutput: assert response.model == "gpt-4o-mini" assert response.choices[0].finish_reason == "tool_calls" + @pytest.mark.asyncio + async def test_output_response_uses_returned_guardrailed_tool_calls(self): + """Test returned tool_calls are remapped even when guardrail does not mutate inputs.""" + handler = OpenAIChatCompletionsHandler() + guardrail = MockCopiedToolCallGuardrail(guardrail_name="test") + + response = ModelResponse( + id="chatcmpl-tool-copy", + created=1234567890, + model="gpt-4", + object="chat.completion", + choices=[ + Choices( + finish_reason="tool_calls", + index=0, + message=Message( + content=None, + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_email", + type="function", + function=Function( + name="send_email", + arguments=json.dumps({"email": "john@example.com"}), + ), + ) + ], + ), + ) + ], + ) + + await handler.process_output_response(response, guardrail) + + response_tool_call = response.choices[0].message.tool_calls[0] + assert response_tool_call.function.name == "send_email" + assert json.loads(response_tool_call.function.arguments) == {"email": "[EMAIL]"} + + @pytest.mark.asyncio + async def test_output_response_ignores_non_list_returned_tool_calls(self): + """A guardrail returning tool_calls as a non-list (e.g. a detection-API envelope + dict) must not crash the remap; the original arguments are preserved.""" + handler = OpenAIChatCompletionsHandler() + guardrail = MockNonListToolCallGuardrail(guardrail_name="test") + original = json.dumps({"email": "john@example.com"}) + response = ModelResponse( + id="chatcmpl-nonlist", + created=1234567890, + model="gpt-4", + object="chat.completion", + choices=[ + Choices( + finish_reason="tool_calls", + index=0, + message=Message( + content=None, + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_email", + type="function", + function=Function( + name="send_email", arguments=original + ), + ) + ], + ), + ) + ], + ) + + await handler.process_output_response(response, guardrail) + + response_tool_call = response.choices[0].message.tool_calls[0] + assert response_tool_call.function.arguments == original + + @pytest.mark.asyncio + async def test_output_response_ignores_misaligned_returned_tool_calls(self): + """A guardrail returning a tool_calls list of a different length than the input + cannot be applied positionally; the handler falls back and preserves the + original arguments instead of writing onto the wrong tool call.""" + handler = OpenAIChatCompletionsHandler() + guardrail = MockMisalignedToolCallGuardrail(guardrail_name="test") + first_args = json.dumps({"email": "a@example.com"}) + second_args = json.dumps({"email": "b@example.com"}) + response = ModelResponse( + id="chatcmpl-misaligned", + created=1234567890, + model="gpt-4", + object="chat.completion", + choices=[ + Choices( + finish_reason="tool_calls", + index=0, + message=Message( + content=None, + role="assistant", + tool_calls=[ + ChatCompletionMessageToolCall( + id="call_1", + type="function", + function=Function( + name="send_email", arguments=first_args + ), + ), + ChatCompletionMessageToolCall( + id="call_2", + type="function", + function=Function( + name="send_email", arguments=second_args + ), + ), + ], + ), + ) + ], + ) + + await handler.process_output_response(response, guardrail) + + tool_calls = response.choices[0].message.tool_calls + assert tool_calls[0].function.arguments == first_args + assert tool_calls[1].function.arguments == second_args + class MockPassThroughGuardrail(CustomGuardrail): """Mock guardrail that passes through without blocking - for testing streaming fallback behavior""" @@ -765,7 +953,7 @@ class TestOpenAIChatCompletionsHandlerStreamingOutput: This test verifies the fix for the bug where accessing chunk.choices[0] would raise IndexError when a streaming chunk has an empty choices list. """ - from litellm.types.utils import Delta, ModelResponseStream, StreamingChoices + from litellm.types.utils import ModelResponseStream handler = OpenAIChatCompletionsHandler() guardrail = MockPassThroughGuardrail(guardrail_name="test") diff --git a/tests/test_litellm/llms/openai_like/test_tensormesh_provider.py b/tests/test_litellm/llms/openai_like/test_tensormesh_provider.py new file mode 100644 index 0000000000..f81f1c00a7 --- /dev/null +++ b/tests/test_litellm/llms/openai_like/test_tensormesh_provider.py @@ -0,0 +1,84 @@ +""" +Tests for Tensormesh provider configuration and integration. +""" + +import litellm + + +class TestTensormeshProviderConfig: + """Test Tensormesh provider configuration""" + + def test_tensormesh_in_provider_list(self): + """Test that tensormesh is in the provider list""" + from litellm import LlmProviders + + assert hasattr(LlmProviders, "TENSORMESH") + assert LlmProviders.TENSORMESH.value == "tensormesh" + assert "tensormesh" in litellm.provider_list + + def test_tensormesh_json_config_exists(self): + """Test that tensormesh is configured in providers.json""" + from litellm.llms.openai_like.json_loader import JSONProviderRegistry + + assert JSONProviderRegistry.exists("tensormesh") + + tensormesh = JSONProviderRegistry.get("tensormesh") + assert tensormesh is not None + assert tensormesh.base_url == "https://serverless.tensormesh.ai/v1" + assert tensormesh.api_key_env == "TENSORMESH_INFERENCE_API_KEY" + assert tensormesh.api_base_env == "TENSORMESH_SERVERLESS_BASE_URL" + assert tensormesh.param_mappings.get("max_completion_tokens") == "max_tokens" + + def test_tensormesh_provider_resolution(self): + """Test that provider resolution finds tensormesh and the default base URL""" + from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider + + model, provider, api_key, api_base = get_llm_provider( + model="tensormesh/openai/gpt-oss-120b", + custom_llm_provider=None, + api_base=None, + api_key=None, + ) + + assert model == "openai/gpt-oss-120b" + assert provider == "tensormesh" + assert api_base == "https://serverless.tensormesh.ai/v1" + + def test_tensormesh_api_base_override(self): + """Test that an explicit api_base / api_key overrides the serverless default""" + from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider + + model, provider, api_key, api_base = get_llm_provider( + model="tensormesh/openai/gpt-oss-120b", + custom_llm_provider=None, + api_base="https://custom.example.com/v1", + api_key="sk-test", + ) + + assert provider == "tensormesh" + assert api_base == "https://custom.example.com/v1" + assert api_key == "sk-test" + + def test_tensormesh_text_completion_enabled(self): + """Tensormesh is wired for the /completions (text completion) route, + matching the text_completion flag in provider_endpoints_support.json.""" + assert "tensormesh" in litellm.openai_text_completion_compatible_providers + + def test_tensormesh_router_config(self): + """Test that tensormesh can be used in Router configuration""" + from litellm import Router + + router = Router( + model_list=[ + { + "model_name": "tensormesh-chat", + "litellm_params": { + "model": "tensormesh/openai/gpt-oss-120b", + "api_key": "test-key", + }, + } + ] + ) + + assert len(router.model_list) == 1 + assert router.model_list[0]["model_name"] == "tensormesh-chat" diff --git a/tests/test_litellm/llms/watsonx/passthrough/test_watsonx_passthrough_transformation.py b/tests/test_litellm/llms/watsonx/passthrough/test_watsonx_passthrough_transformation.py new file mode 100644 index 0000000000..d1db04f521 --- /dev/null +++ b/tests/test_litellm/llms/watsonx/passthrough/test_watsonx_passthrough_transformation.py @@ -0,0 +1,282 @@ +""" +Unit tests for WatsonxPassthroughConfig transformation. + +Tests the Watsonx-specific passthrough configuration including URL construction, +streaming detection, and authentication handling. +""" + +import os +import sys +from unittest.mock import MagicMock, patch + +import httpx +import pytest + +sys.path.insert(0, os.path.abspath("../../../../..")) + +import litellm +from litellm.llms.watsonx.passthrough.transformation import WatsonxPassthroughConfig + + +class TestWatsonxPassthroughConfig: + """Tests for WatsonxPassthroughConfig class.""" + + def test_is_streaming_request_true(self): + """Test that streaming is detected when stream=True in request data.""" + config = WatsonxPassthroughConfig() + request_data = {"stream": True, "input": "test"} + + result = config.is_streaming_request( + endpoint="ml/v1/text/generation", request_data=request_data + ) + + assert result is True + + def test_is_streaming_request_false(self): + """Test that streaming is not detected when stream=False in request data.""" + config = WatsonxPassthroughConfig() + request_data = {"stream": False, "input": "test"} + + result = config.is_streaming_request( + endpoint="ml/v1/text/generation", request_data=request_data + ) + + assert result is False + + def test_is_streaming_request_missing_stream_key(self): + """Test that streaming defaults to False when stream key is missing.""" + config = WatsonxPassthroughConfig() + request_data = {"input": "test"} + + result = config.is_streaming_request( + endpoint="ml/v1/text/generation", request_data=request_data + ) + + assert result is False + + def test_get_complete_url_with_api_base(self): + """Test URL construction with explicit api_base.""" + config = WatsonxPassthroughConfig() + api_base = "https://us-south.ml.cloud.ibm.com" + endpoint = "ml/v1/text/generation" + request_query_params = {"version": "2024-03-19"} + + complete_url, base_target_url = config.get_complete_url( + api_base=api_base, + api_key=None, + model="ibm/granite-13b-chat-v2", + endpoint=endpoint, + request_query_params=request_query_params, + litellm_params={}, + ) + + assert isinstance(complete_url, httpx.URL) + assert str(complete_url).startswith(api_base) + assert endpoint in str(complete_url) + assert "version=2024-03-19" in str(complete_url) + assert base_target_url == api_base + + @patch("litellm.llms.watsonx.common_utils.get_secret_str") + def test_get_complete_url_with_env_api_base(self, mock_get_secret): + """Test URL construction with api_base from environment.""" + config = WatsonxPassthroughConfig() + env_api_base = "https://eu-de.ml.cloud.ibm.com" + mock_get_secret.return_value = env_api_base + + endpoint = "ml/v1/text/tokenization" + request_query_params = {"version": "2024-03-19"} + + complete_url, base_target_url = config.get_complete_url( + api_base=None, + api_key=None, + model="ibm/granite-13b-chat-v2", + endpoint=endpoint, + request_query_params=request_query_params, + litellm_params={}, + ) + + assert isinstance(complete_url, httpx.URL) + assert str(complete_url).startswith(env_api_base) + assert endpoint in str(complete_url) + assert base_target_url == env_api_base + + def test_get_complete_url_with_query_params(self): + """Test that query parameters are correctly added to URL.""" + config = WatsonxPassthroughConfig() + api_base = "https://us-south.ml.cloud.ibm.com" + endpoint = "ml/v1/text/generation" + request_query_params = { + "version": "2024-03-19", + } + + complete_url, _ = config.get_complete_url( + api_base=api_base, + api_key=None, + model="ibm/granite-13b-chat-v2", + endpoint=endpoint, + request_query_params=request_query_params, + litellm_params={}, + ) + + url_str = str(complete_url) + assert "version=2024-03-19" in url_str + + def test_get_complete_url_without_query_params(self): + """Test URL construction without query parameters.""" + config = WatsonxPassthroughConfig() + api_base = "https://us-south.ml.cloud.ibm.com" + endpoint = "ml/v1/models" + + complete_url, base_target_url = config.get_complete_url( + api_base=api_base, + api_key=None, + model="", + endpoint=endpoint, + request_query_params=None, + litellm_params={}, + ) + + assert isinstance(complete_url, httpx.URL) + assert str(complete_url) == f"{api_base}/{endpoint}" + assert base_target_url == api_base + assert "version=2024-03-19" not in str(complete_url) + + @patch("litellm.llms.watsonx.common_utils.get_secret_str") + def test_get_api_base_with_explicit_value(self, mock_get_secret): + """Test get_api_base returns explicit value when provided.""" + explicit_base = "https://custom.watsonx.com" + + result = WatsonxPassthroughConfig.get_api_base(api_base=explicit_base) + + assert result == explicit_base + mock_get_secret.assert_not_called() + + @patch("litellm.llms.watsonx.common_utils.get_secret_str") + def test_get_api_base_from_environment(self, mock_get_secret): + """Test get_api_base retrieves from environment when not provided.""" + env_base = "https://env.watsonx.com" + mock_get_secret.return_value = env_base + + result = WatsonxPassthroughConfig.get_api_base(api_base=None) + + assert result == env_base + mock_get_secret.assert_called_once_with("WATSONX_API_BASE") + + @patch("litellm.llms.watsonx.common_utils.get_secret_str") + def test_get_api_key_with_explicit_value(self, mock_get_secret): + """Test get_api_key returns explicit value when provided.""" + explicit_key = "test-api-key-123" + + result = WatsonxPassthroughConfig.get_api_key(api_key=explicit_key) + + assert result == explicit_key + mock_get_secret.assert_not_called() + + @patch("litellm.llms.watsonx.common_utils.get_secret_str") + def test_get_api_key_from_environment(self, mock_get_secret): + """Test get_api_key retrieves from environment when not provided.""" + env_key = "env-api-key-456" + mock_get_secret.return_value = env_key + + result = WatsonxPassthroughConfig.get_api_key(api_key=None) + + assert result == env_key + mock_get_secret.assert_any_call("WATSONX_APIKEY") + + def test_get_base_model_returns_model(self): + """Test get_base_model returns the model as-is.""" + model = "ibm/granite-13b-chat-v2" + + result = WatsonxPassthroughConfig.get_base_model(model) + + assert result == model + + def test_get_base_model_with_deployment(self): + """Test get_base_model with deployment model.""" + model = "deployment/test-deployment-id" + + result = WatsonxPassthroughConfig.get_base_model(model) + + assert result == model + + def test_get_complete_url_with_different_endpoints(self): + """Test URL construction with various endpoint paths.""" + config = WatsonxPassthroughConfig() + api_base = "https://us-south.ml.cloud.ibm.com" + + endpoints = [ + "ml/v1/text/generation", + "ml/v1/text/tokenization", + "ml/v1/deployments/test-id/text/generation", + "ml/v1/models", + "ml/v1/foundation_model_specs", + ] + + for endpoint in endpoints: + complete_url, base_target_url = config.get_complete_url( + api_base=api_base, + api_key=None, + model="", + endpoint=endpoint, + request_query_params={"version": "2024-03-19"}, + litellm_params={}, + ) + + assert isinstance(complete_url, httpx.URL) + assert endpoint in str(complete_url) + assert base_target_url == api_base + + def test_get_complete_url_preserves_query_param_order(self): + """Test that query parameters maintain their values correctly.""" + config = WatsonxPassthroughConfig() + api_base = "https://us-south.ml.cloud.ibm.com" + endpoint = "ml/v1/text/generation" + request_query_params = { + "version": "2024-03-19", + "project_id": "abc-123", + "space_id": "xyz-789", + } + + complete_url, _ = config.get_complete_url( + api_base=api_base, + api_key=None, + model="", + endpoint=endpoint, + request_query_params=request_query_params, + litellm_params={}, + ) + + url_str = str(complete_url) + # Verify all params are present + assert "version=2024-03-19" in url_str + assert "project_id=abc-123" in url_str + assert "space_id=xyz-789" in url_str + + def test_is_streaming_request_with_various_stream_values(self): + """Test streaming detection with different stream value types.""" + config = WatsonxPassthroughConfig() + + # Test with boolean True + assert config.is_streaming_request("endpoint", {"stream": True}) is True + + # Test with boolean False + assert config.is_streaming_request("endpoint", {"stream": False}) is False + + # Test with string "true" (truthy string) + result = config.is_streaming_request("endpoint", {"stream": "true"}) + assert result == "true" # Returns the value as-is from .get() + + # Test with integer 1 (truthy) + result = config.is_streaming_request("endpoint", {"stream": 1}) + assert result == 1 + + # Test with integer 0 (falsy) + result = config.is_streaming_request("endpoint", {"stream": 0}) + assert result == 0 + + # Test with None + result = config.is_streaming_request("endpoint", {"stream": None}) + assert result is None + + # Test with empty dict (defaults to False) + assert config.is_streaming_request("endpoint", {}) is False diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_panw_prisma_airs.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_panw_prisma_airs.py index 5c60e3e2bd..431a7aa6f0 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_panw_prisma_airs.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_panw_prisma_airs.py @@ -5375,5 +5375,93 @@ class TestPanwAirsDualScanIndependence: assert mcp_call.get("content") is None +class TestPanwAirsTimeoutCoercion: + """Regression tests for string-valued timeout handling. + + Before the fix, a string `timeout` (which is what the dashboard UI persists + and what raw YAML preserves if quoted) survived into httpx, which raised + `TypeError: '<=' not supported between instances of 'str' and 'int'`. The + broad except in apply_guardrail swallowed it and the proxy returned a + misleading 500 'Security scan failed - request blocked for safety'. + """ + + def test_handler_coerces_string_timeout_to_float(self): + handler = make_handler(timeout="30") + assert handler.timeout == 30.0 + assert isinstance(handler.timeout, float) + + def test_handler_accepts_int_timeout(self): + handler = make_handler(timeout=15) + assert handler.timeout == 15.0 + + def test_handler_accepts_float_timeout(self): + handler = make_handler(timeout=7.5) + assert handler.timeout == 7.5 + + def test_handler_none_timeout_falls_back_to_default(self): + handler = make_handler(timeout=None) + assert handler.timeout == 10.0 + + def test_handler_omitted_timeout_uses_default(self): + handler = make_handler() + assert handler.timeout == 10.0 + + def test_litellm_params_coerces_string_timeout(self): + """Boundary validation: the Pydantic model itself should normalize + string timeouts before any handler reads the value via model_dump().""" + params = LitellmParams( + guardrail="panw_prisma_airs", + mode="pre_call", + api_key="test_key", + profile_name="test_profile", + timeout="30", + ) + assert params.timeout == 30.0 + assert isinstance(params.timeout, float) + + def test_litellm_params_rejects_garbage_timeout(self): + with pytest.raises(ValueError): + LitellmParams( + guardrail="panw_prisma_airs", + mode="pre_call", + api_key="test_key", + profile_name="test_profile", + timeout="not-a-number", + ) + + def test_litellm_params_empty_string_timeout_becomes_none(self): + """Empty-string timeout (which the dashboard form can send) should + be coerced to None, not crash, and not produce float('').""" + params = LitellmParams( + guardrail="panw_prisma_airs", + mode="pre_call", + api_key="test_key", + profile_name="test_profile", + timeout="", + ) + assert params.timeout is None + + def test_legacy_initializer_handles_unset_timeout(self): + """Regression guard: with timeout now a declared Optional[float] = None + on BaseLitellmParams, the legacy panw initializer at + guardrail_initializers.py:220 must not crash on float(None) when the + caller omits timeout entirely.""" + from litellm.proxy.guardrails.guardrail_initializers import ( + initialize_panw_prisma_airs, + ) + + params = LitellmParams( + guardrail="panw_prisma_airs", + mode="pre_call", + api_key="test_key", + profile_name="test_profile", + # timeout intentionally omitted - field defaults to None + ) + guardrail_config = {"guardrail_name": "test_legacy"} + handler = initialize_panw_prisma_airs(params, guardrail_config) + # Default fallback applied, not crashed on float(None) + assert handler.timeout == 10.0 + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_vigil_guard.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_vigil_guard.py new file mode 100644 index 0000000000..7ee424a2c1 --- /dev/null +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_vigil_guard.py @@ -0,0 +1,900 @@ +import json +import logging +import ssl +from types import SimpleNamespace +from typing import Any, List + +import httpx +import pytest + +from litellm.exceptions import GuardrailRaisedException +from litellm.exceptions import Timeout as LiteLLMTimeout +from litellm.proxy.guardrails.guardrail_hooks.vigil_guard import ( + VigilGuardGuardrail, + guardrail_class_registry, + guardrail_initializer_registry, + initialize_guardrail, +) +from litellm.proxy.guardrails.guardrail_hooks.vigil_guard.vigil_guard import ( + _DEFAULT_VIGIL_TIMEOUT, + VigilGuardMissingConfig, +) +from litellm.types.guardrails import LitellmParams, SupportedGuardrailIntegrations +from litellm.types.proxy.guardrails.guardrail_hooks.vigil_guard import ( + VigilGuardGuardrailConfigModel, +) + +_ENDPOINT = "https://vigil.test/v1/guard/analyze" + + +def _resp(body: dict, status_code: int = 200) -> httpx.Response: + return httpx.Response( + status_code=status_code, + json=body, + request=httpx.Request("POST", _ENDPOINT), + ) + + +class FakeHandler: + def __init__(self, items: List[Any]): + self._items = list(items) + self.calls: List[SimpleNamespace] = [] + + async def post(self, *, url, headers, json, timeout=None): # noqa: A002 + self.calls.append( + SimpleNamespace(url=url, headers=headers, json=json, timeout=timeout) + ) + if not self._items: + raise AssertionError("FakeHandler ran out of programmed responses") + item = self._items.pop(0) + if isinstance(item, BaseException): + raise item + return item + + +def _make_guardrail( + handler: FakeHandler, + *, + unreachable_fallback="fail_closed", + api_base="https://vigil.test", + api_key="vg_secret_key_123", + guardrail_name="vigil-guard", + timeout=None, +) -> VigilGuardGuardrail: + return VigilGuardGuardrail( + api_base=api_base, + api_key=api_key, + unreachable_fallback=unreachable_fallback, + timeout=timeout, + async_handler=handler, + guardrail_name=guardrail_name, + event_hook="pre_call", + default_on=True, + ) + + +def _transient_exceptions() -> List[BaseException]: + req = httpx.Request("POST", _ENDPOINT) + return [ + httpx.ConnectError("boom", request=req), + httpx.ConnectTimeout("boom", request=req), + httpx.ReadTimeout("boom", request=req), + httpx.RemoteProtocolError("boom", request=req), + LiteLLMTimeout(message="t", model="m", llm_provider="vigil_guard"), + ] + + +def test_requires_api_base(monkeypatch): + monkeypatch.delenv("VIGIL_GUARD_URL", raising=False) + monkeypatch.delenv("VIGIL_GUARD_API_KEY", raising=False) + with pytest.raises(VigilGuardMissingConfig): + VigilGuardGuardrail(api_key="k", async_handler=FakeHandler([])) + + +def test_requires_api_key(monkeypatch): + monkeypatch.delenv("VIGIL_GUARD_API_KEY", raising=False) + with pytest.raises(VigilGuardMissingConfig): + VigilGuardGuardrail( + api_base="https://vigil.test", async_handler=FakeHandler([]) + ) + + +def test_trailing_slash_stripped(): + g = _make_guardrail(FakeHandler([]), api_base="https://vigil.test/") + assert g.api_base == "https://vigil.test" + + +def test_env_fallback(monkeypatch): + monkeypatch.setenv("VIGIL_GUARD_URL", "https://env.vigil.test") + monkeypatch.setenv("VIGIL_GUARD_API_KEY", "env_key") + g = VigilGuardGuardrail( + async_handler=FakeHandler([]), + guardrail_name="vg", + event_hook="pre_call", + default_on=True, + ) + assert g.api_base == "https://env.vigil.test" + assert g.api_key == "env_key" + + +def test_default_unreachable_fallback_is_fail_closed(): + g = _make_guardrail(FakeHandler([]), unreachable_fallback=None) + assert g.unreachable_fallback == "fail_closed" + + +def test_explicit_fail_open_is_stored(): + g = _make_guardrail(FakeHandler([]), unreachable_fallback="fail_open") + assert g.unreachable_fallback == "fail_open" + + +def test_unknown_fallback_defaults_to_fail_closed(): + g = _make_guardrail(FakeHandler([]), unreachable_fallback="weird") + assert g.unreachable_fallback == "fail_closed" + + +async def test_allowed_preserves_full_input_shape_and_logs_allow(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + structured = [{"role": "user", "content": "hello"}] + inputs = {"texts": ["hello"], "structured_messages": structured, "model": "gpt-4o"} + request_data = {"metadata": {}} + out = await g.apply_guardrail( + inputs=inputs, request_data=request_data, input_type="request", logging_obj=None + ) + assert out["texts"] == ["hello"] + assert out["structured_messages"] is structured + assert out["model"] == "gpt-4o" + assert out is not inputs + assert inputs["structured_messages"] is structured + assert len(handler.calls) == 1 + entries = request_data["metadata"]["standard_logging_guardrail_information"] + assert entries[0]["guardrail_response"] == "allow" + + +async def test_sanitized_replaces_text(): + handler = FakeHandler( + [_resp({"decision": "SANITIZED", "sanitizedText": "[REDACTED]"})] + ) + g = _make_guardrail(handler) + out = await g.apply_guardrail( + inputs={"texts": ["my ssn is 123"]}, request_data={}, input_type="request" + ) + assert out["texts"] == ["[REDACTED]"] + + +@pytest.mark.parametrize( + "body,expected", + [ + ( + { + "decision": "SANITIZED", + "sanitizedText": "S", + "outputText": "O", + }, + "S", + ), + ({"decision": "SANITIZED", "outputText": "O"}, "O"), + ({"decision": "SANITIZED", "sanitizedText": 123, "outputText": "O"}, "O"), + ({"decision": "SANITIZED", "sanitizedText": ""}, ""), + ({"decision": "SANITIZED"}, "orig"), + ], +) +async def test_sanitized_precedence(body, expected): + handler = FakeHandler([_resp(body)]) + g = _make_guardrail(handler) + out = await g.apply_guardrail( + inputs={"texts": ["orig"]}, request_data={}, input_type="request" + ) + assert out["texts"] == [expected] + + +async def test_blocked_raises_guardrail_exception_with_400(): + handler = FakeHandler([_resp({"decision": "BLOCKED", "blockMessage": "nope"})]) + g = _make_guardrail(handler) + with pytest.raises(GuardrailRaisedException) as exc_info: + await g.apply_guardrail( + inputs={"texts": ["bad"]}, request_data={}, input_type="request" + ) + assert exc_info.value.status_code == 400 + assert exc_info.value.guardrail_name == "vigil-guard" + assert exc_info.value.message == "nope" + + +@pytest.mark.parametrize( + "body,expected", + [ + ( + { + "decision": "BLOCKED", + "blockMessage": "bm", + "decisionReason": "dr", + "categories": ["c1"], + }, + "bm", + ), + ({"decision": "BLOCKED", "blockMessage": " ", "decisionReason": "dr"}, "dr"), + ( + {"decision": "BLOCKED", "decisionReason": "dr", "categories": ["c1", "c2"]}, + "dr", + ), + ({"decision": "BLOCKED", "categories": ["c1", "c2"]}, "c1, c2"), + ({"decision": "BLOCKED"}, "Blocked by policy"), + ], +) +async def test_block_reason_precedence(body, expected): + handler = FakeHandler([_resp(body)]) + g = _make_guardrail(handler) + with pytest.raises(GuardrailRaisedException) as exc_info: + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert exc_info.value.message == expected + + +async def test_block_reason_is_clamped_to_500_chars(): + handler = FakeHandler([_resp({"decision": "BLOCKED", "blockMessage": "x" * 600})]) + g = _make_guardrail(handler) + with pytest.raises(GuardrailRaisedException) as exc_info: + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert "x" * 500 in exc_info.value.message + assert "x" * 501 not in exc_info.value.message + + +async def test_empty_and_whitespace_texts_skip_analyze(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + out = await g.apply_guardrail( + inputs={"texts": ["", " ", "real"]}, request_data={}, input_type="request" + ) + assert out["texts"] == ["", " ", "real"] + assert len(handler.calls) == 1 + assert handler.calls[0].json["text"] == "real" + + +async def test_no_scannable_text_returns_inputs_unchanged(): + handler = FakeHandler([]) + g = _make_guardrail(handler) + inputs = {"texts": ["", " "], "structured_messages": [{"role": "user"}]} + out = await g.apply_guardrail(inputs=inputs, request_data={}, input_type="request") + assert out is inputs + assert len(handler.calls) == 0 + + +async def test_multi_text_preserves_length_and_order(): + handler = FakeHandler( + [ + _resp({"decision": "ALLOWED"}), + _resp({"decision": "SANITIZED", "sanitizedText": "B-clean"}), + _resp({"decision": "ALLOWED"}), + ] + ) + g = _make_guardrail(handler) + out = await g.apply_guardrail( + inputs={"texts": ["A", "B", "C"]}, request_data={}, input_type="request" + ) + assert out["texts"] == ["A", "B-clean", "C"] + assert len(handler.calls) == 3 + + +async def test_one_blocked_text_blocks_the_whole_call(): + handler = FakeHandler( + [ + _resp({"decision": "ALLOWED"}), + _resp({"decision": "BLOCKED", "blockMessage": "bad second"}), + ] + ) + g = _make_guardrail(handler) + with pytest.raises(GuardrailRaisedException): + await g.apply_guardrail( + inputs={"texts": ["ok", "bad"]}, request_data={}, input_type="request" + ) + + +async def test_request_source_is_user_input(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert handler.calls[0].json["source"] == "user_input" + + +async def test_response_source_is_model_output(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="response" + ) + assert handler.calls[0].json["source"] == "model_output" + + +async def test_sanitized_returns_canonical_shape_and_logs_mask(): + handler = FakeHandler( + [_resp({"decision": "SANITIZED", "sanitizedText": "[REDACTED]"})] + ) + g = _make_guardrail(handler) + tools = [{"type": "function", "function": {"name": "f"}}] + inputs = { + "texts": ["my ssn is 123"], + "images": ["img1"], + "tools": tools, + "tool_calls": [{"id": "1"}], + "structured_messages": [{"role": "user", "content": "my ssn is 123"}], + "model": "gpt-4o", + } + request_data = {"metadata": {}} + out = await g.apply_guardrail( + inputs=inputs, request_data=request_data, input_type="request" + ) + assert out["texts"] == ["[REDACTED]"] + assert out["images"] == ["img1"] + assert out["tools"] == tools + assert set(out.keys()) == {"texts", "images", "tools"} + entries = request_data["metadata"]["standard_logging_guardrail_information"] + assert entries[0]["guardrail_response"] == "mask" + + +async def test_empty_images_and_tools_are_preserved_when_present(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + out = await g.apply_guardrail( + inputs={"texts": ["x"], "images": [], "tools": []}, + request_data={}, + input_type="request", + ) + assert set(out.keys()) == {"texts", "images", "tools"} + assert out["images"] == [] + assert out["tools"] == [] + + +async def test_logging_obj_none_supported(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + out = await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request", logging_obj=None + ) + assert out["texts"] == ["x"] + + +async def test_standard_guardrail_logging_remains_active(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + request_data = {"metadata": {}} + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data=request_data, input_type="request" + ) + entries = request_data["metadata"]["standard_logging_guardrail_information"] + assert len(entries) == 1 + assert entries[0]["guardrail_name"] == "vigil-guard" + assert entries[0]["guardrail_status"] == "success" + + +async def test_request_url_headers_and_body(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler, api_base="https://vigil.test", api_key="vg_secret") + await g.apply_guardrail( + inputs={"texts": ["hello"]}, request_data={}, input_type="request" + ) + call = handler.calls[0] + assert call.url == "https://vigil.test/v1/guard/analyze" + assert call.headers["Authorization"] == "Bearer vg_secret" + assert call.headers["Content-Type"] == "application/json" + assert call.json["text"] == "hello" + assert call.json["mode"] == "full" + assert set(call.json.keys()) == {"text", "source", "mode", "metadata"} + assert "metadata" in call.json + + +async def test_default_timeout_forwarded_when_unset(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + assert g.timeout == _DEFAULT_VIGIL_TIMEOUT + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert handler.calls[0].timeout == _DEFAULT_VIGIL_TIMEOUT + + +async def test_configured_timeout_forwarded_to_handler(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler, timeout=30) + expected = httpx.Timeout(30, connect=5.0) + assert g.timeout == expected + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert handler.calls[0].timeout == expected + + +def test_short_timeout_caps_connect(): + g = _make_guardrail(FakeHandler([]), timeout=2) + assert g.timeout == httpx.Timeout(2, connect=2.0) + + +def test_initialize_guardrail_forwards_timeout(): + lp = LitellmParams( + guardrail="vigil_guard", + mode="pre_call", + api_base="https://vigil.test", + api_key="k", + timeout="30", + ) + cb = initialize_guardrail(lp, {"guardrail_name": "vg"}) + assert cb.timeout == httpx.Timeout(30, connect=5.0) + + +async def test_api_key_only_in_header_never_in_payload(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler, api_key="super_secret_key") + await g.apply_guardrail( + inputs={"texts": ["hello"]}, + request_data={"metadata": {"user_id": "u"}}, + input_type="request", + ) + call = handler.calls[0] + assert "super_secret_key" not in json.dumps(call.json) + assert call.headers["Authorization"] == "Bearer super_secret_key" + + +@pytest.mark.parametrize("code", [429, 502, 503, 504]) +async def test_retry_once_on_transient_status(code): + handler = FakeHandler([_resp({}, status_code=code), _resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + out = await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert out["texts"] == ["x"] + assert len(handler.calls) == 2 + + +@pytest.mark.parametrize("exc", _transient_exceptions()) +async def test_retry_once_on_transient_exception(exc): + handler = FakeHandler([exc, _resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + out = await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert out["texts"] == ["x"] + assert len(handler.calls) == 2 + + +@pytest.mark.parametrize( + "exc, expected", + [ + (RuntimeError("boom"), RuntimeError), + ( + httpx.WriteError("boom", request=httpx.Request("POST", _ENDPOINT)), + GuardrailRaisedException, + ), + ], +) +async def test_no_retry_on_non_transient_exception(exc, expected): + handler = FakeHandler([exc]) + g = _make_guardrail(handler) + with pytest.raises(expected): + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert len(handler.calls) == 1 + + +@pytest.mark.parametrize("code", [400, 401, 403, 404, 422]) +async def test_no_retry_on_non_429_4xx(code): + handler = FakeHandler([_resp({}, status_code=code)]) + g = _make_guardrail(handler) + with pytest.raises(GuardrailRaisedException) as exc_info: + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert exc_info.value.status_code == 400 + assert len(handler.calls) == 1 + + +async def test_fail_closed_raises_after_exhausted_retry(caplog): + handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)]) + g = _make_guardrail(handler) + with ( + caplog.at_level(logging.ERROR), + pytest.raises(GuardrailRaisedException) as exc_info, + ): + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert exc_info.value.status_code == 400 + assert len(handler.calls) == 2 + assert any("fail_closed" in record.message for record in caplog.records) + assert any("vigil-guard" in record.message for record in caplog.records) + + +@pytest.mark.parametrize("exc", _transient_exceptions()) +async def test_fail_closed_raises_controlled_block_on_transport_error(exc, caplog): + handler = FakeHandler([exc, exc]) + g = _make_guardrail(handler) + with ( + caplog.at_level(logging.ERROR), + pytest.raises(GuardrailRaisedException) as exc_info, + ): + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert exc_info.value.status_code == 400 + assert exc_info.value.guardrail_name == "vigil-guard" + assert exc_info.value.__cause__ is exc + assert any("fail_closed" in record.message for record in caplog.records) + + +async def test_fail_open_returns_inputs_unchanged_on_backend_error(caplog): + handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)]) + g = _make_guardrail(handler, unreachable_fallback="fail_open") + structured = [{"role": "user", "content": "x"}] + inputs = {"texts": ["x"], "structured_messages": structured} + request_data = {"metadata": {}} + with caplog.at_level(logging.ERROR): + out = await g.apply_guardrail( + inputs=inputs, request_data=request_data, input_type="request" + ) + assert out is not inputs + assert out["texts"] == ["x"] + assert out["structured_messages"] == structured + assert len(handler.calls) == 2 + assert any("fail_open" in record.message for record in caplog.records) + assert any("vigil-guard" in record.message for record in caplog.records) + entries = request_data["metadata"]["standard_logging_guardrail_information"] + assert entries[0]["guardrail_response"] == "allow" + + +@pytest.mark.parametrize("exc", [ssl.SSLError("tls failed"), OSError("network down")]) +async def test_fail_open_returns_inputs_unchanged_on_transport_error(exc): + handler = FakeHandler([exc]) + g = _make_guardrail(handler, unreachable_fallback="fail_open") + inputs = {"texts": ["x"]} + out = await g.apply_guardrail(inputs=inputs, request_data={}, input_type="request") + assert out is not inputs + assert out["texts"] == ["x"] + assert len(handler.calls) == 1 + + +@pytest.mark.parametrize( + "exc", + [ + TypeError("bug"), + KeyError("bug"), + AttributeError("bug"), + ], +) +async def test_fail_open_does_not_swallow_programming_errors(exc): + handler = FakeHandler([exc]) + g = _make_guardrail(handler, unreachable_fallback="fail_open") + with pytest.raises(type(exc)): + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert len(handler.calls) == 1 + + +async def test_invalid_decision_fail_closed_raises(caplog): + handler = FakeHandler([_resp({"decision": "MAYBE"})]) + g = _make_guardrail(handler) + with ( + caplog.at_level(logging.ERROR), + pytest.raises(GuardrailRaisedException) as exc_info, + ): + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert exc_info.value.status_code == 400 + assert "MAYBE" not in exc_info.value.message + assert any("MAYBE" in record.message for record in caplog.records) + + +async def test_invalid_decision_fail_open_returns_inputs(): + handler = FakeHandler([_resp({"decision": "MAYBE"})]) + g = _make_guardrail(handler, unreachable_fallback="fail_open") + out = await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data={}, input_type="request" + ) + assert out["texts"] == ["x"] + + +async def test_fail_open_multi_text_preserves_earlier_sanitization(): + handler = FakeHandler( + [ + _resp({"decision": "SANITIZED", "sanitizedText": "[REDACTED]"}), + _resp({}, status_code=503), + _resp({}, status_code=503), + ] + ) + g = _make_guardrail(handler, unreachable_fallback="fail_open") + request_data = {"metadata": {}} + out = await g.apply_guardrail( + inputs={"texts": ["my ssn is 123", "second"]}, + request_data=request_data, + input_type="request", + ) + assert out["texts"] == ["[REDACTED]", "second"] + assert len(handler.calls) == 3 + entries = request_data["metadata"]["standard_logging_guardrail_information"] + assert entries[0]["guardrail_response"] == "mask" + + +def _tool_call(arguments, name="f", tc_id="1"): + return { + "id": tc_id, + "type": "function", + "function": {"name": name, "arguments": arguments}, + } + + +async def test_response_tool_call_arguments_allowed_unchanged(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + tcs = [_tool_call('{"q": "weather"}')] + out = await g.apply_guardrail( + inputs={"texts": [], "tool_calls": tcs}, request_data={}, input_type="response" + ) + assert handler.calls[0].json["text"] == '{"q": "weather"}' + assert handler.calls[0].json["source"] == "model_output" + assert out["tool_calls"] == tcs + + +async def test_response_tool_call_arguments_sanitized_in_place(): + handler = FakeHandler( + [_resp({"decision": "SANITIZED", "sanitizedText": '{"email": "[EMAIL]"}'})] + ) + g = _make_guardrail(handler) + tcs = [_tool_call('{"email": "john@example.com"}', name="send_mail")] + inputs = {"texts": [], "tool_calls": tcs} + out = await g.apply_guardrail(inputs=inputs, request_data={}, input_type="response") + assert out["tool_calls"][0]["function"]["arguments"] == '{"email": "[EMAIL]"}' + assert out["tool_calls"][0]["function"]["name"] == "send_mail" + # original inputs are not mutated in place + assert inputs["tool_calls"][0]["function"]["arguments"] == ( + '{"email": "john@example.com"}' + ) + + +async def test_response_tool_call_arguments_blocked_raises(): + handler = FakeHandler( + [_resp({"decision": "BLOCKED", "blockMessage": "tool blocked"})] + ) + g = _make_guardrail(handler) + tcs = [_tool_call('{"x": "bad"}')] + with pytest.raises(GuardrailRaisedException) as exc_info: + await g.apply_guardrail( + inputs={"texts": [], "tool_calls": tcs}, + request_data={}, + input_type="response", + ) + assert exc_info.value.status_code == 400 + assert exc_info.value.message == "tool blocked" + + +async def test_request_tool_calls_are_not_scanned(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + tcs = [_tool_call('{"x": "y"}')] + await g.apply_guardrail( + inputs={"texts": ["hello"], "tool_calls": tcs}, + request_data={}, + input_type="request", + ) + assert len(handler.calls) == 1 + assert handler.calls[0].json["text"] == "hello" + + +async def test_tool_call_scan_backend_failure_fail_closed_raises(): + handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)]) + g = _make_guardrail(handler) + tcs = [_tool_call('{"x": "y"}')] + with pytest.raises(GuardrailRaisedException) as exc_info: + await g.apply_guardrail( + inputs={"texts": [], "tool_calls": tcs}, + request_data={}, + input_type="response", + ) + assert exc_info.value.status_code == 400 + assert len(handler.calls) == 2 + + +async def test_tool_call_scan_backend_failure_fail_open_passes_through(): + handler = FakeHandler([_resp({}, status_code=503), _resp({}, status_code=503)]) + g = _make_guardrail(handler, unreachable_fallback="fail_open") + tcs = [_tool_call('{"x": "y"}')] + out = await g.apply_guardrail( + inputs={"texts": [], "tool_calls": tcs}, request_data={}, input_type="response" + ) + assert out["tool_calls"] == tcs + + +async def test_response_tool_call_unrecognized_decision_fail_closed_raises(): + handler = FakeHandler([_resp({"decision": "MAYBE"})]) + g = _make_guardrail(handler) + tcs = [_tool_call('{"x": "y"}')] + with pytest.raises(GuardrailRaisedException) as exc_info: + await g.apply_guardrail( + inputs={"texts": [], "tool_calls": tcs}, + request_data={}, + input_type="response", + ) + assert exc_info.value.status_code == 400 + + +async def test_response_tool_call_unrecognized_decision_fail_open_passes_through(): + handler = FakeHandler([_resp({"decision": "MAYBE"})]) + g = _make_guardrail(handler, unreachable_fallback="fail_open") + tcs = [_tool_call('{"x": "y"}')] + out = await g.apply_guardrail( + inputs={"texts": [], "tool_calls": tcs}, request_data={}, input_type="response" + ) + assert out["tool_calls"] == tcs + + +async def test_metadata_allowlist_and_clamping(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + request_data = { + "model": "gpt-4o", + "metadata": { + "user_id": "u1", + "tenant_id": "t1", + "secret_unlisted": "should_not_forward", + "session_id": "s" * 600, + "org_id": ["a"] * 20, + "request_id": True, + "conversation_id": 7, + }, + } + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data=request_data, input_type="request" + ) + md = handler.calls[0].json["metadata"] + assert md["model"] == "gpt-4o" + assert md["user_id"] == "u1" + assert md["tenant_id"] == "t1" + assert "secret_unlisted" not in md + assert len(md["session_id"]) == 500 + assert len(md["org_id"]) == 10 + assert "request_id" not in md + assert md["conversation_id"] == 7 + + +async def test_metadata_source_precedence_and_litellm_metadata_fallback(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + request_data = { + "user_id": "top", + "metadata": {"user_id": "nested"}, + "litellm_metadata": {"tenant_id": "lm-tenant"}, + } + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data=request_data, input_type="request" + ) + md = handler.calls[0].json["metadata"] + assert md["user_id"] == "top" + assert md["tenant_id"] == "lm-tenant" + + +async def test_metadata_uses_later_source_when_earlier_value_is_unclampable(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + request_data = { + "user_id": {"drop": "dicts are not forwarded"}, + "metadata": {"user_id": "nested"}, + } + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data=request_data, input_type="request" + ) + assert handler.calls[0].json["metadata"]["user_id"] == "nested" + + +async def test_metadata_array_items_are_clamped_and_filtered(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + request_data = { + "metadata": { + "org_id": ["z" * 600, 123, True, {"drop": 1}, None], + }, + } + await g.apply_guardrail( + inputs={"texts": ["x"]}, request_data=request_data, input_type="request" + ) + assert handler.calls[0].json["metadata"]["org_id"] == ["z" * 500, 123] + + +async def test_metadata_array_with_no_supported_items_is_dropped(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + await g.apply_guardrail( + inputs={"texts": ["x"]}, + request_data={"metadata": {"org_id": [{"drop": 1}, None]}}, + input_type="request", + ) + assert "org_id" not in handler.calls[0].json["metadata"] + + +async def test_call_id_forwarded_from_logging_obj(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + logging_obj = SimpleNamespace(litellm_call_id="call-123") + await g.apply_guardrail( + inputs={"texts": ["x"]}, + request_data={}, + input_type="request", + logging_obj=logging_obj, + ) + assert handler.calls[0].json["metadata"]["litellm_call_id"] == "call-123" + + +async def test_call_id_forwarded_from_request_data(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + await g.apply_guardrail( + inputs={"texts": ["x"]}, + request_data={"litellm_call_id": "rd-1"}, + input_type="request", + logging_obj=None, + ) + assert handler.calls[0].json["metadata"]["litellm_call_id"] == "rd-1" + + +async def test_call_id_forwarded_from_request_metadata(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + await g.apply_guardrail( + inputs={"texts": ["x"]}, + request_data={"metadata": {"litellm_call_id": "md-1"}}, + input_type="request", + logging_obj=None, + ) + assert handler.calls[0].json["metadata"]["litellm_call_id"] == "md-1" + + +async def test_call_id_logging_obj_takes_precedence(): + handler = FakeHandler([_resp({"decision": "ALLOWED"})]) + g = _make_guardrail(handler) + logging_obj = SimpleNamespace(litellm_call_id="log-1") + await g.apply_guardrail( + inputs={"texts": ["x"]}, + request_data={"litellm_call_id": "rd-1"}, + input_type="request", + logging_obj=logging_obj, + ) + assert handler.calls[0].json["metadata"]["litellm_call_id"] == "log-1" + + +def test_enum_value(): + assert SupportedGuardrailIntegrations.VIGIL_GUARD.value == "vigil_guard" + + +def test_config_model_ui_name_and_instantiation(): + assert VigilGuardGuardrailConfigModel.ui_friendly_name() == "Vigil Guard" + model = VigilGuardGuardrailConfigModel(api_base="https://x", api_key="k") + assert model.api_base == "https://x" + + +def test_get_config_model_returns_config_model(): + g = _make_guardrail(FakeHandler([])) + assert g.get_config_model() is VigilGuardGuardrailConfigModel + + +def test_registries_expose_initializer_and_class(): + assert "vigil_guard" in guardrail_initializer_registry + assert guardrail_class_registry["vigil_guard"] is VigilGuardGuardrail + + +def test_litellm_params_includes_config_model(): + assert VigilGuardGuardrailConfigModel in LitellmParams.__mro__ + + +def test_config_driven_initialization_creates_callback(): + lp = LitellmParams( + guardrail="vigil_guard", + mode="pre_call", + api_base="https://vigil.test", + api_key="k", + ) + cb = initialize_guardrail(lp, {"guardrail_name": "vg"}) + assert isinstance(cb, VigilGuardGuardrail) + assert cb.unreachable_fallback == "fail_closed" diff --git a/tests/test_litellm/proxy/guardrails/test_content_filter_path_traversal.py b/tests/test_litellm/proxy/guardrails/test_content_filter_path_traversal.py new file mode 100644 index 0000000000..2d19fe7fe7 --- /dev/null +++ b/tests/test_litellm/proxy/guardrails/test_content_filter_path_traversal.py @@ -0,0 +1,213 @@ +import os +from unittest.mock import patch +import pytest + + +class TestContentFilterPathTraversal: + """Tests that _resolve_category_file_path rejects path traversal.""" + + def _get_guardrail(self): + from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import ( + ContentFilterGuardrail, + ) + + return ContentFilterGuardrail.__new__(ContentFilterGuardrail) + + def test_traversal_via_relative_dotdot_raises(self): + guardrail = self._get_guardrail() + with pytest.raises(ValueError, match="outside the allowed categories"): + guardrail._resolve_category_file_path("../../../../etc/passwd") + + def test_traversal_via_absolute_path_raises(self): + guardrail = self._get_guardrail() + with pytest.raises(ValueError, match="outside the allowed categories"): + guardrail._resolve_category_file_path("/etc/passwd") + + def test_valid_category_file_inside_categories_dir_allowed(self): + guardrail = self._get_guardrail() + categories_dir = os.path.join( + os.path.dirname( + __import__( + "litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter", + fromlist=["content_filter"], + ).__file__ + ), + "categories", + ) + valid_file = os.path.join(categories_dir, "harmful_self_harm.yaml") + if not os.path.exists(valid_file): + pytest.skip("harmful_self_harm.yaml not present in this environment") + result = guardrail._resolve_category_file_path(valid_file) + assert result == valid_file + + def test_invalid_category_name_skipped(self): + from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import ( + ContentFilterGuardrail, + ) + + guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail) + guardrail.loaded_categories = {} + guardrail.severity_threshold = "medium" + guardrail.category_keywords = {} + guardrail.always_block_category_keywords = {} + guardrail.conditional_categories = {} + # category name with path traversal chars must be skipped, not crash + guardrail._load_categories([{"category": "../../etc/passwd", "enabled": True}]) + assert "../../etc/passwd" not in guardrail.loaded_categories + + def test_category_name_with_slash_skipped(self): + from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import ( + ContentFilterGuardrail, + ) + + guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail) + guardrail.loaded_categories = {} + guardrail.severity_threshold = "medium" + guardrail.category_keywords = {} + guardrail.always_block_category_keywords = {} + guardrail.conditional_categories = {} + guardrail._load_categories( + [{"category": "foo/../../etc/passwd", "enabled": True}] + ) + assert "foo/../../etc/passwd" not in guardrail.loaded_categories + + def test_assert_within_categories_dir_blocks_parent_traversal(self): + from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import ( + ContentFilterGuardrail, + ) + + categories_dir = os.path.join( + os.path.dirname( + __import__( + "litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter", + fromlist=["content_filter"], + ).__file__ + ), + "categories", + ) + with pytest.raises(ValueError, match="outside the allowed categories"): + ContentFilterGuardrail._assert_within_categories_dir( + "/etc/passwd", categories_dir + ) + + def test_assert_within_categories_dir_allows_valid_file(self, tmp_path): + from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import ( + ContentFilterGuardrail, + ) + + categories_dir = str(tmp_path) + valid_file = str(tmp_path / "test.yaml") + # Should not raise + ContentFilterGuardrail._assert_within_categories_dir(valid_file, categories_dir) + + def test_assert_within_categories_dir_commonpath_raises_valueerror(self, tmp_path): + """Cover the except-ValueError branch (Windows cross-drive paths).""" + from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import ( + ContentFilterGuardrail, + ) + + categories_dir = str(tmp_path) + valid_file = str(tmp_path / "test.yaml") + with patch( + "os.path.commonpath", side_effect=ValueError("Paths on different drives") + ): + with pytest.raises( + ValueError, match="outside the allowed categories directory" + ): + ContentFilterGuardrail._assert_within_categories_dir( + valid_file, categories_dir + ) + + def test_resolve_category_file_path_direct_join_hit(self): + """Cover the first-join-attempt success branch (lines 383-384).""" + guardrail = self._get_guardrail() + # "categories/" joined directly to module_dir resolves to an existing file. + categories_dir = os.path.join( + os.path.dirname( + __import__( + "litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter", + fromlist=["content_filter"], + ).__file__ + ), + "categories", + ) + yaml_files = [f for f in os.listdir(categories_dir) if f.endswith(".yaml")] + if not yaml_files: + pytest.skip("No category YAML files present in this environment") + relative_path = os.path.join("categories", yaml_files[0]) + result = guardrail._resolve_category_file_path(relative_path) + assert os.path.isabs(result) or os.path.exists(result) + + def test_resolve_category_file_path_component_strip_hit(self): + """Cover the component-stripping loop success branch (lines 392-393).""" + guardrail = self._get_guardrail() + categories_dir = os.path.join( + os.path.dirname( + __import__( + "litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter", + fromlist=["content_filter"], + ).__file__ + ), + "categories", + ) + yaml_files = [f for f in os.listdir(categories_dir) if f.endswith(".yaml")] + if not yaml_files: + pytest.skip("No category YAML files present in this environment") + # Prefix with a fake leading component so the first-join attempt misses, + # but stripping that component reveals categories/ which exists. + prefixed_path = "some_prefix/categories/" + yaml_files[0] + result = guardrail._resolve_category_file_path(prefixed_path) + assert os.path.isabs(result) or os.path.exists(result) + + def test_load_categories_traversal_category_file_skipped(self): + """Cover the except-ValueError branch in _load_categories (lines 451-454).""" + from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import ( + ContentFilterGuardrail, + ) + + guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail) + guardrail.loaded_categories = {} + guardrail.severity_threshold = "medium" + guardrail.category_keywords = {} + guardrail.always_block_category_keywords = {} + guardrail.conditional_categories = {} + # A traversal path in category_file must be skipped (not crash) via ValueError. + guardrail._load_categories( + [ + { + "category": "valid_name", + "enabled": True, + "category_file": "../../../../etc/passwd", + } + ] + ) + assert "valid_name" not in guardrail.loaded_categories + + def test_allow_external_paths_env_var_bypasses_jail(self, tmp_path): + """LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS=true skips the directory jail.""" + import os as _os + from litellm.proxy.guardrails.guardrail_hooks.litellm_content_filter.content_filter import ( + ContentFilterGuardrail, + ) + + guardrail = ContentFilterGuardrail.__new__(ContentFilterGuardrail) + # Create a real file outside the module directory (simulates mounted volume). + external_file = tmp_path / "external_categories.yaml" + external_file.write_text("category_name: test\n") + + with patch.dict( + _os.environ, {"LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS": "true"} + ): + # Should return the path without raising ValueError. + result = guardrail._resolve_category_file_path(str(external_file)) + assert result == str(external_file) + + def test_traversal_blocked_when_allow_external_not_set(self): + """Without the env var the jail still blocks traversal paths.""" + import os as _os + + guardrail = self._get_guardrail() + with patch.dict(_os.environ, {}, clear=False): + _os.environ.pop("LITELLM_CONTENT_FILTER_ALLOW_EXTERNAL_PATHS", None) + with pytest.raises(ValueError, match="outside the allowed categories"): + guardrail._resolve_category_file_path("/etc/passwd") diff --git a/tests/test_litellm/proxy/pass_through_endpoints/test_watsonx_proxy_route.py b/tests/test_litellm/proxy/pass_through_endpoints/test_watsonx_proxy_route.py new file mode 100644 index 0000000000..19a2f7a050 --- /dev/null +++ b/tests/test_litellm/proxy/pass_through_endpoints/test_watsonx_proxy_route.py @@ -0,0 +1,444 @@ +""" +Unit tests for watsonx_proxy_route endpoint. + +Tests the Watsonx pass-through endpoint that handles automatic IAM token management +and version parameter injection. +""" + +import json +import os +import sys +from unittest.mock import AsyncMock, MagicMock, Mock, patch + +import pytest +from fastapi import HTTPException, Request, Response + +sys.path.insert( + 0, os.path.abspath("../../../..") +) # Adds the parent directory to the system path + +import litellm +from litellm.proxy._types import UserAPIKeyAuth +from litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints import ( + watsonx_proxy_route, +) + + +class TestWatsonxProxyRoute: + """Tests for the Watsonx pass-through route.""" + + @pytest.mark.asyncio + async def test_watsonx_proxy_route_success_non_streaming(self): + """Test successful non-streaming request through Watsonx proxy route.""" + # Setup mocks + mock_request = MagicMock(spec=Request) + mock_request.method = "POST" + mock_request.query_params = {} + mock_request.headers = {"content-type": "application/json"} + mock_request.json = AsyncMock(return_value={"stream": False, "input": "test"}) + mock_response = MagicMock(spec=Response) + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + + # Mock provider config + mock_provider_config = MagicMock() + mock_provider_config.get_complete_url.return_value = ( + "https://us-south.ml.cloud.ibm.com/ml/v1/text/generation", + {}, + ) + mock_provider_config.validate_environment.return_value = { + "Authorization": "Bearer test-iam-token" + } + + # Mock endpoint function + mock_endpoint_func = AsyncMock( + return_value={"model_id": "ibm/granite-13b-chat-v2", "results": []} + ) + + with ( + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config", + return_value=mock_provider_config, + ), + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route", + return_value=mock_endpoint_func, + ) as mock_create_route, + ): + result = await watsonx_proxy_route( + endpoint="ml/v1/text/generation", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Verify provider config was called correctly + mock_provider_config.get_complete_url.assert_called_once() + mock_provider_config.validate_environment.assert_called_once() + + # Verify create_pass_through_route was called with correct parameters + mock_create_route.assert_called_once() + call_args = mock_create_route.call_args[1] + assert call_args["endpoint"] == "ml/v1/text/generation" + assert ( + call_args["target"] + == "https://us-south.ml.cloud.ibm.com/ml/v1/text/generation" + ) + assert ( + call_args["custom_headers"]["Authorization"] == "Bearer test-iam-token" + ) + assert call_args["is_streaming_request"] is False + assert call_args["custom_llm_provider"] == "watsonx" + assert ( + call_args["query_params"]["version"] + == litellm.WATSONX_DEFAULT_API_VERSION + ) + + # Verify endpoint function was called + mock_endpoint_func.assert_called_once_with( + mock_request, mock_response, mock_user_api_key_dict + ) + + assert result == {"model_id": "ibm/granite-13b-chat-v2", "results": []} + + @pytest.mark.asyncio + async def test_watsonx_proxy_route_success_streaming(self): + """Test successful streaming request through Watsonx proxy route.""" + # Setup mocks + mock_request = MagicMock(spec=Request) + mock_request.method = "POST" + mock_request.query_params = {} + mock_request.headers = {"content-type": "application/json"} + mock_request.json = AsyncMock(return_value={"stream": True, "input": "test"}) + mock_response = MagicMock(spec=Response) + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + + # Mock provider config + mock_provider_config = MagicMock() + mock_provider_config.get_complete_url.return_value = ( + "https://us-south.ml.cloud.ibm.com/ml/v1/text/generation_stream", + {}, + ) + mock_provider_config.validate_environment.return_value = { + "Authorization": "Bearer test-iam-token" + } + + # Mock endpoint function + mock_endpoint_func = AsyncMock(return_value="streaming_response") + + with ( + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config", + return_value=mock_provider_config, + ), + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route", + return_value=mock_endpoint_func, + ) as mock_create_route, + ): + result = await watsonx_proxy_route( + endpoint="ml/v1/text/generation_stream", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Verify create_pass_through_route was called with streaming enabled + mock_create_route.assert_called_once() + call_args = mock_create_route.call_args[1] + assert call_args["is_streaming_request"] is True + + assert result == "streaming_response" + + @pytest.mark.asyncio + async def test_watsonx_proxy_route_get_request(self): + """Test GET request through Watsonx proxy route.""" + # Setup mocks + mock_request = MagicMock(spec=Request) + mock_request.method = "GET" + mock_request.query_params = {"project_id": "test-project"} + mock_request.headers = {} + mock_response = MagicMock(spec=Response) + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + + # Mock provider config + mock_provider_config = MagicMock() + mock_provider_config.get_complete_url.return_value = ( + "https://us-south.ml.cloud.ibm.com/ml/v1/models", + {}, + ) + mock_provider_config.validate_environment.return_value = { + "Authorization": "Bearer test-iam-token" + } + + # Mock endpoint function + mock_endpoint_func = AsyncMock(return_value={"resources": []}) + + with ( + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config", + return_value=mock_provider_config, + ), + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route", + return_value=mock_endpoint_func, + ) as mock_create_route, + ): + result = await watsonx_proxy_route( + endpoint="ml/v1/models", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Verify is_streaming_request is False for GET requests + mock_create_route.assert_called_once() + call_args = mock_create_route.call_args[1] + assert call_args["is_streaming_request"] is False + + assert result == {"resources": []} + + @pytest.mark.asyncio + async def test_watsonx_proxy_route_multipart_form_data(self): + """Test multipart/form-data request through Watsonx proxy route.""" + # Setup mocks + mock_request = MagicMock(spec=Request) + mock_request.method = "POST" + mock_request.query_params = {} + mock_request.headers = {"content-type": "multipart/form-data; boundary=----"} + mock_response = MagicMock(spec=Response) + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + + # Mock form data + mock_form_data = {"file": "test_file", "stream": False} + + # Mock provider config + mock_provider_config = MagicMock() + mock_provider_config.get_complete_url.return_value = ( + "https://us-south.ml.cloud.ibm.com/ml/v1/text/tokenization", + {}, + ) + mock_provider_config.validate_environment.return_value = { + "Authorization": "Bearer test-iam-token" + } + + # Mock endpoint function + mock_endpoint_func = AsyncMock(return_value={"token_count": 10}) + + with ( + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config", + return_value=mock_provider_config, + ), + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.get_form_data", + return_value=mock_form_data, + ), + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route", + return_value=mock_endpoint_func, + ) as mock_create_route, + ): + result = await watsonx_proxy_route( + endpoint="ml/v1/text/tokenization", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Verify is_streaming_request is False for non-streaming form data + mock_create_route.assert_called_once() + call_args = mock_create_route.call_args[1] + assert call_args["is_streaming_request"] is False + + assert result == {"token_count": 10} + + @pytest.mark.asyncio + async def test_watsonx_proxy_route_no_provider_config(self): + """Test that HTTPException is raised when provider config is not found.""" + # Setup mocks + mock_request = MagicMock(spec=Request) + mock_request.method = "POST" + mock_request.query_params = {} + mock_request.headers = {"content-type": "application/json"} + mock_response = MagicMock(spec=Response) + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + + with ( + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config", + return_value=None, + ), + ): + with pytest.raises(HTTPException) as exc_info: + await watsonx_proxy_route( + endpoint="ml/v1/text/generation", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + assert exc_info.value.status_code == 404 + assert exc_info.value.detail == "Watsonx passthrough config not found" + + @pytest.mark.asyncio + async def test_watsonx_proxy_route_version_parameter_injection(self): + """Test that version parameter is correctly injected into query params.""" + # Setup mocks + mock_request = MagicMock(spec=Request) + mock_request.method = "POST" + mock_request.query_params = {} + mock_request.headers = {"content-type": "application/json"} + mock_request.json = AsyncMock(return_value={"input": "test"}) + mock_response = MagicMock(spec=Response) + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + + # Mock provider config + mock_provider_config = MagicMock() + mock_provider_config.get_complete_url.return_value = ( + "https://us-south.ml.cloud.ibm.com/ml/v1/text/generation", + {}, + ) + mock_provider_config.validate_environment.return_value = { + "Authorization": "Bearer test-iam-token" + } + + # Mock endpoint function + mock_endpoint_func = AsyncMock(return_value={}) + + with ( + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config", + return_value=mock_provider_config, + ), + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route", + return_value=mock_endpoint_func, + ) as mock_create_route, + ): + await watsonx_proxy_route( + endpoint="ml/v1/text/generation", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Verify version parameter is injected + mock_create_route.assert_called_once() + call_args = mock_create_route.call_args[1] + assert "query_params" in call_args + assert "version" in call_args["query_params"] + assert ( + call_args["query_params"]["version"] + == litellm.WATSONX_DEFAULT_API_VERSION + ) + + @pytest.mark.asyncio + async def test_watsonx_proxy_route_custom_headers_from_validate_environment(self): + """Test that custom headers from validate_environment are passed through.""" + # Setup mocks + mock_request = MagicMock(spec=Request) + mock_request.method = "POST" + mock_request.query_params = {} + mock_request.headers = {"content-type": "application/json"} + mock_request.json = AsyncMock(return_value={"input": "test"}) + mock_response = MagicMock(spec=Response) + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + + # Mock provider config with custom headers + mock_provider_config = MagicMock() + mock_provider_config.get_complete_url.return_value = ( + "https://us-south.ml.cloud.ibm.com/ml/v1/text/generation", + {}, + ) + mock_provider_config.validate_environment.return_value = { + "Authorization": "Bearer test-iam-token", + "X-Custom-Header": "custom-value", + } + + # Mock endpoint function + mock_endpoint_func = AsyncMock(return_value={}) + + with ( + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config", + return_value=mock_provider_config, + ), + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route", + return_value=mock_endpoint_func, + ) as mock_create_route, + ): + await watsonx_proxy_route( + endpoint="ml/v1/text/generation", + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Verify custom headers are passed through + mock_create_route.assert_called_once() + call_args = mock_create_route.call_args[1] + assert "custom_headers" in call_args + assert ( + call_args["custom_headers"]["Authorization"] == "Bearer test-iam-token" + ) + assert call_args["custom_headers"]["X-Custom-Header"] == "custom-value" + + @pytest.mark.asyncio + async def test_watsonx_proxy_route_different_endpoints(self): + """Test various Watsonx endpoint paths.""" + endpoints = [ + "ml/v1/text/generation", + "ml/v1/text/tokenization", + "ml/v1/deployments/test-deployment/text/generation", + "ml/v1/models", + ] + + for endpoint_path in endpoints: + # Setup mocks + mock_request = MagicMock(spec=Request) + mock_request.method = "POST" + mock_request.query_params = {} + mock_request.headers = {"content-type": "application/json"} + mock_request.json = AsyncMock(return_value={"input": "test"}) + mock_response = MagicMock(spec=Response) + mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth) + + # Mock provider config + mock_provider_config = MagicMock() + mock_provider_config.get_complete_url.return_value = ( + f"https://us-south.ml.cloud.ibm.com/{endpoint_path}", + {}, + ) + mock_provider_config.validate_environment.return_value = { + "Authorization": "Bearer test-iam-token" + } + + # Mock endpoint function + mock_endpoint_func = AsyncMock(return_value={}) + + with ( + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.ProviderConfigManager.get_provider_passthrough_config", + return_value=mock_provider_config, + ), + patch( + "litellm.proxy.pass_through_endpoints.llm_passthrough_endpoints.create_pass_through_route", + return_value=mock_endpoint_func, + ) as mock_create_route, + ): + await watsonx_proxy_route( + endpoint=endpoint_path, + request=mock_request, + fastapi_response=mock_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # Verify endpoint is passed correctly + mock_create_route.assert_called_once() + call_args = mock_create_route.call_args[1] + assert call_args["endpoint"] == endpoint_path + assert ( + call_args["target"] + == f"https://us-south.ml.cloud.ibm.com/{endpoint_path}" + ) diff --git a/tests/test_litellm/proxy/spend_tracking/test_spend_management_endpoints.py b/tests/test_litellm/proxy/spend_tracking/test_spend_management_endpoints.py index 4bcabfe853..3559008642 100644 --- a/tests/test_litellm/proxy/spend_tracking/test_spend_management_endpoints.py +++ b/tests/test_litellm/proxy/spend_tracking/test_spend_management_endpoints.py @@ -3185,3 +3185,358 @@ async def test_view_spend_logs_date_range_hashes_sk_api_key(client, monkeypatch) assert where["api_key"] == "hashed::sk-raw-admin-token" finally: app.dependency_overrides.pop(ps.user_api_key_auth, None) + + +class _SpendScopeMockPrismaClient: + + def __init__(self, get_data_returns=None, find_many_returns=None): + self._get_data_returns = ( + get_data_returns if get_data_returns is not None else [] + ) + self._find_many_returns = ( + find_many_returns if find_many_returns is not None else [] + ) + self.get_data_calls = [] + self.find_many_calls = [] + + client = self + + class _VerificationTokenTable: + async def find_many(self, where=None, order=None, include=None): + client.find_many_calls.append( + {"where": where, "order": order, "include": include} + ) + return client._find_many_returns + + class _DB: + def __init__(self): + self.litellm_verificationtoken = _VerificationTokenTable() + + self.db = _DB() + + async def get_data(self, table_name=None, query_type=None, **kwargs): + self.get_data_calls.append( + {"table_name": table_name, "query_type": query_type, **kwargs} + ) + if query_type == "find_unique": + return self._get_data_returns[0] if self._get_data_returns else None + return self._get_data_returns + + +@pytest.mark.asyncio +async def test_spend_key_fn_proxy_admin_returns_all_keys(client, monkeypatch): + """Admins keep their existing full-table view of /spend/keys.""" + mock_keys = [ + {"token": "hashed-a", "user_id": "alice", "spend": 10.0}, + {"token": "hashed-b", "user_id": "bob", "spend": 5.0}, + ] + mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=mock_keys) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin" + ) + try: + response = client.get( + "/spend/keys", headers={"Authorization": "Bearer sk-test"} + ) + assert response.status_code == 200 + # Admin path: goes through get_data (full table), never the scoped find_many + assert len(mock_prisma.get_data_calls) == 1 + assert mock_prisma.get_data_calls[0]["table_name"] == "key" + assert mock_prisma.get_data_calls[0]["query_type"] == "find_all" + assert mock_prisma.find_many_calls == [] + assert response.json() == mock_keys + finally: + app.dependency_overrides.pop(ps.user_api_key_auth, None) + + +@pytest.mark.asyncio +async def test_spend_key_fn_proxy_admin_view_only_returns_all_keys(client, monkeypatch): + """View-only admins are still admins for this endpoint.""" + mock_keys = [{"token": "hashed-a", "user_id": "alice"}] + mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=mock_keys) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY, user_id="admin_viewer" + ) + try: + response = client.get( + "/spend/keys", headers={"Authorization": "Bearer sk-test"} + ) + assert response.status_code == 200 + assert mock_prisma.find_many_calls == [] + assert len(mock_prisma.get_data_calls) == 1 + finally: + app.dependency_overrides.pop(ps.user_api_key_auth, None) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "role", + [LitellmUserRoles.INTERNAL_USER, LitellmUserRoles.INTERNAL_USER_VIEW_ONLY], +) +async def test_spend_key_fn_internal_user_scoped_to_own_keys(client, monkeypatch, role): + """Both internal-user roles must only see keys they own.""" + caller_owned_keys = [ + {"token": "hashed-mine-1", "user_id": "alice", "spend": 2.0}, + {"token": "hashed-mine-2", "user_id": "alice", "spend": 1.0}, + ] + mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=caller_owned_keys) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth( + user_role=role, user_id="alice" + ) + try: + response = client.get( + "/spend/keys", headers={"Authorization": "Bearer sk-test"} + ) + assert response.status_code == 200 + # Non-admin path goes through the same get_data helper as admin, + # but with a user_id scope so only the caller's rows come back. + assert mock_prisma.find_many_calls == [] + assert len(mock_prisma.get_data_calls) == 1 + call = mock_prisma.get_data_calls[0] + assert call["table_name"] == "key" + assert call["query_type"] == "find_all" + assert call["user_id"] == "alice" + assert response.json() == caller_owned_keys + finally: + app.dependency_overrides.pop(ps.user_api_key_auth, None) + + +@pytest.mark.asyncio +async def test_spend_key_fn_internal_user_without_user_id_returns_empty( + client, monkeypatch +): + """ + A non-admin key with no user_id has no tenant scope. Returning the full + table would re-introduce the leak; return an empty list instead. + """ + mock_prisma = _SpendScopeMockPrismaClient( + get_data_returns=[{"token": "do-not-leak"}], + find_many_returns=[{"token": "do-not-leak"}], + ) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, user_id=None + ) + try: + response = client.get( + "/spend/keys", headers={"Authorization": "Bearer sk-test"} + ) + assert response.status_code == 200 + assert response.json() == [] + assert mock_prisma.get_data_calls == [] + assert mock_prisma.find_many_calls == [] + finally: + app.dependency_overrides.pop(ps.user_api_key_auth, None) + + +@pytest.mark.asyncio +async def test_spend_user_fn_proxy_admin_returns_all_users_without_user_id( + client, monkeypatch +): + """Admins keep their existing full-table view of /spend/users.""" + mock_users = [ + {"user_id": "alice", "user_email": "alice@example.com", "spend": 1.0}, + {"user_id": "bob", "user_email": "bob@example.com", "spend": 2.0}, + ] + mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=mock_users) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin" + ) + try: + response = client.get( + "/spend/users", headers={"Authorization": "Bearer sk-test"} + ) + assert response.status_code == 200 + assert len(mock_prisma.get_data_calls) == 1 + assert mock_prisma.get_data_calls[0]["table_name"] == "user" + assert mock_prisma.get_data_calls[0]["query_type"] == "find_all" + assert response.json() == mock_users + finally: + app.dependency_overrides.pop(ps.user_api_key_auth, None) + + +@pytest.mark.asyncio +async def test_spend_user_fn_proxy_admin_can_query_specific_user_id( + client, monkeypatch +): + """Admins can still target a specific user_id.""" + mock_user = { + "user_id": "carol", + "user_email": "carol@example.com", + "spend": 7.0, + } + mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[mock_user]) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, user_id="admin" + ) + try: + response = client.get( + "/spend/users", + params={"user_id": "carol"}, + headers={"Authorization": "Bearer sk-test"}, + ) + assert response.status_code == 200 + assert len(mock_prisma.get_data_calls) == 1 + assert mock_prisma.get_data_calls[0]["query_type"] == "find_unique" + assert mock_prisma.get_data_calls[0]["user_id"] == "carol" + assert response.json() == [mock_user] + finally: + app.dependency_overrides.pop(ps.user_api_key_auth, None) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "role", + [LitellmUserRoles.INTERNAL_USER, LitellmUserRoles.INTERNAL_USER_VIEW_ONLY], +) +async def test_spend_user_fn_internal_user_scoped_without_user_id( + client, monkeypatch, role +): + """No user_id supplied -> must query the caller's own row, not the table.""" + own_row = {"user_id": "alice", "user_email": "alice@example.com", "spend": 3.0} + mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[own_row]) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth( + user_role=role, user_id="alice" + ) + try: + response = client.get( + "/spend/users", headers={"Authorization": "Bearer sk-test"} + ) + assert response.status_code == 200 + assert len(mock_prisma.get_data_calls) == 1 + assert mock_prisma.get_data_calls[0]["query_type"] == "find_unique" + assert mock_prisma.get_data_calls[0]["user_id"] == "alice" + assert response.json() == [own_row] + finally: + app.dependency_overrides.pop(ps.user_api_key_auth, None) + + +@pytest.mark.asyncio +async def test_spend_user_fn_internal_user_supplying_other_user_id_returns_403( + client, monkeypatch +): + """ + An internal user passing user_id=victim must be rejected outright, not + silently rewritten. A 403 makes the attempt observable in logs. + """ + leaked_victim_row = { + "user_id": "victim", + "user_email": "victim@example.com", + "spend": 999.0, + } + mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[leaked_victim_row]) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, user_id="alice" + ) + try: + response = client.get( + "/spend/users", + params={"user_id": "victim"}, + headers={"Authorization": "Bearer sk-test"}, + ) + assert response.status_code == 403 + assert mock_prisma.get_data_calls == [] + finally: + app.dependency_overrides.pop(ps.user_api_key_auth, None) + + +@pytest.mark.asyncio +async def test_spend_user_fn_internal_user_supplying_own_user_id_is_allowed( + client, monkeypatch +): + """ + Passing your own user_id explicitly is fine — the 403 only fires when + the supplied id differs from the caller's. + """ + own_row = {"user_id": "alice", "user_email": "alice@example.com", "spend": 3.0} + mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[own_row]) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, user_id="alice" + ) + try: + response = client.get( + "/spend/users", + params={"user_id": "alice"}, + headers={"Authorization": "Bearer sk-test"}, + ) + assert response.status_code == 200 + assert len(mock_prisma.get_data_calls) == 1 + assert mock_prisma.get_data_calls[0]["query_type"] == "find_unique" + assert mock_prisma.get_data_calls[0]["user_id"] == "alice" + assert response.json() == [own_row] + finally: + app.dependency_overrides.pop(ps.user_api_key_auth, None) + + +@pytest.mark.asyncio +async def test_spend_user_fn_internal_user_without_user_id_returns_empty( + client, monkeypatch +): + """ + A non-admin key with no user_id has no tenant scope -> return empty, + never the full table. Same defensive contract as /spend/keys. + """ + mock_prisma = _SpendScopeMockPrismaClient( + get_data_returns=[{"user_id": "do-not-leak"}] + ) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER_VIEW_ONLY, user_id=None + ) + try: + response = client.get( + "/spend/users", headers={"Authorization": "Bearer sk-test"} + ) + assert response.status_code == 200 + assert response.json() == [] + assert mock_prisma.get_data_calls == [] + finally: + app.dependency_overrides.pop(ps.user_api_key_auth, None) + + +@pytest.mark.asyncio +async def test_spend_user_fn_strips_password_field(client, monkeypatch): + """ + Existing password-redaction behavior must be preserved on the scoped + path so we don't regress a separate disclosure when adding the fix. + """ + own_row = { + "user_id": "alice", + "user_email": "alice@example.com", + "password": "hashed-password-must-not-leak", + "spend": 1.0, + } + mock_prisma = _SpendScopeMockPrismaClient(get_data_returns=[own_row]) + monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma) + + app.dependency_overrides[ps.user_api_key_auth] = lambda: UserAPIKeyAuth( + user_role=LitellmUserRoles.INTERNAL_USER, user_id="alice" + ) + try: + response = client.get( + "/spend/users", headers={"Authorization": "Bearer sk-test"} + ) + assert response.status_code == 200 + body = response.json() + assert len(body) == 1 + assert "password" not in body[0] + finally: + app.dependency_overrides.pop(ps.user_api_key_auth, None) diff --git a/tests/test_litellm/test__types.py b/tests/test_litellm/test__types.py new file mode 100644 index 0000000000..c6c37d748e --- /dev/null +++ b/tests/test_litellm/test__types.py @@ -0,0 +1,32 @@ +# tests/test_litellm/proxy/test__types.py + +from litellm.proxy._types import LiteLLM_TeamMembership + + +def test_team_membership_budget_table_optional_no_crash(): + """ + Regression test for #28689 + Pydantic v2: Optional[T] without default = required field. + When budget_id is null, DB join returns no litellm_budget_table key. + model_validate must NOT raise 'Field required'. + """ + data = { + "user_id": "test-user", + "team_id": "test-team", + "budget_id": None, + # litellm_budget_table intentionally absent (as DB join returns when budget_id is null) + } + result = LiteLLM_TeamMembership.model_validate(data) + assert result.litellm_budget_table is None + + +def test_team_membership_budget_table_present_still_works(): + """When budget_id exists, litellm_budget_table should still be populated.""" + data = { + "user_id": "test-user", + "team_id": "test-team", + "budget_id": "some-budget-id", + "litellm_budget_table": None, + } + result = LiteLLM_TeamMembership.model_validate(data) + assert result.litellm_budget_table is None diff --git a/tests/test_litellm/test_bedrock_usgov_haiku_1hr_cache.py b/tests/test_litellm/test_bedrock_usgov_haiku_1hr_cache.py new file mode 100644 index 0000000000..1312aa110d --- /dev/null +++ b/tests/test_litellm/test_bedrock_usgov_haiku_1hr_cache.py @@ -0,0 +1,47 @@ +""" +Validate that AWS GovCloud (Bedrock us-gov-*) Haiku 4.5 entries carry +the 1-hour cache write tier. + +AWS Bedrock GovCloud pricing applies a +20% premium over global +Anthropic rates. Global Haiku 4.5 1h cache write is $2.00/MTok; us-gov +is therefore $2.40/MTok — exactly 1.6x the 5-minute rate of $1.50/MTok. + +Source: https://aws.amazon.com/bedrock/pricing/ +""" + +import json +import os + +import pytest + + +@pytest.fixture(scope="module") +def model_data(): + json_path = os.path.join( + os.path.dirname(__file__), "../../model_prices_and_context_window.json" + ) + with open(json_path) as f: + return json.load(f) + + +HAIKU_USGOV_KEYS = [ + "bedrock/us-gov-east-1/anthropic.claude-haiku-4-5-20251001-v1:0", + "bedrock/us-gov-west-1/anthropic.claude-haiku-4-5-20251001-v1:0", +] + + +@pytest.mark.parametrize("model_key", HAIKU_USGOV_KEYS) +def test_usgov_haiku_4_5_1hr_cache_write(model_data, model_key): + assert model_key in model_data, f"Missing model entry: {model_key}" + info = model_data[model_key] + assert ( + info["cache_creation_input_token_cost"] == 1.5e-06 + ), f"{model_key}: 5m cache write should be $1.50/MTok" + assert ( + info["cache_creation_input_token_cost_above_1hr"] == 2.4e-06 + ), f"{model_key}: 1h cache write should be $2.40/MTok" + ratio = ( + info["cache_creation_input_token_cost_above_1hr"] + / info["cache_creation_input_token_cost"] + ) + assert abs(ratio - 1.6) < 1e-9, f"{model_key}: 1h/5m ratio is {ratio}, expected 1.6" diff --git a/tests/test_litellm/test_bedrock_usgov_pricing.py b/tests/test_litellm/test_bedrock_usgov_pricing.py new file mode 100644 index 0000000000..6b3312b5cc --- /dev/null +++ b/tests/test_litellm/test_bedrock_usgov_pricing.py @@ -0,0 +1,132 @@ +""" +Validate AWS GovCloud (Bedrock us-gov-*) Anthropic pricing entries. + +AWS Bedrock pricing in GovCloud carries a +20% premium over the global +Anthropic prices (not the +10% commercial-US premium). Until 2026-05-22 +these entries silently mirrored commercial US, undercharging customers +by ~9%. + +Source: https://aws.amazon.com/bedrock/pricing/ + + Sonnet 4.5 in us-gov-* (per million tokens): + input = $3.60 + output = $18.00 + cache write 5m = $4.50 + cache write 1h = $7.20 + cache read = $0.36 + +Reference: https://github.com/BerriAI/litellm/issues/27120 +""" + +import json +import os + +import pytest + + +@pytest.fixture(scope="module") +def model_data(): + json_path = os.path.join( + os.path.dirname(__file__), "../../model_prices_and_context_window.json" + ) + with open(json_path) as f: + return json.load(f) + + +SONNET_4_5_USGOV_KEYS = [ + "bedrock/us-gov-east-1/anthropic.claude-sonnet-4-5-20250929-v1:0", + "bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0", + "bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0", + "bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0", + "us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0", +] + + +@pytest.mark.parametrize("model_key", SONNET_4_5_USGOV_KEYS) +def test_usgov_sonnet_4_5_pricing(model_data, model_key): + """Each us-gov sonnet-4-5 entry must carry the +20%-over-global rates + that AWS publishes on the GovCloud pricing page. + """ + assert model_key in model_data, f"Missing model entry: {model_key}" + info = model_data[model_key] + + assert info["input_cost_per_token"] == 3.6e-06, ( + f"{model_key}: input_cost_per_token should be $3.60/MTok " + f"(got {info['input_cost_per_token']})" + ) + assert ( + info["output_cost_per_token"] == 1.8e-05 + ), f"{model_key}: output_cost_per_token should be $18.00/MTok" + assert ( + info["cache_creation_input_token_cost"] == 4.5e-06 + ), f"{model_key}: 5m cache write should be $4.50/MTok" + assert ( + info["cache_creation_input_token_cost_above_1hr"] == 7.2e-06 + ), f"{model_key}: 1h cache write should be $7.20/MTok" + assert ( + info["cache_read_input_token_cost"] == 3.6e-07 + ), f"{model_key}: cache read should be $0.36/MTok" + + +def test_usgov_carries_20_percent_premium_over_global(model_data): + """The us-gov rates must equal 1.2x the global anthropic.* rates, + matching AWS's documented GovCloud uplift. + """ + global_key = "anthropic.claude-sonnet-4-5-20250929-v1:0" + usgov_key = "bedrock/us-gov-west-1/anthropic.claude-sonnet-4-5-20250929-v1:0" + global_info = model_data[global_key] + usgov_info = model_data[usgov_key] + for field in ( + "input_cost_per_token", + "output_cost_per_token", + "cache_creation_input_token_cost", + "cache_creation_input_token_cost_above_1hr", + "cache_read_input_token_cost", + ): + ratio = usgov_info[field] / global_info[field] + assert ( + abs(ratio - 1.2) < 1e-9 + ), f"{field}: us-gov / global ratio is {ratio}, expected 1.2" + + +# The us-gov.anthropic.* cross-region inference profile is the only us-gov +# entry that carries the 1M-context `_above_200k_tokens` pricing tier — the +# bedrock/us-gov-{east,west}-1/ entries are capped at 200k tokens. +USGOV_CROSS_REGION_KEY = "us-gov.anthropic.claude-sonnet-4-5-20250929-v1:0" + +EXPECTED_USGOV_ABOVE_200K = { + "input_cost_per_token_above_200k_tokens": 7.2e-06, + "output_cost_per_token_above_200k_tokens": 2.7e-05, + "cache_creation_input_token_cost_above_200k_tokens": 9.0e-06, + "cache_creation_input_token_cost_above_1hr_above_200k_tokens": 1.44e-05, + "cache_read_input_token_cost_above_200k_tokens": 7.2e-07, +} + + +@pytest.mark.parametrize("field,expected", EXPECTED_USGOV_ABOVE_200K.items()) +def test_usgov_cross_region_above_200k_carries_gov_premium(model_data, field, expected): + """The `_above_200k_tokens` tier on the us-gov cross-region inference + profile must also carry the +20% GovCloud uplift. The original PR + corrected the base rates but left the 200k-tier fields at the +10% + commercial-US rates, undercharging long-context requests. + """ + info = model_data[USGOV_CROSS_REGION_KEY] + assert field in info, f"{USGOV_CROSS_REGION_KEY}: missing field {field}" + assert ( + info[field] == expected + ), f"{USGOV_CROSS_REGION_KEY}: {field} should be {expected} (got {info[field]})" + + +def test_usgov_cross_region_above_200k_ratio_to_global(model_data): + """Cross-check via the property-based invariant: every `_above_200k_tokens` + field on the us-gov cross-region profile must equal 1.2x the global + anthropic.* rate, the same GovCloud uplift the base tier carries. + """ + global_key = "anthropic.claude-sonnet-4-5-20250929-v1:0" + global_info = model_data[global_key] + usgov_info = model_data[USGOV_CROSS_REGION_KEY] + for field in EXPECTED_USGOV_ABOVE_200K: + ratio = usgov_info[field] / global_info[field] + assert ( + abs(ratio - 1.2) < 1e-9 + ), f"{field}: us-gov / global ratio is {ratio}, expected 1.2"