fix(proxy): self-heal startup/reload prisma reads on engine disconnect (#28803)

This commit is contained in:
michelligabriele 2026-06-10 20:16:58 +02:00 committed by GitHub
parent 3b40ac987f
commit f9293d40c4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 295 additions and 14 deletions

View File

@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import ToolDiscoveryQueueItem from litellm.proxy._types import ToolDiscoveryQueueItem
from litellm.proxy.db.exception_handler import call_with_db_reconnect_retry
from litellm.repositories.object_permission_repository import ObjectPermissionRepository from litellm.repositories.object_permission_repository import ObjectPermissionRepository
from litellm.repositories.table_repositories import ToolRepository from litellm.repositories.table_repositories import ToolRepository
from litellm.types.tool_management import ( from litellm.types.tool_management import (
@ -309,7 +310,11 @@ class ToolPolicyRegistry:
async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None: async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None:
"""Load all tool policies and object-permission blocked_tools from DB.""" """Load all tool policies and object-permission blocked_tools from DB."""
try: try:
tools = await ToolRepository(prisma_client).table.find_many() tools = await call_with_db_reconnect_retry(
prisma_client,
lambda: ToolRepository(prisma_client).table.find_many(),
reason="sync_tool_policy_from_db_tools_lookup_failure",
)
self._tool_input_policies = { self._tool_input_policies = {
row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted" row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted"
for row in tools for row in tools
@ -319,7 +324,11 @@ class ToolPolicyRegistry:
for row in tools for row in tools
} }
perms = await ObjectPermissionRepository(prisma_client).table.find_many() perms = await call_with_db_reconnect_retry(
prisma_client,
lambda: ObjectPermissionRepository(prisma_client).table.find_many(),
reason="sync_tool_policy_from_db_perms_lookup_failure",
)
self._blocked_tools_by_op_id = {} self._blocked_tools_by_op_id = {}
for row in perms: for row in perms:
op_id = getattr(row, "object_permission_id", None) op_id = getattr(row, "object_permission_id", None)

View File

@ -27,6 +27,7 @@ from litellm.proxy._types import (
UserAPIKeyAuth, UserAPIKeyAuth,
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.db.exception_handler import call_with_db_reconnect_retry
from litellm.repositories.table_repositories import CacheConfigRepository from litellm.repositories.table_repositories import CacheConfigRepository
from litellm.types.management_endpoints import ( from litellm.types.management_endpoints import (
CACHE_SETTINGS_FIELDS, CACHE_SETTINGS_FIELDS,
@ -160,8 +161,12 @@ class CacheSettingsManager:
import json import json
try: try:
cache_config = await CacheConfigRepository(prisma_client).table.find_unique( cache_config = await call_with_db_reconnect_retry(
where={"id": "cache_config"} prisma_client,
lambda: CacheConfigRepository(prisma_client).table.find_unique(
where={"id": "cache_config"}
),
reason="init_cache_settings_in_db_lookup_failure",
) )
if cache_config is not None and cache_config.cache_settings: if cache_config is not None and cache_config.cache_settings:
# Parse cache settings JSON # Parse cache settings JSON

View File

@ -311,7 +311,10 @@ from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
from litellm.proxy.container_endpoints.endpoints import router as container_router from litellm.proxy.container_endpoints.endpoints import router as container_router
from litellm.proxy.credential_endpoints.endpoints import router as credential_router from litellm.proxy.credential_endpoints.endpoints import router as credential_router
from litellm.proxy.db.db_transaction_queue.spend_log_cleanup import SpendLogCleanup from litellm.proxy.db.db_transaction_queue.spend_log_cleanup import SpendLogCleanup
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler from litellm.proxy.db.exception_handler import (
PrismaDBExceptionHandler,
call_with_db_reconnect_retry,
)
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
from litellm.proxy.discovery_endpoints import ui_discovery_endpoints_router from litellm.proxy.discovery_endpoints import ui_discovery_endpoints_router
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
@ -5984,8 +5987,12 @@ class ProxyConfig:
""" """
try: try:
sso_settings = await SSOConfigRepository(prisma_client).table.find_unique( sso_settings = await call_with_db_reconnect_retry(
where={"id": "sso_config"} prisma_client,
lambda: SSOConfigRepository(prisma_client).table.find_unique(
where={"id": "sso_config"}
),
reason="init_sso_settings_in_db_lookup_failure",
) )
if sso_settings is not None: if sso_settings is not None:
sso_settings.sso_settings.pop("role_mappings", None) sso_settings.sso_settings.pop("role_mappings", None)
@ -6020,9 +6027,13 @@ class ProxyConfig:
) )
try: try:
db_record = await ConfigOverridesRepository( db_record = await call_with_db_reconnect_retry(
prisma_client prisma_client,
).table.find_unique(where={"config_type": "hashicorp_vault"}) lambda: ConfigOverridesRepository(prisma_client).table.find_unique(
where={"config_type": "hashicorp_vault"}
),
reason="init_hashicorp_vault_config_override_lookup_failure",
)
if db_record is None or db_record.config_value is None: if db_record is None or db_record.config_value is None:
if self._last_hashicorp_vault_config is not None: if self._last_hashicorp_vault_config is not None:

View File

@ -7,6 +7,7 @@ from typing import List, Optional
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.proxy.db.exception_handler import call_with_db_reconnect_retry
from litellm.proxy.utils import PrismaClient from litellm.proxy.utils import PrismaClient
from litellm.repositories.table_repositories import SearchToolsRepository from litellm.repositories.table_repositories import SearchToolsRepository
from litellm.types.search import SearchTool from litellm.types.search import SearchTool
@ -180,10 +181,12 @@ class SearchToolRegistry:
List of search tool configurations List of search tool configurations
""" """
try: try:
search_tools_from_db = await SearchToolsRepository( search_tools_from_db = await call_with_db_reconnect_retry(
prisma_client prisma_client,
).table.find_many( lambda: SearchToolsRepository(prisma_client).table.find_many(
order={"created_at": "desc"}, order={"created_at": "desc"},
),
reason="get_all_search_tools_from_db_lookup_failure",
) )
search_tools: List[SearchTool] = [] search_tools: List[SearchTool] = []

View File

@ -291,3 +291,77 @@ async def test_tool_policy_registry_not_initialized_returns_untrusted():
assert not registry.is_initialized() assert not registry.is_initialized()
result = registry.get_effective_policies(["unknown_tool"]) result = registry.get_effective_policies(["unknown_tool"])
assert result == {"unknown_tool": "untrusted"} assert result == {"unknown_tool": "untrusted"}
@pytest.mark.asyncio
async def test_sync_tool_policy_from_db_retries_on_transport_error_first_read():
"""`ToolPolicyRegistry.sync_tool_policy_from_db` self-heals across one
ClientNotConnectedError on the tools read the perms read still fires
after the recovery and the registry initializes cleanly."""
import prisma as prisma_pkg
registry = ToolPolicyRegistry()
invocations: list = []
async def _flaky_find_many():
invocations.append(None)
if len(invocations) == 1:
raise prisma_pkg.errors.ClientNotConnectedError()
return []
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_tooltable.find_many = AsyncMock(
side_effect=_flaky_find_many
)
mock_prisma_client.db.litellm_objectpermissiontable.find_many = AsyncMock(
return_value=[]
)
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
await registry.sync_tool_policy_from_db(mock_prisma_client)
assert len(invocations) == 2
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
assert (
reconnect_kwargs["reason"]
== "sync_tool_policy_from_db_tools_lookup_failure"
)
assert registry.is_initialized()
@pytest.mark.asyncio
async def test_sync_tool_policy_from_db_retries_on_transport_error_second_read():
"""Same as above but the blip happens on the perms read — distinct reason
tag in telemetry confirms the second wrap is wired separately."""
import prisma as prisma_pkg
registry = ToolPolicyRegistry()
perms_invocations: list = []
async def _flaky_perms_find_many():
perms_invocations.append(None)
if len(perms_invocations) == 1:
raise prisma_pkg.errors.ClientNotConnectedError()
return []
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_tooltable.find_many = AsyncMock(return_value=[])
mock_prisma_client.db.litellm_objectpermissiontable.find_many = AsyncMock(
side_effect=_flaky_perms_find_many
)
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
await registry.sync_tool_policy_from_db(mock_prisma_client)
assert len(perms_invocations) == 2
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
assert (
reconnect_kwargs["reason"]
== "sync_tool_policy_from_db_perms_lookup_failure"
)

View File

@ -611,6 +611,45 @@ async def test_list_search_tools_db_masking_sensitive_values(monkeypatch):
app.dependency_overrides.pop(user_api_key_auth, None) app.dependency_overrides.pop(user_api_key_auth, None)
@pytest.mark.asyncio
async def test_get_all_search_tools_from_db_retries_on_transport_error():
"""`SearchToolRegistry.get_all_search_tools_from_db` self-heals across one
ClientNotConnectedError via call_with_db_reconnect_retry."""
import prisma
from litellm.proxy.search_endpoints.search_tool_registry import (
SearchToolRegistry,
)
invocations: list = []
async def _flaky_find_many(**kwargs):
invocations.append(None)
if len(invocations) == 1:
raise prisma.errors.ClientNotConnectedError()
return []
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_searchtoolstable.find_many = AsyncMock(
side_effect=_flaky_find_many
)
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
result = await SearchToolRegistry.get_all_search_tools_from_db(
prisma_client=mock_prisma_client
)
assert result == []
assert len(invocations) == 2
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
assert (
reconnect_kwargs["reason"]
== "get_all_search_tools_from_db_lookup_failure"
)
@contextlib.contextmanager @contextlib.contextmanager
def _mock_search_tool_backend(db_tools): def _mock_search_tool_backend(db_tools):
"""Patch the DB registry, prisma client, and config so /search_tools/list """Patch the DB registry, prisma client, and config so /search_tools/list

View File

@ -259,6 +259,41 @@ class TestCacheSettingsManager:
mock_proxy_config._init_cache.assert_not_called() mock_proxy_config._init_cache.assert_not_called()
mock_proxy_config.switch_on_llm_response_caching.assert_not_called() mock_proxy_config.switch_on_llm_response_caching.assert_not_called()
@pytest.mark.asyncio
async def test_init_cache_settings_in_db_retries_on_transport_error(self):
"""`CacheSettingsManager.init_cache_settings_in_db` self-heals across one
ClientNotConnectedError via call_with_db_reconnect_retry."""
import prisma
invocations: list = []
async def _flaky_find_unique(**kwargs):
invocations.append(None)
if len(invocations) == 1:
raise prisma.errors.ClientNotConnectedError()
return None # No config → function returns early after retry.
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_cacheconfig.find_unique = AsyncMock(
side_effect=_flaky_find_unique
)
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
mock_proxy_config = MagicMock()
await CacheSettingsManager.init_cache_settings_in_db(
prisma_client=mock_prisma_client, proxy_config=mock_proxy_config
)
assert len(invocations) == 2
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
assert (
reconnect_kwargs["reason"]
== "init_cache_settings_in_db_lookup_failure"
)
# ── Audit-log emission for /cache/settings ──────────────────────────────────── # ── Audit-log emission for /cache/settings ────────────────────────────────────

View File

@ -4324,6 +4324,111 @@ async def test_init_sso_settings_in_db_empty_settings():
assert uppercased_settings == {} assert uppercased_settings == {}
@pytest.mark.asyncio
async def test_init_sso_settings_in_db_retries_on_transport_error():
"""`_init_sso_settings_in_db` self-heals across one ClientNotConnectedError
via call_with_db_reconnect_retry mirrors the auth-path behavior so
startup/reload bursts don't spam the log."""
import prisma
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
mock_sso_config = MagicMock()
mock_sso_config.sso_settings = {"GOOGLE_CLIENT_ID": "xxx"}
invocations: list = []
async def _flaky_find_unique(**kwargs):
invocations.append(None)
if len(invocations) == 1:
raise prisma.errors.ClientNotConnectedError()
return mock_sso_config
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
side_effect=_flaky_find_unique
)
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
with patch.object(
proxy_config, "_decrypt_and_set_db_env_variables"
) as mock_decrypt:
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
assert len(invocations) == 2
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
assert reconnect_kwargs["reason"] == "init_sso_settings_in_db_lookup_failure"
mock_decrypt.assert_called_once()
@pytest.mark.asyncio
async def test_init_sso_settings_in_db_propagates_when_reconnect_fails():
"""When reconnect returns False (cooldown / lock contention), the original
ClientNotConnectedError is caught by the function's `except Exception` and
logged no retry storm, no crash."""
import prisma
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
side_effect=prisma.errors.ClientNotConnectedError()
)
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=False)
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
# Should NOT raise — the function's own try/except swallows the propagated error.
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
@pytest.mark.asyncio
async def test_init_hashicorp_vault_config_override_retries_on_transport_error():
"""`_init_hashicorp_vault_config_override` self-heals across one
ClientNotConnectedError via call_with_db_reconnect_retry."""
import prisma
from litellm.proxy.proxy_server import ProxyConfig
proxy_config = ProxyConfig()
proxy_config._last_hashicorp_vault_config = None
invocations: list = []
async def _flaky_find_unique(**kwargs):
invocations.append(None)
if len(invocations) == 1:
raise prisma.errors.ClientNotConnectedError()
return None # No config in DB → function returns early after retry.
mock_prisma_client = MagicMock()
mock_prisma_client.db.litellm_configoverrides.find_unique = AsyncMock(
side_effect=_flaky_find_unique
)
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
await proxy_config._init_hashicorp_vault_config_override(
prisma_client=mock_prisma_client
)
assert len(invocations) == 2
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
assert (
reconnect_kwargs["reason"]
== "init_hashicorp_vault_config_override_lookup_failure"
)
def test_update_config_fields_uppercases_env_vars(monkeypatch): def test_update_config_fields_uppercases_env_vars(monkeypatch):
""" """
Ensure environment variables pulled from DB are uppercased when applied so Ensure environment variables pulled from DB are uppercased when applied so