Embedding caching fixes - handle str -> list cache, set usage tokens for cache hits, combine usage tokens on partial cache hits (#10424)

* build(model_prices_and_context_window.json): add fireworks ai new 0-4b pricing tier

* build(model_prices_and_context_window.json): add more fireworks ai models

* test: update testing

* fix(caching_handler.py): handle str + list cache

Fixes issue on cache hits for embedding when initial cached input was str

* test(test_caching.py): add e2e test on caching with individual item and then list

* fix(caching_handler.py): set usage tokens for cache hits

enables token counting to work

* fix(caching_handler.py): combine usage between cached result and embedding response

Handles case of new input to embedding response

* fix: cleanup

* test: move to gpt-4o-new-test

* test: update test
This commit is contained in:
Krish Dholakia 2025-04-29 21:21:28 -07:00 committed by GitHub
parent 290e2528cd
commit 9e35ca2010
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 129 additions and 35 deletions

View File

@ -13,7 +13,7 @@ import json
import time
import traceback
from enum import Enum
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union
from pydantic import BaseModel
@ -22,7 +22,7 @@ from litellm._logging import verbose_logger
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
from litellm.types.caching import *
from litellm.types.utils import all_litellm_params
from litellm.types.utils import EmbeddingResponse, all_litellm_params
from .base_cache import BaseCache
from .disk_cache import DiskCache
@ -582,6 +582,22 @@ class Cache:
except Exception as e:
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
def add_embedding_response_to_cache(
self,
result: EmbeddingResponse,
input: str,
kwargs: dict,
idx_in_result_data: int = 0,
) -> Tuple[str, dict, dict]:
preset_cache_key = self.get_cache_key(**{**kwargs, "input": input})
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx_in_result_data]
cache_key, cached_data, kwargs = self._add_cache_logic(
result=embedding_response,
**kwargs,
)
return cache_key, cached_data, kwargs
async def async_add_cache_pipeline(self, result, **kwargs):
"""
Async implementation of add_cache for Embedding calls
@ -597,13 +613,17 @@ class Cache:
kwargs["ttl"] = self.ttl
cache_list = []
for idx, i in enumerate(kwargs["input"]):
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx]
cache_key, cached_data, kwargs = self._add_cache_logic(
result=embedding_response,
**kwargs,
if isinstance(kwargs["input"], list):
for idx, i in enumerate(kwargs["input"]):
(
cache_key,
cached_data,
kwargs,
) = self.add_embedding_response_to_cache(result, i, kwargs, idx)
cache_list.append((cache_key, cached_data))
elif isinstance(kwargs["input"], str):
cache_key, cached_data, kwargs = self.add_embedding_response_to_cache(
result, kwargs["input"], kwargs
)
cache_list.append((cache_key, cached_data))

View File

@ -47,6 +47,7 @@ from litellm.types.utils import (
ModelResponse,
TextCompletionResponse,
TranscriptionResponse,
Usage,
)
if TYPE_CHECKING:
@ -129,7 +130,7 @@ class LLMCachingHandler:
if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function
):
verbose_logger.debug("Checking Cache")
verbose_logger.debug("Checking Async Cache")
cached_result = await self._retrieve_from_cache(
call_type=call_type,
kwargs=kwargs,
@ -237,7 +238,7 @@ class LLMCachingHandler:
if litellm.cache is not None and self._is_call_type_supported_by_cache(
original_function=original_function
):
print_verbose("Checking Cache")
print_verbose("Checking Sync Cache")
cached_result = litellm.cache.get_cache(**new_kwargs)
if cached_result is not None:
if "detail" in cached_result:
@ -339,6 +340,7 @@ class LLMCachingHandler:
)
final_embedding_cached_response._hidden_params["cache_hit"] = True
prompt_tokens = 0
for val in non_null_list:
idx, cr = val # (idx, cr) tuple
if cr is not None:
@ -347,6 +349,19 @@ class LLMCachingHandler:
index=idx,
object="embedding",
)
if isinstance(original_kwargs_input[idx], str):
from litellm.utils import token_counter
prompt_tokens += token_counter(
text=original_kwargs_input[idx], count_response_tokens=True
)
## USAGE
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=0,
total_tokens=prompt_tokens,
)
final_embedding_cached_response.usage = usage
if len(remaining_list) == 0:
# LOG SUCCESS
cache_hit = True
@ -382,6 +397,13 @@ class LLMCachingHandler:
return final_embedding_cached_response, embedding_all_elements_cache_hit
return final_embedding_cached_response, embedding_all_elements_cache_hit
def combine_usage(self, usage1: Usage, usage2: Usage) -> Usage:
return Usage(
prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
completion_tokens=usage1.completion_tokens + usage2.completion_tokens,
total_tokens=usage1.total_tokens + usage2.total_tokens,
)
def _combine_cached_embedding_response_with_api_result(
self,
_caching_handler_response: CachingHandlerResponse,
@ -421,6 +443,17 @@ class LLMCachingHandler:
_caching_handler_response.final_embedding_cached_response._response_ms = (
end_time - start_time
).total_seconds() * 1000
## USAGE
if (
_caching_handler_response.final_embedding_cached_response.usage is not None
and embedding_response.usage is not None
):
_caching_handler_response.final_embedding_cached_response.usage = self.combine_usage(
usage1=_caching_handler_response.final_embedding_cached_response.usage,
usage2=embedding_response.usage,
)
return _caching_handler_response.final_embedding_cached_response
def _async_log_cache_hit_on_callbacks(
@ -689,7 +722,6 @@ class LLMCachingHandler:
):
if (
isinstance(result, EmbeddingResponse)
and isinstance(new_kwargs["input"], list)
and litellm.cache is not None
and not isinstance(
litellm.cache.cache, S3Cache

View File

@ -49,10 +49,15 @@ model_list:
litellm_params:
model: xai/*
api_key: os.environ/XAI_API_KEY
- model_name: "text-embedding-ada-002"
litellm_params:
model: text-embedding-ada-002
api_key: os.environ/OPENAI_API_KEY
litellm_settings:
num_retries: 0
check_provider_endpoint: true
cache: true
files_settings:
- custom_llm_provider: gemini

View File

@ -524,9 +524,9 @@ def function_setup( # noqa: PLR0915
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
## DYNAMIC CALLBACKS ##
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = (
kwargs.pop("callbacks", None)
)
dynamic_callbacks: Optional[
List[Union[str, Callable, CustomLogger]]
] = kwargs.pop("callbacks", None)
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
if len(all_callbacks) > 0:
@ -1210,9 +1210,9 @@ def client(original_function): # noqa: PLR0915
exception=e,
retry_policy=kwargs.get("retry_policy"),
)
kwargs["retry_policy"] = (
reset_retry_policy()
) # prevent infinite loops
kwargs[
"retry_policy"
] = reset_retry_policy() # prevent infinite loops
litellm.num_retries = (
None # set retries to None to prevent infinite loops
)
@ -1308,6 +1308,7 @@ def client(original_function): # noqa: PLR0915
args=args,
)
)
if (
_caching_handler_response.cached_result is not None
and _caching_handler_response.final_embedding_cached_response is None
@ -3036,16 +3037,16 @@ def get_optional_params( # noqa: PLR0915
True # so that main.py adds the function call to the prompt
)
if "tools" in non_default_params:
optional_params["functions_unsupported_model"] = (
non_default_params.pop("tools")
)
optional_params[
"functions_unsupported_model"
] = non_default_params.pop("tools")
non_default_params.pop(
"tool_choice", None
) # causes ollama requests to hang
elif "functions" in non_default_params:
optional_params["functions_unsupported_model"] = (
non_default_params.pop("functions")
)
optional_params[
"functions_unsupported_model"
] = non_default_params.pop("functions")
elif (
litellm.add_function_to_prompt
): # if user opts to add it to prompt instead
@ -3068,10 +3069,10 @@ def get_optional_params( # noqa: PLR0915
if "response_format" in non_default_params:
if provider_config is not None:
non_default_params["response_format"] = (
provider_config.get_json_schema_from_pydantic_object(
response_format=non_default_params["response_format"]
)
non_default_params[
"response_format"
] = provider_config.get_json_schema_from_pydantic_object(
response_format=non_default_params["response_format"]
)
else:
non_default_params["response_format"] = type_to_response_format_param(
@ -4087,9 +4088,9 @@ def _count_characters(text: str) -> int:
def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str:
_choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = (
response_obj.choices
)
_choices: Union[
List[Union[Choices, StreamingChoices]], List[StreamingChoices]
] = response_obj.choices
response_str = ""
for choice in _choices:

View File

@ -449,6 +449,41 @@ def test_embedding_caching():
# test_embedding_caching()
@pytest.mark.asyncio
async def test_embedding_caching_individual_items_and_then_list():
litellm._turn_on_debug()
litellm.cache = Cache()
text_to_embed = [
"hello",
"world",
]
embedding1 = await aembedding(
model="text-embedding-ada-002", input=text_to_embed[0], caching=True
)
initial_prompt_tokens = embedding1.usage.prompt_tokens
await asyncio.sleep(1)
embedding2 = await aembedding(
model="text-embedding-ada-002", input=text_to_embed[1], caching=True
)
await asyncio.sleep(1)
embedding3 = await aembedding(
model="text-embedding-ada-002", input=text_to_embed, caching=True
)
final_prompt_tokens = embedding3.usage.prompt_tokens
assert embedding3["data"][0]["embedding"] == embedding1["data"][0]["embedding"]
assert embedding3["data"][1]["embedding"] == embedding2["data"][0]["embedding"]
assert embedding3._hidden_params["cache_hit"] == True
assert embedding3.usage.prompt_tokens != 0
## with new input, check that prompt tokens increase
additional_text = "this is a new text"
text_to_embed.append(additional_text)
embedding4 = await aembedding(
model="text-embedding-ada-002", input=text_to_embed, caching=True
)
assert embedding4.usage.prompt_tokens > embedding3.usage.prompt_tokens
def test_embedding_caching_azure():
print("Testing azure embedding caching")
@ -2668,4 +2703,5 @@ def test_caching_thinking_args_hit(): # test in memory cache
assert response1.id == response2.id
except Exception as e:
print(f"error occurred: {traceback.format_exc()}")
pytest.fail(f"Error occurred: {e}")
pytest.fail(f"Error occurred: {e}")

View File

@ -2862,7 +2862,7 @@ def test_completion_azure_deployment_id():
try:
litellm.set_verbose = True
response = completion(
deployment_id="chatgpt-v-3",
deployment_id="gpt-4o-new-test",
model="gpt-3.5-turbo",
messages=messages,
)

View File

@ -104,12 +104,12 @@ async def test_router_with_caching():
model_list = [
{
"model_name": "azure/gpt-4",
"litellm_params": get_azure_params("chatgpt-v-3"),
"litellm_params": get_azure_params("gpt-4o-new-test"),
"tpm": 100,
},
{
"model_name": "azure/gpt-4",
"litellm_params": get_azure_params("chatgpt-v-3"),
"litellm_params": get_azure_params("gpt-4o-new-test"),
"tpm": 1000,
},
]