Merge pull request #26270 from BerriAI/litellm_/lucid-kowalevski-de832f
[Fix] Stabilize flaky spend accuracy tests + patch Redis buffer data-loss path
This commit is contained in:
commit
b6fdd46636
@ -1802,6 +1802,7 @@ jobs:
|
||||
-e DD_API_KEY=$DD_API_KEY \
|
||||
-e DD_SITE=$DD_SITE \
|
||||
-e AWS_REGION_NAME=$AWS_REGION_NAME \
|
||||
-e PROXY_BATCH_WRITE_AT=2 \
|
||||
--add-host host.docker.internal:host-gateway \
|
||||
--name my-app \
|
||||
-v $(pwd)/litellm/proxy/example_config_yaml/spend_tracking_config.yaml:/app/config.yaml \
|
||||
|
||||
@ -22,6 +22,7 @@ from litellm.constants import (
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy._types import (
|
||||
BaseDailySpendTransaction,
|
||||
DailyAgentSpendTransaction,
|
||||
DailyEndUserSpendTransaction,
|
||||
DailyOrganizationSpendTransaction,
|
||||
@ -29,6 +30,8 @@ from litellm.proxy._types import (
|
||||
DailyTeamSpendTransaction,
|
||||
DailyUserSpendTransaction,
|
||||
DBSpendUpdateTransactions,
|
||||
Litellm_EntityType,
|
||||
SpendUpdateQueueItem,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.base_update_queue import service_logger_obj
|
||||
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
|
||||
@ -259,9 +262,36 @@ class RedisUpdateBuffer:
|
||||
if len(rpush_list) == 0:
|
||||
return
|
||||
|
||||
result_lengths = await self.redis_cache.async_rpush_pipeline(
|
||||
rpush_list=rpush_list,
|
||||
)
|
||||
try:
|
||||
result_lengths = await self.redis_cache.async_rpush_pipeline(
|
||||
rpush_list=rpush_list,
|
||||
)
|
||||
except Exception as e:
|
||||
# The in-memory queues were already drained above. If we let the
|
||||
# exception propagate without restoring, the aggregated spend is
|
||||
# permanently lost. Re-enqueue so the next scheduler tick retries.
|
||||
verbose_proxy_logger.error(
|
||||
"Spend tracking - failed to push aggregated spend updates to Redis. "
|
||||
"Restoring %d transaction sets to in-memory queues for retry on next tick. "
|
||||
"Error: %s",
|
||||
len(rpush_list),
|
||||
str(e),
|
||||
)
|
||||
await self._restore_spend_updates_to_in_memory_queues(
|
||||
db_spend_update_transactions=db_spend_update_transactions,
|
||||
daily_spend_update_transactions=daily_spend_update_transactions,
|
||||
daily_team_spend_update_transactions=daily_team_spend_update_transactions,
|
||||
daily_org_spend_update_transactions=daily_org_spend_update_transactions,
|
||||
daily_end_user_spend_update_transactions=daily_end_user_spend_update_transactions,
|
||||
daily_agent_spend_update_transactions=daily_agent_spend_update_transactions,
|
||||
spend_update_queue=spend_update_queue,
|
||||
daily_spend_update_queue=daily_spend_update_queue,
|
||||
daily_team_spend_update_queue=daily_team_spend_update_queue,
|
||||
daily_org_spend_update_queue=daily_org_spend_update_queue,
|
||||
daily_end_user_spend_update_queue=daily_end_user_spend_update_queue,
|
||||
daily_agent_spend_update_queue=daily_agent_spend_update_queue,
|
||||
)
|
||||
return
|
||||
|
||||
# Emit gauge events for each queue
|
||||
for i, queue_size in enumerate(result_lengths):
|
||||
@ -271,6 +301,101 @@ class RedisUpdateBuffer:
|
||||
service=service_types[i],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
async def _restore_spend_updates_to_in_memory_queues(
|
||||
db_spend_update_transactions: Optional[DBSpendUpdateTransactions],
|
||||
daily_spend_update_transactions: Optional[Dict[str, BaseDailySpendTransaction]],
|
||||
daily_team_spend_update_transactions: Optional[
|
||||
Dict[str, BaseDailySpendTransaction]
|
||||
],
|
||||
daily_org_spend_update_transactions: Optional[
|
||||
Dict[str, BaseDailySpendTransaction]
|
||||
],
|
||||
daily_end_user_spend_update_transactions: Optional[
|
||||
Dict[str, BaseDailySpendTransaction]
|
||||
],
|
||||
daily_agent_spend_update_transactions: Optional[
|
||||
Dict[str, BaseDailySpendTransaction]
|
||||
],
|
||||
spend_update_queue: SpendUpdateQueue,
|
||||
daily_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_team_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_org_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_end_user_spend_update_queue: DailySpendUpdateQueue,
|
||||
daily_agent_spend_update_queue: DailySpendUpdateQueue,
|
||||
) -> None:
|
||||
"""
|
||||
Put drained-but-unpushed transactions back into in-memory queues.
|
||||
|
||||
Called when the Redis rpush pipeline raises. Without this, all spend
|
||||
data aggregated during the current scheduler tick is permanently lost
|
||||
because the source queues were already drained before the rpush.
|
||||
"""
|
||||
if db_spend_update_transactions is not None:
|
||||
entity_entries: List[
|
||||
Tuple[Litellm_EntityType, Optional[Dict[str, float]]]
|
||||
] = [
|
||||
(
|
||||
Litellm_EntityType.USER,
|
||||
db_spend_update_transactions.get("user_list_transactions"),
|
||||
),
|
||||
(
|
||||
Litellm_EntityType.END_USER,
|
||||
db_spend_update_transactions.get("end_user_list_transactions"),
|
||||
),
|
||||
(
|
||||
Litellm_EntityType.KEY,
|
||||
db_spend_update_transactions.get("key_list_transactions"),
|
||||
),
|
||||
(
|
||||
Litellm_EntityType.TEAM,
|
||||
db_spend_update_transactions.get("team_list_transactions"),
|
||||
),
|
||||
(
|
||||
Litellm_EntityType.TEAM_MEMBER,
|
||||
db_spend_update_transactions.get("team_member_list_transactions"),
|
||||
),
|
||||
(
|
||||
Litellm_EntityType.ORGANIZATION,
|
||||
db_spend_update_transactions.get("org_list_transactions"),
|
||||
),
|
||||
(
|
||||
Litellm_EntityType.TAG,
|
||||
db_spend_update_transactions.get("tag_list_transactions"),
|
||||
),
|
||||
(
|
||||
Litellm_EntityType.AGENT,
|
||||
db_spend_update_transactions.get("agent_list_transactions"),
|
||||
),
|
||||
]
|
||||
for entity_type, entities in entity_entries:
|
||||
if not entities:
|
||||
continue
|
||||
for entity_id, cost in entities.items():
|
||||
await spend_update_queue.add_update(
|
||||
SpendUpdateQueueItem(
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
response_cost=cost,
|
||||
)
|
||||
)
|
||||
|
||||
daily_pairs: List[
|
||||
Tuple[Optional[Dict[str, BaseDailySpendTransaction]], DailySpendUpdateQueue]
|
||||
] = [
|
||||
(daily_spend_update_transactions, daily_spend_update_queue),
|
||||
(daily_team_spend_update_transactions, daily_team_spend_update_queue),
|
||||
(daily_org_spend_update_transactions, daily_org_spend_update_queue),
|
||||
(
|
||||
daily_end_user_spend_update_transactions,
|
||||
daily_end_user_spend_update_queue,
|
||||
),
|
||||
(daily_agent_spend_update_transactions, daily_agent_spend_update_queue),
|
||||
]
|
||||
for daily_txns, daily_queue in daily_pairs:
|
||||
if daily_txns:
|
||||
await daily_queue.update_queue.put(daily_txns)
|
||||
|
||||
@staticmethod
|
||||
def _number_of_transactions_to_store_in_redis(
|
||||
db_spend_update_transactions: DBSpendUpdateTransactions,
|
||||
|
||||
@ -1,10 +1,9 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
import time
|
||||
from httpx import AsyncClient
|
||||
from typing import Any, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._uuid import uuid
|
||||
|
||||
"""
|
||||
@ -12,15 +11,13 @@ Tests to run
|
||||
|
||||
Basic Tests:
|
||||
1. Basic Spend Accuracy Test:
|
||||
- Make 1 calibration request, poll for spend to derive SPEND_PER_REQUEST
|
||||
- Make N-1 more requests (N total)
|
||||
- Expect the spend for each of the following to be N * SPEND_PER_REQUEST
|
||||
Key, Team, User, Org (call /info endpoint for each object to validate)
|
||||
- Make N requests, compute expected total spend locally from each response's usage
|
||||
- Poll until batch writer has flushed spend to the DB
|
||||
- Expect spend for Key, Team, User, Org (/info endpoints) to equal the computed total
|
||||
|
||||
2. Long term spend accuracy test (with 2 bursts of requests)
|
||||
- Burst 1: Make requests, derive SPEND_PER_REQUEST from first request
|
||||
- Burst 2: Make more requests
|
||||
- Verify total spend = (burst1 + burst2) * SPEND_PER_REQUEST
|
||||
- Burst 1: compute expected from responses, verify
|
||||
- Burst 2: compute expected from responses, verify total = burst1 + burst2
|
||||
|
||||
Additional Test Scenarios:
|
||||
|
||||
@ -38,6 +35,18 @@ Additional Test Scenarios:
|
||||
- Verify accurate total spend calculation
|
||||
"""
|
||||
|
||||
# Upstream model the proxy is configured with (spend_tracking_config.yaml).
|
||||
# The proxy computes spend using this model's pricing; the local ground-truth
|
||||
# calculation uses the same pricing table via litellm.cost_per_token.
|
||||
UPSTREAM_MODEL = "gpt-3.5-turbo"
|
||||
|
||||
# Batch writer flush cadence in CI is ~2-7s (PROXY_BATCH_WRITE_AT=2 + up to 5s jitter).
|
||||
# Poll every 2s for 60s — plenty of headroom for multiple ticks to land.
|
||||
POLL_INTERVAL_SECONDS = 2
|
||||
POLL_TIMEOUT_SECONDS = 60
|
||||
|
||||
TOLERANCE = 1e-10
|
||||
|
||||
|
||||
async def create_organization(session, organization_alias: str):
|
||||
"""Helper function to create a new organization"""
|
||||
@ -102,118 +111,135 @@ async def get_spend_info(session, entity_type: str, entity_id: str):
|
||||
return await response.json()
|
||||
|
||||
|
||||
async def poll_key_spend_until_nonzero(
|
||||
session, key: str, timeout: int = 120, interval: int = 10
|
||||
):
|
||||
"""Poll key spend until it becomes non-zero or timeout is reached."""
|
||||
async def get_proxy_readiness(session):
|
||||
"""Fetch /health/readiness. Used both as a fail-fast gate and as a diagnostic on poll timeout."""
|
||||
url = "http://0.0.0.0:4000/health/readiness"
|
||||
headers = {"Authorization": "Bearer sk-1234"}
|
||||
async with session.get(url, headers=headers) as response:
|
||||
return response.status, await response.json()
|
||||
|
||||
|
||||
async def assert_proxy_healthy(session):
|
||||
"""Fail fast if the proxy's DB or cache is not reachable — no point running the test."""
|
||||
status, body = await get_proxy_readiness(session)
|
||||
if status != 200 or body.get("db") != "connected":
|
||||
pytest.fail(
|
||||
f"Proxy /health/readiness unhealthy (status={status}). "
|
||||
f"Cannot run spend accuracy test. Response: {body}"
|
||||
)
|
||||
print(f"Proxy readiness OK: {body}")
|
||||
|
||||
|
||||
def compute_expected_spend(responses) -> float:
|
||||
"""
|
||||
Compute the expected total spend locally from each response's usage tokens,
|
||||
using the same pricing table the proxy uses. This is the independent ground
|
||||
truth we compare the proxy's reported spend against.
|
||||
"""
|
||||
total = 0.0
|
||||
for r in responses:
|
||||
usage = r.usage
|
||||
prompt_cost, completion_cost = litellm.cost_per_token(
|
||||
model=UPSTREAM_MODEL,
|
||||
prompt_tokens=usage.prompt_tokens,
|
||||
completion_tokens=usage.completion_tokens,
|
||||
)
|
||||
total += prompt_cost + completion_cost
|
||||
return total
|
||||
|
||||
|
||||
async def poll_key_spend_until(session, key: str, expected: float) -> float:
|
||||
"""
|
||||
Poll key spend until it matches `expected` within TOLERANCE, or timeout.
|
||||
Returns the last observed spend either way; caller decides how to report.
|
||||
"""
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
last_spend = 0.0
|
||||
while time.time() - start < POLL_TIMEOUT_SECONDS:
|
||||
key_info = await get_spend_info(session, "key", key)
|
||||
spend = key_info["info"]["spend"]
|
||||
if spend > 0:
|
||||
last_spend = key_info["info"]["spend"]
|
||||
if abs(last_spend - expected) < TOLERANCE:
|
||||
print(
|
||||
f"Key spend became non-zero ({spend}) after {time.time() - start:.1f}s"
|
||||
f"Key spend reached expected {expected} after {time.time() - start:.1f}s"
|
||||
)
|
||||
return spend
|
||||
print(f"Key spend still 0.0, waiting... ({time.time() - start:.1f}s elapsed)")
|
||||
await asyncio.sleep(interval)
|
||||
raise TimeoutError(
|
||||
f"Key spend remained 0.0 after {timeout}s — batch writer may not be running"
|
||||
return last_spend
|
||||
print(
|
||||
f"Key spend {last_spend}, expected {expected}, waiting... "
|
||||
f"({time.time() - start:.1f}s elapsed)"
|
||||
)
|
||||
await asyncio.sleep(POLL_INTERVAL_SECONDS)
|
||||
return last_spend
|
||||
|
||||
|
||||
async def fail_with_diagnostics(session, stage: str, expected: float, observed: float):
|
||||
"""Emit a failure with readiness state so CI output points at the real cause."""
|
||||
_, readiness = await get_proxy_readiness(session)
|
||||
pytest.fail(
|
||||
f"{stage}: key spend did not match expected after {POLL_TIMEOUT_SECONDS}s poll. "
|
||||
f"expected={expected}, observed={observed}, diff={expected - observed}. "
|
||||
f"Proxy readiness: {readiness}"
|
||||
)
|
||||
|
||||
|
||||
async def calibrate_spend_per_request(session, key: str, max_retries: int = 5):
|
||||
"""
|
||||
Make a single calibration request and poll for its spend to derive SPEND_PER_REQUEST.
|
||||
Fails fast with pytest.fail() if spend cannot be determined.
|
||||
"""
|
||||
response = await chat_completion(session, key)
|
||||
print(f"Calibration request completed: {response}")
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
spend = await poll_key_spend_until_nonzero(
|
||||
session, key, timeout=120, interval=10
|
||||
)
|
||||
print(
|
||||
f"Calibrated SPEND_PER_REQUEST = {spend} "
|
||||
f"(attempt {attempt}/{max_retries})"
|
||||
)
|
||||
return spend
|
||||
except TimeoutError:
|
||||
if attempt < max_retries:
|
||||
print(
|
||||
f"Calibration attempt {attempt}/{max_retries} timed out, retrying..."
|
||||
)
|
||||
else:
|
||||
pytest.fail(
|
||||
f"Failed to calibrate SPEND_PER_REQUEST after {max_retries} attempts. "
|
||||
"The batch writer may not be running or the model may have 0 cost."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_spend_accuracy():
|
||||
"""
|
||||
Test basic spend accuracy across different entities:
|
||||
1. Create org, team, user, and key
|
||||
2. Make 1 calibration request to derive SPEND_PER_REQUEST
|
||||
3. Make remaining requests (NUM_LLM_REQUESTS total)
|
||||
4. Verify spend accuracy for key, team, user, and org
|
||||
2. Make N requests, keeping each response
|
||||
3. Compute expected spend locally from response usage (independent ground truth)
|
||||
4. Poll until proxy-reported spend matches expected
|
||||
5. Verify spend is consistent across key, team, user, and org entities
|
||||
"""
|
||||
NUM_LLM_REQUESTS = 20
|
||||
TOLERANCE = 1e-10
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create organization
|
||||
await assert_proxy_healthy(session)
|
||||
|
||||
org_response = await create_organization(
|
||||
session=session, organization_alias=f"test-org-{uuid.uuid4()}"
|
||||
)
|
||||
print("org_response: ", org_response)
|
||||
org_id = org_response["organization_id"]
|
||||
|
||||
# Create team under organization
|
||||
team_response = await create_team(session, org_id)
|
||||
print("team_response: ", team_response)
|
||||
team_id = team_response["team_id"]
|
||||
|
||||
# Create user
|
||||
user_response = await create_user(session, org_id)
|
||||
print("user_response: ", user_response)
|
||||
user_id = user_response["user_id"]
|
||||
|
||||
# Generate key
|
||||
key_response = await generate_key(session, user_id, team_id)
|
||||
print("key_response: ", key_response)
|
||||
key = key_response["key"]
|
||||
|
||||
# Calibrate: make 1 request and derive SPEND_PER_REQUEST
|
||||
spend_per_request = await calibrate_spend_per_request(session, key)
|
||||
expected_spend = NUM_LLM_REQUESTS * spend_per_request
|
||||
print(f"SPEND_PER_REQUEST={spend_per_request}, expected_spend={expected_spend}")
|
||||
|
||||
# Make remaining requests (1 already made during calibration)
|
||||
for i in range(NUM_LLM_REQUESTS - 1):
|
||||
responses = []
|
||||
for i in range(NUM_LLM_REQUESTS):
|
||||
response = await chat_completion(session, key)
|
||||
print(f"Request {i + 2}/{NUM_LLM_REQUESTS} completed")
|
||||
responses.append(response)
|
||||
print(f"Request {i + 1}/{NUM_LLM_REQUESTS} completed")
|
||||
|
||||
# Poll until batch writer has flushed all spend
|
||||
start = time.time()
|
||||
while time.time() - start < 120:
|
||||
key_info = await get_spend_info(session, "key", key)
|
||||
current_spend = key_info["info"]["spend"]
|
||||
if abs(current_spend - expected_spend) < TOLERANCE:
|
||||
print(
|
||||
f"Key spend reached expected {expected_spend} after {time.time() - start:.1f}s"
|
||||
)
|
||||
break
|
||||
print(f"Key spend {current_spend}, expected {expected_spend}, waiting...")
|
||||
await asyncio.sleep(10)
|
||||
expected_spend = compute_expected_spend(responses)
|
||||
assert expected_spend > 0, (
|
||||
f"Locally computed expected spend is {expected_spend}. Either cost calc "
|
||||
f"is broken or upstream returned zero tokens. "
|
||||
f"Usage: {[r.usage.model_dump() for r in responses]}"
|
||||
)
|
||||
print(f"Expected total spend (local ground truth): {expected_spend}")
|
||||
|
||||
# Allow extra time for all entity spend aggregations to complete
|
||||
final_spend = await poll_key_spend_until(session, key, expected_spend)
|
||||
if abs(final_spend - expected_spend) >= TOLERANCE:
|
||||
await fail_with_diagnostics(
|
||||
session,
|
||||
stage="test_basic_spend_accuracy",
|
||||
expected=expected_spend,
|
||||
observed=final_spend,
|
||||
)
|
||||
|
||||
# Allow a final scheduler tick for team/user/org aggregations to settle
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Get spend information for each entity
|
||||
key_info = await get_spend_info(session, "key", key)
|
||||
print("key_info: ", key_info)
|
||||
team_info = await get_spend_info(session, "team", team_id)
|
||||
@ -223,7 +249,6 @@ async def test_basic_spend_accuracy():
|
||||
org_info = await get_spend_info(session, "organization", org_id)
|
||||
print("org_info: ", org_info)
|
||||
|
||||
# Verify spend for each entity
|
||||
assert (
|
||||
abs(key_info["info"]["spend"] - expected_spend) < TOLERANCE
|
||||
), f"Key spend {key_info['info']['spend']} does not match expected {expected_spend}"
|
||||
@ -246,91 +271,78 @@ async def test_long_term_spend_accuracy_with_bursts():
|
||||
"""
|
||||
Test long-term spend accuracy with multiple bursts of requests:
|
||||
1. Create org, team, user, and key
|
||||
2. Calibrate SPEND_PER_REQUEST from first request
|
||||
3. Burst 1: Make remaining requests
|
||||
4. Burst 2: Make more requests
|
||||
5. Verify the total spend is tracked accurately across all entities
|
||||
2. Burst 1: make requests, compute expected locally, verify proxy matches
|
||||
3. Burst 2: make more requests, verify proxy total == burst1 + burst2
|
||||
4. Verify total spend is consistent across all entities
|
||||
"""
|
||||
BURST_1_REQUESTS = 22
|
||||
BURST_2_REQUESTS = 12
|
||||
TOTAL_REQUESTS = BURST_1_REQUESTS + BURST_2_REQUESTS
|
||||
TOLERANCE = 1e-10
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create organization
|
||||
await assert_proxy_healthy(session)
|
||||
|
||||
org_response = await create_organization(
|
||||
session=session, organization_alias=f"test-org-{uuid.uuid4()}"
|
||||
)
|
||||
print("org_response: ", org_response)
|
||||
org_id = org_response["organization_id"]
|
||||
|
||||
# Create team under organization
|
||||
team_response = await create_team(session, org_id)
|
||||
print("team_response: ", team_response)
|
||||
team_id = team_response["team_id"]
|
||||
|
||||
# Create user
|
||||
user_response = await create_user(session, org_id)
|
||||
print("user_response: ", user_response)
|
||||
user_id = user_response["user_id"]
|
||||
|
||||
# Generate key
|
||||
key_response = await generate_key(session, user_id, team_id)
|
||||
print("key_response: ", key_response)
|
||||
key = key_response["key"]
|
||||
|
||||
# Calibrate: make 1 request and derive SPEND_PER_REQUEST
|
||||
spend_per_request = await calibrate_spend_per_request(session, key)
|
||||
expected_spend = TOTAL_REQUESTS * spend_per_request
|
||||
print(f"SPEND_PER_REQUEST={spend_per_request}, expected_spend={expected_spend}")
|
||||
|
||||
# First burst: remaining requests (1 already made during calibration)
|
||||
print(f"Starting first burst ({BURST_1_REQUESTS - 1} remaining requests)...")
|
||||
for i in range(BURST_1_REQUESTS - 1):
|
||||
print(f"Starting first burst of {BURST_1_REQUESTS} requests...")
|
||||
burst_1_responses = []
|
||||
for i in range(BURST_1_REQUESTS):
|
||||
response = await chat_completion(session, key)
|
||||
print(f"Burst 1 - Request {i + 2}/{BURST_1_REQUESTS} completed")
|
||||
burst_1_responses.append(response)
|
||||
print(f"Burst 1 - Request {i + 1}/{BURST_1_REQUESTS} completed")
|
||||
|
||||
# Poll until batch writer has flushed burst 1 spend
|
||||
burst_1_expected = BURST_1_REQUESTS * spend_per_request
|
||||
start = time.time()
|
||||
while time.time() - start < 120:
|
||||
key_info_check = await get_spend_info(session, "key", key)
|
||||
current_spend = key_info_check["info"]["spend"]
|
||||
if abs(current_spend - burst_1_expected) < TOLERANCE:
|
||||
print(
|
||||
f"Burst 1 spend reached expected {burst_1_expected} after {time.time() - start:.1f}s"
|
||||
)
|
||||
break
|
||||
print(f"Key spend {current_spend}, expected {burst_1_expected}, waiting...")
|
||||
await asyncio.sleep(10)
|
||||
burst_1_expected = compute_expected_spend(burst_1_responses)
|
||||
assert burst_1_expected > 0, (
|
||||
f"Burst 1 expected spend is {burst_1_expected}. "
|
||||
f"Usage: {[r.usage.model_dump() for r in burst_1_responses]}"
|
||||
)
|
||||
print(f"Burst 1 expected spend: {burst_1_expected}")
|
||||
|
||||
# Check intermediate spend
|
||||
intermediate_key_info = await get_spend_info(session, "key", key)
|
||||
print(f"After Burst 1 - Key spend: {intermediate_key_info['info']['spend']}")
|
||||
final_burst_1 = await poll_key_spend_until(session, key, burst_1_expected)
|
||||
if abs(final_burst_1 - burst_1_expected) >= TOLERANCE:
|
||||
await fail_with_diagnostics(
|
||||
session,
|
||||
stage="test_long_term_spend_accuracy burst 1",
|
||||
expected=burst_1_expected,
|
||||
observed=final_burst_1,
|
||||
)
|
||||
|
||||
# Second burst
|
||||
print(f"Starting second burst of {BURST_2_REQUESTS} requests...")
|
||||
burst_2_responses = []
|
||||
for i in range(BURST_2_REQUESTS):
|
||||
response = await chat_completion(session, key)
|
||||
burst_2_responses.append(response)
|
||||
print(f"Burst 2 - Request {i + 1}/{BURST_2_REQUESTS} completed")
|
||||
|
||||
# Poll until key spend reaches expected total (burst 1 + burst 2)
|
||||
start = time.time()
|
||||
while time.time() - start < 120:
|
||||
key_info_check = await get_spend_info(session, "key", key)
|
||||
current_spend = key_info_check["info"]["spend"]
|
||||
if abs(current_spend - expected_spend) < TOLERANCE:
|
||||
print(
|
||||
f"Total spend reached expected {expected_spend} after {time.time() - start:.1f}s"
|
||||
)
|
||||
break
|
||||
print(f"Key spend {current_spend}, expected {expected_spend}, waiting...")
|
||||
await asyncio.sleep(10)
|
||||
total_expected = burst_1_expected + compute_expected_spend(burst_2_responses)
|
||||
print(f"Total expected spend (burst 1 + burst 2): {total_expected}")
|
||||
|
||||
final_total = await poll_key_spend_until(session, key, total_expected)
|
||||
if abs(final_total - total_expected) >= TOLERANCE:
|
||||
await fail_with_diagnostics(
|
||||
session,
|
||||
stage="test_long_term_spend_accuracy total",
|
||||
expected=total_expected,
|
||||
observed=final_total,
|
||||
)
|
||||
|
||||
# Allow extra time for all entity spend aggregations
|
||||
await asyncio.sleep(5)
|
||||
|
||||
# Get final spend information for each entity
|
||||
key_info = await get_spend_info(session, "key", key)
|
||||
team_info = await get_spend_info(session, "team", team_id)
|
||||
user_info = await get_spend_info(session, "user", user_id)
|
||||
@ -341,19 +353,18 @@ async def test_long_term_spend_accuracy_with_bursts():
|
||||
print(f"Final user spend: {user_info['user_info']['spend']}")
|
||||
print(f"Final org spend: {org_info['spend']}")
|
||||
|
||||
# Verify total spend for each entity
|
||||
assert (
|
||||
abs(key_info["info"]["spend"] - expected_spend) < TOLERANCE
|
||||
), f"Key spend {key_info['info']['spend']} does not match expected {expected_spend}"
|
||||
abs(key_info["info"]["spend"] - total_expected) < TOLERANCE
|
||||
), f"Key spend {key_info['info']['spend']} does not match expected {total_expected}"
|
||||
|
||||
assert (
|
||||
abs(user_info["user_info"]["spend"] - expected_spend) < TOLERANCE
|
||||
), f"User spend {user_info['user_info']['spend']} does not match expected {expected_spend}"
|
||||
abs(user_info["user_info"]["spend"] - total_expected) < TOLERANCE
|
||||
), f"User spend {user_info['user_info']['spend']} does not match expected {total_expected}"
|
||||
|
||||
assert (
|
||||
abs(team_info["team_info"]["spend"] - expected_spend) < TOLERANCE
|
||||
), f"Team spend {team_info['team_info']['spend']} does not match expected {expected_spend}"
|
||||
abs(team_info["team_info"]["spend"] - total_expected) < TOLERANCE
|
||||
), f"Team spend {team_info['team_info']['spend']} does not match expected {total_expected}"
|
||||
|
||||
assert (
|
||||
abs(org_info["spend"] - expected_spend) < TOLERANCE
|
||||
), f"Organization spend {org_info['spend']} does not match expected {expected_spend}"
|
||||
abs(org_info["spend"] - total_expected) < TOLERANCE
|
||||
), f"Organization spend {org_info['spend']} does not match expected {total_expected}"
|
||||
|
||||
@ -87,6 +87,89 @@ async def test_store_in_memory_spend_updates_uses_pipeline(
|
||||
assert len(rpush_list) == 3
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_in_memory_spend_updates_restores_on_rpush_failure(
|
||||
redis_update_buffer, mock_redis_cache
|
||||
):
|
||||
"""
|
||||
If async_rpush_pipeline raises, the already-drained transactions must be
|
||||
put back into the in-memory queues so the next scheduler tick retries.
|
||||
Without this, any transient Redis hiccup silently loses spend data.
|
||||
"""
|
||||
from litellm.proxy._types import Litellm_EntityType
|
||||
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
|
||||
DailySpendUpdateQueue,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.spend_update_queue import (
|
||||
SpendUpdateQueue,
|
||||
)
|
||||
|
||||
mock_redis_cache.async_rpush_pipeline = AsyncMock(
|
||||
side_effect=ConnectionError("redis went away")
|
||||
)
|
||||
|
||||
spend_queue = SpendUpdateQueue()
|
||||
daily_user_queue = DailySpendUpdateQueue()
|
||||
daily_team_queue = DailySpendUpdateQueue()
|
||||
daily_org_queue = DailySpendUpdateQueue()
|
||||
daily_end_user_queue = DailySpendUpdateQueue()
|
||||
daily_agent_queue = DailySpendUpdateQueue()
|
||||
|
||||
# Seed real queues with data so flush_and_get_aggregated returns it
|
||||
await spend_queue.add_update(
|
||||
{
|
||||
"entity_type": Litellm_EntityType.KEY,
|
||||
"entity_id": "key-abc",
|
||||
"response_cost": 1.5,
|
||||
}
|
||||
)
|
||||
await spend_queue.add_update(
|
||||
{
|
||||
"entity_type": Litellm_EntityType.TEAM,
|
||||
"entity_id": "team-xyz",
|
||||
"response_cost": 2.5,
|
||||
}
|
||||
)
|
||||
await daily_user_queue.add_update(
|
||||
{
|
||||
"user1_day_model": {
|
||||
"spend": 1.0,
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
await redis_update_buffer.store_in_memory_spend_updates_in_redis(
|
||||
spend_update_queue=spend_queue,
|
||||
daily_spend_update_queue=daily_user_queue,
|
||||
daily_team_spend_update_queue=daily_team_queue,
|
||||
daily_org_spend_update_queue=daily_org_queue,
|
||||
daily_end_user_spend_update_queue=daily_end_user_queue,
|
||||
daily_agent_spend_update_queue=daily_agent_queue,
|
||||
)
|
||||
|
||||
# After restore, the main spend queue should hold one item per
|
||||
# (entity_type, entity_id) pair with the aggregated cost
|
||||
restored_spend = (
|
||||
await spend_queue.flush_and_get_aggregated_db_spend_update_transactions()
|
||||
)
|
||||
assert restored_spend["key_list_transactions"] == {"key-abc": 1.5}
|
||||
assert restored_spend["team_list_transactions"] == {"team-xyz": 2.5}
|
||||
|
||||
# Daily user queue should hold the same aggregated dict
|
||||
restored_daily = (
|
||||
await daily_user_queue.flush_and_get_aggregated_daily_spend_update_transactions()
|
||||
)
|
||||
assert restored_daily == {
|
||||
"user1_day_model": {
|
||||
"spend": 1.0,
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 20,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_store_in_memory_spend_updates_all_empty_returns_early(
|
||||
redis_update_buffer, mock_redis_cache
|
||||
|
||||
Loading…
Reference in New Issue
Block a user