Merge pull request #26809 from BerriAI/litellm_teamMemberNullBudgetFallback

[Fix] Team member null budget fallback
This commit is contained in:
yuneng-jiang 2026-04-29 19:44:41 -07:00 committed by GitHub
commit 383bf12001
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 134 additions and 1 deletions

View File

@ -915,7 +915,8 @@ async def get_team_member_default_budget(
Fetches the team-level default per-member budget referenced by team.metadata["team_member_budget_id"].
This budget is applied to team members whose TeamMembership row has no
linked budget. Results are cached for performance.
linked budget, or whose linked budget has max_budget=NULL. Results are
cached for performance.
Args:
budget_id: The budget_id pulled from team.metadata["team_member_budget_id"]
@ -3293,6 +3294,7 @@ async def _check_team_member_budget(
if (
team_membership is not None
and team_membership.litellm_budget_table is not None
and team_membership.litellm_budget_table.max_budget is not None
):
team_member_budget = team_membership.litellm_budget_table.max_budget
else:

View File

@ -2336,3 +2336,134 @@ async def test_team_member_budget_check_per_member_override_wins_over_team_defau
)
assert exc_info.value.current_cost == 250.0
assert exc_info.value.max_budget == 200.0
@pytest.mark.asyncio
async def test_team_member_budget_check_null_clone_falls_back_to_team_default():
"""Per-member NULL max_budget falls through to the team default cap."""
from litellm.caching.dual_cache import DualCache
from litellm.proxy._types import LiteLLM_BudgetTable, LiteLLM_TeamMembership
from litellm.proxy.utils import ProxyLogging
team_object = LiteLLM_TeamTable(
team_id="test-team",
metadata={"team_member_budget_id": "budget-default"},
)
user_object = LiteLLM_UserTable(user_id="test-user")
valid_token = UserAPIKeyAuth(
token="test-token",
user_id="test-user",
team_id="test-team",
)
# Per-member row exists with NULL max_budget (the cloned-from-incomplete-default case).
team_membership = LiteLLM_TeamMembership(
user_id="test-user",
team_id="test-team",
spend=0.0,
budget_id="budget-clone",
litellm_budget_table=LiteLLM_BudgetTable(max_budget=None),
)
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
fake_default_row = MagicMock()
fake_default_row.max_budget = 65.0
fake_default_row.dict = MagicMock(
return_value={"budget_id": "budget-default", "max_budget": 65.0}
)
prisma_client = MagicMock()
prisma_client.db.litellm_budgettable.find_unique = AsyncMock(
return_value=fake_default_row
)
async def mock_get_current_spend(counter_key, fallback_spend):
if counter_key == "spend:team_member:test-user:test-team":
return 500.0
return fallback_spend
with (
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
patch(
"litellm.proxy.auth.auth_checks.get_team_membership",
new_callable=AsyncMock,
return_value=team_membership,
),
):
with pytest.raises(litellm.BudgetExceededError) as exc_info:
await _check_team_member_budget(
team_object=team_object,
user_object=user_object,
valid_token=valid_token,
prisma_client=prisma_client,
user_api_key_cache=DualCache(),
proxy_logging_obj=proxy_logging_obj,
)
assert exc_info.value.current_cost == 500.0
assert exc_info.value.max_budget == 65.0
prisma_client.db.litellm_budgettable.find_unique.assert_awaited_once()
@pytest.mark.asyncio
async def test_team_member_budget_check_null_clone_with_null_default_skips_enforcement():
"""When per-member and team default are both NULL, enforcement still skips."""
from litellm.caching.dual_cache import DualCache
from litellm.proxy._types import LiteLLM_BudgetTable, LiteLLM_TeamMembership
from litellm.proxy.utils import ProxyLogging
team_object = LiteLLM_TeamTable(
team_id="test-team",
metadata={"team_member_budget_id": "budget-default"},
)
user_object = LiteLLM_UserTable(user_id="test-user")
valid_token = UserAPIKeyAuth(
token="test-token",
user_id="test-user",
team_id="test-team",
)
team_membership = LiteLLM_TeamMembership(
user_id="test-user",
team_id="test-team",
spend=0.0,
budget_id="budget-clone",
litellm_budget_table=LiteLLM_BudgetTable(max_budget=None),
)
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
fake_default_row = MagicMock()
fake_default_row.max_budget = None
fake_default_row.dict = MagicMock(
return_value={"budget_id": "budget-default", "max_budget": None}
)
prisma_client = MagicMock()
prisma_client.db.litellm_budgettable.find_unique = AsyncMock(
return_value=fake_default_row
)
async def mock_get_current_spend(counter_key, fallback_spend):
if counter_key == "spend:team_member:test-user:test-team":
return 1000.0
return fallback_spend
with (
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
patch(
"litellm.proxy.auth.auth_checks.get_team_membership",
new_callable=AsyncMock,
return_value=team_membership,
),
):
# No raise: both rows are NULL, so enforcement is correctly skipped.
await _check_team_member_budget(
team_object=team_object,
user_object=user_object,
valid_token=valid_token,
prisma_client=prisma_client,
user_api_key_cache=DualCache(),
proxy_logging_obj=proxy_logging_obj,
)