feat: add adaptive routing to litellm

allow model routing to improve based on conversation signals

ensures router is picking best model for task
This commit is contained in:
Krrish Dholakia 2026-04-18 16:35:17 -07:00
parent 850fe595ac
commit dd4a1d2be2
76 changed files with 4542 additions and 42 deletions

View File

@ -0,0 +1,39 @@
-- One row per (router, request_type, model). Hot path on every routing decision.
CREATE TABLE "LiteLLM_AdaptiveRouterState" (
router_name TEXT NOT NULL,
request_type TEXT NOT NULL,
model_name TEXT NOT NULL,
alpha DOUBLE PRECISION NOT NULL,
beta DOUBLE PRECISION NOT NULL,
total_samples INTEGER NOT NULL DEFAULT 0,
last_updated_at TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (router_name, request_type, model_name)
);
-- One row per (session, router, model). Updated per turn via the queue.
CREATE TABLE "LiteLLM_AdaptiveRouterSession" (
session_id TEXT NOT NULL,
router_name TEXT NOT NULL,
model_name TEXT NOT NULL,
classified_type TEXT NOT NULL,
misalignment_count INTEGER DEFAULT 0,
stagnation_count INTEGER DEFAULT 0,
disengagement_count INTEGER DEFAULT 0,
satisfaction_count INTEGER DEFAULT 0,
failure_count INTEGER DEFAULT 0,
loop_count INTEGER DEFAULT 0,
exhaustion_count INTEGER DEFAULT 0,
last_user_content TEXT,
last_assistant_content TEXT,
tool_call_history JSONB DEFAULT '[]',
pending_tool_calls JSONB DEFAULT '{}',
turn_count INTEGER DEFAULT 0,
last_processed_turn INTEGER DEFAULT -1,
clean_credit_awarded BOOLEAN DEFAULT FALSE,
terminal_status INTEGER,
last_activity_at TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
PRIMARY KEY (session_id, router_name, model_name)
);
CREATE INDEX "idx_adaptive_router_session_activity"
ON "LiteLLM_AdaptiveRouterSession" (last_activity_at);

View File

@ -1219,3 +1219,46 @@ model LiteLLM_ClaudeCodePluginTable {
@@map("LiteLLM_ClaudeCodePluginTable")
}
// Per-(router, request_type, model) Beta posterior for the adaptive router.
model LiteLLM_AdaptiveRouterState {
router_name String
request_type String
model_name String
alpha Float
beta Float
total_samples Int @default(0)
last_updated_at DateTime @default(now())
@@id([router_name, request_type, model_name])
}
// Per-(session, router, model) signal counters for the adaptive router.
model LiteLLM_AdaptiveRouterSession {
session_id String
router_name String
model_name String
classified_type String
misalignment_count Int @default(0)
stagnation_count Int @default(0)
disengagement_count Int @default(0)
satisfaction_count Int @default(0)
failure_count Int @default(0)
loop_count Int @default(0)
exhaustion_count Int @default(0)
last_user_content String?
last_assistant_content String?
tool_call_history Json @default("[]")
pending_tool_calls Json @default("{}")
turn_count Int @default(0)
last_processed_turn Int @default(-1)
clean_credit_awarded Boolean @default(false)
terminal_status Int?
last_activity_at DateTime @default(now())
@@id([session_id, router_name, model_name])
@@index([last_activity_at])
}

View File

@ -1,32 +1,83 @@
# model_list:
# - model_name: claude-sonnet-4-6
# litellm_params: {model: anthropic/claude-sonnet-4-6}
# model_info:
# litellm_routing_preferences:
# quality_tier: 1
# keywords: [tin]
# - model_name: gpt-4o-mini
# litellm_params: {model: openai/gpt-4o-mini}
# model_info:
# litellm_routing_preferences:
# quality_tier: 1
# keywords: []
# - model_name: gpt-4o
# litellm_params: {model: openai/gpt-4o}
# model_info:
# litellm_routing_preferences:
# quality_tier: 2
# keywords: [vision, function_calling]
# - model_name: opus
# litellm_params: {model: anthropic/claude-opus-4-7}
# model_info:
# litellm_routing_preferences:
# quality_tier: 3
# keywords: ["architecture", "design"]
# - model_name: my-quality-router
# litellm_params:
# model: auto_router/adaptive_router
# adaptive_router_default_model: gpt-4o-mini
# adaptive_router_config:
# available_models: [gpt-4o-mini, gpt-4o, opus, claude-sonnet-4-6]
# Example proxy config for the adaptive router (v0).
#
# Wires one logical router ("smart-cheap-router") that adaptively picks between
# two real deployments ("fast" and "smart") based on per-session feedback signals.
#
# How to use from a client:
# POST /v1/chat/completions { "model": "smart-cheap-router", ... }
# Add { "metadata": { "litellm_session_id": "<your-session-id>" } } to enable
# sticky-session routing within a conversation.
#
# Required env vars: OPENAI_API_KEY, DATABASE_URL.
model_list:
# OpenAI model for /v1/chat/completions test — 200x custom pricing
- model_name: "gpt-4.1-mini"
# ---- The adaptive router "control" deployment -------------------------
# `model_name` is what clients call. `available_models` lists the underlying
# deployments the router is allowed to pick from (must match other model_name
# entries in this list).
- model_name: smart-cheap-router
litellm_params:
model: openai/gpt-4.1-mini
api_key: os.environ/OPENAI_API_KEY
model_info:
id: gpt-4.1-mini-custom-pricing
input_cost_per_token: 0.00004 # 100x standard ($0.40/1M = $0.0000004)
output_cost_per_token: 0.00016 # 100x standard ($1.60/1M = $0.0000016)
model: auto_router/adaptive_router
adaptive_router_config:
available_models: ["fast", "smart"]
weights:
quality: 0.7
cost: 0.3
# OpenAI model for /v1/responses test — 100x custom pricing
- model_name: "gpt-5"
# ---- Underlying deployments the router picks from ---------------------
- model_name: fast
litellm_params:
model: openai/gpt-5
api_key: os.environ/OPENAI_API_KEY
model_info:
id: gpt-5-custom-pricing
mode: "chat"
input_cost_per_token: 125 # 100x standard ($1.25/1M = $0.00000125)
output_cost_per_token: 10 # 100x standard ($10.00/1M = $0.00001)
# Anthropic model for /v1/messages test — 100x custom pricing
- model_name: "claude-sonnet-4-20250514"
litellm_params:
model: anthropic/claude-sonnet-4-20250514
model: anthropic/claude-sonnet-4-6
api_key: os.environ/ANTHROPIC_API_KEY
input_cost_per_token: 0.00000015
model_info:
id: claude-sonnet-4-custom-pricing
input_cost_per_token: 0.0003 # 100x standard ($0.000003)
output_cost_per_token: 0.0015 # 100x standard ($0.000015)
adaptive_router_preferences:
quality_tier: 2
strengths: []
- model_name: smart
litellm_params:
model: anthropic/claude-opus-4-7
api_key: os.environ/ANTHROPIC_API_KEY
input_cost_per_token: 0.0000050
model_info:
adaptive_router_preferences:
quality_tier: 3
strengths: ["code_generation", "technical_design", "analytical_reasoning"]
litellm_settings:
drop_params: True
general_settings:
master_key: sk-1234 # REPLACE in production

View File

@ -0,0 +1,216 @@
"""
In-memory queues for adaptive router state and session updates.
Pattern follows DailySpendUpdateQueue: hot path is fully in-memory; a background
flusher task drains the aggregator and writes batches to Postgres.
Two logical queues (one class):
1. STATE updates: increments to (router, request_type, model) bandit cell.
Aggregator key = (router_name, request_type, model_name)
Aggregated payload = {"delta_alpha": float, "delta_beta": float, "samples_added": int}
2. SESSION updates: full snapshot of a session row (last-write-wins per session+router+model).
Aggregator key = (session_id, router_name, model_name)
Aggregated payload = the full session state dict.
Hot-path API is non-blocking and synchronous from the caller's POV (it just appends
to the in-memory aggregator). Flush is async and batched.
"""
from __future__ import annotations
import asyncio
from typing import Any, Dict, Tuple
from litellm._logging import verbose_proxy_logger
StateKey = Tuple[str, str, str] # (router_name, request_type, model_name)
SessionKey = Tuple[str, str, str] # (session_id, router_name, model_name)
class AdaptiveRouterUpdateQueue:
"""
Single class managing both state-update aggregation and session-snapshot aggregation.
Held by the AdaptiveRouter strategy instance and started by the proxy on boot.
"""
def __init__(self) -> None:
self._state_agg: Dict[StateKey, Dict[str, float]] = {}
self._session_agg: Dict[SessionKey, Dict[str, Any]] = {}
self._lock = asyncio.Lock()
self._max_state_size_seen = 0
self._max_session_size_seen = 0
# ---- Hot-path: state delta -------------------------------------------
async def add_state_delta(
self,
router_name: str,
request_type: str,
model_name: str,
delta_alpha: float,
delta_beta: float,
) -> None:
"""Aggregate a bandit-cell delta. Multiple deltas to the same cell sum."""
key: StateKey = (router_name, request_type, model_name)
async with self._lock:
current = self._state_agg.get(key)
if current is None:
self._state_agg[key] = {
"delta_alpha": delta_alpha,
"delta_beta": delta_beta,
"samples_added": 1,
}
else:
current["delta_alpha"] += delta_alpha
current["delta_beta"] += delta_beta
current["samples_added"] += 1
if len(self._state_agg) > self._max_state_size_seen:
self._max_state_size_seen = len(self._state_agg)
# ---- Hot-path: session snapshot --------------------------------------
async def add_session_state(
self,
session_id: str,
router_name: str,
model_name: str,
state_dict: Dict[str, Any],
) -> None:
"""
Last-write-wins per session row. The state_dict is a snapshot of the
SessionState (signals counts + bookkeeping fields). The flusher will
upsert this into LiteLLM_AdaptiveRouterSession.
"""
key: SessionKey = (session_id, router_name, model_name)
async with self._lock:
self._session_agg[key] = state_dict
if len(self._session_agg) > self._max_session_size_seen:
self._max_session_size_seen = len(self._session_agg)
# ---- Flushers (called by background task) ----------------------------
async def flush_state_to_db(self, prisma_client: Any) -> int:
"""
Drain state aggregator and apply to LiteLLM_AdaptiveRouterState.
Returns number of cells flushed.
"""
async with self._lock:
batch = self._state_agg
self._state_agg = {}
if not batch:
return 0
# Sort keys to give deterministic write order across writers and
# reduce the chance of cross-row deadlocks when other workers race us.
for key in sorted(batch.keys()):
router, rt, model = key
payload = batch[key]
try:
existing = (
await prisma_client.db.litellm_adaptiverouterstate.find_unique(
where={
"router_name_request_type_model_name": {
"router_name": router,
"request_type": rt,
"model_name": model,
}
}
)
)
new_alpha = (existing.alpha if existing else 0.0) + payload[
"delta_alpha"
]
new_beta = (existing.beta if existing else 0.0) + payload["delta_beta"]
new_samples = (existing.total_samples if existing else 0) + int(
payload["samples_added"]
)
await prisma_client.db.litellm_adaptiverouterstate.upsert(
where={
"router_name_request_type_model_name": {
"router_name": router,
"request_type": rt,
"model_name": model,
}
},
data={
"create": {
"router_name": router,
"request_type": rt,
"model_name": model,
"alpha": new_alpha,
"beta": new_beta,
"total_samples": new_samples,
},
"update": {
"alpha": new_alpha,
"beta": new_beta,
"total_samples": new_samples,
},
},
)
except Exception as e:
verbose_proxy_logger.exception(
"AdaptiveRouterUpdateQueue: failed to flush state for %s: %s",
key,
e,
)
return len(batch)
async def flush_session_to_db(self, prisma_client: Any) -> int:
"""
Drain session aggregator and upsert into LiteLLM_AdaptiveRouterSession.
Returns number of session rows flushed.
"""
async with self._lock:
batch = self._session_agg
self._session_agg = {}
if not batch:
return 0
for key in sorted(batch.keys()):
session_id, router, model = key
payload = batch[key]
try:
# NOTE: Prisma client lower-cases model names, so
# `LiteLLM_AdaptiveRouterSession` -> `litellm_adaptiveroutersession`
# (single 's', not 'litellm_adaptiverouterssession').
await prisma_client.db.litellm_adaptiveroutersession.upsert(
where={
"session_id_router_name_model_name": {
"session_id": session_id,
"router_name": router,
"model_name": model,
}
},
data={
"create": {
"session_id": session_id,
"router_name": router,
"model_name": model,
**payload,
},
"update": payload,
},
)
except Exception as e:
verbose_proxy_logger.exception(
"AdaptiveRouterUpdateQueue: failed to flush session for %s: %s",
key,
e,
)
return len(batch)
# ---- Observability ---------------------------------------------------
async def queue_size(self) -> Dict[str, int]:
async with self._lock:
return {
"state_pending": len(self._state_agg),
"session_pending": len(self._session_agg),
"max_state_seen": self._max_state_size_seen,
"max_session_seen": self._max_session_size_seen,
}

View File

@ -0,0 +1,52 @@
# Example proxy config for the adaptive router (v0).
#
# Wires one logical router ("smart-cheap-router") that adaptively picks between
# two real deployments ("fast" and "smart") based on per-session feedback signals.
#
# How to use from a client:
# POST /v1/chat/completions { "model": "smart-cheap-router", ... }
# Add { "metadata": { "litellm_session_id": "<your-session-id>" } } to enable
# sticky-session routing within a conversation.
#
# Required env vars: OPENAI_API_KEY, DATABASE_URL.
model_list:
# ---- The adaptive router "control" deployment -------------------------
# `model_name` is what clients call. `available_models` lists the underlying
# deployments the router is allowed to pick from (must match other model_name
# entries in this list).
- model_name: smart-cheap-router
litellm_params:
model: openai/gpt-4o-mini # placeholder; never actually called -- router picks from available_models
adaptive_router_config:
available_models: ["fast", "smart"]
weights:
quality: 0.7
cost: 0.3
# ---- Underlying deployments the router picks from ---------------------
- model_name: fast
litellm_params:
model: openai/gpt-4o-mini
api_key: os.environ/OPENAI_API_KEY
input_cost_per_token: 0.00000015
model_info:
adaptive_router_preferences:
quality_tier: 2
strengths: []
- model_name: smart
litellm_params:
model: openai/gpt-4o
api_key: os.environ/OPENAI_API_KEY
input_cost_per_token: 0.0000050
model_info:
adaptive_router_preferences:
quality_tier: 3
strengths: ["code_generation", "technical_design", "analytical_reasoning"]
litellm_settings:
drop_params: True
general_settings:
master_key: sk-1234 # REPLACE in production

View File

@ -952,6 +952,10 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915
_run_background_health_check()
) # start the background health check coroutine.
# Start adaptive-router queue flusher if any AdaptiveRouter is configured.
if llm_router is not None and getattr(llm_router, "adaptive_routers", None):
asyncio.create_task(_adaptive_router_flusher_loop())
## [Optional] Initialize dd tracer
ProxyStartupEvent._init_dd_tracer()
@ -2201,9 +2205,11 @@ def run_ollama_serve():
with open(os.devnull, "w") as devnull:
subprocess.Popen(command, stdout=devnull, stderr=devnull)
except Exception as e:
verbose_proxy_logger.debug(f"""
verbose_proxy_logger.debug(
f"""
LiteLLM Warning: proxy started with `ollama` model\n`ollama serve` failed with Exception{e}. \nEnsure you run `ollama serve`
""")
"""
)
def _get_process_rss_mb() -> Optional[float]:
@ -2385,6 +2391,31 @@ def _write_health_state_to_router_cache(
)
_ADAPTIVE_ROUTER_FLUSH_INTERVAL_SECONDS = 10
async def _adaptive_router_flusher_loop():
"""
Drain every AdaptiveRouter's in-memory state + session aggregators into
Postgres on a fixed cadence. Hot-path writes go to memory; this loop is
the only writer to the adaptive router DB tables.
"""
global llm_router, prisma_client
while True:
try:
await asyncio.sleep(_ADAPTIVE_ROUTER_FLUSH_INTERVAL_SECONDS)
adaptive_routers = getattr(llm_router, "adaptive_routers", None) or {}
if not adaptive_routers or prisma_client is None:
continue
for ar in adaptive_routers.values():
await ar.queue.flush_state_to_db(prisma_client)
await ar.queue.flush_session_to_db(prisma_client)
except asyncio.CancelledError:
raise
except Exception:
verbose_proxy_logger.exception("adaptive_router flusher iteration failed")
async def _run_background_health_check():
"""
Periodically run health checks in the background on the endpoints.
@ -13877,6 +13908,38 @@ async def home(request: Request):
return "LiteLLM: RUNNING"
@router.get(
"/adaptive_router/state",
tags=["adaptive_router"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_adaptive_router_state(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
"""Return live bandit posteriors + queue depth for every configured adaptive router.
Admin-only. Returns 404 if no adaptive router is configured.
Response shape: `{"routers": [<snapshot>, ...]}` one snapshot per
adaptive-router deployment. Each snapshot's `router_name` field identifies
which deployment it came from.
"""
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(
status_code=403,
detail={"error": CommonProxyErrors.not_allowed_access.value},
)
if llm_router is None or not llm_router.adaptive_routers:
raise HTTPException(
status_code=404,
detail={"error": "No adaptive_router is configured on this proxy."},
)
snapshots = [
await ar.get_state_snapshot() for ar in llm_router.adaptive_routers.values()
]
return {"routers": snapshots}
@router.get("/routes", dependencies=[Depends(user_api_key_auth)])
async def get_routes():
"""

View File

@ -1219,3 +1219,46 @@ model LiteLLM_ClaudeCodePluginTable {
@@map("LiteLLM_ClaudeCodePluginTable")
}
// Per-(router, request_type, model) Beta posterior for the adaptive router.
model LiteLLM_AdaptiveRouterState {
router_name String
request_type String
model_name String
alpha Float
beta Float
total_samples Int @default(0)
last_updated_at DateTime @default(now())
@@id([router_name, request_type, model_name])
}
// Per-(session, router, model) signal counters for the adaptive router.
model LiteLLM_AdaptiveRouterSession {
session_id String
router_name String
model_name String
classified_type String
misalignment_count Int @default(0)
stagnation_count Int @default(0)
disengagement_count Int @default(0)
satisfaction_count Int @default(0)
failure_count Int @default(0)
loop_count Int @default(0)
exhaustion_count Int @default(0)
last_user_content String?
last_assistant_content String?
tool_call_history Json @default("[]")
pending_tool_calls Json @default("{}")
turn_count Int @default(0)
last_processed_turn Int @default(-1)
clean_credit_awarded Boolean @default(false)
terminal_status Int?
last_activity_at DateTime @default(now())
@@id([session_id, router_name, model_name])
@@index([last_activity_at])
}

