fix(caching): restore stored prompt_tokens on embedding cache hits instead of recomputing (#30046)
This commit is contained in:
parent
e15b37a18e
commit
2fe9feda71
@ -691,6 +691,7 @@ class Cache:
|
||||
self,
|
||||
embedding_response: Any,
|
||||
model: Optional[str],
|
||||
prompt_tokens: Optional[int] = None,
|
||||
prompt_tokens_details: Optional[dict] = None,
|
||||
) -> CachedEmbedding:
|
||||
"""
|
||||
@ -703,6 +704,7 @@ class Cache:
|
||||
"index": embedding_response.get("index"),
|
||||
"object": embedding_response.get("object"),
|
||||
"model": model,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_tokens_details": prompt_tokens_details,
|
||||
}
|
||||
elif hasattr(embedding_response, "model_dump"):
|
||||
@ -712,6 +714,7 @@ class Cache:
|
||||
"index": data.get("index"),
|
||||
"object": data.get("object"),
|
||||
"model": model,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_tokens_details": prompt_tokens_details,
|
||||
}
|
||||
else:
|
||||
@ -721,6 +724,7 @@ class Cache:
|
||||
"index": data.get("index"),
|
||||
"object": data.get("object"),
|
||||
"model": model,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"prompt_tokens_details": prompt_tokens_details,
|
||||
}
|
||||
except KeyError as e:
|
||||
@ -769,6 +773,29 @@ class Cache:
|
||||
per_item[key] = value
|
||||
return per_item if per_item else None
|
||||
|
||||
def _get_per_item_prompt_tokens(
|
||||
self,
|
||||
result: EmbeddingResponse,
|
||||
idx_in_result_data: int,
|
||||
) -> Optional[int]:
|
||||
"""
|
||||
Extract the per-item prompt_tokens from a response for caching.
|
||||
|
||||
Single-item responses store the full usage.prompt_tokens. Multi-item
|
||||
responses distribute it evenly (with remainder) so that summing all
|
||||
per-item values on retrieval reconstructs the original total.
|
||||
"""
|
||||
if result.usage is None or result.usage.prompt_tokens is None:
|
||||
return None
|
||||
|
||||
total = result.usage.prompt_tokens
|
||||
num_items = len(result.data)
|
||||
if num_items <= 1:
|
||||
return total
|
||||
|
||||
quotient, remainder = divmod(total, num_items)
|
||||
return quotient + (1 if idx_in_result_data < remainder else 0)
|
||||
|
||||
def add_embedding_response_to_cache(
|
||||
self,
|
||||
result: EmbeddingResponse,
|
||||
@ -780,7 +807,11 @@ class Cache:
|
||||
kwargs["cache_key"] = preset_cache_key
|
||||
embedding_response = result.data[idx_in_result_data]
|
||||
|
||||
# Extract per-item prompt_tokens_details from response usage
|
||||
# Extract per-item prompt_tokens + details from response usage
|
||||
prompt_tokens = self._get_per_item_prompt_tokens(
|
||||
result=result,
|
||||
idx_in_result_data=idx_in_result_data,
|
||||
)
|
||||
prompt_tokens_details = self._get_per_item_prompt_tokens_details(
|
||||
result=result,
|
||||
idx_in_result_data=idx_in_result_data,
|
||||
@ -791,6 +822,7 @@ class Cache:
|
||||
embedding_dict: CachedEmbedding = self._convert_to_cached_embedding(
|
||||
embedding_response,
|
||||
model_name,
|
||||
prompt_tokens=prompt_tokens,
|
||||
prompt_tokens_details=prompt_tokens_details,
|
||||
)
|
||||
|
||||
|
||||
@ -394,7 +394,7 @@ class LLMCachingHandler:
|
||||
return cr["model"]
|
||||
return None
|
||||
|
||||
def _process_async_embedding_cached_response(
|
||||
def _process_async_embedding_cached_response( # noqa: PLR0915
|
||||
self,
|
||||
final_embedding_cached_response: Optional[EmbeddingResponse],
|
||||
cached_result: List[Optional[CachedEmbedding]],
|
||||
@ -456,7 +456,10 @@ class LLMCachingHandler:
|
||||
index=idx,
|
||||
object="embedding",
|
||||
)
|
||||
if isinstance(kwargs_input_as_list[idx], str):
|
||||
cached_prompt_tokens = cr.get("prompt_tokens")
|
||||
if cached_prompt_tokens is not None:
|
||||
prompt_tokens += cached_prompt_tokens
|
||||
elif isinstance(kwargs_input_as_list[idx], str):
|
||||
from litellm.utils import token_counter
|
||||
|
||||
prompt_tokens += token_counter(
|
||||
|
||||
@ -118,4 +118,5 @@ class CachedEmbedding(TypedDict):
|
||||
index: Optional[int]
|
||||
object: Optional[str]
|
||||
model: Optional[str]
|
||||
prompt_tokens: Optional[int]
|
||||
prompt_tokens_details: Optional[dict]
|
||||
|
||||
@ -3,6 +3,7 @@ import re
|
||||
|
||||
from litellm.caching.caching import Cache
|
||||
from litellm.types.caching import LiteLLMCacheType
|
||||
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||
|
||||
|
||||
def test_cache_key_debug_log_does_not_include_prompt_material(caplog):
|
||||
@ -41,8 +42,37 @@ def test_cache_key_debug_log_does_not_include_prompt_material(caplog):
|
||||
assert re.fullmatch(r"[0-9a-f]{64}", cache_key)
|
||||
|
||||
created_cache_key_logs = [
|
||||
record.getMessage() for record in caplog.records if "Created cache key:" in record.getMessage()
|
||||
record.getMessage()
|
||||
for record in caplog.records
|
||||
if "Created cache key:" in record.getMessage()
|
||||
]
|
||||
assert created_cache_key_logs
|
||||
assert all(prompt_marker not in message for message in created_cache_key_logs)
|
||||
assert any(cache_key in message for message in created_cache_key_logs)
|
||||
|
||||
|
||||
def _embedding_response(prompt_tokens, num_items):
|
||||
return EmbeddingResponse(
|
||||
model="amazon.titan-embed-image-v1",
|
||||
data=[
|
||||
Embedding(embedding=[0.0], index=i, object="embedding")
|
||||
for i in range(num_items)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_get_per_item_prompt_tokens_single_item_returns_full_value():
|
||||
cache = Cache(type=LiteLLMCacheType.LOCAL)
|
||||
result = _embedding_response(prompt_tokens=0, num_items=1)
|
||||
assert cache._get_per_item_prompt_tokens(result, 0) == 0
|
||||
|
||||
|
||||
def test_get_per_item_prompt_tokens_distributes_with_remainder():
|
||||
cache = Cache(type=LiteLLMCacheType.LOCAL)
|
||||
result = _embedding_response(prompt_tokens=10, num_items=3)
|
||||
per_item = [cache._get_per_item_prompt_tokens(result, i) for i in range(3)]
|
||||
assert sum(per_item) == 10 # 4 + 3 + 3
|
||||
assert per_item == [4, 3, 3]
|
||||
|
||||
@ -436,3 +436,123 @@ def test_convert_cached_responses_legacy_stream_path():
|
||||
)
|
||||
|
||||
assert isinstance(result, CachedResponsesAPIStreamingIterator)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_cache_restores_stored_prompt_tokens_for_image_input():
|
||||
"""Image-embedding cache hit restores prompt_tokens=0 from the stored value
|
||||
instead of recomputing a bogus count by tokenizing the base64 input."""
|
||||
llm_caching_handler = LLMCachingHandler(
|
||||
original_function=MagicMock(),
|
||||
request_kwargs={},
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
# base64-like blob — token_counter over this would return a large nonzero count
|
||||
image_input = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk" * 50
|
||||
|
||||
cached_result = [
|
||||
{
|
||||
"embedding": [-0.025, -0.019],
|
||||
"index": 0,
|
||||
"object": "embedding",
|
||||
"model": "amazon.titan-embed-image-v1",
|
||||
"prompt_tokens": 0,
|
||||
"prompt_tokens_details": {"image_count": 1},
|
||||
}
|
||||
]
|
||||
|
||||
mock_logging_obj = MagicMock()
|
||||
mock_logging_obj.async_success_handler = AsyncMock()
|
||||
response, cache_hit = llm_caching_handler._process_async_embedding_cached_response(
|
||||
final_embedding_cached_response=None,
|
||||
cached_result=cached_result,
|
||||
kwargs={"model": "amazon.titan-embed-image-v1", "input": image_input},
|
||||
logging_obj=mock_logging_obj,
|
||||
start_time=datetime.now(),
|
||||
model="amazon.titan-embed-image-v1",
|
||||
)
|
||||
|
||||
assert cache_hit
|
||||
assert response.usage is not None
|
||||
assert response.usage.prompt_tokens == 0
|
||||
assert response.usage.total_tokens == 0
|
||||
assert response.usage.prompt_tokens_details.image_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_cache_sums_stored_prompt_tokens_across_items():
|
||||
"""A multi-item cache hit sums the stored per-item prompt_tokens back to the total."""
|
||||
llm_caching_handler = LLMCachingHandler(
|
||||
original_function=MagicMock(),
|
||||
request_kwargs={},
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
cached_result = [
|
||||
{
|
||||
"embedding": [-0.01],
|
||||
"index": 0,
|
||||
"object": "embedding",
|
||||
"model": "text-embedding-3-small",
|
||||
"prompt_tokens": 5,
|
||||
},
|
||||
{
|
||||
"embedding": [-0.02],
|
||||
"index": 1,
|
||||
"object": "embedding",
|
||||
"model": "text-embedding-3-small",
|
||||
"prompt_tokens": 4,
|
||||
},
|
||||
]
|
||||
|
||||
mock_logging_obj = MagicMock()
|
||||
mock_logging_obj.async_success_handler = AsyncMock()
|
||||
response, cache_hit = llm_caching_handler._process_async_embedding_cached_response(
|
||||
final_embedding_cached_response=None,
|
||||
cached_result=cached_result,
|
||||
kwargs={"model": "text-embedding-3-small", "input": ["hello world", "foo bar"]},
|
||||
logging_obj=mock_logging_obj,
|
||||
start_time=datetime.now(),
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
|
||||
assert cache_hit
|
||||
assert response.usage.prompt_tokens == 9
|
||||
assert response.usage.total_tokens == 9
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embedding_cache_falls_back_to_token_counter_for_legacy_entries():
|
||||
"""Legacy cache entries with no stored prompt_tokens still recompute via token_counter
|
||||
for str inputs (backward compatibility)."""
|
||||
llm_caching_handler = LLMCachingHandler(
|
||||
original_function=MagicMock(),
|
||||
request_kwargs={},
|
||||
start_time=datetime.now(),
|
||||
)
|
||||
|
||||
# No prompt_tokens key — pre-fix entry
|
||||
cached_result = [
|
||||
{
|
||||
"embedding": [-0.025, -0.019],
|
||||
"index": 0,
|
||||
"object": "embedding",
|
||||
"model": "text-embedding-ada-002",
|
||||
},
|
||||
]
|
||||
|
||||
mock_logging_obj = MagicMock()
|
||||
mock_logging_obj.async_success_handler = AsyncMock()
|
||||
response, cache_hit = llm_caching_handler._process_async_embedding_cached_response(
|
||||
final_embedding_cached_response=None,
|
||||
cached_result=cached_result,
|
||||
kwargs={"model": "text-embedding-ada-002", "input": "hello world"},
|
||||
logging_obj=mock_logging_obj,
|
||||
start_time=datetime.now(),
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
assert cache_hit
|
||||
# token_counter over "hello world" yields a nonzero count — fallback path still runs
|
||||
assert response.usage.prompt_tokens > 0
|
||||
|
||||
Loading…
Reference in New Issue
Block a user