fix(proxy): strip NUL bytes from spend log payloads to prevent PostgreSQL 22P05 (#29515)
A raw NUL byte (\x00) in request/response content is serialized by json.dumps
into the \u0000 JSON escape. When update_spend_logs writes this to the
LiteLLM_SpendLogs jsonb columns, Postgres rejects the whole batch with
error 22P05 ("unsupported Unicode escape sequence ... cannot be converted to
text"), crashing the periodic update_spend job and dropping the spend-log batch.
Centralize stripping in safe_dumps (covers metadata/response paths and any
future caller) and route the messages, proxy_server_request, request_tags, and
response (string branch) payloads through it instead of json.dumps. Dict keys
are stripped too.
Adds regression tests for safe_dumps and the spend-log message, response, and
request_tags payload builders.
Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
parent
ce7b1fd29d
commit
efaafbbd02
@ -6,10 +6,16 @@ from pydantic import BaseModel
|
||||
from litellm.constants import DEFAULT_MAX_RECURSE_DEPTH
|
||||
|
||||
|
||||
def strip_null_bytes(value: str) -> str:
|
||||
"""Strip NUL bytes, which PostgreSQL text/jsonb columns reject (error 22P05)."""
|
||||
return value.replace("\x00", "")
|
||||
|
||||
|
||||
def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str:
|
||||
"""
|
||||
Recursively serialize data while detecting circular references.
|
||||
If a circular reference is detected then a marker string is returned.
|
||||
NUL bytes are stripped from strings to prevent PostgreSQL 22P05 errors.
|
||||
"""
|
||||
|
||||
def _serialize(obj: Any, seen: set, depth: int) -> Any:
|
||||
@ -17,7 +23,9 @@ def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str:
|
||||
if depth > max_depth:
|
||||
return "MaxDepthExceeded"
|
||||
# Base-case: if it is a primitive, simply return it.
|
||||
if isinstance(obj, (str, int, float, bool, type(None))):
|
||||
if isinstance(obj, str):
|
||||
return strip_null_bytes(obj)
|
||||
if isinstance(obj, (int, float, bool, type(None))):
|
||||
return obj
|
||||
# Check for circular reference.
|
||||
if id(obj) in seen:
|
||||
@ -28,7 +36,7 @@ def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str:
|
||||
result = {}
|
||||
for k, v in obj.items():
|
||||
if isinstance(k, (str)):
|
||||
result[k] = _serialize(v, seen, depth + 1)
|
||||
result[strip_null_bytes(k)] = _serialize(v, seen, depth + 1)
|
||||
seen.remove(id(obj))
|
||||
return result
|
||||
elif isinstance(obj, list):
|
||||
@ -51,7 +59,7 @@ def safe_dumps(data: Any, max_depth: int = DEFAULT_MAX_RECURSE_DEPTH) -> str:
|
||||
else:
|
||||
# Fall back to string conversion for non-serializable objects.
|
||||
try:
|
||||
return str(obj)
|
||||
return strip_null_bytes(str(obj))
|
||||
except Exception:
|
||||
return "Unserializable Object"
|
||||
|
||||
|
||||
@ -24,7 +24,7 @@ from litellm.litellm_core_utils.core_helpers import (
|
||||
get_litellm_metadata_from_kwargs,
|
||||
reconstruct_model_name,
|
||||
)
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps, strip_null_bytes
|
||||
from litellm.proxy._types import SpendLogsMetadata, SpendLogsPayload
|
||||
from litellm.proxy.spend_tracking.spend_log_error_logger import spend_log_error
|
||||
from litellm.proxy.utils import PrismaClient, hash_token
|
||||
@ -304,7 +304,7 @@ def get_logging_payload( # noqa: PLR0915
|
||||
# BUG FIX: Don't overwrite api_key when standard_logging_payload is None
|
||||
# The api_key was already extracted from metadata (line 243) and hashed (lines 256-259)
|
||||
request_tags = (
|
||||
json.dumps(metadata.get("tags", []))
|
||||
safe_dumps(metadata.get("tags", []))
|
||||
if isinstance(metadata.get("tags", []), list)
|
||||
else "[]"
|
||||
)
|
||||
@ -312,7 +312,7 @@ def get_logging_payload( # noqa: PLR0915
|
||||
standard_logging_payload is not None
|
||||
and standard_logging_payload.get("request_tags") is not None
|
||||
): # use 'tags' from standard logging payload instead
|
||||
request_tags = json.dumps(standard_logging_payload["request_tags"])
|
||||
request_tags = safe_dumps(standard_logging_payload["request_tags"])
|
||||
|
||||
_model_id = metadata.get("model_info", {}).get("id", "")
|
||||
_model_group = metadata.get("model_group", "")
|
||||
@ -606,7 +606,7 @@ def _get_messages_for_spend_logs_payload(
|
||||
messages = standard_logging_payload.get("messages")
|
||||
if messages is not None:
|
||||
try:
|
||||
return json.dumps(messages, default=str)
|
||||
return safe_dumps(messages)
|
||||
except Exception:
|
||||
return "{}"
|
||||
return "{}"
|
||||
@ -976,7 +976,7 @@ def _get_proxy_server_request_for_spend_logs_payload(
|
||||
perform_redaction(model_call_details=_request_body, result=None)
|
||||
|
||||
_request_body = _sanitize_request_body_for_spend_logs_payload(_request_body)
|
||||
_request_body_json_str = json.dumps(_request_body, default=str)
|
||||
_request_body_json_str = safe_dumps(_request_body)
|
||||
if LITELLM_TRUNCATED_PAYLOAD_FIELD in _request_body_json_str:
|
||||
verbose_proxy_logger.info(
|
||||
"Spend Log: request body was truncated before storing in DB. %s",
|
||||
@ -1059,7 +1059,7 @@ def _get_response_for_spend_logs_payload(
|
||||
if sanitized_response is None:
|
||||
return "{}"
|
||||
if isinstance(sanitized_response, str):
|
||||
result_str = sanitized_response
|
||||
result_str = strip_null_bytes(sanitized_response)
|
||||
else:
|
||||
result_str = safe_dumps(sanitized_response)
|
||||
if LITELLM_TRUNCATED_PAYLOAD_FIELD in result_str:
|
||||
|
||||
@ -8,7 +8,7 @@ sys.path.insert(
|
||||
0, os.path.abspath("../../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps
|
||||
from litellm.litellm_core_utils.safe_json_dumps import safe_dumps, strip_null_bytes
|
||||
|
||||
|
||||
def test_primitive_types():
|
||||
@ -140,6 +140,40 @@ def test_non_standard_dict_keys_complex():
|
||||
raise e
|
||||
|
||||
|
||||
def test_strip_null_bytes_helper():
|
||||
assert strip_null_bytes("hello\x00world") == "helloworld"
|
||||
assert strip_null_bytes("\x00\x00abc\x00") == "abc"
|
||||
assert strip_null_bytes("no null here") == "no null here"
|
||||
|
||||
|
||||
def test_null_byte_stripped_from_string():
|
||||
out = safe_dumps("hello\x00world")
|
||||
assert "\\u0000" not in out
|
||||
assert json.loads(out) == "helloworld"
|
||||
|
||||
|
||||
def test_null_byte_stripped_in_nested_structure():
|
||||
data = {
|
||||
"messages": [{"role": "user", "content": "bad\x00content"}],
|
||||
"nested": {"k\x00ey": "v\x00alue"},
|
||||
}
|
||||
out = safe_dumps(data)
|
||||
assert "\\u0000" not in out
|
||||
result = json.loads(out)
|
||||
assert result["messages"][0]["content"] == "badcontent"
|
||||
assert result["nested"] == {"key": "value"}
|
||||
|
||||
|
||||
def test_null_byte_stripped_in_fallback_str():
|
||||
class WithNullStr:
|
||||
def __str__(self):
|
||||
return "obj\x00repr"
|
||||
|
||||
out = safe_dumps({"obj": WithNullStr()})
|
||||
assert "\\u0000" not in out
|
||||
assert json.loads(out)["obj"] == "objrepr"
|
||||
|
||||
|
||||
def test_pydantic_base_model():
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
@ -300,6 +300,25 @@ def test_get_messages_for_spend_logs_realtime_returns_messages(mock_should_store
|
||||
assert parsed[1]["content"] == "What is the weather today?"
|
||||
|
||||
|
||||
@patch(
|
||||
"litellm.proxy.spend_tracking.spend_tracking_utils._should_store_prompts_and_responses_in_spend_logs"
|
||||
)
|
||||
def test_get_messages_for_spend_logs_strips_null_bytes(mock_should_store):
|
||||
"""Regression for PostgreSQL 22P05: NUL bytes must be stripped from messages."""
|
||||
mock_should_store.return_value = True
|
||||
payload = cast(
|
||||
StandardLoggingPayload,
|
||||
{
|
||||
"call_type": "_arealtime",
|
||||
"messages": [{"role": "user", "content": "hello\x00world"}],
|
||||
},
|
||||
)
|
||||
result = _get_messages_for_spend_logs_payload(payload)
|
||||
assert "\\u0000" not in result
|
||||
parsed = json.loads(result)
|
||||
assert parsed[0]["content"] == "helloworld"
|
||||
|
||||
|
||||
@patch(
|
||||
"litellm.proxy.spend_tracking.spend_tracking_utils._should_store_prompts_and_responses_in_spend_logs"
|
||||
)
|
||||
@ -370,6 +389,21 @@ def test_get_response_for_spend_logs_payload_truncates_large_base64(mock_should_
|
||||
assert parsed["data"][0]["other_field"] == "value"
|
||||
|
||||
|
||||
@patch(
|
||||
"litellm.proxy.spend_tracking.spend_tracking_utils._should_store_prompts_and_responses_in_spend_logs"
|
||||
)
|
||||
def test_get_response_for_spend_logs_payload_strips_null_bytes(mock_should_store):
|
||||
"""Regression for PostgreSQL 22P05: NUL bytes must be stripped from response."""
|
||||
mock_should_store.return_value = True
|
||||
payload = cast(
|
||||
StandardLoggingPayload,
|
||||
{"response": {"content": "answer\x00here"}},
|
||||
)
|
||||
response_json = _get_response_for_spend_logs_payload(payload)
|
||||
assert "\\u0000" not in response_json
|
||||
assert json.loads(response_json)["content"] == "answerhere"
|
||||
|
||||
|
||||
@patch(
|
||||
"litellm.proxy.spend_tracking.spend_tracking_utils._should_store_prompts_and_responses_in_spend_logs"
|
||||
)
|
||||
@ -936,6 +970,36 @@ def test_get_logging_payload_includes_overhead_in_spend_logs_metadata():
|
||||
), f"Expected overhead '{test_overhead_ms}', got '{metadata.get('litellm_overhead_time_ms')}'"
|
||||
|
||||
|
||||
@patch("litellm.proxy.proxy_server.master_key", None)
|
||||
@patch("litellm.proxy.proxy_server.general_settings", {})
|
||||
def test_get_logging_payload_strips_null_bytes_from_request_tags():
|
||||
"""Regression for PostgreSQL 22P05: NUL bytes must be stripped from request_tags."""
|
||||
kwargs = {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"litellm_params": {
|
||||
"metadata": {
|
||||
"user_api_key": "sk-test-key",
|
||||
"tags": ["clean-tag", "bad\x00tag"],
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
start_time = datetime.datetime.now(timezone.utc)
|
||||
end_time = datetime.datetime.now(timezone.utc)
|
||||
|
||||
payload = get_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj={},
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
request_tags = payload.get("request_tags")
|
||||
assert request_tags is not None
|
||||
assert "\\u0000" not in request_tags
|
||||
assert json.loads(request_tags) == ["clean-tag", "badtag"]
|
||||
|
||||
|
||||
@patch("litellm.proxy.proxy_server.master_key", None)
|
||||
@patch("litellm.proxy.proxy_server.general_settings", {})
|
||||
def test_get_logging_payload_handles_missing_overhead_gracefully():
|
||||
|
||||
Loading…
Reference in New Issue
Block a user