Embedding caching fixes - handle str -> list cache, set usage tokens for cache hits, combine usage tokens on partial cache hits (#10424)
* build(model_prices_and_context_window.json): add fireworks ai new 0-4b pricing tier * build(model_prices_and_context_window.json): add more fireworks ai models * test: update testing * fix(caching_handler.py): handle str + list cache Fixes issue on cache hits for embedding when initial cached input was str * test(test_caching.py): add e2e test on caching with individual item and then list * fix(caching_handler.py): set usage tokens for cache hits enables token counting to work * fix(caching_handler.py): combine usage between cached result and embedding response Handles case of new input to embedding response * fix: cleanup * test: move to gpt-4o-new-test * test: update test
This commit is contained in:
parent
290e2528cd
commit
9e35ca2010
@ -13,7 +13,7 @@ import json
|
||||
import time
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@ -22,7 +22,7 @@ from litellm._logging import verbose_logger
|
||||
from litellm.constants import CACHED_STREAMING_CHUNK_DELAY
|
||||
from litellm.litellm_core_utils.model_param_helper import ModelParamHelper
|
||||
from litellm.types.caching import *
|
||||
from litellm.types.utils import all_litellm_params
|
||||
from litellm.types.utils import EmbeddingResponse, all_litellm_params
|
||||
|
||||
from .base_cache import BaseCache
|
||||
from .disk_cache import DiskCache
|
||||
@ -582,6 +582,22 @@ class Cache:
|
||||
except Exception as e:
|
||||
verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}")
|
||||
|
||||
def add_embedding_response_to_cache(
|
||||
self,
|
||||
result: EmbeddingResponse,
|
||||
input: str,
|
||||
kwargs: dict,
|
||||
idx_in_result_data: int = 0,
|
||||
) -> Tuple[str, dict, dict]:
|
||||
preset_cache_key = self.get_cache_key(**{**kwargs, "input": input})
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx_in_result_data]
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=embedding_response,
|
||||
**kwargs,
|
||||
)
|
||||
return cache_key, cached_data, kwargs
|
||||
|
||||
async def async_add_cache_pipeline(self, result, **kwargs):
|
||||
"""
|
||||
Async implementation of add_cache for Embedding calls
|
||||
@ -597,13 +613,17 @@ class Cache:
|
||||
kwargs["ttl"] = self.ttl
|
||||
|
||||
cache_list = []
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
preset_cache_key = self.get_cache_key(**{**kwargs, "input": i})
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx]
|
||||
cache_key, cached_data, kwargs = self._add_cache_logic(
|
||||
result=embedding_response,
|
||||
**kwargs,
|
||||
if isinstance(kwargs["input"], list):
|
||||
for idx, i in enumerate(kwargs["input"]):
|
||||
(
|
||||
cache_key,
|
||||
cached_data,
|
||||
kwargs,
|
||||
) = self.add_embedding_response_to_cache(result, i, kwargs, idx)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
elif isinstance(kwargs["input"], str):
|
||||
cache_key, cached_data, kwargs = self.add_embedding_response_to_cache(
|
||||
result, kwargs["input"], kwargs
|
||||
)
|
||||
cache_list.append((cache_key, cached_data))
|
||||
|
||||
|
||||
@ -47,6 +47,7 @@ from litellm.types.utils import (
|
||||
ModelResponse,
|
||||
TextCompletionResponse,
|
||||
TranscriptionResponse,
|
||||
Usage,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -129,7 +130,7 @@ class LLMCachingHandler:
|
||||
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||
original_function=original_function
|
||||
):
|
||||
verbose_logger.debug("Checking Cache")
|
||||
verbose_logger.debug("Checking Async Cache")
|
||||
cached_result = await self._retrieve_from_cache(
|
||||
call_type=call_type,
|
||||
kwargs=kwargs,
|
||||
@ -237,7 +238,7 @@ class LLMCachingHandler:
|
||||
if litellm.cache is not None and self._is_call_type_supported_by_cache(
|
||||
original_function=original_function
|
||||
):
|
||||
print_verbose("Checking Cache")
|
||||
print_verbose("Checking Sync Cache")
|
||||
cached_result = litellm.cache.get_cache(**new_kwargs)
|
||||
if cached_result is not None:
|
||||
if "detail" in cached_result:
|
||||
@ -339,6 +340,7 @@ class LLMCachingHandler:
|
||||
)
|
||||
final_embedding_cached_response._hidden_params["cache_hit"] = True
|
||||
|
||||
prompt_tokens = 0
|
||||
for val in non_null_list:
|
||||
idx, cr = val # (idx, cr) tuple
|
||||
if cr is not None:
|
||||
@ -347,6 +349,19 @@ class LLMCachingHandler:
|
||||
index=idx,
|
||||
object="embedding",
|
||||
)
|
||||
if isinstance(original_kwargs_input[idx], str):
|
||||
from litellm.utils import token_counter
|
||||
|
||||
prompt_tokens += token_counter(
|
||||
text=original_kwargs_input[idx], count_response_tokens=True
|
||||
)
|
||||
## USAGE
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=0,
|
||||
total_tokens=prompt_tokens,
|
||||
)
|
||||
final_embedding_cached_response.usage = usage
|
||||
if len(remaining_list) == 0:
|
||||
# LOG SUCCESS
|
||||
cache_hit = True
|
||||
@ -382,6 +397,13 @@ class LLMCachingHandler:
|
||||
return final_embedding_cached_response, embedding_all_elements_cache_hit
|
||||
return final_embedding_cached_response, embedding_all_elements_cache_hit
|
||||
|
||||
def combine_usage(self, usage1: Usage, usage2: Usage) -> Usage:
|
||||
return Usage(
|
||||
prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens,
|
||||
completion_tokens=usage1.completion_tokens + usage2.completion_tokens,
|
||||
total_tokens=usage1.total_tokens + usage2.total_tokens,
|
||||
)
|
||||
|
||||
def _combine_cached_embedding_response_with_api_result(
|
||||
self,
|
||||
_caching_handler_response: CachingHandlerResponse,
|
||||
@ -421,6 +443,17 @@ class LLMCachingHandler:
|
||||
_caching_handler_response.final_embedding_cached_response._response_ms = (
|
||||
end_time - start_time
|
||||
).total_seconds() * 1000
|
||||
|
||||
## USAGE
|
||||
if (
|
||||
_caching_handler_response.final_embedding_cached_response.usage is not None
|
||||
and embedding_response.usage is not None
|
||||
):
|
||||
_caching_handler_response.final_embedding_cached_response.usage = self.combine_usage(
|
||||
usage1=_caching_handler_response.final_embedding_cached_response.usage,
|
||||
usage2=embedding_response.usage,
|
||||
)
|
||||
|
||||
return _caching_handler_response.final_embedding_cached_response
|
||||
|
||||
def _async_log_cache_hit_on_callbacks(
|
||||
@ -689,7 +722,6 @@ class LLMCachingHandler:
|
||||
):
|
||||
if (
|
||||
isinstance(result, EmbeddingResponse)
|
||||
and isinstance(new_kwargs["input"], list)
|
||||
and litellm.cache is not None
|
||||
and not isinstance(
|
||||
litellm.cache.cache, S3Cache
|
||||
|
||||
@ -49,10 +49,15 @@ model_list:
|
||||
litellm_params:
|
||||
model: xai/*
|
||||
api_key: os.environ/XAI_API_KEY
|
||||
- model_name: "text-embedding-ada-002"
|
||||
litellm_params:
|
||||
model: text-embedding-ada-002
|
||||
api_key: os.environ/OPENAI_API_KEY
|
||||
|
||||
litellm_settings:
|
||||
num_retries: 0
|
||||
check_provider_endpoint: true
|
||||
cache: true
|
||||
|
||||
files_settings:
|
||||
- custom_llm_provider: gemini
|
||||
|
||||
@ -524,9 +524,9 @@ def function_setup( # noqa: PLR0915
|
||||
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
## DYNAMIC CALLBACKS ##
|
||||
dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = (
|
||||
kwargs.pop("callbacks", None)
|
||||
)
|
||||
dynamic_callbacks: Optional[
|
||||
List[Union[str, Callable, CustomLogger]]
|
||||
] = kwargs.pop("callbacks", None)
|
||||
all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks)
|
||||
|
||||
if len(all_callbacks) > 0:
|
||||
@ -1210,9 +1210,9 @@ def client(original_function): # noqa: PLR0915
|
||||
exception=e,
|
||||
retry_policy=kwargs.get("retry_policy"),
|
||||
)
|
||||
kwargs["retry_policy"] = (
|
||||
reset_retry_policy()
|
||||
) # prevent infinite loops
|
||||
kwargs[
|
||||
"retry_policy"
|
||||
] = reset_retry_policy() # prevent infinite loops
|
||||
litellm.num_retries = (
|
||||
None # set retries to None to prevent infinite loops
|
||||
)
|
||||
@ -1308,6 +1308,7 @@ def client(original_function): # noqa: PLR0915
|
||||
args=args,
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
_caching_handler_response.cached_result is not None
|
||||
and _caching_handler_response.final_embedding_cached_response is None
|
||||
@ -3036,16 +3037,16 @@ def get_optional_params( # noqa: PLR0915
|
||||
True # so that main.py adds the function call to the prompt
|
||||
)
|
||||
if "tools" in non_default_params:
|
||||
optional_params["functions_unsupported_model"] = (
|
||||
non_default_params.pop("tools")
|
||||
)
|
||||
optional_params[
|
||||
"functions_unsupported_model"
|
||||
] = non_default_params.pop("tools")
|
||||
non_default_params.pop(
|
||||
"tool_choice", None
|
||||
) # causes ollama requests to hang
|
||||
elif "functions" in non_default_params:
|
||||
optional_params["functions_unsupported_model"] = (
|
||||
non_default_params.pop("functions")
|
||||
)
|
||||
optional_params[
|
||||
"functions_unsupported_model"
|
||||
] = non_default_params.pop("functions")
|
||||
elif (
|
||||
litellm.add_function_to_prompt
|
||||
): # if user opts to add it to prompt instead
|
||||
@ -3068,10 +3069,10 @@ def get_optional_params( # noqa: PLR0915
|
||||
|
||||
if "response_format" in non_default_params:
|
||||
if provider_config is not None:
|
||||
non_default_params["response_format"] = (
|
||||
provider_config.get_json_schema_from_pydantic_object(
|
||||
response_format=non_default_params["response_format"]
|
||||
)
|
||||
non_default_params[
|
||||
"response_format"
|
||||
] = provider_config.get_json_schema_from_pydantic_object(
|
||||
response_format=non_default_params["response_format"]
|
||||
)
|
||||
else:
|
||||
non_default_params["response_format"] = type_to_response_format_param(
|
||||
@ -4087,9 +4088,9 @@ def _count_characters(text: str) -> int:
|
||||
|
||||
|
||||
def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str:
|
||||
_choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = (
|
||||
response_obj.choices
|
||||
)
|
||||
_choices: Union[
|
||||
List[Union[Choices, StreamingChoices]], List[StreamingChoices]
|
||||
] = response_obj.choices
|
||||
|
||||
response_str = ""
|
||||
for choice in _choices:
|
||||
|
||||
@ -449,6 +449,41 @@ def test_embedding_caching():
|
||||
|
||||
# test_embedding_caching()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_caching_individual_items_and_then_list():
|
||||
litellm._turn_on_debug()
|
||||
litellm.cache = Cache()
|
||||
text_to_embed = [
|
||||
"hello",
|
||||
"world",
|
||||
]
|
||||
embedding1 = await aembedding(
|
||||
model="text-embedding-ada-002", input=text_to_embed[0], caching=True
|
||||
)
|
||||
initial_prompt_tokens = embedding1.usage.prompt_tokens
|
||||
await asyncio.sleep(1)
|
||||
embedding2 = await aembedding(
|
||||
model="text-embedding-ada-002", input=text_to_embed[1], caching=True
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
embedding3 = await aembedding(
|
||||
model="text-embedding-ada-002", input=text_to_embed, caching=True
|
||||
)
|
||||
final_prompt_tokens = embedding3.usage.prompt_tokens
|
||||
assert embedding3["data"][0]["embedding"] == embedding1["data"][0]["embedding"]
|
||||
assert embedding3["data"][1]["embedding"] == embedding2["data"][0]["embedding"]
|
||||
assert embedding3._hidden_params["cache_hit"] == True
|
||||
assert embedding3.usage.prompt_tokens != 0
|
||||
|
||||
## with new input, check that prompt tokens increase
|
||||
additional_text = "this is a new text"
|
||||
text_to_embed.append(additional_text)
|
||||
embedding4 = await aembedding(
|
||||
model="text-embedding-ada-002", input=text_to_embed, caching=True
|
||||
)
|
||||
assert embedding4.usage.prompt_tokens > embedding3.usage.prompt_tokens
|
||||
|
||||
|
||||
|
||||
def test_embedding_caching_azure():
|
||||
print("Testing azure embedding caching")
|
||||
@ -2668,4 +2703,5 @@ def test_caching_thinking_args_hit(): # test in memory cache
|
||||
assert response1.id == response2.id
|
||||
except Exception as e:
|
||||
print(f"error occurred: {traceback.format_exc()}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@ -2862,7 +2862,7 @@ def test_completion_azure_deployment_id():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
response = completion(
|
||||
deployment_id="chatgpt-v-3",
|
||||
deployment_id="gpt-4o-new-test",
|
||||
model="gpt-3.5-turbo",
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
@ -104,12 +104,12 @@ async def test_router_with_caching():
|
||||
model_list = [
|
||||
{
|
||||
"model_name": "azure/gpt-4",
|
||||
"litellm_params": get_azure_params("chatgpt-v-3"),
|
||||
"litellm_params": get_azure_params("gpt-4o-new-test"),
|
||||
"tpm": 100,
|
||||
},
|
||||
{
|
||||
"model_name": "azure/gpt-4",
|
||||
"litellm_params": get_azure_params("chatgpt-v-3"),
|
||||
"litellm_params": get_azure_params("gpt-4o-new-test"),
|
||||
"tpm": 1000,
|
||||
},
|
||||
]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user