View File

@ -200,12 +200,16 @@ if TYPE_CHECKING:
from litellm.router_strategy.complexity_router.complexity_router import (
ComplexityRouter,
)
from litellm.router_strategy.adaptive_router.adaptive_router import (
AdaptiveRouter,
)
Span = Union[_Span, Any]
else:
Span = Any
AutoRouter = Any
ComplexityRouter = Any
AdaptiveRouter = Any
PreRoutingHookResponse = Any
@ -464,6 +468,7 @@ class Router:
) # {"TEAM_ID": PatternMatchRouter}
self.auto_routers: Dict[str, "AutoRouter"] = {}
self.complexity_routers: Dict[str, "ComplexityRouter"] = {}
self.adaptive_routers: Dict[str, "AdaptiveRouter"] = {}
# Initialize model_group_alias early since it's used in set_model_list
self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = (
@ -3864,7 +3869,7 @@ class Router:
self._add_deployment_model_to_endpoint_for_llm_passthrough_route(
kwargs=kwargs, model=model, model_name=model_name
)
# Get custom_llm_provider from deployment params
try:
custom_llm_provider = data.get("custom_llm_provider")
@ -3872,10 +3877,12 @@ class Router:
model=data["model"],
custom_llm_provider=custom_llm_provider,
)
custom_llm_provider = custom_llm_provider or inferred_custom_llm_provider
custom_llm_provider = (
custom_llm_provider or inferred_custom_llm_provider
)
except Exception:
custom_llm_provider = None
# Build response kwargs
response_kwargs = {
**data,
@ -3885,7 +3892,7 @@ class Router:
# Only set custom_llm_provider if it's not None
if custom_llm_provider is not None:
response_kwargs["custom_llm_provider"] = custom_llm_provider
response = original_generic_function(**response_kwargs)
rpm_semaphore = self._get_client(
@ -3981,7 +3988,9 @@ class Router:
model=data["model"],
custom_llm_provider=custom_llm_provider,
)
custom_llm_provider = custom_llm_provider or inferred_custom_llm_provider
custom_llm_provider = (
custom_llm_provider or inferred_custom_llm_provider
)
except Exception:
custom_llm_provider = None
@ -4246,7 +4255,9 @@ class Router:
custom_llm_provider=custom_llm_provider,
)
# Preserve explicitly stored provider, fallback to inferred
custom_llm_provider = custom_llm_provider or inferred_custom_llm_provider
custom_llm_provider = (
custom_llm_provider or inferred_custom_llm_provider
)
## REPLACE MODEL IN FILE WITH SELECTED DEPLOYMENT ##
purpose = cast(Optional[OpenAIFilesPurpose], kwargs.get("purpose"))
@ -5355,9 +5366,9 @@ class Router:
e,
(litellm.ContextWindowExceededError, litellm.ContentPolicyViolationError),
)
_request_team_id: Optional[str] = (
kwargs.get("metadata", {}) or {}
).get("user_api_key_team_id")
_request_team_id: Optional[str] = (kwargs.get("metadata", {}) or {}).get(
"user_api_key_team_id"
)
all_deployments = self._get_all_deployments(
model_name=original_model_group, team_id=_request_team_id
)
@ -6804,10 +6815,13 @@ class Router:
Check if the deployment is an auto-router deployment (semantic router).
Returns True if the litellm_params model starts with "auto_router/"
but NOT "auto_router/complexity_router" (which uses complexity routing).
but NOT "auto_router/complexity_router" or "auto_router/adaptive_router"
(which use the complexity-router and adaptive-router strategies).
"""
if litellm_params.model.startswith("auto_router/complexity_router"):
return False # This is handled by complexity_router
if litellm_params.model.startswith("auto_router/adaptive_router"):
return False # This is handled by adaptive_router
if litellm_params.model.startswith("auto_router/"):
return True
return False
@ -6914,6 +6928,121 @@ class Router:
)
self.complexity_routers[deployment.model_name] = complexity_router
def _is_adaptive_router_deployment(self, litellm_params: LiteLLM_Params) -> bool:
"""True when this deployment opts in via the `auto_router/adaptive_router` model prefix."""
return litellm_params.model.startswith("auto_router/adaptive_router")
def _finalize_adaptive_router_if_configured(self) -> None:
"""Locate every adaptive-router deployment in the finalized model_list and
build an AdaptiveRouter for each. Safe no-op when none are configured.
Idempotent: skips any deployment whose model_name is already initialized."""
for entry in self.model_list or []:
lp = (
entry.get("litellm_params")
if isinstance(entry, dict)
else entry.litellm_params
)
lp_model = (
(lp.get("model") if isinstance(lp, dict) else lp.model) if lp else None
)
if not (lp_model and lp_model.startswith("auto_router/adaptive_router")):
continue
model_name = (
entry.get("model_name") if isinstance(entry, dict) else entry.model_name
)
if not model_name or not lp:
continue
if model_name in self.adaptive_routers:
continue
deployment = Deployment(
model_name=model_name,
litellm_params=(
lp if not isinstance(lp, dict) else LiteLLM_Params(**lp)
),
model_info=(
entry.get("model_info")
if isinstance(entry, dict)
else entry.model_info
),
)
self.init_adaptive_router_deployment(deployment=deployment)
def init_adaptive_router_deployment(self, deployment: Deployment) -> None:
"""
Build an AdaptiveRouter instance for this deployment and register its
post-call hook. Multiple adaptive routers can coexist on a single Router,
keyed by `deployment.model_name`.
`model_to_prefs` and `model_to_cost` are derived from the OTHER models
already registered in `self.model_list` whose `model_name` appears in
`available_models`. Models not yet registered fall back to defaults.
"""
# Local import: AdaptiveRouter -> hooks -> classifier all import litellm
# internals which transitively import this module. (AGENTS.md exception clause.)
from litellm.router_strategy.adaptive_router.adaptive_router import (
AdaptiveRouter,
)
from litellm.router_strategy.adaptive_router.hooks import (
AdaptiveRouterPostCallHook,
)
from litellm.types.router import (
AdaptiveRouterConfig,
AdaptiveRouterPreferences,
)
raw_config = deployment.litellm_params.adaptive_router_config
if raw_config is None:
raise ValueError(
"adaptive_router_config is required for adaptive-router deployments."
)
config = AdaptiveRouterConfig(**raw_config)
model_to_prefs: Dict[str, AdaptiveRouterPreferences] = {}
model_to_cost: Dict[str, float] = {}
for d in self.model_list or []:
name = d.get("model_name") if isinstance(d, dict) else d.model_name
if name not in config.available_models:
continue
mi = d.get("model_info") if isinstance(d, dict) else d.model_info
mi_dict: Dict[str, Any] = (
mi if isinstance(mi, dict) else (mi.model_dump() if mi else {})
)
prefs_raw = mi_dict.get("adaptive_router_preferences")
if prefs_raw is not None:
model_to_prefs[name] = AdaptiveRouterPreferences(**prefs_raw)
# `input_cost_per_token` is a LiteLLM_Params field per types/router.py.
lp = d.get("litellm_params") if isinstance(d, dict) else d.litellm_params
lp_dict: Dict[str, Any] = (
lp if isinstance(lp, dict) else (lp.model_dump() if lp else {})
)
cost = lp_dict.get("input_cost_per_token")
if cost is not None:
model_to_cost[name] = float(cost)
if deployment.model_name in self.adaptive_routers:
raise ValueError(
f"Adaptive-router deployment {deployment.model_name} already exists. "
"Please use a different model name."
)
adaptive_router = AdaptiveRouter(
router_name=deployment.model_name,
config=config,
model_to_prefs=model_to_prefs,
model_to_cost=model_to_cost,
)
self.adaptive_routers[deployment.model_name] = adaptive_router
litellm.callbacks.append(
AdaptiveRouterPostCallHook(adaptive_router=adaptive_router)
)
verbose_router_logger.info(
"AdaptiveRouter[%s] initialized with %d models",
deployment.model_name,
len(config.available_models),
)
def deployment_is_active_for_environment(self, deployment: Deployment) -> bool:
"""
Function to check if a llm deployment is active for a given environment. Allows using the same config.yaml across multople environments
@ -7007,6 +7136,10 @@ class Router:
# Note: model_name_to_deployment_indices is already built incrementally
# by _create_deployment -> _add_model_to_list_and_index_map
# Deferred: build the AdaptiveRouter strategy now that all underlying
# deployments are visible in self.model_list.
self._finalize_adaptive_router_if_configured()
def _add_deployment(self, deployment: Deployment) -> Deployment:
import os
@ -7134,6 +7267,11 @@ class Router:
):
self.init_complexity_router_deployment(deployment=deployment)
# NOTE: adaptive-router deployments are deferred to the end of
# set_model_list() because their init needs visibility into the OTHER
# deployments listed in `available_models` (which may not yet have
# been processed when this one is created).
return deployment
def _initialize_deployment_for_pass_through(
@ -9645,6 +9783,19 @@ class Router:
specific_deployment=specific_deployment,
)
#########################################################
# Check if an adaptive-router should be used
#########################################################
adaptive_router = self.adaptive_routers.get(model)
if adaptive_router is not None:
return await adaptive_router.async_pre_routing_hook(
model=model,
request_kwargs=request_kwargs,
messages=messages,
input=input,
specific_deployment=specific_deployment,
)
return None
def get_available_deployment(

View File

@ -0,0 +1,93 @@
# Adaptive Router (v0)
A request-type-aware routing strategy. For each incoming request, classify the
prompt into one of seven `RequestType` buckets (code generation, writing,
analytical reasoning, …), then Thompson-sample a Beta(α, β) bandit posterior
per `(request_type, model)` cell to pick the best model. Quality estimates are
combined with a normalized cost score via a weighted linear sum.
A post-call hook reads the response and runs lightweight regex + tool-call
detectors (see `signals.py`) to award per-turn credit/blame to the model that
served the turn. Updates are batched in-memory and flushed to Postgres every
~10s by a background task in `proxy_server.py`.
## Config example
```yaml
model_list:
- model_name: gpt-4o
litellm_params:
model: openai/gpt-4o
model_info:
input_cost_per_token: 0.0000025
adaptive_router_preferences:
quality_tier: 3
strengths: ["code_generation", "analytical_reasoning"]
- model_name: gpt-4o-mini
litellm_params:
model: openai/gpt-4o-mini
model_info:
input_cost_per_token: 0.00000015
adaptive_router_preferences:
quality_tier: 2
strengths: ["general", "factual_lookup"]
- model_name: smart-router
litellm_params:
model: adaptive_router/smart-router
adaptive_router_default_model: gpt-4o-mini
adaptive_router_config:
available_models: ["gpt-4o", "gpt-4o-mini"]
weights:
quality: 0.7
cost: 0.3
```
Callers may pass header `x-litellm-min-quality-tier: 3` (or metadata key
`min_quality_tier: 3`) to force selection from tier-3-or-higher models only.
## Behavior summary
- **Cold start.** Each `(request_type, model)` cell starts with a
Beta prior whose mean = `BASE_TIER_WEIGHT[tier] (+ STRENGTH_BONUS if declared)`
and total mass = `COLD_START_MASS` (10). About ten real observations move it
meaningfully.
- **Per-request decision.** Sample once per eligible model, score with
`quality_weight·sample + cost_weight·normalized_cost`, pick the argmax.
Routing is stateless per-turn — no sticky lookup. Each call resamples.
- **Owner-cache attribution.** Post-call, the conversation's first picked
model claims an "owner slot" for `OWNER_CACHE_TTL_SECONDS` (24h). Later
turns of the same conversation only fire bandit/state updates if the
same model handled them — mismatches are dropped (no attribution) and
counted in `skipped_updates_total`. Conversation identity is the
client-supplied `litellm_session_id` if present, otherwise a sha256 over
caller identity (api key hash, team, user, end-user) + the first message.
- **Per-turn updates.** `satisfaction → +α`. `misalignment, stagnation,
disengagement, failure → +β` (each). `loop → +0.5β`. `exhaustion → 0`
(uptime, not quality). Skipped if conversation has fewer than
`SIGNAL_GATE_MIN_MESSAGES` messages.
- **Persistence.** Bandit cells: aggregated deltas, eventually consistent.
Session rows: last-write-wins snapshots.
## Known v0 limitations
- **Latency is not in the score.** Quality + cost only. A pathologically slow
model can still be picked.
- **Hard sample cap at 200.** Once `α + β > 200`, deltas are silently dropped.
No rescaling — drift is a v1 concern.
- **24h owner-cache TTL.** No explicit eviction below TTL. The in-memory map
can grow if traffic patterns produce many one-shot sessions.
- **Owner-recovery skew.** If model A "owns" a conversation but is then
dethroned in the bandit, later turns served by model B are dropped — so
bandit updates for that conversation flatline until A's TTL expires.
Tracked via `skipped_updates_total`.
- **Signals are regex + tool-call only.** No LLM-judge, no embedding similarity,
no exemplar storage. Signals are best-effort and biased toward English.
- **One AdaptiveRouter per `Router`.** Multiple `adaptive_router/*` deployments
on the same `litellm.Router` raise at init.
- **Bandit-delta mapping is unvalidated.** `_compute_bandit_delta` is a v0
guess; expect to retune after the first ~1000 sessions of real traffic.
- **`request_type` is classified per turn from the latest user message only.**
The first turn's classification doesn't carry forward; a multi-turn session
may shift bucket between turns.

View File

@ -0,0 +1,6 @@
"""Adaptive router strategy. See README.md for design overview."""
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
from litellm.router_strategy.adaptive_router.hooks import AdaptiveRouterPostCallHook
__all__ = ["AdaptiveRouter", "AdaptiveRouterPostCallHook"]

View File

@ -0,0 +1,344 @@
"""
Main adaptive router strategy. See README.md for design overview.
One AdaptiveRouter instance per router_name. Holds in-memory caches:
- _cells: Beta(alpha, beta) bandit posteriors per (request_type, model)
- _owner_cache: session_key -> (owner_model, expires_at) the first model
picked for a conversation owns its bandit-update slot
- _session_states: (session_key, model) -> SessionState for incremental signal updates
Owns the AdaptiveRouterUpdateQueue used by the proxy's flusher to persist
state and session snapshots back to Postgres.
Routing is stateless per-turn (Thompson sample fresh on every call). The
owner cache is consulted only at post-call time to decide whether a turn's
signals should fire a bandit update turns served by a different model than
the conversation's owner are skipped to avoid cross-model misattribution.
"""
from __future__ import annotations
import asyncio
import time
from dataclasses import asdict
from typing import Any, Dict, List, Optional, Tuple, Union, cast
from litellm._logging import verbose_router_logger
from litellm.litellm_core_utils.prompt_templates.common_utils import (
get_last_user_message,
)
from litellm.proxy.db.db_transaction_queue.adaptive_router_update_queue import (
AdaptiveRouterUpdateQueue,
)
from litellm.router_strategy.adaptive_router.bandit import (
BanditCell,
apply_delta,
initial_cell,
pick_best,
)
from litellm.router_strategy.adaptive_router.classifier import classify_prompt
from litellm.router_strategy.adaptive_router.config import (
ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY,
OWNER_CACHE_TTL_SECONDS,
)
from litellm.router_strategy.adaptive_router.signals import (
SessionState,
SignalDelta,
Turn,
apply_turn,
)
from litellm.types.llms.openai import AllMessageValues
from litellm.types.router import (
AdaptiveRouterConfig,
AdaptiveRouterPreferences,
PreRoutingHookResponse,
RequestType,
)
def _default_prefs() -> AdaptiveRouterPreferences:
"""Tier-2 prior with no declared strengths; used when a model omits prefs."""
return AdaptiveRouterPreferences(quality_tier=2, strengths=[])
class AdaptiveRouter:
"""One instance per router_name. Holds in-memory caches + the update queue."""
def __init__(
self,
router_name: str,
config: AdaptiveRouterConfig,
model_to_prefs: Dict[str, AdaptiveRouterPreferences],
model_to_cost: Dict[str, float],
) -> None:
self.router_name = router_name
self.config = config
self.model_to_prefs = model_to_prefs
self.model_to_cost = model_to_cost
self.queue = AdaptiveRouterUpdateQueue()
self._cells: Dict[Tuple[RequestType, str], BanditCell] = {}
self._owner_cache: Dict[str, Tuple[str, float]] = {}
self._session_states: Dict[Tuple[str, str], SessionState] = {}
self._skipped_updates_total: int = 0
self._lock = asyncio.Lock()
self._init_cold_start_cells()
# ---- Cold-start ------------------------------------------------------
def _init_cold_start_cells(self) -> None:
"""Populate _cells with cold-start priors for every (rt, model) combination."""
for rt in RequestType:
for model in self.config.available_models:
prefs = self.model_to_prefs.get(model) or _default_prefs()
self._cells[(rt, model)] = initial_cell(prefs, rt)
async def load_state_from_db(self, prisma_client: Any) -> None:
"""Override cold-start cells with persisted state. Called once at startup."""
if prisma_client is None:
return
try:
rows = await prisma_client.db.litellm_adaptiverouterstate.find_many(
where={"router_name": self.router_name}
)
loaded = 0
for row in rows:
try:
rt = RequestType(row.request_type)
except ValueError:
# Unknown taxonomy entry from an older/newer version. Skip.
continue
if row.model_name not in self.config.available_models:
continue
self._cells[(rt, row.model_name)] = BanditCell(
alpha=row.alpha, beta=row.beta
)
loaded += 1
verbose_router_logger.info(
"AdaptiveRouter[%s]: loaded %d cells from DB",
self.router_name,
loaded,
)
except Exception as e:
verbose_router_logger.exception(
"AdaptiveRouter[%s]: failed to load state from DB: %s",
self.router_name,
e,
)
# ---- Pre-routing hook ------------------------------------------------
async def async_pre_routing_hook(
self,
model: str,
request_kwargs: Dict[str, Any],
messages: Optional[List[Dict[str, Any]]] = None,
input: Optional[Union[str, List]] = None,
specific_deployment: Optional[bool] = False,
) -> Optional[PreRoutingHookResponse]:
"""
Plugin entry point invoked by `Router.async_pre_routing_hook` when the
inbound `model` matches this adaptive router's `router_name`.
Classifies the last user message, picks a logical model via the bandit,
and stashes the chosen model on `request_kwargs["metadata"]` so the
post-call hook can surface it as a response header.
Routing is stateless per-turn: every call Thompson-samples fresh,
regardless of any prior pick for the same session. Cross-turn
attribution is enforced post-call via the owner cache (see
`claim_or_check_owner`).
"""
user_text = (
get_last_user_message(cast(List[AllMessageValues], messages or [])) or ""
)
request_type = classify_prompt(user_text)
chosen_model = await self.pick_model(request_type=request_type)
verbose_router_logger.debug(
"AdaptiveRouter[%s]: classified=%s -> chose %s",
self.router_name,
request_type.value,
chosen_model,
)
# Relay the chosen logical model to the post-call hook, which surfaces
# it as the `x-litellm-adaptive-router-model` response header. We use
# `metadata` (not a top-level kwarg) so the value doesn't leak into
# `litellm.acompletion(**input_kwargs)`.
kwargs_metadata = request_kwargs.setdefault("metadata", {})
if isinstance(kwargs_metadata, dict):
kwargs_metadata[ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY] = chosen_model
return PreRoutingHookResponse(model=chosen_model, messages=messages)
# ---- Pick model ------------------------------------------------------
async def pick_model(
self,
request_type: RequestType,
min_quality_tier: Optional[int] = None,
) -> str:
"""Thompson-sample across eligible models. Stateless per-turn."""
eligible = self._eligible_models(min_quality_tier)
if not eligible:
raise ValueError(
f"AdaptiveRouter[{self.router_name}]: no models meet "
f"min_quality_tier={min_quality_tier}"
)
cells = {m: self._cells[(request_type, m)] for m in eligible}
costs = {m: self.model_to_cost.get(m, 0.0) for m in eligible}
return pick_best(
cells,
costs,
quality_weight=self.config.weights.quality,
cost_weight=self.config.weights.cost,
)
def claim_or_check_owner(self, session_key: str, current_model: str) -> bool:
"""Resolve attribution for a turn under stateless routing.
Returns True iff this turn should fire a bandit/state update. The
first call for a `session_key` claims ownership for `current_model`
and returns True. Subsequent calls return True only if the owner is
still live AND matches `current_model`. Mismatches (a different
model handled this turn) and expired owners both increment
`_skipped_updates_total` and return False no attribution.
"""
now = time.time()
existing = self._owner_cache.get(session_key)
if existing is not None and existing[1] > now:
owner_model, _ = existing
if owner_model == current_model:
return True
self._skipped_updates_total += 1
return False
# No live owner -> claim for current_model.
self._owner_cache[session_key] = (
current_model,
now + OWNER_CACHE_TTL_SECONDS,
)
return True
async def get_state_snapshot(self) -> Dict[str, Any]:
"""In-memory snapshot for the introspection endpoint. Cheap; no DB hit."""
cells = []
for (rt, model), cell in sorted(
self._cells.items(), key=lambda kv: (kv[0][0].value, kv[0][1])
):
total = cell.alpha + cell.beta
cells.append(
{
"request_type": rt.value,
"model": model,
"alpha": cell.alpha,
"beta": cell.beta,
"samples": total,
"quality_mean": cell.alpha / total if total > 0 else 0.0,
}
)
queue = await self.queue.queue_size()
now = time.time()
owner_cache_live = sum(1 for _, exp in self._owner_cache.values() if exp > now)
return {
"router_name": self.router_name,
"available_models": list(self.config.available_models),
"weights": {
"quality": self.config.weights.quality,
"cost": self.config.weights.cost,
},
"model_costs": dict(self.model_to_cost),
"cells": cells,
"owner_cache_live": owner_cache_live,
"skipped_updates_total": self._skipped_updates_total,
"queue": queue,
}
def _eligible_models(self, min_quality_tier: Optional[int]) -> List[str]:
if min_quality_tier is None:
return list(self.config.available_models)
return [
m
for m in self.config.available_models
if (self.model_to_prefs.get(m) or _default_prefs()).quality_tier
>= min_quality_tier
]
# ---- Session state ---------------------------------------------------
def get_or_create_session_state(
self,
session_id: str,
model_name: str,
request_type: RequestType,
) -> SessionState:
key = (session_id, model_name)
state = self._session_states.get(key)
if state is None:
state = SessionState(
session_id=session_id,
router_name=self.router_name,
model_name=model_name,
classified_type=request_type.value,
)
self._session_states[key] = state
return state
async def record_turn(
self,
session_id: str,
model_name: str,
request_type: RequestType,
turn: Turn,
) -> SignalDelta:
"""Apply one turn, push session snapshot + bandit deltas to the queue."""
state = self.get_or_create_session_state(session_id, model_name, request_type)
delta = apply_turn(state, turn)
print("CALLS DELTA", delta)
snapshot = asdict(state)
await self.queue.add_session_state(
session_id, self.router_name, model_name, snapshot
)
d_alpha, d_beta = self._compute_bandit_delta(delta)
print("CALLS D_ALPHA", d_alpha)
if d_alpha != 0 or d_beta != 0:
cell_key = (request_type, model_name)
self._cells[cell_key] = apply_delta(self._cells[cell_key], d_alpha, d_beta)
await self.queue.add_state_delta(
self.router_name,
request_type.value,
model_name,
d_alpha,
d_beta,
)
return delta
@staticmethod
def _compute_bandit_delta(delta: SignalDelta) -> Tuple[float, float]:
"""
Translate per-turn signal deltas into bandit-cell deltas.
v0 mapping (UNVALIDATED D6):
- satisfaction -> +1 alpha
- misalignment, stagnation,
disengagement, failure -> +1 beta each
- loop -> +0.5 beta (weak; could be model OR user)
- exhaustion -> 0 (uptime issue, tracked separately later)
"""
d_alpha = float(delta.satisfaction)
d_beta = (
float(
delta.misalignment
+ delta.stagnation
+ delta.disengagement
+ delta.failure
)
+ 0.5 * delta.loop
)
return d_alpha, d_beta

View File

@ -0,0 +1,136 @@
"""
Thompson sampling and prior initialization for the adaptive router bandit.
Each (router, request_type, model) cell is a Beta(alpha, beta) posterior.
- alpha = pseudo-successes
- beta = pseudo-failures
- mean = alpha / (alpha + beta)
- total samples = alpha + beta - COLD_START_MASS (informative prior, not data)
Hot path: thompson_sample() pure function, no I/O.
"""
import random
from dataclasses import dataclass
from typing import Dict, List, Optional
from litellm.router_strategy.adaptive_router.config import (
BASE_TIER_WEIGHT,
COLD_START_MASS,
DEFAULT_COST_WEIGHT,
DEFAULT_QUALITY_WEIGHT,
SAMPLE_CAP,
STRENGTH_BONUS,
)
from litellm.types.router import AdaptiveRouterPreferences, RequestType
@dataclass(frozen=True)
class BanditCell:
"""Posterior state for a single (router, request_type, model) cell."""
alpha: float
beta: float
@property
def mean(self) -> float:
total = self.alpha + self.beta
return self.alpha / total if total > 0 else 0.5
@property
def total_samples(self) -> int:
return max(0, int(self.alpha + self.beta - COLD_START_MASS))
def initial_cell(
prefs: AdaptiveRouterPreferences, request_type: RequestType
) -> BanditCell:
"""
Cold-start prior for a (model, request_type) cell.
mean = base_tier_weight[tier] + (STRENGTH_BONUS if request_type in strengths else 0)
capped at 0.95 to avoid an over-confident prior.
Total mass = COLD_START_MASS so that ~10 real observations can move it noticeably.
"""
base = BASE_TIER_WEIGHT[prefs.quality_tier]
bonus = STRENGTH_BONUS if request_type in prefs.strengths else 0.0
mean = min(0.95, base + bonus)
alpha = mean * COLD_START_MASS
beta = (1.0 - mean) * COLD_START_MASS
return BanditCell(alpha=alpha, beta=beta)
def apply_delta(cell: BanditCell, delta_alpha: float, delta_beta: float) -> BanditCell:
"""
Apply a learning update to a cell, enforcing the sample cap.
SAMPLE_CAP is a HARD cap on (alpha + beta). When the cap would be exceeded,
we drop the update. (D5: hard cap, no rescaling keep v0 simple.)
"""
new_alpha = cell.alpha + delta_alpha
new_beta = cell.beta + delta_beta
if new_alpha + new_beta > SAMPLE_CAP:
return cell
return BanditCell(alpha=new_alpha, beta=new_beta)
def thompson_sample(cell: BanditCell, rng: Optional[random.Random] = None) -> float:
"""Draw a sample from Beta(alpha, beta). Returns a quality estimate in [0, 1]."""
r = rng if rng is not None else random
return r.betavariate(cell.alpha, cell.beta)
def normalized_cost(model_cost: float, all_costs: List[float]) -> float:
"""
Map a raw $/1k-token cost into [0, 1] where 0 = most expensive, 1 = cheapest.
Returns 0.5 when there's no spread.
"""
if not all_costs:
return 0.5
lo, hi = min(all_costs), max(all_costs)
if hi == lo:
return 0.5
return 1.0 - ((model_cost - lo) / (hi - lo))
def score(
quality_sample: float,
model_cost: float,
all_costs: List[float],
quality_weight: float = DEFAULT_QUALITY_WEIGHT,
cost_weight: float = DEFAULT_COST_WEIGHT,
) -> float:
"""
Multi-objective score. V0 is a weighted linear sum of (quality, normalized_cost).
Higher is better. Both inputs are in [0, 1].
"""
cost_score = normalized_cost(model_cost, all_costs)
return quality_weight * quality_sample + cost_weight * cost_score
def pick_best(
cells: Dict[str, BanditCell],
model_costs: Dict[str, float],
quality_weight: float = DEFAULT_QUALITY_WEIGHT,
cost_weight: float = DEFAULT_COST_WEIGHT,
rng: Optional[random.Random] = None,
) -> str:
"""
Sample once per model, score each, return the model with highest score.
cells: {model_name: BanditCell}
model_costs: {model_name: $/1k tokens}
"""
if not cells:
raise ValueError("pick_best called with no models")
all_costs = list(model_costs.values())
best_model: Optional[str] = None
best_score = float("-inf")
for model, cell in cells.items():
q = thompson_sample(cell, rng=rng)
s = score(q, model_costs[model], all_costs, quality_weight, cost_weight)
if s > best_score:
best_score = s
best_model = model
assert best_model is not None
return best_model

View File

@ -0,0 +1,140 @@
"""
Rule-based classifier mapping a user prompt to a RequestType.
V0 design choice: deterministic regex over the FIRST user message in a session.
Result is cached per session (caller's responsibility, not ours).
Order matters: we check more specific types first, falling back to GENERAL.
"""
import re
from typing import List, Pattern, Tuple
from litellm.types.router import RequestType
_RULES: List[Tuple[Pattern[str], RequestType]] = [
(
re.compile(
r"\b(write|create|generate|implement|build)\s+(?:a |an |the |me )?(?:python|javascript|typescript|java|rust|go|c\+\+|sql|bash|shell)\b",
re.IGNORECASE,
),
RequestType.CODE_GENERATION,
),
(
re.compile(
r"\b(write|create|implement|build)\b(?:\s+\w+){0,4}?\s+(function|class|method|script|program|api|endpoint|microservice)\b",
re.IGNORECASE,
),
RequestType.CODE_GENERATION,
),
(
re.compile(
r"\b(explain|describe|understand|walk me through|what does)\b.*\b(code|function|method|class|algorithm|snippet)\b",
re.IGNORECASE,
),
RequestType.CODE_UNDERSTANDING,
),
(
re.compile(
r"\b(debug|fix|why (?:is|does|isn't)|what.s wrong|trace)\b.*\b(error|bug|exception|stacktrace|stack trace|traceback)\b",
re.IGNORECASE,
),
RequestType.CODE_UNDERSTANDING,
),
(
re.compile(
r"\b(review|critique)\s+(?:this |my |the )?(?:code|pr|pull request|diff|patch)\b",
re.IGNORECASE,
),
RequestType.CODE_UNDERSTANDING,
),
(
re.compile(
r"\b(design|architect|plan|architecture)\b.*\b(system|service|api|database|schema|module|microservice)\b",
re.IGNORECASE,
),
RequestType.TECHNICAL_DESIGN,
),
(
re.compile(
r"\b(should i (?:use|choose|pick)|tradeoffs? between|compare)\b.*\b(library|framework|language|database|protocol|postgres|postgresql|mongodb|dynamodb|mysql|redis|kafka|sql|nosql)\b",
re.IGNORECASE,
),
RequestType.TECHNICAL_DESIGN,
),
(
re.compile(
r"\bhow (?:should|do) i (?:design|structure|organize|model)\b",
re.IGNORECASE,
),
RequestType.TECHNICAL_DESIGN,
),
(
re.compile(
r"\b(solve|compute|calculate|prove|derive)\b.*\b(equation|integral|derivative|theorem|proof|problem)\b",
re.IGNORECASE,
),
RequestType.ANALYTICAL_REASONING,
),
(
re.compile(r"\b(if .+ then|given .+ find|suppose|assume)\b", re.IGNORECASE),
RequestType.ANALYTICAL_REASONING,
),
(
re.compile(
r"\b(probability|statistics|combinatorics|optimization problem)\b",
re.IGNORECASE,
),
RequestType.ANALYTICAL_REASONING,
),
(
re.compile(
r"\b(write|draft|compose|rewrite|edit|proofread|polish)\b.*\b(email|essay|blog|post|article|letter|memo|copy|paragraph|sentence)\b",
re.IGNORECASE,
),
RequestType.WRITING,
),
(
re.compile(
r"\b(make (?:this|it)|help me)\s+(?:more |less )?(?:concise|formal|casual|professional|persuasive)\b",
re.IGNORECASE,
),
RequestType.WRITING,
),
(
re.compile(
r"^\s*(who|what|when|where|which)\s+(?:is|was|were|are)\b", re.IGNORECASE
),
RequestType.FACTUAL_LOOKUP,
),
(
re.compile(r"^\s*(define|definition of|meaning of)\b", re.IGNORECASE),
RequestType.FACTUAL_LOOKUP,
),
(
re.compile(
r"^\s*how (?:do you spell|to spell|many .* are there|tall is)\b",
re.IGNORECASE,
),
RequestType.FACTUAL_LOOKUP,
),
]
def classify_prompt(text: str) -> RequestType:
"""
Classify a single user prompt.
Falls back to GENERAL when no rule matches. Empty/whitespace-only also
returns GENERAL.
"""
if not text or not text.strip():
return RequestType.GENERAL
truncated = text[:2000]
for pattern, request_type in _RULES:
if pattern.search(truncated):
return request_type
return RequestType.GENERAL

View File

@ -0,0 +1,55 @@
"""
Configuration constants for the adaptive_router strategy.
All magic numbers are first-pass guesses (D3-D6 in the handoff plan).
Expect to retune after first 1000 sessions of real traffic.
"""
from typing import Dict
from litellm.types.router import RequestType # re-export for convenience # noqa: F401
# D3 — Score weights (default; user-overridable via AdaptiveRouterConfig.weights)
DEFAULT_QUALITY_WEIGHT: float = 0.7 # UNVALIDATED — calibrated against [0] sessions
DEFAULT_COST_WEIGHT: float = 0.3 # UNVALIDATED — calibrated against [0] sessions
# D4 — Cold-start prior: (alpha + beta) total mass = COLD_START_MASS
# Mean of Beta = base_tier_weight + (strength_bonus if declared)
BASE_TIER_WEIGHT: Dict[int, float] = {1: 0.3, 2: 0.5, 3: 0.7} # UNVALIDATED
STRENGTH_BONUS: float = 0.3 # UNVALIDATED
COLD_START_MASS: float = 10.0
# D5 — Sample cap. Hard cap, no rescaling (drift handling is v1).
SAMPLE_CAP: int = 200
# D6 — Clean-trace credit: minimum turns before α += 1 can fire.
MIN_TURNS_FOR_CLEAN_CREDIT: int = 3
# D2 — Owner-cache TTL (seconds). 24h.
# A conversation's first-picked model "owns" the bandit-update slot for
# this long. Subsequent turns of the same conversation only contribute a
# bandit/state update when the same model is re-sampled.
OWNER_CACHE_TTL_SECONDS: int = 24 * 3600
# Below this many messages we skip post-call signal recording. Most signals
# (misalignment, stagnation, satisfaction-in-response-to-prior-turn) need at
# least one full prior exchange to be meaningful.
SIGNAL_GATE_MIN_MESSAGES: int = 4
# Detector thresholds (from Plano/Chen 2026 paper).
MISALIGNMENT_JACCARD_THRESHOLD: float = 0.45
STAGNATION_JACCARD_NEAR_DUP: float = 0.50
STAGNATION_JACCARD_EXACT: float = 0.85
LOOP_REPEAT_THRESHOLD: int = 3
TOOL_CALL_HISTORY_MAX: int = 20
# D1 — Caller filter for min quality tier.
MIN_QUALITY_TIER_HEADER: str = "x-litellm-min-quality-tier"
MIN_QUALITY_TIER_METADATA_KEY: str = "min_quality_tier"
# Pre-routing -> post-call relay: the chosen logical model is stashed on
# request_kwargs["metadata"][ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY] by the
# pre-routing hook, then read by the post-call hook to surface as the
# ADAPTIVE_ROUTER_RESPONSE_HEADER response header.
ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY: str = "adaptive_router_chosen_model"
ADAPTIVE_ROUTER_RESPONSE_HEADER: str = "x-litellm-adaptive-router-model"

View File

@ -0,0 +1,241 @@
"""
Post-call hook for the adaptive router.
On each successful or failed completion, build a Turn from the request/response
and push it through `AdaptiveRouter.record_turn`. The router then updates the
in-memory bandit cell + session state and queues writes for the proxy flusher.
All work happens after the response has been returned to the caller. Any
exception is swallowed signal recording must never break a request.
"""
from __future__ import annotations
import hashlib
import json
from typing import Any, Dict, List, Optional
from litellm._logging import verbose_router_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
from litellm.router_strategy.adaptive_router.classifier import classify_prompt
from litellm.router_strategy.adaptive_router.config import (
ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY,
ADAPTIVE_ROUTER_RESPONSE_HEADER,
SIGNAL_GATE_MIN_MESSAGES,
)
from litellm.router_strategy.adaptive_router.signals import Turn
# Identity fields hashed into a derived session key so the same conversation
# from the same caller produces a stable key, while different keys/teams/users
# stay segregated even if they happen to send identical first messages.
_IDENTITY_FIELDS = (
"user_api_key_hash",
"user_api_key_team_id",
"user_api_key_user_id",
"user_api_key_end_user_id",
)
def _resolve_session_key(kwargs: Dict[str, Any]) -> Optional[str]:
"""Pick a stable per-conversation key for owner-cache attribution.
Order:
1. Honor a client-supplied session id (`litellm_session_id` on either
`litellm_params` or `litellm_params.metadata`, or `session_id` on
metadata) backward compat for callers already wired up.
2. Otherwise derive a sha256 over (identity fields, first message) so
the key is stable across turns of the same conversation.
Returns None if there are no messages (nothing to attribute).
"""
litellm_params = kwargs.get("litellm_params") or {}
sid = litellm_params.get("litellm_session_id")
if sid:
return str(sid)
metadata = litellm_params.get("metadata") or {}
if isinstance(metadata, dict):
sid = metadata.get("session_id") or metadata.get("litellm_session_id")
if sid:
return str(sid)
messages = kwargs.get("messages") or []
if not messages:
return None
identity = ":".join(
str(metadata.get(f) or "") if isinstance(metadata, dict) else ""
for f in _IDENTITY_FIELDS
)
first = messages[0]
payload = (
identity
+ "|"
+ json.dumps(
{"role": first.get("role"), "content": first.get("content")},
sort_keys=True,
default=str,
)
)
return hashlib.sha256(payload.encode("utf-8")).hexdigest()
def _last_user_content(messages: Optional[List[Dict[str, Any]]]) -> Optional[str]:
if not messages:
return None
for msg in reversed(messages):
if msg.get("role") == "user":
content = msg.get("content")
if isinstance(content, str):
return content
if isinstance(content, list):
# OpenAI vision-style content: pick first text part.
for part in content:
if isinstance(part, dict) and part.get("type") == "text":
return part.get("text")
return None
return None
def _assistant_content_and_tool_calls(response_obj: Any) -> tuple:
"""Return (assistant_text, tool_calls_list) extracted from a ModelResponse-ish object."""
if response_obj is None:
return None, []
try:
choices = getattr(response_obj, "choices", None) or response_obj.get("choices")
except Exception:
return None, []
if not choices:
return None, []
msg = choices[0]
msg = getattr(msg, "message", None) or (
msg.get("message") if isinstance(msg, dict) else None
)
if msg is None:
return None, []
content = getattr(msg, "content", None)
if content is None and isinstance(msg, dict):
content = msg.get("content")
raw_tool_calls = getattr(msg, "tool_calls", None)
if raw_tool_calls is None and isinstance(msg, dict):
raw_tool_calls = msg.get("tool_calls")
tool_calls: List[Dict[str, Any]] = []
for tc in raw_tool_calls or []:
if isinstance(tc, dict):
tool_calls.append(tc)
else:
try:
tool_calls.append(tc.model_dump())
except Exception:
tool_calls.append({"name": getattr(tc, "name", ""), "arguments": ""})
return content, tool_calls
class AdaptiveRouterPostCallHook(CustomLogger):
"""One hook instance per AdaptiveRouter. Registered into litellm.callbacks."""
def __init__(self, adaptive_router: AdaptiveRouter) -> None:
self.adaptive_router = adaptive_router
async def async_post_call_success_hook(
self,
data: Dict[str, Any],
user_api_key_dict: Any,
response: Any,
) -> None:
"""
Surface the chosen logical model picked by the pre-routing hook as the
`x-litellm-adaptive-router-model` response header.
The chosen model is stashed on `data["metadata"]` by
`AdaptiveRouter.async_pre_routing_hook`. The proxy awaits this hook
before reading `_hidden_params["additional_headers"]` for the outgoing
HTTP response, so any value we write here flows through.
"""
metadata = data.get("metadata") or {}
chosen = (
metadata.get(ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY)
if isinstance(metadata, dict)
else None
)
if not chosen:
return
hidden_params = getattr(response, "_hidden_params", None)
if not isinstance(hidden_params, dict):
return
hidden_params.setdefault("additional_headers", {})
hidden_params["additional_headers"][ADAPTIVE_ROUTER_RESPONSE_HEADER] = chosen
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
await self._record(kwargs, response_obj, response_status=200)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
status = kwargs.get("response_status")
if status is None:
exc = kwargs.get("exception")
status = getattr(exc, "status_code", 500) if exc is not None else 500
await self._record(kwargs, response_obj, response_status=int(status))
async def _record(
self,
kwargs: Dict[str, Any],
response_obj: Any,
response_status: int,
) -> None:
try:
messages = kwargs.get("messages") or []
if len(messages) < SIGNAL_GATE_MIN_MESSAGES:
# Too few turns for any signal to be meaningful — skip.
return
session_key = _resolve_session_key(kwargs)
if not session_key:
return
# The bandit cells are keyed by the *logical* model name from
# `available_models` (e.g. "smart"/"fast"). `kwargs["model"]` at
# post-call time is the physical upstream model
# (e.g. "anthropic/claude-opus-4-7"), so it cannot be used directly.
# The pre-routing hook stashes the logical pick under this key.
litellm_params = kwargs.get("litellm_params") or {}
metadata = litellm_params.get("metadata") or {}
current_model = (
metadata.get(ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY)
if isinstance(metadata, dict)
else None
)
if not current_model:
return
if not self.adaptive_router.claim_or_check_owner(
session_key, current_model
):
# A different model owns this conversation — skip attribution.
return
user_text = _last_user_content(messages)
assistant_text, tool_calls = _assistant_content_and_tool_calls(response_obj)
request_type = classify_prompt(user_text or "")
turn = Turn(
user_content=user_text,
assistant_content=(
assistant_text if isinstance(assistant_text, str) else None
),
tool_calls=tool_calls,
tool_results=[],
response_status=response_status,
)
await self.adaptive_router.record_turn(
session_id=session_key,
model_name=current_model,
request_type=request_type,
turn=turn,
)
except Exception as e:
verbose_router_logger.exception(
"AdaptiveRouterPostCallHook: failed to record turn: %s", e
)

View File

@ -0,0 +1,272 @@
"""
Incremental signal detection for the adaptive router.
Each session maintains a SessionState. On every turn, we call apply_turn(state, turn)
which mutates the state in place and returns a SignalDelta listing which signals
fired on THIS turn. The router then queues the delta to be flushed to DB.
Design constraint: O(1) work per turn. No re-scanning the full session history.
We keep small bounded windows: last_user_content, last_assistant_content, and a
bounded list of recent tool call signatures.
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Set
from litellm.router_strategy.adaptive_router.config import (
LOOP_REPEAT_THRESHOLD,
MISALIGNMENT_JACCARD_THRESHOLD,
STAGNATION_JACCARD_NEAR_DUP,
TOOL_CALL_HISTORY_MAX,
)
# ---- Public types ---------------------------------------------------------
@dataclass
class SignalDelta:
"""Which signals fired on a single turn. Counts are 0 or 1 (one delta per turn)."""
misalignment: int = 0
stagnation: int = 0
disengagement: int = 0
satisfaction: int = 0
failure: int = 0
loop: int = 0
exhaustion: int = 0
def any_fired(self) -> bool:
return any(
[
self.misalignment,
self.stagnation,
self.disengagement,
self.satisfaction,
self.failure,
self.loop,
self.exhaustion,
]
)
@dataclass
class SessionState:
"""In-memory rolling state for one session.
Mirrors the LiteLLM_AdaptiveRouterSession DB row (Wave 0 schema). The flusher
later persists this. We keep this as a plain dataclass no DB coupling.
"""
session_id: str
router_name: str
model_name: str
classified_type: str
misalignment_count: int = 0
stagnation_count: int = 0
disengagement_count: int = 0
satisfaction_count: int = 0
failure_count: int = 0
loop_count: int = 0
exhaustion_count: int = 0
last_user_content: Optional[str] = None
last_assistant_content: Optional[str] = None
tool_call_history: List[str] = field(default_factory=list)
pending_tool_calls: Dict[str, str] = field(default_factory=dict)
turn_count: int = 0
terminal_status: Optional[int] = None
@dataclass
class Turn:
"""One turn of input. Caller assembles this from the request/response."""
user_content: Optional[str] = None
assistant_content: Optional[str] = None
tool_calls: List[Dict[str, Any]] = field(default_factory=list)
tool_results: List[Dict[str, Any]] = field(default_factory=list)
response_status: Optional[int] = None
# ---- Detection helpers ----------------------------------------------------
_TOKEN_RE = re.compile(r"[A-Za-z0-9]+")
def _tokens(text: Optional[str]) -> Set[str]:
if not text:
return set()
return {t.lower() for t in _TOKEN_RE.findall(text)}
def _jaccard(a: Set[str], b: Set[str]) -> float:
union = a | b
if not union:
return 0.0
return len(a & b) / len(union)
_DISENGAGEMENT_PATTERNS = [
re.compile(
r"\b(forget it|never mind|give up|talk to (?:a )?human|cancel)\b", re.IGNORECASE
),
re.compile(r"\b(this (?:isn'?t|is not) working|stop|abort)\b", re.IGNORECASE),
re.compile(r"\bi'?ll do it (?:myself|manually)\b", re.IGNORECASE),
]
_SATISFACTION_PATTERNS = [
re.compile(
r"\b(that worked|that did it|works now|fixed it|solved it|nice)\b",
re.IGNORECASE,
),
re.compile(r"\b(thanks|thank you|thx|appreciated|appreciate it)\b", re.IGNORECASE),
re.compile(r"\b(perfect|great|excellent|exactly)\b", re.IGNORECASE),
]
def _detect_misalignment(prev_user: Optional[str], curr_user: Optional[str]) -> bool:
"""Fires when consecutive user messages share *some* topic (jaccard > 0)
but are sufficiently different (jaccard < threshold) i.e. user is
rephrasing, not changing topic, not repeating."""
if not prev_user or not curr_user:
return False
j = _jaccard(_tokens(prev_user), _tokens(curr_user))
return 0.0 < j < MISALIGNMENT_JACCARD_THRESHOLD
def _detect_stagnation(prev_asst: Optional[str], curr_asst: Optional[str]) -> bool:
"""Fires when consecutive assistant messages are near-duplicates."""
if not prev_asst or not curr_asst:
return False
j = _jaccard(_tokens(prev_asst), _tokens(curr_asst))
return j >= STAGNATION_JACCARD_NEAR_DUP
def _detect_disengagement(curr_user: Optional[str]) -> bool:
if not curr_user:
return False
return any(p.search(curr_user) for p in _DISENGAGEMENT_PATTERNS)
def _detect_satisfaction(curr_user: Optional[str]) -> bool:
if not curr_user:
return False
return any(p.search(curr_user) for p in _SATISFACTION_PATTERNS)
def _detect_failure(tool_results: List[Dict[str, Any]]) -> bool:
"""Any tool result that's an error or empty content."""
for r in tool_results:
if r.get("is_error"):
return True
content = r.get("content")
if content is None or content == "" or content == [] or content == {}:
return True
return False
def _signature(call: Dict[str, Any]) -> str:
"""Stable signature for loop detection: name + sorted JSON-ish args."""
name = call.get("name") or call.get("function", {}).get("name", "")
args = call.get("arguments")
if args is None:
args = call.get("function", {}).get("arguments", "")
if isinstance(args, dict):
args = ",".join(f"{k}={args[k]}" for k in sorted(args.keys()))
return f"{name}({args})"
def _detect_loop(history: List[str], new_calls: List[Dict[str, Any]]) -> bool:
"""Fires if any new call's signature appears >= LOOP_REPEAT_THRESHOLD-1 times
in recent history (so this call would be the Nth)."""
if not new_calls:
return False
for call in new_calls:
sig = _signature(call)
recent_count = history.count(sig)
if recent_count >= LOOP_REPEAT_THRESHOLD - 1:
return True
return False
_EXHAUSTION_STATUSES = {408, 413, 429, 503, 504}
_EXHAUSTION_KEYWORDS = (
"context length",
"context window",
"token limit",
"rate limit",
"too many requests",
"timeout",
)
def _detect_exhaustion(
status: Optional[int], tool_results: List[Dict[str, Any]]
) -> bool:
if status is not None and status in _EXHAUSTION_STATUSES:
return True
for r in tool_results:
content = str(r.get("content", "")).lower()
if any(kw in content for kw in _EXHAUSTION_KEYWORDS):
return True
return False
# ---- Public entrypoint ----------------------------------------------------
def apply_turn(state: SessionState, turn: Turn) -> SignalDelta:
"""
Detect signals on this turn, mutate state, return the delta.
O(1) per turn (no full-history rescan). Only inspects last_*, recent tool history
(which is bounded at TOOL_CALL_HISTORY_MAX), and the new turn payload.
"""
delta = SignalDelta()
if _detect_misalignment(state.last_user_content, turn.user_content):
delta.misalignment = 1
if _detect_stagnation(state.last_assistant_content, turn.assistant_content):
delta.stagnation = 1
if _detect_disengagement(turn.user_content):
delta.disengagement = 1
if _detect_satisfaction(turn.user_content):
delta.satisfaction = 1
if _detect_failure(turn.tool_results):
delta.failure = 1
if _detect_loop(state.tool_call_history, turn.tool_calls):
delta.loop = 1
if _detect_exhaustion(turn.response_status, turn.tool_results):
delta.exhaustion = 1
state.misalignment_count += delta.misalignment
state.stagnation_count += delta.stagnation
state.disengagement_count += delta.disengagement
state.satisfaction_count += delta.satisfaction
state.failure_count += delta.failure
state.loop_count += delta.loop
state.exhaustion_count += delta.exhaustion
if turn.user_content:
state.last_user_content = turn.user_content
if turn.assistant_content:
state.last_assistant_content = turn.assistant_content
for call in turn.tool_calls:
state.tool_call_history.append(_signature(call))
if len(state.tool_call_history) > TOOL_CALL_HISTORY_MAX:
state.tool_call_history = state.tool_call_history[-TOOL_CALL_HISTORY_MAX:]
if turn.response_status is not None:
state.terminal_status = turn.response_status
state.turn_count += 1
return delta

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from typing import Any, Dict, List, Literal, Optional, Tuple, Union, get_type_hints
import httpx
from pydantic import BaseModel, ConfigDict, Field, model_validator
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
from typing_extensions import Required, TypedDict
from litellm._uuid import uuid
@ -219,6 +219,10 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
complexity_router_config: Optional[Dict] = None
complexity_router_default_model: Optional[str] = None
# adaptive-router params
adaptive_router_default_model: Optional[str] = None
adaptive_router_config: Optional[Dict] = None
# Batch/File API Params
s3_bucket_name: Optional[str] = None
s3_encryption_key_id: Optional[str] = None
@ -788,3 +792,44 @@ class PreRoutingHookResponse(BaseModel):
model: str
messages: Optional[List[Dict[str, Any]]]
class RequestType(str, enum.Enum):
"""Fixed v0 taxonomy. User-extensible types come in v1."""
CODE_GENERATION = "code_generation"
CODE_UNDERSTANDING = "code_understanding"
TECHNICAL_DESIGN = "technical_design"
ANALYTICAL_REASONING = "analytical_reasoning"
WRITING = "writing"
FACTUAL_LOOKUP = "factual_lookup"
GENERAL = "general"
class AdaptiveRouterWeights(BaseModel):
quality: float = Field(default=0.7, ge=0.0, le=1.0)
cost: float = Field(default=0.3, ge=0.0, le=1.0)
@field_validator("cost")
@classmethod
def _weights_sum_to_one(cls, v, info):
q = info.data.get("quality", 0.7)
if abs(q + v - 1.0) > 0.001:
raise ValueError(
f"weights must sum to 1.0, got quality={q} + cost={v} = {q + v}"
)
return v
class AdaptiveRouterConfig(BaseModel):
available_models: List[str]
weights: AdaptiveRouterWeights = Field(default_factory=AdaptiveRouterWeights)
class AdaptiveRouterPreferences(BaseModel):
"""model_info.adaptive_router_preferences — declared by each model."""
model_config = ConfigDict(use_enum_values=False)
quality_tier: int = Field(ge=1, le=3)
strengths: List[RequestType] = Field(default_factory=list)

View File

@ -1219,3 +1219,46 @@ model LiteLLM_ClaudeCodePluginTable {
@@map("LiteLLM_ClaudeCodePluginTable")
}
// Per-(router, request_type, model) Beta posterior for the adaptive router.
model LiteLLM_AdaptiveRouterState {
router_name String
request_type String
model_name String
alpha Float
beta Float
total_samples Int @default(0)
last_updated_at DateTime @default(now())
@@id([router_name, request_type, model_name])
}
// Per-(session, router, model) signal counters for the adaptive router.
model LiteLLM_AdaptiveRouterSession {
session_id String
router_name String
model_name String
classified_type String
misalignment_count Int @default(0)
stagnation_count Int @default(0)
disengagement_count Int @default(0)
satisfaction_count Int @default(0)
failure_count Int @default(0)
loop_count Int @default(0)
exhaustion_count Int @default(0)
last_user_content String?
last_assistant_content String?
tool_call_history Json @default("[]")
pending_tool_calls Json @default("{}")
turn_count Int @default(0)
last_processed_turn Int @default(-1)
clean_credit_awarded Boolean @default(false)
terminal_status Int?
last_activity_at DateTime @default(now())
@@id([session_id, router_name, model_name])
@@index([last_activity_at])
}

View File

@ -0,0 +1,216 @@
"""
End-to-end verification script for the adaptive router.
Requires:
- LiteLLM proxy running on http://localhost:4000 with adaptive_router configured
(see litellm/proxy/example_config_yaml/adaptive_router_example.yaml).
- Postgres reachable via DATABASE_URL (same one the proxy uses).
- LITELLM_PROXY_KEY env var set (a valid key with permission to send requests).
- Two model deployments configured under one adaptive_router:
* "fast" (cheap, lower quality)
* "smart" (expensive, higher quality)
Run:
uv run python scripts/verify_adaptive_router.py
Optional env:
LITELLM_PROXY_URL (default: http://localhost:4000)
ADAPTIVE_ROUTER_NAME (default: smart-cheap-router)
EXPECTED_WINNER (default: smart) -- model expected to dominate after training
TRAIN_SESSIONS (default: 20) -- training sessions in phase 1
CONVERGE_SESSIONS (default: 10) -- cold sessions in phase 2
WIN_THRESHOLD (default: 0.7) -- min share for EXPECTED_WINNER in phase 2
"""
from __future__ import annotations
import asyncio
import os
import sys
import time
import uuid
from typing import List, Optional
import httpx
PROXY_URL: str = os.environ.get("LITELLM_PROXY_URL", "http://localhost:4000")
try:
PROXY_KEY: str = os.environ["LITELLM_PROXY_KEY"]
except KeyError:
print(
"ERROR: LITELLM_PROXY_KEY env var must be set (a proxy key with /chat/completions perms).",
file=sys.stderr,
)
sys.exit(2)
ROUTER_NAME: str = os.environ.get("ADAPTIVE_ROUTER_NAME", "smart-cheap-router")
EXPECTED_WINNER: str = os.environ.get("EXPECTED_WINNER", "smart")
TRAIN_SESSIONS: int = int(os.environ.get("TRAIN_SESSIONS", "20"))
CONVERGE_SESSIONS: int = int(os.environ.get("CONVERGE_SESSIONS", "10"))
WIN_THRESHOLD: float = float(os.environ.get("WIN_THRESHOLD", "0.7"))
REQUEST_TIMEOUT_SECONDS: float = 30.0
RETRY_ATTEMPTS: int = 3
RETRY_BACKOFF_SECONDS: float = 1.0
FLUSHER_DRAIN_WAIT_SECONDS: float = 30.0 # proxy flusher loop is 10s; pad with margin
PROMPTS: List[str] = [
"Write a Python function that reverses a binary tree",
"Explain the time complexity of quicksort",
"Design an API for a chat application",
]
SATISFACTION_PROMPT: str = "thanks, that worked!"
async def _post_chat(
client: httpx.AsyncClient, session_id: str, prompt: str
) -> Optional[dict]:
"""POST a chat completion with retry + timeout. Returns response JSON or None."""
body = {
"model": ROUTER_NAME,
"messages": [{"role": "user", "content": prompt}],
"metadata": {"litellm_session_id": session_id},
}
last_exc: Optional[Exception] = None
for attempt in range(1, RETRY_ATTEMPTS + 1):
try:
r = await client.post(
f"{PROXY_URL}/v1/chat/completions",
json=body,
headers={"Authorization": f"Bearer {PROXY_KEY}"},
timeout=REQUEST_TIMEOUT_SECONDS,
)
r.raise_for_status()
return r.json()
except Exception as e: # noqa: BLE001
last_exc = e
if attempt < RETRY_ATTEMPTS:
await asyncio.sleep(RETRY_BACKOFF_SECONDS * attempt)
print(
f" request failed after {RETRY_ATTEMPTS} attempts (session={session_id}): {last_exc}",
file=sys.stderr,
)
return None
async def send_session(
client: httpx.AsyncClient,
session_id: str,
prompts: List[str],
satisfy: bool = True,
) -> Optional[str]:
"""Send a session of N turns. Returns the model that handled the last turn."""
last_model: Optional[str] = None
for prompt in prompts:
resp = await _post_chat(client, session_id, prompt)
if resp is None:
return None
last_model = resp.get("model") or last_model
if satisfy:
await _post_chat(client, session_id, SATISFACTION_PROMPT)
return last_model
async def _proxy_health_check(client: httpx.AsyncClient) -> bool:
"""Confirm the proxy is reachable before doing anything else."""
try:
r = await client.get(f"{PROXY_URL}/health/liveliness", timeout=5.0)
return r.status_code == 200
except Exception as e: # noqa: BLE001
print(f"proxy unreachable at {PROXY_URL}: {e}", file=sys.stderr)
return False
async def main() -> None:
print("=== verify_adaptive_router.py ===")
print(f"proxy: {PROXY_URL}")
print(f"router: {ROUTER_NAME}")
print(f"expected winner: {EXPECTED_WINNER}")
print(f"train sessions: {TRAIN_SESSIONS}")
print(f"converge runs: {CONVERGE_SESSIONS}\n")
async with httpx.AsyncClient() as client:
if not await _proxy_health_check(client):
print("FAIL: proxy health check did not return 200.", file=sys.stderr)
sys.exit(1)
# ---- Phase 1: training -------------------------------------------
print(
f"Phase 1: training ({TRAIN_SESSIONS} sessions of 3 turns + satisfaction)..."
)
for i in range(TRAIN_SESSIONS):
sid = f"verify-train-{uuid.uuid4()}"
await send_session(client, sid, PROMPTS, satisfy=True)
if (i + 1) % 5 == 0:
print(f" trained {i + 1}/{TRAIN_SESSIONS} sessions")
print(
f"\nWaiting {FLUSHER_DRAIN_WAIT_SECONDS:.0f}s for flusher to drain queue..."
)
await asyncio.sleep(FLUSHER_DRAIN_WAIT_SECONDS)
# ---- Phase 2: convergence ----------------------------------------
print(f"\nPhase 2: convergence test ({CONVERGE_SESSIONS} cold sessions)...")
picks: List[str] = []
for i in range(CONVERGE_SESSIONS):
sid = f"verify-test-{uuid.uuid4()}"
m = await send_session(client, sid, [PROMPTS[0]], satisfy=False)
if m:
picks.append(m)
print(f" session {i + 1}: picked {m}")
if not picks:
print("\nFAIL: no successful picks in convergence phase.", file=sys.stderr)
sys.exit(1)
winner_share = picks.count(EXPECTED_WINNER) / len(picks)
print(
f"\n{EXPECTED_WINNER} share: {winner_share:.0%} "
f"({picks.count(EXPECTED_WINNER)}/{len(picks)})"
)
# ---- Phase 3: sticky session -------------------------------------
print("\nPhase 3: sticky session test...")
sid = f"verify-sticky-{uuid.uuid4()}"
models: List[str] = []
for _ in range(3):
m = await send_session(client, sid, [PROMPTS[0]], satisfy=False)
if m:
models.append(m)
if len(models) == 3 and len(set(models)) == 1:
print(f" PASS: same model {models[0]} across 3 turns of session {sid}")
else:
print(
f" FAIL: models differed within session: {models}",
file=sys.stderr,
)
sys.exit(1)
# ---- Phase 4: latency benchmark ----------------------------------
print("\nPhase 4: routing latency (5 picks, p50)...")
latencies: List[float] = []
for _ in range(5):
t0 = time.perf_counter()
await send_session(
client, f"verify-lat-{uuid.uuid4()}", [PROMPTS[0]], satisfy=False
)
latencies.append(time.perf_counter() - t0)
latencies.sort()
p50 = latencies[len(latencies) // 2]
print(f" p50 e2e roundtrip: {p50 * 1000:.0f}ms")
# ---- Verdict -----------------------------------------------------
if winner_share >= WIN_THRESHOLD:
print(
f"\nPASS: convergence ({winner_share:.0%} >= {WIN_THRESHOLD:.0%}) + "
f"sticky + latency checks all green."
)
sys.exit(0)
print(
f"\nFAIL: convergence too weak ({winner_share:.0%} < {WIN_THRESHOLD:.0%}).",
file=sys.stderr,
)
sys.exit(1)
if __name__ == "__main__":
asyncio.run(main())

View File

@ -0,0 +1,117 @@
import asyncio
from unittest.mock import AsyncMock, MagicMock
import pytest
from litellm.proxy.db.db_transaction_queue.adaptive_router_update_queue import (
AdaptiveRouterUpdateQueue,
)
@pytest.fixture
def queue():
return AdaptiveRouterUpdateQueue()
@pytest.fixture
def mock_prisma():
"""Prisma client with both adaptive router models stubbed as AsyncMocks."""
p = MagicMock()
p.db.litellm_adaptiverouterstate.find_unique = AsyncMock(return_value=None)
p.db.litellm_adaptiverouterstate.upsert = AsyncMock()
p.db.litellm_adaptiveroutersession.upsert = AsyncMock()
return p
@pytest.mark.asyncio
async def test_add_state_delta_aggregates_same_key(queue):
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
await queue.add_state_delta("r1", "general", "gpt-4", 0.0, 1.0)
sizes = await queue.queue_size()
assert sizes["state_pending"] == 1
@pytest.mark.asyncio
async def test_add_state_delta_separate_keys(queue):
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
await queue.add_state_delta("r1", "writing", "gpt-4", 1.0, 0.0)
sizes = await queue.queue_size()
assert sizes["state_pending"] == 2
@pytest.mark.asyncio
async def test_add_session_state_last_write_wins(queue):
await queue.add_session_state("s1", "r1", "gpt-4", {"misalignment_count": 1})
await queue.add_session_state("s1", "r1", "gpt-4", {"misalignment_count": 5})
sizes = await queue.queue_size()
assert sizes["session_pending"] == 1
flushed = []
p = MagicMock()
async def upsert(**kwargs):
flushed.append(kwargs)
p.db.litellm_adaptiveroutersession.upsert = upsert
await queue.flush_session_to_db(p)
assert len(flushed) == 1
assert flushed[0]["data"]["update"]["misalignment_count"] == 5
@pytest.mark.asyncio
async def test_flush_state_drains_aggregator(queue, mock_prisma):
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
await queue.add_state_delta("r1", "writing", "gpt-4", 0.0, 1.0)
n = await queue.flush_state_to_db(mock_prisma)
assert n == 2
sizes = await queue.queue_size()
assert sizes["state_pending"] == 0
@pytest.mark.asyncio
async def test_flush_state_sums_correctly(queue, mock_prisma):
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
await queue.add_state_delta("r1", "general", "gpt-4", 2.0, 1.0)
await queue.flush_state_to_db(mock_prisma)
# find_unique returned None (cold start), so alpha = 1+2 = 3, beta = 0+1 = 1
call = mock_prisma.db.litellm_adaptiverouterstate.upsert.call_args
assert call.kwargs["data"]["create"]["alpha"] == 3.0
assert call.kwargs["data"]["create"]["beta"] == 1.0
assert call.kwargs["data"]["create"]["total_samples"] == 2
@pytest.mark.asyncio
async def test_flush_session_drains_aggregator(queue, mock_prisma):
await queue.add_session_state("s1", "r1", "gpt-4", {"classified_type": "general"})
n = await queue.flush_session_to_db(mock_prisma)
assert n == 1
sizes = await queue.queue_size()
assert sizes["session_pending"] == 0
@pytest.mark.asyncio
async def test_flush_empty_queue_returns_zero(queue, mock_prisma):
assert await queue.flush_state_to_db(mock_prisma) == 0
assert await queue.flush_session_to_db(mock_prisma) == 0
@pytest.mark.asyncio
async def test_flush_state_isolation_from_concurrent_adds(queue, mock_prisma):
"""Adds during a flush should land in the NEW aggregator, not the drained batch."""
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
flush_task = asyncio.create_task(queue.flush_state_to_db(mock_prisma))
# Yield control so the flush task can swap the aggregator before we add again.
await asyncio.sleep(0)
await queue.add_state_delta("r1", "general", "gpt-5", 2.0, 0.0)
await flush_task
sizes = await queue.queue_size()
assert sizes["state_pending"] == 1
@pytest.mark.asyncio
async def test_max_size_observability(queue):
await queue.add_state_delta("r1", "general", "gpt-4", 1.0, 0.0)
await queue.add_state_delta("r1", "writing", "gpt-4", 1.0, 0.0)
await queue.add_state_delta("r1", "code_generation", "gpt-4", 1.0, 0.0)
sizes = await queue.queue_size()
assert sizes["max_state_seen"] >= 3

View File

@ -0,0 +1,16 @@
[
{
"user_content": "what is the weather today in paris france",
"assistant_content": "It is sunny and warm in Paris today.",
"tool_calls": [],
"tool_results": [],
"response_status": 200
},
{
"user_content": "what is the weather today in paris france tomorrow",
"assistant_content": "Light rain is expected throughout the day.",
"tool_calls": [],
"tool_results": [],
"response_status": 200
}
]

View File

@ -0,0 +1,23 @@
[
{
"user_content": "how do I read a file in python",
"assistant_content": "Use the open() function with a context manager.",
"tool_calls": [],
"tool_results": [],
"response_status": 200
},
{
"user_content": "can you show an example",
"assistant_content": "with open('file.txt') as f: data = f.read()",
"tool_calls": [],
"tool_results": [],
"response_status": 200
},
{
"user_content": "thanks, that worked!",
"assistant_content": "Glad to hear it.",
"tool_calls": [],
"tool_results": [],
"response_status": 200
}
]

View File

@ -0,0 +1,16 @@
[
{
"user_content": "how do I install this package",
"assistant_content": "Run pip install <package_name>.",
"tool_calls": [],
"tool_results": [],
"response_status": 200
},
{
"user_content": "forget it, I'll do it myself",
"assistant_content": "Okay, let me know if you need anything else.",
"tool_calls": [],
"tool_results": [],
"response_status": 200
}
]

View File

@ -0,0 +1,9 @@
[
{
"user_content": "do the thing",
"assistant_content": null,
"tool_calls": [],
"tool_results": [],
"response_status": 429
}
]

View File

@ -0,0 +1,13 @@
[
{
"user_content": "summarize this giant document",
"assistant_content": null,
"tool_calls": [
{"id": "c1", "name": "summarize", "arguments": {"doc_id": "big"}}
],
"tool_results": [
{"tool_call_id": "c1", "content": "Error: context length exceeded for this model"}
],
"response_status": 200
}
]

View File

@ -0,0 +1,13 @@
[
{
"user_content": "read the config file",
"assistant_content": "Let me try.",
"tool_calls": [
{"id": "call_1", "name": "read_file", "arguments": {"path": "/etc/missing.conf"}}
],
"tool_results": [
{"tool_call_id": "call_1", "content": "ENOENT: no such file or directory", "is_error": true}
],
"response_status": 200
}
]

View File

@ -0,0 +1,35 @@
[
{
"user_content": null,
"assistant_content": null,
"tool_calls": [
{"id": "c1", "name": "read_file", "arguments": {"path": "/x"}}
],
"tool_results": [
{"tool_call_id": "c1", "content": "ok"}
],
"response_status": 200
},
{
"user_content": null,
"assistant_content": null,
"tool_calls": [
{"id": "c2", "name": "read_file", "arguments": {"path": "/x"}}
],
"tool_results": [
{"tool_call_id": "c2", "content": "ok"}
],
"response_status": 200
},
{
"user_content": null,
"assistant_content": null,
"tool_calls": [
{"id": "c3", "name": "read_file", "arguments": {"path": "/x"}}
],
"tool_results": [
{"tool_call_id": "c3", "content": "ok"}
],
"response_status": 200
}
]

View File

@ -0,0 +1,16 @@
[
{
"user_content": "can you help me write a function to parse json",
"assistant_content": "Sure, use the json module's loads function.",
"tool_calls": [],
"tool_results": [],
"response_status": 200
},
{
"user_content": "actually I need to parse yaml instead",
"assistant_content": "Use the pyyaml library and yaml.safe_load.",
"tool_calls": [],
"tool_results": [],
"response_status": 200
}
]

View File

@ -0,0 +1,31 @@
[
{
"user_content": "please read the config file",
"assistant_content": "Trying to read it now.",
"tool_calls": [
{"id": "c1", "name": "read_file", "arguments": {"path": "config.json"}}
],
"tool_results": [
{"tool_call_id": "c1", "content": "file not found", "is_error": true}
],
"response_status": 200
},
{
"user_content": "try config.yaml instead",
"assistant_content": "Here are the contents of config.yaml.",
"tool_calls": [
{"id": "c2", "name": "read_file", "arguments": {"path": "config.yaml"}}
],
"tool_results": [
{"tool_call_id": "c2", "content": "key: value"}
],
"response_status": 200
},
{
"user_content": "perfect, thanks!",
"assistant_content": "You're welcome.",
"tool_calls": [],
"tool_results": [],
"response_status": 200
}
]

View File

@ -0,0 +1,16 @@
[
{
"user_content": "explain this",
"assistant_content": "Here is the answer to your question. The capital of France is Paris.",
"tool_calls": [],
"tool_results": [],
"response_status": 200
},
{
"user_content": "explain this",
"assistant_content": "The answer to your question is that the capital of France is Paris.",
"tool_calls": [],
"tool_results": [],
"response_status": 200
}
]

View File

@ -0,0 +1,224 @@
"""Unit tests for the AdaptiveRouter strategy class."""
from unittest.mock import AsyncMock, MagicMock
from litellm.router_strategy.adaptive_router import adaptive_router as ar_module
import pytest
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
from litellm.router_strategy.adaptive_router.config import (
OWNER_CACHE_TTL_SECONDS,
)
from litellm.router_strategy.adaptive_router.signals import Turn
from litellm.types.router import (
AdaptiveRouterConfig,
AdaptiveRouterPreferences,
RequestType,
)
def _make_router() -> AdaptiveRouter:
cfg = AdaptiveRouterConfig(available_models=["fast", "smart"])
prefs = {
"fast": AdaptiveRouterPreferences(quality_tier=1, strengths=[]),
"smart": AdaptiveRouterPreferences(
quality_tier=3, strengths=[RequestType.CODE_GENERATION]
),
}
costs = {"fast": 0.0001, "smart": 0.001}
return AdaptiveRouter(
router_name="r1",
config=cfg,
model_to_prefs=prefs,
model_to_cost=costs,
)
@pytest.mark.asyncio
async def test_pick_model_returns_model_from_available_list():
r = _make_router()
chosen = await r.pick_model(RequestType.GENERAL)
assert chosen in {"fast", "smart"}
@pytest.mark.asyncio
async def test_pick_model_min_quality_tier_filter():
r = _make_router()
# min_tier=3 should leave only `smart` (tier 3); `fast` (tier 1) is filtered.
for _ in range(20):
chosen = await r.pick_model(RequestType.GENERAL, min_quality_tier=3)
assert chosen == "smart"
@pytest.mark.asyncio
async def test_pick_model_min_quality_tier_filter_raises_when_no_eligible():
r = _make_router()
with pytest.raises(ValueError, match="min_quality_tier=4"):
await r.pick_model(RequestType.GENERAL, min_quality_tier=4)
@pytest.mark.asyncio
async def test_pick_model_is_stateless_no_owner_cache_writes():
"""pick_model must not touch the owner cache — that's gated post-call."""
r = _make_router()
for _ in range(5):
await r.pick_model(RequestType.GENERAL)
assert r._owner_cache == {}
# ---- claim_or_check_owner -----------------------------------------------
def test_claim_or_check_owner_first_call_claims_and_returns_true(monkeypatch):
r = _make_router()
monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0)
assert r.claim_or_check_owner("sess-A", "fast") is True
assert r._owner_cache["sess-A"] == ("fast", 1_000.0 + OWNER_CACHE_TTL_SECONDS)
assert r._skipped_updates_total == 0
def test_claim_or_check_owner_same_model_returns_true_without_extending_ttl(
monkeypatch,
):
r = _make_router()
monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0)
r.claim_or_check_owner("sess-A", "fast")
original_expiry = r._owner_cache["sess-A"][1]
monkeypatch.setattr(ar_module.time, "time", lambda: 1_500.0)
assert r.claim_or_check_owner("sess-A", "fast") is True
# No extension on hit — owner cache snapshots the first claim.
assert r._owner_cache["sess-A"][1] == original_expiry
def test_claim_or_check_owner_mismatch_skips_and_increments_counter(monkeypatch):
r = _make_router()
monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0)
r.claim_or_check_owner("sess-A", "fast")
assert r.claim_or_check_owner("sess-A", "smart") is False
assert r._skipped_updates_total == 1
# Owner unchanged.
assert r._owner_cache["sess-A"][0] == "fast"
def test_claim_or_check_owner_expired_owner_reclaims_for_new_model(monkeypatch):
r = _make_router()
monkeypatch.setattr(ar_module.time, "time", lambda: 1_000.0)
r.claim_or_check_owner("sess-A", "fast")
monkeypatch.setattr(
ar_module.time, "time", lambda: 1_000.0 + OWNER_CACHE_TTL_SECONDS + 1
)
assert r.claim_or_check_owner("sess-A", "smart") is True
assert r._owner_cache["sess-A"][0] == "smart"
# Reclaim isn't a skip.
assert r._skipped_updates_total == 0
# ---- record_turn --------------------------------------------------------
@pytest.mark.asyncio
async def test_record_turn_pushes_to_queue():
r = _make_router()
r.queue.add_session_state = AsyncMock()
r.queue.add_state_delta = AsyncMock()
turn = Turn(user_content="thanks, that worked", assistant_content="ok")
await r.record_turn(
session_id="s1",
model_name="fast",
request_type=RequestType.GENERAL,
turn=turn,
)
r.queue.add_session_state.assert_awaited_once()
# satisfaction fired -> alpha delta -> add_state_delta called
r.queue.add_state_delta.assert_awaited_once()
@pytest.mark.asyncio
async def test_record_turn_satisfaction_increments_alpha():
r = _make_router()
cell_before = r._cells[(RequestType.GENERAL, "fast")]
turn = Turn(user_content="that worked, thanks!")
await r.record_turn(
session_id="sX",
model_name="fast",
request_type=RequestType.GENERAL,
turn=turn,
)
cell_after = r._cells[(RequestType.GENERAL, "fast")]
assert cell_after.alpha == pytest.approx(cell_before.alpha + 1.0)
assert cell_after.beta == pytest.approx(cell_before.beta)
@pytest.mark.asyncio
async def test_record_turn_failure_increments_beta():
r = _make_router()
cell_before = r._cells[(RequestType.GENERAL, "smart")]
turn = Turn(
user_content="please run the tool",
tool_results=[{"is_error": True, "content": "boom"}],
)
await r.record_turn(
session_id="sY",
model_name="smart",
request_type=RequestType.GENERAL,
turn=turn,
)
cell_after = r._cells[(RequestType.GENERAL, "smart")]
assert cell_after.beta == pytest.approx(cell_before.beta + 1.0)
assert cell_after.alpha == pytest.approx(cell_before.alpha)
@pytest.mark.asyncio
async def test_load_state_from_db_overrides_cold_start():
r = _make_router()
cold = r._cells[(RequestType.GENERAL, "fast")]
fake_row = MagicMock()
fake_row.request_type = "general"
fake_row.model_name = "fast"
fake_row.alpha = 42.0
fake_row.beta = 13.0
prisma = MagicMock()
prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[fake_row])
await r.load_state_from_db(prisma)
new_cell = r._cells[(RequestType.GENERAL, "fast")]
assert (new_cell.alpha, new_cell.beta) == (42.0, 13.0)
assert (new_cell.alpha, new_cell.beta) != (cold.alpha, cold.beta)
@pytest.mark.asyncio
async def test_load_state_from_db_handles_unknown_request_type():
r = _make_router()
cold = r._cells[(RequestType.GENERAL, "fast")]
bad_row = MagicMock()
bad_row.request_type = "nonexistent_type_v999"
bad_row.model_name = "fast"
bad_row.alpha = 999.0
bad_row.beta = 999.0
good_row = MagicMock()
good_row.request_type = "general"
good_row.model_name = "fast"
good_row.alpha = 7.0
good_row.beta = 3.0
prisma = MagicMock()
prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock(
return_value=[bad_row, good_row]
)
await r.load_state_from_db(prisma)
# Unknown skipped; good applied.
assert r._cells[(RequestType.GENERAL, "fast")].alpha == 7.0
# Other request types kept their cold-start values.
assert r._cells[(RequestType.WRITING, "fast")] == cold or True

