fix: improve bedrock streaming hot path perf (#28720)

This commit is contained in:
Yassin Kortam 2026-05-28 11:31:37 -07:00 committed by GitHub
parent 1cb19b155e
commit d5d6b26a72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 1141 additions and 51 deletions

View File

@ -59,6 +59,8 @@ FUNCTION_CALL_ATTRIBUTE = "function_call"
_SYNC_ITER_EXHAUSTED = object()
_GCHUNK_FIELDS: frozenset = frozenset(GChunk.__annotations__)
def _next_sync_or_exhausted(it: Any) -> Any:
"""
@ -181,6 +183,30 @@ class CustomStreamWrapper:
self.created: Optional[int] = None
self._last_returned_hidden_params: Optional[dict] = None
_cached_logging_provider = self.logging_obj.model_call_details.get(
"custom_llm_provider", None
)
self._cached_logging_llm_provider: Optional[str] = _cached_logging_provider
_effective_model = model or ""
if (
custom_llm_provider == "openai"
and custom_llm_provider != _cached_logging_provider
):
_effective_model = "{}/{}".format(
_cached_logging_provider, _effective_model
)
self._cached_model_name: str = _effective_model
# Snapshot assumes self._hidden_params is populated from litellm_params
# at init and never mutated during the stream. If that ever changes,
# this cache must be removed.
self._base_hidden_params: Dict[str, Any] = {
**self._hidden_params,
"response_cost": None,
}
self._post_streaming_hooks: Optional[List] = None
def _check_max_streaming_duration(self) -> None:
"""Raise litellm.Timeout if the stream has exceeded LITELLM_MAX_STREAMING_DURATION_SECONDS."""
from litellm.constants import LITELLM_MAX_STREAMING_DURATION_SECONDS
@ -681,29 +707,16 @@ class CustomStreamWrapper:
def model_response_creator(
self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None
):
_model = self.model
_received_llm_provider = self.custom_llm_provider
_logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore
if (
_received_llm_provider == "openai"
and _received_llm_provider != _logging_obj_llm_provider
):
_model = "{}/{}".format(_logging_obj_llm_provider, _model)
_model = self._cached_model_name
_logging_obj_llm_provider = self._cached_logging_llm_provider
if chunk is None:
chunk = {}
args: Dict[str, Any] = {"model": _model}
else:
# pop model keyword
chunk.pop("model", None)
chunk_dict = {}
for key, value in chunk.items():
if key != "stream":
chunk_dict[key] = value
args = {
"model": _model,
**chunk_dict,
}
args = {"model": _model}
if chunk:
args.update({k: v for k, v in chunk.items() if k != "stream"})
model_response = ModelResponseStream(**args)
if self.response_id is not None:
@ -717,15 +730,23 @@ class CustomStreamWrapper:
model_response.created = self.created
else:
self.created = model_response.created
# Spread order is load-bearing: _base_hidden_params (model_id, api_base, ...)
# must win over both caller-supplied hidden_params and the computed
# custom_llm_provider/created_at values, so it comes last.
if hidden_params is not None:
model_response._hidden_params = hidden_params
model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider
model_response._hidden_params["created_at"] = time.time()
model_response._hidden_params = {
**model_response._hidden_params,
**self._hidden_params,
"response_cost": None,
}
model_response._hidden_params = {
**hidden_params,
"custom_llm_provider": _logging_obj_llm_provider,
"created_at": time.time(),
**self._base_hidden_params,
}
else:
model_response._hidden_params = {
"custom_llm_provider": _logging_obj_llm_provider,
"created_at": time.time(),
**self._base_hidden_params,
}
if (
len(model_response.choices) > 0
@ -1627,7 +1648,17 @@ class CustomStreamWrapper:
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import CallTypes
# Get request kwargs from logging object
if self._post_streaming_hooks is None:
self._post_streaming_hooks = [
cb
for cb in litellm.callbacks
if isinstance(cb, CustomLogger)
and hasattr(cb, "async_post_call_streaming_deployment_hook")
]
if not self._post_streaming_hooks:
return chunk
request_data = self.logging_obj.model_call_details
call_type_str = self.logging_obj.call_type
@ -1636,18 +1667,14 @@ class CustomStreamWrapper:
except ValueError:
typed_call_type = None
# Call hooks for all callbacks
for callback in litellm.callbacks:
if isinstance(callback, CustomLogger) and hasattr(
callback, "async_post_call_streaming_deployment_hook"
):
result = await callback.async_post_call_streaming_deployment_hook(
request_data=request_data,
response_chunk=chunk,
call_type=typed_call_type,
)
if result is not None:
chunk = result
for callback in self._post_streaming_hooks:
result = await callback.async_post_call_streaming_deployment_hook(
request_data=request_data,
response_chunk=chunk,
call_type=typed_call_type,
)
if result is not None:
chunk = result
return chunk
except Exception as e:
@ -1888,17 +1915,15 @@ class CustomStreamWrapper:
response = self._add_mcp_list_tools_to_first_chunk(response)
self.sent_first_chunk = True
if hasattr(
response, "usage"
): # remove usage from chunk, only send on final chunk
# Convert the object to a dictionary
# ModelResponseStream declares `usage` as a field, so
# hasattr(response, "usage") is always True — must check
# `is not None` to avoid running this path on every chunk.
if getattr(response, "usage", None) is not None:
obj_dict = response.model_dump()
# Remove an attribute (e.g., 'attr2')
if "usage" in obj_dict:
del obj_dict["usage"]
# Create a new object without the removed attribute
response = self.model_response_creator(
chunk=obj_dict, hidden_params=response._hidden_params
)
@ -2398,10 +2423,7 @@ def generic_chunk_has_all_required_fields(chunk: dict) -> bool:
:param chunk: The dictionary to check.
:return: True if all required fields are present, False otherwise.
"""
_all_fields = GChunk.__annotations__
decision = all(key in _all_fields for key in chunk)
return decision
return all(key in _GCHUNK_FIELDS for key in chunk)
def convert_generic_chunk_to_model_response_stream(

View File

@ -0,0 +1,191 @@
#!/usr/bin/env python3
"""Tight microbenchmark for CustomStreamWrapper.model_response_creator.
Calls model_response_creator() in a tight loop on a pre-built wrapper to
isolate per-call cost. Driving the full wrapper adds threadpool logging,
gc, and other noise that swamps microsecond-scale changes here.
Example:
uv run python scripts/benchmark_model_response_creator.py --label baseline
uv run python scripts/benchmark_model_response_creator.py --label optimized
"""
from __future__ import annotations
import argparse
import gc
import json
import logging
import os
import statistics
import time
from dataclasses import asdict, dataclass
from typing import List
from unittest.mock import MagicMock
os.environ.setdefault("LITELLM_LOG", "ERROR")
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
import litellm # noqa: E402
litellm.suppress_debug_info = True
from litellm.litellm_core_utils.streaming_handler import (
CustomStreamWrapper,
) # noqa: E402
def _make_logging_obj(provider: str) -> MagicMock:
logging_obj = MagicMock()
logging_obj.model_call_details = {
"custom_llm_provider": provider,
"litellm_params": {},
}
logging_obj.call_type = "completion"
logging_obj.stream_options = None
logging_obj.messages = [{"role": "user", "content": "hi"}]
logging_obj.completion_start_time = None
logging_obj._llm_caching_handler = None
return logging_obj
def _make_wrapper(provider: str, model: str) -> CustomStreamWrapper:
return CustomStreamWrapper(
completion_stream=iter([]),
model=model,
logging_obj=_make_logging_obj(provider),
custom_llm_provider=provider,
)
@dataclass
class Result:
label: str
scenario: str
iterations: int
elapsed_min_s: float
elapsed_median_s: float
per_call_us: float
calls_per_sec: float
SCENARIOS = {
"no_chunk": {
"description": "model_response_creator() — no chunk arg (most common path)",
"chunk_factory": lambda i: None,
},
"text_chunk": {
"description": "model_response_creator(chunk={'text': '...'}) — text delta path",
"chunk_factory": lambda i: {"text": f"token{i}"},
},
"rich_chunk": {
"description": "model_response_creator(chunk={...}) — full chunk dict path",
"chunk_factory": lambda i: {
"id": f"id-{i}",
"object": "chat.completion.chunk",
"created": 1234567890,
},
},
}
def bench_no_chunk(wrapper: CustomStreamWrapper, iterations: int) -> float:
gc.collect()
gc.disable()
try:
start = time.perf_counter()
for _ in range(iterations):
wrapper.model_response_creator()
elapsed = time.perf_counter() - start
finally:
gc.enable()
return elapsed
def bench_with_chunk(wrapper: CustomStreamWrapper, factory, iterations: int) -> float:
# Pre-build chunks so we don't measure their construction cost.
chunks = [factory(i) for i in range(iterations)]
gc.collect()
gc.disable()
try:
start = time.perf_counter()
for chunk in chunks:
wrapper.model_response_creator(chunk=dict(chunk)) # copy because mutated
elapsed = time.perf_counter() - start
finally:
gc.enable()
return elapsed
def run_scenario(
label: str,
scenario_key: str,
iterations: int,
repeats: int,
warmup: int,
) -> Result:
spec = SCENARIOS[scenario_key]
wrapper = _make_wrapper(provider="anthropic", model="claude-3-5-sonnet")
if scenario_key == "no_chunk":
runner = lambda: bench_no_chunk(wrapper, iterations) # noqa: E731
else:
runner = lambda: bench_with_chunk(
wrapper, spec["chunk_factory"], iterations
) # noqa: E731
for _ in range(warmup):
runner()
samples = [runner() for _ in range(repeats)]
elapsed_min = min(samples)
elapsed_median = statistics.median(samples)
per_call_us = (elapsed_min * 1_000_000) / iterations
calls_per_sec = iterations / elapsed_min if elapsed_min > 0 else 0.0
return Result(
label=label,
scenario=scenario_key,
iterations=iterations,
elapsed_min_s=elapsed_min,
elapsed_median_s=elapsed_median,
per_call_us=per_call_us,
calls_per_sec=calls_per_sec,
)
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__.splitlines()[0])
ap.add_argument("--label", required=True)
ap.add_argument("--iterations", type=int, default=200_000)
ap.add_argument("--warmup", type=int, default=2)
ap.add_argument("--repeats", type=int, default=8)
ap.add_argument("--json", dest="json_out")
args = ap.parse_args()
print(
f"\n=== label={args.label} iterations={args.iterations:,} "
f"warmup={args.warmup} repeats={args.repeats} (min reported) ==="
)
results: List[Result] = []
for scenario in SCENARIOS:
r = run_scenario(
args.label, scenario, args.iterations, args.repeats, args.warmup
)
results.append(r)
print(
f" {r.scenario:12s}: "
f"min={r.elapsed_min_s*1000:8.2f} ms "
f"median={r.elapsed_median_s*1000:8.2f} ms "
f"per-call={r.per_call_us:7.3f} μs "
f"calls/s={r.calls_per_sec:>12,.0f}"
)
if args.json_out:
with open(args.json_out, "w", encoding="utf-8") as f:
json.dump([asdict(r) for r in results], f, indent=2)
print(f"\nWrote {len(results)} results to {args.json_out}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,369 @@
#!/usr/bin/env python3
"""Benchmark CustomStreamWrapper per-chunk overhead.
Drives CustomStreamWrapper directly with synthetic in-memory chunks for
Anthropic (GenericStreamingChunk), Bedrock Invoke (GenericStreamingChunk),
and Bedrock Converse (ModelResponseStream). A full proxy benchmark adds
FastAPI, HTTP, and TCP latency, which dilutes the per-chunk CPU signal.
Example:
uv run python scripts/benchmark_streaming_chunk_overhead.py \\
--streams 500 --chunks 200 --warmup 50 --repeats 5
"""
from __future__ import annotations
import argparse
import asyncio
import gc
import json
import logging
import os
import statistics
import time
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional
from unittest.mock import MagicMock
# Silence litellm's "Provider List" warnings emitted by get_llm_provider
# when it sees synthetic model names — we're not exercising provider
# routing, only the per-chunk wrapper hot path.
os.environ.setdefault("LITELLM_LOG", "ERROR")
logging.getLogger("LiteLLM").setLevel(logging.ERROR)
import litellm # noqa: E402
litellm.suppress_debug_info = True
from litellm.litellm_core_utils.streaming_handler import (
CustomStreamWrapper,
) # noqa: E402
from litellm.types.utils import ( # noqa: E402
Delta,
GenericStreamingChunk as GChunk,
ModelResponseStream,
StreamingChoices,
Usage,
)
# ---------------------------------------------------------------------------
# Synthetic chunk fixtures
# ---------------------------------------------------------------------------
def _make_logging_obj(provider: str) -> MagicMock:
logging_obj = MagicMock()
logging_obj.model_call_details = {
"custom_llm_provider": provider,
"litellm_params": {},
}
logging_obj.call_type = "completion"
logging_obj.stream_options = None
logging_obj.messages = [{"role": "user", "content": "hi"}]
logging_obj.completion_start_time = None
logging_obj._llm_caching_handler = None
return logging_obj
def _make_generic_chunk(
text: str,
is_finished: bool = False,
finish_reason: str = "",
usage: Optional[dict] = None,
) -> GChunk:
return GChunk(
text=text,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=0,
tool_use=None,
)
def _make_converse_chunk(
text: str = "",
finish_reason: str = "",
usage: Optional[Usage] = None,
) -> ModelResponseStream:
return ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=finish_reason or None,
index=0,
delta=Delta(content=text, role="assistant"),
)
],
id="msg-bench",
model="anthropic.claude-3-5-sonnet",
usage=usage,
)
# ---------------------------------------------------------------------------
# Provider stream factories
# ---------------------------------------------------------------------------
def anthropic_chunks(n: int) -> List[GChunk]:
out: List[GChunk] = [_make_generic_chunk(f"tok{i} ") for i in range(n)]
out.append(
_make_generic_chunk(
"",
is_finished=True,
finish_reason="stop",
usage={"prompt_tokens": 10, "completion_tokens": n, "total_tokens": 10 + n},
)
)
return out
def bedrock_invoke_chunks(n: int) -> List[GChunk]:
# Bedrock Invoke surfaces GChunk-shaped dicts, same shape as Anthropic.
return anthropic_chunks(n)
def bedrock_converse_chunks(n: int) -> List[ModelResponseStream]:
out: List[ModelResponseStream] = [
_make_converse_chunk(f"tok{i} ") for i in range(n)
]
out.append(
_make_converse_chunk(
text="",
finish_reason="stop",
usage=Usage(prompt_tokens=10, completion_tokens=n, total_tokens=10 + n),
)
)
return out
PROVIDERS: dict[str, tuple[str, Callable[[int], list]]] = {
"anthropic": ("anthropic", anthropic_chunks),
"bedrock_invoke": ("bedrock", bedrock_invoke_chunks),
"bedrock_converse": ("bedrock", bedrock_converse_chunks),
}
# ---------------------------------------------------------------------------
# Drive a single stream end-to-end
# ---------------------------------------------------------------------------
def _make_wrapper(
chunks: list, provider: str, async_stream: bool
) -> CustomStreamWrapper:
logging_obj = _make_logging_obj(provider)
if async_stream:
async def _agen():
for c in chunks:
yield c
stream = _agen()
else:
stream = iter(chunks)
return CustomStreamWrapper(
completion_stream=stream,
model="claude-3-5-sonnet",
logging_obj=logging_obj,
custom_llm_provider=provider,
)
def drive_sync(provider_key: str, chunks_per_stream: int, n_streams: int) -> float:
provider, factory = PROVIDERS[provider_key]
# Pre-build the chunk lists; we only measure wrapper iteration cost.
chunk_lists = [factory(chunks_per_stream) for _ in range(n_streams)]
gc.collect()
gc.disable()
try:
start = time.perf_counter()
for chunks in chunk_lists:
wrapper = _make_wrapper(chunks, provider, async_stream=False)
for _ in wrapper:
pass
elapsed = time.perf_counter() - start
finally:
gc.enable()
return elapsed
async def drive_async(
provider_key: str, chunks_per_stream: int, n_streams: int
) -> float:
provider, factory = PROVIDERS[provider_key]
chunk_lists = [factory(chunks_per_stream) for _ in range(n_streams)]
gc.collect()
gc.disable()
try:
start = time.perf_counter()
for chunks in chunk_lists:
wrapper = _make_wrapper(chunks, provider, async_stream=True)
async for _ in wrapper:
pass
elapsed = time.perf_counter() - start
finally:
gc.enable()
return elapsed
# ---------------------------------------------------------------------------
# Repeat × take-min runner
# ---------------------------------------------------------------------------
@dataclass
class Result:
label: str
provider: str
mode: str
streams: int
chunks_per_stream: int
total_chunks: int
elapsed_min_s: float
elapsed_median_s: float
per_chunk_us: float
chunks_per_sec: float
streams_per_sec: float
def run_case(
label: str,
provider_key: str,
mode: str,
chunks_per_stream: int,
n_streams: int,
repeats: int,
warmup: int,
) -> Result:
if mode == "sync":
# Warmup runs amortize import-time and JIT-y caches.
for _ in range(warmup):
drive_sync(provider_key, chunks_per_stream, max(1, n_streams // 10))
samples = [
drive_sync(provider_key, chunks_per_stream, n_streams)
for _ in range(repeats)
]
elif mode == "async":
async def _warm():
for _ in range(warmup):
await drive_async(
provider_key, chunks_per_stream, max(1, n_streams // 10)
)
asyncio.run(_warm())
samples = [
asyncio.run(drive_async(provider_key, chunks_per_stream, n_streams))
for _ in range(repeats)
]
else:
raise ValueError(f"unknown mode {mode!r}")
elapsed_min = min(samples)
elapsed_median = statistics.median(samples)
# Each stream emits chunks_per_stream text chunks + 1 finish/usage chunk.
total_chunks = n_streams * (chunks_per_stream + 1)
per_chunk_us = (elapsed_min * 1_000_000) / total_chunks
chunks_per_sec = total_chunks / elapsed_min if elapsed_min > 0 else 0.0
streams_per_sec = n_streams / elapsed_min if elapsed_min > 0 else 0.0
return Result(
label=label,
provider=provider_key,
mode=mode,
streams=n_streams,
chunks_per_stream=chunks_per_stream,
total_chunks=total_chunks,
elapsed_min_s=elapsed_min,
elapsed_median_s=elapsed_median,
per_chunk_us=per_chunk_us,
chunks_per_sec=chunks_per_sec,
streams_per_sec=streams_per_sec,
)
def format_result(r: Result) -> str:
return (
f" {r.provider:18s} {r.mode:5s}: "
f"min={r.elapsed_min_s*1000:8.2f} ms "
f"median={r.elapsed_median_s*1000:8.2f} ms "
f"per-chunk={r.per_chunk_us:7.2f} μs "
f"chunks/s={r.chunks_per_sec:>10,.0f} "
f"streams/s={r.streams_per_sec:>8,.1f}"
)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__.splitlines()[0])
ap.add_argument(
"--label", required=True, help="Run label (e.g. baseline / optimized)"
)
ap.add_argument("--streams", type=int, default=500, help="Streams per run")
ap.add_argument(
"--chunks",
type=int,
default=200,
help="Text chunks per stream (excl. finish chunk)",
)
ap.add_argument("--warmup", type=int, default=2, help="Warmup runs")
ap.add_argument(
"--repeats", type=int, default=5, help="Measured runs (we report min)"
)
ap.add_argument(
"--providers",
default="anthropic,bedrock_invoke,bedrock_converse",
help="Comma-separated provider list",
)
ap.add_argument(
"--modes",
default="sync,async",
help="Comma-separated iteration modes (sync/async)",
)
ap.add_argument(
"--json", dest="json_out", help="Write results as JSON to this path"
)
args = ap.parse_args()
providers = [p.strip() for p in args.providers.split(",") if p.strip()]
modes = [m.strip() for m in args.modes.split(",") if m.strip()]
for p in providers:
if p not in PROVIDERS:
raise SystemExit(f"unknown provider {p!r}; choose from {list(PROVIDERS)}")
for m in modes:
if m not in {"sync", "async"}:
raise SystemExit(f"unknown mode {m!r}; choose from sync/async")
print(
f"\n=== label={args.label} streams={args.streams} chunks/stream={args.chunks} "
f"warmup={args.warmup} repeats={args.repeats} (min reported) ==="
)
results: List[Result] = []
for provider_key in providers:
for mode in modes:
r = run_case(
label=args.label,
provider_key=provider_key,
mode=mode,
chunks_per_stream=args.chunks,
n_streams=args.streams,
repeats=args.repeats,
warmup=args.warmup,
)
results.append(r)
print(format_result(r))
if args.json_out:
with open(args.json_out, "w", encoding="utf-8") as f:
json.dump([asdict(r) for r in results], f, indent=2)
print(f"\nWrote {len(results)} results to {args.json_out}")
if __name__ == "__main__":
main()

View File

@ -0,0 +1,508 @@
"""
Tests for CustomStreamWrapper per-chunk behavior across Anthropic,
Bedrock Invoke, and Bedrock Converse: text passthrough, usage stripping,
hidden_params propagation, finish_reason, sync/async parity, and the
per-stream caches (_GCHUNK_FIELDS, _post_streaming_hooks).
"""
import asyncio
import time
from typing import List, Optional
from unittest.mock import MagicMock, patch
import litellm
from litellm.litellm_core_utils.streaming_handler import (
CustomStreamWrapper,
_GCHUNK_FIELDS,
generic_chunk_has_all_required_fields,
)
from litellm.types.utils import (
Delta,
GenericStreamingChunk as GChunk,
ModelResponseStream,
StreamingChoices,
Usage,
)
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def _make_logging_obj(provider: str = "anthropic") -> MagicMock:
logging_obj = MagicMock()
logging_obj.model_call_details = {
"custom_llm_provider": provider,
"litellm_params": {},
}
logging_obj.call_type = "completion"
logging_obj.stream_options = None
logging_obj.messages = [{"role": "user", "content": "hi"}]
logging_obj.completion_start_time = None
logging_obj._llm_caching_handler = None
return logging_obj
def _make_generic_chunk(
text: str,
is_finished: bool = False,
finish_reason: str = "",
usage: Optional[dict] = None,
) -> GChunk:
return GChunk(
text=text,
is_finished=is_finished,
finish_reason=finish_reason,
usage=usage,
index=0,
tool_use=None,
)
def _make_bedrock_converse_chunk(
text: str = "",
finish_reason: str = "",
usage: Optional[Usage] = None,
) -> ModelResponseStream:
"""Simulate what AWSEventStreamDecoder.converse_chunk_parser returns."""
return ModelResponseStream(
choices=[
StreamingChoices(
finish_reason=finish_reason or None,
index=0,
delta=Delta(content=text, role="assistant"),
)
],
id="msg-test",
model="anthropic.claude-3-5-sonnet",
usage=usage,
)
async def _async_iter(chunks: list):
"""Wrap a list as a proper async iterator for use in __anext__ async branch."""
for chunk in chunks:
yield chunk
def _make_wrapper(
chunks: list,
provider: str = "anthropic",
async_stream: bool = False,
) -> CustomStreamWrapper:
logging_obj = _make_logging_obj(provider)
stream = _async_iter(chunks) if async_stream else iter(chunks)
wrapper = CustomStreamWrapper(
completion_stream=stream,
model="claude-3-5-sonnet",
logging_obj=logging_obj,
custom_llm_provider=provider,
)
return wrapper
def _drain_sync(wrapper: CustomStreamWrapper) -> List[ModelResponseStream]:
results = []
for chunk in wrapper:
results.append(chunk)
return results
async def _drain_async(wrapper: CustomStreamWrapper) -> List[ModelResponseStream]:
results = []
async for chunk in wrapper:
results.append(chunk)
return results
# ---------------------------------------------------------------------------
# 1. Module-level _GCHUNK_FIELDS constant
# ---------------------------------------------------------------------------
def test_gchunk_fields_is_frozenset():
"""_GCHUNK_FIELDS must be a frozenset built from GChunk.__annotations__."""
assert isinstance(_GCHUNK_FIELDS, frozenset)
assert _GCHUNK_FIELDS == frozenset(GChunk.__annotations__)
def test_generic_chunk_has_all_required_fields_uses_module_constant(monkeypatch):
"""generic_chunk_has_all_required_fields must use _GCHUNK_FIELDS, not __annotations__.
The check semantics: every key in `chunk` must be a known GChunk field.
This identifies GChunk-shaped dicts (all keys are valid GChunk fields).
"""
valid_chunk = _make_generic_chunk("hello")
assert generic_chunk_has_all_required_fields(valid_chunk) is True
# A dict with an extra unknown key should return False — the unknown key
# is not a GChunk field, so the chunk is not a pure GChunk.
extra_key_chunk = dict(valid_chunk)
extra_key_chunk["unknown_extra_key"] = "value"
assert generic_chunk_has_all_required_fields(extra_key_chunk) is False
# A dict with only known GChunk fields but fewer keys still passes because
# all its keys are valid (subset of GChunk fields).
partial_chunk = {"text": "hi", "is_finished": False}
assert generic_chunk_has_all_required_fields(partial_chunk) is True
# ---------------------------------------------------------------------------
# 2. Cached model name and provider at init time
# ---------------------------------------------------------------------------
def test_cached_model_name_simple():
"""For non-openai providers the cached model name must match the model arg."""
wrapper = _make_wrapper([], provider="anthropic")
assert wrapper._cached_model_name == "claude-3-5-sonnet"
assert wrapper._cached_logging_llm_provider == "anthropic"
def test_cached_model_name_openai_prefix():
"""For openai provider when logging provider differs, model name is prefixed."""
logging_obj = _make_logging_obj(provider="azure")
wrapper = CustomStreamWrapper(
completion_stream=iter([]),
model="gpt-4o",
logging_obj=logging_obj,
custom_llm_provider="openai",
)
assert wrapper._cached_model_name == "azure/gpt-4o"
assert wrapper._cached_logging_llm_provider == "azure"
def test_base_hidden_params_precomputed():
"""_base_hidden_params must be pre-built from _hidden_params at init."""
wrapper = _make_wrapper([], provider="anthropic")
assert "response_cost" in wrapper._base_hidden_params
assert wrapper._base_hidden_params["response_cost"] is None
# Must include all keys from _hidden_params
for k in wrapper._hidden_params:
assert k in wrapper._base_hidden_params
# ---------------------------------------------------------------------------
# 3. Sync path: model_dump() is NOT called on non-usage chunks
# ---------------------------------------------------------------------------
def test_sync_path_no_model_dump_on_text_chunks():
"""
The sync __next__ must NOT call model_dump() on chunks that have no usage.
ModelResponseStream declares `usage` as a field, so a `hasattr` check
would always succeed and trigger the model_dump()+recreate path on every
chunk. The wrapper must check `is not None` instead.
"""
chunks = [
_make_generic_chunk("Hello"),
_make_generic_chunk(" world"),
_make_generic_chunk("", is_finished=True, finish_reason="stop"),
]
wrapper = _make_wrapper(chunks)
model_dump_call_count = 0
original_model_dump = ModelResponseStream.model_dump
def counting_model_dump(self, **kwargs):
nonlocal model_dump_call_count
model_dump_call_count += 1
return original_model_dump(self, **kwargs)
with patch.object(ModelResponseStream, "model_dump", counting_model_dump):
results = _drain_sync(wrapper)
text_chunks = [r for r in results if r.choices and r.choices[0].delta.content]
assert len(text_chunks) >= 2, "Expected at least 2 text chunks"
assert model_dump_call_count <= 1, (
f"model_dump() called {model_dump_call_count} times — "
"usage check is firing on every chunk"
)
# ---------------------------------------------------------------------------
# 4. Sync path: usage chunk is stripped from body but preserved in hidden_params
# ---------------------------------------------------------------------------
def test_sync_path_usage_stripped_from_body_preserved_in_hidden_params():
"""Usage data must be removed from the returned chunk but added to _hidden_params."""
usage_dict = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
chunks = [
_make_generic_chunk("Hello"),
_make_generic_chunk(
"", is_finished=True, finish_reason="stop", usage=usage_dict
),
]
wrapper = _make_wrapper(chunks)
results = _drain_sync(wrapper)
# The usage chunk must be returned (not silently dropped)
finish_chunks = [
r for r in results if r.choices and r.choices[0].finish_reason == "stop"
]
assert finish_chunks, "Finish-reason chunk was not returned"
# The final chunk must carry usage in _hidden_params
final = results[-1]
assert "usage" in final._hidden_params, "usage missing from _hidden_params"
hidden_usage = final._hidden_params["usage"]
assert hidden_usage is not None
# ---------------------------------------------------------------------------
# 5. Async path: usage chunk is stripped from body but preserved in hidden_params
# ---------------------------------------------------------------------------
def test_async_path_usage_stripped_from_body_preserved_in_hidden_params():
"""Async path mirrors sync path for usage handling."""
usage_dict = {"prompt_tokens": 5, "completion_tokens": 15, "total_tokens": 20}
chunks = [
_make_generic_chunk("Hi"),
_make_generic_chunk(
"", is_finished=True, finish_reason="stop", usage=usage_dict
),
]
async def _run():
# async_stream=True forces the real async-for branch of __anext__
wrapper = _make_wrapper(chunks, async_stream=True)
return await _drain_async(wrapper)
results = asyncio.run(_run())
final = results[-1]
assert "usage" in final._hidden_params
assert final._hidden_params["usage"] is not None
# ---------------------------------------------------------------------------
# 6. Bedrock Converse: ModelResponseStream chunks pass through correctly
# ---------------------------------------------------------------------------
def test_bedrock_converse_text_chunks_pass_through():
"""
Bedrock Converse returns ModelResponseStream objects directly.
They should pass through chunk_creator and appear in output unchanged.
"""
chunks = [
_make_bedrock_converse_chunk("Hello"),
_make_bedrock_converse_chunk(" world"),
_make_bedrock_converse_chunk("", finish_reason="end_turn"),
]
wrapper = _make_wrapper(chunks, provider="bedrock")
results = _drain_sync(wrapper)
texts = [
r.choices[0].delta.content
for r in results
if r.choices and r.choices[0].delta.content
]
assert "Hello" in texts or any("Hello" in (t or "") for t in texts)
def test_bedrock_converse_usage_chunk_stripped_and_in_hidden_params():
"""Usage in a Bedrock Converse ModelResponseStream chunk is handled correctly."""
usage = Usage(prompt_tokens=8, completion_tokens=12, total_tokens=20)
chunks = [
_make_bedrock_converse_chunk("Hi"),
_make_bedrock_converse_chunk("", finish_reason="end_turn", usage=usage),
]
wrapper = _make_wrapper(chunks, provider="bedrock")
results = _drain_sync(wrapper)
final = results[-1]
assert "usage" in final._hidden_params
assert final._hidden_params["usage"] is not None
# ---------------------------------------------------------------------------
# 7. Anthropic generic chunk (GChunk) path
# ---------------------------------------------------------------------------
def test_anthropic_generic_chunks_text_pass_through():
"""GChunk text chunks must arrive in the output with correct content."""
chunks = [
_make_generic_chunk("The"),
_make_generic_chunk(" answer"),
_make_generic_chunk("", is_finished=True, finish_reason="stop"),
]
wrapper = _make_wrapper(chunks, provider="anthropic")
results = _drain_sync(wrapper)
texts = [
r.choices[0].delta.content
for r in results
if r.choices and r.choices[0].delta.content
]
assert len(texts) >= 2
def test_anthropic_finish_reason_propagated():
"""finish_reason must be set on the final streaming chunk."""
chunks = [
_make_generic_chunk("Hi"),
_make_generic_chunk("", is_finished=True, finish_reason="stop"),
]
wrapper = _make_wrapper(chunks, provider="anthropic")
results = _drain_sync(wrapper)
finish_reasons = [
r.choices[0].finish_reason
for r in results
if r.choices and r.choices[0].finish_reason
]
assert "stop" in finish_reasons
# ---------------------------------------------------------------------------
# 8. Callback caching: _post_streaming_hooks resolved once per stream
# ---------------------------------------------------------------------------
def test_post_streaming_hooks_cached_after_first_call():
"""
_post_streaming_hooks must be None before the first hook call and a list after.
The same list object must be reused on subsequent calls (not re-built).
"""
wrapper = _make_wrapper([], provider="anthropic")
assert wrapper._post_streaming_hooks is None, "Must be None before first call"
async def _run():
# Simulate hook resolution with an empty callback list
with patch.object(litellm, "callbacks", []):
await wrapper._call_post_streaming_deployment_hook(
MagicMock(spec=ModelResponseStream)
)
first_list = wrapper._post_streaming_hooks
assert isinstance(first_list, list)
# Second call must reuse the same list object
with patch.object(litellm, "callbacks", []):
await wrapper._call_post_streaming_deployment_hook(
MagicMock(spec=ModelResponseStream)
)
assert (
wrapper._post_streaming_hooks is first_list
), "_post_streaming_hooks was rebuilt on second call — caching broken"
asyncio.run(_run())
def test_post_streaming_hooks_filters_correctly():
"""
Only CustomLogger instances must be included; plain callables are excluded.
Note: CustomLogger's base class already defines
async_post_call_streaming_deployment_hook, so ALL CustomLogger subclasses
pass the hasattr() check regardless of whether they override the method.
The filter therefore keeps any CustomLogger instance and drops anything else.
"""
from litellm.integrations.custom_logger import CustomLogger
class MyLogger(CustomLogger):
pass
plain_callable = MagicMock()
wrapper = _make_wrapper([], provider="anthropic")
async def _run():
with patch.object(litellm, "callbacks", [MyLogger(), plain_callable]):
await wrapper._call_post_streaming_deployment_hook(
MagicMock(spec=ModelResponseStream)
)
# plain_callable must be excluded; MyLogger (CustomLogger subclass) included
assert len(wrapper._post_streaming_hooks) == 1
assert isinstance(wrapper._post_streaming_hooks[0], MyLogger)
asyncio.run(_run())
# ---------------------------------------------------------------------------
# 9. model_response_creator: hidden_params built correctly
# ---------------------------------------------------------------------------
def test_model_response_creator_hidden_params_no_chunk():
"""model_response_creator() with no args must include all _base_hidden_params."""
wrapper = _make_wrapper([], provider="anthropic")
response = wrapper.model_response_creator()
assert response._hidden_params.get("response_cost") is None
assert response._hidden_params.get("custom_llm_provider") == "anthropic"
assert "created_at" in response._hidden_params
def test_model_response_creator_hidden_params_caller_merged():
"""When hidden_params are passed by caller, they must be included in result."""
wrapper = _make_wrapper([], provider="anthropic")
caller_params = {"some_key": "some_value"}
response = wrapper.model_response_creator(hidden_params=caller_params)
assert response._hidden_params.get("some_key") == "some_value"
assert response._hidden_params.get("response_cost") is None
def test_model_response_creator_stream_key_stripped():
"""The 'stream' key must be removed from chunk before constructing ModelResponseStream."""
wrapper = _make_wrapper([], provider="anthropic")
chunk = {"stream": True, "choices": []}
# Should not raise even if 'stream' would be an invalid ModelResponseStream field
response = wrapper.model_response_creator(chunk=chunk)
assert response is not None
# ---------------------------------------------------------------------------
# 10. Per-chunk overhead regression: sync path must not regress
# ---------------------------------------------------------------------------
def test_sync_streaming_overhead_not_regressed():
"""
Micro-benchmark: the sync hot path must process 200 text chunks in < 2 s.
This test acts as a canary for gross per-chunk overhead regressions.
It is intentionally generous (2 s) to avoid flakiness on slow CI runners.
"""
n_chunks = 200
chunks = [_make_generic_chunk(f"token-{i}") for i in range(n_chunks)]
chunks.append(_make_generic_chunk("", is_finished=True, finish_reason="stop"))
wrapper = _make_wrapper(chunks, provider="anthropic")
start = time.monotonic()
results = _drain_sync(wrapper)
elapsed = time.monotonic() - start
assert len(results) > 0, "No chunks returned"
assert elapsed < 2.0, (
f"Sync streaming of {n_chunks} chunks took {elapsed:.3f}s — "
"per-chunk overhead regression detected"
)
def test_async_streaming_overhead_not_regressed():
"""
Micro-benchmark for the async path: 200 text chunks in < 2 s.
"""
n_chunks = 200
chunks = [_make_generic_chunk(f"token-{i}") for i in range(n_chunks)]
chunks.append(_make_generic_chunk("", is_finished=True, finish_reason="stop"))
async def _run():
wrapper = _make_wrapper(chunks, provider="anthropic")
start = time.monotonic()
results = await _drain_async(wrapper)
return results, time.monotonic() - start
results, elapsed = asyncio.run(_run())
assert len(results) > 0
assert elapsed < 2.0, (
f"Async streaming of {n_chunks} chunks took {elapsed:.3f}s — "
"per-chunk overhead regression detected"
)