diff --git a/litellm/proxy/db/tool_registry_writer.py b/litellm/proxy/db/tool_registry_writer.py index bbcc7396d6..08bc8944b9 100644 --- a/litellm/proxy/db/tool_registry_writer.py +++ b/litellm/proxy/db/tool_registry_writer.py @@ -11,6 +11,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union from litellm._logging import verbose_proxy_logger from litellm.proxy._types import ToolDiscoveryQueueItem +from litellm.proxy.db.exception_handler import call_with_db_reconnect_retry from litellm.repositories.object_permission_repository import ObjectPermissionRepository from litellm.repositories.table_repositories import ToolRepository from litellm.types.tool_management import ( @@ -309,7 +310,11 @@ class ToolPolicyRegistry: async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None: """Load all tool policies and object-permission blocked_tools from DB.""" try: - tools = await ToolRepository(prisma_client).table.find_many() + tools = await call_with_db_reconnect_retry( + prisma_client, + lambda: ToolRepository(prisma_client).table.find_many(), + reason="sync_tool_policy_from_db_tools_lookup_failure", + ) self._tool_input_policies = { row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted" for row in tools @@ -319,7 +324,11 @@ class ToolPolicyRegistry: for row in tools } - perms = await ObjectPermissionRepository(prisma_client).table.find_many() + perms = await call_with_db_reconnect_retry( + prisma_client, + lambda: ObjectPermissionRepository(prisma_client).table.find_many(), + reason="sync_tool_policy_from_db_perms_lookup_failure", + ) self._blocked_tools_by_op_id = {} for row in perms: op_id = getattr(row, "object_permission_id", None) diff --git a/litellm/proxy/management_endpoints/cache_settings_endpoints.py b/litellm/proxy/management_endpoints/cache_settings_endpoints.py index d8eb5dfee9..b6ddf2d8e0 100644 --- a/litellm/proxy/management_endpoints/cache_settings_endpoints.py +++ b/litellm/proxy/management_endpoints/cache_settings_endpoints.py @@ -27,6 +27,7 @@ from litellm.proxy._types import ( UserAPIKeyAuth, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy.db.exception_handler import call_with_db_reconnect_retry from litellm.repositories.table_repositories import CacheConfigRepository from litellm.types.management_endpoints import ( CACHE_SETTINGS_FIELDS, @@ -160,8 +161,12 @@ class CacheSettingsManager: import json try: - cache_config = await CacheConfigRepository(prisma_client).table.find_unique( - where={"id": "cache_config"} + cache_config = await call_with_db_reconnect_retry( + prisma_client, + lambda: CacheConfigRepository(prisma_client).table.find_unique( + where={"id": "cache_config"} + ), + reason="init_cache_settings_in_db_lookup_failure", ) if cache_config is not None and cache_config.cache_settings: # Parse cache settings JSON diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 37a0285b19..709b5dcf7e 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -311,7 +311,10 @@ from litellm.proxy.common_utils.user_api_key_cache import UserApiKeyCache from litellm.proxy.container_endpoints.endpoints import router as container_router from litellm.proxy.credential_endpoints.endpoints import router as credential_router from litellm.proxy.db.db_transaction_queue.spend_log_cleanup import SpendLogCleanup -from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler +from litellm.proxy.db.exception_handler import ( + PrismaDBExceptionHandler, + call_with_db_reconnect_retry, +) from litellm.proxy.db.spend_counter_reseed import SpendCounterReseed from litellm.proxy.discovery_endpoints import ui_discovery_endpoints_router from litellm.proxy.fine_tuning_endpoints.endpoints import router as fine_tuning_router @@ -5984,8 +5987,12 @@ class ProxyConfig: """ try: - sso_settings = await SSOConfigRepository(prisma_client).table.find_unique( - where={"id": "sso_config"} + sso_settings = await call_with_db_reconnect_retry( + prisma_client, + lambda: SSOConfigRepository(prisma_client).table.find_unique( + where={"id": "sso_config"} + ), + reason="init_sso_settings_in_db_lookup_failure", ) if sso_settings is not None: sso_settings.sso_settings.pop("role_mappings", None) @@ -6020,9 +6027,13 @@ class ProxyConfig: ) try: - db_record = await ConfigOverridesRepository( - prisma_client - ).table.find_unique(where={"config_type": "hashicorp_vault"}) + db_record = await call_with_db_reconnect_retry( + prisma_client, + lambda: ConfigOverridesRepository(prisma_client).table.find_unique( + where={"config_type": "hashicorp_vault"} + ), + reason="init_hashicorp_vault_config_override_lookup_failure", + ) if db_record is None or db_record.config_value is None: if self._last_hashicorp_vault_config is not None: diff --git a/litellm/proxy/search_endpoints/search_tool_registry.py b/litellm/proxy/search_endpoints/search_tool_registry.py index 588d71b77f..2ec2533211 100644 --- a/litellm/proxy/search_endpoints/search_tool_registry.py +++ b/litellm/proxy/search_endpoints/search_tool_registry.py @@ -7,6 +7,7 @@ from typing import List, Optional from litellm._logging import verbose_proxy_logger from litellm.litellm_core_utils.safe_json_dumps import safe_dumps +from litellm.proxy.db.exception_handler import call_with_db_reconnect_retry from litellm.proxy.utils import PrismaClient from litellm.repositories.table_repositories import SearchToolsRepository from litellm.types.search import SearchTool @@ -180,10 +181,12 @@ class SearchToolRegistry: List of search tool configurations """ try: - search_tools_from_db = await SearchToolsRepository( - prisma_client - ).table.find_many( - order={"created_at": "desc"}, + search_tools_from_db = await call_with_db_reconnect_retry( + prisma_client, + lambda: SearchToolsRepository(prisma_client).table.find_many( + order={"created_at": "desc"}, + ), + reason="get_all_search_tools_from_db_lookup_failure", ) search_tools: List[SearchTool] = [] diff --git a/tests/test_litellm/proxy/db/test_tool_registry_writer.py b/tests/test_litellm/proxy/db/test_tool_registry_writer.py index 8074871c3d..7bf1ffda4f 100644 --- a/tests/test_litellm/proxy/db/test_tool_registry_writer.py +++ b/tests/test_litellm/proxy/db/test_tool_registry_writer.py @@ -291,3 +291,77 @@ async def test_tool_policy_registry_not_initialized_returns_untrusted(): assert not registry.is_initialized() result = registry.get_effective_policies(["unknown_tool"]) assert result == {"unknown_tool": "untrusted"} + + +@pytest.mark.asyncio +async def test_sync_tool_policy_from_db_retries_on_transport_error_first_read(): + """`ToolPolicyRegistry.sync_tool_policy_from_db` self-heals across one + ClientNotConnectedError on the tools read — the perms read still fires + after the recovery and the registry initializes cleanly.""" + import prisma as prisma_pkg + + registry = ToolPolicyRegistry() + invocations: list = [] + + async def _flaky_find_many(): + invocations.append(None) + if len(invocations) == 1: + raise prisma_pkg.errors.ClientNotConnectedError() + return [] + + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_tooltable.find_many = AsyncMock( + side_effect=_flaky_find_many + ) + mock_prisma_client.db.litellm_objectpermissiontable.find_many = AsyncMock( + return_value=[] + ) + mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True) + mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0 + mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1 + + await registry.sync_tool_policy_from_db(mock_prisma_client) + + assert len(invocations) == 2 + mock_prisma_client.attempt_db_reconnect.assert_awaited_once() + reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs + assert ( + reconnect_kwargs["reason"] + == "sync_tool_policy_from_db_tools_lookup_failure" + ) + assert registry.is_initialized() + + +@pytest.mark.asyncio +async def test_sync_tool_policy_from_db_retries_on_transport_error_second_read(): + """Same as above but the blip happens on the perms read — distinct reason + tag in telemetry confirms the second wrap is wired separately.""" + import prisma as prisma_pkg + + registry = ToolPolicyRegistry() + perms_invocations: list = [] + + async def _flaky_perms_find_many(): + perms_invocations.append(None) + if len(perms_invocations) == 1: + raise prisma_pkg.errors.ClientNotConnectedError() + return [] + + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_tooltable.find_many = AsyncMock(return_value=[]) + mock_prisma_client.db.litellm_objectpermissiontable.find_many = AsyncMock( + side_effect=_flaky_perms_find_many + ) + mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True) + mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0 + mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1 + + await registry.sync_tool_policy_from_db(mock_prisma_client) + + assert len(perms_invocations) == 2 + mock_prisma_client.attempt_db_reconnect.assert_awaited_once() + reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs + assert ( + reconnect_kwargs["reason"] + == "sync_tool_policy_from_db_perms_lookup_failure" + ) diff --git a/tests/test_litellm/proxy/management_endpoints/search_endpoints/test_search_tool_management.py b/tests/test_litellm/proxy/management_endpoints/search_endpoints/test_search_tool_management.py index ea7e5591f1..f2ccfcd015 100644 --- a/tests/test_litellm/proxy/management_endpoints/search_endpoints/test_search_tool_management.py +++ b/tests/test_litellm/proxy/management_endpoints/search_endpoints/test_search_tool_management.py @@ -611,6 +611,45 @@ async def test_list_search_tools_db_masking_sensitive_values(monkeypatch): app.dependency_overrides.pop(user_api_key_auth, None) +@pytest.mark.asyncio +async def test_get_all_search_tools_from_db_retries_on_transport_error(): + """`SearchToolRegistry.get_all_search_tools_from_db` self-heals across one + ClientNotConnectedError via call_with_db_reconnect_retry.""" + import prisma + from litellm.proxy.search_endpoints.search_tool_registry import ( + SearchToolRegistry, + ) + + invocations: list = [] + + async def _flaky_find_many(**kwargs): + invocations.append(None) + if len(invocations) == 1: + raise prisma.errors.ClientNotConnectedError() + return [] + + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_searchtoolstable.find_many = AsyncMock( + side_effect=_flaky_find_many + ) + mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True) + mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0 + mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1 + + result = await SearchToolRegistry.get_all_search_tools_from_db( + prisma_client=mock_prisma_client + ) + + assert result == [] + assert len(invocations) == 2 + mock_prisma_client.attempt_db_reconnect.assert_awaited_once() + reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs + assert ( + reconnect_kwargs["reason"] + == "get_all_search_tools_from_db_lookup_failure" + ) + + @contextlib.contextmanager def _mock_search_tool_backend(db_tools): """Patch the DB registry, prisma client, and config so /search_tools/list diff --git a/tests/test_litellm/proxy/management_endpoints/test_cache_settings_endpoints.py b/tests/test_litellm/proxy/management_endpoints/test_cache_settings_endpoints.py index b892c4e556..4bdef2e8f9 100644 --- a/tests/test_litellm/proxy/management_endpoints/test_cache_settings_endpoints.py +++ b/tests/test_litellm/proxy/management_endpoints/test_cache_settings_endpoints.py @@ -259,6 +259,41 @@ class TestCacheSettingsManager: mock_proxy_config._init_cache.assert_not_called() mock_proxy_config.switch_on_llm_response_caching.assert_not_called() + @pytest.mark.asyncio + async def test_init_cache_settings_in_db_retries_on_transport_error(self): + """`CacheSettingsManager.init_cache_settings_in_db` self-heals across one + ClientNotConnectedError via call_with_db_reconnect_retry.""" + import prisma + + invocations: list = [] + + async def _flaky_find_unique(**kwargs): + invocations.append(None) + if len(invocations) == 1: + raise prisma.errors.ClientNotConnectedError() + return None # No config → function returns early after retry. + + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_cacheconfig.find_unique = AsyncMock( + side_effect=_flaky_find_unique + ) + mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True) + mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0 + mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1 + mock_proxy_config = MagicMock() + + await CacheSettingsManager.init_cache_settings_in_db( + prisma_client=mock_prisma_client, proxy_config=mock_proxy_config + ) + + assert len(invocations) == 2 + mock_prisma_client.attempt_db_reconnect.assert_awaited_once() + reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs + assert ( + reconnect_kwargs["reason"] + == "init_cache_settings_in_db_lookup_failure" + ) + # ── Audit-log emission for /cache/settings ──────────────────────────────────── diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index b2e36fd64c..b1dee205ee 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -4324,6 +4324,111 @@ async def test_init_sso_settings_in_db_empty_settings(): assert uppercased_settings == {} +@pytest.mark.asyncio +async def test_init_sso_settings_in_db_retries_on_transport_error(): + """`_init_sso_settings_in_db` self-heals across one ClientNotConnectedError + via call_with_db_reconnect_retry — mirrors the auth-path behavior so + startup/reload bursts don't spam the log.""" + import prisma + + from litellm.proxy.proxy_server import ProxyConfig + + proxy_config = ProxyConfig() + mock_sso_config = MagicMock() + mock_sso_config.sso_settings = {"GOOGLE_CLIENT_ID": "xxx"} + + invocations: list = [] + + async def _flaky_find_unique(**kwargs): + invocations.append(None) + if len(invocations) == 1: + raise prisma.errors.ClientNotConnectedError() + return mock_sso_config + + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock( + side_effect=_flaky_find_unique + ) + mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True) + mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0 + mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1 + + with patch.object( + proxy_config, "_decrypt_and_set_db_env_variables" + ) as mock_decrypt: + await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client) + + assert len(invocations) == 2 + mock_prisma_client.attempt_db_reconnect.assert_awaited_once() + reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs + assert reconnect_kwargs["reason"] == "init_sso_settings_in_db_lookup_failure" + mock_decrypt.assert_called_once() + + +@pytest.mark.asyncio +async def test_init_sso_settings_in_db_propagates_when_reconnect_fails(): + """When reconnect returns False (cooldown / lock contention), the original + ClientNotConnectedError is caught by the function's `except Exception` and + logged — no retry storm, no crash.""" + import prisma + + from litellm.proxy.proxy_server import ProxyConfig + + proxy_config = ProxyConfig() + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_ssoconfig.find_unique = AsyncMock( + side_effect=prisma.errors.ClientNotConnectedError() + ) + mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=False) + mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0 + mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1 + + # Should NOT raise — the function's own try/except swallows the propagated error. + await proxy_config._init_sso_settings_in_db(prisma_client=mock_prisma_client) + + mock_prisma_client.attempt_db_reconnect.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_init_hashicorp_vault_config_override_retries_on_transport_error(): + """`_init_hashicorp_vault_config_override` self-heals across one + ClientNotConnectedError via call_with_db_reconnect_retry.""" + import prisma + + from litellm.proxy.proxy_server import ProxyConfig + + proxy_config = ProxyConfig() + proxy_config._last_hashicorp_vault_config = None + + invocations: list = [] + + async def _flaky_find_unique(**kwargs): + invocations.append(None) + if len(invocations) == 1: + raise prisma.errors.ClientNotConnectedError() + return None # No config in DB → function returns early after retry. + + mock_prisma_client = MagicMock() + mock_prisma_client.db.litellm_configoverrides.find_unique = AsyncMock( + side_effect=_flaky_find_unique + ) + mock_prisma_client.attempt_db_reconnect = AsyncMock(return_value=True) + mock_prisma_client._db_auth_reconnect_timeout_seconds = 2.0 + mock_prisma_client._db_auth_reconnect_lock_timeout_seconds = 0.1 + + await proxy_config._init_hashicorp_vault_config_override( + prisma_client=mock_prisma_client + ) + + assert len(invocations) == 2 + mock_prisma_client.attempt_db_reconnect.assert_awaited_once() + reconnect_kwargs = mock_prisma_client.attempt_db_reconnect.await_args.kwargs + assert ( + reconnect_kwargs["reason"] + == "init_hashicorp_vault_config_override_lookup_failure" + ) + + def test_update_config_fields_uppercases_env_vars(monkeypatch): """ Ensure environment variables pulled from DB are uppercased when applied so