chore: simplify ownership tracking — drop thin stores, in-memory fallback, hand-rolled cache
Substantial reduction (~765 LOC) without changing the security boundary: * Drop ContainerOwnershipStore and LiteLLMSkillsStore — both were one-method-per-Prisma-call wrappers. Inline the calls instead, matching the established pattern in vector_store_endpoints, agent_endpoints, and mcp_server/db.py. * Drop the prisma_client is None in-memory fallback. Production deploys always have Prisma; running ownership-critical paths on a process-local dict is a security footgun in the dev-mode case it was meant to support, and complicates every code path with a branch. Fail-secure: skip recording if Prisma is unavailable, and treat reads as "not found" (admin-only). * Drop the hand-rolled module-level cache. Replace with the existing litellm.caching.in_memory_cache.InMemoryCache, which already has TTL + max-size + eviction tested in its own module. Sentinel string for negative caching since InMemoryCache can't disambiguate "miss" from "cached as None". * Tests: drop coverage for removed code paths (in-memory fallback, hand-rolled cache internals). Keep tests for actual behavior (cache hit-rate, negative caching, owner check, list filtering, identity-less reject, admin bypass).
This commit is contained in:
parent
12fe945e7b
commit
6ce84effe1
@ -5,13 +5,11 @@ This module contains the actual database operations for skills CRUD.
|
||||
Used by the transformation layer and skills injection hook.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.llms.litellm_proxy.skills.store import LiteLLMSkillsStore
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.proxy._types import LiteLLM_SkillsTable, NewSkillRequest, UserAPIKeyAuth
|
||||
from litellm.proxy.common_utils.resource_ownership import (
|
||||
get_primary_resource_owner_scope,
|
||||
@ -21,88 +19,37 @@ from litellm.proxy.common_utils.resource_ownership import (
|
||||
)
|
||||
|
||||
# Skills are looked up on every chat completion that has skills enabled
|
||||
# (`LiteLLMSkillsHandler.fetch_skill_from_db` in the injection hook). Cache
|
||||
# the Prisma skill row for a short window so the hot path doesn't issue a DB
|
||||
# round-trip per request. Same shape as `_byok_cred_cache` and the container
|
||||
# ownership cache: (value, monotonic_timestamp). `None` is cached as a true
|
||||
# negative ("skill does not exist") so repeated misses also avoid the DB.
|
||||
_SKILL_CACHE: "OrderedDict[str, Tuple[Optional[Any], float]]" = OrderedDict()
|
||||
_SKILL_CACHE_TTL = 60 # seconds
|
||||
_SKILL_CACHE_MAX_SIZE = 10000
|
||||
|
||||
|
||||
def _read_skill_cache(skill_id: str) -> Tuple[bool, Optional[Any]]:
|
||||
"""Return (hit, value). hit=False means caller must consult the DB."""
|
||||
entry = _SKILL_CACHE.get(skill_id)
|
||||
if entry is None:
|
||||
return False, None
|
||||
value, timestamp = entry
|
||||
if time.monotonic() - timestamp > _SKILL_CACHE_TTL:
|
||||
_SKILL_CACHE.pop(skill_id, None)
|
||||
return False, None
|
||||
_SKILL_CACHE.move_to_end(skill_id)
|
||||
return True, value
|
||||
|
||||
|
||||
def _write_skill_cache(skill_id: str, skill: Optional[Any]) -> None:
|
||||
# LRU eviction (popitem(last=False)) instead of full ``clear()`` —
|
||||
# see container ownership cache for rationale.
|
||||
if skill_id in _SKILL_CACHE:
|
||||
_SKILL_CACHE.move_to_end(skill_id)
|
||||
_SKILL_CACHE[skill_id] = (skill, time.monotonic())
|
||||
while len(_SKILL_CACHE) > _SKILL_CACHE_MAX_SIZE:
|
||||
_SKILL_CACHE.popitem(last=False)
|
||||
|
||||
|
||||
def _invalidate_skill_cache(skill_id: str) -> None:
|
||||
"""Drop the cache entry after a write so the next read sees the new row."""
|
||||
_SKILL_CACHE.pop(skill_id, None)
|
||||
|
||||
|
||||
def _user_can_access_skill_owner(
|
||||
owner: Optional[str],
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth],
|
||||
) -> bool:
|
||||
# Pre-isolation skills with no ``created_by`` are admin-only — same
|
||||
# rule as untracked containers. Owners need to either re-create via
|
||||
# the now-tracked flow or have an admin assign ``created_by`` on the
|
||||
# row.
|
||||
return user_can_access_resource_owner(owner, user_api_key_dict)
|
||||
# (`SkillsInjectionHook` calls ``fetch_skill_from_db``). 60s LRU/TTL cache
|
||||
# absorbs the hot read before it reaches Prisma. ``_NEGATIVE_SKILL_SENTINEL``
|
||||
# lets us cache a true "skill does not exist" so repeated misses also
|
||||
# avoid the DB — ``InMemoryCache`` returns ``None`` indistinguishably for
|
||||
# "miss" and "cached as None".
|
||||
_NEGATIVE_SKILL_SENTINEL = "__litellm_skill_not_found__"
|
||||
_SKILL_CACHE = InMemoryCache(max_size_in_memory=10000, default_ttl=60)
|
||||
|
||||
|
||||
def _prisma_skill_to_litellm(prisma_skill) -> LiteLLM_SkillsTable:
|
||||
"""
|
||||
Convert a Prisma skill record to LiteLLM_SkillsTable.
|
||||
"""Convert a Prisma skill record to LiteLLM_SkillsTable.
|
||||
|
||||
Handles Base64 decoding of file_content field.
|
||||
Handles Base64 decoding of file_content field — model_dump() converts
|
||||
Base64 fields to base64-encoded strings.
|
||||
"""
|
||||
import base64
|
||||
|
||||
data = prisma_skill.model_dump()
|
||||
|
||||
# Decode Base64 file_content back to bytes
|
||||
# model_dump() converts Base64 field to base64-encoded string
|
||||
if data.get("file_content") is not None:
|
||||
if isinstance(data["file_content"], str):
|
||||
data["file_content"] = base64.b64decode(data["file_content"])
|
||||
elif isinstance(data["file_content"], bytes):
|
||||
# Already bytes, no conversion needed
|
||||
pass
|
||||
|
||||
return LiteLLM_SkillsTable(**data)
|
||||
|
||||
|
||||
class LiteLLMSkillsHandler:
|
||||
"""
|
||||
Handler for LiteLLM database-backed skills operations.
|
||||
|
||||
This class provides static methods for CRUD operations on skills
|
||||
stored in the LiteLLM proxy database (LiteLLM_SkillsTable).
|
||||
"""
|
||||
"""CRUD for skills stored in ``litellm_skillstable``."""
|
||||
|
||||
@staticmethod
|
||||
async def _get_prisma_client():
|
||||
"""Get the prisma client from proxy server."""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
@ -118,28 +65,16 @@ class LiteLLMSkillsHandler:
|
||||
user_id: Optional[str] = None,
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
) -> LiteLLM_SkillsTable:
|
||||
"""
|
||||
Create a new skill in the LiteLLM database.
|
||||
|
||||
Args:
|
||||
data: NewSkillRequest with skill details
|
||||
user_id: Optional user ID for tracking
|
||||
|
||||
Returns:
|
||||
LiteLLM_SkillsTable record
|
||||
"""
|
||||
prisma_client = await LiteLLMSkillsHandler._get_prisma_client()
|
||||
store = LiteLLMSkillsStore(prisma_client)
|
||||
|
||||
skill_id = f"litellm_skill_{uuid.uuid4()}"
|
||||
owner = get_primary_resource_owner_scope(user_api_key_dict) or user_id
|
||||
if owner is None:
|
||||
# Caller has no identity scope (no user_id / team_id / org_id /
|
||||
# api_key / token). Stamping a placeholder would let any two
|
||||
# identity-less callers see each other's skills via the shared
|
||||
# owner — the cross-tenant primitive we avoid. ValueError keeps
|
||||
# this module FastAPI-free per the project layering rule
|
||||
# (litellm_proxy provider integrations live outside proxy/).
|
||||
# Identity-less callers (no user_id / team_id / org_id /
|
||||
# api_key / token) can't be uniquely stamped on the row.
|
||||
# Stamping a placeholder would let any two such callers see
|
||||
# each other's skills via the shared owner. ValueError keeps
|
||||
# this module FastAPI-free per the project layering rule.
|
||||
raise ValueError(
|
||||
"Unable to record skill ownership: caller has no identity scope."
|
||||
)
|
||||
@ -154,13 +89,11 @@ class LiteLLMSkillsHandler:
|
||||
"updated_by": owner,
|
||||
}
|
||||
|
||||
# Handle metadata
|
||||
if data.metadata is not None:
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
|
||||
skill_data["metadata"] = safe_dumps(data.metadata)
|
||||
|
||||
# Handle file content - wrap bytes in Base64 for Prisma
|
||||
if data.file_content is not None:
|
||||
from prisma.fields import Base64
|
||||
|
||||
@ -174,8 +107,7 @@ class LiteLLMSkillsHandler:
|
||||
f"LiteLLMSkillsHandler: Creating skill {skill_id} with title={data.display_title}"
|
||||
)
|
||||
|
||||
new_skill = await store.create_skill(skill_data)
|
||||
|
||||
new_skill = await prisma_client.db.litellm_skillstable.create(data=skill_data)
|
||||
return _prisma_skill_to_litellm(new_skill)
|
||||
|
||||
@staticmethod
|
||||
@ -184,18 +116,7 @@ class LiteLLMSkillsHandler:
|
||||
offset: int = 0,
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
) -> List[LiteLLM_SkillsTable]:
|
||||
"""
|
||||
List skills from the LiteLLM database.
|
||||
|
||||
Args:
|
||||
limit: Maximum number of skills to return
|
||||
offset: Number of skills to skip
|
||||
|
||||
Returns:
|
||||
List of LiteLLM_SkillsTable records
|
||||
"""
|
||||
prisma_client = await LiteLLMSkillsHandler._get_prisma_client()
|
||||
store = LiteLLMSkillsStore(prisma_client)
|
||||
|
||||
verbose_logger.debug(
|
||||
f"LiteLLMSkillsHandler: Listing skills with limit={limit}, offset={offset}"
|
||||
@ -212,26 +133,29 @@ class LiteLLMSkillsHandler:
|
||||
return []
|
||||
find_many_kwargs["where"] = {"created_by": {"in": owner_scopes}}
|
||||
|
||||
skills = await store.list_skills(find_many_kwargs)
|
||||
|
||||
skills = await prisma_client.db.litellm_skillstable.find_many(
|
||||
**find_many_kwargs
|
||||
)
|
||||
return [_prisma_skill_to_litellm(s) for s in skills]
|
||||
|
||||
@staticmethod
|
||||
async def _load_skill(skill_id: str) -> Optional[Any]:
|
||||
"""Cache-first read of the Prisma skill row.
|
||||
|
||||
Caching here keeps `fetch_skill_from_db` (called per chat completion in
|
||||
the skills injection hook) off the DB. Owner-scope filtering happens
|
||||
on the cached row, so the cache is per-skill and not per-caller.
|
||||
"""Cache-first read of the Prisma skill row. Owner-scope filtering
|
||||
happens on the cached row, so the cache is per-skill not per-caller.
|
||||
"""
|
||||
cached_hit, cached_skill = _read_skill_cache(skill_id)
|
||||
if cached_hit:
|
||||
return cached_skill
|
||||
cached = _SKILL_CACHE.get_cache(skill_id)
|
||||
if cached == _NEGATIVE_SKILL_SENTINEL:
|
||||
return None
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
prisma_client = await LiteLLMSkillsHandler._get_prisma_client()
|
||||
store = LiteLLMSkillsStore(prisma_client)
|
||||
skill = await store.find_skill(skill_id)
|
||||
_write_skill_cache(skill_id, skill)
|
||||
skill = await prisma_client.db.litellm_skillstable.find_unique(
|
||||
where={"skill_id": skill_id}
|
||||
)
|
||||
_SKILL_CACHE.set_cache(
|
||||
skill_id, skill if skill is not None else _NEGATIVE_SKILL_SENTINEL
|
||||
)
|
||||
return skill
|
||||
|
||||
@staticmethod
|
||||
@ -239,26 +163,12 @@ class LiteLLMSkillsHandler:
|
||||
skill_id: str,
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
) -> LiteLLM_SkillsTable:
|
||||
"""
|
||||
Get a skill by ID from the LiteLLM database.
|
||||
|
||||
Args:
|
||||
skill_id: The skill ID to retrieve
|
||||
|
||||
Returns:
|
||||
LiteLLM_SkillsTable record
|
||||
|
||||
Raises:
|
||||
ValueError: If skill not found
|
||||
"""
|
||||
verbose_logger.debug(f"LiteLLMSkillsHandler: Getting skill {skill_id}")
|
||||
|
||||
skill = await LiteLLMSkillsHandler._load_skill(skill_id)
|
||||
|
||||
if skill is None:
|
||||
raise ValueError(f"Skill not found: {skill_id}")
|
||||
|
||||
if not _user_can_access_skill_owner(
|
||||
# Same "not found" message for both "missing" and "cross-tenant"
|
||||
# so callers can't enumerate skill IDs they don't own.
|
||||
if skill is None or not user_can_access_resource_owner(
|
||||
getattr(skill, "created_by", None), user_api_key_dict
|
||||
):
|
||||
raise ValueError(f"Skill not found: {skill_id}")
|
||||
@ -270,37 +180,17 @@ class LiteLLMSkillsHandler:
|
||||
skill_id: str,
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Delete a skill by ID from the LiteLLM database.
|
||||
|
||||
Args:
|
||||
skill_id: The skill ID to delete
|
||||
|
||||
Returns:
|
||||
Dict with id and type of deleted skill
|
||||
|
||||
Raises:
|
||||
ValueError: If skill not found
|
||||
"""
|
||||
prisma_client = await LiteLLMSkillsHandler._get_prisma_client()
|
||||
store = LiteLLMSkillsStore(prisma_client)
|
||||
|
||||
verbose_logger.debug(f"LiteLLMSkillsHandler: Deleting skill {skill_id}")
|
||||
|
||||
# Check if skill exists
|
||||
skill = await LiteLLMSkillsHandler._load_skill(skill_id)
|
||||
|
||||
if skill is None:
|
||||
raise ValueError(f"Skill not found: {skill_id}")
|
||||
|
||||
if not _user_can_access_skill_owner(
|
||||
if skill is None or not user_can_access_resource_owner(
|
||||
getattr(skill, "created_by", None), user_api_key_dict
|
||||
):
|
||||
raise ValueError(f"Skill not found: {skill_id}")
|
||||
|
||||
# Delete the skill
|
||||
await store.delete_skill(skill_id)
|
||||
_invalidate_skill_cache(skill_id)
|
||||
await prisma_client.db.litellm_skillstable.delete(where={"skill_id": skill_id})
|
||||
_SKILL_CACHE.set_cache(skill_id, _NEGATIVE_SKILL_SENTINEL)
|
||||
|
||||
return {"id": skill_id, "type": "skill_deleted"}
|
||||
|
||||
@ -309,22 +199,11 @@ class LiteLLMSkillsHandler:
|
||||
skill_id: str,
|
||||
user_api_key_dict: Optional[UserAPIKeyAuth] = None,
|
||||
) -> Optional[LiteLLM_SkillsTable]:
|
||||
"""
|
||||
Fetch a skill from the database (used by skills injection hook).
|
||||
|
||||
This is a convenience method that returns None instead of raising
|
||||
an exception if the skill is not found.
|
||||
|
||||
Args:
|
||||
skill_id: The skill ID to fetch
|
||||
|
||||
Returns:
|
||||
LiteLLM_SkillsTable or None if not found
|
||||
"""
|
||||
"""Skills-injection-hook helper: returns None instead of raising on
|
||||
not-found / not-authorized so the hook can silently skip."""
|
||||
try:
|
||||
return await LiteLLMSkillsHandler.get_skill(
|
||||
skill_id,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
skill_id, user_api_key_dict=user_api_key_dict
|
||||
)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
@ -1,22 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
|
||||
class LiteLLMSkillsStore:
|
||||
def __init__(self, prisma_client: Any):
|
||||
self.prisma_client = prisma_client
|
||||
|
||||
@property
|
||||
def _table(self) -> Any:
|
||||
return self.prisma_client.db.litellm_skillstable
|
||||
|
||||
async def create_skill(self, data: Dict[str, Any]) -> Any:
|
||||
return await self._table.create(data=data)
|
||||
|
||||
async def list_skills(self, find_many_kwargs: Dict[str, Any]) -> List[Any]:
|
||||
return await self._table.find_many(**find_many_kwargs)
|
||||
|
||||
async def find_skill(self, skill_id: str) -> Optional[Any]:
|
||||
return await self._table.find_unique(where={"skill_id": skill_id})
|
||||
|
||||
async def delete_skill(self, skill_id: str) -> None:
|
||||
await self._table.delete(where={"skill_id": skill_id})
|
||||
@ -1,10 +1,9 @@
|
||||
import time
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, List, Optional, Set, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.common_utils.resource_ownership import (
|
||||
get_primary_resource_owner_scope,
|
||||
@ -12,75 +11,26 @@ from litellm.proxy.common_utils.resource_ownership import (
|
||||
is_proxy_admin,
|
||||
user_can_access_resource_owner,
|
||||
)
|
||||
from litellm.proxy.container_endpoints.ownership_store import (
|
||||
CONTAINER_OBJECT_PURPOSE,
|
||||
ContainerOwnershipStore,
|
||||
)
|
||||
from litellm.responses.utils import ResponsesAPIRequestUtils
|
||||
|
||||
MAX_IN_MEMORY_CONTAINER_OWNERS = 10000
|
||||
_IN_MEMORY_CONTAINER_OWNERS: "OrderedDict[str, str]" = OrderedDict()
|
||||
CONTAINER_OBJECT_PURPOSE = "container"
|
||||
|
||||
# Short-lived cache keeps every container access check from hitting the DB
|
||||
# (`_get_container_owner` is invoked on retrieve / delete / list / file-content
|
||||
# paths). Mirrors the `_byok_cred_cache` pattern in mcp_server/server.py:
|
||||
# (value, monotonic_timestamp) tuples, TTL'd, LRU-evicted at capacity,
|
||||
# invalidated by writes. A ``None`` value caches "untracked" so repeated
|
||||
# negative lookups also avoid DB.
|
||||
_CONTAINER_OWNER_CACHE: "OrderedDict[str, Tuple[Optional[str], float]]" = OrderedDict()
|
||||
_CONTAINER_OWNER_CACHE_TTL = 60 # seconds
|
||||
_CONTAINER_OWNER_CACHE_MAX_SIZE = 10000
|
||||
|
||||
|
||||
def _read_container_owner_cache(model_object_id: str) -> Tuple[bool, Optional[str]]:
|
||||
"""Return (hit, value). hit=False means caller must consult the DB."""
|
||||
entry = _CONTAINER_OWNER_CACHE.get(model_object_id)
|
||||
if entry is None:
|
||||
return False, None
|
||||
value, timestamp = entry
|
||||
if time.monotonic() - timestamp > _CONTAINER_OWNER_CACHE_TTL:
|
||||
_CONTAINER_OWNER_CACHE.pop(model_object_id, None)
|
||||
return False, None
|
||||
_CONTAINER_OWNER_CACHE.move_to_end(model_object_id)
|
||||
return True, value
|
||||
|
||||
|
||||
def _write_container_owner_cache(model_object_id: str, owner: Optional[str]) -> None:
|
||||
# LRU eviction (popitem(last=False)) instead of full ``clear()`` — a
|
||||
# full clear at capacity converts a steady-state cached workload into
|
||||
# a periodic full-DB-load oscillation as the cache repopulates from
|
||||
# zero and clears again.
|
||||
if model_object_id in _CONTAINER_OWNER_CACHE:
|
||||
_CONTAINER_OWNER_CACHE.move_to_end(model_object_id)
|
||||
_CONTAINER_OWNER_CACHE[model_object_id] = (owner, time.monotonic())
|
||||
while len(_CONTAINER_OWNER_CACHE) > _CONTAINER_OWNER_CACHE_MAX_SIZE:
|
||||
_CONTAINER_OWNER_CACHE.popitem(last=False)
|
||||
|
||||
|
||||
def _invalidate_container_owner_cache(model_object_id: str) -> None:
|
||||
"""Drop a cache entry after a write so the next read sees the new owner."""
|
||||
_CONTAINER_OWNER_CACHE.pop(model_object_id, None)
|
||||
|
||||
|
||||
def _remember_container_owner(model_object_id: str, owner: str) -> None:
|
||||
existing_owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
|
||||
if existing_owner is not None:
|
||||
_IN_MEMORY_CONTAINER_OWNERS.move_to_end(model_object_id)
|
||||
_IN_MEMORY_CONTAINER_OWNERS[model_object_id] = owner
|
||||
while len(_IN_MEMORY_CONTAINER_OWNERS) > MAX_IN_MEMORY_CONTAINER_OWNERS:
|
||||
_IN_MEMORY_CONTAINER_OWNERS.popitem(last=False)
|
||||
# 60s LRU/TTL cache absorbs every container access check before it reaches
|
||||
# Prisma. ``_NEGATIVE_OWNER_SENTINEL`` lets us cache a true "untracked"
|
||||
# answer so repeated misses also avoid the DB — ``InMemoryCache`` returns
|
||||
# ``None`` indistinguishably for "miss" and "cached as None".
|
||||
_NEGATIVE_OWNER_SENTINEL = "__litellm_container_no_owner__"
|
||||
_CONTAINER_OWNER_CACHE = InMemoryCache(max_size_in_memory=10000, default_ttl=60)
|
||||
|
||||
|
||||
def _container_model_object_id(
|
||||
original_container_id: str,
|
||||
custom_llm_provider: str,
|
||||
original_container_id: str, custom_llm_provider: str
|
||||
) -> str:
|
||||
return f"{CONTAINER_OBJECT_PURPOSE}:{custom_llm_provider}:{original_container_id}"
|
||||
|
||||
|
||||
def decode_container_id_for_ownership(
|
||||
container_id: str,
|
||||
custom_llm_provider: str,
|
||||
container_id: str, custom_llm_provider: str
|
||||
) -> Tuple[str, str]:
|
||||
decoded = ResponsesAPIRequestUtils._decode_container_id(container_id)
|
||||
original_container_id = decoded.get("response_id", container_id)
|
||||
@ -91,9 +41,7 @@ def decode_container_id_for_ownership(
|
||||
|
||||
|
||||
def get_container_forwarding_params(
|
||||
container_id: str,
|
||||
original_container_id: str,
|
||||
custom_llm_provider: str,
|
||||
container_id: str, original_container_id: str, custom_llm_provider: str
|
||||
) -> Dict[str, str]:
|
||||
params = {
|
||||
"container_id": original_container_id,
|
||||
@ -147,122 +95,93 @@ async def record_container_owner(
|
||||
)
|
||||
return response
|
||||
if owner is None:
|
||||
# Caller has no identity (no user_id / team_id / org_id / api_key /
|
||||
# token) we can stamp on the row. Recording with a placeholder
|
||||
# would collapse every such caller into a single shared owner —
|
||||
# the cross-tenant data-access primitive we explicitly avoid.
|
||||
# Reject with 403 rather than fall back to a sentinel.
|
||||
# Identity-less callers (no user_id / team_id / org_id / api_key /
|
||||
# token) can't be uniquely stamped on the row. Stamping a
|
||||
# placeholder would collapse every such caller into a shared
|
||||
# owner — the cross-tenant primitive we explicitly avoid.
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Unable to record container ownership: caller has no identity scope.",
|
||||
)
|
||||
|
||||
original_container_id, resolved_provider = decode_container_id_for_ownership(
|
||||
container_id,
|
||||
custom_llm_provider,
|
||||
container_id, custom_llm_provider
|
||||
)
|
||||
model_object_id = _container_model_object_id(
|
||||
original_container_id,
|
||||
resolved_provider,
|
||||
original_container_id, resolved_provider
|
||||
)
|
||||
file_object = _dump_response(response)
|
||||
file_object["custom_llm_provider"] = resolved_provider
|
||||
file_object["provider_container_id"] = original_container_id
|
||||
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
if prisma_client is None:
|
||||
existing_owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
|
||||
if existing_owner is not None and not user_can_access_resource_owner(
|
||||
existing_owner, user_api_key_dict
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
_remember_container_owner(model_object_id, owner)
|
||||
_invalidate_container_owner_cache(model_object_id)
|
||||
return response
|
||||
|
||||
store = ContainerOwnershipStore(prisma_client)
|
||||
existing = await store.find_by_model_object_id(model_object_id)
|
||||
if existing is not None:
|
||||
if getattr(existing, "file_purpose", None) != CONTAINER_OBJECT_PURPOSE:
|
||||
raise HTTPException(status_code=500, detail="Unable to track container")
|
||||
if not user_can_access_resource_owner(
|
||||
getattr(existing, "created_by", None), user_api_key_dict
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
await store.update_owner_record(
|
||||
model_object_id=model_object_id,
|
||||
data={
|
||||
"unified_object_id": container_id,
|
||||
"file_object": file_object,
|
||||
"updated_by": owner,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await store.create_owner_record(
|
||||
data={
|
||||
"unified_object_id": container_id,
|
||||
"model_object_id": model_object_id,
|
||||
"file_object": file_object,
|
||||
"file_purpose": CONTAINER_OBJECT_PURPOSE,
|
||||
"created_by": owner,
|
||||
"updated_by": owner,
|
||||
}
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
prisma_client = await _get_prisma_client()
|
||||
if prisma_client is None:
|
||||
verbose_proxy_logger.warning(
|
||||
"Failed to persist container ownership for container_id=%s; "
|
||||
"falling back to in-process tracking: %s",
|
||||
model_object_id,
|
||||
e,
|
||||
"Skipping container ownership tracking because prisma_client is None"
|
||||
)
|
||||
existing_owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
|
||||
if existing_owner is not None and not user_can_access_resource_owner(
|
||||
existing_owner, user_api_key_dict
|
||||
return response
|
||||
|
||||
table = prisma_client.db.litellm_managedobjecttable
|
||||
existing = await table.find_unique(where={"model_object_id": model_object_id})
|
||||
if existing is not None:
|
||||
if getattr(existing, "file_purpose", None) != CONTAINER_OBJECT_PURPOSE:
|
||||
raise HTTPException(status_code=500, detail="Unable to track container")
|
||||
if not user_can_access_resource_owner(
|
||||
getattr(existing, "created_by", None), user_api_key_dict
|
||||
):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
_remember_container_owner(model_object_id, owner)
|
||||
await table.update(
|
||||
where={"model_object_id": model_object_id},
|
||||
data={
|
||||
"unified_object_id": container_id,
|
||||
"file_object": file_object,
|
||||
"updated_by": owner,
|
||||
},
|
||||
)
|
||||
else:
|
||||
await table.create(
|
||||
data={
|
||||
"unified_object_id": container_id,
|
||||
"model_object_id": model_object_id,
|
||||
"file_object": file_object,
|
||||
"file_purpose": CONTAINER_OBJECT_PURPOSE,
|
||||
"created_by": owner,
|
||||
"updated_by": owner,
|
||||
}
|
||||
)
|
||||
|
||||
_invalidate_container_owner_cache(model_object_id)
|
||||
_CONTAINER_OWNER_CACHE.set_cache(model_object_id, owner)
|
||||
return response
|
||||
|
||||
|
||||
async def _get_container_owner(
|
||||
original_container_id: str,
|
||||
custom_llm_provider: str,
|
||||
original_container_id: str, custom_llm_provider: str
|
||||
) -> Optional[str]:
|
||||
model_object_id = _container_model_object_id(
|
||||
original_container_id,
|
||||
custom_llm_provider,
|
||||
original_container_id, custom_llm_provider
|
||||
)
|
||||
|
||||
cached_hit, cached_value = _read_container_owner_cache(model_object_id)
|
||||
if cached_hit:
|
||||
return cached_value
|
||||
cached = _CONTAINER_OWNER_CACHE.get_cache(model_object_id)
|
||||
if cached == _NEGATIVE_OWNER_SENTINEL:
|
||||
return None
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
if prisma_client is None:
|
||||
owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
|
||||
_write_container_owner_cache(model_object_id, owner)
|
||||
return owner
|
||||
prisma_client = await _get_prisma_client()
|
||||
if prisma_client is None:
|
||||
return None
|
||||
|
||||
owner = await ContainerOwnershipStore(prisma_client).get_owner(model_object_id)
|
||||
if owner is None:
|
||||
owner = _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
|
||||
_write_container_owner_cache(model_object_id, owner)
|
||||
return owner
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"Failed to load container ownership for container_id=%s; "
|
||||
"falling back to in-process tracking: %s",
|
||||
model_object_id,
|
||||
e,
|
||||
)
|
||||
# Don't cache transient DB errors — let the next request retry.
|
||||
return _IN_MEMORY_CONTAINER_OWNERS.get(model_object_id)
|
||||
row = await prisma_client.db.litellm_managedobjecttable.find_first(
|
||||
where={
|
||||
"model_object_id": model_object_id,
|
||||
"file_purpose": CONTAINER_OBJECT_PURPOSE,
|
||||
}
|
||||
)
|
||||
owner = getattr(row, "created_by", None) if row is not None else None
|
||||
_CONTAINER_OWNER_CACHE.set_cache(
|
||||
model_object_id, owner if owner is not None else _NEGATIVE_OWNER_SENTINEL
|
||||
)
|
||||
return owner
|
||||
|
||||
|
||||
async def assert_user_can_access_container(
|
||||
@ -271,17 +190,15 @@ async def assert_user_can_access_container(
|
||||
custom_llm_provider: str,
|
||||
) -> Tuple[str, str]:
|
||||
original_container_id, resolved_provider = decode_container_id_for_ownership(
|
||||
container_id,
|
||||
custom_llm_provider,
|
||||
container_id, custom_llm_provider
|
||||
)
|
||||
|
||||
if is_proxy_admin(user_api_key_dict):
|
||||
return original_container_id, resolved_provider
|
||||
|
||||
# Untracked containers (no ownership row) are admin-only. Pre-isolation
|
||||
# rows that pre-date this enforcement need the admin to either re-create
|
||||
# via the now-tracked flow or explicitly assign ``created_by`` on the
|
||||
# ``litellm_managedobjecttable`` row.
|
||||
# Untracked rows (no ownership) are admin-only. Pre-isolation rows
|
||||
# that pre-date this enforcement need an admin to either re-create
|
||||
# via the now-tracked flow or assign ``created_by`` on the row.
|
||||
owner = await _get_container_owner(original_container_id, resolved_provider)
|
||||
if not user_can_access_resource_owner(owner, user_api_key_dict):
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
@ -327,35 +244,26 @@ def _set_container_list_data(
|
||||
|
||||
async def _get_allowed_container_ids(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
custom_llm_provider: str,
|
||||
) -> Set[str]:
|
||||
owner_scopes = get_resource_owner_scopes(user_api_key_dict)
|
||||
if not owner_scopes:
|
||||
return set()
|
||||
|
||||
in_memory_allowed_ids = {
|
||||
model_object_id
|
||||
for model_object_id, owner in _IN_MEMORY_CONTAINER_OWNERS.items()
|
||||
if owner in owner_scopes
|
||||
}
|
||||
try:
|
||||
prisma_client = await _get_prisma_client()
|
||||
if prisma_client is None:
|
||||
return in_memory_allowed_ids
|
||||
prisma_client = await _get_prisma_client()
|
||||
if prisma_client is None:
|
||||
return set()
|
||||
|
||||
db_allowed_ids = await ContainerOwnershipStore(
|
||||
prisma_client
|
||||
).list_model_object_ids_for_owners(
|
||||
owner_scopes=owner_scopes,
|
||||
)
|
||||
return in_memory_allowed_ids | db_allowed_ids
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"Failed to load allowed container ids; falling back to in-process "
|
||||
"tracking: %s",
|
||||
e,
|
||||
)
|
||||
return in_memory_allowed_ids
|
||||
rows = await prisma_client.db.litellm_managedobjecttable.find_many(
|
||||
where={
|
||||
"file_purpose": CONTAINER_OBJECT_PURPOSE,
|
||||
"created_by": {"in": owner_scopes},
|
||||
}
|
||||
)
|
||||
return {
|
||||
row.model_object_id
|
||||
for row in rows
|
||||
if getattr(row, "model_object_id", None) is not None
|
||||
}
|
||||
|
||||
|
||||
async def filter_container_list_response(
|
||||
@ -370,18 +278,14 @@ async def filter_container_list_response(
|
||||
if data is None:
|
||||
return response
|
||||
|
||||
allowed_container_ids = await _get_allowed_container_ids(
|
||||
user_api_key_dict,
|
||||
custom_llm_provider,
|
||||
)
|
||||
allowed_container_ids = await _get_allowed_container_ids(user_api_key_dict)
|
||||
filtered: List[Any] = []
|
||||
for item in data:
|
||||
container_id = _get_response_id(item)
|
||||
if container_id is None:
|
||||
continue
|
||||
original_container_id, resolved_provider = decode_container_id_for_ownership(
|
||||
container_id,
|
||||
custom_llm_provider,
|
||||
container_id, custom_llm_provider
|
||||
)
|
||||
if (
|
||||
_container_model_object_id(original_container_id, resolved_provider)
|
||||
@ -390,7 +294,5 @@ async def filter_container_list_response(
|
||||
filtered.append(item)
|
||||
|
||||
return _set_container_list_data(
|
||||
response,
|
||||
filtered,
|
||||
removed_filtered_items=len(filtered) != len(data),
|
||||
response, filtered, removed_filtered_items=len(filtered) != len(data)
|
||||
)
|
||||
|
||||
@ -1,55 +0,0 @@
|
||||
from typing import Any, Dict, List, Optional, Set
|
||||
|
||||
CONTAINER_OBJECT_PURPOSE = "container"
|
||||
|
||||
|
||||
class ContainerOwnershipStore:
|
||||
def __init__(self, prisma_client: Any):
|
||||
self.prisma_client = prisma_client
|
||||
|
||||
@property
|
||||
def _table(self) -> Any:
|
||||
return self.prisma_client.db.litellm_managedobjecttable
|
||||
|
||||
async def find_by_model_object_id(self, model_object_id: str) -> Optional[Any]:
|
||||
return await self._table.find_unique(where={"model_object_id": model_object_id})
|
||||
|
||||
async def create_owner_record(self, data: Dict[str, Any]) -> None:
|
||||
await self._table.create(data=data)
|
||||
|
||||
async def update_owner_record(
|
||||
self,
|
||||
model_object_id: str,
|
||||
data: Dict[str, Any],
|
||||
) -> None:
|
||||
await self._table.update(
|
||||
where={"model_object_id": model_object_id},
|
||||
data=data,
|
||||
)
|
||||
|
||||
async def get_owner(self, model_object_id: str) -> Optional[str]:
|
||||
row = await self._table.find_first(
|
||||
where={
|
||||
"model_object_id": model_object_id,
|
||||
"file_purpose": CONTAINER_OBJECT_PURPOSE,
|
||||
}
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
return getattr(row, "created_by", None)
|
||||
|
||||
async def list_model_object_ids_for_owners(
|
||||
self,
|
||||
owner_scopes: List[str],
|
||||
) -> Set[str]:
|
||||
rows = await self._table.find_many(
|
||||
where={
|
||||
"file_purpose": CONTAINER_OBJECT_PURPOSE,
|
||||
"created_by": {"in": owner_scopes},
|
||||
}
|
||||
)
|
||||
return {
|
||||
row.model_object_id
|
||||
for row in rows
|
||||
if getattr(row, "model_object_id", None) is not None
|
||||
}
|
||||
@ -12,12 +12,12 @@ from litellm.types.containers.main import ContainerListResponse, ContainerObject
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_in_memory_container_owners():
|
||||
ownership._IN_MEMORY_CONTAINER_OWNERS.clear()
|
||||
ownership._CONTAINER_OWNER_CACHE.clear()
|
||||
def clear_container_owner_cache():
|
||||
ownership._CONTAINER_OWNER_CACHE.cache_dict.clear()
|
||||
ownership._CONTAINER_OWNER_CACHE.ttl_dict.clear()
|
||||
yield
|
||||
ownership._IN_MEMORY_CONTAINER_OWNERS.clear()
|
||||
ownership._CONTAINER_OWNER_CACHE.clear()
|
||||
ownership._CONTAINER_OWNER_CACHE.cache_dict.clear()
|
||||
ownership._CONTAINER_OWNER_CACHE.ttl_dict.clear()
|
||||
|
||||
|
||||
def _container(container_id: str) -> ContainerObject:
|
||||
@ -159,7 +159,6 @@ async def test_should_reject_record_for_identityless_proxy_auth(monkeypatch):
|
||||
)
|
||||
assert exc.value.status_code == 403
|
||||
assert "identity scope" in str(exc.value.detail)
|
||||
assert ownership._IN_MEMORY_CONTAINER_OWNERS == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -186,106 +185,6 @@ async def test_should_skip_owner_record_when_provider_response_has_no_id(monkeyp
|
||||
table.create.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_fallback_to_memory_when_persistent_owner_record_fails(
|
||||
monkeypatch,
|
||||
):
|
||||
table = AsyncMock()
|
||||
table.find_unique.side_effect = Exception("db unavailable")
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
|
||||
await ownership.record_container_owner(
|
||||
response=_container("cntr_provider"),
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
assert (
|
||||
ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_provider"]
|
||||
== "user-1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_track_container_owner_in_memory_without_prisma(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
|
||||
await ownership.record_container_owner(
|
||||
response=_container("cntr_provider"),
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
original_id, provider = await ownership.assert_user_can_access_container(
|
||||
container_id="cntr_provider",
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
assert original_id == "cntr_provider"
|
||||
assert provider == "openai"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_bound_in_memory_container_owner_tracking(monkeypatch):
|
||||
monkeypatch.setattr(ownership, "MAX_IN_MEMORY_CONTAINER_OWNERS", 2)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
|
||||
for container_id in ("cntr_1", "cntr_2", "cntr_3"):
|
||||
await ownership.record_container_owner(
|
||||
response=_container(container_id),
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
assert list(ownership._IN_MEMORY_CONTAINER_OWNERS.keys()) == [
|
||||
"container:openai:cntr_2",
|
||||
"container:openai:cntr_3",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_deny_container_access_for_different_owner(monkeypatch):
|
||||
table = AsyncMock()
|
||||
table.find_first.return_value = SimpleNamespace(created_by="user-2")
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await ownership.assert_user_can_access_container(
|
||||
container_id="cntr_provider",
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_deny_untracked_container_access_by_default(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
@ -305,103 +204,6 @@ async def test_should_deny_untracked_container_access_by_default(monkeypatch):
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_fallback_to_memory_when_owner_lookup_fails(monkeypatch):
|
||||
table = AsyncMock()
|
||||
table.find_first.side_effect = Exception("db unavailable")
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1"
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
|
||||
original_id, provider = await ownership.assert_user_can_access_container(
|
||||
container_id="cntr_owned",
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
assert original_id == "cntr_owned"
|
||||
assert provider == "openai"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_use_memory_owner_when_db_recovers_without_row(monkeypatch):
|
||||
table = AsyncMock()
|
||||
table.find_first.return_value = None
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1"
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
|
||||
original_id, provider = await ownership.assert_user_can_access_container(
|
||||
container_id="cntr_owned",
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
assert original_id == "cntr_owned"
|
||||
assert provider == "openai"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_fail_closed_when_owner_lookup_fails_without_memory(monkeypatch):
|
||||
table = AsyncMock()
|
||||
table.find_first.side_effect = Exception("db unavailable")
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await ownership.assert_user_can_access_container(
|
||||
container_id="cntr_owned",
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_untracked_container_is_admin_only(monkeypatch):
|
||||
"""Pre-isolation containers with no ownership row are admin-only.
|
||||
Non-admin callers see them as 403, with no opt-out flag re-opening
|
||||
the cross-tenant access primitive."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
await ownership.assert_user_can_access_container(
|
||||
container_id="cntr_untracked",
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
assert exc.value.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_not_reassign_existing_container_to_different_owner(monkeypatch):
|
||||
table = AsyncMock()
|
||||
@ -536,165 +338,6 @@ async def test_should_clear_dict_has_more_when_filtered_container_list_is_empty(
|
||||
assert filtered["has_more"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_filter_container_list_with_in_memory_ownership(monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=None),
|
||||
)
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
|
||||
await ownership.record_container_owner(
|
||||
response=_container("cntr_owned"),
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
response = ContainerListResponse(
|
||||
object="list",
|
||||
data=[_container("cntr_owned"), _container("cntr_other")],
|
||||
has_more=False,
|
||||
)
|
||||
|
||||
filtered = await ownership.filter_container_list_response(
|
||||
response=response,
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
assert [item.id for item in filtered.data] == ["cntr_owned"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_filter_container_list_with_memory_when_db_lookup_fails(
|
||||
monkeypatch,
|
||||
):
|
||||
table = AsyncMock()
|
||||
table.find_many.side_effect = Exception("db unavailable")
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1"
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
response = ContainerListResponse(
|
||||
object="list",
|
||||
data=[_container("cntr_owned"), _container("cntr_other")],
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
filtered = await ownership.filter_container_list_response(
|
||||
response=response,
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
assert [item.id for item in filtered.data] == ["cntr_owned"]
|
||||
assert filtered.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_include_memory_container_list_when_db_recovers_without_row(
|
||||
monkeypatch,
|
||||
):
|
||||
table = AsyncMock()
|
||||
table.find_many.return_value = []
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
ownership._IN_MEMORY_CONTAINER_OWNERS["container:openai:cntr_owned"] = "user-1"
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
response = ContainerListResponse(
|
||||
object="list",
|
||||
data=[_container("cntr_owned"), _container("cntr_other")],
|
||||
has_more=True,
|
||||
)
|
||||
|
||||
filtered = await ownership.filter_container_list_response(
|
||||
response=response,
|
||||
user_api_key_dict=auth,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
assert [item.id for item in filtered.data] == ["cntr_owned"]
|
||||
assert filtered.has_more is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_validate_owner_and_forward_decoded_id_for_proxy_forwarding(
|
||||
monkeypatch,
|
||||
):
|
||||
from litellm.proxy.container_endpoints import handler_factory
|
||||
|
||||
proxy_server_stub = SimpleNamespace(
|
||||
general_settings={},
|
||||
llm_router=None,
|
||||
proxy_config=None,
|
||||
proxy_logging_obj=None,
|
||||
select_data_generator=None,
|
||||
user_api_base=None,
|
||||
user_max_tokens=None,
|
||||
user_model=None,
|
||||
user_request_timeout=None,
|
||||
user_temperature=None,
|
||||
version="test",
|
||||
)
|
||||
monkeypatch.setitem(sys.modules, "litellm.proxy.proxy_server", proxy_server_stub)
|
||||
|
||||
captured = {}
|
||||
|
||||
class FakeProcessor:
|
||||
def __init__(self, data):
|
||||
captured["data"] = data
|
||||
|
||||
async def base_process_llm_request(self, **kwargs):
|
||||
return captured["data"]
|
||||
|
||||
async def _handle_llm_api_exception(self, **kwargs):
|
||||
raise kwargs["e"]
|
||||
|
||||
monkeypatch.setattr(
|
||||
handler_factory,
|
||||
"ProxyBaseLLMRequestProcessing",
|
||||
FakeProcessor,
|
||||
)
|
||||
access_check = AsyncMock(return_value=("cntr_provider", "azure"))
|
||||
monkeypatch.setattr(
|
||||
handler_factory,
|
||||
"assert_user_can_access_container",
|
||||
access_check,
|
||||
)
|
||||
encoded_id = ResponsesAPIRequestUtils._build_container_id(
|
||||
custom_llm_provider="azure",
|
||||
model_id="router-gpt",
|
||||
container_id="cntr_provider",
|
||||
)
|
||||
|
||||
result = await handler_factory._process_request(
|
||||
request=SimpleNamespace(query_params={}, headers={}),
|
||||
fastapi_response=SimpleNamespace(),
|
||||
user_api_key_dict=UserAPIKeyAuth(user_id="user-1"),
|
||||
route_type="alist_container_files",
|
||||
path_params={"container_id": encoded_id},
|
||||
)
|
||||
|
||||
access_check.assert_awaited_once()
|
||||
assert access_check.await_args.kwargs["container_id"] == encoded_id
|
||||
assert result["container_id"] == "cntr_provider"
|
||||
assert result["custom_llm_provider"] == "azure"
|
||||
assert result["model_id"] == "router-gpt"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_should_validate_owner_and_forward_decoded_id_for_multipart_upload(
|
||||
monkeypatch,
|
||||
@ -1158,93 +801,3 @@ async def test_get_container_owner_caches_negative_lookups(monkeypatch):
|
||||
assert await ownership._get_container_owner("cntr_x", "openai") is None
|
||||
assert await ownership._get_container_owner("cntr_x", "openai") is None
|
||||
assert table.find_first.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_record_container_owner_invalidates_cache(monkeypatch):
|
||||
"""A recorded owner must drop the cached value so the next read re-fetches.
|
||||
|
||||
Otherwise a stale `None` from a prior negative lookup would survive the
|
||||
create and the new owner would be invisible until the TTL elapses.
|
||||
"""
|
||||
# Seed the cache with a stale negative result.
|
||||
ownership._write_container_owner_cache("container:openai:cntr_new", None)
|
||||
cached_hit, cached_value = ownership._read_container_owner_cache(
|
||||
"container:openai:cntr_new"
|
||||
)
|
||||
assert cached_hit and cached_value is None
|
||||
|
||||
table = AsyncMock()
|
||||
table.find_unique.return_value = None
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
|
||||
await ownership.record_container_owner(
|
||||
response=_container("cntr_new"),
|
||||
user_api_key_dict=UserAPIKeyAuth(user_id="user-1"),
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
|
||||
# Invalidation drops the entry — next read goes to the DB.
|
||||
cached_hit, _ = ownership._read_container_owner_cache("container:openai:cntr_new")
|
||||
assert not cached_hit
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_container_owner_does_not_cache_on_db_error(monkeypatch):
|
||||
"""DB errors must skip caching so transient failures don't pin a `None`."""
|
||||
table = AsyncMock()
|
||||
table.find_first.side_effect = Exception("db unavailable")
|
||||
prisma_client = SimpleNamespace(
|
||||
db=SimpleNamespace(litellm_managedobjecttable=table)
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
ownership,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
|
||||
result = await ownership._get_container_owner("cntr_x", "openai")
|
||||
assert result is None
|
||||
cached_hit, _ = ownership._read_container_owner_cache("container:openai:cntr_x")
|
||||
assert not cached_hit
|
||||
|
||||
|
||||
def test_container_owner_cache_expires_after_ttl(monkeypatch):
|
||||
"""Entries past the TTL count as misses so writes elsewhere are eventually
|
||||
visible to this process."""
|
||||
monkeypatch.setattr(ownership, "_CONTAINER_OWNER_CACHE_TTL", 0.0)
|
||||
ownership._write_container_owner_cache("k", "user-1")
|
||||
cached_hit, _ = ownership._read_container_owner_cache("k")
|
||||
# TTL of 0 means anything in the cache is already stale.
|
||||
assert not cached_hit
|
||||
|
||||
|
||||
def test_container_owner_cache_evicts_when_at_capacity(monkeypatch):
|
||||
"""The cache must not grow unbounded; reaching capacity LRU-evicts the
|
||||
oldest entry, not the entire cache."""
|
||||
monkeypatch.setattr(ownership, "_CONTAINER_OWNER_CACHE_MAX_SIZE", 2)
|
||||
ownership._write_container_owner_cache("a", "user-a")
|
||||
ownership._write_container_owner_cache("b", "user-b")
|
||||
ownership._write_container_owner_cache("c", "user-c")
|
||||
# ``a`` was the oldest and is dropped; ``b`` and ``c`` survive.
|
||||
assert list(ownership._CONTAINER_OWNER_CACHE.keys()) == ["b", "c"]
|
||||
|
||||
|
||||
def test_container_owner_cache_read_marks_as_recently_used(monkeypatch):
|
||||
"""Reading an entry should reset its position so a subsequent eviction
|
||||
drops a less-recently-used entry instead of the just-touched one."""
|
||||
monkeypatch.setattr(ownership, "_CONTAINER_OWNER_CACHE_MAX_SIZE", 2)
|
||||
ownership._write_container_owner_cache("a", "user-a")
|
||||
ownership._write_container_owner_cache("b", "user-b")
|
||||
# Touch ``a`` so it becomes the most-recently-used.
|
||||
ownership._read_container_owner_cache("a")
|
||||
ownership._write_container_owner_cache("c", "user-c")
|
||||
# ``b`` is the LRU at this point; ``a`` and ``c`` survive.
|
||||
assert list(ownership._CONTAINER_OWNER_CACHE.keys()) == ["a", "c"]
|
||||
|
||||
@ -19,9 +19,11 @@ from litellm.skills import main as skills_main
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_skill_cache():
|
||||
skills_handler._SKILL_CACHE.clear()
|
||||
skills_handler._SKILL_CACHE.cache_dict.clear()
|
||||
skills_handler._SKILL_CACHE.ttl_dict.clear()
|
||||
yield
|
||||
skills_handler._SKILL_CACHE.clear()
|
||||
skills_handler._SKILL_CACHE.cache_dict.clear()
|
||||
skills_handler._SKILL_CACHE.ttl_dict.clear()
|
||||
|
||||
|
||||
def _skill(skill_id: str, created_by: str | None) -> LiteLLM_SkillsTable:
|
||||
@ -440,99 +442,75 @@ async def test_should_scope_skill_injection_fetch_to_authenticated_user(monkeypa
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_skill_uses_cache_after_first_db_hit(monkeypatch):
|
||||
"""`fetch_skill_from_db` is hit per-chat-completion; the cache absorbs
|
||||
"""``fetch_skill_from_db`` runs per chat-completion; the cache absorbs
|
||||
repeats so we don't issue a Prisma query on every request."""
|
||||
fake_skill = Mock(created_by="user-1", skill_id="litellm_skill_a")
|
||||
store_factory = AsyncMock()
|
||||
store = Mock()
|
||||
store.find_skill = AsyncMock(return_value=fake_skill)
|
||||
table = AsyncMock()
|
||||
table.find_unique = AsyncMock(return_value=fake_skill)
|
||||
prisma_client = type(
|
||||
"Prisma", (), {"db": type("DB", (), {"litellm_skillstable": table})()}
|
||||
)()
|
||||
monkeypatch.setattr(
|
||||
skills_handler.LiteLLMSkillsHandler,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=store_factory),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
skills_handler,
|
||||
"LiteLLMSkillsStore",
|
||||
Mock(return_value=store),
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
|
||||
first = await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a")
|
||||
second = await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a")
|
||||
third = await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a")
|
||||
|
||||
assert first is fake_skill
|
||||
assert second is fake_skill
|
||||
assert third is fake_skill
|
||||
assert store.find_skill.await_count == 1
|
||||
for _ in range(3):
|
||||
assert (
|
||||
await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a")
|
||||
is fake_skill
|
||||
)
|
||||
assert table.find_unique.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_load_skill_caches_negative_lookups(monkeypatch):
|
||||
"""Missing skills must cache as `None` so repeated lookups skip the DB."""
|
||||
store_factory = AsyncMock()
|
||||
store = Mock()
|
||||
store.find_skill = AsyncMock(return_value=None)
|
||||
"""Missing skills cache as the negative sentinel so repeated misses skip
|
||||
the DB and the caller still sees ``None``."""
|
||||
table = AsyncMock()
|
||||
table.find_unique = AsyncMock(return_value=None)
|
||||
prisma_client = type(
|
||||
"Prisma", (), {"db": type("DB", (), {"litellm_skillstable": table})()}
|
||||
)()
|
||||
monkeypatch.setattr(
|
||||
skills_handler.LiteLLMSkillsHandler,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=store_factory),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
skills_handler,
|
||||
"LiteLLMSkillsStore",
|
||||
Mock(return_value=store),
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
|
||||
assert await skills_handler.LiteLLMSkillsHandler._load_skill("missing") is None
|
||||
assert await skills_handler.LiteLLMSkillsHandler._load_skill("missing") is None
|
||||
assert store.find_skill.await_count == 1
|
||||
assert table.find_unique.await_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_skill_invalidates_cache(monkeypatch):
|
||||
"""After delete, the next read must consult the DB rather than the cached
|
||||
pre-delete row."""
|
||||
"""After delete, the next read should not see the pre-delete cached row."""
|
||||
fake_skill = Mock(created_by="user-1", skill_id="litellm_skill_a")
|
||||
store = Mock()
|
||||
store.find_skill = AsyncMock(return_value=fake_skill)
|
||||
store.delete_skill = AsyncMock()
|
||||
table = AsyncMock()
|
||||
table.find_unique = AsyncMock(return_value=fake_skill)
|
||||
table.delete = AsyncMock()
|
||||
prisma_client = type(
|
||||
"Prisma", (), {"db": type("DB", (), {"litellm_skillstable": table})()}
|
||||
)()
|
||||
monkeypatch.setattr(
|
||||
skills_handler.LiteLLMSkillsHandler,
|
||||
"_get_prisma_client",
|
||||
AsyncMock(return_value=Mock()),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
skills_handler,
|
||||
"LiteLLMSkillsStore",
|
||||
Mock(return_value=store),
|
||||
AsyncMock(return_value=prisma_client),
|
||||
)
|
||||
|
||||
# Prime the cache via the read path.
|
||||
await skills_handler.LiteLLMSkillsHandler._load_skill("litellm_skill_a")
|
||||
cached_hit, _ = skills_handler._read_skill_cache("litellm_skill_a")
|
||||
assert cached_hit
|
||||
assert skills_handler._SKILL_CACHE.get_cache("litellm_skill_a") is fake_skill
|
||||
|
||||
auth = UserAPIKeyAuth(user_id="user-1")
|
||||
await skills_handler.LiteLLMSkillsHandler.delete_skill(
|
||||
"litellm_skill_a", user_api_key_dict=auth
|
||||
)
|
||||
|
||||
cached_hit_after, _ = skills_handler._read_skill_cache("litellm_skill_a")
|
||||
assert not cached_hit_after
|
||||
|
||||
|
||||
def test_skill_cache_expires_after_ttl(monkeypatch):
|
||||
monkeypatch.setattr(skills_handler, "_SKILL_CACHE_TTL", 0.0)
|
||||
skills_handler._write_skill_cache("k", Mock())
|
||||
cached_hit, _ = skills_handler._read_skill_cache("k")
|
||||
assert not cached_hit
|
||||
|
||||
|
||||
def test_skill_cache_evicts_when_at_capacity(monkeypatch):
|
||||
monkeypatch.setattr(skills_handler, "_SKILL_CACHE_MAX_SIZE", 2)
|
||||
skills_handler._write_skill_cache("a", Mock())
|
||||
skills_handler._write_skill_cache("b", Mock())
|
||||
skills_handler._write_skill_cache("c", Mock())
|
||||
# ``a`` was the oldest and is LRU-evicted; ``b`` and ``c`` survive.
|
||||
assert list(skills_handler._SKILL_CACHE.keys()) == ["b", "c"]
|
||||
# Post-delete, the cache holds the negative sentinel — not the stale row.
|
||||
assert (
|
||||
skills_handler._SKILL_CACHE.get_cache("litellm_skill_a")
|
||||
== skills_handler._NEGATIVE_SKILL_SENTINEL
|
||||
)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user