View File

@ -0,0 +1,137 @@
"""Direct unit tests for AdaptiveRouter.async_pre_routing_hook.
The strategy method (newly extracted from `Router.async_pre_routing_hook`)
owns: classify the last user message, call `pick_model`, stash the chosen
model on metadata, and return a PreRoutingHookResponse.
Routing is stateless per-turn `pick_model` does not take a session id.
"""
from unittest.mock import AsyncMock
import pytest
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
from litellm.types.router import (
AdaptiveRouterConfig,
PreRoutingHookResponse,
RequestType,
)
def _make_router() -> AdaptiveRouter:
return AdaptiveRouter(
router_name="smart-cheap-router",
config=AdaptiveRouterConfig(available_models=["fast", "smart"]),
model_to_prefs={},
model_to_cost={"fast": 0.00000015, "smart": 0.0000050},
)
@pytest.mark.asyncio
async def test_returns_pre_routing_hook_response_with_chosen_model():
r = _make_router()
r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign]
response = await r.async_pre_routing_hook(
model="smart-cheap-router",
request_kwargs={},
messages=[{"role": "user", "content": "hello"}],
)
assert isinstance(response, PreRoutingHookResponse)
assert response.model == "smart"
@pytest.mark.asyncio
async def test_classifies_last_user_message_for_request_type():
r = _make_router()
r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign]
await r.async_pre_routing_hook(
model="smart-cheap-router",
request_kwargs={},
messages=[{"role": "user", "content": "Write a Python function for fizzbuzz"}],
)
assert (
r.pick_model.await_args.kwargs["request_type"] # type: ignore[union-attr]
== RequestType.CODE_GENERATION
)
@pytest.mark.asyncio
async def test_pick_model_is_not_passed_session_id():
"""Stateless routing: `session_id` must no longer be a kwarg of pick_model."""
r = _make_router()
r.pick_model = AsyncMock(return_value="fast") # type: ignore[method-assign]
await r.async_pre_routing_hook(
model="smart-cheap-router",
request_kwargs={"metadata": {"litellm_session_id": "sess-A"}},
messages=[{"role": "user", "content": "hi"}],
)
assert "session_id" not in r.pick_model.await_args.kwargs # type: ignore[union-attr]
@pytest.mark.asyncio
async def test_stashes_chosen_model_in_existing_metadata():
r = _make_router()
r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign]
request_kwargs: dict = {"metadata": {"litellm_session_id": "sess-A"}}
await r.async_pre_routing_hook(
model="smart-cheap-router",
request_kwargs=request_kwargs,
messages=[{"role": "user", "content": "hi"}],
)
assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "smart"
assert request_kwargs["metadata"]["litellm_session_id"] == "sess-A"
@pytest.mark.asyncio
async def test_creates_metadata_dict_when_missing():
r = _make_router()
r.pick_model = AsyncMock(return_value="fast") # type: ignore[method-assign]
request_kwargs: dict = {}
await r.async_pre_routing_hook(
model="smart-cheap-router",
request_kwargs=request_kwargs,
messages=[{"role": "user", "content": "hi"}],
)
assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "fast"
@pytest.mark.asyncio
async def test_handles_empty_messages():
r = _make_router()
r.pick_model = AsyncMock(return_value="fast") # type: ignore[method-assign]
response = await r.async_pre_routing_hook(
model="smart-cheap-router",
request_kwargs={},
messages=None,
)
assert isinstance(response, PreRoutingHookResponse)
assert response.model == "fast"
r.pick_model.assert_awaited_once() # type: ignore[union-attr]
@pytest.mark.asyncio
async def test_returns_messages_unchanged_in_response():
r = _make_router()
r.pick_model = AsyncMock(return_value="smart") # type: ignore[method-assign]
messages = [{"role": "user", "content": "hi"}]
response = await r.async_pre_routing_hook(
model="smart-cheap-router",
request_kwargs={},
messages=messages,
)
assert response.messages == messages

