diff --git a/.github/workflows/test-unit-proxy-endpoints.yml b/.github/workflows/test-unit-proxy-endpoints.yml index 6b34a08a8e..1ce0abfc00 100644 --- a/.github/workflows/test-unit-proxy-endpoints.yml +++ b/.github/workflows/test-unit-proxy-endpoints.yml @@ -36,6 +36,7 @@ jobs: tests/test_litellm/proxy/a2a tests/test_litellm/proxy/discovery_endpoints tests/test_litellm/proxy/health_endpoints + tests/test_litellm/proxy/shutdown tests/test_litellm/proxy/public_endpoints tests/test_litellm/proxy/prompts tests/test_litellm/proxy/rag_endpoints diff --git a/Makefile b/Makefile index 5dbd308a3e..a00a90da60 100644 --- a/Makefile +++ b/Makefile @@ -146,7 +146,7 @@ test-unit-proxy-core: install-test-deps $(UV_RUN) pytest tests/test_litellm/proxy/auth tests/test_litellm/proxy/client tests/test_litellm/proxy/db tests/test_litellm/proxy/hooks tests/test_litellm/proxy/policy_engine --tb=short -vv -n 4 --durations=20 test-unit-proxy-misc: install-test-deps - $(UV_RUN) pytest tests/test_litellm/proxy/_experimental tests/test_litellm/proxy/agent_endpoints tests/test_litellm/proxy/anthropic_endpoints tests/test_litellm/proxy/common_utils tests/test_litellm/proxy/discovery_endpoints tests/test_litellm/proxy/experimental tests/test_litellm/proxy/google_endpoints tests/test_litellm/proxy/health_endpoints tests/test_litellm/proxy/image_endpoints tests/test_litellm/proxy/middleware tests/test_litellm/proxy/openai_files_endpoint tests/test_litellm/proxy/pass_through_endpoints tests/test_litellm/proxy/prompts tests/test_litellm/proxy/public_endpoints tests/test_litellm/proxy/response_api_endpoints tests/test_litellm/proxy/spend_tracking tests/test_litellm/proxy/ui_crud_endpoints tests/test_litellm/proxy/vector_store_endpoints tests/test_litellm/proxy/test_*.py --tb=short -vv -n 4 --durations=20 + $(UV_RUN) pytest tests/test_litellm/proxy/_experimental tests/test_litellm/proxy/agent_endpoints tests/test_litellm/proxy/anthropic_endpoints tests/test_litellm/proxy/common_utils tests/test_litellm/proxy/discovery_endpoints tests/test_litellm/proxy/experimental tests/test_litellm/proxy/google_endpoints tests/test_litellm/proxy/health_endpoints tests/test_litellm/proxy/image_endpoints tests/test_litellm/proxy/middleware tests/test_litellm/proxy/openai_files_endpoint tests/test_litellm/proxy/pass_through_endpoints tests/test_litellm/proxy/prompts tests/test_litellm/proxy/public_endpoints tests/test_litellm/proxy/response_api_endpoints tests/test_litellm/proxy/shutdown tests/test_litellm/proxy/spend_tracking tests/test_litellm/proxy/ui_crud_endpoints tests/test_litellm/proxy/vector_store_endpoints tests/test_litellm/proxy/test_*.py --tb=short -vv -n 4 --durations=20 test-unit-integrations: install-test-deps $(UV_RUN) pytest tests/test_litellm/integrations --tb=short -vv -n 4 --durations=20 diff --git a/deploy/charts/litellm-helm/values.yaml b/deploy/charts/litellm-helm/values.yaml index a9cdf28f0e..6e30a6af44 100644 --- a/deploy/charts/litellm-helm/values.yaml +++ b/deploy/charts/litellm-helm/values.yaml @@ -285,11 +285,31 @@ db: deployStandalone: true # Lifecycle hooks for the LiteLLM container +# +# Prefer the native /health/drain preStop hook over a fixed `sleep`: it marks +# the pod NotReady and blocks only until in-flight requests actually finish +# (bounded by GRACEFUL_SHUTDOWN_TIMEOUT, default 30s), instead of always +# waiting the worst-case duration. The drain runs once (the preStop hook and +# the SIGTERM handler share it), so set terminationGracePeriodSeconds a few +# seconds above GRACEFUL_SHUTDOWN_TIMEOUT to leave room for teardown before +# SIGKILL. +# +# /health/drain is off by default; enable it with +# general_settings.enable_drain_endpoint: true. The kubelet calls preStop +# hooks without proxy credentials, so when the health port is reachable from +# other pods (the common case) also set +# general_settings.drain_endpoint_token (or the DRAIN_ENDPOINT_TOKEN env +# var) and send the same value on the X-Drain-Token header from the hook. +# Calls missing/wrong the token get a 401 and have no side effect. # Example: # lifecycle: # preStop: -# exec: -# command: ["/bin/sh", "-c", "sleep 10"] +# httpGet: +# path: /health/drain +# port: 4000 +# httpHeaders: +# - name: X-Drain-Token +# value: lifecycle: {} # Settings for Bitnami postgresql chart (if db.deployStandalone is true, ignored diff --git a/litellm/proxy/health_endpoints/_health_endpoints.py b/litellm/proxy/health_endpoints/_health_endpoints.py index ba3aee7504..c109f37499 100644 --- a/litellm/proxy/health_endpoints/_health_endpoints.py +++ b/litellm/proxy/health_endpoints/_health_endpoints.py @@ -2,6 +2,7 @@ import asyncio import copy import logging import os +import secrets import time import traceback from datetime import datetime, timedelta @@ -39,6 +40,7 @@ from litellm.proxy.health_check import ( from litellm.proxy.middleware.in_flight_requests_middleware import ( get_in_flight_requests, ) +from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager #### Health ENDPOINTS #### @@ -1551,6 +1553,50 @@ def _allow_public_health_readiness_details() -> bool: return general_settings.get("allow_public_health_readiness_details") is True +def _drain_endpoint_enabled() -> bool: + from litellm.proxy.proxy_server import general_settings + + return general_settings.get("enable_drain_endpoint") is True + + +def _drain_endpoint_token() -> Optional[str]: + """ + Shared secret required on the X-Drain-Token header to call /health/drain. + + Falls back to the ``DRAIN_ENDPOINT_TOKEN`` env var when unset in + general_settings so the kubelet preStop hook can supply it via + ``valueFrom.secretKeyRef`` without a config reload. + """ + from litellm.proxy.proxy_server import general_settings + + token = general_settings.get("drain_endpoint_token") + if isinstance(token, str) and token: + return token + env_token = os.getenv("DRAIN_ENDPOINT_TOKEN") + if env_token: + return env_token + return None + + +def _authorize_drain_request(request: Request) -> None: + """ + Reject /health/drain calls that don't carry the configured X-Drain-Token. + + When no token is configured the endpoint is treated as already opted-in + (the ``enable_drain_endpoint`` flag is the only gate). Comparison uses + ``secrets.compare_digest`` to avoid timing leaks. + """ + expected = _drain_endpoint_token() + if expected is None: + return + supplied = request.headers.get("x-drain-token") or "" + if not secrets.compare_digest(supplied, expected): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Invalid or missing X-Drain-Token", + ) + + async def _resolve_public_readiness_db(response: Response) -> str: """ Return the db status string for the public probe and flip the response to @@ -1580,6 +1626,10 @@ async def health_readiness(response: Response): credential. Admins can opt into the legacy detailed payload with general_settings.allow_public_health_readiness_details. """ + if GracefulShutdownManager.is_shutting_down(): + response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE + return {"status": "shutting_down"} + if _allow_public_health_readiness_details(): return await _get_health_readiness_details(response=response) @@ -1616,6 +1666,54 @@ async def health_backlog(): return {"in_flight_requests": get_in_flight_requests()} +@router.get( + "/health/drain", + tags=["health"], +) +async def health_drain(request: Request): + """ + Graceful-drain probe for Kubernetes ``preStop`` hooks. + + Disabled by default and returns 404 unless ``general_settings`` sets + ``enable_drain_endpoint: true``. Calling it flips a process-wide + shutting-down flag, so a successful call permanently takes the worker out + of rotation until the pod restarts. + + Because the kubelet calls preStop hooks without proxy credentials, the + endpoint does not require ``user_api_key_auth``. To prevent any + pod-reachable caller from triggering shutdown, set + ``general_settings.drain_endpoint_token`` (or the ``DRAIN_ENDPOINT_TOKEN`` + env var) and supply the same value on the ``X-Drain-Token`` header from + the preStop hook. Calls without the header (or with a wrong value) get a + 401 and have no side effect. + + When enabled, it marks the worker as shutting down (so /health/readiness + and /health/liveliness immediately start returning 503, removing the pod + from service) and blocks until the in-flight request counter drains to + zero or ``GRACEFUL_SHUTDOWN_TIMEOUT`` elapses. Unlike a fixed ``sleep``, + this returns as soon as real in-flight work is done. + + Wire it up as: + + ```yaml + lifecycle: + preStop: + httpGet: + path: /health/drain + port: 4000 + httpHeaders: + - name: X-Drain-Token + value: + ``` + """ + if not _drain_endpoint_enabled(): + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found") + _authorize_drain_request(request) + GracefulShutdownManager.start_shutdown() + drained = await GracefulShutdownManager.wait_for_drain(exclude_self=True) + return {"status": "drained", "drained_requests": drained} + + @router.get( "/health/liveliness", # Historical LiteLLM name; doesn't match k8s terminology but kept for backwards compatibility tags=["health"], @@ -1624,10 +1722,16 @@ async def health_backlog(): "/health/liveness", # Kubernetes has "liveness" probes (https://kubernetes.io/docs/tasks/configure-pod-container/configure-liveness-readiness-startup-probes/#define-a-liveness-command) tags=["health"], ) -async def health_liveliness(): +async def health_liveliness(response: Response): """ - Unprotected endpoint for checking if worker is alive + Unprotected endpoint for checking if worker is alive. + + Returns 503 once graceful shutdown has begun so Kubernetes stops counting + the draining pod as live and terminates it on schedule. """ + if GracefulShutdownManager.is_shutting_down(): + response.status_code = status.HTTP_503_SERVICE_UNAVAILABLE + return {"status": "shutting_down"} return "I'm alive!" diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 61df62ba2b..bd667298b4 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -417,6 +417,7 @@ from litellm.proxy.middleware.prometheus_auth_middleware import PrometheusAuthMi from litellm.proxy.middleware.request_size_limit_middleware import ( RequestSizeLimitMiddleware, ) +from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager from litellm.proxy.ocr_endpoints.endpoints import router as ocr_router from litellm.proxy.openai_files_endpoints.files_endpoints import ( router as openai_files_router, @@ -973,6 +974,11 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915 # End of startup event yield + # Shutdown event - drain in-flight requests before tearing down dependencies + # so SIGTERM (rolling update, scale-down, liveness kill) doesn't drop them. + GracefulShutdownManager.start_shutdown() + await GracefulShutdownManager.wait_for_drain() + # Shutdown event - close shared aiohttp session if shared_aiohttp_session is not None: try: diff --git a/litellm/proxy/shutdown/__init__.py b/litellm/proxy/shutdown/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/litellm/proxy/shutdown/graceful_shutdown_manager.py b/litellm/proxy/shutdown/graceful_shutdown_manager.py new file mode 100644 index 0000000000..20ffabdd7d --- /dev/null +++ b/litellm/proxy/shutdown/graceful_shutdown_manager.py @@ -0,0 +1,174 @@ +""" +Application-level graceful shutdown coordination for the LiteLLM proxy. + +Kubernetes terminates a pod by sending ``SIGTERM`` and, after +``terminationGracePeriodSeconds``, ``SIGKILL``. By default LiteLLM delegates +the signal to uvicorn and tears down immediately, dropping any in-flight +requests (streaming, batch inference, long-lived calls). + +A fixed ``preStop`` sleep can not solve this: it has to be sized for the +*worst-case* request, so it either wastes time on every routine shutdown or is +too short for a long-running request. This manager instead drains based on the +*actual* in-flight request counter (already tracked by +``InFlightRequestsMiddleware``), so a pod terminates as soon as its real +in-flight work is done — and never waits longer than ``GRACEFUL_SHUTDOWN_TIMEOUT``. + +The state is process-scoped (class-level), matching the per-uvicorn-worker +granularity of ``InFlightRequestsMiddleware``. +""" + +import asyncio +import os +import time +from typing import Callable, Optional + +from litellm._logging import verbose_proxy_logger +from litellm.proxy.middleware.in_flight_requests_middleware import ( + get_in_flight_requests, +) + +# Keep below terminationGracePeriodSeconds so the process exits before SIGKILL. +DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT = 30.0 +_DRAIN_POLL_INTERVAL = 0.1 +_DRAIN_LOG_INTERVAL = 5.0 + + +class GracefulShutdownManager: + """ + Process-scoped singleton that tracks whether the worker is draining and + blocks until in-flight requests reach zero (or a timeout elapses). + """ + + _is_shutting_down: bool = False + _shutdown_started_at: Optional[float] = None + _drain_performed: bool = False + + @classmethod + def is_shutting_down(cls) -> bool: + """Whether this worker has begun graceful shutdown.""" + return cls._is_shutting_down + + @classmethod + def get_timeout(cls) -> float: + """ + Read GRACEFUL_SHUTDOWN_TIMEOUT (seconds) from the environment on each + call so deployments can tune it without code changes. Falls back to the + default on an unset or malformed value. + """ + raw = os.getenv("GRACEFUL_SHUTDOWN_TIMEOUT") + if raw is None: + return DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + try: + return float(raw) + except (TypeError, ValueError): + verbose_proxy_logger.warning( + "GRACEFUL_SHUTDOWN_TIMEOUT=%r is not a number; using default %ss", + raw, + DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT, + ) + return DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + + @classmethod + def start_shutdown(cls) -> None: + """ + Mark the worker as draining. Idempotent — repeated calls (e.g. SIGTERM + followed by a preStop hit on /health/drain) do not reset the clock. + """ + if cls._is_shutting_down: + return + cls._is_shutting_down = True + cls._shutdown_started_at = time.monotonic() + verbose_proxy_logger.info( + "graceful_shutdown_started in_flight_requests=%s", + get_in_flight_requests(), + ) + + @classmethod + async def wait_for_drain( + cls, + timeout: Optional[float] = None, + exclude_self: bool = False, + count_fn: Optional[Callable[[], int]] = None, + poll_interval: float = _DRAIN_POLL_INTERVAL, + log_interval: float = _DRAIN_LOG_INTERVAL, + ) -> int: + """ + Poll the in-flight request counter until it reaches the drain target or + ``timeout`` seconds elapse. + + Args: + timeout: Max seconds to wait. Defaults to ``get_timeout()``. + exclude_self: When the caller is itself an in-flight HTTP request + (the /health/drain endpoint), set this so the caller's own + request is not counted as outstanding work. + count_fn: Source of the current in-flight count. Defaults to the + live ``InFlightRequestsMiddleware`` counter; injectable for tests. + poll_interval: Seconds between counter polls. + log_interval: Minimum seconds between ``drain_waiting`` log lines. + + Returns: + Number of requests that drained while waiting (>= 0). + """ + # A preStop /health/drain hook and the lifespan SIGTERM handler both + # drain; once one has run, the other must not wait again, otherwise the + # effective window is 2x the timeout and terminationGracePeriodSeconds + # has to be doubled to avoid a mid-drain SIGKILL. + if cls._drain_performed: + return 0 + cls._drain_performed = True + + if timeout is None: + timeout = cls.get_timeout() + if count_fn is None: + count_fn = get_in_flight_requests + + # The /health/drain HTTP request flows through InFlightRequestsMiddleware + # and so counts itself; treat <=1 as "drained" in that case. + target = 1 if exclude_self else 0 + + start = time.monotonic() + initial = count_fn() + last_log = start + + if timeout <= 0: + return max(0, initial - target) + + while True: + current = count_fn() + if current <= target: + drained = max(0, initial - current) + verbose_proxy_logger.info( + "graceful_shutdown_complete drained_requests=%s elapsed_s=%.2f", + drained, + time.monotonic() - start, + ) + return drained + + elapsed = time.monotonic() - start + if elapsed >= timeout: + verbose_proxy_logger.warning( + "graceful_shutdown_timeout in_flight_requests=%s elapsed_s=%.2f " + "timeout_s=%s — proceeding with teardown", + current, + elapsed, + timeout, + ) + return max(0, initial - current) + + now = time.monotonic() + if now - last_log >= log_interval: + verbose_proxy_logger.info( + "drain_waiting in_flight_requests=%s elapsed_s=%.2f", + current, + elapsed, + ) + last_log = now + + await asyncio.sleep(poll_interval) + + @classmethod + def reset(cls) -> None: + """Reset state. Intended for use in tests.""" + cls._is_shutting_down = False + cls._shutdown_started_at = None + cls._drain_performed = False diff --git a/tests/test_litellm/proxy/conftest.py b/tests/test_litellm/proxy/conftest.py index 20236ebdf4..607315eb24 100644 --- a/tests/test_litellm/proxy/conftest.py +++ b/tests/test_litellm/proxy/conftest.py @@ -14,7 +14,6 @@ import pytest import yaml from fastapi.testclient import TestClient - _PROXY_MODULE_GLOBALS_TO_ISOLATE = ( "master_key", "prisma_client", @@ -49,6 +48,18 @@ def _isolate_proxy_module_globals(): setattr(proxy_server, name, value) +@pytest.fixture(autouse=True) +def _reset_graceful_shutdown_state(): + """Graceful shutdown state is process-scoped; keep it from leaking between tests.""" + from litellm.proxy.shutdown.graceful_shutdown_manager import ( + GracefulShutdownManager, + ) + + GracefulShutdownManager.reset() + yield + GracefulShutdownManager.reset() + + def build_cache_config(enable_cache: bool = True) -> Optional[Dict]: """ Build Redis cache configuration from environment variables. diff --git a/tests/test_litellm/proxy/health_endpoints/test_graceful_shutdown_endpoints.py b/tests/test_litellm/proxy/health_endpoints/test_graceful_shutdown_endpoints.py new file mode 100644 index 0000000000..e20b54f28f --- /dev/null +++ b/tests/test_litellm/proxy/health_endpoints/test_graceful_shutdown_endpoints.py @@ -0,0 +1,166 @@ +""" +Behaviour tests for the graceful-shutdown health probes. + +Builds a minimal FastAPI app from the health router plus +InFlightRequestsMiddleware so the probe responses can be asserted without +standing up the full proxy. +""" + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +from litellm.proxy.health_endpoints._health_endpoints import router +from litellm.proxy.middleware.in_flight_requests_middleware import ( + InFlightRequestsMiddleware, +) +from litellm.proxy.shutdown.graceful_shutdown_manager import GracefulShutdownManager + + +@pytest.fixture(autouse=True) +def _reset(): + GracefulShutdownManager.reset() + InFlightRequestsMiddleware._in_flight = 0 + yield + GracefulShutdownManager.reset() + InFlightRequestsMiddleware._in_flight = 0 + + +@pytest.fixture +def client(): + app = FastAPI() + app.include_router(router) + app.add_middleware(InFlightRequestsMiddleware) + return TestClient(app) + + +@pytest.fixture +def enable_drain(monkeypatch): + from litellm.proxy import proxy_server + + monkeypatch.setattr( + proxy_server, "general_settings", {"enable_drain_endpoint": True} + ) + + +@pytest.fixture +def enable_drain_with_token(monkeypatch): + from litellm.proxy import proxy_server + + monkeypatch.setattr( + proxy_server, + "general_settings", + {"enable_drain_endpoint": True, "drain_endpoint_token": "secret-123"}, + ) + + +def test_drain_disabled_by_default_returns_404_with_no_side_effect(client, monkeypatch): + from litellm.proxy import proxy_server + + monkeypatch.setattr(proxy_server, "general_settings", {}) + resp = client.get("/health/drain") + assert resp.status_code == 404 + assert GracefulShutdownManager.is_shutting_down() is False + + +def test_drain_disabled_ignores_token_header(client, monkeypatch): + """A token alone must not bypass the enable flag; otherwise enabling the + token side-channel would silently enable the endpoint.""" + from litellm.proxy import proxy_server + + monkeypatch.setattr( + proxy_server, "general_settings", {"drain_endpoint_token": "secret-123"} + ) + resp = client.get("/health/drain", headers={"X-Drain-Token": "secret-123"}) + assert resp.status_code == 404 + assert GracefulShutdownManager.is_shutting_down() is False + + +def test_drain_when_enabled_without_token_sets_shutting_down_and_returns_drained( + client, enable_drain +): + resp = client.get("/health/drain") + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "drained" + assert body["drained_requests"] == 0 + assert GracefulShutdownManager.is_shutting_down() is True + + +def test_drain_with_token_configured_rejects_missing_header( + client, enable_drain_with_token +): + resp = client.get("/health/drain") + assert resp.status_code == 401 + assert GracefulShutdownManager.is_shutting_down() is False + + +def test_drain_with_token_configured_rejects_wrong_header( + client, enable_drain_with_token +): + resp = client.get("/health/drain", headers={"X-Drain-Token": "wrong-value"}) + assert resp.status_code == 401 + assert GracefulShutdownManager.is_shutting_down() is False + + +def test_drain_with_token_configured_accepts_correct_header( + client, enable_drain_with_token +): + resp = client.get("/health/drain", headers={"X-Drain-Token": "secret-123"}) + assert resp.status_code == 200 + assert resp.json()["status"] == "drained" + assert GracefulShutdownManager.is_shutting_down() is True + + +def test_drain_with_token_from_env_var(client, enable_drain, monkeypatch): + monkeypatch.setenv("DRAIN_ENDPOINT_TOKEN", "env-token") + resp = client.get("/health/drain") + assert resp.status_code == 401 + resp = client.get("/health/drain", headers={"X-Drain-Token": "env-token"}) + assert resp.status_code == 200 + + +def test_drain_general_settings_token_overrides_env_var(client, monkeypatch): + from litellm.proxy import proxy_server + + monkeypatch.setattr( + proxy_server, + "general_settings", + {"enable_drain_endpoint": True, "drain_endpoint_token": "config-token"}, + ) + monkeypatch.setenv("DRAIN_ENDPOINT_TOKEN", "env-token") + resp = client.get("/health/drain", headers={"X-Drain-Token": "env-token"}) + assert resp.status_code == 401 + resp = client.get("/health/drain", headers={"X-Drain-Token": "config-token"}) + assert resp.status_code == 200 + + +def test_readiness_returns_503_shutting_down_during_drain(client): + GracefulShutdownManager.start_shutdown() + resp = client.get("/health/readiness") + assert resp.status_code == 503 + assert resp.json() == {"status": "shutting_down"} + + +def test_readiness_does_not_report_shutting_down_normally(client): + resp = client.get("/health/readiness") + assert resp.json().get("status") != "shutting_down" + + +def test_liveliness_returns_503_during_drain(client): + GracefulShutdownManager.start_shutdown() + resp = client.get("/health/liveliness") + assert resp.status_code == 503 + assert resp.json() == {"status": "shutting_down"} + + +def test_liveness_alias_returns_503_during_drain(client): + GracefulShutdownManager.start_shutdown() + resp = client.get("/health/liveness") + assert resp.status_code == 503 + + +def test_liveliness_returns_alive_when_not_shutting_down(client): + resp = client.get("/health/liveliness") + assert resp.status_code == 200 + assert resp.json() == "I'm alive!" diff --git a/tests/test_litellm/proxy/shutdown/test_graceful_shutdown_manager.py b/tests/test_litellm/proxy/shutdown/test_graceful_shutdown_manager.py new file mode 100644 index 0000000000..d38852617b --- /dev/null +++ b/tests/test_litellm/proxy/shutdown/test_graceful_shutdown_manager.py @@ -0,0 +1,181 @@ +""" +Tests for GracefulShutdownManager. + +These verify the drain logic that lets a pod terminate as soon as its real +in-flight work is done (bounded by GRACEFUL_SHUTDOWN_TIMEOUT) rather than +sleeping for a fixed worst-case duration. +""" + +import time + +import pytest + +from litellm.proxy.shutdown.graceful_shutdown_manager import ( + DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT, + GracefulShutdownManager, +) + + +@pytest.fixture(autouse=True) +def _reset(): + GracefulShutdownManager.reset() + yield + GracefulShutdownManager.reset() + + +def _counter_that_drains_after(calls_before_zero: int): + """Return a count_fn that reports N in-flight until it has been polled + `calls_before_zero` times, then reports 0.""" + state = {"polls": 0} + + def count_fn() -> int: + state["polls"] += 1 + return 0 if state["polls"] > calls_before_zero else 3 + + return count_fn + + +# ── shutdown flag ─────────────────────────────────────────────────────────── + + +def test_not_shutting_down_by_default(): + assert GracefulShutdownManager.is_shutting_down() is False + + +def test_start_shutdown_sets_flag(): + GracefulShutdownManager.start_shutdown() + assert GracefulShutdownManager.is_shutting_down() is True + + +def test_start_shutdown_is_idempotent_and_does_not_reset_clock(): + GracefulShutdownManager.start_shutdown() + first = GracefulShutdownManager._shutdown_started_at + time.sleep(0.01) + GracefulShutdownManager.start_shutdown() + assert GracefulShutdownManager._shutdown_started_at == first + + +def test_reset_clears_flag(): + GracefulShutdownManager.start_shutdown() + GracefulShutdownManager.reset() + assert GracefulShutdownManager.is_shutting_down() is False + + +# ── timeout config ──────────────────────────────────────────────────────────── + + +def test_timeout_defaults_when_unset(monkeypatch): + monkeypatch.delenv("GRACEFUL_SHUTDOWN_TIMEOUT", raising=False) + assert GracefulShutdownManager.get_timeout() == DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + + +def test_timeout_reads_env(monkeypatch): + monkeypatch.setenv("GRACEFUL_SHUTDOWN_TIMEOUT", "5") + assert GracefulShutdownManager.get_timeout() == 5.0 + + +def test_timeout_falls_back_on_garbage(monkeypatch): + monkeypatch.setenv("GRACEFUL_SHUTDOWN_TIMEOUT", "not-a-number") + assert GracefulShutdownManager.get_timeout() == DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT + + +# ── wait_for_drain ──────────────────────────────────────────────────────────── + + +@pytest.mark.asyncio +async def test_returns_immediately_when_already_drained(): + start = time.monotonic() + drained = await GracefulShutdownManager.wait_for_drain( + timeout=10, count_fn=lambda: 0 + ) + assert drained == 0 + assert time.monotonic() - start < 0.5 + + +@pytest.mark.asyncio +async def test_waits_until_counter_reaches_zero_then_returns_drained_count(): + count_fn = _counter_that_drains_after(calls_before_zero=3) + drained = await GracefulShutdownManager.wait_for_drain( + timeout=10, count_fn=count_fn + ) + assert drained == 3 + + +@pytest.mark.asyncio +async def test_times_out_when_counter_never_drains(): + start = time.monotonic() + drained = await GracefulShutdownManager.wait_for_drain( + timeout=0.3, count_fn=lambda: 2 + ) + elapsed = time.monotonic() - start + assert 0.3 <= elapsed < 2.0 + assert drained == 0 + + +@pytest.mark.asyncio +async def test_zero_timeout_does_not_block(): + start = time.monotonic() + drained = await GracefulShutdownManager.wait_for_drain( + timeout=0, count_fn=lambda: 5 + ) + assert time.monotonic() - start < 0.2 + assert drained == 5 + + +@pytest.mark.asyncio +async def test_exclude_self_treats_one_inflight_as_drained(): + """The /health/drain request counts itself, so a steady count of 1 must be + treated as fully drained rather than timing out.""" + start = time.monotonic() + drained = await GracefulShutdownManager.wait_for_drain( + timeout=5, exclude_self=True, count_fn=lambda: 1 + ) + assert time.monotonic() - start < 0.5 + assert drained == 0 + + +@pytest.mark.asyncio +async def test_without_exclude_self_one_inflight_blocks_until_timeout(): + start = time.monotonic() + await GracefulShutdownManager.wait_for_drain(timeout=0.3, count_fn=lambda: 1) + assert time.monotonic() - start >= 0.3 + + +@pytest.mark.asyncio +async def test_defaults_to_get_timeout_and_live_counter(monkeypatch): + """With no timeout/count_fn passed, it falls back to get_timeout() and the + live InFlightRequestsMiddleware counter.""" + from litellm.proxy.middleware.in_flight_requests_middleware import ( + InFlightRequestsMiddleware, + ) + + monkeypatch.delenv("GRACEFUL_SHUTDOWN_TIMEOUT", raising=False) + InFlightRequestsMiddleware._in_flight = 0 + drained = await GracefulShutdownManager.wait_for_drain() + assert drained == 0 + + +@pytest.mark.asyncio +async def test_second_drain_is_a_noop_so_window_is_not_doubled(): + """preStop /health/drain and the lifespan SIGTERM handler both drain; the + second call must return immediately rather than waiting another full + timeout (which would require doubling terminationGracePeriodSeconds).""" + await GracefulShutdownManager.wait_for_drain(timeout=0.2, count_fn=lambda: 1) + + start = time.monotonic() + drained = await GracefulShutdownManager.wait_for_drain( + timeout=5, count_fn=lambda: 1 + ) + assert time.monotonic() - start < 0.1 + assert drained == 0 + + +@pytest.mark.asyncio +async def test_emits_periodic_drain_waiting_log_while_waiting(): + """With a zero log interval, the periodic drain_waiting branch runs on each + poll until the counter finally drains.""" + count_fn = _counter_that_drains_after(calls_before_zero=2) + drained = await GracefulShutdownManager.wait_for_drain( + timeout=10, count_fn=count_fn, poll_interval=0, log_interval=0 + ) + assert drained == 3