Merge pull request #25732 from harish876/health-check-oom
Optimize database query to prevent OOM errors during health checks
This commit is contained in:
commit
cb8fc480e6
@ -0,0 +1,12 @@
|
||||
-- CreateIndex (CONCURRENTLY)
|
||||
--
|
||||
-- Disclaimer:
|
||||
-- - CREATE INDEX CONCURRENTLY cannot run inside a transaction. This migration must stay a
|
||||
-- single statement so Prisma Migrate on PostgreSQL can apply it outside a transaction.
|
||||
-- - Builds are slower and use more I/O than a blocking CREATE INDEX; if the build is
|
||||
-- interrupted, Postgres may leave an INVALID index that must be dropped and recreated.
|
||||
-- - Do not edit this file after it has been applied to any database: Prisma checksums
|
||||
-- migrations; add a new migration instead.
|
||||
-- - Requires PostgreSQL that supports CONCURRENTLY with IF NOT EXISTS (use a new migration
|
||||
-- without IF NOT EXISTS if you must support older versions).
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx" ON "LiteLLM_HealthCheckTable"("model_id", "model_name", "checked_at" DESC);
|
||||
@ -1045,6 +1045,7 @@ model LiteLLM_HealthCheckTable {
|
||||
@@index([model_name])
|
||||
@@index([checked_at])
|
||||
@@index([status])
|
||||
@@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx")
|
||||
}
|
||||
|
||||
// Search Tools table for storing search tool configurations
|
||||
|
||||
@ -84,7 +84,7 @@ class SharedHealthCheckManager:
|
||||
"Pod %s failed to acquire health check lock", self.pod_id
|
||||
)
|
||||
|
||||
return acquired
|
||||
return bool(acquired)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("Error acquiring health check lock: %s", str(e))
|
||||
return False
|
||||
|
||||
@ -1045,6 +1045,7 @@ model LiteLLM_HealthCheckTable {
|
||||
@@index([model_name])
|
||||
@@index([checked_at])
|
||||
@@index([status])
|
||||
@@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx")
|
||||
}
|
||||
|
||||
// Search Tools table for storing search tool configurations
|
||||
|
||||
@ -4525,29 +4525,20 @@ class PrismaClient:
|
||||
|
||||
async def get_all_latest_health_checks(self):
|
||||
"""
|
||||
Get the latest health check for each model
|
||||
Get the latest health check for each model.
|
||||
|
||||
Uses DB-level DISTINCT ON (model_id, model_name) with ORDER BY checked_at DESC
|
||||
(via Prisma ``distinct`` + ``order``) so we never load the full history into memory.
|
||||
"""
|
||||
try:
|
||||
# Get all unique model names first
|
||||
all_checks = await self.db.litellm_healthchecktable.find_many(
|
||||
order={"checked_at": "desc"}
|
||||
return await self.db.litellm_healthchecktable.find_many(
|
||||
distinct=["model_id", "model_name"],
|
||||
order=[
|
||||
{"model_id": "asc"},
|
||||
{"model_name": "asc"},
|
||||
{"checked_at": "desc"},
|
||||
],
|
||||
)
|
||||
|
||||
# Group by model_name and get the latest for each
|
||||
latest_checks = {}
|
||||
for check in all_checks:
|
||||
# Create a unique key: prefer model_id if available, otherwise use model_name
|
||||
# This ensures we get the latest check for each unique model
|
||||
if check.model_id:
|
||||
key = (check.model_id, check.model_name)
|
||||
else:
|
||||
key = (None, check.model_name)
|
||||
|
||||
# Only add if we haven't seen this key yet (since checks are ordered by checked_at desc)
|
||||
if key not in latest_checks:
|
||||
latest_checks[key] = check
|
||||
|
||||
return list(latest_checks.values())
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(f"Error getting all latest health checks: {e}")
|
||||
return []
|
||||
|
||||
@ -1045,6 +1045,7 @@ model LiteLLM_HealthCheckTable {
|
||||
@@index([model_name])
|
||||
@@index([checked_at])
|
||||
@@index([status])
|
||||
@@index([model_id, model_name, checked_at(sort: Desc)], map: "LiteLLM_HealthCheckTable_model_id_model_name_checked_at_idx")
|
||||
}
|
||||
|
||||
// Search Tools table for storing search tool configurations
|
||||
|
||||
182
scripts/health_check/benchmark_get_all_latest_health_checks.py
Normal file
182
scripts/health_check/benchmark_get_all_latest_health_checks.py
Normal file
@ -0,0 +1,182 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Bench LiteLLM_HealthCheckTable + PrismaClient
|
||||
- set DATABASE_URL to your Postgres
|
||||
- Run ```prisma generate``` to install prisma client before running test )
|
||||
- This test writes to the default "public" database. Make sure to run cleanup after testing
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import gc
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import tracemalloc
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, List
|
||||
|
||||
SEED_MARKER = "benchmark_get_all_latest_health_checks.py" # Utility Marker for cleanup process.
|
||||
|
||||
|
||||
def _rss_kb_linux() -> int:
|
||||
try:
|
||||
with open("/proc/self/status", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
if line.startswith("VmRSS:"):
|
||||
return int(line.split()[1])
|
||||
except OSError:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
def _fmt_kb(kb: int) -> str:
|
||||
if kb <= 0:
|
||||
return "n/a"
|
||||
return f"{kb} KiB (~{kb / 1024.0:.1f} MiB)"
|
||||
|
||||
|
||||
def _build_batch(
|
||||
*,
|
||||
batch_index: int,
|
||||
batch_size: int,
|
||||
num_models: int,
|
||||
base_time: datetime,
|
||||
) -> List[dict[str, Any]]:
|
||||
rows: List[dict[str, Any]] = []
|
||||
for i in range(batch_size):
|
||||
global_i = batch_index * batch_size + i
|
||||
model_idx = global_i % max(num_models, 1)
|
||||
model_name = f"bench-model-{model_idx}"
|
||||
model_id = f"bench-mid-{model_idx}" if model_idx % 2 == 0 else None
|
||||
checked_at = base_time - timedelta(seconds=global_i)
|
||||
rows.append(
|
||||
{
|
||||
"model_name": model_name,
|
||||
"model_id": model_id,
|
||||
"status": "healthy" if global_i % 3 else "unhealthy",
|
||||
"healthy_count": 1,
|
||||
"unhealthy_count": 0,
|
||||
"checked_by": SEED_MARKER,
|
||||
"checked_at": checked_at,
|
||||
}
|
||||
)
|
||||
return rows
|
||||
|
||||
|
||||
async def _seed(
|
||||
prisma: Any,
|
||||
*,
|
||||
total_rows: int,
|
||||
batch_size: int,
|
||||
num_models: int,
|
||||
) -> None:
|
||||
db = prisma.db
|
||||
base_time = datetime.now(timezone.utc)
|
||||
inserted = 0
|
||||
batch_idx = 0
|
||||
while inserted < total_rows:
|
||||
n = min(batch_size, total_rows - inserted)
|
||||
await db.litellm_healthchecktable.create_many(
|
||||
data=_build_batch(
|
||||
batch_index=batch_idx,
|
||||
batch_size=n,
|
||||
num_models=num_models,
|
||||
base_time=base_time,
|
||||
)
|
||||
)
|
||||
inserted += n
|
||||
batch_idx += 1
|
||||
if batch_idx % 10 == 0:
|
||||
print(f" {inserted}/{total_rows}", flush=True)
|
||||
print(f"Seeded {inserted} rows ({SEED_MARKER}).")
|
||||
|
||||
|
||||
async def _cleanup(prisma: Any) -> None:
|
||||
result = await prisma.db.litellm_healthchecktable.delete_many(
|
||||
where={"checked_by": SEED_MARKER},
|
||||
)
|
||||
n = getattr(result, "count", result)
|
||||
print(f"Deleted {n} rows.")
|
||||
|
||||
|
||||
async def _bench(prisma: Any) -> None:
|
||||
gc.collect()
|
||||
rss0 = _rss_kb_linux()
|
||||
print(f"RSS (after gc): {_fmt_kb(rss0)}")
|
||||
|
||||
tracemalloc.start()
|
||||
t0 = time.perf_counter()
|
||||
try:
|
||||
rows = await prisma.get_all_latest_health_checks()
|
||||
finally:
|
||||
elapsed = time.perf_counter() - t0
|
||||
_, peak = tracemalloc.get_traced_memory()
|
||||
tracemalloc.stop()
|
||||
|
||||
gc.collect()
|
||||
rss1 = _rss_kb_linux()
|
||||
print(f"get_all_latest_health_checks: {len(rows)} rows in {elapsed:.2f}s")
|
||||
print(f"tracemalloc peak: {peak / 1e6:.2f} MiB")
|
||||
print(f"RSS after: {_fmt_kb(rss1)}")
|
||||
|
||||
|
||||
async def _amain() -> int:
|
||||
p = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter)
|
||||
p.add_argument("action", choices=("seed", "bench", "cleanup"))
|
||||
p.add_argument("--rows", type=int, default=10_000)
|
||||
p.add_argument("--batch-size", type=int, default=1000)
|
||||
p.add_argument("--num-models", type=int, default=50)
|
||||
args = p.parse_args()
|
||||
|
||||
database_url = os.getenv("DATABASE_URL")
|
||||
if not database_url:
|
||||
print("Set DATABASE_URL.", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
repo_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
if repo_root not in sys.path:
|
||||
sys.path.insert(0, repo_root)
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy.proxy_cli import append_query_params
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging
|
||||
|
||||
db_url = append_query_params(
|
||||
database_url, {"connection_limit": 100, "pool_timeout": 60}
|
||||
)
|
||||
prisma = PrismaClient(
|
||||
database_url=db_url,
|
||||
proxy_logging_obj=ProxyLogging(user_api_key_cache=DualCache()),
|
||||
)
|
||||
try:
|
||||
await prisma.connect()
|
||||
except Exception as e:
|
||||
print(f"Connect failed: {e}", file=sys.stderr)
|
||||
return 1
|
||||
|
||||
try:
|
||||
if args.action == "seed":
|
||||
await _seed(
|
||||
prisma,
|
||||
total_rows=args.rows,
|
||||
batch_size=args.batch_size,
|
||||
num_models=args.num_models,
|
||||
)
|
||||
elif args.action == "bench":
|
||||
await _bench(prisma)
|
||||
else:
|
||||
await _cleanup(prisma)
|
||||
finally:
|
||||
try:
|
||||
await prisma.disconnect()
|
||||
except Exception:
|
||||
pass
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(asyncio.run(_amain()))
|
||||
@ -406,12 +406,6 @@ async def test_save_background_health_checks_to_db_exception_handling():
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_latest_health_checks_with_model_id(mock_prisma):
|
||||
"""Test get_all_latest_health_checks properly groups by model_id"""
|
||||
# Create mock checks with same model_name but different model_id
|
||||
mock_check1 = MagicMock()
|
||||
mock_check1.model_id = "model-123"
|
||||
mock_check1.model_name = "gpt-3.5-turbo"
|
||||
mock_check1.checked_at = datetime.now(timezone.utc) - timedelta(minutes=10)
|
||||
|
||||
mock_check2 = MagicMock()
|
||||
mock_check2.model_id = "model-456"
|
||||
mock_check2.model_name = "gpt-3.5-turbo"
|
||||
@ -424,7 +418,7 @@ async def test_get_all_latest_health_checks_with_model_id(mock_prisma):
|
||||
|
||||
# Order by checked_at desc
|
||||
mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock(
|
||||
return_value=[mock_check3, mock_check2, mock_check1]
|
||||
return_value=[mock_check3, mock_check2]
|
||||
)
|
||||
|
||||
result = await mock_prisma.get_all_latest_health_checks()
|
||||
@ -445,18 +439,13 @@ async def test_get_all_latest_health_checks_with_model_id(mock_prisma):
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_latest_health_checks_without_model_id(mock_prisma):
|
||||
"""Test get_all_latest_health_checks groups by model_name when model_id is None"""
|
||||
mock_check1 = MagicMock()
|
||||
mock_check1.model_id = None
|
||||
mock_check1.model_name = "gpt-3.5-turbo"
|
||||
mock_check1.checked_at = datetime.now(timezone.utc) - timedelta(minutes=10)
|
||||
|
||||
mock_check2 = MagicMock()
|
||||
mock_check2.model_id = None
|
||||
mock_check2.model_name = "gpt-3.5-turbo"
|
||||
mock_check2.checked_at = datetime.now(timezone.utc) - timedelta(minutes=1) # Latest
|
||||
|
||||
mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock(
|
||||
return_value=[mock_check2, mock_check1]
|
||||
return_value=[mock_check2]
|
||||
)
|
||||
|
||||
result = await mock_prisma.get_all_latest_health_checks()
|
||||
@ -467,6 +456,41 @@ async def test_get_all_latest_health_checks_without_model_id(mock_prisma):
|
||||
assert result[0].checked_at == mock_check2.checked_at # Latest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_all_latest_health_checks_same_name_with_and_without_model_id(mock_prisma):
|
||||
"""
|
||||
Same model_name can appear twice after DISTINCT ON: once keyed by (model_id, name)
|
||||
and once by (NULL, name) — different Postgres groups than a single row with id.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
with_id = MagicMock()
|
||||
with_id.model_id = "deployment-abc"
|
||||
with_id.model_name = "gpt-4"
|
||||
with_id.checked_at = now - timedelta(minutes=2)
|
||||
|
||||
without_id = MagicMock()
|
||||
without_id.model_id = None
|
||||
without_id.model_name = "gpt-4"
|
||||
without_id.checked_at = now - timedelta(minutes=1)
|
||||
|
||||
mock_prisma.db.litellm_healthchecktable.find_many = AsyncMock(
|
||||
return_value=[without_id, with_id]
|
||||
)
|
||||
|
||||
result = await mock_prisma.get_all_latest_health_checks()
|
||||
|
||||
assert len(result) == 2
|
||||
names = {r.model_name for r in result}
|
||||
assert names == {"gpt-4"}
|
||||
ids = {r.model_id for r in result}
|
||||
assert "deployment-abc" in ids
|
||||
assert None in ids
|
||||
|
||||
by_key = {(r.model_id, r.model_name): r for r in result}
|
||||
assert by_key[("deployment-abc", "gpt-4")].checked_at == with_id.checked_at
|
||||
assert by_key[(None, "gpt-4")].checked_at == without_id.checked_at
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perform_health_check_and_save_passes_model_id_to_perform_health_check():
|
||||
"""Test that _perform_health_check_and_save passes model_id to perform_health_check so health checks run by model id."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user