Merge remote-tracking branch 'origin/litellm_internal_staging' into litellm_/sweet-mcclintock-2b3656

This commit is contained in:
Yuneng Jiang 2026-05-07 13:46:53 -07:00
commit 4189d78a64
No known key found for this signature in database
17 changed files with 1073 additions and 183 deletions

View File

@ -388,6 +388,7 @@ anthropic_beta_headers_url: str = os.getenv(
suppress_debug_info = False
dynamodb_table_name: Optional[str] = None
s3_callback_params: Optional[Dict] = None
s3_audit_callback_params: Optional[Dict] = None
datadog_llm_observability_params: Optional[Union[DatadogLLMObsInitParams, Dict]] = None
datadog_params: Optional[Union[DatadogInitParams, Dict]] = None
aws_sqs_callback_params: Optional[Dict] = None

View File

@ -16,6 +16,7 @@ from litellm._logging import print_verbose, verbose_logger
from litellm.constants import DEFAULT_S3_BATCH_SIZE, DEFAULT_S3_FLUSH_INTERVAL_SECONDS
from litellm.integrations.s3 import get_s3_object_key
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
from litellm.litellm_core_utils.sensitive_data_masker import SensitiveDataMasker
from litellm.llms.bedrock.base_aws_llm import BaseAWSLLM
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
@ -53,15 +54,25 @@ class S3Logger(CustomBatchLogger, BaseAWSLLM):
s3_strip_base64_files: bool = False,
s3_use_key_prefix: bool = False,
s3_use_virtual_hosted_style: bool = False,
s3_callback_params_override: Optional[dict] = None,
**kwargs,
):
try:
verbose_logger.debug(
f"in init s3 logger - s3_callback_params {litellm.s3_callback_params}"
)
_masker = SensitiveDataMasker()
if s3_callback_params_override is not None:
verbose_logger.debug(
f"in init s3 logger (audit override) - "
f"{_masker.mask_dict(dict(s3_callback_params_override))}"
)
else:
verbose_logger.debug(
f"in init s3 logger - s3_callback_params "
f"{_masker.mask_dict(dict(litellm.s3_callback_params or {}))}"
)
# Initialize S3 params first to get the correct s3_verify value
self._init_s3_params(
params_source=s3_callback_params_override,
s3_bucket_name=s3_bucket_name,
s3_region_name=s3_region_name,
s3_api_version=s3_api_version,
@ -139,94 +150,85 @@ class S3Logger(CustomBatchLogger, BaseAWSLLM):
s3_strip_base64_files: bool = False,
s3_use_key_prefix: bool = False,
s3_use_virtual_hosted_style: bool = False,
params_source: Optional[dict] = None,
):
"""
Initialize the s3 params for this logging callback
Initialize the s3 params for this logging callback. Reads from
`params_source` if given (e.g. `s3_audit_callback_params` for the
audit-log instance), otherwise falls back to `litellm.s3_callback_params`.
Resolves `os.environ/X` markers into a local dict; never mutates the source.
"""
litellm.s3_callback_params = litellm.s3_callback_params or {}
# read in .env variables - example os.environ/AWS_BUCKET_NAME
for key, value in litellm.s3_callback_params.items():
if isinstance(value, str) and value.startswith("os.environ/"):
litellm.s3_callback_params[key] = litellm.get_secret(value)
if params_source is None:
params_source = litellm.s3_callback_params or {}
params: dict = {
key: (
litellm.get_secret(value)
if isinstance(value, str) and value.startswith("os.environ/")
else value
)
for key, value in params_source.items()
}
self.s3_bucket_name = (
litellm.s3_callback_params.get("s3_bucket_name") or s3_bucket_name
)
self.s3_region_name = (
litellm.s3_callback_params.get("s3_region_name") or s3_region_name
)
self.s3_api_version = (
litellm.s3_callback_params.get("s3_api_version") or s3_api_version
)
self.s3_bucket_name = params.get("s3_bucket_name") or s3_bucket_name
self.s3_region_name = params.get("s3_region_name") or s3_region_name
self.s3_api_version = params.get("s3_api_version") or s3_api_version
self.s3_use_ssl = (
litellm.s3_callback_params.get("s3_use_ssl", True)
if litellm.s3_callback_params.get("s3_use_ssl") is not None
params.get("s3_use_ssl", True)
if params.get("s3_use_ssl") is not None
else s3_use_ssl
)
self.s3_verify = (
litellm.s3_callback_params.get("s3_verify")
if litellm.s3_callback_params.get("s3_verify") is not None
params.get("s3_verify")
if params.get("s3_verify") is not None
else s3_verify
)
self.s3_endpoint_url = (
litellm.s3_callback_params.get("s3_endpoint_url") or s3_endpoint_url
)
self.s3_endpoint_url = params.get("s3_endpoint_url") or s3_endpoint_url
self.s3_aws_access_key_id = (
litellm.s3_callback_params.get("s3_aws_access_key_id")
or s3_aws_access_key_id
params.get("s3_aws_access_key_id") or s3_aws_access_key_id
)
self.s3_aws_secret_access_key = (
litellm.s3_callback_params.get("s3_aws_secret_access_key")
or s3_aws_secret_access_key
params.get("s3_aws_secret_access_key") or s3_aws_secret_access_key
)
self.s3_aws_session_token = (
litellm.s3_callback_params.get("s3_aws_session_token")
or s3_aws_session_token
params.get("s3_aws_session_token") or s3_aws_session_token
)
self.s3_aws_session_name = (
litellm.s3_callback_params.get("s3_aws_session_name") or s3_aws_session_name
params.get("s3_aws_session_name") or s3_aws_session_name
)
self.s3_aws_profile_name = (
litellm.s3_callback_params.get("s3_aws_profile_name") or s3_aws_profile_name
params.get("s3_aws_profile_name") or s3_aws_profile_name
)
self.s3_aws_role_name = (
litellm.s3_callback_params.get("s3_aws_role_name") or s3_aws_role_name
)
self.s3_aws_role_name = params.get("s3_aws_role_name") or s3_aws_role_name
self.s3_aws_web_identity_token = (
litellm.s3_callback_params.get("s3_aws_web_identity_token")
or s3_aws_web_identity_token
params.get("s3_aws_web_identity_token") or s3_aws_web_identity_token
)
self.s3_aws_sts_endpoint = (
litellm.s3_callback_params.get("s3_aws_sts_endpoint") or s3_aws_sts_endpoint
params.get("s3_aws_sts_endpoint") or s3_aws_sts_endpoint
)
self.s3_config = litellm.s3_callback_params.get("s3_config") or s3_config
self.s3_path = litellm.s3_callback_params.get("s3_path") or s3_path
# done reading litellm.s3_callback_params
self.s3_config = params.get("s3_config") or s3_config
self.s3_path = params.get("s3_path") or s3_path
self.s3_use_team_prefix = (
bool(litellm.s3_callback_params.get("s3_use_team_prefix", False))
or s3_use_team_prefix
bool(params.get("s3_use_team_prefix", False)) or s3_use_team_prefix
)
self.s3_use_key_prefix = (
bool(litellm.s3_callback_params.get("s3_use_key_prefix", False))
or s3_use_key_prefix
bool(params.get("s3_use_key_prefix", False)) or s3_use_key_prefix
)
self.s3_strip_base64_files = (
bool(litellm.s3_callback_params.get("s3_strip_base64_files", False))
or s3_strip_base64_files
bool(params.get("s3_strip_base64_files", False)) or s3_strip_base64_files
)
self.s3_use_virtual_hosted_style = (
bool(litellm.s3_callback_params.get("s3_use_virtual_hosted_style", False))
bool(params.get("s3_use_virtual_hosted_style", False))
or s3_use_virtual_hosted_style
)

