fix(proxy): return 5xx on DB infra errors during auth; reserve 401 for genuine auth failures (#29986)

This commit is contained in:
Yassin Kortam 2026-06-10 16:48:11 -07:00 committed by GitHub
parent ba72ccf52c
commit da9d64b4de
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 569 additions and 1 deletions

View File

@ -168,6 +168,16 @@ class UserAPIKeyAuthExceptionHandler:
)
elif isinstance(e, ProxyException):
raise e
if PrismaDBExceptionHandler.is_database_service_unavailable_error(e):
raise ProxyException(
message=(
"Service Unavailable, the authentication database is "
"temporarily unreachable. Please retry shortly."
),
type=ProxyErrorTypes.no_db_connection,
param="None",
code=status.HTTP_503_SERVICE_UNAVAILABLE,
)
raise ProxyException(
message="Authentication Error, " + str(e),
type=ProxyErrorTypes.auth_error,

View File

@ -109,6 +109,92 @@ class PrismaDBExceptionHandler:
return True
return False
@staticmethod
def is_prisma_engine_internal_error(e: Exception) -> bool:
"""True iff ``e`` is a non-``PrismaError`` exception raised from inside
prisma-client-py's query-engine layer.
During the instant a DB connection is torn down, the query engine can
return a malformed error payload (``user_facing_error.meta`` is
``null``). prisma-client-py's ``handle_response_errors`` then crashes
with ``AttributeError: 'NoneType' object has no attribute 'get'``
before it can raise the proper P1001 "can't reach database server"
error. That AttributeError carries no connection keyword, so it can't
be matched by message; identify it by its ``prisma.engine`` origin
instead.
Recognized ``PrismaError`` subclasses are excluded: connectivity ones
are already classified by type/keyword above, and data-layer ones
(the DB IS reachable) must stay 401.
"""
import prisma
if isinstance(e, prisma.errors.PrismaError):
return False
tb = getattr(e, "__traceback__", None)
while tb is not None:
if tb.tb_frame.f_globals.get("__name__", "").startswith("prisma.engine"):
return True
tb = tb.tb_next
return False
@staticmethod
def is_database_service_unavailable_error(e: Exception) -> bool:
"""True iff the exception means the database could not answer at the
infrastructure level (connection refused, socket/interface failure,
timeout) rather than a genuine auth failure (key not found) or a
data-layer error (the DB IS reachable and rejected the data).
Auth must answer 401 only for a key the DB confirms is invalid. When
the DB itself is unreachable, the request has to surface as 503 so
callers retry instead of treating valid keys as invalid during an
outage.
Note: prisma-client-py mislabels the P1001 "can't reach database
server" connectivity failure as a ``DataError`` (a data-layer type),
so a type-only check misses real outages. ``is_database_transport_error``
keyword-matches the connection message and catches that masquerade,
while genuine data errors (no connection keyword) correctly stay 401.
The Postgres "cached plan must not change result type" error is matched
here, not in ``is_database_transport_error``: it is a transient stale-DB-
state condition (not an invalid key), but the connection is healthy so it
must not trigger a reconnect.
A non-``PrismaError`` raised from inside the prisma query engine (e.g.
the ``AttributeError`` from ``handle_response_errors`` when the engine
returns a malformed error payload mid-tear-down) is also treated as
unavailable; see ``is_prisma_engine_internal_error``.
"""
import asyncio
if PrismaDBExceptionHandler.is_database_connection_error(e):
return True
if PrismaDBExceptionHandler.is_database_transport_error(e):
return True
if PrismaDBExceptionHandler.is_prisma_engine_internal_error(e):
return True
if "cached plan must not change result type" in str(e).lower():
return True
# OSError already covers ConnectionError and (Py3.3+) TimeoutError.
# asyncio.TimeoutError is a distinct class before Py3.11.
if isinstance(e, (OSError, asyncio.TimeoutError)):
return True
try:
import asyncpg
except ImportError:
return False
return isinstance(
e,
(
asyncpg.exceptions.PostgresConnectionError,
asyncpg.exceptions.InterfaceError,
),
)
@staticmethod
def handle_db_exception(e: Exception):
"""

View File

@ -112,6 +112,166 @@ async def test_handle_authentication_error_data_layer_errors_do_not_fall_back(
)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"db_error",
[
ConnectionError("connection refused"),
TimeoutError("timed out"),
asyncio.TimeoutError(),
OSError("network is unreachable"),
HTTPClientClosedError(),
PrismaError("can't reach database server"),
RawQueryError(
data={
"user_facing_error": {
"message": "cached plan must not change result type",
"meta": {"table": "t"},
}
}
),
],
)
async def test_handle_authentication_error_db_infra_error_returns_503(db_error):
"""Regression for the outage where valid keys got 401 for 4 hours: an
infrastructure-level DB failure during auth must surface as 503 (the DB
could not confirm the key), never as 401 ("Invalid API key")."""
handler = UserAPIKeyAuthExceptionHandler()
with (
patch(
"litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook",
new_callable=AsyncMock,
return_value=None,
),
patch(
"litellm.proxy.auth.auth_exception_handler.seed_request_identity",
),
patch(
"litellm.proxy.proxy_server.general_settings",
{"allow_requests_on_db_unavailable": False},
),
):
with pytest.raises(ProxyException) as exc_info:
await handler._handle_authentication_error(
db_error,
MagicMock(),
{},
"/v1/chat/completions",
None,
"sk-valid-but-db-down",
)
assert int(exc_info.value.code) == status.HTTP_503_SERVICE_UNAVAILABLE
assert exc_info.value.type == ProxyErrorTypes.no_db_connection
assert "Invalid API key" not in str(exc_info.value.message)
@pytest.mark.asyncio
async def test_handle_authentication_error_prisma_engine_teardown_returns_503():
"""Regression for the first-request-of-an-outage edge case: at the instant
the DB socket drops, the prisma query engine returns a malformed error
payload and prisma-client-py crashes with a bare
``AttributeError: 'NoneType' object has no attribute 'get'`` before it can
raise P1001. That AttributeError reached auth and fell through to 401. It
must surface as 503 like every other infra failure during the outage."""
from prisma.engine import utils as prisma_engine_utils
malformed_payload = [
{
"error": "Can't reach database server",
"user_facing_error": {
"error_code": "P1001",
"message": "Can't reach database server at `localhost`:`5503`",
"meta": None,
},
}
]
try:
prisma_engine_utils.handle_response_errors(None, malformed_payload)
raise AssertionError("expected prisma to raise AttributeError")
except AttributeError as e:
teardown_error = e
handler = UserAPIKeyAuthExceptionHandler()
with (
patch(
"litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook",
new_callable=AsyncMock,
return_value=None,
),
patch(
"litellm.proxy.auth.auth_exception_handler.seed_request_identity",
),
patch(
"litellm.proxy.proxy_server.general_settings",
{"allow_requests_on_db_unavailable": False},
),
):
with pytest.raises(ProxyException) as exc_info:
await handler._handle_authentication_error(
teardown_error,
MagicMock(),
{},
"/v1/chat/completions",
None,
"sk-valid-but-db-down",
)
assert int(exc_info.value.code) == status.HTTP_503_SERVICE_UNAVAILABLE
assert exc_info.value.type == ProxyErrorTypes.no_db_connection
assert "Invalid API key" not in str(exc_info.value.message)
@pytest.mark.asyncio
@pytest.mark.parametrize(
"auth_error",
[
# DB returned no row -> get_key_object raises this exact 401.
ProxyException(
message="Authentication Error, Invalid proxy server token passed.",
type=ProxyErrorTypes.token_not_found_in_db,
param="key",
code=status.HTTP_401_UNAUTHORIZED,
),
# A bare auth failure raised as a plain Exception (e.g. master-key-only
# route) must keep returning 401, not get reclassified as 503.
Exception("Invalid proxy server token passed"),
],
)
async def test_handle_authentication_error_genuine_auth_failure_stays_401(auth_error):
"""Guard against the 503 conversion being too broad: a genuine auth
failure (missing key / wrong key) must still be 401."""
handler = UserAPIKeyAuthExceptionHandler()
with (
patch(
"litellm.proxy.proxy_server.proxy_logging_obj.post_call_failure_hook",
new_callable=AsyncMock,
return_value=None,
),
patch(
"litellm.proxy.auth.auth_exception_handler.seed_request_identity",
),
patch(
"litellm.proxy.proxy_server.general_settings",
{"allow_requests_on_db_unavailable": False},
),
):
with pytest.raises(ProxyException) as exc_info:
await handler._handle_authentication_error(
auth_error,
MagicMock(),
{},
"/v1/chat/completions",
None,
"sk-bad-key",
)
assert int(exc_info.value.code) == status.HTTP_401_UNAUTHORIZED
@pytest.mark.asyncio
async def test_handle_authentication_error_budget_exceeded():
handler = UserAPIKeyAuthExceptionHandler()

View File

@ -1696,7 +1696,9 @@ class TestJWTOAuth2Coexistence:
assert mock_auto_register.call_args.kwargs["team_id"] == "validated-team"
assert mock_auto_register.call_args.kwargs["user_id"] == "validated-user"
assert mock_auto_register.call_args.kwargs["org_id"] == "validated-org"
assert mock_auto_register.call_args.kwargs["end_user_id"] == "validated-end-user"
assert (
mock_auto_register.call_args.kwargs["end_user_id"] == "validated-end-user"
)
assert result.org_id == "validated-org"
@pytest.mark.asyncio
@ -3608,3 +3610,118 @@ async def test_user_api_key_auth_does_not_overwrite_end_user_id_set_by_builder()
finally:
for k, v in originals.items():
setattr(_proxy_server_mod, k, v)
def _proxy_attrs_for_db_lookup():
"""Minimal proxy_server attributes for driving the real
``_user_api_key_auth_builder`` down to the DB key lookup."""
proxy_logging_obj = MagicMock()
proxy_logging_obj.post_call_failure_hook = AsyncMock(return_value=None)
return {
"prisma_client": MagicMock(),
"user_api_key_cache": DualCache(),
"proxy_logging_obj": proxy_logging_obj,
"master_key": "sk-test-master",
"general_settings": {"allow_requests_on_db_unavailable": False},
"llm_model_list": [],
"llm_router": None,
"open_telemetry_logger": None,
"model_max_budget_limiter": MagicMock(),
"user_custom_auth": None,
"jwt_handler": None,
"litellm_proxy_admin_name": "admin",
}
async def _run_builder_with_key_lookup(get_key_object_mock):
"""Drive the real auth builder with ``get_key_object`` replaced by the
given mock. Returns the builder result. Patches ``seed_request_identity``
so the failure path doesn't touch OTEL."""
from fastapi import Request
from starlette.datastructures import URL
import litellm.proxy.proxy_server as _proxy_server_mod
from litellm.proxy.auth.user_api_key_auth import _user_api_key_auth_builder
attrs = _proxy_attrs_for_db_lookup()
originals = {a: getattr(_proxy_server_mod, a, None) for a in attrs}
try:
for k, v in attrs.items():
setattr(_proxy_server_mod, k, v)
request = Request(scope={"type": "http"})
request._url = URL(url="/chat/completions")
with (
patch(
"litellm.proxy.auth.user_api_key_auth.get_key_object",
get_key_object_mock,
),
patch(
"litellm.proxy.auth.auth_exception_handler.seed_request_identity",
),
):
return await _user_api_key_auth_builder(
request=request,
api_key="Bearer sk-db-lookup-test",
azure_api_key_header="",
anthropic_api_key_header=None,
google_ai_studio_api_key_header=None,
azure_apim_header=None,
request_data={},
)
finally:
for k, v in originals.items():
setattr(_proxy_server_mod, k, v)
@pytest.mark.asyncio
async def test_builder_returns_503_when_db_lookup_raises_infra_error():
"""End-to-end: a DB infrastructure failure during the key lookup must
propagate past the ``except ProxyException`` guard and surface as 503,
not the 401 that masked the 4-hour outage. Killing the new 503 branch
flips this to 401 and fails the test."""
get_key_object = AsyncMock(side_effect=ConnectionError("connection refused"))
with pytest.raises(ProxyException) as exc_info:
await _run_builder_with_key_lookup(get_key_object)
assert int(exc_info.value.code) == status.HTTP_503_SERVICE_UNAVAILABLE
assert exc_info.value.type == ProxyErrorTypes.no_db_connection
assert "Invalid API key" not in str(exc_info.value.message)
@pytest.mark.asyncio
async def test_builder_returns_401_when_db_lookup_reports_missing_key():
"""Regression guard: a genuinely missing key (DB returned no row, which
``get_key_object`` raises as a 401 ProxyException) must still be 401."""
missing_key_error = ProxyException(
message="Authentication Error, Invalid proxy server token passed. key=..., not found in db.",
type=ProxyErrorTypes.token_not_found_in_db,
param="key",
code=status.HTTP_401_UNAUTHORIZED,
)
get_key_object = AsyncMock(side_effect=missing_key_error)
with pytest.raises(ProxyException) as exc_info:
await _run_builder_with_key_lookup(get_key_object)
assert int(exc_info.value.code) == status.HTTP_401_UNAUTHORIZED
@pytest.mark.asyncio
async def test_builder_succeeds_when_db_lookup_returns_valid_token():
"""Regression guard: a valid key still authenticates. Proves the 503
conversion only fires on the failure path and never intercepts success."""
valid_token = UserAPIKeyAuth(api_key="sk-db-lookup-test", token="hashed-valid")
get_key_object = AsyncMock(return_value=valid_token)
with patch(
"litellm.proxy.auth.user_api_key_auth._return_user_api_key_auth_obj",
new_callable=AsyncMock,
return_value=valid_token,
) as mock_return:
result = await _run_builder_with_key_lookup(get_key_object)
assert isinstance(result, UserAPIKeyAuth)
# Reaching the success-assembly return (never the exception handler)
# proves a valid key is unaffected by the 503 conversion.
mock_return.assert_awaited_once()

View File

@ -107,6 +107,201 @@ def test_is_database_connection_generic_errors():
)
@pytest.mark.parametrize(
"error",
[
ConnectionError("connection refused"),
TimeoutError("timed out"),
OSError("network is unreachable"),
asyncio.TimeoutError(),
HTTPClientClosedError(),
ClientNotConnectedError(),
PrismaError("can't reach database server"),
PrismaError(),
],
)
def test_is_database_service_unavailable_error_infra_failures(error):
"""Infrastructure-level failures (socket/connection/timeout, prisma
transport, unknown PrismaError) mean the DB could not answer, so auth
must surface 503 instead of treating a valid key as invalid."""
assert PrismaDBExceptionHandler.is_database_service_unavailable_error(error) is True
def test_is_database_service_unavailable_error_prisma_p1001_masquerades_as_dataerror():
"""Real-world regression: prisma-client-py raises the P1001 "can't reach
database server" connectivity failure as a DataError (a data-layer type).
A type-only check would miss it and return 401 during a genuine outage;
the message keyword must still classify it as service-unavailable -> 503."""
p1001_as_dataerror = DataError(
data={
"user_facing_error": {
"message": "Can't reach database server at `127.0.0.1`:`5499`",
"meta": {"table": "t"},
}
}
)
assert (
PrismaDBExceptionHandler.is_database_service_unavailable_error(
p1001_as_dataerror
)
is True
)
def test_is_database_service_unavailable_error_cached_plan_escapes_as_503():
"""Composes with the cached-plan retry: when that recovery fails and the
Postgres "cached plan must not change result type" error escapes (raised by
prisma as a data-layer RawQueryError), it is a transient stale-DB-state
condition, not an invalid key, so it must classify as service-unavailable
-> 503 rather than fall through to 401."""
cached_plan_error = RawQueryError(
data={
"user_facing_error": {
"message": "cached plan must not change result type",
"meta": {"table": "t"},
}
}
)
assert (
PrismaDBExceptionHandler.is_database_service_unavailable_error(
cached_plan_error
)
is True
)
def test_is_database_service_unavailable_error_prisma_engine_malformed_payload():
"""Real-world regression: at the instant the DB socket drops, the prisma
query engine returns a malformed error payload (``user_facing_error.meta``
is ``null``). prisma-client-py's ``handle_response_errors`` then crashes
with ``AttributeError: 'NoneType' object has no attribute 'get'`` before it
can raise the proper P1001 error. That bare AttributeError has no
connection keyword, so without the prisma-engine-origin check it falls
through to 401 on the first request of an outage. Reproduce the exact
prisma crash and assert it classifies as service-unavailable -> 503."""
from prisma.engine import utils as prisma_engine_utils
malformed_payload = [
{
"error": "Can't reach database server",
"user_facing_error": {
"error_code": "P1001",
"message": "Can't reach database server at `localhost`:`5503`",
"meta": None,
},
}
]
with pytest.raises(AttributeError) as exc_info:
prisma_engine_utils.handle_response_errors(None, malformed_payload)
assert "no attribute 'get'" in str(exc_info.value)
assert (
PrismaDBExceptionHandler.is_database_service_unavailable_error(exc_info.value)
is True
)
def test_is_prisma_engine_internal_error_excludes_application_attributeerror():
"""The prisma-engine-origin check must stay narrow: a genuine AttributeError
raised by application code (a real bug) must NOT be classified as
service-unavailable, otherwise real bugs would silently become 503s."""
def application_bug():
none_value = None
return none_value.get("oops")
with pytest.raises(AttributeError) as exc_info:
application_bug()
assert (
PrismaDBExceptionHandler.is_prisma_engine_internal_error(exc_info.value)
is False
)
assert (
PrismaDBExceptionHandler.is_database_service_unavailable_error(exc_info.value)
is False
)
def test_is_prisma_engine_internal_error_excludes_data_layer_prisma_error():
"""A data-layer ``PrismaError`` (the DB IS reachable and rejected the data)
must stay 401. These are always raised from prisma internals, so the check
excludes any ``PrismaError`` by type before inspecting the traceback."""
data_layer_error = UniqueViolationError(
data={"user_facing_error": {"meta": {"table": "t"}}}
)
try:
raise data_layer_error
except UniqueViolationError as e:
assert PrismaDBExceptionHandler.is_prisma_engine_internal_error(e) is False
@pytest.mark.parametrize(
"error",
[
DataError(data={"user_facing_error": {"meta": {"table": "t"}}}),
UniqueViolationError(data={"user_facing_error": {"meta": {"table": "t"}}}),
RecordNotFoundError(data={"user_facing_error": {"meta": {"table": "t"}}}),
Exception("some unrelated error"),
ValueError("bad value"),
],
)
def test_is_database_service_unavailable_error_excludes_non_infra(error):
"""Data-layer errors (the DB IS reachable and answered) and generic
non-DB errors must NOT be classified as service-unavailable, otherwise a
genuine 401 would be masked as a transient 503."""
assert (
PrismaDBExceptionHandler.is_database_service_unavailable_error(error) is False
)
def test_is_database_service_unavailable_error_asyncpg(monkeypatch):
"""asyncpg connection/interface errors map to service-unavailable. asyncpg
is not a hard dependency, so inject a stand-in module to exercise the
branch deterministically regardless of the install environment."""
import sys
import types
fake_asyncpg = types.ModuleType("asyncpg")
fake_exceptions = types.ModuleType("asyncpg.exceptions")
class PostgresConnectionError(Exception):
pass
class InterfaceError(Exception):
pass
class UniqueViolationError(Exception): # data-layer, must stay False
pass
fake_exceptions.PostgresConnectionError = PostgresConnectionError
fake_exceptions.InterfaceError = InterfaceError
fake_exceptions.UniqueViolationError = UniqueViolationError
fake_asyncpg.exceptions = fake_exceptions
monkeypatch.setitem(sys.modules, "asyncpg", fake_asyncpg)
monkeypatch.setitem(sys.modules, "asyncpg.exceptions", fake_exceptions)
assert (
PrismaDBExceptionHandler.is_database_service_unavailable_error(
PostgresConnectionError("connection reset")
)
is True
)
assert (
PrismaDBExceptionHandler.is_database_service_unavailable_error(
InterfaceError("connection was closed")
)
is True
)
assert (
PrismaDBExceptionHandler.is_database_service_unavailable_error(
UniqueViolationError("duplicate key")
)
is False
)
# Test should_allow_request_on_db_unavailable method
@patch(
"litellm.proxy.proxy_server.general_settings",