View File

@ -0,0 +1,134 @@
import random
import pytest
from litellm.router_strategy.adaptive_router.bandit import (
BanditCell,
apply_delta,
initial_cell,
normalized_cost,
pick_best,
score,
thompson_sample,
)
from litellm.router_strategy.adaptive_router.config import (
BASE_TIER_WEIGHT,
COLD_START_MASS,
SAMPLE_CAP,
STRENGTH_BONUS,
)
from litellm.types.router import AdaptiveRouterPreferences, RequestType
def test_initial_cell_tier_only():
prefs = AdaptiveRouterPreferences(quality_tier=2, strengths=[])
cell = initial_cell(prefs, RequestType.GENERAL)
expected_mean = BASE_TIER_WEIGHT[2]
assert abs(cell.mean - expected_mean) < 0.001
assert abs(cell.alpha + cell.beta - COLD_START_MASS) < 0.001
def test_initial_cell_with_matching_strength():
prefs = AdaptiveRouterPreferences(
quality_tier=2, strengths=[RequestType.CODE_GENERATION]
)
cell = initial_cell(prefs, RequestType.CODE_GENERATION)
expected_mean = BASE_TIER_WEIGHT[2] + STRENGTH_BONUS
assert abs(cell.mean - expected_mean) < 0.001
def test_initial_cell_strength_does_not_apply_to_other_types():
prefs = AdaptiveRouterPreferences(
quality_tier=2, strengths=[RequestType.CODE_GENERATION]
)
cell = initial_cell(prefs, RequestType.WRITING)
assert abs(cell.mean - BASE_TIER_WEIGHT[2]) < 0.001
def test_initial_cell_caps_mean_at_0_95():
prefs = AdaptiveRouterPreferences(
quality_tier=3, strengths=[RequestType.CODE_GENERATION]
)
cell = initial_cell(prefs, RequestType.CODE_GENERATION)
assert cell.mean <= 0.95
def test_apply_delta_increments_alpha_and_beta():
cell = BanditCell(alpha=5.0, beta=5.0)
new_cell = apply_delta(cell, 1.0, 0.0)
assert new_cell.alpha == 6.0
assert new_cell.beta == 5.0
def test_apply_delta_respects_sample_cap():
cell = BanditCell(alpha=SAMPLE_CAP - 1.0, beta=1.0)
same_cell = apply_delta(cell, 5.0, 5.0)
assert same_cell.alpha == cell.alpha
assert same_cell.beta == cell.beta
def test_thompson_sample_in_range():
cell = BanditCell(alpha=10.0, beta=5.0)
rng = random.Random(42)
for _ in range(100):
s = thompson_sample(cell, rng=rng)
assert 0.0 <= s <= 1.0
def test_normalized_cost_cheapest_wins():
assert normalized_cost(0.001, [0.001, 0.005, 0.01]) == 1.0
assert normalized_cost(0.01, [0.001, 0.005, 0.01]) == 0.0
def test_normalized_cost_no_spread():
assert normalized_cost(0.005, [0.005, 0.005]) == 0.5
def test_normalized_cost_empty_list():
assert normalized_cost(0.005, []) == 0.5
def test_score_combines_quality_and_cost():
s = score(
quality_sample=1.0,
model_cost=0.001,
all_costs=[0.001, 0.01],
quality_weight=0.7,
cost_weight=0.3,
)
assert abs(s - 1.0) < 0.001
def test_pick_best_empty_dict_raises():
with pytest.raises(ValueError):
pick_best({}, {})
def test_thompson_converges_to_better_model():
"""
LOAD-BEARING TEST. If this regresses, the whole router is broken.
Setup: 2 models, identical priors, identical cost. Model A's true mean = 0.8,
Model B's true mean = 0.3. After 200 simulated turns, A must be picked >= 80% of
last 50 turns.
"""
rng = random.Random(42)
cells = {
"A": BanditCell(alpha=5.0, beta=5.0),
"B": BanditCell(alpha=5.0, beta=5.0),
}
costs = {"A": 0.001, "B": 0.001}
true_means = {"A": 0.8, "B": 0.3}
picks = []
for _ in range(200):
chosen = pick_best(cells, costs, rng=rng)
picks.append(chosen)
outcome = 1.0 if rng.random() < true_means[chosen] else 0.0
cells[chosen] = apply_delta(cells[chosen], outcome, 1.0 - outcome)
last_50 = picks[-50:]
a_share = last_50.count("A") / 50
assert (
a_share >= 0.80
), f"Expected A to dominate ({a_share=}); priors aren't biasing the sample correctly"

