From da9d64b4de4b6927d3496f89fa402490a98bfb10 Mon Sep 17 00:00:00 2001 From: Yassin Kortam Date: Wed, 10 Jun 2026 16:48:11 -0700 Subject: [PATCH] fix(proxy): return 5xx on DB infra errors during auth; reserve 401 for genuine auth failures (#29986) --- litellm/proxy/auth/auth_exception_handler.py | 10 + litellm/proxy/db/exception_handler.py | 86 ++++++++ .../proxy/auth/test_auth_exception_handler.py | 160 ++++++++++++++ .../proxy/auth/test_user_api_key_auth.py | 119 ++++++++++- .../proxy/db/test_exception_handler.py | 195 ++++++++++++++++++ 5 files changed, 569 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/auth/auth_exception_handler.py b/litellm/proxy/auth/auth_exception_handler.py index f76949f4d1..83f1817318 100644 --- a/litellm/proxy/auth/auth_exception_handler.py +++ b/litellm/proxy/auth/auth_exception_handler.py @@ -168,6 +168,16 @@ class UserAPIKeyAuthExceptionHandler: ) elif isinstance(e, ProxyException): raise e + if PrismaDBExceptionHandler.is_database_service_unavailable_error(e): + raise ProxyException( + message=( + "Service Unavailable, the authentication database is " + "temporarily unreachable. Please retry shortly." + ), + type=ProxyErrorTypes.no_db_connection, + param="None", + code=status.HTTP_503_SERVICE_UNAVAILABLE, + ) raise ProxyException( message="Authentication Error, " + str(e), type=ProxyErrorTypes.auth_error, diff --git a/litellm/proxy/db/exception_handler.py b/litellm/proxy/db/exception_handler.py index ab9d341aa5..c500e72759 100644 --- a/litellm/proxy/db/exception_handler.py +++ b/litellm/proxy/db/exception_handler.py @@ -109,6 +109,92 @@ class PrismaDBExceptionHandler: return True return False + @staticmethod + def is_prisma_engine_internal_error(e: Exception) -> bool: + """True iff ``e`` is a non-``PrismaError`` exception raised from inside + prisma-client-py's query-engine layer. + + During the instant a DB connection is torn down, the query engine can + return a malformed error payload (``user_facing_error.meta`` is + ``null``). prisma-client-py's ``handle_response_errors`` then crashes + with ``AttributeError: 'NoneType' object has no attribute 'get'`` + before it can raise the proper P1001 "can't reach database server" + error. That AttributeError carries no connection keyword, so it can't + be matched by message; identify it by its ``prisma.engine`` origin + instead. + + Recognized ``PrismaError`` subclasses are excluded: connectivity ones + are already classified by type/keyword above, and data-layer ones + (the DB IS reachable) must stay 401. + """ + import prisma + + if isinstance(e, prisma.errors.PrismaError): + return False + tb = getattr(e, "__traceback__", None) + while tb is not None: + if tb.tb_frame.f_globals.get("__name__", "").startswith("prisma.engine"): + return True + tb = tb.tb_next + return False + + @staticmethod + def is_database_service_unavailable_error(e: Exception) -> bool: + """True iff the exception means the database could not answer at the + infrastructure level (connection refused, socket/interface failure, + timeout) rather than a genuine auth failure (key not found) or a + data-layer error (the DB IS reachable and rejected the data). + + Auth must answer 401 only for a key the DB confirms is invalid. When + the DB itself is unreachable, the request has to surface as 503 so + callers retry instead of treating valid keys as invalid during an + outage. + + Note: prisma-client-py mislabels the P1001 "can't reach database + server" connectivity failure as a ``DataError`` (a data-layer type), + so a type-only check misses real outages. ``is_database_transport_error`` + keyword-matches the connection message and catches that masquerade, + while genuine data errors (no connection keyword) correctly stay 401. + + The Postgres "cached plan must not change result type" error is matched + here, not in ``is_database_transport_error``: it is a transient stale-DB- + state condition (not an invalid key), but the connection is healthy so it + must not trigger a reconnect. + + A non-``PrismaError`` raised from inside the prisma query engine (e.g. + the ``AttributeError`` from ``handle_response_errors`` when the engine + returns a malformed error payload mid-tear-down) is also treated as + unavailable; see ``is_prisma_engine_internal_error``. + """ + import asyncio + + if PrismaDBExceptionHandler.is_database_connection_error(e): + return True + if PrismaDBExceptionHandler.is_database_transport_error(e): + return True + if PrismaDBExceptionHandler.is_prisma_engine_internal_error(e): + return True + if "cached plan must not change result type" in str(e).lower(): + return True + + # OSError already covers ConnectionError and (Py3.3+) TimeoutError. + # asyncio.TimeoutError is a distinct class before Py3.11. + if isinstance(e, (OSError, asyncio.TimeoutError)): + return True + + try: + import asyncpg + except ImportError: + return False + + return isinstance( + e, + ( + asyncpg.exceptions.PostgresConnectionError, + asyncpg.exceptions.InterfaceError, + ), + ) + @staticmethod def handle_db_exception(e: Exception): """ diff --git a/tests/test_litellm/proxy/auth/test_auth_exception_handler.py b/tests/test_litellm/proxy/auth/test_auth_exception_handler.py index 27f6015e6f..11e6f483e3 100644 --- a/tests/test_litellm/proxy/auth/test_auth_exception_handler.py +++ b/tests/test_litellm/proxy/auth/test_auth_exception_handler.py @@ -112,6 +112,166 @@ async def test_handle_authentication_error_data_layer_errors_do_not_fall_back( ) +@pytest.mark.asyncio +@pytest.mark.parametrize( + "db_error", + [ + ConnectionError("connection refused"), + TimeoutError("timed out"), + asyncio.TimeoutError(), + OSError("network is unreachable"), + HTTPClientClosedError(), + PrismaError("can't reach database server"), + RawQueryError( + data={ + "user_facing_error": { + "message": "cached plan must not change result type", + "meta": {"table": "t"}, + } + } + ), + ], +) +async def test_handle_authentication_error_db_infra_error_returns_503(db_error): + """Regression for the outage where valid keys got 401 for 4 hours: an + infrastructure-level DB failure during auth must surface as 503 (the DB + could not confirm the key), never as 401 ("Invalid API key").""" + handler = UserAPIKeyAuthExceptionHandler() + + with ( + patch( + "litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook", + new_callable=AsyncMock, + return_value=None, + ), + patch( + "litellm.proxy.auth.auth_exception_handler.seed_request_identity", + ), + patch( + "litellm.proxy.proxy_server.general_settings", + {"allow_requests_on_db_unavailable": False}, + ), + ): + with pytest.raises(ProxyException) as exc_info: + await handler._handle_authentication_error( + db_error, + MagicMock(), + {}, + "/v1/chat/completions", + None, + "sk-valid-but-db-down", + ) + + assert int(exc_info.value.code) == status.HTTP_503_SERVICE_UNAVAILABLE + assert exc_info.value.type == ProxyErrorTypes.no_db_connection + assert "Invalid API key" not in str(exc_info.value.message) + + +@pytest.mark.asyncio +async def test_handle_authentication_error_prisma_engine_teardown_returns_503(): + """Regression for the first-request-of-an-outage edge case: at the instant + the DB socket drops, the prisma query engine returns a malformed error + payload and prisma-client-py crashes with a bare + ``AttributeError: 'NoneType' object has no attribute 'get'`` before it can + raise P1001. That AttributeError reached auth and fell through to 401. It + must surface as 503 like every other infra failure during the outage.""" + from prisma.engine import utils as prisma_engine_utils + + malformed_payload = [ + { + "error": "Can't reach database server", + "user_facing_error": { + "error_code": "P1001", + "message": "Can't reach database server at `localhost`:`5503`", + "meta": None, + }, + } + ] + try: + prisma_engine_utils.handle_response_errors(None, malformed_payload) + raise AssertionError("expected prisma to raise AttributeError") + except AttributeError as e: + teardown_error = e + + handler = UserAPIKeyAuthExceptionHandler() + + with ( + patch( + "litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook", + new_callable=AsyncMock, + return_value=None, + ), + patch( + "litellm.proxy.auth.auth_exception_handler.seed_request_identity", + ), + patch( + "litellm.proxy.proxy_server.general_settings", + {"allow_requests_on_db_unavailable": False}, + ), + ): + with pytest.raises(ProxyException) as exc_info: + await handler._handle_authentication_error( + teardown_error, + MagicMock(), + {}, + "/v1/chat/completions", + None, + "sk-valid-but-db-down", + ) + + assert int(exc_info.value.code) == status.HTTP_503_SERVICE_UNAVAILABLE + assert exc_info.value.type == ProxyErrorTypes.no_db_connection + assert "Invalid API key" not in str(exc_info.value.message) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "auth_error", + [ + # DB returned no row -> get_key_object raises this exact 401. + ProxyException( + message="Authentication Error, Invalid proxy server token passed.", + type=ProxyErrorTypes.token_not_found_in_db, + param="key", + code=status.HTTP_401_UNAUTHORIZED, + ), + # A bare auth failure raised as a plain Exception (e.g. master-key-only + # route) must keep returning 401, not get reclassified as 503. + Exception("Invalid proxy server token passed"), + ], +) +async def test_handle_authentication_error_genuine_auth_failure_stays_401(auth_error): + """Guard against the 503 conversion being too broad: a genuine auth + failure (missing key / wrong key) must still be 401.""" + handler = UserAPIKeyAuthExceptionHandler() + + with ( + patch( + "litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook", + new_callable=AsyncMock, + return_value=None, + ), + patch( + "litellm.proxy.auth.auth_exception_handler.seed_request_identity", + ), + patch( + "litellm.proxy.proxy_server.general_settings", + {"allow_requests_on_db_unavailable": False}, + ), + ): + with pytest.raises(ProxyException) as exc_info: + await handler._handle_authentication_error( + auth_error, + MagicMock(), + {}, + "/v1/chat/completions", + None, + "sk-bad-key", + ) + + assert int(exc_info.value.code) == status.HTTP_401_UNAUTHORIZED + + @pytest.mark.asyncio async def test_handle_authentication_error_budget_exceeded(): handler = UserAPIKeyAuthExceptionHandler() diff --git a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py index 0236646c79..80f12d4459 100644 --- a/tests/test_litellm/proxy/auth/test_user_api_key_auth.py +++ b/tests/test_litellm/proxy/auth/test_user_api_key_auth.py @@ -1696,7 +1696,9 @@ class TestJWTOAuth2Coexistence: assert mock_auto_register.call_args.kwargs["team_id"] == "validated-team" assert mock_auto_register.call_args.kwargs["user_id"] == "validated-user" assert mock_auto_register.call_args.kwargs["org_id"] == "validated-org" - assert mock_auto_register.call_args.kwargs["end_user_id"] == "validated-end-user" + assert ( + mock_auto_register.call_args.kwargs["end_user_id"] == "validated-end-user" + ) assert result.org_id == "validated-org" @pytest.mark.asyncio @@ -3608,3 +3610,118 @@ async def test_user_api_key_auth_does_not_overwrite_end_user_id_set_by_builder() finally: for k, v in originals.items(): setattr(_proxy_server_mod, k, v) + + +def _proxy_attrs_for_db_lookup(): + """Minimal proxy_server attributes for driving the real + ``_user_api_key_auth_builder`` down to the DB key lookup.""" + proxy_logging_obj = MagicMock() + proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None) + return { + "prisma_client": MagicMock(), + "user_api_key_cache": DualCache(), + "proxy_logging_obj": proxy_logging_obj, + "master_key": "sk-test-master", + "general_settings": {"allow_requests_on_db_unavailable": False}, + "llm_model_list": [], + "llm_router": None, + "open_telemetry_logger": None, + "model_max_budget_limiter": MagicMock(), + "user_custom_auth": None, + "jwt_handler": None, + "litellm_proxy_admin_name": "admin", + } + + +async def _run_builder_with_key_lookup(get_key_object_mock): + """Drive the real auth builder with ``get_key_object`` replaced by the + given mock. Returns the builder result. Patches ``seed_request_identity`` + so the failure path doesn't touch OTEL.""" + from fastapi import Request + from starlette.datastructures import URL + + import litellm.proxy.proxy_server as _proxy_server_mod + from litellm.proxy.auth.user_api_key_auth import _user_api_key_auth_builder + + attrs = _proxy_attrs_for_db_lookup() + originals = {a: getattr(_proxy_server_mod, a, None) for a in attrs} + try: + for k, v in attrs.items(): + setattr(_proxy_server_mod, k, v) + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + with ( + patch( + "litellm.proxy.auth.user_api_key_auth.get_key_object", + get_key_object_mock, + ), + patch( + "litellm.proxy.auth.auth_exception_handler.seed_request_identity", + ), + ): + return await _user_api_key_auth_builder( + request=request, + api_key="Bearer sk-db-lookup-test", + azure_api_key_header="", + anthropic_api_key_header=None, + google_ai_studio_api_key_header=None, + azure_apim_header=None, + request_data={}, + ) + finally: + for k, v in originals.items(): + setattr(_proxy_server_mod, k, v) + + +@pytest.mark.asyncio +async def test_builder_returns_503_when_db_lookup_raises_infra_error(): + """End-to-end: a DB infrastructure failure during the key lookup must + propagate past the ``except ProxyException`` guard and surface as 503, + not the 401 that masked the 4-hour outage. Killing the new 503 branch + flips this to 401 and fails the test.""" + get_key_object = AsyncMock(side_effect=ConnectionError("connection refused")) + + with pytest.raises(ProxyException) as exc_info: + await _run_builder_with_key_lookup(get_key_object) + + assert int(exc_info.value.code) == status.HTTP_503_SERVICE_UNAVAILABLE + assert exc_info.value.type == ProxyErrorTypes.no_db_connection + assert "Invalid API key" not in str(exc_info.value.message) + + +@pytest.mark.asyncio +async def test_builder_returns_401_when_db_lookup_reports_missing_key(): + """Regression guard: a genuinely missing key (DB returned no row, which + ``get_key_object`` raises as a 401 ProxyException) must still be 401.""" + missing_key_error = ProxyException( + message="Authentication Error, Invalid proxy server token passed. key=..., not found in db.", + type=ProxyErrorTypes.token_not_found_in_db, + param="key", + code=status.HTTP_401_UNAUTHORIZED, + ) + get_key_object = AsyncMock(side_effect=missing_key_error) + + with pytest.raises(ProxyException) as exc_info: + await _run_builder_with_key_lookup(get_key_object) + + assert int(exc_info.value.code) == status.HTTP_401_UNAUTHORIZED + + +@pytest.mark.asyncio +async def test_builder_succeeds_when_db_lookup_returns_valid_token(): + """Regression guard: a valid key still authenticates. Proves the 503 + conversion only fires on the failure path and never intercepts success.""" + valid_token = UserAPIKeyAuth(api_key="sk-db-lookup-test", token="hashed-valid") + get_key_object = AsyncMock(return_value=valid_token) + + with patch( + "litellm.proxy.auth.user_api_key_auth._return_user_api_key_auth_obj", + new_callable=AsyncMock, + return_value=valid_token, + ) as mock_return: + result = await _run_builder_with_key_lookup(get_key_object) + + assert isinstance(result, UserAPIKeyAuth) + # Reaching the success-assembly return (never the exception handler) + # proves a valid key is unaffected by the 503 conversion. + mock_return.assert_awaited_once() diff --git a/tests/test_litellm/proxy/db/test_exception_handler.py b/tests/test_litellm/proxy/db/test_exception_handler.py index 9dcf5df4ae..6021c22142 100644 --- a/tests/test_litellm/proxy/db/test_exception_handler.py +++ b/tests/test_litellm/proxy/db/test_exception_handler.py @@ -107,6 +107,201 @@ def test_is_database_connection_generic_errors(): ) +@pytest.mark.parametrize( + "error", + [ + ConnectionError("connection refused"), + TimeoutError("timed out"), + OSError("network is unreachable"), + asyncio.TimeoutError(), + HTTPClientClosedError(), + ClientNotConnectedError(), + PrismaError("can't reach database server"), + PrismaError(), + ], +) +def test_is_database_service_unavailable_error_infra_failures(error): + """Infrastructure-level failures (socket/connection/timeout, prisma + transport, unknown PrismaError) mean the DB could not answer, so auth + must surface 503 instead of treating a valid key as invalid.""" + assert PrismaDBExceptionHandler.is_database_service_unavailable_error(error) is True + + +def test_is_database_service_unavailable_error_prisma_p1001_masquerades_as_dataerror(): + """Real-world regression: prisma-client-py raises the P1001 "can't reach + database server" connectivity failure as a DataError (a data-layer type). + A type-only check would miss it and return 401 during a genuine outage; + the message keyword must still classify it as service-unavailable -> 503.""" + p1001_as_dataerror = DataError( + data={ + "user_facing_error": { + "message": "Can't reach database server at `127.0.0.1`:`5499`", + "meta": {"table": "t"}, + } + } + ) + assert ( + PrismaDBExceptionHandler.is_database_service_unavailable_error( + p1001_as_dataerror + ) + is True + ) + + +def test_is_database_service_unavailable_error_cached_plan_escapes_as_503(): + """Composes with the cached-plan retry: when that recovery fails and the + Postgres "cached plan must not change result type" error escapes (raised by + prisma as a data-layer RawQueryError), it is a transient stale-DB-state + condition, not an invalid key, so it must classify as service-unavailable + -> 503 rather than fall through to 401.""" + cached_plan_error = RawQueryError( + data={ + "user_facing_error": { + "message": "cached plan must not change result type", + "meta": {"table": "t"}, + } + } + ) + assert ( + PrismaDBExceptionHandler.is_database_service_unavailable_error( + cached_plan_error + ) + is True + ) + + +def test_is_database_service_unavailable_error_prisma_engine_malformed_payload(): + """Real-world regression: at the instant the DB socket drops, the prisma + query engine returns a malformed error payload (``user_facing_error.meta`` + is ``null``). prisma-client-py's ``handle_response_errors`` then crashes + with ``AttributeError: 'NoneType' object has no attribute 'get'`` before it + can raise the proper P1001 error. That bare AttributeError has no + connection keyword, so without the prisma-engine-origin check it falls + through to 401 on the first request of an outage. Reproduce the exact + prisma crash and assert it classifies as service-unavailable -> 503.""" + from prisma.engine import utils as prisma_engine_utils + + malformed_payload = [ + { + "error": "Can't reach database server", + "user_facing_error": { + "error_code": "P1001", + "message": "Can't reach database server at `localhost`:`5503`", + "meta": None, + }, + } + ] + with pytest.raises(AttributeError) as exc_info: + prisma_engine_utils.handle_response_errors(None, malformed_payload) + + assert "no attribute 'get'" in str(exc_info.value) + assert ( + PrismaDBExceptionHandler.is_database_service_unavailable_error(exc_info.value) + is True + ) + + +def test_is_prisma_engine_internal_error_excludes_application_attributeerror(): + """The prisma-engine-origin check must stay narrow: a genuine AttributeError + raised by application code (a real bug) must NOT be classified as + service-unavailable, otherwise real bugs would silently become 503s.""" + + def application_bug(): + none_value = None + return none_value.get("oops") + + with pytest.raises(AttributeError) as exc_info: + application_bug() + + assert ( + PrismaDBExceptionHandler.is_prisma_engine_internal_error(exc_info.value) + is False + ) + assert ( + PrismaDBExceptionHandler.is_database_service_unavailable_error(exc_info.value) + is False + ) + + +def test_is_prisma_engine_internal_error_excludes_data_layer_prisma_error(): + """A data-layer ``PrismaError`` (the DB IS reachable and rejected the data) + must stay 401. These are always raised from prisma internals, so the check + excludes any ``PrismaError`` by type before inspecting the traceback.""" + data_layer_error = UniqueViolationError( + data={"user_facing_error": {"meta": {"table": "t"}}} + ) + try: + raise data_layer_error + except UniqueViolationError as e: + assert PrismaDBExceptionHandler.is_prisma_engine_internal_error(e) is False + + +@pytest.mark.parametrize( + "error", + [ + DataError(data={"user_facing_error": {"meta": {"table": "t"}}}), + UniqueViolationError(data={"user_facing_error": {"meta": {"table": "t"}}}), + RecordNotFoundError(data={"user_facing_error": {"meta": {"table": "t"}}}), + Exception("some unrelated error"), + ValueError("bad value"), + ], +) +def test_is_database_service_unavailable_error_excludes_non_infra(error): + """Data-layer errors (the DB IS reachable and answered) and generic + non-DB errors must NOT be classified as service-unavailable, otherwise a + genuine 401 would be masked as a transient 503.""" + assert ( + PrismaDBExceptionHandler.is_database_service_unavailable_error(error) is False + ) + + +def test_is_database_service_unavailable_error_asyncpg(monkeypatch): + """asyncpg connection/interface errors map to service-unavailable. asyncpg + is not a hard dependency, so inject a stand-in module to exercise the + branch deterministically regardless of the install environment.""" + import sys + import types + + fake_asyncpg = types.ModuleType("asyncpg") + fake_exceptions = types.ModuleType("asyncpg.exceptions") + + class PostgresConnectionError(Exception): + pass + + class InterfaceError(Exception): + pass + + class UniqueViolationError(Exception): # data-layer, must stay False + pass + + fake_exceptions.PostgresConnectionError = PostgresConnectionError + fake_exceptions.InterfaceError = InterfaceError + fake_exceptions.UniqueViolationError = UniqueViolationError + fake_asyncpg.exceptions = fake_exceptions + + monkeypatch.setitem(sys.modules, "asyncpg", fake_asyncpg) + monkeypatch.setitem(sys.modules, "asyncpg.exceptions", fake_exceptions) + + assert ( + PrismaDBExceptionHandler.is_database_service_unavailable_error( + PostgresConnectionError("connection reset") + ) + is True + ) + assert ( + PrismaDBExceptionHandler.is_database_service_unavailable_error( + InterfaceError("connection was closed") + ) + is True + ) + assert ( + PrismaDBExceptionHandler.is_database_service_unavailable_error( + UniqueViolationError("duplicate key") + ) + is False + ) + + # Test should_allow_request_on_db_unavailable method @patch( "litellm.proxy.proxy_server.general_settings",