diff --git a/.circleci/config.yml b/.circleci/config.yml index ba9d452000..0a59b7ef0d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1802,6 +1802,7 @@ jobs: -e DD_API_KEY=$DD_API_KEY \ -e DD_SITE=$DD_SITE \ -e AWS_REGION_NAME=$AWS_REGION_NAME \ + -e PROXY_BATCH_WRITE_AT=2 \ --add-host host.docker.internal:host-gateway \ --name my-app \ -v $(pwd)/litellm/proxy/example_config_yaml/spend_tracking_config.yaml:/app/config.yaml \ diff --git a/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py b/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py index b8537c2be9..1e3014dbf3 100644 --- a/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py +++ b/litellm/proxy/db/db_transaction_queue/redis_update_buffer.py @@ -22,6 +22,7 @@ from litellm.constants import ( ) from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.proxy._types import ( + BaseDailySpendTransaction, DailyAgentSpendTransaction, DailyEndUserSpendTransaction, DailyOrganizationSpendTransaction, @@ -29,6 +30,8 @@ from litellm.proxy._types import ( DailyTeamSpendTransaction, DailyUserSpendTransaction, DBSpendUpdateTransactions, + Litellm_EntityType, + SpendUpdateQueueItem, ) from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import ( @@ -259,9 +262,36 @@ class RedisUpdateBuffer: if len(rpush_list) == 0: return - result_lengths = await self.redis_cache.async_rpush_pipeline( - rpush_list=rpush_list, - ) + try: + result_lengths = await self.redis_cache.async_rpush_pipeline( + rpush_list=rpush_list, + ) + except Exception as e: + # The in-memory queues were already drained above. If we let the + # exception propagate without restoring, the aggregated spend is + # permanently lost. Re-enqueue so the next scheduler tick retries. + verbose_proxy_logger.error( + "Spend tracking - failed to push aggregated spend updates to Redis. " + "Restoring %d transaction sets to in-memory queues for retry on next tick. " + "Error: %s", + len(rpush_list), + str(e), + ) + await self._restore_spend_updates_to_in_memory_queues( + db_spend_update_transactions=db_spend_update_transactions, + daily_spend_update_transactions=daily_spend_update_transactions, + daily_team_spend_update_transactions=daily_team_spend_update_transactions, + daily_org_spend_update_transactions=daily_org_spend_update_transactions, + daily_end_user_spend_update_transactions=daily_end_user_spend_update_transactions, + daily_agent_spend_update_transactions=daily_agent_spend_update_transactions, + spend_update_queue=spend_update_queue, + daily_spend_update_queue=daily_spend_update_queue, + daily_team_spend_update_queue=daily_team_spend_update_queue, + daily_org_spend_update_queue=daily_org_spend_update_queue, + daily_end_user_spend_update_queue=daily_end_user_spend_update_queue, + daily_agent_spend_update_queue=daily_agent_spend_update_queue, + ) + return # Emit gauge events for each queue for i, queue_size in enumerate(result_lengths): @@ -271,6 +301,101 @@ class RedisUpdateBuffer: service=service_types[i], ) + @staticmethod + async def _restore_spend_updates_to_in_memory_queues( + db_spend_update_transactions: Optional[DBSpendUpdateTransactions], + daily_spend_update_transactions: Optional[Dict[str, BaseDailySpendTransaction]], + daily_team_spend_update_transactions: Optional[ + Dict[str, BaseDailySpendTransaction] + ], + daily_org_spend_update_transactions: Optional[ + Dict[str, BaseDailySpendTransaction] + ], + daily_end_user_spend_update_transactions: Optional[ + Dict[str, BaseDailySpendTransaction] + ], + daily_agent_spend_update_transactions: Optional[ + Dict[str, BaseDailySpendTransaction] + ], + spend_update_queue: SpendUpdateQueue, + daily_spend_update_queue: DailySpendUpdateQueue, + daily_team_spend_update_queue: DailySpendUpdateQueue, + daily_org_spend_update_queue: DailySpendUpdateQueue, + daily_end_user_spend_update_queue: DailySpendUpdateQueue, + daily_agent_spend_update_queue: DailySpendUpdateQueue, + ) -> None: + """ + Put drained-but-unpushed transactions back into in-memory queues. + + Called when the Redis rpush pipeline raises. Without this, all spend + data aggregated during the current scheduler tick is permanently lost + because the source queues were already drained before the rpush. + """ + if db_spend_update_transactions is not None: + entity_entries: List[ + Tuple[Litellm_EntityType, Optional[Dict[str, float]]] + ] = [ + ( + Litellm_EntityType.USER, + db_spend_update_transactions.get("user_list_transactions"), + ), + ( + Litellm_EntityType.END_USER, + db_spend_update_transactions.get("end_user_list_transactions"), + ), + ( + Litellm_EntityType.KEY, + db_spend_update_transactions.get("key_list_transactions"), + ), + ( + Litellm_EntityType.TEAM, + db_spend_update_transactions.get("team_list_transactions"), + ), + ( + Litellm_EntityType.TEAM_MEMBER, + db_spend_update_transactions.get("team_member_list_transactions"), + ), + ( + Litellm_EntityType.ORGANIZATION, + db_spend_update_transactions.get("org_list_transactions"), + ), + ( + Litellm_EntityType.TAG, + db_spend_update_transactions.get("tag_list_transactions"), + ), + ( + Litellm_EntityType.AGENT, + db_spend_update_transactions.get("agent_list_transactions"), + ), + ] + for entity_type, entities in entity_entries: + if not entities: + continue + for entity_id, cost in entities.items(): + await spend_update_queue.add_update( + SpendUpdateQueueItem( + entity_type=entity_type, + entity_id=entity_id, + response_cost=cost, + ) + ) + + daily_pairs: List[ + Tuple[Optional[Dict[str, BaseDailySpendTransaction]], DailySpendUpdateQueue] + ] = [ + (daily_spend_update_transactions, daily_spend_update_queue), + (daily_team_spend_update_transactions, daily_team_spend_update_queue), + (daily_org_spend_update_transactions, daily_org_spend_update_queue), + ( + daily_end_user_spend_update_transactions, + daily_end_user_spend_update_queue, + ), + (daily_agent_spend_update_transactions, daily_agent_spend_update_queue), + ] + for daily_txns, daily_queue in daily_pairs: + if daily_txns: + await daily_queue.update_queue.put(daily_txns) + @staticmethod def _number_of_transactions_to_store_in_redis( db_spend_update_transactions: DBSpendUpdateTransactions, diff --git a/tests/spend_tracking_tests/test_spend_accuracy_tests.py b/tests/spend_tracking_tests/test_spend_accuracy_tests.py index 18527b525e..8c523e9191 100644 --- a/tests/spend_tracking_tests/test_spend_accuracy_tests.py +++ b/tests/spend_tracking_tests/test_spend_accuracy_tests.py @@ -1,10 +1,9 @@ import pytest import asyncio import aiohttp -import json import time -from httpx import AsyncClient -from typing import Any, Optional + +import litellm from litellm._uuid import uuid """ @@ -12,15 +11,13 @@ Tests to run Basic Tests: 1. Basic Spend Accuracy Test: - - Make 1 calibration request, poll for spend to derive SPEND_PER_REQUEST - - Make N-1 more requests (N total) - - Expect the spend for each of the following to be N * SPEND_PER_REQUEST - Key, Team, User, Org (call /info endpoint for each object to validate) + - Make N requests, compute expected total spend locally from each response's usage + - Poll until batch writer has flushed spend to the DB + - Expect spend for Key, Team, User, Org (/info endpoints) to equal the computed total 2. Long term spend accuracy test (with 2 bursts of requests) - - Burst 1: Make requests, derive SPEND_PER_REQUEST from first request - - Burst 2: Make more requests - - Verify total spend = (burst1 + burst2) * SPEND_PER_REQUEST + - Burst 1: compute expected from responses, verify + - Burst 2: compute expected from responses, verify total = burst1 + burst2 Additional Test Scenarios: @@ -38,6 +35,18 @@ Additional Test Scenarios: - Verify accurate total spend calculation """ +# Upstream model the proxy is configured with (spend_tracking_config.yaml). +# The proxy computes spend using this model's pricing; the local ground-truth +# calculation uses the same pricing table via litellm.cost_per_token. +UPSTREAM_MODEL = "gpt-3.5-turbo" + +# Batch writer flush cadence in CI is ~2-7s (PROXY_BATCH_WRITE_AT=2 + up to 5s jitter). +# Poll every 2s for 60s — plenty of headroom for multiple ticks to land. +POLL_INTERVAL_SECONDS = 2 +POLL_TIMEOUT_SECONDS = 60 + +TOLERANCE = 1e-10 + async def create_organization(session, organization_alias: str): """Helper function to create a new organization""" @@ -102,118 +111,135 @@ async def get_spend_info(session, entity_type: str, entity_id: str): return await response.json() -async def poll_key_spend_until_nonzero( - session, key: str, timeout: int = 120, interval: int = 10 -): - """Poll key spend until it becomes non-zero or timeout is reached.""" +async def get_proxy_readiness(session): + """Fetch /health/readiness. Used both as a fail-fast gate and as a diagnostic on poll timeout.""" + url = "http://0.0.0.0:4000/health/readiness" + headers = {"Authorization": "Bearer sk-1234"} + async with session.get(url, headers=headers) as response: + return response.status, await response.json() + + +async def assert_proxy_healthy(session): + """Fail fast if the proxy's DB or cache is not reachable — no point running the test.""" + status, body = await get_proxy_readiness(session) + if status != 200 or body.get("db") != "connected": + pytest.fail( + f"Proxy /health/readiness unhealthy (status={status}). " + f"Cannot run spend accuracy test. Response: {body}" + ) + print(f"Proxy readiness OK: {body}") + + +def compute_expected_spend(responses) -> float: + """ + Compute the expected total spend locally from each response's usage tokens, + using the same pricing table the proxy uses. This is the independent ground + truth we compare the proxy's reported spend against. + """ + total = 0.0 + for r in responses: + usage = r.usage + prompt_cost, completion_cost = litellm.cost_per_token( + model=UPSTREAM_MODEL, + prompt_tokens=usage.prompt_tokens, + completion_tokens=usage.completion_tokens, + ) + total += prompt_cost + completion_cost + return total + + +async def poll_key_spend_until(session, key: str, expected: float) -> float: + """ + Poll key spend until it matches `expected` within TOLERANCE, or timeout. + Returns the last observed spend either way; caller decides how to report. + """ start = time.time() - while time.time() - start < timeout: + last_spend = 0.0 + while time.time() - start < POLL_TIMEOUT_SECONDS: key_info = await get_spend_info(session, "key", key) - spend = key_info["info"]["spend"] - if spend > 0: + last_spend = key_info["info"]["spend"] + if abs(last_spend - expected) < TOLERANCE: print( - f"Key spend became non-zero ({spend}) after {time.time() - start:.1f}s" + f"Key spend reached expected {expected} after {time.time() - start:.1f}s" ) - return spend - print(f"Key spend still 0.0, waiting... ({time.time() - start:.1f}s elapsed)") - await asyncio.sleep(interval) - raise TimeoutError( - f"Key spend remained 0.0 after {timeout}s — batch writer may not be running" + return last_spend + print( + f"Key spend {last_spend}, expected {expected}, waiting... " + f"({time.time() - start:.1f}s elapsed)" + ) + await asyncio.sleep(POLL_INTERVAL_SECONDS) + return last_spend + + +async def fail_with_diagnostics(session, stage: str, expected: float, observed: float): + """Emit a failure with readiness state so CI output points at the real cause.""" + _, readiness = await get_proxy_readiness(session) + pytest.fail( + f"{stage}: key spend did not match expected after {POLL_TIMEOUT_SECONDS}s poll. " + f"expected={expected}, observed={observed}, diff={expected - observed}. " + f"Proxy readiness: {readiness}" ) -async def calibrate_spend_per_request(session, key: str, max_retries: int = 5): - """ - Make a single calibration request and poll for its spend to derive SPEND_PER_REQUEST. - Fails fast with pytest.fail() if spend cannot be determined. - """ - response = await chat_completion(session, key) - print(f"Calibration request completed: {response}") - - for attempt in range(1, max_retries + 1): - try: - spend = await poll_key_spend_until_nonzero( - session, key, timeout=120, interval=10 - ) - print( - f"Calibrated SPEND_PER_REQUEST = {spend} " - f"(attempt {attempt}/{max_retries})" - ) - return spend - except TimeoutError: - if attempt < max_retries: - print( - f"Calibration attempt {attempt}/{max_retries} timed out, retrying..." - ) - else: - pytest.fail( - f"Failed to calibrate SPEND_PER_REQUEST after {max_retries} attempts. " - "The batch writer may not be running or the model may have 0 cost." - ) - - @pytest.mark.asyncio async def test_basic_spend_accuracy(): """ Test basic spend accuracy across different entities: 1. Create org, team, user, and key - 2. Make 1 calibration request to derive SPEND_PER_REQUEST - 3. Make remaining requests (NUM_LLM_REQUESTS total) - 4. Verify spend accuracy for key, team, user, and org + 2. Make N requests, keeping each response + 3. Compute expected spend locally from response usage (independent ground truth) + 4. Poll until proxy-reported spend matches expected + 5. Verify spend is consistent across key, team, user, and org entities """ NUM_LLM_REQUESTS = 20 - TOLERANCE = 1e-10 async with aiohttp.ClientSession() as session: - # Create organization + await assert_proxy_healthy(session) + org_response = await create_organization( session=session, organization_alias=f"test-org-{uuid.uuid4()}" ) print("org_response: ", org_response) org_id = org_response["organization_id"] - # Create team under organization team_response = await create_team(session, org_id) print("team_response: ", team_response) team_id = team_response["team_id"] - # Create user user_response = await create_user(session, org_id) print("user_response: ", user_response) user_id = user_response["user_id"] - # Generate key key_response = await generate_key(session, user_id, team_id) print("key_response: ", key_response) key = key_response["key"] - # Calibrate: make 1 request and derive SPEND_PER_REQUEST - spend_per_request = await calibrate_spend_per_request(session, key) - expected_spend = NUM_LLM_REQUESTS * spend_per_request - print(f"SPEND_PER_REQUEST={spend_per_request}, expected_spend={expected_spend}") - - # Make remaining requests (1 already made during calibration) - for i in range(NUM_LLM_REQUESTS - 1): + responses = [] + for i in range(NUM_LLM_REQUESTS): response = await chat_completion(session, key) - print(f"Request {i + 2}/{NUM_LLM_REQUESTS} completed") + responses.append(response) + print(f"Request {i + 1}/{NUM_LLM_REQUESTS} completed") - # Poll until batch writer has flushed all spend - start = time.time() - while time.time() - start < 120: - key_info = await get_spend_info(session, "key", key) - current_spend = key_info["info"]["spend"] - if abs(current_spend - expected_spend) < TOLERANCE: - print( - f"Key spend reached expected {expected_spend} after {time.time() - start:.1f}s" - ) - break - print(f"Key spend {current_spend}, expected {expected_spend}, waiting...") - await asyncio.sleep(10) + expected_spend = compute_expected_spend(responses) + assert expected_spend > 0, ( + f"Locally computed expected spend is {expected_spend}. Either cost calc " + f"is broken or upstream returned zero tokens. " + f"Usage: {[r.usage.model_dump() for r in responses]}" + ) + print(f"Expected total spend (local ground truth): {expected_spend}") - # Allow extra time for all entity spend aggregations to complete + final_spend = await poll_key_spend_until(session, key, expected_spend) + if abs(final_spend - expected_spend) >= TOLERANCE: + await fail_with_diagnostics( + session, + stage="test_basic_spend_accuracy", + expected=expected_spend, + observed=final_spend, + ) + + # Allow a final scheduler tick for team/user/org aggregations to settle await asyncio.sleep(5) - # Get spend information for each entity key_info = await get_spend_info(session, "key", key) print("key_info: ", key_info) team_info = await get_spend_info(session, "team", team_id) @@ -223,7 +249,6 @@ async def test_basic_spend_accuracy(): org_info = await get_spend_info(session, "organization", org_id) print("org_info: ", org_info) - # Verify spend for each entity assert ( abs(key_info["info"]["spend"] - expected_spend) < TOLERANCE ), f"Key spend {key_info['info']['spend']} does not match expected {expected_spend}" @@ -246,91 +271,78 @@ async def test_long_term_spend_accuracy_with_bursts(): """ Test long-term spend accuracy with multiple bursts of requests: 1. Create org, team, user, and key - 2. Calibrate SPEND_PER_REQUEST from first request - 3. Burst 1: Make remaining requests - 4. Burst 2: Make more requests - 5. Verify the total spend is tracked accurately across all entities + 2. Burst 1: make requests, compute expected locally, verify proxy matches + 3. Burst 2: make more requests, verify proxy total == burst1 + burst2 + 4. Verify total spend is consistent across all entities """ BURST_1_REQUESTS = 22 BURST_2_REQUESTS = 12 - TOTAL_REQUESTS = BURST_1_REQUESTS + BURST_2_REQUESTS - TOLERANCE = 1e-10 async with aiohttp.ClientSession() as session: - # Create organization + await assert_proxy_healthy(session) + org_response = await create_organization( session=session, organization_alias=f"test-org-{uuid.uuid4()}" ) print("org_response: ", org_response) org_id = org_response["organization_id"] - # Create team under organization team_response = await create_team(session, org_id) print("team_response: ", team_response) team_id = team_response["team_id"] - # Create user user_response = await create_user(session, org_id) print("user_response: ", user_response) user_id = user_response["user_id"] - # Generate key key_response = await generate_key(session, user_id, team_id) print("key_response: ", key_response) key = key_response["key"] - # Calibrate: make 1 request and derive SPEND_PER_REQUEST - spend_per_request = await calibrate_spend_per_request(session, key) - expected_spend = TOTAL_REQUESTS * spend_per_request - print(f"SPEND_PER_REQUEST={spend_per_request}, expected_spend={expected_spend}") - - # First burst: remaining requests (1 already made during calibration) - print(f"Starting first burst ({BURST_1_REQUESTS - 1} remaining requests)...") - for i in range(BURST_1_REQUESTS - 1): + print(f"Starting first burst of {BURST_1_REQUESTS} requests...") + burst_1_responses = [] + for i in range(BURST_1_REQUESTS): response = await chat_completion(session, key) - print(f"Burst 1 - Request {i + 2}/{BURST_1_REQUESTS} completed") + burst_1_responses.append(response) + print(f"Burst 1 - Request {i + 1}/{BURST_1_REQUESTS} completed") - # Poll until batch writer has flushed burst 1 spend - burst_1_expected = BURST_1_REQUESTS * spend_per_request - start = time.time() - while time.time() - start < 120: - key_info_check = await get_spend_info(session, "key", key) - current_spend = key_info_check["info"]["spend"] - if abs(current_spend - burst_1_expected) < TOLERANCE: - print( - f"Burst 1 spend reached expected {burst_1_expected} after {time.time() - start:.1f}s" - ) - break - print(f"Key spend {current_spend}, expected {burst_1_expected}, waiting...") - await asyncio.sleep(10) + burst_1_expected = compute_expected_spend(burst_1_responses) + assert burst_1_expected > 0, ( + f"Burst 1 expected spend is {burst_1_expected}. " + f"Usage: {[r.usage.model_dump() for r in burst_1_responses]}" + ) + print(f"Burst 1 expected spend: {burst_1_expected}") - # Check intermediate spend - intermediate_key_info = await get_spend_info(session, "key", key) - print(f"After Burst 1 - Key spend: {intermediate_key_info['info']['spend']}") + final_burst_1 = await poll_key_spend_until(session, key, burst_1_expected) + if abs(final_burst_1 - burst_1_expected) >= TOLERANCE: + await fail_with_diagnostics( + session, + stage="test_long_term_spend_accuracy burst 1", + expected=burst_1_expected, + observed=final_burst_1, + ) - # Second burst print(f"Starting second burst of {BURST_2_REQUESTS} requests...") + burst_2_responses = [] for i in range(BURST_2_REQUESTS): response = await chat_completion(session, key) + burst_2_responses.append(response) print(f"Burst 2 - Request {i + 1}/{BURST_2_REQUESTS} completed") - # Poll until key spend reaches expected total (burst 1 + burst 2) - start = time.time() - while time.time() - start < 120: - key_info_check = await get_spend_info(session, "key", key) - current_spend = key_info_check["info"]["spend"] - if abs(current_spend - expected_spend) < TOLERANCE: - print( - f"Total spend reached expected {expected_spend} after {time.time() - start:.1f}s" - ) - break - print(f"Key spend {current_spend}, expected {expected_spend}, waiting...") - await asyncio.sleep(10) + total_expected = burst_1_expected + compute_expected_spend(burst_2_responses) + print(f"Total expected spend (burst 1 + burst 2): {total_expected}") + + final_total = await poll_key_spend_until(session, key, total_expected) + if abs(final_total - total_expected) >= TOLERANCE: + await fail_with_diagnostics( + session, + stage="test_long_term_spend_accuracy total", + expected=total_expected, + observed=final_total, + ) - # Allow extra time for all entity spend aggregations await asyncio.sleep(5) - # Get final spend information for each entity key_info = await get_spend_info(session, "key", key) team_info = await get_spend_info(session, "team", team_id) user_info = await get_spend_info(session, "user", user_id) @@ -341,19 +353,18 @@ async def test_long_term_spend_accuracy_with_bursts(): print(f"Final user spend: {user_info['user_info']['spend']}") print(f"Final org spend: {org_info['spend']}") - # Verify total spend for each entity assert ( - abs(key_info["info"]["spend"] - expected_spend) < TOLERANCE - ), f"Key spend {key_info['info']['spend']} does not match expected {expected_spend}" + abs(key_info["info"]["spend"] - total_expected) < TOLERANCE + ), f"Key spend {key_info['info']['spend']} does not match expected {total_expected}" assert ( - abs(user_info["user_info"]["spend"] - expected_spend) < TOLERANCE - ), f"User spend {user_info['user_info']['spend']} does not match expected {expected_spend}" + abs(user_info["user_info"]["spend"] - total_expected) < TOLERANCE + ), f"User spend {user_info['user_info']['spend']} does not match expected {total_expected}" assert ( - abs(team_info["team_info"]["spend"] - expected_spend) < TOLERANCE - ), f"Team spend {team_info['team_info']['spend']} does not match expected {expected_spend}" + abs(team_info["team_info"]["spend"] - total_expected) < TOLERANCE + ), f"Team spend {team_info['team_info']['spend']} does not match expected {total_expected}" assert ( - abs(org_info["spend"] - expected_spend) < TOLERANCE - ), f"Organization spend {org_info['spend']} does not match expected {expected_spend}" + abs(org_info["spend"] - total_expected) < TOLERANCE + ), f"Organization spend {org_info['spend']} does not match expected {total_expected}" diff --git a/tests/test_litellm/proxy/db/db_transaction_queue/test_redis_update_buffer.py b/tests/test_litellm/proxy/db/db_transaction_queue/test_redis_update_buffer.py index 44602125ff..0587e3bce1 100644 --- a/tests/test_litellm/proxy/db/db_transaction_queue/test_redis_update_buffer.py +++ b/tests/test_litellm/proxy/db/db_transaction_queue/test_redis_update_buffer.py @@ -87,6 +87,89 @@ async def test_store_in_memory_spend_updates_uses_pipeline( assert len(rpush_list) == 3 +@pytest.mark.asyncio +async def test_store_in_memory_spend_updates_restores_on_rpush_failure( + redis_update_buffer, mock_redis_cache +): + """ + If async_rpush_pipeline raises, the already-drained transactions must be + put back into the in-memory queues so the next scheduler tick retries. + Without this, any transient Redis hiccup silently loses spend data. + """ + from litellm.proxy._types import Litellm_EntityType + from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import ( + DailySpendUpdateQueue, + ) + from litellm.proxy.db.db_transaction_queue.spend_update_queue import ( + SpendUpdateQueue, + ) + + mock_redis_cache.async_rpush_pipeline = AsyncMock( + side_effect=ConnectionError("redis went away") + ) + + spend_queue = SpendUpdateQueue() + daily_user_queue = DailySpendUpdateQueue() + daily_team_queue = DailySpendUpdateQueue() + daily_org_queue = DailySpendUpdateQueue() + daily_end_user_queue = DailySpendUpdateQueue() + daily_agent_queue = DailySpendUpdateQueue() + + # Seed real queues with data so flush_and_get_aggregated returns it + await spend_queue.add_update( + { + "entity_type": Litellm_EntityType.KEY, + "entity_id": "key-abc", + "response_cost": 1.5, + } + ) + await spend_queue.add_update( + { + "entity_type": Litellm_EntityType.TEAM, + "entity_id": "team-xyz", + "response_cost": 2.5, + } + ) + await daily_user_queue.add_update( + { + "user1_day_model": { + "spend": 1.0, + "prompt_tokens": 10, + "completion_tokens": 20, + } + } + ) + + await redis_update_buffer.store_in_memory_spend_updates_in_redis( + spend_update_queue=spend_queue, + daily_spend_update_queue=daily_user_queue, + daily_team_spend_update_queue=daily_team_queue, + daily_org_spend_update_queue=daily_org_queue, + daily_end_user_spend_update_queue=daily_end_user_queue, + daily_agent_spend_update_queue=daily_agent_queue, + ) + + # After restore, the main spend queue should hold one item per + # (entity_type, entity_id) pair with the aggregated cost + restored_spend = ( + await spend_queue.flush_and_get_aggregated_db_spend_update_transactions() + ) + assert restored_spend["key_list_transactions"] == {"key-abc": 1.5} + assert restored_spend["team_list_transactions"] == {"team-xyz": 2.5} + + # Daily user queue should hold the same aggregated dict + restored_daily = ( + await daily_user_queue.flush_and_get_aggregated_daily_spend_update_transactions() + ) + assert restored_daily == { + "user1_day_model": { + "spend": 1.0, + "prompt_tokens": 10, + "completion_tokens": 20, + } + } + + @pytest.mark.asyncio async def test_store_in_memory_spend_updates_all_empty_returns_early( redis_update_buffer, mock_redis_cache