* feat(xai): add grok-4.20 beta 2 models with pricing (#23900)
Add three grok-4.20 beta 2 model variants from xAI:
- grok-4.20-multi-agent-beta-0309 (reasoning + multi-agent)
- grok-4.20-beta-0309-reasoning (reasoning)
- grok-4.20-beta-0309-non-reasoning
Pricing (from https://docs.x.ai/docs/models):
- Input: $2.00/1M tokens ($0.20/1M cached)
- Output: $6.00/1M tokens
- Context: 2M tokens
All variants support vision, function calling, tool choice, and web search.
Closes LIT-2171
* docs: add Quick Install section for litellm --setup wizard (#23905)
* docs: add Quick Install section for litellm --setup wizard
* docs: clarify setup wizard is for local/beginner use
* feat(setup): interactive setup wizard + install.sh (#23644)
* feat(setup): add interactive setup wizard + install.sh
Adds `litellm --setup` — a Claude Code-style TUI onboarding wizard that
guides users through provider selection, API key entry, and proxy config
generation, then optionally starts the proxy immediately.
- litellm/setup_wizard.py: wizard with ASCII art, numbered provider menu
(OpenAI, Anthropic, Azure, Gemini, Bedrock, Ollama), API key prompts,
port/master-key config, and litellm_config.yaml generation
- litellm/proxy/proxy_cli.py: adds --setup flag that invokes the wizard
- scripts/install.sh: curl-installable script (detect OS/Python, pip
install litellm[proxy], launch wizard)
Usage:
curl -fsSL https://raw.githubusercontent.com/BerriAI/litellm/main/scripts/install.sh | sh
litellm --setup
* fix(install.sh): remove orange color, add LITELLM_BRANCH env var for branch installs
* fix(install.sh): install from git branch so --setup is available for QA
* fix(install.sh): remove stale LITELLM_BRANCH reference that caused unbound variable error
* fix(install.sh): force-reinstall from git to bypass cached PyPI version
* fix(install.sh): show pip progress bar during install
* fix(install.sh): always launch wizard via $PYTHON_BIN -m litellm, not PATH binary
* fix(install.sh): use litellm.proxy.proxy_cli module (no __main__.py exists)
* fix(install.sh): suppress RuntimeWarning from module invocation
* fix(install.sh): use Python bin-dir litellm binary to avoid CWD sys.path shadowing
* fix(install.sh): use sysconfig.get_path('scripts') to find pip-installed litellm binary
* fix(install.sh): redirect stdin from /dev/tty on exec so wizard gets terminal, not exhausted pipe
* fix(install.sh): warn about git clone duration, drop --no-cache-dir so re-runs are faster
* feat(setup_wizard): arrow-key selector, updated model names
* fix(setup_wizard): use sysconfig binary to start proxy, not python -m litellm
* feat(setup_wizard): credential validation after key entry + clear next-steps after proxy start
* style(install.sh): show git clone warning in blue
* refactor(setup_wizard): class with static methods, use check_valid_key from litellm.utils
* address greptile review: fix yaml escaping, port validation, display name collisions, tests
- setup_wizard.py: add _yaml_escape() for safe YAML embedding of API keys
- setup_wizard.py: add _styled_input() with readline ANSI ignore markers
- setup_wizard.py: change DIVIDER to _divider() fn to avoid import-time color capture
- setup_wizard.py: validate port range 1-65535, initialize before loop
- setup_wizard.py: qualify azure display names (azure-gpt-4o) to avoid collision with openai
- setup_wizard.py: work on env_copy in _build_config to avoid mutating caller's dict
- setup_wizard.py: skip model_list entries for providers with no credentials
- setup_wizard.py: prompt for azure deployment name
- setup_wizard.py: wrap os.execlp in try/except with friendly fallback
- setup_wizard.py: wrap config write in try/except OSError
- setup_wizard.py: fix _validate_and_report to use two print lines (no \r overwrite)
- setup_wizard.py: add .gitignore tip next to key storage notice
- setup_wizard.py: fix run_setup_wizard() return type annotation to None
- scripts/install.sh: drop pipefail (not supported by dash on Ubuntu when invoked as sh)
- scripts/install.sh: use litellm[proxy] from PyPI (not hardcoded dev branch)
- scripts/install.sh: guard /dev/tty read with -r check for Docker/CI compat
- scripts/install.sh: remove --force-reinstall to avoid downgrading dependencies
- tests/test_litellm/test_setup_wizard.py: 13 unit tests for _build_config and _yaml_escape
* style: black format setup_wizard.py
* fix: address remaining greptile issues - Windows compat, YAML quoting, credential flow
- guard termios/tty imports with try/except ImportError for Windows compat
- quote master_key as YAML double-quoted scalar (same as env vars)
- remove unused port param from _build_config signature
- _validate_and_report now returns the final key so re-entered creds are stored
- add test for master_key YAML quoting
* fix: add --port to suggested command, guard /dev/tty exec in install.sh
* fix: quote api_base in YAML, skip azure if no deployment, only redraw on state change
* fix: address greptile review comments
- _yaml_escape: add control character escaping (\n, \r, \t)
- test: fix tautological assertion in test_build_config_azure_no_deployment_skipped
- test: add tests for control character escaping in _yaml_escape
* feat(ui): remove Chat UI page link and banner from sidebar and playground (#23908)
* feat(guardrails): MCPJWTSigner - built-in guardrail for zero trust MCP auth (#23897)
* Allow pre_mcp_call guardrail hooks to mutate outbound MCP headers
* Enhance MCPServerManager to support hook-modified arguments and extra headers. Update tests to validate argument mutation and header injection behavior, including warnings for OpenAPI-backed servers when headers are present.
* Refactor MCPServerManager to raise HTTPException for extra headers in OpenAPI-backed servers. Update tests to reflect this change, ensuring proper exception handling instead of logging warnings.
* Allow pre_mcp_call guardrail hooks to mutate outbound MCP headers
* Enhance MCPServerManager to support hook-modified arguments and extra headers. Update tests to validate argument mutation and header injection behavior, including warnings for OpenAPI-backed servers when headers are present.
* Refactor MCPServerManager to raise HTTPException for extra headers in OpenAPI-backed servers. Update tests to reflect this change, ensuring proper exception handling instead of logging warnings.
* feat(guardrails): add MCPJWTSigner built-in guardrail for zero trust MCP auth
Signs outbound MCP tool calls with a LiteLLM-issued RS256 JWT so MCP servers
can trust a single signing authority instead of every upstream IdP.
Enable in config.yaml:
guardrails:
- guardrail_name: mcp-jwt-signer
litellm_params:
guardrail: mcp_jwt_signer
mode: pre_mcp_call
default_on: true
JWT carries sub (user_id), act.sub (team_id, RFC 8693), tool-level scope, iss,
aud, iat/exp/nbf. RSA-2048 keypair auto-generated at startup unless
MCP_JWT_SIGNING_KEY env var is set.
Adds /.well-known/jwks.json endpoint and jwks_uri to /.well-known/openid-configuration
so MCP servers can verify LiteLLM-issued tokens via OIDC discovery.
* Update MCPServerManager to raise HTTPException with status code 400 for extra headers in OpenAPI-backed servers. Adjust tests to verify the correct status code and exception message.
* fix: address P1 issues in MCPJWTSigner
- OpenAPI servers: warn + skip header injection instead of 500
- JWKS Cache-Control: 5min for auto-generated keys, 1h for persistent
- sub claim: fallback to apikey:{token_hash} for anonymous callers
- ttl_seconds: validate > 0 at init time
* docs: add MCP zero trust auth guide with architecture diagram
* docs: add FastMCP JWT verification guide to zero trust doc
* fix: address remaining Greptile review issues (round 2)
- mcp_server_manager: warn when hook Authorization overwrites existing header
- __init__: remove _mcp_jwt_signer_instance from __all__ (private internal)
- discoverable_endpoints: copy dict instead of mutating in-place on OIDC augmentation
- test docstring: reflect warn-and-continue behavior for OpenAPI servers
- test: update scope assertions for least-privilege (no mcp:tools/list on tool-call JWTs)
* fix: address Greptile round 3 feedback
- initialize_guardrail: validate mode='pre_mcp_call' at init time — misconfigured
mode silently bypasses JWT injection, which is a zero-trust bypass
- _build_claims: remove duplicate inline 'import re' (module-level import already present)
- _types.py: add TODO comment explaining jwt_claims is forward-compat plumbing
for a follow-up PR that will forward upstream IdP claims into outbound MCP JWTs
* feat(mcp_jwt_signer): add verify+re-sign, claim ops, two-token model, configurable scopes
Addresses all missing pieces from the scoping doc review:
FR-5 (Verify + re-sign): MCPJWTSigner now accepts access_token_discovery_uri
and token_introspection_endpoint. When set, the incoming Bearer token is
extracted from raw_headers (threaded through pre_call_tool_check), verified
against the IdP's JWKS (JWT) or introspected (opaque), and only re-signed if
valid. Falls back to user_api_key_dict.jwt_claims for LiteLLM JWT-auth mode.
FR-12 (Configurable end-user identity mapping): end_user_claim_sources
ordered list drives sub resolution — sources: token:<claim>, litellm:user_id,
litellm:email, litellm:end_user_id, litellm:team_id.
FR-13 (Claim operations): add_claims (insert-if-absent), set_claims (always
override), remove_claims (delete) applied in that order.
FR-14 (Two-token model): channel_token_audience + channel_token_ttl issue a
second JWT injected as x-mcp-channel-token: Bearer <token>.
FR-15 (Incoming claim validation): required_claims raises HTTP 403 when any
listed claim is absent; optional_claims passes listed claims from verified
token into the outbound JWT.
FR-9 (Debug headers): debug_headers: true emits x-litellm-mcp-debug with kid,
sub, iss, exp, scope.
FR-10 (Configurable scopes): allowed_scopes replaces auto-generation. Also
fixed: tool-call JWTs no longer grant mcp:tools/list (overpermission).
P1 fixes:
- proxy/utils.py: _convert_mcp_hook_response_to_kwargs merges rather than
replaces extra_headers, preserving headers from prior guardrails.
- mcp_server_manager.py: warns when hook injects Authorization alongside a
server-configured authentication_token (previously silent).
- mcp_server_manager.py: pre_call_tool_check now accepts raw_headers and
extracts incoming_bearer_token so FR-5 verification has the raw token.
- proxy/utils.py: remove stray inline import inspect inside loop (pre-existing
lint error, now cleaned up).
Tests: 43 passing (28 new tests covering all FR flags + P1 fixes).
* feat(mcp_jwt_signer): add verify+re-sign, claim ops, two-token model, configurable scopes (core)
Remaining files from the FR implementation:
mcp_jwt_signer.py — full rewrite with all new params:
FR-5: access_token_discovery_uri, token_introspection_endpoint,
verify_issuer, verify_audience + _verify_incoming_jwt(),
_introspect_opaque_token()
FR-12: end_user_claim_sources ordered resolution chain
FR-13: add_claims, set_claims, remove_claims
FR-14: channel_token_audience, channel_token_ttl → x-mcp-channel-token
FR-15: required_claims (raises 403), optional_claims (passthrough)
FR-9: debug_headers → x-litellm-mcp-debug
FR-10: allowed_scopes; tool-call JWTs no longer over-grant tools/list
mcp_server_manager.py:
- pre_call_tool_check gains raw_headers param to extract incoming_bearer_token
- Silent Authorization override warning fixed: now fires when server has
authentication_token AND hook injects Authorization
tests/test_mcp_jwt_signer.py:
28 new tests covering all FR flags + P1 fixes (43 total, all passing)
* fix(mcp_jwt_signer): address pre-landing review issues
- Remove stale TODO comment on UserAPIKeyAuth.jwt_claims — the field is
already populated and consumed by MCPJWTSigner in the same PR
- Fix _get_oidc_discovery to only cache the OIDC discovery doc when
jwks_uri is present; a malformed/empty doc now retries on the next
request instead of being permanently cached until proxy restart
- Add FR-5 test coverage for _fetch_jwks (cache hit/miss),
_get_oidc_discovery (cache/no-cache on bad doc), _verify_incoming_jwt
(valid token, expired token), _introspect_opaque_token (active,
inactive, no endpoint), and the end-to-end 401 hook path — 53 tests
total, all passing
* docs(mcp_zero_trust): rewrite as use-case guide covering all new JWT signer features
Add scenario-driven sections for each new config area:
- Verify+re-sign with Okta/Azure AD (access_token_discovery_uri,
end_user_claim_sources, token_introspection_endpoint)
- Enforcing caller attributes with required_claims / optional_claims
- Adding metadata via add_claims / set_claims / remove_claims
- Two-token model for AWS Bedrock AgentCore Gateway
(channel_token_audience / channel_token_ttl)
- Controlling scopes with allowed_scopes
- Debugging JWT rejections with debug_headers
Update JWT claims table to reflect configurable sub (end_user_claim_sources)
* fix(mcp_jwt_signer): wire all config.yaml params through initialize_guardrail
The factory was only passing issuer/audience/ttl_seconds to MCPJWTSigner.
All FR-5/9/10/12/13/14/15 params (access_token_discovery_uri,
end_user_claim_sources, add/set/remove_claims, channel_token_audience,
required/optional_claims, debug_headers, allowed_scopes, etc.) were
silently dropped, making every advertised advanced feature non-functional
when loaded from config.yaml.
Add regression test that asserts every param is wired through correctly.
* docs(mcp_zero_trust): add hero image
* docs(mcp_zero_trust): apply Linear-style edits
- Lead with the problem (unsigned direct calls bypass access controls)
- Shorter statement section headers instead of question-form headers
- Move diagram/OIDC discovery block after the reader is bought in
- Add 'read further only if you need to' callout after basic setup
- Two-token section now opens from the user problem not product jargon
- Add concrete 403 error response example in required_claims section
- Debug section opens from the symptom (MCP server returning 401)
- Lowercase claims reference header for consistency
* fix(mcp_jwt_signer): fix algorithm confusion attack + add OIDC discovery 24h TTL
- Remove alg from unverified JWT header; use signing_jwk.algorithm_name from JWKS key instead.
Reading alg from attacker-controlled headers enables alg:none / HS256 confusion attacks.
- Add _oidc_discovery_fetched_at timestamp and _OIDC_DISCOVERY_TTL = 86400 (24h).
Without a TTL the cached discovery doc never refreshes, so IdP key rotation is invisible.
---------
Co-authored-by: Noah Nistler <60981020+noahnistler@users.noreply.github.com>
* fix(ci): stabilize CI - formatting, type errors, test polling, security CVEs, router bug, batch resolution
Fix 1: Run Black formatter on 35 files
Fix 2: Fix MyPy type errors:
- setup_wizard.py: add type annotation for 'selected' set variable
- user_api_key_auth.py: remove redundant type annotation on jwt_claims reassignment
Fix 3: Fix spend accuracy test burst 2 polling to wait for expected total
spend instead of just 'any increase' from burst 2
Fix 4: Bump Next.js 16.1.6 -> 16.1.7 to fix CVE-2026-27978, CVE-2026-27979,
CVE-2026-27980, CVE-2026-29057
Fix 5: Fix router _pre_call_checks model variable being overwritten inside
loop, causing wrong model lookups on subsequent deployments. Use local
_deployment_model variable instead.
Fix 6: Add missing resolve_output_file_ids_to_unified call in batch retrieve
non-terminal-to-terminal path (matching the terminal path behavior)
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* chore: regenerate poetry.lock to sync with pyproject.toml
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix: format merged files from main and regenerate poetry.lock
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix(mypy): annotate jwt_claims as Optional[dict] to fix type incompatibility
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix(ci): update router region test to use gpt-4.1-mini (fix flaky model lookup)
Replace deprecated gpt-3.5-turbo-1106 with gpt-4.1-mini + mock_response in
test_router_region_pre_call_check, following the same pattern used in commit
717d37cc5b for test_router_context_window_check_pre_call_check_out_group.
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* ci: retry flaky logging_testing (async event loop race condition)
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix(ci): aggregate all mock calls in langfuse e2e test to fix race condition
The _verify_langfuse_call helper only inspected the last mock call
(mock_post.call_args), but the Langfuse SDK may split trace-create and
generation-create events across separate HTTP flush cycles. This caused
an IndexError when the last call's batch contained only one event type.
Fix: iterate over mock_post.call_args_list to collect batch items from
ALL calls. Also add a safety assertion after filtering by trace_id and
mark all langfuse e2e tests with @pytest.mark.flaky(retries=3) as an
extra safety net for any residual timing issues.
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix(ci): black formatting + update OpenAPI compliance tests for spec changes
- Apply Black 26.x formatting to litellm_logging.py (parenthesized style)
- Update test_input_types_match_spec to follow $ref to InteractionsInput schema
(Google updated their OpenAPI spec to use $ref instead of inline oneOf)
- Update test_content_schema_uses_discriminator to handle discriminator without
explicit mapping (Google removed the mapping key from Content discriminator)
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* revert: undo incorrect Black 26.x formatting on litellm_logging.py
The file was correctly formatted for Black 23.12.1 (the version pinned
in pyproject.toml). The previous commit applied Black 26.x formatting
which was incompatible with the CI's Black version.
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix(ci): deduplicate and sort langfuse batch events after aggregation
The Langfuse SDK may send the same event (e.g., trace-create) in
multiple flush cycles, causing duplicates when we aggregate from all
mock calls. After filtering by trace_id, deduplicate by keeping only
the first event of each type, then sort to ensure trace-create is at
index 0 and generation-create at index 1.
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
---------
Co-authored-by: Noah Nistler <60981020+noahnistler@users.noreply.github.com>
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
356 lines
14 KiB
Python
356 lines
14 KiB
Python
import pytest
|
|
import asyncio
|
|
import aiohttp
|
|
import json
|
|
import time
|
|
from httpx import AsyncClient
|
|
from typing import Any, Optional
|
|
from litellm._uuid import uuid
|
|
|
|
"""
|
|
Tests to run
|
|
|
|
Basic Tests:
|
|
1. Basic Spend Accuracy Test:
|
|
- Make 1 calibration request, poll for spend to derive SPEND_PER_REQUEST
|
|
- Make N-1 more requests (N total)
|
|
- Expect the spend for each of the following to be N * SPEND_PER_REQUEST
|
|
Key, Team, User, Org (call /info endpoint for each object to validate)
|
|
|
|
2. Long term spend accuracy test (with 2 bursts of requests)
|
|
- Burst 1: Make requests, derive SPEND_PER_REQUEST from first request
|
|
- Burst 2: Make more requests
|
|
- Verify total spend = (burst1 + burst2) * SPEND_PER_REQUEST
|
|
|
|
Additional Test Scenarios:
|
|
|
|
3. Concurrent Request Accuracy Test:
|
|
- Make 20 concurrent requests
|
|
- Check for race conditions in spend tracking
|
|
|
|
4. Error Case Test:
|
|
- Make 10 successful requests
|
|
- Make 5 failed requests
|
|
- Verify spend is only counted for successful requests
|
|
|
|
5. Mixed Request Type Test:
|
|
- Make different types of requests with varying costs
|
|
- Verify accurate total spend calculation
|
|
"""
|
|
|
|
|
|
async def create_organization(session, organization_alias: str):
|
|
"""Helper function to create a new organization"""
|
|
url = "http://0.0.0.0:4000/organization/new"
|
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
|
data = {"organization_alias": organization_alias}
|
|
async with session.post(url, headers=headers, json=data) as response:
|
|
return await response.json()
|
|
|
|
|
|
async def create_team(session, org_id: str):
|
|
"""Helper function to create a new team under an organization"""
|
|
url = "http://0.0.0.0:4000/team/new"
|
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
|
data = {"organization_id": org_id, "team_alias": f"test-team-{uuid.uuid4()}"}
|
|
async with session.post(url, headers=headers, json=data) as response:
|
|
return await response.json()
|
|
|
|
|
|
async def create_user(session, org_id: str):
|
|
"""Helper function to create a new user"""
|
|
url = "http://0.0.0.0:4000/user/new"
|
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
|
data = {"user_name": f"test-user-{uuid.uuid4()}"}
|
|
async with session.post(url, headers=headers, json=data) as response:
|
|
return await response.json()
|
|
|
|
|
|
async def generate_key(session, user_id: str, team_id: str):
|
|
"""Helper function to generate a key for a specific user and team"""
|
|
url = "http://0.0.0.0:4000/key/generate"
|
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
|
data = {"user_id": user_id, "team_id": team_id}
|
|
async with session.post(url, headers=headers, json=data) as response:
|
|
return await response.json()
|
|
|
|
|
|
async def chat_completion(session, key: str):
|
|
"""Make a chat completion request"""
|
|
from openai import AsyncOpenAI
|
|
from litellm._uuid import uuid
|
|
|
|
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000/v1")
|
|
|
|
response = await client.chat.completions.create(
|
|
model="fake-openai-endpoint",
|
|
messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}],
|
|
)
|
|
return response
|
|
|
|
|
|
async def get_spend_info(session, entity_type: str, entity_id: str):
|
|
"""Helper function to get spend information for an entity"""
|
|
url = f"http://0.0.0.0:4000/{entity_type}/info"
|
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
|
if entity_type == "key":
|
|
data = {"key": entity_id}
|
|
else:
|
|
data = {f"{entity_type}_id": entity_id}
|
|
|
|
async with session.get(url, headers=headers, params=data) as response:
|
|
return await response.json()
|
|
|
|
|
|
async def poll_key_spend_until_nonzero(
|
|
session, key: str, timeout: int = 120, interval: int = 10
|
|
):
|
|
"""Poll key spend until it becomes non-zero or timeout is reached."""
|
|
start = time.time()
|
|
while time.time() - start < timeout:
|
|
key_info = await get_spend_info(session, "key", key)
|
|
spend = key_info["info"]["spend"]
|
|
if spend > 0:
|
|
print(f"Key spend became non-zero ({spend}) after {time.time() - start:.1f}s")
|
|
return spend
|
|
print(f"Key spend still 0.0, waiting... ({time.time() - start:.1f}s elapsed)")
|
|
await asyncio.sleep(interval)
|
|
raise TimeoutError(
|
|
f"Key spend remained 0.0 after {timeout}s — batch writer may not be running"
|
|
)
|
|
|
|
|
|
async def calibrate_spend_per_request(session, key: str, max_retries: int = 5):
|
|
"""
|
|
Make a single calibration request and poll for its spend to derive SPEND_PER_REQUEST.
|
|
Fails fast with pytest.fail() if spend cannot be determined.
|
|
"""
|
|
response = await chat_completion(session, key)
|
|
print(f"Calibration request completed: {response}")
|
|
|
|
for attempt in range(1, max_retries + 1):
|
|
try:
|
|
spend = await poll_key_spend_until_nonzero(
|
|
session, key, timeout=120, interval=10
|
|
)
|
|
print(
|
|
f"Calibrated SPEND_PER_REQUEST = {spend} "
|
|
f"(attempt {attempt}/{max_retries})"
|
|
)
|
|
return spend
|
|
except TimeoutError:
|
|
if attempt < max_retries:
|
|
print(
|
|
f"Calibration attempt {attempt}/{max_retries} timed out, retrying..."
|
|
)
|
|
else:
|
|
pytest.fail(
|
|
f"Failed to calibrate SPEND_PER_REQUEST after {max_retries} attempts. "
|
|
"The batch writer may not be running or the model may have 0 cost."
|
|
)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_basic_spend_accuracy():
|
|
"""
|
|
Test basic spend accuracy across different entities:
|
|
1. Create org, team, user, and key
|
|
2. Make 1 calibration request to derive SPEND_PER_REQUEST
|
|
3. Make remaining requests (NUM_LLM_REQUESTS total)
|
|
4. Verify spend accuracy for key, team, user, and org
|
|
"""
|
|
NUM_LLM_REQUESTS = 20
|
|
TOLERANCE = 1e-10
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
# Create organization
|
|
org_response = await create_organization(
|
|
session=session, organization_alias=f"test-org-{uuid.uuid4()}"
|
|
)
|
|
print("org_response: ", org_response)
|
|
org_id = org_response["organization_id"]
|
|
|
|
# Create team under organization
|
|
team_response = await create_team(session, org_id)
|
|
print("team_response: ", team_response)
|
|
team_id = team_response["team_id"]
|
|
|
|
# Create user
|
|
user_response = await create_user(session, org_id)
|
|
print("user_response: ", user_response)
|
|
user_id = user_response["user_id"]
|
|
|
|
# Generate key
|
|
key_response = await generate_key(session, user_id, team_id)
|
|
print("key_response: ", key_response)
|
|
key = key_response["key"]
|
|
|
|
# Calibrate: make 1 request and derive SPEND_PER_REQUEST
|
|
spend_per_request = await calibrate_spend_per_request(session, key)
|
|
expected_spend = NUM_LLM_REQUESTS * spend_per_request
|
|
print(f"SPEND_PER_REQUEST={spend_per_request}, expected_spend={expected_spend}")
|
|
|
|
# Make remaining requests (1 already made during calibration)
|
|
for i in range(NUM_LLM_REQUESTS - 1):
|
|
response = await chat_completion(session, key)
|
|
print(f"Request {i + 2}/{NUM_LLM_REQUESTS} completed")
|
|
|
|
# Poll until batch writer has flushed all spend
|
|
start = time.time()
|
|
while time.time() - start < 120:
|
|
key_info = await get_spend_info(session, "key", key)
|
|
current_spend = key_info["info"]["spend"]
|
|
if abs(current_spend - expected_spend) < TOLERANCE:
|
|
print(f"Key spend reached expected {expected_spend} after {time.time() - start:.1f}s")
|
|
break
|
|
print(f"Key spend {current_spend}, expected {expected_spend}, waiting...")
|
|
await asyncio.sleep(10)
|
|
|
|
# Allow extra time for all entity spend aggregations to complete
|
|
await asyncio.sleep(5)
|
|
|
|
# Get spend information for each entity
|
|
key_info = await get_spend_info(session, "key", key)
|
|
print("key_info: ", key_info)
|
|
team_info = await get_spend_info(session, "team", team_id)
|
|
print("team_info: ", team_info)
|
|
user_info = await get_spend_info(session, "user", user_id)
|
|
print("user_info: ", user_info)
|
|
org_info = await get_spend_info(session, "organization", org_id)
|
|
print("org_info: ", org_info)
|
|
|
|
# Verify spend for each entity
|
|
assert (
|
|
abs(key_info["info"]["spend"] - expected_spend) < TOLERANCE
|
|
), f"Key spend {key_info['info']['spend']} does not match expected {expected_spend}"
|
|
|
|
assert (
|
|
abs(user_info["user_info"]["spend"] - expected_spend) < TOLERANCE
|
|
), f"User spend {user_info['user_info']['spend']} does not match expected {expected_spend}"
|
|
|
|
assert (
|
|
abs(team_info["team_info"]["spend"] - expected_spend) < TOLERANCE
|
|
), f"Team spend {team_info['team_info']['spend']} does not match expected {expected_spend}"
|
|
|
|
assert (
|
|
abs(org_info["spend"] - expected_spend) < TOLERANCE
|
|
), f"Organization spend {org_info['spend']} does not match expected {expected_spend}"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_long_term_spend_accuracy_with_bursts():
|
|
"""
|
|
Test long-term spend accuracy with multiple bursts of requests:
|
|
1. Create org, team, user, and key
|
|
2. Calibrate SPEND_PER_REQUEST from first request
|
|
3. Burst 1: Make remaining requests
|
|
4. Burst 2: Make more requests
|
|
5. Verify the total spend is tracked accurately across all entities
|
|
"""
|
|
BURST_1_REQUESTS = 22
|
|
BURST_2_REQUESTS = 12
|
|
TOTAL_REQUESTS = BURST_1_REQUESTS + BURST_2_REQUESTS
|
|
TOLERANCE = 1e-10
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
# Create organization
|
|
org_response = await create_organization(
|
|
session=session, organization_alias=f"test-org-{uuid.uuid4()}"
|
|
)
|
|
print("org_response: ", org_response)
|
|
org_id = org_response["organization_id"]
|
|
|
|
# Create team under organization
|
|
team_response = await create_team(session, org_id)
|
|
print("team_response: ", team_response)
|
|
team_id = team_response["team_id"]
|
|
|
|
# Create user
|
|
user_response = await create_user(session, org_id)
|
|
print("user_response: ", user_response)
|
|
user_id = user_response["user_id"]
|
|
|
|
# Generate key
|
|
key_response = await generate_key(session, user_id, team_id)
|
|
print("key_response: ", key_response)
|
|
key = key_response["key"]
|
|
|
|
# Calibrate: make 1 request and derive SPEND_PER_REQUEST
|
|
spend_per_request = await calibrate_spend_per_request(session, key)
|
|
expected_spend = TOTAL_REQUESTS * spend_per_request
|
|
print(f"SPEND_PER_REQUEST={spend_per_request}, expected_spend={expected_spend}")
|
|
|
|
# First burst: remaining requests (1 already made during calibration)
|
|
print(f"Starting first burst ({BURST_1_REQUESTS - 1} remaining requests)...")
|
|
for i in range(BURST_1_REQUESTS - 1):
|
|
response = await chat_completion(session, key)
|
|
print(f"Burst 1 - Request {i + 2}/{BURST_1_REQUESTS} completed")
|
|
|
|
# Poll until batch writer has flushed burst 1 spend
|
|
burst_1_expected = BURST_1_REQUESTS * spend_per_request
|
|
start = time.time()
|
|
while time.time() - start < 120:
|
|
key_info_check = await get_spend_info(session, "key", key)
|
|
current_spend = key_info_check["info"]["spend"]
|
|
if abs(current_spend - burst_1_expected) < TOLERANCE:
|
|
print(f"Burst 1 spend reached expected {burst_1_expected} after {time.time() - start:.1f}s")
|
|
break
|
|
print(f"Key spend {current_spend}, expected {burst_1_expected}, waiting...")
|
|
await asyncio.sleep(10)
|
|
|
|
# Check intermediate spend
|
|
intermediate_key_info = await get_spend_info(session, "key", key)
|
|
print(f"After Burst 1 - Key spend: {intermediate_key_info['info']['spend']}")
|
|
|
|
# Second burst
|
|
print(f"Starting second burst of {BURST_2_REQUESTS} requests...")
|
|
for i in range(BURST_2_REQUESTS):
|
|
response = await chat_completion(session, key)
|
|
print(f"Burst 2 - Request {i + 1}/{BURST_2_REQUESTS} completed")
|
|
|
|
# Poll until key spend reaches expected total (burst 1 + burst 2)
|
|
start = time.time()
|
|
while time.time() - start < 120:
|
|
key_info_check = await get_spend_info(session, "key", key)
|
|
current_spend = key_info_check["info"]["spend"]
|
|
if abs(current_spend - expected_spend) < TOLERANCE:
|
|
print(
|
|
f"Total spend reached expected {expected_spend} after {time.time() - start:.1f}s"
|
|
)
|
|
break
|
|
print(
|
|
f"Key spend {current_spend}, expected {expected_spend}, waiting..."
|
|
)
|
|
await asyncio.sleep(10)
|
|
|
|
# Allow extra time for all entity spend aggregations
|
|
await asyncio.sleep(5)
|
|
|
|
# Get final spend information for each entity
|
|
key_info = await get_spend_info(session, "key", key)
|
|
team_info = await get_spend_info(session, "team", team_id)
|
|
user_info = await get_spend_info(session, "user", user_id)
|
|
org_info = await get_spend_info(session, "organization", org_id)
|
|
|
|
print(f"Final key spend: {key_info['info']['spend']}")
|
|
print(f"Final team spend: {team_info['team_info']['spend']}")
|
|
print(f"Final user spend: {user_info['user_info']['spend']}")
|
|
print(f"Final org spend: {org_info['spend']}")
|
|
|
|
# Verify total spend for each entity
|
|
assert (
|
|
abs(key_info["info"]["spend"] - expected_spend) < TOLERANCE
|
|
), f"Key spend {key_info['info']['spend']} does not match expected {expected_spend}"
|
|
|
|
assert (
|
|
abs(user_info["user_info"]["spend"] - expected_spend) < TOLERANCE
|
|
), f"User spend {user_info['user_info']['spend']} does not match expected {expected_spend}"
|
|
|
|
assert (
|
|
abs(team_info["team_info"]["spend"] - expected_spend) < TOLERANCE
|
|
), f"Team spend {team_info['team_info']['spend']} does not match expected {expected_spend}"
|
|
|
|
assert (
|
|
abs(org_info["spend"] - expected_spend) < TOLERANCE
|
|
), f"Organization spend {org_info['spend']} does not match expected {expected_spend}"
|