[Feat] - Ishaan main merge branch (#23596)
* fix(bedrock): respect s3_region_name for batch file uploads (#23569) * fix(bedrock): respect s3_region_name for batch file uploads (GovCloud fix) * fix: s3_region_name always wins over aws_region_name for S3 signing (Greptile feedback) * fix: _filter_headers_for_aws_signature - Bedrock KB (#23571) * fix: _filter_headers_for_aws_signature * fix: filter None header values in all post-signing re-merge paths Addresses Greptile feedback: None-valued headers were being filtered during SigV4 signing but re-merged back into the final headers dict afterward, which would cause downstream HTTP client failures. Made-with: Cursor * feat(router): tag_regex routing — route by User-Agent regex without per-developer tag config (#23594) * feat(router): add tag_regex support for header-based routing Adds a new `tag_regex` field to litellm_params that lets operators route requests based on regex patterns matched against request headers — primarily User-Agent — without requiring per-developer tag configuration. Use case: route all Claude Code traffic (User-Agent: claude-code/x.y.z) to a dedicated deployment by setting: tag_regex: - "^User-Agent: claude-code\\/" in the deployment's litellm_params. Works alongside existing `tags` routing; exact tag match takes precedence over regex match. Unmatched requests fall through to deployments tagged `default`. The matched deployment, pattern, and user_agent are recorded in `metadata["tag_routing"]` so they flow through to SpendLogs automatically. * fix(tag_regex): address backwards-compat, metadata overwrite, and warning noise Three issues from code review: 1. Backwards-compat: `has_tag_filter` was widened to activate on any non-empty User-Agent, which would raise ValueError for existing deployments using plain tags without a `default` fallback. Fix: only activate header-based regex filtering when at least one candidate deployment has `tag_regex` configured. 2. Metadata overwrite: `metadata["tag_routing"]` was overwritten for every matching deployment in the loop, leaving inaccurate provenance when multiple deployments match. Fix: write only for the first match. 3. Warning noise: an invalid regex pattern logged one warning per header string rather than once per pattern. Fix: compile first (catching re.error once), then iterate over header strings. Also adds two new tests covering these cases, and adds docs page for tag_regex routing with a Claude Code walk-through. * refactor(tag_regex): remove unnecessary _healthy_list copy * docs: merge tag_regex section into tag_routing.md, remove standalone page - Add ## Regex-based tag routing (tag_regex) section to existing tag_routing.md instead of a separate page - Remove tag_regex_routing.md standalone doc (odd UX to have a separate page for a sub-feature) - Remove proxy/tag_regex_routing from sidebars.js - Add match_any=False debug warning in tag_based_routing.py when regex routing fires under strict mode (regex always uses OR semantics) * fix(tag_regex): address greptile review - security docs, strict-mode enforcement, validation order - Strengthen security note in tag_routing.md: explicitly state User-Agent is client-supplied and can be set to any value; frame tag_regex as a traffic classification hint, not an access-control mechanism - Move tag_regex startup validation before _add_deployment() so an invalid pattern never leaves partial router state - Enforce match_any=False strict-tag policy: when a deployment has both tags and tag_regex and the strict tag check fails, skip the regex fallback rather than silently bypassing the operator's intent - Extract per-deployment match logic into _match_deployment() helper to keep get_deployments_for_tag() readable - Add two new tests: strict-mode blocks regex fallback, regex-only deployment still matches under match_any=False * fix(ci): apply Black formatting to 14 files and stabilize flaky caplog tests - Run Black formatter on 14 files that were failing the lint check - Replace caplog-based assertions in TestAliasConflicts with unittest.mock.patch on verbose_logger.warning for xdist compatibility - The caplog fixture can produce empty text in pytest-xdist workers in certain CI environments, causing flaky test failures Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com> --------- Co-authored-by: Cursor Agent <cursoragent@cursor.com> Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
This commit is contained in:
parent
25ee2fb3f9
commit
b87d1f8dad
@ -209,6 +209,106 @@ Expect to see the following response header when this works
|
||||
x-litellm-model-id: default-model
|
||||
```
|
||||
|
||||
## Regex-based tag routing (`tag_regex`)
|
||||
|
||||
Use `tag_regex` to route requests based on regex patterns matched against request headers, without requiring clients to pass a tag explicitly. This is useful when clients already send a recognisable header, such as `User-Agent`.
|
||||
|
||||
**Use case: route all Claude Code traffic to dedicated AWS accounts**
|
||||
|
||||
Claude Code always sends `User-Agent: claude-code/<version>`. With `tag_regex` you can route that traffic to a dedicated deployment automatically — no per-developer configuration needed.
|
||||
|
||||
### 1. Config
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
# Claude Code traffic → dedicated deployment, matched by User-Agent
|
||||
- model_name: claude-sonnet
|
||||
litellm_params:
|
||||
model: bedrock/converse/anthropic-claude-sonnet-4-6
|
||||
aws_region_name: us-east-1
|
||||
aws_role_name: arn:aws:iam::111122223333:role/LiteLLMClaudeCode
|
||||
tag_regex:
|
||||
- "^User-Agent: claude-code\\/" # matches claude-code/1.x, 2.x, etc.
|
||||
model_info:
|
||||
id: claude-code-deployment
|
||||
|
||||
# All other traffic falls back to the default deployment
|
||||
- model_name: claude-sonnet
|
||||
litellm_params:
|
||||
model: bedrock/converse/anthropic-claude-sonnet-4-6
|
||||
aws_region_name: us-east-1
|
||||
aws_role_name: arn:aws:iam::444455556666:role/LiteLLMDefault
|
||||
tags:
|
||||
- default
|
||||
model_info:
|
||||
id: regular-deployment
|
||||
|
||||
router_settings:
|
||||
enable_tag_filtering: true
|
||||
tag_filtering_match_any: true
|
||||
|
||||
general_settings:
|
||||
master_key: sk-1234
|
||||
```
|
||||
|
||||
### 2. Verify routing
|
||||
|
||||
Claude Code sets `User-Agent: claude-code/<version>` automatically — no client config needed:
|
||||
|
||||
```shell
|
||||
# Claude Code request (User-Agent set automatically by Claude Code)
|
||||
curl http://localhost:4000/v1/chat/completions \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-H "User-Agent: claude-code/1.2.3" \
|
||||
-d '{"model": "claude-sonnet", "messages": [{"role": "user", "content": "hi"}]}'
|
||||
# → x-litellm-model-id: claude-code-deployment
|
||||
|
||||
# Any other client (no matching User-Agent) → default deployment
|
||||
curl http://localhost:4000/v1/chat/completions \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{"model": "claude-sonnet", "messages": [{"role": "user", "content": "hi"}]}'
|
||||
# → x-litellm-model-id: regular-deployment
|
||||
```
|
||||
|
||||
### How matching works
|
||||
|
||||
| Priority | Condition | Result |
|
||||
|----------|-----------|--------|
|
||||
| 1 | Request has `tags` AND deployment has `tags` | Exact tag match (respects `match_any` setting) |
|
||||
| 2 | Deployment has `tag_regex` AND request has a `User-Agent` | Regex match (always OR logic — any pattern match suffices) |
|
||||
| 3 | Deployment has `tags: [default]` | Default fallback |
|
||||
| 4 | No default set | All healthy deployments returned |
|
||||
|
||||
`tag_regex` always uses OR semantics — `tag_filtering_match_any=False` applies only to exact tag matching, not to regex patterns.
|
||||
|
||||
### Observability
|
||||
|
||||
When a regex matches, `tag_routing` is written into request metadata and flows to SpendLogs:
|
||||
|
||||
```json
|
||||
{
|
||||
"tag_routing": {
|
||||
"matched_via": "tag_regex",
|
||||
"matched_value": "^User-Agent: claude-code\\/",
|
||||
"user_agent": "claude-code/1.2.3",
|
||||
"request_tags": []
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Security note
|
||||
|
||||
:::caution
|
||||
|
||||
**`User-Agent` is a client-supplied header and can be set to any value.** Any API consumer can send `User-Agent: claude-code/1.0` regardless of whether they are actually using Claude Code.
|
||||
|
||||
Do not rely on `tag_regex` routing to enforce access controls or spend limits — use [team/key-based routing](./users) for that. `tag_regex` is a **traffic classification hint** (useful for billing visibility, capacity planning, and routing convenience), not a security boundary.
|
||||
|
||||
:::
|
||||
|
||||
|
||||
---
|
||||
|
||||
## ✨ Team based tag routing (Enterprise)
|
||||
|
||||
LiteLLM Proxy supports team-based tag routing, allowing you to associate specific tags with teams and route requests accordingly. Example **Team A can access gpt-4 deployment A, Team B can access gpt-4 deployment B** (LLM Access Control For Teams)
|
||||
|
||||
@ -398,6 +398,7 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
|
||||
ResponseOutputMessage,
|
||||
ResponseReasoningItem,
|
||||
)
|
||||
|
||||
try:
|
||||
from openai.types.responses.response_output_item import (
|
||||
ResponseApplyPatchToolCall,
|
||||
@ -460,7 +461,9 @@ class LiteLLMResponsesTransformationHandler(CompletionTransformationBridge):
|
||||
accumulated_tool_calls.append(tool_call_dict)
|
||||
tool_call_index += 1
|
||||
|
||||
elif ResponseApplyPatchToolCall is not None and isinstance(item, ResponseApplyPatchToolCall):
|
||||
elif ResponseApplyPatchToolCall is not None and isinstance(
|
||||
item, ResponseApplyPatchToolCall
|
||||
):
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
|
||||
@ -2680,7 +2680,9 @@ def anthropic_messages_pt( # noqa: PLR0915
|
||||
_content_is_list = "content" in assistant_content_block and isinstance(
|
||||
assistant_content_block["content"], list
|
||||
)
|
||||
_content_list = assistant_content_block.get("content") if _content_is_list else None
|
||||
_content_list = (
|
||||
assistant_content_block.get("content") if _content_is_list else None
|
||||
)
|
||||
_list_has_thinking = False
|
||||
if _content_is_list and _content_list is not None:
|
||||
for _item in _content_list:
|
||||
|
||||
@ -79,7 +79,9 @@ class AnthropicFilesConfig(BaseFilesConfig):
|
||||
return AnthropicError(
|
||||
status_code=status_code,
|
||||
message=error_message,
|
||||
headers=cast(httpx.Headers, headers) if isinstance(headers, dict) else headers,
|
||||
headers=cast(httpx.Headers, headers)
|
||||
if isinstance(headers, dict)
|
||||
else headers,
|
||||
)
|
||||
|
||||
def validate_environment(
|
||||
|
||||
@ -144,9 +144,7 @@ class BaseModelResponseIterator:
|
||||
# Skip empty lines (common in SSE streams between events).
|
||||
# Only apply to str chunks — non-string objects (e.g. Pydantic
|
||||
# BaseModel events from the Responses API) must pass through.
|
||||
if isinstance(str_line, str) and (
|
||||
not str_line or not str_line.strip()
|
||||
):
|
||||
if isinstance(str_line, str) and (not str_line or not str_line.strip()):
|
||||
continue
|
||||
|
||||
# chunk is a str at this point
|
||||
@ -184,9 +182,7 @@ class BaseModelResponseIterator:
|
||||
# Skip empty lines (common in SSE streams between events).
|
||||
# Only apply to str chunks — non-string objects (e.g. Pydantic
|
||||
# BaseModel events from the Responses API) must pass through.
|
||||
if isinstance(str_line, str) and (
|
||||
not str_line or not str_line.strip()
|
||||
):
|
||||
if isinstance(str_line, str) and (not str_line or not str_line.strip()):
|
||||
continue
|
||||
|
||||
# chunk is a str at this point
|
||||
|
||||
@ -1268,7 +1268,8 @@ class BaseAWSLLM:
|
||||
|
||||
# Add back all original headers (including forwarded ones) after signature calculation
|
||||
for header_name, header_value in headers.items():
|
||||
request.headers[header_name] = header_value
|
||||
if header_value is not None:
|
||||
request.headers[header_name] = header_value
|
||||
|
||||
if (
|
||||
extra_headers is not None and "Authorization" in extra_headers
|
||||
@ -1298,6 +1299,8 @@ class BaseAWSLLM:
|
||||
}
|
||||
|
||||
for header_name, header_value in headers.items():
|
||||
if header_value is None:
|
||||
continue
|
||||
header_lower = header_name.lower()
|
||||
if (
|
||||
header_lower in aws_headers
|
||||
@ -1393,7 +1396,8 @@ class BaseAWSLLM:
|
||||
# Add back original headers after signing. Only headers in SignedHeaders
|
||||
# are integrity-protected; forwarded headers (x-forwarded-*) must remain unsigned.
|
||||
for header_name, header_value in headers.items():
|
||||
request_headers_dict[header_name] = header_value
|
||||
if header_value is not None:
|
||||
request_headers_dict[header_name] = header_value
|
||||
if (
|
||||
headers is not None and "Authorization" in headers
|
||||
): # prevent sigv4 from overwriting the auth header
|
||||
|
||||
@ -173,7 +173,12 @@ class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig):
|
||||
"S3 bucket_name is required. Set 's3_bucket_name' in litellm_params or AWS_S3_BUCKET_NAME env var"
|
||||
)
|
||||
|
||||
aws_region_name = self._get_aws_region_name(optional_params, model)
|
||||
s3_region_name = litellm_params.get("s3_region_name") or optional_params.get(
|
||||
"s3_region_name"
|
||||
)
|
||||
aws_region_name = s3_region_name or self._get_aws_region_name(
|
||||
optional_params, model
|
||||
)
|
||||
|
||||
file_data = data.get("file")
|
||||
purpose = data.get("purpose")
|
||||
@ -398,6 +403,15 @@ class BedrockFilesConfig(BaseAWSLLM, BaseFilesConfig):
|
||||
data=create_file_data,
|
||||
)
|
||||
|
||||
# s3_region_name always wins for S3 operations (same priority as in
|
||||
# get_complete_file_url above). Overwrite aws_region_name unconditionally
|
||||
# so the SigV4 region matches the URL region, avoiding SignatureDoesNotMatch.
|
||||
s3_region_name = litellm_params.get("s3_region_name") or optional_params.get(
|
||||
"s3_region_name"
|
||||
)
|
||||
if s3_region_name:
|
||||
optional_params = {**optional_params, "aws_region_name": s3_region_name}
|
||||
|
||||
# Sign the request and return a pre-signed request object
|
||||
signed_headers, signed_body = self._sign_s3_request(
|
||||
content=file_content,
|
||||
|
||||
@ -201,7 +201,9 @@ class BlackForestLabsImageEditConfig(BaseImageEditConfig):
|
||||
return image
|
||||
elif isinstance(image, list):
|
||||
# If it's a list, take the first image
|
||||
return self._read_image_bytes(image[0], depth=depth + 1, max_depth=max_depth)
|
||||
return self._read_image_bytes(
|
||||
image[0], depth=depth + 1, max_depth=max_depth
|
||||
)
|
||||
elif isinstance(image, str):
|
||||
if image.startswith(("http://", "https://")):
|
||||
# Download image from URL
|
||||
|
||||
@ -71,7 +71,9 @@ class PerplexityResponsesConfig(OpenAIResponsesAPIConfig):
|
||||
result: List[Any] = []
|
||||
for item in input:
|
||||
if isinstance(item, dict) and "type" not in item:
|
||||
new_item = dict(item) # convert to plain dict to avoid TypedDict checking
|
||||
new_item = dict(
|
||||
item
|
||||
) # convert to plain dict to avoid TypedDict checking
|
||||
new_item["type"] = "message"
|
||||
result.append(new_item)
|
||||
else:
|
||||
|
||||
@ -378,9 +378,7 @@ if MCP_AVAILABLE:
|
||||
# Resolve a server name to its UUID if needed
|
||||
_name_resolved = None
|
||||
if server_id not in allowed_server_ids:
|
||||
_name_resolved = global_mcp_server_manager.get_mcp_server_by_name(
|
||||
server_id
|
||||
)
|
||||
_name_resolved = global_mcp_server_manager.get_mcp_server_by_name(server_id)
|
||||
if _name_resolved is not None and _name_resolved.server_id in set(
|
||||
allowed_server_ids
|
||||
):
|
||||
@ -442,9 +440,7 @@ if MCP_AVAILABLE:
|
||||
extra_headers=user_oauth_extra_headers,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Error getting tools from {server.name}: {e}"
|
||||
)
|
||||
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
|
||||
return {
|
||||
"tools": [],
|
||||
"error": "server_error",
|
||||
@ -473,7 +469,9 @@ if MCP_AVAILABLE:
|
||||
_name_resolved = None
|
||||
if server_id not in allowed_server_ids:
|
||||
_name_resolved = global_mcp_server_manager.get_mcp_server_by_name(server_id)
|
||||
if _name_resolved is not None and _name_resolved.server_id in set(allowed_server_ids):
|
||||
if _name_resolved is not None and _name_resolved.server_id in set(
|
||||
allowed_server_ids
|
||||
):
|
||||
server_id = _name_resolved.server_id
|
||||
|
||||
if server_id not in allowed_server_ids:
|
||||
@ -518,7 +516,9 @@ if MCP_AVAILABLE:
|
||||
server_auth_header = _get_server_auth_header(
|
||||
server, mcp_server_auth_headers, mcp_auth_header
|
||||
)
|
||||
user_oauth_extra_headers = await _get_user_oauth_extra_headers(server, user_api_key_dict)
|
||||
user_oauth_extra_headers = await _get_user_oauth_extra_headers(
|
||||
server, user_api_key_dict
|
||||
)
|
||||
|
||||
try:
|
||||
list_tools_result = await _get_tools_for_single_server(
|
||||
@ -529,9 +529,7 @@ if MCP_AVAILABLE:
|
||||
extra_headers=user_oauth_extra_headers,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
f"Error getting tools from {server.name}: {e}"
|
||||
)
|
||||
verbose_logger.exception(f"Error getting tools from {server.name}: {e}")
|
||||
return {
|
||||
"tools": [],
|
||||
"error": "server_error",
|
||||
@ -905,7 +903,9 @@ if MCP_AVAILABLE:
|
||||
try:
|
||||
client_id, client_secret, scopes = _extract_credentials(request)
|
||||
|
||||
_oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = (
|
||||
_oauth2_flow: Optional[
|
||||
Literal["client_credentials", "authorization_code"]
|
||||
] = (
|
||||
"client_credentials"
|
||||
if client_id and client_secret and request.token_url
|
||||
else None
|
||||
|
||||
@ -64,7 +64,11 @@ from litellm.proxy.common_utils.http_parsing_utils import (
|
||||
populate_request_with_path_params,
|
||||
)
|
||||
from litellm.proxy.common_utils.realtime_utils import _realtime_request_body
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, normalize_route_for_root_path
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient,
|
||||
ProxyLogging,
|
||||
normalize_route_for_root_path,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_bool
|
||||
from litellm.types.services import ServiceTypes
|
||||
|
||||
|
||||
@ -106,9 +106,13 @@ class _OPTIONAL_PresidioPIIMasking(CustomGuardrail):
|
||||
if (self.output_parse_pii or self.apply_to_output) and not logging_only:
|
||||
current_hook = self.event_hook
|
||||
if isinstance(current_hook, str) and current_hook != "post_call":
|
||||
self.event_hook = cast(List[GuardrailEventHooks], [current_hook, "post_call"])
|
||||
self.event_hook = cast(
|
||||
List[GuardrailEventHooks], [current_hook, "post_call"]
|
||||
)
|
||||
elif isinstance(current_hook, list) and "post_call" not in current_hook:
|
||||
self.event_hook = cast(List[GuardrailEventHooks], current_hook + ["post_call"])
|
||||
self.event_hook = cast(
|
||||
List[GuardrailEventHooks], current_hook + ["post_call"]
|
||||
)
|
||||
self.pii_entities_config: Dict[Union[PiiEntityType, str], PiiAction] = (
|
||||
pii_entities_config or {}
|
||||
)
|
||||
|
||||
@ -1838,9 +1838,7 @@ async def _validate_update_key_data(
|
||||
|
||||
# Check team limits if key has a team_id (from request or existing key)
|
||||
team_obj: Optional[LiteLLM_TeamTableCachedObj] = None
|
||||
_team_id_to_check = data.team_id or getattr(
|
||||
existing_key_row, "team_id", None
|
||||
)
|
||||
_team_id_to_check = data.team_id or getattr(existing_key_row, "team_id", None)
|
||||
if _team_id_to_check is not None:
|
||||
team_obj = await get_team_object(
|
||||
team_id=_team_id_to_check,
|
||||
@ -1910,9 +1908,7 @@ async def _validate_update_key_data(
|
||||
if team_obj is None:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": "Team object not found for team change validation"
|
||||
},
|
||||
detail={"error": "Team object not found for team change validation"},
|
||||
)
|
||||
await validate_key_team_change(
|
||||
key=existing_key_row,
|
||||
|
||||
@ -846,12 +846,16 @@ async def get_generic_sso_response(
|
||||
verbose_proxy_logger.debug("calling generic_sso.verify_and_process")
|
||||
additional_generic_sso_headers_dict = _parse_generic_sso_headers()
|
||||
|
||||
code_verifier: Optional[str] = None # assigned inside try; initialized for type tracking
|
||||
code_verifier: Optional[
|
||||
str
|
||||
] = None # assigned inside try; initialized for type tracking
|
||||
|
||||
try:
|
||||
token_exchange_params = await SSOAuthenticationHandler.prepare_token_exchange_parameters(
|
||||
request=request,
|
||||
generic_include_client_id=generic_include_client_id,
|
||||
token_exchange_params = (
|
||||
await SSOAuthenticationHandler.prepare_token_exchange_parameters(
|
||||
request=request,
|
||||
generic_include_client_id=generic_include_client_id,
|
||||
)
|
||||
)
|
||||
|
||||
# Extract code_verifier (and the cache key for deferred deletion) before calling fastapi-sso
|
||||
@ -915,7 +919,9 @@ async def get_generic_sso_response(
|
||||
# Assign directly rather than relying on nonlocal mutation so that Pyright
|
||||
# can track that received_response is non-None from this point on.
|
||||
received_response = {
|
||||
k: v for k, v in combined_response.items() if k not in _OAUTH_TOKEN_FIELDS
|
||||
k: v
|
||||
for k, v in combined_response.items()
|
||||
if k not in _OAUTH_TOKEN_FIELDS
|
||||
}
|
||||
# In the PKCE path verify_and_process is skipped, so generic_sso.access_token
|
||||
# is never set. Read the token directly from the exchange response instead so
|
||||
@ -2598,7 +2604,9 @@ class SSOAuthenticationHandler:
|
||||
state,
|
||||
)
|
||||
else:
|
||||
verbose_proxy_logger.debug("PKCE code_verifier retrieved from cache")
|
||||
verbose_proxy_logger.debug(
|
||||
"PKCE code_verifier retrieved from cache"
|
||||
)
|
||||
elif isinstance(cached_data, str):
|
||||
# Handle legacy format (plain string) for backward compatibility
|
||||
code_verifier = cached_data
|
||||
@ -2647,7 +2655,9 @@ class SSOAuthenticationHandler:
|
||||
In strict mode (PKCE_STRICT_CACHE_MISS=true) raises ProxyException.
|
||||
Otherwise logs a warning and returns (token exchange proceeds without verifier).
|
||||
"""
|
||||
active_cache = redis_usage_cache if redis_usage_cache is not None else user_api_key_cache
|
||||
active_cache = (
|
||||
redis_usage_cache if redis_usage_cache is not None else user_api_key_cache
|
||||
)
|
||||
strict_cache_miss = (
|
||||
os.getenv("PKCE_STRICT_CACHE_MISS", "false").lower() == "true"
|
||||
)
|
||||
|
||||
@ -5233,7 +5233,7 @@ def normalize_route_for_root_path(route: str) -> Optional[str]:
|
||||
root_path = get_server_root_path()
|
||||
if root_path and root_path != "/":
|
||||
if route.startswith(root_path + "/"):
|
||||
return route[len(root_path):]
|
||||
return route[len(root_path) :]
|
||||
return None
|
||||
return route
|
||||
|
||||
|
||||
@ -415,7 +415,9 @@ class LiteLLMCompletionResponsesConfig:
|
||||
if isinstance(new_msg, dict)
|
||||
else getattr(new_msg, "tool_calls", None)
|
||||
)
|
||||
new_tcs: list = _raw_tcs if isinstance(_raw_tcs, list) else []
|
||||
new_tcs: list = (
|
||||
_raw_tcs if isinstance(_raw_tcs, list) else []
|
||||
)
|
||||
for tc in new_tcs:
|
||||
LiteLLMCompletionResponsesConfig._add_tool_call_to_assistant(
|
||||
last_msg, tc
|
||||
|
||||
@ -14,6 +14,7 @@ import hashlib
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
@ -1551,7 +1552,9 @@ class Router:
|
||||
# Drain any fire-and-forget tasks (e.g. alerting hooks)
|
||||
# scheduled via asyncio.create_task during acompletion.
|
||||
pending = asyncio.all_tasks()
|
||||
pending.discard(asyncio.current_task())
|
||||
current = asyncio.current_task()
|
||||
if current is not None:
|
||||
pending.discard(current)
|
||||
if pending:
|
||||
await asyncio.gather(*pending, return_exceptions=True)
|
||||
|
||||
@ -6542,6 +6545,18 @@ class Router:
|
||||
)
|
||||
return None
|
||||
|
||||
# Validate tag_regex patterns BEFORE adding the deployment so we never
|
||||
# have partially-initialised router state if a pattern is invalid.
|
||||
_tag_regex = deployment.litellm_params.get("tag_regex") or []
|
||||
for pattern in _tag_regex:
|
||||
try:
|
||||
re.compile(pattern)
|
||||
except re.error as exc:
|
||||
raise ValueError(
|
||||
f"Invalid regex in tag_regex for model '{deployment.model_name}': "
|
||||
f"{pattern!r} — {exc}"
|
||||
) from exc
|
||||
|
||||
deployment = self._add_deployment(deployment=deployment)
|
||||
|
||||
model = deployment.to_json(exclude_none=True)
|
||||
|
||||
@ -6,6 +6,7 @@ Use this to route requests between Teams
|
||||
- If no default_deployments are set, return all deployments
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
@ -19,6 +20,29 @@ else:
|
||||
LitellmRouter = Any
|
||||
|
||||
|
||||
def _is_valid_deployment_tag_regex(
|
||||
tag_regexes: List[str],
|
||||
header_strings: List[str],
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Test compiled regex patterns against "Header-Name: value" strings.
|
||||
|
||||
Returns the first matching pattern string, or None if nothing matches.
|
||||
Compiles each pattern once (re's LRU cache) and logs invalid patterns once
|
||||
per pattern, not once per header string.
|
||||
"""
|
||||
for pattern in tag_regexes:
|
||||
try:
|
||||
compiled = re.compile(pattern)
|
||||
except re.error:
|
||||
verbose_logger.warning("tag_regex: invalid pattern %r — skipping", pattern)
|
||||
continue
|
||||
for header_str in header_strings:
|
||||
if compiled.search(header_str):
|
||||
return pattern
|
||||
return None
|
||||
|
||||
|
||||
def is_valid_deployment_tag(
|
||||
deployment_tags: List[str], request_tags: List[str], match_any: bool = True
|
||||
) -> bool:
|
||||
@ -47,6 +71,54 @@ def is_valid_deployment_tag(
|
||||
return False
|
||||
|
||||
|
||||
def _match_deployment(
|
||||
deployment: Any,
|
||||
request_tags: Optional[List[str]],
|
||||
header_strings: List[str],
|
||||
match_any: bool,
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Determine whether *deployment* matches the current request.
|
||||
|
||||
Returns {"matched_via": ..., "matched_value": ...} if the deployment
|
||||
should be included, or None if it should be excluded.
|
||||
|
||||
Priority:
|
||||
1. Exact tag match (respects match_any semantics).
|
||||
2. Regex match — skipped when match_any=False and the tag check already
|
||||
ran and failed, so the regex cannot override strict-tag policy.
|
||||
"""
|
||||
litellm_params = deployment.get("litellm_params", {})
|
||||
deployment_tags: Optional[List[str]] = litellm_params.get("tags")
|
||||
deployment_tag_regex: Optional[List[str]] = litellm_params.get("tag_regex")
|
||||
|
||||
# 1. Exact tag match (existing behaviour).
|
||||
if deployment_tags and request_tags:
|
||||
if is_valid_deployment_tag(deployment_tags, request_tags, match_any):
|
||||
matched_value = next(
|
||||
(t for t in deployment_tags if t in set(request_tags)),
|
||||
deployment_tags[0],
|
||||
)
|
||||
return {"matched_via": "tags", "matched_value": matched_value}
|
||||
|
||||
# 2. Regex match against request headers.
|
||||
# When match_any=False and the deployment has both plain tags and tag_regex,
|
||||
# the strict tag check has already failed (step 1 returned None). Allow
|
||||
# the regex to fire only when the deployment has NO plain tags, so we never
|
||||
# use regex as a backdoor around the operator's strict-tag policy.
|
||||
strict_tag_check_failed = (
|
||||
not match_any and bool(deployment_tags) and bool(request_tags)
|
||||
)
|
||||
if deployment_tag_regex and header_strings and not strict_tag_check_failed:
|
||||
regex_match = _is_valid_deployment_tag_regex(
|
||||
deployment_tag_regex, header_strings
|
||||
)
|
||||
if regex_match is not None:
|
||||
return {"matched_via": "tag_regex", "matched_value": regex_match}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def get_deployments_for_tag(
|
||||
llm_router_instance: LitellmRouter,
|
||||
model: str, # used to raise the correct error
|
||||
@ -83,30 +155,63 @@ async def get_deployments_for_tag(
|
||||
request_tags = metadata.get("tags")
|
||||
match_any = llm_router_instance.tag_filtering_match_any
|
||||
|
||||
new_healthy_deployments = []
|
||||
default_deployments = []
|
||||
if request_tags:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tag routing: router_keys: %s", request_tags
|
||||
)
|
||||
# example this can be router_keys=["free", "custom"]
|
||||
for deployment in healthy_deployments:
|
||||
deployment_litellm_params = deployment.get("litellm_params")
|
||||
deployment_tags = deployment_litellm_params.get("tags")
|
||||
# Build header strings for regex matching from what the proxy already stores.
|
||||
# Currently we match against User-Agent; format matches "^User-Agent: claude-code/..."
|
||||
user_agent = metadata.get("user_agent", "")
|
||||
header_strings: List[str] = [f"User-Agent: {user_agent}"] if user_agent else []
|
||||
|
||||
verbose_logger.debug(
|
||||
"deployment: %s, deployment_router_keys: %s",
|
||||
deployment,
|
||||
deployment_tags,
|
||||
new_healthy_deployments: List[Any] = []
|
||||
default_deployments: List[Any] = []
|
||||
|
||||
# Only activate header-based regex filtering when at least one deployment in
|
||||
# the candidate set has tag_regex configured. This preserves existing
|
||||
# behaviour for operators who use plain tags: a request that carries a
|
||||
# User-Agent (all proxy requests do) but targets deployments with no
|
||||
# tag_regex will continue to use the original tag-only code path.
|
||||
has_regex_deployments = any(
|
||||
d.get("litellm_params", {}).get("tag_regex") for d in healthy_deployments
|
||||
)
|
||||
has_tag_filter = bool(request_tags) or (
|
||||
bool(header_strings) and has_regex_deployments
|
||||
)
|
||||
if has_tag_filter:
|
||||
verbose_logger.debug(
|
||||
"get_deployments_for_tag routing: request_tags=%s user_agent=%s",
|
||||
request_tags,
|
||||
user_agent,
|
||||
)
|
||||
for deployment in healthy_deployments:
|
||||
deployment_tags = deployment.get("litellm_params", {}).get("tags")
|
||||
|
||||
match_result = _match_deployment(
|
||||
deployment=deployment,
|
||||
request_tags=request_tags,
|
||||
header_strings=header_strings,
|
||||
match_any=match_any,
|
||||
)
|
||||
|
||||
if deployment_tags is None:
|
||||
continue
|
||||
|
||||
if is_valid_deployment_tag(deployment_tags, request_tags, match_any):
|
||||
if match_result is not None:
|
||||
verbose_logger.debug(
|
||||
"tag routing match: deployment=%s matched_via=%s matched_value=%s",
|
||||
deployment.get("model_name"),
|
||||
match_result["matched_via"],
|
||||
match_result["matched_value"],
|
||||
)
|
||||
# Record provenance in metadata so it flows to SpendLogs.
|
||||
# Written only for the first match — load balancer selects one
|
||||
# deployment from new_healthy_deployments, so overwriting on
|
||||
# subsequent matches would produce misleading observability data.
|
||||
if "tag_routing" not in metadata:
|
||||
metadata["tag_routing"] = {
|
||||
"matched_deployment": deployment.get("model_name"),
|
||||
"matched_via": match_result["matched_via"],
|
||||
"matched_value": match_result["matched_value"],
|
||||
"request_tags": request_tags or [],
|
||||
"user_agent": user_agent,
|
||||
}
|
||||
new_healthy_deployments.append(deployment)
|
||||
|
||||
if "default" in deployment_tags:
|
||||
if deployment_tags and "default" in deployment_tags:
|
||||
default_deployments.append(deployment)
|
||||
|
||||
if len(new_healthy_deployments) == 0 and len(default_deployments) == 0:
|
||||
|
||||
@ -198,6 +198,11 @@ class GenericLiteLLMParams(CredentialLiteLLMParams, CustomPricingLiteLLMParams):
|
||||
model_info: Optional[Dict] = None
|
||||
mock_response: Optional[Union[str, ModelResponse, Exception, Any]] = None
|
||||
|
||||
# tag-based routing
|
||||
tags: Optional[List[str]] = None
|
||||
# regex patterns matched against request headers for tag routing
|
||||
tag_regex: Optional[List[str]] = None
|
||||
|
||||
# auto-router params
|
||||
auto_router_config_path: Optional[str] = None
|
||||
auto_router_config: Optional[str] = None
|
||||
@ -334,6 +339,8 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||
# routing params
|
||||
# use this for tag-based routing
|
||||
tags: Optional[List[str]]
|
||||
# regex patterns matched against request headers (e.g. "^User-Agent:\\s*claude-code\\/")
|
||||
tag_regex: Optional[List[str]]
|
||||
|
||||
# deployment budgets
|
||||
max_budget: Optional[float]
|
||||
|
||||
@ -272,6 +272,164 @@ class TestBedrockFilesTransformation:
|
||||
assert "messages" in model_input
|
||||
assert "max_tokens" in model_input
|
||||
|
||||
def test_get_complete_file_url_respects_s3_region_name(self):
|
||||
"""
|
||||
s3_region_name in litellm_params must be used when building the S3 URL.
|
||||
Previously the code fell back to us-west-2 even when s3_region_name was set,
|
||||
breaking GovCloud (us-gov-west-1) deployments.
|
||||
"""
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
|
||||
jsonl_content = json.dumps(
|
||||
{
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "bedrock/amazon.nova-pro-v1:0",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
).encode()
|
||||
|
||||
create_file_data = {
|
||||
"file": ("batch.jsonl", jsonl_content, "application/jsonl"),
|
||||
"purpose": "batch",
|
||||
}
|
||||
|
||||
litellm_params = {
|
||||
"s3_bucket_name": "litellm-batch-352026",
|
||||
"s3_region_name": "us-gov-west-1",
|
||||
}
|
||||
|
||||
url = config.get_complete_file_url(
|
||||
api_base=None,
|
||||
api_key=None,
|
||||
model="amazon.nova-pro-v1:0",
|
||||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
data=create_file_data,
|
||||
)
|
||||
|
||||
assert "us-gov-west-1" in url, (
|
||||
f"Expected us-gov-west-1 in URL but got: {url}"
|
||||
)
|
||||
assert "us-west-2" not in url, (
|
||||
f"us-west-2 must not appear when s3_region_name is set, got: {url}"
|
||||
)
|
||||
assert "litellm-batch-352026" in url
|
||||
|
||||
def test_transform_create_file_request_injects_s3_region_for_signing(self):
|
||||
"""
|
||||
When s3_region_name is provided, transform_create_file_request must pass
|
||||
that region to _sign_s3_request so SigV4 signatures use the correct region.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
|
||||
jsonl_content = json.dumps(
|
||||
{
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "bedrock/amazon.nova-pro-v1:0",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
).encode()
|
||||
|
||||
create_file_data = {
|
||||
"file": ("batch.jsonl", jsonl_content, "application/jsonl"),
|
||||
"purpose": "batch",
|
||||
}
|
||||
|
||||
litellm_params = {
|
||||
"s3_bucket_name": "litellm-batch-352026",
|
||||
"s3_region_name": "us-gov-west-1",
|
||||
}
|
||||
|
||||
captured_optional_params: dict = {}
|
||||
|
||||
def fake_sign(content, api_base, optional_params):
|
||||
captured_optional_params.update(optional_params)
|
||||
return {"Authorization": "fake"}, content
|
||||
|
||||
with patch.object(config, "_sign_s3_request", side_effect=fake_sign):
|
||||
config.transform_create_file_request(
|
||||
model="amazon.nova-pro-v1:0",
|
||||
create_file_data=create_file_data,
|
||||
optional_params={},
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
assert captured_optional_params.get("aws_region_name") == "us-gov-west-1", (
|
||||
"s3_region_name must be forwarded as aws_region_name for SigV4 signing"
|
||||
)
|
||||
|
||||
def test_s3_region_name_wins_over_aws_region_name_for_signing(self):
|
||||
"""
|
||||
When both s3_region_name and aws_region_name are set to different values,
|
||||
s3_region_name must win for signing (same as for the URL). Otherwise the
|
||||
SigV4 signature would be computed against a different region than the URL,
|
||||
causing SignatureDoesNotMatch from AWS.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
config = BedrockFilesConfig()
|
||||
|
||||
jsonl_content = json.dumps(
|
||||
{
|
||||
"custom_id": "req-1",
|
||||
"method": "POST",
|
||||
"url": "/v1/chat/completions",
|
||||
"body": {
|
||||
"model": "bedrock/amazon.nova-pro-v1:0",
|
||||
"messages": [{"role": "user", "content": "Hello"}],
|
||||
"max_tokens": 10,
|
||||
},
|
||||
}
|
||||
).encode()
|
||||
|
||||
create_file_data = {
|
||||
"file": ("batch.jsonl", jsonl_content, "application/jsonl"),
|
||||
"purpose": "batch",
|
||||
}
|
||||
|
||||
litellm_params = {
|
||||
"s3_bucket_name": "litellm-batch-352026",
|
||||
"s3_region_name": "us-gov-west-1",
|
||||
}
|
||||
# aws_region_name set to something different — s3_region_name must still win
|
||||
optional_params = {"aws_region_name": "us-east-1"}
|
||||
|
||||
captured_optional_params: dict = {}
|
||||
|
||||
def fake_sign(content, api_base, optional_params):
|
||||
captured_optional_params.update(optional_params)
|
||||
return {"Authorization": "fake"}, content
|
||||
|
||||
with patch.object(config, "_sign_s3_request", side_effect=fake_sign):
|
||||
config.transform_create_file_request(
|
||||
model="amazon.nova-pro-v1:0",
|
||||
create_file_data=create_file_data,
|
||||
optional_params=optional_params,
|
||||
litellm_params=litellm_params,
|
||||
)
|
||||
|
||||
assert captured_optional_params.get("aws_region_name") == "us-gov-west-1", (
|
||||
"s3_region_name must override aws_region_name for SigV4 signing"
|
||||
)
|
||||
|
||||
def test_openai_passthrough_still_works(self):
|
||||
"""
|
||||
Regression test: ensure OpenAI-compatible models (e.g. gpt-oss)
|
||||
|
||||
@ -14,15 +14,16 @@ from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from botocore.awsrequest import AWSPreparedRequest, AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
from botocore.awsrequest import AWSRequest, AWSPreparedRequest
|
||||
|
||||
import litellm
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.llms.bedrock.base_aws_llm import (
|
||||
AwsAuthError,
|
||||
BaseAWSLLM,
|
||||
Boto3CredentialsInfo,
|
||||
)
|
||||
from litellm.caching.caching import DualCache
|
||||
|
||||
# Global variable for the base_aws_llm.py file path
|
||||
|
||||
@ -1519,6 +1520,83 @@ def test_is_already_running_as_role_invalid_target_arn():
|
||||
assert base_aws_llm._is_already_running_as_role("not-a-valid-arn") is False
|
||||
|
||||
|
||||
def test_filter_headers_skips_none_values():
|
||||
"""
|
||||
Test that _filter_headers_for_aws_signature skips headers with None values.
|
||||
|
||||
Reproduces the issue where botocore's SigV4Auth crashes with
|
||||
'NoneType' object has no attribute 'split' when a header value is None.
|
||||
"""
|
||||
llm = BaseAWSLLM()
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"x-amz-security-token": None,
|
||||
"x-amzn-bedrock-kb-session-id": None,
|
||||
"host": None,
|
||||
"x-amz-date": "20240101T000000Z",
|
||||
"x-custom-header": None,
|
||||
}
|
||||
|
||||
filtered = llm._filter_headers_for_aws_signature(headers)
|
||||
|
||||
assert filtered["Content-Type"] == "application/json"
|
||||
assert filtered["x-amz-date"] == "20240101T000000Z"
|
||||
assert "x-amz-security-token" not in filtered
|
||||
assert "x-amzn-bedrock-kb-session-id" not in filtered
|
||||
assert "host" not in filtered
|
||||
# Non-AWS headers are excluded regardless
|
||||
assert "x-custom-header" not in filtered
|
||||
|
||||
|
||||
def test_sign_request_with_none_header_values():
|
||||
"""
|
||||
End-to-end test that _sign_request does not crash when headers contain
|
||||
None values for x-amz-* keys.
|
||||
|
||||
This reproduces the Bedrock KB GovCloud issue where SigV4 signing failed
|
||||
with 'NoneType' object has no attribute 'split'.
|
||||
|
||||
Also verifies that None-valued headers are NOT re-merged into the
|
||||
returned headers dict (which would cause downstream HTTP client failures).
|
||||
"""
|
||||
llm = BaseAWSLLM()
|
||||
|
||||
mock_credentials = Credentials("test_key", "test_secret")
|
||||
|
||||
headers_with_nones = {
|
||||
"Content-Type": "application/json",
|
||||
"x-amzn-trace-id": None,
|
||||
"x-forwarded-for": None,
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
llm, "get_credentials", return_value=mock_credentials
|
||||
), patch.object(
|
||||
llm, "_get_aws_region_name", return_value="us-gov-west-1"
|
||||
):
|
||||
result_headers, result_body = llm._sign_request(
|
||||
service_name="bedrock",
|
||||
headers=headers_with_nones,
|
||||
optional_params={
|
||||
"aws_access_key_id": "test_key",
|
||||
"aws_secret_access_key": "test_secret",
|
||||
"aws_region_name": "us-gov-west-1",
|
||||
},
|
||||
request_data={"retrievalQuery": {"text": "test query"}},
|
||||
api_base="https://bedrock-agent-runtime.us-gov-west-1.amazonaws.com/knowledgebases/KB123/retrieve",
|
||||
)
|
||||
|
||||
assert "Authorization" in result_headers
|
||||
assert result_body is not None
|
||||
|
||||
# None-valued headers must NOT appear in the returned headers
|
||||
for header_name, header_value in result_headers.items():
|
||||
assert header_value is not None, (
|
||||
f"Header '{header_name}' has None value in returned headers"
|
||||
)
|
||||
|
||||
|
||||
def test_is_already_running_as_role_ssl_verify_passed():
|
||||
"""
|
||||
Test that ssl_verify parameter is correctly passed to the STS client.
|
||||
|
||||
@ -0,0 +1,375 @@
|
||||
"""
|
||||
Unit tests for tag_regex routing.
|
||||
|
||||
Tests _is_valid_deployment_tag_regex() and get_deployments_for_tag() with tag_regex
|
||||
patterns, verifying that regex-based header matching works correctly alongside
|
||||
existing tag-based routing.
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from litellm.router_strategy import tag_based_routing
|
||||
from litellm.router_strategy.tag_based_routing import get_deployments_for_tag
|
||||
|
||||
_is_valid_deployment_tag_regex = tag_based_routing._is_valid_deployment_tag_regex
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_valid_deployment_tag_regex unit tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_regex_matches_claude_code_user_agent():
|
||||
"""^User-Agent: claude-code/ matches a claude-code UA string."""
|
||||
result = _is_valid_deployment_tag_regex(
|
||||
tag_regexes=[r"^User-Agent: claude-code\/"],
|
||||
header_strings=["User-Agent: claude-code/1.2.3"],
|
||||
)
|
||||
assert result == r"^User-Agent: claude-code\/"
|
||||
|
||||
|
||||
def test_regex_no_match_for_other_ua():
|
||||
"""Pattern does not match a non-claude-code User-Agent."""
|
||||
result = _is_valid_deployment_tag_regex(
|
||||
tag_regexes=[r"^User-Agent: claude-code\/"],
|
||||
header_strings=["User-Agent: Mozilla/5.0 (browser)"],
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_regex_returns_first_matching_pattern():
|
||||
"""When multiple patterns are provided, returns the first match."""
|
||||
result = _is_valid_deployment_tag_regex(
|
||||
tag_regexes=[r"^User-Agent: cursor\/", r"^User-Agent: claude-code\/"],
|
||||
header_strings=["User-Agent: claude-code/2.0.0"],
|
||||
)
|
||||
assert result == r"^User-Agent: claude-code\/"
|
||||
|
||||
|
||||
def test_regex_empty_inputs_return_none():
|
||||
"""Empty lists return None without errors."""
|
||||
assert _is_valid_deployment_tag_regex([], ["User-Agent: claude-code/1.0"]) is None
|
||||
assert _is_valid_deployment_tag_regex([r"^User-Agent: claude-code\/"], []) is None
|
||||
|
||||
|
||||
def test_invalid_regex_skipped_does_not_raise():
|
||||
"""An invalid regex pattern is skipped (warning logged) — no exception raised."""
|
||||
result = _is_valid_deployment_tag_regex(
|
||||
tag_regexes=["[invalid(regex"],
|
||||
header_strings=["User-Agent: claude-code/1.0"],
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_regex_matches_version_range():
|
||||
"""Semver-aware pattern matches multiple versions."""
|
||||
pattern = r"^User-Agent: claude-code\/\d"
|
||||
for ua in ["claude-code/1.0", "claude-code/2.0.0-beta.1", "claude-code/99.0"]:
|
||||
result = _is_valid_deployment_tag_regex(
|
||||
tag_regexes=[pattern],
|
||||
header_strings=[f"User-Agent: {ua}"],
|
||||
)
|
||||
assert result == pattern, f"Expected match for UA: {ua}"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_deployments_for_tag integration tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CLAUDE_CODE_DEPLOYMENT = {
|
||||
"model_name": "claude-sonnet",
|
||||
"litellm_params": {
|
||||
"model": "openai/claude-code-deployment",
|
||||
"api_key": "fake",
|
||||
"mock_response": "cc",
|
||||
"tag_regex": [r"^User-Agent: claude-code\/"],
|
||||
},
|
||||
"model_info": {"id": "claude-code-deployment"},
|
||||
}
|
||||
|
||||
REGULAR_DEPLOYMENT = {
|
||||
"model_name": "claude-sonnet",
|
||||
"litellm_params": {
|
||||
"model": "openai/regular-deployment",
|
||||
"api_key": "fake",
|
||||
"mock_response": "regular",
|
||||
"tags": ["default"],
|
||||
},
|
||||
"model_info": {"id": "regular-deployment"},
|
||||
}
|
||||
|
||||
ALL_DEPLOYMENTS = [CLAUDE_CODE_DEPLOYMENT, REGULAR_DEPLOYMENT]
|
||||
|
||||
|
||||
def _make_router_mock(enable_tag_filtering=True, match_any=True):
|
||||
mock = MagicMock()
|
||||
mock.enable_tag_filtering = enable_tag_filtering
|
||||
mock.tag_filtering_match_any = match_any
|
||||
return mock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude_code_ua_routes_to_cc_deployment():
|
||||
"""claude-code/x.y.z UA → claude-code-deployment via tag_regex."""
|
||||
router = _make_router_mock()
|
||||
result = await get_deployments_for_tag(
|
||||
llm_router_instance=router,
|
||||
model="claude-sonnet",
|
||||
healthy_deployments=ALL_DEPLOYMENTS,
|
||||
request_kwargs={"metadata": {"user_agent": "claude-code/1.2.3"}},
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0]["model_info"]["id"] == "claude-code-deployment"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regular_ua_routes_to_default_deployment():
|
||||
"""Mozilla UA → regular-deployment via default tag fallback."""
|
||||
router = _make_router_mock()
|
||||
result = await get_deployments_for_tag(
|
||||
llm_router_instance=router,
|
||||
model="claude-sonnet",
|
||||
healthy_deployments=ALL_DEPLOYMENTS,
|
||||
request_kwargs={"metadata": {"user_agent": "Mozilla/5.0 (browser)"}},
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0]["model_info"]["id"] == "regular-deployment"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_ua_routes_to_default_deployment():
|
||||
"""No User-Agent → default deployment."""
|
||||
router = _make_router_mock()
|
||||
result = await get_deployments_for_tag(
|
||||
llm_router_instance=router,
|
||||
model="claude-sonnet",
|
||||
healthy_deployments=ALL_DEPLOYMENTS,
|
||||
request_kwargs={"metadata": {}},
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0]["model_info"]["id"] == "regular-deployment"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tag_routing_metadata_written_for_regex_match():
|
||||
"""tag_routing metadata block is populated when regex matches."""
|
||||
router = _make_router_mock()
|
||||
metadata: dict = {"user_agent": "claude-code/2.0.0-beta.1"}
|
||||
await get_deployments_for_tag(
|
||||
llm_router_instance=router,
|
||||
model="claude-sonnet",
|
||||
healthy_deployments=ALL_DEPLOYMENTS,
|
||||
request_kwargs={"metadata": metadata},
|
||||
)
|
||||
assert "tag_routing" in metadata
|
||||
tr = metadata["tag_routing"]
|
||||
assert tr["matched_via"] == "tag_regex"
|
||||
assert tr["matched_value"] == r"^User-Agent: claude-code\/"
|
||||
assert tr["user_agent"] == "claude-code/2.0.0-beta.1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tag_filtering_disabled_returns_all_deployments():
|
||||
"""When enable_tag_filtering is False, all deployments returned regardless of UA."""
|
||||
router = _make_router_mock(enable_tag_filtering=False)
|
||||
result = await get_deployments_for_tag(
|
||||
llm_router_instance=router,
|
||||
model="claude-sonnet",
|
||||
healthy_deployments=ALL_DEPLOYMENTS,
|
||||
request_kwargs={"metadata": {"user_agent": "claude-code/1.0"}},
|
||||
)
|
||||
assert result == ALL_DEPLOYMENTS
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_explicit_tag_match_takes_precedence_over_regex():
|
||||
"""A deployment with both tags and tag_regex: exact tag match fires first."""
|
||||
deployment_with_both = {
|
||||
"model_name": "claude-sonnet",
|
||||
"litellm_params": {
|
||||
"model": "openai/both-deployment",
|
||||
"api_key": "fake",
|
||||
"tags": ["premium"],
|
||||
"tag_regex": [r"^User-Agent: claude-code\/"],
|
||||
},
|
||||
"model_info": {"id": "both-deployment"},
|
||||
}
|
||||
router = _make_router_mock()
|
||||
metadata: dict = {
|
||||
"tags": ["premium"],
|
||||
"user_agent": "claude-code/1.0",
|
||||
}
|
||||
result = await get_deployments_for_tag(
|
||||
llm_router_instance=router,
|
||||
model="claude-sonnet",
|
||||
healthy_deployments=[deployment_with_both],
|
||||
request_kwargs={"metadata": metadata},
|
||||
)
|
||||
assert len(result) == 1
|
||||
tr = metadata.get("tag_routing", {})
|
||||
assert tr.get("matched_via") == "tags"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_agent_present_no_tag_regex_deployments_does_not_raise():
|
||||
"""
|
||||
Backwards-compat: a request that carries a User-Agent but targets plain-tag
|
||||
deployments (no tag_regex) must NOT raise ValueError — it should fall
|
||||
through to the default/all-deployments path just like before.
|
||||
"""
|
||||
plain_tag_only_deployments = [
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "openai/premium-deployment",
|
||||
"api_key": "fake",
|
||||
"tags": ["premium"],
|
||||
},
|
||||
"model_info": {"id": "premium-deployment"},
|
||||
},
|
||||
{
|
||||
"model_name": "gpt-4",
|
||||
"litellm_params": {
|
||||
"model": "openai/free-deployment",
|
||||
"api_key": "fake",
|
||||
"tags": ["free"],
|
||||
},
|
||||
"model_info": {"id": "free-deployment"},
|
||||
},
|
||||
]
|
||||
router = _make_router_mock()
|
||||
# The request has a User-Agent (as all proxy requests do) but NO tags and
|
||||
# neither deployment has tag_regex — must not raise, must return all.
|
||||
result = await get_deployments_for_tag(
|
||||
llm_router_instance=router,
|
||||
model="gpt-4",
|
||||
healthy_deployments=plain_tag_only_deployments,
|
||||
request_kwargs={"metadata": {"user_agent": "Mozilla/5.0 (any-client)"}},
|
||||
)
|
||||
# Falls through to "return healthy_deployments" path unchanged
|
||||
assert result == plain_tag_only_deployments
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tag_routing_metadata_not_overwritten_for_multiple_matches():
|
||||
"""
|
||||
When multiple deployments match, tag_routing records only the first match
|
||||
so the provenance reflects what the load balancer likely selected.
|
||||
"""
|
||||
deployment_a = {
|
||||
"model_name": "claude-sonnet",
|
||||
"litellm_params": {
|
||||
"model": "openai/cc-deployment-a",
|
||||
"api_key": "fake",
|
||||
"tag_regex": [r"^User-Agent: claude-code\/"],
|
||||
},
|
||||
"model_info": {"id": "cc-deployment-a"},
|
||||
}
|
||||
deployment_b = {
|
||||
"model_name": "claude-sonnet",
|
||||
"litellm_params": {
|
||||
"model": "openai/cc-deployment-b",
|
||||
"api_key": "fake",
|
||||
"tag_regex": [r"^User-Agent: claude-code\/"],
|
||||
},
|
||||
"model_info": {"id": "cc-deployment-b"},
|
||||
}
|
||||
router = _make_router_mock()
|
||||
metadata: dict = {"user_agent": "claude-code/1.0"}
|
||||
result = await get_deployments_for_tag(
|
||||
llm_router_instance=router,
|
||||
model="claude-sonnet",
|
||||
healthy_deployments=[deployment_a, deployment_b],
|
||||
request_kwargs={"metadata": metadata},
|
||||
)
|
||||
assert len(result) == 2
|
||||
# tag_routing recorded once and reflects the first match
|
||||
tr = metadata.get("tag_routing", {})
|
||||
assert tr.get("matched_deployment") == "claude-sonnet"
|
||||
assert tr.get("matched_via") == "tag_regex"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_any_false_strict_tag_check_blocks_regex_fallback():
|
||||
"""
|
||||
When match_any=False and a deployment has both tags and tag_regex:
|
||||
if the strict tag check fails (request has a tag NOT present on the
|
||||
deployment, so req_set is NOT a subset of dep_set), the regex fallback
|
||||
must NOT fire — that would violate the operator's strict-filtering intent.
|
||||
|
||||
Semantics of match_any=False: req_set.issubset(dep_set), i.e. every
|
||||
request tag must appear on the deployment. A request with tags ["vip"]
|
||||
against a deployment with tags ["premium"] fails because "vip" ∉ dep_set.
|
||||
"""
|
||||
deployment_strict = {
|
||||
"model_name": "claude-sonnet",
|
||||
"litellm_params": {
|
||||
"model": "openai/strict-deployment",
|
||||
"api_key": "fake",
|
||||
"tags": ["premium"],
|
||||
"tag_regex": [r"^User-Agent: claude-code\/"],
|
||||
},
|
||||
"model_info": {"id": "strict-deployment"},
|
||||
}
|
||||
default_deployment = {
|
||||
"model_name": "claude-sonnet",
|
||||
"litellm_params": {
|
||||
"model": "openai/default-deployment",
|
||||
"api_key": "fake",
|
||||
"tags": ["default"],
|
||||
},
|
||||
"model_info": {"id": "default-deployment"},
|
||||
}
|
||||
# match_any=False: req_set must be a subset of dep_set.
|
||||
# Request has "vip" which is NOT in ["premium"], so tag check fails.
|
||||
# Even though UA matches tag_regex, the deployment must NOT be selected.
|
||||
router = _make_router_mock(enable_tag_filtering=True, match_any=False)
|
||||
metadata: dict = {
|
||||
"tags": ["vip"], # "vip" not in deployment tags → strict check fails
|
||||
"user_agent": "claude-code/1.0",
|
||||
}
|
||||
result = await get_deployments_for_tag(
|
||||
llm_router_instance=router,
|
||||
model="claude-sonnet",
|
||||
healthy_deployments=[deployment_strict, default_deployment],
|
||||
request_kwargs={"metadata": metadata},
|
||||
)
|
||||
ids = [d["model_info"]["id"] for d in result]
|
||||
assert "strict-deployment" not in ids, (
|
||||
"strict-deployment should not be selected: strict tag check failed "
|
||||
"and regex must not override the strict policy"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_match_any_false_regex_only_deployment_still_matches():
|
||||
"""
|
||||
When match_any=False and a deployment has ONLY tag_regex (no plain tags),
|
||||
there is no strict tag policy to violate, so the regex check must still fire.
|
||||
"""
|
||||
regex_only_deployment = {
|
||||
"model_name": "claude-sonnet",
|
||||
"litellm_params": {
|
||||
"model": "openai/regex-only-deployment",
|
||||
"api_key": "fake",
|
||||
"tag_regex": [r"^User-Agent: claude-code\/"],
|
||||
# no "tags" key at all
|
||||
},
|
||||
"model_info": {"id": "regex-only-deployment"},
|
||||
}
|
||||
router = _make_router_mock(enable_tag_filtering=True, match_any=False)
|
||||
result = await get_deployments_for_tag(
|
||||
llm_router_instance=router,
|
||||
model="claude-sonnet",
|
||||
healthy_deployments=[regex_only_deployment],
|
||||
request_kwargs={"metadata": {"user_agent": "claude-code/1.0"}},
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0]["model_info"]["id"] == "regex-only-deployment"
|
||||
@ -5,8 +5,9 @@ The ``_expand_model_aliases`` function processes ``aliases`` lists from model
|
||||
entries, creating shared dict references for alias entries at load time.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.get_model_cost_map import _expand_model_aliases
|
||||
|
||||
|
||||
@ -118,7 +119,7 @@ class TestExpandModelAliases:
|
||||
class TestAliasConflicts:
|
||||
"""Tests for alias conflict detection and handling."""
|
||||
|
||||
def test_alias_conflicts_with_canonical_entry(self, caplog):
|
||||
def test_alias_conflicts_with_canonical_entry(self):
|
||||
"""Alias that matches an existing canonical entry is skipped with a warning."""
|
||||
model_cost = {
|
||||
"model-latest": {
|
||||
@ -133,14 +134,17 @@ class TestAliasConflicts:
|
||||
"mode": "chat",
|
||||
},
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="LiteLLM"):
|
||||
with patch.object(verbose_logger, "warning") as mock_warn:
|
||||
result = _expand_model_aliases(model_cost)
|
||||
|
||||
# The canonical "model-dated" entry is preserved, not overwritten
|
||||
assert "model-dated" in result
|
||||
assert "alias conflict" in caplog.text.lower()
|
||||
# Verify a warning about the alias conflict was logged
|
||||
mock_warn.assert_called()
|
||||
warning_messages = " ".join(str(c) for c in mock_warn.call_args_list)
|
||||
assert "alias conflict" in warning_messages.lower()
|
||||
|
||||
def test_duplicate_alias_across_entries(self, caplog):
|
||||
def test_duplicate_alias_across_entries(self):
|
||||
"""Same alias claimed by two different entries: second one is skipped."""
|
||||
model_cost = {
|
||||
"model-a": {
|
||||
@ -156,13 +160,16 @@ class TestAliasConflicts:
|
||||
"mode": "chat",
|
||||
},
|
||||
}
|
||||
with caplog.at_level(logging.WARNING, logger="LiteLLM"):
|
||||
with patch.object(verbose_logger, "warning") as mock_warn:
|
||||
result = _expand_model_aliases(model_cost)
|
||||
|
||||
# "shared-alias" should point to model-a (first one wins)
|
||||
assert "shared-alias" in result
|
||||
assert result["shared-alias"]["input_cost_per_token"] == 1e-06
|
||||
assert "alias conflict" in caplog.text.lower()
|
||||
# Verify a warning about the alias conflict was logged
|
||||
mock_warn.assert_called()
|
||||
warning_messages = " ".join(str(c) for c in mock_warn.call_args_list)
|
||||
assert "alias conflict" in warning_messages.lower()
|
||||
|
||||
def test_canonical_entry_not_overwritten_by_alias(self):
|
||||
"""An alias must never overwrite an existing canonical entry's data."""
|
||||
|
||||
Loading…
Reference in New Issue
Block a user