feat: add read-replica routing for Prisma DB via DATABASE_URL_READ_REPLICA (#27493)
- Introduce RoutingPrismaWrapper that transparently routes read operations (find_*, count, group_by, query_raw, query_first) to a reader endpoint while writes remain on the writer, enabling Aurora-style reader/writer endpoint splits - Add IAMEndpoint dataclass and parse_iam_endpoint_from_url() to capture static connection fields from a reader URL so only the IAM token needs to rotate, avoiding the need for separate DATABASE_HOST_READ_REPLICA/etc. env vars - Enhance PrismaWrapper with per-instance knobs (db_url_env_var, iam_endpoint, recreate_uses_datasource, log_prefix) so writer and reader wrappers are independent: the reader writes its fresh URL to DATABASE_URL_READ_REPLICA and passes datasource override to Prisma since Prisma only auto-reads DATABASE_URL - Fix deadlock in PrismaWrapper.__getattr__: when called from inside a running event loop, schedule the token refresh as a background task instead of blocking with run_coroutine_threadsafe + future.result(), which would deadlock the loop thread waiting for a coroutine that needs the loop to run - Fix botocore crash when DATABASE_PORT is unset by defaulting to "5432" in both proxy_cli.py and PrismaWrapper.get_rds_iam_token(); passing None caused botocore to embed the literal string "None" in the presigned URL - Implement graceful reader degradation: reader connect/recreate failures are non-fatal; wrapper sets _reader_unavailable=True and silently routes reads to the writer to keep the proxy serving traffic during transient reader outages - Add PrismaClient.writer_db property so the reconnect smoke-test always validates the writer engine specifically; query_raw on the routing wrapper would route to the reader and not verify the newly-recreated writer - Expose DATABASE_URL_READ_REPLICA in Helm chart (values.yaml + deployment.yaml) via both plain value and secret key reference, and document the field in docker-compose.yml - Add 887-line test suite covering routing logic, IAM token refresh paths, reader degradation scenarios, datasource override behavior, and the deadlock regression Co-authored-by: Yassin Kortam <yassinkortam@g.ucla.edu>
This commit is contained in:
parent
0bcff0214a
commit
b5d3a5fc85
@ -100,6 +100,16 @@ spec:
|
|||||||
- name: DATABASE_URL
|
- name: DATABASE_URL
|
||||||
value: {{ .Values.db.url | quote }}
|
value: {{ .Values.db.url | quote }}
|
||||||
{{- end }}
|
{{- end }}
|
||||||
|
{{- if and .Values.db.useExisting .Values.db.secret.readReplicaUrlKey }}
|
||||||
|
- name: DATABASE_URL_READ_REPLICA
|
||||||
|
valueFrom:
|
||||||
|
secretKeyRef:
|
||||||
|
name: {{ .Values.db.secret.name }}
|
||||||
|
key: {{ .Values.db.secret.readReplicaUrlKey }}
|
||||||
|
{{- else if .Values.db.readReplicaUrl }}
|
||||||
|
- name: DATABASE_URL_READ_REPLICA
|
||||||
|
value: {{ .Values.db.readReplicaUrl | quote }}
|
||||||
|
{{- end }}
|
||||||
- name: PROXY_MASTER_KEY
|
- name: PROXY_MASTER_KEY
|
||||||
valueFrom:
|
valueFrom:
|
||||||
secretKeyRef:
|
secretKeyRef:
|
||||||
|
|||||||
@ -252,6 +252,26 @@ db:
|
|||||||
passwordKey: password
|
passwordKey: password
|
||||||
# Optional: when set, DATABASE_HOST will be sourced from this secret key instead of db.endpoint
|
# Optional: when set, DATABASE_HOST will be sourced from this secret key instead of db.endpoint
|
||||||
endpointKey: ""
|
endpointKey: ""
|
||||||
|
# Optional: when set, DATABASE_URL_READ_REPLICA will be sourced from this
|
||||||
|
# secret key instead of db.readReplicaUrl. Prefer this over the plain
|
||||||
|
# value: read-replica URLs typically embed credentials, and a value
|
||||||
|
# written to db.readReplicaUrl ends up visible in the rendered pod spec
|
||||||
|
# and the Helm release secret.
|
||||||
|
readReplicaUrlKey: ""
|
||||||
|
|
||||||
|
# Optional read-replica routing. When set, the proxy sends read-only
|
||||||
|
# queries (find_*, count, group_by, query_raw/_first) to this URL while
|
||||||
|
# writes continue to go to db.url. Useful for Aurora-style clusters with
|
||||||
|
# separate reader/writer endpoints. Leave empty to keep single-DB behavior.
|
||||||
|
# When IAM_TOKEN_DB_AUTH is enabled, the reader URL is auto-refreshed
|
||||||
|
# alongside the writer (host/port/user/db are parsed from this URL once
|
||||||
|
# at startup; only the IAM token rotates).
|
||||||
|
#
|
||||||
|
# If the URL embeds credentials, prefer db.secret.readReplicaUrlKey over
|
||||||
|
# this field — the plain value is rendered into the pod spec and the
|
||||||
|
# Helm release secret. This field is intended for credential-less URLs
|
||||||
|
# only (e.g. when IAM_TOKEN_DB_AUTH supplies the token at runtime).
|
||||||
|
readReplicaUrl: ""
|
||||||
|
|
||||||
# Use the Stackgres Helm chart to deploy an instance of a Stackgres cluster.
|
# Use the Stackgres Helm chart to deploy an instance of a Stackgres cluster.
|
||||||
# The Stackgres Operator must already be installed within the target
|
# The Stackgres Operator must already be installed within the target
|
||||||
|
|||||||
@ -16,6 +16,11 @@ services:
|
|||||||
- "4000:4000" # Map the container port to the host, change the host port if necessary
|
- "4000:4000" # Map the container port to the host, change the host port if necessary
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: "postgresql://llmproxy:dbpassword9090@db:5432/litellm"
|
DATABASE_URL: "postgresql://llmproxy:dbpassword9090@db:5432/litellm"
|
||||||
|
# Optional: route read-only queries (find_*, count, group_by, query_raw/_first)
|
||||||
|
# to a separate reader endpoint, e.g. an Aurora reader. Leave unset for
|
||||||
|
# single-DB deployments. With IAM_TOKEN_DB_AUTH enabled, the reader URL
|
||||||
|
# is auto-refreshed alongside the writer.
|
||||||
|
# DATABASE_URL_READ_REPLICA: "postgresql://llmproxy:dbpassword9090@db-reader:5432/litellm"
|
||||||
STORE_MODEL_IN_DB: "True" # allows adding models to proxy via UI
|
STORE_MODEL_IN_DB: "True" # allows adding models to proxy via UI
|
||||||
env_file:
|
env_file:
|
||||||
- .env # Load local .env file
|
- .env # Load local .env file
|
||||||
|
|||||||
@ -10,13 +10,64 @@ import subprocess
|
|||||||
import time
|
import time
|
||||||
import urllib
|
import urllib
|
||||||
import urllib.parse
|
import urllib.parse
|
||||||
|
from dataclasses import dataclass
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.secret_managers.main import str_to_bool
|
from litellm.secret_managers.main import str_to_bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class IAMEndpoint:
|
||||||
|
"""Static parts of an RDS IAM-authenticated Postgres connection.
|
||||||
|
|
||||||
|
The IAM token rotates every ~15 minutes; everything else (host, port, user,
|
||||||
|
database name, schema) stays fixed. We capture the static fields once so
|
||||||
|
refresh just regenerates the token and reassembles the URL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
host: str
|
||||||
|
port: str
|
||||||
|
user: str
|
||||||
|
name: str
|
||||||
|
schema: Optional[str] = None
|
||||||
|
|
||||||
|
def build_url(self, token: str) -> str:
|
||||||
|
url = f"postgresql://{self.user}:{token}@{self.host}:{self.port}/{self.name}"
|
||||||
|
if self.schema:
|
||||||
|
url += f"?schema={self.schema}"
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def parse_iam_endpoint_from_url(url: str) -> IAMEndpoint:
|
||||||
|
"""Parse an IAMEndpoint from a Postgres URL.
|
||||||
|
|
||||||
|
Used so a reader URL can drive its own IAM refresh without requiring
|
||||||
|
callers to set parallel DATABASE_HOST_READ_REPLICA / etc. env vars.
|
||||||
|
"""
|
||||||
|
parsed = urllib.parse.urlparse(url)
|
||||||
|
if not parsed.hostname or not parsed.username:
|
||||||
|
raise ValueError("Cannot parse IAM endpoint from URL: missing host or username")
|
||||||
|
name = (parsed.path or "/").lstrip("/")
|
||||||
|
if not name:
|
||||||
|
raise ValueError("Cannot parse IAM endpoint from URL: missing database name")
|
||||||
|
port = str(parsed.port) if parsed.port else "5432"
|
||||||
|
schema: Optional[str] = None
|
||||||
|
if parsed.query:
|
||||||
|
qs = urllib.parse.parse_qs(parsed.query)
|
||||||
|
schema_vals = qs.get("schema")
|
||||||
|
if schema_vals:
|
||||||
|
schema = schema_vals[0]
|
||||||
|
return IAMEndpoint(
|
||||||
|
host=parsed.hostname,
|
||||||
|
port=port,
|
||||||
|
user=parsed.username,
|
||||||
|
name=name,
|
||||||
|
schema=schema,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PrismaWrapper:
|
class PrismaWrapper:
|
||||||
"""
|
"""
|
||||||
Wrapper around Prisma client that handles RDS IAM token authentication.
|
Wrapper around Prisma client that handles RDS IAM token authentication.
|
||||||
@ -37,10 +88,33 @@ class PrismaWrapper:
|
|||||||
# Fallback refresh interval if token parsing fails (10 minutes)
|
# Fallback refresh interval if token parsing fails (10 minutes)
|
||||||
FALLBACK_REFRESH_INTERVAL_SECONDS = 600
|
FALLBACK_REFRESH_INTERVAL_SECONDS = 600
|
||||||
|
|
||||||
def __init__(self, original_prisma: Any, iam_token_db_auth: bool):
|
def __init__(
|
||||||
|
self,
|
||||||
|
original_prisma: Any,
|
||||||
|
iam_token_db_auth: bool,
|
||||||
|
*,
|
||||||
|
db_url_env_var: str = "DATABASE_URL",
|
||||||
|
iam_endpoint: Optional[IAMEndpoint] = None,
|
||||||
|
recreate_uses_datasource: bool = False,
|
||||||
|
log_prefix: str = "",
|
||||||
|
):
|
||||||
self._original_prisma = original_prisma
|
self._original_prisma = original_prisma
|
||||||
self.iam_token_db_auth = iam_token_db_auth
|
self.iam_token_db_auth = iam_token_db_auth
|
||||||
|
|
||||||
|
# Per-connection knobs so the same wrapper can be used for the writer
|
||||||
|
# (defaults: DATABASE_URL env, IAM endpoint from DATABASE_HOST/etc.,
|
||||||
|
# recreate via env reload) or for a reader (DATABASE_URL_READ_REPLICA
|
||||||
|
# env, IAM endpoint parsed from that URL, recreate via datasource
|
||||||
|
# override since Prisma only auto-reads DATABASE_URL).
|
||||||
|
self._db_url_env_var = db_url_env_var
|
||||||
|
self._iam_endpoint = iam_endpoint
|
||||||
|
self._recreate_uses_datasource = recreate_uses_datasource
|
||||||
|
# Tag every log line emitted by this wrapper instance so writer and
|
||||||
|
# reader can be told apart in interleaved output (e.g. "[writer] RDS
|
||||||
|
# IAM token refresh scheduled in 720 seconds"). Empty string (default)
|
||||||
|
# keeps backward-compatible logs for the single-DB case.
|
||||||
|
self._log_prefix = f"{log_prefix} " if log_prefix else ""
|
||||||
|
|
||||||
# Background token refresh task management
|
# Background token refresh task management
|
||||||
self._token_refresh_task: Optional[asyncio.Task] = None
|
self._token_refresh_task: Optional[asyncio.Task] = None
|
||||||
self._reconnection_lock = asyncio.Lock()
|
self._reconnection_lock = asyncio.Lock()
|
||||||
@ -157,7 +231,7 @@ class PrismaWrapper:
|
|||||||
Returns 0 if token should be refreshed immediately.
|
Returns 0 if token should be refreshed immediately.
|
||||||
Returns FALLBACK_REFRESH_INTERVAL_SECONDS if parsing fails.
|
Returns FALLBACK_REFRESH_INTERVAL_SECONDS if parsing fails.
|
||||||
"""
|
"""
|
||||||
db_url = os.getenv("DATABASE_URL")
|
db_url = os.getenv(self._db_url_env_var)
|
||||||
token = self._extract_token_from_db_url(db_url)
|
token = self._extract_token_from_db_url(db_url)
|
||||||
expiration_time = self._parse_token_expiration(token)
|
expiration_time = self._parse_token_expiration(token)
|
||||||
|
|
||||||
@ -199,12 +273,30 @@ class PrismaWrapper:
|
|||||||
return datetime.utcnow() > expiration_time
|
return datetime.utcnow() > expiration_time
|
||||||
|
|
||||||
def get_rds_iam_token(self) -> Optional[str]:
|
def get_rds_iam_token(self) -> Optional[str]:
|
||||||
"""Generate a new RDS IAM token and update DATABASE_URL."""
|
"""Generate a new RDS IAM token and update the configured DB URL env var.
|
||||||
if self.iam_token_db_auth:
|
|
||||||
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
|
|
||||||
|
|
||||||
|
When the wrapper was constructed with an explicit `iam_endpoint`
|
||||||
|
(typical for a reader wrapper whose host/port/user came from a parsed
|
||||||
|
URL), use that. Otherwise fall back to the legacy DATABASE_HOST/PORT/
|
||||||
|
USER/NAME/SCHEMA env vars (writer behavior).
|
||||||
|
"""
|
||||||
|
if not self.iam_token_db_auth:
|
||||||
|
return None
|
||||||
|
|
||||||
|
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
|
||||||
|
|
||||||
|
if self._iam_endpoint is not None:
|
||||||
|
endpoint = self._iam_endpoint
|
||||||
|
token = generate_iam_auth_token(
|
||||||
|
db_host=endpoint.host, db_port=endpoint.port, db_user=endpoint.user
|
||||||
|
)
|
||||||
|
_db_url = endpoint.build_url(token)
|
||||||
|
else:
|
||||||
db_host = os.getenv("DATABASE_HOST")
|
db_host = os.getenv("DATABASE_HOST")
|
||||||
db_port = os.getenv("DATABASE_PORT")
|
# Default to the Postgres standard port; passing None to
|
||||||
|
# `generate_iam_auth_token` makes botocore embed the literal
|
||||||
|
# string "None" in the presigned URL, which then fails to parse.
|
||||||
|
db_port = os.getenv("DATABASE_PORT", "5432")
|
||||||
db_user = os.getenv("DATABASE_USER")
|
db_user = os.getenv("DATABASE_USER")
|
||||||
db_name = os.getenv("DATABASE_NAME")
|
db_name = os.getenv("DATABASE_NAME")
|
||||||
db_schema = os.getenv("DATABASE_SCHEMA")
|
db_schema = os.getenv("DATABASE_SCHEMA")
|
||||||
@ -217,9 +309,8 @@ class PrismaWrapper:
|
|||||||
if db_schema:
|
if db_schema:
|
||||||
_db_url += f"?schema={db_schema}"
|
_db_url += f"?schema={db_schema}"
|
||||||
|
|
||||||
os.environ["DATABASE_URL"] = _db_url
|
os.environ[self._db_url_env_var] = _db_url
|
||||||
return _db_url
|
return _db_url
|
||||||
return None
|
|
||||||
|
|
||||||
async def recreate_prisma_client(
|
async def recreate_prisma_client(
|
||||||
self, new_db_url: str, http_client: Optional[Any] = None
|
self, new_db_url: str, http_client: Optional[Any] = None
|
||||||
@ -231,6 +322,11 @@ class PrismaWrapper:
|
|||||||
synchronous `subprocess.Popen.wait()` that can freeze the asyncio event
|
synchronous `subprocess.Popen.wait()` that can freeze the asyncio event
|
||||||
loop for 30-120+ seconds when the engine is stuck on TCP close,
|
loop for 30-120+ seconds when the engine is stuck on TCP close,
|
||||||
breaking `/health/liveliness` and causing Kubernetes pod restarts.
|
breaking `/health/liveliness` and causing Kubernetes pod restarts.
|
||||||
|
|
||||||
|
The writer wrapper relies on Prisma re-reading `DATABASE_URL` from env;
|
||||||
|
the reader wrapper opts into `recreate_uses_datasource=True` so the
|
||||||
|
new URL is passed explicitly via `datasource={"url": ...}` (Prisma
|
||||||
|
does not auto-read alternate env vars like DATABASE_URL_READ_REPLICA).
|
||||||
"""
|
"""
|
||||||
from prisma import Prisma # type: ignore
|
from prisma import Prisma # type: ignore
|
||||||
|
|
||||||
@ -238,10 +334,12 @@ class PrismaWrapper:
|
|||||||
if old_engine_pid > 0:
|
if old_engine_pid > 0:
|
||||||
await self._kill_engine_process(old_engine_pid)
|
await self._kill_engine_process(old_engine_pid)
|
||||||
|
|
||||||
|
kwargs: Dict[str, Any] = {}
|
||||||
if http_client is not None:
|
if http_client is not None:
|
||||||
self._original_prisma = Prisma(http=http_client)
|
kwargs["http"] = http_client
|
||||||
else:
|
if self._recreate_uses_datasource:
|
||||||
self._original_prisma = Prisma()
|
kwargs["datasource"] = {"url": new_db_url}
|
||||||
|
self._original_prisma = Prisma(**kwargs)
|
||||||
|
|
||||||
await self._original_prisma.connect()
|
await self._original_prisma.connect()
|
||||||
|
|
||||||
@ -265,7 +363,8 @@ class PrismaWrapper:
|
|||||||
|
|
||||||
self._token_refresh_task = asyncio.create_task(self._token_refresh_loop())
|
self._token_refresh_task = asyncio.create_task(self._token_refresh_loop())
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
"Started RDS IAM token proactive refresh background task"
|
"%sStarted RDS IAM token proactive refresh background task",
|
||||||
|
self._log_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def stop_token_refresh_task(self) -> None:
|
async def stop_token_refresh_task(self) -> None:
|
||||||
@ -283,7 +382,9 @@ class PrismaWrapper:
|
|||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
pass
|
pass
|
||||||
self._token_refresh_task = None
|
self._token_refresh_task = None
|
||||||
verbose_proxy_logger.info("Stopped RDS IAM token refresh background task")
|
verbose_proxy_logger.info(
|
||||||
|
"%sStopped RDS IAM token refresh background task", self._log_prefix
|
||||||
|
)
|
||||||
|
|
||||||
async def _token_refresh_loop(self) -> None:
|
async def _token_refresh_loop(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -294,7 +395,7 @@ class PrismaWrapper:
|
|||||||
This is more efficient than polling, requiring only 1 wake-up per token cycle.
|
This is more efficient than polling, requiring only 1 wake-up per token cycle.
|
||||||
"""
|
"""
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"RDS IAM token refresh loop started. "
|
f"{self._log_prefix}RDS IAM token refresh loop started. "
|
||||||
f"Tokens will be refreshed {self.TOKEN_REFRESH_BUFFER_SECONDS}s before expiration."
|
f"Tokens will be refreshed {self.TOKEN_REFRESH_BUFFER_SECONDS}s before expiration."
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -305,21 +406,25 @@ class PrismaWrapper:
|
|||||||
|
|
||||||
if sleep_seconds > 0:
|
if sleep_seconds > 0:
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
f"RDS IAM token refresh scheduled in {sleep_seconds:.0f} seconds "
|
f"{self._log_prefix}RDS IAM token refresh scheduled in "
|
||||||
f"({sleep_seconds / 60:.1f} minutes)"
|
f"{sleep_seconds:.0f} seconds ({sleep_seconds / 60:.1f} minutes)"
|
||||||
)
|
)
|
||||||
await asyncio.sleep(sleep_seconds)
|
await asyncio.sleep(sleep_seconds)
|
||||||
|
|
||||||
# Refresh the token
|
# Refresh the token
|
||||||
verbose_proxy_logger.info("Proactively refreshing RDS IAM token...")
|
verbose_proxy_logger.info(
|
||||||
|
"%sProactively refreshing RDS IAM token...", self._log_prefix
|
||||||
|
)
|
||||||
await self._safe_refresh_token()
|
await self._safe_refresh_token()
|
||||||
|
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
verbose_proxy_logger.info("RDS IAM token refresh loop cancelled")
|
verbose_proxy_logger.info(
|
||||||
|
"%sRDS IAM token refresh loop cancelled", self._log_prefix
|
||||||
|
)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
f"Error in RDS IAM token refresh loop: {e}. "
|
f"{self._log_prefix}Error in RDS IAM token refresh loop: {e}. "
|
||||||
f"Retrying in {self.FALLBACK_REFRESH_INTERVAL_SECONDS}s..."
|
f"Retrying in {self.FALLBACK_REFRESH_INTERVAL_SECONDS}s..."
|
||||||
)
|
)
|
||||||
# On error, wait before retrying to avoid tight error loops
|
# On error, wait before retrying to avoid tight error loops
|
||||||
@ -341,65 +446,75 @@ class PrismaWrapper:
|
|||||||
await self.recreate_prisma_client(new_db_url)
|
await self.recreate_prisma_client(new_db_url)
|
||||||
self._last_refresh_time = datetime.utcnow()
|
self._last_refresh_time = datetime.utcnow()
|
||||||
verbose_proxy_logger.info(
|
verbose_proxy_logger.info(
|
||||||
"RDS IAM token refreshed successfully. New token valid for ~15 minutes."
|
"%sRDS IAM token refreshed successfully. New token valid for ~15 minutes.",
|
||||||
|
self._log_prefix,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
verbose_proxy_logger.error(
|
verbose_proxy_logger.error(
|
||||||
"Failed to generate new RDS IAM token during proactive refresh"
|
"%sFailed to generate new RDS IAM token during proactive refresh",
|
||||||
|
self._log_prefix,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __getattr__(self, name: str):
|
def __getattr__(self, name: str):
|
||||||
"""
|
"""
|
||||||
Proxy attribute access to the underlying Prisma client.
|
Proxy attribute access to the underlying Prisma client.
|
||||||
|
|
||||||
If IAM token auth is enabled and the token is expired, this method
|
If IAM token auth is enabled and the token is found expired here, the
|
||||||
provides a synchronous fallback to refresh the token. However, this
|
proactive refresh task has missed its window. Behavior depends on
|
||||||
should rarely be needed since the background task proactively refreshes
|
whether we're called from inside a running event loop:
|
||||||
tokens before they expire.
|
|
||||||
|
|
||||||
FIXED: Now properly waits for reconnection to complete before returning,
|
- Inside the loop (typical: from a coroutine): schedule a refresh as a
|
||||||
instead of the previous fire-and-forget pattern that caused the bug.
|
background task and return the (stale) attribute. The caller's await
|
||||||
|
will likely fail with a connection error and be retried by upper
|
||||||
|
layers (`call_with_db_reconnect_retry`); by that time the refresh
|
||||||
|
has either completed or escalated to the proactive loop's error
|
||||||
|
path. We CANNOT block here — `run_coroutine_threadsafe(...)` +
|
||||||
|
`future.result()` from inside the same loop deadlocks the loop
|
||||||
|
(loop thread is blocked, scheduled coroutine never runs, 30s timeout).
|
||||||
|
|
||||||
|
- No running loop (sync caller, mostly tests): run the refresh in a
|
||||||
|
fresh loop and re-fetch the attribute.
|
||||||
"""
|
"""
|
||||||
original_attr = getattr(self._original_prisma, name)
|
original_attr = getattr(self._original_prisma, name)
|
||||||
|
|
||||||
if self.iam_token_db_auth:
|
if self.iam_token_db_auth:
|
||||||
db_url = os.getenv("DATABASE_URL")
|
db_url = os.getenv(self._db_url_env_var)
|
||||||
|
|
||||||
# Check if token is expired (should be rare if background task is running)
|
# Check if token is expired (should be rare if background task is running)
|
||||||
if self.is_token_expired(db_url):
|
if self.is_token_expired(db_url):
|
||||||
verbose_proxy_logger.warning(
|
try:
|
||||||
"RDS IAM token expired in __getattr__ - proactive refresh may have failed. "
|
running_loop = asyncio.get_running_loop()
|
||||||
"Triggering synchronous fallback refresh..."
|
except RuntimeError:
|
||||||
)
|
running_loop = None
|
||||||
|
|
||||||
new_db_url = self.get_rds_iam_token()
|
if running_loop is not None:
|
||||||
if new_db_url:
|
verbose_proxy_logger.warning(
|
||||||
loop = asyncio.get_event_loop()
|
"%sRDS IAM token expired in __getattr__ — proactive refresh "
|
||||||
|
"may have failed. Scheduling async refresh; the current "
|
||||||
if loop.is_running():
|
"request may fail and be retried with the fresh token.",
|
||||||
# FIXED: Actually wait for the reconnection to complete!
|
self._log_prefix,
|
||||||
# The previous code used fire-and-forget which caused the bug.
|
)
|
||||||
future = asyncio.run_coroutine_threadsafe(
|
# Non-blocking: schedule the locked refresh on the
|
||||||
self.recreate_prisma_client(new_db_url), loop
|
# running loop. The reconnection lock inside
|
||||||
)
|
# `_safe_refresh_token` coalesces concurrent triggers.
|
||||||
try:
|
running_loop.create_task(self._safe_refresh_token())
|
||||||
# Wait up to 30 seconds for reconnection
|
|
||||||
future.result(timeout=30)
|
|
||||||
verbose_proxy_logger.info(
|
|
||||||
"Synchronous token refresh completed successfully"
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
verbose_proxy_logger.error(
|
|
||||||
f"Failed to refresh token synchronously: {e}"
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
else:
|
|
||||||
asyncio.run(self.recreate_prisma_client(new_db_url))
|
|
||||||
|
|
||||||
# Get the NEW attribute after reconnection
|
|
||||||
original_attr = getattr(self._original_prisma, name)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError("Failed to get RDS IAM token")
|
verbose_proxy_logger.warning(
|
||||||
|
"%sRDS IAM token expired in __getattr__ — proactive refresh "
|
||||||
|
"may have failed. Triggering synchronous fallback refresh...",
|
||||||
|
self._log_prefix,
|
||||||
|
)
|
||||||
|
new_db_url = self.get_rds_iam_token()
|
||||||
|
if new_db_url:
|
||||||
|
asyncio.run(self.recreate_prisma_client(new_db_url))
|
||||||
|
# Re-fetch attribute against the recreated Prisma instance.
|
||||||
|
original_attr = getattr(self._original_prisma, name)
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
"%sSynchronous token refresh completed successfully",
|
||||||
|
self._log_prefix,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("Failed to get RDS IAM token")
|
||||||
|
|
||||||
return original_attr
|
return original_attr
|
||||||
|
|
||||||
|
|||||||
213
litellm/proxy/db/routing_prisma_wrapper.py
Normal file
213
litellm/proxy/db/routing_prisma_wrapper.py
Normal file
@ -0,0 +1,213 @@
|
|||||||
|
"""
|
||||||
|
RoutingPrismaWrapper: routes Prisma reads to a read-replica client and writes
|
||||||
|
to a writer client. Used when DATABASE_URL_READ_REPLICA is configured;
|
||||||
|
otherwise PrismaClient uses the writer-only PrismaWrapper directly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
|
||||||
|
# Per-model action methods that read from the database. These are routed to
|
||||||
|
# the read replica when one is configured.
|
||||||
|
_MODEL_READ_METHODS = frozenset(
|
||||||
|
{
|
||||||
|
"find_first",
|
||||||
|
"find_first_or_raise",
|
||||||
|
"find_many",
|
||||||
|
"find_unique",
|
||||||
|
"find_unique_or_raise",
|
||||||
|
"count",
|
||||||
|
"group_by",
|
||||||
|
"query_first",
|
||||||
|
"query_raw",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Top-level Prisma client methods that read from the database.
|
||||||
|
_TOP_LEVEL_READ_METHODS = frozenset({"query_first", "query_raw"})
|
||||||
|
|
||||||
|
|
||||||
|
class _RoutedActions:
|
||||||
|
"""Per-model accessor that sends reads to the reader and writes to the writer.
|
||||||
|
|
||||||
|
`should_use_reader` is consulted on every read dispatch so a mid-call flip
|
||||||
|
of the routing wrapper's reader-availability flag (e.g. after the reader
|
||||||
|
fails a recreate) is observed without re-fetching the actions accessor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__slots__ = ("_writer_actions", "_reader_actions", "_should_use_reader")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
writer_actions: Any,
|
||||||
|
reader_actions: Any,
|
||||||
|
should_use_reader: Callable[[], bool],
|
||||||
|
):
|
||||||
|
self._writer_actions = writer_actions
|
||||||
|
self._reader_actions = reader_actions
|
||||||
|
self._should_use_reader = should_use_reader
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
if name in _MODEL_READ_METHODS and self._should_use_reader():
|
||||||
|
return getattr(self._reader_actions, name)
|
||||||
|
return getattr(self._writer_actions, name)
|
||||||
|
|
||||||
|
|
||||||
|
class RoutingPrismaWrapper:
|
||||||
|
"""
|
||||||
|
Routes Prisma operations between a writer and a reader Prisma client.
|
||||||
|
|
||||||
|
Reads (find_*, count, group_by, query_raw, query_first) go to the reader;
|
||||||
|
everything else (writes, transactions, raw execute) goes to the writer.
|
||||||
|
Lifecycle methods (connect, disconnect, IAM token refresh) act on both
|
||||||
|
clients so callers do not need to know about the split. When
|
||||||
|
IAM_TOKEN_DB_AUTH is enabled, both writer and reader refresh their tokens
|
||||||
|
independently on their own ~12-minute cadence.
|
||||||
|
|
||||||
|
Reader degradation: a reader-side failure (failed connect, failed
|
||||||
|
recreate) is non-fatal — the wrapper sets `_reader_unavailable=True`, logs
|
||||||
|
a warning, and routes subsequent reads to the writer. The next successful
|
||||||
|
`connect()` or `recreate_prisma_client()` clears the flag. This keeps the
|
||||||
|
proxy serving traffic during transient reader outages instead of failing
|
||||||
|
startup or returning errors for read-heavy endpoints.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, writer: PrismaWrapper, reader: PrismaWrapper):
|
||||||
|
self._writer = writer
|
||||||
|
self._reader = reader
|
||||||
|
# When True, reads fall back to the writer. Flipped on by reader
|
||||||
|
# connect/recreate failures and flipped off on the next reader recovery.
|
||||||
|
self._reader_unavailable: bool = False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def writer(self) -> PrismaWrapper:
|
||||||
|
return self._writer
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reader(self) -> PrismaWrapper:
|
||||||
|
return self._reader
|
||||||
|
|
||||||
|
@property
|
||||||
|
def reader_unavailable(self) -> bool:
|
||||||
|
return self._reader_unavailable
|
||||||
|
|
||||||
|
def _should_use_reader(self) -> bool:
|
||||||
|
return not self._reader_unavailable
|
||||||
|
|
||||||
|
async def connect(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
await self._writer.connect(*args, **kwargs)
|
||||||
|
verbose_proxy_logger.info("[writer] DB connected")
|
||||||
|
try:
|
||||||
|
await self._reader.connect(*args, **kwargs)
|
||||||
|
self._reader_unavailable = False
|
||||||
|
verbose_proxy_logger.info("[reader] DB connected")
|
||||||
|
except Exception as e:
|
||||||
|
# Degrade gracefully: the proxy keeps serving traffic with reads
|
||||||
|
# routed to the writer until the reader endpoint is reachable.
|
||||||
|
# Aborting startup here would tie proxy availability to an
|
||||||
|
# opt-in, best-effort reader endpoint.
|
||||||
|
self._reader_unavailable = True
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"Failed to connect to read replica DB: %s. "
|
||||||
|
"Falling back to the writer for reads until the reader is reachable.",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def disconnect(self, *args: Any, **kwargs: Any) -> None:
|
||||||
|
first_error: Optional[BaseException] = None
|
||||||
|
for client in (self._writer, self._reader):
|
||||||
|
try:
|
||||||
|
await client.disconnect(*args, **kwargs)
|
||||||
|
except Exception as e:
|
||||||
|
if first_error is None:
|
||||||
|
first_error = e
|
||||||
|
verbose_proxy_logger.warning("Error disconnecting Prisma client: %s", e)
|
||||||
|
if first_error is not None:
|
||||||
|
raise first_error
|
||||||
|
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
# Reflects writer health only. The reader is best-effort; its
|
||||||
|
# availability is tracked via `_reader_unavailable` and a degraded
|
||||||
|
# reader must NOT cause a writer reconnect (would loop indefinitely
|
||||||
|
# since recreate_prisma_client only fixes writer-side problems).
|
||||||
|
return bool(self._writer.is_connected())
|
||||||
|
|
||||||
|
async def start_token_refresh_task(self) -> None:
|
||||||
|
await self._writer.start_token_refresh_task()
|
||||||
|
await self._reader.start_token_refresh_task()
|
||||||
|
|
||||||
|
async def stop_token_refresh_task(self) -> None:
|
||||||
|
await self._writer.stop_token_refresh_task()
|
||||||
|
await self._reader.stop_token_refresh_task()
|
||||||
|
|
||||||
|
async def recreate_prisma_client(
|
||||||
|
self, new_db_url: str, http_client: Optional[Any] = None
|
||||||
|
) -> None:
|
||||||
|
"""Recreate both writer and reader Prisma clients.
|
||||||
|
|
||||||
|
The writer reconnect path in PrismaClient calls
|
||||||
|
`self.db.recreate_prisma_client(...)`. Without this method, a DB-wide
|
||||||
|
connectivity event would only re-create the writer; the reader engine
|
||||||
|
would stay broken and every routed read would fail. We always recreate
|
||||||
|
the writer first (its URL is the one passed in), then best-effort
|
||||||
|
recreate the reader. A reader failure flips `_reader_unavailable=True`
|
||||||
|
so reads transparently fall through to the writer.
|
||||||
|
"""
|
||||||
|
await self._writer.recreate_prisma_client(new_db_url, http_client=http_client)
|
||||||
|
try:
|
||||||
|
await self._recreate_reader(http_client=http_client)
|
||||||
|
self._reader_unavailable = False
|
||||||
|
except Exception as e:
|
||||||
|
self._reader_unavailable = True
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"Failed to recreate reader Prisma client: %s. "
|
||||||
|
"Reads will fall back to the writer until the reader recovers.",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _recreate_reader(self, http_client: Optional[Any] = None) -> None:
|
||||||
|
"""Resolve the reader URL and recreate its Prisma client.
|
||||||
|
|
||||||
|
IAM-enabled readers regenerate their token (host/port/user came from
|
||||||
|
the parsed reader URL at construction time). Non-IAM readers reuse
|
||||||
|
the URL stored in `DATABASE_URL_READ_REPLICA`.
|
||||||
|
"""
|
||||||
|
if self._reader.iam_token_db_auth:
|
||||||
|
new_reader_url = self._reader.get_rds_iam_token()
|
||||||
|
if not new_reader_url:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Failed to generate fresh IAM token for read replica"
|
||||||
|
)
|
||||||
|
await self._reader.recreate_prisma_client(
|
||||||
|
new_reader_url, http_client=http_client
|
||||||
|
)
|
||||||
|
return
|
||||||
|
reader_url = os.getenv("DATABASE_URL_READ_REPLICA", "")
|
||||||
|
if not reader_url:
|
||||||
|
raise RuntimeError(
|
||||||
|
"DATABASE_URL_READ_REPLICA not set; cannot recreate read replica client"
|
||||||
|
)
|
||||||
|
await self._reader.recreate_prisma_client(reader_url, http_client=http_client)
|
||||||
|
|
||||||
|
def __getattr__(self, name: str) -> Any:
|
||||||
|
if name in _TOP_LEVEL_READ_METHODS:
|
||||||
|
target = self._writer if self._reader_unavailable else self._reader
|
||||||
|
return getattr(target, name)
|
||||||
|
writer_attr = getattr(self._writer, name)
|
||||||
|
# Per-model action accessors are non-callable instances that expose
|
||||||
|
# both `find_many` and `create`. Methods like execute_raw / batch_ /
|
||||||
|
# tx are callables and stay on the writer untouched.
|
||||||
|
if (
|
||||||
|
not callable(writer_attr)
|
||||||
|
and hasattr(writer_attr, "find_many")
|
||||||
|
and hasattr(writer_attr, "create")
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
reader_attr = getattr(self._reader, name)
|
||||||
|
except AttributeError:
|
||||||
|
return writer_attr
|
||||||
|
return _RoutedActions(writer_attr, reader_attr, self._should_use_reader)
|
||||||
|
return writer_attr
|
||||||
@ -813,7 +813,12 @@ def run_server( # noqa: PLR0915
|
|||||||
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
|
from litellm.proxy.auth.rds_iam_token import generate_iam_auth_token
|
||||||
|
|
||||||
db_host = os.getenv("DATABASE_HOST")
|
db_host = os.getenv("DATABASE_HOST")
|
||||||
db_port = os.getenv("DATABASE_PORT")
|
# Default to the Postgres standard port. Without a default,
|
||||||
|
# `db_port=None` flows into `boto.generate_db_auth_token(Port=None)`
|
||||||
|
# and botocore stringifies it to `"None"` while building the
|
||||||
|
# presigned URL, which then blows up with `ValueError: Port could
|
||||||
|
# not be cast to integer value as 'None'` during signing.
|
||||||
|
db_port = os.getenv("DATABASE_PORT", "5432")
|
||||||
db_user = os.getenv("DATABASE_USER")
|
db_user = os.getenv("DATABASE_USER")
|
||||||
db_name = os.getenv("DATABASE_NAME")
|
db_name = os.getenv("DATABASE_NAME")
|
||||||
db_schema = os.getenv("DATABASE_SCHEMA")
|
db_schema = os.getenv("DATABASE_SCHEMA")
|
||||||
|
|||||||
@ -113,7 +113,11 @@ from litellm.proxy.db.exception_handler import (
|
|||||||
call_with_db_reconnect_retry,
|
call_with_db_reconnect_retry,
|
||||||
)
|
)
|
||||||
from litellm.proxy.db.log_db_metrics import log_db_metrics
|
from litellm.proxy.db.log_db_metrics import log_db_metrics
|
||||||
from litellm.proxy.db.prisma_client import PrismaWrapper
|
from litellm.proxy.db.prisma_client import (
|
||||||
|
PrismaWrapper,
|
||||||
|
parse_iam_endpoint_from_url,
|
||||||
|
)
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import (
|
from litellm.proxy.guardrails.guardrail_hooks.unified_guardrail.unified_guardrail import (
|
||||||
UnifiedLLMGuardrails,
|
UnifiedLLMGuardrails,
|
||||||
)
|
)
|
||||||
@ -2569,24 +2573,101 @@ class PrismaClient:
|
|||||||
raise Exception(
|
raise Exception(
|
||||||
"Unable to find Prisma binaries. Please run 'prisma generate' first."
|
"Unable to find Prisma binaries. Please run 'prisma generate' first."
|
||||||
)
|
)
|
||||||
|
iam_flag = (
|
||||||
|
self.iam_token_db_auth if self.iam_token_db_auth is not None else False
|
||||||
|
)
|
||||||
|
# When read-replica routing is on, tag log lines with [writer]/[reader]
|
||||||
|
# so the two wrappers' interleaved IAM refresh logs can be told apart.
|
||||||
|
# Single-DB deployments get an empty prefix (logs unchanged).
|
||||||
|
read_replica_url = os.getenv("DATABASE_URL_READ_REPLICA")
|
||||||
|
writer_log_prefix = "[writer]" if read_replica_url else ""
|
||||||
if http_client is not None:
|
if http_client is not None:
|
||||||
self.db = PrismaWrapper(
|
writer_wrapper = PrismaWrapper(
|
||||||
original_prisma=Prisma(http=http_client),
|
original_prisma=Prisma(http=http_client),
|
||||||
iam_token_db_auth=(
|
iam_token_db_auth=iam_flag,
|
||||||
self.iam_token_db_auth
|
log_prefix=writer_log_prefix,
|
||||||
if self.iam_token_db_auth is not None
|
|
||||||
else False
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.db = PrismaWrapper(
|
writer_wrapper = PrismaWrapper(
|
||||||
original_prisma=Prisma(),
|
original_prisma=Prisma(),
|
||||||
iam_token_db_auth=(
|
iam_token_db_auth=iam_flag,
|
||||||
self.iam_token_db_auth
|
log_prefix=writer_log_prefix,
|
||||||
if self.iam_token_db_auth is not None
|
)
|
||||||
else False
|
|
||||||
),
|
# Optional read-replica routing. When DATABASE_URL_READ_REPLICA is set,
|
||||||
) # Client to connect to Prisma db
|
# reads (find_*, count, group_by, query_raw/_first) are routed to the
|
||||||
|
# reader endpoint and writes stay on the writer. Falls back to the
|
||||||
|
# writer-only wrapper when the env var is unset, preserving existing
|
||||||
|
# single-DB deployments.
|
||||||
|
self.db: Union[PrismaWrapper, RoutingPrismaWrapper]
|
||||||
|
if read_replica_url:
|
||||||
|
try:
|
||||||
|
# If IAM auth is enabled, the reader refreshes its own token on
|
||||||
|
# the same cadence as the writer. We parse the static endpoint
|
||||||
|
# pieces (host/port/user/db) once from the reader URL — only
|
||||||
|
# the IAM token rotates after that.
|
||||||
|
reader_iam_endpoint = (
|
||||||
|
parse_iam_endpoint_from_url(read_replica_url) if iam_flag else None
|
||||||
|
)
|
||||||
|
# Mint a fresh IAM token for the reader BEFORE constructing the
|
||||||
|
# Prisma client. Mirrors what `proxy_cli.py` already does for
|
||||||
|
# the writer (proxy_cli.py:812-832) — without this, the reader
|
||||||
|
# Prisma is built with whatever placeholder URL the user
|
||||||
|
# supplied (no real token), and the first query falls through
|
||||||
|
# to the synchronous fallback path in
|
||||||
|
# `PrismaWrapper.__getattr__`, which deadlocks the event loop
|
||||||
|
# and times out after 30s.
|
||||||
|
if iam_flag and reader_iam_endpoint is not None:
|
||||||
|
from litellm.proxy.auth.rds_iam_token import (
|
||||||
|
generate_iam_auth_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
reader_token = generate_iam_auth_token(
|
||||||
|
db_host=reader_iam_endpoint.host,
|
||||||
|
db_port=reader_iam_endpoint.port,
|
||||||
|
db_user=reader_iam_endpoint.user,
|
||||||
|
)
|
||||||
|
read_replica_url = reader_iam_endpoint.build_url(reader_token)
|
||||||
|
os.environ["DATABASE_URL_READ_REPLICA"] = read_replica_url
|
||||||
|
reader_kwargs: Dict[str, Any] = {
|
||||||
|
"datasource": {"url": read_replica_url}
|
||||||
|
}
|
||||||
|
if http_client is not None:
|
||||||
|
reader_prisma = Prisma(http=http_client, **reader_kwargs)
|
||||||
|
else:
|
||||||
|
reader_prisma = Prisma(**reader_kwargs)
|
||||||
|
reader_wrapper = PrismaWrapper(
|
||||||
|
original_prisma=reader_prisma,
|
||||||
|
iam_token_db_auth=iam_flag,
|
||||||
|
db_url_env_var="DATABASE_URL_READ_REPLICA",
|
||||||
|
iam_endpoint=reader_iam_endpoint,
|
||||||
|
recreate_uses_datasource=True,
|
||||||
|
log_prefix="[reader]",
|
||||||
|
)
|
||||||
|
self.db = RoutingPrismaWrapper(
|
||||||
|
writer=writer_wrapper, reader=reader_wrapper
|
||||||
|
)
|
||||||
|
verbose_proxy_logger.info(
|
||||||
|
"PrismaClient: read-replica routing enabled via DATABASE_URL_READ_REPLICA"
|
||||||
|
+ (" (with IAM token auto-refresh)" if iam_flag else "")
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# Reader is opt-in; never let its construction fail proxy
|
||||||
|
# startup. Mirrors the runtime contract from
|
||||||
|
# `RoutingPrismaWrapper.connect`: reader-side failures are
|
||||||
|
# logged and we keep serving traffic via the writer alone.
|
||||||
|
# This recovers from transient AWS STS hiccups during the
|
||||||
|
# reader IAM token mint, malformed DATABASE_URL_READ_REPLICA,
|
||||||
|
# and Prisma construction errors. Operator restart is required
|
||||||
|
# to retry read-routing once the underlying issue is resolved.
|
||||||
|
verbose_proxy_logger.warning(
|
||||||
|
"Failed to initialize read replica Prisma client: %s. "
|
||||||
|
"Falling back to writer-only mode (no read routing) until proxy restart.",
|
||||||
|
e,
|
||||||
|
)
|
||||||
|
self.db = writer_wrapper
|
||||||
|
else:
|
||||||
|
self.db = writer_wrapper # Client to connect to Prisma db
|
||||||
self._db_reconnect_lock = asyncio.Lock()
|
self._db_reconnect_lock = asyncio.Lock()
|
||||||
self._db_health_watchdog_task: Optional[asyncio.Task] = None
|
self._db_health_watchdog_task: Optional[asyncio.Task] = None
|
||||||
self._db_last_reconnect_attempt_ts: float = 0.0
|
self._db_last_reconnect_attempt_ts: float = 0.0
|
||||||
@ -2624,6 +2705,13 @@ class PrismaClient:
|
|||||||
self._engine_wait_thread: Optional[threading.Thread] = None
|
self._engine_wait_thread: Optional[threading.Thread] = None
|
||||||
verbose_proxy_logger.debug("Success - Created Prisma Client")
|
verbose_proxy_logger.debug("Success - Created Prisma Client")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def writer_db(self) -> PrismaWrapper:
|
||||||
|
"""Underlying writer Prisma wrapper, regardless of read-replica routing."""
|
||||||
|
if isinstance(self.db, RoutingPrismaWrapper):
|
||||||
|
return self.db.writer
|
||||||
|
return self.db
|
||||||
|
|
||||||
def get_request_status(
|
def get_request_status(
|
||||||
self, payload: Union[dict, SpendLogsPayload]
|
self, payload: Union[dict, SpendLogsPayload]
|
||||||
) -> Literal["success", "failure"]:
|
) -> Literal["success", "failure"]:
|
||||||
@ -4272,7 +4360,10 @@ class PrismaClient:
|
|||||||
self._cleanup_engine_watcher()
|
self._cleanup_engine_watcher()
|
||||||
await self.db.recreate_prisma_client(db_url)
|
await self.db.recreate_prisma_client(db_url)
|
||||||
await self._start_engine_watcher()
|
await self._start_engine_watcher()
|
||||||
await self.db.query_raw("SELECT 1")
|
# Smoke-test the writer specifically; query_raw on the routing
|
||||||
|
# wrapper sends to the reader, which would not validate the
|
||||||
|
# newly-recreated writer engine.
|
||||||
|
await self.writer_db.query_raw("SELECT 1")
|
||||||
|
|
||||||
await asyncio.wait_for(_do_direct_reconnect(), timeout=effective_timeout)
|
await asyncio.wait_for(_do_direct_reconnect(), timeout=effective_timeout)
|
||||||
|
|
||||||
|
|||||||
887
tests/test_litellm/proxy/db/test_routing_prisma_wrapper.py
Normal file
887
tests/test_litellm/proxy/db/test_routing_prisma_wrapper.py
Normal file
@ -0,0 +1,887 @@
|
|||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
sys.path.insert(0, os.path.abspath("../../../.."))
|
||||||
|
|
||||||
|
|
||||||
|
# NOTE: do NOT patch sys.modules["prisma"] file-wide via an autouse fixture.
|
||||||
|
# Doing so leaks across pytest-xdist test scheduling: when a worker runs a
|
||||||
|
# routing test, then later runs test_exception_handler.py, the cached MagicMock
|
||||||
|
# attribute references break `isinstance(e, prisma.errors.X)` in
|
||||||
|
# `is_database_transport_error`. The two tests below that actually need to
|
||||||
|
# stub the prisma SDK do so per-test via monkeypatch, which is properly scoped.
|
||||||
|
|
||||||
|
|
||||||
|
def _make_wrappers():
|
||||||
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
|
||||||
|
writer_inner = MagicMock(name="writer_prisma")
|
||||||
|
reader_inner = MagicMock(name="reader_prisma")
|
||||||
|
writer = PrismaWrapper(original_prisma=writer_inner, iam_token_db_auth=False)
|
||||||
|
reader = PrismaWrapper(original_prisma=reader_inner, iam_token_db_auth=False)
|
||||||
|
return writer, writer_inner, reader, reader_inner
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeActions:
|
||||||
|
"""Stand-in for a Prisma per-model Actions instance (non-callable, has find_many/create)."""
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
self._name = name
|
||||||
|
for method in (
|
||||||
|
"find_many",
|
||||||
|
"find_unique",
|
||||||
|
"find_first",
|
||||||
|
"count",
|
||||||
|
"group_by",
|
||||||
|
"create",
|
||||||
|
"update",
|
||||||
|
"upsert",
|
||||||
|
"delete",
|
||||||
|
"delete_many",
|
||||||
|
"update_many",
|
||||||
|
):
|
||||||
|
setattr(self, method, MagicMock(name=f"{name}.{method}"))
|
||||||
|
|
||||||
|
|
||||||
|
def _model_actions_mock(name: str) -> _FakeActions:
|
||||||
|
return _FakeActions(name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_top_level_query_raw_routes_to_reader():
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer, writer_inner, reader, reader_inner = _make_wrappers()
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
# query_raw should resolve to the reader's underlying client.
|
||||||
|
assert routing.query_raw is reader_inner.query_raw
|
||||||
|
assert routing.query_first is reader_inner.query_first
|
||||||
|
|
||||||
|
|
||||||
|
def test_top_level_execute_raw_routes_to_writer():
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer, writer_inner, reader, reader_inner = _make_wrappers()
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
# execute_raw, batch_, tx are write-side and must hit the writer.
|
||||||
|
assert routing.execute_raw is writer_inner.execute_raw
|
||||||
|
assert routing.batch_ is writer_inner.batch_
|
||||||
|
assert routing.tx is writer_inner.tx
|
||||||
|
|
||||||
|
|
||||||
|
def test_per_model_reads_route_to_reader_writes_to_writer():
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer, writer_inner, reader, reader_inner = _make_wrappers()
|
||||||
|
writer_inner.litellm_usertable = _model_actions_mock("writer_users")
|
||||||
|
reader_inner.litellm_usertable = _model_actions_mock("reader_users")
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
actions = routing.litellm_usertable
|
||||||
|
|
||||||
|
# Reads → reader actions.
|
||||||
|
assert actions.find_many is reader_inner.litellm_usertable.find_many
|
||||||
|
assert actions.find_unique is reader_inner.litellm_usertable.find_unique
|
||||||
|
assert actions.find_first is reader_inner.litellm_usertable.find_first
|
||||||
|
assert actions.count is reader_inner.litellm_usertable.count
|
||||||
|
assert actions.group_by is reader_inner.litellm_usertable.group_by
|
||||||
|
|
||||||
|
# Writes → writer actions.
|
||||||
|
assert actions.create is writer_inner.litellm_usertable.create
|
||||||
|
assert actions.update is writer_inner.litellm_usertable.update
|
||||||
|
assert actions.upsert is writer_inner.litellm_usertable.upsert
|
||||||
|
assert actions.delete is writer_inner.litellm_usertable.delete
|
||||||
|
assert actions.update_many is writer_inner.litellm_usertable.update_many
|
||||||
|
assert actions.delete_many is writer_inner.litellm_usertable.delete_many
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_invokes_both_clients():
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer, writer_inner, reader, reader_inner = _make_wrappers()
|
||||||
|
writer_inner.connect = AsyncMock()
|
||||||
|
reader_inner.connect = AsyncMock()
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
await routing.connect()
|
||||||
|
|
||||||
|
writer_inner.connect.assert_awaited_once()
|
||||||
|
reader_inner.connect.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_logs_writer_and_reader_success(caplog):
|
||||||
|
"""Successful startup emits a positive INFO confirmation for both writer
|
||||||
|
and reader so operators can verify connectivity without inspecting the URL
|
||||||
|
in logs."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer, writer_inner, reader, reader_inner = _make_wrappers()
|
||||||
|
writer_inner.connect = AsyncMock()
|
||||||
|
reader_inner.connect = AsyncMock()
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.INFO, logger="LiteLLM Proxy"):
|
||||||
|
await routing.connect()
|
||||||
|
|
||||||
|
messages = [r.getMessage() for r in caplog.records]
|
||||||
|
assert "[writer] DB connected" in messages
|
||||||
|
assert "[reader] DB connected" in messages
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_disconnect_continues_when_one_side_fails():
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer, writer_inner, reader, reader_inner = _make_wrappers()
|
||||||
|
writer_inner.disconnect = AsyncMock(side_effect=RuntimeError("writer down"))
|
||||||
|
reader_inner.disconnect = AsyncMock()
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="writer down"):
|
||||||
|
await routing.disconnect()
|
||||||
|
|
||||||
|
# Reader still attempted even though writer raised.
|
||||||
|
reader_inner.disconnect.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_connected_reflects_writer_only():
|
||||||
|
"""is_connected() must NOT depend on reader health — a healthy writer with
|
||||||
|
a degraded reader should report True so that PrismaClient.connect()'s
|
||||||
|
health check does not re-trigger a writer reconnect (which only fixes
|
||||||
|
writer-side problems and would loop indefinitely)."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer, writer_inner, reader, reader_inner = _make_wrappers()
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
writer_inner.is_connected = MagicMock(return_value=True)
|
||||||
|
reader_inner.is_connected = MagicMock(return_value=True)
|
||||||
|
assert routing.is_connected() is True
|
||||||
|
|
||||||
|
# Reader down → still True (reader degradation is tracked separately).
|
||||||
|
reader_inner.is_connected = MagicMock(return_value=False)
|
||||||
|
assert routing.is_connected() is True
|
||||||
|
|
||||||
|
# Writer down → False.
|
||||||
|
writer_inner.is_connected = MagicMock(return_value=False)
|
||||||
|
assert routing.is_connected() is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_refresh_delegates_to_both_writer_and_reader():
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer = MagicMock()
|
||||||
|
writer.start_token_refresh_task = AsyncMock()
|
||||||
|
writer.stop_token_refresh_task = AsyncMock()
|
||||||
|
reader = MagicMock()
|
||||||
|
reader.start_token_refresh_task = AsyncMock()
|
||||||
|
reader.stop_token_refresh_task = AsyncMock()
|
||||||
|
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
asyncio.run(routing.start_token_refresh_task())
|
||||||
|
asyncio.run(routing.stop_token_refresh_task())
|
||||||
|
|
||||||
|
# Both wrappers get start/stop — each manages its own IAM token. When
|
||||||
|
# IAM is disabled on a wrapper its task body is a no-op.
|
||||||
|
writer.start_token_refresh_task.assert_awaited_once()
|
||||||
|
writer.stop_token_refresh_task.assert_awaited_once()
|
||||||
|
reader.start_token_refresh_task.assert_awaited_once()
|
||||||
|
reader.stop_token_refresh_task.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_routed_actions_falls_back_to_writer_for_unknown_methods():
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import _RoutedActions
|
||||||
|
|
||||||
|
writer_actions = _model_actions_mock("writer")
|
||||||
|
writer_actions.some_custom_method = "writer-custom"
|
||||||
|
reader_actions = _model_actions_mock("reader")
|
||||||
|
reader_actions.some_custom_method = "reader-custom"
|
||||||
|
|
||||||
|
routed = _RoutedActions(writer_actions, reader_actions, lambda: True)
|
||||||
|
# Unknown method → defaults to writer (safe fallback for write-like ops).
|
||||||
|
assert routed.some_custom_method == "writer-custom"
|
||||||
|
|
||||||
|
|
||||||
|
def test_routed_actions_respects_should_use_reader_flag():
|
||||||
|
"""When the routing wrapper marks the reader unavailable, _RoutedActions
|
||||||
|
must redirect reads to the writer instead — without needing to re-fetch
|
||||||
|
the actions accessor."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import _RoutedActions
|
||||||
|
|
||||||
|
writer_actions = _model_actions_mock("writer")
|
||||||
|
reader_actions = _model_actions_mock("reader")
|
||||||
|
|
||||||
|
use_reader = {"value": True}
|
||||||
|
routed = _RoutedActions(writer_actions, reader_actions, lambda: use_reader["value"])
|
||||||
|
|
||||||
|
# Reader healthy → reads to reader.
|
||||||
|
assert routed.find_many is reader_actions.find_many
|
||||||
|
|
||||||
|
# Reader degrades mid-flight → next read goes to writer.
|
||||||
|
use_reader["value"] = False
|
||||||
|
assert routed.find_many is writer_actions.find_many
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Reader graceful degradation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_connect_swallows_reader_failure_and_falls_back_to_writer():
|
||||||
|
"""A reader connect failure must NOT abort proxy startup. The wrapper
|
||||||
|
flips into degraded mode so subsequent reads route to the writer."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer, writer_inner, reader, reader_inner = _make_wrappers()
|
||||||
|
writer_inner.connect = AsyncMock()
|
||||||
|
reader_inner.connect = AsyncMock(side_effect=RuntimeError("reader unreachable"))
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
# Must not raise — reader failure is non-fatal.
|
||||||
|
await routing.connect()
|
||||||
|
|
||||||
|
assert routing.reader_unavailable is True
|
||||||
|
writer_inner.connect.assert_awaited_once()
|
||||||
|
reader_inner.connect.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reads_route_to_writer_when_reader_unavailable():
|
||||||
|
"""Top-level read methods and per-model reads must fall through to the
|
||||||
|
writer while the reader is degraded."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer, writer_inner, reader, reader_inner = _make_wrappers()
|
||||||
|
writer_inner.litellm_usertable = _model_actions_mock("writer_users")
|
||||||
|
reader_inner.litellm_usertable = _model_actions_mock("reader_users")
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
routing._reader_unavailable = True
|
||||||
|
|
||||||
|
# Top-level reads → writer.
|
||||||
|
assert routing.query_raw is writer_inner.query_raw
|
||||||
|
assert routing.query_first is writer_inner.query_first
|
||||||
|
|
||||||
|
# Per-model reads → writer actions.
|
||||||
|
actions = routing.litellm_usertable
|
||||||
|
assert actions.find_many is writer_inner.litellm_usertable.find_many
|
||||||
|
assert actions.find_unique is writer_inner.litellm_usertable.find_unique
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recreate_prisma_client_recreates_both_writer_and_reader():
|
||||||
|
"""Writer reconnect path calls recreate_prisma_client. The routing wrapper
|
||||||
|
must recreate BOTH clients so a DB-wide event doesn't leave a stale reader."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer = MagicMock()
|
||||||
|
writer.recreate_prisma_client = AsyncMock()
|
||||||
|
reader = MagicMock()
|
||||||
|
reader.iam_token_db_auth = False
|
||||||
|
reader.recreate_prisma_client = AsyncMock()
|
||||||
|
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"DATABASE_URL_READ_REPLICA": "reader-url"}):
|
||||||
|
await routing.recreate_prisma_client("writer-url", http_client=None)
|
||||||
|
|
||||||
|
writer.recreate_prisma_client.assert_awaited_once_with(
|
||||||
|
"writer-url", http_client=None
|
||||||
|
)
|
||||||
|
reader.recreate_prisma_client.assert_awaited_once_with(
|
||||||
|
"reader-url", http_client=None
|
||||||
|
)
|
||||||
|
assert routing.reader_unavailable is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recreate_recovers_reader_after_prior_degradation():
|
||||||
|
"""If a previous connect/recreate degraded the reader, a successful
|
||||||
|
recreate must clear the flag so reads start hitting the reader again."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer = MagicMock()
|
||||||
|
writer.recreate_prisma_client = AsyncMock()
|
||||||
|
reader = MagicMock()
|
||||||
|
reader.iam_token_db_auth = False
|
||||||
|
reader.recreate_prisma_client = AsyncMock()
|
||||||
|
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
routing._reader_unavailable = True
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"DATABASE_URL_READ_REPLICA": "reader-url"}):
|
||||||
|
await routing.recreate_prisma_client("writer-url")
|
||||||
|
|
||||||
|
assert routing.reader_unavailable is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recreate_degrades_reader_if_reader_recreate_fails():
|
||||||
|
"""If the reader recreate fails, writer recreate still succeeds and the
|
||||||
|
routing wrapper degrades (does not raise)."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer = MagicMock()
|
||||||
|
writer.recreate_prisma_client = AsyncMock()
|
||||||
|
reader = MagicMock()
|
||||||
|
reader.iam_token_db_auth = False
|
||||||
|
reader.recreate_prisma_client = AsyncMock(
|
||||||
|
side_effect=RuntimeError("reader still down")
|
||||||
|
)
|
||||||
|
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
with patch.dict(os.environ, {"DATABASE_URL_READ_REPLICA": "reader-url"}):
|
||||||
|
# Must not raise — writer was recreated, reader is best-effort.
|
||||||
|
await routing.recreate_prisma_client("writer-url")
|
||||||
|
|
||||||
|
writer.recreate_prisma_client.assert_awaited_once()
|
||||||
|
assert routing.reader_unavailable is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recreate_degrades_reader_when_replica_url_missing():
|
||||||
|
"""Non-IAM reader needs DATABASE_URL_READ_REPLICA. If it's missing
|
||||||
|
(configuration drift), the wrapper degrades instead of raising."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer = MagicMock()
|
||||||
|
writer.recreate_prisma_client = AsyncMock()
|
||||||
|
reader = MagicMock()
|
||||||
|
reader.iam_token_db_auth = False
|
||||||
|
reader.recreate_prisma_client = AsyncMock()
|
||||||
|
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
# Ensure env var is absent.
|
||||||
|
with patch.dict(os.environ, {}, clear=False):
|
||||||
|
os.environ.pop("DATABASE_URL_READ_REPLICA", None)
|
||||||
|
await routing.recreate_prisma_client("writer-url")
|
||||||
|
|
||||||
|
writer.recreate_prisma_client.assert_awaited_once()
|
||||||
|
reader.recreate_prisma_client.assert_not_awaited()
|
||||||
|
assert routing.reader_unavailable is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recreate_iam_reader_refreshes_token():
|
||||||
|
"""IAM-enabled readers must refresh their token (reader has its own parsed
|
||||||
|
endpoint) and pass the fresh URL to recreate_prisma_client."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer = MagicMock()
|
||||||
|
writer.recreate_prisma_client = AsyncMock()
|
||||||
|
reader = MagicMock()
|
||||||
|
reader.iam_token_db_auth = True
|
||||||
|
reader.get_rds_iam_token = MagicMock(return_value="postgresql://u:fresh@h:5432/db")
|
||||||
|
reader.recreate_prisma_client = AsyncMock()
|
||||||
|
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
await routing.recreate_prisma_client("writer-url")
|
||||||
|
|
||||||
|
reader.get_rds_iam_token.assert_called_once()
|
||||||
|
reader.recreate_prisma_client.assert_awaited_once_with(
|
||||||
|
"postgresql://u:fresh@h:5432/db", http_client=None
|
||||||
|
)
|
||||||
|
assert routing.reader_unavailable is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_recreate_degrades_when_iam_token_generation_returns_none():
|
||||||
|
"""If `get_rds_iam_token` returns None (e.g. AWS-side failure), the wrapper
|
||||||
|
must degrade rather than crash — this exercises the explicit `raise
|
||||||
|
RuntimeError` inside `_recreate_reader`'s IAM branch."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer = MagicMock()
|
||||||
|
writer.recreate_prisma_client = AsyncMock()
|
||||||
|
reader = MagicMock()
|
||||||
|
reader.iam_token_db_auth = True
|
||||||
|
reader.get_rds_iam_token = MagicMock(return_value=None)
|
||||||
|
reader.recreate_prisma_client = AsyncMock()
|
||||||
|
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
await routing.recreate_prisma_client("writer-url")
|
||||||
|
|
||||||
|
writer.recreate_prisma_client.assert_awaited_once()
|
||||||
|
reader.recreate_prisma_client.assert_not_awaited()
|
||||||
|
assert routing.reader_unavailable is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_and_reader_properties_expose_underlying_wrappers():
|
||||||
|
"""The `writer` and `reader` properties are used by PrismaClient.writer_db
|
||||||
|
to smoke-test the writer specifically during reconnect — they must return
|
||||||
|
the exact wrappers passed in."""
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
writer, _, reader, _ = _make_wrappers()
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
assert routing.writer is writer
|
||||||
|
assert routing.reader is reader
|
||||||
|
|
||||||
|
|
||||||
|
def test_per_model_accessor_falls_back_when_reader_lacks_attr():
|
||||||
|
"""If the reader Prisma client somehow lacks a model accessor that the
|
||||||
|
writer has (older client / partial mock), the wrapper must fall back to
|
||||||
|
the writer accessor instead of raising AttributeError to the caller."""
|
||||||
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
# Plain class with only the accessor set on the writer side. Using a real
|
||||||
|
# class instead of MagicMock so attribute access raises AttributeError
|
||||||
|
# naturally instead of auto-creating mock attributes.
|
||||||
|
class _PartialPrisma:
|
||||||
|
pass
|
||||||
|
|
||||||
|
writer_inner = _PartialPrisma()
|
||||||
|
writer_inner.litellm_usertable = _model_actions_mock("writer_users")
|
||||||
|
reader_inner = _PartialPrisma() # deliberately missing litellm_usertable
|
||||||
|
|
||||||
|
writer = PrismaWrapper(original_prisma=writer_inner, iam_token_db_auth=False)
|
||||||
|
reader = PrismaWrapper(original_prisma=reader_inner, iam_token_db_auth=False)
|
||||||
|
routing = RoutingPrismaWrapper(writer=writer, reader=reader)
|
||||||
|
|
||||||
|
actions = routing.litellm_usertable
|
||||||
|
# Falls back to the writer's accessor verbatim — not a _RoutedActions wrapper.
|
||||||
|
assert actions is writer_inner.litellm_usertable
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_writer_recreate_passes_http_client_through(monkeypatch):
|
||||||
|
"""When PrismaClient is constructed with an http_client, recreate must
|
||||||
|
forward it to the new Prisma() so connection settings persist across
|
||||||
|
reconnects."""
|
||||||
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
|
||||||
|
captured_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
class FakePrisma:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
fake_module = MagicMock()
|
||||||
|
fake_module.Prisma = FakePrisma
|
||||||
|
monkeypatch.setitem(sys.modules, "prisma", fake_module)
|
||||||
|
|
||||||
|
writer = PrismaWrapper(original_prisma=MagicMock(), iam_token_db_auth=False)
|
||||||
|
sentinel_http = object()
|
||||||
|
await writer.recreate_prisma_client(
|
||||||
|
"postgresql://u:p@h:5432/db", http_client=sentinel_http
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured_kwargs == {"http": sentinel_http}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# IAM endpoint parsing + reader IAM refresh
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_iam_endpoint_from_url_extracts_all_fields():
|
||||||
|
from litellm.proxy.db.prisma_client import parse_iam_endpoint_from_url
|
||||||
|
|
||||||
|
ep = parse_iam_endpoint_from_url(
|
||||||
|
"postgresql://litellm_user:initial-token@aurora-reader.example.com:6543/litellm?schema=public"
|
||||||
|
)
|
||||||
|
assert ep.host == "aurora-reader.example.com"
|
||||||
|
assert ep.port == "6543"
|
||||||
|
assert ep.user == "litellm_user"
|
||||||
|
assert ep.name == "litellm"
|
||||||
|
assert ep.schema == "public"
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_iam_endpoint_defaults_port_to_5432_and_skips_schema():
|
||||||
|
from litellm.proxy.db.prisma_client import parse_iam_endpoint_from_url
|
||||||
|
|
||||||
|
ep = parse_iam_endpoint_from_url("postgresql://u@host/dbname")
|
||||||
|
assert ep.host == "host"
|
||||||
|
assert ep.port == "5432"
|
||||||
|
assert ep.user == "u"
|
||||||
|
assert ep.name == "dbname"
|
||||||
|
assert ep.schema is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_iam_endpoint_rejects_url_without_user_or_dbname():
|
||||||
|
from litellm.proxy.db.prisma_client import parse_iam_endpoint_from_url
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="missing host or username"):
|
||||||
|
parse_iam_endpoint_from_url("postgresql://host:5432/db")
|
||||||
|
with pytest.raises(ValueError, match="missing database name"):
|
||||||
|
parse_iam_endpoint_from_url("postgresql://u@host:5432/")
|
||||||
|
|
||||||
|
|
||||||
|
def test_iam_endpoint_build_url_inserts_token_verbatim():
|
||||||
|
from litellm.proxy.db.prisma_client import IAMEndpoint
|
||||||
|
|
||||||
|
# `generate_iam_auth_token` already URL-encodes the presigned token, so
|
||||||
|
# `build_url` must NOT encode again — double-encoding turned `%3D` into
|
||||||
|
# `%253D` and broke RDS auth on the reader path.
|
||||||
|
ep = IAMEndpoint(host="h", port="5432", user="u", name="db", schema="public")
|
||||||
|
pre_encoded_token = "token%2Fwith%3Fweird%26chars%3Dyes"
|
||||||
|
url = ep.build_url(pre_encoded_token)
|
||||||
|
assert url == f"postgresql://u:{pre_encoded_token}@h:5432/db?schema=public"
|
||||||
|
# Sanity check: no `%25` (the encoding of `%`), confirming we didn't re-encode.
|
||||||
|
assert "%25" not in url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_iam_refresh_logs_carry_log_prefix(caplog):
|
||||||
|
"""When `log_prefix` is set on a PrismaWrapper, every IAM-related log
|
||||||
|
line emitted by that wrapper must start with the prefix so writer and
|
||||||
|
reader can be told apart in interleaved output."""
|
||||||
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
|
||||||
|
wrapper = PrismaWrapper(
|
||||||
|
original_prisma=MagicMock(),
|
||||||
|
iam_token_db_auth=True,
|
||||||
|
log_prefix="[reader]",
|
||||||
|
)
|
||||||
|
|
||||||
|
with caplog.at_level(logging.INFO, logger="LiteLLM Proxy"):
|
||||||
|
await wrapper.start_token_refresh_task()
|
||||||
|
# Loop emits "RDS IAM token refresh loop started..." on first tick.
|
||||||
|
# Cancel immediately so the loop body runs once and we can assert.
|
||||||
|
await wrapper.stop_token_refresh_task()
|
||||||
|
|
||||||
|
messages = [r.getMessage() for r in caplog.records]
|
||||||
|
# Both start and stop notifications carry the prefix.
|
||||||
|
assert any(
|
||||||
|
m.startswith("[reader] Started RDS IAM token proactive refresh")
|
||||||
|
for m in messages
|
||||||
|
)
|
||||||
|
assert any(
|
||||||
|
m.startswith("[reader] Stopped RDS IAM token refresh background task")
|
||||||
|
for m in messages
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_rds_iam_token_returns_none_when_iam_disabled():
|
||||||
|
"""`get_rds_iam_token` short-circuits to None when iam_token_db_auth is
|
||||||
|
False — covers the early-return guard at the top of the method."""
|
||||||
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
|
||||||
|
wrapper = PrismaWrapper(original_prisma=MagicMock(), iam_token_db_auth=False)
|
||||||
|
assert wrapper.get_rds_iam_token() is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_getattr_does_not_block_inside_running_loop_on_expired_token(monkeypatch):
|
||||||
|
"""When `__getattr__` runs inside a running event loop and the IAM token
|
||||||
|
is expired, it MUST schedule the refresh as a background task and return
|
||||||
|
immediately. The previous `run_coroutine_threadsafe` + `future.result()`
|
||||||
|
pattern deadlocks the loop (loop thread blocks waiting for a coroutine
|
||||||
|
that needs the loop to run) and times out at 30s — exactly what was
|
||||||
|
breaking the reader on first query."""
|
||||||
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
|
||||||
|
# Stale URL — `is_token_expired` returns True because the password isn't
|
||||||
|
# a parseable IAM token, so we exercise the expired branch.
|
||||||
|
monkeypatch.setenv(
|
||||||
|
"DATABASE_URL_READ_REPLICA",
|
||||||
|
"postgresql://reader:placeholder@reader.aurora.local:5432/litellm",
|
||||||
|
)
|
||||||
|
|
||||||
|
inner = MagicMock()
|
||||||
|
inner.query_raw = MagicMock(name="query_raw_attr")
|
||||||
|
|
||||||
|
wrapper = PrismaWrapper(
|
||||||
|
original_prisma=inner,
|
||||||
|
iam_token_db_auth=True,
|
||||||
|
db_url_env_var="DATABASE_URL_READ_REPLICA",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Replace the heavy refresh coroutine with a no-op AsyncMock so we can
|
||||||
|
# observe whether it was scheduled without actually doing the recreate.
|
||||||
|
refresh_calls = {"count": 0}
|
||||||
|
|
||||||
|
async def fake_refresh():
|
||||||
|
refresh_calls["count"] += 1
|
||||||
|
|
||||||
|
monkeypatch.setattr(wrapper, "_safe_refresh_token", fake_refresh)
|
||||||
|
|
||||||
|
# Direct attribute access from inside this async test runs __getattr__
|
||||||
|
# on the loop thread, exercising the in-loop branch. If the previous
|
||||||
|
# `run_coroutine_threadsafe` + `future.result()` pattern were back, this
|
||||||
|
# line would deadlock the loop and the test would hang (and pytest's
|
||||||
|
# per-test timeout would catch it).
|
||||||
|
attr = wrapper.query_raw
|
||||||
|
# Yield once so the scheduled refresh task gets a chance to run.
|
||||||
|
await asyncio.sleep(0)
|
||||||
|
|
||||||
|
assert attr is inner.query_raw
|
||||||
|
assert refresh_calls["count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_get_rds_iam_token_defaults_port_when_unset(monkeypatch):
|
||||||
|
"""When DATABASE_PORT is unset, the writer must default to the Postgres
|
||||||
|
standard port instead of passing `None` through. Passing None to
|
||||||
|
`generate_iam_auth_token` makes botocore embed the literal string
|
||||||
|
\"None\" in the presigned URL during signing and crashes with
|
||||||
|
`ValueError: Port could not be cast to integer value as 'None'`."""
|
||||||
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
|
||||||
|
monkeypatch.setenv("DATABASE_HOST", "writer.aurora.local")
|
||||||
|
monkeypatch.delenv("DATABASE_PORT", raising=False)
|
||||||
|
monkeypatch.setenv("DATABASE_USER", "litellm")
|
||||||
|
monkeypatch.setenv("DATABASE_NAME", "litellm")
|
||||||
|
monkeypatch.delenv("DATABASE_SCHEMA", raising=False)
|
||||||
|
monkeypatch.delenv("DATABASE_URL", raising=False)
|
||||||
|
|
||||||
|
captured: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def fake_generate(db_host=None, db_port=None, db_user=None):
|
||||||
|
captured["port"] = db_port
|
||||||
|
return "TOKEN"
|
||||||
|
|
||||||
|
fake_module = MagicMock()
|
||||||
|
fake_module.generate_iam_auth_token = fake_generate
|
||||||
|
monkeypatch.setitem(sys.modules, "litellm.proxy.auth.rds_iam_token", fake_module)
|
||||||
|
|
||||||
|
writer = PrismaWrapper(
|
||||||
|
original_prisma=MagicMock(),
|
||||||
|
iam_token_db_auth=True,
|
||||||
|
)
|
||||||
|
new_url = writer.get_rds_iam_token()
|
||||||
|
|
||||||
|
assert captured["port"] == "5432" # default applied, NOT None
|
||||||
|
assert ":5432/litellm" in (new_url or "")
|
||||||
|
|
||||||
|
|
||||||
|
def test_writer_get_rds_iam_token_uses_database_host_env_vars(monkeypatch):
|
||||||
|
"""Writer's IAM path (no iam_endpoint configured) reads host/port/user/db
|
||||||
|
from the legacy DATABASE_HOST/PORT/USER/NAME env vars and writes the URL
|
||||||
|
back to DATABASE_URL — this is the pre-read-replica behavior the patch
|
||||||
|
must preserve."""
|
||||||
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
|
||||||
|
monkeypatch.setenv("DATABASE_HOST", "writer.aurora.local")
|
||||||
|
monkeypatch.setenv("DATABASE_PORT", "5432")
|
||||||
|
monkeypatch.setenv("DATABASE_USER", "litellm")
|
||||||
|
monkeypatch.setenv("DATABASE_NAME", "litellm")
|
||||||
|
monkeypatch.setenv("DATABASE_SCHEMA", "public")
|
||||||
|
monkeypatch.delenv("DATABASE_URL", raising=False)
|
||||||
|
|
||||||
|
captured: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def fake_generate(db_host=None, db_port=None, db_user=None):
|
||||||
|
captured["host"] = db_host
|
||||||
|
captured["port"] = db_port
|
||||||
|
captured["user"] = db_user
|
||||||
|
return "WRITER-TOKEN"
|
||||||
|
|
||||||
|
fake_module = MagicMock()
|
||||||
|
fake_module.generate_iam_auth_token = fake_generate
|
||||||
|
monkeypatch.setitem(sys.modules, "litellm.proxy.auth.rds_iam_token", fake_module)
|
||||||
|
|
||||||
|
writer = PrismaWrapper(
|
||||||
|
original_prisma=MagicMock(),
|
||||||
|
iam_token_db_auth=True,
|
||||||
|
# No iam_endpoint → legacy DATABASE_HOST/etc. path.
|
||||||
|
)
|
||||||
|
new_url = writer.get_rds_iam_token()
|
||||||
|
|
||||||
|
assert captured == {
|
||||||
|
"host": "writer.aurora.local",
|
||||||
|
"port": "5432",
|
||||||
|
"user": "litellm",
|
||||||
|
}
|
||||||
|
assert new_url == (
|
||||||
|
"postgresql://litellm:WRITER-TOKEN@writer.aurora.local:5432/litellm?schema=public"
|
||||||
|
)
|
||||||
|
# Writer updates its own env var (DATABASE_URL by default), not the reader's.
|
||||||
|
assert os.environ["DATABASE_URL"] == new_url
|
||||||
|
|
||||||
|
|
||||||
|
def test_reader_iam_refresh_uses_parsed_endpoint(monkeypatch):
|
||||||
|
"""The reader generates fresh tokens against its parsed endpoint and
|
||||||
|
writes the new URL to DATABASE_URL_READ_REPLICA — not DATABASE_URL."""
|
||||||
|
from litellm.proxy.db.prisma_client import IAMEndpoint, PrismaWrapper
|
||||||
|
|
||||||
|
# Pre-seed env vars so we can prove the reader does NOT touch DATABASE_URL.
|
||||||
|
monkeypatch.setenv("DATABASE_URL", "writer-url-untouched")
|
||||||
|
monkeypatch.setenv("DATABASE_URL_READ_REPLICA", "stale-reader-url")
|
||||||
|
|
||||||
|
captured: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def fake_generate(db_host=None, db_port=None, db_user=None):
|
||||||
|
captured["host"] = db_host
|
||||||
|
captured["port"] = db_port
|
||||||
|
captured["user"] = db_user
|
||||||
|
return "FRESH-TOKEN"
|
||||||
|
|
||||||
|
fake_module = MagicMock()
|
||||||
|
fake_module.generate_iam_auth_token = fake_generate
|
||||||
|
monkeypatch.setitem(sys.modules, "litellm.proxy.auth.rds_iam_token", fake_module)
|
||||||
|
|
||||||
|
endpoint = IAMEndpoint(
|
||||||
|
host="reader.aurora.local",
|
||||||
|
port="5432",
|
||||||
|
user="lit",
|
||||||
|
name="litellm",
|
||||||
|
schema=None,
|
||||||
|
)
|
||||||
|
reader = PrismaWrapper(
|
||||||
|
original_prisma=MagicMock(),
|
||||||
|
iam_token_db_auth=True,
|
||||||
|
db_url_env_var="DATABASE_URL_READ_REPLICA",
|
||||||
|
iam_endpoint=endpoint,
|
||||||
|
recreate_uses_datasource=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
new_url = reader.get_rds_iam_token()
|
||||||
|
|
||||||
|
# IAM token generator was called with the reader's parsed endpoint, not
|
||||||
|
# the writer's DATABASE_HOST/PORT/USER env vars.
|
||||||
|
assert captured == {
|
||||||
|
"host": "reader.aurora.local",
|
||||||
|
"port": "5432",
|
||||||
|
"user": "lit",
|
||||||
|
}
|
||||||
|
assert new_url is not None
|
||||||
|
assert new_url.startswith(
|
||||||
|
"postgresql://lit:FRESH-TOKEN@reader.aurora.local:5432/litellm"
|
||||||
|
)
|
||||||
|
# The reader updates its OWN env var; writer's DATABASE_URL is left alone.
|
||||||
|
assert os.environ["DATABASE_URL_READ_REPLICA"] == new_url
|
||||||
|
assert os.environ["DATABASE_URL"] == "writer-url-untouched"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reader_recreate_uses_datasource_override(monkeypatch):
|
||||||
|
"""Reader recreate must pass `datasource={"url": ...}` to Prisma() — Prisma
|
||||||
|
only auto-reads DATABASE_URL, so without the override the new reader URL
|
||||||
|
would be silently ignored."""
|
||||||
|
from litellm.proxy.db.prisma_client import IAMEndpoint, PrismaWrapper
|
||||||
|
|
||||||
|
captured_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
class FakePrisma:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
fake_module = MagicMock()
|
||||||
|
fake_module.Prisma = FakePrisma
|
||||||
|
monkeypatch.setitem(sys.modules, "prisma", fake_module)
|
||||||
|
|
||||||
|
reader = PrismaWrapper(
|
||||||
|
original_prisma=MagicMock(),
|
||||||
|
iam_token_db_auth=True,
|
||||||
|
db_url_env_var="DATABASE_URL_READ_REPLICA",
|
||||||
|
iam_endpoint=IAMEndpoint(host="h", port="5432", user="u", name="db"),
|
||||||
|
recreate_uses_datasource=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await reader.recreate_prisma_client(
|
||||||
|
"postgresql://u:newtoken@h:5432/db", http_client=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert captured_kwargs == {
|
||||||
|
"datasource": {"url": "postgresql://u:newtoken@h:5432/db"}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_writer_recreate_does_not_use_datasource(monkeypatch):
|
||||||
|
"""Writer keeps relying on Prisma reading DATABASE_URL from env — datasource
|
||||||
|
override must NOT leak into the writer path (would override the freshly
|
||||||
|
rotated env var)."""
|
||||||
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
|
||||||
|
captured_kwargs: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
class FakePrisma:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
captured_kwargs.update(kwargs)
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
fake_module = MagicMock()
|
||||||
|
fake_module.Prisma = FakePrisma
|
||||||
|
monkeypatch.setitem(sys.modules, "prisma", fake_module)
|
||||||
|
|
||||||
|
writer = PrismaWrapper(
|
||||||
|
original_prisma=MagicMock(),
|
||||||
|
iam_token_db_auth=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
await writer.recreate_prisma_client(
|
||||||
|
"postgresql://u:newtoken@h:5432/db", http_client=None
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "datasource" not in captured_kwargs
|
||||||
|
|
||||||
|
|
||||||
|
def test_prisma_client_init_falls_back_to_writer_when_reader_iam_token_fails(
|
||||||
|
monkeypatch, caplog
|
||||||
|
):
|
||||||
|
"""A transient AWS STS error (or any other failure) during the reader
|
||||||
|
IAM token mint must NOT abort proxy startup. The reader is opt-in, so
|
||||||
|
`PrismaClient.__init__` should log a warning and fall back to the
|
||||||
|
writer-only `PrismaWrapper`. The runtime contract in
|
||||||
|
`RoutingPrismaWrapper.connect` already says reader-side failures are
|
||||||
|
non-fatal — but that code never runs if construction throws first."""
|
||||||
|
from litellm.proxy.db.prisma_client import PrismaWrapper
|
||||||
|
from litellm.proxy.db.routing_prisma_wrapper import RoutingPrismaWrapper
|
||||||
|
|
||||||
|
monkeypatch.setenv("IAM_TOKEN_DB_AUTH", "true")
|
||||||
|
monkeypatch.setenv(
|
||||||
|
"DATABASE_URL_READ_REPLICA",
|
||||||
|
"postgresql://reader_user@reader.aurora.local:5432/litellm",
|
||||||
|
)
|
||||||
|
|
||||||
|
class FakePrisma:
|
||||||
|
def __init__(self, **kwargs):
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
fake_prisma_module = MagicMock()
|
||||||
|
fake_prisma_module.Prisma = FakePrisma
|
||||||
|
monkeypatch.setitem(sys.modules, "prisma", fake_prisma_module)
|
||||||
|
|
||||||
|
fake_iam_module = MagicMock()
|
||||||
|
|
||||||
|
def boom(**_kwargs):
|
||||||
|
raise RuntimeError("simulated AWS STS hiccup")
|
||||||
|
|
||||||
|
fake_iam_module.generate_iam_auth_token = boom
|
||||||
|
monkeypatch.setitem(
|
||||||
|
sys.modules, "litellm.proxy.auth.rds_iam_token", fake_iam_module
|
||||||
|
)
|
||||||
|
|
||||||
|
from litellm.proxy.utils import PrismaClient
|
||||||
|
|
||||||
|
with caplog.at_level(logging.WARNING, logger="LiteLLM Proxy"):
|
||||||
|
client = PrismaClient(
|
||||||
|
database_url="postgresql://writer@writer.aurora.local:5432/litellm",
|
||||||
|
proxy_logging_obj=MagicMock(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Construction did not raise, and the proxy is in writer-only mode —
|
||||||
|
# NOT a RoutingPrismaWrapper, so reads will go to the writer.
|
||||||
|
assert isinstance(client.db, PrismaWrapper)
|
||||||
|
assert not isinstance(client.db, RoutingPrismaWrapper)
|
||||||
|
# And the operator gets a clear warning.
|
||||||
|
assert any(
|
||||||
|
"Failed to initialize read replica Prisma client" in r.getMessage()
|
||||||
|
for r in caplog.records
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue
Block a user