feat: add read-replica routing for Prisma DB via DATABASE_URL_READ_REPLICA (#27493)

- Introduce RoutingPrismaWrapper that transparently routes read operations (find_*, count, group_by, query_raw, query_first) to a reader endpoint while writes remain on the writer, enabling Aurora-style reader/writer endpoint splits
- Add IAMEndpoint dataclass and parse_iam_endpoint_from_url() to capture static connection fields from a reader URL so only the IAM token needs to rotate, avoiding the need for separate DATABASE_HOST_READ_REPLICA/etc. env vars
- Enhance PrismaWrapper with per-instance knobs (db_url_env_var, iam_endpoint, recreate_uses_datasource, log_prefix) so writer and reader wrappers are independent: the reader writes its fresh URL to DATABASE_URL_READ_REPLICA and passes datasource override to Prisma since Prisma only auto-reads DATABASE_URL
- Fix deadlock in PrismaWrapper.__getattr__: when called from inside a running event loop, schedule the token refresh as a background task instead of blocking with run_coroutine_threadsafe + future.result(), which would deadlock the loop thread waiting for a coroutine that needs the loop to run
- Fix botocore crash when DATABASE_PORT is unset by defaulting to "5432" in both proxy_cli.py and PrismaWrapper.get_rds_iam_token(); passing None caused botocore to embed the literal string "None" in the presigned URL
- Implement graceful reader degradation: reader connect/recreate failures are non-fatal; wrapper sets _reader_unavailable=True and silently routes reads to the writer to keep the proxy serving traffic during transient reader outages
- Add PrismaClient.writer_db property so the reconnect smoke-test always validates the writer engine specifically; query_raw on the routing wrapper would route to the reader and not verify the newly-recreated writer
- Expose DATABASE_URL_READ_REPLICA in Helm chart (values.yaml + deployment.yaml) via both plain value and secret key reference, and document the field in docker-compose.yml
- Add 887-line test suite covering routing logic, IAM token refresh paths, reader degradation scenarios, datasource override behavior, and the deadlock regression

Co-authored-by: Yassin Kortam <yassinkortam@g.ucla.edu>
This commit is contained in:
Yassin Kortam 2026-05-08 21:05:50 -07:00 committed by GitHub
parent 0bcff0214a
commit b5d3a5fc85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 1423 additions and 77 deletions

View File

@ -100,6 +100,16 @@ spec:
- name: DATABASE_URL - name: DATABASE_URL
value: {{ .Values.db.url | quote }} value: {{ .Values.db.url | quote }}
{{- end }} {{- end }}
{{- if and .Values.db.useExisting .Values.db.secret.readReplicaUrlKey }}
- name: DATABASE_URL_READ_REPLICA
valueFrom:
secretKeyRef:
name: {{ .Values.db.secret.name }}
key: {{ .Values.db.secret.readReplicaUrlKey }}
{{- else if .Values.db.readReplicaUrl }}
- name: DATABASE_URL_READ_REPLICA
value: {{ .Values.db.readReplicaUrl | quote }}
{{- end }}
- name: PROXY_MASTER_KEY - name: PROXY_MASTER_KEY
valueFrom: valueFrom:
secretKeyRef: secretKeyRef:

View File

@ -252,6 +252,26 @@ db:
passwordKey: password passwordKey: password
# Optional: when set, DATABASE_HOST will be sourced from this secret key instead of db.endpoint # Optional: when set, DATABASE_HOST will be sourced from this secret key instead of db.endpoint
endpointKey: "" endpointKey: ""
# Optional: when set, DATABASE_URL_READ_REPLICA will be sourced from this
# secret key instead of db.readReplicaUrl. Prefer this over the plain
# value: read-replica URLs typically embed credentials, and a value
# written to db.readReplicaUrl ends up visible in the rendered pod spec
# and the Helm release secret.
readReplicaUrlKey: ""
# Optional read-replica routing. When set, the proxy sends read-only
# queries (find_*, count, group_by, query_raw/_first) to this URL while
# writes continue to go to db.url. Useful for Aurora-style clusters with
# separate reader/writer endpoints. Leave empty to keep single-DB behavior.
# When IAM_TOKEN_DB_AUTH is enabled, the reader URL is auto-refreshed
# alongside the writer (host/port/user/db are parsed from this URL once
# at startup; only the IAM token rotates).
#
# If the URL embeds credentials, prefer db.secret.readReplicaUrlKey over
# this field — the plain value is rendered into the pod spec and the
# Helm release secret. This field is intended for credential-less URLs
# only (e.g. when IAM_TOKEN_DB_AUTH supplies the token at runtime).
readReplicaUrl: ""
# Use the Stackgres Helm chart to deploy an instance of a Stackgres cluster. # Use the Stackgres Helm chart to deploy an instance of a Stackgres cluster.
# The Stackgres Operator must already be installed within the target # The Stackgres Operator must already be installed within the target

View File

@ -16,6 +16,11 @@ services:
- "4000:4000" # Map the container port to the host, change the host port if necessary - "4000:4000" # Map the container port to the host, change the host port if necessary
environment: environment:
DATABASE_URL: "postgresql://llmproxy:dbpassword9090@db:5432/litellm" DATABASE_URL: "postgresql://llmproxy:dbpassword9090@db:5432/litellm"
# Optional: route read-only queries (find_*, count, group_by, query_raw/_first)
# to a separate reader endpoint, e.g. an Aurora reader. Leave unset for
# single-DB deployments. With IAM_TOKEN_DB_AUTH enabled, the reader URL
# is auto-refreshed alongside the writer.
# DATABASE_URL_READ_REPLICA: "postgresql://llmproxy:dbpassword9090@db-reader:5432/litellm"
STORE_MODEL_IN_DB: "True" # allows adding models to proxy via UI STORE_MODEL_IN_DB: "True" # allows adding models to proxy via UI
env_file: env_file:
- .env # Load local .env file - .env # Load local .env file

View File

@ -10,13 +10,64 @@ import subprocess
import time import time
import urllib import urllib
import urllib.parse import urllib.parse
from dataclasses import dataclass
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Optional, Union from typing import Any, Dict, Optional, Union
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.secret_managers.main import str_to_bool from litellm.secret_managers.main import str_to_bool
@dataclass(frozen=True)
class IAMEndpoint:
"""Static parts of an RDS IAM-authenticated Postgres connection.
The IAM token rotates every ~15 minutes; everything else (host, port, user,
database name, schema) stays fixed. We capture the static fields once so
refresh just regenerates the token and reassembles the URL.
"""
host: str
port: str
user: str
name: str
schema: Optional[str] = None
def build_url(self, token: str) -> str:
url = f"postgresql://{self.user}:{token}@{self.host}:{self.port}/{self.name}"
if self.schema:
url += f"?schema={self.schema}"
return url
def parse_iam_endpoint_from_url(url: str) -> IAMEndpoint:
"""Parse an IAMEndpoint from a Postgres URL.
Used so a reader URL can drive its own IAM refresh without requiring
callers to set parallel DATABASE_HOST_READ_REPLICA / etc. env vars.
"""
parsed = urllib.parse.urlparse(url)
if not parsed.hostname or not parsed.username:
raise ValueError("Cannot parse IAM endpoint from URL: missing host or username")
name = (parsed.path or "/").lstrip("/")
if not name:
raise ValueError("Cannot parse IAM endpoint from URL: missing database name")
port = str(parsed.port) if parsed.port else "5432"
schema: Optional[str] = None
if parsed.query:
qs = urllib.parse.parse_qs(parsed.query)
schema_vals = qs.get("schema")
if schema_vals:
schema = schema_vals[0]
return IAMEndpoint(
host=parsed.hostname,
port=port,
user=parsed.username,
name=name,
schema=schema,
)
class PrismaWrapper: class PrismaWrapper:
""" """
Wrapper around Prisma client that handles RDS IAM token authentication. Wrapper around Prisma client that handles RDS IAM token authentication.
@ -37,10 +88,33 @@ class PrismaWrapper:
# Fallback refresh interval if token parsing fails (10 minutes) # Fallback refresh interval if token parsing fails (10 minutes)
FALLBACK_REFRESH_INTERVAL_SECONDS = 600 FALLBACK_REFRESH_INTERVAL_SECONDS = 600
def __init__(self, original_prisma: Any, iam_token_db_auth: bool): def __init__(
self,
original_prisma: Any,
iam_token_db_auth: bool,
*,
db_url_env_var: str = "DATABASE_URL",
iam_endpoint: Optional[IAMEndpoint] = None,
recreate_uses_datasource: bool = False,
log_prefix: str = "",
):
self._original_prisma = original_prisma self._original_prisma = original_prisma
self.iam_token_db_auth = iam_token_db_auth self.iam_token_db_auth = iam_token_db_auth
# Per-connection knobs so the same wrapper can be used for the writer
# (defaults: DATABASE_URL env, IAM endpoint from DATABASE_HOST/etc.,
# recreate via env reload) or for a reader (DATABASE_URL_READ_REPLICA
# env, IAM endpoint parsed from that URL, recreate via datasource
# override since Prisma only auto-reads DATABASE_URL).
self._db_url_env_var = db_url_env_var
self._iam_endpoint = iam_endpoint
self._recreate_uses_datasource = recreate_uses_datasource
# Tag every log line emitted by this wrapper instance so writer and
# reader can be told apart in interleaved output (e.g. "[writer] RDS
# IAM token refresh scheduled in 720 seconds"). Empty string (default)
# keeps backward-compatible logs for the single-DB case.
self._log_prefix = f"{log_prefix} " if log_prefix else ""
# Background token refresh task management # Background token refresh task management
self._token_refresh_task: Optional[asyncio.Task] = None self._token_refresh_task: Optional[asyncio.Task] = None
self._reconnection_lock = asyncio.Lock() self._reconnection_lock = asyncio.Lock()
@ -157,7 +231,7 @@ class PrismaWrapper:
Returns 0 if token should be refreshed immediately. Returns 0 if token should be refreshed immediately.
Returns FALLBACK_REFRESH_INTERVAL_SECONDS if parsing fails. Returns FALLBACK_REFRESH_INTERVAL_SECONDS if parsing fails.
""" """
db_url = os.getenv("DATABASE_URL") db_url = os.getenv(self._db_url_env_var)
token = self._extract_token_from_db_url(db_url) token = self._extract_token_from_db_url(db_url)
expiration_time = self._parse_token_expiration(token) expiration_time = self._parse_token_expiration(token)
@ -199,12 +273,30 @@ class PrismaWrapper:
return datetime.utcnow() > expiration_time return datetime.utcnow() > expiration_time
def get_rds_iam_token(self) -> Optional[str]: def get_rds_iam_token(self) -> Optional[str]:
"""Generate a new RDS IAM token and update DATABASE_URL.""" """Generate a new RDS IAM token and update the configured DB URL env var.
if self.iam_token_db_auth:
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
When the wrapper was constructed with an explicit `iam_endpoint`
(typical for a reader wrapper whose host/port/user came from a parsed
URL), use that. Otherwise fall back to the legacy DATABASE_HOST/PORT/
USER/NAME/SCHEMA env vars (writer behavior).
"""
if not self.iam_token_db_auth:
return None
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
if self._iam_endpoint is not None:
endpoint = self._iam_endpoint
token = generate_iam_auth_token(
db_host=endpoint.host, db_port=endpoint.port, db_user=endpoint.user
)
_db_url = endpoint.build_url(token)
else:
db_host = os.getenv("DATABASE_HOST") db_host = os.getenv("DATABASE_HOST")
db_port = os.getenv("DATABASE_PORT") # Default to the Postgres standard port; passing None to
# `generate_iam_auth_token` makes botocore embed the literal
# string "None" in the presigned URL, which then fails to parse.
db_port = os.getenv("DATABASE_PORT", "5432")
db_user = os.getenv("DATABASE_USER") db_user = os.getenv("DATABASE_USER")
db_name = os.getenv("DATABASE_NAME") db_name = os.getenv("DATABASE_NAME")
db_schema = os.getenv("DATABASE_SCHEMA") db_schema = os.getenv("DATABASE_SCHEMA")
@ -217,9 +309,8 @@ class PrismaWrapper:
if db_schema: if db_schema:
_db_url += f"?schema={db_schema}" _db_url += f"?schema={db_schema}"
os.environ["DATABASE_URL"] = _db_url os.environ[self._db_url_env_var] = _db_url
return _db_url return _db_url
return None
async def recreate_prisma_client( async def recreate_prisma_client(
self, new_db_url: str, http_client: Optional[Any] = None self, new_db_url: str, http_client: Optional[Any] = None
@ -231,6 +322,11 @@ class PrismaWrapper:
synchronous `subprocess.Popen.wait()` that can freeze the asyncio event synchronous `subprocess.Popen.wait()` that can freeze the asyncio event
loop for 30-120+ seconds when the engine is stuck on TCP close, loop for 30-120+ seconds when the engine is stuck on TCP close,
breaking `/health/liveliness` and causing Kubernetes pod restarts. breaking `/health/liveliness` and causing Kubernetes pod restarts.
The writer wrapper relies on Prisma re-reading `DATABASE_URL` from env;
the reader wrapper opts into `recreate_uses_datasource=True` so the
new URL is passed explicitly via `datasource={"url": ...}` (Prisma
does not auto-read alternate env vars like DATABASE_URL_READ_REPLICA).
""" """
from prisma import Prisma # type: ignore from prisma import Prisma # type: ignore
@ -238,10 +334,12 @@ class PrismaWrapper:
if old_engine_pid > 0: if old_engine_pid > 0:
await self._kill_engine_process(old_engine_pid) await self._kill_engine_process(old_engine_pid)
kwargs: Dict[str, Any] = {}
if http_client is not None: if http_client is not None:
self._original_prisma = Prisma(http=http_client) kwargs["http"] = http_client
else: if self._recreate_uses_datasource:
self._original_prisma = Prisma() kwargs["datasource"] = {"url": new_db_url}
self._original_prisma = Prisma(**kwargs)
await self._original_prisma.connect() await self._original_prisma.connect()
@ -265,7 +363,8 @@ class PrismaWrapper:
self._token_refresh_task = asyncio.create_task(self._token_refresh_loop()) self._token_refresh_task = asyncio.create_task(self._token_refresh_loop())
verbose_proxy_logger.info( verbose_proxy_logger.info(
"Started RDS IAM token proactive refresh background task" "%sStarted RDS IAM token proactive refresh background task",
self._log_prefix,
) )
async def stop_token_refresh_task(self) -> None: async def stop_token_refresh_task(self) -> None:
@ -283,7 +382,9 @@ class PrismaWrapper:
except asyncio.CancelledError: except asyncio.CancelledError:
pass pass
self._token_refresh_task = None self._token_refresh_task = None
verbose_proxy_logger.info("Stopped RDS IAM token refresh background task") verbose_proxy_logger.info(
"%sStopped RDS IAM token refresh background task", self._log_prefix
)
async def _token_refresh_loop(self) -> None: async def _token_refresh_loop(self) -> None:
""" """
@ -294,7 +395,7 @@ class PrismaWrapper:
This is more efficient than polling, requiring only 1 wake-up per token cycle. This is more efficient than polling, requiring only 1 wake-up per token cycle.
""" """
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"RDS IAM token refresh loop started. " f"{self._log_prefix}RDS IAM token refresh loop started. "
f"Tokens will be refreshed {self.TOKEN_REFRESH_BUFFER_SECONDS}s before expiration." f"Tokens will be refreshed {self.TOKEN_REFRESH_BUFFER_SECONDS}s before expiration."
) )
@ -305,21 +406,25 @@ class PrismaWrapper:
if sleep_seconds > 0: if sleep_seconds > 0:
verbose_proxy_logger.info( verbose_proxy_logger.info(
f"RDS IAM token refresh scheduled in {sleep_seconds:.0f} seconds " f"{self._log_prefix}RDS IAM token refresh scheduled in "
f"({sleep_seconds / 60:.1f} minutes)" f"{sleep_seconds:.0f} seconds ({sleep_seconds / 60:.1f} minutes)"
) )
await asyncio.sleep(sleep_seconds) await asyncio.sleep(sleep_seconds)
# Refresh the token # Refresh the token
verbose_proxy_logger.info("Proactively refreshing RDS IAM token...") verbose_proxy_logger.info(
"%sProactively refreshing RDS IAM token...", self._log_prefix
)
await self._safe_refresh_token() await self._safe_refresh_token()
except asyncio.CancelledError: except asyncio.CancelledError:
verbose_proxy_logger.info("RDS IAM token refresh loop cancelled") verbose_proxy_logger.info(
"%sRDS IAM token refresh loop cancelled", self._log_prefix
)
break break
except Exception as e: except Exception as e:
verbose_proxy_logger.error( verbose_proxy_logger.error(
f"Error in RDS IAM token refresh loop: {e}. " f"{self._log_prefix}Error in RDS IAM token refresh loop: {e}. "
f"Retrying in {self.FALLBACK_REFRESH_INTERVAL_SECONDS}s..." f"Retrying in {self.FALLBACK_REFRESH_INTERVAL_SECONDS}s..."
) )
# On error, wait before retrying to avoid tight error loops # On error, wait before retrying to avoid tight error loops
@ -341,65 +446,75 @@ class PrismaWrapper:
await self.recreate_prisma_client(new_db_url) await self.recreate_prisma_client(new_db_url)
self._last_refresh_time = datetime.utcnow() self._last_refresh_time = datetime.utcnow()
verbose_proxy_logger.info( verbose_proxy_logger.info(
"RDS IAM token refreshed successfully. New token valid for ~15 minutes." "%sRDS IAM token refreshed successfully. New token valid for ~15 minutes.",
self._log_prefix,
) )
else: else:
verbose_proxy_logger.error( verbose_proxy_logger.error(
"Failed to generate new RDS IAM token during proactive refresh" "%sFailed to generate new RDS IAM token during proactive refresh",
self._log_prefix,
) )
def __getattr__(self, name: str): def __getattr__(self, name: str):
""" """
Proxy attribute access to the underlying Prisma client. Proxy attribute access to the underlying Prisma client.
If IAM token auth is enabled and the token is expired, this method If IAM token auth is enabled and the token is found expired here, the
provides a synchronous fallback to refresh the token. However, this proactive refresh task has missed its window. Behavior depends on
should rarely be needed since the background task proactively refreshes whether we're called from inside a running event loop:
tokens before they expire.
FIXED: Now properly waits for reconnection to complete before returning, - Inside the loop (typical: from a coroutine): schedule a refresh as a
instead of the previous fire-and-forget pattern that caused the bug. background task and return the (stale) attribute. The caller's await
will likely fail with a connection error and be retried by upper
layers (`call_with_db_reconnect_retry`); by that time the refresh
has either completed or escalated to the proactive loop's error
path. We CANNOT block here `run_coroutine_threadsafe(...)` +
`future.result()` from inside the same loop deadlocks the loop
(loop thread is blocked, scheduled coroutine never runs, 30s timeout).
- No running loop (sync caller, mostly tests): run the refresh in a
fresh loop and re-fetch the attribute.
""" """
original_attr = getattr(self._original_prisma, name) original_attr = getattr(self._original_prisma, name)
if self.iam_token_db_auth: if self.iam_token_db_auth:
db_url = os.getenv("DATABASE_URL") db_url = os.getenv(self._db_url_env_var)
# Check if token is expired (should be rare if background task is running) # Check if token is expired (should be rare if background task is running)
if self.is_token_expired(db_url): if self.is_token_expired(db_url):
verbose_proxy_logger.warning( try:
"RDS IAM token expired in __getattr__ - proactive refresh may have failed. " running_loop = asyncio.get_running_loop()
"Triggering synchronous fallback refresh..." except RuntimeError:
) running_loop = None
new_db_url = self.get_rds_iam_token() if running_loop is not None:
if new_db_url: verbose_proxy_logger.warning(
loop = asyncio.get_event_loop() "%sRDS IAM token expired in __getattr__ — proactive refresh "
"may have failed. Scheduling async refresh; the current "
if loop.is_running(): "request may fail and be retried with the fresh token.",
# FIXED: Actually wait for the reconnection to complete! self._log_prefix,
# The previous code used fire-and-forget which caused the bug. )
future = asyncio.run_coroutine_threadsafe( # Non-blocking: schedule the locked refresh on the
self.recreate_prisma_client(new_db_url), loop # running loop. The reconnection lock inside
) # `_safe_refresh_token` coalesces concurrent triggers.
try: running_loop.create_task(self._safe_refresh_token())
# Wait up to 30 seconds for reconnection
future.result(timeout=30)
verbose_proxy_logger.info(
"Synchronous token refresh completed successfully"
)
except Exception as e:
verbose_proxy_logger.error(
f"Failed to refresh token synchronously: {e}"
)
raise
else:
asyncio.run(self.recreate_prisma_client(new_db_url))
# Get the NEW attribute after reconnection
original_attr = getattr(self._original_prisma, name)
else: else:
raise ValueError("Failed to get RDS IAM token") verbose_proxy_logger.warning(
"%sRDS IAM token expired in __getattr__ — proactive refresh "
"may have failed. Triggering synchronous fallback refresh...",
self._log_prefix,
)
new_db_url = self.get_rds_iam_token()
if new_db_url:
asyncio.run(self.recreate_prisma_client(new_db_url))
# Re-fetch attribute against the recreated Prisma instance.
original_attr = getattr(self._original_prisma, name)
verbose_proxy_logger.info(
"%sSynchronous token refresh completed successfully",
self._log_prefix,
)
else:
raise ValueError("Failed to get RDS IAM token")
return original_attr return original_attr

View File

@ -0,0 +1,213 @@
"""
RoutingPrismaWrapper: routes Prisma reads to a read-replica client and writes
to a writer client. Used when DATABASE_URL_READ_REPLICA is configured;
otherwise PrismaClient uses the writer-only PrismaWrapper directly.
"""
import os
from typing import Any, Callable, Optional
from litellm._logging import verbose_proxy_logger
from litellm.proxy.db.prisma_client import PrismaWrapper
# Per-model action methods that read from the database. These are routed to
# the read replica when one is configured.
_MODEL_READ_METHODS = frozenset(
{
"find_first",
"find_first_or_raise",
"find_many",
"find_unique",
"find_unique_or_raise",
"count",
"group_by",
"query_first",
"query_raw",
}
)
# Top-level Prisma client methods that read from the database.
_TOP_LEVEL_READ_METHODS = frozenset({"query_first", "query_raw"})
class _RoutedActions:
"""Per-model accessor that sends reads to the reader and writes to the writer.
`should_use_reader` is consulted on every read dispatch so a mid-call flip
of the routing wrapper's reader-availability flag (e.g. after the reader
fails a recreate) is observed without re-fetching the actions accessor.
"""
__slots__ = ("_writer_actions", "_reader_actions", "_should_use_reader")
def __init__(
self,
writer_actions: Any,
reader_actions: Any,
should_use_reader: Callable[[], bool],
):
self._writer_actions = writer_actions
self._reader_actions = reader_actions
self._should_use_reader = should_use_reader
def __getattr__(self, name: str) -> Any:
if name in _MODEL_READ_METHODS and self._should_use_reader():
return getattr(self._reader_actions, name)
return getattr(self._writer_actions, name)
class RoutingPrismaWrapper:
"""
Routes Prisma operations between a writer and a reader Prisma client.
Reads (find_*, count, group_by, query_raw, query_first) go to the reader;
everything else (writes, transactions, raw execute) goes to the writer.
Lifecycle methods (connect, disconnect, IAM token refresh) act on both
clients so callers do not need to know about the split. When
IAM_TOKEN_DB_AUTH is enabled, both writer and reader refresh their tokens
independently on their own ~12-minute cadence.
Reader degradation: a reader-side failure (failed connect, failed
recreate) is non-fatal the wrapper sets `_reader_unavailable=True`, logs
a warning, and routes subsequent reads to the writer. The next successful
`connect()` or `recreate_prisma_client()` clears the flag. This keeps the
proxy serving traffic during transient reader outages instead of failing
startup or returning errors for read-heavy endpoints.
"""
def __init__(self, writer: PrismaWrapper, reader: PrismaWrapper):
self._writer = writer
self._reader = reader
# When True, reads fall back to the writer. Flipped on by reader
# connect/recreate failures and flipped off on the next reader recovery.
self._reader_unavailable: bool = False
@property
def writer(self) -> PrismaWrapper:
return self._writer
@property
def reader(self) -> PrismaWrapper:
return self._reader
@property
def reader_unavailable(self) -> bool:
return self._reader_unavailable
def _should_use_reader(self) -> bool:
return not self._reader_unavailable
async def connect(self, *args: Any, **kwargs: Any) -> None:
await self._writer.connect(*args, **kwargs)
verbose_proxy_logger.info("[writer] DB connected")
try:
await self._reader.connect(*args, **kwargs)
self._reader_unavailable = False
verbose_proxy_logger.info("[reader] DB connected")
except Exception as e:
# Degrade gracefully: the proxy keeps serving traffic with reads
# routed to the writer until the reader endpoint is reachable.
# Aborting startup here would tie proxy availability to an
# opt-in, best-effort reader endpoint.
self._reader_unavailable = True
verbose_proxy_logger.warning(
"Failed to connect to read replica DB: %s. "
"Falling back to the writer for reads until the reader is reachable.",
e,
)
async def disconnect(self, *args: Any, **kwargs: Any) -> None:
first_error: Optional[BaseException] = None
for client in (self._writer, self._reader):
try:
await client.disconnect(*args, **kwargs)
except Exception as e:
if first_error is None:
first_error = e
verbose_proxy_logger.warning("Error disconnecting Prisma client: %s", e)
if first_error is not None:
raise first_error
def is_connected(self) -> bool:
# Reflects writer health only. The reader is best-effort; its
# availability is tracked via `_reader_unavailable` and a degraded
# reader must NOT cause a writer reconnect (would loop indefinitely
# since recreate_prisma_client only fixes writer-side problems).
return bool(self._writer.is_connected())
async def start_token_refresh_task(self) -> None:
await self._writer.start_token_refresh_task()
await self._reader.start_token_refresh_task()
async def stop_token_refresh_task(self) -> None:
await self._writer.stop_token_refresh_task()
await self._reader.stop_token_refresh_task()
async def recreate_prisma_client(
self, new_db_url: str, http_client: Optional[Any] = None
) -> None:
"""Recreate both writer and reader Prisma clients.
The writer reconnect path in PrismaClient calls
`self.db.recreate_prisma_client(...)`. Without this method, a DB-wide
connectivity event would only re-create the writer; the reader engine
would stay broken and every routed read would fail. We always recreate
the writer first (its URL is the one passed in), then best-effort
recreate the reader. A reader failure flips `_reader_unavailable=True`
so reads transparently fall through to the writer.
"""
await self._writer.recreate_prisma_client(new_db_url, http_client=http_client)
try:
await self._recreate_reader(http_client=http_client)
self._reader_unavailable = False
except Exception as e:
self._reader_unavailable = True
verbose_proxy_logger.warning(
"Failed to recreate reader Prisma client: %s. "
"Reads will fall back to the writer until the reader recovers.",
e,
)
async def _recreate_reader(self, http_client: Optional[Any] = None) -> None:
"""Resolve the reader URL and recreate its Prisma client.
IAM-enabled readers regenerate their token (host/port/user came from
the parsed reader URL at construction time). Non-IAM readers reuse
the URL stored in `DATABASE_URL_READ_REPLICA`.
"""
if self._reader.iam_token_db_auth:
new_reader_url = self._reader.get_rds_iam_token()
if not new_reader_url:
raise RuntimeError(
"Failed to generate fresh IAM token for read replica"
)
await self._reader.recreate_prisma_client(
new_reader_url, http_client=http_client
)
return
reader_url = os.getenv("DATABASE_URL_READ_REPLICA", "")
if not reader_url:
raise RuntimeError(
"DATABASE_URL_READ_REPLICA not set; cannot recreate read replica client"
)
await self._reader.recreate_prisma_client(reader_url, http_client=http_client)
def __getattr__(self, name: str) -> Any:
if name in _TOP_LEVEL_READ_METHODS:
target = self._writer if self._reader_unavailable else self._reader
return getattr(target, name)
writer_attr = getattr(self._writer, name)
# Per-model action accessors are non-callable instances that expose
# both `find_many` and `create`. Methods like execute_raw / batch_ /
# tx are callables and stay on the writer untouched.
if (
not callable(writer_attr)
and hasattr(writer_attr, "find_many")
and hasattr(writer_attr, "create")
):
try:
reader_attr = getattr(self._reader, name)
except AttributeError:
return writer_attr
return _RoutedActions(writer_attr, reader_attr, self._should_use_reader)
return writer_attr

View File

@ -813,7 +813,12 @@ def run_server( # noqa: PLR0915
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
db_host = os.getenv("DATABASE_HOST") db_host = os.getenv("DATABASE_HOST")
db_port = os.getenv("DATABASE_PORT") # Default to the Postgres standard port. Without a default,
# `db_port=None` flows into `boto.generate_db_auth_token(Port=None)`
# and botocore stringifies it to `"None"` while building the
# presigned URL, which then blows up with `ValueError: Port could
# not be cast to integer value as 'None'` during signing.
db_port = os.getenv("DATABASE_PORT", "5432")
db_user = os.getenv("DATABASE_USER") db_user = os.getenv("DATABASE_USER")
db_name = os.getenv("DATABASE_NAME") db_name = os.getenv("DATABASE_NAME")
db_schema = os.getenv("DATABASE_SCHEMA") db_schema = os.getenv("DATABASE_SCHEMA")

View File

@ -113,7 +113,11 @@ from litellm.proxy.db.exception_handler import (
call_with_db_reconnect_retry, call_with_db_reconnect_retry,
) )
from litellm.proxy.db.log_db_metrics import log_db_metrics from litellm.proxy.db.log_db_metrics import log_db_metrics
from litellm.proxy.db.prisma_client import PrismaWrapper from litellm.proxy.db.prisma_client import (
PrismaWrapper,
parse_iam_endpoint_from_url,
)
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import ( from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import (
UnifiedLLMGuardrails, UnifiedLLMGuardrails,
) )
@ -2569,24 +2573,101 @@ class PrismaClient:
raise Exception( raise Exception(
"Unable to find Prisma binaries. Please run 'prisma generate' first." "Unable to find Prisma binaries. Please run 'prisma generate' first."
) )
iam_flag = (
self.iam_token_db_auth if self.iam_token_db_auth is not None else False
)
# When read-replica routing is on, tag log lines with [writer]/[reader]
# so the two wrappers' interleaved IAM refresh logs can be told apart.
# Single-DB deployments get an empty prefix (logs unchanged).
read_replica_url = os.getenv("DATABASE_URL_READ_REPLICA")
writer_log_prefix = "[writer]" if read_replica_url else ""
if http_client is not None: if http_client is not None:
self.db = PrismaWrapper( writer_wrapper = PrismaWrapper(
original_prisma=Prisma(http=http_client), original_prisma=Prisma(http=http_client),
iam_token_db_auth=( iam_token_db_auth=iam_flag,
self.iam_token_db_auth log_prefix=writer_log_prefix,
if self.iam_token_db_auth is not None
else False
),
) )
else: else:
self.db = PrismaWrapper( writer_wrapper = PrismaWrapper(
original_prisma=Prisma(), original_prisma=Prisma(),
iam_token_db_auth=( iam_token_db_auth=iam_flag,
self.iam_token_db_auth log_prefix=writer_log_prefix,
if self.iam_token_db_auth is not None )
else False
), # Optional read-replica routing. When DATABASE_URL_READ_REPLICA is set,
) # Client to connect to Prisma db # reads (find_*, count, group_by, query_raw/_first) are routed to the
# reader endpoint and writes stay on the writer. Falls back to the
# writer-only wrapper when the env var is unset, preserving existing
# single-DB deployments.
self.db: Union[PrismaWrapper, RoutingPrismaWrapper]
if read_replica_url:
try:
# If IAM auth is enabled, the reader refreshes its own token on
# the same cadence as the writer. We parse the static endpoint
# pieces (host/port/user/db) once from the reader URL — only
# the IAM token rotates after that.
reader_iam_endpoint = (
parse_iam_endpoint_from_url(read_replica_url) if iam_flag else None
)
# Mint a fresh IAM token for the reader BEFORE constructing the
# Prisma client. Mirrors what `proxy_cli.py` already does for
# the writer (proxy_cli.py:812-832) — without this, the reader
# Prisma is built with whatever placeholder URL the user
# supplied (no real token), and the first query falls through
# to the synchronous fallback path in
# `PrismaWrapper.__getattr__`, which deadlocks the event loop
# and times out after 30s.
if iam_flag and reader_iam_endpoint is not None:
from litellm.proxy.auth.rds_iam_token import (
generate_iam_auth_token,
)
reader_token = generate_iam_auth_token(
db_host=reader_iam_endpoint.host,
db_port=reader_iam_endpoint.port,
db_user=reader_iam_endpoint.user,
)
read_replica_url = reader_iam_endpoint.build_url(reader_token)
os.environ["DATABASE_URL_READ_REPLICA"] = read_replica_url
reader_kwargs: Dict[str, Any] = {
"datasource": {"url": read_replica_url}
}
if http_client is not None:
reader_prisma = Prisma(http=http_client, **reader_kwargs)
else:
reader_prisma = Prisma(**reader_kwargs)
reader_wrapper = PrismaWrapper(
original_prisma=reader_prisma,
iam_token_db_auth=iam_flag,
db_url_env_var="DATABASE_URL_READ_REPLICA",
iam_endpoint=reader_iam_endpoint,
recreate_uses_datasource=True,
log_prefix="[reader]",
)
self.db = RoutingPrismaWrapper(
writer=writer_wrapper, reader=reader_wrapper
)
verbose_proxy_logger.info(
"PrismaClient: read-replica routing enabled via DATABASE_URL_READ_REPLICA"
+ (" (with IAM token auto-refresh)" if iam_flag else "")
)
except Exception as e:
# Reader is opt-in; never let its construction fail proxy
# startup. Mirrors the runtime contract from
# `RoutingPrismaWrapper.connect`: reader-side failures are
# logged and we keep serving traffic via the writer alone.
# This recovers from transient AWS STS hiccups during the
# reader IAM token mint, malformed DATABASE_URL_READ_REPLICA,
# and Prisma construction errors. Operator restart is required
# to retry read-routing once the underlying issue is resolved.
verbose_proxy_logger.warning(
"Failed to initialize read replica Prisma client: %s. "
"Falling back to writer-only mode (no read routing) until proxy restart.",
e,
)
self.db = writer_wrapper
else:
self.db = writer_wrapper # Client to connect to Prisma db
self._db_reconnect_lock = asyncio.Lock() self._db_reconnect_lock = asyncio.Lock()
self._db_health_watchdog_task: Optional[asyncio.Task] = None self._db_health_watchdog_task: Optional[asyncio.Task] = None
self._db_last_reconnect_attempt_ts: float = 0.0 self._db_last_reconnect_attempt_ts: float = 0.0
@ -2624,6 +2705,13 @@ class PrismaClient:
self._engine_wait_thread: Optional[threading.Thread] = None self._engine_wait_thread: Optional[threading.Thread] = None
verbose_proxy_logger.debug("Success - Created Prisma Client") verbose_proxy_logger.debug("Success - Created Prisma Client")
@property
def writer_db(self) -> PrismaWrapper:
"""Underlying writer Prisma wrapper, regardless of read-replica routing."""
if isinstance(self.db, RoutingPrismaWrapper):
return self.db.writer
return self.db
def get_request_status( def get_request_status(
self, payload: Union[dict, SpendLogsPayload] self, payload: Union[dict, SpendLogsPayload]
) -> Literal["success", "failure"]: ) -> Literal["success", "failure"]:
@ -4272,7 +4360,10 @@ class PrismaClient:
self._cleanup_engine_watcher() self._cleanup_engine_watcher()
await self.db.recreate_prisma_client(db_url) await self.db.recreate_prisma_client(db_url)
await self._start_engine_watcher() await self._start_engine_watcher()
await self.db.query_raw("SELECT 1") # Smoke-test the writer specifically; query_raw on the routing
# wrapper sends to the reader, which would not validate the
# newly-recreated writer engine.
await self.writer_db.query_raw("SELECT 1")
await asyncio.wait_for(_do_direct_reconnect(), timeout=effective_timeout) await asyncio.wait_for(_do_direct_reconnect(), timeout=effective_timeout)

View File

@ -0,0 +1,887 @@
import asyncio
import logging
import os
import sys
from typing import Any, Dict
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
sys.path.insert(0, os.path.abspath("../../../.."))
# NOTE: do NOT patch sys.modules["prisma"] file-wide via an autouse fixture.
# Doing so leaks across pytest-xdist test scheduling: when a worker runs a
# routing test, then later runs test_exception_handler.py, the cached MagicMock
# attribute references break `isinstance(e, prisma.errors.X)` in
# `is_database_transport_error`. The two tests below that actually need to
# stub the prisma SDK do so per-test via monkeypatch, which is properly scoped.
def _make_wrappers():
from litellm.proxy.db.prisma_client import PrismaWrapper
writer_inner = MagicMock(name="writer_prisma")
reader_inner = MagicMock(name="reader_prisma")
writer = PrismaWrapper(original_prisma=writer_inner, iam_token_db_auth=False)
reader = PrismaWrapper(original_prisma=reader_inner, iam_token_db_auth=False)
return writer, writer_inner, reader, reader_inner
class _FakeActions:
"""Stand-in for a Prisma per-model Actions instance (non-callable, has find_many/create)."""
def __init__(self, name: str):
self._name = name
for method in (
"find_many",
"find_unique",
"find_first",
"count",
"group_by",
"create",
"update",
"upsert",
"delete",
"delete_many",
"update_many",
):
setattr(self, method, MagicMock(name=f"{name}.{method}"))
def _model_actions_mock(name: str) -> _FakeActions:
return _FakeActions(name)
def test_top_level_query_raw_routes_to_reader():
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer, writer_inner, reader, reader_inner = _make_wrappers()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
# query_raw should resolve to the reader's underlying client.
assert routing.query_raw is reader_inner.query_raw
assert routing.query_first is reader_inner.query_first
def test_top_level_execute_raw_routes_to_writer():
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer, writer_inner, reader, reader_inner = _make_wrappers()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
# execute_raw, batch_, tx are write-side and must hit the writer.
assert routing.execute_raw is writer_inner.execute_raw
assert routing.batch_ is writer_inner.batch_
assert routing.tx is writer_inner.tx
def test_per_model_reads_route_to_reader_writes_to_writer():
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer, writer_inner, reader, reader_inner = _make_wrappers()
writer_inner.litellm_usertable = _model_actions_mock("writer_users")
reader_inner.litellm_usertable = _model_actions_mock("reader_users")
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
actions = routing.litellm_usertable
# Reads → reader actions.
assert actions.find_many is reader_inner.litellm_usertable.find_many
assert actions.find_unique is reader_inner.litellm_usertable.find_unique
assert actions.find_first is reader_inner.litellm_usertable.find_first
assert actions.count is reader_inner.litellm_usertable.count
assert actions.group_by is reader_inner.litellm_usertable.group_by
# Writes → writer actions.
assert actions.create is writer_inner.litellm_usertable.create
assert actions.update is writer_inner.litellm_usertable.update
assert actions.upsert is writer_inner.litellm_usertable.upsert
assert actions.delete is writer_inner.litellm_usertable.delete
assert actions.update_many is writer_inner.litellm_usertable.update_many
assert actions.delete_many is writer_inner.litellm_usertable.delete_many
@pytest.mark.asyncio
async def test_connect_invokes_both_clients():
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer, writer_inner, reader, reader_inner = _make_wrappers()
writer_inner.connect = AsyncMock()
reader_inner.connect = AsyncMock()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
await routing.connect()
writer_inner.connect.assert_awaited_once()
reader_inner.connect.assert_awaited_once()
@pytest.mark.asyncio
async def test_connect_logs_writer_and_reader_success(caplog):
"""Successful startup emits a positive INFO confirmation for both writer
and reader so operators can verify connectivity without inspecting the URL
in logs."""
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer, writer_inner, reader, reader_inner = _make_wrappers()
writer_inner.connect = AsyncMock()
reader_inner.connect = AsyncMock()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
with caplog.at_level(logging.INFO, logger="LiteLLM Proxy"):
await routing.connect()
messages = [r.getMessage() for r in caplog.records]
assert "[writer] DB connected" in messages
assert "[reader] DB connected" in messages
@pytest.mark.asyncio
async def test_disconnect_continues_when_one_side_fails():
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer, writer_inner, reader, reader_inner = _make_wrappers()
writer_inner.disconnect = AsyncMock(side_effect=RuntimeError("writer down"))
reader_inner.disconnect = AsyncMock()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
with pytest.raises(RuntimeError, match="writer down"):
await routing.disconnect()
# Reader still attempted even though writer raised.
reader_inner.disconnect.assert_awaited_once()
def test_is_connected_reflects_writer_only():
"""is_connected() must NOT depend on reader health — a healthy writer with
a degraded reader should report True so that PrismaClient.connect()'s
health check does not re-trigger a writer reconnect (which only fixes
writer-side problems and would loop indefinitely)."""
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer, writer_inner, reader, reader_inner = _make_wrappers()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
writer_inner.is_connected = MagicMock(return_value=True)
reader_inner.is_connected = MagicMock(return_value=True)
assert routing.is_connected() is True
# Reader down → still True (reader degradation is tracked separately).
reader_inner.is_connected = MagicMock(return_value=False)
assert routing.is_connected() is True
# Writer down → False.
writer_inner.is_connected = MagicMock(return_value=False)
assert routing.is_connected() is False
def test_token_refresh_delegates_to_both_writer_and_reader():
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer = MagicMock()
writer.start_token_refresh_task = AsyncMock()
writer.stop_token_refresh_task = AsyncMock()
reader = MagicMock()
reader.start_token_refresh_task = AsyncMock()
reader.stop_token_refresh_task = AsyncMock()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
asyncio.run(routing.start_token_refresh_task())
asyncio.run(routing.stop_token_refresh_task())
# Both wrappers get start/stop — each manages its own IAM token. When
# IAM is disabled on a wrapper its task body is a no-op.
writer.start_token_refresh_task.assert_awaited_once()
writer.stop_token_refresh_task.assert_awaited_once()
reader.start_token_refresh_task.assert_awaited_once()
reader.stop_token_refresh_task.assert_awaited_once()
def test_routed_actions_falls_back_to_writer_for_unknown_methods():
from litellm.proxy.db.routing_prisma_wrapper import _RoutedActions
writer_actions = _model_actions_mock("writer")
writer_actions.some_custom_method = "writer-custom"
reader_actions = _model_actions_mock("reader")
reader_actions.some_custom_method = "reader-custom"
routed = _RoutedActions(writer_actions, reader_actions, lambda: True)
# Unknown method → defaults to writer (safe fallback for write-like ops).
assert routed.some_custom_method == "writer-custom"
def test_routed_actions_respects_should_use_reader_flag():
"""When the routing wrapper marks the reader unavailable, _RoutedActions
must redirect reads to the writer instead without needing to re-fetch
the actions accessor."""
from litellm.proxy.db.routing_prisma_wrapper import _RoutedActions
writer_actions = _model_actions_mock("writer")
reader_actions = _model_actions_mock("reader")
use_reader = {"value": True}
routed = _RoutedActions(writer_actions, reader_actions, lambda: use_reader["value"])
# Reader healthy → reads to reader.
assert routed.find_many is reader_actions.find_many
# Reader degrades mid-flight → next read goes to writer.
use_reader["value"] = False
assert routed.find_many is writer_actions.find_many
# ---------------------------------------------------------------------------
# Reader graceful degradation
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_connect_swallows_reader_failure_and_falls_back_to_writer():
"""A reader connect failure must NOT abort proxy startup. The wrapper
flips into degraded mode so subsequent reads route to the writer."""
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer, writer_inner, reader, reader_inner = _make_wrappers()
writer_inner.connect = AsyncMock()
reader_inner.connect = AsyncMock(side_effect=RuntimeError("reader unreachable"))
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
# Must not raise — reader failure is non-fatal.
await routing.connect()
assert routing.reader_unavailable is True
writer_inner.connect.assert_awaited_once()
reader_inner.connect.assert_awaited_once()
@pytest.mark.asyncio
async def test_reads_route_to_writer_when_reader_unavailable():
"""Top-level read methods and per-model reads must fall through to the
writer while the reader is degraded."""
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer, writer_inner, reader, reader_inner = _make_wrappers()
writer_inner.litellm_usertable = _model_actions_mock("writer_users")
reader_inner.litellm_usertable = _model_actions_mock("reader_users")
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
routing._reader_unavailable = True
# Top-level reads → writer.
assert routing.query_raw is writer_inner.query_raw
assert routing.query_first is writer_inner.query_first
# Per-model reads → writer actions.
actions = routing.litellm_usertable
assert actions.find_many is writer_inner.litellm_usertable.find_many
assert actions.find_unique is writer_inner.litellm_usertable.find_unique
@pytest.mark.asyncio
async def test_recreate_prisma_client_recreates_both_writer_and_reader():
"""Writer reconnect path calls recreate_prisma_client. The routing wrapper
must recreate BOTH clients so a DB-wide event doesn't leave a stale reader."""
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer = MagicMock()
writer.recreate_prisma_client = AsyncMock()
reader = MagicMock()
reader.iam_token_db_auth = False
reader.recreate_prisma_client = AsyncMock()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
with patch.dict(os.environ, {"DATABASE_URL_READ_REPLICA": "reader-url"}):
await routing.recreate_prisma_client("writer-url", http_client=None)
writer.recreate_prisma_client.assert_awaited_once_with(
"writer-url", http_client=None
)
reader.recreate_prisma_client.assert_awaited_once_with(
"reader-url", http_client=None
)
assert routing.reader_unavailable is False
@pytest.mark.asyncio
async def test_recreate_recovers_reader_after_prior_degradation():
"""If a previous connect/recreate degraded the reader, a successful
recreate must clear the flag so reads start hitting the reader again."""
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer = MagicMock()
writer.recreate_prisma_client = AsyncMock()
reader = MagicMock()
reader.iam_token_db_auth = False
reader.recreate_prisma_client = AsyncMock()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
routing._reader_unavailable = True
with patch.dict(os.environ, {"DATABASE_URL_READ_REPLICA": "reader-url"}):
await routing.recreate_prisma_client("writer-url")
assert routing.reader_unavailable is False
@pytest.mark.asyncio
async def test_recreate_degrades_reader_if_reader_recreate_fails():
"""If the reader recreate fails, writer recreate still succeeds and the
routing wrapper degrades (does not raise)."""
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer = MagicMock()
writer.recreate_prisma_client = AsyncMock()
reader = MagicMock()
reader.iam_token_db_auth = False
reader.recreate_prisma_client = AsyncMock(
side_effect=RuntimeError("reader still down")
)
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
with patch.dict(os.environ, {"DATABASE_URL_READ_REPLICA": "reader-url"}):
# Must not raise — writer was recreated, reader is best-effort.
await routing.recreate_prisma_client("writer-url")
writer.recreate_prisma_client.assert_awaited_once()
assert routing.reader_unavailable is True
@pytest.mark.asyncio
async def test_recreate_degrades_reader_when_replica_url_missing():
"""Non-IAM reader needs DATABASE_URL_READ_REPLICA. If it's missing
(configuration drift), the wrapper degrades instead of raising."""
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer = MagicMock()
writer.recreate_prisma_client = AsyncMock()
reader = MagicMock()
reader.iam_token_db_auth = False
reader.recreate_prisma_client = AsyncMock()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
# Ensure env var is absent.
with patch.dict(os.environ, {}, clear=False):
os.environ.pop("DATABASE_URL_READ_REPLICA", None)
await routing.recreate_prisma_client("writer-url")
writer.recreate_prisma_client.assert_awaited_once()
reader.recreate_prisma_client.assert_not_awaited()
assert routing.reader_unavailable is True
@pytest.mark.asyncio
async def test_recreate_iam_reader_refreshes_token():
"""IAM-enabled readers must refresh their token (reader has its own parsed
endpoint) and pass the fresh URL to recreate_prisma_client."""
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer = MagicMock()
writer.recreate_prisma_client = AsyncMock()
reader = MagicMock()
reader.iam_token_db_auth = True
reader.get_rds_iam_token = MagicMock(return_value="postgresql://u:fresh@h:5432/db")
reader.recreate_prisma_client = AsyncMock()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
await routing.recreate_prisma_client("writer-url")
reader.get_rds_iam_token.assert_called_once()
reader.recreate_prisma_client.assert_awaited_once_with(
"postgresql://u:fresh@h:5432/db", http_client=None
)
assert routing.reader_unavailable is False
@pytest.mark.asyncio
async def test_recreate_degrades_when_iam_token_generation_returns_none():
"""If `get_rds_iam_token` returns None (e.g. AWS-side failure), the wrapper
must degrade rather than crash this exercises the explicit `raise
RuntimeError` inside `_recreate_reader`'s IAM branch."""
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer = MagicMock()
writer.recreate_prisma_client = AsyncMock()
reader = MagicMock()
reader.iam_token_db_auth = True
reader.get_rds_iam_token = MagicMock(return_value=None)
reader.recreate_prisma_client = AsyncMock()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
await routing.recreate_prisma_client("writer-url")
writer.recreate_prisma_client.assert_awaited_once()
reader.recreate_prisma_client.assert_not_awaited()
assert routing.reader_unavailable is True
def test_writer_and_reader_properties_expose_underlying_wrappers():
"""The `writer` and `reader` properties are used by PrismaClient.writer_db
to smoke-test the writer specifically during reconnect they must return
the exact wrappers passed in."""
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
writer, _, reader, _ = _make_wrappers()
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
assert routing.writer is writer
assert routing.reader is reader
def test_per_model_accessor_falls_back_when_reader_lacks_attr():
"""If the reader Prisma client somehow lacks a model accessor that the
writer has (older client / partial mock), the wrapper must fall back to
the writer accessor instead of raising AttributeError to the caller."""
from litellm.proxy.db.prisma_client import PrismaWrapper
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
# Plain class with only the accessor set on the writer side. Using a real
# class instead of MagicMock so attribute access raises AttributeError
# naturally instead of auto-creating mock attributes.
class _PartialPrisma:
pass
writer_inner = _PartialPrisma()
writer_inner.litellm_usertable = _model_actions_mock("writer_users")
reader_inner = _PartialPrisma() # deliberately missing litellm_usertable
writer = PrismaWrapper(original_prisma=writer_inner, iam_token_db_auth=False)
reader = PrismaWrapper(original_prisma=reader_inner, iam_token_db_auth=False)
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
actions = routing.litellm_usertable
# Falls back to the writer's accessor verbatim — not a _RoutedActions wrapper.
assert actions is writer_inner.litellm_usertable
@pytest.mark.asyncio
async def test_writer_recreate_passes_http_client_through(monkeypatch):
"""When PrismaClient is constructed with an http_client, recreate must
forward it to the new Prisma() so connection settings persist across
reconnects."""
from litellm.proxy.db.prisma_client import PrismaWrapper
captured_kwargs: Dict[str, Any] = {}
class FakePrisma:
def __init__(self, **kwargs):
captured_kwargs.update(kwargs)
async def connect(self):
return None
fake_module = MagicMock()
fake_module.Prisma = FakePrisma
monkeypatch.setitem(sys.modules, "prisma", fake_module)
writer = PrismaWrapper(original_prisma=MagicMock(), iam_token_db_auth=False)
sentinel_http = object()
await writer.recreate_prisma_client(
"postgresql://u:p@h:5432/db", http_client=sentinel_http
)
assert captured_kwargs == {"http": sentinel_http}
# ---------------------------------------------------------------------------
# IAM endpoint parsing + reader IAM refresh
# ---------------------------------------------------------------------------
def test_parse_iam_endpoint_from_url_extracts_all_fields():
from litellm.proxy.db.prisma_client import parse_iam_endpoint_from_url
ep = parse_iam_endpoint_from_url(
"postgresql://litellm_user:initial-token@aurora-reader.example.com:6543/litellm?schema=public"
)
assert ep.host == "aurora-reader.example.com"
assert ep.port == "6543"
assert ep.user == "litellm_user"
assert ep.name == "litellm"
assert ep.schema == "public"
def test_parse_iam_endpoint_defaults_port_to_5432_and_skips_schema():
from litellm.proxy.db.prisma_client import parse_iam_endpoint_from_url
ep = parse_iam_endpoint_from_url("postgresql://u@host/dbname")
assert ep.host == "host"
assert ep.port == "5432"
assert ep.user == "u"
assert ep.name == "dbname"
assert ep.schema is None
def test_parse_iam_endpoint_rejects_url_without_user_or_dbname():
from litellm.proxy.db.prisma_client import parse_iam_endpoint_from_url
with pytest.raises(ValueError, match="missing host or username"):
parse_iam_endpoint_from_url("postgresql://host:5432/db")
with pytest.raises(ValueError, match="missing database name"):
parse_iam_endpoint_from_url("postgresql://u@host:5432/")
def test_iam_endpoint_build_url_inserts_token_verbatim():
from litellm.proxy.db.prisma_client import IAMEndpoint
# `generate_iam_auth_token` already URL-encodes the presigned token, so
# `build_url` must NOT encode again — double-encoding turned `%3D` into
# `%253D` and broke RDS auth on the reader path.
ep = IAMEndpoint(host="h", port="5432", user="u", name="db", schema="public")
pre_encoded_token = "token%2Fwith%3Fweird%26chars%3Dyes"
url = ep.build_url(pre_encoded_token)
assert url == f"postgresql://u:{pre_encoded_token}@h:5432/db?schema=public"
# Sanity check: no `%25` (the encoding of `%`), confirming we didn't re-encode.
assert "%25" not in url
@pytest.mark.asyncio
async def test_iam_refresh_logs_carry_log_prefix(caplog):
"""When `log_prefix` is set on a PrismaWrapper, every IAM-related log
line emitted by that wrapper must start with the prefix so writer and
reader can be told apart in interleaved output."""
from litellm.proxy.db.prisma_client import PrismaWrapper
wrapper = PrismaWrapper(
original_prisma=MagicMock(),
iam_token_db_auth=True,
log_prefix="[reader]",
)
with caplog.at_level(logging.INFO, logger="LiteLLM Proxy"):
await wrapper.start_token_refresh_task()
# Loop emits "RDS IAM token refresh loop started..." on first tick.
# Cancel immediately so the loop body runs once and we can assert.
await wrapper.stop_token_refresh_task()
messages = [r.getMessage() for r in caplog.records]
# Both start and stop notifications carry the prefix.
assert any(
m.startswith("[reader] Started RDS IAM token proactive refresh")
for m in messages
)
assert any(
m.startswith("[reader] Stopped RDS IAM token refresh background task")
for m in messages
)
def test_get_rds_iam_token_returns_none_when_iam_disabled():
"""`get_rds_iam_token` short-circuits to None when iam_token_db_auth is
False covers the early-return guard at the top of the method."""
from litellm.proxy.db.prisma_client import PrismaWrapper
wrapper = PrismaWrapper(original_prisma=MagicMock(), iam_token_db_auth=False)
assert wrapper.get_rds_iam_token() is None
@pytest.mark.asyncio
async def test_getattr_does_not_block_inside_running_loop_on_expired_token(monkeypatch):
"""When `__getattr__` runs inside a running event loop and the IAM token
is expired, it MUST schedule the refresh as a background task and return
immediately. The previous `run_coroutine_threadsafe` + `future.result()`
pattern deadlocks the loop (loop thread blocks waiting for a coroutine
that needs the loop to run) and times out at 30s exactly what was
breaking the reader on first query."""
from litellm.proxy.db.prisma_client import PrismaWrapper
# Stale URL — `is_token_expired` returns True because the password isn't
# a parseable IAM token, so we exercise the expired branch.
monkeypatch.setenv(
"DATABASE_URL_READ_REPLICA",
"postgresql://reader:placeholder@reader.aurora.local:5432/litellm",
)
inner = MagicMock()
inner.query_raw = MagicMock(name="query_raw_attr")
wrapper = PrismaWrapper(
original_prisma=inner,
iam_token_db_auth=True,
db_url_env_var="DATABASE_URL_READ_REPLICA",
)
# Replace the heavy refresh coroutine with a no-op AsyncMock so we can
# observe whether it was scheduled without actually doing the recreate.
refresh_calls = {"count": 0}
async def fake_refresh():
refresh_calls["count"] += 1
monkeypatch.setattr(wrapper, "_safe_refresh_token", fake_refresh)
# Direct attribute access from inside this async test runs __getattr__
# on the loop thread, exercising the in-loop branch. If the previous
# `run_coroutine_threadsafe` + `future.result()` pattern were back, this
# line would deadlock the loop and the test would hang (and pytest's
# per-test timeout would catch it).
attr = wrapper.query_raw
# Yield once so the scheduled refresh task gets a chance to run.
await asyncio.sleep(0)
assert attr is inner.query_raw
assert refresh_calls["count"] == 1
def test_writer_get_rds_iam_token_defaults_port_when_unset(monkeypatch):
"""When DATABASE_PORT is unset, the writer must default to the Postgres
standard port instead of passing `None` through. Passing None to
`generate_iam_auth_token` makes botocore embed the literal string
\"None\" in the presigned URL during signing and crashes with
`ValueError: Port could not be cast to integer value as 'None'`."""
from litellm.proxy.db.prisma_client import PrismaWrapper
monkeypatch.setenv("DATABASE_HOST", "writer.aurora.local")
monkeypatch.delenv("DATABASE_PORT", raising=False)
monkeypatch.setenv("DATABASE_USER", "litellm")
monkeypatch.setenv("DATABASE_NAME", "litellm")
monkeypatch.delenv("DATABASE_SCHEMA", raising=False)
monkeypatch.delenv("DATABASE_URL", raising=False)
captured: Dict[str, Any] = {}
def fake_generate(db_host=None, db_port=None, db_user=None):
captured["port"] = db_port
return "TOKEN"
fake_module = MagicMock()
fake_module.generate_iam_auth_token = fake_generate
monkeypatch.setitem(sys.modules, "litellm.proxy.auth.rds_iam_token", fake_module)
writer = PrismaWrapper(
original_prisma=MagicMock(),
iam_token_db_auth=True,
)
new_url = writer.get_rds_iam_token()
assert captured["port"] == "5432" # default applied, NOT None
assert ":5432/litellm" in (new_url or "")
def test_writer_get_rds_iam_token_uses_database_host_env_vars(monkeypatch):
"""Writer's IAM path (no iam_endpoint configured) reads host/port/user/db
from the legacy DATABASE_HOST/PORT/USER/NAME env vars and writes the URL
back to DATABASE_URL this is the pre-read-replica behavior the patch
must preserve."""
from litellm.proxy.db.prisma_client import PrismaWrapper
monkeypatch.setenv("DATABASE_HOST", "writer.aurora.local")
monkeypatch.setenv("DATABASE_PORT", "5432")
monkeypatch.setenv("DATABASE_USER", "litellm")
monkeypatch.setenv("DATABASE_NAME", "litellm")
monkeypatch.setenv("DATABASE_SCHEMA", "public")
monkeypatch.delenv("DATABASE_URL", raising=False)
captured: Dict[str, Any] = {}
def fake_generate(db_host=None, db_port=None, db_user=None):
captured["host"] = db_host
captured["port"] = db_port
captured["user"] = db_user
return "WRITER-TOKEN"
fake_module = MagicMock()
fake_module.generate_iam_auth_token = fake_generate
monkeypatch.setitem(sys.modules, "litellm.proxy.auth.rds_iam_token", fake_module)
writer = PrismaWrapper(
original_prisma=MagicMock(),
iam_token_db_auth=True,
# No iam_endpoint → legacy DATABASE_HOST/etc. path.
)
new_url = writer.get_rds_iam_token()
assert captured == {
"host": "writer.aurora.local",
"port": "5432",
"user": "litellm",
}
assert new_url == (
"postgresql://litellm:WRITER-TOKEN@writer.aurora.local:5432/litellm?schema=public"
)
# Writer updates its own env var (DATABASE_URL by default), not the reader's.
assert os.environ["DATABASE_URL"] == new_url
def test_reader_iam_refresh_uses_parsed_endpoint(monkeypatch):
"""The reader generates fresh tokens against its parsed endpoint and
writes the new URL to DATABASE_URL_READ_REPLICA not DATABASE_URL."""
from litellm.proxy.db.prisma_client import IAMEndpoint, PrismaWrapper
# Pre-seed env vars so we can prove the reader does NOT touch DATABASE_URL.
monkeypatch.setenv("DATABASE_URL", "writer-url-untouched")
monkeypatch.setenv("DATABASE_URL_READ_REPLICA", "stale-reader-url")
captured: Dict[str, Any] = {}
def fake_generate(db_host=None, db_port=None, db_user=None):
captured["host"] = db_host
captured["port"] = db_port
captured["user"] = db_user
return "FRESH-TOKEN"
fake_module = MagicMock()
fake_module.generate_iam_auth_token = fake_generate
monkeypatch.setitem(sys.modules, "litellm.proxy.auth.rds_iam_token", fake_module)
endpoint = IAMEndpoint(
host="reader.aurora.local",
port="5432",
user="lit",
name="litellm",
schema=None,
)
reader = PrismaWrapper(
original_prisma=MagicMock(),
iam_token_db_auth=True,
db_url_env_var="DATABASE_URL_READ_REPLICA",
iam_endpoint=endpoint,
recreate_uses_datasource=True,
)
new_url = reader.get_rds_iam_token()
# IAM token generator was called with the reader's parsed endpoint, not
# the writer's DATABASE_HOST/PORT/USER env vars.
assert captured == {
"host": "reader.aurora.local",
"port": "5432",
"user": "lit",
}
assert new_url is not None
assert new_url.startswith(
"postgresql://lit:FRESH-TOKEN@reader.aurora.local:5432/litellm"
)
# The reader updates its OWN env var; writer's DATABASE_URL is left alone.
assert os.environ["DATABASE_URL_READ_REPLICA"] == new_url
assert os.environ["DATABASE_URL"] == "writer-url-untouched"
@pytest.mark.asyncio
async def test_reader_recreate_uses_datasource_override(monkeypatch):
"""Reader recreate must pass `datasource={"url": ...}` to Prisma() — Prisma
only auto-reads DATABASE_URL, so without the override the new reader URL
would be silently ignored."""
from litellm.proxy.db.prisma_client import IAMEndpoint, PrismaWrapper
captured_kwargs: Dict[str, Any] = {}
class FakePrisma:
def __init__(self, **kwargs):
captured_kwargs.update(kwargs)
async def connect(self):
return None
fake_module = MagicMock()
fake_module.Prisma = FakePrisma
monkeypatch.setitem(sys.modules, "prisma", fake_module)
reader = PrismaWrapper(
original_prisma=MagicMock(),
iam_token_db_auth=True,
db_url_env_var="DATABASE_URL_READ_REPLICA",
iam_endpoint=IAMEndpoint(host="h", port="5432", user="u", name="db"),
recreate_uses_datasource=True,
)
await reader.recreate_prisma_client(
"postgresql://u:newtoken@h:5432/db", http_client=None
)
assert captured_kwargs == {
"datasource": {"url": "postgresql://u:newtoken@h:5432/db"}
}
@pytest.mark.asyncio
async def test_writer_recreate_does_not_use_datasource(monkeypatch):
"""Writer keeps relying on Prisma reading DATABASE_URL from env — datasource
override must NOT leak into the writer path (would override the freshly
rotated env var)."""
from litellm.proxy.db.prisma_client import PrismaWrapper
captured_kwargs: Dict[str, Any] = {}
class FakePrisma:
def __init__(self, **kwargs):
captured_kwargs.update(kwargs)
async def connect(self):
return None
fake_module = MagicMock()
fake_module.Prisma = FakePrisma
monkeypatch.setitem(sys.modules, "prisma", fake_module)
writer = PrismaWrapper(
original_prisma=MagicMock(),
iam_token_db_auth=True,
)
await writer.recreate_prisma_client(
"postgresql://u:newtoken@h:5432/db", http_client=None
)
assert "datasource" not in captured_kwargs
def test_prisma_client_init_falls_back_to_writer_when_reader_iam_token_fails(
monkeypatch, caplog
):
"""A transient AWS STS error (or any other failure) during the reader
IAM token mint must NOT abort proxy startup. The reader is opt-in, so
`PrismaClient.__init__` should log a warning and fall back to the
writer-only `PrismaWrapper`. The runtime contract in
`RoutingPrismaWrapper.connect` already says reader-side failures are
non-fatal but that code never runs if construction throws first."""
from litellm.proxy.db.prisma_client import PrismaWrapper
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
monkeypatch.setenv("IAM_TOKEN_DB_AUTH", "true")
monkeypatch.setenv(
"DATABASE_URL_READ_REPLICA",
"postgresql://reader_user@reader.aurora.local:5432/litellm",
)
class FakePrisma:
def __init__(self, **kwargs):
self.kwargs = kwargs
async def connect(self):
return None
fake_prisma_module = MagicMock()
fake_prisma_module.Prisma = FakePrisma
monkeypatch.setitem(sys.modules, "prisma", fake_prisma_module)
fake_iam_module = MagicMock()
def boom(**_kwargs):
raise RuntimeError("simulated AWS STS hiccup")
fake_iam_module.generate_iam_auth_token = boom
monkeypatch.setitem(
sys.modules, "litellm.proxy.auth.rds_iam_token", fake_iam_module
)
from litellm.proxy.utils import PrismaClient
with caplog.at_level(logging.WARNING, logger="LiteLLM Proxy"):
client = PrismaClient(
database_url="postgresql://writer@writer.aurora.local:5432/litellm",
proxy_logging_obj=MagicMock(),
)
# Construction did not raise, and the proxy is in writer-only mode —
# NOT a RoutingPrismaWrapper, so reads will go to the writer.
assert isinstance(client.db, PrismaWrapper)
assert not isinstance(client.db, RoutingPrismaWrapper)
# And the operator gets a clear warning.
assert any(
"Failed to initialize read replica Prisma client" in r.getMessage()
for r in caplog.records
)