diff --git a/litellm/litellm_core_utils/url_utils.py b/litellm/litellm_core_utils/url_utils.py index a65d0892aa..4dd0b5142e 100644 --- a/litellm/litellm_core_utils/url_utils.py +++ b/litellm/litellm_core_utils/url_utils.py @@ -199,6 +199,49 @@ def validate_url(url: str) -> Tuple[str, str]: return rewritten, host_header +def assert_same_origin(candidate_url: str, expected_url: str) -> None: + """Verify ``candidate_url`` shares scheme, host, and port with ``expected_url``. + + Use when an upstream API returns a URL meant for follow-up requests + (e.g. an async-job polling URL that will be hit with the operator's + API key in the headers). The upstream is trusted because the operator + configured ``api_base``, but the URL it hands back must actually point + back at the same origin or we'd be blindly forwarding credentials + wherever the upstream told us to. + + Hostnames are compared case-insensitively. Default ports are made + explicit (HTTP→80, HTTPS→443) so ``https://api.example.com:443/...`` + and ``https://api.example.com/...`` are treated as the same origin. + """ + candidate = urlparse(candidate_url) + expected = urlparse(expected_url) + + if candidate.scheme not in _ALLOWED_SCHEMES: + raise SSRFError(f"URL scheme '{candidate.scheme}' is not allowed") + + if candidate.scheme != expected.scheme: + raise SSRFError( + "Origin mismatch: scheme " + f"{candidate.scheme!r} != expected {expected.scheme!r}" + ) + + candidate_host = _normalize_host(candidate.hostname or "") + expected_host = _normalize_host(expected.hostname or "") + if not candidate_host or candidate_host != expected_host: + raise SSRFError( + "Origin mismatch: host " + f"{candidate.hostname!r} != expected {expected.hostname!r}" + ) + + default_port = 443 if candidate.scheme == "https" else 80 + candidate_port = candidate.port if candidate.port is not None else default_port + expected_port = expected.port if expected.port is not None else default_port + if candidate_port != expected_port: + raise SSRFError( + "Origin mismatch: port " f"{candidate_port} != expected {expected_port}" + ) + + _MAX_REDIRECTS = 10 diff --git a/litellm/llms/azure/azure.py b/litellm/llms/azure/azure.py index 61cfd54b56..877a7d3c84 100644 --- a/litellm/llms/azure/azure.py +++ b/litellm/llms/azure/azure.py @@ -16,6 +16,7 @@ import litellm from litellm.constants import AZURE_OPERATION_POLLING_TIMEOUT, DEFAULT_MAX_RETRIES from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.logging_utils import track_llm_api_timing +from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, @@ -898,6 +899,17 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): operation_location_url = response.headers["operation-location"] else: raise AzureOpenAIError(status_code=500, message=response.text) + # Reject polling URLs that don't share an origin with ``api_base``. + # Without this an upstream-controlled or attacker-controlled + # value would receive the operator's Azure API key in the + # request headers below. VERIA-51. + try: + assert_same_origin(operation_location_url, api_base) + except SSRFError as ssrf_err: + raise AzureOpenAIError( + status_code=502, + message=f"Rejected polling URL: {ssrf_err}", + ) response = await async_handler.get( url=operation_location_url, headers=headers, @@ -908,8 +920,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): timeout_secs: int = AZURE_OPERATION_POLLING_TIMEOUT start_time = time.time() if "status" not in response.json(): - raise Exception( - "Expected 'status' in response. Got={}".format(response.json()) + # Don't reflect the raw response body — when the polling + # URL points at an internal JSON API (cloud metadata + # service etc.) reflecting it here turns Blind SSRF into + # Full-Read SSRF. VERIA-51. + raise AzureOpenAIError( + status_code=502, + message="Polling response missing 'status' field", ) while response.json()["status"] not in ["succeeded", "failed"]: if time.time() - start_time > timeout_secs: @@ -1009,6 +1026,13 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): operation_location_url = response.headers["operation-location"] else: raise AzureOpenAIError(status_code=500, message=response.text) + try: + assert_same_origin(operation_location_url, api_base) + except SSRFError as ssrf_err: + raise AzureOpenAIError( + status_code=502, + message=f"Rejected polling URL: {ssrf_err}", + ) response = sync_handler.get( url=operation_location_url, headers=headers, @@ -1019,8 +1043,9 @@ class AzureChatCompletion(BaseAzureLLM, BaseLLM): timeout_secs: int = AZURE_OPERATION_POLLING_TIMEOUT start_time = time.time() if "status" not in response.json(): - raise Exception( - "Expected 'status' in response. Got={}".format(response.json()) + raise AzureOpenAIError( + status_code=502, + message="Polling response missing 'status' field", ) while response.json()["status"] not in ["succeeded", "failed"]: if time.time() - start_time > timeout_secs: diff --git a/litellm/llms/azure_ai/ocr/document_intelligence/transformation.py b/litellm/llms/azure_ai/ocr/document_intelligence/transformation.py index 76c247aea8..bb5ebaa6a5 100644 --- a/litellm/llms/azure_ai/ocr/document_intelligence/transformation.py +++ b/litellm/llms/azure_ai/ocr/document_intelligence/transformation.py @@ -17,6 +17,7 @@ from urllib.parse import quote import httpx from litellm._logging import verbose_logger +from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin from litellm.constants import ( AZURE_DOCUMENT_INTELLIGENCE_API_VERSION, AZURE_DOCUMENT_INTELLIGENCE_DEFAULT_DPI, @@ -599,6 +600,16 @@ class AzureDocumentIntelligenceOCRConfig(BaseOCRConfig): "Azure Document Intelligence returned 202 but no Operation-Location header found" ) + # Reject cross-origin polling URLs — the auth headers + # below would otherwise leak to whatever URL the upstream + # (or an attacker-controlled upstream) returns. VERIA-51. + try: + assert_same_origin(operation_url, str(raw_response.request.url)) + except SSRFError as ssrf_err: + raise ValueError( + f"Azure Document Intelligence: rejected polling URL ({ssrf_err})" + ) + # Get headers for polling (need auth) poll_headers = { "Ocp-Apim-Subscription-Key": raw_response.request.headers.get( @@ -711,6 +722,14 @@ class AzureDocumentIntelligenceOCRConfig(BaseOCRConfig): "Azure Document Intelligence returned 202 but no Operation-Location header found" ) + # Reject cross-origin polling URLs (see sync path). VERIA-51. + try: + assert_same_origin(operation_url, str(raw_response.request.url)) + except SSRFError as ssrf_err: + raise ValueError( + f"Azure Document Intelligence: rejected polling URL ({ssrf_err})" + ) + # Get headers for polling (need auth) poll_headers = { "Ocp-Apim-Subscription-Key": raw_response.request.headers.get( diff --git a/litellm/llms/black_forest_labs/image_edit/handler.py b/litellm/llms/black_forest_labs/image_edit/handler.py index dea2683a04..f5784e0836 100644 --- a/litellm/llms/black_forest_labs/image_edit/handler.py +++ b/litellm/llms/black_forest_labs/image_edit/handler.py @@ -15,6 +15,7 @@ import httpx import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, @@ -331,6 +332,17 @@ class BlackForestLabsImageEdit: message="No polling_url in BFL response", ) + # Reject cross-origin polling URLs — the ``x-key`` auth header + # would otherwise leak to whatever URL the upstream returns. + # VERIA-51. + try: + assert_same_origin(polling_url, str(initial_response.request.url)) + except SSRFError as ssrf_err: + raise BlackForestLabsError( + status_code=502, + message=f"Rejected polling URL: {ssrf_err}", + ) + # Get just the auth header for polling polling_headers = {"x-key": headers.get("x-key", "")} @@ -416,6 +428,17 @@ class BlackForestLabsImageEdit: message="No polling_url in BFL response", ) + # Reject cross-origin polling URLs — the ``x-key`` auth header + # would otherwise leak to whatever URL the upstream returns. + # VERIA-51. + try: + assert_same_origin(polling_url, str(initial_response.request.url)) + except SSRFError as ssrf_err: + raise BlackForestLabsError( + status_code=502, + message=f"Rejected polling URL: {ssrf_err}", + ) + # Get just the auth header for polling polling_headers = {"x-key": headers.get("x-key", "")} diff --git a/litellm/llms/black_forest_labs/image_generation/handler.py b/litellm/llms/black_forest_labs/image_generation/handler.py index 5a1d885e52..8af4a236fd 100644 --- a/litellm/llms/black_forest_labs/image_generation/handler.py +++ b/litellm/llms/black_forest_labs/image_generation/handler.py @@ -15,6 +15,7 @@ import httpx import litellm from litellm._logging import verbose_logger from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.litellm_core_utils.url_utils import SSRFError, assert_same_origin from litellm.llms.custom_httpx.http_handler import ( AsyncHTTPHandler, HTTPHandler, @@ -317,6 +318,17 @@ class BlackForestLabsImageGeneration: message="No polling_url in BFL response", ) + # Reject cross-origin polling URLs — the ``x-key`` auth header + # would otherwise leak to whatever URL the upstream returns. + # VERIA-51. + try: + assert_same_origin(polling_url, str(initial_response.request.url)) + except SSRFError as ssrf_err: + raise BlackForestLabsError( + status_code=502, + message=f"Rejected polling URL: {ssrf_err}", + ) + # Get just the auth header for polling polling_headers = {"x-key": headers.get("x-key", "")} @@ -402,6 +414,17 @@ class BlackForestLabsImageGeneration: message="No polling_url in BFL response", ) + # Reject cross-origin polling URLs — the ``x-key`` auth header + # would otherwise leak to whatever URL the upstream returns. + # VERIA-51. + try: + assert_same_origin(polling_url, str(initial_response.request.url)) + except SSRFError as ssrf_err: + raise BlackForestLabsError( + status_code=502, + message=f"Rejected polling URL: {ssrf_err}", + ) + # Get just the auth header for polling polling_headers = {"x-key": headers.get("x-key", "")} diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404/index.html similarity index 100% rename from litellm/proxy/_experimental/out/404.html rename to litellm/proxy/_experimental/out/404/index.html diff --git a/litellm/proxy/_experimental/out/_not-found.html b/litellm/proxy/_experimental/out/_not-found/index.html similarity index 100% rename from litellm/proxy/_experimental/out/_not-found.html rename to litellm/proxy/_experimental/out/_not-found/index.html diff --git a/litellm/proxy/_experimental/out/api-reference.html b/litellm/proxy/_experimental/out/api-reference/index.html similarity index 100% rename from litellm/proxy/_experimental/out/api-reference.html rename to litellm/proxy/_experimental/out/api-reference/index.html diff --git a/litellm/proxy/_experimental/out/experimental/api-playground.html b/litellm/proxy/_experimental/out/experimental/api-playground/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/api-playground.html rename to litellm/proxy/_experimental/out/experimental/api-playground/index.html diff --git a/litellm/proxy/_experimental/out/experimental/budgets.html b/litellm/proxy/_experimental/out/experimental/budgets/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/budgets.html rename to litellm/proxy/_experimental/out/experimental/budgets/index.html diff --git a/litellm/proxy/_experimental/out/experimental/caching.html b/litellm/proxy/_experimental/out/experimental/caching/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/caching.html rename to litellm/proxy/_experimental/out/experimental/caching/index.html diff --git a/litellm/proxy/_experimental/out/experimental/claude-code-plugins.html b/litellm/proxy/_experimental/out/experimental/claude-code-plugins/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/claude-code-plugins.html rename to litellm/proxy/_experimental/out/experimental/claude-code-plugins/index.html diff --git a/litellm/proxy/_experimental/out/experimental/old-usage.html b/litellm/proxy/_experimental/out/experimental/old-usage/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/old-usage.html rename to litellm/proxy/_experimental/out/experimental/old-usage/index.html diff --git a/litellm/proxy/_experimental/out/experimental/prompts.html b/litellm/proxy/_experimental/out/experimental/prompts/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/prompts.html rename to litellm/proxy/_experimental/out/experimental/prompts/index.html diff --git a/litellm/proxy/_experimental/out/experimental/tag-management.html b/litellm/proxy/_experimental/out/experimental/tag-management/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/tag-management.html rename to litellm/proxy/_experimental/out/experimental/tag-management/index.html diff --git a/litellm/proxy/_experimental/out/guardrails.html b/litellm/proxy/_experimental/out/guardrails/index.html similarity index 100% rename from litellm/proxy/_experimental/out/guardrails.html rename to litellm/proxy/_experimental/out/guardrails/index.html diff --git a/litellm/proxy/_experimental/out/login.html b/litellm/proxy/_experimental/out/login/index.html similarity index 100% rename from litellm/proxy/_experimental/out/login.html rename to litellm/proxy/_experimental/out/login/index.html diff --git a/litellm/proxy/_experimental/out/logs.html b/litellm/proxy/_experimental/out/logs/index.html similarity index 100% rename from litellm/proxy/_experimental/out/logs.html rename to litellm/proxy/_experimental/out/logs/index.html diff --git a/litellm/proxy/_experimental/out/mcp/oauth/callback.html b/litellm/proxy/_experimental/out/mcp/oauth/callback/index.html similarity index 100% rename from litellm/proxy/_experimental/out/mcp/oauth/callback.html rename to litellm/proxy/_experimental/out/mcp/oauth/callback/index.html diff --git a/litellm/proxy/_experimental/out/model-hub.html b/litellm/proxy/_experimental/out/model-hub/index.html similarity index 100% rename from litellm/proxy/_experimental/out/model-hub.html rename to litellm/proxy/_experimental/out/model-hub/index.html diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub/index.html similarity index 100% rename from litellm/proxy/_experimental/out/model_hub.html rename to litellm/proxy/_experimental/out/model_hub/index.html diff --git a/litellm/proxy/_experimental/out/model_hub_table.html b/litellm/proxy/_experimental/out/model_hub_table/index.html similarity index 100% rename from litellm/proxy/_experimental/out/model_hub_table.html rename to litellm/proxy/_experimental/out/model_hub_table/index.html diff --git a/litellm/proxy/_experimental/out/models-and-endpoints.html b/litellm/proxy/_experimental/out/models-and-endpoints/index.html similarity index 100% rename from litellm/proxy/_experimental/out/models-and-endpoints.html rename to litellm/proxy/_experimental/out/models-and-endpoints/index.html diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding/index.html similarity index 100% rename from litellm/proxy/_experimental/out/onboarding.html rename to litellm/proxy/_experimental/out/onboarding/index.html diff --git a/litellm/proxy/_experimental/out/organizations.html b/litellm/proxy/_experimental/out/organizations/index.html similarity index 100% rename from litellm/proxy/_experimental/out/organizations.html rename to litellm/proxy/_experimental/out/organizations/index.html diff --git a/litellm/proxy/_experimental/out/playground.html b/litellm/proxy/_experimental/out/playground/index.html similarity index 100% rename from litellm/proxy/_experimental/out/playground.html rename to litellm/proxy/_experimental/out/playground/index.html diff --git a/litellm/proxy/_experimental/out/policies.html b/litellm/proxy/_experimental/out/policies/index.html similarity index 100% rename from litellm/proxy/_experimental/out/policies.html rename to litellm/proxy/_experimental/out/policies/index.html diff --git a/litellm/proxy/_experimental/out/settings/admin-settings.html b/litellm/proxy/_experimental/out/settings/admin-settings/index.html similarity index 100% rename from litellm/proxy/_experimental/out/settings/admin-settings.html rename to litellm/proxy/_experimental/out/settings/admin-settings/index.html diff --git a/litellm/proxy/_experimental/out/settings/logging-and-alerts.html b/litellm/proxy/_experimental/out/settings/logging-and-alerts/index.html similarity index 100% rename from litellm/proxy/_experimental/out/settings/logging-and-alerts.html rename to litellm/proxy/_experimental/out/settings/logging-and-alerts/index.html diff --git a/litellm/proxy/_experimental/out/settings/router-settings.html b/litellm/proxy/_experimental/out/settings/router-settings/index.html similarity index 100% rename from litellm/proxy/_experimental/out/settings/router-settings.html rename to litellm/proxy/_experimental/out/settings/router-settings/index.html diff --git a/litellm/proxy/_experimental/out/settings/ui-theme.html b/litellm/proxy/_experimental/out/settings/ui-theme/index.html similarity index 100% rename from litellm/proxy/_experimental/out/settings/ui-theme.html rename to litellm/proxy/_experimental/out/settings/ui-theme/index.html diff --git a/litellm/proxy/_experimental/out/skills.html b/litellm/proxy/_experimental/out/skills/index.html similarity index 100% rename from litellm/proxy/_experimental/out/skills.html rename to litellm/proxy/_experimental/out/skills/index.html diff --git a/litellm/proxy/_experimental/out/teams.html b/litellm/proxy/_experimental/out/teams/index.html similarity index 100% rename from litellm/proxy/_experimental/out/teams.html rename to litellm/proxy/_experimental/out/teams/index.html diff --git a/litellm/proxy/_experimental/out/test-key.html b/litellm/proxy/_experimental/out/test-key/index.html similarity index 100% rename from litellm/proxy/_experimental/out/test-key.html rename to litellm/proxy/_experimental/out/test-key/index.html diff --git a/litellm/proxy/_experimental/out/tools/mcp-servers.html b/litellm/proxy/_experimental/out/tools/mcp-servers/index.html similarity index 100% rename from litellm/proxy/_experimental/out/tools/mcp-servers.html rename to litellm/proxy/_experimental/out/tools/mcp-servers/index.html diff --git a/litellm/proxy/_experimental/out/tools/vector-stores.html b/litellm/proxy/_experimental/out/tools/vector-stores/index.html similarity index 100% rename from litellm/proxy/_experimental/out/tools/vector-stores.html rename to litellm/proxy/_experimental/out/tools/vector-stores/index.html diff --git a/litellm/proxy/_experimental/out/usage.html b/litellm/proxy/_experimental/out/usage/index.html similarity index 100% rename from litellm/proxy/_experimental/out/usage.html rename to litellm/proxy/_experimental/out/usage/index.html diff --git a/litellm/proxy/_experimental/out/users.html b/litellm/proxy/_experimental/out/users/index.html similarity index 100% rename from litellm/proxy/_experimental/out/users.html rename to litellm/proxy/_experimental/out/users/index.html diff --git a/litellm/proxy/_experimental/out/virtual-keys.html b/litellm/proxy/_experimental/out/virtual-keys/index.html similarity index 100% rename from litellm/proxy/_experimental/out/virtual-keys.html rename to litellm/proxy/_experimental/out/virtual-keys/index.html diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index 91c8f2dd7c..e395e03def 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -241,9 +241,30 @@ def is_request_body_safe( "Relevant Issue: https://huntr.com/bounties/4001e1a2-7b7a-4776-a3ae-e6692ec3d997", ) + # Recurse into nested config dicts whose values get unpacked as + # ``**kwargs`` into outbound API calls — same SSRF / credential + # exfil surface as the root, but historically not covered by this + # banned-param check. VERIA-6. + for nested_key in _NESTED_CONFIG_KEYS: + nested = request_body.get(nested_key) + if isinstance(nested, dict): + is_request_body_safe( + request_body=nested, + general_settings=general_settings, + llm_router=llm_router, + model=model, + ) + return True +# Config dicts whose entries are spread as ``**dict`` into outbound LLM +# API calls. ``litellm_embedding_config`` is consumed by the Milvus +# vector store transformer; future nested-config keys with the same +# threat shape should be added here. +_NESTED_CONFIG_KEYS: Tuple[str, ...] = ("litellm_embedding_config",) + + async def pre_db_read_auth_checks( request: Request, request_data: dict, diff --git a/tests/test_litellm/litellm_core_utils/test_url_utils.py b/tests/test_litellm/litellm_core_utils/test_url_utils.py index 4579c20321..ff91c41885 100644 --- a/tests/test_litellm/litellm_core_utils/test_url_utils.py +++ b/tests/test_litellm/litellm_core_utils/test_url_utils.py @@ -394,3 +394,61 @@ class TestHostAllowlist: monkeypatch.setattr(url_utils.socket, "getaddrinfo", fake) validate_url("http://internal.corp/") + + +# ── assert_same_origin ──────────────────────────────────────────────────────── + + +from litellm.litellm_core_utils.url_utils import assert_same_origin + + +def test_assert_same_origin_matches_scheme_host_port(): + """A polling URL on the same scheme + host + port as the api_base + passes — the upstream is trusted; the URL it returned points back at + the same upstream.""" + assert_same_origin( + "https://api.example.com/v1/operations/abc", + "https://api.example.com/v1/generate", + ) + + +def test_assert_same_origin_treats_default_ports_as_explicit(): + """``https://x/`` and ``https://x:443/`` are the same origin.""" + assert_same_origin("https://api.example.com/poll", "https://api.example.com:443/") + assert_same_origin("https://api.example.com:443/poll", "https://api.example.com/") + assert_same_origin("http://api.example.com/poll", "http://api.example.com:80/") + + +def test_assert_same_origin_rejects_different_host(): + with pytest.raises(SSRFError, match="host"): + assert_same_origin( + "https://attacker.example.com/poll", + "https://api.example.com/generate", + ) + + +def test_assert_same_origin_rejects_different_scheme(): + with pytest.raises(SSRFError, match="scheme"): + assert_same_origin( + "http://api.example.com/poll", "https://api.example.com/generate" + ) + + +def test_assert_same_origin_rejects_different_port(): + with pytest.raises(SSRFError, match="port"): + assert_same_origin( + "https://api.example.com:8443/poll", "https://api.example.com/generate" + ) + + +def test_assert_same_origin_rejects_non_http_scheme(): + """``file://`` polling URLs are rejected outright — the upstream + should never return a non-HTTP scheme.""" + with pytest.raises(SSRFError, match="scheme"): + assert_same_origin("file:///etc/passwd", "https://api.example.com/") + + +def test_assert_same_origin_case_insensitive_host(): + assert_same_origin( + "https://API.example.com/poll", "https://api.example.com/generate" + ) diff --git a/tests/test_litellm/llms/test_polling_url_origin_match.py b/tests/test_litellm/llms/test_polling_url_origin_match.py new file mode 100644 index 0000000000..f1f910bc73 --- /dev/null +++ b/tests/test_litellm/llms/test_polling_url_origin_match.py @@ -0,0 +1,177 @@ +""" +VERIA-51: polling URLs returned by upstream APIs (Azure DALL-E, +Azure Document Intelligence, Black Forest Labs) used to be followed +without origin validation. The handlers attached the operator's API +key to the polling request, so an attacker who could influence the +upstream response (or a compromised upstream) could redirect the proxy +to send credentials anywhere. + +These tests assert each handler now rejects polling URLs that don't +share an origin with the original request URL. +""" + +from unittest.mock import MagicMock, patch + +import httpx +import pytest + + +# Azure DALL-E sync + async paths route through ``assert_same_origin`` +# the same way as the cases below. The helper itself is unit-tested in +# ``tests/test_litellm/litellm_core_utils/test_url_utils.py``; the +# tests here exercise the wiring at sites with simpler signatures. + + +# ── Azure Document Intelligence polling ─────────────────────────────────────── + + +def test_azure_di_sync_rejects_cross_origin_polling(): + from litellm.llms.azure_ai.ocr.document_intelligence.transformation import ( + AzureDocumentIntelligenceOCRConfig, + ) + + config = AzureDocumentIntelligenceOCRConfig() + + raw_response = MagicMock() + raw_response.status_code = 202 + raw_response.headers = { + "Operation-Location": "https://attacker.example.com/results/xyz", + } + raw_response.request = MagicMock() + raw_response.request.url = ( + "https://eastus.cognitiveservices.azure.com/documentintelligence/.../analyze" + ) + raw_response.request.headers = {"Ocp-Apim-Subscription-Key": "leak-me"} + + with pytest.raises(ValueError, match="rejected polling URL"): + config.transform_ocr_response( + model="azure-doc-intel", + raw_response=raw_response, + logging_obj=MagicMock(), + request_data={}, + optional_params={}, + litellm_params={}, + encoding=None, + response={}, + ) + + +# ── Black Forest Labs polling ───────────────────────────────────────────────── + + +def test_bfl_image_generation_sync_rejects_cross_origin_polling(): + from litellm.llms.black_forest_labs.image_generation.handler import ( + BlackForestLabsImageGeneration, + ) + + handler = BlackForestLabsImageGeneration() + + initial_response = MagicMock() + initial_response.status_code = 200 + initial_response.json = MagicMock( + return_value={"polling_url": "https://attacker.example.com/get_result"} + ) + initial_response.request = MagicMock() + initial_response.request.url = "https://api.bfl.ai/v1/flux-pro" + + sync_client = MagicMock() + sync_client.get = MagicMock() + + with pytest.raises(Exception, match="Rejected polling URL"): + handler._poll_for_result_sync( + initial_response=initial_response, + headers={"x-key": "secret"}, + sync_client=sync_client, + ) + + sync_client.get.assert_not_called() + + +@pytest.mark.asyncio +async def test_bfl_image_generation_async_rejects_cross_origin_polling(): + from litellm.llms.black_forest_labs.image_generation.handler import ( + BlackForestLabsImageGeneration, + ) + + handler = BlackForestLabsImageGeneration() + + initial_response = MagicMock() + initial_response.status_code = 200 + initial_response.json = MagicMock( + return_value={"polling_url": "https://attacker.example.com/get_result"} + ) + initial_response.request = MagicMock() + initial_response.request.url = "https://api.bfl.ai/v1/flux-pro" + + async_client = MagicMock() + async_client.get = MagicMock() + + with pytest.raises(Exception, match="Rejected polling URL"): + await handler._poll_for_result_async( + initial_response=initial_response, + headers={"x-key": "secret"}, + async_client=async_client, + ) + + async_client.get.assert_not_called() + + +def test_bfl_image_edit_sync_rejects_cross_origin_polling(): + from litellm.llms.black_forest_labs.image_edit.handler import ( + BlackForestLabsImageEdit, + ) + + handler = BlackForestLabsImageEdit() + + initial_response = MagicMock() + initial_response.status_code = 200 + initial_response.json = MagicMock( + return_value={"polling_url": "https://attacker.example.com/get_result"} + ) + initial_response.request = MagicMock() + initial_response.request.url = "https://api.bfl.ai/v1/flux-pro/edit" + + sync_client = MagicMock() + sync_client.get = MagicMock() + + with pytest.raises(Exception, match="Rejected polling URL"): + handler._poll_for_result_sync( + initial_response=initial_response, + headers={"x-key": "secret"}, + sync_client=sync_client, + ) + + sync_client.get.assert_not_called() + + +def test_bfl_image_generation_same_origin_polling_passes(): + """Sanity check: when the polling URL shares origin with the original + request, the origin check passes and polling proceeds.""" + from litellm.llms.black_forest_labs.image_generation.handler import ( + BlackForestLabsImageGeneration, + ) + + handler = BlackForestLabsImageGeneration() + + initial_response = MagicMock() + initial_response.status_code = 200 + initial_response.json = MagicMock( + return_value={"polling_url": "https://api.bfl.ai/v1/get_result?id=abc"} + ) + initial_response.request = MagicMock() + initial_response.request.url = "https://api.bfl.ai/v1/flux-pro" + + sync_client = MagicMock() + poll_response = MagicMock() + poll_response.status_code = 200 + poll_response.json = MagicMock(return_value={"status": "Ready"}) + sync_client.get = MagicMock(return_value=poll_response) + + result = handler._poll_for_result_sync( + initial_response=initial_response, + headers={"x-key": "secret"}, + sync_client=sync_client, + ) + + sync_client.get.assert_called_once() + assert result is poll_response diff --git a/tests/test_litellm/proxy/auth/test_auth_utils.py b/tests/test_litellm/proxy/auth/test_auth_utils.py index 91f300b88c..81826ec864 100644 --- a/tests/test_litellm/proxy/auth/test_auth_utils.py +++ b/tests/test_litellm/proxy/auth/test_auth_utils.py @@ -964,3 +964,103 @@ class TestIsRequestBodySafeBlocksEndpointTargetingFields: ) is True ) + + +# ── is_request_body_safe nested-config recursion (VERIA-6) ──────────────────── + + +class TestIsRequestBodySafeNestedConfig: + """The Milvus vector store transformer unpacks + ``litellm_embedding_config`` as ``**kwargs`` into ``litellm.embedding(...)`` + — same SSRF / credential-exfil surface as a top-level ``api_base`` in + the request body. ``is_request_body_safe`` must recurse into this + nested dict so a banned param can't be smuggled in via nesting.""" + + def test_root_level_api_base_blocked_when_no_opt_in(self): + """Sanity check: pre-existing root-level enforcement still works.""" + with pytest.raises(ValueError, match="api_base"): + is_request_body_safe( + request_body={"api_base": "https://attacker.example.com"}, + general_settings={}, + llm_router=None, + model="gpt-4", + ) + + def test_nested_api_base_in_embedding_config_blocked(self): + """Smuggling ``api_base`` inside ``litellm_embedding_config`` is + the VERIA-6 bypass — must be blocked by the recursive check.""" + with pytest.raises(ValueError, match="api_base"): + is_request_body_safe( + request_body={ + "litellm_embedding_config": { + "api_base": "https://attacker.example.com", + "api_key": "leaked-key", + } + }, + general_settings={}, + llm_router=None, + model="milvus-store", + ) + + def test_nested_langfuse_host_in_embedding_config_blocked(self): + """The recursion uses the *full* banned-param list, not a special + subset — so any flag that's banned at the root is also banned + when nested.""" + with pytest.raises(ValueError, match="langfuse_host"): + is_request_body_safe( + request_body={ + "litellm_embedding_config": { + "langfuse_host": "https://attacker.example.com" + } + }, + general_settings={}, + llm_router=None, + model="milvus-store", + ) + + def test_nested_api_base_allowed_when_admin_opts_in(self): + """Admins who explicitly enable client-side credential passthrough + keep the existing escape hatch — same UX as for root-level.""" + assert ( + is_request_body_safe( + request_body={ + "litellm_embedding_config": { + "api_base": "https://my-azure.example.com" + } + }, + general_settings={"allow_client_side_credentials": True}, + llm_router=None, + model="milvus-store", + ) + is True + ) + + def test_safe_nested_config_accepted(self): + """A nested config without any banned params passes — there's no + false-positive on legitimate ``api_version`` / model params.""" + assert ( + is_request_body_safe( + request_body={ + "litellm_embedding_config": { + "api_version": "2024-02-15-preview", + } + }, + general_settings={}, + llm_router=None, + model="milvus-store", + ) + is True + ) + + def test_non_dict_nested_config_does_not_break_check(self): + """A bogus type for ``litellm_embedding_config`` (string, list, + None) must not crash the validator — it should just fall through.""" + assert ( + is_request_body_safe( + request_body={"litellm_embedding_config": "not-a-dict"}, + general_settings={}, + llm_router=None, + model="x", + ) + is True + )