refactor: trim explanatory comments from streaming-flush fix
Strip module-level docstrings and per-test/per-block prose from the LIT-2642 fix and tests. Keep one short comment in each streaming site that flags the GeneratorExit-vs-Exception subtlety, since that's the non-obvious reason the flush lives in finally rather than after the loop. Pure cleanup; no behavior change. All 12 regression tests still pass. Co-authored-by: Mateo Wang <mateo-berri@users.noreply.github.com>
This commit is contained in:
parent
1ef034bff6
commit
8759413312
@ -397,10 +397,6 @@ def _sync_streaming(
|
||||
raw_bytes.append(chunk)
|
||||
yield chunk
|
||||
finally:
|
||||
# Always flush collected chunks for spend tracking, even if the
|
||||
# consumer terminates the generator early (GeneratorExit). Without
|
||||
# this, an interrupted stream loses all per-chunk usage data
|
||||
# because the post-loop flush never runs. See LIT-2642.
|
||||
if not flush_scheduled and raw_bytes:
|
||||
flush_scheduled = True
|
||||
try:
|
||||
@ -410,8 +406,6 @@ def _sync_streaming(
|
||||
provider_config=provider_config,
|
||||
)
|
||||
except Exception:
|
||||
# Don't mask the original exception (incl. GeneratorExit)
|
||||
# if scheduling the flush itself fails.
|
||||
pass
|
||||
|
||||
|
||||
@ -422,8 +416,6 @@ async def _async_streaming(
|
||||
):
|
||||
iter_response = await response
|
||||
|
||||
# Validate response status before consuming the body so 4xx/5xx
|
||||
# responses raise without entering the chunk-collection path.
|
||||
try:
|
||||
iter_response.raise_for_status()
|
||||
except Exception:
|
||||
@ -446,12 +438,9 @@ async def _async_streaming(
|
||||
pass
|
||||
raise
|
||||
finally:
|
||||
# Always flush collected chunks for spend tracking, even if the
|
||||
# client disconnects mid-stream. On disconnect, Starlette calls
|
||||
# aclose() on this generator, which raises GeneratorExit at the
|
||||
# suspended `yield` — `except Exception` does not catch it, so
|
||||
# the post-loop flush would otherwise be skipped and all
|
||||
# captured per-chunk usage data lost. See LIT-2642.
|
||||
# GeneratorExit (raised on client disconnect) is not caught by
|
||||
# `except Exception`; the finally block ensures partial usage
|
||||
# still gets flushed for spend tracking. See LIT-2642.
|
||||
if not flush_scheduled and raw_bytes:
|
||||
flush_scheduled = True
|
||||
try:
|
||||
@ -462,6 +451,4 @@ async def _async_streaming(
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
# Don't mask the original exception (incl. GeneratorExit)
|
||||
# if scheduling the flush itself fails.
|
||||
pass
|
||||
|
||||
@ -36,14 +36,8 @@ class PassThroughStreamingHandler:
|
||||
passthrough_success_handler_obj: PassThroughEndpointLogging,
|
||||
url_route: str,
|
||||
):
|
||||
"""
|
||||
- Yields chunks from the response
|
||||
- Collect non-empty chunks for post-processing (logging)
|
||||
- Inject cost into chunks if include_cost_in_streaming_usage is enabled
|
||||
"""
|
||||
raw_bytes: List[bytes] = []
|
||||
logging_scheduled = False
|
||||
# Extract model name for cost injection
|
||||
model_name = PassThroughStreamingHandler._extract_model_for_cost_injection(
|
||||
request_body=request_body,
|
||||
url_route=url_route,
|
||||
@ -59,7 +53,6 @@ class PassThroughStreamingHandler:
|
||||
and model_name
|
||||
):
|
||||
if endpoint_type == EndpointType.VERTEX_AI:
|
||||
# Only handle streamRawPredict (uses Anthropic format)
|
||||
if "streamRawPredict" in url_route or "rawPredict" in url_route:
|
||||
modified_chunk = ProxyBaseLLMRequestProcessing._process_chunk_with_cost_injection(
|
||||
chunk, model_name
|
||||
@ -78,17 +71,12 @@ class PassThroughStreamingHandler:
|
||||
verbose_proxy_logger.error(f"Error in chunk_processor: {str(e)}")
|
||||
raise
|
||||
finally:
|
||||
# Always log collected chunks for spend tracking, even if the
|
||||
# client disconnects mid-stream. On disconnect, Starlette calls
|
||||
# aclose() on this async generator, which raises GeneratorExit
|
||||
# at the suspended `yield` — `except Exception` does not catch
|
||||
# it, so post-loop logging would otherwise be skipped and all
|
||||
# captured per-chunk usage data lost (e.g. for interrupted
|
||||
# Bedrock streams). See LIT-2642.
|
||||
# GeneratorExit (raised on client disconnect) is not caught by
|
||||
# `except Exception`; the finally block ensures partial usage
|
||||
# still gets logged for spend tracking. See LIT-2642.
|
||||
if not logging_scheduled and raw_bytes:
|
||||
logging_scheduled = True
|
||||
try:
|
||||
end_time = datetime.now()
|
||||
asyncio.create_task(
|
||||
PassThroughStreamingHandler._route_streaming_logging_to_handler(
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
@ -98,7 +86,7 @@ class PassThroughStreamingHandler:
|
||||
endpoint_type=endpoint_type,
|
||||
start_time=start_time,
|
||||
raw_bytes=raw_bytes,
|
||||
end_time=end_time,
|
||||
end_time=datetime.now(),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@ -1,27 +1,14 @@
|
||||
"""
|
||||
Regression tests for LIT-2642 — interrupted streaming responses must still
|
||||
flush collected chunks so spend is tracked even when the client disconnects
|
||||
mid-stream.
|
||||
|
||||
Bedrock invoke streaming was the reported reproducer: the proxy passes the
|
||||
upstream stream through `_async_streaming` in `litellm/passthrough/main.py`,
|
||||
which collects bytes and triggers `async_flush_passthrough_collected_chunks`
|
||||
once the loop completes. When a FastAPI client disconnects mid-stream,
|
||||
Starlette calls `aclose()` on the async generator and raises `GeneratorExit`
|
||||
at the suspended `yield`. The previous `except Exception` branch did not
|
||||
catch `GeneratorExit`, so the post-loop flush was skipped and all per-chunk
|
||||
usage data was dropped.
|
||||
"""
|
||||
"""Regression tests for LIT-2642 — interrupted streams must still flush usage."""
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
def _make_streaming_response(chunks: List[bytes]):
|
||||
"""Build a mock httpx.Response that streams the given chunks via aiter_bytes."""
|
||||
mock = MagicMock(spec=httpx.Response)
|
||||
mock.status_code = 200
|
||||
mock.headers = httpx.Headers({"content-type": "application/vnd.amazon.eventstream"})
|
||||
@ -42,9 +29,13 @@ def _make_logging_obj():
|
||||
return mock
|
||||
|
||||
|
||||
class _ImmediateExecutor:
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
fn(*args, **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_flushes_on_normal_completion():
|
||||
"""Baseline: full stream consumption flushes collected chunks once."""
|
||||
from litellm.passthrough.main import _async_streaming
|
||||
|
||||
chunks = [b"chunk-1", b"chunk-2", b"chunk-3"]
|
||||
@ -66,9 +57,6 @@ async def test_async_streaming_flushes_on_normal_completion():
|
||||
|
||||
assert received == chunks
|
||||
|
||||
# Allow the scheduled task to run.
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
mock_logging_obj.async_flush_passthrough_collected_chunks.assert_called_once()
|
||||
@ -81,11 +69,6 @@ async def test_async_streaming_flushes_on_normal_completion():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_flushes_on_client_disconnect():
|
||||
"""
|
||||
LIT-2642 regression: GeneratorExit (raised when the consumer disconnects
|
||||
mid-stream) must still flush whatever chunks we already collected so
|
||||
spend tracking captures the partial usage.
|
||||
"""
|
||||
from litellm.passthrough.main import _async_streaming
|
||||
|
||||
chunks = [
|
||||
@ -107,31 +90,22 @@ async def test_async_streaming_flushes_on_client_disconnect():
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
# Pull one chunk, then close the generator early — mirrors what
|
||||
# Starlette does when the HTTP client disconnects mid-stream.
|
||||
received = [await gen.__anext__()]
|
||||
await gen.aclose()
|
||||
|
||||
assert received == [chunks[0]]
|
||||
|
||||
# Allow the scheduled flush task to run.
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
mock_logging_obj.async_flush_passthrough_collected_chunks.assert_called_once()
|
||||
call_kwargs = (
|
||||
mock_logging_obj.async_flush_passthrough_collected_chunks.call_args.kwargs
|
||||
)
|
||||
# Only the first chunk was consumed before disconnect; that's what we
|
||||
# must hand off to the cost-tracking flush so partial usage isn't
|
||||
# silently dropped.
|
||||
assert call_kwargs["raw_bytes"] == [chunks[0]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_does_not_flush_on_4xx():
|
||||
"""Error responses must still raise without entering the flush path."""
|
||||
from litellm.passthrough.main import _async_streaming
|
||||
|
||||
err_response = MagicMock(spec=httpx.Response)
|
||||
@ -162,17 +136,11 @@ async def test_async_streaming_does_not_flush_on_4xx():
|
||||
):
|
||||
pass
|
||||
|
||||
# No bytes were collected, so no flush should have been scheduled.
|
||||
mock_logging_obj.async_flush_passthrough_collected_chunks.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_streaming_flushes_on_upstream_exception_with_partial_data():
|
||||
"""
|
||||
If the upstream connection drops mid-stream and aiter_bytes raises,
|
||||
we still surface the exception, but partial chunks already collected
|
||||
are flushed so spend tracking isn't fully lost.
|
||||
"""
|
||||
from litellm.passthrough.main import _async_streaming
|
||||
|
||||
partial_chunks = [b"partial-chunk-1", b"partial-chunk-2"]
|
||||
@ -206,8 +174,6 @@ async def test_async_streaming_flushes_on_upstream_exception_with_partial_data()
|
||||
|
||||
assert received == partial_chunks
|
||||
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
mock_logging_obj.async_flush_passthrough_collected_chunks.assert_called_once()
|
||||
@ -218,7 +184,6 @@ async def test_async_streaming_flushes_on_upstream_exception_with_partial_data()
|
||||
|
||||
|
||||
def test_sync_streaming_flushes_on_normal_completion():
|
||||
"""Baseline for the sync codepath."""
|
||||
from litellm.passthrough.main import _sync_streaming
|
||||
|
||||
chunks = [b"a", b"b", b"c"]
|
||||
@ -234,13 +199,6 @@ def test_sync_streaming_flushes_on_normal_completion():
|
||||
mock_logging_obj.flush_passthrough_collected_chunks = MagicMock()
|
||||
provider_config = MagicMock()
|
||||
|
||||
# Use a synchronous in-process executor so we can assert immediately.
|
||||
class _ImmediateExecutor:
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
fn(*args, **kwargs)
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch("litellm.utils.executor", _ImmediateExecutor()):
|
||||
received = list(
|
||||
_sync_streaming(
|
||||
@ -255,10 +213,6 @@ def test_sync_streaming_flushes_on_normal_completion():
|
||||
|
||||
|
||||
def test_sync_streaming_flushes_on_early_close():
|
||||
"""
|
||||
Sync analog of LIT-2642: closing the generator early must still flush
|
||||
so per-chunk usage data is not silently dropped.
|
||||
"""
|
||||
from litellm.passthrough.main import _sync_streaming
|
||||
|
||||
chunks = [b"first", b"second", b"third"]
|
||||
@ -274,12 +228,6 @@ def test_sync_streaming_flushes_on_early_close():
|
||||
mock_logging_obj.flush_passthrough_collected_chunks = MagicMock()
|
||||
provider_config = MagicMock()
|
||||
|
||||
class _ImmediateExecutor:
|
||||
def submit(self, fn, *args, **kwargs):
|
||||
fn(*args, **kwargs)
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch("litellm.utils.executor", _ImmediateExecutor()):
|
||||
gen = _sync_streaming(
|
||||
response=mock_response,
|
||||
@ -287,7 +235,6 @@ def test_sync_streaming_flushes_on_early_close():
|
||||
provider_config=provider_config,
|
||||
)
|
||||
|
||||
# Consume one chunk, then close — analog of a client disconnect.
|
||||
first = next(gen)
|
||||
gen.close()
|
||||
|
||||
|
||||
@ -1,15 +1,6 @@
|
||||
"""
|
||||
Regression tests for LIT-2642 — interrupted pass-through streams must still
|
||||
trigger logging so spend is tracked.
|
||||
|
||||
`PassThroughStreamingHandler.chunk_processor` collects bytes from the
|
||||
upstream response and schedules `_route_streaming_logging_to_handler` once
|
||||
the chunk loop completes. When a FastAPI client disconnects mid-stream,
|
||||
Starlette calls `aclose()` on the async generator and raises `GeneratorExit`
|
||||
at the suspended `yield`. The previous `except Exception` branch did not
|
||||
catch `GeneratorExit`, so the post-loop logging task was never scheduled.
|
||||
"""
|
||||
"""Regression tests for LIT-2642 — interrupted pass-through streams must still log usage."""
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
@ -36,7 +27,6 @@ def _make_streaming_response(chunks):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_processor_logs_on_normal_completion():
|
||||
"""Baseline: full consumption schedules logging exactly once."""
|
||||
chunks = [b"chunk-1", b"chunk-2", b"chunk-3"]
|
||||
response = _make_streaming_response(chunks)
|
||||
|
||||
@ -60,8 +50,6 @@ async def test_chunk_processor_logs_on_normal_completion():
|
||||
):
|
||||
received.append(chunk)
|
||||
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert received == chunks
|
||||
@ -72,11 +60,6 @@ async def test_chunk_processor_logs_on_normal_completion():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_processor_logs_on_client_disconnect():
|
||||
"""
|
||||
LIT-2642 regression: closing the generator early (e.g. client
|
||||
disconnect) must still schedule logging so per-chunk spend data
|
||||
isn't dropped.
|
||||
"""
|
||||
chunks = [b"event-1", b"event-2", b"event-3"]
|
||||
response = _make_streaming_response(chunks)
|
||||
|
||||
@ -98,26 +81,19 @@ async def test_chunk_processor_logs_on_client_disconnect():
|
||||
url_route="/bedrock/model/claude/invoke-with-response-stream",
|
||||
)
|
||||
|
||||
# Consume one chunk, then close the generator — same path Starlette
|
||||
# takes when the HTTP client disconnects mid-stream.
|
||||
first = await gen.__anext__()
|
||||
await gen.aclose()
|
||||
|
||||
import asyncio
|
||||
|
||||
await asyncio.sleep(0)
|
||||
|
||||
assert first == chunks[0]
|
||||
mock_route.assert_called_once()
|
||||
call_kwargs = mock_route.call_args.kwargs
|
||||
# Only one chunk made it through before disconnect — that is what
|
||||
# the logging handler must be given so partial usage is captured.
|
||||
assert call_kwargs["raw_bytes"] == [chunks[0]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chunk_processor_does_not_schedule_logging_when_no_chunks():
|
||||
"""If no chunks were ever received, don't schedule a no-op logging task."""
|
||||
response = _make_streaming_response([])
|
||||
|
||||
mock_logging_obj = MagicMock()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user