fix: improve bedrock streaming hot path perf (#28720)
This commit is contained in:
parent
1cb19b155e
commit
d5d6b26a72
@ -59,6 +59,8 @@ FUNCTION_CALL_ATTRIBUTE = "function_call"
|
||||
|
||||
_SYNC_ITER_EXHAUSTED = object()
|
||||
|
||||
_GCHUNK_FIELDS: frozenset = frozenset(GChunk.__annotations__)
|
||||
|
||||
|
||||
def _next_sync_or_exhausted(it: Any) -> Any:
|
||||
"""
|
||||
@ -181,6 +183,30 @@ class CustomStreamWrapper:
|
||||
self.created: Optional[int] = None
|
||||
self._last_returned_hidden_params: Optional[dict] = None
|
||||
|
||||
_cached_logging_provider = self.logging_obj.model_call_details.get(
|
||||
"custom_llm_provider", None
|
||||
)
|
||||
self._cached_logging_llm_provider: Optional[str] = _cached_logging_provider
|
||||
_effective_model = model or ""
|
||||
if (
|
||||
custom_llm_provider == "openai"
|
||||
and custom_llm_provider != _cached_logging_provider
|
||||
):
|
||||
_effective_model = "{}/{}".format(
|
||||
_cached_logging_provider, _effective_model
|
||||
)
|
||||
self._cached_model_name: str = _effective_model
|
||||
|
||||
# Snapshot assumes self._hidden_params is populated from litellm_params
|
||||
# at init and never mutated during the stream. If that ever changes,
|
||||
# this cache must be removed.
|
||||
self._base_hidden_params: Dict[str, Any] = {
|
||||
**self._hidden_params,
|
||||
"response_cost": None,
|
||||
}
|
||||
|
||||
self._post_streaming_hooks: Optional[List] = None
|
||||
|
||||
def _check_max_streaming_duration(self) -> None:
|
||||
"""Raise litellm.Timeout if the stream has exceeded LITELLM_MAX_STREAMING_DURATION_SECONDS."""
|
||||
from litellm.constants import LITELLM_MAX_STREAMING_DURATION_SECONDS
|
||||
@ -681,29 +707,16 @@ class CustomStreamWrapper:
|
||||
def model_response_creator(
|
||||
self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None
|
||||
):
|
||||
_model = self.model
|
||||
_received_llm_provider = self.custom_llm_provider
|
||||
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
|
||||
if (
|
||||
_received_llm_provider == "openai"
|
||||
and _received_llm_provider != _logging_obj_llm_provider
|
||||
):
|
||||
_model = "{}/{}".format(_logging_obj_llm_provider, _model)
|
||||
_model = self._cached_model_name
|
||||
_logging_obj_llm_provider = self._cached_logging_llm_provider
|
||||
|
||||
if chunk is None:
|
||||
chunk = {}
|
||||
args: Dict[str, Any] = {"model": _model}
|
||||
else:
|
||||
# pop model keyword
|
||||
chunk.pop("model", None)
|
||||
|
||||
chunk_dict = {}
|
||||
for key, value in chunk.items():
|
||||
if key != "stream":
|
||||
chunk_dict[key] = value
|
||||
|
||||
args = {
|
||||
"model": _model,
|
||||
**chunk_dict,
|
||||
}
|
||||
args = {"model": _model}
|
||||
if chunk:
|
||||
args.update({k: v for k, v in chunk.items() if k != "stream"})
|
||||
|
||||
model_response = ModelResponseStream(**args)
|
||||
if self.response_id is not None:
|
||||
@ -717,15 +730,23 @@ class CustomStreamWrapper:
|
||||
model_response.created = self.created
|
||||
else:
|
||||
self.created = model_response.created
|
||||
|
||||
# Spread order is load-bearing: _base_hidden_params (model_id, api_base, ...)
|
||||
# must win over both caller-supplied hidden_params and the computed
|
||||
# custom_llm_provider/created_at values, so it comes last.
|
||||
if hidden_params is not None:
|
||||
model_response._hidden_params = hidden_params
|
||||
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
|
||||
model_response._hidden_params["created_at"] = time.time()
|
||||
model_response._hidden_params = {
|
||||
**model_response._hidden_params,
|
||||
**self._hidden_params,
|
||||
"response_cost": None,
|
||||
}
|
||||
model_response._hidden_params = {
|
||||
**hidden_params,
|
||||
"custom_llm_provider": _logging_obj_llm_provider,
|
||||
"created_at": time.time(),
|
||||
**self._base_hidden_params,
|
||||
}
|
||||
else:
|
||||
model_response._hidden_params = {
|
||||
"custom_llm_provider": _logging_obj_llm_provider,
|
||||
"created_at": time.time(),
|
||||
**self._base_hidden_params,
|
||||
}
|
||||
|
||||
if (
|
||||
len(model_response.choices) > 0
|
||||
@ -1627,7 +1648,17 @@ class CustomStreamWrapper:
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
# Get request kwargs from logging object
|
||||
if self._post_streaming_hooks is None:
|
||||
self._post_streaming_hooks = [
|
||||
cb
|
||||
for cb in litellm.callbacks
|
||||
if isinstance(cb, CustomLogger)
|
||||
and hasattr(cb, "async_post_call_streaming_deployment_hook")
|
||||
]
|
||||
|
||||
if not self._post_streaming_hooks:
|
||||
return chunk
|
||||
|
||||
request_data = self.logging_obj.model_call_details
|
||||
call_type_str = self.logging_obj.call_type
|
||||
|
||||
@ -1636,18 +1667,14 @@ class CustomStreamWrapper:
|
||||
except ValueError:
|
||||
typed_call_type = None
|
||||
|
||||
# Call hooks for all callbacks
|
||||
for callback in litellm.callbacks:
|
||||
if isinstance(callback, CustomLogger) and hasattr(
|
||||
callback, "async_post_call_streaming_deployment_hook"
|
||||
):
|
||||
result = await callback.async_post_call_streaming_deployment_hook(
|
||||
request_data=request_data,
|
||||
response_chunk=chunk,
|
||||
call_type=typed_call_type,
|
||||
)
|
||||
if result is not None:
|
||||
chunk = result
|
||||
for callback in self._post_streaming_hooks:
|
||||
result = await callback.async_post_call_streaming_deployment_hook(
|
||||
request_data=request_data,
|
||||
response_chunk=chunk,
|
||||
call_type=typed_call_type,
|
||||
)
|
||||
if result is not None:
|
||||
chunk = result
|
||||
|
||||
return chunk
|
||||
except Exception as e:
|
||||
@ -1888,17 +1915,15 @@ class CustomStreamWrapper:
|
||||
response = self._add_mcp_list_tools_to_first_chunk(response)
|
||||
self.sent_first_chunk = True
|
||||
|
||||
if hasattr(
|
||||
response, "usage"
|
||||
): # remove usage from chunk, only send on final chunk
|
||||
# Convert the object to a dictionary
|
||||
# ModelResponseStream declares `usage` as a field, so
|
||||
# hasattr(response, "usage") is always True — must check
|
||||
# `is not None` to avoid running this path on every chunk.
|
||||
if getattr(response, "usage", None) is not None:
|
||||
obj_dict = response.model_dump()
|
||||
|
||||
# Remove an attribute (e.g., 'attr2')
|
||||
if "usage" in obj_dict:
|
||||
del obj_dict["usage"]
|
||||
|
||||
# Create a new object without the removed attribute
|
||||
response = self.model_response_creator(
|
||||
chunk=obj_dict, hidden_params=response._hidden_params
|
||||
)
|
||||
@ -2398,10 +2423,7 @@ def generic_chunk_has_all_required_fields(chunk: dict) -> bool:
|
||||
:param chunk: The dictionary to check.
|
||||
:return: True if all required fields are present, False otherwise.
|
||||
"""
|
||||
_all_fields = GChunk.__annotations__
|
||||
|
||||
decision = all(key in _all_fields for key in chunk)
|
||||
return decision
|
||||
return all(key in _GCHUNK_FIELDS for key in chunk)
|
||||
|
||||
|
||||
def convert_generic_chunk_to_model_response_stream(
|
||||
|
||||
191
scripts/benchmark_model_response_creator.py
Normal file
191
scripts/benchmark_model_response_creator.py
Normal file
@ -0,0 +1,191 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Tight microbenchmark for CustomStreamWrapper.model_response_creator.
|
||||
|
||||
Calls model_response_creator() in a tight loop on a pre-built wrapper to
|
||||
isolate per-call cost. Driving the full wrapper adds threadpool logging,
|
||||
gc, and other noise that swamps microsecond-scale changes here.
|
||||
|
||||
Example:
|
||||
uv run python scripts/benchmark_model_response_creator.py --label baseline
|
||||
uv run python scripts/benchmark_model_response_creator.py --label optimized
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
os.environ.setdefault("LITELLM_LOG", "ERROR")
|
||||
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
|
||||
|
||||
import litellm # noqa: E402
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
from litellm.litellm_core_utils.streaming_handler import (
|
||||
CustomStreamWrapper,
|
||||
) # noqa: E402
|
||||
|
||||
|
||||
def _make_logging_obj(provider: str) -> MagicMock:
|
||||
logging_obj = MagicMock()
|
||||
logging_obj.model_call_details = {
|
||||
"custom_llm_provider": provider,
|
||||
"litellm_params": {},
|
||||
}
|
||||
logging_obj.call_type = "completion"
|
||||
logging_obj.stream_options = None
|
||||
logging_obj.messages = [{"role": "user", "content": "hi"}]
|
||||
logging_obj.completion_start_time = None
|
||||
logging_obj._llm_caching_handler = None
|
||||
return logging_obj
|
||||
|
||||
|
||||
def _make_wrapper(provider: str, model: str) -> CustomStreamWrapper:
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=iter([]),
|
||||
model=model,
|
||||
logging_obj=_make_logging_obj(provider),
|
||||
custom_llm_provider=provider,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
label: str
|
||||
scenario: str
|
||||
iterations: int
|
||||
elapsed_min_s: float
|
||||
elapsed_median_s: float
|
||||
per_call_us: float
|
||||
calls_per_sec: float
|
||||
|
||||
|
||||
SCENARIOS = {
|
||||
"no_chunk": {
|
||||
"description": "model_response_creator() — no chunk arg (most common path)",
|
||||
"chunk_factory": lambda i: None,
|
||||
},
|
||||
"text_chunk": {
|
||||
"description": "model_response_creator(chunk={'text': '...'}) — text delta path",
|
||||
"chunk_factory": lambda i: {"text": f"token{i}"},
|
||||
},
|
||||
"rich_chunk": {
|
||||
"description": "model_response_creator(chunk={...}) — full chunk dict path",
|
||||
"chunk_factory": lambda i: {
|
||||
"id": f"id-{i}",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1234567890,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def bench_no_chunk(wrapper: CustomStreamWrapper, iterations: int) -> float:
|
||||
gc.collect()
|
||||
gc.disable()
|
||||
try:
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
wrapper.model_response_creator()
|
||||
elapsed = time.perf_counter() - start
|
||||
finally:
|
||||
gc.enable()
|
||||
return elapsed
|
||||
|
||||
|
||||
def bench_with_chunk(wrapper: CustomStreamWrapper, factory, iterations: int) -> float:
|
||||
# Pre-build chunks so we don't measure their construction cost.
|
||||
chunks = [factory(i) for i in range(iterations)]
|
||||
gc.collect()
|
||||
gc.disable()
|
||||
try:
|
||||
start = time.perf_counter()
|
||||
for chunk in chunks:
|
||||
wrapper.model_response_creator(chunk=dict(chunk)) # copy because mutated
|
||||
elapsed = time.perf_counter() - start
|
||||
finally:
|
||||
gc.enable()
|
||||
return elapsed
|
||||
|
||||
|
||||
def run_scenario(
|
||||
label: str,
|
||||
scenario_key: str,
|
||||
iterations: int,
|
||||
repeats: int,
|
||||
warmup: int,
|
||||
) -> Result:
|
||||
spec = SCENARIOS[scenario_key]
|
||||
wrapper = _make_wrapper(provider="anthropic", model="claude-3-5-sonnet")
|
||||
|
||||
if scenario_key == "no_chunk":
|
||||
runner = lambda: bench_no_chunk(wrapper, iterations) # noqa: E731
|
||||
else:
|
||||
runner = lambda: bench_with_chunk(
|
||||
wrapper, spec["chunk_factory"], iterations
|
||||
) # noqa: E731
|
||||
|
||||
for _ in range(warmup):
|
||||
runner()
|
||||
samples = [runner() for _ in range(repeats)]
|
||||
|
||||
elapsed_min = min(samples)
|
||||
elapsed_median = statistics.median(samples)
|
||||
per_call_us = (elapsed_min * 1_000_000) / iterations
|
||||
calls_per_sec = iterations / elapsed_min if elapsed_min > 0 else 0.0
|
||||
|
||||
return Result(
|
||||
label=label,
|
||||
scenario=scenario_key,
|
||||
iterations=iterations,
|
||||
elapsed_min_s=elapsed_min,
|
||||
elapsed_median_s=elapsed_median,
|
||||
per_call_us=per_call_us,
|
||||
calls_per_sec=calls_per_sec,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description=__doc__.splitlines()[0])
|
||||
ap.add_argument("--label", required=True)
|
||||
ap.add_argument("--iterations", type=int, default=200_000)
|
||||
ap.add_argument("--warmup", type=int, default=2)
|
||||
ap.add_argument("--repeats", type=int, default=8)
|
||||
ap.add_argument("--json", dest="json_out")
|
||||
args = ap.parse_args()
|
||||
|
||||
print(
|
||||
f"\n=== label={args.label} iterations={args.iterations:,} "
|
||||
f"warmup={args.warmup} repeats={args.repeats} (min reported) ==="
|
||||
)
|
||||
results: List[Result] = []
|
||||
for scenario in SCENARIOS:
|
||||
r = run_scenario(
|
||||
args.label, scenario, args.iterations, args.repeats, args.warmup
|
||||
)
|
||||
results.append(r)
|
||||
print(
|
||||
f" {r.scenario:12s}: "
|
||||
f"min={r.elapsed_min_s*1000:8.2f} ms "
|
||||
f"median={r.elapsed_median_s*1000:8.2f} ms "
|
||||
f"per-call={r.per_call_us:7.3f} μs "
|
||||
f"calls/s={r.calls_per_sec:>12,.0f}"
|
||||
)
|
||||
|
||||
if args.json_out:
|
||||
with open(args.json_out, "w", encoding="utf-8") as f:
|
||||
json.dump([asdict(r) for r in results], f, indent=2)
|
||||
print(f"\nWrote {len(results)} results to {args.json_out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
369
scripts/benchmark_streaming_chunk_overhead.py
Normal file
369
scripts/benchmark_streaming_chunk_overhead.py
Normal file
@ -0,0 +1,369 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Benchmark CustomStreamWrapper per-chunk overhead.
|
||||
|
||||
Drives CustomStreamWrapper directly with synthetic in-memory chunks for
|
||||
Anthropic (GenericStreamingChunk), Bedrock Invoke (GenericStreamingChunk),
|
||||
and Bedrock Converse (ModelResponseStream). A full proxy benchmark adds
|
||||
FastAPI, HTTP, and TCP latency, which dilutes the per-chunk CPU signal.
|
||||
|
||||
Example:
|
||||
uv run python scripts/benchmark_streaming_chunk_overhead.py \\
|
||||
--streams 500 --chunks 200 --warmup 50 --repeats 5
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import gc
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Callable, List, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
# Silence litellm's "Provider List" warnings emitted by get_llm_provider
|
||||
# when it sees synthetic model names — we're not exercising provider
|
||||
# routing, only the per-chunk wrapper hot path.
|
||||
os.environ.setdefault("LITELLM_LOG", "ERROR")
|
||||
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
|
||||
|
||||
import litellm # noqa: E402
|
||||
|
||||
litellm.suppress_debug_info = True
|
||||
|
||||
from litellm.litellm_core_utils.streaming_handler import (
|
||||
CustomStreamWrapper,
|
||||
) # noqa: E402
|
||||
from litellm.types.utils import ( # noqa: E402
|
||||
Delta,
|
||||
GenericStreamingChunk as GChunk,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
Usage,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Synthetic chunk fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_logging_obj(provider: str) -> MagicMock:
|
||||
logging_obj = MagicMock()
|
||||
logging_obj.model_call_details = {
|
||||
"custom_llm_provider": provider,
|
||||
"litellm_params": {},
|
||||
}
|
||||
logging_obj.call_type = "completion"
|
||||
logging_obj.stream_options = None
|
||||
logging_obj.messages = [{"role": "user", "content": "hi"}]
|
||||
logging_obj.completion_start_time = None
|
||||
logging_obj._llm_caching_handler = None
|
||||
return logging_obj
|
||||
|
||||
|
||||
def _make_generic_chunk(
|
||||
text: str,
|
||||
is_finished: bool = False,
|
||||
finish_reason: str = "",
|
||||
usage: Optional[dict] = None,
|
||||
) -> GChunk:
|
||||
return GChunk(
|
||||
text=text,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_converse_chunk(
|
||||
text: str = "",
|
||||
finish_reason: str = "",
|
||||
usage: Optional[Usage] = None,
|
||||
) -> ModelResponseStream:
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=finish_reason or None,
|
||||
index=0,
|
||||
delta=Delta(content=text, role="assistant"),
|
||||
)
|
||||
],
|
||||
id="msg-bench",
|
||||
model="anthropic.claude-3-5-sonnet",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider stream factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def anthropic_chunks(n: int) -> List[GChunk]:
|
||||
out: List[GChunk] = [_make_generic_chunk(f"tok{i} ") for i in range(n)]
|
||||
out.append(
|
||||
_make_generic_chunk(
|
||||
"",
|
||||
is_finished=True,
|
||||
finish_reason="stop",
|
||||
usage={"prompt_tokens": 10, "completion_tokens": n, "total_tokens": 10 + n},
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def bedrock_invoke_chunks(n: int) -> List[GChunk]:
|
||||
# Bedrock Invoke surfaces GChunk-shaped dicts, same shape as Anthropic.
|
||||
return anthropic_chunks(n)
|
||||
|
||||
|
||||
def bedrock_converse_chunks(n: int) -> List[ModelResponseStream]:
|
||||
out: List[ModelResponseStream] = [
|
||||
_make_converse_chunk(f"tok{i} ") for i in range(n)
|
||||
]
|
||||
out.append(
|
||||
_make_converse_chunk(
|
||||
text="",
|
||||
finish_reason="stop",
|
||||
usage=Usage(prompt_tokens=10, completion_tokens=n, total_tokens=10 + n),
|
||||
)
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
PROVIDERS: dict[str, tuple[str, Callable[[int], list]]] = {
|
||||
"anthropic": ("anthropic", anthropic_chunks),
|
||||
"bedrock_invoke": ("bedrock", bedrock_invoke_chunks),
|
||||
"bedrock_converse": ("bedrock", bedrock_converse_chunks),
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Drive a single stream end-to-end
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_wrapper(
|
||||
chunks: list, provider: str, async_stream: bool
|
||||
) -> CustomStreamWrapper:
|
||||
logging_obj = _make_logging_obj(provider)
|
||||
if async_stream:
|
||||
|
||||
async def _agen():
|
||||
for c in chunks:
|
||||
yield c
|
||||
|
||||
stream = _agen()
|
||||
else:
|
||||
stream = iter(chunks)
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=stream,
|
||||
model="claude-3-5-sonnet",
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=provider,
|
||||
)
|
||||
|
||||
|
||||
def drive_sync(provider_key: str, chunks_per_stream: int, n_streams: int) -> float:
|
||||
provider, factory = PROVIDERS[provider_key]
|
||||
# Pre-build the chunk lists; we only measure wrapper iteration cost.
|
||||
chunk_lists = [factory(chunks_per_stream) for _ in range(n_streams)]
|
||||
gc.collect()
|
||||
gc.disable()
|
||||
try:
|
||||
start = time.perf_counter()
|
||||
for chunks in chunk_lists:
|
||||
wrapper = _make_wrapper(chunks, provider, async_stream=False)
|
||||
for _ in wrapper:
|
||||
pass
|
||||
elapsed = time.perf_counter() - start
|
||||
finally:
|
||||
gc.enable()
|
||||
return elapsed
|
||||
|
||||
|
||||
async def drive_async(
|
||||
provider_key: str, chunks_per_stream: int, n_streams: int
|
||||
) -> float:
|
||||
provider, factory = PROVIDERS[provider_key]
|
||||
chunk_lists = [factory(chunks_per_stream) for _ in range(n_streams)]
|
||||
gc.collect()
|
||||
gc.disable()
|
||||
try:
|
||||
start = time.perf_counter()
|
||||
for chunks in chunk_lists:
|
||||
wrapper = _make_wrapper(chunks, provider, async_stream=True)
|
||||
async for _ in wrapper:
|
||||
pass
|
||||
elapsed = time.perf_counter() - start
|
||||
finally:
|
||||
gc.enable()
|
||||
return elapsed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Repeat × take-min runner
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
label: str
|
||||
provider: str
|
||||
mode: str
|
||||
streams: int
|
||||
chunks_per_stream: int
|
||||
total_chunks: int
|
||||
elapsed_min_s: float
|
||||
elapsed_median_s: float
|
||||
per_chunk_us: float
|
||||
chunks_per_sec: float
|
||||
streams_per_sec: float
|
||||
|
||||
|
||||
def run_case(
|
||||
label: str,
|
||||
provider_key: str,
|
||||
mode: str,
|
||||
chunks_per_stream: int,
|
||||
n_streams: int,
|
||||
repeats: int,
|
||||
warmup: int,
|
||||
) -> Result:
|
||||
if mode == "sync":
|
||||
# Warmup runs amortize import-time and JIT-y caches.
|
||||
for _ in range(warmup):
|
||||
drive_sync(provider_key, chunks_per_stream, max(1, n_streams // 10))
|
||||
samples = [
|
||||
drive_sync(provider_key, chunks_per_stream, n_streams)
|
||||
for _ in range(repeats)
|
||||
]
|
||||
elif mode == "async":
|
||||
|
||||
async def _warm():
|
||||
for _ in range(warmup):
|
||||
await drive_async(
|
||||
provider_key, chunks_per_stream, max(1, n_streams // 10)
|
||||
)
|
||||
|
||||
asyncio.run(_warm())
|
||||
samples = [
|
||||
asyncio.run(drive_async(provider_key, chunks_per_stream, n_streams))
|
||||
for _ in range(repeats)
|
||||
]
|
||||
else:
|
||||
raise ValueError(f"unknown mode {mode!r}")
|
||||
|
||||
elapsed_min = min(samples)
|
||||
elapsed_median = statistics.median(samples)
|
||||
# Each stream emits chunks_per_stream text chunks + 1 finish/usage chunk.
|
||||
total_chunks = n_streams * (chunks_per_stream + 1)
|
||||
per_chunk_us = (elapsed_min * 1_000_000) / total_chunks
|
||||
chunks_per_sec = total_chunks / elapsed_min if elapsed_min > 0 else 0.0
|
||||
streams_per_sec = n_streams / elapsed_min if elapsed_min > 0 else 0.0
|
||||
|
||||
return Result(
|
||||
label=label,
|
||||
provider=provider_key,
|
||||
mode=mode,
|
||||
streams=n_streams,
|
||||
chunks_per_stream=chunks_per_stream,
|
||||
total_chunks=total_chunks,
|
||||
elapsed_min_s=elapsed_min,
|
||||
elapsed_median_s=elapsed_median,
|
||||
per_chunk_us=per_chunk_us,
|
||||
chunks_per_sec=chunks_per_sec,
|
||||
streams_per_sec=streams_per_sec,
|
||||
)
|
||||
|
||||
|
||||
def format_result(r: Result) -> str:
|
||||
return (
|
||||
f" {r.provider:18s} {r.mode:5s}: "
|
||||
f"min={r.elapsed_min_s*1000:8.2f} ms "
|
||||
f"median={r.elapsed_median_s*1000:8.2f} ms "
|
||||
f"per-chunk={r.per_chunk_us:7.2f} μs "
|
||||
f"chunks/s={r.chunks_per_sec:>10,.0f} "
|
||||
f"streams/s={r.streams_per_sec:>8,.1f}"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def main() -> None:
|
||||
ap = argparse.ArgumentParser(description=__doc__.splitlines()[0])
|
||||
ap.add_argument(
|
||||
"--label", required=True, help="Run label (e.g. baseline / optimized)"
|
||||
)
|
||||
ap.add_argument("--streams", type=int, default=500, help="Streams per run")
|
||||
ap.add_argument(
|
||||
"--chunks",
|
||||
type=int,
|
||||
default=200,
|
||||
help="Text chunks per stream (excl. finish chunk)",
|
||||
)
|
||||
ap.add_argument("--warmup", type=int, default=2, help="Warmup runs")
|
||||
ap.add_argument(
|
||||
"--repeats", type=int, default=5, help="Measured runs (we report min)"
|
||||
)
|
||||
ap.add_argument(
|
||||
"--providers",
|
||||
default="anthropic,bedrock_invoke,bedrock_converse",
|
||||
help="Comma-separated provider list",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--modes",
|
||||
default="sync,async",
|
||||
help="Comma-separated iteration modes (sync/async)",
|
||||
)
|
||||
ap.add_argument(
|
||||
"--json", dest="json_out", help="Write results as JSON to this path"
|
||||
)
|
||||
args = ap.parse_args()
|
||||
|
||||
providers = [p.strip() for p in args.providers.split(",") if p.strip()]
|
||||
modes = [m.strip() for m in args.modes.split(",") if m.strip()]
|
||||
|
||||
for p in providers:
|
||||
if p not in PROVIDERS:
|
||||
raise SystemExit(f"unknown provider {p!r}; choose from {list(PROVIDERS)}")
|
||||
for m in modes:
|
||||
if m not in {"sync", "async"}:
|
||||
raise SystemExit(f"unknown mode {m!r}; choose from sync/async")
|
||||
|
||||
print(
|
||||
f"\n=== label={args.label} streams={args.streams} chunks/stream={args.chunks} "
|
||||
f"warmup={args.warmup} repeats={args.repeats} (min reported) ==="
|
||||
)
|
||||
results: List[Result] = []
|
||||
for provider_key in providers:
|
||||
for mode in modes:
|
||||
r = run_case(
|
||||
label=args.label,
|
||||
provider_key=provider_key,
|
||||
mode=mode,
|
||||
chunks_per_stream=args.chunks,
|
||||
n_streams=args.streams,
|
||||
repeats=args.repeats,
|
||||
warmup=args.warmup,
|
||||
)
|
||||
results.append(r)
|
||||
print(format_result(r))
|
||||
|
||||
if args.json_out:
|
||||
with open(args.json_out, "w", encoding="utf-8") as f:
|
||||
json.dump([asdict(r) for r in results], f, indent=2)
|
||||
print(f"\nWrote {len(results)} results to {args.json_out}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
508
tests/test_litellm/litellm_core_utils/test_streaming_overhead.py
Normal file
508
tests/test_litellm/litellm_core_utils/test_streaming_overhead.py
Normal file
@ -0,0 +1,508 @@
|
||||
"""
|
||||
Tests for CustomStreamWrapper per-chunk behavior across Anthropic,
|
||||
Bedrock Invoke, and Bedrock Converse: text passthrough, usage stripping,
|
||||
hidden_params propagation, finish_reason, sync/async parity, and the
|
||||
per-stream caches (_GCHUNK_FIELDS, _post_streaming_hooks).
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import List, Optional
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.streaming_handler import (
|
||||
CustomStreamWrapper,
|
||||
_GCHUNK_FIELDS,
|
||||
generic_chunk_has_all_required_fields,
|
||||
)
|
||||
from litellm.types.utils import (
|
||||
Delta,
|
||||
GenericStreamingChunk as GChunk,
|
||||
ModelResponseStream,
|
||||
StreamingChoices,
|
||||
Usage,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_logging_obj(provider: str = "anthropic") -> MagicMock:
|
||||
logging_obj = MagicMock()
|
||||
logging_obj.model_call_details = {
|
||||
"custom_llm_provider": provider,
|
||||
"litellm_params": {},
|
||||
}
|
||||
logging_obj.call_type = "completion"
|
||||
logging_obj.stream_options = None
|
||||
logging_obj.messages = [{"role": "user", "content": "hi"}]
|
||||
logging_obj.completion_start_time = None
|
||||
logging_obj._llm_caching_handler = None
|
||||
return logging_obj
|
||||
|
||||
|
||||
def _make_generic_chunk(
|
||||
text: str,
|
||||
is_finished: bool = False,
|
||||
finish_reason: str = "",
|
||||
usage: Optional[dict] = None,
|
||||
) -> GChunk:
|
||||
return GChunk(
|
||||
text=text,
|
||||
is_finished=is_finished,
|
||||
finish_reason=finish_reason,
|
||||
usage=usage,
|
||||
index=0,
|
||||
tool_use=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_bedrock_converse_chunk(
|
||||
text: str = "",
|
||||
finish_reason: str = "",
|
||||
usage: Optional[Usage] = None,
|
||||
) -> ModelResponseStream:
|
||||
"""Simulate what AWSEventStreamDecoder.converse_chunk_parser returns."""
|
||||
return ModelResponseStream(
|
||||
choices=[
|
||||
StreamingChoices(
|
||||
finish_reason=finish_reason or None,
|
||||
index=0,
|
||||
delta=Delta(content=text, role="assistant"),
|
||||
)
|
||||
],
|
||||
id="msg-test",
|
||||
model="anthropic.claude-3-5-sonnet",
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
|
||||
async def _async_iter(chunks: list):
|
||||
"""Wrap a list as a proper async iterator for use in __anext__ async branch."""
|
||||
for chunk in chunks:
|
||||
yield chunk
|
||||
|
||||
|
||||
def _make_wrapper(
|
||||
chunks: list,
|
||||
provider: str = "anthropic",
|
||||
async_stream: bool = False,
|
||||
) -> CustomStreamWrapper:
|
||||
logging_obj = _make_logging_obj(provider)
|
||||
stream = _async_iter(chunks) if async_stream else iter(chunks)
|
||||
wrapper = CustomStreamWrapper(
|
||||
completion_stream=stream,
|
||||
model="claude-3-5-sonnet",
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider=provider,
|
||||
)
|
||||
return wrapper
|
||||
|
||||
|
||||
def _drain_sync(wrapper: CustomStreamWrapper) -> List[ModelResponseStream]:
|
||||
results = []
|
||||
for chunk in wrapper:
|
||||
results.append(chunk)
|
||||
return results
|
||||
|
||||
|
||||
async def _drain_async(wrapper: CustomStreamWrapper) -> List[ModelResponseStream]:
|
||||
results = []
|
||||
async for chunk in wrapper:
|
||||
results.append(chunk)
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Module-level _GCHUNK_FIELDS constant
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_gchunk_fields_is_frozenset():
|
||||
"""_GCHUNK_FIELDS must be a frozenset built from GChunk.__annotations__."""
|
||||
assert isinstance(_GCHUNK_FIELDS, frozenset)
|
||||
assert _GCHUNK_FIELDS == frozenset(GChunk.__annotations__)
|
||||
|
||||
|
||||
def test_generic_chunk_has_all_required_fields_uses_module_constant(monkeypatch):
|
||||
"""generic_chunk_has_all_required_fields must use _GCHUNK_FIELDS, not __annotations__.
|
||||
|
||||
The check semantics: every key in `chunk` must be a known GChunk field.
|
||||
This identifies GChunk-shaped dicts (all keys are valid GChunk fields).
|
||||
"""
|
||||
valid_chunk = _make_generic_chunk("hello")
|
||||
assert generic_chunk_has_all_required_fields(valid_chunk) is True
|
||||
|
||||
# A dict with an extra unknown key should return False — the unknown key
|
||||
# is not a GChunk field, so the chunk is not a pure GChunk.
|
||||
extra_key_chunk = dict(valid_chunk)
|
||||
extra_key_chunk["unknown_extra_key"] = "value"
|
||||
assert generic_chunk_has_all_required_fields(extra_key_chunk) is False
|
||||
|
||||
# A dict with only known GChunk fields but fewer keys still passes because
|
||||
# all its keys are valid (subset of GChunk fields).
|
||||
partial_chunk = {"text": "hi", "is_finished": False}
|
||||
assert generic_chunk_has_all_required_fields(partial_chunk) is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Cached model name and provider at init time
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_cached_model_name_simple():
|
||||
"""For non-openai providers the cached model name must match the model arg."""
|
||||
wrapper = _make_wrapper([], provider="anthropic")
|
||||
assert wrapper._cached_model_name == "claude-3-5-sonnet"
|
||||
assert wrapper._cached_logging_llm_provider == "anthropic"
|
||||
|
||||
|
||||
def test_cached_model_name_openai_prefix():
|
||||
"""For openai provider when logging provider differs, model name is prefixed."""
|
||||
logging_obj = _make_logging_obj(provider="azure")
|
||||
wrapper = CustomStreamWrapper(
|
||||
completion_stream=iter([]),
|
||||
model="gpt-4o",
|
||||
logging_obj=logging_obj,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
assert wrapper._cached_model_name == "azure/gpt-4o"
|
||||
assert wrapper._cached_logging_llm_provider == "azure"
|
||||
|
||||
|
||||
def test_base_hidden_params_precomputed():
|
||||
"""_base_hidden_params must be pre-built from _hidden_params at init."""
|
||||
wrapper = _make_wrapper([], provider="anthropic")
|
||||
assert "response_cost" in wrapper._base_hidden_params
|
||||
assert wrapper._base_hidden_params["response_cost"] is None
|
||||
# Must include all keys from _hidden_params
|
||||
for k in wrapper._hidden_params:
|
||||
assert k in wrapper._base_hidden_params
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Sync path: model_dump() is NOT called on non-usage chunks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sync_path_no_model_dump_on_text_chunks():
|
||||
"""
|
||||
The sync __next__ must NOT call model_dump() on chunks that have no usage.
|
||||
|
||||
ModelResponseStream declares `usage` as a field, so a `hasattr` check
|
||||
would always succeed and trigger the model_dump()+recreate path on every
|
||||
chunk. The wrapper must check `is not None` instead.
|
||||
"""
|
||||
chunks = [
|
||||
_make_generic_chunk("Hello"),
|
||||
_make_generic_chunk(" world"),
|
||||
_make_generic_chunk("", is_finished=True, finish_reason="stop"),
|
||||
]
|
||||
wrapper = _make_wrapper(chunks)
|
||||
|
||||
model_dump_call_count = 0
|
||||
original_model_dump = ModelResponseStream.model_dump
|
||||
|
||||
def counting_model_dump(self, **kwargs):
|
||||
nonlocal model_dump_call_count
|
||||
model_dump_call_count += 1
|
||||
return original_model_dump(self, **kwargs)
|
||||
|
||||
with patch.object(ModelResponseStream, "model_dump", counting_model_dump):
|
||||
results = _drain_sync(wrapper)
|
||||
|
||||
text_chunks = [r for r in results if r.choices and r.choices[0].delta.content]
|
||||
assert len(text_chunks) >= 2, "Expected at least 2 text chunks"
|
||||
assert model_dump_call_count <= 1, (
|
||||
f"model_dump() called {model_dump_call_count} times — "
|
||||
"usage check is firing on every chunk"
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Sync path: usage chunk is stripped from body but preserved in hidden_params
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sync_path_usage_stripped_from_body_preserved_in_hidden_params():
|
||||
"""Usage data must be removed from the returned chunk but added to _hidden_params."""
|
||||
usage_dict = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
|
||||
chunks = [
|
||||
_make_generic_chunk("Hello"),
|
||||
_make_generic_chunk(
|
||||
"", is_finished=True, finish_reason="stop", usage=usage_dict
|
||||
),
|
||||
]
|
||||
wrapper = _make_wrapper(chunks)
|
||||
results = _drain_sync(wrapper)
|
||||
|
||||
# The usage chunk must be returned (not silently dropped)
|
||||
finish_chunks = [
|
||||
r for r in results if r.choices and r.choices[0].finish_reason == "stop"
|
||||
]
|
||||
assert finish_chunks, "Finish-reason chunk was not returned"
|
||||
|
||||
# The final chunk must carry usage in _hidden_params
|
||||
final = results[-1]
|
||||
assert "usage" in final._hidden_params, "usage missing from _hidden_params"
|
||||
hidden_usage = final._hidden_params["usage"]
|
||||
assert hidden_usage is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Async path: usage chunk is stripped from body but preserved in hidden_params
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_async_path_usage_stripped_from_body_preserved_in_hidden_params():
|
||||
"""Async path mirrors sync path for usage handling."""
|
||||
usage_dict = {"prompt_tokens": 5, "completion_tokens": 15, "total_tokens": 20}
|
||||
chunks = [
|
||||
_make_generic_chunk("Hi"),
|
||||
_make_generic_chunk(
|
||||
"", is_finished=True, finish_reason="stop", usage=usage_dict
|
||||
),
|
||||
]
|
||||
|
||||
async def _run():
|
||||
# async_stream=True forces the real async-for branch of __anext__
|
||||
wrapper = _make_wrapper(chunks, async_stream=True)
|
||||
return await _drain_async(wrapper)
|
||||
|
||||
results = asyncio.run(_run())
|
||||
final = results[-1]
|
||||
assert "usage" in final._hidden_params
|
||||
assert final._hidden_params["usage"] is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. Bedrock Converse: ModelResponseStream chunks pass through correctly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_bedrock_converse_text_chunks_pass_through():
|
||||
"""
|
||||
Bedrock Converse returns ModelResponseStream objects directly.
|
||||
They should pass through chunk_creator and appear in output unchanged.
|
||||
"""
|
||||
chunks = [
|
||||
_make_bedrock_converse_chunk("Hello"),
|
||||
_make_bedrock_converse_chunk(" world"),
|
||||
_make_bedrock_converse_chunk("", finish_reason="end_turn"),
|
||||
]
|
||||
wrapper = _make_wrapper(chunks, provider="bedrock")
|
||||
results = _drain_sync(wrapper)
|
||||
|
||||
texts = [
|
||||
r.choices[0].delta.content
|
||||
for r in results
|
||||
if r.choices and r.choices[0].delta.content
|
||||
]
|
||||
assert "Hello" in texts or any("Hello" in (t or "") for t in texts)
|
||||
|
||||
|
||||
def test_bedrock_converse_usage_chunk_stripped_and_in_hidden_params():
|
||||
"""Usage in a Bedrock Converse ModelResponseStream chunk is handled correctly."""
|
||||
usage = Usage(prompt_tokens=8, completion_tokens=12, total_tokens=20)
|
||||
chunks = [
|
||||
_make_bedrock_converse_chunk("Hi"),
|
||||
_make_bedrock_converse_chunk("", finish_reason="end_turn", usage=usage),
|
||||
]
|
||||
wrapper = _make_wrapper(chunks, provider="bedrock")
|
||||
results = _drain_sync(wrapper)
|
||||
|
||||
final = results[-1]
|
||||
assert "usage" in final._hidden_params
|
||||
assert final._hidden_params["usage"] is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 7. Anthropic generic chunk (GChunk) path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_anthropic_generic_chunks_text_pass_through():
|
||||
"""GChunk text chunks must arrive in the output with correct content."""
|
||||
chunks = [
|
||||
_make_generic_chunk("The"),
|
||||
_make_generic_chunk(" answer"),
|
||||
_make_generic_chunk("", is_finished=True, finish_reason="stop"),
|
||||
]
|
||||
wrapper = _make_wrapper(chunks, provider="anthropic")
|
||||
results = _drain_sync(wrapper)
|
||||
|
||||
texts = [
|
||||
r.choices[0].delta.content
|
||||
for r in results
|
||||
if r.choices and r.choices[0].delta.content
|
||||
]
|
||||
assert len(texts) >= 2
|
||||
|
||||
|
||||
def test_anthropic_finish_reason_propagated():
|
||||
"""finish_reason must be set on the final streaming chunk."""
|
||||
chunks = [
|
||||
_make_generic_chunk("Hi"),
|
||||
_make_generic_chunk("", is_finished=True, finish_reason="stop"),
|
||||
]
|
||||
wrapper = _make_wrapper(chunks, provider="anthropic")
|
||||
results = _drain_sync(wrapper)
|
||||
|
||||
finish_reasons = [
|
||||
r.choices[0].finish_reason
|
||||
for r in results
|
||||
if r.choices and r.choices[0].finish_reason
|
||||
]
|
||||
assert "stop" in finish_reasons
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 8. Callback caching: _post_streaming_hooks resolved once per stream
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_post_streaming_hooks_cached_after_first_call():
|
||||
"""
|
||||
_post_streaming_hooks must be None before the first hook call and a list after.
|
||||
The same list object must be reused on subsequent calls (not re-built).
|
||||
"""
|
||||
wrapper = _make_wrapper([], provider="anthropic")
|
||||
assert wrapper._post_streaming_hooks is None, "Must be None before first call"
|
||||
|
||||
async def _run():
|
||||
# Simulate hook resolution with an empty callback list
|
||||
with patch.object(litellm, "callbacks", []):
|
||||
await wrapper._call_post_streaming_deployment_hook(
|
||||
MagicMock(spec=ModelResponseStream)
|
||||
)
|
||||
first_list = wrapper._post_streaming_hooks
|
||||
assert isinstance(first_list, list)
|
||||
|
||||
# Second call must reuse the same list object
|
||||
with patch.object(litellm, "callbacks", []):
|
||||
await wrapper._call_post_streaming_deployment_hook(
|
||||
MagicMock(spec=ModelResponseStream)
|
||||
)
|
||||
assert (
|
||||
wrapper._post_streaming_hooks is first_list
|
||||
), "_post_streaming_hooks was rebuilt on second call — caching broken"
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_post_streaming_hooks_filters_correctly():
|
||||
"""
|
||||
Only CustomLogger instances must be included; plain callables are excluded.
|
||||
|
||||
Note: CustomLogger's base class already defines
|
||||
async_post_call_streaming_deployment_hook, so ALL CustomLogger subclasses
|
||||
pass the hasattr() check regardless of whether they override the method.
|
||||
The filter therefore keeps any CustomLogger instance and drops anything else.
|
||||
"""
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
class MyLogger(CustomLogger):
|
||||
pass
|
||||
|
||||
plain_callable = MagicMock()
|
||||
|
||||
wrapper = _make_wrapper([], provider="anthropic")
|
||||
|
||||
async def _run():
|
||||
with patch.object(litellm, "callbacks", [MyLogger(), plain_callable]):
|
||||
await wrapper._call_post_streaming_deployment_hook(
|
||||
MagicMock(spec=ModelResponseStream)
|
||||
)
|
||||
|
||||
# plain_callable must be excluded; MyLogger (CustomLogger subclass) included
|
||||
assert len(wrapper._post_streaming_hooks) == 1
|
||||
assert isinstance(wrapper._post_streaming_hooks[0], MyLogger)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 9. model_response_creator: hidden_params built correctly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_model_response_creator_hidden_params_no_chunk():
|
||||
"""model_response_creator() with no args must include all _base_hidden_params."""
|
||||
wrapper = _make_wrapper([], provider="anthropic")
|
||||
response = wrapper.model_response_creator()
|
||||
|
||||
assert response._hidden_params.get("response_cost") is None
|
||||
assert response._hidden_params.get("custom_llm_provider") == "anthropic"
|
||||
assert "created_at" in response._hidden_params
|
||||
|
||||
|
||||
def test_model_response_creator_hidden_params_caller_merged():
|
||||
"""When hidden_params are passed by caller, they must be included in result."""
|
||||
wrapper = _make_wrapper([], provider="anthropic")
|
||||
caller_params = {"some_key": "some_value"}
|
||||
response = wrapper.model_response_creator(hidden_params=caller_params)
|
||||
|
||||
assert response._hidden_params.get("some_key") == "some_value"
|
||||
assert response._hidden_params.get("response_cost") is None
|
||||
|
||||
|
||||
def test_model_response_creator_stream_key_stripped():
|
||||
"""The 'stream' key must be removed from chunk before constructing ModelResponseStream."""
|
||||
wrapper = _make_wrapper([], provider="anthropic")
|
||||
chunk = {"stream": True, "choices": []}
|
||||
# Should not raise even if 'stream' would be an invalid ModelResponseStream field
|
||||
response = wrapper.model_response_creator(chunk=chunk)
|
||||
assert response is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 10. Per-chunk overhead regression: sync path must not regress
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_sync_streaming_overhead_not_regressed():
|
||||
"""
|
||||
Micro-benchmark: the sync hot path must process 200 text chunks in < 2 s.
|
||||
|
||||
This test acts as a canary for gross per-chunk overhead regressions.
|
||||
It is intentionally generous (2 s) to avoid flakiness on slow CI runners.
|
||||
"""
|
||||
n_chunks = 200
|
||||
chunks = [_make_generic_chunk(f"token-{i}") for i in range(n_chunks)]
|
||||
chunks.append(_make_generic_chunk("", is_finished=True, finish_reason="stop"))
|
||||
|
||||
wrapper = _make_wrapper(chunks, provider="anthropic")
|
||||
|
||||
start = time.monotonic()
|
||||
results = _drain_sync(wrapper)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
assert len(results) > 0, "No chunks returned"
|
||||
assert elapsed < 2.0, (
|
||||
f"Sync streaming of {n_chunks} chunks took {elapsed:.3f}s — "
|
||||
"per-chunk overhead regression detected"
|
||||
)
|
||||
|
||||
|
||||
def test_async_streaming_overhead_not_regressed():
|
||||
"""
|
||||
Micro-benchmark for the async path: 200 text chunks in < 2 s.
|
||||
"""
|
||||
n_chunks = 200
|
||||
chunks = [_make_generic_chunk(f"token-{i}") for i in range(n_chunks)]
|
||||
chunks.append(_make_generic_chunk("", is_finished=True, finish_reason="stop"))
|
||||
|
||||
async def _run():
|
||||
wrapper = _make_wrapper(chunks, provider="anthropic")
|
||||
start = time.monotonic()
|
||||
results = await _drain_async(wrapper)
|
||||
return results, time.monotonic() - start
|
||||
|
||||
results, elapsed = asyncio.run(_run())
|
||||
assert len(results) > 0
|
||||
assert elapsed < 2.0, (
|
||||
f"Async streaming of {n_chunks} chunks took {elapsed:.3f}s — "
|
||||
"per-chunk overhead regression detected"
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user