fix(auth): expand all-team-models sentinel in can_key_call_model for batch validation (#29746)

* fix(auth): expand all-team-models sentinel in can_key_call_model

Keys with models=["all-team-models"] were denied during batch JSONL
model validation because can_key_call_model matched the literal string
against the model name. Add _resolve_key_models_for_auth_check to
expand the sentinel to team_models before the check, consistent with
get_key_models in model_checks.py and the completion-route bypass.

Co-authored-by: Cursor <cursoragent@cursor.com>

* docs(auth): document empty team_models unrestricted access behavior; add regression test

Adds a docstring note to _resolve_key_models_for_auth_check explaining that
when team_models is empty, all-team-models resolves to [] which is treated as
unrestricted access (consistent with get_key_models behavior on other auth
paths). Adds a test to lock in this behavior.

* fix(auth): deny all-team-models access when key has no team_id

A key configured with models=["all-team-models"] but no team_id could
previously resolve to an empty allowlist, which _check_model_access_helper
treats as unrestricted access. Now the sentinel is only expanded when
team_id is set; otherwise the unresolved sentinel stays in the model list
and causes a deny (no real model name matches it). Same fix applied to
get_key_models in model_checks.py for consistency across batch and
non-batch auth paths.

* style: black format model_checks.py

* Fix batch all-team-models auth

* style: black format batch_rate_limiter.py

* fix(test): add tool_use_system_prompt_tokens to model prices schema validator

* fix(batch): catch get_team_object errors to avoid 404 escaping batch auth

* fix(batch): apply per-member model scope check after team auth in batch validation

* Fail closed on batch team auth fetch errors

* test(batch): cover team_object grant and member-scope denial in batch auth

---------

Co-authored-by: Cursor <cursoragent@cursor.com>
Co-authored-by: mateo-berri <277851410+mateo-berri@users.noreply.github.com>
This commit is contained in:
Sameer Kankute 2026-06-05 21:34:45 +05:30 committed by GitHub
parent 89f177b7b6
commit 074455c138
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 461 additions and 24 deletions

View File

@ -3074,6 +3074,26 @@ def _model_in_team_aliases(
return False
def _resolve_key_models_for_auth_check(valid_token: UserAPIKeyAuth) -> List[str]:
"""
Expand key model sentinels before auth checks.
``all-team-models`` means inherit the parent team's allowlist — same
semantics as ``get_key_models`` in ``model_checks.py``.
If the key has no team_id the sentinel cannot be resolved, so the original
model list (still containing the sentinel string) is returned unchanged.
That string won't match any real model, so access is denied rather than
silently falling through to unrestricted access.
"""
models = list(valid_token.models or [])
if SpecialModelNames.all_team_models.value in models:
if valid_token.team_id is None:
return models
return list(valid_token.team_models or [])
return models
async def can_key_call_model(
model: Union[str, List[str]],
llm_model_list: Optional[list],
@ -3092,11 +3112,12 @@ async def can_key_call_model(
Raises:
- Exception: If token not allowed to call model
"""
key_models = _resolve_key_models_for_auth_check(valid_token=valid_token)
try:
return _can_object_call_model(
model=model,
llm_router=llm_router,
models=valid_token.models,
models=key_models,
team_model_aliases=valid_token.team_model_aliases,
team_id=valid_token.team_id,
object_type="key",

View File

@ -116,7 +116,10 @@ def get_key_models(
all_models = list(
user_api_key_dict.models
) # copy to avoid mutating cached objects
if SpecialModelNames.all_team_models.value in all_models:
if (
SpecialModelNames.all_team_models.value in all_models
and user_api_key_dict.team_id is not None
):
all_models = list(
user_api_key_dict.team_models
) # copy to avoid mutating cached objects

View File

@ -31,7 +31,12 @@ from litellm.batches.batch_utils import (
_get_models_from_batch_input_file_content,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
from litellm.proxy._types import (
ProxyErrorTypes,
ProxyException,
SpecialModelNames,
UserAPIKeyAuth,
)
from litellm.proxy.hooks.rate_limiter_utils import (
ProxyHTTPRateLimitError,
resolve_llm_provider_for_rate_limit,
@ -594,17 +599,52 @@ class _PROXY_BatchRateLimiter(CustomLogger):
"""Reject the batch if the caller is not authorized for every
``body.model`` named inside the JSONL.
Reuses ``can_key_call_model`` so the same allowlist semantics
(wildcards, access groups, ``all-proxy-models``, team aliases)
the proxy enforces on `/chat/completions` apply here.
Reuses standard auth helpers so the same model access rules the proxy
enforces on `/chat/completions` apply here.
"""
from litellm.proxy.auth.auth_checks import can_key_call_model
from litellm.proxy.auth.auth_checks import (
_check_team_member_model_access,
_key_access_group_grants_model,
can_key_call_model,
can_team_access_model,
get_team_object,
)
from litellm.proxy.proxy_server import llm_router
from litellm.proxy.proxy_server import prisma_client
from litellm.proxy.proxy_server import proxy_logging_obj
from litellm.proxy.proxy_server import user_api_key_cache
models = _get_models_from_batch_input_file_content(file_content_as_dict)
if not models:
return
team_object = None
if (
SpecialModelNames.all_team_models.value in (user_api_key_dict.models or [])
and user_api_key_dict.team_id is not None
and prisma_client is not None
):
try:
team_object = await get_team_object(
team_id=user_api_key_dict.team_id,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
parent_otel_span=user_api_key_dict.parent_otel_span,
proxy_logging_obj=proxy_logging_obj,
)
except HTTPException:
raise
except Exception as e:
raise HTTPException(
status_code=403,
detail={
"error": (
"Batch input file model access could not be "
"validated against the current team."
)
},
) from e
llm_model_list = llm_router.model_list if llm_router is not None else None
for model in models:
# body.model may be the provider id after replace_model_in_jsonl; map to proxy model_name for auth.
@ -614,18 +654,43 @@ class _PROXY_BatchRateLimiter(CustomLogger):
if proxy_model_name is not None:
model_to_check = proxy_model_name
try:
await can_key_call_model(
model=model_to_check,
llm_model_list=llm_model_list,
valid_token=user_api_key_dict,
llm_router=llm_router,
)
if team_object is not None:
try:
await can_team_access_model(
model=model_to_check,
team_object=team_object,
llm_router=llm_router,
team_model_aliases=user_api_key_dict.team_model_aliases,
)
except ProxyException as team_denial:
if team_denial.type != ProxyErrorTypes.team_model_access_denied:
raise
if not await _key_access_group_grants_model(
model=model_to_check,
valid_token=user_api_key_dict,
team_object=team_object,
llm_router=llm_router,
):
raise
await _check_team_member_model_access(
model=model_to_check,
team_object=team_object,
valid_token=user_api_key_dict,
llm_router=llm_router,
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
proxy_logging_obj=proxy_logging_obj,
)
else:
await can_key_call_model(
model=model_to_check,
llm_model_list=llm_model_list,
valid_token=user_api_key_dict,
llm_router=llm_router,
)
except HTTPException:
raise
except Exception as e:
# `can_key_call_model` raises ProxyException on denial;
# re-shape to a 403 so the batch endpoint returns a
# consistent rejection without leaking internal types.
raise HTTPException(
status_code=403,
detail={

View File

@ -271,6 +271,86 @@ async def test_can_user_call_model_no_default_models_returns_forbidden():
assert int(exc_info.value.code) == status.HTTP_403_FORBIDDEN
@pytest.mark.asyncio
async def test_can_key_call_model_all_team_models_uses_team_allowlist():
from litellm.proxy._types import SpecialModelNames
from litellm.proxy.auth.auth_checks import can_key_call_model
valid_token = UserAPIKeyAuth(
api_key="sk-team-key",
team_id="team-123",
models=[SpecialModelNames.all_team_models.value],
team_models=["openai/openai/gpt-5.5-batch"],
)
assert (
await can_key_call_model(
model="openai/openai/gpt-5.5-batch",
llm_model_list=None,
valid_token=valid_token,
llm_router=None,
)
is True
)
with pytest.raises(ProxyException) as exc_info:
await can_key_call_model(
model="gpt-4o",
llm_model_list=None,
valid_token=valid_token,
llm_router=None,
)
assert exc_info.value.type == ProxyErrorTypes.key_model_access_denied
@pytest.mark.asyncio
async def test_can_key_call_model_all_team_models_empty_team_models_is_unrestricted():
"""Team-bound key with empty team_models expands to [] -> unrestricted (same as get_key_models)."""
from litellm.proxy._types import SpecialModelNames
from litellm.proxy.auth.auth_checks import can_key_call_model
valid_token = UserAPIKeyAuth(
api_key="sk-team-key",
team_id="team-123",
models=[SpecialModelNames.all_team_models.value],
team_models=[],
)
assert (
await can_key_call_model(
model="any-model",
llm_model_list=None,
valid_token=valid_token,
llm_router=None,
)
is True
)
@pytest.mark.asyncio
async def test_can_key_call_model_all_team_models_no_team_id_is_denied():
"""Key with all-team-models but no team_id cannot resolve the sentinel; access must be denied."""
from litellm.proxy._types import SpecialModelNames
from litellm.proxy.auth.auth_checks import can_key_call_model
valid_token = UserAPIKeyAuth(
api_key="sk-orphan-key",
models=[SpecialModelNames.all_team_models.value],
team_models=[],
)
with pytest.raises(ProxyException) as exc_info:
await can_key_call_model(
model="gpt-4o",
llm_model_list=None,
valid_token=valid_token,
llm_router=None,
)
assert exc_info.value.type == ProxyErrorTypes.key_model_access_denied
@pytest.mark.asyncio
async def test_get_key_object_should_reconnect_once_on_db_connection_error():
mock_prisma_client = MagicMock()
@ -3391,18 +3471,16 @@ async def test_resolve_end_user_swallows_db_errors_and_returns_none(
@pytest.mark.asyncio
async def test_resolve_end_user(
_validate_flag_on, monkeypatch
):
async def test_resolve_end_user(_validate_flag_on, monkeypatch):
"""Verify that resolve_and_validate_end_user_id does NOT raise BudgetExceededError.
Note: As of the refactor that moved _check_end_user_budget out of
Note: As of the refactor that moved _check_end_user_budget out of
get_end_user_object, budget enforcement now happens in common_checks().
The end-user validation path should return the user ID regardless of budget status.
Budget enforcement for end users happens later in common_checks() via
Budget enforcement for end users happens later in common_checks() via
_check_end_user_budget(), which respects skip_budget_checks for zero-cost models.
This test verifies that even when get_end_user_object returns a user with a budget,
resolve_and_validate_end_user_id does not block the request - budget enforcement
is deferred to common_checks() where skip_budget_checks logic can be applied.

View File

@ -218,6 +218,275 @@ async def test_pre_call_rejects_unauthorized_model_in_batch_file():
assert "gpt-4o" in str(exc.value.detail)
@pytest.mark.asyncio
async def test_pre_call_allows_all_team_models_key_when_model_in_team_allowlist():
"""Keys with ``all-team-models`` must inherit the team allowlist when
validating models embedded in batch JSONL."""
from litellm.proxy._types import SpecialModelNames
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
rate_limiter = _PROXY_BatchRateLimiter(
internal_usage_cache=MagicMock(),
parallel_request_limiter=MagicMock(),
)
proxy_alias = "openai/openai/gpt-5.5-batch"
file_dict = [
{
"body": {
"model": proxy_alias,
"messages": [{"role": "user", "content": "x"}],
}
}
]
user = UserAPIKeyAuth(
api_key="sk-team",
user_id="alice",
team_id="team-123",
models=[SpecialModelNames.all_team_models.value],
team_models=[proxy_alias],
user_role=LitellmUserRoles.INTERNAL_USER.value,
)
with patch("litellm.proxy.proxy_server.llm_router", None):
await rate_limiter._enforce_batch_file_model_access(
user_api_key_dict=user,
file_content_as_dict=file_dict,
)
@pytest.mark.asyncio
async def test_pre_call_uses_current_team_allowlist_for_all_team_models_key():
from litellm.proxy._types import LiteLLM_TeamTable, SpecialModelNames
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
rate_limiter = _PROXY_BatchRateLimiter(
internal_usage_cache=MagicMock(),
parallel_request_limiter=MagicMock(),
)
stale_model = "stale-model"
current_model = "current-model"
file_dict = [
{
"body": {
"model": stale_model,
"messages": [{"role": "user", "content": "x"}],
}
}
]
user = UserAPIKeyAuth(
api_key="sk-team",
user_id="alice",
team_id="team-123",
models=[SpecialModelNames.all_team_models.value],
team_models=[stale_model],
user_role=LitellmUserRoles.INTERNAL_USER.value,
)
team_object = LiteLLM_TeamTable(
team_id="team-123",
models=[current_model],
)
with (
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
patch("litellm.proxy.proxy_server.llm_router", None),
patch(
"litellm.proxy.auth.auth_checks.get_team_object",
new=AsyncMock(return_value=team_object),
) as mock_get_team_object,
pytest.raises(HTTPException) as exc_info,
):
await rate_limiter._enforce_batch_file_model_access(
user_api_key_dict=user,
file_content_as_dict=file_dict,
)
assert exc_info.value.status_code == 403
mock_get_team_object.assert_awaited_once()
@pytest.mark.asyncio
async def test_pre_call_allows_all_team_models_key_via_current_team_object():
"""Happy path for the team_object branch: with a DB client present, an
``all-team-models`` key whose batch model is on the *current* team
allowlist must be authorized through the freshly-fetched team object,
not the cached-``team_models`` fallback."""
from litellm.proxy._types import LiteLLM_TeamTable, SpecialModelNames
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
rate_limiter = _PROXY_BatchRateLimiter(
internal_usage_cache=MagicMock(),
parallel_request_limiter=MagicMock(),
)
current_model = "current-model"
file_dict = [
{
"body": {
"model": current_model,
"messages": [{"role": "user", "content": "x"}],
}
}
]
user = UserAPIKeyAuth(
api_key="sk-team",
user_id="alice",
team_id="team-123",
models=[SpecialModelNames.all_team_models.value],
team_models=["stale-model"],
user_role=LitellmUserRoles.INTERNAL_USER.value,
)
team_object = LiteLLM_TeamTable(
team_id="team-123",
models=[current_model],
)
can_key_call_model = AsyncMock(return_value=True)
with (
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
patch("litellm.proxy.proxy_server.llm_router", None),
patch(
"litellm.proxy.auth.auth_checks.get_team_object",
new=AsyncMock(return_value=team_object),
) as mock_get_team_object,
patch(
"litellm.proxy.auth.auth_checks.get_team_membership",
new=AsyncMock(return_value=None),
),
patch(
"litellm.proxy.auth.auth_checks.can_key_call_model",
new=can_key_call_model,
),
):
await rate_limiter._enforce_batch_file_model_access(
user_api_key_dict=user,
file_content_as_dict=file_dict,
)
mock_get_team_object.assert_awaited_once()
can_key_call_model.assert_not_awaited()
@pytest.mark.asyncio
async def test_pre_call_denies_all_team_models_key_via_member_scope():
"""The team_object branch must also apply the per-member model scope: a
model on the team allowlist but outside the member's ``allowed_models``
must be rejected with a 403."""
from litellm.proxy._types import (
LiteLLM_BudgetTable,
LiteLLM_TeamMembership,
LiteLLM_TeamTable,
SpecialModelNames,
)
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
rate_limiter = _PROXY_BatchRateLimiter(
internal_usage_cache=MagicMock(),
parallel_request_limiter=MagicMock(),
)
team_model = "team-model"
file_dict = [
{
"body": {
"model": team_model,
"messages": [{"role": "user", "content": "x"}],
}
}
]
user = UserAPIKeyAuth(
api_key="sk-team",
user_id="alice",
team_id="team-123",
models=[SpecialModelNames.all_team_models.value],
team_models=[team_model],
user_role=LitellmUserRoles.INTERNAL_USER.value,
)
team_object = LiteLLM_TeamTable(team_id="team-123", models=[team_model])
membership = LiteLLM_TeamMembership(
user_id="alice",
team_id="team-123",
litellm_budget_table=LiteLLM_BudgetTable(allowed_models=["other-model"]),
)
with (
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
patch("litellm.proxy.proxy_server.llm_router", None),
patch(
"litellm.proxy.auth.auth_checks.get_team_object",
new=AsyncMock(return_value=team_object),
),
patch(
"litellm.proxy.auth.auth_checks.get_team_membership",
new=AsyncMock(return_value=membership),
),
pytest.raises(HTTPException) as exc_info,
):
await rate_limiter._enforce_batch_file_model_access(
user_api_key_dict=user,
file_content_as_dict=file_dict,
)
assert exc_info.value.status_code == 403
assert team_model in str(exc_info.value.detail)
@pytest.mark.parametrize(
("team_fetch_error", "expected_status"),
[
(HTTPException(status_code=404, detail="team not found"), 404),
(Exception("team fetch failed"), 403),
],
)
@pytest.mark.asyncio
async def test_pre_call_fails_closed_when_current_team_fetch_fails_for_all_team_models_key(
team_fetch_error, expected_status
):
from litellm.proxy._types import SpecialModelNames
from litellm.proxy.hooks.batch_rate_limiter import _PROXY_BatchRateLimiter
rate_limiter = _PROXY_BatchRateLimiter(
internal_usage_cache=MagicMock(),
parallel_request_limiter=MagicMock(),
)
stale_model = "stale-model"
file_dict = [
{
"body": {
"model": stale_model,
"messages": [{"role": "user", "content": "x"}],
}
}
]
user = UserAPIKeyAuth(
api_key="sk-team",
user_id="alice",
team_id="team-123",
models=[SpecialModelNames.all_team_models.value],
team_models=[stale_model],
user_role=LitellmUserRoles.INTERNAL_USER.value,
)
with (
patch("litellm.proxy.proxy_server.prisma_client", MagicMock()),
patch("litellm.proxy.proxy_server.llm_router", None),
patch(
"litellm.proxy.auth.auth_checks.get_team_object",
new=AsyncMock(side_effect=team_fetch_error),
) as mock_get_team_object,
patch(
"litellm.proxy.auth.auth_checks.can_key_call_model",
new=AsyncMock(return_value=True),
) as mock_can_key_call_model,
pytest.raises(HTTPException) as exc_info,
):
await rate_limiter._enforce_batch_file_model_access(
user_api_key_dict=user,
file_content_as_dict=file_dict,
)
assert exc_info.value.status_code == expected_status
mock_get_team_object.assert_awaited_once()
mock_can_key_call_model.assert_not_awaited()
@pytest.mark.asyncio
async def test_pre_call_allows_authorized_model_in_batch_file():
"""If every model in the JSONL is on the caller's allowlist, the hook

View File

@ -692,6 +692,7 @@ def test_aaamodel_prices_and_context_window_json_is_valid():
"type": "object",
"properties": {
"supports_computer_use": {"type": "boolean"},
"tool_use_system_prompt_tokens": {"type": "number"},
"cache_creation_input_audio_token_cost": {"type": "number"},
"cache_creation_input_token_cost": {"type": "number"},
"cache_creation_input_token_cost_above_1hr": {"type": "number"},