diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 36486747c3..300c311f36 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -240,7 +240,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac self, model: str, request_kwargs: Dict, - messages: Optional[List[Dict[str, str]]] = None, + messages: Optional[List[Dict[str, Any]]] = None, input: Optional[Union[str, List]] = None, specific_deployment: Optional[bool] = False, ) -> Optional[PreRoutingHookResponse]: diff --git a/litellm/llms/anthropic/chat/guardrail_translation/handler.py b/litellm/llms/anthropic/chat/guardrail_translation/handler.py index cb430b0694..2bb82f227b 100644 --- a/litellm/llms/anthropic/chat/guardrail_translation/handler.py +++ b/litellm/llms/anthropic/chat/guardrail_translation/handler.py @@ -34,6 +34,7 @@ from litellm.types.llms.anthropic import ( ) from litellm.types.llms.openai import ( AllMessageValues, + ChatCompletionRequest, ChatCompletionToolCallChunk, ChatCompletionToolParam, ) @@ -67,6 +68,32 @@ class AnthropicMessagesHandler(BaseTranslation): super().__init__() self.adapter = LiteLLMAnthropicMessagesAdapter() + def _translate_to_openai(self, data: dict) -> ChatCompletionRequest: + """Translate Anthropic request to OpenAI chat completion format.""" + ( + chat_completion_compatible_request, + _tool_name_mapping, + ) = LiteLLMAnthropicMessagesAdapter().translate_anthropic_to_openai( + anthropic_message_request=cast(AnthropicMessagesRequest, data.copy()) + ) + return chat_completion_compatible_request + + def get_structured_messages(self, data: dict) -> Optional[List[AllMessageValues]]: + """ + Convert Anthropic messages request data to OpenAI-spec structured messages. + + Uses the Anthropic-to-OpenAI adapter to translate message format. + """ + messages = data.get("messages") + if messages is None: + return None + chat_completion_compatible_request = self._translate_to_openai(data) + result = cast( + List[AllMessageValues], + chat_completion_compatible_request.get("messages", []), + ) + return result if result else None + async def process_input_messages( self, data: dict, @@ -82,13 +109,7 @@ class AnthropicMessagesHandler(BaseTranslation): skip_system = effective_skip_system_message_for_guardrail(guardrail_to_apply) - ( - chat_completion_compatible_request, - _tool_name_mapping, - ) = LiteLLMAnthropicMessagesAdapter().translate_anthropic_to_openai( - # Use a shallow copy to avoid mutating request data (pop on litellm_metadata). - anthropic_message_request=cast(AnthropicMessagesRequest, data.copy()) - ) + chat_completion_compatible_request = self._translate_to_openai(data) structured_messages = cast( List[AllMessageValues], @@ -103,8 +124,6 @@ class AnthropicMessagesHandler(BaseTranslation): chat_completion_compatible_request.get("tools", []) ) task_mappings: List[Tuple[int, Optional[int]]] = [] - # Track (message_index, content_index) for each text - # content_index is None for string content, int for list content # Step 1: Extract all text content and images for msg_idx, message in enumerate(messages): diff --git a/litellm/llms/base_llm/guardrail_translation/base_translation.py b/litellm/llms/base_llm/guardrail_translation/base_translation.py index e1da0dfa29..1efeb159a3 100644 --- a/litellm/llms/base_llm/guardrail_translation/base_translation.py +++ b/litellm/llms/base_llm/guardrail_translation/base_translation.py @@ -5,6 +5,7 @@ if TYPE_CHECKING: from litellm.integrations.custom_guardrail import CustomGuardrail from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.proxy._types import UserAPIKeyAuth + from litellm.types.llms.openai import AllMessageValues class BaseTranslation(ABC): @@ -101,6 +102,16 @@ class BaseTranslation(ABC): """ return responses_so_far + def get_structured_messages(self, data: dict) -> Optional[List["AllMessageValues"]]: + """ + Convert request data to OpenAI-spec structured messages. + + Override in subclasses for format-specific conversion. + + Returns None if no convertible content is found. + """ + return None + def extract_request_tool_names(self, data: dict) -> List[str]: """ Extract tool names from the request body for allowlist/policy checks. diff --git a/litellm/llms/openai/chat/guardrail_translation/handler.py b/litellm/llms/openai/chat/guardrail_translation/handler.py index 2db19dea0b..86ca662562 100644 --- a/litellm/llms/openai/chat/guardrail_translation/handler.py +++ b/litellm/llms/openai/chat/guardrail_translation/handler.py @@ -48,6 +48,17 @@ class OpenAIChatCompletionsHandler(BaseTranslation): Methods can be overridden to customize behavior for different message formats. """ + def get_structured_messages(self, data: dict) -> Optional[List[AllMessageValues]]: + """ + Convert chat completions request data to OpenAI-spec structured messages. + + Messages are already in OpenAI format, so this is a simple extraction. + """ + messages = data.get("messages") + if messages is None: + return None + return cast(List[AllMessageValues], messages) + async def process_input_messages( self, data: dict, @@ -68,9 +79,6 @@ class OpenAIChatCompletionsHandler(BaseTranslation): tool_calls_to_check: List[ChatCompletionToolParam] = [] text_task_mappings: List[Tuple[int, Optional[int]]] = [] tool_call_task_mappings: List[Tuple[int, int]] = [] - # text_task_mappings: Track (message_index, content_index) for each text - # content_index is None for string content, int for list content - # tool_call_task_mappings: Track (message_index, tool_call_index) for each tool call # Step 1: Extract all text content, images, and tool calls for msg_idx, message in enumerate(messages): @@ -92,12 +100,12 @@ class OpenAIChatCompletionsHandler(BaseTranslation): inputs["images"] = images_to_check if tool_calls_to_check: inputs["tool_calls"] = tool_calls_to_check # type: ignore - if messages: - msg_list = cast(List[AllMessageValues], messages) + structured_messages = self.get_structured_messages(data) + if structured_messages: inputs["structured_messages"] = ( - openai_messages_without_system(msg_list) + openai_messages_without_system(structured_messages) if skip_system - else msg_list + else structured_messages ) # Pass tools (function definitions) to the guardrail tools = data.get("tools") diff --git a/litellm/llms/openai/responses/guardrail_translation/handler.py b/litellm/llms/openai/responses/guardrail_translation/handler.py index 76f40eed71..f7dd68aec5 100644 --- a/litellm/llms/openai/responses/guardrail_translation/handler.py +++ b/litellm/llms/openai/responses/guardrail_translation/handler.py @@ -43,6 +43,7 @@ from litellm.responses.litellm_completion_transformation.transformation import ( LiteLLMCompletionResponsesConfig, ) from litellm.types.llms.openai import ( + AllMessageValues, ChatCompletionToolCallChunk, ChatCompletionToolParam, ) @@ -70,6 +71,24 @@ class OpenAIResponsesHandler(BaseTranslation): Methods can be overridden to customize behavior for different message formats. """ + def get_structured_messages(self, data: dict) -> Optional[List[AllMessageValues]]: + """ + Convert Responses API request data to OpenAI-spec structured messages. + + Transforms `input` (string or ResponseInputParam) and optional + `instructions` into chat completion messages. + """ + input_data = data.get("input") + if input_data is None: + return None + messages = ( + LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages( + input=input_data, + responses_api_request=data, + ) + ) + return cast(List[AllMessageValues], messages) if messages else None + async def process_input_messages( self, data: dict, @@ -86,12 +105,7 @@ class OpenAIResponsesHandler(BaseTranslation): if input_data is None: return data - structured_messages = ( - LiteLLMCompletionResponsesConfig.transform_responses_api_input_to_messages( - input=input_data, - responses_api_request=data, - ) - ) + structured_messages = self.get_structured_messages(data) # Handle simple string input if isinstance(input_data, str): diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 36c90c2855..427ec46740 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -30,13 +30,13 @@ model_list: id: claude-sonnet-4-custom-pricing input_cost_per_token: 0.0003 # 100x standard ($0.000003) output_cost_per_token: 0.0015 # 100x standard ($0.000015) - -litellm_settings: - callbacks: ["compression_interception"] - compression_interception_params: - enabled: true - compression_trigger: 100000 -# # optional: -# # embedding_model: "text-embedding-3-small" -# # embedding_model_params: -# # dimensions: 512 \ No newline at end of file + - model_name: my-auto + litellm_params: + model: auto_router/complexity_router + complexity_router_config: + tiers: + SIMPLE: "gpt-4.1-mini" + COMPLEX: claude-sonnet-4-6 + tier_boundaries: + simple_medium: 0.30 + complexity_router_default_model: small-model diff --git a/litellm/router.py b/litellm/router.py index 6572d96f7b..f6976109bc 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -200,12 +200,16 @@ if TYPE_CHECKING: from litellm.router_strategy.complexity_router.complexity_router import ( ComplexityRouter, ) + from litellm.router_strategy.quality_router.quality_router import ( + QualityRouter, + ) Span = Union[_Span, Any] else: Span = Any AutoRouter = Any ComplexityRouter = Any + QualityRouter = Any PreRoutingHookResponse = Any @@ -464,6 +468,7 @@ class Router: ) # {"TEAM_ID": PatternMatchRouter} self.auto_routers: Dict[str, "AutoRouter"] = {} self.complexity_routers: Dict[str, "ComplexityRouter"] = {} + self.quality_routers: Dict[str, "QualityRouter"] = {} # Initialize model_group_alias early since it's used in set_model_list self.model_group_alias: Dict[str, Union[str, RouterModelGroupAliasItem]] = ( @@ -5884,7 +5889,7 @@ class Router: response = await response ## PROCESS RESPONSE HEADERS response = await self.set_response_headers( - response=response, model_group=model_group + response=response, model_group=model_group, request_kwargs=kwargs ) return response @@ -6814,6 +6819,8 @@ class Router: """ if litellm_params.model.startswith("auto_router/complexity_router"): return False # This is handled by complexity_router + if litellm_params.model.startswith("auto_router/quality_router"): + return False # This is handled by quality_router if litellm_params.model.startswith("auto_router/"): return True return False @@ -6920,6 +6927,58 @@ class Router: ) self.complexity_routers[deployment.model_name] = complexity_router + def _is_quality_router_deployment(self, litellm_params: LiteLLM_Params) -> bool: + """ + Check if the deployment is a quality-router deployment. + + Returns True if the litellm_params model starts with "auto_router/quality_router". + """ + if litellm_params.model.startswith("auto_router/quality_router"): + return True + return False + + def init_quality_router_deployment(self, deployment: Deployment): + """ + Initialize the quality-router deployment. + + Resolves the default model from either `quality_router_default_model` or + `quality_router_config["default_model"]`, then instantiates the + QualityRouter and stores it in `self.quality_routers`. + """ + # Import here to mirror the AutoRouter / ComplexityRouter init pattern + # and avoid circular imports. + from litellm.router_strategy.quality_router.quality_router import ( + QualityRouter, + ) + + quality_router_config: Optional[dict] = ( + deployment.litellm_params.quality_router_config + ) + + default_model: Optional[str] = ( + deployment.litellm_params.quality_router_default_model + ) + if default_model is None and quality_router_config: + default_model = quality_router_config.get("default_model") + + if default_model is None: + raise ValueError( + "quality_router_default_model is required for quality-router deployments, " + "or set default_model in quality_router_config. Please configure it in the litellm_params" + ) + + quality_router: QualityRouter = QualityRouter( + model_name=deployment.model_name, + default_model=default_model, + litellm_router_instance=self, + quality_router_config=quality_router_config, + ) + if deployment.model_name in self.quality_routers: + raise ValueError( + f"Quality-router deployment {deployment.model_name} already exists. Please use a different model name." + ) + self.quality_routers[deployment.model_name] = quality_router + def deployment_is_active_for_environment(self, deployment: Deployment) -> bool: """ Function to check if a llm deployment is active for a given environment. Allows using the same config.yaml across multople environments @@ -6966,6 +7025,11 @@ class Router: self.model_id_to_deployment_index_map = {} # Reset the index self.model_name_to_deployment_indices = {} # Reset the model_name index self.team_model_to_deployment_indices = {} # Reset the team_model index + # Reset per-strategy router registries so hot-reload doesn't leave + # stale routers pointing at the old model_list. + self.quality_routers = {} + self.complexity_routers = {} + self.auto_routers = {} self._invalidate_model_group_info_cache() self._invalidate_access_groups_cache() # we add api_base/api_key each model so load balancing between azure/gpt on api_base1 and api_base2 works @@ -7140,6 +7204,12 @@ class Router: ): self.init_complexity_router_deployment(deployment=deployment) + ######################################################### + # Check if this is a quality-router deployment + ######################################################### + if self._is_quality_router_deployment(litellm_params=deployment.litellm_params): + self.init_quality_router_deployment(deployment=deployment) + return deployment def _initialize_deployment_for_pass_through( @@ -8143,7 +8213,10 @@ class Router: return returned_dict async def set_response_headers( - self, response: Any, model_group: Optional[str] = None + self, + response: Any, + model_group: Optional[str] = None, + request_kwargs: Optional[dict] = None, ) -> Any: """ Add the most accurate rate limit headers for a given model response. @@ -8164,6 +8237,45 @@ class Router: additional_headers = response._hidden_params["additional_headers"] # type: ignore + # Lift QualityRouter routing decision into response headers for + # transparency. The decision is stashed in request_kwargs.metadata + # by QualityRouter.async_pre_routing_hook. + metadata = ( + (request_kwargs.get("metadata") or {}) + if isinstance(request_kwargs, dict) + else {} + ) + decision = ( + metadata.get("quality_router_decision") + if isinstance(metadata, dict) + else None + ) + if isinstance(decision, dict): + # Only emit headers for fields that have a meaningful value. + # `complexity_tier` and `matched_keyword` are mutually exclusive + # (the keyword path short-circuits classification), so each + # request emits one or the other but not both. + if decision.get("routed_model") is not None: + additional_headers["x-litellm-quality-router-model"] = str( + decision["routed_model"] + ) + if decision.get("quality_tier") is not None: + additional_headers["x-litellm-quality-router-tier"] = str( + decision["quality_tier"] + ) + if decision.get("routed_via") is not None: + additional_headers["x-litellm-quality-router-via"] = str( + decision["routed_via"] + ) + if decision.get("matched_keyword") is not None: + additional_headers["x-litellm-quality-router-keyword"] = str( + decision["matched_keyword"] + ) + if decision.get("complexity_tier") is not None: + additional_headers["x-litellm-quality-router-complexity"] = str( + decision["complexity_tier"] + ) + if ( "x-ratelimit-remaining-tokens" not in additional_headers and "x-ratelimit-remaining-requests" not in additional_headers @@ -8708,8 +8820,6 @@ class Router: and self.routing_strategy == "latency-based-routing" ): _settings_to_return[var] = self.lowestlatency_logger.routing_args.json() - elif var == "routing_strategy_args": - _settings_to_return[var] = None return _settings_to_return def update_settings(self, **kwargs): @@ -9620,7 +9730,7 @@ class Router: self, model: str, request_kwargs: Dict, - messages: Optional[List[Dict[str, str]]] = None, + messages: Optional[List[Dict[str, Any]]] = None, input: Optional[Union[str, List]] = None, specific_deployment: Optional[bool] = False, ) -> Optional[PreRoutingHookResponse]: @@ -9653,6 +9763,18 @@ class Router: specific_deployment=specific_deployment, ) + ######################################################### + # Check if any quality-router should be used + ######################################################### + if model in self.quality_routers: + return await self.quality_routers[model].async_pre_routing_hook( + model=model, + request_kwargs=request_kwargs, + messages=messages, + input=input, + specific_deployment=specific_deployment, + ) + return None def get_available_deployment( diff --git a/litellm/router_strategy/auto_router/auto_router.py b/litellm/router_strategy/auto_router/auto_router.py index 4ead7225ab..58b2c5a391 100644 --- a/litellm/router_strategy/auto_router/auto_router.py +++ b/litellm/router_strategy/auto_router/auto_router.py @@ -82,11 +82,34 @@ class AutoRouter(CustomLogger): ) return auto_router_routes + @staticmethod + def _extract_text_from_messages(messages: List[Dict[str, Any]]) -> str: + """ + Extract text content from the last user message for routing. + + Handles tool-call conversations (where the last message may be an + assistant or tool message with non-string content) and multimodal + messages (where content is a list of content blocks). + """ + for msg in reversed(messages): + if msg.get("role") == "user": + content = msg.get("content") + if content is None: + return "" + if isinstance(content, list): + return " ".join( + block.get("text", "") + for block in content + if isinstance(block, dict) and block.get("type") == "text" + ) + return str(content) + return "" + async def async_pre_routing_hook( self, model: str, request_kwargs: Dict, - messages: Optional[List[Dict[str, str]]] = None, + messages: Optional[List[Dict[str, Any]]] = None, input: Optional[Union[str, List]] = None, specific_deployment: Optional[bool] = False, ) -> Optional["PreRoutingHookResponse"]: @@ -120,8 +143,7 @@ class AutoRouter(CustomLogger): auto_sync=self.auto_sync_value, ) - user_message: Dict[str, str] = messages[-1] - message_content: str = user_message.get("content", "") + message_content = self._extract_text_from_messages(messages) route_choice: Optional[Union[RouteChoice, List[RouteChoice]]] = self.routelayer( text=message_content ) diff --git a/litellm/router_strategy/complexity_router/complexity_router.py b/litellm/router_strategy/complexity_router/complexity_router.py index e51249b1cb..aa3bcef639 100644 --- a/litellm/router_strategy/complexity_router/complexity_router.py +++ b/litellm/router_strategy/complexity_router/complexity_router.py @@ -332,45 +332,68 @@ class ComplexityRouter(CustomLogger): f"No model configured for tier {tier_key} and no default_model set" ) - async def async_pre_routing_hook( + def _resolve_messages( self, - model: str, + messages: Optional[List[Dict[str, Any]]], request_kwargs: Dict, - messages: Optional[List[Dict[str, Any]]] = None, - input: Optional[Union[str, List]] = None, - specific_deployment: Optional[bool] = False, - ) -> Optional["PreRoutingHookResponse"]: + ) -> Optional[List[Dict[str, Any]]]: """ - Pre-routing hook called before the routing decision. + Resolve messages from the request, converting from other formats if needed. - Classifies the request by complexity and returns the appropriate model. - - Args: - model: The original model name requested. - request_kwargs: The request kwargs. - messages: The messages in the request. - input: Optional input for embeddings. - specific_deployment: Whether a specific deployment was requested. - - Returns: - PreRoutingHookResponse with the routed model, or None if no routing needed. + Uses the guardrail translation handler dispatch to convert Responses API + ``input`` (or other non-chat-completions formats) into OpenAI-spec messages. """ - from litellm.types.router import PreRoutingHookResponse + if messages: + return messages - if messages is None or len(messages) == 0: - verbose_router_logger.debug( - "ComplexityRouter: No messages provided, skipping routing" - ) - return None + from litellm.litellm_core_utils.api_route_to_call_types import ( + get_call_types_for_route, + ) + from litellm.llms import load_guardrail_translation_mappings + from litellm.types.utils import CallTypes - # Extract the last user message and the last system prompt + mappings = load_guardrail_translation_mappings() + call_type: Optional[CallTypes] = None + + # 1. Try route-based inference from proxy metadata + route = request_kwargs.get("litellm_metadata", {}).get( + "user_api_key_request_route" + ) + if route: + call_types_list = get_call_types_for_route(route) + if call_types_list: + for ct in call_types_list: + if ct in mappings: + call_type = ct + break + + # 2. Fallback: try each mapped handler until one produces messages + handlers_to_try: List[Any] = [] + if call_type is not None and call_type in mappings: + handlers_to_try.append(mappings[call_type]()) + else: + handlers_to_try.extend(handler_cls() for handler_cls in mappings.values()) + + for handler in handlers_to_try: + structured = handler.get_structured_messages(request_kwargs) + if structured: + return [ + msg if isinstance(msg, dict) else msg.model_dump() # type: ignore + for msg in structured + ] + return None + + @staticmethod + def _extract_user_message_and_system_prompt( + messages: List[Dict[str, Any]], + ) -> Tuple[Optional[str], Optional[str]]: + """Extract the last user message text and last system prompt from messages.""" user_message: Optional[str] = None system_prompt: Optional[str] = None for msg in reversed(messages): role = msg.get("role", "") content = msg.get("content") or "" - # content may be a list of content parts (e.g. [{"type": "text", "text": "..."}]) if isinstance(content, list): text_parts = [ part.get("text", "") @@ -383,6 +406,52 @@ class ComplexityRouter(CustomLogger): user_message = content elif role == "system" and system_prompt is None: system_prompt = content + if user_message is not None and system_prompt is not None: + break + + return user_message, system_prompt + + async def async_pre_routing_hook( + self, + model: str, + request_kwargs: Dict, + messages: Optional[List[Dict[str, Any]]] = None, + input: Optional[Union[str, List]] = None, + specific_deployment: Optional[bool] = False, + ) -> Optional["PreRoutingHookResponse"]: + """ + Pre-routing hook called before the routing decision. + + Classifies the request by complexity and returns the appropriate model. + Supports chat completions (messages), Responses API (input), and other + formats via the guardrail translation handler dispatch. + + Args: + model: The original model name requested. + request_kwargs: The request kwargs. + messages: The messages in the request. + input: Optional input for Responses API or embeddings. + specific_deployment: Whether a specific deployment was requested. + + Returns: + PreRoutingHookResponse with the routed model, or None if no routing needed. + """ + from litellm.types.router import PreRoutingHookResponse + + resolved_messages = self._resolve_messages(messages, request_kwargs) + + if not resolved_messages: + verbose_router_logger.debug( + "ComplexityRouter: No messages could be resolved, skipping routing" + ) + return None + + # Determine whether the original request used messages directly + has_original_messages = messages is not None and len(messages) > 0 + + user_message, system_prompt = self._extract_user_message_and_system_prompt( + resolved_messages + ) if user_message is None: verbose_router_logger.debug( @@ -391,13 +460,10 @@ class ComplexityRouter(CustomLogger): return PreRoutingHookResponse( model=self.config.default_model or self.get_model_for_tier(ComplexityTier.MEDIUM), - messages=messages, + messages=messages if has_original_messages else None, ) - # Classify the request tier, score, signals = self.classify(user_message, system_prompt) - - # Get the model for this tier routed_model = self.get_model_for_tier(tier) verbose_router_logger.info( @@ -407,5 +473,5 @@ class ComplexityRouter(CustomLogger): return PreRoutingHookResponse( model=routed_model, - messages=messages, + messages=messages if has_original_messages else None, ) diff --git a/litellm/router_strategy/quality_router/__init__.py b/litellm/router_strategy/quality_router/__init__.py new file mode 100644 index 0000000000..5728943448 --- /dev/null +++ b/litellm/router_strategy/quality_router/__init__.py @@ -0,0 +1,21 @@ +""" +Quality-tier auto-router. + +Re-uses the ComplexityRouter's classification to decide a request's complexity, +then maps that complexity to an admin-configured quality tier and resolves the +target model from each candidate's `model_info.litellm_routing_preferences`. +""" + +from .config import ( + DEFAULT_COMPLEXITY_TO_QUALITY, + QualityRouterConfig, + RoutingPreferences, +) +from .quality_router import QualityRouter + +__all__ = [ + "QualityRouter", + "QualityRouterConfig", + "RoutingPreferences", + "DEFAULT_COMPLEXITY_TO_QUALITY", +] diff --git a/litellm/router_strategy/quality_router/config.py b/litellm/router_strategy/quality_router/config.py new file mode 100644 index 0000000000..125ecd5bb9 --- /dev/null +++ b/litellm/router_strategy/quality_router/config.py @@ -0,0 +1,74 @@ +""" +Configuration models for the QualityRouter. +""" + +from typing import Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field + +# Default mapping from ComplexityTier name (string) to quality tier (int). +# Higher tier = higher capability requirement. +DEFAULT_COMPLEXITY_TO_QUALITY: Dict[str, int] = { + "SIMPLE": 1, + "MEDIUM": 2, + "COMPLEX": 3, + "REASONING": 4, +} + + +class QualityRouterConfig(BaseModel): + """Configuration for the QualityRouter.""" + + available_models: List[str] = Field( + default_factory=list, + description=( + "List of candidate model names this router may route to. Each model " + "must declare its quality_tier in model_info.litellm_routing_preferences." + ), + ) + + default_model: Optional[str] = Field( + default=None, + description="Fallback model when no quality tier resolves.", + ) + + complexity_to_quality: Dict[str, int] = Field( + default_factory=lambda: DEFAULT_COMPLEXITY_TO_QUALITY.copy(), + description="Mapping from ComplexityTier name to quality tier (int).", + ) + + model_config = ConfigDict(extra="allow") + + +class RoutingPreferences(BaseModel): + """Per-deployment routing preferences declared on model_info.""" + + quality_tier: int = Field( + ..., + description="The quality tier this deployment satisfies.", + ) + + keywords: List[str] = Field( + default_factory=list, + description=( + "Substring keywords (case-insensitive) that, when present in the " + "user message, route the request to this deployment. See `order` " + "for explicit collision handling, otherwise ties fall through to " + "(highest quality_tier, then cheapest model_info.input_cost_per_token)." + ), + ) + + order: Optional[int] = Field( + default=None, + description=( + "Explicit priority used to break ties between deployments at the " + "same quality tier. Lower values win. Applies both to keyword " + "collisions and to picking between multiple deployments at the " + "same quality_tier. Tiebreak order is " + "(quality_tier DESC, order ASC, input_cost_per_token ASC, " + "model_name ASC) — quality always wins first, then explicit " + "order, then price." + ), + ) + + model_config = ConfigDict(extra="allow") diff --git a/litellm/router_strategy/quality_router/quality_router.py b/litellm/router_strategy/quality_router/quality_router.py new file mode 100644 index 0000000000..a79b4384f5 --- /dev/null +++ b/litellm/router_strategy/quality_router/quality_router.py @@ -0,0 +1,446 @@ +""" +Quality-tier Auto Router. + +Routes a request to a model at a target quality tier. The quality tier is +inferred by re-using the existing ComplexityRouter's classification, then +mapped through an admin-configured `complexity_to_quality` table. Each +candidate model declares its own `quality_tier` in +`model_info.litellm_routing_preferences`. + +Optional keyword override: deployments may also declare `keywords` in +`litellm_routing_preferences`. If any declared keyword appears in the user +message (case-insensitive substring match), the router short-circuits the +complexity-classification flow and routes to the matching deployment. When +multiple deployments match, ties are broken by (highest quality_tier first, +then cheapest `model_info.input_cost_per_token`). +""" + +import math +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +from litellm._logging import verbose_router_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.router_strategy.complexity_router.complexity_router import ( + ComplexityRouter, +) + +from .config import QualityRouterConfig, RoutingPreferences + +if TYPE_CHECKING: + from litellm.router import Router + from litellm.types.router import PreRoutingHookResponse +else: + Router = Any + PreRoutingHookResponse = Any + + +class QualityRouter(CustomLogger): + """ + Routes requests to a model at a target quality tier, with an optional + keyword override. + """ + + def __init__( + self, + model_name: str, + litellm_router_instance: "Router", + default_model: Optional[str] = None, + quality_router_config: Optional[Dict[str, Any]] = None, + ): + self.model_name = model_name + self.litellm_router_instance = litellm_router_instance + + if quality_router_config: + self.config = QualityRouterConfig(**quality_router_config) + else: + self.config = QualityRouterConfig() + + # Explicit default_model arg overrides anything in the config dict. + if default_model: + self.config.default_model = default_model + + # Internal scorer — re-use the existing rule-based classifier. + self._scorer = ComplexityRouter( + model_name=f"{model_name}::scorer", + litellm_router_instance=litellm_router_instance, + ) + + # Per-model indices populated alongside the tier index. `_model_keywords` + # stores keywords lowercased so we can substring-match against the + # lowercased user message in O(total-keyword-count). `_model_quality`, + # `_model_cost`, and `_model_order` drive tiebreaking — `_model_order` + # is the explicit priority (lower wins, unset = +inf). + self._model_keywords: Dict[str, List[str]] = {} + self._model_quality: Dict[str, int] = {} + self._model_cost: Dict[str, Optional[float]] = {} + self._model_order: Dict[str, Optional[int]] = {} + + # Tier → models index. Built lazily on first access so the QualityRouter + # deployment does NOT need to appear after all its referenced models in + # the config — when `_build_tier_index` runs eagerly in `__init__`, the + # router instance's `model_list` is still being assembled incrementally + # by `_create_deployment`, and any `available_models` defined AFTER the + # router entry in config.yaml would silently be reported as missing. + self._tier_to_models_cache: Optional[Dict[int, List[str]]] = None + + verbose_router_logger.debug( + f"QualityRouter initialized for {model_name} with " + f"available_models={self.config.available_models}, " + f"default_model={self.config.default_model}" + ) + + @property + def _tier_to_models(self) -> Dict[int, List[str]]: + """Lazy tier→models index; built on first access.""" + if self._tier_to_models_cache is None: + self._tier_to_models_cache = self._build_tier_index() + return self._tier_to_models_cache + + def _get_routing_preferences(self, deployment: Any) -> Optional[Dict[str, Any]]: + """ + Extract litellm_routing_preferences from a deployment, handling both + dict-shaped and Pydantic-object-shaped deployments. + """ + # Dict-shaped deployment. + if isinstance(deployment, dict): + model_info = deployment.get("model_info") or {} + if isinstance(model_info, dict): + return model_info.get("litellm_routing_preferences") + # Pydantic ModelInfo nested in a dict. + return getattr(model_info, "litellm_routing_preferences", None) + + # Pydantic-object deployment. + model_info = getattr(deployment, "model_info", None) + if model_info is None: + return None + if isinstance(model_info, dict): + return model_info.get("litellm_routing_preferences") + return getattr(model_info, "litellm_routing_preferences", None) + + def _get_deployment_input_cost(self, deployment: Any) -> Optional[float]: + """ + Extract `input_cost_per_token` from a deployment's model_info. + + Returns None when not declared — None is treated as "infinite cost" + for the cheapest-tiebreak ordering, so unpriced models lose ties to + priced ones. (Admins who want a model to win on price must declare it.) + """ + if isinstance(deployment, dict): + model_info = deployment.get("model_info") or {} + else: + model_info = getattr(deployment, "model_info", None) or {} + + if isinstance(model_info, dict): + cost = model_info.get("input_cost_per_token") + else: + cost = getattr(model_info, "input_cost_per_token", None) + + if cost is None: + return None + try: + return float(cost) + except (TypeError, ValueError): + return None + + def _get_deployment_model_name(self, deployment: Any) -> Optional[str]: + """Extract `model_name` from a dict- or object-shaped deployment.""" + if isinstance(deployment, dict): + return deployment.get("model_name") + return getattr(deployment, "model_name", None) + + def _build_tier_index(self) -> Dict[int, List[str]]: + """ + Build {quality_tier: [model_name, ...]} for every model in + `available_models`, plus side indices `_model_keywords`, + `_model_quality`, and `_model_cost`. Raises if any listed model is + missing `litellm_routing_preferences`. + """ + model_list = getattr(self.litellm_router_instance, "model_list", None) or [] + available = set(self.config.available_models) + + # Track which available models we've matched so we can error on missing. + seen: Dict[str, bool] = {name: False for name in available} + tier_to_models: Dict[int, List[str]] = {} + + for deployment in model_list: + name = self._get_deployment_model_name(deployment) + if name is None or name not in available: + continue + + raw_prefs = self._get_routing_preferences(deployment) + if raw_prefs is None: + raise ValueError( + f"QualityRouter: model '{name}' is listed in available_models " + f"but has no model_info.litellm_routing_preferences" + ) + + # Validate via the Pydantic model so we get a clear error for + # missing quality_tier, wrong types, etc. This also means + # `RoutingPreferences` is the single source of truth for the + # accepted shape — readers relied on raw dicts before. + try: + if isinstance(raw_prefs, RoutingPreferences): + prefs = raw_prefs + elif isinstance(raw_prefs, dict): + prefs = RoutingPreferences(**raw_prefs) + else: + # A Pydantic object of some other shape — coerce via its dict. + prefs = RoutingPreferences( + **( + raw_prefs.model_dump() + if hasattr(raw_prefs, "model_dump") + else dict(raw_prefs) + ) + ) + except Exception as e: + raise ValueError( + f"QualityRouter: model '{name}' has invalid " + f"litellm_routing_preferences: {e}" + ) from e + + tier_int = int(prefs.quality_tier) + tier_to_models.setdefault(tier_int, []).append(name) + self._model_keywords[name] = [str(k).lower() for k in prefs.keywords if k] + self._model_quality[name] = tier_int + self._model_cost[name] = self._get_deployment_input_cost(deployment) + self._model_order[name] = prefs.order + seen[name] = True + + missing = [name for name, found in seen.items() if not found] + if missing: + raise ValueError( + f"QualityRouter: the following available_models are not present in " + f"the router's model_list (or are missing routing preferences): {missing}" + ) + + # Sort each tier's model list so `_resolve_model_for_quality_tier` + # (which picks index [0]) honors (order ASC, cost ASC, name ASC). + # Quality is moot within a single tier; keep parity with the keyword + # tiebreak by ordering on (order, cost, name) here. + for models in tier_to_models.values(): + models.sort(key=lambda n: (self._order_key(n), self._cost_key(n), n)) + + return tier_to_models + + def _order_key(self, model_name: str) -> float: + """`order` lookup as a float — unset becomes +inf so explicit wins.""" + order = self._model_order.get(model_name) + return float(order) if order is not None else math.inf + + def _cost_key(self, model_name: str) -> float: + """`input_cost_per_token` as a float — unset becomes +inf.""" + cost = self._model_cost.get(model_name) + return float(cost) if cost is not None else math.inf + + def _keyword_override(self, user_message: str) -> Optional[Tuple[str, str]]: + """ + Find a deployment whose declared keywords appear in `user_message`. + + Returns (model_name, matched_keyword) or None when no keyword matches. + When multiple deployments match, sorts by: + 1. quality_tier DESC (best quality always wins first) + 2. `order` ASC (explicit priority — unset = +inf so explicit wins + within the same tier) + 3. input_cost_per_token ASC (unpriced = +inf so priced wins) + 4. model_name ASC (deterministic stability) + """ + # Touch the lazy index so `_model_keywords` / `_model_quality` / + # `_model_cost` / `_model_order` are populated. + _ = self._tier_to_models + + text = user_message.lower() + + matches: List[Tuple[str, str]] = [] # (model_name, matched_keyword) + for model_name, keywords in self._model_keywords.items(): + for kw in keywords: + if kw and kw in text: + matches.append((model_name, kw)) + break # one match per model is enough + + if not matches: + return None + + def sort_key(match: Tuple[str, str]) -> Tuple[int, float, float, str]: + name = match[0] + quality = self._model_quality.get(name, 0) + order_val = self._order_key(name) + cost = self._model_cost.get(name) + cost_val = cost if cost is not None else math.inf + # Negate quality so higher tier sorts first under ASC sort. + return (-quality, order_val, cost_val, name) + + matches.sort(key=sort_key) + return matches[0] + + def _resolve_model_for_quality_tier(self, tier: int) -> str: + """ + Resolve a quality tier to a concrete model name. + + Strategy: + 1. Exact tier match → first model registered at that tier. + 2. Round UP to the next higher tier that has a model (closer to a + request we might lack capacity for). + 3. Round DOWN to the closest lower tier that has a model (degrade + gracefully instead of jumping straight to `default_model`, + which may be off-tier). + 4. Fall back to `config.default_model`. + 5. Otherwise raise. + """ + tier_index = self._tier_to_models + if tier in tier_index and tier_index[tier]: + return tier_index[tier][0] + + # Round up. + higher_tiers = sorted(t for t in tier_index if t > tier) + for t in higher_tiers: + if tier_index[t]: + return tier_index[t][0] + + # Round down — closest lower tier first. + lower_tiers = sorted((t for t in tier_index if t < tier), reverse=True) + for t in lower_tiers: + if tier_index[t]: + return tier_index[t][0] + + if self.config.default_model: + return self.config.default_model + + raise ValueError( + f"QualityRouter: no model available for quality tier {tier} and " + f"no default_model configured" + ) + + def _stash_decision( + self, + request_kwargs: Optional[Dict[str, Any]], + decision: Dict[str, Any], + ) -> None: + """ + Stash the routing decision in request_kwargs.metadata so the Router can + lift it into response headers (`x-litellm-quality-router-*`). The same + dict object flows from here through to `make_call.set_response_headers`. + """ + if request_kwargs is None: + return + metadata = request_kwargs.setdefault("metadata", {}) + if isinstance(metadata, dict): + metadata["quality_router_decision"] = decision + + async def async_pre_routing_hook( + self, + model: str, + request_kwargs: Dict, + messages: Optional[List[Dict[str, Any]]] = None, + input: Optional[Union[str, List]] = None, + specific_deployment: Optional[bool] = False, + ) -> Optional["PreRoutingHookResponse"]: + """Try keyword override first; fall back to complexity-tier routing.""" + from litellm.types.router import PreRoutingHookResponse + + if messages is None or len(messages) == 0: + verbose_router_logger.debug( + "QualityRouter: No messages provided, skipping routing" + ) + return None + + # Extract last user message and last system prompt — same rules as + # ComplexityRouter.async_pre_routing_hook. + user_message: Optional[str] = None + system_prompt: Optional[str] = None + + for msg in reversed(messages): + role = msg.get("role", "") + content = msg.get("content") or "" + if isinstance(content, list): + text_parts = [ + part.get("text", "") + for part in content + if isinstance(part, dict) and part.get("type") == "text" + ] + content = " ".join(text_parts).strip() + if isinstance(content, str) and content: + if role == "user" and user_message is None: + user_message = content + elif role == "system" and system_prompt is None: + system_prompt = content + + if user_message is None: + verbose_router_logger.debug( + "QualityRouter: No user message found, routing to default model" + ) + if not self.config.default_model: + raise ValueError( + "QualityRouter: no user message and no default_model configured" + ) + return PreRoutingHookResponse( + model=self.config.default_model, + messages=messages, + ) + + # Try keyword override first — it short-circuits complexity classification. + keyword_match = self._keyword_override(user_message) + if keyword_match is not None: + routed_model, matched_keyword = keyword_match + verbose_router_logger.info( + f"QualityRouter: keyword override matched='{matched_keyword}' " + f"routed_model={routed_model} " + f"(quality_tier={self._model_quality.get(routed_model)}, " + f"input_cost_per_token={self._model_cost.get(routed_model)})" + ) + self._stash_decision( + request_kwargs, + { + "router_model_name": self.model_name, + "routed_model": routed_model, + "routed_via": "keyword", + "matched_keyword": matched_keyword, + "quality_tier": self._model_quality.get(routed_model), + "complexity_tier": None, + }, + ) + return PreRoutingHookResponse( + model=routed_model, + messages=messages, + ) + + # No keyword match → complexity classification flow. + complexity_tier, score, signals = self._scorer.classify( + user_message, system_prompt + ) + complexity_name = ( + complexity_tier.value + if hasattr(complexity_tier, "value") + else str(complexity_tier) + ) + + quality_tier = self.config.complexity_to_quality.get(complexity_name) + if quality_tier is None: + raise ValueError( + f"QualityRouter: complexity tier '{complexity_name}' not present " + f"in complexity_to_quality mapping {self.config.complexity_to_quality}" + ) + + routed_model = self._resolve_model_for_quality_tier(int(quality_tier)) + + verbose_router_logger.info( + f"QualityRouter: complexity={complexity_name}, score={score:.3f}, " + f"signals={signals}, quality_tier={quality_tier}, " + f"routed_model={routed_model}" + ) + + self._stash_decision( + request_kwargs, + { + "router_model_name": self.model_name, + "routed_model": routed_model, + "routed_via": "quality_tier", + "matched_keyword": None, + "quality_tier": int(quality_tier), + "complexity_tier": complexity_name, + }, + ) + + return PreRoutingHookResponse( + model=routed_model, + messages=messages, + ) diff --git a/litellm/types/router.py b/litellm/types/router.py index 6bd64915d7..fb71e1f649 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -221,6 +221,10 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams): complexity_router_config: Optional[Dict] = None complexity_router_default_model: Optional[str] = None + # quality-router params + quality_router_config: Optional[Dict] = None + quality_router_default_model: Optional[str] = None + # Batch/File API Params s3_bucket_name: Optional[str] = None s3_encryption_key_id: Optional[str] = None diff --git a/tests/test_litellm/llms/openai/chat/guardrail_translation/test_openai_guardrail_handler.py b/tests/test_litellm/llms/openai/chat/guardrail_translation/test_openai_guardrail_handler.py index a4ac4c94d2..a2c3700294 100644 --- a/tests/test_litellm/llms/openai/chat/guardrail_translation/test_openai_guardrail_handler.py +++ b/tests/test_litellm/llms/openai/chat/guardrail_translation/test_openai_guardrail_handler.py @@ -891,6 +891,61 @@ class TestOpenAIChatCompletionsHandlerStreamingOutput: assert result == responses_so_far +class TestGetStructuredMessages: + """Test the get_structured_messages method.""" + + def test_should_return_messages_from_chat_completions_request(self): + """Test that messages are returned from a chat completions request.""" + handler = OpenAIChatCompletionsHandler() + data = { + "messages": [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + ] + } + result = handler.get_structured_messages(data) + assert result is not None + assert len(result) == 2 + assert result[0]["role"] == "system" + assert result[1]["role"] == "user" + + def test_should_return_none_when_no_messages(self): + """Test that None is returned when no messages key exists.""" + handler = OpenAIChatCompletionsHandler() + data = {"model": "gpt-4"} + result = handler.get_structured_messages(data) + assert result is None + + def test_should_return_none_for_none_messages(self): + """Test that None is returned when messages is explicitly None.""" + handler = OpenAIChatCompletionsHandler() + data = {"messages": None} + result = handler.get_structured_messages(data) + assert result is None + + def test_should_handle_multimodal_content(self): + """Test that messages with multimodal content are returned.""" + handler = OpenAIChatCompletionsHandler() + data = { + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/image.png"}, + }, + ], + } + ] + } + result = handler.get_structured_messages(data) + assert result is not None + assert len(result) == 1 + assert isinstance(result[0]["content"], list) + + if __name__ == "__main__": # Run the tests pytest.main([__file__, "-v"]) diff --git a/tests/test_litellm/llms/openai/responses/test_openai_responses_guardrail_handler.py b/tests/test_litellm/llms/openai/responses/test_openai_responses_guardrail_handler.py index ccece8018f..aee6ccc2e7 100644 --- a/tests/test_litellm/llms/openai/responses/test_openai_responses_guardrail_handler.py +++ b/tests/test_litellm/llms/openai/responses/test_openai_responses_guardrail_handler.py @@ -995,3 +995,63 @@ class TestOpenAIResponsesHandlerStreamingOutputProcessing: # Should return the responses assert result == responses_so_far + + +class TestGetStructuredMessages: + """Test the get_structured_messages method for Responses API handler.""" + + def test_should_convert_string_input_to_messages(self): + """Test that a simple string input is converted to OpenAI messages.""" + handler = OpenAIResponsesHandler() + data = {"input": "What is the capital of France?"} + result = handler.get_structured_messages(data) + assert result is not None + assert len(result) >= 1 + found_user = False + for msg in result: + if isinstance(msg, dict) and msg.get("role") == "user": + found_user = True + break + assert found_user, f"Expected a user message, got: {result}" + + def test_should_convert_list_input_to_messages(self): + """Test that list input (ResponseInputParam) is converted to OpenAI messages.""" + handler = OpenAIResponsesHandler() + data = { + "input": [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + {"role": "user", "content": "How are you?"}, + ] + } + result = handler.get_structured_messages(data) + assert result is not None + assert len(result) >= 3 + + def test_should_include_instructions_as_system_message(self): + """Test that instructions are included as a system message.""" + handler = OpenAIResponsesHandler() + data = { + "input": "Roll a d20", + "instructions": "You are a helpful dungeon master.", + } + result = handler.get_structured_messages(data) + assert result is not None + has_system = any( + isinstance(msg, dict) and msg.get("role") == "system" for msg in result + ) + assert has_system, f"Expected system message from instructions, got: {result}" + + def test_should_return_none_when_no_input(self): + """Test that None is returned when input key is missing.""" + handler = OpenAIResponsesHandler() + data = {"model": "gpt-4o"} + result = handler.get_structured_messages(data) + assert result is None + + def test_should_return_none_for_none_input(self): + """Test that None is returned when input is explicitly None.""" + handler = OpenAIResponsesHandler() + data = {"input": None} + result = handler.get_structured_messages(data) + assert result is None diff --git a/tests/test_litellm/router_strategy/test_auto_router.py b/tests/test_litellm/router_strategy/test_auto_router.py index caff2bc8f1..cb46a4ae55 100644 --- a/tests/test_litellm/router_strategy/test_auto_router.py +++ b/tests/test_litellm/router_strategy/test_auto_router.py @@ -12,7 +12,148 @@ sys.path.insert( from litellm.router_strategy.auto_router.auto_router import AutoRouter -pytestmark = pytest.mark.skip(reason="Skipping auto router tests - beta feature") +pytestmark_skip_beta = pytest.mark.skip( + reason="Skipping auto router tests - beta feature" +) + + +class TestExtractTextFromMessages: + """Tests for AutoRouter._extract_text_from_messages (no semantic_router dependency).""" + + def test_should_extract_content_from_simple_user_message(self): + messages = [{"role": "user", "content": "Hello world"}] + result = AutoRouter._extract_text_from_messages(messages) + assert result == "Hello world" + + def test_should_extract_last_user_message_from_tool_call_conversation(self): + messages = [ + {"role": "user", "content": "What's the weather in NYC?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "NYC"}', + }, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_abc123", + "content": "72°F and sunny", + }, + {"role": "user", "content": "Now tell me about London"}, + ] + result = AutoRouter._extract_text_from_messages(messages) + assert result == "Now tell me about London" + + def test_should_find_user_message_when_last_message_is_assistant_with_tool_calls( + self, + ): + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ], + }, + ] + result = AutoRouter._extract_text_from_messages(messages) + assert result == "What's the weather?" + + def test_should_find_user_message_when_last_message_is_tool_response(self): + messages = [ + {"role": "user", "content": "What's the weather?"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ], + }, + { + "role": "tool", + "tool_call_id": "call_abc", + "content": "72°F and sunny", + }, + ] + result = AutoRouter._extract_text_from_messages(messages) + assert result == "What's the weather?" + + def test_should_handle_multimodal_content_list(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/img.png"}, + }, + ], + } + ] + result = AutoRouter._extract_text_from_messages(messages) + assert result == "What's in this image?" + + def test_should_handle_multimodal_content_with_multiple_text_blocks(self): + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "First part"}, + {"type": "text", "text": "Second part"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/img.png"}, + }, + ], + } + ] + result = AutoRouter._extract_text_from_messages(messages) + assert result == "First part Second part" + + def test_should_return_empty_string_when_user_content_is_none(self): + messages = [{"role": "user", "content": None}] + result = AutoRouter._extract_text_from_messages(messages) + assert result == "" + + def test_should_return_empty_string_when_no_user_messages(self): + messages = [ + {"role": "system", "content": "You are a helpful assistant"}, + { + "role": "assistant", + "content": None, + "tool_calls": [ + { + "id": "call_abc", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ], + }, + ] + result = AutoRouter._extract_text_from_messages(messages) + assert result == "" + + def test_should_return_empty_string_for_empty_messages_list(self): + result = AutoRouter._extract_text_from_messages([]) + assert result == "" @pytest.fixture @@ -41,6 +182,7 @@ def mock_route_choice(): return mock_choice +@pytestmark_skip_beta class TestAutoRouter: """Test class for AutoRouter methods.""" diff --git a/tests/test_litellm/router_strategy/test_complexity_router.py b/tests/test_litellm/router_strategy/test_complexity_router.py index 8d36fc2ba3..e68ea863d8 100644 --- a/tests/test_litellm/router_strategy/test_complexity_router.py +++ b/tests/test_litellm/router_strategy/test_complexity_router.py @@ -7,7 +7,7 @@ Tests the rule-based complexity scoring and tier assignment logic. import os import sys from typing import Dict, List -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest @@ -828,3 +828,222 @@ class TestRouterComplexityDeploymentMethods: ) router.init_complexity_router_deployment(deployment) assert "auto_router/complexity_router/test-router" in router.complexity_routers + + +class TestAsyncPreRoutingHookMultiFormat: + """Test async_pre_routing_hook with multiple input formats.""" + + @pytest.mark.asyncio + async def test_should_route_with_chat_completions_messages(self, complexity_router): + """Test routing with standard chat completions messages.""" + result = await complexity_router.async_pre_routing_hook( + model="test-model", + request_kwargs={}, + messages=[{"role": "user", "content": "What is 2+2?"}], + ) + assert result is not None + assert result.model is not None + assert result.messages is not None + + @pytest.mark.asyncio + async def test_should_route_with_responses_api_string_input( + self, complexity_router + ): + """Test routing with Responses API string input via handler dispatch.""" + from litellm.llms.openai.responses.guardrail_translation.handler import ( + OpenAIResponsesHandler, + ) + from litellm.types.utils import CallTypes + + mock_mappings = {CallTypes.responses: OpenAIResponsesHandler} + + with patch( + "litellm.llms.load_guardrail_translation_mappings", + return_value=mock_mappings, + ): + result = await complexity_router.async_pre_routing_hook( + model="test-model", + request_kwargs={"input": "What is the capital of France?"}, + messages=None, + input="What is the capital of France?", + ) + + assert result is not None + assert result.model is not None + # messages should be None since the original request didn't have messages + assert result.messages is None + + @pytest.mark.asyncio + async def test_should_route_with_responses_api_list_input(self, complexity_router): + """Test routing with Responses API list input via handler dispatch.""" + from litellm.llms.openai.responses.guardrail_translation.handler import ( + OpenAIResponsesHandler, + ) + from litellm.types.utils import CallTypes + + mock_mappings = {CallTypes.responses: OpenAIResponsesHandler} + + list_input = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi there!"}, + { + "role": "user", + "content": "Write a Python function to sort a list using merge sort", + }, + ] + + with patch( + "litellm.llms.load_guardrail_translation_mappings", + return_value=mock_mappings, + ): + result = await complexity_router.async_pre_routing_hook( + model="test-model", + request_kwargs={"input": list_input}, + messages=None, + input=list_input, + ) + + assert result is not None + assert result.model is not None + assert result.messages is None + + @pytest.mark.asyncio + async def test_should_use_route_based_inference(self, complexity_router): + """Test that route-based call type inference is used when available.""" + from litellm.llms.openai.responses.guardrail_translation.handler import ( + OpenAIResponsesHandler, + ) + from litellm.types.utils import CallTypes + + mock_mappings = {CallTypes.responses: OpenAIResponsesHandler} + + with patch( + "litellm.llms.load_guardrail_translation_mappings", + return_value=mock_mappings, + ): + result = await complexity_router.async_pre_routing_hook( + model="test-model", + request_kwargs={ + "input": "Roll 2d4+1", + "litellm_metadata": { + "user_api_key_request_route": "/v1/responses", + }, + }, + messages=None, + ) + + assert result is not None + assert result.model is not None + + @pytest.mark.asyncio + async def test_should_return_none_when_no_messages_or_input( + self, complexity_router + ): + """Test that None is returned when neither messages nor input is available.""" + result = await complexity_router.async_pre_routing_hook( + model="test-model", + request_kwargs={}, + messages=None, + input=None, + ) + assert result is None + + @pytest.mark.asyncio + async def test_should_prefer_original_messages_over_conversion( + self, complexity_router + ): + """Test that original messages are used when both messages and input are available.""" + messages = [{"role": "user", "content": "What is 2+2?"}] + result = await complexity_router.async_pre_routing_hook( + model="test-model", + request_kwargs={"input": "This should be ignored"}, + messages=messages, + ) + assert result is not None + assert result.messages == messages + + @pytest.mark.asyncio + async def test_should_include_instructions_in_classification( + self, complexity_router + ): + """Test that Responses API instructions influence classification via system message.""" + from litellm.llms.openai.responses.guardrail_translation.handler import ( + OpenAIResponsesHandler, + ) + from litellm.types.utils import CallTypes + + mock_mappings = {CallTypes.responses: OpenAIResponsesHandler} + + with patch( + "litellm.llms.load_guardrail_translation_mappings", + return_value=mock_mappings, + ): + result = await complexity_router.async_pre_routing_hook( + model="test-model", + request_kwargs={ + "input": "Write merge sort", + "instructions": "You are an expert Python developer. Use advanced algorithms and optimize for performance.", + }, + messages=None, + ) + + assert result is not None + assert result.model is not None + + +class TestExtractUserMessageAndSystemPrompt: + """Test the _extract_user_message_and_system_prompt static method.""" + + def test_should_extract_user_message(self): + """Test extraction of the last user message.""" + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + {"role": "user", "content": "How are you?"}, + ] + user_msg, sys_prompt = ComplexityRouter._extract_user_message_and_system_prompt( + messages + ) + assert user_msg == "How are you?" + assert sys_prompt == "You are helpful." + + def test_should_handle_no_user_message(self): + """Test when there is no user message.""" + messages = [ + {"role": "system", "content": "You are helpful."}, + {"role": "assistant", "content": "Hi!"}, + ] + user_msg, sys_prompt = ComplexityRouter._extract_user_message_and_system_prompt( + messages + ) + assert user_msg is None + assert sys_prompt == "You are helpful." + + def test_should_handle_multipart_content(self): + """Test extraction from multipart content messages.""" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Describe this image"}, + { + "type": "image_url", + "image_url": {"url": "https://example.com/img.png"}, + }, + ], + } + ] + user_msg, sys_prompt = ComplexityRouter._extract_user_message_and_system_prompt( + messages + ) + assert user_msg == "Describe this image" + assert sys_prompt is None + + def test_should_handle_empty_messages(self): + """Test with empty messages list.""" + user_msg, sys_prompt = ComplexityRouter._extract_user_message_and_system_prompt( + [] + ) + assert user_msg is None + assert sys_prompt is None diff --git a/tests/test_litellm/router_strategy/test_quality_router.py b/tests/test_litellm/router_strategy/test_quality_router.py new file mode 100644 index 0000000000..01574cb980 --- /dev/null +++ b/tests/test_litellm/router_strategy/test_quality_router.py @@ -0,0 +1,1033 @@ +""" +Tests for the QualityRouter. + +Covers: +- Tier index construction from `model_info.litellm_routing_preferences`. +- Quality-tier resolution (exact, round-up, default fallback). +- Keyword override (match, tiebreaking by quality + price). +- Pre-routing hook end-to-end. +- Decision metadata stash + Router.set_response_headers lift. +""" + +import os +import sys +from typing import Any, Dict, List +from unittest.mock import MagicMock + +import pytest + +sys.path.insert(0, os.path.abspath("../../..")) + +from litellm.router_strategy.quality_router.config import ( + DEFAULT_COMPLEXITY_TO_QUALITY, +) +from litellm.router_strategy.quality_router.quality_router import QualityRouter + + +def _make_model_list(spec: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """ + Build a router model_list from a compact spec. + + spec entry shape: { + "model_name": str, + "quality_tier": Optional[int], + "keywords": Optional[List[str]], + "order": Optional[int], + "input_cost_per_token": Optional[float], + } + If quality_tier is None, the deployment is created without + `litellm_routing_preferences`. + """ + out: List[Dict[str, Any]] = [] + for entry in spec: + model_info: Dict[str, Any] = {"id": f"id-{entry['model_name']}"} + if entry.get("quality_tier") is not None: + prefs: Dict[str, Any] = {"quality_tier": entry["quality_tier"]} + if "keywords" in entry: + prefs["keywords"] = entry["keywords"] + if "order" in entry: + prefs["order"] = entry["order"] + model_info["litellm_routing_preferences"] = prefs + if "input_cost_per_token" in entry: + model_info["input_cost_per_token"] = entry["input_cost_per_token"] + out.append( + { + "model_name": entry["model_name"], + "litellm_params": {"model": f"openai/{entry['model_name']}"}, + "model_info": model_info, + } + ) + return out + + +@pytest.fixture +def four_tier_model_list() -> List[Dict[str, Any]]: + """A standard haiku(1)/sonnet(2)/opus(3)/opus-next(4) model list.""" + return _make_model_list( + [ + {"model_name": "haiku", "quality_tier": 1}, + {"model_name": "sonnet", "quality_tier": 2}, + {"model_name": "opus", "quality_tier": 3}, + {"model_name": "opus-next", "quality_tier": 4}, + ] + ) + + +@pytest.fixture +def mock_router(four_tier_model_list): + """A MagicMock router preloaded with the four-tier model list.""" + router = MagicMock() + router.model_list = four_tier_model_list + return router + + +@pytest.fixture +def quality_router(mock_router) -> QualityRouter: + """Default QualityRouter wired to all four tiers.""" + config = { + "available_models": ["haiku", "sonnet", "opus", "opus-next"], + "complexity_to_quality": DEFAULT_COMPLEXITY_TO_QUALITY, + } + return QualityRouter( + model_name="quality-router-test", + litellm_router_instance=mock_router, + default_model="haiku", + quality_router_config=config, + ) + + +# ─── Tier index ───────────────────────────────────────────────────────────── + + +class TestTierIndex: + def test_builds_correct_tier_to_models_map(self, quality_router): + assert quality_router._tier_to_models == { + 1: ["haiku"], + 2: ["sonnet"], + 3: ["opus"], + 4: ["opus-next"], + } + + def test_ignores_models_not_in_available_models(self, four_tier_model_list): + # Add a model the config doesn't list — it should be ignored. + extra = _make_model_list([{"model_name": "ghost", "quality_tier": 5}]) + router = MagicMock() + router.model_list = four_tier_model_list + extra + + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="haiku", + quality_router_config={ + "available_models": ["haiku", "sonnet", "opus", "opus-next"] + }, + ) + + for models in qr._tier_to_models.values(): + assert "ghost" not in models + + def test_raises_when_routing_preferences_missing(self): + # `sonnet` is in available_models but has no preferences. + ml = _make_model_list( + [ + {"model_name": "haiku", "quality_tier": 1}, + {"model_name": "sonnet", "quality_tier": None}, + ] + ) + router = MagicMock() + router.model_list = ml + + # Construction succeeds (tier index is lazy); the error surfaces on + # first use so the router entry doesn't have to appear after all of + # its referenced models in config.yaml. + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="haiku", + quality_router_config={"available_models": ["haiku", "sonnet"]}, + ) + with pytest.raises(ValueError, match="sonnet"): + _ = qr._tier_to_models + + +# ─── Resolve model for quality tier ───────────────────────────────────────── + + +class TestResolveModelForQualityTier: + def test_exact_match(self, quality_router): + assert quality_router._resolve_model_for_quality_tier(2) == "sonnet" + assert quality_router._resolve_model_for_quality_tier(4) == "opus-next" + + def test_rounds_up_when_tier_missing(self, mock_router): + # Available tiers: 1, 3, 4. Asking for 2 should round up to 3. + spec = [ + {"model_name": "haiku", "quality_tier": 1}, + {"model_name": "opus", "quality_tier": 3}, + {"model_name": "opus-next", "quality_tier": 4}, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="haiku", + quality_router_config={"available_models": ["haiku", "opus", "opus-next"]}, + ) + + assert qr._resolve_model_for_quality_tier(2) == "opus" + + def test_rounds_down_when_no_higher_tier_exists(self): + # Only tier 1 available. Asking for tier 4 rounds up (nothing), then + # rounds DOWN to the closest lower tier — tier 1. + spec = [{"model_name": "haiku", "quality_tier": 1}] + router = MagicMock() + router.model_list = _make_model_list(spec) + + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="emergency-default", + quality_router_config={"available_models": ["haiku"]}, + ) + + assert qr._resolve_model_for_quality_tier(4) == "haiku" + + def test_rounds_down_prefers_closest_lower_tier(self): + # Available: 1, 2. Asking for 4 rounds down to tier 2 (not tier 1). + spec = [ + {"model_name": "haiku", "quality_tier": 1}, + {"model_name": "sonnet", "quality_tier": 2}, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="emergency-default", + quality_router_config={"available_models": ["haiku", "sonnet"]}, + ) + + assert qr._resolve_model_for_quality_tier(4) == "sonnet" + + def test_prefers_round_up_over_round_down(self): + # Available: 1, 3. Asking for 2 rounds UP to 3, not DOWN to 1. + spec = [ + {"model_name": "haiku", "quality_tier": 1}, + {"model_name": "opus", "quality_tier": 3}, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="emergency-default", + quality_router_config={"available_models": ["haiku", "opus"]}, + ) + + assert qr._resolve_model_for_quality_tier(2) == "opus" + + +# ─── RoutingPreferences validation ───────────────────────────────────────── + + +class TestRoutingPreferencesValidation: + def test_invalid_quality_tier_type_raises_clear_error(self): + # quality_tier must be an int — pass a non-coercible string. + ml = [ + { + "model_name": "haiku", + "litellm_params": {"model": "openai/gpt-4o-mini"}, + "model_info": { + "id": "id-haiku", + "litellm_routing_preferences": {"quality_tier": "not-an-int"}, + }, + } + ] + router = MagicMock() + router.model_list = ml + + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="haiku", + quality_router_config={"available_models": ["haiku"]}, + ) + with pytest.raises(ValueError, match="invalid litellm_routing_preferences"): + _ = qr._tier_to_models + + +# ─── Config-ordering independence (lazy index build) ─────────────────────── + + +class TestConfigOrderingIndependence: + def test_router_can_be_instantiated_before_its_targets_exist(self): + # Build a router instance whose referenced model_list is EMPTY at + # construction time (simulating a config where the router entry + # appears before its target deployments). The tier index must not be + # built eagerly — it's deferred until first use. + router = MagicMock() + router.model_list = [] # <- targets haven't been added yet + + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="haiku", + quality_router_config={"available_models": ["haiku", "sonnet", "opus"]}, + ) + + # Now the targets come online. This mirrors the incremental add by + # `Router._create_deployment`. + router.model_list = _make_model_list( + [ + {"model_name": "haiku", "quality_tier": 1}, + {"model_name": "sonnet", "quality_tier": 2}, + {"model_name": "opus", "quality_tier": 3}, + ] + ) + + # First access triggers the index build and sees the full list. + assert qr._tier_to_models == { + 1: ["haiku"], + 2: ["sonnet"], + 3: ["opus"], + } + + +# ─── Router.set_model_list resets quality_routers (hot reload) ───────────── + + +class TestSetModelListResetsQualityRouters: + def test_set_model_list_clears_quality_routers_registry(self): + from litellm.router import Router + + router = Router( + model_list=[ + { + "model_name": "haiku", + "litellm_params": { + "model": "openai/gpt-4o-mini", + "api_key": "sk-test", + }, + "model_info": {"litellm_routing_preferences": {"quality_tier": 1}}, + }, + { + "model_name": "my-qr", + "litellm_params": { + "model": "auto_router/quality_router", + "quality_router_default_model": "haiku", + "quality_router_config": {"available_models": ["haiku"]}, + }, + }, + ] + ) + + assert "my-qr" in router.quality_routers + + # Hot-reload with a new model_list that doesn't define the router. + router.set_model_list( + [ + { + "model_name": "haiku", + "litellm_params": { + "model": "openai/gpt-4o-mini", + "api_key": "sk-test", + }, + } + ] + ) + + # Stale router from before must be cleared. + assert "my-qr" not in router.quality_routers + + +# ─── Pre-routing hook ─────────────────────────────────────────────────────── + + +class TestPreRoutingHook: + @pytest.mark.asyncio + async def test_simple_message_routes_to_tier_1(self, quality_router): + messages = [{"role": "user", "content": "hi"}] + resp = await quality_router.async_pre_routing_hook( + model="quality-router-test", + request_kwargs={}, + messages=messages, + ) + assert resp is not None + assert resp.model == "haiku" + + @pytest.mark.asyncio + async def test_reasoning_message_routes_to_tier_4(self, quality_router): + # Two reasoning markers triggers ComplexityTier.REASONING → quality 4. + messages = [ + { + "role": "user", + "content": ( + "Think step by step and reason through this problem. " + "Analyze this carefully and break down each component." + ), + } + ] + resp = await quality_router.async_pre_routing_hook( + model="quality-router-test", + request_kwargs={}, + messages=messages, + ) + assert resp is not None + assert resp.model == "opus-next" + + @pytest.mark.asyncio + async def test_empty_messages_returns_none(self, quality_router): + resp = await quality_router.async_pre_routing_hook( + model="quality-router-test", + request_kwargs={}, + messages=[], + ) + assert resp is None + + @pytest.mark.asyncio + async def test_only_system_message_routes_to_default(self, quality_router): + messages = [{"role": "system", "content": "You are a helpful assistant."}] + resp = await quality_router.async_pre_routing_hook( + model="quality-router-test", + request_kwargs={}, + messages=messages, + ) + assert resp is not None + assert resp.model == "haiku" # the configured default_model + + +# ─── Keyword override ────────────────────────────────────────────────────── + + +@pytest.fixture +def keyword_router(): + """ + Router where multiple deployments declare overlapping keywords so we can + exercise the (quality DESC, price ASC) tiebreak. + + - cheap-coder tier 2, keywords [code, python], cost 0.000001 + - smart-coder tier 3, keywords [code, python], cost 0.000010 + - law-bot tier 2, keywords [legal, contract], cost 0.000005 + - default-haiku tier 1, no keywords, cost 0.0000005 + """ + spec = [ + { + "model_name": "default-haiku", + "quality_tier": 1, + "keywords": [], + "input_cost_per_token": 0.0000005, + }, + { + "model_name": "cheap-coder", + "quality_tier": 2, + "keywords": ["code", "python"], + "input_cost_per_token": 0.000001, + }, + { + "model_name": "smart-coder", + "quality_tier": 3, + "keywords": ["code", "python"], + "input_cost_per_token": 0.000010, + }, + { + "model_name": "law-bot", + "quality_tier": 2, + "keywords": ["legal", "contract"], + "input_cost_per_token": 0.000005, + }, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + return QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="default-haiku", + quality_router_config={ + "available_models": [ + "default-haiku", + "cheap-coder", + "smart-coder", + "law-bot", + ], + }, + ) + + +class TestKeywordOverride: + def test_no_keyword_in_message_returns_none(self, keyword_router): + assert keyword_router._keyword_override("hello there") is None + + def test_single_match_returns_that_model(self, keyword_router): + # Only law-bot declares "legal". + assert keyword_router._keyword_override("review this legal doc") == ( + "law-bot", + "legal", + ) + + def test_case_insensitive_match(self, keyword_router): + assert keyword_router._keyword_override("LEGAL question") == ( + "law-bot", + "legal", + ) + + def test_overlap_picks_highest_quality_tier(self, keyword_router): + # Both cheap-coder (tier 2) and smart-coder (tier 3) declare "code". + # Quality wins over price → smart-coder. + assert keyword_router._keyword_override("write some code for me") == ( + "smart-coder", + "code", + ) + + def test_same_tier_picks_cheapest(self): + # Two models at the same tier, both matching "data" — cheapest wins. + spec = [ + { + "model_name": "expensive", + "quality_tier": 2, + "keywords": ["data"], + "input_cost_per_token": 0.000050, + }, + { + "model_name": "cheap", + "quality_tier": 2, + "keywords": ["data"], + "input_cost_per_token": 0.000005, + }, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="cheap", + quality_router_config={"available_models": ["expensive", "cheap"]}, + ) + match = qr._keyword_override("show me the data") + assert match == ("cheap", "data") + + def test_unpriced_loses_to_priced_at_same_tier(self): + # Same quality tier, one has cost, one doesn't → priced wins. + spec = [ + { + "model_name": "no-price", + "quality_tier": 2, + "keywords": ["data"], + # input_cost_per_token deliberately omitted + }, + { + "model_name": "with-price", + "quality_tier": 2, + "keywords": ["data"], + "input_cost_per_token": 0.000005, + }, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="no-price", + quality_router_config={"available_models": ["no-price", "with-price"]}, + ) + match = qr._keyword_override("show me the data") + assert match == ("with-price", "data") + + @pytest.mark.asyncio + async def test_hook_short_circuits_complexity_on_keyword_match( + self, keyword_router + ): + # A reasoning-style prompt would normally route to a high-quality model + # via the complexity flow — but the keyword "code" should short-circuit + # to smart-coder (highest tier among "code" models). + messages = [ + { + "role": "user", + "content": ( + "Think step by step and reason through this code problem. " + "Analyze this carefully and break down each component." + ), + } + ] + request_kwargs: Dict[str, Any] = {} + resp = await keyword_router.async_pre_routing_hook( + model="qr", + request_kwargs=request_kwargs, + messages=messages, + ) + assert resp is not None + assert resp.model == "smart-coder" + + decision = request_kwargs["metadata"]["quality_router_decision"] + assert decision["routed_via"] == "keyword" + assert decision["matched_keyword"] == "code" + assert decision["complexity_tier"] is None # short-circuited + + def test_quality_wins_over_explicit_order(self): + # Quality always beats order. A tier-3 model with no `order` wins over + # a tier-2 model with `order=1`. + spec = [ + { + "model_name": "ordered-tier2", + "quality_tier": 2, + "keywords": ["code"], + "order": 1, + "input_cost_per_token": 0.000010, + }, + { + "model_name": "implicit-tier3", + "quality_tier": 3, + "keywords": ["code"], + "input_cost_per_token": 0.000005, + }, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="ordered-tier2", + quality_router_config={ + "available_models": ["ordered-tier2", "implicit-tier3"] + }, + ) + match = qr._keyword_override("write some code") + assert match == ("implicit-tier3", "code") + + def test_order_breaks_tie_within_same_quality_tier(self): + # Two tier-3 models, both match "code". Lower `order` wins. + spec = [ + { + "model_name": "preferred", + "quality_tier": 3, + "keywords": ["code"], + "order": 1, + "input_cost_per_token": 0.000050, # more expensive + }, + { + "model_name": "default-tier3", + "quality_tier": 3, + "keywords": ["code"], + "input_cost_per_token": 0.000005, # cheaper + }, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="default-tier3", + quality_router_config={"available_models": ["preferred", "default-tier3"]}, + ) + match = qr._keyword_override("write some code") + assert match == ("preferred", "code") + + def test_explicit_order_overrides_price(self): + # Same tier, but the more expensive one has a lower `order` and wins. + spec = [ + { + "model_name": "expensive-but-preferred", + "quality_tier": 2, + "keywords": ["data"], + "order": 1, + "input_cost_per_token": 0.000050, + }, + { + "model_name": "cheap-default", + "quality_tier": 2, + "keywords": ["data"], + "input_cost_per_token": 0.000005, + }, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="cheap-default", + quality_router_config={ + "available_models": ["expensive-but-preferred", "cheap-default"] + }, + ) + match = qr._keyword_override("show me the data") + assert match == ("expensive-but-preferred", "data") + + def test_lower_order_wins_between_two_explicitly_ordered(self): + spec = [ + { + "model_name": "second", + "quality_tier": 2, + "keywords": ["data"], + "order": 5, + }, + { + "model_name": "first", + "quality_tier": 2, + "keywords": ["data"], + "order": 1, + }, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="first", + quality_router_config={"available_models": ["first", "second"]}, + ) + match = qr._keyword_override("show me the data") + assert match == ("first", "data") + + def test_same_order_falls_through_to_quality_then_price(self): + # All three models share order=1 → tiebreak falls through to + # (quality DESC, cost ASC). + spec = [ + { + "model_name": "low-tier", + "quality_tier": 1, + "keywords": ["data"], + "order": 1, + "input_cost_per_token": 0.000001, + }, + { + "model_name": "high-tier-cheap", + "quality_tier": 3, + "keywords": ["data"], + "order": 1, + "input_cost_per_token": 0.000005, + }, + { + "model_name": "high-tier-expensive", + "quality_tier": 3, + "keywords": ["data"], + "order": 1, + "input_cost_per_token": 0.000050, + }, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="low-tier", + quality_router_config={ + "available_models": [ + "low-tier", + "high-tier-cheap", + "high-tier-expensive", + ] + }, + ) + match = qr._keyword_override("show me the data") + assert match == ("high-tier-cheap", "data") + + def test_order_is_used_in_tier_resolution_too(self): + # Two models at the same tier. Explicit `order=1` on the second one + # should make _resolve_model_for_quality_tier(2) pick it. + spec = [ + { + "model_name": "default-pick", + "quality_tier": 2, + }, + { + "model_name": "preferred-pick", + "quality_tier": 2, + "order": 1, + }, + ] + router = MagicMock() + router.model_list = _make_model_list(spec) + qr = QualityRouter( + model_name="qr", + litellm_router_instance=router, + default_model="default-pick", + quality_router_config={ + "available_models": ["default-pick", "preferred-pick"] + }, + ) + assert qr._resolve_model_for_quality_tier(2) == "preferred-pick" + + @pytest.mark.asyncio + async def test_hook_falls_back_to_complexity_when_no_keyword(self, keyword_router): + # No declared keyword in the message → complexity-based routing. + # "hi" is SIMPLE → quality 1 → default-haiku (the only tier-1 model). + messages = [{"role": "user", "content": "hi"}] + request_kwargs: Dict[str, Any] = {} + resp = await keyword_router.async_pre_routing_hook( + model="qr", + request_kwargs=request_kwargs, + messages=messages, + ) + assert resp is not None + assert resp.model == "default-haiku" + + decision = request_kwargs["metadata"]["quality_router_decision"] + assert decision["routed_via"] == "quality_tier" + assert decision["matched_keyword"] is None + assert decision["complexity_tier"] == "SIMPLE" + + +# ─── Routing-decision metadata (powers x-litellm-quality-router-* headers) ── + + +class TestDecisionMetadata: + @pytest.mark.asyncio + async def test_hook_stashes_decision_in_request_kwargs_metadata( + self, quality_router + ): + # Reasoning prompt → REASONING → quality tier 4 → opus-next. + messages = [ + { + "role": "user", + "content": ( + "Think step by step and reason through this problem. " + "Analyze this carefully and break down each component." + ), + } + ] + request_kwargs: Dict[str, Any] = {} + + resp = await quality_router.async_pre_routing_hook( + model="quality-router-test", + request_kwargs=request_kwargs, + messages=messages, + ) + assert resp is not None and resp.model == "opus-next" + + decision = request_kwargs["metadata"]["quality_router_decision"] + assert decision["routed_model"] == "opus-next" + assert decision["quality_tier"] == 4 + assert decision["complexity_tier"] == "REASONING" + assert decision["router_model_name"] == "quality-router-test" + assert decision["routed_via"] == "quality_tier" + assert decision["matched_keyword"] is None + + @pytest.mark.asyncio + async def test_decision_metadata_preserves_existing_metadata(self, quality_router): + request_kwargs: Dict[str, Any] = { + "metadata": {"trace_id": "abc-123", "user_id": "u-1"} + } + + await quality_router.async_pre_routing_hook( + model="quality-router-test", + request_kwargs=request_kwargs, + messages=[{"role": "user", "content": "hi"}], + ) + + # Existing metadata keys are intact and the decision is added alongside. + assert request_kwargs["metadata"]["trace_id"] == "abc-123" + assert request_kwargs["metadata"]["user_id"] == "u-1" + assert "quality_router_decision" in request_kwargs["metadata"] + + +# ─── Router.set_response_headers lifts decision into x-litellm-quality-* ──── + + +class TestSetResponseHeadersLiftsDecision: + """ + Verify the Router.set_response_headers helper turns a stashed quality-router + decision into x-litellm-quality-router-* headers on the response. + """ + + @pytest.mark.asyncio + async def test_lifts_decision_into_additional_headers(self): + from pydantic import BaseModel + + from litellm.router import Router + + class FakeResponse(BaseModel): + model_config = {"arbitrary_types_allowed": True} + _hidden_params: Dict[str, Any] = {} + + # Build a real Router with a tiny model_list — enough to satisfy + # set_response_headers without needing the rest of the router stack. + router = Router( + model_list=[ + { + "model_name": "haiku", + "litellm_params": { + "model": "openai/gpt-4o-mini", + "api_key": "sk-test", + }, + } + ] + ) + + response = FakeResponse() + response._hidden_params = {} + + request_kwargs = { + "metadata": { + "quality_router_decision": { + "router_model_name": "qr", + "routed_model": "smart-coder", + "routed_via": "keyword", + "matched_keyword": "code", + "quality_tier": 3, + "complexity_tier": None, + } + } + } + + await router.set_response_headers( + response=response, + model_group="qr", + request_kwargs=request_kwargs, + ) + + headers = response._hidden_params["additional_headers"] + assert headers["x-litellm-quality-router-model"] == "smart-coder" + assert headers["x-litellm-quality-router-tier"] == "3" + assert headers["x-litellm-quality-router-via"] == "keyword" + assert headers["x-litellm-quality-router-keyword"] == "code" + # Keyword route short-circuits classification → no complexity header. + assert "x-litellm-quality-router-complexity" not in headers + # Existing x-litellm-model-group behavior is unchanged. + assert headers["x-litellm-model-group"] == "qr" + + @pytest.mark.asyncio + async def test_quality_tier_route_emits_complexity_not_keyword(self): + from pydantic import BaseModel + + from litellm.router import Router + + class FakeResponse(BaseModel): + model_config = {"arbitrary_types_allowed": True} + _hidden_params: Dict[str, Any] = {} + + router = Router( + model_list=[ + { + "model_name": "haiku", + "litellm_params": { + "model": "openai/gpt-4o-mini", + "api_key": "sk-test", + }, + } + ] + ) + + response = FakeResponse() + response._hidden_params = {} + + request_kwargs = { + "metadata": { + "quality_router_decision": { + "router_model_name": "qr", + "routed_model": "haiku", + "routed_via": "quality_tier", + "matched_keyword": None, + "quality_tier": 1, + "complexity_tier": "SIMPLE", + } + } + } + + await router.set_response_headers( + response=response, + model_group="qr", + request_kwargs=request_kwargs, + ) + + headers = response._hidden_params["additional_headers"] + assert headers["x-litellm-quality-router-via"] == "quality_tier" + assert headers["x-litellm-quality-router-complexity"] == "SIMPLE" + # Quality-tier route → no keyword header. + assert "x-litellm-quality-router-keyword" not in headers + + @pytest.mark.asyncio + async def test_no_decision_leaves_quality_router_headers_unset(self): + from pydantic import BaseModel + + from litellm.router import Router + + class FakeResponse(BaseModel): + model_config = {"arbitrary_types_allowed": True} + _hidden_params: Dict[str, Any] = {} + + router = Router( + model_list=[ + { + "model_name": "haiku", + "litellm_params": { + "model": "openai/gpt-4o-mini", + "api_key": "sk-test", + }, + } + ] + ) + + response = FakeResponse() + response._hidden_params = {} + + await router.set_response_headers( + response=response, + model_group="haiku", + request_kwargs={}, # no quality_router_decision + ) + + headers = response._hidden_params["additional_headers"] + assert "x-litellm-quality-router-model" not in headers + assert "x-litellm-quality-router-tier" not in headers + + +class TestRouterQualityDeploymentMethods: + """Tests for Router._is_quality_router_deployment and Router.init_quality_router_deployment.""" + + def test_is_quality_router_deployment_true(self): + """_is_quality_router_deployment returns True for quality router models.""" + from litellm.router import Router + from litellm.types.router import LiteLLM_Params + + router = Router( + model_list=[ + { + "model_name": "gpt-4o-mini", + "litellm_params": {"model": "openai/gpt-4o-mini"}, + } + ] + ) + params = LiteLLM_Params(model="auto_router/quality_router/my-router") + assert router._is_quality_router_deployment(params) is True + + def test_is_quality_router_deployment_false(self): + """_is_quality_router_deployment returns False for regular models.""" + from litellm.router import Router + from litellm.types.router import LiteLLM_Params + + router = Router( + model_list=[ + { + "model_name": "gpt-4o-mini", + "litellm_params": {"model": "openai/gpt-4o-mini"}, + } + ] + ) + params = LiteLLM_Params(model="openai/gpt-4o-mini") + assert router._is_quality_router_deployment(params) is False + + def test_init_quality_router_deployment(self): + """init_quality_router_deployment registers a QualityRouter.""" + from litellm.router import Router + from litellm.types.router import Deployment, LiteLLM_Params + + router = Router( + model_list=[ + { + "model_name": "gpt-4o-mini", + "litellm_params": {"model": "openai/gpt-4o-mini"}, + } + ] + ) + deployment = Deployment( + model_name="auto_router/quality_router/test-router", + litellm_params=LiteLLM_Params( + model="auto_router/quality_router/test-router", + quality_router_default_model="gpt-4o-mini", + ), + model_info={"id": "test-id"}, + ) + router.init_quality_router_deployment(deployment) + assert "auto_router/quality_router/test-router" in router.quality_routers