diff --git a/litellm/caching/caching.py b/litellm/caching/caching.py index 6a7c93e3fe..7adede7961 100644 --- a/litellm/caching/caching.py +++ b/litellm/caching/caching.py @@ -13,7 +13,7 @@ import json import time import traceback from enum import Enum -from typing import Any, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Tuple, Union from pydantic import BaseModel @@ -22,7 +22,7 @@ from litellm._logging import verbose_logger from litellm.constants import CACHED_STREAMING_CHUNK_DELAY from litellm.litellm_core_utils.model_param_helper import ModelParamHelper from litellm.types.caching import * -from litellm.types.utils import all_litellm_params +from litellm.types.utils import EmbeddingResponse, all_litellm_params from .base_cache import BaseCache from .disk_cache import DiskCache @@ -582,6 +582,22 @@ class Cache: except Exception as e: verbose_logger.exception(f"LiteLLM Cache: Excepton add_cache: {str(e)}") + def add_embedding_response_to_cache( + self, + result: EmbeddingResponse, + input: str, + kwargs: dict, + idx_in_result_data: int = 0, + ) -> Tuple[str, dict, dict]: + preset_cache_key = self.get_cache_key(**{**kwargs, "input": input}) + kwargs["cache_key"] = preset_cache_key + embedding_response = result.data[idx_in_result_data] + cache_key, cached_data, kwargs = self._add_cache_logic( + result=embedding_response, + **kwargs, + ) + return cache_key, cached_data, kwargs + async def async_add_cache_pipeline(self, result, **kwargs): """ Async implementation of add_cache for Embedding calls @@ -597,13 +613,17 @@ class Cache: kwargs["ttl"] = self.ttl cache_list = [] - for idx, i in enumerate(kwargs["input"]): - preset_cache_key = self.get_cache_key(**{**kwargs, "input": i}) - kwargs["cache_key"] = preset_cache_key - embedding_response = result.data[idx] - cache_key, cached_data, kwargs = self._add_cache_logic( - result=embedding_response, - **kwargs, + if isinstance(kwargs["input"], list): + for idx, i in enumerate(kwargs["input"]): + ( + cache_key, + cached_data, + kwargs, + ) = self.add_embedding_response_to_cache(result, i, kwargs, idx) + cache_list.append((cache_key, cached_data)) + elif isinstance(kwargs["input"], str): + cache_key, cached_data, kwargs = self.add_embedding_response_to_cache( + result, kwargs["input"], kwargs ) cache_list.append((cache_key, cached_data)) diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 14278de9cd..afe96c7f46 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -47,6 +47,7 @@ from litellm.types.utils import ( ModelResponse, TextCompletionResponse, TranscriptionResponse, + Usage, ) if TYPE_CHECKING: @@ -129,7 +130,7 @@ class LLMCachingHandler: if litellm.cache is not None and self._is_call_type_supported_by_cache( original_function=original_function ): - verbose_logger.debug("Checking Cache") + verbose_logger.debug("Checking Async Cache") cached_result = await self._retrieve_from_cache( call_type=call_type, kwargs=kwargs, @@ -237,7 +238,7 @@ class LLMCachingHandler: if litellm.cache is not None and self._is_call_type_supported_by_cache( original_function=original_function ): - print_verbose("Checking Cache") + print_verbose("Checking Sync Cache") cached_result = litellm.cache.get_cache(**new_kwargs) if cached_result is not None: if "detail" in cached_result: @@ -339,6 +340,7 @@ class LLMCachingHandler: ) final_embedding_cached_response._hidden_params["cache_hit"] = True + prompt_tokens = 0 for val in non_null_list: idx, cr = val # (idx, cr) tuple if cr is not None: @@ -347,6 +349,19 @@ class LLMCachingHandler: index=idx, object="embedding", ) + if isinstance(original_kwargs_input[idx], str): + from litellm.utils import token_counter + + prompt_tokens += token_counter( + text=original_kwargs_input[idx], count_response_tokens=True + ) + ## USAGE + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=0, + total_tokens=prompt_tokens, + ) + final_embedding_cached_response.usage = usage if len(remaining_list) == 0: # LOG SUCCESS cache_hit = True @@ -382,6 +397,13 @@ class LLMCachingHandler: return final_embedding_cached_response, embedding_all_elements_cache_hit return final_embedding_cached_response, embedding_all_elements_cache_hit + def combine_usage(self, usage1: Usage, usage2: Usage) -> Usage: + return Usage( + prompt_tokens=usage1.prompt_tokens + usage2.prompt_tokens, + completion_tokens=usage1.completion_tokens + usage2.completion_tokens, + total_tokens=usage1.total_tokens + usage2.total_tokens, + ) + def _combine_cached_embedding_response_with_api_result( self, _caching_handler_response: CachingHandlerResponse, @@ -421,6 +443,17 @@ class LLMCachingHandler: _caching_handler_response.final_embedding_cached_response._response_ms = ( end_time - start_time ).total_seconds() * 1000 + + ## USAGE + if ( + _caching_handler_response.final_embedding_cached_response.usage is not None + and embedding_response.usage is not None + ): + _caching_handler_response.final_embedding_cached_response.usage = self.combine_usage( + usage1=_caching_handler_response.final_embedding_cached_response.usage, + usage2=embedding_response.usage, + ) + return _caching_handler_response.final_embedding_cached_response def _async_log_cache_hit_on_callbacks( @@ -689,7 +722,6 @@ class LLMCachingHandler: ): if ( isinstance(result, EmbeddingResponse) - and isinstance(new_kwargs["input"], list) and litellm.cache is not None and not isinstance( litellm.cache.cache, S3Cache diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 0c78332a8f..631dbee625 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -49,10 +49,15 @@ model_list: litellm_params: model: xai/* api_key: os.environ/XAI_API_KEY + - model_name: "text-embedding-ada-002" + litellm_params: + model: text-embedding-ada-002 + api_key: os.environ/OPENAI_API_KEY litellm_settings: num_retries: 0 check_provider_endpoint: true + cache: true files_settings: - custom_llm_provider: gemini diff --git a/litellm/utils.py b/litellm/utils.py index 98a9c34b47..1bcfbef169 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -524,9 +524,9 @@ def function_setup( # noqa: PLR0915 function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None ## DYNAMIC CALLBACKS ## - dynamic_callbacks: Optional[List[Union[str, Callable, CustomLogger]]] = ( - kwargs.pop("callbacks", None) - ) + dynamic_callbacks: Optional[ + List[Union[str, Callable, CustomLogger]] + ] = kwargs.pop("callbacks", None) all_callbacks = get_dynamic_callbacks(dynamic_callbacks=dynamic_callbacks) if len(all_callbacks) > 0: @@ -1210,9 +1210,9 @@ def client(original_function): # noqa: PLR0915 exception=e, retry_policy=kwargs.get("retry_policy"), ) - kwargs["retry_policy"] = ( - reset_retry_policy() - ) # prevent infinite loops + kwargs[ + "retry_policy" + ] = reset_retry_policy() # prevent infinite loops litellm.num_retries = ( None # set retries to None to prevent infinite loops ) @@ -1308,6 +1308,7 @@ def client(original_function): # noqa: PLR0915 args=args, ) ) + if ( _caching_handler_response.cached_result is not None and _caching_handler_response.final_embedding_cached_response is None @@ -3036,16 +3037,16 @@ def get_optional_params( # noqa: PLR0915 True # so that main.py adds the function call to the prompt ) if "tools" in non_default_params: - optional_params["functions_unsupported_model"] = ( - non_default_params.pop("tools") - ) + optional_params[ + "functions_unsupported_model" + ] = non_default_params.pop("tools") non_default_params.pop( "tool_choice", None ) # causes ollama requests to hang elif "functions" in non_default_params: - optional_params["functions_unsupported_model"] = ( - non_default_params.pop("functions") - ) + optional_params[ + "functions_unsupported_model" + ] = non_default_params.pop("functions") elif ( litellm.add_function_to_prompt ): # if user opts to add it to prompt instead @@ -3068,10 +3069,10 @@ def get_optional_params( # noqa: PLR0915 if "response_format" in non_default_params: if provider_config is not None: - non_default_params["response_format"] = ( - provider_config.get_json_schema_from_pydantic_object( - response_format=non_default_params["response_format"] - ) + non_default_params[ + "response_format" + ] = provider_config.get_json_schema_from_pydantic_object( + response_format=non_default_params["response_format"] ) else: non_default_params["response_format"] = type_to_response_format_param( @@ -4087,9 +4088,9 @@ def _count_characters(text: str) -> int: def get_response_string(response_obj: Union[ModelResponse, ModelResponseStream]) -> str: - _choices: Union[List[Union[Choices, StreamingChoices]], List[StreamingChoices]] = ( - response_obj.choices - ) + _choices: Union[ + List[Union[Choices, StreamingChoices]], List[StreamingChoices] + ] = response_obj.choices response_str = "" for choice in _choices: diff --git a/tests/local_testing/test_caching.py b/tests/local_testing/test_caching.py index 8c12f3fd9b..0acfecfe33 100644 --- a/tests/local_testing/test_caching.py +++ b/tests/local_testing/test_caching.py @@ -449,6 +449,41 @@ def test_embedding_caching(): # test_embedding_caching() +@pytest.mark.asyncio +async def test_embedding_caching_individual_items_and_then_list(): + litellm._turn_on_debug() + litellm.cache = Cache() + text_to_embed = [ + "hello", + "world", + ] + embedding1 = await aembedding( + model="text-embedding-ada-002", input=text_to_embed[0], caching=True + ) + initial_prompt_tokens = embedding1.usage.prompt_tokens + await asyncio.sleep(1) + embedding2 = await aembedding( + model="text-embedding-ada-002", input=text_to_embed[1], caching=True + ) + await asyncio.sleep(1) + embedding3 = await aembedding( + model="text-embedding-ada-002", input=text_to_embed, caching=True + ) + final_prompt_tokens = embedding3.usage.prompt_tokens + assert embedding3["data"][0]["embedding"] == embedding1["data"][0]["embedding"] + assert embedding3["data"][1]["embedding"] == embedding2["data"][0]["embedding"] + assert embedding3._hidden_params["cache_hit"] == True + assert embedding3.usage.prompt_tokens != 0 + + ## with new input, check that prompt tokens increase + additional_text = "this is a new text" + text_to_embed.append(additional_text) + embedding4 = await aembedding( + model="text-embedding-ada-002", input=text_to_embed, caching=True + ) + assert embedding4.usage.prompt_tokens > embedding3.usage.prompt_tokens + + def test_embedding_caching_azure(): print("Testing azure embedding caching") @@ -2668,4 +2703,5 @@ def test_caching_thinking_args_hit(): # test in memory cache assert response1.id == response2.id except Exception as e: print(f"error occurred: {traceback.format_exc()}") - pytest.fail(f"Error occurred: {e}") \ No newline at end of file + pytest.fail(f"Error occurred: {e}") + diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 3f4204f7e8..86968866b1 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -2862,7 +2862,7 @@ def test_completion_azure_deployment_id(): try: litellm.set_verbose = True response = completion( - deployment_id="chatgpt-v-3", + deployment_id="gpt-4o-new-test", model="gpt-3.5-turbo", messages=messages, ) diff --git a/tests/local_testing/test_prometheus_service.py b/tests/local_testing/test_prometheus_service.py index cfbd6a1a83..be620a0c7c 100644 --- a/tests/local_testing/test_prometheus_service.py +++ b/tests/local_testing/test_prometheus_service.py @@ -104,12 +104,12 @@ async def test_router_with_caching(): model_list = [ { "model_name": "azure/gpt-4", - "litellm_params": get_azure_params("chatgpt-v-3"), + "litellm_params": get_azure_params("gpt-4o-new-test"), "tpm": 100, }, { "model_name": "azure/gpt-4", - "litellm_params": get_azure_params("chatgpt-v-3"), + "litellm_params": get_azure_params("gpt-4o-new-test"), "tpm": 1000, }, ]