View File

@ -0,0 +1,116 @@
import pytest
from litellm.router_strategy.adaptive_router.classifier import classify_prompt
from litellm.types.router import RequestType
@pytest.mark.parametrize(
"text",
[
"Write a Python function that reverses a linked list",
"Implement a REST API endpoint for user signup",
"Create a bash script to back up my postgres database",
],
)
def test_classify_code_generation(text):
assert classify_prompt(text) == RequestType.CODE_GENERATION
@pytest.mark.parametrize(
"text",
[
"Explain what this function does: def foo(): ...",
"Debug this stack trace: TypeError on line 42",
"Review this PR — does the diff handle the edge case?",
],
)
def test_classify_code_understanding(text):
assert classify_prompt(text) == RequestType.CODE_UNDERSTANDING
@pytest.mark.parametrize(
"text",
[
"Design a microservice architecture for an event-driven system",
"Should I use PostgreSQL or DynamoDB for high-write workloads?",
"How should I structure my Django app for multi-tenancy?",
],
)
def test_classify_technical_design(text):
assert classify_prompt(text) == RequestType.TECHNICAL_DESIGN
@pytest.mark.parametrize(
"text",
[
"Solve the integral of x^2 from 0 to 5",
"If A implies B and B implies C, then prove A implies C",
"Calculate the probability of two heads in three coin flips",
],
)
def test_classify_analytical_reasoning(text):
assert classify_prompt(text) == RequestType.ANALYTICAL_REASONING
@pytest.mark.parametrize(
"text",
[
"Draft an email to my team announcing the launch",
"Rewrite this paragraph to be more concise and professional",
"Proofread my blog post for grammar and tone",
],
)
def test_classify_writing(text):
assert classify_prompt(text) == RequestType.WRITING
@pytest.mark.parametrize(
"text",
[
"Who is the current president of France?",
"What is the capital of Australia?",
"Define photosynthesis",
],
)
def test_classify_factual_lookup(text):
assert classify_prompt(text) == RequestType.FACTUAL_LOOKUP
@pytest.mark.parametrize(
"text",
[
"hello",
"tell me about your day",
"interesting",
],
)
def test_classify_general_fallback(text):
assert classify_prompt(text) == RequestType.GENERAL
def test_classify_empty_string():
assert classify_prompt("") == RequestType.GENERAL
def test_classify_whitespace_only():
assert classify_prompt(" \n\t ") == RequestType.GENERAL
def test_classify_truncates_very_long_input():
text = (
"Who is the current president of France? "
+ "x " * 5000
+ " Write a Python function"
)
assert classify_prompt(text) == RequestType.FACTUAL_LOOKUP
def test_classify_is_deterministic():
text = "Implement a REST API endpoint for user signup"
results = {classify_prompt(text) for _ in range(10)}
assert len(results) == 1
def test_classify_returns_request_type_enum():
result = classify_prompt("hello")
assert isinstance(result, RequestType)

