fix(agentcore): simplify agentcore streaming (#17141)
* fix(agentcore): simplify agentcore streaming * fix(agentcore): move CustomStreamWrapper import to module level The deferred imports inside streaming methods caused initialization delays during health check requests, leading to timeouts in ECS deployments. - Move CustomStreamWrapper import to module-level (line 19) - Remove deferred imports from get_sync_custom_stream_wrapper (line 588) - Remove deferred import from get_async_custom_stream_wrapper (line 747) - Remove from TYPE_CHECKING block to use actual import This ensures the import happens at module load time rather than during first request processing, preventing health check endpoint blocking. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com> * test(agentcore): ensure sync response * chore: upgrade boto3 to 1.40.76 in pyproject.toml * chore: added taplo.toml * fix(types): correct annotation type hint for MyPy compatibility Update _convert_annotations_to_chat_format return type from Dict[str, Any] to ChatCompletionAnnotation TypedDict to match the Message class's expected type signature. Co-Authored-By: Claude <noreply@anthropic.com> --------- Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Benedikt Óskarsson <bensi94@hotmail.com>
This commit is contained in:
parent
6cd4b3603f
commit
5db0e3289a
@ -779,10 +779,10 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
|
||||
@staticmethod
|
||||
def _convert_annotations_to_chat_format(
|
||||
annotations: Optional[List[Any]],
|
||||
) -> Optional[List["ChatCompletionAnnotation"]]:
|
||||
) -> Optional[List[ChatCompletionAnnotation]]:
|
||||
"""
|
||||
Convert annotations from Responses API to Chat Completions format.
|
||||
|
||||
|
||||
Annotations are already in compatible format between both APIs,
|
||||
so we just need to convert Pydantic models to dicts.
|
||||
"""
|
||||
|
||||
@ -1,252 +0,0 @@
|
||||
"""
|
||||
SSE Stream Iterator for Bedrock AgentCore.
|
||||
|
||||
Handles Server-Sent Events (SSE) streaming responses from AgentCore.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm._uuid import uuid
|
||||
from litellm.types.llms.bedrock_agentcore import AgentCoreUsage
|
||||
from litellm.types.utils import Delta, ModelResponse, StreamingChoices, Usage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class AgentCoreSSEStreamIterator:
|
||||
"""
|
||||
Iterator for AgentCore SSE streaming responses.
|
||||
Supports both sync and async iteration.
|
||||
|
||||
CRITICAL: The line iterators are created lazily on first access and reused.
|
||||
We must NOT create new iterators in __aiter__/__iter__ because
|
||||
CustomStreamWrapper calls __aiter__ on every call to its __anext__,
|
||||
which would create new iterators and cause StreamConsumed errors.
|
||||
"""
|
||||
|
||||
def __init__(self, response: httpx.Response, model: str):
|
||||
self.response = response
|
||||
self.model = model
|
||||
self.finished = False
|
||||
self._sync_iter: Any = None
|
||||
self._async_iter: Any = None
|
||||
self._sync_iter_initialized = False
|
||||
self._async_iter_initialized = False
|
||||
|
||||
def __iter__(self):
|
||||
"""Initialize sync iteration - create iterator lazily on first call only."""
|
||||
if not self._sync_iter_initialized:
|
||||
self._sync_iter = iter(self.response.iter_lines())
|
||||
self._sync_iter_initialized = True
|
||||
return self
|
||||
|
||||
def __aiter__(self):
|
||||
"""Initialize async iteration - create iterator lazily on first call only."""
|
||||
if not self._async_iter_initialized:
|
||||
self._async_iter = self.response.aiter_lines().__aiter__()
|
||||
self._async_iter_initialized = True
|
||||
return self
|
||||
|
||||
def _parse_sse_line(self, line: str) -> Optional[ModelResponse]:
|
||||
"""
|
||||
Parse a single SSE line and return a ModelResponse chunk if applicable.
|
||||
|
||||
AgentCore SSE format:
|
||||
- data: {"event": {"contentBlockDelta": {"delta": {"text": "..."}}}}
|
||||
- data: {"event": {"metadata": {"usage": {...}}}}
|
||||
- data: {"message": {...}}
|
||||
"""
|
||||
line = line.strip()
|
||||
if not line or not line.startswith("data:"):
|
||||
return None
|
||||
|
||||
json_str = line[5:].strip()
|
||||
if not json_str:
|
||||
return None
|
||||
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
|
||||
# Skip non-dict data (some lines contain Python repr strings)
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
# Process content delta events
|
||||
if "event" in data and isinstance(data["event"], dict):
|
||||
event_payload = data["event"]
|
||||
content_block_delta = event_payload.get("contentBlockDelta")
|
||||
|
||||
if content_block_delta:
|
||||
delta = content_block_delta.get("delta", {})
|
||||
text = delta.get("text", "")
|
||||
|
||||
if text:
|
||||
# Return chunk with text
|
||||
chunk = ModelResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=self.model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(content=text, role="assistant"),
|
||||
)
|
||||
]
|
||||
|
||||
return chunk
|
||||
|
||||
# Check for metadata/usage - this signals the end
|
||||
metadata = event_payload.get("metadata")
|
||||
if metadata and "usage" in metadata:
|
||||
chunk = ModelResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=self.model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(),
|
||||
)
|
||||
]
|
||||
|
||||
usage_data: AgentCoreUsage = metadata["usage"] # type: ignore
|
||||
setattr(
|
||||
chunk,
|
||||
"usage",
|
||||
Usage(
|
||||
prompt_tokens=usage_data.get("inputTokens", 0),
|
||||
completion_tokens=usage_data.get("outputTokens", 0),
|
||||
total_tokens=usage_data.get("totalTokens", 0),
|
||||
),
|
||||
)
|
||||
|
||||
self.finished = True
|
||||
return chunk
|
||||
|
||||
# Check for final message (alternative finish signal)
|
||||
if "message" in data and isinstance(data["message"], dict):
|
||||
if not self.finished:
|
||||
chunk = ModelResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=self.model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(),
|
||||
)
|
||||
]
|
||||
|
||||
self.finished = True
|
||||
return chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
verbose_logger.debug(f"Skipping non-JSON SSE line: {line[:100]}")
|
||||
|
||||
return None
|
||||
|
||||
def _create_final_chunk(self) -> ModelResponse:
|
||||
"""Create a final chunk to signal stream completion."""
|
||||
chunk = ModelResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=self.model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(),
|
||||
)
|
||||
]
|
||||
|
||||
return chunk
|
||||
|
||||
def __next__(self) -> ModelResponse:
|
||||
"""
|
||||
Sync iteration - parse SSE events and yield ModelResponse chunks.
|
||||
|
||||
Uses next() on the stored iterator to properly resume between calls.
|
||||
"""
|
||||
try:
|
||||
if self._sync_iter is None:
|
||||
raise StopIteration
|
||||
|
||||
# Keep getting lines until we have a result to return
|
||||
while True:
|
||||
try:
|
||||
line = next(self._sync_iter)
|
||||
except StopIteration:
|
||||
# Stream ended - send final chunk if not already finished
|
||||
if not self.finished:
|
||||
self.finished = True
|
||||
return self._create_final_chunk()
|
||||
raise
|
||||
|
||||
result = self._parse_sse_line(line)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
except StopIteration:
|
||||
raise
|
||||
except httpx.StreamConsumed:
|
||||
raise StopIteration
|
||||
except httpx.StreamClosed:
|
||||
raise StopIteration
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in AgentCore SSE stream: {str(e)}")
|
||||
raise StopIteration
|
||||
|
||||
async def __anext__(self) -> ModelResponse:
|
||||
"""
|
||||
Async iteration - parse SSE events and yield ModelResponse chunks.
|
||||
|
||||
Uses __anext__() on the stored iterator to properly resume between calls.
|
||||
"""
|
||||
try:
|
||||
if self._async_iter is None:
|
||||
raise StopAsyncIteration
|
||||
|
||||
# Keep getting lines until we have a result to return
|
||||
while True:
|
||||
try:
|
||||
line = await self._async_iter.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# Stream ended - send final chunk if not already finished
|
||||
if not self.finished:
|
||||
self.finished = True
|
||||
return self._create_final_chunk()
|
||||
raise
|
||||
|
||||
result = self._parse_sse_line(line)
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
except StopAsyncIteration:
|
||||
raise
|
||||
except httpx.StreamConsumed:
|
||||
raise StopAsyncIteration
|
||||
except httpx.StreamClosed:
|
||||
raise StopAsyncIteration
|
||||
except Exception as e:
|
||||
verbose_logger.error(f"Error in AgentCore SSE stream: {str(e)}")
|
||||
raise StopAsyncIteration
|
||||
@ -5,6 +5,7 @@ https://docs.aws.amazon.com/bedrock/latest/APIReference/API_agentcore_InvokeAgen
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
from urllib.parse import quote
|
||||
|
||||
@ -15,9 +16,9 @@ from litellm._uuid import uuid
|
||||
from litellm.litellm_core_utils.prompt_templates.common_utils import (
|
||||
convert_content_list_to_str,
|
||||
)
|
||||
from litellm.litellm_core_utils.streaming_handler import CustomStreamWrapper
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
|
||||
from litellm.llms.bedrock.chat.agentcore.sse_iterator import AgentCoreSSEStreamIterator
|
||||
from litellm.llms.bedrock.common_utils import BedrockError
|
||||
from litellm.types.llms.bedrock_agentcore import (
|
||||
AgentCoreMessage,
|
||||
@ -25,19 +26,17 @@ from litellm.types.llms.bedrock_agentcore import (
|
||||
AgentCoreUsage,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import Choices, Message, ModelResponse, Usage
|
||||
from litellm.types.utils import Choices, Delta, Message, ModelResponse, StreamingChoices, Usage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as _LiteLLMLoggingObj
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
LiteLLMLoggingObj = _LiteLLMLoggingObj
|
||||
else:
|
||||
LiteLLMLoggingObj = Any
|
||||
HTTPHandler = Any
|
||||
AsyncHTTPHandler = Any
|
||||
CustomStreamWrapper = Any
|
||||
|
||||
|
||||
class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
@ -116,7 +115,8 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
fake_stream: Optional[bool] = None,
|
||||
) -> Tuple[dict, Optional[bytes]]:
|
||||
# Check if api_key (bearer token) is provided for Cognito authentication
|
||||
jwt_token = optional_params.get("api_key")
|
||||
# Priority: api_key parameter first, then optional_params
|
||||
jwt_token = api_key or optional_params.get("api_key")
|
||||
if jwt_token:
|
||||
verbose_logger.debug(
|
||||
f"AgentCore: Using Bearer token authentication (Cognito/JWT) - token: {jwt_token[:50]}..."
|
||||
@ -437,22 +437,104 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
content=content, usage=usage_data, final_message=final_message
|
||||
)
|
||||
|
||||
def get_streaming_response(
|
||||
def _stream_agentcore_response_sync(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
raw_response: httpx.Response,
|
||||
) -> AgentCoreSSEStreamIterator:
|
||||
):
|
||||
"""
|
||||
Return a streaming iterator for SSE responses.
|
||||
|
||||
Args:
|
||||
model: The model name
|
||||
raw_response: Raw HTTP response with streaming data
|
||||
|
||||
Returns:
|
||||
AgentCoreSSEStreamIterator: Iterator that yields ModelResponse chunks
|
||||
Internal sync generator that parses SSE and yields ModelResponse chunks.
|
||||
"""
|
||||
return AgentCoreSSEStreamIterator(response=raw_response, model=model)
|
||||
buffer = ""
|
||||
for text_chunk in response.iter_text():
|
||||
buffer += text_chunk
|
||||
|
||||
# Process complete lines
|
||||
while '\n' in buffer:
|
||||
line, buffer = buffer.split('\n', 1)
|
||||
line = line.strip()
|
||||
|
||||
if not line or not line.startswith('data:'):
|
||||
continue
|
||||
|
||||
json_str = line[5:].strip()
|
||||
if not json_str:
|
||||
continue
|
||||
|
||||
try:
|
||||
data_obj = json.loads(json_str)
|
||||
if not isinstance(data_obj, dict):
|
||||
continue
|
||||
|
||||
# Process contentBlockDelta events
|
||||
if "event" in data_obj and isinstance(data_obj["event"], dict):
|
||||
event_payload = data_obj["event"]
|
||||
content_block_delta = event_payload.get("contentBlockDelta")
|
||||
|
||||
if content_block_delta:
|
||||
delta = content_block_delta.get("delta", {})
|
||||
text = delta.get("text", "")
|
||||
|
||||
if text:
|
||||
chunk = ModelResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(content=text, role="assistant"),
|
||||
)
|
||||
]
|
||||
yield chunk
|
||||
|
||||
# Process metadata/usage
|
||||
metadata = event_payload.get("metadata")
|
||||
if metadata and "usage" in metadata:
|
||||
chunk = ModelResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(),
|
||||
)
|
||||
]
|
||||
usage_data: AgentCoreUsage = metadata["usage"] # type: ignore
|
||||
setattr(chunk, "usage", Usage(
|
||||
prompt_tokens=usage_data.get("inputTokens", 0),
|
||||
completion_tokens=usage_data.get("outputTokens", 0),
|
||||
total_tokens=usage_data.get("totalTokens", 0),
|
||||
))
|
||||
yield chunk
|
||||
|
||||
# Process final message
|
||||
if "message" in data_obj and isinstance(data_obj["message"], dict):
|
||||
chunk = ModelResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(),
|
||||
)
|
||||
]
|
||||
yield chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
verbose_logger.debug(f"Skipping non-JSON SSE line: {line[:100]}")
|
||||
continue
|
||||
|
||||
def get_sync_custom_stream_wrapper(
|
||||
self,
|
||||
@ -466,17 +548,14 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
client: Optional[Union[HTTPHandler, "AsyncHTTPHandler"]] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
) -> "CustomStreamWrapper":
|
||||
"""
|
||||
Get a CustomStreamWrapper for synchronous streaming.
|
||||
|
||||
This is called when stream=True is passed to completion().
|
||||
Simplified sync streaming - returns a generator that yields ModelResponse chunks.
|
||||
"""
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
_get_httpx_client,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if client is None or not isinstance(client, HTTPHandler):
|
||||
client = _get_httpx_client(params={})
|
||||
@ -488,7 +567,7 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=signed_json_body if signed_json_body else json.dumps(data),
|
||||
stream=True, # THIS IS KEY - tells httpx to not buffer
|
||||
stream=True,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
@ -497,18 +576,6 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
status_code=response.status_code, message=str(response.read())
|
||||
)
|
||||
|
||||
# Create iterator for SSE stream
|
||||
completion_stream = self.get_streaming_response(
|
||||
model=model, raw_response=response
|
||||
)
|
||||
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
@ -517,7 +584,112 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
# Wrap the generator in CustomStreamWrapper
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=self._stream_agentcore_response_sync(response, model),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
async def _stream_agentcore_response(
|
||||
self,
|
||||
response: httpx.Response,
|
||||
model: str,
|
||||
) -> AsyncGenerator[ModelResponse, None]:
|
||||
"""
|
||||
Internal async generator that parses SSE and yields ModelResponse chunks.
|
||||
"""
|
||||
buffer = ""
|
||||
async for text_chunk in response.aiter_text():
|
||||
buffer += text_chunk
|
||||
|
||||
# Process complete lines
|
||||
while '\n' in buffer:
|
||||
line, buffer = buffer.split('\n', 1)
|
||||
line = line.strip()
|
||||
|
||||
if not line or not line.startswith('data:'):
|
||||
continue
|
||||
|
||||
json_str = line[5:].strip()
|
||||
if not json_str:
|
||||
continue
|
||||
|
||||
try:
|
||||
data_obj = json.loads(json_str)
|
||||
if not isinstance(data_obj, dict):
|
||||
continue
|
||||
|
||||
# Process contentBlockDelta events
|
||||
if "event" in data_obj and isinstance(data_obj["event"], dict):
|
||||
event_payload = data_obj["event"]
|
||||
content_block_delta = event_payload.get("contentBlockDelta")
|
||||
|
||||
if content_block_delta:
|
||||
delta = content_block_delta.get("delta", {})
|
||||
text = delta.get("text", "")
|
||||
|
||||
if text:
|
||||
chunk = ModelResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason=None,
|
||||
index=0,
|
||||
delta=Delta(content=text, role="assistant"),
|
||||
)
|
||||
]
|
||||
yield chunk
|
||||
|
||||
# Process metadata/usage
|
||||
metadata = event_payload.get("metadata")
|
||||
if metadata and "usage" in metadata:
|
||||
chunk = ModelResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(),
|
||||
)
|
||||
]
|
||||
usage_data: AgentCoreUsage = metadata["usage"] # type: ignore
|
||||
setattr(chunk, "usage", Usage(
|
||||
prompt_tokens=usage_data.get("inputTokens", 0),
|
||||
completion_tokens=usage_data.get("outputTokens", 0),
|
||||
total_tokens=usage_data.get("totalTokens", 0),
|
||||
))
|
||||
yield chunk
|
||||
|
||||
# Process final message
|
||||
if "message" in data_obj and isinstance(data_obj["message"], dict):
|
||||
chunk = ModelResponse(
|
||||
id=f"chatcmpl-{uuid.uuid4()}",
|
||||
created=0,
|
||||
model=model,
|
||||
object="chat.completion.chunk",
|
||||
)
|
||||
chunk.choices = [
|
||||
StreamingChoices(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
delta=Delta(),
|
||||
)
|
||||
]
|
||||
yield chunk
|
||||
|
||||
except json.JSONDecodeError:
|
||||
verbose_logger.debug(f"Skipping non-JSON SSE line: {line[:100]}")
|
||||
continue
|
||||
|
||||
async def get_async_custom_stream_wrapper(
|
||||
self,
|
||||
@ -531,17 +703,14 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
client: Optional["AsyncHTTPHandler"] = None,
|
||||
json_mode: Optional[bool] = None,
|
||||
signed_json_body: Optional[bytes] = None,
|
||||
) -> CustomStreamWrapper:
|
||||
) -> "CustomStreamWrapper":
|
||||
"""
|
||||
Get a CustomStreamWrapper for asynchronous streaming.
|
||||
|
||||
This is called when stream=True is passed to acompletion().
|
||||
Simplified async streaming - returns an async generator that yields ModelResponse chunks.
|
||||
"""
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
|
||||
if client is None or not isinstance(client, AsyncHTTPHandler):
|
||||
client = get_async_httpx_client(
|
||||
@ -555,7 +724,7 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
api_base,
|
||||
headers=headers,
|
||||
data=signed_json_body if signed_json_body else json.dumps(data),
|
||||
stream=True, # THIS IS KEY - tells httpx to not buffer
|
||||
stream=True,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
@ -564,18 +733,6 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
status_code=response.status_code, message=str(await response.aread())
|
||||
)
|
||||
|
||||
# Create iterator for SSE stream
|
||||
completion_stream = self.get_streaming_response(
|
||||
model=model, raw_response=response
|
||||
)
|
||||
|
||||
streaming_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
# LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
@ -584,7 +741,13 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
|
||||
return streaming_response
|
||||
# Wrap the async generator in CustomStreamWrapper
|
||||
return CustomStreamWrapper(
|
||||
completion_stream=self._stream_agentcore_response(response, model),
|
||||
model=model,
|
||||
custom_llm_provider="bedrock",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
@property
|
||||
def has_custom_stream_wrapper(self) -> bool:
|
||||
@ -692,4 +855,5 @@ class AmazonAgentCoreConfig(BaseConfig, BaseAWSLLM):
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> bool:
|
||||
return True
|
||||
# AgentCore supports true streaming - don't buffer
|
||||
return False
|
||||
|
||||
@ -56,7 +56,7 @@ google-cloud-iam = {version = "^2.19.1", optional = true}
|
||||
resend = {version = ">=0.8.0", optional = true}
|
||||
pynacl = {version = "^1.5.0", optional = true}
|
||||
websockets = {version = "^15.0.1", optional = true}
|
||||
boto3 = {version = "1.36.0", optional = true}
|
||||
boto3 = { version = "1.40.76", optional = true }
|
||||
redisvl = {version = "^0.4.1", optional = true, markers = "python_version >= '3.9' and python_version < '3.14'"}
|
||||
mcp = {version = "^1.21.2", optional = true, python = ">=3.10"}
|
||||
litellm-proxy-extras = {version = "0.4.23", optional = true}
|
||||
|
||||
@ -10,7 +10,7 @@ uvicorn==0.31.1 # server dep
|
||||
gunicorn==23.0.0 # server dep
|
||||
fastuuid==0.13.5 # for uuid4
|
||||
uvloop==0.21.0 # uvicorn dep, gives us much better performance under load
|
||||
boto3==1.36.0 # aws bedrock/sagemaker calls
|
||||
boto3==1.40.53 # aws bedrock/sagemaker calls (has bedrock-agentcore-control, compatible with aioboto3)
|
||||
redis==5.2.1 # redis caching
|
||||
prisma==0.11.0 # for db
|
||||
nodejs-wheel-binaries==24.12.0 ## required by prisma for migrations, prevents runtime download (updated from nodejs-bin for security fixes)
|
||||
@ -58,8 +58,8 @@ tokenizers==0.20.2 # for calculating usage
|
||||
click==8.1.7 # for proxy cli
|
||||
rich==13.7.1 # for litellm proxy cli
|
||||
jinja2==3.1.6 # for prompt templates
|
||||
aioboto3==15.5.0 # for async sagemaker calls (updated to match boto3 1.40.73)
|
||||
aiohttp==3.13.3 # for network calls
|
||||
aioboto3==13.4.0 # for async sagemaker calls
|
||||
tenacity==8.5.0 # for retrying requests, when litellm.num_retries set
|
||||
pydantic>=2.11,<3 # proxy + openai req. + mcp
|
||||
jsonschema>=4.23.0,<5.0.0 # validating json schema - aligned with openapi-core + mcp
|
||||
|
||||
23
taplo.toml
Normal file
23
taplo.toml
Normal file
@ -0,0 +1,23 @@
|
||||
[formatting]
|
||||
|
||||
# Keep your table/key order as written (project, project.scripts, dependency-groups, tool.uv, ...)
|
||||
reorder_keys = false
|
||||
|
||||
# Force arrays to stay multiline once expanded
|
||||
array_auto_expand = true
|
||||
array_auto_collapse = false
|
||||
|
||||
# Keep nice spacing inside arrays and inline tables
|
||||
compact_arrays = false
|
||||
compact_inline_tables = true
|
||||
|
||||
# Don’t align `=` vertically (matches your example)
|
||||
align_entries = true
|
||||
|
||||
# Reasonable defaults
|
||||
align_comments = true
|
||||
trailing_newline = true
|
||||
reorder_arrays = true
|
||||
reorder_inline_tables = false
|
||||
allowed_blank_lines = 2
|
||||
crlf = false
|
||||
@ -12,10 +12,9 @@ sys.path.insert(
|
||||
)
|
||||
|
||||
import litellm
|
||||
from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import pytest
|
||||
import httpx
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model", [
|
||||
@ -367,3 +366,279 @@ def test_bedrock_agentcore_without_api_key_uses_sigv4():
|
||||
assert "X-Amzn-Bedrock-AgentCore-Runtime-Session-Id" in headers
|
||||
assert headers["X-Amzn-Bedrock-AgentCore-Runtime-Session-Id"] == "sigv4-test-session"
|
||||
|
||||
|
||||
def test_agentcore_parse_json_response():
|
||||
"""
|
||||
Unit test for JSON response parsing (non-streaming)
|
||||
Verifies that content-type: application/json responses are parsed correctly
|
||||
"""
|
||||
from litellm.llms.bedrock.chat.agentcore.transformation import AmazonAgentCoreConfig
|
||||
|
||||
config = AmazonAgentCoreConfig()
|
||||
|
||||
# Create a mock JSON response
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"result": {
|
||||
"role": "assistant",
|
||||
"content": [{"text": "Hello from JSON response"}]
|
||||
}
|
||||
}
|
||||
|
||||
# Parse the response
|
||||
parsed = config._get_parsed_response(mock_response)
|
||||
|
||||
# Verify content extraction
|
||||
assert parsed["content"] == "Hello from JSON response"
|
||||
# JSON responses don't include usage data
|
||||
assert parsed["usage"] is None
|
||||
# Final message should be the result object
|
||||
assert parsed["final_message"] == mock_response.json.return_value["result"]
|
||||
|
||||
|
||||
def test_agentcore_parse_sse_response():
|
||||
"""
|
||||
Unit test for SSE response parsing (streaming response consumed as text)
|
||||
Verifies that text/event-stream responses are parsed correctly
|
||||
"""
|
||||
from litellm.llms.bedrock.chat.agentcore.transformation import AmazonAgentCoreConfig
|
||||
|
||||
config = AmazonAgentCoreConfig()
|
||||
|
||||
# Create a mock SSE response with multiple events
|
||||
sse_data = """data: {"event":{"contentBlockDelta":{"delta":{"text":"Hello "}}}}
|
||||
|
||||
data: {"event":{"contentBlockDelta":{"delta":{"text":"from SSE"}}}}
|
||||
|
||||
data: {"event":{"metadata":{"usage":{"inputTokens":10,"outputTokens":5,"totalTokens":15}}}}
|
||||
|
||||
data: {"message":{"role":"assistant","content":[{"text":"Hello from SSE"}]}}
|
||||
"""
|
||||
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.headers = {"content-type": "text/event-stream"}
|
||||
mock_response.text = sse_data
|
||||
|
||||
# Parse the response
|
||||
parsed = config._get_parsed_response(mock_response)
|
||||
|
||||
# Verify content extraction from final message
|
||||
assert parsed["content"] == "Hello from SSE"
|
||||
# SSE responses can include usage data
|
||||
assert parsed["usage"] is not None
|
||||
assert parsed["usage"]["inputTokens"] == 10
|
||||
assert parsed["usage"]["outputTokens"] == 5
|
||||
assert parsed["usage"]["totalTokens"] == 15
|
||||
# Final message should be present
|
||||
assert parsed["final_message"] is not None
|
||||
assert parsed["final_message"]["role"] == "assistant"
|
||||
|
||||
|
||||
def test_agentcore_parse_sse_response_without_final_message():
|
||||
"""
|
||||
Unit test for SSE response parsing when only deltas are present (no final message)
|
||||
"""
|
||||
from litellm.llms.bedrock.chat.agentcore.transformation import AmazonAgentCoreConfig
|
||||
|
||||
config = AmazonAgentCoreConfig()
|
||||
|
||||
# Create a mock SSE response with only content deltas
|
||||
sse_data = """data: {"event":{"contentBlockDelta":{"delta":{"text":"First "}}}}
|
||||
|
||||
data: {"event":{"contentBlockDelta":{"delta":{"text":"second "}}}}
|
||||
|
||||
data: {"event":{"contentBlockDelta":{"delta":{"text":"third"}}}}
|
||||
"""
|
||||
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.headers = {"content-type": "text/event-stream"}
|
||||
mock_response.text = sse_data
|
||||
|
||||
# Parse the response
|
||||
parsed = config._get_parsed_response(mock_response)
|
||||
|
||||
# Content should be concatenated from deltas
|
||||
assert parsed["content"] == "First second third"
|
||||
# No final message
|
||||
assert parsed["final_message"] is None
|
||||
|
||||
|
||||
def test_agentcore_transform_response_json():
|
||||
"""
|
||||
Integration test for transform_response with JSON response
|
||||
Verifies end-to-end transformation of JSON responses to ModelResponse
|
||||
"""
|
||||
from litellm.llms.bedrock.chat.agentcore.transformation import AmazonAgentCoreConfig
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
config = AmazonAgentCoreConfig()
|
||||
|
||||
# Create mock JSON response
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.json.return_value = {
|
||||
"result": {
|
||||
"role": "assistant",
|
||||
"content": [{"text": "Response from transform_response"}]
|
||||
}
|
||||
}
|
||||
mock_response.status_code = 200
|
||||
|
||||
# Create model response
|
||||
model_response = ModelResponse()
|
||||
|
||||
# Mock logging object
|
||||
mock_logging = MagicMock()
|
||||
|
||||
# Transform the response
|
||||
result = config.transform_response(
|
||||
model="bedrock/agentcore/arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/test",
|
||||
raw_response=mock_response,
|
||||
model_response=model_response,
|
||||
logging_obj=mock_logging,
|
||||
request_data={},
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
encoding=None,
|
||||
)
|
||||
|
||||
# Verify ModelResponse structure
|
||||
assert len(result.choices) == 1
|
||||
assert result.choices[0].message.content == "Response from transform_response"
|
||||
assert result.choices[0].message.role == "assistant"
|
||||
assert result.choices[0].finish_reason == "stop"
|
||||
assert result.choices[0].index == 0
|
||||
|
||||
|
||||
def test_agentcore_transform_response_sse():
|
||||
"""
|
||||
Integration test for transform_response with SSE response
|
||||
Verifies end-to-end transformation of SSE responses to ModelResponse
|
||||
"""
|
||||
from litellm.llms.bedrock.chat.agentcore.transformation import AmazonAgentCoreConfig
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
config = AmazonAgentCoreConfig()
|
||||
|
||||
# Create mock SSE response
|
||||
sse_data = """data: {"event":{"contentBlockDelta":{"delta":{"text":"SSE "}}}}
|
||||
|
||||
data: {"event":{"contentBlockDelta":{"delta":{"text":"response"}}}}
|
||||
|
||||
data: {"event":{"metadata":{"usage":{"inputTokens":20,"outputTokens":10,"totalTokens":30}}}}
|
||||
|
||||
data: {"message":{"role":"assistant","content":[{"text":"SSE response"}]}}
|
||||
"""
|
||||
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.headers = {"content-type": "text/event-stream"}
|
||||
mock_response.text = sse_data
|
||||
mock_response.status_code = 200
|
||||
|
||||
# Create model response
|
||||
model_response = ModelResponse()
|
||||
|
||||
# Mock logging object
|
||||
mock_logging = MagicMock()
|
||||
|
||||
# Transform the response
|
||||
result = config.transform_response(
|
||||
model="bedrock/agentcore/arn:aws:bedrock-agentcore:us-west-2:123456789012:runtime/test",
|
||||
raw_response=mock_response,
|
||||
model_response=model_response,
|
||||
logging_obj=mock_logging,
|
||||
request_data={},
|
||||
messages=[{"role": "user", "content": "test"}],
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
encoding=None,
|
||||
)
|
||||
|
||||
# Verify ModelResponse structure
|
||||
assert len(result.choices) == 1
|
||||
assert result.choices[0].message.content == "SSE response"
|
||||
assert result.choices[0].message.role == "assistant"
|
||||
assert result.choices[0].finish_reason == "stop"
|
||||
|
||||
# Verify usage data from SSE metadata
|
||||
assert hasattr(result, "usage")
|
||||
assert result.usage.prompt_tokens == 20
|
||||
assert result.usage.completion_tokens == 10
|
||||
assert result.usage.total_tokens == 30
|
||||
|
||||
|
||||
def test_agentcore_synchronous_non_streaming_response():
|
||||
"""
|
||||
Test that synchronous (non-streaming) AgentCore calls still work correctly
|
||||
after streaming simplification changes.
|
||||
|
||||
This test verifies:
|
||||
1. Synchronous completion calls work (stream=False or no stream param)
|
||||
2. Response is properly parsed and returned as ModelResponse
|
||||
3. Content is extracted correctly
|
||||
4. Usage data is calculated when not provided by API
|
||||
|
||||
This is a regression test for the streaming simplification changes
|
||||
to ensure we didn't break the non-streaming code path.
|
||||
"""
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
litellm._turn_on_debug()
|
||||
client = HTTPHandler()
|
||||
|
||||
# Mock a JSON response (typical for synchronous AgentCore calls)
|
||||
mock_json_response = {
|
||||
"result": {
|
||||
"role": "assistant",
|
||||
"content": [{"text": "This is a synchronous response from AgentCore."}]
|
||||
}
|
||||
}
|
||||
|
||||
# Create a mock response object
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.headers = {"content-type": "application/json"}
|
||||
mock_response.json.return_value = mock_json_response
|
||||
|
||||
with patch.object(client, "post", return_value=mock_response) as mock_post:
|
||||
# Make a synchronous (non-streaming) completion call
|
||||
response = litellm.completion(
|
||||
model="bedrock/agentcore/arn:aws:bedrock-agentcore:us-west-2:888602223428:runtime/hosted_agent_r9jvp-3ySZuRHjLC",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Test synchronous response",
|
||||
}
|
||||
],
|
||||
stream=False, # Explicitly disable streaming
|
||||
client=client,
|
||||
)
|
||||
|
||||
# Verify the response structure
|
||||
assert response is not None
|
||||
assert hasattr(response, "choices")
|
||||
assert len(response.choices) > 0
|
||||
|
||||
# Verify content
|
||||
message = response.choices[0].message
|
||||
assert message is not None
|
||||
assert message.content == "This is a synchronous response from AgentCore."
|
||||
assert message.role == "assistant"
|
||||
|
||||
# Verify completion metadata
|
||||
assert response.choices[0].finish_reason == "stop"
|
||||
assert response.choices[0].index == 0
|
||||
|
||||
# Verify usage data exists (either from API or calculated)
|
||||
assert hasattr(response, "usage")
|
||||
assert response.usage is not None
|
||||
assert response.usage.prompt_tokens > 0
|
||||
assert response.usage.completion_tokens > 0
|
||||
assert response.usage.total_tokens > 0
|
||||
|
||||
print(f"Synchronous response: {response}")
|
||||
print(f"Content: {message.content}")
|
||||
print(f"Usage: prompt={response.usage.prompt_tokens}, completion={response.usage.completion_tokens}, total={response.usage.total_tokens}")
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user