diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 6128f48574..f513381c86 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2401,6 +2401,13 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase): "health checks run without a concurrency cap" ), ) + health_check_skip_disabled_background_models: bool = Field( + False, + description=( + "When true, deployments with model_info.disable_background_health_check " + "are skipped for on-demand GET /health as well as the background health loop." + ), + ) alerting: Optional[List] = Field( None, description="List of alerting integrations. Today, just slack - `alerting: ['slack']`", diff --git a/litellm/proxy/health_check.py b/litellm/proxy/health_check.py index 400da9da0d..585ba88394 100644 --- a/litellm/proxy/health_check.py +++ b/litellm/proxy/health_check.py @@ -86,6 +86,24 @@ def _clean_endpoint_data(endpoint_data: dict, details: Optional[bool] = True): ) +def health_check_filter_kwargs_from_general_settings( + general_settings: Optional[dict], +) -> dict: + """ + Build kwargs for ``perform_health_check`` from ``general_settings``. + + When ``health_check_skip_disabled_background_models`` is true, deployments with + ``model_info.disable_background_health_check`` are omitted from health runs + (including on-demand ``GET /health``), matching the background loop behavior. + """ + g = general_settings or {} + return { + "health_check_skip_disabled_background_models": bool( + g.get("health_check_skip_disabled_background_models", False) + ), + } + + def filter_deployments_by_id( model_list: List, ) -> List: @@ -438,6 +456,7 @@ async def perform_health_check( model_id: Optional[str] = None, max_concurrency: Optional[int] = None, instrumentation_context: Optional[dict] = None, + health_check_skip_disabled_background_models: bool = False, ): """ Perform a health check on the system. @@ -446,6 +465,12 @@ async def perform_health_check( (so models that share the same name but have different ids are checked separately). When model (name) is provided, all deployments matching that name are checked. + When ``health_check_skip_disabled_background_models`` is True (via + ``general_settings.health_check_skip_disabled_background_models``), deployments + with ``model_info.disable_background_health_check: true`` are omitted from + this run (including targeted ``/health`` queries), consistent with the + background health loop. + Returns: (bool): True if the health check passes, False otherwise. """ @@ -486,6 +511,23 @@ async def perform_health_check( _new_model_list = [x for x in model_list if x["model_name"] == model] model_list = _new_model_list + if health_check_skip_disabled_background_models: + model_list = [ + x + for x in model_list + if not (x.get("model_info") or {}).get( + "disable_background_health_check", False + ) + ] + if not model_list: + if instrumentation_enabled: + logger.debug( + "health_check_cycle_skipped source=%s cycle_id=%s reason=no_models_after_filter", + source, + cycle_id, + ) + return [], [], {} + post_filter_model_count = len(model_list) model_list = filter_deployments_by_id( model_list=model_list diff --git a/litellm/proxy/health_check_utils/shared_health_check_manager.py b/litellm/proxy/health_check_utils/shared_health_check_manager.py index ad58bc7e28..5c5f8929a3 100644 --- a/litellm/proxy/health_check_utils/shared_health_check_manager.py +++ b/litellm/proxy/health_check_utils/shared_health_check_manager.py @@ -192,6 +192,7 @@ class SharedHealthCheckManager: model_list: List[Dict[str, Any]], details: bool = True, max_concurrency: Optional[int] = None, + health_check_skip_disabled_background_models: bool = False, ) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]], Dict[str, Any]]: """ Perform health check with shared state coordination. @@ -207,6 +208,7 @@ class SharedHealthCheckManager: model_list: List of models to check details: Whether to include detailed information max_concurrency: Optional limit on concurrent health check requests + health_check_skip_disabled_background_models: Remove models with disable_background_health_check: true Returns: Tuple of (healthy_endpoints, unhealthy_endpoints) @@ -240,6 +242,7 @@ class SharedHealthCheckManager: model_list=model_list, details=details, max_concurrency=max_concurrency, + health_check_skip_disabled_background_models=health_check_skip_disabled_background_models, ) # Cache the results @@ -260,6 +263,7 @@ class SharedHealthCheckManager: model_list=model_list, details=details, max_concurrency=max_concurrency, + health_check_skip_disabled_background_models=health_check_skip_disabled_background_models, ) # Lock not acquired — poll for cached results until the lock @@ -316,6 +320,7 @@ class SharedHealthCheckManager: model_list=model_list, details=details, max_concurrency=max_concurrency, + health_check_skip_disabled_background_models=health_check_skip_disabled_background_models, ) async def is_health_check_in_progress(self) -> bool: diff --git a/litellm/proxy/health_endpoints/_health_endpoints.py b/litellm/proxy/health_endpoints/_health_endpoints.py index 2c857c0fde..3032c9d38c 100644 --- a/litellm/proxy/health_endpoints/_health_endpoints.py +++ b/litellm/proxy/health_endpoints/_health_endpoints.py @@ -32,6 +32,7 @@ from litellm.proxy.health_check import ( ADMIN_ONLY_HEALTH_DISPLAY_PARAMS, _clean_endpoint_data, _update_litellm_params_for_health_check, + health_check_filter_kwargs_from_general_settings, perform_health_check, run_with_timeout, ) @@ -858,6 +859,7 @@ async def _perform_health_check_and_save( user_id, model_id=None, max_concurrency=None, + **perform_health_check_extra, ): """Helper function to perform health check and save results to database""" healthy_endpoints, unhealthy_endpoints, _ = await perform_health_check( @@ -867,6 +869,7 @@ async def _perform_health_check_and_save( details=details, max_concurrency=max_concurrency, model_id=model_id, + **perform_health_check_extra, ) # Optionally save health check result to database (non-blocking) @@ -894,6 +897,37 @@ async def _perform_health_check_and_save( } +def _health_endpoint_resolve_target_model_name( + model: Optional[str], + model_id: Optional[str], + llm_router, +) -> Optional[str]: + """Map ``model_id`` (without ``model``) to ``model_name`` for live health checks.""" + if not model_id or model: + return model + if llm_router is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": f"Model with ID {model_id} not found"}, + ) + try: + deployment = llm_router.get_deployment(model_id=model_id) + except Exception as e: + verbose_proxy_logger.error( + f"Error getting deployment for model_id {model_id}: {e}" + ) + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": f"Model with ID {model_id} not found"}, + ) from e + if deployment is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"error": f"Model with ID {model_id} not found"}, + ) + return deployment.model_name + + @router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)]) async def health_endpoint( response: Response, @@ -920,10 +954,15 @@ async def health_endpoint( background_health_checks: True ``` else, the health checks will be run on models when /health is called. + + To skip deployments that set ``model_info.disable_background_health_check: true`` + on ``GET /health`` as well as in the background loop, set + ``general_settings.health_check_skip_disabled_background_models: true``. """ import time from litellm.proxy.proxy_server import ( + general_settings, health_check_concurrency, health_check_details, health_check_results, @@ -934,35 +973,12 @@ async def health_endpoint( user_model, ) + _hc_filter = health_check_filter_kwargs_from_general_settings(general_settings) start_time = time.time() - # Handle model_id parameter - convert to model name for health check - target_model = model - if model_id and not model: - # Use get_deployment from router to find the model name - if llm_router is not None: - try: - deployment = llm_router.get_deployment(model_id=model_id) - if deployment is not None: - target_model = deployment.model_name - else: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={"error": f"Model with ID {model_id} not found"}, - ) - except Exception as e: - verbose_proxy_logger.error( - f"Error getting deployment for model_id {model_id}: {e}" - ) - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={"error": f"Model with ID {model_id} not found"}, - ) - else: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={"error": f"Model with ID {model_id} not found"}, - ) + target_model = _health_endpoint_resolve_target_model_name( + model, model_id, llm_router + ) is_admin = _is_proxy_admin(user_api_key_dict) model_specific_request = bool(model or model_id) @@ -1000,6 +1016,7 @@ async def health_endpoint( user_id=user_api_key_dict.user_id, model_id=None, # CLI model doesn't have model_id max_concurrency=health_check_concurrency, + **_hc_filter, ) return _post_process(cli_result) raise HTTPException( @@ -1085,6 +1102,7 @@ async def health_endpoint( user_id=user_api_key_dict.user_id, model_id=model_id, max_concurrency=health_check_concurrency, + **_hc_filter, ) return _post_process(router_result) except Exception as e: diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 4a538a28e0..5851d2550b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -314,7 +314,10 @@ from litellm.proxy.guardrails.init_guardrails import ( init_guardrails_v2, initialize_guardrails, ) -from litellm.proxy.health_check import perform_health_check +from litellm.proxy.health_check import ( + health_check_filter_kwargs_from_general_settings, + perform_health_check, +) from litellm.proxy.health_endpoints._health_endpoints import router as health_router from litellm.proxy.hooks.model_max_budget_limiter import ( _PROXY_VirtualKeyModelMaxBudgetLimiter, @@ -2733,29 +2736,44 @@ def _rss_mb_for_log() -> str: return f"{rss_mb:.2f}" +def _is_unexpected_keyword_argument_type_error(exc: BaseException) -> bool: + """True when ``exc`` is a TypeError from passing a kwarg the callee does not accept.""" + return isinstance(exc, TypeError) and ( + "unexpected keyword argument" in str(exc).lower() + ) + + async def _run_direct_health_check_with_instrumentation( model_list: list, details: Optional[bool], max_concurrency: Optional[int], instrumentation_context: dict, ): - try: - return await perform_health_check( - model_list=model_list, - details=details, - max_concurrency=max_concurrency, - instrumentation_context=instrumentation_context, - ) - except TypeError as e: - if "instrumentation_context" not in str(e): - raise - # Backward compatibility for monkeypatched or wrapped callables - # that do not accept instrumentation_context. - return await perform_health_check( - model_list=model_list, - details=details, - max_concurrency=max_concurrency, - ) + """Call ``perform_health_check``, retrying with fewer kwargs on unexpected-kw TypeErrors.""" + _hc_filter = health_check_filter_kwargs_from_general_settings(general_settings) + last_type_error: Optional[TypeError] = None + for extra_kwargs in ( + { + "instrumentation_context": instrumentation_context, + **_hc_filter, + }, + {"instrumentation_context": instrumentation_context}, + dict(_hc_filter), + {}, + ): + try: + return await perform_health_check( + model_list=model_list, + details=details, + max_concurrency=max_concurrency, + **extra_kwargs, + ) + except TypeError as e: + if not _is_unexpected_keyword_argument_type_error(e): + raise + last_type_error = e + assert last_type_error is not None + raise last_type_error def _schedule_background_health_check_db_save( @@ -3020,6 +3038,7 @@ async def _run_background_health_check(): details_bool = ( health_check_details if health_check_details is not None else True ) + _hc_filter = health_check_filter_kwargs_from_general_settings(general_settings) if shared_health_manager is not None: try: @@ -3031,6 +3050,7 @@ async def _run_background_health_check(): model_list=_llm_model_list, details=details_bool, max_concurrency=health_check_concurrency, + **_hc_filter, ) except Exception as e: verbose_proxy_logger.error( @@ -3043,7 +3063,7 @@ async def _run_background_health_check(): _exceptions_by_model_id, ) = await _run_direct_health_check_with_instrumentation( _llm_model_list, - health_check_details, + details_bool, health_check_concurrency, instrumentation_context, ) @@ -3054,7 +3074,7 @@ async def _run_background_health_check(): _exceptions_by_model_id, ) = await _run_direct_health_check_with_instrumentation( _llm_model_list, - health_check_details, + details_bool, health_check_concurrency, instrumentation_context, ) diff --git a/tests/litellm_utils_tests/test_health_check.py b/tests/litellm_utils_tests/test_health_check.py index 61590e8f95..8a6bed61a3 100644 --- a/tests/litellm_utils_tests/test_health_check.py +++ b/tests/litellm_utils_tests/test_health_check.py @@ -495,6 +495,45 @@ async def test_perform_health_check_filters_by_model_id(): assert healthy_endpoints[0]["api_key"] == "fake-key-2" +@pytest.mark.asyncio +async def test_perform_health_check_skip_disabled_background_models(): + from litellm.proxy.health_check import perform_health_check + + model_list = [ + { + "model_name": "a", + "model_info": {"id": "id-a"}, + "litellm_params": {"model": "m-a", "api_key": "k1"}, + }, + { + "model_name": "b", + "model_info": { + "id": "id-b", + "disable_background_health_check": True, + }, + "litellm_params": {"model": "m-b", "api_key": "k2"}, + }, + ] + captured = [] + + async def mock_inner(m_list, details=True, **kwargs): + captured.append(list(m_list)) + return [], [], {} + + with patch( + "litellm.proxy.health_check._perform_health_check", + side_effect=mock_inner, + ): + await perform_health_check( + model_list=model_list, + health_check_skip_disabled_background_models=True, + ) + + assert len(captured) == 1 + assert len(captured[0]) == 1 + assert captured[0][0]["model_name"] == "a" + + @pytest.mark.asyncio async def test_perform_health_check_with_health_check_model(): """ diff --git a/tests/proxy_unit_tests/test_proxy_server.py b/tests/proxy_unit_tests/test_proxy_server.py index cdcdc89e7f..9c08175767 100644 --- a/tests/proxy_unit_tests/test_proxy_server.py +++ b/tests/proxy_unit_tests/test_proxy_server.py @@ -2485,7 +2485,9 @@ async def test_background_health_check_skip_disabled_models(monkeypatch): ] called_model_lists = [] - async def fake_perform_health_check(model_list, details, max_concurrency=None): + async def fake_perform_health_check( + model_list, details, max_concurrency=None, **kwargs + ): called_model_lists.append(copy.deepcopy(model_list)) return (["healthy"], [], {}) @@ -2508,6 +2510,100 @@ async def test_background_health_check_skip_disabled_models(monkeypatch): assert called_model_lists == [[{"model_name": "model-a"}]] +@pytest.mark.asyncio +async def test_run_direct_health_check_with_instrumentation_legacy_three_arg_stub( + monkeypatch, +): + """Monkeypatched perform_health_check with only base kwargs should still run.""" + import litellm.proxy.proxy_server as proxy_server + + async def fake_perform_health_check(model_list, details, max_concurrency=None): + return ([], [], {}) + + monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check) + result = await proxy_server._run_direct_health_check_with_instrumentation( + [{"model_name": "m"}], + True, + 1, + {"enabled": True, "source": "test", "cycle_id": "c1"}, + ) + assert result == ([], [], {}) + + +@pytest.mark.asyncio +async def test_run_direct_health_check_with_instrumentation_accepts_instrumentation_only( + monkeypatch, +): + """Stub that accepts instrumentation_context but not health_check filter kwargs.""" + import litellm.proxy.proxy_server as proxy_server + + seen: list = [] + + async def fake_perform_health_check( + model_list, details, max_concurrency=None, instrumentation_context=None + ): + seen.append(instrumentation_context) + return ([], [], {}) + + monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check) + await proxy_server._run_direct_health_check_with_instrumentation( + [], + False, + 2, + {"enabled": True, "source": "test", "cycle_id": "c2"}, + ) + assert len(seen) == 1 + assert seen[0]["cycle_id"] == "c2" + + +@pytest.mark.asyncio +async def test_run_direct_health_check_with_instrumentation_accepts_filter_only( + monkeypatch, +): + """Stub that accepts health_check_skip_disabled_background_models but not instrumentation.""" + import litellm.proxy.proxy_server as proxy_server + + seen: list = [] + + async def fake_perform_health_check( + model_list, + details, + max_concurrency=None, + health_check_skip_disabled_background_models=False, + ): + seen.append(health_check_skip_disabled_background_models) + return ([], [], {}) + + monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check) + await proxy_server._run_direct_health_check_with_instrumentation( + [], + True, + None, + {"enabled": False}, + ) + assert len(seen) == 1 + assert seen[0] is False + + +@pytest.mark.asyncio +async def test_run_direct_health_check_with_instrumentation_non_kw_typeerror_reraises( + monkeypatch, +): + import litellm.proxy.proxy_server as proxy_server + + async def fake_perform_health_check(**kwargs): + raise TypeError("unsupported operand type(s)") + + monkeypatch.setattr(proxy_server, "perform_health_check", fake_perform_health_check) + with pytest.raises(TypeError, match="unsupported operand"): + await proxy_server._run_direct_health_check_with_instrumentation( + [], + True, + 1, + {}, + ) + + def test_get_timeout_from_request(): from litellm.proxy.litellm_pre_call_utils import LiteLLMProxyRequestSetup diff --git a/tests/test_litellm/proxy/test_health_check_functions.py b/tests/test_litellm/proxy/test_health_check_functions.py index bd79361fa9..f223241baf 100644 --- a/tests/test_litellm/proxy/test_health_check_functions.py +++ b/tests/test_litellm/proxy/test_health_check_functions.py @@ -566,6 +566,7 @@ async def test_perform_health_check_and_save_passes_model_id_to_perform_health_c details=True, model_id=None, max_concurrency=None, + **kwargs, ): return healthy, unhealthy, {} @@ -591,5 +592,39 @@ async def test_perform_health_check_and_save_passes_model_id_to_perform_health_c assert result["unhealthy_count"] == 0 +@pytest.mark.asyncio +async def test_perform_health_check_and_save_forwards_skip_disabled_background_flag(): + """health_check_skip_disabled_background_models should reach perform_health_check.""" + model_list = [ + { + "model_name": "gpt-4", + "model_info": {"id": "deployment-abc"}, + "litellm_params": {"model": "gpt-4"}, + }, + ] + + async def mock_perform_health_check(**kwargs): + return [], [], {} + + with patch( + "litellm.proxy.health_endpoints._health_endpoints.perform_health_check", + side_effect=mock_perform_health_check, + ) as mock_perform: + await _perform_health_check_and_save( + model_list=model_list, + target_model=None, + cli_model=None, + details=True, + prisma_client=None, + start_time=0.0, + user_id="user-1", + model_id=None, + health_check_skip_disabled_background_models=True, + ) + + call_kwargs = mock_perform.call_args[1] + assert call_kwargs["health_check_skip_disabled_background_models"] is True + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/tests/test_litellm/proxy/test_shared_health_check.py b/tests/test_litellm/proxy/test_shared_health_check.py index 1530d33608..9f4078880e 100644 --- a/tests/test_litellm/proxy/test_shared_health_check.py +++ b/tests/test_litellm/proxy/test_shared_health_check.py @@ -310,7 +310,10 @@ class TestSharedHealthCheckManager: # Should call perform_health_check and cache results mock_perform.assert_called_once_with( - model_list=model_list, details=True, max_concurrency=None + model_list=model_list, + details=True, + max_concurrency=None, + health_check_skip_disabled_background_models=False, ) assert healthy == expected_healthy assert unhealthy == expected_unhealthy @@ -397,7 +400,10 @@ class TestSharedHealthCheckManager: assert mock_sleep.call_count == 2 mock_sleep.assert_called_with(5) mock_perform.assert_called_once_with( - model_list=model_list, details=True, max_concurrency=None + model_list=model_list, + details=True, + max_concurrency=None, + health_check_skip_disabled_background_models=False, ) assert healthy == expected_healthy assert unhealthy == expected_unhealthy @@ -437,7 +443,10 @@ class TestSharedHealthCheckManager: # Should detect orphaned lock after 1 iteration and fall back immediately mock_sleep.assert_called_once_with(5) mock_perform.assert_called_once_with( - model_list=model_list, details=True, max_concurrency=None + model_list=model_list, + details=True, + max_concurrency=None, + health_check_skip_disabled_background_models=False, ) assert healthy == expected_healthy assert unhealthy == expected_unhealthy @@ -506,7 +515,10 @@ class TestSharedHealthCheckManager: # Should NOT sleep at all — falls back to local health check immediately mock_sleep.assert_not_called() mock_perform.assert_called_once_with( - model_list=model_list, details=True, max_concurrency=None + model_list=model_list, + details=True, + max_concurrency=None, + health_check_skip_disabled_background_models=False, ) assert healthy == expected_healthy assert unhealthy == expected_unhealthy