Merge pull request #25732 from harish876/health-check-oom

Optimize database query to prevent OOM errors during health checks
This commit is contained in:
ishaan-berri 2026-04-15 18:13:11 -07:00 committed by GitHub
commit cb8fc480e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 246 additions and 34 deletions

View File

@ -0,0 +1,12 @@
-- CreateIndex (CONCURRENTLY)
--
-- Disclaimer:
-- - CREATE INDEX CONCURRENTLY cannot run inside a transaction. This migration must stay a
-- single statement so Prisma Migrate on PostgreSQL can apply it outside a transaction.
-- - Builds are slower and use more I/O than a blocking CREATE INDEX; if the build is
-- interrupted, Postgres may leave an INVALID index that must be dropped and recreated.
-- - Do not edit this file after it has been applied to any database: Prisma checksums
-- migrations; add a new migration instead.
-- - Requires PostgreSQL that supports CONCURRENTLY with IF NOT EXISTS (use a new migration
-- without IF NOT EXISTS if you must support older versions).
CREATE INDEX CONCURRENTLY IF NOT EXISTS "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx" ON "LiteLLM_HealthCheckTable"("model_id", "model_name", "checked_at" DESC);

View File

@ -1045,6 +1045,7 @@ model LiteLLM_HealthCheckTable {
@@index([model_name])
@@index([checked_at])
@@index([status])
@@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx")
}
// Search Tools table for storing search tool configurations

View File

@ -84,7 +84,7 @@ class SharedHealthCheckManager:
"Pod %s failed to acquire health check lock", self.pod_id
)
return acquired
return bool(acquired)
except Exception as e:
verbose_proxy_logger.error("Error acquiring health check lock: %s", str(e))
return False

View File

@ -1045,6 +1045,7 @@ model LiteLLM_HealthCheckTable {
@@index([model_name])
@@index([checked_at])
@@index([status])
@@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx")
}
// Search Tools table for storing search tool configurations

View File

