chore(security): close two unaddressed SSRF cases

Two SSRF findings were OPEN with no in-flight fix; both are closed
now using narrow defenses that key off existing trust boundaries.

VERIA-6 (Milvus ``litellm_embedding_config``):
``is_request_body_safe`` already blocks ``api_base`` / ``api_key`` /
``langfuse_host`` / ``s3_endpoint_url`` / etc. at the *root* of the
request body, gated by an admin opt-in (``allow_client_side_credentials``
or per-deployment ``configurable_clientside_auth_params``). The bug is
that the Milvus vector-store transformer unpacks
``litellm_embedding_config`` into ``litellm.embedding(**embedding_config)``,
so a caller can smuggle the same banned params in via nesting and bypass
the check. Fix: ``is_request_body_safe`` now recurses into a known list
of nested-config dicts (``litellm_embedding_config`` for now) and applies
the same banned-param check with the same admin opt-in. Admin-side
vector-store config flows through ``litellm_params`` rather than the
request body, so it's unaffected.

VERIA-51 (polling URLs returned by upstream APIs):
Azure DALL-E 2, Azure Document Intelligence, and Black Forest Labs
all blindly fetched a polling URL returned by the upstream and
attached the operator's API key to the request. A compromised upstream
or a future API contract change could redirect credentials anywhere.
New ``url_utils.assert_same_origin(candidate, expected)`` helper checks
scheme, host (case-insensitive), and port (with default-port
normalization). Applied at all five polling sites: Azure DALL-E
sync+async, Azure DI sync+async, BFL image generation sync+async, BFL
image edit sync+async. Cross-origin polling URLs now raise rather than
forward credentials. The Azure DALL-E ``Expected 'status' in response``
exception no longer reflects the raw response body — that path turned
Blind SSRF into Full-Read SSRF for the limited window before the
origin check fully closed it.

Tests: 7 ``assert_same_origin`` unit tests, 6 ``is_request_body_safe``
nested-config tests, 5 polling-site rejection tests + 1 same-origin
sanity check.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
user 2026-05-01 18:43:47 +00:00
parent 934ecdca78
commit 0d4875dec9
No known key found for this signature in database
43 changed files with 493 additions and 4 deletions

View File

@ -199,6 +199,49 @@ def validate_url(url: str) -> Tuple[str, str]:
return rewritten, host_header
def assert_same_origin(candidate_url: str, expected_url: str) -> None:
"""Verify ``candidate_url`` shares scheme, host, and port with ``expected_url``.
Use when an upstream API returns a URL meant for follow-up requests
(e.g. an async-job polling URL that will be hit with the operator's
API key in the headers). The upstream is trusted because the operator
configured ``api_base``, but the URL it hands back must actually point
back at the same origin or we'd be blindly forwarding credentials
wherever the upstream told us to.
Hostnames are compared case-insensitively. Default ports are made
explicit (HTTP80, HTTPS443) so ``https://api.example.com:443/...``
and ``https://api.example.com/...`` are treated as the same origin.
"""
candidate = urlparse(candidate_url)
expected = urlparse(expected_url)
if candidate.scheme not in _ALLOWED_SCHEMES:
raise SSRFError(f"URL scheme '{candidate.scheme}' is not allowed")
if candidate.scheme != expected.scheme:
raise SSRFError(
"Origin mismatch: scheme "
f"{candidate.scheme!r} != expected {expected.scheme!r}"
)
candidate_host = _normalize_host(candidate.hostname or "")
expected_host = _normalize_host(expected.hostname or "")
if not candidate_host or candidate_host != expected_host:
raise SSRFError(
"Origin mismatch: host "
f"{candidate.hostname!r} != expected {expected.hostname!r}"
)
default_port = 443 if candidate.scheme == "https" else 80
candidate_port = candidate.port if candidate.port is not None else default_port
expected_port = expected.port if expected.port is not None else default_port
if candidate_port != expected_port:
raise SSRFError(
"Origin mismatch: port " f"{candidate_port} != expected {expected_port}"
)
_MAX_REDIRECTS = 10

View File