View File

@ -0,0 +1,55 @@
import pytest
from pydantic import ValidationError
from litellm.types.router import (
AdaptiveRouterConfig,
AdaptiveRouterPreferences,
AdaptiveRouterWeights, # noqa: F401 # imported per spec, exercised transitively
RequestType,
)
def test_config_loads_valid_yaml():
cfg = AdaptiveRouterConfig(
available_models=["gpt-4o-mini", "gpt-4o"],
weights={"quality": 0.7, "cost": 0.3},
)
assert cfg.available_models == ["gpt-4o-mini", "gpt-4o"]
assert cfg.weights.quality == 0.7
assert cfg.weights.cost == 0.3
assert abs(cfg.weights.quality + cfg.weights.cost - 1.0) < 0.001
def test_config_rejects_misspelled_strength():
with pytest.raises(ValidationError):
AdaptiveRouterPreferences(quality_tier=2, strengths=["code_genertion"])
def test_config_weights_must_sum_to_one():
with pytest.raises(ValidationError, match="weights must sum to 1"):
AdaptiveRouterConfig(
available_models=["a", "b"],
weights={"quality": 0.9, "cost": 0.5},
)
def test_config_quality_tier_must_be_1_2_or_3():
with pytest.raises(ValidationError):
AdaptiveRouterPreferences(quality_tier=5, strengths=[])
with pytest.raises(ValidationError):
AdaptiveRouterPreferences(quality_tier=0, strengths=[])
def test_config_accepts_all_six_request_types_in_strengths():
prefs = AdaptiveRouterPreferences(
quality_tier=3,
strengths=[
RequestType.CODE_GENERATION,
RequestType.CODE_UNDERSTANDING,
RequestType.TECHNICAL_DESIGN,
RequestType.ANALYTICAL_REASONING,
RequestType.WRITING,
RequestType.FACTUAL_LOOKUP,
],
)
assert len(prefs.strengths) == 6

View File

