[Staging] - Ishaan March 17th (#23903)

* feat(xai): add grok-4.20 beta 2 models with pricing (#23900)

Add three grok-4.20 beta 2 model variants from xAI:
- grok-4.20-multi-agent-beta-0309 (reasoning + multi-agent)
- grok-4.20-beta-0309-reasoning (reasoning)
- grok-4.20-beta-0309-non-reasoning

Pricing (from https://docs.x.ai/docs/models):
- Input: $2.00/1M tokens ($0.20/1M cached)
- Output: $6.00/1M tokens
- Context: 2M tokens

All variants support vision, function calling, tool choice, and web search.
Closes LIT-2171

* docs: add Quick Install section for litellm --setup wizard (#23905)

* docs: add Quick Install section for litellm --setup wizard

* docs: clarify setup wizard is for local/beginner use

* feat(setup): interactive setup wizard + install.sh (#23644)

* feat(setup): add interactive setup wizard + install.sh

Adds `litellm --setup` — a Claude Code-style TUI onboarding wizard that
guides users through provider selection, API key entry, and proxy config
generation, then optionally starts the proxy immediately.

- litellm/setup_wizard.py: wizard with ASCII art, numbered provider menu
  (OpenAI, Anthropic, Azure, Gemini, Bedrock, Ollama), API key prompts,
  port/master-key config, and litellm_config.yaml generation
- litellm/proxy/proxy_cli.py: adds --setup flag that invokes the wizard
- scripts/install.sh: curl-installable script (detect OS/Python, pip
  install litellm[proxy], launch wizard)

Usage:
  curl -fsSL https://raw.githubusercontent.com/BerriAI/litellm/main/scripts/install.sh | sh
  litellm --setup

* fix(install.sh): remove orange color, add LITELLM_BRANCH env var for branch installs

* fix(install.sh): install from git branch so --setup is available for QA

* fix(install.sh): remove stale LITELLM_BRANCH reference that caused unbound variable error

* fix(install.sh): force-reinstall from git to bypass cached PyPI version

* fix(install.sh): show pip progress bar during install

* fix(install.sh): always launch wizard via $PYTHON_BIN -m litellm, not PATH binary

* fix(install.sh): use litellm.proxy.proxy_cli module (no __main__.py exists)

* fix(install.sh): suppress RuntimeWarning from module invocation

* fix(install.sh): use Python bin-dir litellm binary to avoid CWD sys.path shadowing

* fix(install.sh): use sysconfig.get_path('scripts') to find pip-installed litellm binary

* fix(install.sh): redirect stdin from /dev/tty on exec so wizard gets terminal, not exhausted pipe

* fix(install.sh): warn about git clone duration, drop --no-cache-dir so re-runs are faster

* feat(setup_wizard): arrow-key selector, updated model names

* fix(setup_wizard): use sysconfig binary to start proxy, not python -m litellm

* feat(setup_wizard): credential validation after key entry + clear next-steps after proxy start

* style(install.sh): show git clone warning in blue

* refactor(setup_wizard): class with static methods, use check_valid_key from litellm.utils

* address greptile review: fix yaml escaping, port validation, display name collisions, tests

- setup_wizard.py: add _yaml_escape() for safe YAML embedding of API keys
- setup_wizard.py: add _styled_input() with readline ANSI ignore markers
- setup_wizard.py: change DIVIDER to _divider() fn to avoid import-time color capture
- setup_wizard.py: validate port range 1-65535, initialize before loop
- setup_wizard.py: qualify azure display names (azure-gpt-4o) to avoid collision with openai
- setup_wizard.py: work on env_copy in _build_config to avoid mutating caller's dict
- setup_wizard.py: skip model_list entries for providers with no credentials
- setup_wizard.py: prompt for azure deployment name
- setup_wizard.py: wrap os.execlp in try/except with friendly fallback
- setup_wizard.py: wrap config write in try/except OSError
- setup_wizard.py: fix _validate_and_report to use two print lines (no \r overwrite)
- setup_wizard.py: add .gitignore tip next to key storage notice
- setup_wizard.py: fix run_setup_wizard() return type annotation to None
- scripts/install.sh: drop pipefail (not supported by dash on Ubuntu when invoked as sh)
- scripts/install.sh: use litellm[proxy] from PyPI (not hardcoded dev branch)
- scripts/install.sh: guard /dev/tty read with -r check for Docker/CI compat
- scripts/install.sh: remove --force-reinstall to avoid downgrading dependencies
- tests/test_litellm/test_setup_wizard.py: 13 unit tests for _build_config and _yaml_escape

* style: black format setup_wizard.py

* fix: address remaining greptile issues - Windows compat, YAML quoting, credential flow

- guard termios/tty imports with try/except ImportError for Windows compat
- quote master_key as YAML double-quoted scalar (same as env vars)
- remove unused port param from _build_config signature
- _validate_and_report now returns the final key so re-entered creds are stored
- add test for master_key YAML quoting

* fix: add --port to suggested command, guard /dev/tty exec in install.sh

* fix: quote api_base in YAML, skip azure if no deployment, only redraw on state change

* fix: address greptile review comments

- _yaml_escape: add control character escaping (\n, \r, \t)
- test: fix tautological assertion in test_build_config_azure_no_deployment_skipped
- test: add tests for control character escaping in _yaml_escape

* feat(ui): remove Chat UI page link and banner from sidebar and playground (#23908)

* feat(guardrails): MCPJWTSigner - built-in guardrail for zero trust MCP auth (#23897)

* Allow pre_mcp_call guardrail hooks to mutate outbound MCP headers

* Enhance MCPServerManager to support hook-modified arguments and extra headers. Update tests to validate argument mutation and header injection behavior, including warnings for OpenAPI-backed servers when headers are present.

* Refactor MCPServerManager to raise HTTPException for extra headers in OpenAPI-backed servers. Update tests to reflect this change, ensuring proper exception handling instead of logging warnings.

* Allow pre_mcp_call guardrail hooks to mutate outbound MCP headers

* Enhance MCPServerManager to support hook-modified arguments and extra headers. Update tests to validate argument mutation and header injection behavior, including warnings for OpenAPI-backed servers when headers are present.

* Refactor MCPServerManager to raise HTTPException for extra headers in OpenAPI-backed servers. Update tests to reflect this change, ensuring proper exception handling instead of logging warnings.

* feat(guardrails): add MCPJWTSigner built-in guardrail for zero trust MCP auth

Signs outbound MCP tool calls with a LiteLLM-issued RS256 JWT so MCP servers
can trust a single signing authority instead of every upstream IdP.

Enable in config.yaml:
  guardrails:
    - guardrail_name: mcp-jwt-signer
      litellm_params:
        guardrail: mcp_jwt_signer
        mode: pre_mcp_call
        default_on: true

JWT carries sub (user_id), act.sub (team_id, RFC 8693), tool-level scope, iss,
aud, iat/exp/nbf. RSA-2048 keypair auto-generated at startup unless
MCP_JWT_SIGNING_KEY env var is set.

Adds /.well-known/jwks.json endpoint and jwks_uri to /.well-known/openid-configuration
so MCP servers can verify LiteLLM-issued tokens via OIDC discovery.

* Update MCPServerManager to raise HTTPException with status code 400 for extra headers in OpenAPI-backed servers. Adjust tests to verify the correct status code and exception message.

* fix: address P1 issues in MCPJWTSigner

- OpenAPI servers: warn + skip header injection instead of 500
- JWKS Cache-Control: 5min for auto-generated keys, 1h for persistent
- sub claim: fallback to apikey:{token_hash} for anonymous callers
- ttl_seconds: validate > 0 at init time

* docs: add MCP zero trust auth guide with architecture diagram

* docs: add FastMCP JWT verification guide to zero trust doc

* fix: address remaining Greptile review issues (round 2)

- mcp_server_manager: warn when hook Authorization overwrites existing header
- __init__: remove _mcp_jwt_signer_instance from __all__ (private internal)
- discoverable_endpoints: copy dict instead of mutating in-place on OIDC augmentation
- test docstring: reflect warn-and-continue behavior for OpenAPI servers
- test: update scope assertions for least-privilege (no mcp:tools/list on tool-call JWTs)

* fix: address Greptile round 3 feedback

- initialize_guardrail: validate mode='pre_mcp_call' at init time — misconfigured
  mode silently bypasses JWT injection, which is a zero-trust bypass
- _build_claims: remove duplicate inline 'import re' (module-level import already present)
- _types.py: add TODO comment explaining jwt_claims is forward-compat plumbing
  for a follow-up PR that will forward upstream IdP claims into outbound MCP JWTs

* feat(mcp_jwt_signer): add verify+re-sign, claim ops, two-token model, configurable scopes

Addresses all missing pieces from the scoping doc review:

FR-5 (Verify + re-sign): MCPJWTSigner now accepts access_token_discovery_uri
and token_introspection_endpoint.  When set, the incoming Bearer token is
extracted from raw_headers (threaded through pre_call_tool_check), verified
against the IdP's JWKS (JWT) or introspected (opaque), and only re-signed if
valid.  Falls back to user_api_key_dict.jwt_claims for LiteLLM JWT-auth mode.

FR-12 (Configurable end-user identity mapping): end_user_claim_sources
ordered list drives sub resolution — sources: token:<claim>, litellm:user_id,
litellm:email, litellm:end_user_id, litellm:team_id.

FR-13 (Claim operations): add_claims (insert-if-absent), set_claims (always
override), remove_claims (delete) applied in that order.

FR-14 (Two-token model): channel_token_audience + channel_token_ttl issue a
second JWT injected as x-mcp-channel-token: Bearer <token>.

FR-15 (Incoming claim validation): required_claims raises HTTP 403 when any
listed claim is absent; optional_claims passes listed claims from verified
token into the outbound JWT.

FR-9 (Debug headers): debug_headers: true emits x-litellm-mcp-debug with kid,
sub, iss, exp, scope.

FR-10 (Configurable scopes): allowed_scopes replaces auto-generation.  Also
fixed: tool-call JWTs no longer grant mcp:tools/list (overpermission).

P1 fixes:
- proxy/utils.py: _convert_mcp_hook_response_to_kwargs merges rather than
  replaces extra_headers, preserving headers from prior guardrails.
- mcp_server_manager.py: warns when hook injects Authorization alongside a
  server-configured authentication_token (previously silent).
- mcp_server_manager.py: pre_call_tool_check now accepts raw_headers and
  extracts incoming_bearer_token so FR-5 verification has the raw token.
- proxy/utils.py: remove stray inline import inspect inside loop (pre-existing
  lint error, now cleaned up).

Tests: 43 passing (28 new tests covering all FR flags + P1 fixes).

* feat(mcp_jwt_signer): add verify+re-sign, claim ops, two-token model, configurable scopes (core)

Remaining files from the FR implementation:

mcp_jwt_signer.py — full rewrite with all new params:
  FR-5:  access_token_discovery_uri, token_introspection_endpoint,
         verify_issuer, verify_audience + _verify_incoming_jwt(),
         _introspect_opaque_token()
  FR-12: end_user_claim_sources ordered resolution chain
  FR-13: add_claims, set_claims, remove_claims
  FR-14: channel_token_audience, channel_token_ttl → x-mcp-channel-token
  FR-15: required_claims (raises 403), optional_claims (passthrough)
  FR-9:  debug_headers → x-litellm-mcp-debug
  FR-10: allowed_scopes; tool-call JWTs no longer over-grant tools/list

mcp_server_manager.py:
  - pre_call_tool_check gains raw_headers param to extract incoming_bearer_token
  - Silent Authorization override warning fixed: now fires when server has
    authentication_token AND hook injects Authorization

tests/test_mcp_jwt_signer.py:
  28 new tests covering all FR flags + P1 fixes (43 total, all passing)

* fix(mcp_jwt_signer): address pre-landing review issues

- Remove stale TODO comment on UserAPIKeyAuth.jwt_claims — the field is
  already populated and consumed by MCPJWTSigner in the same PR
- Fix _get_oidc_discovery to only cache the OIDC discovery doc when
  jwks_uri is present; a malformed/empty doc now retries on the next
  request instead of being permanently cached until proxy restart
- Add FR-5 test coverage for _fetch_jwks (cache hit/miss),
  _get_oidc_discovery (cache/no-cache on bad doc), _verify_incoming_jwt
  (valid token, expired token), _introspect_opaque_token (active,
  inactive, no endpoint), and the end-to-end 401 hook path — 53 tests
  total, all passing

* docs(mcp_zero_trust): rewrite as use-case guide covering all new JWT signer features

Add scenario-driven sections for each new config area:
- Verify+re-sign with Okta/Azure AD (access_token_discovery_uri,
  end_user_claim_sources, token_introspection_endpoint)
- Enforcing caller attributes with required_claims / optional_claims
- Adding metadata via add_claims / set_claims / remove_claims
- Two-token model for AWS Bedrock AgentCore Gateway
  (channel_token_audience / channel_token_ttl)
- Controlling scopes with allowed_scopes
- Debugging JWT rejections with debug_headers

Update JWT claims table to reflect configurable sub (end_user_claim_sources)

* fix(mcp_jwt_signer): wire all config.yaml params through initialize_guardrail

The factory was only passing issuer/audience/ttl_seconds to MCPJWTSigner.
All FR-5/9/10/12/13/14/15 params (access_token_discovery_uri,
end_user_claim_sources, add/set/remove_claims, channel_token_audience,
required/optional_claims, debug_headers, allowed_scopes, etc.) were
silently dropped, making every advertised advanced feature non-functional
when loaded from config.yaml.

Add regression test that asserts every param is wired through correctly.

* docs(mcp_zero_trust): add hero image

* docs(mcp_zero_trust): apply Linear-style edits

- Lead with the problem (unsigned direct calls bypass access controls)
- Shorter statement section headers instead of question-form headers
- Move diagram/OIDC discovery block after the reader is bought in
- Add 'read further only if you need to' callout after basic setup
- Two-token section now opens from the user problem not product jargon
- Add concrete 403 error response example in required_claims section
- Debug section opens from the symptom (MCP server returning 401)
- Lowercase claims reference header for consistency

* fix(mcp_jwt_signer): fix algorithm confusion attack + add OIDC discovery 24h TTL

- Remove alg from unverified JWT header; use signing_jwk.algorithm_name from JWKS key instead.
  Reading alg from attacker-controlled headers enables alg:none / HS256 confusion attacks.
- Add _oidc_discovery_fetched_at timestamp and _OIDC_DISCOVERY_TTL = 86400 (24h).
  Without a TTL the cached discovery doc never refreshes, so IdP key rotation is invisible.

---------

Co-authored-by: Noah Nistler <60981020+noahnistler@users.noreply.github.com>

* fix(ci): stabilize CI - formatting, type errors, test polling, security CVEs, router bug, batch resolution

Fix 1: Run Black formatter on 35 files
Fix 2: Fix MyPy type errors:
  - setup_wizard.py: add type annotation for 'selected' set variable
  - user_api_key_auth.py: remove redundant type annotation on jwt_claims reassignment
Fix 3: Fix spend accuracy test burst 2 polling to wait for expected total
  spend instead of just 'any increase' from burst 2
Fix 4: Bump Next.js 16.1.6 -> 16.1.7 to fix CVE-2026-27978, CVE-2026-27979,
  CVE-2026-27980, CVE-2026-29057
Fix 5: Fix router _pre_call_checks model variable being overwritten inside
  loop, causing wrong model lookups on subsequent deployments. Use local
  _deployment_model variable instead.
Fix 6: Add missing resolve_output_file_ids_to_unified call in batch retrieve
  non-terminal-to-terminal path (matching the terminal path behavior)

Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>

* chore: regenerate poetry.lock to sync with pyproject.toml

Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>

* fix: format merged files from main and regenerate poetry.lock

Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>

* fix(mypy): annotate jwt_claims as Optional[dict] to fix type incompatibility

Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>

* fix(ci): update router region test to use gpt-4.1-mini (fix flaky model lookup)

Replace deprecated gpt-3.5-turbo-1106 with gpt-4.1-mini + mock_response in
test_router_region_pre_call_check, following the same pattern used in commit
717d37cc5b for test_router_context_window_check_pre_call_check_out_group.

Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>

* ci: retry flaky logging_testing (async event loop race condition)

Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>

* fix(ci): aggregate all mock calls in langfuse e2e test to fix race condition

The _verify_langfuse_call helper only inspected the last mock call
(mock_post.call_args), but the Langfuse SDK may split trace-create and
generation-create events across separate HTTP flush cycles. This caused
an IndexError when the last call's batch contained only one event type.

Fix: iterate over mock_post.call_args_list to collect batch items from
ALL calls. Also add a safety assertion after filtering by trace_id and
mark all langfuse e2e tests with @pytest.mark.flaky(retries=3) as an
extra safety net for any residual timing issues.

Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>

* fix(ci): black formatting + update OpenAPI compliance tests for spec changes

- Apply Black 26.x formatting to litellm_logging.py (parenthesized style)
- Update test_input_types_match_spec to follow $ref to InteractionsInput schema
  (Google updated their OpenAPI spec to use $ref instead of inline oneOf)
- Update test_content_schema_uses_discriminator to handle discriminator without
  explicit mapping (Google removed the mapping key from Content discriminator)

Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>

* revert: undo incorrect Black 26.x formatting on litellm_logging.py

The file was correctly formatted for Black 23.12.1 (the version pinned
in pyproject.toml). The previous commit applied Black 26.x formatting
which was incompatible with the CI's Black version.

Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>

* fix(ci): deduplicate and sort langfuse batch events after aggregation

The Langfuse SDK may send the same event (e.g., trace-create) in
multiple flush cycles, causing duplicates when we aggregate from all
mock calls. After filtering by trace_id, deduplicate by keeping only
the first event of each type, then sort to ensure trace-create is at
index 0 and generation-create at index 1.

Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>

---------

Co-authored-by: Noah Nistler <60981020+noahnistler@users.noreply.github.com>
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
This commit is contained in:
Ishaan Jaff 2026-03-18 15:09:01 -07:00 committed by GitHub
parent cbb4c2c220
commit 8e61b32b8e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
61 changed files with 4695 additions and 450 deletions

View File

@ -140,6 +140,11 @@ LiteLLM is a unified interface for 100+ LLM providers with two main components:
- **Check index coverage.** For new or modified queries, check `schema.prisma` for a supporting index. Prefer extending an existing index (e.g. `@@index([a])``@@index([a, b])`) over adding a new one, unless it's a `@@unique`. Only add indexes for large/frequent queries.
- **Keep schema files in sync.** Apply schema changes to all `schema.prisma` copies (`schema.prisma`, `litellm/proxy/`, `litellm-proxy-extras/`, `litellm-js/spend-logs/` for SpendLogs) with a migration under `litellm-proxy-extras/litellm_proxy_extras/migrations/`.
### Setup Wizard (`litellm/setup_wizard.py`)
- The wizard is implemented as a single `SetupWizard` class with `@staticmethod` methods — keep it that way. No module-level functions except `run_setup_wizard()` (the public entrypoint) and pure helpers (color, ANSI).
- Use `litellm.utils.check_valid_key(model, api_key)` for credential validation — never roll a custom completion call.
- Do not hardcode provider env-key names or model lists that already exist in the codebase. Add a `test_model` field to each provider entry to drive `check_valid_key`; set it to `None` for providers that can't be validated with a single API key (Azure, Bedrock, Ollama).
### Enterprise Features
- Enterprise-specific code in `enterprise/` directory
- Optional features enabled via environment variables

View File

@ -5,11 +5,76 @@ import Image from '@theme/IdealImage';
# Getting Started Tutorial
End-to-End tutorial for LiteLLM Proxy to:
- Add an Azure OpenAI model
- Make a successful /chat/completion call
- Generate a virtual key
- Set RPM limit on virtual key
- Add an Azure OpenAI model
- Make a successful /chat/completion call
- Generate a virtual key
- Set RPM limit on virtual key
## Quick Install (Recommended for local / beginners)
New to LiteLLM? This is the easiest way to get started locally. One command installs LiteLLM and walks you through setup interactively — no config files to write by hand.
### 1. Install
```bash
curl -fsSL https://raw.githubusercontent.com/BerriAI/litellm/main/scripts/install.sh | sh
```
This detects your OS, installs `litellm[proxy]`, and drops you straight into the setup wizard.
### 2. Follow the wizard
```
$ litellm --setup
Welcome to LiteLLM
Choose your LLM providers
○ 1. OpenAI GPT-4o, GPT-4o-mini, o1
○ 2. Anthropic Claude Opus, Sonnet, Haiku
○ 3. Azure OpenAI GPT-4o via Azure
○ 4. Google Gemini Gemini 2.0 Flash, 1.5 Pro
○ 5. AWS Bedrock Claude, Llama via AWS
○ 6. Ollama Local models
Provider(s): 1,2
OpenAI API key: sk-...
Anthropic API key: sk-ant-...
Port [4000]:
Master key [auto-generate]:
✔ Config saved → ./litellm_config.yaml
Start the proxy now? (Y/n):
```
The wizard walks you through:
1. Pick your LLM providers (OpenAI, Anthropic, Azure, Bedrock, Gemini, Ollama)
2. Enter API keys for each provider
3. Set a port and master key (or accept the defaults)
4. Config is saved to `./litellm_config.yaml` and the proxy starts immediately
### 3. Make a call
Your proxy is running on `http://0.0.0.0:4000`. Test it:
```bash
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
-H 'Content-Type: application/json' \
-H 'Authorization: Bearer <your-master-key>' \
-d '{
"model": "gpt-4o",
"messages": [{"role": "user", "content": "Hello!"}]
}'
```
:::tip Already have pip installed?
You can skip the curl install and run `litellm --setup` directly after `pip install 'litellm[proxy]'`.
:::
---
## Pre-Requisites

Binary file not shown.

After

Width:  |  Height:  |  Size: 294 KiB

View File

@ -672,6 +672,7 @@ const sidebars = {
"mcp_control",
"mcp_cost",
"mcp_guardrail",
"mcp_zero_trust",
"mcp_troubleshoot",
]
},

View File

@ -1465,9 +1465,15 @@ if TYPE_CHECKING:
from .llms.petals.completion.transformation import PetalsConfig as PetalsConfig
from .llms.ollama.chat.transformation import OllamaChatConfig as OllamaChatConfig
from .llms.ollama.completion.transformation import OllamaConfig as OllamaConfig
from .llms.sagemaker.completion.transformation import SagemakerConfig as SagemakerConfig
from .llms.sagemaker.chat.transformation import SagemakerChatConfig as SagemakerChatConfig
from .llms.sagemaker.nova.transformation import SagemakerNovaConfig as SagemakerNovaConfig
from .llms.sagemaker.completion.transformation import (
SagemakerConfig as SagemakerConfig,
)
from .llms.sagemaker.chat.transformation import (
SagemakerChatConfig as SagemakerChatConfig,
)
from .llms.sagemaker.nova.transformation import (
SagemakerNovaConfig as SagemakerNovaConfig,
)
from .llms.cohere.chat.transformation import CohereChatConfig as CohereChatConfig
from .llms.anthropic.experimental_pass_through.messages.transformation import (
AnthropicMessagesConfig as AnthropicMessagesConfig,

View File

@ -17,7 +17,9 @@ if set_verbose is True:
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
)
_ENABLE_SECRET_REDACTION = os.getenv("LITELLM_DISABLE_REDACT_SECRETS", "").lower() != "true"
_ENABLE_SECRET_REDACTION = (
os.getenv("LITELLM_DISABLE_REDACT_SECRETS", "").lower() != "true"
)
_REDACTED = "REDACTED"
@ -199,7 +201,9 @@ class JsonFormatter(Formatter):
json_record[key] = value
if record.exc_info:
json_record["stacktrace"] = record.exc_text or self.formatException(record.exc_info)
json_record["stacktrace"] = record.exc_text or self.formatException(
record.exc_info
)
return safe_dumps(json_record)

View File

@ -1189,7 +1189,9 @@ def completion_cost( # noqa: PLR0915
and _usage["prompt_tokens_details"] != {}
and _usage["prompt_tokens_details"]
):
prompt_tokens_details = _usage.get("prompt_tokens_details") or {}
prompt_tokens_details = (
_usage.get("prompt_tokens_details") or {}
)
cache_read_input_tokens = prompt_tokens_details.get(
"cached_tokens", 0
)
@ -1515,7 +1517,9 @@ def completion_cost( # noqa: PLR0915
if custom_llm_provider == "azure_ai":
model_for_additional_costs = request_model_for_cost
if completion_response is not None:
hidden_params = getattr(completion_response, "_hidden_params", None) or {}
hidden_params = (
getattr(completion_response, "_hidden_params", None) or {}
)
hidden_model = hidden_params.get("model") or hidden_params.get(
"litellm_model_name"
)

View File

@ -59,17 +59,14 @@ class FocusDestinationFactory:
return {k: v for k, v in resolved.items() if v is not None}
if provider == "vantage":
resolved = {
"api_key": overrides.get("api_key")
or os.getenv("VANTAGE_API_KEY"),
"api_key": overrides.get("api_key") or os.getenv("VANTAGE_API_KEY"),
"integration_token": overrides.get("integration_token")
or os.getenv("VANTAGE_INTEGRATION_TOKEN"),
"base_url": overrides.get("base_url")
or os.getenv("VANTAGE_BASE_URL", "https://api.vantage.sh"),
}
if not resolved.get("api_key"):
raise ValueError(
"VANTAGE_API_KEY must be provided for Vantage exports"
)
raise ValueError("VANTAGE_API_KEY must be provided for Vantage exports")
if not resolved.get("integration_token"):
raise ValueError(
"VANTAGE_INTEGRATION_TOKEN must be provided for Vantage exports"

View File

@ -340,9 +340,9 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
)
status_message = str(kwargs.get("exception", "Unknown error"))
if standard_logging_object is not None:
status_message = standard_logging_object.get(
"error_str", None
) or status_message
status_message = (
standard_logging_object.get("error_str", None) or status_message
)
langfuse_logger_to_use.log_event_on_langfuse(
start_time=start_time,
end_time=end_time,

View File

@ -83,7 +83,9 @@ class VantageLogger(FocusLogger):
verbose_logger.debug(
"VantageLogger initialized (integration_token=%s)",
resolved_token[:4] + "***" if resolved_token and len(resolved_token) > 4 else "***",
resolved_token[:4] + "***"
if resolved_token and len(resolved_token) > 4
else "***",
)
async def initialize_focus_export_job(self) -> None:
@ -128,9 +130,7 @@ class VantageLogger(FocusLogger):
callback_type=VantageLogger
)
if not vantage_loggers:
verbose_logger.debug(
"No Vantage logger registered; skipping scheduler"
)
verbose_logger.debug("No Vantage logger registered; skipping scheduler")
return
vantage_logger = cast(VantageLogger, vantage_loggers[0])

View File

@ -26,7 +26,9 @@ if custom_cache_dir:
else:
cache_dir = filename
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
os.environ[
"TIKTOKEN_CACHE_DIR"
] = cache_dir # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
import tiktoken
import time
@ -48,4 +50,3 @@ for attempt in range(_max_retries):
# Exponential backoff with jitter to reduce collision probability
delay = _retry_delay * (2**attempt) + random.uniform(0, 0.1)
time.sleep(delay)

View File

@ -352,9 +352,9 @@ class Logging(LiteLLMLoggingBaseClass):
)
self.function_id = function_id
self.streaming_chunks: List[Any] = [] # for generating complete stream response
self.sync_streaming_chunks: List[Any] = (
[]
) # for generating complete stream response
self.sync_streaming_chunks: List[
Any
] = [] # for generating complete stream response
self.log_raw_request_response = log_raw_request_response
# Initialize dynamic callbacks
@ -782,9 +782,9 @@ class Logging(LiteLLMLoggingBaseClass):
prompt_spec=prompt_spec,
dynamic_callback_params=dynamic_callback_params,
):
self.model_call_details["prompt_integration"] = (
logger.__class__.__name__
)
self.model_call_details[
"prompt_integration"
] = logger.__class__.__name__
return logger
except Exception:
# If check fails, continue to next logger
@ -852,9 +852,9 @@ class Logging(LiteLLMLoggingBaseClass):
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
non_default_params
):
self.model_call_details["prompt_integration"] = (
anthropic_cache_control_logger.__class__.__name__
)
self.model_call_details[
"prompt_integration"
] = anthropic_cache_control_logger.__class__.__name__
return anthropic_cache_control_logger
#########################################################
@ -866,9 +866,9 @@ class Logging(LiteLLMLoggingBaseClass):
internal_usage_cache=None,
llm_router=None,
)
self.model_call_details["prompt_integration"] = (
vector_store_custom_logger.__class__.__name__
)
self.model_call_details[
"prompt_integration"
] = vector_store_custom_logger.__class__.__name__
# Add to global callbacks so post-call hooks are invoked
if (
vector_store_custom_logger
@ -928,9 +928,9 @@ class Logging(LiteLLMLoggingBaseClass):
model
): # if model name was changes pre-call, overwrite the initial model call name with the new one
self.model_call_details["model"] = model
self.model_call_details["litellm_params"]["api_base"] = (
self._get_masked_api_base(additional_args.get("api_base", ""))
)
self.model_call_details["litellm_params"][
"api_base"
] = self._get_masked_api_base(additional_args.get("api_base", ""))
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
# Log the exact input to the LLM API
@ -959,10 +959,10 @@ class Logging(LiteLLMLoggingBaseClass):
try:
# [Non-blocking Extra Debug Information in metadata]
if turn_off_message_logging is True:
_metadata["raw_request"] = (
"redacted by litellm. \
_metadata[
"raw_request"
] = "redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
curl_command = self._get_request_curl_command(
api_base=additional_args.get("api_base", ""),
@ -973,34 +973,34 @@ class Logging(LiteLLMLoggingBaseClass):
_metadata["raw_request"] = str(curl_command)
# split up, so it's easier to parse in the UI
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
# NOTE: setting ignore_sensitive_headers to True will cause
# the Authorization header to be leaked when calls to the health
# endpoint are made and fail.
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
),
error=None,
)
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
raw_request_api_base=str(
additional_args.get("api_base") or ""
),
raw_request_body=self._get_raw_request_body(
additional_args.get("complete_input_dict", {})
),
# NOTE: setting ignore_sensitive_headers to True will cause
# the Authorization header to be leaked when calls to the health
# endpoint are made and fail.
raw_request_headers=self._get_masked_headers(
additional_args.get("headers", {}) or {},
),
error=None,
)
except Exception as e:
self.model_call_details["raw_request_typed_dict"] = (
RawRequestTypedDict(
error=str(e),
)
self.model_call_details[
"raw_request_typed_dict"
] = RawRequestTypedDict(
error=str(e),
)
_metadata["raw_request"] = (
"Unable to Log \
_metadata[
"raw_request"
] = "Unable to Log \
raw request: {}".format(
str(e)
)
str(e)
)
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
try:
@ -1301,13 +1301,13 @@ class Logging(LiteLLMLoggingBaseClass):
for callback in callbacks:
try:
if isinstance(callback, CustomLogger):
response: Optional[MCPPostCallResponseObject] = (
await callback.async_post_mcp_tool_call_hook(
kwargs=kwargs,
response_obj=post_mcp_tool_call_response_obj,
start_time=start_time,
end_time=end_time,
)
response: Optional[
MCPPostCallResponseObject
] = await callback.async_post_mcp_tool_call_hook(
kwargs=kwargs,
response_obj=post_mcp_tool_call_response_obj,
start_time=start_time,
end_time=end_time,
)
######################################################################
# if any of the callbacks modify the response, use the modified response
@ -1502,9 +1502,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
return None
try:
@ -1530,9 +1530,9 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
f"response_cost_failure_debug_information: {debug_info}"
)
self.model_call_details["response_cost_failure_debug_information"] = (
debug_info
)
self.model_call_details[
"response_cost_failure_debug_information"
] = debug_info
return None
@ -1688,9 +1688,9 @@ class Logging(LiteLLMLoggingBaseClass):
result=logging_result
)
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(logging_result, start_time, end_time)
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(logging_result, start_time, end_time)
if (
standard_logging_payload := self.model_call_details.get(
@ -1768,9 +1768,9 @@ class Logging(LiteLLMLoggingBaseClass):
end_time = datetime.datetime.now()
if self.completion_start_time is None:
self.completion_start_time = end_time
self.model_call_details["completion_start_time"] = (
self.completion_start_time
)
self.model_call_details[
"completion_start_time"
] = self.completion_start_time
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
@ -1807,10 +1807,10 @@ class Logging(LiteLLMLoggingBaseClass):
end_time=end_time,
)
elif isinstance(result, dict) or isinstance(result, list):
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
result, start_time, end_time
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
result, start_time, end_time
)
if (
standard_logging_payload := self.model_call_details.get(
@ -1819,9 +1819,9 @@ class Logging(LiteLLMLoggingBaseClass):
) is not None:
emit_standard_logging_payload(standard_logging_payload)
elif standard_logging_object is not None:
self.model_call_details["standard_logging_object"] = (
standard_logging_object
)
self.model_call_details[
"standard_logging_object"
] = standard_logging_object
else:
self.model_call_details["response_cost"] = None
@ -1979,17 +1979,17 @@ class Logging(LiteLLMLoggingBaseClass):
verbose_logger.debug(
"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details["complete_streaming_response"] = (
complete_streaming_response
)
self.model_call_details["response_cost"] = (
self._response_cost_calculator(result=complete_streaming_response)
)
self.model_call_details[
"complete_streaming_response"
] = complete_streaming_response
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(result=complete_streaming_response)
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
if (
standard_logging_payload := self.model_call_details.get(
@ -2323,10 +2323,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
)
result = self.model_call_details["complete_response"]
openMeterLogger.log_success_event(
@ -2350,10 +2350,10 @@ class Logging(LiteLLMLoggingBaseClass):
)
else:
if self.stream and complete_streaming_response:
self.model_call_details["complete_response"] = (
self.model_call_details.get(
"complete_streaming_response", {}
)
self.model_call_details[
"complete_response"
] = self.model_call_details.get(
"complete_streaming_response", {}
)
result = self.model_call_details["complete_response"]
@ -2492,9 +2492,9 @@ class Logging(LiteLLMLoggingBaseClass):
if complete_streaming_response is not None:
print_verbose("Async success callbacks: Got a complete streaming response")
self.model_call_details["async_complete_streaming_response"] = (
complete_streaming_response
)
self.model_call_details[
"async_complete_streaming_response"
] = complete_streaming_response
try:
if self.model_call_details.get("cache_hit", False) is True:
@ -2505,10 +2505,10 @@ class Logging(LiteLLMLoggingBaseClass):
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
self.model_call_details["response_cost"] = (
self._response_cost_calculator(
result=complete_streaming_response
)
self.model_call_details[
"response_cost"
] = self._response_cost_calculator(
result=complete_streaming_response
)
verbose_logger.debug(
@ -2521,10 +2521,10 @@ class Logging(LiteLLMLoggingBaseClass):
self.model_call_details["response_cost"] = None
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(
complete_streaming_response, start_time, end_time
)
# print standard logging payload
@ -2551,9 +2551,9 @@ class Logging(LiteLLMLoggingBaseClass):
# _success_handler_helper_fn
if self.model_call_details.get("standard_logging_object") is None:
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
self._build_standard_logging_payload(result, start_time, end_time)
)
self.model_call_details[
"standard_logging_object"
] = self._build_standard_logging_payload(result, start_time, end_time)
# print standard logging payload
if (
@ -2796,18 +2796,18 @@ class Logging(LiteLLMLoggingBaseClass):
## STANDARDIZED LOGGING PAYLOAD
self.model_call_details["standard_logging_object"] = (
get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
self.model_call_details[
"standard_logging_object"
] = get_standard_logging_object_payload(
kwargs=self.model_call_details,
init_response_obj={},
start_time=start_time,
end_time=end_time,
logging_obj=self,
status="failure",
error_str=str(exception),
original_exception=exception,
standard_built_in_tools_params=self.standard_built_in_tools_params,
)
return start_time, end_time
@ -3774,9 +3774,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
service_name=arize_config.project_name,
)
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
for callback in _in_memory_loggers:
if (
isinstance(callback, ArizeLogger)
@ -3802,13 +3802,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
# Add openinference.project.name attribute
if existing_attrs:
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
)
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
else:
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"openinference.project.name={arize_phoenix_config.project_name}"
)
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"openinference.project.name={arize_phoenix_config.project_name}"
# Set Phoenix project name from environment variable
phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None)
@ -3816,19 +3816,19 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
# Add openinference.project.name attribute
if existing_attrs:
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"{existing_attrs},openinference.project.name={phoenix_project_name}"
)
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"{existing_attrs},openinference.project.name={phoenix_project_name}"
else:
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
f"openinference.project.name={phoenix_project_name}"
)
os.environ[
"OTEL_RESOURCE_ATTRIBUTES"
] = f"openinference.project.name={phoenix_project_name}"
# auth can be disabled on local deployments of arize phoenix
if arize_phoenix_config.otlp_auth_headers is not None:
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
arize_phoenix_config.otlp_auth_headers
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = arize_phoenix_config.otlp_auth_headers
for callback in _in_memory_loggers:
if (
@ -3907,7 +3907,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
from litellm.integrations.focus.focus_logger import FocusLogger
for callback in _in_memory_loggers:
if type(callback) is FocusLogger: # exact match; exclude subclasses like VantageLogger
if (
type(callback) is FocusLogger
): # exact match; exclude subclasses like VantageLogger
return callback # type: ignore
focus_logger = FocusLogger()
_in_memory_loggers.append(focus_logger)
@ -4013,9 +4015,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
exporter="otlp_http",
endpoint="https://langtrace.ai/api/trace",
)
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
)
os.environ[
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
for callback in _in_memory_loggers:
if (
isinstance(callback, OpenTelemetry)
@ -4289,7 +4291,9 @@ def get_custom_logger_compatible_class( # noqa: PLR0915
from litellm.integrations.focus.focus_logger import FocusLogger
for callback in _in_memory_loggers:
if type(callback) is FocusLogger: # exact match; exclude subclasses like VantageLogger
if (
type(callback) is FocusLogger
): # exact match; exclude subclasses like VantageLogger
return callback
elif logging_integration == "vantage":
from litellm.integrations.vantage.vantage_logger import VantageLogger
@ -4937,10 +4941,10 @@ class StandardLoggingPayloadSetup:
for key in StandardLoggingHiddenParams.__annotations__.keys():
if key in hidden_params:
if key == "additional_headers":
clean_hidden_params["additional_headers"] = (
StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
)
clean_hidden_params[
"additional_headers"
] = StandardLoggingPayloadSetup.get_additional_headers(
hidden_params[key]
)
else:
clean_hidden_params[key] = hidden_params[key] # type: ignore
@ -5579,9 +5583,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
):
for k, v in metadata["user_api_key_metadata"].items():
if k == "logging": # prevent logging user logging keys
cleaned_user_api_key_metadata[k] = (
"scrubbed_by_litellm_for_sensitive_keys"
)
cleaned_user_api_key_metadata[
k
] = "scrubbed_by_litellm_for_sensitive_keys"
else:
cleaned_user_api_key_metadata[k] = v

View File

@ -2442,7 +2442,9 @@ def anthropic_messages_pt( # noqa: PLR0915
_document_content_element = cast(
AnthropicMessagesDocumentParam,
add_cache_control_to_content(
anthropic_content_element=cast(AnthropicMessagesDocumentParam, m),
anthropic_content_element=cast(
AnthropicMessagesDocumentParam, m
),
original_content_element=dict(m),
),
)
@ -2454,10 +2456,18 @@ def anthropic_messages_pt( # noqa: PLR0915
)
)
_file_content_element = add_cache_control_to_content(
anthropic_content_element=cast(AnthropicMessagesDocumentParam, _file_content_element),
anthropic_content_element=cast(
AnthropicMessagesDocumentParam,
_file_content_element,
),
original_content_element=dict(m),
)
user_content.append(cast(AnthropicMessagesDocumentParam,_file_content_element))
user_content.append(
cast(
AnthropicMessagesDocumentParam,
_file_content_element,
)
)
elif isinstance(user_message_types_block["content"], str):
_anthropic_content_text_element: AnthropicMessagesTextParam = {
"type": "text",

View File

@ -780,7 +780,7 @@ class LiteLLMAnthropicMessagesAdapter:
# Keep Anthropic-native tools in their original format
new_tools.append(tool) # type: ignore[arg-type]
continue
original_name = tool["name"]
truncated_name = truncate_tool_name(original_name)

View File

@ -336,9 +336,7 @@ class BaseVideoConfig(ABC):
Returns:
Tuple[str, Dict]: (url, data) for the POST request
"""
raise NotImplementedError(
"video edit is not supported for this provider"
)
raise NotImplementedError("video edit is not supported for this provider")
def transform_video_edit_response(
self,
@ -346,9 +344,7 @@ class BaseVideoConfig(ABC):
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
) -> VideoObject:
raise NotImplementedError(
"video edit is not supported for this provider"
)
raise NotImplementedError("video edit is not supported for this provider")
def transform_video_extension_request(
self,
@ -366,9 +362,7 @@ class BaseVideoConfig(ABC):
Returns:
Tuple[str, Dict]: (url, data) for the POST request
"""
raise NotImplementedError(
"video extension is not supported for this provider"
)
raise NotImplementedError("video extension is not supported for this provider")
def transform_video_extension_response(
self,
@ -376,9 +370,7 @@ class BaseVideoConfig(ABC):
logging_obj: LiteLLMLoggingObj,
custom_llm_provider: Optional[str] = None,
) -> VideoObject:
raise NotImplementedError(
"video extension is not supported for this provider"
)
raise NotImplementedError("video extension is not supported for this provider")
def get_error_class(
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]

View File

@ -6162,7 +6162,10 @@ class BaseLLMHTTPHandler:
litellm_params=dict(litellm_params),
)
url, files_list = video_provider_config.transform_video_create_character_request(
(
url,
files_list,
) = video_provider_config.transform_video_create_character_request(
name=name,
video=video,
api_base=api_base,
@ -6230,7 +6233,10 @@ class BaseLLMHTTPHandler:
litellm_params=dict(litellm_params),
)
url, files_list = video_provider_config.transform_video_create_character_request(
(
url,
files_list,
) = video_provider_config.transform_video_create_character_request(
name=name,
video=video,
api_base=api_base,
@ -6324,11 +6330,7 @@ class BaseLLMHTTPHandler:
)
try:
response = sync_httpx_client.get(
url=url,
headers=headers,
params=params
)
response = sync_httpx_client.get(url=url, headers=headers, params=params)
response.raise_for_status()
return video_provider_config.transform_video_get_character_response(
raw_response=response,
@ -6386,9 +6388,7 @@ class BaseLLMHTTPHandler:
try:
response = await async_httpx_client.get(
url=url,
headers=headers,
params=params
url=url, headers=headers, params=params
)
response.raise_for_status()
return video_provider_config.transform_video_get_character_response(

View File

@ -525,28 +525,47 @@ class GeminiVideoConfig(BaseVideoConfig):
"""Video delete is not supported."""
raise NotImplementedError("Video delete is not supported by Google Veo.")
def transform_video_create_character_request(self, name, video, api_base, litellm_params, headers):
def transform_video_create_character_request(
self, name, video, api_base, litellm_params, headers
):
raise NotImplementedError("video create character is not supported for Gemini")
def transform_video_create_character_response(self, raw_response, logging_obj):
raise NotImplementedError("video create character is not supported for Gemini")
def transform_video_get_character_request(self, character_id, api_base, litellm_params, headers):
def transform_video_get_character_request(
self, character_id, api_base, litellm_params, headers
):
raise NotImplementedError("video get character is not supported for Gemini")
def transform_video_get_character_response(self, raw_response, logging_obj):
raise NotImplementedError("video get character is not supported for Gemini")
def transform_video_edit_request(self, prompt, video_id, api_base, litellm_params, headers, extra_body=None):
def transform_video_edit_request(
self, prompt, video_id, api_base, litellm_params, headers, extra_body=None
):
raise NotImplementedError("video edit is not supported for Gemini")
def transform_video_edit_response(self, raw_response, logging_obj, custom_llm_provider=None):
def transform_video_edit_response(
self, raw_response, logging_obj, custom_llm_provider=None
):
raise NotImplementedError("video edit is not supported for Gemini")
def transform_video_extension_request(self, prompt, video_id, seconds, api_base, litellm_params, headers, extra_body=None):
def transform_video_extension_request(
self,
prompt,
video_id,
seconds,
api_base,
litellm_params,
headers,
extra_body=None,
):
raise NotImplementedError("video extension is not supported for Gemini")
def transform_video_extension_response(self, raw_response, logging_obj, custom_llm_provider=None):
def transform_video_extension_response(
self, raw_response, logging_obj, custom_llm_provider=None
):
raise NotImplementedError("video extension is not supported for Gemini")
def get_error_class(

View File

@ -19,7 +19,8 @@ class MoonshotChatConfig(OpenAIGPTConfig):
@overload
def _transform_messages(
self, messages: List[AllMessageValues], model: str, is_async: Literal[True]
) -> Coroutine[Any, Any, List[AllMessageValues]]: ...
) -> Coroutine[Any, Any, List[AllMessageValues]]:
...
@overload
def _transform_messages(
@ -27,7 +28,8 @@ class MoonshotChatConfig(OpenAIGPTConfig):
messages: List[AllMessageValues],
model: str,
is_async: Literal[False] = False,
) -> List[AllMessageValues]: ...
) -> List[AllMessageValues]:
...
def _transform_messages(
self, messages: List[AllMessageValues], model: str, is_async: bool = False
@ -53,9 +55,13 @@ class MoonshotChatConfig(OpenAIGPTConfig):
messages = handle_messages_with_content_list_to_str_conversion(messages)
if is_async:
return super()._transform_messages(messages=messages, model=model, is_async=True)
return super()._transform_messages(
messages=messages, model=model, is_async=True
)
else:
return super()._transform_messages(messages=messages, model=model, is_async=False)
return super()._transform_messages(
messages=messages, model=model, is_async=False
)
def _get_openai_compatible_provider_info(
self, api_base: Optional[str], api_key: Optional[str]
@ -141,7 +147,9 @@ class MoonshotChatConfig(OpenAIGPTConfig):
optional_params["temperature"] = 0.3
return optional_params
def fill_reasoning_content(self, messages: List[AllMessageValues]) -> List[AllMessageValues]:
def fill_reasoning_content(
self, messages: List[AllMessageValues]
) -> List[AllMessageValues]:
"""
Moonshot reasoning models require `reasoning_content` on every assistant
message that contains tool_calls (multi-turn tool-calling flows).

View File

@ -592,28 +592,51 @@ class RunwayMLVideoConfig(BaseVideoConfig):
return video_obj
def transform_video_create_character_request(self, name, video, api_base, litellm_params, headers):
raise NotImplementedError("video create character is not supported for RunwayML")
def transform_video_create_character_request(
self, name, video, api_base, litellm_params, headers
):
raise NotImplementedError(
"video create character is not supported for RunwayML"
)
def transform_video_create_character_response(self, raw_response, logging_obj):
raise NotImplementedError("video create character is not supported for RunwayML")
raise NotImplementedError(
"video create character is not supported for RunwayML"
)
def transform_video_get_character_request(self, character_id, api_base, litellm_params, headers):
def transform_video_get_character_request(
self, character_id, api_base, litellm_params, headers
):
raise NotImplementedError("video get character is not supported for RunwayML")
def transform_video_get_character_response(self, raw_response, logging_obj):
raise NotImplementedError("video get character is not supported for RunwayML")
def transform_video_edit_request(self, prompt, video_id, api_base, litellm_params, headers, extra_body=None):
def transform_video_edit_request(
self, prompt, video_id, api_base, litellm_params, headers, extra_body=None
):
raise NotImplementedError("video edit is not supported for RunwayML")
def transform_video_edit_response(self, raw_response, logging_obj, custom_llm_provider=None):
def transform_video_edit_response(
self, raw_response, logging_obj, custom_llm_provider=None
):
raise NotImplementedError("video edit is not supported for RunwayML")
def transform_video_extension_request(self, prompt, video_id, seconds, api_base, litellm_params, headers, extra_body=None):
def transform_video_extension_request(
self,
prompt,
video_id,
seconds,
api_base,
litellm_params,
headers,
extra_body=None,
):
raise NotImplementedError("video extension is not supported for RunwayML")
def transform_video_extension_response(self, raw_response, logging_obj, custom_llm_provider=None):
def transform_video_extension_response(
self, raw_response, logging_obj, custom_llm_provider=None
):
raise NotImplementedError("video extension is not supported for RunwayML")
def get_error_class(

View File

@ -184,9 +184,7 @@ class SagemakerChatConfig(OpenAIGPTConfig, BaseAWSLLM):
llm_provider = LlmProviders(custom_llm_provider)
except ValueError:
llm_provider = LlmProviders.SAGEMAKER_CHAT
client = get_async_httpx_client(
llm_provider=llm_provider, params={}
)
client = get_async_httpx_client(llm_provider=llm_provider, params={})
try:
response = await client.post(

View File

@ -142,8 +142,8 @@ class VertexAIBatchTransformation:
Gets the output file id from the Vertex AI Batch response
"""
output_file_id: str = (
response.get("outputInfo", OutputInfo()).get("gcsOutputDirectory", "")
output_file_id: str = response.get("outputInfo", OutputInfo()).get(
"gcsOutputDirectory", ""
)
if output_file_id:
output_file_id = output_file_id.rstrip("/") + "/predictions.jsonl"

View File

@ -624,28 +624,51 @@ class VertexAIVideoConfig(BaseVideoConfig, VertexBase):
"""Video delete is not supported."""
raise NotImplementedError("Video delete is not supported by Vertex AI Veo.")
def transform_video_create_character_request(self, name, video, api_base, litellm_params, headers):
raise NotImplementedError("video create character is not supported for Vertex AI")
def transform_video_create_character_request(
self, name, video, api_base, litellm_params, headers
):
raise NotImplementedError(
"video create character is not supported for Vertex AI"
)
def transform_video_create_character_response(self, raw_response, logging_obj):
raise NotImplementedError("video create character is not supported for Vertex AI")
raise NotImplementedError(
"video create character is not supported for Vertex AI"
)
def transform_video_get_character_request(self, character_id, api_base, litellm_params, headers):
def transform_video_get_character_request(
self, character_id, api_base, litellm_params, headers
):
raise NotImplementedError("video get character is not supported for Vertex AI")
def transform_video_get_character_response(self, raw_response, logging_obj):
raise NotImplementedError("video get character is not supported for Vertex AI")
def transform_video_edit_request(self, prompt, video_id, api_base, litellm_params, headers, extra_body=None):
def transform_video_edit_request(
self, prompt, video_id, api_base, litellm_params, headers, extra_body=None
):
raise NotImplementedError("video edit is not supported for Vertex AI")
def transform_video_edit_response(self, raw_response, logging_obj, custom_llm_provider=None):
def transform_video_edit_response(
self, raw_response, logging_obj, custom_llm_provider=None
):
raise NotImplementedError("video edit is not supported for Vertex AI")
def transform_video_extension_request(self, prompt, video_id, seconds, api_base, litellm_params, headers, extra_body=None):
def transform_video_extension_request(
self,
prompt,
video_id,
seconds,
api_base,
litellm_params,
headers,
extra_body=None,
):
raise NotImplementedError("video extension is not supported for Vertex AI")
def transform_video_extension_response(self, raw_response, logging_obj, custom_llm_provider=None):
def transform_video_extension_response(
self, raw_response, logging_obj, custom_llm_provider=None
):
raise NotImplementedError("video extension is not supported for Vertex AI")
def get_error_class(

View File

@ -7533,9 +7533,7 @@ def stream_chunk_builder( # noqa: PLR0915
# the final chunk.
all_annotations: list = []
for ac in annotation_chunks:
all_annotations.extend(
ac["choices"][0]["delta"]["annotations"]
)
all_annotations.extend(ac["choices"][0]["delta"]["annotations"])
response["choices"][0]["message"]["annotations"] = all_annotations
audio_chunks = [

View File

@ -32354,6 +32354,53 @@
"supports_vision": true,
"supports_web_search": true
},
"xai/grok-4.20-multi-agent-beta-0309": {
"cache_read_input_token_cost": 2e-07,
"input_cost_per_token": 2e-06,
"litellm_provider": "xai",
"max_input_tokens": 2000000,
"max_output_tokens": 2000000,
"max_tokens": 2000000,
"mode": "chat",
"output_cost_per_token": 6e-06,
"source": "https://docs.x.ai/docs/models",
"supports_function_calling": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"supports_vision": true,
"supports_web_search": true
},
"xai/grok-4.20-beta-0309-reasoning": {
"cache_read_input_token_cost": 2e-07,
"input_cost_per_token": 2e-06,
"litellm_provider": "xai",
"max_input_tokens": 2000000,
"max_output_tokens": 2000000,
"max_tokens": 2000000,
"mode": "chat",
"output_cost_per_token": 6e-06,
"source": "https://docs.x.ai/docs/models",
"supports_function_calling": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"supports_vision": true,
"supports_web_search": true
},
"xai/grok-4.20-beta-0309-non-reasoning": {
"cache_read_input_token_cost": 2e-07,
"input_cost_per_token": 2e-06,
"litellm_provider": "xai",
"max_input_tokens": 2000000,
"max_output_tokens": 2000000,
"max_tokens": 2000000,
"mode": "chat",
"output_cost_per_token": 6e-06,
"source": "https://docs.x.ai/docs/models",
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_vision": true,
"supports_web_search": true
},
"xai/grok-beta": {
"input_cost_per_token": 5e-06,
"litellm_provider": "xai",

View File

@ -677,7 +677,60 @@ async def oauth_authorization_server_mcp(
# Alias for standard OpenID discovery
@router.get("/.well-known/openid-configuration")
async def openid_configuration(request: Request):
return await oauth_authorization_server_mcp(request)
response = await oauth_authorization_server_mcp(request)
# If MCPJWTSigner is active, augment the discovery doc with JWKS fields so
# MCP servers and gateways (e.g. AWS Bedrock AgentCore Gateway) can resolve
# the signing keys and verify liteLLM-issued tokens.
try:
from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import (
get_mcp_jwt_signer,
)
signer = get_mcp_jwt_signer()
if signer is not None:
request_base_url = get_request_base_url(request)
if isinstance(response, dict):
response = {
**response,
"jwks_uri": f"{request_base_url}/.well-known/jwks.json",
"id_token_signing_alg_values_supported": ["RS256"],
}
except ImportError:
pass
return response
@router.get("/.well-known/jwks.json")
async def jwks_json(request: Request):
"""
JSON Web Key Set endpoint.
Returns the RSA public key used by MCPJWTSigner to sign outbound MCP tokens.
MCP servers and gateways use this endpoint to verify liteLLM-issued JWTs.
Returns an empty key set if MCPJWTSigner is not configured.
"""
try:
from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import (
get_mcp_jwt_signer,
)
signer = get_mcp_jwt_signer()
if signer is not None:
return JSONResponse(
content=signer.get_jwks(),
headers={"Cache-Control": f"public, max-age={signer.jwks_max_age}"},
)
except ImportError:
pass
# No signer active — return empty key set; short cache so activation is picked up quickly.
return JSONResponse(
content={"keys": []},
headers={"Cache-Control": "public, max-age=60"},
)
# Additional legacy pattern support

View File

@ -1908,7 +1908,15 @@ class MCPServerManager:
user_api_key_auth: Optional[UserAPIKeyAuth],
proxy_logging_obj: ProxyLogging,
server: MCPServer,
):
raw_headers: Optional[Dict[str, str]] = None,
) -> Dict[str, Any]:
"""
Run pre-call checks and guardrail hooks for an MCP tool call.
Returns a dict that may contain:
- "arguments": hook-modified tool arguments (only if changed)
- "extra_headers": headers injected by pre_mcp_call guardrail hooks
"""
## check if the tool is allowed or banned for the given server
if not self.check_allowed_or_banned_tools(name, server):
raise HTTPException(
@ -1932,6 +1940,14 @@ class MCPServerManager:
server=server,
)
# Extract incoming Bearer token from raw request headers so
# guardrails like MCPJWTSigner can verify + re-sign it (FR-5).
normalized_raw = {k.lower(): v for k, v in (raw_headers or {}).items()}
incoming_bearer_token: Optional[str] = None
auth_hdr = normalized_raw.get("authorization", "")
if auth_hdr.lower().startswith("bearer "):
incoming_bearer_token = auth_hdr[len("bearer ") :]
pre_hook_kwargs = {
"name": name,
"arguments": arguments,
@ -1957,6 +1973,7 @@ class MCPServerManager:
if user_api_key_auth
else None
),
"incoming_bearer_token": incoming_bearer_token,
}
# Create MCP request object for processing
@ -1969,6 +1986,7 @@ class MCPServerManager:
mcp_request_obj, pre_hook_kwargs
)
hook_result: Dict[str, Any] = {}
try:
# Use standard pre_call_hook
modified_data = await proxy_logging_obj.pre_call_hook(
@ -1984,7 +2002,9 @@ class MCPServerManager:
)
)
if modified_kwargs.get("arguments") != arguments:
arguments = modified_kwargs["arguments"]
hook_result["arguments"] = modified_kwargs["arguments"]
if modified_kwargs.get("extra_headers"):
hook_result["extra_headers"] = modified_kwargs["extra_headers"]
except (
BlockedPiiEntityError,
@ -1995,6 +2015,8 @@ class MCPServerManager:
verbose_logger.error(f"Guardrail blocked MCP tool call pre call: {str(e)}")
raise e
return hook_result
def _create_during_hook_task(
self,
name: str,
@ -2047,6 +2069,7 @@ class MCPServerManager:
raw_headers: Optional[Dict[str, str]],
proxy_logging_obj: Optional[ProxyLogging],
host_progress_callback: Optional[Callable] = None,
hook_extra_headers: Optional[Dict[str, str]] = None,
) -> CallToolResult:
"""
Call a regular MCP tool using the MCP client.
@ -2061,6 +2084,9 @@ class MCPServerManager:
oauth2_headers: Optional OAuth2 headers
raw_headers: Optional raw headers from the request
proxy_logging_obj: Optional ProxyLogging object for hook integration
host_progress_callback: Optional callback for progress updates
hook_extra_headers: Optional headers injected by pre_mcp_call guardrail
hooks. Merged last (highest priority) into outbound request headers.
Returns:
CallToolResult from the MCP server
@ -2116,6 +2142,31 @@ class MCPServerManager:
extra_headers = {}
extra_headers.update(mcp_server.static_headers)
if hook_extra_headers:
if extra_headers is None:
extra_headers = {}
if "Authorization" in hook_extra_headers:
if "Authorization" in extra_headers:
verbose_logger.warning(
"MCPServerManager: hook_extra_headers 'Authorization' will overwrite "
"the existing Authorization header from static_headers. "
"The hook JWT will take precedence."
)
elif server_auth_header is not None:
# server_auth_header is passed separately to _create_mcp_client as
# auth_value. Both will reach the upstream server — warn so admins
# know two Authorization credentials are being sent.
verbose_logger.warning(
"MCPServerManager: hook_extra_headers injects 'Authorization' while "
"server '%s' already has a configured authentication_token. "
"Both credentials will be sent; the hook header is in extra_headers "
"and the server token is in auth_value — the upstream server decides "
"which one wins. Consider unsetting authentication_token if you want "
"the hook JWT to be the sole credential.",
mcp_server.server_name or mcp_server.name,
)
extra_headers.update(hook_extra_headers)
stdio_env = self._build_stdio_env(mcp_server, raw_headers)
client = await self._create_mcp_client(
@ -2201,15 +2252,19 @@ class MCPServerManager:
# Allow validation and modification of tool calls before execution
# Using standard pre_call_hook
#########################################################
hook_result: Dict[str, Any] = {}
if proxy_logging_obj:
await self.pre_call_tool_check(
hook_result = await self.pre_call_tool_check(
name=name,
arguments=arguments,
server_name=server_name,
user_api_key_auth=user_api_key_auth,
proxy_logging_obj=proxy_logging_obj,
server=mcp_server,
raw_headers=raw_headers,
)
if "arguments" in hook_result:
arguments = hook_result["arguments"]
# Prepare tasks for during hooks
tasks = []
@ -2227,8 +2282,16 @@ class MCPServerManager:
# For OpenAPI servers, call the tool handler directly instead of via MCP client
if mcp_server.spec_path:
verbose_logger.debug(
f"Calling OpenAPI tool {name} directly via HTTP handler"
"Calling OpenAPI tool %s directly via HTTP handler", name
)
if hook_result.get("extra_headers"):
verbose_logger.warning(
"pre_mcp_call hook returned extra_headers for OpenAPI-backed "
"MCP server '%s' — header injection is not supported for "
"OpenAPI servers; headers will be ignored. Use SSE/HTTP "
"transport to enable hook header injection.",
server_name,
)
tasks.append(
asyncio.create_task(
self._call_openapi_tool_handler(mcp_server, name, arguments)
@ -2247,6 +2310,7 @@ class MCPServerManager:
raw_headers=raw_headers,
proxy_logging_obj=proxy_logging_obj,
host_progress_callback=host_progress_callback,
hook_extra_headers=hook_result.get("extra_headers"),
)
# For OpenAPI tools, await outside the client context

View File

@ -903,12 +903,12 @@ if MCP_AVAILABLE:
try:
client_id, client_secret, scopes = _extract_credentials(request)
_oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = (
request.oauth2_flow or (
"client_credentials"
if client_id and client_secret and request.token_url
else None
)
_oauth2_flow: Optional[
Literal["client_credentials", "authorization_code"]
] = request.oauth2_flow or (
"client_credentials"
if client_id and client_secret and request.token_url
else None
)
# client_credentials requires token_url to fetch a token; without it the
# incoming auth header would be dropped with nothing to replace it.

View File

@ -2471,6 +2471,9 @@ class UserAPIKeyAuth(
Any
] = None # Expanded created_by user when expand=user is used
end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
# Decoded upstream IdP claims (groups, roles, etc.) propagated by JWT auth machinery
# and forwarded into outbound tokens by guardrails such as MCPJWTSigner.
jwt_claims: Optional[Dict] = None
model_config = ConfigDict(arbitrary_types_allowed=True)

View File

@ -680,7 +680,7 @@ def get_customer_user_header_from_mapping(user_id_mapping) -> Optional[list]:
if customer_headers_mappings:
return customer_headers_mappings
return None
@ -754,15 +754,11 @@ def get_end_user_id_from_request_body(
user_id_str = str(header_value)
if user_id_str.strip():
return user_id_str
elif isinstance(custom_header_name_to_check, str):
for header_name, header_value in request_headers.items():
if header_name.lower() == custom_header_name_to_check.lower():
user_id_str = (
str(header_value)
if header_value is not None
else ""
)
user_id_str = str(header_value) if header_value is not None else ""
if user_id_str.strip():
return user_id_str

View File

@ -685,6 +685,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
do_standard_jwt_auth = True
if jwt_handler.litellm_jwtauth.virtual_key_claim_field is not None:
# Decode JWT to get claims without running full auth_builder
jwt_claims: Optional[dict]
if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled:
jwt_claims = await jwt_handler.get_oidc_userinfo(token=api_key)
else:
@ -700,6 +701,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
)
if valid_token is not None:
api_key = valid_token.token or ""
valid_token.jwt_claims = jwt_claims
do_standard_jwt_auth = False
# Fall through to virtual key checks
@ -729,6 +731,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
team_membership: Optional[LiteLLM_TeamMembership] = result.get(
"team_membership", None
)
jwt_claims = result.get("jwt_claims", None)
global_proxy_spend = await get_global_proxy_spend(
litellm_proxy_admin_name=litellm_proxy_admin_name,
@ -757,6 +760,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
org_id=org_id,
end_user_id=end_user_id,
parent_otel_span=parent_otel_span,
jwt_claims=jwt_claims,
)
valid_token = UserAPIKeyAuth(
@ -803,6 +807,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
team_metadata=(
team_object.metadata if team_object is not None else None
),
jwt_claims=jwt_claims,
)
# Check if model has zero cost - if so, skip all budget checks

View File

@ -537,9 +537,10 @@ async def retrieve_batch( # noqa: PLR0915
)
# Fix: bug_feb14_batch_retrieve_returns_raw_input_file_id
# Resolve raw provider input_file_id to unified ID.
# Resolve raw provider file IDs (input, output, error) to unified IDs.
if unified_batch_id:
await resolve_input_file_id_to_unified(response, prisma_client)
await resolve_output_file_ids_to_unified(response, prisma_client)
### ALERTING ###
asyncio.create_task(

View File

@ -0,0 +1,84 @@
"""MCP JWT Signer guardrail — built-in LiteLLM guardrail for zero trust MCP auth."""
from typing import TYPE_CHECKING
from litellm.types.guardrails import SupportedGuardrailIntegrations
from .mcp_jwt_signer import MCPJWTSigner, get_mcp_jwt_signer
if TYPE_CHECKING:
from litellm.types.guardrails import Guardrail, LitellmParams
def initialize_guardrail(
litellm_params: "LitellmParams", guardrail: "Guardrail"
) -> MCPJWTSigner:
import litellm
guardrail_name = guardrail.get("guardrail_name")
if not guardrail_name:
raise ValueError("MCPJWTSigner guardrail requires a guardrail_name")
mode = litellm_params.mode
if mode != "pre_mcp_call":
raise ValueError(
f"MCPJWTSigner guardrail '{guardrail_name}' has mode='{mode}' but must use "
"mode='pre_mcp_call'. JWT injection only fires for MCP tool calls."
)
optional_params = getattr(litellm_params, "optional_params", None)
def _get(key): # type: ignore[no-untyped-def]
if optional_params is not None:
v = getattr(optional_params, key, None)
if v is not None:
return v
return getattr(litellm_params, key, None)
signer = MCPJWTSigner(
guardrail_name=guardrail_name,
event_hook=litellm_params.mode,
default_on=litellm_params.default_on,
# Core signing
issuer=_get("issuer"),
audience=_get("audience"),
ttl_seconds=_get("ttl_seconds"),
# FR-5: verify + re-sign
access_token_discovery_uri=_get("access_token_discovery_uri"),
token_introspection_endpoint=_get("token_introspection_endpoint"),
verify_issuer=_get("verify_issuer"),
verify_audience=_get("verify_audience"),
# FR-12: end-user identity mapping
end_user_claim_sources=_get("end_user_claim_sources"),
# FR-13: claim operations
add_claims=_get("add_claims"),
set_claims=_get("set_claims"),
remove_claims=_get("remove_claims"),
# FR-14: two-token model
channel_token_audience=_get("channel_token_audience"),
channel_token_ttl=_get("channel_token_ttl"),
# FR-15: incoming claim validation
required_claims=_get("required_claims"),
optional_claims=_get("optional_claims"),
# FR-9: debug headers
debug_headers=_get("debug_headers") or False,
# FR-10: configurable scopes
allowed_scopes=_get("allowed_scopes"),
)
litellm.logging_callback_manager.add_litellm_callback(signer)
return signer
guardrail_initializer_registry = {
SupportedGuardrailIntegrations.MCP_JWT_SIGNER.value: initialize_guardrail,
}
guardrail_class_registry = {
SupportedGuardrailIntegrations.MCP_JWT_SIGNER.value: MCPJWTSigner,
}
__all__ = [
"MCPJWTSigner",
"initialize_guardrail",
"get_mcp_jwt_signer",
]

View File

@ -0,0 +1,891 @@
"""
MCPJWTSigner Built-in LiteLLM guardrail for zero trust MCP authentication.
Signs outbound MCP requests with a LiteLLM-issued RS256 JWT so that MCP servers
can trust a single signing authority (liteLLM) instead of every upstream IdP.
Usage in config.yaml:
guardrails:
- guardrail_name: "mcp-jwt-signer"
litellm_params:
guardrail: mcp_jwt_signer
mode: "pre_mcp_call"
default_on: true
# Core signing config
issuer: "https://my-litellm.example.com" # optional
audience: "mcp" # optional
ttl_seconds: 300 # optional
# FR-5: Verify + re-sign — validate incoming Bearer token before signing
access_token_discovery_uri: "https://idp.example.com/.well-known/openid-configuration"
token_introspection_endpoint: "https://idp.example.com/introspect" # opaque tokens
verify_issuer: "https://idp.example.com" # expected iss in incoming JWT
verify_audience: "api://my-app" # expected aud in incoming JWT
# FR-12: End-user identity mapping — ordered resolution chain
# Supported: token:<claim>, litellm:user_id, litellm:email,
# litellm:end_user_id, litellm:team_id
end_user_claim_sources:
- "token:sub"
- "token:email"
- "litellm:user_id"
# FR-13: Claim operations
add_claims: # add if key not already present in the JWT
deployment_id: "prod-001"
set_claims: # always set (overrides computed value)
env: "production"
remove_claims: # remove from final JWT
- "nbf"
# FR-14: Two-token model — issue a second JWT for the MCP transport channel
channel_token_audience: "bedrock-gateway"
channel_token_ttl: 60
# FR-15: Incoming claim validation — enforce required IdP claims
required_claims:
- "sub"
- "email"
optional_claims: # pass through from jwt_claims into outbound JWT
- "groups"
- "roles"
# FR-9: Debug headers
debug_headers: false # emit x-litellm-mcp-debug header when true
# FR-10: Configurable scopes — explicit list replaces auto-generation
allowed_scopes:
- "mcp:tools/call"
- "mcp:tools/list"
MCP servers verify tokens via:
GET /.well-known/openid-configuration { jwks_uri: ".../.well-known/jwks.json" }
GET /.well-known/jwks.json RSA public key in JWKS format
Optionally set MCP_JWT_SIGNING_KEY env var (PEM string or file:///path) to use
your own RSA keypair. If unset, an RSA-2048 keypair is auto-generated at startup.
"""
import base64
import hashlib
import os
import re
import time
from typing import Any, Dict, List, Optional, Union
import jwt
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from litellm._logging import verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_guardrail import (
CustomGuardrail,
log_guardrail_information,
)
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.utils import CallTypesLiteral
# Module-level singleton for the JWKS discovery endpoint to access.
_mcp_jwt_signer_instance: Optional["MCPJWTSigner"] = None
# Simple in-memory JWKS cache: keyed by JWKS URI → (keys_list, fetched_at).
_jwks_cache: Dict[str, tuple] = {}
_JWKS_CACHE_TTL = 3600 # 1 hour
def get_mcp_jwt_signer() -> Optional["MCPJWTSigner"]:
"""Return the active MCPJWTSigner singleton, or None if not initialized."""
return _mcp_jwt_signer_instance
def _load_private_key_from_env(env_var: str) -> RSAPrivateKey:
"""Load an RSA private key from an env var (PEM string or file:// path)."""
key_material = os.environ.get(env_var, "")
if not key_material:
raise ValueError(
f"MCPJWTSigner: environment variable '{env_var}' is set but empty."
)
if key_material.startswith("file://"):
path = key_material[len("file://") :]
with open(path, "rb") as f:
key_bytes = f.read()
else:
key_bytes = key_material.encode("utf-8")
return serialization.load_pem_private_key(key_bytes, password=None) # type: ignore[return-value]
def _generate_rsa_key_pair() -> RSAPrivateKey:
"""Generate a new RSA-2048 private key."""
return rsa.generate_private_key(
public_exponent=65537,
key_size=2048,
)
def _int_to_base64url(n: int) -> str:
"""Encode an integer as a base64url string (no padding)."""
byte_length = (n.bit_length() + 7) // 8
return (
base64.urlsafe_b64encode(n.to_bytes(byte_length, byteorder="big"))
.rstrip(b"=")
.decode("ascii")
)
def _compute_kid(public_key: Any) -> str:
"""Derive a key ID from the public key's DER encoding (SHA-256, first 16 hex chars)."""
der_bytes = public_key.public_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
)
return hashlib.sha256(der_bytes).hexdigest()[:16]
async def _fetch_jwks(jwks_uri: str) -> List[Dict[str, Any]]:
"""
Fetch and cache a JWKS from the given URI.
Results are cached for _JWKS_CACHE_TTL seconds to avoid hammering the IdP.
"""
now = time.time()
cached = _jwks_cache.get(jwks_uri)
if cached is not None:
keys, fetched_at = cached
if now - fetched_at < _JWKS_CACHE_TTL:
return keys # type: ignore[return-value]
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
resp = await client.get(jwks_uri, headers={"Accept": "application/json"})
resp.raise_for_status()
keys = resp.json().get("keys", [])
_jwks_cache[jwks_uri] = (keys, now)
return keys # type: ignore[return-value]
async def _fetch_oidc_discovery(discovery_uri: str) -> Dict[str, Any]:
"""Fetch an OIDC discovery document and return its parsed JSON."""
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
resp = await client.get(discovery_uri, headers={"Accept": "application/json"})
resp.raise_for_status()
return resp.json() # type: ignore[return-value]
class MCPJWTSigner(CustomGuardrail):
"""
Built-in LiteLLM guardrail that signs outbound MCP requests with a
LiteLLM-issued RS256 JWT, enabling zero trust authentication.
MCP servers verify tokens using liteLLM's OIDC discovery endpoint and
JWKS endpoint rather than trusting each upstream IdP directly.
The signed JWT carries:
- iss: LiteLLM issuer identifier
- aud: MCP audience (configurable)
- sub: End-user identity (resolved via end_user_claim_sources, RFC 8693)
- act: Actor/agent identity (team_id or org_id, RFC 8693 delegation)
- scope: Tool-level access scopes (configurable via allowed_scopes)
- iat, exp, nbf: Standard timing claims
Feature set:
FR-5: Verify + re-sign (access_token_discovery_uri, token_introspection_endpoint)
FR-9: Debug headers (debug_headers)
FR-10: Configurable scopes (allowed_scopes)
FR-12: Configurable end-user identity mapping (end_user_claim_sources)
FR-13: Claim operations (add_claims, set_claims, remove_claims)
FR-14: Two-token model (channel_token_audience, channel_token_ttl)
FR-15: Incoming claim validation (required_claims, optional_claims)
"""
ALGORITHM = "RS256"
DEFAULT_TTL = 300
DEFAULT_AUDIENCE = "mcp"
SIGNING_KEY_ENV = "MCP_JWT_SIGNING_KEY"
def __init__(
self,
# Core signing config
issuer: Optional[str] = None,
audience: Optional[str] = None,
ttl_seconds: Optional[int] = None,
# FR-5: Verify + re-sign
access_token_discovery_uri: Optional[str] = None,
token_introspection_endpoint: Optional[str] = None,
verify_issuer: Optional[str] = None,
verify_audience: Optional[str] = None,
# FR-12: End-user identity mapping
end_user_claim_sources: Optional[List[str]] = None,
# FR-13: Claim operations
add_claims: Optional[Dict[str, Any]] = None,
set_claims: Optional[Dict[str, Any]] = None,
remove_claims: Optional[List[str]] = None,
# FR-14: Two-token model
channel_token_audience: Optional[str] = None,
channel_token_ttl: Optional[int] = None,
# FR-15: Incoming claim validation
required_claims: Optional[List[str]] = None,
optional_claims: Optional[List[str]] = None,
# FR-9: Debug headers
debug_headers: bool = False,
# FR-10: Configurable scopes
allowed_scopes: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
# --- Signing key setup ---
key_material = os.environ.get(self.SIGNING_KEY_ENV)
if key_material:
self._private_key = _load_private_key_from_env(self.SIGNING_KEY_ENV)
self._persistent_key: bool = True
verbose_proxy_logger.info(
"MCPJWTSigner: loaded RSA key from env var %s", self.SIGNING_KEY_ENV
)
else:
self._private_key = _generate_rsa_key_pair()
self._persistent_key = False
verbose_proxy_logger.info(
"MCPJWTSigner: auto-generated RSA-2048 keypair (set %s to use your own key)",
self.SIGNING_KEY_ENV,
)
self._public_key = self._private_key.public_key()
self._kid = _compute_kid(self._public_key)
# --- Core config ---
self.issuer: str = (
issuer
or os.environ.get("MCP_JWT_ISSUER")
or os.environ.get("LITELLM_EXTERNAL_URL")
or "litellm"
)
self.audience: str = (
audience or os.environ.get("MCP_JWT_AUDIENCE") or self.DEFAULT_AUDIENCE
)
resolved_ttl = int(
ttl_seconds
if ttl_seconds is not None
else os.environ.get("MCP_JWT_TTL_SECONDS", str(self.DEFAULT_TTL))
)
if resolved_ttl <= 0:
raise ValueError(
f"MCPJWTSigner: ttl_seconds must be > 0, got {resolved_ttl}"
)
self.ttl_seconds: int = resolved_ttl
# --- FR-5: Verify + re-sign ---
self.access_token_discovery_uri: Optional[str] = access_token_discovery_uri
self.token_introspection_endpoint: Optional[str] = token_introspection_endpoint
self.verify_issuer: Optional[str] = verify_issuer
self.verify_audience: Optional[str] = verify_audience
# Cached OIDC discovery document (fetched lazily, TTL = 24 h)
self._oidc_discovery_doc: Optional[Dict[str, Any]] = None
self._oidc_discovery_fetched_at: float = 0.0
# --- FR-12: End-user identity mapping ---
# Default chain: try incoming JWT sub, fall back to litellm user_id
self.end_user_claim_sources: List[str] = end_user_claim_sources or [
"token:sub",
"litellm:user_id",
]
# --- FR-13: Claim operations ---
self.add_claims: Dict[str, Any] = add_claims or {}
self.set_claims: Dict[str, Any] = set_claims or {}
self.remove_claims: List[str] = remove_claims or []
# --- FR-14: Two-token model ---
self.channel_token_audience: Optional[str] = channel_token_audience
self.channel_token_ttl: int = (
channel_token_ttl if channel_token_ttl is not None else self.ttl_seconds
)
# --- FR-15: Incoming claim validation ---
self.required_claims: List[str] = required_claims or []
self.optional_claims: List[str] = optional_claims or []
# --- FR-9: Debug headers ---
self.debug_headers: bool = debug_headers
# --- FR-10: Configurable scopes ---
self.allowed_scopes: Optional[List[str]] = allowed_scopes
# Register singleton for JWKS/OIDC discovery endpoints.
global _mcp_jwt_signer_instance
if _mcp_jwt_signer_instance is not None:
verbose_proxy_logger.warning(
"MCPJWTSigner: replacing existing singleton — previously issued tokens "
"signed with the old key will fail JWKS verification. "
"Avoid configuring multiple mcp_jwt_signer guardrails."
)
_mcp_jwt_signer_instance = self
verbose_proxy_logger.info(
"MCPJWTSigner initialized: issuer=%s audience=%s ttl=%ds kid=%s "
"verify=%s channel_token=%s debug=%s",
self.issuer,
self.audience,
self.ttl_seconds,
self._kid,
bool(self.access_token_discovery_uri),
bool(self.channel_token_audience),
self.debug_headers,
)
# ------------------------------------------------------------------
# Public helpers (used by /.well-known/jwks.json endpoint)
# ------------------------------------------------------------------
@property
def jwks_max_age(self) -> int:
"""
Recommended Cache-Control max-age for the JWKS response (seconds).
1 hour for persistent keys; 5 minutes for auto-generated keys so MCP
servers re-fetch quickly after a proxy restart.
"""
return 3600 if self._persistent_key else 300
def get_jwks(self) -> Dict[str, Any]:
"""
Return the JWKS for the RSA public key.
Used by GET /.well-known/jwks.json so MCP servers can verify tokens.
"""
public_numbers = self._public_key.public_numbers()
return {
"keys": [
{
"kty": "RSA",
"alg": self.ALGORITHM,
"use": "sig",
"kid": self._kid,
"n": _int_to_base64url(public_numbers.n),
"e": _int_to_base64url(public_numbers.e),
}
]
}
# ------------------------------------------------------------------
# FR-5: Verify + re-sign helpers
# ------------------------------------------------------------------
# 24-hour TTL for the OIDC discovery doc — long enough to avoid hammering
# the IdP, short enough to pick up jwks_uri changes after key rotation.
_OIDC_DISCOVERY_TTL = 86400
async def _get_oidc_discovery(self) -> Dict[str, Any]:
"""Fetch and cache the OIDC discovery document with a 24-hour TTL.
Only caches when the doc contains a 'jwks_uri' so that a transient or
malformed response doesn't permanently disable JWT verification.
"""
now = time.time()
cache_expired = (
now - self._oidc_discovery_fetched_at
) >= self._OIDC_DISCOVERY_TTL
if (
self._oidc_discovery_doc is None or cache_expired
) and self.access_token_discovery_uri:
doc = await _fetch_oidc_discovery(self.access_token_discovery_uri)
if "jwks_uri" in doc:
self._oidc_discovery_doc = doc
self._oidc_discovery_fetched_at = now
else:
return doc
return self._oidc_discovery_doc or {}
async def _verify_incoming_jwt(self, raw_token: str) -> Dict[str, Any]:
"""
Verify an incoming Bearer JWT against the configured IdP's JWKS.
Returns the verified payload claims dict.
Raises jwt.PyJWTError (or subclass) if verification fails.
"""
discovery = await self._get_oidc_discovery()
jwks_uri = discovery.get("jwks_uri")
if not jwks_uri:
raise ValueError(
"MCPJWTSigner: access_token_discovery_uri discovery document "
f"at {self.access_token_discovery_uri!r} has no 'jwks_uri'."
)
jwks_keys = await _fetch_jwks(jwks_uri)
# Only read `kid` from the unverified header — never `alg`.
# Reading `alg` from an attacker-controlled header enables algorithm
# confusion attacks (e.g. alg:none, HS256 with the public key as secret).
# The algorithm is determined from the JWKS key entry instead.
unverified_header = jwt.get_unverified_header(raw_token)
kid = unverified_header.get("kid")
# Build a JWKS object and pick the matching key.
# PyJWT's PyJWKSet handles key-type parsing and kid matching correctly.
from jwt import PyJWKSet
try:
jwks_set = PyJWKSet.from_dict({"keys": jwks_keys})
except Exception as exc:
raise jwt.exceptions.PyJWKSetError( # type: ignore[attr-defined]
f"Failed to parse JWKS from {jwks_uri!r}: {exc}"
) from exc
signing_jwk = None
for jwk_obj in jwks_set.keys:
if not kid or jwk_obj.key_id == kid:
signing_jwk = jwk_obj
break
if signing_jwk is None:
raise jwt.exceptions.PyJWKSetError( # type: ignore[attr-defined]
f"No JWKS key matching kid={kid!r} at {jwks_uri!r}"
)
# Use the algorithm declared by the JWKS key entry, not the token header.
# PyJWT populates algorithm_name from the key's `alg` field; when absent
# it infers from the key type (RSAPublicKey → RS256).
alg = getattr(signing_jwk, "algorithm_name", None) or "RS256"
decode_options: Dict[str, Any] = {"verify_exp": True}
decode_kwargs: Dict[str, Any] = {
"algorithms": [alg],
"options": decode_options,
}
if self.verify_audience:
decode_kwargs["audience"] = self.verify_audience
else:
decode_options["verify_aud"] = False
if self.verify_issuer:
decode_kwargs["issuer"] = self.verify_issuer
payload: Dict[str, Any] = jwt.decode(
raw_token, signing_jwk.key, **decode_kwargs
)
return payload
async def _introspect_opaque_token(self, token: str) -> Dict[str, Any]:
"""
Perform RFC 7662 token introspection for opaque (non-JWT) tokens.
Returns the introspection response dict. Raises on HTTP error or
inactive token.
"""
if not self.token_introspection_endpoint:
raise ValueError(
"MCPJWTSigner: token_introspection_endpoint is required for "
"opaque token verification but is not configured."
)
from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
resp = await client.post(
self.token_introspection_endpoint,
data={"token": token},
headers={"Accept": "application/json"},
)
resp.raise_for_status()
result: Dict[str, Any] = resp.json()
if not result.get("active", False):
raise jwt.exceptions.ExpiredSignatureError( # type: ignore[attr-defined]
"MCPJWTSigner: incoming token is inactive (introspection returned active=false)"
)
return result
# ------------------------------------------------------------------
# FR-15: Incoming claim validation
# ------------------------------------------------------------------
def _validate_required_claims(
self,
jwt_claims: Optional[Dict[str, Any]],
) -> None:
"""
Raise HTTP 403 if any required_claims are absent from the verified
incoming token claims.
"""
if not self.required_claims:
return
from fastapi import HTTPException
missing = [c for c in self.required_claims if not (jwt_claims or {}).get(c)]
if missing:
raise HTTPException(
status_code=403,
detail={
"error": (
f"MCPJWTSigner: incoming token is missing required claims: "
f"{missing}. Configure the IdP to include these claims."
)
},
)
# ------------------------------------------------------------------
# FR-12: End-user identity mapping
# ------------------------------------------------------------------
def _resolve_end_user_identity(
self,
user_api_key_dict: UserAPIKeyAuth,
jwt_claims: Optional[Dict[str, Any]],
) -> str:
"""
Resolve the outbound JWT 'sub' using the ordered end_user_claim_sources list.
Supported source prefixes:
token:<claim> from verified incoming JWT / introspection claims
litellm:user_id from UserAPIKeyAuth.user_id
litellm:email from UserAPIKeyAuth.user_email
litellm:end_user_id from UserAPIKeyAuth.end_user_id
litellm:team_id from UserAPIKeyAuth.team_id
Falls back to a stable hash of the API token for service-account callers.
"""
for source in self.end_user_claim_sources:
value: Optional[str] = None
if source.startswith("token:"):
claim_name = source[len("token:") :]
raw = (jwt_claims or {}).get(claim_name)
value = str(raw) if raw else None
elif source == "litellm:user_id":
uid = getattr(user_api_key_dict, "user_id", None)
value = str(uid) if uid else None
elif source == "litellm:email":
email = getattr(user_api_key_dict, "user_email", None)
value = str(email) if email else None
elif source == "litellm:end_user_id":
eid = getattr(user_api_key_dict, "end_user_id", None)
value = str(eid) if eid else None
elif source == "litellm:team_id":
tid = getattr(user_api_key_dict, "team_id", None)
value = str(tid) if tid else None
else:
verbose_proxy_logger.warning(
"MCPJWTSigner: unknown end_user_claim_source %r — skipping", source
)
continue
if value:
return value
# Final fallback for service accounts with no user identity
token = getattr(user_api_key_dict, "token", None) or getattr(
user_api_key_dict, "api_key", None
)
if token:
return "apikey:" + hashlib.sha256(str(token).encode()).hexdigest()[:16]
return "litellm-proxy"
# ------------------------------------------------------------------
# FR-10: Scope building
# ------------------------------------------------------------------
def _build_scope(self, raw_tool_name: str) -> str:
"""
Build the JWT scope string.
When allowed_scopes is configured: join them verbatim.
Otherwise auto-generate minimal, least-privilege scopes:
- Tool call mcp:tools/call mcp:tools/<name>:call
- No tool mcp:tools/call mcp:tools/list
NOTE: tools/list is intentionally NOT granted on tool-call JWTs to
prevent callers from enumerating tools they didn't ask to use.
"""
if self.allowed_scopes is not None:
return " ".join(self.allowed_scopes)
tool_name = (
re.sub(r"[^a-zA-Z0-9_\-]", "_", raw_tool_name) if raw_tool_name else ""
)
if tool_name:
scopes = ["mcp:tools/call", f"mcp:tools/{tool_name}:call"]
else:
scopes = ["mcp:tools/call", "mcp:tools/list"]
return " ".join(scopes)
# ------------------------------------------------------------------
# FR-13: Claim operations
# ------------------------------------------------------------------
def _apply_claim_operations(self, claims: Dict[str, Any]) -> Dict[str, Any]:
"""Apply add_claims, set_claims, and remove_claims to the claim dict."""
# add_claims: insert only when key is absent
for k, v in self.add_claims.items():
if k not in claims:
claims[k] = v
# set_claims: always override (highest priority)
claims = {**claims, **self.set_claims}
# remove_claims: delete listed keys
for k in self.remove_claims:
claims.pop(k, None)
return claims
# ------------------------------------------------------------------
# FR-15: optional_claims passthrough
# ------------------------------------------------------------------
def _passthrough_optional_claims(
self,
claims: Dict[str, Any],
jwt_claims: Optional[Dict[str, Any]],
) -> Dict[str, Any]:
"""Forward optional_claims from verified incoming token into the outbound JWT."""
if not self.optional_claims or not jwt_claims:
return claims
for claim in self.optional_claims:
if claim in jwt_claims and claim not in claims:
claims[claim] = jwt_claims[claim]
return claims
# ------------------------------------------------------------------
# Core JWT builder
# ------------------------------------------------------------------
def _build_claims(
self,
user_api_key_dict: UserAPIKeyAuth,
data: dict,
jwt_claims: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
"""
Build JWT claims for the outbound MCP access token.
Args:
user_api_key_dict: LiteLLM auth context for the current request.
data: Pre-call hook data dict (contains mcp_tool_name etc.).
jwt_claims: Verified incoming IdP claims (FR-5), or LiteLLM-decoded
jwt_claims if available. None for pure API-key requests.
"""
now = int(time.time())
claims: Dict[str, Any] = {
"iss": self.issuer,
"aud": self.audience,
"iat": now,
"exp": now + self.ttl_seconds,
"nbf": now,
}
# sub — resolved via ordered claim sources (FR-12)
claims["sub"] = self._resolve_end_user_identity(user_api_key_dict, jwt_claims)
# email passthrough when available from LiteLLM context
user_email = getattr(user_api_key_dict, "user_email", None)
if user_email:
claims["email"] = user_email
# act — RFC 8693 delegation claim (team/org context)
team_id = getattr(user_api_key_dict, "team_id", None)
org_id = getattr(user_api_key_dict, "org_id", None)
act_sub = team_id or org_id or "litellm-proxy"
claims["act"] = {"sub": act_sub}
# end_user_id when set separately from user_id
end_user_id = getattr(user_api_key_dict, "end_user_id", None)
if end_user_id:
claims["end_user_id"] = end_user_id
# scope (FR-10)
raw_tool_name: str = data.get("mcp_tool_name", "")
claims["scope"] = self._build_scope(raw_tool_name)
# optional_claims passthrough (FR-15)
claims = self._passthrough_optional_claims(claims, jwt_claims)
# Claim operations — applied last so admin overrides take effect (FR-13)
claims = self._apply_claim_operations(claims)
return claims
def _build_channel_token_claims(
self,
base_claims: Dict[str, Any],
) -> Dict[str, Any]:
"""
Build claims for the channel token (FR-14 two-token model).
Inherits sub/act/scope from the access token but uses a separate
audience and TTL so the transport layer and resource layer receive
purpose-bound credentials.
"""
now = int(time.time())
return {
**base_claims,
"aud": self.channel_token_audience,
"iat": now,
"exp": now + self.channel_token_ttl,
"nbf": now,
}
# ------------------------------------------------------------------
# FR-9: Debug header
# ------------------------------------------------------------------
@staticmethod
def _build_debug_header(claims: Dict[str, Any], kid: str) -> str:
"""
Build the x-litellm-mcp-debug header value.
Format: v=1; kid=<kid>; sub=<sub>; iss=<iss>; exp=<exp>; scope=<scope>
Scope is truncated to 80 chars for header safety.
"""
sub = claims.get("sub", "")
iss = claims.get("iss", "")
exp = claims.get("exp", 0)
scope = claims.get("scope", "")
if len(scope) > 80:
scope = scope[:77] + "..."
return f"v=1; kid={kid}; sub={sub}; iss={iss}; exp={exp}; scope={scope}"
# ------------------------------------------------------------------
# Guardrail hook
# ------------------------------------------------------------------
@log_guardrail_information
async def async_pre_call_hook(
self,
user_api_key_dict: UserAPIKeyAuth,
cache: DualCache,
data: dict,
call_type: CallTypesLiteral,
) -> Optional[Union[Exception, str, dict]]:
"""
Verifies the incoming token (when configured), validates required claims,
then signs an outbound JWT and injects it as the Authorization header.
All non-MCP call types pass through unchanged.
"""
if call_type != "call_mcp_tool":
return data
# ------------------------------------------------------------------
# FR-5: Verify incoming token before re-signing
# ------------------------------------------------------------------
jwt_claims: Optional[Dict[str, Any]] = None
raw_token: Optional[str] = data.get("incoming_bearer_token")
if self.access_token_discovery_uri and raw_token:
# Three-dot pattern → JWT; otherwise opaque.
is_jwt = raw_token.count(".") == 2
try:
if is_jwt:
jwt_claims = await self._verify_incoming_jwt(raw_token)
elif self.token_introspection_endpoint:
jwt_claims = await self._introspect_opaque_token(raw_token)
else:
verbose_proxy_logger.warning(
"MCPJWTSigner: access_token_discovery_uri is set but the "
"incoming token appears to be opaque and no "
"token_introspection_endpoint is configured. "
"Proceeding without incoming token verification."
)
except Exception as exc:
verbose_proxy_logger.error(
"MCPJWTSigner: incoming token verification failed: %s", exc
)
from fastapi import HTTPException
raise HTTPException(
status_code=401,
detail={
"error": (
f"MCPJWTSigner: incoming token verification failed: {exc}"
)
},
)
elif not raw_token and self.access_token_discovery_uri:
verbose_proxy_logger.debug(
"MCPJWTSigner: access_token_discovery_uri configured but no Bearer "
"token found in request (API-key auth request — skipping verification)."
)
# Fall back to LiteLLM-decoded JWT claims (available when proxy uses JWT auth).
if jwt_claims is None:
jwt_claims = getattr(user_api_key_dict, "jwt_claims", None)
# ------------------------------------------------------------------
# FR-15: Validate required claims
# ------------------------------------------------------------------
self._validate_required_claims(jwt_claims)
# ------------------------------------------------------------------
# Build outbound access token
# ------------------------------------------------------------------
claims = self._build_claims(user_api_key_dict, data, jwt_claims)
signed_token = jwt.encode(
claims,
self._private_key,
algorithm=self.ALGORITHM,
headers={"kid": self._kid},
)
# Merge into existing extra_headers — a prior guardrail in the chain may
# have already injected tracing headers or correlation IDs.
existing_headers: Dict[str, str] = data.get("extra_headers") or {}
new_headers: Dict[str, str] = {
**existing_headers,
"Authorization": f"Bearer {signed_token}",
}
# ------------------------------------------------------------------
# FR-14: Two-token model — channel token
# ------------------------------------------------------------------
if self.channel_token_audience:
channel_claims = self._build_channel_token_claims(claims)
channel_token = jwt.encode(
channel_claims,
self._private_key,
algorithm=self.ALGORITHM,
headers={"kid": self._kid},
)
new_headers["x-mcp-channel-token"] = f"Bearer {channel_token}"
# ------------------------------------------------------------------
# FR-9: Debug header
# ------------------------------------------------------------------
if self.debug_headers:
new_headers["x-litellm-mcp-debug"] = self._build_debug_header(
claims, self._kid
)
data["extra_headers"] = new_headers
verbose_proxy_logger.debug(
"MCPJWTSigner: signed JWT sub=%s act=%s tool=%s exp=%d "
"verified=%s channel=%s",
claims.get("sub"),
claims.get("act", {}).get("sub"),
data.get("mcp_tool_name"),
claims["exp"],
jwt_claims is not None,
bool(self.channel_token_audience),
)
return data

View File

@ -2142,8 +2142,7 @@ async def _resolve_org_filter_for_user_search(
member_org_ids: List[str] = []
if caller_user is not None:
member_org_ids = [
m.organization_id
for m in (caller_user.organization_memberships or [])
m.organization_id for m in (caller_user.organization_memberships or [])
]
if member_org_ids:

View File

@ -1863,16 +1863,10 @@ async def _validate_update_key_data(
user_api_key_cache: Any,
) -> None:
"""Validate permissions and constraints for key update."""
_is_proxy_admin = (
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
)
_is_proxy_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
# Prevent non-admin from removing user_id (setting to empty string) (LIT-1884)
if (
data.user_id is not None
and data.user_id == ""
and not _is_proxy_admin
):
if data.user_id is not None and data.user_id == "" and not _is_proxy_admin:
raise HTTPException(
status_code=403,
detail="Non-admin users cannot remove the user_id from a key.",

View File

@ -857,7 +857,13 @@ async def new_team( # noqa: PLR0915
# Apply defaults from litellm.default_team_params for any fields
# not explicitly provided in the request.
for field in ("max_budget", "budget_duration", "tpm_limit", "rpm_limit", "team_member_permissions"):
for field in (
"max_budget",
"budget_duration",
"tpm_limit",
"rpm_limit",
"team_member_permissions",
):
if getattr(data, field, None) is None:
default_value = _get_default_team_param(field)
if default_value is not None:

View File

@ -857,7 +857,10 @@ async def update_batch_in_database(
# If the batch_processed column doesn't exist (old schema),
# retry without it so the status update still succeeds.
err_str = str(col_err).lower()
if "batch_processed" in err_str and update_data.get("batch_processed") is not None:
if (
"batch_processed" in err_str
and update_data.get("batch_processed") is not None
):
verbose_proxy_logger.warning(
f"batch_processed column not found, retrying update without it: {col_err}"
)

View File

@ -468,6 +468,12 @@ class ProxyInitializationHelpers:
type=str,
help="Path to the logging configuration file",
)
@click.option(
"--setup",
is_flag=True,
default=False,
help="Run the interactive setup wizard to configure providers and generate a config file",
)
@click.option(
"--version",
"-v",
@ -598,6 +604,7 @@ def run_server( # noqa: PLR0915
num_requests,
use_queue,
health,
setup,
version,
run_gunicorn,
run_hypercorn,
@ -611,6 +618,12 @@ def run_server( # noqa: PLR0915
max_requests_before_restart,
enforce_prisma_migration_check: bool,
):
if setup:
from litellm.setup_wizard import run_setup_wizard
run_setup_wizard()
return
args = locals()
if local:
from proxy_server import (
@ -904,7 +917,7 @@ def run_server( # noqa: PLR0915
# Auto-create PROMETHEUS_MULTIPROC_DIR for multi-worker setups
ProxyInitializationHelpers._maybe_setup_prometheus_multiproc_dir(
num_workers=num_workers,
litellm_settings=litellm_settings if config else None,
litellm_settings=litellm_settings if config else None, # type: ignore[possibly-unbound]
)
# --- SEPARATE HEALTH APP LOGIC ---

View File

@ -115,7 +115,9 @@ async def background_streaming_task( # noqa: PLR0915
UPDATE_INTERVAL = 0.150 # 150ms batching interval
# Track the terminal event from the stream (may not be "completed")
terminal_status: Optional[ResponsesAPIStatus] = None # Will be set by response.completed/failed/incomplete/cancelled
terminal_status: Optional[
ResponsesAPIStatus
] = None # Will be set by response.completed/failed/incomplete/cancelled
terminal_error = None
_event_to_status = {
"response.completed": "completed",
@ -259,7 +261,10 @@ async def background_streaming_task( # noqa: PLR0915
)
# Extract error for failed and incomplete responses
if event_type == "response.failed" or event_type == "response.incomplete":
if (
event_type == "response.failed"
or event_type == "response.incomplete"
):
terminal_error = response_data.get("error")
# Core response fields

View File

@ -40,9 +40,7 @@ def _get_registered_vantage_logger():
return None
async def _set_vantage_settings(
api_key: str, integration_token: str, base_url: str
):
async def _set_vantage_settings(api_key: str, integration_token: str, base_url: str):
"""Store Vantage settings in the database with encrypted API key."""
from litellm.proxy.proxy_server import prisma_client
@ -341,9 +339,7 @@ async def init_vantage_settings(
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.error(
f"Error initializing Vantage settings: {str(e)}"
)
verbose_proxy_logger.error(f"Error initializing Vantage settings: {str(e)}")
raise HTTPException(
status_code=500,
detail={"error": f"Failed to initialize Vantage settings: {str(e)}"},
@ -395,7 +391,8 @@ async def vantage_dry_run_export(
"""Cast Decimal columns to Float64 so .to_dicts() produces
JSON-serializable float values instead of decimal.Decimal."""
decimal_cols = [
col for col, dtype in zip(frame.columns, frame.dtypes)
col
for col, dtype in zip(frame.columns, frame.dtypes)
if isinstance(dtype, pl.Decimal)
]
if decimal_cols:
@ -404,8 +401,16 @@ async def vantage_dry_run_export(
)
return frame.to_dicts()
usage_sample = _to_json_safe_dicts(data.head(min(50, len(data)))) if not data.is_empty() else []
normalized_sample = _to_json_safe_dicts(normalized.head(min(50, len(normalized)))) if not normalized.is_empty() else []
usage_sample = (
_to_json_safe_dicts(data.head(min(50, len(data))))
if not data.is_empty()
else []
)
normalized_sample = (
_to_json_safe_dicts(normalized.head(min(50, len(normalized))))
if not normalized.is_empty()
else []
)
# Use the same pre-transform column names as
# FocusExportEngine.dry_run_export_usage_data for consistency.
@ -437,14 +442,10 @@ async def vantage_dry_run_export(
except HTTPException:
raise
except Exception as e:
verbose_proxy_logger.error(
f"Error performing Vantage dry run export: {str(e)}"
)
verbose_proxy_logger.error(f"Error performing Vantage dry run export: {str(e)}")
raise HTTPException(
status_code=500,
detail={
"error": f"Failed to perform Vantage dry run export: {str(e)}"
},
detail={"error": f"Failed to perform Vantage dry run export: {str(e)}"},
)

View File

@ -454,8 +454,6 @@ class ProxyLogging:
for hook in PROXY_HOOKS:
proxy_hook = get_proxy_hook(hook)
import inspect
expected_args = inspect.getfullargspec(proxy_hook).args
passed_in_args: Dict[str, Any] = {}
if "internal_usage_cache" in expected_args:
@ -559,6 +557,10 @@ class ProxyLogging:
"user_api_key_request_route": kwargs.get("user_api_key_request_route"),
"mcp_tool_name": request_obj.tool_name, # Keep original for reference
"mcp_arguments": request_obj.arguments, # Keep original for reference
# Raw Bearer token from the original HTTP request — allows guardrails
# (e.g. MCPJWTSigner) to independently verify the caller's identity
# before re-signing an outbound token (FR-5 verify+re-sign).
"incoming_bearer_token": kwargs.get("incoming_bearer_token"),
}
return synthetic_data
@ -824,17 +826,30 @@ class ProxyLogging:
) -> dict:
"""
Helper function to convert pre_call_hook response back to kwargs for MCP usage.
Supports:
- modified_arguments: Override tool call arguments
- extra_headers: Inject custom headers into the outbound MCP request
"""
if not response_data:
return original_kwargs
# Apply any argument modifications from the hook response
modified_kwargs = original_kwargs.copy()
# If the response contains modified arguments, apply them
if response_data.get("modified_arguments"):
modified_kwargs["arguments"] = response_data["modified_arguments"]
if response_data.get("extra_headers"):
# Merge rather than replace — a prior guardrail in the chain may have
# already injected headers (e.g. tracing IDs). Later guardrails win on
# key collisions so that the most-specific guardrail (e.g. JWT signer)
# takes precedence over earlier ones.
existing = modified_kwargs.get("extra_headers") or {}
modified_kwargs["extra_headers"] = {
**existing,
**response_data["extra_headers"],
}
return modified_kwargs
async def process_pre_call_hook_response(self, response, data, call_type):

View File

@ -7,7 +7,9 @@ from litellm.types.videos.utils import encode_character_id_with_provider
def extract_model_from_target_model_names(target_model_names: Any) -> Optional[str]:
if isinstance(target_model_names, str):
target_model_names = [m.strip() for m in target_model_names.split(",") if m.strip()]
target_model_names = [
m.strip() for m in target_model_names.split(",") if m.strip()
]
elif not isinstance(target_model_names, list):
return None
return target_model_names[0] if target_model_names else None

View File

@ -692,11 +692,11 @@ def responses(
return run_async_function(aresponses_api_with_mcp, **mcp_call_kwargs)
# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=custom_llm_provider,
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=custom_llm_provider,
)
local_vars.update(kwargs)
@ -908,11 +908,11 @@ def delete_responses(
raise ValueError("custom_llm_provider is required but passed as None")
# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=custom_llm_provider,
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=custom_llm_provider,
)
if responses_api_provider_config is None:
@ -1089,11 +1089,11 @@ def get_responses(
raise ValueError("custom_llm_provider is required but passed as None")
# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=custom_llm_provider,
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=custom_llm_provider,
)
if responses_api_provider_config is None:
@ -1247,11 +1247,11 @@ def list_input_items(
if custom_llm_provider is None:
raise ValueError("custom_llm_provider is required but passed as None")
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=custom_llm_provider,
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=custom_llm_provider,
)
if responses_api_provider_config is None:
@ -1406,11 +1406,11 @@ def cancel_responses(
raise ValueError("custom_llm_provider is required but passed as None")
# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=custom_llm_provider,
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=None,
provider=custom_llm_provider,
)
if responses_api_provider_config is None:
@ -1594,11 +1594,11 @@ def compact_responses(
raise ValueError("custom_llm_provider is required but passed as None")
# get provider config
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=custom_llm_provider,
)
responses_api_provider_config: Optional[
BaseResponsesAPIConfig
] = ProviderConfigManager.get_provider_responses_api_config(
model=model,
provider=custom_llm_provider,
)
if responses_api_provider_config is None:

View File

@ -8611,6 +8611,7 @@ class Router:
_model_info = deployment.get("model_info", {})
# see if we have the info for this model
_deployment_model = None # per-deployment model name (avoids overwriting the outer `model` group name)
try:
base_model = _model_info.get("base_model", None)
if base_model is None:
@ -8618,7 +8619,7 @@ class Router:
model_info = self.get_router_model_info(
deployment=deployment, received_model_name=model
)
model = base_model or _litellm_params.get("model", None)
_deployment_model = base_model or _litellm_params.get("model", None)
if (
isinstance(model_info, dict)
@ -8632,7 +8633,9 @@ class Router:
_context_window_error = True
_potential_error_str += (
"Model={}, Max Input Tokens={}, Got={}".format(
model, model_info["max_input_tokens"], input_tokens
_deployment_model,
model_info["max_input_tokens"],
input_tokens,
)
)
continue
@ -8688,13 +8691,21 @@ class Router:
## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_format' param
if request_kwargs is not None and litellm.drop_params is False:
# get supported params
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, litellm_params=LiteLLM_Params(**_litellm_params)
# get supported params — use per-deployment model to avoid overwriting the outer model group name
_dep_model_for_params = _deployment_model or model
(
_dep_model_for_params,
custom_llm_provider,
_,
_,
) = litellm.get_llm_provider(
model=_dep_model_for_params,
litellm_params=LiteLLM_Params(**_litellm_params),
)
supported_openai_params = litellm.get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
model=_dep_model_for_params,
custom_llm_provider=custom_llm_provider,
)
if supported_openai_params is None:

668
litellm/setup_wizard.py Normal file
View File

@ -0,0 +1,668 @@
# ruff: noqa: T201
# flake8: noqa: T201
"""
LiteLLM Interactive Setup Wizard
Guides users through selecting LLM providers, entering API keys,
and generating a proxy config file mirroring the Claude Code onboarding UX.
"""
import importlib.metadata
import os
import re
import secrets
import sys
import sysconfig
from pathlib import Path
from typing import Dict, List, Optional, Set
# termios / tty are Unix-only; fall back gracefully on Windows
try:
import termios
import tty
_HAS_RAW_TERMINAL: bool = True
except ImportError:
termios = None # type: ignore[assignment]
tty = None # type: ignore[assignment]
_HAS_RAW_TERMINAL = False
from litellm.utils import check_valid_key
# ---------------------------------------------------------------------------
# Provider definitions
# ---------------------------------------------------------------------------
# Each entry describes one provider card shown in the wizard.
# `env_key` — primary env var name (None = no key needed, e.g. Ollama)
# `test_model` — model passed to check_valid_key for credential validation
# (None = skip validation, e.g. Azure needs a deployment name)
# `models` — default models written into the generated config
# ---------------------------------------------------------------------------
PROVIDERS: List[Dict] = [
{
"id": "openai",
"name": "OpenAI",
"description": "GPT-4o, GPT-4o-mini, o3-mini",
"env_key": "OPENAI_API_KEY",
"key_hint": "sk-...",
"test_model": "gpt-4o-mini",
"models": ["gpt-4o", "gpt-4o-mini"],
},
{
"id": "anthropic",
"name": "Anthropic",
"description": "Claude Opus 4.6, Sonnet 4.6, Haiku 4.5",
"env_key": "ANTHROPIC_API_KEY",
"key_hint": "sk-ant-...",
"test_model": "claude-haiku-4-5-20251001",
"models": ["claude-opus-4-6", "claude-sonnet-4-6", "claude-haiku-4-5-20251001"],
},
{
"id": "gemini",
"name": "Google Gemini",
"description": "Gemini 2.0 Flash, Gemini 2.5 Pro",
"env_key": "GEMINI_API_KEY",
"key_hint": "AIza...",
"test_model": "gemini/gemini-2.0-flash",
"models": ["gemini/gemini-2.0-flash", "gemini/gemini-2.5-pro"],
},
{
"id": "azure",
"name": "Azure OpenAI",
"description": "GPT-4o via Azure",
"env_key": "AZURE_API_KEY",
"key_hint": "your-azure-key",
"test_model": None, # needs deployment name — skip validation
"models": [],
"needs_api_base": True,
"api_base_hint": "https://<resource>.openai.azure.com/",
"api_version": "2024-07-01-preview",
},
{
"id": "bedrock",
"name": "AWS Bedrock",
"description": "Claude 3.5, Llama 3 via AWS",
"env_key": "AWS_ACCESS_KEY_ID",
"key_hint": "AKIA...",
"test_model": None, # multi-key auth — skip validation
"models": ["bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0"],
"extra_keys": ["AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"],
"extra_hints": ["your-secret-key", "us-east-1"],
},
{
"id": "ollama",
"name": "Ollama",
"description": "Local models (llama3.2, mistral, etc.)",
"env_key": None,
"key_hint": None,
"test_model": None, # local — no remote validation
"models": ["ollama/llama3.2", "ollama/mistral"],
"api_base": "http://localhost:11434",
},
]
# ---------------------------------------------------------------------------
# ANSI colour helpers
# ---------------------------------------------------------------------------
_ANSI_RE = re.compile(r"\033\[[^m]*m")
_ORANGE = "\033[38;2;215;119;87m"
_DIM = "\033[2m"
_BOLD = "\033[1m"
_GREEN = "\033[38;2;78;186;101m"
_BLUE = "\033[38;2;177;185;249m"
_GREY = "\033[38;2;153;153;153m"
_RESET = "\033[0m"
_CHECK = ""
_CROSS = ""
_CURSOR_HIDE = "\033[?25l"
_CURSOR_SHOW = "\033[?25h"
_MOVE_UP = "\033[{}A"
def _supports_color() -> bool:
return sys.stdout.isatty() and os.environ.get("NO_COLOR") is None
def _c(code: str, text: str) -> str:
return f"{code}{text}{_RESET}" if _supports_color() else text
def orange(t: str) -> str:
return _c(_ORANGE, t)
def bold(t: str) -> str:
return _c(_BOLD, t)
def green(t: str) -> str:
return _c(_GREEN, t)
def blue(t: str) -> str:
return _c(_BLUE, t)
def grey(t: str) -> str:
return _c(_GREY, t)
def dim(t: str) -> str:
return _c(_DIM, t)
def _divider() -> str:
"""Return a styled divider line (evaluated at call-time, not import-time)."""
return dim(" " + "" * 74)
def _styled_input(prompt: str) -> str:
"""
Like input() but wraps ANSI sequences in readline ignore markers
(\\001...\\002) so readline correctly tracks the cursor column.
In non-TTY contexts, strips ANSI entirely so no escape codes appear.
"""
if sys.stdout.isatty():
rl_prompt = _ANSI_RE.sub(lambda m: f"\001{m.group()}\002", prompt)
else:
rl_prompt = _ANSI_RE.sub("", prompt)
return input(rl_prompt).strip()
def _yaml_escape(value: str) -> str:
"""Escape a string for safe embedding in a double-quoted YAML scalar."""
return (
value.replace("\\", "\\\\")
.replace('"', '\\"')
.replace("\n", "\\n")
.replace("\r", "\\r")
.replace("\t", "\\t")
)
# ---------------------------------------------------------------------------
# Layout constants
# ---------------------------------------------------------------------------
LITELLM_ASCII = r"""
"""
# ---------------------------------------------------------------------------
# Setup wizard
# ---------------------------------------------------------------------------
class SetupWizard:
"""
Interactive onboarding wizard: provider selection API keys config file.
All methods are static the class is purely a namespace with clear
single-responsibility sections. Entry point: SetupWizard.run().
"""
# ── entry point ─────────────────────────────────────────────────────────
@staticmethod
def run() -> None:
try:
SetupWizard._wizard()
except (KeyboardInterrupt, EOFError):
print(f"\n\n {grey('Setup cancelled.')}\n")
# ── wizard steps ────────────────────────────────────────────────────────
@staticmethod
def _wizard() -> None:
SetupWizard._print_welcome()
print(f" {bold('Lets get started.')}")
print()
providers = SetupWizard._select_providers()
env_vars = SetupWizard._collect_keys(providers)
port, master_key = SetupWizard._proxy_settings()
config_path = Path(os.getcwd()) / "litellm_config.yaml"
try:
config_path.write_text(
SetupWizard._build_config(providers, env_vars, master_key)
)
except OSError as exc:
print(f"\n {bold(_CROSS + ' Could not write config:')} {exc}")
print(" Try running from a directory you have write access to.\n")
return
SetupWizard._print_success(config_path, port, master_key)
SetupWizard._offer_start(config_path, port, master_key)
# ── welcome ─────────────────────────────────────────────────────────────
@staticmethod
def _print_welcome() -> None:
try:
version = importlib.metadata.version("litellm")
except Exception:
version = "unknown"
print()
print(orange(LITELLM_ASCII.rstrip("\n")))
print(f" {orange('Welcome')} to {bold('LiteLLM')} {grey('v' + version)}")
print()
print(_divider())
print()
# ── provider selector ───────────────────────────────────────────────────
@staticmethod
def _select_providers() -> List[Dict]:
"""Arrow-key multi-select. Falls back to number input if /dev/tty unavailable."""
if not _HAS_RAW_TERMINAL:
return SetupWizard._select_fallback()
try:
return SetupWizard._select_interactive()
except OSError:
return SetupWizard._select_fallback()
@staticmethod
def _read_key() -> str:
"""Read one keypress from /dev/tty in raw mode."""
assert (
termios is not None and tty is not None
) # only called when _HAS_RAW_TERMINAL
with open("/dev/tty", "rb") as tty_fh:
fd = tty_fh.fileno()
old = termios.tcgetattr(fd)
try:
tty.setraw(fd)
ch = tty_fh.read(1)
if ch == b"\x1b":
ch2 = tty_fh.read(1)
if ch2 == b"[":
ch3 = tty_fh.read(1)
return "\x1b[" + ch3.decode("utf-8", errors="replace")
return "\x1b" + ch2.decode("utf-8", errors="replace")
return ch.decode("utf-8", errors="replace")
finally:
termios.tcsetattr(fd, termios.TCSADRAIN, old)
@staticmethod
def _render_selector(cursor: int, selected: Set[int], first_render: bool) -> int:
"""Draw or redraw the provider list. Returns the number of lines printed."""
lines = [
f"\n {bold('Add your first model')}\n",
grey(" ↑↓ to navigate · Space to select · Enter to confirm") + "\n",
"\n",
]
for i, p in enumerate(PROVIDERS):
arrow = blue("") if i == cursor else " "
bullet = green("") if i in selected else grey("")
name_str = bold(p["name"]) if i == cursor else p["name"]
lines.append(f" {arrow} {bullet} {name_str} {grey(p['description'])}\n")
lines.append("\n")
content = "".join(lines)
if not first_render and _supports_color():
sys.stdout.write(_MOVE_UP.format(content.count("\n")))
sys.stdout.write(content)
sys.stdout.flush()
return content.count("\n")
@staticmethod
def _select_interactive() -> List[Dict]:
cursor = 0
selected: set[int] = set()
if _supports_color():
sys.stdout.write(_CURSOR_HIDE)
sys.stdout.flush()
try:
SetupWizard._render_selector(cursor, selected, first_render=True)
while True:
key = SetupWizard._read_key()
dirty = False
if key == "\x1b[A":
cursor = (cursor - 1) % len(PROVIDERS)
dirty = True
elif key == "\x1b[B":
cursor = (cursor + 1) % len(PROVIDERS)
dirty = True
elif key == " ":
selected.symmetric_difference_update({cursor})
dirty = True
elif key in ("\r", "\n"):
if not selected:
selected.add(cursor)
break
elif key in ("\x03", "\x04"):
raise KeyboardInterrupt
if dirty:
SetupWizard._render_selector(cursor, selected, first_render=False)
finally:
if _supports_color():
sys.stdout.write(_CURSOR_SHOW)
sys.stdout.flush()
return [PROVIDERS[i] for i in sorted(selected)]
@staticmethod
def _select_fallback() -> List[Dict]:
"""Number-based fallback when raw terminal input is unavailable."""
print()
print(f" {bold('Add your first model')}")
print(
grey(
" Enter numbers separated by commas (e.g. 1,2). Press Enter to confirm."
)
)
print()
for i, p in enumerate(PROVIDERS, 1):
print(f" {grey(str(i) + '.')} {bold(p['name'])} {grey(p['description'])}")
print()
while True:
raw = _styled_input(f" {blue('')} Provider(s): ")
if not raw:
print(grey(" Please select at least one provider."))
continue
try:
nums = [
int(x.strip())
for x in raw.replace(" ", ",").split(",")
if x.strip()
]
valid = sorted({n for n in nums if 1 <= n <= len(PROVIDERS)})
if not valid:
print(grey(f" Enter numbers between 1 and {len(PROVIDERS)}."))
continue
return [PROVIDERS[i - 1] for i in valid]
except ValueError:
print(grey(" Enter numbers separated by commas, e.g. 1,3"))
# ── key collection ───────────────────────────────────────────────────────
@staticmethod
def _collect_keys(providers: List[Dict]) -> Dict[str, str]:
env_vars: Dict[str, str] = {}
print()
print(_divider())
print()
print(f" {bold('Enter your API keys')}")
print(grey(" Keys are stored only in the generated config file."))
print(
grey(
" Tip: add litellm_config.yaml to .gitignore to avoid committing secrets."
)
)
print()
for p in providers:
if p["env_key"] is None:
print(
f" {green(p['name'])}: {grey('no key needed (uses local Ollama)')}"
)
continue
key = SetupWizard._prompt_key(p)
if not key:
continue
for extra_key, extra_hint in zip(
p.get("extra_keys", []), p.get("extra_hints", [])
):
val = _styled_input(f" {blue('')} {extra_key} {grey(extra_hint)}: ")
if val:
env_vars[extra_key] = val
if p.get("needs_api_base"):
api_base = _styled_input(
f" {blue('')} Azure endpoint URL {grey(p.get('api_base_hint', ''))}: "
)
if api_base:
env_vars[f"_LITELLM_AZURE_API_BASE_{p['id'].upper()}"] = api_base
deployment = _styled_input(
f" {blue('')} Azure deployment name {grey('(e.g. my-gpt4o)')}: "
)
if deployment:
env_vars[
f"_LITELLM_AZURE_DEPLOYMENT_{p['id'].upper()}"
] = deployment
# Store the key returned by validation — may be a re-entered replacement
env_vars[p["env_key"]] = SetupWizard._validate_and_report(p, key)
return env_vars
@staticmethod
def _prompt_key(provider: Dict) -> str:
"""Prompt for a provider's API key, with skip option. Returns the key or ''."""
hint = grey(provider.get("key_hint", ""))
while True:
key = _styled_input(
f" {blue('')} {bold(provider['name'])} API key {hint}: "
)
if key:
return key
print(grey(" Key is required. Leave blank to skip this provider."))
if _styled_input(grey(" Skip? (y/N): ")).lower() == "y":
return ""
@staticmethod
def _validate_and_report(provider: Dict, api_key: str) -> str:
"""
Validate credentials using litellm.utils.check_valid_key and print result.
Offers a re-entry loop on failure. Returns the final (possibly re-entered) key.
"""
test_model: Optional[str] = provider.get("test_model")
if not test_model:
return api_key # Azure / Bedrock / Ollama — skip validation
while True:
print(
f" {grey('Testing connection to ' + provider['name'] + '...')}",
flush=True,
)
valid = check_valid_key(model=test_model, api_key=api_key)
if valid:
print(
f" {green(_CHECK)} {bold(provider['name'])} connected successfully"
)
return api_key
print(f" {_CROSS} {bold(provider['name'])} {grey('— invalid API key')}")
if (
_styled_input(f" {blue('')} Re-enter key? {grey('(y/N)')}: ").lower()
!= "y"
):
return api_key
hint = grey(provider.get("key_hint", ""))
new_key = _styled_input(
f" {blue('')} {bold(provider['name'])} API key {hint}: "
)
if not new_key:
return api_key
api_key = new_key
# ── proxy settings ───────────────────────────────────────────────────────
@staticmethod
def _proxy_settings() -> "tuple[int, str]":
print()
print(_divider())
print()
print(f" {bold('Proxy settings')}")
print()
port = 4000
while True:
port_raw = _styled_input(f" {blue('')} Port {grey('[4000]')}: ")
if not port_raw:
break
if port_raw.isdigit() and 1 <= int(port_raw) <= 65535:
port = int(port_raw)
break
print(grey(" Enter a valid port number (165535)."))
key_raw = _styled_input(f" {blue('')} Master key {grey('[auto-generate]')}: ")
master_key = key_raw if key_raw else f"sk-{secrets.token_urlsafe(32)}"
return port, master_key
# ── config generation ────────────────────────────────────────────────────
@staticmethod
def _build_config(
providers: List[Dict],
env_vars: Dict[str, str],
master_key: str,
) -> str:
env_copy = dict(env_vars) # work on a copy — do not mutate caller's dict
lines = ["model_list:"]
for p in providers:
# Only emit models for providers that actually have credentials
has_creds = p["env_key"] is None or p["env_key"] in env_copy
if not has_creds:
continue
if p["id"] == "azure":
deployment = env_copy.pop(
f"_LITELLM_AZURE_DEPLOYMENT_{p['id'].upper()}", ""
)
if not deployment:
continue # skip Azure entirely if no deployment name was provided
models = [f"azure/{deployment}"]
else:
models = p["models"]
for model in models:
raw_display = model.split("/")[-1] if "/" in model else model
# Qualify azure display names to avoid collision with OpenAI model names
display = f"azure-{raw_display}" if p["id"] == "azure" else raw_display
lines += [
f" - model_name: {display}",
" litellm_params:",
f" model: {model}",
]
if p["env_key"] and p["env_key"] in env_copy:
lines.append(f" api_key: os.environ/{p['env_key']}")
if p.get("api_base"):
lines.append(
f' api_base: "{_yaml_escape(str(p["api_base"]))}"'
)
elif p.get("needs_api_base"):
azure_base_key = f"_LITELLM_AZURE_API_BASE_{p['id'].upper()}"
if azure_base_key in env_copy:
lines.append(
f' api_base: "{_yaml_escape(env_copy.pop(azure_base_key))}"'
)
if p.get("api_version"):
lines.append(f" api_version: {p['api_version']}")
lines += [
"",
"general_settings:",
f' master_key: "{_yaml_escape(master_key)}"',
"",
]
real_vars = {k: v for k, v in env_copy.items() if not k.startswith("_LITELLM_")}
if real_vars:
lines.append("environment_variables:")
for k, v in real_vars.items():
lines.append(f' {k}: "{_yaml_escape(v)}"')
lines.append("")
return "\n".join(lines)
# ── success + launch ─────────────────────────────────────────────────────
@staticmethod
def _print_success(config_path: Path, port: int, master_key: str) -> None:
print()
print(_divider())
print()
print(f" {green(_CHECK + ' Config saved')}{bold(str(config_path))}")
print()
print(f" {bold('To start your proxy:')}")
print()
print(f" {grey('$')} litellm --config {config_path} --port {port}")
print()
print(f" {bold('Then set your client:')}")
print()
print(f" export OPENAI_BASE_URL=http://localhost:{port}")
print(f" export OPENAI_API_KEY={master_key}")
print()
print(_divider())
print()
@staticmethod
def _offer_start(config_path: Path, port: int, master_key: str) -> None:
start = _styled_input(
f" {blue('')} Start the proxy now? {grey('(Y/n)')}: "
).lower()
if start not in ("", "y", "yes"):
print()
print(
f" Run {bold(f'litellm --config {config_path}')} whenever you're ready."
)
print()
print(
grey(f" Quick test once running: curl http://localhost:{port}/health")
)
print()
return
print()
print(_divider())
print()
print(f" {bold('Proxy is starting on')} http://localhost:{port}")
print()
print(grey(" Your proxy is OpenAI-compatible. Point any OpenAI SDK at it:"))
print()
print(f" export OPENAI_BASE_URL=http://localhost:{port}")
print(f" export OPENAI_API_KEY={master_key}")
print()
print(grey(" Quick test (in another terminal):"))
print()
print(f" curl http://localhost:{port}/health")
print()
print(grey(" Dashboard:"))
print()
print(f" http://localhost:{port}/ui {grey('(login with your master key)')}")
print()
print(_divider())
print()
print(f" {green(_CHECK)} Starting… {grey('(Ctrl+C to stop)')}")
print()
scripts_dir = sysconfig.get_path("scripts")
litellm_bin = os.path.join(scripts_dir or "", "litellm")
try:
os.execlp(
litellm_bin,
litellm_bin,
"--config",
str(config_path),
"--port",
str(port),
) # noqa: S606
except OSError as exc:
print(f"\n {bold(_CROSS + ' Could not start proxy:')} {exc}")
print(f" Run manually: litellm --config {config_path} --port {port}\n")
# ---------------------------------------------------------------------------
# Public entrypoint
# ---------------------------------------------------------------------------
def run_setup_wizard() -> None:
"""Run the interactive setup wizard. Called by `litellm --setup`."""
SetupWizard.run()

View File

@ -79,6 +79,7 @@ class SupportedGuardrailIntegrations(Enum):
SEMANTIC_GUARD = "semantic_guard"
MCP_END_USER_PERMISSION = "mcp_end_user_permission"
BLOCK_CODE_EXECUTION = "block_code_execution"
MCP_JWT_SIGNER = "mcp_jwt_signer"
class Role(Enum):

View File

@ -39,7 +39,8 @@ class VantageExportRequest(BaseModel):
"""Request model for Vantage export operations (actual export, no default limit)"""
limit: Optional[int] = Field(
None, description="Optional limit on number of records to export (default: no limit)"
None,
description="Optional limit on number of records to export (default: no limit)",
)
start_time_utc: Optional[datetime] = Field(
None, description="Start time for data export in UTC"

View File

@ -195,7 +195,9 @@ def decode_character_id_with_provider(encoded_character_id: str) -> DecodedChara
character_id=decoded_character_id,
)
except Exception as e:
verbose_logger.debug(f"Error decoding character_id '{encoded_character_id}': {e}")
verbose_logger.debug(
f"Error decoding character_id '{encoded_character_id}': {e}"
)
return DecodedCharacterId(
custom_llm_provider=None,
model_id=None,

View File

@ -1186,13 +1186,17 @@ def video_create_character(
litellm_params = GenericLiteLLMParams(**kwargs)
provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config(
provider_config: Optional[
BaseVideoConfig
] = ProviderConfigManager.get_provider_video_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
if provider_config is None:
raise ValueError(f"video create character is not supported for {custom_llm_provider}")
raise ValueError(
f"video create character is not supported for {custom_llm_provider}"
)
local_vars.update(kwargs)
request_params: Dict = {"name": name}
@ -1311,13 +1315,17 @@ def video_get_character(
litellm_params = GenericLiteLLMParams(**kwargs)
provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config(
provider_config: Optional[
BaseVideoConfig
] = ProviderConfigManager.get_provider_video_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
if provider_config is None:
raise ValueError(f"video get character is not supported for {custom_llm_provider}")
raise ValueError(
f"video get character is not supported for {custom_llm_provider}"
)
local_vars.update(kwargs)
request_params: Dict = {"character_id": character_id}
@ -1439,7 +1447,9 @@ def video_edit(
litellm_params = GenericLiteLLMParams(**kwargs)
provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config(
provider_config: Optional[
BaseVideoConfig
] = ProviderConfigManager.get_provider_video_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
@ -1572,16 +1582,24 @@ def video_extension(
litellm_params = GenericLiteLLMParams(**kwargs)
provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config(
provider_config: Optional[
BaseVideoConfig
] = ProviderConfigManager.get_provider_video_config(
model=None,
provider=litellm.LlmProviders(custom_llm_provider),
)
if provider_config is None:
raise ValueError(f"video extension is not supported for {custom_llm_provider}")
raise ValueError(
f"video extension is not supported for {custom_llm_provider}"
)
local_vars.update(kwargs)
request_params: Dict = {"video_id": video_id, "prompt": prompt, "seconds": seconds}
request_params: Dict = {
"video_id": video_id,
"prompt": prompt,
"seconds": seconds,
}
litellm_logging_obj.update_environment_variables(
model="",

View File

@ -32354,6 +32354,53 @@
"supports_vision": true,
"supports_web_search": true
},
"xai/grok-4.20-multi-agent-beta-0309": {
"cache_read_input_token_cost": 2e-07,
"input_cost_per_token": 2e-06,
"litellm_provider": "xai",
"max_input_tokens": 2000000,
"max_output_tokens": 2000000,
"max_tokens": 2000000,
"mode": "chat",
"output_cost_per_token": 6e-06,
"source": "https://docs.x.ai/docs/models",
"supports_function_calling": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"supports_vision": true,
"supports_web_search": true
},
"xai/grok-4.20-beta-0309-reasoning": {
"cache_read_input_token_cost": 2e-07,
"input_cost_per_token": 2e-06,
"litellm_provider": "xai",
"max_input_tokens": 2000000,
"max_output_tokens": 2000000,
"max_tokens": 2000000,
"mode": "chat",
"output_cost_per_token": 6e-06,
"source": "https://docs.x.ai/docs/models",
"supports_function_calling": true,
"supports_reasoning": true,
"supports_tool_choice": true,
"supports_vision": true,
"supports_web_search": true
},
"xai/grok-4.20-beta-0309-non-reasoning": {
"cache_read_input_token_cost": 2e-07,
"input_cost_per_token": 2e-06,
"litellm_provider": "xai",
"max_input_tokens": 2000000,
"max_output_tokens": 2000000,
"max_tokens": 2000000,
"mode": "chat",
"output_cost_per_token": 6e-06,
"source": "https://docs.x.ai/docs/models",
"supports_function_calling": true,
"supports_tool_choice": true,
"supports_vision": true,
"supports_web_search": true
},
"xai/grok-beta": {
"input_cost_per_token": 5e-06,
"litellm_provider": "xai",

143
scripts/install.sh Executable file
View File

@ -0,0 +1,143 @@
#!/usr/bin/env bash
# LiteLLM Installer
# Usage: curl -fsSL https://raw.githubusercontent.com/BerriAI/litellm/main/scripts/install.sh | sh
#
# NOTE: set -e without pipefail for POSIX sh compatibility (dash on Ubuntu/Debian
# ignores the shebang when invoked as `sh` and does not support `pipefail`).
set -eu
MIN_PYTHON_MAJOR=3
MIN_PYTHON_MINOR=9
# NOTE: before merging, this must stay as "litellm[proxy]" to install from PyPI.
LITELLM_PACKAGE="litellm[proxy]"
# ── colours ────────────────────────────────────────────────────────────────
if [ -t 1 ]; then
BOLD='\033[1m'
GREEN='\033[38;2;78;186;101m'
GREY='\033[38;2;153;153;153m'
RESET='\033[0m'
else
BOLD='' GREEN='' GREY='' RESET=''
fi
info() { printf "${GREY} %s${RESET}\n" "$*"; }
success() { printf "${GREEN} ✔ %s${RESET}\n" "$*"; }
header() { printf "${BOLD} %s${RESET}\n" "$*"; }
die() { printf "\n Error: %s\n\n" "$*" >&2; exit 1; }
# ── banner ─────────────────────────────────────────────────────────────────
echo ""
cat << 'EOF'
██╗ ██╗████████╗███████╗██╗ ██╗ ███╗ ███╗
██║ ██║╚══██╔══╝██╔════╝██║ ██║ ████╗ ████║
██║ ██║ ██║ █████╗ ██║ ██║ ██╔████╔██║
██║ ██║ ██║ ██╔══╝ ██║ ██║ ██║╚██╔╝██║
███████╗██║ ██║ ███████╗███████╗███████╗██║ ╚═╝ ██║
╚══════╝╚═╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚═╝ ╚═╝
EOF
printf " ${BOLD}LiteLLM Installer${RESET} ${GREY}— unified gateway for 100+ LLM providers${RESET}\n\n"
# ── OS detection ───────────────────────────────────────────────────────────
OS="$(uname -s)"
ARCH="$(uname -m)"
case "$OS" in
Darwin) PLATFORM="macOS ($ARCH)" ;;
Linux) PLATFORM="Linux ($ARCH)" ;;
*) die "Unsupported OS: $OS. LiteLLM supports macOS and Linux." ;;
esac
info "Platform: $PLATFORM"
# ── Python detection ───────────────────────────────────────────────────────
PYTHON_BIN=""
for candidate in python3 python; do
if command -v "$candidate" >/dev/null 2>&1; then
major="$("$candidate" -c 'import sys; print(sys.version_info.major)' 2>/dev/null || true)"
minor="$("$candidate" -c 'import sys; print(sys.version_info.minor)' 2>/dev/null || true)"
if [ "${major:-0}" -ge "$MIN_PYTHON_MAJOR" ] && [ "${minor:-0}" -ge "$MIN_PYTHON_MINOR" ]; then
PYTHON_BIN="$(command -v "$candidate")"
info "Python: $("$candidate" --version 2>&1)"
break
fi
fi
done
if [ -z "$PYTHON_BIN" ]; then
die "Python ${MIN_PYTHON_MAJOR}.${MIN_PYTHON_MINOR}+ is required but not found.
Install it from https://python.org/downloads or via your package manager:
macOS: brew install python@3
Ubuntu: sudo apt install python3 python3-pip"
fi
# ── pip detection ──────────────────────────────────────────────────────────
if ! "$PYTHON_BIN" -m pip --version >/dev/null 2>&1; then
die "pip is not available. Install it with:
$PYTHON_BIN -m ensurepip --upgrade"
fi
# ── install ────────────────────────────────────────────────────────────────
echo ""
header "Installing litellm[proxy]…"
echo ""
"$PYTHON_BIN" -m pip install --upgrade "${LITELLM_PACKAGE}" \
|| die "pip install failed. Try manually: $PYTHON_BIN -m pip install '${LITELLM_PACKAGE}'"
# ── find the litellm binary installed by pip for this Python ───────────────
# sysconfig.get_path('scripts') is where pip puts console scripts — reliable
# even when the Python lives in a libexec/ symlink tree (e.g. Homebrew).
SCRIPTS_DIR="$("$PYTHON_BIN" -c 'import sysconfig; print(sysconfig.get_path("scripts"))')"
LITELLM_BIN="${SCRIPTS_DIR}/litellm"
if [ ! -x "$LITELLM_BIN" ]; then
# Fall back to user-base bin (pip install --user)
USER_BIN="$("$PYTHON_BIN" -c 'import site; print(site.getuserbase())')/bin"
LITELLM_BIN="${USER_BIN}/litellm"
fi
if [ ! -x "$LITELLM_BIN" ]; then
die "litellm binary not found after install. Try: $PYTHON_BIN -m pip install --user '${LITELLM_PACKAGE}'"
fi
# ── success banner ─────────────────────────────────────────────────────────
echo ""
success "LiteLLM installed"
installed_ver="$("$LITELLM_BIN" --version 2>&1 | grep -oE '[0-9]+\.[0-9]+\.[0-9]+' | head -1 || true)"
[ -n "$installed_ver" ] && info "Version: $installed_ver"
# ── PATH hint ──────────────────────────────────────────────────────────────
if ! command -v litellm >/dev/null 2>&1; then
info "Note: add litellm to your PATH: export PATH=\"\$PATH:${SCRIPTS_DIR}\""
fi
# ── launch setup wizard ────────────────────────────────────────────────────
echo ""
printf " ${BOLD}Run the interactive setup wizard?${RESET} ${GREY}(Y/n)${RESET}: "
# /dev/tty may be unavailable in Docker/CI — default to yes if it can't be read
answer=""
if [ -r /dev/tty ]; then
read -r answer </dev/tty || answer=""
fi
if [ -z "$answer" ] || [ "$answer" = "y" ] || [ "$answer" = "Y" ]; then
echo ""
# Use /dev/tty for interactive input when available (stdin is a pipe from curl)
if [ -r /dev/tty ]; then
exec "$LITELLM_BIN" --setup </dev/tty
else
exec "$LITELLM_BIN" --setup
fi
else
echo ""
header "Quick start:"
echo ""
info " litellm --setup # interactive wizard"
info " litellm --model gpt-4o # single-model quickstart"
echo ""
info "Docs: https://docs.litellm.ai"
echo ""
fi

View File

@ -907,8 +907,9 @@ def test_router_region_pre_call_check(allowed_model_region):
{
"model_name": "gpt-3.5-turbo-large", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "gpt-3.5-turbo-1106",
"model": "gpt-4.1-mini",
"api_key": os.getenv("OPENAI_API_KEY"),
"mock_response": "This is a mock response.",
},
"model_info": {"id": "2"},
},

View File

@ -60,10 +60,33 @@ def assert_langfuse_request_matches_expected(
)
]
# When aggregating from multiple flush cycles, deduplicate by keeping
# only one trace-create and one generation-create per trace_id.
seen_types: dict = {}
deduped_batch: list = []
for item in actual_request_body["batch"]:
item_type = item["type"]
if item_type not in seen_types:
seen_types[item_type] = True
deduped_batch.append(item)
actual_request_body["batch"] = deduped_batch
# Ensure canonical order: trace-create first, generation-create second
actual_request_body["batch"].sort(
key=lambda x: 0 if x["type"] == "trace-create" else 1
)
print(
"actual_request_body after filtering", json.dumps(actual_request_body, indent=4)
)
assert len(actual_request_body["batch"]) >= 2, (
f"Expected at least 2 batch items (trace-create + generation-create) "
f"after filtering by trace_id={trace_id}, "
f"but got {len(actual_request_body['batch'])}. "
f"Items: {json.dumps(actual_request_body['batch'], indent=2)}"
)
# Replace dynamic values in actual request body
for item in actual_request_body["batch"]:
@ -150,19 +173,36 @@ class TestLangfuseLogging:
"""Helper method to verify Langfuse API calls"""
await asyncio.sleep(3)
# Verify the call
# Verify at least one call was made
assert mock_post.call_count >= 1
url = mock_post.call_args[0][0]
request_body = mock_post.call_args[1].get("content")
# Parse the JSON string into a dict for assertions
actual_request_body = json.loads(request_body)
# Aggregate batch items from ALL calls — the Langfuse SDK may split
# trace-create and generation-create across separate HTTP flushes.
langfuse_url = "https://us.cloud.langfuse.com/api/public/ingestion"
all_batch_items: list = []
metadata: Optional[dict] = None
for call in mock_post.call_args_list:
url = call[0][0]
if url != langfuse_url:
continue
request_body = call[1].get("content")
if request_body:
body = json.loads(request_body)
all_batch_items.extend(body.get("batch", []))
if metadata is None:
metadata = body.get("metadata")
print("\nMocked Request Details:")
print(f"URL: {url}")
assert len(all_batch_items) > 0, "No Langfuse ingestion calls found"
assert metadata is not None, "No metadata found in Langfuse calls"
actual_request_body = {
"batch": all_batch_items,
"metadata": metadata,
}
print("\nMocked Request Details (aggregated from all calls):")
print(f"Request Body: {json.dumps(actual_request_body, indent=4)}")
assert url == "https://us.cloud.langfuse.com/api/public/ingestion"
assert_langfuse_request_matches_expected(
actual_request_body,
expected_file_name,
@ -170,6 +210,7 @@ class TestLangfuseLogging:
)
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_langfuse_logging_completion(self, mock_setup):
"""Test Langfuse logging for chat completion"""
setup = mock_setup
@ -185,6 +226,7 @@ class TestLangfuseLogging:
)
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_langfuse_logging_completion_with_tags(self, mock_setup):
"""Test Langfuse logging for chat completion with tags"""
setup = mock_setup
@ -203,6 +245,7 @@ class TestLangfuseLogging:
)
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_langfuse_logging_completion_with_tags_stream(self, mock_setup):
"""Test Langfuse logging for chat completion with tags"""
setup = mock_setup
@ -223,6 +266,7 @@ class TestLangfuseLogging:
)
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_langfuse_logging_completion_with_langfuse_metadata(self, mock_setup):
"""Test Langfuse logging for chat completion with metadata for langfuse"""
setup = mock_setup
@ -252,6 +296,7 @@ class TestLangfuseLogging:
)
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_langfuse_logging_with_non_serializable_metadata(self, mock_setup):
"""Test Langfuse logging with metadata that requires preparation (Pydantic models, sets, etc)"""
from pydantic import BaseModel
@ -358,6 +403,7 @@ class TestLangfuseLogging:
)
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_langfuse_logging_completion_with_malformed_llm_response(
self, mock_setup
):
@ -387,6 +433,7 @@ class TestLangfuseLogging:
)
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_langfuse_logging_completion_with_bedrock_llm_response(
self, mock_setup
):
@ -418,6 +465,7 @@ class TestLangfuseLogging:
setup["mock_post"], "completion_with_bedrock_call.json", setup["trace_id"]
)
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_langfuse_logging_completion_with_vertex_llm_response(
self, mock_setup
):
@ -449,6 +497,7 @@ class TestLangfuseLogging:
)
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_langfuse_logging_vllm_embedding(self, mock_setup):
"""
Test that the request sent to the vllm embedding endpoint is correct.
@ -500,6 +549,7 @@ class TestLangfuseLogging:
)
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_langfuse_logging_with_router(self, mock_setup):
"""Test Langfuse logging with router"""
litellm._turn_on_debug()

View File

@ -308,16 +308,19 @@ async def test_long_term_spend_accuracy_with_bursts():
response = await chat_completion(session, key)
print(f"Burst 2 - Request {i + 1}/{BURST_2_REQUESTS} completed")
# Poll until key spend reflects burst 2
burst_1_spend = intermediate_key_info["info"]["spend"]
# Poll until key spend reaches expected total (burst 1 + burst 2)
start = time.time()
while time.time() - start < 120:
key_info_check = await get_spend_info(session, "key", key)
current_spend = key_info_check["info"]["spend"]
if current_spend > burst_1_spend:
print(f"Key spend increased to {current_spend} after {time.time() - start:.1f}s")
if abs(current_spend - expected_spend) < TOLERANCE:
print(
f"Total spend reached expected {expected_spend} after {time.time() - start:.1f}s"
)
break
print(f"Key spend still {current_spend}, waiting for burst 2 flush...")
print(
f"Key spend {current_spend}, expected {expected_spend}, waiting..."
)
await asyncio.sleep(10)
# Allow extra time for all entity spend aggregations

View File

@ -77,6 +77,11 @@ class TestRequestCompliance:
schema = spec_dict["components"]["schemas"]["CreateModelInteractionParams"]
input_schema = schema["properties"]["input"]
# The input property may be inline oneOf or a $ref to InteractionsInput
if "$ref" in input_schema:
ref_name = input_schema["$ref"].split("/")[-1]
input_schema = spec_dict["components"]["schemas"][ref_name]
# Should be oneOf with multiple types
assert "oneOf" in input_schema
@ -100,10 +105,21 @@ class TestRequestCompliance:
assert "discriminator" in content_schema
assert content_schema["discriminator"]["propertyName"] == "type"
# Check TextContent is an option
mapping = content_schema["discriminator"]["mapping"]
assert "text" in mapping
print(f"Content type discriminator mapping: {list(mapping.keys())}")
# Check TextContent is an option (via mapping if present, or via oneOf refs)
mapping = content_schema["discriminator"].get("mapping")
if mapping:
assert "text" in mapping
print(f"Content type discriminator mapping: {list(mapping.keys())}")
else:
# Discriminator without explicit mapping — verify via oneOf
one_of = content_schema.get("oneOf", [])
ref_names = [
opt["$ref"].split("/")[-1] for opt in one_of if "$ref" in opt
]
assert "TextContent" in ref_names, (
f"TextContent not found in oneOf refs: {ref_names}"
)
print(f"Content type discriminator (no mapping), oneOf refs: {ref_names}")
def test_text_content_schema(self, spec_dict):
"""Verify TextContent schema."""

View File

@ -386,8 +386,56 @@ class TestXAICostCalculator:
completion_tokens=50,
total_tokens=150,
)
web_search_cost = cost_per_web_search_request(usage=usage, model_info={})
# Expected cost: No web search data = $0.0
assert web_search_cost == 0.0
def test_grok_4_20_beta_reasoning_cost_calculation(self):
"""Test cost calculation for grok-4.20-beta-0309-reasoning model."""
usage = Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)
prompt_cost, completion_cost = cost_per_token(
model="grok-4.20-beta-0309-reasoning", usage=usage
)
# Input: 100 tokens * $2e-6 = $0.0002
# Output: 200 tokens * $6e-6 = $0.0012
expected_prompt_cost = 100 * 2e-6
expected_completion_cost = 200 * 6e-6
assert math.isclose(prompt_cost, expected_prompt_cost, rel_tol=1e-10)
assert math.isclose(completion_cost, expected_completion_cost, rel_tol=1e-10)
def test_grok_4_20_beta_non_reasoning_cost_calculation(self):
"""Test cost calculation for grok-4.20-beta-0309-non-reasoning model."""
usage = Usage(prompt_tokens=50, completion_tokens=100, total_tokens=150)
prompt_cost, completion_cost = cost_per_token(
model="grok-4.20-beta-0309-non-reasoning", usage=usage
)
# Input: 50 tokens * $2e-6 = $0.0001
# Output: 100 tokens * $6e-6 = $0.0006
expected_prompt_cost = 50 * 2e-6
expected_completion_cost = 100 * 6e-6
assert math.isclose(prompt_cost, expected_prompt_cost, rel_tol=1e-10)
assert math.isclose(completion_cost, expected_completion_cost, rel_tol=1e-10)
def test_grok_4_20_multi_agent_cost_calculation(self):
"""Test cost calculation for grok-4.20-multi-agent-beta-0309 model."""
usage = Usage(prompt_tokens=200, completion_tokens=300, total_tokens=500)
prompt_cost, completion_cost = cost_per_token(
model="grok-4.20-multi-agent-beta-0309", usage=usage
)
# Input: 200 tokens * $2e-6 = $0.0004
# Output: 300 tokens * $6e-6 = $0.0018
expected_prompt_cost = 200 * 2e-6
expected_completion_cost = 300 * 6e-6
assert math.isclose(prompt_cost, expected_prompt_cost, rel_tol=1e-10)
assert math.isclose(completion_cost, expected_completion_cost, rel_tol=1e-10)

View File

@ -0,0 +1,707 @@
"""
Tests for pre_mcp_call guardrail hook header mutation support.
Validates that:
1. _convert_mcp_hook_response_to_kwargs extracts extra_headers from hook response
2. pre_call_tool_check returns hook-provided extra_headers AND modified arguments
3. call_tool flows hook headers and modified arguments downstream
4. Hook-provided headers take highest priority (merge after static_headers)
5. OpenAPI-backed servers log a warning and continue (skip injection) when hook headers are present
6. JWT claims are propagated in both standard and virtual-key fast paths
7. Backward compatibility: hooks without extra_headers continue to work
"""
import asyncio
from typing import Any, Dict, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import HTTPException
from litellm.proxy._experimental.mcp_server.mcp_server_manager import MCPServerManager
from litellm.proxy._types import UserAPIKeyAuth
from litellm.proxy.utils import ProxyLogging
from litellm.types.mcp import MCPAuth, MCPTransport
from litellm.types.mcp_server.mcp_server_manager import MCPServer
class TestConvertMcpHookResponseToKwargs:
"""Tests for ProxyLogging._convert_mcp_hook_response_to_kwargs"""
def setup_method(self):
self.proxy_logging = ProxyLogging(user_api_key_cache=MagicMock())
def test_returns_original_kwargs_when_response_is_none(self):
original = {"arguments": {"key": "val"}, "name": "tool"}
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
None, original
)
assert result == original
def test_returns_original_kwargs_when_response_is_empty_dict(self):
original = {"arguments": {"key": "val"}}
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs({}, original)
assert result == original
def test_extracts_modified_arguments(self):
original = {"arguments": {"old": "value"}}
response = {"modified_arguments": {"new": "value"}}
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
response, original
)
assert result["arguments"] == {"new": "value"}
def test_extracts_extra_headers(self):
original = {"arguments": {"key": "val"}}
response = {"extra_headers": {"Authorization": "Bearer signed-jwt"}}
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
response, original
)
assert result["extra_headers"] == {"Authorization": "Bearer signed-jwt"}
def test_extracts_both_arguments_and_headers(self):
original = {"arguments": {"old": "value"}}
response = {
"modified_arguments": {"new": "value"},
"extra_headers": {"X-Custom": "header-val"},
}
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
response, original
)
assert result["arguments"] == {"new": "value"}
assert result["extra_headers"] == {"X-Custom": "header-val"}
def test_no_extra_headers_key_preserves_original(self):
"""Backward compat: hooks that only return modified_arguments still work."""
original = {"arguments": {"key": "val"}}
response = {"modified_arguments": {"key": "new_val"}}
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
response, original
)
assert "extra_headers" not in result
assert result["arguments"] == {"key": "new_val"}
def test_empty_extra_headers_not_set(self):
"""Empty dict for extra_headers is falsy and should not be set."""
original = {"arguments": {"key": "val"}}
response = {"extra_headers": {}}
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
response, original
)
assert "extra_headers" not in result
class TestPreCallToolCheckReturnsHeaders:
"""Tests that pre_call_tool_check returns hook-provided headers."""
def _make_server(self, name="test_server"):
return MCPServer(
server_id="test-id",
name=name,
server_name=name,
url="https://example.com",
transport=MCPTransport.http,
auth_type=MCPAuth.none,
)
@pytest.mark.asyncio
async def test_returns_empty_dict_when_hook_has_no_headers(self):
manager = MCPServerManager()
server = self._make_server()
proxy_logging = MagicMock(spec=ProxyLogging)
proxy_logging._create_mcp_request_object_from_kwargs = MagicMock(
return_value=MagicMock()
)
proxy_logging._convert_mcp_to_llm_format = MagicMock(
return_value={"model": "fake"}
)
proxy_logging.pre_call_hook = AsyncMock(
return_value={"modified_arguments": {"key": "val"}}
)
proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock(
return_value={"arguments": {"key": "val"}}
)
with patch.object(manager, "check_allowed_or_banned_tools", return_value=True):
with patch.object(
manager,
"check_tool_permission_for_key_team",
new_callable=AsyncMock,
):
with patch.object(manager, "validate_allowed_params"):
result = await manager.pre_call_tool_check(
name="test_tool",
arguments={"key": "val"},
server_name="test_server",
user_api_key_auth=None,
proxy_logging_obj=proxy_logging,
server=server,
)
assert result == {}
@pytest.mark.asyncio
async def test_returns_extra_headers_from_hook(self):
manager = MCPServerManager()
server = self._make_server()
hook_headers = {"Authorization": "Bearer signed-jwt", "X-Trace-Id": "abc123"}
proxy_logging = MagicMock(spec=ProxyLogging)
proxy_logging._create_mcp_request_object_from_kwargs = MagicMock(
return_value=MagicMock()
)
proxy_logging._convert_mcp_to_llm_format = MagicMock(
return_value={"model": "fake"}
)
proxy_logging.pre_call_hook = AsyncMock(
return_value={"extra_headers": hook_headers}
)
proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock(
return_value={"arguments": {"key": "val"}, "extra_headers": hook_headers}
)
with patch.object(manager, "check_allowed_or_banned_tools", return_value=True):
with patch.object(
manager,
"check_tool_permission_for_key_team",
new_callable=AsyncMock,
):
with patch.object(manager, "validate_allowed_params"):
result = await manager.pre_call_tool_check(
name="test_tool",
arguments={"key": "val"},
server_name="test_server",
user_api_key_auth=None,
proxy_logging_obj=proxy_logging,
server=server,
)
assert result["extra_headers"] == hook_headers
@pytest.mark.asyncio
async def test_returns_empty_dict_when_hook_returns_none(self):
manager = MCPServerManager()
server = self._make_server()
proxy_logging = MagicMock(spec=ProxyLogging)
proxy_logging._create_mcp_request_object_from_kwargs = MagicMock(
return_value=MagicMock()
)
proxy_logging._convert_mcp_to_llm_format = MagicMock(
return_value={"model": "fake"}
)
proxy_logging.pre_call_hook = AsyncMock(return_value=None)
with patch.object(manager, "check_allowed_or_banned_tools", return_value=True):
with patch.object(
manager,
"check_tool_permission_for_key_team",
new_callable=AsyncMock,
):
with patch.object(manager, "validate_allowed_params"):
result = await manager.pre_call_tool_check(
name="test_tool",
arguments={"key": "val"},
server_name="test_server",
user_api_key_auth=None,
proxy_logging_obj=proxy_logging,
server=server,
)
assert result == {}
@pytest.mark.asyncio
async def test_returns_modified_arguments_from_hook(self):
"""Modified arguments from the hook must be returned so the caller can use them."""
manager = MCPServerManager()
server = self._make_server()
original_args = {"key": "original"}
modified_args = {"key": "modified", "extra": "added"}
proxy_logging = MagicMock(spec=ProxyLogging)
proxy_logging._create_mcp_request_object_from_kwargs = MagicMock(
return_value=MagicMock()
)
proxy_logging._convert_mcp_to_llm_format = MagicMock(
return_value={"model": "fake"}
)
proxy_logging.pre_call_hook = AsyncMock(
return_value={"modified_arguments": modified_args}
)
proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock(
return_value={"arguments": modified_args}
)
with patch.object(manager, "check_allowed_or_banned_tools", return_value=True):
with patch.object(
manager,
"check_tool_permission_for_key_team",
new_callable=AsyncMock,
):
with patch.object(manager, "validate_allowed_params"):
result = await manager.pre_call_tool_check(
name="test_tool",
arguments=original_args,
server_name="test_server",
user_api_key_auth=None,
proxy_logging_obj=proxy_logging,
server=server,
)
assert result["arguments"] == modified_args
@pytest.mark.asyncio
async def test_returns_both_modified_arguments_and_headers(self):
"""Hook can modify both arguments and inject headers simultaneously."""
manager = MCPServerManager()
server = self._make_server()
modified_args = {"key": "modified"}
hook_headers = {"Authorization": "Bearer jwt"}
proxy_logging = MagicMock(spec=ProxyLogging)
proxy_logging._create_mcp_request_object_from_kwargs = MagicMock(
return_value=MagicMock()
)
proxy_logging._convert_mcp_to_llm_format = MagicMock(
return_value={"model": "fake"}
)
proxy_logging.pre_call_hook = AsyncMock(return_value={"dummy": True})
proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock(
return_value={"arguments": modified_args, "extra_headers": hook_headers}
)
with patch.object(manager, "check_allowed_or_banned_tools", return_value=True):
with patch.object(
manager,
"check_tool_permission_for_key_team",
new_callable=AsyncMock,
):
with patch.object(manager, "validate_allowed_params"):
result = await manager.pre_call_tool_check(
name="test_tool",
arguments={"key": "original"},
server_name="test_server",
user_api_key_auth=None,
proxy_logging_obj=proxy_logging,
server=server,
)
assert result["arguments"] == modified_args
assert result["extra_headers"] == hook_headers
class TestCallToolFlowsHookHeaders:
"""Tests that call_tool passes hook_extra_headers to _call_regular_mcp_tool."""
def _make_server(self, name="test_server"):
return MCPServer(
server_id="test-id",
name=name,
server_name=name,
url="https://example.com",
transport=MCPTransport.http,
auth_type=MCPAuth.none,
)
@pytest.mark.asyncio
async def test_hook_headers_passed_to_call_regular_mcp_tool(self):
"""Verify that hook_extra_headers kwarg is forwarded."""
manager = MCPServerManager()
server = self._make_server()
hook_headers = {"Authorization": "Bearer signed-jwt"}
with patch.object(
manager,
"_get_mcp_server_from_tool_name",
return_value=server,
):
with patch.object(
manager,
"pre_call_tool_check",
new_callable=AsyncMock,
return_value={"extra_headers": hook_headers},
):
with patch.object(
manager,
"_create_during_hook_task",
return_value=asyncio.create_task(asyncio.sleep(0)),
):
with patch.object(
manager,
"_call_regular_mcp_tool",
new_callable=AsyncMock,
return_value=MagicMock(),
) as mock_call:
proxy_logging = MagicMock(spec=ProxyLogging)
await manager.call_tool(
server_name="test_server",
name="test_tool",
arguments={"key": "val"},
proxy_logging_obj=proxy_logging,
)
mock_call.assert_called_once()
call_kwargs = mock_call.call_args
assert call_kwargs.kwargs.get("hook_extra_headers") == hook_headers
@pytest.mark.asyncio
async def test_no_hook_headers_when_no_proxy_logging(self):
"""Without proxy_logging_obj, no pre_call_tool_check runs."""
manager = MCPServerManager()
server = self._make_server()
with patch.object(
manager,
"_get_mcp_server_from_tool_name",
return_value=server,
):
with patch.object(
manager,
"_call_regular_mcp_tool",
new_callable=AsyncMock,
return_value=MagicMock(),
) as mock_call:
await manager.call_tool(
server_name="test_server",
name="test_tool",
arguments={"key": "val"},
proxy_logging_obj=None,
)
mock_call.assert_called_once()
call_kwargs = mock_call.call_args
assert call_kwargs.kwargs.get("hook_extra_headers") is None
@pytest.mark.asyncio
async def test_modified_arguments_passed_to_downstream(self):
"""Hook-modified arguments must be used for the actual tool call."""
manager = MCPServerManager()
server = self._make_server()
modified_args = {"key": "modified_by_hook"}
with patch.object(
manager,
"_get_mcp_server_from_tool_name",
return_value=server,
):
with patch.object(
manager,
"pre_call_tool_check",
new_callable=AsyncMock,
return_value={"arguments": modified_args},
):
with patch.object(
manager,
"_create_during_hook_task",
return_value=asyncio.create_task(asyncio.sleep(0)),
):
with patch.object(
manager,
"_call_regular_mcp_tool",
new_callable=AsyncMock,
return_value=MagicMock(),
) as mock_call:
proxy_logging = MagicMock(spec=ProxyLogging)
await manager.call_tool(
server_name="test_server",
name="test_tool",
arguments={"key": "original"},
proxy_logging_obj=proxy_logging,
)
mock_call.assert_called_once()
call_kwargs = mock_call.call_args
assert call_kwargs.kwargs.get("arguments") == modified_args
@pytest.mark.asyncio
async def test_openapi_server_warns_and_continues_on_hook_headers(self):
"""OpenAPI-backed servers log a warning and continue when hook injects headers."""
manager = MCPServerManager()
server = MCPServer(
server_id="test-id",
name="openapi_server",
server_name="openapi_server",
url="https://example.com",
transport=MCPTransport.http,
auth_type=MCPAuth.none,
spec_path="/path/to/spec.yaml",
)
with patch.object(
manager, "_get_mcp_server_from_tool_name", return_value=server
):
with patch.object(
manager,
"pre_call_tool_check",
new_callable=AsyncMock,
return_value={"extra_headers": {"Authorization": "Bearer jwt"}},
):
with patch.object(
manager,
"_create_during_hook_task",
return_value=asyncio.create_task(asyncio.sleep(0)),
):
with patch.object(
manager,
"_call_openapi_tool_handler",
new_callable=AsyncMock,
return_value=MagicMock(),
):
import litellm.proxy._experimental.mcp_server.mcp_server_manager as mgr_mod
proxy_logging = MagicMock(spec=ProxyLogging)
with patch.object(mgr_mod, "verbose_logger") as mock_logger:
# Should NOT raise — just warn and proceed
await manager.call_tool(
server_name="openapi_server",
name="test_tool",
arguments={},
proxy_logging_obj=proxy_logging,
)
mock_logger.warning.assert_called_once()
assert "header injection is not supported" in mock_logger.warning.call_args[0][0]
@pytest.mark.asyncio
async def test_openapi_server_no_error_without_hook_headers(self):
"""No exception when OpenAPI server has no hook-injected headers."""
manager = MCPServerManager()
server = MCPServer(
server_id="test-id",
name="openapi_server",
server_name="openapi_server",
url="https://example.com",
transport=MCPTransport.http,
auth_type=MCPAuth.none,
spec_path="/path/to/spec.yaml",
)
with patch.object(
manager, "_get_mcp_server_from_tool_name", return_value=server
):
with patch.object(
manager,
"pre_call_tool_check",
new_callable=AsyncMock,
return_value={},
):
with patch.object(
manager,
"_create_during_hook_task",
return_value=asyncio.create_task(asyncio.sleep(0)),
):
with patch.object(
manager,
"_call_openapi_tool_handler",
new_callable=AsyncMock,
return_value=MagicMock(),
):
proxy_logging = MagicMock(spec=ProxyLogging)
await manager.call_tool(
server_name="openapi_server",
name="test_tool",
arguments={},
proxy_logging_obj=proxy_logging,
)
class TestHookHeaderMergePriority:
"""Tests that hook-provided headers have highest priority in _call_regular_mcp_tool."""
def _make_server(
self,
static_headers: Optional[Dict[str, str]] = None,
extra_headers_config: Optional[list] = None,
):
return MCPServer(
server_id="test-id",
name="Test Server",
server_name="test_server",
url="https://example.com",
transport=MCPTransport.http,
auth_type=MCPAuth.none,
static_headers=static_headers,
extra_headers=extra_headers_config,
)
@pytest.mark.asyncio
async def test_hook_headers_override_static_headers(self):
"""Hook headers should take precedence over static_headers."""
manager = MCPServerManager()
server = self._make_server(
static_headers={"Authorization": "Bearer static-token", "X-Static": "yes"}
)
hook_headers = {"Authorization": "Bearer hook-signed-jwt"}
captured_extra_headers: Dict[str, Any] = {}
async def fake_create_mcp_client(
server, mcp_auth_header=None, extra_headers=None, stdio_env=None
):
captured_extra_headers["value"] = extra_headers
mock_client = MagicMock()
mock_client.call_tool = AsyncMock(return_value=MagicMock())
return mock_client
with patch.object(
manager, "_create_mcp_client", side_effect=fake_create_mcp_client
):
with patch.object(manager, "_build_stdio_env", return_value=None):
try:
await manager._call_regular_mcp_tool(
mcp_server=server,
original_tool_name="test_tool",
arguments={"key": "val"},
tasks=[],
mcp_auth_header=None,
mcp_server_auth_headers=None,
oauth2_headers=None,
raw_headers=None,
proxy_logging_obj=None,
hook_extra_headers=hook_headers,
)
except Exception:
pass
headers = captured_extra_headers.get("value", {})
assert headers["Authorization"] == "Bearer hook-signed-jwt"
assert headers["X-Static"] == "yes"
@pytest.mark.asyncio
async def test_no_hook_headers_preserves_existing_behavior(self):
"""When hook_extra_headers is None, existing header logic is unchanged."""
manager = MCPServerManager()
server = self._make_server(
static_headers={"X-Static": "static-value"}
)
captured_extra_headers: Dict[str, Any] = {}
async def fake_create_mcp_client(
server, mcp_auth_header=None, extra_headers=None, stdio_env=None
):
captured_extra_headers["value"] = extra_headers
mock_client = MagicMock()
mock_client.call_tool = AsyncMock(return_value=MagicMock())
return mock_client
with patch.object(
manager, "_create_mcp_client", side_effect=fake_create_mcp_client
):
with patch.object(manager, "_build_stdio_env", return_value=None):
try:
await manager._call_regular_mcp_tool(
mcp_server=server,
original_tool_name="test_tool",
arguments={"key": "val"},
tasks=[],
mcp_auth_header=None,
mcp_server_auth_headers=None,
oauth2_headers=None,
raw_headers=None,
proxy_logging_obj=None,
hook_extra_headers=None,
)
except Exception:
pass
headers = captured_extra_headers.get("value", {})
assert headers == {"X-Static": "static-value"}
@pytest.mark.asyncio
async def test_hook_headers_merge_with_oauth2(self):
"""Hook headers merge on top of OAuth2 headers."""
manager = MCPServerManager()
server = MCPServer(
server_id="test-id",
name="Test Server",
server_name="test_server",
url="https://example.com",
transport=MCPTransport.http,
auth_type=MCPAuth.oauth2,
)
captured_extra_headers: Dict[str, Any] = {}
async def fake_create_mcp_client(
server, mcp_auth_header=None, extra_headers=None, stdio_env=None
):
captured_extra_headers["value"] = extra_headers
mock_client = MagicMock()
mock_client.call_tool = AsyncMock(return_value=MagicMock())
return mock_client
with patch.object(
manager, "_create_mcp_client", side_effect=fake_create_mcp_client
):
with patch.object(manager, "_build_stdio_env", return_value=None):
try:
await manager._call_regular_mcp_tool(
mcp_server=server,
original_tool_name="test_tool",
arguments={"key": "val"},
tasks=[],
mcp_auth_header=None,
mcp_server_auth_headers=None,
oauth2_headers={
"Authorization": "Bearer oauth2-token",
"X-OAuth": "yes",
},
raw_headers=None,
proxy_logging_obj=None,
hook_extra_headers={
"Authorization": "Bearer hook-jwt",
"X-Trace-Id": "trace-123",
},
)
except Exception:
pass
headers = captured_extra_headers.get("value", {})
assert headers["Authorization"] == "Bearer hook-jwt"
assert headers["X-OAuth"] == "yes"
assert headers["X-Trace-Id"] == "trace-123"
class TestUserAPIKeyAuthJwtClaims:
"""Tests that UserAPIKeyAuth correctly carries jwt_claims."""
def test_jwt_claims_field_defaults_to_none(self):
auth = UserAPIKeyAuth(api_key="test-key")
assert auth.jwt_claims is None
def test_jwt_claims_field_accepts_dict(self):
claims = {"sub": "user-123", "iss": "litellm", "exp": 9999999999}
auth = UserAPIKeyAuth(api_key="test-key", jwt_claims=claims)
assert auth.jwt_claims == claims
assert auth.jwt_claims["sub"] == "user-123"
def test_jwt_claims_backward_compatible_without_field(self):
"""Existing code that doesn't pass jwt_claims should still work."""
auth = UserAPIKeyAuth(
api_key="test-key",
user_id="user-1",
team_id="team-1",
)
assert auth.jwt_claims is None
assert auth.user_id == "user-1"
def test_jwt_claims_set_after_construction(self):
"""Virtual-key fast path sets jwt_claims after the object is created."""
auth = UserAPIKeyAuth(api_key="test-key")
assert auth.jwt_claims is None
claims = {"sub": "user-456", "iss": "okta", "groups": ["admin"]}
auth.jwt_claims = claims
assert auth.jwt_claims == claims
assert auth.jwt_claims["groups"] == ["admin"]

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,188 @@
"""Unit tests for litellm.setup_wizard — pure functions only, no network calls."""
from litellm.setup_wizard import SetupWizard, _yaml_escape
# ---------------------------------------------------------------------------
# _yaml_escape
# ---------------------------------------------------------------------------
def test_yaml_escape_plain():
assert _yaml_escape("sk-abc123") == "sk-abc123"
def test_yaml_escape_double_quote():
assert _yaml_escape('sk-ab"cd') == 'sk-ab\\"cd'
def test_yaml_escape_backslash():
assert _yaml_escape("sk-ab\\cd") == "sk-ab\\\\cd"
def test_yaml_escape_combined():
assert _yaml_escape('ab\\"cd') == 'ab\\\\\\"cd'
def test_yaml_escape_newline():
assert _yaml_escape("sk-abc\ndef") == "sk-abc\\ndef"
def test_yaml_escape_carriage_return():
assert _yaml_escape("sk-abc\rdef") == "sk-abc\\rdef"
def test_yaml_escape_tab():
assert _yaml_escape("sk-abc\tdef") == "sk-abc\\tdef"
# ---------------------------------------------------------------------------
# SetupWizard._build_config
# ---------------------------------------------------------------------------
_OPENAI = {
"id": "openai",
"name": "OpenAI",
"env_key": "OPENAI_API_KEY",
"models": ["gpt-4o", "gpt-4o-mini"],
"test_model": "gpt-4o-mini",
}
_ANTHROPIC = {
"id": "anthropic",
"name": "Anthropic",
"env_key": "ANTHROPIC_API_KEY",
"models": ["claude-opus-4-6"],
"test_model": "claude-haiku-4-5-20251001",
}
_AZURE = {
"id": "azure",
"name": "Azure OpenAI",
"env_key": "AZURE_API_KEY",
"models": [],
"test_model": None,
"needs_api_base": True,
"api_base_hint": "https://<resource>.openai.azure.com/",
"api_version": "2024-07-01-preview",
}
_OLLAMA = {
"id": "ollama",
"name": "Ollama",
"env_key": None,
"models": ["ollama/llama3.2"],
"test_model": None,
"api_base": "http://localhost:11434",
}
def test_build_config_basic_openai():
config = SetupWizard._build_config(
[_OPENAI],
{"OPENAI_API_KEY": "sk-test"},
"sk-master",
)
assert "model_list:" in config
assert "model_name: gpt-4o" in config
assert "model: gpt-4o" in config
assert "api_key: os.environ/OPENAI_API_KEY" in config
assert 'master_key: "sk-master"' in config
def test_build_config_skipped_provider_omitted():
"""Provider with no key in env_vars should not appear in model_list."""
config = SetupWizard._build_config(
[_OPENAI, _ANTHROPIC],
{"ANTHROPIC_API_KEY": "sk-ant-test"}, # OpenAI key missing
"sk-master",
)
assert "gpt-4o" not in config
assert "claude-opus-4-6" in config
def test_build_config_env_vars_written_escaped():
"""API keys with special chars should be YAML-escaped."""
config = SetupWizard._build_config(
[_OPENAI],
{"OPENAI_API_KEY": 'sk-ab"cd'},
"sk-master",
)
assert 'OPENAI_API_KEY: "sk-ab\\"cd"' in config
def test_build_config_master_key_quoted():
"""master_key must be quoted in YAML to handle special characters."""
config = SetupWizard._build_config(
[_OPENAI],
{"OPENAI_API_KEY": "sk-test"},
'sk-master"special',
)
assert 'master_key: "sk-master\\"special"' in config
def test_build_config_does_not_mutate_env_vars():
"""_build_config must not modify the caller's env_vars dict."""
env_vars = {
"AZURE_API_KEY": "az-key",
"_LITELLM_AZURE_API_BASE_AZURE": "https://my.azure.com",
"_LITELLM_AZURE_DEPLOYMENT_AZURE": "my-deployment",
}
original_keys = set(env_vars.keys())
SetupWizard._build_config([_AZURE], env_vars, "sk-master")
assert set(env_vars.keys()) == original_keys
def test_build_config_azure_uses_deployment_name():
env_vars = {
"AZURE_API_KEY": "az-key",
"_LITELLM_AZURE_API_BASE_AZURE": "https://my.azure.com",
"_LITELLM_AZURE_DEPLOYMENT_AZURE": "my-gpt4o",
}
config = SetupWizard._build_config([_AZURE], env_vars, "sk-master")
assert "model: azure/my-gpt4o" in config
assert "model_name: azure-my-gpt4o" in config
# api_base must be quoted to survive YAML special chars
assert 'api_base: "https://my.azure.com"' in config
def test_build_config_azure_no_deployment_skipped():
"""Azure without a deployment name should emit nothing (not fallback to gpt-4o)."""
env_vars = {"AZURE_API_KEY": "az-key"} # no deployment sentinel
config = SetupWizard._build_config([_AZURE], env_vars, "sk-master")
# No azure model entry should be emitted when deployment name is absent
assert "model: azure/" not in config
def test_build_config_no_display_name_collision_openai_and_azure():
"""OpenAI gpt-4o and azure gpt-4o should get distinct model_name values."""
env_vars = {
"OPENAI_API_KEY": "sk-openai",
"AZURE_API_KEY": "az-key",
"_LITELLM_AZURE_DEPLOYMENT_AZURE": "gpt-4o",
}
config = SetupWizard._build_config([_OPENAI, _AZURE], env_vars, "sk-master")
assert "model_name: gpt-4o" in config # OpenAI
assert "model_name: azure-gpt-4o" in config # Azure — qualified
def test_build_config_ollama_no_api_key_line():
"""Ollama has no env_key — config should not contain an api_key line for it."""
config = SetupWizard._build_config([_OLLAMA], {}, "sk-master")
assert "ollama/llama3.2" in config
assert "api_key:" not in config
def test_build_config_master_key_in_general_settings():
"""master_key is written to general_settings."""
config = SetupWizard._build_config([_OPENAI], {"OPENAI_API_KEY": "k"}, "sk-m")
assert 'master_key: "sk-m"' in config
def test_build_config_internal_sentinel_keys_excluded():
"""_LITELLM_ prefixed sentinel keys must not appear in environment_variables."""
env_vars = {
"OPENAI_API_KEY": "sk-real",
"_LITELLM_AZURE_API_BASE_AZURE": "https://x.azure.com",
}
config = SetupWizard._build_config([_OPENAI], env_vars, "sk-master")
assert "_LITELLM_" not in config

View File

@ -20,7 +20,6 @@ import {
ToolOutlined,
TagsOutlined,
AuditOutlined,
MessageOutlined,
} from "@ant-design/icons";
// import {
// all_admin_roles,
@ -466,41 +465,6 @@ const Sidebar2: React.FC<SidebarProps> = ({ accessToken, userRole, defaultSelect
</ConfigProvider>
{isAdminRole(userRole) && !collapsed && <UsageIndicator accessToken={accessToken} width={220} />}
{/* Pinned "Open Chat" button at bottom */}
<div style={{
padding: collapsed ? "10px 8px" : "10px 12px",
borderTop: "1px solid #f0f0f0",
flexShrink: 0,
}}>
<a
href={toHref("chat")}
target="_blank"
rel="noopener noreferrer"
style={{
display: "flex",
alignItems: "center",
justifyContent: collapsed ? "center" : "flex-start",
gap: 8,
padding: collapsed ? "8px 0" : "8px 10px",
borderRadius: 8,
background: "#1677ff",
color: "#fff",
textDecoration: "none",
fontSize: 13,
fontWeight: 600,
transition: "background 0.15s",
}}
onMouseEnter={(e) => {
(e.currentTarget as HTMLAnchorElement).style.background = "#0958d9";
}}
onMouseLeave={(e) => {
(e.currentTarget as HTMLAnchorElement).style.background = "#1677ff";
}}
>
<MessageOutlined style={{ fontSize: 16, flexShrink: 0 }} />
{!collapsed && <span>Open Chat</span>}
</a>
</div>
</Sider>
</Layout>
);

View File

@ -8,8 +8,6 @@ import ComplianceUI from "@/components/playground/complianceUI/ComplianceUI";
import { TabGroup, TabList, Tab, TabPanels, TabPanel } from "@tremor/react";
import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized";
import { fetchProxySettings } from "@/utils/proxyUtils";
import { useUIConfig } from "@/app/(dashboard)/hooks/uiConfig/useUIConfig";
import { MessageOutlined, CloseOutlined } from "@ant-design/icons";
interface ProxySettings {
PROXY_BASE_URL?: string;
@ -19,12 +17,6 @@ interface ProxySettings {
export default function PlaygroundPage() {
const { accessToken, userRole, userId, disabledPersonalKeyCreation, token } = useAuthorized();
const [proxySettings, setProxySettings] = useState<ProxySettings | undefined>(undefined);
const [chatBannerDismissed, setChatBannerDismissed] = useState(false);
const { data: uiConfig } = useUIConfig();
const uiRoot = uiConfig?.server_root_path && uiConfig.server_root_path !== "/"
? uiConfig.server_root_path.replace(/\/+$/, "")
: "";
const chatHref = `${uiRoot}/ui/chat`;
useEffect(() => {
const initializeProxySettings = async () => {
@ -44,64 +36,6 @@ export default function PlaygroundPage() {
return (
<div className="h-full w-full flex flex-col">
{!chatBannerDismissed && (
<div style={{
display: "flex",
alignItems: "center",
gap: 16,
padding: "10px 20px",
background: "#f0f9ff",
borderBottom: "1px solid #bae6fd",
flexShrink: 0,
}}>
<span style={{
fontSize: 10,
fontWeight: 700,
color: "#fff",
background: "#0ea5e9",
borderRadius: 4,
padding: "2px 7px",
letterSpacing: "0.08em",
textTransform: "uppercase",
flexShrink: 0,
lineHeight: "18px",
}}>
New
</span>
<span style={{ flex: 1, color: "#0c4a6e", fontSize: 13.5, lineHeight: 1.5 }}>
<strong>Chat UI</strong>
{" "} a ChatGPT-like interface for your users to chat with AI models and MCP tools. Share it with your team.
</span>
<a
href={chatHref}
target="_blank"
rel="noopener noreferrer"
style={{
display: "inline-flex",
alignItems: "center",
gap: 5,
padding: "5px 14px",
borderRadius: 6,
background: "#0ea5e9",
color: "#fff",
fontSize: 12.5,
fontWeight: 600,
textDecoration: "none",
whiteSpace: "nowrap",
flexShrink: 0,
}}
>
Open Chat UI
</a>
<button
onClick={() => setChatBannerDismissed(true)}
style={{ background: "none", border: "none", cursor: "pointer", color: "#64748b", padding: 4, flexShrink: 0, lineHeight: 1 }}
aria-label="Dismiss"
>
<CloseOutlined style={{ fontSize: 13 }} />
</button>
</div>
)}
<TabGroup className="w-full" style={{ flex: 1, minHeight: 0, display: "flex", flexDirection: "column" }}>
<TabList className="mb-0">
<Tab>Chat</Tab>