Merge pull request #26161 from BerriAI/litellm_access-group-routing-fix

fix(router): constrain same-name deployment routing by access groups
This commit is contained in:
Sameer Kankute 2026-05-02 11:19:51 +05:30 committed by GitHub
commit d2015f0baf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 637 additions and 42 deletions

View File

@ -9351,6 +9351,52 @@ class Router:
"""
return [m for m in self.model_list if m["litellm_params"]["model"] == model]
def _try_early_resolve_deployments_for_model_not_in_names(
self, model: str, request_team_id: Optional[str]
) -> Optional[Tuple[str, Union[List, Dict]]]:
"""
When ``model`` is not in ``self.model_names``, try team routes, pattern routes,
team pattern routers, then default deployment. Returns None if none apply.
"""
if model in self.model_names:
return None
# Check for team-specific deployments by team_public_model_name.
# This intentionally takes priority over team pattern routers below,
# so that named team deployments shadow wildcard/pattern routes.
if request_team_id is not None:
team_deployments = self._get_all_deployments(
model_name=model, team_id=request_team_id
)
if team_deployments:
return model, team_deployments
pattern_deployments = self.pattern_router.get_deployments_by_pattern(
model=model,
)
if pattern_deployments:
return model, pattern_deployments
if request_team_id is not None and request_team_id in self.team_pattern_routers:
pattern_deployments = self.team_pattern_routers[
request_team_id
].get_deployments_by_pattern(
model=model,
)
if pattern_deployments:
return model, pattern_deployments
if self.default_deployment is not None:
# Shallow copy with nested litellm_params copy (100x+ faster than deepcopy)
updated_deployment = self.default_deployment.copy()
updated_deployment["litellm_params"] = self.default_deployment[
"litellm_params"
].copy()
updated_deployment["litellm_params"]["model"] = model
return model, updated_deployment
return None
def _common_checks_available_deployment(
self,
model: str,
@ -9393,56 +9439,52 @@ class Router:
if _model_from_alias is not None:
model = _model_from_alias
if model not in self.model_names:
# Check for team-specific deployments by team_public_model_name.
# This intentionally takes priority over team pattern routers below,
# so that named team deployments shadow wildcard/pattern routes.
if request_team_id is not None:
team_deployments = self._get_all_deployments(
model_name=model, team_id=request_team_id
)
if team_deployments:
return model, team_deployments
# check if provider/ specific wildcard routing use pattern matching
pattern_deployments = self.pattern_router.get_deployments_by_pattern(
model=model,
)
if pattern_deployments:
return model, pattern_deployments
if (
request_team_id is not None
and request_team_id in self.team_pattern_routers
):
pattern_deployments = self.team_pattern_routers[
request_team_id
].get_deployments_by_pattern(
model=model,
)
if pattern_deployments:
return model, pattern_deployments
# check if default deployment is set
if self.default_deployment is not None:
# Shallow copy with nested litellm_params copy (100x+ faster than deepcopy)
updated_deployment = self.default_deployment.copy()
updated_deployment["litellm_params"] = self.default_deployment[
"litellm_params"
].copy()
updated_deployment["litellm_params"]["model"] = model
return model, updated_deployment
early = self._try_early_resolve_deployments_for_model_not_in_names(
model=model, request_team_id=request_team_id
)
if early is not None:
return early
## get healthy deployments
### get all deployments
healthy_deployments = self._get_all_deployments(
model_name=model, team_id=request_team_id
)
_pre_model_access_group_filter_len = len(healthy_deployments)
healthy_deployments = self._filter_deployments_by_model_access_groups(
model=model,
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
request_team_id=request_team_id,
)
_access_group_filter_emptied_candidates = (
_pre_model_access_group_filter_len > 0 and len(healthy_deployments) == 0
)
if len(healthy_deployments) == 0:
# check if the user sent in a deployment name instead
healthy_deployments = self._get_deployment_by_litellm_model(model=model)
# Do not fall back when access-group filtering removed every candidate;
# _get_deployment_by_litellm_model does not re-apply that filter.
if _pre_model_access_group_filter_len == 0:
_litellm_model_deployments = self._get_deployment_by_litellm_model(
model=model
)
healthy_deployments = self._filter_deployments_by_model_access_groups(
model=model,
healthy_deployments=_litellm_model_deployments,
request_kwargs=request_kwargs,
request_team_id=request_team_id,
)
# If the litellm-model lookup produced candidates that access-group
# filtering then removed, treat this the same as the by-name path
# being emptied: prevent default-model fallback from bypassing the
# restriction (the fallback model may have no access_groups and
# would short-circuit the filter).
if (
len(_litellm_model_deployments) > 0
and len(healthy_deployments) == 0
):
_access_group_filter_emptied_candidates = True
if verbose_router_logger.isEnabledFor(logging.DEBUG):
verbose_router_logger.debug(
@ -9451,7 +9493,13 @@ class Router:
if len(healthy_deployments) == 0:
# Check for default fallbacks if no deployments are found for the requested model
if self._has_default_fallbacks():
# Do not fall back to another model when access-group filtering removed every
# candidate for the requested name: re-filtering the fallback model can be a
# no-op when it has no access_groups, incorrectly serving a different model.
if (
self._has_default_fallbacks()
and not _access_group_filter_emptied_candidates
):
fallback_model = self._get_first_default_fallback()
if fallback_model:
verbose_router_logger.info(
@ -9462,6 +9510,14 @@ class Router:
healthy_deployments = self._get_all_deployments(
model_name=model, team_id=request_team_id
)
healthy_deployments = (
self._filter_deployments_by_model_access_groups(
model=model,
healthy_deployments=healthy_deployments,
request_kwargs=request_kwargs,
request_team_id=request_team_id,
)
)
# If still no deployments after checking for fallbacks, raise an error
if len(healthy_deployments) == 0:
@ -9487,6 +9543,70 @@ class Router:
return model, healthy_deployments
def _filter_deployments_by_model_access_groups(
self,
model: str,
healthy_deployments: List,
request_kwargs: Optional[Dict],
request_team_id: Optional[str],
) -> List:
"""
Restrict candidate deployments to caller-authorized model access groups.
This is only applied when:
- request metadata includes `user_api_key_auth`, and
- caller permissions for this model are access-group-only
(no explicit model, wildcard, or all-proxy grants).
"""
if not healthy_deployments or request_kwargs is None:
return healthy_deployments
metadata = request_kwargs.get("metadata") or {}
litellm_metadata = request_kwargs.get("litellm_metadata") or {}
user_api_key_auth = metadata.get("user_api_key_auth") or litellm_metadata.get(
"user_api_key_auth"
)
if user_api_key_auth is None:
return healthy_deployments
object_models = set(getattr(user_api_key_auth, "models", []) or [])
object_team_models = set(getattr(user_api_key_auth, "team_models", []) or [])
allowed_models = object_models | object_team_models
if not allowed_models:
return healthy_deployments
# If caller has direct model/wildcard/all-proxy access, do not constrain
# deployment choice by access group.
if (
model in allowed_models
or "*" in allowed_models
or "all-proxy-models" in allowed_models
):
return healthy_deployments
access_groups_for_model = self.get_model_access_groups(
model_name=model, team_id=request_team_id
)
if len(access_groups_for_model) == 0:
return healthy_deployments
allowed_access_groups = set(access_groups_for_model.keys()) & allowed_models
if not allowed_access_groups:
# No overlap means this request was not authorized via model access
# group membership for this model, so do not force group filtering.
return healthy_deployments
filtered_deployments = []
for deployment in healthy_deployments:
deployment_model_info = deployment.get("model_info") or {}
deployment_access_groups = set(
deployment_model_info.get("access_groups", []) or []
)
if deployment_access_groups & allowed_access_groups:
filtered_deployments.append(deployment)
return filtered_deployments
async def async_get_healthy_deployments(
self,
model: str,
@ -10007,6 +10127,7 @@ class Router:
messages=messages,
input=input,
specific_deployment=specific_deployment,
request_kwargs=request_kwargs,
)
if isinstance(healthy_deployments, dict):

View File

@ -4,6 +4,7 @@ from unittest.mock import Mock
import pytest
from litellm import Router
from litellm.proxy._types import UserAPIKeyAuth
from litellm.router_utils.common_utils import (
_deployment_supports_web_search,
add_model_file_id_mappings,
@ -365,6 +366,49 @@ def test_invalidate_model_group_info_cache():
assert router._cached_get_model_group_info.cache_info().currsize == 0
def test_filter_deployments_by_model_access_groups_access_group_only_key():
"""
Access-group-only keys should only route to deployments in allowed groups,
even when multiple deployments share the same public model name.
"""
router = Router(
model_list=[
{
"model_name": "gpt-5",
"litellm_params": {"model": "openai/gpt-5.1", "api_key": "key-1"},
"model_info": {"access_groups": ["AG1"]},
},
{
"model_name": "gpt-5",
"litellm_params": {"model": "openai/gpt-4o", "api_key": "key-2"},
"model_info": {"access_groups": ["AG2"]},
},
]
)
scoped_key = UserAPIKeyAuth(
api_key="hashed-key",
team_id="team-2",
models=["AG2"],
team_models=["AG2"],
)
filtered = router._filter_deployments_by_model_access_groups(
model="gpt-5",
healthy_deployments=router._get_all_deployments(model_name="gpt-5"),
request_kwargs={
"metadata": {
"user_api_key_team_id": "team-2",
"user_api_key_auth": scoped_key,
}
},
request_team_id="team-2",
)
assert len(filtered) == 1
assert filtered[0].get("model_info", {}).get("access_groups") == ["AG2"]
class TestAddModelFileIdMappings:
"""Test cases for add_model_file_id_mappings.