@ -16,6 +16,7 @@ import litellm
from litellm.constants import AZURE_OPERATION_POLLING_TIMEOUT, DEFAULT_MAX_RETRIES
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -898,6 +899,17 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
operation_location_url = response.headers["operation-location"]
else:
raise AzureOpenAIError(status_code=500, message=response.text)
# Reject polling URLs that don't share an origin with ``api_base``.
# Without this an upstream-controlled or attacker-controlled
# value would receive the operator's Azure API key in the
# request headers below. VERIA-51.
try:
assert_same_origin(operation_location_url, api_base)
except SSRFError as ssrf_err:
raise AzureOpenAIError(
status_code=502,
message=f"Rejected polling URL: {ssrf_err}",
)
response = await async_handler.get(
url=operation_location_url,
headers=headers,
@ -908,8 +920,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
timeout_secs: int = AZURE_OPERATION_POLLING_TIMEOUT
start_time = time.time()
if "status" not in response.json():
raise Exception(
"Expected 'status' in response. Got={}".format(response.json())
# Don't reflect the raw response body — when the polling
# URL points at an internal JSON API (cloud metadata
# service etc.) reflecting it here turns Blind SSRF into
# Full-Read SSRF. VERIA-51.
raise AzureOpenAIError(
status_code=502,
message="Polling response missing 'status' field",
)
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
@ -1009,6 +1026,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
operation_location_url = response.headers["operation-location"]
else:
raise AzureOpenAIError(status_code=500, message=response.text)
try:
assert_same_origin(operation_location_url, api_base)
except SSRFError as ssrf_err:
raise AzureOpenAIError(
status_code=502,
message=f"Rejected polling URL: {ssrf_err}",
)
response = sync_handler.get(
url=operation_location_url,
headers=headers,
@ -1019,8 +1043,9 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
timeout_secs: int = AZURE_OPERATION_POLLING_TIMEOUT
start_time = time.time()
if "status" not in response.json():
raise Exception(
"Expected 'status' in response. Got={}".format(response.json())
raise AzureOpenAIError(
status_code=502,
message="Polling response missing 'status' field",
)
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:

View File

@ -17,6 +17,7 @@ from urllib.parse import quote
import httpx
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
from litellm.constants import (
AZURE_DOCUMENT_INTELLIGENCE_API_VERSION,
AZURE_DOCUMENT_INTELLIGENCE_DEFAULT_DPI,
@ -599,6 +600,16 @@ class AzureDocumentIntelligenceOCRConfig(BaseOCRConfig):
"Azure Document Intelligence returned 202 but no Operation-Location header found"
)
# Reject cross-origin polling URLs — the auth headers
# below would otherwise leak to whatever URL the upstream
# (or an attacker-controlled upstream) returns. VERIA-51.
try:
assert_same_origin(operation_url, str(raw_response.request.url))
except SSRFError as ssrf_err:
raise ValueError(
f"Azure Document Intelligence: rejected polling URL ({ssrf_err})"
)
# Get headers for polling (need auth)
poll_headers = {
"Ocp-Apim-Subscription-Key": raw_response.request.headers.get(
@ -711,6 +722,14 @@ class AzureDocumentIntelligenceOCRConfig(BaseOCRConfig):
"Azure Document Intelligence returned 202 but no Operation-Location header found"
)
# Reject cross-origin polling URLs (see sync path). VERIA-51.
try:
assert_same_origin(operation_url, str(raw_response.request.url))
except SSRFError as ssrf_err:
raise ValueError(
f"Azure Document Intelligence: rejected polling URL ({ssrf_err})"
)
# Get headers for polling (need auth)
poll_headers = {
"Ocp-Apim-Subscription-Key": raw_response.request.headers.get(

View File

@ -15,6 +15,7 @@ import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -331,6 +332,17 @@ class BlackForestLabsImageEdit:
message="No polling_url in BFL response",
)
# Reject cross-origin polling URLs — the ``x-key`` auth header
# would otherwise leak to whatever URL the upstream returns.
# VERIA-51.
try:
assert_same_origin(polling_url, str(initial_response.request.url))
except SSRFError as ssrf_err:
raise BlackForestLabsError(
status_code=502,
message=f"Rejected polling URL: {ssrf_err}",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
@ -416,6 +428,17 @@ class BlackForestLabsImageEdit:
message="No polling_url in BFL response",
)
# Reject cross-origin polling URLs — the ``x-key`` auth header
# would otherwise leak to whatever URL the upstream returns.
# VERIA-51.
try:
assert_same_origin(polling_url, str(initial_response.request.url))
except SSRFError as ssrf_err:
raise BlackForestLabsError(
status_code=502,
message=f"Rejected polling URL: {ssrf_err}",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}

View File

@ -15,6 +15,7 @@ import httpx
import litellm
from litellm._logging import verbose_logger
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
HTTPHandler,
@ -317,6 +318,17 @@ class BlackForestLabsImageGeneration:
message="No polling_url in BFL response",
)
# Reject cross-origin polling URLs — the ``x-key`` auth header
# would otherwise leak to whatever URL the upstream returns.
# VERIA-51.
try:
assert_same_origin(polling_url, str(initial_response.request.url))
except SSRFError as ssrf_err:
raise BlackForestLabsError(
status_code=502,
message=f"Rejected polling URL: {ssrf_err}",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}
@ -402,6 +414,17 @@ class BlackForestLabsImageGeneration:
message="No polling_url in BFL response",
)
# Reject cross-origin polling URLs — the ``x-key`` auth header
# would otherwise leak to whatever URL the upstream returns.
# VERIA-51.
try:
assert_same_origin(polling_url, str(initial_response.request.url))
except SSRFError as ssrf_err:
raise BlackForestLabsError(
status_code=502,
message=f"Rejected polling URL: {ssrf_err}",
)
# Get just the auth header for polling
polling_headers = {"x-key": headers.get("x-key", "")}

View File

@ -241,9 +241,30 @@ def is_request_body_safe(
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
)
# Recurse into nested config dicts whose values get unpacked as
# ``**kwargs`` into outbound API calls — same SSRF / credential
# exfil surface as the root, but historically not covered by this
# banned-param check. VERIA-6.
for nested_key in _NESTED_CONFIG_KEYS:
nested = request_body.get(nested_key)
if isinstance(nested, dict):
is_request_body_safe(
request_body=nested,
general_settings=general_settings,
llm_router=llm_router,
model=model,
)
return True
# Config dicts whose entries are spread as ``**dict`` into outbound LLM
# API calls. ``litellm_embedding_config`` is consumed by the Milvus
# vector store transformer; future nested-config keys with the same
# threat shape should be added here.
_NESTED_CONFIG_KEYS: Tuple[str, ...] = ("litellm_embedding_config",)
async def pre_db_read_auth_checks(
request: Request,
request_data: dict,

View File

@ -394,3 +394,61 @@ class TestHostAllowlist:
monkeypatch.setattr(url_utils.socket, "getaddrinfo", fake)
validate_url("http://internal.corp/")
# ── assert_same_origin ────────────────────────────────────────────────────────
from litellm.litellm_core_utils.url_utils import assert_same_origin
def test_assert_same_origin_matches_scheme_host_port():
"""A polling URL on the same scheme + host + port as the api_base
passes the upstream is trusted; the URL it returned points back at
the same upstream."""
assert_same_origin(
"https://api.example.com/v1/operations/abc",
"https://api.example.com/v1/generate",
)
def test_assert_same_origin_treats_default_ports_as_explicit():
"""``https://x/`` and ``https://x:443/`` are the same origin."""
assert_same_origin("https://api.example.com/poll", "https://api.example.com:443/")
assert_same_origin("https://api.example.com:443/poll", "https://api.example.com/")
assert_same_origin("http://api.example.com/poll", "http://api.example.com:80/")
def test_assert_same_origin_rejects_different_host():
with pytest.raises(SSRFError, match="host"):
assert_same_origin(
"https://attacker.example.com/poll",
"https://api.example.com/generate",
)
def test_assert_same_origin_rejects_different_scheme():
with pytest.raises(SSRFError, match="scheme"):
assert_same_origin(
"http://api.example.com/poll", "https://api.example.com/generate"
)
def test_assert_same_origin_rejects_different_port():
with pytest.raises(SSRFError, match="port"):
assert_same_origin(
"https://api.example.com:8443/poll", "https://api.example.com/generate"
)
def test_assert_same_origin_rejects_non_http_scheme():
"""``file://`` polling URLs are rejected outright — the upstream
should never return a non-HTTP scheme."""
with pytest.raises(SSRFError, match="scheme"):
assert_same_origin("file:///etc/passwd", "https://api.example.com/")
def test_assert_same_origin_case_insensitive_host():
assert_same_origin(
"https://API.example.com/poll", "https://api.example.com/generate"
)

View File

@ -0,0 +1,177 @@
"""
VERIA-51: polling URLs returned by upstream APIs (Azure DALL-E,
Azure Document Intelligence, Black Forest Labs) used to be followed
without origin validation. The handlers attached the operator's API
key to the polling request, so an attacker who could influence the
upstream response (or a compromised upstream) could redirect the proxy
to send credentials anywhere.
These tests assert each handler now rejects polling URLs that don't
share an origin with the original request URL.
"""
from unittest.mock import MagicMock, patch
import httpx
import pytest
# Azure DALL-E sync + async paths route through ``assert_same_origin``
# the same way as the cases below. The helper itself is unit-tested in
# ``tests/test_litellm/litellm_core_utils/test_url_utils.py``; the
# tests here exercise the wiring at sites with simpler signatures.
# ── Azure Document Intelligence polling ───────────────────────────────────────
def test_azure_di_sync_rejects_cross_origin_polling():
from litellm.llms.azure_ai.ocr.document_intelligence.transformation import (
AzureDocumentIntelligenceOCRConfig,
)
config = AzureDocumentIntelligenceOCRConfig()
raw_response = MagicMock()
raw_response.status_code = 202
raw_response.headers = {
"Operation-Location": "https://attacker.example.com/results/xyz",
}
raw_response.request = MagicMock()
raw_response.request.url = (
"https://eastus.cognitiveservices.azure.com/documentintelligence/.../analyze"
)
raw_response.request.headers = {"Ocp-Apim-Subscription-Key": "leak-me"}
with pytest.raises(ValueError, match="rejected polling URL"):
config.transform_ocr_response(
model="azure-doc-intel",
raw_response=raw_response,
logging_obj=MagicMock(),
request_data={},
optional_params={},
litellm_params={},
encoding=None,
response={},
)
# ── Black Forest Labs polling ─────────────────────────────────────────────────
def test_bfl_image_generation_sync_rejects_cross_origin_polling():
from litellm.llms.black_forest_labs.image_generation.handler import (
BlackForestLabsImageGeneration,
)
handler = BlackForestLabsImageGeneration()
initial_response = MagicMock()
initial_response.status_code = 200
initial_response.json = MagicMock(
return_value={"polling_url": "https://attacker.example.com/get_result"}
)
initial_response.request = MagicMock()
initial_response.request.url = "https://api.bfl.ai/v1/flux-pro"
sync_client = MagicMock()
sync_client.get = MagicMock()
with pytest.raises(Exception, match="Rejected polling URL"):
handler._poll_for_result_sync(
initial_response=initial_response,
headers={"x-key": "secret"},
sync_client=sync_client,
)
sync_client.get.assert_not_called()
@pytest.mark.asyncio
async def test_bfl_image_generation_async_rejects_cross_origin_polling():
from litellm.llms.black_forest_labs.image_generation.handler import (
BlackForestLabsImageGeneration,
)
handler = BlackForestLabsImageGeneration()
initial_response = MagicMock()
initial_response.status_code = 200
initial_response.json = MagicMock(
return_value={"polling_url": "https://attacker.example.com/get_result"}
)
initial_response.request = MagicMock()
initial_response.request.url = "https://api.bfl.ai/v1/flux-pro"
async_client = MagicMock()
async_client.get = MagicMock()
with pytest.raises(Exception, match="Rejected polling URL"):
await handler._poll_for_result_async(
initial_response=initial_response,
headers={"x-key": "secret"},
async_client=async_client,
)
async_client.get.assert_not_called()
def test_bfl_image_edit_sync_rejects_cross_origin_polling():
from litellm.llms.black_forest_labs.image_edit.handler import (
BlackForestLabsImageEdit,
)
handler = BlackForestLabsImageEdit()
initial_response = MagicMock()
initial_response.status_code = 200
initial_response.json = MagicMock(
return_value={"polling_url": "https://attacker.example.com/get_result"}
)
initial_response.request = MagicMock()
initial_response.request.url = "https://api.bfl.ai/v1/flux-pro/edit"
sync_client = MagicMock()
sync_client.get = MagicMock()
with pytest.raises(Exception, match="Rejected polling URL"):
handler._poll_for_result_sync(
initial_response=initial_response,
headers={"x-key": "secret"},
sync_client=sync_client,
)
sync_client.get.assert_not_called()
def test_bfl_image_generation_same_origin_polling_passes():
"""Sanity check: when the polling URL shares origin with the original
request, the origin check passes and polling proceeds."""
from litellm.llms.black_forest_labs.image_generation.handler import (
BlackForestLabsImageGeneration,
)
handler = BlackForestLabsImageGeneration()
initial_response = MagicMock()
initial_response.status_code = 200
initial_response.json = MagicMock(
return_value={"polling_url": "https://api.bfl.ai/v1/get_result?id=abc"}
)
initial_response.request = MagicMock()
initial_response.request.url = "https://api.bfl.ai/v1/flux-pro"
sync_client = MagicMock()
poll_response = MagicMock()
poll_response.status_code = 200
poll_response.json = MagicMock(return_value={"status": "Ready"})
sync_client.get = MagicMock(return_value=poll_response)
result = handler._poll_for_result_sync(
initial_response=initial_response,
headers={"x-key": "secret"},
sync_client=sync_client,
)
sync_client.get.assert_called_once()
assert result is poll_response

View File

@ -964,3 +964,103 @@ class TestIsRequestBodySafeBlocksEndpointTargetingFields:
)
is True
)
# ── is_request_body_safe nested-config recursion (VERIA-6) ────────────────────
class TestIsRequestBodySafeNestedConfig:
"""The Milvus vector store transformer unpacks
``litellm_embedding_config`` as ``**kwargs`` into ``litellm.embedding(...)``
same SSRF / credential-exfil surface as a top-level ``api_base`` in
the request body. ``is_request_body_safe`` must recurse into this
nested dict so a banned param can't be smuggled in via nesting."""
def test_root_level_api_base_blocked_when_no_opt_in(self):
"""Sanity check: pre-existing root-level enforcement still works."""
with pytest.raises(ValueError, match="api_base"):
is_request_body_safe(
request_body={"api_base": "https://attacker.example.com"},
general_settings={},
llm_router=None,
model="gpt-4",
)
def test_nested_api_base_in_embedding_config_blocked(self):
"""Smuggling ``api_base`` inside ``litellm_embedding_config`` is
the VERIA-6 bypass must be blocked by the recursive check."""
with pytest.raises(ValueError, match="api_base"):
is_request_body_safe(
request_body={
"litellm_embedding_config": {
"api_base": "https://attacker.example.com",
"api_key": "leaked-key",
}
},
general_settings={},
llm_router=None,
model="milvus-store",
)
def test_nested_langfuse_host_in_embedding_config_blocked(self):
"""The recursion uses the *full* banned-param list, not a special
subset so any flag that's banned at the root is also banned
when nested."""
with pytest.raises(ValueError, match="langfuse_host"):
is_request_body_safe(
request_body={
"litellm_embedding_config": {
"langfuse_host": "https://attacker.example.com"
}
},
general_settings={},
llm_router=None,
model="milvus-store",
)
def test_nested_api_base_allowed_when_admin_opts_in(self):
"""Admins who explicitly enable client-side credential passthrough
keep the existing escape hatch same UX as for root-level."""
assert (
is_request_body_safe(
request_body={
"litellm_embedding_config": {
"api_base": "https://my-azure.example.com"
}
},
general_settings={"allow_client_side_credentials": True},
llm_router=None,
model="milvus-store",
)
is True
)
def test_safe_nested_config_accepted(self):
"""A nested config without any banned params passes — there's no
false-positive on legitimate ``api_version`` / model params."""
assert (
is_request_body_safe(
request_body={
"litellm_embedding_config": {
"api_version": "2024-02-15-preview",
}
},
general_settings={},
llm_router=None,
model="milvus-store",
)
is True
)
def test_non_dict_nested_config_does_not_break_check(self):
"""A bogus type for ``litellm_embedding_config`` (string, list,
None) must not crash the validator it should just fall through."""
assert (
is_request_body_safe(
request_body={"litellm_embedding_config": "not-a-dict"},
general_settings={},
llm_router=None,
model="x",
)
is True
)