fix: preserve safe provider model path segments
This commit is contained in:
parent
d4dd865b1a
commit
1d7778673a
@ -66,6 +66,22 @@ def encode_url_path_segment(value: Any, *, field_name: str = "path parameter") -
|
||||
return quote(value_str, safe="")
|
||||
|
||||
|
||||
def encode_url_path_segments(value: Any, *, field_name: str = "path") -> str:
|
||||
"""Percent-encode a user-controlled URL path made of multiple segments."""
|
||||
if value is None:
|
||||
raise ValueError(f"{field_name} is required")
|
||||
|
||||
value_str = str(value)
|
||||
if value_str == "":
|
||||
raise ValueError(f"{field_name} is required")
|
||||
|
||||
encoded_segments = []
|
||||
for segment in value_str.split("/"):
|
||||
encoded_segments.append(encode_url_path_segment(segment, field_name=field_name))
|
||||
|
||||
return "/".join(encoded_segments)
|
||||
|
||||
|
||||
def _is_blocked_ip(addr: str) -> bool:
|
||||
"""Return True for any IP not safe to reach from a user-supplied URL.
|
||||
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
from litellm.litellm_core_utils.url_utils import encode_url_path_segment
|
||||
from litellm.litellm_core_utils.url_utils import encode_url_path_segments
|
||||
from litellm.litellm_core_utils.exception_mapping_utils import exception_type
|
||||
from litellm.litellm_core_utils.logging_utils import track_llm_api_timing
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
@ -150,7 +150,7 @@ class BytezChatConfig(BaseConfig):
|
||||
litellm_params: dict,
|
||||
stream: Optional[bool] = None,
|
||||
) -> str:
|
||||
encoded_model = encode_url_path_segment(model, field_name="model")
|
||||
encoded_model = encode_url_path_segments(model, field_name="model")
|
||||
return f"{API_BASE}/{encoded_model}"
|
||||
|
||||
def transform_request(
|
||||
|
||||
@ -5,7 +5,7 @@ from typing import AsyncIterator, Iterator, List, Optional, Union
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.url_utils import encode_url_path_segment
|
||||
from litellm.litellm_core_utils.url_utils import encode_url_path_segments
|
||||
from litellm.llms.base_llm.base_model_iterator import BaseModelResponseIterator
|
||||
from litellm.llms.base_llm.chat.transformation import (
|
||||
BaseConfig,
|
||||
@ -90,7 +90,7 @@ class CloudflareChatConfig(BaseConfig):
|
||||
api_base = (
|
||||
f"https://api.cloudflare.com/client/v4/accounts/{account_id}/ai/run/"
|
||||
)
|
||||
encoded_model = encode_url_path_segment(model, field_name="model")
|
||||
encoded_model = encode_url_path_segments(model, field_name="model")
|
||||
return api_base + encoded_model
|
||||
|
||||
def get_supported_openai_params(self, model: str) -> List[str]:
|
||||
|
||||
@ -8,6 +8,7 @@ from litellm.litellm_core_utils.url_utils import (
|
||||
SSRFError,
|
||||
_is_blocked_ip,
|
||||
encode_url_path_segment,
|
||||
encode_url_path_segments,
|
||||
validate_url,
|
||||
)
|
||||
|
||||
@ -91,11 +92,21 @@ class TestEncodeUrlPathSegment:
|
||||
|
||||
assert encoded == "..%2F..%2Fv1%2Ffiles%3Flimit%3D1%23frag"
|
||||
|
||||
def test_encodes_path_segments_without_collapsing_valid_model_paths(self):
|
||||
encoded = encode_url_path_segments("@cf/meta/model?debug=1")
|
||||
|
||||
assert encoded == "%40cf/meta/model%3Fdebug%3D1"
|
||||
|
||||
@pytest.mark.parametrize("value", ["", ".", "..", None])
|
||||
def test_rejects_empty_and_dot_segments(self, value):
|
||||
with pytest.raises(ValueError):
|
||||
encode_url_path_segment(value, field_name="resource_id")
|
||||
|
||||
@pytest.mark.parametrize("value", ["../model", "model/../other", "/model"])
|
||||
def test_rejects_dot_segments_in_multi_segment_paths(self, value):
|
||||
with pytest.raises(ValueError):
|
||||
encode_url_path_segments(value, field_name="model")
|
||||
|
||||
|
||||
class TestValidateUrl:
|
||||
def test_blocks_loopback(self):
|
||||
|
||||
@ -2,7 +2,6 @@ import os
|
||||
import sys
|
||||
import pytest
|
||||
import json
|
||||
from urllib.parse import quote
|
||||
|
||||
# Adds the parent directory to the system path
|
||||
sys.path.insert(0, os.path.abspath("../../../../.."))
|
||||
@ -70,7 +69,7 @@ class TestBytezChatConfig:
|
||||
}
|
||||
|
||||
# Mock the HTTP request
|
||||
respx_mock.post(f"{API_BASE}/{quote(TEST_MODEL_NAME, safe='')}").respond(
|
||||
respx_mock.post(f"{API_BASE}/{TEST_MODEL_NAME}").respond(
|
||||
json={
|
||||
"error": None,
|
||||
"output": output,
|
||||
@ -91,15 +90,25 @@ class TestBytezChatConfig:
|
||||
def test_get_complete_url_encodes_model_path_segment(self):
|
||||
config = BytezChatConfig()
|
||||
|
||||
url = config.get_complete_url(
|
||||
api_base=API_BASE,
|
||||
api_key=TEST_API_KEY,
|
||||
model="../../models/other?x=1#frag",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
assert (
|
||||
config.get_complete_url(
|
||||
api_base=API_BASE,
|
||||
api_key=TEST_API_KEY,
|
||||
model="google/gemma?x=1#frag",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
== f"{API_BASE}/google/gemma%3Fx%3D1%23frag"
|
||||
)
|
||||
|
||||
assert url == f"{API_BASE}/..%2F..%2Fmodels%2Fother%3Fx%3D1%23frag"
|
||||
with pytest.raises(ValueError, match="dot path segment"):
|
||||
config.get_complete_url(
|
||||
api_base=API_BASE,
|
||||
api_key=TEST_API_KEY,
|
||||
model="../../models/other",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
def test_bytez_messages_adaptation(self):
|
||||
cases = [
|
||||
|
||||
@ -1,18 +1,27 @@
|
||||
import pytest
|
||||
|
||||
from litellm.llms.cloudflare.chat.transformation import CloudflareChatConfig
|
||||
|
||||
|
||||
def test_get_complete_url_encodes_model_path_segment():
|
||||
config = CloudflareChatConfig()
|
||||
|
||||
url = config.get_complete_url(
|
||||
api_base="https://api.cloudflare.com/client/v4/accounts/acct/ai/run/",
|
||||
api_key="cf-key",
|
||||
model="../../accounts/other?x=1#frag",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
assert (
|
||||
config.get_complete_url(
|
||||
api_base="https://api.cloudflare.com/client/v4/accounts/acct/ai/run/",
|
||||
api_key="cf-key",
|
||||
model="@cf/meta/llama?x=1#frag",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
== "https://api.cloudflare.com/client/v4/accounts/acct/ai/run/%40cf/meta/llama%3Fx%3D1%23frag"
|
||||
)
|
||||
|
||||
assert (
|
||||
url
|
||||
== "https://api.cloudflare.com/client/v4/accounts/acct/ai/run/..%2F..%2Faccounts%2Fother%3Fx%3D1%23frag"
|
||||
)
|
||||
with pytest.raises(ValueError, match="dot path segment"):
|
||||
config.get_complete_url(
|
||||
api_base="https://api.cloudflare.com/client/v4/accounts/acct/ai/run/",
|
||||
api_key="cf-key",
|
||||
model="../../accounts/other",
|
||||
optional_params={},
|
||||
litellm_params={},
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user