fix(caching): restore stored prompt_tokens on embedding cache hits instead of recomputing (#30046)

This commit is contained in:
michelligabriele 2026-06-10 12:19:20 +02:00 committed by GitHub
parent e15b37a18e
commit 2fe9feda71
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 190 additions and 4 deletions

View File

@ -691,6 +691,7 @@ class Cache:
self,
embedding_response: Any,
model: Optional[str],
prompt_tokens: Optional[int] = None,
prompt_tokens_details: Optional[dict] = None,
) -> CachedEmbedding:
"""
@ -703,6 +704,7 @@ class Cache:
"index": embedding_response.get("index"),
"object": embedding_response.get("object"),
"model": model,
"prompt_tokens": prompt_tokens,
"prompt_tokens_details": prompt_tokens_details,
}
elif hasattr(embedding_response, "model_dump"):
@ -712,6 +714,7 @@ class Cache:
"index": data.get("index"),
"object": data.get("object"),
"model": model,
"prompt_tokens": prompt_tokens,
"prompt_tokens_details": prompt_tokens_details,
}
else:
@ -721,6 +724,7 @@ class Cache:
"index": data.get("index"),
"object": data.get("object"),
"model": model,
"prompt_tokens": prompt_tokens,
"prompt_tokens_details": prompt_tokens_details,
}
except KeyError as e:
@ -769,6 +773,29 @@ class Cache:
per_item[key] = value
return per_item if per_item else None
def _get_per_item_prompt_tokens(
self,
result: EmbeddingResponse,
idx_in_result_data: int,
) -> Optional[int]:
"""
Extract the per-item prompt_tokens from a response for caching.
Single-item responses store the full usage.prompt_tokens. Multi-item
responses distribute it evenly (with remainder) so that summing all
per-item values on retrieval reconstructs the original total.
"""
if result.usage is None or result.usage.prompt_tokens is None:
return None
total = result.usage.prompt_tokens
num_items = len(result.data)
if num_items <= 1:
return total
quotient, remainder = divmod(total, num_items)
return quotient + (1 if idx_in_result_data < remainder else 0)
def add_embedding_response_to_cache(
self,
result: EmbeddingResponse,
@ -780,7 +807,11 @@ class Cache:
kwargs["cache_key"] = preset_cache_key
embedding_response = result.data[idx_in_result_data]
# Extract per-item prompt_tokens_details from response usage
# Extract per-item prompt_tokens + details from response usage
prompt_tokens = self._get_per_item_prompt_tokens(
result=result,
idx_in_result_data=idx_in_result_data,
)
prompt_tokens_details = self._get_per_item_prompt_tokens_details(
result=result,
idx_in_result_data=idx_in_result_data,
@ -791,6 +822,7 @@ class Cache:
embedding_dict: CachedEmbedding = self._convert_to_cached_embedding(
embedding_response,
model_name,
prompt_tokens=prompt_tokens,
prompt_tokens_details=prompt_tokens_details,
)

View File

@ -394,7 +394,7 @@ class LLMCachingHandler:
return cr["model"]
return None
def _process_async_embedding_cached_response(
def _process_async_embedding_cached_response( # noqa: PLR0915
self,
final_embedding_cached_response: Optional[EmbeddingResponse],
cached_result: List[Optional[CachedEmbedding]],
@ -456,7 +456,10 @@ class LLMCachingHandler:
index=idx,
object="embedding",
)
if isinstance(kwargs_input_as_list[idx], str):
cached_prompt_tokens = cr.get("prompt_tokens")
if cached_prompt_tokens is not None:
prompt_tokens += cached_prompt_tokens
elif isinstance(kwargs_input_as_list[idx], str):
from litellm.utils import token_counter
prompt_tokens += token_counter(

View File

@ -118,4 +118,5 @@ class CachedEmbedding(TypedDict):
index: Optional[int]
object: Optional[str]
model: Optional[str]
prompt_tokens: Optional[int]
prompt_tokens_details: Optional[dict]

View File

@ -3,6 +3,7 @@ import re
from litellm.caching.caching import Cache
from litellm.types.caching import LiteLLMCacheType
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
def test_cache_key_debug_log_does_not_include_prompt_material(caplog):
@ -41,8 +42,37 @@ def test_cache_key_debug_log_does_not_include_prompt_material(caplog):
assert re.fullmatch(r"[0-9a-f]{64}", cache_key)
created_cache_key_logs = [
record.getMessage() for record in caplog.records if "Created cache key:" in record.getMessage()
record.getMessage()
for record in caplog.records
if "Created cache key:" in record.getMessage()
]
assert created_cache_key_logs
assert all(prompt_marker not in message for message in created_cache_key_logs)
assert any(cache_key in message for message in created_cache_key_logs)
def _embedding_response(prompt_tokens, num_items):
return EmbeddingResponse(
model="amazon.titan-embed-image-v1",
data=[
Embedding(embedding=[0.0], index=i, object="embedding")
for i in range(num_items)
],
usage=Usage(
prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens
),
)
def test_get_per_item_prompt_tokens_single_item_returns_full_value():
cache = Cache(type=LiteLLMCacheType.LOCAL)
result = _embedding_response(prompt_tokens=0, num_items=1)
assert cache._get_per_item_prompt_tokens(result, 0) == 0
def test_get_per_item_prompt_tokens_distributes_with_remainder():
cache = Cache(type=LiteLLMCacheType.LOCAL)
result = _embedding_response(prompt_tokens=10, num_items=3)
per_item = [cache._get_per_item_prompt_tokens(result, i) for i in range(3)]
assert sum(per_item) == 10 # 4 + 3 + 3
assert per_item == [4, 3, 3]

View File

@ -436,3 +436,123 @@ def test_convert_cached_responses_legacy_stream_path():
)
assert isinstance(result, CachedResponsesAPIStreamingIterator)
@pytest.mark.asyncio
async def test_embedding_cache_restores_stored_prompt_tokens_for_image_input():
"""Image-embedding cache hit restores prompt_tokens=0 from the stored value
instead of recomputing a bogus count by tokenizing the base64 input."""
llm_caching_handler = LLMCachingHandler(
original_function=MagicMock(),
request_kwargs={},
start_time=datetime.now(),
)
# base64-like blob — token_counter over this would return a large nonzero count
image_input = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk" * 50
cached_result = [
{
"embedding": [-0.025, -0.019],
"index": 0,
"object": "embedding",
"model": "amazon.titan-embed-image-v1",
"prompt_tokens": 0,
"prompt_tokens_details": {"image_count": 1},
}
]
mock_logging_obj = MagicMock()
mock_logging_obj.async_success_handler = AsyncMock()
response, cache_hit = llm_caching_handler._process_async_embedding_cached_response(
final_embedding_cached_response=None,
cached_result=cached_result,
kwargs={"model": "amazon.titan-embed-image-v1", "input": image_input},
logging_obj=mock_logging_obj,
start_time=datetime.now(),
model="amazon.titan-embed-image-v1",
)
assert cache_hit
assert response.usage is not None
assert response.usage.prompt_tokens == 0
assert response.usage.total_tokens == 0
assert response.usage.prompt_tokens_details.image_count == 1
@pytest.mark.asyncio
async def test_embedding_cache_sums_stored_prompt_tokens_across_items():
"""A multi-item cache hit sums the stored per-item prompt_tokens back to the total."""
llm_caching_handler = LLMCachingHandler(
original_function=MagicMock(),
request_kwargs={},
start_time=datetime.now(),
)
cached_result = [
{
"embedding": [-0.01],
"index": 0,
"object": "embedding",
"model": "text-embedding-3-small",
"prompt_tokens": 5,
},
{
"embedding": [-0.02],
"index": 1,
"object": "embedding",
"model": "text-embedding-3-small",
"prompt_tokens": 4,
},
]
mock_logging_obj = MagicMock()
mock_logging_obj.async_success_handler = AsyncMock()
response, cache_hit = llm_caching_handler._process_async_embedding_cached_response(
final_embedding_cached_response=None,
cached_result=cached_result,
kwargs={"model": "text-embedding-3-small", "input": ["hello world", "foo bar"]},
logging_obj=mock_logging_obj,
start_time=datetime.now(),
model="text-embedding-3-small",
)
assert cache_hit
assert response.usage.prompt_tokens == 9
assert response.usage.total_tokens == 9
@pytest.mark.asyncio
async def test_embedding_cache_falls_back_to_token_counter_for_legacy_entries():
"""Legacy cache entries with no stored prompt_tokens still recompute via token_counter
for str inputs (backward compatibility)."""
llm_caching_handler = LLMCachingHandler(
original_function=MagicMock(),
request_kwargs={},
start_time=datetime.now(),
)
# No prompt_tokens key — pre-fix entry
cached_result = [
{
"embedding": [-0.025, -0.019],
"index": 0,
"object": "embedding",
"model": "text-embedding-ada-002",
},
]
mock_logging_obj = MagicMock()
mock_logging_obj.async_success_handler = AsyncMock()
response, cache_hit = llm_caching_handler._process_async_embedding_cached_response(
final_embedding_cached_response=None,
cached_result=cached_result,
kwargs={"model": "text-embedding-ada-002", "input": "hello world"},
logging_obj=mock_logging_obj,
start_time=datetime.now(),
model="text-embedding-ada-002",
)
assert cache_hit
# token_counter over "hello world" yields a nonzero count — fallback path still runs
assert response.usage.prompt_tokens > 0