chore(security): close two unaddressed SSRF cases
Two SSRF findings were OPEN with no in-flight fix; both are closed now using narrow defenses that key off existing trust boundaries. VERIA-6 (Milvus ``litellm_embedding_config``): ``is_request_body_safe`` already blocks ``api_base`` / ``api_key`` / ``langfuse_host`` / ``s3_endpoint_url`` / etc. at the *root* of the request body, gated by an admin opt-in (``allow_client_side_credentials`` or per-deployment ``configurable_clientside_auth_params``). The bug is that the Milvus vector-store transformer unpacks ``litellm_embedding_config`` into ``litellm.embedding(**embedding_config)``, so a caller can smuggle the same banned params in via nesting and bypass the check. Fix: ``is_request_body_safe`` now recurses into a known list of nested-config dicts (``litellm_embedding_config`` for now) and applies the same banned-param check with the same admin opt-in. Admin-side vector-store config flows through ``litellm_params`` rather than the request body, so it's unaffected. VERIA-51 (polling URLs returned by upstream APIs): Azure DALL-E 2, Azure Document Intelligence, and Black Forest Labs all blindly fetched a polling URL returned by the upstream and attached the operator's API key to the request. A compromised upstream or a future API contract change could redirect credentials anywhere. New ``url_utils.assert_same_origin(candidate, expected)`` helper checks scheme, host (case-insensitive), and port (with default-port normalization). Applied at all five polling sites: Azure DALL-E sync+async, Azure DI sync+async, BFL image generation sync+async, BFL image edit sync+async. Cross-origin polling URLs now raise rather than forward credentials. The Azure DALL-E ``Expected 'status' in response`` exception no longer reflects the raw response body — that path turned Blind SSRF into Full-Read SSRF for the limited window before the origin check fully closed it. Tests: 7 ``assert_same_origin`` unit tests, 6 ``is_request_body_safe`` nested-config tests, 5 polling-site rejection tests + 1 same-origin sanity check. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
934ecdca78
commit
0d4875dec9
@ -199,6 +199,49 @@ def validate_url(url: str) -> Tuple[str, str]:
|
||||
return rewritten, host_header
|
||||
|
||||
|
||||
def assert_same_origin(candidate_url: str, expected_url: str) -> None:
|
||||
"""Verify ``candidate_url`` shares scheme, host, and port with ``expected_url``.
|
||||
|
||||
Use when an upstream API returns a URL meant for follow-up requests
|
||||
(e.g. an async-job polling URL that will be hit with the operator's
|
||||
API key in the headers). The upstream is trusted because the operator
|
||||
configured ``api_base``, but the URL it hands back must actually point
|
||||
back at the same origin or we'd be blindly forwarding credentials
|
||||
wherever the upstream told us to.
|
||||
|
||||
Hostnames are compared case-insensitively. Default ports are made
|
||||
explicit (HTTP→80, HTTPS→443) so ``https://api.example.com:443/...``
|
||||
and ``https://api.example.com/...`` are treated as the same origin.
|
||||
"""
|
||||
candidate = urlparse(candidate_url)
|
||||
expected = urlparse(expected_url)
|
||||
|
||||
if candidate.scheme not in _ALLOWED_SCHEMES:
|
||||
raise SSRFError(f"URL scheme '{candidate.scheme}' is not allowed")
|
||||
|
||||
if candidate.scheme != expected.scheme:
|
||||
raise SSRFError(
|
||||
"Origin mismatch: scheme "
|
||||
f"{candidate.scheme!r} != expected {expected.scheme!r}"
|
||||
)
|
||||
|
||||
candidate_host = _normalize_host(candidate.hostname or "")
|
||||
expected_host = _normalize_host(expected.hostname or "")
|
||||
if not candidate_host or candidate_host != expected_host:
|
||||
raise SSRFError(
|
||||
"Origin mismatch: host "
|
||||
f"{candidate.hostname!r} != expected {expected.hostname!r}"
|
||||
)
|
||||
|
||||
default_port = 443 if candidate.scheme == "https" else 80
|
||||
candidate_port = candidate.port if candidate.port is not None else default_port
|
||||
expected_port = expected.port if expected.port is not None else default_port
|
||||
if candidate_port != expected_port:
|
||||
raise SSRFError(
|
||||
"Origin mismatch: port " f"{candidate_port} != expected {expected_port}"
|
||||
)
|
||||
|
||||
|
||||
_MAX_REDIRECTS = 10
|
||||
|
||||
|
||||
|
||||
@ -16,6 +16,7 @@ import litellm
|
||||
from litellm.constants import AZURE_OPERATION_POLLING_TIMEOUT, DEFAULT_MAX_RETRIES
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
@ -898,6 +899,17 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||
operation_location_url = response.headers["operation-location"]
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=response.text)
|
||||
# Reject polling URLs that don't share an origin with ``api_base``.
|
||||
# Without this an upstream-controlled or attacker-controlled
|
||||
# value would receive the operator's Azure API key in the
|
||||
# request headers below. VERIA-51.
|
||||
try:
|
||||
assert_same_origin(operation_location_url, api_base)
|
||||
except SSRFError as ssrf_err:
|
||||
raise AzureOpenAIError(
|
||||
status_code=502,
|
||||
message=f"Rejected polling URL: {ssrf_err}",
|
||||
)
|
||||
response = await async_handler.get(
|
||||
url=operation_location_url,
|
||||
headers=headers,
|
||||
@ -908,8 +920,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||
timeout_secs: int = AZURE_OPERATION_POLLING_TIMEOUT
|
||||
start_time = time.time()
|
||||
if "status" not in response.json():
|
||||
raise Exception(
|
||||
"Expected 'status' in response. Got={}".format(response.json())
|
||||
# Don't reflect the raw response body — when the polling
|
||||
# URL points at an internal JSON API (cloud metadata
|
||||
# service etc.) reflecting it here turns Blind SSRF into
|
||||
# Full-Read SSRF. VERIA-51.
|
||||
raise AzureOpenAIError(
|
||||
status_code=502,
|
||||
message="Polling response missing 'status' field",
|
||||
)
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
@ -1009,6 +1026,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||
operation_location_url = response.headers["operation-location"]
|
||||
else:
|
||||
raise AzureOpenAIError(status_code=500, message=response.text)
|
||||
try:
|
||||
assert_same_origin(operation_location_url, api_base)
|
||||
except SSRFError as ssrf_err:
|
||||
raise AzureOpenAIError(
|
||||
status_code=502,
|
||||
message=f"Rejected polling URL: {ssrf_err}",
|
||||
)
|
||||
response = sync_handler.get(
|
||||
url=operation_location_url,
|
||||
headers=headers,
|
||||
@ -1019,8 +1043,9 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM):
|
||||
timeout_secs: int = AZURE_OPERATION_POLLING_TIMEOUT
|
||||
start_time = time.time()
|
||||
if "status" not in response.json():
|
||||
raise Exception(
|
||||
"Expected 'status' in response. Got={}".format(response.json())
|
||||
raise AzureOpenAIError(
|
||||
status_code=502,
|
||||
message="Polling response missing 'status' field",
|
||||
)
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
|
||||
@ -17,6 +17,7 @@ from urllib.parse import quote
|
||||
import httpx
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
|
||||
from litellm.constants import (
|
||||
AZURE_DOCUMENT_INTELLIGENCE_API_VERSION,
|
||||
AZURE_DOCUMENT_INTELLIGENCE_DEFAULT_DPI,
|
||||
@ -599,6 +600,16 @@ class AzureDocumentIntelligenceOCRConfig(BaseOCRConfig):
|
||||
"Azure Document Intelligence returned 202 but no Operation-Location header found"
|
||||
)
|
||||
|
||||
# Reject cross-origin polling URLs — the auth headers
|
||||
# below would otherwise leak to whatever URL the upstream
|
||||
# (or an attacker-controlled upstream) returns. VERIA-51.
|
||||
try:
|
||||
assert_same_origin(operation_url, str(raw_response.request.url))
|
||||
except SSRFError as ssrf_err:
|
||||
raise ValueError(
|
||||
f"Azure Document Intelligence: rejected polling URL ({ssrf_err})"
|
||||
)
|
||||
|
||||
# Get headers for polling (need auth)
|
||||
poll_headers = {
|
||||
"Ocp-Apim-Subscription-Key": raw_response.request.headers.get(
|
||||
@ -711,6 +722,14 @@ class AzureDocumentIntelligenceOCRConfig(BaseOCRConfig):
|
||||
"Azure Document Intelligence returned 202 but no Operation-Location header found"
|
||||
)
|
||||
|
||||
# Reject cross-origin polling URLs (see sync path). VERIA-51.
|
||||
try:
|
||||
assert_same_origin(operation_url, str(raw_response.request.url))
|
||||
except SSRFError as ssrf_err:
|
||||
raise ValueError(
|
||||
f"Azure Document Intelligence: rejected polling URL ({ssrf_err})"
|
||||
)
|
||||
|
||||
# Get headers for polling (need auth)
|
||||
poll_headers = {
|
||||
"Ocp-Apim-Subscription-Key": raw_response.request.headers.get(
|
||||
|
||||
@ -15,6 +15,7 @@ import httpx
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
@ -331,6 +332,17 @@ class BlackForestLabsImageEdit:
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Reject cross-origin polling URLs — the ``x-key`` auth header
|
||||
# would otherwise leak to whatever URL the upstream returns.
|
||||
# VERIA-51.
|
||||
try:
|
||||
assert_same_origin(polling_url, str(initial_response.request.url))
|
||||
except SSRFError as ssrf_err:
|
||||
raise BlackForestLabsError(
|
||||
status_code=502,
|
||||
message=f"Rejected polling URL: {ssrf_err}",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
@ -416,6 +428,17 @@ class BlackForestLabsImageEdit:
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Reject cross-origin polling URLs — the ``x-key`` auth header
|
||||
# would otherwise leak to whatever URL the upstream returns.
|
||||
# VERIA-51.
|
||||
try:
|
||||
assert_same_origin(polling_url, str(initial_response.request.url))
|
||||
except SSRFError as ssrf_err:
|
||||
raise BlackForestLabsError(
|
||||
status_code=502,
|
||||
message=f"Rejected polling URL: {ssrf_err}",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
|
||||
@ -15,6 +15,7 @@ import httpx
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
HTTPHandler,
|
||||
@ -317,6 +318,17 @@ class BlackForestLabsImageGeneration:
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Reject cross-origin polling URLs — the ``x-key`` auth header
|
||||
# would otherwise leak to whatever URL the upstream returns.
|
||||
# VERIA-51.
|
||||
try:
|
||||
assert_same_origin(polling_url, str(initial_response.request.url))
|
||||
except SSRFError as ssrf_err:
|
||||
raise BlackForestLabsError(
|
||||
status_code=502,
|
||||
message=f"Rejected polling URL: {ssrf_err}",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
@ -402,6 +414,17 @@ class BlackForestLabsImageGeneration:
|
||||
message="No polling_url in BFL response",
|
||||
)
|
||||
|
||||
# Reject cross-origin polling URLs — the ``x-key`` auth header
|
||||
# would otherwise leak to whatever URL the upstream returns.
|
||||
# VERIA-51.
|
||||
try:
|
||||
assert_same_origin(polling_url, str(initial_response.request.url))
|
||||
except SSRFError as ssrf_err:
|
||||
raise BlackForestLabsError(
|
||||
status_code=502,
|
||||
message=f"Rejected polling URL: {ssrf_err}",
|
||||
)
|
||||
|
||||
# Get just the auth header for polling
|
||||
polling_headers = {"x-key": headers.get("x-key", "")}
|
||||
|
||||
|
||||
@ -241,9 +241,30 @@ def is_request_body_safe(
|
||||
"Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997",
|
||||
)
|
||||
|
||||
# Recurse into nested config dicts whose values get unpacked as
|
||||
# ``**kwargs`` into outbound API calls — same SSRF / credential
|
||||
# exfil surface as the root, but historically not covered by this
|
||||
# banned-param check. VERIA-6.
|
||||
for nested_key in _NESTED_CONFIG_KEYS:
|
||||
nested = request_body.get(nested_key)
|
||||
if isinstance(nested, dict):
|
||||
is_request_body_safe(
|
||||
request_body=nested,
|
||||
general_settings=general_settings,
|
||||
llm_router=llm_router,
|
||||
model=model,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
# Config dicts whose entries are spread as ``**dict`` into outbound LLM
|
||||
# API calls. ``litellm_embedding_config`` is consumed by the Milvus
|
||||
# vector store transformer; future nested-config keys with the same
|
||||
# threat shape should be added here.
|
||||
_NESTED_CONFIG_KEYS: Tuple[str, ...] = ("litellm_embedding_config",)
|
||||
|
||||
|
||||
async def pre_db_read_auth_checks(
|
||||
request: Request,
|
||||
request_data: dict,
|
||||
|
||||
@ -394,3 +394,61 @@ class TestHostAllowlist:
|
||||
|
||||
monkeypatch.setattr(url_utils.socket, "getaddrinfo", fake)
|
||||
validate_url("http://internal.corp/")
|
||||
|
||||
|
||||
# ── assert_same_origin ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
from litellm.litellm_core_utils.url_utils import assert_same_origin
|
||||
|
||||
|
||||
def test_assert_same_origin_matches_scheme_host_port():
|
||||
"""A polling URL on the same scheme + host + port as the api_base
|
||||
passes — the upstream is trusted; the URL it returned points back at
|
||||
the same upstream."""
|
||||
assert_same_origin(
|
||||
"https://api.example.com/v1/operations/abc",
|
||||
"https://api.example.com/v1/generate",
|
||||
)
|
||||
|
||||
|
||||
def test_assert_same_origin_treats_default_ports_as_explicit():
|
||||
"""``https://x/`` and ``https://x:443/`` are the same origin."""
|
||||
assert_same_origin("https://api.example.com/poll", "https://api.example.com:443/")
|
||||
assert_same_origin("https://api.example.com:443/poll", "https://api.example.com/")
|
||||
assert_same_origin("http://api.example.com/poll", "http://api.example.com:80/")
|
||||
|
||||
|
||||
def test_assert_same_origin_rejects_different_host():
|
||||
with pytest.raises(SSRFError, match="host"):
|
||||
assert_same_origin(
|
||||
"https://attacker.example.com/poll",
|
||||
"https://api.example.com/generate",
|
||||
)
|
||||
|
||||
|
||||
def test_assert_same_origin_rejects_different_scheme():
|
||||
with pytest.raises(SSRFError, match="scheme"):
|
||||
assert_same_origin(
|
||||
"http://api.example.com/poll", "https://api.example.com/generate"
|
||||
)
|
||||
|
||||
|
||||
def test_assert_same_origin_rejects_different_port():
|
||||
with pytest.raises(SSRFError, match="port"):
|
||||
assert_same_origin(
|
||||
"https://api.example.com:8443/poll", "https://api.example.com/generate"
|
||||
)
|
||||
|
||||
|
||||
def test_assert_same_origin_rejects_non_http_scheme():
|
||||
"""``file://`` polling URLs are rejected outright — the upstream
|
||||
should never return a non-HTTP scheme."""
|
||||
with pytest.raises(SSRFError, match="scheme"):
|
||||
assert_same_origin("file:///etc/passwd", "https://api.example.com/")
|
||||
|
||||
|
||||
def test_assert_same_origin_case_insensitive_host():
|
||||
assert_same_origin(
|
||||
"https://API.example.com/poll", "https://api.example.com/generate"
|
||||
)
|
||||
|
||||
177
tests/test_litellm/llms/test_polling_url_origin_match.py
Normal file
177
tests/test_litellm/llms/test_polling_url_origin_match.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""
|
||||
VERIA-51: polling URLs returned by upstream APIs (Azure DALL-E,
|
||||
Azure Document Intelligence, Black Forest Labs) used to be followed
|
||||
without origin validation. The handlers attached the operator's API
|
||||
key to the polling request, so an attacker who could influence the
|
||||
upstream response (or a compromised upstream) could redirect the proxy
|
||||
to send credentials anywhere.
|
||||
|
||||
These tests assert each handler now rejects polling URLs that don't
|
||||
share an origin with the original request URL.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
# Azure DALL-E sync + async paths route through ``assert_same_origin``
|
||||
# the same way as the cases below. The helper itself is unit-tested in
|
||||
# ``tests/test_litellm/litellm_core_utils/test_url_utils.py``; the
|
||||
# tests here exercise the wiring at sites with simpler signatures.
|
||||
|
||||
|
||||
# ── Azure Document Intelligence polling ───────────────────────────────────────
|
||||
|
||||
|
||||
def test_azure_di_sync_rejects_cross_origin_polling():
|
||||
from litellm.llms.azure_ai.ocr.document_intelligence.transformation import (
|
||||
AzureDocumentIntelligenceOCRConfig,
|
||||
)
|
||||
|
||||
config = AzureDocumentIntelligenceOCRConfig()
|
||||
|
||||
raw_response = MagicMock()
|
||||
raw_response.status_code = 202
|
||||
raw_response.headers = {
|
||||
"Operation-Location": "https://attacker.example.com/results/xyz",
|
||||
}
|
||||
raw_response.request = MagicMock()
|
||||
raw_response.request.url = (
|
||||
"https://eastus.cognitiveservices.azure.com/documentintelligence/.../analyze"
|
||||
)
|
||||
raw_response.request.headers = {"Ocp-Apim-Subscription-Key": "leak-me"}
|
||||
|
||||
with pytest.raises(ValueError, match="rejected polling URL"):
|
||||
config.transform_ocr_response(
|
||||
model="azure-doc-intel",
|
||||
raw_response=raw_response,
|
||||
logging_obj=MagicMock(),
|
||||
request_data={},
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
encoding=None,
|
||||
response={},
|
||||
)
|
||||
|
||||
|
||||
# ── Black Forest Labs polling ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_bfl_image_generation_sync_rejects_cross_origin_polling():
|
||||
from litellm.llms.black_forest_labs.image_generation.handler import (
|
||||
BlackForestLabsImageGeneration,
|
||||
)
|
||||
|
||||
handler = BlackForestLabsImageGeneration()
|
||||
|
||||
initial_response = MagicMock()
|
||||
initial_response.status_code = 200
|
||||
initial_response.json = MagicMock(
|
||||
return_value={"polling_url": "https://attacker.example.com/get_result"}
|
||||
)
|
||||
initial_response.request = MagicMock()
|
||||
initial_response.request.url = "https://api.bfl.ai/v1/flux-pro"
|
||||
|
||||
sync_client = MagicMock()
|
||||
sync_client.get = MagicMock()
|
||||
|
||||
with pytest.raises(Exception, match="Rejected polling URL"):
|
||||
handler._poll_for_result_sync(
|
||||
initial_response=initial_response,
|
||||
headers={"x-key": "secret"},
|
||||
sync_client=sync_client,
|
||||
)
|
||||
|
||||
sync_client.get.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_bfl_image_generation_async_rejects_cross_origin_polling():
|
||||
from litellm.llms.black_forest_labs.image_generation.handler import (
|
||||
BlackForestLabsImageGeneration,
|
||||
)
|
||||
|
||||
handler = BlackForestLabsImageGeneration()
|
||||
|
||||
initial_response = MagicMock()
|
||||
initial_response.status_code = 200
|
||||
initial_response.json = MagicMock(
|
||||
return_value={"polling_url": "https://attacker.example.com/get_result"}
|
||||
)
|
||||
initial_response.request = MagicMock()
|
||||
initial_response.request.url = "https://api.bfl.ai/v1/flux-pro"
|
||||
|
||||
async_client = MagicMock()
|
||||
async_client.get = MagicMock()
|
||||
|
||||
with pytest.raises(Exception, match="Rejected polling URL"):
|
||||
await handler._poll_for_result_async(
|
||||
initial_response=initial_response,
|
||||
headers={"x-key": "secret"},
|
||||
async_client=async_client,
|
||||
)
|
||||
|
||||
async_client.get.assert_not_called()
|
||||
|
||||
|
||||
def test_bfl_image_edit_sync_rejects_cross_origin_polling():
|
||||
from litellm.llms.black_forest_labs.image_edit.handler import (
|
||||
BlackForestLabsImageEdit,
|
||||
)
|
||||
|
||||
handler = BlackForestLabsImageEdit()
|
||||
|
||||
initial_response = MagicMock()
|
||||
initial_response.status_code = 200
|
||||
initial_response.json = MagicMock(
|
||||
return_value={"polling_url": "https://attacker.example.com/get_result"}
|
||||
)
|
||||
initial_response.request = MagicMock()
|
||||
initial_response.request.url = "https://api.bfl.ai/v1/flux-pro/edit"
|
||||
|
||||
sync_client = MagicMock()
|
||||
sync_client.get = MagicMock()
|
||||
|
||||
with pytest.raises(Exception, match="Rejected polling URL"):
|
||||
handler._poll_for_result_sync(
|
||||
initial_response=initial_response,
|
||||
headers={"x-key": "secret"},
|
||||
sync_client=sync_client,
|
||||
)
|
||||
|
||||
sync_client.get.assert_not_called()
|
||||
|
||||
|
||||
def test_bfl_image_generation_same_origin_polling_passes():
|
||||
"""Sanity check: when the polling URL shares origin with the original
|
||||
request, the origin check passes and polling proceeds."""
|
||||
from litellm.llms.black_forest_labs.image_generation.handler import (
|
||||
BlackForestLabsImageGeneration,
|
||||
)
|
||||
|
||||
handler = BlackForestLabsImageGeneration()
|
||||
|
||||
initial_response = MagicMock()
|
||||
initial_response.status_code = 200
|
||||
initial_response.json = MagicMock(
|
||||
return_value={"polling_url": "https://api.bfl.ai/v1/get_result?id=abc"}
|
||||
)
|
||||
initial_response.request = MagicMock()
|
||||
initial_response.request.url = "https://api.bfl.ai/v1/flux-pro"
|
||||
|
||||
sync_client = MagicMock()
|
||||
poll_response = MagicMock()
|
||||
poll_response.status_code = 200
|
||||
poll_response.json = MagicMock(return_value={"status": "Ready"})
|
||||
sync_client.get = MagicMock(return_value=poll_response)
|
||||
|
||||
result = handler._poll_for_result_sync(
|
||||
initial_response=initial_response,
|
||||
headers={"x-key": "secret"},
|
||||
sync_client=sync_client,
|
||||
)
|
||||
|
||||
sync_client.get.assert_called_once()
|
||||
assert result is poll_response
|
||||
@ -964,3 +964,103 @@ class TestIsRequestBodySafeBlocksEndpointTargetingFields:
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
|
||||
# ── is_request_body_safe nested-config recursion (VERIA-6) ────────────────────
|
||||
|
||||
|
||||
class TestIsRequestBodySafeNestedConfig:
|
||||
"""The Milvus vector store transformer unpacks
|
||||
``litellm_embedding_config`` as ``**kwargs`` into ``litellm.embedding(...)``
|
||||
— same SSRF / credential-exfil surface as a top-level ``api_base`` in
|
||||
the request body. ``is_request_body_safe`` must recurse into this
|
||||
nested dict so a banned param can't be smuggled in via nesting."""
|
||||
|
||||
def test_root_level_api_base_blocked_when_no_opt_in(self):
|
||||
"""Sanity check: pre-existing root-level enforcement still works."""
|
||||
with pytest.raises(ValueError, match="api_base"):
|
||||
is_request_body_safe(
|
||||
request_body={"api_base": "https://attacker.example.com"},
|
||||
general_settings={},
|
||||
llm_router=None,
|
||||
model="gpt-4",
|
||||
)
|
||||
|
||||
def test_nested_api_base_in_embedding_config_blocked(self):
|
||||
"""Smuggling ``api_base`` inside ``litellm_embedding_config`` is
|
||||
the VERIA-6 bypass — must be blocked by the recursive check."""
|
||||
with pytest.raises(ValueError, match="api_base"):
|
||||
is_request_body_safe(
|
||||
request_body={
|
||||
"litellm_embedding_config": {
|
||||
"api_base": "https://attacker.example.com",
|
||||
"api_key": "leaked-key",
|
||||
}
|
||||
},
|
||||
general_settings={},
|
||||
llm_router=None,
|
||||
model="milvus-store",
|
||||
)
|
||||
|
||||
def test_nested_langfuse_host_in_embedding_config_blocked(self):
|
||||
"""The recursion uses the *full* banned-param list, not a special
|
||||
subset — so any flag that's banned at the root is also banned
|
||||
when nested."""
|
||||
with pytest.raises(ValueError, match="langfuse_host"):
|
||||
is_request_body_safe(
|
||||
request_body={
|
||||
"litellm_embedding_config": {
|
||||
"langfuse_host": "https://attacker.example.com"
|
||||
}
|
||||
},
|
||||
general_settings={},
|
||||
llm_router=None,
|
||||
model="milvus-store",
|
||||
)
|
||||
|
||||
def test_nested_api_base_allowed_when_admin_opts_in(self):
|
||||
"""Admins who explicitly enable client-side credential passthrough
|
||||
keep the existing escape hatch — same UX as for root-level."""
|
||||
assert (
|
||||
is_request_body_safe(
|
||||
request_body={
|
||||
"litellm_embedding_config": {
|
||||
"api_base": "https://my-azure.example.com"
|
||||
}
|
||||
},
|
||||
general_settings={"allow_client_side_credentials": True},
|
||||
llm_router=None,
|
||||
model="milvus-store",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_safe_nested_config_accepted(self):
|
||||
"""A nested config without any banned params passes — there's no
|
||||
false-positive on legitimate ``api_version`` / model params."""
|
||||
assert (
|
||||
is_request_body_safe(
|
||||
request_body={
|
||||
"litellm_embedding_config": {
|
||||
"api_version": "2024-02-15-preview",
|
||||
}
|
||||
},
|
||||
general_settings={},
|
||||
llm_router=None,
|
||||
model="milvus-store",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
def test_non_dict_nested_config_does_not_break_check(self):
|
||||
"""A bogus type for ``litellm_embedding_config`` (string, list,
|
||||
None) must not crash the validator — it should just fall through."""
|
||||
assert (
|
||||
is_request_body_safe(
|
||||
request_body={"litellm_embedding_config": "not-a-dict"},
|
||||
general_settings={},
|
||||
llm_router=None,
|
||||
model="x",
|
||||
)
|
||||
is True
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user