commit
5d6c76aa1a
@ -169,9 +169,6 @@ def get_llm_provider( # noqa: PLR0915
|
||||
return remainder, custom_llm_provider, dynamic_api_key, api_base
|
||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
||||
|
||||
if api_key and api_key.startswith("os.environ/"):
|
||||
dynamic_api_key = get_secret_str(api_key)
|
||||
|
||||
# Check JSON-configured providers FIRST (before enum-based provider_list)
|
||||
provider_prefix = model.split("/", 1)[0]
|
||||
if len(model.split("/")) > 1 and JSONProviderRegistry.exists(provider_prefix):
|
||||
|
||||
@ -7,6 +7,7 @@ JWT token must have 'litellm_proxy_admin' in scope.
|
||||
"""
|
||||
|
||||
import fnmatch
|
||||
import hashlib
|
||||
import os
|
||||
import re
|
||||
from typing import Any, List, Literal, Optional, Set, Tuple, cast
|
||||
@ -621,7 +622,7 @@ class JWTHandler:
|
||||
|
||||
# Check cache first
|
||||
cache_key = (
|
||||
f"oidc_userinfo_{token[:20]}" # Use first 20 chars of token as cache key
|
||||
f"oidc_userinfo_{hashlib.sha256(token.encode()).hexdigest()}"
|
||||
)
|
||||
cached_userinfo = await self.user_api_key_cache.async_get_cache(cache_key)
|
||||
|
||||
|
||||
@ -13,6 +13,7 @@ from fastapi import HTTPException
|
||||
|
||||
import litellm
|
||||
from litellm.constants import LITELLM_PROXY_ADMIN_NAME, LITELLM_UI_SESSION_DURATION
|
||||
from litellm.proxy.utils import hash_password, verify_password
|
||||
from litellm.proxy._types import (
|
||||
LiteLLM_UserTable,
|
||||
LitellmUserRoles,
|
||||
@ -34,6 +35,18 @@ from litellm.secret_managers.main import get_secret_bool
|
||||
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
|
||||
|
||||
|
||||
async def _rehash_password_if_needed(user_id: str, password: str, stored: str) -> None:
|
||||
"""Rehash legacy password (SHA256) to scrypt on successful login."""
|
||||
if stored.startswith("scrypt:"):
|
||||
return
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
if prisma_client is not None:
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
where={"user_id": user_id},
|
||||
data={"password": hash_password(password)},
|
||||
)
|
||||
|
||||
|
||||
def get_ui_credentials(master_key: Optional[str]) -> tuple[str, str]:
|
||||
"""
|
||||
Get UI username and password from environment variables or master key.
|
||||
@ -254,13 +267,8 @@ async def authenticate_user( # noqa: PLR0915
|
||||
code=401,
|
||||
)
|
||||
|
||||
# check if password == _user_row.password
|
||||
hash_password = hash_token(token=password)
|
||||
if secrets.compare_digest(
|
||||
password.encode("utf-8"), _password.encode("utf-8")
|
||||
) or secrets.compare_digest(
|
||||
hash_password.encode("utf-8"), _password.encode("utf-8")
|
||||
):
|
||||
if verify_password(password, _password):
|
||||
await _rehash_password_if_needed(_user_row.user_id, password, _password)
|
||||
if os.getenv("DATABASE_URL") is not None:
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key",
|
||||
|
||||
@ -41,7 +41,7 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
prepare_metadata_fields,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.proxy.utils import handle_exception_on_proxy
|
||||
from litellm.proxy.utils import handle_exception_on_proxy, hash_password
|
||||
from litellm.types.proxy.management_endpoints.common_daily_activity import (
|
||||
SpendAnalyticsPaginatedResponse,
|
||||
)
|
||||
@ -58,6 +58,22 @@ if TYPE_CHECKING:
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _hash_password_in_dict(data: dict) -> None:
|
||||
"""Hash password field in-place if present."""
|
||||
if "password" in data and data["password"] is not None:
|
||||
data["password"] = hash_password(data["password"])
|
||||
|
||||
|
||||
def _strip_password_from_response(response) -> None:
|
||||
"""Strip password from API response (handles dicts, nested dicts, and Prisma models)."""
|
||||
if isinstance(response, dict):
|
||||
response.pop("password", None)
|
||||
if isinstance(response.get("data"), dict):
|
||||
response["data"].pop("password", None)
|
||||
elif hasattr(response.get("data"), "__dict__"):
|
||||
response["data"].__dict__.pop("password", None)
|
||||
|
||||
|
||||
def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> dict:
|
||||
if "user_id" in data_json and data_json["user_id"] is None:
|
||||
data_json["user_id"] = str(uuid.uuid4())
|
||||
@ -438,6 +454,7 @@ async def new_user(
|
||||
|
||||
data_json = data.json() # type: ignore
|
||||
data_json = _update_internal_new_user_params(data_json, data)
|
||||
_hash_password_in_dict(data_json)
|
||||
teams = data.teams
|
||||
if teams is None:
|
||||
teams = check_if_default_team_set()
|
||||
@ -723,6 +740,8 @@ async def user_info(
|
||||
_user_info = (
|
||||
user_info.model_dump() if isinstance(user_info, BaseModel) else user_info
|
||||
)
|
||||
if isinstance(_user_info, dict):
|
||||
_user_info.pop("password", None)
|
||||
response_data = UserInfoResponse(
|
||||
user_id=user_id, user_info=_user_info, keys=returned_keys, teams=team_list
|
||||
)
|
||||
@ -950,6 +969,8 @@ async def _get_user_info_for_proxy_admin(user_api_key_dict: UserAPIKeyAuth):
|
||||
if isinstance(admin_user_info, BaseModel)
|
||||
else admin_user_info
|
||||
)
|
||||
if isinstance(admin_user_info, dict):
|
||||
admin_user_info.pop("password", None)
|
||||
|
||||
return UserInfoResponse(
|
||||
user_id=admin_user_id,
|
||||
@ -1089,6 +1110,8 @@ async def _update_single_user_helper(
|
||||
data_json=data_json, data=user_request
|
||||
)
|
||||
|
||||
_hash_password_in_dict(non_default_values)
|
||||
|
||||
# Get existing user data for audit logging and metadata preparation
|
||||
existing_user_row: Optional[BaseModel] = None
|
||||
if user_request.user_id:
|
||||
@ -1205,6 +1228,7 @@ async def _update_single_user_helper(
|
||||
status_code=400,
|
||||
detail={"error": "Failed to update user"},
|
||||
)
|
||||
_strip_password_from_response(response)
|
||||
return response
|
||||
|
||||
|
||||
|
||||
@ -503,7 +503,9 @@ from litellm.proxy.utils import (
|
||||
get_error_message_str,
|
||||
get_server_root_path,
|
||||
handle_exception_on_proxy,
|
||||
hash_password,
|
||||
hash_token,
|
||||
migrate_passwords_to_scrypt_async,
|
||||
model_dump_with_preserved_fields,
|
||||
update_spend,
|
||||
)
|
||||
@ -870,6 +872,15 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915
|
||||
user_api_key_cache=user_api_key_cache,
|
||||
)
|
||||
|
||||
if prisma_client is not None:
|
||||
async def _run_pw_migration():
|
||||
try:
|
||||
result = await migrate_passwords_to_scrypt_async(prisma_client)
|
||||
verbose_proxy_logger.info(f"Password migration: {result}")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.warning(f"Password migration skipped: {e}")
|
||||
asyncio.create_task(_run_pw_migration())
|
||||
|
||||
ProxyStartupEvent._initialize_startup_logging(
|
||||
llm_router=llm_router,
|
||||
proxy_logging_obj=proxy_logging_obj,
|
||||
@ -11683,9 +11694,9 @@ async def claim_onboarding_link(data: InvitationClaim):
|
||||
},
|
||||
)
|
||||
### UPDATE USER OBJECT ###
|
||||
hash_password = hash_token(token=data.password)
|
||||
hashed_pw = hash_password(data.password)
|
||||
user_obj = await prisma_client.db.litellm_usertable.update(
|
||||
where={"user_id": invite_obj.user_id}, data={"password": hash_password}
|
||||
where={"user_id": invite_obj.user_id}, data={"password": hashed_pw}
|
||||
)
|
||||
|
||||
if user_obj is None:
|
||||
@ -11705,6 +11716,8 @@ async def claim_onboarding_link(data: InvitationClaim):
|
||||
},
|
||||
)
|
||||
|
||||
if user_obj and hasattr(user_obj, "__dict__"):
|
||||
user_obj.__dict__.pop("password", None)
|
||||
return user_obj
|
||||
|
||||
|
||||
@ -12143,7 +12156,7 @@ async def invitation_delete(
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def update_config(config_info: ConfigYAML): # noqa: PLR0915
|
||||
async def update_config(config_info: ConfigYAML, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)): # noqa: PLR0915
|
||||
"""
|
||||
For Admin UI - allows admin to update config via UI
|
||||
|
||||
@ -12151,6 +12164,8 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915
|
||||
"""
|
||||
global llm_router, llm_model_list, general_settings, proxy_config, proxy_logging_obj, master_key, prisma_client
|
||||
try:
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(status_code=403, detail="Only proxy admins can update config")
|
||||
import base64
|
||||
|
||||
"""
|
||||
|
||||
@ -66,6 +66,15 @@ async def spend_key_fn():
|
||||
)
|
||||
|
||||
|
||||
def _strip_password_from_users(users) -> None:
|
||||
"""Strip password field from a list of user objects."""
|
||||
for user in users if isinstance(users, list) else [users]:
|
||||
if user and hasattr(user, "__dict__"):
|
||||
user.__dict__.pop("password", None)
|
||||
elif isinstance(user, dict):
|
||||
user.pop("password", None)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/spend/users",
|
||||
tags=["Budget & Spend Tracking"],
|
||||
@ -105,13 +114,15 @@ async def spend_user_fn(
|
||||
user_info = await prisma_client.get_data(
|
||||
table_name="user", query_type="find_unique", user_id=user_id
|
||||
)
|
||||
return [user_info]
|
||||
result = [user_info]
|
||||
else:
|
||||
user_info = await prisma_client.get_data(
|
||||
table_name="user", query_type="find_all"
|
||||
)
|
||||
result = user_info
|
||||
|
||||
return user_info
|
||||
_strip_password_from_users(result)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
||||
@ -4554,6 +4554,66 @@ def hash_token(token: str):
|
||||
return hashed_token
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
"""Hash a password using scrypt with a random salt."""
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
salt = os.urandom(16)
|
||||
dk = hashlib.scrypt(password.encode(), salt=salt, n=16384, r=8, p=1, dklen=32)
|
||||
return "scrypt:" + base64.b64encode(salt + dk).decode()
|
||||
|
||||
|
||||
def verify_password(password: str, stored: str) -> bool:
|
||||
"""Verify a password against a stored hash. Supports scrypt and SHA256."""
|
||||
import base64
|
||||
import hashlib
|
||||
import secrets
|
||||
|
||||
if stored.startswith("scrypt:"):
|
||||
try:
|
||||
raw = base64.b64decode(stored[7:])
|
||||
salt, dk = raw[:16], raw[16:]
|
||||
dk2 = hashlib.scrypt(password.encode(), salt=salt, n=16384, r=8, p=1, dklen=32)
|
||||
return secrets.compare_digest(dk, dk2)
|
||||
except Exception:
|
||||
return False
|
||||
# SHA256 fallback (not vulnerable to pass-the-hash: checks sha256(input) == stored)
|
||||
if len(stored) == 64 and all(c in "0123456789abcdef" for c in stored):
|
||||
return secrets.compare_digest(
|
||||
hashlib.sha256(password.encode()).hexdigest().encode(), stored.encode()
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
async def migrate_passwords_to_scrypt_async(prisma_client) -> str:
|
||||
"""
|
||||
Migrate plaintext passwords in the DB to scrypt. SHA256 passwords
|
||||
are left alone (they migrate on next login via the SHA256 fallback).
|
||||
Skips quickly if no plaintext passwords exist.
|
||||
"""
|
||||
all_with_pw = await prisma_client.db.litellm_usertable.find_many(
|
||||
where={"password": {"not": None}},
|
||||
)
|
||||
def _is_sha256_hex(s: str) -> bool:
|
||||
return len(s) == 64 and all(c in "0123456789abcdef" for c in s)
|
||||
|
||||
plaintext_users = [
|
||||
u for u in all_with_pw
|
||||
if u.password and not u.password.startswith("scrypt:") and not _is_sha256_hex(u.password)
|
||||
]
|
||||
if not plaintext_users:
|
||||
return "No plaintext passwords found"
|
||||
|
||||
for user in plaintext_users:
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
where={"user_id": user.user_id},
|
||||
data={"password": hash_password(user.password)},
|
||||
)
|
||||
return f"Migrated {len(plaintext_users)} plaintext passwords to scrypt"
|
||||
|
||||
|
||||
def _hash_token_if_needed(token: str) -> str:
|
||||
"""
|
||||
Hash the token if it's a string and starts with "sk-"
|
||||
|
||||
@ -1495,10 +1495,6 @@ def client(original_function): # noqa: PLR0915
|
||||
)
|
||||
logging_obj._llm_caching_handler = _llm_caching_handler
|
||||
|
||||
# CHECK FOR 'os.environ/' in kwargs
|
||||
for k, v in kwargs.items():
|
||||
if v is not None and isinstance(v, str) and v.startswith("os.environ/"):
|
||||
kwargs[k] = litellm.get_secret(v)
|
||||
# [OPTIONAL] CHECK BUDGET
|
||||
if litellm.max_budget:
|
||||
if litellm._current_cost > litellm.max_budget:
|
||||
|
||||
@ -2787,7 +2787,9 @@ async def test_update_config_success_callback_normalization():
|
||||
|
||||
# Update config with mixed-case callbacks - expect normalization to lowercase
|
||||
config_update = ConfigYAML(litellm_settings={"success_callback": ["SQS", "sQs"]})
|
||||
await proxy_server.update_config(config_update)
|
||||
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
|
||||
admin_user = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-test")
|
||||
await proxy_server.update_config(config_update, user_api_key_dict=admin_user)
|
||||
|
||||
saved = mock_proxy_config.saved_config
|
||||
assert saved is not None, "save_config was not called"
|
||||
|
||||
69
tests/test_litellm/proxy/auth/test_password_hashing.py
Normal file
69
tests/test_litellm/proxy/auth/test_password_hashing.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""Tests for password hashing and verification utilities."""
|
||||
|
||||
import hashlib
|
||||
|
||||
import pytest
|
||||
|
||||
from litellm.proxy.utils import hash_password, verify_password
|
||||
|
||||
|
||||
class TestHashPassword:
|
||||
def test_produces_scrypt_prefix(self):
|
||||
assert hash_password("test").startswith("scrypt:")
|
||||
|
||||
def test_unique_salt_per_call(self):
|
||||
assert hash_password("same") != hash_password("same")
|
||||
|
||||
def test_output_length(self):
|
||||
# "scrypt:" (7) + base64(48 bytes) (64) = 71
|
||||
assert len(hash_password("test")) == 71
|
||||
|
||||
|
||||
class TestVerifyPassword:
|
||||
def test_correct_password(self):
|
||||
h = hash_password("correct")
|
||||
assert verify_password("correct", h) is True
|
||||
|
||||
def test_wrong_password(self):
|
||||
h = hash_password("correct")
|
||||
assert verify_password("wrong", h) is False
|
||||
|
||||
def test_empty_password(self):
|
||||
h = hash_password("")
|
||||
assert verify_password("", h) is True
|
||||
assert verify_password("notempty", h) is False
|
||||
|
||||
def test_unicode_password(self):
|
||||
h = hash_password("pässwörd")
|
||||
assert verify_password("pässwörd", h) is True
|
||||
assert verify_password("password", h) is False
|
||||
|
||||
def test_long_password(self):
|
||||
pw = "a" * 1000
|
||||
h = hash_password(pw)
|
||||
assert verify_password(pw, h) is True
|
||||
|
||||
|
||||
class TestVerifyPasswordFallbacks:
|
||||
def test_sha256_fallback(self):
|
||||
stored = hashlib.sha256("oldpass".encode()).hexdigest()
|
||||
assert verify_password("oldpass", stored) is True
|
||||
assert verify_password("wrong", stored) is False
|
||||
|
||||
def test_no_plaintext_fallback(self):
|
||||
# Plaintext fallback removed to prevent pass-the-hash attacks
|
||||
assert verify_password("plaintext", "plaintext") is False
|
||||
|
||||
def test_scrypt_preferred_over_fallbacks(self):
|
||||
h = hash_password("test")
|
||||
# Scrypt hash should not accidentally match as plaintext or SHA256
|
||||
assert verify_password("test", h) is True
|
||||
assert h.startswith("scrypt:")
|
||||
|
||||
def test_sha256_not_confused_with_plaintext(self):
|
||||
# A 64-char hex string that isn't a valid SHA256 of the password
|
||||
fake_hex = "a" * 64
|
||||
assert verify_password("test", fake_hex) is False
|
||||
|
||||
def test_scrypt_invalid_base64_rejected(self):
|
||||
assert verify_password("test", "scrypt:not-valid-base64!!!") is False
|
||||
Loading…
Reference in New Issue
Block a user