Merge pull request #24823 from jaydns/fixes

chore: fixes
This commit is contained in:
Krrish Dholakia 2026-03-30 19:35:40 -07:00 committed by GitHub
commit 5d6c76aa1a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 205 additions and 22 deletions

View File

@ -169,9 +169,6 @@ def get_llm_provider( # noqa: PLR0915
return remainder, custom_llm_provider, dynamic_api_key, api_base
return model, custom_llm_provider, dynamic_api_key, api_base
if api_key and api_key.startswith("os.environ/"):
dynamic_api_key = get_secret_str(api_key)
# Check JSON-configured providers FIRST (before enum-based provider_list)
provider_prefix = model.split("/", 1)[0]
if len(model.split("/")) > 1 and JSONProviderRegistry.exists(provider_prefix):

View File

@ -7,6 +7,7 @@ JWT token must have 'litellm_proxy_admin' in scope.
"""
import fnmatch
import hashlib
import os
import re
from typing import Any, List, Literal, Optional, Set, Tuple, cast
@ -621,7 +622,7 @@ class JWTHandler:
# Check cache first
cache_key = (
f"oidc_userinfo_{token[:20]}" # Use first 20 chars of token as cache key
f"oidc_userinfo_{hashlib.sha256(token.encode()).hexdigest()}"
)
cached_userinfo = await self.user_api_key_cache.async_get_cache(cache_key)

View File

@ -13,6 +13,7 @@ from fastapi import HTTPException
import litellm
from litellm.constants import LITELLM_PROXY_ADMIN_NAME, LITELLM_UI_SESSION_DURATION
from litellm.proxy.utils import hash_password, verify_password
from litellm.proxy._types import (
LiteLLM_UserTable,
LitellmUserRoles,
@ -34,6 +35,18 @@ from litellm.secret_managers.main import get_secret_bool
from litellm.types.proxy.ui_sso import ReturnedUITokenObject
async def _rehash_password_if_needed(user_id: str, password: str, stored: str) -> None:
"""Rehash legacy password (SHA256) to scrypt on successful login."""
if stored.startswith("scrypt:"):
return
from litellm.proxy.proxy_server import prisma_client
if prisma_client is not None:
await prisma_client.db.litellm_usertable.update(
where={"user_id": user_id},
data={"password": hash_password(password)},
)
def get_ui_credentials(master_key: Optional[str]) -> tuple[str, str]:
"""
Get UI username and password from environment variables or master key.
@ -254,13 +267,8 @@ async def authenticate_user( # noqa: PLR0915
code=401,
)
# check if password == _user_row.password
hash_password = hash_token(token=password)
if secrets.compare_digest(
password.encode("utf-8"), _password.encode("utf-8")
) or secrets.compare_digest(
hash_password.encode("utf-8"), _password.encode("utf-8")
):
if verify_password(password, _password):
await _rehash_password_if_needed(_user_row.user_id, password, _password)
if os.getenv("DATABASE_URL") is not None:
response = await generate_key_helper_fn(
request_type="key",

View File

@ -41,7 +41,7 @@ from litellm.proxy.management_endpoints.key_management_endpoints import (
prepare_metadata_fields,
)
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import handle_exception_on_proxy
from litellm.proxy.utils import handle_exception_on_proxy, hash_password
from litellm.types.proxy.management_endpoints.common_daily_activity import (
SpendAnalyticsPaginatedResponse,
)
@ -58,6 +58,22 @@ if TYPE_CHECKING:
router = APIRouter()
def _hash_password_in_dict(data: dict) -> None:
"""Hash password field in-place if present."""
if "password" in data and data["password"] is not None:
data["password"] = hash_password(data["password"])
def _strip_password_from_response(response) -> None:
"""Strip password from API response (handles dicts, nested dicts, and Prisma models)."""
if isinstance(response, dict):
response.pop("password", None)
if isinstance(response.get("data"), dict):
response["data"].pop("password", None)
elif hasattr(response.get("data"), "__dict__"):
response["data"].__dict__.pop("password", None)
def _update_internal_new_user_params(data_json: dict, data: NewUserRequest) -> dict:
if "user_id" in data_json and data_json["user_id"] is None:
data_json["user_id"] = str(uuid.uuid4())
@ -438,6 +454,7 @@ async def new_user(
data_json = data.json() # type: ignore
data_json = _update_internal_new_user_params(data_json, data)
_hash_password_in_dict(data_json)
teams = data.teams
if teams is None:
teams = check_if_default_team_set()
@ -723,6 +740,8 @@ async def user_info(
_user_info = (
user_info.model_dump() if isinstance(user_info, BaseModel) else user_info
)
if isinstance(_user_info, dict):
_user_info.pop("password", None)
response_data = UserInfoResponse(
user_id=user_id, user_info=_user_info, keys=returned_keys, teams=team_list
)
@ -950,6 +969,8 @@ async def _get_user_info_for_proxy_admin(user_api_key_dict: UserAPIKeyAuth):
if isinstance(admin_user_info, BaseModel)
else admin_user_info
)
if isinstance(admin_user_info, dict):
admin_user_info.pop("password", None)
return UserInfoResponse(
user_id=admin_user_id,
@ -1089,6 +1110,8 @@ async def _update_single_user_helper(
data_json=data_json, data=user_request
)
_hash_password_in_dict(non_default_values)
# Get existing user data for audit logging and metadata preparation
existing_user_row: Optional[BaseModel] = None
if user_request.user_id:
@ -1205,6 +1228,7 @@ async def _update_single_user_helper(
status_code=400,
detail={"error": "Failed to update user"},
)
_strip_password_from_response(response)
return response

View File

@ -503,7 +503,9 @@ from litellm.proxy.utils import (
get_error_message_str,
get_server_root_path,
handle_exception_on_proxy,
hash_password,
hash_token,
migrate_passwords_to_scrypt_async,
model_dump_with_preserved_fields,
update_spend,
)
@ -870,6 +872,15 @@ async def proxy_startup_event(app: FastAPI): # noqa: PLR0915
user_api_key_cache=user_api_key_cache,
)
if prisma_client is not None:
async def _run_pw_migration():
try:
result = await migrate_passwords_to_scrypt_async(prisma_client)
verbose_proxy_logger.info(f"Password migration: {result}")
except Exception as e:
verbose_proxy_logger.warning(f"Password migration skipped: {e}")
asyncio.create_task(_run_pw_migration())
ProxyStartupEvent._initialize_startup_logging(
llm_router=llm_router,
proxy_logging_obj=proxy_logging_obj,
@ -11683,9 +11694,9 @@ async def claim_onboarding_link(data: InvitationClaim):
},
)
### UPDATE USER OBJECT ###
hash_password = hash_token(token=data.password)
hashed_pw = hash_password(data.password)
user_obj = await prisma_client.db.litellm_usertable.update(
where={"user_id": invite_obj.user_id}, data={"password": hash_password}
where={"user_id": invite_obj.user_id}, data={"password": hashed_pw}
)
if user_obj is None:
@ -11705,6 +11716,8 @@ async def claim_onboarding_link(data: InvitationClaim):
},
)
if user_obj and hasattr(user_obj, "__dict__"):
user_obj.__dict__.pop("password", None)
return user_obj
@ -12143,7 +12156,7 @@ async def invitation_delete(
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
async def update_config(config_info: ConfigYAML): # noqa: PLR0915
async def update_config(config_info: ConfigYAML, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)): # noqa: PLR0915
"""
For Admin UI - allows admin to update config via UI
@ -12151,6 +12164,8 @@ async def update_config(config_info: ConfigYAML): # noqa: PLR0915
"""
global llm_router, llm_model_list, general_settings, proxy_config, proxy_logging_obj, master_key, prisma_client
try:
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
raise HTTPException(status_code=403, detail="Only proxy admins can update config")
import base64
"""

View File

@ -66,6 +66,15 @@ async def spend_key_fn():
)
def _strip_password_from_users(users) -> None:
"""Strip password field from a list of user objects."""
for user in users if isinstance(users, list) else [users]:
if user and hasattr(user, "__dict__"):
user.__dict__.pop("password", None)
elif isinstance(user, dict):
user.pop("password", None)
@router.get(
"/spend/users",
tags=["Budget & Spend Tracking"],
@ -105,13 +114,15 @@ async def spend_user_fn(
user_info = await prisma_client.get_data(
table_name="user", query_type="find_unique", user_id=user_id
)
return [user_info]
result = [user_info]
else:
user_info = await prisma_client.get_data(
table_name="user", query_type="find_all"
)
result = user_info
return user_info
_strip_password_from_users(result)
return result
except Exception as e:
raise HTTPException(

View File

@ -4554,6 +4554,66 @@ def hash_token(token: str):
return hashed_token
def hash_password(password: str) -> str:
"""Hash a password using scrypt with a random salt."""
import base64
import hashlib
import os
salt = os.urandom(16)
dk = hashlib.scrypt(password.encode(), salt=salt, n=16384, r=8, p=1, dklen=32)
return "scrypt:" + base64.b64encode(salt + dk).decode()
def verify_password(password: str, stored: str) -> bool:
"""Verify a password against a stored hash. Supports scrypt and SHA256."""
import base64
import hashlib
import secrets
if stored.startswith("scrypt:"):
try:
raw = base64.b64decode(stored[7:])
salt, dk = raw[:16], raw[16:]
dk2 = hashlib.scrypt(password.encode(), salt=salt, n=16384, r=8, p=1, dklen=32)
return secrets.compare_digest(dk, dk2)
except Exception:
return False
# SHA256 fallback (not vulnerable to pass-the-hash: checks sha256(input) == stored)
if len(stored) == 64 and all(c in "0123456789abcdef" for c in stored):
return secrets.compare_digest(
hashlib.sha256(password.encode()).hexdigest().encode(), stored.encode()
)
return False
async def migrate_passwords_to_scrypt_async(prisma_client) -> str:
"""
Migrate plaintext passwords in the DB to scrypt. SHA256 passwords
are left alone (they migrate on next login via the SHA256 fallback).
Skips quickly if no plaintext passwords exist.
"""
all_with_pw = await prisma_client.db.litellm_usertable.find_many(
where={"password": {"not": None}},
)
def _is_sha256_hex(s: str) -> bool:
return len(s) == 64 and all(c in "0123456789abcdef" for c in s)
plaintext_users = [
u for u in all_with_pw
if u.password and not u.password.startswith("scrypt:") and not _is_sha256_hex(u.password)
]
if not plaintext_users:
return "No plaintext passwords found"
for user in plaintext_users:
await prisma_client.db.litellm_usertable.update(
where={"user_id": user.user_id},
data={"password": hash_password(user.password)},
)
return f"Migrated {len(plaintext_users)} plaintext passwords to scrypt"
def _hash_token_if_needed(token: str) -> str:
"""
Hash the token if it's a string and starts with "sk-"