View File

@ -436,12 +436,21 @@ def update_messages_with_model_file_ids(
"""
Updates messages with model file ids.
For managed files (unified file IDs), uses model_file_id_mapping if it
resolves the id, otherwise decodes the base64-encoded unified file ID
and extracts the llm_output_file_id directly. Mirrors the Responses-API
sibling `update_responses_input_with_model_file_ids`.
model_file_id_mapping: Dict[str, Dict[str, str]] = {
"litellm_proxy/file_id": {
"model_id": "provider_file_id"
}
}
"""
from litellm.proxy.openai_files_endpoints.common_utils import (
_is_base64_encoded_unified_file_id,
convert_b64_uid_to_unified_uid,
)
for message in messages:
if message.get("role") == "user":
@ -450,7 +459,13 @@ def update_messages_with_model_file_ids(
if isinstance(content, str):
continue
for c in content:
if c["type"] == "file":
if not isinstance(c, dict):
# Content list items aren't always dicts. e.g.
# text_completion forwards a token-ids list/list-of-
# lists through this path. Skip non-dict items
# instead of indexing into them.
continue
if c.get("type") == "file":
file_object = cast(ChatCompletionFileObject, c)
file_object_file_field = file_object.get("file")
if not isinstance(file_object_file_field, dict):
@ -468,9 +483,23 @@ def update_messages_with_model_file_ids(
if file_id:
provider_file_id = (
model_file_id_mapping.get(file_id, {}).get(model_id)
or file_id
if model_file_id_mapping
else None
)
if (
not provider_file_id
and _is_base64_encoded_unified_file_id(file_id)
):
unified_file_id = convert_b64_uid_to_unified_uid(
file_id
)
if "llm_output_file_id," in unified_file_id:
provider_file_id = unified_file_id.split(
"llm_output_file_id,"
)[1].split(";")[0]
file_object_file_field["file_id"] = (
provider_file_id or file_id
)
file_object_file_field["file_id"] = provider_file_id
if format:
file_object_file_field["format"] = format
return messages

View File

@ -1459,14 +1459,14 @@ def completion( # type: ignore # noqa: PLR0915
if eos_token:
custom_prompt_dict[model]["eos_token"] = eos_token
if kwargs.get("model_file_id_mapping"):
messages = update_messages_with_model_file_ids(
messages=messages,
model_id=kwargs.get("model_info", {}).get("id", None),
model_file_id_mapping=cast(
Dict[str, Dict[str, str]], kwargs.get("model_file_id_mapping")
),
)
messages = update_messages_with_model_file_ids(
messages=messages,
model_id=kwargs.get("model_info", {}).get("id", None),
model_file_id_mapping=cast(
Dict[str, Dict[str, str]],
kwargs.get("model_file_id_mapping") or {},
),
)
provider_config: Optional[BaseConfig] = None
if custom_llm_provider is not None and custom_llm_provider in [

View File

@ -1512,7 +1512,7 @@ class ProxyBaseLLMRequestProcessing:
status_code=result.status_code,
headers=HttpPassThroughEndpointHelpers.get_response_headers(
headers=result.headers,
custom_headers=None,
custom_headers=dict(fastapi_response.headers),
),
)

View File

@ -46,25 +46,46 @@ def get_audit_log_changed_by(
def _resolve_audit_log_callback(name: str) -> Optional[CustomLogger]:
"""Resolve a string callback name to a CustomLogger instance, with caching."""
"""Resolve a string callback name to a CustomLogger instance, with caching.
For "s3_v2" with `litellm.s3_audit_callback_params` set, constructs a
dedicated `S3Logger` so audit logs can target a different bucket than the
normal-log singleton served by `_init_custom_logger_compatible_class`.
"""
if name in _audit_log_callback_cache:
return _audit_log_callback_cache[name]
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
instance: Optional[CustomLogger]
if (
name == "s3_v2"
and getattr(litellm, "s3_audit_callback_params", None) is not None
):
from litellm.integrations.s3_v2 import S3Logger as S3V2Logger
instance = _init_custom_logger_compatible_class(
logging_integration=name, # type: ignore
internal_usage_cache=None,
llm_router=None,
)
instance = S3V2Logger(
s3_callback_params_override=litellm.s3_audit_callback_params
)
else:
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
instance = _init_custom_logger_compatible_class(
logging_integration=name, # type: ignore
internal_usage_cache=None,
llm_router=None,
)
if instance is not None:
_audit_log_callback_cache[name] = instance
return instance
def reset_audit_log_callback_cache() -> None:
"""Clear cached audit-log callback instances. Call on config reload."""
_audit_log_callback_cache.clear()
def _build_audit_log_payload(
request_data: LiteLLM_AuditLogs,
) -> StandardAuditLogPayload:

View File

@ -3823,6 +3823,11 @@ class ProxyConfig:
f"{blue_color_code} Initialized Failure Callbacks - {litellm.failure_callback} {reset_color_code}"
) # noqa
elif key == "audit_log_callbacks":
from litellm.proxy.management_helpers.audit_logs import (
reset_audit_log_callback_cache,
)
reset_audit_log_callback_cache()
litellm.audit_log_callbacks = []
for callback in value:
@ -3904,6 +3909,21 @@ class ProxyConfig:
f"{blue_color_code} setting litellm.{key}={value}{reset_color_code}"
)
setattr(litellm, key, value)
if key in {"s3_audit_callback_params", "s3_callback_params"}:
from litellm.proxy.management_helpers.audit_logs import (
reset_audit_log_callback_cache,
)
from litellm.litellm_core_utils.litellm_logging import (
_in_memory_loggers,
)
from litellm.integrations.s3_v2 import S3Logger as S3V2Logger
reset_audit_log_callback_cache()
_in_memory_loggers[:] = [
cb
for cb in _in_memory_loggers
if not isinstance(cb, S3V2Logger)
]
## GENERAL SERVER SETTINGS (e.g. master key,..) # do this after initializing litellm, to ensure sentry logging works for proxylogging
general_settings = config.get("general_settings", {})

View File

@ -7,7 +7,9 @@ from __future__ import annotations
import atexit
import hashlib
import json
import os
import re
import sys
from typing import Iterable
@ -74,6 +76,17 @@ FILTERED_RESPONSE_HEADERS = (
"date",
)
# Tiny placeholder used to replace base64 image payloads in cassettes.
# Decodes to b"test" — short, valid base64 so test code that decodes
# the field still succeeds.
VCR_IMAGE_B64_PLACEHOLDER = "dGVzdA=="
# Fixed boundary substituted into multipart request bodies so the
# ``safe_body`` matcher sees the same bytes across record and replay.
# httpx generates a fresh random boundary per request via os.urandom,
# which otherwise turns every multipart cassette into a permanent miss.
VCR_FIXED_MULTIPART_BOUNDARY = "vcr-static-boundary"
def _scrub_response(response):
if not isinstance(response, dict):
@ -86,8 +99,88 @@ def _scrub_response(response):
return response
def _replace_b64_json_in_place(obj) -> bool:
"""Recursively replace ``b64_json`` string values in a JSON tree.
Returns ``True`` if any value was rewritten. The check on the
existing value's length keeps the function idempotent — once a
value has been swapped to the placeholder, subsequent invocations
are no-ops.
"""
changed = False
if isinstance(obj, dict):
for key, value in obj.items():
if (
key == "b64_json"
and isinstance(value, str)
and len(value) > len(VCR_IMAGE_B64_PLACEHOLDER)
):
obj[key] = VCR_IMAGE_B64_PLACEHOLDER
changed = True
elif _replace_b64_json_in_place(value):
changed = True
elif isinstance(obj, list):
for item in obj:
if _replace_b64_json_in_place(item):
changed = True
return changed
def _strip_image_b64_payloads(response):
"""Replace ``b64_json`` payloads in image-gen responses before save.
Image-edit and image-generation responses carry the full base64
PNG/JPEG (1-10+ MB) in ``data[*].b64_json``. The image_gen tests
only assert response shape the field decodes, schema validates
they never inspect pixel content. Swapping to a 4-byte placeholder
preserves all those checks while shrinking cassettes by ~99%.
"""
if not isinstance(response, dict):
return response
body = response.get("body")
if not isinstance(body, dict):
return response
raw = body.get("string")
if raw is None:
return response
if isinstance(raw, (bytes, bytearray)):
try:
text = bytes(raw).decode("utf-8")
except UnicodeDecodeError:
return response
was_bytes = True
elif isinstance(raw, str):
text = raw
was_bytes = False
else:
return response
try:
payload = json.loads(text)
except (ValueError, TypeError):
return response
if not _replace_b64_json_in_place(payload):
return response
new_text = json.dumps(payload, separators=(",", ":"))
body["string"] = new_text.encode("utf-8") if was_bytes else new_text
headers = response.get("headers")
if isinstance(headers, dict):
new_len_value = str(len(new_text.encode("utf-8")))
for key in list(headers):
if str(key).lower() == "content-length":
value = headers[key]
headers[key] = (
[new_len_value] if isinstance(value, list) else new_len_value
)
return response
def _before_record_response(response):
return filter_non_2xx_response(_scrub_response(response))
return filter_non_2xx_response(_scrub_response(_strip_image_b64_payloads(response)))
def _safe_body_matcher(r1, r2) -> None:
@ -172,8 +265,84 @@ def _strip_headers(headers, names: Iterable[str]) -> None:
pass
def _normalize_multipart_boundary(request) -> None:
"""Rewrite random multipart boundaries to a fixed string in-place.
httpx generates a fresh ``boundary=<random hex>`` for every
multipart request via ``os.urandom``. Without normalization, the
request body bytes differ across runs even when everything else is
identical, the ``safe_body`` matcher misses, and the persister
keeps appending new episodes until ``MAX_EPISODES_PER_CASSETTE``
refuses the save leaving audio-transcription tests effectively
unmocked. Replacing the boundary in both the Content-Type header
and the body bytes makes the request deterministic.
Idempotent vcrpy invokes this hook multiple times per request,
so the second invocation sees ``boundary=vcr-static-boundary``
already and short-circuits.
"""
headers = getattr(request, "headers", None)
if headers is None:
return
content_type_key = None
content_type_value = None
try:
for key in list(headers.keys()):
if str(key).lower() == "content-type":
content_type_key = key
value = headers[key]
content_type_value = value if isinstance(value, str) else str(value)
break
except AttributeError:
return
if not content_type_value or "multipart/" not in content_type_value.lower():
return
fixed_param = f"boundary={VCR_FIXED_MULTIPART_BOUNDARY}"
if fixed_param in content_type_value:
return
match = re.search(r"boundary=([^\s;]+)", content_type_value)
if not match:
return
current_boundary = match.group(1).strip('"')
if current_boundary == VCR_FIXED_MULTIPART_BOUNDARY:
return
try:
headers[content_type_key] = content_type_value.replace(
match.group(0), fixed_param
)
except (TypeError, AttributeError):
return
body = getattr(request, "body", None)
if body is None:
return
if isinstance(body, (bytes, bytearray)):
try:
new_body = bytes(body).replace(
current_boundary.encode("utf-8"),
VCR_FIXED_MULTIPART_BOUNDARY.encode("utf-8"),
)
except (TypeError, ValueError):
return
elif isinstance(body, str):
new_body = body.replace(current_boundary, VCR_FIXED_MULTIPART_BOUNDARY)
else:
return
try:
request.body = new_body
except (AttributeError, TypeError):
pass
def _before_record_request(request):
"""Fingerprint API keys, then scrub them.
"""Fingerprint API keys, scrub them, and normalize multipart boundaries.
Order matters in two ways:
@ -187,7 +356,8 @@ def _before_record_request(request):
auth headers we already stripped, so re-hashing would yield
``"no-key"`` and the stored vs. incoming fingerprints would
diverge. Skip the recompute when the header is already set so
this hook is idempotent.
this hook is idempotent. The boundary normalizer is also
idempotent for the same reason.
"""
headers = getattr(request, "headers", None)
if headers is None:
@ -199,6 +369,7 @@ def _before_record_request(request):
except (TypeError, AttributeError):
pass
_strip_headers(headers, FILTERED_REQUEST_HEADERS)
_normalize_multipart_boundary(request)
return request

View File

@ -853,7 +853,11 @@ class BaseLLMChatTest(ABC):
@pytest.mark.parametrize(
"image_url",
[
"http://img1.etsystatic.com/260/0/7813604/il_fullxfull.4226713999_q86e.jpg",
# In-repo logo served via jsdelivr (sha-pinned, immutable).
# Bedrock fetches the URL and base64-embeds it in the
# Converse request body; using a multi-MB hosted product
# photo here previously bloated cassettes to ~60 MB each.
"https://cdn.jsdelivr.net/gh/BerriAI/litellm@d769e81c90d453240c61fc572cdb27fae06a89d0/ui/litellm-dashboard/public/assets/logos/litellm_logo.jpg",
"https://awsmp-logos.s3.amazonaws.com/seller-xw5kijmvmzasy/c233c9ade2ccb5491072ae232c814942.png",
],
)

View File

@ -2,6 +2,7 @@
Tests for Evals API operations across providers
"""
import hashlib
import os
import sys
from abc import ABC, abstractmethod
@ -20,6 +21,46 @@ from litellm.types.llms.openai_evals import (
)
def _stable_eval_name(test_node_name: str, suffix: str = "") -> str:
"""Deterministic eval name keyed off the test's node name.
The previous ``f"Test Eval {int(time.time())}"`` pattern embedded a
fresh value into the request body every run, defeating VCR's
``safe_body`` matcher and forcing a real OpenAI ``create`` call on
every CI run. With a stable per-test name the cassette matches on
replay, and provider-side resources stay bounded because each test
deletes the eval it owns on teardown.
"""
nonce = hashlib.sha1(test_node_name.encode()).hexdigest()[:12]
return f"vcr-managed-{nonce}{suffix}"
_TESTING_CRITERIA = [
{
"type": "label_model",
"model": "gpt-4o",
"input": [
{
"role": "developer",
"content": "Classify the sentiment as 'positive' or 'negative'",
},
{"role": "user", "content": "Statement: {{item.input}}"},
],
"passing_labels": ["positive"],
"labels": ["positive", "negative"],
"name": "Sentiment grader",
}
]
_PROVIDER_FLAKINESS = (
litellm.InternalServerError,
litellm.APIConnectionError,
litellm.Timeout,
litellm.ServiceUnavailableError,
)
class BaseEvalsAPITest(ABC):
"""
Base test class for Evals API operations.
@ -41,13 +82,64 @@ class BaseEvalsAPITest(ABC):
"""Return the API base URL for the provider"""
pass
@pytest.fixture
def managed_eval(self, request):
"""Create a stable-named eval for this test; delete on teardown.
Function-scoped so each cassette captures the full
createtestdelete cycle. A class-scoped fixture would push
the create into whichever test ran first and the delete into
whichever ran last, which is fragile under reordering.
Replaces the prior ``list_evals().data[0].id`` pattern, which
made the URL of ``get_eval`` / ``update_eval`` vary across
runs (the "first" eval depends on what other runs left
behind).
"""
custom_llm_provider = self.get_custom_llm_provider()
api_key = self.get_api_key()
api_base = self.get_api_base()
if not api_key:
pytest.skip(f"No API key provided for {custom_llm_provider}")
try:
created = litellm.create_eval(
name=_stable_eval_name(request.node.name),
data_source_config={
"type": "stored_completions",
"metadata": {"usecase": "chatbot", "vcr": "managed"},
},
testing_criteria=_TESTING_CRITERIA,
custom_llm_provider=custom_llm_provider,
api_key=api_key,
api_base=api_base,
)
except _PROVIDER_FLAKINESS:
pytest.skip("Provider service unavailable")
except litellm.RateLimitError:
pytest.skip("Rate limit exceeded")
yield created
# Best-effort cleanup. OpenAI eval names are not unique-keyed
# (only IDs are), so a failed delete doesn't block the next
# run's create.
try:
litellm.delete_eval(
eval_id=created.id,
custom_llm_provider=custom_llm_provider,
api_key=api_key,
api_base=api_base,
)
except Exception:
pass
@pytest.mark.flaky(retries=3, delay=2)
def test_create_eval(self):
def test_create_eval(self, request):
"""
Test creating an evaluation.
"""
import time
custom_llm_provider = self.get_custom_llm_provider()
api_key = self.get_api_key()
api_base = self.get_api_base()
@ -56,53 +148,45 @@ class BaseEvalsAPITest(ABC):
pytest.skip(f"No API key provided for {custom_llm_provider}")
litellm.set_verbose = True
unique_name = _stable_eval_name(request.node.name)
# Create eval with stored_completions data source
unique_name = f"Test Eval {int(time.time())}"
created_id = None
try:
response = litellm.create_eval(
name=unique_name,
data_source_config={
"type": "stored_completions",
"metadata": {"usecase": "chatbot"},
},
testing_criteria=[
{
"type": "label_model",
"model": "gpt-4o",
"input": [
{
"role": "developer",
"content": "Classify the sentiment as 'positive' or 'negative'",
},
{"role": "user", "content": "Statement: {{item.input}}"},
],
"passing_labels": ["positive"],
"labels": ["positive", "negative"],
"name": "Sentiment grader",
}
],
custom_llm_provider=custom_llm_provider,
api_key=api_key,
api_base=api_base,
)
except (
litellm.InternalServerError,
litellm.APIConnectionError,
litellm.Timeout,
litellm.ServiceUnavailableError,
):
pytest.skip("Provider service unavailable")
except litellm.RateLimitError:
pytest.skip("Rate limit exceeded")
try:
response = litellm.create_eval(
name=unique_name,
data_source_config={
"type": "stored_completions",
"metadata": {"usecase": "chatbot"},
},
testing_criteria=_TESTING_CRITERIA,
custom_llm_provider=custom_llm_provider,
api_key=api_key,
api_base=api_base,
)
except _PROVIDER_FLAKINESS:
pytest.skip("Provider service unavailable")
except litellm.RateLimitError:
pytest.skip("Rate limit exceeded")
assert response is not None
assert isinstance(response, Eval)
assert response.id is not None
assert response.name == unique_name
print(f"Created eval: {response}")
print(f"Eval ID: {response.id}")
assert response is not None
assert isinstance(response, Eval)
assert response.id is not None
assert response.name == unique_name
created_id = response.id
print(f"Created eval: {response}")
print(f"Eval ID: {response.id}")
finally:
if created_id is not None:
try:
litellm.delete_eval(
eval_id=created_id,
custom_llm_provider=custom_llm_provider,
api_key=api_key,
api_base=api_base,
)
except Exception:
pass
def test_list_evals(self):
"""
@ -130,7 +214,7 @@ class BaseEvalsAPITest(ABC):
assert hasattr(response, "has_more")
print(f"Listed evals: {len(response.data)} evaluations")
def test_get_eval(self):
def test_get_eval(self, managed_eval):
"""
Test getting a specific evaluation by ID.
"""
@ -138,89 +222,54 @@ class BaseEvalsAPITest(ABC):
api_key = self.get_api_key()
api_base = self.get_api_base()
if not api_key:
pytest.skip(f"No API key provided for {custom_llm_provider}")
litellm.set_verbose = True
# First list existing evals to get an ID
list_response = litellm.list_evals(
limit=1,
response = litellm.get_eval(
eval_id=managed_eval.id,
custom_llm_provider=custom_llm_provider,
api_key=api_key,
api_base=api_base,
)
assert isinstance(list_response, ListEvalsResponse)
assert response is not None
assert isinstance(response, Eval)
assert response.id == managed_eval.id
print(f"Retrieved eval: {response}")
if list_response.data and len(list_response.data) > 0:
eval_id = list_response.data[0].id
print(f"Testing with eval ID: {eval_id}")
# Get the eval
response = litellm.get_eval(
eval_id=eval_id,
custom_llm_provider=custom_llm_provider,
api_key=api_key,
api_base=api_base,
)
assert response is not None
assert isinstance(response, Eval)
assert response.id == eval_id
print(f"Retrieved eval: {response}")
else:
pytest.skip("No existing evals to test with")
def test_update_eval(self):
@pytest.mark.flaky(retries=3, delay=2)
def test_update_eval(self, request, managed_eval):
"""
Test updating an evaluation.
"""
import time
custom_llm_provider = self.get_custom_llm_provider()
api_key = self.get_api_key()
api_base = self.get_api_base()
if not api_key:
pytest.skip(f"No API key provided for {custom_llm_provider}")
litellm.set_verbose = True
updated_name = _stable_eval_name(request.node.name, suffix="-updated")
# First list existing evals
list_response = litellm.list_evals(
limit=1,
response = litellm.update_eval(
eval_id=managed_eval.id,
name=updated_name,
custom_llm_provider=custom_llm_provider,
api_key=api_key,
api_base=api_base,
)
assert isinstance(list_response, ListEvalsResponse)
if list_response.data and len(list_response.data) > 0:
eval_id = list_response.data[0].id
updated_name = f"Updated Eval {int(time.time())}"
# Update the eval
response = litellm.update_eval(
eval_id=eval_id,
name=updated_name,
custom_llm_provider=custom_llm_provider,
api_key=api_key,
api_base=api_base,
)
assert response is not None
assert isinstance(response, Eval)
assert response.id == eval_id
assert response.name == updated_name
print(f"Updated eval: {response}")
else:
pytest.skip("No existing evals to test with")
assert response is not None
assert isinstance(response, Eval)
assert response.id == managed_eval.id
assert response.name == updated_name
print(f"Updated eval: {response}")
def test_delete_eval(self):
"""
Test deleting an evaluation.
Real delete coverage now lives in the ``managed_eval`` fixture
teardown and in ``test_create_eval``'s ``finally`` block, so
this stays a no-op skip rather than creating a fresh resource
just to delete it.
"""
custom_llm_provider = self.get_custom_llm_provider()
api_key = self.get_api_key()
@ -229,8 +278,7 @@ class BaseEvalsAPITest(ABC):
if not api_key:
pytest.skip(f"No API key provided for {custom_llm_provider}")
# Skip this test to avoid deleting production evals
pytest.skip("Skipping delete test to preserve existing evals")
pytest.skip("Delete is exercised via managed_eval fixture teardown.")
class TestOpenAIEvalsAPI(BaseEvalsAPITest):

View File

@ -0,0 +1,220 @@
"""Unit tests for the VCR record-time filters that keep cassettes small.
Covers:
- ``_strip_image_b64_payloads`` replaces base64 image bodies in
image-gen responses so cassettes don't carry MB-class PNG payloads.
- ``_normalize_multipart_boundary`` rewrites random multipart
boundaries to a fixed string so audio-transcription request bodies
match across record and replay.
"""
from __future__ import annotations
import json
import os
import sys
from vcr.request import Request
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from tests._vcr_conftest_common import ( # noqa: E402
VCR_FIXED_MULTIPART_BOUNDARY,
VCR_IMAGE_B64_PLACEHOLDER,
_normalize_multipart_boundary,
_strip_image_b64_payloads,
)
# ---------------------------------------------------------------------------
# Image b64 stripper
# ---------------------------------------------------------------------------
def _image_response(b64_payload: str, body_type: str = "bytes") -> dict:
body_text = json.dumps({"data": [{"b64_json": b64_payload}]})
body_string = body_text.encode("utf-8") if body_type == "bytes" else body_text
return {
"status": {"code": 200, "message": "OK"},
"headers": {
"content-type": ["application/json"],
"content-length": [str(len(body_text.encode("utf-8")))],
},
"body": {"string": body_string},
}
def test_strip_image_b64_replaces_payload_when_body_is_bytes():
response = _image_response("A" * 5000, body_type="bytes")
out = _strip_image_b64_payloads(response)
payload = json.loads(out["body"]["string"].decode("utf-8"))
assert payload["data"][0]["b64_json"] == VCR_IMAGE_B64_PLACEHOLDER
def test_strip_image_b64_replaces_payload_when_body_is_str():
response = _image_response("A" * 5000, body_type="str")
out = _strip_image_b64_payloads(response)
payload = json.loads(out["body"]["string"])
assert payload["data"][0]["b64_json"] == VCR_IMAGE_B64_PLACEHOLDER
def test_strip_image_b64_updates_content_length():
response = _image_response("A" * 5000)
out = _strip_image_b64_payloads(response)
expected_len = len(out["body"]["string"])
assert out["headers"]["content-length"] == [str(expected_len)]
def test_strip_image_b64_is_idempotent():
response = _image_response("A" * 5000)
once = _strip_image_b64_payloads(response)
twice = _strip_image_b64_payloads(once)
assert once["body"]["string"] == twice["body"]["string"]
def test_strip_image_b64_handles_nested_data():
body_text = json.dumps(
{
"outer": {
"data": [
{"b64_json": "X" * 4000, "label": "first"},
{"b64_json": "Y" * 4000, "label": "second"},
]
}
}
)
response = {
"status": {"code": 200, "message": "OK"},
"headers": {"content-type": ["application/json"]},
"body": {"string": body_text.encode("utf-8")},
}
out = _strip_image_b64_payloads(response)
payload = json.loads(out["body"]["string"].decode("utf-8"))
assert payload["outer"]["data"][0]["b64_json"] == VCR_IMAGE_B64_PLACEHOLDER
assert payload["outer"]["data"][1]["b64_json"] == VCR_IMAGE_B64_PLACEHOLDER
assert payload["outer"]["data"][0]["label"] == "first"
def test_strip_image_b64_leaves_non_image_response_unchanged():
body_text = json.dumps({"choices": [{"message": {"content": "hello"}}]})
response = {
"status": {"code": 200, "message": "OK"},
"headers": {"content-type": ["application/json"]},
"body": {"string": body_text.encode("utf-8")},
}
out = _strip_image_b64_payloads(response)
assert json.loads(out["body"]["string"].decode("utf-8")) == json.loads(body_text)
def test_strip_image_b64_leaves_invalid_json_unchanged():
response = {
"status": {"code": 200, "message": "OK"},
"headers": {"content-type": ["application/octet-stream"]},
"body": {"string": b"\x89PNG\r\n\x1a\n binary stuff not json"},
}
out = _strip_image_b64_payloads(response)
assert out["body"]["string"] == b"\x89PNG\r\n\x1a\n binary stuff not json"
def test_strip_image_b64_skips_short_values():
"""Already-placeholder values aren't re-replaced (idempotency guard)."""
body_text = json.dumps({"data": [{"b64_json": VCR_IMAGE_B64_PLACEHOLDER}]})
response = {
"status": {"code": 200, "message": "OK"},
"headers": {"content-type": ["application/json"]},
"body": {"string": body_text.encode("utf-8")},
}
out = _strip_image_b64_payloads(response)
payload = json.loads(out["body"]["string"].decode("utf-8"))
assert payload["data"][0]["b64_json"] == VCR_IMAGE_B64_PLACEHOLDER
# ---------------------------------------------------------------------------
# Multipart boundary normalizer
# ---------------------------------------------------------------------------
def _multipart_request(boundary: str):
body_text = (
f"--{boundary}\r\n"
'Content-Disposition: form-data; name="file"; filename="audio.wav"\r\n'
"Content-Type: audio/wav\r\n"
"\r\n"
"fake-audio-bytes\r\n"
f"--{boundary}--\r\n"
)
return Request(
method="POST",
uri="https://api.openai.com/v1/audio/transcriptions",
body=body_text.encode("utf-8"),
headers={
"content-type": f"multipart/form-data; boundary={boundary}",
},
)
def test_normalize_multipart_rewrites_header_and_body():
req = _multipart_request("abc123random")
_normalize_multipart_boundary(req)
assert (
req.headers["content-type"]
== f"multipart/form-data; boundary={VCR_FIXED_MULTIPART_BOUNDARY}"
)
assert b"abc123random" not in req.body
assert VCR_FIXED_MULTIPART_BOUNDARY.encode("utf-8") in req.body
def test_normalize_multipart_is_idempotent():
req = _multipart_request("abc123random")
_normalize_multipart_boundary(req)
body_first = req.body
header_first = req.headers["content-type"]
_normalize_multipart_boundary(req)
assert req.body == body_first
assert req.headers["content-type"] == header_first
def test_normalize_multipart_two_distinct_boundaries_match_after_normalize():
"""Whisper-style: two requests with different random boundaries should
end up with byte-identical bodies after normalization."""
req1 = _multipart_request("boundaryAAA")
req2 = _multipart_request("boundaryBBB")
_normalize_multipart_boundary(req1)
_normalize_multipart_boundary(req2)
assert req1.body == req2.body
assert req1.headers["content-type"] == req2.headers["content-type"]
def test_normalize_multipart_skips_non_multipart_requests():
req = Request(
method="POST",
uri="https://api.openai.com/v1/chat/completions",
body=b'{"model":"gpt-4o"}',
headers={"content-type": "application/json"},
)
_normalize_multipart_boundary(req)
assert req.headers["content-type"] == "application/json"
assert req.body == b'{"model":"gpt-4o"}'
def test_normalize_multipart_skips_request_without_content_type():
req = Request(
method="POST",
uri="https://api.openai.com/v1/chat/completions",
body=b"unknown body",
headers={},
)
_normalize_multipart_boundary(req)
assert req.body == b"unknown body"
def test_normalize_multipart_handles_quoted_boundary():
req = Request(
method="POST",
uri="https://api.openai.com/v1/audio/transcriptions",
body=b"--quoted-boundary--body content--quoted-boundary--",
headers={"content-type": 'multipart/form-data; boundary="quoted-boundary"'},
)
_normalize_multipart_boundary(req)
assert b"quoted-boundary" not in req.body
assert VCR_FIXED_MULTIPART_BOUNDARY.encode("utf-8") in req.body

View File

@ -101,6 +101,7 @@ _SCALAR_ATTRS = (
"redact_messages_in_exceptions",
"redact_user_api_key_info",
"s3_callback_params",
"s3_audit_callback_params",
"datadog_params",
"vector_store_registry",
)
@ -128,6 +129,7 @@ def isolate_litellm_state():
leaking across tests within the same xdist worker.
"""
from litellm.litellm_core_utils import litellm_logging as ll_logging
from litellm.proxy.management_helpers import audit_logs as ll_audit_logs
# Flush cache and clear internal logger instances before test
if hasattr(litellm, "in_memory_llm_clients_cache"):
@ -135,6 +137,7 @@ def isolate_litellm_state():
# Clear cached logger instances (LangsmithLogger, SlackAlerting, etc.)
ll_logging._in_memory_loggers.clear()
ll_audit_logs._audit_log_callback_cache.clear()
# Reset ALL attrs to their true defaults before the test runs.
# This undoes any module-level mutations from test file imports.
@ -156,6 +159,7 @@ def isolate_litellm_state():
litellm.in_memory_llm_clients_cache.flush_cache()
ll_logging._in_memory_loggers.clear()
ll_audit_logs._audit_log_callback_cache.clear()
for attr in _LIST_ATTRS:
if attr in _DEFAULTS:

View File

@ -12,7 +12,15 @@ from abc import ABC, abstractmethod
# Test resources
TEST_IMAGE_PATH = "test_image_edit.png"
TEST_PDF_URL = "https://arxiv.org/pdf/2201.04234"
# Tiny in-repo PDF served via jsdelivr (sha-pinned, immutable). The arxiv
# PDF previously used here was several MB — once base64-encoded into the
# Vertex OCR request it ballooned cassettes past 100 MB per test. Keep
# the URL stable across runs so cassettes don't churn.
TEST_PDF_URL = (
"https://cdn.jsdelivr.net/gh/BerriAI/litellm"
"@d769e81c90d453240c61fc572cdb27fae06a89d0"
"/tests/llm_translation/fixtures/dummy.pdf"
)
class BaseOCRTest(ABC):

View File

@ -1123,3 +1123,74 @@ async def test_combined_prefix_reflects_in_s3_object_key():
result = logger.create_s3_batch_logging_element(datetime.utcnow(), payload)
key = result.s3_object_key
assert "myteam/apikey/" in key, f"Expected both prefixes in key: {key}"
# --------------------------------------------------------------
# params_source / s3_callback_params_override (audit-log decoupling)
# --------------------------------------------------------------
def test_s3_callback_params_override_uses_alternate_dict():
"""`s3_callback_params_override` makes the logger read its config from
the override dict instead of `litellm.s3_callback_params`."""
import litellm
original = litellm.s3_callback_params
litellm.s3_callback_params = {"s3_bucket_name": "normal-bucket"}
try:
logger = S3Logger(
s3_callback_params_override={
"s3_bucket_name": "audit-bucket",
"s3_path": "audit-prefix",
"s3_region_name": "us-west-2",
}
)
assert logger.s3_bucket_name == "audit-bucket"
assert logger.s3_path == "audit-prefix"
assert logger.s3_region_name == "us-west-2"
finally:
litellm.s3_callback_params = original
def test_s3_callback_params_override_does_not_mutate_inputs(monkeypatch):
"""Resolving `os.environ/X` markers must not mutate the override dict
or `litellm.s3_callback_params`."""
import litellm
monkeypatch.setenv("MY_AUDIT_BUCKET", "resolved-bucket")
override = {"s3_bucket_name": "os.environ/MY_AUDIT_BUCKET"}
original_global = litellm.s3_callback_params
litellm.s3_callback_params = {"s3_bucket_name": "os.environ/MY_AUDIT_BUCKET"}
try:
logger = S3Logger(s3_callback_params_override=override)
assert logger.s3_bucket_name == "resolved-bucket"
assert override["s3_bucket_name"] == "os.environ/MY_AUDIT_BUCKET"
assert (
litellm.s3_callback_params["s3_bucket_name"] == "os.environ/MY_AUDIT_BUCKET"
)
finally:
litellm.s3_callback_params = original_global
def test_s3_callback_params_override_none_falls_back_to_global():
"""No override → behaves exactly as today (reads `litellm.s3_callback_params`)."""
import litellm
original = litellm.s3_callback_params
litellm.s3_callback_params = {"s3_bucket_name": "from-global"}
try:
logger = S3Logger()
assert logger.s3_bucket_name == "from-global"
finally:
litellm.s3_callback_params = original
def test_s3_callback_params_override_empty_dict_is_opt_in():
"""An empty override dict skips the global entirely (env/IAM-only config)."""
import litellm
original = litellm.s3_callback_params
litellm.s3_callback_params = {"s3_bucket_name": "from-global"}
try:
logger = S3Logger(s3_callback_params_override={})
assert logger.s3_bucket_name is None
finally:
litellm.s3_callback_params = original

View File

@ -367,3 +367,128 @@ def test_update_messages_with_model_file_ids_skips_non_openai_file_blocks():
# Messages pass through unchanged when there is no `file` sub-dict to remap.
assert updated == messages
# Reusable fixture (decodes to: litellm_proxy:application/pdf;unified_id,...;
# target_model_names,gpt-4o;llm_output_file_id,file-ECBPW7ML9g7XHdwGgUPZaM;
# llm_output_file_model_id,...)
UNIFIED_FILE_ID_B64 = (
"bGl0ZWxsbV9wcm94eTphcHBsaWNhdGlvbi9wZGY7dW5pZmllZF9pZCw2YzBiNTg5MC04OTE0"
"LTQ4ZTAtYjhmNC0wYWU1ZWQzYzE0YTU7dGFyZ2V0X21vZGVsX25hbWVzLGdwdC00bztsbG1f"
"b3V0cHV0X2ZpbGVfaWQsZmlsZS1FQ0JQVzdNTDlnN1hIZHdHZ1VQWmFNO2xsbV9vdXRwdXRf"
"ZmlsZV9tb2RlbF9pZCxlMjY0NTNmOWU3NmU3OTkzNjgwZDAwNjhkOThjMWY0Y2MyMDViYmFk"
"MDk2N2EzM2M2NjQ4OTM1NjhjYTc0M2My"
)
def test_update_messages_with_model_file_ids_decodes_unified_id_when_mapping_empty():
"""When the mapping is empty (e.g. multi-replica cache miss), the function
must decode the base64-encoded unified file id and substitute the embedded
llm_output_file_id mirroring the Responses-API sibling."""
messages = [
{
"role": "user",
"content": [
{"type": "text", "text": "What is in this recording?"},
{
"type": "file",
"file": {
"file_id": UNIFIED_FILE_ID_B64,
"format": "audio/wav",
},
},
],
}
]
updated = update_messages_with_model_file_ids(messages, "any-model-id", {})
assert updated[0]["content"][1]["file"]["file_id"] == "file-ECBPW7ML9g7XHdwGgUPZaM"
# Customer-supplied format is preserved (this is the field whose absence
# the misleading error message used to complain about).
assert updated[0]["content"][1]["file"]["format"] == "audio/wav"
def test_update_messages_with_model_file_ids_mapping_takes_precedence_over_decode():
"""When both mapping and decode would resolve, the mapping must win
(preserves per-deployment routing precision)."""
mapping = {UNIFIED_FILE_ID_B64: {"model-A": "mapped-provider-file-id"}}
messages = [
{
"role": "user",
"content": [
{
"type": "file",
"file": {
"file_id": UNIFIED_FILE_ID_B64,
"format": "application/pdf",
},
},
],
}
]
updated = update_messages_with_model_file_ids(messages, "model-A", mapping)
assert updated[0]["content"][0]["file"]["file_id"] == "mapped-provider-file-id"
def test_update_messages_with_model_file_ids_non_unified_passes_through():
"""A raw provider id (e.g. gs:// URI or a random string) must be left
untouched when the mapping doesn't resolve it. The decode fallback must
not corrupt non-unified ids."""
raw_id = "gs://my-bucket/uploads/abc-123.wav"
messages = [
{
"role": "user",
"content": [
{"type": "file", "file": {"file_id": raw_id, "format": "audio/wav"}},
],
}
]
updated = update_messages_with_model_file_ids(messages, "model-A", {})
assert updated[0]["content"][0]["file"]["file_id"] == raw_id
def test_update_messages_with_model_file_ids_mapping_miss_falls_back_to_decode():
"""A mapping that exists but doesn't contain this file_id should still
trigger the decode fallback covers the case where the hook resolved
*some* ids but not this one."""
other_id = "some-other-file-id"
mapping = {other_id: {"model-A": "other-provider-id"}}
messages = [
{
"role": "user",
"content": [
{
"type": "file",
"file": {"file_id": UNIFIED_FILE_ID_B64, "format": "audio/wav"},
},
],
}
]
updated = update_messages_with_model_file_ids(messages, "model-A", mapping)
assert updated[0]["content"][0]["file"]["file_id"] == "file-ECBPW7ML9g7XHdwGgUPZaM"
def test_update_messages_with_model_file_ids_tolerates_non_dict_content_items():
"""Content list items aren't always dicts. text_completion forwards
token-ids (list of ints, or list of list of ints for batch) through
this path. The function must skip non-dict items instead of indexing
into them."""
messages_token_ids = [{"role": "user", "content": [15496, 995]}]
messages_token_ids_batch = [{"role": "user", "content": [[15496, 995], [9906, 0]]}]
# Both should pass through unchanged without raising.
assert (
update_messages_with_model_file_ids(messages_token_ids, "model-A", {})
== messages_token_ids
)
assert (
update_messages_with_model_file_ids(messages_token_ids_batch, "model-A", {})
== messages_token_ids_batch
)

View File

@ -346,3 +346,126 @@ class TestS3LoggerAuditLogEvent:
element = logger.log_queue[0]
assert element.s3_object_key.startswith("audit_logs/")
assert "audit-456" in element.s3_object_key
class TestS3AuditCallbackParamsDecoupling:
"""`s3_audit_callback_params` should give the audit-log path its own
S3Logger instance, distinct from the singleton serving normal logs."""
@pytest.fixture(autouse=True)
def _isolate_caches_and_globals(self):
from litellm.litellm_core_utils import litellm_logging as ll_logging
from litellm.proxy.management_helpers import audit_logs as ll_audit_logs
original_s3 = litellm.s3_callback_params
original_audit = getattr(litellm, "s3_audit_callback_params", None)
ll_audit_logs._audit_log_callback_cache.clear()
ll_logging._in_memory_loggers.clear()
yield
litellm.s3_callback_params = original_s3
litellm.s3_audit_callback_params = original_audit
ll_audit_logs._audit_log_callback_cache.clear()
ll_logging._in_memory_loggers.clear()
def test_opt_in_constructs_separate_instance_with_audit_config(self):
"""Audit config set → audit resolver returns a fresh S3Logger pointing
at the audit bucket, distinct from the normal-log singleton."""
from litellm.integrations.s3_v2 import S3Logger
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
from litellm.proxy.management_helpers.audit_logs import (
_resolve_audit_log_callback,
)
litellm.s3_callback_params = {"s3_bucket_name": "normal-bucket"}
litellm.s3_audit_callback_params = {"s3_bucket_name": "audit-bucket"}
with patch("asyncio.create_task"):
audit_instance = _resolve_audit_log_callback("s3_v2")
normal_instance = _init_custom_logger_compatible_class(
logging_integration="s3_v2",
internal_usage_cache=None,
llm_router=None,
)
assert isinstance(audit_instance, S3Logger)
assert isinstance(normal_instance, S3Logger)
assert id(audit_instance) != id(normal_instance)
assert audit_instance.s3_bucket_name == "audit-bucket"
assert normal_instance.s3_bucket_name == "normal-bucket"
def test_opt_out_preserves_singleton_behavior(self):
"""No `s3_audit_callback_params` → audit and normal share the singleton
(existing behavior, regression guard)."""
from litellm.integrations.s3_v2 import S3Logger
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
from litellm.proxy.management_helpers.audit_logs import (
_resolve_audit_log_callback,
)
litellm.s3_callback_params = {"s3_bucket_name": "shared-bucket"}
litellm.s3_audit_callback_params = None
with patch("asyncio.create_task"):
normal_instance = _init_custom_logger_compatible_class(
logging_integration="s3_v2",
internal_usage_cache=None,
llm_router=None,
)
audit_instance = _resolve_audit_log_callback("s3_v2")
assert isinstance(audit_instance, S3Logger)
assert id(audit_instance) == id(normal_instance)
assert audit_instance.s3_bucket_name == "shared-bucket"
def test_empty_dict_opts_in(self):
"""`s3_audit_callback_params = {}` is opt-in (truthy-by-presence) and
produces a separate instance with no bucket configured (env/IAM-only)."""
from litellm.integrations.s3_v2 import S3Logger
from litellm.litellm_core_utils.litellm_logging import (
_init_custom_logger_compatible_class,
)
from litellm.proxy.management_helpers.audit_logs import (
_resolve_audit_log_callback,
)
litellm.s3_callback_params = {"s3_bucket_name": "normal-bucket"}
litellm.s3_audit_callback_params = {}
with patch("asyncio.create_task"):
audit_instance = _resolve_audit_log_callback("s3_v2")
normal_instance = _init_custom_logger_compatible_class(
logging_integration="s3_v2",
internal_usage_cache=None,
llm_router=None,
)
assert id(audit_instance) != id(normal_instance)
assert audit_instance.s3_bucket_name is None
assert normal_instance.s3_bucket_name == "normal-bucket"
def test_reset_audit_log_callback_cache_clears_audit_instance(self):
"""`reset_audit_log_callback_cache()` must drop the cached audit
instance so a config reload picks up the new params."""
from litellm.proxy.management_helpers.audit_logs import (
_audit_log_callback_cache,
_resolve_audit_log_callback,
reset_audit_log_callback_cache,
)
litellm.s3_audit_callback_params = {"s3_bucket_name": "first"}
with patch("asyncio.create_task"):
first = _resolve_audit_log_callback("s3_v2")
assert first is not None and "s3_v2" in _audit_log_callback_cache
reset_audit_log_callback_cache()
assert "s3_v2" not in _audit_log_callback_cache
litellm.s3_audit_callback_params = {"s3_bucket_name": "second"}
second = _resolve_audit_log_callback("s3_v2")
assert second is not None
assert id(second) != id(first)
assert second.s3_bucket_name == "second"

View File

@ -3,8 +3,9 @@ import datetime
from typing import AsyncGenerator
from unittest.mock import AsyncMock, MagicMock, patch
import httpx
import pytest
from fastapi import HTTPException, Request, status
from fastapi import HTTPException, Request, Response, status
from fastapi.responses import JSONResponse, StreamingResponse
import litellm
@ -26,6 +27,48 @@ from litellm.proxy.utils import ProxyLogging
class TestProxyBaseLLMRequestProcessing:
@pytest.mark.asyncio
async def test_base_passthrough_process_llm_request_preserves_litellm_headers_for_non_streaming_response(
self, monkeypatch
):
processing_obj = ProxyBaseLLMRequestProcessing(data={})
async def fake_base_process_llm_request(**kwargs):
passthrough_response = kwargs["fastapi_response"]
passthrough_response.headers["x-litellm-call-id"] = "test-call-id"
passthrough_response.headers["x-litellm-version"] = "test-version"
return httpx.Response(
status_code=200,
content=b'{"ok":true}',
headers={
"content-type": "application/json",
"x-amzn-requestid": "bedrock-request-id",
},
)
monkeypatch.setattr(
processing_obj,
"base_process_llm_request",
fake_base_process_llm_request,
)
result = await processing_obj.base_passthrough_process_llm_request(
request=MagicMock(spec=Request),
fastapi_response=Response(),
user_api_key_dict=MagicMock(spec=UserAPIKeyAuth),
proxy_logging_obj=MagicMock(spec=ProxyLogging),
general_settings={},
proxy_config=MagicMock(spec=ProxyConfig),
select_data_generator=MagicMock(),
model="bedrock-test-model",
)
assert result.status_code == 200
assert result.body == b'{"ok":true}'
assert result.headers["x-amzn-requestid"] == "bedrock-request-id"
assert result.headers["x-litellm-call-id"] == "test-call-id"
assert result.headers["x-litellm-version"] == "test-version"
@pytest.mark.asyncio
async def test_common_processing_pre_call_logic_pre_call_hook_receives_litellm_call_id(
self, monkeypatch