[Staging] - Ishaan March 17th (#23903)
* feat(xai): add grok-4.20 beta 2 models with pricing (#23900)
Add three grok-4.20 beta 2 model variants from xAI:
- grok-4.20-multi-agent-beta-0309 (reasoning + multi-agent)
- grok-4.20-beta-0309-reasoning (reasoning)
- grok-4.20-beta-0309-non-reasoning
Pricing (from https://docs.x.ai/docs/models):
- Input: $2.00/1M tokens ($0.20/1M cached)
- Output: $6.00/1M tokens
- Context: 2M tokens
All variants support vision, function calling, tool choice, and web search.
Closes LIT-2171
* docs: add Quick Install section for litellm --setup wizard (#23905)
* docs: add Quick Install section for litellm --setup wizard
* docs: clarify setup wizard is for local/beginner use
* feat(setup): interactive setup wizard + install.sh (#23644)
* feat(setup): add interactive setup wizard + install.sh
Adds `litellm --setup` — a Claude Code-style TUI onboarding wizard that
guides users through provider selection, API key entry, and proxy config
generation, then optionally starts the proxy immediately.
- litellm/setup_wizard.py: wizard with ASCII art, numbered provider menu
(OpenAI, Anthropic, Azure, Gemini, Bedrock, Ollama), API key prompts,
port/master-key config, and litellm_config.yaml generation
- litellm/proxy/proxy_cli.py: adds --setup flag that invokes the wizard
- scripts/install.sh: curl-installable script (detect OS/Python, pip
install litellm[proxy], launch wizard)
Usage:
curl -fsSL https://raw.githubusercontent.com/BerriAI/litellm/main/scripts/install.sh | sh
litellm --setup
* fix(install.sh): remove orange color, add LITELLM_BRANCH env var for branch installs
* fix(install.sh): install from git branch so --setup is available for QA
* fix(install.sh): remove stale LITELLM_BRANCH reference that caused unbound variable error
* fix(install.sh): force-reinstall from git to bypass cached PyPI version
* fix(install.sh): show pip progress bar during install
* fix(install.sh): always launch wizard via $PYTHON_BIN -m litellm, not PATH binary
* fix(install.sh): use litellm.proxy.proxy_cli module (no __main__.py exists)
* fix(install.sh): suppress RuntimeWarning from module invocation
* fix(install.sh): use Python bin-dir litellm binary to avoid CWD sys.path shadowing
* fix(install.sh): use sysconfig.get_path('scripts') to find pip-installed litellm binary
* fix(install.sh): redirect stdin from /dev/tty on exec so wizard gets terminal, not exhausted pipe
* fix(install.sh): warn about git clone duration, drop --no-cache-dir so re-runs are faster
* feat(setup_wizard): arrow-key selector, updated model names
* fix(setup_wizard): use sysconfig binary to start proxy, not python -m litellm
* feat(setup_wizard): credential validation after key entry + clear next-steps after proxy start
* style(install.sh): show git clone warning in blue
* refactor(setup_wizard): class with static methods, use check_valid_key from litellm.utils
* address greptile review: fix yaml escaping, port validation, display name collisions, tests
- setup_wizard.py: add _yaml_escape() for safe YAML embedding of API keys
- setup_wizard.py: add _styled_input() with readline ANSI ignore markers
- setup_wizard.py: change DIVIDER to _divider() fn to avoid import-time color capture
- setup_wizard.py: validate port range 1-65535, initialize before loop
- setup_wizard.py: qualify azure display names (azure-gpt-4o) to avoid collision with openai
- setup_wizard.py: work on env_copy in _build_config to avoid mutating caller's dict
- setup_wizard.py: skip model_list entries for providers with no credentials
- setup_wizard.py: prompt for azure deployment name
- setup_wizard.py: wrap os.execlp in try/except with friendly fallback
- setup_wizard.py: wrap config write in try/except OSError
- setup_wizard.py: fix _validate_and_report to use two print lines (no \r overwrite)
- setup_wizard.py: add .gitignore tip next to key storage notice
- setup_wizard.py: fix run_setup_wizard() return type annotation to None
- scripts/install.sh: drop pipefail (not supported by dash on Ubuntu when invoked as sh)
- scripts/install.sh: use litellm[proxy] from PyPI (not hardcoded dev branch)
- scripts/install.sh: guard /dev/tty read with -r check for Docker/CI compat
- scripts/install.sh: remove --force-reinstall to avoid downgrading dependencies
- tests/test_litellm/test_setup_wizard.py: 13 unit tests for _build_config and _yaml_escape
* style: black format setup_wizard.py
* fix: address remaining greptile issues - Windows compat, YAML quoting, credential flow
- guard termios/tty imports with try/except ImportError for Windows compat
- quote master_key as YAML double-quoted scalar (same as env vars)
- remove unused port param from _build_config signature
- _validate_and_report now returns the final key so re-entered creds are stored
- add test for master_key YAML quoting
* fix: add --port to suggested command, guard /dev/tty exec in install.sh
* fix: quote api_base in YAML, skip azure if no deployment, only redraw on state change
* fix: address greptile review comments
- _yaml_escape: add control character escaping (\n, \r, \t)
- test: fix tautological assertion in test_build_config_azure_no_deployment_skipped
- test: add tests for control character escaping in _yaml_escape
* feat(ui): remove Chat UI page link and banner from sidebar and playground (#23908)
* feat(guardrails): MCPJWTSigner - built-in guardrail for zero trust MCP auth (#23897)
* Allow pre_mcp_call guardrail hooks to mutate outbound MCP headers
* Enhance MCPServerManager to support hook-modified arguments and extra headers. Update tests to validate argument mutation and header injection behavior, including warnings for OpenAPI-backed servers when headers are present.
* Refactor MCPServerManager to raise HTTPException for extra headers in OpenAPI-backed servers. Update tests to reflect this change, ensuring proper exception handling instead of logging warnings.
* Allow pre_mcp_call guardrail hooks to mutate outbound MCP headers
* Enhance MCPServerManager to support hook-modified arguments and extra headers. Update tests to validate argument mutation and header injection behavior, including warnings for OpenAPI-backed servers when headers are present.
* Refactor MCPServerManager to raise HTTPException for extra headers in OpenAPI-backed servers. Update tests to reflect this change, ensuring proper exception handling instead of logging warnings.
* feat(guardrails): add MCPJWTSigner built-in guardrail for zero trust MCP auth
Signs outbound MCP tool calls with a LiteLLM-issued RS256 JWT so MCP servers
can trust a single signing authority instead of every upstream IdP.
Enable in config.yaml:
guardrails:
- guardrail_name: mcp-jwt-signer
litellm_params:
guardrail: mcp_jwt_signer
mode: pre_mcp_call
default_on: true
JWT carries sub (user_id), act.sub (team_id, RFC 8693), tool-level scope, iss,
aud, iat/exp/nbf. RSA-2048 keypair auto-generated at startup unless
MCP_JWT_SIGNING_KEY env var is set.
Adds /.well-known/jwks.json endpoint and jwks_uri to /.well-known/openid-configuration
so MCP servers can verify LiteLLM-issued tokens via OIDC discovery.
* Update MCPServerManager to raise HTTPException with status code 400 for extra headers in OpenAPI-backed servers. Adjust tests to verify the correct status code and exception message.
* fix: address P1 issues in MCPJWTSigner
- OpenAPI servers: warn + skip header injection instead of 500
- JWKS Cache-Control: 5min for auto-generated keys, 1h for persistent
- sub claim: fallback to apikey:{token_hash} for anonymous callers
- ttl_seconds: validate > 0 at init time
* docs: add MCP zero trust auth guide with architecture diagram
* docs: add FastMCP JWT verification guide to zero trust doc
* fix: address remaining Greptile review issues (round 2)
- mcp_server_manager: warn when hook Authorization overwrites existing header
- __init__: remove _mcp_jwt_signer_instance from __all__ (private internal)
- discoverable_endpoints: copy dict instead of mutating in-place on OIDC augmentation
- test docstring: reflect warn-and-continue behavior for OpenAPI servers
- test: update scope assertions for least-privilege (no mcp:tools/list on tool-call JWTs)
* fix: address Greptile round 3 feedback
- initialize_guardrail: validate mode='pre_mcp_call' at init time — misconfigured
mode silently bypasses JWT injection, which is a zero-trust bypass
- _build_claims: remove duplicate inline 'import re' (module-level import already present)
- _types.py: add TODO comment explaining jwt_claims is forward-compat plumbing
for a follow-up PR that will forward upstream IdP claims into outbound MCP JWTs
* feat(mcp_jwt_signer): add verify+re-sign, claim ops, two-token model, configurable scopes
Addresses all missing pieces from the scoping doc review:
FR-5 (Verify + re-sign): MCPJWTSigner now accepts access_token_discovery_uri
and token_introspection_endpoint. When set, the incoming Bearer token is
extracted from raw_headers (threaded through pre_call_tool_check), verified
against the IdP's JWKS (JWT) or introspected (opaque), and only re-signed if
valid. Falls back to user_api_key_dict.jwt_claims for LiteLLM JWT-auth mode.
FR-12 (Configurable end-user identity mapping): end_user_claim_sources
ordered list drives sub resolution — sources: token:<claim>, litellm:user_id,
litellm:email, litellm:end_user_id, litellm:team_id.
FR-13 (Claim operations): add_claims (insert-if-absent), set_claims (always
override), remove_claims (delete) applied in that order.
FR-14 (Two-token model): channel_token_audience + channel_token_ttl issue a
second JWT injected as x-mcp-channel-token: Bearer <token>.
FR-15 (Incoming claim validation): required_claims raises HTTP 403 when any
listed claim is absent; optional_claims passes listed claims from verified
token into the outbound JWT.
FR-9 (Debug headers): debug_headers: true emits x-litellm-mcp-debug with kid,
sub, iss, exp, scope.
FR-10 (Configurable scopes): allowed_scopes replaces auto-generation. Also
fixed: tool-call JWTs no longer grant mcp:tools/list (overpermission).
P1 fixes:
- proxy/utils.py: _convert_mcp_hook_response_to_kwargs merges rather than
replaces extra_headers, preserving headers from prior guardrails.
- mcp_server_manager.py: warns when hook injects Authorization alongside a
server-configured authentication_token (previously silent).
- mcp_server_manager.py: pre_call_tool_check now accepts raw_headers and
extracts incoming_bearer_token so FR-5 verification has the raw token.
- proxy/utils.py: remove stray inline import inspect inside loop (pre-existing
lint error, now cleaned up).
Tests: 43 passing (28 new tests covering all FR flags + P1 fixes).
* feat(mcp_jwt_signer): add verify+re-sign, claim ops, two-token model, configurable scopes (core)
Remaining files from the FR implementation:
mcp_jwt_signer.py — full rewrite with all new params:
FR-5: access_token_discovery_uri, token_introspection_endpoint,
verify_issuer, verify_audience + _verify_incoming_jwt(),
_introspect_opaque_token()
FR-12: end_user_claim_sources ordered resolution chain
FR-13: add_claims, set_claims, remove_claims
FR-14: channel_token_audience, channel_token_ttl → x-mcp-channel-token
FR-15: required_claims (raises 403), optional_claims (passthrough)
FR-9: debug_headers → x-litellm-mcp-debug
FR-10: allowed_scopes; tool-call JWTs no longer over-grant tools/list
mcp_server_manager.py:
- pre_call_tool_check gains raw_headers param to extract incoming_bearer_token
- Silent Authorization override warning fixed: now fires when server has
authentication_token AND hook injects Authorization
tests/test_mcp_jwt_signer.py:
28 new tests covering all FR flags + P1 fixes (43 total, all passing)
* fix(mcp_jwt_signer): address pre-landing review issues
- Remove stale TODO comment on UserAPIKeyAuth.jwt_claims — the field is
already populated and consumed by MCPJWTSigner in the same PR
- Fix _get_oidc_discovery to only cache the OIDC discovery doc when
jwks_uri is present; a malformed/empty doc now retries on the next
request instead of being permanently cached until proxy restart
- Add FR-5 test coverage for _fetch_jwks (cache hit/miss),
_get_oidc_discovery (cache/no-cache on bad doc), _verify_incoming_jwt
(valid token, expired token), _introspect_opaque_token (active,
inactive, no endpoint), and the end-to-end 401 hook path — 53 tests
total, all passing
* docs(mcp_zero_trust): rewrite as use-case guide covering all new JWT signer features
Add scenario-driven sections for each new config area:
- Verify+re-sign with Okta/Azure AD (access_token_discovery_uri,
end_user_claim_sources, token_introspection_endpoint)
- Enforcing caller attributes with required_claims / optional_claims
- Adding metadata via add_claims / set_claims / remove_claims
- Two-token model for AWS Bedrock AgentCore Gateway
(channel_token_audience / channel_token_ttl)
- Controlling scopes with allowed_scopes
- Debugging JWT rejections with debug_headers
Update JWT claims table to reflect configurable sub (end_user_claim_sources)
* fix(mcp_jwt_signer): wire all config.yaml params through initialize_guardrail
The factory was only passing issuer/audience/ttl_seconds to MCPJWTSigner.
All FR-5/9/10/12/13/14/15 params (access_token_discovery_uri,
end_user_claim_sources, add/set/remove_claims, channel_token_audience,
required/optional_claims, debug_headers, allowed_scopes, etc.) were
silently dropped, making every advertised advanced feature non-functional
when loaded from config.yaml.
Add regression test that asserts every param is wired through correctly.
* docs(mcp_zero_trust): add hero image
* docs(mcp_zero_trust): apply Linear-style edits
- Lead with the problem (unsigned direct calls bypass access controls)
- Shorter statement section headers instead of question-form headers
- Move diagram/OIDC discovery block after the reader is bought in
- Add 'read further only if you need to' callout after basic setup
- Two-token section now opens from the user problem not product jargon
- Add concrete 403 error response example in required_claims section
- Debug section opens from the symptom (MCP server returning 401)
- Lowercase claims reference header for consistency
* fix(mcp_jwt_signer): fix algorithm confusion attack + add OIDC discovery 24h TTL
- Remove alg from unverified JWT header; use signing_jwk.algorithm_name from JWKS key instead.
Reading alg from attacker-controlled headers enables alg:none / HS256 confusion attacks.
- Add _oidc_discovery_fetched_at timestamp and _OIDC_DISCOVERY_TTL = 86400 (24h).
Without a TTL the cached discovery doc never refreshes, so IdP key rotation is invisible.
---------
Co-authored-by: Noah Nistler <60981020+noahnistler@users.noreply.github.com>
* fix(ci): stabilize CI - formatting, type errors, test polling, security CVEs, router bug, batch resolution
Fix 1: Run Black formatter on 35 files
Fix 2: Fix MyPy type errors:
- setup_wizard.py: add type annotation for 'selected' set variable
- user_api_key_auth.py: remove redundant type annotation on jwt_claims reassignment
Fix 3: Fix spend accuracy test burst 2 polling to wait for expected total
spend instead of just 'any increase' from burst 2
Fix 4: Bump Next.js 16.1.6 -> 16.1.7 to fix CVE-2026-27978, CVE-2026-27979,
CVE-2026-27980, CVE-2026-29057
Fix 5: Fix router _pre_call_checks model variable being overwritten inside
loop, causing wrong model lookups on subsequent deployments. Use local
_deployment_model variable instead.
Fix 6: Add missing resolve_output_file_ids_to_unified call in batch retrieve
non-terminal-to-terminal path (matching the terminal path behavior)
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* chore: regenerate poetry.lock to sync with pyproject.toml
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix: format merged files from main and regenerate poetry.lock
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix(mypy): annotate jwt_claims as Optional[dict] to fix type incompatibility
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix(ci): update router region test to use gpt-4.1-mini (fix flaky model lookup)
Replace deprecated gpt-3.5-turbo-1106 with gpt-4.1-mini + mock_response in
test_router_region_pre_call_check, following the same pattern used in commit
717d37cc5b for test_router_context_window_check_pre_call_check_out_group.
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* ci: retry flaky logging_testing (async event loop race condition)
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix(ci): aggregate all mock calls in langfuse e2e test to fix race condition
The _verify_langfuse_call helper only inspected the last mock call
(mock_post.call_args), but the Langfuse SDK may split trace-create and
generation-create events across separate HTTP flush cycles. This caused
an IndexError when the last call's batch contained only one event type.
Fix: iterate over mock_post.call_args_list to collect batch items from
ALL calls. Also add a safety assertion after filtering by trace_id and
mark all langfuse e2e tests with @pytest.mark.flaky(retries=3) as an
extra safety net for any residual timing issues.
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix(ci): black formatting + update OpenAPI compliance tests for spec changes
- Apply Black 26.x formatting to litellm_logging.py (parenthesized style)
- Update test_input_types_match_spec to follow $ref to InteractionsInput schema
(Google updated their OpenAPI spec to use $ref instead of inline oneOf)
- Update test_content_schema_uses_discriminator to handle discriminator without
explicit mapping (Google removed the mapping key from Content discriminator)
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* revert: undo incorrect Black 26.x formatting on litellm_logging.py
The file was correctly formatted for Black 23.12.1 (the version pinned
in pyproject.toml). The previous commit applied Black 26.x formatting
which was incompatible with the CI's Black version.
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
* fix(ci): deduplicate and sort langfuse batch events after aggregation
The Langfuse SDK may send the same event (e.g., trace-create) in
multiple flush cycles, causing duplicates when we aggregate from all
mock calls. After filtering by trace_id, deduplicate by keeping only
the first event of each type, then sort to ensure trace-create is at
index 0 and generation-create at index 1.
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
---------
Co-authored-by: Noah Nistler <60981020+noahnistler@users.noreply.github.com>
Co-authored-by: Cursor Agent <cursoragent@cursor.com>
Co-authored-by: Ishaan Jaff <ishaan-jaff@users.noreply.github.com>
This commit is contained in:
parent
cbb4c2c220
commit
8e61b32b8e
@ -140,6 +140,11 @@ LiteLLM is a unified interface for 100+ LLM providers with two main components:
|
||||
- **Check index coverage.** For new or modified queries, check `schema.prisma` for a supporting index. Prefer extending an existing index (e.g. `@@index([a])` → `@@index([a, b])`) over adding a new one, unless it's a `@@unique`. Only add indexes for large/frequent queries.
|
||||
- **Keep schema files in sync.** Apply schema changes to all `schema.prisma` copies (`schema.prisma`, `litellm/proxy/`, `litellm-proxy-extras/`, `litellm-js/spend-logs/` for SpendLogs) with a migration under `litellm-proxy-extras/litellm_proxy_extras/migrations/`.
|
||||
|
||||
### Setup Wizard (`litellm/setup_wizard.py`)
|
||||
- The wizard is implemented as a single `SetupWizard` class with `@staticmethod` methods — keep it that way. No module-level functions except `run_setup_wizard()` (the public entrypoint) and pure helpers (color, ANSI).
|
||||
- Use `litellm.utils.check_valid_key(model, api_key)` for credential validation — never roll a custom completion call.
|
||||
- Do not hardcode provider env-key names or model lists that already exist in the codebase. Add a `test_model` field to each provider entry to drive `check_valid_key`; set it to `None` for providers that can't be validated with a single API key (Azure, Bedrock, Ollama).
|
||||
|
||||
### Enterprise Features
|
||||
- Enterprise-specific code in `enterprise/` directory
|
||||
- Optional features enabled via environment variables
|
||||
|
||||
@ -5,11 +5,76 @@ import Image from '@theme/IdealImage';
|
||||
# Getting Started Tutorial
|
||||
|
||||
End-to-End tutorial for LiteLLM Proxy to:
|
||||
- Add an Azure OpenAI model
|
||||
- Make a successful /chat/completion call
|
||||
- Generate a virtual key
|
||||
- Set RPM limit on virtual key
|
||||
- Add an Azure OpenAI model
|
||||
- Make a successful /chat/completion call
|
||||
- Generate a virtual key
|
||||
- Set RPM limit on virtual key
|
||||
|
||||
## Quick Install (Recommended for local / beginners)
|
||||
|
||||
New to LiteLLM? This is the easiest way to get started locally. One command installs LiteLLM and walks you through setup interactively — no config files to write by hand.
|
||||
|
||||
### 1. Install
|
||||
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/BerriAI/litellm/main/scripts/install.sh | sh
|
||||
```
|
||||
|
||||
This detects your OS, installs `litellm[proxy]`, and drops you straight into the setup wizard.
|
||||
|
||||
### 2. Follow the wizard
|
||||
|
||||
```
|
||||
$ litellm --setup
|
||||
|
||||
Welcome to LiteLLM
|
||||
|
||||
Choose your LLM providers
|
||||
○ 1. OpenAI GPT-4o, GPT-4o-mini, o1
|
||||
○ 2. Anthropic Claude Opus, Sonnet, Haiku
|
||||
○ 3. Azure OpenAI GPT-4o via Azure
|
||||
○ 4. Google Gemini Gemini 2.0 Flash, 1.5 Pro
|
||||
○ 5. AWS Bedrock Claude, Llama via AWS
|
||||
○ 6. Ollama Local models
|
||||
|
||||
❯ Provider(s): 1,2
|
||||
|
||||
❯ OpenAI API key: sk-...
|
||||
❯ Anthropic API key: sk-ant-...
|
||||
|
||||
❯ Port [4000]:
|
||||
❯ Master key [auto-generate]:
|
||||
|
||||
✔ Config saved → ./litellm_config.yaml
|
||||
|
||||
❯ Start the proxy now? (Y/n):
|
||||
```
|
||||
|
||||
The wizard walks you through:
|
||||
1. Pick your LLM providers (OpenAI, Anthropic, Azure, Bedrock, Gemini, Ollama)
|
||||
2. Enter API keys for each provider
|
||||
3. Set a port and master key (or accept the defaults)
|
||||
4. Config is saved to `./litellm_config.yaml` and the proxy starts immediately
|
||||
|
||||
### 3. Make a call
|
||||
|
||||
Your proxy is running on `http://0.0.0.0:4000`. Test it:
|
||||
|
||||
```bash
|
||||
curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-H 'Authorization: Bearer <your-master-key>' \
|
||||
-d '{
|
||||
"model": "gpt-4o",
|
||||
"messages": [{"role": "user", "content": "Hello!"}]
|
||||
}'
|
||||
```
|
||||
|
||||
:::tip Already have pip installed?
|
||||
You can skip the curl install and run `litellm --setup` directly after `pip install 'litellm[proxy]'`.
|
||||
:::
|
||||
|
||||
---
|
||||
|
||||
## Pre-Requisites
|
||||
|
||||
|
||||
BIN
docs/my-website/img/mcp_zero_trust_gateway.png
Normal file
BIN
docs/my-website/img/mcp_zero_trust_gateway.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 294 KiB |
@ -672,6 +672,7 @@ const sidebars = {
|
||||
"mcp_control",
|
||||
"mcp_cost",
|
||||
"mcp_guardrail",
|
||||
"mcp_zero_trust",
|
||||
"mcp_troubleshoot",
|
||||
]
|
||||
},
|
||||
|
||||
@ -1465,9 +1465,15 @@ if TYPE_CHECKING:
|
||||
from .llms.petals.completion.transformation import PetalsConfig as PetalsConfig
|
||||
from .llms.ollama.chat.transformation import OllamaChatConfig as OllamaChatConfig
|
||||
from .llms.ollama.completion.transformation import OllamaConfig as OllamaConfig
|
||||
from .llms.sagemaker.completion.transformation import SagemakerConfig as SagemakerConfig
|
||||
from .llms.sagemaker.chat.transformation import SagemakerChatConfig as SagemakerChatConfig
|
||||
from .llms.sagemaker.nova.transformation import SagemakerNovaConfig as SagemakerNovaConfig
|
||||
from .llms.sagemaker.completion.transformation import (
|
||||
SagemakerConfig as SagemakerConfig,
|
||||
)
|
||||
from .llms.sagemaker.chat.transformation import (
|
||||
SagemakerChatConfig as SagemakerChatConfig,
|
||||
)
|
||||
from .llms.sagemaker.nova.transformation import (
|
||||
SagemakerNovaConfig as SagemakerNovaConfig,
|
||||
)
|
||||
from .llms.cohere.chat.transformation import CohereChatConfig as CohereChatConfig
|
||||
from .llms.anthropic.experimental_pass_through.messages.transformation import (
|
||||
AnthropicMessagesConfig as AnthropicMessagesConfig,
|
||||
|
||||
@ -17,7 +17,9 @@ if set_verbose is True:
|
||||
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
|
||||
)
|
||||
|
||||
_ENABLE_SECRET_REDACTION = os.getenv("LITELLM_DISABLE_REDACT_SECRETS", "").lower() != "true"
|
||||
_ENABLE_SECRET_REDACTION = (
|
||||
os.getenv("LITELLM_DISABLE_REDACT_SECRETS", "").lower() != "true"
|
||||
)
|
||||
|
||||
_REDACTED = "REDACTED"
|
||||
|
||||
@ -199,7 +201,9 @@ class JsonFormatter(Formatter):
|
||||
json_record[key] = value
|
||||
|
||||
if record.exc_info:
|
||||
json_record["stacktrace"] = record.exc_text or self.formatException(record.exc_info)
|
||||
json_record["stacktrace"] = record.exc_text or self.formatException(
|
||||
record.exc_info
|
||||
)
|
||||
|
||||
return safe_dumps(json_record)
|
||||
|
||||
|
||||
@ -1189,7 +1189,9 @@ def completion_cost( # noqa: PLR0915
|
||||
and _usage["prompt_tokens_details"] != {}
|
||||
and _usage["prompt_tokens_details"]
|
||||
):
|
||||
prompt_tokens_details = _usage.get("prompt_tokens_details") or {}
|
||||
prompt_tokens_details = (
|
||||
_usage.get("prompt_tokens_details") or {}
|
||||
)
|
||||
cache_read_input_tokens = prompt_tokens_details.get(
|
||||
"cached_tokens", 0
|
||||
)
|
||||
@ -1515,7 +1517,9 @@ def completion_cost( # noqa: PLR0915
|
||||
if custom_llm_provider == "azure_ai":
|
||||
model_for_additional_costs = request_model_for_cost
|
||||
if completion_response is not None:
|
||||
hidden_params = getattr(completion_response, "_hidden_params", None) or {}
|
||||
hidden_params = (
|
||||
getattr(completion_response, "_hidden_params", None) or {}
|
||||
)
|
||||
hidden_model = hidden_params.get("model") or hidden_params.get(
|
||||
"litellm_model_name"
|
||||
)
|
||||
|
||||
@ -59,17 +59,14 @@ class FocusDestinationFactory:
|
||||
return {k: v for k, v in resolved.items() if v is not None}
|
||||
if provider == "vantage":
|
||||
resolved = {
|
||||
"api_key": overrides.get("api_key")
|
||||
or os.getenv("VANTAGE_API_KEY"),
|
||||
"api_key": overrides.get("api_key") or os.getenv("VANTAGE_API_KEY"),
|
||||
"integration_token": overrides.get("integration_token")
|
||||
or os.getenv("VANTAGE_INTEGRATION_TOKEN"),
|
||||
"base_url": overrides.get("base_url")
|
||||
or os.getenv("VANTAGE_BASE_URL", "https://api.vantage.sh"),
|
||||
}
|
||||
if not resolved.get("api_key"):
|
||||
raise ValueError(
|
||||
"VANTAGE_API_KEY must be provided for Vantage exports"
|
||||
)
|
||||
raise ValueError("VANTAGE_API_KEY must be provided for Vantage exports")
|
||||
if not resolved.get("integration_token"):
|
||||
raise ValueError(
|
||||
"VANTAGE_INTEGRATION_TOKEN must be provided for Vantage exports"
|
||||
|
||||
@ -340,9 +340,9 @@ class LangfusePromptManagement(LangFuseLogger, PromptManagementBase, CustomLogge
|
||||
)
|
||||
status_message = str(kwargs.get("exception", "Unknown error"))
|
||||
if standard_logging_object is not None:
|
||||
status_message = standard_logging_object.get(
|
||||
"error_str", None
|
||||
) or status_message
|
||||
status_message = (
|
||||
standard_logging_object.get("error_str", None) or status_message
|
||||
)
|
||||
langfuse_logger_to_use.log_event_on_langfuse(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
|
||||
@ -83,7 +83,9 @@ class VantageLogger(FocusLogger):
|
||||
|
||||
verbose_logger.debug(
|
||||
"VantageLogger initialized (integration_token=%s)",
|
||||
resolved_token[:4] + "***" if resolved_token and len(resolved_token) > 4 else "***",
|
||||
resolved_token[:4] + "***"
|
||||
if resolved_token and len(resolved_token) > 4
|
||||
else "***",
|
||||
)
|
||||
|
||||
async def initialize_focus_export_job(self) -> None:
|
||||
@ -128,9 +130,7 @@ class VantageLogger(FocusLogger):
|
||||
callback_type=VantageLogger
|
||||
)
|
||||
if not vantage_loggers:
|
||||
verbose_logger.debug(
|
||||
"No Vantage logger registered; skipping scheduler"
|
||||
)
|
||||
verbose_logger.debug("No Vantage logger registered; skipping scheduler")
|
||||
return
|
||||
|
||||
vantage_logger = cast(VantageLogger, vantage_loggers[0])
|
||||
|
||||
@ -26,7 +26,9 @@ if custom_cache_dir:
|
||||
else:
|
||||
cache_dir = filename
|
||||
|
||||
os.environ["TIKTOKEN_CACHE_DIR"] = cache_dir # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
|
||||
os.environ[
|
||||
"TIKTOKEN_CACHE_DIR"
|
||||
] = cache_dir # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
|
||||
|
||||
import tiktoken
|
||||
import time
|
||||
@ -48,4 +50,3 @@ for attempt in range(_max_retries):
|
||||
# Exponential backoff with jitter to reduce collision probability
|
||||
delay = _retry_delay * (2**attempt) + random.uniform(0, 0.1)
|
||||
time.sleep(delay)
|
||||
|
||||
|
||||
@ -352,9 +352,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
)
|
||||
self.function_id = function_id
|
||||
self.streaming_chunks: List[Any] = [] # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[Any] = (
|
||||
[]
|
||||
) # for generating complete stream response
|
||||
self.sync_streaming_chunks: List[
|
||||
Any
|
||||
] = [] # for generating complete stream response
|
||||
self.log_raw_request_response = log_raw_request_response
|
||||
|
||||
# Initialize dynamic callbacks
|
||||
@ -782,9 +782,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
prompt_spec=prompt_spec,
|
||||
dynamic_callback_params=dynamic_callback_params,
|
||||
):
|
||||
self.model_call_details["prompt_integration"] = (
|
||||
logger.__class__.__name__
|
||||
)
|
||||
self.model_call_details[
|
||||
"prompt_integration"
|
||||
] = logger.__class__.__name__
|
||||
return logger
|
||||
except Exception:
|
||||
# If check fails, continue to next logger
|
||||
@ -852,9 +852,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
if anthropic_cache_control_logger := AnthropicCacheControlHook.get_custom_logger_for_anthropic_cache_control_hook(
|
||||
non_default_params
|
||||
):
|
||||
self.model_call_details["prompt_integration"] = (
|
||||
anthropic_cache_control_logger.__class__.__name__
|
||||
)
|
||||
self.model_call_details[
|
||||
"prompt_integration"
|
||||
] = anthropic_cache_control_logger.__class__.__name__
|
||||
return anthropic_cache_control_logger
|
||||
|
||||
#########################################################
|
||||
@ -866,9 +866,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
internal_usage_cache=None,
|
||||
llm_router=None,
|
||||
)
|
||||
self.model_call_details["prompt_integration"] = (
|
||||
vector_store_custom_logger.__class__.__name__
|
||||
)
|
||||
self.model_call_details[
|
||||
"prompt_integration"
|
||||
] = vector_store_custom_logger.__class__.__name__
|
||||
# Add to global callbacks so post-call hooks are invoked
|
||||
if (
|
||||
vector_store_custom_logger
|
||||
@ -928,9 +928,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
model
|
||||
): # if model name was changes pre-call, overwrite the initial model call name with the new one
|
||||
self.model_call_details["model"] = model
|
||||
self.model_call_details["litellm_params"]["api_base"] = (
|
||||
self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
)
|
||||
self.model_call_details["litellm_params"][
|
||||
"api_base"
|
||||
] = self._get_masked_api_base(additional_args.get("api_base", ""))
|
||||
|
||||
def pre_call(self, input, api_key, model=None, additional_args={}): # noqa: PLR0915
|
||||
# Log the exact input to the LLM API
|
||||
@ -959,10 +959,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
try:
|
||||
# [Non-blocking Extra Debug Information in metadata]
|
||||
if turn_off_message_logging is True:
|
||||
_metadata["raw_request"] = (
|
||||
"redacted by litellm. \
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "redacted by litellm. \
|
||||
'litellm.turn_off_message_logging=True'"
|
||||
)
|
||||
else:
|
||||
curl_command = self._get_request_curl_command(
|
||||
api_base=additional_args.get("api_base", ""),
|
||||
@ -973,34 +973,34 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
|
||||
_metadata["raw_request"] = str(curl_command)
|
||||
# split up, so it's easier to parse in the UI
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
# NOTE: setting ignore_sensitive_headers to True will cause
|
||||
# the Authorization header to be leaked when calls to the health
|
||||
# endpoint are made and fail.
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
raw_request_api_base=str(
|
||||
additional_args.get("api_base") or ""
|
||||
),
|
||||
raw_request_body=self._get_raw_request_body(
|
||||
additional_args.get("complete_input_dict", {})
|
||||
),
|
||||
# NOTE: setting ignore_sensitive_headers to True will cause
|
||||
# the Authorization header to be leaked when calls to the health
|
||||
# endpoint are made and fail.
|
||||
raw_request_headers=self._get_masked_headers(
|
||||
additional_args.get("headers", {}) or {},
|
||||
),
|
||||
error=None,
|
||||
)
|
||||
except Exception as e:
|
||||
self.model_call_details["raw_request_typed_dict"] = (
|
||||
RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
self.model_call_details[
|
||||
"raw_request_typed_dict"
|
||||
] = RawRequestTypedDict(
|
||||
error=str(e),
|
||||
)
|
||||
_metadata["raw_request"] = (
|
||||
"Unable to Log \
|
||||
_metadata[
|
||||
"raw_request"
|
||||
] = "Unable to Log \
|
||||
raw request: {}".format(
|
||||
str(e)
|
||||
)
|
||||
str(e)
|
||||
)
|
||||
if getattr(self, "logger_fn", None) and callable(self.logger_fn):
|
||||
try:
|
||||
@ -1301,13 +1301,13 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
for callback in callbacks:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger):
|
||||
response: Optional[MCPPostCallResponseObject] = (
|
||||
await callback.async_post_mcp_tool_call_hook(
|
||||
kwargs=kwargs,
|
||||
response_obj=post_mcp_tool_call_response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
response: Optional[
|
||||
MCPPostCallResponseObject
|
||||
] = await callback.async_post_mcp_tool_call_hook(
|
||||
kwargs=kwargs,
|
||||
response_obj=post_mcp_tool_call_response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
######################################################################
|
||||
# if any of the callbacks modify the response, use the modified response
|
||||
@ -1502,9 +1502,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
return None
|
||||
|
||||
try:
|
||||
@ -1530,9 +1530,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
verbose_logger.debug(
|
||||
f"response_cost_failure_debug_information: {debug_info}"
|
||||
)
|
||||
self.model_call_details["response_cost_failure_debug_information"] = (
|
||||
debug_info
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost_failure_debug_information"
|
||||
] = debug_info
|
||||
|
||||
return None
|
||||
|
||||
@ -1688,9 +1688,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
result=logging_result
|
||||
)
|
||||
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(logging_result, start_time, end_time)
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(logging_result, start_time, end_time)
|
||||
|
||||
if (
|
||||
standard_logging_payload := self.model_call_details.get(
|
||||
@ -1768,9 +1768,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
end_time = datetime.datetime.now()
|
||||
if self.completion_start_time is None:
|
||||
self.completion_start_time = end_time
|
||||
self.model_call_details["completion_start_time"] = (
|
||||
self.completion_start_time
|
||||
)
|
||||
self.model_call_details[
|
||||
"completion_start_time"
|
||||
] = self.completion_start_time
|
||||
|
||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||
self.model_call_details["end_time"] = end_time
|
||||
@ -1807,10 +1807,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
end_time=end_time,
|
||||
)
|
||||
elif isinstance(result, dict) or isinstance(result, list):
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(
|
||||
result, start_time, end_time
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(
|
||||
result, start_time, end_time
|
||||
)
|
||||
if (
|
||||
standard_logging_payload := self.model_call_details.get(
|
||||
@ -1819,9 +1819,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
) is not None:
|
||||
emit_standard_logging_payload(standard_logging_payload)
|
||||
elif standard_logging_object is not None:
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
standard_logging_object
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = standard_logging_object
|
||||
else:
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
@ -1979,17 +1979,17 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
verbose_logger.debug(
|
||||
"Logging Details LiteLLM-Success Call streaming complete"
|
||||
)
|
||||
self.model_call_details["complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(result=complete_streaming_response)
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(result=complete_streaming_response)
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
)
|
||||
if (
|
||||
standard_logging_payload := self.model_call_details.get(
|
||||
@ -2323,10 +2323,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
openMeterLogger.log_success_event(
|
||||
@ -2350,10 +2350,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
)
|
||||
else:
|
||||
if self.stream and complete_streaming_response:
|
||||
self.model_call_details["complete_response"] = (
|
||||
self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
self.model_call_details[
|
||||
"complete_response"
|
||||
] = self.model_call_details.get(
|
||||
"complete_streaming_response", {}
|
||||
)
|
||||
result = self.model_call_details["complete_response"]
|
||||
|
||||
@ -2492,9 +2492,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
if complete_streaming_response is not None:
|
||||
print_verbose("Async success callbacks: Got a complete streaming response")
|
||||
|
||||
self.model_call_details["async_complete_streaming_response"] = (
|
||||
complete_streaming_response
|
||||
)
|
||||
self.model_call_details[
|
||||
"async_complete_streaming_response"
|
||||
] = complete_streaming_response
|
||||
|
||||
try:
|
||||
if self.model_call_details.get("cache_hit", False) is True:
|
||||
@ -2505,10 +2505,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
model_call_details=self.model_call_details
|
||||
)
|
||||
# base_model defaults to None if not set on model_info
|
||||
self.model_call_details["response_cost"] = (
|
||||
self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
)
|
||||
self.model_call_details[
|
||||
"response_cost"
|
||||
] = self._response_cost_calculator(
|
||||
result=complete_streaming_response
|
||||
)
|
||||
|
||||
verbose_logger.debug(
|
||||
@ -2521,10 +2521,10 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
self.model_call_details["response_cost"] = None
|
||||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(
|
||||
complete_streaming_response, start_time, end_time
|
||||
)
|
||||
|
||||
# print standard logging payload
|
||||
@ -2551,9 +2551,9 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
# _success_handler_helper_fn
|
||||
if self.model_call_details.get("standard_logging_object") is None:
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
self._build_standard_logging_payload(result, start_time, end_time)
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = self._build_standard_logging_payload(result, start_time, end_time)
|
||||
|
||||
# print standard logging payload
|
||||
if (
|
||||
@ -2796,18 +2796,18 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||
|
||||
## STANDARDIZED LOGGING PAYLOAD
|
||||
|
||||
self.model_call_details["standard_logging_object"] = (
|
||||
get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
self.model_call_details[
|
||||
"standard_logging_object"
|
||||
] = get_standard_logging_object_payload(
|
||||
kwargs=self.model_call_details,
|
||||
init_response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=self,
|
||||
status="failure",
|
||||
error_str=str(exception),
|
||||
original_exception=exception,
|
||||
standard_built_in_tools_params=self.standard_built_in_tools_params,
|
||||
)
|
||||
return start_time, end_time
|
||||
|
||||
@ -3774,9 +3774,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
service_name=arize_config.project_name,
|
||||
)
|
||||
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"space_id={arize_config.space_key or arize_config.space_id},api_key={arize_config.api_key}"
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, ArizeLogger)
|
||||
@ -3802,13 +3802,13 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
|
||||
# Add openinference.project.name attribute
|
||||
if existing_attrs:
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}"
|
||||
else:
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"openinference.project.name={arize_phoenix_config.project_name}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"openinference.project.name={arize_phoenix_config.project_name}"
|
||||
|
||||
# Set Phoenix project name from environment variable
|
||||
phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None)
|
||||
@ -3816,19 +3816,19 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "")
|
||||
# Add openinference.project.name attribute
|
||||
if existing_attrs:
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"{existing_attrs},openinference.project.name={phoenix_project_name}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"{existing_attrs},openinference.project.name={phoenix_project_name}"
|
||||
else:
|
||||
os.environ["OTEL_RESOURCE_ATTRIBUTES"] = (
|
||||
f"openinference.project.name={phoenix_project_name}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_RESOURCE_ATTRIBUTES"
|
||||
] = f"openinference.project.name={phoenix_project_name}"
|
||||
|
||||
# auth can be disabled on local deployments of arize phoenix
|
||||
if arize_phoenix_config.otlp_auth_headers is not None:
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
arize_phoenix_config.otlp_auth_headers
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = arize_phoenix_config.otlp_auth_headers
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
@ -3907,7 +3907,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
from litellm.integrations.focus.focus_logger import FocusLogger
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if type(callback) is FocusLogger: # exact match; exclude subclasses like VantageLogger
|
||||
if (
|
||||
type(callback) is FocusLogger
|
||||
): # exact match; exclude subclasses like VantageLogger
|
||||
return callback # type: ignore
|
||||
focus_logger = FocusLogger()
|
||||
_in_memory_loggers.append(focus_logger)
|
||||
@ -4013,9 +4015,9 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915
|
||||
exporter="otlp_http",
|
||||
endpoint="https://langtrace.ai/api/trace",
|
||||
)
|
||||
os.environ["OTEL_EXPORTER_OTLP_TRACES_HEADERS"] = (
|
||||
f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
)
|
||||
os.environ[
|
||||
"OTEL_EXPORTER_OTLP_TRACES_HEADERS"
|
||||
] = f"api_key={os.getenv('LANGTRACE_API_KEY')}"
|
||||
for callback in _in_memory_loggers:
|
||||
if (
|
||||
isinstance(callback, OpenTelemetry)
|
||||
@ -4289,7 +4291,9 @@ def get_custom_logger_compatible_class( # noqa: PLR0915
|
||||
from litellm.integrations.focus.focus_logger import FocusLogger
|
||||
|
||||
for callback in _in_memory_loggers:
|
||||
if type(callback) is FocusLogger: # exact match; exclude subclasses like VantageLogger
|
||||
if (
|
||||
type(callback) is FocusLogger
|
||||
): # exact match; exclude subclasses like VantageLogger
|
||||
return callback
|
||||
elif logging_integration == "vantage":
|
||||
from litellm.integrations.vantage.vantage_logger import VantageLogger
|
||||
@ -4937,10 +4941,10 @@ class StandardLoggingPayloadSetup:
|
||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||
if key in hidden_params:
|
||||
if key == "additional_headers":
|
||||
clean_hidden_params["additional_headers"] = (
|
||||
StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
)
|
||||
clean_hidden_params[
|
||||
"additional_headers"
|
||||
] = StandardLoggingPayloadSetup.get_additional_headers(
|
||||
hidden_params[key]
|
||||
)
|
||||
else:
|
||||
clean_hidden_params[key] = hidden_params[key] # type: ignore
|
||||
@ -5579,9 +5583,9 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
||||
):
|
||||
for k, v in metadata["user_api_key_metadata"].items():
|
||||
if k == "logging": # prevent logging user logging keys
|
||||
cleaned_user_api_key_metadata[k] = (
|
||||
"scrubbed_by_litellm_for_sensitive_keys"
|
||||
)
|
||||
cleaned_user_api_key_metadata[
|
||||
k
|
||||
] = "scrubbed_by_litellm_for_sensitive_keys"
|
||||
else:
|
||||
cleaned_user_api_key_metadata[k] = v
|
||||
|
||||
|
||||
@ -2442,7 +2442,9 @@ def anthropic_messages_pt( # noqa: PLR0915
|
||||
_document_content_element = cast(
|
||||
AnthropicMessagesDocumentParam,
|
||||
add_cache_control_to_content(
|
||||
anthropic_content_element=cast(AnthropicMessagesDocumentParam, m),
|
||||
anthropic_content_element=cast(
|
||||
AnthropicMessagesDocumentParam, m
|
||||
),
|
||||
original_content_element=dict(m),
|
||||
),
|
||||
)
|
||||
@ -2454,10 +2456,18 @@ def anthropic_messages_pt( # noqa: PLR0915
|
||||
)
|
||||
)
|
||||
_file_content_element = add_cache_control_to_content(
|
||||
anthropic_content_element=cast(AnthropicMessagesDocumentParam, _file_content_element),
|
||||
anthropic_content_element=cast(
|
||||
AnthropicMessagesDocumentParam,
|
||||
_file_content_element,
|
||||
),
|
||||
original_content_element=dict(m),
|
||||
)
|
||||
user_content.append(cast(AnthropicMessagesDocumentParam,_file_content_element))
|
||||
user_content.append(
|
||||
cast(
|
||||
AnthropicMessagesDocumentParam,
|
||||
_file_content_element,
|
||||
)
|
||||
)
|
||||
elif isinstance(user_message_types_block["content"], str):
|
||||
_anthropic_content_text_element: AnthropicMessagesTextParam = {
|
||||
"type": "text",
|
||||
|
||||
@ -780,7 +780,7 @@ class LiteLLMAnthropicMessagesAdapter:
|
||||
# Keep Anthropic-native tools in their original format
|
||||
new_tools.append(tool) # type: ignore[arg-type]
|
||||
continue
|
||||
|
||||
|
||||
original_name = tool["name"]
|
||||
truncated_name = truncate_tool_name(original_name)
|
||||
|
||||
|
||||
@ -336,9 +336,7 @@ class BaseVideoConfig(ABC):
|
||||
Returns:
|
||||
Tuple[str, Dict]: (url, data) for the POST request
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"video edit is not supported for this provider"
|
||||
)
|
||||
raise NotImplementedError("video edit is not supported for this provider")
|
||||
|
||||
def transform_video_edit_response(
|
||||
self,
|
||||
@ -346,9 +344,7 @@ class BaseVideoConfig(ABC):
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> VideoObject:
|
||||
raise NotImplementedError(
|
||||
"video edit is not supported for this provider"
|
||||
)
|
||||
raise NotImplementedError("video edit is not supported for this provider")
|
||||
|
||||
def transform_video_extension_request(
|
||||
self,
|
||||
@ -366,9 +362,7 @@ class BaseVideoConfig(ABC):
|
||||
Returns:
|
||||
Tuple[str, Dict]: (url, data) for the POST request
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"video extension is not supported for this provider"
|
||||
)
|
||||
raise NotImplementedError("video extension is not supported for this provider")
|
||||
|
||||
def transform_video_extension_response(
|
||||
self,
|
||||
@ -376,9 +370,7 @@ class BaseVideoConfig(ABC):
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> VideoObject:
|
||||
raise NotImplementedError(
|
||||
"video extension is not supported for this provider"
|
||||
)
|
||||
raise NotImplementedError("video extension is not supported for this provider")
|
||||
|
||||
def get_error_class(
|
||||
self, error_message: str, status_code: int, headers: Union[dict, httpx.Headers]
|
||||
|
||||
@ -6162,7 +6162,10 @@ class BaseLLMHTTPHandler:
|
||||
litellm_params=dict(litellm_params),
|
||||
)
|
||||
|
||||
url, files_list = video_provider_config.transform_video_create_character_request(
|
||||
(
|
||||
url,
|
||||
files_list,
|
||||
) = video_provider_config.transform_video_create_character_request(
|
||||
name=name,
|
||||
video=video,
|
||||
api_base=api_base,
|
||||
@ -6230,7 +6233,10 @@ class BaseLLMHTTPHandler:
|
||||
litellm_params=dict(litellm_params),
|
||||
)
|
||||
|
||||
url, files_list = video_provider_config.transform_video_create_character_request(
|
||||
(
|
||||
url,
|
||||
files_list,
|
||||
) = video_provider_config.transform_video_create_character_request(
|
||||
name=name,
|
||||
video=video,
|
||||
api_base=api_base,
|
||||
@ -6324,11 +6330,7 @@ class BaseLLMHTTPHandler:
|
||||
)
|
||||
|
||||
try:
|
||||
response = sync_httpx_client.get(
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=params
|
||||
)
|
||||
response = sync_httpx_client.get(url=url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
return video_provider_config.transform_video_get_character_response(
|
||||
raw_response=response,
|
||||
@ -6386,9 +6388,7 @@ class BaseLLMHTTPHandler:
|
||||
|
||||
try:
|
||||
response = await async_httpx_client.get(
|
||||
url=url,
|
||||
headers=headers,
|
||||
params=params
|
||||
url=url, headers=headers, params=params
|
||||
)
|
||||
response.raise_for_status()
|
||||
return video_provider_config.transform_video_get_character_response(
|
||||
|
||||
@ -525,28 +525,47 @@ class GeminiVideoConfig(BaseVideoConfig):
|
||||
"""Video delete is not supported."""
|
||||
raise NotImplementedError("Video delete is not supported by Google Veo.")
|
||||
|
||||
def transform_video_create_character_request(self, name, video, api_base, litellm_params, headers):
|
||||
def transform_video_create_character_request(
|
||||
self, name, video, api_base, litellm_params, headers
|
||||
):
|
||||
raise NotImplementedError("video create character is not supported for Gemini")
|
||||
|
||||
def transform_video_create_character_response(self, raw_response, logging_obj):
|
||||
raise NotImplementedError("video create character is not supported for Gemini")
|
||||
|
||||
def transform_video_get_character_request(self, character_id, api_base, litellm_params, headers):
|
||||
def transform_video_get_character_request(
|
||||
self, character_id, api_base, litellm_params, headers
|
||||
):
|
||||
raise NotImplementedError("video get character is not supported for Gemini")
|
||||
|
||||
def transform_video_get_character_response(self, raw_response, logging_obj):
|
||||
raise NotImplementedError("video get character is not supported for Gemini")
|
||||
|
||||
def transform_video_edit_request(self, prompt, video_id, api_base, litellm_params, headers, extra_body=None):
|
||||
def transform_video_edit_request(
|
||||
self, prompt, video_id, api_base, litellm_params, headers, extra_body=None
|
||||
):
|
||||
raise NotImplementedError("video edit is not supported for Gemini")
|
||||
|
||||
def transform_video_edit_response(self, raw_response, logging_obj, custom_llm_provider=None):
|
||||
def transform_video_edit_response(
|
||||
self, raw_response, logging_obj, custom_llm_provider=None
|
||||
):
|
||||
raise NotImplementedError("video edit is not supported for Gemini")
|
||||
|
||||
def transform_video_extension_request(self, prompt, video_id, seconds, api_base, litellm_params, headers, extra_body=None):
|
||||
def transform_video_extension_request(
|
||||
self,
|
||||
prompt,
|
||||
video_id,
|
||||
seconds,
|
||||
api_base,
|
||||
litellm_params,
|
||||
headers,
|
||||
extra_body=None,
|
||||
):
|
||||
raise NotImplementedError("video extension is not supported for Gemini")
|
||||
|
||||
def transform_video_extension_response(self, raw_response, logging_obj, custom_llm_provider=None):
|
||||
def transform_video_extension_response(
|
||||
self, raw_response, logging_obj, custom_llm_provider=None
|
||||
):
|
||||
raise NotImplementedError("video extension is not supported for Gemini")
|
||||
|
||||
def get_error_class(
|
||||
|
||||
@ -19,7 +19,8 @@ class MoonshotChatConfig(OpenAIGPTConfig):
|
||||
@overload
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: Literal[True]
|
||||
) -> Coroutine[Any, Any, List[AllMessageValues]]: ...
|
||||
) -> Coroutine[Any, Any, List[AllMessageValues]]:
|
||||
...
|
||||
|
||||
@overload
|
||||
def _transform_messages(
|
||||
@ -27,7 +28,8 @@ class MoonshotChatConfig(OpenAIGPTConfig):
|
||||
messages: List[AllMessageValues],
|
||||
model: str,
|
||||
is_async: Literal[False] = False,
|
||||
) -> List[AllMessageValues]: ...
|
||||
) -> List[AllMessageValues]:
|
||||
...
|
||||
|
||||
def _transform_messages(
|
||||
self, messages: List[AllMessageValues], model: str, is_async: bool = False
|
||||
@ -53,9 +55,13 @@ class MoonshotChatConfig(OpenAIGPTConfig):
|
||||
messages = handle_messages_with_content_list_to_str_conversion(messages)
|
||||
|
||||
if is_async:
|
||||
return super()._transform_messages(messages=messages, model=model, is_async=True)
|
||||
return super()._transform_messages(
|
||||
messages=messages, model=model, is_async=True
|
||||
)
|
||||
else:
|
||||
return super()._transform_messages(messages=messages, model=model, is_async=False)
|
||||
return super()._transform_messages(
|
||||
messages=messages, model=model, is_async=False
|
||||
)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
@ -141,7 +147,9 @@ class MoonshotChatConfig(OpenAIGPTConfig):
|
||||
optional_params["temperature"] = 0.3
|
||||
return optional_params
|
||||
|
||||
def fill_reasoning_content(self, messages: List[AllMessageValues]) -> List[AllMessageValues]:
|
||||
def fill_reasoning_content(
|
||||
self, messages: List[AllMessageValues]
|
||||
) -> List[AllMessageValues]:
|
||||
"""
|
||||
Moonshot reasoning models require `reasoning_content` on every assistant
|
||||
message that contains tool_calls (multi-turn tool-calling flows).
|
||||
|
||||
@ -592,28 +592,51 @@ class RunwayMLVideoConfig(BaseVideoConfig):
|
||||
|
||||
return video_obj
|
||||
|
||||
def transform_video_create_character_request(self, name, video, api_base, litellm_params, headers):
|
||||
raise NotImplementedError("video create character is not supported for RunwayML")
|
||||
def transform_video_create_character_request(
|
||||
self, name, video, api_base, litellm_params, headers
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"video create character is not supported for RunwayML"
|
||||
)
|
||||
|
||||
def transform_video_create_character_response(self, raw_response, logging_obj):
|
||||
raise NotImplementedError("video create character is not supported for RunwayML")
|
||||
raise NotImplementedError(
|
||||
"video create character is not supported for RunwayML"
|
||||
)
|
||||
|
||||
def transform_video_get_character_request(self, character_id, api_base, litellm_params, headers):
|
||||
def transform_video_get_character_request(
|
||||
self, character_id, api_base, litellm_params, headers
|
||||
):
|
||||
raise NotImplementedError("video get character is not supported for RunwayML")
|
||||
|
||||
def transform_video_get_character_response(self, raw_response, logging_obj):
|
||||
raise NotImplementedError("video get character is not supported for RunwayML")
|
||||
|
||||
def transform_video_edit_request(self, prompt, video_id, api_base, litellm_params, headers, extra_body=None):
|
||||
def transform_video_edit_request(
|
||||
self, prompt, video_id, api_base, litellm_params, headers, extra_body=None
|
||||
):
|
||||
raise NotImplementedError("video edit is not supported for RunwayML")
|
||||
|
||||
def transform_video_edit_response(self, raw_response, logging_obj, custom_llm_provider=None):
|
||||
def transform_video_edit_response(
|
||||
self, raw_response, logging_obj, custom_llm_provider=None
|
||||
):
|
||||
raise NotImplementedError("video edit is not supported for RunwayML")
|
||||
|
||||
def transform_video_extension_request(self, prompt, video_id, seconds, api_base, litellm_params, headers, extra_body=None):
|
||||
def transform_video_extension_request(
|
||||
self,
|
||||
prompt,
|
||||
video_id,
|
||||
seconds,
|
||||
api_base,
|
||||
litellm_params,
|
||||
headers,
|
||||
extra_body=None,
|
||||
):
|
||||
raise NotImplementedError("video extension is not supported for RunwayML")
|
||||
|
||||
def transform_video_extension_response(self, raw_response, logging_obj, custom_llm_provider=None):
|
||||
def transform_video_extension_response(
|
||||
self, raw_response, logging_obj, custom_llm_provider=None
|
||||
):
|
||||
raise NotImplementedError("video extension is not supported for RunwayML")
|
||||
|
||||
def get_error_class(
|
||||
|
||||
@ -184,9 +184,7 @@ class SagemakerChatConfig(OpenAIGPTConfig, BaseAWSLLM):
|
||||
llm_provider = LlmProviders(custom_llm_provider)
|
||||
except ValueError:
|
||||
llm_provider = LlmProviders.SAGEMAKER_CHAT
|
||||
client = get_async_httpx_client(
|
||||
llm_provider=llm_provider, params={}
|
||||
)
|
||||
client = get_async_httpx_client(llm_provider=llm_provider, params={})
|
||||
|
||||
try:
|
||||
response = await client.post(
|
||||
|
||||
@ -142,8 +142,8 @@ class VertexAIBatchTransformation:
|
||||
Gets the output file id from the Vertex AI Batch response
|
||||
"""
|
||||
|
||||
output_file_id: str = (
|
||||
response.get("outputInfo", OutputInfo()).get("gcsOutputDirectory", "")
|
||||
output_file_id: str = response.get("outputInfo", OutputInfo()).get(
|
||||
"gcsOutputDirectory", ""
|
||||
)
|
||||
if output_file_id:
|
||||
output_file_id = output_file_id.rstrip("/") + "/predictions.jsonl"
|
||||
|
||||
@ -624,28 +624,51 @@ class VertexAIVideoConfig(BaseVideoConfig, VertexBase):
|
||||
"""Video delete is not supported."""
|
||||
raise NotImplementedError("Video delete is not supported by Vertex AI Veo.")
|
||||
|
||||
def transform_video_create_character_request(self, name, video, api_base, litellm_params, headers):
|
||||
raise NotImplementedError("video create character is not supported for Vertex AI")
|
||||
def transform_video_create_character_request(
|
||||
self, name, video, api_base, litellm_params, headers
|
||||
):
|
||||
raise NotImplementedError(
|
||||
"video create character is not supported for Vertex AI"
|
||||
)
|
||||
|
||||
def transform_video_create_character_response(self, raw_response, logging_obj):
|
||||
raise NotImplementedError("video create character is not supported for Vertex AI")
|
||||
raise NotImplementedError(
|
||||
"video create character is not supported for Vertex AI"
|
||||
)
|
||||
|
||||
def transform_video_get_character_request(self, character_id, api_base, litellm_params, headers):
|
||||
def transform_video_get_character_request(
|
||||
self, character_id, api_base, litellm_params, headers
|
||||
):
|
||||
raise NotImplementedError("video get character is not supported for Vertex AI")
|
||||
|
||||
def transform_video_get_character_response(self, raw_response, logging_obj):
|
||||
raise NotImplementedError("video get character is not supported for Vertex AI")
|
||||
|
||||
def transform_video_edit_request(self, prompt, video_id, api_base, litellm_params, headers, extra_body=None):
|
||||
def transform_video_edit_request(
|
||||
self, prompt, video_id, api_base, litellm_params, headers, extra_body=None
|
||||
):
|
||||
raise NotImplementedError("video edit is not supported for Vertex AI")
|
||||
|
||||
def transform_video_edit_response(self, raw_response, logging_obj, custom_llm_provider=None):
|
||||
def transform_video_edit_response(
|
||||
self, raw_response, logging_obj, custom_llm_provider=None
|
||||
):
|
||||
raise NotImplementedError("video edit is not supported for Vertex AI")
|
||||
|
||||
def transform_video_extension_request(self, prompt, video_id, seconds, api_base, litellm_params, headers, extra_body=None):
|
||||
def transform_video_extension_request(
|
||||
self,
|
||||
prompt,
|
||||
video_id,
|
||||
seconds,
|
||||
api_base,
|
||||
litellm_params,
|
||||
headers,
|
||||
extra_body=None,
|
||||
):
|
||||
raise NotImplementedError("video extension is not supported for Vertex AI")
|
||||
|
||||
def transform_video_extension_response(self, raw_response, logging_obj, custom_llm_provider=None):
|
||||
def transform_video_extension_response(
|
||||
self, raw_response, logging_obj, custom_llm_provider=None
|
||||
):
|
||||
raise NotImplementedError("video extension is not supported for Vertex AI")
|
||||
|
||||
def get_error_class(
|
||||
|
||||
@ -7533,9 +7533,7 @@ def stream_chunk_builder( # noqa: PLR0915
|
||||
# the final chunk.
|
||||
all_annotations: list = []
|
||||
for ac in annotation_chunks:
|
||||
all_annotations.extend(
|
||||
ac["choices"][0]["delta"]["annotations"]
|
||||
)
|
||||
all_annotations.extend(ac["choices"][0]["delta"]["annotations"])
|
||||
response["choices"][0]["message"]["annotations"] = all_annotations
|
||||
|
||||
audio_chunks = [
|
||||
|
||||
@ -32354,6 +32354,53 @@
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"xai/grok-4.20-multi-agent-beta-0309": {
|
||||
"cache_read_input_token_cost": 2e-07,
|
||||
"input_cost_per_token": 2e-06,
|
||||
"litellm_provider": "xai",
|
||||
"max_input_tokens": 2000000,
|
||||
"max_output_tokens": 2000000,
|
||||
"max_tokens": 2000000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 6e-06,
|
||||
"source": "https://docs.x.ai/docs/models",
|
||||
"supports_function_calling": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"xai/grok-4.20-beta-0309-reasoning": {
|
||||
"cache_read_input_token_cost": 2e-07,
|
||||
"input_cost_per_token": 2e-06,
|
||||
"litellm_provider": "xai",
|
||||
"max_input_tokens": 2000000,
|
||||
"max_output_tokens": 2000000,
|
||||
"max_tokens": 2000000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 6e-06,
|
||||
"source": "https://docs.x.ai/docs/models",
|
||||
"supports_function_calling": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"xai/grok-4.20-beta-0309-non-reasoning": {
|
||||
"cache_read_input_token_cost": 2e-07,
|
||||
"input_cost_per_token": 2e-06,
|
||||
"litellm_provider": "xai",
|
||||
"max_input_tokens": 2000000,
|
||||
"max_output_tokens": 2000000,
|
||||
"max_tokens": 2000000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 6e-06,
|
||||
"source": "https://docs.x.ai/docs/models",
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"xai/grok-beta": {
|
||||
"input_cost_per_token": 5e-06,
|
||||
"litellm_provider": "xai",
|
||||
|
||||
@ -677,7 +677,60 @@ async def oauth_authorization_server_mcp(
|
||||
# Alias for standard OpenID discovery
|
||||
@router.get("/.well-known/openid-configuration")
|
||||
async def openid_configuration(request: Request):
|
||||
return await oauth_authorization_server_mcp(request)
|
||||
response = await oauth_authorization_server_mcp(request)
|
||||
|
||||
# If MCPJWTSigner is active, augment the discovery doc with JWKS fields so
|
||||
# MCP servers and gateways (e.g. AWS Bedrock AgentCore Gateway) can resolve
|
||||
# the signing keys and verify liteLLM-issued tokens.
|
||||
try:
|
||||
from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import (
|
||||
get_mcp_jwt_signer,
|
||||
)
|
||||
|
||||
signer = get_mcp_jwt_signer()
|
||||
if signer is not None:
|
||||
request_base_url = get_request_base_url(request)
|
||||
if isinstance(response, dict):
|
||||
response = {
|
||||
**response,
|
||||
"jwks_uri": f"{request_base_url}/.well-known/jwks.json",
|
||||
"id_token_signing_alg_values_supported": ["RS256"],
|
||||
}
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return response
|
||||
|
||||
|
||||
@router.get("/.well-known/jwks.json")
|
||||
async def jwks_json(request: Request):
|
||||
"""
|
||||
JSON Web Key Set endpoint.
|
||||
|
||||
Returns the RSA public key used by MCPJWTSigner to sign outbound MCP tokens.
|
||||
MCP servers and gateways use this endpoint to verify liteLLM-issued JWTs.
|
||||
|
||||
Returns an empty key set if MCPJWTSigner is not configured.
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.guardrails.guardrail_hooks.mcp_jwt_signer.mcp_jwt_signer import (
|
||||
get_mcp_jwt_signer,
|
||||
)
|
||||
|
||||
signer = get_mcp_jwt_signer()
|
||||
if signer is not None:
|
||||
return JSONResponse(
|
||||
content=signer.get_jwks(),
|
||||
headers={"Cache-Control": f"public, max-age={signer.jwks_max_age}"},
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
# No signer active — return empty key set; short cache so activation is picked up quickly.
|
||||
return JSONResponse(
|
||||
content={"keys": []},
|
||||
headers={"Cache-Control": "public, max-age=60"},
|
||||
)
|
||||
|
||||
|
||||
# Additional legacy pattern support
|
||||
|
||||
@ -1908,7 +1908,15 @@ class MCPServerManager:
|
||||
user_api_key_auth: Optional[UserAPIKeyAuth],
|
||||
proxy_logging_obj: ProxyLogging,
|
||||
server: MCPServer,
|
||||
):
|
||||
raw_headers: Optional[Dict[str, str]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Run pre-call checks and guardrail hooks for an MCP tool call.
|
||||
|
||||
Returns a dict that may contain:
|
||||
- "arguments": hook-modified tool arguments (only if changed)
|
||||
- "extra_headers": headers injected by pre_mcp_call guardrail hooks
|
||||
"""
|
||||
## check if the tool is allowed or banned for the given server
|
||||
if not self.check_allowed_or_banned_tools(name, server):
|
||||
raise HTTPException(
|
||||
@ -1932,6 +1940,14 @@ class MCPServerManager:
|
||||
server=server,
|
||||
)
|
||||
|
||||
# Extract incoming Bearer token from raw request headers so
|
||||
# guardrails like MCPJWTSigner can verify + re-sign it (FR-5).
|
||||
normalized_raw = {k.lower(): v for k, v in (raw_headers or {}).items()}
|
||||
incoming_bearer_token: Optional[str] = None
|
||||
auth_hdr = normalized_raw.get("authorization", "")
|
||||
if auth_hdr.lower().startswith("bearer "):
|
||||
incoming_bearer_token = auth_hdr[len("bearer ") :]
|
||||
|
||||
pre_hook_kwargs = {
|
||||
"name": name,
|
||||
"arguments": arguments,
|
||||
@ -1957,6 +1973,7 @@ class MCPServerManager:
|
||||
if user_api_key_auth
|
||||
else None
|
||||
),
|
||||
"incoming_bearer_token": incoming_bearer_token,
|
||||
}
|
||||
|
||||
# Create MCP request object for processing
|
||||
@ -1969,6 +1986,7 @@ class MCPServerManager:
|
||||
mcp_request_obj, pre_hook_kwargs
|
||||
)
|
||||
|
||||
hook_result: Dict[str, Any] = {}
|
||||
try:
|
||||
# Use standard pre_call_hook
|
||||
modified_data = await proxy_logging_obj.pre_call_hook(
|
||||
@ -1984,7 +2002,9 @@ class MCPServerManager:
|
||||
)
|
||||
)
|
||||
if modified_kwargs.get("arguments") != arguments:
|
||||
arguments = modified_kwargs["arguments"]
|
||||
hook_result["arguments"] = modified_kwargs["arguments"]
|
||||
if modified_kwargs.get("extra_headers"):
|
||||
hook_result["extra_headers"] = modified_kwargs["extra_headers"]
|
||||
|
||||
except (
|
||||
BlockedPiiEntityError,
|
||||
@ -1995,6 +2015,8 @@ class MCPServerManager:
|
||||
verbose_logger.error(f"Guardrail blocked MCP tool call pre call: {str(e)}")
|
||||
raise e
|
||||
|
||||
return hook_result
|
||||
|
||||
def _create_during_hook_task(
|
||||
self,
|
||||
name: str,
|
||||
@ -2047,6 +2069,7 @@ class MCPServerManager:
|
||||
raw_headers: Optional[Dict[str, str]],
|
||||
proxy_logging_obj: Optional[ProxyLogging],
|
||||
host_progress_callback: Optional[Callable] = None,
|
||||
hook_extra_headers: Optional[Dict[str, str]] = None,
|
||||
) -> CallToolResult:
|
||||
"""
|
||||
Call a regular MCP tool using the MCP client.
|
||||
@ -2061,6 +2084,9 @@ class MCPServerManager:
|
||||
oauth2_headers: Optional OAuth2 headers
|
||||
raw_headers: Optional raw headers from the request
|
||||
proxy_logging_obj: Optional ProxyLogging object for hook integration
|
||||
host_progress_callback: Optional callback for progress updates
|
||||
hook_extra_headers: Optional headers injected by pre_mcp_call guardrail
|
||||
hooks. Merged last (highest priority) into outbound request headers.
|
||||
|
||||
Returns:
|
||||
CallToolResult from the MCP server
|
||||
@ -2116,6 +2142,31 @@ class MCPServerManager:
|
||||
extra_headers = {}
|
||||
extra_headers.update(mcp_server.static_headers)
|
||||
|
||||
if hook_extra_headers:
|
||||
if extra_headers is None:
|
||||
extra_headers = {}
|
||||
if "Authorization" in hook_extra_headers:
|
||||
if "Authorization" in extra_headers:
|
||||
verbose_logger.warning(
|
||||
"MCPServerManager: hook_extra_headers 'Authorization' will overwrite "
|
||||
"the existing Authorization header from static_headers. "
|
||||
"The hook JWT will take precedence."
|
||||
)
|
||||
elif server_auth_header is not None:
|
||||
# server_auth_header is passed separately to _create_mcp_client as
|
||||
# auth_value. Both will reach the upstream server — warn so admins
|
||||
# know two Authorization credentials are being sent.
|
||||
verbose_logger.warning(
|
||||
"MCPServerManager: hook_extra_headers injects 'Authorization' while "
|
||||
"server '%s' already has a configured authentication_token. "
|
||||
"Both credentials will be sent; the hook header is in extra_headers "
|
||||
"and the server token is in auth_value — the upstream server decides "
|
||||
"which one wins. Consider unsetting authentication_token if you want "
|
||||
"the hook JWT to be the sole credential.",
|
||||
mcp_server.server_name or mcp_server.name,
|
||||
)
|
||||
extra_headers.update(hook_extra_headers)
|
||||
|
||||
stdio_env = self._build_stdio_env(mcp_server, raw_headers)
|
||||
|
||||
client = await self._create_mcp_client(
|
||||
@ -2201,15 +2252,19 @@ class MCPServerManager:
|
||||
# Allow validation and modification of tool calls before execution
|
||||
# Using standard pre_call_hook
|
||||
#########################################################
|
||||
hook_result: Dict[str, Any] = {}
|
||||
if proxy_logging_obj:
|
||||
await self.pre_call_tool_check(
|
||||
hook_result = await self.pre_call_tool_check(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
server_name=server_name,
|
||||
user_api_key_auth=user_api_key_auth,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
server=mcp_server,
|
||||
raw_headers=raw_headers,
|
||||
)
|
||||
if "arguments" in hook_result:
|
||||
arguments = hook_result["arguments"]
|
||||
|
||||
# Prepare tasks for during hooks
|
||||
tasks = []
|
||||
@ -2227,8 +2282,16 @@ class MCPServerManager:
|
||||
# For OpenAPI servers, call the tool handler directly instead of via MCP client
|
||||
if mcp_server.spec_path:
|
||||
verbose_logger.debug(
|
||||
f"Calling OpenAPI tool {name} directly via HTTP handler"
|
||||
"Calling OpenAPI tool %s directly via HTTP handler", name
|
||||
)
|
||||
if hook_result.get("extra_headers"):
|
||||
verbose_logger.warning(
|
||||
"pre_mcp_call hook returned extra_headers for OpenAPI-backed "
|
||||
"MCP server '%s' — header injection is not supported for "
|
||||
"OpenAPI servers; headers will be ignored. Use SSE/HTTP "
|
||||
"transport to enable hook header injection.",
|
||||
server_name,
|
||||
)
|
||||
tasks.append(
|
||||
asyncio.create_task(
|
||||
self._call_openapi_tool_handler(mcp_server, name, arguments)
|
||||
@ -2247,6 +2310,7 @@ class MCPServerManager:
|
||||
raw_headers=raw_headers,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
host_progress_callback=host_progress_callback,
|
||||
hook_extra_headers=hook_result.get("extra_headers"),
|
||||
)
|
||||
|
||||
# For OpenAPI tools, await outside the client context
|
||||
|
||||
@ -903,12 +903,12 @@ if MCP_AVAILABLE:
|
||||
try:
|
||||
client_id, client_secret, scopes = _extract_credentials(request)
|
||||
|
||||
_oauth2_flow: Optional[Literal["client_credentials", "authorization_code"]] = (
|
||||
request.oauth2_flow or (
|
||||
"client_credentials"
|
||||
if client_id and client_secret and request.token_url
|
||||
else None
|
||||
)
|
||||
_oauth2_flow: Optional[
|
||||
Literal["client_credentials", "authorization_code"]
|
||||
] = request.oauth2_flow or (
|
||||
"client_credentials"
|
||||
if client_id and client_secret and request.token_url
|
||||
else None
|
||||
)
|
||||
# client_credentials requires token_url to fetch a token; without it the
|
||||
# incoming auth header would be dropped with nothing to replace it.
|
||||
|
||||
@ -2471,6 +2471,9 @@ class UserAPIKeyAuth(
|
||||
Any
|
||||
] = None # Expanded created_by user when expand=user is used
|
||||
end_user_object_permission: Optional[LiteLLM_ObjectPermissionTable] = None
|
||||
# Decoded upstream IdP claims (groups, roles, etc.) propagated by JWT auth machinery
|
||||
# and forwarded into outbound tokens by guardrails such as MCPJWTSigner.
|
||||
jwt_claims: Optional[Dict] = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@ -680,7 +680,7 @@ def get_customer_user_header_from_mapping(user_id_mapping) -> Optional[list]:
|
||||
|
||||
if customer_headers_mappings:
|
||||
return customer_headers_mappings
|
||||
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@ -754,15 +754,11 @@ def get_end_user_id_from_request_body(
|
||||
user_id_str = str(header_value)
|
||||
if user_id_str.strip():
|
||||
return user_id_str
|
||||
|
||||
|
||||
elif isinstance(custom_header_name_to_check, str):
|
||||
for header_name, header_value in request_headers.items():
|
||||
if header_name.lower() == custom_header_name_to_check.lower():
|
||||
user_id_str = (
|
||||
str(header_value)
|
||||
if header_value is not None
|
||||
else ""
|
||||
)
|
||||
user_id_str = str(header_value) if header_value is not None else ""
|
||||
if user_id_str.strip():
|
||||
return user_id_str
|
||||
|
||||
|
||||
@ -685,6 +685,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
do_standard_jwt_auth = True
|
||||
if jwt_handler.litellm_jwtauth.virtual_key_claim_field is not None:
|
||||
# Decode JWT to get claims without running full auth_builder
|
||||
jwt_claims: Optional[dict]
|
||||
if jwt_handler.litellm_jwtauth.oidc_userinfo_enabled:
|
||||
jwt_claims = await jwt_handler.get_oidc_userinfo(token=api_key)
|
||||
else:
|
||||
@ -700,6 +701,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
)
|
||||
if valid_token is not None:
|
||||
api_key = valid_token.token or ""
|
||||
valid_token.jwt_claims = jwt_claims
|
||||
do_standard_jwt_auth = False
|
||||
# Fall through to virtual key checks
|
||||
|
||||
@ -729,6 +731,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
team_membership: Optional[LiteLLM_TeamMembership] = result.get(
|
||||
"team_membership", None
|
||||
)
|
||||
jwt_claims = result.get("jwt_claims", None)
|
||||
|
||||
global_proxy_spend = await get_global_proxy_spend(
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
@ -757,6 +760,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
org_id=org_id,
|
||||
end_user_id=end_user_id,
|
||||
parent_otel_span=parent_otel_span,
|
||||
jwt_claims=jwt_claims,
|
||||
)
|
||||
|
||||
valid_token = UserAPIKeyAuth(
|
||||
@ -803,6 +807,7 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
|
||||
team_metadata=(
|
||||
team_object.metadata if team_object is not None else None
|
||||
),
|
||||
jwt_claims=jwt_claims,
|
||||
)
|
||||
|
||||
# Check if model has zero cost - if so, skip all budget checks
|
||||
|
||||
@ -537,9 +537,10 @@ async def retrieve_batch( # noqa: PLR0915
|
||||
)
|
||||
|
||||
# Fix: bug_feb14_batch_retrieve_returns_raw_input_file_id
|
||||
# Resolve raw provider input_file_id to unified ID.
|
||||
# Resolve raw provider file IDs (input, output, error) to unified IDs.
|
||||
if unified_batch_id:
|
||||
await resolve_input_file_id_to_unified(response, prisma_client)
|
||||
await resolve_output_file_ids_to_unified(response, prisma_client)
|
||||
|
||||
### ALERTING ###
|
||||
asyncio.create_task(
|
||||
|
||||
@ -0,0 +1,84 @@
|
||||
"""MCP JWT Signer guardrail — built-in LiteLLM guardrail for zero trust MCP auth."""
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from litellm.types.guardrails import SupportedGuardrailIntegrations
|
||||
|
||||
from .mcp_jwt_signer import MCPJWTSigner, get_mcp_jwt_signer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.types.guardrails import Guardrail, LitellmParams
|
||||
|
||||
|
||||
def initialize_guardrail(
|
||||
litellm_params: "LitellmParams", guardrail: "Guardrail"
|
||||
) -> MCPJWTSigner:
|
||||
import litellm
|
||||
|
||||
guardrail_name = guardrail.get("guardrail_name")
|
||||
if not guardrail_name:
|
||||
raise ValueError("MCPJWTSigner guardrail requires a guardrail_name")
|
||||
|
||||
mode = litellm_params.mode
|
||||
if mode != "pre_mcp_call":
|
||||
raise ValueError(
|
||||
f"MCPJWTSigner guardrail '{guardrail_name}' has mode='{mode}' but must use "
|
||||
"mode='pre_mcp_call'. JWT injection only fires for MCP tool calls."
|
||||
)
|
||||
|
||||
optional_params = getattr(litellm_params, "optional_params", None)
|
||||
|
||||
def _get(key): # type: ignore[no-untyped-def]
|
||||
if optional_params is not None:
|
||||
v = getattr(optional_params, key, None)
|
||||
if v is not None:
|
||||
return v
|
||||
return getattr(litellm_params, key, None)
|
||||
|
||||
signer = MCPJWTSigner(
|
||||
guardrail_name=guardrail_name,
|
||||
event_hook=litellm_params.mode,
|
||||
default_on=litellm_params.default_on,
|
||||
# Core signing
|
||||
issuer=_get("issuer"),
|
||||
audience=_get("audience"),
|
||||
ttl_seconds=_get("ttl_seconds"),
|
||||
# FR-5: verify + re-sign
|
||||
access_token_discovery_uri=_get("access_token_discovery_uri"),
|
||||
token_introspection_endpoint=_get("token_introspection_endpoint"),
|
||||
verify_issuer=_get("verify_issuer"),
|
||||
verify_audience=_get("verify_audience"),
|
||||
# FR-12: end-user identity mapping
|
||||
end_user_claim_sources=_get("end_user_claim_sources"),
|
||||
# FR-13: claim operations
|
||||
add_claims=_get("add_claims"),
|
||||
set_claims=_get("set_claims"),
|
||||
remove_claims=_get("remove_claims"),
|
||||
# FR-14: two-token model
|
||||
channel_token_audience=_get("channel_token_audience"),
|
||||
channel_token_ttl=_get("channel_token_ttl"),
|
||||
# FR-15: incoming claim validation
|
||||
required_claims=_get("required_claims"),
|
||||
optional_claims=_get("optional_claims"),
|
||||
# FR-9: debug headers
|
||||
debug_headers=_get("debug_headers") or False,
|
||||
# FR-10: configurable scopes
|
||||
allowed_scopes=_get("allowed_scopes"),
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_callback(signer)
|
||||
return signer
|
||||
|
||||
|
||||
guardrail_initializer_registry = {
|
||||
SupportedGuardrailIntegrations.MCP_JWT_SIGNER.value: initialize_guardrail,
|
||||
}
|
||||
|
||||
guardrail_class_registry = {
|
||||
SupportedGuardrailIntegrations.MCP_JWT_SIGNER.value: MCPJWTSigner,
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"MCPJWTSigner",
|
||||
"initialize_guardrail",
|
||||
"get_mcp_jwt_signer",
|
||||
]
|
||||
@ -0,0 +1,891 @@
|
||||
"""
|
||||
MCPJWTSigner — Built-in LiteLLM guardrail for zero trust MCP authentication.
|
||||
|
||||
Signs outbound MCP requests with a LiteLLM-issued RS256 JWT so that MCP servers
|
||||
can trust a single signing authority (liteLLM) instead of every upstream IdP.
|
||||
|
||||
Usage in config.yaml:
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "mcp-jwt-signer"
|
||||
litellm_params:
|
||||
guardrail: mcp_jwt_signer
|
||||
mode: "pre_mcp_call"
|
||||
default_on: true
|
||||
|
||||
# Core signing config
|
||||
issuer: "https://my-litellm.example.com" # optional
|
||||
audience: "mcp" # optional
|
||||
ttl_seconds: 300 # optional
|
||||
|
||||
# FR-5: Verify + re-sign — validate incoming Bearer token before signing
|
||||
access_token_discovery_uri: "https://idp.example.com/.well-known/openid-configuration"
|
||||
token_introspection_endpoint: "https://idp.example.com/introspect" # opaque tokens
|
||||
verify_issuer: "https://idp.example.com" # expected iss in incoming JWT
|
||||
verify_audience: "api://my-app" # expected aud in incoming JWT
|
||||
|
||||
# FR-12: End-user identity mapping — ordered resolution chain
|
||||
# Supported: token:<claim>, litellm:user_id, litellm:email,
|
||||
# litellm:end_user_id, litellm:team_id
|
||||
end_user_claim_sources:
|
||||
- "token:sub"
|
||||
- "token:email"
|
||||
- "litellm:user_id"
|
||||
|
||||
# FR-13: Claim operations
|
||||
add_claims: # add if key not already present in the JWT
|
||||
deployment_id: "prod-001"
|
||||
set_claims: # always set (overrides computed value)
|
||||
env: "production"
|
||||
remove_claims: # remove from final JWT
|
||||
- "nbf"
|
||||
|
||||
# FR-14: Two-token model — issue a second JWT for the MCP transport channel
|
||||
channel_token_audience: "bedrock-gateway"
|
||||
channel_token_ttl: 60
|
||||
|
||||
# FR-15: Incoming claim validation — enforce required IdP claims
|
||||
required_claims:
|
||||
- "sub"
|
||||
- "email"
|
||||
optional_claims: # pass through from jwt_claims into outbound JWT
|
||||
- "groups"
|
||||
- "roles"
|
||||
|
||||
# FR-9: Debug headers
|
||||
debug_headers: false # emit x-litellm-mcp-debug header when true
|
||||
|
||||
# FR-10: Configurable scopes — explicit list replaces auto-generation
|
||||
allowed_scopes:
|
||||
- "mcp:tools/call"
|
||||
- "mcp:tools/list"
|
||||
|
||||
MCP servers verify tokens via:
|
||||
GET /.well-known/openid-configuration → { jwks_uri: ".../.well-known/jwks.json" }
|
||||
GET /.well-known/jwks.json → RSA public key in JWKS format
|
||||
|
||||
Optionally set MCP_JWT_SIGNING_KEY env var (PEM string or file:///path) to use
|
||||
your own RSA keypair. If unset, an RSA-2048 keypair is auto-generated at startup.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import jwt
|
||||
from cryptography.hazmat.primitives import serialization
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
||||
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_guardrail import (
|
||||
CustomGuardrail,
|
||||
log_guardrail_information,
|
||||
)
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.utils import CallTypesLiteral
|
||||
|
||||
# Module-level singleton for the JWKS discovery endpoint to access.
|
||||
_mcp_jwt_signer_instance: Optional["MCPJWTSigner"] = None
|
||||
|
||||
# Simple in-memory JWKS cache: keyed by JWKS URI → (keys_list, fetched_at).
|
||||
_jwks_cache: Dict[str, tuple] = {}
|
||||
_JWKS_CACHE_TTL = 3600 # 1 hour
|
||||
|
||||
|
||||
def get_mcp_jwt_signer() -> Optional["MCPJWTSigner"]:
|
||||
"""Return the active MCPJWTSigner singleton, or None if not initialized."""
|
||||
return _mcp_jwt_signer_instance
|
||||
|
||||
|
||||
def _load_private_key_from_env(env_var: str) -> RSAPrivateKey:
|
||||
"""Load an RSA private key from an env var (PEM string or file:// path)."""
|
||||
key_material = os.environ.get(env_var, "")
|
||||
if not key_material:
|
||||
raise ValueError(
|
||||
f"MCPJWTSigner: environment variable '{env_var}' is set but empty."
|
||||
)
|
||||
if key_material.startswith("file://"):
|
||||
path = key_material[len("file://") :]
|
||||
with open(path, "rb") as f:
|
||||
key_bytes = f.read()
|
||||
else:
|
||||
key_bytes = key_material.encode("utf-8")
|
||||
return serialization.load_pem_private_key(key_bytes, password=None) # type: ignore[return-value]
|
||||
|
||||
|
||||
def _generate_rsa_key_pair() -> RSAPrivateKey:
|
||||
"""Generate a new RSA-2048 private key."""
|
||||
return rsa.generate_private_key(
|
||||
public_exponent=65537,
|
||||
key_size=2048,
|
||||
)
|
||||
|
||||
|
||||
def _int_to_base64url(n: int) -> str:
|
||||
"""Encode an integer as a base64url string (no padding)."""
|
||||
byte_length = (n.bit_length() + 7) // 8
|
||||
return (
|
||||
base64.urlsafe_b64encode(n.to_bytes(byte_length, byteorder="big"))
|
||||
.rstrip(b"=")
|
||||
.decode("ascii")
|
||||
)
|
||||
|
||||
|
||||
def _compute_kid(public_key: Any) -> str:
|
||||
"""Derive a key ID from the public key's DER encoding (SHA-256, first 16 hex chars)."""
|
||||
der_bytes = public_key.public_bytes(
|
||||
encoding=serialization.Encoding.DER,
|
||||
format=serialization.PublicFormat.SubjectPublicKeyInfo,
|
||||
)
|
||||
return hashlib.sha256(der_bytes).hexdigest()[:16]
|
||||
|
||||
|
||||
async def _fetch_jwks(jwks_uri: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Fetch and cache a JWKS from the given URI.
|
||||
|
||||
Results are cached for _JWKS_CACHE_TTL seconds to avoid hammering the IdP.
|
||||
"""
|
||||
now = time.time()
|
||||
cached = _jwks_cache.get(jwks_uri)
|
||||
if cached is not None:
|
||||
keys, fetched_at = cached
|
||||
if now - fetched_at < _JWKS_CACHE_TTL:
|
||||
return keys # type: ignore[return-value]
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
resp = await client.get(jwks_uri, headers={"Accept": "application/json"})
|
||||
resp.raise_for_status()
|
||||
keys = resp.json().get("keys", [])
|
||||
_jwks_cache[jwks_uri] = (keys, now)
|
||||
return keys # type: ignore[return-value]
|
||||
|
||||
|
||||
async def _fetch_oidc_discovery(discovery_uri: str) -> Dict[str, Any]:
|
||||
"""Fetch an OIDC discovery document and return its parsed JSON."""
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
resp = await client.get(discovery_uri, headers={"Accept": "application/json"})
|
||||
resp.raise_for_status()
|
||||
return resp.json() # type: ignore[return-value]
|
||||
|
||||
|
||||
class MCPJWTSigner(CustomGuardrail):
|
||||
"""
|
||||
Built-in LiteLLM guardrail that signs outbound MCP requests with a
|
||||
LiteLLM-issued RS256 JWT, enabling zero trust authentication.
|
||||
|
||||
MCP servers verify tokens using liteLLM's OIDC discovery endpoint and
|
||||
JWKS endpoint rather than trusting each upstream IdP directly.
|
||||
|
||||
The signed JWT carries:
|
||||
- iss: LiteLLM issuer identifier
|
||||
- aud: MCP audience (configurable)
|
||||
- sub: End-user identity (resolved via end_user_claim_sources, RFC 8693)
|
||||
- act: Actor/agent identity (team_id or org_id, RFC 8693 delegation)
|
||||
- scope: Tool-level access scopes (configurable via allowed_scopes)
|
||||
- iat, exp, nbf: Standard timing claims
|
||||
|
||||
Feature set:
|
||||
FR-5: Verify + re-sign (access_token_discovery_uri, token_introspection_endpoint)
|
||||
FR-9: Debug headers (debug_headers)
|
||||
FR-10: Configurable scopes (allowed_scopes)
|
||||
FR-12: Configurable end-user identity mapping (end_user_claim_sources)
|
||||
FR-13: Claim operations (add_claims, set_claims, remove_claims)
|
||||
FR-14: Two-token model (channel_token_audience, channel_token_ttl)
|
||||
FR-15: Incoming claim validation (required_claims, optional_claims)
|
||||
"""
|
||||
|
||||
ALGORITHM = "RS256"
|
||||
DEFAULT_TTL = 300
|
||||
DEFAULT_AUDIENCE = "mcp"
|
||||
SIGNING_KEY_ENV = "MCP_JWT_SIGNING_KEY"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# Core signing config
|
||||
issuer: Optional[str] = None,
|
||||
audience: Optional[str] = None,
|
||||
ttl_seconds: Optional[int] = None,
|
||||
# FR-5: Verify + re-sign
|
||||
access_token_discovery_uri: Optional[str] = None,
|
||||
token_introspection_endpoint: Optional[str] = None,
|
||||
verify_issuer: Optional[str] = None,
|
||||
verify_audience: Optional[str] = None,
|
||||
# FR-12: End-user identity mapping
|
||||
end_user_claim_sources: Optional[List[str]] = None,
|
||||
# FR-13: Claim operations
|
||||
add_claims: Optional[Dict[str, Any]] = None,
|
||||
set_claims: Optional[Dict[str, Any]] = None,
|
||||
remove_claims: Optional[List[str]] = None,
|
||||
# FR-14: Two-token model
|
||||
channel_token_audience: Optional[str] = None,
|
||||
channel_token_ttl: Optional[int] = None,
|
||||
# FR-15: Incoming claim validation
|
||||
required_claims: Optional[List[str]] = None,
|
||||
optional_claims: Optional[List[str]] = None,
|
||||
# FR-9: Debug headers
|
||||
debug_headers: bool = False,
|
||||
# FR-10: Configurable scopes
|
||||
allowed_scopes: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# --- Signing key setup ---
|
||||
key_material = os.environ.get(self.SIGNING_KEY_ENV)
|
||||
if key_material:
|
||||
self._private_key = _load_private_key_from_env(self.SIGNING_KEY_ENV)
|
||||
self._persistent_key: bool = True
|
||||
verbose_proxy_logger.info(
|
||||
"MCPJWTSigner: loaded RSA key from env var %s", self.SIGNING_KEY_ENV
|
||||
)
|
||||
else:
|
||||
self._private_key = _generate_rsa_key_pair()
|
||||
self._persistent_key = False
|
||||
verbose_proxy_logger.info(
|
||||
"MCPJWTSigner: auto-generated RSA-2048 keypair (set %s to use your own key)",
|
||||
self.SIGNING_KEY_ENV,
|
||||
)
|
||||
|
||||
self._public_key = self._private_key.public_key()
|
||||
self._kid = _compute_kid(self._public_key)
|
||||
|
||||
# --- Core config ---
|
||||
self.issuer: str = (
|
||||
issuer
|
||||
or os.environ.get("MCP_JWT_ISSUER")
|
||||
or os.environ.get("LITELLM_EXTERNAL_URL")
|
||||
or "litellm"
|
||||
)
|
||||
self.audience: str = (
|
||||
audience or os.environ.get("MCP_JWT_AUDIENCE") or self.DEFAULT_AUDIENCE
|
||||
)
|
||||
resolved_ttl = int(
|
||||
ttl_seconds
|
||||
if ttl_seconds is not None
|
||||
else os.environ.get("MCP_JWT_TTL_SECONDS", str(self.DEFAULT_TTL))
|
||||
)
|
||||
if resolved_ttl <= 0:
|
||||
raise ValueError(
|
||||
f"MCPJWTSigner: ttl_seconds must be > 0, got {resolved_ttl}"
|
||||
)
|
||||
self.ttl_seconds: int = resolved_ttl
|
||||
|
||||
# --- FR-5: Verify + re-sign ---
|
||||
self.access_token_discovery_uri: Optional[str] = access_token_discovery_uri
|
||||
self.token_introspection_endpoint: Optional[str] = token_introspection_endpoint
|
||||
self.verify_issuer: Optional[str] = verify_issuer
|
||||
self.verify_audience: Optional[str] = verify_audience
|
||||
# Cached OIDC discovery document (fetched lazily, TTL = 24 h)
|
||||
self._oidc_discovery_doc: Optional[Dict[str, Any]] = None
|
||||
self._oidc_discovery_fetched_at: float = 0.0
|
||||
|
||||
# --- FR-12: End-user identity mapping ---
|
||||
# Default chain: try incoming JWT sub, fall back to litellm user_id
|
||||
self.end_user_claim_sources: List[str] = end_user_claim_sources or [
|
||||
"token:sub",
|
||||
"litellm:user_id",
|
||||
]
|
||||
|
||||
# --- FR-13: Claim operations ---
|
||||
self.add_claims: Dict[str, Any] = add_claims or {}
|
||||
self.set_claims: Dict[str, Any] = set_claims or {}
|
||||
self.remove_claims: List[str] = remove_claims or []
|
||||
|
||||
# --- FR-14: Two-token model ---
|
||||
self.channel_token_audience: Optional[str] = channel_token_audience
|
||||
self.channel_token_ttl: int = (
|
||||
channel_token_ttl if channel_token_ttl is not None else self.ttl_seconds
|
||||
)
|
||||
|
||||
# --- FR-15: Incoming claim validation ---
|
||||
self.required_claims: List[str] = required_claims or []
|
||||
self.optional_claims: List[str] = optional_claims or []
|
||||
|
||||
# --- FR-9: Debug headers ---
|
||||
self.debug_headers: bool = debug_headers
|
||||
|
||||
# --- FR-10: Configurable scopes ---
|
||||
self.allowed_scopes: Optional[List[str]] = allowed_scopes
|
||||
|
||||
# Register singleton for JWKS/OIDC discovery endpoints.
|
||||
global _mcp_jwt_signer_instance
|
||||
if _mcp_jwt_signer_instance is not None:
|
||||
verbose_proxy_logger.warning(
|
||||
"MCPJWTSigner: replacing existing singleton — previously issued tokens "
|
||||
"signed with the old key will fail JWKS verification. "
|
||||
"Avoid configuring multiple mcp_jwt_signer guardrails."
|
||||
)
|
||||
_mcp_jwt_signer_instance = self
|
||||
|
||||
verbose_proxy_logger.info(
|
||||
"MCPJWTSigner initialized: issuer=%s audience=%s ttl=%ds kid=%s "
|
||||
"verify=%s channel_token=%s debug=%s",
|
||||
self.issuer,
|
||||
self.audience,
|
||||
self.ttl_seconds,
|
||||
self._kid,
|
||||
bool(self.access_token_discovery_uri),
|
||||
bool(self.channel_token_audience),
|
||||
self.debug_headers,
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public helpers (used by /.well-known/jwks.json endpoint)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@property
|
||||
def jwks_max_age(self) -> int:
|
||||
"""
|
||||
Recommended Cache-Control max-age for the JWKS response (seconds).
|
||||
|
||||
1 hour for persistent keys; 5 minutes for auto-generated keys so MCP
|
||||
servers re-fetch quickly after a proxy restart.
|
||||
"""
|
||||
return 3600 if self._persistent_key else 300
|
||||
|
||||
def get_jwks(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Return the JWKS for the RSA public key.
|
||||
Used by GET /.well-known/jwks.json so MCP servers can verify tokens.
|
||||
"""
|
||||
public_numbers = self._public_key.public_numbers()
|
||||
return {
|
||||
"keys": [
|
||||
{
|
||||
"kty": "RSA",
|
||||
"alg": self.ALGORITHM,
|
||||
"use": "sig",
|
||||
"kid": self._kid,
|
||||
"n": _int_to_base64url(public_numbers.n),
|
||||
"e": _int_to_base64url(public_numbers.e),
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FR-5: Verify + re-sign helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# 24-hour TTL for the OIDC discovery doc — long enough to avoid hammering
|
||||
# the IdP, short enough to pick up jwks_uri changes after key rotation.
|
||||
_OIDC_DISCOVERY_TTL = 86400
|
||||
|
||||
async def _get_oidc_discovery(self) -> Dict[str, Any]:
|
||||
"""Fetch and cache the OIDC discovery document with a 24-hour TTL.
|
||||
|
||||
Only caches when the doc contains a 'jwks_uri' so that a transient or
|
||||
malformed response doesn't permanently disable JWT verification.
|
||||
"""
|
||||
now = time.time()
|
||||
cache_expired = (
|
||||
now - self._oidc_discovery_fetched_at
|
||||
) >= self._OIDC_DISCOVERY_TTL
|
||||
if (
|
||||
self._oidc_discovery_doc is None or cache_expired
|
||||
) and self.access_token_discovery_uri:
|
||||
doc = await _fetch_oidc_discovery(self.access_token_discovery_uri)
|
||||
if "jwks_uri" in doc:
|
||||
self._oidc_discovery_doc = doc
|
||||
self._oidc_discovery_fetched_at = now
|
||||
else:
|
||||
return doc
|
||||
return self._oidc_discovery_doc or {}
|
||||
|
||||
async def _verify_incoming_jwt(self, raw_token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Verify an incoming Bearer JWT against the configured IdP's JWKS.
|
||||
|
||||
Returns the verified payload claims dict.
|
||||
Raises jwt.PyJWTError (or subclass) if verification fails.
|
||||
"""
|
||||
discovery = await self._get_oidc_discovery()
|
||||
jwks_uri = discovery.get("jwks_uri")
|
||||
if not jwks_uri:
|
||||
raise ValueError(
|
||||
"MCPJWTSigner: access_token_discovery_uri discovery document "
|
||||
f"at {self.access_token_discovery_uri!r} has no 'jwks_uri'."
|
||||
)
|
||||
|
||||
jwks_keys = await _fetch_jwks(jwks_uri)
|
||||
|
||||
# Only read `kid` from the unverified header — never `alg`.
|
||||
# Reading `alg` from an attacker-controlled header enables algorithm
|
||||
# confusion attacks (e.g. alg:none, HS256 with the public key as secret).
|
||||
# The algorithm is determined from the JWKS key entry instead.
|
||||
unverified_header = jwt.get_unverified_header(raw_token)
|
||||
kid = unverified_header.get("kid")
|
||||
|
||||
# Build a JWKS object and pick the matching key.
|
||||
# PyJWT's PyJWKSet handles key-type parsing and kid matching correctly.
|
||||
from jwt import PyJWKSet
|
||||
|
||||
try:
|
||||
jwks_set = PyJWKSet.from_dict({"keys": jwks_keys})
|
||||
except Exception as exc:
|
||||
raise jwt.exceptions.PyJWKSetError( # type: ignore[attr-defined]
|
||||
f"Failed to parse JWKS from {jwks_uri!r}: {exc}"
|
||||
) from exc
|
||||
|
||||
signing_jwk = None
|
||||
for jwk_obj in jwks_set.keys:
|
||||
if not kid or jwk_obj.key_id == kid:
|
||||
signing_jwk = jwk_obj
|
||||
break
|
||||
|
||||
if signing_jwk is None:
|
||||
raise jwt.exceptions.PyJWKSetError( # type: ignore[attr-defined]
|
||||
f"No JWKS key matching kid={kid!r} at {jwks_uri!r}"
|
||||
)
|
||||
|
||||
# Use the algorithm declared by the JWKS key entry, not the token header.
|
||||
# PyJWT populates algorithm_name from the key's `alg` field; when absent
|
||||
# it infers from the key type (RSAPublicKey → RS256).
|
||||
alg = getattr(signing_jwk, "algorithm_name", None) or "RS256"
|
||||
|
||||
decode_options: Dict[str, Any] = {"verify_exp": True}
|
||||
decode_kwargs: Dict[str, Any] = {
|
||||
"algorithms": [alg],
|
||||
"options": decode_options,
|
||||
}
|
||||
if self.verify_audience:
|
||||
decode_kwargs["audience"] = self.verify_audience
|
||||
else:
|
||||
decode_options["verify_aud"] = False
|
||||
|
||||
if self.verify_issuer:
|
||||
decode_kwargs["issuer"] = self.verify_issuer
|
||||
|
||||
payload: Dict[str, Any] = jwt.decode(
|
||||
raw_token, signing_jwk.key, **decode_kwargs
|
||||
)
|
||||
return payload
|
||||
|
||||
async def _introspect_opaque_token(self, token: str) -> Dict[str, Any]:
|
||||
"""
|
||||
Perform RFC 7662 token introspection for opaque (non-JWT) tokens.
|
||||
|
||||
Returns the introspection response dict. Raises on HTTP error or
|
||||
inactive token.
|
||||
"""
|
||||
if not self.token_introspection_endpoint:
|
||||
raise ValueError(
|
||||
"MCPJWTSigner: token_introspection_endpoint is required for "
|
||||
"opaque token verification but is not configured."
|
||||
)
|
||||
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
||||
client = get_async_httpx_client(llm_provider=httpxSpecialProvider.Oauth2Check)
|
||||
resp = await client.post(
|
||||
self.token_introspection_endpoint,
|
||||
data={"token": token},
|
||||
headers={"Accept": "application/json"},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
result: Dict[str, Any] = resp.json()
|
||||
if not result.get("active", False):
|
||||
raise jwt.exceptions.ExpiredSignatureError( # type: ignore[attr-defined]
|
||||
"MCPJWTSigner: incoming token is inactive (introspection returned active=false)"
|
||||
)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FR-15: Incoming claim validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _validate_required_claims(
|
||||
self,
|
||||
jwt_claims: Optional[Dict[str, Any]],
|
||||
) -> None:
|
||||
"""
|
||||
Raise HTTP 403 if any required_claims are absent from the verified
|
||||
incoming token claims.
|
||||
"""
|
||||
if not self.required_claims:
|
||||
return
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
missing = [c for c in self.required_claims if not (jwt_claims or {}).get(c)]
|
||||
if missing:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={
|
||||
"error": (
|
||||
f"MCPJWTSigner: incoming token is missing required claims: "
|
||||
f"{missing}. Configure the IdP to include these claims."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FR-12: End-user identity mapping
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _resolve_end_user_identity(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
jwt_claims: Optional[Dict[str, Any]],
|
||||
) -> str:
|
||||
"""
|
||||
Resolve the outbound JWT 'sub' using the ordered end_user_claim_sources list.
|
||||
|
||||
Supported source prefixes:
|
||||
token:<claim> — from verified incoming JWT / introspection claims
|
||||
litellm:user_id — from UserAPIKeyAuth.user_id
|
||||
litellm:email — from UserAPIKeyAuth.user_email
|
||||
litellm:end_user_id — from UserAPIKeyAuth.end_user_id
|
||||
litellm:team_id — from UserAPIKeyAuth.team_id
|
||||
|
||||
Falls back to a stable hash of the API token for service-account callers.
|
||||
"""
|
||||
for source in self.end_user_claim_sources:
|
||||
value: Optional[str] = None
|
||||
|
||||
if source.startswith("token:"):
|
||||
claim_name = source[len("token:") :]
|
||||
raw = (jwt_claims or {}).get(claim_name)
|
||||
value = str(raw) if raw else None
|
||||
|
||||
elif source == "litellm:user_id":
|
||||
uid = getattr(user_api_key_dict, "user_id", None)
|
||||
value = str(uid) if uid else None
|
||||
|
||||
elif source == "litellm:email":
|
||||
email = getattr(user_api_key_dict, "user_email", None)
|
||||
value = str(email) if email else None
|
||||
|
||||
elif source == "litellm:end_user_id":
|
||||
eid = getattr(user_api_key_dict, "end_user_id", None)
|
||||
value = str(eid) if eid else None
|
||||
|
||||
elif source == "litellm:team_id":
|
||||
tid = getattr(user_api_key_dict, "team_id", None)
|
||||
value = str(tid) if tid else None
|
||||
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"MCPJWTSigner: unknown end_user_claim_source %r — skipping", source
|
||||
)
|
||||
continue
|
||||
|
||||
if value:
|
||||
return value
|
||||
|
||||
# Final fallback for service accounts with no user identity
|
||||
token = getattr(user_api_key_dict, "token", None) or getattr(
|
||||
user_api_key_dict, "api_key", None
|
||||
)
|
||||
if token:
|
||||
return "apikey:" + hashlib.sha256(str(token).encode()).hexdigest()[:16]
|
||||
return "litellm-proxy"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FR-10: Scope building
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_scope(self, raw_tool_name: str) -> str:
|
||||
"""
|
||||
Build the JWT scope string.
|
||||
|
||||
When allowed_scopes is configured: join them verbatim.
|
||||
Otherwise auto-generate minimal, least-privilege scopes:
|
||||
- Tool call → mcp:tools/call mcp:tools/<name>:call
|
||||
- No tool → mcp:tools/call mcp:tools/list
|
||||
|
||||
NOTE: tools/list is intentionally NOT granted on tool-call JWTs to
|
||||
prevent callers from enumerating tools they didn't ask to use.
|
||||
"""
|
||||
if self.allowed_scopes is not None:
|
||||
return " ".join(self.allowed_scopes)
|
||||
|
||||
tool_name = (
|
||||
re.sub(r"[^a-zA-Z0-9_\-]", "_", raw_tool_name) if raw_tool_name else ""
|
||||
)
|
||||
if tool_name:
|
||||
scopes = ["mcp:tools/call", f"mcp:tools/{tool_name}:call"]
|
||||
else:
|
||||
scopes = ["mcp:tools/call", "mcp:tools/list"]
|
||||
return " ".join(scopes)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FR-13: Claim operations
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _apply_claim_operations(self, claims: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Apply add_claims, set_claims, and remove_claims to the claim dict."""
|
||||
# add_claims: insert only when key is absent
|
||||
for k, v in self.add_claims.items():
|
||||
if k not in claims:
|
||||
claims[k] = v
|
||||
|
||||
# set_claims: always override (highest priority)
|
||||
claims = {**claims, **self.set_claims}
|
||||
|
||||
# remove_claims: delete listed keys
|
||||
for k in self.remove_claims:
|
||||
claims.pop(k, None)
|
||||
|
||||
return claims
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FR-15: optional_claims passthrough
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _passthrough_optional_claims(
|
||||
self,
|
||||
claims: Dict[str, Any],
|
||||
jwt_claims: Optional[Dict[str, Any]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Forward optional_claims from verified incoming token into the outbound JWT."""
|
||||
if not self.optional_claims or not jwt_claims:
|
||||
return claims
|
||||
for claim in self.optional_claims:
|
||||
if claim in jwt_claims and claim not in claims:
|
||||
claims[claim] = jwt_claims[claim]
|
||||
return claims
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core JWT builder
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _build_claims(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
data: dict,
|
||||
jwt_claims: Optional[Dict[str, Any]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build JWT claims for the outbound MCP access token.
|
||||
|
||||
Args:
|
||||
user_api_key_dict: LiteLLM auth context for the current request.
|
||||
data: Pre-call hook data dict (contains mcp_tool_name etc.).
|
||||
jwt_claims: Verified incoming IdP claims (FR-5), or LiteLLM-decoded
|
||||
jwt_claims if available. None for pure API-key requests.
|
||||
"""
|
||||
now = int(time.time())
|
||||
claims: Dict[str, Any] = {
|
||||
"iss": self.issuer,
|
||||
"aud": self.audience,
|
||||
"iat": now,
|
||||
"exp": now + self.ttl_seconds,
|
||||
"nbf": now,
|
||||
}
|
||||
|
||||
# sub — resolved via ordered claim sources (FR-12)
|
||||
claims["sub"] = self._resolve_end_user_identity(user_api_key_dict, jwt_claims)
|
||||
|
||||
# email passthrough when available from LiteLLM context
|
||||
user_email = getattr(user_api_key_dict, "user_email", None)
|
||||
if user_email:
|
||||
claims["email"] = user_email
|
||||
|
||||
# act — RFC 8693 delegation claim (team/org context)
|
||||
team_id = getattr(user_api_key_dict, "team_id", None)
|
||||
org_id = getattr(user_api_key_dict, "org_id", None)
|
||||
act_sub = team_id or org_id or "litellm-proxy"
|
||||
claims["act"] = {"sub": act_sub}
|
||||
|
||||
# end_user_id when set separately from user_id
|
||||
end_user_id = getattr(user_api_key_dict, "end_user_id", None)
|
||||
if end_user_id:
|
||||
claims["end_user_id"] = end_user_id
|
||||
|
||||
# scope (FR-10)
|
||||
raw_tool_name: str = data.get("mcp_tool_name", "")
|
||||
claims["scope"] = self._build_scope(raw_tool_name)
|
||||
|
||||
# optional_claims passthrough (FR-15)
|
||||
claims = self._passthrough_optional_claims(claims, jwt_claims)
|
||||
|
||||
# Claim operations — applied last so admin overrides take effect (FR-13)
|
||||
claims = self._apply_claim_operations(claims)
|
||||
|
||||
return claims
|
||||
|
||||
def _build_channel_token_claims(
|
||||
self,
|
||||
base_claims: Dict[str, Any],
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build claims for the channel token (FR-14 two-token model).
|
||||
|
||||
Inherits sub/act/scope from the access token but uses a separate
|
||||
audience and TTL so the transport layer and resource layer receive
|
||||
purpose-bound credentials.
|
||||
"""
|
||||
now = int(time.time())
|
||||
return {
|
||||
**base_claims,
|
||||
"aud": self.channel_token_audience,
|
||||
"iat": now,
|
||||
"exp": now + self.channel_token_ttl,
|
||||
"nbf": now,
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FR-9: Debug header
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@staticmethod
|
||||
def _build_debug_header(claims: Dict[str, Any], kid: str) -> str:
|
||||
"""
|
||||
Build the x-litellm-mcp-debug header value.
|
||||
|
||||
Format: v=1; kid=<kid>; sub=<sub>; iss=<iss>; exp=<exp>; scope=<scope>
|
||||
Scope is truncated to 80 chars for header safety.
|
||||
"""
|
||||
sub = claims.get("sub", "")
|
||||
iss = claims.get("iss", "")
|
||||
exp = claims.get("exp", 0)
|
||||
scope = claims.get("scope", "")
|
||||
if len(scope) > 80:
|
||||
scope = scope[:77] + "..."
|
||||
return f"v=1; kid={kid}; sub={sub}; iss={iss}; exp={exp}; scope={scope}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Guardrail hook
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@log_guardrail_information
|
||||
async def async_pre_call_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
cache: DualCache,
|
||||
data: dict,
|
||||
call_type: CallTypesLiteral,
|
||||
) -> Optional[Union[Exception, str, dict]]:
|
||||
"""
|
||||
Verifies the incoming token (when configured), validates required claims,
|
||||
then signs an outbound JWT and injects it as the Authorization header.
|
||||
|
||||
All non-MCP call types pass through unchanged.
|
||||
"""
|
||||
if call_type != "call_mcp_tool":
|
||||
return data
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FR-5: Verify incoming token before re-signing
|
||||
# ------------------------------------------------------------------
|
||||
jwt_claims: Optional[Dict[str, Any]] = None
|
||||
raw_token: Optional[str] = data.get("incoming_bearer_token")
|
||||
|
||||
if self.access_token_discovery_uri and raw_token:
|
||||
# Three-dot pattern → JWT; otherwise opaque.
|
||||
is_jwt = raw_token.count(".") == 2
|
||||
try:
|
||||
if is_jwt:
|
||||
jwt_claims = await self._verify_incoming_jwt(raw_token)
|
||||
elif self.token_introspection_endpoint:
|
||||
jwt_claims = await self._introspect_opaque_token(raw_token)
|
||||
else:
|
||||
verbose_proxy_logger.warning(
|
||||
"MCPJWTSigner: access_token_discovery_uri is set but the "
|
||||
"incoming token appears to be opaque and no "
|
||||
"token_introspection_endpoint is configured. "
|
||||
"Proceeding without incoming token verification."
|
||||
)
|
||||
except Exception as exc:
|
||||
verbose_proxy_logger.error(
|
||||
"MCPJWTSigner: incoming token verification failed: %s", exc
|
||||
)
|
||||
from fastapi import HTTPException
|
||||
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": (
|
||||
f"MCPJWTSigner: incoming token verification failed: {exc}"
|
||||
)
|
||||
},
|
||||
)
|
||||
elif not raw_token and self.access_token_discovery_uri:
|
||||
verbose_proxy_logger.debug(
|
||||
"MCPJWTSigner: access_token_discovery_uri configured but no Bearer "
|
||||
"token found in request (API-key auth request — skipping verification)."
|
||||
)
|
||||
|
||||
# Fall back to LiteLLM-decoded JWT claims (available when proxy uses JWT auth).
|
||||
if jwt_claims is None:
|
||||
jwt_claims = getattr(user_api_key_dict, "jwt_claims", None)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FR-15: Validate required claims
|
||||
# ------------------------------------------------------------------
|
||||
self._validate_required_claims(jwt_claims)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build outbound access token
|
||||
# ------------------------------------------------------------------
|
||||
claims = self._build_claims(user_api_key_dict, data, jwt_claims)
|
||||
|
||||
signed_token = jwt.encode(
|
||||
claims,
|
||||
self._private_key,
|
||||
algorithm=self.ALGORITHM,
|
||||
headers={"kid": self._kid},
|
||||
)
|
||||
|
||||
# Merge into existing extra_headers — a prior guardrail in the chain may
|
||||
# have already injected tracing headers or correlation IDs.
|
||||
existing_headers: Dict[str, str] = data.get("extra_headers") or {}
|
||||
new_headers: Dict[str, str] = {
|
||||
**existing_headers,
|
||||
"Authorization": f"Bearer {signed_token}",
|
||||
}
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FR-14: Two-token model — channel token
|
||||
# ------------------------------------------------------------------
|
||||
if self.channel_token_audience:
|
||||
channel_claims = self._build_channel_token_claims(claims)
|
||||
channel_token = jwt.encode(
|
||||
channel_claims,
|
||||
self._private_key,
|
||||
algorithm=self.ALGORITHM,
|
||||
headers={"kid": self._kid},
|
||||
)
|
||||
new_headers["x-mcp-channel-token"] = f"Bearer {channel_token}"
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FR-9: Debug header
|
||||
# ------------------------------------------------------------------
|
||||
if self.debug_headers:
|
||||
new_headers["x-litellm-mcp-debug"] = self._build_debug_header(
|
||||
claims, self._kid
|
||||
)
|
||||
|
||||
data["extra_headers"] = new_headers
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"MCPJWTSigner: signed JWT sub=%s act=%s tool=%s exp=%d "
|
||||
"verified=%s channel=%s",
|
||||
claims.get("sub"),
|
||||
claims.get("act", {}).get("sub"),
|
||||
data.get("mcp_tool_name"),
|
||||
claims["exp"],
|
||||
jwt_claims is not None,
|
||||
bool(self.channel_token_audience),
|
||||
)
|
||||
|
||||
return data
|
||||
@ -2142,8 +2142,7 @@ async def _resolve_org_filter_for_user_search(
|
||||
member_org_ids: List[str] = []
|
||||
if caller_user is not None:
|
||||
member_org_ids = [
|
||||
m.organization_id
|
||||
for m in (caller_user.organization_memberships or [])
|
||||
m.organization_id for m in (caller_user.organization_memberships or [])
|
||||
]
|
||||
|
||||
if member_org_ids:
|
||||
|
||||
@ -1863,16 +1863,10 @@ async def _validate_update_key_data(
|
||||
user_api_key_cache: Any,
|
||||
) -> None:
|
||||
"""Validate permissions and constraints for key update."""
|
||||
_is_proxy_admin = (
|
||||
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
)
|
||||
_is_proxy_admin = user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
|
||||
# Prevent non-admin from removing user_id (setting to empty string) (LIT-1884)
|
||||
if (
|
||||
data.user_id is not None
|
||||
and data.user_id == ""
|
||||
and not _is_proxy_admin
|
||||
):
|
||||
if data.user_id is not None and data.user_id == "" and not _is_proxy_admin:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="Non-admin users cannot remove the user_id from a key.",
|
||||
|
||||
@ -857,7 +857,13 @@ async def new_team( # noqa: PLR0915
|
||||
|
||||
# Apply defaults from litellm.default_team_params for any fields
|
||||
# not explicitly provided in the request.
|
||||
for field in ("max_budget", "budget_duration", "tpm_limit", "rpm_limit", "team_member_permissions"):
|
||||
for field in (
|
||||
"max_budget",
|
||||
"budget_duration",
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
"team_member_permissions",
|
||||
):
|
||||
if getattr(data, field, None) is None:
|
||||
default_value = _get_default_team_param(field)
|
||||
if default_value is not None:
|
||||
|
||||
@ -857,7 +857,10 @@ async def update_batch_in_database(
|
||||
# If the batch_processed column doesn't exist (old schema),
|
||||
# retry without it so the status update still succeeds.
|
||||
err_str = str(col_err).lower()
|
||||
if "batch_processed" in err_str and update_data.get("batch_processed") is not None:
|
||||
if (
|
||||
"batch_processed" in err_str
|
||||
and update_data.get("batch_processed") is not None
|
||||
):
|
||||
verbose_proxy_logger.warning(
|
||||
f"batch_processed column not found, retrying update without it: {col_err}"
|
||||
)
|
||||
|
||||
@ -468,6 +468,12 @@ class ProxyInitializationHelpers:
|
||||
type=str,
|
||||
help="Path to the logging configuration file",
|
||||
)
|
||||
@click.option(
|
||||
"--setup",
|
||||
is_flag=True,
|
||||
default=False,
|
||||
help="Run the interactive setup wizard to configure providers and generate a config file",
|
||||
)
|
||||
@click.option(
|
||||
"--version",
|
||||
"-v",
|
||||
@ -598,6 +604,7 @@ def run_server( # noqa: PLR0915
|
||||
num_requests,
|
||||
use_queue,
|
||||
health,
|
||||
setup,
|
||||
version,
|
||||
run_gunicorn,
|
||||
run_hypercorn,
|
||||
@ -611,6 +618,12 @@ def run_server( # noqa: PLR0915
|
||||
max_requests_before_restart,
|
||||
enforce_prisma_migration_check: bool,
|
||||
):
|
||||
if setup:
|
||||
from litellm.setup_wizard import run_setup_wizard
|
||||
|
||||
run_setup_wizard()
|
||||
return
|
||||
|
||||
args = locals()
|
||||
if local:
|
||||
from proxy_server import (
|
||||
@ -904,7 +917,7 @@ def run_server( # noqa: PLR0915
|
||||
# Auto-create PROMETHEUS_MULTIPROC_DIR for multi-worker setups
|
||||
ProxyInitializationHelpers._maybe_setup_prometheus_multiproc_dir(
|
||||
num_workers=num_workers,
|
||||
litellm_settings=litellm_settings if config else None,
|
||||
litellm_settings=litellm_settings if config else None, # type: ignore[possibly-unbound]
|
||||
)
|
||||
|
||||
# --- SEPARATE HEALTH APP LOGIC ---
|
||||
|
||||
@ -115,7 +115,9 @@ async def background_streaming_task( # noqa: PLR0915
|
||||
UPDATE_INTERVAL = 0.150 # 150ms batching interval
|
||||
|
||||
# Track the terminal event from the stream (may not be "completed")
|
||||
terminal_status: Optional[ResponsesAPIStatus] = None # Will be set by response.completed/failed/incomplete/cancelled
|
||||
terminal_status: Optional[
|
||||
ResponsesAPIStatus
|
||||
] = None # Will be set by response.completed/failed/incomplete/cancelled
|
||||
terminal_error = None
|
||||
_event_to_status = {
|
||||
"response.completed": "completed",
|
||||
@ -259,7 +261,10 @@ async def background_streaming_task( # noqa: PLR0915
|
||||
)
|
||||
|
||||
# Extract error for failed and incomplete responses
|
||||
if event_type == "response.failed" or event_type == "response.incomplete":
|
||||
if (
|
||||
event_type == "response.failed"
|
||||
or event_type == "response.incomplete"
|
||||
):
|
||||
terminal_error = response_data.get("error")
|
||||
|
||||
# Core response fields
|
||||
|
||||
@ -40,9 +40,7 @@ def _get_registered_vantage_logger():
|
||||
return None
|
||||
|
||||
|
||||
async def _set_vantage_settings(
|
||||
api_key: str, integration_token: str, base_url: str
|
||||
):
|
||||
async def _set_vantage_settings(api_key: str, integration_token: str, base_url: str):
|
||||
"""Store Vantage settings in the database with encrypted API key."""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
@ -341,9 +339,7 @@ async def init_vantage_settings(
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error initializing Vantage settings: {str(e)}"
|
||||
)
|
||||
verbose_proxy_logger.error(f"Error initializing Vantage settings: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={"error": f"Failed to initialize Vantage settings: {str(e)}"},
|
||||
@ -395,7 +391,8 @@ async def vantage_dry_run_export(
|
||||
"""Cast Decimal columns to Float64 so .to_dicts() produces
|
||||
JSON-serializable float values instead of decimal.Decimal."""
|
||||
decimal_cols = [
|
||||
col for col, dtype in zip(frame.columns, frame.dtypes)
|
||||
col
|
||||
for col, dtype in zip(frame.columns, frame.dtypes)
|
||||
if isinstance(dtype, pl.Decimal)
|
||||
]
|
||||
if decimal_cols:
|
||||
@ -404,8 +401,16 @@ async def vantage_dry_run_export(
|
||||
)
|
||||
return frame.to_dicts()
|
||||
|
||||
usage_sample = _to_json_safe_dicts(data.head(min(50, len(data)))) if not data.is_empty() else []
|
||||
normalized_sample = _to_json_safe_dicts(normalized.head(min(50, len(normalized)))) if not normalized.is_empty() else []
|
||||
usage_sample = (
|
||||
_to_json_safe_dicts(data.head(min(50, len(data))))
|
||||
if not data.is_empty()
|
||||
else []
|
||||
)
|
||||
normalized_sample = (
|
||||
_to_json_safe_dicts(normalized.head(min(50, len(normalized))))
|
||||
if not normalized.is_empty()
|
||||
else []
|
||||
)
|
||||
|
||||
# Use the same pre-transform column names as
|
||||
# FocusExportEngine.dry_run_export_usage_data for consistency.
|
||||
@ -437,14 +442,10 @@ async def vantage_dry_run_export(
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
f"Error performing Vantage dry run export: {str(e)}"
|
||||
)
|
||||
verbose_proxy_logger.error(f"Error performing Vantage dry run export: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
"error": f"Failed to perform Vantage dry run export: {str(e)}"
|
||||
},
|
||||
detail={"error": f"Failed to perform Vantage dry run export: {str(e)}"},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@ -454,8 +454,6 @@ class ProxyLogging:
|
||||
|
||||
for hook in PROXY_HOOKS:
|
||||
proxy_hook = get_proxy_hook(hook)
|
||||
import inspect
|
||||
|
||||
expected_args = inspect.getfullargspec(proxy_hook).args
|
||||
passed_in_args: Dict[str, Any] = {}
|
||||
if "internal_usage_cache" in expected_args:
|
||||
@ -559,6 +557,10 @@ class ProxyLogging:
|
||||
"user_api_key_request_route": kwargs.get("user_api_key_request_route"),
|
||||
"mcp_tool_name": request_obj.tool_name, # Keep original for reference
|
||||
"mcp_arguments": request_obj.arguments, # Keep original for reference
|
||||
# Raw Bearer token from the original HTTP request — allows guardrails
|
||||
# (e.g. MCPJWTSigner) to independently verify the caller's identity
|
||||
# before re-signing an outbound token (FR-5 verify+re-sign).
|
||||
"incoming_bearer_token": kwargs.get("incoming_bearer_token"),
|
||||
}
|
||||
|
||||
return synthetic_data
|
||||
@ -824,17 +826,30 @@ class ProxyLogging:
|
||||
) -> dict:
|
||||
"""
|
||||
Helper function to convert pre_call_hook response back to kwargs for MCP usage.
|
||||
|
||||
Supports:
|
||||
- modified_arguments: Override tool call arguments
|
||||
- extra_headers: Inject custom headers into the outbound MCP request
|
||||
"""
|
||||
if not response_data:
|
||||
return original_kwargs
|
||||
|
||||
# Apply any argument modifications from the hook response
|
||||
modified_kwargs = original_kwargs.copy()
|
||||
|
||||
# If the response contains modified arguments, apply them
|
||||
if response_data.get("modified_arguments"):
|
||||
modified_kwargs["arguments"] = response_data["modified_arguments"]
|
||||
|
||||
if response_data.get("extra_headers"):
|
||||
# Merge rather than replace — a prior guardrail in the chain may have
|
||||
# already injected headers (e.g. tracing IDs). Later guardrails win on
|
||||
# key collisions so that the most-specific guardrail (e.g. JWT signer)
|
||||
# takes precedence over earlier ones.
|
||||
existing = modified_kwargs.get("extra_headers") or {}
|
||||
modified_kwargs["extra_headers"] = {
|
||||
**existing,
|
||||
**response_data["extra_headers"],
|
||||
}
|
||||
|
||||
return modified_kwargs
|
||||
|
||||
async def process_pre_call_hook_response(self, response, data, call_type):
|
||||
|
||||
@ -7,7 +7,9 @@ from litellm.types.videos.utils import encode_character_id_with_provider
|
||||
|
||||
def extract_model_from_target_model_names(target_model_names: Any) -> Optional[str]:
|
||||
if isinstance(target_model_names, str):
|
||||
target_model_names = [m.strip() for m in target_model_names.split(",") if m.strip()]
|
||||
target_model_names = [
|
||||
m.strip() for m in target_model_names.split(",") if m.strip()
|
||||
]
|
||||
elif not isinstance(target_model_names, list):
|
||||
return None
|
||||
return target_model_names[0] if target_model_names else None
|
||||
|
||||
@ -692,11 +692,11 @@ def responses(
|
||||
return run_async_function(aresponses_api_with_mcp, **mcp_call_kwargs)
|
||||
|
||||
# get provider config
|
||||
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
|
||||
ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=model,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
responses_api_provider_config: Optional[
|
||||
BaseResponsesAPIConfig
|
||||
] = ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=model,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
local_vars.update(kwargs)
|
||||
@ -908,11 +908,11 @@ def delete_responses(
|
||||
raise ValueError("custom_llm_provider is required but passed as None")
|
||||
|
||||
# get provider config
|
||||
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
|
||||
ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=None,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
responses_api_provider_config: Optional[
|
||||
BaseResponsesAPIConfig
|
||||
] = ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=None,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if responses_api_provider_config is None:
|
||||
@ -1089,11 +1089,11 @@ def get_responses(
|
||||
raise ValueError("custom_llm_provider is required but passed as None")
|
||||
|
||||
# get provider config
|
||||
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
|
||||
ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=None,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
responses_api_provider_config: Optional[
|
||||
BaseResponsesAPIConfig
|
||||
] = ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=None,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if responses_api_provider_config is None:
|
||||
@ -1247,11 +1247,11 @@ def list_input_items(
|
||||
if custom_llm_provider is None:
|
||||
raise ValueError("custom_llm_provider is required but passed as None")
|
||||
|
||||
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
|
||||
ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=None,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
responses_api_provider_config: Optional[
|
||||
BaseResponsesAPIConfig
|
||||
] = ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=None,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if responses_api_provider_config is None:
|
||||
@ -1406,11 +1406,11 @@ def cancel_responses(
|
||||
raise ValueError("custom_llm_provider is required but passed as None")
|
||||
|
||||
# get provider config
|
||||
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
|
||||
ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=None,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
responses_api_provider_config: Optional[
|
||||
BaseResponsesAPIConfig
|
||||
] = ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=None,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if responses_api_provider_config is None:
|
||||
@ -1594,11 +1594,11 @@ def compact_responses(
|
||||
raise ValueError("custom_llm_provider is required but passed as None")
|
||||
|
||||
# get provider config
|
||||
responses_api_provider_config: Optional[BaseResponsesAPIConfig] = (
|
||||
ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=model,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
responses_api_provider_config: Optional[
|
||||
BaseResponsesAPIConfig
|
||||
] = ProviderConfigManager.get_provider_responses_api_config(
|
||||
model=model,
|
||||
provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if responses_api_provider_config is None:
|
||||
|
||||
@ -8611,6 +8611,7 @@ class Router:
|
||||
_model_info = deployment.get("model_info", {})
|
||||
|
||||
# see if we have the info for this model
|
||||
_deployment_model = None # per-deployment model name (avoids overwriting the outer `model` group name)
|
||||
try:
|
||||
base_model = _model_info.get("base_model", None)
|
||||
if base_model is None:
|
||||
@ -8618,7 +8619,7 @@ class Router:
|
||||
model_info = self.get_router_model_info(
|
||||
deployment=deployment, received_model_name=model
|
||||
)
|
||||
model = base_model or _litellm_params.get("model", None)
|
||||
_deployment_model = base_model or _litellm_params.get("model", None)
|
||||
|
||||
if (
|
||||
isinstance(model_info, dict)
|
||||
@ -8632,7 +8633,9 @@ class Router:
|
||||
_context_window_error = True
|
||||
_potential_error_str += (
|
||||
"Model={}, Max Input Tokens={}, Got={}".format(
|
||||
model, model_info["max_input_tokens"], input_tokens
|
||||
_deployment_model,
|
||||
model_info["max_input_tokens"],
|
||||
input_tokens,
|
||||
)
|
||||
)
|
||||
continue
|
||||
@ -8688,13 +8691,21 @@ class Router:
|
||||
|
||||
## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_format' param
|
||||
if request_kwargs is not None and litellm.drop_params is False:
|
||||
# get supported params
|
||||
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=model, litellm_params=LiteLLM_Params(**_litellm_params)
|
||||
# get supported params — use per-deployment model to avoid overwriting the outer model group name
|
||||
_dep_model_for_params = _deployment_model or model
|
||||
(
|
||||
_dep_model_for_params,
|
||||
custom_llm_provider,
|
||||
_,
|
||||
_,
|
||||
) = litellm.get_llm_provider(
|
||||
model=_dep_model_for_params,
|
||||
litellm_params=LiteLLM_Params(**_litellm_params),
|
||||
)
|
||||
|
||||
supported_openai_params = litellm.get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
model=_dep_model_for_params,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if supported_openai_params is None:
|
||||
|
||||
668
litellm/setup_wizard.py
Normal file
668
litellm/setup_wizard.py
Normal file
@ -0,0 +1,668 @@
|
||||
# ruff: noqa: T201
|
||||
# flake8: noqa: T201
|
||||
"""
|
||||
LiteLLM Interactive Setup Wizard
|
||||
|
||||
Guides users through selecting LLM providers, entering API keys,
|
||||
and generating a proxy config file — mirroring the Claude Code onboarding UX.
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
import os
|
||||
import re
|
||||
import secrets
|
||||
import sys
|
||||
import sysconfig
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
# termios / tty are Unix-only; fall back gracefully on Windows
|
||||
try:
|
||||
import termios
|
||||
import tty
|
||||
|
||||
_HAS_RAW_TERMINAL: bool = True
|
||||
except ImportError:
|
||||
termios = None # type: ignore[assignment]
|
||||
tty = None # type: ignore[assignment]
|
||||
_HAS_RAW_TERMINAL = False
|
||||
|
||||
from litellm.utils import check_valid_key
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Provider definitions
|
||||
# ---------------------------------------------------------------------------
|
||||
# Each entry describes one provider card shown in the wizard.
|
||||
# `env_key` — primary env var name (None = no key needed, e.g. Ollama)
|
||||
# `test_model` — model passed to check_valid_key for credential validation
|
||||
# (None = skip validation, e.g. Azure needs a deployment name)
|
||||
# `models` — default models written into the generated config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
PROVIDERS: List[Dict] = [
|
||||
{
|
||||
"id": "openai",
|
||||
"name": "OpenAI",
|
||||
"description": "GPT-4o, GPT-4o-mini, o3-mini",
|
||||
"env_key": "OPENAI_API_KEY",
|
||||
"key_hint": "sk-...",
|
||||
"test_model": "gpt-4o-mini",
|
||||
"models": ["gpt-4o", "gpt-4o-mini"],
|
||||
},
|
||||
{
|
||||
"id": "anthropic",
|
||||
"name": "Anthropic",
|
||||
"description": "Claude Opus 4.6, Sonnet 4.6, Haiku 4.5",
|
||||
"env_key": "ANTHROPIC_API_KEY",
|
||||
"key_hint": "sk-ant-...",
|
||||
"test_model": "claude-haiku-4-5-20251001",
|
||||
"models": ["claude-opus-4-6", "claude-sonnet-4-6", "claude-haiku-4-5-20251001"],
|
||||
},
|
||||
{
|
||||
"id": "gemini",
|
||||
"name": "Google Gemini",
|
||||
"description": "Gemini 2.0 Flash, Gemini 2.5 Pro",
|
||||
"env_key": "GEMINI_API_KEY",
|
||||
"key_hint": "AIza...",
|
||||
"test_model": "gemini/gemini-2.0-flash",
|
||||
"models": ["gemini/gemini-2.0-flash", "gemini/gemini-2.5-pro"],
|
||||
},
|
||||
{
|
||||
"id": "azure",
|
||||
"name": "Azure OpenAI",
|
||||
"description": "GPT-4o via Azure",
|
||||
"env_key": "AZURE_API_KEY",
|
||||
"key_hint": "your-azure-key",
|
||||
"test_model": None, # needs deployment name — skip validation
|
||||
"models": [],
|
||||
"needs_api_base": True,
|
||||
"api_base_hint": "https://<resource>.openai.azure.com/",
|
||||
"api_version": "2024-07-01-preview",
|
||||
},
|
||||
{
|
||||
"id": "bedrock",
|
||||
"name": "AWS Bedrock",
|
||||
"description": "Claude 3.5, Llama 3 via AWS",
|
||||
"env_key": "AWS_ACCESS_KEY_ID",
|
||||
"key_hint": "AKIA...",
|
||||
"test_model": None, # multi-key auth — skip validation
|
||||
"models": ["bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0"],
|
||||
"extra_keys": ["AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME"],
|
||||
"extra_hints": ["your-secret-key", "us-east-1"],
|
||||
},
|
||||
{
|
||||
"id": "ollama",
|
||||
"name": "Ollama",
|
||||
"description": "Local models (llama3.2, mistral, etc.)",
|
||||
"env_key": None,
|
||||
"key_hint": None,
|
||||
"test_model": None, # local — no remote validation
|
||||
"models": ["ollama/llama3.2", "ollama/mistral"],
|
||||
"api_base": "http://localhost:11434",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ANSI colour helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_ANSI_RE = re.compile(r"\033\[[^m]*m")
|
||||
|
||||
_ORANGE = "\033[38;2;215;119;87m"
|
||||
_DIM = "\033[2m"
|
||||
_BOLD = "\033[1m"
|
||||
_GREEN = "\033[38;2;78;186;101m"
|
||||
_BLUE = "\033[38;2;177;185;249m"
|
||||
_GREY = "\033[38;2;153;153;153m"
|
||||
_RESET = "\033[0m"
|
||||
_CHECK = "✔"
|
||||
_CROSS = "✘"
|
||||
|
||||
_CURSOR_HIDE = "\033[?25l"
|
||||
_CURSOR_SHOW = "\033[?25h"
|
||||
_MOVE_UP = "\033[{}A"
|
||||
|
||||
|
||||
def _supports_color() -> bool:
|
||||
return sys.stdout.isatty() and os.environ.get("NO_COLOR") is None
|
||||
|
||||
|
||||
def _c(code: str, text: str) -> str:
|
||||
return f"{code}{text}{_RESET}" if _supports_color() else text
|
||||
|
||||
|
||||
def orange(t: str) -> str:
|
||||
return _c(_ORANGE, t)
|
||||
|
||||
|
||||
def bold(t: str) -> str:
|
||||
return _c(_BOLD, t)
|
||||
|
||||
|
||||
def green(t: str) -> str:
|
||||
return _c(_GREEN, t)
|
||||
|
||||
|
||||
def blue(t: str) -> str:
|
||||
return _c(_BLUE, t)
|
||||
|
||||
|
||||
def grey(t: str) -> str:
|
||||
return _c(_GREY, t)
|
||||
|
||||
|
||||
def dim(t: str) -> str:
|
||||
return _c(_DIM, t)
|
||||
|
||||
|
||||
def _divider() -> str:
|
||||
"""Return a styled divider line (evaluated at call-time, not import-time)."""
|
||||
return dim(" " + "╌" * 74)
|
||||
|
||||
|
||||
def _styled_input(prompt: str) -> str:
|
||||
"""
|
||||
Like input() but wraps ANSI sequences in readline ignore markers
|
||||
(\\001...\\002) so readline correctly tracks the cursor column.
|
||||
In non-TTY contexts, strips ANSI entirely so no escape codes appear.
|
||||
"""
|
||||
if sys.stdout.isatty():
|
||||
rl_prompt = _ANSI_RE.sub(lambda m: f"\001{m.group()}\002", prompt)
|
||||
else:
|
||||
rl_prompt = _ANSI_RE.sub("", prompt)
|
||||
return input(rl_prompt).strip()
|
||||
|
||||
|
||||
def _yaml_escape(value: str) -> str:
|
||||
"""Escape a string for safe embedding in a double-quoted YAML scalar."""
|
||||
return (
|
||||
value.replace("\\", "\\\\")
|
||||
.replace('"', '\\"')
|
||||
.replace("\n", "\\n")
|
||||
.replace("\r", "\\r")
|
||||
.replace("\t", "\\t")
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Layout constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
LITELLM_ASCII = r"""
|
||||
██╗ ██╗████████╗███████╗██╗ ██╗ ███╗ ███╗
|
||||
██║ ██║╚══██╔══╝██╔════╝██║ ██║ ████╗ ████║
|
||||
██║ ██║ ██║ █████╗ ██║ ██║ ██╔████╔██║
|
||||
██║ ██║ ██║ ██╔══╝ ██║ ██║ ██║╚██╔╝██║
|
||||
███████╗██║ ██║ ███████╗███████╗███████╗██║ ╚═╝ ██║
|
||||
╚══════╝╚═╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚═╝ ╚═╝
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Setup wizard
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SetupWizard:
|
||||
"""
|
||||
Interactive onboarding wizard: provider selection → API keys → config file.
|
||||
|
||||
All methods are static — the class is purely a namespace with clear
|
||||
single-responsibility sections. Entry point: SetupWizard.run().
|
||||
"""
|
||||
|
||||
# ── entry point ─────────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def run() -> None:
|
||||
try:
|
||||
SetupWizard._wizard()
|
||||
except (KeyboardInterrupt, EOFError):
|
||||
print(f"\n\n {grey('Setup cancelled.')}\n")
|
||||
|
||||
# ── wizard steps ────────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _wizard() -> None:
|
||||
SetupWizard._print_welcome()
|
||||
print(f" {bold('Lets get started.')}")
|
||||
print()
|
||||
|
||||
providers = SetupWizard._select_providers()
|
||||
env_vars = SetupWizard._collect_keys(providers)
|
||||
port, master_key = SetupWizard._proxy_settings()
|
||||
|
||||
config_path = Path(os.getcwd()) / "litellm_config.yaml"
|
||||
try:
|
||||
config_path.write_text(
|
||||
SetupWizard._build_config(providers, env_vars, master_key)
|
||||
)
|
||||
except OSError as exc:
|
||||
print(f"\n {bold(_CROSS + ' Could not write config:')} {exc}")
|
||||
print(" Try running from a directory you have write access to.\n")
|
||||
return
|
||||
|
||||
SetupWizard._print_success(config_path, port, master_key)
|
||||
SetupWizard._offer_start(config_path, port, master_key)
|
||||
|
||||
# ── welcome ─────────────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _print_welcome() -> None:
|
||||
try:
|
||||
version = importlib.metadata.version("litellm")
|
||||
except Exception:
|
||||
version = "unknown"
|
||||
print()
|
||||
print(orange(LITELLM_ASCII.rstrip("\n")))
|
||||
print(f" {orange('Welcome')} to {bold('LiteLLM')} {grey('v' + version)}")
|
||||
print()
|
||||
print(_divider())
|
||||
print()
|
||||
|
||||
# ── provider selector ───────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _select_providers() -> List[Dict]:
|
||||
"""Arrow-key multi-select. Falls back to number input if /dev/tty unavailable."""
|
||||
if not _HAS_RAW_TERMINAL:
|
||||
return SetupWizard._select_fallback()
|
||||
try:
|
||||
return SetupWizard._select_interactive()
|
||||
except OSError:
|
||||
return SetupWizard._select_fallback()
|
||||
|
||||
@staticmethod
|
||||
def _read_key() -> str:
|
||||
"""Read one keypress from /dev/tty in raw mode."""
|
||||
assert (
|
||||
termios is not None and tty is not None
|
||||
) # only called when _HAS_RAW_TERMINAL
|
||||
with open("/dev/tty", "rb") as tty_fh:
|
||||
fd = tty_fh.fileno()
|
||||
old = termios.tcgetattr(fd)
|
||||
try:
|
||||
tty.setraw(fd)
|
||||
ch = tty_fh.read(1)
|
||||
if ch == b"\x1b":
|
||||
ch2 = tty_fh.read(1)
|
||||
if ch2 == b"[":
|
||||
ch3 = tty_fh.read(1)
|
||||
return "\x1b[" + ch3.decode("utf-8", errors="replace")
|
||||
return "\x1b" + ch2.decode("utf-8", errors="replace")
|
||||
return ch.decode("utf-8", errors="replace")
|
||||
finally:
|
||||
termios.tcsetattr(fd, termios.TCSADRAIN, old)
|
||||
|
||||
@staticmethod
|
||||
def _render_selector(cursor: int, selected: Set[int], first_render: bool) -> int:
|
||||
"""Draw or redraw the provider list. Returns the number of lines printed."""
|
||||
lines = [
|
||||
f"\n {bold('Add your first model')}\n",
|
||||
grey(" ↑↓ to navigate · Space to select · Enter to confirm") + "\n",
|
||||
"\n",
|
||||
]
|
||||
for i, p in enumerate(PROVIDERS):
|
||||
arrow = blue("❯") if i == cursor else " "
|
||||
bullet = green("◉") if i in selected else grey("○")
|
||||
name_str = bold(p["name"]) if i == cursor else p["name"]
|
||||
lines.append(f" {arrow} {bullet} {name_str} {grey(p['description'])}\n")
|
||||
lines.append("\n")
|
||||
|
||||
content = "".join(lines)
|
||||
if not first_render and _supports_color():
|
||||
sys.stdout.write(_MOVE_UP.format(content.count("\n")))
|
||||
sys.stdout.write(content)
|
||||
sys.stdout.flush()
|
||||
return content.count("\n")
|
||||
|
||||
@staticmethod
|
||||
def _select_interactive() -> List[Dict]:
|
||||
cursor = 0
|
||||
selected: set[int] = set()
|
||||
|
||||
if _supports_color():
|
||||
sys.stdout.write(_CURSOR_HIDE)
|
||||
sys.stdout.flush()
|
||||
try:
|
||||
SetupWizard._render_selector(cursor, selected, first_render=True)
|
||||
while True:
|
||||
key = SetupWizard._read_key()
|
||||
dirty = False
|
||||
if key == "\x1b[A":
|
||||
cursor = (cursor - 1) % len(PROVIDERS)
|
||||
dirty = True
|
||||
elif key == "\x1b[B":
|
||||
cursor = (cursor + 1) % len(PROVIDERS)
|
||||
dirty = True
|
||||
elif key == " ":
|
||||
selected.symmetric_difference_update({cursor})
|
||||
dirty = True
|
||||
elif key in ("\r", "\n"):
|
||||
if not selected:
|
||||
selected.add(cursor)
|
||||
break
|
||||
elif key in ("\x03", "\x04"):
|
||||
raise KeyboardInterrupt
|
||||
if dirty:
|
||||
SetupWizard._render_selector(cursor, selected, first_render=False)
|
||||
finally:
|
||||
if _supports_color():
|
||||
sys.stdout.write(_CURSOR_SHOW)
|
||||
sys.stdout.flush()
|
||||
|
||||
return [PROVIDERS[i] for i in sorted(selected)]
|
||||
|
||||
@staticmethod
|
||||
def _select_fallback() -> List[Dict]:
|
||||
"""Number-based fallback when raw terminal input is unavailable."""
|
||||
print()
|
||||
print(f" {bold('Add your first model')}")
|
||||
print(
|
||||
grey(
|
||||
" Enter numbers separated by commas (e.g. 1,2). Press Enter to confirm."
|
||||
)
|
||||
)
|
||||
print()
|
||||
for i, p in enumerate(PROVIDERS, 1):
|
||||
print(f" {grey(str(i) + '.')} {bold(p['name'])} {grey(p['description'])}")
|
||||
print()
|
||||
|
||||
while True:
|
||||
raw = _styled_input(f" {blue('❯')} Provider(s): ")
|
||||
if not raw:
|
||||
print(grey(" Please select at least one provider."))
|
||||
continue
|
||||
try:
|
||||
nums = [
|
||||
int(x.strip())
|
||||
for x in raw.replace(" ", ",").split(",")
|
||||
if x.strip()
|
||||
]
|
||||
valid = sorted({n for n in nums if 1 <= n <= len(PROVIDERS)})
|
||||
if not valid:
|
||||
print(grey(f" Enter numbers between 1 and {len(PROVIDERS)}."))
|
||||
continue
|
||||
return [PROVIDERS[i - 1] for i in valid]
|
||||
except ValueError:
|
||||
print(grey(" Enter numbers separated by commas, e.g. 1,3"))
|
||||
|
||||
# ── key collection ───────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _collect_keys(providers: List[Dict]) -> Dict[str, str]:
|
||||
env_vars: Dict[str, str] = {}
|
||||
print()
|
||||
print(_divider())
|
||||
print()
|
||||
print(f" {bold('Enter your API keys')}")
|
||||
print(grey(" Keys are stored only in the generated config file."))
|
||||
print(
|
||||
grey(
|
||||
" Tip: add litellm_config.yaml to .gitignore to avoid committing secrets."
|
||||
)
|
||||
)
|
||||
print()
|
||||
|
||||
for p in providers:
|
||||
if p["env_key"] is None:
|
||||
print(
|
||||
f" {green(p['name'])}: {grey('no key needed (uses local Ollama)')}"
|
||||
)
|
||||
continue
|
||||
|
||||
key = SetupWizard._prompt_key(p)
|
||||
if not key:
|
||||
continue
|
||||
|
||||
for extra_key, extra_hint in zip(
|
||||
p.get("extra_keys", []), p.get("extra_hints", [])
|
||||
):
|
||||
val = _styled_input(f" {blue('❯')} {extra_key} {grey(extra_hint)}: ")
|
||||
if val:
|
||||
env_vars[extra_key] = val
|
||||
|
||||
if p.get("needs_api_base"):
|
||||
api_base = _styled_input(
|
||||
f" {blue('❯')} Azure endpoint URL {grey(p.get('api_base_hint', ''))}: "
|
||||
)
|
||||
if api_base:
|
||||
env_vars[f"_LITELLM_AZURE_API_BASE_{p['id'].upper()}"] = api_base
|
||||
deployment = _styled_input(
|
||||
f" {blue('❯')} Azure deployment name {grey('(e.g. my-gpt4o)')}: "
|
||||
)
|
||||
if deployment:
|
||||
env_vars[
|
||||
f"_LITELLM_AZURE_DEPLOYMENT_{p['id'].upper()}"
|
||||
] = deployment
|
||||
|
||||
# Store the key returned by validation — may be a re-entered replacement
|
||||
env_vars[p["env_key"]] = SetupWizard._validate_and_report(p, key)
|
||||
|
||||
return env_vars
|
||||
|
||||
@staticmethod
|
||||
def _prompt_key(provider: Dict) -> str:
|
||||
"""Prompt for a provider's API key, with skip option. Returns the key or ''."""
|
||||
hint = grey(provider.get("key_hint", ""))
|
||||
while True:
|
||||
key = _styled_input(
|
||||
f" {blue('❯')} {bold(provider['name'])} API key {hint}: "
|
||||
)
|
||||
if key:
|
||||
return key
|
||||
print(grey(" Key is required. Leave blank to skip this provider."))
|
||||
if _styled_input(grey(" Skip? (y/N): ")).lower() == "y":
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _validate_and_report(provider: Dict, api_key: str) -> str:
|
||||
"""
|
||||
Validate credentials using litellm.utils.check_valid_key and print result.
|
||||
Offers a re-entry loop on failure. Returns the final (possibly re-entered) key.
|
||||
"""
|
||||
test_model: Optional[str] = provider.get("test_model")
|
||||
if not test_model:
|
||||
return api_key # Azure / Bedrock / Ollama — skip validation
|
||||
|
||||
while True:
|
||||
print(
|
||||
f" {grey('Testing connection to ' + provider['name'] + '...')}",
|
||||
flush=True,
|
||||
)
|
||||
valid = check_valid_key(model=test_model, api_key=api_key)
|
||||
if valid:
|
||||
print(
|
||||
f" {green(_CHECK)} {bold(provider['name'])} connected successfully"
|
||||
)
|
||||
return api_key
|
||||
|
||||
print(f" {_CROSS} {bold(provider['name'])} {grey('— invalid API key')}")
|
||||
if (
|
||||
_styled_input(f" {blue('❯')} Re-enter key? {grey('(y/N)')}: ").lower()
|
||||
!= "y"
|
||||
):
|
||||
return api_key
|
||||
|
||||
hint = grey(provider.get("key_hint", ""))
|
||||
new_key = _styled_input(
|
||||
f" {blue('❯')} {bold(provider['name'])} API key {hint}: "
|
||||
)
|
||||
if not new_key:
|
||||
return api_key
|
||||
api_key = new_key
|
||||
|
||||
# ── proxy settings ───────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _proxy_settings() -> "tuple[int, str]":
|
||||
print()
|
||||
print(_divider())
|
||||
print()
|
||||
print(f" {bold('Proxy settings')}")
|
||||
print()
|
||||
port = 4000
|
||||
while True:
|
||||
port_raw = _styled_input(f" {blue('❯')} Port {grey('[4000]')}: ")
|
||||
if not port_raw:
|
||||
break
|
||||
if port_raw.isdigit() and 1 <= int(port_raw) <= 65535:
|
||||
port = int(port_raw)
|
||||
break
|
||||
print(grey(" Enter a valid port number (1–65535)."))
|
||||
key_raw = _styled_input(f" {blue('❯')} Master key {grey('[auto-generate]')}: ")
|
||||
master_key = key_raw if key_raw else f"sk-{secrets.token_urlsafe(32)}"
|
||||
return port, master_key
|
||||
|
||||
# ── config generation ────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _build_config(
|
||||
providers: List[Dict],
|
||||
env_vars: Dict[str, str],
|
||||
master_key: str,
|
||||
) -> str:
|
||||
env_copy = dict(env_vars) # work on a copy — do not mutate caller's dict
|
||||
lines = ["model_list:"]
|
||||
for p in providers:
|
||||
# Only emit models for providers that actually have credentials
|
||||
has_creds = p["env_key"] is None or p["env_key"] in env_copy
|
||||
if not has_creds:
|
||||
continue
|
||||
|
||||
if p["id"] == "azure":
|
||||
deployment = env_copy.pop(
|
||||
f"_LITELLM_AZURE_DEPLOYMENT_{p['id'].upper()}", ""
|
||||
)
|
||||
if not deployment:
|
||||
continue # skip Azure entirely if no deployment name was provided
|
||||
models = [f"azure/{deployment}"]
|
||||
else:
|
||||
models = p["models"]
|
||||
|
||||
for model in models:
|
||||
raw_display = model.split("/")[-1] if "/" in model else model
|
||||
# Qualify azure display names to avoid collision with OpenAI model names
|
||||
display = f"azure-{raw_display}" if p["id"] == "azure" else raw_display
|
||||
lines += [
|
||||
f" - model_name: {display}",
|
||||
" litellm_params:",
|
||||
f" model: {model}",
|
||||
]
|
||||
if p["env_key"] and p["env_key"] in env_copy:
|
||||
lines.append(f" api_key: os.environ/{p['env_key']}")
|
||||
if p.get("api_base"):
|
||||
lines.append(
|
||||
f' api_base: "{_yaml_escape(str(p["api_base"]))}"'
|
||||
)
|
||||
elif p.get("needs_api_base"):
|
||||
azure_base_key = f"_LITELLM_AZURE_API_BASE_{p['id'].upper()}"
|
||||
if azure_base_key in env_copy:
|
||||
lines.append(
|
||||
f' api_base: "{_yaml_escape(env_copy.pop(azure_base_key))}"'
|
||||
)
|
||||
if p.get("api_version"):
|
||||
lines.append(f" api_version: {p['api_version']}")
|
||||
|
||||
lines += [
|
||||
"",
|
||||
"general_settings:",
|
||||
f' master_key: "{_yaml_escape(master_key)}"',
|
||||
"",
|
||||
]
|
||||
|
||||
real_vars = {k: v for k, v in env_copy.items() if not k.startswith("_LITELLM_")}
|
||||
if real_vars:
|
||||
lines.append("environment_variables:")
|
||||
for k, v in real_vars.items():
|
||||
lines.append(f' {k}: "{_yaml_escape(v)}"')
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
# ── success + launch ─────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _print_success(config_path: Path, port: int, master_key: str) -> None:
|
||||
print()
|
||||
print(_divider())
|
||||
print()
|
||||
print(f" {green(_CHECK + ' Config saved')} → {bold(str(config_path))}")
|
||||
print()
|
||||
print(f" {bold('To start your proxy:')}")
|
||||
print()
|
||||
print(f" {grey('$')} litellm --config {config_path} --port {port}")
|
||||
print()
|
||||
print(f" {bold('Then set your client:')}")
|
||||
print()
|
||||
print(f" export OPENAI_BASE_URL=http://localhost:{port}")
|
||||
print(f" export OPENAI_API_KEY={master_key}")
|
||||
print()
|
||||
print(_divider())
|
||||
print()
|
||||
|
||||
@staticmethod
|
||||
def _offer_start(config_path: Path, port: int, master_key: str) -> None:
|
||||
start = _styled_input(
|
||||
f" {blue('❯')} Start the proxy now? {grey('(Y/n)')}: "
|
||||
).lower()
|
||||
if start not in ("", "y", "yes"):
|
||||
print()
|
||||
print(
|
||||
f" Run {bold(f'litellm --config {config_path}')} whenever you're ready."
|
||||
)
|
||||
print()
|
||||
print(
|
||||
grey(f" Quick test once running: curl http://localhost:{port}/health")
|
||||
)
|
||||
print()
|
||||
return
|
||||
|
||||
print()
|
||||
print(_divider())
|
||||
print()
|
||||
print(f" {bold('Proxy is starting on')} http://localhost:{port}")
|
||||
print()
|
||||
print(grey(" Your proxy is OpenAI-compatible. Point any OpenAI SDK at it:"))
|
||||
print()
|
||||
print(f" export OPENAI_BASE_URL=http://localhost:{port}")
|
||||
print(f" export OPENAI_API_KEY={master_key}")
|
||||
print()
|
||||
print(grey(" Quick test (in another terminal):"))
|
||||
print()
|
||||
print(f" curl http://localhost:{port}/health")
|
||||
print()
|
||||
print(grey(" Dashboard:"))
|
||||
print()
|
||||
print(f" http://localhost:{port}/ui {grey('(login with your master key)')}")
|
||||
print()
|
||||
print(_divider())
|
||||
print()
|
||||
print(f" {green(_CHECK)} Starting… {grey('(Ctrl+C to stop)')}")
|
||||
print()
|
||||
|
||||
scripts_dir = sysconfig.get_path("scripts")
|
||||
litellm_bin = os.path.join(scripts_dir or "", "litellm")
|
||||
try:
|
||||
os.execlp(
|
||||
litellm_bin,
|
||||
litellm_bin,
|
||||
"--config",
|
||||
str(config_path),
|
||||
"--port",
|
||||
str(port),
|
||||
) # noqa: S606
|
||||
except OSError as exc:
|
||||
print(f"\n {bold(_CROSS + ' Could not start proxy:')} {exc}")
|
||||
print(f" Run manually: litellm --config {config_path} --port {port}\n")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entrypoint
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_setup_wizard() -> None:
|
||||
"""Run the interactive setup wizard. Called by `litellm --setup`."""
|
||||
SetupWizard.run()
|
||||
@ -79,6 +79,7 @@ class SupportedGuardrailIntegrations(Enum):
|
||||
SEMANTIC_GUARD = "semantic_guard"
|
||||
MCP_END_USER_PERMISSION = "mcp_end_user_permission"
|
||||
BLOCK_CODE_EXECUTION = "block_code_execution"
|
||||
MCP_JWT_SIGNER = "mcp_jwt_signer"
|
||||
|
||||
|
||||
class Role(Enum):
|
||||
|
||||
@ -39,7 +39,8 @@ class VantageExportRequest(BaseModel):
|
||||
"""Request model for Vantage export operations (actual export, no default limit)"""
|
||||
|
||||
limit: Optional[int] = Field(
|
||||
None, description="Optional limit on number of records to export (default: no limit)"
|
||||
None,
|
||||
description="Optional limit on number of records to export (default: no limit)",
|
||||
)
|
||||
start_time_utc: Optional[datetime] = Field(
|
||||
None, description="Start time for data export in UTC"
|
||||
|
||||
@ -195,7 +195,9 @@ def decode_character_id_with_provider(encoded_character_id: str) -> DecodedChara
|
||||
character_id=decoded_character_id,
|
||||
)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error decoding character_id '{encoded_character_id}': {e}")
|
||||
verbose_logger.debug(
|
||||
f"Error decoding character_id '{encoded_character_id}': {e}"
|
||||
)
|
||||
return DecodedCharacterId(
|
||||
custom_llm_provider=None,
|
||||
model_id=None,
|
||||
|
||||
@ -1186,13 +1186,17 @@ def video_create_character(
|
||||
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config(
|
||||
provider_config: Optional[
|
||||
BaseVideoConfig
|
||||
] = ProviderConfigManager.get_provider_video_config(
|
||||
model=None,
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
if provider_config is None:
|
||||
raise ValueError(f"video create character is not supported for {custom_llm_provider}")
|
||||
raise ValueError(
|
||||
f"video create character is not supported for {custom_llm_provider}"
|
||||
)
|
||||
|
||||
local_vars.update(kwargs)
|
||||
request_params: Dict = {"name": name}
|
||||
@ -1311,13 +1315,17 @@ def video_get_character(
|
||||
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config(
|
||||
provider_config: Optional[
|
||||
BaseVideoConfig
|
||||
] = ProviderConfigManager.get_provider_video_config(
|
||||
model=None,
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
if provider_config is None:
|
||||
raise ValueError(f"video get character is not supported for {custom_llm_provider}")
|
||||
raise ValueError(
|
||||
f"video get character is not supported for {custom_llm_provider}"
|
||||
)
|
||||
|
||||
local_vars.update(kwargs)
|
||||
request_params: Dict = {"character_id": character_id}
|
||||
@ -1439,7 +1447,9 @@ def video_edit(
|
||||
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config(
|
||||
provider_config: Optional[
|
||||
BaseVideoConfig
|
||||
] = ProviderConfigManager.get_provider_video_config(
|
||||
model=None,
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
@ -1572,16 +1582,24 @@ def video_extension(
|
||||
|
||||
litellm_params = GenericLiteLLMParams(**kwargs)
|
||||
|
||||
provider_config: Optional[BaseVideoConfig] = ProviderConfigManager.get_provider_video_config(
|
||||
provider_config: Optional[
|
||||
BaseVideoConfig
|
||||
] = ProviderConfigManager.get_provider_video_config(
|
||||
model=None,
|
||||
provider=litellm.LlmProviders(custom_llm_provider),
|
||||
)
|
||||
|
||||
if provider_config is None:
|
||||
raise ValueError(f"video extension is not supported for {custom_llm_provider}")
|
||||
raise ValueError(
|
||||
f"video extension is not supported for {custom_llm_provider}"
|
||||
)
|
||||
|
||||
local_vars.update(kwargs)
|
||||
request_params: Dict = {"video_id": video_id, "prompt": prompt, "seconds": seconds}
|
||||
request_params: Dict = {
|
||||
"video_id": video_id,
|
||||
"prompt": prompt,
|
||||
"seconds": seconds,
|
||||
}
|
||||
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model="",
|
||||
|
||||
@ -32354,6 +32354,53 @@
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"xai/grok-4.20-multi-agent-beta-0309": {
|
||||
"cache_read_input_token_cost": 2e-07,
|
||||
"input_cost_per_token": 2e-06,
|
||||
"litellm_provider": "xai",
|
||||
"max_input_tokens": 2000000,
|
||||
"max_output_tokens": 2000000,
|
||||
"max_tokens": 2000000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 6e-06,
|
||||
"source": "https://docs.x.ai/docs/models",
|
||||
"supports_function_calling": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"xai/grok-4.20-beta-0309-reasoning": {
|
||||
"cache_read_input_token_cost": 2e-07,
|
||||
"input_cost_per_token": 2e-06,
|
||||
"litellm_provider": "xai",
|
||||
"max_input_tokens": 2000000,
|
||||
"max_output_tokens": 2000000,
|
||||
"max_tokens": 2000000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 6e-06,
|
||||
"source": "https://docs.x.ai/docs/models",
|
||||
"supports_function_calling": true,
|
||||
"supports_reasoning": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"xai/grok-4.20-beta-0309-non-reasoning": {
|
||||
"cache_read_input_token_cost": 2e-07,
|
||||
"input_cost_per_token": 2e-06,
|
||||
"litellm_provider": "xai",
|
||||
"max_input_tokens": 2000000,
|
||||
"max_output_tokens": 2000000,
|
||||
"max_tokens": 2000000,
|
||||
"mode": "chat",
|
||||
"output_cost_per_token": 6e-06,
|
||||
"source": "https://docs.x.ai/docs/models",
|
||||
"supports_function_calling": true,
|
||||
"supports_tool_choice": true,
|
||||
"supports_vision": true,
|
||||
"supports_web_search": true
|
||||
},
|
||||
"xai/grok-beta": {
|
||||
"input_cost_per_token": 5e-06,
|
||||
"litellm_provider": "xai",
|
||||
|
||||
143
scripts/install.sh
Executable file
143
scripts/install.sh
Executable file
@ -0,0 +1,143 @@
|
||||
#!/usr/bin/env bash
|
||||
# LiteLLM Installer
|
||||
# Usage: curl -fsSL https://raw.githubusercontent.com/BerriAI/litellm/main/scripts/install.sh | sh
|
||||
#
|
||||
# NOTE: set -e without pipefail for POSIX sh compatibility (dash on Ubuntu/Debian
|
||||
# ignores the shebang when invoked as `sh` and does not support `pipefail`).
|
||||
set -eu
|
||||
|
||||
MIN_PYTHON_MAJOR=3
|
||||
MIN_PYTHON_MINOR=9
|
||||
|
||||
# NOTE: before merging, this must stay as "litellm[proxy]" to install from PyPI.
|
||||
LITELLM_PACKAGE="litellm[proxy]"
|
||||
|
||||
# ── colours ────────────────────────────────────────────────────────────────
|
||||
if [ -t 1 ]; then
|
||||
BOLD='\033[1m'
|
||||
GREEN='\033[38;2;78;186;101m'
|
||||
GREY='\033[38;2;153;153;153m'
|
||||
RESET='\033[0m'
|
||||
else
|
||||
BOLD='' GREEN='' GREY='' RESET=''
|
||||
fi
|
||||
|
||||
info() { printf "${GREY} %s${RESET}\n" "$*"; }
|
||||
success() { printf "${GREEN} ✔ %s${RESET}\n" "$*"; }
|
||||
header() { printf "${BOLD} %s${RESET}\n" "$*"; }
|
||||
die() { printf "\n Error: %s\n\n" "$*" >&2; exit 1; }
|
||||
|
||||
# ── banner ─────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
cat << 'EOF'
|
||||
██╗ ██╗████████╗███████╗██╗ ██╗ ███╗ ███╗
|
||||
██║ ██║╚══██╔══╝██╔════╝██║ ██║ ████╗ ████║
|
||||
██║ ██║ ██║ █████╗ ██║ ██║ ██╔████╔██║
|
||||
██║ ██║ ██║ ██╔══╝ ██║ ██║ ██║╚██╔╝██║
|
||||
███████╗██║ ██║ ███████╗███████╗███████╗██║ ╚═╝ ██║
|
||||
╚══════╝╚═╝ ╚═╝ ╚══════╝╚══════╝╚══════╝╚═╝ ╚═╝
|
||||
EOF
|
||||
printf " ${BOLD}LiteLLM Installer${RESET} ${GREY}— unified gateway for 100+ LLM providers${RESET}\n\n"
|
||||
|
||||
# ── OS detection ───────────────────────────────────────────────────────────
|
||||
OS="$(uname -s)"
|
||||
ARCH="$(uname -m)"
|
||||
|
||||
case "$OS" in
|
||||
Darwin) PLATFORM="macOS ($ARCH)" ;;
|
||||
Linux) PLATFORM="Linux ($ARCH)" ;;
|
||||
*) die "Unsupported OS: $OS. LiteLLM supports macOS and Linux." ;;
|
||||
esac
|
||||
|
||||
info "Platform: $PLATFORM"
|
||||
|
||||
# ── Python detection ───────────────────────────────────────────────────────
|
||||
PYTHON_BIN=""
|
||||
for candidate in python3 python; do
|
||||
if command -v "$candidate" >/dev/null 2>&1; then
|
||||
major="$("$candidate" -c 'import sys; print(sys.version_info.major)' 2>/dev/null || true)"
|
||||
minor="$("$candidate" -c 'import sys; print(sys.version_info.minor)' 2>/dev/null || true)"
|
||||
if [ "${major:-0}" -ge "$MIN_PYTHON_MAJOR" ] && [ "${minor:-0}" -ge "$MIN_PYTHON_MINOR" ]; then
|
||||
PYTHON_BIN="$(command -v "$candidate")"
|
||||
info "Python: $("$candidate" --version 2>&1)"
|
||||
break
|
||||
fi
|
||||
fi
|
||||
done
|
||||
|
||||
if [ -z "$PYTHON_BIN" ]; then
|
||||
die "Python ${MIN_PYTHON_MAJOR}.${MIN_PYTHON_MINOR}+ is required but not found.
|
||||
Install it from https://python.org/downloads or via your package manager:
|
||||
macOS: brew install python@3
|
||||
Ubuntu: sudo apt install python3 python3-pip"
|
||||
fi
|
||||
|
||||
# ── pip detection ──────────────────────────────────────────────────────────
|
||||
if ! "$PYTHON_BIN" -m pip --version >/dev/null 2>&1; then
|
||||
die "pip is not available. Install it with:
|
||||
$PYTHON_BIN -m ensurepip --upgrade"
|
||||
fi
|
||||
|
||||
# ── install ────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
header "Installing litellm[proxy]…"
|
||||
echo ""
|
||||
|
||||
"$PYTHON_BIN" -m pip install --upgrade "${LITELLM_PACKAGE}" \
|
||||
|| die "pip install failed. Try manually: $PYTHON_BIN -m pip install '${LITELLM_PACKAGE}'"
|
||||
|
||||
# ── find the litellm binary installed by pip for this Python ───────────────
|
||||
# sysconfig.get_path('scripts') is where pip puts console scripts — reliable
|
||||
# even when the Python lives in a libexec/ symlink tree (e.g. Homebrew).
|
||||
SCRIPTS_DIR="$("$PYTHON_BIN" -c 'import sysconfig; print(sysconfig.get_path("scripts"))')"
|
||||
LITELLM_BIN="${SCRIPTS_DIR}/litellm"
|
||||
|
||||
if [ ! -x "$LITELLM_BIN" ]; then
|
||||
# Fall back to user-base bin (pip install --user)
|
||||
USER_BIN="$("$PYTHON_BIN" -c 'import site; print(site.getuserbase())')/bin"
|
||||
LITELLM_BIN="${USER_BIN}/litellm"
|
||||
fi
|
||||
|
||||
if [ ! -x "$LITELLM_BIN" ]; then
|
||||
die "litellm binary not found after install. Try: $PYTHON_BIN -m pip install --user '${LITELLM_PACKAGE}'"
|
||||
fi
|
||||
|
||||
# ── success banner ─────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
success "LiteLLM installed"
|
||||
|
||||
installed_ver="$("$LITELLM_BIN" --version 2>&1 | grep -oE '[0-9]+\.[0-9]+\.[0-9]+' | head -1 || true)"
|
||||
[ -n "$installed_ver" ] && info "Version: $installed_ver"
|
||||
|
||||
# ── PATH hint ──────────────────────────────────────────────────────────────
|
||||
if ! command -v litellm >/dev/null 2>&1; then
|
||||
info "Note: add litellm to your PATH: export PATH=\"\$PATH:${SCRIPTS_DIR}\""
|
||||
fi
|
||||
|
||||
# ── launch setup wizard ────────────────────────────────────────────────────
|
||||
echo ""
|
||||
printf " ${BOLD}Run the interactive setup wizard?${RESET} ${GREY}(Y/n)${RESET}: "
|
||||
# /dev/tty may be unavailable in Docker/CI — default to yes if it can't be read
|
||||
answer=""
|
||||
if [ -r /dev/tty ]; then
|
||||
read -r answer </dev/tty || answer=""
|
||||
fi
|
||||
|
||||
if [ -z "$answer" ] || [ "$answer" = "y" ] || [ "$answer" = "Y" ]; then
|
||||
echo ""
|
||||
# Use /dev/tty for interactive input when available (stdin is a pipe from curl)
|
||||
if [ -r /dev/tty ]; then
|
||||
exec "$LITELLM_BIN" --setup </dev/tty
|
||||
else
|
||||
exec "$LITELLM_BIN" --setup
|
||||
fi
|
||||
else
|
||||
echo ""
|
||||
header "Quick start:"
|
||||
echo ""
|
||||
info " litellm --setup # interactive wizard"
|
||||
info " litellm --model gpt-4o # single-model quickstart"
|
||||
echo ""
|
||||
info "Docs: https://docs.litellm.ai"
|
||||
echo ""
|
||||
fi
|
||||
@ -907,8 +907,9 @@ def test_router_region_pre_call_check(allowed_model_region):
|
||||
{
|
||||
"model_name": "gpt-3.5-turbo-large", # openai model name
|
||||
"litellm_params": { # params for litellm completion/embedding call
|
||||
"model": "gpt-3.5-turbo-1106",
|
||||
"model": "gpt-4.1-mini",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
"mock_response": "This is a mock response.",
|
||||
},
|
||||
"model_info": {"id": "2"},
|
||||
},
|
||||
|
||||
@ -60,10 +60,33 @@ def assert_langfuse_request_matches_expected(
|
||||
)
|
||||
]
|
||||
|
||||
# When aggregating from multiple flush cycles, deduplicate by keeping
|
||||
# only one trace-create and one generation-create per trace_id.
|
||||
seen_types: dict = {}
|
||||
deduped_batch: list = []
|
||||
for item in actual_request_body["batch"]:
|
||||
item_type = item["type"]
|
||||
if item_type not in seen_types:
|
||||
seen_types[item_type] = True
|
||||
deduped_batch.append(item)
|
||||
actual_request_body["batch"] = deduped_batch
|
||||
|
||||
# Ensure canonical order: trace-create first, generation-create second
|
||||
actual_request_body["batch"].sort(
|
||||
key=lambda x: 0 if x["type"] == "trace-create" else 1
|
||||
)
|
||||
|
||||
print(
|
||||
"actual_request_body after filtering", json.dumps(actual_request_body, indent=4)
|
||||
)
|
||||
|
||||
assert len(actual_request_body["batch"]) >= 2, (
|
||||
f"Expected at least 2 batch items (trace-create + generation-create) "
|
||||
f"after filtering by trace_id={trace_id}, "
|
||||
f"but got {len(actual_request_body['batch'])}. "
|
||||
f"Items: {json.dumps(actual_request_body['batch'], indent=2)}"
|
||||
)
|
||||
|
||||
# Replace dynamic values in actual request body
|
||||
for item in actual_request_body["batch"]:
|
||||
|
||||
@ -150,19 +173,36 @@ class TestLangfuseLogging:
|
||||
"""Helper method to verify Langfuse API calls"""
|
||||
await asyncio.sleep(3)
|
||||
|
||||
# Verify the call
|
||||
# Verify at least one call was made
|
||||
assert mock_post.call_count >= 1
|
||||
url = mock_post.call_args[0][0]
|
||||
request_body = mock_post.call_args[1].get("content")
|
||||
|
||||
# Parse the JSON string into a dict for assertions
|
||||
actual_request_body = json.loads(request_body)
|
||||
# Aggregate batch items from ALL calls — the Langfuse SDK may split
|
||||
# trace-create and generation-create across separate HTTP flushes.
|
||||
langfuse_url = "https://us.cloud.langfuse.com/api/public/ingestion"
|
||||
all_batch_items: list = []
|
||||
metadata: Optional[dict] = None
|
||||
for call in mock_post.call_args_list:
|
||||
url = call[0][0]
|
||||
if url != langfuse_url:
|
||||
continue
|
||||
request_body = call[1].get("content")
|
||||
if request_body:
|
||||
body = json.loads(request_body)
|
||||
all_batch_items.extend(body.get("batch", []))
|
||||
if metadata is None:
|
||||
metadata = body.get("metadata")
|
||||
|
||||
print("\nMocked Request Details:")
|
||||
print(f"URL: {url}")
|
||||
assert len(all_batch_items) > 0, "No Langfuse ingestion calls found"
|
||||
assert metadata is not None, "No metadata found in Langfuse calls"
|
||||
|
||||
actual_request_body = {
|
||||
"batch": all_batch_items,
|
||||
"metadata": metadata,
|
||||
}
|
||||
|
||||
print("\nMocked Request Details (aggregated from all calls):")
|
||||
print(f"Request Body: {json.dumps(actual_request_body, indent=4)}")
|
||||
|
||||
assert url == "https://us.cloud.langfuse.com/api/public/ingestion"
|
||||
assert_langfuse_request_matches_expected(
|
||||
actual_request_body,
|
||||
expected_file_name,
|
||||
@ -170,6 +210,7 @@ class TestLangfuseLogging:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_langfuse_logging_completion(self, mock_setup):
|
||||
"""Test Langfuse logging for chat completion"""
|
||||
setup = mock_setup
|
||||
@ -185,6 +226,7 @@ class TestLangfuseLogging:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_langfuse_logging_completion_with_tags(self, mock_setup):
|
||||
"""Test Langfuse logging for chat completion with tags"""
|
||||
setup = mock_setup
|
||||
@ -203,6 +245,7 @@ class TestLangfuseLogging:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_langfuse_logging_completion_with_tags_stream(self, mock_setup):
|
||||
"""Test Langfuse logging for chat completion with tags"""
|
||||
setup = mock_setup
|
||||
@ -223,6 +266,7 @@ class TestLangfuseLogging:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_langfuse_logging_completion_with_langfuse_metadata(self, mock_setup):
|
||||
"""Test Langfuse logging for chat completion with metadata for langfuse"""
|
||||
setup = mock_setup
|
||||
@ -252,6 +296,7 @@ class TestLangfuseLogging:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_langfuse_logging_with_non_serializable_metadata(self, mock_setup):
|
||||
"""Test Langfuse logging with metadata that requires preparation (Pydantic models, sets, etc)"""
|
||||
from pydantic import BaseModel
|
||||
@ -358,6 +403,7 @@ class TestLangfuseLogging:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_langfuse_logging_completion_with_malformed_llm_response(
|
||||
self, mock_setup
|
||||
):
|
||||
@ -387,6 +433,7 @@ class TestLangfuseLogging:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_langfuse_logging_completion_with_bedrock_llm_response(
|
||||
self, mock_setup
|
||||
):
|
||||
@ -418,6 +465,7 @@ class TestLangfuseLogging:
|
||||
setup["mock_post"], "completion_with_bedrock_call.json", setup["trace_id"]
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_langfuse_logging_completion_with_vertex_llm_response(
|
||||
self, mock_setup
|
||||
):
|
||||
@ -449,6 +497,7 @@ class TestLangfuseLogging:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_langfuse_logging_vllm_embedding(self, mock_setup):
|
||||
"""
|
||||
Test that the request sent to the vllm embedding endpoint is correct.
|
||||
@ -500,6 +549,7 @@ class TestLangfuseLogging:
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_langfuse_logging_with_router(self, mock_setup):
|
||||
"""Test Langfuse logging with router"""
|
||||
litellm._turn_on_debug()
|
||||
|
||||
@ -308,16 +308,19 @@ async def test_long_term_spend_accuracy_with_bursts():
|
||||
response = await chat_completion(session, key)
|
||||
print(f"Burst 2 - Request {i + 1}/{BURST_2_REQUESTS} completed")
|
||||
|
||||
# Poll until key spend reflects burst 2
|
||||
burst_1_spend = intermediate_key_info["info"]["spend"]
|
||||
# Poll until key spend reaches expected total (burst 1 + burst 2)
|
||||
start = time.time()
|
||||
while time.time() - start < 120:
|
||||
key_info_check = await get_spend_info(session, "key", key)
|
||||
current_spend = key_info_check["info"]["spend"]
|
||||
if current_spend > burst_1_spend:
|
||||
print(f"Key spend increased to {current_spend} after {time.time() - start:.1f}s")
|
||||
if abs(current_spend - expected_spend) < TOLERANCE:
|
||||
print(
|
||||
f"Total spend reached expected {expected_spend} after {time.time() - start:.1f}s"
|
||||
)
|
||||
break
|
||||
print(f"Key spend still {current_spend}, waiting for burst 2 flush...")
|
||||
print(
|
||||
f"Key spend {current_spend}, expected {expected_spend}, waiting..."
|
||||
)
|
||||
await asyncio.sleep(10)
|
||||
|
||||
# Allow extra time for all entity spend aggregations
|
||||
|
||||
@ -77,6 +77,11 @@ class TestRequestCompliance:
|
||||
schema = spec_dict["components"]["schemas"]["CreateModelInteractionParams"]
|
||||
input_schema = schema["properties"]["input"]
|
||||
|
||||
# The input property may be inline oneOf or a $ref to InteractionsInput
|
||||
if "$ref" in input_schema:
|
||||
ref_name = input_schema["$ref"].split("/")[-1]
|
||||
input_schema = spec_dict["components"]["schemas"][ref_name]
|
||||
|
||||
# Should be oneOf with multiple types
|
||||
assert "oneOf" in input_schema
|
||||
|
||||
@ -100,10 +105,21 @@ class TestRequestCompliance:
|
||||
assert "discriminator" in content_schema
|
||||
assert content_schema["discriminator"]["propertyName"] == "type"
|
||||
|
||||
# Check TextContent is an option
|
||||
mapping = content_schema["discriminator"]["mapping"]
|
||||
assert "text" in mapping
|
||||
print(f"Content type discriminator mapping: {list(mapping.keys())}")
|
||||
# Check TextContent is an option (via mapping if present, or via oneOf refs)
|
||||
mapping = content_schema["discriminator"].get("mapping")
|
||||
if mapping:
|
||||
assert "text" in mapping
|
||||
print(f"Content type discriminator mapping: {list(mapping.keys())}")
|
||||
else:
|
||||
# Discriminator without explicit mapping — verify via oneOf
|
||||
one_of = content_schema.get("oneOf", [])
|
||||
ref_names = [
|
||||
opt["$ref"].split("/")[-1] for opt in one_of if "$ref" in opt
|
||||
]
|
||||
assert "TextContent" in ref_names, (
|
||||
f"TextContent not found in oneOf refs: {ref_names}"
|
||||
)
|
||||
print(f"Content type discriminator (no mapping), oneOf refs: {ref_names}")
|
||||
|
||||
def test_text_content_schema(self, spec_dict):
|
||||
"""Verify TextContent schema."""
|
||||
|
||||
@ -386,8 +386,56 @@ class TestXAICostCalculator:
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
)
|
||||
|
||||
|
||||
web_search_cost = cost_per_web_search_request(usage=usage, model_info={})
|
||||
|
||||
|
||||
# Expected cost: No web search data = $0.0
|
||||
assert web_search_cost == 0.0
|
||||
|
||||
def test_grok_4_20_beta_reasoning_cost_calculation(self):
|
||||
"""Test cost calculation for grok-4.20-beta-0309-reasoning model."""
|
||||
usage = Usage(prompt_tokens=100, completion_tokens=200, total_tokens=300)
|
||||
|
||||
prompt_cost, completion_cost = cost_per_token(
|
||||
model="grok-4.20-beta-0309-reasoning", usage=usage
|
||||
)
|
||||
|
||||
# Input: 100 tokens * $2e-6 = $0.0002
|
||||
# Output: 200 tokens * $6e-6 = $0.0012
|
||||
expected_prompt_cost = 100 * 2e-6
|
||||
expected_completion_cost = 200 * 6e-6
|
||||
|
||||
assert math.isclose(prompt_cost, expected_prompt_cost, rel_tol=1e-10)
|
||||
assert math.isclose(completion_cost, expected_completion_cost, rel_tol=1e-10)
|
||||
|
||||
def test_grok_4_20_beta_non_reasoning_cost_calculation(self):
|
||||
"""Test cost calculation for grok-4.20-beta-0309-non-reasoning model."""
|
||||
usage = Usage(prompt_tokens=50, completion_tokens=100, total_tokens=150)
|
||||
|
||||
prompt_cost, completion_cost = cost_per_token(
|
||||
model="grok-4.20-beta-0309-non-reasoning", usage=usage
|
||||
)
|
||||
|
||||
# Input: 50 tokens * $2e-6 = $0.0001
|
||||
# Output: 100 tokens * $6e-6 = $0.0006
|
||||
expected_prompt_cost = 50 * 2e-6
|
||||
expected_completion_cost = 100 * 6e-6
|
||||
|
||||
assert math.isclose(prompt_cost, expected_prompt_cost, rel_tol=1e-10)
|
||||
assert math.isclose(completion_cost, expected_completion_cost, rel_tol=1e-10)
|
||||
|
||||
def test_grok_4_20_multi_agent_cost_calculation(self):
|
||||
"""Test cost calculation for grok-4.20-multi-agent-beta-0309 model."""
|
||||
usage = Usage(prompt_tokens=200, completion_tokens=300, total_tokens=500)
|
||||
|
||||
prompt_cost, completion_cost = cost_per_token(
|
||||
model="grok-4.20-multi-agent-beta-0309", usage=usage
|
||||
)
|
||||
|
||||
# Input: 200 tokens * $2e-6 = $0.0004
|
||||
# Output: 300 tokens * $6e-6 = $0.0018
|
||||
expected_prompt_cost = 200 * 2e-6
|
||||
expected_completion_cost = 300 * 6e-6
|
||||
|
||||
assert math.isclose(prompt_cost, expected_prompt_cost, rel_tol=1e-10)
|
||||
assert math.isclose(completion_cost, expected_completion_cost, rel_tol=1e-10)
|
||||
|
||||
@ -0,0 +1,707 @@
|
||||
"""
|
||||
Tests for pre_mcp_call guardrail hook header mutation support.
|
||||
|
||||
Validates that:
|
||||
1. _convert_mcp_hook_response_to_kwargs extracts extra_headers from hook response
|
||||
2. pre_call_tool_check returns hook-provided extra_headers AND modified arguments
|
||||
3. call_tool flows hook headers and modified arguments downstream
|
||||
4. Hook-provided headers take highest priority (merge after static_headers)
|
||||
5. OpenAPI-backed servers log a warning and continue (skip injection) when hook headers are present
|
||||
6. JWT claims are propagated in both standard and virtual-key fast paths
|
||||
7. Backward compatibility: hooks without extra_headers continue to work
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from litellm.proxy._experimental.mcp_server.mcp_server_manager import MCPServerManager
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.proxy.utils import ProxyLogging
|
||||
from litellm.types.mcp import MCPAuth, MCPTransport
|
||||
from litellm.types.mcp_server.mcp_server_manager import MCPServer
|
||||
|
||||
|
||||
class TestConvertMcpHookResponseToKwargs:
|
||||
"""Tests for ProxyLogging._convert_mcp_hook_response_to_kwargs"""
|
||||
|
||||
def setup_method(self):
|
||||
self.proxy_logging = ProxyLogging(user_api_key_cache=MagicMock())
|
||||
|
||||
def test_returns_original_kwargs_when_response_is_none(self):
|
||||
original = {"arguments": {"key": "val"}, "name": "tool"}
|
||||
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
|
||||
None, original
|
||||
)
|
||||
assert result == original
|
||||
|
||||
def test_returns_original_kwargs_when_response_is_empty_dict(self):
|
||||
original = {"arguments": {"key": "val"}}
|
||||
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs({}, original)
|
||||
assert result == original
|
||||
|
||||
def test_extracts_modified_arguments(self):
|
||||
original = {"arguments": {"old": "value"}}
|
||||
response = {"modified_arguments": {"new": "value"}}
|
||||
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
|
||||
response, original
|
||||
)
|
||||
assert result["arguments"] == {"new": "value"}
|
||||
|
||||
def test_extracts_extra_headers(self):
|
||||
original = {"arguments": {"key": "val"}}
|
||||
response = {"extra_headers": {"Authorization": "Bearer signed-jwt"}}
|
||||
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
|
||||
response, original
|
||||
)
|
||||
assert result["extra_headers"] == {"Authorization": "Bearer signed-jwt"}
|
||||
|
||||
def test_extracts_both_arguments_and_headers(self):
|
||||
original = {"arguments": {"old": "value"}}
|
||||
response = {
|
||||
"modified_arguments": {"new": "value"},
|
||||
"extra_headers": {"X-Custom": "header-val"},
|
||||
}
|
||||
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
|
||||
response, original
|
||||
)
|
||||
assert result["arguments"] == {"new": "value"}
|
||||
assert result["extra_headers"] == {"X-Custom": "header-val"}
|
||||
|
||||
def test_no_extra_headers_key_preserves_original(self):
|
||||
"""Backward compat: hooks that only return modified_arguments still work."""
|
||||
original = {"arguments": {"key": "val"}}
|
||||
response = {"modified_arguments": {"key": "new_val"}}
|
||||
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
|
||||
response, original
|
||||
)
|
||||
assert "extra_headers" not in result
|
||||
assert result["arguments"] == {"key": "new_val"}
|
||||
|
||||
def test_empty_extra_headers_not_set(self):
|
||||
"""Empty dict for extra_headers is falsy and should not be set."""
|
||||
original = {"arguments": {"key": "val"}}
|
||||
response = {"extra_headers": {}}
|
||||
result = self.proxy_logging._convert_mcp_hook_response_to_kwargs(
|
||||
response, original
|
||||
)
|
||||
assert "extra_headers" not in result
|
||||
|
||||
|
||||
class TestPreCallToolCheckReturnsHeaders:
|
||||
"""Tests that pre_call_tool_check returns hook-provided headers."""
|
||||
|
||||
def _make_server(self, name="test_server"):
|
||||
return MCPServer(
|
||||
server_id="test-id",
|
||||
name=name,
|
||||
server_name=name,
|
||||
url="https://example.com",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.none,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dict_when_hook_has_no_headers(self):
|
||||
manager = MCPServerManager()
|
||||
server = self._make_server()
|
||||
|
||||
proxy_logging = MagicMock(spec=ProxyLogging)
|
||||
proxy_logging._create_mcp_request_object_from_kwargs = MagicMock(
|
||||
return_value=MagicMock()
|
||||
)
|
||||
proxy_logging._convert_mcp_to_llm_format = MagicMock(
|
||||
return_value={"model": "fake"}
|
||||
)
|
||||
proxy_logging.pre_call_hook = AsyncMock(
|
||||
return_value={"modified_arguments": {"key": "val"}}
|
||||
)
|
||||
proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock(
|
||||
return_value={"arguments": {"key": "val"}}
|
||||
)
|
||||
|
||||
with patch.object(manager, "check_allowed_or_banned_tools", return_value=True):
|
||||
with patch.object(
|
||||
manager,
|
||||
"check_tool_permission_for_key_team",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
with patch.object(manager, "validate_allowed_params"):
|
||||
result = await manager.pre_call_tool_check(
|
||||
name="test_tool",
|
||||
arguments={"key": "val"},
|
||||
server_name="test_server",
|
||||
user_api_key_auth=None,
|
||||
proxy_logging_obj=proxy_logging,
|
||||
server=server,
|
||||
)
|
||||
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_extra_headers_from_hook(self):
|
||||
manager = MCPServerManager()
|
||||
server = self._make_server()
|
||||
|
||||
hook_headers = {"Authorization": "Bearer signed-jwt", "X-Trace-Id": "abc123"}
|
||||
|
||||
proxy_logging = MagicMock(spec=ProxyLogging)
|
||||
proxy_logging._create_mcp_request_object_from_kwargs = MagicMock(
|
||||
return_value=MagicMock()
|
||||
)
|
||||
proxy_logging._convert_mcp_to_llm_format = MagicMock(
|
||||
return_value={"model": "fake"}
|
||||
)
|
||||
proxy_logging.pre_call_hook = AsyncMock(
|
||||
return_value={"extra_headers": hook_headers}
|
||||
)
|
||||
proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock(
|
||||
return_value={"arguments": {"key": "val"}, "extra_headers": hook_headers}
|
||||
)
|
||||
|
||||
with patch.object(manager, "check_allowed_or_banned_tools", return_value=True):
|
||||
with patch.object(
|
||||
manager,
|
||||
"check_tool_permission_for_key_team",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
with patch.object(manager, "validate_allowed_params"):
|
||||
result = await manager.pre_call_tool_check(
|
||||
name="test_tool",
|
||||
arguments={"key": "val"},
|
||||
server_name="test_server",
|
||||
user_api_key_auth=None,
|
||||
proxy_logging_obj=proxy_logging,
|
||||
server=server,
|
||||
)
|
||||
|
||||
assert result["extra_headers"] == hook_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_empty_dict_when_hook_returns_none(self):
|
||||
manager = MCPServerManager()
|
||||
server = self._make_server()
|
||||
|
||||
proxy_logging = MagicMock(spec=ProxyLogging)
|
||||
proxy_logging._create_mcp_request_object_from_kwargs = MagicMock(
|
||||
return_value=MagicMock()
|
||||
)
|
||||
proxy_logging._convert_mcp_to_llm_format = MagicMock(
|
||||
return_value={"model": "fake"}
|
||||
)
|
||||
proxy_logging.pre_call_hook = AsyncMock(return_value=None)
|
||||
|
||||
with patch.object(manager, "check_allowed_or_banned_tools", return_value=True):
|
||||
with patch.object(
|
||||
manager,
|
||||
"check_tool_permission_for_key_team",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
with patch.object(manager, "validate_allowed_params"):
|
||||
result = await manager.pre_call_tool_check(
|
||||
name="test_tool",
|
||||
arguments={"key": "val"},
|
||||
server_name="test_server",
|
||||
user_api_key_auth=None,
|
||||
proxy_logging_obj=proxy_logging,
|
||||
server=server,
|
||||
)
|
||||
|
||||
assert result == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_modified_arguments_from_hook(self):
|
||||
"""Modified arguments from the hook must be returned so the caller can use them."""
|
||||
manager = MCPServerManager()
|
||||
server = self._make_server()
|
||||
|
||||
original_args = {"key": "original"}
|
||||
modified_args = {"key": "modified", "extra": "added"}
|
||||
|
||||
proxy_logging = MagicMock(spec=ProxyLogging)
|
||||
proxy_logging._create_mcp_request_object_from_kwargs = MagicMock(
|
||||
return_value=MagicMock()
|
||||
)
|
||||
proxy_logging._convert_mcp_to_llm_format = MagicMock(
|
||||
return_value={"model": "fake"}
|
||||
)
|
||||
proxy_logging.pre_call_hook = AsyncMock(
|
||||
return_value={"modified_arguments": modified_args}
|
||||
)
|
||||
proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock(
|
||||
return_value={"arguments": modified_args}
|
||||
)
|
||||
|
||||
with patch.object(manager, "check_allowed_or_banned_tools", return_value=True):
|
||||
with patch.object(
|
||||
manager,
|
||||
"check_tool_permission_for_key_team",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
with patch.object(manager, "validate_allowed_params"):
|
||||
result = await manager.pre_call_tool_check(
|
||||
name="test_tool",
|
||||
arguments=original_args,
|
||||
server_name="test_server",
|
||||
user_api_key_auth=None,
|
||||
proxy_logging_obj=proxy_logging,
|
||||
server=server,
|
||||
)
|
||||
|
||||
assert result["arguments"] == modified_args
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_both_modified_arguments_and_headers(self):
|
||||
"""Hook can modify both arguments and inject headers simultaneously."""
|
||||
manager = MCPServerManager()
|
||||
server = self._make_server()
|
||||
|
||||
modified_args = {"key": "modified"}
|
||||
hook_headers = {"Authorization": "Bearer jwt"}
|
||||
|
||||
proxy_logging = MagicMock(spec=ProxyLogging)
|
||||
proxy_logging._create_mcp_request_object_from_kwargs = MagicMock(
|
||||
return_value=MagicMock()
|
||||
)
|
||||
proxy_logging._convert_mcp_to_llm_format = MagicMock(
|
||||
return_value={"model": "fake"}
|
||||
)
|
||||
proxy_logging.pre_call_hook = AsyncMock(return_value={"dummy": True})
|
||||
proxy_logging._convert_mcp_hook_response_to_kwargs = MagicMock(
|
||||
return_value={"arguments": modified_args, "extra_headers": hook_headers}
|
||||
)
|
||||
|
||||
with patch.object(manager, "check_allowed_or_banned_tools", return_value=True):
|
||||
with patch.object(
|
||||
manager,
|
||||
"check_tool_permission_for_key_team",
|
||||
new_callable=AsyncMock,
|
||||
):
|
||||
with patch.object(manager, "validate_allowed_params"):
|
||||
result = await manager.pre_call_tool_check(
|
||||
name="test_tool",
|
||||
arguments={"key": "original"},
|
||||
server_name="test_server",
|
||||
user_api_key_auth=None,
|
||||
proxy_logging_obj=proxy_logging,
|
||||
server=server,
|
||||
)
|
||||
|
||||
assert result["arguments"] == modified_args
|
||||
assert result["extra_headers"] == hook_headers
|
||||
|
||||
|
||||
class TestCallToolFlowsHookHeaders:
|
||||
"""Tests that call_tool passes hook_extra_headers to _call_regular_mcp_tool."""
|
||||
|
||||
def _make_server(self, name="test_server"):
|
||||
return MCPServer(
|
||||
server_id="test-id",
|
||||
name=name,
|
||||
server_name=name,
|
||||
url="https://example.com",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.none,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_headers_passed_to_call_regular_mcp_tool(self):
|
||||
"""Verify that hook_extra_headers kwarg is forwarded."""
|
||||
manager = MCPServerManager()
|
||||
server = self._make_server()
|
||||
|
||||
hook_headers = {"Authorization": "Bearer signed-jwt"}
|
||||
|
||||
with patch.object(
|
||||
manager,
|
||||
"_get_mcp_server_from_tool_name",
|
||||
return_value=server,
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"pre_call_tool_check",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"extra_headers": hook_headers},
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"_create_during_hook_task",
|
||||
return_value=asyncio.create_task(asyncio.sleep(0)),
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"_call_regular_mcp_tool",
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(),
|
||||
) as mock_call:
|
||||
proxy_logging = MagicMock(spec=ProxyLogging)
|
||||
|
||||
await manager.call_tool(
|
||||
server_name="test_server",
|
||||
name="test_tool",
|
||||
arguments={"key": "val"},
|
||||
proxy_logging_obj=proxy_logging,
|
||||
)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
call_kwargs = mock_call.call_args
|
||||
assert call_kwargs.kwargs.get("hook_extra_headers") == hook_headers
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_hook_headers_when_no_proxy_logging(self):
|
||||
"""Without proxy_logging_obj, no pre_call_tool_check runs."""
|
||||
manager = MCPServerManager()
|
||||
server = self._make_server()
|
||||
|
||||
with patch.object(
|
||||
manager,
|
||||
"_get_mcp_server_from_tool_name",
|
||||
return_value=server,
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"_call_regular_mcp_tool",
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(),
|
||||
) as mock_call:
|
||||
await manager.call_tool(
|
||||
server_name="test_server",
|
||||
name="test_tool",
|
||||
arguments={"key": "val"},
|
||||
proxy_logging_obj=None,
|
||||
)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
call_kwargs = mock_call.call_args
|
||||
assert call_kwargs.kwargs.get("hook_extra_headers") is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_modified_arguments_passed_to_downstream(self):
|
||||
"""Hook-modified arguments must be used for the actual tool call."""
|
||||
manager = MCPServerManager()
|
||||
server = self._make_server()
|
||||
|
||||
modified_args = {"key": "modified_by_hook"}
|
||||
|
||||
with patch.object(
|
||||
manager,
|
||||
"_get_mcp_server_from_tool_name",
|
||||
return_value=server,
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"pre_call_tool_check",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"arguments": modified_args},
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"_create_during_hook_task",
|
||||
return_value=asyncio.create_task(asyncio.sleep(0)),
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"_call_regular_mcp_tool",
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(),
|
||||
) as mock_call:
|
||||
proxy_logging = MagicMock(spec=ProxyLogging)
|
||||
|
||||
await manager.call_tool(
|
||||
server_name="test_server",
|
||||
name="test_tool",
|
||||
arguments={"key": "original"},
|
||||
proxy_logging_obj=proxy_logging,
|
||||
)
|
||||
|
||||
mock_call.assert_called_once()
|
||||
call_kwargs = mock_call.call_args
|
||||
assert call_kwargs.kwargs.get("arguments") == modified_args
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openapi_server_warns_and_continues_on_hook_headers(self):
|
||||
"""OpenAPI-backed servers log a warning and continue when hook injects headers."""
|
||||
manager = MCPServerManager()
|
||||
server = MCPServer(
|
||||
server_id="test-id",
|
||||
name="openapi_server",
|
||||
server_name="openapi_server",
|
||||
url="https://example.com",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.none,
|
||||
spec_path="/path/to/spec.yaml",
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
manager, "_get_mcp_server_from_tool_name", return_value=server
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"pre_call_tool_check",
|
||||
new_callable=AsyncMock,
|
||||
return_value={"extra_headers": {"Authorization": "Bearer jwt"}},
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"_create_during_hook_task",
|
||||
return_value=asyncio.create_task(asyncio.sleep(0)),
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"_call_openapi_tool_handler",
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
import litellm.proxy._experimental.mcp_server.mcp_server_manager as mgr_mod
|
||||
|
||||
proxy_logging = MagicMock(spec=ProxyLogging)
|
||||
|
||||
with patch.object(mgr_mod, "verbose_logger") as mock_logger:
|
||||
# Should NOT raise — just warn and proceed
|
||||
await manager.call_tool(
|
||||
server_name="openapi_server",
|
||||
name="test_tool",
|
||||
arguments={},
|
||||
proxy_logging_obj=proxy_logging,
|
||||
)
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "header injection is not supported" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openapi_server_no_error_without_hook_headers(self):
|
||||
"""No exception when OpenAPI server has no hook-injected headers."""
|
||||
manager = MCPServerManager()
|
||||
server = MCPServer(
|
||||
server_id="test-id",
|
||||
name="openapi_server",
|
||||
server_name="openapi_server",
|
||||
url="https://example.com",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.none,
|
||||
spec_path="/path/to/spec.yaml",
|
||||
)
|
||||
|
||||
with patch.object(
|
||||
manager, "_get_mcp_server_from_tool_name", return_value=server
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"pre_call_tool_check",
|
||||
new_callable=AsyncMock,
|
||||
return_value={},
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"_create_during_hook_task",
|
||||
return_value=asyncio.create_task(asyncio.sleep(0)),
|
||||
):
|
||||
with patch.object(
|
||||
manager,
|
||||
"_call_openapi_tool_handler",
|
||||
new_callable=AsyncMock,
|
||||
return_value=MagicMock(),
|
||||
):
|
||||
proxy_logging = MagicMock(spec=ProxyLogging)
|
||||
|
||||
await manager.call_tool(
|
||||
server_name="openapi_server",
|
||||
name="test_tool",
|
||||
arguments={},
|
||||
proxy_logging_obj=proxy_logging,
|
||||
)
|
||||
|
||||
|
||||
class TestHookHeaderMergePriority:
|
||||
"""Tests that hook-provided headers have highest priority in _call_regular_mcp_tool."""
|
||||
|
||||
def _make_server(
|
||||
self,
|
||||
static_headers: Optional[Dict[str, str]] = None,
|
||||
extra_headers_config: Optional[list] = None,
|
||||
):
|
||||
return MCPServer(
|
||||
server_id="test-id",
|
||||
name="Test Server",
|
||||
server_name="test_server",
|
||||
url="https://example.com",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.none,
|
||||
static_headers=static_headers,
|
||||
extra_headers=extra_headers_config,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_headers_override_static_headers(self):
|
||||
"""Hook headers should take precedence over static_headers."""
|
||||
manager = MCPServerManager()
|
||||
server = self._make_server(
|
||||
static_headers={"Authorization": "Bearer static-token", "X-Static": "yes"}
|
||||
)
|
||||
|
||||
hook_headers = {"Authorization": "Bearer hook-signed-jwt"}
|
||||
|
||||
captured_extra_headers: Dict[str, Any] = {}
|
||||
|
||||
async def fake_create_mcp_client(
|
||||
server, mcp_auth_header=None, extra_headers=None, stdio_env=None
|
||||
):
|
||||
captured_extra_headers["value"] = extra_headers
|
||||
mock_client = MagicMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=MagicMock())
|
||||
return mock_client
|
||||
|
||||
with patch.object(
|
||||
manager, "_create_mcp_client", side_effect=fake_create_mcp_client
|
||||
):
|
||||
with patch.object(manager, "_build_stdio_env", return_value=None):
|
||||
try:
|
||||
await manager._call_regular_mcp_tool(
|
||||
mcp_server=server,
|
||||
original_tool_name="test_tool",
|
||||
arguments={"key": "val"},
|
||||
tasks=[],
|
||||
mcp_auth_header=None,
|
||||
mcp_server_auth_headers=None,
|
||||
oauth2_headers=None,
|
||||
raw_headers=None,
|
||||
proxy_logging_obj=None,
|
||||
hook_extra_headers=hook_headers,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
headers = captured_extra_headers.get("value", {})
|
||||
assert headers["Authorization"] == "Bearer hook-signed-jwt"
|
||||
assert headers["X-Static"] == "yes"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_hook_headers_preserves_existing_behavior(self):
|
||||
"""When hook_extra_headers is None, existing header logic is unchanged."""
|
||||
manager = MCPServerManager()
|
||||
server = self._make_server(
|
||||
static_headers={"X-Static": "static-value"}
|
||||
)
|
||||
|
||||
captured_extra_headers: Dict[str, Any] = {}
|
||||
|
||||
async def fake_create_mcp_client(
|
||||
server, mcp_auth_header=None, extra_headers=None, stdio_env=None
|
||||
):
|
||||
captured_extra_headers["value"] = extra_headers
|
||||
mock_client = MagicMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=MagicMock())
|
||||
return mock_client
|
||||
|
||||
with patch.object(
|
||||
manager, "_create_mcp_client", side_effect=fake_create_mcp_client
|
||||
):
|
||||
with patch.object(manager, "_build_stdio_env", return_value=None):
|
||||
try:
|
||||
await manager._call_regular_mcp_tool(
|
||||
mcp_server=server,
|
||||
original_tool_name="test_tool",
|
||||
arguments={"key": "val"},
|
||||
tasks=[],
|
||||
mcp_auth_header=None,
|
||||
mcp_server_auth_headers=None,
|
||||
oauth2_headers=None,
|
||||
raw_headers=None,
|
||||
proxy_logging_obj=None,
|
||||
hook_extra_headers=None,
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
headers = captured_extra_headers.get("value", {})
|
||||
assert headers == {"X-Static": "static-value"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hook_headers_merge_with_oauth2(self):
|
||||
"""Hook headers merge on top of OAuth2 headers."""
|
||||
manager = MCPServerManager()
|
||||
server = MCPServer(
|
||||
server_id="test-id",
|
||||
name="Test Server",
|
||||
server_name="test_server",
|
||||
url="https://example.com",
|
||||
transport=MCPTransport.http,
|
||||
auth_type=MCPAuth.oauth2,
|
||||
)
|
||||
|
||||
captured_extra_headers: Dict[str, Any] = {}
|
||||
|
||||
async def fake_create_mcp_client(
|
||||
server, mcp_auth_header=None, extra_headers=None, stdio_env=None
|
||||
):
|
||||
captured_extra_headers["value"] = extra_headers
|
||||
mock_client = MagicMock()
|
||||
mock_client.call_tool = AsyncMock(return_value=MagicMock())
|
||||
return mock_client
|
||||
|
||||
with patch.object(
|
||||
manager, "_create_mcp_client", side_effect=fake_create_mcp_client
|
||||
):
|
||||
with patch.object(manager, "_build_stdio_env", return_value=None):
|
||||
try:
|
||||
await manager._call_regular_mcp_tool(
|
||||
mcp_server=server,
|
||||
original_tool_name="test_tool",
|
||||
arguments={"key": "val"},
|
||||
tasks=[],
|
||||
mcp_auth_header=None,
|
||||
mcp_server_auth_headers=None,
|
||||
oauth2_headers={
|
||||
"Authorization": "Bearer oauth2-token",
|
||||
"X-OAuth": "yes",
|
||||
},
|
||||
raw_headers=None,
|
||||
proxy_logging_obj=None,
|
||||
hook_extra_headers={
|
||||
"Authorization": "Bearer hook-jwt",
|
||||
"X-Trace-Id": "trace-123",
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
headers = captured_extra_headers.get("value", {})
|
||||
assert headers["Authorization"] == "Bearer hook-jwt"
|
||||
assert headers["X-OAuth"] == "yes"
|
||||
assert headers["X-Trace-Id"] == "trace-123"
|
||||
|
||||
|
||||
class TestUserAPIKeyAuthJwtClaims:
|
||||
"""Tests that UserAPIKeyAuth correctly carries jwt_claims."""
|
||||
|
||||
def test_jwt_claims_field_defaults_to_none(self):
|
||||
auth = UserAPIKeyAuth(api_key="test-key")
|
||||
assert auth.jwt_claims is None
|
||||
|
||||
def test_jwt_claims_field_accepts_dict(self):
|
||||
claims = {"sub": "user-123", "iss": "litellm", "exp": 9999999999}
|
||||
auth = UserAPIKeyAuth(api_key="test-key", jwt_claims=claims)
|
||||
assert auth.jwt_claims == claims
|
||||
assert auth.jwt_claims["sub"] == "user-123"
|
||||
|
||||
def test_jwt_claims_backward_compatible_without_field(self):
|
||||
"""Existing code that doesn't pass jwt_claims should still work."""
|
||||
auth = UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_id="user-1",
|
||||
team_id="team-1",
|
||||
)
|
||||
assert auth.jwt_claims is None
|
||||
assert auth.user_id == "user-1"
|
||||
|
||||
def test_jwt_claims_set_after_construction(self):
|
||||
"""Virtual-key fast path sets jwt_claims after the object is created."""
|
||||
auth = UserAPIKeyAuth(api_key="test-key")
|
||||
assert auth.jwt_claims is None
|
||||
|
||||
claims = {"sub": "user-456", "iss": "okta", "groups": ["admin"]}
|
||||
auth.jwt_claims = claims
|
||||
assert auth.jwt_claims == claims
|
||||
assert auth.jwt_claims["groups"] == ["admin"]
|
||||
1103
tests/test_litellm/proxy/guardrails/test_mcp_jwt_signer.py
Normal file
1103
tests/test_litellm/proxy/guardrails/test_mcp_jwt_signer.py
Normal file
File diff suppressed because it is too large
Load Diff
188
tests/test_litellm/test_setup_wizard.py
Normal file
188
tests/test_litellm/test_setup_wizard.py
Normal file
@ -0,0 +1,188 @@
|
||||
"""Unit tests for litellm.setup_wizard — pure functions only, no network calls."""
|
||||
|
||||
from litellm.setup_wizard import SetupWizard, _yaml_escape
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _yaml_escape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_yaml_escape_plain():
|
||||
assert _yaml_escape("sk-abc123") == "sk-abc123"
|
||||
|
||||
|
||||
def test_yaml_escape_double_quote():
|
||||
assert _yaml_escape('sk-ab"cd') == 'sk-ab\\"cd'
|
||||
|
||||
|
||||
def test_yaml_escape_backslash():
|
||||
assert _yaml_escape("sk-ab\\cd") == "sk-ab\\\\cd"
|
||||
|
||||
|
||||
def test_yaml_escape_combined():
|
||||
assert _yaml_escape('ab\\"cd') == 'ab\\\\\\"cd'
|
||||
|
||||
|
||||
def test_yaml_escape_newline():
|
||||
assert _yaml_escape("sk-abc\ndef") == "sk-abc\\ndef"
|
||||
|
||||
|
||||
def test_yaml_escape_carriage_return():
|
||||
assert _yaml_escape("sk-abc\rdef") == "sk-abc\\rdef"
|
||||
|
||||
|
||||
def test_yaml_escape_tab():
|
||||
assert _yaml_escape("sk-abc\tdef") == "sk-abc\\tdef"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SetupWizard._build_config
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_OPENAI = {
|
||||
"id": "openai",
|
||||
"name": "OpenAI",
|
||||
"env_key": "OPENAI_API_KEY",
|
||||
"models": ["gpt-4o", "gpt-4o-mini"],
|
||||
"test_model": "gpt-4o-mini",
|
||||
}
|
||||
|
||||
_ANTHROPIC = {
|
||||
"id": "anthropic",
|
||||
"name": "Anthropic",
|
||||
"env_key": "ANTHROPIC_API_KEY",
|
||||
"models": ["claude-opus-4-6"],
|
||||
"test_model": "claude-haiku-4-5-20251001",
|
||||
}
|
||||
|
||||
_AZURE = {
|
||||
"id": "azure",
|
||||
"name": "Azure OpenAI",
|
||||
"env_key": "AZURE_API_KEY",
|
||||
"models": [],
|
||||
"test_model": None,
|
||||
"needs_api_base": True,
|
||||
"api_base_hint": "https://<resource>.openai.azure.com/",
|
||||
"api_version": "2024-07-01-preview",
|
||||
}
|
||||
|
||||
_OLLAMA = {
|
||||
"id": "ollama",
|
||||
"name": "Ollama",
|
||||
"env_key": None,
|
||||
"models": ["ollama/llama3.2"],
|
||||
"test_model": None,
|
||||
"api_base": "http://localhost:11434",
|
||||
}
|
||||
|
||||
|
||||
def test_build_config_basic_openai():
|
||||
config = SetupWizard._build_config(
|
||||
[_OPENAI],
|
||||
{"OPENAI_API_KEY": "sk-test"},
|
||||
"sk-master",
|
||||
)
|
||||
assert "model_list:" in config
|
||||
assert "model_name: gpt-4o" in config
|
||||
assert "model: gpt-4o" in config
|
||||
assert "api_key: os.environ/OPENAI_API_KEY" in config
|
||||
assert 'master_key: "sk-master"' in config
|
||||
|
||||
|
||||
def test_build_config_skipped_provider_omitted():
|
||||
"""Provider with no key in env_vars should not appear in model_list."""
|
||||
config = SetupWizard._build_config(
|
||||
[_OPENAI, _ANTHROPIC],
|
||||
{"ANTHROPIC_API_KEY": "sk-ant-test"}, # OpenAI key missing
|
||||
"sk-master",
|
||||
)
|
||||
assert "gpt-4o" not in config
|
||||
assert "claude-opus-4-6" in config
|
||||
|
||||
|
||||
def test_build_config_env_vars_written_escaped():
|
||||
"""API keys with special chars should be YAML-escaped."""
|
||||
config = SetupWizard._build_config(
|
||||
[_OPENAI],
|
||||
{"OPENAI_API_KEY": 'sk-ab"cd'},
|
||||
"sk-master",
|
||||
)
|
||||
assert 'OPENAI_API_KEY: "sk-ab\\"cd"' in config
|
||||
|
||||
|
||||
def test_build_config_master_key_quoted():
|
||||
"""master_key must be quoted in YAML to handle special characters."""
|
||||
config = SetupWizard._build_config(
|
||||
[_OPENAI],
|
||||
{"OPENAI_API_KEY": "sk-test"},
|
||||
'sk-master"special',
|
||||
)
|
||||
assert 'master_key: "sk-master\\"special"' in config
|
||||
|
||||
|
||||
def test_build_config_does_not_mutate_env_vars():
|
||||
"""_build_config must not modify the caller's env_vars dict."""
|
||||
env_vars = {
|
||||
"AZURE_API_KEY": "az-key",
|
||||
"_LITELLM_AZURE_API_BASE_AZURE": "https://my.azure.com",
|
||||
"_LITELLM_AZURE_DEPLOYMENT_AZURE": "my-deployment",
|
||||
}
|
||||
original_keys = set(env_vars.keys())
|
||||
SetupWizard._build_config([_AZURE], env_vars, "sk-master")
|
||||
assert set(env_vars.keys()) == original_keys
|
||||
|
||||
|
||||
def test_build_config_azure_uses_deployment_name():
|
||||
env_vars = {
|
||||
"AZURE_API_KEY": "az-key",
|
||||
"_LITELLM_AZURE_API_BASE_AZURE": "https://my.azure.com",
|
||||
"_LITELLM_AZURE_DEPLOYMENT_AZURE": "my-gpt4o",
|
||||
}
|
||||
config = SetupWizard._build_config([_AZURE], env_vars, "sk-master")
|
||||
assert "model: azure/my-gpt4o" in config
|
||||
assert "model_name: azure-my-gpt4o" in config
|
||||
# api_base must be quoted to survive YAML special chars
|
||||
assert 'api_base: "https://my.azure.com"' in config
|
||||
|
||||
|
||||
def test_build_config_azure_no_deployment_skipped():
|
||||
"""Azure without a deployment name should emit nothing (not fallback to gpt-4o)."""
|
||||
env_vars = {"AZURE_API_KEY": "az-key"} # no deployment sentinel
|
||||
config = SetupWizard._build_config([_AZURE], env_vars, "sk-master")
|
||||
# No azure model entry should be emitted when deployment name is absent
|
||||
assert "model: azure/" not in config
|
||||
|
||||
|
||||
def test_build_config_no_display_name_collision_openai_and_azure():
|
||||
"""OpenAI gpt-4o and azure gpt-4o should get distinct model_name values."""
|
||||
env_vars = {
|
||||
"OPENAI_API_KEY": "sk-openai",
|
||||
"AZURE_API_KEY": "az-key",
|
||||
"_LITELLM_AZURE_DEPLOYMENT_AZURE": "gpt-4o",
|
||||
}
|
||||
config = SetupWizard._build_config([_OPENAI, _AZURE], env_vars, "sk-master")
|
||||
assert "model_name: gpt-4o" in config # OpenAI
|
||||
assert "model_name: azure-gpt-4o" in config # Azure — qualified
|
||||
|
||||
|
||||
def test_build_config_ollama_no_api_key_line():
|
||||
"""Ollama has no env_key — config should not contain an api_key line for it."""
|
||||
config = SetupWizard._build_config([_OLLAMA], {}, "sk-master")
|
||||
assert "ollama/llama3.2" in config
|
||||
assert "api_key:" not in config
|
||||
|
||||
|
||||
def test_build_config_master_key_in_general_settings():
|
||||
"""master_key is written to general_settings."""
|
||||
config = SetupWizard._build_config([_OPENAI], {"OPENAI_API_KEY": "k"}, "sk-m")
|
||||
assert 'master_key: "sk-m"' in config
|
||||
|
||||
|
||||
def test_build_config_internal_sentinel_keys_excluded():
|
||||
"""_LITELLM_ prefixed sentinel keys must not appear in environment_variables."""
|
||||
env_vars = {
|
||||
"OPENAI_API_KEY": "sk-real",
|
||||
"_LITELLM_AZURE_API_BASE_AZURE": "https://x.azure.com",
|
||||
}
|
||||
config = SetupWizard._build_config([_OPENAI], env_vars, "sk-master")
|
||||
assert "_LITELLM_" not in config
|
||||
@ -20,7 +20,6 @@ import {
|
||||
ToolOutlined,
|
||||
TagsOutlined,
|
||||
AuditOutlined,
|
||||
MessageOutlined,
|
||||
} from "@ant-design/icons";
|
||||
// import {
|
||||
// all_admin_roles,
|
||||
@ -466,41 +465,6 @@ const Sidebar2: React.FC<SidebarProps> = ({ accessToken, userRole, defaultSelect
|
||||
</ConfigProvider>
|
||||
{isAdminRole(userRole) && !collapsed && <UsageIndicator accessToken={accessToken} width={220} />}
|
||||
|
||||
{/* Pinned "Open Chat" button at bottom */}
|
||||
<div style={{
|
||||
padding: collapsed ? "10px 8px" : "10px 12px",
|
||||
borderTop: "1px solid #f0f0f0",
|
||||
flexShrink: 0,
|
||||
}}>
|
||||
<a
|
||||
href={toHref("chat")}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
style={{
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
justifyContent: collapsed ? "center" : "flex-start",
|
||||
gap: 8,
|
||||
padding: collapsed ? "8px 0" : "8px 10px",
|
||||
borderRadius: 8,
|
||||
background: "#1677ff",
|
||||
color: "#fff",
|
||||
textDecoration: "none",
|
||||
fontSize: 13,
|
||||
fontWeight: 600,
|
||||
transition: "background 0.15s",
|
||||
}}
|
||||
onMouseEnter={(e) => {
|
||||
(e.currentTarget as HTMLAnchorElement).style.background = "#0958d9";
|
||||
}}
|
||||
onMouseLeave={(e) => {
|
||||
(e.currentTarget as HTMLAnchorElement).style.background = "#1677ff";
|
||||
}}
|
||||
>
|
||||
<MessageOutlined style={{ fontSize: 16, flexShrink: 0 }} />
|
||||
{!collapsed && <span>Open Chat</span>}
|
||||
</a>
|
||||
</div>
|
||||
</Sider>
|
||||
</Layout>
|
||||
);
|
||||
|
||||
@ -8,8 +8,6 @@ import ComplianceUI from "@/components/playground/complianceUI/ComplianceUI";
|
||||
import { TabGroup, TabList, Tab, TabPanels, TabPanel } from "@tremor/react";
|
||||
import useAuthorized from "@/app/(dashboard)/hooks/useAuthorized";
|
||||
import { fetchProxySettings } from "@/utils/proxyUtils";
|
||||
import { useUIConfig } from "@/app/(dashboard)/hooks/uiConfig/useUIConfig";
|
||||
import { MessageOutlined, CloseOutlined } from "@ant-design/icons";
|
||||
|
||||
interface ProxySettings {
|
||||
PROXY_BASE_URL?: string;
|
||||
@ -19,12 +17,6 @@ interface ProxySettings {
|
||||
export default function PlaygroundPage() {
|
||||
const { accessToken, userRole, userId, disabledPersonalKeyCreation, token } = useAuthorized();
|
||||
const [proxySettings, setProxySettings] = useState<ProxySettings | undefined>(undefined);
|
||||
const [chatBannerDismissed, setChatBannerDismissed] = useState(false);
|
||||
const { data: uiConfig } = useUIConfig();
|
||||
const uiRoot = uiConfig?.server_root_path && uiConfig.server_root_path !== "/"
|
||||
? uiConfig.server_root_path.replace(/\/+$/, "")
|
||||
: "";
|
||||
const chatHref = `${uiRoot}/ui/chat`;
|
||||
|
||||
useEffect(() => {
|
||||
const initializeProxySettings = async () => {
|
||||
@ -44,64 +36,6 @@ export default function PlaygroundPage() {
|
||||
|
||||
return (
|
||||
<div className="h-full w-full flex flex-col">
|
||||
{!chatBannerDismissed && (
|
||||
<div style={{
|
||||
display: "flex",
|
||||
alignItems: "center",
|
||||
gap: 16,
|
||||
padding: "10px 20px",
|
||||
background: "#f0f9ff",
|
||||
borderBottom: "1px solid #bae6fd",
|
||||
flexShrink: 0,
|
||||
}}>
|
||||
<span style={{
|
||||
fontSize: 10,
|
||||
fontWeight: 700,
|
||||
color: "#fff",
|
||||
background: "#0ea5e9",
|
||||
borderRadius: 4,
|
||||
padding: "2px 7px",
|
||||
letterSpacing: "0.08em",
|
||||
textTransform: "uppercase",
|
||||
flexShrink: 0,
|
||||
lineHeight: "18px",
|
||||
}}>
|
||||
New
|
||||
</span>
|
||||
<span style={{ flex: 1, color: "#0c4a6e", fontSize: 13.5, lineHeight: 1.5 }}>
|
||||
<strong>Chat UI</strong>
|
||||
{" "}— a ChatGPT-like interface for your users to chat with AI models and MCP tools. Share it with your team.
|
||||
</span>
|
||||
<a
|
||||
href={chatHref}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
style={{
|
||||
display: "inline-flex",
|
||||
alignItems: "center",
|
||||
gap: 5,
|
||||
padding: "5px 14px",
|
||||
borderRadius: 6,
|
||||
background: "#0ea5e9",
|
||||
color: "#fff",
|
||||
fontSize: 12.5,
|
||||
fontWeight: 600,
|
||||
textDecoration: "none",
|
||||
whiteSpace: "nowrap",
|
||||
flexShrink: 0,
|
||||
}}
|
||||
>
|
||||
Open Chat UI →
|
||||
</a>
|
||||
<button
|
||||
onClick={() => setChatBannerDismissed(true)}
|
||||
style={{ background: "none", border: "none", cursor: "pointer", color: "#64748b", padding: 4, flexShrink: 0, lineHeight: 1 }}
|
||||
aria-label="Dismiss"
|
||||
>
|
||||
<CloseOutlined style={{ fontSize: 13 }} />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
<TabGroup className="w-full" style={{ flex: 1, minHeight: 0, display: "flex", flexDirection: "column" }}>
|
||||
<TabList className="mb-0">
|
||||
<Tab>Chat</Tab>
|
||||
|
||||
Loading…
Reference in New Issue
Block a user