fix(caching): restore stored prompt_tokens on embedding cache hits instead of recomputing (#30046)
This commit is contained in:
parent
e15b37a18e
commit
2fe9feda71
@ -691,6 +691,7 @@ class Cache:
|
|||||||
self,
|
self,
|
||||||
embedding_response: Any,
|
embedding_response: Any,
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
|
prompt_tokens: Optional[int] = None,
|
||||||
prompt_tokens_details: Optional[dict] = None,
|
prompt_tokens_details: Optional[dict] = None,
|
||||||
) -> CachedEmbedding:
|
) -> CachedEmbedding:
|
||||||
"""
|
"""
|
||||||
@ -703,6 +704,7 @@ class Cache:
|
|||||||
"index": embedding_response.get("index"),
|
"index": embedding_response.get("index"),
|
||||||
"object": embedding_response.get("object"),
|
"object": embedding_response.get("object"),
|
||||||
"model": model,
|
"model": model,
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
"prompt_tokens_details": prompt_tokens_details,
|
"prompt_tokens_details": prompt_tokens_details,
|
||||||
}
|
}
|
||||||
elif hasattr(embedding_response, "model_dump"):
|
elif hasattr(embedding_response, "model_dump"):
|
||||||
@ -712,6 +714,7 @@ class Cache:
|
|||||||
"index": data.get("index"),
|
"index": data.get("index"),
|
||||||
"object": data.get("object"),
|
"object": data.get("object"),
|
||||||
"model": model,
|
"model": model,
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
"prompt_tokens_details": prompt_tokens_details,
|
"prompt_tokens_details": prompt_tokens_details,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
@ -721,6 +724,7 @@ class Cache:
|
|||||||
"index": data.get("index"),
|
"index": data.get("index"),
|
||||||
"object": data.get("object"),
|
"object": data.get("object"),
|
||||||
"model": model,
|
"model": model,
|
||||||
|
"prompt_tokens": prompt_tokens,
|
||||||
"prompt_tokens_details": prompt_tokens_details,
|
"prompt_tokens_details": prompt_tokens_details,
|
||||||
}
|
}
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
@ -769,6 +773,29 @@ class Cache:
|
|||||||
per_item[key] = value
|
per_item[key] = value
|
||||||
return per_item if per_item else None
|
return per_item if per_item else None
|
||||||
|
|
||||||
|
def _get_per_item_prompt_tokens(
|
||||||
|
self,
|
||||||
|
result: EmbeddingResponse,
|
||||||
|
idx_in_result_data: int,
|
||||||
|
) -> Optional[int]:
|
||||||
|
"""
|
||||||
|
Extract the per-item prompt_tokens from a response for caching.
|
||||||
|
|
||||||
|
Single-item responses store the full usage.prompt_tokens. Multi-item
|
||||||
|
responses distribute it evenly (with remainder) so that summing all
|
||||||
|
per-item values on retrieval reconstructs the original total.
|
||||||
|
"""
|
||||||
|
if result.usage is None or result.usage.prompt_tokens is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
total = result.usage.prompt_tokens
|
||||||
|
num_items = len(result.data)
|
||||||
|
if num_items <= 1:
|
||||||
|
return total
|
||||||
|
|
||||||
|
quotient, remainder = divmod(total, num_items)
|
||||||
|
return quotient + (1 if idx_in_result_data < remainder else 0)
|
||||||
|
|
||||||
def add_embedding_response_to_cache(
|
def add_embedding_response_to_cache(
|
||||||
self,
|
self,
|
||||||
result: EmbeddingResponse,
|
result: EmbeddingResponse,
|
||||||
@ -780,7 +807,11 @@ class Cache:
|
|||||||
kwargs["cache_key"] = preset_cache_key
|
kwargs["cache_key"] = preset_cache_key
|
||||||
embedding_response = result.data[idx_in_result_data]
|
embedding_response = result.data[idx_in_result_data]
|
||||||
|
|
||||||
# Extract per-item prompt_tokens_details from response usage
|
# Extract per-item prompt_tokens + details from response usage
|
||||||
|
prompt_tokens = self._get_per_item_prompt_tokens(
|
||||||
|
result=result,
|
||||||
|
idx_in_result_data=idx_in_result_data,
|
||||||
|
)
|
||||||
prompt_tokens_details = self._get_per_item_prompt_tokens_details(
|
prompt_tokens_details = self._get_per_item_prompt_tokens_details(
|
||||||
result=result,
|
result=result,
|
||||||
idx_in_result_data=idx_in_result_data,
|
idx_in_result_data=idx_in_result_data,
|
||||||
@ -791,6 +822,7 @@ class Cache:
|
|||||||
embedding_dict: CachedEmbedding = self._convert_to_cached_embedding(
|
embedding_dict: CachedEmbedding = self._convert_to_cached_embedding(
|
||||||
embedding_response,
|
embedding_response,
|
||||||
model_name,
|
model_name,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
prompt_tokens_details=prompt_tokens_details,
|
prompt_tokens_details=prompt_tokens_details,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@ -394,7 +394,7 @@ class LLMCachingHandler:
|
|||||||
return cr["model"]
|
return cr["model"]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _process_async_embedding_cached_response(
|
def _process_async_embedding_cached_response( # noqa: PLR0915
|
||||||
self,
|
self,
|
||||||
final_embedding_cached_response: Optional[EmbeddingResponse],
|
final_embedding_cached_response: Optional[EmbeddingResponse],
|
||||||
cached_result: List[Optional[CachedEmbedding]],
|
cached_result: List[Optional[CachedEmbedding]],
|
||||||
@ -456,7 +456,10 @@ class LLMCachingHandler:
|
|||||||
index=idx,
|
index=idx,
|
||||||
object="embedding",
|
object="embedding",
|
||||||
)
|
)
|
||||||
if isinstance(kwargs_input_as_list[idx], str):
|
cached_prompt_tokens = cr.get("prompt_tokens")
|
||||||
|
if cached_prompt_tokens is not None:
|
||||||
|
prompt_tokens += cached_prompt_tokens
|
||||||
|
elif isinstance(kwargs_input_as_list[idx], str):
|
||||||
from litellm.utils import token_counter
|
from litellm.utils import token_counter
|
||||||
|
|
||||||
prompt_tokens += token_counter(
|
prompt_tokens += token_counter(
|
||||||
|
|||||||
@ -118,4 +118,5 @@ class CachedEmbedding(TypedDict):
|
|||||||
index: Optional[int]
|
index: Optional[int]
|
||||||
object: Optional[str]
|
object: Optional[str]
|
||||||
model: Optional[str]
|
model: Optional[str]
|
||||||
|
prompt_tokens: Optional[int]
|
||||||
prompt_tokens_details: Optional[dict]
|
prompt_tokens_details: Optional[dict]
|
||||||
|
|||||||
@ -3,6 +3,7 @@ import re
|
|||||||
|
|
||||||
from litellm.caching.caching import Cache
|
from litellm.caching.caching import Cache
|
||||||
from litellm.types.caching import LiteLLMCacheType
|
from litellm.types.caching import LiteLLMCacheType
|
||||||
|
from litellm.types.utils import Embedding, EmbeddingResponse, Usage
|
||||||
|
|
||||||
|
|
||||||
def test_cache_key_debug_log_does_not_include_prompt_material(caplog):
|
def test_cache_key_debug_log_does_not_include_prompt_material(caplog):
|
||||||
@ -41,8 +42,37 @@ def test_cache_key_debug_log_does_not_include_prompt_material(caplog):
|
|||||||
assert re.fullmatch(r"[0-9a-f]{64}", cache_key)
|
assert re.fullmatch(r"[0-9a-f]{64}", cache_key)
|
||||||
|
|
||||||
created_cache_key_logs = [
|
created_cache_key_logs = [
|
||||||
record.getMessage() for record in caplog.records if "Created cache key:" in record.getMessage()
|
record.getMessage()
|
||||||
|
for record in caplog.records
|
||||||
|
if "Created cache key:" in record.getMessage()
|
||||||
]
|
]
|
||||||
assert created_cache_key_logs
|
assert created_cache_key_logs
|
||||||
assert all(prompt_marker not in message for message in created_cache_key_logs)
|
assert all(prompt_marker not in message for message in created_cache_key_logs)
|
||||||
assert any(cache_key in message for message in created_cache_key_logs)
|
assert any(cache_key in message for message in created_cache_key_logs)
|
||||||
|
|
||||||
|
|
||||||
|
def _embedding_response(prompt_tokens, num_items):
|
||||||
|
return EmbeddingResponse(
|
||||||
|
model="amazon.titan-embed-image-v1",
|
||||||
|
data=[
|
||||||
|
Embedding(embedding=[0.0], index=i, object="embedding")
|
||||||
|
for i in range(num_items)
|
||||||
|
],
|
||||||
|
usage=Usage(
|
||||||
|
prompt_tokens=prompt_tokens, completion_tokens=0, total_tokens=prompt_tokens
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_per_item_prompt_tokens_single_item_returns_full_value():
|
||||||
|
cache = Cache(type=LiteLLMCacheType.LOCAL)
|
||||||
|
result = _embedding_response(prompt_tokens=0, num_items=1)
|
||||||
|
assert cache._get_per_item_prompt_tokens(result, 0) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_per_item_prompt_tokens_distributes_with_remainder():
|
||||||
|
cache = Cache(type=LiteLLMCacheType.LOCAL)
|
||||||
|
result = _embedding_response(prompt_tokens=10, num_items=3)
|
||||||
|
per_item = [cache._get_per_item_prompt_tokens(result, i) for i in range(3)]
|
||||||
|
assert sum(per_item) == 10 # 4 + 3 + 3
|
||||||
|
assert per_item == [4, 3, 3]
|
||||||
|
|||||||
@ -436,3 +436,123 @@ def test_convert_cached_responses_legacy_stream_path():
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(result, CachedResponsesAPIStreamingIterator)
|
assert isinstance(result, CachedResponsesAPIStreamingIterator)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_cache_restores_stored_prompt_tokens_for_image_input():
|
||||||
|
"""Image-embedding cache hit restores prompt_tokens=0 from the stored value
|
||||||
|
instead of recomputing a bogus count by tokenizing the base64 input."""
|
||||||
|
llm_caching_handler = LLMCachingHandler(
|
||||||
|
original_function=MagicMock(),
|
||||||
|
request_kwargs={},
|
||||||
|
start_time=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# base64-like blob — token_counter over this would return a large nonzero count
|
||||||
|
image_input = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAAAAC0lEQVR42mNk" * 50
|
||||||
|
|
||||||
|
cached_result = [
|
||||||
|
{
|
||||||
|
"embedding": [-0.025, -0.019],
|
||||||
|
"index": 0,
|
||||||
|
"object": "embedding",
|
||||||
|
"model": "amazon.titan-embed-image-v1",
|
||||||
|
"prompt_tokens": 0,
|
||||||
|
"prompt_tokens_details": {"image_count": 1},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_logging_obj = MagicMock()
|
||||||
|
mock_logging_obj.async_success_handler = AsyncMock()
|
||||||
|
response, cache_hit = llm_caching_handler._process_async_embedding_cached_response(
|
||||||
|
final_embedding_cached_response=None,
|
||||||
|
cached_result=cached_result,
|
||||||
|
kwargs={"model": "amazon.titan-embed-image-v1", "input": image_input},
|
||||||
|
logging_obj=mock_logging_obj,
|
||||||
|
start_time=datetime.now(),
|
||||||
|
model="amazon.titan-embed-image-v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cache_hit
|
||||||
|
assert response.usage is not None
|
||||||
|
assert response.usage.prompt_tokens == 0
|
||||||
|
assert response.usage.total_tokens == 0
|
||||||
|
assert response.usage.prompt_tokens_details.image_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_cache_sums_stored_prompt_tokens_across_items():
|
||||||
|
"""A multi-item cache hit sums the stored per-item prompt_tokens back to the total."""
|
||||||
|
llm_caching_handler = LLMCachingHandler(
|
||||||
|
original_function=MagicMock(),
|
||||||
|
request_kwargs={},
|
||||||
|
start_time=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
cached_result = [
|
||||||
|
{
|
||||||
|
"embedding": [-0.01],
|
||||||
|
"index": 0,
|
||||||
|
"object": "embedding",
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"prompt_tokens": 5,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"embedding": [-0.02],
|
||||||
|
"index": 1,
|
||||||
|
"object": "embedding",
|
||||||
|
"model": "text-embedding-3-small",
|
||||||
|
"prompt_tokens": 4,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_logging_obj = MagicMock()
|
||||||
|
mock_logging_obj.async_success_handler = AsyncMock()
|
||||||
|
response, cache_hit = llm_caching_handler._process_async_embedding_cached_response(
|
||||||
|
final_embedding_cached_response=None,
|
||||||
|
cached_result=cached_result,
|
||||||
|
kwargs={"model": "text-embedding-3-small", "input": ["hello world", "foo bar"]},
|
||||||
|
logging_obj=mock_logging_obj,
|
||||||
|
start_time=datetime.now(),
|
||||||
|
model="text-embedding-3-small",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cache_hit
|
||||||
|
assert response.usage.prompt_tokens == 9
|
||||||
|
assert response.usage.total_tokens == 9
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embedding_cache_falls_back_to_token_counter_for_legacy_entries():
|
||||||
|
"""Legacy cache entries with no stored prompt_tokens still recompute via token_counter
|
||||||
|
for str inputs (backward compatibility)."""
|
||||||
|
llm_caching_handler = LLMCachingHandler(
|
||||||
|
original_function=MagicMock(),
|
||||||
|
request_kwargs={},
|
||||||
|
start_time=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# No prompt_tokens key — pre-fix entry
|
||||||
|
cached_result = [
|
||||||
|
{
|
||||||
|
"embedding": [-0.025, -0.019],
|
||||||
|
"index": 0,
|
||||||
|
"object": "embedding",
|
||||||
|
"model": "text-embedding-ada-002",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
mock_logging_obj = MagicMock()
|
||||||
|
mock_logging_obj.async_success_handler = AsyncMock()
|
||||||
|
response, cache_hit = llm_caching_handler._process_async_embedding_cached_response(
|
||||||
|
final_embedding_cached_response=None,
|
||||||
|
cached_result=cached_result,
|
||||||
|
kwargs={"model": "text-embedding-ada-002", "input": "hello world"},
|
||||||
|
logging_obj=mock_logging_obj,
|
||||||
|
start_time=datetime.now(),
|
||||||
|
model="text-embedding-ada-002",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert cache_hit
|
||||||
|
# token_counter over "hello world" yields a nonzero count — fallback path still runs
|
||||||
|
assert response.usage.prompt_tokens > 0
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user