@ -4525,29 +4525,20 @@ class PrismaClient:
async def get_all_latest_health_checks(self):
"""
Get the latest health check for each model
Get the latest health check for each model.
Uses DB-level DISTINCT ON (model_id, model_name) with ORDER BY checked_at DESC
(via Prisma ``distinct`` + ``order``) so we never load the full history into memory.
"""
try:
# Get all unique model names first
all_checks = await self.db.litellm_healthchecktable.find_many(
order={"checked_at": "desc"}
return await self.db.litellm_healthchecktable.find_many(
distinct=["model_id", "model_name"],
order=[
{"model_id": "asc"},
{"model_name": "asc"},
{"checked_at": "desc"},
],
)
# Group by model_name and get the latest for each
latest_checks = {}
for check in all_checks:
# Create a unique key: prefer model_id if available, otherwise use model_name
# This ensures we get the latest check for each unique model
if check.model_id:
key = (check.model_id, check.model_name)
else:
key = (None, check.model_name)
# Only add if we haven't seen this key yet (since checks are ordered by checked_at desc)
if key not in latest_checks:
latest_checks[key] = check
return list(latest_checks.values())
except Exception as e:
verbose_proxy_logger.error(f"Error getting all latest health checks: {e}")
return []

View File

@ -1045,6 +1045,7 @@ model LiteLLM_HealthCheckTable {
@@index([model_name])
@@index([checked_at])
@@index([status])
@@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx")
}
// Search Tools table for storing search tool configurations

View File

@ -0,0 +1,182 @@
#!/usr/bin/env python3
"""
Bench LiteLLM_HealthCheckTable + PrismaClient
- set DATABASE_URL to your Postgres
- Run ```prisma generate``` to install prisma client before running test )
- This test writes to the default "public" database. Make sure to run cleanup after testing
"""
from __future__ import annotations
import argparse
import asyncio
import gc
import os
import sys
import time
import tracemalloc
from datetime import datetime, timedelta, timezone
from typing import Any, List
SEED_MARKER = "benchmark_get_all_latest_health_checks.py" # Utility Marker for cleanup process.
def _rss_kb_linux() -> int:
try:
with open("/proc/self/status", encoding="utf-8") as f:
for line in f:
if line.startswith("VmRSS:"):
return int(line.split()[1])
except OSError:
pass
return 0
def _fmt_kb(kb: int) -> str:
if kb <= 0:
return "n/a"
return f"{kb} KiB (~{kb / 1024.0:.1f} MiB)"
def _build_batch(
*,
batch_index: int,
batch_size: int,
num_models: int,
base_time: datetime,
) -> List[dict[str, Any]]:
rows: List[dict[str, Any]] = []
for i in range(batch_size):
global_i = batch_index * batch_size + i
model_idx = global_i % max(num_models, 1)
model_name = f"bench-model-{model_idx}"
model_id = f"bench-mid-{model_idx}" if model_idx % 2 == 0 else None
checked_at = base_time - timedelta(seconds=global_i)
rows.append(
{
"model_name": model_name,
"model_id": model_id,
"status": "healthy" if global_i % 3 else "unhealthy",
"healthy_count": 1,
"unhealthy_count": 0,
"checked_by": SEED_MARKER,
"checked_at": checked_at,
}
)
return rows
async def _seed(
prisma: Any,
*,
total_rows: int,
batch_size: int,
num_models: int,
) -> None:
db = prisma.db
base_time = datetime.now(timezone.utc)
inserted = 0
batch_idx = 0
while inserted < total_rows:
n = min(batch_size, total_rows - inserted)
await db.litellm_healthchecktable.create_many(
data=_build_batch(
batch_index=batch_idx,
batch_size=n,
num_models=num_models,
base_time=base_time,
)
)
inserted += n
batch_idx += 1
if batch_idx % 10 == 0:
print(f" {inserted}/{total_rows}", flush=True)
print(f"Seeded {inserted} rows ({SEED_MARKER}).")
async def _cleanup(prisma: Any) -> None:
result = await prisma.db.litellm_healthchecktable.delete_many(
where={"checked_by": SEED_MARKER},
)
n = getattr(result, "count", result)
print(f"Deleted {n} rows.")
async def _bench(prisma: Any) -> None:
gc.collect()
rss0 = _rss_kb_linux()
print(f"RSS (after gc): {_fmt_kb(rss0)}")
tracemalloc.start()
t0 = time.perf_counter()
try:
rows = await prisma.get_all_latest_health_checks()
finally:
elapsed = time.perf_counter() - t0
_, peak = tracemalloc.get_traced_memory()
tracemalloc.stop()
gc.collect()
rss1 = _rss_kb_linux()
print(f"get_all_latest_health_checks: {len(rows)} rows in {elapsed:.2f}s")
print(f"tracemalloc peak: {peak / 1e6:.2f} MiB")
print(f"RSS after: {_fmt_kb(rss1)}")
async def _amain() -> int:
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
p.add_argument("action", choices=("seed", "bench", "cleanup"))
p.add_argument("--rows", type=int, default=10_000)
p.add_argument("--batch-size", type=int, default=1000)
p.add_argument("--num-models", type=int, default=50)
args = p.parse_args()
database_url = os.getenv("DATABASE_URL")
if not database_url:
print("Set DATABASE_URL.", file=sys.stderr)
return 1
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
if repo_root not in sys.path:
sys.path.insert(0, repo_root)
from litellm.caching.caching import DualCache
from litellm.proxy.proxy_cli import append_query_params
from litellm.proxy.utils import PrismaClient, ProxyLogging
db_url = append_query_params(
database_url, {"connection_limit": 100, "pool_timeout": 60}
)
prisma = PrismaClient(
database_url=db_url,
proxy_logging_obj=ProxyLogging(user_api_key_cache=DualCache()),
)
try:
await prisma.connect()
except Exception as e:
print(f"Connect failed: {e}", file=sys.stderr)
return 1
try:
if args.action == "seed":
await _seed(
prisma,
total_rows=args.rows,
batch_size=args.batch_size,
num_models=args.num_models,
)
elif args.action == "bench":
await _bench(prisma)
else:
await _cleanup(prisma)
finally:
try:
await prisma.disconnect()
except Exception:
pass
return 0
if __name__ == "__main__":
raise SystemExit(asyncio.run(_amain()))

View File

@ -406,12 +406,6 @@ async def test_save_background_health_checks_to_db_exception_handling():
@pytest.mark.asyncio
async def test_get_all_latest_health_checks_with_model_id(mock_prisma):
"""Test get_all_latest_health_checks properly groups by model_id"""
# Create mock checks with same model_name but different model_id
mock_check1 = MagicMock()
mock_check1.model_id = "model-123"
mock_check1.model_name = "gpt-3.5-turbo"
mock_check1.checked_at = datetime.now(timezone.utc) - timedelta(minutes=10)
mock_check2 = MagicMock()
mock_check2.model_id = "model-456"
mock_check2.model_name = "gpt-3.5-turbo"
@ -424,7 +418,7 @@ async def test_get_all_latest_health_checks_with_model_id(mock_prisma):
# Order by checked_at desc
mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock(
return_value=[mock_check3, mock_check2, mock_check1]
return_value=[mock_check3, mock_check2]
)
result = await mock_prisma.get_all_latest_health_checks()
@ -445,18 +439,13 @@ async def test_get_all_latest_health_checks_with_model_id(mock_prisma):
@pytest.mark.asyncio
async def test_get_all_latest_health_checks_without_model_id(mock_prisma):
"""Test get_all_latest_health_checks groups by model_name when model_id is None"""
mock_check1 = MagicMock()
mock_check1.model_id = None
mock_check1.model_name = "gpt-3.5-turbo"
mock_check1.checked_at = datetime.now(timezone.utc) - timedelta(minutes=10)
mock_check2 = MagicMock()
mock_check2.model_id = None
mock_check2.model_name = "gpt-3.5-turbo"
mock_check2.checked_at = datetime.now(timezone.utc) - timedelta(minutes=1) # Latest
mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock(
return_value=[mock_check2, mock_check1]
return_value=[mock_check2]
)
result = await mock_prisma.get_all_latest_health_checks()
@ -467,6 +456,41 @@ async def test_get_all_latest_health_checks_without_model_id(mock_prisma):
assert result[0].checked_at == mock_check2.checked_at # Latest
@pytest.mark.asyncio
async def test_get_all_latest_health_checks_same_name_with_and_without_model_id(mock_prisma):
"""
Same model_name can appear twice after DISTINCT ON: once keyed by (model_id, name)
and once by (NULL, name) different Postgres groups than a single row with id.
"""
now = datetime.now(timezone.utc)
with_id = MagicMock()
with_id.model_id = "deployment-abc"
with_id.model_name = "gpt-4"
with_id.checked_at = now - timedelta(minutes=2)
without_id = MagicMock()
without_id.model_id = None
without_id.model_name = "gpt-4"
without_id.checked_at = now - timedelta(minutes=1)
mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock(
return_value=[without_id, with_id]
)
result = await mock_prisma.get_all_latest_health_checks()
assert len(result) == 2
names = {r.model_name for r in result}
assert names == {"gpt-4"}
ids = {r.model_id for r in result}
assert "deployment-abc" in ids
assert None in ids
by_key = {(r.model_id, r.model_name): r for r in result}
assert by_key[("deployment-abc", "gpt-4")].checked_at == with_id.checked_at
assert by_key[(None, "gpt-4")].checked_at == without_id.checked_at
@pytest.mark.asyncio
async def test_perform_health_check_and_save_passes_model_id_to_perform_health_check():
"""Test that _perform_health_check_and_save passes model_id to perform_health_check so health checks run by model id."""