fix(proxy): self-heal startup/reload prisma reads on engine disconnect (#28803)
This commit is contained in:
parent
3b40ac987f
commit
f9293d40c4
@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ToolDiscoveryQueueItem
|
||||
from litellm.proxy.db.exception_handler import call_with_db_reconnect_retry
|
||||
from litellm.repositories.object_permission_repository import ObjectPermissionRepository
|
||||
from litellm.repositories.table_repositories import ToolRepository
|
||||
from litellm.types.tool_management import (
|
||||
@ -309,7 +310,11 @@ class ToolPolicyRegistry:
|
||||
async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None:
|
||||
"""Load all tool policies and object-permission blocked_tools from DB."""
|
||||
try:
|
||||
tools = await ToolRepository(prisma_client).table.find_many()
|
||||
tools = await call_with_db_reconnect_retry(
|
||||
prisma_client,
|
||||
lambda: ToolRepository(prisma_client).table.find_many(),
|
||||
reason="sync_tool_policy_from_db_tools_lookup_failure",
|
||||
)
|
||||
self._tool_input_policies = {
|
||||
row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted"
|
||||
for row in tools
|
||||
@ -319,7 +324,11 @@ class ToolPolicyRegistry:
|
||||
for row in tools
|
||||
}
|
||||
|
||||
perms = await ObjectPermissionRepository(prisma_client).table.find_many()
|
||||
perms = await call_with_db_reconnect_retry(
|
||||
prisma_client,
|
||||
lambda: ObjectPermissionRepository(prisma_client).table.find_many(),
|
||||
reason="sync_tool_policy_from_db_perms_lookup_failure",
|
||||
)
|
||||
self._blocked_tools_by_op_id = {}
|
||||
for row in perms:
|
||||
op_id = getattr(row, "object_permission_id", None)
|
||||
|
||||
@ -27,6 +27,7 @@ from litellm.proxy._types import (
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.db.exception_handler import call_with_db_reconnect_retry
|
||||
from litellm.repositories.table_repositories import CacheConfigRepository
|
||||
from litellm.types.management_endpoints import (
|
||||
CACHE_SETTINGS_FIELDS,
|
||||
@ -160,8 +161,12 @@ class CacheSettingsManager:
|
||||
import json
|
||||
|
||||
try:
|
||||
cache_config = await CacheConfigRepository(prisma_client).table.find_unique(
|
||||
where={"id": "cache_config"}
|
||||
cache_config = await call_with_db_reconnect_retry(
|
||||
prisma_client,
|
||||
lambda: CacheConfigRepository(prisma_client).table.find_unique(
|
||||
where={"id": "cache_config"}
|
||||
),
|
||||
reason="init_cache_settings_in_db_lookup_failure",
|
||||
)
|
||||
if cache_config is not None and cache_config.cache_settings:
|
||||
# Parse cache settings JSON
|
||||
|
||||
@ -311,7 +311,10 @@ from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache
|
||||
from litellm.proxy.container_endpoints.endpoints import router as container_router
|
||||
from litellm.proxy.credential_endpoints.endpoints import router as credential_router
|
||||
from litellm.proxy.db.db_transaction_queue.spend_log_cleanup import SpendLogCleanup
|
||||
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
|
||||
from litellm.proxy.db.exception_handler import (
|
||||
PrismaDBExceptionHandler,
|
||||
call_with_db_reconnect_retry,
|
||||
)
|
||||
from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed
|
||||
from litellm.proxy.discovery_endpoints import ui_discovery_endpoints_router
|
||||
from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router
|
||||
@ -5984,8 +5987,12 @@ class ProxyConfig:
|
||||
"""
|
||||
|
||||
try:
|
||||
sso_settings = await SSOConfigRepository(prisma_client).table.find_unique(
|
||||
where={"id": "sso_config"}
|
||||
sso_settings = await call_with_db_reconnect_retry(
|
||||
prisma_client,
|
||||
lambda: SSOConfigRepository(prisma_client).table.find_unique(
|
||||
where={"id": "sso_config"}
|
||||
),
|
||||
reason="init_sso_settings_in_db_lookup_failure",
|
||||
)
|
||||
if sso_settings is not None:
|
||||
sso_settings.sso_settings.pop("role_mappings", None)
|
||||
@ -6020,9 +6027,13 @@ class ProxyConfig:
|
||||
)
|
||||
|
||||
try:
|
||||
db_record = await ConfigOverridesRepository(
|
||||
prisma_client
|
||||
).table.find_unique(where={"config_type": "hashicorp_vault"})
|
||||
db_record = await call_with_db_reconnect_retry(
|
||||
prisma_client,
|
||||
lambda: ConfigOverridesRepository(prisma_client).table.find_unique(
|
||||
where={"config_type": "hashicorp_vault"}
|
||||
),
|
||||
reason="init_hashicorp_vault_config_override_lookup_failure",
|
||||
)
|
||||
|
||||
if db_record is None or db_record.config_value is None:
|
||||
if self._last_hashicorp_vault_config is not None:
|
||||
|
||||
@ -7,6 +7,7 @@ from typing import List, Optional
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy.db.exception_handler import call_with_db_reconnect_retry
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.repositories.table_repositories import SearchToolsRepository
|
||||
from litellm.types.search import SearchTool
|
||||
@ -180,10 +181,12 @@ class SearchToolRegistry:
|
||||
List of search tool configurations
|
||||
"""
|
||||
try:
|
||||
search_tools_from_db = await SearchToolsRepository(
|
||||
prisma_client
|
||||
).table.find_many(
|
||||
order={"created_at": "desc"},
|
||||
search_tools_from_db = await call_with_db_reconnect_retry(
|
||||
prisma_client,
|
||||
lambda: SearchToolsRepository(prisma_client).table.find_many(
|
||||
order={"created_at": "desc"},
|
||||
),
|
||||
reason="get_all_search_tools_from_db_lookup_failure",
|
||||
)
|
||||
|
||||
search_tools: List[SearchTool] = []
|
||||
|
||||
@ -291,3 +291,77 @@ async def test_tool_policy_registry_not_initialized_returns_untrusted():
|
||||
assert not registry.is_initialized()
|
||||
result = registry.get_effective_policies(["unknown_tool"])
|
||||
assert result == {"unknown_tool": "untrusted"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_tool_policy_from_db_retries_on_transport_error_first_read():
|
||||
"""`ToolPolicyRegistry.sync_tool_policy_from_db` self-heals across one
|
||||
ClientNotConnectedError on the tools read — the perms read still fires
|
||||
after the recovery and the registry initializes cleanly."""
|
||||
import prisma as prisma_pkg
|
||||
|
||||
registry = ToolPolicyRegistry()
|
||||
invocations: list = []
|
||||
|
||||
async def _flaky_find_many():
|
||||
invocations.append(None)
|
||||
if len(invocations) == 1:
|
||||
raise prisma_pkg.errors.ClientNotConnectedError()
|
||||
return []
|
||||
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_prisma_client.db.litellm_tooltable.find_many = AsyncMock(
|
||||
side_effect=_flaky_find_many
|
||||
)
|
||||
mock_prisma_client.db.litellm_objectpermissiontable.find_many = AsyncMock(
|
||||
return_value=[]
|
||||
)
|
||||
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
|
||||
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
|
||||
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
|
||||
|
||||
await registry.sync_tool_policy_from_db(mock_prisma_client)
|
||||
|
||||
assert len(invocations) == 2
|
||||
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
|
||||
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
|
||||
assert (
|
||||
reconnect_kwargs["reason"]
|
||||
== "sync_tool_policy_from_db_tools_lookup_failure"
|
||||
)
|
||||
assert registry.is_initialized()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sync_tool_policy_from_db_retries_on_transport_error_second_read():
|
||||
"""Same as above but the blip happens on the perms read — distinct reason
|
||||
tag in telemetry confirms the second wrap is wired separately."""
|
||||
import prisma as prisma_pkg
|
||||
|
||||
registry = ToolPolicyRegistry()
|
||||
perms_invocations: list = []
|
||||
|
||||
async def _flaky_perms_find_many():
|
||||
perms_invocations.append(None)
|
||||
if len(perms_invocations) == 1:
|
||||
raise prisma_pkg.errors.ClientNotConnectedError()
|
||||
return []
|
||||
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_prisma_client.db.litellm_tooltable.find_many = AsyncMock(return_value=[])
|
||||
mock_prisma_client.db.litellm_objectpermissiontable.find_many = AsyncMock(
|
||||
side_effect=_flaky_perms_find_many
|
||||
)
|
||||
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
|
||||
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
|
||||
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
|
||||
|
||||
await registry.sync_tool_policy_from_db(mock_prisma_client)
|
||||
|
||||
assert len(perms_invocations) == 2
|
||||
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
|
||||
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
|
||||
assert (
|
||||
reconnect_kwargs["reason"]
|
||||
== "sync_tool_policy_from_db_perms_lookup_failure"
|
||||
)
|
||||
|
||||
@ -611,6 +611,45 @@ async def test_list_search_tools_db_masking_sensitive_values(monkeypatch):
|
||||
app.dependency_overrides.pop(user_api_key_auth, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_search_tools_from_db_retries_on_transport_error():
|
||||
"""`SearchToolRegistry.get_all_search_tools_from_db` self-heals across one
|
||||
ClientNotConnectedError via call_with_db_reconnect_retry."""
|
||||
import prisma
|
||||
from litellm.proxy.search_endpoints.search_tool_registry import (
|
||||
SearchToolRegistry,
|
||||
)
|
||||
|
||||
invocations: list = []
|
||||
|
||||
async def _flaky_find_many(**kwargs):
|
||||
invocations.append(None)
|
||||
if len(invocations) == 1:
|
||||
raise prisma.errors.ClientNotConnectedError()
|
||||
return []
|
||||
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_prisma_client.db.litellm_searchtoolstable.find_many = AsyncMock(
|
||||
side_effect=_flaky_find_many
|
||||
)
|
||||
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
|
||||
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
|
||||
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
|
||||
|
||||
result = await SearchToolRegistry.get_all_search_tools_from_db(
|
||||
prisma_client=mock_prisma_client
|
||||
)
|
||||
|
||||
assert result == []
|
||||
assert len(invocations) == 2
|
||||
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
|
||||
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
|
||||
assert (
|
||||
reconnect_kwargs["reason"]
|
||||
== "get_all_search_tools_from_db_lookup_failure"
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _mock_search_tool_backend(db_tools):
|
||||
"""Patch the DB registry, prisma client, and config so /search_tools/list
|
||||
|
||||
@ -259,6 +259,41 @@ class TestCacheSettingsManager:
|
||||
mock_proxy_config._init_cache.assert_not_called()
|
||||
mock_proxy_config.switch_on_llm_response_caching.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_cache_settings_in_db_retries_on_transport_error(self):
|
||||
"""`CacheSettingsManager.init_cache_settings_in_db` self-heals across one
|
||||
ClientNotConnectedError via call_with_db_reconnect_retry."""
|
||||
import prisma
|
||||
|
||||
invocations: list = []
|
||||
|
||||
async def _flaky_find_unique(**kwargs):
|
||||
invocations.append(None)
|
||||
if len(invocations) == 1:
|
||||
raise prisma.errors.ClientNotConnectedError()
|
||||
return None # No config → function returns early after retry.
|
||||
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_prisma_client.db.litellm_cacheconfig.find_unique = AsyncMock(
|
||||
side_effect=_flaky_find_unique
|
||||
)
|
||||
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
|
||||
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
|
||||
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
|
||||
mock_proxy_config = MagicMock()
|
||||
|
||||
await CacheSettingsManager.init_cache_settings_in_db(
|
||||
prisma_client=mock_prisma_client, proxy_config=mock_proxy_config
|
||||
)
|
||||
|
||||
assert len(invocations) == 2
|
||||
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
|
||||
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
|
||||
assert (
|
||||
reconnect_kwargs["reason"]
|
||||
== "init_cache_settings_in_db_lookup_failure"
|
||||
)
|
||||
|
||||
|
||||
# ── Audit-log emission for /cache/settings ────────────────────────────────────
|
||||
|
||||
|
||||
@ -4324,6 +4324,111 @@ async def test_init_sso_settings_in_db_empty_settings():
|
||||
assert uppercased_settings == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_sso_settings_in_db_retries_on_transport_error():
|
||||
"""`_init_sso_settings_in_db` self-heals across one ClientNotConnectedError
|
||||
via call_with_db_reconnect_retry — mirrors the auth-path behavior so
|
||||
startup/reload bursts don't spam the log."""
|
||||
import prisma
|
||||
|
||||
from litellm.proxy.proxy_server import ProxyConfig
|
||||
|
||||
proxy_config = ProxyConfig()
|
||||
mock_sso_config = MagicMock()
|
||||
mock_sso_config.sso_settings = {"GOOGLE_CLIENT_ID": "xxx"}
|
||||
|
||||
invocations: list = []
|
||||
|
||||
async def _flaky_find_unique(**kwargs):
|
||||
invocations.append(None)
|
||||
if len(invocations) == 1:
|
||||
raise prisma.errors.ClientNotConnectedError()
|
||||
return mock_sso_config
|
||||
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
|
||||
side_effect=_flaky_find_unique
|
||||
)
|
||||
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
|
||||
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
|
||||
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
|
||||
|
||||
with patch.object(
|
||||
proxy_config, "_decrypt_and_set_db_env_variables"
|
||||
) as mock_decrypt:
|
||||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||||
|
||||
assert len(invocations) == 2
|
||||
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
|
||||
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
|
||||
assert reconnect_kwargs["reason"] == "init_sso_settings_in_db_lookup_failure"
|
||||
mock_decrypt.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_sso_settings_in_db_propagates_when_reconnect_fails():
|
||||
"""When reconnect returns False (cooldown / lock contention), the original
|
||||
ClientNotConnectedError is caught by the function's `except Exception` and
|
||||
logged — no retry storm, no crash."""
|
||||
import prisma
|
||||
|
||||
from litellm.proxy.proxy_server import ProxyConfig
|
||||
|
||||
proxy_config = ProxyConfig()
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock(
|
||||
side_effect=prisma.errors.ClientNotConnectedError()
|
||||
)
|
||||
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=False)
|
||||
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
|
||||
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
|
||||
|
||||
# Should NOT raise — the function's own try/except swallows the propagated error.
|
||||
await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client)
|
||||
|
||||
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_hashicorp_vault_config_override_retries_on_transport_error():
|
||||
"""`_init_hashicorp_vault_config_override` self-heals across one
|
||||
ClientNotConnectedError via call_with_db_reconnect_retry."""
|
||||
import prisma
|
||||
|
||||
from litellm.proxy.proxy_server import ProxyConfig
|
||||
|
||||
proxy_config = ProxyConfig()
|
||||
proxy_config._last_hashicorp_vault_config = None
|
||||
|
||||
invocations: list = []
|
||||
|
||||
async def _flaky_find_unique(**kwargs):
|
||||
invocations.append(None)
|
||||
if len(invocations) == 1:
|
||||
raise prisma.errors.ClientNotConnectedError()
|
||||
return None # No config in DB → function returns early after retry.
|
||||
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_prisma_client.db.litellm_configoverrides.find_unique = AsyncMock(
|
||||
side_effect=_flaky_find_unique
|
||||
)
|
||||
mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
|
||||
mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0
|
||||
mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1
|
||||
|
||||
await proxy_config._init_hashicorp_vault_config_override(
|
||||
prisma_client=mock_prisma_client
|
||||
)
|
||||
|
||||
assert len(invocations) == 2
|
||||
mock_prisma_client.attempt_db_reconnect.assert_awaited_once()
|
||||
reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs
|
||||
assert (
|
||||
reconnect_kwargs["reason"]
|
||||
== "init_hashicorp_vault_config_override_lookup_failure"
|
||||
)
|
||||
|
||||
|
||||
def test_update_config_fields_uppercases_env_vars(monkeypatch):
|
||||
"""
|
||||
Ensure environment variables pulled from DB are uppercased when applied so
|
||||
|
||||
Loading…
Reference in New Issue
Block a user