Merge remote-tracking branch 'upstream/litellm_internal_staging' into codex/skills-containers-tenant-guard

This commit is contained in:
user 2026-05-04 23:50:29 +00:00
commit b5a14f22d6
No known key found for this signature in database
11 changed files with 1528 additions and 204 deletions

View File

@ -11,17 +11,23 @@ Has 4 methods:
import ast
import asyncio
import json
from typing import Any, cast
import os
from typing import Any, Dict, cast
import litellm
from litellm._logging import print_verbose
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_str_from_messages,
)
from litellm.types.utils import EmbeddingResponse
from .base_cache import BaseCache
class QdrantSemanticCache(BaseCache):
CACHE_KEY_FIELD_NAME = "litellm_cache_key"
def __init__( # noqa: PLR0915
self,
qdrant_api_base=None,
@ -33,8 +39,6 @@ class QdrantSemanticCache(BaseCache):
host_type=None,
vector_size=None,
):
import os
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
@ -115,7 +119,9 @@ class QdrantSemanticCache(BaseCache):
print_verbose(
f"Collection already exists.\nCollection details:{self.collection_info}"
)
self._ensure_cache_key_payload_index()
else:
quantization_params: Dict[str, Any]
if quantization_config is None or quantization_config == "binary":
quantization_params = {
"binary": {
@ -156,6 +162,7 @@ class QdrantSemanticCache(BaseCache):
print_verbose(
f"New collection created.\nCollection details:{self.collection_info}"
)
self._ensure_cache_key_payload_index()
else:
raise Exception("Error while creating new collection")
@ -170,15 +177,94 @@ class QdrantSemanticCache(BaseCache):
cached_response = ast.literal_eval(cached_response)
return cached_response
def _get_qdrant_cache_key_filter(self, key: str) -> dict:
return {
"must": [
{
"key": self.CACHE_KEY_FIELD_NAME,
"match": {"value": str(key)},
}
]
}
def _add_cache_key_filter_to_search_data(self, data: dict, key: str) -> None:
data["filter"] = self._get_qdrant_cache_key_filter(key)
def _ensure_cache_key_payload_index(self) -> None:
try:
response = self.sync_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/index",
headers=self.headers,
json={
"field_name": self.CACHE_KEY_FIELD_NAME,
"field_schema": "keyword",
},
)
if response.status_code not in (200, 201):
print_verbose(
"Qdrant semantic-cache could not create cache-key payload index: "
f"{response.text}"
)
except Exception as exc:
print_verbose(
"Qdrant semantic-cache could not create cache-key payload index: "
f"{str(exc)}"
)
def _payload_matches_cache_key(self, payload: dict, key: str) -> bool:
# Pre-isolation points stored only prompt + response with no cache-key
# payload field. Reassigning them to a caller's key would risk
# cross-scope hits, so they're treated as misses and re-populated on
# the next set_cache.
cached_key = payload.get(self.CACHE_KEY_FIELD_NAME)
return cached_key is not None and str(cached_key) == str(key)
async def _get_async_embedding(self, prompt: str, **kwargs) -> Any:
llm_model_list = None
llm_router = None
try:
from litellm.proxy.proxy_server import (
llm_model_list as proxy_llm_model_list,
llm_router as proxy_llm_router,
)
llm_model_list = proxy_llm_model_list
llm_router = proxy_llm_router
except ImportError:
pass
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
return await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
return await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
def set_cache(self, key, value, **kwargs):
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
from litellm._uuid import uuid
# get the prompt
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
prompt = get_str_from_messages(messages)
# create an embedding for prompt
embedding_response = cast(
@ -202,6 +288,7 @@ class QdrantSemanticCache(BaseCache):
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {
self.CACHE_KEY_FIELD_NAME: str(key),
"text": prompt,
"response": value,
},
@ -220,9 +307,7 @@ class QdrantSemanticCache(BaseCache):
# get the messages
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
prompt = get_str_from_messages(messages)
# convert to embedding
embedding_response = cast(
@ -249,6 +334,7 @@ class QdrantSemanticCache(BaseCache):
"limit": 1,
"with_payload": True,
}
self._add_cache_key_filter_to_search_data(data=data, key=key)
search_response = self.sync_client.post(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
@ -258,21 +344,33 @@ class QdrantSemanticCache(BaseCache):
results = search_response.json()["result"]
if results is None:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
if isinstance(results, list):
if len(results) == 0:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
similarity = results[0]["score"]
cached_prompt = results[0]["payload"]["text"]
payload = results[0]["payload"]
if not self._payload_matches_cache_key(payload=payload, key=key):
print_verbose("Qdrant semantic-cache hit did not match cache key scope")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
cached_prompt = payload["text"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
)
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
if similarity >= self.similarity_threshold:
# cache hit !
cached_value = results[0]["payload"]["response"]
cached_value = payload["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)
@ -285,40 +383,12 @@ class QdrantSemanticCache(BaseCache):
async def async_set_cache(self, key, value, **kwargs):
from litellm._uuid import uuid
from litellm.proxy.proxy_server import llm_model_list, llm_router
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
# get the prompt
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
# create an embedding for prompt
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
prompt = get_str_from_messages(messages)
embedding_response = await self._get_async_embedding(prompt, **kwargs)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
@ -332,6 +402,7 @@ class QdrantSemanticCache(BaseCache):
"id": str(uuid.uuid4()),
"vector": embedding,
"payload": {
self.CACHE_KEY_FIELD_NAME: str(key),
"text": prompt,
"response": value,
},
@ -348,38 +419,12 @@ class QdrantSemanticCache(BaseCache):
async def async_get_cache(self, key, **kwargs):
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
from litellm.proxy.proxy_server import llm_model_list, llm_router
# get the messages
messages = kwargs["messages"]
prompt = ""
for message in messages:
prompt += message["content"]
prompt = get_str_from_messages(messages)
router_model_names = (
[m["model_name"] for m in llm_model_list]
if llm_model_list is not None
else []
)
if llm_router is not None and self.embedding_model in router_model_names:
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
embedding_response = await llm_router.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
metadata={
"user_api_key": user_api_key,
"semantic-cache-embedding": True,
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
},
)
else:
# convert to embedding
embedding_response = await litellm.aembedding(
model=self.embedding_model,
input=prompt,
cache={"no-store": True, "no-cache": True},
)
embedding_response = await self._get_async_embedding(prompt, **kwargs)
# get the embedding
embedding = embedding_response["data"][0]["embedding"]
@ -396,6 +441,7 @@ class QdrantSemanticCache(BaseCache):
"limit": 1,
"with_payload": True,
}
self._add_cache_key_filter_to_search_data(data=data, key=key)
search_response = await self.async_client.post(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
@ -414,7 +460,13 @@ class QdrantSemanticCache(BaseCache):
return None
similarity = results[0]["score"]
cached_prompt = results[0]["payload"]["text"]
payload = results[0]["payload"]
if not self._payload_matches_cache_key(payload=payload, key=key):
print_verbose("Qdrant semantic-cache hit did not match cache key scope")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
cached_prompt = payload["text"]
# check similarity, if more than self.similarity_threshold, return results
print_verbose(
@ -426,7 +478,7 @@ class QdrantSemanticCache(BaseCache):
if similarity >= self.similarity_threshold:
# cache hit !
cached_value = results[0]["payload"]["response"]
cached_value = payload["response"]
print_verbose(
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
)

View File

@ -35,6 +35,7 @@ class RedisSemanticCache(BaseCache):
"""
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
CACHE_KEY_FIELD_NAME: str = "litellm_cache_key"
def __init__(
self,
@ -66,8 +67,8 @@ class RedisSemanticCache(BaseCache):
Exception: If similarity_threshold is not provided or required Redis
connection information is missing
"""
from redisvl.extensions.llmcache import SemanticCache
from redisvl.utils.vectorize import CustomTextVectorizer
from redisvl.extensions.llmcache import SemanticCache # type: ignore[import-not-found, import-untyped]
from redisvl.utils.vectorize import CustomTextVectorizer # type: ignore[import-not-found, import-untyped]
if index_name is None:
index_name = self.DEFAULT_REDIS_INDEX_NAME
@ -109,14 +110,94 @@ class RedisSemanticCache(BaseCache):
# Initialize the Redis vectorizer and cache
cache_vectorizer = CustomTextVectorizer(self._get_embedding)
self.llmcache = SemanticCache(
name=index_name,
self.llmcache = self._init_semantic_cache(
semantic_cache_cls=SemanticCache,
index_name=index_name,
redis_url=redis_url,
vectorizer=cache_vectorizer,
distance_threshold=self.distance_threshold,
overwrite=False,
cache_vectorizer=cache_vectorizer,
)
@classmethod
def _cache_key_filterable_field(cls) -> Dict[str, str]:
return {
"name": cls.CACHE_KEY_FIELD_NAME,
"type": "tag",
}
def _init_semantic_cache(
self,
semantic_cache_cls: Any,
index_name: str,
redis_url: str,
cache_vectorizer: Any,
) -> Any:
def _is_schema_mismatch(exc: ValueError) -> bool:
error_message = str(exc).lower()
return any(
phrase in error_message
for phrase in ("schema does not match", "index schema")
)
try:
return semantic_cache_cls(
name=index_name,
redis_url=redis_url,
vectorizer=cache_vectorizer,
distance_threshold=self.distance_threshold,
filterable_fields=[self._cache_key_filterable_field()],
overwrite=False,
)
except ValueError as exc:
if not _is_schema_mismatch(exc):
raise
isolated_index_name = f"{index_name}_isolated"
print_verbose(
"Redis semantic-cache existing index schema is not isolated; "
f"using isolated index - {isolated_index_name}"
)
try:
return semantic_cache_cls(
name=isolated_index_name,
redis_url=redis_url,
vectorizer=cache_vectorizer,
distance_threshold=self.distance_threshold,
filterable_fields=[self._cache_key_filterable_field()],
overwrite=False,
)
except ValueError as isolated_exc:
if not _is_schema_mismatch(isolated_exc):
raise
print_verbose(
"Redis semantic-cache isolated index schema is stale; "
f"recreating isolated index - {isolated_index_name}"
)
return semantic_cache_cls(
name=isolated_index_name,
redis_url=redis_url,
vectorizer=cache_vectorizer,
distance_threshold=self.distance_threshold,
filterable_fields=[self._cache_key_filterable_field()],
overwrite=True,
)
def _get_cache_filters(self, key: str) -> Dict[str, str]:
return {self.CACHE_KEY_FIELD_NAME: str(key)}
def _get_cache_key_filter_expression(self, key: str) -> Any:
from redisvl.query.filter import Tag # type: ignore[import-not-found, import-untyped]
return Tag(self.CACHE_KEY_FIELD_NAME) == str(key)
def _cache_hit_matches_key(self, cache_hit: Dict[str, Any], key: str) -> bool:
# Pre-isolation entries with no ``litellm_cache_key`` field cannot be
# safely reassigned to a caller's scope and are treated as misses.
cached_key = cache_hit.get(self.CACHE_KEY_FIELD_NAME)
if isinstance(cached_key, bytes):
cached_key = cached_key.decode("utf-8")
return cached_key is not None and str(cached_key) == str(key)
def _get_ttl(self, **kwargs) -> Optional[int]:
"""
Get the TTL (time-to-live) value for cache entries.
@ -188,7 +269,7 @@ class RedisSemanticCache(BaseCache):
Store a value in the semantic cache.
Args:
key: The cache key (not directly used in semantic caching)
key: The cache key used to isolate semantic cache entries
value: The response value to cache
**kwargs: Additional arguments including 'messages' for the prompt
and optional 'ttl' for time-to-live
@ -206,12 +287,15 @@ class RedisSemanticCache(BaseCache):
prompt = get_str_from_messages(messages)
value_str = str(value)
store_kwargs: Dict[str, Any] = {
"filters": self._get_cache_filters(key),
}
# Get TTL and store in Redis semantic cache
ttl = self._get_ttl(**kwargs)
if ttl is not None:
self.llmcache.store(prompt, value_str, ttl=int(ttl))
else:
self.llmcache.store(prompt, value_str)
store_kwargs["ttl"] = int(ttl)
self.llmcache.store(prompt, value_str, **store_kwargs)
except Exception as e:
print_verbose(
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
@ -222,7 +306,7 @@ class RedisSemanticCache(BaseCache):
Retrieve a semantically similar cached response.
Args:
key: The cache key (not directly used in semantic caching)
key: The cache key used to isolate semantic cache entries
**kwargs: Additional arguments including 'messages' for the prompt
Returns:
@ -235,18 +319,29 @@ class RedisSemanticCache(BaseCache):
messages = kwargs.get("messages", [])
if not messages:
print_verbose("No messages provided for semantic cache lookup")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
prompt = get_str_from_messages(messages)
# Check the cache for semantically similar prompts
results = self.llmcache.check(prompt=prompt)
# Check the cache for semantically similar prompts in this exact
# LiteLLM cache-key scope.
check_kwargs: Dict[str, Any] = {
"prompt": prompt,
"filter_expression": self._get_cache_key_filter_expression(key),
}
results = self.llmcache.check(**check_kwargs)
# Return None if no similar prompts found
if not results:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
# Process the best matching result
cache_hit = results[0]
if not self._cache_hit_matches_key(cache_hit=cache_hit, key=key):
print_verbose("Redis semantic-cache hit did not match cache key scope")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
vector_distance = float(cache_hit["vector_distance"])
# Convert vector distance back to similarity score
@ -257,6 +352,9 @@ class RedisSemanticCache(BaseCache):
cached_prompt = cache_hit["prompt"]
cached_response = cache_hit["response"]
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
print_verbose(
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
f"actual similarity: {similarity}, "
@ -267,6 +365,7 @@ class RedisSemanticCache(BaseCache):
return self._get_cache_logic(cached_response=cached_response)
except Exception as e:
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
"""
@ -321,7 +420,7 @@ class RedisSemanticCache(BaseCache):
Asynchronously store a value in the semantic cache.
Args:
key: The cache key (not directly used in semantic caching)
key: The cache key used to isolate semantic cache entries
value: The response value to cache
**kwargs: Additional arguments including 'messages' for the prompt
and optional 'ttl' for time-to-live
@ -341,21 +440,20 @@ class RedisSemanticCache(BaseCache):
# Generate embedding for the value (response) to cache
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
store_kwargs: Dict[str, Any] = {
"vector": prompt_embedding,
"filters": self._get_cache_filters(key),
}
# Get TTL and store in Redis semantic cache
ttl = self._get_ttl(**kwargs)
if ttl is not None:
await self.llmcache.astore(
prompt,
value_str,
vector=prompt_embedding, # Pass through custom embedding
ttl=ttl,
)
else:
await self.llmcache.astore(
prompt,
value_str,
vector=prompt_embedding, # Pass through custom embedding
)
store_kwargs["ttl"] = ttl
await self.llmcache.astore(
prompt,
value_str,
**store_kwargs,
)
except Exception as e:
print_verbose(f"Error in async_set_cache: {str(e)}")
@ -364,7 +462,7 @@ class RedisSemanticCache(BaseCache):
Asynchronously retrieve a semantically similar cached response.
Args:
key: The cache key (not directly used in semantic caching)
key: The cache key used to isolate semantic cache entries
**kwargs: Additional arguments including 'messages' for the prompt
Returns:
@ -385,17 +483,25 @@ class RedisSemanticCache(BaseCache):
# Generate embedding for the prompt
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
# Check the cache for semantically similar prompts
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
# Check the cache for semantically similar prompts in this exact
# LiteLLM cache-key scope.
check_kwargs: Dict[str, Any] = {
"prompt": prompt,
"vector": prompt_embedding,
"filter_expression": self._get_cache_key_filter_expression(key),
}
results = await self.llmcache.acheck(**check_kwargs)
# handle results / cache hit
if not results:
kwargs.setdefault("metadata", {})[
"semantic-similarity"
] = 0.0 # TODO why here but not above??
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
cache_hit = results[0]
if not self._cache_hit_matches_key(cache_hit=cache_hit, key=key):
print_verbose("Redis semantic-cache hit did not match cache key scope")
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
vector_distance = float(cache_hit["vector_distance"])
# Convert vector distance back to similarity

View File

@ -2,7 +2,7 @@ import os
import re
import sys
from functools import lru_cache
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, FrozenSet, List, Mapping, Optional, Tuple, Union
from fastapi import HTTPException, Request, status
@ -173,9 +173,67 @@ def _allow_model_level_clientside_configurable_parameters(
# threat shape should be added here.
_NESTED_CONFIG_KEYS: Tuple[str, ...] = ("litellm_embedding_config",)
# Banned root-level params. Same list applies to every entry in
# ``_NESTED_CONFIG_KEYS`` because those dicts get spread as ``**kwargs``
# into the same outbound calls.
# Metadata containers that carry per-request configuration consumed by the
# observability callbacks. The same banned-param list applies — a value
# under ``metadata.langfuse_host`` redirects the same Langfuse client and
# leaks the same credentials as the root-level ``langfuse_host``, but the
# original check only walked the request-body root, so the metadata path
# was an unintentional bypass.
_NESTED_METADATA_KEYS: Tuple[str, ...] = ("metadata", "litellm_metadata")
# Banned request-body params. The same list applies to every entry in
# ``_NESTED_CONFIG_KEYS`` (dicts spread as ``**kwargs`` into outbound
# calls) and ``_NESTED_METADATA_KEYS`` (dicts read directly by integration
# callbacks), so a single banned name is enforced wherever the field can
# reach the call path from.
# Per-request observability params that are SAFE to accept from clients.
# These describe the request being logged (prompt version, sampling rate)
# without choosing the destination or the credentials, so they don't
# contribute to the data-exfil primitive that the rest of
# ``_supported_callback_params`` does.
_SAFE_CLIENT_CALLBACK_PARAMS: FrozenSet[str] = frozenset(
{
"langfuse_prompt_version",
"langsmith_sampling_rate",
}
)
# Observability fields that integrations read from the request body or
# metadata but that are not (yet) listed in ``_supported_callback_params``.
# Listed here so the proxy bans them today; the long-term cleanup is to
# fold these into the canonical allowlist so they share one source of
# truth with the rest.
_EXTRA_BANNED_OBSERVABILITY_PARAMS: FrozenSet[str] = frozenset(
{
"posthog_api_url",
"phoenix_project_name",
"wandb_api_key",
"weave_project_id",
}
)
def _build_banned_observability_params() -> FrozenSet[str]:
"""Derive the observability ban list from the canonical allowlist.
``_supported_callback_params`` in
``litellm/litellm_core_utils/initialize_dynamic_callback_params.py`` is
the single place that enumerates every observability field
integrations resolve from kwargs/metadata. Subtract the small set of
informational fields (``_SAFE_CLIENT_CALLBACK_PARAMS``) and union with
the extras the canonical allowlist hasn't caught up to yet. New
integrations added to the canonical allowlist are banned by default,
which is the safe failure mode.
"""
from litellm.litellm_core_utils.initialize_dynamic_callback_params import (
_supported_callback_params,
)
return (
frozenset(_supported_callback_params) - _SAFE_CLIENT_CALLBACK_PARAMS
) | _EXTRA_BANNED_OBSERVABILITY_PARAMS
_BANNED_REQUEST_BODY_PARAMS: Tuple[str, ...] = (
"api_base",
"base_url",
@ -190,11 +248,6 @@ _BANNED_REQUEST_BODY_PARAMS: Tuple[str, ...] = (
# tokens) to the attacker's host, or coerces the proxy into
# authenticating against the attacker's host with admin secrets.
"aws_bedrock_runtime_endpoint",
"langsmith_base_url",
"langfuse_host",
"posthog_host",
"braintrust_host",
"slack_webhook_url",
# Provider-specific endpoint overrides that flow into the outbound
# request via ``optional_params``. Same threat as ``api_base``:
# ``s3_endpoint_url`` redirects Bedrock file uploads to attacker
@ -203,6 +256,11 @@ _BANNED_REQUEST_BODY_PARAMS: Tuple[str, ...] = (
"s3_endpoint_url",
"sagemaker_base_url",
"deployment_url",
# Observability credentials, hosts, and project identifiers: derived
# from the canonical ``_supported_callback_params`` allowlist so new
# integrations are covered automatically. Sorted for stable iteration
# order and reviewable diffs.
*sorted(_build_banned_observability_params()),
)
@ -221,6 +279,8 @@ def _check_banned_params(
if param not in body:
continue
if general_settings.get("allow_client_side_credentials") is True:
# Proxy-wide opt-in: every banned param is permitted, exit
# entirely so the rest of the loop doesn't waste work.
return
if (
_allow_model_level_clientside_configurable_parameters(
@ -231,7 +291,12 @@ def _check_banned_params(
)
is True
):
return
# Per-param opt-in: only THIS param is permitted by the
# deployment's ``configurable_clientside_auth_params``. Skip
# to the next banned param so a body that pairs an allowed
# ``api_base`` with an unallowed ``langfuse_host`` is still
# rejected for the second field.
continue
raise ValueError(
f"Rejected Request: {param} is not allowed in request body. "
"Clientside passthrough requires explicit admin opt-in via "
@ -275,9 +340,33 @@ def is_request_body_safe(
nested = request_body.get(nested_key)
if isinstance(nested, dict):
_check_banned_params(nested, general_settings, llm_router, model)
for metadata_key in _NESTED_METADATA_KEYS:
metadata = _coerce_metadata_to_dict(request_body.get(metadata_key))
if metadata is not None:
_check_banned_params(metadata, general_settings, llm_router, model)
return True
def _coerce_metadata_to_dict(value: Any) -> Optional[Dict[str, Any]]:
"""Return ``value`` as a dict, parsing it from JSON if delivered as a string.
Multipart/form-data and ``extra_body`` callers send ``litellm_metadata``
as a JSON-encoded string; the proxy parses it into a dict later in
``add_litellm_data_to_request``, but the auth-time bouncer runs first
and would otherwise miss the banned-param check on a still-stringified
metadata blob.
"""
if isinstance(value, dict):
return value
if isinstance(value, str):
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
parsed = safe_json_loads(value)
if isinstance(parsed, dict):
return parsed
return None
async def pre_db_read_auth_checks(
request: Request,
request_data: dict,

View File

@ -121,6 +121,19 @@ _UNTRUSTED_ROOT_CONTROL_FIELDS = (
"pillar_response_headers",
"_guardrail_pipelines",
"_pipeline_managed_guardrails",
# Callback-registration fields. ``callbacks``, ``service_callback``,
# and ``logger_fn`` are read by ``litellm.utils.function_setup`` and
# appended to process-wide ``litellm.{input,success,failure,_async_*,
# service}_callback`` lists / ``litellm.user_logger_fn`` — one request
# poisons the worker for every subsequent caller.
# ``litellm_disabled_callbacks`` is the inverse primitive: the
# legitimate path reads it from key/team metadata, the request-body
# version silently turns off admin-configured audit/observability
# for the caller's request.
"callbacks",
"service_callback",
"logger_fn",
"litellm_disabled_callbacks",
)
_UNTRUSTED_METADATA_CONTROL_FIELDS = (

View File

@ -8,6 +8,9 @@ from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
from litellm.proxy.utils import jsonify_object
from litellm.proxy.vector_store_endpoints.management_endpoints import (
_resolve_embedding_config,
)
from litellm.proxy.vector_store_endpoints.utils import (
assert_user_can_access_vector_store,
get_litellm_managed_vector_store,
@ -56,6 +59,30 @@ async def _update_request_data_with_litellm_managed_vector_store_registry(
if "litellm_params" in vector_store_to_run:
litellm_params = vector_store_to_run.get("litellm_params", {}) or {}
# Resolve ``litellm_embedding_config`` here, at request-handling
# time, instead of at row-creation time. The resolved
# ``api_key`` / ``api_base`` / ``api_version`` lives only in
# this per-request ``data`` dict and is never persisted.
# Legacy rows that already carry a resolved (cleartext)
# ``litellm_embedding_config`` skip the lookup and pass through
# unchanged so the embed call keeps working.
embedding_model = litellm_params.get("litellm_embedding_model")
if embedding_model and not litellm_params.get("litellm_embedding_config"):
from litellm.proxy.proxy_server import prisma_client
resolved_config = await _resolve_embedding_config(
embedding_model=embedding_model, prisma_client=prisma_client
)
if resolved_config:
# Build a fresh dict via spread instead of mutating
# ``litellm_params`` in place — the registry hands back
# a reference to its cached object, so an in-place
# update would persist the resolved cleartext into the
# in-memory cache for the lifetime of the process.
litellm_params = {
**litellm_params,
"litellm_embedding_config": resolved_config,
}
data.update(litellm_params)
return data

View File

@ -16,6 +16,7 @@ from fastapi import APIRouter, Depends, HTTPException
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.caching.in_memory_cache import InMemoryCache
from litellm.constants import REDACTED_BY_LITELM_STRING
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
@ -45,6 +46,28 @@ _LITELLM_PARAMS_MASKER = SensitiveDataMasker()
_REDACT_LITELLM_PARAMS_MAX_DEPTH = 10
# Use-time embedding-config resolution runs on every vector-store request
# whose persisted row carries only a model reference (the post-fix shape).
# Without a cache, that's one ``litellm_proxymodeltable.find_first`` per
# request — the no-DB-in-critical-path rule. Hold the resolved config in
# memory for a short TTL so a hot model name pays the DB lookup at most
# once per ``_EMBEDDING_CONFIG_CACHE_TTL`` seconds. Cleartext credentials
# only ever live in process memory (never persisted, never echoed in
# management responses), so the cache doesn't widen the disclosure surface.
_EMBEDDING_CONFIG_CACHE_TTL = 60
_EMBEDDING_CONFIG_CACHE_MAX_SIZE = 256
_embedding_config_cache: Optional[InMemoryCache] = None
def _get_embedding_config_cache() -> InMemoryCache:
global _embedding_config_cache
if _embedding_config_cache is None:
_embedding_config_cache = InMemoryCache(
max_size_in_memory=_EMBEDDING_CONFIG_CACHE_MAX_SIZE,
default_ttl=_EMBEDDING_CONFIG_CACHE_TTL,
)
return _embedding_config_cache
def _redact_sensitive_litellm_params(litellm_params: Any, _depth: int = 0) -> Any:
"""
@ -303,6 +326,11 @@ async def _resolve_embedding_config(
This function first checks the router for config-defined models, then falls back
to the database. This allows users to use models defined in either location.
Results are cached in process memory for ``_EMBEDDING_CONFIG_CACHE_TTL``
seconds so the request-handling path doesn't hit the database on every
vector-store call. Negative results (model not found) are intentionally
not cached to avoid blocking a freshly-added model behind the TTL.
Args:
embedding_model: The embedding model string (e.g., "text-embedding-ada-002" or "azure/text-embedding-3-large")
prisma_client: The Prisma client instance
@ -314,6 +342,11 @@ async def _resolve_embedding_config(
if not embedding_model:
return None
cache = _get_embedding_config_cache()
cached = cache.get_cache(embedding_model)
if cached is not None:
return cached
# Import llm_router if not provided
if llm_router is None:
try:
@ -330,6 +363,7 @@ async def _resolve_embedding_config(
verbose_proxy_logger.debug(
f"Resolved embedding config from router for model {embedding_model}"
)
cache.set_cache(embedding_model, router_config)
return router_config
# Fall back to database
@ -341,6 +375,7 @@ async def _resolve_embedding_config(
verbose_proxy_logger.debug(
f"Resolved embedding config from database for model {embedding_model}"
)
cache.set_cache(embedding_model, db_config)
return db_config
verbose_proxy_logger.debug(
@ -432,20 +467,17 @@ async def create_vector_store_in_db(
if user_id is not None:
data_to_create["user_id"] = user_id
# Handle litellm_params - always provide at least an empty dict
# Handle litellm_params - always provide at least an empty dict.
# The earlier behaviour resolved ``litellm_embedding_config`` from the
# admin-configured router/DB model and persisted the cleartext result
# (``api_key``, ``api_base``, ``api_version``) into this row. That
# exposed every env-stored embedding-model credential on the
# ``/vector_store/{new,info,update,list}`` responses. Keep the user's
# raw ``litellm_embedding_model`` reference; resolution now happens in
# ``_update_request_data_with_litellm_managed_vector_store_registry``
# at request-handling time so the cleartext config exists only in
# per-request memory and never reaches the database.
if litellm_params:
# Auto-resolve embedding config if embedding model is provided but config is not
embedding_model = litellm_params.get("litellm_embedding_model")
if embedding_model and not litellm_params.get("litellm_embedding_config"):
resolved_config = await _resolve_embedding_config(
embedding_model=embedding_model, prisma_client=prisma_client
)
if resolved_config:
litellm_params["litellm_embedding_config"] = resolved_config
verbose_proxy_logger.info(
f"Auto-resolved embedding config for model {embedding_model}"
)
litellm_params_dict = GenericLiteLLMParams(**litellm_params).model_dump(
exclude_none=True
)
@ -531,10 +563,19 @@ async def new_vector_store(
user_id=user_api_key_dict.user_id,
)
# Apply the same litellm_params redaction the list / info / update
# endpoints already use, so a caller-supplied credential or a
# cleartext value persisted by an earlier proxy version doesn't
# come back in the response.
response_vs = LiteLLM_ManagedVectorStore(**new_vector_store)
response_vs["litellm_params"] = _redact_sensitive_litellm_params(
new_vector_store.get("litellm_params")
)
return {
"status": "success",
"message": f"Vector store {vector_store.get('vector_store_id')} created successfully",
"vector_store": new_vector_store,
"vector_store": response_vs,
}
except Exception as e:
verbose_proxy_logger.exception(f"Error creating vector store: {str(e)}")
@ -865,24 +906,15 @@ async def update_vector_store(
update_data["vector_store_metadata"]
)
# Handle litellm_params if provided
# Handle litellm_params if provided. As with the create path, the
# embedding-config auto-resolve previously persisted cleartext
# credentials into the row; resolution now happens at request-
# handling time in
# ``_update_request_data_with_litellm_managed_vector_store_registry``
# so this row only ever stores the user-supplied
# ``litellm_embedding_model`` reference.
if "litellm_params" in update_data:
_input_litellm_params: dict = update_data.get("litellm_params", {}) or {}
# Auto-resolve embedding config if embedding model is provided but config is not
embedding_model = _input_litellm_params.get("litellm_embedding_model")
if embedding_model and not _input_litellm_params.get(
"litellm_embedding_config"
):
resolved_config = await _resolve_embedding_config(
embedding_model=embedding_model, prisma_client=prisma_client
)
if resolved_config:
_input_litellm_params["litellm_embedding_config"] = resolved_config
verbose_proxy_logger.info(
f"Auto-resolved embedding config for model {embedding_model}"
)
litellm_params_dict = GenericLiteLLMParams(
**_input_litellm_params
).model_dump(exclude_none=True)

View File

@ -19,9 +19,7 @@ def test_qdrant_semantic_cache_initialization(monkeypatch):
patch(
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
) as mock_sync_client,
patch(
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
) as mock_async_client,
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
):
# Mock the collection exists check
@ -31,6 +29,9 @@ def test_qdrant_semantic_cache_initialization(monkeypatch):
mock_sync_client_instance = MagicMock()
mock_sync_client_instance.get.return_value = mock_response
mock_index_response = MagicMock()
mock_index_response.status_code = 200
mock_sync_client_instance.put.return_value = mock_index_response
mock_sync_client.return_value = mock_sync_client_instance
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
@ -48,6 +49,17 @@ def test_qdrant_semantic_cache_initialization(monkeypatch):
assert qdrant_cache.qdrant_api_base == "http://test.qdrant.local"
assert qdrant_cache.qdrant_api_key == "test_key"
assert qdrant_cache.similarity_threshold == 0.8
mock_sync_client_instance.put.assert_called_once_with(
url="http://test.qdrant.local/collections/test_collection/index",
headers={
"Content-Type": "application/json",
"api-key": "test_key",
},
json={
"field_name": QdrantSemanticCache.CACHE_KEY_FIELD_NAME,
"field_schema": "keyword",
},
)
# Test initialization with missing similarity_threshold
with pytest.raises(Exception, match="similarity_threshold must be provided"):
@ -67,9 +79,7 @@ def test_qdrant_semantic_cache_get_cache_hit():
patch(
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
) as mock_sync_client,
patch(
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
) as mock_async_client,
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
):
# Mock the collection exists check
@ -98,6 +108,7 @@ def test_qdrant_semantic_cache_get_cache_hit():
"result": [
{
"payload": {
QdrantSemanticCache.CACHE_KEY_FIELD_NAME: "test_key",
"text": "What is the capital of France?", # Original prompt
"response": '{"id": "test-123", "choices": [{"message": {"content": "Paris is the capital of France."}}]}',
},
@ -127,6 +138,177 @@ def test_qdrant_semantic_cache_get_cache_hit():
# Verify search was called
qdrant_cache.sync_client.post.assert_called()
assert qdrant_cache.sync_client.post.call_args.kwargs["json"]["filter"] == {
"must": [
{
"key": QdrantSemanticCache.CACHE_KEY_FIELD_NAME,
"match": {"value": "test_key"},
}
]
}
def test_qdrant_semantic_cache_rejects_unscoped_cache_hit():
"""
Test QDRANT semantic cache rejects old or unscoped cache hits.
Legacy points have only text and response payloads, so they cannot be
safely migrated to a generated LiteLLM cache key.
"""
with (
patch(
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
) as mock_sync_client,
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"result": {"exists": True}}
mock_sync_client_instance = MagicMock()
mock_sync_client_instance.get.return_value = mock_response
mock_sync_client.return_value = mock_sync_client_instance
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
qdrant_cache = QdrantSemanticCache(
collection_name="test_collection",
qdrant_api_base="http://test.qdrant.local",
qdrant_api_key="test_key",
similarity_threshold=0.8,
)
mock_search_response = MagicMock()
mock_search_response.status_code = 200
mock_search_response.json.return_value = {
"result": [
{
"payload": {
"text": "What is the capital of France?",
"response": '{"id": "test-123"}',
},
"score": 0.9,
}
]
}
qdrant_cache.sync_client.post = MagicMock(return_value=mock_search_response)
with patch(
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
):
metadata = {}
result = qdrant_cache.get_cache(
key="test_key",
messages=[{"content": "What is the capital of France?"}],
metadata=metadata,
)
assert result is None
assert metadata["semantic-similarity"] == 0.0
def test_qdrant_semantic_cache_payload_index_failure_is_non_blocking():
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
qdrant_cache = QdrantSemanticCache.__new__(QdrantSemanticCache)
qdrant_cache.qdrant_api_base = "http://test.qdrant.local"
qdrant_cache.collection_name = "test_collection"
qdrant_cache.headers = {"Content-Type": "application/json"}
qdrant_cache.sync_client = MagicMock()
response = MagicMock()
response.status_code = 400
response.text = "bad index"
qdrant_cache.sync_client.put.return_value = response
qdrant_cache._ensure_cache_key_payload_index()
qdrant_cache.sync_client.put.assert_called_once()
def test_qdrant_semantic_cache_payload_index_exception_is_non_blocking():
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
qdrant_cache = QdrantSemanticCache.__new__(QdrantSemanticCache)
qdrant_cache.qdrant_api_base = "http://test.qdrant.local"
qdrant_cache.collection_name = "test_collection"
qdrant_cache.headers = {"Content-Type": "application/json"}
qdrant_cache.sync_client = MagicMock()
qdrant_cache.sync_client.put.side_effect = Exception("boom")
qdrant_cache._ensure_cache_key_payload_index()
qdrant_cache.sync_client.put.assert_called_once()
def _mock_qdrant_get_cache_result(qdrant_result):
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
qdrant_cache = QdrantSemanticCache.__new__(QdrantSemanticCache)
qdrant_cache.embedding_model = "text-embedding-ada-002"
qdrant_cache.qdrant_api_base = "http://test.qdrant.local"
qdrant_cache.collection_name = "test_collection"
qdrant_cache.headers = {
"Content-Type": "application/json",
"api-key": "test_key",
}
qdrant_cache.similarity_threshold = 0.8
qdrant_cache.sync_client = MagicMock()
mock_search_response = MagicMock()
mock_search_response.status_code = 200
mock_search_response.json.return_value = {"result": qdrant_result}
qdrant_cache.sync_client.post.return_value = mock_search_response
return qdrant_cache, QdrantSemanticCache
@pytest.mark.parametrize("qdrant_result", [None, []])
def test_qdrant_semantic_cache_get_cache_sets_metadata_on_empty_miss(qdrant_result):
qdrant_cache, _ = _mock_qdrant_get_cache_result(qdrant_result)
metadata = {}
with patch(
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
):
result = qdrant_cache.get_cache(
key="test_key",
messages=[{"content": "What is the capital of Spain?"}],
metadata=metadata,
)
assert result is None
assert metadata["semantic-similarity"] == 0.0
def test_qdrant_semantic_cache_get_cache_sets_metadata_on_below_threshold_miss():
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
qdrant_cache, _ = _mock_qdrant_get_cache_result(
[
{
"payload": {
QdrantSemanticCache.CACHE_KEY_FIELD_NAME: "test_key",
"text": "What is the capital of Spain?",
"response": '{"id": "test-456"}',
},
"score": 0.7,
}
]
)
metadata = {}
with patch(
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
):
result = qdrant_cache.get_cache(
key="test_key",
messages=[{"content": "What is the capital of Spain?"}],
metadata=metadata,
)
assert result is None
assert metadata["semantic-similarity"] == 0.7
def test_qdrant_semantic_cache_get_cache_miss():
@ -138,9 +320,7 @@ def test_qdrant_semantic_cache_get_cache_miss():
patch(
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
) as mock_sync_client,
patch(
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
) as mock_async_client,
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
):
# Mock the collection exists check
@ -230,6 +410,7 @@ async def test_qdrant_semantic_cache_async_get_cache_hit():
"result": [
{
"payload": {
QdrantSemanticCache.CACHE_KEY_FIELD_NAME: "test_key",
"text": "What is the capital of Spain?", # Original prompt
"response": '{"id": "test-456", "choices": [{"message": {"content": "Madrid is the capital of Spain."}}]}',
},
@ -262,6 +443,16 @@ async def test_qdrant_semantic_cache_async_get_cache_hit():
# Verify async search was called
qdrant_cache.async_client.post.assert_called()
assert qdrant_cache.async_client.post.call_args.kwargs["json"][
"filter"
] == {
"must": [
{
"key": QdrantSemanticCache.CACHE_KEY_FIELD_NAME,
"match": {"value": "test_key"},
}
]
}
@pytest.mark.asyncio
@ -336,9 +527,7 @@ def test_qdrant_semantic_cache_set_cache():
patch(
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
) as mock_sync_client,
patch(
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
) as mock_async_client,
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
):
# Mock the collection exists check
@ -384,6 +573,12 @@ def test_qdrant_semantic_cache_set_cache():
# Verify upsert was called
qdrant_cache.sync_client.put.assert_called()
upsert_payload = qdrant_cache.sync_client.put.call_args.kwargs["json"][
"points"
][0]["payload"]
assert (
upsert_payload[QdrantSemanticCache.CACHE_KEY_FIELD_NAME] == "test_key"
)
@pytest.mark.asyncio
@ -450,6 +645,12 @@ async def test_qdrant_semantic_cache_async_set_cache():
# Verify async upsert was called
qdrant_cache.async_client.put.assert_called()
upsert_payload = qdrant_cache.async_client.put.call_args.kwargs["json"][
"points"
][0]["payload"]
assert (
upsert_payload[QdrantSemanticCache.CACHE_KEY_FIELD_NAME] == "test_key"
)
def test_qdrant_semantic_cache_custom_vector_size():
@ -462,9 +663,7 @@ def test_qdrant_semantic_cache_custom_vector_size():
patch(
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
) as mock_sync_client,
patch(
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
) as mock_async_client,
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
):
# Mock the collection does NOT exist (so it will be created)
@ -505,9 +704,13 @@ def test_qdrant_semantic_cache_custom_vector_size():
assert qdrant_cache.vector_size == 768
# Verify the PUT call to create the collection used vector_size=768
put_call = mock_sync_client_instance.put.call_args
assert put_call is not None
create_payload = put_call.kwargs.get("json") or put_call[1].get("json")
put_call = next(
call
for call in mock_sync_client_instance.put.call_args_list
if call.kwargs["url"]
== "http://test.qdrant.local/collections/test_collection_768"
)
create_payload = put_call.kwargs["json"]
assert create_payload["vectors"]["size"] == 768
assert create_payload["vectors"]["distance"] == "Cosine"
@ -521,9 +724,7 @@ def test_qdrant_semantic_cache_default_vector_size():
patch(
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
) as mock_sync_client,
patch(
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
) as mock_async_client,
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
):
# Mock the collection exists check
@ -559,9 +760,7 @@ def test_qdrant_semantic_cache_large_vector_size():
patch(
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
) as mock_sync_client,
patch(
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
) as mock_async_client,
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
):
# Mock the collection does NOT exist (so it will be created)
@ -599,6 +798,11 @@ def test_qdrant_semantic_cache_large_vector_size():
assert qdrant_cache.vector_size == 4096
# Verify the collection was created with 4096
put_call = mock_sync_client_instance.put.call_args
create_payload = put_call.kwargs.get("json") or put_call[1].get("json")
put_call = next(
call
for call in mock_sync_client_instance.put.call_args_list
if call.kwargs["url"]
== "http://test.qdrant.local/collections/test_collection_4096"
)
create_payload = put_call.kwargs["json"]
assert create_payload["vectors"]["size"] == 4096

View File

@ -72,24 +72,301 @@ def test_redis_semantic_cache_get_cache(monkeypatch):
"prompt": "What is the capital of France?",
"response": '{"content": "Paris is the capital of France."}',
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key",
}
]
redis_semantic_cache.llmcache.check = MagicMock(return_value=mock_result)
# Mock the embedding function
with patch(
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
with (
patch(
"litellm.embedding",
return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]},
),
patch.object(
redis_semantic_cache,
"_get_cache_key_filter_expression",
return_value="cache-key-filter",
),
):
# Test get_cache with a message
metadata = {}
result = redis_semantic_cache.get_cache(
key="test_key", messages=[{"content": "What is the capital of France?"}]
key="test_key",
messages=[{"content": "What is the capital of France?"}],
metadata=metadata,
)
# Verify result is properly parsed
assert result == {"content": "Paris is the capital of France."}
assert metadata["semantic-similarity"] == pytest.approx(0.9)
# Verify llmcache.check was called
redis_semantic_cache.llmcache.check.assert_called_once()
redis_semantic_cache.llmcache.check.assert_called_once_with(
prompt="What is the capital of France?",
filter_expression="cache-key-filter",
)
def test_redis_semantic_cache_rejects_unscoped_cache_hit(monkeypatch):
semantic_cache_mock = MagicMock()
custom_vectorizer_mock = MagicMock()
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
redis_semantic_cache.llmcache.check = MagicMock(
return_value=[
{
"prompt": "What is the capital of France?",
"response": '{"content": "Paris"}',
"vector_distance": 0.1,
}
]
)
with patch.object(
redis_semantic_cache,
"_get_cache_key_filter_expression",
return_value="cache-key-filter",
):
metadata = {}
result = redis_semantic_cache.get_cache(
key="test_key",
messages=[{"content": "What is the capital of France?"}],
metadata=metadata,
)
assert result is None
assert metadata["semantic-similarity"] == 0.0
def test_redis_semantic_cache_set_cache_stores_cache_key_filter(monkeypatch):
semantic_cache_mock = MagicMock()
custom_vectorizer_mock = MagicMock()
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
redis_semantic_cache.llmcache.store = MagicMock()
redis_semantic_cache.set_cache(
key="test_key",
value={"content": "Paris"},
messages=[{"content": "What is the capital of France?"}],
ttl=60,
)
redis_semantic_cache.llmcache.store.assert_called_once_with(
"What is the capital of France?",
"{'content': 'Paris'}",
filters={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"},
ttl=60,
)
def test_redis_semantic_cache_uses_isolated_index_for_old_schema(monkeypatch):
fallback_cache_mock = MagicMock()
semantic_cache_mock = MagicMock(
side_effect=[
ValueError("stored index schema differs from requested fields"),
fallback_cache_mock,
]
)
custom_vectorizer_mock = MagicMock()
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
redis_semantic_cache = RedisSemanticCache(
similarity_threshold=0.8,
index_name="existing_index",
)
assert redis_semantic_cache.llmcache is fallback_cache_mock
assert semantic_cache_mock.call_args_list[0].kwargs["name"] == "existing_index"
assert (
semantic_cache_mock.call_args_list[1].kwargs["name"]
== "existing_index_isolated"
)
assert semantic_cache_mock.call_args_list[1].kwargs["filterable_fields"] == [
RedisSemanticCache._cache_key_filterable_field()
]
def test_redis_semantic_cache_overwrites_stale_isolated_index(monkeypatch):
fallback_cache_mock = MagicMock()
semantic_cache_mock = MagicMock(
side_effect=[
ValueError("Existing index schema does not match"),
ValueError("Existing index schema does not match"),
fallback_cache_mock,
]
)
custom_vectorizer_mock = MagicMock()
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
redis_semantic_cache = RedisSemanticCache(
similarity_threshold=0.8,
index_name="existing_index",
)
assert redis_semantic_cache.llmcache is fallback_cache_mock
assert (
semantic_cache_mock.call_args_list[2].kwargs["name"]
== "existing_index_isolated"
)
assert semantic_cache_mock.call_args_list[2].kwargs["overwrite"] is True
assert semantic_cache_mock.call_args_list[2].kwargs["filterable_fields"] == [
RedisSemanticCache._cache_key_filterable_field()
]
def test_redis_semantic_cache_reraises_unexpected_isolated_index_error(monkeypatch):
semantic_cache_mock = MagicMock(
side_effect=[
ValueError("Existing index schema does not match"),
ValueError("connection failed"),
]
)
custom_vectorizer_mock = MagicMock()
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
with pytest.raises(ValueError, match="connection failed"):
RedisSemanticCache(
similarity_threshold=0.8,
index_name="existing_index",
)
def test_redis_semantic_cache_reraises_unexpected_index_error():
from litellm.caching.redis_semantic_cache import RedisSemanticCache
redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache)
redis_semantic_cache.distance_threshold = 0.2
semantic_cache_mock = MagicMock(side_effect=ValueError("connection failed"))
with pytest.raises(ValueError, match="connection failed"):
redis_semantic_cache._init_semantic_cache(
semantic_cache_cls=semantic_cache_mock,
index_name="existing_index",
redis_url="redis://localhost:6379",
cache_vectorizer=MagicMock(),
)
def test_redis_semantic_cache_matches_bytes_cache_key():
from litellm.caching.redis_semantic_cache import RedisSemanticCache
redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache)
assert redis_semantic_cache._cache_hit_matches_key(
cache_hit={RedisSemanticCache.CACHE_KEY_FIELD_NAME: b"test_key"},
key="test_key",
)
def test_redis_semantic_cache_rejects_pre_isolation_unscoped_hit():
"""Pre-isolation entries with no cache-key field cannot be safely
reassigned to a caller's scope and are treated as misses."""
from litellm.caching.redis_semantic_cache import RedisSemanticCache
redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache)
cache_hit = {
"prompt": "What is the capital of France?",
"response": '{"content": "Paris"}',
"vector_distance": 0.1,
}
assert not redis_semantic_cache._cache_hit_matches_key(
cache_hit=cache_hit,
key="test_key",
)
def test_redis_semantic_cache_builds_filter_expression(monkeypatch):
class FakeTag:
def __init__(self, field_name):
self.field_name = field_name
def __eq__(self, value):
return (self.field_name, value)
with patch.dict("sys.modules", {"redisvl.query.filter": MagicMock(Tag=FakeTag)}):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache)
assert redis_semantic_cache._get_cache_key_filter_expression("test_key") == (
RedisSemanticCache.CACHE_KEY_FIELD_NAME,
"test_key",
)
@pytest.mark.asyncio
@ -123,6 +400,7 @@ async def test_redis_semantic_cache_async_get_cache(monkeypatch):
"prompt": "What is the capital of France?",
"response": '{"content": "Paris is the capital of France."}',
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key",
}
]
@ -131,16 +409,117 @@ async def test_redis_semantic_cache_async_get_cache(monkeypatch):
return_value=[0.1, 0.2, 0.3]
)
# Test async_get_cache with a message
result = await redis_semantic_cache.async_get_cache(
key="test_key",
messages=[{"content": "What is the capital of France?"}],
metadata={},
)
with patch.object(
redis_semantic_cache,
"_get_cache_key_filter_expression",
return_value="cache-key-filter",
):
# Test async_get_cache with a message
result = await redis_semantic_cache.async_get_cache(
key="test_key",
messages=[{"content": "What is the capital of France?"}],
metadata={},
)
# Verify result is properly parsed
assert result == {"content": "Paris is the capital of France."}
# Verify methods were called
redis_semantic_cache._get_async_embedding.assert_called_once()
redis_semantic_cache.llmcache.acheck.assert_called_once()
redis_semantic_cache.llmcache.acheck.assert_called_once_with(
prompt="What is the capital of France?",
vector=[0.1, 0.2, 0.3],
filter_expression="cache-key-filter",
)
@pytest.mark.asyncio
async def test_redis_semantic_cache_async_get_cache_rejects_unscoped_hit(monkeypatch):
semantic_cache_mock = MagicMock()
custom_vectorizer_mock = MagicMock()
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
redis_semantic_cache.llmcache.acheck = AsyncMock(
return_value=[
{
"prompt": "What is the capital of France?",
"response": '{"content": "Paris"}',
"vector_distance": 0.1,
}
]
)
redis_semantic_cache._get_async_embedding = AsyncMock(
return_value=[0.1, 0.2, 0.3]
)
with patch.object(
redis_semantic_cache,
"_get_cache_key_filter_expression",
return_value="cache-key-filter",
):
result = await redis_semantic_cache.async_get_cache(
key="test_key",
messages=[{"content": "What is the capital of France?"}],
metadata={},
)
assert result is None
@pytest.mark.asyncio
async def test_redis_semantic_cache_async_set_cache_stores_cache_key_filter(
monkeypatch,
):
semantic_cache_mock = MagicMock()
custom_vectorizer_mock = MagicMock()
with patch.dict(
"sys.modules",
{
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
"redisvl.utils.vectorize": MagicMock(
CustomTextVectorizer=custom_vectorizer_mock
),
},
):
from litellm.caching.redis_semantic_cache import RedisSemanticCache
monkeypatch.setenv("REDIS_HOST", "localhost")
monkeypatch.setenv("REDIS_PORT", "6379")
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
redis_semantic_cache.llmcache.astore = AsyncMock()
redis_semantic_cache._get_async_embedding = AsyncMock(
return_value=[0.1, 0.2, 0.3]
)
await redis_semantic_cache.async_set_cache(
key="test_key",
value={"content": "Paris"},
messages=[{"content": "What is the capital of France?"}],
ttl=60,
)
redis_semantic_cache.llmcache.astore.assert_called_once_with(
"What is the capital of France?",
"{'content': 'Paris'}",
vector=[0.1, 0.2, 0.3],
filters={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"},
ttl=60,
)

View File

@ -1292,3 +1292,219 @@ class TestIsRequestBodySafeNestedConfig:
)
is True
)
# ── observability-callback ban (root + metadata) ───────────────────────────
class TestObservabilityCallbackBans:
"""The proxy must reject observability credentials, hosts, and project
identifiers regardless of whether they arrive at the request body root,
in ``metadata`` / ``litellm_metadata``, or in a JSON-string-encoded
metadata blob (multipart/``extra_body`` path).
The ban list is derived from
``litellm.litellm_core_utils.initialize_dynamic_callback_params._supported_callback_params``
minus a small ``_SAFE_CLIENT_CALLBACK_PARAMS`` allow-list, plus
``_EXTRA_BANNED_OBSERVABILITY_PARAMS`` for fields integrations read but
that are not yet in the canonical allow-list. The derivation keeps the
proxy in sync as new integrations are added.
"""
@pytest.fixture(autouse=True)
def _disable_url_validation(self, monkeypatch):
import litellm
monkeypatch.setattr(litellm, "user_url_validation", False, raising=False)
@pytest.mark.parametrize(
"field",
[
"langfuse_public_key",
"langfuse_secret",
"langfuse_secret_key",
"langsmith_api_key",
"langsmith_project",
"langsmith_tenant_id",
"arize_api_key",
"arize_space_key",
"arize_space_id",
"posthog_api_key",
"posthog_api_url",
"braintrust_api_key",
"braintrust_project",
"phoenix_project_name",
"wandb_api_key",
"weave_project_id",
"gcs_bucket_name",
"gcs_path_service_account",
"humanloop_api_key",
"lunary_public_key",
],
)
def test_observability_field_in_request_body_root_is_rejected(self, field):
with pytest.raises(ValueError) as exc:
is_request_body_safe(
request_body={"model": "gpt-4", field: "attacker-value"},
general_settings={},
llm_router=None,
model="gpt-4",
)
assert field in str(exc.value)
@pytest.mark.parametrize(
"metadata_key",
["metadata", "litellm_metadata"],
)
@pytest.mark.parametrize(
"field",
[
"langfuse_host",
"langfuse_secret_key",
"langsmith_api_key",
"posthog_api_url",
"braintrust_project",
"phoenix_project_name",
],
)
def test_observability_field_in_metadata_dict_is_rejected(
self, metadata_key, field
):
# Verifies the metadata walk: a value smuggled inside ``metadata``
# or ``litellm_metadata`` is just as dangerous as the same field
# at the body root, and must hit the same gate.
with pytest.raises(ValueError) as exc:
is_request_body_safe(
request_body={
"model": "gpt-4",
metadata_key: {field: "attacker-value"},
},
general_settings={},
llm_router=None,
model="gpt-4",
)
assert field in str(exc.value)
@pytest.mark.parametrize(
"metadata_key",
["metadata", "litellm_metadata"],
)
def test_observability_field_in_json_string_metadata_is_rejected(
self, metadata_key
):
# Multipart/form-data and ``extra_body`` callers send metadata as a
# JSON-encoded string. The bouncer parses it before applying the
# banned-params check so the JSON-string path can't smuggle past
# the ``isinstance(dict)`` guard.
import json
with pytest.raises(ValueError) as exc:
is_request_body_safe(
request_body={
"model": "gpt-4",
metadata_key: json.dumps(
{"langfuse_host": "https://attacker.example"}
),
},
general_settings={},
llm_router=None,
model="gpt-4",
)
assert "langfuse_host" in str(exc.value)
def test_admin_opt_in_allows_metadata_credential_passthrough(self):
# The opt-in gate covers the metadata path the same way it covers
# the root path — operators running BYO observability with
# clientside creds flip a single flag and both paths work.
assert (
is_request_body_safe(
request_body={
"model": "gpt-4",
"metadata": {
"langfuse_host": "https://my-langfuse.example",
"langfuse_public_key": "pk-mine",
"langfuse_secret_key": "sk-mine",
},
},
general_settings={"allow_client_side_credentials": True},
llm_router=None,
model="gpt-4",
)
is True
)
def test_safe_per_request_observability_metadata_is_allowed(self):
# Informational fields (sampling rate, prompt version) describe
# the request being logged — they don't choose the destination or
# credentials, so they must remain accepted from clients without
# the opt-in flag.
assert (
is_request_body_safe(
request_body={
"model": "gpt-4",
"metadata": {
"langfuse_prompt_version": "v2",
"langsmith_sampling_rate": 0.1,
},
},
general_settings={},
llm_router=None,
model="gpt-4",
)
is True
)
def test_model_level_allow_does_not_skip_subsequent_banned_params(monkeypatch):
"""Greptile P1: ``_check_banned_params`` previously ``return``-ed when a
deployment's ``configurable_clientside_auth_params`` permitted one
banned field, exiting before any later banned field in the same body
was checked. The metadata walk this PR adds multiplies the surface
where that bypass matters: a body pairing a model-level-allowed
``api_base`` with an observability credential like ``langfuse_host``
must still reject on the second field, not silently pass."""
from litellm.proxy.auth import auth_utils
monkeypatch.setattr(
auth_utils,
"_allow_model_level_clientside_configurable_parameters",
lambda model, param, request_body_value, llm_router: param == "api_base",
)
with pytest.raises(ValueError) as exc:
is_request_body_safe(
request_body={
"model": "gpt-4",
"api_base": "https://allowed-by-deployment.example",
"langfuse_host": "https://attacker.example",
},
general_settings={},
llm_router=None,
model="gpt-4",
)
assert "langfuse_host" in str(exc.value)
def test_observability_ban_covers_canonical_supported_callback_params():
"""Guard test: every entry in the canonical
``_supported_callback_params`` allow-list must end up either banned by
the proxy or explicitly safe-listed. New integrations added to that
list are banned by default (the safe failure mode); flagging them as
safe is an explicit decision recorded in
``_SAFE_CLIENT_CALLBACK_PARAMS``."""
from litellm.litellm_core_utils.initialize_dynamic_callback_params import (
_supported_callback_params,
)
from litellm.proxy.auth.auth_utils import (
_BANNED_REQUEST_BODY_PARAMS,
_SAFE_CLIENT_CALLBACK_PARAMS,
)
banned = set(_BANNED_REQUEST_BODY_PARAMS)
for param in _supported_callback_params:
assert param in banned or param in _SAFE_CLIENT_CALLBACK_PARAMS, (
f"{param} is in _supported_callback_params but neither banned nor "
f"safe-listed. Add it to _SAFE_CLIENT_CALLBACK_PARAMS if it is an "
f"informational per-request field; otherwise the derivation will "
f"ban it automatically."
)

View File

@ -596,6 +596,59 @@ async def test_add_litellm_data_to_request_strips_user_control_fields():
assert "pillar_response_headers" not in snapshot_body["metadata"]
@pytest.mark.asyncio
@pytest.mark.parametrize(
"control_field",
["callbacks", "service_callback", "logger_fn", "litellm_disabled_callbacks"],
)
async def test_add_litellm_data_to_request_strips_callback_control_fields(
control_field,
):
"""``callbacks`` / ``service_callback`` / ``logger_fn`` get appended to
the worker-wide ``litellm.{input,success,failure,_async_*,service}_callback``
lists and ``litellm.user_logger_fn`` from inside ``function_setup``
one request poisons every subsequent caller in that worker.
``litellm_disabled_callbacks`` is the inverse: a request-body value
silently disables admin-configured audit/observability for the call.
None has a documented per-request use, so all four are stripped at
the proxy boundary alongside the existing internal-only fields."""
request_mock = MagicMock(spec=Request)
request_mock.url.path = "/v1/chat/completions"
request_mock.url = MagicMock()
request_mock.url.__str__.return_value = "http://localhost/v1/chat/completions"
request_mock.method = "POST"
request_mock.query_params = {}
request_mock.headers = {"Content-Type": "application/json"}
request_mock.client = MagicMock()
request_mock.client.host = "127.0.0.1"
sample_value = (
["langfuse"]
if control_field
in ("callbacks", "service_callback", "litellm_disabled_callbacks")
else "module.func"
)
updated = await add_litellm_data_to_request(
data={
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "hi"}],
control_field: sample_value,
},
request=request_mock,
user_api_key_dict=UserAPIKeyAuth(api_key="hashed-key"),
proxy_config=MagicMock(),
general_settings={},
version="test-version",
)
assert control_field not in updated
# The post-strip body snapshot used by audit/spend logging must also
# not retain the attacker-injected control field.
snapshot_body = updated["proxy_server_request"]["body"]
assert control_field not in snapshot_body
@pytest.mark.asyncio
async def test_add_litellm_data_to_request_allows_client_mock_response_with_admin_opt_in():
request_mock = MagicMock(spec=Request)

View File

@ -36,6 +36,33 @@ from litellm.proxy.vector_store_endpoints.utils import (
from litellm.types.utils import LlmProviders
def _serialize_litellm_params(litellm_params):
"""Serialize ``litellm_params`` to a string for substring assertions.
The redact helper preserves the persisted shape string in, string
out; dict in, dict out so callers that just want to assert "this
secret never appears" need a single text representation either way.
"""
import json
if isinstance(litellm_params, str):
return litellm_params
return json.dumps(litellm_params or {})
@pytest.fixture(autouse=True)
def _reset_embedding_config_cache():
"""The use-time embedding-config resolver caches results in process
memory across calls. Reset it before every test so the resolver
actually exercises the router/DB path under test instead of returning
a value cached by an earlier test."""
from litellm.proxy.vector_store_endpoints import management_endpoints
management_endpoints._embedding_config_cache = None
yield
management_endpoints._embedding_config_cache = None
@pytest.mark.asyncio
async def test_router_avector_store_search_passes_correct_args():
"""
@ -170,6 +197,93 @@ async def test_update_request_data_with_litellm_managed_vector_store_registry():
assert result == original_data
@pytest.mark.asyncio
async def test_update_request_data_resolves_embedding_config_at_use_time():
"""When the persisted vector store row carries only a
``litellm_embedding_model`` reference (the new behaviour after
moving the auto-resolve out of write time), the request-handling
layer must resolve the embedding config so the downstream embed
call still has ``api_key`` / ``api_base`` / ``api_version``. The
resolved config lives in this per-request data dict only never
persisted."""
mock_vector_store: LiteLLM_ManagedVectorStore = {
"vector_store_id": "test_store",
"custom_llm_provider": "azure_ai",
"litellm_params": {
"litellm_embedding_model": "azure/text-embedding-3-large",
# Note: no litellm_embedding_config persisted
},
}
mock_registry = MagicMock()
mock_registry.get_litellm_managed_vector_store_from_registry.return_value = (
mock_vector_store
)
resolved = {
"api_key": "use-time-resolved-key",
"api_base": "https://my-azure.example",
"api_version": "2024-09-01",
}
with (
patch.object(litellm, "vector_store_registry", mock_registry),
patch(
"litellm.proxy.vector_store_endpoints.endpoints._resolve_embedding_config",
new=AsyncMock(return_value=resolved),
),
):
result = await _update_request_data_with_litellm_managed_vector_store_registry(
data={}, vector_store_id="test_store"
)
assert result["litellm_embedding_model"] == "azure/text-embedding-3-large"
assert result["litellm_embedding_config"] == resolved
@pytest.mark.asyncio
async def test_update_request_data_passes_through_legacy_embedding_config():
"""A vector store row created by an older proxy version may already
carry a fully-resolved ``litellm_embedding_config`` in its persisted
``litellm_params`` (the very leak this PR closes). Those legacy rows
must still work the use-time resolver skips re-resolution when
the config is already present so the embed call keeps succeeding."""
legacy_config = {
"api_key": "legacy-cleartext-key",
"api_base": "https://legacy-azure.example",
"api_version": "2024-01-01",
}
mock_vector_store: LiteLLM_ManagedVectorStore = {
"vector_store_id": "legacy_store",
"custom_llm_provider": "azure_ai",
"litellm_params": {
"litellm_embedding_model": "azure/text-embedding-3-large",
"litellm_embedding_config": legacy_config,
},
}
mock_registry = MagicMock()
mock_registry.get_litellm_managed_vector_store_from_registry.return_value = (
mock_vector_store
)
resolve_mock = AsyncMock()
with (
patch.object(litellm, "vector_store_registry", mock_registry),
patch(
"litellm.proxy.vector_store_endpoints.endpoints._resolve_embedding_config",
new=resolve_mock,
),
):
result = await _update_request_data_with_litellm_managed_vector_store_registry(
data={}, vector_store_id="legacy_store"
)
assert result["litellm_embedding_config"] == legacy_config
resolve_mock.assert_not_awaited()
class TestCheckVectorStorePermission:
"""Test suite for check_vector_store_permission function."""
@ -1417,20 +1531,25 @@ async def test_new_vector_store_auto_resolves_embedding_config():
)
assert result["status"] == "success"
# Verify that embedding config was resolved and included in the create call
# Auto-resolve no longer happens at create time — the persisted row
# carries only the model reference, never the resolved cleartext
# credential. Resolution now happens at request-handling time inside
# ``_update_request_data_with_litellm_managed_vector_store_registry``,
# where the resolved config lives in per-request memory and is never
# written to the database.
litellm_params_json = captured_create_data.get("litellm_params")
assert litellm_params_json is not None
litellm_params_dict = json.loads(litellm_params_json)
assert "litellm_embedding_config" in litellm_params_dict
assert (
litellm_params_dict["litellm_embedding_config"]["api_key"] == "resolved-api-key"
)
assert (
litellm_params_dict["litellm_embedding_config"]["api_base"]
== "https://api.openai.com"
)
assert (
litellm_params_dict["litellm_embedding_config"]["api_version"] == "2024-01-01"
assert "litellm_embedding_config" not in litellm_params_dict
assert litellm_params_dict["litellm_embedding_model"] == "text-embedding-ada-002"
# The response must also not echo a cleartext credential — even on
# the create response, where redaction guards against caller-supplied
# cleartext or pre-existing rows that were created by an earlier
# proxy version.
response_vs = result["vector_store"]
assert "resolved-api-key" not in _serialize_litellm_params(
response_vs.get("litellm_params")
)
@ -1578,6 +1697,43 @@ async def test_resolve_embedding_config_tries_router_then_db():
mock_prisma_client.db.litellm_proxymodeltable.find_first.assert_not_called()
@pytest.mark.asyncio
async def test_resolve_embedding_config_caches_result():
"""The first lookup should hit the router/DB; subsequent lookups for
the same model name should return the cached value without touching
the router or the database."""
from litellm.types.router import Deployment, LiteLLM_Params
mock_prisma_client = MagicMock()
mock_router = MagicMock()
mock_litellm_params = MagicMock(spec=LiteLLM_Params)
mock_litellm_params.api_key = "router-api-key"
mock_litellm_params.api_base = "https://router-api-base.com"
mock_litellm_params.api_version = None
mock_deployment = MagicMock(spec=Deployment)
mock_deployment.litellm_params = mock_litellm_params
mock_router.get_deployment_by_model_group_name.return_value = mock_deployment
first = await _resolve_embedding_config(
embedding_model="cached-model",
prisma_client=mock_prisma_client,
llm_router=mock_router,
)
assert first is not None
assert mock_router.get_deployment_by_model_group_name.call_count == 1
second = await _resolve_embedding_config(
embedding_model="cached-model",
prisma_client=mock_prisma_client,
llm_router=mock_router,
)
assert second == first
# Router (and by extension the DB) was not consulted again.
assert mock_router.get_deployment_by_model_group_name.call_count == 1
@pytest.mark.asyncio
async def test_resolve_embedding_config_falls_back_to_db():
"""Test that _resolve_embedding_config falls back to DB when router doesn't have the model."""
@ -1687,21 +1843,18 @@ async def test_new_vector_store_auto_resolves_from_router():
)
assert result["status"] == "success"
# Verify that embedding config was resolved from router and included in the create call
# Resolution against the router happens at request-handling time now,
# not at row creation. The persisted ``litellm_params`` carries only
# the model reference, never the cleartext credential.
litellm_params_json = captured_create_data.get("litellm_params")
assert litellm_params_json is not None
litellm_params_dict = json.loads(litellm_params_json)
assert "litellm_embedding_config" in litellm_params_dict
assert (
litellm_params_dict["litellm_embedding_config"]["api_key"]
== "router-resolved-api-key"
)
assert (
litellm_params_dict["litellm_embedding_config"]["api_base"]
== "https://router-resolved-base.com"
)
assert (
litellm_params_dict["litellm_embedding_config"]["api_version"] == "2024-03-01"
assert "litellm_embedding_config" not in litellm_params_dict
assert litellm_params_dict["litellm_embedding_model"] == "config-embedding-model"
response_vs = result["vector_store"]
assert "router-resolved-api-key" not in _serialize_litellm_params(
response_vs.get("litellm_params")
)