litellm/tests/test_litellm/conftest.py
2026-04-17 13:02:59 -07:00

475 lines
17 KiB
Python

# conftest.py - IMPROVED VERSION
#
# Key changes:
# 1. Changed module reload from 'module' scope to 'function' scope for better isolation
# 2. Made cache flushing happen per-function instead of per-module
# 3. Removed manual event loop creation (let pytest-asyncio handle it)
# 4. Added proper cleanup in fixtures
# 5. Added worker-specific isolation for parallel execution
import importlib
import os
import sys
from pathlib import Path
import pytest
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import litellm
from litellm._logging import ALL_LOGGERS
from litellm.litellm_core_utils.prompt_templates import (
image_handling as image_handling_module,
)
from litellm.llms.custom_httpx.async_client_cleanup import (
close_litellm_async_clients,
)
from litellm.proxy.db import tool_registry_writer as tool_registry_writer_module
def _reset_module_level_aws_auth_caches():
"""
Clear module-level AWS auth state that can survive between tests.
Bedrock/SageMaker handlers are instantiated once at import time and cache
resolved credentials on the handler instance. If a previous test resolves an
invalid or different auth flow, later tests can reuse that cached state and
bypass their local monkeypatched env setup.
"""
for module_name in (
"litellm.main",
"litellm.files.main",
"litellm.rerank_api.main",
"litellm.realtime_api.main",
):
try:
module = importlib.import_module(module_name)
except Exception:
continue
for attr_name in dir(module):
obj = getattr(module, attr_name)
iam_cache = getattr(obj, "iam_cache", None)
if iam_cache is None:
continue
flush_cache = getattr(iam_cache, "flush_cache", None)
if callable(flush_cache):
flush_cache()
try:
import boto3
boto3.DEFAULT_SESSION = None
except Exception:
pass
@pytest.fixture(scope="session")
def isolated_aws_credentials_dir(tmp_path_factory):
aws_dir = tmp_path_factory.mktemp("aws-config")
credentials_file = Path(aws_dir) / "credentials"
config_file = Path(aws_dir) / "config"
credentials_file.write_text("", encoding="utf-8")
config_file.write_text("", encoding="utf-8")
return {
"credentials": str(credentials_file),
"config": str(config_file),
}
@pytest.fixture(scope="function", autouse=True)
def isolate_host_aws_config(monkeypatch, isolated_aws_credentials_dir):
"""Prevent botocore from reading host AWS profiles during unit tests."""
monkeypatch.setenv(
"AWS_SHARED_CREDENTIALS_FILE", isolated_aws_credentials_dir["credentials"]
)
monkeypatch.setenv("AWS_CONFIG_FILE", isolated_aws_credentials_dir["config"])
monkeypatch.setenv("AWS_EC2_METADATA_DISABLED", "true")
monkeypatch.delenv("AWS_PROFILE", raising=False)
monkeypatch.delenv("AWS_DEFAULT_PROFILE", raising=False)
monkeypatch.delenv("AWS_CONTAINER_CREDENTIALS_FULL_URI", raising=False)
monkeypatch.delenv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", raising=False)
monkeypatch.delenv("AWS_SESSION_TOKEN", raising=False)
monkeypatch.delenv("AWS_ROLE_ARN", raising=False)
monkeypatch.delenv("AWS_WEB_IDENTITY_TOKEN_FILE", raising=False)
monkeypatch.delenv("AWS_BEARER_TOKEN_BEDROCK", raising=False)
monkeypatch.delenv("AWS_REGION_NAME", raising=False)
monkeypatch.delenv("AWS_DEFAULT_REGION", raising=False)
def _run_coroutine_if_needed(result):
if not asyncio.iscoroutine(result):
return
try:
asyncio.run(result)
except RuntimeError:
# If pytest-asyncio already has a running loop, best-effort scheduling is
# still better than leaking the client entirely.
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return
if loop.is_running():
loop.create_task(result)
except Exception:
pass
def _close_handler_if_needed(handler):
if handler is None:
return
close_fn = getattr(handler, "close", None)
if not callable(close_fn):
return
try:
result = close_fn()
_run_coroutine_if_needed(result)
except Exception:
pass
@pytest.fixture(scope="function", autouse=True)
def isolate_litellm_state():
"""
Per-function isolation fixture (changed from module scope).
This ensures better isolation when running tests in parallel:
- Each test function gets a clean litellm state
- Cache is flushed before each test
- No module reloading during parallel execution
Note: Module reloading at function scope is safer for parallel execution
but adds overhead. Consider removing reload entirely if tests can work without it.
"""
# Get worker ID if running with pytest-xdist
worker_id = os.environ.get("PYTEST_XDIST_WORKER", "master")
# Store original callback state (all callback lists)
original_state = {}
if hasattr(litellm, "callbacks"):
original_state["callbacks"] = (
litellm.callbacks.copy() if litellm.callbacks else []
)
if hasattr(litellm, "success_callback"):
original_state["success_callback"] = (
litellm.success_callback.copy() if litellm.success_callback else []
)
if hasattr(litellm, "failure_callback"):
original_state["failure_callback"] = (
litellm.failure_callback.copy() if litellm.failure_callback else []
)
if hasattr(litellm, "input_callback"):
original_state["input_callback"] = (
litellm.input_callback.copy() if litellm.input_callback else []
)
if hasattr(litellm, "_async_success_callback"):
original_state["_async_success_callback"] = (
litellm._async_success_callback.copy()
if litellm._async_success_callback
else []
)
if hasattr(litellm, "_async_failure_callback"):
original_state["_async_failure_callback"] = (
litellm._async_failure_callback.copy()
if litellm._async_failure_callback
else []
)
if hasattr(litellm, "_async_input_callback"):
original_state["_async_input_callback"] = (
litellm._async_input_callback.copy()
if litellm._async_input_callback
else []
)
# Store routing globals — leaked model_fallbacks causes tests to route
# through async_completion_with_fallbacks / Router, bypassing HTTP mocks
if hasattr(litellm, "model_fallbacks"):
original_state["model_fallbacks"] = litellm.model_fallbacks
# Store transport/network globals — many tests set these without restoring,
# causing subsequent tests to get None from _create_async_transport()
for _attr in ("disable_aiohttp_transport", "force_ipv4"):
if hasattr(litellm, _attr):
original_state[_attr] = getattr(litellm, _attr)
# Store request-mapping globals that are frequently mutated in tests.
if hasattr(litellm, "drop_params"):
original_state["drop_params"] = litellm.drop_params
if hasattr(litellm, "cache"):
original_state["cache"] = litellm.cache
# Store secret-manager globals. Several tests swap these out, which changes
# get_secret() behavior for later env-driven tests (for example Redis config).
for _attr in (
"secret_manager_client",
"_key_management_system",
"_key_management_settings",
):
if hasattr(litellm, _attr):
original_state[_attr] = getattr(litellm, _attr)
# Store other commonly-mutated LiteLLM globals that affect provider routing,
# auth, and request shaping during larger suite runs.
for _attr in (
"api_base",
"num_retries",
"modify_params",
"ssl_verify",
"credential_list",
"model_group_settings",
"default_internal_user_params",
"default_team_params",
"prometheus_emit_stream_label",
"vector_store_registry",
"model_cost",
"cost_margin_config",
"cost_discount_config",
"disable_hf_tokenizer_download",
"disable_copilot_system_to_assistant",
"cohere_models",
"anthropic_models",
"token_counter",
"initialized_langfuse_clients",
):
if hasattr(litellm, _attr):
original_state[_attr] = getattr(litellm, _attr)
# Store LiteLLM logger state. Some tests reconfigure handlers/propagation for
# JSON logging and do not restore them, which breaks later caplog-based tests.
logger_state = {}
for logger in ALL_LOGGERS:
logger_state[logger.name] = {
"level": logger.level,
"disabled": logger.disabled,
"propagate": logger.propagate,
"handlers": list(logger.handlers),
"filters": list(logger.filters),
}
# Store singleton registries that are lazily initialized during tests and
# can change endpoint behavior later in the suite.
original_tool_policy_registry = tool_registry_writer_module._tool_policy_registry
had_module_level_client = "module_level_client" in litellm.__dict__
had_module_level_aclient = "module_level_aclient" in litellm.__dict__
original_module_level_client = litellm.__dict__.get("module_level_client")
original_module_level_aclient = litellm.__dict__.get("module_level_aclient")
# Flush cache before test (critical for respx mocks)
if hasattr(litellm, "in_memory_llm_clients_cache"):
litellm.in_memory_llm_clients_cache.flush_cache()
image_handling_module.in_memory_cache.flush_cache()
_reset_module_level_aws_auth_caches()
# Clear all callback lists to prevent cross-test contamination
if hasattr(litellm, "callbacks"):
litellm.callbacks = []
if hasattr(litellm, "success_callback"):
litellm.success_callback = []
if hasattr(litellm, "failure_callback"):
litellm.failure_callback = []
if hasattr(litellm, "input_callback"):
litellm.input_callback = []
if hasattr(litellm, "_async_success_callback"):
litellm._async_success_callback = []
if hasattr(litellm, "_async_failure_callback"):
litellm._async_failure_callback = []
if hasattr(litellm, "_async_input_callback"):
litellm._async_input_callback = []
# Clear routing globals
if hasattr(litellm, "model_fallbacks"):
litellm.model_fallbacks = None
if hasattr(litellm, "cache"):
litellm.cache = None
litellm.__dict__.pop("module_level_client", None)
litellm.__dict__.pop("module_level_aclient", None)
tool_registry_writer_module._tool_policy_registry = None
yield
# Cleanup after test
if hasattr(litellm, "in_memory_llm_clients_cache"):
litellm.in_memory_llm_clients_cache.flush_cache()
image_handling_module.in_memory_cache.flush_cache()
_reset_module_level_aws_auth_caches()
current_module_level_client = litellm.__dict__.get("module_level_client")
current_module_level_aclient = litellm.__dict__.get("module_level_aclient")
# Restore all callback lists to original state
for attr_name, original_value in original_state.items():
if hasattr(litellm, attr_name):
setattr(litellm, attr_name, original_value)
# Restore logger configuration mutated by logging-focused tests.
for logger in ALL_LOGGERS:
original_logger_state = logger_state.get(logger.name)
if original_logger_state is None:
continue
logger.setLevel(original_logger_state["level"])
logger.disabled = original_logger_state["disabled"]
logger.propagate = original_logger_state["propagate"]
logger.handlers = list(original_logger_state["handlers"])
logger.filters = list(original_logger_state["filters"])
tool_registry_writer_module._tool_policy_registry = original_tool_policy_registry
if current_module_level_client is not original_module_level_client:
_close_handler_if_needed(current_module_level_client)
if current_module_level_aclient is not original_module_level_aclient:
_close_handler_if_needed(current_module_level_aclient)
if had_module_level_client:
litellm.__dict__["module_level_client"] = original_module_level_client
else:
litellm.__dict__.pop("module_level_client", None)
if had_module_level_aclient:
litellm.__dict__["module_level_aclient"] = original_module_level_aclient
else:
litellm.__dict__.pop("module_level_aclient", None)
@pytest.fixture(scope="module", autouse=True)
def setup_and_teardown():
"""
Module-scoped setup/teardown for heavy initialization.
Use this sparingly - most state should be handled by isolate_litellm_state.
Only reload modules here if absolutely necessary.
"""
sys.path.insert(0, os.path.abspath("../.."))
import litellm
# Only reload if NOT running in parallel (module reload + parallel = bad)
worker_id = os.environ.get("PYTEST_XDIST_WORKER", None)
if worker_id is None:
# Single process mode - safe to reload
importlib.reload(litellm)
try:
if hasattr(litellm, "proxy") and hasattr(litellm.proxy, "proxy_server"):
import litellm.proxy.proxy_server
importlib.reload(litellm.proxy.proxy_server)
except Exception as e:
print(f"Error reloading litellm.proxy.proxy_server: {e}")
# Flush cache after reload (prevents stale client instances)
if hasattr(litellm, "in_memory_llm_clients_cache"):
litellm.in_memory_llm_clients_cache.flush_cache()
print(f"[conftest] Module setup complete (worker: {worker_id or 'master'})")
yield
# Teardown - no need to manually manage event loops with pytest-asyncio auto mode
print(f"[conftest] Module teardown complete (worker: {worker_id or 'master'})")
def pytest_collection_modifyitems(config, items):
"""
Customize test collection order.
- Separate tests marked with 'no_parallel' from parallelizable tests
- Sort custom_logger tests first (they tend to interfere with other tests)
"""
# Separate no_parallel tests
no_parallel_tests = [
item
for item in items
if any(mark.name == "no_parallel" for mark in item.iter_markers())
]
# Separate custom_logger tests
custom_logger_tests = [
item
for item in items
if "custom_logger" in item.parent.name and item not in no_parallel_tests
]
# Everything else
other_tests = [
item
for item in items
if item not in no_parallel_tests and item not in custom_logger_tests
]
# Sort each group
custom_logger_tests.sort(key=lambda x: x.name)
other_tests.sort(key=lambda x: x.name)
no_parallel_tests.sort(key=lambda x: x.name)
# Reorder: custom_logger first (isolated), then other tests, then no_parallel tests last
items[:] = custom_logger_tests + other_tests + no_parallel_tests
def pytest_configure(config):
"""
Configure pytest with custom settings.
"""
# Add marker for flaky tests (for documentation purposes)
config.addinivalue_line(
"markers", "flaky: mark test as potentially flaky (should use --reruns)"
)
# Detect if running in CI
is_ci = os.environ.get("CI") == "true" or os.environ.get("LITELLM_CI") == "true"
if is_ci:
print("[conftest] Running in CI mode - enabling stricter test isolation")
# Optional: Add a fixture for tests that need even stricter isolation
@pytest.fixture
def strict_isolation():
"""
Use this fixture for tests that need extra strict isolation.
Example:
def test_something(strict_isolation):
# Test code with guaranteed clean state
pass
"""
# Force flush all caches
if hasattr(litellm, "in_memory_llm_clients_cache"):
litellm.in_memory_llm_clients_cache.flush_cache()
# Reset all global state
if hasattr(litellm, "disable_aiohttp_transport"):
original_aiohttp = litellm.disable_aiohttp_transport
litellm.disable_aiohttp_transport = False
else:
original_aiohttp = None
if hasattr(litellm, "set_verbose"):
original_verbose = litellm.set_verbose
litellm.set_verbose = False
else:
original_verbose = None
yield
# Restore original state
if original_aiohttp is not None:
litellm.disable_aiohttp_transport = original_aiohttp
if original_verbose is not None:
litellm.set_verbose = original_verbose
# Final cache flush
if hasattr(litellm, "in_memory_llm_clients_cache"):
litellm.in_memory_llm_clients_cache.flush_cache()
def pytest_sessionfinish(session, exitstatus):
"""Close any globally cached HTTP clients so xdist workers exit cleanly."""
_close_handler_if_needed(litellm.__dict__.get("module_level_client"))
_close_handler_if_needed(litellm.__dict__.get("module_level_aclient"))
litellm.__dict__.pop("module_level_client", None)
litellm.__dict__.pop("module_level_aclient", None)
_close_handler_if_needed(getattr(litellm, "base_llm_aiohttp_handler", None))
_close_handler_if_needed(getattr(litellm, "httpx_client", None))
_close_handler_if_needed(getattr(litellm, "aclient", None))
_close_handler_if_needed(getattr(litellm, "client", None))
_run_coroutine_if_needed(close_litellm_async_clients())