Merge pull request #26270 from BerriAI/litellm_/lucid-kowalevski-de832f

[Fix] Stabilize flaky spend accuracy tests + patch Redis buffer data-loss path
This commit is contained in:
shin-berri 2026-04-22 15:02:24 -07:00 committed by GitHub
commit b6fdd46636
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 366 additions and 146 deletions

View File

@ -1802,6 +1802,7 @@ jobs:
-e DD_API_KEY=$DD_API_KEY \
-e DD_SITE=$DD_SITE \
-e AWS_REGION_NAME=$AWS_REGION_NAME \
-e PROXY_BATCH_WRITE_AT=2 \
--add-host host.docker.internal:host-gateway \
--name my-app \
-v $(pwd)/litellm/proxy/example_config_yaml/spend_tracking_config.yaml:/app/config.yaml \

View File

@ -22,6 +22,7 @@ from litellm.constants import (
)
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy._types import (
BaseDailySpendTransaction,
DailyAgentSpendTransaction,
DailyEndUserSpendTransaction,
DailyOrganizationSpendTransaction,
@ -29,6 +30,8 @@ from litellm.proxy._types import (
DailyTeamSpendTransaction,
DailyUserSpendTransaction,
DBSpendUpdateTransactions,
Litellm_EntityType,
SpendUpdateQueueItem,
)
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
@ -259,9 +262,36 @@ class RedisUpdateBuffer:
if len(rpush_list) == 0:
return
result_lengths = await self.redis_cache.async_rpush_pipeline(
rpush_list=rpush_list,
)
try:
result_lengths = await self.redis_cache.async_rpush_pipeline(
rpush_list=rpush_list,
)
except Exception as e:
# The in-memory queues were already drained above. If we let the
# exception propagate without restoring, the aggregated spend is
# permanently lost. Re-enqueue so the next scheduler tick retries.
verbose_proxy_logger.error(
"Spend tracking - failed to push aggregated spend updates to Redis. "
"Restoring %d transaction sets to in-memory queues for retry on next tick. "
"Error: %s",
len(rpush_list),
str(e),
)
await self._restore_spend_updates_to_in_memory_queues(
db_spend_update_transactions=db_spend_update_transactions,
daily_spend_update_transactions=daily_spend_update_transactions,
daily_team_spend_update_transactions=daily_team_spend_update_transactions,
daily_org_spend_update_transactions=daily_org_spend_update_transactions,
daily_end_user_spend_update_transactions=daily_end_user_spend_update_transactions,
daily_agent_spend_update_transactions=daily_agent_spend_update_transactions,
spend_update_queue=spend_update_queue,
daily_spend_update_queue=daily_spend_update_queue,
daily_team_spend_update_queue=daily_team_spend_update_queue,
daily_org_spend_update_queue=daily_org_spend_update_queue,
daily_end_user_spend_update_queue=daily_end_user_spend_update_queue,
daily_agent_spend_update_queue=daily_agent_spend_update_queue,
)
return
# Emit gauge events for each queue
for i, queue_size in enumerate(result_lengths):
@ -271,6 +301,101 @@ class RedisUpdateBuffer:
service=service_types[i],
)
@staticmethod
async def _restore_spend_updates_to_in_memory_queues(
db_spend_update_transactions: Optional[DBSpendUpdateTransactions],
daily_spend_update_transactions: Optional[Dict[str, BaseDailySpendTransaction]],
daily_team_spend_update_transactions: Optional[
Dict[str, BaseDailySpendTransaction]
],
daily_org_spend_update_transactions: Optional[
Dict[str, BaseDailySpendTransaction]
],
daily_end_user_spend_update_transactions: Optional[
Dict[str, BaseDailySpendTransaction]
],
daily_agent_spend_update_transactions: Optional[
Dict[str, BaseDailySpendTransaction]
],
spend_update_queue: SpendUpdateQueue,
daily_spend_update_queue: DailySpendUpdateQueue,
daily_team_spend_update_queue: DailySpendUpdateQueue,
daily_org_spend_update_queue: DailySpendUpdateQueue,
daily_end_user_spend_update_queue: DailySpendUpdateQueue,
daily_agent_spend_update_queue: DailySpendUpdateQueue,
) -> None:
"""
Put drained-but-unpushed transactions back into in-memory queues.
Called when the Redis rpush pipeline raises. Without this, all spend
data aggregated during the current scheduler tick is permanently lost
because the source queues were already drained before the rpush.
"""
if db_spend_update_transactions is not None:
entity_entries: List[
Tuple[Litellm_EntityType, Optional[Dict[str, float]]]
] = [
(
Litellm_EntityType.USER,
db_spend_update_transactions.get("user_list_transactions"),
),
(
Litellm_EntityType.END_USER,
db_spend_update_transactions.get("end_user_list_transactions"),
),
(
Litellm_EntityType.KEY,
db_spend_update_transactions.get("key_list_transactions"),
),
(
Litellm_EntityType.TEAM,
db_spend_update_transactions.get("team_list_transactions"),
),
(
Litellm_EntityType.TEAM_MEMBER,
db_spend_update_transactions.get("team_member_list_transactions"),
),
(
Litellm_EntityType.ORGANIZATION,
db_spend_update_transactions.get("org_list_transactions"),
),
(
Litellm_EntityType.TAG,
db_spend_update_transactions.get("tag_list_transactions"),
),
(
Litellm_EntityType.AGENT,
db_spend_update_transactions.get("agent_list_transactions"),
),
]
for entity_type, entities in entity_entries:
if not entities:
continue
for entity_id, cost in entities.items():
await spend_update_queue.add_update(
SpendUpdateQueueItem(
entity_type=entity_type,
entity_id=entity_id,
response_cost=cost,
)
)
daily_pairs: List[
Tuple[Optional[Dict[str, BaseDailySpendTransaction]], DailySpendUpdateQueue]
] = [
(daily_spend_update_transactions, daily_spend_update_queue),
(daily_team_spend_update_transactions, daily_team_spend_update_queue),
(daily_org_spend_update_transactions, daily_org_spend_update_queue),
(
daily_end_user_spend_update_transactions,
daily_end_user_spend_update_queue,
),
(daily_agent_spend_update_transactions, daily_agent_spend_update_queue),
]
for daily_txns, daily_queue in daily_pairs:
if daily_txns:
await daily_queue.update_queue.put(daily_txns)
@staticmethod
def _number_of_transactions_to_store_in_redis(
db_spend_update_transactions: DBSpendUpdateTransactions,

View File

@ -1,10 +1,9 @@
import pytest
import asyncio
import aiohttp
import json
import time
from httpx import AsyncClient
from typing import Any, Optional
import litellm
from litellm._uuid import uuid
"""
@ -12,15 +11,13 @@ Tests to run
Basic Tests:
1. Basic Spend Accuracy Test:
- Make 1 calibration request, poll for spend to derive SPEND_PER_REQUEST
- Make N-1 more requests (N total)
- Expect the spend for each of the following to be N * SPEND_PER_REQUEST
Key, Team, User, Org (call /info endpoint for each object to validate)
- Make N requests, compute expected total spend locally from each response's usage
- Poll until batch writer has flushed spend to the DB
- Expect spend for Key, Team, User, Org (/info endpoints) to equal the computed total
2. Long term spend accuracy test (with 2 bursts of requests)
- Burst 1: Make requests, derive SPEND_PER_REQUEST from first request
- Burst 2: Make more requests
- Verify total spend = (burst1 + burst2) * SPEND_PER_REQUEST
- Burst 1: compute expected from responses, verify
- Burst 2: compute expected from responses, verify total = burst1 + burst2
Additional Test Scenarios:
@ -38,6 +35,18 @@ Additional Test Scenarios:
- Verify accurate total spend calculation
"""
# Upstream model the proxy is configured with (spend_tracking_config.yaml).
# The proxy computes spend using this model's pricing; the local ground-truth
# calculation uses the same pricing table via litellm.cost_per_token.
UPSTREAM_MODEL = "gpt-3.5-turbo"
# Batch writer flush cadence in CI is ~2-7s (PROXY_BATCH_WRITE_AT=2 + up to 5s jitter).
# Poll every 2s for 60s — plenty of headroom for multiple ticks to land.
POLL_INTERVAL_SECONDS = 2
POLL_TIMEOUT_SECONDS = 60
TOLERANCE = 1e-10
async def create_organization(session, organization_alias: str):
"""Helper function to create a new organization"""
@ -102,118 +111,135 @@ async def get_spend_info(session, entity_type: str, entity_id: str):
return await response.json()
async def poll_key_spend_until_nonzero(
session, key: str, timeout: int = 120, interval: int = 10
):
"""Poll key spend until it becomes non-zero or timeout is reached."""
async def get_proxy_readiness(session):
"""Fetch /health/readiness. Used both as a fail-fast gate and as a diagnostic on poll timeout."""
url = "http://0.0.0.0:4000/health/readiness"
headers = {"Authorization": "Bearer sk-1234"}
async with session.get(url, headers=headers) as response:
return response.status, await response.json()
async def assert_proxy_healthy(session):
"""Fail fast if the proxy's DB or cache is not reachable — no point running the test."""
status, body = await get_proxy_readiness(session)
if status != 200 or body.get("db") != "connected":
pytest.fail(
f"Proxy /health/readiness unhealthy (status={status}). "
f"Cannot run spend accuracy test. Response: {body}"
)
print(f"Proxy readiness OK: {body}")
def compute_expected_spend(responses) -> float:
"""
Compute the expected total spend locally from each response's usage tokens,
using the same pricing table the proxy uses. This is the independent ground
truth we compare the proxy's reported spend against.
"""
total = 0.0
for r in responses:
usage = r.usage
prompt_cost, completion_cost = litellm.cost_per_token(
model=UPSTREAM_MODEL,
prompt_tokens=usage.prompt_tokens,
completion_tokens=usage.completion_tokens,
)
total += prompt_cost + completion_cost
return total
async def poll_key_spend_until(session, key: str, expected: float) -> float:
"""
Poll key spend until it matches `expected` within TOLERANCE, or timeout.
Returns the last observed spend either way; caller decides how to report.
"""
start = time.time()
while time.time() - start < timeout:
last_spend = 0.0
while time.time() - start < POLL_TIMEOUT_SECONDS:
key_info = await get_spend_info(session, "key", key)
spend = key_info["info"]["spend"]
if spend > 0:
last_spend = key_info["info"]["spend"]
if abs(last_spend - expected) < TOLERANCE:
print(
f"Key spend became non-zero ({spend}) after {time.time() - start:.1f}s"
f"Key spend reached expected {expected} after {time.time() - start:.1f}s"
)
return spend
print(f"Key spend still 0.0, waiting... ({time.time() - start:.1f}s elapsed)")
await asyncio.sleep(interval)
raise TimeoutError(
f"Key spend remained 0.0 after {timeout}s — batch writer may not be running"
return last_spend
print(
f"Key spend {last_spend}, expected {expected}, waiting... "
f"({time.time() - start:.1f}s elapsed)"
)
await asyncio.sleep(POLL_INTERVAL_SECONDS)
return last_spend
async def fail_with_diagnostics(session, stage: str, expected: float, observed: float):
"""Emit a failure with readiness state so CI output points at the real cause."""
_, readiness = await get_proxy_readiness(session)
pytest.fail(
f"{stage}: key spend did not match expected after {POLL_TIMEOUT_SECONDS}s poll. "
f"expected={expected}, observed={observed}, diff={expected - observed}. "
f"Proxy readiness: {readiness}"
)
async def calibrate_spend_per_request(session, key: str, max_retries: int = 5):
"""
Make a single calibration request and poll for its spend to derive SPEND_PER_REQUEST.
Fails fast with pytest.fail() if spend cannot be determined.
"""
response = await chat_completion(session, key)
print(f"Calibration request completed: {response}")
for attempt in range(1, max_retries + 1):
try:
spend = await poll_key_spend_until_nonzero(
session, key, timeout=120, interval=10
)
print(
f"Calibrated SPEND_PER_REQUEST = {spend} "
f"(attempt {attempt}/{max_retries})"
)
return spend
except TimeoutError:
if attempt < max_retries:
print(
f"Calibration attempt {attempt}/{max_retries} timed out, retrying..."
)
else:
pytest.fail(
f"Failed to calibrate SPEND_PER_REQUEST after {max_retries} attempts. "
"The batch writer may not be running or the model may have 0 cost."
)
@pytest.mark.asyncio
async def test_basic_spend_accuracy():
"""
Test basic spend accuracy across different entities:
1. Create org, team, user, and key
2. Make 1 calibration request to derive SPEND_PER_REQUEST
3. Make remaining requests (NUM_LLM_REQUESTS total)
4. Verify spend accuracy for key, team, user, and org
2. Make N requests, keeping each response
3. Compute expected spend locally from response usage (independent ground truth)
4. Poll until proxy-reported spend matches expected
5. Verify spend is consistent across key, team, user, and org entities
"""
NUM_LLM_REQUESTS = 20
TOLERANCE = 1e-10
async with aiohttp.ClientSession() as session:
# Create organization
await assert_proxy_healthy(session)
org_response = await create_organization(
session=session, organization_alias=f"test-org-{uuid.uuid4()}"
)
print("org_response: ", org_response)
org_id = org_response["organization_id"]
# Create team under organization
team_response = await create_team(session, org_id)
print("team_response: ", team_response)
team_id = team_response["team_id"]
# Create user
user_response = await create_user(session, org_id)
print("user_response: ", user_response)
user_id = user_response["user_id"]
# Generate key
key_response = await generate_key(session, user_id, team_id)
print("key_response: ", key_response)
key = key_response["key"]
# Calibrate: make 1 request and derive SPEND_PER_REQUEST
spend_per_request = await calibrate_spend_per_request(session, key)
expected_spend = NUM_LLM_REQUESTS * spend_per_request
print(f"SPEND_PER_REQUEST={spend_per_request}, expected_spend={expected_spend}")
# Make remaining requests (1 already made during calibration)
for i in range(NUM_LLM_REQUESTS - 1):
responses = []
for i in range(NUM_LLM_REQUESTS):
response = await chat_completion(session, key)
print(f"Request {i + 2}/{NUM_LLM_REQUESTS} completed")
responses.append(response)
print(f"Request {i + 1}/{NUM_LLM_REQUESTS} completed")
# Poll until batch writer has flushed all spend
start = time.time()
while time.time() - start < 120:
key_info = await get_spend_info(session, "key", key)
current_spend = key_info["info"]["spend"]
if abs(current_spend - expected_spend) < TOLERANCE:
print(
f"Key spend reached expected {expected_spend} after {time.time() - start:.1f}s"
)
break
print(f"Key spend {current_spend}, expected {expected_spend}, waiting...")
await asyncio.sleep(10)
expected_spend = compute_expected_spend(responses)
assert expected_spend > 0, (
f"Locally computed expected spend is {expected_spend}. Either cost calc "
f"is broken or upstream returned zero tokens. "
f"Usage: {[r.usage.model_dump() for r in responses]}"
)
print(f"Expected total spend (local ground truth): {expected_spend}")
# Allow extra time for all entity spend aggregations to complete
final_spend = await poll_key_spend_until(session, key, expected_spend)
if abs(final_spend - expected_spend) >= TOLERANCE:
await fail_with_diagnostics(
session,
stage="test_basic_spend_accuracy",
expected=expected_spend,
observed=final_spend,
)
# Allow a final scheduler tick for team/user/org aggregations to settle
await asyncio.sleep(5)
# Get spend information for each entity
key_info = await get_spend_info(session, "key", key)
print("key_info: ", key_info)
team_info = await get_spend_info(session, "team", team_id)
@ -223,7 +249,6 @@ async def test_basic_spend_accuracy():
org_info = await get_spend_info(session, "organization", org_id)
print("org_info: ", org_info)
# Verify spend for each entity
assert (
abs(key_info["info"]["spend"] - expected_spend) < TOLERANCE
), f"Key spend {key_info['info']['spend']} does not match expected {expected_spend}"
@ -246,91 +271,78 @@ async def test_long_term_spend_accuracy_with_bursts():
"""
Test long-term spend accuracy with multiple bursts of requests:
1. Create org, team, user, and key
2. Calibrate SPEND_PER_REQUEST from first request
3. Burst 1: Make remaining requests
4. Burst 2: Make more requests
5. Verify the total spend is tracked accurately across all entities
2. Burst 1: make requests, compute expected locally, verify proxy matches
3. Burst 2: make more requests, verify proxy total == burst1 + burst2
4. Verify total spend is consistent across all entities
"""
BURST_1_REQUESTS = 22
BURST_2_REQUESTS = 12
TOTAL_REQUESTS = BURST_1_REQUESTS + BURST_2_REQUESTS
TOLERANCE = 1e-10
async with aiohttp.ClientSession() as session:
# Create organization
await assert_proxy_healthy(session)
org_response = await create_organization(
session=session, organization_alias=f"test-org-{uuid.uuid4()}"
)
print("org_response: ", org_response)
org_id = org_response["organization_id"]
# Create team under organization
team_response = await create_team(session, org_id)
print("team_response: ", team_response)
team_id = team_response["team_id"]
# Create user
user_response = await create_user(session, org_id)
print("user_response: ", user_response)
user_id = user_response["user_id"]
# Generate key
key_response = await generate_key(session, user_id, team_id)
print("key_response: ", key_response)
key = key_response["key"]
# Calibrate: make 1 request and derive SPEND_PER_REQUEST
spend_per_request = await calibrate_spend_per_request(session, key)
expected_spend = TOTAL_REQUESTS * spend_per_request
print(f"SPEND_PER_REQUEST={spend_per_request}, expected_spend={expected_spend}")
# First burst: remaining requests (1 already made during calibration)
print(f"Starting first burst ({BURST_1_REQUESTS - 1} remaining requests)...")
for i in range(BURST_1_REQUESTS - 1):
print(f"Starting first burst of {BURST_1_REQUESTS} requests...")
burst_1_responses = []
for i in range(BURST_1_REQUESTS):
response = await chat_completion(session, key)
print(f"Burst 1 - Request {i + 2}/{BURST_1_REQUESTS} completed")
burst_1_responses.append(response)
print(f"Burst 1 - Request {i + 1}/{BURST_1_REQUESTS} completed")
# Poll until batch writer has flushed burst 1 spend
burst_1_expected = BURST_1_REQUESTS * spend_per_request
start = time.time()
while time.time() - start < 120:
key_info_check = await get_spend_info(session, "key", key)
current_spend = key_info_check["info"]["spend"]
if abs(current_spend - burst_1_expected) < TOLERANCE:
print(
f"Burst 1 spend reached expected {burst_1_expected} after {time.time() - start:.1f}s"
)
break
print(f"Key spend {current_spend}, expected {burst_1_expected}, waiting...")
await asyncio.sleep(10)
burst_1_expected = compute_expected_spend(burst_1_responses)
assert burst_1_expected > 0, (
f"Burst 1 expected spend is {burst_1_expected}. "
f"Usage: {[r.usage.model_dump() for r in burst_1_responses]}"
)
print(f"Burst 1 expected spend: {burst_1_expected}")
# Check intermediate spend
intermediate_key_info = await get_spend_info(session, "key", key)
print(f"After Burst 1 - Key spend: {intermediate_key_info['info']['spend']}")
final_burst_1 = await poll_key_spend_until(session, key, burst_1_expected)
if abs(final_burst_1 - burst_1_expected) >= TOLERANCE:
await fail_with_diagnostics(
session,
stage="test_long_term_spend_accuracy burst 1",
expected=burst_1_expected,
observed=final_burst_1,
)
# Second burst
print(f"Starting second burst of {BURST_2_REQUESTS} requests...")
burst_2_responses = []
for i in range(BURST_2_REQUESTS):
response = await chat_completion(session, key)
burst_2_responses.append(response)
print(f"Burst 2 - Request {i + 1}/{BURST_2_REQUESTS} completed")
# Poll until key spend reaches expected total (burst 1 + burst 2)
start = time.time()
while time.time() - start < 120:
key_info_check = await get_spend_info(session, "key", key)
current_spend = key_info_check["info"]["spend"]
if abs(current_spend - expected_spend) < TOLERANCE:
print(
f"Total spend reached expected {expected_spend} after {time.time() - start:.1f}s"
)
break
print(f"Key spend {current_spend}, expected {expected_spend}, waiting...")
await asyncio.sleep(10)
total_expected = burst_1_expected + compute_expected_spend(burst_2_responses)
print(f"Total expected spend (burst 1 + burst 2): {total_expected}")
final_total = await poll_key_spend_until(session, key, total_expected)
if abs(final_total - total_expected) >= TOLERANCE:
await fail_with_diagnostics(
session,
stage="test_long_term_spend_accuracy total",
expected=total_expected,
observed=final_total,
)
# Allow extra time for all entity spend aggregations
await asyncio.sleep(5)
# Get final spend information for each entity
key_info = await get_spend_info(session, "key", key)
team_info = await get_spend_info(session, "team", team_id)
user_info = await get_spend_info(session, "user", user_id)
@ -341,19 +353,18 @@ async def test_long_term_spend_accuracy_with_bursts():
print(f"Final user spend: {user_info['user_info']['spend']}")
print(f"Final org spend: {org_info['spend']}")
# Verify total spend for each entity
assert (
abs(key_info["info"]["spend"] - expected_spend) < TOLERANCE
), f"Key spend {key_info['info']['spend']} does not match expected {expected_spend}"
abs(key_info["info"]["spend"] - total_expected) < TOLERANCE
), f"Key spend {key_info['info']['spend']} does not match expected {total_expected}"
assert (
abs(user_info["user_info"]["spend"] - expected_spend) < TOLERANCE
), f"User spend {user_info['user_info']['spend']} does not match expected {expected_spend}"
abs(user_info["user_info"]["spend"] - total_expected) < TOLERANCE
), f"User spend {user_info['user_info']['spend']} does not match expected {total_expected}"
assert (
abs(team_info["team_info"]["spend"] - expected_spend) < TOLERANCE
), f"Team spend {team_info['team_info']['spend']} does not match expected {expected_spend}"
abs(team_info["team_info"]["spend"] - total_expected) < TOLERANCE
), f"Team spend {team_info['team_info']['spend']} does not match expected {total_expected}"
assert (
abs(org_info["spend"] - expected_spend) < TOLERANCE
), f"Organization spend {org_info['spend']} does not match expected {expected_spend}"
abs(org_info["spend"] - total_expected) < TOLERANCE
), f"Organization spend {org_info['spend']} does not match expected {total_expected}"

