diff --git a/tests/test_litellm/proxy/db/test_db_url_settings.py b/tests/test_litellm/proxy/db/test_db_url_settings.py index 9e348c3988..b2212068a5 100644 --- a/tests/test_litellm/proxy/db/test_db_url_settings.py +++ b/tests/test_litellm/proxy/db/test_db_url_settings.py @@ -24,29 +24,47 @@ def _apply() -> bool: return DatabaseURLSettings.from_env().apply_to_env() +_MANAGED_DB_ENV_VARS = ( + "IAM_TOKEN_DB_AUTH", + "DATABASE_URL", + "DATABASE_URL_READ_REPLICA", + "DATABASE_HOST", + "DATABASE_PORT", + "DATABASE_USER", + "DATABASE_USERNAME", + "DATABASE_NAME", + "DATABASE_SCHEMA", + "DATABASE_PASSWORD", + "DATABASE_HOST_READ_REPLICA", + "DATABASE_PORT_READ_REPLICA", + "DATABASE_USER_READ_REPLICA", + "DATABASE_USERNAME_READ_REPLICA", + "DATABASE_NAME_READ_REPLICA", + "DATABASE_SCHEMA_READ_REPLICA", + "DATABASE_PASSWORD_READ_REPLICA", +) + + @pytest.fixture(autouse=True) -def _scrub_db_env(monkeypatch): - """Remove every env var the model reads so tests start from a clean slate.""" - for var in ( - "IAM_TOKEN_DB_AUTH", - "DATABASE_URL", - "DATABASE_URL_READ_REPLICA", - "DATABASE_HOST", - "DATABASE_PORT", - "DATABASE_USER", - "DATABASE_USERNAME", - "DATABASE_NAME", - "DATABASE_SCHEMA", - "DATABASE_PASSWORD", - "DATABASE_HOST_READ_REPLICA", - "DATABASE_PORT_READ_REPLICA", - "DATABASE_USER_READ_REPLICA", - "DATABASE_USERNAME_READ_REPLICA", - "DATABASE_NAME_READ_REPLICA", - "DATABASE_SCHEMA_READ_REPLICA", - "DATABASE_PASSWORD_READ_REPLICA", - ): - monkeypatch.delenv(var, raising=False) +def _scrub_db_env(): + """Start each test from a clean slate and restore the original env afterward. + + ``apply_to_env`` writes ``DATABASE_URL`` straight into ``os.environ``, which + ``monkeypatch`` cannot undo. Snapshotting and restoring here keeps a + synthesized URL (e.g. ``writer.example.com``) from leaking into later tests + that read ``DATABASE_URL`` to decide whether to hit a real database. + """ + saved = {var: os.environ.get(var) for var in _MANAGED_DB_ENV_VARS} + for var in _MANAGED_DB_ENV_VARS: + os.environ.pop(var, None) + try: + yield + finally: + for var, value in saved.items(): + if value is None: + os.environ.pop(var, None) + else: + os.environ[var] = value def _stub_iam_token(token: str = "FAKE_TOKEN"): diff --git a/tests/unified_google_tests/base_google_genai_proxy_sdk_test.py b/tests/unified_google_tests/base_google_genai_proxy_sdk_test.py new file mode 100644 index 0000000000..1143183b86 --- /dev/null +++ b/tests/unified_google_tests/base_google_genai_proxy_sdk_test.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import pytest + +try: + from google import genai + from google.genai import types + + GOOGLE_GENAI_SDK_AVAILABLE = True +except ImportError: + GOOGLE_GENAI_SDK_AVAILABLE = False + +MASTER_KEY = "sk-1234" +PROMPT = "Reply with only the single word: pong" + + +def has_vertex_credentials() -> bool: + credentials_file = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", "") + if credentials_file and os.path.isfile(credentials_file): + return True + return bool( + os.environ.get("VERTEX_AI_PRIVATE_KEY", "") + and os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "") + ) + + +def _make_client(proxy_url: str) -> "genai.Client": + return genai.Client( + api_key=MASTER_KEY, + http_options={"base_url": proxy_url}, + ) + + +def _generation_config() -> "types.GenerateContentConfig": + return types.GenerateContentConfig( + temperature=0, + top_p=0.95, + top_k=20, + ) + + +def _collect_stream_text(chunks: List["types.GenerateContentResponse"]) -> str: + return "".join(chunk.text for chunk in chunks if chunk.text) + + +class BaseGoogleGenAIProxySDKTest(ABC): + @property + @abstractmethod + def proxy_model_name(self) -> str: ... + + @property + @abstractmethod + def model_config(self) -> Dict[str, Any]: ... + + def _skip_reason_if_credentials_missing(self) -> Optional[str]: + model = self.model_config.get("model", "") + if model.startswith("gemini/"): + if not os.getenv("GEMINI_API_KEY"): + return "GEMINI_API_KEY not set — skipping Gemini proxy SDK tests" + return None + + if "vertex_ai" in model: + if has_vertex_credentials(): + return None + return "Vertex AI credentials not set — skipping Vertex AI proxy SDK tests" + + return f"Unsupported model for proxy SDK tests: {model}" + + def _require_proxy_sdk(self) -> None: + if not GOOGLE_GENAI_SDK_AVAILABLE: + pytest.skip("google-genai SDK not installed") + reason = self._skip_reason_if_credentials_missing() + if reason: + pytest.skip(reason) + + def test_proxy_genai_sdk_non_streaming(self, google_genai_proxy_url: str) -> None: + self._require_proxy_sdk() + + client = _make_client(google_genai_proxy_url) + response = client.models.generate_content( + model=self.proxy_model_name, + contents=types.Part.from_text(text=PROMPT), + config=_generation_config(), + ) + + assert response is not None + assert response.text is not None + assert len(response.text.strip()) > 0 + + def test_proxy_genai_sdk_streaming_completes_without_errors( + self, google_genai_proxy_url: str + ) -> None: + self._require_proxy_sdk() + + client = _make_client(google_genai_proxy_url) + stream = client.models.generate_content_stream( + model=self.proxy_model_name, + contents=types.Part.from_text(text=PROMPT), + config=_generation_config(), + ) + + chunks: List[types.GenerateContentResponse] = [] + stream_error: Optional[Exception] = None + + try: + for chunk in stream: + chunks.append(chunk) + except Exception as exc: + stream_error = exc + + assert ( + stream_error is None + ), f"Streaming raised {type(stream_error).__name__}: {stream_error}" + assert len(chunks) > 0, "Expected at least one streaming chunk" + assert _collect_stream_text(chunks).strip(), "Expected non-empty streamed text" + + def test_proxy_genai_sdk_streaming_dict_style( + self, google_genai_proxy_url: str + ) -> None: + self._require_proxy_sdk() + + client = _make_client(google_genai_proxy_url) + stream = client.models.generate_content_stream( + model=self.proxy_model_name, + contents={"text": PROMPT}, + config={ + "temperature": 0, + "top_p": 0.95, + "top_k": 20, + }, + ) + + chunks = list(stream) + assert len(chunks) > 0 + assert _collect_stream_text(chunks).strip() diff --git a/tests/unified_google_tests/conftest.py b/tests/unified_google_tests/conftest.py index 5b4f57b803..c6b3fb82d0 100644 --- a/tests/unified_google_tests/conftest.py +++ b/tests/unified_google_tests/conftest.py @@ -3,9 +3,18 @@ import asyncio import importlib import os +import socket import sys +import threading +import time +from pathlib import Path +from typing import Iterator, Tuple import pytest +import uvicorn +from dotenv import load_dotenv + +load_dotenv() sys.path.insert( 0, os.path.abspath("../..") @@ -28,6 +37,99 @@ from tests._vcr_conftest_common import ( # noqa: E402,F401 _verbose_state = VerboseReporterState() +PROXY_CONFIG_PATH = Path(__file__).parent / "google_genai_proxy_test_config.yaml" +PROXY_MASTER_KEY = "sk-1234" +PROXY_START_TIMEOUT_S = 30.0 + + +def _start_proxy_server( + config_path: str, +) -> Tuple[str, uvicorn.Server, threading.Thread, socket.socket]: + from litellm.proxy.proxy_server import ( + app as proxy_app, + cleanup_router_config_variables, + initialize, + ) + + cleanup_router_config_variables() + + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.bind(("127.0.0.1", 0)) + host, port = sock.getsockname() + + config = uvicorn.Config(proxy_app, host=host, port=port, log_level="warning") + server = uvicorn.Server(config) + + def _run() -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(initialize(config=config_path, debug=True)) + loop.run_until_complete(server.serve(sockets=[sock])) + + thread = threading.Thread(target=_run, daemon=True) + thread.start() + + start_time = time.time() + while not server.started: + if not thread.is_alive(): + raise RuntimeError("LiteLLM proxy failed to start") + if time.time() - start_time > PROXY_START_TIMEOUT_S: + raise TimeoutError("LiteLLM proxy did not start in time") + time.sleep(0.05) + + return f"http://{host}:{port}", server, thread, sock + + +@pytest.fixture(scope="session") +def google_genai_proxy_url() -> Iterator[str]: + from base_google_genai_proxy_sdk_test import has_vertex_credentials + from base_google_test import load_vertex_ai_credentials + + saved_env = { + key: os.environ.get(key) + for key in ( + "DATABASE_URL", + "DIRECT_URL", + "LITELLM_MASTER_KEY", + "STORE_MODEL_IN_DB", + "GOOGLE_APPLICATION_CREDENTIALS", + ) + } + temp_credentials_path: str | None = None + os.environ.pop("DATABASE_URL", None) + os.environ.pop("DIRECT_URL", None) + os.environ["LITELLM_MASTER_KEY"] = PROXY_MASTER_KEY + os.environ["STORE_MODEL_IN_DB"] = "False" + + if has_vertex_credentials(): + credentials_file = os.environ.get("GOOGLE_APPLICATION_CREDENTIALS", "") + if not (credentials_file and os.path.isfile(credentials_file)): + vertex_credentials_path = load_vertex_ai_credentials( + model="vertex_ai/gemini-2.5-flash-lite" + ) + if vertex_credentials_path: + temp_credentials_path = vertex_credentials_path + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = vertex_credentials_path + + server_url, server, thread, sock = _start_proxy_server(str(PROXY_CONFIG_PATH)) + try: + yield server_url + finally: + server.should_exit = True + thread.join(timeout=10) + sock.close() + if temp_credentials_path: + try: + os.unlink(temp_credentials_path) + except OSError: + pass + for key, value in saved_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value + @pytest.fixture(scope="session") def event_loop(): @@ -40,7 +142,7 @@ def event_loop(): @pytest.fixture(scope="function", autouse=True) -def setup_and_teardown(): +def setup_and_teardown(request): """ This fixture reloads litellm before every function. To speed up testing by removing callbacks being chained. """ @@ -50,7 +152,8 @@ def setup_and_teardown(): import litellm - importlib.reload(litellm) + if "google_genai_proxy_url" not in request.fixturenames: + importlib.reload(litellm) loop = asyncio.get_event_loop_policy().new_event_loop() asyncio.set_event_loop(loop) @@ -95,7 +198,14 @@ def pytest_runtest_logreport(report): def pytest_collection_modifyitems(config, items): - apply_vcr_auto_marker_to_items(items) + apply_vcr_auto_marker_to_items( + items, + skip_nodeid_suffixes=( + "test_proxy_genai_sdk_non_streaming", + "test_proxy_genai_sdk_streaming_completes_without_errors", + "test_proxy_genai_sdk_streaming_dict_style", + ), + ) # Separate tests in 'test_amazing_proxy_custom_logger.py' and other tests custom_logger_tests = [ diff --git a/tests/unified_google_tests/google_genai_proxy_test_config.yaml b/tests/unified_google_tests/google_genai_proxy_test_config.yaml new file mode 100644 index 0000000000..9913c05d43 --- /dev/null +++ b/tests/unified_google_tests/google_genai_proxy_test_config.yaml @@ -0,0 +1,16 @@ +model_list: + - model_name: gemini-2.5-flash-lite + litellm_params: + model: gemini/gemini-2.5-flash-lite + api_key: os.environ/GEMINI_API_KEY + + - model_name: vertex-gemini-2.5-flash-lite + litellm_params: + model: vertex_ai/gemini-2.5-flash-lite + +general_settings: + master_key: sk-1234 + store_model_in_db: false + +litellm_settings: + drop_params: true diff --git a/tests/unified_google_tests/test_google_ai_studio.py b/tests/unified_google_tests/test_google_ai_studio.py index 2d80f4bc45..afe237a4e5 100644 --- a/tests/unified_google_tests/test_google_ai_studio.py +++ b/tests/unified_google_tests/test_google_ai_studio.py @@ -1,3 +1,4 @@ +from base_google_genai_proxy_sdk_test import BaseGoogleGenAIProxySDKTest from base_google_test import BaseGoogleGenAITest import sys import os @@ -11,7 +12,7 @@ import unittest.mock import json -class TestGoogleGenAIStudio(BaseGoogleGenAITest): +class TestGoogleGenAIStudio(BaseGoogleGenAITest, BaseGoogleGenAIProxySDKTest): """Test Google GenAI Studio""" @property @@ -20,6 +21,10 @@ class TestGoogleGenAIStudio(BaseGoogleGenAITest): "model": "gemini/gemini-2.5-flash-lite", } + @property + def proxy_model_name(self) -> str: + return "gemini-2.5-flash-lite" + @pytest.mark.asyncio async def test_mock_stream_generate_content_with_tools(): diff --git a/tests/unified_google_tests/test_vertex_ai_native.py b/tests/unified_google_tests/test_vertex_ai_native.py index c390d5e728..640157bc33 100644 --- a/tests/unified_google_tests/test_vertex_ai_native.py +++ b/tests/unified_google_tests/test_vertex_ai_native.py @@ -1,7 +1,8 @@ +from base_google_genai_proxy_sdk_test import BaseGoogleGenAIProxySDKTest from base_google_test import BaseGoogleGenAITest -class TestVertexAIGenerateContent(BaseGoogleGenAITest): +class TestVertexAIGenerateContent(BaseGoogleGenAITest, BaseGoogleGenAIProxySDKTest): """Test Vertex AI""" @property @@ -9,3 +10,7 @@ class TestVertexAIGenerateContent(BaseGoogleGenAITest): return { "model": "vertex_ai/gemini-2.5-flash-lite", } + + @property + def proxy_model_name(self) -> str: + return "vertex-gemini-2.5-flash-lite"