diff --git a/litellm/caching/qdrant_semantic_cache.py b/litellm/caching/qdrant_semantic_cache.py index 5e3713e5a1..cb521efca0 100644 --- a/litellm/caching/qdrant_semantic_cache.py +++ b/litellm/caching/qdrant_semantic_cache.py @@ -11,17 +11,23 @@ Has 4 methods: import ast import asyncio import json -from typing import Any, cast +import os +from typing import Any, Dict, cast import litellm from litellm._logging import print_verbose from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + get_str_from_messages, +) from litellm.types.utils import EmbeddingResponse from .base_cache import BaseCache class QdrantSemanticCache(BaseCache): + CACHE_KEY_FIELD_NAME = "litellm_cache_key" + def __init__( # noqa: PLR0915 self, qdrant_api_base=None, @@ -33,8 +39,6 @@ class QdrantSemanticCache(BaseCache): host_type=None, vector_size=None, ): - import os - from litellm.llms.custom_httpx.http_handler import ( _get_httpx_client, get_async_httpx_client, @@ -115,7 +119,9 @@ class QdrantSemanticCache(BaseCache): print_verbose( f"Collection already exists.\nCollection details:{self.collection_info}" ) + self._ensure_cache_key_payload_index() else: + quantization_params: Dict[str, Any] if quantization_config is None or quantization_config == "binary": quantization_params = { "binary": { @@ -156,6 +162,7 @@ class QdrantSemanticCache(BaseCache): print_verbose( f"New collection created.\nCollection details:{self.collection_info}" ) + self._ensure_cache_key_payload_index() else: raise Exception("Error while creating new collection") @@ -170,15 +177,94 @@ class QdrantSemanticCache(BaseCache): cached_response = ast.literal_eval(cached_response) return cached_response + def _get_qdrant_cache_key_filter(self, key: str) -> dict: + return { + "must": [ + { + "key": self.CACHE_KEY_FIELD_NAME, + "match": {"value": str(key)}, + } + ] + } + + def _add_cache_key_filter_to_search_data(self, data: dict, key: str) -> None: + data["filter"] = self._get_qdrant_cache_key_filter(key) + + def _ensure_cache_key_payload_index(self) -> None: + try: + response = self.sync_client.put( + url=f"{self.qdrant_api_base}/collections/{self.collection_name}/index", + headers=self.headers, + json={ + "field_name": self.CACHE_KEY_FIELD_NAME, + "field_schema": "keyword", + }, + ) + if response.status_code not in (200, 201): + print_verbose( + "Qdrant semantic-cache could not create cache-key payload index: " + f"{response.text}" + ) + except Exception as exc: + print_verbose( + "Qdrant semantic-cache could not create cache-key payload index: " + f"{str(exc)}" + ) + + def _payload_matches_cache_key(self, payload: dict, key: str) -> bool: + # Pre-isolation points stored only prompt + response with no cache-key + # payload field. Reassigning them to a caller's key would risk + # cross-scope hits, so they're treated as misses and re-populated on + # the next set_cache. + cached_key = payload.get(self.CACHE_KEY_FIELD_NAME) + return cached_key is not None and str(cached_key) == str(key) + + async def _get_async_embedding(self, prompt: str, **kwargs) -> Any: + llm_model_list = None + llm_router = None + + try: + from litellm.proxy.proxy_server import ( + llm_model_list as proxy_llm_model_list, + llm_router as proxy_llm_router, + ) + + llm_model_list = proxy_llm_model_list + llm_router = proxy_llm_router + except ImportError: + pass + + router_model_names = ( + [m["model_name"] for m in llm_model_list] + if llm_model_list is not None + else [] + ) + if llm_router is not None and self.embedding_model in router_model_names: + user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") + return await llm_router.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + metadata={ + "user_api_key": user_api_key, + "semantic-cache-embedding": True, + "trace_id": kwargs.get("metadata", {}).get("trace_id", None), + }, + ) + + return await litellm.aembedding( + model=self.embedding_model, + input=prompt, + cache={"no-store": True, "no-cache": True}, + ) + def set_cache(self, key, value, **kwargs): print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}") from litellm._uuid import uuid # get the prompt messages = kwargs["messages"] - prompt = "" - for message in messages: - prompt += message["content"] + prompt = get_str_from_messages(messages) # create an embedding for prompt embedding_response = cast( @@ -202,6 +288,7 @@ class QdrantSemanticCache(BaseCache): "id": str(uuid.uuid4()), "vector": embedding, "payload": { + self.CACHE_KEY_FIELD_NAME: str(key), "text": prompt, "response": value, }, @@ -220,9 +307,7 @@ class QdrantSemanticCache(BaseCache): # get the messages messages = kwargs["messages"] - prompt = "" - for message in messages: - prompt += message["content"] + prompt = get_str_from_messages(messages) # convert to embedding embedding_response = cast( @@ -249,6 +334,7 @@ class QdrantSemanticCache(BaseCache): "limit": 1, "with_payload": True, } + self._add_cache_key_filter_to_search_data(data=data, key=key) search_response = self.sync_client.post( url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search", @@ -258,21 +344,33 @@ class QdrantSemanticCache(BaseCache): results = search_response.json()["result"] if results is None: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 return None if isinstance(results, list): if len(results) == 0: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 return None similarity = results[0]["score"] - cached_prompt = results[0]["payload"]["text"] + payload = results[0]["payload"] + if not self._payload_matches_cache_key(payload=payload, key=key): + print_verbose("Qdrant semantic-cache hit did not match cache key scope") + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + + cached_prompt = payload["text"] # check similarity, if more than self.similarity_threshold, return results print_verbose( f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}" ) + + # update kwargs["metadata"] with similarity, don't rewrite the original metadata + kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity + if similarity >= self.similarity_threshold: # cache hit ! - cached_value = results[0]["payload"]["response"] + cached_value = payload["response"] print_verbose( f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" ) @@ -285,40 +383,12 @@ class QdrantSemanticCache(BaseCache): async def async_set_cache(self, key, value, **kwargs): from litellm._uuid import uuid - from litellm.proxy.proxy_server import llm_model_list, llm_router - print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}") # get the prompt messages = kwargs["messages"] - prompt = "" - for message in messages: - prompt += message["content"] - # create an embedding for prompt - router_model_names = ( - [m["model_name"] for m in llm_model_list] - if llm_model_list is not None - else [] - ) - if llm_router is not None and self.embedding_model in router_model_names: - user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") - embedding_response = await llm_router.aembedding( - model=self.embedding_model, - input=prompt, - cache={"no-store": True, "no-cache": True}, - metadata={ - "user_api_key": user_api_key, - "semantic-cache-embedding": True, - "trace_id": kwargs.get("metadata", {}).get("trace_id", None), - }, - ) - else: - # convert to embedding - embedding_response = await litellm.aembedding( - model=self.embedding_model, - input=prompt, - cache={"no-store": True, "no-cache": True}, - ) + prompt = get_str_from_messages(messages) + embedding_response = await self._get_async_embedding(prompt, **kwargs) # get the embedding embedding = embedding_response["data"][0]["embedding"] @@ -332,6 +402,7 @@ class QdrantSemanticCache(BaseCache): "id": str(uuid.uuid4()), "vector": embedding, "payload": { + self.CACHE_KEY_FIELD_NAME: str(key), "text": prompt, "response": value, }, @@ -348,38 +419,12 @@ class QdrantSemanticCache(BaseCache): async def async_get_cache(self, key, **kwargs): print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}") - from litellm.proxy.proxy_server import llm_model_list, llm_router # get the messages messages = kwargs["messages"] - prompt = "" - for message in messages: - prompt += message["content"] + prompt = get_str_from_messages(messages) - router_model_names = ( - [m["model_name"] for m in llm_model_list] - if llm_model_list is not None - else [] - ) - if llm_router is not None and self.embedding_model in router_model_names: - user_api_key = kwargs.get("metadata", {}).get("user_api_key", "") - embedding_response = await llm_router.aembedding( - model=self.embedding_model, - input=prompt, - cache={"no-store": True, "no-cache": True}, - metadata={ - "user_api_key": user_api_key, - "semantic-cache-embedding": True, - "trace_id": kwargs.get("metadata", {}).get("trace_id", None), - }, - ) - else: - # convert to embedding - embedding_response = await litellm.aembedding( - model=self.embedding_model, - input=prompt, - cache={"no-store": True, "no-cache": True}, - ) + embedding_response = await self._get_async_embedding(prompt, **kwargs) # get the embedding embedding = embedding_response["data"][0]["embedding"] @@ -396,6 +441,7 @@ class QdrantSemanticCache(BaseCache): "limit": 1, "with_payload": True, } + self._add_cache_key_filter_to_search_data(data=data, key=key) search_response = await self.async_client.post( url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search", @@ -414,7 +460,13 @@ class QdrantSemanticCache(BaseCache): return None similarity = results[0]["score"] - cached_prompt = results[0]["payload"]["text"] + payload = results[0]["payload"] + if not self._payload_matches_cache_key(payload=payload, key=key): + print_verbose("Qdrant semantic-cache hit did not match cache key scope") + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None + + cached_prompt = payload["text"] # check similarity, if more than self.similarity_threshold, return results print_verbose( @@ -426,7 +478,7 @@ class QdrantSemanticCache(BaseCache): if similarity >= self.similarity_threshold: # cache hit ! - cached_value = results[0]["payload"]["response"] + cached_value = payload["response"] print_verbose( f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}" ) diff --git a/litellm/caching/redis_semantic_cache.py b/litellm/caching/redis_semantic_cache.py index c76f27377d..da9e7b1e58 100644 --- a/litellm/caching/redis_semantic_cache.py +++ b/litellm/caching/redis_semantic_cache.py @@ -35,6 +35,7 @@ class RedisSemanticCache(BaseCache): """ DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index" + CACHE_KEY_FIELD_NAME: str = "litellm_cache_key" def __init__( self, @@ -66,8 +67,8 @@ class RedisSemanticCache(BaseCache): Exception: If similarity_threshold is not provided or required Redis connection information is missing """ - from redisvl.extensions.llmcache import SemanticCache - from redisvl.utils.vectorize import CustomTextVectorizer + from redisvl.extensions.llmcache import SemanticCache # type: ignore[import-not-found, import-untyped] + from redisvl.utils.vectorize import CustomTextVectorizer # type: ignore[import-not-found, import-untyped] if index_name is None: index_name = self.DEFAULT_REDIS_INDEX_NAME @@ -109,14 +110,94 @@ class RedisSemanticCache(BaseCache): # Initialize the Redis vectorizer and cache cache_vectorizer = CustomTextVectorizer(self._get_embedding) - self.llmcache = SemanticCache( - name=index_name, + self.llmcache = self._init_semantic_cache( + semantic_cache_cls=SemanticCache, + index_name=index_name, redis_url=redis_url, - vectorizer=cache_vectorizer, - distance_threshold=self.distance_threshold, - overwrite=False, + cache_vectorizer=cache_vectorizer, ) + @classmethod + def _cache_key_filterable_field(cls) -> Dict[str, str]: + return { + "name": cls.CACHE_KEY_FIELD_NAME, + "type": "tag", + } + + def _init_semantic_cache( + self, + semantic_cache_cls: Any, + index_name: str, + redis_url: str, + cache_vectorizer: Any, + ) -> Any: + def _is_schema_mismatch(exc: ValueError) -> bool: + error_message = str(exc).lower() + return any( + phrase in error_message + for phrase in ("schema does not match", "index schema") + ) + + try: + return semantic_cache_cls( + name=index_name, + redis_url=redis_url, + vectorizer=cache_vectorizer, + distance_threshold=self.distance_threshold, + filterable_fields=[self._cache_key_filterable_field()], + overwrite=False, + ) + except ValueError as exc: + if not _is_schema_mismatch(exc): + raise + + isolated_index_name = f"{index_name}_isolated" + print_verbose( + "Redis semantic-cache existing index schema is not isolated; " + f"using isolated index - {isolated_index_name}" + ) + try: + return semantic_cache_cls( + name=isolated_index_name, + redis_url=redis_url, + vectorizer=cache_vectorizer, + distance_threshold=self.distance_threshold, + filterable_fields=[self._cache_key_filterable_field()], + overwrite=False, + ) + except ValueError as isolated_exc: + if not _is_schema_mismatch(isolated_exc): + raise + + print_verbose( + "Redis semantic-cache isolated index schema is stale; " + f"recreating isolated index - {isolated_index_name}" + ) + return semantic_cache_cls( + name=isolated_index_name, + redis_url=redis_url, + vectorizer=cache_vectorizer, + distance_threshold=self.distance_threshold, + filterable_fields=[self._cache_key_filterable_field()], + overwrite=True, + ) + + def _get_cache_filters(self, key: str) -> Dict[str, str]: + return {self.CACHE_KEY_FIELD_NAME: str(key)} + + def _get_cache_key_filter_expression(self, key: str) -> Any: + from redisvl.query.filter import Tag # type: ignore[import-not-found, import-untyped] + + return Tag(self.CACHE_KEY_FIELD_NAME) == str(key) + + def _cache_hit_matches_key(self, cache_hit: Dict[str, Any], key: str) -> bool: + # Pre-isolation entries with no ``litellm_cache_key`` field cannot be + # safely reassigned to a caller's scope and are treated as misses. + cached_key = cache_hit.get(self.CACHE_KEY_FIELD_NAME) + if isinstance(cached_key, bytes): + cached_key = cached_key.decode("utf-8") + return cached_key is not None and str(cached_key) == str(key) + def _get_ttl(self, **kwargs) -> Optional[int]: """ Get the TTL (time-to-live) value for cache entries. @@ -188,7 +269,7 @@ class RedisSemanticCache(BaseCache): Store a value in the semantic cache. Args: - key: The cache key (not directly used in semantic caching) + key: The cache key used to isolate semantic cache entries value: The response value to cache **kwargs: Additional arguments including 'messages' for the prompt and optional 'ttl' for time-to-live @@ -206,12 +287,15 @@ class RedisSemanticCache(BaseCache): prompt = get_str_from_messages(messages) value_str = str(value) + store_kwargs: Dict[str, Any] = { + "filters": self._get_cache_filters(key), + } + # Get TTL and store in Redis semantic cache ttl = self._get_ttl(**kwargs) if ttl is not None: - self.llmcache.store(prompt, value_str, ttl=int(ttl)) - else: - self.llmcache.store(prompt, value_str) + store_kwargs["ttl"] = int(ttl) + self.llmcache.store(prompt, value_str, **store_kwargs) except Exception as e: print_verbose( f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}" @@ -222,7 +306,7 @@ class RedisSemanticCache(BaseCache): Retrieve a semantically similar cached response. Args: - key: The cache key (not directly used in semantic caching) + key: The cache key used to isolate semantic cache entries **kwargs: Additional arguments including 'messages' for the prompt Returns: @@ -235,18 +319,29 @@ class RedisSemanticCache(BaseCache): messages = kwargs.get("messages", []) if not messages: print_verbose("No messages provided for semantic cache lookup") + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 return None prompt = get_str_from_messages(messages) - # Check the cache for semantically similar prompts - results = self.llmcache.check(prompt=prompt) + # Check the cache for semantically similar prompts in this exact + # LiteLLM cache-key scope. + check_kwargs: Dict[str, Any] = { + "prompt": prompt, + "filter_expression": self._get_cache_key_filter_expression(key), + } + results = self.llmcache.check(**check_kwargs) # Return None if no similar prompts found if not results: + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 return None # Process the best matching result cache_hit = results[0] + if not self._cache_hit_matches_key(cache_hit=cache_hit, key=key): + print_verbose("Redis semantic-cache hit did not match cache key scope") + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None vector_distance = float(cache_hit["vector_distance"]) # Convert vector distance back to similarity score @@ -257,6 +352,9 @@ class RedisSemanticCache(BaseCache): cached_prompt = cache_hit["prompt"] cached_response = cache_hit["response"] + # update kwargs["metadata"] with similarity, don't rewrite the original metadata + kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity + print_verbose( f"Cache hit: similarity threshold: {self.similarity_threshold}, " f"actual similarity: {similarity}, " @@ -267,6 +365,7 @@ class RedisSemanticCache(BaseCache): return self._get_cache_logic(cached_response=cached_response) except Exception as e: print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}") + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]: """ @@ -321,7 +420,7 @@ class RedisSemanticCache(BaseCache): Asynchronously store a value in the semantic cache. Args: - key: The cache key (not directly used in semantic caching) + key: The cache key used to isolate semantic cache entries value: The response value to cache **kwargs: Additional arguments including 'messages' for the prompt and optional 'ttl' for time-to-live @@ -341,21 +440,20 @@ class RedisSemanticCache(BaseCache): # Generate embedding for the value (response) to cache prompt_embedding = await self._get_async_embedding(prompt, **kwargs) + store_kwargs: Dict[str, Any] = { + "vector": prompt_embedding, + "filters": self._get_cache_filters(key), + } + # Get TTL and store in Redis semantic cache ttl = self._get_ttl(**kwargs) if ttl is not None: - await self.llmcache.astore( - prompt, - value_str, - vector=prompt_embedding, # Pass through custom embedding - ttl=ttl, - ) - else: - await self.llmcache.astore( - prompt, - value_str, - vector=prompt_embedding, # Pass through custom embedding - ) + store_kwargs["ttl"] = ttl + await self.llmcache.astore( + prompt, + value_str, + **store_kwargs, + ) except Exception as e: print_verbose(f"Error in async_set_cache: {str(e)}") @@ -364,7 +462,7 @@ class RedisSemanticCache(BaseCache): Asynchronously retrieve a semantically similar cached response. Args: - key: The cache key (not directly used in semantic caching) + key: The cache key used to isolate semantic cache entries **kwargs: Additional arguments including 'messages' for the prompt Returns: @@ -385,17 +483,25 @@ class RedisSemanticCache(BaseCache): # Generate embedding for the prompt prompt_embedding = await self._get_async_embedding(prompt, **kwargs) - # Check the cache for semantically similar prompts - results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding) + # Check the cache for semantically similar prompts in this exact + # LiteLLM cache-key scope. + check_kwargs: Dict[str, Any] = { + "prompt": prompt, + "vector": prompt_embedding, + "filter_expression": self._get_cache_key_filter_expression(key), + } + results = await self.llmcache.acheck(**check_kwargs) # handle results / cache hit if not results: - kwargs.setdefault("metadata", {})[ - "semantic-similarity" - ] = 0.0 # TODO why here but not above?? + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 return None cache_hit = results[0] + if not self._cache_hit_matches_key(cache_hit=cache_hit, key=key): + print_verbose("Redis semantic-cache hit did not match cache key scope") + kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 + return None vector_distance = float(cache_hit["vector_distance"]) # Convert vector distance back to similarity diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index 51108827f6..9a6fc95f14 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -2,7 +2,7 @@ import os import re import sys from functools import lru_cache -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, FrozenSet, List, Mapping, Optional, Tuple, Union from fastapi import HTTPException, Request, status @@ -173,9 +173,67 @@ def _allow_model_level_clientside_configurable_parameters( # threat shape should be added here. _NESTED_CONFIG_KEYS: Tuple[str, ...] = ("litellm_embedding_config",) -# Banned root-level params. Same list applies to every entry in -# ``_NESTED_CONFIG_KEYS`` because those dicts get spread as ``**kwargs`` -# into the same outbound calls. +# Metadata containers that carry per-request configuration consumed by the +# observability callbacks. The same banned-param list applies — a value +# under ``metadata.langfuse_host`` redirects the same Langfuse client and +# leaks the same credentials as the root-level ``langfuse_host``, but the +# original check only walked the request-body root, so the metadata path +# was an unintentional bypass. +_NESTED_METADATA_KEYS: Tuple[str, ...] = ("metadata", "litellm_metadata") + +# Banned request-body params. The same list applies to every entry in +# ``_NESTED_CONFIG_KEYS`` (dicts spread as ``**kwargs`` into outbound +# calls) and ``_NESTED_METADATA_KEYS`` (dicts read directly by integration +# callbacks), so a single banned name is enforced wherever the field can +# reach the call path from. +# Per-request observability params that are SAFE to accept from clients. +# These describe the request being logged (prompt version, sampling rate) +# without choosing the destination or the credentials, so they don't +# contribute to the data-exfil primitive that the rest of +# ``_supported_callback_params`` does. +_SAFE_CLIENT_CALLBACK_PARAMS: FrozenSet[str] = frozenset( + { + "langfuse_prompt_version", + "langsmith_sampling_rate", + } +) + +# Observability fields that integrations read from the request body or +# metadata but that are not (yet) listed in ``_supported_callback_params``. +# Listed here so the proxy bans them today; the long-term cleanup is to +# fold these into the canonical allowlist so they share one source of +# truth with the rest. +_EXTRA_BANNED_OBSERVABILITY_PARAMS: FrozenSet[str] = frozenset( + { + "posthog_api_url", + "phoenix_project_name", + "wandb_api_key", + "weave_project_id", + } +) + + +def _build_banned_observability_params() -> FrozenSet[str]: + """Derive the observability ban list from the canonical allowlist. + + ``_supported_callback_params`` in + ``litellm/litellm_core_utils/initialize_dynamic_callback_params.py`` is + the single place that enumerates every observability field + integrations resolve from kwargs/metadata. Subtract the small set of + informational fields (``_SAFE_CLIENT_CALLBACK_PARAMS``) and union with + the extras the canonical allowlist hasn't caught up to yet. New + integrations added to the canonical allowlist are banned by default, + which is the safe failure mode. + """ + from litellm.litellm_core_utils.initialize_dynamic_callback_params import ( + _supported_callback_params, + ) + + return ( + frozenset(_supported_callback_params) - _SAFE_CLIENT_CALLBACK_PARAMS + ) | _EXTRA_BANNED_OBSERVABILITY_PARAMS + + _BANNED_REQUEST_BODY_PARAMS: Tuple[str, ...] = ( "api_base", "base_url", @@ -190,11 +248,6 @@ _BANNED_REQUEST_BODY_PARAMS: Tuple[str, ...] = ( # tokens) to the attacker's host, or coerces the proxy into # authenticating against the attacker's host with admin secrets. "aws_bedrock_runtime_endpoint", - "langsmith_base_url", - "langfuse_host", - "posthog_host", - "braintrust_host", - "slack_webhook_url", # Provider-specific endpoint overrides that flow into the outbound # request via ``optional_params``. Same threat as ``api_base``: # ``s3_endpoint_url`` redirects Bedrock file uploads to attacker @@ -203,6 +256,11 @@ _BANNED_REQUEST_BODY_PARAMS: Tuple[str, ...] = ( "s3_endpoint_url", "sagemaker_base_url", "deployment_url", + # Observability credentials, hosts, and project identifiers: derived + # from the canonical ``_supported_callback_params`` allowlist so new + # integrations are covered automatically. Sorted for stable iteration + # order and reviewable diffs. + *sorted(_build_banned_observability_params()), ) @@ -221,6 +279,8 @@ def _check_banned_params( if param not in body: continue if general_settings.get("allow_client_side_credentials") is True: + # Proxy-wide opt-in: every banned param is permitted, exit + # entirely so the rest of the loop doesn't waste work. return if ( _allow_model_level_clientside_configurable_parameters( @@ -231,7 +291,12 @@ def _check_banned_params( ) is True ): - return + # Per-param opt-in: only THIS param is permitted by the + # deployment's ``configurable_clientside_auth_params``. Skip + # to the next banned param so a body that pairs an allowed + # ``api_base`` with an unallowed ``langfuse_host`` is still + # rejected for the second field. + continue raise ValueError( f"Rejected Request: {param} is not allowed in request body. " "Clientside passthrough requires explicit admin opt-in via " @@ -275,9 +340,33 @@ def is_request_body_safe( nested = request_body.get(nested_key) if isinstance(nested, dict): _check_banned_params(nested, general_settings, llm_router, model) + for metadata_key in _NESTED_METADATA_KEYS: + metadata = _coerce_metadata_to_dict(request_body.get(metadata_key)) + if metadata is not None: + _check_banned_params(metadata, general_settings, llm_router, model) return True +def _coerce_metadata_to_dict(value: Any) -> Optional[Dict[str, Any]]: + """Return ``value`` as a dict, parsing it from JSON if delivered as a string. + + Multipart/form-data and ``extra_body`` callers send ``litellm_metadata`` + as a JSON-encoded string; the proxy parses it into a dict later in + ``add_litellm_data_to_request``, but the auth-time bouncer runs first + and would otherwise miss the banned-param check on a still-stringified + metadata blob. + """ + if isinstance(value, dict): + return value + if isinstance(value, str): + from litellm.litellm_core_utils.safe_json_loads import safe_json_loads + + parsed = safe_json_loads(value) + if isinstance(parsed, dict): + return parsed + return None + + async def pre_db_read_auth_checks( request: Request, request_data: dict, diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index b05f69b439..3b22397559 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -121,6 +121,19 @@ _UNTRUSTED_ROOT_CONTROL_FIELDS = ( "pillar_response_headers", "_guardrail_pipelines", "_pipeline_managed_guardrails", + # Callback-registration fields. ``callbacks``, ``service_callback``, + # and ``logger_fn`` are read by ``litellm.utils.function_setup`` and + # appended to process-wide ``litellm.{input,success,failure,_async_*, + # service}_callback`` lists / ``litellm.user_logger_fn`` — one request + # poisons the worker for every subsequent caller. + # ``litellm_disabled_callbacks`` is the inverse primitive: the + # legitimate path reads it from key/team metadata, the request-body + # version silently turns off admin-configured audit/observability + # for the caller's request. + "callbacks", + "service_callback", + "logger_fn", + "litellm_disabled_callbacks", ) _UNTRUSTED_METADATA_CONTROL_FIELDS = ( diff --git a/litellm/proxy/vector_store_endpoints/endpoints.py b/litellm/proxy/vector_store_endpoints/endpoints.py index 86e316e7f4..ccf15c206b 100644 --- a/litellm/proxy/vector_store_endpoints/endpoints.py +++ b/litellm/proxy/vector_store_endpoints/endpoints.py @@ -8,6 +8,9 @@ from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing from litellm.proxy.utils import jsonify_object +from litellm.proxy.vector_store_endpoints.management_endpoints import ( + _resolve_embedding_config, +) from litellm.proxy.vector_store_endpoints.utils import ( assert_user_can_access_vector_store, get_litellm_managed_vector_store, @@ -56,6 +59,30 @@ async def _update_request_data_with_litellm_managed_vector_store_registry( if "litellm_params" in vector_store_to_run: litellm_params = vector_store_to_run.get("litellm_params", {}) or {} + # Resolve ``litellm_embedding_config`` here, at request-handling + # time, instead of at row-creation time. The resolved + # ``api_key`` / ``api_base`` / ``api_version`` lives only in + # this per-request ``data`` dict and is never persisted. + # Legacy rows that already carry a resolved (cleartext) + # ``litellm_embedding_config`` skip the lookup and pass through + # unchanged so the embed call keeps working. + embedding_model = litellm_params.get("litellm_embedding_model") + if embedding_model and not litellm_params.get("litellm_embedding_config"): + from litellm.proxy.proxy_server import prisma_client + + resolved_config = await _resolve_embedding_config( + embedding_model=embedding_model, prisma_client=prisma_client + ) + if resolved_config: + # Build a fresh dict via spread instead of mutating + # ``litellm_params`` in place — the registry hands back + # a reference to its cached object, so an in-place + # update would persist the resolved cleartext into the + # in-memory cache for the lifetime of the process. + litellm_params = { + **litellm_params, + "litellm_embedding_config": resolved_config, + } data.update(litellm_params) return data diff --git a/litellm/proxy/vector_store_endpoints/management_endpoints.py b/litellm/proxy/vector_store_endpoints/management_endpoints.py index 99a2085bfc..cbb3d92718 100644 --- a/litellm/proxy/vector_store_endpoints/management_endpoints.py +++ b/litellm/proxy/vector_store_endpoints/management_endpoints.py @@ -16,6 +16,7 @@ from fastapi import APIRouter, Depends, HTTPException import litellm from litellm._logging import verbose_proxy_logger +from litellm.caching.in_memory_cache import InMemoryCache from litellm.constants import REDACTED_BY_LITELM_STRING from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker @@ -45,6 +46,28 @@ _LITELLM_PARAMS_MASKER = SensitiveDataMasker() _REDACT_LITELLM_PARAMS_MAX_DEPTH = 10 +# Use-time embedding-config resolution runs on every vector-store request +# whose persisted row carries only a model reference (the post-fix shape). +# Without a cache, that's one ``litellm_proxymodeltable.find_first`` per +# request — the no-DB-in-critical-path rule. Hold the resolved config in +# memory for a short TTL so a hot model name pays the DB lookup at most +# once per ``_EMBEDDING_CONFIG_CACHE_TTL`` seconds. Cleartext credentials +# only ever live in process memory (never persisted, never echoed in +# management responses), so the cache doesn't widen the disclosure surface. +_EMBEDDING_CONFIG_CACHE_TTL = 60 +_EMBEDDING_CONFIG_CACHE_MAX_SIZE = 256 +_embedding_config_cache: Optional[InMemoryCache] = None + + +def _get_embedding_config_cache() -> InMemoryCache: + global _embedding_config_cache + if _embedding_config_cache is None: + _embedding_config_cache = InMemoryCache( + max_size_in_memory=_EMBEDDING_CONFIG_CACHE_MAX_SIZE, + default_ttl=_EMBEDDING_CONFIG_CACHE_TTL, + ) + return _embedding_config_cache + def _redact_sensitive_litellm_params(litellm_params: Any, _depth: int = 0) -> Any: """ @@ -303,6 +326,11 @@ async def _resolve_embedding_config( This function first checks the router for config-defined models, then falls back to the database. This allows users to use models defined in either location. + Results are cached in process memory for ``_EMBEDDING_CONFIG_CACHE_TTL`` + seconds so the request-handling path doesn't hit the database on every + vector-store call. Negative results (model not found) are intentionally + not cached to avoid blocking a freshly-added model behind the TTL. + Args: embedding_model: The embedding model string (e.g., "text-embedding-ada-002" or "azure/text-embedding-3-large") prisma_client: The Prisma client instance @@ -314,6 +342,11 @@ async def _resolve_embedding_config( if not embedding_model: return None + cache = _get_embedding_config_cache() + cached = cache.get_cache(embedding_model) + if cached is not None: + return cached + # Import llm_router if not provided if llm_router is None: try: @@ -330,6 +363,7 @@ async def _resolve_embedding_config( verbose_proxy_logger.debug( f"Resolved embedding config from router for model {embedding_model}" ) + cache.set_cache(embedding_model, router_config) return router_config # Fall back to database @@ -341,6 +375,7 @@ async def _resolve_embedding_config( verbose_proxy_logger.debug( f"Resolved embedding config from database for model {embedding_model}" ) + cache.set_cache(embedding_model, db_config) return db_config verbose_proxy_logger.debug( @@ -432,20 +467,17 @@ async def create_vector_store_in_db( if user_id is not None: data_to_create["user_id"] = user_id - # Handle litellm_params - always provide at least an empty dict + # Handle litellm_params - always provide at least an empty dict. + # The earlier behaviour resolved ``litellm_embedding_config`` from the + # admin-configured router/DB model and persisted the cleartext result + # (``api_key``, ``api_base``, ``api_version``) into this row. That + # exposed every env-stored embedding-model credential on the + # ``/vector_store/{new,info,update,list}`` responses. Keep the user's + # raw ``litellm_embedding_model`` reference; resolution now happens in + # ``_update_request_data_with_litellm_managed_vector_store_registry`` + # at request-handling time so the cleartext config exists only in + # per-request memory and never reaches the database. if litellm_params: - # Auto-resolve embedding config if embedding model is provided but config is not - embedding_model = litellm_params.get("litellm_embedding_model") - if embedding_model and not litellm_params.get("litellm_embedding_config"): - resolved_config = await _resolve_embedding_config( - embedding_model=embedding_model, prisma_client=prisma_client - ) - if resolved_config: - litellm_params["litellm_embedding_config"] = resolved_config - verbose_proxy_logger.info( - f"Auto-resolved embedding config for model {embedding_model}" - ) - litellm_params_dict = GenericLiteLLMParams(**litellm_params).model_dump( exclude_none=True ) @@ -531,10 +563,19 @@ async def new_vector_store( user_id=user_api_key_dict.user_id, ) + # Apply the same litellm_params redaction the list / info / update + # endpoints already use, so a caller-supplied credential or a + # cleartext value persisted by an earlier proxy version doesn't + # come back in the response. + response_vs = LiteLLM_ManagedVectorStore(**new_vector_store) + response_vs["litellm_params"] = _redact_sensitive_litellm_params( + new_vector_store.get("litellm_params") + ) + return { "status": "success", "message": f"Vector store {vector_store.get('vector_store_id')} created successfully", - "vector_store": new_vector_store, + "vector_store": response_vs, } except Exception as e: verbose_proxy_logger.exception(f"Error creating vector store: {str(e)}") @@ -865,24 +906,15 @@ async def update_vector_store( update_data["vector_store_metadata"] ) - # Handle litellm_params if provided + # Handle litellm_params if provided. As with the create path, the + # embedding-config auto-resolve previously persisted cleartext + # credentials into the row; resolution now happens at request- + # handling time in + # ``_update_request_data_with_litellm_managed_vector_store_registry`` + # so this row only ever stores the user-supplied + # ``litellm_embedding_model`` reference. if "litellm_params" in update_data: _input_litellm_params: dict = update_data.get("litellm_params", {}) or {} - - # Auto-resolve embedding config if embedding model is provided but config is not - embedding_model = _input_litellm_params.get("litellm_embedding_model") - if embedding_model and not _input_litellm_params.get( - "litellm_embedding_config" - ): - resolved_config = await _resolve_embedding_config( - embedding_model=embedding_model, prisma_client=prisma_client - ) - if resolved_config: - _input_litellm_params["litellm_embedding_config"] = resolved_config - verbose_proxy_logger.info( - f"Auto-resolved embedding config for model {embedding_model}" - ) - litellm_params_dict = GenericLiteLLMParams( **_input_litellm_params ).model_dump(exclude_none=True) diff --git a/tests/test_litellm/caching/test_qdrant_semantic_cache.py b/tests/test_litellm/caching/test_qdrant_semantic_cache.py index 13dc4b5812..949e6ccc29 100644 --- a/tests/test_litellm/caching/test_qdrant_semantic_cache.py +++ b/tests/test_litellm/caching/test_qdrant_semantic_cache.py @@ -19,9 +19,7 @@ def test_qdrant_semantic_cache_initialization(monkeypatch): patch( "litellm.llms.custom_httpx.http_handler._get_httpx_client" ) as mock_sync_client, - patch( - "litellm.llms.custom_httpx.http_handler.get_async_httpx_client" - ) as mock_async_client, + patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"), ): # Mock the collection exists check @@ -31,6 +29,9 @@ def test_qdrant_semantic_cache_initialization(monkeypatch): mock_sync_client_instance = MagicMock() mock_sync_client_instance.get.return_value = mock_response + mock_index_response = MagicMock() + mock_index_response.status_code = 200 + mock_sync_client_instance.put.return_value = mock_index_response mock_sync_client.return_value = mock_sync_client_instance from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache @@ -48,6 +49,17 @@ def test_qdrant_semantic_cache_initialization(monkeypatch): assert qdrant_cache.qdrant_api_base == "http://test.qdrant.local" assert qdrant_cache.qdrant_api_key == "test_key" assert qdrant_cache.similarity_threshold == 0.8 + mock_sync_client_instance.put.assert_called_once_with( + url="http://test.qdrant.local/collections/test_collection/index", + headers={ + "Content-Type": "application/json", + "api-key": "test_key", + }, + json={ + "field_name": QdrantSemanticCache.CACHE_KEY_FIELD_NAME, + "field_schema": "keyword", + }, + ) # Test initialization with missing similarity_threshold with pytest.raises(Exception, match="similarity_threshold must be provided"): @@ -67,9 +79,7 @@ def test_qdrant_semantic_cache_get_cache_hit(): patch( "litellm.llms.custom_httpx.http_handler._get_httpx_client" ) as mock_sync_client, - patch( - "litellm.llms.custom_httpx.http_handler.get_async_httpx_client" - ) as mock_async_client, + patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"), ): # Mock the collection exists check @@ -98,6 +108,7 @@ def test_qdrant_semantic_cache_get_cache_hit(): "result": [ { "payload": { + QdrantSemanticCache.CACHE_KEY_FIELD_NAME: "test_key", "text": "What is the capital of France?", # Original prompt "response": '{"id": "test-123", "choices": [{"message": {"content": "Paris is the capital of France."}}]}', }, @@ -127,6 +138,177 @@ def test_qdrant_semantic_cache_get_cache_hit(): # Verify search was called qdrant_cache.sync_client.post.assert_called() + assert qdrant_cache.sync_client.post.call_args.kwargs["json"]["filter"] == { + "must": [ + { + "key": QdrantSemanticCache.CACHE_KEY_FIELD_NAME, + "match": {"value": "test_key"}, + } + ] + } + + +def test_qdrant_semantic_cache_rejects_unscoped_cache_hit(): + """ + Test QDRANT semantic cache rejects old or unscoped cache hits. + + Legacy points have only text and response payloads, so they cannot be + safely migrated to a generated LiteLLM cache key. + """ + with ( + patch( + "litellm.llms.custom_httpx.http_handler._get_httpx_client" + ) as mock_sync_client, + patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"), + ): + + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"result": {"exists": True}} + + mock_sync_client_instance = MagicMock() + mock_sync_client_instance.get.return_value = mock_response + mock_sync_client.return_value = mock_sync_client_instance + + from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache + + qdrant_cache = QdrantSemanticCache( + collection_name="test_collection", + qdrant_api_base="http://test.qdrant.local", + qdrant_api_key="test_key", + similarity_threshold=0.8, + ) + + mock_search_response = MagicMock() + mock_search_response.status_code = 200 + mock_search_response.json.return_value = { + "result": [ + { + "payload": { + "text": "What is the capital of France?", + "response": '{"id": "test-123"}', + }, + "score": 0.9, + } + ] + } + qdrant_cache.sync_client.post = MagicMock(return_value=mock_search_response) + + with patch( + "litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]} + ): + metadata = {} + result = qdrant_cache.get_cache( + key="test_key", + messages=[{"content": "What is the capital of France?"}], + metadata=metadata, + ) + + assert result is None + assert metadata["semantic-similarity"] == 0.0 + + +def test_qdrant_semantic_cache_payload_index_failure_is_non_blocking(): + from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache + + qdrant_cache = QdrantSemanticCache.__new__(QdrantSemanticCache) + qdrant_cache.qdrant_api_base = "http://test.qdrant.local" + qdrant_cache.collection_name = "test_collection" + qdrant_cache.headers = {"Content-Type": "application/json"} + qdrant_cache.sync_client = MagicMock() + response = MagicMock() + response.status_code = 400 + response.text = "bad index" + qdrant_cache.sync_client.put.return_value = response + + qdrant_cache._ensure_cache_key_payload_index() + + qdrant_cache.sync_client.put.assert_called_once() + + +def test_qdrant_semantic_cache_payload_index_exception_is_non_blocking(): + from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache + + qdrant_cache = QdrantSemanticCache.__new__(QdrantSemanticCache) + qdrant_cache.qdrant_api_base = "http://test.qdrant.local" + qdrant_cache.collection_name = "test_collection" + qdrant_cache.headers = {"Content-Type": "application/json"} + qdrant_cache.sync_client = MagicMock() + qdrant_cache.sync_client.put.side_effect = Exception("boom") + + qdrant_cache._ensure_cache_key_payload_index() + + qdrant_cache.sync_client.put.assert_called_once() + + +def _mock_qdrant_get_cache_result(qdrant_result): + from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache + + qdrant_cache = QdrantSemanticCache.__new__(QdrantSemanticCache) + qdrant_cache.embedding_model = "text-embedding-ada-002" + qdrant_cache.qdrant_api_base = "http://test.qdrant.local" + qdrant_cache.collection_name = "test_collection" + qdrant_cache.headers = { + "Content-Type": "application/json", + "api-key": "test_key", + } + qdrant_cache.similarity_threshold = 0.8 + qdrant_cache.sync_client = MagicMock() + + mock_search_response = MagicMock() + mock_search_response.status_code = 200 + mock_search_response.json.return_value = {"result": qdrant_result} + qdrant_cache.sync_client.post.return_value = mock_search_response + + return qdrant_cache, QdrantSemanticCache + + +@pytest.mark.parametrize("qdrant_result", [None, []]) +def test_qdrant_semantic_cache_get_cache_sets_metadata_on_empty_miss(qdrant_result): + qdrant_cache, _ = _mock_qdrant_get_cache_result(qdrant_result) + metadata = {} + + with patch( + "litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]} + ): + result = qdrant_cache.get_cache( + key="test_key", + messages=[{"content": "What is the capital of Spain?"}], + metadata=metadata, + ) + + assert result is None + assert metadata["semantic-similarity"] == 0.0 + + +def test_qdrant_semantic_cache_get_cache_sets_metadata_on_below_threshold_miss(): + from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache + + qdrant_cache, _ = _mock_qdrant_get_cache_result( + [ + { + "payload": { + QdrantSemanticCache.CACHE_KEY_FIELD_NAME: "test_key", + "text": "What is the capital of Spain?", + "response": '{"id": "test-456"}', + }, + "score": 0.7, + } + ] + ) + metadata = {} + + with patch( + "litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]} + ): + result = qdrant_cache.get_cache( + key="test_key", + messages=[{"content": "What is the capital of Spain?"}], + metadata=metadata, + ) + + assert result is None + assert metadata["semantic-similarity"] == 0.7 def test_qdrant_semantic_cache_get_cache_miss(): @@ -138,9 +320,7 @@ def test_qdrant_semantic_cache_get_cache_miss(): patch( "litellm.llms.custom_httpx.http_handler._get_httpx_client" ) as mock_sync_client, - patch( - "litellm.llms.custom_httpx.http_handler.get_async_httpx_client" - ) as mock_async_client, + patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"), ): # Mock the collection exists check @@ -230,6 +410,7 @@ async def test_qdrant_semantic_cache_async_get_cache_hit(): "result": [ { "payload": { + QdrantSemanticCache.CACHE_KEY_FIELD_NAME: "test_key", "text": "What is the capital of Spain?", # Original prompt "response": '{"id": "test-456", "choices": [{"message": {"content": "Madrid is the capital of Spain."}}]}', }, @@ -262,6 +443,16 @@ async def test_qdrant_semantic_cache_async_get_cache_hit(): # Verify async search was called qdrant_cache.async_client.post.assert_called() + assert qdrant_cache.async_client.post.call_args.kwargs["json"][ + "filter" + ] == { + "must": [ + { + "key": QdrantSemanticCache.CACHE_KEY_FIELD_NAME, + "match": {"value": "test_key"}, + } + ] + } @pytest.mark.asyncio @@ -336,9 +527,7 @@ def test_qdrant_semantic_cache_set_cache(): patch( "litellm.llms.custom_httpx.http_handler._get_httpx_client" ) as mock_sync_client, - patch( - "litellm.llms.custom_httpx.http_handler.get_async_httpx_client" - ) as mock_async_client, + patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"), ): # Mock the collection exists check @@ -384,6 +573,12 @@ def test_qdrant_semantic_cache_set_cache(): # Verify upsert was called qdrant_cache.sync_client.put.assert_called() + upsert_payload = qdrant_cache.sync_client.put.call_args.kwargs["json"][ + "points" + ][0]["payload"] + assert ( + upsert_payload[QdrantSemanticCache.CACHE_KEY_FIELD_NAME] == "test_key" + ) @pytest.mark.asyncio @@ -450,6 +645,12 @@ async def test_qdrant_semantic_cache_async_set_cache(): # Verify async upsert was called qdrant_cache.async_client.put.assert_called() + upsert_payload = qdrant_cache.async_client.put.call_args.kwargs["json"][ + "points" + ][0]["payload"] + assert ( + upsert_payload[QdrantSemanticCache.CACHE_KEY_FIELD_NAME] == "test_key" + ) def test_qdrant_semantic_cache_custom_vector_size(): @@ -462,9 +663,7 @@ def test_qdrant_semantic_cache_custom_vector_size(): patch( "litellm.llms.custom_httpx.http_handler._get_httpx_client" ) as mock_sync_client, - patch( - "litellm.llms.custom_httpx.http_handler.get_async_httpx_client" - ) as mock_async_client, + patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"), ): # Mock the collection does NOT exist (so it will be created) @@ -505,9 +704,13 @@ def test_qdrant_semantic_cache_custom_vector_size(): assert qdrant_cache.vector_size == 768 # Verify the PUT call to create the collection used vector_size=768 - put_call = mock_sync_client_instance.put.call_args - assert put_call is not None - create_payload = put_call.kwargs.get("json") or put_call[1].get("json") + put_call = next( + call + for call in mock_sync_client_instance.put.call_args_list + if call.kwargs["url"] + == "http://test.qdrant.local/collections/test_collection_768" + ) + create_payload = put_call.kwargs["json"] assert create_payload["vectors"]["size"] == 768 assert create_payload["vectors"]["distance"] == "Cosine" @@ -521,9 +724,7 @@ def test_qdrant_semantic_cache_default_vector_size(): patch( "litellm.llms.custom_httpx.http_handler._get_httpx_client" ) as mock_sync_client, - patch( - "litellm.llms.custom_httpx.http_handler.get_async_httpx_client" - ) as mock_async_client, + patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"), ): # Mock the collection exists check @@ -559,9 +760,7 @@ def test_qdrant_semantic_cache_large_vector_size(): patch( "litellm.llms.custom_httpx.http_handler._get_httpx_client" ) as mock_sync_client, - patch( - "litellm.llms.custom_httpx.http_handler.get_async_httpx_client" - ) as mock_async_client, + patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"), ): # Mock the collection does NOT exist (so it will be created) @@ -599,6 +798,11 @@ def test_qdrant_semantic_cache_large_vector_size(): assert qdrant_cache.vector_size == 4096 # Verify the collection was created with 4096 - put_call = mock_sync_client_instance.put.call_args - create_payload = put_call.kwargs.get("json") or put_call[1].get("json") + put_call = next( + call + for call in mock_sync_client_instance.put.call_args_list + if call.kwargs["url"] + == "http://test.qdrant.local/collections/test_collection_4096" + ) + create_payload = put_call.kwargs["json"] assert create_payload["vectors"]["size"] == 4096 diff --git a/tests/test_litellm/caching/test_redis_semantic_cache.py b/tests/test_litellm/caching/test_redis_semantic_cache.py index f9946e266f..b50a35ef50 100644 --- a/tests/test_litellm/caching/test_redis_semantic_cache.py +++ b/tests/test_litellm/caching/test_redis_semantic_cache.py @@ -72,24 +72,301 @@ def test_redis_semantic_cache_get_cache(monkeypatch): "prompt": "What is the capital of France?", "response": '{"content": "Paris is the capital of France."}', "vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9 + RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key", } ] redis_semantic_cache.llmcache.check = MagicMock(return_value=mock_result) # Mock the embedding function - with patch( - "litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]} + with ( + patch( + "litellm.embedding", + return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}, + ), + patch.object( + redis_semantic_cache, + "_get_cache_key_filter_expression", + return_value="cache-key-filter", + ), ): # Test get_cache with a message + metadata = {} result = redis_semantic_cache.get_cache( - key="test_key", messages=[{"content": "What is the capital of France?"}] + key="test_key", + messages=[{"content": "What is the capital of France?"}], + metadata=metadata, ) # Verify result is properly parsed assert result == {"content": "Paris is the capital of France."} + assert metadata["semantic-similarity"] == pytest.approx(0.9) # Verify llmcache.check was called - redis_semantic_cache.llmcache.check.assert_called_once() + redis_semantic_cache.llmcache.check.assert_called_once_with( + prompt="What is the capital of France?", + filter_expression="cache-key-filter", + ) + + +def test_redis_semantic_cache_rejects_unscoped_cache_hit(monkeypatch): + semantic_cache_mock = MagicMock() + custom_vectorizer_mock = MagicMock() + + with patch.dict( + "sys.modules", + { + "redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock), + "redisvl.utils.vectorize": MagicMock( + CustomTextVectorizer=custom_vectorizer_mock + ), + }, + ): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + monkeypatch.setenv("REDIS_HOST", "localhost") + monkeypatch.setenv("REDIS_PORT", "6379") + monkeypatch.setenv("REDIS_PASSWORD", "test_password") + + redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8) + redis_semantic_cache.llmcache.check = MagicMock( + return_value=[ + { + "prompt": "What is the capital of France?", + "response": '{"content": "Paris"}', + "vector_distance": 0.1, + } + ] + ) + + with patch.object( + redis_semantic_cache, + "_get_cache_key_filter_expression", + return_value="cache-key-filter", + ): + metadata = {} + result = redis_semantic_cache.get_cache( + key="test_key", + messages=[{"content": "What is the capital of France?"}], + metadata=metadata, + ) + + assert result is None + assert metadata["semantic-similarity"] == 0.0 + + +def test_redis_semantic_cache_set_cache_stores_cache_key_filter(monkeypatch): + semantic_cache_mock = MagicMock() + custom_vectorizer_mock = MagicMock() + + with patch.dict( + "sys.modules", + { + "redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock), + "redisvl.utils.vectorize": MagicMock( + CustomTextVectorizer=custom_vectorizer_mock + ), + }, + ): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + monkeypatch.setenv("REDIS_HOST", "localhost") + monkeypatch.setenv("REDIS_PORT", "6379") + monkeypatch.setenv("REDIS_PASSWORD", "test_password") + + redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8) + redis_semantic_cache.llmcache.store = MagicMock() + + redis_semantic_cache.set_cache( + key="test_key", + value={"content": "Paris"}, + messages=[{"content": "What is the capital of France?"}], + ttl=60, + ) + + redis_semantic_cache.llmcache.store.assert_called_once_with( + "What is the capital of France?", + "{'content': 'Paris'}", + filters={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"}, + ttl=60, + ) + + +def test_redis_semantic_cache_uses_isolated_index_for_old_schema(monkeypatch): + fallback_cache_mock = MagicMock() + semantic_cache_mock = MagicMock( + side_effect=[ + ValueError("stored index schema differs from requested fields"), + fallback_cache_mock, + ] + ) + custom_vectorizer_mock = MagicMock() + + with patch.dict( + "sys.modules", + { + "redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock), + "redisvl.utils.vectorize": MagicMock( + CustomTextVectorizer=custom_vectorizer_mock + ), + }, + ): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + monkeypatch.setenv("REDIS_HOST", "localhost") + monkeypatch.setenv("REDIS_PORT", "6379") + monkeypatch.setenv("REDIS_PASSWORD", "test_password") + + redis_semantic_cache = RedisSemanticCache( + similarity_threshold=0.8, + index_name="existing_index", + ) + + assert redis_semantic_cache.llmcache is fallback_cache_mock + assert semantic_cache_mock.call_args_list[0].kwargs["name"] == "existing_index" + assert ( + semantic_cache_mock.call_args_list[1].kwargs["name"] + == "existing_index_isolated" + ) + assert semantic_cache_mock.call_args_list[1].kwargs["filterable_fields"] == [ + RedisSemanticCache._cache_key_filterable_field() + ] + + +def test_redis_semantic_cache_overwrites_stale_isolated_index(monkeypatch): + fallback_cache_mock = MagicMock() + semantic_cache_mock = MagicMock( + side_effect=[ + ValueError("Existing index schema does not match"), + ValueError("Existing index schema does not match"), + fallback_cache_mock, + ] + ) + custom_vectorizer_mock = MagicMock() + + with patch.dict( + "sys.modules", + { + "redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock), + "redisvl.utils.vectorize": MagicMock( + CustomTextVectorizer=custom_vectorizer_mock + ), + }, + ): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + monkeypatch.setenv("REDIS_HOST", "localhost") + monkeypatch.setenv("REDIS_PORT", "6379") + monkeypatch.setenv("REDIS_PASSWORD", "test_password") + + redis_semantic_cache = RedisSemanticCache( + similarity_threshold=0.8, + index_name="existing_index", + ) + + assert redis_semantic_cache.llmcache is fallback_cache_mock + assert ( + semantic_cache_mock.call_args_list[2].kwargs["name"] + == "existing_index_isolated" + ) + assert semantic_cache_mock.call_args_list[2].kwargs["overwrite"] is True + assert semantic_cache_mock.call_args_list[2].kwargs["filterable_fields"] == [ + RedisSemanticCache._cache_key_filterable_field() + ] + + +def test_redis_semantic_cache_reraises_unexpected_isolated_index_error(monkeypatch): + semantic_cache_mock = MagicMock( + side_effect=[ + ValueError("Existing index schema does not match"), + ValueError("connection failed"), + ] + ) + custom_vectorizer_mock = MagicMock() + + with patch.dict( + "sys.modules", + { + "redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock), + "redisvl.utils.vectorize": MagicMock( + CustomTextVectorizer=custom_vectorizer_mock + ), + }, + ): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + monkeypatch.setenv("REDIS_HOST", "localhost") + monkeypatch.setenv("REDIS_PORT", "6379") + monkeypatch.setenv("REDIS_PASSWORD", "test_password") + + with pytest.raises(ValueError, match="connection failed"): + RedisSemanticCache( + similarity_threshold=0.8, + index_name="existing_index", + ) + + +def test_redis_semantic_cache_reraises_unexpected_index_error(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + redis_semantic_cache.distance_threshold = 0.2 + semantic_cache_mock = MagicMock(side_effect=ValueError("connection failed")) + + with pytest.raises(ValueError, match="connection failed"): + redis_semantic_cache._init_semantic_cache( + semantic_cache_cls=semantic_cache_mock, + index_name="existing_index", + redis_url="redis://localhost:6379", + cache_vectorizer=MagicMock(), + ) + + +def test_redis_semantic_cache_matches_bytes_cache_key(): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + + assert redis_semantic_cache._cache_hit_matches_key( + cache_hit={RedisSemanticCache.CACHE_KEY_FIELD_NAME: b"test_key"}, + key="test_key", + ) + + +def test_redis_semantic_cache_rejects_pre_isolation_unscoped_hit(): + """Pre-isolation entries with no cache-key field cannot be safely + reassigned to a caller's scope and are treated as misses.""" + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + + cache_hit = { + "prompt": "What is the capital of France?", + "response": '{"content": "Paris"}', + "vector_distance": 0.1, + } + assert not redis_semantic_cache._cache_hit_matches_key( + cache_hit=cache_hit, + key="test_key", + ) + + +def test_redis_semantic_cache_builds_filter_expression(monkeypatch): + class FakeTag: + def __init__(self, field_name): + self.field_name = field_name + + def __eq__(self, value): + return (self.field_name, value) + + with patch.dict("sys.modules", {"redisvl.query.filter": MagicMock(Tag=FakeTag)}): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache) + + assert redis_semantic_cache._get_cache_key_filter_expression("test_key") == ( + RedisSemanticCache.CACHE_KEY_FIELD_NAME, + "test_key", + ) @pytest.mark.asyncio @@ -123,6 +400,7 @@ async def test_redis_semantic_cache_async_get_cache(monkeypatch): "prompt": "What is the capital of France?", "response": '{"content": "Paris is the capital of France."}', "vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9 + RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key", } ] @@ -131,16 +409,117 @@ async def test_redis_semantic_cache_async_get_cache(monkeypatch): return_value=[0.1, 0.2, 0.3] ) - # Test async_get_cache with a message - result = await redis_semantic_cache.async_get_cache( - key="test_key", - messages=[{"content": "What is the capital of France?"}], - metadata={}, - ) + with patch.object( + redis_semantic_cache, + "_get_cache_key_filter_expression", + return_value="cache-key-filter", + ): + # Test async_get_cache with a message + result = await redis_semantic_cache.async_get_cache( + key="test_key", + messages=[{"content": "What is the capital of France?"}], + metadata={}, + ) # Verify result is properly parsed assert result == {"content": "Paris is the capital of France."} # Verify methods were called redis_semantic_cache._get_async_embedding.assert_called_once() - redis_semantic_cache.llmcache.acheck.assert_called_once() + redis_semantic_cache.llmcache.acheck.assert_called_once_with( + prompt="What is the capital of France?", + vector=[0.1, 0.2, 0.3], + filter_expression="cache-key-filter", + ) + + +@pytest.mark.asyncio +async def test_redis_semantic_cache_async_get_cache_rejects_unscoped_hit(monkeypatch): + semantic_cache_mock = MagicMock() + custom_vectorizer_mock = MagicMock() + + with patch.dict( + "sys.modules", + { + "redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock), + "redisvl.utils.vectorize": MagicMock( + CustomTextVectorizer=custom_vectorizer_mock + ), + }, + ): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + monkeypatch.setenv("REDIS_HOST", "localhost") + monkeypatch.setenv("REDIS_PORT", "6379") + monkeypatch.setenv("REDIS_PASSWORD", "test_password") + + redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8) + redis_semantic_cache.llmcache.acheck = AsyncMock( + return_value=[ + { + "prompt": "What is the capital of France?", + "response": '{"content": "Paris"}', + "vector_distance": 0.1, + } + ] + ) + redis_semantic_cache._get_async_embedding = AsyncMock( + return_value=[0.1, 0.2, 0.3] + ) + + with patch.object( + redis_semantic_cache, + "_get_cache_key_filter_expression", + return_value="cache-key-filter", + ): + result = await redis_semantic_cache.async_get_cache( + key="test_key", + messages=[{"content": "What is the capital of France?"}], + metadata={}, + ) + + assert result is None + + +@pytest.mark.asyncio +async def test_redis_semantic_cache_async_set_cache_stores_cache_key_filter( + monkeypatch, +): + semantic_cache_mock = MagicMock() + custom_vectorizer_mock = MagicMock() + + with patch.dict( + "sys.modules", + { + "redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock), + "redisvl.utils.vectorize": MagicMock( + CustomTextVectorizer=custom_vectorizer_mock + ), + }, + ): + from litellm.caching.redis_semantic_cache import RedisSemanticCache + + monkeypatch.setenv("REDIS_HOST", "localhost") + monkeypatch.setenv("REDIS_PORT", "6379") + monkeypatch.setenv("REDIS_PASSWORD", "test_password") + + redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8) + redis_semantic_cache.llmcache.astore = AsyncMock() + redis_semantic_cache._get_async_embedding = AsyncMock( + return_value=[0.1, 0.2, 0.3] + ) + + await redis_semantic_cache.async_set_cache( + key="test_key", + value={"content": "Paris"}, + messages=[{"content": "What is the capital of France?"}], + ttl=60, + ) + + redis_semantic_cache.llmcache.astore.assert_called_once_with( + "What is the capital of France?", + "{'content': 'Paris'}", + vector=[0.1, 0.2, 0.3], + filters={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"}, + ttl=60, + ) diff --git a/tests/test_litellm/proxy/auth/test_auth_utils.py b/tests/test_litellm/proxy/auth/test_auth_utils.py index c146b5ded5..7c04a4f61f 100644 --- a/tests/test_litellm/proxy/auth/test_auth_utils.py +++ b/tests/test_litellm/proxy/auth/test_auth_utils.py @@ -1292,3 +1292,219 @@ class TestIsRequestBodySafeNestedConfig: ) is True ) + + +# ── observability-callback ban (root + metadata) ─────────────────────────── + + +class TestObservabilityCallbackBans: + """The proxy must reject observability credentials, hosts, and project + identifiers regardless of whether they arrive at the request body root, + in ``metadata`` / ``litellm_metadata``, or in a JSON-string-encoded + metadata blob (multipart/``extra_body`` path). + + The ban list is derived from + ``litellm.litellm_core_utils.initialize_dynamic_callback_params._supported_callback_params`` + minus a small ``_SAFE_CLIENT_CALLBACK_PARAMS`` allow-list, plus + ``_EXTRA_BANNED_OBSERVABILITY_PARAMS`` for fields integrations read but + that are not yet in the canonical allow-list. The derivation keeps the + proxy in sync as new integrations are added. + """ + + @pytest.fixture(autouse=True) + def _disable_url_validation(self, monkeypatch): + import litellm + + monkeypatch.setattr(litellm, "user_url_validation", False, raising=False) + + @pytest.mark.parametrize( + "field", + [ + "langfuse_public_key", + "langfuse_secret", + "langfuse_secret_key", + "langsmith_api_key", + "langsmith_project", + "langsmith_tenant_id", + "arize_api_key", + "arize_space_key", + "arize_space_id", + "posthog_api_key", + "posthog_api_url", + "braintrust_api_key", + "braintrust_project", + "phoenix_project_name", + "wandb_api_key", + "weave_project_id", + "gcs_bucket_name", + "gcs_path_service_account", + "humanloop_api_key", + "lunary_public_key", + ], + ) + def test_observability_field_in_request_body_root_is_rejected(self, field): + with pytest.raises(ValueError) as exc: + is_request_body_safe( + request_body={"model": "gpt-4", field: "attacker-value"}, + general_settings={}, + llm_router=None, + model="gpt-4", + ) + assert field in str(exc.value) + + @pytest.mark.parametrize( + "metadata_key", + ["metadata", "litellm_metadata"], + ) + @pytest.mark.parametrize( + "field", + [ + "langfuse_host", + "langfuse_secret_key", + "langsmith_api_key", + "posthog_api_url", + "braintrust_project", + "phoenix_project_name", + ], + ) + def test_observability_field_in_metadata_dict_is_rejected( + self, metadata_key, field + ): + # Verifies the metadata walk: a value smuggled inside ``metadata`` + # or ``litellm_metadata`` is just as dangerous as the same field + # at the body root, and must hit the same gate. + with pytest.raises(ValueError) as exc: + is_request_body_safe( + request_body={ + "model": "gpt-4", + metadata_key: {field: "attacker-value"}, + }, + general_settings={}, + llm_router=None, + model="gpt-4", + ) + assert field in str(exc.value) + + @pytest.mark.parametrize( + "metadata_key", + ["metadata", "litellm_metadata"], + ) + def test_observability_field_in_json_string_metadata_is_rejected( + self, metadata_key + ): + # Multipart/form-data and ``extra_body`` callers send metadata as a + # JSON-encoded string. The bouncer parses it before applying the + # banned-params check so the JSON-string path can't smuggle past + # the ``isinstance(dict)`` guard. + import json + + with pytest.raises(ValueError) as exc: + is_request_body_safe( + request_body={ + "model": "gpt-4", + metadata_key: json.dumps( + {"langfuse_host": "https://attacker.example"} + ), + }, + general_settings={}, + llm_router=None, + model="gpt-4", + ) + assert "langfuse_host" in str(exc.value) + + def test_admin_opt_in_allows_metadata_credential_passthrough(self): + # The opt-in gate covers the metadata path the same way it covers + # the root path — operators running BYO observability with + # clientside creds flip a single flag and both paths work. + assert ( + is_request_body_safe( + request_body={ + "model": "gpt-4", + "metadata": { + "langfuse_host": "https://my-langfuse.example", + "langfuse_public_key": "pk-mine", + "langfuse_secret_key": "sk-mine", + }, + }, + general_settings={"allow_client_side_credentials": True}, + llm_router=None, + model="gpt-4", + ) + is True + ) + + def test_safe_per_request_observability_metadata_is_allowed(self): + # Informational fields (sampling rate, prompt version) describe + # the request being logged — they don't choose the destination or + # credentials, so they must remain accepted from clients without + # the opt-in flag. + assert ( + is_request_body_safe( + request_body={ + "model": "gpt-4", + "metadata": { + "langfuse_prompt_version": "v2", + "langsmith_sampling_rate": 0.1, + }, + }, + general_settings={}, + llm_router=None, + model="gpt-4", + ) + is True + ) + + +def test_model_level_allow_does_not_skip_subsequent_banned_params(monkeypatch): + """Greptile P1: ``_check_banned_params`` previously ``return``-ed when a + deployment's ``configurable_clientside_auth_params`` permitted one + banned field, exiting before any later banned field in the same body + was checked. The metadata walk this PR adds multiplies the surface + where that bypass matters: a body pairing a model-level-allowed + ``api_base`` with an observability credential like ``langfuse_host`` + must still reject on the second field, not silently pass.""" + from litellm.proxy.auth import auth_utils + + monkeypatch.setattr( + auth_utils, + "_allow_model_level_clientside_configurable_parameters", + lambda model, param, request_body_value, llm_router: param == "api_base", + ) + + with pytest.raises(ValueError) as exc: + is_request_body_safe( + request_body={ + "model": "gpt-4", + "api_base": "https://allowed-by-deployment.example", + "langfuse_host": "https://attacker.example", + }, + general_settings={}, + llm_router=None, + model="gpt-4", + ) + assert "langfuse_host" in str(exc.value) + + +def test_observability_ban_covers_canonical_supported_callback_params(): + """Guard test: every entry in the canonical + ``_supported_callback_params`` allow-list must end up either banned by + the proxy or explicitly safe-listed. New integrations added to that + list are banned by default (the safe failure mode); flagging them as + safe is an explicit decision recorded in + ``_SAFE_CLIENT_CALLBACK_PARAMS``.""" + from litellm.litellm_core_utils.initialize_dynamic_callback_params import ( + _supported_callback_params, + ) + from litellm.proxy.auth.auth_utils import ( + _BANNED_REQUEST_BODY_PARAMS, + _SAFE_CLIENT_CALLBACK_PARAMS, + ) + + banned = set(_BANNED_REQUEST_BODY_PARAMS) + for param in _supported_callback_params: + assert param in banned or param in _SAFE_CLIENT_CALLBACK_PARAMS, ( + f"{param} is in _supported_callback_params but neither banned nor " + f"safe-listed. Add it to _SAFE_CLIENT_CALLBACK_PARAMS if it is an " + f"informational per-request field; otherwise the derivation will " + f"ban it automatically." + ) diff --git a/tests/test_litellm/proxy/test_litellm_pre_call_utils.py b/tests/test_litellm/proxy/test_litellm_pre_call_utils.py index 4e24d8af65..d2a1468be2 100644 --- a/tests/test_litellm/proxy/test_litellm_pre_call_utils.py +++ b/tests/test_litellm/proxy/test_litellm_pre_call_utils.py @@ -596,6 +596,59 @@ async def test_add_litellm_data_to_request_strips_user_control_fields(): assert "pillar_response_headers" not in snapshot_body["metadata"] +@pytest.mark.asyncio +@pytest.mark.parametrize( + "control_field", + ["callbacks", "service_callback", "logger_fn", "litellm_disabled_callbacks"], +) +async def test_add_litellm_data_to_request_strips_callback_control_fields( + control_field, +): + """``callbacks`` / ``service_callback`` / ``logger_fn`` get appended to + the worker-wide ``litellm.{input,success,failure,_async_*,service}_callback`` + lists and ``litellm.user_logger_fn`` from inside ``function_setup`` — + one request poisons every subsequent caller in that worker. + ``litellm_disabled_callbacks`` is the inverse: a request-body value + silently disables admin-configured audit/observability for the call. + None has a documented per-request use, so all four are stripped at + the proxy boundary alongside the existing internal-only fields.""" + request_mock = MagicMock(spec=Request) + request_mock.url.path = "/v1/chat/completions" + request_mock.url = MagicMock() + request_mock.url.__str__.return_value = "http://localhost/v1/chat/completions" + request_mock.method = "POST" + request_mock.query_params = {} + request_mock.headers = {"Content-Type": "application/json"} + request_mock.client = MagicMock() + request_mock.client.host = "127.0.0.1" + + sample_value = ( + ["langfuse"] + if control_field + in ("callbacks", "service_callback", "litellm_disabled_callbacks") + else "module.func" + ) + + updated = await add_litellm_data_to_request( + data={ + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "hi"}], + control_field: sample_value, + }, + request=request_mock, + user_api_key_dict=UserAPIKeyAuth(api_key="hashed-key"), + proxy_config=MagicMock(), + general_settings={}, + version="test-version", + ) + + assert control_field not in updated + # The post-strip body snapshot used by audit/spend logging must also + # not retain the attacker-injected control field. + snapshot_body = updated["proxy_server_request"]["body"] + assert control_field not in snapshot_body + + @pytest.mark.asyncio async def test_add_litellm_data_to_request_allows_client_mock_response_with_admin_opt_in(): request_mock = MagicMock(spec=Request) diff --git a/tests/test_litellm/proxy/vector_store_endpoints/test_vector_store_endpoints.py b/tests/test_litellm/proxy/vector_store_endpoints/test_vector_store_endpoints.py index e67a04c749..81b67e8bc5 100644 --- a/tests/test_litellm/proxy/vector_store_endpoints/test_vector_store_endpoints.py +++ b/tests/test_litellm/proxy/vector_store_endpoints/test_vector_store_endpoints.py @@ -36,6 +36,33 @@ from litellm.proxy.vector_store_endpoints.utils import ( from litellm.types.utils import LlmProviders +def _serialize_litellm_params(litellm_params): + """Serialize ``litellm_params`` to a string for substring assertions. + + The redact helper preserves the persisted shape — string in, string + out; dict in, dict out — so callers that just want to assert "this + secret never appears" need a single text representation either way. + """ + import json + + if isinstance(litellm_params, str): + return litellm_params + return json.dumps(litellm_params or {}) + + +@pytest.fixture(autouse=True) +def _reset_embedding_config_cache(): + """The use-time embedding-config resolver caches results in process + memory across calls. Reset it before every test so the resolver + actually exercises the router/DB path under test instead of returning + a value cached by an earlier test.""" + from litellm.proxy.vector_store_endpoints import management_endpoints + + management_endpoints._embedding_config_cache = None + yield + management_endpoints._embedding_config_cache = None + + @pytest.mark.asyncio async def test_router_avector_store_search_passes_correct_args(): """ @@ -170,6 +197,93 @@ async def test_update_request_data_with_litellm_managed_vector_store_registry(): assert result == original_data +@pytest.mark.asyncio +async def test_update_request_data_resolves_embedding_config_at_use_time(): + """When the persisted vector store row carries only a + ``litellm_embedding_model`` reference (the new behaviour after + moving the auto-resolve out of write time), the request-handling + layer must resolve the embedding config so the downstream embed + call still has ``api_key`` / ``api_base`` / ``api_version``. The + resolved config lives in this per-request data dict only — never + persisted.""" + mock_vector_store: LiteLLM_ManagedVectorStore = { + "vector_store_id": "test_store", + "custom_llm_provider": "azure_ai", + "litellm_params": { + "litellm_embedding_model": "azure/text-embedding-3-large", + # Note: no litellm_embedding_config persisted + }, + } + + mock_registry = MagicMock() + mock_registry.get_litellm_managed_vector_store_from_registry.return_value = ( + mock_vector_store + ) + + resolved = { + "api_key": "use-time-resolved-key", + "api_base": "https://my-azure.example", + "api_version": "2024-09-01", + } + + with ( + patch.object(litellm, "vector_store_registry", mock_registry), + patch( + "litellm.proxy.vector_store_endpoints.endpoints._resolve_embedding_config", + new=AsyncMock(return_value=resolved), + ), + ): + result = await _update_request_data_with_litellm_managed_vector_store_registry( + data={}, vector_store_id="test_store" + ) + + assert result["litellm_embedding_model"] == "azure/text-embedding-3-large" + assert result["litellm_embedding_config"] == resolved + + +@pytest.mark.asyncio +async def test_update_request_data_passes_through_legacy_embedding_config(): + """A vector store row created by an older proxy version may already + carry a fully-resolved ``litellm_embedding_config`` in its persisted + ``litellm_params`` (the very leak this PR closes). Those legacy rows + must still work — the use-time resolver skips re-resolution when + the config is already present so the embed call keeps succeeding.""" + legacy_config = { + "api_key": "legacy-cleartext-key", + "api_base": "https://legacy-azure.example", + "api_version": "2024-01-01", + } + mock_vector_store: LiteLLM_ManagedVectorStore = { + "vector_store_id": "legacy_store", + "custom_llm_provider": "azure_ai", + "litellm_params": { + "litellm_embedding_model": "azure/text-embedding-3-large", + "litellm_embedding_config": legacy_config, + }, + } + + mock_registry = MagicMock() + mock_registry.get_litellm_managed_vector_store_from_registry.return_value = ( + mock_vector_store + ) + + resolve_mock = AsyncMock() + + with ( + patch.object(litellm, "vector_store_registry", mock_registry), + patch( + "litellm.proxy.vector_store_endpoints.endpoints._resolve_embedding_config", + new=resolve_mock, + ), + ): + result = await _update_request_data_with_litellm_managed_vector_store_registry( + data={}, vector_store_id="legacy_store" + ) + + assert result["litellm_embedding_config"] == legacy_config + resolve_mock.assert_not_awaited() + + class TestCheckVectorStorePermission: """Test suite for check_vector_store_permission function.""" @@ -1417,20 +1531,25 @@ async def test_new_vector_store_auto_resolves_embedding_config(): ) assert result["status"] == "success" - # Verify that embedding config was resolved and included in the create call + # Auto-resolve no longer happens at create time — the persisted row + # carries only the model reference, never the resolved cleartext + # credential. Resolution now happens at request-handling time inside + # ``_update_request_data_with_litellm_managed_vector_store_registry``, + # where the resolved config lives in per-request memory and is never + # written to the database. litellm_params_json = captured_create_data.get("litellm_params") assert litellm_params_json is not None litellm_params_dict = json.loads(litellm_params_json) - assert "litellm_embedding_config" in litellm_params_dict - assert ( - litellm_params_dict["litellm_embedding_config"]["api_key"] == "resolved-api-key" - ) - assert ( - litellm_params_dict["litellm_embedding_config"]["api_base"] - == "https://api.openai.com" - ) - assert ( - litellm_params_dict["litellm_embedding_config"]["api_version"] == "2024-01-01" + assert "litellm_embedding_config" not in litellm_params_dict + assert litellm_params_dict["litellm_embedding_model"] == "text-embedding-ada-002" + + # The response must also not echo a cleartext credential — even on + # the create response, where redaction guards against caller-supplied + # cleartext or pre-existing rows that were created by an earlier + # proxy version. + response_vs = result["vector_store"] + assert "resolved-api-key" not in _serialize_litellm_params( + response_vs.get("litellm_params") ) @@ -1578,6 +1697,43 @@ async def test_resolve_embedding_config_tries_router_then_db(): mock_prisma_client.db.litellm_proxymodeltable.find_first.assert_not_called() +@pytest.mark.asyncio +async def test_resolve_embedding_config_caches_result(): + """The first lookup should hit the router/DB; subsequent lookups for + the same model name should return the cached value without touching + the router or the database.""" + from litellm.types.router import Deployment, LiteLLM_Params + + mock_prisma_client = MagicMock() + mock_router = MagicMock() + + mock_litellm_params = MagicMock(spec=LiteLLM_Params) + mock_litellm_params.api_key = "router-api-key" + mock_litellm_params.api_base = "https://router-api-base.com" + mock_litellm_params.api_version = None + + mock_deployment = MagicMock(spec=Deployment) + mock_deployment.litellm_params = mock_litellm_params + mock_router.get_deployment_by_model_group_name.return_value = mock_deployment + + first = await _resolve_embedding_config( + embedding_model="cached-model", + prisma_client=mock_prisma_client, + llm_router=mock_router, + ) + assert first is not None + assert mock_router.get_deployment_by_model_group_name.call_count == 1 + + second = await _resolve_embedding_config( + embedding_model="cached-model", + prisma_client=mock_prisma_client, + llm_router=mock_router, + ) + assert second == first + # Router (and by extension the DB) was not consulted again. + assert mock_router.get_deployment_by_model_group_name.call_count == 1 + + @pytest.mark.asyncio async def test_resolve_embedding_config_falls_back_to_db(): """Test that _resolve_embedding_config falls back to DB when router doesn't have the model.""" @@ -1687,21 +1843,18 @@ async def test_new_vector_store_auto_resolves_from_router(): ) assert result["status"] == "success" - # Verify that embedding config was resolved from router and included in the create call + # Resolution against the router happens at request-handling time now, + # not at row creation. The persisted ``litellm_params`` carries only + # the model reference, never the cleartext credential. litellm_params_json = captured_create_data.get("litellm_params") assert litellm_params_json is not None litellm_params_dict = json.loads(litellm_params_json) - assert "litellm_embedding_config" in litellm_params_dict - assert ( - litellm_params_dict["litellm_embedding_config"]["api_key"] - == "router-resolved-api-key" - ) - assert ( - litellm_params_dict["litellm_embedding_config"]["api_base"] - == "https://router-resolved-base.com" - ) - assert ( - litellm_params_dict["litellm_embedding_config"]["api_version"] == "2024-03-01" + assert "litellm_embedding_config" not in litellm_params_dict + assert litellm_params_dict["litellm_embedding_model"] == "config-embedding-model" + + response_vs = result["vector_store"] + assert "router-resolved-api-key" not in _serialize_litellm_params( + response_vs.get("litellm_params") )