Agents - assign tools (#22064)
* feat(proxy): add max_iterations limiter for agent session loops (#22058) Adds a new proxy hook that enforces a per-session cap on the number of LLM calls an agentic loop can make. Callers send a session_id with each request, and the hook counts calls per session, returning 429 when the configured max_iterations limit is exceeded. - Uses Redis Lua script for atomic increment (multi-instance safe) - Falls back to in-memory cache when Redis unavailable - Follows parallel_request_limiter_v3 pattern - Configurable via key metadata: {"max_iterations": 25} - Session counters auto-expire via TTL (default 1hr) Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> * feat: add new code execution dataset * feat(agent_endpoints/): allow giving agents keys * fix: ui fixes * feat: allow assigning mcp servers to agents * fix: eliminate duplicate DB queries in MCP agent auth and N+1 in agent listing (#22110) - Extract _get_agent_object_permission helper so _get_allowed_mcp_servers_for_agent and _get_agent_tool_permissions_for_server share a single DB fetch instead of each independently querying the same agent row (was 1+N queries per MCP request) - Use include={"object_permission": True} on find_many in get_all_agents_from_db to eagerly load permissions in one query instead of N+1 - Use include={"object_permission": True} on create/update/find_unique in all agent CRUD operations, removing attach_object_permission_to_dict follow-up calls Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
a4baa022c5
commit
12c4876891
@ -0,0 +1,40 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "LiteLLM_AgentsTable" ADD COLUMN "object_permission_id" TEXT;
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "LiteLLM_MCPServerTable" DROP COLUMN "spec_path";
|
||||
|
||||
-- AlterTable
|
||||
ALTER TABLE "LiteLLM_VerificationToken" ADD COLUMN "agent_id" TEXT;
|
||||
|
||||
-- CreateTable
|
||||
CREATE TABLE "LiteLLM_ToolTable" (
|
||||
"tool_id" TEXT NOT NULL,
|
||||
"tool_name" TEXT NOT NULL,
|
||||
"origin" TEXT,
|
||||
"call_policy" TEXT NOT NULL DEFAULT 'untrusted',
|
||||
"call_count" INTEGER NOT NULL DEFAULT 0,
|
||||
"assignments" JSONB DEFAULT '{}',
|
||||
"key_hash" TEXT,
|
||||
"team_id" TEXT,
|
||||
"key_alias" TEXT,
|
||||
"created_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"created_by" TEXT,
|
||||
"updated_at" TIMESTAMP(3) NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
"updated_by" TEXT,
|
||||
|
||||
CONSTRAINT "LiteLLM_ToolTable_pkey" PRIMARY KEY ("tool_id")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE UNIQUE INDEX "LiteLLM_ToolTable_tool_name_key" ON "LiteLLM_ToolTable"("tool_name");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LiteLLM_ToolTable_call_policy_idx" ON "LiteLLM_ToolTable"("call_policy");
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LiteLLM_ToolTable_team_id_idx" ON "LiteLLM_ToolTable"("team_id");
|
||||
|
||||
-- AddForeignKey
|
||||
ALTER TABLE "LiteLLM_AgentsTable" ADD CONSTRAINT "LiteLLM_AgentsTable_object_permission_id_fkey" FOREIGN KEY ("object_permission_id") REFERENCES "LiteLLM_ObjectPermissionTable"("object_permission_id") ON DELETE SET NULL ON UPDATE CASCADE;
|
||||
|
||||
@ -64,6 +64,8 @@ model LiteLLM_AgentsTable {
|
||||
litellm_params Json?
|
||||
agent_card_params Json
|
||||
agent_access_groups String[] @default([])
|
||||
object_permission_id String?
|
||||
object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id])
|
||||
created_at DateTime @default(now()) @map("created_at")
|
||||
created_by String
|
||||
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||
@ -264,6 +266,7 @@ model LiteLLM_ObjectPermissionTable {
|
||||
organizations LiteLLM_OrganizationTable[]
|
||||
users LiteLLM_UserTable[]
|
||||
end_users LiteLLM_EndUserTable[]
|
||||
agents_table LiteLLM_AgentsTable[]
|
||||
}
|
||||
|
||||
// Holds the MCP server configuration
|
||||
@ -273,7 +276,6 @@ model LiteLLM_MCPServerTable {
|
||||
alias String?
|
||||
description String?
|
||||
url String?
|
||||
spec_path String?
|
||||
transport String @default("sse")
|
||||
auth_type String?
|
||||
credentials Json? @default("{}")
|
||||
@ -315,6 +317,7 @@ model LiteLLM_VerificationToken {
|
||||
router_settings Json? @default("{}")
|
||||
user_id String?
|
||||
team_id String?
|
||||
agent_id String?
|
||||
project_id String?
|
||||
permissions Json @default("{}")
|
||||
max_parallel_requests Int?
|
||||
@ -1052,6 +1055,26 @@ model LiteLLM_PolicyAttachmentTable {
|
||||
updated_by String?
|
||||
}
|
||||
|
||||
// Global tool registry - auto-discovered from LLM responses; admins set call_policy here
|
||||
model LiteLLM_ToolTable {
|
||||
tool_id String @id @default(uuid())
|
||||
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
|
||||
origin String? // MCP server name or "user_defined"
|
||||
call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked"
|
||||
call_count Int @default(0) // cumulative number of times this tool was seen
|
||||
assignments Json? @default("{}")
|
||||
key_hash String? // hash of the virtual key that first called this tool
|
||||
team_id String? // team that first called this tool
|
||||
key_alias String? // human-readable alias of the virtual key
|
||||
created_at DateTime @default(now())
|
||||
created_by String?
|
||||
updated_at DateTime @default(now()) @updatedAt
|
||||
updated_by String?
|
||||
|
||||
@@index([call_policy])
|
||||
@@index([team_id])
|
||||
}
|
||||
|
||||
//Unified Access Groups table for storing unified access groups
|
||||
model LiteLLM_AccessGroupTable {
|
||||
access_group_id String @id @default(uuid())
|
||||
|
||||
@ -412,6 +412,26 @@ class MCPRequestHandler:
|
||||
)
|
||||
return []
|
||||
|
||||
#########################################################
|
||||
# Check agent permissions if agent_id is set on the key
|
||||
#########################################################
|
||||
if user_api_key_auth and user_api_key_auth.agent_id:
|
||||
allowed_mcp_servers_for_agent = (
|
||||
await MCPRequestHandler._get_allowed_mcp_servers_for_agent(
|
||||
user_api_key_auth
|
||||
)
|
||||
)
|
||||
if len(allowed_mcp_servers_for_agent) > 0:
|
||||
# Intersect: agent can only use servers allowed by BOTH key/team AND agent config
|
||||
allowed_mcp_servers = [
|
||||
s
|
||||
for s in allowed_mcp_servers
|
||||
if s in allowed_mcp_servers_for_agent
|
||||
]
|
||||
verbose_logger.debug(
|
||||
f"Applied agent intersection filter. Final allowed servers: {allowed_mcp_servers}"
|
||||
)
|
||||
|
||||
return list(set(allowed_mcp_servers))
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to get allowed MCP servers: {str(e)}")
|
||||
@ -513,13 +533,33 @@ class MCPRequestHandler:
|
||||
if team_tools:
|
||||
if key_tools:
|
||||
# Both have restrictions → intersection
|
||||
return list(set(team_tools) & set(key_tools))
|
||||
allowed_tools = list(set(team_tools) & set(key_tools))
|
||||
else:
|
||||
# Only team has restrictions → inherit from team
|
||||
return team_tools
|
||||
allowed_tools = team_tools
|
||||
else:
|
||||
# No team restrictions → use key restrictions
|
||||
return key_tools
|
||||
allowed_tools = key_tools
|
||||
|
||||
# Intersect with agent's tool permissions if agent_id is set
|
||||
if user_api_key_auth.agent_id:
|
||||
# Pre-fetch agent object_permission once to avoid duplicate DB query
|
||||
agent_obj_perm = await MCPRequestHandler._get_agent_object_permission(
|
||||
user_api_key_auth
|
||||
)
|
||||
agent_tools = await MCPRequestHandler._get_agent_tool_permissions_for_server(
|
||||
server_id=server_id,
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
agent_object_permission=agent_obj_perm,
|
||||
)
|
||||
if agent_tools is not None:
|
||||
if allowed_tools is not None:
|
||||
allowed_tools = list(
|
||||
set(allowed_tools) & set(agent_tools)
|
||||
)
|
||||
else:
|
||||
allowed_tools = agent_tools
|
||||
return allowed_tools
|
||||
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to get allowed tools for server: {str(e)}")
|
||||
@ -715,6 +755,131 @@ class MCPRequestHandler:
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _get_agent_object_permission(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
):
|
||||
"""
|
||||
Fetch the agent's object_permission from the DB (single query).
|
||||
|
||||
Returns the object_permission object or None.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if not user_api_key_auth or not user_api_key_auth.agent_id:
|
||||
return None
|
||||
|
||||
if prisma_client is None:
|
||||
verbose_logger.debug("prisma_client is None")
|
||||
return None
|
||||
|
||||
try:
|
||||
agent_row = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": user_api_key_auth.agent_id},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
if agent_row is None or agent_row.object_permission is None:
|
||||
return None
|
||||
|
||||
return agent_row.object_permission
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to get agent object permission: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
async def _get_allowed_mcp_servers_for_agent(
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
agent_object_permission=None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Get allowed MCP servers for an agent (from the agent's object_permission).
|
||||
|
||||
Returns the MCP servers from the agent's object_permission.
|
||||
If agent has no object_permission, returns [] (no extra restriction).
|
||||
|
||||
Args:
|
||||
user_api_key_auth: User auth with agent_id
|
||||
agent_object_permission: Pre-fetched object_permission to avoid duplicate DB query.
|
||||
If None, will be fetched from DB.
|
||||
"""
|
||||
if not user_api_key_auth or not user_api_key_auth.agent_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
obj_perm = agent_object_permission
|
||||
if obj_perm is None:
|
||||
obj_perm = await MCPRequestHandler._get_agent_object_permission(
|
||||
user_api_key_auth
|
||||
)
|
||||
if obj_perm is None:
|
||||
return []
|
||||
|
||||
direct_mcp_servers = getattr(obj_perm, "mcp_servers", None) or []
|
||||
if isinstance(direct_mcp_servers, str):
|
||||
direct_mcp_servers = []
|
||||
mcp_access_groups = getattr(obj_perm, "mcp_access_groups", None) or []
|
||||
if isinstance(mcp_access_groups, str):
|
||||
mcp_access_groups = []
|
||||
|
||||
access_group_servers = (
|
||||
await MCPRequestHandler._get_mcp_servers_from_access_groups(
|
||||
mcp_access_groups
|
||||
)
|
||||
)
|
||||
all_servers = list(direct_mcp_servers) + access_group_servers
|
||||
return list(set(all_servers))
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to get allowed MCP servers for agent: {str(e)}"
|
||||
)
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
async def _get_agent_tool_permissions_for_server(
|
||||
server_id: str,
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth] = None,
|
||||
agent_object_permission=None,
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Get allowed tool names for a server from the agent's object_permission.
|
||||
Returns None if agent has no tool restrictions for this server.
|
||||
|
||||
Args:
|
||||
server_id: Server ID to check permissions for
|
||||
user_api_key_auth: User auth with agent_id
|
||||
agent_object_permission: Pre-fetched object_permission to avoid duplicate DB query.
|
||||
If None, will be fetched from DB.
|
||||
"""
|
||||
if not user_api_key_auth or not user_api_key_auth.agent_id:
|
||||
return None
|
||||
|
||||
try:
|
||||
obj_perm = agent_object_permission
|
||||
if obj_perm is None:
|
||||
obj_perm = await MCPRequestHandler._get_agent_object_permission(
|
||||
user_api_key_auth
|
||||
)
|
||||
if obj_perm is None:
|
||||
return None
|
||||
|
||||
mcp_tool_permissions = getattr(
|
||||
obj_perm, "mcp_tool_permissions", None
|
||||
)
|
||||
if not mcp_tool_permissions:
|
||||
return None
|
||||
if isinstance(mcp_tool_permissions, dict):
|
||||
tools = mcp_tool_permissions.get(server_id)
|
||||
else:
|
||||
tools = None
|
||||
return list(tools) if tools else None
|
||||
except Exception as e:
|
||||
verbose_logger.warning(
|
||||
f"Failed to get agent tool permissions for server: {str(e)}"
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_config_server_ids_for_access_groups(
|
||||
config_mcp_servers, access_groups: List[str]
|
||||
|
||||
@ -1,60 +1,40 @@
|
||||
import enum
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Union
|
||||
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Literal,
|
||||
Optional, Union)
|
||||
|
||||
import httpx
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
Json,
|
||||
field_validator,
|
||||
model_validator,
|
||||
)
|
||||
from pydantic import (BaseModel, ConfigDict, Field, Json, field_validator,
|
||||
model_validator)
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
from litellm._uuid import uuid
|
||||
from litellm.types.integrations.slack_alerting import AlertType
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
OpenAIFileObject,
|
||||
ResponsesAPIResponse,
|
||||
)
|
||||
from litellm.types.mcp import (
|
||||
MCPAuth,
|
||||
MCPAuthType,
|
||||
MCPCredentials,
|
||||
MCPTransport,
|
||||
MCPTransportType,
|
||||
)
|
||||
from litellm.types.llms.openai import (AllMessageValues, OpenAIFileObject,
|
||||
ResponsesAPIResponse)
|
||||
from litellm.types.mcp import (MCPAuth, MCPAuthType, MCPCredentials,
|
||||
MCPTransport, MCPTransportType)
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPInfo
|
||||
from litellm.types.router import RouterErrors, UpdateRouterConfig
|
||||
from litellm.types.secret_managers.main import KeyManagementSystem
|
||||
from litellm.types.utils import (
|
||||
CallTypes,
|
||||
CostBreakdown,
|
||||
EmbeddingResponse,
|
||||
GenericBudgetConfigType,
|
||||
ImageResponse,
|
||||
LiteLLMBatch,
|
||||
LiteLLMFineTuningJob,
|
||||
LiteLLMPydanticObjectBase,
|
||||
ModelResponse,
|
||||
ProviderField,
|
||||
StandardCallbackDynamicParams,
|
||||
StandardLoggingGuardrailInformation,
|
||||
StandardLoggingMCPToolCall,
|
||||
StandardLoggingModelInformation,
|
||||
StandardLoggingPayloadErrorInformation,
|
||||
StandardLoggingPayloadStatus,
|
||||
StandardLoggingVectorStoreRequest,
|
||||
StandardPassThroughResponseObject,
|
||||
TextCompletionResponse,
|
||||
)
|
||||
from litellm.types.utils import (CallTypes, CostBreakdown, EmbeddingResponse,
|
||||
GenericBudgetConfigType, ImageResponse,
|
||||
LiteLLMBatch, LiteLLMFineTuningJob,
|
||||
LiteLLMPydanticObjectBase, ModelResponse,
|
||||
ProviderField, StandardCallbackDynamicParams,
|
||||
StandardLoggingGuardrailInformation,
|
||||
StandardLoggingMCPToolCall,
|
||||
StandardLoggingModelInformation,
|
||||
StandardLoggingPayloadErrorInformation,
|
||||
StandardLoggingPayloadStatus,
|
||||
StandardLoggingVectorStoreRequest,
|
||||
StandardPassThroughResponseObject,
|
||||
TextCompletionResponse)
|
||||
from litellm.types.videos.main import VideoObject
|
||||
|
||||
from .types_utils.utils import get_instance_fn, validate_custom_validate_return_type
|
||||
from .types_utils.utils import (get_instance_fn,
|
||||
validate_custom_validate_return_type)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from opentelemetry.trace import Span as _Span
|
||||
@ -2202,6 +2182,7 @@ class LiteLLM_VerificationToken(LiteLLMPydanticObjectBase):
|
||||
config: Dict = {}
|
||||
user_id: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
agent_id: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
max_parallel_requests: Optional[int] = None
|
||||
metadata: Dict = {}
|
||||
@ -2379,7 +2360,8 @@ class UserAPIKeyAuth(
|
||||
|
||||
This is used to track number of requests/spend for health check calls.
|
||||
"""
|
||||
from litellm.constants import LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME
|
||||
from litellm.constants import \
|
||||
LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME
|
||||
|
||||
return cls(
|
||||
api_key=LITTELM_INTERNAL_HEALTH_SERVICE_ACCOUNT_NAME,
|
||||
@ -2411,7 +2393,8 @@ class UserAPIKeyAuth(
|
||||
|
||||
This is used to track actions performed by automated system jobs.
|
||||
"""
|
||||
from litellm.constants import LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME
|
||||
from litellm.constants import \
|
||||
LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME
|
||||
|
||||
return cls(
|
||||
api_key=LITELLM_INTERNAL_JOBS_SERVICE_ACCOUNT_NAME,
|
||||
@ -2802,7 +2785,8 @@ class LiteLLM_AuditLogs(LiteLLMPydanticObjectBase):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def mask_api_keys(self):
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
|
||||
from litellm.litellm_core_utils.sensitive_data_masker import \
|
||||
SensitiveDataMasker
|
||||
|
||||
masker = SensitiveDataMasker(sensitive_patterns={"key"})
|
||||
|
||||
|
||||
@ -5,6 +5,9 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.proxy.management_helpers.object_permission_utils import (
|
||||
handle_update_object_permission_common,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.agents import AgentConfig, AgentResponse, PatchAgentRequest
|
||||
|
||||
@ -117,20 +120,39 @@ class AgentRegistry:
|
||||
)
|
||||
agent_card_params: str = safe_dumps(agent_card_params_dict)
|
||||
|
||||
# Handle object_permission (MCP tool access for agent)
|
||||
object_permission_id: Optional[str] = None
|
||||
if agent.get("object_permission") is not None:
|
||||
agent_copy = dict(agent)
|
||||
object_permission_id = await handle_update_object_permission_common(
|
||||
agent_copy, None, prisma_client
|
||||
)
|
||||
|
||||
create_data: Dict[str, Any] = {
|
||||
"agent_name": agent_name,
|
||||
"litellm_params": litellm_params,
|
||||
"agent_card_params": agent_card_params,
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
if object_permission_id is not None:
|
||||
create_data["object_permission_id"] = object_permission_id
|
||||
|
||||
# Create agent in DB
|
||||
created_agent = await prisma_client.db.litellm_agentstable.create(
|
||||
data={
|
||||
"agent_name": agent_name,
|
||||
"litellm_params": litellm_params,
|
||||
"agent_card_params": agent_card_params,
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
data=create_data,
|
||||
include={"object_permission": True},
|
||||
)
|
||||
|
||||
return AgentResponse(**created_agent.model_dump()) # type: ignore
|
||||
created_agent_dict = created_agent.model_dump()
|
||||
if created_agent.object_permission is not None:
|
||||
try:
|
||||
created_agent_dict["object_permission"] = created_agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
created_agent_dict["object_permission"] = created_agent.object_permission.dict()
|
||||
return AgentResponse(**created_agent_dict) # type: ignore
|
||||
except Exception as e:
|
||||
raise Exception(f"Error adding agent to DB: {str(e)}")
|
||||
|
||||
@ -181,7 +203,7 @@ class AgentRegistry:
|
||||
raise Exception(f"Agent with ID {agent_id} not found")
|
||||
|
||||
augment_agent = {**existing_agent, **agent}
|
||||
update_data = {}
|
||||
update_data: Dict[str, Any] = {}
|
||||
if augment_agent.get("agent_name"):
|
||||
update_data["agent_name"] = augment_agent.get("agent_name")
|
||||
if augment_agent.get("litellm_params"):
|
||||
@ -192,6 +214,20 @@ class AgentRegistry:
|
||||
update_data["agent_card_params"] = safe_dumps(
|
||||
augment_agent.get("agent_card_params")
|
||||
)
|
||||
if agent.get("object_permission") is not None:
|
||||
agent_copy = dict(augment_agent)
|
||||
existing_object_permission_id = existing_agent.get(
|
||||
"object_permission_id"
|
||||
)
|
||||
object_permission_id = (
|
||||
await handle_update_object_permission_common(
|
||||
agent_copy,
|
||||
existing_object_permission_id,
|
||||
prisma_client,
|
||||
)
|
||||
)
|
||||
if object_permission_id is not None:
|
||||
update_data["object_permission_id"] = object_permission_id
|
||||
# Patch agent in DB
|
||||
patched_agent = await prisma_client.db.litellm_agentstable.update(
|
||||
where={"agent_id": agent_id},
|
||||
@ -200,8 +236,15 @@ class AgentRegistry:
|
||||
"updated_by": updated_by,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
return AgentResponse(**patched_agent.model_dump()) # type: ignore
|
||||
patched_agent_dict = patched_agent.model_dump()
|
||||
if patched_agent.object_permission is not None:
|
||||
try:
|
||||
patched_agent_dict["object_permission"] = patched_agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
patched_agent_dict["object_permission"] = patched_agent.object_permission.dict()
|
||||
return AgentResponse(**patched_agent_dict) # type: ignore
|
||||
except Exception as e:
|
||||
raise Exception(f"Error patching agent in DB: {str(e)}")
|
||||
|
||||
@ -238,19 +281,47 @@ class AgentRegistry:
|
||||
)
|
||||
agent_card_params: str = safe_dumps(agent_card_params_dict)
|
||||
|
||||
update_data: Dict[str, Any] = {
|
||||
"agent_name": agent_name,
|
||||
"litellm_params": litellm_params,
|
||||
"agent_card_params": agent_card_params,
|
||||
"updated_by": updated_by,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
}
|
||||
if agent.get("object_permission") is not None:
|
||||
existing_agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
)
|
||||
existing_object_permission_id = (
|
||||
existing_agent.object_permission_id
|
||||
if existing_agent is not None
|
||||
else None
|
||||
)
|
||||
agent_copy = dict(agent)
|
||||
object_permission_id = (
|
||||
await handle_update_object_permission_common(
|
||||
agent_copy,
|
||||
existing_object_permission_id,
|
||||
prisma_client,
|
||||
)
|
||||
)
|
||||
if object_permission_id is not None:
|
||||
update_data["object_permission_id"] = object_permission_id
|
||||
|
||||
# Update agent in DB
|
||||
updated_agent = await prisma_client.db.litellm_agentstable.update(
|
||||
where={"agent_id": agent_id},
|
||||
data={
|
||||
"agent_name": agent_name,
|
||||
"litellm_params": litellm_params,
|
||||
"agent_card_params": agent_card_params,
|
||||
"updated_by": updated_by,
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
},
|
||||
data=update_data,
|
||||
include={"object_permission": True},
|
||||
)
|
||||
|
||||
return AgentResponse(**updated_agent.model_dump()) # type: ignore
|
||||
updated_agent_dict = updated_agent.model_dump()
|
||||
if updated_agent.object_permission is not None:
|
||||
try:
|
||||
updated_agent_dict["object_permission"] = updated_agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
updated_agent_dict["object_permission"] = updated_agent.object_permission.dict()
|
||||
return AgentResponse(**updated_agent_dict) # type: ignore
|
||||
except Exception as e:
|
||||
raise Exception(f"Error updating agent in DB: {str(e)}")
|
||||
|
||||
@ -264,11 +335,19 @@ class AgentRegistry:
|
||||
try:
|
||||
agents_from_db = await prisma_client.db.litellm_agentstable.find_many(
|
||||
order={"created_at": "desc"},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
|
||||
agents: List[Dict[str, Any]] = []
|
||||
for agent in agents_from_db:
|
||||
agents.append(dict(agent))
|
||||
agent_dict = dict(agent)
|
||||
# object_permission is eagerly loaded via include above
|
||||
if agent.object_permission is not None:
|
||||
try:
|
||||
agent_dict["object_permission"] = agent.object_permission.model_dump()
|
||||
except Exception:
|
||||
agent_dict["object_permission"] = agent.object_permission.dict()
|
||||
agents.append(agent_dict)
|
||||
|
||||
return agents
|
||||
except Exception as e:
|
||||
|
||||
@ -16,6 +16,7 @@ import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
|
||||
from litellm.types.agents import (
|
||||
AgentConfig,
|
||||
AgentMakePublicResponse,
|
||||
@ -23,8 +24,6 @@ from litellm.types.agents import (
|
||||
MakeAgentsPublicRequest,
|
||||
PatchAgentRequest,
|
||||
)
|
||||
|
||||
from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
@ -233,11 +232,18 @@ async def get_agent_by_id(agent_id: str):
|
||||
try:
|
||||
agent = AGENT_REGISTRY.get_agent_by_id(agent_id=agent_id)
|
||||
if agent is None:
|
||||
agent = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id}
|
||||
agent_row = await prisma_client.db.litellm_agentstable.find_unique(
|
||||
where={"agent_id": agent_id},
|
||||
include={"object_permission": True},
|
||||
)
|
||||
if agent is not None:
|
||||
agent = AgentResponse(**agent.model_dump()) # type: ignore
|
||||
if agent_row is not None:
|
||||
agent_dict = agent_row.model_dump()
|
||||
if agent_row.object_permission is not None:
|
||||
try:
|
||||
agent_dict["object_permission"] = agent_row.object_permission.model_dump()
|
||||
except Exception:
|
||||
agent_dict["object_permission"] = agent_row.object_permission.dict()
|
||||
agent = AgentResponse(**agent_dict) # type: ignore
|
||||
|
||||
if agent is None:
|
||||
raise HTTPException(
|
||||
|
||||
208
litellm/proxy/hooks/max_iterations_limiter.py
Normal file
208
litellm/proxy/hooks/max_iterations_limiter.py
Normal file
@ -0,0 +1,208 @@
|
||||
"""
|
||||
Max Iterations Limiter for LiteLLM Proxy.
|
||||
|
||||
Enforces a per-session cap on the number of LLM calls an agentic loop can make.
|
||||
Callers send a `session_id` with each request (via `x-litellm-session-id` header
|
||||
or `metadata.session_id`), and this hook counts calls per session. When the count
|
||||
exceeds `max_iterations` (configured in key/team metadata), returns 429.
|
||||
|
||||
Works across multiple proxy instances via DualCache (in-memory + Redis).
|
||||
Follows the same pattern as parallel_request_limiter_v3.py.
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm import DualCache
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import InternalUsageCache as _InternalUsageCache
|
||||
|
||||
InternalUsageCache = _InternalUsageCache
|
||||
else:
|
||||
InternalUsageCache = Any
|
||||
|
||||
|
||||
# Redis Lua script for atomic increment with TTL.
|
||||
# Returns the new count after increment.
|
||||
# Only sets EXPIRE on first increment (when count becomes 1).
|
||||
MAX_ITERATIONS_INCREMENT_SCRIPT = """
|
||||
local key = KEYS[1]
|
||||
local ttl = tonumber(ARGV[1])
|
||||
|
||||
local current = redis.call('INCR', key)
|
||||
if current == 1 then
|
||||
redis.call('EXPIRE', key, ttl)
|
||||
end
|
||||
|
||||
return current
|
||||
"""
|
||||
|
||||
# Default TTL for session iteration counters (1 hour)
|
||||
DEFAULT_MAX_ITERATIONS_TTL = 3600
|
||||
|
||||
|
||||
class _PROXY_MaxIterationsHandler(CustomLogger):
|
||||
"""
|
||||
Pre-call hook that enforces max_iterations per session.
|
||||
|
||||
Configuration:
|
||||
- max_iterations: set in key metadata via /key/generate or /key/update
|
||||
e.g. metadata={"max_iterations": 25}
|
||||
- session_id: sent by caller via x-litellm-session-id header or
|
||||
metadata.session_id in request body
|
||||
|
||||
Cache key pattern:
|
||||
{session_iterations:<session_id>}:count
|
||||
|
||||
Multi-instance support:
|
||||
Uses Redis Lua script for atomic increment (same pattern as
|
||||
parallel_request_limiter_v3). Falls back to in-memory cache
|
||||
when Redis is unavailable.
|
||||
"""
|
||||
|
||||
def __init__(self, internal_usage_cache: InternalUsageCache):
|
||||
self.internal_usage_cache = internal_usage_cache
|
||||
self.ttl = int(
|
||||
os.getenv("LITELLM_MAX_ITERATIONS_TTL", DEFAULT_MAX_ITERATIONS_TTL)
|
||||
)
|
||||
|
||||
# Register Lua script with Redis if available (same pattern as v3 limiter)
|
||||
if self.internal_usage_cache.dual_cache.redis_cache is not None:
|
||||
self.increment_script = (
|
||||
self.internal_usage_cache.dual_cache.redis_cache.async_register_script(
|
||||
MAX_ITERATIONS_INCREMENT_SCRIPT
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.increment_script = None
|
||||
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: str,
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
"""
|
||||
Check session iteration count before making the API call.
|
||||
|
||||
Extracts session_id from request metadata and max_iterations from
|
||||
key metadata. If the session has exceeded max_iterations, raises 429.
|
||||
"""
|
||||
# Extract session_id from request data
|
||||
session_id = self._get_session_id(data)
|
||||
if session_id is None:
|
||||
return None
|
||||
|
||||
# Extract max_iterations from key metadata
|
||||
max_iterations = self._get_max_iterations(user_api_key_dict)
|
||||
if max_iterations is None:
|
||||
return None
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"MaxIterationsHandler: session_id=%s, max_iterations=%s",
|
||||
session_id,
|
||||
max_iterations,
|
||||
)
|
||||
|
||||
# Increment and check
|
||||
cache_key = self._make_cache_key(session_id)
|
||||
current_count = await self._increment_and_get(cache_key)
|
||||
|
||||
if current_count > max_iterations:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=(
|
||||
f"Max iterations exceeded for session {session_id}. "
|
||||
f"Current count: {current_count}, max_iterations: {max_iterations}."
|
||||
),
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"MaxIterationsHandler: session_id=%s, count=%s/%s",
|
||||
session_id,
|
||||
current_count,
|
||||
max_iterations,
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _get_session_id(self, data: dict) -> Optional[str]:
|
||||
"""Extract session_id from request metadata."""
|
||||
metadata = data.get("metadata") or {}
|
||||
session_id = metadata.get("session_id")
|
||||
if session_id is not None:
|
||||
return str(session_id)
|
||||
|
||||
# Also check litellm_metadata (used for /thread and /assistant endpoints)
|
||||
litellm_metadata = data.get("litellm_metadata") or {}
|
||||
session_id = litellm_metadata.get("session_id")
|
||||
if session_id is not None:
|
||||
return str(session_id)
|
||||
|
||||
return None
|
||||
|
||||
def _get_max_iterations(
|
||||
self, user_api_key_dict: UserAPIKeyAuth
|
||||
) -> Optional[int]:
|
||||
"""Extract max_iterations from key metadata."""
|
||||
metadata = user_api_key_dict.metadata or {}
|
||||
max_iterations = metadata.get("max_iterations")
|
||||
if max_iterations is not None:
|
||||
return int(max_iterations)
|
||||
return None
|
||||
|
||||
def _make_cache_key(self, session_id: str) -> str:
|
||||
"""
|
||||
Create cache key for session iteration counter.
|
||||
|
||||
Uses Redis hash-tag pattern {session_iterations:<session_id>} so all
|
||||
keys for a session land on the same Redis Cluster slot.
|
||||
"""
|
||||
return f"{{session_iterations:{session_id}}}:count"
|
||||
|
||||
async def _increment_and_get(self, cache_key: str) -> int:
|
||||
"""
|
||||
Atomically increment the session counter and return the new value.
|
||||
|
||||
Tries Redis first (via registered Lua script for atomicity across
|
||||
instances), falls back to in-memory cache.
|
||||
"""
|
||||
if self.increment_script is not None:
|
||||
try:
|
||||
result = await self.increment_script(
|
||||
keys=[cache_key],
|
||||
args=[self.ttl],
|
||||
)
|
||||
return int(result)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"MaxIterationsHandler: Redis failed, falling back to in-memory: %s",
|
||||
str(e),
|
||||
)
|
||||
|
||||
# Fallback: in-memory cache
|
||||
return await self._in_memory_increment(cache_key)
|
||||
|
||||
async def _in_memory_increment(self, cache_key: str) -> int:
|
||||
"""Increment counter in in-memory cache with TTL."""
|
||||
current = await self.internal_usage_cache.async_get_cache(
|
||||
key=cache_key,
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=True,
|
||||
)
|
||||
new_value = (int(current) if current is not None else 0) + 1
|
||||
await self.internal_usage_cache.async_set_cache(
|
||||
key=cache_key,
|
||||
value=new_value,
|
||||
ttl=self.ttl,
|
||||
litellm_parent_otel_span=None,
|
||||
local_only=True,
|
||||
)
|
||||
return new_value
|
||||
@ -167,16 +167,25 @@ class AugmentedAgentCard(AgentCard):
|
||||
is_public: bool
|
||||
|
||||
|
||||
# Object permission shape for agent MCP tool access (mirrors LiteLLM_ObjectPermissionBase)
|
||||
class AgentObjectPermission(TypedDict, total=False):
|
||||
mcp_servers: Optional[List[str]]
|
||||
mcp_access_groups: Optional[List[str]]
|
||||
mcp_tool_permissions: Optional[Dict[str, List[str]]]
|
||||
|
||||
|
||||
class AgentConfig(TypedDict, total=False):
|
||||
agent_name: Required[str]
|
||||
agent_card_params: Required[AgentCard]
|
||||
litellm_params: Dict[str, Any] # allow for any future litellm params
|
||||
object_permission: AgentObjectPermission
|
||||
|
||||
|
||||
class PatchAgentRequest(TypedDict, total=False):
|
||||
agent_name: str
|
||||
agent_card_params: AgentCard
|
||||
litellm_params: Dict[str, Any]
|
||||
object_permission: AgentObjectPermission
|
||||
|
||||
|
||||
# Request/Response models for CRUD endpoints
|
||||
@ -187,6 +196,7 @@ class AgentResponse(BaseModel):
|
||||
agent_name: str
|
||||
litellm_params: Optional[Dict[str, Any]] = None
|
||||
agent_card_params: Dict[str, Any]
|
||||
object_permission: Optional[Dict[str, Any]] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
|
||||
@ -64,6 +64,8 @@ model LiteLLM_AgentsTable {
|
||||
litellm_params Json?
|
||||
agent_card_params Json
|
||||
agent_access_groups String[] @default([])
|
||||
object_permission_id String?
|
||||
object_permission LiteLLM_ObjectPermissionTable? @relation(fields: [object_permission_id], references: [object_permission_id])
|
||||
created_at DateTime @default(now()) @map("created_at")
|
||||
created_by String
|
||||
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||
@ -264,6 +266,7 @@ model LiteLLM_ObjectPermissionTable {
|
||||
organizations LiteLLM_OrganizationTable[]
|
||||
users LiteLLM_UserTable[]
|
||||
end_users LiteLLM_EndUserTable[]
|
||||
agents_table LiteLLM_AgentsTable[]
|
||||
}
|
||||
|
||||
// Holds the MCP server configuration
|
||||
@ -273,7 +276,6 @@ model LiteLLM_MCPServerTable {
|
||||
alias String?
|
||||
description String?
|
||||
url String?
|
||||
spec_path String?
|
||||
transport String @default("sse")
|
||||
auth_type String?
|
||||
credentials Json? @default("{}")
|
||||
@ -315,6 +317,7 @@ model LiteLLM_VerificationToken {
|
||||
router_settings Json? @default("{}")
|
||||
user_id String?
|
||||
team_id String?
|
||||
agent_id String?
|
||||
project_id String?
|
||||
permissions Json @default("{}")
|
||||
max_parallel_requests Int?
|
||||
@ -1052,6 +1055,26 @@ model LiteLLM_PolicyAttachmentTable {
|
||||
updated_by String?
|
||||
}
|
||||
|
||||
// Global tool registry - auto-discovered from LLM responses; admins set call_policy here
|
||||
model LiteLLM_ToolTable {
|
||||
tool_id String @id @default(uuid())
|
||||
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
|
||||
origin String? // MCP server name or "user_defined"
|
||||
call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked"
|
||||
call_count Int @default(0) // cumulative number of times this tool was seen
|
||||
assignments Json? @default("{}")
|
||||
key_hash String? // hash of the virtual key that first called this tool
|
||||
team_id String? // team that first called this tool
|
||||
key_alias String? // human-readable alias of the virtual key
|
||||
created_at DateTime @default(now())
|
||||
created_by String?
|
||||
updated_at DateTime @default(now()) @updatedAt
|
||||
updated_by String?
|
||||
|
||||
@@index([call_policy])
|
||||
@@index([team_id])
|
||||
}
|
||||
|
||||
//Unified Access Groups table for storing unified access groups
|
||||
model LiteLLM_AccessGroupTable {
|
||||
access_group_id String @id @default(uuid())
|
||||
|
||||
186
scripts/test_agent_mcp_endpoints.sh
Executable file
186
scripts/test_agent_mcp_endpoints.sh
Executable file
@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env bash
|
||||
#
|
||||
# Test agent endpoint-level changes for MCP tool permissions (object_permission).
|
||||
# Requires: proxy running, valid admin API key, curl, jq.
|
||||
#
|
||||
# Usage:
|
||||
# export LITELLM_PROXY_BASE_URL="http://localhost:4000" # optional, default below
|
||||
# export LITELLM_API_KEY="sk-..." # required
|
||||
# ./scripts/test_agent_mcp_endpoints.sh
|
||||
#
|
||||
set -euo pipefail
|
||||
|
||||
BASE_URL="${LITELLM_PROXY_BASE_URL:-http://localhost:4000}"
|
||||
API_KEY="${LITELLM_API_KEY:-}"
|
||||
|
||||
if ! command -v jq &>/dev/null; then
|
||||
echo "Error: jq is required. Install with: brew install jq (macOS) or apt install jq (Linux)"
|
||||
exit 1
|
||||
fi
|
||||
if [[ -z "$API_KEY" ]]; then
|
||||
echo "Error: LITELLM_API_KEY is not set. Export it or pass via env."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
AUTH_HEADER="Authorization: Bearer $API_KEY"
|
||||
AGENT_NAME="test-agent-mcp-$(date +%s)"
|
||||
|
||||
# Colors for output
|
||||
RED='\033[0;31m'
|
||||
GREEN='\033[0;32m'
|
||||
YELLOW='\033[1;33m'
|
||||
NC='\033[0m'
|
||||
|
||||
pass() { echo -e "${GREEN}PASS${NC}: $*"; }
|
||||
fail() { echo -e "${RED}FAIL${NC}: $*"; exit 1; }
|
||||
info() { echo -e "${YELLOW}INFO${NC}: $*"; }
|
||||
|
||||
# --- 1. Create agent with object_permission ---
|
||||
info "Creating agent with object_permission (mcp_servers, mcp_tool_permissions)..."
|
||||
CREATE_RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/v1/agents" \
|
||||
-H "$AUTH_HEADER" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"agent_name": "'"$AGENT_NAME"'",
|
||||
"agent_card_params": {
|
||||
"protocolVersion": "1.0",
|
||||
"name": "Test MCP Agent",
|
||||
"description": "Agent for endpoint tests",
|
||||
"url": "http://localhost:9999/",
|
||||
"version": "1.0.0",
|
||||
"defaultInputModes": ["text"],
|
||||
"defaultOutputModes": ["text"],
|
||||
"capabilities": {"streaming": true},
|
||||
"skills": []
|
||||
},
|
||||
"object_permission": {
|
||||
"mcp_servers": ["server_1", "server_2"],
|
||||
"mcp_access_groups": ["group_a"],
|
||||
"mcp_tool_permissions": {"server_1": ["tool_a", "tool_b"], "server_2": ["tool_c"]}
|
||||
}
|
||||
}')
|
||||
HTTP_CODE=$(echo "$CREATE_RESP" | tail -n1)
|
||||
BODY=$(echo "$CREATE_RESP" | sed '$d')
|
||||
if [[ "$HTTP_CODE" != "200" ]]; then
|
||||
fail "POST /v1/agents returned $HTTP_CODE. Body: $BODY"
|
||||
fi
|
||||
AGENT_ID=$(echo "$BODY" | jq -r '.agent_id')
|
||||
if [[ -z "$AGENT_ID" || "$AGENT_ID" == "null" ]]; then
|
||||
fail "POST /v1/agents did not return agent_id. Body: $BODY"
|
||||
fi
|
||||
pass "Created agent $AGENT_ID"
|
||||
|
||||
# Check create response includes object_permission
|
||||
OP=$(echo "$BODY" | jq '.object_permission')
|
||||
if [[ "$OP" == "null" || -z "$OP" ]]; then
|
||||
fail "POST /v1/agents response missing object_permission. Body: $BODY"
|
||||
fi
|
||||
SERVERS=$(echo "$OP" | jq -r '.mcp_servers | join(",")')
|
||||
if [[ "$SERVERS" != "server_1,server_2" ]]; then
|
||||
fail "object_permission.mcp_servers unexpected: $SERVERS"
|
||||
fi
|
||||
pass "Create response includes object_permission with mcp_servers and mcp_tool_permissions"
|
||||
|
||||
# --- 2. GET /v1/agents (list) includes object_permission for our agent ---
|
||||
info "GET /v1/agents and check one agent has object_permission..."
|
||||
LIST_RESP=$(curl -s -w "\n%{http_code}" -X GET "$BASE_URL/v1/agents" -H "$AUTH_HEADER")
|
||||
LIST_CODE=$(echo "$LIST_RESP" | tail -n1)
|
||||
LIST_BODY=$(echo "$LIST_RESP" | sed '$d')
|
||||
if [[ "$LIST_CODE" != "200" ]]; then
|
||||
fail "GET /v1/agents returned $LIST_CODE"
|
||||
fi
|
||||
AGENT_IN_LIST=$(echo "$LIST_BODY" | jq --arg id "$AGENT_ID" '.[] | select(.agent_id == $id)')
|
||||
if [[ -z "$AGENT_IN_LIST" ]]; then
|
||||
fail "GET /v1/agents did not return agent $AGENT_ID (list might be key-scoped)"
|
||||
fi
|
||||
OP_LIST=$(echo "$AGENT_IN_LIST" | jq '.object_permission')
|
||||
if [[ "$OP_LIST" == "null" || -z "$OP_LIST" ]]; then
|
||||
fail "GET /v1/agents list entry for agent missing object_permission"
|
||||
fi
|
||||
pass "GET /v1/agents list includes object_permission for agent"
|
||||
|
||||
# --- 3. GET /v1/agents/{agent_id} returns object_permission ---
|
||||
info "GET /v1/agents/{agent_id}..."
|
||||
GET_RESP=$(curl -s -w "\n%{http_code}" -X GET "$BASE_URL/v1/agents/$AGENT_ID" -H "$AUTH_HEADER")
|
||||
GET_CODE=$(echo "$GET_RESP" | tail -n1)
|
||||
GET_BODY=$(echo "$GET_RESP" | sed '$d')
|
||||
if [[ "$GET_CODE" != "200" ]]; then
|
||||
fail "GET /v1/agents/$AGENT_ID returned $GET_CODE. Body: $GET_BODY"
|
||||
fi
|
||||
OP_GET=$(echo "$GET_BODY" | jq '.object_permission')
|
||||
if [[ "$OP_GET" == "null" || -z "$OP_GET" ]]; then
|
||||
fail "GET /v1/agents/$AGENT_ID response missing object_permission"
|
||||
fi
|
||||
TOOL_PERMS=$(echo "$OP_GET" | jq -r '.mcp_tool_permissions.server_1 | join(",")')
|
||||
if [[ "$TOOL_PERMS" != "tool_a,tool_b" ]]; then
|
||||
fail "object_permission.mcp_tool_permissions.server_1 unexpected: $TOOL_PERMS"
|
||||
fi
|
||||
pass "GET /v1/agents/{agent_id} returns object_permission with mcp_tool_permissions"
|
||||
|
||||
# --- 4. PATCH /v1/agents/{agent_id} with new object_permission ---
|
||||
info "PATCH /v1/agents/{agent_id} with updated object_permission..."
|
||||
PATCH_RESP=$(curl -s -w "\n%{http_code}" -X PATCH "$BASE_URL/v1/agents/$AGENT_ID" \
|
||||
-H "$AUTH_HEADER" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"object_permission": {
|
||||
"mcp_servers": ["server_3"],
|
||||
"mcp_tool_permissions": {"server_3": ["tool_x"]}
|
||||
}
|
||||
}')
|
||||
PATCH_CODE=$(echo "$PATCH_RESP" | tail -n1)
|
||||
PATCH_BODY=$(echo "$PATCH_RESP" | sed '$d')
|
||||
if [[ "$PATCH_CODE" != "200" ]]; then
|
||||
fail "PATCH /v1/agents/$AGENT_ID returned $PATCH_CODE. Body: $PATCH_BODY"
|
||||
fi
|
||||
OP_PATCH=$(echo "$PATCH_BODY" | jq '.object_permission')
|
||||
if [[ "$OP_PATCH" == "null" || -z "$OP_PATCH" ]]; then
|
||||
fail "PATCH response missing object_permission"
|
||||
fi
|
||||
PATCH_SERVERS=$(echo "$OP_PATCH" | jq -r '.mcp_servers | join(",")')
|
||||
if [[ "$PATCH_SERVERS" != "server_3" ]]; then
|
||||
fail "PATCH object_permission.mcp_servers unexpected: $PATCH_SERVERS"
|
||||
fi
|
||||
pass "PATCH /v1/agents/{agent_id} updates and returns object_permission"
|
||||
|
||||
# --- 5. Create agent without object_permission; GET should still work ---
|
||||
info "Creating agent without object_permission..."
|
||||
AGENT_NAME_2="test-agent-no-mcp-$(date +%s)"
|
||||
CREATE2_RESP=$(curl -s -w "\n%{http_code}" -X POST "$BASE_URL/v1/agents" \
|
||||
-H "$AUTH_HEADER" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"agent_name": "'"$AGENT_NAME_2"'",
|
||||
"agent_card_params": {
|
||||
"protocolVersion": "1.0",
|
||||
"name": "No MCP Agent",
|
||||
"description": "No object_permission",
|
||||
"url": "http://localhost:9999/",
|
||||
"version": "1.0.0",
|
||||
"defaultInputModes": ["text"],
|
||||
"defaultOutputModes": ["text"],
|
||||
"capabilities": {},
|
||||
"skills": []
|
||||
}
|
||||
}')
|
||||
CODE2=$(echo "$CREATE2_RESP" | tail -n1)
|
||||
BODY2=$(echo "$CREATE2_RESP" | sed '$d')
|
||||
if [[ "$CODE2" != "200" ]]; then
|
||||
fail "POST /v1/agents (no object_permission) returned $CODE2. Body: $BODY2"
|
||||
fi
|
||||
AGENT_ID_2=$(echo "$BODY2" | jq -r '.agent_id')
|
||||
# object_permission may be null or absent
|
||||
pass "Created agent without object_permission: $AGENT_ID_2"
|
||||
|
||||
# --- 6. Cleanup: delete both agents ---
|
||||
info "Deleting test agents..."
|
||||
for AID in "$AGENT_ID" "$AGENT_ID_2"; do
|
||||
DEL_CODE=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE_URL/v1/agents/$AID" -H "$AUTH_HEADER")
|
||||
if [[ "$DEL_CODE" != "200" ]]; then
|
||||
info "DELETE /v1/agents/$AID returned $DEL_CODE (non-fatal)"
|
||||
fi
|
||||
done
|
||||
pass "Cleanup done"
|
||||
|
||||
echo ""
|
||||
echo -e "${GREEN}All endpoint checks passed.${NC}"
|
||||
@ -1595,3 +1595,146 @@ async def test_get_allowed_mcp_servers_for_key_prefers_in_memory_permission():
|
||||
assert set(result) == {"direct-server", "group-server"}
|
||||
mock_get_perm.assert_not_called()
|
||||
mock_access_groups.assert_called_once_with(["grp-alpha"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
class TestAgentMCPPermissions:
|
||||
"""Test agent-level MCP server and tool permission intersection."""
|
||||
|
||||
async def test_get_allowed_mcp_servers_agent_intersection(self):
|
||||
"""Key/team allow [server_1, server_2]; agent allows [server_1]. Result = [server_1]."""
|
||||
user_api_key_auth = UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_id="test-user",
|
||||
team_id="test-team",
|
||||
agent_id="agent-123",
|
||||
)
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_allowed_mcp_servers_for_key"
|
||||
) as mock_key:
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_allowed_mcp_servers_for_team"
|
||||
) as mock_team:
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_allowed_mcp_servers_for_agent"
|
||||
) as mock_agent:
|
||||
mock_key.return_value = ["server_1", "server_2"]
|
||||
mock_team.return_value = []
|
||||
mock_agent.return_value = ["server_1"]
|
||||
result = await MCPRequestHandler.get_allowed_mcp_servers(
|
||||
user_api_key_auth=user_api_key_auth
|
||||
)
|
||||
assert sorted(result) == ["server_1"]
|
||||
mock_agent.assert_called_once_with(user_api_key_auth)
|
||||
|
||||
async def test_get_allowed_mcp_servers_agent_no_restriction(self):
|
||||
"""Agent with no object_permission returns []; no intersection applied (inherit key/team)."""
|
||||
user_api_key_auth = UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_id="test-user",
|
||||
agent_id="agent-456",
|
||||
)
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_allowed_mcp_servers_for_key"
|
||||
) as mock_key:
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_allowed_mcp_servers_for_team"
|
||||
) as mock_team:
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_allowed_mcp_servers_for_agent"
|
||||
) as mock_agent:
|
||||
mock_key.return_value = ["server_1", "server_2"]
|
||||
mock_team.return_value = []
|
||||
mock_agent.return_value = [] # no agent-level restriction
|
||||
result = await MCPRequestHandler.get_allowed_mcp_servers(
|
||||
user_api_key_auth=user_api_key_auth
|
||||
)
|
||||
assert sorted(result) == ["server_1", "server_2"]
|
||||
mock_agent.assert_called_once_with(user_api_key_auth)
|
||||
|
||||
async def test_get_allowed_mcp_servers_key_team_agent_intersection(self):
|
||||
"""Key allows [1, 2], agent allows [2, 3]. Result = [2]."""
|
||||
user_api_key_auth = UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_id="test-user",
|
||||
agent_id="agent-789",
|
||||
)
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_allowed_mcp_servers_for_key"
|
||||
) as mock_key:
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_allowed_mcp_servers_for_team"
|
||||
) as mock_team:
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_allowed_mcp_servers_for_agent"
|
||||
) as mock_agent:
|
||||
mock_key.return_value = ["server_1", "server_2"]
|
||||
mock_team.return_value = []
|
||||
mock_agent.return_value = ["server_2", "server_3"]
|
||||
result = await MCPRequestHandler.get_allowed_mcp_servers(
|
||||
user_api_key_auth=user_api_key_auth
|
||||
)
|
||||
assert sorted(result) == ["server_2"]
|
||||
|
||||
async def test_get_allowed_tools_for_server_agent_intersection(self):
|
||||
"""Key allows [tool_a, tool_b], agent allows [tool_a]. Result = [tool_a]."""
|
||||
user_api_key_auth = UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_id="test-user",
|
||||
agent_id="agent-tools",
|
||||
)
|
||||
key_perm = MagicMock()
|
||||
key_perm.mcp_tool_permissions = {"server_1": ["tool_a", "tool_b"]}
|
||||
team_perm = None
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_key_object_permission", return_value=key_perm
|
||||
):
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_team_object_permission",
|
||||
new_callable=AsyncMock,
|
||||
return_value=team_perm,
|
||||
):
|
||||
with patch.object(
|
||||
MCPRequestHandler,
|
||||
"_get_agent_tool_permissions_for_server",
|
||||
new_callable=AsyncMock,
|
||||
return_value=["tool_a"],
|
||||
) as mock_agent_tools:
|
||||
result = await MCPRequestHandler.get_allowed_tools_for_server(
|
||||
server_id="server_1",
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
)
|
||||
assert result == ["tool_a"]
|
||||
mock_agent_tools.assert_called_once()
|
||||
call_kwargs = mock_agent_tools.call_args.kwargs
|
||||
assert call_kwargs["server_id"] == "server_1"
|
||||
assert call_kwargs["user_api_key_auth"] == user_api_key_auth
|
||||
|
||||
async def test_get_allowed_tools_for_server_agent_no_restriction(self):
|
||||
"""Agent has no tool permissions for server; key/team result is unchanged."""
|
||||
user_api_key_auth = UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_id="test-user",
|
||||
agent_id="agent-no-tools",
|
||||
)
|
||||
key_perm = MagicMock()
|
||||
key_perm.mcp_tool_permissions = {"server_1": ["tool_a", "tool_b"]}
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_key_object_permission", return_value=key_perm
|
||||
):
|
||||
with patch.object(
|
||||
MCPRequestHandler, "_get_team_object_permission",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
with patch.object(
|
||||
MCPRequestHandler,
|
||||
"_get_agent_tool_permissions_for_server",
|
||||
new_callable=AsyncMock,
|
||||
return_value=None,
|
||||
):
|
||||
result = await MCPRequestHandler.get_allowed_tools_for_server(
|
||||
server_id="server_1",
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
)
|
||||
assert sorted(result) == ["tool_a", "tool_b"]
|
||||
|
||||
106
tests/test_litellm/proxy/hooks/test_max_iterations_limiter.py
Normal file
106
tests/test_litellm/proxy/hooks/test_max_iterations_limiter.py
Normal file
@ -0,0 +1,106 @@
|
||||
"""
|
||||
Unit Tests for the max iterations limiter for the proxy.
|
||||
|
||||
Tests that session-scoped iteration counting works correctly:
|
||||
- Enforces max_iterations per session_id
|
||||
- Different sessions have independent counters
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.caching.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.hooks.max_iterations_limiter import _PROXY_MaxIterationsHandler
|
||||
from litellm.proxy.utils import InternalUsageCache
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_basic_enforcement():
|
||||
"""
|
||||
Test that max_iterations is enforced per session_id.
|
||||
|
||||
- 3 requests with the same session_id should succeed when max_iterations=3
|
||||
- 4th request should raise 429
|
||||
"""
|
||||
local_cache = DualCache()
|
||||
handler = _PROXY_MaxIterationsHandler(
|
||||
internal_usage_cache=InternalUsageCache(local_cache),
|
||||
)
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="sk-test-key-1234", metadata={"max_iterations": 3}
|
||||
)
|
||||
|
||||
# First 3 requests should succeed
|
||||
for i in range(3):
|
||||
await handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={"metadata": {"session_id": "session-abc"}},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
# 4th request should fail with 429
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={"metadata": {"session_id": "session-abc"}},
|
||||
call_type="",
|
||||
)
|
||||
assert exc_info.value.status_code == 429
|
||||
assert "max_iterations" in str(exc_info.value.detail).lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iterations_different_sessions_independent():
|
||||
"""
|
||||
Test that different session_ids have independent iteration counters.
|
||||
|
||||
- Session A and Session B each get their own max_iterations budget
|
||||
- Exhausting Session A does not affect Session B
|
||||
"""
|
||||
local_cache = DualCache()
|
||||
handler = _PROXY_MaxIterationsHandler(
|
||||
internal_usage_cache=InternalUsageCache(local_cache),
|
||||
)
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="sk-test-key-5678", metadata={"max_iterations": 2}
|
||||
)
|
||||
|
||||
# Session A: 2 calls succeed
|
||||
for _ in range(2):
|
||||
await handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={"metadata": {"session_id": "session-A"}},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
# Session B: 2 calls succeed (independent counter)
|
||||
for _ in range(2):
|
||||
await handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={"metadata": {"session_id": "session-B"}},
|
||||
call_type="",
|
||||
)
|
||||
|
||||
# Session A: 3rd call fails
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={"metadata": {"session_id": "session-A"}},
|
||||
call_type="",
|
||||
)
|
||||
assert exc_info.value.status_code == 429
|
||||
|
||||
# Session B: 3rd call also fails
|
||||
with pytest.raises(HTTPException):
|
||||
await handler.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=local_cache,
|
||||
data={"metadata": {"session_id": "session-B"}},
|
||||
call_type="",
|
||||
)
|
||||
15
ui/litellm-dashboard/package-lock.json
generated
15
ui/litellm-dashboard/package-lock.json
generated
@ -13056,6 +13056,21 @@
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/@next/swc-win32-ia32-msvc": {
|
||||
"version": "14.2.33",
|
||||
"resolved": "https://registry.npmjs.org/@next/swc-win32-ia32-msvc/-/swc-win32-ia32-msvc-14.2.33.tgz",
|
||||
"integrity": "sha512-pc9LpGNKhJ0dXQhZ5QMmYxtARwwmWLpeocFmVG5Z0DzWq5Uf0izcI8tLc+qOpqxO1PWqZ5A7J1blrUIKrIFc7Q==",
|
||||
"cpu": [
|
||||
"ia32"
|
||||
],
|
||||
"optional": true,
|
||||
"os": [
|
||||
"win32"
|
||||
],
|
||||
"engines": {
|
||||
"node": ">= 10"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import React from "react";
|
||||
import { describe, beforeEach, expect, it, vi } from "vitest";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import ModelRetrySettingsTab from "./ModelRetrySettingsTab";
|
||||
|
||||
// TabPanel requires a parent Tabs context in Tremor. We stub it to render children
|
||||
|
||||
@ -1,13 +1,13 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { Button } from "@tremor/react";
|
||||
import { Modal } from "antd";
|
||||
import { getAgentsList, deleteAgentCall } from "./networking";
|
||||
import { Modal, Alert } from "antd";
|
||||
import { getAgentsList, deleteAgentCall, keyListCall } from "./networking";
|
||||
import AddAgentForm from "./agents/add_agent_form";
|
||||
import AgentTable from "./agents/agent_table";
|
||||
import AgentCardGrid from "./agents/agent_card_grid";
|
||||
import { isAdminRole } from "@/utils/roles";
|
||||
import AgentInfoView from "./agents/agent_info";
|
||||
import NotificationsManager from "./molecules/notifications_manager";
|
||||
import { Agent } from "./agents/types";
|
||||
import { Agent, AgentKeyInfo } from "./agents/types";
|
||||
|
||||
interface AgentsPanelProps {
|
||||
accessToken: string | null;
|
||||
@ -20,6 +20,7 @@ interface AgentsResponse {
|
||||
|
||||
const AgentsPanel: React.FC<AgentsPanelProps> = ({ accessToken, userRole }) => {
|
||||
const [agentsList, setAgentsList] = useState<Agent[]>([]);
|
||||
const [keyInfoMap, setKeyInfoMap] = useState<Record<string, AgentKeyInfo>>({});
|
||||
const [isAddModalVisible, setIsAddModalVisible] = useState(false);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
@ -36,8 +37,7 @@ const AgentsPanel: React.FC<AgentsPanelProps> = ({ accessToken, userRole }) => {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
const response: AgentsResponse = await getAgentsList(accessToken);
|
||||
console.log(`agents: ${JSON.stringify(response)}`);
|
||||
setAgentsList(response.agents);
|
||||
setAgentsList(response.agents || []);
|
||||
} catch (error) {
|
||||
console.error("Error fetching agents:", error);
|
||||
} finally {
|
||||
@ -45,10 +45,50 @@ const AgentsPanel: React.FC<AgentsPanelProps> = ({ accessToken, userRole }) => {
|
||||
}
|
||||
};
|
||||
|
||||
const fetchKeysForAgents = async () => {
|
||||
if (!accessToken) return;
|
||||
try {
|
||||
const { keys = [] } = await keyListCall(
|
||||
accessToken,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
null,
|
||||
1,
|
||||
500
|
||||
);
|
||||
const map: Record<string, AgentKeyInfo> = {};
|
||||
for (const key of keys) {
|
||||
const agentId = (key as { agent_id?: string }).agent_id;
|
||||
if (agentId && !map[agentId]) {
|
||||
map[agentId] = {
|
||||
has_key: true,
|
||||
key_alias: (key as { key_alias?: string }).key_alias,
|
||||
token_prefix: (key as { token?: string }).token
|
||||
? `${(key as { token: string }).token.slice(0, 8)}…`
|
||||
: undefined,
|
||||
};
|
||||
}
|
||||
}
|
||||
setKeyInfoMap(map);
|
||||
} catch (error) {
|
||||
console.error("Error fetching keys for agents:", error);
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
fetchAgents();
|
||||
}, [accessToken]);
|
||||
|
||||
useEffect(() => {
|
||||
if (accessToken && agentsList.length > 0) {
|
||||
fetchKeysForAgents();
|
||||
} else if (agentsList.length === 0) {
|
||||
setKeyInfoMap({});
|
||||
}
|
||||
}, [accessToken, agentsList.length]);
|
||||
|
||||
const handleAddAgent = () => {
|
||||
if (selectedAgentId) {
|
||||
setSelectedAgentId(null);
|
||||
@ -94,6 +134,13 @@ const AgentsPanel: React.FC<AgentsPanelProps> = ({ accessToken, userRole }) => {
|
||||
<div className="flex flex-col gap-2 mb-4">
|
||||
<h1 className="text-2xl font-bold">Agents</h1>
|
||||
<p className="text-sm text-gray-600">List of A2A-spec agents that are available to be used in your organization. Go to AI Hub, to make agents public.</p>
|
||||
<Alert
|
||||
message="Why do agents need keys?"
|
||||
description="Keys scope access to an agent and allow it to call MCP tools. Assign a key when creating an agent or from the Virtual Keys page."
|
||||
type="info"
|
||||
showIcon
|
||||
className="mb-3"
|
||||
/>
|
||||
<div className="mt-2">
|
||||
<Button onClick={handleAddAgent} disabled={!accessToken}>
|
||||
+ Add New Agent
|
||||
@ -109,8 +156,9 @@ const AgentsPanel: React.FC<AgentsPanelProps> = ({ accessToken, userRole }) => {
|
||||
isAdmin={isAdmin}
|
||||
/>
|
||||
) : (
|
||||
<AgentTable
|
||||
<AgentCardGrid
|
||||
agentsList={agentsList}
|
||||
keyInfoMap={keyInfoMap}
|
||||
isLoading={isLoading}
|
||||
onDeleteClick={handleDeleteClick}
|
||||
accessToken={accessToken}
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { Modal, Form, message, Select, Input, Steps, Radio, Tag, Divider } from "antd";
|
||||
import { Button } from "@tremor/react";
|
||||
import { CheckCircleFilled, KeyOutlined, RobotOutlined, AppstoreOutlined } from "@ant-design/icons";
|
||||
import { CheckCircleFilled, KeyOutlined, RobotOutlined, AppstoreOutlined, InfoCircleOutlined } from "@ant-design/icons";
|
||||
import CreatedKeyDisplay from "../shared/CreatedKeyDisplay";
|
||||
import {
|
||||
createAgentCall,
|
||||
@ -9,11 +9,16 @@ import {
|
||||
keyCreateForAgentCall,
|
||||
keyListCall,
|
||||
keyUpdateCall,
|
||||
modelAvailableCall,
|
||||
AgentCreateInfo,
|
||||
} from "../networking";
|
||||
import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized";
|
||||
import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_team_key";
|
||||
import AgentFormFields from "./agent_form_fields";
|
||||
import DynamicAgentFormFields, { buildDynamicAgentData } from "./dynamic_agent_form_fields";
|
||||
import { getDefaultFormValues, buildAgentDataFromForm } from "./agent_config";
|
||||
import MCPServerSelector from "../mcp_server_management/MCPServerSelector";
|
||||
import MCPToolPermissions from "../mcp_server_management/MCPToolPermissions";
|
||||
|
||||
const { Step } = Steps;
|
||||
|
||||
@ -32,6 +37,7 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
accessToken,
|
||||
onSuccess,
|
||||
}) => {
|
||||
const { userId, userRole } = useAuthorized();
|
||||
const [form] = Form.useForm();
|
||||
const [currentStep, setCurrentStep] = useState(0);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
@ -46,6 +52,8 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
const [existingKeys, setExistingKeys] = useState<any[]>([]);
|
||||
const [selectedExistingKey, setSelectedExistingKey] = useState<string | null>(null);
|
||||
const [loadingKeys, setLoadingKeys] = useState(false);
|
||||
const [availableModels, setAvailableModels] = useState<string[]>([]);
|
||||
const [loadingModels, setLoadingModels] = useState(false);
|
||||
|
||||
// Step 2: results
|
||||
const [createdAgentName, setCreatedAgentName] = useState<string>("");
|
||||
@ -68,9 +76,9 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
fetchMetadata();
|
||||
}, []);
|
||||
|
||||
// Fetch existing keys when assign key step becomes active
|
||||
// Fetch existing keys when assign key step becomes active (step 2)
|
||||
useEffect(() => {
|
||||
if (currentStep === 1 && accessToken && existingKeys.length === 0) {
|
||||
if (currentStep === 2 && accessToken && existingKeys.length === 0) {
|
||||
const fetchKeys = async () => {
|
||||
setLoadingKeys(true);
|
||||
try {
|
||||
@ -86,6 +94,31 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
}
|
||||
}, [currentStep, accessToken]);
|
||||
|
||||
// Fetch available models when Assign Key step is active (same list as key generation)
|
||||
useEffect(() => {
|
||||
if (currentStep !== 2 || !accessToken || !userId || !userRole) return;
|
||||
let cancelled = false;
|
||||
setLoadingModels(true);
|
||||
modelAvailableCall(accessToken, userId, userRole)
|
||||
.then((response) => {
|
||||
if (cancelled) return;
|
||||
const modelsArray = response?.data ?? (Array.isArray(response) ? response : []);
|
||||
const ids = modelsArray
|
||||
.map((m: { id?: string; model_name?: string }) => m.id ?? m.model_name)
|
||||
.filter(Boolean) as string[];
|
||||
setAvailableModels(ids);
|
||||
})
|
||||
.catch((error) => {
|
||||
if (!cancelled) console.error("Error fetching models:", error);
|
||||
})
|
||||
.finally(() => {
|
||||
if (!cancelled) setLoadingModels(false);
|
||||
});
|
||||
return () => {
|
||||
cancelled = true;
|
||||
};
|
||||
}, [currentStep, accessToken, userId, userRole]);
|
||||
|
||||
const selectedAgentTypeInfo = agentTypeMetadata.find(
|
||||
(info) => info.agent_type === agentType
|
||||
);
|
||||
@ -156,8 +189,6 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
// getFieldsValue(true) returns ALL preserved values including fields from
|
||||
// unmounted steps; merge with any currently-mounted validated fields.
|
||||
await form.validateFields();
|
||||
const values = { ...form.getFieldsValue(true) };
|
||||
const agentData = buildAgentData(values);
|
||||
@ -167,6 +198,26 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
return;
|
||||
}
|
||||
|
||||
// Build object_permission from MCP Tools step (allowed_mcp_servers_and_groups, mcp_tool_permissions)
|
||||
const mcpServersAndGroups = values.allowed_mcp_servers_and_groups;
|
||||
const mcpToolPermissions = values.mcp_tool_permissions || {};
|
||||
if (
|
||||
mcpServersAndGroups &&
|
||||
(mcpServersAndGroups.servers?.length > 0 || mcpServersAndGroups.accessGroups?.length > 0) ||
|
||||
Object.keys(mcpToolPermissions).length > 0
|
||||
) {
|
||||
agentData.object_permission = {};
|
||||
if (mcpServersAndGroups?.servers?.length > 0) {
|
||||
agentData.object_permission.mcp_servers = mcpServersAndGroups.servers;
|
||||
}
|
||||
if (mcpServersAndGroups?.accessGroups?.length > 0) {
|
||||
agentData.object_permission.mcp_access_groups = mcpServersAndGroups.accessGroups;
|
||||
}
|
||||
if (Object.keys(mcpToolPermissions).length > 0) {
|
||||
agentData.object_permission.mcp_tool_permissions = mcpToolPermissions;
|
||||
}
|
||||
}
|
||||
|
||||
const agentResponse = await createAgentCall(accessToken, agentData);
|
||||
const agentId: string = agentResponse.agent_id;
|
||||
const agentName: string = agentResponse.agent_name || values.agent_name || agentId;
|
||||
@ -194,11 +245,12 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
setAssignedKeyAlias(keyInfo?.key_alias || selectedExistingKey.slice(0, 12) + "…");
|
||||
}
|
||||
|
||||
setCurrentStep(2);
|
||||
setCurrentStep(3);
|
||||
onSuccess();
|
||||
} catch (error) {
|
||||
console.error("Error creating agent:", error);
|
||||
message.error("Failed to create agent");
|
||||
const errorMessage = error instanceof Error ? error.message : String(error);
|
||||
message.error(errorMessage ? `Failed to create agent: ${errorMessage}` : "Failed to create agent");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
@ -218,6 +270,54 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
onClose();
|
||||
};
|
||||
|
||||
const renderMCPToolsStep = () => (
|
||||
<div className="space-y-4">
|
||||
<p className="text-sm text-gray-600">
|
||||
Optionally restrict which MCP servers and tools this agent can use. Leave empty to allow all (subject to key/team permissions).
|
||||
</p>
|
||||
<Form.Item
|
||||
label={
|
||||
<span>
|
||||
Allowed MCP Servers{" "}
|
||||
<InfoCircleOutlined title="Select which MCP servers or access groups this agent can access" style={{ marginLeft: "4px" }} />
|
||||
</span>
|
||||
}
|
||||
name="allowed_mcp_servers_and_groups"
|
||||
initialValue={{ servers: [], accessGroups: [] }}
|
||||
>
|
||||
<MCPServerSelector
|
||||
onChange={(val: { servers?: string[]; accessGroups?: string[] }) =>
|
||||
form.setFieldValue("allowed_mcp_servers_and_groups", val)
|
||||
}
|
||||
value={form.getFieldValue("allowed_mcp_servers_and_groups") || { servers: [], accessGroups: [] }}
|
||||
accessToken={accessToken ?? ""}
|
||||
placeholder="Select MCP servers or access groups (optional)"
|
||||
/>
|
||||
</Form.Item>
|
||||
<Form.Item name="mcp_tool_permissions" initialValue={{}} hidden>
|
||||
<Input type="hidden" />
|
||||
</Form.Item>
|
||||
<Form.Item
|
||||
noStyle
|
||||
shouldUpdate={(prev, curr) =>
|
||||
prev.allowed_mcp_servers_and_groups !== curr.allowed_mcp_servers_and_groups ||
|
||||
prev.mcp_tool_permissions !== curr.mcp_tool_permissions
|
||||
}
|
||||
>
|
||||
{() => (
|
||||
<div className="mt-4">
|
||||
<MCPToolPermissions
|
||||
accessToken={accessToken ?? ""}
|
||||
selectedServers={form.getFieldValue("allowed_mcp_servers_and_groups")?.servers ?? []}
|
||||
toolPermissions={form.getFieldValue("mcp_tool_permissions") ?? {}}
|
||||
onChange={(toolPerms: Record<string, string[]>) => form.setFieldsValue({ mcp_tool_permissions: toolPerms })}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</Form.Item>
|
||||
</div>
|
||||
);
|
||||
|
||||
const handleAgentTypeChange = (value: string) => {
|
||||
setAgentType(value);
|
||||
form.resetFields();
|
||||
@ -412,10 +512,16 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
<Select
|
||||
mode="tags"
|
||||
style={{ width: "100%" }}
|
||||
placeholder="e.g. gpt-4o, claude-3-5-sonnet"
|
||||
placeholder={loadingModels ? "Loading models..." : "e.g. gpt-4o, claude-3-5-sonnet"}
|
||||
value={newKeyModels}
|
||||
onChange={setNewKeyModels}
|
||||
tokenSeparators={[","]}
|
||||
loading={loadingModels}
|
||||
showSearch
|
||||
options={availableModels.map((m) => ({
|
||||
label: getModelDisplayName(m),
|
||||
value: m,
|
||||
}))}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
@ -537,6 +643,7 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
{/* Step indicator */}
|
||||
<Steps current={currentStep} size="small" className="mb-8">
|
||||
<Step title="Configure" />
|
||||
<Step title="MCP Tools" />
|
||||
<Step title="Assign Key" />
|
||||
<Step title="Ready" />
|
||||
</Steps>
|
||||
@ -544,18 +651,23 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
<Form
|
||||
form={form}
|
||||
layout="vertical"
|
||||
initialValues={agentType === "a2a" ? getDefaultFormValues() : {}}
|
||||
initialValues={
|
||||
agentType === "a2a"
|
||||
? { ...getDefaultFormValues(), allowed_mcp_servers_and_groups: { servers: [], accessGroups: [] }, mcp_tool_permissions: {} }
|
||||
: { allowed_mcp_servers_and_groups: { servers: [], accessGroups: [] }, mcp_tool_permissions: {} }
|
||||
}
|
||||
className="space-y-4"
|
||||
>
|
||||
{currentStep === 0 && renderConfigureStep()}
|
||||
{currentStep === 1 && renderAssignKeyStep()}
|
||||
{currentStep === 2 && renderReadyStep()}
|
||||
{currentStep === 1 && renderMCPToolsStep()}
|
||||
{currentStep === 2 && renderAssignKeyStep()}
|
||||
{currentStep === 3 && renderReadyStep()}
|
||||
</Form>
|
||||
|
||||
{/* Footer navigation */}
|
||||
<div className="flex items-center justify-between pt-6 border-t border-gray-100 mt-6">
|
||||
<div>
|
||||
{currentStep > 0 && currentStep < 2 && (
|
||||
{currentStep > 0 && currentStep < 3 && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={handleBack}
|
||||
@ -566,7 +678,7 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
)}
|
||||
</div>
|
||||
<div className="flex gap-3">
|
||||
{currentStep < 2 && (
|
||||
{currentStep < 3 && (
|
||||
<Button variant="secondary" onClick={handleClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
@ -577,11 +689,16 @@ const AddAgentForm: React.FC<AddAgentFormProps> = ({
|
||||
</Button>
|
||||
)}
|
||||
{currentStep === 1 && (
|
||||
<Button variant="primary" onClick={handleNext}>
|
||||
Next →
|
||||
</Button>
|
||||
)}
|
||||
{currentStep === 2 && (
|
||||
<Button variant="primary" loading={isSubmitting} onClick={handleCreateAgent}>
|
||||
{isSubmitting ? "Creating..." : "Create Agent →"}
|
||||
</Button>
|
||||
)}
|
||||
{currentStep === 2 && (
|
||||
{currentStep === 3 && (
|
||||
<Button variant="primary" onClick={handleClose}>
|
||||
Done
|
||||
</Button>
|
||||
|
||||
103
ui/litellm-dashboard/src/components/agents/agent_card.tsx
Normal file
103
ui/litellm-dashboard/src/components/agents/agent_card.tsx
Normal file
@ -0,0 +1,103 @@
|
||||
import React from "react";
|
||||
import { Card, Badge, Tooltip, Button } from "antd";
|
||||
import { CopyOutlined, KeyOutlined, WarningOutlined, DeleteOutlined } from "@ant-design/icons";
|
||||
import { Agent, AgentKeyInfo } from "./types";
|
||||
|
||||
interface AgentCardProps {
|
||||
agent: Agent;
|
||||
keyInfo?: AgentKeyInfo;
|
||||
onAgentClick: (agentId: string) => void;
|
||||
onDeleteClick?: (agentId: string, agentName: string) => void;
|
||||
accessToken: string | null;
|
||||
isAdmin: boolean;
|
||||
onAgentUpdated: () => void;
|
||||
}
|
||||
|
||||
const AgentCard: React.FC<AgentCardProps> = ({
|
||||
agent,
|
||||
keyInfo,
|
||||
onAgentClick,
|
||||
onDeleteClick,
|
||||
isAdmin,
|
||||
}) => {
|
||||
const description =
|
||||
agent.agent_card_params?.description || "No description";
|
||||
const url = agent.agent_card_params?.url;
|
||||
const hasKey = keyInfo?.has_key ?? false;
|
||||
const statusBadge = hasKey ? (
|
||||
<Badge status="success" text="Active" />
|
||||
) : (
|
||||
<Badge status="warning" text="Needs Setup" />
|
||||
);
|
||||
|
||||
const copyToClipboard = (e: React.MouseEvent, text: string) => {
|
||||
e.stopPropagation();
|
||||
navigator.clipboard.writeText(text);
|
||||
};
|
||||
|
||||
return (
|
||||
<Card
|
||||
hoverable
|
||||
className="h-full flex flex-col"
|
||||
styles={{
|
||||
body: { flex: 1, display: "flex", flexDirection: "column" },
|
||||
}}
|
||||
onClick={() => onAgentClick(agent.agent_id)}
|
||||
>
|
||||
<div className="flex items-start justify-between gap-2 mb-2">
|
||||
<div className="flex-1 min-w-0">
|
||||
<div className="flex items-center gap-2 flex-wrap">
|
||||
<span className="font-medium text-gray-900 truncate">
|
||||
{agent.agent_name}
|
||||
</span>
|
||||
<Tooltip title="Copy Agent ID">
|
||||
<CopyOutlined
|
||||
onClick={(e) => copyToClipboard(e, agent.agent_id)}
|
||||
className="cursor-pointer text-gray-400 hover:text-blue-500 text-xs shrink-0"
|
||||
/>
|
||||
</Tooltip>
|
||||
</div>
|
||||
<div className="mt-1">{statusBadge}</div>
|
||||
</div>
|
||||
{isAdmin && onDeleteClick && (
|
||||
<Tooltip title="Delete agent">
|
||||
<Button
|
||||
type="text"
|
||||
size="small"
|
||||
danger
|
||||
icon={<DeleteOutlined />}
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onDeleteClick(agent.agent_id, agent.agent_name);
|
||||
}}
|
||||
className="shrink-0 -mr-1"
|
||||
/>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
<p className="text-sm text-gray-600 line-clamp-2 flex-1 mb-3">
|
||||
{description}
|
||||
</p>
|
||||
{url && (
|
||||
<p className="text-xs text-gray-500 truncate mb-2" title={url}>
|
||||
{url}
|
||||
</p>
|
||||
)}
|
||||
<div className="mt-auto pt-3 border-t border-gray-100 text-xs">
|
||||
{hasKey ? (
|
||||
<div className="flex items-center gap-1.5 text-gray-600">
|
||||
<KeyOutlined />
|
||||
<span>{keyInfo?.key_alias || keyInfo?.token_prefix || "Key assigned"}</span>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex items-center gap-1.5 text-amber-600">
|
||||
<WarningOutlined />
|
||||
<span>No key assigned</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
};
|
||||
|
||||
export default AgentCard;
|
||||
@ -0,0 +1,63 @@
|
||||
import React from "react";
|
||||
import { Skeleton } from "antd";
|
||||
import AgentCard from "./agent_card";
|
||||
import { Agent, AgentKeyInfo } from "./types";
|
||||
|
||||
interface AgentCardGridProps {
|
||||
agentsList: Agent[];
|
||||
keyInfoMap: Record<string, AgentKeyInfo>;
|
||||
isLoading: boolean;
|
||||
onDeleteClick: (agentId: string, agentName: string) => void;
|
||||
accessToken: string | null;
|
||||
onAgentUpdated: () => void;
|
||||
isAdmin: boolean;
|
||||
onAgentClick: (agentId: string) => void;
|
||||
}
|
||||
|
||||
const AgentCardGrid: React.FC<AgentCardGridProps> = ({
|
||||
agentsList,
|
||||
keyInfoMap,
|
||||
isLoading,
|
||||
onDeleteClick,
|
||||
accessToken,
|
||||
onAgentUpdated,
|
||||
isAdmin,
|
||||
onAgentClick,
|
||||
}) => {
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6">
|
||||
{[1, 2, 3].map((i) => (
|
||||
<Skeleton key={i} active paragraph={{ rows: 3 }} />
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!agentsList || agentsList.length === 0) {
|
||||
return (
|
||||
<div className="rounded-lg border border-gray-200 bg-gray-50/50 py-12 text-center">
|
||||
<p className="text-gray-500">No agents found. Create one to get started.</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-3 gap-6">
|
||||
{agentsList.map((agent) => (
|
||||
<AgentCard
|
||||
key={agent.agent_id}
|
||||
agent={agent}
|
||||
keyInfo={keyInfoMap[agent.agent_id]}
|
||||
onAgentClick={onAgentClick}
|
||||
onDeleteClick={isAdmin ? onDeleteClick : undefined}
|
||||
accessToken={accessToken}
|
||||
isAdmin={isAdmin}
|
||||
onAgentUpdated={onAgentUpdated}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default AgentCardGrid;
|
||||
@ -205,6 +205,44 @@ const AgentInfoView: React.FC<AgentInfoViewProps> = ({
|
||||
<Descriptions.Item label="Updated At">{formatDate(agent.updated_at)}</Descriptions.Item>
|
||||
</Descriptions>
|
||||
|
||||
{agent.object_permission &&
|
||||
(agent.object_permission.mcp_servers?.length ||
|
||||
agent.object_permission.mcp_access_groups?.length ||
|
||||
(agent.object_permission.mcp_tool_permissions &&
|
||||
Object.keys(agent.object_permission.mcp_tool_permissions).length > 0)) && (
|
||||
<div style={{ marginTop: 24 }}>
|
||||
<Title>MCP Tool Permissions</Title>
|
||||
<Descriptions bordered column={1} style={{ marginTop: 16 }}>
|
||||
{agent.object_permission.mcp_servers && agent.object_permission.mcp_servers.length > 0 && (
|
||||
<Descriptions.Item label="MCP Servers">
|
||||
{agent.object_permission.mcp_servers.join(", ")}
|
||||
</Descriptions.Item>
|
||||
)}
|
||||
{agent.object_permission.mcp_access_groups &&
|
||||
agent.object_permission.mcp_access_groups.length > 0 && (
|
||||
<Descriptions.Item label="MCP Access Groups">
|
||||
{agent.object_permission.mcp_access_groups.join(", ")}
|
||||
</Descriptions.Item>
|
||||
)}
|
||||
{agent.object_permission.mcp_tool_permissions &&
|
||||
Object.keys(agent.object_permission.mcp_tool_permissions).length > 0 && (
|
||||
<Descriptions.Item label="Tool permissions per server">
|
||||
<div className="space-y-1">
|
||||
{Object.entries(agent.object_permission.mcp_tool_permissions).map(
|
||||
([serverId, tools]) => (
|
||||
<div key={serverId}>
|
||||
<span className="font-medium">{serverId}:</span>{" "}
|
||||
{Array.isArray(tools) ? tools.join(", ") : String(tools)}
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
</Descriptions.Item>
|
||||
)}
|
||||
</Descriptions>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<AgentCostView agent={agent} />
|
||||
|
||||
{agent.agent_card_params?.skills && agent.agent_card_params.skills.length > 0 && (
|
||||
|
||||
@ -1,3 +1,15 @@
|
||||
export interface AgentKeyInfo {
|
||||
key_alias?: string;
|
||||
token_prefix?: string;
|
||||
has_key: boolean;
|
||||
}
|
||||
|
||||
export interface AgentObjectPermission {
|
||||
mcp_servers?: string[];
|
||||
mcp_access_groups?: string[];
|
||||
mcp_tool_permissions?: Record<string, string[]>;
|
||||
}
|
||||
|
||||
export interface Agent {
|
||||
agent_id: string;
|
||||
agent_name: string;
|
||||
@ -7,8 +19,10 @@ export interface Agent {
|
||||
};
|
||||
agent_card_params?: {
|
||||
description?: string;
|
||||
url?: string;
|
||||
[key: string]: any;
|
||||
};
|
||||
object_permission?: AgentObjectPermission;
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
created_by?: string;
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import React, { useState } from "react";
|
||||
import { Input } from "antd";
|
||||
import { SearchOutlined, ArrowRightOutlined } from "@ant-design/icons";
|
||||
import { GuardrailCardInfo, LITELLM_CONTENT_FILTER_CARDS, PARTNER_GUARDRAIL_CARDS, ALL_CARDS } from "./guardrail_garden_data";
|
||||
import { GuardrailCardInfo, ALL_CARDS } from "./guardrail_garden_data";
|
||||
import GuardrailCard from "./guardrail_garden_card";
|
||||
import GuardrailDetailView from "./guardrail_garden_detail";
|
||||
|
||||
|
||||
@ -194,7 +194,7 @@ const MCPToolConfiguration: React.FC<MCPToolConfigurationProps> = ({
|
||||
{filteredTools.length === 0 ? (
|
||||
<div className="text-center py-6 text-gray-400 border rounded-lg border-dashed">
|
||||
<SearchOutlined className="text-2xl mb-2" />
|
||||
<Text>No tools found matching "{toolSearchTerm}"</Text>
|
||||
<Text>No tools found matching "{toolSearchTerm}"</Text>
|
||||
</div>
|
||||
) : (
|
||||
filteredTools.map((tool, index) => (
|
||||
|
||||
@ -1,12 +1,12 @@
|
||||
import React, { useState } from "react";
|
||||
import { useQuery, useMutation } from "@tanstack/react-query";
|
||||
import { ToolTestPanel } from "./ToolTestPanel";
|
||||
import { MCPTool, MCPToolsViewerProps, MCPContent, CallMCPToolResponse, AUTH_TYPE } from "./types";
|
||||
import { MCPTool, MCPToolsViewerProps, MCPContent, CallMCPToolResponse } from "./types";
|
||||
import { listMCPTools, callMCPTool } from "../networking";
|
||||
|
||||
import { Card, Title, Text } from "@tremor/react";
|
||||
import { RobotOutlined, ToolOutlined, SearchOutlined, LockOutlined, KeyOutlined } from "@ant-design/icons";
|
||||
import { Input, Alert, Button as AntdButton } from "antd";
|
||||
import { RobotOutlined, ToolOutlined, SearchOutlined, KeyOutlined } from "@ant-design/icons";
|
||||
import { Input, Button as AntdButton } from "antd";
|
||||
|
||||
const MCPToolsViewer = ({
|
||||
serverId,
|
||||
@ -134,7 +134,7 @@ const MCPToolsViewer = ({
|
||||
|
||||
{!showHeaderInput && Object.keys(passthroughHeaders).length === 0 && (
|
||||
<Text className="text-xs text-blue-700">
|
||||
This server requires additional headers. Click "Configure" to provide values.
|
||||
This server requires additional headers. Click "Configure" to provide values.
|
||||
</Text>
|
||||
)}
|
||||
|
||||
@ -255,7 +255,7 @@ const MCPToolsViewer = ({
|
||||
<div className="p-4 text-center bg-white border border-gray-200 rounded-lg">
|
||||
<SearchOutlined className="text-2xl text-gray-400 mb-2" />
|
||||
<p className="text-xs font-medium text-gray-700 mb-1">No tools found</p>
|
||||
<p className="text-xs text-gray-500">No tools match "{toolSearchTerm}"</p>
|
||||
<p className="text-xs text-gray-500">No tools match "{toolSearchTerm}"</p>
|
||||
</div>
|
||||
) : (
|
||||
<div
|
||||
|
||||
@ -5908,6 +5908,7 @@ export const enrichPolicyTemplateStream = async (
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
// eslint-disable-next-line no-constant-condition -- stream read loop
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
@ -5982,6 +5983,7 @@ export const usageAiChatStream = async (
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
|
||||
// eslint-disable-next-line no-constant-condition -- stream read loop
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
if (done) break;
|
||||
|
||||
@ -5,7 +5,7 @@ import { formatNumberWithCommas } from "@/utils/dataUtils";
|
||||
import { InfoCircleOutlined } from "@ant-design/icons";
|
||||
import { useQueryClient } from "@tanstack/react-query";
|
||||
import { Accordion, AccordionBody, AccordionHeader, Button, Col, Grid, Text, TextInput, Title } from "@tremor/react";
|
||||
import { Button as Button2, Form, Input, Modal, Radio, Select, Switch, Tag, Tooltip } from "antd";
|
||||
import { Button as Button2, Form, Input, message, Modal, Radio, Select, Switch, Tag, Tooltip } from "antd";
|
||||
import debounce from "lodash/debounce";
|
||||
import React, { useCallback, useEffect, useState } from "react";
|
||||
import { rolesWithWriteAccess } from "../../utils/roles";
|
||||
|
||||
@ -998,7 +998,7 @@ const PipelineTestPanel: React.FC<PipelineTestPanelProps> = ({
|
||||
|
||||
{!result && !error && complianceResults.length === 0 && (
|
||||
<div style={{ textAlign: "center", color: "#9ca3af", fontSize: 13, marginTop: 24 }}>
|
||||
Choose a test source above (quick chat or a compliance dataset) and click "Run Test"
|
||||
Choose a test source above (quick chat or a compliance dataset) and click "Run Test"
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import React, { useState, useEffect, useMemo } from "react";
|
||||
import { Card, Button, Spin, message, Checkbox, Badge } from "antd";
|
||||
import { Card, Button, Spin, message, Checkbox } from "antd";
|
||||
import {
|
||||
ShieldCheckIcon,
|
||||
ShieldExclamationIcon,
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
import { render, screen, waitFor } from "@testing-library/react";
|
||||
import { render, screen } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { describe, expect, it, vi } from "vitest";
|
||||
import { KeyInfoHeader, KeyInfoData } from "./KeyInfoHeader";
|
||||
|
||||
@ -4,7 +4,6 @@ import moment from "moment";
|
||||
import { LogEntry } from "../columns";
|
||||
import { formatNumberWithCommas } from "@/utils/dataUtils";
|
||||
import GuardrailViewer from "../GuardrailViewer/GuardrailViewer";
|
||||
import CompliancePanel from "../GuardrailViewer/CompliancePanel";
|
||||
import { CostBreakdownViewer } from "../CostBreakdownViewer";
|
||||
import { ConfigInfoMessage } from "../ConfigInfoMessage";
|
||||
import { VectorStoreViewer } from "../VectorStoreViewer";
|
||||
|
||||
@ -15,7 +15,7 @@ import { fetchAllKeyAliases } from "../key_team_helpers/filter_helpers";
|
||||
import { KeyResponse, Team } from "../key_team_helpers/key_list";
|
||||
import { PaginatedModelSelect } from "../ModelSelect/PaginatedModelSelect/PaginatedModelSelect";
|
||||
import FilterComponent, { FilterOption } from "../molecules/filter";
|
||||
import { allEndUsersCall, keyInfoV1Call, keyListCall, uiSpendLogsCall } from "../networking";
|
||||
import { allEndUsersCall, keyInfoV1Call, uiSpendLogsCall } from "../networking";
|
||||
import KeyInfoView from "../templates/key_info_view";
|
||||
import AuditLogs from "./audit_logs";
|
||||
import { createColumns, LogEntry, type LogsSortField } from "./columns";
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
// Auto-generated from block_investment.csv — do not edit manually.
|
||||
// Regenerate: python scripts/generate_compliance_prompts.py --csv ... --output ...
|
||||
|
||||
import type { CompliancePrompt, ComplianceFramework } from "./compliancePrompts";
|
||||
import type { CompliancePrompt } from "./compliancePrompts";
|
||||
|
||||
export const financialCompliancePrompts: CompliancePrompt[] = [
|
||||
{
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
// Auto-generated from block_insults.csv — do not edit manually.
|
||||
// Regenerate: python scripts/generate_compliance_prompts.py --csv ... --output ...
|
||||
|
||||
import type { CompliancePrompt, ComplianceFramework } from "./compliancePrompts";
|
||||
import type { CompliancePrompt } from "./compliancePrompts";
|
||||
|
||||
export const insultsCompliancePrompts: CompliancePrompt[] = [
|
||||
{
|
||||
|
||||
Loading…
Reference in New Issue
Block a user