diff --git a/litellm/litellm_core_utils/streaming_handler.py b/litellm/litellm_core_utils/streaming_handler.py index fa7faf3035..4642201ca6 100644 --- a/litellm/litellm_core_utils/streaming_handler.py +++ b/litellm/litellm_core_utils/streaming_handler.py @@ -59,6 +59,8 @@ FUNCTION_CALL_ATTRIBUTE = "function_call" _SYNC_ITER_EXHAUSTED = object() +_GCHUNK_FIELDS: frozenset = frozenset(GChunk.__annotations__) + def _next_sync_or_exhausted(it: Any) -> Any: """ @@ -181,6 +183,30 @@ class CustomStreamWrapper: self.created: Optional[int] = None self._last_returned_hidden_params: Optional[dict] = None + _cached_logging_provider = self.logging_obj.model_call_details.get( + "custom_llm_provider", None + ) + self._cached_logging_llm_provider: Optional[str] = _cached_logging_provider + _effective_model = model or "" + if ( + custom_llm_provider == "openai" + and custom_llm_provider != _cached_logging_provider + ): + _effective_model = "{}/{}".format( + _cached_logging_provider, _effective_model + ) + self._cached_model_name: str = _effective_model + + # Snapshot assumes self._hidden_params is populated from litellm_params + # at init and never mutated during the stream. If that ever changes, + # this cache must be removed. + self._base_hidden_params: Dict[str, Any] = { + **self._hidden_params, + "response_cost": None, + } + + self._post_streaming_hooks: Optional[List] = None + def _check_max_streaming_duration(self) -> None: """Raise litellm.Timeout if the stream has exceeded LITELLM_MAX_STREAMING_DURATION_SECONDS.""" from litellm.constants import LITELLM_MAX_STREAMING_DURATION_SECONDS @@ -681,29 +707,16 @@ class CustomStreamWrapper: def model_response_creator( self, chunk: Optional[dict] = None, hidden_params: Optional[dict] = None ): - _model = self.model - _received_llm_provider = self.custom_llm_provider - _logging_obj_llm_provider = self.logging_obj.model_call_details.get("custom_llm_provider", None) # type: ignore - if ( - _received_llm_provider == "openai" - and _received_llm_provider != _logging_obj_llm_provider - ): - _model = "{}/{}".format(_logging_obj_llm_provider, _model) + _model = self._cached_model_name + _logging_obj_llm_provider = self._cached_logging_llm_provider + if chunk is None: - chunk = {} + args: Dict[str, Any] = {"model": _model} else: - # pop model keyword chunk.pop("model", None) - - chunk_dict = {} - for key, value in chunk.items(): - if key != "stream": - chunk_dict[key] = value - - args = { - "model": _model, - **chunk_dict, - } + args = {"model": _model} + if chunk: + args.update({k: v for k, v in chunk.items() if k != "stream"}) model_response = ModelResponseStream(**args) if self.response_id is not None: @@ -717,15 +730,23 @@ class CustomStreamWrapper: model_response.created = self.created else: self.created = model_response.created + + # Spread order is load-bearing: _base_hidden_params (model_id, api_base, ...) + # must win over both caller-supplied hidden_params and the computed + # custom_llm_provider/created_at values, so it comes last. if hidden_params is not None: - model_response._hidden_params = hidden_params - model_response._hidden_params["custom_llm_provider"] = _logging_obj_llm_provider - model_response._hidden_params["created_at"] = time.time() - model_response._hidden_params = { - **model_response._hidden_params, - **self._hidden_params, - "response_cost": None, - } + model_response._hidden_params = { + **hidden_params, + "custom_llm_provider": _logging_obj_llm_provider, + "created_at": time.time(), + **self._base_hidden_params, + } + else: + model_response._hidden_params = { + "custom_llm_provider": _logging_obj_llm_provider, + "created_at": time.time(), + **self._base_hidden_params, + } if ( len(model_response.choices) > 0 @@ -1627,7 +1648,17 @@ class CustomStreamWrapper: from litellm.integrations.custom_logger import CustomLogger from litellm.types.utils import CallTypes - # Get request kwargs from logging object + if self._post_streaming_hooks is None: + self._post_streaming_hooks = [ + cb + for cb in litellm.callbacks + if isinstance(cb, CustomLogger) + and hasattr(cb, "async_post_call_streaming_deployment_hook") + ] + + if not self._post_streaming_hooks: + return chunk + request_data = self.logging_obj.model_call_details call_type_str = self.logging_obj.call_type @@ -1636,18 +1667,14 @@ class CustomStreamWrapper: except ValueError: typed_call_type = None - # Call hooks for all callbacks - for callback in litellm.callbacks: - if isinstance(callback, CustomLogger) and hasattr( - callback, "async_post_call_streaming_deployment_hook" - ): - result = await callback.async_post_call_streaming_deployment_hook( - request_data=request_data, - response_chunk=chunk, - call_type=typed_call_type, - ) - if result is not None: - chunk = result + for callback in self._post_streaming_hooks: + result = await callback.async_post_call_streaming_deployment_hook( + request_data=request_data, + response_chunk=chunk, + call_type=typed_call_type, + ) + if result is not None: + chunk = result return chunk except Exception as e: @@ -1888,17 +1915,15 @@ class CustomStreamWrapper: response = self._add_mcp_list_tools_to_first_chunk(response) self.sent_first_chunk = True - if hasattr( - response, "usage" - ): # remove usage from chunk, only send on final chunk - # Convert the object to a dictionary + # ModelResponseStream declares `usage` as a field, so + # hasattr(response, "usage") is always True — must check + # `is not None` to avoid running this path on every chunk. + if getattr(response, "usage", None) is not None: obj_dict = response.model_dump() - # Remove an attribute (e.g., 'attr2') if "usage" in obj_dict: del obj_dict["usage"] - # Create a new object without the removed attribute response = self.model_response_creator( chunk=obj_dict, hidden_params=response._hidden_params ) @@ -2398,10 +2423,7 @@ def generic_chunk_has_all_required_fields(chunk: dict) -> bool: :param chunk: The dictionary to check. :return: True if all required fields are present, False otherwise. """ - _all_fields = GChunk.__annotations__ - - decision = all(key in _all_fields for key in chunk) - return decision + return all(key in _GCHUNK_FIELDS for key in chunk) def convert_generic_chunk_to_model_response_stream( diff --git a/scripts/benchmark_model_response_creator.py b/scripts/benchmark_model_response_creator.py new file mode 100644 index 0000000000..881870d385 --- /dev/null +++ b/scripts/benchmark_model_response_creator.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python3 +"""Tight microbenchmark for CustomStreamWrapper.model_response_creator. + +Calls model_response_creator() in a tight loop on a pre-built wrapper to +isolate per-call cost. Driving the full wrapper adds threadpool logging, +gc, and other noise that swamps microsecond-scale changes here. + +Example: + uv run python scripts/benchmark_model_response_creator.py --label baseline + uv run python scripts/benchmark_model_response_creator.py --label optimized +""" + +from __future__ import annotations + +import argparse +import gc +import json +import logging +import os +import statistics +import time +from dataclasses import asdict, dataclass +from typing import List +from unittest.mock import MagicMock + +os.environ.setdefault("LITELLM_LOG", "ERROR") +logging.getLogger("LiteLLM").setLevel(logging.ERROR) + +import litellm # noqa: E402 + +litellm.suppress_debug_info = True + +from litellm.litellm_core_utils.streaming_handler import ( + CustomStreamWrapper, +) # noqa: E402 + + +def _make_logging_obj(provider: str) -> MagicMock: + logging_obj = MagicMock() + logging_obj.model_call_details = { + "custom_llm_provider": provider, + "litellm_params": {}, + } + logging_obj.call_type = "completion" + logging_obj.stream_options = None + logging_obj.messages = [{"role": "user", "content": "hi"}] + logging_obj.completion_start_time = None + logging_obj._llm_caching_handler = None + return logging_obj + + +def _make_wrapper(provider: str, model: str) -> CustomStreamWrapper: + return CustomStreamWrapper( + completion_stream=iter([]), + model=model, + logging_obj=_make_logging_obj(provider), + custom_llm_provider=provider, + ) + + +@dataclass +class Result: + label: str + scenario: str + iterations: int + elapsed_min_s: float + elapsed_median_s: float + per_call_us: float + calls_per_sec: float + + +SCENARIOS = { + "no_chunk": { + "description": "model_response_creator() — no chunk arg (most common path)", + "chunk_factory": lambda i: None, + }, + "text_chunk": { + "description": "model_response_creator(chunk={'text': '...'}) — text delta path", + "chunk_factory": lambda i: {"text": f"token{i}"}, + }, + "rich_chunk": { + "description": "model_response_creator(chunk={...}) — full chunk dict path", + "chunk_factory": lambda i: { + "id": f"id-{i}", + "object": "chat.completion.chunk", + "created": 1234567890, + }, + }, +} + + +def bench_no_chunk(wrapper: CustomStreamWrapper, iterations: int) -> float: + gc.collect() + gc.disable() + try: + start = time.perf_counter() + for _ in range(iterations): + wrapper.model_response_creator() + elapsed = time.perf_counter() - start + finally: + gc.enable() + return elapsed + + +def bench_with_chunk(wrapper: CustomStreamWrapper, factory, iterations: int) -> float: + # Pre-build chunks so we don't measure their construction cost. + chunks = [factory(i) for i in range(iterations)] + gc.collect() + gc.disable() + try: + start = time.perf_counter() + for chunk in chunks: + wrapper.model_response_creator(chunk=dict(chunk)) # copy because mutated + elapsed = time.perf_counter() - start + finally: + gc.enable() + return elapsed + + +def run_scenario( + label: str, + scenario_key: str, + iterations: int, + repeats: int, + warmup: int, +) -> Result: + spec = SCENARIOS[scenario_key] + wrapper = _make_wrapper(provider="anthropic", model="claude-3-5-sonnet") + + if scenario_key == "no_chunk": + runner = lambda: bench_no_chunk(wrapper, iterations) # noqa: E731 + else: + runner = lambda: bench_with_chunk( + wrapper, spec["chunk_factory"], iterations + ) # noqa: E731 + + for _ in range(warmup): + runner() + samples = [runner() for _ in range(repeats)] + + elapsed_min = min(samples) + elapsed_median = statistics.median(samples) + per_call_us = (elapsed_min * 1_000_000) / iterations + calls_per_sec = iterations / elapsed_min if elapsed_min > 0 else 0.0 + + return Result( + label=label, + scenario=scenario_key, + iterations=iterations, + elapsed_min_s=elapsed_min, + elapsed_median_s=elapsed_median, + per_call_us=per_call_us, + calls_per_sec=calls_per_sec, + ) + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + ap.add_argument("--label", required=True) + ap.add_argument("--iterations", type=int, default=200_000) + ap.add_argument("--warmup", type=int, default=2) + ap.add_argument("--repeats", type=int, default=8) + ap.add_argument("--json", dest="json_out") + args = ap.parse_args() + + print( + f"\n=== label={args.label} iterations={args.iterations:,} " + f"warmup={args.warmup} repeats={args.repeats} (min reported) ===" + ) + results: List[Result] = [] + for scenario in SCENARIOS: + r = run_scenario( + args.label, scenario, args.iterations, args.repeats, args.warmup + ) + results.append(r) + print( + f" {r.scenario:12s}: " + f"min={r.elapsed_min_s*1000:8.2f} ms " + f"median={r.elapsed_median_s*1000:8.2f} ms " + f"per-call={r.per_call_us:7.3f} μs " + f"calls/s={r.calls_per_sec:>12,.0f}" + ) + + if args.json_out: + with open(args.json_out, "w", encoding="utf-8") as f: + json.dump([asdict(r) for r in results], f, indent=2) + print(f"\nWrote {len(results)} results to {args.json_out}") + + +if __name__ == "__main__": + main() diff --git a/scripts/benchmark_streaming_chunk_overhead.py b/scripts/benchmark_streaming_chunk_overhead.py new file mode 100644 index 0000000000..948be096be --- /dev/null +++ b/scripts/benchmark_streaming_chunk_overhead.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python3 +"""Benchmark CustomStreamWrapper per-chunk overhead. + +Drives CustomStreamWrapper directly with synthetic in-memory chunks for +Anthropic (GenericStreamingChunk), Bedrock Invoke (GenericStreamingChunk), +and Bedrock Converse (ModelResponseStream). A full proxy benchmark adds +FastAPI, HTTP, and TCP latency, which dilutes the per-chunk CPU signal. + +Example: + uv run python scripts/benchmark_streaming_chunk_overhead.py \\ + --streams 500 --chunks 200 --warmup 50 --repeats 5 +""" + +from __future__ import annotations + +import argparse +import asyncio +import gc +import json +import logging +import os +import statistics +import time +from dataclasses import asdict, dataclass +from typing import Callable, List, Optional +from unittest.mock import MagicMock + +# Silence litellm's "Provider List" warnings emitted by get_llm_provider +# when it sees synthetic model names — we're not exercising provider +# routing, only the per-chunk wrapper hot path. +os.environ.setdefault("LITELLM_LOG", "ERROR") +logging.getLogger("LiteLLM").setLevel(logging.ERROR) + +import litellm # noqa: E402 + +litellm.suppress_debug_info = True + +from litellm.litellm_core_utils.streaming_handler import ( + CustomStreamWrapper, +) # noqa: E402 +from litellm.types.utils import ( # noqa: E402 + Delta, + GenericStreamingChunk as GChunk, + ModelResponseStream, + StreamingChoices, + Usage, +) + +# --------------------------------------------------------------------------- +# Synthetic chunk fixtures +# --------------------------------------------------------------------------- + + +def _make_logging_obj(provider: str) -> MagicMock: + logging_obj = MagicMock() + logging_obj.model_call_details = { + "custom_llm_provider": provider, + "litellm_params": {}, + } + logging_obj.call_type = "completion" + logging_obj.stream_options = None + logging_obj.messages = [{"role": "user", "content": "hi"}] + logging_obj.completion_start_time = None + logging_obj._llm_caching_handler = None + return logging_obj + + +def _make_generic_chunk( + text: str, + is_finished: bool = False, + finish_reason: str = "", + usage: Optional[dict] = None, +) -> GChunk: + return GChunk( + text=text, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=0, + tool_use=None, + ) + + +def _make_converse_chunk( + text: str = "", + finish_reason: str = "", + usage: Optional[Usage] = None, +) -> ModelResponseStream: + return ModelResponseStream( + choices=[ + StreamingChoices( + finish_reason=finish_reason or None, + index=0, + delta=Delta(content=text, role="assistant"), + ) + ], + id="msg-bench", + model="anthropic.claude-3-5-sonnet", + usage=usage, + ) + + +# --------------------------------------------------------------------------- +# Provider stream factories +# --------------------------------------------------------------------------- + + +def anthropic_chunks(n: int) -> List[GChunk]: + out: List[GChunk] = [_make_generic_chunk(f"tok{i} ") for i in range(n)] + out.append( + _make_generic_chunk( + "", + is_finished=True, + finish_reason="stop", + usage={"prompt_tokens": 10, "completion_tokens": n, "total_tokens": 10 + n}, + ) + ) + return out + + +def bedrock_invoke_chunks(n: int) -> List[GChunk]: + # Bedrock Invoke surfaces GChunk-shaped dicts, same shape as Anthropic. + return anthropic_chunks(n) + + +def bedrock_converse_chunks(n: int) -> List[ModelResponseStream]: + out: List[ModelResponseStream] = [ + _make_converse_chunk(f"tok{i} ") for i in range(n) + ] + out.append( + _make_converse_chunk( + text="", + finish_reason="stop", + usage=Usage(prompt_tokens=10, completion_tokens=n, total_tokens=10 + n), + ) + ) + return out + + +PROVIDERS: dict[str, tuple[str, Callable[[int], list]]] = { + "anthropic": ("anthropic", anthropic_chunks), + "bedrock_invoke": ("bedrock", bedrock_invoke_chunks), + "bedrock_converse": ("bedrock", bedrock_converse_chunks), +} + + +# --------------------------------------------------------------------------- +# Drive a single stream end-to-end +# --------------------------------------------------------------------------- + + +def _make_wrapper( + chunks: list, provider: str, async_stream: bool +) -> CustomStreamWrapper: + logging_obj = _make_logging_obj(provider) + if async_stream: + + async def _agen(): + for c in chunks: + yield c + + stream = _agen() + else: + stream = iter(chunks) + return CustomStreamWrapper( + completion_stream=stream, + model="claude-3-5-sonnet", + logging_obj=logging_obj, + custom_llm_provider=provider, + ) + + +def drive_sync(provider_key: str, chunks_per_stream: int, n_streams: int) -> float: + provider, factory = PROVIDERS[provider_key] + # Pre-build the chunk lists; we only measure wrapper iteration cost. + chunk_lists = [factory(chunks_per_stream) for _ in range(n_streams)] + gc.collect() + gc.disable() + try: + start = time.perf_counter() + for chunks in chunk_lists: + wrapper = _make_wrapper(chunks, provider, async_stream=False) + for _ in wrapper: + pass + elapsed = time.perf_counter() - start + finally: + gc.enable() + return elapsed + + +async def drive_async( + provider_key: str, chunks_per_stream: int, n_streams: int +) -> float: + provider, factory = PROVIDERS[provider_key] + chunk_lists = [factory(chunks_per_stream) for _ in range(n_streams)] + gc.collect() + gc.disable() + try: + start = time.perf_counter() + for chunks in chunk_lists: + wrapper = _make_wrapper(chunks, provider, async_stream=True) + async for _ in wrapper: + pass + elapsed = time.perf_counter() - start + finally: + gc.enable() + return elapsed + + +# --------------------------------------------------------------------------- +# Repeat × take-min runner +# --------------------------------------------------------------------------- + + +@dataclass +class Result: + label: str + provider: str + mode: str + streams: int + chunks_per_stream: int + total_chunks: int + elapsed_min_s: float + elapsed_median_s: float + per_chunk_us: float + chunks_per_sec: float + streams_per_sec: float + + +def run_case( + label: str, + provider_key: str, + mode: str, + chunks_per_stream: int, + n_streams: int, + repeats: int, + warmup: int, +) -> Result: + if mode == "sync": + # Warmup runs amortize import-time and JIT-y caches. + for _ in range(warmup): + drive_sync(provider_key, chunks_per_stream, max(1, n_streams // 10)) + samples = [ + drive_sync(provider_key, chunks_per_stream, n_streams) + for _ in range(repeats) + ] + elif mode == "async": + + async def _warm(): + for _ in range(warmup): + await drive_async( + provider_key, chunks_per_stream, max(1, n_streams // 10) + ) + + asyncio.run(_warm()) + samples = [ + asyncio.run(drive_async(provider_key, chunks_per_stream, n_streams)) + for _ in range(repeats) + ] + else: + raise ValueError(f"unknown mode {mode!r}") + + elapsed_min = min(samples) + elapsed_median = statistics.median(samples) + # Each stream emits chunks_per_stream text chunks + 1 finish/usage chunk. + total_chunks = n_streams * (chunks_per_stream + 1) + per_chunk_us = (elapsed_min * 1_000_000) / total_chunks + chunks_per_sec = total_chunks / elapsed_min if elapsed_min > 0 else 0.0 + streams_per_sec = n_streams / elapsed_min if elapsed_min > 0 else 0.0 + + return Result( + label=label, + provider=provider_key, + mode=mode, + streams=n_streams, + chunks_per_stream=chunks_per_stream, + total_chunks=total_chunks, + elapsed_min_s=elapsed_min, + elapsed_median_s=elapsed_median, + per_chunk_us=per_chunk_us, + chunks_per_sec=chunks_per_sec, + streams_per_sec=streams_per_sec, + ) + + +def format_result(r: Result) -> str: + return ( + f" {r.provider:18s} {r.mode:5s}: " + f"min={r.elapsed_min_s*1000:8.2f} ms " + f"median={r.elapsed_median_s*1000:8.2f} ms " + f"per-chunk={r.per_chunk_us:7.2f} μs " + f"chunks/s={r.chunks_per_sec:>10,.0f} " + f"streams/s={r.streams_per_sec:>8,.1f}" + ) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + + +def main() -> None: + ap = argparse.ArgumentParser(description=__doc__.splitlines()[0]) + ap.add_argument( + "--label", required=True, help="Run label (e.g. baseline / optimized)" + ) + ap.add_argument("--streams", type=int, default=500, help="Streams per run") + ap.add_argument( + "--chunks", + type=int, + default=200, + help="Text chunks per stream (excl. finish chunk)", + ) + ap.add_argument("--warmup", type=int, default=2, help="Warmup runs") + ap.add_argument( + "--repeats", type=int, default=5, help="Measured runs (we report min)" + ) + ap.add_argument( + "--providers", + default="anthropic,bedrock_invoke,bedrock_converse", + help="Comma-separated provider list", + ) + ap.add_argument( + "--modes", + default="sync,async", + help="Comma-separated iteration modes (sync/async)", + ) + ap.add_argument( + "--json", dest="json_out", help="Write results as JSON to this path" + ) + args = ap.parse_args() + + providers = [p.strip() for p in args.providers.split(",") if p.strip()] + modes = [m.strip() for m in args.modes.split(",") if m.strip()] + + for p in providers: + if p not in PROVIDERS: + raise SystemExit(f"unknown provider {p!r}; choose from {list(PROVIDERS)}") + for m in modes: + if m not in {"sync", "async"}: + raise SystemExit(f"unknown mode {m!r}; choose from sync/async") + + print( + f"\n=== label={args.label} streams={args.streams} chunks/stream={args.chunks} " + f"warmup={args.warmup} repeats={args.repeats} (min reported) ===" + ) + results: List[Result] = [] + for provider_key in providers: + for mode in modes: + r = run_case( + label=args.label, + provider_key=provider_key, + mode=mode, + chunks_per_stream=args.chunks, + n_streams=args.streams, + repeats=args.repeats, + warmup=args.warmup, + ) + results.append(r) + print(format_result(r)) + + if args.json_out: + with open(args.json_out, "w", encoding="utf-8") as f: + json.dump([asdict(r) for r in results], f, indent=2) + print(f"\nWrote {len(results)} results to {args.json_out}") + + +if __name__ == "__main__": + main() diff --git a/tests/test_litellm/litellm_core_utils/test_streaming_overhead.py b/tests/test_litellm/litellm_core_utils/test_streaming_overhead.py new file mode 100644 index 0000000000..8fb0659ab5 --- /dev/null +++ b/tests/test_litellm/litellm_core_utils/test_streaming_overhead.py @@ -0,0 +1,508 @@ +""" +Tests for CustomStreamWrapper per-chunk behavior across Anthropic, +Bedrock Invoke, and Bedrock Converse: text passthrough, usage stripping, +hidden_params propagation, finish_reason, sync/async parity, and the +per-stream caches (_GCHUNK_FIELDS, _post_streaming_hooks). +""" + +import asyncio +import time +from typing import List, Optional +from unittest.mock import MagicMock, patch + +import litellm +from litellm.litellm_core_utils.streaming_handler import ( + CustomStreamWrapper, + _GCHUNK_FIELDS, + generic_chunk_has_all_required_fields, +) +from litellm.types.utils import ( + Delta, + GenericStreamingChunk as GChunk, + ModelResponseStream, + StreamingChoices, + Usage, +) + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + + +def _make_logging_obj(provider: str = "anthropic") -> MagicMock: + logging_obj = MagicMock() + logging_obj.model_call_details = { + "custom_llm_provider": provider, + "litellm_params": {}, + } + logging_obj.call_type = "completion" + logging_obj.stream_options = None + logging_obj.messages = [{"role": "user", "content": "hi"}] + logging_obj.completion_start_time = None + logging_obj._llm_caching_handler = None + return logging_obj + + +def _make_generic_chunk( + text: str, + is_finished: bool = False, + finish_reason: str = "", + usage: Optional[dict] = None, +) -> GChunk: + return GChunk( + text=text, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + index=0, + tool_use=None, + ) + + +def _make_bedrock_converse_chunk( + text: str = "", + finish_reason: str = "", + usage: Optional[Usage] = None, +) -> ModelResponseStream: + """Simulate what AWSEventStreamDecoder.converse_chunk_parser returns.""" + return ModelResponseStream( + choices=[ + StreamingChoices( + finish_reason=finish_reason or None, + index=0, + delta=Delta(content=text, role="assistant"), + ) + ], + id="msg-test", + model="anthropic.claude-3-5-sonnet", + usage=usage, + ) + + +async def _async_iter(chunks: list): + """Wrap a list as a proper async iterator for use in __anext__ async branch.""" + for chunk in chunks: + yield chunk + + +def _make_wrapper( + chunks: list, + provider: str = "anthropic", + async_stream: bool = False, +) -> CustomStreamWrapper: + logging_obj = _make_logging_obj(provider) + stream = _async_iter(chunks) if async_stream else iter(chunks) + wrapper = CustomStreamWrapper( + completion_stream=stream, + model="claude-3-5-sonnet", + logging_obj=logging_obj, + custom_llm_provider=provider, + ) + return wrapper + + +def _drain_sync(wrapper: CustomStreamWrapper) -> List[ModelResponseStream]: + results = [] + for chunk in wrapper: + results.append(chunk) + return results + + +async def _drain_async(wrapper: CustomStreamWrapper) -> List[ModelResponseStream]: + results = [] + async for chunk in wrapper: + results.append(chunk) + return results + + +# --------------------------------------------------------------------------- +# 1. Module-level _GCHUNK_FIELDS constant +# --------------------------------------------------------------------------- + + +def test_gchunk_fields_is_frozenset(): + """_GCHUNK_FIELDS must be a frozenset built from GChunk.__annotations__.""" + assert isinstance(_GCHUNK_FIELDS, frozenset) + assert _GCHUNK_FIELDS == frozenset(GChunk.__annotations__) + + +def test_generic_chunk_has_all_required_fields_uses_module_constant(monkeypatch): + """generic_chunk_has_all_required_fields must use _GCHUNK_FIELDS, not __annotations__. + + The check semantics: every key in `chunk` must be a known GChunk field. + This identifies GChunk-shaped dicts (all keys are valid GChunk fields). + """ + valid_chunk = _make_generic_chunk("hello") + assert generic_chunk_has_all_required_fields(valid_chunk) is True + + # A dict with an extra unknown key should return False — the unknown key + # is not a GChunk field, so the chunk is not a pure GChunk. + extra_key_chunk = dict(valid_chunk) + extra_key_chunk["unknown_extra_key"] = "value" + assert generic_chunk_has_all_required_fields(extra_key_chunk) is False + + # A dict with only known GChunk fields but fewer keys still passes because + # all its keys are valid (subset of GChunk fields). + partial_chunk = {"text": "hi", "is_finished": False} + assert generic_chunk_has_all_required_fields(partial_chunk) is True + + +# --------------------------------------------------------------------------- +# 2. Cached model name and provider at init time +# --------------------------------------------------------------------------- + + +def test_cached_model_name_simple(): + """For non-openai providers the cached model name must match the model arg.""" + wrapper = _make_wrapper([], provider="anthropic") + assert wrapper._cached_model_name == "claude-3-5-sonnet" + assert wrapper._cached_logging_llm_provider == "anthropic" + + +def test_cached_model_name_openai_prefix(): + """For openai provider when logging provider differs, model name is prefixed.""" + logging_obj = _make_logging_obj(provider="azure") + wrapper = CustomStreamWrapper( + completion_stream=iter([]), + model="gpt-4o", + logging_obj=logging_obj, + custom_llm_provider="openai", + ) + assert wrapper._cached_model_name == "azure/gpt-4o" + assert wrapper._cached_logging_llm_provider == "azure" + + +def test_base_hidden_params_precomputed(): + """_base_hidden_params must be pre-built from _hidden_params at init.""" + wrapper = _make_wrapper([], provider="anthropic") + assert "response_cost" in wrapper._base_hidden_params + assert wrapper._base_hidden_params["response_cost"] is None + # Must include all keys from _hidden_params + for k in wrapper._hidden_params: + assert k in wrapper._base_hidden_params + + +# --------------------------------------------------------------------------- +# 3. Sync path: model_dump() is NOT called on non-usage chunks +# --------------------------------------------------------------------------- + + +def test_sync_path_no_model_dump_on_text_chunks(): + """ + The sync __next__ must NOT call model_dump() on chunks that have no usage. + + ModelResponseStream declares `usage` as a field, so a `hasattr` check + would always succeed and trigger the model_dump()+recreate path on every + chunk. The wrapper must check `is not None` instead. + """ + chunks = [ + _make_generic_chunk("Hello"), + _make_generic_chunk(" world"), + _make_generic_chunk("", is_finished=True, finish_reason="stop"), + ] + wrapper = _make_wrapper(chunks) + + model_dump_call_count = 0 + original_model_dump = ModelResponseStream.model_dump + + def counting_model_dump(self, **kwargs): + nonlocal model_dump_call_count + model_dump_call_count += 1 + return original_model_dump(self, **kwargs) + + with patch.object(ModelResponseStream, "model_dump", counting_model_dump): + results = _drain_sync(wrapper) + + text_chunks = [r for r in results if r.choices and r.choices[0].delta.content] + assert len(text_chunks) >= 2, "Expected at least 2 text chunks" + assert model_dump_call_count <= 1, ( + f"model_dump() called {model_dump_call_count} times — " + "usage check is firing on every chunk" + ) + + +# --------------------------------------------------------------------------- +# 4. Sync path: usage chunk is stripped from body but preserved in hidden_params +# --------------------------------------------------------------------------- + + +def test_sync_path_usage_stripped_from_body_preserved_in_hidden_params(): + """Usage data must be removed from the returned chunk but added to _hidden_params.""" + usage_dict = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30} + chunks = [ + _make_generic_chunk("Hello"), + _make_generic_chunk( + "", is_finished=True, finish_reason="stop", usage=usage_dict + ), + ] + wrapper = _make_wrapper(chunks) + results = _drain_sync(wrapper) + + # The usage chunk must be returned (not silently dropped) + finish_chunks = [ + r for r in results if r.choices and r.choices[0].finish_reason == "stop" + ] + assert finish_chunks, "Finish-reason chunk was not returned" + + # The final chunk must carry usage in _hidden_params + final = results[-1] + assert "usage" in final._hidden_params, "usage missing from _hidden_params" + hidden_usage = final._hidden_params["usage"] + assert hidden_usage is not None + + +# --------------------------------------------------------------------------- +# 5. Async path: usage chunk is stripped from body but preserved in hidden_params +# --------------------------------------------------------------------------- + + +def test_async_path_usage_stripped_from_body_preserved_in_hidden_params(): + """Async path mirrors sync path for usage handling.""" + usage_dict = {"prompt_tokens": 5, "completion_tokens": 15, "total_tokens": 20} + chunks = [ + _make_generic_chunk("Hi"), + _make_generic_chunk( + "", is_finished=True, finish_reason="stop", usage=usage_dict + ), + ] + + async def _run(): + # async_stream=True forces the real async-for branch of __anext__ + wrapper = _make_wrapper(chunks, async_stream=True) + return await _drain_async(wrapper) + + results = asyncio.run(_run()) + final = results[-1] + assert "usage" in final._hidden_params + assert final._hidden_params["usage"] is not None + + +# --------------------------------------------------------------------------- +# 6. Bedrock Converse: ModelResponseStream chunks pass through correctly +# --------------------------------------------------------------------------- + + +def test_bedrock_converse_text_chunks_pass_through(): + """ + Bedrock Converse returns ModelResponseStream objects directly. + They should pass through chunk_creator and appear in output unchanged. + """ + chunks = [ + _make_bedrock_converse_chunk("Hello"), + _make_bedrock_converse_chunk(" world"), + _make_bedrock_converse_chunk("", finish_reason="end_turn"), + ] + wrapper = _make_wrapper(chunks, provider="bedrock") + results = _drain_sync(wrapper) + + texts = [ + r.choices[0].delta.content + for r in results + if r.choices and r.choices[0].delta.content + ] + assert "Hello" in texts or any("Hello" in (t or "") for t in texts) + + +def test_bedrock_converse_usage_chunk_stripped_and_in_hidden_params(): + """Usage in a Bedrock Converse ModelResponseStream chunk is handled correctly.""" + usage = Usage(prompt_tokens=8, completion_tokens=12, total_tokens=20) + chunks = [ + _make_bedrock_converse_chunk("Hi"), + _make_bedrock_converse_chunk("", finish_reason="end_turn", usage=usage), + ] + wrapper = _make_wrapper(chunks, provider="bedrock") + results = _drain_sync(wrapper) + + final = results[-1] + assert "usage" in final._hidden_params + assert final._hidden_params["usage"] is not None + + +# --------------------------------------------------------------------------- +# 7. Anthropic generic chunk (GChunk) path +# --------------------------------------------------------------------------- + + +def test_anthropic_generic_chunks_text_pass_through(): + """GChunk text chunks must arrive in the output with correct content.""" + chunks = [ + _make_generic_chunk("The"), + _make_generic_chunk(" answer"), + _make_generic_chunk("", is_finished=True, finish_reason="stop"), + ] + wrapper = _make_wrapper(chunks, provider="anthropic") + results = _drain_sync(wrapper) + + texts = [ + r.choices[0].delta.content + for r in results + if r.choices and r.choices[0].delta.content + ] + assert len(texts) >= 2 + + +def test_anthropic_finish_reason_propagated(): + """finish_reason must be set on the final streaming chunk.""" + chunks = [ + _make_generic_chunk("Hi"), + _make_generic_chunk("", is_finished=True, finish_reason="stop"), + ] + wrapper = _make_wrapper(chunks, provider="anthropic") + results = _drain_sync(wrapper) + + finish_reasons = [ + r.choices[0].finish_reason + for r in results + if r.choices and r.choices[0].finish_reason + ] + assert "stop" in finish_reasons + + +# --------------------------------------------------------------------------- +# 8. Callback caching: _post_streaming_hooks resolved once per stream +# --------------------------------------------------------------------------- + + +def test_post_streaming_hooks_cached_after_first_call(): + """ + _post_streaming_hooks must be None before the first hook call and a list after. + The same list object must be reused on subsequent calls (not re-built). + """ + wrapper = _make_wrapper([], provider="anthropic") + assert wrapper._post_streaming_hooks is None, "Must be None before first call" + + async def _run(): + # Simulate hook resolution with an empty callback list + with patch.object(litellm, "callbacks", []): + await wrapper._call_post_streaming_deployment_hook( + MagicMock(spec=ModelResponseStream) + ) + first_list = wrapper._post_streaming_hooks + assert isinstance(first_list, list) + + # Second call must reuse the same list object + with patch.object(litellm, "callbacks", []): + await wrapper._call_post_streaming_deployment_hook( + MagicMock(spec=ModelResponseStream) + ) + assert ( + wrapper._post_streaming_hooks is first_list + ), "_post_streaming_hooks was rebuilt on second call — caching broken" + + asyncio.run(_run()) + + +def test_post_streaming_hooks_filters_correctly(): + """ + Only CustomLogger instances must be included; plain callables are excluded. + + Note: CustomLogger's base class already defines + async_post_call_streaming_deployment_hook, so ALL CustomLogger subclasses + pass the hasattr() check regardless of whether they override the method. + The filter therefore keeps any CustomLogger instance and drops anything else. + """ + from litellm.integrations.custom_logger import CustomLogger + + class MyLogger(CustomLogger): + pass + + plain_callable = MagicMock() + + wrapper = _make_wrapper([], provider="anthropic") + + async def _run(): + with patch.object(litellm, "callbacks", [MyLogger(), plain_callable]): + await wrapper._call_post_streaming_deployment_hook( + MagicMock(spec=ModelResponseStream) + ) + + # plain_callable must be excluded; MyLogger (CustomLogger subclass) included + assert len(wrapper._post_streaming_hooks) == 1 + assert isinstance(wrapper._post_streaming_hooks[0], MyLogger) + + asyncio.run(_run()) + + +# --------------------------------------------------------------------------- +# 9. model_response_creator: hidden_params built correctly +# --------------------------------------------------------------------------- + + +def test_model_response_creator_hidden_params_no_chunk(): + """model_response_creator() with no args must include all _base_hidden_params.""" + wrapper = _make_wrapper([], provider="anthropic") + response = wrapper.model_response_creator() + + assert response._hidden_params.get("response_cost") is None + assert response._hidden_params.get("custom_llm_provider") == "anthropic" + assert "created_at" in response._hidden_params + + +def test_model_response_creator_hidden_params_caller_merged(): + """When hidden_params are passed by caller, they must be included in result.""" + wrapper = _make_wrapper([], provider="anthropic") + caller_params = {"some_key": "some_value"} + response = wrapper.model_response_creator(hidden_params=caller_params) + + assert response._hidden_params.get("some_key") == "some_value" + assert response._hidden_params.get("response_cost") is None + + +def test_model_response_creator_stream_key_stripped(): + """The 'stream' key must be removed from chunk before constructing ModelResponseStream.""" + wrapper = _make_wrapper([], provider="anthropic") + chunk = {"stream": True, "choices": []} + # Should not raise even if 'stream' would be an invalid ModelResponseStream field + response = wrapper.model_response_creator(chunk=chunk) + assert response is not None + + +# --------------------------------------------------------------------------- +# 10. Per-chunk overhead regression: sync path must not regress +# --------------------------------------------------------------------------- + + +def test_sync_streaming_overhead_not_regressed(): + """ + Micro-benchmark: the sync hot path must process 200 text chunks in < 2 s. + + This test acts as a canary for gross per-chunk overhead regressions. + It is intentionally generous (2 s) to avoid flakiness on slow CI runners. + """ + n_chunks = 200 + chunks = [_make_generic_chunk(f"token-{i}") for i in range(n_chunks)] + chunks.append(_make_generic_chunk("", is_finished=True, finish_reason="stop")) + + wrapper = _make_wrapper(chunks, provider="anthropic") + + start = time.monotonic() + results = _drain_sync(wrapper) + elapsed = time.monotonic() - start + + assert len(results) > 0, "No chunks returned" + assert elapsed < 2.0, ( + f"Sync streaming of {n_chunks} chunks took {elapsed:.3f}s — " + "per-chunk overhead regression detected" + ) + + +def test_async_streaming_overhead_not_regressed(): + """ + Micro-benchmark for the async path: 200 text chunks in < 2 s. + """ + n_chunks = 200 + chunks = [_make_generic_chunk(f"token-{i}") for i in range(n_chunks)] + chunks.append(_make_generic_chunk("", is_finished=True, finish_reason="stop")) + + async def _run(): + wrapper = _make_wrapper(chunks, provider="anthropic") + start = time.monotonic() + results = await _drain_async(wrapper) + return results, time.monotonic() - start + + results, elapsed = asyncio.run(_run()) + assert len(results) > 0 + assert elapsed < 2.0, ( + f"Async streaming of {n_chunks} chunks took {elapsed:.3f}s — " + "per-chunk overhead regression detected" + )