fix(proxy): Bedrock Knowledge Base pass-through: preserve SigV4 headers and signed request body (#27526)

* Fix Bedrock KB pass-through SigV4 headers and signed body

Coerce botocore HeadersDict to a dict for pass-through routes. When
forward_headers is true, drop request headers that collide case-insensitively
with signed headers so client Bearer auth does not shadow AWS SigV4.
Send prepped.body as raw content so the outbound payload matches the
signature after logging hooks mutate the parsed dict.

Co-authored-by: Cursor <cursoragent@cursor.com>

* Simplify pass-through raw body handling

Read the SigV4-signed bytes directly from request.state inside
pass_through_request instead of threading a custom_raw_body argument
through three functions. Helper methods are restored to their original
signatures, and the new branch lives in one place at each httpx call site.

Co-authored-by: Cursor <cursoragent@cursor.com>

* Harden pass-through raw body read from request.state

Guard missing request.state (test fixtures) and ignore non-bytes/str
values so MagicMock does not trigger the SigV4 raw-body path.

Co-authored-by: Cursor <cursoragent@cursor.com>

* Test pass_through_request state_raw_body uses httpx content=

Cover non-streaming (async_client.request) and streaming (build_request)
paths so SigV4 bytes on request.state are not replaced by json= of a
hook-mutated dict.

Co-authored-by: Cursor <cursoragent@cursor.com>

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
milan-berri 2026-05-25 16:51:55 +03:00 committed by GitHub
parent 4148667671
commit f45909cb81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 246 additions and 14 deletions

View File

@ -71,6 +71,11 @@ class BasePassthroughUtils:
request_headers.pop("content-length", None)
request_headers.pop("host", None)
custom_header_names = {header_name.lower() for header_name in headers}
for header_name in list(request_headers.keys()):
if header_name.lower() in custom_header_names:
request_headers.pop(header_name, None)
# Combine request headers with custom headers
headers = {**request_headers, **headers}

View File

@ -44,6 +44,7 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
)
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY,
)
from litellm.proxy.utils import is_known_model
from litellm.proxy.vector_store_endpoints.utils import (
@ -1123,6 +1124,9 @@ async def bedrock_proxy_route(
_forward_headers=True,
) # dynamically construct pass-through endpoint based on incoming path
setattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY, data)
# SigV4 signs an exact payload; pass-through must send prepped.body, not json.dumps
# of a dict that hooks may mutate (logging_obj, metadata, etc.).
setattr(request.state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY, prepped.body)
received_value = await endpoint_func(
request,
fastapi_response,

View File

@ -6,7 +6,7 @@ import posixpath
import traceback
from base64 import b64encode
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union, cast
from urllib.parse import urlencode, urlparse
import httpx
@ -62,6 +62,7 @@ from litellm.types.llms.custom_http import httpxSpecialProvider
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
EndpointType,
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY,
PassthroughStandardLoggingPayload,
)
@ -735,6 +736,22 @@ async def pass_through_request( # noqa: PLR0915
str(url)
)
# SigV4-signed callers (e.g. Bedrock) attach the exact bytes that were
# signed via request.state; we must send those instead of re-encoding the
# parsed dict (hooks mutate it, breaking the signature / Content-Length).
# Tolerate request objects without `state` (test fixtures) and only honor
# values httpx accepts for `content=`.
_request_state = getattr(request, "state", None)
state_raw_body: Optional[Union[str, bytes]] = (
getattr(_request_state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY, None)
if _request_state is not None
else None
)
if state_raw_body is not None and not isinstance(
state_raw_body, (str, bytes, bytearray)
):
state_raw_body = None
# Skip body parsing for multipart requests - make_multipart_http_request will handle it
# But if custom_body is provided (e.g., JSON parsed despite multipart content-type), use it
is_multipart = (
@ -883,12 +900,19 @@ async def pass_through_request( # noqa: PLR0915
)
)
else:
# SigV4-signed callers (Bedrock) supply the exact pre-signed bytes;
# otherwise httpx encodes the parsed JSON dict as before.
body_kwargs: Dict[str, Any] = (
{"content": state_raw_body}
if state_raw_body is not None
else {"json": _parsed_body}
)
req = async_client.build_request(
"POST",
url,
json=_parsed_body,
params=requested_query_params,
headers=headers,
**body_kwargs,
)
response = await async_client.send(req, stream=stream)
@ -917,17 +941,28 @@ async def pass_through_request( # noqa: PLR0915
status_code=response.status_code,
)
response = (
await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
request=request,
async_client=async_client,
if state_raw_body is not None:
# SigV4-signed callers (Bedrock) require the exact pre-signed bytes
# to be forwarded so the signature/Content-Length stay valid.
response = await async_client.request(
method=request.method,
url=url,
headers=headers,
requested_query_params=requested_query_params,
_parsed_body=_parsed_body,
forward_multipart=is_multipart,
params=requested_query_params,
content=state_raw_body,
)
else:
response = (
await HttpPassThroughEndpointHelpers.non_streaming_http_request_handler(
request=request,
async_client=async_client,
url=url,
headers=headers,
requested_query_params=requested_query_params,
_parsed_body=_parsed_body,
forward_multipart=is_multipart,
)
)
)
verbose_proxy_logger.debug("response.headers= %s", response.headers)
if _is_streaming_response(response) is True:
@ -1225,7 +1260,7 @@ async def _parse_request_data_by_content_type(
def create_pass_through_route(
endpoint,
target: str,
custom_headers: Optional[dict] = None,
custom_headers: Optional[Mapping[str, Any]] = None,
_forward_headers: Optional[bool] = False,
_merge_query_params: Optional[bool] = False,
dependencies: Optional[List] = None,
@ -1335,9 +1370,12 @@ def create_pass_through_route(
)
)
# Ensure custom_headers is a dict
# Ensure custom_headers is a dict. Botocore returns a HeadersDict
# for SigV4-prepared requests, which is a Mapping but not a dict.
headers_dict = (
param_custom_headers if isinstance(param_custom_headers, dict) else {}
dict(param_custom_headers)
if isinstance(param_custom_headers, Mapping)
else {}
)
# Ensure query_params and custom_body are dicts or None
@ -1380,6 +1418,8 @@ def create_pass_through_route(
finally:
if hasattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY):
delattr(request.state, LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY)
if hasattr(request.state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY):
delattr(request.state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY)
return endpoint_func

View File

@ -7,6 +7,10 @@ from typing_extensions import TypedDict
# JSON without a FastAPI `custom_body` parameter (which would consume the HTTP body).
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY = "litellm_pass_through_custom_body"
# Request.state key for programmatic pass-through callers that must preserve an
# exact byte/string body, such as AWS SigV4-signed requests.
LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY = "litellm_pass_through_raw_body"
class EndpointType(str, Enum):
VERTEX_AI = "vertex-ai"

View File

@ -20,6 +20,9 @@ from litellm.proxy.pass_through_endpoints.pass_through_endpoints import (
LITELLM_PASS_THROUGH_CUSTOM_BODY_STATE_KEY,
pass_through_request,
)
from litellm.types.passthrough_endpoints.pass_through_endpoints import (
LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY,
)
from litellm.proxy.pass_through_endpoints.success_handler import (
PassThroughEndpointLogging,
)
@ -2153,7 +2156,12 @@ async def test_create_pass_through_route_custom_body_url_target():
endpoint_func = create_pass_through_route(
endpoint=unique_path,
target="https://bedrock-agent-runtime.us-east-1.amazonaws.com",
custom_headers={"Content-Type": "application/json"},
custom_headers=Headers(
{
"Authorization": "AWS4-HMAC-SHA256 signed",
"Content-Type": "application/json",
}
),
_forward_headers=True,
)
@ -2213,6 +2221,147 @@ async def test_create_pass_through_route_custom_body_url_target():
# The critical assertion: custom_body takes precedence over
# the body parsed from the raw request
assert call_kwargs["custom_body"] == bedrock_body
# HeadersDict-like custom_headers (e.g. botocore SigV4) must be coerced
# to a plain dict so signed headers actually reach the upstream.
assert call_kwargs["custom_headers"] == {
"authorization": "AWS4-HMAC-SHA256 signed",
"content-type": "application/json",
}
@pytest.mark.asyncio
async def test_pass_through_request_non_streaming_uses_content_for_state_raw_body():
"""
Bedrock SigV4 path: exact signed bytes live on request.state; upstream must receive
content=... even if pre_call_hook mutates the parsed dict (would change json=).
"""
# Bytes that were signed (simulated); parsed body + hook will diverge on purpose.
raw_signed = b'{"retrievalQuery":{"text":"signed"},"sig":"intact"}'
parsed_from_wire = {"retrievalQuery": {"text": "signed"}, "sig": "intact"}
mock_request = MagicMock(spec=Request)
mock_request.method = "POST"
mock_request.query_params = QueryParams({})
mock_request.headers = Headers({"Content-Type": "application/json"})
mock_request.state = SimpleNamespace()
setattr(mock_request.state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY, raw_signed)
mock_request.body = AsyncMock(
return_value=json.dumps(parsed_from_wire).encode("utf-8")
)
mock_user = MagicMock()
mock_user.api_key = "sk-test"
upstream = httpx.Response(
status_code=200,
headers={"content-type": "application/json"},
content=b'{"ok": true}',
request=httpx.Request(
"POST",
"https://bedrock-agent-runtime.us-east-1.amazonaws.com/knowledgebases/KB/retrieve",
),
)
mock_async_client = AsyncMock()
mock_async_client.request = AsyncMock(return_value=upstream)
mock_client_obj = MagicMock()
mock_client_obj.client = mock_async_client
async def _hook_mutates_body(**kwargs):
data = kwargs["data"]
if isinstance(data, dict):
data["hook_mutated"] = True
return data
with (
patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client",
return_value=mock_client_obj,
),
patch(
"litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook",
new=AsyncMock(side_effect=_hook_mutates_body),
),
patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler",
new=AsyncMock(),
),
):
await pass_through_request(
request=mock_request,
target="https://bedrock-agent-runtime.us-east-1.amazonaws.com/knowledgebases/KB/retrieve",
custom_headers={"content-type": "application/json"},
user_api_key_dict=mock_user,
stream=False,
)
mock_async_client.request.assert_called_once()
req_kw = mock_async_client.request.call_args[1]
assert req_kw.get("content") == raw_signed
assert "json" not in req_kw
@pytest.mark.asyncio
async def test_pass_through_request_streaming_uses_content_for_state_raw_body():
"""Streaming pass-through with state raw body must use build_request(..., content=...)."""
raw_signed = b'{"model":"m","stream":true}'
parsed_from_wire = {"model": "m", "stream": True}
mock_request = MagicMock(spec=Request)
mock_request.method = "POST"
mock_request.query_params = QueryParams({})
mock_request.headers = Headers({"Content-Type": "application/json"})
mock_request.state = SimpleNamespace()
setattr(mock_request.state, LITELLM_PASS_THROUGH_RAW_BODY_STATE_KEY, raw_signed)
mock_request.body = AsyncMock(
return_value=json.dumps(parsed_from_wire).encode("utf-8")
)
mock_user = MagicMock()
mock_user.api_key = "sk-test"
mock_built = MagicMock()
mock_async_client = AsyncMock()
mock_async_client.build_request = MagicMock(return_value=mock_built)
stream_resp = httpx.Response(
status_code=200,
headers={"content-type": "text/event-stream"},
content=b"data: {}\n\n",
request=httpx.Request("POST", "https://example.com/v1/messages"),
)
mock_async_client.send = AsyncMock(return_value=stream_resp)
mock_client_obj = MagicMock()
mock_client_obj.client = mock_async_client
with (
patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.get_async_httpx_client",
return_value=mock_client_obj,
),
patch(
"litellm.proxy.proxy_server.proxy_logging_obj.pre_call_hook",
new=AsyncMock(side_effect=lambda **kw: kw["data"]),
),
patch(
"litellm.proxy.pass_through_endpoints.pass_through_endpoints.pass_through_endpoint_logging.pass_through_async_success_handler",
new=AsyncMock(),
),
):
response = await pass_through_request(
request=mock_request,
target="https://example.com/v1/messages",
custom_headers={"Authorization": "Bearer x"},
user_api_key_dict=mock_user,
stream=None,
)
from fastapi.responses import StreamingResponse
assert isinstance(response, StreamingResponse)
mock_async_client.build_request.assert_called_once()
br_kw = mock_async_client.build_request.call_args[1]
assert br_kw.get("content") == raw_signed
assert "json" not in br_kw
@pytest.mark.asyncio

View File

@ -538,6 +538,36 @@ def test_forward_headers_from_request_protected_headers_not_overwritten():
assert "Anthropic-Beta" not in result
def test_forward_headers_custom_wins_case_insensitive_over_request_authorization():
"""
When forwarding request headers, provider-signed/custom headers must win
even if the incoming request uses a different case for the same header name.
"""
from litellm.passthrough.utils import BasePassthroughUtils
request_headers = {
"authorization": "Bearer sk-litellm-key",
"content-type": "application/json",
"x-request-id": "req-123",
}
signed_headers = {
"Authorization": "AWS4-HMAC-SHA256 signed",
"Content-Type": "application/json",
}
result = BasePassthroughUtils.forward_headers_from_request(
request_headers=request_headers,
headers=signed_headers.copy(),
forward_headers=True,
)
assert result["Authorization"] == "AWS4-HMAC-SHA256 signed"
assert "authorization" not in result
assert result["Content-Type"] == "application/json"
assert "content-type" not in result
assert result["x-request-id"] == "req-123"
@pytest.mark.asyncio
async def test_vertex_passthrough_custom_model_name_replaced_in_url():
"""