View File

@ -87,6 +87,89 @@ async def test_store_in_memory_spend_updates_uses_pipeline(
assert len(rpush_list) == 3
@pytest.mark.asyncio
async def test_store_in_memory_spend_updates_restores_on_rpush_failure(
redis_update_buffer, mock_redis_cache
):
"""
If async_rpush_pipeline raises, the already-drained transactions must be
put back into the in-memory queues so the next scheduler tick retries.
Without this, any transient Redis hiccup silently loses spend data.
"""
from litellm.proxy._types import Litellm_EntityType
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
DailySpendUpdateQueue,
)
from litellm.proxy.db.db_transaction_queue.spend_update_queue import (
SpendUpdateQueue,
)
mock_redis_cache.async_rpush_pipeline = AsyncMock(
side_effect=ConnectionError("redis went away")
)
spend_queue = SpendUpdateQueue()
daily_user_queue = DailySpendUpdateQueue()
daily_team_queue = DailySpendUpdateQueue()
daily_org_queue = DailySpendUpdateQueue()
daily_end_user_queue = DailySpendUpdateQueue()
daily_agent_queue = DailySpendUpdateQueue()
# Seed real queues with data so flush_and_get_aggregated returns it
await spend_queue.add_update(
{
"entity_type": Litellm_EntityType.KEY,
"entity_id": "key-abc",
"response_cost": 1.5,
}
)
await spend_queue.add_update(
{
"entity_type": Litellm_EntityType.TEAM,
"entity_id": "team-xyz",
"response_cost": 2.5,
}
)
await daily_user_queue.add_update(
{
"user1_day_model": {
"spend": 1.0,
"prompt_tokens": 10,
"completion_tokens": 20,
}
}
)
await redis_update_buffer.store_in_memory_spend_updates_in_redis(
spend_update_queue=spend_queue,
daily_spend_update_queue=daily_user_queue,
daily_team_spend_update_queue=daily_team_queue,
daily_org_spend_update_queue=daily_org_queue,
daily_end_user_spend_update_queue=daily_end_user_queue,
daily_agent_spend_update_queue=daily_agent_queue,
)
# After restore, the main spend queue should hold one item per
# (entity_type, entity_id) pair with the aggregated cost
restored_spend = (
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
)
assert restored_spend["key_list_transactions"] == {"key-abc": 1.5}
assert restored_spend["team_list_transactions"] == {"team-xyz": 2.5}
# Daily user queue should hold the same aggregated dict
restored_daily = (
await daily_user_queue.flush_and_get_aggregated_daily_spend_update_transactions()
)
assert restored_daily == {
"user1_day_model": {
"spend": 1.0,
"prompt_tokens": 10,
"completion_tokens": 20,
}
}
@pytest.mark.asyncio
async def test_store_in_memory_spend_updates_all_empty_returns_early(
redis_update_buffer, mock_redis_cache