diff --git a/litellm/caching/caching.py b/litellm/caching/caching.py index 11733ce4ce..c1afde1625 100644 --- a/litellm/caching/caching.py +++ b/litellm/caching/caching.py @@ -309,9 +309,13 @@ class Cache: param_value = kwargs[param] cache_key += f"{str(param)}: {str(param_value)}" - verbose_logger.debug("\nCreated cache key: %s", cache_key) hashed_cache_key = Cache._get_hashed_cache_key(cache_key) hashed_cache_key = self._add_namespace_to_cache_key(hashed_cache_key, **kwargs) + verbose_logger.debug( + "\nCreated cache key: %s (source material length: %d)", + hashed_cache_key, + len(cache_key), + ) # Remove preset_cache_key from kwargs to avoid "got multiple values" TypeError # when kwargs already contains preset_cache_key from upstream callers kwargs_for_preset = {k: v for k, v in kwargs.items() if k != "preset_cache_key"} @@ -497,6 +501,34 @@ class Cache: return cached_response return cached_result + @staticmethod + def _get_safe_cache_lookup_kwargs(kwargs: Dict[str, Any]) -> Dict[str, Any]: + cache_lookup_kwargs: Dict[str, Any] = {} + for prompt_kwarg in ("messages", "input"): + if prompt_kwarg in kwargs: + cache_lookup_kwargs[prompt_kwarg] = kwargs[prompt_kwarg] + + if isinstance(kwargs.get("metadata"), dict): + cache_lookup_kwargs["metadata"] = {} + + return cache_lookup_kwargs + + @staticmethod + def _update_metadata_from_cache_lookup_kwargs( + original_kwargs: Dict[str, Any], cache_lookup_kwargs: Dict[str, Any] + ) -> None: + original_metadata = original_kwargs.get("metadata") + cache_lookup_metadata = cache_lookup_kwargs.get("metadata") + if not isinstance(original_metadata, dict) or not isinstance( + cache_lookup_metadata, dict + ): + return + + if "semantic-similarity" in cache_lookup_metadata: + original_metadata["semantic-similarity"] = cache_lookup_metadata[ + "semantic-similarity" + ] + def get_cache(self, dynamic_cache_object: Optional[BaseCache] = None, **kwargs): """ Retrieves the cached result for the given arguments. @@ -511,7 +543,6 @@ class Cache: try: # never block execution if self.should_use_cache(**kwargs) is not True: return - messages = kwargs.get("messages", []) if "cache_key" in kwargs: cache_key = kwargs["cache_key"] else: @@ -523,12 +554,19 @@ class Cache: or cache_control_args.get("s-max-age") or float("inf") ) + cache_lookup_kwargs = self._get_safe_cache_lookup_kwargs(kwargs) if dynamic_cache_object is not None: cached_result = dynamic_cache_object.get_cache( - cache_key, messages=messages + cache_key, **cache_lookup_kwargs ) else: - cached_result = self.cache.get_cache(cache_key, messages=messages) + cached_result = self.cache.get_cache( + cache_key, **cache_lookup_kwargs + ) + self._update_metadata_from_cache_lookup_kwargs( + original_kwargs=kwargs, + cache_lookup_kwargs=cache_lookup_kwargs, + ) return self._get_cache_logic( cached_result=cached_result, max_age=max_age ) @@ -549,7 +587,6 @@ class Cache: if self.should_use_cache(**kwargs) is not True: return - kwargs.get("messages", []) if "cache_key" in kwargs: cache_key = kwargs["cache_key"] else: diff --git a/litellm/caching/redis_semantic_cache.py b/litellm/caching/redis_semantic_cache.py index da9e7b1e58..cce4b75795 100644 --- a/litellm/caching/redis_semantic_cache.py +++ b/litellm/caching/redis_semantic_cache.py @@ -213,6 +213,78 @@ class RedisSemanticCache(BaseCache): ttl = int(ttl) return ttl + @classmethod + def _get_prompt_from_kwargs(cls, **kwargs) -> Optional[str]: + """ + Extract a semantic-cache prompt from chat or Responses API request kwargs. + """ + messages = kwargs.get("messages") + if messages: + return get_str_from_messages(messages) + + if "input" not in kwargs: + return None + + prompt_parts: List[str] = [] + cls._collect_responses_input_text(kwargs.get("input"), prompt_parts) + prompt = "\n".join(prompt_parts).strip() + return prompt or None + + @classmethod + def _collect_responses_input_text(cls, value: Any, prompt_parts: List[str]) -> None: + value = cls._coerce_response_input_value(value) + if value is None: + return + + if isinstance(value, str): + stripped_value = value.strip() + if stripped_value: + prompt_parts.append(stripped_value) + return + + if isinstance(value, (list, tuple)): + for item in value: + cls._collect_responses_input_text(item, prompt_parts) + return + + if isinstance(value, dict): + content = value.get("content") + if content is not None: + cls._collect_responses_input_text(content, prompt_parts) + return + + for text_key in ("text", "output", "input_text", "output_text"): + text_value = value.get(text_key) + if isinstance(text_value, str): + stripped_text = text_value.strip() + if stripped_text: + prompt_parts.append(stripped_text) + return + return + + content = getattr(value, "content", None) + if content is not None: + cls._collect_responses_input_text(content, prompt_parts) + return + + for text_key in ("text", "output", "input_text", "output_text"): + text_value = getattr(value, text_key, None) + if isinstance(text_value, str): + stripped_text = text_value.strip() + if stripped_text: + prompt_parts.append(stripped_text) + return + + @staticmethod + def _coerce_response_input_value(value: Any) -> Any: + model_dump = getattr(value, "model_dump", None) + if callable(model_dump): + return model_dump() + dict_method = getattr(value, "dict", None) + if callable(dict_method): + return dict_method() + return value + def _get_embedding(self, prompt: str) -> List[float]: """ Generate an embedding vector for the given prompt using the configured embedding model. @@ -278,13 +350,11 @@ class RedisSemanticCache(BaseCache): value_str: Optional[str] = None try: - # Extract the prompt from messages - messages = kwargs.get("messages", []) - if not messages: - print_verbose("No messages provided for semantic caching") + prompt = self._get_prompt_from_kwargs(**kwargs) + if prompt is None: + print_verbose("No prompt provided for semantic caching") return - prompt = get_str_from_messages(messages) value_str = str(value) store_kwargs: Dict[str, Any] = { @@ -315,14 +385,12 @@ class RedisSemanticCache(BaseCache): print_verbose(f"Redis semantic-cache get_cache, kwargs: {kwargs}") try: - # Extract the prompt from messages - messages = kwargs.get("messages", []) - if not messages: - print_verbose("No messages provided for semantic cache lookup") + prompt = self._get_prompt_from_kwargs(**kwargs) + if prompt is None: + print_verbose("No prompt provided for semantic cache lookup") kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 return None - prompt = get_str_from_messages(messages) # Check the cache for semantically similar prompts in this exact # LiteLLM cache-key scope. check_kwargs: Dict[str, Any] = { @@ -428,13 +496,11 @@ class RedisSemanticCache(BaseCache): print_verbose(f"Async Redis semantic-cache set_cache, kwargs: {kwargs}") try: - # Extract the prompt from messages - messages = kwargs.get("messages", []) - if not messages: - print_verbose("No messages provided for semantic caching") + prompt = self._get_prompt_from_kwargs(**kwargs) + if prompt is None: + print_verbose("No prompt provided for semantic caching") return - prompt = get_str_from_messages(messages) value_str = str(value) # Generate embedding for the value (response) to cache @@ -471,15 +537,12 @@ class RedisSemanticCache(BaseCache): print_verbose(f"Async Redis semantic-cache get_cache, kwargs: {kwargs}") try: - # Extract the prompt from messages - messages = kwargs.get("messages", []) - if not messages: - print_verbose("No messages provided for semantic cache lookup") + prompt = self._get_prompt_from_kwargs(**kwargs) + if prompt is None: + print_verbose("No prompt provided for semantic cache lookup") kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 return None - prompt = get_str_from_messages(messages) - # Generate embedding for the prompt prompt_embedding = await self._get_async_embedding(prompt, **kwargs) diff --git a/litellm/constants.py b/litellm/constants.py index 36e578bd32..f10cec034f 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -831,6 +831,7 @@ openai_compatible_providers: List = [ "nano-gpt", # Nano-GPT - JSON-configured provider "poe", # Poe - JSON-configured provider "chutes", # Chutes - JSON-configured provider + "parasail", # Parasail - JSON-configured provider "featherless_ai", "nscale", "nebius", diff --git a/litellm/integrations/focus/destinations/__init__.py b/litellm/integrations/focus/destinations/__init__.py index 775d3a259d..e0cd90c1d6 100644 --- a/litellm/integrations/focus/destinations/__init__.py +++ b/litellm/integrations/focus/destinations/__init__.py @@ -2,12 +2,14 @@ from .base import FocusDestination, FocusTimeWindow from .factory import FocusDestinationFactory +from .gcs_destination import FocusGCSDestination from .s3_destination import FocusS3Destination from .vantage_destination import FocusVantageDestination __all__ = [ "FocusDestination", "FocusDestinationFactory", + "FocusGCSDestination", "FocusTimeWindow", "FocusS3Destination", "FocusVantageDestination", diff --git a/litellm/integrations/focus/destinations/factory.py b/litellm/integrations/focus/destinations/factory.py index 706e10624c..7ce21d4040 100644 --- a/litellm/integrations/focus/destinations/factory.py +++ b/litellm/integrations/focus/destinations/factory.py @@ -6,6 +6,7 @@ import os from typing import Any, Dict, Optional from .base import FocusDestination +from .gcs_destination import FocusGCSDestination from .s3_destination import FocusS3Destination from .vantage_destination import FocusVantageDestination @@ -29,6 +30,8 @@ class FocusDestinationFactory: return FocusS3Destination(prefix=prefix, config=normalized_config) if provider_lower == "vantage": return FocusVantageDestination(prefix=prefix, config=normalized_config) + if provider_lower == "gcs": + return FocusGCSDestination(prefix=prefix, config=normalized_config) raise NotImplementedError( f"Provider '{provider}' not supported for Focus export" ) @@ -72,6 +75,18 @@ class FocusDestinationFactory: "VANTAGE_INTEGRATION_TOKEN must be provided for Vantage exports" ) return {k: v for k, v in resolved.items() if v is not None} + if provider == "gcs": + resolved = { + "bucket_name": overrides.get("bucket_name") + or os.getenv("FOCUS_GCS_BUCKET_NAME"), + "service_account_json": overrides.get("service_account_json") + or os.getenv("FOCUS_GCS_PATH_SERVICE_ACCOUNT"), + } + if not resolved.get("bucket_name"): + raise ValueError( + "FOCUS_GCS_BUCKET_NAME must be provided for GCS exports" + ) + return {k: v for k, v in resolved.items() if v is not None} raise NotImplementedError( f"Provider '{provider}' not supported for Focus export configuration" ) diff --git a/litellm/integrations/focus/destinations/gcs_destination.py b/litellm/integrations/focus/destinations/gcs_destination.py new file mode 100644 index 0000000000..b04c16c9d3 --- /dev/null +++ b/litellm/integrations/focus/destinations/gcs_destination.py @@ -0,0 +1,74 @@ +"""GCS destination for Focus export — reuses GCSBucketBase auth and httpx client.""" + +from __future__ import annotations + +from datetime import timezone +from typing import Any, Optional + +from litellm._logging import verbose_logger +from litellm.integrations.gcs_bucket.gcs_bucket_base import GCSBucketBase +from litellm.litellm_core_utils.cloud_storage_security import ( + encode_gcs_object_name_for_url, +) + +from .base import FocusDestination, FocusTimeWindow + + +class FocusGCSDestination(GCSBucketBase, FocusDestination): + """Upload serialized Focus exports to GCS using the GCS JSON API.""" + + def __init__( + self, + *, + prefix: str, + config: Optional[dict[str, Any]] = None, + ) -> None: + config = config or {} + bucket_name = config.get("bucket_name") + if not bucket_name: + raise ValueError("bucket_name must be provided for GCS destination") + super().__init__(bucket_name=bucket_name) + service_account_json = config.get("service_account_json") + if service_account_json is not None: + self.path_service_account_json = service_account_json + self.prefix = prefix.rstrip("/") + + async def deliver( + self, + *, + content: bytes, + time_window: FocusTimeWindow, + filename: str, + ) -> None: + object_name = self._build_object_key(time_window=time_window, filename=filename) + headers = await self.construct_request_headers( + service_account_json=self.path_service_account_json + ) + headers["Content-Type"] = "application/octet-stream" + encoded_name = encode_gcs_object_name_for_url(object_name) + url = ( + f"https://storage.googleapis.com/upload/storage/v1/b/" + f"{self.BUCKET_NAME}/o?uploadType=media&name={encoded_name}" + ) + response = await self.async_httpx_client.post( + url=url, headers=headers, data=content + ) + if response.status_code != 200: + raise RuntimeError( + f"GCS upload failed: status={response.status_code} body={response.text}" + ) + verbose_logger.debug( + "Focus GCS: uploaded %d bytes to gs://%s/%s", + len(content), + self.BUCKET_NAME, + object_name, + ) + + def _build_object_key(self, *, time_window: FocusTimeWindow, filename: str) -> str: + start_utc = time_window.start_time.astimezone(timezone.utc) + date_component = f"date={start_utc.strftime('%Y-%m-%d')}" + parts = [self.prefix, date_component] + if time_window.frequency == "hourly": + parts.append(f"hour={start_utc.strftime('%H')}") + key_prefix = "/".join(filter(None, parts)) + return f"{key_prefix}/{filename}" if key_prefix else filename diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 7949e150c2..3f30d5d680 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -1607,6 +1607,15 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): ) return _tool + def should_strip_billing_metadata(self) -> bool: + """ + Whether to drop x-anthropic-billing-header system blocks before sending upstream. + + The first-party Anthropic API uses these blocks for Claude Code attribution, so the + base config keeps them. Providers that reject them (e.g. Bedrock) override this to True. + """ + return False + def translate_system_message( self, messages: List[AllMessageValues] ) -> List[AnthropicSystemMessageContent]: @@ -1614,7 +1623,7 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): Translate system message to anthropic format. Removes system message from the original list and returns a new list of anthropic system message content. - Filters out system messages containing x-anthropic-billing-header metadata. + When should_strip_billing_metadata() is True, x-anthropic-billing-header system blocks are dropped. """ system_prompt_indices = [] anthropic_system_message_list: List[AnthropicSystemMessageContent] = [] @@ -1626,10 +1635,9 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): # Skip empty text blocks - Anthropic API raises errors for empty text if not system_message_block["content"]: continue - # Skip system messages containing x-anthropic-billing-header metadata - if system_message_block["content"].startswith( - "x-anthropic-billing-header:" - ): + if self.should_strip_billing_metadata() and system_message_block[ + "content" + ].startswith("x-anthropic-billing-header:"): continue anthropic_system_message_content = AnthropicSystemMessageContent( type="text", @@ -1648,9 +1656,9 @@ class AnthropicConfig(AnthropicModelInfo, BaseConfig): text_value = _content.get("text") if _content.get("type") == "text" and not text_value: continue - # Skip system messages containing x-anthropic-billing-header metadata if ( - _content.get("type") == "text" + self.should_strip_billing_metadata() + and _content.get("type") == "text" and text_value and text_value.startswith("x-anthropic-billing-header:") ): diff --git a/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py b/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py index 3a2c09f218..07e8270b49 100644 --- a/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py +++ b/litellm/llms/anthropic/experimental_pass_through/messages/transformation.py @@ -84,6 +84,15 @@ class AnthropicMessagesConfig(BaseAnthropicMessagesConfig): if isinstance(content, list): _process_content_list(content) + def should_strip_billing_metadata(self) -> bool: + """ + Whether to drop x-anthropic-billing-header system blocks before sending upstream. + + The first-party Anthropic API uses these blocks for Claude Code attribution, so the + base config keeps them. Providers that reject them override this to True. + """ + return False + @staticmethod def _filter_billing_headers_from_system(system_param): """ @@ -286,14 +295,12 @@ class AnthropicMessagesConfig(BaseAnthropicMessagesConfig): optional_params=anthropic_messages_optional_request_params, ) - # Filter out x-anthropic-billing-header from system messages system_param = anthropic_messages_optional_request_params.get("system") - if system_param is not None: + if self.should_strip_billing_metadata() and system_param is not None: filtered_system = self._filter_billing_headers_from_system(system_param) if filtered_system is not None and len(filtered_system) > 0: anthropic_messages_optional_request_params["system"] = filtered_system else: - # Remove system parameter if all content was filtered out anthropic_messages_optional_request_params.pop("system", None) # Transform context_management from OpenAI format to Anthropic format if needed diff --git a/litellm/llms/azure_ai/anthropic/messages_transformation.py b/litellm/llms/azure_ai/anthropic/messages_transformation.py index a81218ab76..59b6ee2b42 100644 --- a/litellm/llms/azure_ai/anthropic/messages_transformation.py +++ b/litellm/llms/azure_ai/anthropic/messages_transformation.py @@ -21,6 +21,9 @@ class AzureAnthropicMessagesConfig(AnthropicMessagesConfig): and Azure endpoint format. """ + def should_strip_billing_metadata(self) -> bool: + return True + def validate_anthropic_messages_environment( self, headers: dict, diff --git a/litellm/llms/azure_ai/anthropic/transformation.py b/litellm/llms/azure_ai/anthropic/transformation.py index e176a4d860..367ca75c19 100644 --- a/litellm/llms/azure_ai/anthropic/transformation.py +++ b/litellm/llms/azure_ai/anthropic/transformation.py @@ -40,6 +40,9 @@ class AzureAnthropicConfig(AnthropicConfig): def custom_llm_provider(self) -> Optional[str]: return "azure_ai" + def should_strip_billing_metadata(self) -> bool: + return True + def validate_environment( self, headers: dict, diff --git a/litellm/llms/base_llm/responses/transformation.py b/litellm/llms/base_llm/responses/transformation.py index 853eb28275..407d5ad814 100644 --- a/litellm/llms/base_llm/responses/transformation.py +++ b/litellm/llms/base_llm/responses/transformation.py @@ -62,6 +62,26 @@ class BaseResponsesAPIConfig(ABC): """ return False + def sign_request( + self, + headers: dict, + optional_params: dict, + request_data: dict, + api_base: str, + api_key: Optional[str] = None, + model: Optional[str] = None, + stream: Optional[bool] = None, + fake_stream: Optional[bool] = None, + ) -> Tuple[dict, Optional[bytes]]: + """Sign the request after the body is finalized. + + Default is a no-op (returns headers unchanged, no signed body). Providers + whose endpoint requires request signing (e.g. Bedrock Mantle SigV4) + override this and return the signed body bytes so the handler sends those + exact bytes. + """ + return headers, None + @abstractmethod def get_supported_openai_params(self, model: str) -> list: pass diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index 90dfa13e93..ea0326dffd 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -1649,12 +1649,14 @@ class AmazonConverseConfig(BaseConfig): bedrock_tool_config["toolChoice"] = tool_choice_values data: CommonRequestObject = { - "additionalModelRequestFields": additional_request_params, - "system": system_content_blocks, "inferenceConfig": self._transform_inference_params( inference_params=inference_params ), } + if additional_request_params: + data["additionalModelRequestFields"] = additional_request_params + if system_content_blocks: + data["system"] = system_content_blocks # Handle all config blocks for config_name, config_class in self.get_config_blocks().items(): diff --git a/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py index a13336b6c8..4887cbd23b 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py @@ -60,6 +60,9 @@ class AmazonAnthropicClaudeConfig(AmazonInvokeConfig, AnthropicConfig): def custom_llm_provider(self) -> Optional[str]: return "bedrock" + def should_strip_billing_metadata(self) -> bool: + return True + def get_supported_openai_params(self, model: str) -> List[str]: return AnthropicConfig.get_supported_openai_params(self, model) diff --git a/litellm/llms/bedrock/claude_platform/transformation.py b/litellm/llms/bedrock/claude_platform/transformation.py index 0167c457c9..c20dc63444 100644 --- a/litellm/llms/bedrock/claude_platform/transformation.py +++ b/litellm/llms/bedrock/claude_platform/transformation.py @@ -17,6 +17,9 @@ class BedrockClaudePlatformConfig(BedrockClaudePlatformMixin, AnthropicConfig): def custom_llm_provider(self) -> Optional[str]: return "bedrock" + def should_strip_billing_metadata(self) -> bool: + return True + def validate_environment( self, headers: dict, diff --git a/litellm/llms/bedrock_mantle/responses/transformation.py b/litellm/llms/bedrock_mantle/responses/transformation.py index b63fd0ecdb..df21909107 100644 --- a/litellm/llms/bedrock_mantle/responses/transformation.py +++ b/litellm/llms/bedrock_mantle/responses/transformation.py @@ -4,14 +4,26 @@ Amazon Bedrock Mantle - Responses API backend. gpt-5.5 / gpt-5.4 on Mantle are exposed ONLY on the `/openai/v1/responses` path (not the standard `/v1/responses`). Payloads and SSE follow the OpenAI Responses spec, so this config inherits OpenAIResponsesAPIConfig and overrides -only the endpoint URL and Bearer authentication. +only the endpoint URL and authentication. -Auth: AWS Bedrock API key as Bearer token (BEDROCK_MANTLE_API_KEY or the -standard AWS_BEARER_TOKEN_BEDROCK), NOT SigV4. +Auth: Bearer token (BEDROCK_MANTLE_API_KEY or the standard +AWS_BEARER_TOKEN_BEDROCK, or litellm_params.api_key) when present; otherwise +AWS SigV4 (service name "bedrock") using the standard credential chain (IAM +role / access key / profile / web identity), signed via the shared +BaseAWSLLM._sign_request after the request body is finalized. """ -from typing import Optional +import re +from typing import Optional, Tuple +from botocore.exceptions import ( + CredentialRetrievalError, + NoCredentialsError, + PartialCredentialsError, + ProfileNotFound, +) + +from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig from litellm.secret_managers.main import get_secret_str from litellm.types.router import GenericLiteLLMParams @@ -29,22 +41,44 @@ _BASE_SUFFIXES_TO_STRIP = ( "/v1", ) +# Standard Mantle host: https://bedrock-mantle..api.aws (group 1 = region). +_MANTLE_HOST_RE = re.compile( + r"^https?://bedrock-mantle\.([^/.]+)\.api\.aws", re.IGNORECASE +) + class BedrockMantleResponsesAPIConfig(OpenAIResponsesAPIConfig): + def __init__(self, aws_signer: Optional[BaseAWSLLM] = None): + super().__init__() + self._aws_signer = aws_signer or BaseAWSLLM() + @property def custom_llm_provider(self) -> LlmProviders: return LlmProviders.BEDROCK_MANTLE + @staticmethod + def _resolve_region(params: dict) -> str: + region = params.get("aws_region_name") + if region: + return region + base = params.get("api_base") or get_secret_str("BEDROCK_MANTLE_API_BASE") + if base: + match = _MANTLE_HOST_RE.match(base.rstrip("/")) + if match: + return match.group(1) + return ( + get_secret_str("BEDROCK_MANTLE_REGION") + or get_secret_str("AWS_REGION_NAME") + or get_secret_str("AWS_REGION") + or BEDROCK_MANTLE_DEFAULT_REGION + ) + def get_complete_url( self, api_base: Optional[str], litellm_params: dict, ) -> str: - region = ( - get_secret_str("BEDROCK_MANTLE_REGION") - or get_secret_str("AWS_REGION") - or BEDROCK_MANTLE_DEFAULT_REGION - ) + region = self._resolve_region({**litellm_params, "api_base": api_base}) base = ( api_base or get_secret_str("BEDROCK_MANTLE_API_BASE") @@ -55,6 +89,11 @@ class BedrockMantleResponsesAPIConfig(OpenAIResponsesAPIConfig): if base.endswith(suffix): base = base[: -len(suffix)] break + # For the standard Mantle host (including the default-region base that + # responses/main.py auto-injects into litellm_params.api_base), pin to the + # single resolved region so aws_region_name wins; preserve custom proxy hosts. + if _MANTLE_HOST_RE.match(base): + base = f"https://bedrock-mantle.{region}.api.aws" return f"{base}/openai/v1/responses" def validate_environment( @@ -66,12 +105,8 @@ class BedrockMantleResponsesAPIConfig(OpenAIResponsesAPIConfig): or get_secret_str("BEDROCK_MANTLE_API_KEY") or get_secret_str("AWS_BEARER_TOKEN_BEDROCK") ) - if not api_key: - raise ValueError( - "Bedrock Mantle API key is required. Set BEDROCK_MANTLE_API_KEY " - "(or AWS_BEARER_TOKEN_BEDROCK) or pass api_key." - ) - headers["Authorization"] = f"Bearer {api_key}" + if api_key: + headers["Authorization"] = f"Bearer {api_key}" return headers def supports_native_file_search(self) -> bool: @@ -79,3 +114,58 @@ class BedrockMantleResponsesAPIConfig(OpenAIResponsesAPIConfig): def supports_native_websocket(self) -> bool: return False + + def sign_request( + self, + headers: dict, + optional_params: dict, + request_data: dict, + api_base: str, + api_key: Optional[str] = None, + model: Optional[str] = None, + stream: Optional[bool] = None, + fake_stream: Optional[bool] = None, + ) -> Tuple[dict, Optional[bytes]]: + bearer = ( + api_key + or get_secret_str("BEDROCK_MANTLE_API_KEY") + or get_secret_str("AWS_BEARER_TOKEN_BEDROCK") + ) + if not bearer: + # SigV4 path. Pin the credential-scope region to the region of the actual + # signing URL (api_base, already region-resolved by get_complete_url) so the + # SigV4 scope and the URL host can never disagree. Resolve from api_base first, + # then fall back to the regular precedence. Also drop any caller Authorization + # so _sign_request's restore-original-Authorization step cannot override the + # SigV4 header. + optional_params = { + **optional_params, + "aws_region_name": self._resolve_region( + {**optional_params, "api_base": api_base} + ), + } + headers = {k: v for k, v in headers.items() if k.lower() != "authorization"} + try: + return self._aws_signer._sign_request( + service_name="bedrock", + headers=headers, + optional_params=optional_params, + request_data=request_data, + api_base=api_base, + api_key=bearer, + model=model, + stream=stream, + fake_stream=fake_stream, + ) + except ( + NoCredentialsError, + PartialCredentialsError, + ProfileNotFound, + CredentialRetrievalError, + ) as e: + raise ValueError( + "Bedrock Mantle auth failed: no Bearer token and no usable AWS " + "credentials. Set BEDROCK_MANTLE_API_KEY (or AWS_BEARER_TOKEN_BEDROCK) " + "or pass api_key for Bearer auth, or provide AWS credentials " + "(IAM role / access key / profile / web identity) for SigV4." + ) from e diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index 31c772510b..25424feaeb 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -2318,6 +2318,31 @@ class BaseLLMHTTPHandler: # but never included in the outbound provider payload. request_context["litellm_params"] = dict(litellm_params) + is_stream_request = bool(stream) + if is_stream_request and fake_stream is True: + stream, data = self._prepare_fake_stream_request( + stream=stream, + data=data, + fake_stream=fake_stream, + ) + + # Sign after the body is final (post-transform/normalize/extra_body and post + # fake-stream prep) so signed bytes match what we send. No-op for providers + # that inherit the default sign_request. + headers, signed_body = responses_api_provider_config.sign_request( + headers=headers, + optional_params=dict(litellm_params), + request_data=data, + api_base=api_base, + api_key=litellm_params.api_key, + model=model, + stream=stream, + fake_stream=fake_stream, + ) + body_kwargs: Dict[str, Any] = ( + {"data": signed_body} if signed_body is not None else {"json": data} + ) + ## LOGGING logging_obj.pre_call( input=input, @@ -2330,22 +2355,14 @@ class BaseLLMHTTPHandler: ) try: - if stream: - # For streaming, use stream=True in the request - if fake_stream is True: - stream, data = self._prepare_fake_stream_request( - stream=stream, - data=data, - fake_stream=fake_stream, - ) - + if is_stream_request: response = sync_httpx_client.post( url=api_base, headers=headers, - json=data, timeout=timeout or float(response_api_optional_request_params.get("timeout", 0)), stream=stream, + **body_kwargs, ) if fake_stream is True: return MockResponsesAPIStreamingIterator( @@ -2370,13 +2387,12 @@ class BaseLLMHTTPHandler: call_type=CallTypes.responses.value, ) else: - # For non-streaming requests response = sync_httpx_client.post( url=api_base, headers=headers, - json=data, timeout=timeout or float(response_api_optional_request_params.get("timeout", 0)), + **body_kwargs, ) except Exception as e: raise self._handle_error( @@ -2464,6 +2480,28 @@ class BaseLLMHTTPHandler: # but never included in the outbound provider payload. request_context["litellm_params"] = dict(litellm_params) + is_stream_request = bool(stream) + if is_stream_request and fake_stream is True: + stream, data = self._prepare_fake_stream_request( + stream=stream, + data=data, + fake_stream=fake_stream, + ) + + headers, signed_body = responses_api_provider_config.sign_request( + headers=headers, + optional_params=dict(litellm_params), + request_data=data, + api_base=api_base, + api_key=litellm_params.api_key, + model=model, + stream=stream, + fake_stream=fake_stream, + ) + body_kwargs: Dict[str, Any] = ( + {"data": signed_body} if signed_body is not None else {"json": data} + ) + ## LOGGING logging_obj.pre_call( input=input, @@ -2476,22 +2514,14 @@ class BaseLLMHTTPHandler: ) try: - if stream: - # For streaming, we need to use stream=True in the request - if fake_stream is True: - stream, data = self._prepare_fake_stream_request( - stream=stream, - data=data, - fake_stream=fake_stream, - ) - + if is_stream_request: response = await async_httpx_client.post( url=api_base, headers=headers, - json=data, timeout=timeout or float(response_api_optional_request_params.get("timeout", 0)), stream=stream, + **body_kwargs, ) if fake_stream is True: @@ -2518,13 +2548,12 @@ class BaseLLMHTTPHandler: call_type=CallTypes.responses.value, ) else: - # For non-streaming, proceed as before response = await async_httpx_client.post( url=api_base, headers=headers, - json=data, timeout=timeout or float(response_api_optional_request_params.get("timeout", 0)), + **body_kwargs, ) except Exception as e: @@ -4005,6 +4034,18 @@ class BaseLLMHTTPHandler: ) data = BaseResponsesAPIConfig.normalize_responses_api_request_dict(data) + headers, signed_body = responses_api_provider_config.sign_request( + headers=headers, + optional_params=dict(litellm_params), + request_data=data, + api_base=url, + api_key=litellm_params.api_key, + model=model, + ) + body_kwargs: Dict[str, Any] = ( + {"data": signed_body} if signed_body is not None else {"json": data} + ) + ## LOGGING logging_obj.pre_call( input=input, @@ -4018,7 +4059,7 @@ class BaseLLMHTTPHandler: try: response = sync_httpx_client.post( - url=url, headers=headers, json=data, timeout=timeout + url=url, headers=headers, timeout=timeout, **body_kwargs ) except Exception as e: @@ -4088,6 +4129,18 @@ class BaseLLMHTTPHandler: ) data = BaseResponsesAPIConfig.normalize_responses_api_request_dict(data) + headers, signed_body = responses_api_provider_config.sign_request( + headers=headers, + optional_params=dict(litellm_params), + request_data=data, + api_base=url, + api_key=litellm_params.api_key, + model=model, + ) + body_kwargs: Dict[str, Any] = ( + {"data": signed_body} if signed_body is not None else {"json": data} + ) + ## LOGGING logging_obj.pre_call( input=input, @@ -4101,7 +4154,7 @@ class BaseLLMHTTPHandler: try: response = await async_httpx_client.post( - url=url, headers=headers, json=data, timeout=timeout + url=url, headers=headers, timeout=timeout, **body_kwargs ) except Exception as e: diff --git a/litellm/llms/deepseek/messages/transformation.py b/litellm/llms/deepseek/messages/transformation.py index ad60478960..63b736ffd1 100644 --- a/litellm/llms/deepseek/messages/transformation.py +++ b/litellm/llms/deepseek/messages/transformation.py @@ -26,6 +26,9 @@ class DeepSeekAnthropicMessagesConfig(AnthropicMessagesConfig): def custom_llm_provider(self) -> Optional[str]: return "deepseek" + def should_strip_billing_metadata(self) -> bool: + return True + @staticmethod def get_api_key(api_key: Optional[str] = None) -> Optional[str]: return api_key or get_secret_str("DEEPSEEK_API_KEY") or litellm.api_key diff --git a/litellm/llms/github_copilot/responses/transformation.py b/litellm/llms/github_copilot/responses/transformation.py index 0929f95cf4..3406538c77 100644 --- a/litellm/llms/github_copilot/responses/transformation.py +++ b/litellm/llms/github_copilot/responses/transformation.py @@ -2,7 +2,7 @@ GitHub Copilot Responses API Configuration. This module provides the configuration for GitHub Copilot's Responses API, -which is required for models like gpt-5.1-codex that only support the /responses endpoint. +which is required for models like gpt-5.3-codex that only support the /responses endpoint. Implementation based on analysis of the copilot-api project by caozhiyuan: https://github.com/caozhiyuan/copilot-api @@ -12,6 +12,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional, Union import os +import litellm from litellm._logging import verbose_logger from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH from litellm.exceptions import AuthenticationError @@ -22,6 +23,7 @@ from litellm.types.llms.openai import ( ) from litellm.types.router import GenericLiteLLMParams from litellm.types.utils import LlmProviders +from litellm.utils import _cached_get_model_info_helper from ..authenticator import Authenticator from ..common_utils import ( @@ -38,6 +40,47 @@ else: LiteLLMLoggingObj = Any +def github_copilot_supports_responses_api(model: str) -> bool: + """ + Gate native /v1/responses dispatch per github_copilot model. + + Resolution (first match wins): mode "responses" -> True; mode "chat" -> + False (opt-out wins for dual-endpoint models); "/v1/responses" in + supported_endpoints -> True; else False. Unknown model -> False (the bridge + always works since every Copilot model supports /chat/completions). + + Reads merged model info (per-deployment model_info applied via the router's + register_model, which also clears the cache used here). + """ + try: + info = _cached_get_model_info_helper( + model=model, custom_llm_provider="github_copilot" + ) + except Exception as e: + verbose_logger.debug( + "github_copilot_supports_responses_api: get_model_info failed " + "for %s: %s", + model, + e, + ) + return False + + mode = info.get("mode") + if mode == "responses": + return True + if mode == "chat": + return False + + # supported_endpoints is dropped by ModelInfoBase; read it from the raw + # model_cost entry via the resolved key. + key = info.get("key") + raw_info = litellm.model_cost.get(key) if isinstance(key, str) else None + endpoints = ( + raw_info.get("supported_endpoints") if isinstance(raw_info, dict) else None + ) + return isinstance(endpoints, list) and "/v1/responses" in endpoints + + class GithubCopilotResponsesAPIConfig(OpenAIResponsesAPIConfig): """ Configuration for GitHub Copilot's Responses API. diff --git a/litellm/llms/minimax/messages/transformation.py b/litellm/llms/minimax/messages/transformation.py index 3190a5f541..57cfcbf062 100644 --- a/litellm/llms/minimax/messages/transformation.py +++ b/litellm/llms/minimax/messages/transformation.py @@ -28,6 +28,9 @@ class MinimaxMessagesConfig(AnthropicMessagesConfig): def custom_llm_provider(self) -> Optional[str]: return "minimax" + def should_strip_billing_metadata(self) -> bool: + return True + @staticmethod def get_api_key(api_key: Optional[str] = None) -> Optional[str]: """ diff --git a/litellm/llms/openai_like/dynamic_config.py b/litellm/llms/openai_like/dynamic_config.py index fac453447f..9ed9734eda 100644 --- a/litellm/llms/openai_like/dynamic_config.py +++ b/litellm/llms/openai_like/dynamic_config.py @@ -187,6 +187,7 @@ def create_responses_config_class(provider: SimpleProviderConfig): from litellm.llms.openai_like.responses.transformation import ( OpenAILikeResponsesConfig, ) + from litellm.types.llms.openai import ResponseInputParam from litellm.types.router import GenericLiteLLMParams class JSONProviderResponsesConfig(OpenAILikeResponsesConfig): @@ -223,5 +224,23 @@ def create_responses_config_class(provider: SimpleProviderConfig): api_base = api_base.rstrip("/") return f"{api_base}/responses" + def transform_responses_api_request( + self, + model: str, + input: Union[str, ResponseInputParam], + response_api_optional_request_params: dict, + litellm_params: GenericLiteLLMParams, + headers: dict, + ) -> dict: + if provider.special_handling.get("force_store_false"): + response_api_optional_request_params["store"] = False + return super().transform_responses_api_request( + model=model, + input=input, + response_api_optional_request_params=response_api_optional_request_params, + litellm_params=litellm_params, + headers=headers, + ) + _responses_config_cache[provider.slug] = JSONProviderResponsesConfig return JSONProviderResponsesConfig diff --git a/litellm/llms/openai_like/providers.json b/litellm/llms/openai_like/providers.json index 49b3801c82..13d2248883 100644 --- a/litellm/llms/openai_like/providers.json +++ b/litellm/llms/openai_like/providers.json @@ -132,5 +132,14 @@ "param_mappings": { "max_completion_tokens": "max_tokens" } + }, + "parasail": { + "base_url": "https://api.parasail.io/v1", + "api_key_env": "PARASAIL_API_KEY", + "api_base_env": "PARASAIL_API_BASE", + "supported_endpoints": ["/v1/chat/completions", "/v1/responses"], + "special_handling": { + "force_store_false": true + } } } diff --git a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py index 5cd02293f1..7d355a2e90 100644 --- a/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai/gemini/vertex_and_google_ai_studio_gemini.py @@ -1111,6 +1111,7 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): { "voice": "alloy", "format": "mp3", + "language_code": "en-US", } Expected output: @@ -1119,7 +1120,8 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): prebuiltVoiceConfig: { voiceName: "alloy", } - } + }, + languageCode: "en-US", } """ from litellm.types.llms.vertex_ai import ( @@ -1145,6 +1147,9 @@ class VertexGeminiConfig(VertexAIBaseConfig, BaseConfig): voice_config: VoiceConfig = {"prebuiltVoiceConfig": prebuilt_voice_config} speech_config["voiceConfig"] = voice_config + if "language_code" in value: + speech_config["languageCode"] = value["language_code"] + return cast(dict, speech_config) @staticmethod diff --git a/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/experimental_pass_through/transformation.py b/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/experimental_pass_through/transformation.py index 1e92754857..8a92e7ec4a 100644 --- a/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/experimental_pass_through/transformation.py +++ b/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/experimental_pass_through/transformation.py @@ -17,6 +17,9 @@ from ..output_params_utils import sanitize_vertex_anthropic_output_params class VertexAIPartnerModelsAnthropicMessagesConfig(AnthropicMessagesConfig, VertexBase): + def should_strip_billing_metadata(self) -> bool: + return True + def validate_anthropic_messages_environment( self, headers: dict, diff --git a/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py b/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py index c852909d47..ae8bdc5544 100644 --- a/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py +++ b/litellm/llms/vertex_ai/vertex_ai_partner_models/anthropic/transformation.py @@ -52,6 +52,9 @@ class VertexAIAnthropicConfig(AnthropicConfig): def custom_llm_provider(self) -> Optional[str]: return "vertex_ai" + def should_strip_billing_metadata(self) -> bool: + return True + def _add_context_management_beta_headers( self, beta_set: set, context_management: dict ) -> None: diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 57a7d860ba..fba968dd95 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2312,6 +2312,24 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase): None, description="List of MCP server fields that must be filled in for a submission to pass standards checks (e.g. ['description', 'source_url', 'alias']).", ) + disable_budget_reservation: Optional[bool] = Field( + None, + description=( + "If True, disables the optimistic per-request budget reservation " + "introduced in v1.84.0. " + "WARNING: This weakens hard budget enforcement. Without the reservation, " + "a burst of concurrent requests from a single key can each pass the " + "read-time spend check before any of them is charged, allowing a " + "configured budget to be exceeded under high concurrency. " + "Budgets are still evaluated on every request at read time, so " + "an already-exhausted budget is still rejected. " + "Enable only if your deployment is experiencing phantom " + "BudgetExceededError responses caused by leaked reservations " + "(see GitHub issue #27639). " + "A proxy-level WARNING is logged on every request while this flag " + "is active as a reminder that hard enforcement is relaxed." + ), + ) class ConfigYAML(LiteLLMPydanticObjectBase): diff --git a/litellm/proxy/auth/user_api_key_auth.py b/litellm/proxy/auth/user_api_key_auth.py index a970e0ddee..666c01562b 100644 --- a/litellm/proxy/auth/user_api_key_auth.py +++ b/litellm/proxy/auth/user_api_key_auth.py @@ -2425,6 +2425,7 @@ async def _run_centralized_common_checks( # noqa: PLR0915 user_api_key_cache=user_api_key_cache, proxy_logging_obj=proxy_logging_obj, skip_budget_checks=skip_budget_checks, + general_settings=general_settings, ) @@ -2445,12 +2446,23 @@ async def _reserve_budget_after_common_checks( user_api_key_cache: UserApiKeyCache, proxy_logging_obj: ProxyLogging, skip_budget_checks: bool, + general_settings: dict, end_user_id: Optional[str] = None, end_user_object: Optional[LiteLLM_EndUserTable] = None, ) -> None: user_api_key_auth_obj.budget_reservation = None if skip_budget_checks: return + if general_settings.get("disable_budget_reservation") is True: + verbose_proxy_logger.warning( + "disable_budget_reservation is enabled: skipping optimistic budget " + "reservation. Budget enforcement is read-time only — concurrent " + "requests can each pass the spend check before their cost is recorded, " + "so a configured budget may be briefly exceeded under high concurrency. " + "Set disable_budget_reservation to False or remove it to restore " + "hard per-request budget enforcement." + ) + return from litellm.proxy.spend_tracking.budget_reservation import ( reserve_budget_for_request, diff --git a/litellm/proxy/guardrails/guardrail_hooks/crowdstrike_aidr/crowdstrike_aidr.py b/litellm/proxy/guardrails/guardrail_hooks/crowdstrike_aidr/crowdstrike_aidr.py index 14d950ecdf..d1ef165b46 100644 --- a/litellm/proxy/guardrails/guardrail_hooks/crowdstrike_aidr/crowdstrike_aidr.py +++ b/litellm/proxy/guardrails/guardrail_hooks/crowdstrike_aidr/crowdstrike_aidr.py @@ -312,11 +312,27 @@ class CrowdStrikeAIDRHandler(CustomGuardrail): event_type = "output" hook_name = "apply_guardrail (response)" - ai_guard_payload = { + ai_guard_payload: dict[str, Any] = { "guard_input": guard_input.model_dump(mode="json"), "event_type": event_type, } + model = inputs.get("model") + if model: + ai_guard_payload["model"] = model + + metadata = request_data.get("litellm_metadata", request_data.get("metadata")) + if isinstance(metadata, Mapping): + user_id = metadata.get("user_api_key_user_id") + if user_id: + ai_guard_payload["user_id"] = user_id + + extra_info: dict[str, str] = {} + user_email = metadata.get("user_api_key_user_email") + if user_email: + extra_info["user_name"] = user_email + ai_guard_payload["extra_info"] = extra_info + ai_guard_response = await self._call_crowdstrike_aidr_guard( ai_guard_payload, hook_name ) diff --git a/litellm/proxy/management_endpoints/ui_sso.py b/litellm/proxy/management_endpoints/ui_sso.py index a957061685..4812bed2f2 100644 --- a/litellm/proxy/management_endpoints/ui_sso.py +++ b/litellm/proxy/management_endpoints/ui_sso.py @@ -1220,7 +1220,7 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: generic_role_mappings_group_claim = os.getenv( "GENERIC_ROLE_MAPPINGS_GROUP_CLAIM", None ) - generic_role_mappoings_default_role = os.getenv( + generic_role_mappings_default_role = os.getenv( "GENERIC_ROLE_MAPPINGS_DEFAULT_ROLE", None ) if generic_role_mappings is not None: @@ -1239,7 +1239,7 @@ async def _setup_role_mappings() -> Optional["RoleMappings"]: role_mappings_data = { "provider": "generic", "group_claim": generic_role_mappings_group_claim, - "default_role": generic_role_mappoings_default_role, + "default_role": generic_role_mappings_default_role, "roles": generic_user_role_mappings_data, } diff --git a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/openai_passthrough_logging_handler.py b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/openai_passthrough_logging_handler.py index 29bbb37501..6dd1f8548e 100644 --- a/litellm/proxy/pass_through_endpoints/llm_provider_handlers/openai_passthrough_logging_handler.py +++ b/litellm/proxy/pass_through_endpoints/llm_provider_handlers/openai_passthrough_logging_handler.py @@ -32,6 +32,71 @@ from litellm.types.passthrough_endpoints.pass_through_endpoints import ( from litellm.types.utils import ImageResponse, LlmProviders, PassthroughCallTypes from litellm.utils import ModelResponse, TextCompletionResponse +# Hostnames that route to OpenAI-compatible APIs. +# +# `api.openai.com` is OpenAI proper. The two Azure domains below are *shared by +# every Azure Cognitive Service* (Speech, Vision, Language, ...), not just Azure +# OpenAI: `openai.azure.com` is the classic Azure OpenAI domain, while +# `cognitiveservices.azure.com` is used by newer "Azure AI Foundry" / +# Cognitive Services-hosted Azure OpenAI deployments. Because the hostname alone +# cannot tell Azure OpenAI apart from the other Cognitive Services on those +# domains, requests there must additionally carry an OpenAI-style path segment. +_OPENAI_HOSTNAMES = ("api.openai.com",) +_AZURE_OPENAI_HOSTNAMES = ("openai.azure.com", "cognitiveservices.azure.com") +# Path markers that identify an Azure request as Azure OpenAI rather than Speech +# / Vision / Language / ... `/openai/` is the native Azure OpenAI path prefix; +# `/v1/` is the OpenAI-v1 surface used by LiteLLM's pass-through routing. Other +# Cognitive Services use service-named prefixes and versions like `/v3.1/`, +# `/v1.0/`, so they do not collide with these markers. +_AZURE_OPENAI_PATH_MARKERS = ("/openai/", "/v1/") + + +def _hostname_matches(hostname: str, suffixes: tuple) -> bool: + """True if hostname equals one of `suffixes` or is a subdomain of it. + + Uses suffix matching (not a bare substring test) so look-alikes such as + `cognitiveservices.azure.com.attacker.example` are not accepted. + """ + return any( + hostname == suffix or hostname.endswith("." + suffix) for suffix in suffixes + ) + + +def _is_openai_compatible_host(hostname: Optional[str]) -> bool: + """True if the hostname is OpenAI proper or one of the Azure OpenAI domains. + + Hostname-only check, kept for the route-level helpers that additionally + require a specific OpenAI path (e.g. `/v1/chat/completions`). When only the + hostname would otherwise gate dispatch, use `_is_openai_compatible_url` so + non-OpenAI Azure Cognitive Services on the shared domains are excluded. + """ + if not hostname: + return False + return _hostname_matches(hostname, _OPENAI_HOSTNAMES) or _hostname_matches( + hostname, _AZURE_OPENAI_HOSTNAMES + ) + + +def _is_openai_compatible_url(url_route: Optional[str]) -> bool: + """True if the URL targets an OpenAI-compatible API surface. + + For the shared Azure Cognitive Services domains we additionally require an + OpenAI-style path segment (`/openai/` or `/v1/`) so non-OpenAI Azure services + (Speech, Vision, Language, ...) on the same domain are not misclassified as + OpenAI routes. + """ + if not url_route: + return False + parsed_url = urlparse(url_route) + hostname = parsed_url.hostname + if not hostname: + return False + if _hostname_matches(hostname, _OPENAI_HOSTNAMES): + return True + if _hostname_matches(hostname, _AZURE_OPENAI_HOSTNAMES): + return any(marker in parsed_url.path for marker in _AZURE_OPENAI_PATH_MARKERS) + return False + class OpenAIPassthroughLoggingHandler(BasePassthroughLoggingHandler): """ @@ -52,12 +117,8 @@ class OpenAIPassthroughLoggingHandler(BasePassthroughLoggingHandler): if not url_route: return False parsed_url = urlparse(url_route) - return bool( - parsed_url.hostname - and ( - "api.openai.com" in parsed_url.hostname - or "openai.azure.com" in parsed_url.hostname - ) + return ( + _is_openai_compatible_host(parsed_url.hostname) and "/v1/chat/completions" in parsed_url.path ) @@ -67,12 +128,8 @@ class OpenAIPassthroughLoggingHandler(BasePassthroughLoggingHandler): if not url_route: return False parsed_url = urlparse(url_route) - return bool( - parsed_url.hostname - and ( - "api.openai.com" in parsed_url.hostname - or "openai.azure.com" in parsed_url.hostname - ) + return ( + _is_openai_compatible_host(parsed_url.hostname) and "/v1/images/generations" in parsed_url.path ) @@ -82,12 +139,8 @@ class OpenAIPassthroughLoggingHandler(BasePassthroughLoggingHandler): if not url_route: return False parsed_url = urlparse(url_route) - return bool( - parsed_url.hostname - and ( - "api.openai.com" in parsed_url.hostname - or "openai.azure.com" in parsed_url.hostname - ) + return ( + _is_openai_compatible_host(parsed_url.hostname) and "/v1/images/edits" in parsed_url.path ) @@ -97,13 +150,8 @@ class OpenAIPassthroughLoggingHandler(BasePassthroughLoggingHandler): if not url_route: return False parsed_url = urlparse(url_route) - return bool( - parsed_url.hostname - and ( - "api.openai.com" in parsed_url.hostname - or "openai.azure.com" in parsed_url.hostname - ) - and ("/v1/responses" in parsed_url.path or "/responses" in parsed_url.path) + return _is_openai_compatible_host(parsed_url.hostname) and ( + "/v1/responses" in parsed_url.path or "/responses" in parsed_url.path ) def _get_user_from_metadata( diff --git a/litellm/proxy/pass_through_endpoints/success_handler.py b/litellm/proxy/pass_through_endpoints/success_handler.py index 292871bae6..af1d39da02 100644 --- a/litellm/proxy/pass_through_endpoints/success_handler.py +++ b/litellm/proxy/pass_through_endpoints/success_handler.py @@ -434,15 +434,20 @@ class PassThroughEndpointLogging: return False def is_openai_route(self, url_route: str): - """Check if the URL route is an OpenAI API route.""" + """Check if the URL route is an OpenAI API route. + + Uses the URL-aware helper so that non-OpenAI Azure Cognitive Services + (Speech, Vision, Language, ...) sharing the `*.cognitiveservices.azure.com` + / `*.openai.azure.com` domains are not misclassified as OpenAI routes. + """ if not url_route: return False - parsed_url = urlparse(url_route) - return parsed_url.hostname and ( - "api.openai.com" in parsed_url.hostname - or "openai.azure.com" in parsed_url.hostname + from .llm_provider_handlers.openai_passthrough_logging_handler import ( + _is_openai_compatible_url, ) + return _is_openai_compatible_url(url_route) + def is_gemini_route( self, url_route: str, custom_llm_provider: Optional[str] = None ): diff --git a/litellm/responses/litellm_completion_transformation/transformation.py b/litellm/responses/litellm_completion_transformation/transformation.py index e2ba835359..d3d3064221 100644 --- a/litellm/responses/litellm_completion_transformation/transformation.py +++ b/litellm/responses/litellm_completion_transformation/transformation.py @@ -148,7 +148,9 @@ class LiteLLMCompletionResponsesConfig: # which is equivalent to "required" in OpenAI format return "required" elif tool_choice_type == "function": - # function type without name - fall back to required + function_name = tool_choice.get("name") + if function_name: + return {"type": "function", "function": {"name": function_name}} return "required" # Return as-is for unknown formats diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 51429d0769..00db7b199b 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -232,6 +232,7 @@ class VoiceConfig(TypedDict): class SpeechConfig(TypedDict, total=False): voiceConfig: VoiceConfig + languageCode: str class GenerationConfig(TypedDict, total=False): diff --git a/litellm/types/utils.py b/litellm/types/utils.py index 9633cecf96..c3ea605dd9 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1540,6 +1540,11 @@ class ServerToolUse(BaseModel): web_search_requests: Optional[int] = None tool_search_requests: Optional[int] = None + def __getitem__(self, key: str) -> Optional[int]: + if key not in self.__class__.model_fields: + raise KeyError(key) + return getattr(self, key) + class Usage(SafeAttributeModel, CompletionUsage): _cache_creation_input_tokens: int = PrivateAttr( @@ -1570,7 +1575,7 @@ class Usage(SafeAttributeModel, CompletionUsage): completion_tokens_details: Optional[ Union[CompletionTokensDetailsWrapper, dict] ] = None, - server_tool_use: Optional[ServerToolUse] = None, + server_tool_use: Optional[Union[ServerToolUse, dict]] = None, cost: Optional[float] = None, **params, ): @@ -1671,6 +1676,9 @@ class Usage(SafeAttributeModel, CompletionUsage): prompt_tokens_details=_prompt_tokens_details or None, ) + if isinstance(server_tool_use, dict): + server_tool_use = ServerToolUse(**server_tool_use) + if server_tool_use is not None: self.server_tool_use = server_tool_use else: # maintain openai compatibility in usage object if possible @@ -3392,6 +3400,7 @@ class LlmProviders(str, Enum): POE = "poe" CHUTES = "chutes" NEOSANTARA = "neosantara" + PARASAIL = "parasail" XIAOMI_MIMO = "xiaomi_mimo" TENSORMESH = "tensormesh" LITELLM_AGENT = "litellm_agent" diff --git a/litellm/utils.py b/litellm/utils.py index 7312e71bbd..8d9d0a409c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8895,7 +8895,13 @@ class ProviderConfigManager: elif litellm.LlmProviders.XAI == provider: return litellm.XAIResponsesAPIConfig() elif litellm.LlmProviders.GITHUB_COPILOT == provider: - return litellm.GithubCopilotResponsesAPIConfig() + from litellm.llms.github_copilot.responses.transformation import ( + github_copilot_supports_responses_api, + ) + + if model is None or github_copilot_supports_responses_api(model=model): + return litellm.GithubCopilotResponsesAPIConfig() + return None elif litellm.LlmProviders.CHATGPT == provider: return litellm.ChatGPTResponsesAPIConfig() elif litellm.LlmProviders.LITELLM_PROXY == provider: diff --git a/provider_endpoints_support.json b/provider_endpoints_support.json index a1ad20fffd..6caab585ac 100644 --- a/provider_endpoints_support.json +++ b/provider_endpoints_support.json @@ -1834,6 +1834,23 @@ "search": true } }, + "parasail": { + "display_name": "Parasail (`parasail`)", + "url": "https://docs.litellm.ai/docs/providers/parasail", + "endpoints": { + "chat_completions": true, + "messages": false, + "responses": true, + "embeddings": false, + "image_generations": false, + "audio_transcriptions": false, + "audio_speech": false, + "moderations": false, + "batches": false, + "rerank": false, + "a2a": false + } + }, "perplexity": { "display_name": "Perplexity AI (`perplexity`)", "url": "https://docs.litellm.ai/docs/providers/perplexity", diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index 9cf253c379..fa22ff6b39 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2712,6 +2712,10 @@ def test_bedrock_top_k_param(model, expected_params): data = json.loads(mock_post.call_args.kwargs["data"]) if "mistral" in model: assert data["top_k"] == 2 + elif expected_params == {}: + # Models that don't support top_k produce no additionalModelRequestFields; + # the empty block is now omitted entirely rather than sent as `{}`. + assert "additionalModelRequestFields" not in data else: assert data["additionalModelRequestFields"] == expected_params @@ -3059,8 +3063,6 @@ async def test_bedrock_max_completion_tokens(model: str): assert request_body == { "messages": [{"role": "user", "content": [{"text": "Hello!"}]}], - "additionalModelRequestFields": {}, - "system": [], "inferenceConfig": {"maxTokens": 10}, } diff --git a/tests/test_litellm/caching/test_caching.py b/tests/test_litellm/caching/test_caching.py new file mode 100644 index 0000000000..02d62a1915 --- /dev/null +++ b/tests/test_litellm/caching/test_caching.py @@ -0,0 +1,48 @@ +import logging +import re + +from litellm.caching.caching import Cache +from litellm.types.caching import LiteLLMCacheType + + +def test_cache_key_debug_log_does_not_include_prompt_material(caplog): + cache = Cache(type=LiteLLMCacheType.LOCAL) + prompt_marker = "secret prompt material " + + with caplog.at_level(logging.DEBUG, logger="LiteLLM"): + cache_key = cache.get_cache_key( + model="gpt-4.1-mini", + messages=[ + {"role": "system", "content": prompt_marker * 100}, + {"role": "user", "content": "hello"}, + ], + tools=[ + { + "type": "function", + "function": { + "name": "lookup", + "parameters": { + "type": "object", + "properties": {"query": {"type": "string"}}, + }, + }, + } + ], + response_format={ + "type": "json_schema", + "json_schema": { + "name": "lookup_response", + "schema": {"type": "object"}, + }, + }, + stream=True, + ) + + assert re.fullmatch(r"[0-9a-f]{64}", cache_key) + + created_cache_key_logs = [ + record.getMessage() for record in caplog.records if "Created cache key:" in record.getMessage() + ] + assert created_cache_key_logs + assert all(prompt_marker not in message for message in created_cache_key_logs) + assert any(cache_key in message for message in created_cache_key_logs) diff --git a/tests/test_litellm/caching/test_redis_semantic_cache.py b/tests/test_litellm/caching/test_redis_semantic_cache.py index b50a35ef50..13f9d00136 100644 --- a/tests/test_litellm/caching/test_redis_semantic_cache.py +++ b/tests/test_litellm/caching/test_redis_semantic_cache.py @@ -523,3 +523,468 @@ async def test_redis_semantic_cache_async_set_cache_stores_cache_key_filter( filters={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"}, ttl=60, ) + + +def test_redis_semantic_cache_set_cache_uses_responses_string_input(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + redis_semantic_cache.llmcache = MagicMock() + redis_semantic_cache._get_cache_filters = MagicMock( + return_value={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"} + ) + redis_semantic_cache._get_ttl = MagicMock(return_value=None) + + redis_semantic_cache.set_cache( + key="test_key", + value={"content": "Paris"}, + input="What is the capital of France?", + ) + + redis_semantic_cache.llmcache.store.assert_called_once_with( + "What is the capital of France?", + "{'content': 'Paris'}", + filters={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"}, + ) + + +def test_redis_semantic_cache_get_cache_uses_responses_string_input(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + redis_semantic_cache.similarity_threshold = 0.8 + redis_semantic_cache.llmcache = MagicMock() + redis_semantic_cache.llmcache.check = MagicMock( + return_value=[ + { + "prompt": "What is the capital of France?", + "response": '{"content": "Paris"}', + "vector_distance": 0.1, + RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key", + } + ] + ) + + with patch.object( + redis_semantic_cache, + "_get_cache_key_filter_expression", + return_value="cache-key-filter", + ): + metadata = {} + result = redis_semantic_cache.get_cache( + key="test_key", + input="What is the capital of France?", + metadata=metadata, + ) + + assert result == {"content": "Paris"} + assert metadata["semantic-similarity"] == pytest.approx(0.9) + redis_semantic_cache.llmcache.check.assert_called_once_with( + prompt="What is the capital of France?", + filter_expression="cache-key-filter", + ) + + +def test_redis_semantic_cache_set_cache_flattens_structured_responses_input(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + redis_semantic_cache.llmcache = MagicMock() + redis_semantic_cache._get_cache_filters = MagicMock( + return_value={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"} + ) + redis_semantic_cache._get_ttl = MagicMock(return_value=None) + + redis_semantic_cache.set_cache( + key="test_key", + value={"content": "Paris"}, + input=[ + { + "role": "user", + "content": [ + {"type": "input_text", "text": "What is the capital of France?"}, + {"type": "input_text", "text": "Answer briefly."}, + { + "type": "input_image", + "image_url": "https://example.com/paris.png", + }, + ], + } + ], + ) + + redis_semantic_cache.llmcache.store.assert_called_once_with( + "What is the capital of France?\nAnswer briefly.", + "{'content': 'Paris'}", + filters={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"}, + ) + + +def test_redis_semantic_cache_prompt_extraction_prefers_messages(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + prompt = RedisSemanticCache._get_prompt_from_kwargs( + messages=[{"content": "message prompt"}], + input="responses prompt", + ) + + assert prompt == "message prompt" + + +def test_redis_semantic_cache_prompt_extraction_handles_model_objects(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + class ModelDumpInput: + def model_dump(self): + return {"content": [{"text": "model dump prompt"}]} + + class DictInput: + def dict(self): + return {"content": [{"output_text": "dict prompt"}]} + + prompt = RedisSemanticCache._get_prompt_from_kwargs( + input=[ + ModelDumpInput(), + DictInput(), + {"content": [{"input_text": "inline prompt"}]}, + {"content": [{"type": "input_image", "image_url": "https://example.com"}]}, + ] + ) + + assert prompt == "model dump prompt\ndict prompt\ninline prompt" + + +def test_redis_semantic_cache_prompt_extraction_returns_none_without_text(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + assert RedisSemanticCache._get_prompt_from_kwargs() is None + assert RedisSemanticCache._get_prompt_from_kwargs(input=None) is None + assert RedisSemanticCache._get_prompt_from_kwargs(input=" ") is None + assert ( + RedisSemanticCache._get_prompt_from_kwargs( + input=[{"type": "input_image", "image_url": "https://example.com"}] + ) + is None + ) + + +def test_redis_semantic_cache_prompt_extraction_skips_blank_dict_text_keys(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + prompt = RedisSemanticCache._get_prompt_from_kwargs( + input={"text": " ", "input_text": "fallback prompt"} + ) + + assert prompt == "fallback prompt" + + +def test_redis_semantic_cache_prompt_extraction_skips_blank_object_text_keys(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + class ResponseInput: + text = " " + input_text = "fallback prompt" + + prompt = RedisSemanticCache._get_prompt_from_kwargs(input=ResponseInput()) + + assert prompt == "fallback prompt" + + +def test_redis_semantic_cache_prompt_extraction_handles_object_content(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + class ResponseInput: + content = [{"text": "object content prompt"}] + + prompt = RedisSemanticCache._get_prompt_from_kwargs(input=ResponseInput()) + + assert prompt == "object content prompt" + + +def test_redis_semantic_cache_set_cache_skips_blank_responses_input(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + redis_semantic_cache.llmcache = MagicMock() + + redis_semantic_cache.set_cache( + key="test_key", + value={"content": "Paris"}, + input=" ", + ) + + redis_semantic_cache.llmcache.store.assert_not_called() + + +def test_redis_semantic_cache_get_cache_sets_similarity_on_blank_responses_input(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + redis_semantic_cache.llmcache = MagicMock() + metadata = {} + + result = redis_semantic_cache.get_cache( + key="test_key", + input=" ", + metadata=metadata, + ) + + assert result is None + assert metadata["semantic-similarity"] == 0.0 + redis_semantic_cache.llmcache.check.assert_not_called() + + +def test_redis_semantic_cache_get_cache_sets_similarity_when_no_results(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + redis_semantic_cache.llmcache = MagicMock() + redis_semantic_cache.llmcache.check = MagicMock(return_value=[]) + + with patch.object( + redis_semantic_cache, + "_get_cache_key_filter_expression", + return_value="cache-key-filter", + ): + metadata = {} + result = redis_semantic_cache.get_cache( + key="test_key", + input="What is the capital of France?", + metadata=metadata, + ) + + assert result is None + assert metadata["semantic-similarity"] == 0.0 + redis_semantic_cache.llmcache.check.assert_called_once_with( + prompt="What is the capital of France?", + filter_expression="cache-key-filter", + ) + + +@pytest.mark.asyncio +async def test_redis_semantic_cache_async_paths_use_responses_string_input(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + redis_semantic_cache.similarity_threshold = 0.8 + redis_semantic_cache.llmcache = MagicMock() + redis_semantic_cache.llmcache.astore = AsyncMock() + redis_semantic_cache.llmcache.acheck = AsyncMock( + return_value=[ + { + "prompt": "What is the capital of France?", + "response": '{"content": "Paris"}', + "vector_distance": 0.1, + RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key", + } + ] + ) + redis_semantic_cache._get_cache_filters = MagicMock( + return_value={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"} + ) + redis_semantic_cache._get_ttl = MagicMock(return_value=None) + redis_semantic_cache._get_async_embedding = AsyncMock(return_value=[0.1, 0.2, 0.3]) + + await redis_semantic_cache.async_set_cache( + key="test_key", + value={"content": "Paris"}, + input="What is the capital of France?", + ) + + with patch.object( + redis_semantic_cache, + "_get_cache_key_filter_expression", + return_value="cache-key-filter", + ): + metadata = {} + result = await redis_semantic_cache.async_get_cache( + key="test_key", + input="What is the capital of France?", + metadata=metadata, + ) + + redis_semantic_cache.llmcache.astore.assert_called_once_with( + "What is the capital of France?", + "{'content': 'Paris'}", + vector=[0.1, 0.2, 0.3], + filters={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"}, + ) + assert result == {"content": "Paris"} + assert metadata["semantic-similarity"] == pytest.approx(0.9) + redis_semantic_cache.llmcache.acheck.assert_called_once_with( + prompt="What is the capital of France?", + vector=[0.1, 0.2, 0.3], + filter_expression="cache-key-filter", + ) + + +@pytest.mark.asyncio +async def test_redis_semantic_cache_async_paths_set_similarity_on_misses(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + redis_semantic_cache.llmcache = MagicMock() + redis_semantic_cache.llmcache.astore = AsyncMock() + redis_semantic_cache.llmcache.acheck = AsyncMock(return_value=[]) + redis_semantic_cache._get_async_embedding = AsyncMock(return_value=[0.1, 0.2, 0.3]) + + await redis_semantic_cache.async_set_cache( + key="test_key", + value={"content": "Paris"}, + input=" ", + ) + + redis_semantic_cache.llmcache.astore.assert_not_called() + redis_semantic_cache._get_async_embedding.assert_not_called() + + blank_metadata = {} + blank_result = await redis_semantic_cache.async_get_cache( + key="test_key", + input=" ", + metadata=blank_metadata, + ) + + assert blank_result is None + assert blank_metadata["semantic-similarity"] == 0.0 + redis_semantic_cache.llmcache.acheck.assert_not_called() + redis_semantic_cache._get_async_embedding.assert_not_called() + + with patch.object( + redis_semantic_cache, + "_get_cache_key_filter_expression", + return_value="cache-key-filter", + ): + miss_metadata = {} + miss_result = await redis_semantic_cache.async_get_cache( + key="test_key", + input="What is the capital of France?", + metadata=miss_metadata, + ) + + assert miss_result is None + assert miss_metadata["semantic-similarity"] == 0.0 + redis_semantic_cache.llmcache.acheck.assert_called_once_with( + prompt="What is the capital of France?", + vector=[0.1, 0.2, 0.3], + filter_expression="cache-key-filter", + ) + + +def test_cache_get_cache_passes_responses_input_to_backend_cache(): + from litellm.caching.caching import Cache + + cache = Cache.__new__(Cache) + cache.cache = MagicMock() + cache.cache.get_cache = MagicMock(return_value=None) + cache.should_use_cache = MagicMock(return_value=True) + cache.get_cache_key = MagicMock(return_value="test_key") + + metadata = {} + cache.get_cache( + input="What is the capital of France?", + metadata=metadata, + cache={}, + ) + + cache.cache.get_cache.assert_called_once_with( + "test_key", + input="What is the capital of France?", + metadata=metadata, + ) + + +def test_cache_get_cache_filters_sensitive_kwargs_from_backend_cache(): + from litellm.caching.caching import Cache + + cache = Cache.__new__(Cache) + cache.cache = MagicMock() + cache.should_use_cache = MagicMock(return_value=True) + cache.get_cache_key = MagicMock(return_value="test_key") + cache._get_cache_logic = MagicMock(return_value={"content": "Paris"}) + + def _cache_hit(_cache_key, **cache_kwargs): + cache_kwargs["metadata"]["semantic-similarity"] = 0.7 + return {"content": "Paris"} + + cache.cache.get_cache = MagicMock(side_effect=_cache_hit) + + metadata = {"user_api_key": "sk-secret", "trace_id": "trace-id"} + result = cache.get_cache( + input="What is the capital of France?", + metadata=metadata, + cache={"s-maxage": 10}, + api_key="sk-secret", + headers={"authorization": "Bearer sk-secret"}, + ) + + assert result == {"content": "Paris"} + assert metadata == { + "user_api_key": "sk-secret", + "trace_id": "trace-id", + "semantic-similarity": 0.7, + } + + forwarded_kwargs = cache.cache.get_cache.call_args.kwargs + assert forwarded_kwargs == { + "input": "What is the capital of France?", + "metadata": {"semantic-similarity": 0.7}, + } + assert forwarded_kwargs["metadata"] is not metadata + cache._get_cache_logic.assert_called_once_with( + cached_result={"content": "Paris"}, + max_age=10, + ) + + +def test_cache_get_cache_filters_sensitive_kwargs_without_metadata(): + from litellm.caching.caching import Cache + + cache = Cache.__new__(Cache) + cache.cache = MagicMock() + cache.cache.get_cache = MagicMock(return_value={"content": "Paris"}) + cache.should_use_cache = MagicMock(return_value=True) + cache.get_cache_key = MagicMock(return_value="test_key") + cache._get_cache_logic = MagicMock(return_value={"content": "Paris"}) + + result = cache.get_cache( + input="What is the capital of France?", + cache={"s-maxage": 10}, + api_key="sk-secret", + headers={"authorization": "Bearer sk-secret"}, + ) + + assert result == {"content": "Paris"} + cache.cache.get_cache.assert_called_once_with( + "test_key", + input="What is the capital of France?", + ) + + +def test_cache_get_cache_passes_responses_input_to_dynamic_cache(): + from litellm.caching.caching import Cache + + cache = Cache.__new__(Cache) + cache.should_use_cache = MagicMock(return_value=True) + cache.get_cache_key = MagicMock(return_value="test_key") + cache._get_cache_logic = MagicMock(return_value={"content": "Paris"}) + dynamic_cache_object = MagicMock() + dynamic_cache_object.get_cache = MagicMock(return_value={"content": "Paris"}) + + metadata = {} + result = cache.get_cache( + dynamic_cache_object=dynamic_cache_object, + input="What is the capital of France?", + metadata=metadata, + cache={}, + ) + + assert result == {"content": "Paris"} + dynamic_cache_object.get_cache.assert_called_once_with( + "test_key", + input="What is the capital of France?", + metadata=metadata, + ) + cache._get_cache_logic.assert_called_once_with( + cached_result={"content": "Paris"}, + max_age=float("inf"), + ) diff --git a/tests/test_litellm/integrations/focus/test_focus_gcs_destination.py b/tests/test_litellm/integrations/focus/test_focus_gcs_destination.py new file mode 100644 index 0000000000..35cdb18326 --- /dev/null +++ b/tests/test_litellm/integrations/focus/test_focus_gcs_destination.py @@ -0,0 +1,180 @@ +"""Tests for FocusGCSDestination.""" + +from __future__ import annotations + +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from litellm.integrations.focus.destinations.base import FocusTimeWindow + + +def _make_window(frequency: str = "hourly") -> FocusTimeWindow: + return FocusTimeWindow( + start_time=datetime(2026, 1, 1, 10, 0, 0, tzinfo=timezone.utc), + end_time=datetime(2026, 1, 1, 11, 0, 0, tzinfo=timezone.utc), + frequency=frequency, + ) + + +@pytest.mark.asyncio +async def test_deliver_posts_to_gcs_upload_endpoint(): + """deliver() must POST raw bytes to the GCS upload endpoint.""" + from litellm.integrations.focus.destinations.gcs_destination import ( + FocusGCSDestination, + ) + + dest = FocusGCSDestination( + prefix="focus_exports", + config={"bucket_name": "my-bucket", "service_account_json": None}, + ) + + mock_response = MagicMock() + mock_response.status_code = 200 + + mock_client = MagicMock() + mock_client.post = AsyncMock(return_value=mock_response) + dest.async_httpx_client = mock_client + + with patch.object( + dest, + "construct_request_headers", + new=AsyncMock(return_value={"Authorization": "Bearer tok-123"}), + ): + await dest.deliver( + content=b"col1,col2\nval1,val2\n", + time_window=_make_window(), + filename="usage_20260101T100000Z_20260101T110000Z.csv", + ) + + mock_client.post.assert_called_once() + call_kwargs = mock_client.post.call_args + url = call_kwargs.kwargs.get("url") or call_kwargs.args[0] + assert "my-bucket" in url + assert "uploadType=media" in url + headers = call_kwargs.kwargs["headers"] + assert headers["Authorization"] == "Bearer tok-123" + + +@pytest.mark.asyncio +async def test_deliver_raises_on_gcs_error(): + """deliver() must raise RuntimeError when GCS returns non-200.""" + from litellm.integrations.focus.destinations.gcs_destination import ( + FocusGCSDestination, + ) + + dest = FocusGCSDestination( + prefix="focus_exports", + config={"bucket_name": "my-bucket"}, + ) + + mock_response = MagicMock() + mock_response.status_code = 403 + mock_response.text = "Permission denied" + + mock_client = MagicMock() + mock_client.post = AsyncMock(return_value=mock_response) + dest.async_httpx_client = mock_client + + with patch.object( + dest, + "construct_request_headers", + new=AsyncMock(return_value={"Authorization": "Bearer tok-bad"}), + ): + with pytest.raises(RuntimeError, match="GCS upload failed"): + await dest.deliver( + content=b"data", + time_window=_make_window(), + filename="usage.csv", + ) + + +def test_build_object_key_hourly(): + """Hourly key must include date= and hour= components.""" + from litellm.integrations.focus.destinations.gcs_destination import ( + FocusGCSDestination, + ) + + dest = FocusGCSDestination(prefix="focus_exports", config={"bucket_name": "b"}) + key = dest._build_object_key( + time_window=_make_window("hourly"), filename="usage.parquet" + ) + + assert key == "focus_exports/date=2026-01-01/hour=10/usage.parquet" + + +def test_build_object_key_daily(): + """Daily key must include date= but not hour=.""" + from litellm.integrations.focus.destinations.gcs_destination import ( + FocusGCSDestination, + ) + + dest = FocusGCSDestination(prefix="focus_exports", config={"bucket_name": "b"}) + window = FocusTimeWindow( + start_time=datetime(2026, 1, 1, 0, 0, 0, tzinfo=timezone.utc), + end_time=datetime(2026, 1, 2, 0, 0, 0, tzinfo=timezone.utc), + frequency="daily", + ) + key = dest._build_object_key(time_window=window, filename="usage.parquet") + + assert key == "focus_exports/date=2026-01-01/usage.parquet" + + +def test_missing_bucket_name_raises(): + """Constructing without bucket_name must raise ValueError.""" + from litellm.integrations.focus.destinations.gcs_destination import ( + FocusGCSDestination, + ) + + with pytest.raises(ValueError, match="bucket_name"): + FocusGCSDestination(prefix="focus_exports", config={}) + + +def test_global_gcs_service_account_not_overwritten_when_absent(monkeypatch): + """service_account_json absent from config must not overwrite GCS_PATH_SERVICE_ACCOUNT. + + GCSBucketBase sets self.path_service_account_json from GCS_PATH_SERVICE_ACCOUNT. + If config has no service_account_json key, we must leave the parent value intact + so deployments using the global credential don't silently fall back to ADC. + """ + monkeypatch.setenv("GCS_PATH_SERVICE_ACCOUNT", "/global/sa.json") + + from litellm.integrations.focus.destinations.gcs_destination import ( + FocusGCSDestination, + ) + + dest = FocusGCSDestination(prefix="focus_exports", config={"bucket_name": "b"}) + + assert dest.path_service_account_json == "/global/sa.json" + + +def test_explicit_service_account_overrides_global(monkeypatch): + """Explicit service_account_json in config must take precedence over GCS_PATH_SERVICE_ACCOUNT.""" + monkeypatch.setenv("GCS_PATH_SERVICE_ACCOUNT", "/global/sa.json") + + from litellm.integrations.focus.destinations.gcs_destination import ( + FocusGCSDestination, + ) + + dest = FocusGCSDestination( + prefix="focus_exports", + config={"bucket_name": "b", "service_account_json": "/focus/sa.json"}, + ) + + assert dest.path_service_account_json == "/focus/sa.json" + + +def test_factory_creates_gcs_destination(monkeypatch): + """FocusDestinationFactory.create(provider='gcs') must return FocusGCSDestination.""" + monkeypatch.setenv("FOCUS_GCS_BUCKET_NAME", "env-bucket") + + from litellm.integrations.focus.destinations.factory import FocusDestinationFactory + from litellm.integrations.focus.destinations.gcs_destination import ( + FocusGCSDestination, + ) + + dest = FocusDestinationFactory.create(provider="gcs", prefix="focus_exports") + + assert isinstance(dest, FocusGCSDestination) + assert dest.BUCKET_NAME == "env-bucket" diff --git a/tests/test_litellm/litellm_core_utils/llm_cost_calc/test_tool_call_cost_tracking.py b/tests/test_litellm/litellm_core_utils/llm_cost_calc/test_tool_call_cost_tracking.py index a04f6407e4..c43291566b 100644 --- a/tests/test_litellm/litellm_core_utils/llm_cost_calc/test_tool_call_cost_tracking.py +++ b/tests/test_litellm/litellm_core_utils/llm_cost_calc/test_tool_call_cost_tracking.py @@ -1,17 +1,14 @@ -import json import os import sys -from unittest.mock import MagicMock import pytest -from fastapi.testclient import TestClient import litellm from litellm.litellm_core_utils.llm_cost_calc.tool_call_cost_tracking import ( StandardBuiltInToolCostTracking, ) from litellm.types.llms.openai import FileSearchTool, WebSearchOptions -from litellm.types.utils import ModelInfo, ModelResponse, StandardBuiltInToolsParams +from litellm.types.utils import ModelResponse, StandardBuiltInToolsParams sys.path.insert( 0, os.path.abspath("../../..") @@ -139,6 +136,22 @@ def test_get_cost_for_anthropic_web_search(): assert cost > 0.0 +def test_get_cost_for_anthropic_web_search_with_server_tool_use_dict(): + """ + Anthropic-compatible passthrough responses can construct Usage from a raw + usage payload. Ensure dict server_tool_use values are normalized before + built-in tool cost tracking reads server_tool_use.web_search_requests. + """ + from litellm.types.utils import ServerToolUse, Usage + + usage = Usage(server_tool_use={"web_search_requests": 1}) + + assert isinstance(usage.server_tool_use, ServerToolUse) + assert StandardBuiltInToolCostTracking.response_object_includes_web_search_call( + response_object=None, usage=usage + ) + + @pytest.mark.parametrize( "model", ["gemini/gemini-2.0-flash-001", "gemini-2.0-flash-001"] ) diff --git a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py index 4c33031293..75038574c6 100644 --- a/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py +++ b/tests/test_litellm/llms/anthropic/chat/test_anthropic_chat_transformation.py @@ -5092,6 +5092,175 @@ def test_map_tool_helper_collision_prefers_definitions_over_components_schemas() assert transformed["input_schema"]["properties"]["from_components"] == expected +BILLING_HEADER_BLOCK = { + "type": "text", + "text": "x-anthropic-billing-header: cc_version=1.0.abc; cc_entrypoint=cli; cch=00000;", +} + + +def _system_with_billing_header(real_text: str) -> list: + return [ + { + "role": "system", + "content": [BILLING_HEADER_BLOCK, {"type": "text", "text": real_text}], + } + ] + + +def test_translate_system_message_keeps_billing_header_for_first_party_anthropic(): + config = AnthropicConfig() + assert config.should_strip_billing_metadata() is False + + result = config.translate_system_message( + messages=_system_with_billing_header( + "You are Claude Code, Anthropic's official CLI for Claude." + ) + ) + + texts = [block["text"] for block in result] + assert any(t.startswith("x-anthropic-billing-header:") for t in texts) + assert "You are Claude Code, Anthropic's official CLI for Claude." in texts + + +def test_translate_system_message_strips_billing_header_for_bedrock(): + from litellm.llms.bedrock.claude_platform.transformation import ( + BedrockClaudePlatformConfig, + ) + + config = BedrockClaudePlatformConfig() + assert config.should_strip_billing_metadata() is True + + result = config.translate_system_message( + messages=_system_with_billing_header("real system prompt") + ) + + texts = [block["text"] for block in result] + assert all(not t.startswith("x-anthropic-billing-header:") for t in texts) + assert "real system prompt" in texts + + +def test_anthropic_messages_request_keeps_billing_header_for_first_party(): + from litellm.types.router import GenericLiteLLMParams + + config = AnthropicMessagesConfig() + assert config.should_strip_billing_metadata() is False + + optional_params = { + "max_tokens": 16, + "system": [ + BILLING_HEADER_BLOCK, + {"type": "text", "text": "real system prompt"}, + ], + } + result = config.transform_anthropic_messages_request( + model="claude-3-5-sonnet-latest", + messages=[{"role": "user", "content": "hi"}], + anthropic_messages_optional_request_params=optional_params, + litellm_params=GenericLiteLLMParams(), + headers={}, + ) + + texts = [block["text"] for block in result["system"]] + assert any(t.startswith("x-anthropic-billing-header:") for t in texts) + + +def test_anthropic_messages_request_strips_billing_header_for_minimax(): + from litellm.llms.minimax.messages.transformation import MinimaxMessagesConfig + from litellm.types.router import GenericLiteLLMParams + + config = MinimaxMessagesConfig() + assert config.should_strip_billing_metadata() is True + + optional_params = { + "max_tokens": 16, + "system": [ + BILLING_HEADER_BLOCK, + {"type": "text", "text": "real system prompt"}, + ], + } + result = config.transform_anthropic_messages_request( + model="MiniMax-M2", + messages=[{"role": "user", "content": "hi"}], + anthropic_messages_optional_request_params=optional_params, + litellm_params=GenericLiteLLMParams(), + headers={}, + ) + + texts = [block["text"] for block in result.get("system", [])] + assert all(not t.startswith("x-anthropic-billing-header:") for t in texts) + + +def test_translate_system_message_strips_billing_header_for_bedrock_invoke(): + from litellm.llms.bedrock.chat.invoke_transformations.anthropic_claude3_transformation import ( + AmazonAnthropicClaudeConfig, + ) + + config = AmazonAnthropicClaudeConfig() + assert config.should_strip_billing_metadata() is True + + result = config.translate_system_message( + messages=_system_with_billing_header("real system prompt") + ) + + texts = [block["text"] for block in result] + assert all(not t.startswith("x-anthropic-billing-header:") for t in texts) + assert "real system prompt" in texts + + +@pytest.mark.parametrize( + "module_path, class_name, expected_strip", + [ + ("litellm.llms.anthropic.chat.transformation", "AnthropicConfig", False), + ( + "litellm.llms.anthropic.experimental_pass_through.messages.transformation", + "AnthropicMessagesConfig", + False, + ), + ( + "litellm.llms.bedrock.claude_platform.transformation", + "BedrockClaudePlatformConfig", + True, + ), + ( + "litellm.llms.bedrock.chat.invoke_transformations.anthropic_claude3_transformation", + "AmazonAnthropicClaudeConfig", + True, + ), + ( + "litellm.llms.vertex_ai.vertex_ai_partner_models.anthropic.transformation", + "VertexAIAnthropicConfig", + True, + ), + ( + "litellm.llms.azure_ai.anthropic.transformation", + "AzureAnthropicConfig", + True, + ), + ("litellm.llms.minimax.messages.transformation", "MinimaxMessagesConfig", True), + ( + "litellm.llms.azure_ai.anthropic.messages_transformation", + "AzureAnthropicMessagesConfig", + True, + ), + ( + "litellm.llms.deepseek.messages.transformation", + "DeepSeekAnthropicMessagesConfig", + True, + ), + ( + "litellm.llms.vertex_ai.vertex_ai_partner_models.anthropic.experimental_pass_through.transformation", + "VertexAIPartnerModelsAnthropicMessagesConfig", + True, + ), + ], +) +def test_should_strip_billing_metadata_by_provider( + module_path, class_name, expected_strip +): + import importlib + + config_cls = getattr(importlib.import_module(module_path), class_name) + assert config_cls().should_strip_billing_metadata() is expected_strip def test_namespace_tool_flat_nested_tools_are_extracted(): """Codex sends nested tools in flat format {type, name, description, parameters} with no 'function' wrapper. These must be normalized and mapped without raising KeyError: 'function'.""" diff --git a/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py b/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py index a6aa35ee6d..ed978113b8 100644 --- a/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py +++ b/tests/test_litellm/llms/bedrock/chat/test_converse_transformation.py @@ -1467,11 +1467,10 @@ def test_transform_request_with_function_tool(): ) # Verify the structure - assert "additionalModelRequestFields" in request_data - additional_fields = request_data["additionalModelRequestFields"] + # Function tools are not computer use tools, so they don't get anthropic_beta — + # additionalModelRequestFields should be absent (not serialized as empty {}) + assert "additionalModelRequestFields" not in request_data - # Function tools are not computer use tools, so they don't get anthropic_beta - # They are processed through the regular tool config assert "toolConfig" in request_data assert "tools" in request_data["toolConfig"] assert len(request_data["toolConfig"]["tools"]) == 1 diff --git a/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py b/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py index ba41fc47e8..4731be13e7 100644 --- a/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py +++ b/tests/test_litellm/llms/bedrock/files/test_bedrock_files_transformation.py @@ -128,6 +128,70 @@ class TestBedrockFilesTransformation: # Must have messages assert "messages" in model_input + # Nova Pro rejects empty additionalModelRequestFields / system — they must be absent + assert ( + "additionalModelRequestFields" not in model_input + ), "Nova: empty additionalModelRequestFields must be omitted, not serialized as {}" + assert ( + "system" not in model_input + ), "Nova: empty system must be omitted, not serialized as []" + + def test_nova_batch_jsonl_omits_empty_converse_fields(self): + """ + Regression test: Amazon Nova Pro returns 400 Malformed input request when + additionalModelRequestFields or system are present but empty in the Converse + API payload. The proxy must strip these keys when they carry no data. + """ + from litellm.llms.bedrock.files.transformation import BedrockFilesConfig + + config = BedrockFilesConfig() + + openai_jsonl_content = [ + { + "custom_id": "req-0", + "method": "POST", + "url": "/v1/chat/completions", + "body": { + "model": "us.amazon.nova-pro-v1:0", + "messages": [ + { + "role": "user", + "content": "What is 1 + 1? Answer with just the number.", + } + ], + "max_tokens": 16, + }, + } + ] + + result = config._transform_openai_jsonl_content_to_bedrock_jsonl_content( + openai_jsonl_content + ) + + assert len(result) == 1 + model_input = result[0]["modelInput"] + + assert ( + "additionalModelRequestFields" not in model_input + or model_input["additionalModelRequestFields"] + ), "additionalModelRequestFields must be absent or non-empty — Nova rejects {}" + assert ( + "system" not in model_input or model_input["system"] + ), "system must be absent or non-empty — Nova rejects []" + + # Validate the exact shape AWS accepts + assert model_input == { + "messages": [ + { + "role": "user", + "content": [ + {"text": "What is 1 + 1? Answer with just the number."} + ], + } + ], + "inferenceConfig": {"maxTokens": 16}, + } + def test_nova_image_content_uses_converse_image_blocks(self): """ Test that image_url content blocks are converted to Bedrock Converse diff --git a/tests/test_litellm/llms/bedrock_mantle/test_bedrock_mantle_responses_transformation.py b/tests/test_litellm/llms/bedrock_mantle/test_bedrock_mantle_responses_transformation.py index e2133d56f8..92b5ca7b10 100644 --- a/tests/test_litellm/llms/bedrock_mantle/test_bedrock_mantle_responses_transformation.py +++ b/tests/test_litellm/llms/bedrock_mantle/test_bedrock_mantle_responses_transformation.py @@ -12,6 +12,11 @@ import sys sys.path.insert(0, os.path.abspath("../../../../..")) import pytest +from botocore.exceptions import ( + ConnectTimeoutError, + PartialCredentialsError, + ProfileNotFound, +) import litellm from litellm.llms.bedrock_mantle.responses.transformation import ( @@ -114,16 +119,15 @@ class TestBedrockMantleResponsesAuth: ) assert headers["Authorization"] == "Bearer bearer-key" - def test_missing_key_raises(self, monkeypatch): + def test_missing_bearer_does_not_raise_in_validate_environment(self, monkeypatch): + # SigV4 may still apply, so validate_environment must defer instead of raising. monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) cfg = BedrockMantleResponsesAPIConfig() - with pytest.raises(ValueError, match="Bedrock Mantle API key"): - cfg.validate_environment( - headers={}, - model="openai.gpt-5.5", - litellm_params=GenericLiteLLMParams(), - ) + headers = cfg.validate_environment( + headers={}, model="openai.gpt-5.5", litellm_params=GenericLiteLLMParams() + ) + assert "Authorization" not in headers def test_custom_llm_provider(self): cfg = BedrockMantleResponsesAPIConfig() @@ -261,6 +265,386 @@ def local_cost_map(monkeypatch): litellm.get_model_info.cache_clear() +class TestBedrockMantleResponsesSigV4: + def test_bearer_short_circuits_without_credentials(self, monkeypatch): + from unittest.mock import MagicMock + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) + + signer = BaseAWSLLM() + signer.get_credentials = MagicMock( + side_effect=AssertionError("get_credentials must not run for bearer auth") + ) + cfg = BedrockMantleResponsesAPIConfig(aws_signer=signer) + + headers, signed_body = cfg.sign_request( + headers={}, + optional_params={}, + request_data={"input": "hi"}, + api_base="https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses", + api_key="bearer-from-config", + ) + assert headers["Authorization"] == "Bearer bearer-from-config" + assert signed_body == b'{"input": "hi"}' + signer.get_credentials.assert_not_called() + + def test_bearer_resolved_from_mantle_env_key(self, monkeypatch): + from unittest.mock import MagicMock + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.setenv("BEDROCK_MANTLE_API_KEY", "env-bearer") + + signer = BaseAWSLLM() + signer.get_credentials = MagicMock( + side_effect=AssertionError("get_credentials must not run for bearer auth") + ) + cfg = BedrockMantleResponsesAPIConfig(aws_signer=signer) + + headers, _ = cfg.sign_request( + headers={}, + optional_params={}, + request_data={"input": "hi"}, + api_base="https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses", + api_key=None, + ) + assert headers["Authorization"] == "Bearer env-bearer" + + def test_bearer_arg_takes_priority_over_mantle_env_key(self, monkeypatch): + # The passed api_key (e.g. litellm_params.api_key) must win over the env + # bearer; a reordered precedence chain would silently use the wrong token. + from unittest.mock import MagicMock + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.setenv("BEDROCK_MANTLE_API_KEY", "env-bearer") + + signer = BaseAWSLLM() + signer.get_credentials = MagicMock( + side_effect=AssertionError("get_credentials must not run for bearer auth") + ) + cfg = BedrockMantleResponsesAPIConfig(aws_signer=signer) + + headers, _ = cfg.sign_request( + headers={}, + optional_params={}, + request_data={"input": "hi"}, + api_base="https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses", + api_key="arg-bearer", + ) + assert headers["Authorization"] == "Bearer arg-bearer" + signer.get_credentials.assert_not_called() + + def test_access_key_produces_sigv4_headers(self, monkeypatch): + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) + + cfg = BedrockMantleResponsesAPIConfig(aws_signer=BaseAWSLLM()) + headers, signed_body = cfg.sign_request( + headers={}, + optional_params={ + "aws_access_key_id": "AKIAEXAMPLE", + "aws_secret_access_key": "c2VjcmV0LXRlc3Qtc2VjcmV0LXRlc3Qtc2VjcmV0", + "aws_session_token": "session-token-test", + "aws_region_name": "us-east-2", + }, + request_data={"input": "hi"}, + api_base="https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses", + api_key=None, + ) + assert headers["Authorization"].startswith("AWS4-HMAC-SHA256") + assert "Credential=AKIAEXAMPLE/" in headers["Authorization"] + assert "/us-east-2/bedrock/aws4_request" in headers["Authorization"] + assert "X-Amz-Date" in headers + assert headers["X-Amz-Security-Token"] == "session-token-test" + assert signed_body == b'{"input": "hi"}' + + def test_assume_role_path_produces_sigv4_headers(self, monkeypatch): + from unittest.mock import MagicMock + from botocore.credentials import Credentials + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) + + signer = BaseAWSLLM() + signer.get_credentials = MagicMock( + return_value=Credentials( + access_key="ASIAEXAMPLE", + secret_key="YXNzdW1lZC1yb2xlLXNlY3JldC1hc3N1bWVk", + token="assumed-session-token", + ) + ) + cfg = BedrockMantleResponsesAPIConfig(aws_signer=signer) + + headers, _ = cfg.sign_request( + headers={}, + optional_params={ + "aws_role_name": "arn:aws:iam::000000000000:role/test-role", + "aws_session_name": "litellm-test", + "aws_region_name": "us-east-2", + }, + request_data={"input": "hi"}, + api_base="https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses", + api_key=None, + ) + signer.get_credentials.assert_called_once() + call = signer.get_credentials.call_args.kwargs + assert call["aws_role_name"] == "arn:aws:iam::000000000000:role/test-role" + assert call["aws_session_name"] == "litellm-test" + assert headers["Authorization"].startswith("AWS4-HMAC-SHA256") + assert "/us-east-2/bedrock/aws4_request" in headers["Authorization"] + + def test_signed_body_matches_final_data_after_normalize(self, monkeypatch): + """Core regression: the signed bytes must equal the bytes actually sent. + + Sign the *final* data dict and assert the returned signed_body decodes to + exactly that dict, so a later change to the data would break the SigV4 hash. + """ + import json + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) + + final_data = {"model": "openai.gpt-5.5", "input": "hi", "max_output_tokens": 16} + cfg = BedrockMantleResponsesAPIConfig(aws_signer=BaseAWSLLM()) + _, signed_body = cfg.sign_request( + headers={}, + optional_params={ + "aws_access_key_id": "AKIAEXAMPLE", + "aws_secret_access_key": "c2VjcmV0LXRlc3Qtc2VjcmV0LXRlc3Qtc2VjcmV0", + "aws_region_name": "us-east-2", + }, + request_data=final_data, + api_base="https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses", + api_key=None, + ) + assert signed_body is not None + assert json.loads(signed_body) == final_data + + def test_region_comes_from_optional_params(self, monkeypatch): + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) + monkeypatch.delenv("AWS_REGION", raising=False) + monkeypatch.delenv("AWS_REGION_NAME", raising=False) + + cfg = BedrockMantleResponsesAPIConfig(aws_signer=BaseAWSLLM()) + headers, _ = cfg.sign_request( + headers={}, + optional_params={ + "aws_access_key_id": "AKIAEXAMPLE", + "aws_secret_access_key": "c2VjcmV0LXRlc3Qtc2VjcmV0LXRlc3Qtc2VjcmV0", + "aws_region_name": "eu-west-1", + }, + request_data={"input": "hi"}, + api_base="https://bedrock-mantle.eu-west-1.api.aws/openai/v1/responses", + api_key=None, + ) + assert "/eu-west-1/bedrock/aws4_request" in headers["Authorization"] + + def test_url_region_and_sigv4_region_agree_from_litellm_params(self, monkeypatch): + """Adversarial-review regression: a caller-supplied aws_region_name (no region + env set) must shape BOTH the URL host and the SigV4 credential scope, or the + request is signed for one region and sent to another -> 401. + """ + monkeypatch.delenv("BEDROCK_MANTLE_REGION", raising=False) + monkeypatch.delenv("BEDROCK_MANTLE_API_BASE", raising=False) + monkeypatch.delenv("AWS_REGION", raising=False) + monkeypatch.delenv("AWS_REGION_NAME", raising=False) + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) + + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + params = { + "aws_region_name": "ap-southeast-2", + "aws_access_key_id": "AKIAEXAMPLE", + "aws_secret_access_key": "c2VjcmV0LXRlc3Qtc2VjcmV0LXRlc3Qtc2VjcmV0", + } + cfg = BedrockMantleResponsesAPIConfig(aws_signer=BaseAWSLLM()) + url = cfg.get_complete_url(api_base=None, litellm_params=params) + assert ( + url == "https://bedrock-mantle.ap-southeast-2.api.aws/openai/v1/responses" + ) + + headers, _ = cfg.sign_request( + headers={}, + optional_params=params, + request_data={"input": "hi"}, + api_base=url, + api_key=None, + ) + assert "/ap-southeast-2/bedrock/aws4_request" in headers["Authorization"] + + def test_injected_default_region_base_does_not_override_aws_region_name( + self, monkeypatch + ): + """2nd-round adversarial regression: responses/main.py auto-injects + litellm_params.api_base = https://bedrock-mantle..api.aws/v1 (default + region, ignoring aws_region_name). The config must still pin BOTH the URL host + and the SigV4 scope to aws_region_name, or the IAM deployment 401s. A naive + 'resolve region only when api_base is None' fix would fail this test. + """ + monkeypatch.delenv("BEDROCK_MANTLE_REGION", raising=False) + monkeypatch.delenv("BEDROCK_MANTLE_API_BASE", raising=False) + monkeypatch.delenv("AWS_REGION", raising=False) + monkeypatch.delenv("AWS_REGION_NAME", raising=False) + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) + + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + injected_base = "https://bedrock-mantle.us-east-1.api.aws/v1" # default region + params = { + "aws_region_name": "us-east-2", # what the caller actually wants + "api_base": injected_base, + "aws_access_key_id": "AKIAEXAMPLE", + "aws_secret_access_key": "c2VjcmV0LXRlc3Qtc2VjcmV0LXRlc3Qtc2VjcmV0", + } + cfg = BedrockMantleResponsesAPIConfig(aws_signer=BaseAWSLLM()) + url = cfg.get_complete_url(api_base=injected_base, litellm_params=params) + assert url == "https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses" + + headers, _ = cfg.sign_request( + headers={}, + optional_params=params, + request_data={"input": "hi"}, + api_base=url, + api_key=None, + ) + assert "/us-east-2/bedrock/aws4_request" in headers["Authorization"] + assert "us-east-1" not in headers["Authorization"] + + def test_custom_proxy_host_is_preserved(self, monkeypatch): + """A genuinely custom (non-Mantle) api_base host must be preserved, not rewritten + to a bedrock-mantle host. Only standard Mantle hosts are region-pinned. + """ + monkeypatch.delenv("BEDROCK_MANTLE_API_BASE", raising=False) + cfg = BedrockMantleResponsesAPIConfig() + url = cfg.get_complete_url( + api_base="https://mantle-proxy.internal.example/openai/v1", + litellm_params={"aws_region_name": "us-east-2"}, + ) + assert url == "https://mantle-proxy.internal.example/openai/v1/responses" + + def test_caller_authorization_does_not_override_sigv4(self, monkeypatch): + """Adversarial-review regression: a caller-supplied Authorization header (e.g. + from extra_headers, surviving the relaxed validate_environment) must not clobber + the SigV4 Authorization that _sign_request would otherwise restore. + """ + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) + + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + cfg = BedrockMantleResponsesAPIConfig(aws_signer=BaseAWSLLM()) + headers, _ = cfg.sign_request( + headers={"Authorization": "Bearer stale-caller-token"}, + optional_params={ + "aws_access_key_id": "AKIAEXAMPLE", + "aws_secret_access_key": "c2VjcmV0LXRlc3Qtc2VjcmV0LXRlc3Qtc2VjcmV0", + "aws_region_name": "us-east-2", + }, + request_data={"input": "hi"}, + api_base="https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses", + api_key=None, + ) + assert headers["Authorization"].startswith("AWS4-HMAC-SHA256") + assert "Bearer stale-caller-token" not in headers["Authorization"] + + def test_no_bearer_and_no_credentials_raises_both_paths(self, monkeypatch): + from unittest.mock import MagicMock + from botocore.exceptions import NoCredentialsError + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + + signer = BaseAWSLLM() + signer.get_credentials = MagicMock(side_effect=NoCredentialsError()) + cfg = BedrockMantleResponsesAPIConfig(aws_signer=signer) + + with pytest.raises(ValueError) as exc: + cfg.sign_request( + headers={}, + optional_params={"aws_region_name": "us-east-2"}, + request_data={"input": "hi"}, + api_base="https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses", + api_key=None, + ) + msg = str(exc.value) + assert "Bearer" in msg + assert "SigV4" in msg or "IAM" in msg + + @pytest.mark.parametrize( + "cred_error", + [ + PartialCredentialsError(provider="env", cred_var="aws_secret_access_key"), + ProfileNotFound(profile="missing-profile"), + ], + ) + def test_partial_credentials_raises_both_paths(self, monkeypatch, cred_error): + from unittest.mock import MagicMock + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + + signer = BaseAWSLLM() + signer.get_credentials = MagicMock(side_effect=cred_error) + cfg = BedrockMantleResponsesAPIConfig(aws_signer=signer) + + with pytest.raises(ValueError) as exc: + cfg.sign_request( + headers={}, + optional_params={"aws_region_name": "us-east-2"}, + request_data={"input": "hi"}, + api_base="https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses", + api_key=None, + ) + msg = str(exc.value) + assert "Bearer" in msg + assert "SigV4" in msg or "IAM" in msg + + def test_sts_transport_error_is_not_masked_as_credentials(self, monkeypatch): + # An AssumeRole / web-identity flow hits STS over the network, so a transient + # connection error must surface as itself, not be rewritten into the + # "no usable AWS credentials" message that would send the user to fix the + # wrong thing. + from unittest.mock import MagicMock + from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM + + monkeypatch.delenv("BEDROCK_MANTLE_API_KEY", raising=False) + monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False) + + signer = BaseAWSLLM() + signer.get_credentials = MagicMock( + side_effect=ConnectTimeoutError( + endpoint_url="https://sts.us-east-2.amazonaws.com" + ) + ) + cfg = BedrockMantleResponsesAPIConfig(aws_signer=signer) + + with pytest.raises(ConnectTimeoutError): + cfg.sign_request( + headers={}, + optional_params={ + "aws_role_name": "arn:aws:iam::000000000000:role/test-role", + "aws_region_name": "us-east-2", + }, + request_data={"input": "hi"}, + api_base="https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses", + api_key=None, + ) + + class TestBedrockMantleResponsesPricing: def test_gpt_5_5_pricing_and_mode(self, local_cost_map): info = litellm.get_model_info("bedrock_mantle/openai.gpt-5.5") diff --git a/tests/test_litellm/llms/custom_httpx/test_llm_http_handler.py b/tests/test_litellm/llms/custom_httpx/test_llm_http_handler.py index 279e9730e6..7321abcee4 100644 --- a/tests/test_litellm/llms/custom_httpx/test_llm_http_handler.py +++ b/tests/test_litellm/llms/custom_httpx/test_llm_http_handler.py @@ -742,3 +742,241 @@ async def test_anthropic_post_retry_reserializes_mutated_body(): assert first_sent == prebuilt # attempt 0 used prebuilt assert second_sent == _json.dumps(request_body) # attempt 1 re-serialized assert "MUTATED" in second_sent # ... the mutated body + + +def test_base_responses_config_sign_request_is_noop_by_default(): + """Default responses sign_request must be a no-op: unchanged headers, no signed body. + + Guards the 15 existing responses providers from accidental signing when the + handler starts calling sign_request. + """ + from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig + + cfg = OpenAIResponsesAPIConfig() + headers = {"Authorization": "Bearer sk-existing"} + out_headers, signed_body = cfg.sign_request( + headers=headers, + optional_params={}, + request_data={"input": "hi"}, + api_base="https://api.openai.com/v1/responses", + ) + assert out_headers == {"Authorization": "Bearer sk-existing"} + assert signed_body is None + + +def _make_responses_handler_call(signed_body): + """Drive BaseLLMHTTPHandler.response_api_handler with a fully mocked provider + config + sync client, returning the kwargs the client.post was called with. + + signed_body=None simulates a no-op (non-signing) provider; bytes simulates a + signing provider (e.g. Bedrock Mantle). + """ + from unittest.mock import MagicMock + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler + from litellm.types.router import GenericLiteLLMParams + + provider_config = MagicMock() + provider_config.validate_environment.return_value = {} + provider_config.get_complete_url.return_value = ( + "https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses" + ) + provider_config.transform_responses_api_request.return_value = {"input": "hi"} + provider_config.should_fake_stream.return_value = False + provider_config.sign_request.return_value = ({"X-Signed": "1"}, signed_body) + + mock_client = MagicMock(spec=HTTPHandler) + mock_client.post.return_value = MagicMock() + + handler = BaseLLMHTTPHandler() + handler.response_api_handler( + model="openai.gpt-5.5", + input="hi", + responses_api_provider_config=provider_config, + response_api_optional_request_params={}, + custom_llm_provider="bedrock_mantle", + litellm_params=GenericLiteLLMParams(aws_region_name="us-east-2"), + logging_obj=MagicMock(), + client=mock_client, + _is_async=False, + ) + return mock_client.post.call_args.kwargs + + +def test_responses_handler_sends_json_when_not_signed(): + """No-op provider (signed_body is None) -> handler posts json=data, no data= bytes.""" + kwargs = _make_responses_handler_call(signed_body=None) + assert kwargs.get("json") == {"input": "hi"} + assert "data" not in kwargs + + +def test_responses_handler_sends_signed_bytes_when_signed(): + """Signing provider -> handler posts the exact signed bytes via data=, not json=.""" + kwargs = _make_responses_handler_call(signed_body=b'{"input": "hi"}') + assert kwargs.get("data") == b'{"input": "hi"}' + assert "json" not in kwargs + assert kwargs["headers"] == {"X-Signed": "1"} + + +def test_responses_handler_signs_after_fake_stream_prep_strips_stream(): + """Fake-stream signing-order invariant: the bytes SIGNED must equal the bytes SENT. + + In the streaming + fake-stream path the handler first runs + _prepare_fake_stream_request, which pops "stream" out of the body, and only + then calls sign_request. If signing ran before that pop, the signed body + would still carry "stream" while the body sent over the wire would not, + producing a SigV4 payload-hash mismatch (401) for a real Mantle deployment. + We snapshot request_data at sign time and assert "stream" is already gone. + """ + from unittest.mock import MagicMock + from litellm.llms.custom_httpx.http_handler import HTTPHandler + from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler + from litellm.types.llms.openai import ResponsesAPIResponse + from litellm.types.router import GenericLiteLLMParams + + provider_config = MagicMock() + provider_config.validate_environment.return_value = {} + provider_config.get_complete_url.return_value = ( + "https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses" + ) + provider_config.transform_responses_api_request.return_value = { + "input": "hi", + "stream": True, + } + provider_config.should_fake_stream.return_value = True + provider_config.transform_response_api_response.return_value = ResponsesAPIResponse( + id="resp_1", + created_at=0, + output=[], + status="completed", + model="openai.gpt-5.5", + ) + + captured = {} + + def _capture_sign(**kwargs): + captured["request_data"] = dict(kwargs["request_data"]) + return ({"X-Signed": "1"}, b'{"input": "hi"}') + + provider_config.sign_request.side_effect = _capture_sign + + mock_client = MagicMock(spec=HTTPHandler) + mock_client.post.return_value = MagicMock() + + handler = BaseLLMHTTPHandler() + handler.response_api_handler( + model="openai.gpt-5.5", + input="hi", + responses_api_provider_config=provider_config, + response_api_optional_request_params={"stream": True}, + custom_llm_provider="bedrock_mantle", + litellm_params=GenericLiteLLMParams(aws_region_name="us-east-2"), + logging_obj=MagicMock(), + client=mock_client, + _is_async=False, + fake_stream=True, + ) + + assert "stream" not in captured["request_data"] + assert "input" in captured["request_data"] + + post_kwargs = mock_client.post.call_args.kwargs + assert post_kwargs.get("data") == b'{"input": "hi"}' + assert "json" not in post_kwargs + assert "stream" in post_kwargs + + +def _make_compact_handler_call(signed_body, is_async): + """Drive (async_)compact_response_api_handler with a fully mocked provider config + + client, returning the kwargs the client.post was called with. + + signed_body=None simulates a no-op (non-signing) provider; bytes simulates a + signing provider (e.g. Bedrock Mantle SigV4 / bearer). + """ + from unittest.mock import MagicMock + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler + from litellm.llms.custom_httpx.llm_http_handler import BaseLLMHTTPHandler + from litellm.types.router import GenericLiteLLMParams + + compact_url = "https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses/compact" + provider_config = MagicMock() + provider_config.validate_environment.return_value = {} + provider_config.get_complete_url.return_value = ( + "https://bedrock-mantle.us-east-2.api.aws/openai/v1/responses" + ) + provider_config.transform_compact_response_api_request.return_value = ( + compact_url, + {"model": "openai.gpt-5.5", "input": "hi"}, + ) + provider_config.sign_request.return_value = ({"X-Signed": "1"}, signed_body) + provider_config.transform_compact_response_api_response.return_value = "ok" + + spec = AsyncHTTPHandler if is_async else HTTPHandler + mock_client = MagicMock(spec=spec) + if is_async: + mock_client.post = AsyncMock(return_value=MagicMock()) + else: + mock_client.post.return_value = MagicMock() + + handler = BaseLLMHTTPHandler() + result = handler.compact_response_api_handler( + model="openai.gpt-5.5", + input="hi", + responses_api_provider_config=provider_config, + response_api_optional_request_params={}, + custom_llm_provider="bedrock_mantle", + litellm_params=GenericLiteLLMParams(aws_region_name="us-east-2"), + logging_obj=MagicMock(), + client=mock_client, + _is_async=is_async, + ) + if is_async: + asyncio.run(result) + return provider_config, mock_client.post.call_args.kwargs + + +def test_compact_handler_sends_json_when_not_signed(): + """No-op provider on compact (signed_body is None) -> posts json=data, no data= bytes.""" + provider_config, kwargs = _make_compact_handler_call( + signed_body=None, is_async=False + ) + provider_config.sign_request.assert_called_once() + assert kwargs.get("json") == {"model": "openai.gpt-5.5", "input": "hi"} + assert "data" not in kwargs + + +def test_compact_handler_sends_signed_bytes_when_signed(): + """Signing provider on compact -> posts the signed bytes via data=, not json=. + + Regression for the adversarial-review finding that /responses/compact bypassed + the SigV4 signing hook, so IAM-only Mantle callers sent unsigned bodies. + """ + provider_config, kwargs = _make_compact_handler_call( + signed_body=b'{"model": "openai.gpt-5.5", "input": "hi"}', is_async=False + ) + assert kwargs.get("data") == b'{"model": "openai.gpt-5.5", "input": "hi"}' + assert "json" not in kwargs + assert kwargs["headers"] == {"X-Signed": "1"} + # signing must use the compact endpoint as api_base, not the create URL + assert provider_config.sign_request.call_args.kwargs["api_base"].endswith( + "/openai/v1/responses/compact" + ) + + +def test_async_compact_handler_sends_signed_bytes_when_signed(): + """Async compact must sign identically to sync (same omission in the async twin).""" + provider_config, kwargs = _make_compact_handler_call( + signed_body=b'{"model": "openai.gpt-5.5", "input": "hi"}', is_async=True + ) + assert kwargs.get("data") == b'{"model": "openai.gpt-5.5", "input": "hi"}' + assert "json" not in kwargs + assert kwargs["headers"] == {"X-Signed": "1"} + + +def test_async_compact_handler_sends_json_when_not_signed(): + """Async no-op provider on compact -> posts json=data, no data= bytes.""" + _provider_config, kwargs = _make_compact_handler_call( + signed_body=None, is_async=True + ) + assert kwargs.get("json") == {"model": "openai.gpt-5.5", "input": "hi"} + assert "data" not in kwargs diff --git a/tests/test_litellm/llms/gemini/test_gemini_tts.py b/tests/test_litellm/llms/gemini/test_gemini_tts.py index 65eefca5af..98f3ac0f4e 100644 --- a/tests/test_litellm/llms/gemini/test_gemini_tts.py +++ b/tests/test_litellm/llms/gemini/test_gemini_tts.py @@ -80,6 +80,46 @@ class TestGeminiTTSTransformation: assert "responseModalities" in result assert "AUDIO" in result["responseModalities"] + def test_gemini_tts_audio_parameter_mapping_with_language_code(self): + config = GoogleAIStudioGeminiConfig() + + non_default_params = { + "audio": {"voice": "Kore", "format": "pcm16", "language_code": "en-US"} + } + optional_params = {} + + result = config.map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model="gemini-2.5-flash-preview-tts", + drop_params=False, + ) + + assert "speechConfig" in result + assert result["speechConfig"]["languageCode"] == "en-US" + assert ( + result["speechConfig"]["voiceConfig"]["prebuiltVoiceConfig"]["voiceName"] + == "Kore" + ) + + def test_map_audio_params_language_code(self): + config = GoogleAIStudioGeminiConfig() + + result = config._map_audio_params( + {"voice": "Kore", "format": "pcm16", "language_code": "de-DE"} + ) + + assert result["languageCode"] == "de-DE" + assert result["voiceConfig"]["prebuiltVoiceConfig"]["voiceName"] == "Kore" + + def test_map_audio_params_no_language_code(self): + config = GoogleAIStudioGeminiConfig() + + result = config._map_audio_params({"voice": "Kore", "format": "pcm16"}) + + assert "languageCode" not in result + assert result["voiceConfig"]["prebuiltVoiceConfig"]["voiceName"] == "Kore" + def test_gemini_tts_audio_parameter_with_existing_modalities(self): """Test audio parameter mapping when modalities already exist""" config = GoogleAIStudioGeminiConfig() @@ -328,5 +368,57 @@ class TestGeminiTTSSpeechConfigInRequestBody: assert "AUDIO" in generation_config["responseModalities"] + @pytest.mark.parametrize( + "model,custom_llm_provider", + [ + ("gemini-2.5-flash-tts", "vertex_ai"), + ("gemini-2.5-flash-tts", "gemini"), + ("gemini-2.5-flash-preview-tts", "vertex_ai"), + ], + ) + def test_language_code_end_to_end_mapping(self, model, custom_llm_provider): + from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import ( + VertexGeminiConfig, + ) + from litellm.llms.vertex_ai.gemini.transformation import ( + _transform_request_body, + ) + + config = VertexGeminiConfig() + + non_default_params = { + "audio": {"voice": "Puck", "format": "pcm16", "language_code": "pt-BR"} + } + optional_params = {} + + mapped_params = config.map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + drop_params=False, + ) + + assert mapped_params["speechConfig"]["languageCode"] == "pt-BR" + + request_body = _transform_request_body( + messages=[{"role": "user", "content": "Hello world"}], + model=model, + optional_params=mapped_params, + custom_llm_provider=custom_llm_provider, + litellm_params={}, + cached_content=None, + ) + + generation_config = request_body["generationConfig"] + assert generation_config["speechConfig"]["languageCode"] == "pt-BR" + assert ( + generation_config["speechConfig"]["voiceConfig"]["prebuiltVoiceConfig"][ + "voiceName" + ] + == "Puck" + ) + assert "AUDIO" in generation_config["responseModalities"] + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_litellm/llms/github_copilot/responses/test_github_copilot_responses_transformation.py b/tests/test_litellm/llms/github_copilot/responses/test_github_copilot_responses_transformation.py index 54e7170bb2..17373f24a9 100644 --- a/tests/test_litellm/llms/github_copilot/responses/test_github_copilot_responses_transformation.py +++ b/tests/test_litellm/llms/github_copilot/responses/test_github_copilot_responses_transformation.py @@ -14,6 +14,8 @@ from unittest.mock import patch, MagicMock sys.path.insert(0, os.path.abspath("../../../../..")) import pytest +import litellm +from litellm.litellm_core_utils.get_model_cost_map import get_model_cost_map from litellm.types.utils import LlmProviders from litellm.utils import ProviderConfigManager from litellm.llms.github_copilot.responses.transformation import ( @@ -22,13 +24,26 @@ from litellm.llms.github_copilot.responses.transformation import ( from litellm.types.llms.openai import ResponsesAPIOptionalRequestParams +@pytest.fixture(autouse=True) +def use_local_model_cost_map(monkeypatch: pytest.MonkeyPatch): + """Pin litellm.model_cost to the bundled local backup so tests don't depend + on remote catalog fetches (and don't change behavior across remote refreshes).""" + monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True") + monkeypatch.setattr( + litellm, "model_cost", get_model_cost_map(url=litellm.model_cost_map_url) + ) + litellm.add_known_models(model_cost_map=litellm.model_cost) + + class TestGithubCopilotResponsesAPITransformation: """Test GitHub Copilot Responses API configuration and transformations""" def test_github_copilot_provider_config_registration(self): - """Test that GitHub Copilot provider returns GithubCopilotResponsesAPIConfig""" + """Test that GitHub Copilot provider returns the native Responses API + config for a Responses-capable catalog model. Exercises the full stack: + catalog lookup -> github_copilot_supports_responses_api -> native config.""" config = ProviderConfigManager.get_provider_responses_api_config( - model="github_copilot/gpt-5.1-codex", + model="github_copilot/gpt-5.3-codex", provider=LlmProviders.GITHUB_COPILOT, ) @@ -373,3 +388,200 @@ class TestGithubCopilotResponsesAPITransformation: # Non-reasoning items should pass through unchanged assert result == message_item + + +class TestGithubCopilotResponsesAPIRouting: + """``ProviderConfigManager.get_provider_responses_api_config`` for github_copilot + returns the native Responses config only when the model has ``mode=responses`` + in the (already-merged) model info; otherwise returns None so the dispatcher + routes through the chat-completions translation bridge.""" + + @patch( + "litellm.llms.github_copilot.responses.transformation._cached_get_model_info_helper" + ) + def test_returns_config_when_mode_is_responses(self, mock_get_info): + """``mode=responses`` returns native config.""" + mock_get_info.return_value = {"mode": "responses"} + config = ProviderConfigManager.get_provider_responses_api_config( + model="github_copilot/some-responses-model", + provider=LlmProviders.GITHUB_COPILOT, + ) + assert isinstance(config, GithubCopilotResponsesAPIConfig) + + @patch( + "litellm.llms.github_copilot.responses.transformation._cached_get_model_info_helper" + ) + def test_returns_none_when_mode_is_chat(self, mock_get_info): + """``mode=chat`` returns None so dispatcher uses bridge.""" + mock_get_info.return_value = {"mode": "chat"} + config = ProviderConfigManager.get_provider_responses_api_config( + model="github_copilot/some-chat-only-model", + provider=LlmProviders.GITHUB_COPILOT, + ) + assert config is None + + @patch( + "litellm.llms.github_copilot.responses.transformation._cached_get_model_info_helper" + ) + def test_returns_none_when_mode_is_unset_and_no_endpoints(self, mock_get_info): + """Entry without ``mode`` and without ``supported_endpoints`` returns None + (conservative default).""" + mock_get_info.return_value = {} + config = ProviderConfigManager.get_provider_responses_api_config( + model="github_copilot/some-model", + provider=LlmProviders.GITHUB_COPILOT, + ) + assert config is None + + def test_returns_config_when_mode_unset_but_endpoints_have_responses(self): + """``mode`` unset but ``supported_endpoints`` declaring /v1/responses + returns native config (endpoint-list fallback for stale-but-correct + catalog entries that lack ``mode``). + + Exercises the real ``_cached_get_model_info_helper`` plumbing via + ``register_model`` (no mock). ``supported_endpoints`` is not carried on + the normalized ``ModelInfoBase`` the helper returns, so the gate must + read it from the raw ``litellm.model_cost`` entry; a mock-based test + would mask that. + """ + litellm.register_model( + { + "github_copilot/test-endpoints-only-model": { + "litellm_provider": "github_copilot", + "max_tokens": 1, + "input_cost_per_token": 0, + "output_cost_per_token": 0, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/responses", + ], + } + } + ) + config = ProviderConfigManager.get_provider_responses_api_config( + model="github_copilot/test-endpoints-only-model", + provider=LlmProviders.GITHUB_COPILOT, + ) + assert isinstance(config, GithubCopilotResponsesAPIConfig) + + def test_mode_chat_overrides_endpoints_with_responses(self): + """``mode=chat`` is a hard opt-out: forces bridge even when + ``supported_endpoints`` includes /v1/responses. Lets users force the + bridge for dual-endpoint models without clearing endpoint metadata. + + Exercises the real ``_cached_get_model_info_helper`` plumbing via + ``register_model`` (no mock) so the ``mode``-over-endpoints precedence + is verified against the actual model-info resolution. + """ + litellm.register_model( + { + "github_copilot/test-chat-override-model": { + "litellm_provider": "github_copilot", + "max_tokens": 1, + "input_cost_per_token": 0, + "output_cost_per_token": 0, + "mode": "chat", + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/responses", + ], + } + } + ) + config = ProviderConfigManager.get_provider_responses_api_config( + model="github_copilot/test-chat-override-model", + provider=LlmProviders.GITHUB_COPILOT, + ) + assert config is None + + def test_returns_config_when_model_is_none(self): + """Follow-up GET/DELETE operations pass model=None and keep the native + config path (no per-model lookup is possible).""" + config = ProviderConfigManager.get_provider_responses_api_config( + model=None, + provider=LlmProviders.GITHUB_COPILOT, + ) + assert isinstance(config, GithubCopilotResponsesAPIConfig) + + @patch( + "litellm.llms.github_copilot.responses.transformation._cached_get_model_info_helper" + ) + def test_returns_none_when_get_model_info_raises(self, mock_get_info): + """Catalog lookup failure (model not registered) returns None + (conservative default; bridge handles unknown models safely).""" + mock_get_info.side_effect = Exception("model not in catalog") + config = ProviderConfigManager.get_provider_responses_api_config( + model="github_copilot/never-seen-model", + provider=LlmProviders.GITHUB_COPILOT, + ) + assert config is None + + @patch( + "litellm.llms.github_copilot.responses.transformation._cached_get_model_info_helper" + ) + def test_user_override_via_register_model(self, mock_get_info): + """User-supplied per-deployment ``model_info`` flows through + ``litellm.register_model`` (called by the router) into the merged + catalog read by ``_cached_get_model_info_helper``. Setting ``mode=responses`` + for a model whose catalog entry says ``mode=chat`` therefore opts in + to native dispatch without any per-call argument plumbing.""" + mock_get_info.return_value = {"mode": "responses"} + config = ProviderConfigManager.get_provider_responses_api_config( + model="github_copilot/some-chat-only-model", + provider=LlmProviders.GITHUB_COPILOT, + ) + assert isinstance(config, GithubCopilotResponsesAPIConfig) + + @patch( + "litellm.llms.github_copilot.responses.transformation._cached_get_model_info_helper" + ) + def test_realistic_chat_only_entry_returns_none(self, mock_get_info): + """Realistic ``model_prices_and_context_window.json`` shape for a + chat-only Copilot model (e.g. github_copilot/gemini-3.1-pro-preview) + returns None so /v1/responses calls fall back to the bridge.""" + mock_get_info.return_value = { + "litellm_provider": "github_copilot", + "max_input_tokens": 136000, + "max_output_tokens": 64000, + "max_tokens": 64000, + "mode": "chat", + "supported_endpoints": ["/v1/chat/completions"], + "supports_function_calling": True, + "supports_tool_choice": True, + "supports_parallel_function_calling": True, + "supports_vision": True, + "supports_reasoning": True, + } + config = ProviderConfigManager.get_provider_responses_api_config( + model="github_copilot/some-chat-only-model", + provider=LlmProviders.GITHUB_COPILOT, + ) + assert config is None + + @patch( + "litellm.llms.github_copilot.responses.transformation._cached_get_model_info_helper" + ) + def test_realistic_responses_only_entry_returns_config(self, mock_get_info): + """Realistic catalog entry for a Responses-only Copilot model + (e.g. github_copilot/gpt-5.5) returns the native config.""" + mock_get_info.return_value = { + "litellm_provider": "github_copilot", + "max_input_tokens": 272000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "responses", + "supported_endpoints": ["/v1/responses"], + "supports_function_calling": True, + "supports_tool_choice": True, + "supports_parallel_function_calling": True, + "supports_response_schema": True, + "supports_vision": True, + "supports_reasoning": True, + "supports_none_reasoning_effort": True, + "supports_xhigh_reasoning_effort": True, + } + config = ProviderConfigManager.get_provider_responses_api_config( + model="github_copilot/some-responses-only-model", + provider=LlmProviders.GITHUB_COPILOT, + ) + assert isinstance(config, GithubCopilotResponsesAPIConfig) diff --git a/tests/test_litellm/llms/parasail/test_parasail.py b/tests/test_litellm/llms/parasail/test_parasail.py new file mode 100644 index 0000000000..8fb9b22b5f --- /dev/null +++ b/tests/test_litellm/llms/parasail/test_parasail.py @@ -0,0 +1,172 @@ +import os +from unittest.mock import patch + +PARASAIL_API_BASE = "https://api.parasail.io/v1" +PARASAIL_RESPONSES_GATEWAY = "https://api-webflux.saas.parasail.io/v1" + + +def test_parasail_json_registry(): + import litellm + from litellm.llms.openai_like.json_loader import JSONProviderRegistry + + assert litellm.LlmProviders.PARASAIL.value == "parasail" + assert litellm.LlmProviders("parasail") == litellm.LlmProviders.PARASAIL + assert JSONProviderRegistry.exists("parasail") + config = JSONProviderRegistry.get("parasail") + assert config is not None + assert config.base_url == PARASAIL_API_BASE + assert config.api_key_env == "PARASAIL_API_KEY" + assert config.api_base_env == "PARASAIL_API_BASE" + assert "/v1/chat/completions" in config.supported_endpoints + assert "/v1/responses" in config.supported_endpoints + assert config.special_handling.get("force_store_false") is True + + +def test_parasail_listed_in_openai_compatible_providers(): + from litellm.constants import openai_compatible_providers + + assert "parasail" in openai_compatible_providers + + +def test_parasail_dynamic_config_env_vars(): + from litellm.llms.openai_like.dynamic_config import create_config_class + from litellm.llms.openai_like.json_loader import JSONProviderRegistry + + config = create_config_class(JSONProviderRegistry.get("parasail"))() + + with patch.dict( + os.environ, + { + "PARASAIL_API_KEY": "test-key", + "PARASAIL_API_BASE": PARASAIL_RESPONSES_GATEWAY, + }, + ): + api_base, api_key = config._get_openai_compatible_provider_info(None, None) + + assert api_base == PARASAIL_RESPONSES_GATEWAY + assert api_key == "test-key" + + +def test_parasail_provider_detection_by_prefix(): + from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider + + model, provider, _, api_base = get_llm_provider( + "parasail/parasail-llama-33-70b-fp8" + ) + + assert model == "parasail-llama-33-70b-fp8" + assert provider == "parasail" + assert api_base == PARASAIL_API_BASE + + +def test_parasail_chat_complete_url(): + from litellm.llms.openai_like.dynamic_config import create_config_class + from litellm.llms.openai_like.json_loader import JSONProviderRegistry + + config = create_config_class(JSONProviderRegistry.get("parasail"))() + + assert ( + config.get_complete_url( + api_base=None, + api_key=None, + model="parasail-llama-33-70b-fp8", + optional_params={}, + litellm_params={}, + ) + == f"{PARASAIL_API_BASE}/chat/completions" + ) + + +def test_parasail_responses_api_config(): + from litellm.llms.openai.responses.transformation import OpenAIResponsesAPIConfig + from litellm.utils import ProviderConfigManager + + config = ProviderConfigManager.get_provider_responses_api_config( + provider="parasail", + model="parasail-kimi-k25-elicit", + ) + + assert isinstance(config, OpenAIResponsesAPIConfig) + assert config.custom_llm_provider == "parasail" + assert ( + config.get_complete_url(api_base=None, litellm_params={}) + == f"{PARASAIL_API_BASE}/responses" + ) + + +def test_parasail_responses_api_honors_api_base_override(): + from litellm.utils import ProviderConfigManager + + config = ProviderConfigManager.get_provider_responses_api_config( + provider="parasail", + model="parasail-kimi-k25-elicit", + ) + + with patch.dict( + os.environ, + {"PARASAIL_API_BASE": PARASAIL_RESPONSES_GATEWAY}, + ): + url = config.get_complete_url(api_base=None, litellm_params={}) + + assert url == f"{PARASAIL_RESPONSES_GATEWAY}/responses" + + +def test_parasail_responses_api_forces_store_false_when_caller_sets_true(): + from litellm.types.router import GenericLiteLLMParams + from litellm.utils import ProviderConfigManager + + config = ProviderConfigManager.get_provider_responses_api_config( + provider="parasail", + model="parasail-kimi-k25-elicit", + ) + + request_params: dict = {"store": True, "temperature": 0.2} + transformed = config.transform_responses_api_request( + model="parasail-kimi-k25-elicit", + input="hello", + response_api_optional_request_params=request_params, + litellm_params=GenericLiteLLMParams(), + headers={}, + ) + + assert transformed["store"] is False + assert transformed["temperature"] == 0.2 + + +def test_parasail_responses_api_forces_store_false_when_caller_omits_store(): + from litellm.types.router import GenericLiteLLMParams + from litellm.utils import ProviderConfigManager + + config = ProviderConfigManager.get_provider_responses_api_config( + provider="parasail", + model="parasail-kimi-k25-elicit", + ) + + transformed = config.transform_responses_api_request( + model="parasail-kimi-k25-elicit", + input="hello", + response_api_optional_request_params={}, + litellm_params=GenericLiteLLMParams(), + headers={}, + ) + + assert transformed["store"] is False + + +def test_parasail_responses_api_validate_environment_sets_bearer_token(): + from litellm.types.router import GenericLiteLLMParams + from litellm.utils import ProviderConfigManager + + config = ProviderConfigManager.get_provider_responses_api_config( + provider="parasail", + model="parasail-kimi-k25-elicit", + ) + + with patch.dict(os.environ, {"PARASAIL_API_KEY": "secret-from-env"}): + headers = config.validate_environment( + headers={}, + model="parasail-kimi-k25-elicit", + litellm_params=GenericLiteLLMParams(), + ) + + assert headers["Authorization"] == "Bearer secret-from-env" diff --git a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py index fa6cc8bed1..0236646c79 100644 --- a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py +++ b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py @@ -112,11 +112,71 @@ async def test_should_clear_stale_budget_reservation_when_budget_checks_skip(): user_api_key_cache=MagicMock(), proxy_logging_obj=MagicMock(), skip_budget_checks=True, + general_settings={}, ) assert user_api_key_auth_obj.budget_reservation is None +@pytest.mark.asyncio +async def test_disable_budget_reservation_skips_reservation(): + """#27639: general_settings.disable_budget_reservation turns off the optimistic Redis + reservation so operators hit by phantom BudgetExceededError can opt out of it.""" + user_api_key_auth_obj = UserAPIKeyAuth(token="test_token") + + with patch( + "litellm.proxy.spend_tracking.budget_reservation.reserve_budget_for_request", + new=AsyncMock(return_value={"reserved_cost": 0.5, "entries": []}), + ) as mock_reserve: + await _reserve_budget_after_common_checks( + user_api_key_auth_obj=user_api_key_auth_obj, + request_data={"model": "gpt-4o"}, + route="/v1/chat/completions", + llm_router=None, + team_object=None, + user_object=None, + prisma_client=None, + user_api_key_cache=MagicMock(), + proxy_logging_obj=MagicMock(), + skip_budget_checks=False, + general_settings={"disable_budget_reservation": True}, + ) + + mock_reserve.assert_not_called() + assert user_api_key_auth_obj.budget_reservation is None + + +@pytest.mark.asyncio +async def test_budget_reservation_runs_when_not_disabled(): + """Control for #27639: with the flag absent, the reservation still runs and is stored.""" + user_api_key_auth_obj = UserAPIKeyAuth(token="test_token") + reservation = { + "reserved_cost": 0.5, + "entries": [{"counter_key": "spend:key:test_token"}], + } + + with patch( + "litellm.proxy.spend_tracking.budget_reservation.reserve_budget_for_request", + new=AsyncMock(return_value=reservation), + ) as mock_reserve: + await _reserve_budget_after_common_checks( + user_api_key_auth_obj=user_api_key_auth_obj, + request_data={"model": "gpt-4o"}, + route="/v1/chat/completions", + llm_router=None, + team_object=None, + user_object=None, + prisma_client=None, + user_api_key_cache=MagicMock(), + proxy_logging_obj=MagicMock(), + skip_budget_checks=False, + general_settings={}, + ) + + mock_reserve.assert_awaited_once() + assert user_api_key_auth_obj.budget_reservation == reservation + + @pytest.mark.asyncio async def test_should_not_reuse_cached_key_object_for_request_state(): key_cache = DualCache() diff --git a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_crowdstrike_aidr.py b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_crowdstrike_aidr.py index c58c94cbbc..e7f72ff7a3 100644 --- a/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_crowdstrike_aidr.py +++ b/tests/test_litellm/proxy/guardrails/guardrail_hooks/test_crowdstrike_aidr.py @@ -41,7 +41,8 @@ def test_crowdstrike_aidr_guardrail_config() -> None: ) -def test_crowdstrike_aidr_guardrail_config_no_api_key() -> None: +def test_crowdstrike_aidr_guardrail_config_no_api_key(monkeypatch) -> None: + monkeypatch.delenv("CS_AIDR_TOKEN", raising=False) with pytest.raises(CrowdStrikeAIDRGuardrailMissingSecrets): init_guardrails_v2( all_guardrails=[ @@ -59,7 +60,8 @@ def test_crowdstrike_aidr_guardrail_config_no_api_key() -> None: ) -def test_crowdstrike_aidr_guardrail_config_no_api_base() -> None: +def test_crowdstrike_aidr_guardrail_config_no_api_base(monkeypatch) -> None: + monkeypatch.delenv("CS_AIDR_BASE_URL", raising=False) with pytest.raises(CrowdStrikeAIDRGuardrailMissingSecrets): init_guardrails_v2( all_guardrails=[ @@ -412,6 +414,121 @@ async def test_apply_guardrail_response_ok( assert result["texts"] == inputs["texts"] +@pytest.mark.asyncio +async def test_apply_guardrail_sends_user_id_model_and_extra_info( + crowdstrike_aidr_guardrail: CrowdStrikeAIDRHandler, +) -> None: + inputs: GenericGuardrailAPIInputs = { + "texts": ["Hello"], + "structured_messages": [{"role": "user", "content": "Hello"}], + "model": "gpt-4o", + } + request_data = { + "messages": inputs["structured_messages"], + "model": "gpt-4o", + "litellm_metadata": { + "user_api_key_user_id": "uid-abc", + "user_api_key_user_email": "alice@example.com", + }, + } + guardrail_endpoint = ( + f"{crowdstrike_aidr_guardrail.api_base}/v1/guard_chat_completions" + ) + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=httpx.Response( + status_code=200, + json={"result": {"blocked": False, "transformed": False}}, + request=httpx.Request(method="POST", url=guardrail_endpoint), + ), + ) as mock_method: + await crowdstrike_aidr_guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + ) + + payload = mock_method.call_args.kwargs["json"] + assert payload["user_id"] == "uid-abc" + assert payload["model"] == "gpt-4o" + assert payload["extra_info"] == {"user_name": "alice@example.com"} + + +@pytest.mark.asyncio +async def test_apply_guardrail_empty_extra_info_when_no_email( + crowdstrike_aidr_guardrail: CrowdStrikeAIDRHandler, +) -> None: + inputs: GenericGuardrailAPIInputs = { + "texts": ["Hello"], + "structured_messages": [{"role": "user", "content": "Hello"}], + "model": "gemini-flash", + } + request_data = { + "messages": inputs["structured_messages"], + "model": "gemini-flash", + "litellm_metadata": { + "user_api_key_user_id": "uid-no-email", + "user_api_key_user_email": None, + }, + } + guardrail_endpoint = ( + f"{crowdstrike_aidr_guardrail.api_base}/v1/guard_chat_completions" + ) + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=httpx.Response( + status_code=200, + json={"result": {"blocked": False, "transformed": False}}, + request=httpx.Request(method="POST", url=guardrail_endpoint), + ), + ) as mock_method: + await crowdstrike_aidr_guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + ) + + payload = mock_method.call_args.kwargs["json"] + assert payload["user_id"] == "uid-no-email" + assert payload["model"] == "gemini-flash" + assert payload["extra_info"] == {} + + +@pytest.mark.asyncio +async def test_apply_guardrail_no_metadata_skips_user_fields( + crowdstrike_aidr_guardrail: CrowdStrikeAIDRHandler, +) -> None: + inputs: GenericGuardrailAPIInputs = { + "texts": ["Hello"], + "structured_messages": [{"role": "user", "content": "Hello"}], + } + request_data = {"messages": inputs["structured_messages"]} + guardrail_endpoint = ( + f"{crowdstrike_aidr_guardrail.api_base}/v1/guard_chat_completions" + ) + + with patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post", + return_value=httpx.Response( + status_code=200, + json={"result": {"blocked": False, "transformed": False}}, + request=httpx.Request(method="POST", url=guardrail_endpoint), + ), + ) as mock_method: + await crowdstrike_aidr_guardrail.apply_guardrail( + inputs=inputs, + request_data=request_data, + input_type="request", + ) + + payload = mock_method.call_args.kwargs["json"] + assert "user_id" not in payload + assert "model" not in payload + assert "extra_info" not in payload + + @pytest.mark.asyncio async def test_apply_guardrail_request_skipped_messages_stay_aligned( crowdstrike_aidr_guardrail: CrowdStrikeAIDRHandler, diff --git a/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py b/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py index f8b6fbde3d..1114b3df0c 100644 --- a/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py +++ b/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_anthropic_passthrough_logging_handler.py @@ -321,6 +321,44 @@ class TestAzureAnthropicCostCalculation: assert call_kwargs["model"] == "azure_ai/claude-sonnet-4-5_gb_20250929" assert call_kwargs["custom_llm_provider"] == "azure_ai" + def test_passthrough_logging_sets_response_cost_with_server_tool_use_dict(self): + from litellm.types.utils import Choices, Message, ModelResponse + + logging_obj = self._create_mock_logging_obj(model="claude-3-7-sonnet-20250219") + logging_obj.get_router_model_id.return_value = None + logging_obj.litellm_params = {} + + response = ModelResponse( + id="test-id", + choices=[ + Choices( + finish_reason="stop", + index=0, + message=Message(content="test", role="assistant"), + ) + ], + created=1234567890, + model="claude-3-7-sonnet-20250219", + usage={ + "prompt_tokens": 10, + "completion_tokens": 5, + "total_tokens": 15, + "server_tool_use": {"web_search_requests": 1}, + }, + ) + + kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload( + litellm_model_response=response, + model="claude-3-7-sonnet-20250219", + kwargs={}, + start_time=datetime.now(), + end_time=datetime.now(), + logging_obj=logging_obj, + ) + + assert "response_cost" in kwargs + assert kwargs["response_cost"] > 0 + class TestAnthropicBatchPassthroughCostTracking: """Test cases for Anthropic batch passthrough cost tracking functionality""" diff --git a/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_openai_passthrough_logging_handler.py b/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_openai_passthrough_logging_handler.py index bfcaaafd33..3c6af3e528 100644 --- a/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_openai_passthrough_logging_handler.py +++ b/tests/test_litellm/proxy/pass_through_endpoints/llm_provider_handlers/test_openai_passthrough_logging_handler.py @@ -257,6 +257,64 @@ class TestOpenAIPassthroughLoggingHandler: ) assert OpenAIPassthroughLoggingHandler.is_openai_responses_route("") == False + def test_is_openai_route_recognizes_cognitiveservices_azure_com(self): + """Azure OpenAI resources created via the newer "Azure AI Foundry" / + Cognitive Services pathway live on `*.cognitiveservices.azure.com` + subdomains rather than the older `openai.azure.com`. All four + is_openai_*_route methods must recognize both Azure subdomains so + cost tracking applies regardless of which Azure naming the user's + resource happens to be on. + """ + cognitive_chat = ( + "https://my-resource.cognitiveservices.azure.com/v1/chat/completions" + ) + cognitive_images_gen = ( + "https://my-resource.cognitiveservices.azure.com/v1/images/generations" + ) + cognitive_images_edit = ( + "https://my-resource.cognitiveservices.azure.com/v1/images/edits" + ) + cognitive_responses = ( + "https://my-resource.cognitiveservices.azure.com/v1/responses" + ) + + assert ( + OpenAIPassthroughLoggingHandler.is_openai_chat_completions_route( + cognitive_chat + ) + is True + ) + assert ( + OpenAIPassthroughLoggingHandler.is_openai_image_generation_route( + cognitive_images_gen + ) + is True + ) + assert ( + OpenAIPassthroughLoggingHandler.is_openai_image_editing_route( + cognitive_images_edit + ) + is True + ) + assert ( + OpenAIPassthroughLoggingHandler.is_openai_responses_route( + cognitive_responses + ) + is True + ) + + # Cross-route negatives still hold for cognitiveservices hosts. + assert ( + OpenAIPassthroughLoggingHandler.is_openai_chat_completions_route( + cognitive_responses + ) + is False + ) + assert ( + OpenAIPassthroughLoggingHandler.is_openai_responses_route(cognitive_chat) + is False + ) + @patch("litellm.completion_cost") @patch( "litellm.litellm_core_utils.litellm_logging.get_standard_logging_object_payload" @@ -766,6 +824,14 @@ class TestOpenAIPassthroughIntegration: == True ) assert self.handler.is_openai_route("https://api.openai.com/v1/models") == True + # Azure OpenAI on the shared Cognitive Services domain, identified by an + # OpenAI-style path segment. + assert ( + self.handler.is_openai_route( + "https://my-resource.cognitiveservices.azure.com/v1/chat/completions" + ) + == True + ) # Negative cases assert ( @@ -782,6 +848,28 @@ class TestOpenAIPassthroughIntegration: self.handler.is_openai_route("https://api.assemblyai.com/v2/transcript") == False ) + # Non-OpenAI Azure Cognitive Services share the `cognitiveservices.azure.com` + # domain but must NOT be classified as OpenAI routes (no OpenAI path segment). + assert ( + self.handler.is_openai_route( + "https://my-resource.cognitiveservices.azure.com/speechtotext/v3.1/recognize" + ) + == False + ) + assert ( + self.handler.is_openai_route( + "https://my-resource.cognitiveservices.azure.com/vision/v3.2/analyze" + ) + == False + ) + # A look-alike domain that merely contains an OpenAI host as a substring + # must be rejected by the suffix-based hostname match. + assert ( + self.handler.is_openai_route( + "https://cognitiveservices.azure.com.attacker.example/v1/chat/completions" + ) + == False + ) assert self.handler.is_openai_route("") == False @patch( diff --git a/tests/test_litellm/responses/litellm_completion_transformation/test_litellm_completion_responses.py b/tests/test_litellm/responses/litellm_completion_transformation/test_litellm_completion_responses.py index 503a610e01..960fca205c 100644 --- a/tests/test_litellm/responses/litellm_completion_transformation/test_litellm_completion_responses.py +++ b/tests/test_litellm/responses/litellm_completion_transformation/test_litellm_completion_responses.py @@ -949,6 +949,28 @@ class TestToolChoiceTransformation: result = LiteLLMCompletionResponsesConfig._transform_tool_choice(tool_choice) assert result == tool_choice + def test_transform_tool_choice_responses_flat_function_name(self): + """Responses-API forced-function with a top-level name maps to the nested Chat + Completions shape instead of degrading to required and dropping the name""" + result = LiteLLMCompletionResponsesConfig._transform_tool_choice( + {"type": "function", "name": "get_weather"} + ) + assert result == {"type": "function", "function": {"name": "get_weather"}} + + def test_transform_tool_choice_function_without_name_falls_back_to_required(self): + """A function-type dict with no name still falls back to required""" + result = LiteLLMCompletionResponsesConfig._transform_tool_choice( + {"type": "function"} + ) + assert result == "required" + + def test_transform_tool_choice_function_empty_name_falls_back_to_required(self): + """An empty top-level name is falsy and must not produce an empty function name""" + result = LiteLLMCompletionResponsesConfig._transform_tool_choice( + {"type": "function", "name": ""} + ) + assert result == "required" + class TestContentTypeTransformation: """Test content type transformation from Responses API to Chat Completion format""" diff --git a/tests/test_litellm/types/test_types_utils.py b/tests/test_litellm/types/test_types_utils.py index c146847f39..a4074ccdaa 100644 --- a/tests/test_litellm/types/test_types_utils.py +++ b/tests/test_litellm/types/test_types_utils.py @@ -1,13 +1,9 @@ -import asyncio import os import sys -from typing import Optional -from unittest.mock import AsyncMock, patch import pytest sys.path.insert(0, os.path.abspath("../..")) -import json from litellm.types.utils import HiddenParams @@ -75,6 +71,48 @@ def test_usage_dump(): assert new_usage.prompt_tokens_details.web_search_requests == 1 +def test_usage_server_tool_use_dict_is_coerced_and_round_trips(): + from litellm.types.utils import ServerToolUse, Usage + + current_usage = Usage( + completion_tokens=1, + prompt_tokens=1, + total_tokens=2, + server_tool_use={"web_search_requests": 1}, + ) + + assert isinstance(current_usage.server_tool_use, ServerToolUse) + assert current_usage.server_tool_use.web_search_requests == 1 + + new_usage = Usage(**current_usage.model_dump()) + assert isinstance(new_usage.server_tool_use, ServerToolUse) + assert new_usage.server_tool_use.web_search_requests == 1 + + +def test_usage_converts_server_tool_use_dict(): + from litellm.types.utils import ServerToolUse, Usage + + usage = Usage( + completion_tokens=2, + prompt_tokens=1, + total_tokens=3, + server_tool_use={"web_search_requests": 4, "tool_search_requests": 1}, + ) + + assert isinstance(usage.server_tool_use, ServerToolUse) + assert usage.server_tool_use.web_search_requests == 4 + assert usage.server_tool_use["web_search_requests"] == 4 + assert usage.server_tool_use.tool_search_requests == 1 + with pytest.raises(KeyError): + usage.server_tool_use["unknown_metric"] + + round_trip = Usage(**usage.model_dump()) + assert isinstance(round_trip.server_tool_use, ServerToolUse) + assert round_trip.server_tool_use.web_search_requests == 4 + assert round_trip.server_tool_use["web_search_requests"] == 4 + assert round_trip.server_tool_use.tool_search_requests == 1 + + def test_usage_completion_tokens_details_text_tokens(): from litellm.types.utils import Usage diff --git a/ui/litellm-dashboard/src/lib/http/schema.d.ts b/ui/litellm-dashboard/src/lib/http/schema.d.ts index 75a14e0985..d9992680c8 100644 --- a/ui/litellm-dashboard/src/lib/http/schema.d.ts +++ b/ui/litellm-dashboard/src/lib/http/schema.d.ts @@ -22034,6 +22034,11 @@ export interface components { * @description connect to a postgres db - needed for generating temporary keys + tracking spend / key */ database_url?: string | null; + /** + * Disable Budget Reservation + * @description If True, disables the optimistic per-request budget reservation introduced in v1.84.0. WARNING: This weakens hard budget enforcement. Without the reservation, a burst of concurrent requests from a single key can each pass the read-time spend check before any of them is charged, allowing a configured budget to be exceeded under high concurrency. Budgets are still evaluated on every request at read time, so an already-exhausted budget is still rejected. Enable only if your deployment is experiencing phantom BudgetExceededError responses caused by leaked reservations (see GitHub issue #27639). A proxy-level WARNING is logged on every request while this flag is active as a reminder that hard enforcement is relaxed. + */ + disable_budget_reservation?: boolean | null; /** * Enable Public Model Hub * @description Public model hub for users to see what models they have access to, supported openai params, etc.