View File

@ -1495,10 +1495,6 @@ def client(original_function): # noqa: PLR0915
)
logging_obj._llm_caching_handler = _llm_caching_handler
# CHECK FOR 'os.environ/' in kwargs
for k, v in kwargs.items():
if v is not None and isinstance(v, str) and v.startswith("os.environ/"):
kwargs[k] = litellm.get_secret(v)
# [OPTIONAL] CHECK BUDGET
if litellm.max_budget:
if litellm._current_cost > litellm.max_budget:

View File

@ -2787,7 +2787,9 @@ async def test_update_config_success_callback_normalization():
# Update config with mixed-case callbacks - expect normalization to lowercase
config_update = ConfigYAML(litellm_settings={"success_callback": ["SQS", "sQs"]})
await proxy_server.update_config(config_update)
from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth
admin_user = UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN, api_key="sk-test")
await proxy_server.update_config(config_update, user_api_key_dict=admin_user)
saved = mock_proxy_config.saved_config
assert saved is not None, "save_config was not called"

View File

@ -0,0 +1,69 @@
"""Tests for password hashing and verification utilities."""
import hashlib
import pytest
from litellm.proxy.utils import hash_password, verify_password
class TestHashPassword:
def test_produces_scrypt_prefix(self):
assert hash_password("test").startswith("scrypt:")
def test_unique_salt_per_call(self):
assert hash_password("same") != hash_password("same")
def test_output_length(self):
# "scrypt:" (7) + base64(48 bytes) (64) = 71
assert len(hash_password("test")) == 71
class TestVerifyPassword:
def test_correct_password(self):
h = hash_password("correct")
assert verify_password("correct", h) is True
def test_wrong_password(self):
h = hash_password("correct")
assert verify_password("wrong", h) is False
def test_empty_password(self):
h = hash_password("")
assert verify_password("", h) is True
assert verify_password("notempty", h) is False
def test_unicode_password(self):
h = hash_password("pässwörd")
assert verify_password("pässwörd", h) is True
assert verify_password("password", h) is False
def test_long_password(self):
pw = "a" * 1000
h = hash_password(pw)
assert verify_password(pw, h) is True
class TestVerifyPasswordFallbacks:
def test_sha256_fallback(self):
stored = hashlib.sha256("oldpass".encode()).hexdigest()
assert verify_password("oldpass", stored) is True
assert verify_password("wrong", stored) is False
def test_no_plaintext_fallback(self):
# Plaintext fallback removed to prevent pass-the-hash attacks
assert verify_password("plaintext", "plaintext") is False
def test_scrypt_preferred_over_fallbacks(self):
h = hash_password("test")
# Scrypt hash should not accidentally match as plaintext or SHA256
assert verify_password("test", h) is True
assert h.startswith("scrypt:")
def test_sha256_not_confused_with_plaintext(self):
# A 64-char hex string that isn't a valid SHA256 of the password
fake_hex = "a" * 64
assert verify_password("test", fake_hex) is False
def test_scrypt_invalid_base64_rejected(self):
assert verify_password("test", "scrypt:not-valid-base64!!!") is False