fix(proxy): recover from cached-plan errors by reconnecting the Prisma client (#29983)

This commit is contained in:
Yassin Kortam 2026-06-10 16:06:01 -07:00 committed by GitHub
parent 1436ee9092
commit 3bd3951e37
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 119 additions and 32 deletions

View File

@ -3253,40 +3253,49 @@ class PrismaClient:
self, sql_query: str, *args
) -> Optional[dict]:
"""
Execute a query with automatic fallback for PostgreSQL cached plan errors.
Execute a query, recovering once from PostgreSQL's "cached plan must not
change result type" error.
This handles the "cached plan must not change result type" error that occurs
during rolling deployments when schema changes are applied while old pods
still have cached query plans expecting the old schema.
That error surfaces during rolling deployments when a schema change
invalidates the prepared-statement plans that pooled connections still
hold. Clearing only the server-side plans with DEALLOCATE ALL makes
things worse: Prisma's query engine keeps a per-connection client-side
cache of prepared-statement names, so once the server drops a plan the
engine re-sends a name PostgreSQL no longer recognizes and the
connection breaks with `prepared statement "sN" does not exist`. With a
small pool that connection stays poisoned and every auth lookup fails.
Args:
sql_query: SQL query string to execute
Recreating the Prisma client kills the engine subprocess and drops the
server-side plans and the engine's client-side name cache together, so
the retried query is prepared fresh. We reconnect through
`attempt_db_reconnect`, which is singleflight: when a schema change
poisons every pooled connection at once, the first cached-plan error
recreates the client and the concurrent waiters reuse that single
recreate instead of racing to kill each other's fresh engine. We then
retry the identical query exactly once.
Returns:
Query result or None
The retry reuses the original query byte-for-byte. Mutating the SQL
(e.g. injecting a unique comment) would defeat PostgreSQL's plan cache,
forcing a fresh plan on every request and pegging the database CPU.
Raises:
Original exception if not a cached plan error
If the reconnect is skipped because a recent reconnect is still within
its cooldown, the retry runs against the same connection and may fail
again; the get_data backoff decorator re-runs the lookup and a later
attempt reconnects once the cooldown elapses.
"""
try:
return await self.db.query_first(sql_query, *args)
except Exception as e:
error_str = str(e)
if "cached plan must not change result type" in error_str:
# Force PostgreSQL to re-plan by invalidating the cache
# Add a unique comment to make the query different
sql_query_retry = sql_query.replace(
"SELECT",
f"SELECT /* cache_invalidated_{int(time.time() * 1000)} */",
)
verbose_proxy_logger.warning(
"PostgreSQL cached plan error detected for token lookup, "
"retrying with fresh plan. This may occur during rolling deployments "
"when schema changes are applied."
)
return await self.db.query_first(sql_query_retry, *args)
else:
if "cached plan must not change result type" not in str(e):
raise
verbose_proxy_logger.warning(
"PostgreSQL cached plan error detected for token lookup; "
"recreating the database connection and retrying with the same "
"query. This may occur during rolling deployments when schema "
"changes are applied."
)
await self.attempt_db_reconnect(reason="postgres_cached_plan_error")
return await self.db.query_first(sql_query, *args)
@backoff.on_exception(
backoff.expo,

View File

@ -193,6 +193,7 @@ async def test_query_first_with_cached_plan_fallback_happy_returns_row(
) -> None:
expected = {"token": "abc", "team_spend": 1.0, "team_max_budget": 5.0}
prisma_client.db.query_first = AsyncMock(return_value=expected)
prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
result = await prisma_client._query_first_with_cached_plan_fallback(
"SELECT * FROM x WHERE token = $1", "abc"
)
@ -208,35 +209,110 @@ async def test_query_first_with_cached_plan_fallback_happy_returns_row(
"args": ("SELECT * FROM x WHERE token = $1", "abc"),
"matches": True,
}
prisma_client.attempt_db_reconnect.assert_not_awaited()
@pytest.mark.asyncio
async def test_query_first_with_cached_plan_fallback_retries_on_cached_plan_error(
async def test_query_first_with_cached_plan_fallback_reconnects_then_retries_identical_query(
prisma_client: PrismaClient,
) -> None:
original_query = 'SELECT * FROM "LiteLLM_VerificationToken" WHERE v.token = $1'
expected = {"token": "abc", "team_spend": 1.0, "team_max_budget": 5.0}
manager = MagicMock()
query_first = AsyncMock(
side_effect=[
RuntimeError("cached plan must not change result type"),
expected,
]
)
reconnect = AsyncMock(return_value=True)
manager.attach_mock(query_first, "query_first")
manager.attach_mock(reconnect, "attempt_db_reconnect")
prisma_client.db.query_first = query_first
prisma_client.attempt_db_reconnect = reconnect
result = await prisma_client._query_first_with_cached_plan_fallback(
original_query, "abc"
)
assert result == expected
assert query_first.await_count == 2
first_call, retry_call = query_first.await_args_list
assert retry_call.args == first_call.args == (original_query, "abc")
reconnect.assert_awaited_once()
assert reconnect.await_args.kwargs.get("force", False) is False
assert [name for name, *_ in manager.mock_calls] == [
"query_first",
"attempt_db_reconnect",
"query_first",
]
@pytest.mark.asyncio
async def test_query_first_with_cached_plan_fallback_never_deallocates(
prisma_client: PrismaClient,
) -> None:
expected = {"token": "abc"}
prisma_client.db.query_first = AsyncMock(
side_effect=[
RuntimeError("cached plan must not change result type"),
expected,
]
)
result = await prisma_client._query_first_with_cached_plan_fallback(
"SELECT * FROM x WHERE token = $1", "abc"
prisma_client.db.execute_raw = AsyncMock(return_value=0)
prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
await prisma_client._query_first_with_cached_plan_fallback("SELECT 1")
prisma_client.db.execute_raw.assert_not_awaited()
@pytest.mark.asyncio
async def test_query_first_with_cached_plan_fallback_propagates_when_retry_also_fails(
prisma_client: PrismaClient,
) -> None:
plan_error = RuntimeError("cached plan must not change result type")
prisma_client.db.query_first = AsyncMock(side_effect=[plan_error, plan_error])
prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
with pytest.raises(RuntimeError, match="cached plan must not change result type"):
await prisma_client._query_first_with_cached_plan_fallback("SELECT 1")
assert prisma_client.db.query_first.await_count == 2
prisma_client.attempt_db_reconnect.assert_awaited_once()
@pytest.mark.asyncio
async def test_query_first_with_cached_plan_fallback_retries_when_reconnect_returns_false(
prisma_client: PrismaClient,
) -> None:
expected = {"token": "abc"}
prisma_client.db.query_first = AsyncMock(
side_effect=[
RuntimeError("cached plan must not change result type"),
expected,
]
)
prisma_client.attempt_db_reconnect = AsyncMock(return_value=False)
result = await prisma_client._query_first_with_cached_plan_fallback("SELECT 1")
assert result == expected
assert prisma_client.db.query_first.await_count == 2
second_call_sql = prisma_client.db.query_first.await_args_list[1].args[0]
assert "cache_invalidated_" in second_call_sql
@pytest.mark.asyncio
async def test_query_first_with_cached_plan_fallback_reraises_non_plan_errors(
prisma_client: PrismaClient,
) -> None:
prisma_client.db.query_first = AsyncMock(side_effect=RuntimeError("totally unrelated"))
prisma_client.db.query_first = AsyncMock(
side_effect=RuntimeError("totally unrelated")
)
prisma_client.attempt_db_reconnect = AsyncMock(return_value=True)
with pytest.raises(RuntimeError, match="totally unrelated"):
await prisma_client._query_first_with_cached_plan_fallback("SELECT 1")
assert prisma_client.db.query_first.await_count == 1
prisma_client.attempt_db_reconnect.assert_not_awaited()
@pytest.mark.asyncio
@ -351,7 +427,9 @@ async def test_get_data_token_find_unique_returns_record(
async def test_get_data_token_find_unique_missing_token_raises_401(
prisma_client: PrismaClient,
) -> None:
prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock(return_value=None)
prisma_client.db.litellm_verificationtoken.find_unique = AsyncMock(
return_value=None
)
with pytest.raises(HTTPException) as excinfo:
await prisma_client.get_data(token="sk-missing", table_name="key")
err = excinfo.value