Model and Team filtering
This commit is contained in:
parent
a7e26460d0
commit
47810f1523
@ -7892,6 +7892,120 @@ def _paginate_models_response(
|
||||
}
|
||||
|
||||
|
||||
async def _filter_models_by_team_id(
|
||||
all_models: List[Dict[str, Any]],
|
||||
team_id: str,
|
||||
prisma_client: PrismaClient,
|
||||
llm_router: Router,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Filter models by team ID. Returns models where:
|
||||
- direct_access is True, OR
|
||||
- team_id is in access_via_team_ids
|
||||
|
||||
Also searches config and database for models accessible to the team.
|
||||
|
||||
Args:
|
||||
all_models: List of models to filter
|
||||
team_id: Team ID to filter by
|
||||
prisma_client: Prisma client for database queries
|
||||
llm_router: Router instance for config queries
|
||||
|
||||
Returns:
|
||||
Filtered list of models
|
||||
"""
|
||||
# Get team from database
|
||||
try:
|
||||
team_db_object = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id}
|
||||
)
|
||||
if team_db_object is None:
|
||||
verbose_proxy_logger.warning(f"Team {team_id} not found in database")
|
||||
# If team doesn't exist, return empty list
|
||||
return []
|
||||
|
||||
team_object = LiteLLM_TeamTable(**team_db_object.model_dump())
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(f"Error fetching team {team_id}: {str(e)}")
|
||||
return []
|
||||
|
||||
# Get models accessible to this team (similar to _add_team_models_to_all_models)
|
||||
team_accessible_model_ids: Set[str] = set()
|
||||
|
||||
if (
|
||||
len(team_object.models) == 0 # empty list = all model access
|
||||
or SpecialModelNames.all_proxy_models.value in team_object.models
|
||||
):
|
||||
# Team has access to all models
|
||||
model_list = llm_router.get_model_list() if llm_router else []
|
||||
if model_list is not None:
|
||||
for model in model_list:
|
||||
model_id = model.get("model_info", {}).get("id", None)
|
||||
if model_id is None:
|
||||
continue
|
||||
# if team model id set, check if team id matches
|
||||
team_model_id = model.get("model_info", {}).get("team_id", None)
|
||||
can_add_model = False
|
||||
if team_model_id is None:
|
||||
can_add_model = True
|
||||
elif team_model_id == team_id:
|
||||
can_add_model = True
|
||||
|
||||
if can_add_model:
|
||||
team_accessible_model_ids.add(model_id)
|
||||
else:
|
||||
# Team has access to specific models
|
||||
for model_name in team_object.models:
|
||||
_models = llm_router.get_model_list(
|
||||
model_name=model_name, team_id=team_id
|
||||
) if llm_router else []
|
||||
if _models is not None:
|
||||
for model in _models:
|
||||
model_id = model.get("model_info", {}).get("id", None)
|
||||
if model_id is not None:
|
||||
team_accessible_model_ids.add(model_id)
|
||||
|
||||
# Also search database for models accessible to this team
|
||||
# This complements the config search done above
|
||||
try:
|
||||
if team_object.models and SpecialModelNames.all_proxy_models.value not in team_object.models:
|
||||
# Team has specific models - check database for those model names
|
||||
db_models = await prisma_client.db.litellm_proxymodeltable.find_many(
|
||||
where={"model_name": {"in": team_object.models}}
|
||||
)
|
||||
for db_model in db_models:
|
||||
model_id = db_model.model_id
|
||||
if model_id:
|
||||
team_accessible_model_ids.add(model_id)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.debug(f"Error querying database models for team {team_id}: {str(e)}")
|
||||
|
||||
# Filter models based on direct_access or access_via_team_ids
|
||||
# Models are already enriched with these fields before this function is called
|
||||
filtered_models = []
|
||||
for _model in all_models:
|
||||
model_info = _model.get("model_info", {})
|
||||
model_id = model_info.get("id", None)
|
||||
|
||||
# Include if direct_access is True
|
||||
if model_info.get("direct_access", False):
|
||||
filtered_models.append(_model)
|
||||
continue
|
||||
|
||||
# Include if team_id is in access_via_team_ids
|
||||
access_via_team_ids = model_info.get("access_via_team_ids", [])
|
||||
if isinstance(access_via_team_ids, list) and team_id in access_via_team_ids:
|
||||
filtered_models.append(_model)
|
||||
continue
|
||||
|
||||
# Also include if model_id is in team_accessible_model_ids (from config/db search)
|
||||
# This catches models that might not have been enriched with access_via_team_ids yet
|
||||
if model_id and model_id in team_accessible_model_ids:
|
||||
filtered_models.append(_model)
|
||||
|
||||
return filtered_models
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v2/model/info",
|
||||
description="v2 - returns models available to the user based on their API key permissions. Shows model info from config.yaml (except api key and api base). Filter to just user-added models with ?user_models_only=true",
|
||||
@ -7916,6 +8030,12 @@ async def model_info_v2(
|
||||
search: Optional[str] = fastapi.Query(
|
||||
None, description="Search model names (case-insensitive partial match)"
|
||||
),
|
||||
modelId: Optional[str] = fastapi.Query(
|
||||
None, description="Search for a specific model by its unique ID"
|
||||
),
|
||||
teamId: Optional[str] = fastapi.Query(
|
||||
None, description="Filter models by team ID. Returns models with direct_access=True or teamId in access_via_team_ids"
|
||||
),
|
||||
):
|
||||
"""
|
||||
BETA ENDPOINT. Might change unexpectedly. Use `/v1/model/info` for now.
|
||||
@ -7940,24 +8060,65 @@ async def model_info_v2(
|
||||
|
||||
# Load existing config
|
||||
await proxy_config.get_config()
|
||||
all_models = copy.deepcopy(llm_router.model_list)
|
||||
|
||||
if user_model is not None:
|
||||
# if user does not use a config.yaml, https://github.com/BerriAI/litellm/issues/2061
|
||||
all_models += [user_model]
|
||||
# If modelId is provided, search for the specific model
|
||||
if modelId is not None:
|
||||
found_model = None
|
||||
|
||||
# First, search in config
|
||||
if llm_router is not None:
|
||||
found_model = llm_router.get_model_info(id=modelId)
|
||||
if found_model:
|
||||
found_model = copy.deepcopy(found_model)
|
||||
|
||||
# If not found in config, search in database
|
||||
if found_model is None:
|
||||
try:
|
||||
db_model = await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": modelId}
|
||||
)
|
||||
if db_model:
|
||||
# Convert database model to router format
|
||||
decrypted_models = proxy_config.decrypt_model_list_from_db([db_model])
|
||||
if decrypted_models:
|
||||
found_model = decrypted_models[0]
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
f"Error querying database for modelId {modelId}: {str(e)}"
|
||||
)
|
||||
|
||||
# If model found, verify search filter if provided
|
||||
if found_model is not None:
|
||||
if search is not None and search.strip():
|
||||
search_lower = search.lower().strip()
|
||||
model_name = found_model.get("model_name", "")
|
||||
if search_lower not in model_name.lower():
|
||||
# Model found but doesn't match search filter
|
||||
found_model = None
|
||||
|
||||
# Set all_models to the found model or empty list
|
||||
all_models = [found_model] if found_model is not None else []
|
||||
search_total_count: Optional[int] = len(all_models)
|
||||
else:
|
||||
# Normal flow when modelId is not provided
|
||||
all_models = copy.deepcopy(llm_router.model_list)
|
||||
|
||||
if model is not None:
|
||||
all_models = [m for m in all_models if m["model_name"] == model]
|
||||
if user_model is not None:
|
||||
# if user does not use a config.yaml, https://github.com/BerriAI/litellm/issues/2061
|
||||
all_models += [user_model]
|
||||
|
||||
# Apply search filter if provided
|
||||
all_models, search_total_count = await _apply_search_filter_to_models(
|
||||
all_models=all_models,
|
||||
search=search or "",
|
||||
page=page,
|
||||
size=size,
|
||||
prisma_client=prisma_client,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
if model is not None:
|
||||
all_models = [m for m in all_models if m["model_name"] == model]
|
||||
|
||||
# Apply search filter if provided
|
||||
all_models, search_total_count = await _apply_search_filter_to_models(
|
||||
all_models=all_models,
|
||||
search=search or "",
|
||||
page=page,
|
||||
size=size,
|
||||
prisma_client=prisma_client,
|
||||
proxy_config=proxy_config,
|
||||
)
|
||||
|
||||
if user_models_only:
|
||||
all_models = await non_admin_all_models(
|
||||
@ -7976,10 +8137,27 @@ async def model_info_v2(
|
||||
)
|
||||
|
||||
# Fill in model info based on config.yaml and litellm model_prices_and_context_window.json
|
||||
# This must happen before teamId filtering so that direct_access and access_via_team_ids are populated
|
||||
for i, _model in enumerate(all_models):
|
||||
all_models[i] = _enrich_model_info_with_litellm_data(
|
||||
model=_model, debug=debug if debug is not None else False, llm_router=llm_router
|
||||
)
|
||||
|
||||
# Apply teamId filter if provided
|
||||
if teamId is not None and teamId.strip():
|
||||
all_models = await _filter_models_by_team_id(
|
||||
all_models=all_models,
|
||||
team_id=teamId.strip(),
|
||||
prisma_client=prisma_client,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
# Update search_total_count after teamId filter is applied
|
||||
search_total_count = len(all_models)
|
||||
|
||||
# If modelId was provided, update search_total_count after filters are applied
|
||||
# to ensure pagination reflects the final filtered result (0 or 1)
|
||||
if modelId is not None:
|
||||
search_total_count = len(all_models)
|
||||
|
||||
verbose_proxy_logger.debug("all_models: %s", all_models)
|
||||
|
||||
|
||||
@ -3720,6 +3720,299 @@ async def test_model_info_v2_search_db_models(monkeypatch):
|
||||
app.dependency_overrides = original_overrides
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_info_v2_filter_by_model_id(monkeypatch):
|
||||
"""
|
||||
Test modelId parameter for filtering by specific model ID.
|
||||
Tests that modelId searches in router config first, then database.
|
||||
"""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.proxy_server import app, proxy_config, user_api_key_auth
|
||||
|
||||
# Create mock config models
|
||||
mock_config_models = [
|
||||
{
|
||||
"model_name": "gpt-4-turbo",
|
||||
"litellm_params": {"model": "gpt-4-turbo"},
|
||||
"model_info": {"id": "config-model-1"},
|
||||
},
|
||||
{
|
||||
"model_name": "claude-3-opus",
|
||||
"litellm_params": {"model": "claude-3-opus"},
|
||||
"model_info": {"id": "config-model-2"},
|
||||
},
|
||||
]
|
||||
|
||||
# Mock llm_router with get_model_info method
|
||||
mock_router = MagicMock()
|
||||
mock_router.model_list = mock_config_models
|
||||
mock_router.get_model_info = MagicMock(
|
||||
side_effect=lambda id: next(
|
||||
(m for m in mock_config_models if m["model_info"]["id"] == id), None
|
||||
)
|
||||
)
|
||||
|
||||
# Mock prisma_client for database queries
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_db_table = MagicMock()
|
||||
mock_prisma_client.db.litellm_proxymodeltable = mock_db_table
|
||||
|
||||
# Mock database model
|
||||
mock_db_model = MagicMock()
|
||||
mock_db_model.model_id = "db-model-1"
|
||||
mock_db_model.model_name = "db-gpt-3.5"
|
||||
mock_db_model.litellm_params = '{"model": "gpt-3.5-turbo"}'
|
||||
mock_db_model.model_info = '{"id": "db-model-1", "db_model": true}'
|
||||
|
||||
# Mock find_unique to return db model when searching for db-model-1
|
||||
async def mock_find_unique(where):
|
||||
if where.get("model_id") == "db-model-1":
|
||||
return mock_db_model
|
||||
return None
|
||||
|
||||
mock_db_table.find_unique = AsyncMock(side_effect=mock_find_unique)
|
||||
|
||||
# Mock proxy_config.decrypt_model_list_from_db
|
||||
def mock_decrypt_models(db_models_list):
|
||||
if db_models_list:
|
||||
return [
|
||||
{
|
||||
"model_name": db_models_list[0].model_name,
|
||||
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||
"model_info": {"id": db_models_list[0].model_id, "db_model": True},
|
||||
}
|
||||
]
|
||||
return []
|
||||
|
||||
# Mock proxy_config.get_config
|
||||
mock_get_config = AsyncMock(return_value={})
|
||||
|
||||
# Mock user authentication
|
||||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||
mock_user_api_key_dict.user_id = "test-user"
|
||||
mock_user_api_key_dict.api_key = "test-key"
|
||||
mock_user_api_key_dict.team_models = []
|
||||
mock_user_api_key_dict.models = []
|
||||
|
||||
# Apply monkeypatches
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router)
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.user_model", None)
|
||||
monkeypatch.setattr(proxy_config, "get_config", mock_get_config)
|
||||
monkeypatch.setattr(proxy_config, "decrypt_model_list_from_db", mock_decrypt_models)
|
||||
|
||||
# Override auth dependency
|
||||
original_overrides = app.dependency_overrides.copy()
|
||||
app.dependency_overrides[user_api_key_auth] = lambda: mock_user_api_key_dict
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
# Test Case 1: Filter by modelId that exists in config
|
||||
response = client.get("/v2/model/info", params={"modelId": "config-model-1"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_count"] == 1
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["model_info"]["id"] == "config-model-1"
|
||||
assert data["data"][0]["model_name"] == "gpt-4-turbo"
|
||||
# Verify router.get_model_info was called
|
||||
mock_router.get_model_info.assert_called_with(id="config-model-1")
|
||||
|
||||
# Test Case 2: Filter by modelId that exists in database (not in config)
|
||||
response = client.get("/v2/model/info", params={"modelId": "db-model-1"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_count"] == 1
|
||||
assert len(data["data"]) == 1
|
||||
assert data["data"][0]["model_info"]["id"] == "db-model-1"
|
||||
assert data["data"][0]["model_name"] == "db-gpt-3.5"
|
||||
# Verify database was queried
|
||||
mock_db_table.find_unique.assert_called()
|
||||
|
||||
# Test Case 3: Filter by modelId that doesn't exist
|
||||
mock_db_table.find_unique = AsyncMock(return_value=None)
|
||||
response = client.get("/v2/model/info", params={"modelId": "non-existent-model"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_count"] == 0
|
||||
assert len(data["data"]) == 0
|
||||
|
||||
# Test Case 4: Filter by modelId with search parameter (should filter further)
|
||||
response = client.get(
|
||||
"/v2/model/info", params={"modelId": "config-model-1", "search": "claude"}
|
||||
)
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# config-model-1 is gpt-4-turbo, doesn't match "claude", so should return empty
|
||||
assert data["total_count"] == 0
|
||||
assert len(data["data"]) == 0
|
||||
|
||||
finally:
|
||||
app.dependency_overrides = original_overrides
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_info_v2_filter_by_team_id(monkeypatch):
|
||||
"""
|
||||
Test teamId parameter for filtering models by team ID.
|
||||
Tests that teamId filters models based on direct_access or access_via_team_ids.
|
||||
"""
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth, LiteLLM_TeamTable
|
||||
from litellm.proxy.proxy_server import app, proxy_config, user_api_key_auth
|
||||
|
||||
# Create mock models with different access configurations
|
||||
mock_models = [
|
||||
{
|
||||
"model_name": "model-direct-access",
|
||||
"litellm_params": {"model": "gpt-4"},
|
||||
"model_info": {
|
||||
"id": "model-1",
|
||||
"direct_access": True, # Should be included
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "model-team-access",
|
||||
"litellm_params": {"model": "claude-3"},
|
||||
"model_info": {
|
||||
"id": "model-2",
|
||||
"direct_access": False,
|
||||
"access_via_team_ids": ["team-123"], # Should be included
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "model-no-access",
|
||||
"litellm_params": {"model": "gemini-pro"},
|
||||
"model_info": {
|
||||
"id": "model-3",
|
||||
"direct_access": False,
|
||||
"access_via_team_ids": ["team-456"], # Should NOT be included
|
||||
},
|
||||
},
|
||||
{
|
||||
"model_name": "model-multiple-teams",
|
||||
"litellm_params": {"model": "gpt-3.5"},
|
||||
"model_info": {
|
||||
"id": "model-4",
|
||||
"direct_access": False,
|
||||
"access_via_team_ids": ["team-789", "team-123"], # Should be included
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# Mock llm_router
|
||||
mock_router = MagicMock()
|
||||
mock_router.model_list = mock_models
|
||||
|
||||
# Mock get_model_list to return models based on model_name filter
|
||||
def mock_get_model_list(model_name=None, team_id=None):
|
||||
if model_name:
|
||||
return [m for m in mock_models if m["model_name"] == model_name]
|
||||
return mock_models
|
||||
|
||||
mock_router.get_model_list = MagicMock(side_effect=mock_get_model_list)
|
||||
|
||||
# Mock team database object - team has access to specific models
|
||||
mock_team_db_object = MagicMock()
|
||||
mock_team_db_object.model_dump.return_value = {
|
||||
"team_id": "team-123",
|
||||
"models": ["model-direct-access", "model-team-access", "model-multiple-teams"], # Specific models
|
||||
}
|
||||
|
||||
# Mock prisma_client
|
||||
mock_prisma_client = MagicMock()
|
||||
mock_team_table = MagicMock()
|
||||
mock_prisma_client.db.litellm_teamtable = mock_team_table
|
||||
mock_team_table.find_unique = AsyncMock(return_value=mock_team_db_object)
|
||||
|
||||
# Mock LiteLLM_TeamTable - team has access to specific models
|
||||
mock_team_object = LiteLLM_TeamTable(
|
||||
team_id="team-123",
|
||||
models=["model-direct-access", "model-team-access", "model-multiple-teams"],
|
||||
)
|
||||
|
||||
# Mock proxy_config.get_config
|
||||
mock_get_config = AsyncMock(return_value={})
|
||||
|
||||
# Mock user authentication
|
||||
mock_user_api_key_dict = MagicMock(spec=UserAPIKeyAuth)
|
||||
mock_user_api_key_dict.user_id = "test-user"
|
||||
mock_user_api_key_dict.api_key = "test-key"
|
||||
mock_user_api_key_dict.team_models = []
|
||||
mock_user_api_key_dict.models = []
|
||||
|
||||
# Apply monkeypatches
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.llm_router", mock_router)
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.prisma_client", mock_prisma_client)
|
||||
monkeypatch.setattr("litellm.proxy.proxy_server.user_model", None)
|
||||
monkeypatch.setattr(proxy_config, "get_config", mock_get_config)
|
||||
# Mock LiteLLM_TeamTable instantiation
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.proxy_server.LiteLLM_TeamTable",
|
||||
lambda **kwargs: mock_team_object,
|
||||
)
|
||||
|
||||
# Override auth dependency
|
||||
original_overrides = app.dependency_overrides.copy()
|
||||
app.dependency_overrides[user_api_key_auth] = lambda: mock_user_api_key_dict
|
||||
|
||||
client = TestClient(app)
|
||||
try:
|
||||
# Test Case 1: Filter by teamId - should return models with direct_access=True or team-123 in access_via_team_ids
|
||||
response = client.get("/v2/model/info", params={"teamId": "team-123"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should include: model-1 (direct_access), model-2 (team-123 in access_via_team_ids), model-4 (team-123 in access_via_team_ids)
|
||||
# Should NOT include: model-3 (team-456 only)
|
||||
assert data["total_count"] == 3
|
||||
assert len(data["data"]) == 3
|
||||
model_ids = [m["model_info"]["id"] for m in data["data"]]
|
||||
assert "model-1" in model_ids # direct_access
|
||||
assert "model-2" in model_ids # team-123 in access_via_team_ids
|
||||
assert "model-4" in model_ids # team-123 in access_via_team_ids
|
||||
assert "model-3" not in model_ids # Should be excluded
|
||||
|
||||
# Test Case 2: Filter by teamId that doesn't exist - should return empty list
|
||||
mock_team_table.find_unique = AsyncMock(return_value=None)
|
||||
response = client.get("/v2/model/info", params={"teamId": "non-existent-team"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["total_count"] == 0
|
||||
assert len(data["data"]) == 0
|
||||
|
||||
# Test Case 3: Filter by different teamId - should only return models with that team in access_via_team_ids
|
||||
mock_team_db_object_456 = MagicMock()
|
||||
mock_team_db_object_456.model_dump.return_value = {
|
||||
"team_id": "team-456",
|
||||
"models": ["model-no-access"], # Team has access to model-no-access
|
||||
}
|
||||
mock_team_table.find_unique = AsyncMock(return_value=mock_team_db_object_456)
|
||||
mock_team_object_456 = LiteLLM_TeamTable(
|
||||
team_id="team-456",
|
||||
models=["model-no-access"],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"litellm.proxy.proxy_server.LiteLLM_TeamTable",
|
||||
lambda **kwargs: mock_team_object_456,
|
||||
)
|
||||
|
||||
response = client.get("/v2/model/info", params={"teamId": "team-456"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
# Should include: model-1 (direct_access), model-3 (team-456 in access_via_team_ids)
|
||||
# Should NOT include: model-2 (team-123 only), model-4 (team-789 and team-123, but not team-456)
|
||||
assert data["total_count"] >= 2
|
||||
model_ids = [m["model_info"]["id"] for m in data["data"]]
|
||||
assert "model-1" in model_ids # direct_access
|
||||
assert "model-3" in model_ids # team-456 in access_via_team_ids
|
||||
|
||||
finally:
|
||||
app.dependency_overrides = original_overrides
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_apply_search_filter_to_models(monkeypatch):
|
||||
"""
|
||||
|
||||
@ -8,6 +8,7 @@
|
||||
"start": "next start",
|
||||
"lint": "next lint",
|
||||
"test": "vitest",
|
||||
"test:dot": "vitest --reporter=dot",
|
||||
"test:watch": "vitest -w",
|
||||
"test:coverage": "vitest run --coverage",
|
||||
"format": "prettier --write .",
|
||||
|
||||
@ -102,6 +102,8 @@ describe("useModelsInfo", () => {
|
||||
"Admin",
|
||||
1,
|
||||
50,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined
|
||||
);
|
||||
expect(modelInfoCall).toHaveBeenCalledTimes(1);
|
||||
@ -122,6 +124,8 @@ describe("useModelsInfo", () => {
|
||||
"Admin",
|
||||
2,
|
||||
25,
|
||||
undefined,
|
||||
undefined,
|
||||
undefined
|
||||
);
|
||||
});
|
||||
|
||||
@ -27,7 +27,7 @@ const modelHubKeys = createQueryKeys("modelHub");
|
||||
const allProxyModelsKeys = createQueryKeys("allProxyModels");
|
||||
const selectedTeamModelsKeys = createQueryKeys("selectedTeamModels");
|
||||
|
||||
export const useModelsInfo = (page: number = 1, size: number = 50, search?: string) => {
|
||||
export const useModelsInfo = (page: number = 1, size: number = 50, search?: string, modelId?: string, teamId?: string) => {
|
||||
const { accessToken, userId, userRole } = useAuthorized();
|
||||
return useQuery<PaginatedModelInfoResponse>({
|
||||
queryKey: modelKeys.list({
|
||||
@ -37,9 +37,11 @@ export const useModelsInfo = (page: number = 1, size: number = 50, search?: stri
|
||||
page,
|
||||
size,
|
||||
...(search && { search }),
|
||||
...(modelId && { modelId }),
|
||||
...(teamId && { teamId }),
|
||||
},
|
||||
}),
|
||||
queryFn: async () => await modelInfoCall(accessToken!, userId!, userRole!, page, size, search),
|
||||
queryFn: async () => await modelInfoCall(accessToken!, userId!, userRole!, page, size, search, modelId, teamId),
|
||||
enabled: Boolean(accessToken && userId && userRole),
|
||||
});
|
||||
};
|
||||
|
||||
@ -318,7 +318,6 @@ const ModelsAndEndpointsView: React.FC<ModelDashboardProps> = ({ premiumUser, te
|
||||
onClose={() => {
|
||||
setSelectedModelId(null);
|
||||
}}
|
||||
modelData={processedModelData.data.find((model: any) => model.model_info.id === selectedModelId)}
|
||||
accessToken={accessToken}
|
||||
userID={userID}
|
||||
userRole={userRole}
|
||||
|
||||
@ -8,7 +8,7 @@ import { getDisplayModelName } from "@/components/view_model/model_name_display"
|
||||
import { InfoCircleOutlined } from "@ant-design/icons";
|
||||
import { PaginationState } from "@tanstack/react-table";
|
||||
import { Grid, Select, SelectItem, TabPanel, Text } from "@tremor/react";
|
||||
import { Skeleton } from "antd";
|
||||
import { Skeleton, Spin } from "antd";
|
||||
import debounce from "lodash/debounce";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { useModelsInfo } from "../../hooks/models/useModels";
|
||||
@ -34,7 +34,7 @@ const AllModelsTab = ({
|
||||
}: AllModelsTabProps) => {
|
||||
const { data: modelCostMapData, isLoading: isLoadingModelCostMap } = useModelCostMap();
|
||||
const { userId, userRole, premiumUser } = useAuthorized();
|
||||
const { data: teams } = useTeams();
|
||||
const { data: teams, isLoading: isLoadingTeams } = useTeams();
|
||||
|
||||
const [modelNameSearch, setModelNameSearch] = useState<string>("");
|
||||
const [debouncedSearch, setDebouncedSearch] = useState<string>("");
|
||||
@ -69,7 +69,16 @@ const AllModelsTab = ({
|
||||
};
|
||||
}, [modelNameSearch, debouncedUpdateSearch]);
|
||||
|
||||
const { data: rawModelData, isLoading: isLoadingModelsInfo } = useModelsInfo(currentPage, pageSize, debouncedSearch || undefined);
|
||||
// Determine teamId to pass to the query - only pass if not "personal"
|
||||
const teamIdForQuery = currentTeam === "personal" ? undefined : currentTeam.team_id;
|
||||
|
||||
const { data: rawModelData, isLoading: isLoadingModelsInfo } = useModelsInfo(
|
||||
currentPage,
|
||||
pageSize,
|
||||
debouncedSearch || undefined,
|
||||
undefined,
|
||||
teamIdForQuery
|
||||
);
|
||||
const isLoading = isLoadingModelsInfo || isLoadingModelCostMap;
|
||||
|
||||
const getProviderFromModel = (model: string) => {
|
||||
@ -122,30 +131,21 @@ const AllModelsTab = ({
|
||||
model.model_info["access_groups"]?.includes(selectedModelAccessGroupFilter) ||
|
||||
!selectedModelAccessGroupFilter;
|
||||
|
||||
let teamAccessMatch = true;
|
||||
if (modelViewMode === "current_team") {
|
||||
if (currentTeam === "personal") {
|
||||
teamAccessMatch = model.model_info?.direct_access === true;
|
||||
} else {
|
||||
// Check if model is directly associated with the team via team_ids
|
||||
const directTeamAccess = model.model_info?.access_via_team_ids?.includes(currentTeam.team_id) === true;
|
||||
|
||||
// Check if any of the team's models match the model's access groups
|
||||
const accessGroupMatch =
|
||||
currentTeam.models?.some((teamModel: string) => model.model_info?.access_groups?.includes(teamModel)) ===
|
||||
true;
|
||||
|
||||
teamAccessMatch = directTeamAccess || accessGroupMatch;
|
||||
}
|
||||
}
|
||||
|
||||
return modelNameMatch && accessGroupMatch && teamAccessMatch;
|
||||
// Team filtering is now handled server-side via teamId query parameter
|
||||
// Only apply client-side filtering for model groups and access groups
|
||||
return modelNameMatch && accessGroupMatch;
|
||||
});
|
||||
}, [modelData, selectedModelGroup, selectedModelAccessGroupFilter, currentTeam, modelViewMode]);
|
||||
}, [modelData, selectedModelGroup, selectedModelAccessGroupFilter]);
|
||||
|
||||
useEffect(() => {
|
||||
setPagination((prev: PaginationState) => ({ ...prev, pageIndex: 0 }));
|
||||
}, [selectedModelGroup, selectedModelAccessGroupFilter, currentTeam, modelViewMode]);
|
||||
}, [selectedModelGroup, selectedModelAccessGroupFilter]);
|
||||
|
||||
// Reset pagination when team changes
|
||||
useEffect(() => {
|
||||
setCurrentPage(1);
|
||||
setPagination((prev: PaginationState) => ({ ...prev, pageIndex: 0 }));
|
||||
}, [teamIdForQuery]);
|
||||
|
||||
const resetFilters = () => {
|
||||
setModelNameSearch("");
|
||||
@ -177,9 +177,17 @@ const AllModelsTab = ({
|
||||
onValueChange={(value) => {
|
||||
if (value === "personal") {
|
||||
setCurrentTeam("personal");
|
||||
// Reset to page 1 when team changes
|
||||
setCurrentPage(1);
|
||||
setPagination((prev: PaginationState) => ({ ...prev, pageIndex: 0 }));
|
||||
} else {
|
||||
const team = teams?.find((t) => t.team_id === value);
|
||||
if (team) setCurrentTeam(team);
|
||||
if (team) {
|
||||
setCurrentTeam(team);
|
||||
// Reset to page 1 when team changes
|
||||
setCurrentPage(1);
|
||||
setPagination((prev: PaginationState) => ({ ...prev, pageIndex: 0 }));
|
||||
}
|
||||
}
|
||||
}}
|
||||
>
|
||||
@ -189,20 +197,29 @@ const AllModelsTab = ({
|
||||
<span className="font-medium">Personal</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
{teams
|
||||
?.filter((team) => team.team_id)
|
||||
.map((team) => (
|
||||
<SelectItem key={team.team_id} value={team.team_id}>
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="w-2 h-2 bg-green-500 rounded-full"></div>
|
||||
<span className="font-medium">
|
||||
{team.team_alias
|
||||
? `${team.team_alias.slice(0, 30)}...`
|
||||
: `Team ${team.team_id.slice(0, 30)}...`}
|
||||
</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))}
|
||||
{isLoadingTeams ? (
|
||||
<SelectItem value="loading">
|
||||
<div className="flex items-center gap-2">
|
||||
<Spin size="small" />
|
||||
<span className="font-medium text-gray-500">Loading teams...</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
) : (
|
||||
teams
|
||||
?.filter((team) => team.team_id)
|
||||
.map((team) => (
|
||||
<SelectItem key={team.team_id} value={team.team_id}>
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="w-2 h-2 bg-green-500 rounded-full"></div>
|
||||
<span className="font-medium">
|
||||
{team.team_alias
|
||||
? `${team.team_alias.slice(0, 30)}...`
|
||||
: `Team ${team.team_id.slice(0, 30)}...`}
|
||||
</span>
|
||||
</div>
|
||||
</SelectItem>
|
||||
))
|
||||
)}
|
||||
</Select>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@ -1,112 +1,38 @@
|
||||
import { QueryClient, QueryClientProvider } from "@tanstack/react-query";
|
||||
import { render, waitFor } from "@testing-library/react";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import React, { ReactNode } from "react";
|
||||
import { beforeEach, describe, expect, it, vi } from "vitest";
|
||||
import ModelInfoView from "./model_info_view";
|
||||
|
||||
vi.mock("../../utils/dataUtils", () => ({
|
||||
copyToClipboard: vi.fn(),
|
||||
copyToClipboard: vi.fn().mockResolvedValue(true),
|
||||
}));
|
||||
|
||||
vi.mock("./networking", () => ({
|
||||
modelInfoV1Call: vi.fn().mockResolvedValue({
|
||||
data: [
|
||||
{
|
||||
model_name: "aws/anthropic/bedrock-claude-3-5-sonnet-v1",
|
||||
model_name: "GPT-4",
|
||||
litellm_params: {
|
||||
aws_region_name: "us-east-1",
|
||||
custom_llm_provider: "bedrock",
|
||||
use_in_pass_through: false,
|
||||
use_litellm_proxy: false,
|
||||
merge_reasoning_content_in_choices: false,
|
||||
model: "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
model: "gpt-4",
|
||||
api_base: "https://api.openai.com/v1",
|
||||
custom_llm_provider: "openai",
|
||||
},
|
||||
model_info: {
|
||||
id: "70b94bbd2af4a75215f7e3e465b5b199529dc15deb5d395d0668a4aabc496c84",
|
||||
db_model: false,
|
||||
access_via_team_ids: [
|
||||
"4fe3cfea-c907-412a-a645-60915b618d11",
|
||||
"9a4b2d15-4198-47e4-971b-7329b77f40e4",
|
||||
"14d55eef-b8d4-4cb8-b080-d973269dae54",
|
||||
"693ce1d2-9fae-4605-a5c9-1c9829415e1a",
|
||||
"fe29d910-4968-45bc-9fe0-6716e89c6270",
|
||||
],
|
||||
direct_access: true,
|
||||
key: "anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
max_tokens: 4096,
|
||||
max_input_tokens: 200000,
|
||||
max_output_tokens: 4096,
|
||||
input_cost_per_token: 0.000003,
|
||||
input_cost_per_token_flex: null,
|
||||
input_cost_per_token_priority: null,
|
||||
cache_creation_input_token_cost: null,
|
||||
cache_read_input_token_cost: null,
|
||||
cache_read_input_token_cost_flex: null,
|
||||
cache_read_input_token_cost_priority: null,
|
||||
cache_creation_input_token_cost_above_1hr: null,
|
||||
input_cost_per_character: null,
|
||||
input_cost_per_token_above_128k_tokens: null,
|
||||
input_cost_per_token_above_200k_tokens: null,
|
||||
input_cost_per_query: null,
|
||||
input_cost_per_second: null,
|
||||
input_cost_per_audio_token: null,
|
||||
input_cost_per_token_batches: null,
|
||||
output_cost_per_token_batches: null,
|
||||
output_cost_per_token: 0.000015,
|
||||
output_cost_per_token_flex: null,
|
||||
output_cost_per_token_priority: null,
|
||||
output_cost_per_audio_token: null,
|
||||
output_cost_per_character: null,
|
||||
output_cost_per_reasoning_token: null,
|
||||
output_cost_per_token_above_128k_tokens: null,
|
||||
output_cost_per_character_above_128k_tokens: null,
|
||||
output_cost_per_token_above_200k_tokens: null,
|
||||
output_cost_per_second: null,
|
||||
output_cost_per_video_per_second: null,
|
||||
output_cost_per_image: null,
|
||||
output_vector_size: null,
|
||||
citation_cost_per_token: null,
|
||||
tiered_pricing: null,
|
||||
litellm_provider: "bedrock",
|
||||
mode: "chat",
|
||||
supports_system_messages: null,
|
||||
supports_response_schema: true,
|
||||
supports_vision: true,
|
||||
supports_function_calling: true,
|
||||
supports_tool_choice: true,
|
||||
supports_assistant_prefill: null,
|
||||
supports_prompt_caching: null,
|
||||
supports_audio_input: null,
|
||||
supports_audio_output: null,
|
||||
supports_pdf_input: true,
|
||||
supports_embedding_image_input: null,
|
||||
supports_native_streaming: null,
|
||||
supports_web_search: null,
|
||||
supports_url_context: null,
|
||||
supports_reasoning: null,
|
||||
supports_computer_use: null,
|
||||
search_context_cost_per_query: null,
|
||||
tpm: null,
|
||||
rpm: null,
|
||||
ocr_cost_per_page: null,
|
||||
annotation_cost_per_page: null,
|
||||
supported_openai_params: [
|
||||
"max_tokens",
|
||||
"max_completion_tokens",
|
||||
"stream",
|
||||
"stream_options",
|
||||
"stop",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"extra_headers",
|
||||
"response_format",
|
||||
"requestMetadata",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
],
|
||||
id: "123",
|
||||
created_by: "123",
|
||||
db_model: true,
|
||||
input_cost_per_token: 0.00003,
|
||||
output_cost_per_token: 0.00006,
|
||||
},
|
||||
},
|
||||
],
|
||||
}),
|
||||
credentialGetCall: vi.fn().mockResolvedValue({}),
|
||||
credentialGetCall: vi.fn().mockResolvedValue({
|
||||
credential_name: "test-credential",
|
||||
credential_values: {},
|
||||
credential_info: {},
|
||||
}),
|
||||
getGuardrailsList: vi.fn().mockResolvedValue({
|
||||
guardrails: [{ guardrail_name: "content_filter" }, { guardrail_name: "toxicity_filter" }],
|
||||
}),
|
||||
@ -120,99 +46,135 @@ vi.mock("./networking", () => ({
|
||||
description: "Production ready models",
|
||||
},
|
||||
}),
|
||||
testConnectionRequest: vi.fn().mockResolvedValue({
|
||||
status: "success",
|
||||
}),
|
||||
modelPatchUpdateCall: vi.fn().mockResolvedValue({}),
|
||||
modelDeleteCall: vi.fn().mockResolvedValue({}),
|
||||
}));
|
||||
|
||||
// Mock the useModelsInfo hook since it uses React Query
|
||||
const mockUseModelsInfo = vi.fn();
|
||||
const mockUseModelHub = vi.fn();
|
||||
|
||||
vi.mock("@/app/(dashboard)/hooks/models/useModels", () => ({
|
||||
useModelsInfo: vi.fn().mockReturnValue({
|
||||
data: {
|
||||
data: [
|
||||
{
|
||||
model_name: "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
provider: "bedrock",
|
||||
litellm_model_name: "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||
},
|
||||
{
|
||||
model_name: "openai/gpt-4",
|
||||
provider: "openai",
|
||||
litellm_model_name: "gpt-4",
|
||||
},
|
||||
],
|
||||
},
|
||||
isLoading: false,
|
||||
error: null,
|
||||
}),
|
||||
useModelsInfo: (...args: any[]) => mockUseModelsInfo(...args),
|
||||
useModelHub: (...args: any[]) => mockUseModelHub(...args),
|
||||
}));
|
||||
|
||||
// Mock the useModelCostMap hook
|
||||
const mockUseModelCostMap = vi.fn();
|
||||
vi.mock("@/app/(dashboard)/hooks/models/useModelCostMap", () => ({
|
||||
useModelCostMap: (...args: any[]) => mockUseModelCostMap(...args),
|
||||
}));
|
||||
|
||||
describe("ModelInfoView", () => {
|
||||
const modelData = {
|
||||
model_info: {
|
||||
id: "123",
|
||||
created_by: "123",
|
||||
db_model: true,
|
||||
},
|
||||
let queryClient: QueryClient;
|
||||
|
||||
const defaultModelData = {
|
||||
model_name: "GPT-4",
|
||||
litellm_params: {
|
||||
model: "gpt-4",
|
||||
api_base: "https://api.openai.com/v1",
|
||||
custom_llm_provider: "openai",
|
||||
},
|
||||
litellm_model_name: "gpt-4",
|
||||
model_name: "GPT-4",
|
||||
litellm_provider: "openai",
|
||||
mode: "chat",
|
||||
supported_openai_params: ["temperature", "max_tokens", "top_p", "frequency_penalty", "presence_penalty"],
|
||||
model_info: {
|
||||
id: "123",
|
||||
created_by: "123",
|
||||
created_at: "2024-01-01T00:00:00Z",
|
||||
db_model: true,
|
||||
input_cost_per_token: 0.00003,
|
||||
output_cost_per_token: 0.00006,
|
||||
},
|
||||
};
|
||||
|
||||
beforeEach(() => {
|
||||
queryClient = new QueryClient({
|
||||
defaultOptions: {
|
||||
queries: {
|
||||
retry: false,
|
||||
},
|
||||
},
|
||||
});
|
||||
vi.clearAllMocks();
|
||||
|
||||
// Set up default mocks
|
||||
mockUseModelsInfo.mockReturnValue({
|
||||
data: {
|
||||
data: [defaultModelData],
|
||||
},
|
||||
isLoading: false,
|
||||
error: null,
|
||||
});
|
||||
|
||||
mockUseModelHub.mockReturnValue({
|
||||
data: {
|
||||
data: [],
|
||||
},
|
||||
isLoading: false,
|
||||
error: null,
|
||||
});
|
||||
|
||||
mockUseModelCostMap.mockReturnValue({
|
||||
data: {},
|
||||
isLoading: false,
|
||||
error: null,
|
||||
});
|
||||
});
|
||||
|
||||
const wrapper = ({ children }: { children: ReactNode }) =>
|
||||
React.createElement(QueryClientProvider, { client: queryClient }, children);
|
||||
|
||||
const DEFAULT_ADMIN_PROPS = {
|
||||
modelId: "123",
|
||||
onClose: () => {},
|
||||
modelData: modelData,
|
||||
accessToken: "123",
|
||||
userID: "123",
|
||||
userRole: "Admin",
|
||||
editModel: false,
|
||||
setEditModalVisible: () => {},
|
||||
setSelectedModel: () => {},
|
||||
onModelUpdate: () => {},
|
||||
modelAccessGroups: [],
|
||||
};
|
||||
|
||||
describe("Edit Model", () => {
|
||||
it("should render the model info view", async () => {
|
||||
const { getByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />);
|
||||
const { getByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(getByText("Model Settings")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should not render an edit model button if the model is not a DB model", async () => {
|
||||
it("should not render an edit settings button if the model is not a DB model", async () => {
|
||||
const nonDbModelData = {
|
||||
...modelData,
|
||||
...defaultModelData,
|
||||
model_info: {
|
||||
...modelData.model_info,
|
||||
...defaultModelData.model_info,
|
||||
db_model: false,
|
||||
},
|
||||
};
|
||||
|
||||
const NON_DB_ADMIN_PROPS = {
|
||||
...DEFAULT_ADMIN_PROPS,
|
||||
modelData: nonDbModelData,
|
||||
};
|
||||
mockUseModelsInfo.mockReturnValue({
|
||||
data: {
|
||||
data: [nonDbModelData],
|
||||
},
|
||||
isLoading: false,
|
||||
error: null,
|
||||
});
|
||||
|
||||
const { queryByText } = render(<ModelInfoView {...NON_DB_ADMIN_PROPS} />);
|
||||
const { queryByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(queryByText("Edit Model")).not.toBeInTheDocument();
|
||||
expect(queryByText("Edit Settings")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should render tags in the edit model", async () => {
|
||||
const { getByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />);
|
||||
const { getByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(getByText("Tags")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should render the litellm params in the edit model", async () => {
|
||||
const { getByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />);
|
||||
const { getByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(getByText("LiteLLM Params")).toBeInTheDocument();
|
||||
});
|
||||
@ -220,21 +182,21 @@ describe("ModelInfoView", () => {
|
||||
});
|
||||
|
||||
it("should render a test connection button", async () => {
|
||||
const { getByTestId } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />);
|
||||
const { getByTestId } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(getByTestId("test-connection-button")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should render a reuse credentials button", async () => {
|
||||
const { getByTestId } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />);
|
||||
const { getByTestId } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(getByTestId("reuse-credentials-button")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should render a delete model button", async () => {
|
||||
const { getByTestId } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />);
|
||||
const { getByTestId } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(getByTestId("delete-model-button")).toBeInTheDocument();
|
||||
});
|
||||
@ -242,17 +204,22 @@ describe("ModelInfoView", () => {
|
||||
|
||||
it("should render a disabled delete model button if the model is not a DB model", async () => {
|
||||
const nonDbModelData = {
|
||||
...modelData,
|
||||
...defaultModelData,
|
||||
model_info: {
|
||||
...modelData.model_info,
|
||||
...defaultModelData.model_info,
|
||||
db_model: false,
|
||||
},
|
||||
};
|
||||
const NON_DB_ADMIN_PROPS = {
|
||||
...DEFAULT_ADMIN_PROPS,
|
||||
modelData: nonDbModelData,
|
||||
};
|
||||
const { getByTestId } = render(<ModelInfoView {...NON_DB_ADMIN_PROPS} />);
|
||||
|
||||
mockUseModelsInfo.mockReturnValue({
|
||||
data: {
|
||||
data: [nonDbModelData],
|
||||
},
|
||||
isLoading: false,
|
||||
error: null,
|
||||
});
|
||||
|
||||
const { getByTestId } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(getByTestId("delete-model-button")).toBeDisabled();
|
||||
});
|
||||
@ -260,18 +227,27 @@ describe("ModelInfoView", () => {
|
||||
|
||||
it("should render a disabled delete model button if the user is not an admin and model is not created by the user", async () => {
|
||||
const nonCreatedByUserModelData = {
|
||||
...modelData,
|
||||
...defaultModelData,
|
||||
model_info: {
|
||||
...modelData.model_info,
|
||||
...defaultModelData.model_info,
|
||||
created_by: "456",
|
||||
},
|
||||
};
|
||||
|
||||
mockUseModelsInfo.mockReturnValue({
|
||||
data: {
|
||||
data: [nonCreatedByUserModelData],
|
||||
},
|
||||
isLoading: false,
|
||||
error: null,
|
||||
});
|
||||
|
||||
const NON_CREATED_BY_USER_ADMIN_PROPS = {
|
||||
...DEFAULT_ADMIN_PROPS,
|
||||
modelData: nonCreatedByUserModelData,
|
||||
userRole: "User",
|
||||
};
|
||||
const { getByTestId } = render(<ModelInfoView {...NON_CREATED_BY_USER_ADMIN_PROPS} />);
|
||||
|
||||
const { getByTestId } = render(<ModelInfoView {...NON_CREATED_BY_USER_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(getByTestId("delete-model-button")).toBeDisabled();
|
||||
});
|
||||
@ -279,16 +255,22 @@ describe("ModelInfoView", () => {
|
||||
|
||||
it("should render health check model field for wildcard routes", async () => {
|
||||
const wildcardModelData = {
|
||||
...modelData,
|
||||
litellm_model_name: "openai/gpt-4*",
|
||||
...defaultModelData,
|
||||
litellm_params: {
|
||||
...defaultModelData.litellm_params,
|
||||
model: "openai/gpt-4*",
|
||||
},
|
||||
};
|
||||
|
||||
const WILDCARD_ADMIN_PROPS = {
|
||||
...DEFAULT_ADMIN_PROPS,
|
||||
modelData: wildcardModelData,
|
||||
};
|
||||
mockUseModelsInfo.mockReturnValue({
|
||||
data: {
|
||||
data: [wildcardModelData],
|
||||
},
|
||||
isLoading: false,
|
||||
error: null,
|
||||
});
|
||||
|
||||
const { getByText } = render(<ModelInfoView {...WILDCARD_ADMIN_PROPS} />);
|
||||
const { getByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(getByText("Model Settings")).toBeInTheDocument();
|
||||
});
|
||||
@ -298,7 +280,7 @@ describe("ModelInfoView", () => {
|
||||
});
|
||||
|
||||
it("should not render health check model field for non-wildcard routes", async () => {
|
||||
const { queryByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />);
|
||||
const { queryByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(queryByText("Model Settings")).toBeInTheDocument();
|
||||
});
|
||||
@ -309,14 +291,14 @@ describe("ModelInfoView", () => {
|
||||
|
||||
describe("View Model", () => {
|
||||
it("should render the model info view", async () => {
|
||||
const { getByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />);
|
||||
const { getByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(getByText("Model Settings")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
it("should render tags in the view model", async () => {
|
||||
const { getByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />);
|
||||
const { getByText } = render(<ModelInfoView {...DEFAULT_ADMIN_PROPS} />, { wrapper });
|
||||
await waitFor(() => {
|
||||
expect(getByText("Tags")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
import { useModelsInfo } from "@/app/(dashboard)/hooks/models/useModels";
|
||||
import { useModelCostMap } from "@/app/(dashboard)/hooks/models/useModelCostMap";
|
||||
import { useModelHub, useModelsInfo } from "@/app/(dashboard)/hooks/models/useModels";
|
||||
import { transformModelData } from "@/app/(dashboard)/models-and-endpoints/utils/modelDataTransformer";
|
||||
import { InfoCircleOutlined } from "@ant-design/icons";
|
||||
import { ArrowLeftIcon, KeyIcon, RefreshIcon, TrashIcon } from "@heroicons/react/outline";
|
||||
import {
|
||||
@ -16,7 +18,7 @@ import {
|
||||
} from "@tremor/react";
|
||||
import { Button, Form, Input, Modal, Select, Tooltip } from "antd";
|
||||
import { CheckIcon, CopyIcon } from "lucide-react";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import { copyToClipboard as utilCopyToClipboard } from "../utils/dataUtils";
|
||||
import { formItemValidateJSON, truncateString } from "../utils/textUtils";
|
||||
import CacheControlSettings from "./add_model/cache_control_settings";
|
||||
@ -43,7 +45,6 @@ import { getDisplayModelName } from "./view_model/model_name_display";
|
||||
interface ModelInfoViewProps {
|
||||
modelId: string;
|
||||
onClose: () => void;
|
||||
modelData: any;
|
||||
accessToken: string | null;
|
||||
userID: string | null;
|
||||
userRole: string | null;
|
||||
@ -54,7 +55,6 @@ interface ModelInfoViewProps {
|
||||
export default function ModelInfoView({
|
||||
modelId,
|
||||
onClose,
|
||||
modelData,
|
||||
accessToken,
|
||||
userID,
|
||||
userRole,
|
||||
@ -75,20 +75,64 @@ export default function ModelInfoView({
|
||||
const [isAutoRouterModalOpen, setIsAutoRouterModalOpen] = useState(false);
|
||||
const [guardrailsList, setGuardrailsList] = useState<string[]>([]);
|
||||
const [tagsList, setTagsList] = useState<Record<string, Tag>>({});
|
||||
|
||||
// Fetch model data using hook
|
||||
const { data: rawModelDataResponse, isLoading: isLoadingModel } = useModelsInfo(1, 50, undefined, modelId);
|
||||
const { data: modelCostMapData } = useModelCostMap();
|
||||
const { data: modelHubData } = useModelHub();
|
||||
|
||||
// Transform the model data
|
||||
const getProviderFromModel = (model: string) => {
|
||||
if (modelCostMapData !== null && modelCostMapData !== undefined) {
|
||||
if (typeof modelCostMapData == "object" && model in modelCostMapData) {
|
||||
return modelCostMapData[model]["litellm_provider"];
|
||||
}
|
||||
}
|
||||
return "openai";
|
||||
};
|
||||
|
||||
const transformedModelData = useMemo(() => {
|
||||
if (!rawModelDataResponse?.data || rawModelDataResponse.data.length === 0) {
|
||||
return null;
|
||||
}
|
||||
const transformed = transformModelData(rawModelDataResponse, getProviderFromModel);
|
||||
return transformed.data[0] || null;
|
||||
}, [rawModelDataResponse, modelCostMapData]);
|
||||
|
||||
// Keep modelData variable name for backwards compatibility
|
||||
const modelData = transformedModelData;
|
||||
|
||||
const canEditModel =
|
||||
(userRole === "Admin" || modelData?.model_info?.created_by === userID) && modelData?.model_info?.db_model;
|
||||
const isAdmin = userRole === "Admin";
|
||||
const isAutoRouter = modelData?.litellm_params?.auto_router_config != null;
|
||||
|
||||
const { data: modelsInfoData } = useModelsInfo();
|
||||
console.log("modelsInfoData, ", modelsInfoData);
|
||||
const usingExistingCredential =
|
||||
modelData?.litellm_params?.litellm_credential_name != null &&
|
||||
modelData?.litellm_params?.litellm_credential_name != undefined;
|
||||
console.log("usingExistingCredential, ", usingExistingCredential);
|
||||
console.log("modelData.litellm_params.litellm_credential_name, ", modelData?.litellm_params?.litellm_credential_name);
|
||||
|
||||
console.log("tagsList, ", modelData.litellm_params?.tags);
|
||||
// Initialize localModelData from modelData when available
|
||||
useEffect(() => {
|
||||
if (modelData && !localModelData) {
|
||||
let processedModelData = modelData;
|
||||
if (!processedModelData.litellm_model_name) {
|
||||
processedModelData = {
|
||||
...processedModelData,
|
||||
litellm_model_name:
|
||||
processedModelData?.litellm_params?.litellm_model_name ??
|
||||
processedModelData?.litellm_params?.model ??
|
||||
processedModelData?.model_info?.key ??
|
||||
null,
|
||||
};
|
||||
}
|
||||
setLocalModelData(processedModelData);
|
||||
|
||||
// Check if cache control is enabled
|
||||
if (processedModelData?.litellm_params?.cache_control_injection_points) {
|
||||
setShowCacheControl(true);
|
||||
}
|
||||
}
|
||||
}, [modelData, localModelData]);
|
||||
|
||||
useEffect(() => {
|
||||
const getExistingCredential = async () => {
|
||||
@ -106,6 +150,8 @@ export default function ModelInfoView({
|
||||
|
||||
const getModelInfo = async () => {
|
||||
if (!accessToken) return;
|
||||
// Only fetch if we don't have modelData yet
|
||||
if (modelData) return;
|
||||
let modelInfoResponse = await modelInfoV1Call(accessToken, modelId);
|
||||
console.log("modelInfoResponse, ", modelInfoResponse);
|
||||
let specificModelData = modelInfoResponse.data[0];
|
||||
@ -270,6 +316,19 @@ export default function ModelInfoView({
|
||||
}
|
||||
};
|
||||
|
||||
// Show loading state
|
||||
if (isLoadingModel) {
|
||||
return (
|
||||
<div className="p-4">
|
||||
<TremorButton icon={ArrowLeftIcon} variant="light" onClick={onClose} className="mb-4">
|
||||
Back to Models
|
||||
</TremorButton>
|
||||
<Text>Loading...</Text>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// Show not found if model is not found
|
||||
if (!modelData) {
|
||||
return (
|
||||
<div className="p-4">
|
||||
@ -353,6 +412,7 @@ export default function ModelInfoView({
|
||||
}
|
||||
};
|
||||
const isWildcardModel = modelData.litellm_model_name.includes("*");
|
||||
console.log("isWildcardModel, ", isWildcardModel);
|
||||
|
||||
return (
|
||||
<div className="p-4">
|
||||
@ -369,11 +429,10 @@ export default function ModelInfoView({
|
||||
size="small"
|
||||
icon={copiedStates["model-id"] ? <CheckIcon size={12} /> : <CopyIcon size={12} />}
|
||||
onClick={() => copyToClipboard(modelData.model_info.id, "model-id")}
|
||||
className={`left-2 z-10 transition-all duration-200 ${
|
||||
copiedStates["model-id"]
|
||||
? "text-green-600 bg-green-50 border-green-200"
|
||||
: "text-gray-500 hover:text-gray-700 hover:bg-gray-100"
|
||||
}`}
|
||||
className={`left-2 z-10 transition-all duration-200 ${copiedStates["model-id"]
|
||||
? "text-green-600 bg-green-50 border-green-200"
|
||||
: "text-gray-500 hover:text-gray-700 hover:bg-gray-100"
|
||||
}`}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@ -484,10 +543,10 @@ export default function ModelInfoView({
|
||||
Created At{" "}
|
||||
{modelData.model_info.created_at
|
||||
? new Date(modelData.model_info.created_at).toLocaleDateString("en-US", {
|
||||
month: "short",
|
||||
day: "numeric",
|
||||
year: "numeric",
|
||||
})
|
||||
month: "short",
|
||||
day: "numeric",
|
||||
year: "numeric",
|
||||
})
|
||||
: "Not Set"}
|
||||
</div>
|
||||
<div className="flex items-center gap-x-2">
|
||||
@ -891,27 +950,19 @@ export default function ModelInfoView({
|
||||
optionFilterProp="children"
|
||||
allowClear
|
||||
options={(() => {
|
||||
const seen = new Set();
|
||||
return modelsInfoData?.data
|
||||
const wildcardProvider = modelData.litellm_model_name.split("/")[0];
|
||||
return modelHubData?.data
|
||||
?.filter((model: any) => {
|
||||
const modelProvider = model.provider;
|
||||
const wildcardProvider = modelData.litellm_model_name.split("/")[0];
|
||||
// Filter by provider to match the wildcard provider
|
||||
return (
|
||||
modelProvider === wildcardProvider &&
|
||||
model.model_name !== modelData.litellm_model_name
|
||||
model.providers?.includes(wildcardProvider) &&
|
||||
model.model_group !== modelData.litellm_model_name
|
||||
);
|
||||
})
|
||||
.filter((model: any) => {
|
||||
if (seen.has(model.model_name)) {
|
||||
return false;
|
||||
}
|
||||
seen.add(model.model_name);
|
||||
return true;
|
||||
})
|
||||
.map((model: any) => ({
|
||||
value: model.model_name,
|
||||
label: model.model_name,
|
||||
}));
|
||||
value: model.model_group,
|
||||
label: model.model_group,
|
||||
})) || [];
|
||||
})()}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
@ -2007,12 +2007,12 @@ export const regenerateKeyCall = async (accessToken: string, keyToRegenerate: st
|
||||
let ModelListerrorShown = false;
|
||||
let errorTimer: NodeJS.Timeout | null = null;
|
||||
|
||||
export const modelInfoCall = async (accessToken: string, userID: string, userRole: string, page: number = 1, size: number = 50, search?: string) => {
|
||||
export const modelInfoCall = async (accessToken: string, userID: string, userRole: string, page: number = 1, size: number = 50, search?: string, modelId?: string, teamId?: string) => {
|
||||
/**
|
||||
* Get all models on proxy
|
||||
*/
|
||||
try {
|
||||
console.log("modelInfoCall:", accessToken, userID, userRole, page, size, search);
|
||||
console.log("modelInfoCall:", accessToken, userID, userRole, page, size, search, modelId, teamId);
|
||||
let url = proxyBaseUrl ? `${proxyBaseUrl}/v2/model/info` : `/v2/model/info`;
|
||||
const params = new URLSearchParams();
|
||||
params.append("include_team_models", "true");
|
||||
@ -2021,6 +2021,12 @@ export const modelInfoCall = async (accessToken: string, userID: string, userRol
|
||||
if (search && search.trim()) {
|
||||
params.append("search", search.trim());
|
||||
}
|
||||
if (modelId && modelId.trim()) {
|
||||
params.append("modelId", modelId.trim());
|
||||
}
|
||||
if (teamId && teamId.trim()) {
|
||||
params.append("teamId", teamId.trim());
|
||||
}
|
||||
if (params.toString()) {
|
||||
url += `?${params.toString()}`;
|
||||
}
|
||||
|
||||
@ -1025,7 +1025,6 @@ const OldModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||
setSelectedModelId(null);
|
||||
setEditModel(false);
|
||||
}}
|
||||
modelData={modelData.data.find((model: any) => model.model_info.id === selectedModelId)}
|
||||
accessToken={accessToken}
|
||||
userID={userID}
|
||||
userRole={userRole}
|
||||
@ -1270,9 +1269,9 @@ const OldModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||
<span className="text-sm text-gray-700">
|
||||
{filteredData.length > 0
|
||||
? `Showing ${pagination.pageIndex * pagination.pageSize + 1} - ${Math.min(
|
||||
(pagination.pageIndex + 1) * pagination.pageSize,
|
||||
filteredData.length,
|
||||
)} of ${filteredData.length} results`
|
||||
(pagination.pageIndex + 1) * pagination.pageSize,
|
||||
filteredData.length,
|
||||
)} of ${filteredData.length} results`
|
||||
: "Showing 0 results"}
|
||||
</span>
|
||||
|
||||
@ -1284,11 +1283,10 @@ const OldModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||
setPagination((prev) => ({ ...prev, pageIndex: prev.pageIndex - 1 }))
|
||||
}
|
||||
disabled={pagination.pageIndex === 0}
|
||||
className={`px-3 py-1 text-sm border rounded-md ${
|
||||
pagination.pageIndex === 0
|
||||
? "bg-gray-100 text-gray-400 cursor-not-allowed"
|
||||
: "hover:bg-gray-50"
|
||||
}`}
|
||||
className={`px-3 py-1 text-sm border rounded-md ${pagination.pageIndex === 0
|
||||
? "bg-gray-100 text-gray-400 cursor-not-allowed"
|
||||
: "hover:bg-gray-50"
|
||||
}`}
|
||||
>
|
||||
Previous
|
||||
</button>
|
||||
@ -1300,11 +1298,10 @@ const OldModelDashboard: React.FC<ModelDashboardProps> = ({
|
||||
disabled={
|
||||
pagination.pageIndex >= Math.ceil(filteredData.length / pagination.pageSize) - 1
|
||||
}
|
||||
className={`px-3 py-1 text-sm border rounded-md ${
|
||||
pagination.pageIndex >= Math.ceil(filteredData.length / pagination.pageSize) - 1
|
||||
? "bg-gray-100 text-gray-400 cursor-not-allowed"
|
||||
: "hover:bg-gray-50"
|
||||
}`}
|
||||
className={`px-3 py-1 text-sm border rounded-md ${pagination.pageIndex >= Math.ceil(filteredData.length / pagination.pageSize) - 1
|
||||
? "bg-gray-100 text-gray-400 cursor-not-allowed"
|
||||
: "hover:bg-gray-50"
|
||||
}`}
|
||||
>
|
||||
Next
|
||||
</button>
|
||||
|
||||
Loading…
Reference in New Issue
Block a user