fix(proxy): Bedrock Knowledge Base pass-through: preserve SigV4 headers and signed request body (#27526)
* Fix Bedrock KB pass-through SigV4 headers and signed body Coerce botocore HeadersDict to a dict for pass-through routes. When forward_headers is true, drop request headers that collide case-insensitively with signed headers so client Bearer auth does not shadow AWS SigV4. Send prepped.body as raw content so the outbound payload matches the signature after logging hooks mutate the parsed dict. Co-authored-by: Cursor <cursoragent@cursor.com> * Simplify pass-through raw body handling Read the SigV4-signed bytes directly from request.state inside pass_through_request instead of threading a custom_raw_body argument through three functions. Helper methods are restored to their original signatures, and the new branch lives in one place at each httpx call site. Co-authored-by: Cursor <cursoragent@cursor.com> * Harden pass-through raw body read from request.state Guard missing request.state (test fixtures) and ignore non-bytes/str values so MagicMock does not trigger the SigV4 raw-body path. Co-authored-by: Cursor <cursoragent@cursor.com> * Test pass_through_request state_raw_body uses httpx content= Cover non-streaming (async_client.request) and streaming (build_request) paths so SigV4 bytes on request.state are not replaced by json= of a hook-mutated dict. Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
4148667671
commit
f45909cb81
@ -71,6 +71,11 @@ class BasePassthroughUtils:
|
||||
request_headers.pop("content-length", None)
|
||||
request_headers.pop("host", None)
|
||||
|
||||
custom_header_names = {header_name.lower() for header_name in headers}
|
||||
for header_name in list(request_headers.keys()):
|
||||
if header_name.lower() in custom_header_names:
|
||||
request_headers.pop(header_name, None)
|
||||
|
||||
# Combine request headers with custom headers
|
||||
headers = {**request_headers, **headers}
|
||||
|
||||
|
||||
@ -44,6 +44,7 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
)
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
|
||||
LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY,
|
||||
)
|
||||
from litellm.proxy.utils import is_known_model
|
||||
from litellm.proxy.vector_store_endpoints.utils import (
|
||||
@ -1123,6 +1124,9 @@ async def bedrock_proxy_route(
|
||||
_forward_headers=True,
|
||||
) # dynamically construct pass-through endpoint based on incoming path
|
||||
setattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY, data)
|
||||
# SigV4 signs an exact payload; pass-through must send prepped.body, not json.dumps
|
||||
# of a dict that hooks may mutate (logging_obj, metadata, etc.).
|
||||
setattr(request.state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY, prepped.body)
|
||||
received_value = await endpoint_func(
|
||||
request,
|
||||
fastapi_response,
|
||||
|
||||
@ -6,7 +6,7 @@ import posixpath
|
||||
import traceback
|
||||
from base64 import b64encode
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union, cast
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union, cast
|
||||
from urllib.parse import urlencode, urlparse
|
||||
|
||||
import httpx
|
||||
@ -62,6 +62,7 @@ from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
EndpointType,
|
||||
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
|
||||
LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY,
|
||||
PassthroughStandardLoggingPayload,
|
||||
)
|
||||
|
||||
@ -735,6 +736,22 @@ async def pass_through_request( # noqa: PLR0915
|
||||
str(url)
|
||||
)
|
||||
|
||||
# SigV4-signed callers (e.g. Bedrock) attach the exact bytes that were
|
||||
# signed via request.state; we must send those instead of re-encoding the
|
||||
# parsed dict (hooks mutate it, breaking the signature / Content-Length).
|
||||
# Tolerate request objects without `state` (test fixtures) and only honor
|
||||
# values httpx accepts for `content=`.
|
||||
_request_state = getattr(request, "state", None)
|
||||
state_raw_body: Optional[Union[str, bytes]] = (
|
||||
getattr(_request_state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY, None)
|
||||
if _request_state is not None
|
||||
else None
|
||||
)
|
||||
if state_raw_body is not None and not isinstance(
|
||||
state_raw_body, (str, bytes, bytearray)
|
||||
):
|
||||
state_raw_body = None
|
||||
|
||||
# Skip body parsing for multipart requests - make_multipart_http_request will handle it
|
||||
# But if custom_body is provided (e.g., JSON parsed despite multipart content-type), use it
|
||||
is_multipart = (
|
||||
@ -883,12 +900,19 @@ async def pass_through_request( # noqa: PLR0915
|
||||
)
|
||||
)
|
||||
else:
|
||||
# SigV4-signed callers (Bedrock) supply the exact pre-signed bytes;
|
||||
# otherwise httpx encodes the parsed JSON dict as before.
|
||||
body_kwargs: Dict[str, Any] = (
|
||||
{"content": state_raw_body}
|
||||
if state_raw_body is not None
|
||||
else {"json": _parsed_body}
|
||||
)
|
||||
req = async_client.build_request(
|
||||
"POST",
|
||||
url,
|
||||
json=_parsed_body,
|
||||
params=requested_query_params,
|
||||
headers=headers,
|
||||
**body_kwargs,
|
||||
)
|
||||
|
||||
response = await async_client.send(req, stream=stream)
|
||||
@ -917,17 +941,28 @@ async def pass_through_request( # noqa: PLR0915
|
||||
status_code=response.status_code,
|
||||
)
|
||||
|
||||
response = (
|
||||
await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
|
||||
request=request,
|
||||
async_client=async_client,
|
||||
if state_raw_body is not None:
|
||||
# SigV4-signed callers (Bedrock) require the exact pre-signed bytes
|
||||
# to be forwarded so the signature/Content-Length stay valid.
|
||||
response = await async_client.request(
|
||||
method=request.method,
|
||||
url=url,
|
||||
headers=headers,
|
||||
requested_query_params=requested_query_params,
|
||||
_parsed_body=_parsed_body,
|
||||
forward_multipart=is_multipart,
|
||||
params=requested_query_params,
|
||||
content=state_raw_body,
|
||||
)
|
||||
else:
|
||||
response = (
|
||||
await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
|
||||
request=request,
|
||||
async_client=async_client,
|
||||
url=url,
|
||||
headers=headers,
|
||||
requested_query_params=requested_query_params,
|
||||
_parsed_body=_parsed_body,
|
||||
forward_multipart=is_multipart,
|
||||
)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug("response.headers= %s", response.headers)
|
||||
|
||||
if _is_streaming_response(response) is True:
|
||||
@ -1225,7 +1260,7 @@ async def _parse_request_data_by_content_type(
|
||||
def create_pass_through_route(
|
||||
endpoint,
|
||||
target: str,
|
||||
custom_headers: Optional[dict] = None,
|
||||
custom_headers: Optional[Mapping[str, Any]] = None,
|
||||
_forward_headers: Optional[bool] = False,
|
||||
_merge_query_params: Optional[bool] = False,
|
||||
dependencies: Optional[List] = None,
|
||||
@ -1335,9 +1370,12 @@ def create_pass_through_route(
|
||||
)
|
||||
)
|
||||
|
||||
# Ensure custom_headers is a dict
|
||||
# Ensure custom_headers is a dict. Botocore returns a HeadersDict
|
||||
# for SigV4-prepared requests, which is a Mapping but not a dict.
|
||||
headers_dict = (
|
||||
param_custom_headers if isinstance(param_custom_headers, dict) else {}
|
||||
dict(param_custom_headers)
|
||||
if isinstance(param_custom_headers, Mapping)
|
||||
else {}
|
||||
)
|
||||
|
||||
# Ensure query_params and custom_body are dicts or None
|
||||
@ -1380,6 +1418,8 @@ def create_pass_through_route(
|
||||
finally:
|
||||
if hasattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY):
|
||||
delattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY)
|
||||
if hasattr(request.state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY):
|
||||
delattr(request.state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY)
|
||||
|
||||
return endpoint_func
|
||||
|
||||
|
||||
@ -7,6 +7,10 @@ from typing_extensions import TypedDict
|
||||
# JSON without a FastAPI `custom_body` parameter (which would consume the HTTP body).
|
||||
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY = "litellm_pass_through_custom_body"
|
||||
|
||||
# Request.state key for programmatic pass-through callers that must preserve an
|
||||
# exact byte/string body, such as AWS SigV4-signed requests.
|
||||
LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY = "litellm_pass_through_raw_body"
|
||||
|
||||
|
||||
class EndpointType(str, Enum):
|
||||
VERTEX_AI = "vertex-ai"
|
||||
|
||||
@ -20,6 +20,9 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
|
||||
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
|
||||
pass_through_request,
|
||||
)
|
||||
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
|
||||
LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY,
|
||||
)
|
||||
from litellm.proxy.pass_through_endpoints.success_handler import (
|
||||
PassThroughEndpointLogging,
|
||||
)
|
||||
@ -2153,7 +2156,12 @@ async def test_create_pass_through_route_custom_body_url_target():
|
||||
endpoint_func = create_pass_through_route(
|
||||
endpoint=unique_path,
|
||||
target="https://bedrock-agent-runtime.us-east-1.amazonaws.com",
|
||||
custom_headers={"Content-Type": "application/json"},
|
||||
custom_headers=Headers(
|
||||
{
|
||||
"Authorization": "AWS4-HMAC-SHA256 signed",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
),
|
||||
_forward_headers=True,
|
||||
)
|
||||
|
||||
@ -2213,6 +2221,147 @@ async def test_create_pass_through_route_custom_body_url_target():
|
||||
# The critical assertion: custom_body takes precedence over
|
||||
# the body parsed from the raw request
|
||||
assert call_kwargs["custom_body"] == bedrock_body
|
||||
# HeadersDict-like custom_headers (e.g. botocore SigV4) must be coerced
|
||||
# to a plain dict so signed headers actually reach the upstream.
|
||||
assert call_kwargs["custom_headers"] == {
|
||||
"authorization": "AWS4-HMAC-SHA256 signed",
|
||||
"content-type": "application/json",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass_through_request_non_streaming_uses_content_for_state_raw_body():
|
||||
"""
|
||||
Bedrock SigV4 path: exact signed bytes live on request.state; upstream must receive
|
||||
content=... even if pre_call_hook mutates the parsed dict (would change json=).
|
||||
"""
|
||||
# Bytes that were signed (simulated); parsed body + hook will diverge on purpose.
|
||||
raw_signed = b'{"retrievalQuery":{"text":"signed"},"sig":"intact"}'
|
||||
parsed_from_wire = {"retrievalQuery": {"text": "signed"}, "sig": "intact"}
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.method = "POST"
|
||||
mock_request.query_params = QueryParams({})
|
||||
mock_request.headers = Headers({"Content-Type": "application/json"})
|
||||
mock_request.state = SimpleNamespace()
|
||||
setattr(mock_request.state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY, raw_signed)
|
||||
mock_request.body = AsyncMock(
|
||||
return_value=json.dumps(parsed_from_wire).encode("utf-8")
|
||||
)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.api_key = "sk-test"
|
||||
|
||||
upstream = httpx.Response(
|
||||
status_code=200,
|
||||
headers={"content-type": "application/json"},
|
||||
content=b'{"ok": true}',
|
||||
request=httpx.Request(
|
||||
"POST",
|
||||
"https://bedrock-agent-runtime.us-east-1.amazonaws.com/knowledgebases/KB/retrieve",
|
||||
),
|
||||
)
|
||||
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.request = AsyncMock(return_value=upstream)
|
||||
mock_client_obj = MagicMock()
|
||||
mock_client_obj.client = mock_async_client
|
||||
|
||||
async def _hook_mutates_body(**kwargs):
|
||||
data = kwargs["data"]
|
||||
if isinstance(data, dict):
|
||||
data["hook_mutated"] = True
|
||||
return data
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client",
|
||||
return_value=mock_client_obj,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook",
|
||||
new=AsyncMock(side_effect=_hook_mutates_body),
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
):
|
||||
await pass_through_request(
|
||||
request=mock_request,
|
||||
target="https://bedrock-agent-runtime.us-east-1.amazonaws.com/knowledgebases/KB/retrieve",
|
||||
custom_headers={"content-type": "application/json"},
|
||||
user_api_key_dict=mock_user,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
mock_async_client.request.assert_called_once()
|
||||
req_kw = mock_async_client.request.call_args[1]
|
||||
assert req_kw.get("content") == raw_signed
|
||||
assert "json" not in req_kw
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pass_through_request_streaming_uses_content_for_state_raw_body():
|
||||
"""Streaming pass-through with state raw body must use build_request(..., content=...)."""
|
||||
raw_signed = b'{"model":"m","stream":true}'
|
||||
parsed_from_wire = {"model": "m", "stream": True}
|
||||
|
||||
mock_request = MagicMock(spec=Request)
|
||||
mock_request.method = "POST"
|
||||
mock_request.query_params = QueryParams({})
|
||||
mock_request.headers = Headers({"Content-Type": "application/json"})
|
||||
mock_request.state = SimpleNamespace()
|
||||
setattr(mock_request.state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY, raw_signed)
|
||||
mock_request.body = AsyncMock(
|
||||
return_value=json.dumps(parsed_from_wire).encode("utf-8")
|
||||
)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.api_key = "sk-test"
|
||||
|
||||
mock_built = MagicMock()
|
||||
mock_async_client = AsyncMock()
|
||||
mock_async_client.build_request = MagicMock(return_value=mock_built)
|
||||
stream_resp = httpx.Response(
|
||||
status_code=200,
|
||||
headers={"content-type": "text/event-stream"},
|
||||
content=b"data: {}\n\n",
|
||||
request=httpx.Request("POST", "https://example.com/v1/messages"),
|
||||
)
|
||||
mock_async_client.send = AsyncMock(return_value=stream_resp)
|
||||
mock_client_obj = MagicMock()
|
||||
mock_client_obj.client = mock_async_client
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client",
|
||||
return_value=mock_client_obj,
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook",
|
||||
new=AsyncMock(side_effect=lambda **kw: kw["data"]),
|
||||
),
|
||||
patch(
|
||||
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler",
|
||||
new=AsyncMock(),
|
||||
),
|
||||
):
|
||||
response = await pass_through_request(
|
||||
request=mock_request,
|
||||
target="https://example.com/v1/messages",
|
||||
custom_headers={"Authorization": "Bearer x"},
|
||||
user_api_key_dict=mock_user,
|
||||
stream=None,
|
||||
)
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
assert isinstance(response, StreamingResponse)
|
||||
mock_async_client.build_request.assert_called_once()
|
||||
br_kw = mock_async_client.build_request.call_args[1]
|
||||
assert br_kw.get("content") == raw_signed
|
||||
assert "json" not in br_kw
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
||||
@ -538,6 +538,36 @@ def test_forward_headers_from_request_protected_headers_not_overwritten():
|
||||
assert "Anthropic-Beta" not in result
|
||||
|
||||
|
||||
def test_forward_headers_custom_wins_case_insensitive_over_request_authorization():
|
||||
"""
|
||||
When forwarding request headers, provider-signed/custom headers must win
|
||||
even if the incoming request uses a different case for the same header name.
|
||||
"""
|
||||
from litellm.passthrough.utils import BasePassthroughUtils
|
||||
|
||||
request_headers = {
|
||||
"authorization": "Bearer sk-litellm-key",
|
||||
"content-type": "application/json",
|
||||
"x-request-id": "req-123",
|
||||
}
|
||||
signed_headers = {
|
||||
"Authorization": "AWS4-HMAC-SHA256 signed",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
result = BasePassthroughUtils.forward_headers_from_request(
|
||||
request_headers=request_headers,
|
||||
headers=signed_headers.copy(),
|
||||
forward_headers=True,
|
||||
)
|
||||
|
||||
assert result["Authorization"] == "AWS4-HMAC-SHA256 signed"
|
||||
assert "authorization" not in result
|
||||
assert result["Content-Type"] == "application/json"
|
||||
assert "content-type" not in result
|
||||
assert result["x-request-id"] == "req-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_passthrough_custom_model_name_replaced_in_url():
|
||||
"""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user