From b5d3a5fc856ed1cf9b101d37bd0ec6d6d44751b2 Mon Sep 17 00:00:00 2001 From: Yassin Kortam Date: Fri, 8 May 2026 21:05:50 -0700 Subject: [PATCH] feat: add read-replica routing for Prisma DB via DATABASE_URL_READ_REPLICA (#27493) - Introduce RoutingPrismaWrapper that transparently routes read operations (find_*, count, group_by, query_raw, query_first) to a reader endpoint while writes remain on the writer, enabling Aurora-style reader/writer endpoint splits - Add IAMEndpoint dataclass and parse_iam_endpoint_from_url() to capture static connection fields from a reader URL so only the IAM token needs to rotate, avoiding the need for separate DATABASE_HOST_READ_REPLICA/etc. env vars - Enhance PrismaWrapper with per-instance knobs (db_url_env_var, iam_endpoint, recreate_uses_datasource, log_prefix) so writer and reader wrappers are independent: the reader writes its fresh URL to DATABASE_URL_READ_REPLICA and passes datasource override to Prisma since Prisma only auto-reads DATABASE_URL - Fix deadlock in PrismaWrapper.__getattr__: when called from inside a running event loop, schedule the token refresh as a background task instead of blocking with run_coroutine_threadsafe + future.result(), which would deadlock the loop thread waiting for a coroutine that needs the loop to run - Fix botocore crash when DATABASE_PORT is unset by defaulting to "5432" in both proxy_cli.py and PrismaWrapper.get_rds_iam_token(); passing None caused botocore to embed the literal string "None" in the presigned URL - Implement graceful reader degradation: reader connect/recreate failures are non-fatal; wrapper sets _reader_unavailable=True and silently routes reads to the writer to keep the proxy serving traffic during transient reader outages - Add PrismaClient.writer_db property so the reconnect smoke-test always validates the writer engine specifically; query_raw on the routing wrapper would route to the reader and not verify the newly-recreated writer - Expose DATABASE_URL_READ_REPLICA in Helm chart (values.yaml + deployment.yaml) via both plain value and secret key reference, and document the field in docker-compose.yml - Add 887-line test suite covering routing logic, IAM token refresh paths, reader degradation scenarios, datasource override behavior, and the deadlock regression Co-authored-by: Yassin Kortam --- .../litellm-helm/templates/deployment.yaml | 10 + deploy/charts/litellm-helm/values.yaml | 20 + docker-compose.yml | 5 + litellm/proxy/db/prisma_client.py | 237 +++-- litellm/proxy/db/routing_prisma_wrapper.py | 213 +++++ litellm/proxy/proxy_cli.py | 7 +- litellm/proxy/utils.py | 121 ++- .../proxy/db/test_routing_prisma_wrapper.py | 887 ++++++++++++++++++ 8 files changed, 1423 insertions(+), 77 deletions(-) create mode 100644 litellm/proxy/db/routing_prisma_wrapper.py create mode 100644 tests/test_litellm/proxy/db/test_routing_prisma_wrapper.py diff --git a/deploy/charts/litellm-helm/templates/deployment.yaml b/deploy/charts/litellm-helm/templates/deployment.yaml index 6aa1771b7b..25f6908087 100644 --- a/deploy/charts/litellm-helm/templates/deployment.yaml +++ b/deploy/charts/litellm-helm/templates/deployment.yaml @@ -100,6 +100,16 @@ spec: - name: DATABASE_URL value: {{ .Values.db.url | quote }} {{- end }} + {{- if and .Values.db.useExisting .Values.db.secret.readReplicaUrlKey }} + - name: DATABASE_URL_READ_REPLICA + valueFrom: + secretKeyRef: + name: {{ .Values.db.secret.name }} + key: {{ .Values.db.secret.readReplicaUrlKey }} + {{- else if .Values.db.readReplicaUrl }} + - name: DATABASE_URL_READ_REPLICA + value: {{ .Values.db.readReplicaUrl | quote }} + {{- end }} - name: PROXY_MASTER_KEY valueFrom: secretKeyRef: diff --git a/deploy/charts/litellm-helm/values.yaml b/deploy/charts/litellm-helm/values.yaml index ba4059e084..9c7c013341 100644 --- a/deploy/charts/litellm-helm/values.yaml +++ b/deploy/charts/litellm-helm/values.yaml @@ -252,6 +252,26 @@ db: passwordKey: password # Optional: when set, DATABASE_HOST will be sourced from this secret key instead of db.endpoint endpointKey: "" + # Optional: when set, DATABASE_URL_READ_REPLICA will be sourced from this + # secret key instead of db.readReplicaUrl. Prefer this over the plain + # value: read-replica URLs typically embed credentials, and a value + # written to db.readReplicaUrl ends up visible in the rendered pod spec + # and the Helm release secret. + readReplicaUrlKey: "" + + # Optional read-replica routing. When set, the proxy sends read-only + # queries (find_*, count, group_by, query_raw/_first) to this URL while + # writes continue to go to db.url. Useful for Aurora-style clusters with + # separate reader/writer endpoints. Leave empty to keep single-DB behavior. + # When IAM_TOKEN_DB_AUTH is enabled, the reader URL is auto-refreshed + # alongside the writer (host/port/user/db are parsed from this URL once + # at startup; only the IAM token rotates). + # + # If the URL embeds credentials, prefer db.secret.readReplicaUrlKey over + # this field — the plain value is rendered into the pod spec and the + # Helm release secret. This field is intended for credential-less URLs + # only (e.g. when IAM_TOKEN_DB_AUTH supplies the token at runtime). + readReplicaUrl: "" # Use the Stackgres Helm chart to deploy an instance of a Stackgres cluster. # The Stackgres Operator must already be installed within the target diff --git a/docker-compose.yml b/docker-compose.yml index 988860a787..80e1f289aa 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -16,6 +16,11 @@ services: - "4000:4000" # Map the container port to the host, change the host port if necessary environment: DATABASE_URL: "postgresql://llmproxy:dbpassword9090@db:5432/litellm" + # Optional: route read-only queries (find_*, count, group_by, query_raw/_first) + # to a separate reader endpoint, e.g. an Aurora reader. Leave unset for + # single-DB deployments. With IAM_TOKEN_DB_AUTH enabled, the reader URL + # is auto-refreshed alongside the writer. + # DATABASE_URL_READ_REPLICA: "postgresql://llmproxy:dbpassword9090@db-reader:5432/litellm" STORE_MODEL_IN_DB: "True" # allows adding models to proxy via UI env_file: - .env # Load local .env file diff --git a/litellm/proxy/db/prisma_client.py b/litellm/proxy/db/prisma_client.py index d112e22230..af5a58802b 100644 --- a/litellm/proxy/db/prisma_client.py +++ b/litellm/proxy/db/prisma_client.py @@ -10,13 +10,64 @@ import subprocess import time import urllib import urllib.parse +from dataclasses import dataclass from datetime import datetime, timedelta -from typing import Any, Optional, Union +from typing import Any, Dict, Optional, Union from litellm._logging import verbose_proxy_logger from litellm.secret_managers.main import str_to_bool +@dataclass(frozen=True) +class IAMEndpoint: + """Static parts of an RDS IAM-authenticated Postgres connection. + + The IAM token rotates every ~15 minutes; everything else (host, port, user, + database name, schema) stays fixed. We capture the static fields once so + refresh just regenerates the token and reassembles the URL. + """ + + host: str + port: str + user: str + name: str + schema: Optional[str] = None + + def build_url(self, token: str) -> str: + url = f"postgresql://{self.user}:{token}@{self.host}:{self.port}/{self.name}" + if self.schema: + url += f"?schema={self.schema}" + return url + + +def parse_iam_endpoint_from_url(url: str) -> IAMEndpoint: + """Parse an IAMEndpoint from a Postgres URL. + + Used so a reader URL can drive its own IAM refresh without requiring + callers to set parallel DATABASE_HOST_READ_REPLICA / etc. env vars. + """ + parsed = urllib.parse.urlparse(url) + if not parsed.hostname or not parsed.username: + raise ValueError("Cannot parse IAM endpoint from URL: missing host or username") + name = (parsed.path or "/").lstrip("/") + if not name: + raise ValueError("Cannot parse IAM endpoint from URL: missing database name") + port = str(parsed.port) if parsed.port else "5432" + schema: Optional[str] = None + if parsed.query: + qs = urllib.parse.parse_qs(parsed.query) + schema_vals = qs.get("schema") + if schema_vals: + schema = schema_vals[0] + return IAMEndpoint( + host=parsed.hostname, + port=port, + user=parsed.username, + name=name, + schema=schema, + ) + + class PrismaWrapper: """ Wrapper around Prisma client that handles RDS IAM token authentication. @@ -37,10 +88,33 @@ class PrismaWrapper: # Fallback refresh interval if token parsing fails (10 minutes) FALLBACK_REFRESH_INTERVAL_SECONDS = 600 - def __init__(self, original_prisma: Any, iam_token_db_auth: bool): + def __init__( + self, + original_prisma: Any, + iam_token_db_auth: bool, + *, + db_url_env_var: str = "DATABASE_URL", + iam_endpoint: Optional[IAMEndpoint] = None, + recreate_uses_datasource: bool = False, + log_prefix: str = "", + ): self._original_prisma = original_prisma self.iam_token_db_auth = iam_token_db_auth + # Per-connection knobs so the same wrapper can be used for the writer + # (defaults: DATABASE_URL env, IAM endpoint from DATABASE_HOST/etc., + # recreate via env reload) or for a reader (DATABASE_URL_READ_REPLICA + # env, IAM endpoint parsed from that URL, recreate via datasource + # override since Prisma only auto-reads DATABASE_URL). + self._db_url_env_var = db_url_env_var + self._iam_endpoint = iam_endpoint + self._recreate_uses_datasource = recreate_uses_datasource + # Tag every log line emitted by this wrapper instance so writer and + # reader can be told apart in interleaved output (e.g. "[writer] RDS + # IAM token refresh scheduled in 720 seconds"). Empty string (default) + # keeps backward-compatible logs for the single-DB case. + self._log_prefix = f"{log_prefix} " if log_prefix else "" + # Background token refresh task management self._token_refresh_task: Optional[asyncio.Task] = None self._reconnection_lock = asyncio.Lock() @@ -157,7 +231,7 @@ class PrismaWrapper: Returns 0 if token should be refreshed immediately. Returns FALLBACK_REFRESH_INTERVAL_SECONDS if parsing fails. """ - db_url = os.getenv("DATABASE_URL") + db_url = os.getenv(self._db_url_env_var) token = self._extract_token_from_db_url(db_url) expiration_time = self._parse_token_expiration(token) @@ -199,12 +273,30 @@ class PrismaWrapper: return datetime.utcnow() > expiration_time def get_rds_iam_token(self) -> Optional[str]: - """Generate a new RDS IAM token and update DATABASE_URL.""" - if self.iam_token_db_auth: - from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token + """Generate a new RDS IAM token and update the configured DB URL env var. + When the wrapper was constructed with an explicit `iam_endpoint` + (typical for a reader wrapper whose host/port/user came from a parsed + URL), use that. Otherwise fall back to the legacy DATABASE_HOST/PORT/ + USER/NAME/SCHEMA env vars (writer behavior). + """ + if not self.iam_token_db_auth: + return None + + from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token + + if self._iam_endpoint is not None: + endpoint = self._iam_endpoint + token = generate_iam_auth_token( + db_host=endpoint.host, db_port=endpoint.port, db_user=endpoint.user + ) + _db_url = endpoint.build_url(token) + else: db_host = os.getenv("DATABASE_HOST") - db_port = os.getenv("DATABASE_PORT") + # Default to the Postgres standard port; passing None to + # `generate_iam_auth_token` makes botocore embed the literal + # string "None" in the presigned URL, which then fails to parse. + db_port = os.getenv("DATABASE_PORT", "5432") db_user = os.getenv("DATABASE_USER") db_name = os.getenv("DATABASE_NAME") db_schema = os.getenv("DATABASE_SCHEMA") @@ -217,9 +309,8 @@ class PrismaWrapper: if db_schema: _db_url += f"?schema={db_schema}" - os.environ["DATABASE_URL"] = _db_url - return _db_url - return None + os.environ[self._db_url_env_var] = _db_url + return _db_url async def recreate_prisma_client( self, new_db_url: str, http_client: Optional[Any] = None @@ -231,6 +322,11 @@ class PrismaWrapper: synchronous `subprocess.Popen.wait()` that can freeze the asyncio event loop for 30-120+ seconds when the engine is stuck on TCP close, breaking `/health/liveliness` and causing Kubernetes pod restarts. + + The writer wrapper relies on Prisma re-reading `DATABASE_URL` from env; + the reader wrapper opts into `recreate_uses_datasource=True` so the + new URL is passed explicitly via `datasource={"url": ...}` (Prisma + does not auto-read alternate env vars like DATABASE_URL_READ_REPLICA). """ from prisma import Prisma # type: ignore @@ -238,10 +334,12 @@ class PrismaWrapper: if old_engine_pid > 0: await self._kill_engine_process(old_engine_pid) + kwargs: Dict[str, Any] = {} if http_client is not None: - self._original_prisma = Prisma(http=http_client) - else: - self._original_prisma = Prisma() + kwargs["http"] = http_client + if self._recreate_uses_datasource: + kwargs["datasource"] = {"url": new_db_url} + self._original_prisma = Prisma(**kwargs) await self._original_prisma.connect() @@ -265,7 +363,8 @@ class PrismaWrapper: self._token_refresh_task = asyncio.create_task(self._token_refresh_loop()) verbose_proxy_logger.info( - "Started RDS IAM token proactive refresh background task" + "%sStarted RDS IAM token proactive refresh background task", + self._log_prefix, ) async def stop_token_refresh_task(self) -> None: @@ -283,7 +382,9 @@ class PrismaWrapper: except asyncio.CancelledError: pass self._token_refresh_task = None - verbose_proxy_logger.info("Stopped RDS IAM token refresh background task") + verbose_proxy_logger.info( + "%sStopped RDS IAM token refresh background task", self._log_prefix + ) async def _token_refresh_loop(self) -> None: """ @@ -294,7 +395,7 @@ class PrismaWrapper: This is more efficient than polling, requiring only 1 wake-up per token cycle. """ verbose_proxy_logger.info( - f"RDS IAM token refresh loop started. " + f"{self._log_prefix}RDS IAM token refresh loop started. " f"Tokens will be refreshed {self.TOKEN_REFRESH_BUFFER_SECONDS}s before expiration." ) @@ -305,21 +406,25 @@ class PrismaWrapper: if sleep_seconds > 0: verbose_proxy_logger.info( - f"RDS IAM token refresh scheduled in {sleep_seconds:.0f} seconds " - f"({sleep_seconds / 60:.1f} minutes)" + f"{self._log_prefix}RDS IAM token refresh scheduled in " + f"{sleep_seconds:.0f} seconds ({sleep_seconds / 60:.1f} minutes)" ) await asyncio.sleep(sleep_seconds) # Refresh the token - verbose_proxy_logger.info("Proactively refreshing RDS IAM token...") + verbose_proxy_logger.info( + "%sProactively refreshing RDS IAM token...", self._log_prefix + ) await self._safe_refresh_token() except asyncio.CancelledError: - verbose_proxy_logger.info("RDS IAM token refresh loop cancelled") + verbose_proxy_logger.info( + "%sRDS IAM token refresh loop cancelled", self._log_prefix + ) break except Exception as e: verbose_proxy_logger.error( - f"Error in RDS IAM token refresh loop: {e}. " + f"{self._log_prefix}Error in RDS IAM token refresh loop: {e}. " f"Retrying in {self.FALLBACK_REFRESH_INTERVAL_SECONDS}s..." ) # On error, wait before retrying to avoid tight error loops @@ -341,65 +446,75 @@ class PrismaWrapper: await self.recreate_prisma_client(new_db_url) self._last_refresh_time = datetime.utcnow() verbose_proxy_logger.info( - "RDS IAM token refreshed successfully. New token valid for ~15 minutes." + "%sRDS IAM token refreshed successfully. New token valid for ~15 minutes.", + self._log_prefix, ) else: verbose_proxy_logger.error( - "Failed to generate new RDS IAM token during proactive refresh" + "%sFailed to generate new RDS IAM token during proactive refresh", + self._log_prefix, ) def __getattr__(self, name: str): """ Proxy attribute access to the underlying Prisma client. - If IAM token auth is enabled and the token is expired, this method - provides a synchronous fallback to refresh the token. However, this - should rarely be needed since the background task proactively refreshes - tokens before they expire. + If IAM token auth is enabled and the token is found expired here, the + proactive refresh task has missed its window. Behavior depends on + whether we're called from inside a running event loop: - FIXED: Now properly waits for reconnection to complete before returning, - instead of the previous fire-and-forget pattern that caused the bug. + - Inside the loop (typical: from a coroutine): schedule a refresh as a + background task and return the (stale) attribute. The caller's await + will likely fail with a connection error and be retried by upper + layers (`call_with_db_reconnect_retry`); by that time the refresh + has either completed or escalated to the proactive loop's error + path. We CANNOT block here — `run_coroutine_threadsafe(...)` + + `future.result()` from inside the same loop deadlocks the loop + (loop thread is blocked, scheduled coroutine never runs, 30s timeout). + + - No running loop (sync caller, mostly tests): run the refresh in a + fresh loop and re-fetch the attribute. """ original_attr = getattr(self._original_prisma, name) if self.iam_token_db_auth: - db_url = os.getenv("DATABASE_URL") + db_url = os.getenv(self._db_url_env_var) # Check if token is expired (should be rare if background task is running) if self.is_token_expired(db_url): - verbose_proxy_logger.warning( - "RDS IAM token expired in __getattr__ - proactive refresh may have failed. " - "Triggering synchronous fallback refresh..." - ) + try: + running_loop = asyncio.get_running_loop() + except RuntimeError: + running_loop = None - new_db_url = self.get_rds_iam_token() - if new_db_url: - loop = asyncio.get_event_loop() - - if loop.is_running(): - # FIXED: Actually wait for the reconnection to complete! - # The previous code used fire-and-forget which caused the bug. - future = asyncio.run_coroutine_threadsafe( - self.recreate_prisma_client(new_db_url), loop - ) - try: - # Wait up to 30 seconds for reconnection - future.result(timeout=30) - verbose_proxy_logger.info( - "Synchronous token refresh completed successfully" - ) - except Exception as e: - verbose_proxy_logger.error( - f"Failed to refresh token synchronously: {e}" - ) - raise - else: - asyncio.run(self.recreate_prisma_client(new_db_url)) - - # Get the NEW attribute after reconnection - original_attr = getattr(self._original_prisma, name) + if running_loop is not None: + verbose_proxy_logger.warning( + "%sRDS IAM token expired in __getattr__ — proactive refresh " + "may have failed. Scheduling async refresh; the current " + "request may fail and be retried with the fresh token.", + self._log_prefix, + ) + # Non-blocking: schedule the locked refresh on the + # running loop. The reconnection lock inside + # `_safe_refresh_token` coalesces concurrent triggers. + running_loop.create_task(self._safe_refresh_token()) else: - raise ValueError("Failed to get RDS IAM token") + verbose_proxy_logger.warning( + "%sRDS IAM token expired in __getattr__ — proactive refresh " + "may have failed. Triggering synchronous fallback refresh...", + self._log_prefix, + ) + new_db_url = self.get_rds_iam_token() + if new_db_url: + asyncio.run(self.recreate_prisma_client(new_db_url)) + # Re-fetch attribute against the recreated Prisma instance. + original_attr = getattr(self._original_prisma, name) + verbose_proxy_logger.info( + "%sSynchronous token refresh completed successfully", + self._log_prefix, + ) + else: + raise ValueError("Failed to get RDS IAM token") return original_attr diff --git a/litellm/proxy/db/routing_prisma_wrapper.py b/litellm/proxy/db/routing_prisma_wrapper.py new file mode 100644 index 0000000000..0a976e9f1e --- /dev/null +++ b/litellm/proxy/db/routing_prisma_wrapper.py @@ -0,0 +1,213 @@ +""" +RoutingPrismaWrapper: routes Prisma reads to a read-replica client and writes +to a writer client. Used when DATABASE_URL_READ_REPLICA is configured; +otherwise PrismaClient uses the writer-only PrismaWrapper directly. +""" + +import os +from typing import Any, Callable, Optional + +from litellm._logging import verbose_proxy_logger +from litellm.proxy.db.prisma_client import PrismaWrapper + +# Per-model action methods that read from the database. These are routed to +# the read replica when one is configured. +_MODEL_READ_METHODS = frozenset( + { + "find_first", + "find_first_or_raise", + "find_many", + "find_unique", + "find_unique_or_raise", + "count", + "group_by", + "query_first", + "query_raw", + } +) + +# Top-level Prisma client methods that read from the database. +_TOP_LEVEL_READ_METHODS = frozenset({"query_first", "query_raw"}) + + +class _RoutedActions: + """Per-model accessor that sends reads to the reader and writes to the writer. + + `should_use_reader` is consulted on every read dispatch so a mid-call flip + of the routing wrapper's reader-availability flag (e.g. after the reader + fails a recreate) is observed without re-fetching the actions accessor. + """ + + __slots__ = ("_writer_actions", "_reader_actions", "_should_use_reader") + + def __init__( + self, + writer_actions: Any, + reader_actions: Any, + should_use_reader: Callable[[], bool], + ): + self._writer_actions = writer_actions + self._reader_actions = reader_actions + self._should_use_reader = should_use_reader + + def __getattr__(self, name: str) -> Any: + if name in _MODEL_READ_METHODS and self._should_use_reader(): + return getattr(self._reader_actions, name) + return getattr(self._writer_actions, name) + + +class RoutingPrismaWrapper: + """ + Routes Prisma operations between a writer and a reader Prisma client. + + Reads (find_*, count, group_by, query_raw, query_first) go to the reader; + everything else (writes, transactions, raw execute) goes to the writer. + Lifecycle methods (connect, disconnect, IAM token refresh) act on both + clients so callers do not need to know about the split. When + IAM_TOKEN_DB_AUTH is enabled, both writer and reader refresh their tokens + independently on their own ~12-minute cadence. + + Reader degradation: a reader-side failure (failed connect, failed + recreate) is non-fatal — the wrapper sets `_reader_unavailable=True`, logs + a warning, and routes subsequent reads to the writer. The next successful + `connect()` or `recreate_prisma_client()` clears the flag. This keeps the + proxy serving traffic during transient reader outages instead of failing + startup or returning errors for read-heavy endpoints. + """ + + def __init__(self, writer: PrismaWrapper, reader: PrismaWrapper): + self._writer = writer + self._reader = reader + # When True, reads fall back to the writer. Flipped on by reader + # connect/recreate failures and flipped off on the next reader recovery. + self._reader_unavailable: bool = False + + @property + def writer(self) -> PrismaWrapper: + return self._writer + + @property + def reader(self) -> PrismaWrapper: + return self._reader + + @property + def reader_unavailable(self) -> bool: + return self._reader_unavailable + + def _should_use_reader(self) -> bool: + return not self._reader_unavailable + + async def connect(self, *args: Any, **kwargs: Any) -> None: + await self._writer.connect(*args, **kwargs) + verbose_proxy_logger.info("[writer] DB connected") + try: + await self._reader.connect(*args, **kwargs) + self._reader_unavailable = False + verbose_proxy_logger.info("[reader] DB connected") + except Exception as e: + # Degrade gracefully: the proxy keeps serving traffic with reads + # routed to the writer until the reader endpoint is reachable. + # Aborting startup here would tie proxy availability to an + # opt-in, best-effort reader endpoint. + self._reader_unavailable = True + verbose_proxy_logger.warning( + "Failed to connect to read replica DB: %s. " + "Falling back to the writer for reads until the reader is reachable.", + e, + ) + + async def disconnect(self, *args: Any, **kwargs: Any) -> None: + first_error: Optional[BaseException] = None + for client in (self._writer, self._reader): + try: + await client.disconnect(*args, **kwargs) + except Exception as e: + if first_error is None: + first_error = e + verbose_proxy_logger.warning("Error disconnecting Prisma client: %s", e) + if first_error is not None: + raise first_error + + def is_connected(self) -> bool: + # Reflects writer health only. The reader is best-effort; its + # availability is tracked via `_reader_unavailable` and a degraded + # reader must NOT cause a writer reconnect (would loop indefinitely + # since recreate_prisma_client only fixes writer-side problems). + return bool(self._writer.is_connected()) + + async def start_token_refresh_task(self) -> None: + await self._writer.start_token_refresh_task() + await self._reader.start_token_refresh_task() + + async def stop_token_refresh_task(self) -> None: + await self._writer.stop_token_refresh_task() + await self._reader.stop_token_refresh_task() + + async def recreate_prisma_client( + self, new_db_url: str, http_client: Optional[Any] = None + ) -> None: + """Recreate both writer and reader Prisma clients. + + The writer reconnect path in PrismaClient calls + `self.db.recreate_prisma_client(...)`. Without this method, a DB-wide + connectivity event would only re-create the writer; the reader engine + would stay broken and every routed read would fail. We always recreate + the writer first (its URL is the one passed in), then best-effort + recreate the reader. A reader failure flips `_reader_unavailable=True` + so reads transparently fall through to the writer. + """ + await self._writer.recreate_prisma_client(new_db_url, http_client=http_client) + try: + await self._recreate_reader(http_client=http_client) + self._reader_unavailable = False + except Exception as e: + self._reader_unavailable = True + verbose_proxy_logger.warning( + "Failed to recreate reader Prisma client: %s. " + "Reads will fall back to the writer until the reader recovers.", + e, + ) + + async def _recreate_reader(self, http_client: Optional[Any] = None) -> None: + """Resolve the reader URL and recreate its Prisma client. + + IAM-enabled readers regenerate their token (host/port/user came from + the parsed reader URL at construction time). Non-IAM readers reuse + the URL stored in `DATABASE_URL_READ_REPLICA`. + """ + if self._reader.iam_token_db_auth: + new_reader_url = self._reader.get_rds_iam_token() + if not new_reader_url: + raise RuntimeError( + "Failed to generate fresh IAM token for read replica" + ) + await self._reader.recreate_prisma_client( + new_reader_url, http_client=http_client + ) + return + reader_url = os.getenv("DATABASE_URL_READ_REPLICA", "") + if not reader_url: + raise RuntimeError( + "DATABASE_URL_READ_REPLICA not set; cannot recreate read replica client" + ) + await self._reader.recreate_prisma_client(reader_url, http_client=http_client) + + def __getattr__(self, name: str) -> Any: + if name in _TOP_LEVEL_READ_METHODS: + target = self._writer if self._reader_unavailable else self._reader + return getattr(target, name) + writer_attr = getattr(self._writer, name) + # Per-model action accessors are non-callable instances that expose + # both `find_many` and `create`. Methods like execute_raw / batch_ / + # tx are callables and stay on the writer untouched. + if ( + not callable(writer_attr) + and hasattr(writer_attr, "find_many") + and hasattr(writer_attr, "create") + ): + try: + reader_attr = getattr(self._reader, name) + except AttributeError: + return writer_attr + return _RoutedActions(writer_attr, reader_attr, self._should_use_reader) + return writer_attr diff --git a/litellm/proxy/proxy_cli.py b/litellm/proxy/proxy_cli.py index 06fc0819a7..6359b48654 100644 --- a/litellm/proxy/proxy_cli.py +++ b/litellm/proxy/proxy_cli.py @@ -813,7 +813,12 @@ def run_server( # noqa: PLR0915 from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token db_host = os.getenv("DATABASE_HOST") - db_port = os.getenv("DATABASE_PORT") + # Default to the Postgres standard port. Without a default, + # `db_port=None` flows into `boto.generate_db_auth_token(Port=None)` + # and botocore stringifies it to `"None"` while building the + # presigned URL, which then blows up with `ValueError: Port could + # not be cast to integer value as 'None'` during signing. + db_port = os.getenv("DATABASE_PORT", "5432") db_user = os.getenv("DATABASE_USER") db_name = os.getenv("DATABASE_NAME") db_schema = os.getenv("DATABASE_SCHEMA") diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index a52dc8e55f..0577110a26 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -113,7 +113,11 @@ from litellm.proxy.db.exception_handler import ( call_with_db_reconnect_retry, ) from litellm.proxy.db.log_db_metrics import log_db_metrics -from litellm.proxy.db.prisma_client import PrismaWrapper +from litellm.proxy.db.prisma_client import ( + PrismaWrapper, + parse_iam_endpoint_from_url, +) +from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import ( UnifiedLLMGuardrails, ) @@ -2569,24 +2573,101 @@ class PrismaClient: raise Exception( "Unable to find Prisma binaries. Please run 'prisma generate' first." ) + iam_flag = ( + self.iam_token_db_auth if self.iam_token_db_auth is not None else False + ) + # When read-replica routing is on, tag log lines with [writer]/[reader] + # so the two wrappers' interleaved IAM refresh logs can be told apart. + # Single-DB deployments get an empty prefix (logs unchanged). + read_replica_url = os.getenv("DATABASE_URL_READ_REPLICA") + writer_log_prefix = "[writer]" if read_replica_url else "" if http_client is not None: - self.db = PrismaWrapper( + writer_wrapper = PrismaWrapper( original_prisma=Prisma(http=http_client), - iam_token_db_auth=( - self.iam_token_db_auth - if self.iam_token_db_auth is not None - else False - ), + iam_token_db_auth=iam_flag, + log_prefix=writer_log_prefix, ) else: - self.db = PrismaWrapper( + writer_wrapper = PrismaWrapper( original_prisma=Prisma(), - iam_token_db_auth=( - self.iam_token_db_auth - if self.iam_token_db_auth is not None - else False - ), - ) # Client to connect to Prisma db + iam_token_db_auth=iam_flag, + log_prefix=writer_log_prefix, + ) + + # Optional read-replica routing. When DATABASE_URL_READ_REPLICA is set, + # reads (find_*, count, group_by, query_raw/_first) are routed to the + # reader endpoint and writes stay on the writer. Falls back to the + # writer-only wrapper when the env var is unset, preserving existing + # single-DB deployments. + self.db: Union[PrismaWrapper, RoutingPrismaWrapper] + if read_replica_url: + try: + # If IAM auth is enabled, the reader refreshes its own token on + # the same cadence as the writer. We parse the static endpoint + # pieces (host/port/user/db) once from the reader URL — only + # the IAM token rotates after that. + reader_iam_endpoint = ( + parse_iam_endpoint_from_url(read_replica_url) if iam_flag else None + ) + # Mint a fresh IAM token for the reader BEFORE constructing the + # Prisma client. Mirrors what `proxy_cli.py` already does for + # the writer (proxy_cli.py:812-832) — without this, the reader + # Prisma is built with whatever placeholder URL the user + # supplied (no real token), and the first query falls through + # to the synchronous fallback path in + # `PrismaWrapper.__getattr__`, which deadlocks the event loop + # and times out after 30s. + if iam_flag and reader_iam_endpoint is not None: + from litellm.proxy.auth.rds_iam_token import ( + generate_iam_auth_token, + ) + + reader_token = generate_iam_auth_token( + db_host=reader_iam_endpoint.host, + db_port=reader_iam_endpoint.port, + db_user=reader_iam_endpoint.user, + ) + read_replica_url = reader_iam_endpoint.build_url(reader_token) + os.environ["DATABASE_URL_READ_REPLICA"] = read_replica_url + reader_kwargs: Dict[str, Any] = { + "datasource": {"url": read_replica_url} + } + if http_client is not None: + reader_prisma = Prisma(http=http_client, **reader_kwargs) + else: + reader_prisma = Prisma(**reader_kwargs) + reader_wrapper = PrismaWrapper( + original_prisma=reader_prisma, + iam_token_db_auth=iam_flag, + db_url_env_var="DATABASE_URL_READ_REPLICA", + iam_endpoint=reader_iam_endpoint, + recreate_uses_datasource=True, + log_prefix="[reader]", + ) + self.db = RoutingPrismaWrapper( + writer=writer_wrapper, reader=reader_wrapper + ) + verbose_proxy_logger.info( + "PrismaClient: read-replica routing enabled via DATABASE_URL_READ_REPLICA" + + (" (with IAM token auto-refresh)" if iam_flag else "") + ) + except Exception as e: + # Reader is opt-in; never let its construction fail proxy + # startup. Mirrors the runtime contract from + # `RoutingPrismaWrapper.connect`: reader-side failures are + # logged and we keep serving traffic via the writer alone. + # This recovers from transient AWS STS hiccups during the + # reader IAM token mint, malformed DATABASE_URL_READ_REPLICA, + # and Prisma construction errors. Operator restart is required + # to retry read-routing once the underlying issue is resolved. + verbose_proxy_logger.warning( + "Failed to initialize read replica Prisma client: %s. " + "Falling back to writer-only mode (no read routing) until proxy restart.", + e, + ) + self.db = writer_wrapper + else: + self.db = writer_wrapper # Client to connect to Prisma db self._db_reconnect_lock = asyncio.Lock() self._db_health_watchdog_task: Optional[asyncio.Task] = None self._db_last_reconnect_attempt_ts: float = 0.0 @@ -2624,6 +2705,13 @@ class PrismaClient: self._engine_wait_thread: Optional[threading.Thread] = None verbose_proxy_logger.debug("Success - Created Prisma Client") + @property + def writer_db(self) -> PrismaWrapper: + """Underlying writer Prisma wrapper, regardless of read-replica routing.""" + if isinstance(self.db, RoutingPrismaWrapper): + return self.db.writer + return self.db + def get_request_status( self, payload: Union[dict, SpendLogsPayload] ) -> Literal["success", "failure"]: @@ -4272,7 +4360,10 @@ class PrismaClient: self._cleanup_engine_watcher() await self.db.recreate_prisma_client(db_url) await self._start_engine_watcher() - await self.db.query_raw("SELECT 1") + # Smoke-test the writer specifically; query_raw on the routing + # wrapper sends to the reader, which would not validate the + # newly-recreated writer engine. + await self.writer_db.query_raw("SELECT 1") await asyncio.wait_for(_do_direct_reconnect(), timeout=effective_timeout) diff --git a/tests/test_litellm/proxy/db/test_routing_prisma_wrapper.py b/tests/test_litellm/proxy/db/test_routing_prisma_wrapper.py new file mode 100644 index 0000000000..8c3a2b9e2d --- /dev/null +++ b/tests/test_litellm/proxy/db/test_routing_prisma_wrapper.py @@ -0,0 +1,887 @@ +import asyncio +import logging +import os +import sys +from typing import Any, Dict +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +sys.path.insert(0, os.path.abspath("../../../..")) + + +# NOTE: do NOT patch sys.modules["prisma"] file-wide via an autouse fixture. +# Doing so leaks across pytest-xdist test scheduling: when a worker runs a +# routing test, then later runs test_exception_handler.py, the cached MagicMock +# attribute references break `isinstance(e, prisma.errors.X)` in +# `is_database_transport_error`. The two tests below that actually need to +# stub the prisma SDK do so per-test via monkeypatch, which is properly scoped. + + +def _make_wrappers(): + from litellm.proxy.db.prisma_client import PrismaWrapper + + writer_inner = MagicMock(name="writer_prisma") + reader_inner = MagicMock(name="reader_prisma") + writer = PrismaWrapper(original_prisma=writer_inner, iam_token_db_auth=False) + reader = PrismaWrapper(original_prisma=reader_inner, iam_token_db_auth=False) + return writer, writer_inner, reader, reader_inner + + +class _FakeActions: + """Stand-in for a Prisma per-model Actions instance (non-callable, has find_many/create).""" + + def __init__(self, name: str): + self._name = name + for method in ( + "find_many", + "find_unique", + "find_first", + "count", + "group_by", + "create", + "update", + "upsert", + "delete", + "delete_many", + "update_many", + ): + setattr(self, method, MagicMock(name=f"{name}.{method}")) + + +def _model_actions_mock(name: str) -> _FakeActions: + return _FakeActions(name) + + +def test_top_level_query_raw_routes_to_reader(): + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer, writer_inner, reader, reader_inner = _make_wrappers() + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + # query_raw should resolve to the reader's underlying client. + assert routing.query_raw is reader_inner.query_raw + assert routing.query_first is reader_inner.query_first + + +def test_top_level_execute_raw_routes_to_writer(): + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer, writer_inner, reader, reader_inner = _make_wrappers() + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + # execute_raw, batch_, tx are write-side and must hit the writer. + assert routing.execute_raw is writer_inner.execute_raw + assert routing.batch_ is writer_inner.batch_ + assert routing.tx is writer_inner.tx + + +def test_per_model_reads_route_to_reader_writes_to_writer(): + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer, writer_inner, reader, reader_inner = _make_wrappers() + writer_inner.litellm_usertable = _model_actions_mock("writer_users") + reader_inner.litellm_usertable = _model_actions_mock("reader_users") + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + actions = routing.litellm_usertable + + # Reads → reader actions. + assert actions.find_many is reader_inner.litellm_usertable.find_many + assert actions.find_unique is reader_inner.litellm_usertable.find_unique + assert actions.find_first is reader_inner.litellm_usertable.find_first + assert actions.count is reader_inner.litellm_usertable.count + assert actions.group_by is reader_inner.litellm_usertable.group_by + + # Writes → writer actions. + assert actions.create is writer_inner.litellm_usertable.create + assert actions.update is writer_inner.litellm_usertable.update + assert actions.upsert is writer_inner.litellm_usertable.upsert + assert actions.delete is writer_inner.litellm_usertable.delete + assert actions.update_many is writer_inner.litellm_usertable.update_many + assert actions.delete_many is writer_inner.litellm_usertable.delete_many + + +@pytest.mark.asyncio +async def test_connect_invokes_both_clients(): + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer, writer_inner, reader, reader_inner = _make_wrappers() + writer_inner.connect = AsyncMock() + reader_inner.connect = AsyncMock() + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + await routing.connect() + + writer_inner.connect.assert_awaited_once() + reader_inner.connect.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_connect_logs_writer_and_reader_success(caplog): + """Successful startup emits a positive INFO confirmation for both writer + and reader so operators can verify connectivity without inspecting the URL + in logs.""" + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer, writer_inner, reader, reader_inner = _make_wrappers() + writer_inner.connect = AsyncMock() + reader_inner.connect = AsyncMock() + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + with caplog.at_level(logging.INFO, logger="LiteLLM Proxy"): + await routing.connect() + + messages = [r.getMessage() for r in caplog.records] + assert "[writer] DB connected" in messages + assert "[reader] DB connected" in messages + + +@pytest.mark.asyncio +async def test_disconnect_continues_when_one_side_fails(): + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer, writer_inner, reader, reader_inner = _make_wrappers() + writer_inner.disconnect = AsyncMock(side_effect=RuntimeError("writer down")) + reader_inner.disconnect = AsyncMock() + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + with pytest.raises(RuntimeError, match="writer down"): + await routing.disconnect() + + # Reader still attempted even though writer raised. + reader_inner.disconnect.assert_awaited_once() + + +def test_is_connected_reflects_writer_only(): + """is_connected() must NOT depend on reader health — a healthy writer with + a degraded reader should report True so that PrismaClient.connect()'s + health check does not re-trigger a writer reconnect (which only fixes + writer-side problems and would loop indefinitely).""" + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer, writer_inner, reader, reader_inner = _make_wrappers() + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + writer_inner.is_connected = MagicMock(return_value=True) + reader_inner.is_connected = MagicMock(return_value=True) + assert routing.is_connected() is True + + # Reader down → still True (reader degradation is tracked separately). + reader_inner.is_connected = MagicMock(return_value=False) + assert routing.is_connected() is True + + # Writer down → False. + writer_inner.is_connected = MagicMock(return_value=False) + assert routing.is_connected() is False + + +def test_token_refresh_delegates_to_both_writer_and_reader(): + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer = MagicMock() + writer.start_token_refresh_task = AsyncMock() + writer.stop_token_refresh_task = AsyncMock() + reader = MagicMock() + reader.start_token_refresh_task = AsyncMock() + reader.stop_token_refresh_task = AsyncMock() + + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + asyncio.run(routing.start_token_refresh_task()) + asyncio.run(routing.stop_token_refresh_task()) + + # Both wrappers get start/stop — each manages its own IAM token. When + # IAM is disabled on a wrapper its task body is a no-op. + writer.start_token_refresh_task.assert_awaited_once() + writer.stop_token_refresh_task.assert_awaited_once() + reader.start_token_refresh_task.assert_awaited_once() + reader.stop_token_refresh_task.assert_awaited_once() + + +def test_routed_actions_falls_back_to_writer_for_unknown_methods(): + from litellm.proxy.db.routing_prisma_wrapper import _RoutedActions + + writer_actions = _model_actions_mock("writer") + writer_actions.some_custom_method = "writer-custom" + reader_actions = _model_actions_mock("reader") + reader_actions.some_custom_method = "reader-custom" + + routed = _RoutedActions(writer_actions, reader_actions, lambda: True) + # Unknown method → defaults to writer (safe fallback for write-like ops). + assert routed.some_custom_method == "writer-custom" + + +def test_routed_actions_respects_should_use_reader_flag(): + """When the routing wrapper marks the reader unavailable, _RoutedActions + must redirect reads to the writer instead — without needing to re-fetch + the actions accessor.""" + from litellm.proxy.db.routing_prisma_wrapper import _RoutedActions + + writer_actions = _model_actions_mock("writer") + reader_actions = _model_actions_mock("reader") + + use_reader = {"value": True} + routed = _RoutedActions(writer_actions, reader_actions, lambda: use_reader["value"]) + + # Reader healthy → reads to reader. + assert routed.find_many is reader_actions.find_many + + # Reader degrades mid-flight → next read goes to writer. + use_reader["value"] = False + assert routed.find_many is writer_actions.find_many + + +# --------------------------------------------------------------------------- +# Reader graceful degradation +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_connect_swallows_reader_failure_and_falls_back_to_writer(): + """A reader connect failure must NOT abort proxy startup. The wrapper + flips into degraded mode so subsequent reads route to the writer.""" + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer, writer_inner, reader, reader_inner = _make_wrappers() + writer_inner.connect = AsyncMock() + reader_inner.connect = AsyncMock(side_effect=RuntimeError("reader unreachable")) + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + # Must not raise — reader failure is non-fatal. + await routing.connect() + + assert routing.reader_unavailable is True + writer_inner.connect.assert_awaited_once() + reader_inner.connect.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_reads_route_to_writer_when_reader_unavailable(): + """Top-level read methods and per-model reads must fall through to the + writer while the reader is degraded.""" + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer, writer_inner, reader, reader_inner = _make_wrappers() + writer_inner.litellm_usertable = _model_actions_mock("writer_users") + reader_inner.litellm_usertable = _model_actions_mock("reader_users") + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + routing._reader_unavailable = True + + # Top-level reads → writer. + assert routing.query_raw is writer_inner.query_raw + assert routing.query_first is writer_inner.query_first + + # Per-model reads → writer actions. + actions = routing.litellm_usertable + assert actions.find_many is writer_inner.litellm_usertable.find_many + assert actions.find_unique is writer_inner.litellm_usertable.find_unique + + +@pytest.mark.asyncio +async def test_recreate_prisma_client_recreates_both_writer_and_reader(): + """Writer reconnect path calls recreate_prisma_client. The routing wrapper + must recreate BOTH clients so a DB-wide event doesn't leave a stale reader.""" + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer = MagicMock() + writer.recreate_prisma_client = AsyncMock() + reader = MagicMock() + reader.iam_token_db_auth = False + reader.recreate_prisma_client = AsyncMock() + + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + with patch.dict(os.environ, {"DATABASE_URL_READ_REPLICA": "reader-url"}): + await routing.recreate_prisma_client("writer-url", http_client=None) + + writer.recreate_prisma_client.assert_awaited_once_with( + "writer-url", http_client=None + ) + reader.recreate_prisma_client.assert_awaited_once_with( + "reader-url", http_client=None + ) + assert routing.reader_unavailable is False + + +@pytest.mark.asyncio +async def test_recreate_recovers_reader_after_prior_degradation(): + """If a previous connect/recreate degraded the reader, a successful + recreate must clear the flag so reads start hitting the reader again.""" + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer = MagicMock() + writer.recreate_prisma_client = AsyncMock() + reader = MagicMock() + reader.iam_token_db_auth = False + reader.recreate_prisma_client = AsyncMock() + + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + routing._reader_unavailable = True + + with patch.dict(os.environ, {"DATABASE_URL_READ_REPLICA": "reader-url"}): + await routing.recreate_prisma_client("writer-url") + + assert routing.reader_unavailable is False + + +@pytest.mark.asyncio +async def test_recreate_degrades_reader_if_reader_recreate_fails(): + """If the reader recreate fails, writer recreate still succeeds and the + routing wrapper degrades (does not raise).""" + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer = MagicMock() + writer.recreate_prisma_client = AsyncMock() + reader = MagicMock() + reader.iam_token_db_auth = False + reader.recreate_prisma_client = AsyncMock( + side_effect=RuntimeError("reader still down") + ) + + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + with patch.dict(os.environ, {"DATABASE_URL_READ_REPLICA": "reader-url"}): + # Must not raise — writer was recreated, reader is best-effort. + await routing.recreate_prisma_client("writer-url") + + writer.recreate_prisma_client.assert_awaited_once() + assert routing.reader_unavailable is True + + +@pytest.mark.asyncio +async def test_recreate_degrades_reader_when_replica_url_missing(): + """Non-IAM reader needs DATABASE_URL_READ_REPLICA. If it's missing + (configuration drift), the wrapper degrades instead of raising.""" + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer = MagicMock() + writer.recreate_prisma_client = AsyncMock() + reader = MagicMock() + reader.iam_token_db_auth = False + reader.recreate_prisma_client = AsyncMock() + + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + # Ensure env var is absent. + with patch.dict(os.environ, {}, clear=False): + os.environ.pop("DATABASE_URL_READ_REPLICA", None) + await routing.recreate_prisma_client("writer-url") + + writer.recreate_prisma_client.assert_awaited_once() + reader.recreate_prisma_client.assert_not_awaited() + assert routing.reader_unavailable is True + + +@pytest.mark.asyncio +async def test_recreate_iam_reader_refreshes_token(): + """IAM-enabled readers must refresh their token (reader has its own parsed + endpoint) and pass the fresh URL to recreate_prisma_client.""" + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer = MagicMock() + writer.recreate_prisma_client = AsyncMock() + reader = MagicMock() + reader.iam_token_db_auth = True + reader.get_rds_iam_token = MagicMock(return_value="postgresql://u:fresh@h:5432/db") + reader.recreate_prisma_client = AsyncMock() + + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + await routing.recreate_prisma_client("writer-url") + + reader.get_rds_iam_token.assert_called_once() + reader.recreate_prisma_client.assert_awaited_once_with( + "postgresql://u:fresh@h:5432/db", http_client=None + ) + assert routing.reader_unavailable is False + + +@pytest.mark.asyncio +async def test_recreate_degrades_when_iam_token_generation_returns_none(): + """If `get_rds_iam_token` returns None (e.g. AWS-side failure), the wrapper + must degrade rather than crash — this exercises the explicit `raise + RuntimeError` inside `_recreate_reader`'s IAM branch.""" + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer = MagicMock() + writer.recreate_prisma_client = AsyncMock() + reader = MagicMock() + reader.iam_token_db_auth = True + reader.get_rds_iam_token = MagicMock(return_value=None) + reader.recreate_prisma_client = AsyncMock() + + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + await routing.recreate_prisma_client("writer-url") + + writer.recreate_prisma_client.assert_awaited_once() + reader.recreate_prisma_client.assert_not_awaited() + assert routing.reader_unavailable is True + + +def test_writer_and_reader_properties_expose_underlying_wrappers(): + """The `writer` and `reader` properties are used by PrismaClient.writer_db + to smoke-test the writer specifically during reconnect — they must return + the exact wrappers passed in.""" + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + writer, _, reader, _ = _make_wrappers() + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + assert routing.writer is writer + assert routing.reader is reader + + +def test_per_model_accessor_falls_back_when_reader_lacks_attr(): + """If the reader Prisma client somehow lacks a model accessor that the + writer has (older client / partial mock), the wrapper must fall back to + the writer accessor instead of raising AttributeError to the caller.""" + from litellm.proxy.db.prisma_client import PrismaWrapper + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + # Plain class with only the accessor set on the writer side. Using a real + # class instead of MagicMock so attribute access raises AttributeError + # naturally instead of auto-creating mock attributes. + class _PartialPrisma: + pass + + writer_inner = _PartialPrisma() + writer_inner.litellm_usertable = _model_actions_mock("writer_users") + reader_inner = _PartialPrisma() # deliberately missing litellm_usertable + + writer = PrismaWrapper(original_prisma=writer_inner, iam_token_db_auth=False) + reader = PrismaWrapper(original_prisma=reader_inner, iam_token_db_auth=False) + routing = RoutingPrismaWrapper(writer=writer, reader=reader) + + actions = routing.litellm_usertable + # Falls back to the writer's accessor verbatim — not a _RoutedActions wrapper. + assert actions is writer_inner.litellm_usertable + + +@pytest.mark.asyncio +async def test_writer_recreate_passes_http_client_through(monkeypatch): + """When PrismaClient is constructed with an http_client, recreate must + forward it to the new Prisma() so connection settings persist across + reconnects.""" + from litellm.proxy.db.prisma_client import PrismaWrapper + + captured_kwargs: Dict[str, Any] = {} + + class FakePrisma: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + async def connect(self): + return None + + fake_module = MagicMock() + fake_module.Prisma = FakePrisma + monkeypatch.setitem(sys.modules, "prisma", fake_module) + + writer = PrismaWrapper(original_prisma=MagicMock(), iam_token_db_auth=False) + sentinel_http = object() + await writer.recreate_prisma_client( + "postgresql://u:p@h:5432/db", http_client=sentinel_http + ) + + assert captured_kwargs == {"http": sentinel_http} + + +# --------------------------------------------------------------------------- +# IAM endpoint parsing + reader IAM refresh +# --------------------------------------------------------------------------- + + +def test_parse_iam_endpoint_from_url_extracts_all_fields(): + from litellm.proxy.db.prisma_client import parse_iam_endpoint_from_url + + ep = parse_iam_endpoint_from_url( + "postgresql://litellm_user:initial-token@aurora-reader.example.com:6543/litellm?schema=public" + ) + assert ep.host == "aurora-reader.example.com" + assert ep.port == "6543" + assert ep.user == "litellm_user" + assert ep.name == "litellm" + assert ep.schema == "public" + + +def test_parse_iam_endpoint_defaults_port_to_5432_and_skips_schema(): + from litellm.proxy.db.prisma_client import parse_iam_endpoint_from_url + + ep = parse_iam_endpoint_from_url("postgresql://u@host/dbname") + assert ep.host == "host" + assert ep.port == "5432" + assert ep.user == "u" + assert ep.name == "dbname" + assert ep.schema is None + + +def test_parse_iam_endpoint_rejects_url_without_user_or_dbname(): + from litellm.proxy.db.prisma_client import parse_iam_endpoint_from_url + + with pytest.raises(ValueError, match="missing host or username"): + parse_iam_endpoint_from_url("postgresql://host:5432/db") + with pytest.raises(ValueError, match="missing database name"): + parse_iam_endpoint_from_url("postgresql://u@host:5432/") + + +def test_iam_endpoint_build_url_inserts_token_verbatim(): + from litellm.proxy.db.prisma_client import IAMEndpoint + + # `generate_iam_auth_token` already URL-encodes the presigned token, so + # `build_url` must NOT encode again — double-encoding turned `%3D` into + # `%253D` and broke RDS auth on the reader path. + ep = IAMEndpoint(host="h", port="5432", user="u", name="db", schema="public") + pre_encoded_token = "token%2Fwith%3Fweird%26chars%3Dyes" + url = ep.build_url(pre_encoded_token) + assert url == f"postgresql://u:{pre_encoded_token}@h:5432/db?schema=public" + # Sanity check: no `%25` (the encoding of `%`), confirming we didn't re-encode. + assert "%25" not in url + + +@pytest.mark.asyncio +async def test_iam_refresh_logs_carry_log_prefix(caplog): + """When `log_prefix` is set on a PrismaWrapper, every IAM-related log + line emitted by that wrapper must start with the prefix so writer and + reader can be told apart in interleaved output.""" + from litellm.proxy.db.prisma_client import PrismaWrapper + + wrapper = PrismaWrapper( + original_prisma=MagicMock(), + iam_token_db_auth=True, + log_prefix="[reader]", + ) + + with caplog.at_level(logging.INFO, logger="LiteLLM Proxy"): + await wrapper.start_token_refresh_task() + # Loop emits "RDS IAM token refresh loop started..." on first tick. + # Cancel immediately so the loop body runs once and we can assert. + await wrapper.stop_token_refresh_task() + + messages = [r.getMessage() for r in caplog.records] + # Both start and stop notifications carry the prefix. + assert any( + m.startswith("[reader] Started RDS IAM token proactive refresh") + for m in messages + ) + assert any( + m.startswith("[reader] Stopped RDS IAM token refresh background task") + for m in messages + ) + + +def test_get_rds_iam_token_returns_none_when_iam_disabled(): + """`get_rds_iam_token` short-circuits to None when iam_token_db_auth is + False — covers the early-return guard at the top of the method.""" + from litellm.proxy.db.prisma_client import PrismaWrapper + + wrapper = PrismaWrapper(original_prisma=MagicMock(), iam_token_db_auth=False) + assert wrapper.get_rds_iam_token() is None + + +@pytest.mark.asyncio +async def test_getattr_does_not_block_inside_running_loop_on_expired_token(monkeypatch): + """When `__getattr__` runs inside a running event loop and the IAM token + is expired, it MUST schedule the refresh as a background task and return + immediately. The previous `run_coroutine_threadsafe` + `future.result()` + pattern deadlocks the loop (loop thread blocks waiting for a coroutine + that needs the loop to run) and times out at 30s — exactly what was + breaking the reader on first query.""" + from litellm.proxy.db.prisma_client import PrismaWrapper + + # Stale URL — `is_token_expired` returns True because the password isn't + # a parseable IAM token, so we exercise the expired branch. + monkeypatch.setenv( + "DATABASE_URL_READ_REPLICA", + "postgresql://reader:placeholder@reader.aurora.local:5432/litellm", + ) + + inner = MagicMock() + inner.query_raw = MagicMock(name="query_raw_attr") + + wrapper = PrismaWrapper( + original_prisma=inner, + iam_token_db_auth=True, + db_url_env_var="DATABASE_URL_READ_REPLICA", + ) + + # Replace the heavy refresh coroutine with a no-op AsyncMock so we can + # observe whether it was scheduled without actually doing the recreate. + refresh_calls = {"count": 0} + + async def fake_refresh(): + refresh_calls["count"] += 1 + + monkeypatch.setattr(wrapper, "_safe_refresh_token", fake_refresh) + + # Direct attribute access from inside this async test runs __getattr__ + # on the loop thread, exercising the in-loop branch. If the previous + # `run_coroutine_threadsafe` + `future.result()` pattern were back, this + # line would deadlock the loop and the test would hang (and pytest's + # per-test timeout would catch it). + attr = wrapper.query_raw + # Yield once so the scheduled refresh task gets a chance to run. + await asyncio.sleep(0) + + assert attr is inner.query_raw + assert refresh_calls["count"] == 1 + + +def test_writer_get_rds_iam_token_defaults_port_when_unset(monkeypatch): + """When DATABASE_PORT is unset, the writer must default to the Postgres + standard port instead of passing `None` through. Passing None to + `generate_iam_auth_token` makes botocore embed the literal string + \"None\" in the presigned URL during signing and crashes with + `ValueError: Port could not be cast to integer value as 'None'`.""" + from litellm.proxy.db.prisma_client import PrismaWrapper + + monkeypatch.setenv("DATABASE_HOST", "writer.aurora.local") + monkeypatch.delenv("DATABASE_PORT", raising=False) + monkeypatch.setenv("DATABASE_USER", "litellm") + monkeypatch.setenv("DATABASE_NAME", "litellm") + monkeypatch.delenv("DATABASE_SCHEMA", raising=False) + monkeypatch.delenv("DATABASE_URL", raising=False) + + captured: Dict[str, Any] = {} + + def fake_generate(db_host=None, db_port=None, db_user=None): + captured["port"] = db_port + return "TOKEN" + + fake_module = MagicMock() + fake_module.generate_iam_auth_token = fake_generate + monkeypatch.setitem(sys.modules, "litellm.proxy.auth.rds_iam_token", fake_module) + + writer = PrismaWrapper( + original_prisma=MagicMock(), + iam_token_db_auth=True, + ) + new_url = writer.get_rds_iam_token() + + assert captured["port"] == "5432" # default applied, NOT None + assert ":5432/litellm" in (new_url or "") + + +def test_writer_get_rds_iam_token_uses_database_host_env_vars(monkeypatch): + """Writer's IAM path (no iam_endpoint configured) reads host/port/user/db + from the legacy DATABASE_HOST/PORT/USER/NAME env vars and writes the URL + back to DATABASE_URL — this is the pre-read-replica behavior the patch + must preserve.""" + from litellm.proxy.db.prisma_client import PrismaWrapper + + monkeypatch.setenv("DATABASE_HOST", "writer.aurora.local") + monkeypatch.setenv("DATABASE_PORT", "5432") + monkeypatch.setenv("DATABASE_USER", "litellm") + monkeypatch.setenv("DATABASE_NAME", "litellm") + monkeypatch.setenv("DATABASE_SCHEMA", "public") + monkeypatch.delenv("DATABASE_URL", raising=False) + + captured: Dict[str, Any] = {} + + def fake_generate(db_host=None, db_port=None, db_user=None): + captured["host"] = db_host + captured["port"] = db_port + captured["user"] = db_user + return "WRITER-TOKEN" + + fake_module = MagicMock() + fake_module.generate_iam_auth_token = fake_generate + monkeypatch.setitem(sys.modules, "litellm.proxy.auth.rds_iam_token", fake_module) + + writer = PrismaWrapper( + original_prisma=MagicMock(), + iam_token_db_auth=True, + # No iam_endpoint → legacy DATABASE_HOST/etc. path. + ) + new_url = writer.get_rds_iam_token() + + assert captured == { + "host": "writer.aurora.local", + "port": "5432", + "user": "litellm", + } + assert new_url == ( + "postgresql://litellm:WRITER-TOKEN@writer.aurora.local:5432/litellm?schema=public" + ) + # Writer updates its own env var (DATABASE_URL by default), not the reader's. + assert os.environ["DATABASE_URL"] == new_url + + +def test_reader_iam_refresh_uses_parsed_endpoint(monkeypatch): + """The reader generates fresh tokens against its parsed endpoint and + writes the new URL to DATABASE_URL_READ_REPLICA — not DATABASE_URL.""" + from litellm.proxy.db.prisma_client import IAMEndpoint, PrismaWrapper + + # Pre-seed env vars so we can prove the reader does NOT touch DATABASE_URL. + monkeypatch.setenv("DATABASE_URL", "writer-url-untouched") + monkeypatch.setenv("DATABASE_URL_READ_REPLICA", "stale-reader-url") + + captured: Dict[str, Any] = {} + + def fake_generate(db_host=None, db_port=None, db_user=None): + captured["host"] = db_host + captured["port"] = db_port + captured["user"] = db_user + return "FRESH-TOKEN" + + fake_module = MagicMock() + fake_module.generate_iam_auth_token = fake_generate + monkeypatch.setitem(sys.modules, "litellm.proxy.auth.rds_iam_token", fake_module) + + endpoint = IAMEndpoint( + host="reader.aurora.local", + port="5432", + user="lit", + name="litellm", + schema=None, + ) + reader = PrismaWrapper( + original_prisma=MagicMock(), + iam_token_db_auth=True, + db_url_env_var="DATABASE_URL_READ_REPLICA", + iam_endpoint=endpoint, + recreate_uses_datasource=True, + ) + + new_url = reader.get_rds_iam_token() + + # IAM token generator was called with the reader's parsed endpoint, not + # the writer's DATABASE_HOST/PORT/USER env vars. + assert captured == { + "host": "reader.aurora.local", + "port": "5432", + "user": "lit", + } + assert new_url is not None + assert new_url.startswith( + "postgresql://lit:FRESH-TOKEN@reader.aurora.local:5432/litellm" + ) + # The reader updates its OWN env var; writer's DATABASE_URL is left alone. + assert os.environ["DATABASE_URL_READ_REPLICA"] == new_url + assert os.environ["DATABASE_URL"] == "writer-url-untouched" + + +@pytest.mark.asyncio +async def test_reader_recreate_uses_datasource_override(monkeypatch): + """Reader recreate must pass `datasource={"url": ...}` to Prisma() — Prisma + only auto-reads DATABASE_URL, so without the override the new reader URL + would be silently ignored.""" + from litellm.proxy.db.prisma_client import IAMEndpoint, PrismaWrapper + + captured_kwargs: Dict[str, Any] = {} + + class FakePrisma: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + async def connect(self): + return None + + fake_module = MagicMock() + fake_module.Prisma = FakePrisma + monkeypatch.setitem(sys.modules, "prisma", fake_module) + + reader = PrismaWrapper( + original_prisma=MagicMock(), + iam_token_db_auth=True, + db_url_env_var="DATABASE_URL_READ_REPLICA", + iam_endpoint=IAMEndpoint(host="h", port="5432", user="u", name="db"), + recreate_uses_datasource=True, + ) + + await reader.recreate_prisma_client( + "postgresql://u:newtoken@h:5432/db", http_client=None + ) + + assert captured_kwargs == { + "datasource": {"url": "postgresql://u:newtoken@h:5432/db"} + } + + +@pytest.mark.asyncio +async def test_writer_recreate_does_not_use_datasource(monkeypatch): + """Writer keeps relying on Prisma reading DATABASE_URL from env — datasource + override must NOT leak into the writer path (would override the freshly + rotated env var).""" + from litellm.proxy.db.prisma_client import PrismaWrapper + + captured_kwargs: Dict[str, Any] = {} + + class FakePrisma: + def __init__(self, **kwargs): + captured_kwargs.update(kwargs) + + async def connect(self): + return None + + fake_module = MagicMock() + fake_module.Prisma = FakePrisma + monkeypatch.setitem(sys.modules, "prisma", fake_module) + + writer = PrismaWrapper( + original_prisma=MagicMock(), + iam_token_db_auth=True, + ) + + await writer.recreate_prisma_client( + "postgresql://u:newtoken@h:5432/db", http_client=None + ) + + assert "datasource" not in captured_kwargs + + +def test_prisma_client_init_falls_back_to_writer_when_reader_iam_token_fails( + monkeypatch, caplog +): + """A transient AWS STS error (or any other failure) during the reader + IAM token mint must NOT abort proxy startup. The reader is opt-in, so + `PrismaClient.__init__` should log a warning and fall back to the + writer-only `PrismaWrapper`. The runtime contract in + `RoutingPrismaWrapper.connect` already says reader-side failures are + non-fatal — but that code never runs if construction throws first.""" + from litellm.proxy.db.prisma_client import PrismaWrapper + from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper + + monkeypatch.setenv("IAM_TOKEN_DB_AUTH", "true") + monkeypatch.setenv( + "DATABASE_URL_READ_REPLICA", + "postgresql://reader_user@reader.aurora.local:5432/litellm", + ) + + class FakePrisma: + def __init__(self, **kwargs): + self.kwargs = kwargs + + async def connect(self): + return None + + fake_prisma_module = MagicMock() + fake_prisma_module.Prisma = FakePrisma + monkeypatch.setitem(sys.modules, "prisma", fake_prisma_module) + + fake_iam_module = MagicMock() + + def boom(**_kwargs): + raise RuntimeError("simulated AWS STS hiccup") + + fake_iam_module.generate_iam_auth_token = boom + monkeypatch.setitem( + sys.modules, "litellm.proxy.auth.rds_iam_token", fake_iam_module + ) + + from litellm.proxy.utils import PrismaClient + + with caplog.at_level(logging.WARNING, logger="LiteLLM Proxy"): + client = PrismaClient( + database_url="postgresql://writer@writer.aurora.local:5432/litellm", + proxy_logging_obj=MagicMock(), + ) + + # Construction did not raise, and the proxy is in writer-only mode — + # NOT a RoutingPrismaWrapper, so reads will go to the writer. + assert isinstance(client.db, PrismaWrapper) + assert not isinstance(client.db, RoutingPrismaWrapper) + # And the operator gets a clear warning. + assert any( + "Failed to initialize read replica Prisma client" in r.getMessage() + for r in caplog.records + )