[Feat] Add Tool Policies for AI Gateway (#22732)
* fix: fix ui render * fix: fix minor bugs * refactor: use prisma functions instead of raw sql (safer) * fix(add-new-tiles-to-tool-policies): allow developer to see what's available * feat: ensure tool allowlist runs correctly for tool names + mcp's * refactor: more ui improvements * feat: working key tool blocking * feat(tools): show tool logs * refactor: backend code improvements * refactor: improve log viewer for tools * fix: address PR review feedback for tool access control - Add missing blocked_tools column to root schema.prisma (schema drift) - Invalidate ToolPolicyRegistry after policy mutations so changes take effect immediately - Remove dead code: unused get_effective_policies, get_tool_policies_cached, and helpers Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: race condition in permission resolution and remove duplicate allowlist check - Use atomic update_many with object_permission_id=None to prevent concurrent requests from creating orphaned permission rows and losing tool blocks - Remove duplicate allowed_tools enforcement from guardrail (already enforced in auth layer via check_tools_allowlist) - Move inline uuid import to module level Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * update to account for userAgent * UI - Add ToolDetails * input/output policy * LiteLLM_PolicyAttachmentTable * LiteLLM_PolicyAttachmentTable * fix: add _enqueue_tool_registry_upsert * fix: tool mgmt endpoints * tool mgmt endpoints * Update tests/test_litellm/proxy/db/test_tool_registry_writer.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Update tests/test_litellm/proxy/db/test_tool_registry_writer.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * Update tests/test_litellm/proxy/db/test_tool_registry_writer.py Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> * fix: sync root schema.prisma and fix test_tool_registry_writer for input/output policy - Migrate root schema.prisma LiteLLM_ToolTable from call_policy to input_policy/output_policy, add missing user_agent and last_used_at columns (now consistent with litellm/proxy/schema.prisma and litellm-proxy-extras) - Fix SpendLogToolIndex comment across all three schema files - Fix all call_policy references in test_tool_registry_writer.py: swapped update_tool_policy arguments, wrong get_tools_by_names return type assertions, _mock_tool_row setting call_policy instead of input_policy Addresses Greptile review feedback on PR #22732. Made-with: Cursor --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com> Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
This commit is contained in:
parent
8baa3ae8cb
commit
1f412bc6d8
@ -109,6 +109,8 @@ Key files:
|
||||
- `litellm/proxy/auth/` - Authentication logic
|
||||
- `litellm/proxy/management_endpoints/` - Admin API endpoints
|
||||
|
||||
**Database (proxy)**: Use Prisma model methods (`prisma_client.db.<model>.upsert`, `.find_many`, `.find_unique`, etc.), not raw SQL (`execute_raw`/`query_raw`). See COMMON PITFALLS for details.
|
||||
|
||||
## MCP (MODEL CONTEXT PROTOCOL) SUPPORT
|
||||
|
||||
LiteLLM supports MCP for agent workflows:
|
||||
@ -176,6 +178,7 @@ When opening issues or pull requests, follow these templates:
|
||||
5. **Dependencies**: Keep dependencies minimal and well-justified
|
||||
6. **UI/Backend Contract Mismatch**: When adding a new entity type to the UI, always check whether the backend endpoint accepts a single value or an array. Match the UI control accordingly (single-select vs. multi-select) to avoid silently dropping user selections
|
||||
7. **Missing Tests for New Entity Types**: When adding a new entity type (e.g., in `EntityUsage`, `UsageViewSelect`), always add corresponding tests in the existing test files and update any icon/component mocks
|
||||
8. **Raw SQL in proxy DB code**: Do not use `execute_raw` or `query_raw` for proxy database access. Use Prisma model methods (e.g. `prisma_client.db.litellm_tooltable.upsert()`, `.find_many()`, `.find_unique()`) so behavior stays consistent with the schema, the client stays mockable in tests, and you avoid the pitfalls of hand-written SQL (parameter ordering, type casting, schema drift)
|
||||
|
||||
8. **Do not hardcode model-specific flags**: Put model-specific capability flags in `model_prices_and_context_window.json` and read them via `get_model_info` (or existing helpers like `supports_reasoning`). This prevents users from needing to upgrade LiteLLM each time a new model supports a feature.
|
||||
|
||||
|
||||
@ -107,6 +107,10 @@ LiteLLM is a unified interface for 100+ LLM providers with two main components:
|
||||
- Migration files auto-generated with `prisma migrate dev`
|
||||
- Always test migrations against both PostgreSQL and SQLite
|
||||
|
||||
### Proxy database access
|
||||
- **Do not write raw SQL** for proxy DB operations. Use Prisma model methods instead of `execute_raw` / `query_raw`.
|
||||
- Use the generated client: `prisma_client.db.<model>` (e.g. `litellm_tooltable`, `litellm_usertable`) with `.upsert()`, `.find_many()`, `.find_unique()`, `.update()`, `.update_many()` as appropriate. This avoids schema/client drift, keeps code testable with simple mocks, and matches patterns used in spend logs and other proxy code.
|
||||
|
||||
### Enterprise Features
|
||||
- Enterprise-specific code in `enterprise/` directory
|
||||
- Optional features enabled via environment variables
|
||||
|
||||
@ -0,0 +1,2 @@
|
||||
-- AlterTable
|
||||
ALTER TABLE "LiteLLM_ObjectPermissionTable" ADD COLUMN "blocked_tools" TEXT[] DEFAULT ARRAY[]::TEXT[];
|
||||
@ -0,0 +1,11 @@
|
||||
-- CreateTable
|
||||
CREATE TABLE "LiteLLM_SpendLogToolIndex" (
|
||||
"request_id" TEXT NOT NULL,
|
||||
"tool_name" TEXT NOT NULL,
|
||||
"start_time" TIMESTAMP(3) NOT NULL,
|
||||
|
||||
CONSTRAINT "LiteLLM_SpendLogToolIndex_pkey" PRIMARY KEY ("request_id","tool_name")
|
||||
);
|
||||
|
||||
-- CreateIndex
|
||||
CREATE INDEX "LiteLLM_SpendLogToolIndex_tool_name_start_time_idx" ON "LiteLLM_SpendLogToolIndex"("tool_name", "start_time");
|
||||
@ -260,6 +260,7 @@ model LiteLLM_ObjectPermissionTable {
|
||||
vector_stores String[] @default([])
|
||||
agents String[] @default([])
|
||||
agent_access_groups String[] @default([])
|
||||
blocked_tools String[] @default([]) // Tool names blocked for any key/team/user with this permission
|
||||
teams LiteLLM_TeamTable[]
|
||||
projects LiteLLM_ProjectTable[]
|
||||
verification_tokens LiteLLM_VerificationToken[]
|
||||
@ -928,6 +929,16 @@ model LiteLLM_SpendLogGuardrailIndex {
|
||||
@@index([policy_id, start_time])
|
||||
}
|
||||
|
||||
// Index for fast "last N logs for tool" from SpendLogs – see how a tool is called in production
|
||||
model LiteLLM_SpendLogToolIndex {
|
||||
request_id String
|
||||
tool_name String // matches LiteLLM_ToolTable.tool_name; join for input_policy/output_policy etc.
|
||||
start_time DateTime
|
||||
|
||||
@@id([request_id, tool_name])
|
||||
@@index([tool_name, start_time])
|
||||
}
|
||||
|
||||
// Prompt table for storing prompt configurations
|
||||
model LiteLLM_PromptTable {
|
||||
id String @id @default(uuid())
|
||||
@ -1065,26 +1076,31 @@ model LiteLLM_PolicyAttachmentTable {
|
||||
updated_by String?
|
||||
}
|
||||
|
||||
// Global tool registry - auto-discovered from LLM responses; admins set call_policy here
|
||||
// Global tool registry - auto-discovered from LLM responses; admins set input_policy/output_policy here
|
||||
model LiteLLM_ToolTable {
|
||||
tool_id String @id @default(uuid())
|
||||
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
|
||||
origin String? // MCP server name or "user_defined"
|
||||
call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked"
|
||||
call_count Int @default(0) // cumulative number of times this tool was seen
|
||||
assignments Json? @default("{}")
|
||||
key_hash String? // hash of the virtual key that first called this tool
|
||||
team_id String? // team that first called this tool
|
||||
key_alias String? // human-readable alias of the virtual key
|
||||
created_at DateTime @default(now())
|
||||
created_by String?
|
||||
updated_at DateTime @default(now()) @updatedAt
|
||||
updated_by String?
|
||||
tool_id String @id @default(uuid())
|
||||
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
|
||||
origin String? // MCP server name or "user_defined"
|
||||
input_policy String @default("untrusted") // "trusted" | "untrusted" | "blocked"
|
||||
output_policy String @default("untrusted") // "trusted" | "untrusted"
|
||||
call_count Int @default(0) // cumulative number of times this tool was seen
|
||||
assignments Json? @default("{}")
|
||||
key_hash String? // hash of the virtual key that first called this tool
|
||||
team_id String? // team that first called this tool
|
||||
key_alias String? // human-readable alias of the virtual key
|
||||
user_agent String? // user-agent of the first request that discovered this tool
|
||||
last_used_at DateTime? // timestamp of the most recent call
|
||||
created_at DateTime @default(now())
|
||||
created_by String?
|
||||
updated_at DateTime @default(now()) @updatedAt
|
||||
updated_by String?
|
||||
|
||||
@@index([call_policy])
|
||||
@@index([input_policy])
|
||||
@@index([output_policy])
|
||||
@@index([team_id])
|
||||
}
|
||||
|
||||
// Per-(tool, team/key) policy overrides. When present, override replaces global tool policy for that scope.
|
||||
//Unified Access Groups table for storing unified access groups
|
||||
model LiteLLM_AccessGroupTable {
|
||||
access_group_id String @id @default(uuid())
|
||||
|
||||
@ -75,7 +75,7 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
if messages is None:
|
||||
return data
|
||||
|
||||
chat_completion_compatible_request, tool_name_mapping = (
|
||||
chat_completion_compatible_request, _tool_name_mapping = (
|
||||
LiteLLMAnthropicMessagesAdapter().translate_anthropic_to_openai(
|
||||
# Use a shallow copy to avoid mutating request data (pop on litellm_metadata).
|
||||
anthropic_message_request=cast(AnthropicMessagesRequest, data.copy())
|
||||
@ -141,6 +141,14 @@ class AnthropicMessagesHandler(BaseTranslation):
|
||||
|
||||
return data
|
||||
|
||||
def extract_request_tool_names(self, data: dict) -> List[str]:
|
||||
"""Extract tool names from Anthropic messages request (tools[].name)."""
|
||||
names: List[str] = []
|
||||
for tool in data.get("tools") or []:
|
||||
if isinstance(tool, dict) and tool.get("name"):
|
||||
names.append(str(tool["name"]))
|
||||
return names
|
||||
|
||||
def _extract_input_text_and_images(
|
||||
self,
|
||||
message: Dict[str, Any],
|
||||
|
||||
@ -98,3 +98,10 @@ class BaseTranslation(ABC):
|
||||
Optional to override in subclasses.
|
||||
"""
|
||||
return responses_so_far
|
||||
|
||||
def extract_request_tool_names(self, data: dict) -> List[str]:
|
||||
"""
|
||||
Extract tool names from the request body for allowlist/policy checks.
|
||||
Override in tool-capable handlers; default returns [].
|
||||
"""
|
||||
return []
|
||||
|
||||
@ -135,6 +135,19 @@ class OpenAIChatCompletionsHandler(BaseTranslation):
|
||||
|
||||
return data
|
||||
|
||||
def extract_request_tool_names(self, data: dict) -> List[str]:
|
||||
"""Extract tool names from OpenAI chat completions request (tools[].function.name, functions[].name)."""
|
||||
names: List[str] = []
|
||||
for tool in data.get("tools") or []:
|
||||
if isinstance(tool, dict) and tool.get("type") == "function":
|
||||
fn = tool.get("function")
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
names.append(str(fn["name"]))
|
||||
for fn in data.get("functions") or []:
|
||||
if isinstance(fn, dict) and fn.get("name"):
|
||||
names.append(str(fn["name"]))
|
||||
return names
|
||||
|
||||
def _extract_inputs(
|
||||
self,
|
||||
message: Dict[str, Any],
|
||||
|
||||
@ -30,27 +30,22 @@ Output: response.output is List[GenericResponseOutputItem] where each has:
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union, cast
|
||||
|
||||
from openai.types.responses.response_function_tool_call import ResponseFunctionToolCall
|
||||
from openai.types.responses.response_function_tool_call import \
|
||||
ResponseFunctionToolCall
|
||||
from pydantic import BaseModel
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.completion_extras.litellm_responses_transformation.transformation import (
|
||||
LiteLLMResponsesTransformationHandler,
|
||||
OpenAiResponsesToChatCompletionStreamIterator,
|
||||
)
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import BaseTranslation
|
||||
from litellm.responses.litellm_completion_transformation.transformation import (
|
||||
LiteLLMCompletionResponsesConfig,
|
||||
)
|
||||
from litellm.types.llms.openai import (
|
||||
ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolParam,
|
||||
)
|
||||
from litellm.types.responses.main import (
|
||||
GenericResponseOutputItem,
|
||||
OutputFunctionToolCall,
|
||||
OutputText,
|
||||
)
|
||||
OpenAiResponsesToChatCompletionStreamIterator)
|
||||
from litellm.llms.base_llm.guardrail_translation.base_translation import \
|
||||
BaseTranslation
|
||||
from litellm.responses.litellm_completion_transformation.transformation import \
|
||||
LiteLLMCompletionResponsesConfig
|
||||
from litellm.types.llms.openai import (ChatCompletionToolCallChunk,
|
||||
ChatCompletionToolParam)
|
||||
from litellm.types.responses.main import (GenericResponseOutputItem,
|
||||
OutputFunctionToolCall, OutputText)
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -188,6 +183,18 @@ class OpenAIResponsesHandler(BaseTranslation):
|
||||
|
||||
return data
|
||||
|
||||
def extract_request_tool_names(self, data: dict) -> List[str]:
|
||||
"""Extract tool names from Responses API request (tools[].name for function, tools[].server_label for mcp)."""
|
||||
names: List[str] = []
|
||||
for tool in data.get("tools") or []:
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
if tool.get("type") == "function" and tool.get("name"):
|
||||
names.append(str(tool["name"]))
|
||||
elif tool.get("type") == "mcp" and tool.get("server_label"):
|
||||
names.append(str(tool["server_label"]))
|
||||
return names
|
||||
|
||||
def _extract_and_transform_tools(
|
||||
self,
|
||||
tools: List[Dict[str, Any]],
|
||||
|
||||
@ -9779,6 +9779,122 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
"dashscope/qwen3-max-2026-01-23": {
|
||||
"litellm_provider": "dashscope",
|
||||
"max_input_tokens": 258048,
|
||||
"max_output_tokens": 65536,
|
||||
"max_tokens": 65536,
|
||||
"mode": "chat",
|
||||
"source": "https://www.alibabacloud.com/help/en/model-studio/models",
|
||||
"supports_function_calling": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"tiered_pricing": [
|
||||
{
|
||||
"input_cost_per_token": 1.2e-06,
|
||||
"output_cost_per_token": 6e-06,
|
||||
"range": [
|
||||
0,
|
||||
32000.0
|
||||
]
|
||||
},
|
||||
{
|
||||
"input_cost_per_token": 2.4e-06,
|
||||
"output_cost_per_token": 1.2e-05,
|
||||
"range": [
|
||||
32000.0,
|
||||
128000.0
|
||||
]
|
||||
},
|
||||
{
|
||||
"input_cost_per_token": 3e-06,
|
||||
"output_cost_per_token": 1.5e-05,
|
||||
"range": [
|
||||
128000.0,
|
||||
252000.0
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"dashscope/qwen3-next-80b-a3b-instruct": {
|
||||
"input_cost_per_token": 1.5e-07,
|
||||
"litellm_provider": "dashscope",
|
||||
"max_input_tokens": 262144,
|
||||
"max_output_tokens": 65536,
|
||||
"max_tokens": 65536,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.2e-06,
|
||||
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"dashscope/qwen3-next-80b-a3b-thinking": {
|
||||
"input_cost_per_token": 1.5e-07,
|
||||
"litellm_provider": "dashscope",
|
||||
"max_input_tokens": 262144,
|
||||
"max_output_tokens": 65536,
|
||||
"max_tokens": 65536,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.2e-06,
|
||||
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
|
||||
"supports_function_calling": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"dashscope/qwen3-vl-235b-a22b-instruct": {
|
||||
"input_cost_per_token": 4e-07,
|
||||
"litellm_provider": "dashscope",
|
||||
"max_input_tokens": 131072,
|
||||
"max_output_tokens": 32768,
|
||||
"max_tokens": 32768,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.6e-06,
|
||||
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"dashscope/qwen3-vl-235b-a22b-thinking": {
|
||||
"input_cost_per_token": 4e-07,
|
||||
"litellm_provider": "dashscope",
|
||||
"max_input_tokens": 131072,
|
||||
"max_output_tokens": 32768,
|
||||
"max_tokens": 32768,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 4e-06,
|
||||
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
|
||||
"supports_function_calling": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"dashscope/qwen3-vl-32b-instruct": {
|
||||
"input_cost_per_token": 1.6e-07,
|
||||
"litellm_provider": "dashscope",
|
||||
"max_input_tokens": 131072,
|
||||
"max_output_tokens": 32768,
|
||||
"max_tokens": 32768,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 6.4e-07,
|
||||
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"dashscope/qwen3-vl-32b-thinking": {
|
||||
"input_cost_per_token": 1.6e-07,
|
||||
"litellm_provider": "dashscope",
|
||||
"max_input_tokens": 131072,
|
||||
"max_output_tokens": 32768,
|
||||
"max_tokens": 32768,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 2.87e-06,
|
||||
"source": "https://www.alibabacloud.com/help/en/model-studio/model-pricing",
|
||||
"supports_function_calling": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"dashscope/qwen3-vl-plus": {
|
||||
"litellm_provider": "dashscope",
|
||||
"max_input_tokens": 260096,
|
||||
@ -10844,7 +10960,8 @@
|
||||
"output_cost_per_token": 9e-08,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/NousResearch/Hermes-3-Llama-3.1-405B": {
|
||||
"max_tokens": 131072,
|
||||
@ -10854,7 +10971,8 @@
|
||||
"output_cost_per_token": 1e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/NousResearch/Hermes-3-Llama-3.1-70B": {
|
||||
"max_tokens": 131072,
|
||||
@ -10874,7 +10992,8 @@
|
||||
"output_cost_per_token": 4e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen2.5-72B-Instruct": {
|
||||
"max_tokens": 32768,
|
||||
@ -10884,7 +11003,8 @@
|
||||
"output_cost_per_token": 3.9e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen2.5-7B-Instruct": {
|
||||
"max_tokens": 32768,
|
||||
@ -10905,7 +11025,8 @@
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true
|
||||
"supports_vision": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen3-14B": {
|
||||
"max_tokens": 40960,
|
||||
@ -10915,7 +11036,8 @@
|
||||
"output_cost_per_token": 2.4e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen3-235B-A22B": {
|
||||
"max_tokens": 40960,
|
||||
@ -10925,7 +11047,8 @@
|
||||
"output_cost_per_token": 5.4e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen3-235B-A22B-Instruct-2507": {
|
||||
"max_tokens": 262144,
|
||||
@ -10935,7 +11058,8 @@
|
||||
"output_cost_per_token": 6e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen3-235B-A22B-Thinking-2507": {
|
||||
"max_tokens": 262144,
|
||||
@ -10945,7 +11069,8 @@
|
||||
"output_cost_per_token": 2.9e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen3-30B-A3B": {
|
||||
"max_tokens": 40960,
|
||||
@ -10955,7 +11080,8 @@
|
||||
"output_cost_per_token": 2.9e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen3-32B": {
|
||||
"max_tokens": 40960,
|
||||
@ -10965,7 +11091,8 @@
|
||||
"output_cost_per_token": 2.8e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen3-Coder-480B-A35B-Instruct": {
|
||||
"max_tokens": 262144,
|
||||
@ -10975,7 +11102,8 @@
|
||||
"output_cost_per_token": 1.6e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen3-Coder-480B-A35B-Instruct-Turbo": {
|
||||
"max_tokens": 262144,
|
||||
@ -10985,7 +11113,8 @@
|
||||
"output_cost_per_token": 1.2e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen3-Next-80B-A3B-Instruct": {
|
||||
"max_tokens": 262144,
|
||||
@ -10995,7 +11124,8 @@
|
||||
"output_cost_per_token": 1.4e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Qwen/Qwen3-Next-80B-A3B-Thinking": {
|
||||
"max_tokens": 262144,
|
||||
@ -11005,7 +11135,8 @@
|
||||
"output_cost_per_token": 1.4e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/Sao10K/L3-8B-Lunaris-v1-Turbo": {
|
||||
"max_tokens": 8192,
|
||||
@ -11056,7 +11187,8 @@
|
||||
"cache_read_input_token_cost": 3.3e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/anthropic/claude-4-opus": {
|
||||
"max_tokens": 200000,
|
||||
@ -11066,7 +11198,8 @@
|
||||
"output_cost_per_token": 8.25e-05,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/anthropic/claude-4-sonnet": {
|
||||
"max_tokens": 200000,
|
||||
@ -11076,7 +11209,8 @@
|
||||
"output_cost_per_token": 1.65e-05,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/deepseek-ai/DeepSeek-R1": {
|
||||
"max_tokens": 163840,
|
||||
@ -11086,7 +11220,8 @@
|
||||
"output_cost_per_token": 2.4e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/deepseek-ai/DeepSeek-R1-0528": {
|
||||
"max_tokens": 163840,
|
||||
@ -11097,7 +11232,8 @@
|
||||
"cache_read_input_token_cost": 4e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/deepseek-ai/DeepSeek-R1-0528-Turbo": {
|
||||
"max_tokens": 32768,
|
||||
@ -11107,7 +11243,8 @@
|
||||
"output_cost_per_token": 3e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/deepseek-ai/DeepSeek-R1-Distill-Llama-70B": {
|
||||
"max_tokens": 131072,
|
||||
@ -11127,7 +11264,8 @@
|
||||
"output_cost_per_token": 2.7e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/deepseek-ai/DeepSeek-R1-Turbo": {
|
||||
"max_tokens": 40960,
|
||||
@ -11137,7 +11275,8 @@
|
||||
"output_cost_per_token": 3e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/deepseek-ai/DeepSeek-V3": {
|
||||
"max_tokens": 163840,
|
||||
@ -11147,7 +11286,8 @@
|
||||
"output_cost_per_token": 8.9e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/deepseek-ai/DeepSeek-V3-0324": {
|
||||
"max_tokens": 163840,
|
||||
@ -11157,7 +11297,8 @@
|
||||
"output_cost_per_token": 8.8e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/deepseek-ai/DeepSeek-V3.1": {
|
||||
"max_tokens": 163840,
|
||||
@ -11169,7 +11310,8 @@
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true,
|
||||
"supports_reasoning": true
|
||||
"supports_reasoning": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/deepseek-ai/DeepSeek-V3.1-Terminus": {
|
||||
"max_tokens": 163840,
|
||||
@ -11180,7 +11322,8 @@
|
||||
"cache_read_input_token_cost": 2.16e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/google/gemini-2.0-flash-001": {
|
||||
"deprecation_date": "2026-06-01",
|
||||
@ -11191,7 +11334,8 @@
|
||||
"output_cost_per_token": 4e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/google/gemini-2.5-flash": {
|
||||
"max_tokens": 1000000,
|
||||
@ -11201,7 +11345,8 @@
|
||||
"output_cost_per_token": 2.5e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/google/gemini-2.5-pro": {
|
||||
"max_tokens": 1000000,
|
||||
@ -11211,7 +11356,8 @@
|
||||
"output_cost_per_token": 1e-05,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/google/gemma-3-12b-it": {
|
||||
"max_tokens": 131072,
|
||||
@ -11221,7 +11367,8 @@
|
||||
"output_cost_per_token": 1e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/google/gemma-3-27b-it": {
|
||||
"max_tokens": 131072,
|
||||
@ -11231,7 +11378,8 @@
|
||||
"output_cost_per_token": 1.6e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/google/gemma-3-4b-it": {
|
||||
"max_tokens": 131072,
|
||||
@ -11241,7 +11389,8 @@
|
||||
"output_cost_per_token": 8e-08,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/meta-llama/Llama-3.2-11B-Vision-Instruct": {
|
||||
"max_tokens": 131072,
|
||||
@ -11261,7 +11410,8 @@
|
||||
"output_cost_per_token": 2e-08,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/meta-llama/Llama-3.3-70B-Instruct": {
|
||||
"max_tokens": 131072,
|
||||
@ -11271,7 +11421,8 @@
|
||||
"output_cost_per_token": 4e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/meta-llama/Llama-3.3-70B-Instruct-Turbo": {
|
||||
"max_tokens": 131072,
|
||||
@ -11281,6 +11432,7 @@
|
||||
"output_cost_per_token": 3.9e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"deepinfra/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8": {
|
||||
@ -11291,7 +11443,8 @@
|
||||
"output_cost_per_token": 6e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/meta-llama/Llama-4-Scout-17B-16E-Instruct": {
|
||||
"max_tokens": 327680,
|
||||
@ -11301,7 +11454,8 @@
|
||||
"output_cost_per_token": 3e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/meta-llama/Llama-Guard-3-8B": {
|
||||
"max_tokens": 131072,
|
||||
@ -11331,7 +11485,8 @@
|
||||
"output_cost_per_token": 6e-08,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/meta-llama/Meta-Llama-3.1-70B-Instruct": {
|
||||
"max_tokens": 131072,
|
||||
@ -11341,7 +11496,8 @@
|
||||
"output_cost_per_token": 4e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": {
|
||||
"max_tokens": 131072,
|
||||
@ -11351,7 +11507,8 @@
|
||||
"output_cost_per_token": 2.8e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/meta-llama/Meta-Llama-3.1-8B-Instruct": {
|
||||
"max_tokens": 131072,
|
||||
@ -11361,7 +11518,8 @@
|
||||
"output_cost_per_token": 5e-08,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo": {
|
||||
"max_tokens": 131072,
|
||||
@ -11371,7 +11529,8 @@
|
||||
"output_cost_per_token": 3e-08,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/microsoft/WizardLM-2-8x22B": {
|
||||
"max_tokens": 65536,
|
||||
@ -11391,7 +11550,8 @@
|
||||
"output_cost_per_token": 1.4e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/mistralai/Mistral-Nemo-Instruct-2407": {
|
||||
"max_tokens": 131072,
|
||||
@ -11401,7 +11561,8 @@
|
||||
"output_cost_per_token": 4e-08,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/mistralai/Mistral-Small-24B-Instruct-2501": {
|
||||
"max_tokens": 32768,
|
||||
@ -11411,7 +11572,8 @@
|
||||
"output_cost_per_token": 8e-08,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/mistralai/Mistral-Small-3.2-24B-Instruct-2506": {
|
||||
"max_tokens": 128000,
|
||||
@ -11421,7 +11583,8 @@
|
||||
"output_cost_per_token": 2e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/mistralai/Mixtral-8x7B-Instruct-v0.1": {
|
||||
"max_tokens": 32768,
|
||||
@ -11431,7 +11594,8 @@
|
||||
"output_cost_per_token": 4e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/moonshotai/Kimi-K2-Instruct": {
|
||||
"max_tokens": 131072,
|
||||
@ -11441,7 +11605,8 @@
|
||||
"output_cost_per_token": 2e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/moonshotai/Kimi-K2-Instruct-0905": {
|
||||
"max_tokens": 262144,
|
||||
@ -11452,7 +11617,8 @@
|
||||
"cache_read_input_token_cost": 4e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/nvidia/Llama-3.1-Nemotron-70B-Instruct": {
|
||||
"max_tokens": 131072,
|
||||
@ -11462,7 +11628,8 @@
|
||||
"output_cost_per_token": 6e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/nvidia/Llama-3.3-Nemotron-Super-49B-v1.5": {
|
||||
"max_tokens": 131072,
|
||||
@ -11472,7 +11639,8 @@
|
||||
"output_cost_per_token": 4e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/nvidia/NVIDIA-Nemotron-Nano-9B-v2": {
|
||||
"max_tokens": 131072,
|
||||
@ -11482,7 +11650,8 @@
|
||||
"output_cost_per_token": 1.6e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/openai/gpt-oss-120b": {
|
||||
"max_tokens": 131072,
|
||||
@ -11492,7 +11661,8 @@
|
||||
"output_cost_per_token": 4.5e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/openai/gpt-oss-20b": {
|
||||
"max_tokens": 131072,
|
||||
@ -11502,7 +11672,8 @@
|
||||
"output_cost_per_token": 1.5e-07,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepinfra/zai-org/GLM-4.5": {
|
||||
"max_tokens": 131072,
|
||||
@ -11512,7 +11683,8 @@
|
||||
"output_cost_per_token": 1.6e-06,
|
||||
"litellm_provider": "deepinfra",
|
||||
"mode": "chat",
|
||||
"supports_tool_choice": true
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"deepseek/deepseek-chat": {
|
||||
"cache_creation_input_token_cost": 0.0,
|
||||
@ -25806,6 +25978,30 @@
|
||||
"supports_vision": true,
|
||||
"tool_use_system_prompt_tokens": 159
|
||||
},
|
||||
"openrouter/anthropic/claude-sonnet-4.6": {
|
||||
"cache_creation_input_token_cost": 3.75e-06,
|
||||
"cache_creation_input_token_cost_above_200k_tokens": 7.5e-06,
|
||||
"cache_read_input_token_cost": 3e-07,
|
||||
"cache_read_input_token_cost_above_200k_tokens": 6e-07,
|
||||
"input_cost_per_token": 3e-06,
|
||||
"input_cost_per_token_above_200k_tokens": 6e-06,
|
||||
"litellm_provider": "openrouter",
|
||||
"max_input_tokens": 1000000,
|
||||
"max_output_tokens": 128000,
|
||||
"max_tokens": 128000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.5e-05,
|
||||
"output_cost_per_token_above_200k_tokens": 2.25e-05,
|
||||
"source": "https://openrouter.ai/anthropic/claude-sonnet-4.6",
|
||||
"supports_assistant_prefill": true,
|
||||
"supports_computer_use": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true,
|
||||
"tool_use_system_prompt_tokens": 159
|
||||
},
|
||||
"openrouter/anthropic/claude-opus-4.5": {
|
||||
"cache_creation_input_token_cost": 6.25e-06,
|
||||
"cache_read_input_token_cost": 5e-07,
|
||||
@ -26156,6 +26352,39 @@
|
||||
"supports_web_search": true,
|
||||
"tpm": 800000
|
||||
},
|
||||
"openrouter/google/gemini-3.1-pro-preview": {
|
||||
"cache_read_input_token_cost": 2e-07,
|
||||
"cache_read_input_token_cost_above_200k_tokens": 4e-07,
|
||||
"cache_creation_input_token_cost_above_200k_tokens": 2.5e-07,
|
||||
"input_cost_per_token": 2e-06,
|
||||
"input_cost_per_token_above_200k_tokens": 4e-06,
|
||||
"litellm_provider": "openrouter",
|
||||
"max_input_tokens": 1048576,
|
||||
"max_output_tokens": 65536,
|
||||
"max_tokens": 65536,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1.2e-05,
|
||||
"output_cost_per_token_above_200k_tokens": 1.8e-05,
|
||||
"source": "https://openrouter.ai/google/gemini-3.1-pro-preview",
|
||||
"supported_modalities": [
|
||||
"text",
|
||||
"image",
|
||||
"audio",
|
||||
"video"
|
||||
],
|
||||
"supported_output_modalities": [
|
||||
"text"
|
||||
],
|
||||
"supports_audio_input": true,
|
||||
"supports_function_calling": true,
|
||||
"supports_pdf_input": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_response_schema": true,
|
||||
"supports_system_messages": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/gryphe/mythomax-l2-13b": {
|
||||
"input_cost_per_token": 1.875e-06,
|
||||
"litellm_provider": "openrouter",
|
||||
@ -26533,6 +26762,29 @@
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"openrouter/openai/gpt-5.1-codex-max": {
|
||||
"cache_read_input_token_cost": 1.25e-07,
|
||||
"input_cost_per_token": 1.25e-06,
|
||||
"litellm_provider": "openrouter",
|
||||
"max_input_tokens": 400000,
|
||||
"max_output_tokens": 128000,
|
||||
"max_tokens": 128000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 1e-05,
|
||||
"source": "https://openrouter.ai/openai/gpt-5.1-codex-max",
|
||||
"supported_modalities": [
|
||||
"text",
|
||||
"image"
|
||||
],
|
||||
"supported_output_modalities": [
|
||||
"text"
|
||||
],
|
||||
"supports_function_calling": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true
|
||||
},
|
||||
"openrouter/openai/gpt-5.2": {
|
||||
"input_cost_per_image": 0,
|
||||
"cache_read_input_token_cost": 1.75e-07,
|
||||
@ -26687,6 +26939,19 @@
|
||||
"supports_tool_choice": true,
|
||||
"supports_function_calling": true
|
||||
},
|
||||
"openrouter/qwen/qwen3-coder-plus": {
|
||||
"input_cost_per_token": 1e-06,
|
||||
"litellm_provider": "openrouter",
|
||||
"max_input_tokens": 997952,
|
||||
"max_output_tokens": 65536,
|
||||
"max_tokens": 65536,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 5e-06,
|
||||
"source": "https://openrouter.ai/qwen/qwen3-coder-plus",
|
||||
"supports_function_calling": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"openrouter/qwen/qwen3-235b-a22b-2507": {
|
||||
"input_cost_per_token": 7.1e-08,
|
||||
"litellm_provider": "openrouter",
|
||||
@ -26822,6 +27087,19 @@
|
||||
"supports_vision": true,
|
||||
"supports_prompt_caching": false
|
||||
},
|
||||
"openrouter/z-ai/glm-5": {
|
||||
"input_cost_per_token": 8e-07,
|
||||
"litellm_provider": "openrouter",
|
||||
"max_input_tokens": 202752,
|
||||
"max_output_tokens": 128000,
|
||||
"max_tokens": 128000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 2.56e-06,
|
||||
"source": "https://openrouter.ai/z-ai/glm-5",
|
||||
"supports_function_calling": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true
|
||||
},
|
||||
"openrouter/minimax/minimax-m2.1": {
|
||||
"input_cost_per_token": 2.7e-07,
|
||||
"output_cost_per_token": 1.2e-06,
|
||||
@ -34327,6 +34605,36 @@
|
||||
"supports_tool_choice": true,
|
||||
"source": "https://aws.amazon.com/bedrock/pricing/"
|
||||
},
|
||||
"zai/glm-5": {
|
||||
"cache_creation_input_token_cost": 0,
|
||||
"cache_read_input_token_cost": 2e-07,
|
||||
"input_cost_per_token": 1e-06,
|
||||
"output_cost_per_token": 3.2e-06,
|
||||
"litellm_provider": "zai",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 128000,
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"source": "https://docs.z.ai/guides/overview/pricing"
|
||||
},
|
||||
"zai/glm-5-code": {
|
||||
"cache_creation_input_token_cost": 0,
|
||||
"cache_read_input_token_cost": 3e-07,
|
||||
"input_cost_per_token": 1.2e-06,
|
||||
"output_cost_per_token": 5e-06,
|
||||
"litellm_provider": "zai",
|
||||
"max_input_tokens": 200000,
|
||||
"max_output_tokens": 128000,
|
||||
"mode": "chat",
|
||||
"supports_function_calling": true,
|
||||
"supports_prompt_caching": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"source": "https://docs.z.ai/guides/overview/pricing"
|
||||
},
|
||||
"zai/glm-4.7": {
|
||||
"cache_creation_input_token_cost": 0,
|
||||
"cache_read_input_token_cost": 1.1e-07,
|
||||
|
||||
@ -23,33 +23,11 @@ model_list:
|
||||
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "airline-competitor-intent"
|
||||
guardrail_id: "airline-competitor-intent"
|
||||
- guardrail_name: "tool_policy"
|
||||
litellm_params:
|
||||
guardrail: litellm_content_filter
|
||||
mode: pre_call
|
||||
default_on: false
|
||||
competitor_intent_config:
|
||||
brand_self:
|
||||
- emirates
|
||||
- ek
|
||||
competitors:
|
||||
- qatar airways
|
||||
- qatar
|
||||
- etihad
|
||||
locations:
|
||||
- qatar
|
||||
- doha
|
||||
- doh
|
||||
competitor_aliases:
|
||||
qatar airways: [qr, doha airline]
|
||||
qatar: [qr]
|
||||
policy:
|
||||
competitor_comparison: refuse
|
||||
possible_competitor_comparison: reframe
|
||||
threshold_high: 0.70
|
||||
threshold_medium: 0.45
|
||||
threshold_low: 0.30
|
||||
guardrail: tool_policy
|
||||
mode: [pre_call, post_call]
|
||||
default_on: true
|
||||
|
||||
mcp_servers:
|
||||
my_http_server:
|
||||
|
||||
@ -77,6 +77,7 @@ class SupportedDBObjectType(str, enum.Enum):
|
||||
PASS_THROUGH_ENDPOINTS = "pass_through_endpoints"
|
||||
PROMPTS = "prompts"
|
||||
MODEL_COST_MAP = "model_cost_map"
|
||||
TOOLS = "tools"
|
||||
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
@ -2133,7 +2134,7 @@ class ConfigGeneralSettings(LiteLLMPydanticObjectBase):
|
||||
user_header_mappings: Optional[List[UserHeaderMapping]] = None
|
||||
supported_db_objects: Optional[List[SupportedDBObjectType]] = Field(
|
||||
None,
|
||||
description="Fine-grained control over which object types to load from the database when store_model_in_db is True. Available types: 'models', 'mcp', 'guardrails', 'vector_stores', 'pass_through_endpoints', 'prompts', 'model_cost_map'. If not set, all objects are loaded (default behavior).",
|
||||
description="Fine-grained control over which object types to load from the database when store_model_in_db is True. Available types: 'models', 'mcp', 'guardrails', 'vector_stores', 'pass_through_endpoints', 'prompts', 'model_cost_map', 'tools'. If not set, all objects are loaded (default behavior).",
|
||||
)
|
||||
user_mcp_management_mode: Optional[UserMCPManagementMode] = Field(
|
||||
None,
|
||||
@ -3377,6 +3378,11 @@ class ProxyErrorTypes(str, enum.Enum):
|
||||
Team member is already in team
|
||||
"""
|
||||
|
||||
tool_access_denied = "tool_access_denied"
|
||||
"""
|
||||
Tool is not in the allowed tools list for this key/team
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def get_model_access_error_type_for_object(
|
||||
cls, object_type: Literal["key", "user", "team", "org", "project"]
|
||||
@ -4161,6 +4167,7 @@ class ToolDiscoveryQueueItem(TypedDict, total=False):
|
||||
key_hash: Optional[str] # hash of virtual key that triggered discovery
|
||||
team_id: Optional[str] # team that triggered discovery
|
||||
key_alias: Optional[str] # human-readable key alias
|
||||
user_agent: Optional[str] # HTTP User-Agent of the caller
|
||||
|
||||
|
||||
class LiteLLM_ManagedFileTable(LiteLLMPydanticObjectBase):
|
||||
|
||||
@ -58,6 +58,10 @@ from litellm.proxy._types import (
|
||||
)
|
||||
from litellm.proxy.auth.route_checks import RouteChecks
|
||||
from litellm.proxy.db.exception_handler import PrismaDBExceptionHandler
|
||||
from litellm.proxy.guardrails.tool_name_extraction import (
|
||||
TOOL_CAPABLE_CALL_TYPES,
|
||||
extract_request_tool_names,
|
||||
)
|
||||
from litellm.proxy.route_llm_request import route_request
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
|
||||
from litellm.router import Router
|
||||
@ -220,7 +224,48 @@ async def _run_project_checks(
|
||||
)
|
||||
|
||||
|
||||
async def common_checks(
|
||||
async def check_tools_allowlist(
|
||||
request_body: dict,
|
||||
valid_token: Optional[UserAPIKeyAuth],
|
||||
team_object: Optional[LiteLLM_TeamTable],
|
||||
route: str,
|
||||
) -> None:
|
||||
"""
|
||||
Enforce key/team tool allowlist (metadata.allowed_tools). No DB in hot path —
|
||||
effective allowlist is read from valid_token.metadata and valid_token.team_metadata.
|
||||
Raises ProxyException with tool_access_denied if a tool is not allowed.
|
||||
"""
|
||||
from litellm.litellm_core_utils.api_route_to_call_types import (
|
||||
get_call_types_for_route,
|
||||
)
|
||||
|
||||
if valid_token is None:
|
||||
return
|
||||
call_types = get_call_types_for_route(route)
|
||||
if not call_types or not any(ct.value in TOOL_CAPABLE_CALL_TYPES for ct in call_types):
|
||||
return
|
||||
tool_names = extract_request_tool_names(route, request_body)
|
||||
if not tool_names:
|
||||
return
|
||||
key_meta = (valid_token.metadata or {}) if isinstance(valid_token.metadata, dict) else {}
|
||||
team_meta = (valid_token.team_metadata or {}) if isinstance(valid_token.team_metadata, dict) else {}
|
||||
key_allowed = key_meta.get("allowed_tools")
|
||||
team_allowed = team_meta.get("allowed_tools")
|
||||
effective = key_allowed if (isinstance(key_allowed, list) and len(key_allowed) > 0) else team_allowed
|
||||
if not isinstance(effective, list) or len(effective) == 0:
|
||||
return
|
||||
allowed_set = {str(t) for t in effective}
|
||||
disallowed = [n for n in tool_names if n not in allowed_set]
|
||||
if disallowed:
|
||||
raise ProxyException(
|
||||
message=f"Tool(s) {disallowed} are not in the allowed tools list for this key/team.",
|
||||
type=ProxyErrorTypes.tool_access_denied,
|
||||
param="tools",
|
||||
code=status.HTTP_403_FORBIDDEN,
|
||||
)
|
||||
|
||||
|
||||
async def common_checks( # noqa: PLR0915
|
||||
request_body: dict,
|
||||
team_object: Optional[LiteLLM_TeamTable],
|
||||
user_object: Optional[LiteLLM_UserTable],
|
||||
@ -477,6 +522,14 @@ async def common_checks(
|
||||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
# 12. [OPTIONAL] Tool allowlist - key/team allowed_tools (no DB in hot path)
|
||||
await check_tools_allowlist(
|
||||
request_body=request_body,
|
||||
valid_token=valid_token,
|
||||
team_object=team_object,
|
||||
route=route,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@ -13,49 +13,36 @@ import random
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Union,
|
||||
cast,
|
||||
overload,
|
||||
)
|
||||
from typing import (TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union,
|
||||
cast, overload)
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.constants import DB_SPEND_UPDATE_JOB_NAME
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
from litellm.proxy._types import (
|
||||
DB_CONNECTION_ERROR_TYPES,
|
||||
BaseDailySpendTransaction,
|
||||
DailyAgentSpendTransaction,
|
||||
DailyEndUserSpendTransaction,
|
||||
DailyOrganizationSpendTransaction,
|
||||
DailyTagSpendTransaction,
|
||||
DailyTeamSpendTransaction,
|
||||
DailyUserSpendTransaction,
|
||||
DBSpendUpdateTransactions,
|
||||
Litellm_EntityType,
|
||||
LiteLLM_UserTable,
|
||||
SpendLogsMetadata,
|
||||
SpendLogsPayload,
|
||||
SpendUpdateQueueItem,
|
||||
ToolDiscoveryQueueItem,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import (
|
||||
DailySpendUpdateQueue,
|
||||
)
|
||||
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import PodLockManager
|
||||
from litellm.proxy.db.db_transaction_queue.redis_update_buffer import RedisUpdateBuffer
|
||||
from litellm.proxy.db.db_transaction_queue.spend_update_queue import SpendUpdateQueue
|
||||
from litellm.proxy.db.db_transaction_queue.tool_discovery_queue import (
|
||||
ToolDiscoveryQueue,
|
||||
)
|
||||
from litellm.proxy._types import (DB_CONNECTION_ERROR_TYPES,
|
||||
BaseDailySpendTransaction,
|
||||
DailyAgentSpendTransaction,
|
||||
DailyEndUserSpendTransaction,
|
||||
DailyOrganizationSpendTransaction,
|
||||
DailyTagSpendTransaction,
|
||||
DailyTeamSpendTransaction,
|
||||
DailyUserSpendTransaction,
|
||||
DBSpendUpdateTransactions,
|
||||
Litellm_EntityType, LiteLLM_UserTable,
|
||||
SpendLogsMetadata, SpendLogsPayload,
|
||||
SpendUpdateQueueItem, ToolDiscoveryQueueItem)
|
||||
from litellm.proxy.db.db_transaction_queue.daily_spend_update_queue import \
|
||||
DailySpendUpdateQueue
|
||||
from litellm.proxy.db.db_transaction_queue.pod_lock_manager import \
|
||||
PodLockManager
|
||||
from litellm.proxy.db.db_transaction_queue.redis_update_buffer import \
|
||||
RedisUpdateBuffer
|
||||
from litellm.proxy.db.db_transaction_queue.spend_update_queue import \
|
||||
SpendUpdateQueue
|
||||
from litellm.proxy.db.db_transaction_queue.tool_discovery_queue import \
|
||||
ToolDiscoveryQueue
|
||||
from litellm.proxy.route_llm_request import ROUTE_ENDPOINT_MAPPING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -104,12 +91,10 @@ class DBSpendUpdateWriter:
|
||||
end_time: Optional[datetime],
|
||||
response_cost: Optional[float],
|
||||
):
|
||||
from litellm.proxy.proxy_server import (
|
||||
disable_spend_logs,
|
||||
litellm_proxy_budget_name,
|
||||
prisma_client,
|
||||
user_api_key_cache,
|
||||
)
|
||||
from litellm.proxy.proxy_server import (disable_spend_logs,
|
||||
litellm_proxy_budget_name,
|
||||
prisma_client,
|
||||
user_api_key_cache)
|
||||
from litellm.proxy.utils import ProxyUpdateSpend, hash_token
|
||||
|
||||
try:
|
||||
@ -124,9 +109,8 @@ class DBSpendUpdateWriter:
|
||||
hashed_token = token
|
||||
|
||||
## CREATE SPEND LOG PAYLOAD ##
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import (
|
||||
get_logging_payload,
|
||||
)
|
||||
from litellm.proxy.spend_tracking.spend_tracking_utils import \
|
||||
get_logging_payload
|
||||
|
||||
payload = get_logging_payload(
|
||||
kwargs=kwargs,
|
||||
@ -230,6 +214,7 @@ class DBSpendUpdateWriter:
|
||||
_litellm_params = kwargs.get("litellm_params") or {}
|
||||
_metadata = _litellm_params.get("metadata") or {}
|
||||
key_alias = _metadata.get("user_api_key_alias") or None
|
||||
user_agent = _metadata.get("user_agent") or None
|
||||
|
||||
def _enqueue(tool_name: str, origin: str = "user_defined") -> None:
|
||||
self.tool_discovery_queue.add_update(
|
||||
@ -239,17 +224,20 @@ class DBSpendUpdateWriter:
|
||||
key_hash=hashed_token,
|
||||
team_id=team_id,
|
||||
key_alias=key_alias,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
)
|
||||
|
||||
# --- MCP tool calls ---
|
||||
sl_object = kwargs.get("standard_logging_object")
|
||||
if sl_object is not None:
|
||||
mcp_metadata = (
|
||||
sl_object.get("metadata", {}) or {}
|
||||
).get("mcp_tool_call_metadata")
|
||||
mcp_metadata = (sl_object.get("metadata", {}) or {}).get(
|
||||
"mcp_tool_call_metadata"
|
||||
)
|
||||
if mcp_metadata and isinstance(mcp_metadata, dict):
|
||||
tool_name = mcp_metadata.get("namespaced_tool_name") or mcp_metadata.get("name")
|
||||
tool_name = mcp_metadata.get(
|
||||
"namespaced_tool_name"
|
||||
) or mcp_metadata.get("name")
|
||||
mcp_server_name = mcp_metadata.get("mcp_server_name")
|
||||
if tool_name:
|
||||
_enqueue(tool_name, origin=mcp_server_name or "user_defined")
|
||||
@ -280,7 +268,9 @@ class DBSpendUpdateWriter:
|
||||
_enqueue(name)
|
||||
|
||||
# --- Response tool_calls (OpenAI format; Anthropic pass-through converts tool_use here) ---
|
||||
if completion_response is not None and hasattr(completion_response, "choices"):
|
||||
if completion_response is not None and hasattr(
|
||||
completion_response, "choices"
|
||||
):
|
||||
for choice in completion_response.choices or []:
|
||||
message = getattr(choice, "message", None)
|
||||
if message is None:
|
||||
@ -768,19 +758,46 @@ class DBSpendUpdateWriter:
|
||||
daily_end_user_spend_update_transactions,
|
||||
daily_agent_spend_update_transactions,
|
||||
daily_tag_spend_update_transactions,
|
||||
) = await self.redis_update_buffer.get_all_transactions_from_redis_buffer_pipeline()
|
||||
) = (
|
||||
await self.redis_update_buffer.get_all_transactions_from_redis_buffer_pipeline()
|
||||
)
|
||||
|
||||
if db_spend_update_transactions is not None:
|
||||
verbose_proxy_logger.info(
|
||||
"Spend tracking - committing spend updates from Redis to DB: "
|
||||
"keys=%d, users=%d, teams=%d, orgs=%d, end_users=%d, team_members=%d, tags=%d",
|
||||
len(db_spend_update_transactions.get("key_list_transactions") or {}),
|
||||
len(db_spend_update_transactions.get("user_list_transactions") or {}),
|
||||
len(db_spend_update_transactions.get("team_list_transactions") or {}),
|
||||
len(db_spend_update_transactions.get("org_list_transactions") or {}),
|
||||
len(db_spend_update_transactions.get("end_user_list_transactions") or {}),
|
||||
len(db_spend_update_transactions.get("team_member_list_transactions") or {}),
|
||||
len(db_spend_update_transactions.get("tag_list_transactions") or {}),
|
||||
len(
|
||||
db_spend_update_transactions.get("key_list_transactions")
|
||||
or {}
|
||||
),
|
||||
len(
|
||||
db_spend_update_transactions.get("user_list_transactions")
|
||||
or {}
|
||||
),
|
||||
len(
|
||||
db_spend_update_transactions.get("team_list_transactions")
|
||||
or {}
|
||||
),
|
||||
len(
|
||||
db_spend_update_transactions.get("org_list_transactions")
|
||||
or {}
|
||||
),
|
||||
len(
|
||||
db_spend_update_transactions.get(
|
||||
"end_user_list_transactions"
|
||||
)
|
||||
or {}
|
||||
),
|
||||
len(
|
||||
db_spend_update_transactions.get(
|
||||
"team_member_list_transactions"
|
||||
)
|
||||
or {}
|
||||
),
|
||||
len(
|
||||
db_spend_update_transactions.get("tag_list_transactions")
|
||||
or {}
|
||||
),
|
||||
)
|
||||
await self._commit_spend_updates_to_db(
|
||||
prisma_client=prisma_client,
|
||||
@ -985,10 +1002,8 @@ class DBSpendUpdateWriter:
|
||||
Commits all the spend `UPDATE` transactions to the Database
|
||||
|
||||
"""
|
||||
from litellm.proxy.utils import (
|
||||
ProxyUpdateSpend,
|
||||
_raise_failed_update_spend_exception,
|
||||
)
|
||||
from litellm.proxy.utils import (ProxyUpdateSpend,
|
||||
_raise_failed_update_spend_exception)
|
||||
|
||||
### UPDATE USER TABLE ###
|
||||
user_list_transactions = db_spend_update_transactions["user_list_transactions"]
|
||||
@ -1523,14 +1538,14 @@ class DBSpendUpdateWriter:
|
||||
|
||||
# Add cache-related fields if they exist
|
||||
if "cache_read_input_tokens" in transaction:
|
||||
common_data[
|
||||
"cache_read_input_tokens"
|
||||
] = transaction.get("cache_read_input_tokens", 0)
|
||||
common_data["cache_read_input_tokens"] = (
|
||||
transaction.get("cache_read_input_tokens", 0)
|
||||
)
|
||||
if "cache_creation_input_tokens" in transaction:
|
||||
common_data[
|
||||
"cache_creation_input_tokens"
|
||||
] = transaction.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
common_data["cache_creation_input_tokens"] = (
|
||||
transaction.get(
|
||||
"cache_creation_input_tokens", 0
|
||||
)
|
||||
)
|
||||
|
||||
if entity_type == "tag" and "request_id" in transaction:
|
||||
|
||||
147
litellm/proxy/db/spend_log_tool_index.py
Normal file
147
litellm/proxy/db/spend_log_tool_index.py
Normal file
@ -0,0 +1,147 @@
|
||||
"""
|
||||
Track tool usage for the dashboard: insert into SpendLogToolIndex when spend logs
|
||||
are written, so "last N requests for tool X" and "how is this tool called in production"
|
||||
queries are fast.
|
||||
"""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Set
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.litellm_core_utils.safe_json_loads import safe_json_loads
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
def _add_tool_calls_to_set(tool_calls: Any, out: Set[str]) -> None:
|
||||
"""Extract tool names from OpenAI-style tool_calls list into out."""
|
||||
if not isinstance(tool_calls, list):
|
||||
return
|
||||
for tc in tool_calls:
|
||||
if not isinstance(tc, dict):
|
||||
continue
|
||||
fn = tc.get("function")
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name")
|
||||
if name and isinstance(name, str) and name.strip():
|
||||
out.add(name.strip())
|
||||
|
||||
|
||||
def _parse_tool_names_from_payload(payload: Dict[str, Any]) -> Set[str]:
|
||||
"""
|
||||
Extract deduplicated tool names from a spend log payload.
|
||||
Sources: mcp_namespaced_tool_name, response (tool_calls), proxy_server_request (tools).
|
||||
"""
|
||||
tool_names: Set[str] = set()
|
||||
|
||||
# Top-level MCP tool name (single tool per request for that flow)
|
||||
mcp_name = payload.get("mcp_namespaced_tool_name")
|
||||
if mcp_name and isinstance(mcp_name, str) and mcp_name.strip():
|
||||
tool_names.add(mcp_name.strip())
|
||||
|
||||
# Response: OpenAI-style tool_calls[].function.name or choices[0].message.tool_calls
|
||||
response_raw = payload.get("response")
|
||||
if response_raw:
|
||||
response_obj = (
|
||||
safe_json_loads(response_raw, default=None)
|
||||
if isinstance(response_raw, str)
|
||||
else response_raw
|
||||
)
|
||||
if isinstance(response_obj, dict):
|
||||
_add_tool_calls_to_set(response_obj.get("tool_calls"), tool_names)
|
||||
choices = response_obj.get("choices")
|
||||
if isinstance(choices, list) and choices:
|
||||
msg = choices[0].get("message") if isinstance(choices[0], dict) else None
|
||||
if isinstance(msg, dict):
|
||||
_add_tool_calls_to_set(msg.get("tool_calls"), tool_names)
|
||||
|
||||
# Request body: tools[].function.name
|
||||
request_raw = payload.get("proxy_server_request")
|
||||
if request_raw:
|
||||
request_obj = (
|
||||
safe_json_loads(request_raw, default=None)
|
||||
if isinstance(request_raw, str)
|
||||
else request_raw
|
||||
)
|
||||
if isinstance(request_obj, dict):
|
||||
body = request_obj.get("body", request_obj)
|
||||
if isinstance(body, dict):
|
||||
request_obj = body
|
||||
if isinstance(request_obj, dict):
|
||||
tools = request_obj.get("tools")
|
||||
if isinstance(tools, list):
|
||||
for t in tools:
|
||||
if isinstance(t, dict):
|
||||
fn = t.get("function")
|
||||
if isinstance(fn, dict):
|
||||
name = fn.get("name")
|
||||
if name and isinstance(name, str) and name.strip():
|
||||
tool_names.add(name.strip())
|
||||
|
||||
return tool_names
|
||||
|
||||
|
||||
async def process_spend_logs_tool_usage(
|
||||
prisma_client: PrismaClient,
|
||||
logs_to_process: List[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
After spend logs are written: insert SpendLogToolIndex rows from each payload.
|
||||
Extracts tool names from mcp_namespaced_tool_name, response tool_calls, and
|
||||
proxy_server_request tools.
|
||||
"""
|
||||
if not logs_to_process:
|
||||
return
|
||||
|
||||
index_rows: List[Dict[str, Any]] = []
|
||||
|
||||
for payload in logs_to_process:
|
||||
request_id = payload.get("request_id")
|
||||
start_time = payload.get("startTime")
|
||||
if not request_id or not start_time:
|
||||
continue
|
||||
if isinstance(start_time, str):
|
||||
try:
|
||||
start_time = datetime.fromisoformat(
|
||||
start_time.replace("Z", "+00:00")
|
||||
)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
if start_time.tzinfo is None:
|
||||
start_time = start_time.replace(tzinfo=timezone.utc)
|
||||
|
||||
tool_names = _parse_tool_names_from_payload(payload)
|
||||
for tool_name in tool_names:
|
||||
index_rows.append({
|
||||
"request_id": request_id,
|
||||
"tool_name": tool_name,
|
||||
"start_time": start_time,
|
||||
})
|
||||
|
||||
if not index_rows:
|
||||
return
|
||||
|
||||
try:
|
||||
index_data = []
|
||||
for r in index_rows:
|
||||
st = r["start_time"]
|
||||
if isinstance(st, str):
|
||||
try:
|
||||
st = datetime.fromisoformat(st.replace("Z", "+00:00"))
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
if st.tzinfo is None:
|
||||
st = st.replace(tzinfo=timezone.utc)
|
||||
index_data.append({
|
||||
"request_id": r["request_id"],
|
||||
"tool_name": r["tool_name"],
|
||||
"start_time": st,
|
||||
})
|
||||
if index_data:
|
||||
await prisma_client.db.litellm_spendlogtoolindex.create_many(
|
||||
data=index_data,
|
||||
skip_duplicates=True,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(
|
||||
"Tool usage tracking (SpendLogToolIndex) failed (non-fatal): %s", e
|
||||
)
|
||||
@ -2,36 +2,64 @@
|
||||
DB helpers for LiteLLM_ToolTable — the global tool registry.
|
||||
|
||||
Tools are auto-discovered from LLM responses and upserted here.
|
||||
Admins use the management endpoints to read and update call_policy.
|
||||
|
||||
NOTE: Uses raw SQL (query_raw / execute_raw) instead of Prisma model methods
|
||||
because the generated Prisma Python client may not have LiteLLM_ToolTable
|
||||
when running against an older generated schema.
|
||||
Admins use the management endpoints to read and update input_policy / output_policy.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import ToolDiscoveryQueueItem
|
||||
from litellm.types.tool_management import LiteLLM_ToolTableRow, ToolCallPolicy
|
||||
from litellm.types.tool_management import (
|
||||
LiteLLM_ToolTableRow,
|
||||
ToolPolicyOverrideRow,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
|
||||
def _row_to_model(row: dict) -> LiteLLM_ToolTableRow:
|
||||
def _row_to_model(row: Union[dict, Any]) -> LiteLLM_ToolTableRow:
|
||||
"""Convert a Prisma model instance or dict to LiteLLM_ToolTableRow."""
|
||||
model_dump = getattr(row, "model_dump", None)
|
||||
if callable(model_dump):
|
||||
row = model_dump()
|
||||
elif not isinstance(row, dict):
|
||||
row = {
|
||||
k: getattr(row, k, None)
|
||||
for k in (
|
||||
"tool_id",
|
||||
"tool_name",
|
||||
"origin",
|
||||
"input_policy",
|
||||
"output_policy",
|
||||
"call_count",
|
||||
"assignments",
|
||||
"key_hash",
|
||||
"team_id",
|
||||
"key_alias",
|
||||
"user_agent",
|
||||
"last_used_at",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"created_by",
|
||||
"updated_by",
|
||||
)
|
||||
}
|
||||
return LiteLLM_ToolTableRow(
|
||||
tool_id=row.get("tool_id", ""),
|
||||
tool_name=row.get("tool_name", ""),
|
||||
origin=row.get("origin"),
|
||||
call_policy=row.get("call_policy", "untrusted"),
|
||||
input_policy=row.get("input_policy") or "untrusted",
|
||||
output_policy=row.get("output_policy") or "untrusted",
|
||||
call_count=int(row.get("call_count") or 0),
|
||||
assignments=row.get("assignments"),
|
||||
key_hash=row.get("key_hash"),
|
||||
team_id=row.get("team_id"),
|
||||
key_alias=row.get("key_alias"),
|
||||
user_agent=row.get("user_agent"),
|
||||
last_used_at=row.get("last_used_at"),
|
||||
created_at=row.get("created_at"),
|
||||
updated_at=row.get("updated_at"),
|
||||
created_by=row.get("created_by"),
|
||||
@ -44,10 +72,10 @@ async def batch_upsert_tools(
|
||||
items: List[ToolDiscoveryQueueItem],
|
||||
) -> None:
|
||||
"""
|
||||
Batch-upsert tool registry rows via raw SQL.
|
||||
Batch-upsert tool registry rows via Prisma.
|
||||
|
||||
On first insert: sets call_policy = "untrusted" (schema default), call_count = 1.
|
||||
On conflict: increments call_count; preserves existing call_policy.
|
||||
On first insert: sets input_policy/output_policy = "untrusted" (default), call_count = 1.
|
||||
On conflict: increments call_count; preserves existing policies.
|
||||
"""
|
||||
if not items:
|
||||
return
|
||||
@ -55,6 +83,8 @@ async def batch_upsert_tools(
|
||||
data = [item for item in items if item.get("tool_name")]
|
||||
if not data:
|
||||
return
|
||||
now = datetime.now(timezone.utc)
|
||||
table = prisma_client.db.litellm_tooltable
|
||||
for item in data:
|
||||
tool_name = item.get("tool_name", "")
|
||||
origin = item.get("origin") or "user_defined"
|
||||
@ -62,49 +92,52 @@ async def batch_upsert_tools(
|
||||
key_hash = item.get("key_hash")
|
||||
team_id = item.get("team_id")
|
||||
key_alias = item.get("key_alias")
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await prisma_client.db.execute_raw(
|
||||
'INSERT INTO "LiteLLM_ToolTable" '
|
||||
"(tool_id, tool_name, origin, call_policy, call_count, created_by, updated_by, key_hash, team_id, key_alias, created_at, updated_at) "
|
||||
"VALUES ($7, $1, $2, 'untrusted', 1, $3, $3, $4, $5, $6, $8, $8) "
|
||||
"ON CONFLICT (tool_name) DO UPDATE SET "
|
||||
"call_count = \"LiteLLM_ToolTable\".call_count + 1, "
|
||||
"updated_at = $8",
|
||||
tool_name,
|
||||
origin,
|
||||
created_by,
|
||||
key_hash,
|
||||
team_id,
|
||||
key_alias,
|
||||
str(uuid.uuid4()),
|
||||
now,
|
||||
user_agent = item.get("user_agent")
|
||||
await table.upsert(
|
||||
where={"tool_name": tool_name},
|
||||
data={
|
||||
"create": {
|
||||
"tool_id": str(uuid.uuid4()),
|
||||
"tool_name": tool_name,
|
||||
"origin": origin,
|
||||
"input_policy": "untrusted",
|
||||
"output_policy": "untrusted",
|
||||
"call_count": 1,
|
||||
"created_by": created_by,
|
||||
"updated_by": created_by,
|
||||
"key_hash": key_hash,
|
||||
"team_id": team_id,
|
||||
"key_alias": key_alias,
|
||||
"user_agent": user_agent,
|
||||
"last_used_at": now,
|
||||
},
|
||||
"update": {
|
||||
"call_count": {"increment": 1},
|
||||
"updated_at": now,
|
||||
"last_used_at": now,
|
||||
},
|
||||
},
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"tool_registry_writer: upserted %d tool(s)", len(data)
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("tool_registry_writer batch_upsert_tools error: %s", e)
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer batch_upsert_tools error: %s", e
|
||||
)
|
||||
|
||||
|
||||
async def list_tools(
|
||||
prisma_client: "PrismaClient",
|
||||
call_policy: Optional[ToolCallPolicy] = None,
|
||||
input_policy: Optional[str] = None,
|
||||
) -> List[LiteLLM_ToolTableRow]:
|
||||
"""Return all tools, optionally filtered by call_policy."""
|
||||
"""Return all tools, optionally filtered by input_policy."""
|
||||
try:
|
||||
if call_policy is not None:
|
||||
rows = await prisma_client.db.query_raw(
|
||||
'SELECT tool_id, tool_name, origin, call_policy, call_count, assignments, '
|
||||
'key_hash, team_id, key_alias, created_at, updated_at, created_by, updated_by '
|
||||
'FROM "LiteLLM_ToolTable" WHERE call_policy = $1 ORDER BY created_at DESC',
|
||||
call_policy,
|
||||
)
|
||||
else:
|
||||
rows = await prisma_client.db.query_raw(
|
||||
'SELECT tool_id, tool_name, origin, call_policy, call_count, assignments, '
|
||||
'key_hash, team_id, key_alias, created_at, updated_at, created_by, updated_by '
|
||||
'FROM "LiteLLM_ToolTable" ORDER BY created_at DESC',
|
||||
)
|
||||
where = {"input_policy": input_policy} if input_policy is not None else {}
|
||||
rows = await prisma_client.db.litellm_tooltable.find_many(
|
||||
where=where,
|
||||
order={"created_at": "desc"},
|
||||
)
|
||||
return [_row_to_model(row) for row in rows]
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("tool_registry_writer list_tools error: %s", e)
|
||||
@ -117,15 +150,12 @@ async def get_tool(
|
||||
) -> Optional[LiteLLM_ToolTableRow]:
|
||||
"""Return a single tool row by tool_name."""
|
||||
try:
|
||||
rows = await prisma_client.db.query_raw(
|
||||
'SELECT tool_id, tool_name, origin, call_policy, call_count, assignments, '
|
||||
'key_hash, team_id, key_alias, created_at, updated_at, created_by, updated_by '
|
||||
'FROM "LiteLLM_ToolTable" WHERE tool_name = $1',
|
||||
tool_name,
|
||||
row = await prisma_client.db.litellm_tooltable.find_unique(
|
||||
where={"tool_name": tool_name},
|
||||
)
|
||||
if not rows:
|
||||
if row is None:
|
||||
return None
|
||||
return _row_to_model(rows[0])
|
||||
return _row_to_model(row)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("tool_registry_writer get_tool error: %s", e)
|
||||
return None
|
||||
@ -134,46 +164,279 @@ async def get_tool(
|
||||
async def update_tool_policy(
|
||||
prisma_client: "PrismaClient",
|
||||
tool_name: str,
|
||||
call_policy: ToolCallPolicy,
|
||||
updated_by: Optional[str],
|
||||
input_policy: Optional[str] = None,
|
||||
output_policy: Optional[str] = None,
|
||||
) -> Optional[LiteLLM_ToolTableRow]:
|
||||
"""Update the call_policy for a tool. Upserts the row if it does not exist yet."""
|
||||
"""Update input_policy and/or output_policy for a tool. Upserts the row if it does not exist yet."""
|
||||
try:
|
||||
_updated_by = updated_by or "system"
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
await prisma_client.db.execute_raw(
|
||||
'INSERT INTO "LiteLLM_ToolTable" (tool_id, tool_name, call_policy, created_by, updated_by, created_at, updated_at) '
|
||||
"VALUES ($4, $1, $2, $3, $3, $5, $5) "
|
||||
"ON CONFLICT (tool_name) DO UPDATE SET call_policy = $2, updated_by = $3, updated_at = $5",
|
||||
tool_name,
|
||||
call_policy,
|
||||
_updated_by,
|
||||
str(uuid.uuid4()),
|
||||
now,
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
create_data: dict = {
|
||||
"tool_id": str(uuid.uuid4()),
|
||||
"tool_name": tool_name,
|
||||
"input_policy": input_policy or "untrusted",
|
||||
"output_policy": output_policy or "untrusted",
|
||||
"created_by": _updated_by,
|
||||
"updated_by": _updated_by,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
update_data: dict = {
|
||||
"updated_by": _updated_by,
|
||||
"updated_at": now,
|
||||
}
|
||||
if input_policy is not None:
|
||||
update_data["input_policy"] = input_policy
|
||||
if output_policy is not None:
|
||||
update_data["output_policy"] = output_policy
|
||||
|
||||
await prisma_client.db.litellm_tooltable.upsert(
|
||||
where={"tool_name": tool_name},
|
||||
data={
|
||||
"create": create_data,
|
||||
"update": update_data,
|
||||
},
|
||||
)
|
||||
return await get_tool(prisma_client, tool_name)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("tool_registry_writer update_tool_policy error: %s", e)
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer update_tool_policy error: %s", e
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
async def get_tools_by_names(
|
||||
prisma_client: "PrismaClient",
|
||||
tool_names: List[str],
|
||||
) -> Dict[str, str]:
|
||||
) -> Dict[str, Tuple[str, str]]:
|
||||
"""
|
||||
Return a {tool_name: call_policy} map for the given tool names.
|
||||
Used by the policy enforcement guardrail — single batch query, never N+1.
|
||||
Return a {tool_name: (input_policy, output_policy)} map for the given tool names.
|
||||
"""
|
||||
if not tool_names:
|
||||
return {}
|
||||
try:
|
||||
placeholders = ", ".join(f"${i+1}" for i in range(len(tool_names)))
|
||||
rows = await prisma_client.db.query_raw(
|
||||
f'SELECT tool_name, call_policy FROM "LiteLLM_ToolTable" WHERE tool_name IN ({placeholders})',
|
||||
*tool_names,
|
||||
rows = await prisma_client.db.litellm_tooltable.find_many(
|
||||
where={"tool_name": {"in": tool_names}},
|
||||
)
|
||||
return {row["tool_name"]: row["call_policy"] for row in rows}
|
||||
return {
|
||||
row.tool_name: (
|
||||
getattr(row, "input_policy", "untrusted") or "untrusted",
|
||||
getattr(row, "output_policy", "untrusted") or "untrusted",
|
||||
)
|
||||
for row in rows
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error("tool_registry_writer get_tools_by_names error: %s", e)
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer get_tools_by_names error: %s", e
|
||||
)
|
||||
return {}
|
||||
|
||||
|
||||
async def list_overrides_for_tool(
|
||||
prisma_client: "PrismaClient",
|
||||
tool_name: str,
|
||||
) -> List[ToolPolicyOverrideRow]:
|
||||
"""
|
||||
Return override-like rows for a tool by finding object permissions that have
|
||||
this tool in blocked_tools, then resolving each permission to key/team scope for display.
|
||||
"""
|
||||
out: List[ToolPolicyOverrideRow] = []
|
||||
try:
|
||||
perms = await prisma_client.db.litellm_objectpermissiontable.find_many(
|
||||
where={"blocked_tools": {"has": tool_name}},
|
||||
include={
|
||||
"verification_tokens": True,
|
||||
"teams": True,
|
||||
},
|
||||
)
|
||||
for perm in perms:
|
||||
op_id = getattr(perm, "object_permission_id", None) or ""
|
||||
tokens = getattr(perm, "verification_tokens", []) or []
|
||||
teams = getattr(perm, "teams", []) or []
|
||||
for t in tokens:
|
||||
out.append(
|
||||
ToolPolicyOverrideRow(
|
||||
override_id=op_id,
|
||||
tool_name=tool_name,
|
||||
team_id=None,
|
||||
key_hash=getattr(t, "token", None),
|
||||
input_policy="blocked",
|
||||
key_alias=getattr(t, "key_alias", None),
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
)
|
||||
)
|
||||
for team in teams:
|
||||
out.append(
|
||||
ToolPolicyOverrideRow(
|
||||
override_id=op_id,
|
||||
tool_name=tool_name,
|
||||
team_id=getattr(team, "team_id", None),
|
||||
key_hash=None,
|
||||
input_policy="blocked",
|
||||
key_alias=getattr(team, "team_alias", None),
|
||||
created_at=None,
|
||||
updated_at=None,
|
||||
)
|
||||
)
|
||||
return out
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer list_overrides_for_tool error: %s", e
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
class ToolPolicyRegistry:
|
||||
"""
|
||||
In-memory registry of tool policies synced from DB.
|
||||
Hot path uses get_effective_policies only — no DB, no cache.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tool_input_policies: Dict[str, str] = {}
|
||||
self._tool_output_policies: Dict[str, str] = {}
|
||||
self._blocked_tools_by_op_id: Dict[str, List[str]] = {}
|
||||
self._initialized: bool = False
|
||||
|
||||
def is_initialized(self) -> bool:
|
||||
return self._initialized
|
||||
|
||||
async def sync_tool_policy_from_db(self, prisma_client: "PrismaClient") -> None:
|
||||
"""Load all tool policies and object-permission blocked_tools from DB."""
|
||||
try:
|
||||
tools = await prisma_client.db.litellm_tooltable.find_many()
|
||||
self._tool_input_policies = {
|
||||
row.tool_name: getattr(row, "input_policy", "untrusted") or "untrusted"
|
||||
for row in tools
|
||||
}
|
||||
self._tool_output_policies = {
|
||||
row.tool_name: getattr(row, "output_policy", "untrusted") or "untrusted"
|
||||
for row in tools
|
||||
}
|
||||
|
||||
perms = await prisma_client.db.litellm_objectpermissiontable.find_many()
|
||||
self._blocked_tools_by_op_id = {}
|
||||
for row in perms:
|
||||
op_id = getattr(row, "object_permission_id", None)
|
||||
blocked = getattr(row, "blocked_tools", None) or []
|
||||
if op_id:
|
||||
self._blocked_tools_by_op_id[op_id] = list(blocked)
|
||||
|
||||
self._initialized = True
|
||||
verbose_proxy_logger.info(
|
||||
"ToolPolicyRegistry: synced %d tool policies and %d object permissions from DB",
|
||||
len(self._tool_input_policies),
|
||||
len(self._blocked_tools_by_op_id),
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"ToolPolicyRegistry sync_tool_policy_from_db error: %s", e
|
||||
)
|
||||
raise
|
||||
|
||||
def get_input_policy(self, tool_name: str) -> str:
|
||||
return self._tool_input_policies.get(tool_name, "untrusted")
|
||||
|
||||
def get_output_policy(self, tool_name: str) -> str:
|
||||
return self._tool_output_policies.get(tool_name, "untrusted")
|
||||
|
||||
def get_effective_policies(
|
||||
self,
|
||||
tool_names: List[str],
|
||||
object_permission_id: Optional[str] = None,
|
||||
team_object_permission_id: Optional[str] = None,
|
||||
) -> Dict[str, str]:
|
||||
"""
|
||||
Return effective input_policy per tool from in-memory state.
|
||||
If tool is in key or team blocked_tools -> "blocked", else global input_policy or "untrusted".
|
||||
"""
|
||||
if not tool_names:
|
||||
return {}
|
||||
blocked: set = set()
|
||||
for op_id in (object_permission_id, team_object_permission_id):
|
||||
if op_id and op_id.strip():
|
||||
blocked.update(
|
||||
self._blocked_tools_by_op_id.get(op_id.strip(), [])
|
||||
)
|
||||
result: Dict[str, str] = {}
|
||||
for name in tool_names:
|
||||
if name in blocked:
|
||||
result[name] = "blocked"
|
||||
else:
|
||||
result[name] = self._tool_input_policies.get(name, "untrusted")
|
||||
return result
|
||||
|
||||
|
||||
_tool_policy_registry: Optional[ToolPolicyRegistry] = None
|
||||
|
||||
|
||||
def get_tool_policy_registry() -> ToolPolicyRegistry:
|
||||
"""Return the global ToolPolicyRegistry singleton."""
|
||||
global _tool_policy_registry
|
||||
if _tool_policy_registry is None:
|
||||
_tool_policy_registry = ToolPolicyRegistry()
|
||||
return _tool_policy_registry
|
||||
|
||||
|
||||
async def add_tool_to_object_permission_blocked(
|
||||
prisma_client: "PrismaClient",
|
||||
object_permission_id: str,
|
||||
tool_name: str,
|
||||
) -> bool:
|
||||
"""Add tool_name to the permission's blocked_tools if not already present."""
|
||||
if not object_permission_id or not tool_name:
|
||||
return False
|
||||
try:
|
||||
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
)
|
||||
if row is None:
|
||||
return False
|
||||
current = list(getattr(row, "blocked_tools", []) or [])
|
||||
if tool_name in current:
|
||||
return True
|
||||
current.append(tool_name)
|
||||
await prisma_client.db.litellm_objectpermissiontable.update(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
data={"blocked_tools": current},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer add_tool_to_object_permission_blocked error: %s", e
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def remove_tool_from_object_permission_blocked(
|
||||
prisma_client: "PrismaClient",
|
||||
object_permission_id: str,
|
||||
tool_name: str,
|
||||
) -> bool:
|
||||
"""Remove tool_name from the permission's blocked_tools. Returns False if tool was not in list."""
|
||||
if not object_permission_id or not tool_name:
|
||||
return False
|
||||
try:
|
||||
row = await prisma_client.db.litellm_objectpermissiontable.find_unique(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
)
|
||||
if row is None:
|
||||
return False
|
||||
current = list(getattr(row, "blocked_tools", []) or [])
|
||||
if tool_name not in current:
|
||||
return False
|
||||
current = [t for t in current if t != tool_name]
|
||||
await prisma_client.db.litellm_objectpermissiontable.update(
|
||||
where={"object_permission_id": object_permission_id},
|
||||
data={"blocked_tools": current},
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"tool_registry_writer remove_tool_from_object_permission_blocked error: %s",
|
||||
e,
|
||||
)
|
||||
return False
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
"""
|
||||
Tool Policy Guardrail
|
||||
|
||||
Reads call_policy from LiteLLM_ToolTable and enforces it on LLM requests/responses.
|
||||
Reads input_policy / output_policy from LiteLLM_ToolTable and enforces them.
|
||||
|
||||
Policy values:
|
||||
"trusted" - allow through (no action)
|
||||
"untrusted" - allow through (no action; default for newly discovered tools)
|
||||
Input policy values:
|
||||
"untrusted" - allow through (default for newly discovered tools)
|
||||
"trusted" - only allow if conversation contains no untrusted tool output
|
||||
"blocked" - raise HTTPException, preventing the tool call
|
||||
"dual_llm" - (Phase 3) send to second LLM for verification; currently treated as allowed
|
||||
|
||||
Output policy values:
|
||||
"untrusted" - output may be tainted (default)
|
||||
"trusted" - output is verified safe
|
||||
|
||||
Configuration in proxy config YAML:
|
||||
guardrails:
|
||||
@ -15,25 +18,18 @@ Configuration in proxy config YAML:
|
||||
litellm_params:
|
||||
guardrail: tool_policy
|
||||
mode: post_call
|
||||
|
||||
or both pre and post call:
|
||||
- guardrail_name: "tool_policy"
|
||||
litellm_params:
|
||||
guardrail: tool_policy
|
||||
mode: during_call # runs before LLM and on response
|
||||
"""
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching.dual_cache import DualCache
|
||||
from litellm.constants import TOOL_POLICY_CACHE_TTL_SECONDS
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.proxy.guardrails.tool_name_extraction import extract_request_tool_names
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
from litellm.types.utils import GenericGuardrailAPIInputs
|
||||
|
||||
@ -43,12 +39,71 @@ if TYPE_CHECKING:
|
||||
GUARDRAIL_NAME = "tool_policy"
|
||||
|
||||
|
||||
def _get_request_object_permission_ids(
|
||||
request_data: dict,
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Extract object_permission_id and team_object_permission_id from request_data."""
|
||||
if not request_data:
|
||||
return None, None
|
||||
for key in ("litellm_metadata", "metadata"):
|
||||
meta = request_data.get(key)
|
||||
if not isinstance(meta, dict):
|
||||
continue
|
||||
auth = meta.get("user_api_key_auth")
|
||||
if auth is not None and hasattr(auth, "object_permission_id"):
|
||||
key_op = getattr(auth, "object_permission_id", None)
|
||||
team_op = getattr(auth, "team_object_permission_id", None)
|
||||
if key_op is not None or team_op is not None:
|
||||
return (
|
||||
str(key_op).strip() if key_op else None,
|
||||
str(team_op).strip() if team_op else None,
|
||||
)
|
||||
key_op = meta.get("user_api_key_object_permission_id")
|
||||
team_op = meta.get("user_api_key_team_object_permission_id")
|
||||
if key_op is not None or team_op is not None:
|
||||
return (
|
||||
str(key_op).strip() if key_op else None,
|
||||
str(team_op).strip() if team_op else None,
|
||||
)
|
||||
return None, None
|
||||
|
||||
|
||||
def _get_request_route_from_data(request_data: dict) -> Optional[str]:
|
||||
"""Get request route from request_data (metadata or top-level)."""
|
||||
route = request_data.get("user_api_key_request_route")
|
||||
if route:
|
||||
return route
|
||||
meta = request_data.get("metadata") or request_data.get("litellm_metadata") or {}
|
||||
return meta.get("user_api_key_request_route")
|
||||
|
||||
|
||||
def _resolve_tool_names_from_messages(messages: List[dict]) -> Dict[str, str]:
|
||||
"""
|
||||
Build a map of tool_call_id -> tool_name from assistant messages' tool_calls.
|
||||
Used to resolve which tool produced each tool result in the conversation.
|
||||
"""
|
||||
mapping: Dict[str, str] = {}
|
||||
for msg in messages:
|
||||
if msg.get("role") != "assistant":
|
||||
continue
|
||||
tool_calls = msg.get("tool_calls") or []
|
||||
for tc in tool_calls:
|
||||
if isinstance(tc, dict):
|
||||
tc_id = tc.get("id")
|
||||
fn = (tc.get("function") or {}).get("name")
|
||||
else:
|
||||
tc_id = getattr(tc, "id", None)
|
||||
fn_obj = getattr(tc, "function", None)
|
||||
fn = getattr(fn_obj, "name", None) if fn_obj else None
|
||||
if tc_id and fn:
|
||||
mapping[tc_id] = fn
|
||||
return mapping
|
||||
|
||||
|
||||
class ToolPolicyGuardrail(CustomGuardrail):
|
||||
"""
|
||||
Guardrail that enforces per-tool call policies stored in LiteLLM_ToolTable.
|
||||
|
||||
Tools with call_policy="blocked" are rejected before/after the LLM call.
|
||||
Tools with call_policy="trusted" or "untrusted" pass through unchanged.
|
||||
Guardrail that enforces per-tool input/output policies from the in-memory
|
||||
ToolPolicyRegistry (synced from DB).
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
@ -59,7 +114,6 @@ class ToolPolicyGuardrail(CustomGuardrail):
|
||||
GuardrailEventHooks.during_call,
|
||||
]
|
||||
super().__init__(**kwargs)
|
||||
self._policy_cache: DualCache = DualCache()
|
||||
|
||||
@log_guardrail_information
|
||||
async def apply_guardrail(
|
||||
@ -70,12 +124,7 @@ class ToolPolicyGuardrail(CustomGuardrail):
|
||||
logging_obj: Optional["LiteLLMLoggingObj"] = None,
|
||||
) -> GenericGuardrailAPIInputs:
|
||||
"""
|
||||
Enforce tool policies on both request tools and response tool_calls.
|
||||
|
||||
- input_type="request": check inputs["tools"] (tool definitions in the LLM request)
|
||||
- input_type="response": check inputs["tool_calls"] (tool_calls in the LLM response)
|
||||
|
||||
Raises HTTPException (400) if any tool is "blocked".
|
||||
Enforce input_policy and output_policy trust chain on request tools / response tool_calls.
|
||||
"""
|
||||
if input_type == "request":
|
||||
tools = inputs.get("tools") or []
|
||||
@ -86,7 +135,11 @@ class ToolPolicyGuardrail(CustomGuardrail):
|
||||
and isinstance(t.get("function"), dict)
|
||||
and t["function"].get("name")
|
||||
]
|
||||
else: # response
|
||||
if not tool_names:
|
||||
route = _get_request_route_from_data(request_data)
|
||||
if route:
|
||||
tool_names = extract_request_tool_names(route, request_data)
|
||||
else:
|
||||
tool_calls = inputs.get("tool_calls") or []
|
||||
tool_names = []
|
||||
for tc in tool_calls:
|
||||
@ -101,12 +154,25 @@ class ToolPolicyGuardrail(CustomGuardrail):
|
||||
if not tool_names:
|
||||
return inputs
|
||||
|
||||
policy_map = await self._get_policies_cached(tool_names)
|
||||
object_permission_id, team_object_permission_id = (
|
||||
_get_request_object_permission_ids(request_data)
|
||||
)
|
||||
from litellm.proxy.db.tool_registry_writer import get_tool_policy_registry
|
||||
|
||||
registry = get_tool_policy_registry()
|
||||
if not registry.is_initialized():
|
||||
return inputs
|
||||
|
||||
# Stage 1: Check for blocked tools (input_policy=blocked or per-key/team override)
|
||||
policy_map = registry.get_effective_policies(
|
||||
tool_names,
|
||||
object_permission_id=object_permission_id,
|
||||
team_object_permission_id=team_object_permission_id,
|
||||
)
|
||||
blocked = [name for name in tool_names if policy_map.get(name) == "blocked"]
|
||||
if blocked:
|
||||
verbose_proxy_logger.warning(
|
||||
"ToolPolicyGuardrail: blocking tool(s) %s (policy=blocked)", blocked
|
||||
"ToolPolicyGuardrail: blocking tool(s) %s (input_policy=blocked)", blocked
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@ -117,47 +183,47 @@ class ToolPolicyGuardrail(CustomGuardrail):
|
||||
},
|
||||
)
|
||||
|
||||
# Stage 2: Trust chain enforcement (response path only)
|
||||
# For each tool with input_policy=trusted, check if conversation
|
||||
# contains output from tools with output_policy=untrusted
|
||||
if input_type == "response":
|
||||
trusted_input_tools = [
|
||||
name for name in tool_names if policy_map.get(name) == "trusted"
|
||||
]
|
||||
if trusted_input_tools:
|
||||
messages = request_data.get("messages") or []
|
||||
tc_id_to_name = _resolve_tool_names_from_messages(messages)
|
||||
|
||||
untrusted_sources: List[str] = []
|
||||
for msg in messages:
|
||||
if msg.get("role") != "tool":
|
||||
continue
|
||||
tool_call_id = msg.get("tool_call_id")
|
||||
source_tool = tc_id_to_name.get(tool_call_id, "") if tool_call_id else ""
|
||||
if not source_tool:
|
||||
continue
|
||||
if registry.get_output_policy(source_tool) == "untrusted":
|
||||
if source_tool not in untrusted_sources:
|
||||
untrusted_sources.append(source_tool)
|
||||
|
||||
if untrusted_sources:
|
||||
verbose_proxy_logger.warning(
|
||||
"ToolPolicyGuardrail: trust chain violation — %s require trusted input "
|
||||
"but conversation has untrusted output from %s",
|
||||
trusted_input_tools,
|
||||
untrusted_sources,
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Violated tool policy",
|
||||
"blocked_tools": trusted_input_tools,
|
||||
"untrusted_sources": untrusted_sources,
|
||||
"message": (
|
||||
f"{', '.join(trusted_input_tools)} requires trusted input but "
|
||||
f"conversation contains untrusted output from {', '.join(untrusted_sources)}."
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
return inputs
|
||||
|
||||
async def _get_policies_cached(self, tool_names: List[str]) -> Dict[str, str]:
|
||||
"""
|
||||
Batch-fetch call_policy for the given tool names.
|
||||
|
||||
Caches per individual tool name (not per combination) so that adding
|
||||
a new tool to a request doesn't invalidate the cached policies for all
|
||||
the other tools already in the cache.
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import get_tools_by_names
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if not tool_names or prisma_client is None:
|
||||
return {}
|
||||
|
||||
result: Dict[str, str] = {}
|
||||
cache_misses: List[str] = []
|
||||
|
||||
for name in tool_names:
|
||||
cached = await self._policy_cache.async_get_cache(f"tool_policy:{name}")
|
||||
if cached is not None and isinstance(cached, str):
|
||||
result[name] = cached
|
||||
else:
|
||||
cache_misses.append(name)
|
||||
|
||||
if cache_misses:
|
||||
fetched = await get_tools_by_names(
|
||||
prisma_client=prisma_client, tool_names=cache_misses
|
||||
)
|
||||
for name, policy in fetched.items():
|
||||
result[name] = policy
|
||||
await self._policy_cache.async_set_cache(
|
||||
key=f"tool_policy:{name}",
|
||||
value=policy,
|
||||
ttl=TOOL_POLICY_CACHE_TTL_SECONDS,
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
"ToolPolicyGuardrail: fetched %d policies from DB (cache hits: %d)",
|
||||
len(cache_misses),
|
||||
len(tool_names) - len(cache_misses),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
85
litellm/proxy/guardrails/tool_name_extraction.py
Normal file
85
litellm/proxy/guardrails/tool_name_extraction.py
Normal file
@ -0,0 +1,85 @@
|
||||
"""
|
||||
Extract tool names from request body by route/call type.
|
||||
|
||||
Used by auth (check_tools_allowlist) and ToolPolicyGuardrail so tool-format
|
||||
knowledge lives in one place. Uses guardrail translation handlers where available,
|
||||
with standalone extractors for generate_content and MCP.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from litellm.litellm_core_utils.api_route_to_call_types import get_call_types_for_route
|
||||
from litellm.llms import load_guardrail_translation_mappings
|
||||
from litellm.types.utils import CallTypes
|
||||
|
||||
# Call types that have no guardrail translation handler; we use standalone extractors
|
||||
STANDALONE_EXTRACTORS: Dict[str, Any] = {}
|
||||
|
||||
|
||||
def _extract_generate_content_tool_names(data: dict) -> List[str]:
|
||||
"""Google generateContent: tools[].functionDeclarations[].name"""
|
||||
names: List[str] = []
|
||||
for tool in data.get("tools") or []:
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
for decl in tool.get("functionDeclarations") or []:
|
||||
if isinstance(decl, dict) and decl.get("name"):
|
||||
names.append(str(decl["name"]))
|
||||
return names
|
||||
|
||||
|
||||
def _extract_mcp_tool_names(data: dict) -> List[str]:
|
||||
"""MCP call_tool: name or mcp_tool_name in body"""
|
||||
names: List[str] = []
|
||||
name = data.get("name") or data.get("mcp_tool_name")
|
||||
if name:
|
||||
names.append(str(name))
|
||||
return names
|
||||
|
||||
|
||||
def _register_standalone_extractors() -> None:
|
||||
if STANDALONE_EXTRACTORS:
|
||||
return
|
||||
STANDALONE_EXTRACTORS[CallTypes.generate_content.value] = _extract_generate_content_tool_names
|
||||
STANDALONE_EXTRACTORS[CallTypes.agenerate_content.value] = _extract_generate_content_tool_names
|
||||
STANDALONE_EXTRACTORS[CallTypes.call_mcp_tool.value] = _extract_mcp_tool_names
|
||||
|
||||
|
||||
# Tool-capable call types (routes that can send tools in the request)
|
||||
TOOL_CAPABLE_CALL_TYPES = frozenset({
|
||||
CallTypes.completion.value,
|
||||
CallTypes.acompletion.value,
|
||||
CallTypes.responses.value,
|
||||
CallTypes.aresponses.value,
|
||||
CallTypes.anthropic_messages.value,
|
||||
CallTypes.generate_content.value,
|
||||
CallTypes.agenerate_content.value,
|
||||
CallTypes.call_mcp_tool.value,
|
||||
})
|
||||
|
||||
|
||||
def extract_request_tool_names(route: str, data: dict) -> List[str]:
|
||||
"""
|
||||
Extract tool names from the request body for the given route.
|
||||
Uses guardrail translation handlers when available, else standalone extractors
|
||||
for generate_content and MCP. Returns [] for non-tool-capable routes or when
|
||||
no tools are present.
|
||||
"""
|
||||
call_types = get_call_types_for_route(route)
|
||||
if not call_types:
|
||||
return []
|
||||
_register_standalone_extractors()
|
||||
mappings = load_guardrail_translation_mappings()
|
||||
for call_type in call_types:
|
||||
if not isinstance(call_type, CallTypes):
|
||||
continue
|
||||
if call_type.value not in TOOL_CAPABLE_CALL_TYPES:
|
||||
continue
|
||||
if call_type.value in STANDALONE_EXTRACTORS:
|
||||
return STANDALONE_EXTRACTORS[call_type.value](data)
|
||||
handler_cls = mappings.get(call_type)
|
||||
if handler_cls is not None:
|
||||
names = handler_cls().extract_request_tool_names(data)
|
||||
if names:
|
||||
return names
|
||||
return []
|
||||
@ -1091,6 +1091,15 @@ async def add_litellm_data_to_request( # noqa: PLR0915
|
||||
] = user_api_key_dict.user_max_budget
|
||||
|
||||
data[_metadata_variable_name]["user_api_key_metadata"] = user_api_key_dict.metadata
|
||||
data[_metadata_variable_name]["user_api_key_team_metadata"] = (
|
||||
user_api_key_dict.team_metadata
|
||||
)
|
||||
data[_metadata_variable_name]["user_api_key_object_permission_id"] = (
|
||||
getattr(user_api_key_dict, "object_permission_id", None)
|
||||
)
|
||||
data[_metadata_variable_name]["user_api_key_team_object_permission_id"] = (
|
||||
getattr(user_api_key_dict, "team_object_permission_id", None)
|
||||
)
|
||||
data[_metadata_variable_name]["headers"] = _headers
|
||||
data[_metadata_variable_name]["endpoint"] = str(request.url)
|
||||
|
||||
|
||||
@ -4,27 +4,87 @@ TOOL POLICY MANAGEMENT
|
||||
All /tool management endpoints
|
||||
|
||||
GET /v1/tool/list - List all discovered tools and their policies
|
||||
GET /v1/tool/policy/options - List available input/output policy options with descriptions
|
||||
GET /v1/tool/{tool_name} - Get a single tool's details
|
||||
POST /v1/tool/policy - Update the call_policy for a tool
|
||||
POST /v1/tool/policy - Update the input_policy / output_policy for a tool
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.types.tool_management import (
|
||||
LiteLLM_ToolTableRow,
|
||||
ToolCallPolicy,
|
||||
ToolDetailResponse,
|
||||
ToolInputPolicy,
|
||||
ToolListResponse,
|
||||
ToolOutputPolicy,
|
||||
ToolPolicyOption,
|
||||
ToolPolicyOptionsResponse,
|
||||
ToolPolicyUpdateRequest,
|
||||
ToolPolicyUpdateResponse,
|
||||
ToolUsageLogEntry,
|
||||
ToolUsageLogsResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
TOOL_POLICY_OPTIONS = ToolPolicyOptionsResponse(
|
||||
input_policies=[
|
||||
ToolPolicyOption(
|
||||
value="untrusted",
|
||||
label="Untrusted",
|
||||
description="Tool accepts any input, including data from untrusted tool outputs. Default for newly discovered tools.",
|
||||
),
|
||||
ToolPolicyOption(
|
||||
value="trusted",
|
||||
label="Trusted",
|
||||
description="Tool requires trusted input. Blocked if the conversation contains output from any tool with output_policy=untrusted.",
|
||||
),
|
||||
ToolPolicyOption(
|
||||
value="blocked",
|
||||
label="Blocked",
|
||||
description="Tool is completely prohibited. Any attempt to call it is rejected.",
|
||||
),
|
||||
],
|
||||
output_policies=[
|
||||
ToolPolicyOption(
|
||||
value="untrusted",
|
||||
label="Untrusted",
|
||||
description="Tool output may contain unsafe content (prompt injection, risky code). Downstream tools with input_policy=trusted will be blocked.",
|
||||
),
|
||||
ToolPolicyOption(
|
||||
value="trusted",
|
||||
label="Trusted",
|
||||
description="Tool output is verified safe. Will not trigger trust-chain blocks on downstream tools.",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/tool/policy/options",
|
||||
tags=["tool management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ToolPolicyOptionsResponse,
|
||||
)
|
||||
async def get_tool_policy_options(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Return the available input and output policy options with descriptions.
|
||||
Static data — no DB call.
|
||||
"""
|
||||
return TOOL_POLICY_OPTIONS
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/tool/list",
|
||||
@ -33,14 +93,14 @@ router = APIRouter()
|
||||
response_model=ToolListResponse,
|
||||
)
|
||||
async def list_tools(
|
||||
call_policy: Optional[ToolCallPolicy] = None,
|
||||
input_policy: Optional[ToolInputPolicy] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
List all auto-discovered tools and their call policies.
|
||||
List all auto-discovered tools and their policies.
|
||||
|
||||
Parameters:
|
||||
- call_policy: Optional filter — one of "trusted", "untrusted", "dual_llm", "blocked"
|
||||
- input_policy: Optional filter — one of "trusted", "untrusted", "blocked"
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import list_tools as db_list_tools
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
@ -51,13 +111,201 @@ async def list_tools(
|
||||
)
|
||||
|
||||
try:
|
||||
tools = await db_list_tools(prisma_client=prisma_client, call_policy=call_policy)
|
||||
tools = await db_list_tools(
|
||||
prisma_client=prisma_client, input_policy=input_policy
|
||||
)
|
||||
return ToolListResponse(tools=tools, total=len(tools))
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error listing tools: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/tool/{tool_name:path}/detail",
|
||||
tags=["tool management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ToolDetailResponse,
|
||||
)
|
||||
async def get_tool_detail(
|
||||
tool_name: str,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Get a single tool with its policy overrides (for UI detail view).
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import get_tool as db_get_tool
|
||||
from litellm.proxy.db.tool_registry_writer import list_overrides_for_tool
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
tool = await db_get_tool(prisma_client=prisma_client, tool_name=tool_name)
|
||||
if tool is None:
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
|
||||
overrides = await list_overrides_for_tool(
|
||||
prisma_client=prisma_client, tool_name=tool_name
|
||||
)
|
||||
return ToolDetailResponse(tool=tool, overrides=overrides)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error getting tool detail: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def _input_snippet_for_tool_log(sl: Any, max_len: int = 200) -> Optional[str]:
|
||||
"""Short snippet from messages or proxy_server_request for tool usage log row."""
|
||||
if sl is None:
|
||||
return None
|
||||
messages = getattr(sl, "messages", None)
|
||||
if messages is not None:
|
||||
s = _snippet_str(messages, max_len)
|
||||
if s:
|
||||
return s
|
||||
psr = getattr(sl, "proxy_server_request", None)
|
||||
if not psr:
|
||||
return None
|
||||
if isinstance(psr, str):
|
||||
import json
|
||||
|
||||
try:
|
||||
psr = json.loads(psr)
|
||||
except Exception:
|
||||
return _snippet_str(psr, max_len)
|
||||
if isinstance(psr, dict):
|
||||
msgs = psr.get("messages")
|
||||
if msgs is None and isinstance(psr.get("body"), dict):
|
||||
msgs = psr["body"].get("messages")
|
||||
s = _snippet_str(msgs, max_len)
|
||||
if s:
|
||||
return s
|
||||
return _snippet_str(psr, max_len)
|
||||
|
||||
|
||||
def _snippet_str(text: Any, max_len: int = 200) -> Optional[str]:
|
||||
if text is None:
|
||||
return None
|
||||
if isinstance(text, str):
|
||||
s = text
|
||||
elif isinstance(text, list):
|
||||
parts = []
|
||||
for item in text:
|
||||
if isinstance(item, dict) and "content" in item:
|
||||
c = item["content"]
|
||||
parts.append(c if isinstance(c, str) else str(c))
|
||||
else:
|
||||
parts.append(str(item))
|
||||
s = " ".join(parts)
|
||||
else:
|
||||
s = str(text)
|
||||
if not s or s == "{}":
|
||||
return None
|
||||
return (s[:max_len] + "...") if len(s) > max_len else s
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/tool/{tool_name:path}/logs",
|
||||
tags=["tool management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=ToolUsageLogsResponse,
|
||||
)
|
||||
async def get_tool_usage_logs(
|
||||
tool_name: str,
|
||||
page: int = Query(1, ge=1),
|
||||
page_size: int = Query(50, ge=1, le=100),
|
||||
start_date: Optional[str] = Query(None, description="YYYY-MM-DD"),
|
||||
end_date: Optional[str] = Query(None, description="YYYY-MM-DD"),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Return paginated spend logs for requests that used this tool (from SpendLogToolIndex).
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
|
||||
try:
|
||||
where: dict = {"tool_name": tool_name}
|
||||
if start_date or end_date:
|
||||
start_time_filter: Optional[datetime] = None
|
||||
end_time_filter: Optional[datetime] = None
|
||||
if start_date:
|
||||
try:
|
||||
start_time_filter = datetime.strptime(
|
||||
start_date + "T00:00:00", "%Y-%m-%dT%H:%M:%S"
|
||||
).replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
pass
|
||||
if end_date:
|
||||
try:
|
||||
end_time_filter = datetime.strptime(
|
||||
end_date + "T23:59:59", "%Y-%m-%dT%H:%M:%S"
|
||||
).replace(tzinfo=timezone.utc)
|
||||
except ValueError:
|
||||
pass
|
||||
if start_time_filter is not None or end_time_filter is not None:
|
||||
where["start_time"] = {}
|
||||
if start_time_filter is not None:
|
||||
where["start_time"]["gte"] = start_time_filter
|
||||
if end_time_filter is not None:
|
||||
where["start_time"]["lte"] = end_time_filter
|
||||
|
||||
total = await prisma_client.db.litellm_spendlogtoolindex.count(where=where)
|
||||
index_rows = await prisma_client.db.litellm_spendlogtoolindex.find_many(
|
||||
where=where,
|
||||
order={"start_time": "desc"},
|
||||
skip=(page - 1) * page_size,
|
||||
take=page_size,
|
||||
)
|
||||
request_ids = [r.request_id for r in index_rows]
|
||||
if not request_ids:
|
||||
return ToolUsageLogsResponse(
|
||||
logs=[], total=total, page=page, page_size=page_size
|
||||
)
|
||||
|
||||
spend_logs = await prisma_client.db.litellm_spendlogs.find_many(
|
||||
where={"request_id": {"in": request_ids}}
|
||||
)
|
||||
log_by_id = {s.request_id: s for s in spend_logs}
|
||||
|
||||
logs_out: List[ToolUsageLogEntry] = []
|
||||
for r in index_rows:
|
||||
sl = log_by_id.get(r.request_id)
|
||||
if not sl:
|
||||
continue
|
||||
ts = (
|
||||
sl.startTime.isoformat()
|
||||
if hasattr(sl.startTime, "isoformat")
|
||||
else str(sl.startTime)
|
||||
)
|
||||
logs_out.append(
|
||||
ToolUsageLogEntry(
|
||||
id=sl.request_id,
|
||||
timestamp=ts,
|
||||
model=getattr(sl, "model", None) or None,
|
||||
spend=getattr(sl, "spend", None),
|
||||
total_tokens=getattr(sl, "total_tokens", None),
|
||||
input_snippet=_input_snippet_for_tool_log(sl),
|
||||
)
|
||||
)
|
||||
|
||||
return ToolUsageLogsResponse(
|
||||
logs=logs_out, total=total, page=page, page_size=page_size
|
||||
)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error getting tool usage logs: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/v1/tool/{tool_name:path}",
|
||||
tags=["tool management"],
|
||||
@ -70,9 +318,6 @@ async def get_tool(
|
||||
):
|
||||
"""
|
||||
Get details for a single tool.
|
||||
|
||||
Parameters:
|
||||
- tool_name: The tool name (supports namespaced names with slashes)
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import get_tool as db_get_tool
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
@ -85,9 +330,7 @@ async def get_tool(
|
||||
try:
|
||||
tool = await db_get_tool(prisma_client=prisma_client, tool_name=tool_name)
|
||||
if tool is None:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Tool '{tool_name}' not found"
|
||||
)
|
||||
raise HTTPException(status_code=404, detail=f"Tool '{tool_name}' not found")
|
||||
return tool
|
||||
except HTTPException:
|
||||
raise
|
||||
@ -96,6 +339,80 @@ async def get_tool(
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
async def _resolve_key_hash_to_object_permission_id(
|
||||
prisma_client: "PrismaClient",
|
||||
key_hash: str,
|
||||
) -> Optional[str]:
|
||||
"""Resolve key (hash or raw) to object_permission_id; create permission if key has none."""
|
||||
from litellm.proxy.proxy_server import hash_token
|
||||
|
||||
hashed = key_hash if "sk-" not in (key_hash or "") else hash_token(key_hash)
|
||||
if not hashed:
|
||||
return None
|
||||
row = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed}
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
op_id = getattr(row, "object_permission_id", None)
|
||||
if op_id:
|
||||
return op_id
|
||||
new_id = str(uuid.uuid4())
|
||||
await prisma_client.db.litellm_objectpermissiontable.create(
|
||||
data={"object_permission_id": new_id, "blocked_tools": []}
|
||||
)
|
||||
updated_count = await prisma_client.db.litellm_verificationtoken.update_many(
|
||||
where={"token": hashed, "object_permission_id": None},
|
||||
data={"object_permission_id": new_id},
|
||||
)
|
||||
if updated_count == 0:
|
||||
await prisma_client.db.litellm_objectpermissiontable.delete(
|
||||
where={"object_permission_id": new_id}
|
||||
)
|
||||
row = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed}
|
||||
)
|
||||
return getattr(row, "object_permission_id", None) if row else None
|
||||
return new_id
|
||||
|
||||
|
||||
async def _resolve_team_id_to_object_permission_id(
|
||||
prisma_client: "PrismaClient",
|
||||
team_id: str,
|
||||
) -> Optional[str]:
|
||||
"""Resolve team_id to object_permission_id; create permission if team has none."""
|
||||
if not team_id or not team_id.strip():
|
||||
return None
|
||||
team_id_clean = team_id.strip()
|
||||
row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id_clean},
|
||||
select={"object_permission_id": True},
|
||||
)
|
||||
if row is None:
|
||||
return None
|
||||
op_id = getattr(row, "object_permission_id", None)
|
||||
if op_id:
|
||||
return op_id
|
||||
new_id = str(uuid.uuid4())
|
||||
await prisma_client.db.litellm_objectpermissiontable.create(
|
||||
data={"object_permission_id": new_id, "blocked_tools": []}
|
||||
)
|
||||
updated_count = await prisma_client.db.litellm_teamtable.update_many(
|
||||
where={"team_id": team_id_clean, "object_permission_id": None},
|
||||
data={"object_permission_id": new_id},
|
||||
)
|
||||
if updated_count == 0:
|
||||
await prisma_client.db.litellm_objectpermissiontable.delete(
|
||||
where={"object_permission_id": new_id}
|
||||
)
|
||||
row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": team_id_clean},
|
||||
select={"object_permission_id": True},
|
||||
)
|
||||
return getattr(row, "object_permission_id", None) if row else None
|
||||
return new_id
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v1/tool/policy",
|
||||
tags=["tool management"],
|
||||
@ -107,15 +424,20 @@ async def update_tool_policy(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Set the call policy for a tool.
|
||||
Set the input_policy and/or output_policy for a tool (global), or block for a specific team/key (override).
|
||||
|
||||
Parameters:
|
||||
- tool_name: str - The tool to update
|
||||
- call_policy: "trusted" | "untrusted" | "dual_llm" | "blocked"
|
||||
|
||||
Setting a tool to "blocked" will cause the ToolPolicyGuardrail to remove
|
||||
that tool_call from LLM responses before returning them to the client.
|
||||
- input_policy: optional - "trusted" | "untrusted" | "blocked"
|
||||
- output_policy: optional - "trusted" | "untrusted"
|
||||
- team_id: optional - if set, create/update override for this team only
|
||||
- key_hash: optional - if set, create/update override for this key only
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import (
|
||||
add_tool_to_object_permission_blocked,
|
||||
get_tool_policy_registry,
|
||||
remove_tool_from_object_permission_blocked,
|
||||
)
|
||||
from litellm.proxy.db.tool_registry_writer import (
|
||||
update_tool_policy as db_update_tool_policy,
|
||||
)
|
||||
@ -127,19 +449,80 @@ async def update_tool_policy(
|
||||
)
|
||||
|
||||
try:
|
||||
if data.team_id is not None or data.key_hash is not None:
|
||||
if data.team_id is not None and data.key_hash is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Provide either team_id or key_hash, not both",
|
||||
)
|
||||
if data.key_hash is not None:
|
||||
op_id = await _resolve_key_hash_to_object_permission_id(
|
||||
prisma_client, data.key_hash
|
||||
)
|
||||
else:
|
||||
op_id = await _resolve_team_id_to_object_permission_id(
|
||||
prisma_client, data.team_id or ""
|
||||
)
|
||||
if op_id is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Key or team not found for the given identifier",
|
||||
)
|
||||
is_blocking = data.input_policy == "blocked"
|
||||
if is_blocking:
|
||||
ok = await add_tool_to_object_permission_blocked(
|
||||
prisma_client=prisma_client,
|
||||
object_permission_id=op_id,
|
||||
tool_name=data.tool_name,
|
||||
)
|
||||
else:
|
||||
ok = await remove_tool_from_object_permission_blocked(
|
||||
prisma_client=prisma_client,
|
||||
object_permission_id=op_id,
|
||||
tool_name=data.tool_name,
|
||||
)
|
||||
if not ok:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to update policy override for tool '{data.tool_name}'",
|
||||
)
|
||||
registry = get_tool_policy_registry()
|
||||
if registry.is_initialized():
|
||||
await registry.sync_tool_policy_from_db(prisma_client)
|
||||
return ToolPolicyUpdateResponse(
|
||||
tool_name=data.tool_name,
|
||||
input_policy=data.input_policy,
|
||||
output_policy=data.output_policy,
|
||||
updated=True,
|
||||
team_id=data.team_id,
|
||||
key_hash=data.key_hash,
|
||||
)
|
||||
|
||||
if data.input_policy is None and data.output_policy is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="At least one of input_policy or output_policy must be provided",
|
||||
)
|
||||
|
||||
updated = await db_update_tool_policy(
|
||||
prisma_client=prisma_client,
|
||||
tool_name=data.tool_name,
|
||||
call_policy=data.call_policy,
|
||||
updated_by=user_api_key_dict.user_id,
|
||||
input_policy=data.input_policy,
|
||||
output_policy=data.output_policy,
|
||||
)
|
||||
if updated is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to update policy for tool '{data.tool_name}'"
|
||||
status_code=500,
|
||||
detail=f"Failed to update policy for tool '{data.tool_name}'",
|
||||
)
|
||||
registry = get_tool_policy_registry()
|
||||
if registry.is_initialized():
|
||||
await registry.sync_tool_policy_from_db(prisma_client)
|
||||
return ToolPolicyUpdateResponse(
|
||||
tool_name=updated.tool_name,
|
||||
call_policy=updated.call_policy,
|
||||
input_policy=updated.input_policy,
|
||||
output_policy=updated.output_policy,
|
||||
updated=True,
|
||||
)
|
||||
except HTTPException:
|
||||
@ -147,3 +530,77 @@ async def update_tool_policy(
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error updating tool policy: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/v1/tool/{tool_name:path}/overrides",
|
||||
tags=["tool management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def delete_tool_policy_override(
|
||||
tool_name: str,
|
||||
team_id: Optional[str] = Query(
|
||||
None, description="Team ID of the override to remove"
|
||||
),
|
||||
key_hash: Optional[str] = Query(
|
||||
None, description="Key hash of the override to remove"
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Remove a policy override for a tool. Specify the override by team_id or key_hash
|
||||
(exactly one required).
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import (
|
||||
get_tool_policy_registry,
|
||||
remove_tool_from_object_permission_blocked,
|
||||
)
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=CommonProxyErrors.db_not_connected_error.value
|
||||
)
|
||||
if team_id is None and key_hash is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="At least one of team_id or key_hash is required to identify the override",
|
||||
)
|
||||
if team_id is not None and key_hash is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Provide either team_id or key_hash, not both",
|
||||
)
|
||||
try:
|
||||
if key_hash is not None:
|
||||
op_id = await _resolve_key_hash_to_object_permission_id(
|
||||
prisma_client, key_hash
|
||||
)
|
||||
else:
|
||||
op_id = await _resolve_team_id_to_object_permission_id(
|
||||
prisma_client, team_id or ""
|
||||
)
|
||||
if op_id is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail="Key or team not found for the given identifier",
|
||||
)
|
||||
deleted = await remove_tool_from_object_permission_blocked(
|
||||
prisma_client=prisma_client,
|
||||
object_permission_id=op_id,
|
||||
tool_name=tool_name,
|
||||
)
|
||||
if not deleted:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"No override found for tool '{tool_name}' with the given scope",
|
||||
)
|
||||
registry = get_tool_policy_registry()
|
||||
if registry.is_initialized():
|
||||
await registry.sync_tool_policy_from_db(prisma_client)
|
||||
return {"deleted": True, "tool_name": tool_name}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception("Error deleting tool policy override: %s", e)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@ -4411,6 +4411,9 @@ class ProxyConfig:
|
||||
if self._should_load_db_object(object_type="search_tools"):
|
||||
await self._init_search_tools_in_db(prisma_client=prisma_client)
|
||||
|
||||
if self._should_load_db_object(object_type="tools"):
|
||||
await self._init_tool_policy_in_db(prisma_client=prisma_client)
|
||||
|
||||
if self._should_load_db_object(object_type="model_cost_map"):
|
||||
await self._check_and_reload_model_cost_map(prisma_client=prisma_client)
|
||||
|
||||
@ -4847,6 +4850,24 @@ class ProxyConfig:
|
||||
)
|
||||
)
|
||||
|
||||
async def _init_tool_policy_in_db(self, prisma_client: PrismaClient):
|
||||
"""
|
||||
Initialize tool policy from database into the in-memory registry.
|
||||
Synced periodically by add_deployment -> _init_non_llm_objects_in_db.
|
||||
"""
|
||||
from litellm.proxy.db.tool_registry_writer import get_tool_policy_registry
|
||||
|
||||
try:
|
||||
registry = get_tool_policy_registry()
|
||||
await registry.sync_tool_policy_from_db(prisma_client=prisma_client)
|
||||
verbose_proxy_logger.debug("Successfully synced tool policy from DB")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.exception(
|
||||
"litellm.proxy.proxy_server.py::ProxyConfig:_init_tool_policy_in_db - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
|
||||
async def _init_vector_stores_in_db(self, prisma_client: PrismaClient):
|
||||
from litellm.vector_stores.vector_store_registry import VectorStoreRegistry
|
||||
|
||||
@ -10577,6 +10598,12 @@ async def async_queue_request(
|
||||
data["metadata"]["user_api_key_team_id"] = getattr(
|
||||
user_api_key_dict, "team_id", None
|
||||
)
|
||||
data["metadata"]["user_api_key_object_permission_id"] = getattr(
|
||||
user_api_key_dict, "object_permission_id", None
|
||||
)
|
||||
data["metadata"]["user_api_key_team_object_permission_id"] = getattr(
|
||||
user_api_key_dict, "team_object_permission_id", None
|
||||
)
|
||||
data["metadata"]["endpoint"] = str(request.url)
|
||||
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
@ -11093,9 +11120,7 @@ async def get_favicon():
|
||||
|
||||
if favicon_url.startswith(("http://", "https://")):
|
||||
try:
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.types.llms.custom_http import httpxSpecialProvider
|
||||
|
||||
async_client = get_async_httpx_client(
|
||||
|
||||
@ -260,6 +260,7 @@ model LiteLLM_ObjectPermissionTable {
|
||||
vector_stores String[] @default([])
|
||||
agents String[] @default([])
|
||||
agent_access_groups String[] @default([])
|
||||
blocked_tools String[] @default([]) // Tool names blocked for any key/team/user with this permission
|
||||
teams LiteLLM_TeamTable[]
|
||||
projects LiteLLM_ProjectTable[]
|
||||
verification_tokens LiteLLM_VerificationToken[]
|
||||
@ -928,6 +929,16 @@ model LiteLLM_SpendLogGuardrailIndex {
|
||||
@@index([policy_id, start_time])
|
||||
}
|
||||
|
||||
// Index for fast "last N logs for tool" from SpendLogs – see how a tool is called in production
|
||||
model LiteLLM_SpendLogToolIndex {
|
||||
request_id String
|
||||
tool_name String // matches LiteLLM_ToolTable.tool_name; join for input_policy/output_policy etc.
|
||||
start_time DateTime
|
||||
|
||||
@@id([request_id, tool_name])
|
||||
@@index([tool_name, start_time])
|
||||
}
|
||||
|
||||
// Prompt table for storing prompt configurations
|
||||
model LiteLLM_PromptTable {
|
||||
id String @id @default(uuid())
|
||||
@ -1065,23 +1076,27 @@ model LiteLLM_PolicyAttachmentTable {
|
||||
updated_by String?
|
||||
}
|
||||
|
||||
// Global tool registry - auto-discovered from LLM responses; admins set call_policy here
|
||||
// Global tool registry - auto-discovered from LLM responses; admins set input/output policies here
|
||||
model LiteLLM_ToolTable {
|
||||
tool_id String @id @default(uuid())
|
||||
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
|
||||
origin String? // MCP server name or "user_defined"
|
||||
call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked"
|
||||
call_count Int @default(0) // cumulative number of times this tool was seen
|
||||
assignments Json? @default("{}")
|
||||
key_hash String? // hash of the virtual key that first called this tool
|
||||
team_id String? // team that first called this tool
|
||||
key_alias String? // human-readable alias of the virtual key
|
||||
created_at DateTime @default(now())
|
||||
created_by String?
|
||||
updated_at DateTime @default(now()) @updatedAt
|
||||
updated_by String?
|
||||
tool_id String @id @default(uuid())
|
||||
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
|
||||
origin String? // MCP server name or "user_defined"
|
||||
input_policy String @default("untrusted") // "trusted" | "untrusted" | "blocked"
|
||||
output_policy String @default("untrusted") // "trusted" | "untrusted"
|
||||
call_count Int @default(0) // cumulative number of times this tool was seen
|
||||
assignments Json? @default("{}")
|
||||
key_hash String? // hash of the virtual key that first called this tool
|
||||
team_id String? // team that first called this tool
|
||||
key_alias String? // human-readable alias of the virtual key
|
||||
user_agent String? // user-agent of the first request that discovered this tool
|
||||
last_used_at DateTime? // timestamp of the most recent call
|
||||
created_at DateTime @default(now())
|
||||
created_by String?
|
||||
updated_at DateTime @default(now()) @updatedAt
|
||||
updated_by String?
|
||||
|
||||
@@index([call_policy])
|
||||
@@index([input_policy])
|
||||
@@index([output_policy])
|
||||
@@index([team_id])
|
||||
}
|
||||
|
||||
|
||||
@ -3583,8 +3583,9 @@ class PrismaClient:
|
||||
def _get_engine_pid(self) -> int:
|
||||
try:
|
||||
engine = self.db._original_prisma._engine # type: ignore[attr-defined]
|
||||
if engine is not None and engine.process is not None:
|
||||
return engine.process.pid
|
||||
process = getattr(engine, "process", None) if engine is not None else None
|
||||
if process is not None:
|
||||
return process.pid
|
||||
except (AttributeError, TypeError):
|
||||
pass
|
||||
return 0
|
||||
@ -4688,6 +4689,19 @@ async def update_spend_logs_job(
|
||||
guardrail_tracking_err,
|
||||
)
|
||||
|
||||
# Tool usage tracking (same batch): SpendLogToolIndex for "last N requests for tool X"
|
||||
try:
|
||||
from litellm.proxy.db.spend_log_tool_index import process_spend_logs_tool_usage
|
||||
await process_spend_logs_tool_usage(
|
||||
prisma_client=prisma_client,
|
||||
logs_to_process=logs_to_process,
|
||||
)
|
||||
except Exception as tool_tracking_err:
|
||||
verbose_proxy_logger.warning(
|
||||
"Spend tracking - tool usage tracking failed (non-fatal): %s",
|
||||
tool_tracking_err,
|
||||
)
|
||||
|
||||
|
||||
async def _monitor_spend_logs_queue(
|
||||
prisma_client: PrismaClient,
|
||||
|
||||
@ -5,21 +5,27 @@ Pydantic models for Tool Policy management endpoints.
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
ToolCallPolicy = Literal["trusted", "untrusted", "dual_llm", "blocked"]
|
||||
|
||||
ToolInputPolicy = Literal["trusted", "untrusted", "blocked"]
|
||||
ToolOutputPolicy = Literal["trusted", "untrusted"]
|
||||
|
||||
|
||||
class LiteLLM_ToolTableRow(BaseModel):
|
||||
tool_id: str
|
||||
tool_name: str
|
||||
origin: Optional[str] = None
|
||||
call_policy: ToolCallPolicy = "untrusted"
|
||||
input_policy: ToolInputPolicy = "untrusted"
|
||||
output_policy: ToolOutputPolicy = "untrusted"
|
||||
call_count: int = 0
|
||||
assignments: Optional[Dict] = None
|
||||
key_hash: Optional[str] = None
|
||||
team_id: Optional[str] = None
|
||||
key_alias: Optional[str] = None
|
||||
user_agent: Optional[str] = None
|
||||
last_used_at: Optional[datetime] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
created_by: Optional[str] = None
|
||||
@ -33,10 +39,62 @@ class ToolListResponse(BaseModel):
|
||||
|
||||
class ToolPolicyUpdateRequest(BaseModel):
|
||||
tool_name: str
|
||||
call_policy: ToolCallPolicy
|
||||
input_policy: Optional[ToolInputPolicy] = None
|
||||
output_policy: Optional[ToolOutputPolicy] = None
|
||||
team_id: Optional[str] = None
|
||||
key_hash: Optional[str] = None
|
||||
key_alias: Optional[str] = None
|
||||
|
||||
|
||||
class ToolPolicyUpdateResponse(BaseModel):
|
||||
tool_name: str
|
||||
call_policy: ToolCallPolicy
|
||||
input_policy: Optional[ToolInputPolicy] = None
|
||||
output_policy: Optional[ToolOutputPolicy] = None
|
||||
updated: bool
|
||||
team_id: Optional[str] = None
|
||||
key_hash: Optional[str] = None
|
||||
|
||||
|
||||
class ToolPolicyOverrideRow(BaseModel):
|
||||
override_id: str
|
||||
tool_name: str
|
||||
team_id: Optional[str] = None
|
||||
key_hash: Optional[str] = None
|
||||
input_policy: ToolInputPolicy = "blocked"
|
||||
key_alias: Optional[str] = None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
|
||||
class ToolPolicyOption(BaseModel):
|
||||
value: str
|
||||
label: str
|
||||
description: str
|
||||
|
||||
|
||||
class ToolPolicyOptionsResponse(BaseModel):
|
||||
input_policies: List[ToolPolicyOption]
|
||||
output_policies: List[ToolPolicyOption]
|
||||
|
||||
|
||||
class ToolDetailResponse(BaseModel):
|
||||
tool: LiteLLM_ToolTableRow
|
||||
overrides: List[ToolPolicyOverrideRow] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ToolUsageLogEntry(BaseModel):
|
||||
"""One spend log row for a tool call (for UI "recent logs" table)."""
|
||||
|
||||
id: str # request_id
|
||||
timestamp: str
|
||||
model: Optional[str] = None
|
||||
spend: Optional[float] = None
|
||||
total_tokens: Optional[int] = None
|
||||
input_snippet: Optional[str] = None
|
||||
|
||||
|
||||
class ToolUsageLogsResponse(BaseModel):
|
||||
logs: List[ToolUsageLogEntry]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
|
||||
@ -260,6 +260,7 @@ model LiteLLM_ObjectPermissionTable {
|
||||
vector_stores String[] @default([])
|
||||
agents String[] @default([])
|
||||
agent_access_groups String[] @default([])
|
||||
blocked_tools String[] @default([]) // Tool names blocked for any key/team/user with this permission
|
||||
teams LiteLLM_TeamTable[]
|
||||
projects LiteLLM_ProjectTable[]
|
||||
verification_tokens LiteLLM_VerificationToken[]
|
||||
@ -928,6 +929,16 @@ model LiteLLM_SpendLogGuardrailIndex {
|
||||
@@index([policy_id, start_time])
|
||||
}
|
||||
|
||||
// Index for fast "last N logs for tool" from SpendLogs – see how a tool is called in production
|
||||
model LiteLLM_SpendLogToolIndex {
|
||||
request_id String
|
||||
tool_name String // matches LiteLLM_ToolTable.tool_name; join for input_policy/output_policy etc.
|
||||
start_time DateTime
|
||||
|
||||
@@id([request_id, tool_name])
|
||||
@@index([tool_name, start_time])
|
||||
}
|
||||
|
||||
// Prompt table for storing prompt configurations
|
||||
model LiteLLM_PromptTable {
|
||||
id String @id @default(uuid())
|
||||
@ -1065,23 +1076,27 @@ model LiteLLM_PolicyAttachmentTable {
|
||||
updated_by String?
|
||||
}
|
||||
|
||||
// Global tool registry - auto-discovered from LLM responses; admins set call_policy here
|
||||
// Global tool registry - auto-discovered from LLM responses; admins set input/output policies here
|
||||
model LiteLLM_ToolTable {
|
||||
tool_id String @id @default(uuid())
|
||||
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
|
||||
origin String? // MCP server name or "user_defined"
|
||||
call_policy String @default("untrusted") // "trusted" | "untrusted" | "dual_llm" | "blocked"
|
||||
call_count Int @default(0) // cumulative number of times this tool was seen
|
||||
assignments Json? @default("{}")
|
||||
key_hash String? // hash of the virtual key that first called this tool
|
||||
team_id String? // team that first called this tool
|
||||
key_alias String? // human-readable alias of the virtual key
|
||||
created_at DateTime @default(now())
|
||||
created_by String?
|
||||
updated_at DateTime @default(now()) @updatedAt
|
||||
updated_by String?
|
||||
tool_id String @id @default(uuid())
|
||||
tool_name String @unique // e.g. "huggingface_remote-mcp__dynamic_space"
|
||||
origin String? // MCP server name or "user_defined"
|
||||
input_policy String @default("untrusted") // "trusted" | "untrusted" | "blocked"
|
||||
output_policy String @default("untrusted") // "trusted" | "untrusted"
|
||||
call_count Int @default(0) // cumulative number of times this tool was seen
|
||||
assignments Json? @default("{}")
|
||||
key_hash String? // hash of the virtual key that first called this tool
|
||||
team_id String? // team that first called this tool
|
||||
key_alias String? // human-readable alias of the virtual key
|
||||
user_agent String? // user-agent of the first request that discovered this tool
|
||||
last_used_at DateTime? // timestamp of the most recent call
|
||||
created_at DateTime @default(now())
|
||||
created_by String?
|
||||
updated_at DateTime @default(now()) @updatedAt
|
||||
updated_by String?
|
||||
|
||||
@@index([call_policy])
|
||||
@@index([input_policy])
|
||||
@@index([output_policy])
|
||||
@@index([team_id])
|
||||
}
|
||||
|
||||
|
||||
116
scripts/test_tool_allowlist_script.py
Normal file
116
scripts/test_tool_allowlist_script.py
Normal file
@ -0,0 +1,116 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Standalone script to test tool allowlist enforcement and tool name extraction.
|
||||
|
||||
Run from repo root:
|
||||
poetry run python scripts/test_tool_allowlist_script.py
|
||||
|
||||
Or run the unit tests:
|
||||
poetry run pytest tests/test_litellm/proxy/test_tools_allowlist_enforcement.py -v
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure repo root is on path
|
||||
repo_root = Path(__file__).resolve().parent.parent
|
||||
if str(repo_root) not in sys.path:
|
||||
sys.path.insert(0, str(repo_root))
|
||||
|
||||
|
||||
def test_extraction():
|
||||
"""Test extract_request_tool_names for each API shape."""
|
||||
from litellm.proxy.guardrails.tool_name_extraction import extract_request_tool_names
|
||||
|
||||
cases = [
|
||||
("OpenAI chat tools", "/v1/chat/completions", {"tools": [{"type": "function", "function": {"name": "get_weather"}}]}),
|
||||
("OpenAI chat functions", "/v1/chat/completions", {"functions": [{"name": "run_sql"}]}),
|
||||
("OpenAI responses function", "/v1/responses", {"tools": [{"type": "function", "name": "get_current_weather"}]}),
|
||||
("OpenAI responses MCP", "/v1/responses", {"tools": [{"type": "mcp", "server_label": "dmcp"}]}),
|
||||
("Anthropic", "/v1/messages", {"tools": [{"name": "get_weather"}, {"name": "run_sql"}]}),
|
||||
("Google generateContent", "/generate_content", {"tools": [{"functionDeclarations": [{"name": "schedule_meeting"}]}]}),
|
||||
("MCP call_tool", "/mcp/call_tool", {"name": "my_tool", "arguments": {}}),
|
||||
("Non-tool route", "/v1/embeddings", {"tools": [{"type": "function", "function": {"name": "x"}}]}),
|
||||
]
|
||||
print("=== extract_request_tool_names(route, data) ===\n")
|
||||
for label, route, data in cases:
|
||||
names = extract_request_tool_names(route, data)
|
||||
print(f" {label}: {names}")
|
||||
print()
|
||||
|
||||
|
||||
async def test_check_tools_allowlist():
|
||||
"""Test check_tools_allowlist with mock tokens."""
|
||||
from litellm.proxy._types import ProxyErrorTypes, ProxyException, UserAPIKeyAuth
|
||||
from litellm.proxy.auth.auth_checks import check_tools_allowlist
|
||||
|
||||
def token(metadata=None, team_metadata=None):
|
||||
return UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_id="user",
|
||||
team_id="team",
|
||||
org_id=None,
|
||||
models=["*"],
|
||||
metadata=metadata or {},
|
||||
team_metadata=team_metadata or {},
|
||||
)
|
||||
|
||||
print("=== check_tools_allowlist (auth) ===\n")
|
||||
|
||||
# No allowlist -> pass
|
||||
await check_tools_allowlist(
|
||||
request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]},
|
||||
valid_token=token(),
|
||||
team_object=None,
|
||||
route="/v1/chat/completions",
|
||||
)
|
||||
print(" No allowlist, body has tools: PASS")
|
||||
|
||||
# Allowed tool -> pass
|
||||
await check_tools_allowlist(
|
||||
request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]},
|
||||
valid_token=token(metadata={"allowed_tools": ["get_weather"]}),
|
||||
team_object=None,
|
||||
route="/v1/chat/completions",
|
||||
)
|
||||
print(" allowed_tools=['get_weather'], body has get_weather: PASS")
|
||||
|
||||
# Disallowed tool -> raise
|
||||
try:
|
||||
await check_tools_allowlist(
|
||||
request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]},
|
||||
valid_token=token(metadata={"allowed_tools": ["other_tool"]}),
|
||||
team_object=None,
|
||||
route="/v1/chat/completions",
|
||||
)
|
||||
print(" DISALLOWED: expected ProxyException")
|
||||
except ProxyException as e:
|
||||
if e.type == ProxyErrorTypes.tool_access_denied:
|
||||
print(" allowed_tools=['other_tool'], body has get_weather: PASS (raised tool_access_denied)")
|
||||
else:
|
||||
print(f" Unexpected ProxyException type: {e.type}")
|
||||
except Exception as e:
|
||||
print(f" Unexpected: {e}")
|
||||
|
||||
# Team allowlist when key empty
|
||||
await check_tools_allowlist(
|
||||
request_body={"tools": [{"type": "function", "function": {"name": "get_weather"}}]},
|
||||
valid_token=token(team_metadata={"allowed_tools": ["get_weather"]}),
|
||||
team_object=None,
|
||||
route="/v1/chat/completions",
|
||||
)
|
||||
print(" team_metadata.allowed_tools=['get_weather']: PASS")
|
||||
print()
|
||||
|
||||
|
||||
def main():
|
||||
print("Tool allowlist / tool name extraction – script checks\n")
|
||||
test_extraction()
|
||||
asyncio.run(test_check_tools_allowlist())
|
||||
print("Done. For full unit tests run:")
|
||||
print(" poetry run pytest tests/test_litellm/proxy/test_tools_allowlist_enforcement.py -v")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@ -1,6 +1,6 @@
|
||||
"""
|
||||
Unit tests for tool_registry_writer.py — uses a mock prisma client
|
||||
that exposes execute_raw / query_raw (matching the actual raw-SQL implementation).
|
||||
that exposes litellm_tooltable.upsert / find_many / find_unique.
|
||||
"""
|
||||
|
||||
import os
|
||||
@ -13,21 +13,28 @@ import pytest
|
||||
sys.path.insert(0, os.path.abspath("../../.."))
|
||||
|
||||
from litellm.proxy.db.tool_registry_writer import (
|
||||
ToolPolicyRegistry,
|
||||
batch_upsert_tools,
|
||||
get_tool,
|
||||
get_tool_policy_registry,
|
||||
get_tools_by_names,
|
||||
list_tools,
|
||||
update_tool_policy,
|
||||
)
|
||||
|
||||
|
||||
def _make_prisma(query_rows=None):
|
||||
"""Return a minimal mock prisma_client with execute_raw / query_raw."""
|
||||
default_row = {
|
||||
def _mock_row(**kwargs):
|
||||
"""Build a row-like object with real attributes (no MagicMock) for _row_to_model."""
|
||||
|
||||
class Row:
|
||||
pass
|
||||
|
||||
default = {
|
||||
"tool_id": "uuid-1",
|
||||
"tool_name": "my_tool",
|
||||
"origin": "user_defined",
|
||||
"call_policy": "untrusted",
|
||||
"input_policy": "untrusted",
|
||||
"output_policy": "untrusted",
|
||||
"call_count": 1,
|
||||
"assignments": {},
|
||||
"key_hash": None,
|
||||
@ -38,31 +45,54 @@ def _make_prisma(query_rows=None):
|
||||
"created_by": None,
|
||||
"updated_by": None,
|
||||
}
|
||||
rows = query_rows if query_rows is not None else [default_row]
|
||||
default.update(kwargs)
|
||||
row = Row()
|
||||
for k, v in default.items():
|
||||
setattr(row, k, v)
|
||||
return row
|
||||
|
||||
|
||||
def _make_prisma(
|
||||
*,
|
||||
upsert_return=None,
|
||||
find_many_rows=None,
|
||||
find_unique_row=None,
|
||||
):
|
||||
"""Return a mock prisma_client with litellm_tooltable.upsert, find_many, find_unique."""
|
||||
prisma = MagicMock()
|
||||
prisma.db.execute_raw = AsyncMock(return_value=None)
|
||||
prisma.db.query_raw = AsyncMock(return_value=rows)
|
||||
prisma.db.litellm_tooltable = MagicMock()
|
||||
prisma.db.litellm_tooltable.upsert = AsyncMock(return_value=upsert_return)
|
||||
prisma.db.litellm_tooltable.find_many = AsyncMock(
|
||||
return_value=find_many_rows if find_many_rows is not None else []
|
||||
)
|
||||
prisma.db.litellm_tooltable.find_unique = AsyncMock(
|
||||
return_value=find_unique_row
|
||||
)
|
||||
return prisma
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_upsert_tools_calls_execute_raw():
|
||||
async def test_batch_upsert_tools_calls_upsert():
|
||||
prisma = _make_prisma()
|
||||
items = [{"tool_name": "tool_a", "origin": "mcp_server", "created_by": None}]
|
||||
await batch_upsert_tools(prisma, items)
|
||||
prisma.db.execute_raw.assert_awaited_once()
|
||||
call_args = prisma.db.execute_raw.call_args
|
||||
sql = call_args.args[0]
|
||||
assert "LiteLLM_ToolTable" in sql
|
||||
assert "ON CONFLICT" in sql
|
||||
prisma.db.litellm_tooltable.upsert.assert_awaited_once()
|
||||
call_kw = prisma.db.litellm_tooltable.upsert.call_args.kwargs
|
||||
assert call_kw["where"] == {"tool_name": "tool_a"}
|
||||
assert call_kw["data"]["create"]["tool_name"] == "tool_a"
|
||||
assert call_kw["data"]["create"]["origin"] == "mcp_server"
|
||||
assert call_kw["data"]["create"]["input_policy"] == "untrusted"
|
||||
assert call_kw["data"]["create"]["output_policy"] == "untrusted"
|
||||
assert call_kw["data"]["create"]["call_count"] == 1
|
||||
assert call_kw["data"]["update"]["call_count"] == {"increment": 1}
|
||||
assert "updated_at" in call_kw["data"]["update"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_upsert_tools_empty_list():
|
||||
prisma = _make_prisma()
|
||||
await batch_upsert_tools(prisma, [])
|
||||
prisma.db.execute_raw.assert_not_awaited()
|
||||
prisma.db.litellm_tooltable.upsert.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -70,123 +100,120 @@ async def test_batch_upsert_tools_skips_empty_names():
|
||||
prisma = _make_prisma()
|
||||
items = [{"tool_name": "", "origin": None}, {"tool_name": None}] # type: ignore[list-item]
|
||||
await batch_upsert_tools(prisma, items)
|
||||
prisma.db.execute_raw.assert_not_awaited()
|
||||
prisma.db.litellm_tooltable.upsert.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_upsert_multiple_tools_calls_execute_raw_per_tool():
|
||||
async def test_batch_upsert_multiple_tools_calls_upsert_per_tool():
|
||||
prisma = _make_prisma()
|
||||
items = [
|
||||
{"tool_name": "tool_a", "origin": "mcp_server", "created_by": None},
|
||||
{"tool_name": "tool_b", "origin": "user_defined", "created_by": "alice"},
|
||||
]
|
||||
await batch_upsert_tools(prisma, items)
|
||||
assert prisma.db.execute_raw.await_count == 2
|
||||
assert prisma.db.litellm_tooltable.upsert.await_count == 2
|
||||
calls = prisma.db.litellm_tooltable.upsert.call_args_list
|
||||
assert calls[0].kwargs["where"]["tool_name"] == "tool_a"
|
||||
assert calls[1].kwargs["where"]["tool_name"] == "tool_b"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_no_filter():
|
||||
row = {
|
||||
"tool_id": "id1",
|
||||
"tool_name": "tool_a",
|
||||
"origin": "mcp",
|
||||
"call_policy": "untrusted",
|
||||
"call_count": 5,
|
||||
"assignments": {},
|
||||
"key_hash": None,
|
||||
"team_id": None,
|
||||
"key_alias": None,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"created_by": None,
|
||||
"updated_by": None,
|
||||
}
|
||||
prisma = _make_prisma(query_rows=[row])
|
||||
row = _mock_row(
|
||||
tool_id="id1",
|
||||
tool_name="tool_a",
|
||||
origin="mcp",
|
||||
input_policy="untrusted",
|
||||
output_policy="untrusted",
|
||||
call_count=5,
|
||||
)
|
||||
prisma = _make_prisma(find_many_rows=[row])
|
||||
result = await list_tools(prisma)
|
||||
assert len(result) == 1
|
||||
assert result[0].tool_name == "tool_a"
|
||||
assert result[0].call_count == 5
|
||||
prisma.db.query_raw.assert_awaited_once()
|
||||
prisma.db.litellm_tooltable.find_many.assert_awaited_once()
|
||||
call_kw = prisma.db.litellm_tooltable.find_many.call_args.kwargs
|
||||
assert call_kw["where"] == {}
|
||||
assert call_kw["order"] == {"created_at": "desc"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tools_with_policy_filter():
|
||||
row = {
|
||||
"tool_id": "id1",
|
||||
"tool_name": "blocked_tool",
|
||||
"origin": None,
|
||||
"call_policy": "blocked",
|
||||
"call_count": 2,
|
||||
"assignments": None,
|
||||
"key_hash": None,
|
||||
"team_id": None,
|
||||
"key_alias": None,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"created_by": None,
|
||||
"updated_by": None,
|
||||
}
|
||||
prisma = _make_prisma(query_rows=[row])
|
||||
result = await list_tools(prisma, call_policy="blocked")
|
||||
assert result[0].call_policy == "blocked"
|
||||
call_args = prisma.db.query_raw.call_args
|
||||
sql = call_args.args[0]
|
||||
assert "WHERE call_policy" in sql
|
||||
async def test_list_tools_with_input_policy_filter():
|
||||
row = _mock_row(
|
||||
tool_id="id1",
|
||||
tool_name="blocked_tool",
|
||||
origin=None,
|
||||
input_policy="blocked",
|
||||
output_policy="untrusted",
|
||||
call_count=2,
|
||||
assignments=None,
|
||||
)
|
||||
prisma = _make_prisma(find_many_rows=[row])
|
||||
result = await list_tools(prisma, input_policy="blocked")
|
||||
assert result[0].input_policy == "blocked"
|
||||
call_kw = prisma.db.litellm_tooltable.find_many.call_args.kwargs
|
||||
assert call_kw["where"] == {"input_policy": "blocked"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tool_found():
|
||||
prisma = _make_prisma()
|
||||
row = _mock_row(tool_name="my_tool")
|
||||
prisma = _make_prisma(find_unique_row=row)
|
||||
result = await get_tool(prisma, "my_tool")
|
||||
assert result is not None
|
||||
assert result.tool_name == "my_tool"
|
||||
prisma.db.query_raw.assert_awaited_once()
|
||||
prisma.db.litellm_tooltable.find_unique.assert_awaited_once_with(
|
||||
where={"tool_name": "my_tool"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tool_not_found():
|
||||
prisma = _make_prisma(query_rows=[])
|
||||
prisma = _make_prisma(find_unique_row=None)
|
||||
result = await get_tool(prisma, "nonexistent")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_tool_policy_calls_execute_raw():
|
||||
row = {
|
||||
"tool_id": "uuid-1",
|
||||
"tool_name": "my_tool",
|
||||
"origin": "user_defined",
|
||||
"call_policy": "blocked",
|
||||
"call_count": 1,
|
||||
"assignments": {},
|
||||
"key_hash": None,
|
||||
"team_id": None,
|
||||
"key_alias": None,
|
||||
"created_at": datetime.now(timezone.utc),
|
||||
"updated_at": datetime.now(timezone.utc),
|
||||
"created_by": None,
|
||||
"updated_by": "admin",
|
||||
}
|
||||
prisma = _make_prisma(query_rows=[row])
|
||||
result = await update_tool_policy(prisma, "my_tool", "blocked", "admin")
|
||||
async def test_update_tool_policy_calls_upsert_then_get_tool():
|
||||
row = _mock_row(
|
||||
tool_name="my_tool",
|
||||
input_policy="blocked",
|
||||
output_policy="untrusted",
|
||||
updated_by="admin",
|
||||
)
|
||||
prisma = _make_prisma(find_unique_row=row)
|
||||
result = await update_tool_policy(
|
||||
prisma, "my_tool", updated_by="admin", input_policy="blocked"
|
||||
)
|
||||
assert result is not None
|
||||
assert result.call_policy == "blocked"
|
||||
prisma.db.execute_raw.assert_awaited_once()
|
||||
call_args = prisma.db.execute_raw.call_args
|
||||
sql = call_args.args[0]
|
||||
assert "ON CONFLICT" in sql
|
||||
assert "call_policy" in sql
|
||||
assert result.input_policy == "blocked"
|
||||
prisma.db.litellm_tooltable.upsert.assert_awaited_once()
|
||||
call_kw = prisma.db.litellm_tooltable.upsert.call_args.kwargs
|
||||
assert call_kw["where"] == {"tool_name": "my_tool"}
|
||||
assert call_kw["data"]["update"]["input_policy"] == "blocked"
|
||||
assert call_kw["data"]["update"]["updated_by"] == "admin"
|
||||
prisma.db.litellm_tooltable.find_unique.assert_awaited_with(
|
||||
where={"tool_name": "my_tool"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tools_by_names_returns_policy_map():
|
||||
rows = [
|
||||
{"tool_name": "tool_a", "call_policy": "trusted"},
|
||||
{"tool_name": "tool_b", "call_policy": "blocked"},
|
||||
_mock_row(tool_name="tool_a", input_policy="trusted", output_policy="untrusted"),
|
||||
_mock_row(tool_name="tool_b", input_policy="blocked", output_policy="untrusted"),
|
||||
]
|
||||
prisma = _make_prisma(query_rows=rows)
|
||||
prisma = _make_prisma(find_many_rows=rows)
|
||||
result = await get_tools_by_names(prisma, ["tool_a", "tool_b"])
|
||||
assert result == {"tool_a": "trusted", "tool_b": "blocked"}
|
||||
assert result == {
|
||||
"tool_a": ("trusted", "untrusted"),
|
||||
"tool_b": ("blocked", "untrusted"),
|
||||
}
|
||||
prisma.db.litellm_tooltable.find_many.assert_awaited_once_with(
|
||||
where={"tool_name": {"in": ["tool_a", "tool_b"]}}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@ -194,4 +221,71 @@ async def test_get_tools_by_names_empty_list():
|
||||
prisma = _make_prisma()
|
||||
result = await get_tools_by_names(prisma, [])
|
||||
assert result == {}
|
||||
prisma.db.query_raw.assert_not_awaited()
|
||||
prisma.db.litellm_tooltable.find_many.assert_not_awaited()
|
||||
|
||||
|
||||
# --- ToolPolicyRegistry ---
|
||||
|
||||
|
||||
def _mock_tool_row(
|
||||
tool_name: str,
|
||||
input_policy: str = "untrusted",
|
||||
output_policy: str = "untrusted",
|
||||
):
|
||||
row = MagicMock()
|
||||
row.tool_name = tool_name
|
||||
row.input_policy = input_policy
|
||||
row.output_policy = output_policy
|
||||
return row
|
||||
|
||||
|
||||
def _mock_perm_row(object_permission_id: str, blocked_tools: list):
|
||||
row = MagicMock()
|
||||
row.object_permission_id = object_permission_id
|
||||
row.blocked_tools = blocked_tools
|
||||
return row
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_policy_registry_sync_and_get_effective_policies():
|
||||
"""Registry syncs from DB; get_effective_policies returns merged blocked + global."""
|
||||
prisma = MagicMock()
|
||||
prisma.db.litellm_tooltable.find_many = AsyncMock(
|
||||
return_value=[
|
||||
_mock_tool_row("tool_a", input_policy="trusted"),
|
||||
_mock_tool_row("tool_b", input_policy="blocked"),
|
||||
_mock_tool_row("tool_c", input_policy="untrusted"),
|
||||
]
|
||||
)
|
||||
prisma.db.litellm_objectpermissiontable.find_many = AsyncMock(
|
||||
return_value=[
|
||||
_mock_perm_row("op-key-1", ["tool_a"]),
|
||||
_mock_perm_row("op-team-1", ["tool_c"]),
|
||||
]
|
||||
)
|
||||
registry = get_tool_policy_registry()
|
||||
await registry.sync_tool_policy_from_db(prisma)
|
||||
assert registry.is_initialized()
|
||||
# Key blocked: tool_a. Team blocked: tool_c. Global: tool_b blocked.
|
||||
result = registry.get_effective_policies(
|
||||
["tool_a", "tool_b", "tool_c"],
|
||||
object_permission_id="op-key-1",
|
||||
team_object_permission_id="op-team-1",
|
||||
)
|
||||
assert result["tool_a"] == "blocked"
|
||||
assert result["tool_b"] == "blocked"
|
||||
assert result["tool_c"] == "blocked"
|
||||
# No op ids: only global
|
||||
result_global = registry.get_effective_policies(["tool_a", "tool_b", "tool_c"])
|
||||
assert result_global["tool_a"] == "trusted"
|
||||
assert result_global["tool_b"] == "blocked"
|
||||
assert result_global["tool_c"] == "untrusted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_policy_registry_not_initialized_returns_untrusted():
|
||||
"""When not synced, get_effective_policies still returns untrusted for unknown tools."""
|
||||
registry = ToolPolicyRegistry()
|
||||
assert not registry.is_initialized()
|
||||
result = registry.get_effective_policies(["unknown_tool"])
|
||||
assert result == {"unknown_tool": "untrusted"}
|
||||
|
||||
@ -12,9 +12,8 @@ from fastapi import HTTPException
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../../../../../.."))
|
||||
|
||||
from litellm.proxy.guardrails.guardrail_hooks.tool_policy.tool_policy_guardrail import (
|
||||
ToolPolicyGuardrail,
|
||||
)
|
||||
from litellm.proxy.guardrails.guardrail_hooks.tool_policy.tool_policy_guardrail import \
|
||||
ToolPolicyGuardrail
|
||||
from litellm.types.guardrails import GuardrailEventHooks
|
||||
|
||||
|
||||
@ -70,10 +69,21 @@ async def test_no_tool_calls_in_response_passes_through(guardrail):
|
||||
assert result is inputs
|
||||
|
||||
|
||||
def _registry_mock(policy_map: dict):
|
||||
"""Return a mock registry with is_initialized=True and get_effective_policies returning policy_map."""
|
||||
reg = MagicMock()
|
||||
reg.is_initialized.return_value = True
|
||||
reg.get_effective_policies.return_value = policy_map
|
||||
return reg
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_untrusted_tools_pass_through(guardrail):
|
||||
policy_map = {"search": "untrusted", "read_file": "trusted"}
|
||||
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)):
|
||||
with patch(
|
||||
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
|
||||
return_value=_registry_mock(policy_map),
|
||||
):
|
||||
inputs: Any = _tool_request_inputs(["search", "read_file"])
|
||||
result = await guardrail.apply_guardrail(
|
||||
inputs=inputs, request_data={}, input_type="request"
|
||||
@ -84,7 +94,10 @@ async def test_untrusted_tools_pass_through(guardrail):
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_tool_in_request_raises_http_exception(guardrail):
|
||||
policy_map = {"dangerous_tool": "blocked"}
|
||||
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)):
|
||||
with patch(
|
||||
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
|
||||
return_value=_registry_mock(policy_map),
|
||||
):
|
||||
inputs: Any = _tool_request_inputs(["dangerous_tool"])
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await guardrail.apply_guardrail(
|
||||
@ -97,7 +110,10 @@ async def test_blocked_tool_in_request_raises_http_exception(guardrail):
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocked_tool_in_response_raises_http_exception(guardrail):
|
||||
policy_map = {"exfil_tool": "blocked"}
|
||||
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)):
|
||||
with patch(
|
||||
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
|
||||
return_value=_registry_mock(policy_map),
|
||||
):
|
||||
inputs: Any = _tool_response_inputs(["exfil_tool"])
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await guardrail.apply_guardrail(
|
||||
@ -110,7 +126,10 @@ async def test_blocked_tool_in_response_raises_http_exception(guardrail):
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_blocked_and_allowed_raises_for_blocked(guardrail):
|
||||
policy_map = {"safe_tool": "trusted", "bad_tool": "blocked"}
|
||||
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)):
|
||||
with patch(
|
||||
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
|
||||
return_value=_registry_mock(policy_map),
|
||||
):
|
||||
inputs: Any = _tool_request_inputs(["safe_tool", "bad_tool"])
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await guardrail.apply_guardrail(
|
||||
@ -123,8 +142,11 @@ async def test_mixed_blocked_and_allowed_raises_for_blocked(guardrail):
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_not_in_db_passes_through(guardrail):
|
||||
"""Tools not found in the DB (no entry) should not be blocked."""
|
||||
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value={})):
|
||||
"""When registry returns no policy (or empty), tools are not blocked."""
|
||||
with patch(
|
||||
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
|
||||
return_value=_registry_mock({}),
|
||||
):
|
||||
inputs: Any = _tool_request_inputs(["unknown_tool"])
|
||||
result = await guardrail.apply_guardrail(
|
||||
inputs=inputs, request_data={}, input_type="request"
|
||||
@ -133,43 +155,30 @@ async def test_tool_not_in_db_passes_through(guardrail):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_policies_cached_uses_cache(guardrail):
|
||||
"""Second call with same tool names should return the cached result."""
|
||||
policy_map = {"tool_a": "trusted"}
|
||||
async def test_registry_not_initialized_passes_through(guardrail):
|
||||
"""When registry is not initialized, no tools are blocked (empty policy map)."""
|
||||
reg = MagicMock()
|
||||
reg.is_initialized.return_value = False
|
||||
with patch(
|
||||
"litellm.proxy.db.tool_registry_writer.get_tools_by_names",
|
||||
new=AsyncMock(return_value=policy_map),
|
||||
) as mock_db, patch(
|
||||
"litellm.proxy.proxy_server.prisma_client",
|
||||
new=MagicMock(),
|
||||
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
|
||||
return_value=reg,
|
||||
):
|
||||
# first call — should hit DB
|
||||
result1 = await guardrail._get_policies_cached(["tool_a"])
|
||||
assert result1 == policy_map
|
||||
|
||||
# second call — should hit cache, not DB again
|
||||
result2 = await guardrail._get_policies_cached(["tool_a"])
|
||||
assert result2 == policy_map
|
||||
|
||||
assert mock_db.call_count == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_policies_cached_no_prisma(guardrail):
|
||||
"""Without a prisma client, returns empty dict."""
|
||||
with patch(
|
||||
"litellm.proxy.proxy_server.prisma_client",
|
||||
None,
|
||||
):
|
||||
result = await guardrail._get_policies_cached(["tool_a"])
|
||||
assert result == {}
|
||||
inputs: Any = _tool_request_inputs(["any_tool"])
|
||||
result = await guardrail.apply_guardrail(
|
||||
inputs=inputs, request_data={}, input_type="request"
|
||||
)
|
||||
assert result is inputs
|
||||
reg.get_effective_policies.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_tool_calls_as_objects(guardrail):
|
||||
"""tool_calls that are objects (not dicts) with .function.name should work."""
|
||||
policy_map = {"obj_tool": "blocked"}
|
||||
with patch.object(guardrail, "_get_policies_cached", new=AsyncMock(return_value=policy_map)):
|
||||
with patch(
|
||||
"litellm.proxy.db.tool_registry_writer.get_tool_policy_registry",
|
||||
return_value=_registry_mock(policy_map),
|
||||
):
|
||||
fn = MagicMock()
|
||||
fn.name = "obj_tool"
|
||||
tc = MagicMock()
|
||||
|
||||
200
tests/test_litellm/proxy/test_tools_allowlist_enforcement.py
Normal file
200
tests/test_litellm/proxy/test_tools_allowlist_enforcement.py
Normal file
@ -0,0 +1,200 @@
|
||||
"""
|
||||
Tests for tool allowlist enforcement (key/team metadata.allowed_tools).
|
||||
|
||||
Covers:
|
||||
- check_tools_allowlist: allowed, disallowed, no allowlist, non-tool routes
|
||||
- extract_request_tool_names: OpenAI chat, responses, Anthropic, generate_content, MCP
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy._types import (ProxyErrorTypes, ProxyException,
|
||||
UserAPIKeyAuth)
|
||||
from litellm.proxy.auth.auth_checks import check_tools_allowlist
|
||||
from litellm.proxy.guardrails.tool_name_extraction import (
|
||||
TOOL_CAPABLE_CALL_TYPES, extract_request_tool_names)
|
||||
|
||||
|
||||
def _token(metadata=None, team_metadata=None):
|
||||
return UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_id="user",
|
||||
team_id="team",
|
||||
org_id=None,
|
||||
models=["*"],
|
||||
metadata=metadata or {},
|
||||
team_metadata=team_metadata or {},
|
||||
)
|
||||
|
||||
|
||||
class TestExtractRequestToolNames:
|
||||
"""Test tool name extraction per API format."""
|
||||
|
||||
def test_openai_chat_tools(self):
|
||||
data = {
|
||||
"tools": [
|
||||
{"type": "function", "function": {"name": "get_weather"}},
|
||||
{"type": "function", "function": {"name": "run_sql"}},
|
||||
]
|
||||
}
|
||||
assert extract_request_tool_names("/v1/chat/completions", data) == [
|
||||
"get_weather",
|
||||
"run_sql",
|
||||
]
|
||||
|
||||
def test_openai_chat_functions_legacy(self):
|
||||
data = {"functions": [{"name": "get_weather"}, {"name": "run_sql"}]}
|
||||
assert extract_request_tool_names("/v1/chat/completions", data) == [
|
||||
"get_weather",
|
||||
"run_sql",
|
||||
]
|
||||
|
||||
def test_openai_responses_function_tools(self):
|
||||
data = {
|
||||
"tools": [
|
||||
{"type": "function", "name": "get_current_weather", "description": "x"},
|
||||
]
|
||||
}
|
||||
assert extract_request_tool_names("/v1/responses", data) == [
|
||||
"get_current_weather"
|
||||
]
|
||||
|
||||
def test_openai_responses_mcp_tools(self):
|
||||
data = {
|
||||
"tools": [
|
||||
{"type": "mcp", "server_label": "dmcp", "server_url": "http://x"},
|
||||
]
|
||||
}
|
||||
assert extract_request_tool_names("/v1/responses", data) == ["dmcp"]
|
||||
|
||||
def test_anthropic_tools(self):
|
||||
data = {"tools": [{"name": "get_weather"}, {"name": "run_sql"}]}
|
||||
assert extract_request_tool_names("/v1/messages", data) == [
|
||||
"get_weather",
|
||||
"run_sql",
|
||||
]
|
||||
|
||||
def test_generate_content_tools(self):
|
||||
data = {
|
||||
"tools": [
|
||||
{
|
||||
"functionDeclarations": [
|
||||
{"name": "schedule_meeting", "description": "x"},
|
||||
]
|
||||
},
|
||||
]
|
||||
}
|
||||
assert extract_request_tool_names("/generate_content", data) == [
|
||||
"schedule_meeting"
|
||||
]
|
||||
|
||||
def test_mcp_call_tool_name(self):
|
||||
data = {"name": "my_tool", "arguments": {}}
|
||||
assert extract_request_tool_names("/mcp/call_tool", data) == ["my_tool"]
|
||||
|
||||
def test_mcp_call_tool_mcp_tool_name(self):
|
||||
data = {"mcp_tool_name": "other_tool"}
|
||||
assert extract_request_tool_names("/mcp/call_tool", data) == ["other_tool"]
|
||||
|
||||
def test_non_tool_route_returns_empty(self):
|
||||
data = {"tools": [{"type": "function", "function": {"name": "x"}}]}
|
||||
assert extract_request_tool_names("/v1/embeddings", data) == []
|
||||
|
||||
|
||||
class TestCheckToolsAllowlist:
|
||||
"""Test allowlist enforcement in auth (no DB in hot path)."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_allowlist_passes(self):
|
||||
token = _token(metadata={}, team_metadata={})
|
||||
body = {
|
||||
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
|
||||
}
|
||||
await check_tools_allowlist(
|
||||
request_body=body,
|
||||
valid_token=token,
|
||||
team_object=None,
|
||||
route="/v1/chat/completions",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allowed_tool_passes(self):
|
||||
token = _token(metadata={"allowed_tools": ["get_weather"]})
|
||||
body = {
|
||||
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
|
||||
}
|
||||
await check_tools_allowlist(
|
||||
request_body=body,
|
||||
valid_token=token,
|
||||
team_object=None,
|
||||
route="/v1/chat/completions",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_disallowed_tool_raises(self):
|
||||
token = _token(metadata={"allowed_tools": ["other_tool"]})
|
||||
body = {
|
||||
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
|
||||
}
|
||||
with pytest.raises(ProxyException) as exc_info:
|
||||
await check_tools_allowlist(
|
||||
request_body=body,
|
||||
valid_token=token,
|
||||
team_object=None,
|
||||
route="/v1/chat/completions",
|
||||
)
|
||||
assert exc_info.value.type == ProxyErrorTypes.tool_access_denied
|
||||
assert "get_weather" in str(exc_info.value.message)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_allowlist_used_when_key_empty(self):
|
||||
token = _token(
|
||||
metadata={},
|
||||
team_metadata={"allowed_tools": ["get_weather"]},
|
||||
)
|
||||
body = {
|
||||
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
|
||||
}
|
||||
await check_tools_allowlist(
|
||||
request_body=body,
|
||||
valid_token=token,
|
||||
team_object=None,
|
||||
route="/v1/chat/completions",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_allowlist_overrides_team(self):
|
||||
token = _token(
|
||||
metadata={"allowed_tools": ["get_weather"]},
|
||||
team_metadata={"allowed_tools": ["other_tool"]},
|
||||
)
|
||||
body = {
|
||||
"tools": [{"type": "function", "function": {"name": "get_weather"}}]
|
||||
}
|
||||
await check_tools_allowlist(
|
||||
request_body=body,
|
||||
valid_token=token,
|
||||
team_object=None,
|
||||
route="/v1/chat/completions",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_valid_token_none_skips(self):
|
||||
await check_tools_allowlist(
|
||||
request_body={"tools": [{"type": "function", "function": {"name": "x"}}]},
|
||||
valid_token=None,
|
||||
team_object=None,
|
||||
route="/v1/chat/completions",
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tools_in_body_passes(self):
|
||||
token = _token(metadata={"allowed_tools": ["get_weather"]})
|
||||
await check_tools_allowlist(
|
||||
request_body={"messages": []},
|
||||
valid_token=token,
|
||||
team_object=None,
|
||||
route="/v1/chat/completions",
|
||||
)
|
||||
@ -39,7 +39,7 @@ import UserDashboard from "@/components/user_dashboard";
|
||||
import { AccessGroupsPage } from "@/components/AccessGroups/AccessGroupsPage";
|
||||
import { ProjectsPage } from "@/components/Projects/ProjectsPage";
|
||||
import VectorStoreManagement from "@/components/vector_store_management";
|
||||
import ToolPolicies from "@/components/ToolPolicies";
|
||||
import ToolPoliciesView from "@/components/ToolPoliciesView";
|
||||
import SpendLogsTable from "@/components/view_logs";
|
||||
import ViewUserDashboard from "@/components/view_users";
|
||||
import { ThemeProvider } from "@/contexts/ThemeContext";
|
||||
@ -549,7 +549,7 @@ function CreateKeyPageContent() {
|
||||
) : page == "vector-stores" ? (
|
||||
<VectorStoreManagement accessToken={accessToken} userRole={userRole} userID={userID} />
|
||||
) : page == "tool-policies" ? (
|
||||
<ToolPolicies accessToken={accessToken} userRole={userRole} />
|
||||
<ToolPoliciesView accessToken={accessToken} userRole={userRole} />
|
||||
) : page == "guardrails-monitor" ? (
|
||||
<GuardrailsMonitorView accessToken={accessToken} />
|
||||
) : page == "new_usage" ? (
|
||||
|
||||
445
ui/litellm-dashboard/src/components/ToolDetail.tsx
Normal file
445
ui/litellm-dashboard/src/components/ToolDetail.tsx
Normal file
@ -0,0 +1,445 @@
|
||||
"use client";
|
||||
|
||||
import { ArrowLeftOutlined, HistoryOutlined, ToolOutlined } from "@ant-design/icons";
|
||||
import { useQuery, useQueryClient } from "@tanstack/react-query";
|
||||
import { Button, Select, Spin } from "antd";
|
||||
import React, { useCallback, useEffect, useMemo, useState } from "react";
|
||||
import TeamDropdown from "@/components/common_components/team_dropdown";
|
||||
import { LogViewer } from "@/components/GuardrailsMonitor/LogViewer";
|
||||
import type { LogEntry } from "@/components/GuardrailsMonitor/mockData";
|
||||
import { PolicySelect } from "@/components/ToolPolicies/PolicySelect";
|
||||
import {
|
||||
deleteToolPolicyOverride,
|
||||
fetchToolDetail,
|
||||
fetchToolPolicyOptions,
|
||||
getToolUsageLogs,
|
||||
keyListCall,
|
||||
teamListCall,
|
||||
updateToolPolicy,
|
||||
type ToolPolicyOption,
|
||||
type ToolPolicyOverrideRow,
|
||||
} from "@/components/networking";
|
||||
import type { Team } from "@/components/key_team_helpers/key_list";
|
||||
|
||||
interface ToolDetailProps {
|
||||
toolName: string;
|
||||
onBack: () => void;
|
||||
accessToken: string | null;
|
||||
}
|
||||
|
||||
interface KeyOption {
|
||||
token: string;
|
||||
key_alias?: string;
|
||||
}
|
||||
|
||||
const TOOL_DETAIL_QUERY_KEY = "tool-detail";
|
||||
|
||||
const LOGS_PAGE_SIZE = 50;
|
||||
|
||||
function getDefaultLogsDateRange(): { start: string; end: string } {
|
||||
const end = new Date();
|
||||
const start = new Date();
|
||||
start.setDate(start.getDate() - 90);
|
||||
const fmt = (d: Date) =>
|
||||
d.toISOString().slice(0, 19).replace("T", " ");
|
||||
return { start: fmt(start), end: fmt(end) };
|
||||
}
|
||||
|
||||
export function ToolDetail({ toolName, onBack, accessToken }: ToolDetailProps) {
|
||||
const queryClient = useQueryClient();
|
||||
const [overrideSaving, setOverrideSaving] = useState(false);
|
||||
const [inputPolicySaving, setInputPolicySaving] = useState(false);
|
||||
const [outputPolicySaving, setOutputPolicySaving] = useState(false);
|
||||
const [blockScope, setBlockScope] = useState<"team" | "key">("team");
|
||||
const [blockTeamId, setBlockTeamId] = useState<string | null>(null);
|
||||
const [blockKey, setBlockKey] = useState<KeyOption | null>(null);
|
||||
|
||||
const logsDateRange = useMemo(() => getDefaultLogsDateRange(), []);
|
||||
|
||||
const { data: detail, isLoading: detailLoading, error: detailError } = useQuery({
|
||||
queryKey: [TOOL_DETAIL_QUERY_KEY, toolName],
|
||||
queryFn: () => fetchToolDetail(accessToken!, toolName),
|
||||
enabled: !!accessToken && !!toolName,
|
||||
});
|
||||
|
||||
const { data: policyOptions } = useQuery({
|
||||
queryKey: ["tool-policy-options"],
|
||||
queryFn: () => fetchToolPolicyOptions(accessToken!),
|
||||
enabled: !!accessToken,
|
||||
staleTime: 60_000,
|
||||
});
|
||||
|
||||
const { data: teamsData } = useQuery({
|
||||
queryKey: ["teams-list-tool-detail"],
|
||||
queryFn: () => teamListCall(accessToken!, null, null),
|
||||
enabled: !!accessToken,
|
||||
});
|
||||
|
||||
const { data: keysData } = useQuery({
|
||||
queryKey: ["keys-list-tool-detail"],
|
||||
queryFn: () => keyListCall(accessToken!, null, null, null, null, null, 1, 100),
|
||||
enabled: !!accessToken,
|
||||
});
|
||||
|
||||
const { data: logsData, isLoading: logsLoading } = useQuery({
|
||||
queryKey: ["tool-usage-logs", toolName, logsDateRange.start, logsDateRange.end],
|
||||
queryFn: () =>
|
||||
getToolUsageLogs(accessToken!, toolName, {
|
||||
page: 1,
|
||||
pageSize: LOGS_PAGE_SIZE,
|
||||
startDate: logsDateRange.start,
|
||||
endDate: logsDateRange.end,
|
||||
}),
|
||||
enabled: !!accessToken && !!toolName,
|
||||
});
|
||||
|
||||
const logs: LogEntry[] = useMemo(() => {
|
||||
const list = logsData?.logs ?? [];
|
||||
return list.map((l) => ({
|
||||
id: l.id,
|
||||
timestamp: l.timestamp,
|
||||
action: "passed" as const,
|
||||
model: l.model ?? undefined,
|
||||
input_snippet: l.input_snippet ?? undefined,
|
||||
}));
|
||||
}, [logsData?.logs]);
|
||||
|
||||
const teams: Team[] = useMemo(() => {
|
||||
const arr = Array.isArray(teamsData) ? teamsData : teamsData?.data ?? [];
|
||||
return arr.map((t: { team_id?: string; id?: string; team_alias?: string }) => ({
|
||||
team_id: t.team_id ?? t.id ?? "",
|
||||
team_alias: t.team_alias ?? t.team_id ?? "",
|
||||
models: [],
|
||||
max_budget: null,
|
||||
budget_duration: null,
|
||||
tpm_limit: null,
|
||||
rpm_limit: null,
|
||||
organization_id: "",
|
||||
created_at: "",
|
||||
keys: [],
|
||||
members_with_roles: [],
|
||||
spend: 0,
|
||||
}));
|
||||
}, [teamsData]);
|
||||
|
||||
const keys: KeyOption[] = useMemo(() => {
|
||||
const keysRes = keysData?.keys ?? keysData?.data ?? [];
|
||||
return keysRes.map((k: { token?: string; api_key?: string; key_hash?: string; key_alias?: string }) => ({
|
||||
token: k.token ?? k.api_key ?? k.key_hash ?? "",
|
||||
key_alias: k.key_alias ?? (k.token ?? k.api_key ?? k.key_hash)?.toString?.()?.substring?.(0, 8),
|
||||
}));
|
||||
}, [keysData]);
|
||||
|
||||
const invalidateDetail = useCallback(() => {
|
||||
queryClient.invalidateQueries({ queryKey: [TOOL_DETAIL_QUERY_KEY, toolName] });
|
||||
}, [queryClient, toolName]);
|
||||
|
||||
const handleInputPolicyChange = useCallback(
|
||||
async (_name: string, newPolicy: string) => {
|
||||
if (!accessToken) return;
|
||||
setInputPolicySaving(true);
|
||||
try {
|
||||
await updateToolPolicy(accessToken, toolName, { input_policy: newPolicy });
|
||||
invalidateDetail();
|
||||
} catch (e: unknown) {
|
||||
alert(`Failed to update input policy: ${e instanceof Error ? e.message : String(e)}`);
|
||||
} finally {
|
||||
setInputPolicySaving(false);
|
||||
}
|
||||
},
|
||||
[accessToken, toolName, invalidateDetail]
|
||||
);
|
||||
|
||||
const handleOutputPolicyChange = useCallback(
|
||||
async (_name: string, newPolicy: string) => {
|
||||
if (!accessToken) return;
|
||||
setOutputPolicySaving(true);
|
||||
try {
|
||||
await updateToolPolicy(accessToken, toolName, { output_policy: newPolicy });
|
||||
invalidateDetail();
|
||||
} catch (e: unknown) {
|
||||
alert(`Failed to update output policy: ${e instanceof Error ? e.message : String(e)}`);
|
||||
} finally {
|
||||
setOutputPolicySaving(false);
|
||||
}
|
||||
},
|
||||
[accessToken, toolName, invalidateDetail]
|
||||
);
|
||||
|
||||
const handleAddOverride = useCallback(async () => {
|
||||
if (!accessToken || !toolName) return;
|
||||
const isTeam = blockScope === "team";
|
||||
if (isTeam && !blockTeamId) return;
|
||||
if (!isTeam && !blockKey?.token) return;
|
||||
setOverrideSaving(true);
|
||||
try {
|
||||
await updateToolPolicy(accessToken, toolName, { input_policy: "blocked" }, {
|
||||
team_id: isTeam ? blockTeamId : undefined,
|
||||
key_hash: !isTeam ? blockKey!.token : undefined,
|
||||
key_alias: !isTeam ? blockKey!.key_alias : undefined,
|
||||
});
|
||||
invalidateDetail();
|
||||
setBlockTeamId(null);
|
||||
setBlockKey(null);
|
||||
} catch (e: unknown) {
|
||||
alert(`Failed to add override: ${e instanceof Error ? e.message : String(e)}`);
|
||||
} finally {
|
||||
setOverrideSaving(false);
|
||||
}
|
||||
}, [accessToken, toolName, blockScope, blockTeamId, blockKey, invalidateDetail]);
|
||||
|
||||
const handleRemoveOverride = useCallback(
|
||||
async (override: ToolPolicyOverrideRow) => {
|
||||
if (!accessToken || !toolName) return;
|
||||
setOverrideSaving(true);
|
||||
try {
|
||||
await deleteToolPolicyOverride(accessToken, toolName, {
|
||||
team_id: override.team_id ?? undefined,
|
||||
key_hash: override.key_hash ?? undefined,
|
||||
});
|
||||
invalidateDetail();
|
||||
} catch (e: unknown) {
|
||||
alert(`Failed to remove override: ${e instanceof Error ? e.message : String(e)}`);
|
||||
} finally {
|
||||
setOverrideSaving(false);
|
||||
}
|
||||
},
|
||||
[accessToken, toolName, invalidateDetail]
|
||||
);
|
||||
|
||||
if (detailLoading && !detail) {
|
||||
return (
|
||||
<div className="flex items-center justify-center py-12">
|
||||
<Spin size="large" />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (detailError && !detail) {
|
||||
return (
|
||||
<div>
|
||||
<Button type="link" icon={<ArrowLeftOutlined />} onClick={onBack} className="pl-0 mb-4">
|
||||
Back to Tool Policies
|
||||
</Button>
|
||||
<p className="text-red-600">Failed to load tool details.</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!detail) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const { tool, overrides } = detail;
|
||||
|
||||
const inputDesc = policyOptions?.input_policies?.find(
|
||||
(p) => p.value === tool.input_policy
|
||||
)?.description;
|
||||
const outputDesc = policyOptions?.output_policies?.find(
|
||||
(p) => p.value === tool.output_policy
|
||||
)?.description;
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="mb-6">
|
||||
<Button
|
||||
type="link"
|
||||
icon={<ArrowLeftOutlined />}
|
||||
onClick={onBack}
|
||||
className="pl-0 mb-4"
|
||||
>
|
||||
Back to Tool Policies
|
||||
</Button>
|
||||
|
||||
<div className="flex items-start justify-between">
|
||||
<div>
|
||||
<div className="flex items-center gap-3 mb-1 flex-wrap">
|
||||
<ToolOutlined className="text-xl text-gray-400" />
|
||||
<h1 className="text-xl font-semibold text-gray-900 font-mono">{tool.tool_name}</h1>
|
||||
<span className="inline-flex items-center px-2.5 py-1 text-xs font-medium rounded-md bg-gray-100 text-gray-700 border border-gray-200">
|
||||
{tool.origin ?? "—"}
|
||||
</span>
|
||||
<span className="inline-flex items-center px-2.5 py-1 text-xs font-medium rounded-md bg-indigo-50 text-indigo-700 border border-indigo-200">
|
||||
{(tool.call_count ?? 0).toLocaleString()} calls
|
||||
</span>
|
||||
</div>
|
||||
<dl className="mt-3 flex flex-wrap gap-x-6 gap-y-1 text-sm text-gray-600">
|
||||
{tool.user_agent && (
|
||||
<div className="flex items-center gap-1.5">
|
||||
<dt className="font-medium text-gray-500 whitespace-nowrap">User Agent:</dt>
|
||||
<dd className="font-mono truncate max-w-[40ch]" title={tool.user_agent}>{tool.user_agent}</dd>
|
||||
</div>
|
||||
)}
|
||||
{tool.created_at && (
|
||||
<div className="flex items-center gap-1.5">
|
||||
<dt className="font-medium text-gray-500 whitespace-nowrap">First Discovered:</dt>
|
||||
<dd>{new Date(tool.created_at).toLocaleString()}</dd>
|
||||
</div>
|
||||
)}
|
||||
{tool.last_used_at && (
|
||||
<div className="flex items-center gap-1.5">
|
||||
<dt className="font-medium text-gray-500 whitespace-nowrap">Last Used:</dt>
|
||||
<dd>{new Date(tool.last_used_at).toLocaleString()}</dd>
|
||||
</div>
|
||||
)}
|
||||
</dl>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-6">
|
||||
{/* Two-panel policy layout */}
|
||||
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
|
||||
<section className="bg-white rounded-lg border border-gray-200 p-5 shadow-sm">
|
||||
<h2 className="text-sm font-semibold text-gray-700 mb-1">Input Policy</h2>
|
||||
<p className="text-xs text-gray-500 mb-3">
|
||||
{inputDesc ?? "Controls what data this tool is allowed to accept."}
|
||||
</p>
|
||||
<PolicySelect
|
||||
value={tool.input_policy}
|
||||
toolName={tool.tool_name}
|
||||
saving={inputPolicySaving}
|
||||
onChange={handleInputPolicyChange}
|
||||
policyType="input"
|
||||
size="middle"
|
||||
minWidth={140}
|
||||
stopPropagation={false}
|
||||
/>
|
||||
</section>
|
||||
|
||||
<section className="bg-white rounded-lg border border-gray-200 p-5 shadow-sm">
|
||||
<h2 className="text-sm font-semibold text-gray-700 mb-1">Output Policy</h2>
|
||||
<p className="text-xs text-gray-500 mb-3">
|
||||
{outputDesc ?? "Controls how this tool's output is trusted by downstream tools."}
|
||||
</p>
|
||||
<PolicySelect
|
||||
value={tool.output_policy}
|
||||
toolName={tool.tool_name}
|
||||
saving={outputPolicySaving}
|
||||
onChange={handleOutputPolicyChange}
|
||||
policyType="output"
|
||||
size="middle"
|
||||
minWidth={140}
|
||||
stopPropagation={false}
|
||||
/>
|
||||
</section>
|
||||
</div>
|
||||
|
||||
{overrides.length > 0 && (
|
||||
<section className="bg-white rounded-lg border border-gray-200 p-5 shadow-sm">
|
||||
<h2 className="text-sm font-semibold text-gray-700 mb-3">Blocked for team or key</h2>
|
||||
<ul className="border rounded-md divide-y divide-gray-100 bg-red-50/30">
|
||||
{overrides.map((ov) => (
|
||||
<li
|
||||
key={ov.override_id}
|
||||
className="flex items-center justify-between px-3 py-2.5 text-sm"
|
||||
>
|
||||
<span className="text-gray-700">
|
||||
{ov.team_id ? `Team: ${ov.team_id}` : ""}
|
||||
{ov.team_id && ov.key_hash ? " · " : ""}
|
||||
{ov.key_hash ? `Key: ${ov.key_alias || ov.key_hash.substring(0, 8)}` : ""}
|
||||
{!ov.team_id && !ov.key_hash ? "—" : ""}
|
||||
</span>
|
||||
<Button
|
||||
type="link"
|
||||
danger
|
||||
size="small"
|
||||
disabled={overrideSaving}
|
||||
onClick={() => handleRemoveOverride(ov)}
|
||||
>
|
||||
Remove
|
||||
</Button>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</section>
|
||||
)}
|
||||
|
||||
<section className="bg-white rounded-lg border border-gray-200 p-5 shadow-sm">
|
||||
<h2 className="text-sm font-semibold text-gray-700 mb-3">Block for team or key</h2>
|
||||
<div className="flex flex-col gap-4 max-w-md">
|
||||
<div>
|
||||
<span className="text-sm font-medium text-gray-700 block mb-2">Scope</span>
|
||||
<div className="flex items-center gap-6">
|
||||
<label className="flex items-center gap-2 cursor-pointer text-sm text-gray-700">
|
||||
<input
|
||||
type="radio"
|
||||
checked={blockScope === "team"}
|
||||
onChange={() => setBlockScope("team")}
|
||||
className="align-middle"
|
||||
/>
|
||||
Team
|
||||
</label>
|
||||
<label className="flex items-center gap-2 cursor-pointer text-sm text-gray-700">
|
||||
<input
|
||||
type="radio"
|
||||
checked={blockScope === "key"}
|
||||
onChange={() => setBlockScope("key")}
|
||||
className="align-middle"
|
||||
/>
|
||||
Key
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
<span className="text-sm font-medium text-gray-700 block mb-2">
|
||||
{blockScope === "team" ? "Team" : "Key"}
|
||||
</span>
|
||||
{blockScope === "team" ? (
|
||||
<TeamDropdown
|
||||
teams={teams}
|
||||
value={blockTeamId ?? undefined}
|
||||
onChange={(id) => setBlockTeamId(id || null)}
|
||||
/>
|
||||
) : (
|
||||
<Select
|
||||
placeholder="Select key"
|
||||
allowClear
|
||||
showSearch
|
||||
optionFilterProp="label"
|
||||
value={blockKey ? blockKey.token : undefined}
|
||||
onChange={(token) => {
|
||||
const k = keys.find((x) => x.token === token);
|
||||
setBlockKey(k ?? null);
|
||||
}}
|
||||
options={keys.map((k) => ({
|
||||
value: k.token,
|
||||
label: k.key_alias || k.token?.substring?.(0, 12) || k.token,
|
||||
}))}
|
||||
className="w-full"
|
||||
style={{ minWidth: 200 }}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
<Button
|
||||
type="primary"
|
||||
danger
|
||||
disabled={overrideSaving || (blockScope === "team" ? !blockTeamId : !blockKey?.token)}
|
||||
loading={overrideSaving}
|
||||
onClick={handleAddOverride}
|
||||
>
|
||||
Block for {blockScope}
|
||||
</Button>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<section className="bg-white rounded-lg border border-gray-200 p-5 shadow-sm">
|
||||
<h2 className="text-sm font-semibold text-gray-700 mb-3 flex items-center gap-2">
|
||||
<HistoryOutlined />
|
||||
Recent logs
|
||||
</h2>
|
||||
<LogViewer
|
||||
guardrailName={tool.tool_name}
|
||||
filterAction="passed"
|
||||
logs={logs}
|
||||
logsLoading={logsLoading}
|
||||
totalLogs={logsData?.total ?? 0}
|
||||
accessToken={accessToken}
|
||||
startDate={logsDateRange.start}
|
||||
endDate={logsDateRange.end}
|
||||
/>
|
||||
</section>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -1,24 +1,46 @@
|
||||
"use client";
|
||||
|
||||
import { Table, TableBody, TableCell, TableHead, TableHeaderCell, TableRow } from "@tremor/react";
|
||||
import { Select, Switch, Tooltip } from "antd";
|
||||
import React, { useCallback, useDeferredValue, useEffect, useState } from "react";
|
||||
import React, { useCallback, useDeferredValue, useEffect, useMemo, useState } from "react";
|
||||
import { Button, Switch, Tooltip } from "antd";
|
||||
import { Table, TableHead, TableHeaderCell, TableBody, TableRow, TableCell } from "@tremor/react";
|
||||
import { TimeCell } from "./view_logs/time_cell";
|
||||
import type { SortState } from "./common_components/TableHeaderSortDropdown/TableHeaderSortDropdown";
|
||||
import { TableHeaderSortDropdown } from "./common_components/TableHeaderSortDropdown/TableHeaderSortDropdown";
|
||||
import FilterComponent, { FilterOption } from "./molecules/filter";
|
||||
import { fetchToolsList, ToolRow, updateToolPolicy } from "./networking";
|
||||
import { TimeCell } from "./view_logs/time_cell";
|
||||
import { MetricCard } from "./GuardrailsMonitor/MetricCard";
|
||||
import { PolicySelect, INPUT_POLICY_OPTIONS, OUTPUT_POLICY_OPTIONS } from "./ToolPolicies/PolicySelect";
|
||||
import {
|
||||
fetchToolsList,
|
||||
updateToolPolicy,
|
||||
ToolRow,
|
||||
} from "./networking";
|
||||
|
||||
const POLICY_OPTIONS = [
|
||||
{ value: "trusted", label: "trusted", color: "#065f46", bg: "#d1fae5", border: "#6ee7b7" },
|
||||
{ value: "blocked", label: "blocked", color: "#991b1b", bg: "#fee2e2", border: "#fca5a5" },
|
||||
] as const;
|
||||
function getUTCDateKey(date: Date): string {
|
||||
return `${date.getUTCFullYear()}-${String(date.getUTCMonth() + 1).padStart(2, "0")}-${String(date.getUTCDate()).padStart(2, "0")}`;
|
||||
}
|
||||
|
||||
type PolicyValue = "trusted" | "blocked";
|
||||
function isCreatedInUTCDay(createdAt: string | undefined, utcDateKey: string): boolean {
|
||||
if (!createdAt) return false;
|
||||
try {
|
||||
const d = new Date(createdAt);
|
||||
return getUTCDateKey(d) === utcDateKey;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
const policyStyle = (p: string) => POLICY_OPTIONS.find((o) => o.value === p) ?? POLICY_OPTIONS[1];
|
||||
function countToolsInUTCDay(tools: ToolRow[], utcDateKey: string): number {
|
||||
return tools.filter((t) => isCreatedInUTCDay(t.created_at, utcDateKey)).length;
|
||||
}
|
||||
|
||||
type SortField = "tool_name" | "call_policy" | "team_id" | "key_alias" | "created_at" | "call_count";
|
||||
function getTrendSubtitle(newToday: number, newYesterday: number): string | undefined {
|
||||
const diff = newToday - newYesterday;
|
||||
if (diff === 0) return undefined;
|
||||
if (diff > 0) return `+${diff} since yesterday`;
|
||||
return `${diff} since yesterday`;
|
||||
}
|
||||
|
||||
type SortField = "tool_name" | "input_policy" | "output_policy" | "team_id" | "key_alias" | "created_at" | "call_count";
|
||||
|
||||
interface FilterValues {
|
||||
[key: string]: string;
|
||||
@ -27,65 +49,16 @@ interface FilterValues {
|
||||
interface ToolPoliciesProps {
|
||||
accessToken: string | null;
|
||||
userRole?: string;
|
||||
onSelectTool?: (toolName: string) => void;
|
||||
}
|
||||
|
||||
const PolicySelect: React.FC<{
|
||||
value: string;
|
||||
toolName: string;
|
||||
saving: boolean;
|
||||
onChange: (toolName: string, policy: string) => void;
|
||||
}> = ({ value, toolName, saving, onChange }) => {
|
||||
const style = policyStyle(value);
|
||||
return (
|
||||
<Select
|
||||
size="small"
|
||||
value={value}
|
||||
disabled={saving}
|
||||
loading={saving}
|
||||
onChange={(v) => onChange(toolName, v)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
style={{
|
||||
minWidth: 110,
|
||||
fontWeight: 500,
|
||||
}}
|
||||
popupMatchSelectWidth={false}
|
||||
options={POLICY_OPTIONS.map((o) => ({
|
||||
value: o.value,
|
||||
label: (
|
||||
<span
|
||||
style={{
|
||||
display: "inline-flex",
|
||||
alignItems: "center",
|
||||
gap: 6,
|
||||
fontSize: 12,
|
||||
fontWeight: 500,
|
||||
color: o.color,
|
||||
}}
|
||||
>
|
||||
<span
|
||||
style={{
|
||||
width: 8,
|
||||
height: 8,
|
||||
borderRadius: "50%",
|
||||
backgroundColor: o.color,
|
||||
display: "inline-block",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
/>
|
||||
{o.label}
|
||||
</span>
|
||||
),
|
||||
}))}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken, onSelectTool }) => {
|
||||
const [tools, setTools] = useState<ToolRow[]>([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [isFetching, setIsFetching] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [saving, setSaving] = useState<string | null>(null);
|
||||
const [savingInput, setSavingInput] = useState<string | null>(null);
|
||||
const [savingOutput, setSavingOutput] = useState<string | null>(null);
|
||||
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
const [sortField, setSortField] = useState<SortField>("created_at");
|
||||
@ -123,16 +96,29 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
return () => clearInterval(id);
|
||||
}, [isLiveTail, load]);
|
||||
|
||||
const handlePolicyChange = async (toolName: string, newPolicy: string) => {
|
||||
const handleInputPolicyChange = async (toolName: string, newPolicy: string) => {
|
||||
if (!accessToken) return;
|
||||
setSaving(toolName);
|
||||
setSavingInput(toolName);
|
||||
try {
|
||||
await updateToolPolicy(accessToken, toolName, newPolicy);
|
||||
setTools((prev) => prev.map((t) => (t.tool_name === toolName ? { ...t, call_policy: newPolicy } : t)));
|
||||
await updateToolPolicy(accessToken, toolName, { input_policy: newPolicy });
|
||||
setTools((prev) => prev.map((t) => (t.tool_name === toolName ? { ...t, input_policy: newPolicy } : t)));
|
||||
} catch (e: any) {
|
||||
alert(`Failed to update policy: ${e.message}`);
|
||||
alert(`Failed to update input policy: ${e.message}`);
|
||||
} finally {
|
||||
setSaving(null);
|
||||
setSavingInput(null);
|
||||
}
|
||||
};
|
||||
|
||||
const handleOutputPolicyChange = async (toolName: string, newPolicy: string) => {
|
||||
if (!accessToken) return;
|
||||
setSavingOutput(toolName);
|
||||
try {
|
||||
await updateToolPolicy(accessToken, toolName, { output_policy: newPolicy });
|
||||
setTools((prev) => prev.map((t) => (t.tool_name === toolName ? { ...t, output_policy: newPolicy } : t)));
|
||||
} catch (e: any) {
|
||||
alert(`Failed to update output policy: ${e.message}`);
|
||||
} finally {
|
||||
setSavingOutput(null);
|
||||
}
|
||||
};
|
||||
|
||||
@ -157,7 +143,6 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
setCurrentPage(1);
|
||||
};
|
||||
|
||||
// Build unique team/key options from loaded data
|
||||
const teamOptions = Array.from(new Set(tools.map((t) => t.team_id).filter(Boolean))).map((v) => ({
|
||||
label: v as string,
|
||||
value: v as string,
|
||||
@ -169,9 +154,14 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
|
||||
const filterOptions: FilterOption[] = [
|
||||
{
|
||||
name: "Policy",
|
||||
label: "Policy",
|
||||
options: POLICY_OPTIONS.map((o) => ({ label: o.label, value: o.value })),
|
||||
name: "Input Policy",
|
||||
label: "Input Policy",
|
||||
options: INPUT_POLICY_OPTIONS.map((o) => ({ label: o.label, value: o.value })),
|
||||
},
|
||||
{
|
||||
name: "Output Policy",
|
||||
label: "Output Policy",
|
||||
options: OUTPUT_POLICY_OPTIONS.map((o) => ({ label: o.label, value: o.value })),
|
||||
},
|
||||
{
|
||||
name: "Team Name",
|
||||
@ -185,6 +175,39 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
},
|
||||
];
|
||||
|
||||
const { newToday, newYesterday, trendSubtitle, totalTools, blockedCount, activeTeamsCount, needsReviewTools } =
|
||||
useMemo(() => {
|
||||
const now = new Date();
|
||||
const todayKey = getUTCDateKey(now);
|
||||
const yesterday = new Date(now);
|
||||
yesterday.setUTCDate(yesterday.getUTCDate() - 1);
|
||||
const yesterdayKey = getUTCDateKey(yesterday);
|
||||
|
||||
const newToday = countToolsInUTCDay(tools, todayKey);
|
||||
const newYesterday = countToolsInUTCDay(tools, yesterdayKey);
|
||||
const trendSubtitle = getTrendSubtitle(newToday, newYesterday);
|
||||
|
||||
const totalTools = tools.length;
|
||||
const blockedCount = tools.filter((t) => t.input_policy === "blocked").length;
|
||||
const activeTeamsCount = new Set(tools.map((t) => t.team_id).filter(Boolean)).size;
|
||||
|
||||
const needsReviewTools = tools.filter(
|
||||
(t) =>
|
||||
isCreatedInUTCDay(t.created_at, todayKey) &&
|
||||
t.input_policy === "untrusted"
|
||||
);
|
||||
|
||||
return {
|
||||
newToday,
|
||||
newYesterday,
|
||||
trendSubtitle,
|
||||
totalTools,
|
||||
blockedCount,
|
||||
activeTeamsCount,
|
||||
needsReviewTools,
|
||||
};
|
||||
}, [tools]);
|
||||
|
||||
const SortHeader = ({ label, field }: { label: string; field: SortField }) => (
|
||||
<div className="flex items-center gap-1">
|
||||
<span>{label}</span>
|
||||
@ -203,10 +226,12 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
(t.team_id ?? "").toLowerCase().includes(q) ||
|
||||
(t.key_alias ?? "").toLowerCase().includes(q) ||
|
||||
(t.key_hash ?? "").toLowerCase().includes(q) ||
|
||||
t.call_policy.toLowerCase().includes(q);
|
||||
t.input_policy.toLowerCase().includes(q) ||
|
||||
t.output_policy.toLowerCase().includes(q);
|
||||
if (!matchesSearch) return false;
|
||||
}
|
||||
if (activeFilters["Policy"] && t.call_policy !== activeFilters["Policy"]) return false;
|
||||
if (activeFilters["Input Policy"] && t.input_policy !== activeFilters["Input Policy"]) return false;
|
||||
if (activeFilters["Output Policy"] && t.output_policy !== activeFilters["Output Policy"]) return false;
|
||||
if (activeFilters["Team Name"] && t.team_id !== activeFilters["Team Name"]) return false;
|
||||
if (activeFilters["Key Name"] && t.key_alias !== activeFilters["Key Name"]) return false;
|
||||
return true;
|
||||
@ -223,11 +248,74 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
const totalPages = Math.max(1, Math.ceil(sorted.length / pageSize));
|
||||
const paginated = sorted.slice((currentPage - 1) * pageSize, currentPage * pageSize);
|
||||
|
||||
const scrollToToolRow = (toolId: string) => {
|
||||
const idx = sorted.findIndex((t) => t.tool_id === toolId);
|
||||
if (idx >= 0) {
|
||||
const page = Math.floor(idx / pageSize) + 1;
|
||||
if (page !== currentPage) setCurrentPage(page);
|
||||
requestAnimationFrame(() => {
|
||||
setTimeout(() => {
|
||||
document.getElementById(`tool-row-${toolId}`)?.scrollIntoView({ behavior: "smooth", block: "center" });
|
||||
}, 100);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="p-6 w-full">
|
||||
<div className="w-full">
|
||||
<h1 className="text-2xl font-semibold text-gray-900 mb-6">Tool Policies</h1>
|
||||
|
||||
<div className="grid grid-cols-2 lg:grid-cols-4 gap-4 mb-6">
|
||||
<MetricCard
|
||||
label="New Today"
|
||||
value={newToday}
|
||||
valueColor="text-green-600"
|
||||
subtitle={trendSubtitle}
|
||||
icon={
|
||||
<svg className="w-4 h-4 text-green-500" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M13 7h8m0 0v8m0-8l-8 8-4-4-6 6" />
|
||||
</svg>
|
||||
}
|
||||
/>
|
||||
<MetricCard label="Total Tools Discovered" value={totalTools} />
|
||||
<MetricCard
|
||||
label="Blocked Tools"
|
||||
value={blockedCount}
|
||||
valueColor={blockedCount > 0 ? "text-red-600" : undefined}
|
||||
/>
|
||||
<MetricCard label="Active Teams" value={activeTeamsCount > 0 ? activeTeamsCount : "—"} />
|
||||
</div>
|
||||
|
||||
{needsReviewTools.length > 0 && (
|
||||
<div className="bg-amber-50 border border-amber-200 rounded-lg p-4 mb-6">
|
||||
<h2 className="text-sm font-semibold text-amber-900 mb-1">Needs Review</h2>
|
||||
<p className="text-sm text-amber-800 mb-3">
|
||||
{needsReviewTools.length} new tool{needsReviewTools.length !== 1 ? "s" : ""} discovered that require
|
||||
policy decisions.
|
||||
</p>
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{needsReviewTools.map((t) => (
|
||||
<span
|
||||
key={t.tool_id}
|
||||
className="inline-flex items-center gap-2 px-3 py-1.5 bg-white border border-amber-200 rounded-md text-sm"
|
||||
>
|
||||
<span className="font-mono text-amber-900 truncate max-w-[200px]" title={t.tool_name}>
|
||||
{t.tool_name}
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => scrollToToolRow(t.tool_id)}
|
||||
className="text-amber-700 hover:text-amber-900 font-medium text-xs whitespace-nowrap"
|
||||
>
|
||||
Review
|
||||
</button>
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="bg-white rounded-lg shadow w-full max-w-full box-border">
|
||||
{/* Toolbar */}
|
||||
<div className="border-b px-6 py-4 w-full max-w-full box-border">
|
||||
<div className="flex flex-col md:flex-row items-start md:items-center justify-between space-y-4 md:space-y-0 w-full max-w-full box-border">
|
||||
<div className="flex flex-wrap items-center gap-3">
|
||||
@ -311,7 +399,6 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Filter row */}
|
||||
<div className="mt-3">
|
||||
<FilterComponent
|
||||
options={filterOptions}
|
||||
@ -322,7 +409,6 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Auto-refresh banner */}
|
||||
{isLiveTail && (
|
||||
<div className="bg-green-50 border-b border-green-100 px-6 py-2 flex items-center justify-between">
|
||||
<span className="text-sm text-green-700">Auto-refreshing every 15 seconds</span>
|
||||
@ -336,7 +422,6 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
<div className="mx-6 mt-4 p-3 bg-red-50 border border-red-200 rounded text-sm text-red-700">{error}</div>
|
||||
)}
|
||||
|
||||
{/* Table */}
|
||||
<Table className="[&_td]:py-0.5 [&_th]:py-1 w-full">
|
||||
<TableHead>
|
||||
<TableRow>
|
||||
@ -347,7 +432,10 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
<SortHeader label="Tool Name" field="tool_name" />
|
||||
</TableHeaderCell>
|
||||
<TableHeaderCell className="py-1 h-8">
|
||||
<SortHeader label="Policy" field="call_policy" />
|
||||
<SortHeader label="Input Policy" field="input_policy" />
|
||||
</TableHeaderCell>
|
||||
<TableHeaderCell className="py-1 h-8">
|
||||
<SortHeader label="Output Policy" field="output_policy" />
|
||||
</TableHeaderCell>
|
||||
<TableHeaderCell className="py-1 h-8">
|
||||
<SortHeader label="# Calls" field="call_count" />
|
||||
@ -359,45 +447,61 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
<TableHeaderCell className="py-1 h-8">
|
||||
<SortHeader label="Key Name" field="key_alias" />
|
||||
</TableHeaderCell>
|
||||
<TableHeaderCell className="py-1 h-8">Origin</TableHeaderCell>
|
||||
<TableHeaderCell className="py-1 h-8">User Agent</TableHeaderCell>
|
||||
</TableRow>
|
||||
</TableHead>
|
||||
<TableBody>
|
||||
{loading ? (
|
||||
<TableRow>
|
||||
<TableCell colSpan={8} className="h-8 text-center text-gray-500">
|
||||
<TableCell colSpan={9} className="h-8 text-center text-gray-500">
|
||||
Loading tools…
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : paginated.length === 0 ? (
|
||||
<TableRow>
|
||||
<TableCell colSpan={8} className="h-8 text-center text-gray-500">
|
||||
<TableCell colSpan={9} className="h-8 text-center text-gray-500">
|
||||
No tools discovered yet. Make a chat completion that returns tool_calls to start auto-discovery.
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
) : (
|
||||
paginated.map((tool) => (
|
||||
<TableRow key={tool.tool_id} className="h-8 hover:bg-gray-50">
|
||||
<TableRow key={tool.tool_id} id={`tool-row-${tool.tool_id}`} className="h-8 hover:bg-gray-50">
|
||||
<TableCell className="py-0.5 max-h-8 overflow-hidden whitespace-nowrap">
|
||||
<TimeCell utcTime={tool.created_at ?? ""} />
|
||||
</TableCell>
|
||||
<TableCell className="py-0.5 max-h-8 overflow-hidden">
|
||||
<Tooltip title={tool.tool_name}>
|
||||
<span className="font-mono text-xs max-w-[20ch] truncate block font-medium">
|
||||
{tool.tool_name}
|
||||
</span>
|
||||
</Tooltip>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => onSelectTool?.(tool.tool_name)}
|
||||
className="text-left w-full font-mono text-xs max-w-[20ch] truncate block font-medium text-blue-600 hover:text-blue-800 hover:underline focus:outline-none focus:ring-0"
|
||||
>
|
||||
<Tooltip title={onSelectTool ? "Click to view details and block for team/key" : tool.tool_name}>
|
||||
<span>{tool.tool_name}</span>
|
||||
</Tooltip>
|
||||
</button>
|
||||
</TableCell>
|
||||
<TableCell className="py-0.5 max-h-8">
|
||||
<PolicySelect
|
||||
value={tool.call_policy}
|
||||
value={tool.input_policy}
|
||||
toolName={tool.tool_name}
|
||||
saving={saving === tool.tool_name}
|
||||
onChange={handlePolicyChange}
|
||||
saving={savingInput === tool.tool_name}
|
||||
onChange={handleInputPolicyChange}
|
||||
policyType="input"
|
||||
/>
|
||||
</TableCell>
|
||||
<TableCell className="py-0.5 max-h-8 text-right tabular-nums text-sm font-mono text-gray-700">
|
||||
{(tool.call_count ?? 0).toLocaleString()}
|
||||
<TableCell className="py-0.5 max-h-8">
|
||||
<PolicySelect
|
||||
value={tool.output_policy}
|
||||
toolName={tool.tool_name}
|
||||
saving={savingOutput === tool.tool_name}
|
||||
onChange={handleOutputPolicyChange}
|
||||
policyType="output"
|
||||
/>
|
||||
</TableCell>
|
||||
<TableCell className="py-0.5 max-h-8">
|
||||
<div className="flex items-center justify-end h-8 tabular-nums text-sm font-mono text-gray-700">
|
||||
{(tool.call_count ?? 0).toLocaleString()}
|
||||
</div>
|
||||
</TableCell>
|
||||
<TableCell className="py-0.5 max-h-8 overflow-hidden whitespace-nowrap">
|
||||
<Tooltip title={tool.team_id ?? "-"}>
|
||||
@ -417,8 +521,8 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
</Tooltip>
|
||||
</TableCell>
|
||||
<TableCell className="py-0.5 max-h-8 overflow-hidden whitespace-nowrap">
|
||||
<Tooltip title={tool.origin ?? "-"}>
|
||||
<span className="max-w-[15ch] truncate block">{tool.origin ?? "-"}</span>
|
||||
<Tooltip title={tool.user_agent ?? "-"}>
|
||||
<span className="font-mono max-w-[20ch] truncate block text-xs text-gray-500">{tool.user_agent ?? "-"}</span>
|
||||
</Tooltip>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
@ -427,7 +531,6 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
{/* Bottom pagination (only when > 1 page) */}
|
||||
{totalPages > 1 && (
|
||||
<div className="border-t px-6 py-3 flex items-center justify-between text-sm text-gray-600">
|
||||
<span>
|
||||
@ -453,6 +556,7 @@ export const ToolPolicies: React.FC<ToolPoliciesProps> = ({ accessToken }) => {
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@ -0,0 +1,92 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import { Select } from "antd";
|
||||
|
||||
export const INPUT_POLICY_OPTIONS = [
|
||||
{ value: "untrusted", label: "untrusted", color: "#92400e", bg: "#fef3c7", border: "#fcd34d" },
|
||||
{ value: "trusted", label: "trusted", color: "#065f46", bg: "#d1fae5", border: "#6ee7b7" },
|
||||
{ value: "blocked", label: "blocked", color: "#991b1b", bg: "#fee2e2", border: "#fca5a5" },
|
||||
] as const;
|
||||
|
||||
export const OUTPUT_POLICY_OPTIONS = [
|
||||
{ value: "untrusted", label: "untrusted", color: "#92400e", bg: "#fef3c7", border: "#fcd34d" },
|
||||
{ value: "trusted", label: "trusted", color: "#065f46", bg: "#d1fae5", border: "#6ee7b7" },
|
||||
] as const;
|
||||
|
||||
export const POLICY_OPTIONS = INPUT_POLICY_OPTIONS;
|
||||
|
||||
export const policyStyle = (p: string) =>
|
||||
INPUT_POLICY_OPTIONS.find((o) => o.value === p) ?? INPUT_POLICY_OPTIONS[0];
|
||||
|
||||
export interface PolicySelectProps {
|
||||
value: string;
|
||||
toolName: string;
|
||||
saving: boolean;
|
||||
onChange: (toolName: string, policy: string) => void;
|
||||
policyType?: "input" | "output";
|
||||
size?: "small" | "middle";
|
||||
minWidth?: number;
|
||||
stopPropagation?: boolean;
|
||||
}
|
||||
|
||||
export const PolicySelect: React.FC<PolicySelectProps> = ({
|
||||
value,
|
||||
toolName,
|
||||
saving,
|
||||
onChange,
|
||||
policyType = "input",
|
||||
size = "small",
|
||||
minWidth = 110,
|
||||
stopPropagation = true,
|
||||
}) => {
|
||||
const options = policyType === "output" ? OUTPUT_POLICY_OPTIONS : INPUT_POLICY_OPTIONS;
|
||||
const style = policyStyle(value);
|
||||
return (
|
||||
<Select
|
||||
size={size}
|
||||
value={value}
|
||||
disabled={saving}
|
||||
loading={saving}
|
||||
onChange={(v) => onChange(toolName, v)}
|
||||
onClick={(e) => stopPropagation && e.stopPropagation()}
|
||||
style={{
|
||||
minWidth,
|
||||
fontWeight: 500,
|
||||
backgroundColor: style.bg,
|
||||
borderColor: style.border,
|
||||
color: style.color,
|
||||
borderRadius: 999,
|
||||
fontSize: size === "small" ? 11 : 12,
|
||||
}}
|
||||
popupMatchSelectWidth={false}
|
||||
options={options.map((o) => ({
|
||||
value: o.value,
|
||||
label: (
|
||||
<span
|
||||
style={{
|
||||
display: "inline-flex",
|
||||
alignItems: "center",
|
||||
gap: 6,
|
||||
fontSize: 12,
|
||||
fontWeight: 500,
|
||||
color: o.color,
|
||||
}}
|
||||
>
|
||||
<span
|
||||
style={{
|
||||
width: 8,
|
||||
height: 8,
|
||||
borderRadius: "50%",
|
||||
backgroundColor: o.color,
|
||||
display: "inline-block",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
/>
|
||||
{o.label}
|
||||
</span>
|
||||
),
|
||||
}))}
|
||||
/>
|
||||
);
|
||||
};
|
||||
44
ui/litellm-dashboard/src/components/ToolPoliciesView.tsx
Normal file
44
ui/litellm-dashboard/src/components/ToolPoliciesView.tsx
Normal file
@ -0,0 +1,44 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState } from "react";
|
||||
import { ToolDetail } from "@/components/ToolDetail";
|
||||
import { ToolPolicies } from "@/components/ToolPolicies";
|
||||
|
||||
type View =
|
||||
| { type: "overview" }
|
||||
| { type: "detail"; toolName: string };
|
||||
|
||||
interface ToolPoliciesViewProps {
|
||||
accessToken: string | null;
|
||||
userRole?: string;
|
||||
}
|
||||
|
||||
export default function ToolPoliciesView({ accessToken, userRole }: ToolPoliciesViewProps) {
|
||||
const [view, setView] = useState<View>({ type: "overview" });
|
||||
|
||||
const handleSelectTool = (toolName: string) => {
|
||||
setView({ type: "detail", toolName });
|
||||
};
|
||||
|
||||
const handleBack = () => {
|
||||
setView({ type: "overview" });
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="p-6 w-full min-w-0 flex-1">
|
||||
{view.type === "detail" ? (
|
||||
<ToolDetail
|
||||
toolName={view.toolName}
|
||||
onBack={handleBack}
|
||||
accessToken={accessToken}
|
||||
/>
|
||||
) : (
|
||||
<ToolPolicies
|
||||
accessToken={accessToken}
|
||||
userRole={userRole}
|
||||
onSelectTool={handleSelectTool}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@ -10127,7 +10127,8 @@ export interface ToolRow {
|
||||
tool_id: string;
|
||||
tool_name: string;
|
||||
origin?: string;
|
||||
call_policy: string;
|
||||
input_policy: string;
|
||||
output_policy: string;
|
||||
call_count?: number;
|
||||
assignments?: Record<string, any>;
|
||||
key_hash?: string;
|
||||
@ -10139,6 +10140,37 @@ export interface ToolRow {
|
||||
updated_by?: string;
|
||||
}
|
||||
|
||||
export interface ToolPolicyOption {
|
||||
value: string;
|
||||
label: string;
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface ToolPolicyOptionsResponse {
|
||||
input_policies: ToolPolicyOption[];
|
||||
output_policies: ToolPolicyOption[];
|
||||
}
|
||||
|
||||
export const fetchToolPolicyOptions = async (
|
||||
accessToken: string
|
||||
): Promise<ToolPolicyOptionsResponse> => {
|
||||
const url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/v1/tool/policy/options`
|
||||
: `/v1/tool/policy/options`;
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
throw new Error(errorData);
|
||||
}
|
||||
return response.json();
|
||||
};
|
||||
|
||||
export const fetchToolsList = async (accessToken: string): Promise<ToolRow[]> => {
|
||||
const url = proxyBaseUrl ? `${proxyBaseUrl}/v1/tool/list` : `/v1/tool/list`;
|
||||
const response = await fetch(url, {
|
||||
@ -10156,19 +10188,137 @@ export const fetchToolsList = async (accessToken: string): Promise<ToolRow[]> =>
|
||||
return data.tools ?? [];
|
||||
};
|
||||
|
||||
export const updateToolPolicy = async (
|
||||
export interface ToolPolicyOverrideRow {
|
||||
override_id: string;
|
||||
tool_name: string;
|
||||
team_id?: string | null;
|
||||
key_hash?: string | null;
|
||||
input_policy: string;
|
||||
key_alias?: string | null;
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
}
|
||||
|
||||
export interface ToolDetailResponse {
|
||||
tool: ToolRow;
|
||||
overrides: ToolPolicyOverrideRow[];
|
||||
}
|
||||
|
||||
export interface ToolUsageLogEntry {
|
||||
id: string;
|
||||
timestamp: string;
|
||||
model?: string | null;
|
||||
spend?: number | null;
|
||||
total_tokens?: number | null;
|
||||
input_snippet?: string | null;
|
||||
}
|
||||
|
||||
export interface ToolUsageLogsResponse {
|
||||
logs: ToolUsageLogEntry[];
|
||||
total: number;
|
||||
page: number;
|
||||
page_size: number;
|
||||
}
|
||||
|
||||
export const getToolUsageLogs = async (
|
||||
accessToken: string,
|
||||
toolName: string,
|
||||
callPolicy: string
|
||||
): Promise<ToolRow> => {
|
||||
const url = proxyBaseUrl ? `${proxyBaseUrl}/v1/tool/policy` : `/v1/tool/policy`;
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
options: { page?: number; pageSize?: number; startDate?: string; endDate?: string }
|
||||
): Promise<ToolUsageLogsResponse> => {
|
||||
const encoded = encodeURIComponent(toolName);
|
||||
const url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/v1/tool/${encoded}/logs`
|
||||
: `/v1/tool/${encoded}/logs`;
|
||||
const params = new URLSearchParams();
|
||||
if (options.page != null) params.append("page", String(options.page));
|
||||
if (options.pageSize != null) params.append("page_size", String(options.pageSize));
|
||||
if (options.startDate) params.append("start_date", options.startDate);
|
||||
if (options.endDate) params.append("end_date", options.endDate);
|
||||
const fullUrl = params.toString() ? `${url}?${params.toString()}` : url;
|
||||
const response = await fetch(fullUrl, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
throw new Error(deriveErrorMessage(errorData));
|
||||
}
|
||||
return response.json();
|
||||
};
|
||||
|
||||
export const fetchToolDetail = async (
|
||||
accessToken: string,
|
||||
toolName: string
|
||||
): Promise<ToolDetailResponse> => {
|
||||
const encoded = encodeURIComponent(toolName);
|
||||
const url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/v1/tool/${encoded}/detail`
|
||||
: `/v1/tool/${encoded}/detail`;
|
||||
const response = await fetch(url, {
|
||||
method: "GET",
|
||||
headers: {
|
||||
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ tool_name: toolName, call_policy: callPolicy }),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
throw new Error(errorData);
|
||||
}
|
||||
return response.json();
|
||||
};
|
||||
|
||||
export const updateToolPolicy = async (
|
||||
accessToken: string,
|
||||
toolName: string,
|
||||
policies: { input_policy?: string; output_policy?: string },
|
||||
options?: { team_id?: string | null; key_hash?: string | null; key_alias?: string | null }
|
||||
): Promise<ToolRow> => {
|
||||
const url = proxyBaseUrl ? `${proxyBaseUrl}/v1/tool/policy` : `/v1/tool/policy`;
|
||||
const body: Record<string, string | undefined | null> = {
|
||||
tool_name: toolName,
|
||||
};
|
||||
if (policies.input_policy != null) body.input_policy = policies.input_policy;
|
||||
if (policies.output_policy != null) body.output_policy = policies.output_policy;
|
||||
if (options?.team_id != null) body.team_id = options.team_id || undefined;
|
||||
if (options?.key_hash != null) body.key_hash = options.key_hash || undefined;
|
||||
if (options?.key_alias != null) body.key_alias = options.key_alias || undefined;
|
||||
const response = await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(body),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
throw new Error(errorData);
|
||||
}
|
||||
return response.json();
|
||||
};
|
||||
|
||||
export const deleteToolPolicyOverride = async (
|
||||
accessToken: string,
|
||||
toolName: string,
|
||||
params: { team_id?: string | null; key_hash?: string | null }
|
||||
): Promise<{ deleted: boolean; tool_name: string }> => {
|
||||
const encoded = encodeURIComponent(toolName);
|
||||
const q = new URLSearchParams();
|
||||
if (params.team_id != null && params.team_id !== "") q.set("team_id", params.team_id);
|
||||
if (params.key_hash != null && params.key_hash !== "") q.set("key_hash", params.key_hash);
|
||||
const query = q.toString();
|
||||
const url = proxyBaseUrl
|
||||
? `${proxyBaseUrl}/v1/tool/${encoded}/overrides${query ? `?${query}` : ""}`
|
||||
: `/v1/tool/${encoded}/overrides${query ? `?${query}` : ""}`;
|
||||
const response = await fetch(url, {
|
||||
method: "DELETE",
|
||||
headers: {
|
||||
[globalLitellmHeaderName]: `Bearer ${accessToken}`,
|
||||
},
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errorData = await response.text();
|
||||
|
||||
Loading…
Reference in New Issue
Block a user