fix(team-management): delete a team's BYOK models when the team is deleted (#29977)
A team's BYOK models (rows in LiteLLM_ProxyModelTable with model_info.team_id set) were left orphaned when the team was deleted; they lingered in the database and kept showing on the Models + Endpoints page. delete_team now removes them via a new delete_team_models helper that deletes the rows in one transaction and syncs the in-memory router only after that transaction commits, run before the team rows are deleted so a mid-flight failure never leaves the team gone with its models orphaned
This commit is contained in:
parent
bac2590b39
commit
c24a3603d9
@ -561,7 +561,7 @@ async def _setup_new_team_model_assignment(
|
||||
|
||||
|
||||
async def _get_team_deployments(
|
||||
team_id: str, prisma_client: PrismaClient
|
||||
team_id: str, prisma_client: PrismaClient, table: Optional[Any] = None
|
||||
) -> List[LiteLLM_ProxyModelTable]:
|
||||
"""
|
||||
Fetch all deployments for a given team_id from the database.
|
||||
@ -572,9 +572,13 @@ async def _get_team_deployments(
|
||||
Note: prisma-client-py 0.11.0 does not support JSON path filtering, so we filter
|
||||
by the model_name prefix (team models use "model_name_{team_id}_*") and confirm
|
||||
team_id in model_info with Python-side filtering.
|
||||
|
||||
Pass ``table`` (a transaction's proxy-model table) to run the read inside an
|
||||
existing transaction.
|
||||
"""
|
||||
prefix = f"model_name_{team_id}_"
|
||||
response = await ModelRepository(prisma_client).table.find_many(
|
||||
table = table or ModelRepository(prisma_client).table
|
||||
response = await table.find_many(
|
||||
where={
|
||||
"model_name": {"startswith": prefix},
|
||||
}
|
||||
@ -596,6 +600,42 @@ async def _get_team_deployments(
|
||||
return result
|
||||
|
||||
|
||||
async def delete_team_models(
|
||||
team_ids: List[str],
|
||||
prisma_client: PrismaClient,
|
||||
llm_router: Optional[Any],
|
||||
) -> List[str]:
|
||||
"""
|
||||
Delete every BYOK model owned by the given teams, from the DB and the router.
|
||||
|
||||
The DB rows are removed inside a single transaction, so deletion is atomic
|
||||
across all team_ids. Each team's rows are deleted by the exact model_ids read
|
||||
in the same transaction, which keeps the deleted set identical to the set
|
||||
handed to the router. The router is synced only after the transaction commits,
|
||||
so a rollback can never leave a deployment live in the router without its row.
|
||||
|
||||
Returns the model_ids that were deleted.
|
||||
"""
|
||||
deleted_model_ids: List[str] = []
|
||||
async with prisma_client.db.tx() as tx:
|
||||
for team_id in team_ids:
|
||||
rows = await _get_team_deployments(
|
||||
team_id, prisma_client, table=tx.litellm_proxymodeltable
|
||||
)
|
||||
model_ids = [row.model_id for row in rows]
|
||||
if model_ids:
|
||||
await tx.litellm_proxymodeltable.delete_many(
|
||||
where={"model_id": {"in": model_ids}}
|
||||
)
|
||||
deleted_model_ids.extend(model_ids)
|
||||
|
||||
if llm_router is not None:
|
||||
for model_id in deleted_model_ids:
|
||||
llm_router.delete_deployment(id=model_id)
|
||||
|
||||
return deleted_model_ids
|
||||
|
||||
|
||||
async def _get_team_public_model_names(
|
||||
team_id: str,
|
||||
prisma_client: PrismaClient,
|
||||
|
||||
@ -3289,6 +3289,20 @@ async def delete_team(
|
||||
|
||||
await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key")
|
||||
|
||||
## DELETE ASSOCIATED BYOK MODELS
|
||||
# Runs before the team rows are deleted so a mid-flight failure never leaves
|
||||
# the team gone with its models orphaned.
|
||||
from litellm.proxy.management_endpoints.model_management_endpoints import (
|
||||
delete_team_models,
|
||||
)
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
await delete_team_models(
|
||||
team_ids=data.team_ids,
|
||||
prisma_client=prisma_client,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
# ## DELETE TEAM MEMBERSHIPS
|
||||
for team_row in team_rows:
|
||||
### get all team members
|
||||
|
||||
@ -24,6 +24,7 @@ from litellm.proxy.management_endpoints.model_management_endpoints import (
|
||||
ModelManagementAuthChecks,
|
||||
_get_team_deployments,
|
||||
clear_cache,
|
||||
delete_team_models,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.router import Deployment, LiteLLM_Params, updateDeployment
|
||||
@ -2184,6 +2185,148 @@ class TestGetTeamDeployments:
|
||||
assert result[0] is dep1
|
||||
|
||||
|
||||
def _model_row(model_id: str, team_id: str):
|
||||
row = MagicMock()
|
||||
row.model_id = model_id
|
||||
row.model_name = f"model_name_{team_id}_{model_id}"
|
||||
row.model_info = {"team_id": team_id}
|
||||
return row
|
||||
|
||||
|
||||
class _TxProxyModelTable:
|
||||
"""Transactional proxy-model table that records the order of DB writes."""
|
||||
|
||||
def __init__(self, rows, events):
|
||||
self._rows = list(rows)
|
||||
self.events = events
|
||||
|
||||
async def find_many(self, where):
|
||||
prefix = where["model_name"]["startswith"]
|
||||
return [r for r in self._rows if r.model_name.startswith(prefix)]
|
||||
|
||||
async def delete_many(self, where):
|
||||
ids = list(where["model_id"]["in"])
|
||||
self.events.append(("delete_many", tuple(ids)))
|
||||
self._rows = [r for r in self._rows if r.model_id not in ids]
|
||||
return len(ids)
|
||||
|
||||
|
||||
class _TxPrismaClient:
|
||||
"""Minimal prisma stub whose ``db.tx()`` yields a transaction and records commit."""
|
||||
|
||||
def __init__(self, rows):
|
||||
self.events: list = []
|
||||
self._table = _TxProxyModelTable(rows, self.events)
|
||||
tx = MagicMock()
|
||||
tx.litellm_proxymodeltable = self._table
|
||||
outer = self
|
||||
|
||||
class _TxCM:
|
||||
async def __aenter__(self):
|
||||
return tx
|
||||
|
||||
async def __aexit__(self, *exc):
|
||||
outer.events.append(("commit",))
|
||||
return False
|
||||
|
||||
self.db = MagicMock()
|
||||
self.db.tx = MagicMock(return_value=_TxCM())
|
||||
|
||||
|
||||
class _RecordingRouter:
|
||||
def __init__(self, events):
|
||||
self.events = events
|
||||
self.deleted: list = []
|
||||
|
||||
def delete_deployment(self, id): # noqa: A002 - matches router signature
|
||||
self.events.append(("router", id))
|
||||
self.deleted.append(id)
|
||||
|
||||
|
||||
class TestDeleteTeamModels:
|
||||
"""delete_team_models must remove every team's BYOK models in one transaction
|
||||
and sync the in-memory router only after that transaction commits."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deletes_all_teams_models_and_syncs_router(self):
|
||||
rows = [_model_row("a1", "team_a"), _model_row("b1", "team_b")]
|
||||
prisma = _TxPrismaClient(rows)
|
||||
router = _RecordingRouter(prisma.events)
|
||||
|
||||
deleted = await delete_team_models(
|
||||
team_ids=["team_a", "team_b"],
|
||||
prisma_client=prisma,
|
||||
llm_router=router,
|
||||
)
|
||||
|
||||
assert sorted(deleted) == ["a1", "b1"]
|
||||
assert sorted(router.deleted) == ["a1", "b1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_router_sync_happens_after_commit(self):
|
||||
"""Race-safety: the router is touched only once the DB transaction has
|
||||
committed, so a rollback can never leave a deployment without its row."""
|
||||
rows = [_model_row("a1", "team_a"), _model_row("b1", "team_b")]
|
||||
prisma = _TxPrismaClient(rows)
|
||||
router = _RecordingRouter(prisma.events)
|
||||
|
||||
await delete_team_models(
|
||||
team_ids=["team_a", "team_b"], prisma_client=prisma, llm_router=router
|
||||
)
|
||||
|
||||
commit_idx = prisma.events.index(("commit",))
|
||||
router_indices = [i for i, e in enumerate(prisma.events) if e[0] == "router"]
|
||||
delete_indices = [
|
||||
i for i, e in enumerate(prisma.events) if e[0] == "delete_many"
|
||||
]
|
||||
assert router_indices, "router was never synced"
|
||||
assert all(i > commit_idx for i in router_indices)
|
||||
assert all(i < commit_idx for i in delete_indices)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_only_owning_team_models_deleted(self):
|
||||
"""A row sharing the prefix but a different model_info.team_id is left alone."""
|
||||
mine = _model_row("a1", "team_a")
|
||||
intruder = MagicMock()
|
||||
intruder.model_id = "x9"
|
||||
intruder.model_name = "model_name_team_a_x9"
|
||||
intruder.model_info = {"team_id": "someone_else"}
|
||||
prisma = _TxPrismaClient([mine, intruder])
|
||||
router = _RecordingRouter(prisma.events)
|
||||
|
||||
deleted = await delete_team_models(
|
||||
team_ids=["team_a"], prisma_client=prisma, llm_router=router
|
||||
)
|
||||
|
||||
assert deleted == ["a1"]
|
||||
assert router.deleted == ["a1"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_models_no_writes(self):
|
||||
prisma = _TxPrismaClient([])
|
||||
router = _RecordingRouter(prisma.events)
|
||||
|
||||
deleted = await delete_team_models(
|
||||
team_ids=["team_a"], prisma_client=prisma, llm_router=router
|
||||
)
|
||||
|
||||
assert deleted == []
|
||||
assert router.deleted == []
|
||||
assert not any(e[0] == "delete_many" for e in prisma.events)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_router_is_safe(self):
|
||||
rows = [_model_row("a1", "team_a")]
|
||||
prisma = _TxPrismaClient(rows)
|
||||
|
||||
deleted = await delete_team_models(
|
||||
team_ids=["team_a"], prisma_client=prisma, llm_router=None
|
||||
)
|
||||
|
||||
assert deleted == ["a1"]
|
||||
assert any(e[0] == "delete_many" for e in prisma.events)
|
||||
|
||||
|
||||
def _build_db_model_for_blocked_test():
|
||||
from litellm.types.router import Deployment, LiteLLM_Params, ModelInfo
|
||||
|
||||
|
||||
@ -6350,6 +6350,14 @@ async def test_delete_team_persists_deleted_teams(monkeypatch):
|
||||
mock_find_many_keys = AsyncMock(return_value=[])
|
||||
mock_prisma_client.db.litellm_verificationtoken.find_many = mock_find_many_keys
|
||||
|
||||
# delete_team now deletes team BYOK models inside a transaction; this team has none.
|
||||
mock_tx = AsyncMock()
|
||||
mock_tx.litellm_proxymodeltable.find_many = AsyncMock(return_value=[])
|
||||
mock_tx_cm = MagicMock()
|
||||
mock_tx_cm.__aenter__ = AsyncMock(return_value=mock_tx)
|
||||
mock_tx_cm.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_prisma_client.db.tx = MagicMock(return_value=mock_tx_cm)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.proxy_server.prisma_client",
|
||||
mock_prisma_client,
|
||||
@ -8499,6 +8507,8 @@ async def test_new_team_encrypts_callback_vars(
|
||||
assert cv["langfuse_secret_key"] != "sk-real"
|
||||
recovered = decrypt_callback_vars(metadata)["logging"][0]["callback_vars"]
|
||||
assert recovered["langfuse_secret_key"] == "sk-real"
|
||||
|
||||
|
||||
def _non_admin_auth():
|
||||
return UserAPIKeyAuth(
|
||||
user_id="u-team-admin", user_role=LitellmUserRoles.INTERNAL_USER
|
||||
|
||||
Loading…
Reference in New Issue
Block a user