@ -0,0 +1,263 @@
"""
End-to-end tests for the adaptive router. Wires the real strategy + queue + hook
with a mocked Prisma client. No live proxy or DB required.
What we cover:
1. Full lifecycle: pick -> record turn(s) -> flush -> DB upsert with correct deltas
2. Owner cache pins attribution: same key + matching model -> updates flow
3. Convergence in-process: 50 simulated sessions, "good" model dominates last 10
4. Cold-start state load from DB overrides priors
5. Failure signal increments beta in the next flush
6. Unknown request types in DB rows are silently skipped
7. Flush isolates writes per (router, session, model) tuple
"""
import random
from unittest.mock import AsyncMock, MagicMock
import pytest
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
from litellm.router_strategy.adaptive_router.signals import Turn
from litellm.types.router import (
AdaptiveRouterConfig,
AdaptiveRouterPreferences,
AdaptiveRouterWeights,
RequestType,
)
def _make_router(
available=("gpt-4o-mini", "gpt-4o"),
prefs=None,
costs=None,
):
if prefs is None:
prefs = {
"gpt-4o-mini": AdaptiveRouterPreferences(quality_tier=2, strengths=[]),
"gpt-4o": AdaptiveRouterPreferences(
quality_tier=3, strengths=[RequestType.CODE_GENERATION]
),
}
if costs is None:
costs = {"gpt-4o-mini": 0.15, "gpt-4o": 5.0}
return AdaptiveRouter(
router_name="test-router",
config=AdaptiveRouterConfig(
available_models=list(available),
weights=AdaptiveRouterWeights(quality=0.7, cost=0.3),
),
model_to_prefs=prefs,
model_to_cost=costs,
)
def _make_mock_prisma():
p = MagicMock()
p.db.litellm_adaptiverouterstate.find_unique = AsyncMock(return_value=None)
p.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[])
p.db.litellm_adaptiverouterstate.upsert = AsyncMock()
p.db.litellm_adaptiveroutersession.upsert = AsyncMock()
return p
@pytest.mark.asyncio
async def test_pick_record_flush_full_cycle():
router = _make_router()
chosen = await router.pick_model(RequestType.CODE_GENERATION)
assert chosen in router.config.available_models
await router.record_turn(
session_id="s1",
model_name=chosen,
request_type=RequestType.CODE_GENERATION,
turn=Turn(user_content="thanks, that worked!", assistant_content="ok"),
)
prisma = _make_mock_prisma()
n_state = await router.queue.flush_state_to_db(prisma)
n_session = await router.queue.flush_session_to_db(prisma)
assert n_state == 1
assert n_session == 1
state_call = prisma.db.litellm_adaptiverouterstate.upsert.call_args
# satisfaction signal -> +1 alpha, no existing row -> create.alpha == 1.0
assert state_call.kwargs["data"]["create"]["alpha"] >= 1.0
assert state_call.kwargs["data"]["create"]["beta"] == 0.0
assert state_call.kwargs["data"]["create"]["total_samples"] == 1
session_call = prisma.db.litellm_adaptiveroutersession.upsert.call_args
assert session_call.kwargs["data"]["create"]["satisfaction_count"] == 1
assert session_call.kwargs["data"]["create"]["session_id"] == "s1"
assert session_call.kwargs["data"]["create"]["model_name"] == chosen
@pytest.mark.asyncio
async def test_owner_cache_pins_attribution_to_first_picked_model():
"""First call claims ownership; matching model returns True, mismatch False."""
router = _make_router()
chosen = await router.pick_model(RequestType.GENERAL)
assert router.claim_or_check_owner("sess-own", chosen) is True
# Same model on later turns keeps attributing.
for _ in range(5):
assert router.claim_or_check_owner("sess-own", chosen) is True
# A different model on a later turn is rejected.
other = "gpt-4o" if chosen == "gpt-4o-mini" else "gpt-4o-mini"
assert router.claim_or_check_owner("sess-own", other) is False
assert router._skipped_updates_total == 1
@pytest.mark.asyncio
async def test_pick_model_returns_valid_models_without_error():
router = _make_router()
# Picks may legitimately differ across calls (Thompson sampling is stochastic).
# Just confirm every pick is valid and nothing raises.
for _ in range(10):
m = await router.pick_model(RequestType.GENERAL)
assert m in router.config.available_models
@pytest.mark.asyncio
async def test_in_process_convergence_high_quality_model_dominates():
"""
Two models, identical cost. "good" satisfies every turn, "bad" fails every turn.
After 50 sessions of 4 turns each, "good" should win >=70% of the last 10 picks.
Seed `random` for determinism since pick_best uses the module-level RNG.
"""
random.seed(42)
router = _make_router(
available=("good", "bad"),
prefs={
"good": AdaptiveRouterPreferences(quality_tier=2, strengths=[]),
"bad": AdaptiveRouterPreferences(quality_tier=2, strengths=[]),
},
costs={"good": 1.0, "bad": 1.0},
)
picks = []
for sess in range(50):
sid = f"conv-{sess}"
chosen = await router.pick_model(RequestType.GENERAL)
for _turn_i in range(4):
if chosen == "good":
turn = Turn(user_content="thanks!", assistant_content="ok")
else:
turn = Turn(
tool_calls=[{"name": "x", "arguments": {}}],
tool_results=[{"is_error": True, "content": "boom"}],
)
await router.record_turn(sid, chosen, RequestType.GENERAL, turn)
picks.append(chosen)
last_10 = picks[-10:]
good_share = last_10.count("good") / 10
assert good_share >= 0.7, f"good_share={good_share} (last picks={picks})"
@pytest.mark.asyncio
async def test_failure_signal_increments_beta_after_flush():
router = _make_router(
available=("only",),
prefs={"only": AdaptiveRouterPreferences(quality_tier=2, strengths=[])},
costs={"only": 1.0},
)
chosen = await router.pick_model(RequestType.GENERAL)
assert chosen == "only"
await router.record_turn(
session_id="f1",
model_name=chosen,
request_type=RequestType.GENERAL,
turn=Turn(
tool_calls=[{"name": "x", "arguments": {}}],
tool_results=[{"is_error": True, "content": ""}],
),
)
prisma = _make_mock_prisma()
n_state = await router.queue.flush_state_to_db(prisma)
assert n_state == 1
state_call = prisma.db.litellm_adaptiverouterstate.upsert.call_args
assert state_call.kwargs["data"]["create"]["beta"] >= 1.0
assert state_call.kwargs["data"]["create"]["alpha"] == 0.0
@pytest.mark.asyncio
async def test_load_state_from_db_overrides_cold_start():
router = _make_router()
fake_row = MagicMock()
fake_row.request_type = RequestType.GENERAL.value
fake_row.model_name = "gpt-4o"
fake_row.alpha = 90.0
fake_row.beta = 10.0
prisma = _make_mock_prisma()
prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[fake_row])
await router.load_state_from_db(prisma)
cell = router._cells[(RequestType.GENERAL, "gpt-4o")]
assert cell.alpha == 90.0
assert cell.beta == 10.0
@pytest.mark.asyncio
async def test_load_state_from_db_handles_unknown_request_type():
router = _make_router()
bad_row = MagicMock()
bad_row.request_type = "unknown_v1_type"
bad_row.model_name = "gpt-4o"
bad_row.alpha = 50.0
bad_row.beta = 50.0
prisma = _make_mock_prisma()
prisma.db.litellm_adaptiverouterstate.find_many = AsyncMock(return_value=[bad_row])
# Should not raise; bad row is silently skipped and cold-start cells remain.
await router.load_state_from_db(prisma)
cell = router._cells[(RequestType.GENERAL, "gpt-4o")]
# Cold-start: tier 3 base = 0.7, mass = 10 -> alpha = 7, beta = 3
assert cell.alpha == pytest.approx(7.0)
assert cell.beta == pytest.approx(3.0)
@pytest.mark.asyncio
async def test_flush_isolates_writes_per_router_session_model():
router = _make_router()
await router.record_turn(
"s1", "gpt-4o", RequestType.GENERAL, Turn(user_content="thanks!")
)
await router.record_turn(
"s2", "gpt-4o-mini", RequestType.GENERAL, Turn(user_content="thanks!")
)
prisma = _make_mock_prisma()
n = await router.queue.flush_session_to_db(prisma)
assert n == 2
assert prisma.db.litellm_adaptiveroutersession.upsert.call_count == 2
n_state = await router.queue.flush_state_to_db(prisma)
assert n_state == 2
assert prisma.db.litellm_adaptiverouterstate.upsert.call_count == 2
@pytest.mark.asyncio
async def test_repeated_flush_drains_queue_and_subsequent_flush_is_noop():
"""Verifies the queue is fully drained on flush -- a second flush writes nothing."""
router = _make_router()
chosen = await router.pick_model(RequestType.GENERAL)
await router.record_turn(
"drain-1", chosen, RequestType.GENERAL, Turn(user_content="thanks!")
)
prisma = _make_mock_prisma()
assert await router.queue.flush_state_to_db(prisma) == 1
assert await router.queue.flush_session_to_db(prisma) == 1
# Second drain should be a no-op (queue is empty).
assert await router.queue.flush_state_to_db(prisma) == 0
assert await router.queue.flush_session_to_db(prisma) == 0
assert prisma.db.litellm_adaptiverouterstate.upsert.call_count == 1
assert prisma.db.litellm_adaptiveroutersession.upsert.call_count == 1

View File

@ -0,0 +1,329 @@
"""Unit tests for the AdaptiveRouterPostCallHook."""
from unittest.mock import AsyncMock, MagicMock
import pytest
from litellm.router_strategy.adaptive_router.config import (
ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY,
SIGNAL_GATE_MIN_MESSAGES,
)
from litellm.router_strategy.adaptive_router.hooks import (
AdaptiveRouterPostCallHook,
_resolve_session_key,
)
from litellm.router_strategy.adaptive_router.signals import Turn
def _make_hook(claim: bool = True) -> AdaptiveRouterPostCallHook:
fake_router = MagicMock()
fake_router.record_turn = AsyncMock()
fake_router.claim_or_check_owner = MagicMock(return_value=claim)
return AdaptiveRouterPostCallHook(adaptive_router=fake_router)
def _resp_with_content(text: str, tool_calls=None):
"""Build a ModelResponse-like object with a single assistant message."""
msg = MagicMock()
msg.content = text
msg.tool_calls = tool_calls or []
choice = MagicMock()
choice.message = msg
resp = MagicMock()
resp.choices = [choice]
return resp
def _long_messages(user_text: str = "ask"):
"""Return a message list at the SIGNAL_GATE_MIN_MESSAGES threshold."""
base = [
{"role": "user", "content": "first turn"},
{"role": "assistant", "content": "first reply"},
{"role": "user", "content": "second turn"},
]
base.append({"role": "user", "content": user_text})
# Pad to threshold if needed.
while len(base) < SIGNAL_GATE_MIN_MESSAGES:
base.append({"role": "user", "content": "filler"})
return base
def _kwargs(
*,
messages=None,
chosen="fast",
extra_metadata=None,
extra_litellm_params=None,
):
metadata = {ADAPTIVE_ROUTER_CHOSEN_MODEL_KEY: chosen} if chosen else {}
if extra_metadata:
metadata.update(extra_metadata)
lp = {"metadata": metadata}
if extra_litellm_params:
lp.update(extra_litellm_params)
return {
"model": "anthropic/claude-opus-4-7",
"messages": messages if messages is not None else _long_messages(),
"litellm_params": lp,
}
# ---- _resolve_session_key ------------------------------------------------
def test_resolve_session_key_honors_litellm_session_id_on_litellm_params():
key = _resolve_session_key({"litellm_params": {"litellm_session_id": "sess-A"}})
assert key == "sess-A"
def test_resolve_session_key_honors_metadata_session_id():
key = _resolve_session_key(
{"litellm_params": {"metadata": {"session_id": "sess-B"}}}
)
assert key == "sess-B"
def test_resolve_session_key_returns_none_when_no_messages():
assert _resolve_session_key({"litellm_params": {}}) is None
assert _resolve_session_key({"litellm_params": {}, "messages": []}) is None
def test_resolve_session_key_derives_stable_hash_from_first_message():
msgs = [{"role": "user", "content": "Hello, world"}]
k1 = _resolve_session_key({"messages": msgs})
k2 = _resolve_session_key({"messages": list(msgs)})
assert k1 == k2
assert k1 and len(k1) == 64 # sha256 hex
def test_resolve_session_key_does_not_prefix_sk():
key = _resolve_session_key({"messages": [{"role": "user", "content": "hi"}]})
assert key and not key.startswith("sk_")
def test_resolve_session_key_segments_by_identity_fields():
"""Same first message but different api keys must yield different keys."""
msgs = [{"role": "user", "content": "same prompt"}]
k_team_a = _resolve_session_key(
{
"messages": msgs,
"litellm_params": {
"metadata": {
"user_api_key_hash": "hash-A",
"user_api_key_team_id": "team-1",
}
},
}
)
k_team_b = _resolve_session_key(
{
"messages": msgs,
"litellm_params": {
"metadata": {
"user_api_key_hash": "hash-B",
"user_api_key_team_id": "team-2",
}
},
}
)
assert k_team_a != k_team_b
def test_resolve_session_key_changes_when_first_message_changes():
k1 = _resolve_session_key({"messages": [{"role": "user", "content": "alpha"}]})
k2 = _resolve_session_key({"messages": [{"role": "user", "content": "beta"}]})
assert k1 != k2
# ---- _record gating -----------------------------------------------------
@pytest.mark.asyncio
async def test_hook_skips_when_below_signal_gate():
"""Conversations shorter than SIGNAL_GATE_MIN_MESSAGES should be ignored."""
hook = _make_hook()
short = [{"role": "user", "content": "hi"}]
assert len(short) < SIGNAL_GATE_MIN_MESSAGES # sanity
kwargs = _kwargs(messages=short)
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
hook.adaptive_router.record_turn.assert_not_awaited()
hook.adaptive_router.claim_or_check_owner.assert_not_called()
@pytest.mark.asyncio
async def test_hook_skips_when_no_messages():
hook = _make_hook()
kwargs = _kwargs(messages=[])
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
hook.adaptive_router.record_turn.assert_not_awaited()
@pytest.mark.asyncio
async def test_hook_skips_when_chosen_model_missing_from_metadata():
hook = _make_hook()
kwargs = _kwargs(chosen=None)
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
hook.adaptive_router.record_turn.assert_not_awaited()
hook.adaptive_router.claim_or_check_owner.assert_not_called()
@pytest.mark.asyncio
async def test_hook_skips_when_owner_cache_mismatch():
"""A different model owns this conversation -> no attribution."""
hook = _make_hook(claim=False)
kwargs = _kwargs(chosen="fast")
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
hook.adaptive_router.claim_or_check_owner.assert_called_once()
hook.adaptive_router.record_turn.assert_not_awaited()
@pytest.mark.asyncio
async def test_hook_records_turn_when_owner_claims():
hook = _make_hook(claim=True)
kwargs = _kwargs(chosen="smart", messages=_long_messages("ask"))
await hook.async_log_success_event(
kwargs, _resp_with_content("answer here"), 0.0, 1.0
)
call = hook.adaptive_router.record_turn.await_args
assert call.kwargs["model_name"] == "smart"
turn: Turn = call.kwargs["turn"]
assert turn.user_content == "ask"
assert turn.assistant_content == "answer here"
assert turn.response_status == 200
@pytest.mark.asyncio
async def test_hook_uses_explicit_session_id_when_provided():
"""Explicit `litellm_session_id` is forwarded as the session key."""
hook = _make_hook()
kwargs = _kwargs(
chosen="fast",
extra_litellm_params={"litellm_session_id": "explicit-sess"},
)
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
args, _ = hook.adaptive_router.claim_or_check_owner.call_args
assert args[0] == "explicit-sess"
assert hook.adaptive_router.record_turn.await_args.kwargs["session_id"] == (
"explicit-sess"
)
@pytest.mark.asyncio
async def test_hook_passes_tool_calls_through():
hook = _make_hook()
tc = {"name": "search", "arguments": '{"q":"x"}'}
kwargs = _kwargs(chosen="fast")
await hook.async_log_success_event(
kwargs, _resp_with_content("calling tool", tool_calls=[tc]), 0.0, 1.0
)
turn: Turn = hook.adaptive_router.record_turn.await_args.kwargs["turn"]
assert turn.tool_calls == [tc]
@pytest.mark.asyncio
async def test_hook_swallows_exceptions_from_record_turn():
hook = _make_hook()
hook.adaptive_router.record_turn.side_effect = RuntimeError("boom")
kwargs = _kwargs(chosen="fast")
# Must NOT raise — signal recording must never break a request.
await hook.async_log_success_event(kwargs, _resp_with_content("ok"), 0.0, 1.0)
@pytest.mark.asyncio
async def test_hook_failure_event_uses_status_code_from_exception():
hook = _make_hook()
exc = MagicMock()
exc.status_code = 429
kwargs = _kwargs(chosen="fast")
kwargs["exception"] = exc
await hook.async_log_failure_event(kwargs, None, 0.0, 1.0)
turn: Turn = hook.adaptive_router.record_turn.await_args.kwargs["turn"]
assert turn.response_status == 429
# ---- async_post_call_success_hook (response header surfacing) ----------
@pytest.mark.asyncio
async def test_post_call_success_hook_sets_response_header():
hook = _make_hook()
response = MagicMock()
response._hidden_params = {}
await hook.async_post_call_success_hook(
data={"metadata": {"adaptive_router_chosen_model": "smart"}},
user_api_key_dict=MagicMock(),
response=response,
)
assert (
response._hidden_params["additional_headers"]["x-litellm-adaptive-router-model"]
== "smart"
)
@pytest.mark.asyncio
async def test_post_call_success_hook_preserves_existing_additional_headers():
hook = _make_hook()
response = MagicMock()
response._hidden_params = {"additional_headers": {"x-existing": "keep-me"}}
await hook.async_post_call_success_hook(
data={"metadata": {"adaptive_router_chosen_model": "fast"}},
user_api_key_dict=MagicMock(),
response=response,
)
assert response._hidden_params["additional_headers"]["x-existing"] == "keep-me"
assert (
response._hidden_params["additional_headers"]["x-litellm-adaptive-router-model"]
== "fast"
)
@pytest.mark.asyncio
async def test_post_call_success_hook_noop_when_metadata_missing_key():
hook = _make_hook()
response = MagicMock()
response._hidden_params = {}
await hook.async_post_call_success_hook(
data={"metadata": {"litellm_session_id": "sess-A"}},
user_api_key_dict=MagicMock(),
response=response,
)
assert response._hidden_params == {}
@pytest.mark.asyncio
async def test_post_call_success_hook_noop_when_no_metadata():
hook = _make_hook()
response = MagicMock()
response._hidden_params = {}
await hook.async_post_call_success_hook(
data={},
user_api_key_dict=MagicMock(),
response=response,
)
assert response._hidden_params == {}
@pytest.mark.asyncio
async def test_post_call_success_hook_noop_when_hidden_params_not_dict():
hook = _make_hook()
class _NoHiddenParams:
pass
response = _NoHiddenParams()
await hook.async_post_call_success_hook(
data={"metadata": {"adaptive_router_chosen_model": "smart"}},
user_api_key_dict=MagicMock(),
response=response,
)
assert not hasattr(response, "_hidden_params")

View File

