test: stabilize batch VCR coverage and stop live upload/network leaks (#29477)
* test: stabilize batch VCR coverage * test: replay bedrock batch s3 uploads * test: stop batch tests leaking live uploads * test: keep bedrock batch workflow off live s3 * test: mock bedrock batch workflow network * test: accept realtime guardrail refusal wording * test: update gemini thought signature model * test: quiet logging worker atexit flush * test: address Greptile review on batch VCR fixes Handle content= bodies in the bedrock batch post stub so payload extraction does not raise a TypeError when a request omits json and data. Restore litellm list state faithfully by preserving None instead of coercing it to an empty list, so callbacks that start as None are not turned into [] after a test. Set logging.raiseExceptions inside the try block in the atexit flush so the finally always restores the previous value. * test: scope atexit logging suppression to the drain loop Wrap only the queue drain loop in LoggingWorker._flush_on_exit with the logging.raiseExceptions toggle so the process-wide global is suppressed for the smallest possible window, keeping other threads' logging error reporting intact outside the loop. * test: cover atexit flush error-swallow branch in LoggingWorker The _flush_on_exit drain loop was wrapped in a try/finally to scope the logging.raiseExceptions toggle, which reindented the existing edge-case branches into the diff and dropped patch coverage below target. Add a regression test that enqueues a coroutine which raises during the atexit flush and asserts the failure is swallowed while later queued events are still drained, exercising the silent-failure path directly.
This commit is contained in:
parent
3f33efdd57
commit
6a9f542f81
@ -3,6 +3,7 @@
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import logging
|
||||
from typing import Coroutine, Optional
|
||||
import atexit
|
||||
from typing_extensions import TypedDict
|
||||
@ -494,31 +495,43 @@ class LoggingWorker:
|
||||
processed = 0
|
||||
start_time = loop.time()
|
||||
|
||||
while not self._queue.empty() and processed < MAX_ITERATIONS_TO_CLEAR_QUEUE:
|
||||
if loop.time() - start_time >= MAX_TIME_TO_CLEAR_QUEUE:
|
||||
self._safe_log(
|
||||
"warning",
|
||||
f"[LoggingWorker] atexit: Reached time limit ({MAX_TIME_TO_CLEAR_QUEUE}s), stopping flush",
|
||||
)
|
||||
break
|
||||
# logging.raiseExceptions is a process-wide global; scope the
|
||||
# suppression to just the drain loop, where shutdown callbacks may
|
||||
# log to already-closed handler streams, so other threads keep their
|
||||
# logging error reporting for as little of the window as possible.
|
||||
previous_raise_exceptions = logging.raiseExceptions
|
||||
logging.raiseExceptions = False
|
||||
try:
|
||||
while (
|
||||
not self._queue.empty()
|
||||
and processed < MAX_ITERATIONS_TO_CLEAR_QUEUE
|
||||
):
|
||||
if loop.time() - start_time >= MAX_TIME_TO_CLEAR_QUEUE:
|
||||
self._safe_log(
|
||||
"warning",
|
||||
f"[LoggingWorker] atexit: Reached time limit ({MAX_TIME_TO_CLEAR_QUEUE}s), stopping flush",
|
||||
)
|
||||
break
|
||||
|
||||
try:
|
||||
task = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
try:
|
||||
task = self._queue.get_nowait()
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
# Run the coroutine synchronously in new loop
|
||||
# Note: We run the coroutine directly, not via create_task,
|
||||
# since we're in a new event loop context
|
||||
try:
|
||||
loop.run_until_complete(task["coroutine"])
|
||||
processed += 1
|
||||
except Exception:
|
||||
# Silent failure to not break user's program
|
||||
pass
|
||||
finally:
|
||||
# Clear reference to prevent memory leaks
|
||||
task = None
|
||||
# Run the coroutine synchronously in new loop
|
||||
# Note: We run the coroutine directly, not via create_task,
|
||||
# since we're in a new event loop context
|
||||
try:
|
||||
loop.run_until_complete(task["coroutine"])
|
||||
processed += 1
|
||||
except Exception:
|
||||
# Silent failure to not break user's program
|
||||
pass
|
||||
finally:
|
||||
# Clear reference to prevent memory leaks
|
||||
task = None
|
||||
finally:
|
||||
logging.raiseExceptions = previous_raise_exceptions
|
||||
|
||||
self._safe_log(
|
||||
"info",
|
||||
|
||||
@ -53,6 +53,7 @@ CASSETTE_CACHE_HIGH_WATER_FRACTION = 0.85
|
||||
SAFE_BODY_MATCHER_NAME = "safe_body"
|
||||
KEY_FINGERPRINT_MATCHER_NAME = "key_fingerprint"
|
||||
TOLERANT_QUERY_MATCHER_NAME = "tolerant_query"
|
||||
TOLERANT_PATH_MATCHER_NAME = "tolerant_path"
|
||||
KEY_FINGERPRINT_HEADER = "x-litellm-key-fp"
|
||||
|
||||
VCR_DIAG_DIR_ENV = "LITELLM_VCR_DIAG_DIR"
|
||||
@ -411,6 +412,7 @@ def _canonical_body(request) -> tuple[bytes, str]:
|
||||
_VCR_UUID_RE = re.compile(
|
||||
rb"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
)
|
||||
_VCR_LITELLM_BATCH_JOB_RE = re.compile(rb"litellm-batch-[0-9a-fA-F]{8}")
|
||||
# ISO-8601 timestamps, e.g. ``2026-05-25T03:40:37.262045Z`` /
|
||||
# ``2026-05-25T03:40:37+00:00``.
|
||||
_VCR_ISO_TS_RE = re.compile(
|
||||
@ -436,6 +438,7 @@ def _normalize_volatile_tokens(body: bytes) -> bytes:
|
||||
if not body:
|
||||
return body
|
||||
body = _VCR_UUID_RE.sub(b"<vcr-uuid>", body)
|
||||
body = _VCR_LITELLM_BATCH_JOB_RE.sub(b"litellm-batch-<vcr-id>", body)
|
||||
body = _VCR_ISO_TS_RE.sub(b"<vcr-iso-ts>", body)
|
||||
body = _VCR_UNIX_MS_RE.sub(b"<vcr-unix-ms>", body)
|
||||
body = _VCR_UNIX_FLOAT_RE.sub(b"<vcr-unix-float>", body)
|
||||
@ -1059,6 +1062,54 @@ def _tolerant_query_matcher(r1, r2) -> None:
|
||||
_vcr_matchers.query(r1, r2)
|
||||
|
||||
|
||||
_BEDROCK_MANAGED_S3_PATH_RE = re.compile(
|
||||
r"(?P<prefix>(?:^|/)(?:litellm-bedrock-files/[^/?#]+-|litellm-bedrock-files-[^/?#]+-))"
|
||||
r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||
r"(?P<suffix>\.jsonl)"
|
||||
)
|
||||
|
||||
|
||||
def _request_path_for_matcher(request) -> str:
|
||||
path = getattr(request, "path", None)
|
||||
if path is not None:
|
||||
return str(path)
|
||||
|
||||
uri = getattr(request, "uri", None) or getattr(request, "url", "") or ""
|
||||
uri = str(uri)
|
||||
if not uri:
|
||||
return ""
|
||||
if "//" in uri:
|
||||
rest = uri.split("//", 1)[1]
|
||||
path_part = "/" + rest.split("/", 1)[1] if "/" in rest else "/"
|
||||
else:
|
||||
path_part = uri
|
||||
return path_part.split("?", 1)[0]
|
||||
|
||||
|
||||
def _normalize_volatile_path(path: str) -> str:
|
||||
return _BEDROCK_MANAGED_S3_PATH_RE.sub(
|
||||
lambda match: f"{match.group('prefix')}<vcr-uuid>{match.group('suffix')}",
|
||||
path,
|
||||
)
|
||||
|
||||
|
||||
def _tolerant_path_matcher(r1, r2) -> None:
|
||||
"""vcrpy's ``path`` matcher, plus LiteLLM-managed Bedrock S3 upload UUIDs.
|
||||
|
||||
Bedrock batch file uploads use object keys like
|
||||
``litellm-bedrock-files-{model}-{uuid}.jsonl`` (and older cassettes may
|
||||
contain ``litellm-bedrock-files/{model}-{uuid}.jsonl``). The UUID is
|
||||
generated client-side before the S3 PUT, so strict path matching makes
|
||||
every replay miss even when the JSONL body and all provider semantics are
|
||||
identical.
|
||||
"""
|
||||
path1 = _normalize_volatile_path(_request_path_for_matcher(r1))
|
||||
path2 = _normalize_volatile_path(_request_path_for_matcher(r2))
|
||||
if path1 == path2:
|
||||
return
|
||||
_vcr_matchers.path(r1, r2)
|
||||
|
||||
|
||||
def vcr_config_dict() -> dict:
|
||||
return {
|
||||
"decode_compressed_response": True,
|
||||
@ -1069,7 +1120,7 @@ def vcr_config_dict() -> dict:
|
||||
"scheme",
|
||||
"host",
|
||||
"port",
|
||||
"path",
|
||||
TOLERANT_PATH_MATCHER_NAME,
|
||||
TOLERANT_QUERY_MATCHER_NAME,
|
||||
KEY_FINGERPRINT_MATCHER_NAME,
|
||||
SAFE_BODY_MATCHER_NAME,
|
||||
@ -1136,6 +1187,7 @@ def register_persister_if_enabled(vcr) -> None:
|
||||
vcr.register_matcher(SAFE_BODY_MATCHER_NAME, _safe_body_matcher)
|
||||
vcr.register_matcher(KEY_FINGERPRINT_MATCHER_NAME, _key_fingerprint_matcher)
|
||||
vcr.register_matcher(TOLERANT_QUERY_MATCHER_NAME, _tolerant_query_matcher)
|
||||
vcr.register_matcher(TOLERANT_PATH_MATCHER_NAME, _tolerant_path_matcher)
|
||||
patch_vcrpy_aiohttp_record_path()
|
||||
patch_vcrpy_cassette_load_guard()
|
||||
global _atexit_banner_registered
|
||||
@ -1647,13 +1699,18 @@ def _is_live_call_host(host: str) -> bool:
|
||||
return False
|
||||
if any(host.endswith(suffix) for suffix in _LIVE_CALL_HOST_SUFFIXES):
|
||||
return True
|
||||
# AWS Bedrock endpoints are ``bedrock-runtime[-fips].{region}.amazonaws.com``
|
||||
# (region between ``bedrock-runtime`` and ``amazonaws.com``), so plain
|
||||
# suffix matching can't catch them.
|
||||
if host.endswith(".amazonaws.com") and host.split(".", 1)[0].startswith(
|
||||
"bedrock-runtime"
|
||||
):
|
||||
return True
|
||||
if host.endswith(".amazonaws.com"):
|
||||
first_label = host.split(".", 1)[0]
|
||||
# AWS Bedrock control/runtime endpoints are
|
||||
# ``bedrock[-runtime][-fips].{region}.amazonaws.com`` (region between
|
||||
# the service label and ``amazonaws.com``), so plain suffix matching
|
||||
# can't catch them.
|
||||
if first_label.startswith("bedrock"):
|
||||
return True
|
||||
# Bedrock batch file upload/download uses real S3. Treat those as part
|
||||
# of the paid provider path so unmarked batch tests surface as leaks.
|
||||
if first_label in {"s3", "s3-fips"} or ".s3." in host or ".s3-" in host:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@ -1,5 +1,3 @@
|
||||
# conftest.py
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
@ -11,6 +9,82 @@ sys.path.insert(
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm # noqa: E402,F401
|
||||
|
||||
from tests._vcr_conftest_common import ( # noqa: E402,F401
|
||||
VerboseReporterState,
|
||||
_pin_multipart_boundary,
|
||||
apply_vcr_auto_marker_to_items,
|
||||
emit_cassette_cache_session_banner,
|
||||
emit_vcr_classification_summary,
|
||||
emit_vcr_diagnostic_log,
|
||||
install_live_call_probe,
|
||||
record_vcr_outcome,
|
||||
register_persister_if_enabled,
|
||||
reset_vcr_diag_dir,
|
||||
vcr_config_dict,
|
||||
)
|
||||
|
||||
_verbose_state = VerboseReporterState()
|
||||
|
||||
_CALLBACK_ATTRS = (
|
||||
"callbacks",
|
||||
"success_callback",
|
||||
"failure_callback",
|
||||
"_async_success_callback",
|
||||
"_async_failure_callback",
|
||||
)
|
||||
|
||||
_SCALAR_ATTRS = (
|
||||
"num_retries",
|
||||
"set_verbose",
|
||||
"cache",
|
||||
"allowed_fails",
|
||||
"disable_aiohttp_transport",
|
||||
"force_ipv4",
|
||||
"drop_params",
|
||||
"modify_params",
|
||||
"api_base",
|
||||
"api_key",
|
||||
"cohere_key",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def vcr_config():
|
||||
return vcr_config_dict()
|
||||
|
||||
|
||||
def pytest_recording_configure(config, vcr):
|
||||
register_persister_if_enabled(vcr)
|
||||
|
||||
|
||||
@pytest.hookimpl(hookwrapper=True)
|
||||
def pytest_runtest_makereport(item, call):
|
||||
outcome = yield
|
||||
rep = outcome.get_result()
|
||||
setattr(item, f"rep_{rep.when}", rep)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _vcr_outcome_gate(request, vcr):
|
||||
install_live_call_probe(request, vcr)
|
||||
yield
|
||||
record_vcr_outcome(request, vcr)
|
||||
|
||||
|
||||
def pytest_configure(config):
|
||||
_verbose_state.remember_pluginmanager(config)
|
||||
reset_vcr_diag_dir()
|
||||
|
||||
|
||||
def pytest_runtest_logreport(report):
|
||||
_verbose_state.maybe_emit_verdict(report)
|
||||
|
||||
|
||||
def pytest_terminal_summary(terminalreporter, exitstatus, config):
|
||||
emit_cassette_cache_session_banner(terminalreporter)
|
||||
emit_vcr_classification_summary(terminalreporter)
|
||||
emit_vcr_diagnostic_log(terminalreporter)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop():
|
||||
@ -20,3 +94,64 @@ def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
def _copy_litellm_state():
|
||||
state = {}
|
||||
for attr in _CALLBACK_ATTRS:
|
||||
if hasattr(litellm, attr):
|
||||
value = getattr(litellm, attr)
|
||||
state[attr] = value.copy() if isinstance(value, list) else value
|
||||
for attr in _SCALAR_ATTRS:
|
||||
if hasattr(litellm, attr):
|
||||
state[attr] = getattr(litellm, attr)
|
||||
return state
|
||||
|
||||
|
||||
def _restore_litellm_state(state) -> None:
|
||||
for attr, value in state.items():
|
||||
if hasattr(litellm, attr):
|
||||
setattr(litellm, attr, value)
|
||||
|
||||
|
||||
def _reset_litellm_callbacks() -> None:
|
||||
for attr in _CALLBACK_ATTRS:
|
||||
if hasattr(litellm, attr):
|
||||
setattr(litellm, attr, [])
|
||||
manager = getattr(litellm, "logging_callback_manager", None)
|
||||
reset = getattr(manager, "_reset_all_callbacks", None)
|
||||
if callable(reset):
|
||||
reset()
|
||||
|
||||
|
||||
def _clear_logging_queue(loop=None) -> None:
|
||||
from litellm.litellm_core_utils.logging_worker import GLOBAL_LOGGING_WORKER
|
||||
|
||||
if loop is not None and not loop.is_closed() and not loop.is_running():
|
||||
loop.run_until_complete(GLOBAL_LOGGING_WORKER.clear_queue())
|
||||
return
|
||||
asyncio.run(GLOBAL_LOGGING_WORKER.clear_queue())
|
||||
|
||||
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def setup_and_teardown(event_loop):
|
||||
original_state = _copy_litellm_state()
|
||||
_clear_logging_queue(event_loop)
|
||||
_reset_litellm_callbacks()
|
||||
asyncio.set_event_loop(event_loop)
|
||||
|
||||
yield
|
||||
|
||||
_clear_logging_queue(event_loop)
|
||||
_reset_litellm_callbacks()
|
||||
_restore_litellm_state(original_state)
|
||||
|
||||
pending = asyncio.all_tasks(event_loop)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
if pending:
|
||||
event_loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True))
|
||||
|
||||
|
||||
def pytest_collection_modifyitems(config, items):
|
||||
apply_vcr_auto_marker_to_items(items)
|
||||
|
||||
@ -51,6 +51,12 @@ def get_expected_batch_file_usage(file_path: str) -> tuple[int, int]:
|
||||
return expected_request_count, expected_total_tokens
|
||||
|
||||
|
||||
def _write_batch_file(tmp_path, file_name: str, content: str) -> str:
|
||||
path = tmp_path / file_name
|
||||
path.write_text(content)
|
||||
return str(path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("OPENAI_API_KEY") is None,
|
||||
@ -114,7 +120,7 @@ async def test_batch_rate_limits():
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_batch_rate_limit_single_file():
|
||||
async def test_batch_rate_limit_single_file(tmp_path):
|
||||
"""
|
||||
Test batch rate limiting with a single file.
|
||||
|
||||
@ -122,8 +128,6 @@ async def test_batch_rate_limit_single_file():
|
||||
- File with < 200 tokens: should go through
|
||||
- File with > 200 tokens: should hit rate limit
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
CUSTOM_LLM_PROVIDER = "openai"
|
||||
|
||||
# Setup: Create internal usage cache and rate limiter
|
||||
@ -152,17 +156,18 @@ async def test_batch_rate_limit_single_file():
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hi"}]}}
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hey"}]}}"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(small_batch_content)
|
||||
small_file_path = f.name
|
||||
small_file_path = _write_batch_file(
|
||||
tmp_path, "small-batch-rate-limit.jsonl", small_batch_content
|
||||
)
|
||||
|
||||
try:
|
||||
# Upload file to OpenAI
|
||||
file_obj_small = await litellm.acreate_file(
|
||||
file=open(small_file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
with open(small_file_path, "rb") as batch_file:
|
||||
file_obj_small = await litellm.acreate_file(
|
||||
file=batch_file,
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
print(f"Created small file: {file_obj_small.id}")
|
||||
await asyncio.sleep(1) # Give API time to process
|
||||
|
||||
@ -183,8 +188,6 @@ async def test_batch_rate_limit_single_file():
|
||||
print(f" Actual tokens: {result.get('_batch_token_count')}")
|
||||
except HTTPException as e:
|
||||
pytest.fail(f"Should not have hit rate limit with small file: {e.detail}")
|
||||
finally:
|
||||
os.unlink(small_file_path)
|
||||
|
||||
# Test 2: File with > 200 tokens should hit rate limit
|
||||
print("\n=== Test 2: File over 200 tokens ===")
|
||||
@ -221,47 +224,45 @@ async def test_batch_rate_limit_single_file():
|
||||
|
||||
large_batch_content = "\n".join(requests)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(large_batch_content)
|
||||
large_file_path = f.name
|
||||
large_file_path = _write_batch_file(
|
||||
tmp_path, "large-batch-rate-limit.jsonl", large_batch_content
|
||||
)
|
||||
|
||||
try:
|
||||
# Upload file to OpenAI
|
||||
# Upload file to OpenAI
|
||||
with open(large_file_path, "rb") as batch_file:
|
||||
file_obj_large = await litellm.acreate_file(
|
||||
file=open(large_file_path, "rb"),
|
||||
file=batch_file,
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
print(f"Created large file: {file_obj_large.id}")
|
||||
await asyncio.sleep(1) # Give API time to process
|
||||
print(f"Created large file: {file_obj_large.id}")
|
||||
await asyncio.sleep(1) # Give API time to process
|
||||
|
||||
data_over_limit = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"input_file_id": file_obj_large.id,
|
||||
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
|
||||
}
|
||||
data_over_limit = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"input_file_id": file_obj_large.id,
|
||||
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
|
||||
}
|
||||
|
||||
# Should raise HTTPException with 429 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await batch_limiter.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=dual_cache,
|
||||
data=data_over_limit,
|
||||
call_type="acreate_batch",
|
||||
)
|
||||
# Should raise HTTPException with 429 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await batch_limiter.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=dual_cache,
|
||||
data=data_over_limit,
|
||||
call_type="acreate_batch",
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 429, "Should return 429 status code"
|
||||
assert (
|
||||
"tokens" in exc_info.value.detail.lower()
|
||||
), "Error message should mention tokens"
|
||||
print(f"✓ File with 250+ tokens correctly rejected (over limit of 200)")
|
||||
print(f" Error: {exc_info.value.detail}")
|
||||
finally:
|
||||
os.unlink(large_file_path)
|
||||
assert exc_info.value.status_code == 429, "Should return 429 status code"
|
||||
assert (
|
||||
"tokens" in exc_info.value.detail.lower()
|
||||
), "Error message should mention tokens"
|
||||
print(f"✓ File with 250+ tokens correctly rejected (over limit of 200)")
|
||||
print(f" Error: {exc_info.value.detail}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_batch_rate_limit_multiple_requests():
|
||||
async def test_batch_rate_limit_multiple_requests(tmp_path):
|
||||
"""
|
||||
Test batch rate limiting with multiple requests.
|
||||
|
||||
@ -269,8 +270,6 @@ async def test_batch_rate_limit_multiple_requests():
|
||||
- Request 1: file with ~100 tokens (should go through, 100/200 used)
|
||||
- Request 2: file with ~105 tokens (should hit limit, 100+105=205 > 200)
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
CUSTOM_LLM_PROVIDER = "openai"
|
||||
|
||||
# Setup: Create internal usage cache and rate limiter
|
||||
@ -313,17 +312,18 @@ async def test_batch_rate_limit_multiple_requests():
|
||||
|
||||
batch_content_1 = "\n".join(requests_1)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(batch_content_1)
|
||||
file_path_1 = f.name
|
||||
file_path_1 = _write_batch_file(
|
||||
tmp_path, "batch-rate-limit-request-1.jsonl", batch_content_1
|
||||
)
|
||||
|
||||
try:
|
||||
# Upload file to OpenAI
|
||||
file_obj_1 = await litellm.acreate_file(
|
||||
file=open(file_path_1, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
with open(file_path_1, "rb") as batch_file:
|
||||
file_obj_1 = await litellm.acreate_file(
|
||||
file=batch_file,
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
print(f"Created file 1: {file_obj_1.id}")
|
||||
await asyncio.sleep(1) # Give API time to process
|
||||
|
||||
@ -346,8 +346,6 @@ async def test_batch_rate_limit_multiple_requests():
|
||||
)
|
||||
except HTTPException as e:
|
||||
pytest.fail(f"Request 1 should not have hit rate limit: {e.detail}")
|
||||
finally:
|
||||
os.unlink(file_path_1)
|
||||
|
||||
# Request 2: File with ~105+ tokens (total would exceed 200)
|
||||
print("\n=== Request 2: File with ~105 tokens (should hit limit) ===")
|
||||
@ -371,43 +369,41 @@ async def test_batch_rate_limit_multiple_requests():
|
||||
|
||||
batch_content_2 = "\n".join(requests_2)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(batch_content_2)
|
||||
file_path_2 = f.name
|
||||
file_path_2 = _write_batch_file(
|
||||
tmp_path, "batch-rate-limit-request-2.jsonl", batch_content_2
|
||||
)
|
||||
|
||||
try:
|
||||
# Upload file to OpenAI
|
||||
# Upload file to OpenAI
|
||||
with open(file_path_2, "rb") as batch_file:
|
||||
file_obj_2 = await litellm.acreate_file(
|
||||
file=open(file_path_2, "rb"),
|
||||
file=batch_file,
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
print(f"Created file 2: {file_obj_2.id}")
|
||||
await asyncio.sleep(1) # Give API time to process
|
||||
print(f"Created file 2: {file_obj_2.id}")
|
||||
await asyncio.sleep(1) # Give API time to process
|
||||
|
||||
data_request2 = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"input_file_id": file_obj_2.id,
|
||||
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
|
||||
}
|
||||
data_request2 = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"input_file_id": file_obj_2.id,
|
||||
"custom_llm_provider": CUSTOM_LLM_PROVIDER,
|
||||
}
|
||||
|
||||
# Should raise HTTPException with 429 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await batch_limiter.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=dual_cache,
|
||||
data=data_request2,
|
||||
call_type="acreate_batch",
|
||||
)
|
||||
# Should raise HTTPException with 429 status
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await batch_limiter.async_pre_call_hook(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
cache=dual_cache,
|
||||
data=data_request2,
|
||||
call_type="acreate_batch",
|
||||
)
|
||||
|
||||
assert exc_info.value.status_code == 429, "Should return 429 status code"
|
||||
assert (
|
||||
"tokens" in exc_info.value.detail.lower()
|
||||
), "Error message should mention tokens"
|
||||
print(f"✓ Request 2 correctly rejected")
|
||||
print(f" Error: {exc_info.value.detail}")
|
||||
finally:
|
||||
os.unlink(file_path_2)
|
||||
assert exc_info.value.status_code == 429, "Should return 429 status code"
|
||||
assert (
|
||||
"tokens" in exc_info.value.detail.lower()
|
||||
), "Error message should mention tokens"
|
||||
print(f"✓ Request 2 correctly rejected")
|
||||
print(f" Error: {exc_info.value.detail}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@ -415,7 +411,7 @@ async def test_batch_rate_limit_multiple_requests():
|
||||
os.environ.get("OPENAI_API_KEY") is None,
|
||||
reason="OPENAI_API_KEY not set - skipping integration test",
|
||||
)
|
||||
async def test_batch_rate_limiter_with_managed_files():
|
||||
async def test_batch_rate_limiter_with_managed_files(tmp_path):
|
||||
"""
|
||||
Test for GEN-2166: Verify batch rate limiter can read user files when managed files are enabled.
|
||||
|
||||
@ -425,7 +421,6 @@ async def test_batch_rate_limiter_with_managed_files():
|
||||
3. Rate limiting is enforced (not silently bypassed)
|
||||
4. No 403 Permission Denied errors occur for files owned by the user
|
||||
"""
|
||||
import tempfile
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
CUSTOM_LLM_PROVIDER = "openai"
|
||||
@ -472,18 +467,19 @@ async def test_batch_rate_limiter_with_managed_files():
|
||||
|
||||
batch_content = "\n".join(requests)
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(batch_content)
|
||||
file_path = f.name
|
||||
file_path = _write_batch_file(
|
||||
tmp_path, "managed-files-batch-rate-limit.jsonl", batch_content
|
||||
)
|
||||
|
||||
try:
|
||||
# Step 1: Upload file to OpenAI (simulating user upload)
|
||||
print("\n1. Uploading batch input file...")
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
with open(file_path, "rb") as batch_file:
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=batch_file,
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
print(f" ✓ File uploaded: {file_obj.id}")
|
||||
await asyncio.sleep(1) # Give API time to process
|
||||
|
||||
@ -568,12 +564,10 @@ async def test_batch_rate_limiter_with_managed_files():
|
||||
raise
|
||||
except Exception as e:
|
||||
pytest.fail(f"Unexpected error: {str(e)}")
|
||||
finally:
|
||||
os.unlink(file_path)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
async def test_batch_rate_limiter_without_user_context():
|
||||
async def test_batch_rate_limiter_without_user_context(tmp_path):
|
||||
"""
|
||||
Test that verifies the bug scenario from GEN-2166.
|
||||
|
||||
@ -583,8 +577,6 @@ async def test_batch_rate_limiter_without_user_context():
|
||||
|
||||
This test documents the expected behavior with and without user context.
|
||||
"""
|
||||
import tempfile
|
||||
|
||||
CUSTOM_LLM_PROVIDER = "openai"
|
||||
|
||||
# Setup
|
||||
@ -596,56 +588,53 @@ async def test_batch_rate_limiter_without_user_context():
|
||||
# Create a simple batch file
|
||||
batch_content = """{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo", "messages": [{"role": "user", "content": "Hello"}]}}"""
|
||||
|
||||
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
||||
f.write(batch_content)
|
||||
file_path = f.name
|
||||
file_path = _write_batch_file(
|
||||
tmp_path, "without-user-context-batch-rate-limit.jsonl", batch_content
|
||||
)
|
||||
|
||||
try:
|
||||
# Upload file
|
||||
# Upload file
|
||||
with open(file_path, "rb") as batch_file:
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=open(file_path, "rb"),
|
||||
file=batch_file,
|
||||
purpose="batch",
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Test 1: Without user context (old behavior - would fail with managed files)
|
||||
print("\n=== Test 1: count_input_file_usage WITHOUT user context ===")
|
||||
try:
|
||||
usage_without_context = await BATCH_LIMITER.count_input_file_usage(
|
||||
file_id=file_obj.id,
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
user_api_key_dict=None, # Explicitly passing None
|
||||
)
|
||||
print(
|
||||
f"✓ Works for non-managed files (tokens: {usage_without_context.total_tokens})"
|
||||
)
|
||||
print(" Note: Would fail with 403 for managed files (GEN-2166 bug)")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed: {str(e)}")
|
||||
|
||||
# Test 2: With user context (new behavior - works with managed files)
|
||||
print("\n=== Test 2: count_input_file_usage WITH user context ===")
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_id="test-user-123",
|
||||
)
|
||||
|
||||
usage_with_context = await BATCH_LIMITER.count_input_file_usage(
|
||||
# Test 1: Without user context (old behavior - would fail with managed files)
|
||||
print("\n=== Test 1: count_input_file_usage WITHOUT user context ===")
|
||||
try:
|
||||
usage_without_context = await BATCH_LIMITER.count_input_file_usage(
|
||||
file_id=file_obj.id,
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
user_api_key_dict=user_api_key_dict, # Passing user context
|
||||
user_api_key_dict=None, # Explicitly passing None
|
||||
)
|
||||
print(f"✓ Works with user context (tokens: {usage_with_context.total_tokens})")
|
||||
print(" Note: This fixes GEN-2166 for managed files")
|
||||
print(
|
||||
f"✓ Works for non-managed files (tokens: {usage_without_context.total_tokens})"
|
||||
)
|
||||
print(" Note: Would fail with 403 for managed files (GEN-2166 bug)")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed: {str(e)}")
|
||||
|
||||
# Verify both return the same results
|
||||
assert usage_with_context.total_tokens == usage_without_context.total_tokens
|
||||
assert usage_with_context.request_count == usage_without_context.request_count
|
||||
print("\n✓ Both methods return identical results for non-managed files")
|
||||
# Test 2: With user context (new behavior - works with managed files)
|
||||
print("\n=== Test 2: count_input_file_usage WITH user context ===")
|
||||
user_api_key_dict = UserAPIKeyAuth(
|
||||
api_key="test-key",
|
||||
user_id="test-user-123",
|
||||
)
|
||||
|
||||
finally:
|
||||
os.unlink(file_path)
|
||||
usage_with_context = await BATCH_LIMITER.count_input_file_usage(
|
||||
file_id=file_obj.id,
|
||||
custom_llm_provider=CUSTOM_LLM_PROVIDER,
|
||||
user_api_key_dict=user_api_key_dict, # Passing user context
|
||||
)
|
||||
print(f"✓ Works with user context (tokens: {usage_with_context.total_tokens})")
|
||||
print(" Note: This fixes GEN-2166 for managed files")
|
||||
|
||||
# Verify both return the same results
|
||||
assert usage_with_context.total_tokens == usage_without_context.total_tokens
|
||||
assert usage_with_context.request_count == usage_without_context.request_count
|
||||
print("\n✓ Both methods return identical results for non-managed files")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
# What is this?
|
||||
## Unit Tests for OpenAI Batches API
|
||||
import asyncio
|
||||
import json
|
||||
import json as json_module
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
@ -19,6 +19,103 @@ from typing import Optional
|
||||
import litellm
|
||||
from unittest.mock import patch, MagicMock
|
||||
import httpx
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
|
||||
|
||||
_BEDROCK_TEST_AWS_ENV = {
|
||||
"AWS_ACCESS_KEY_ID": "test-access-key",
|
||||
"AWS_SECRET_ACCESS_KEY": "test-secret-key",
|
||||
"AWS_REGION": "us-west-2",
|
||||
"AWS_DEFAULT_REGION": "us-west-2",
|
||||
}
|
||||
|
||||
|
||||
class _CaptureAsyncHTTPHandler(AsyncHTTPHandler):
|
||||
def __init__(self):
|
||||
self.timeout = None
|
||||
self.event_hooks = None
|
||||
self.client_alias = "bedrock-test"
|
||||
self.put_calls = []
|
||||
self.post_calls = []
|
||||
self.batch_jobs = {}
|
||||
|
||||
async def put(
|
||||
self,
|
||||
url: str,
|
||||
data=None,
|
||||
json=None,
|
||||
params=None,
|
||||
headers=None,
|
||||
timeout=None,
|
||||
stream: bool = False,
|
||||
content=None,
|
||||
):
|
||||
self.put_calls.append(
|
||||
{
|
||||
"url": url,
|
||||
"data": data,
|
||||
"json": json,
|
||||
"params": params,
|
||||
"headers": headers or {},
|
||||
"timeout": timeout,
|
||||
"stream": stream,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
body = data if data is not None else content
|
||||
content_bytes = body.encode("utf-8") if isinstance(body, str) else body or b""
|
||||
content_length = len(content_bytes)
|
||||
return httpx.Response(
|
||||
status_code=200,
|
||||
headers={"Content-Length": str(content_length)},
|
||||
request=httpx.Request("PUT", url),
|
||||
)
|
||||
|
||||
async def post(
|
||||
self,
|
||||
url: str,
|
||||
data=None,
|
||||
json=None,
|
||||
params=None,
|
||||
headers=None,
|
||||
timeout=None,
|
||||
stream: bool = False,
|
||||
logging_obj=None,
|
||||
files=None,
|
||||
content=None,
|
||||
):
|
||||
self.post_calls.append(
|
||||
{
|
||||
"url": url,
|
||||
"data": data,
|
||||
"json": json,
|
||||
"params": params,
|
||||
"headers": headers or {},
|
||||
"timeout": timeout,
|
||||
"stream": stream,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
raw = json if json is not None else (data if data is not None else content)
|
||||
payload = raw if isinstance(raw, dict) else json_module.loads(raw)
|
||||
job_name = payload["jobName"]
|
||||
job_arn = f"arn:aws:bedrock:us-west-2:941277531214:model-invocation-job/{job_name}"
|
||||
self.batch_jobs[job_arn] = {
|
||||
"jobArn": job_arn,
|
||||
"jobName": job_name,
|
||||
"modelId": payload["modelId"],
|
||||
"roleArn": payload["roleArn"],
|
||||
"status": "InProgress",
|
||||
"submitTime": "2026-06-02T03:50:00Z",
|
||||
"lastModifiedTime": "2026-06-02T03:55:00Z",
|
||||
"inputDataConfig": payload["inputDataConfig"],
|
||||
"outputDataConfig": payload["outputDataConfig"],
|
||||
}
|
||||
return httpx.Response(
|
||||
status_code=200,
|
||||
json={"jobArn": job_arn, "jobName": job_name, "status": "Submitted"},
|
||||
request=httpx.Request("POST", url),
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@ -34,12 +131,34 @@ async def test_async_create_file():
|
||||
file_name = "bedrock_batch_completions.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider="bedrock",
|
||||
s3_bucket_name="litellm-proxy",
|
||||
capture_client = _CaptureAsyncHTTPHandler()
|
||||
with (
|
||||
patch.dict(os.environ, _BEDROCK_TEST_AWS_ENV),
|
||||
open(file_path, "rb") as batch_file,
|
||||
):
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=batch_file,
|
||||
purpose="batch",
|
||||
custom_llm_provider="bedrock",
|
||||
s3_bucket_name="litellm-proxy-941277531214",
|
||||
client=capture_client,
|
||||
)
|
||||
|
||||
assert len(capture_client.put_calls) == 1
|
||||
put_call = capture_client.put_calls[0]
|
||||
assert put_call["url"].startswith(
|
||||
"https://s3.us-west-2.amazonaws.com/litellm-proxy-941277531214/"
|
||||
)
|
||||
assert "/litellm-bedrock-files-us.anthropic.claude-haiku-4-5-20251001-v1-0-" in (
|
||||
put_call["url"]
|
||||
)
|
||||
assert put_call["url"].endswith(".jsonl")
|
||||
assert put_call["headers"]["Authorization"].startswith("AWS4-HMAC-SHA256")
|
||||
assert "recordId" in put_call["data"]
|
||||
assert file_obj.id.startswith(
|
||||
"s3://litellm-proxy-941277531214/litellm-bedrock-files-"
|
||||
)
|
||||
assert file_obj.filename.endswith(".jsonl")
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@ -51,36 +170,54 @@ async def test_async_file_and_batch():
|
||||
file_name = "bedrock_batch_completions.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider="bedrock",
|
||||
s3_bucket_name="litellm-proxy",
|
||||
)
|
||||
print("CREATED FILE RESPONSE=", file_obj)
|
||||
capture_client = _CaptureAsyncHTTPHandler()
|
||||
with patch.dict(os.environ, _BEDROCK_TEST_AWS_ENV):
|
||||
with open(file_path, "rb") as batch_file:
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=batch_file,
|
||||
purpose="batch",
|
||||
custom_llm_provider="bedrock",
|
||||
s3_bucket_name="litellm-proxy-941277531214",
|
||||
client=capture_client,
|
||||
)
|
||||
assert len(capture_client.put_calls) == 1
|
||||
print("CREATED FILE RESPONSE=", file_obj)
|
||||
|
||||
# create batch
|
||||
create_batch_response = await litellm.acreate_batch(
|
||||
completion_window="24h",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=file_obj.id,
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
custom_llm_provider="bedrock",
|
||||
#########################################################
|
||||
# bedrock specific params
|
||||
#########################################################
|
||||
model="us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
aws_batch_role_arn="arn:aws:iam::888602223428:role/service-role/AmazonBedrockExecutionRoleForAgents_BB9HNW6V4CV",
|
||||
)
|
||||
print("CREATED BATCH RESPONSE=", create_batch_response)
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.llm_http_handler.get_async_httpx_client",
|
||||
return_value=capture_client,
|
||||
):
|
||||
# create batch
|
||||
create_batch_response = await litellm.acreate_batch(
|
||||
completion_window="24h",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=file_obj.id,
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
custom_llm_provider="bedrock",
|
||||
#########################################################
|
||||
# bedrock specific params
|
||||
#########################################################
|
||||
model="us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
aws_batch_role_arn="arn:aws:iam::941277531214:role/service-role/AmazonBedrockExecutionRoleForAgents_BB9HNW6V4CV",
|
||||
)
|
||||
assert len(capture_client.post_calls) == 1
|
||||
print("CREATED BATCH RESPONSE=", create_batch_response)
|
||||
|
||||
# retrieve batch
|
||||
retrieve_batch_response = await litellm.aretrieve_batch(
|
||||
batch_id=create_batch_response.id,
|
||||
custom_llm_provider="bedrock",
|
||||
model="us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
)
|
||||
print("RETRIEVED BATCH RESPONSE=", retrieve_batch_response)
|
||||
# retrieve batch
|
||||
mock_bedrock_client = MagicMock()
|
||||
mock_bedrock_client.get_model_invocation_job.side_effect = (
|
||||
lambda jobIdentifier: capture_client.batch_jobs[jobIdentifier]
|
||||
)
|
||||
with patch("boto3.client", return_value=mock_bedrock_client):
|
||||
retrieve_batch_response = await litellm.aretrieve_batch(
|
||||
batch_id=create_batch_response.id,
|
||||
custom_llm_provider="bedrock",
|
||||
model="us.anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
)
|
||||
mock_bedrock_client.get_model_invocation_job.assert_called_once_with(
|
||||
jobIdentifier=create_batch_response.id
|
||||
)
|
||||
print("RETRIEVED BATCH RESPONSE=", retrieve_batch_response)
|
||||
|
||||
# Validate the response
|
||||
assert retrieve_batch_response.id == create_batch_response.id
|
||||
@ -101,52 +238,36 @@ async def test_mock_bedrock_file_url_mapping():
|
||||
"""
|
||||
print("Testing Bedrock file URL mapping")
|
||||
|
||||
captured_put_url = None
|
||||
|
||||
async def mock_async_create_file(transformed_request, **kwargs):
|
||||
nonlocal captured_put_url
|
||||
# Capture PUT URL from transformed request
|
||||
if isinstance(transformed_request, dict) and "url" in transformed_request:
|
||||
captured_put_url = transformed_request["url"]
|
||||
|
||||
# Call the real method to get actual response
|
||||
from litellm.files.main import base_llm_http_handler
|
||||
|
||||
return await base_llm_http_handler.__class__.async_create_file(
|
||||
base_llm_http_handler, transformed_request, **kwargs
|
||||
)
|
||||
|
||||
with patch(
|
||||
"litellm.files.main.base_llm_http_handler.async_create_file",
|
||||
side_effect=mock_async_create_file,
|
||||
capture_client = _CaptureAsyncHTTPHandler()
|
||||
with (
|
||||
patch.dict(os.environ, _BEDROCK_TEST_AWS_ENV),
|
||||
open(
|
||||
os.path.join(os.path.dirname(__file__), "bedrock_batch_completions.jsonl"),
|
||||
"rb",
|
||||
) as batch_file,
|
||||
):
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=open(
|
||||
os.path.join(
|
||||
os.path.dirname(__file__), "bedrock_batch_completions.jsonl"
|
||||
),
|
||||
"rb",
|
||||
),
|
||||
file=batch_file,
|
||||
purpose="batch",
|
||||
custom_llm_provider="bedrock",
|
||||
s3_bucket_name="litellm-proxy",
|
||||
s3_bucket_name="litellm-proxy-941277531214",
|
||||
client=capture_client,
|
||||
)
|
||||
|
||||
print(f"PUT URL: {captured_put_url}")
|
||||
print(f"File ID: {file_obj.id}")
|
||||
captured_put_url = capture_client.put_calls[0]["url"]
|
||||
print(f"PUT URL: {captured_put_url}")
|
||||
print(f"File ID: {file_obj.id}")
|
||||
|
||||
# Validate URL was captured and response is correct
|
||||
assert captured_put_url is not None
|
||||
assert file_obj.id.startswith("s3://")
|
||||
# Validate URL was captured and response is correct
|
||||
assert captured_put_url is not None
|
||||
assert file_obj.id.startswith("s3://")
|
||||
|
||||
# Verify mapping
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
# Verify mapping
|
||||
from litellm.llms.bedrock.files.transformation import BedrockFilesConfig
|
||||
|
||||
bedrock_config = BedrockFilesConfig()
|
||||
expected_s3_uri, _ = bedrock_config._convert_https_url_to_s3_uri(
|
||||
captured_put_url
|
||||
)
|
||||
assert file_obj.id == expected_s3_uri
|
||||
bedrock_config = BedrockFilesConfig()
|
||||
expected_s3_uri, _ = bedrock_config._convert_https_url_to_s3_uri(captured_put_url)
|
||||
assert file_obj.id == expected_s3_uri
|
||||
|
||||
|
||||
@pytest.mark.asyncio()
|
||||
@ -237,8 +358,12 @@ def test_bedrock_batch_with_encryption_key_in_post_request():
|
||||
mock_response.raise_for_status.return_value = None
|
||||
return mock_response
|
||||
|
||||
with patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post", side_effect=mock_post
|
||||
with (
|
||||
patch.dict(os.environ, _BEDROCK_TEST_AWS_ENV),
|
||||
patch(
|
||||
"litellm.llms.custom_httpx.http_handler.HTTPHandler.post",
|
||||
side_effect=mock_post,
|
||||
),
|
||||
):
|
||||
response = litellm.create_batch(
|
||||
completion_window="24h",
|
||||
|
||||
@ -4,7 +4,6 @@ import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
import tempfile
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@ -15,12 +14,10 @@ sys.path.insert(
|
||||
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
from typing import Optional
|
||||
import litellm
|
||||
from litellm import create_batch, create_file
|
||||
from litellm._logging import verbose_logger
|
||||
import openai
|
||||
|
||||
@ -28,7 +25,6 @@ verbose_logger.setLevel(logging.DEBUG)
|
||||
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
import random
|
||||
import socket
|
||||
import httpx
|
||||
from unittest.mock import patch, MagicMock
|
||||
@ -49,6 +45,21 @@ skip_if_no_openai_network = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
|
||||
async def _wait_for_standard_logging_object(
|
||||
custom_logger: "TestCustomLogger", timeout: float = 15.0
|
||||
) -> StandardLoggingPayload:
|
||||
from litellm.litellm_core_utils.logging_worker import GLOBAL_LOGGING_WORKER
|
||||
|
||||
deadline = time.monotonic() + timeout
|
||||
while time.monotonic() < deadline:
|
||||
await GLOBAL_LOGGING_WORKER.flush()
|
||||
if custom_logger.standard_logging_object is not None:
|
||||
return custom_logger.standard_logging_object
|
||||
await asyncio.sleep(0.25)
|
||||
assert custom_logger.standard_logging_object is not None
|
||||
return custom_logger.standard_logging_object
|
||||
|
||||
|
||||
def load_vertex_ai_credentials():
|
||||
# Define the path to the vertex_key.json file
|
||||
print("loading vertex ai credentials")
|
||||
@ -95,7 +106,7 @@ def load_vertex_ai_credentials():
|
||||
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
||||
@pytest.mark.asyncio
|
||||
@skip_if_no_openai_network
|
||||
async def test_create_batch(provider):
|
||||
async def test_create_batch(provider, tmp_path):
|
||||
"""
|
||||
1. Create File for Batch completion
|
||||
2. Create Batch Request
|
||||
@ -108,11 +119,12 @@ async def test_create_batch(provider):
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=provider,
|
||||
)
|
||||
with open(file_path, "rb") as batch_file:
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=batch_file,
|
||||
purpose="batch",
|
||||
custom_llm_provider=provider,
|
||||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
|
||||
batch_input_file_id = file_obj.id
|
||||
@ -161,10 +173,8 @@ async def test_create_batch(provider):
|
||||
|
||||
result = file_content.content
|
||||
|
||||
result_file_name = "batch_job_results_furniture.jsonl"
|
||||
|
||||
with open(result_file_name, "wb") as file:
|
||||
file.write(result)
|
||||
result_file_path = tmp_path / "batch_job_results_furniture.jsonl"
|
||||
result_file_path.write_bytes(result)
|
||||
|
||||
# Cancel Batch - handle race condition where batch may already be completed
|
||||
try:
|
||||
@ -268,9 +278,8 @@ def cleanup_azure_ft_models():
|
||||
|
||||
@pytest.mark.parametrize("provider", ["openai"])
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
@skip_if_no_openai_network
|
||||
async def test_async_create_batch(provider):
|
||||
async def test_async_create_batch(provider, tmp_path):
|
||||
"""
|
||||
1. Create File for Batch completion
|
||||
2. Create Batch Request
|
||||
@ -279,17 +288,16 @@ async def test_async_create_batch(provider):
|
||||
litellm._turn_on_debug()
|
||||
print("Testing async create batch")
|
||||
litellm.logging_callback_manager._reset_all_callbacks()
|
||||
custom_logger = TestCustomLogger()
|
||||
litellm.callbacks = [custom_logger, "datadog"]
|
||||
|
||||
file_name = "openai_batch_completions.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=open(file_path, "rb"),
|
||||
purpose="batch",
|
||||
custom_llm_provider=provider,
|
||||
)
|
||||
with open(file_path, "rb") as batch_file:
|
||||
file_obj = await litellm.acreate_file(
|
||||
file=batch_file,
|
||||
purpose="batch",
|
||||
custom_llm_provider=provider,
|
||||
)
|
||||
print("Response from creating file=", file_obj)
|
||||
|
||||
await asyncio.sleep(10)
|
||||
@ -302,6 +310,8 @@ async def test_async_create_batch(provider):
|
||||
"user_api_key_alias": "special_api_key_alias",
|
||||
"user_api_key_team_alias": "special_team_alias",
|
||||
}
|
||||
custom_logger = TestCustomLogger()
|
||||
litellm.callbacks = [custom_logger, "datadog"]
|
||||
create_batch_response = await litellm.acreate_batch(
|
||||
completion_window="24h",
|
||||
endpoint="/v1/chat/completions",
|
||||
@ -325,19 +335,18 @@ async def test_async_create_batch(provider):
|
||||
create_batch_response.input_file_id == batch_input_file_id
|
||||
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
||||
|
||||
await asyncio.sleep(6)
|
||||
# Assert that the create batch event is logged on CustomLogger
|
||||
assert custom_logger.standard_logging_object is not None
|
||||
standard_logging_object = await _wait_for_standard_logging_object(custom_logger)
|
||||
print(
|
||||
"standard_logging_object=",
|
||||
json.dumps(custom_logger.standard_logging_object, indent=4, default=str),
|
||||
json.dumps(standard_logging_object, indent=4, default=str),
|
||||
)
|
||||
assert (
|
||||
custom_logger.standard_logging_object["metadata"]["user_api_key_alias"]
|
||||
standard_logging_object["metadata"]["user_api_key_alias"]
|
||||
== extra_metadata_field["user_api_key_alias"]
|
||||
)
|
||||
assert (
|
||||
custom_logger.standard_logging_object["metadata"]["user_api_key_team_alias"]
|
||||
standard_logging_object["metadata"]["user_api_key_team_alias"]
|
||||
== extra_metadata_field["user_api_key_team_alias"]
|
||||
)
|
||||
|
||||
@ -383,10 +392,8 @@ async def test_async_create_batch(provider):
|
||||
|
||||
print("all_files_list = ", all_files_list)
|
||||
|
||||
result_file_name = "batch_job_results_furniture.jsonl"
|
||||
|
||||
with open(result_file_name, "wb") as file:
|
||||
file.write(file_content.content)
|
||||
result_file_path = tmp_path / "batch_job_results_furniture.jsonl"
|
||||
result_file_path.write_bytes(file_content.content)
|
||||
|
||||
# Cancel Batch - handle race condition where batch may already be completed
|
||||
try:
|
||||
@ -407,11 +414,6 @@ async def test_async_create_batch(provider):
|
||||
print(f"Unexpected error during batch cancellation: {e}")
|
||||
raise
|
||||
|
||||
if random.randint(1, 3) == 1:
|
||||
print("Running random cleanup of Azure files and models...")
|
||||
cleanup_azure_files()
|
||||
cleanup_azure_ft_models()
|
||||
|
||||
|
||||
mock_file_response = {
|
||||
"kind": "storage#object",
|
||||
|
||||
@ -212,6 +212,8 @@ async def test_text_message_blocked_by_guardrail_no_ai_response():
|
||||
"policy",
|
||||
"can't repeat",
|
||||
"cannot repeat",
|
||||
"can't say",
|
||||
"cannot say",
|
||||
"won't repeat",
|
||||
"can't assist",
|
||||
"can't help",
|
||||
|
||||
@ -178,9 +178,12 @@ def test_should_distinguish_different_aws_access_keys():
|
||||
[
|
||||
("api.openai.com", True),
|
||||
("api.anthropic.com", True),
|
||||
("bedrock.us-east-1.amazonaws.com", True),
|
||||
("bedrock-runtime.us-east-1.amazonaws.com", True),
|
||||
("bedrock-runtime-fips.us-east-1.amazonaws.com", True),
|
||||
("api.us-east-1.bedrock-runtime.amazonaws.com", False),
|
||||
("s3.us-west-2.amazonaws.com", True),
|
||||
("litellm-proxy-test.s3.us-west-2.amazonaws.com", True),
|
||||
("foo.bar.openai.com", True),
|
||||
("127.0.0.1", False),
|
||||
("localhost", False),
|
||||
|
||||
@ -4,6 +4,8 @@ Tests for the LoggingWorker class to ensure graceful shutdown handling.
|
||||
|
||||
import asyncio
|
||||
import contextvars
|
||||
import io
|
||||
import logging
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
@ -65,6 +67,57 @@ class TestLoggingWorker:
|
||||
# Verify the queue is empty after clearing
|
||||
assert logging_worker._queue.empty()
|
||||
|
||||
def test_flush_on_exit_suppresses_closed_handler_errors(self, capsys):
|
||||
"""Atexit flushing should not print logging errors after streams close."""
|
||||
worker = LoggingWorker(timeout=1.0, max_queue_size=10)
|
||||
worker._queue = asyncio.Queue(maxsize=10)
|
||||
|
||||
stream = io.StringIO()
|
||||
handler = logging.StreamHandler(stream)
|
||||
logger = logging.getLogger("test_logging_worker_closed_handler")
|
||||
logger.addHandler(handler)
|
||||
logger.setLevel(logging.DEBUG)
|
||||
logger.propagate = False
|
||||
|
||||
async def log_with_closed_handler():
|
||||
logger.debug("flush me during shutdown")
|
||||
|
||||
previous_raise_exceptions = logging.raiseExceptions
|
||||
logging.raiseExceptions = True
|
||||
|
||||
try:
|
||||
worker.enqueue(log_with_closed_handler())
|
||||
stream.close()
|
||||
|
||||
worker._flush_on_exit()
|
||||
|
||||
captured = capsys.readouterr()
|
||||
assert "I/O operation on closed file" not in captured.err
|
||||
finally:
|
||||
logging.raiseExceptions = previous_raise_exceptions
|
||||
logger.removeHandler(handler)
|
||||
|
||||
def test_flush_on_exit_swallows_errors_and_drains_remaining(self):
|
||||
"""A failing queued coroutine must not abort the atexit drain of later events."""
|
||||
worker = LoggingWorker(timeout=1.0, max_queue_size=10)
|
||||
worker._queue = asyncio.Queue(maxsize=10)
|
||||
|
||||
processed = []
|
||||
|
||||
async def raises_during_flush():
|
||||
raise RuntimeError("boom during shutdown flush")
|
||||
|
||||
async def records_during_flush():
|
||||
processed.append("ran")
|
||||
|
||||
worker.enqueue(raises_during_flush())
|
||||
worker.enqueue(records_during_flush())
|
||||
|
||||
worker._flush_on_exit()
|
||||
|
||||
assert processed == ["ran"]
|
||||
assert worker._queue.empty()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_worker_handles_cancellation_gracefully(self, logging_worker):
|
||||
"""Test that the worker handles cancellation without throwing exceptions."""
|
||||
|
||||
@ -14,6 +14,7 @@ from tests._vcr_conftest_common import ( # noqa: E402
|
||||
KEY_FINGERPRINT_HEADER,
|
||||
KEY_FINGERPRINT_MATCHER_NAME,
|
||||
SAFE_BODY_MATCHER_NAME,
|
||||
TOLERANT_PATH_MATCHER_NAME,
|
||||
TOLERANT_QUERY_MATCHER_NAME,
|
||||
_before_record_request,
|
||||
_is_credential_exchange_request,
|
||||
@ -21,6 +22,7 @@ from tests._vcr_conftest_common import ( # noqa: E402
|
||||
_key_fingerprint_matcher,
|
||||
_normalize_volatile_tokens,
|
||||
_safe_body_matcher,
|
||||
_tolerant_path_matcher,
|
||||
_tolerant_query_matcher,
|
||||
vcr_config_dict,
|
||||
)
|
||||
@ -198,6 +200,20 @@ def test_normalize_volatile_tokens_collapses_uuid_and_timestamps():
|
||||
assert _normalize_volatile_tokens(e) == _normalize_volatile_tokens(f)
|
||||
|
||||
|
||||
def test_normalize_volatile_tokens_collapses_bedrock_batch_job_names():
|
||||
a = (
|
||||
b'{"jobName":"litellm-batch-aaaaaaaa",'
|
||||
b'"outputDataConfig":{"s3OutputDataConfig":'
|
||||
b'{"s3Uri":"s3://bucket/litellm-batch-outputs/litellm-batch-aaaaaaaa/"}}}'
|
||||
)
|
||||
b = (
|
||||
b'{"jobName":"litellm-batch-bbbbbbbb",'
|
||||
b'"outputDataConfig":{"s3OutputDataConfig":'
|
||||
b'{"s3Uri":"s3://bucket/litellm-batch-outputs/litellm-batch-bbbbbbbb/"}}}'
|
||||
)
|
||||
assert _normalize_volatile_tokens(a) == _normalize_volatile_tokens(b)
|
||||
|
||||
|
||||
def test_normalize_volatile_tokens_leaves_deterministic_bodies_unchanged():
|
||||
body = b'{"model":"claude-haiku-4-5-20251001","temperature":0.0,"n":2}'
|
||||
assert _normalize_volatile_tokens(body) == body
|
||||
@ -239,6 +255,83 @@ def test_match_on_uses_tolerant_query_not_builtin():
|
||||
assert "query" not in cfg["match_on"]
|
||||
|
||||
|
||||
def test_match_on_uses_tolerant_path_not_builtin():
|
||||
cfg = vcr_config_dict()
|
||||
assert TOLERANT_PATH_MATCHER_NAME in cfg["match_on"]
|
||||
assert "path" not in cfg["match_on"]
|
||||
|
||||
|
||||
def test_tolerant_path_normalizes_bedrock_managed_s3_file_uuid():
|
||||
from vcr.request import Request
|
||||
|
||||
a = Request(
|
||||
method="PUT",
|
||||
uri=(
|
||||
"https://s3.us-west-2.amazonaws.com/litellm-proxy-test/"
|
||||
"litellm-bedrock-files/us.anthropic.claude-haiku-4-5-20251001-v1-0-"
|
||||
"123e4567-e89b-12d3-a456-426614174000.jsonl"
|
||||
),
|
||||
body=b"",
|
||||
headers={},
|
||||
)
|
||||
b = Request(
|
||||
method="PUT",
|
||||
uri=(
|
||||
"https://s3.us-west-2.amazonaws.com/litellm-proxy-test/"
|
||||
"litellm-bedrock-files/us.anthropic.claude-haiku-4-5-20251001-v1-0-"
|
||||
"abcdefab-1234-5678-9abc-def012345678.jsonl"
|
||||
),
|
||||
body=b"",
|
||||
headers={},
|
||||
)
|
||||
_tolerant_path_matcher(a, b)
|
||||
|
||||
|
||||
def test_tolerant_path_normalizes_bedrock_batch_s3_file_uuid():
|
||||
from vcr.request import Request
|
||||
|
||||
a = Request(
|
||||
method="PUT",
|
||||
uri=(
|
||||
"https://s3.us-west-2.amazonaws.com/litellm-proxy-test/"
|
||||
"litellm-bedrock-files-us.anthropic.claude-haiku-4-5-20251001-v1-0-"
|
||||
"a48e9ec2-5594-45e3-bdbb-44f5d71c06f3.jsonl"
|
||||
),
|
||||
body=b"",
|
||||
headers={},
|
||||
)
|
||||
b = Request(
|
||||
method="PUT",
|
||||
uri=(
|
||||
"https://s3.us-west-2.amazonaws.com/litellm-proxy-test/"
|
||||
"litellm-bedrock-files-us.anthropic.claude-haiku-4-5-20251001-v1-0-"
|
||||
"123e4567-e89b-12d3-a456-426614174000.jsonl"
|
||||
),
|
||||
body=b"",
|
||||
headers={},
|
||||
)
|
||||
_tolerant_path_matcher(a, b)
|
||||
|
||||
|
||||
def test_tolerant_path_still_rejects_different_regular_paths():
|
||||
from vcr.request import Request
|
||||
|
||||
a = Request(
|
||||
method="GET",
|
||||
uri="https://api.openai.com/v1/files/file-a/content",
|
||||
body=b"",
|
||||
headers={},
|
||||
)
|
||||
b = Request(
|
||||
method="GET",
|
||||
uri="https://api.openai.com/v1/files/file-b/content",
|
||||
body=b"",
|
||||
headers={},
|
||||
)
|
||||
with pytest.raises(AssertionError):
|
||||
_tolerant_path_matcher(a, b)
|
||||
|
||||
|
||||
def test_telemetry_request_detection():
|
||||
assert _is_telemetry_request(
|
||||
_req(b"x", uri="https://us.cloud.langfuse.com/api/public/ingestion")
|
||||
|
||||
Loading…
Reference in New Issue
Block a user