[internal copy of #29003] fix(vertex_ai): use user-supplied api_base as is for Model Garden OpenAI-compat path (#29530)
* fix(vertex_ai): use user-supplied api_base as is for Model Garden OpenAI-compat path
* chore(tests): url assertions and outputs
* fix(tests): fixing reference to unused test
* fix(aiohttp): drop octet-stream content-type on bodyless requests
The aiohttp transport forwarded httpx's empty request body straight
to aiohttp, which attaches a default Content-Type: application/octet-stream
for any bytes payload. Bodyless requests such as DELETE /responses/{id} then
hit OpenAI with that header and were rejected with unsupported_content_type,
breaking the e2e_openai_endpoints test_basic_response check. Coercing an
empty body to None makes aiohttp behave like the httpx transport and send no
content-type for bodyless requests.
---------
Co-authored-by: Steven Kessler <9701252+stvnksslr@users.noreply.github.com>
This commit is contained in:
parent
6a9f542f81
commit
a5ccd96152
@ -256,7 +256,10 @@ class LiteLLMAiohttpTransport(AiohttpTransport):
|
||||
from yarl import URL as YarlURL
|
||||
|
||||
try:
|
||||
data = request.content
|
||||
# Coerce an empty body to None so aiohttp does not attach a
|
||||
# `Content-Type: application/octet-stream` header for bodyless
|
||||
# requests (e.g. DELETE /responses/{id}), which upstream APIs reject.
|
||||
data = request.content or None
|
||||
except httpx.RequestNotRead:
|
||||
data = request.stream # type: ignore
|
||||
request.headers.pop("transfer-encoding", None) # handled by aiohttp
|
||||
|
||||
@ -114,33 +114,18 @@ class VertexAIModelGardenModels(VertexBase):
|
||||
openai_like_chat_completions = OpenAILikeChatHandler()
|
||||
|
||||
## CONSTRUCT API BASE
|
||||
# Skip _check_custom_proxy: its ":verb" URL construction corrupts a
|
||||
# user-supplied api_base (e.g. Vertex MG dedicated endpoint), and
|
||||
# OpenAILikeChatHandler already appends "/chat/completions".
|
||||
stream: bool = optional_params.get("stream", False) or False
|
||||
optional_params["stream"] = stream
|
||||
default_api_base = create_vertex_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
stream=stream,
|
||||
model=model,
|
||||
)
|
||||
|
||||
if len(default_api_base.split(":")) > 1:
|
||||
endpoint = default_api_base.split(":")[-1]
|
||||
else:
|
||||
endpoint = ""
|
||||
|
||||
_, api_base = self._check_custom_proxy(
|
||||
api_base=api_base,
|
||||
custom_llm_provider="vertex_ai",
|
||||
gemini_api_key=None,
|
||||
endpoint=endpoint,
|
||||
stream=stream,
|
||||
auth_header=None,
|
||||
url=default_api_base,
|
||||
model=model,
|
||||
vertex_project=vertex_project or project_id,
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_api_version="v1beta1",
|
||||
)
|
||||
if api_base is None:
|
||||
api_base = create_vertex_url(
|
||||
vertex_location=vertex_location or "us-central1",
|
||||
vertex_project=vertex_project or project_id,
|
||||
stream=stream,
|
||||
model=model,
|
||||
)
|
||||
# Publisher/catalog models: model id must be sent in the JSON body (OpenAPI route).
|
||||
# Single-segment endpoint ids: model is encoded in the URL path; body model stays empty.
|
||||
if not _vertex_model_garden_model_id_in_json_body(model):
|
||||
|
||||
@ -262,6 +262,61 @@ async def test_handle_async_request_uses_env_proxy(monkeypatch):
|
||||
assert captured["proxy"] == proxy_url
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_async_request_empty_body_sends_no_data():
|
||||
"""
|
||||
A bodyless request (e.g. DELETE /responses/{id}) must reach aiohttp with
|
||||
data=None. Passing the empty `b""` httpx content makes aiohttp attach a
|
||||
`Content-Type: application/octet-stream` header, which providers like
|
||||
OpenAI reject with `unsupported_content_type`.
|
||||
"""
|
||||
captured = {}
|
||||
|
||||
class FakeSession:
|
||||
def __init__(self):
|
||||
self.closed = False
|
||||
try:
|
||||
self._loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self._loop = None
|
||||
|
||||
def request(self, *args, **kwargs):
|
||||
captured["data"] = kwargs.get("data")
|
||||
|
||||
class Resp:
|
||||
status = 200
|
||||
headers = {}
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
pass
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
class C:
|
||||
async def iter_chunked(self, size):
|
||||
yield b""
|
||||
|
||||
return C()
|
||||
|
||||
return Resp()
|
||||
|
||||
transport = LiteLLMAiohttpTransport(client=lambda: FakeSession()) # type: ignore
|
||||
|
||||
empty_request = httpx.Request("DELETE", "http://example.com/responses/resp_123")
|
||||
await transport.handle_async_request(empty_request)
|
||||
assert captured["data"] is None
|
||||
|
||||
body_request = httpx.Request(
|
||||
"POST", "http://example.com/responses", json={"input": "ping"}
|
||||
)
|
||||
await transport.handle_async_request(body_request)
|
||||
assert captured["data"] == body_request.content
|
||||
assert captured["data"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_async_request_uses_env_proxy_per_url(monkeypatch):
|
||||
"""Aiohttp transport should honor HTTP(S)_PROXY env vars unless NO_PROXY matches"""
|
||||
|
||||
@ -1,7 +1,13 @@
|
||||
"""Vertex Model Garden: OpenAPI base URL for publisher/model ids vs per-endpoint path."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.llms.vertex_ai.vertex_model_garden.main import (
|
||||
_vertex_model_garden_model_id_in_json_body,
|
||||
create_vertex_url,
|
||||
@ -37,5 +43,198 @@ def test_create_vertex_url_openapi_vs_deployed_endpoint(
|
||||
|
||||
|
||||
def test_model_id_in_json_body_heuristic() -> None:
|
||||
assert _vertex_model_garden_model_id_in_json_body("xai/grok-4.1-fast-reasoning") is True
|
||||
assert (
|
||||
_vertex_model_garden_model_id_in_json_body("xai/grok-4.1-fast-reasoning")
|
||||
is True
|
||||
)
|
||||
assert _vertex_model_garden_model_id_in_json_body("5464397967697903616") is False
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _reset_litellm_http_client_cache():
|
||||
from litellm import in_memory_llm_clients_cache
|
||||
|
||||
in_memory_llm_clients_cache.flush_cache()
|
||||
yield
|
||||
in_memory_llm_clients_cache.flush_cache()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def clean_vertex_env():
|
||||
saved_env = {}
|
||||
env_vars_to_clear = [
|
||||
"GOOGLE_APPLICATION_CREDENTIALS",
|
||||
"GOOGLE_CLOUD_PROJECT",
|
||||
"VERTEXAI_PROJECT",
|
||||
"VERTEXAI_LOCATION",
|
||||
"VERTEXAI_CREDENTIALS",
|
||||
"VERTEX_PROJECT",
|
||||
"VERTEX_LOCATION",
|
||||
"VERTEX_AI_PROJECT",
|
||||
]
|
||||
for var in env_vars_to_clear:
|
||||
if var in os.environ:
|
||||
saved_env[var] = os.environ[var]
|
||||
del os.environ[var]
|
||||
|
||||
yield
|
||||
|
||||
for var, value in saved_env.items():
|
||||
os.environ[var] = value
|
||||
|
||||
|
||||
def _mock_chat_completion_response(model_in_response: str) -> MagicMock:
|
||||
response = MagicMock()
|
||||
response.status_code = 200
|
||||
response.headers = {}
|
||||
response.json.return_value = {
|
||||
"id": "chatcmpl-test",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": model_in_response,
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": "hi"},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {"prompt_tokens": 1, "completion_tokens": 1, "total_tokens": 2},
|
||||
}
|
||||
return response
|
||||
|
||||
|
||||
async def _invoke_model_garden_completion(
|
||||
*,
|
||||
model: str,
|
||||
api_base,
|
||||
mock_response: MagicMock,
|
||||
):
|
||||
"""Drive litellm.acompletion through the Vertex Model Garden route and return
|
||||
the patched AsyncHTTPHandler so callers can inspect the outbound HTTP call."""
|
||||
mock_vertexai = MagicMock()
|
||||
mock_vertexai.preview = MagicMock()
|
||||
mock_vertexai.preview.language_models = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler"
|
||||
) as mock_http_handler,
|
||||
patch(
|
||||
"litellm.llms.vertex_ai.vertex_model_garden.main.VertexAIModelGardenModels._ensure_access_token",
|
||||
return_value=("fake-token", "test-project"),
|
||||
),
|
||||
patch.dict(
|
||||
sys.modules,
|
||||
{"vertexai": mock_vertexai, "vertexai.preview": mock_vertexai.preview},
|
||||
),
|
||||
):
|
||||
mock_http_handler.return_value.post = AsyncMock(return_value=mock_response)
|
||||
|
||||
kwargs = dict(
|
||||
model=model,
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
vertex_ai_location="us-central1",
|
||||
vertex_ai_project="test-project",
|
||||
)
|
||||
if api_base is not None:
|
||||
kwargs["api_base"] = api_base
|
||||
|
||||
await litellm.acompletion(**kwargs)
|
||||
|
||||
return mock_http_handler
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_supplied_api_base_passes_through_unchanged(
|
||||
clean_vertex_env, _reset_litellm_http_client_cache
|
||||
):
|
||||
"""A user-supplied api_base must reach the OpenAI-like handler unchanged,
|
||||
with only its own '/chat/completions' suffix appended."""
|
||||
user_api_base = "https://my-endpoint.example.com/v1"
|
||||
mock_http_handler = await _invoke_model_garden_completion(
|
||||
model="vertex_ai/openai/5464397967697903616",
|
||||
api_base=user_api_base,
|
||||
mock_response=_mock_chat_completion_response("5464397967697903616"),
|
||||
)
|
||||
|
||||
mock_http_handler.return_value.post.assert_called_once()
|
||||
call_args = mock_http_handler.return_value.post.call_args
|
||||
called_url = call_args.kwargs.get("url") or call_args.args[0]
|
||||
request_body = json.loads(call_args.kwargs["data"])
|
||||
|
||||
assert called_url == f"{user_api_base}/chat/completions"
|
||||
assert ":" not in called_url.replace("https://", "")
|
||||
assert "aiplatform.googleapis.com" not in called_url
|
||||
assert request_body["model"] == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_supplied_api_base_passthrough_for_publisher_model(
|
||||
clean_vertex_env, _reset_litellm_http_client_cache
|
||||
):
|
||||
"""User-supplied api_base is forwarded unchanged for publisher/catalog
|
||||
models too; the publisher model id stays in the JSON body."""
|
||||
user_api_base = "https://my-endpoint.example.com/v1"
|
||||
mock_http_handler = await _invoke_model_garden_completion(
|
||||
model="vertex_ai/openai/xai/grok-4.1-fast-reasoning",
|
||||
api_base=user_api_base,
|
||||
mock_response=_mock_chat_completion_response("xai/grok-4.1-fast-reasoning"),
|
||||
)
|
||||
|
||||
mock_http_handler.return_value.post.assert_called_once()
|
||||
call_args = mock_http_handler.return_value.post.call_args
|
||||
called_url = call_args.kwargs.get("url") or call_args.args[0]
|
||||
request_body = json.loads(call_args.kwargs["data"])
|
||||
|
||||
assert called_url == f"{user_api_base}/chat/completions"
|
||||
assert "aiplatform.googleapis.com" not in called_url
|
||||
assert request_body["model"] == "xai/grok-4.1-fast-reasoning"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_api_base_when_none_provided_single_segment(
|
||||
clean_vertex_env, _reset_litellm_http_client_cache
|
||||
):
|
||||
"""With no api_base, single-segment endpoint ids must hit the per-endpoint
|
||||
Vertex URL and send an empty model field in the body."""
|
||||
mock_http_handler = await _invoke_model_garden_completion(
|
||||
model="vertex_ai/openai/5464397967697903616",
|
||||
api_base=None,
|
||||
mock_response=_mock_chat_completion_response("5464397967697903616"),
|
||||
)
|
||||
|
||||
mock_http_handler.return_value.post.assert_called_once()
|
||||
call_args = mock_http_handler.return_value.post.call_args
|
||||
called_url = call_args.kwargs.get("url") or call_args.args[0]
|
||||
request_body = json.loads(call_args.kwargs["data"])
|
||||
|
||||
assert called_url == (
|
||||
"https://us-central1-aiplatform.googleapis.com/v1beta1/projects/"
|
||||
"test-project/locations/us-central1/endpoints/5464397967697903616/chat/completions"
|
||||
)
|
||||
assert request_body["model"] == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_api_base_when_none_provided_publisher_model(
|
||||
clean_vertex_env, _reset_litellm_http_client_cache
|
||||
):
|
||||
"""With no api_base, publisher/catalog models must hit the shared OpenAPI
|
||||
URL and send the publisher model id in the body."""
|
||||
mock_http_handler = await _invoke_model_garden_completion(
|
||||
model="vertex_ai/openai/xai/grok-4.1-fast-reasoning",
|
||||
api_base=None,
|
||||
mock_response=_mock_chat_completion_response("xai/grok-4.1-fast-reasoning"),
|
||||
)
|
||||
|
||||
mock_http_handler.return_value.post.assert_called_once()
|
||||
call_args = mock_http_handler.return_value.post.call_args
|
||||
called_url = call_args.kwargs.get("url") or call_args.args[0]
|
||||
request_body = json.loads(call_args.kwargs["data"])
|
||||
|
||||
assert called_url == (
|
||||
"https://us-central1-aiplatform.googleapis.com/v1/projects/"
|
||||
"test-project/locations/us-central1/endpoints/openapi/chat/completions"
|
||||
)
|
||||
assert request_body["model"] == "xai/grok-4.1-fast-reasoning"
|
||||
|
||||
Loading…
Reference in New Issue
Block a user