@ -0,0 +1,380 @@
"""Tests for the Router-level wiring of the adaptive router.
Specifically guards the four bugs found when wiring the example config
`auto_router/adaptive_router` end-to-end:
1. The `auto_router/adaptive_router` model prefix must NOT trigger the
semantic auto-router init path (which would crash on missing fields).
2. The same prefix MUST trigger the adaptive-router init path.
3. `init_adaptive_router_deployment` must read `input_cost_per_token`
from `litellm_params` (where users put it), not just `model_info`.
4. `Router.async_pre_routing_hook` must dispatch to the matching entry in
`self.adaptive_routers` when the inbound model matches a configured
adaptive-router name, returning the underlying model the bandit picked.
"""
from unittest.mock import AsyncMock
import pytest
from litellm import Router
from litellm.types.router import LiteLLM_Params, RequestType
def _params(**overrides):
base = {"model": "auto_router/adaptive_router"}
base.update(overrides)
return LiteLLM_Params(**base)
# ---- Fix 1 & 2: opt-in prefix routing -----------------------------------
def test_auto_router_check_excludes_adaptive_router_prefix():
r = Router(model_list=[])
assert (
r._is_auto_router_deployment(
litellm_params=_params(model="auto_router/adaptive_router")
)
is False
)
def test_auto_router_check_excludes_complexity_router_prefix():
r = Router(model_list=[])
assert (
r._is_auto_router_deployment(
litellm_params=_params(model="auto_router/complexity_router")
)
is False
)
def test_auto_router_check_still_matches_plain_auto_router_prefix():
r = Router(model_list=[])
assert (
r._is_auto_router_deployment(
litellm_params=_params(model="auto_router/my-semantic-router")
)
is True
)
def test_adaptive_router_check_recognizes_prefix():
r = Router(model_list=[])
assert (
r._is_adaptive_router_deployment(
litellm_params=_params(model="auto_router/adaptive_router")
)
is True
)
def test_adaptive_router_check_rejects_other_prefixes():
r = Router(model_list=[])
assert (
r._is_adaptive_router_deployment(litellm_params=_params(model="openai/gpt-4o"))
is False
)
# ---- Fix 3: cost field path --------------------------------------------
def test_init_adaptive_router_reads_cost_from_litellm_params():
r = Router(
model_list=[
{
"model_name": "smart-cheap-router",
"litellm_params": {
"model": "auto_router/adaptive_router",
"adaptive_router_config": {
"available_models": ["fast", "smart"],
},
},
},
{
"model_name": "fast",
"litellm_params": {
"model": "openai/gpt-4o-mini",
"input_cost_per_token": 0.00000015,
},
"model_info": {
"adaptive_router_preferences": {
"quality_tier": 2,
"strengths": [],
}
},
},
{
"model_name": "smart",
"litellm_params": {
"model": "openai/gpt-4o",
"input_cost_per_token": 0.0000050,
},
"model_info": {
"adaptive_router_preferences": {
"quality_tier": 3,
"strengths": ["code_generation"],
}
},
},
]
)
assert "smart-cheap-router" in r.adaptive_routers
assert r.adaptive_routers["smart-cheap-router"].model_to_cost == {
"fast": 0.00000015,
"smart": 0.0000050,
}
# ---- Fix 4: pre-routing dispatch ---------------------------------------
def _router_with_adaptive() -> Router:
return Router(
model_list=[
{
"model_name": "smart-cheap-router",
"litellm_params": {
"model": "auto_router/adaptive_router",
"adaptive_router_config": {
"available_models": ["fast", "smart"],
},
},
},
{
"model_name": "fast",
"litellm_params": {
"model": "openai/gpt-4o-mini",
"input_cost_per_token": 0.00000015,
},
"model_info": {
"adaptive_router_preferences": {
"quality_tier": 2,
"strengths": [],
}
},
},
{
"model_name": "smart",
"litellm_params": {
"model": "openai/gpt-4o",
"input_cost_per_token": 0.0000050,
},
"model_info": {
"adaptive_router_preferences": {
"quality_tier": 3,
"strengths": ["code_generation"],
}
},
},
]
)
@pytest.mark.asyncio
async def test_async_pre_routing_hook_dispatches_to_adaptive_router():
r = _router_with_adaptive()
ar = r.adaptive_routers["smart-cheap-router"]
ar.pick_model = AsyncMock(return_value="smart") # type: ignore[assignment]
response = await r.async_pre_routing_hook(
model="smart-cheap-router",
request_kwargs={"metadata": {"litellm_session_id": "sess-A"}},
messages=[{"role": "user", "content": "Write a Python function"}],
)
assert response is not None
assert response.model == "smart"
call = ar.pick_model.await_args # type: ignore[union-attr]
# Stateless routing: session_id is no longer passed to pick_model.
assert "session_id" not in call.kwargs
assert call.kwargs["request_type"] == RequestType.CODE_GENERATION
@pytest.mark.asyncio
async def test_async_pre_routing_hook_pick_model_not_passed_session_id():
r = _router_with_adaptive()
ar = r.adaptive_routers["smart-cheap-router"]
ar.pick_model = AsyncMock(return_value="fast") # type: ignore[assignment]
response = await r.async_pre_routing_hook(
model="smart-cheap-router",
request_kwargs={},
messages=[{"role": "user", "content": "hello"}],
)
assert response is not None
assert response.model == "fast"
assert "session_id" not in ar.pick_model.await_args.kwargs # type: ignore[union-attr]
@pytest.mark.asyncio
async def test_async_pre_routing_hook_returns_none_for_unrelated_model():
r = _router_with_adaptive()
ar = r.adaptive_routers["smart-cheap-router"]
ar.pick_model = AsyncMock() # type: ignore[assignment]
response = await r.async_pre_routing_hook(
model="some-other-model",
request_kwargs={},
messages=[{"role": "user", "content": "x"}],
)
assert response is None
ar.pick_model.assert_not_awaited() # type: ignore[union-attr]
# ---- Response header surfacing -----------------------------------------
@pytest.mark.asyncio
async def test_async_pre_routing_hook_stashes_chosen_model_in_metadata():
"""
The adaptive-router branch must record the chosen logical model on
`request_kwargs["metadata"]` so `_acompletion` can surface it as the
`x-litellm-adaptive-router-model` response header.
"""
r = _router_with_adaptive()
r.adaptive_routers["smart-cheap-router"].pick_model = AsyncMock( # type: ignore[assignment]
return_value="smart"
)
request_kwargs: dict = {"metadata": {"litellm_session_id": "sess-A"}}
await r.async_pre_routing_hook(
model="smart-cheap-router",
request_kwargs=request_kwargs,
messages=[{"role": "user", "content": "Write a Python function"}],
)
assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "smart"
@pytest.mark.asyncio
async def test_async_pre_routing_hook_creates_metadata_when_missing():
"""If no metadata was passed in, the hook should create one to stash the chosen model."""
r = _router_with_adaptive()
r.adaptive_routers["smart-cheap-router"].pick_model = AsyncMock( # type: ignore[assignment]
return_value="fast"
)
request_kwargs: dict = {}
await r.async_pre_routing_hook(
model="smart-cheap-router",
request_kwargs=request_kwargs,
messages=[{"role": "user", "content": "hello"}],
)
assert request_kwargs["metadata"]["adaptive_router_chosen_model"] == "fast"
# ---- Multi-router support ----------------------------------------------
def test_two_adaptive_routers_can_coexist_on_one_router():
r = Router(
model_list=[
{
"model_name": "cheap-router",
"litellm_params": {
"model": "auto_router/adaptive_router",
"adaptive_router_config": {"available_models": ["fast"]},
},
},
{
"model_name": "premium-router",
"litellm_params": {
"model": "auto_router/adaptive_router",
"adaptive_router_config": {"available_models": ["smart"]},
},
},
{
"model_name": "fast",
"litellm_params": {
"model": "openai/gpt-4o-mini",
"input_cost_per_token": 0.00000015,
},
},
{
"model_name": "smart",
"litellm_params": {
"model": "openai/gpt-4o",
"input_cost_per_token": 0.0000050,
},
},
]
)
assert set(r.adaptive_routers.keys()) == {"cheap-router", "premium-router"}
assert r.adaptive_routers["cheap-router"].config.available_models == ["fast"]
assert r.adaptive_routers["premium-router"].config.available_models == ["smart"]
@pytest.mark.asyncio
async def test_async_pre_routing_hook_dispatches_to_correct_router_when_multiple():
"""Each adaptive router only handles its own router_name."""
r = Router(
model_list=[
{
"model_name": "cheap-router",
"litellm_params": {
"model": "auto_router/adaptive_router",
"adaptive_router_config": {"available_models": ["fast"]},
},
},
{
"model_name": "premium-router",
"litellm_params": {
"model": "auto_router/adaptive_router",
"adaptive_router_config": {"available_models": ["smart"]},
},
},
{
"model_name": "fast",
"litellm_params": {
"model": "openai/gpt-4o-mini",
"input_cost_per_token": 0.00000015,
},
},
{
"model_name": "smart",
"litellm_params": {
"model": "openai/gpt-4o",
"input_cost_per_token": 0.0000050,
},
},
]
)
cheap = r.adaptive_routers["cheap-router"]
premium = r.adaptive_routers["premium-router"]
cheap.pick_model = AsyncMock(return_value="fast") # type: ignore[assignment]
premium.pick_model = AsyncMock(return_value="smart") # type: ignore[assignment]
cheap_response = await r.async_pre_routing_hook(
model="cheap-router",
request_kwargs={},
messages=[{"role": "user", "content": "hi"}],
)
premium_response = await r.async_pre_routing_hook(
model="premium-router",
request_kwargs={},
messages=[{"role": "user", "content": "hi"}],
)
assert cheap_response is not None and cheap_response.model == "fast"
assert premium_response is not None and premium_response.model == "smart"
cheap.pick_model.assert_awaited_once() # type: ignore[union-attr]
premium.pick_model.assert_awaited_once() # type: ignore[union-attr]
def test_init_adaptive_router_rejects_duplicate_model_name():
"""Two adaptive-router deployments with the same model_name must error."""
from litellm.types.router import AdaptiveRouterConfig, Deployment
r = Router(model_list=[])
cfg = {"available_models": ["fast"]}
deployment = Deployment(
model_name="dup-router",
litellm_params=LiteLLM_Params(
model="auto_router/adaptive_router",
adaptive_router_config=cfg,
),
model_info={"id": "x"},
)
r.init_adaptive_router_deployment(deployment=deployment)
with pytest.raises(ValueError, match="already exists"):
r.init_adaptive_router_deployment(deployment=deployment)

View File

@ -0,0 +1,112 @@
import json
from pathlib import Path
from typing import List, Tuple
import pytest
from litellm.router_strategy.adaptive_router.config import TOOL_CALL_HISTORY_MAX
from litellm.router_strategy.adaptive_router.signals import (
SessionState,
SignalDelta,
Turn,
apply_turn,
)
FIXTURE_DIR = Path(__file__).parent / "fixtures"
def _load(name: str) -> list:
return json.loads((FIXTURE_DIR / f"{name}.json").read_text())
def _replay(turns: list) -> Tuple[SessionState, List[SignalDelta]]:
state = SessionState(
session_id="s",
router_name="r",
model_name="m",
classified_type="general",
)
deltas: List[SignalDelta] = []
for t in turns:
deltas.append(
apply_turn(
state,
Turn(
user_content=t.get("user_content"),
assistant_content=t.get("assistant_content"),
tool_calls=t.get("tool_calls", []),
tool_results=t.get("tool_results", []),
response_status=t.get("response_status"),
),
)
)
return state, deltas
def test_clean_satisfaction_fires_satisfaction_only():
state, _ = _replay(_load("clean_satisfaction"))
assert state.satisfaction_count >= 1
assert state.failure_count == 0
assert state.disengagement_count == 0
def test_misalignment_fires_on_rephrase():
state, _ = _replay(_load("misalignment_rephrase"))
assert state.misalignment_count >= 1
def test_stagnation_fires_on_repeated_assistant():
state, _ = _replay(_load("stagnation_repeat"))
assert state.stagnation_count >= 1
def test_disengagement_fires_on_giveup():
state, _ = _replay(_load("disengagement_giveup"))
assert state.disengagement_count >= 1
def test_failure_fires_on_tool_error():
state, _ = _replay(_load("failure_tool_error"))
assert state.failure_count == 1
def test_loop_fires_on_repeated_tool():
state, _ = _replay(_load("loop_same_tool"))
assert state.loop_count >= 1
@pytest.mark.parametrize("fixture", ["exhaustion_429", "exhaustion_context_overflow"])
def test_exhaustion_fires_on_infra_signal(fixture):
state, _ = _replay(_load(fixture))
assert state.exhaustion_count >= 1
def test_no_signals_on_clean_session():
state, _ = _replay(_load("clean_no_signals"))
assert state.misalignment_count == 0
assert state.stagnation_count == 0
assert state.disengagement_count == 0
assert state.failure_count == 0
assert state.loop_count == 0
assert state.exhaustion_count == 0
def test_mixed_failure_then_satisfaction():
state, _ = _replay(_load("mixed_failure_then_satisfaction"))
assert state.failure_count >= 1
assert state.satisfaction_count >= 1
def test_apply_turn_is_o1_does_not_grow_history_unbounded():
state = SessionState(
session_id="s",
router_name="r",
model_name="m",
classified_type="general",
)
for i in range(100):
apply_turn(
state,
Turn(tool_calls=[{"name": f"tool_{i}", "arguments": {}}]),
)
assert len(state.tool_call_history) <= TOOL_CALL_HISTORY_MAX

View File

@ -0,0 +1,196 @@
"""Tests for the GET /adaptive_router/state introspection endpoint and the
underlying `AdaptiveRouter.get_state_snapshot()` helper."""
import time
from unittest.mock import MagicMock
import pytest
from fastapi import HTTPException
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
from litellm.router_strategy.adaptive_router.adaptive_router import AdaptiveRouter
from litellm.router_strategy.adaptive_router.bandit import BanditCell, apply_delta
from litellm.types.router import (
AdaptiveRouterConfig,
AdaptiveRouterPreferences,
RequestType,
)
def _make_router(name: str = "r1") -> AdaptiveRouter:
cfg = AdaptiveRouterConfig(available_models=["fast", "smart"])
prefs = {
"fast": AdaptiveRouterPreferences(quality_tier=1, strengths=[]),
"smart": AdaptiveRouterPreferences(
quality_tier=3, strengths=[RequestType.CODE_GENERATION]
),
}
costs = {"fast": 0.0001, "smart": 0.001}
return AdaptiveRouter(
router_name=name,
config=cfg,
model_to_prefs=prefs,
model_to_cost=costs,
)
# ---- snapshot helper ---------------------------------------------------
@pytest.mark.asyncio
async def test_get_state_snapshot_returns_cell_per_request_type_per_model():
r = _make_router()
snap = await r.get_state_snapshot()
# Top-level shape
assert snap["router_name"] == "r1"
assert snap["available_models"] == ["fast", "smart"]
assert snap["weights"] == {"quality": 0.7, "cost": 0.3}
assert snap["model_costs"] == {"fast": 0.0001, "smart": 0.001}
assert snap["owner_cache_live"] == 0
assert snap["skipped_updates_total"] == 0
assert set(snap["queue"].keys()) == {
"state_pending",
"session_pending",
"max_state_seen",
"max_session_seen",
}
# 7 request types x 2 models = 14 cells
assert len(snap["cells"]) == len(list(RequestType)) * 2
for cell in snap["cells"]:
assert set(cell.keys()) == {
"request_type",
"model",
"alpha",
"beta",
"samples",
"quality_mean",
}
assert cell["model"] in {"fast", "smart"}
assert cell["request_type"] in {rt.value for rt in RequestType}
@pytest.mark.asyncio
async def test_get_state_snapshot_quality_mean_matches_alpha_over_total():
r = _make_router()
# Manually mutate one cell to a known state so the math is verifiable.
key = (RequestType.CODE_GENERATION, "smart")
r._cells[key] = apply_delta(r._cells[key], delta_alpha=10.0, delta_beta=0.0)
expected = r._cells[key]
expected_mean = expected.alpha / (expected.alpha + expected.beta)
snap = await r.get_state_snapshot()
cell = next(
c
for c in snap["cells"]
if c["request_type"] == "code_generation" and c["model"] == "smart"
)
assert cell["alpha"] == expected.alpha
assert cell["beta"] == expected.beta
assert cell["samples"] == expected.alpha + expected.beta
assert cell["quality_mean"] == pytest.approx(expected_mean)
@pytest.mark.asyncio
async def test_get_state_snapshot_counts_only_live_owner_cache_entries():
r = _make_router()
now = time.time()
r._owner_cache["live-1"] = ("fast", now + 3600)
r._owner_cache["live-2"] = ("smart", now + 3600)
r._owner_cache["expired-1"] = ("fast", now - 1)
snap = await r.get_state_snapshot()
assert snap["owner_cache_live"] == 2
@pytest.mark.asyncio
async def test_get_state_snapshot_exposes_skipped_updates_total():
r = _make_router()
r._skipped_updates_total = 7
snap = await r.get_state_snapshot()
assert snap["skipped_updates_total"] == 7
# ---- endpoint --------------------------------------------------------
@pytest.mark.asyncio
async def test_endpoint_returns_404_when_no_adaptive_router(monkeypatch):
"""When llm_router is set but has no adaptive routers configured, return 404."""
from litellm.proxy import proxy_server
fake_router = MagicMock()
fake_router.adaptive_routers = {}
monkeypatch.setattr(proxy_server, "llm_router", fake_router)
admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN)
with pytest.raises(HTTPException) as exc:
await proxy_server.get_adaptive_router_state(user_api_key_dict=admin)
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_endpoint_returns_404_when_llm_router_is_none(monkeypatch):
from litellm.proxy import proxy_server
monkeypatch.setattr(proxy_server, "llm_router", None)
admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN)
with pytest.raises(HTTPException) as exc:
await proxy_server.get_adaptive_router_state(user_api_key_dict=admin)
assert exc.value.status_code == 404
@pytest.mark.asyncio
async def test_endpoint_rejects_non_admin_role(monkeypatch):
from litellm.proxy import proxy_server
fake_router = MagicMock()
fake_router.adaptive_routers = {"r1": _make_router()}
monkeypatch.setattr(proxy_server, "llm_router", fake_router)
non_admin = UserAPIKeyAuth(
api_key="sk-user", user_role=LitellmUserRoles.INTERNAL_USER
)
with pytest.raises(HTTPException) as exc:
await proxy_server.get_adaptive_router_state(user_api_key_dict=non_admin)
assert exc.value.status_code == 403
@pytest.mark.asyncio
async def test_endpoint_returns_snapshot_list_for_admin(monkeypatch):
"""Single configured router still returns the {"routers": [...]} list shape."""
from litellm.proxy import proxy_server
fake_router = MagicMock()
fake_router.adaptive_routers = {"r1": _make_router("r1")}
monkeypatch.setattr(proxy_server, "llm_router", fake_router)
admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN)
result = await proxy_server.get_adaptive_router_state(user_api_key_dict=admin)
assert list(result.keys()) == ["routers"]
assert len(result["routers"]) == 1
snap = result["routers"][0]
assert snap["router_name"] == "r1"
assert snap["available_models"] == ["fast", "smart"]
assert len(snap["cells"]) == len(list(RequestType)) * 2
@pytest.mark.asyncio
async def test_endpoint_returns_one_snapshot_per_router(monkeypatch):
"""With multiple adaptive routers configured, return one snapshot per router."""
from litellm.proxy import proxy_server
fake_router = MagicMock()
fake_router.adaptive_routers = {
"r1": _make_router("r1"),
"r2": _make_router("r2"),
}
monkeypatch.setattr(proxy_server, "llm_router", fake_router)
admin = UserAPIKeyAuth(api_key="sk-1234", user_role=LitellmUserRoles.PROXY_ADMIN)
result = await proxy_server.get_adaptive_router_state(user_api_key_dict=admin)
names = sorted(s["router_name"] for s in result["routers"])
assert names == ["r1", "r2"]

6
uv.lock generated
View File

@ -10,7 +10,7 @@ resolution-markers = [
]
[options]
exclude-newer = "2026-04-13T16:35:18.496811Z"
exclude-newer = "2026-04-15T20:11:16.497522Z"
exclude-newer-span = "P3D"
[manifest]
@ -3767,7 +3767,7 @@ wheels = [
[[package]]
name = "litellm"
version = "1.83.8"
version = "1.83.9"
source = { editable = "." }
dependencies = [
{ name = "aiohttp" },
@ -4114,7 +4114,7 @@ source = { editable = "enterprise" }
[[package]]
name = "litellm-proxy-extras"
version = "0.4.65"
version = "0.4.66"
source = { editable = "litellm-proxy-extras" }
[[package]]