Merge pull request #26809 from BerriAI/litellm_teamMemberNullBudgetFallback
[Fix] Team member null budget fallback
This commit is contained in:
commit
383bf12001
@ -915,7 +915,8 @@ async def get_team_member_default_budget(
|
||||
Fetches the team-level default per-member budget referenced by team.metadata["team_member_budget_id"].
|
||||
|
||||
This budget is applied to team members whose TeamMembership row has no
|
||||
linked budget. Results are cached for performance.
|
||||
linked budget, or whose linked budget has max_budget=NULL. Results are
|
||||
cached for performance.
|
||||
|
||||
Args:
|
||||
budget_id: The budget_id pulled from team.metadata["team_member_budget_id"]
|
||||
@ -3293,6 +3294,7 @@ async def _check_team_member_budget(
|
||||
if (
|
||||
team_membership is not None
|
||||
and team_membership.litellm_budget_table is not None
|
||||
and team_membership.litellm_budget_table.max_budget is not None
|
||||
):
|
||||
team_member_budget = team_membership.litellm_budget_table.max_budget
|
||||
else:
|
||||
|
||||
@ -2336,3 +2336,134 @@ async def test_team_member_budget_check_per_member_override_wins_over_team_defau
|
||||
)
|
||||
assert exc_info.value.current_cost == 250.0
|
||||
assert exc_info.value.max_budget == 200.0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_member_budget_check_null_clone_falls_back_to_team_default():
|
||||
"""Per-member NULL max_budget falls through to the team default cap."""
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.proxy._types import LiteLLM_BudgetTable, LiteLLM_TeamMembership
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
team_object = LiteLLM_TeamTable(
|
||||
team_id="test-team",
|
||||
metadata={"team_member_budget_id": "budget-default"},
|
||||
)
|
||||
user_object = LiteLLM_UserTable(user_id="test-user")
|
||||
valid_token = UserAPIKeyAuth(
|
||||
token="test-token",
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
)
|
||||
|
||||
# Per-member row exists with NULL max_budget (the cloned-from-incomplete-default case).
|
||||
team_membership = LiteLLM_TeamMembership(
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
spend=0.0,
|
||||
budget_id="budget-clone",
|
||||
litellm_budget_table=LiteLLM_BudgetTable(max_budget=None),
|
||||
)
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
|
||||
|
||||
fake_default_row = MagicMock()
|
||||
fake_default_row.max_budget = 65.0
|
||||
fake_default_row.dict = MagicMock(
|
||||
return_value={"budget_id": "budget-default", "max_budget": 65.0}
|
||||
)
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_budgettable.find_unique = AsyncMock(
|
||||
return_value=fake_default_row
|
||||
)
|
||||
|
||||
async def mock_get_current_spend(counter_key, fallback_spend):
|
||||
if counter_key == "spend:team_member:test-user:test-team":
|
||||
return 500.0
|
||||
return fallback_spend
|
||||
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
|
||||
patch(
|
||||
"litellm.proxy.auth.auth_checks.get_team_membership",
|
||||
new_callable=AsyncMock,
|
||||
return_value=team_membership,
|
||||
),
|
||||
):
|
||||
with pytest.raises(litellm.BudgetExceededError) as exc_info:
|
||||
await _check_team_member_budget(
|
||||
team_object=team_object,
|
||||
user_object=user_object,
|
||||
valid_token=valid_token,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=DualCache(),
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
assert exc_info.value.current_cost == 500.0
|
||||
assert exc_info.value.max_budget == 65.0
|
||||
prisma_client.db.litellm_budgettable.find_unique.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_member_budget_check_null_clone_with_null_default_skips_enforcement():
|
||||
"""When per-member and team default are both NULL, enforcement still skips."""
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.proxy._types import LiteLLM_BudgetTable, LiteLLM_TeamMembership
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
|
||||
team_object = LiteLLM_TeamTable(
|
||||
team_id="test-team",
|
||||
metadata={"team_member_budget_id": "budget-default"},
|
||||
)
|
||||
user_object = LiteLLM_UserTable(user_id="test-user")
|
||||
valid_token = UserAPIKeyAuth(
|
||||
token="test-token",
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
)
|
||||
|
||||
team_membership = LiteLLM_TeamMembership(
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
spend=0.0,
|
||||
budget_id="budget-clone",
|
||||
litellm_budget_table=LiteLLM_BudgetTable(max_budget=None),
|
||||
)
|
||||
|
||||
proxy_logging_obj = ProxyLogging(user_api_key_cache=None)
|
||||
|
||||
fake_default_row = MagicMock()
|
||||
fake_default_row.max_budget = None
|
||||
fake_default_row.dict = MagicMock(
|
||||
return_value={"budget_id": "budget-default", "max_budget": None}
|
||||
)
|
||||
|
||||
prisma_client = MagicMock()
|
||||
prisma_client.db.litellm_budgettable.find_unique = AsyncMock(
|
||||
return_value=fake_default_row
|
||||
)
|
||||
|
||||
async def mock_get_current_spend(counter_key, fallback_spend):
|
||||
if counter_key == "spend:team_member:test-user:test-team":
|
||||
return 1000.0
|
||||
return fallback_spend
|
||||
|
||||
with (
|
||||
patch("litellm.proxy.proxy_server.get_current_spend", mock_get_current_spend),
|
||||
patch(
|
||||
"litellm.proxy.auth.auth_checks.get_team_membership",
|
||||
new_callable=AsyncMock,
|
||||
return_value=team_membership,
|
||||
),
|
||||
):
|
||||
# No raise: both rows are NULL, so enforcement is correctly skipped.
|
||||
await _check_team_member_budget(
|
||||
team_object=team_object,
|
||||
user_object=user_object,
|
||||
valid_token=valid_token,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_cache=DualCache(),
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user