View File

@ -3267,3 +3267,433 @@ async def test_multiregion_team_failover_between_regions():
"response from us-east-1",
"response from us-west-2",
]
def test_access_group_scoped_key_filters_deployments_with_same_public_model():
"""
If a key can access a model only via access group membership,
router candidate deployments for that public model should be constrained
to deployments in the allowed access group.
"""
from litellm.proxy._types import UserAPIKeyAuth
router = litellm.Router(
model_list=[
{
"model_name": "gpt-5",
"litellm_params": {
"model": "openai/gpt-5.1",
"api_key": "key1",
"mock_response": "response-via-AG1",
},
"model_info": {"access_groups": ["AG1"]},
},
{
"model_name": "gpt-5",
"litellm_params": {
"model": "openai/gpt-4o",
"api_key": "key2",
"mock_response": "response-via-AG2",
},
"model_info": {"access_groups": ["AG2"]},
},
]
)
scoped_key = UserAPIKeyAuth(
api_key="hashed-key",
team_id="team2",
models=["AG2"],
team_models=["AG2"],
)
_model, deployments = router._common_checks_available_deployment(
model="gpt-5",
request_kwargs={
"metadata": {
"user_api_key_team_id": "team2",
"user_api_key_auth": scoped_key,
}
},
)
assert len(deployments) == 1
assert deployments[0].get("model_info", {}).get("access_groups") == ["AG2"]
seen = set()
for _ in range(20):
response = router.completion(
model="gpt-5",
messages=[{"role": "user", "content": "hello"}],
metadata={"user_api_key_team_id": "team2", "user_api_key_auth": scoped_key},
)
seen.add(response.choices[0].message.content)
assert seen == {"response-via-AG2"}
def test_explicit_model_access_does_not_force_access_group_filtering():
"""
If a key has explicit model access in addition to access group entries,
do not force access-group-only filtering for deployment selection.
"""
from litellm.proxy._types import UserAPIKeyAuth
router = litellm.Router(
model_list=[
{
"model_name": "gpt-5",
"litellm_params": {
"model": "openai/gpt-5.1",
"api_key": "key1",
"mock_response": "response-via-AG1",
},
"model_info": {"access_groups": ["AG1"]},
},
{
"model_name": "gpt-5",
"litellm_params": {
"model": "openai/gpt-4o",
"api_key": "key2",
"mock_response": "response-via-AG2",
},
"model_info": {"access_groups": ["AG2"]},
},
]
)
explicit_key = UserAPIKeyAuth(
api_key="hashed-key",
team_id="team2",
models=["AG2", "gpt-5"],
team_models=["AG2", "gpt-5"],
)
_model, deployments = router._common_checks_available_deployment(
model="gpt-5",
request_kwargs={
"metadata": {
"user_api_key_team_id": "team2",
"user_api_key_auth": explicit_key,
}
},
)
deployment_groups = [
d.get("model_info", {}).get("access_groups") for d in deployments
]
assert ["AG1"] in deployment_groups
assert ["AG2"] in deployment_groups
def test_access_group_filter_empty_does_not_bypass_via_litellm_model_fallback(
monkeypatch: pytest.MonkeyPatch,
):
"""
When access-group filtering removes all candidates, _get_deployment_by_litellm_model
must not run: it does not re-apply access groups and could return blocked deployments
that share the same litellm_params.model as the request model string.
``get_model_access_groups`` is patched to expose AG1 for the public model (so the
access-group filter runs with a non-empty allowed set) while every deployment
returned for that name is AG2-only filtered to empty. Without the guard, the
litellm-model fallback would return both rows because ``litellm_params.model`` matches.
"""
from litellm.proxy._types import UserAPIKeyAuth
router = litellm.Router(
model_list=[
{
"model_name": "gpt-5",
"litellm_params": {
"model": "gpt-5",
"api_key": "key1",
"mock_response": "blocked-dep-1",
},
"model_info": {"access_groups": ["AG2"]},
},
{
"model_name": "gpt-5",
"litellm_params": {
"model": "gpt-5",
"api_key": "key2",
"mock_response": "blocked-dep-2",
},
"model_info": {"access_groups": ["AG2"]},
},
]
)
orig_groups = router.get_model_access_groups
def fake_get_model_access_groups(
model_name=None, model_access_group=None, team_id=None
):
if model_name == "gpt-5" and model_access_group is None:
return {"AG1": ["gpt-5"], "AG2": ["gpt-5"]}
return orig_groups(
model_name=model_name,
model_access_group=model_access_group,
team_id=team_id,
)
monkeypatch.setattr(router, "get_model_access_groups", fake_get_model_access_groups)
scoped_key = UserAPIKeyAuth(
api_key="hashed-key",
team_id="team2",
models=["AG1"],
team_models=["AG1"],
)
with pytest.raises(litellm.BadRequestError):
router._common_checks_available_deployment(
model="gpt-5",
request_kwargs={
"metadata": {
"user_api_key_team_id": "team2",
"user_api_key_auth": scoped_key,
}
},
)
def test_access_group_block_does_not_silently_use_default_fallback_model(
monkeypatch: pytest.MonkeyPatch,
):
"""
When access-group filtering empties candidates for model X, the router must not use
``fallbacks`` default ``*`` routing to model Y: Y may have no ``access_groups``, so
``_filter_deployments_by_model_access_groups`` would not constrain Y and the caller
would be served despite being blocked from X.
"""
from litellm.proxy._types import UserAPIKeyAuth
router = litellm.Router(
model_list=[
{
"model_name": "gpt-5",
"litellm_params": {
"model": "gpt-5",
"api_key": "key1",
"mock_response": "blocked-dep-1",
},
"model_info": {"access_groups": ["AG2"]},
},
{
"model_name": "gpt-5",
"litellm_params": {
"model": "gpt-5",
"api_key": "key2",
"mock_response": "blocked-dep-2",
},
"model_info": {"access_groups": ["AG2"]},
},
{
"model_name": "gpt-4-fallback",
"litellm_params": {
"model": "gpt-4",
"api_key": "fallback-key",
"mock_response": "should-not-reach",
},
},
],
fallbacks=[{"*": ["gpt-4-fallback"]}],
)
orig_groups = router.get_model_access_groups
def fake_get_model_access_groups(
model_name=None, model_access_group=None, team_id=None
):
if model_name == "gpt-5" and model_access_group is None:
return {"AG1": ["gpt-5"], "AG2": ["gpt-5"]}
return orig_groups(
model_name=model_name,
model_access_group=model_access_group,
team_id=team_id,
)
monkeypatch.setattr(router, "get_model_access_groups", fake_get_model_access_groups)
scoped_key = UserAPIKeyAuth(
api_key="hashed-key",
team_id="team2",
models=["AG1"],
team_models=["AG1"],
)
with pytest.raises(litellm.BadRequestError):
router._common_checks_available_deployment(
model="gpt-5",
request_kwargs={
"metadata": {
"user_api_key_team_id": "team2",
"user_api_key_auth": scoped_key,
}
},
)
def test_access_group_block_via_litellm_model_branch_does_not_use_default_fallback(
monkeypatch: pytest.MonkeyPatch,
):
"""
When the by-name lookup returns no deployments and the litellm-model fallback
branch finds candidates that access-group filtering then empties, the router
must not fall through to default ``fallbacks`` routing the default fallback
model may have no ``access_groups`` and would short-circuit the filter,
silently serving a caller blocked by access-group restrictions.
"""
from litellm.proxy._types import UserAPIKeyAuth
router = litellm.Router(
model_list=[
{
"model_name": "gpt-5-alias",
"litellm_params": {
"model": "gpt-5",
"api_key": "key1",
"mock_response": "blocked-dep-1",
},
"model_info": {"access_groups": ["AG2"]},
},
{
"model_name": "gpt-4-fallback",
"litellm_params": {
"model": "gpt-4",
"api_key": "fallback-key",
"mock_response": "should-not-reach",
},
},
],
fallbacks=[{"*": ["gpt-4-fallback"]}],
)
orig_groups = router.get_model_access_groups
def fake_get_model_access_groups(
model_name=None, model_access_group=None, team_id=None
):
if model_name == "gpt-5" and model_access_group is None:
return {"AG1": ["gpt-5"], "AG2": ["gpt-5"]}
return orig_groups(
model_name=model_name,
model_access_group=model_access_group,
team_id=team_id,
)
monkeypatch.setattr(router, "get_model_access_groups", fake_get_model_access_groups)
scoped_key = UserAPIKeyAuth(
api_key="hashed-key",
team_id="team2",
models=["AG1"],
team_models=["AG1"],
)
with pytest.raises(litellm.BadRequestError):
router._common_checks_available_deployment(
model="gpt-5",
request_kwargs={
"metadata": {
"user_api_key_team_id": "team2",
"user_api_key_auth": scoped_key,
}
},
)
def test_try_early_resolve_deployments_for_model_not_in_names():
"""
Direct coverage for ``_try_early_resolve_deployments_for_model_not_in_names``:
- Returns ``None`` when the requested model is already in ``self.model_names``
(the by-name lookup path will handle it).
- Returns ``None`` when there are no team deployments, no pattern matches, and
no default deployment to fall back to.
- Returns the pattern-router match when the model matches a wildcard route.
- Returns the default deployment with the request model substituted in when one
is configured, without mutating the stored default.
"""
router_in_names = litellm.Router(
model_list=[
{
"model_name": "gpt-5",
"litellm_params": {
"model": "openai/gpt-5",
"api_key": "key1",
},
},
]
)
assert (
router_in_names._try_early_resolve_deployments_for_model_not_in_names(
model="gpt-5", request_team_id=None
)
is None
)
assert (
router_in_names._try_early_resolve_deployments_for_model_not_in_names(
model="some-unknown-model", request_team_id=None
)
is None
)
pattern_router = litellm.Router(
model_list=[
{
"model_name": "openai/*",
"litellm_params": {
"model": "openai/*",
"api_key": "key-pattern",
},
},
]
)
pattern_result = (
pattern_router._try_early_resolve_deployments_for_model_not_in_names(
model="openai/gpt-4o-mini", request_team_id=None
)
)
assert pattern_result is not None
resolved_model, pattern_deployments = pattern_result
assert resolved_model == "openai/gpt-4o-mini"
assert isinstance(pattern_deployments, list) and len(pattern_deployments) == 1
default_router = litellm.Router(
model_list=[
{
"model_name": "named-model",
"litellm_params": {
"model": "openai/gpt-4o",
"api_key": "key-named",
},
},
]
)
default_router.default_deployment = {
"model_name": "default",
"litellm_params": {
"model": "openai/will-be-overridden",
"api_key": "key-default",
},
}
default_result = (
default_router._try_early_resolve_deployments_for_model_not_in_names(
model="brand-new-model", request_team_id=None
)
)
assert default_result is not None
resolved_model, default_deployment = default_result
assert resolved_model == "brand-new-model"
assert isinstance(default_deployment, dict)
assert default_deployment["litellm_params"]["model"] == "brand-new-model"
# The original default_deployment must not be mutated.
assert (
default_router.default_deployment["litellm_params"]["model"]
== "openai/will-be-overridden"
)