litellm/scripts/benchmark_chat_completions_perf.py
Yassin Kortam a6494e6fe3
perf: eliminate per-request callback scanning on proxy hot path (#27858)
- Introduce `_CallbackCapabilities` dataclass and `ProxyLogging._callback_capabilities()` static method that inspects `litellm.callbacks` once and caches capability flags keyed on (list length, member ids); invalidates automatically when the callback list mutates without per-request iteration overhead
- Replace O(n) `litellm.callbacks` walks in `async_pre_call_hook`, `during_call_hook`, `async_post_call_streaming_iterator_hook`, `async_post_call_streaming_hook`, and `post_call_response_headers_hook` with fast-path exits when no relevant callbacks are registered
- Add `needs_iterator_wrap()` and `needs_per_chunk_streaming_hook()` instance methods to decouple iterator-level wrapping from per-chunk hook execution; avoids `get_response_string` materialization per chunk when no guardrail or chunk-hook callback is active
- Introduce `_fast_serialize_simple_model_response_stream()` using `orjson` for common single-choice text streaming chunks, bypassing the full Pydantic serializer; falls back to `model_dump_json` for tool calls, logprobs, usage, and provider-specific fields
- Add early-return in `_restamp_streaming_chunk_model` when downstream model already matches the requested model, avoiding unnecessary string comparisons on every chunk
- Fix stale zero-cost cache bug in `_is_model_cost_zero`: move the per-router `_zero_cost_cache` dict onto the `Router` instance and clear it in `_invalidate_model_group_info_cache` so in-place pricing updates via `upsert_deployment` immediately resume budget enforcement
- Add `scripts/benchmark_chat_completions_perf.py`: standalone async benchmarking tool with a mock OpenAI provider, LiteLLM proxy process management, non-streaming RPS, streaming TTFT, and full-stream latency measurements with repeat/median run support
- Add comprehensive unit tests covering capability detection, cache invalidation, fast-path correctness, zero-cost cache regression, and the no-callback streaming fast path

Co-authored-by: Yassin Kortam <yassinkortam@g.ucla.edu>
2026-05-14 09:28:31 -07:00

843 lines
29 KiB
Python

#!/usr/bin/env python3
"""Benchmark LiteLLM proxy /v1/chat/completions overhead and streaming TTFT.
The script can run a local OpenAI-compatible mock provider plus a LiteLLM proxy
from any checkout. That makes it useful for comparing tags/commits without
depending on real provider latency.
Example:
uv run python scripts/benchmark_chat_completions_perf.py \
--label current --requests 500 --concurrency 100
Compare another checkout:
uv run python scripts/benchmark_chat_completions_perf.py \
--label v1.83.14-stable --litellm-dir /tmp/litellm-v1.83.14-stable
"""
from __future__ import annotations
import argparse
import asyncio
import json
import os
import shlex
import signal
import statistics
import subprocess
import sys
import tempfile
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional
import aiohttp
from aiohttp import web
DEFAULT_MODEL = "perf-test-model"
DEFAULT_API_KEY = "sk-1234"
@dataclass
class RequestSample:
success: bool
latency_ms: float
status_code: int
overhead_header_ms: Optional[float] = None
error: str = ""
@dataclass
class SummaryStats:
requests: int
failures: int
rps: float
mean_ms: float
p50_ms: float
p95_ms: float
p99_ms: float
overhead_header_mean_ms: Optional[float] = None
overhead_header_p50_ms: Optional[float] = None
overhead_header_p95_ms: Optional[float] = None
class MockOpenAIProvider:
def __init__(
self,
host: str,
port: int,
first_token_delay_ms: float,
stream_content_chunks: int,
) -> None:
self.host = host
self.port = port
self.first_token_delay_ms = first_token_delay_ms
self.stream_content_chunks = stream_content_chunks
self.runner: Optional[web.AppRunner] = None
@property
def base_url(self) -> str:
return f"http://{self.host}:{self.port}"
async def start(self) -> None:
app = web.Application()
app.router.add_post("/v1/chat/completions", self.handle_chat_completions)
self.runner = web.AppRunner(app, access_log=None)
await self.runner.setup()
site = web.TCPSite(self.runner, self.host, self.port)
await site.start()
async def stop(self) -> None:
if self.runner is not None:
await self.runner.cleanup()
async def handle_chat_completions(self, request: web.Request) -> web.StreamResponse:
body = await request.json()
if body.get("stream"):
return await self._streaming_response(request=request, body=body)
return self._json_response(body)
def _json_response(self, body: dict[str, Any]) -> web.Response:
now = int(time.time())
payload = {
"id": "chatcmpl-perf",
"object": "chat.completion",
"created": now,
"model": body.get("model", DEFAULT_MODEL),
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": "hello"},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 1,
"completion_tokens": 1,
"total_tokens": 2,
},
}
return web.json_response(payload)
async def _streaming_response(
self, request: web.Request, body: dict[str, Any]
) -> web.StreamResponse:
response = web.StreamResponse(
status=200,
headers={
"Content-Type": "text/event-stream",
"Cache-Control": "no-cache",
},
)
await response.prepare(request)
if self.first_token_delay_ms > 0:
await asyncio.sleep(self.first_token_delay_ms / 1000)
created = int(time.time())
chunks = [{"role": "assistant"}]
chunks.extend({"content": "hello"} for _ in range(self.stream_content_chunks))
for delta in chunks:
event = {
"id": "chatcmpl-perf",
"object": "chat.completion.chunk",
"created": created,
"model": body.get("model", DEFAULT_MODEL),
"choices": [{"index": 0, "delta": delta, "finish_reason": None}],
}
await response.write(f"data: {json.dumps(event)}\n\n".encode())
done_event = {
"id": "chatcmpl-perf",
"object": "chat.completion.chunk",
"created": created,
"model": body.get("model", DEFAULT_MODEL),
"choices": [{"index": 0, "delta": {}, "finish_reason": "stop"}],
}
await response.write(f"data: {json.dumps(done_event)}\n\n".encode())
await response.write(b"data: [DONE]\n\n")
await response.write_eof()
return response
def percentile(values: list[float], pct: float) -> float:
if not values:
return 0.0
sorted_values = sorted(values)
index = min(int(len(sorted_values) * pct / 100), len(sorted_values) - 1)
return sorted_values[index]
def summarize(samples: list[RequestSample], wall_time_s: float) -> SummaryStats:
latencies = [sample.latency_ms for sample in samples if sample.success]
overhead_headers = [
sample.overhead_header_ms
for sample in samples
if sample.success and sample.overhead_header_ms is not None
]
failures = len(samples) - len(latencies)
return SummaryStats(
requests=len(samples),
failures=failures,
rps=(len(latencies) / wall_time_s) if wall_time_s > 0 else 0.0,
mean_ms=statistics.mean(latencies) if latencies else 0.0,
p50_ms=percentile(latencies, 50),
p95_ms=percentile(latencies, 95),
p99_ms=percentile(latencies, 99),
overhead_header_mean_ms=(
statistics.mean(overhead_headers) if overhead_headers else None
),
overhead_header_p50_ms=(
percentile(overhead_headers, 50) if overhead_headers else None
),
overhead_header_p95_ms=(
percentile(overhead_headers, 95) if overhead_headers else None
),
)
def format_optional_ms(value: Optional[float]) -> str:
return "n/a" if value is None else f"{value:.2f}"
def get_git_revision(litellm_dir: Path) -> str:
try:
result = subprocess.run(
["git", "rev-parse", "--short", "HEAD"],
cwd=litellm_dir,
check=True,
capture_output=True,
text=True,
)
return result.stdout.strip()
except Exception:
return "unknown"
def write_proxy_config(config_path: Path, provider_base_url: str, api_key: str) -> None:
config_path.write_text(
f"""model_list:
- model_name: {DEFAULT_MODEL}
litellm_params:
model: openai/{DEFAULT_MODEL}
api_key: fake-provider-key
api_base: {provider_base_url}/v1
general_settings:
master_key: {api_key}
litellm_settings:
drop_params: true
telemetry: false
""",
encoding="utf-8",
)
async def wait_for_proxy(base_url: str, timeout_s: float) -> None:
deadline = time.perf_counter() + timeout_s
last_error = ""
async with aiohttp.ClientSession() as session:
while time.perf_counter() < deadline:
try:
async with session.get(f"{base_url}/health") as response:
if response.status < 500:
return
last_error = f"HTTP {response.status}: {await response.text()}"
except Exception as exc:
last_error = str(exc)
await asyncio.sleep(0.5)
raise TimeoutError(f"Timed out waiting for proxy at {base_url}: {last_error}")
def start_proxy_process(
litellm_dir: Path,
proxy_command: str,
config_path: Path,
port: int,
log_path: Path,
) -> subprocess.Popen:
command = shlex.split(proxy_command) + [
"--config",
str(config_path),
"--port",
str(port),
]
env = {
**os.environ,
"LITELLM_TELEMETRY": "False",
"PYTHONUNBUFFERED": "1",
}
log_file = log_path.open("w", encoding="utf-8")
return subprocess.Popen(
command,
cwd=litellm_dir,
env=env,
stdout=log_file,
stderr=subprocess.STDOUT,
start_new_session=True,
)
def stop_proxy_process(process: subprocess.Popen) -> None:
if process.poll() is not None:
return
try:
os.killpg(process.pid, signal.SIGTERM)
process.wait(timeout=10)
except Exception:
try:
os.killpg(process.pid, signal.SIGKILL)
except Exception:
pass
def extract_overhead_header(headers: aiohttp.typedefs.LooseHeaders) -> Optional[float]:
raw_value = headers.get("x-litellm-overhead-duration-ms") # type: ignore[union-attr]
if raw_value is None:
return None
try:
return float(raw_value)
except ValueError:
return None
async def post_non_streaming(
session: aiohttp.ClientSession,
url: str,
headers: dict[str, str],
payload: dict[str, Any],
semaphore: asyncio.Semaphore,
) -> RequestSample:
async with semaphore:
start = time.perf_counter()
try:
async with session.post(url, headers=headers, json=payload) as response:
body = await response.read()
latency_ms = (time.perf_counter() - start) * 1000
if response.status != 200:
return RequestSample(
success=False,
latency_ms=latency_ms,
status_code=response.status,
error=body.decode("utf-8", errors="ignore")[:200],
)
return RequestSample(
success=True,
latency_ms=latency_ms,
status_code=response.status,
overhead_header_ms=extract_overhead_header(response.headers),
)
except Exception as exc:
return RequestSample(
success=False,
latency_ms=(time.perf_counter() - start) * 1000,
status_code=0,
error=str(exc)[:200],
)
async def run_non_streaming_benchmark(
url: str,
headers: dict[str, str],
payload: dict[str, Any],
requests: int,
concurrency: int,
warmup: int,
timeout_s: float,
) -> SummaryStats:
timeout = aiohttp.ClientTimeout(total=timeout_s)
connector = aiohttp.TCPConnector(
limit=max(concurrency * 2, 10),
limit_per_host=max(concurrency, 10),
force_close=False,
)
semaphore = asyncio.Semaphore(concurrency)
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
if warmup > 0:
await asyncio.gather(
*[
post_non_streaming(session, url, headers, payload, semaphore)
for _ in range(warmup)
]
)
wall_start = time.perf_counter()
samples = await asyncio.gather(
*[
post_non_streaming(session, url, headers, payload, semaphore)
for _ in range(requests)
]
)
wall_time_s = time.perf_counter() - wall_start
return summarize(samples, wall_time_s)
async def measure_stream_ttft(
session: aiohttp.ClientSession,
url: str,
headers: dict[str, str],
payload: dict[str, Any],
semaphore: asyncio.Semaphore,
) -> RequestSample:
async with semaphore:
start = time.perf_counter()
try:
async with session.post(url, headers=headers, json=payload) as response:
if response.status != 200:
body = await response.read()
return RequestSample(
success=False,
latency_ms=(time.perf_counter() - start) * 1000,
status_code=response.status,
error=body.decode("utf-8", errors="ignore")[:200],
)
while raw_line := await response.content.readline():
line = raw_line.strip()
if not line or not line.startswith(b"data:"):
continue
event_payload = line[5:].strip()
if event_payload == b"[DONE]":
break
event = json.loads(event_payload)
choice = (event.get("choices") or [{}])[0]
delta = choice.get("delta") or {}
content = delta.get("content") or choice.get("text")
if content:
return RequestSample(
success=True,
latency_ms=(time.perf_counter() - start) * 1000,
status_code=response.status,
overhead_header_ms=extract_overhead_header(
response.headers
),
)
return RequestSample(
success=False,
latency_ms=(time.perf_counter() - start) * 1000,
status_code=response.status,
error="stream ended before a content token",
)
except Exception as exc:
return RequestSample(
success=False,
latency_ms=(time.perf_counter() - start) * 1000,
status_code=0,
error=str(exc)[:200],
)
async def run_streaming_ttft_benchmark(
url: str,
headers: dict[str, str],
payload: dict[str, Any],
requests: int,
concurrency: int,
warmup: int,
timeout_s: float,
) -> SummaryStats:
timeout = aiohttp.ClientTimeout(total=timeout_s)
connector = aiohttp.TCPConnector(
limit=max(concurrency * 2, 10),
limit_per_host=max(concurrency, 10),
force_close=False,
)
semaphore = asyncio.Semaphore(concurrency)
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
if warmup > 0:
await asyncio.gather(
*[
measure_stream_ttft(session, url, headers, payload, semaphore)
for _ in range(warmup)
]
)
wall_start = time.perf_counter()
samples = await asyncio.gather(
*[
measure_stream_ttft(session, url, headers, payload, semaphore)
for _ in range(requests)
]
)
wall_time_s = time.perf_counter() - wall_start
return summarize(samples, wall_time_s)
async def measure_stream_full_response(
session: aiohttp.ClientSession,
url: str,
headers: dict[str, str],
payload: dict[str, Any],
semaphore: asyncio.Semaphore,
) -> RequestSample:
async with semaphore:
start = time.perf_counter()
try:
async with session.post(url, headers=headers, json=payload) as response:
if response.status != 200:
body = await response.read()
return RequestSample(
success=False,
latency_ms=(time.perf_counter() - start) * 1000,
status_code=response.status,
error=body.decode("utf-8", errors="ignore")[:200],
)
saw_content = False
while raw_line := await response.content.readline():
line = raw_line.strip()
if not line or not line.startswith(b"data:"):
continue
event_payload = line[5:].strip()
if event_payload == b"[DONE]":
return RequestSample(
success=saw_content,
latency_ms=(time.perf_counter() - start) * 1000,
status_code=response.status,
overhead_header_ms=extract_overhead_header(
response.headers
),
error="" if saw_content else "stream ended without content",
)
if b'"content"' in event_payload or b'"text"' in event_payload:
saw_content = True
return RequestSample(
success=False,
latency_ms=(time.perf_counter() - start) * 1000,
status_code=response.status,
error="stream ended before [DONE]",
)
except Exception as exc:
return RequestSample(
success=False,
latency_ms=(time.perf_counter() - start) * 1000,
status_code=0,
error=str(exc)[:200],
)
async def run_streaming_full_benchmark(
url: str,
headers: dict[str, str],
payload: dict[str, Any],
requests: int,
concurrency: int,
warmup: int,
timeout_s: float,
) -> SummaryStats:
timeout = aiohttp.ClientTimeout(total=timeout_s)
connector = aiohttp.TCPConnector(
limit=max(concurrency * 2, 10),
limit_per_host=max(concurrency, 10),
force_close=False,
)
semaphore = asyncio.Semaphore(concurrency)
async with aiohttp.ClientSession(connector=connector, timeout=timeout) as session:
if warmup > 0:
await asyncio.gather(
*[
measure_stream_full_response(
session, url, headers, payload, semaphore
)
for _ in range(warmup)
]
)
wall_start = time.perf_counter()
samples = await asyncio.gather(
*[
measure_stream_full_response(session, url, headers, payload, semaphore)
for _ in range(requests)
]
)
wall_time_s = time.perf_counter() - wall_start
return summarize(samples, wall_time_s)
def stats_to_dict(stats: SummaryStats) -> dict[str, Any]:
return {
"requests": stats.requests,
"failures": stats.failures,
"rps": stats.rps,
"mean_ms": stats.mean_ms,
"p50_ms": stats.p50_ms,
"p95_ms": stats.p95_ms,
"p99_ms": stats.p99_ms,
"overhead_header_mean_ms": stats.overhead_header_mean_ms,
"overhead_header_p50_ms": stats.overhead_header_p50_ms,
"overhead_header_p95_ms": stats.overhead_header_p95_ms,
}
def _median_run(
runs: list[tuple[SummaryStats, SummaryStats, SummaryStats, Optional[SummaryStats]]],
) -> tuple[SummaryStats, SummaryStats, SummaryStats, Optional[SummaryStats]]:
# Pick the run whose proxy non-stream p50 is the median across repeats.
# Choosing a single representative run (rather than aggregating each metric
# separately) keeps related metrics from the same execution context so
# client-overhead deltas stay internally consistent.
sorted_runs = sorted(runs, key=lambda r: r[1].p50_ms)
return sorted_runs[len(sorted_runs) // 2]
def print_summary(
label: str,
revision: str,
direct: SummaryStats,
proxy: SummaryStats,
stream: SummaryStats,
stream_full: Optional[SummaryStats],
) -> None:
client_overhead_p50 = proxy.p50_ms - direct.p50_ms
client_overhead_p95 = proxy.p95_ms - direct.p95_ms
print("\n=== Benchmark summary ===")
print(f"Label: {label}")
print(f"Revision: {revision}")
print(f"Direct provider non-stream p50: {direct.p50_ms:.2f} ms")
print(f"Proxy non-stream p50: {proxy.p50_ms:.2f} ms")
print(f"Proxy non-stream p95: {proxy.p95_ms:.2f} ms")
print(f"Proxy non-stream RPS: {proxy.rps:.2f}")
print(f"Client-observed overhead p50: {client_overhead_p50:.2f} ms")
print(f"Client-observed overhead p95: {client_overhead_p95:.2f} ms")
print(
"x-litellm-overhead-duration-ms p50: "
f"{format_optional_ms(proxy.overhead_header_p50_ms)} ms"
)
print(f"Streaming TTFT p50: {stream.p50_ms:.2f} ms")
print(f"Streaming TTFT p95: {stream.p95_ms:.2f} ms")
print(f"Streaming TTFT RPS: {stream.rps:.2f}")
if stream_full is not None:
print(f"Streaming full response p50: {stream_full.p50_ms:.2f} ms")
print(f"Streaming full response p95: {stream_full.p95_ms:.2f} ms")
print(f"Streaming full response RPS: {stream_full.rps:.2f}")
print("\nMarkdown row:")
print(
"| "
+ " | ".join(
[
label,
revision,
f"{stream.p50_ms:.2f}",
f"{stream.p95_ms:.2f}",
f"{proxy.rps:.2f}",
f"{client_overhead_p50:.2f}",
f"{client_overhead_p95:.2f}",
format_optional_ms(proxy.overhead_header_p50_ms),
f"{stream_full.p50_ms:.2f}" if stream_full is not None else "n/a",
f"{stream_full.rps:.2f}" if stream_full is not None else "n/a",
]
)
+ " |"
)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("--label", default="current", help="Label for this run")
parser.add_argument(
"--litellm-dir",
default=str(Path.cwd()),
help="Checkout directory used to start the LiteLLM proxy",
)
parser.add_argument(
"--proxy-command",
default="uv run litellm",
help="Command used to start the proxy inside --litellm-dir",
)
parser.add_argument("--proxy-host", default="127.0.0.1")
parser.add_argument("--proxy-port", type=int, default=4000)
parser.add_argument("--provider-host", default="127.0.0.1")
parser.add_argument("--provider-port", type=int, default=8099)
parser.add_argument("--api-key", default=DEFAULT_API_KEY)
parser.add_argument("--requests", type=int, default=500)
parser.add_argument("--concurrency", type=int, default=100)
parser.add_argument("--stream-requests", type=int, default=200)
parser.add_argument("--stream-concurrency", type=int, default=20)
parser.add_argument("--warmup", type=int, default=100)
parser.add_argument("--stream-warmup", type=int, default=20)
parser.add_argument("--timeout", type=float, default=30)
parser.add_argument("--proxy-start-timeout", type=float, default=90)
parser.add_argument("--provider-first-token-delay-ms", type=float, default=0)
parser.add_argument(
"--provider-stream-content-chunks",
type=int,
default=20,
help="Streaming chunks the mock provider emits. Default 20 (realistic).",
)
parser.add_argument(
"--measure-full-stream",
action="store_true",
default=True,
help="Measure time to consume the complete streaming response (on by default).",
)
parser.add_argument(
"--no-measure-full-stream",
dest="measure_full_stream",
action="store_false",
help="Skip the full-stream RPS measurement.",
)
parser.add_argument(
"--repeats",
type=int,
default=1,
help="Run the entire suite N times against the same proxy and report the median run.",
)
parser.add_argument(
"--no-start-proxy",
action="store_true",
help="Benchmark an already-running proxy at --proxy-host/--proxy-port",
)
parser.add_argument(
"--provider-url",
help="Use an already-running provider instead of starting the mock provider",
)
parser.add_argument("--output-json", help="Write machine-readable results")
return parser.parse_args()
async def async_main() -> None:
args = parse_args()
litellm_dir = Path(args.litellm_dir).resolve()
revision = get_git_revision(litellm_dir)
proxy_base_url = f"http://{args.proxy_host}:{args.proxy_port}"
proxy_url = f"{proxy_base_url}/v1/chat/completions"
headers = {
"Authorization": f"Bearer {args.api_key}",
"Content-Type": "application/json",
}
provider_headers = {
"Authorization": "Bearer fake-provider-key",
"Content-Type": "application/json",
}
non_stream_payload = {
"model": DEFAULT_MODEL,
"messages": [{"role": "user", "content": "hi"}],
"max_tokens": 1,
}
stream_payload = {**non_stream_payload, "stream": True}
provider: Optional[MockOpenAIProvider] = None
proxy_process: Optional[subprocess.Popen] = None
with tempfile.TemporaryDirectory(prefix="litellm-perf-") as tmp_dir_name:
tmp_dir = Path(tmp_dir_name)
proxy_log_path = tmp_dir / "proxy.log"
if args.provider_url:
provider_base_url = args.provider_url.rstrip("/")
else:
provider = MockOpenAIProvider(
host=args.provider_host,
port=args.provider_port,
first_token_delay_ms=args.provider_first_token_delay_ms,
stream_content_chunks=args.provider_stream_content_chunks,
)
await provider.start()
provider_base_url = provider.base_url
config_path = tmp_dir / "config.yaml"
write_proxy_config(config_path, provider_base_url, args.api_key)
try:
if not args.no_start_proxy:
proxy_process = start_proxy_process(
litellm_dir=litellm_dir,
proxy_command=args.proxy_command,
config_path=config_path,
port=args.proxy_port,
log_path=proxy_log_path,
)
await wait_for_proxy(proxy_base_url, args.proxy_start_timeout)
runs: list[
tuple[
SummaryStats,
SummaryStats,
SummaryStats,
Optional[SummaryStats],
]
] = []
for run_idx in range(max(1, args.repeats)):
if args.repeats > 1:
print(f"\n--- Run {run_idx + 1}/{args.repeats} ---")
_direct = await run_non_streaming_benchmark(
url=f"{provider_base_url}/v1/chat/completions",
headers=provider_headers,
payload=non_stream_payload,
requests=args.requests,
concurrency=args.concurrency,
warmup=args.warmup,
timeout_s=args.timeout,
)
_proxy = await run_non_streaming_benchmark(
url=proxy_url,
headers=headers,
payload=non_stream_payload,
requests=args.requests,
concurrency=args.concurrency,
warmup=args.warmup,
timeout_s=args.timeout,
)
_stream = await run_streaming_ttft_benchmark(
url=proxy_url,
headers=headers,
payload=stream_payload,
requests=args.stream_requests,
concurrency=args.stream_concurrency,
warmup=args.stream_warmup,
timeout_s=args.timeout,
)
_stream_full = (
await run_streaming_full_benchmark(
url=proxy_url,
headers=headers,
payload=stream_payload,
requests=args.stream_requests,
concurrency=args.stream_concurrency,
warmup=args.stream_warmup,
timeout_s=args.timeout,
)
if args.measure_full_stream
else None
)
runs.append((_direct, _proxy, _stream, _stream_full))
if args.repeats > 1:
print(
f" run {run_idx + 1}: non-stream p50={_proxy.p50_ms:.2f}ms "
f"rps={_proxy.rps:.2f} | TTFT p50={_stream.p50_ms:.2f}ms "
f"full RPS="
+ (f"{_stream_full.rps:.2f}" if _stream_full else "n/a")
)
direct, proxy, stream, stream_full = _median_run(runs)
finally:
if proxy_process is not None:
stop_proxy_process(proxy_process)
if provider is not None:
await provider.stop()
print_summary(args.label, revision, direct, proxy, stream, stream_full)
if args.output_json:
output = {
"label": args.label,
"revision": revision,
"direct_non_streaming": stats_to_dict(direct),
"proxy_non_streaming": stats_to_dict(proxy),
"proxy_streaming_ttft": stats_to_dict(stream),
"proxy_streaming_full": (
stats_to_dict(stream_full) if stream_full is not None else None
),
"client_observed_overhead_p50_ms": proxy.p50_ms - direct.p50_ms,
"client_observed_overhead_p95_ms": proxy.p95_ms - direct.p95_ms,
"proxy_log_path": str(proxy_log_path),
}
Path(args.output_json).write_text(
json.dumps(output, indent=2, sort_keys=True), encoding="utf-8"
)
def main() -> None:
asyncio.run(async_main())
if __name__ == "__main__":
main()