Merge remote-tracking branch 'upstream/litellm_internal_staging' into codex/skills-containers-tenant-guard
This commit is contained in:
commit
b5a14f22d6
@ -11,17 +11,23 @@ Has 4 methods:
|
||||
import ast
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, cast
|
||||
import os
|
||||
from typing import Any, Dict, cast
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.constants import QDRANT_SCALAR_QUANTILE, QDRANT_VECTOR_SIZE
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
get_str_from_messages,
|
||||
)
|
||||
from litellm.types.utils import EmbeddingResponse
|
||||
|
||||
from .base_cache import BaseCache
|
||||
|
||||
|
||||
class QdrantSemanticCache(BaseCache):
|
||||
CACHE_KEY_FIELD_NAME = "litellm_cache_key"
|
||||
|
||||
def __init__( # noqa: PLR0915
|
||||
self,
|
||||
qdrant_api_base=None,
|
||||
@ -33,8 +39,6 @@ class QdrantSemanticCache(BaseCache):
|
||||
host_type=None,
|
||||
vector_size=None,
|
||||
):
|
||||
import os
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
@ -115,7 +119,9 @@ class QdrantSemanticCache(BaseCache):
|
||||
print_verbose(
|
||||
f"Collection already exists.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
self._ensure_cache_key_payload_index()
|
||||
else:
|
||||
quantization_params: Dict[str, Any]
|
||||
if quantization_config is None or quantization_config == "binary":
|
||||
quantization_params = {
|
||||
"binary": {
|
||||
@ -156,6 +162,7 @@ class QdrantSemanticCache(BaseCache):
|
||||
print_verbose(
|
||||
f"New collection created.\nCollection details:{self.collection_info}"
|
||||
)
|
||||
self._ensure_cache_key_payload_index()
|
||||
else:
|
||||
raise Exception("Error while creating new collection")
|
||||
|
||||
@ -170,15 +177,94 @@ class QdrantSemanticCache(BaseCache):
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
def _get_qdrant_cache_key_filter(self, key: str) -> dict:
|
||||
return {
|
||||
"must": [
|
||||
{
|
||||
"key": self.CACHE_KEY_FIELD_NAME,
|
||||
"match": {"value": str(key)},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
def _add_cache_key_filter_to_search_data(self, data: dict, key: str) -> None:
|
||||
data["filter"] = self._get_qdrant_cache_key_filter(key)
|
||||
|
||||
def _ensure_cache_key_payload_index(self) -> None:
|
||||
try:
|
||||
response = self.sync_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/index",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"field_name": self.CACHE_KEY_FIELD_NAME,
|
||||
"field_schema": "keyword",
|
||||
},
|
||||
)
|
||||
if response.status_code not in (200, 201):
|
||||
print_verbose(
|
||||
"Qdrant semantic-cache could not create cache-key payload index: "
|
||||
f"{response.text}"
|
||||
)
|
||||
except Exception as exc:
|
||||
print_verbose(
|
||||
"Qdrant semantic-cache could not create cache-key payload index: "
|
||||
f"{str(exc)}"
|
||||
)
|
||||
|
||||
def _payload_matches_cache_key(self, payload: dict, key: str) -> bool:
|
||||
# Pre-isolation points stored only prompt + response with no cache-key
|
||||
# payload field. Reassigning them to a caller's key would risk
|
||||
# cross-scope hits, so they're treated as misses and re-populated on
|
||||
# the next set_cache.
|
||||
cached_key = payload.get(self.CACHE_KEY_FIELD_NAME)
|
||||
return cached_key is not None and str(cached_key) == str(key)
|
||||
|
||||
async def _get_async_embedding(self, prompt: str, **kwargs) -> Any:
|
||||
llm_model_list = None
|
||||
llm_router = None
|
||||
|
||||
try:
|
||||
from litellm.proxy.proxy_server import (
|
||||
llm_model_list as proxy_llm_model_list,
|
||||
llm_router as proxy_llm_router,
|
||||
)
|
||||
|
||||
llm_model_list = proxy_llm_model_list
|
||||
llm_router = proxy_llm_router
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
return await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
|
||||
return await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose(f"qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
from litellm._uuid import uuid
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
prompt = get_str_from_messages(messages)
|
||||
|
||||
# create an embedding for prompt
|
||||
embedding_response = cast(
|
||||
@ -202,6 +288,7 @@ class QdrantSemanticCache(BaseCache):
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": embedding,
|
||||
"payload": {
|
||||
self.CACHE_KEY_FIELD_NAME: str(key),
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
@ -220,9 +307,7 @@ class QdrantSemanticCache(BaseCache):
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
prompt = get_str_from_messages(messages)
|
||||
|
||||
# convert to embedding
|
||||
embedding_response = cast(
|
||||
@ -249,6 +334,7 @@ class QdrantSemanticCache(BaseCache):
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
self._add_cache_key_filter_to_search_data(data=data, key=key)
|
||||
|
||||
search_response = self.sync_client.post(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
@ -258,21 +344,33 @@ class QdrantSemanticCache(BaseCache):
|
||||
results = search_response.json()["result"]
|
||||
|
||||
if results is None:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
similarity = results[0]["score"]
|
||||
cached_prompt = results[0]["payload"]["text"]
|
||||
payload = results[0]["payload"]
|
||||
if not self._payload_matches_cache_key(payload=payload, key=key):
|
||||
print_verbose("Qdrant semantic-cache hit did not match cache key scope")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
cached_prompt = payload["text"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
f"semantic cache: similarity threshold: {self.similarity_threshold}, similarity: {similarity}, prompt: {prompt}, closest_cached_prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["payload"]["response"]
|
||||
cached_value = payload["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
@ -285,40 +383,12 @@ class QdrantSemanticCache(BaseCache):
|
||||
async def async_set_cache(self, key, value, **kwargs):
|
||||
from litellm._uuid import uuid
|
||||
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
print_verbose(f"async qdrant semantic-cache set_cache, kwargs: {kwargs}")
|
||||
|
||||
# get the prompt
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
# create an embedding for prompt
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
prompt = get_str_from_messages(messages)
|
||||
embedding_response = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
@ -332,6 +402,7 @@ class QdrantSemanticCache(BaseCache):
|
||||
"id": str(uuid.uuid4()),
|
||||
"vector": embedding,
|
||||
"payload": {
|
||||
self.CACHE_KEY_FIELD_NAME: str(key),
|
||||
"text": prompt,
|
||||
"response": value,
|
||||
},
|
||||
@ -348,38 +419,12 @@ class QdrantSemanticCache(BaseCache):
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
print_verbose(f"async qdrant semantic-cache get_cache, kwargs: {kwargs}")
|
||||
from litellm.proxy.proxy_server import llm_model_list, llm_router
|
||||
|
||||
# get the messages
|
||||
messages = kwargs["messages"]
|
||||
prompt = ""
|
||||
for message in messages:
|
||||
prompt += message["content"]
|
||||
prompt = get_str_from_messages(messages)
|
||||
|
||||
router_model_names = (
|
||||
[m["model_name"] for m in llm_model_list]
|
||||
if llm_model_list is not None
|
||||
else []
|
||||
)
|
||||
if llm_router is not None and self.embedding_model in router_model_names:
|
||||
user_api_key = kwargs.get("metadata", {}).get("user_api_key", "")
|
||||
embedding_response = await llm_router.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
metadata={
|
||||
"user_api_key": user_api_key,
|
||||
"semantic-cache-embedding": True,
|
||||
"trace_id": kwargs.get("metadata", {}).get("trace_id", None),
|
||||
},
|
||||
)
|
||||
else:
|
||||
# convert to embedding
|
||||
embedding_response = await litellm.aembedding(
|
||||
model=self.embedding_model,
|
||||
input=prompt,
|
||||
cache={"no-store": True, "no-cache": True},
|
||||
)
|
||||
embedding_response = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# get the embedding
|
||||
embedding = embedding_response["data"][0]["embedding"]
|
||||
@ -396,6 +441,7 @@ class QdrantSemanticCache(BaseCache):
|
||||
"limit": 1,
|
||||
"with_payload": True,
|
||||
}
|
||||
self._add_cache_key_filter_to_search_data(data=data, key=key)
|
||||
|
||||
search_response = await self.async_client.post(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points/search",
|
||||
@ -414,7 +460,13 @@ class QdrantSemanticCache(BaseCache):
|
||||
return None
|
||||
|
||||
similarity = results[0]["score"]
|
||||
cached_prompt = results[0]["payload"]["text"]
|
||||
payload = results[0]["payload"]
|
||||
if not self._payload_matches_cache_key(payload=payload, key=key):
|
||||
print_verbose("Qdrant semantic-cache hit did not match cache key scope")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
cached_prompt = payload["text"]
|
||||
|
||||
# check similarity, if more than self.similarity_threshold, return results
|
||||
print_verbose(
|
||||
@ -426,7 +478,7 @@ class QdrantSemanticCache(BaseCache):
|
||||
|
||||
if similarity >= self.similarity_threshold:
|
||||
# cache hit !
|
||||
cached_value = results[0]["payload"]["response"]
|
||||
cached_value = payload["response"]
|
||||
print_verbose(
|
||||
f"got a cache hit, similarity: {similarity}, Current prompt: {prompt}, cached_prompt: {cached_prompt}"
|
||||
)
|
||||
|
||||
@ -35,6 +35,7 @@ class RedisSemanticCache(BaseCache):
|
||||
"""
|
||||
|
||||
DEFAULT_REDIS_INDEX_NAME: str = "litellm_semantic_cache_index"
|
||||
CACHE_KEY_FIELD_NAME: str = "litellm_cache_key"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@ -66,8 +67,8 @@ class RedisSemanticCache(BaseCache):
|
||||
Exception: If similarity_threshold is not provided or required Redis
|
||||
connection information is missing
|
||||
"""
|
||||
from redisvl.extensions.llmcache import SemanticCache
|
||||
from redisvl.utils.vectorize import CustomTextVectorizer
|
||||
from redisvl.extensions.llmcache import SemanticCache # type: ignore[import-not-found, import-untyped]
|
||||
from redisvl.utils.vectorize import CustomTextVectorizer # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
if index_name is None:
|
||||
index_name = self.DEFAULT_REDIS_INDEX_NAME
|
||||
@ -109,14 +110,94 @@ class RedisSemanticCache(BaseCache):
|
||||
# Initialize the Redis vectorizer and cache
|
||||
cache_vectorizer = CustomTextVectorizer(self._get_embedding)
|
||||
|
||||
self.llmcache = SemanticCache(
|
||||
name=index_name,
|
||||
self.llmcache = self._init_semantic_cache(
|
||||
semantic_cache_cls=SemanticCache,
|
||||
index_name=index_name,
|
||||
redis_url=redis_url,
|
||||
vectorizer=cache_vectorizer,
|
||||
distance_threshold=self.distance_threshold,
|
||||
overwrite=False,
|
||||
cache_vectorizer=cache_vectorizer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _cache_key_filterable_field(cls) -> Dict[str, str]:
|
||||
return {
|
||||
"name": cls.CACHE_KEY_FIELD_NAME,
|
||||
"type": "tag",
|
||||
}
|
||||
|
||||
def _init_semantic_cache(
|
||||
self,
|
||||
semantic_cache_cls: Any,
|
||||
index_name: str,
|
||||
redis_url: str,
|
||||
cache_vectorizer: Any,
|
||||
) -> Any:
|
||||
def _is_schema_mismatch(exc: ValueError) -> bool:
|
||||
error_message = str(exc).lower()
|
||||
return any(
|
||||
phrase in error_message
|
||||
for phrase in ("schema does not match", "index schema")
|
||||
)
|
||||
|
||||
try:
|
||||
return semantic_cache_cls(
|
||||
name=index_name,
|
||||
redis_url=redis_url,
|
||||
vectorizer=cache_vectorizer,
|
||||
distance_threshold=self.distance_threshold,
|
||||
filterable_fields=[self._cache_key_filterable_field()],
|
||||
overwrite=False,
|
||||
)
|
||||
except ValueError as exc:
|
||||
if not _is_schema_mismatch(exc):
|
||||
raise
|
||||
|
||||
isolated_index_name = f"{index_name}_isolated"
|
||||
print_verbose(
|
||||
"Redis semantic-cache existing index schema is not isolated; "
|
||||
f"using isolated index - {isolated_index_name}"
|
||||
)
|
||||
try:
|
||||
return semantic_cache_cls(
|
||||
name=isolated_index_name,
|
||||
redis_url=redis_url,
|
||||
vectorizer=cache_vectorizer,
|
||||
distance_threshold=self.distance_threshold,
|
||||
filterable_fields=[self._cache_key_filterable_field()],
|
||||
overwrite=False,
|
||||
)
|
||||
except ValueError as isolated_exc:
|
||||
if not _is_schema_mismatch(isolated_exc):
|
||||
raise
|
||||
|
||||
print_verbose(
|
||||
"Redis semantic-cache isolated index schema is stale; "
|
||||
f"recreating isolated index - {isolated_index_name}"
|
||||
)
|
||||
return semantic_cache_cls(
|
||||
name=isolated_index_name,
|
||||
redis_url=redis_url,
|
||||
vectorizer=cache_vectorizer,
|
||||
distance_threshold=self.distance_threshold,
|
||||
filterable_fields=[self._cache_key_filterable_field()],
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
def _get_cache_filters(self, key: str) -> Dict[str, str]:
|
||||
return {self.CACHE_KEY_FIELD_NAME: str(key)}
|
||||
|
||||
def _get_cache_key_filter_expression(self, key: str) -> Any:
|
||||
from redisvl.query.filter import Tag # type: ignore[import-not-found, import-untyped]
|
||||
|
||||
return Tag(self.CACHE_KEY_FIELD_NAME) == str(key)
|
||||
|
||||
def _cache_hit_matches_key(self, cache_hit: Dict[str, Any], key: str) -> bool:
|
||||
# Pre-isolation entries with no ``litellm_cache_key`` field cannot be
|
||||
# safely reassigned to a caller's scope and are treated as misses.
|
||||
cached_key = cache_hit.get(self.CACHE_KEY_FIELD_NAME)
|
||||
if isinstance(cached_key, bytes):
|
||||
cached_key = cached_key.decode("utf-8")
|
||||
return cached_key is not None and str(cached_key) == str(key)
|
||||
|
||||
def _get_ttl(self, **kwargs) -> Optional[int]:
|
||||
"""
|
||||
Get the TTL (time-to-live) value for cache entries.
|
||||
@ -188,7 +269,7 @@ class RedisSemanticCache(BaseCache):
|
||||
Store a value in the semantic cache.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
key: The cache key used to isolate semantic cache entries
|
||||
value: The response value to cache
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
and optional 'ttl' for time-to-live
|
||||
@ -206,12 +287,15 @@ class RedisSemanticCache(BaseCache):
|
||||
prompt = get_str_from_messages(messages)
|
||||
value_str = str(value)
|
||||
|
||||
store_kwargs: Dict[str, Any] = {
|
||||
"filters": self._get_cache_filters(key),
|
||||
}
|
||||
|
||||
# Get TTL and store in Redis semantic cache
|
||||
ttl = self._get_ttl(**kwargs)
|
||||
if ttl is not None:
|
||||
self.llmcache.store(prompt, value_str, ttl=int(ttl))
|
||||
else:
|
||||
self.llmcache.store(prompt, value_str)
|
||||
store_kwargs["ttl"] = int(ttl)
|
||||
self.llmcache.store(prompt, value_str, **store_kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(
|
||||
f"Error setting {value_str or value} in the Redis semantic cache: {str(e)}"
|
||||
@ -222,7 +306,7 @@ class RedisSemanticCache(BaseCache):
|
||||
Retrieve a semantically similar cached response.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
key: The cache key used to isolate semantic cache entries
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
Returns:
|
||||
@ -235,18 +319,29 @@ class RedisSemanticCache(BaseCache):
|
||||
messages = kwargs.get("messages", [])
|
||||
if not messages:
|
||||
print_verbose("No messages provided for semantic cache lookup")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
prompt = get_str_from_messages(messages)
|
||||
# Check the cache for semantically similar prompts
|
||||
results = self.llmcache.check(prompt=prompt)
|
||||
# Check the cache for semantically similar prompts in this exact
|
||||
# LiteLLM cache-key scope.
|
||||
check_kwargs: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"filter_expression": self._get_cache_key_filter_expression(key),
|
||||
}
|
||||
results = self.llmcache.check(**check_kwargs)
|
||||
|
||||
# Return None if no similar prompts found
|
||||
if not results:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
# Process the best matching result
|
||||
cache_hit = results[0]
|
||||
if not self._cache_hit_matches_key(cache_hit=cache_hit, key=key):
|
||||
print_verbose("Redis semantic-cache hit did not match cache key scope")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
vector_distance = float(cache_hit["vector_distance"])
|
||||
|
||||
# Convert vector distance back to similarity score
|
||||
@ -257,6 +352,9 @@ class RedisSemanticCache(BaseCache):
|
||||
cached_prompt = cache_hit["prompt"]
|
||||
cached_response = cache_hit["response"]
|
||||
|
||||
# update kwargs["metadata"] with similarity, don't rewrite the original metadata
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = similarity
|
||||
|
||||
print_verbose(
|
||||
f"Cache hit: similarity threshold: {self.similarity_threshold}, "
|
||||
f"actual similarity: {similarity}, "
|
||||
@ -267,6 +365,7 @@ class RedisSemanticCache(BaseCache):
|
||||
return self._get_cache_logic(cached_response=cached_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error retrieving from Redis semantic cache: {str(e)}")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
|
||||
async def _get_async_embedding(self, prompt: str, **kwargs) -> List[float]:
|
||||
"""
|
||||
@ -321,7 +420,7 @@ class RedisSemanticCache(BaseCache):
|
||||
Asynchronously store a value in the semantic cache.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
key: The cache key used to isolate semantic cache entries
|
||||
value: The response value to cache
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
and optional 'ttl' for time-to-live
|
||||
@ -341,21 +440,20 @@ class RedisSemanticCache(BaseCache):
|
||||
# Generate embedding for the value (response) to cache
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
store_kwargs: Dict[str, Any] = {
|
||||
"vector": prompt_embedding,
|
||||
"filters": self._get_cache_filters(key),
|
||||
}
|
||||
|
||||
# Get TTL and store in Redis semantic cache
|
||||
ttl = self._get_ttl(**kwargs)
|
||||
if ttl is not None:
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
ttl=ttl,
|
||||
)
|
||||
else:
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
vector=prompt_embedding, # Pass through custom embedding
|
||||
)
|
||||
store_kwargs["ttl"] = ttl
|
||||
await self.llmcache.astore(
|
||||
prompt,
|
||||
value_str,
|
||||
**store_kwargs,
|
||||
)
|
||||
except Exception as e:
|
||||
print_verbose(f"Error in async_set_cache: {str(e)}")
|
||||
|
||||
@ -364,7 +462,7 @@ class RedisSemanticCache(BaseCache):
|
||||
Asynchronously retrieve a semantically similar cached response.
|
||||
|
||||
Args:
|
||||
key: The cache key (not directly used in semantic caching)
|
||||
key: The cache key used to isolate semantic cache entries
|
||||
**kwargs: Additional arguments including 'messages' for the prompt
|
||||
|
||||
Returns:
|
||||
@ -385,17 +483,25 @@ class RedisSemanticCache(BaseCache):
|
||||
# Generate embedding for the prompt
|
||||
prompt_embedding = await self._get_async_embedding(prompt, **kwargs)
|
||||
|
||||
# Check the cache for semantically similar prompts
|
||||
results = await self.llmcache.acheck(prompt=prompt, vector=prompt_embedding)
|
||||
# Check the cache for semantically similar prompts in this exact
|
||||
# LiteLLM cache-key scope.
|
||||
check_kwargs: Dict[str, Any] = {
|
||||
"prompt": prompt,
|
||||
"vector": prompt_embedding,
|
||||
"filter_expression": self._get_cache_key_filter_expression(key),
|
||||
}
|
||||
results = await self.llmcache.acheck(**check_kwargs)
|
||||
|
||||
# handle results / cache hit
|
||||
if not results:
|
||||
kwargs.setdefault("metadata", {})[
|
||||
"semantic-similarity"
|
||||
] = 0.0 # TODO why here but not above??
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
|
||||
cache_hit = results[0]
|
||||
if not self._cache_hit_matches_key(cache_hit=cache_hit, key=key):
|
||||
print_verbose("Redis semantic-cache hit did not match cache key scope")
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
vector_distance = float(cache_hit["vector_distance"])
|
||||
|
||||
# Convert vector distance back to similarity
|
||||
|
||||
@ -2,7 +2,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
from functools import lru_cache
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
|
||||
from typing import Any, Dict, FrozenSet, List, Mapping, Optional, Tuple, Union
|
||||
|
||||
from fastapi import HTTPException, Request, status
|
||||
|
||||
@ -173,9 +173,67 @@ def _allow_model_level_clientside_configurable_parameters(
|
||||
# threat shape should be added here.
|
||||
_NESTED_CONFIG_KEYS: Tuple[str, ...] = ("litellm_embedding_config",)
|
||||
|
||||
# Banned root-level params. Same list applies to every entry in
|
||||
# ``_NESTED_CONFIG_KEYS`` because those dicts get spread as ``**kwargs``
|
||||
# into the same outbound calls.
|
||||
# Metadata containers that carry per-request configuration consumed by the
|
||||
# observability callbacks. The same banned-param list applies — a value
|
||||
# under ``metadata.langfuse_host`` redirects the same Langfuse client and
|
||||
# leaks the same credentials as the root-level ``langfuse_host``, but the
|
||||
# original check only walked the request-body root, so the metadata path
|
||||
# was an unintentional bypass.
|
||||
_NESTED_METADATA_KEYS: Tuple[str, ...] = ("metadata", "litellm_metadata")
|
||||
|
||||
# Banned request-body params. The same list applies to every entry in
|
||||
# ``_NESTED_CONFIG_KEYS`` (dicts spread as ``**kwargs`` into outbound
|
||||
# calls) and ``_NESTED_METADATA_KEYS`` (dicts read directly by integration
|
||||
# callbacks), so a single banned name is enforced wherever the field can
|
||||
# reach the call path from.
|
||||
# Per-request observability params that are SAFE to accept from clients.
|
||||
# These describe the request being logged (prompt version, sampling rate)
|
||||
# without choosing the destination or the credentials, so they don't
|
||||
# contribute to the data-exfil primitive that the rest of
|
||||
# ``_supported_callback_params`` does.
|
||||
_SAFE_CLIENT_CALLBACK_PARAMS: FrozenSet[str] = frozenset(
|
||||
{
|
||||
"langfuse_prompt_version",
|
||||
"langsmith_sampling_rate",
|
||||
}
|
||||
)
|
||||
|
||||
# Observability fields that integrations read from the request body or
|
||||
# metadata but that are not (yet) listed in ``_supported_callback_params``.
|
||||
# Listed here so the proxy bans them today; the long-term cleanup is to
|
||||
# fold these into the canonical allowlist so they share one source of
|
||||
# truth with the rest.
|
||||
_EXTRA_BANNED_OBSERVABILITY_PARAMS: FrozenSet[str] = frozenset(
|
||||
{
|
||||
"posthog_api_url",
|
||||
"phoenix_project_name",
|
||||
"wandb_api_key",
|
||||
"weave_project_id",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _build_banned_observability_params() -> FrozenSet[str]:
|
||||
"""Derive the observability ban list from the canonical allowlist.
|
||||
|
||||
``_supported_callback_params`` in
|
||||
``litellm/litellm_core_utils/initialize_dynamic_callback_params.py`` is
|
||||
the single place that enumerates every observability field
|
||||
integrations resolve from kwargs/metadata. Subtract the small set of
|
||||
informational fields (``_SAFE_CLIENT_CALLBACK_PARAMS``) and union with
|
||||
the extras the canonical allowlist hasn't caught up to yet. New
|
||||
integrations added to the canonical allowlist are banned by default,
|
||||
which is the safe failure mode.
|
||||
"""
|
||||
from litellm.litellm_core_utils.initialize_dynamic_callback_params import (
|
||||
_supported_callback_params,
|
||||
)
|
||||
|
||||
return (
|
||||
frozenset(_supported_callback_params) - _SAFE_CLIENT_CALLBACK_PARAMS
|
||||
) | _EXTRA_BANNED_OBSERVABILITY_PARAMS
|
||||
|
||||
|
||||
_BANNED_REQUEST_BODY_PARAMS: Tuple[str, ...] = (
|
||||
"api_base",
|
||||
"base_url",
|
||||
@ -190,11 +248,6 @@ _BANNED_REQUEST_BODY_PARAMS: Tuple[str, ...] = (
|
||||
# tokens) to the attacker's host, or coerces the proxy into
|
||||
# authenticating against the attacker's host with admin secrets.
|
||||
"aws_bedrock_runtime_endpoint",
|
||||
"langsmith_base_url",
|
||||
"langfuse_host",
|
||||
"posthog_host",
|
||||
"braintrust_host",
|
||||
"slack_webhook_url",
|
||||
# Provider-specific endpoint overrides that flow into the outbound
|
||||
# request via ``optional_params``. Same threat as ``api_base``:
|
||||
# ``s3_endpoint_url`` redirects Bedrock file uploads to attacker
|
||||
@ -203,6 +256,11 @@ _BANNED_REQUEST_BODY_PARAMS: Tuple[str, ...] = (
|
||||
"s3_endpoint_url",
|
||||
"sagemaker_base_url",
|
||||
"deployment_url",
|
||||
# Observability credentials, hosts, and project identifiers: derived
|
||||
# from the canonical ``_supported_callback_params`` allowlist so new
|
||||
# integrations are covered automatically. Sorted for stable iteration
|
||||
# order and reviewable diffs.
|
||||
*sorted(_build_banned_observability_params()),
|
||||
)
|
||||
|
||||
|
||||
@ -221,6 +279,8 @@ def _check_banned_params(
|
||||
if param not in body:
|
||||
continue
|
||||
if general_settings.get("allow_client_side_credentials") is True:
|
||||
# Proxy-wide opt-in: every banned param is permitted, exit
|
||||
# entirely so the rest of the loop doesn't waste work.
|
||||
return
|
||||
if (
|
||||
_allow_model_level_clientside_configurable_parameters(
|
||||
@ -231,7 +291,12 @@ def _check_banned_params(
|
||||
)
|
||||
is True
|
||||
):
|
||||
return
|
||||
# Per-param opt-in: only THIS param is permitted by the
|
||||
# deployment's ``configurable_clientside_auth_params``. Skip
|
||||
# to the next banned param so a body that pairs an allowed
|
||||
# ``api_base`` with an unallowed ``langfuse_host`` is still
|
||||
# rejected for the second field.
|
||||
continue
|
||||
raise ValueError(
|
||||
f"Rejected Request: {param} is not allowed in request body. "
|
||||
"Clientside passthrough requires explicit admin opt-in via "
|
||||
@ -275,9 +340,33 @@ def is_request_body_safe(
|
||||
nested = request_body.get(nested_key)
|
||||
if isinstance(nested, dict):
|
||||
_check_banned_params(nested, general_settings, llm_router, model)
|
||||
for metadata_key in _NESTED_METADATA_KEYS:
|
||||
metadata = _coerce_metadata_to_dict(request_body.get(metadata_key))
|
||||
if metadata is not None:
|
||||
_check_banned_params(metadata, general_settings, llm_router, model)
|
||||
return True
|
||||
|
||||
|
||||
def _coerce_metadata_to_dict(value: Any) -> Optional[Dict[str, Any]]:
|
||||
"""Return ``value`` as a dict, parsing it from JSON if delivered as a string.
|
||||
|
||||
Multipart/form-data and ``extra_body`` callers send ``litellm_metadata``
|
||||
as a JSON-encoded string; the proxy parses it into a dict later in
|
||||
``add_litellm_data_to_request``, but the auth-time bouncer runs first
|
||||
and would otherwise miss the banned-param check on a still-stringified
|
||||
metadata blob.
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
|
||||
parsed = safe_json_loads(value)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
async def pre_db_read_auth_checks(
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
|
||||
@ -121,6 +121,19 @@ _UNTRUSTED_ROOT_CONTROL_FIELDS = (
|
||||
"pillar_response_headers",
|
||||
"_guardrail_pipelines",
|
||||
"_pipeline_managed_guardrails",
|
||||
# Callback-registration fields. ``callbacks``, ``service_callback``,
|
||||
# and ``logger_fn`` are read by ``litellm.utils.function_setup`` and
|
||||
# appended to process-wide ``litellm.{input,success,failure,_async_*,
|
||||
# service}_callback`` lists / ``litellm.user_logger_fn`` — one request
|
||||
# poisons the worker for every subsequent caller.
|
||||
# ``litellm_disabled_callbacks`` is the inverse primitive: the
|
||||
# legitimate path reads it from key/team metadata, the request-body
|
||||
# version silently turns off admin-configured audit/observability
|
||||
# for the caller's request.
|
||||
"callbacks",
|
||||
"service_callback",
|
||||
"logger_fn",
|
||||
"litellm_disabled_callbacks",
|
||||
)
|
||||
|
||||
_UNTRUSTED_METADATA_CONTROL_FIELDS = (
|
||||
|
||||
@ -8,6 +8,9 @@ from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_request_processing import ProxyBaseLLMRequestProcessing
|
||||
from litellm.proxy.utils import jsonify_object
|
||||
from litellm.proxy.vector_store_endpoints.management_endpoints import (
|
||||
_resolve_embedding_config,
|
||||
)
|
||||
from litellm.proxy.vector_store_endpoints.utils import (
|
||||
assert_user_can_access_vector_store,
|
||||
get_litellm_managed_vector_store,
|
||||
@ -56,6 +59,30 @@ async def _update_request_data_with_litellm_managed_vector_store_registry(
|
||||
|
||||
if "litellm_params" in vector_store_to_run:
|
||||
litellm_params = vector_store_to_run.get("litellm_params", {}) or {}
|
||||
# Resolve ``litellm_embedding_config`` here, at request-handling
|
||||
# time, instead of at row-creation time. The resolved
|
||||
# ``api_key`` / ``api_base`` / ``api_version`` lives only in
|
||||
# this per-request ``data`` dict and is never persisted.
|
||||
# Legacy rows that already carry a resolved (cleartext)
|
||||
# ``litellm_embedding_config`` skip the lookup and pass through
|
||||
# unchanged so the embed call keeps working.
|
||||
embedding_model = litellm_params.get("litellm_embedding_model")
|
||||
if embedding_model and not litellm_params.get("litellm_embedding_config"):
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
resolved_config = await _resolve_embedding_config(
|
||||
embedding_model=embedding_model, prisma_client=prisma_client
|
||||
)
|
||||
if resolved_config:
|
||||
# Build a fresh dict via spread instead of mutating
|
||||
# ``litellm_params`` in place — the registry hands back
|
||||
# a reference to its cached object, so an in-place
|
||||
# update would persist the resolved cleartext into the
|
||||
# in-memory cache for the lifetime of the process.
|
||||
litellm_params = {
|
||||
**litellm_params,
|
||||
"litellm_embedding_config": resolved_config,
|
||||
}
|
||||
data.update(litellm_params)
|
||||
return data
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ from fastapi import APIRouter, Depends, HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.constants import REDACTED_BY_LITELM_STRING
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
@ -45,6 +46,28 @@ _LITELLM_PARAMS_MASKER = SensitiveDataMasker()
|
||||
|
||||
_REDACT_LITELLM_PARAMS_MAX_DEPTH = 10
|
||||
|
||||
# Use-time embedding-config resolution runs on every vector-store request
|
||||
# whose persisted row carries only a model reference (the post-fix shape).
|
||||
# Without a cache, that's one ``litellm_proxymodeltable.find_first`` per
|
||||
# request — the no-DB-in-critical-path rule. Hold the resolved config in
|
||||
# memory for a short TTL so a hot model name pays the DB lookup at most
|
||||
# once per ``_EMBEDDING_CONFIG_CACHE_TTL`` seconds. Cleartext credentials
|
||||
# only ever live in process memory (never persisted, never echoed in
|
||||
# management responses), so the cache doesn't widen the disclosure surface.
|
||||
_EMBEDDING_CONFIG_CACHE_TTL = 60
|
||||
_EMBEDDING_CONFIG_CACHE_MAX_SIZE = 256
|
||||
_embedding_config_cache: Optional[InMemoryCache] = None
|
||||
|
||||
|
||||
def _get_embedding_config_cache() -> InMemoryCache:
|
||||
global _embedding_config_cache
|
||||
if _embedding_config_cache is None:
|
||||
_embedding_config_cache = InMemoryCache(
|
||||
max_size_in_memory=_EMBEDDING_CONFIG_CACHE_MAX_SIZE,
|
||||
default_ttl=_EMBEDDING_CONFIG_CACHE_TTL,
|
||||
)
|
||||
return _embedding_config_cache
|
||||
|
||||
|
||||
def _redact_sensitive_litellm_params(litellm_params: Any, _depth: int = 0) -> Any:
|
||||
"""
|
||||
@ -303,6 +326,11 @@ async def _resolve_embedding_config(
|
||||
This function first checks the router for config-defined models, then falls back
|
||||
to the database. This allows users to use models defined in either location.
|
||||
|
||||
Results are cached in process memory for ``_EMBEDDING_CONFIG_CACHE_TTL``
|
||||
seconds so the request-handling path doesn't hit the database on every
|
||||
vector-store call. Negative results (model not found) are intentionally
|
||||
not cached to avoid blocking a freshly-added model behind the TTL.
|
||||
|
||||
Args:
|
||||
embedding_model: The embedding model string (e.g., "text-embedding-ada-002" or "azure/text-embedding-3-large")
|
||||
prisma_client: The Prisma client instance
|
||||
@ -314,6 +342,11 @@ async def _resolve_embedding_config(
|
||||
if not embedding_model:
|
||||
return None
|
||||
|
||||
cache = _get_embedding_config_cache()
|
||||
cached = cache.get_cache(embedding_model)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
# Import llm_router if not provided
|
||||
if llm_router is None:
|
||||
try:
|
||||
@ -330,6 +363,7 @@ async def _resolve_embedding_config(
|
||||
verbose_proxy_logger.debug(
|
||||
f"Resolved embedding config from router for model {embedding_model}"
|
||||
)
|
||||
cache.set_cache(embedding_model, router_config)
|
||||
return router_config
|
||||
|
||||
# Fall back to database
|
||||
@ -341,6 +375,7 @@ async def _resolve_embedding_config(
|
||||
verbose_proxy_logger.debug(
|
||||
f"Resolved embedding config from database for model {embedding_model}"
|
||||
)
|
||||
cache.set_cache(embedding_model, db_config)
|
||||
return db_config
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
@ -432,20 +467,17 @@ async def create_vector_store_in_db(
|
||||
if user_id is not None:
|
||||
data_to_create["user_id"] = user_id
|
||||
|
||||
# Handle litellm_params - always provide at least an empty dict
|
||||
# Handle litellm_params - always provide at least an empty dict.
|
||||
# The earlier behaviour resolved ``litellm_embedding_config`` from the
|
||||
# admin-configured router/DB model and persisted the cleartext result
|
||||
# (``api_key``, ``api_base``, ``api_version``) into this row. That
|
||||
# exposed every env-stored embedding-model credential on the
|
||||
# ``/vector_store/{new,info,update,list}`` responses. Keep the user's
|
||||
# raw ``litellm_embedding_model`` reference; resolution now happens in
|
||||
# ``_update_request_data_with_litellm_managed_vector_store_registry``
|
||||
# at request-handling time so the cleartext config exists only in
|
||||
# per-request memory and never reaches the database.
|
||||
if litellm_params:
|
||||
# Auto-resolve embedding config if embedding model is provided but config is not
|
||||
embedding_model = litellm_params.get("litellm_embedding_model")
|
||||
if embedding_model and not litellm_params.get("litellm_embedding_config"):
|
||||
resolved_config = await _resolve_embedding_config(
|
||||
embedding_model=embedding_model, prisma_client=prisma_client
|
||||
)
|
||||
if resolved_config:
|
||||
litellm_params["litellm_embedding_config"] = resolved_config
|
||||
verbose_proxy_logger.info(
|
||||
f"Auto-resolved embedding config for model {embedding_model}"
|
||||
)
|
||||
|
||||
litellm_params_dict = GenericLiteLLMParams(**litellm_params).model_dump(
|
||||
exclude_none=True
|
||||
)
|
||||
@ -531,10 +563,19 @@ async def new_vector_store(
|
||||
user_id=user_api_key_dict.user_id,
|
||||
)
|
||||
|
||||
# Apply the same litellm_params redaction the list / info / update
|
||||
# endpoints already use, so a caller-supplied credential or a
|
||||
# cleartext value persisted by an earlier proxy version doesn't
|
||||
# come back in the response.
|
||||
response_vs = LiteLLM_ManagedVectorStore(**new_vector_store)
|
||||
response_vs["litellm_params"] = _redact_sensitive_litellm_params(
|
||||
new_vector_store.get("litellm_params")
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Vector store {vector_store.get('vector_store_id')} created successfully",
|
||||
"vector_store": new_vector_store,
|
||||
"vector_store": response_vs,
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error creating vector store: {str(e)}")
|
||||
@ -865,24 +906,15 @@ async def update_vector_store(
|
||||
update_data["vector_store_metadata"]
|
||||
)
|
||||
|
||||
# Handle litellm_params if provided
|
||||
# Handle litellm_params if provided. As with the create path, the
|
||||
# embedding-config auto-resolve previously persisted cleartext
|
||||
# credentials into the row; resolution now happens at request-
|
||||
# handling time in
|
||||
# ``_update_request_data_with_litellm_managed_vector_store_registry``
|
||||
# so this row only ever stores the user-supplied
|
||||
# ``litellm_embedding_model`` reference.
|
||||
if "litellm_params" in update_data:
|
||||
_input_litellm_params: dict = update_data.get("litellm_params", {}) or {}
|
||||
|
||||
# Auto-resolve embedding config if embedding model is provided but config is not
|
||||
embedding_model = _input_litellm_params.get("litellm_embedding_model")
|
||||
if embedding_model and not _input_litellm_params.get(
|
||||
"litellm_embedding_config"
|
||||
):
|
||||
resolved_config = await _resolve_embedding_config(
|
||||
embedding_model=embedding_model, prisma_client=prisma_client
|
||||
)
|
||||
if resolved_config:
|
||||
_input_litellm_params["litellm_embedding_config"] = resolved_config
|
||||
verbose_proxy_logger.info(
|
||||
f"Auto-resolved embedding config for model {embedding_model}"
|
||||
)
|
||||
|
||||
litellm_params_dict = GenericLiteLLMParams(
|
||||
**_input_litellm_params
|
||||
).model_dump(exclude_none=True)
|
||||
|
||||
@ -19,9 +19,7 @@ def test_qdrant_semantic_cache_initialization(monkeypatch):
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
|
||||
) as mock_sync_client,
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
|
||||
) as mock_async_client,
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
|
||||
):
|
||||
|
||||
# Mock the collection exists check
|
||||
@ -31,6 +29,9 @@ def test_qdrant_semantic_cache_initialization(monkeypatch):
|
||||
|
||||
mock_sync_client_instance = MagicMock()
|
||||
mock_sync_client_instance.get.return_value = mock_response
|
||||
mock_index_response = MagicMock()
|
||||
mock_index_response.status_code = 200
|
||||
mock_sync_client_instance.put.return_value = mock_index_response
|
||||
mock_sync_client.return_value = mock_sync_client_instance
|
||||
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
@ -48,6 +49,17 @@ def test_qdrant_semantic_cache_initialization(monkeypatch):
|
||||
assert qdrant_cache.qdrant_api_base == "http://test.qdrant.local"
|
||||
assert qdrant_cache.qdrant_api_key == "test_key"
|
||||
assert qdrant_cache.similarity_threshold == 0.8
|
||||
mock_sync_client_instance.put.assert_called_once_with(
|
||||
url="http://test.qdrant.local/collections/test_collection/index",
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"api-key": "test_key",
|
||||
},
|
||||
json={
|
||||
"field_name": QdrantSemanticCache.CACHE_KEY_FIELD_NAME,
|
||||
"field_schema": "keyword",
|
||||
},
|
||||
)
|
||||
|
||||
# Test initialization with missing similarity_threshold
|
||||
with pytest.raises(Exception, match="similarity_threshold must be provided"):
|
||||
@ -67,9 +79,7 @@ def test_qdrant_semantic_cache_get_cache_hit():
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
|
||||
) as mock_sync_client,
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
|
||||
) as mock_async_client,
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
|
||||
):
|
||||
|
||||
# Mock the collection exists check
|
||||
@ -98,6 +108,7 @@ def test_qdrant_semantic_cache_get_cache_hit():
|
||||
"result": [
|
||||
{
|
||||
"payload": {
|
||||
QdrantSemanticCache.CACHE_KEY_FIELD_NAME: "test_key",
|
||||
"text": "What is the capital of France?", # Original prompt
|
||||
"response": '{"id": "test-123", "choices": [{"message": {"content": "Paris is the capital of France."}}]}',
|
||||
},
|
||||
@ -127,6 +138,177 @@ def test_qdrant_semantic_cache_get_cache_hit():
|
||||
|
||||
# Verify search was called
|
||||
qdrant_cache.sync_client.post.assert_called()
|
||||
assert qdrant_cache.sync_client.post.call_args.kwargs["json"]["filter"] == {
|
||||
"must": [
|
||||
{
|
||||
"key": QdrantSemanticCache.CACHE_KEY_FIELD_NAME,
|
||||
"match": {"value": "test_key"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_qdrant_semantic_cache_rejects_unscoped_cache_hit():
|
||||
"""
|
||||
Test QDRANT semantic cache rejects old or unscoped cache hits.
|
||||
|
||||
Legacy points have only text and response payloads, so they cannot be
|
||||
safely migrated to a generated LiteLLM cache key.
|
||||
"""
|
||||
with (
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
|
||||
) as mock_sync_client,
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
|
||||
):
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"result": {"exists": True}}
|
||||
|
||||
mock_sync_client_instance = MagicMock()
|
||||
mock_sync_client_instance.get.return_value = mock_response
|
||||
mock_sync_client.return_value = mock_sync_client_instance
|
||||
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
qdrant_cache = QdrantSemanticCache(
|
||||
collection_name="test_collection",
|
||||
qdrant_api_base="http://test.qdrant.local",
|
||||
qdrant_api_key="test_key",
|
||||
similarity_threshold=0.8,
|
||||
)
|
||||
|
||||
mock_search_response = MagicMock()
|
||||
mock_search_response.status_code = 200
|
||||
mock_search_response.json.return_value = {
|
||||
"result": [
|
||||
{
|
||||
"payload": {
|
||||
"text": "What is the capital of France?",
|
||||
"response": '{"id": "test-123"}',
|
||||
},
|
||||
"score": 0.9,
|
||||
}
|
||||
]
|
||||
}
|
||||
qdrant_cache.sync_client.post = MagicMock(return_value=mock_search_response)
|
||||
|
||||
with patch(
|
||||
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
):
|
||||
metadata = {}
|
||||
result = qdrant_cache.get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of France?"}],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert metadata["semantic-similarity"] == 0.0
|
||||
|
||||
|
||||
def test_qdrant_semantic_cache_payload_index_failure_is_non_blocking():
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
qdrant_cache = QdrantSemanticCache.__new__(QdrantSemanticCache)
|
||||
qdrant_cache.qdrant_api_base = "http://test.qdrant.local"
|
||||
qdrant_cache.collection_name = "test_collection"
|
||||
qdrant_cache.headers = {"Content-Type": "application/json"}
|
||||
qdrant_cache.sync_client = MagicMock()
|
||||
response = MagicMock()
|
||||
response.status_code = 400
|
||||
response.text = "bad index"
|
||||
qdrant_cache.sync_client.put.return_value = response
|
||||
|
||||
qdrant_cache._ensure_cache_key_payload_index()
|
||||
|
||||
qdrant_cache.sync_client.put.assert_called_once()
|
||||
|
||||
|
||||
def test_qdrant_semantic_cache_payload_index_exception_is_non_blocking():
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
qdrant_cache = QdrantSemanticCache.__new__(QdrantSemanticCache)
|
||||
qdrant_cache.qdrant_api_base = "http://test.qdrant.local"
|
||||
qdrant_cache.collection_name = "test_collection"
|
||||
qdrant_cache.headers = {"Content-Type": "application/json"}
|
||||
qdrant_cache.sync_client = MagicMock()
|
||||
qdrant_cache.sync_client.put.side_effect = Exception("boom")
|
||||
|
||||
qdrant_cache._ensure_cache_key_payload_index()
|
||||
|
||||
qdrant_cache.sync_client.put.assert_called_once()
|
||||
|
||||
|
||||
def _mock_qdrant_get_cache_result(qdrant_result):
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
qdrant_cache = QdrantSemanticCache.__new__(QdrantSemanticCache)
|
||||
qdrant_cache.embedding_model = "text-embedding-ada-002"
|
||||
qdrant_cache.qdrant_api_base = "http://test.qdrant.local"
|
||||
qdrant_cache.collection_name = "test_collection"
|
||||
qdrant_cache.headers = {
|
||||
"Content-Type": "application/json",
|
||||
"api-key": "test_key",
|
||||
}
|
||||
qdrant_cache.similarity_threshold = 0.8
|
||||
qdrant_cache.sync_client = MagicMock()
|
||||
|
||||
mock_search_response = MagicMock()
|
||||
mock_search_response.status_code = 200
|
||||
mock_search_response.json.return_value = {"result": qdrant_result}
|
||||
qdrant_cache.sync_client.post.return_value = mock_search_response
|
||||
|
||||
return qdrant_cache, QdrantSemanticCache
|
||||
|
||||
|
||||
@pytest.mark.parametrize("qdrant_result", [None, []])
|
||||
def test_qdrant_semantic_cache_get_cache_sets_metadata_on_empty_miss(qdrant_result):
|
||||
qdrant_cache, _ = _mock_qdrant_get_cache_result(qdrant_result)
|
||||
metadata = {}
|
||||
|
||||
with patch(
|
||||
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
):
|
||||
result = qdrant_cache.get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of Spain?"}],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert metadata["semantic-similarity"] == 0.0
|
||||
|
||||
|
||||
def test_qdrant_semantic_cache_get_cache_sets_metadata_on_below_threshold_miss():
|
||||
from litellm.caching.qdrant_semantic_cache import QdrantSemanticCache
|
||||
|
||||
qdrant_cache, _ = _mock_qdrant_get_cache_result(
|
||||
[
|
||||
{
|
||||
"payload": {
|
||||
QdrantSemanticCache.CACHE_KEY_FIELD_NAME: "test_key",
|
||||
"text": "What is the capital of Spain?",
|
||||
"response": '{"id": "test-456"}',
|
||||
},
|
||||
"score": 0.7,
|
||||
}
|
||||
]
|
||||
)
|
||||
metadata = {}
|
||||
|
||||
with patch(
|
||||
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
):
|
||||
result = qdrant_cache.get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of Spain?"}],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert metadata["semantic-similarity"] == 0.7
|
||||
|
||||
|
||||
def test_qdrant_semantic_cache_get_cache_miss():
|
||||
@ -138,9 +320,7 @@ def test_qdrant_semantic_cache_get_cache_miss():
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
|
||||
) as mock_sync_client,
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
|
||||
) as mock_async_client,
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
|
||||
):
|
||||
|
||||
# Mock the collection exists check
|
||||
@ -230,6 +410,7 @@ async def test_qdrant_semantic_cache_async_get_cache_hit():
|
||||
"result": [
|
||||
{
|
||||
"payload": {
|
||||
QdrantSemanticCache.CACHE_KEY_FIELD_NAME: "test_key",
|
||||
"text": "What is the capital of Spain?", # Original prompt
|
||||
"response": '{"id": "test-456", "choices": [{"message": {"content": "Madrid is the capital of Spain."}}]}',
|
||||
},
|
||||
@ -262,6 +443,16 @@ async def test_qdrant_semantic_cache_async_get_cache_hit():
|
||||
|
||||
# Verify async search was called
|
||||
qdrant_cache.async_client.post.assert_called()
|
||||
assert qdrant_cache.async_client.post.call_args.kwargs["json"][
|
||||
"filter"
|
||||
] == {
|
||||
"must": [
|
||||
{
|
||||
"key": QdrantSemanticCache.CACHE_KEY_FIELD_NAME,
|
||||
"match": {"value": "test_key"},
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -336,9 +527,7 @@ def test_qdrant_semantic_cache_set_cache():
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
|
||||
) as mock_sync_client,
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
|
||||
) as mock_async_client,
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
|
||||
):
|
||||
|
||||
# Mock the collection exists check
|
||||
@ -384,6 +573,12 @@ def test_qdrant_semantic_cache_set_cache():
|
||||
|
||||
# Verify upsert was called
|
||||
qdrant_cache.sync_client.put.assert_called()
|
||||
upsert_payload = qdrant_cache.sync_client.put.call_args.kwargs["json"][
|
||||
"points"
|
||||
][0]["payload"]
|
||||
assert (
|
||||
upsert_payload[QdrantSemanticCache.CACHE_KEY_FIELD_NAME] == "test_key"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -450,6 +645,12 @@ async def test_qdrant_semantic_cache_async_set_cache():
|
||||
|
||||
# Verify async upsert was called
|
||||
qdrant_cache.async_client.put.assert_called()
|
||||
upsert_payload = qdrant_cache.async_client.put.call_args.kwargs["json"][
|
||||
"points"
|
||||
][0]["payload"]
|
||||
assert (
|
||||
upsert_payload[QdrantSemanticCache.CACHE_KEY_FIELD_NAME] == "test_key"
|
||||
)
|
||||
|
||||
|
||||
def test_qdrant_semantic_cache_custom_vector_size():
|
||||
@ -462,9 +663,7 @@ def test_qdrant_semantic_cache_custom_vector_size():
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
|
||||
) as mock_sync_client,
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
|
||||
) as mock_async_client,
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
|
||||
):
|
||||
|
||||
# Mock the collection does NOT exist (so it will be created)
|
||||
@ -505,9 +704,13 @@ def test_qdrant_semantic_cache_custom_vector_size():
|
||||
assert qdrant_cache.vector_size == 768
|
||||
|
||||
# Verify the PUT call to create the collection used vector_size=768
|
||||
put_call = mock_sync_client_instance.put.call_args
|
||||
assert put_call is not None
|
||||
create_payload = put_call.kwargs.get("json") or put_call[1].get("json")
|
||||
put_call = next(
|
||||
call
|
||||
for call in mock_sync_client_instance.put.call_args_list
|
||||
if call.kwargs["url"]
|
||||
== "http://test.qdrant.local/collections/test_collection_768"
|
||||
)
|
||||
create_payload = put_call.kwargs["json"]
|
||||
assert create_payload["vectors"]["size"] == 768
|
||||
assert create_payload["vectors"]["distance"] == "Cosine"
|
||||
|
||||
@ -521,9 +724,7 @@ def test_qdrant_semantic_cache_default_vector_size():
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
|
||||
) as mock_sync_client,
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
|
||||
) as mock_async_client,
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
|
||||
):
|
||||
|
||||
# Mock the collection exists check
|
||||
@ -559,9 +760,7 @@ def test_qdrant_semantic_cache_large_vector_size():
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler._get_httpx_client"
|
||||
) as mock_sync_client,
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler.get_async_httpx_client"
|
||||
) as mock_async_client,
|
||||
patch("litellm.llms.custom_httpx.http_handler.get_async_httpx_client"),
|
||||
):
|
||||
|
||||
# Mock the collection does NOT exist (so it will be created)
|
||||
@ -599,6 +798,11 @@ def test_qdrant_semantic_cache_large_vector_size():
|
||||
assert qdrant_cache.vector_size == 4096
|
||||
|
||||
# Verify the collection was created with 4096
|
||||
put_call = mock_sync_client_instance.put.call_args
|
||||
create_payload = put_call.kwargs.get("json") or put_call[1].get("json")
|
||||
put_call = next(
|
||||
call
|
||||
for call in mock_sync_client_instance.put.call_args_list
|
||||
if call.kwargs["url"]
|
||||
== "http://test.qdrant.local/collections/test_collection_4096"
|
||||
)
|
||||
create_payload = put_call.kwargs["json"]
|
||||
assert create_payload["vectors"]["size"] == 4096
|
||||
|
||||
@ -72,24 +72,301 @@ def test_redis_semantic_cache_get_cache(monkeypatch):
|
||||
"prompt": "What is the capital of France?",
|
||||
"response": '{"content": "Paris is the capital of France."}',
|
||||
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
|
||||
RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key",
|
||||
}
|
||||
]
|
||||
redis_semantic_cache.llmcache.check = MagicMock(return_value=mock_result)
|
||||
|
||||
# Mock the embedding function
|
||||
with patch(
|
||||
"litellm.embedding", return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||
with (
|
||||
patch(
|
||||
"litellm.embedding",
|
||||
return_value={"data": [{"embedding": [0.1, 0.2, 0.3]}]},
|
||||
),
|
||||
patch.object(
|
||||
redis_semantic_cache,
|
||||
"_get_cache_key_filter_expression",
|
||||
return_value="cache-key-filter",
|
||||
),
|
||||
):
|
||||
# Test get_cache with a message
|
||||
metadata = {}
|
||||
result = redis_semantic_cache.get_cache(
|
||||
key="test_key", messages=[{"content": "What is the capital of France?"}]
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of France?"}],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Verify result is properly parsed
|
||||
assert result == {"content": "Paris is the capital of France."}
|
||||
assert metadata["semantic-similarity"] == pytest.approx(0.9)
|
||||
|
||||
# Verify llmcache.check was called
|
||||
redis_semantic_cache.llmcache.check.assert_called_once()
|
||||
redis_semantic_cache.llmcache.check.assert_called_once_with(
|
||||
prompt="What is the capital of France?",
|
||||
filter_expression="cache-key-filter",
|
||||
)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_rejects_unscoped_cache_hit(monkeypatch):
|
||||
semantic_cache_mock = MagicMock()
|
||||
custom_vectorizer_mock = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(
|
||||
CustomTextVectorizer=custom_vectorizer_mock
|
||||
),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
|
||||
redis_semantic_cache.llmcache.check = MagicMock(
|
||||
return_value=[
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"response": '{"content": "Paris"}',
|
||||
"vector_distance": 0.1,
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
redis_semantic_cache,
|
||||
"_get_cache_key_filter_expression",
|
||||
return_value="cache-key-filter",
|
||||
):
|
||||
metadata = {}
|
||||
result = redis_semantic_cache.get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of France?"}],
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert metadata["semantic-similarity"] == 0.0
|
||||
|
||||
|
||||
def test_redis_semantic_cache_set_cache_stores_cache_key_filter(monkeypatch):
|
||||
semantic_cache_mock = MagicMock()
|
||||
custom_vectorizer_mock = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(
|
||||
CustomTextVectorizer=custom_vectorizer_mock
|
||||
),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
|
||||
redis_semantic_cache.llmcache.store = MagicMock()
|
||||
|
||||
redis_semantic_cache.set_cache(
|
||||
key="test_key",
|
||||
value={"content": "Paris"},
|
||||
messages=[{"content": "What is the capital of France?"}],
|
||||
ttl=60,
|
||||
)
|
||||
|
||||
redis_semantic_cache.llmcache.store.assert_called_once_with(
|
||||
"What is the capital of France?",
|
||||
"{'content': 'Paris'}",
|
||||
filters={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"},
|
||||
ttl=60,
|
||||
)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_uses_isolated_index_for_old_schema(monkeypatch):
|
||||
fallback_cache_mock = MagicMock()
|
||||
semantic_cache_mock = MagicMock(
|
||||
side_effect=[
|
||||
ValueError("stored index schema differs from requested fields"),
|
||||
fallback_cache_mock,
|
||||
]
|
||||
)
|
||||
custom_vectorizer_mock = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(
|
||||
CustomTextVectorizer=custom_vectorizer_mock
|
||||
),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
redis_semantic_cache = RedisSemanticCache(
|
||||
similarity_threshold=0.8,
|
||||
index_name="existing_index",
|
||||
)
|
||||
|
||||
assert redis_semantic_cache.llmcache is fallback_cache_mock
|
||||
assert semantic_cache_mock.call_args_list[0].kwargs["name"] == "existing_index"
|
||||
assert (
|
||||
semantic_cache_mock.call_args_list[1].kwargs["name"]
|
||||
== "existing_index_isolated"
|
||||
)
|
||||
assert semantic_cache_mock.call_args_list[1].kwargs["filterable_fields"] == [
|
||||
RedisSemanticCache._cache_key_filterable_field()
|
||||
]
|
||||
|
||||
|
||||
def test_redis_semantic_cache_overwrites_stale_isolated_index(monkeypatch):
|
||||
fallback_cache_mock = MagicMock()
|
||||
semantic_cache_mock = MagicMock(
|
||||
side_effect=[
|
||||
ValueError("Existing index schema does not match"),
|
||||
ValueError("Existing index schema does not match"),
|
||||
fallback_cache_mock,
|
||||
]
|
||||
)
|
||||
custom_vectorizer_mock = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(
|
||||
CustomTextVectorizer=custom_vectorizer_mock
|
||||
),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
redis_semantic_cache = RedisSemanticCache(
|
||||
similarity_threshold=0.8,
|
||||
index_name="existing_index",
|
||||
)
|
||||
|
||||
assert redis_semantic_cache.llmcache is fallback_cache_mock
|
||||
assert (
|
||||
semantic_cache_mock.call_args_list[2].kwargs["name"]
|
||||
== "existing_index_isolated"
|
||||
)
|
||||
assert semantic_cache_mock.call_args_list[2].kwargs["overwrite"] is True
|
||||
assert semantic_cache_mock.call_args_list[2].kwargs["filterable_fields"] == [
|
||||
RedisSemanticCache._cache_key_filterable_field()
|
||||
]
|
||||
|
||||
|
||||
def test_redis_semantic_cache_reraises_unexpected_isolated_index_error(monkeypatch):
|
||||
semantic_cache_mock = MagicMock(
|
||||
side_effect=[
|
||||
ValueError("Existing index schema does not match"),
|
||||
ValueError("connection failed"),
|
||||
]
|
||||
)
|
||||
custom_vectorizer_mock = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(
|
||||
CustomTextVectorizer=custom_vectorizer_mock
|
||||
),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
with pytest.raises(ValueError, match="connection failed"):
|
||||
RedisSemanticCache(
|
||||
similarity_threshold=0.8,
|
||||
index_name="existing_index",
|
||||
)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_reraises_unexpected_index_error():
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache)
|
||||
redis_semantic_cache.distance_threshold = 0.2
|
||||
semantic_cache_mock = MagicMock(side_effect=ValueError("connection failed"))
|
||||
|
||||
with pytest.raises(ValueError, match="connection failed"):
|
||||
redis_semantic_cache._init_semantic_cache(
|
||||
semantic_cache_cls=semantic_cache_mock,
|
||||
index_name="existing_index",
|
||||
redis_url="redis://localhost:6379",
|
||||
cache_vectorizer=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_matches_bytes_cache_key():
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache)
|
||||
|
||||
assert redis_semantic_cache._cache_hit_matches_key(
|
||||
cache_hit={RedisSemanticCache.CACHE_KEY_FIELD_NAME: b"test_key"},
|
||||
key="test_key",
|
||||
)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_rejects_pre_isolation_unscoped_hit():
|
||||
"""Pre-isolation entries with no cache-key field cannot be safely
|
||||
reassigned to a caller's scope and are treated as misses."""
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache)
|
||||
|
||||
cache_hit = {
|
||||
"prompt": "What is the capital of France?",
|
||||
"response": '{"content": "Paris"}',
|
||||
"vector_distance": 0.1,
|
||||
}
|
||||
assert not redis_semantic_cache._cache_hit_matches_key(
|
||||
cache_hit=cache_hit,
|
||||
key="test_key",
|
||||
)
|
||||
|
||||
|
||||
def test_redis_semantic_cache_builds_filter_expression(monkeypatch):
|
||||
class FakeTag:
|
||||
def __init__(self, field_name):
|
||||
self.field_name = field_name
|
||||
|
||||
def __eq__(self, value):
|
||||
return (self.field_name, value)
|
||||
|
||||
with patch.dict("sys.modules", {"redisvl.query.filter": MagicMock(Tag=FakeTag)}):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
redis_semantic_cache = RedisSemanticCache.__new__(RedisSemanticCache)
|
||||
|
||||
assert redis_semantic_cache._get_cache_key_filter_expression("test_key") == (
|
||||
RedisSemanticCache.CACHE_KEY_FIELD_NAME,
|
||||
"test_key",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -123,6 +400,7 @@ async def test_redis_semantic_cache_async_get_cache(monkeypatch):
|
||||
"prompt": "What is the capital of France?",
|
||||
"response": '{"content": "Paris is the capital of France."}',
|
||||
"vector_distance": 0.1, # Distance of 0.1 means similarity of 0.9
|
||||
RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key",
|
||||
}
|
||||
]
|
||||
|
||||
@ -131,16 +409,117 @@ async def test_redis_semantic_cache_async_get_cache(monkeypatch):
|
||||
return_value=[0.1, 0.2, 0.3]
|
||||
)
|
||||
|
||||
# Test async_get_cache with a message
|
||||
result = await redis_semantic_cache.async_get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of France?"}],
|
||||
metadata={},
|
||||
)
|
||||
with patch.object(
|
||||
redis_semantic_cache,
|
||||
"_get_cache_key_filter_expression",
|
||||
return_value="cache-key-filter",
|
||||
):
|
||||
# Test async_get_cache with a message
|
||||
result = await redis_semantic_cache.async_get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of France?"}],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
# Verify result is properly parsed
|
||||
assert result == {"content": "Paris is the capital of France."}
|
||||
|
||||
# Verify methods were called
|
||||
redis_semantic_cache._get_async_embedding.assert_called_once()
|
||||
redis_semantic_cache.llmcache.acheck.assert_called_once()
|
||||
redis_semantic_cache.llmcache.acheck.assert_called_once_with(
|
||||
prompt="What is the capital of France?",
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
filter_expression="cache-key-filter",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_semantic_cache_async_get_cache_rejects_unscoped_hit(monkeypatch):
|
||||
semantic_cache_mock = MagicMock()
|
||||
custom_vectorizer_mock = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(
|
||||
CustomTextVectorizer=custom_vectorizer_mock
|
||||
),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
|
||||
redis_semantic_cache.llmcache.acheck = AsyncMock(
|
||||
return_value=[
|
||||
{
|
||||
"prompt": "What is the capital of France?",
|
||||
"response": '{"content": "Paris"}',
|
||||
"vector_distance": 0.1,
|
||||
}
|
||||
]
|
||||
)
|
||||
redis_semantic_cache._get_async_embedding = AsyncMock(
|
||||
return_value=[0.1, 0.2, 0.3]
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
redis_semantic_cache,
|
||||
"_get_cache_key_filter_expression",
|
||||
return_value="cache-key-filter",
|
||||
):
|
||||
result = await redis_semantic_cache.async_get_cache(
|
||||
key="test_key",
|
||||
messages=[{"content": "What is the capital of France?"}],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_redis_semantic_cache_async_set_cache_stores_cache_key_filter(
|
||||
monkeypatch,
|
||||
):
|
||||
semantic_cache_mock = MagicMock()
|
||||
custom_vectorizer_mock = MagicMock()
|
||||
|
||||
with patch.dict(
|
||||
"sys.modules",
|
||||
{
|
||||
"redisvl.extensions.llmcache": MagicMock(SemanticCache=semantic_cache_mock),
|
||||
"redisvl.utils.vectorize": MagicMock(
|
||||
CustomTextVectorizer=custom_vectorizer_mock
|
||||
),
|
||||
},
|
||||
):
|
||||
from litellm.caching.redis_semantic_cache import RedisSemanticCache
|
||||
|
||||
monkeypatch.setenv("REDIS_HOST", "localhost")
|
||||
monkeypatch.setenv("REDIS_PORT", "6379")
|
||||
monkeypatch.setenv("REDIS_PASSWORD", "test_password")
|
||||
|
||||
redis_semantic_cache = RedisSemanticCache(similarity_threshold=0.8)
|
||||
redis_semantic_cache.llmcache.astore = AsyncMock()
|
||||
redis_semantic_cache._get_async_embedding = AsyncMock(
|
||||
return_value=[0.1, 0.2, 0.3]
|
||||
)
|
||||
|
||||
await redis_semantic_cache.async_set_cache(
|
||||
key="test_key",
|
||||
value={"content": "Paris"},
|
||||
messages=[{"content": "What is the capital of France?"}],
|
||||
ttl=60,
|
||||
)
|
||||
|
||||
redis_semantic_cache.llmcache.astore.assert_called_once_with(
|
||||
"What is the capital of France?",
|
||||
"{'content': 'Paris'}",
|
||||
vector=[0.1, 0.2, 0.3],
|
||||
filters={RedisSemanticCache.CACHE_KEY_FIELD_NAME: "test_key"},
|
||||
ttl=60,
|
||||
)
|
||||
|
||||
@ -1292,3 +1292,219 @@ class TestIsRequestBodySafeNestedConfig:
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
# ── observability-callback ban (root + metadata) ───────────────────────────
|
||||
|
||||
|
||||
class TestObservabilityCallbackBans:
|
||||
"""The proxy must reject observability credentials, hosts, and project
|
||||
identifiers regardless of whether they arrive at the request body root,
|
||||
in ``metadata`` / ``litellm_metadata``, or in a JSON-string-encoded
|
||||
metadata blob (multipart/``extra_body`` path).
|
||||
|
||||
The ban list is derived from
|
||||
``litellm.litellm_core_utils.initialize_dynamic_callback_params._supported_callback_params``
|
||||
minus a small ``_SAFE_CLIENT_CALLBACK_PARAMS`` allow-list, plus
|
||||
``_EXTRA_BANNED_OBSERVABILITY_PARAMS`` for fields integrations read but
|
||||
that are not yet in the canonical allow-list. The derivation keeps the
|
||||
proxy in sync as new integrations are added.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_url_validation(self, monkeypatch):
|
||||
import litellm
|
||||
|
||||
monkeypatch.setattr(litellm, "user_url_validation", False, raising=False)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"field",
|
||||
[
|
||||
"langfuse_public_key",
|
||||
"langfuse_secret",
|
||||
"langfuse_secret_key",
|
||||
"langsmith_api_key",
|
||||
"langsmith_project",
|
||||
"langsmith_tenant_id",
|
||||
"arize_api_key",
|
||||
"arize_space_key",
|
||||
"arize_space_id",
|
||||
"posthog_api_key",
|
||||
"posthog_api_url",
|
||||
"braintrust_api_key",
|
||||
"braintrust_project",
|
||||
"phoenix_project_name",
|
||||
"wandb_api_key",
|
||||
"weave_project_id",
|
||||
"gcs_bucket_name",
|
||||
"gcs_path_service_account",
|
||||
"humanloop_api_key",
|
||||
"lunary_public_key",
|
||||
],
|
||||
)
|
||||
def test_observability_field_in_request_body_root_is_rejected(self, field):
|
||||
with pytest.raises(ValueError) as exc:
|
||||
is_request_body_safe(
|
||||
request_body={"model": "gpt-4", field: "attacker-value"},
|
||||
general_settings={},
|
||||
llm_router=None,
|
||||
model="gpt-4",
|
||||
)
|
||||
assert field in str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metadata_key",
|
||||
["metadata", "litellm_metadata"],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"field",
|
||||
[
|
||||
"langfuse_host",
|
||||
"langfuse_secret_key",
|
||||
"langsmith_api_key",
|
||||
"posthog_api_url",
|
||||
"braintrust_project",
|
||||
"phoenix_project_name",
|
||||
],
|
||||
)
|
||||
def test_observability_field_in_metadata_dict_is_rejected(
|
||||
self, metadata_key, field
|
||||
):
|
||||
# Verifies the metadata walk: a value smuggled inside ``metadata``
|
||||
# or ``litellm_metadata`` is just as dangerous as the same field
|
||||
# at the body root, and must hit the same gate.
|
||||
with pytest.raises(ValueError) as exc:
|
||||
is_request_body_safe(
|
||||
request_body={
|
||||
"model": "gpt-4",
|
||||
metadata_key: {field: "attacker-value"},
|
||||
},
|
||||
general_settings={},
|
||||
llm_router=None,
|
||||
model="gpt-4",
|
||||
)
|
||||
assert field in str(exc.value)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"metadata_key",
|
||||
["metadata", "litellm_metadata"],
|
||||
)
|
||||
def test_observability_field_in_json_string_metadata_is_rejected(
|
||||
self, metadata_key
|
||||
):
|
||||
# Multipart/form-data and ``extra_body`` callers send metadata as a
|
||||
# JSON-encoded string. The bouncer parses it before applying the
|
||||
# banned-params check so the JSON-string path can't smuggle past
|
||||
# the ``isinstance(dict)`` guard.
|
||||
import json
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
is_request_body_safe(
|
||||
request_body={
|
||||
"model": "gpt-4",
|
||||
metadata_key: json.dumps(
|
||||
{"langfuse_host": "https://attacker.example"}
|
||||
),
|
||||
},
|
||||
general_settings={},
|
||||
llm_router=None,
|
||||
model="gpt-4",
|
||||
)
|
||||
assert "langfuse_host" in str(exc.value)
|
||||
|
||||
def test_admin_opt_in_allows_metadata_credential_passthrough(self):
|
||||
# The opt-in gate covers the metadata path the same way it covers
|
||||
# the root path — operators running BYO observability with
|
||||
# clientside creds flip a single flag and both paths work.
|
||||
assert (
|
||||
is_request_body_safe(
|
||||
request_body={
|
||||
"model": "gpt-4",
|
||||
"metadata": {
|
||||
"langfuse_host": "https://my-langfuse.example",
|
||||
"langfuse_public_key": "pk-mine",
|
||||
"langfuse_secret_key": "sk-mine",
|
||||
},
|
||||
},
|
||||
general_settings={"allow_client_side_credentials": True},
|
||||
llm_router=None,
|
||||
model="gpt-4",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_safe_per_request_observability_metadata_is_allowed(self):
|
||||
# Informational fields (sampling rate, prompt version) describe
|
||||
# the request being logged — they don't choose the destination or
|
||||
# credentials, so they must remain accepted from clients without
|
||||
# the opt-in flag.
|
||||
assert (
|
||||
is_request_body_safe(
|
||||
request_body={
|
||||
"model": "gpt-4",
|
||||
"metadata": {
|
||||
"langfuse_prompt_version": "v2",
|
||||
"langsmith_sampling_rate": 0.1,
|
||||
},
|
||||
},
|
||||
general_settings={},
|
||||
llm_router=None,
|
||||
model="gpt-4",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
def test_model_level_allow_does_not_skip_subsequent_banned_params(monkeypatch):
|
||||
"""Greptile P1: ``_check_banned_params`` previously ``return``-ed when a
|
||||
deployment's ``configurable_clientside_auth_params`` permitted one
|
||||
banned field, exiting before any later banned field in the same body
|
||||
was checked. The metadata walk this PR adds multiplies the surface
|
||||
where that bypass matters: a body pairing a model-level-allowed
|
||||
``api_base`` with an observability credential like ``langfuse_host``
|
||||
must still reject on the second field, not silently pass."""
|
||||
from litellm.proxy.auth import auth_utils
|
||||
|
||||
monkeypatch.setattr(
|
||||
auth_utils,
|
||||
"_allow_model_level_clientside_configurable_parameters",
|
||||
lambda model, param, request_body_value, llm_router: param == "api_base",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError) as exc:
|
||||
is_request_body_safe(
|
||||
request_body={
|
||||
"model": "gpt-4",
|
||||
"api_base": "https://allowed-by-deployment.example",
|
||||
"langfuse_host": "https://attacker.example",
|
||||
},
|
||||
general_settings={},
|
||||
llm_router=None,
|
||||
model="gpt-4",
|
||||
)
|
||||
assert "langfuse_host" in str(exc.value)
|
||||
|
||||
|
||||
def test_observability_ban_covers_canonical_supported_callback_params():
|
||||
"""Guard test: every entry in the canonical
|
||||
``_supported_callback_params`` allow-list must end up either banned by
|
||||
the proxy or explicitly safe-listed. New integrations added to that
|
||||
list are banned by default (the safe failure mode); flagging them as
|
||||
safe is an explicit decision recorded in
|
||||
``_SAFE_CLIENT_CALLBACK_PARAMS``."""
|
||||
from litellm.litellm_core_utils.initialize_dynamic_callback_params import (
|
||||
_supported_callback_params,
|
||||
)
|
||||
from litellm.proxy.auth.auth_utils import (
|
||||
_BANNED_REQUEST_BODY_PARAMS,
|
||||
_SAFE_CLIENT_CALLBACK_PARAMS,
|
||||
)
|
||||
|
||||
banned = set(_BANNED_REQUEST_BODY_PARAMS)
|
||||
for param in _supported_callback_params:
|
||||
assert param in banned or param in _SAFE_CLIENT_CALLBACK_PARAMS, (
|
||||
f"{param} is in _supported_callback_params but neither banned nor "
|
||||
f"safe-listed. Add it to _SAFE_CLIENT_CALLBACK_PARAMS if it is an "
|
||||
f"informational per-request field; otherwise the derivation will "
|
||||
f"ban it automatically."
|
||||
)
|
||||
|
||||
@ -596,6 +596,59 @@ async def test_add_litellm_data_to_request_strips_user_control_fields():
|
||||
assert "pillar_response_headers" not in snapshot_body["metadata"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"control_field",
|
||||
["callbacks", "service_callback", "logger_fn", "litellm_disabled_callbacks"],
|
||||
)
|
||||
async def test_add_litellm_data_to_request_strips_callback_control_fields(
|
||||
control_field,
|
||||
):
|
||||
"""``callbacks`` / ``service_callback`` / ``logger_fn`` get appended to
|
||||
the worker-wide ``litellm.{input,success,failure,_async_*,service}_callback``
|
||||
lists and ``litellm.user_logger_fn`` from inside ``function_setup`` —
|
||||
one request poisons every subsequent caller in that worker.
|
||||
``litellm_disabled_callbacks`` is the inverse: a request-body value
|
||||
silently disables admin-configured audit/observability for the call.
|
||||
None has a documented per-request use, so all four are stripped at
|
||||
the proxy boundary alongside the existing internal-only fields."""
|
||||
request_mock = MagicMock(spec=Request)
|
||||
request_mock.url.path = "/v1/chat/completions"
|
||||
request_mock.url = MagicMock()
|
||||
request_mock.url.__str__.return_value = "http://localhost/v1/chat/completions"
|
||||
request_mock.method = "POST"
|
||||
request_mock.query_params = {}
|
||||
request_mock.headers = {"Content-Type": "application/json"}
|
||||
request_mock.client = MagicMock()
|
||||
request_mock.client.host = "127.0.0.1"
|
||||
|
||||
sample_value = (
|
||||
["langfuse"]
|
||||
if control_field
|
||||
in ("callbacks", "service_callback", "litellm_disabled_callbacks")
|
||||
else "module.func"
|
||||
)
|
||||
|
||||
updated = await add_litellm_data_to_request(
|
||||
data={
|
||||
"model": "gpt-3.5-turbo",
|
||||
"messages": [{"role": "user", "content": "hi"}],
|
||||
control_field: sample_value,
|
||||
},
|
||||
request=request_mock,
|
||||
user_api_key_dict=UserAPIKeyAuth(api_key="hashed-key"),
|
||||
proxy_config=MagicMock(),
|
||||
general_settings={},
|
||||
version="test-version",
|
||||
)
|
||||
|
||||
assert control_field not in updated
|
||||
# The post-strip body snapshot used by audit/spend logging must also
|
||||
# not retain the attacker-injected control field.
|
||||
snapshot_body = updated["proxy_server_request"]["body"]
|
||||
assert control_field not in snapshot_body
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_litellm_data_to_request_allows_client_mock_response_with_admin_opt_in():
|
||||
request_mock = MagicMock(spec=Request)
|
||||
|
||||
@ -36,6 +36,33 @@ from litellm.proxy.vector_store_endpoints.utils import (
|
||||
from litellm.types.utils import LlmProviders
|
||||
|
||||
|
||||
def _serialize_litellm_params(litellm_params):
|
||||
"""Serialize ``litellm_params`` to a string for substring assertions.
|
||||
|
||||
The redact helper preserves the persisted shape — string in, string
|
||||
out; dict in, dict out — so callers that just want to assert "this
|
||||
secret never appears" need a single text representation either way.
|
||||
"""
|
||||
import json
|
||||
|
||||
if isinstance(litellm_params, str):
|
||||
return litellm_params
|
||||
return json.dumps(litellm_params or {})
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_embedding_config_cache():
|
||||
"""The use-time embedding-config resolver caches results in process
|
||||
memory across calls. Reset it before every test so the resolver
|
||||
actually exercises the router/DB path under test instead of returning
|
||||
a value cached by an earlier test."""
|
||||
from litellm.proxy.vector_store_endpoints import management_endpoints
|
||||
|
||||
management_endpoints._embedding_config_cache = None
|
||||
yield
|
||||
management_endpoints._embedding_config_cache = None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_avector_store_search_passes_correct_args():
|
||||
"""
|
||||
@ -170,6 +197,93 @@ async def test_update_request_data_with_litellm_managed_vector_store_registry():
|
||||
assert result == original_data
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_request_data_resolves_embedding_config_at_use_time():
|
||||
"""When the persisted vector store row carries only a
|
||||
``litellm_embedding_model`` reference (the new behaviour after
|
||||
moving the auto-resolve out of write time), the request-handling
|
||||
layer must resolve the embedding config so the downstream embed
|
||||
call still has ``api_key`` / ``api_base`` / ``api_version``. The
|
||||
resolved config lives in this per-request data dict only — never
|
||||
persisted."""
|
||||
mock_vector_store: LiteLLM_ManagedVectorStore = {
|
||||
"vector_store_id": "test_store",
|
||||
"custom_llm_provider": "azure_ai",
|
||||
"litellm_params": {
|
||||
"litellm_embedding_model": "azure/text-embedding-3-large",
|
||||
# Note: no litellm_embedding_config persisted
|
||||
},
|
||||
}
|
||||
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_litellm_managed_vector_store_from_registry.return_value = (
|
||||
mock_vector_store
|
||||
)
|
||||
|
||||
resolved = {
|
||||
"api_key": "use-time-resolved-key",
|
||||
"api_base": "https://my-azure.example",
|
||||
"api_version": "2024-09-01",
|
||||
}
|
||||
|
||||
with (
|
||||
patch.object(litellm, "vector_store_registry", mock_registry),
|
||||
patch(
|
||||
"litellm.proxy.vector_store_endpoints.endpoints._resolve_embedding_config",
|
||||
new=AsyncMock(return_value=resolved),
|
||||
),
|
||||
):
|
||||
result = await _update_request_data_with_litellm_managed_vector_store_registry(
|
||||
data={}, vector_store_id="test_store"
|
||||
)
|
||||
|
||||
assert result["litellm_embedding_model"] == "azure/text-embedding-3-large"
|
||||
assert result["litellm_embedding_config"] == resolved
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_request_data_passes_through_legacy_embedding_config():
|
||||
"""A vector store row created by an older proxy version may already
|
||||
carry a fully-resolved ``litellm_embedding_config`` in its persisted
|
||||
``litellm_params`` (the very leak this PR closes). Those legacy rows
|
||||
must still work — the use-time resolver skips re-resolution when
|
||||
the config is already present so the embed call keeps succeeding."""
|
||||
legacy_config = {
|
||||
"api_key": "legacy-cleartext-key",
|
||||
"api_base": "https://legacy-azure.example",
|
||||
"api_version": "2024-01-01",
|
||||
}
|
||||
mock_vector_store: LiteLLM_ManagedVectorStore = {
|
||||
"vector_store_id": "legacy_store",
|
||||
"custom_llm_provider": "azure_ai",
|
||||
"litellm_params": {
|
||||
"litellm_embedding_model": "azure/text-embedding-3-large",
|
||||
"litellm_embedding_config": legacy_config,
|
||||
},
|
||||
}
|
||||
|
||||
mock_registry = MagicMock()
|
||||
mock_registry.get_litellm_managed_vector_store_from_registry.return_value = (
|
||||
mock_vector_store
|
||||
)
|
||||
|
||||
resolve_mock = AsyncMock()
|
||||
|
||||
with (
|
||||
patch.object(litellm, "vector_store_registry", mock_registry),
|
||||
patch(
|
||||
"litellm.proxy.vector_store_endpoints.endpoints._resolve_embedding_config",
|
||||
new=resolve_mock,
|
||||
),
|
||||
):
|
||||
result = await _update_request_data_with_litellm_managed_vector_store_registry(
|
||||
data={}, vector_store_id="legacy_store"
|
||||
)
|
||||
|
||||
assert result["litellm_embedding_config"] == legacy_config
|
||||
resolve_mock.assert_not_awaited()
|
||||
|
||||
|
||||
class TestCheckVectorStorePermission:
|
||||
"""Test suite for check_vector_store_permission function."""
|
||||
|
||||
@ -1417,20 +1531,25 @@ async def test_new_vector_store_auto_resolves_embedding_config():
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
# Verify that embedding config was resolved and included in the create call
|
||||
# Auto-resolve no longer happens at create time — the persisted row
|
||||
# carries only the model reference, never the resolved cleartext
|
||||
# credential. Resolution now happens at request-handling time inside
|
||||
# ``_update_request_data_with_litellm_managed_vector_store_registry``,
|
||||
# where the resolved config lives in per-request memory and is never
|
||||
# written to the database.
|
||||
litellm_params_json = captured_create_data.get("litellm_params")
|
||||
assert litellm_params_json is not None
|
||||
litellm_params_dict = json.loads(litellm_params_json)
|
||||
assert "litellm_embedding_config" in litellm_params_dict
|
||||
assert (
|
||||
litellm_params_dict["litellm_embedding_config"]["api_key"] == "resolved-api-key"
|
||||
)
|
||||
assert (
|
||||
litellm_params_dict["litellm_embedding_config"]["api_base"]
|
||||
== "https://api.openai.com"
|
||||
)
|
||||
assert (
|
||||
litellm_params_dict["litellm_embedding_config"]["api_version"] == "2024-01-01"
|
||||
assert "litellm_embedding_config" not in litellm_params_dict
|
||||
assert litellm_params_dict["litellm_embedding_model"] == "text-embedding-ada-002"
|
||||
|
||||
# The response must also not echo a cleartext credential — even on
|
||||
# the create response, where redaction guards against caller-supplied
|
||||
# cleartext or pre-existing rows that were created by an earlier
|
||||
# proxy version.
|
||||
response_vs = result["vector_store"]
|
||||
assert "resolved-api-key" not in _serialize_litellm_params(
|
||||
response_vs.get("litellm_params")
|
||||
)
|
||||
|
||||
|
||||
@ -1578,6 +1697,43 @@ async def test_resolve_embedding_config_tries_router_then_db():
|
||||
mock_prisma_client.db.litellm_proxymodeltable.find_first.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_embedding_config_caches_result():
|
||||
"""The first lookup should hit the router/DB; subsequent lookups for
|
||||
the same model name should return the cached value without touching
|
||||
the router or the database."""
|
||||
from litellm.types.router import Deployment, LiteLLM_Params
|
||||
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_router = MagicMock()
|
||||
|
||||
mock_litellm_params = MagicMock(spec=LiteLLM_Params)
|
||||
mock_litellm_params.api_key = "router-api-key"
|
||||
mock_litellm_params.api_base = "https://router-api-base.com"
|
||||
mock_litellm_params.api_version = None
|
||||
|
||||
mock_deployment = MagicMock(spec=Deployment)
|
||||
mock_deployment.litellm_params = mock_litellm_params
|
||||
mock_router.get_deployment_by_model_group_name.return_value = mock_deployment
|
||||
|
||||
first = await _resolve_embedding_config(
|
||||
embedding_model="cached-model",
|
||||
prisma_client=mock_prisma_client,
|
||||
llm_router=mock_router,
|
||||
)
|
||||
assert first is not None
|
||||
assert mock_router.get_deployment_by_model_group_name.call_count == 1
|
||||
|
||||
second = await _resolve_embedding_config(
|
||||
embedding_model="cached-model",
|
||||
prisma_client=mock_prisma_client,
|
||||
llm_router=mock_router,
|
||||
)
|
||||
assert second == first
|
||||
# Router (and by extension the DB) was not consulted again.
|
||||
assert mock_router.get_deployment_by_model_group_name.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resolve_embedding_config_falls_back_to_db():
|
||||
"""Test that _resolve_embedding_config falls back to DB when router doesn't have the model."""
|
||||
@ -1687,21 +1843,18 @@ async def test_new_vector_store_auto_resolves_from_router():
|
||||
)
|
||||
|
||||
assert result["status"] == "success"
|
||||
# Verify that embedding config was resolved from router and included in the create call
|
||||
# Resolution against the router happens at request-handling time now,
|
||||
# not at row creation. The persisted ``litellm_params`` carries only
|
||||
# the model reference, never the cleartext credential.
|
||||
litellm_params_json = captured_create_data.get("litellm_params")
|
||||
assert litellm_params_json is not None
|
||||
litellm_params_dict = json.loads(litellm_params_json)
|
||||
assert "litellm_embedding_config" in litellm_params_dict
|
||||
assert (
|
||||
litellm_params_dict["litellm_embedding_config"]["api_key"]
|
||||
== "router-resolved-api-key"
|
||||
)
|
||||
assert (
|
||||
litellm_params_dict["litellm_embedding_config"]["api_base"]
|
||||
== "https://router-resolved-base.com"
|
||||
)
|
||||
assert (
|
||||
litellm_params_dict["litellm_embedding_config"]["api_version"] == "2024-03-01"
|
||||
assert "litellm_embedding_config" not in litellm_params_dict
|
||||
assert litellm_params_dict["litellm_embedding_model"] == "config-embedding-model"
|
||||
|
||||
response_vs = result["vector_store"]
|
||||
assert "router-resolved-api-key" not in _serialize_litellm_params(
|
||||
response_vs.get("litellm_params")
|
||||
)
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user