diff --git a/litellm/caching/caching.py b/litellm/caching/caching.py index c1afde1625..b6cfc8e790 100644 --- a/litellm/caching/caching.py +++ b/litellm/caching/caching.py @@ -691,6 +691,7 @@ class Cache: self, embedding_response: Any, model: Optional[str], + prompt_tokens: Optional[int] = None, prompt_tokens_details: Optional[dict] = None, ) -> CachedEmbedding: """ @@ -703,6 +704,7 @@ class Cache: "index": embedding_response.get("index"), "object": embedding_response.get("object"), "model": model, + "prompt_tokens": prompt_tokens, "prompt_tokens_details": prompt_tokens_details, } elif hasattr(embedding_response, "model_dump"): @@ -712,6 +714,7 @@ class Cache: "index": data.get("index"), "object": data.get("object"), "model": model, + "prompt_tokens": prompt_tokens, "prompt_tokens_details": prompt_tokens_details, } else: @@ -721,6 +724,7 @@ class Cache: "index": data.get("index"), "object": data.get("object"), "model": model, + "prompt_tokens": prompt_tokens, "prompt_tokens_details": prompt_tokens_details, } except KeyError as e: @@ -769,6 +773,29 @@ class Cache: per_item[key] = value return per_item if per_item else None + def _get_per_item_prompt_tokens( + self, + result: EmbeddingResponse, + idx_in_result_data: int, + ) -> Optional[int]: + """ + Extract the per-item prompt_tokens from a response for caching. + + Single-item responses store the full usage.prompt_tokens. Multi-item + responses distribute it evenly (with remainder) so that summing all + per-item values on retrieval reconstructs the original total. + """ + if result.usage is None or result.usage.prompt_tokens is None: + return None + + total = result.usage.prompt_tokens + num_items = len(result.data) + if num_items <= 1: + return total + + quotient, remainder = divmod(total, num_items) + return quotient + (1 if idx_in_result_data < remainder else 0) + def add_embedding_response_to_cache( self, result: EmbeddingResponse, @@ -780,7 +807,11 @@ class Cache: kwargs["cache_key"] = preset_cache_key embedding_response = result.data[idx_in_result_data] - # Extract per-item prompt_tokens_details from response usage + # Extract per-item prompt_tokens + details from response usage + prompt_tokens = self._get_per_item_prompt_tokens( + result=result, + idx_in_result_data=idx_in_result_data, + ) prompt_tokens_details = self._get_per_item_prompt_tokens_details( result=result, idx_in_result_data=idx_in_result_data, @@ -791,6 +822,7 @@ class Cache: embedding_dict: CachedEmbedding = self._convert_to_cached_embedding( embedding_response, model_name, + prompt_tokens=prompt_tokens, prompt_tokens_details=prompt_tokens_details, ) diff --git a/litellm/caching/caching_handler.py b/litellm/caching/caching_handler.py index 3f4e54382c..48691335b4 100644 --- a/litellm/caching/caching_handler.py +++ b/litellm/caching/caching_handler.py @@ -394,7 +394,7 @@ class LLMCachingHandler: return cr["model"] return None - def _process_async_embedding_cached_response( + def _process_async_embedding_cached_response( # noqa: PLR0915 self, final_embedding_cached_response: Optional[EmbeddingResponse], cached_result: List[Optional[CachedEmbedding]], @@ -456,7 +456,10 @@ class LLMCachingHandler: index=idx, object="embedding", ) - if isinstance(kwargs_input_as_list[idx], str): + cached_prompt_tokens = cr.get("prompt_tokens") + if cached_prompt_tokens is not None: + prompt_tokens += cached_prompt_tokens + elif isinstance(kwargs_input_as_list[idx], str): from litellm.utils import token_counter prompt_tokens += token_counter( diff --git a/litellm/types/caching.py b/litellm/types/caching.py index f8050b292c..10453c74a1 100644 --- a/litellm/types/caching.py +++ b/litellm/types/caching.py @@ -118,4 +118,5 @@ class CachedEmbedding(TypedDict): index: Optional[int] object: Optional[str] model: Optional[str] + prompt_tokens: Optional[int] prompt_tokens_details: Optional[dict] diff --git a/tests/test_litellm/caching/test_caching.py b/tests/test_litellm/caching/test_caching.py index 02d62a1915..20614103ed 100644 --- a/tests/test_litellm/caching/test_caching.py +++ b/tests/test_litellm/caching/test_caching.py @@ -3,6 +3,7 @@ import re from litellm.caching.caching import Cache from litellm.types.caching import LiteLLMCacheType +from litellm.types.utils import Embedding, EmbeddingResponse, Usage def test_cache_key_debug_log_does_not_include_prompt_material(caplog): @@ -41,8 +42,37 @@ def test_cache_key_debug_log_does_not_include_prompt_material(caplog): assert re.fullmatch(r"[0-9a-f]{64}", cache_key) created_cache_key_logs = [ - record.getMessage() for record in caplog.records if "Created cache key:" in record.getMessage() + record.getMessage() + for record in caplog.records + if "Created cache key:" in record.getMessage() ] assert created_cache_key_logs assert all(prompt_marker not in message for message in created_cache_key_logs) assert any(cache_key in message for message in created_cache_key_logs) + + +def _embedding_response(prompt_tokens, num_items): + return EmbeddingResponse( + model="amazon.titan-embed-image-v1", + data=[ + Embedding(embedding=[0.0], index=i, object="embedding") + for i in range(num_items) + ], + usage=Usage( + prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens + ), + ) + + +def test_get_per_item_prompt_tokens_single_item_returns_full_value(): + cache = Cache(type=LiteLLMCacheType.LOCAL) + result = _embedding_response(prompt_tokens=0, num_items=1) + assert cache._get_per_item_prompt_tokens(result, 0) == 0 + + +def test_get_per_item_prompt_tokens_distributes_with_remainder(): + cache = Cache(type=LiteLLMCacheType.LOCAL) + result = _embedding_response(prompt_tokens=10, num_items=3) + per_item = [cache._get_per_item_prompt_tokens(result, i) for i in range(3)] + assert sum(per_item) == 10 # 4 + 3 + 3 + assert per_item == [4, 3, 3] diff --git a/tests/test_litellm/caching/test_caching_handler.py b/tests/test_litellm/caching/test_caching_handler.py index 3eb949d7f2..0132752941 100644 --- a/tests/test_litellm/caching/test_caching_handler.py +++ b/tests/test_litellm/caching/test_caching_handler.py @@ -436,3 +436,123 @@ def test_convert_cached_responses_legacy_stream_path(): ) assert isinstance(result, CachedResponsesAPIStreamingIterator) + + +@pytest.mark.asyncio +async def test_embedding_cache_restores_stored_prompt_tokens_for_image_input(): + """Image-embedding cache hit restores prompt_tokens=0 from the stored value + instead of recomputing a bogus count by tokenizing the base64 input.""" + llm_caching_handler = LLMCachingHandler( + original_function=MagicMock(), + request_kwargs={}, + start_time=datetime.now(), + ) + + # base64-like blob — token_counter over this would return a large nonzero count + image_input = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk" * 50 + + cached_result = [ + { + "embedding": [-0.025, -0.019], + "index": 0, + "object": "embedding", + "model": "amazon.titan-embed-image-v1", + "prompt_tokens": 0, + "prompt_tokens_details": {"image_count": 1}, + } + ] + + mock_logging_obj = MagicMock() + mock_logging_obj.async_success_handler = AsyncMock() + response, cache_hit = llm_caching_handler._process_async_embedding_cached_response( + final_embedding_cached_response=None, + cached_result=cached_result, + kwargs={"model": "amazon.titan-embed-image-v1", "input": image_input}, + logging_obj=mock_logging_obj, + start_time=datetime.now(), + model="amazon.titan-embed-image-v1", + ) + + assert cache_hit + assert response.usage is not None + assert response.usage.prompt_tokens == 0 + assert response.usage.total_tokens == 0 + assert response.usage.prompt_tokens_details.image_count == 1 + + +@pytest.mark.asyncio +async def test_embedding_cache_sums_stored_prompt_tokens_across_items(): + """A multi-item cache hit sums the stored per-item prompt_tokens back to the total.""" + llm_caching_handler = LLMCachingHandler( + original_function=MagicMock(), + request_kwargs={}, + start_time=datetime.now(), + ) + + cached_result = [ + { + "embedding": [-0.01], + "index": 0, + "object": "embedding", + "model": "text-embedding-3-small", + "prompt_tokens": 5, + }, + { + "embedding": [-0.02], + "index": 1, + "object": "embedding", + "model": "text-embedding-3-small", + "prompt_tokens": 4, + }, + ] + + mock_logging_obj = MagicMock() + mock_logging_obj.async_success_handler = AsyncMock() + response, cache_hit = llm_caching_handler._process_async_embedding_cached_response( + final_embedding_cached_response=None, + cached_result=cached_result, + kwargs={"model": "text-embedding-3-small", "input": ["hello world", "foo bar"]}, + logging_obj=mock_logging_obj, + start_time=datetime.now(), + model="text-embedding-3-small", + ) + + assert cache_hit + assert response.usage.prompt_tokens == 9 + assert response.usage.total_tokens == 9 + + +@pytest.mark.asyncio +async def test_embedding_cache_falls_back_to_token_counter_for_legacy_entries(): + """Legacy cache entries with no stored prompt_tokens still recompute via token_counter + for str inputs (backward compatibility).""" + llm_caching_handler = LLMCachingHandler( + original_function=MagicMock(), + request_kwargs={}, + start_time=datetime.now(), + ) + + # No prompt_tokens key — pre-fix entry + cached_result = [ + { + "embedding": [-0.025, -0.019], + "index": 0, + "object": "embedding", + "model": "text-embedding-ada-002", + }, + ] + + mock_logging_obj = MagicMock() + mock_logging_obj.async_success_handler = AsyncMock() + response, cache_hit = llm_caching_handler._process_async_embedding_cached_response( + final_embedding_cached_response=None, + cached_result=cached_result, + kwargs={"model": "text-embedding-ada-002", "input": "hello world"}, + logging_obj=mock_logging_obj, + start_time=datetime.now(), + model="text-embedding-ada-002", + ) + + assert cache_hit + # token_counter over "hello world" yields a nonzero count — fallback path still runs + assert response.usage.prompt_tokens > 0