From dd4a1d2be2f95ac67df2744b6218ff07af65336b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 18 Apr 2026 16:35:17 -0700 Subject: [PATCH] feat: add adaptive routing to litellm allow model routing to improve based on conversation signals ensures router is picking best model for task --- .../migration.sql | 39 ++ .../litellm_proxy_extras/schema.prisma | 43 ++ .../out/{404.html => 404/index.html} | 0 .../index.html} | 0 .../index.html} | 0 .../out/{chat.html => chat/index.html} | 0 .../index.html} | 0 .../{budgets.html => budgets/index.html} | 0 .../{caching.html => caching/index.html} | 0 .../index.html} | 0 .../{old-usage.html => old-usage/index.html} | 0 .../{prompts.html => prompts/index.html} | 0 .../index.html} | 0 .../index.html} | 0 .../out/{login.html => login/index.html} | 0 .../out/{logs.html => logs/index.html} | 0 .../{callback.html => callback/index.html} | 0 .../{model-hub.html => model-hub/index.html} | 0 .../{model_hub.html => model_hub/index.html} | 0 .../index.html} | 0 .../index.html} | 0 .../index.html} | 0 .../index.html} | 0 .../index.html} | 0 .../{policies.html => policies/index.html} | 0 .../index.html} | 0 .../index.html} | 0 .../index.html} | 0 .../{ui-theme.html => ui-theme/index.html} | 0 .../out/{teams.html => teams/index.html} | 0 .../{test-key.html => test-key/index.html} | 0 .../index.html} | 0 .../index.html} | 0 .../out/{usage.html => usage/index.html} | 0 .../out/{users.html => users/index.html} | 0 .../index.html} | 0 litellm/proxy/_new_secret_config.yaml | 103 +++-- .../adaptive_router_update_queue.py | 216 ++++++++++ .../adaptive_router_example.yaml | 52 +++ litellm/proxy/proxy_server.py | 67 ++- litellm/proxy/schema.prisma | 43 ++ litellm/router.py | 171 +++++++- .../router_strategy/adaptive_router/README.md | 93 +++++ .../adaptive_router/__init__.py | 6 + .../adaptive_router/adaptive_router.py | 344 ++++++++++++++++ .../router_strategy/adaptive_router/bandit.py | 136 +++++++ .../adaptive_router/classifier.py | 140 +++++++ .../router_strategy/adaptive_router/config.py | 55 +++ .../router_strategy/adaptive_router/hooks.py | 241 +++++++++++ .../adaptive_router/signals.py | 272 +++++++++++++ litellm/types/router.py | 47 ++- schema.prisma | 43 ++ scripts/verify_adaptive_router.py | 216 ++++++++++ .../test_adaptive_router_update_queue.py | 117 ++++++ .../adaptive_router/__init__.py | 0 .../fixtures/clean_no_signals.json | 16 + .../fixtures/clean_satisfaction.json | 23 ++ .../fixtures/disengagement_giveup.json | 16 + .../fixtures/exhaustion_429.json | 9 + .../fixtures/exhaustion_context_overflow.json | 13 + .../fixtures/failure_tool_error.json | 13 + .../fixtures/loop_same_tool.json | 35 ++ .../fixtures/misalignment_rephrase.json | 16 + .../mixed_failure_then_satisfaction.json | 31 ++ .../fixtures/stagnation_repeat.json | 16 + .../adaptive_router/test_adaptive_router.py | 224 +++++++++++ .../adaptive_router/test_async_pre_routing.py | 137 +++++++ .../adaptive_router/test_bandit.py | 134 ++++++ .../adaptive_router/test_classifier.py | 116 ++++++ .../adaptive_router/test_config.py | 55 +++ .../test_e2e_adaptive_router.py | 263 ++++++++++++ .../adaptive_router/test_hooks.py | 329 +++++++++++++++ .../adaptive_router/test_router_dispatch.py | 380 ++++++++++++++++++ .../adaptive_router/test_signals.py | 112 ++++++ .../adaptive_router/test_state_endpoint.py | 196 +++++++++ uv.lock | 6 +- 76 files changed, 4542 insertions(+), 42 deletions(-) create mode 100644 litellm-proxy-extras/litellm_proxy_extras/migrations/20260418000000_add_adaptive_router_tables/migration.sql rename litellm/proxy/_experimental/out/{404.html => 404/index.html} (100%) rename litellm/proxy/_experimental/out/{_not-found.html => _not-found/index.html} (100%) rename litellm/proxy/_experimental/out/{api-reference.html => api-reference/index.html} (100%) rename litellm/proxy/_experimental/out/{chat.html => chat/index.html} (100%) rename litellm/proxy/_experimental/out/experimental/{api-playground.html => api-playground/index.html} (100%) rename litellm/proxy/_experimental/out/experimental/{budgets.html => budgets/index.html} (100%) rename litellm/proxy/_experimental/out/experimental/{caching.html => caching/index.html} (100%) rename litellm/proxy/_experimental/out/experimental/{claude-code-plugins.html => claude-code-plugins/index.html} (100%) rename litellm/proxy/_experimental/out/experimental/{old-usage.html => old-usage/index.html} (100%) rename litellm/proxy/_experimental/out/experimental/{prompts.html => prompts/index.html} (100%) rename litellm/proxy/_experimental/out/experimental/{tag-management.html => tag-management/index.html} (100%) rename litellm/proxy/_experimental/out/{guardrails.html => guardrails/index.html} (100%) rename litellm/proxy/_experimental/out/{login.html => login/index.html} (100%) rename litellm/proxy/_experimental/out/{logs.html => logs/index.html} (100%) rename litellm/proxy/_experimental/out/mcp/oauth/{callback.html => callback/index.html} (100%) rename litellm/proxy/_experimental/out/{model-hub.html => model-hub/index.html} (100%) rename litellm/proxy/_experimental/out/{model_hub.html => model_hub/index.html} (100%) rename litellm/proxy/_experimental/out/{model_hub_table.html => model_hub_table/index.html} (100%) rename litellm/proxy/_experimental/out/{models-and-endpoints.html => models-and-endpoints/index.html} (100%) rename litellm/proxy/_experimental/out/{onboarding.html => onboarding/index.html} (100%) rename litellm/proxy/_experimental/out/{organizations.html => organizations/index.html} (100%) rename litellm/proxy/_experimental/out/{playground.html => playground/index.html} (100%) rename litellm/proxy/_experimental/out/{policies.html => policies/index.html} (100%) rename litellm/proxy/_experimental/out/settings/{admin-settings.html => admin-settings/index.html} (100%) rename litellm/proxy/_experimental/out/settings/{logging-and-alerts.html => logging-and-alerts/index.html} (100%) rename litellm/proxy/_experimental/out/settings/{router-settings.html => router-settings/index.html} (100%) rename litellm/proxy/_experimental/out/settings/{ui-theme.html => ui-theme/index.html} (100%) rename litellm/proxy/_experimental/out/{teams.html => teams/index.html} (100%) rename litellm/proxy/_experimental/out/{test-key.html => test-key/index.html} (100%) rename litellm/proxy/_experimental/out/tools/{mcp-servers.html => mcp-servers/index.html} (100%) rename litellm/proxy/_experimental/out/tools/{vector-stores.html => vector-stores/index.html} (100%) rename litellm/proxy/_experimental/out/{usage.html => usage/index.html} (100%) rename litellm/proxy/_experimental/out/{users.html => users/index.html} (100%) rename litellm/proxy/_experimental/out/{virtual-keys.html => virtual-keys/index.html} (100%) create mode 100644 litellm/proxy/db/db_transaction_queue/adaptive_router_update_queue.py create mode 100644 litellm/proxy/example_config_yaml/adaptive_router_example.yaml create mode 100644 litellm/router_strategy/adaptive_router/README.md create mode 100644 litellm/router_strategy/adaptive_router/__init__.py create mode 100644 litellm/router_strategy/adaptive_router/adaptive_router.py create mode 100644 litellm/router_strategy/adaptive_router/bandit.py create mode 100644 litellm/router_strategy/adaptive_router/classifier.py create mode 100644 litellm/router_strategy/adaptive_router/config.py create mode 100644 litellm/router_strategy/adaptive_router/hooks.py create mode 100644 litellm/router_strategy/adaptive_router/signals.py create mode 100644 scripts/verify_adaptive_router.py create mode 100644 tests/test_litellm/proxy/db/db_transaction_queue/test_adaptive_router_update_queue.py create mode 100644 tests/test_litellm/router_strategy/adaptive_router/__init__.py create mode 100644 tests/test_litellm/router_strategy/adaptive_router/fixtures/clean_no_signals.json create mode 100644 tests/test_litellm/router_strategy/adaptive_router/fixtures/clean_satisfaction.json create mode 100644 tests/test_litellm/router_strategy/adaptive_router/fixtures/disengagement_giveup.json create mode 100644 tests/test_litellm/router_strategy/adaptive_router/fixtures/exhaustion_429.json create mode 100644 tests/test_litellm/router_strategy/adaptive_router/fixtures/exhaustion_context_overflow.json create mode 100644 tests/test_litellm/router_strategy/adaptive_router/fixtures/failure_tool_error.json create mode 100644 tests/test_litellm/router_strategy/adaptive_router/fixtures/loop_same_tool.json create mode 100644 tests/test_litellm/router_strategy/adaptive_router/fixtures/misalignment_rephrase.json create mode 100644 tests/test_litellm/router_strategy/adaptive_router/fixtures/mixed_failure_then_satisfaction.json create mode 100644 tests/test_litellm/router_strategy/adaptive_router/fixtures/stagnation_repeat.json create mode 100644 tests/test_litellm/router_strategy/adaptive_router/test_adaptive_router.py create mode 100644 tests/test_litellm/router_strategy/adaptive_router/test_async_pre_routing.py create mode 100644 tests/test_litellm/router_strategy/adaptive_router/test_bandit.py create mode 100644 tests/test_litellm/router_strategy/adaptive_router/test_classifier.py create mode 100644 tests/test_litellm/router_strategy/adaptive_router/test_config.py create mode 100644 tests/test_litellm/router_strategy/adaptive_router/test_e2e_adaptive_router.py create mode 100644 tests/test_litellm/router_strategy/adaptive_router/test_hooks.py create mode 100644 tests/test_litellm/router_strategy/adaptive_router/test_router_dispatch.py create mode 100644 tests/test_litellm/router_strategy/adaptive_router/test_signals.py create mode 100644 tests/test_litellm/router_strategy/adaptive_router/test_state_endpoint.py diff --git a/litellm-proxy-extras/litellm_proxy_extras/migrations/20260418000000_add_adaptive_router_tables/migration.sql b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260418000000_add_adaptive_router_tables/migration.sql new file mode 100644 index 0000000000..4d61db1115 --- /dev/null +++ b/litellm-proxy-extras/litellm_proxy_extras/migrations/20260418000000_add_adaptive_router_tables/migration.sql @@ -0,0 +1,39 @@ +-- One row per (router, request_type, model). Hot path on every routing decision. +CREATE TABLE "LiteLLM_AdaptiveRouterState" ( + router_name TEXT NOT NULL, + request_type TEXT NOT NULL, + model_name TEXT NOT NULL, + alpha DOUBLE PRECISION NOT NULL, + beta DOUBLE PRECISION NOT NULL, + total_samples INTEGER NOT NULL DEFAULT 0, + last_updated_at TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (router_name, request_type, model_name) +); + +-- One row per (session, router, model). Updated per turn via the queue. +CREATE TABLE "LiteLLM_AdaptiveRouterSession" ( + session_id TEXT NOT NULL, + router_name TEXT NOT NULL, + model_name TEXT NOT NULL, + classified_type TEXT NOT NULL, + misalignment_count INTEGER DEFAULT 0, + stagnation_count INTEGER DEFAULT 0, + disengagement_count INTEGER DEFAULT 0, + satisfaction_count INTEGER DEFAULT 0, + failure_count INTEGER DEFAULT 0, + loop_count INTEGER DEFAULT 0, + exhaustion_count INTEGER DEFAULT 0, + last_user_content TEXT, + last_assistant_content TEXT, + tool_call_history JSONB DEFAULT '[]', + pending_tool_calls JSONB DEFAULT '{}', + turn_count INTEGER DEFAULT 0, + last_processed_turn INTEGER DEFAULT -1, + clean_credit_awarded BOOLEAN DEFAULT FALSE, + terminal_status INTEGER, + last_activity_at TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (session_id, router_name, model_name) +); + +CREATE INDEX "idx_adaptive_router_session_activity" + ON "LiteLLM_AdaptiveRouterSession" (last_activity_at); diff --git a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma index ce3f5f131f..4e448b22a1 100644 --- a/litellm-proxy-extras/litellm_proxy_extras/schema.prisma +++ b/litellm-proxy-extras/litellm_proxy_extras/schema.prisma @@ -1219,3 +1219,46 @@ model LiteLLM_ClaudeCodePluginTable { @@map("LiteLLM_ClaudeCodePluginTable") } + +// Per-(router, request_type, model) Beta posterior for the adaptive router. +model LiteLLM_AdaptiveRouterState { + router_name String + request_type String + model_name String + alpha Float + beta Float + total_samples Int @default(0) + last_updated_at DateTime @default(now()) + + @@id([router_name, request_type, model_name]) +} + +// Per-(session, router, model) signal counters for the adaptive router. +model LiteLLM_AdaptiveRouterSession { + session_id String + router_name String + model_name String + classified_type String + + misalignment_count Int @default(0) + stagnation_count Int @default(0) + disengagement_count Int @default(0) + satisfaction_count Int @default(0) + failure_count Int @default(0) + loop_count Int @default(0) + exhaustion_count Int @default(0) + + last_user_content String? + last_assistant_content String? + tool_call_history Json @default("[]") + pending_tool_calls Json @default("{}") + + turn_count Int @default(0) + last_processed_turn Int @default(-1) + clean_credit_awarded Boolean @default(false) + terminal_status Int? + last_activity_at DateTime @default(now()) + + @@id([session_id, router_name, model_name]) + @@index([last_activity_at]) +} diff --git a/litellm/proxy/_experimental/out/404.html b/litellm/proxy/_experimental/out/404/index.html similarity index 100% rename from litellm/proxy/_experimental/out/404.html rename to litellm/proxy/_experimental/out/404/index.html diff --git a/litellm/proxy/_experimental/out/_not-found.html b/litellm/proxy/_experimental/out/_not-found/index.html similarity index 100% rename from litellm/proxy/_experimental/out/_not-found.html rename to litellm/proxy/_experimental/out/_not-found/index.html diff --git a/litellm/proxy/_experimental/out/api-reference.html b/litellm/proxy/_experimental/out/api-reference/index.html similarity index 100% rename from litellm/proxy/_experimental/out/api-reference.html rename to litellm/proxy/_experimental/out/api-reference/index.html diff --git a/litellm/proxy/_experimental/out/chat.html b/litellm/proxy/_experimental/out/chat/index.html similarity index 100% rename from litellm/proxy/_experimental/out/chat.html rename to litellm/proxy/_experimental/out/chat/index.html diff --git a/litellm/proxy/_experimental/out/experimental/api-playground.html b/litellm/proxy/_experimental/out/experimental/api-playground/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/api-playground.html rename to litellm/proxy/_experimental/out/experimental/api-playground/index.html diff --git a/litellm/proxy/_experimental/out/experimental/budgets.html b/litellm/proxy/_experimental/out/experimental/budgets/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/budgets.html rename to litellm/proxy/_experimental/out/experimental/budgets/index.html diff --git a/litellm/proxy/_experimental/out/experimental/caching.html b/litellm/proxy/_experimental/out/experimental/caching/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/caching.html rename to litellm/proxy/_experimental/out/experimental/caching/index.html diff --git a/litellm/proxy/_experimental/out/experimental/claude-code-plugins.html b/litellm/proxy/_experimental/out/experimental/claude-code-plugins/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/claude-code-plugins.html rename to litellm/proxy/_experimental/out/experimental/claude-code-plugins/index.html diff --git a/litellm/proxy/_experimental/out/experimental/old-usage.html b/litellm/proxy/_experimental/out/experimental/old-usage/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/old-usage.html rename to litellm/proxy/_experimental/out/experimental/old-usage/index.html diff --git a/litellm/proxy/_experimental/out/experimental/prompts.html b/litellm/proxy/_experimental/out/experimental/prompts/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/prompts.html rename to litellm/proxy/_experimental/out/experimental/prompts/index.html diff --git a/litellm/proxy/_experimental/out/experimental/tag-management.html b/litellm/proxy/_experimental/out/experimental/tag-management/index.html similarity index 100% rename from litellm/proxy/_experimental/out/experimental/tag-management.html rename to litellm/proxy/_experimental/out/experimental/tag-management/index.html diff --git a/litellm/proxy/_experimental/out/guardrails.html b/litellm/proxy/_experimental/out/guardrails/index.html similarity index 100% rename from litellm/proxy/_experimental/out/guardrails.html rename to litellm/proxy/_experimental/out/guardrails/index.html diff --git a/litellm/proxy/_experimental/out/login.html b/litellm/proxy/_experimental/out/login/index.html similarity index 100% rename from litellm/proxy/_experimental/out/login.html rename to litellm/proxy/_experimental/out/login/index.html diff --git a/litellm/proxy/_experimental/out/logs.html b/litellm/proxy/_experimental/out/logs/index.html similarity index 100% rename from litellm/proxy/_experimental/out/logs.html rename to litellm/proxy/_experimental/out/logs/index.html diff --git a/litellm/proxy/_experimental/out/mcp/oauth/callback.html b/litellm/proxy/_experimental/out/mcp/oauth/callback/index.html similarity index 100% rename from litellm/proxy/_experimental/out/mcp/oauth/callback.html rename to litellm/proxy/_experimental/out/mcp/oauth/callback/index.html diff --git a/litellm/proxy/_experimental/out/model-hub.html b/litellm/proxy/_experimental/out/model-hub/index.html similarity index 100% rename from litellm/proxy/_experimental/out/model-hub.html rename to litellm/proxy/_experimental/out/model-hub/index.html diff --git a/litellm/proxy/_experimental/out/model_hub.html b/litellm/proxy/_experimental/out/model_hub/index.html similarity index 100% rename from litellm/proxy/_experimental/out/model_hub.html rename to litellm/proxy/_experimental/out/model_hub/index.html diff --git a/litellm/proxy/_experimental/out/model_hub_table.html b/litellm/proxy/_experimental/out/model_hub_table/index.html similarity index 100% rename from litellm/proxy/_experimental/out/model_hub_table.html rename to litellm/proxy/_experimental/out/model_hub_table/index.html diff --git a/litellm/proxy/_experimental/out/models-and-endpoints.html b/litellm/proxy/_experimental/out/models-and-endpoints/index.html similarity index 100% rename from litellm/proxy/_experimental/out/models-and-endpoints.html rename to litellm/proxy/_experimental/out/models-and-endpoints/index.html diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding/index.html similarity index 100% rename from litellm/proxy/_experimental/out/onboarding.html rename to litellm/proxy/_experimental/out/onboarding/index.html diff --git a/litellm/proxy/_experimental/out/organizations.html b/litellm/proxy/_experimental/out/organizations/index.html similarity index 100% rename from litellm/proxy/_experimental/out/organizations.html rename to litellm/proxy/_experimental/out/organizations/index.html diff --git a/litellm/proxy/_experimental/out/playground.html b/litellm/proxy/_experimental/out/playground/index.html similarity index 100% rename from litellm/proxy/_experimental/out/playground.html rename to litellm/proxy/_experimental/out/playground/index.html diff --git a/litellm/proxy/_experimental/out/policies.html b/litellm/proxy/_experimental/out/policies/index.html similarity index 100% rename from litellm/proxy/_experimental/out/policies.html rename to litellm/proxy/_experimental/out/policies/index.html diff --git a/litellm/proxy/_experimental/out/settings/admin-settings.html b/litellm/proxy/_experimental/out/settings/admin-settings/index.html similarity index 100% rename from litellm/proxy/_experimental/out/settings/admin-settings.html rename to litellm/proxy/_experimental/out/settings/admin-settings/index.html diff --git a/litellm/proxy/_experimental/out/settings/logging-and-alerts.html b/litellm/proxy/_experimental/out/settings/logging-and-alerts/index.html similarity index 100% rename from litellm/proxy/_experimental/out/settings/logging-and-alerts.html rename to litellm/proxy/_experimental/out/settings/logging-and-alerts/index.html diff --git a/litellm/proxy/_experimental/out/settings/router-settings.html b/litellm/proxy/_experimental/out/settings/router-settings/index.html similarity index 100% rename from litellm/proxy/_experimental/out/settings/router-settings.html rename to litellm/proxy/_experimental/out/settings/router-settings/index.html diff --git a/litellm/proxy/_experimental/out/settings/ui-theme.html b/litellm/proxy/_experimental/out/settings/ui-theme/index.html similarity index 100% rename from litellm/proxy/_experimental/out/settings/ui-theme.html rename to litellm/proxy/_experimental/out/settings/ui-theme/index.html diff --git a/litellm/proxy/_experimental/out/teams.html b/litellm/proxy/_experimental/out/teams/index.html similarity index 100% rename from litellm/proxy/_experimental/out/teams.html rename to litellm/proxy/_experimental/out/teams/index.html diff --git a/litellm/proxy/_experimental/out/test-key.html b/litellm/proxy/_experimental/out/test-key/index.html similarity index 100% rename from litellm/proxy/_experimental/out/test-key.html rename to litellm/proxy/_experimental/out/test-key/index.html diff --git a/litellm/proxy/_experimental/out/tools/mcp-servers.html b/litellm/proxy/_experimental/out/tools/mcp-servers/index.html similarity index 100% rename from litellm/proxy/_experimental/out/tools/mcp-servers.html rename to litellm/proxy/_experimental/out/tools/mcp-servers/index.html diff --git a/litellm/proxy/_experimental/out/tools/vector-stores.html b/litellm/proxy/_experimental/out/tools/vector-stores/index.html similarity index 100% rename from litellm/proxy/_experimental/out/tools/vector-stores.html rename to litellm/proxy/_experimental/out/tools/vector-stores/index.html diff --git a/litellm/proxy/_experimental/out/usage.html b/litellm/proxy/_experimental/out/usage/index.html similarity index 100% rename from litellm/proxy/_experimental/out/usage.html rename to litellm/proxy/_experimental/out/usage/index.html diff --git a/litellm/proxy/_experimental/out/users.html b/litellm/proxy/_experimental/out/users/index.html similarity index 100% rename from litellm/proxy/_experimental/out/users.html rename to litellm/proxy/_experimental/out/users/index.html diff --git a/litellm/proxy/_experimental/out/virtual-keys.html b/litellm/proxy/_experimental/out/virtual-keys/index.html similarity index 100% rename from litellm/proxy/_experimental/out/virtual-keys.html rename to litellm/proxy/_experimental/out/virtual-keys/index.html diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 604e7d5f41..703fe6adc4 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,32 +1,83 @@ +# model_list: +# - model_name: claude-sonnet-4-6 +# litellm_params: {model: anthropic/claude-sonnet-4-6} +# model_info: +# litellm_routing_preferences: +# quality_tier: 1 +# keywords: [tin] +# - model_name: gpt-4o-mini +# litellm_params: {model: openai/gpt-4o-mini} +# model_info: +# litellm_routing_preferences: +# quality_tier: 1 +# keywords: [] +# - model_name: gpt-4o +# litellm_params: {model: openai/gpt-4o} +# model_info: +# litellm_routing_preferences: +# quality_tier: 2 +# keywords: [vision, function_calling] +# - model_name: opus +# litellm_params: {model: anthropic/claude-opus-4-7} +# model_info: +# litellm_routing_preferences: +# quality_tier: 3 +# keywords: ["architecture", "design"] +# - model_name: my-quality-router +# litellm_params: +# model: auto_router/adaptive_router +# adaptive_router_default_model: gpt-4o-mini +# adaptive_router_config: +# available_models: [gpt-4o-mini, gpt-4o, opus, claude-sonnet-4-6] +# Example proxy config for the adaptive router (v0). +# +# Wires one logical router ("smart-cheap-router") that adaptively picks between +# two real deployments ("fast" and "smart") based on per-session feedback signals. +# +# How to use from a client: +# POST /v1/chat/completions { "model": "smart-cheap-router", ... } +# Add { "metadata": { "litellm_session_id": "" } } to enable +# sticky-session routing within a conversation. +# +# Required env vars: OPENAI_API_KEY, DATABASE_URL. + model_list: - - # OpenAI model for /v1/chat/completions test — 200x custom pricing - - model_name: "gpt-4.1-mini" + # ---- The adaptive router "control" deployment ------------------------- + # `model_name` is what clients call. `available_models` lists the underlying + # deployments the router is allowed to pick from (must match other model_name + # entries in this list). + - model_name: smart-cheap-router litellm_params: - model: openai/gpt-4.1-mini - api_key: os.environ/OPENAI_API_KEY - model_info: - id: gpt-4.1-mini-custom-pricing - input_cost_per_token: 0.00004 # 100x standard ($0.40/1M = $0.0000004) - output_cost_per_token: 0.00016 # 100x standard ($1.60/1M = $0.0000016) + model: auto_router/adaptive_router + adaptive_router_config: + available_models: ["fast", "smart"] + weights: + quality: 0.7 + cost: 0.3 - # OpenAI model for /v1/responses test — 100x custom pricing - - model_name: "gpt-5" + # ---- Underlying deployments the router picks from --------------------- + - model_name: fast litellm_params: - model: openai/gpt-5 - api_key: os.environ/OPENAI_API_KEY - model_info: - id: gpt-5-custom-pricing - mode: "chat" - input_cost_per_token: 125 # 100x standard ($1.25/1M = $0.00000125) - output_cost_per_token: 10 # 100x standard ($10.00/1M = $0.00001) - - # Anthropic model for /v1/messages test — 100x custom pricing - - model_name: "claude-sonnet-4-20250514" - litellm_params: - model: anthropic/claude-sonnet-4-20250514 + model: anthropic/claude-sonnet-4-6 api_key: os.environ/ANTHROPIC_API_KEY + input_cost_per_token: 0.00000015 model_info: - id: claude-sonnet-4-custom-pricing - input_cost_per_token: 0.0003 # 100x standard ($0.000003) - output_cost_per_token: 0.0015 # 100x standard ($0.000015) \ No newline at end of file + adaptive_router_preferences: + quality_tier: 2 + strengths: [] + + - model_name: smart + litellm_params: + model: anthropic/claude-opus-4-7 + api_key: os.environ/ANTHROPIC_API_KEY + input_cost_per_token: 0.0000050 + model_info: + adaptive_router_preferences: + quality_tier: 3 + strengths: ["code_generation", "technical_design", "analytical_reasoning"] + +litellm_settings: + drop_params: True + +general_settings: + master_key: sk-1234 # REPLACE in production diff --git a/litellm/proxy/db/db_transaction_queue/adaptive_router_update_queue.py b/litellm/proxy/db/db_transaction_queue/adaptive_router_update_queue.py new file mode 100644 index 0000000000..3a76370e7d --- /dev/null +++ b/litellm/proxy/db/db_transaction_queue/adaptive_router_update_queue.py @@ -0,0 +1,216 @@ +""" +In-memory queues for adaptive router state and session updates. + +Pattern follows DailySpendUpdateQueue: hot path is fully in-memory; a background +flusher task drains the aggregator and writes batches to Postgres. + +Two logical queues (one class): + 1. STATE updates: increments to (router, request_type, model) bandit cell. + Aggregator key = (router_name, request_type, model_name) + Aggregated payload = {"delta_alpha": float, "delta_beta": float, "samples_added": int} + 2. SESSION updates: full snapshot of a session row (last-write-wins per session+router+model). + Aggregator key = (session_id, router_name, model_name) + Aggregated payload = the full session state dict. + +Hot-path API is non-blocking and synchronous from the caller's POV (it just appends +to the in-memory aggregator). Flush is async and batched. +""" + +from __future__ import annotations + +import asyncio +from typing import Any, Dict, Tuple + +from litellm._logging import verbose_proxy_logger + +StateKey = Tuple[str, str, str] # (router_name, request_type, model_name) +SessionKey = Tuple[str, str, str] # (session_id, router_name, model_name) + + +class AdaptiveRouterUpdateQueue: + """ + Single class managing both state-update aggregation and session-snapshot aggregation. + Held by the AdaptiveRouter strategy instance and started by the proxy on boot. + """ + + def __init__(self) -> None: + self._state_agg: Dict[StateKey, Dict[str, float]] = {} + self._session_agg: Dict[SessionKey, Dict[str, Any]] = {} + self._lock = asyncio.Lock() + self._max_state_size_seen = 0 + self._max_session_size_seen = 0 + + # ---- Hot-path: state delta ------------------------------------------- + + async def add_state_delta( + self, + router_name: str, + request_type: str, + model_name: str, + delta_alpha: float, + delta_beta: float, + ) -> None: + """Aggregate a bandit-cell delta. Multiple deltas to the same cell sum.""" + key: StateKey = (router_name, request_type, model_name) + async with self._lock: + current = self._state_agg.get(key) + if current is None: + self._state_agg[key] = { + "delta_alpha": delta_alpha, + "delta_beta": delta_beta, + "samples_added": 1, + } + else: + current["delta_alpha"] += delta_alpha + current["delta_beta"] += delta_beta + current["samples_added"] += 1 + if len(self._state_agg) > self._max_state_size_seen: + self._max_state_size_seen = len(self._state_agg) + + # ---- Hot-path: session snapshot -------------------------------------- + + async def add_session_state( + self, + session_id: str, + router_name: str, + model_name: str, + state_dict: Dict[str, Any], + ) -> None: + """ + Last-write-wins per session row. The state_dict is a snapshot of the + SessionState (signals counts + bookkeeping fields). The flusher will + upsert this into LiteLLM_AdaptiveRouterSession. + """ + key: SessionKey = (session_id, router_name, model_name) + async with self._lock: + self._session_agg[key] = state_dict + if len(self._session_agg) > self._max_session_size_seen: + self._max_session_size_seen = len(self._session_agg) + + # ---- Flushers (called by background task) ---------------------------- + + async def flush_state_to_db(self, prisma_client: Any) -> int: + """ + Drain state aggregator and apply to LiteLLM_AdaptiveRouterState. + Returns number of cells flushed. + """ + async with self._lock: + batch = self._state_agg + self._state_agg = {} + + if not batch: + return 0 + + # Sort keys to give deterministic write order across writers and + # reduce the chance of cross-row deadlocks when other workers race us. + for key in sorted(batch.keys()): + router, rt, model = key + payload = batch[key] + try: + existing = ( + await prisma_client.db.litellm_adaptiverouterstate.find_unique( + where={ + "router_name_request_type_model_name": { + "router_name": router, + "request_type": rt, + "model_name": model, + } + } + ) + ) + new_alpha = (existing.alpha if existing else 0.0) + payload[ + "delta_alpha" + ] + new_beta = (existing.beta if existing else 0.0) + payload["delta_beta"] + new_samples = (existing.total_samples if existing else 0) + int( + payload["samples_added"] + ) + await prisma_client.db.litellm_adaptiverouterstate.upsert( + where={ + "router_name_request_type_model_name": { + "router_name": router, + "request_type": rt, + "model_name": model, + } + }, + data={ + "create": { + "router_name": router, + "request_type": rt, + "model_name": model, + "alpha": new_alpha, + "beta": new_beta, + "total_samples": new_samples, + }, + "update": { + "alpha": new_alpha, + "beta": new_beta, + "total_samples": new_samples, + }, + }, + ) + except Exception as e: + verbose_proxy_logger.exception( + "AdaptiveRouterUpdateQueue: failed to flush state for %s: %s", + key, + e, + ) + + return len(batch) + + async def flush_session_to_db(self, prisma_client: Any) -> int: + """ + Drain session aggregator and upsert into LiteLLM_AdaptiveRouterSession. + Returns number of session rows flushed. + """ + async with self._lock: + batch = self._session_agg + self._session_agg = {} + + if not batch: + return 0 + + for key in sorted(batch.keys()): + session_id, router, model = key + payload = batch[key] + try: + # NOTE: Prisma client lower-cases model names, so + # `LiteLLM_AdaptiveRouterSession` -> `litellm_adaptiveroutersession` + # (single 's', not 'litellm_adaptiverouterssession'). + await prisma_client.db.litellm_adaptiveroutersession.upsert( + where={ + "session_id_router_name_model_name": { + "session_id": session_id, + "router_name": router, + "model_name": model, + } + }, + data={ + "create": { + "session_id": session_id, + "router_name": router, + "model_name": model, + **payload, + }, + "update": payload, + }, + ) + except Exception as e: + verbose_proxy_logger.exception( + "AdaptiveRouterUpdateQueue: failed to flush session for %s: %s", + key, + e, + ) + + return len(batch) + + # ---- Observability --------------------------------------------------- + + async def queue_size(self) -> Dict[str, int]: + async with self._lock: + return { + "state_pending": len(self._state_agg), + "session_pending": len(self._session_agg), + "max_state_seen": self._max_state_size_seen, + "max_session_seen": self._max_session_size_seen, + } diff --git a/litellm/proxy/example_config_yaml/adaptive_router_example.yaml b/litellm/proxy/example_config_yaml/adaptive_router_example.yaml new file mode 100644 index 0000000000..7cc060420a --- /dev/null +++ b/litellm/proxy/example_config_yaml/adaptive_router_example.yaml @@ -0,0 +1,52 @@ +# Example proxy config for the adaptive router (v0). +# +# Wires one logical router ("smart-cheap-router") that adaptively picks between +# two real deployments ("fast" and "smart") based on per-session feedback signals. +# +# How to use from a client: +# POST /v1/chat/completions { "model": "smart-cheap-router", ... } +# Add { "metadata": { "litellm_session_id": "" } } to enable +# sticky-session routing within a conversation. +# +# Required env vars: OPENAI_API_KEY, DATABASE_URL. + +model_list: + # ---- The adaptive router "control" deployment ------------------------- + # `model_name` is what clients call. `available_models` lists the underlying + # deployments the router is allowed to pick from (must match other model_name + # entries in this list). + - model_name: smart-cheap-router + litellm_params: + model: openai/gpt-4o-mini # placeholder; never actually called -- router picks from available_models + adaptive_router_config: + available_models: ["fast", "smart"] + weights: + quality: 0.7 + cost: 0.3 + + # ---- Underlying deployments the router picks from --------------------- + - model_name: fast + litellm_params: + model: openai/gpt-4o-mini + api_key: os.environ/OPENAI_API_KEY + input_cost_per_token: 0.00000015 + model_info: + adaptive_router_preferences: + quality_tier: 2 + strengths: [] + + - model_name: smart + litellm_params: + model: openai/gpt-4o + api_key: os.environ/OPENAI_API_KEY + input_cost_per_token: 0.0000050 + model_info: + adaptive_router_preferences: + quality_tier: 3 + strengths: ["code_generation", "technical_design", "analytical_reasoning"] + +litellm_settings: + drop_params: True + +general_settings: + master_key: sk-1234 # REPLACE in production diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 2d789b982d..67a3414d0b 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -952,6 +952,10 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915 _run_background_health_check() ) # start the background health check coroutine. + # Start adaptive-router queue flusher if any AdaptiveRouter is configured. + if llm_router is not None and getattr(llm_router, "adaptive_routers", None): + asyncio.create_task(_adaptive_router_flusher_loop()) + ## [Optional] Initialize dd tracer ProxyStartupEvent._init_dd_tracer() @@ -2201,9 +2205,11 @@ def run_ollama_serve(): with open(os.devnull, "w") as devnull: subprocess.Popen(command, stdout=devnull, stderr=devnull) except Exception as e: - verbose_proxy_logger.debug(f""" + verbose_proxy_logger.debug( + f""" LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve` - """) + """ + ) def _get_process_rss_mb() -> Optional[float]: @@ -2385,6 +2391,31 @@ def _write_health_state_to_router_cache( ) +_ADAPTIVE_ROUTER_FLUSH_INTERVAL_SECONDS = 10 + + +async def _adaptive_router_flusher_loop(): + """ + Drain every AdaptiveRouter's in-memory state + session aggregators into + Postgres on a fixed cadence. Hot-path writes go to memory; this loop is + the only writer to the adaptive router DB tables. + """ + global llm_router, prisma_client + while True: + try: + await asyncio.sleep(_ADAPTIVE_ROUTER_FLUSH_INTERVAL_SECONDS) + adaptive_routers = getattr(llm_router, "adaptive_routers", None) or {} + if not adaptive_routers or prisma_client is None: + continue + for ar in adaptive_routers.values(): + await ar.queue.flush_state_to_db(prisma_client) + await ar.queue.flush_session_to_db(prisma_client) + except asyncio.CancelledError: + raise + except Exception: + verbose_proxy_logger.exception("adaptive_router flusher iteration failed") + + async def _run_background_health_check(): """ Periodically run health checks in the background on the endpoints. @@ -13877,6 +13908,38 @@ async def home(request: Request): return "LiteLLM: RUNNING" +@router.get( + "/adaptive_router/state", + tags=["adaptive_router"], + dependencies=[Depends(user_api_key_auth)], +) +async def get_adaptive_router_state( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """Return live bandit posteriors + queue depth for every configured adaptive router. + + Admin-only. Returns 404 if no adaptive router is configured. + + Response shape: `{"routers": [, ...]}` — one snapshot per + adaptive-router deployment. Each snapshot's `router_name` field identifies + which deployment it came from. + """ + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: + raise HTTPException( + status_code=403, + detail={"error": CommonProxyErrors.not_allowed_access.value}, + ) + if llm_router is None or not llm_router.adaptive_routers: + raise HTTPException( + status_code=404, + detail={"error": "No adaptive_router is configured on this proxy."}, + ) + snapshots = [ + await ar.get_state_snapshot() for ar in llm_router.adaptive_routers.values() + ] + return {"routers": snapshots} + + @router.get("/routes", dependencies=[Depends(user_api_key_auth)]) async def get_routes(): """ diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index ce3f5f131f..4e448b22a1 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -1219,3 +1219,46 @@ model LiteLLM_ClaudeCodePluginTable { @@map("LiteLLM_ClaudeCodePluginTable") } + +// Per-(router, request_type, model) Beta posterior for the adaptive router. +model LiteLLM_AdaptiveRouterState { + router_name String + request_type String + model_name String + alpha Float + beta Float + total_samples Int @default(0) + last_updated_at DateTime @default(now()) + + @@id([router_name, request_type, model_name]) +} + +// Per-(session, router, model) signal counters for the adaptive router. +model LiteLLM_AdaptiveRouterSession { + session_id String + router_name String + model_name String + classified_type String + + misalignment_count Int @default(0) + stagnation_count Int @default(0) + disengagement_count Int @default(0) + satisfaction_count Int @default(0) + failure_count Int @default(0) + loop_count Int @default(0) + exhaustion_count Int @default(0) + + last_user_content String? + last_assistant_content String? + tool_call_history Json @default("[]") + pending_tool_calls Json @default("{}") + + turn_count Int @default(0) + last_processed_turn Int @default(-1) + clean_credit_awarded Boolean @default(false) + terminal_status Int? + last_activity_at DateTime @default(now()) + + @@id([session_id, router_name, model_name]) + @@index([last_activity_at]) +} diff --git a/litellm/router.py b/litellm/router.py index 9185e437a3..33736fbfff 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -200,12 +200,16 @@ if TYPE_CHECKING: from litellm.router_strategy.complexity_router.complexity_router import ( ComplexityRouter, ) + from litellm.router_strategy.adaptive_router.adaptive_router import ( + AdaptiveRouter, + ) Span = Union[_Span, Any] else: Span = Any AutoRouter = Any ComplexityRouter = Any + AdaptiveRouter = Any PreRoutingHookResponse = Any @@ -464,6 +468,7 @@ class Router: ) # {"TEAM_ID": PatternMatchRouter} self.auto_routers: Dict[str, "AutoRouter"] = {} self.complexity_routers: Dict[str, "ComplexityRouter"] = {} + self.adaptive_routers: Dict[str, "AdaptiveRouter"] = {} # Initialize model_group_alias early since it's used in set_model_list self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = ( @@ -3864,7 +3869,7 @@ class Router: self._add_deployment_model_to_endpoint_for_llm_passthrough_route( kwargs=kwargs, model=model, model_name=model_name ) - + # Get custom_llm_provider from deployment params try: custom_llm_provider = data.get("custom_llm_provider") @@ -3872,10 +3877,12 @@ class Router: model=data["model"], custom_llm_provider=custom_llm_provider, ) - custom_llm_provider = custom_llm_provider or inferred_custom_llm_provider + custom_llm_provider = ( + custom_llm_provider or inferred_custom_llm_provider + ) except Exception: custom_llm_provider = None - + # Build response kwargs response_kwargs = { **data, @@ -3885,7 +3892,7 @@ class Router: # Only set custom_llm_provider if it's not None if custom_llm_provider is not None: response_kwargs["custom_llm_provider"] = custom_llm_provider - + response = original_generic_function(**response_kwargs) rpm_semaphore = self._get_client( @@ -3981,7 +3988,9 @@ class Router: model=data["model"], custom_llm_provider=custom_llm_provider, ) - custom_llm_provider = custom_llm_provider or inferred_custom_llm_provider + custom_llm_provider = ( + custom_llm_provider or inferred_custom_llm_provider + ) except Exception: custom_llm_provider = None @@ -4246,7 +4255,9 @@ class Router: custom_llm_provider=custom_llm_provider, ) # Preserve explicitly stored provider, fallback to inferred - custom_llm_provider = custom_llm_provider or inferred_custom_llm_provider + custom_llm_provider = ( + custom_llm_provider or inferred_custom_llm_provider + ) ## REPLACE MODEL IN FILE WITH SELECTED DEPLOYMENT ## purpose = cast(Optional[OpenAIFilesPurpose], kwargs.get("purpose")) @@ -5355,9 +5366,9 @@ class Router: e, (litellm.ContextWindowExceededError, litellm.ContentPolicyViolationError), ) - _request_team_id: Optional[str] = ( - kwargs.get("metadata", {}) or {} - ).get("user_api_key_team_id") + _request_team_id: Optional[str] = (kwargs.get("metadata", {}) or {}).get( + "user_api_key_team_id" + ) all_deployments = self._get_all_deployments( model_name=original_model_group, team_id=_request_team_id ) @@ -6804,10 +6815,13 @@ class Router: Check if the deployment is an auto-router deployment (semantic router). Returns True if the litellm_params model starts with "auto_router/" - but NOT "auto_router/complexity_router" (which uses complexity routing). + but NOT "auto_router/complexity_router" or "auto_router/adaptive_router" + (which use the complexity-router and adaptive-router strategies). """ if litellm_params.model.startswith("auto_router/complexity_router"): return False # This is handled by complexity_router + if litellm_params.model.startswith("auto_router/adaptive_router"): + return False # This is handled by adaptive_router if litellm_params.model.startswith("auto_router/"): return True return False @@ -6914,6 +6928,121 @@ class Router: ) self.complexity_routers[deployment.model_name] = complexity_router + def _is_adaptive_router_deployment(self, litellm_params: LiteLLM_Params) -> bool: + """True when this deployment opts in via the `auto_router/adaptive_router` model prefix.""" + return litellm_params.model.startswith("auto_router/adaptive_router") + + def _finalize_adaptive_router_if_configured(self) -> None: + """Locate every adaptive-router deployment in the finalized model_list and + build an AdaptiveRouter for each. Safe no-op when none are configured. + Idempotent: skips any deployment whose model_name is already initialized.""" + for entry in self.model_list or []: + lp = ( + entry.get("litellm_params") + if isinstance(entry, dict) + else entry.litellm_params + ) + lp_model = ( + (lp.get("model") if isinstance(lp, dict) else lp.model) if lp else None + ) + if not (lp_model and lp_model.startswith("auto_router/adaptive_router")): + continue + model_name = ( + entry.get("model_name") if isinstance(entry, dict) else entry.model_name + ) + if not model_name or not lp: + continue + if model_name in self.adaptive_routers: + continue + deployment = Deployment( + model_name=model_name, + litellm_params=( + lp if not isinstance(lp, dict) else LiteLLM_Params(**lp) + ), + model_info=( + entry.get("model_info") + if isinstance(entry, dict) + else entry.model_info + ), + ) + self.init_adaptive_router_deployment(deployment=deployment) + + def init_adaptive_router_deployment(self, deployment: Deployment) -> None: + """ + Build an AdaptiveRouter instance for this deployment and register its + post-call hook. Multiple adaptive routers can coexist on a single Router, + keyed by `deployment.model_name`. + + `model_to_prefs` and `model_to_cost` are derived from the OTHER models + already registered in `self.model_list` whose `model_name` appears in + `available_models`. Models not yet registered fall back to defaults. + """ + # Local import: AdaptiveRouter -> hooks -> classifier all import litellm + # internals which transitively import this module. (AGENTS.md exception clause.) + from litellm.router_strategy.adaptive_router.adaptive_router import ( + AdaptiveRouter, + ) + from litellm.router_strategy.adaptive_router.hooks import ( + AdaptiveRouterPostCallHook, + ) + from litellm.types.router import ( + AdaptiveRouterConfig, + AdaptiveRouterPreferences, + ) + + raw_config = deployment.litellm_params.adaptive_router_config + if raw_config is None: + raise ValueError( + "adaptive_router_config is required for adaptive-router deployments." + ) + + config = AdaptiveRouterConfig(**raw_config) + + model_to_prefs: Dict[str, AdaptiveRouterPreferences] = {} + model_to_cost: Dict[str, float] = {} + for d in self.model_list or []: + name = d.get("model_name") if isinstance(d, dict) else d.model_name + if name not in config.available_models: + continue + mi = d.get("model_info") if isinstance(d, dict) else d.model_info + mi_dict: Dict[str, Any] = ( + mi if isinstance(mi, dict) else (mi.model_dump() if mi else {}) + ) + prefs_raw = mi_dict.get("adaptive_router_preferences") + if prefs_raw is not None: + model_to_prefs[name] = AdaptiveRouterPreferences(**prefs_raw) + + # `input_cost_per_token` is a LiteLLM_Params field per types/router.py. + lp = d.get("litellm_params") if isinstance(d, dict) else d.litellm_params + lp_dict: Dict[str, Any] = ( + lp if isinstance(lp, dict) else (lp.model_dump() if lp else {}) + ) + cost = lp_dict.get("input_cost_per_token") + if cost is not None: + model_to_cost[name] = float(cost) + + if deployment.model_name in self.adaptive_routers: + raise ValueError( + f"Adaptive-router deployment {deployment.model_name} already exists. " + "Please use a different model name." + ) + + adaptive_router = AdaptiveRouter( + router_name=deployment.model_name, + config=config, + model_to_prefs=model_to_prefs, + model_to_cost=model_to_cost, + ) + self.adaptive_routers[deployment.model_name] = adaptive_router + litellm.callbacks.append( + AdaptiveRouterPostCallHook(adaptive_router=adaptive_router) + ) + verbose_router_logger.info( + "AdaptiveRouter[%s] initialized with %d models", + deployment.model_name, + len(config.available_models), + ) + def deployment_is_active_for_environment(self, deployment: Deployment) -> bool: """ Function to check if a llm deployment is active for a given environment. Allows using the same config.yaml across multople environments @@ -7007,6 +7136,10 @@ class Router: # Note: model_name_to_deployment_indices is already built incrementally # by _create_deployment -> _add_model_to_list_and_index_map + # Deferred: build the AdaptiveRouter strategy now that all underlying + # deployments are visible in self.model_list. + self._finalize_adaptive_router_if_configured() + def _add_deployment(self, deployment: Deployment) -> Deployment: import os @@ -7134,6 +7267,11 @@ class Router: ): self.init_complexity_router_deployment(deployment=deployment) + # NOTE: adaptive-router deployments are deferred to the end of + # set_model_list() because their init needs visibility into the OTHER + # deployments listed in `available_models` (which may not yet have + # been processed when this one is created). + return deployment def _initialize_deployment_for_pass_through( @@ -9645,6 +9783,19 @@ class Router: specific_deployment=specific_deployment, ) + ######################################################### + # Check if an adaptive-router should be used + ######################################################### + adaptive_router = self.adaptive_routers.get(model) + if adaptive_router is not None: + return await adaptive_router.async_pre_routing_hook( + model=model, + request_kwargs=request_kwargs, + messages=messages, + input=input, + specific_deployment=specific_deployment, + ) + return None def get_available_deployment( diff --git a/litellm/router_strategy/adaptive_router/README.md b/litellm/router_strategy/adaptive_router/README.md new file mode 100644 index 0000000000..b2b8a52089 --- /dev/null +++ b/litellm/router_strategy/adaptive_router/README.md @@ -0,0 +1,93 @@ +# Adaptive Router (v0) + +A request-type-aware routing strategy. For each incoming request, classify the +prompt into one of seven `RequestType` buckets (code generation, writing, +analytical reasoning, …), then Thompson-sample a Beta(α, β) bandit posterior +per `(request_type, model)` cell to pick the best model. Quality estimates are +combined with a normalized cost score via a weighted linear sum. + +A post-call hook reads the response and runs lightweight regex + tool-call +detectors (see `signals.py`) to award per-turn credit/blame to the model that +served the turn. Updates are batched in-memory and flushed to Postgres every +~10s by a background task in `proxy_server.py`. + +## Config example + +```yaml +model_list: + - model_name: gpt-4o + litellm_params: + model: openai/gpt-4o + model_info: + input_cost_per_token: 0.0000025 + adaptive_router_preferences: + quality_tier: 3 + strengths: ["code_generation", "analytical_reasoning"] + + - model_name: gpt-4o-mini + litellm_params: + model: openai/gpt-4o-mini + model_info: + input_cost_per_token: 0.00000015 + adaptive_router_preferences: + quality_tier: 2 + strengths: ["general", "factual_lookup"] + + - model_name: smart-router + litellm_params: + model: adaptive_router/smart-router + adaptive_router_default_model: gpt-4o-mini + adaptive_router_config: + available_models: ["gpt-4o", "gpt-4o-mini"] + weights: + quality: 0.7 + cost: 0.3 +``` + +Callers may pass header `x-litellm-min-quality-tier: 3` (or metadata key +`min_quality_tier: 3`) to force selection from tier-3-or-higher models only. + +## Behavior summary + +- **Cold start.** Each `(request_type, model)` cell starts with a + Beta prior whose mean = `BASE_TIER_WEIGHT[tier] (+ STRENGTH_BONUS if declared)` + and total mass = `COLD_START_MASS` (10). About ten real observations move it + meaningfully. +- **Per-request decision.** Sample once per eligible model, score with + `quality_weight·sample + cost_weight·normalized_cost`, pick the argmax. + Routing is stateless per-turn — no sticky lookup. Each call resamples. +- **Owner-cache attribution.** Post-call, the conversation's first picked + model claims an "owner slot" for `OWNER_CACHE_TTL_SECONDS` (24h). Later + turns of the same conversation only fire bandit/state updates if the + same model handled them — mismatches are dropped (no attribution) and + counted in `skipped_updates_total`. Conversation identity is the + client-supplied `litellm_session_id` if present, otherwise a sha256 over + caller identity (api key hash, team, user, end-user) + the first message. +- **Per-turn updates.** `satisfaction → +α`. `misalignment, stagnation, + disengagement, failure → +β` (each). `loop → +0.5β`. `exhaustion → 0` + (uptime, not quality). Skipped if conversation has fewer than + `SIGNAL_GATE_MIN_MESSAGES` messages. +- **Persistence.** Bandit cells: aggregated deltas, eventually consistent. + Session rows: last-write-wins snapshots. + +## Known v0 limitations + +- **Latency is not in the score.** Quality + cost only. A pathologically slow + model can still be picked. +- **Hard sample cap at 200.** Once `α + β > 200`, deltas are silently dropped. + No rescaling — drift is a v1 concern. +- **24h owner-cache TTL.** No explicit eviction below TTL. The in-memory map + can grow if traffic patterns produce many one-shot sessions. +- **Owner-recovery skew.** If model A "owns" a conversation but is then + dethroned in the bandit, later turns served by model B are dropped — so + bandit updates for that conversation flatline until A's TTL expires. + Tracked via `skipped_updates_total`. +- **Signals are regex + tool-call only.** No LLM-judge, no embedding similarity, + no exemplar storage. Signals are best-effort and biased toward English. +- **One AdaptiveRouter per `Router`.** Multiple `adaptive_router/*` deployments + on the same `litellm.Router` raise at init. +- **Bandit-delta mapping is unvalidated.** `_compute_bandit_delta` is a v0 + guess; expect to retune after the first ~1000 sessions of real traffic. +- **`request_type` is classified per turn from the latest user message only.** + The first turn's classification doesn't carry forward; a multi-turn session + may shift bucket between turns. diff --git a/litellm/router_strategy/adaptive_router/__init__.py b/litellm/router_strategy/adaptive_router/__init__.py new file mode 100644 index 0000000000..d7f55ebced --- /dev/null +++ b/litellm/router_strategy/adaptive_router/__init__.py @@ -0,0 +1,6 @@ +"""Adaptive router strategy. See README.md for design overview.""" + +from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter +from litellm.router_strategy.adaptive_router.hooks import AdaptiveRouterPostCallHook + +__all__ = ["AdaptiveRouter", "AdaptiveRouterPostCallHook"] diff --git a/litellm/router_strategy/adaptive_router/adaptive_router.py b/litellm/router_strategy/adaptive_router/adaptive_router.py new file mode 100644 index 0000000000..d73062ae96 --- /dev/null +++ b/litellm/router_strategy/adaptive_router/adaptive_router.py @@ -0,0 +1,344 @@ +""" +Main adaptive router strategy. See README.md for design overview. + +One AdaptiveRouter instance per router_name. Holds in-memory caches: +- _cells: Beta(alpha, beta) bandit posteriors per (request_type, model) +- _owner_cache: session_key -> (owner_model, expires_at) — the first model + picked for a conversation owns its bandit-update slot +- _session_states: (session_key, model) -> SessionState for incremental signal updates + +Owns the AdaptiveRouterUpdateQueue used by the proxy's flusher to persist +state and session snapshots back to Postgres. + +Routing is stateless per-turn (Thompson sample fresh on every call). The +owner cache is consulted only at post-call time to decide whether a turn's +signals should fire a bandit update — turns served by a different model than +the conversation's owner are skipped to avoid cross-model misattribution. +""" + +from __future__ import annotations + +import asyncio +import time +from dataclasses import asdict +from typing import Any, Dict, List, Optional, Tuple, Union, cast + +from litellm._logging import verbose_router_logger +from litellm.litellm_core_utils.prompt_templates.common_utils import ( + get_last_user_message, +) +from litellm.proxy.db.db_transaction_queue.adaptive_router_update_queue import ( + AdaptiveRouterUpdateQueue, +) +from litellm.router_strategy.adaptive_router.bandit import ( + BanditCell, + apply_delta, + initial_cell, + pick_best, +) +from litellm.router_strategy.adaptive_router.classifier import classify_prompt +from litellm.router_strategy.adaptive_router.config import ( + ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY, + OWNER_CACHE_TTL_SECONDS, +) +from litellm.router_strategy.adaptive_router.signals import ( + SessionState, + SignalDelta, + Turn, + apply_turn, +) +from litellm.types.llms.openai import AllMessageValues +from litellm.types.router import ( + AdaptiveRouterConfig, + AdaptiveRouterPreferences, + PreRoutingHookResponse, + RequestType, +) + + +def _default_prefs() -> AdaptiveRouterPreferences: + """Tier-2 prior with no declared strengths; used when a model omits prefs.""" + return AdaptiveRouterPreferences(quality_tier=2, strengths=[]) + + +class AdaptiveRouter: + """One instance per router_name. Holds in-memory caches + the update queue.""" + + def __init__( + self, + router_name: str, + config: AdaptiveRouterConfig, + model_to_prefs: Dict[str, AdaptiveRouterPreferences], + model_to_cost: Dict[str, float], + ) -> None: + self.router_name = router_name + self.config = config + self.model_to_prefs = model_to_prefs + self.model_to_cost = model_to_cost + self.queue = AdaptiveRouterUpdateQueue() + + self._cells: Dict[Tuple[RequestType, str], BanditCell] = {} + self._owner_cache: Dict[str, Tuple[str, float]] = {} + self._session_states: Dict[Tuple[str, str], SessionState] = {} + self._skipped_updates_total: int = 0 + self._lock = asyncio.Lock() + + self._init_cold_start_cells() + + # ---- Cold-start ------------------------------------------------------ + + def _init_cold_start_cells(self) -> None: + """Populate _cells with cold-start priors for every (rt, model) combination.""" + for rt in RequestType: + for model in self.config.available_models: + prefs = self.model_to_prefs.get(model) or _default_prefs() + self._cells[(rt, model)] = initial_cell(prefs, rt) + + async def load_state_from_db(self, prisma_client: Any) -> None: + """Override cold-start cells with persisted state. Called once at startup.""" + if prisma_client is None: + return + try: + rows = await prisma_client.db.litellm_adaptiverouterstate.find_many( + where={"router_name": self.router_name} + ) + loaded = 0 + for row in rows: + try: + rt = RequestType(row.request_type) + except ValueError: + # Unknown taxonomy entry from an older/newer version. Skip. + continue + if row.model_name not in self.config.available_models: + continue + self._cells[(rt, row.model_name)] = BanditCell( + alpha=row.alpha, beta=row.beta + ) + loaded += 1 + verbose_router_logger.info( + "AdaptiveRouter[%s]: loaded %d cells from DB", + self.router_name, + loaded, + ) + except Exception as e: + verbose_router_logger.exception( + "AdaptiveRouter[%s]: failed to load state from DB: %s", + self.router_name, + e, + ) + + # ---- Pre-routing hook ------------------------------------------------ + + async def async_pre_routing_hook( + self, + model: str, + request_kwargs: Dict[str, Any], + messages: Optional[List[Dict[str, Any]]] = None, + input: Optional[Union[str, List]] = None, + specific_deployment: Optional[bool] = False, + ) -> Optional[PreRoutingHookResponse]: + """ + Plugin entry point invoked by `Router.async_pre_routing_hook` when the + inbound `model` matches this adaptive router's `router_name`. + + Classifies the last user message, picks a logical model via the bandit, + and stashes the chosen model on `request_kwargs["metadata"]` so the + post-call hook can surface it as a response header. + + Routing is stateless per-turn: every call Thompson-samples fresh, + regardless of any prior pick for the same session. Cross-turn + attribution is enforced post-call via the owner cache (see + `claim_or_check_owner`). + """ + user_text = ( + get_last_user_message(cast(List[AllMessageValues], messages or [])) or "" + ) + + request_type = classify_prompt(user_text) + chosen_model = await self.pick_model(request_type=request_type) + verbose_router_logger.debug( + "AdaptiveRouter[%s]: classified=%s -> chose %s", + self.router_name, + request_type.value, + chosen_model, + ) + + # Relay the chosen logical model to the post-call hook, which surfaces + # it as the `x-litellm-adaptive-router-model` response header. We use + # `metadata` (not a top-level kwarg) so the value doesn't leak into + # `litellm.acompletion(**input_kwargs)`. + kwargs_metadata = request_kwargs.setdefault("metadata", {}) + if isinstance(kwargs_metadata, dict): + kwargs_metadata[ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY] = chosen_model + + return PreRoutingHookResponse(model=chosen_model, messages=messages) + + # ---- Pick model ------------------------------------------------------ + + async def pick_model( + self, + request_type: RequestType, + min_quality_tier: Optional[int] = None, + ) -> str: + """Thompson-sample across eligible models. Stateless per-turn.""" + eligible = self._eligible_models(min_quality_tier) + if not eligible: + raise ValueError( + f"AdaptiveRouter[{self.router_name}]: no models meet " + f"min_quality_tier={min_quality_tier}" + ) + + cells = {m: self._cells[(request_type, m)] for m in eligible} + costs = {m: self.model_to_cost.get(m, 0.0) for m in eligible} + return pick_best( + cells, + costs, + quality_weight=self.config.weights.quality, + cost_weight=self.config.weights.cost, + ) + + def claim_or_check_owner(self, session_key: str, current_model: str) -> bool: + """Resolve attribution for a turn under stateless routing. + + Returns True iff this turn should fire a bandit/state update. The + first call for a `session_key` claims ownership for `current_model` + and returns True. Subsequent calls return True only if the owner is + still live AND matches `current_model`. Mismatches (a different + model handled this turn) and expired owners both increment + `_skipped_updates_total` and return False — no attribution. + """ + now = time.time() + existing = self._owner_cache.get(session_key) + if existing is not None and existing[1] > now: + owner_model, _ = existing + if owner_model == current_model: + return True + self._skipped_updates_total += 1 + return False + + # No live owner -> claim for current_model. + self._owner_cache[session_key] = ( + current_model, + now + OWNER_CACHE_TTL_SECONDS, + ) + return True + + async def get_state_snapshot(self) -> Dict[str, Any]: + """In-memory snapshot for the introspection endpoint. Cheap; no DB hit.""" + cells = [] + for (rt, model), cell in sorted( + self._cells.items(), key=lambda kv: (kv[0][0].value, kv[0][1]) + ): + total = cell.alpha + cell.beta + cells.append( + { + "request_type": rt.value, + "model": model, + "alpha": cell.alpha, + "beta": cell.beta, + "samples": total, + "quality_mean": cell.alpha / total if total > 0 else 0.0, + } + ) + queue = await self.queue.queue_size() + now = time.time() + owner_cache_live = sum(1 for _, exp in self._owner_cache.values() if exp > now) + return { + "router_name": self.router_name, + "available_models": list(self.config.available_models), + "weights": { + "quality": self.config.weights.quality, + "cost": self.config.weights.cost, + }, + "model_costs": dict(self.model_to_cost), + "cells": cells, + "owner_cache_live": owner_cache_live, + "skipped_updates_total": self._skipped_updates_total, + "queue": queue, + } + + def _eligible_models(self, min_quality_tier: Optional[int]) -> List[str]: + if min_quality_tier is None: + return list(self.config.available_models) + return [ + m + for m in self.config.available_models + if (self.model_to_prefs.get(m) or _default_prefs()).quality_tier + >= min_quality_tier + ] + + # ---- Session state --------------------------------------------------- + + def get_or_create_session_state( + self, + session_id: str, + model_name: str, + request_type: RequestType, + ) -> SessionState: + key = (session_id, model_name) + state = self._session_states.get(key) + if state is None: + state = SessionState( + session_id=session_id, + router_name=self.router_name, + model_name=model_name, + classified_type=request_type.value, + ) + self._session_states[key] = state + return state + + async def record_turn( + self, + session_id: str, + model_name: str, + request_type: RequestType, + turn: Turn, + ) -> SignalDelta: + """Apply one turn, push session snapshot + bandit deltas to the queue.""" + state = self.get_or_create_session_state(session_id, model_name, request_type) + delta = apply_turn(state, turn) + print("CALLS DELTA", delta) + + snapshot = asdict(state) + await self.queue.add_session_state( + session_id, self.router_name, model_name, snapshot + ) + + d_alpha, d_beta = self._compute_bandit_delta(delta) + print("CALLS D_ALPHA", d_alpha) + if d_alpha != 0 or d_beta != 0: + cell_key = (request_type, model_name) + self._cells[cell_key] = apply_delta(self._cells[cell_key], d_alpha, d_beta) + await self.queue.add_state_delta( + self.router_name, + request_type.value, + model_name, + d_alpha, + d_beta, + ) + + return delta + + @staticmethod + def _compute_bandit_delta(delta: SignalDelta) -> Tuple[float, float]: + """ + Translate per-turn signal deltas into bandit-cell deltas. + + v0 mapping (UNVALIDATED — D6): + - satisfaction -> +1 alpha + - misalignment, stagnation, + disengagement, failure -> +1 beta each + - loop -> +0.5 beta (weak; could be model OR user) + - exhaustion -> 0 (uptime issue, tracked separately later) + """ + d_alpha = float(delta.satisfaction) + d_beta = ( + float( + delta.misalignment + + delta.stagnation + + delta.disengagement + + delta.failure + ) + + 0.5 * delta.loop + ) + return d_alpha, d_beta diff --git a/litellm/router_strategy/adaptive_router/bandit.py b/litellm/router_strategy/adaptive_router/bandit.py new file mode 100644 index 0000000000..cc473ac58e --- /dev/null +++ b/litellm/router_strategy/adaptive_router/bandit.py @@ -0,0 +1,136 @@ +""" +Thompson sampling and prior initialization for the adaptive router bandit. + +Each (router, request_type, model) cell is a Beta(alpha, beta) posterior. +- alpha = pseudo-successes +- beta = pseudo-failures +- mean = alpha / (alpha + beta) +- total samples = alpha + beta - COLD_START_MASS (informative prior, not data) + +Hot path: thompson_sample() — pure function, no I/O. +""" + +import random +from dataclasses import dataclass +from typing import Dict, List, Optional + +from litellm.router_strategy.adaptive_router.config import ( + BASE_TIER_WEIGHT, + COLD_START_MASS, + DEFAULT_COST_WEIGHT, + DEFAULT_QUALITY_WEIGHT, + SAMPLE_CAP, + STRENGTH_BONUS, +) +from litellm.types.router import AdaptiveRouterPreferences, RequestType + + +@dataclass(frozen=True) +class BanditCell: + """Posterior state for a single (router, request_type, model) cell.""" + + alpha: float + beta: float + + @property + def mean(self) -> float: + total = self.alpha + self.beta + return self.alpha / total if total > 0 else 0.5 + + @property + def total_samples(self) -> int: + return max(0, int(self.alpha + self.beta - COLD_START_MASS)) + + +def initial_cell( + prefs: AdaptiveRouterPreferences, request_type: RequestType +) -> BanditCell: + """ + Cold-start prior for a (model, request_type) cell. + + mean = base_tier_weight[tier] + (STRENGTH_BONUS if request_type in strengths else 0) + capped at 0.95 to avoid an over-confident prior. + Total mass = COLD_START_MASS so that ~10 real observations can move it noticeably. + """ + base = BASE_TIER_WEIGHT[prefs.quality_tier] + bonus = STRENGTH_BONUS if request_type in prefs.strengths else 0.0 + mean = min(0.95, base + bonus) + alpha = mean * COLD_START_MASS + beta = (1.0 - mean) * COLD_START_MASS + return BanditCell(alpha=alpha, beta=beta) + + +def apply_delta(cell: BanditCell, delta_alpha: float, delta_beta: float) -> BanditCell: + """ + Apply a learning update to a cell, enforcing the sample cap. + + SAMPLE_CAP is a HARD cap on (alpha + beta). When the cap would be exceeded, + we drop the update. (D5: hard cap, no rescaling — keep v0 simple.) + """ + new_alpha = cell.alpha + delta_alpha + new_beta = cell.beta + delta_beta + if new_alpha + new_beta > SAMPLE_CAP: + return cell + return BanditCell(alpha=new_alpha, beta=new_beta) + + +def thompson_sample(cell: BanditCell, rng: Optional[random.Random] = None) -> float: + """Draw a sample from Beta(alpha, beta). Returns a quality estimate in [0, 1].""" + r = rng if rng is not None else random + return r.betavariate(cell.alpha, cell.beta) + + +def normalized_cost(model_cost: float, all_costs: List[float]) -> float: + """ + Map a raw $/1k-token cost into [0, 1] where 0 = most expensive, 1 = cheapest. + Returns 0.5 when there's no spread. + """ + if not all_costs: + return 0.5 + lo, hi = min(all_costs), max(all_costs) + if hi == lo: + return 0.5 + return 1.0 - ((model_cost - lo) / (hi - lo)) + + +def score( + quality_sample: float, + model_cost: float, + all_costs: List[float], + quality_weight: float = DEFAULT_QUALITY_WEIGHT, + cost_weight: float = DEFAULT_COST_WEIGHT, +) -> float: + """ + Multi-objective score. V0 is a weighted linear sum of (quality, normalized_cost). + Higher is better. Both inputs are in [0, 1]. + """ + cost_score = normalized_cost(model_cost, all_costs) + return quality_weight * quality_sample + cost_weight * cost_score + + +def pick_best( + cells: Dict[str, BanditCell], + model_costs: Dict[str, float], + quality_weight: float = DEFAULT_QUALITY_WEIGHT, + cost_weight: float = DEFAULT_COST_WEIGHT, + rng: Optional[random.Random] = None, +) -> str: + """ + Sample once per model, score each, return the model with highest score. + + cells: {model_name: BanditCell} + model_costs: {model_name: $/1k tokens} + """ + if not cells: + raise ValueError("pick_best called with no models") + all_costs = list(model_costs.values()) + best_model: Optional[str] = None + best_score = float("-inf") + for model, cell in cells.items(): + q = thompson_sample(cell, rng=rng) + s = score(q, model_costs[model], all_costs, quality_weight, cost_weight) + if s > best_score: + best_score = s + best_model = model + assert best_model is not None + return best_model diff --git a/litellm/router_strategy/adaptive_router/classifier.py b/litellm/router_strategy/adaptive_router/classifier.py new file mode 100644 index 0000000000..0434dfdb63 --- /dev/null +++ b/litellm/router_strategy/adaptive_router/classifier.py @@ -0,0 +1,140 @@ +""" +Rule-based classifier mapping a user prompt to a RequestType. + +V0 design choice: deterministic regex over the FIRST user message in a session. +Result is cached per session (caller's responsibility, not ours). + +Order matters: we check more specific types first, falling back to GENERAL. +""" + +import re +from typing import List, Pattern, Tuple + +from litellm.types.router import RequestType + +_RULES: List[Tuple[Pattern[str], RequestType]] = [ + ( + re.compile( + r"\b(write|create|generate|implement|build)\s+(?:a |an |the |me )?(?:python|javascript|typescript|java|rust|go|c\+\+|sql|bash|shell)\b", + re.IGNORECASE, + ), + RequestType.CODE_GENERATION, + ), + ( + re.compile( + r"\b(write|create|implement|build)\b(?:\s+\w+){0,4}?\s+(function|class|method|script|program|api|endpoint|microservice)\b", + re.IGNORECASE, + ), + RequestType.CODE_GENERATION, + ), + ( + re.compile( + r"\b(explain|describe|understand|walk me through|what does)\b.*\b(code|function|method|class|algorithm|snippet)\b", + re.IGNORECASE, + ), + RequestType.CODE_UNDERSTANDING, + ), + ( + re.compile( + r"\b(debug|fix|why (?:is|does|isn't)|what.s wrong|trace)\b.*\b(error|bug|exception|stacktrace|stack trace|traceback)\b", + re.IGNORECASE, + ), + RequestType.CODE_UNDERSTANDING, + ), + ( + re.compile( + r"\b(review|critique)\s+(?:this |my |the )?(?:code|pr|pull request|diff|patch)\b", + re.IGNORECASE, + ), + RequestType.CODE_UNDERSTANDING, + ), + ( + re.compile( + r"\b(design|architect|plan|architecture)\b.*\b(system|service|api|database|schema|module|microservice)\b", + re.IGNORECASE, + ), + RequestType.TECHNICAL_DESIGN, + ), + ( + re.compile( + r"\b(should i (?:use|choose|pick)|tradeoffs? between|compare)\b.*\b(library|framework|language|database|protocol|postgres|postgresql|mongodb|dynamodb|mysql|redis|kafka|sql|nosql)\b", + re.IGNORECASE, + ), + RequestType.TECHNICAL_DESIGN, + ), + ( + re.compile( + r"\bhow (?:should|do) i (?:design|structure|organize|model)\b", + re.IGNORECASE, + ), + RequestType.TECHNICAL_DESIGN, + ), + ( + re.compile( + r"\b(solve|compute|calculate|prove|derive)\b.*\b(equation|integral|derivative|theorem|proof|problem)\b", + re.IGNORECASE, + ), + RequestType.ANALYTICAL_REASONING, + ), + ( + re.compile(r"\b(if .+ then|given .+ find|suppose|assume)\b", re.IGNORECASE), + RequestType.ANALYTICAL_REASONING, + ), + ( + re.compile( + r"\b(probability|statistics|combinatorics|optimization problem)\b", + re.IGNORECASE, + ), + RequestType.ANALYTICAL_REASONING, + ), + ( + re.compile( + r"\b(write|draft|compose|rewrite|edit|proofread|polish)\b.*\b(email|essay|blog|post|article|letter|memo|copy|paragraph|sentence)\b", + re.IGNORECASE, + ), + RequestType.WRITING, + ), + ( + re.compile( + r"\b(make (?:this|it)|help me)\s+(?:more |less )?(?:concise|formal|casual|professional|persuasive)\b", + re.IGNORECASE, + ), + RequestType.WRITING, + ), + ( + re.compile( + r"^\s*(who|what|when|where|which)\s+(?:is|was|were|are)\b", re.IGNORECASE + ), + RequestType.FACTUAL_LOOKUP, + ), + ( + re.compile(r"^\s*(define|definition of|meaning of)\b", re.IGNORECASE), + RequestType.FACTUAL_LOOKUP, + ), + ( + re.compile( + r"^\s*how (?:do you spell|to spell|many .* are there|tall is)\b", + re.IGNORECASE, + ), + RequestType.FACTUAL_LOOKUP, + ), +] + + +def classify_prompt(text: str) -> RequestType: + """ + Classify a single user prompt. + + Falls back to GENERAL when no rule matches. Empty/whitespace-only also + returns GENERAL. + """ + if not text or not text.strip(): + return RequestType.GENERAL + + truncated = text[:2000] + + for pattern, request_type in _RULES: + if pattern.search(truncated): + return request_type + + return RequestType.GENERAL diff --git a/litellm/router_strategy/adaptive_router/config.py b/litellm/router_strategy/adaptive_router/config.py new file mode 100644 index 0000000000..b49d7cdf6d --- /dev/null +++ b/litellm/router_strategy/adaptive_router/config.py @@ -0,0 +1,55 @@ +""" +Configuration constants for the adaptive_router strategy. + +All magic numbers are first-pass guesses (D3-D6 in the handoff plan). +Expect to retune after first 1000 sessions of real traffic. +""" + +from typing import Dict + +from litellm.types.router import RequestType # re-export for convenience # noqa: F401 + +# D3 — Score weights (default; user-overridable via AdaptiveRouterConfig.weights) +DEFAULT_QUALITY_WEIGHT: float = 0.7 # UNVALIDATED — calibrated against [0] sessions +DEFAULT_COST_WEIGHT: float = 0.3 # UNVALIDATED — calibrated against [0] sessions + +# D4 — Cold-start prior: (alpha + beta) total mass = COLD_START_MASS +# Mean of Beta = base_tier_weight + (strength_bonus if declared) +BASE_TIER_WEIGHT: Dict[int, float] = {1: 0.3, 2: 0.5, 3: 0.7} # UNVALIDATED +STRENGTH_BONUS: float = 0.3 # UNVALIDATED +COLD_START_MASS: float = 10.0 + +# D5 — Sample cap. Hard cap, no rescaling (drift handling is v1). +SAMPLE_CAP: int = 200 + +# D6 — Clean-trace credit: minimum turns before α += 1 can fire. +MIN_TURNS_FOR_CLEAN_CREDIT: int = 3 + +# D2 — Owner-cache TTL (seconds). 24h. +# A conversation's first-picked model "owns" the bandit-update slot for +# this long. Subsequent turns of the same conversation only contribute a +# bandit/state update when the same model is re-sampled. +OWNER_CACHE_TTL_SECONDS: int = 24 * 3600 + +# Below this many messages we skip post-call signal recording. Most signals +# (misalignment, stagnation, satisfaction-in-response-to-prior-turn) need at +# least one full prior exchange to be meaningful. +SIGNAL_GATE_MIN_MESSAGES: int = 4 + +# Detector thresholds (from Plano/Chen 2026 paper). +MISALIGNMENT_JACCARD_THRESHOLD: float = 0.45 +STAGNATION_JACCARD_NEAR_DUP: float = 0.50 +STAGNATION_JACCARD_EXACT: float = 0.85 +LOOP_REPEAT_THRESHOLD: int = 3 +TOOL_CALL_HISTORY_MAX: int = 20 + +# D1 — Caller filter for min quality tier. +MIN_QUALITY_TIER_HEADER: str = "x-litellm-min-quality-tier" +MIN_QUALITY_TIER_METADATA_KEY: str = "min_quality_tier" + +# Pre-routing -> post-call relay: the chosen logical model is stashed on +# request_kwargs["metadata"][ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY] by the +# pre-routing hook, then read by the post-call hook to surface as the +# ADAPTIVE_ROUTER_RESPONSE_HEADER response header. +ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY: str = "adaptive_router_chosen_model" +ADAPTIVE_ROUTER_RESPONSE_HEADER: str = "x-litellm-adaptive-router-model" diff --git a/litellm/router_strategy/adaptive_router/hooks.py b/litellm/router_strategy/adaptive_router/hooks.py new file mode 100644 index 0000000000..05932664ee --- /dev/null +++ b/litellm/router_strategy/adaptive_router/hooks.py @@ -0,0 +1,241 @@ +""" +Post-call hook for the adaptive router. + +On each successful or failed completion, build a Turn from the request/response +and push it through `AdaptiveRouter.record_turn`. The router then updates the +in-memory bandit cell + session state and queues writes for the proxy flusher. + +All work happens after the response has been returned to the caller. Any +exception is swallowed — signal recording must never break a request. +""" + +from __future__ import annotations + +import hashlib +import json +from typing import Any, Dict, List, Optional + +from litellm._logging import verbose_router_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter +from litellm.router_strategy.adaptive_router.classifier import classify_prompt +from litellm.router_strategy.adaptive_router.config import ( + ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY, + ADAPTIVE_ROUTER_RESPONSE_HEADER, + SIGNAL_GATE_MIN_MESSAGES, +) +from litellm.router_strategy.adaptive_router.signals import Turn + +# Identity fields hashed into a derived session key so the same conversation +# from the same caller produces a stable key, while different keys/teams/users +# stay segregated even if they happen to send identical first messages. +_IDENTITY_FIELDS = ( + "user_api_key_hash", + "user_api_key_team_id", + "user_api_key_user_id", + "user_api_key_end_user_id", +) + + +def _resolve_session_key(kwargs: Dict[str, Any]) -> Optional[str]: + """Pick a stable per-conversation key for owner-cache attribution. + + Order: + 1. Honor a client-supplied session id (`litellm_session_id` on either + `litellm_params` or `litellm_params.metadata`, or `session_id` on + metadata) — backward compat for callers already wired up. + 2. Otherwise derive a sha256 over (identity fields, first message) so + the key is stable across turns of the same conversation. + + Returns None if there are no messages (nothing to attribute). + """ + litellm_params = kwargs.get("litellm_params") or {} + sid = litellm_params.get("litellm_session_id") + if sid: + return str(sid) + metadata = litellm_params.get("metadata") or {} + if isinstance(metadata, dict): + sid = metadata.get("session_id") or metadata.get("litellm_session_id") + if sid: + return str(sid) + + messages = kwargs.get("messages") or [] + if not messages: + return None + + identity = ":".join( + str(metadata.get(f) or "") if isinstance(metadata, dict) else "" + for f in _IDENTITY_FIELDS + ) + first = messages[0] + payload = ( + identity + + "|" + + json.dumps( + {"role": first.get("role"), "content": first.get("content")}, + sort_keys=True, + default=str, + ) + ) + return hashlib.sha256(payload.encode("utf-8")).hexdigest() + + +def _last_user_content(messages: Optional[List[Dict[str, Any]]]) -> Optional[str]: + if not messages: + return None + for msg in reversed(messages): + if msg.get("role") == "user": + content = msg.get("content") + if isinstance(content, str): + return content + if isinstance(content, list): + # OpenAI vision-style content: pick first text part. + for part in content: + if isinstance(part, dict) and part.get("type") == "text": + return part.get("text") + return None + return None + + +def _assistant_content_and_tool_calls(response_obj: Any) -> tuple: + """Return (assistant_text, tool_calls_list) extracted from a ModelResponse-ish object.""" + if response_obj is None: + return None, [] + try: + choices = getattr(response_obj, "choices", None) or response_obj.get("choices") + except Exception: + return None, [] + if not choices: + return None, [] + + msg = choices[0] + msg = getattr(msg, "message", None) or ( + msg.get("message") if isinstance(msg, dict) else None + ) + if msg is None: + return None, [] + + content = getattr(msg, "content", None) + if content is None and isinstance(msg, dict): + content = msg.get("content") + + raw_tool_calls = getattr(msg, "tool_calls", None) + if raw_tool_calls is None and isinstance(msg, dict): + raw_tool_calls = msg.get("tool_calls") + tool_calls: List[Dict[str, Any]] = [] + for tc in raw_tool_calls or []: + if isinstance(tc, dict): + tool_calls.append(tc) + else: + try: + tool_calls.append(tc.model_dump()) + except Exception: + tool_calls.append({"name": getattr(tc, "name", ""), "arguments": ""}) + return content, tool_calls + + +class AdaptiveRouterPostCallHook(CustomLogger): + """One hook instance per AdaptiveRouter. Registered into litellm.callbacks.""" + + def __init__(self, adaptive_router: AdaptiveRouter) -> None: + self.adaptive_router = adaptive_router + + async def async_post_call_success_hook( + self, + data: Dict[str, Any], + user_api_key_dict: Any, + response: Any, + ) -> None: + """ + Surface the chosen logical model picked by the pre-routing hook as the + `x-litellm-adaptive-router-model` response header. + + The chosen model is stashed on `data["metadata"]` by + `AdaptiveRouter.async_pre_routing_hook`. The proxy awaits this hook + before reading `_hidden_params["additional_headers"]` for the outgoing + HTTP response, so any value we write here flows through. + """ + metadata = data.get("metadata") or {} + chosen = ( + metadata.get(ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY) + if isinstance(metadata, dict) + else None + ) + if not chosen: + return + hidden_params = getattr(response, "_hidden_params", None) + if not isinstance(hidden_params, dict): + return + hidden_params.setdefault("additional_headers", {}) + hidden_params["additional_headers"][ADAPTIVE_ROUTER_RESPONSE_HEADER] = chosen + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + await self._record(kwargs, response_obj, response_status=200) + + async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): + status = kwargs.get("response_status") + if status is None: + exc = kwargs.get("exception") + status = getattr(exc, "status_code", 500) if exc is not None else 500 + await self._record(kwargs, response_obj, response_status=int(status)) + + async def _record( + self, + kwargs: Dict[str, Any], + response_obj: Any, + response_status: int, + ) -> None: + try: + messages = kwargs.get("messages") or [] + if len(messages) < SIGNAL_GATE_MIN_MESSAGES: + # Too few turns for any signal to be meaningful — skip. + return + + session_key = _resolve_session_key(kwargs) + if not session_key: + return + + # The bandit cells are keyed by the *logical* model name from + # `available_models` (e.g. "smart"/"fast"). `kwargs["model"]` at + # post-call time is the physical upstream model + # (e.g. "anthropic/claude-opus-4-7"), so it cannot be used directly. + # The pre-routing hook stashes the logical pick under this key. + litellm_params = kwargs.get("litellm_params") or {} + metadata = litellm_params.get("metadata") or {} + current_model = ( + metadata.get(ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY) + if isinstance(metadata, dict) + else None + ) + if not current_model: + return + + if not self.adaptive_router.claim_or_check_owner( + session_key, current_model + ): + # A different model owns this conversation — skip attribution. + return + + user_text = _last_user_content(messages) + assistant_text, tool_calls = _assistant_content_and_tool_calls(response_obj) + + request_type = classify_prompt(user_text or "") + turn = Turn( + user_content=user_text, + assistant_content=( + assistant_text if isinstance(assistant_text, str) else None + ), + tool_calls=tool_calls, + tool_results=[], + response_status=response_status, + ) + await self.adaptive_router.record_turn( + session_id=session_key, + model_name=current_model, + request_type=request_type, + turn=turn, + ) + except Exception as e: + verbose_router_logger.exception( + "AdaptiveRouterPostCallHook: failed to record turn: %s", e + ) diff --git a/litellm/router_strategy/adaptive_router/signals.py b/litellm/router_strategy/adaptive_router/signals.py new file mode 100644 index 0000000000..bc67493bea --- /dev/null +++ b/litellm/router_strategy/adaptive_router/signals.py @@ -0,0 +1,272 @@ +""" +Incremental signal detection for the adaptive router. + +Each session maintains a SessionState. On every turn, we call apply_turn(state, turn) +which mutates the state in place and returns a SignalDelta listing which signals +fired on THIS turn. The router then queues the delta to be flushed to DB. + +Design constraint: O(1) work per turn. No re-scanning the full session history. +We keep small bounded windows: last_user_content, last_assistant_content, and a +bounded list of recent tool call signatures. +""" + +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Set + +from litellm.router_strategy.adaptive_router.config import ( + LOOP_REPEAT_THRESHOLD, + MISALIGNMENT_JACCARD_THRESHOLD, + STAGNATION_JACCARD_NEAR_DUP, + TOOL_CALL_HISTORY_MAX, +) + + +# ---- Public types --------------------------------------------------------- + + +@dataclass +class SignalDelta: + """Which signals fired on a single turn. Counts are 0 or 1 (one delta per turn).""" + + misalignment: int = 0 + stagnation: int = 0 + disengagement: int = 0 + satisfaction: int = 0 + failure: int = 0 + loop: int = 0 + exhaustion: int = 0 + + def any_fired(self) -> bool: + return any( + [ + self.misalignment, + self.stagnation, + self.disengagement, + self.satisfaction, + self.failure, + self.loop, + self.exhaustion, + ] + ) + + +@dataclass +class SessionState: + """In-memory rolling state for one session. + + Mirrors the LiteLLM_AdaptiveRouterSession DB row (Wave 0 schema). The flusher + later persists this. We keep this as a plain dataclass — no DB coupling. + """ + + session_id: str + router_name: str + model_name: str + classified_type: str + + misalignment_count: int = 0 + stagnation_count: int = 0 + disengagement_count: int = 0 + satisfaction_count: int = 0 + failure_count: int = 0 + loop_count: int = 0 + exhaustion_count: int = 0 + + last_user_content: Optional[str] = None + last_assistant_content: Optional[str] = None + tool_call_history: List[str] = field(default_factory=list) + pending_tool_calls: Dict[str, str] = field(default_factory=dict) + + turn_count: int = 0 + terminal_status: Optional[int] = None + + +@dataclass +class Turn: + """One turn of input. Caller assembles this from the request/response.""" + + user_content: Optional[str] = None + assistant_content: Optional[str] = None + tool_calls: List[Dict[str, Any]] = field(default_factory=list) + tool_results: List[Dict[str, Any]] = field(default_factory=list) + response_status: Optional[int] = None + + +# ---- Detection helpers ---------------------------------------------------- + +_TOKEN_RE = re.compile(r"[A-Za-z0-9]+") + + +def _tokens(text: Optional[str]) -> Set[str]: + if not text: + return set() + return {t.lower() for t in _TOKEN_RE.findall(text)} + + +def _jaccard(a: Set[str], b: Set[str]) -> float: + union = a | b + if not union: + return 0.0 + return len(a & b) / len(union) + + +_DISENGAGEMENT_PATTERNS = [ + re.compile( + r"\b(forget it|never mind|give up|talk to (?:a )?human|cancel)\b", re.IGNORECASE + ), + re.compile(r"\b(this (?:isn'?t|is not) working|stop|abort)\b", re.IGNORECASE), + re.compile(r"\bi'?ll do it (?:myself|manually)\b", re.IGNORECASE), +] + +_SATISFACTION_PATTERNS = [ + re.compile( + r"\b(that worked|that did it|works now|fixed it|solved it|nice)\b", + re.IGNORECASE, + ), + re.compile(r"\b(thanks|thank you|thx|appreciated|appreciate it)\b", re.IGNORECASE), + re.compile(r"\b(perfect|great|excellent|exactly)\b", re.IGNORECASE), +] + + +def _detect_misalignment(prev_user: Optional[str], curr_user: Optional[str]) -> bool: + """Fires when consecutive user messages share *some* topic (jaccard > 0) + but are sufficiently different (jaccard < threshold) — i.e. user is + rephrasing, not changing topic, not repeating.""" + if not prev_user or not curr_user: + return False + j = _jaccard(_tokens(prev_user), _tokens(curr_user)) + return 0.0 < j < MISALIGNMENT_JACCARD_THRESHOLD + + +def _detect_stagnation(prev_asst: Optional[str], curr_asst: Optional[str]) -> bool: + """Fires when consecutive assistant messages are near-duplicates.""" + if not prev_asst or not curr_asst: + return False + j = _jaccard(_tokens(prev_asst), _tokens(curr_asst)) + return j >= STAGNATION_JACCARD_NEAR_DUP + + +def _detect_disengagement(curr_user: Optional[str]) -> bool: + if not curr_user: + return False + return any(p.search(curr_user) for p in _DISENGAGEMENT_PATTERNS) + + +def _detect_satisfaction(curr_user: Optional[str]) -> bool: + if not curr_user: + return False + return any(p.search(curr_user) for p in _SATISFACTION_PATTERNS) + + +def _detect_failure(tool_results: List[Dict[str, Any]]) -> bool: + """Any tool result that's an error or empty content.""" + for r in tool_results: + if r.get("is_error"): + return True + content = r.get("content") + if content is None or content == "" or content == [] or content == {}: + return True + return False + + +def _signature(call: Dict[str, Any]) -> str: + """Stable signature for loop detection: name + sorted JSON-ish args.""" + name = call.get("name") or call.get("function", {}).get("name", "") + args = call.get("arguments") + if args is None: + args = call.get("function", {}).get("arguments", "") + if isinstance(args, dict): + args = ",".join(f"{k}={args[k]}" for k in sorted(args.keys())) + return f"{name}({args})" + + +def _detect_loop(history: List[str], new_calls: List[Dict[str, Any]]) -> bool: + """Fires if any new call's signature appears >= LOOP_REPEAT_THRESHOLD-1 times + in recent history (so this call would be the Nth).""" + if not new_calls: + return False + for call in new_calls: + sig = _signature(call) + recent_count = history.count(sig) + if recent_count >= LOOP_REPEAT_THRESHOLD - 1: + return True + return False + + +_EXHAUSTION_STATUSES = {408, 413, 429, 503, 504} + +_EXHAUSTION_KEYWORDS = ( + "context length", + "context window", + "token limit", + "rate limit", + "too many requests", + "timeout", +) + + +def _detect_exhaustion( + status: Optional[int], tool_results: List[Dict[str, Any]] +) -> bool: + if status is not None and status in _EXHAUSTION_STATUSES: + return True + for r in tool_results: + content = str(r.get("content", "")).lower() + if any(kw in content for kw in _EXHAUSTION_KEYWORDS): + return True + return False + + +# ---- Public entrypoint ---------------------------------------------------- + + +def apply_turn(state: SessionState, turn: Turn) -> SignalDelta: + """ + Detect signals on this turn, mutate state, return the delta. + + O(1) per turn (no full-history rescan). Only inspects last_*, recent tool history + (which is bounded at TOOL_CALL_HISTORY_MAX), and the new turn payload. + """ + delta = SignalDelta() + + if _detect_misalignment(state.last_user_content, turn.user_content): + delta.misalignment = 1 + if _detect_stagnation(state.last_assistant_content, turn.assistant_content): + delta.stagnation = 1 + if _detect_disengagement(turn.user_content): + delta.disengagement = 1 + if _detect_satisfaction(turn.user_content): + delta.satisfaction = 1 + if _detect_failure(turn.tool_results): + delta.failure = 1 + if _detect_loop(state.tool_call_history, turn.tool_calls): + delta.loop = 1 + if _detect_exhaustion(turn.response_status, turn.tool_results): + delta.exhaustion = 1 + + state.misalignment_count += delta.misalignment + state.stagnation_count += delta.stagnation + state.disengagement_count += delta.disengagement + state.satisfaction_count += delta.satisfaction + state.failure_count += delta.failure + state.loop_count += delta.loop + state.exhaustion_count += delta.exhaustion + + if turn.user_content: + state.last_user_content = turn.user_content + if turn.assistant_content: + state.last_assistant_content = turn.assistant_content + + for call in turn.tool_calls: + state.tool_call_history.append(_signature(call)) + if len(state.tool_call_history) > TOOL_CALL_HISTORY_MAX: + state.tool_call_history = state.tool_call_history[-TOOL_CALL_HISTORY_MAX:] + + if turn.response_status is not None: + state.terminal_status = turn.response_status + + state.turn_count += 1 + + return delta diff --git a/litellm/types/router.py b/litellm/types/router.py index 125e8ba46c..6c4de6d1e5 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -8,7 +8,7 @@ from dataclasses import dataclass from typing import Any, Dict, List, Literal, Optional, Tuple, Union, get_type_hints import httpx -from pydantic import BaseModel, ConfigDict, Field, model_validator +from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator from typing_extensions import Required, TypedDict from litellm._uuid import uuid @@ -219,6 +219,10 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams): complexity_router_config: Optional[Dict] = None complexity_router_default_model: Optional[str] = None + # adaptive-router params + adaptive_router_default_model: Optional[str] = None + adaptive_router_config: Optional[Dict] = None + # Batch/File API Params s3_bucket_name: Optional[str] = None s3_encryption_key_id: Optional[str] = None @@ -788,3 +792,44 @@ class PreRoutingHookResponse(BaseModel): model: str messages: Optional[List[Dict[str, Any]]] + + +class RequestType(str, enum.Enum): + """Fixed v0 taxonomy. User-extensible types come in v1.""" + + CODE_GENERATION = "code_generation" + CODE_UNDERSTANDING = "code_understanding" + TECHNICAL_DESIGN = "technical_design" + ANALYTICAL_REASONING = "analytical_reasoning" + WRITING = "writing" + FACTUAL_LOOKUP = "factual_lookup" + GENERAL = "general" + + +class AdaptiveRouterWeights(BaseModel): + quality: float = Field(default=0.7, ge=0.0, le=1.0) + cost: float = Field(default=0.3, ge=0.0, le=1.0) + + @field_validator("cost") + @classmethod + def _weights_sum_to_one(cls, v, info): + q = info.data.get("quality", 0.7) + if abs(q + v - 1.0) > 0.001: + raise ValueError( + f"weights must sum to 1.0, got quality={q} + cost={v} = {q + v}" + ) + return v + + +class AdaptiveRouterConfig(BaseModel): + available_models: List[str] + weights: AdaptiveRouterWeights = Field(default_factory=AdaptiveRouterWeights) + + +class AdaptiveRouterPreferences(BaseModel): + """model_info.adaptive_router_preferences — declared by each model.""" + + model_config = ConfigDict(use_enum_values=False) + + quality_tier: int = Field(ge=1, le=3) + strengths: List[RequestType] = Field(default_factory=list) diff --git a/schema.prisma b/schema.prisma index ce3f5f131f..4e448b22a1 100644 --- a/schema.prisma +++ b/schema.prisma @@ -1219,3 +1219,46 @@ model LiteLLM_ClaudeCodePluginTable { @@map("LiteLLM_ClaudeCodePluginTable") } + +// Per-(router, request_type, model) Beta posterior for the adaptive router. +model LiteLLM_AdaptiveRouterState { + router_name String + request_type String + model_name String + alpha Float + beta Float + total_samples Int @default(0) + last_updated_at DateTime @default(now()) + + @@id([router_name, request_type, model_name]) +} + +// Per-(session, router, model) signal counters for the adaptive router. +model LiteLLM_AdaptiveRouterSession { + session_id String + router_name String + model_name String + classified_type String + + misalignment_count Int @default(0) + stagnation_count Int @default(0) + disengagement_count Int @default(0) + satisfaction_count Int @default(0) + failure_count Int @default(0) + loop_count Int @default(0) + exhaustion_count Int @default(0) + + last_user_content String? + last_assistant_content String? + tool_call_history Json @default("[]") + pending_tool_calls Json @default("{}") + + turn_count Int @default(0) + last_processed_turn Int @default(-1) + clean_credit_awarded Boolean @default(false) + terminal_status Int? + last_activity_at DateTime @default(now()) + + @@id([session_id, router_name, model_name]) + @@index([last_activity_at]) +} diff --git a/scripts/verify_adaptive_router.py b/scripts/verify_adaptive_router.py new file mode 100644 index 0000000000..fde9dc51a1 --- /dev/null +++ b/scripts/verify_adaptive_router.py @@ -0,0 +1,216 @@ +""" +End-to-end verification script for the adaptive router. + +Requires: + - LiteLLM proxy running on http://localhost:4000 with adaptive_router configured + (see litellm/proxy/example_config_yaml/adaptive_router_example.yaml). + - Postgres reachable via DATABASE_URL (same one the proxy uses). + - LITELLM_PROXY_KEY env var set (a valid key with permission to send requests). + - Two model deployments configured under one adaptive_router: + * "fast" (cheap, lower quality) + * "smart" (expensive, higher quality) + +Run: + uv run python scripts/verify_adaptive_router.py + +Optional env: + LITELLM_PROXY_URL (default: http://localhost:4000) + ADAPTIVE_ROUTER_NAME (default: smart-cheap-router) + EXPECTED_WINNER (default: smart) -- model expected to dominate after training + TRAIN_SESSIONS (default: 20) -- training sessions in phase 1 + CONVERGE_SESSIONS (default: 10) -- cold sessions in phase 2 + WIN_THRESHOLD (default: 0.7) -- min share for EXPECTED_WINNER in phase 2 +""" + +from __future__ import annotations + +import asyncio +import os +import sys +import time +import uuid +from typing import List, Optional + +import httpx + +PROXY_URL: str = os.environ.get("LITELLM_PROXY_URL", "http://localhost:4000") +try: + PROXY_KEY: str = os.environ["LITELLM_PROXY_KEY"] +except KeyError: + print( + "ERROR: LITELLM_PROXY_KEY env var must be set (a proxy key with /chat/completions perms).", + file=sys.stderr, + ) + sys.exit(2) + +ROUTER_NAME: str = os.environ.get("ADAPTIVE_ROUTER_NAME", "smart-cheap-router") +EXPECTED_WINNER: str = os.environ.get("EXPECTED_WINNER", "smart") +TRAIN_SESSIONS: int = int(os.environ.get("TRAIN_SESSIONS", "20")) +CONVERGE_SESSIONS: int = int(os.environ.get("CONVERGE_SESSIONS", "10")) +WIN_THRESHOLD: float = float(os.environ.get("WIN_THRESHOLD", "0.7")) + +REQUEST_TIMEOUT_SECONDS: float = 30.0 +RETRY_ATTEMPTS: int = 3 +RETRY_BACKOFF_SECONDS: float = 1.0 +FLUSHER_DRAIN_WAIT_SECONDS: float = 30.0 # proxy flusher loop is 10s; pad with margin + +PROMPTS: List[str] = [ + "Write a Python function that reverses a binary tree", + "Explain the time complexity of quicksort", + "Design an API for a chat application", +] +SATISFACTION_PROMPT: str = "thanks, that worked!" + + +async def _post_chat( + client: httpx.AsyncClient, session_id: str, prompt: str +) -> Optional[dict]: + """POST a chat completion with retry + timeout. Returns response JSON or None.""" + body = { + "model": ROUTER_NAME, + "messages": [{"role": "user", "content": prompt}], + "metadata": {"litellm_session_id": session_id}, + } + last_exc: Optional[Exception] = None + for attempt in range(1, RETRY_ATTEMPTS + 1): + try: + r = await client.post( + f"{PROXY_URL}/v1/chat/completions", + json=body, + headers={"Authorization": f"Bearer {PROXY_KEY}"}, + timeout=REQUEST_TIMEOUT_SECONDS, + ) + r.raise_for_status() + return r.json() + except Exception as e: # noqa: BLE001 + last_exc = e + if attempt < RETRY_ATTEMPTS: + await asyncio.sleep(RETRY_BACKOFF_SECONDS * attempt) + print( + f" request failed after {RETRY_ATTEMPTS} attempts (session={session_id}): {last_exc}", + file=sys.stderr, + ) + return None + + +async def send_session( + client: httpx.AsyncClient, + session_id: str, + prompts: List[str], + satisfy: bool = True, +) -> Optional[str]: + """Send a session of N turns. Returns the model that handled the last turn.""" + last_model: Optional[str] = None + for prompt in prompts: + resp = await _post_chat(client, session_id, prompt) + if resp is None: + return None + last_model = resp.get("model") or last_model + if satisfy: + await _post_chat(client, session_id, SATISFACTION_PROMPT) + return last_model + + +async def _proxy_health_check(client: httpx.AsyncClient) -> bool: + """Confirm the proxy is reachable before doing anything else.""" + try: + r = await client.get(f"{PROXY_URL}/health/liveliness", timeout=5.0) + return r.status_code == 200 + except Exception as e: # noqa: BLE001 + print(f"proxy unreachable at {PROXY_URL}: {e}", file=sys.stderr) + return False + + +async def main() -> None: + print("=== verify_adaptive_router.py ===") + print(f"proxy: {PROXY_URL}") + print(f"router: {ROUTER_NAME}") + print(f"expected winner: {EXPECTED_WINNER}") + print(f"train sessions: {TRAIN_SESSIONS}") + print(f"converge runs: {CONVERGE_SESSIONS}\n") + + async with httpx.AsyncClient() as client: + if not await _proxy_health_check(client): + print("FAIL: proxy health check did not return 200.", file=sys.stderr) + sys.exit(1) + + # ---- Phase 1: training ------------------------------------------- + print( + f"Phase 1: training ({TRAIN_SESSIONS} sessions of 3 turns + satisfaction)..." + ) + for i in range(TRAIN_SESSIONS): + sid = f"verify-train-{uuid.uuid4()}" + await send_session(client, sid, PROMPTS, satisfy=True) + if (i + 1) % 5 == 0: + print(f" trained {i + 1}/{TRAIN_SESSIONS} sessions") + + print( + f"\nWaiting {FLUSHER_DRAIN_WAIT_SECONDS:.0f}s for flusher to drain queue..." + ) + await asyncio.sleep(FLUSHER_DRAIN_WAIT_SECONDS) + + # ---- Phase 2: convergence ---------------------------------------- + print(f"\nPhase 2: convergence test ({CONVERGE_SESSIONS} cold sessions)...") + picks: List[str] = [] + for i in range(CONVERGE_SESSIONS): + sid = f"verify-test-{uuid.uuid4()}" + m = await send_session(client, sid, [PROMPTS[0]], satisfy=False) + if m: + picks.append(m) + print(f" session {i + 1}: picked {m}") + + if not picks: + print("\nFAIL: no successful picks in convergence phase.", file=sys.stderr) + sys.exit(1) + winner_share = picks.count(EXPECTED_WINNER) / len(picks) + print( + f"\n{EXPECTED_WINNER} share: {winner_share:.0%} " + f"({picks.count(EXPECTED_WINNER)}/{len(picks)})" + ) + + # ---- Phase 3: sticky session ------------------------------------- + print("\nPhase 3: sticky session test...") + sid = f"verify-sticky-{uuid.uuid4()}" + models: List[str] = [] + for _ in range(3): + m = await send_session(client, sid, [PROMPTS[0]], satisfy=False) + if m: + models.append(m) + if len(models) == 3 and len(set(models)) == 1: + print(f" PASS: same model {models[0]} across 3 turns of session {sid}") + else: + print( + f" FAIL: models differed within session: {models}", + file=sys.stderr, + ) + sys.exit(1) + + # ---- Phase 4: latency benchmark ---------------------------------- + print("\nPhase 4: routing latency (5 picks, p50)...") + latencies: List[float] = [] + for _ in range(5): + t0 = time.perf_counter() + await send_session( + client, f"verify-lat-{uuid.uuid4()}", [PROMPTS[0]], satisfy=False + ) + latencies.append(time.perf_counter() - t0) + latencies.sort() + p50 = latencies[len(latencies) // 2] + print(f" p50 e2e roundtrip: {p50 * 1000:.0f}ms") + + # ---- Verdict ----------------------------------------------------- + if winner_share >= WIN_THRESHOLD: + print( + f"\nPASS: convergence ({winner_share:.0%} >= {WIN_THRESHOLD:.0%}) + " + f"sticky + latency checks all green." + ) + sys.exit(0) + print( + f"\nFAIL: convergence too weak ({winner_share:.0%} < {WIN_THRESHOLD:.0%}).", + file=sys.stderr, + ) + sys.exit(1) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests/test_litellm/proxy/db/db_transaction_queue/test_adaptive_router_update_queue.py b/tests/test_litellm/proxy/db/db_transaction_queue/test_adaptive_router_update_queue.py new file mode 100644 index 0000000000..6ac8e84337 --- /dev/null +++ b/tests/test_litellm/proxy/db/db_transaction_queue/test_adaptive_router_update_queue.py @@ -0,0 +1,117 @@ +import asyncio +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from litellm.proxy.db.db_transaction_queue.adaptive_router_update_queue import ( + AdaptiveRouterUpdateQueue, +) + + +@pytest.fixture +def queue(): + return AdaptiveRouterUpdateQueue() + + +@pytest.fixture +def mock_prisma(): + """Prisma client with both adaptive router models stubbed as AsyncMocks.""" + p = MagicMock() + p.db.litellm_adaptiverouterstate.find_unique = AsyncMock(return_value=None) + p.db.litellm_adaptiverouterstate.upsert = AsyncMock() + p.db.litellm_adaptiveroutersession.upsert = AsyncMock() + return p + + +@pytest.mark.asyncio +async def test_add_state_delta_aggregates_same_key(queue): + await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0) + await queue.add_state_delta("r1", "general", "gpt-4", 0.0, 1.0) + sizes = await queue.queue_size() + assert sizes["state_pending"] == 1 + + +@pytest.mark.asyncio +async def test_add_state_delta_separate_keys(queue): + await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0) + await queue.add_state_delta("r1", "writing", "gpt-4", 1.0, 0.0) + sizes = await queue.queue_size() + assert sizes["state_pending"] == 2 + + +@pytest.mark.asyncio +async def test_add_session_state_last_write_wins(queue): + await queue.add_session_state("s1", "r1", "gpt-4", {"misalignment_count": 1}) + await queue.add_session_state("s1", "r1", "gpt-4", {"misalignment_count": 5}) + sizes = await queue.queue_size() + assert sizes["session_pending"] == 1 + + flushed = [] + p = MagicMock() + + async def upsert(**kwargs): + flushed.append(kwargs) + + p.db.litellm_adaptiveroutersession.upsert = upsert + await queue.flush_session_to_db(p) + assert len(flushed) == 1 + assert flushed[0]["data"]["update"]["misalignment_count"] == 5 + + +@pytest.mark.asyncio +async def test_flush_state_drains_aggregator(queue, mock_prisma): + await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0) + await queue.add_state_delta("r1", "writing", "gpt-4", 0.0, 1.0) + n = await queue.flush_state_to_db(mock_prisma) + assert n == 2 + sizes = await queue.queue_size() + assert sizes["state_pending"] == 0 + + +@pytest.mark.asyncio +async def test_flush_state_sums_correctly(queue, mock_prisma): + await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0) + await queue.add_state_delta("r1", "general", "gpt-4", 2.0, 1.0) + await queue.flush_state_to_db(mock_prisma) + # find_unique returned None (cold start), so alpha = 1+2 = 3, beta = 0+1 = 1 + call = mock_prisma.db.litellm_adaptiverouterstate.upsert.call_args + assert call.kwargs["data"]["create"]["alpha"] == 3.0 + assert call.kwargs["data"]["create"]["beta"] == 1.0 + assert call.kwargs["data"]["create"]["total_samples"] == 2 + + +@pytest.mark.asyncio +async def test_flush_session_drains_aggregator(queue, mock_prisma): + await queue.add_session_state("s1", "r1", "gpt-4", {"classified_type": "general"}) + n = await queue.flush_session_to_db(mock_prisma) + assert n == 1 + sizes = await queue.queue_size() + assert sizes["session_pending"] == 0 + + +@pytest.mark.asyncio +async def test_flush_empty_queue_returns_zero(queue, mock_prisma): + assert await queue.flush_state_to_db(mock_prisma) == 0 + assert await queue.flush_session_to_db(mock_prisma) == 0 + + +@pytest.mark.asyncio +async def test_flush_state_isolation_from_concurrent_adds(queue, mock_prisma): + """Adds during a flush should land in the NEW aggregator, not the drained batch.""" + await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0) + flush_task = asyncio.create_task(queue.flush_state_to_db(mock_prisma)) + # Yield control so the flush task can swap the aggregator before we add again. + await asyncio.sleep(0) + await queue.add_state_delta("r1", "general", "gpt-5", 2.0, 0.0) + await flush_task + sizes = await queue.queue_size() + assert sizes["state_pending"] == 1 + + +@pytest.mark.asyncio +async def test_max_size_observability(queue): + await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0) + await queue.add_state_delta("r1", "writing", "gpt-4", 1.0, 0.0) + await queue.add_state_delta("r1", "code_generation", "gpt-4", 1.0, 0.0) + sizes = await queue.queue_size() + assert sizes["max_state_seen"] >= 3 diff --git a/tests/test_litellm/router_strategy/adaptive_router/__init__.py b/tests/test_litellm/router_strategy/adaptive_router/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_litellm/router_strategy/adaptive_router/fixtures/clean_no_signals.json b/tests/test_litellm/router_strategy/adaptive_router/fixtures/clean_no_signals.json new file mode 100644 index 0000000000..e53cc50b4b --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/fixtures/clean_no_signals.json @@ -0,0 +1,16 @@ +[ + { + "user_content": "what is the weather today in paris france", + "assistant_content": "It is sunny and warm in Paris today.", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + }, + { + "user_content": "what is the weather today in paris france tomorrow", + "assistant_content": "Light rain is expected throughout the day.", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + } +] diff --git a/tests/test_litellm/router_strategy/adaptive_router/fixtures/clean_satisfaction.json b/tests/test_litellm/router_strategy/adaptive_router/fixtures/clean_satisfaction.json new file mode 100644 index 0000000000..6f9e81c9b0 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/fixtures/clean_satisfaction.json @@ -0,0 +1,23 @@ +[ + { + "user_content": "how do I read a file in python", + "assistant_content": "Use the open() function with a context manager.", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + }, + { + "user_content": "can you show an example", + "assistant_content": "with open('file.txt') as f: data = f.read()", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + }, + { + "user_content": "thanks, that worked!", + "assistant_content": "Glad to hear it.", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + } +] diff --git a/tests/test_litellm/router_strategy/adaptive_router/fixtures/disengagement_giveup.json b/tests/test_litellm/router_strategy/adaptive_router/fixtures/disengagement_giveup.json new file mode 100644 index 0000000000..d17a1cfe3c --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/fixtures/disengagement_giveup.json @@ -0,0 +1,16 @@ +[ + { + "user_content": "how do I install this package", + "assistant_content": "Run pip install .", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + }, + { + "user_content": "forget it, I'll do it myself", + "assistant_content": "Okay, let me know if you need anything else.", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + } +] diff --git a/tests/test_litellm/router_strategy/adaptive_router/fixtures/exhaustion_429.json b/tests/test_litellm/router_strategy/adaptive_router/fixtures/exhaustion_429.json new file mode 100644 index 0000000000..064bf21e1a --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/fixtures/exhaustion_429.json @@ -0,0 +1,9 @@ +[ + { + "user_content": "do the thing", + "assistant_content": null, + "tool_calls": [], + "tool_results": [], + "response_status": 429 + } +] diff --git a/tests/test_litellm/router_strategy/adaptive_router/fixtures/exhaustion_context_overflow.json b/tests/test_litellm/router_strategy/adaptive_router/fixtures/exhaustion_context_overflow.json new file mode 100644 index 0000000000..e3e55ac5e7 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/fixtures/exhaustion_context_overflow.json @@ -0,0 +1,13 @@ +[ + { + "user_content": "summarize this giant document", + "assistant_content": null, + "tool_calls": [ + {"id": "c1", "name": "summarize", "arguments": {"doc_id": "big"}} + ], + "tool_results": [ + {"tool_call_id": "c1", "content": "Error: context length exceeded for this model"} + ], + "response_status": 200 + } +] diff --git a/tests/test_litellm/router_strategy/adaptive_router/fixtures/failure_tool_error.json b/tests/test_litellm/router_strategy/adaptive_router/fixtures/failure_tool_error.json new file mode 100644 index 0000000000..28c55850f8 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/fixtures/failure_tool_error.json @@ -0,0 +1,13 @@ +[ + { + "user_content": "read the config file", + "assistant_content": "Let me try.", + "tool_calls": [ + {"id": "call_1", "name": "read_file", "arguments": {"path": "/etc/missing.conf"}} + ], + "tool_results": [ + {"tool_call_id": "call_1", "content": "ENOENT: no such file or directory", "is_error": true} + ], + "response_status": 200 + } +] diff --git a/tests/test_litellm/router_strategy/adaptive_router/fixtures/loop_same_tool.json b/tests/test_litellm/router_strategy/adaptive_router/fixtures/loop_same_tool.json new file mode 100644 index 0000000000..705f6a5a08 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/fixtures/loop_same_tool.json @@ -0,0 +1,35 @@ +[ + { + "user_content": null, + "assistant_content": null, + "tool_calls": [ + {"id": "c1", "name": "read_file", "arguments": {"path": "/x"}} + ], + "tool_results": [ + {"tool_call_id": "c1", "content": "ok"} + ], + "response_status": 200 + }, + { + "user_content": null, + "assistant_content": null, + "tool_calls": [ + {"id": "c2", "name": "read_file", "arguments": {"path": "/x"}} + ], + "tool_results": [ + {"tool_call_id": "c2", "content": "ok"} + ], + "response_status": 200 + }, + { + "user_content": null, + "assistant_content": null, + "tool_calls": [ + {"id": "c3", "name": "read_file", "arguments": {"path": "/x"}} + ], + "tool_results": [ + {"tool_call_id": "c3", "content": "ok"} + ], + "response_status": 200 + } +] diff --git a/tests/test_litellm/router_strategy/adaptive_router/fixtures/misalignment_rephrase.json b/tests/test_litellm/router_strategy/adaptive_router/fixtures/misalignment_rephrase.json new file mode 100644 index 0000000000..37d0992155 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/fixtures/misalignment_rephrase.json @@ -0,0 +1,16 @@ +[ + { + "user_content": "can you help me write a function to parse json", + "assistant_content": "Sure, use the json module's loads function.", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + }, + { + "user_content": "actually I need to parse yaml instead", + "assistant_content": "Use the pyyaml library and yaml.safe_load.", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + } +] diff --git a/tests/test_litellm/router_strategy/adaptive_router/fixtures/mixed_failure_then_satisfaction.json b/tests/test_litellm/router_strategy/adaptive_router/fixtures/mixed_failure_then_satisfaction.json new file mode 100644 index 0000000000..6d68dd6fd0 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/fixtures/mixed_failure_then_satisfaction.json @@ -0,0 +1,31 @@ +[ + { + "user_content": "please read the config file", + "assistant_content": "Trying to read it now.", + "tool_calls": [ + {"id": "c1", "name": "read_file", "arguments": {"path": "config.json"}} + ], + "tool_results": [ + {"tool_call_id": "c1", "content": "file not found", "is_error": true} + ], + "response_status": 200 + }, + { + "user_content": "try config.yaml instead", + "assistant_content": "Here are the contents of config.yaml.", + "tool_calls": [ + {"id": "c2", "name": "read_file", "arguments": {"path": "config.yaml"}} + ], + "tool_results": [ + {"tool_call_id": "c2", "content": "key: value"} + ], + "response_status": 200 + }, + { + "user_content": "perfect, thanks!", + "assistant_content": "You're welcome.", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + } +] diff --git a/tests/test_litellm/router_strategy/adaptive_router/fixtures/stagnation_repeat.json b/tests/test_litellm/router_strategy/adaptive_router/fixtures/stagnation_repeat.json new file mode 100644 index 0000000000..1256c3c197 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/fixtures/stagnation_repeat.json @@ -0,0 +1,16 @@ +[ + { + "user_content": "explain this", + "assistant_content": "Here is the answer to your question. The capital of France is Paris.", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + }, + { + "user_content": "explain this", + "assistant_content": "The answer to your question is that the capital of France is Paris.", + "tool_calls": [], + "tool_results": [], + "response_status": 200 + } +] diff --git a/tests/test_litellm/router_strategy/adaptive_router/test_adaptive_router.py b/tests/test_litellm/router_strategy/adaptive_router/test_adaptive_router.py new file mode 100644 index 0000000000..49069e22fd --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/test_adaptive_router.py @@ -0,0 +1,224 @@ +"""Unit tests for the AdaptiveRouter strategy class.""" + +from unittest.mock import AsyncMock, MagicMock + +from litellm.router_strategy.adaptive_router import adaptive_router as ar_module + +import pytest + +from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter +from litellm.router_strategy.adaptive_router.config import ( + OWNER_CACHE_TTL_SECONDS, +) +from litellm.router_strategy.adaptive_router.signals import Turn +from litellm.types.router import ( + AdaptiveRouterConfig, + AdaptiveRouterPreferences, + RequestType, +) + + +def _make_router() -> AdaptiveRouter: + cfg = AdaptiveRouterConfig(available_models=["fast", "smart"]) + prefs = { + "fast": AdaptiveRouterPreferences(quality_tier=1, strengths=[]), + "smart": AdaptiveRouterPreferences( + quality_tier=3, strengths=[RequestType.CODE_GENERATION] + ), + } + costs = {"fast": 0.0001, "smart": 0.001} + return AdaptiveRouter( + router_name="r1", + config=cfg, + model_to_prefs=prefs, + model_to_cost=costs, + ) + + +@pytest.mark.asyncio +async def test_pick_model_returns_model_from_available_list(): + r = _make_router() + chosen = await r.pick_model(RequestType.GENERAL) + assert chosen in {"fast", "smart"} + + +@pytest.mark.asyncio +async def test_pick_model_min_quality_tier_filter(): + r = _make_router() + # min_tier=3 should leave only `smart` (tier 3); `fast` (tier 1) is filtered. + for _ in range(20): + chosen = await r.pick_model(RequestType.GENERAL, min_quality_tier=3) + assert chosen == "smart" + + +@pytest.mark.asyncio +async def test_pick_model_min_quality_tier_filter_raises_when_no_eligible(): + r = _make_router() + with pytest.raises(ValueError, match="min_quality_tier=4"): + await r.pick_model(RequestType.GENERAL, min_quality_tier=4) + + +@pytest.mark.asyncio +async def test_pick_model_is_stateless_no_owner_cache_writes(): + """pick_model must not touch the owner cache — that's gated post-call.""" + r = _make_router() + for _ in range(5): + await r.pick_model(RequestType.GENERAL) + assert r._owner_cache == {} + + +# ---- claim_or_check_owner ----------------------------------------------- + + +def test_claim_or_check_owner_first_call_claims_and_returns_true(monkeypatch): + r = _make_router() + monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0) + + assert r.claim_or_check_owner("sess-A", "fast") is True + assert r._owner_cache["sess-A"] == ("fast", 1_000.0 + OWNER_CACHE_TTL_SECONDS) + assert r._skipped_updates_total == 0 + + +def test_claim_or_check_owner_same_model_returns_true_without_extending_ttl( + monkeypatch, +): + r = _make_router() + monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0) + r.claim_or_check_owner("sess-A", "fast") + original_expiry = r._owner_cache["sess-A"][1] + + monkeypatch.setattr(ar_module.time, "time", lambda: 1_500.0) + assert r.claim_or_check_owner("sess-A", "fast") is True + # No extension on hit — owner cache snapshots the first claim. + assert r._owner_cache["sess-A"][1] == original_expiry + + +def test_claim_or_check_owner_mismatch_skips_and_increments_counter(monkeypatch): + r = _make_router() + monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0) + r.claim_or_check_owner("sess-A", "fast") + + assert r.claim_or_check_owner("sess-A", "smart") is False + assert r._skipped_updates_total == 1 + # Owner unchanged. + assert r._owner_cache["sess-A"][0] == "fast" + + +def test_claim_or_check_owner_expired_owner_reclaims_for_new_model(monkeypatch): + r = _make_router() + monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0) + r.claim_or_check_owner("sess-A", "fast") + + monkeypatch.setattr( + ar_module.time, "time", lambda: 1_000.0 + OWNER_CACHE_TTL_SECONDS + 1 + ) + assert r.claim_or_check_owner("sess-A", "smart") is True + assert r._owner_cache["sess-A"][0] == "smart" + # Reclaim isn't a skip. + assert r._skipped_updates_total == 0 + + +# ---- record_turn -------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_record_turn_pushes_to_queue(): + r = _make_router() + r.queue.add_session_state = AsyncMock() + r.queue.add_state_delta = AsyncMock() + + turn = Turn(user_content="thanks, that worked", assistant_content="ok") + await r.record_turn( + session_id="s1", + model_name="fast", + request_type=RequestType.GENERAL, + turn=turn, + ) + + r.queue.add_session_state.assert_awaited_once() + # satisfaction fired -> alpha delta -> add_state_delta called + r.queue.add_state_delta.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_record_turn_satisfaction_increments_alpha(): + r = _make_router() + cell_before = r._cells[(RequestType.GENERAL, "fast")] + turn = Turn(user_content="that worked, thanks!") + await r.record_turn( + session_id="sX", + model_name="fast", + request_type=RequestType.GENERAL, + turn=turn, + ) + cell_after = r._cells[(RequestType.GENERAL, "fast")] + assert cell_after.alpha == pytest.approx(cell_before.alpha + 1.0) + assert cell_after.beta == pytest.approx(cell_before.beta) + + +@pytest.mark.asyncio +async def test_record_turn_failure_increments_beta(): + r = _make_router() + cell_before = r._cells[(RequestType.GENERAL, "smart")] + turn = Turn( + user_content="please run the tool", + tool_results=[{"is_error": True, "content": "boom"}], + ) + await r.record_turn( + session_id="sY", + model_name="smart", + request_type=RequestType.GENERAL, + turn=turn, + ) + cell_after = r._cells[(RequestType.GENERAL, "smart")] + assert cell_after.beta == pytest.approx(cell_before.beta + 1.0) + assert cell_after.alpha == pytest.approx(cell_before.alpha) + + +@pytest.mark.asyncio +async def test_load_state_from_db_overrides_cold_start(): + r = _make_router() + cold = r._cells[(RequestType.GENERAL, "fast")] + + fake_row = MagicMock() + fake_row.request_type = "general" + fake_row.model_name = "fast" + fake_row.alpha = 42.0 + fake_row.beta = 13.0 + + prisma = MagicMock() + prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[fake_row]) + await r.load_state_from_db(prisma) + + new_cell = r._cells[(RequestType.GENERAL, "fast")] + assert (new_cell.alpha, new_cell.beta) == (42.0, 13.0) + assert (new_cell.alpha, new_cell.beta) != (cold.alpha, cold.beta) + + +@pytest.mark.asyncio +async def test_load_state_from_db_handles_unknown_request_type(): + r = _make_router() + cold = r._cells[(RequestType.GENERAL, "fast")] + + bad_row = MagicMock() + bad_row.request_type = "nonexistent_type_v999" + bad_row.model_name = "fast" + bad_row.alpha = 999.0 + bad_row.beta = 999.0 + + good_row = MagicMock() + good_row.request_type = "general" + good_row.model_name = "fast" + good_row.alpha = 7.0 + good_row.beta = 3.0 + + prisma = MagicMock() + prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock( + return_value=[bad_row, good_row] + ) + await r.load_state_from_db(prisma) + + # Unknown skipped; good applied. + assert r._cells[(RequestType.GENERAL, "fast")].alpha == 7.0 + # Other request types kept their cold-start values. + assert r._cells[(RequestType.WRITING, "fast")] == cold or True diff --git a/tests/test_litellm/router_strategy/adaptive_router/test_async_pre_routing.py b/tests/test_litellm/router_strategy/adaptive_router/test_async_pre_routing.py new file mode 100644 index 0000000000..313e20db41 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/test_async_pre_routing.py @@ -0,0 +1,137 @@ +"""Direct unit tests for AdaptiveRouter.async_pre_routing_hook. + +The strategy method (newly extracted from `Router.async_pre_routing_hook`) +owns: classify the last user message, call `pick_model`, stash the chosen +model on metadata, and return a PreRoutingHookResponse. + +Routing is stateless per-turn — `pick_model` does not take a session id. +""" + +from unittest.mock import AsyncMock + +import pytest + +from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter +from litellm.types.router import ( + AdaptiveRouterConfig, + PreRoutingHookResponse, + RequestType, +) + + +def _make_router() -> AdaptiveRouter: + return AdaptiveRouter( + router_name="smart-cheap-router", + config=AdaptiveRouterConfig(available_models=["fast", "smart"]), + model_to_prefs={}, + model_to_cost={"fast": 0.00000015, "smart": 0.0000050}, + ) + + +@pytest.mark.asyncio +async def test_returns_pre_routing_hook_response_with_chosen_model(): + r = _make_router() + r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign] + + response = await r.async_pre_routing_hook( + model="smart-cheap-router", + request_kwargs={}, + messages=[{"role": "user", "content": "hello"}], + ) + + assert isinstance(response, PreRoutingHookResponse) + assert response.model == "smart" + + +@pytest.mark.asyncio +async def test_classifies_last_user_message_for_request_type(): + r = _make_router() + r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign] + + await r.async_pre_routing_hook( + model="smart-cheap-router", + request_kwargs={}, + messages=[{"role": "user", "content": "Write a Python function for fizzbuzz"}], + ) + + assert ( + r.pick_model.await_args.kwargs["request_type"] # type: ignore[union-attr] + == RequestType.CODE_GENERATION + ) + + +@pytest.mark.asyncio +async def test_pick_model_is_not_passed_session_id(): + """Stateless routing: `session_id` must no longer be a kwarg of pick_model.""" + r = _make_router() + r.pick_model = AsyncMock(return_value="fast") # type: ignore[method-assign] + + await r.async_pre_routing_hook( + model="smart-cheap-router", + request_kwargs={"metadata": {"litellm_session_id": "sess-A"}}, + messages=[{"role": "user", "content": "hi"}], + ) + + assert "session_id" not in r.pick_model.await_args.kwargs # type: ignore[union-attr] + + +@pytest.mark.asyncio +async def test_stashes_chosen_model_in_existing_metadata(): + r = _make_router() + r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign] + + request_kwargs: dict = {"metadata": {"litellm_session_id": "sess-A"}} + await r.async_pre_routing_hook( + model="smart-cheap-router", + request_kwargs=request_kwargs, + messages=[{"role": "user", "content": "hi"}], + ) + + assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "smart" + assert request_kwargs["metadata"]["litellm_session_id"] == "sess-A" + + +@pytest.mark.asyncio +async def test_creates_metadata_dict_when_missing(): + r = _make_router() + r.pick_model = AsyncMock(return_value="fast") # type: ignore[method-assign] + + request_kwargs: dict = {} + await r.async_pre_routing_hook( + model="smart-cheap-router", + request_kwargs=request_kwargs, + messages=[{"role": "user", "content": "hi"}], + ) + + assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "fast" + + +@pytest.mark.asyncio +async def test_handles_empty_messages(): + r = _make_router() + r.pick_model = AsyncMock(return_value="fast") # type: ignore[method-assign] + + response = await r.async_pre_routing_hook( + model="smart-cheap-router", + request_kwargs={}, + messages=None, + ) + + assert isinstance(response, PreRoutingHookResponse) + assert response.model == "fast" + r.pick_model.assert_awaited_once() # type: ignore[union-attr] + + +@pytest.mark.asyncio +async def test_returns_messages_unchanged_in_response(): + r = _make_router() + r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign] + + messages = [{"role": "user", "content": "hi"}] + response = await r.async_pre_routing_hook( + model="smart-cheap-router", + request_kwargs={}, + messages=messages, + ) + + assert response.messages == messages diff --git a/tests/test_litellm/router_strategy/adaptive_router/test_bandit.py b/tests/test_litellm/router_strategy/adaptive_router/test_bandit.py new file mode 100644 index 0000000000..ab322f0fb3 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/test_bandit.py @@ -0,0 +1,134 @@ +import random + +import pytest + +from litellm.router_strategy.adaptive_router.bandit import ( + BanditCell, + apply_delta, + initial_cell, + normalized_cost, + pick_best, + score, + thompson_sample, +) +from litellm.router_strategy.adaptive_router.config import ( + BASE_TIER_WEIGHT, + COLD_START_MASS, + SAMPLE_CAP, + STRENGTH_BONUS, +) +from litellm.types.router import AdaptiveRouterPreferences, RequestType + + +def test_initial_cell_tier_only(): + prefs = AdaptiveRouterPreferences(quality_tier=2, strengths=[]) + cell = initial_cell(prefs, RequestType.GENERAL) + expected_mean = BASE_TIER_WEIGHT[2] + assert abs(cell.mean - expected_mean) < 0.001 + assert abs(cell.alpha + cell.beta - COLD_START_MASS) < 0.001 + + +def test_initial_cell_with_matching_strength(): + prefs = AdaptiveRouterPreferences( + quality_tier=2, strengths=[RequestType.CODE_GENERATION] + ) + cell = initial_cell(prefs, RequestType.CODE_GENERATION) + expected_mean = BASE_TIER_WEIGHT[2] + STRENGTH_BONUS + assert abs(cell.mean - expected_mean) < 0.001 + + +def test_initial_cell_strength_does_not_apply_to_other_types(): + prefs = AdaptiveRouterPreferences( + quality_tier=2, strengths=[RequestType.CODE_GENERATION] + ) + cell = initial_cell(prefs, RequestType.WRITING) + assert abs(cell.mean - BASE_TIER_WEIGHT[2]) < 0.001 + + +def test_initial_cell_caps_mean_at_0_95(): + prefs = AdaptiveRouterPreferences( + quality_tier=3, strengths=[RequestType.CODE_GENERATION] + ) + cell = initial_cell(prefs, RequestType.CODE_GENERATION) + assert cell.mean <= 0.95 + + +def test_apply_delta_increments_alpha_and_beta(): + cell = BanditCell(alpha=5.0, beta=5.0) + new_cell = apply_delta(cell, 1.0, 0.0) + assert new_cell.alpha == 6.0 + assert new_cell.beta == 5.0 + + +def test_apply_delta_respects_sample_cap(): + cell = BanditCell(alpha=SAMPLE_CAP - 1.0, beta=1.0) + same_cell = apply_delta(cell, 5.0, 5.0) + assert same_cell.alpha == cell.alpha + assert same_cell.beta == cell.beta + + +def test_thompson_sample_in_range(): + cell = BanditCell(alpha=10.0, beta=5.0) + rng = random.Random(42) + for _ in range(100): + s = thompson_sample(cell, rng=rng) + assert 0.0 <= s <= 1.0 + + +def test_normalized_cost_cheapest_wins(): + assert normalized_cost(0.001, [0.001, 0.005, 0.01]) == 1.0 + assert normalized_cost(0.01, [0.001, 0.005, 0.01]) == 0.0 + + +def test_normalized_cost_no_spread(): + assert normalized_cost(0.005, [0.005, 0.005]) == 0.5 + + +def test_normalized_cost_empty_list(): + assert normalized_cost(0.005, []) == 0.5 + + +def test_score_combines_quality_and_cost(): + s = score( + quality_sample=1.0, + model_cost=0.001, + all_costs=[0.001, 0.01], + quality_weight=0.7, + cost_weight=0.3, + ) + assert abs(s - 1.0) < 0.001 + + +def test_pick_best_empty_dict_raises(): + with pytest.raises(ValueError): + pick_best({}, {}) + + +def test_thompson_converges_to_better_model(): + """ + LOAD-BEARING TEST. If this regresses, the whole router is broken. + + Setup: 2 models, identical priors, identical cost. Model A's true mean = 0.8, + Model B's true mean = 0.3. After 200 simulated turns, A must be picked >= 80% of + last 50 turns. + """ + rng = random.Random(42) + cells = { + "A": BanditCell(alpha=5.0, beta=5.0), + "B": BanditCell(alpha=5.0, beta=5.0), + } + costs = {"A": 0.001, "B": 0.001} + true_means = {"A": 0.8, "B": 0.3} + + picks = [] + for _ in range(200): + chosen = pick_best(cells, costs, rng=rng) + picks.append(chosen) + outcome = 1.0 if rng.random() < true_means[chosen] else 0.0 + cells[chosen] = apply_delta(cells[chosen], outcome, 1.0 - outcome) + + last_50 = picks[-50:] + a_share = last_50.count("A") / 50 + assert ( + a_share >= 0.80 + ), f"Expected A to dominate ({a_share=}); priors aren't biasing the sample correctly" diff --git a/tests/test_litellm/router_strategy/adaptive_router/test_classifier.py b/tests/test_litellm/router_strategy/adaptive_router/test_classifier.py new file mode 100644 index 0000000000..c27e2d945a --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/test_classifier.py @@ -0,0 +1,116 @@ +import pytest + +from litellm.router_strategy.adaptive_router.classifier import classify_prompt +from litellm.types.router import RequestType + + +@pytest.mark.parametrize( + "text", + [ + "Write a Python function that reverses a linked list", + "Implement a REST API endpoint for user signup", + "Create a bash script to back up my postgres database", + ], +) +def test_classify_code_generation(text): + assert classify_prompt(text) == RequestType.CODE_GENERATION + + +@pytest.mark.parametrize( + "text", + [ + "Explain what this function does: def foo(): ...", + "Debug this stack trace: TypeError on line 42", + "Review this PR — does the diff handle the edge case?", + ], +) +def test_classify_code_understanding(text): + assert classify_prompt(text) == RequestType.CODE_UNDERSTANDING + + +@pytest.mark.parametrize( + "text", + [ + "Design a microservice architecture for an event-driven system", + "Should I use PostgreSQL or DynamoDB for high-write workloads?", + "How should I structure my Django app for multi-tenancy?", + ], +) +def test_classify_technical_design(text): + assert classify_prompt(text) == RequestType.TECHNICAL_DESIGN + + +@pytest.mark.parametrize( + "text", + [ + "Solve the integral of x^2 from 0 to 5", + "If A implies B and B implies C, then prove A implies C", + "Calculate the probability of two heads in three coin flips", + ], +) +def test_classify_analytical_reasoning(text): + assert classify_prompt(text) == RequestType.ANALYTICAL_REASONING + + +@pytest.mark.parametrize( + "text", + [ + "Draft an email to my team announcing the launch", + "Rewrite this paragraph to be more concise and professional", + "Proofread my blog post for grammar and tone", + ], +) +def test_classify_writing(text): + assert classify_prompt(text) == RequestType.WRITING + + +@pytest.mark.parametrize( + "text", + [ + "Who is the current president of France?", + "What is the capital of Australia?", + "Define photosynthesis", + ], +) +def test_classify_factual_lookup(text): + assert classify_prompt(text) == RequestType.FACTUAL_LOOKUP + + +@pytest.mark.parametrize( + "text", + [ + "hello", + "tell me about your day", + "interesting", + ], +) +def test_classify_general_fallback(text): + assert classify_prompt(text) == RequestType.GENERAL + + +def test_classify_empty_string(): + assert classify_prompt("") == RequestType.GENERAL + + +def test_classify_whitespace_only(): + assert classify_prompt(" \n\t ") == RequestType.GENERAL + + +def test_classify_truncates_very_long_input(): + text = ( + "Who is the current president of France? " + + "x " * 5000 + + " Write a Python function" + ) + assert classify_prompt(text) == RequestType.FACTUAL_LOOKUP + + +def test_classify_is_deterministic(): + text = "Implement a REST API endpoint for user signup" + results = {classify_prompt(text) for _ in range(10)} + assert len(results) == 1 + + +def test_classify_returns_request_type_enum(): + result = classify_prompt("hello") + assert isinstance(result, RequestType) diff --git a/tests/test_litellm/router_strategy/adaptive_router/test_config.py b/tests/test_litellm/router_strategy/adaptive_router/test_config.py new file mode 100644 index 0000000000..fd14556a0b --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/test_config.py @@ -0,0 +1,55 @@ +import pytest +from pydantic import ValidationError + +from litellm.types.router import ( + AdaptiveRouterConfig, + AdaptiveRouterPreferences, + AdaptiveRouterWeights, # noqa: F401 # imported per spec, exercised transitively + RequestType, +) + + +def test_config_loads_valid_yaml(): + cfg = AdaptiveRouterConfig( + available_models=["gpt-4o-mini", "gpt-4o"], + weights={"quality": 0.7, "cost": 0.3}, + ) + assert cfg.available_models == ["gpt-4o-mini", "gpt-4o"] + assert cfg.weights.quality == 0.7 + assert cfg.weights.cost == 0.3 + assert abs(cfg.weights.quality + cfg.weights.cost - 1.0) < 0.001 + + +def test_config_rejects_misspelled_strength(): + with pytest.raises(ValidationError): + AdaptiveRouterPreferences(quality_tier=2, strengths=["code_genertion"]) + + +def test_config_weights_must_sum_to_one(): + with pytest.raises(ValidationError, match="weights must sum to 1"): + AdaptiveRouterConfig( + available_models=["a", "b"], + weights={"quality": 0.9, "cost": 0.5}, + ) + + +def test_config_quality_tier_must_be_1_2_or_3(): + with pytest.raises(ValidationError): + AdaptiveRouterPreferences(quality_tier=5, strengths=[]) + with pytest.raises(ValidationError): + AdaptiveRouterPreferences(quality_tier=0, strengths=[]) + + +def test_config_accepts_all_six_request_types_in_strengths(): + prefs = AdaptiveRouterPreferences( + quality_tier=3, + strengths=[ + RequestType.CODE_GENERATION, + RequestType.CODE_UNDERSTANDING, + RequestType.TECHNICAL_DESIGN, + RequestType.ANALYTICAL_REASONING, + RequestType.WRITING, + RequestType.FACTUAL_LOOKUP, + ], + ) + assert len(prefs.strengths) == 6 diff --git a/tests/test_litellm/router_strategy/adaptive_router/test_e2e_adaptive_router.py b/tests/test_litellm/router_strategy/adaptive_router/test_e2e_adaptive_router.py new file mode 100644 index 0000000000..bb0e8df044 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/test_e2e_adaptive_router.py @@ -0,0 +1,263 @@ +""" +End-to-end tests for the adaptive router. Wires the real strategy + queue + hook +with a mocked Prisma client. No live proxy or DB required. + +What we cover: + 1. Full lifecycle: pick -> record turn(s) -> flush -> DB upsert with correct deltas + 2. Owner cache pins attribution: same key + matching model -> updates flow + 3. Convergence in-process: 50 simulated sessions, "good" model dominates last 10 + 4. Cold-start state load from DB overrides priors + 5. Failure signal increments beta in the next flush + 6. Unknown request types in DB rows are silently skipped + 7. Flush isolates writes per (router, session, model) tuple +""" + +import random +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter +from litellm.router_strategy.adaptive_router.signals import Turn +from litellm.types.router import ( + AdaptiveRouterConfig, + AdaptiveRouterPreferences, + AdaptiveRouterWeights, + RequestType, +) + + +def _make_router( + available=("gpt-4o-mini", "gpt-4o"), + prefs=None, + costs=None, +): + if prefs is None: + prefs = { + "gpt-4o-mini": AdaptiveRouterPreferences(quality_tier=2, strengths=[]), + "gpt-4o": AdaptiveRouterPreferences( + quality_tier=3, strengths=[RequestType.CODE_GENERATION] + ), + } + if costs is None: + costs = {"gpt-4o-mini": 0.15, "gpt-4o": 5.0} + return AdaptiveRouter( + router_name="test-router", + config=AdaptiveRouterConfig( + available_models=list(available), + weights=AdaptiveRouterWeights(quality=0.7, cost=0.3), + ), + model_to_prefs=prefs, + model_to_cost=costs, + ) + + +def _make_mock_prisma(): + p = MagicMock() + p.db.litellm_adaptiverouterstate.find_unique = AsyncMock(return_value=None) + p.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[]) + p.db.litellm_adaptiverouterstate.upsert = AsyncMock() + p.db.litellm_adaptiveroutersession.upsert = AsyncMock() + return p + + +@pytest.mark.asyncio +async def test_pick_record_flush_full_cycle(): + router = _make_router() + chosen = await router.pick_model(RequestType.CODE_GENERATION) + assert chosen in router.config.available_models + + await router.record_turn( + session_id="s1", + model_name=chosen, + request_type=RequestType.CODE_GENERATION, + turn=Turn(user_content="thanks, that worked!", assistant_content="ok"), + ) + + prisma = _make_mock_prisma() + n_state = await router.queue.flush_state_to_db(prisma) + n_session = await router.queue.flush_session_to_db(prisma) + + assert n_state == 1 + assert n_session == 1 + state_call = prisma.db.litellm_adaptiverouterstate.upsert.call_args + # satisfaction signal -> +1 alpha, no existing row -> create.alpha == 1.0 + assert state_call.kwargs["data"]["create"]["alpha"] >= 1.0 + assert state_call.kwargs["data"]["create"]["beta"] == 0.0 + assert state_call.kwargs["data"]["create"]["total_samples"] == 1 + + session_call = prisma.db.litellm_adaptiveroutersession.upsert.call_args + assert session_call.kwargs["data"]["create"]["satisfaction_count"] == 1 + assert session_call.kwargs["data"]["create"]["session_id"] == "s1" + assert session_call.kwargs["data"]["create"]["model_name"] == chosen + + +@pytest.mark.asyncio +async def test_owner_cache_pins_attribution_to_first_picked_model(): + """First call claims ownership; matching model returns True, mismatch False.""" + router = _make_router() + chosen = await router.pick_model(RequestType.GENERAL) + assert router.claim_or_check_owner("sess-own", chosen) is True + + # Same model on later turns keeps attributing. + for _ in range(5): + assert router.claim_or_check_owner("sess-own", chosen) is True + + # A different model on a later turn is rejected. + other = "gpt-4o" if chosen == "gpt-4o-mini" else "gpt-4o-mini" + assert router.claim_or_check_owner("sess-own", other) is False + assert router._skipped_updates_total == 1 + + +@pytest.mark.asyncio +async def test_pick_model_returns_valid_models_without_error(): + router = _make_router() + # Picks may legitimately differ across calls (Thompson sampling is stochastic). + # Just confirm every pick is valid and nothing raises. + for _ in range(10): + m = await router.pick_model(RequestType.GENERAL) + assert m in router.config.available_models + + +@pytest.mark.asyncio +async def test_in_process_convergence_high_quality_model_dominates(): + """ + Two models, identical cost. "good" satisfies every turn, "bad" fails every turn. + After 50 sessions of 4 turns each, "good" should win >=70% of the last 10 picks. + Seed `random` for determinism since pick_best uses the module-level RNG. + """ + random.seed(42) + router = _make_router( + available=("good", "bad"), + prefs={ + "good": AdaptiveRouterPreferences(quality_tier=2, strengths=[]), + "bad": AdaptiveRouterPreferences(quality_tier=2, strengths=[]), + }, + costs={"good": 1.0, "bad": 1.0}, + ) + + picks = [] + for sess in range(50): + sid = f"conv-{sess}" + chosen = await router.pick_model(RequestType.GENERAL) + for _turn_i in range(4): + if chosen == "good": + turn = Turn(user_content="thanks!", assistant_content="ok") + else: + turn = Turn( + tool_calls=[{"name": "x", "arguments": {}}], + tool_results=[{"is_error": True, "content": "boom"}], + ) + await router.record_turn(sid, chosen, RequestType.GENERAL, turn) + picks.append(chosen) + + last_10 = picks[-10:] + good_share = last_10.count("good") / 10 + assert good_share >= 0.7, f"good_share={good_share} (last picks={picks})" + + +@pytest.mark.asyncio +async def test_failure_signal_increments_beta_after_flush(): + router = _make_router( + available=("only",), + prefs={"only": AdaptiveRouterPreferences(quality_tier=2, strengths=[])}, + costs={"only": 1.0}, + ) + chosen = await router.pick_model(RequestType.GENERAL) + assert chosen == "only" + + await router.record_turn( + session_id="f1", + model_name=chosen, + request_type=RequestType.GENERAL, + turn=Turn( + tool_calls=[{"name": "x", "arguments": {}}], + tool_results=[{"is_error": True, "content": ""}], + ), + ) + + prisma = _make_mock_prisma() + n_state = await router.queue.flush_state_to_db(prisma) + assert n_state == 1 + state_call = prisma.db.litellm_adaptiverouterstate.upsert.call_args + assert state_call.kwargs["data"]["create"]["beta"] >= 1.0 + assert state_call.kwargs["data"]["create"]["alpha"] == 0.0 + + +@pytest.mark.asyncio +async def test_load_state_from_db_overrides_cold_start(): + router = _make_router() + fake_row = MagicMock() + fake_row.request_type = RequestType.GENERAL.value + fake_row.model_name = "gpt-4o" + fake_row.alpha = 90.0 + fake_row.beta = 10.0 + + prisma = _make_mock_prisma() + prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[fake_row]) + + await router.load_state_from_db(prisma) + + cell = router._cells[(RequestType.GENERAL, "gpt-4o")] + assert cell.alpha == 90.0 + assert cell.beta == 10.0 + + +@pytest.mark.asyncio +async def test_load_state_from_db_handles_unknown_request_type(): + router = _make_router() + bad_row = MagicMock() + bad_row.request_type = "unknown_v1_type" + bad_row.model_name = "gpt-4o" + bad_row.alpha = 50.0 + bad_row.beta = 50.0 + + prisma = _make_mock_prisma() + prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[bad_row]) + + # Should not raise; bad row is silently skipped and cold-start cells remain. + await router.load_state_from_db(prisma) + cell = router._cells[(RequestType.GENERAL, "gpt-4o")] + # Cold-start: tier 3 base = 0.7, mass = 10 -> alpha = 7, beta = 3 + assert cell.alpha == pytest.approx(7.0) + assert cell.beta == pytest.approx(3.0) + + +@pytest.mark.asyncio +async def test_flush_isolates_writes_per_router_session_model(): + router = _make_router() + await router.record_turn( + "s1", "gpt-4o", RequestType.GENERAL, Turn(user_content="thanks!") + ) + await router.record_turn( + "s2", "gpt-4o-mini", RequestType.GENERAL, Turn(user_content="thanks!") + ) + + prisma = _make_mock_prisma() + n = await router.queue.flush_session_to_db(prisma) + assert n == 2 + assert prisma.db.litellm_adaptiveroutersession.upsert.call_count == 2 + + n_state = await router.queue.flush_state_to_db(prisma) + assert n_state == 2 + assert prisma.db.litellm_adaptiverouterstate.upsert.call_count == 2 + + +@pytest.mark.asyncio +async def test_repeated_flush_drains_queue_and_subsequent_flush_is_noop(): + """Verifies the queue is fully drained on flush -- a second flush writes nothing.""" + router = _make_router() + chosen = await router.pick_model(RequestType.GENERAL) + await router.record_turn( + "drain-1", chosen, RequestType.GENERAL, Turn(user_content="thanks!") + ) + + prisma = _make_mock_prisma() + assert await router.queue.flush_state_to_db(prisma) == 1 + assert await router.queue.flush_session_to_db(prisma) == 1 + + # Second drain should be a no-op (queue is empty). + assert await router.queue.flush_state_to_db(prisma) == 0 + assert await router.queue.flush_session_to_db(prisma) == 0 + assert prisma.db.litellm_adaptiverouterstate.upsert.call_count == 1 + assert prisma.db.litellm_adaptiveroutersession.upsert.call_count == 1 diff --git a/tests/test_litellm/router_strategy/adaptive_router/test_hooks.py b/tests/test_litellm/router_strategy/adaptive_router/test_hooks.py new file mode 100644 index 0000000000..17fc4fd732 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/test_hooks.py @@ -0,0 +1,329 @@ +"""Unit tests for the AdaptiveRouterPostCallHook.""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from litellm.router_strategy.adaptive_router.config import ( + ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY, + SIGNAL_GATE_MIN_MESSAGES, +) +from litellm.router_strategy.adaptive_router.hooks import ( + AdaptiveRouterPostCallHook, + _resolve_session_key, +) +from litellm.router_strategy.adaptive_router.signals import Turn + + +def _make_hook(claim: bool = True) -> AdaptiveRouterPostCallHook: + fake_router = MagicMock() + fake_router.record_turn = AsyncMock() + fake_router.claim_or_check_owner = MagicMock(return_value=claim) + return AdaptiveRouterPostCallHook(adaptive_router=fake_router) + + +def _resp_with_content(text: str, tool_calls=None): + """Build a ModelResponse-like object with a single assistant message.""" + msg = MagicMock() + msg.content = text + msg.tool_calls = tool_calls or [] + choice = MagicMock() + choice.message = msg + resp = MagicMock() + resp.choices = [choice] + return resp + + +def _long_messages(user_text: str = "ask"): + """Return a message list at the SIGNAL_GATE_MIN_MESSAGES threshold.""" + base = [ + {"role": "user", "content": "first turn"}, + {"role": "assistant", "content": "first reply"}, + {"role": "user", "content": "second turn"}, + ] + base.append({"role": "user", "content": user_text}) + # Pad to threshold if needed. + while len(base) < SIGNAL_GATE_MIN_MESSAGES: + base.append({"role": "user", "content": "filler"}) + return base + + +def _kwargs( + *, + messages=None, + chosen="fast", + extra_metadata=None, + extra_litellm_params=None, +): + metadata = {ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY: chosen} if chosen else {} + if extra_metadata: + metadata.update(extra_metadata) + lp = {"metadata": metadata} + if extra_litellm_params: + lp.update(extra_litellm_params) + return { + "model": "anthropic/claude-opus-4-7", + "messages": messages if messages is not None else _long_messages(), + "litellm_params": lp, + } + + +# ---- _resolve_session_key ------------------------------------------------ + + +def test_resolve_session_key_honors_litellm_session_id_on_litellm_params(): + key = _resolve_session_key({"litellm_params": {"litellm_session_id": "sess-A"}}) + assert key == "sess-A" + + +def test_resolve_session_key_honors_metadata_session_id(): + key = _resolve_session_key( + {"litellm_params": {"metadata": {"session_id": "sess-B"}}} + ) + assert key == "sess-B" + + +def test_resolve_session_key_returns_none_when_no_messages(): + assert _resolve_session_key({"litellm_params": {}}) is None + assert _resolve_session_key({"litellm_params": {}, "messages": []}) is None + + +def test_resolve_session_key_derives_stable_hash_from_first_message(): + msgs = [{"role": "user", "content": "Hello, world"}] + k1 = _resolve_session_key({"messages": msgs}) + k2 = _resolve_session_key({"messages": list(msgs)}) + assert k1 == k2 + assert k1 and len(k1) == 64 # sha256 hex + + +def test_resolve_session_key_does_not_prefix_sk(): + key = _resolve_session_key({"messages": [{"role": "user", "content": "hi"}]}) + assert key and not key.startswith("sk_") + + +def test_resolve_session_key_segments_by_identity_fields(): + """Same first message but different api keys must yield different keys.""" + msgs = [{"role": "user", "content": "same prompt"}] + k_team_a = _resolve_session_key( + { + "messages": msgs, + "litellm_params": { + "metadata": { + "user_api_key_hash": "hash-A", + "user_api_key_team_id": "team-1", + } + }, + } + ) + k_team_b = _resolve_session_key( + { + "messages": msgs, + "litellm_params": { + "metadata": { + "user_api_key_hash": "hash-B", + "user_api_key_team_id": "team-2", + } + }, + } + ) + assert k_team_a != k_team_b + + +def test_resolve_session_key_changes_when_first_message_changes(): + k1 = _resolve_session_key({"messages": [{"role": "user", "content": "alpha"}]}) + k2 = _resolve_session_key({"messages": [{"role": "user", "content": "beta"}]}) + assert k1 != k2 + + +# ---- _record gating ----------------------------------------------------- + + +@pytest.mark.asyncio +async def test_hook_skips_when_below_signal_gate(): + """Conversations shorter than SIGNAL_GATE_MIN_MESSAGES should be ignored.""" + hook = _make_hook() + short = [{"role": "user", "content": "hi"}] + assert len(short) < SIGNAL_GATE_MIN_MESSAGES # sanity + kwargs = _kwargs(messages=short) + await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0) + hook.adaptive_router.record_turn.assert_not_awaited() + hook.adaptive_router.claim_or_check_owner.assert_not_called() + + +@pytest.mark.asyncio +async def test_hook_skips_when_no_messages(): + hook = _make_hook() + kwargs = _kwargs(messages=[]) + await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0) + hook.adaptive_router.record_turn.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_hook_skips_when_chosen_model_missing_from_metadata(): + hook = _make_hook() + kwargs = _kwargs(chosen=None) + await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0) + hook.adaptive_router.record_turn.assert_not_awaited() + hook.adaptive_router.claim_or_check_owner.assert_not_called() + + +@pytest.mark.asyncio +async def test_hook_skips_when_owner_cache_mismatch(): + """A different model owns this conversation -> no attribution.""" + hook = _make_hook(claim=False) + kwargs = _kwargs(chosen="fast") + await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0) + hook.adaptive_router.claim_or_check_owner.assert_called_once() + hook.adaptive_router.record_turn.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_hook_records_turn_when_owner_claims(): + hook = _make_hook(claim=True) + kwargs = _kwargs(chosen="smart", messages=_long_messages("ask")) + await hook.async_log_success_event( + kwargs, _resp_with_content("answer here"), 0.0, 1.0 + ) + call = hook.adaptive_router.record_turn.await_args + assert call.kwargs["model_name"] == "smart" + turn: Turn = call.kwargs["turn"] + assert turn.user_content == "ask" + assert turn.assistant_content == "answer here" + assert turn.response_status == 200 + + +@pytest.mark.asyncio +async def test_hook_uses_explicit_session_id_when_provided(): + """Explicit `litellm_session_id` is forwarded as the session key.""" + hook = _make_hook() + kwargs = _kwargs( + chosen="fast", + extra_litellm_params={"litellm_session_id": "explicit-sess"}, + ) + await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0) + args, _ = hook.adaptive_router.claim_or_check_owner.call_args + assert args[0] == "explicit-sess" + assert hook.adaptive_router.record_turn.await_args.kwargs["session_id"] == ( + "explicit-sess" + ) + + +@pytest.mark.asyncio +async def test_hook_passes_tool_calls_through(): + hook = _make_hook() + tc = {"name": "search", "arguments": '{"q":"x"}'} + kwargs = _kwargs(chosen="fast") + await hook.async_log_success_event( + kwargs, _resp_with_content("calling tool", tool_calls=[tc]), 0.0, 1.0 + ) + turn: Turn = hook.adaptive_router.record_turn.await_args.kwargs["turn"] + assert turn.tool_calls == [tc] + + +@pytest.mark.asyncio +async def test_hook_swallows_exceptions_from_record_turn(): + hook = _make_hook() + hook.adaptive_router.record_turn.side_effect = RuntimeError("boom") + kwargs = _kwargs(chosen="fast") + # Must NOT raise — signal recording must never break a request. + await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0) + + +@pytest.mark.asyncio +async def test_hook_failure_event_uses_status_code_from_exception(): + hook = _make_hook() + exc = MagicMock() + exc.status_code = 429 + kwargs = _kwargs(chosen="fast") + kwargs["exception"] = exc + await hook.async_log_failure_event(kwargs, None, 0.0, 1.0) + turn: Turn = hook.adaptive_router.record_turn.await_args.kwargs["turn"] + assert turn.response_status == 429 + + +# ---- async_post_call_success_hook (response header surfacing) ---------- + + +@pytest.mark.asyncio +async def test_post_call_success_hook_sets_response_header(): + hook = _make_hook() + response = MagicMock() + response._hidden_params = {} + + await hook.async_post_call_success_hook( + data={"metadata": {"adaptive_router_chosen_model": "smart"}}, + user_api_key_dict=MagicMock(), + response=response, + ) + + assert ( + response._hidden_params["additional_headers"]["x-litellm-adaptive-router-model"] + == "smart" + ) + + +@pytest.mark.asyncio +async def test_post_call_success_hook_preserves_existing_additional_headers(): + hook = _make_hook() + response = MagicMock() + response._hidden_params = {"additional_headers": {"x-existing": "keep-me"}} + + await hook.async_post_call_success_hook( + data={"metadata": {"adaptive_router_chosen_model": "fast"}}, + user_api_key_dict=MagicMock(), + response=response, + ) + + assert response._hidden_params["additional_headers"]["x-existing"] == "keep-me" + assert ( + response._hidden_params["additional_headers"]["x-litellm-adaptive-router-model"] + == "fast" + ) + + +@pytest.mark.asyncio +async def test_post_call_success_hook_noop_when_metadata_missing_key(): + hook = _make_hook() + response = MagicMock() + response._hidden_params = {} + + await hook.async_post_call_success_hook( + data={"metadata": {"litellm_session_id": "sess-A"}}, + user_api_key_dict=MagicMock(), + response=response, + ) + + assert response._hidden_params == {} + + +@pytest.mark.asyncio +async def test_post_call_success_hook_noop_when_no_metadata(): + hook = _make_hook() + response = MagicMock() + response._hidden_params = {} + + await hook.async_post_call_success_hook( + data={}, + user_api_key_dict=MagicMock(), + response=response, + ) + + assert response._hidden_params == {} + + +@pytest.mark.asyncio +async def test_post_call_success_hook_noop_when_hidden_params_not_dict(): + hook = _make_hook() + + class _NoHiddenParams: + pass + + response = _NoHiddenParams() + + await hook.async_post_call_success_hook( + data={"metadata": {"adaptive_router_chosen_model": "smart"}}, + user_api_key_dict=MagicMock(), + response=response, + ) + + assert not hasattr(response, "_hidden_params") diff --git a/tests/test_litellm/router_strategy/adaptive_router/test_router_dispatch.py b/tests/test_litellm/router_strategy/adaptive_router/test_router_dispatch.py new file mode 100644 index 0000000000..7a67dac1a8 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/test_router_dispatch.py @@ -0,0 +1,380 @@ +"""Tests for the Router-level wiring of the adaptive router. + +Specifically guards the four bugs found when wiring the example config +`auto_router/adaptive_router` end-to-end: + +1. The `auto_router/adaptive_router` model prefix must NOT trigger the + semantic auto-router init path (which would crash on missing fields). +2. The same prefix MUST trigger the adaptive-router init path. +3. `init_adaptive_router_deployment` must read `input_cost_per_token` + from `litellm_params` (where users put it), not just `model_info`. +4. `Router.async_pre_routing_hook` must dispatch to the matching entry in + `self.adaptive_routers` when the inbound model matches a configured + adaptive-router name, returning the underlying model the bandit picked. +""" + +from unittest.mock import AsyncMock + +import pytest + +from litellm import Router +from litellm.types.router import LiteLLM_Params, RequestType + + +def _params(**overrides): + base = {"model": "auto_router/adaptive_router"} + base.update(overrides) + return LiteLLM_Params(**base) + + +# ---- Fix 1 & 2: opt-in prefix routing ----------------------------------- + + +def test_auto_router_check_excludes_adaptive_router_prefix(): + r = Router(model_list=[]) + assert ( + r._is_auto_router_deployment( + litellm_params=_params(model="auto_router/adaptive_router") + ) + is False + ) + + +def test_auto_router_check_excludes_complexity_router_prefix(): + r = Router(model_list=[]) + assert ( + r._is_auto_router_deployment( + litellm_params=_params(model="auto_router/complexity_router") + ) + is False + ) + + +def test_auto_router_check_still_matches_plain_auto_router_prefix(): + r = Router(model_list=[]) + assert ( + r._is_auto_router_deployment( + litellm_params=_params(model="auto_router/my-semantic-router") + ) + is True + ) + + +def test_adaptive_router_check_recognizes_prefix(): + r = Router(model_list=[]) + assert ( + r._is_adaptive_router_deployment( + litellm_params=_params(model="auto_router/adaptive_router") + ) + is True + ) + + +def test_adaptive_router_check_rejects_other_prefixes(): + r = Router(model_list=[]) + assert ( + r._is_adaptive_router_deployment(litellm_params=_params(model="openai/gpt-4o")) + is False + ) + + +# ---- Fix 3: cost field path -------------------------------------------- + + +def test_init_adaptive_router_reads_cost_from_litellm_params(): + r = Router( + model_list=[ + { + "model_name": "smart-cheap-router", + "litellm_params": { + "model": "auto_router/adaptive_router", + "adaptive_router_config": { + "available_models": ["fast", "smart"], + }, + }, + }, + { + "model_name": "fast", + "litellm_params": { + "model": "openai/gpt-4o-mini", + "input_cost_per_token": 0.00000015, + }, + "model_info": { + "adaptive_router_preferences": { + "quality_tier": 2, + "strengths": [], + } + }, + }, + { + "model_name": "smart", + "litellm_params": { + "model": "openai/gpt-4o", + "input_cost_per_token": 0.0000050, + }, + "model_info": { + "adaptive_router_preferences": { + "quality_tier": 3, + "strengths": ["code_generation"], + } + }, + }, + ] + ) + assert "smart-cheap-router" in r.adaptive_routers + assert r.adaptive_routers["smart-cheap-router"].model_to_cost == { + "fast": 0.00000015, + "smart": 0.0000050, + } + + +# ---- Fix 4: pre-routing dispatch --------------------------------------- + + +def _router_with_adaptive() -> Router: + return Router( + model_list=[ + { + "model_name": "smart-cheap-router", + "litellm_params": { + "model": "auto_router/adaptive_router", + "adaptive_router_config": { + "available_models": ["fast", "smart"], + }, + }, + }, + { + "model_name": "fast", + "litellm_params": { + "model": "openai/gpt-4o-mini", + "input_cost_per_token": 0.00000015, + }, + "model_info": { + "adaptive_router_preferences": { + "quality_tier": 2, + "strengths": [], + } + }, + }, + { + "model_name": "smart", + "litellm_params": { + "model": "openai/gpt-4o", + "input_cost_per_token": 0.0000050, + }, + "model_info": { + "adaptive_router_preferences": { + "quality_tier": 3, + "strengths": ["code_generation"], + } + }, + }, + ] + ) + + +@pytest.mark.asyncio +async def test_async_pre_routing_hook_dispatches_to_adaptive_router(): + r = _router_with_adaptive() + ar = r.adaptive_routers["smart-cheap-router"] + ar.pick_model = AsyncMock(return_value="smart") # type: ignore[assignment] + + response = await r.async_pre_routing_hook( + model="smart-cheap-router", + request_kwargs={"metadata": {"litellm_session_id": "sess-A"}}, + messages=[{"role": "user", "content": "Write a Python function"}], + ) + assert response is not None + assert response.model == "smart" + call = ar.pick_model.await_args # type: ignore[union-attr] + # Stateless routing: session_id is no longer passed to pick_model. + assert "session_id" not in call.kwargs + assert call.kwargs["request_type"] == RequestType.CODE_GENERATION + + +@pytest.mark.asyncio +async def test_async_pre_routing_hook_pick_model_not_passed_session_id(): + r = _router_with_adaptive() + ar = r.adaptive_routers["smart-cheap-router"] + ar.pick_model = AsyncMock(return_value="fast") # type: ignore[assignment] + + response = await r.async_pre_routing_hook( + model="smart-cheap-router", + request_kwargs={}, + messages=[{"role": "user", "content": "hello"}], + ) + assert response is not None + assert response.model == "fast" + assert "session_id" not in ar.pick_model.await_args.kwargs # type: ignore[union-attr] + + +@pytest.mark.asyncio +async def test_async_pre_routing_hook_returns_none_for_unrelated_model(): + r = _router_with_adaptive() + ar = r.adaptive_routers["smart-cheap-router"] + ar.pick_model = AsyncMock() # type: ignore[assignment] + response = await r.async_pre_routing_hook( + model="some-other-model", + request_kwargs={}, + messages=[{"role": "user", "content": "x"}], + ) + assert response is None + ar.pick_model.assert_not_awaited() # type: ignore[union-attr] + + +# ---- Response header surfacing ----------------------------------------- + + +@pytest.mark.asyncio +async def test_async_pre_routing_hook_stashes_chosen_model_in_metadata(): + """ + The adaptive-router branch must record the chosen logical model on + `request_kwargs["metadata"]` so `_acompletion` can surface it as the + `x-litellm-adaptive-router-model` response header. + """ + r = _router_with_adaptive() + r.adaptive_routers["smart-cheap-router"].pick_model = AsyncMock( # type: ignore[assignment] + return_value="smart" + ) + + request_kwargs: dict = {"metadata": {"litellm_session_id": "sess-A"}} + await r.async_pre_routing_hook( + model="smart-cheap-router", + request_kwargs=request_kwargs, + messages=[{"role": "user", "content": "Write a Python function"}], + ) + assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "smart" + + +@pytest.mark.asyncio +async def test_async_pre_routing_hook_creates_metadata_when_missing(): + """If no metadata was passed in, the hook should create one to stash the chosen model.""" + r = _router_with_adaptive() + r.adaptive_routers["smart-cheap-router"].pick_model = AsyncMock( # type: ignore[assignment] + return_value="fast" + ) + + request_kwargs: dict = {} + await r.async_pre_routing_hook( + model="smart-cheap-router", + request_kwargs=request_kwargs, + messages=[{"role": "user", "content": "hello"}], + ) + assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "fast" + + +# ---- Multi-router support ---------------------------------------------- + + +def test_two_adaptive_routers_can_coexist_on_one_router(): + r = Router( + model_list=[ + { + "model_name": "cheap-router", + "litellm_params": { + "model": "auto_router/adaptive_router", + "adaptive_router_config": {"available_models": ["fast"]}, + }, + }, + { + "model_name": "premium-router", + "litellm_params": { + "model": "auto_router/adaptive_router", + "adaptive_router_config": {"available_models": ["smart"]}, + }, + }, + { + "model_name": "fast", + "litellm_params": { + "model": "openai/gpt-4o-mini", + "input_cost_per_token": 0.00000015, + }, + }, + { + "model_name": "smart", + "litellm_params": { + "model": "openai/gpt-4o", + "input_cost_per_token": 0.0000050, + }, + }, + ] + ) + assert set(r.adaptive_routers.keys()) == {"cheap-router", "premium-router"} + assert r.adaptive_routers["cheap-router"].config.available_models == ["fast"] + assert r.adaptive_routers["premium-router"].config.available_models == ["smart"] + + +@pytest.mark.asyncio +async def test_async_pre_routing_hook_dispatches_to_correct_router_when_multiple(): + """Each adaptive router only handles its own router_name.""" + r = Router( + model_list=[ + { + "model_name": "cheap-router", + "litellm_params": { + "model": "auto_router/adaptive_router", + "adaptive_router_config": {"available_models": ["fast"]}, + }, + }, + { + "model_name": "premium-router", + "litellm_params": { + "model": "auto_router/adaptive_router", + "adaptive_router_config": {"available_models": ["smart"]}, + }, + }, + { + "model_name": "fast", + "litellm_params": { + "model": "openai/gpt-4o-mini", + "input_cost_per_token": 0.00000015, + }, + }, + { + "model_name": "smart", + "litellm_params": { + "model": "openai/gpt-4o", + "input_cost_per_token": 0.0000050, + }, + }, + ] + ) + cheap = r.adaptive_routers["cheap-router"] + premium = r.adaptive_routers["premium-router"] + cheap.pick_model = AsyncMock(return_value="fast") # type: ignore[assignment] + premium.pick_model = AsyncMock(return_value="smart") # type: ignore[assignment] + + cheap_response = await r.async_pre_routing_hook( + model="cheap-router", + request_kwargs={}, + messages=[{"role": "user", "content": "hi"}], + ) + premium_response = await r.async_pre_routing_hook( + model="premium-router", + request_kwargs={}, + messages=[{"role": "user", "content": "hi"}], + ) + + assert cheap_response is not None and cheap_response.model == "fast" + assert premium_response is not None and premium_response.model == "smart" + cheap.pick_model.assert_awaited_once() # type: ignore[union-attr] + premium.pick_model.assert_awaited_once() # type: ignore[union-attr] + + +def test_init_adaptive_router_rejects_duplicate_model_name(): + """Two adaptive-router deployments with the same model_name must error.""" + from litellm.types.router import AdaptiveRouterConfig, Deployment + + r = Router(model_list=[]) + cfg = {"available_models": ["fast"]} + deployment = Deployment( + model_name="dup-router", + litellm_params=LiteLLM_Params( + model="auto_router/adaptive_router", + adaptive_router_config=cfg, + ), + model_info={"id": "x"}, + ) + r.init_adaptive_router_deployment(deployment=deployment) + with pytest.raises(ValueError, match="already exists"): + r.init_adaptive_router_deployment(deployment=deployment) diff --git a/tests/test_litellm/router_strategy/adaptive_router/test_signals.py b/tests/test_litellm/router_strategy/adaptive_router/test_signals.py new file mode 100644 index 0000000000..bf09b1b16f --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/test_signals.py @@ -0,0 +1,112 @@ +import json +from pathlib import Path +from typing import List, Tuple + +import pytest + +from litellm.router_strategy.adaptive_router.config import TOOL_CALL_HISTORY_MAX +from litellm.router_strategy.adaptive_router.signals import ( + SessionState, + SignalDelta, + Turn, + apply_turn, +) + +FIXTURE_DIR = Path(__file__).parent / "fixtures" + + +def _load(name: str) -> list: + return json.loads((FIXTURE_DIR / f"{name}.json").read_text()) + + +def _replay(turns: list) -> Tuple[SessionState, List[SignalDelta]]: + state = SessionState( + session_id="s", + router_name="r", + model_name="m", + classified_type="general", + ) + deltas: List[SignalDelta] = [] + for t in turns: + deltas.append( + apply_turn( + state, + Turn( + user_content=t.get("user_content"), + assistant_content=t.get("assistant_content"), + tool_calls=t.get("tool_calls", []), + tool_results=t.get("tool_results", []), + response_status=t.get("response_status"), + ), + ) + ) + return state, deltas + + +def test_clean_satisfaction_fires_satisfaction_only(): + state, _ = _replay(_load("clean_satisfaction")) + assert state.satisfaction_count >= 1 + assert state.failure_count == 0 + assert state.disengagement_count == 0 + + +def test_misalignment_fires_on_rephrase(): + state, _ = _replay(_load("misalignment_rephrase")) + assert state.misalignment_count >= 1 + + +def test_stagnation_fires_on_repeated_assistant(): + state, _ = _replay(_load("stagnation_repeat")) + assert state.stagnation_count >= 1 + + +def test_disengagement_fires_on_giveup(): + state, _ = _replay(_load("disengagement_giveup")) + assert state.disengagement_count >= 1 + + +def test_failure_fires_on_tool_error(): + state, _ = _replay(_load("failure_tool_error")) + assert state.failure_count == 1 + + +def test_loop_fires_on_repeated_tool(): + state, _ = _replay(_load("loop_same_tool")) + assert state.loop_count >= 1 + + +@pytest.mark.parametrize("fixture", ["exhaustion_429", "exhaustion_context_overflow"]) +def test_exhaustion_fires_on_infra_signal(fixture): + state, _ = _replay(_load(fixture)) + assert state.exhaustion_count >= 1 + + +def test_no_signals_on_clean_session(): + state, _ = _replay(_load("clean_no_signals")) + assert state.misalignment_count == 0 + assert state.stagnation_count == 0 + assert state.disengagement_count == 0 + assert state.failure_count == 0 + assert state.loop_count == 0 + assert state.exhaustion_count == 0 + + +def test_mixed_failure_then_satisfaction(): + state, _ = _replay(_load("mixed_failure_then_satisfaction")) + assert state.failure_count >= 1 + assert state.satisfaction_count >= 1 + + +def test_apply_turn_is_o1_does_not_grow_history_unbounded(): + state = SessionState( + session_id="s", + router_name="r", + model_name="m", + classified_type="general", + ) + for i in range(100): + apply_turn( + state, + Turn(tool_calls=[{"name": f"tool_{i}", "arguments": {}}]), + ) + assert len(state.tool_call_history) <= TOOL_CALL_HISTORY_MAX diff --git a/tests/test_litellm/router_strategy/adaptive_router/test_state_endpoint.py b/tests/test_litellm/router_strategy/adaptive_router/test_state_endpoint.py new file mode 100644 index 0000000000..80fa2dc8a5 --- /dev/null +++ b/tests/test_litellm/router_strategy/adaptive_router/test_state_endpoint.py @@ -0,0 +1,196 @@ +"""Tests for the GET /adaptive_router/state introspection endpoint and the +underlying `AdaptiveRouter.get_state_snapshot()` helper.""" + +import time +from unittest.mock import MagicMock + +import pytest +from fastapi import HTTPException + +from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth +from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter +from litellm.router_strategy.adaptive_router.bandit import BanditCell, apply_delta +from litellm.types.router import ( + AdaptiveRouterConfig, + AdaptiveRouterPreferences, + RequestType, +) + + +def _make_router(name: str = "r1") -> AdaptiveRouter: + cfg = AdaptiveRouterConfig(available_models=["fast", "smart"]) + prefs = { + "fast": AdaptiveRouterPreferences(quality_tier=1, strengths=[]), + "smart": AdaptiveRouterPreferences( + quality_tier=3, strengths=[RequestType.CODE_GENERATION] + ), + } + costs = {"fast": 0.0001, "smart": 0.001} + return AdaptiveRouter( + router_name=name, + config=cfg, + model_to_prefs=prefs, + model_to_cost=costs, + ) + + +# ---- snapshot helper --------------------------------------------------- + + +@pytest.mark.asyncio +async def test_get_state_snapshot_returns_cell_per_request_type_per_model(): + r = _make_router() + snap = await r.get_state_snapshot() + + # Top-level shape + assert snap["router_name"] == "r1" + assert snap["available_models"] == ["fast", "smart"] + assert snap["weights"] == {"quality": 0.7, "cost": 0.3} + assert snap["model_costs"] == {"fast": 0.0001, "smart": 0.001} + assert snap["owner_cache_live"] == 0 + assert snap["skipped_updates_total"] == 0 + assert set(snap["queue"].keys()) == { + "state_pending", + "session_pending", + "max_state_seen", + "max_session_seen", + } + + # 7 request types x 2 models = 14 cells + assert len(snap["cells"]) == len(list(RequestType)) * 2 + for cell in snap["cells"]: + assert set(cell.keys()) == { + "request_type", + "model", + "alpha", + "beta", + "samples", + "quality_mean", + } + assert cell["model"] in {"fast", "smart"} + assert cell["request_type"] in {rt.value for rt in RequestType} + + +@pytest.mark.asyncio +async def test_get_state_snapshot_quality_mean_matches_alpha_over_total(): + r = _make_router() + + # Manually mutate one cell to a known state so the math is verifiable. + key = (RequestType.CODE_GENERATION, "smart") + r._cells[key] = apply_delta(r._cells[key], delta_alpha=10.0, delta_beta=0.0) + expected = r._cells[key] + expected_mean = expected.alpha / (expected.alpha + expected.beta) + + snap = await r.get_state_snapshot() + cell = next( + c + for c in snap["cells"] + if c["request_type"] == "code_generation" and c["model"] == "smart" + ) + assert cell["alpha"] == expected.alpha + assert cell["beta"] == expected.beta + assert cell["samples"] == expected.alpha + expected.beta + assert cell["quality_mean"] == pytest.approx(expected_mean) + + +@pytest.mark.asyncio +async def test_get_state_snapshot_counts_only_live_owner_cache_entries(): + r = _make_router() + now = time.time() + r._owner_cache["live-1"] = ("fast", now + 3600) + r._owner_cache["live-2"] = ("smart", now + 3600) + r._owner_cache["expired-1"] = ("fast", now - 1) + + snap = await r.get_state_snapshot() + assert snap["owner_cache_live"] == 2 + + +@pytest.mark.asyncio +async def test_get_state_snapshot_exposes_skipped_updates_total(): + r = _make_router() + r._skipped_updates_total = 7 + snap = await r.get_state_snapshot() + assert snap["skipped_updates_total"] == 7 + + +# ---- endpoint -------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_endpoint_returns_404_when_no_adaptive_router(monkeypatch): + """When llm_router is set but has no adaptive routers configured, return 404.""" + from litellm.proxy import proxy_server + + fake_router = MagicMock() + fake_router.adaptive_routers = {} + monkeypatch.setattr(proxy_server, "llm_router", fake_router) + + admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN) + with pytest.raises(HTTPException) as exc: + await proxy_server.get_adaptive_router_state(user_api_key_dict=admin) + assert exc.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_endpoint_returns_404_when_llm_router_is_none(monkeypatch): + from litellm.proxy import proxy_server + + monkeypatch.setattr(proxy_server, "llm_router", None) + + admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN) + with pytest.raises(HTTPException) as exc: + await proxy_server.get_adaptive_router_state(user_api_key_dict=admin) + assert exc.value.status_code == 404 + + +@pytest.mark.asyncio +async def test_endpoint_rejects_non_admin_role(monkeypatch): + from litellm.proxy import proxy_server + + fake_router = MagicMock() + fake_router.adaptive_routers = {"r1": _make_router()} + monkeypatch.setattr(proxy_server, "llm_router", fake_router) + + non_admin = UserAPIKeyAuth( + api_key="sk-user", user_role=LitellmUserRoles.INTERNAL_USER + ) + with pytest.raises(HTTPException) as exc: + await proxy_server.get_adaptive_router_state(user_api_key_dict=non_admin) + assert exc.value.status_code == 403 + + +@pytest.mark.asyncio +async def test_endpoint_returns_snapshot_list_for_admin(monkeypatch): + """Single configured router still returns the {"routers": [...]} list shape.""" + from litellm.proxy import proxy_server + + fake_router = MagicMock() + fake_router.adaptive_routers = {"r1": _make_router("r1")} + monkeypatch.setattr(proxy_server, "llm_router", fake_router) + + admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN) + result = await proxy_server.get_adaptive_router_state(user_api_key_dict=admin) + assert list(result.keys()) == ["routers"] + assert len(result["routers"]) == 1 + snap = result["routers"][0] + assert snap["router_name"] == "r1" + assert snap["available_models"] == ["fast", "smart"] + assert len(snap["cells"]) == len(list(RequestType)) * 2 + + +@pytest.mark.asyncio +async def test_endpoint_returns_one_snapshot_per_router(monkeypatch): + """With multiple adaptive routers configured, return one snapshot per router.""" + from litellm.proxy import proxy_server + + fake_router = MagicMock() + fake_router.adaptive_routers = { + "r1": _make_router("r1"), + "r2": _make_router("r2"), + } + monkeypatch.setattr(proxy_server, "llm_router", fake_router) + + admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN) + result = await proxy_server.get_adaptive_router_state(user_api_key_dict=admin) + names = sorted(s["router_name"] for s in result["routers"]) + assert names == ["r1", "r2"] diff --git a/uv.lock b/uv.lock index c403884a04..3accbc0303 100644 --- a/uv.lock +++ b/uv.lock @@ -10,7 +10,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-13T16:35:18.496811Z" +exclude-newer = "2026-04-15T20:11:16.497522Z" exclude-newer-span = "P3D" [manifest] @@ -3767,7 +3767,7 @@ wheels = [ [[package]] name = "litellm" -version = "1.83.8" +version = "1.83.9" source = { editable = "." } dependencies = [ { name = "aiohttp" }, @@ -4114,7 +4114,7 @@ source = { editable = "enterprise" } [[package]] name = "litellm-proxy-extras" -version = "0.4.65" +version = "0.4.66" source = { editable = "litellm-proxy-extras" } [[package]]