fix(guardrails): read CrowdStrike AIDR identity from both metadata bags (#29991)

Capture user_id and extra_info from metadata or litellm_metadata. The single-bag read dropped identity whenever a request carried a present litellm_metadata field (null or a user-supplied dict), since /chat/completions routes the authenticated identity into metadata while the guardrail read litellm_metadata first
This commit is contained in:
yuneng-jiang 2026-06-08 17:46:28 -07:00 committed by GitHub
parent 411bd3da5b
commit 1bbaf1c39d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 62 additions and 2 deletions

View File

@ -105,6 +105,16 @@ def _extract_text_from_content(content: object) -> str:
return ""
def _merge_metadata_bags(request_data: Mapping[str, Any]) -> Optional[dict[str, Any]]:
merged: dict[str, Any] = {}
present = False
for bag in (request_data.get("metadata"), request_data.get("litellm_metadata")):
if isinstance(bag, Mapping):
present = True
merged.update(bag)
return merged if present else None
class CrowdStrikeAIDRHandler(CustomGuardrail):
"""
CrowdStrike AIDR AI Guardrail handler to interact with the CrowdStrike AIDR
@ -321,8 +331,8 @@ class CrowdStrikeAIDRHandler(CustomGuardrail):
if model:
ai_guard_payload["model"] = model
metadata = request_data.get("litellm_metadata", request_data.get("metadata"))
if isinstance(metadata, Mapping):
metadata = _merge_metadata_bags(request_data)
if metadata is not None:
user_id = metadata.get("user_api_key_user_id")
if user_id:
ai_guard_payload["user_id"] = user_id

View File

@ -529,6 +529,56 @@ async def test_apply_guardrail_no_metadata_skips_user_fields(
assert "extra_info" not in payload
@pytest.mark.asyncio
@pytest.mark.parametrize(
"litellm_metadata, metadata",
[
(None, {"user_api_key_user_id": "uid-abc", "user_api_key_user_email": "alice@example.com"}),
({"trace_id": "t1"}, {"user_api_key_user_id": "uid-abc", "user_api_key_user_email": "alice@example.com"}),
(["unexpected"], {"user_api_key_user_id": "uid-abc", "user_api_key_user_email": "alice@example.com"}),
({"user_api_key_user_id": "uid-abc", "user_api_key_user_email": "alice@example.com"}, {"trace_id": "t1"}),
],
ids=["identity_in_metadata_llm_none", "identity_in_metadata_llm_user_dict", "identity_in_metadata_llm_non_mapping", "identity_in_litellm_metadata"],
)
async def test_apply_guardrail_reads_identity_from_either_metadata_bag(
crowdstrike_aidr_guardrail: CrowdStrikeAIDRHandler,
litellm_metadata,
metadata,
) -> None:
inputs: GenericGuardrailAPIInputs = {
"texts": ["Hello"],
"structured_messages": [{"role": "user", "content": "Hello"}],
"model": "gpt-4o",
}
request_data = {
"messages": inputs["structured_messages"],
"model": "gpt-4o",
"litellm_metadata": litellm_metadata,
"metadata": metadata,
}
guardrail_endpoint = (
f"{crowdstrike_aidr_guardrail.api_base}/v1/guard_chat_completions"
)
with patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post",
return_value=httpx.Response(
status_code=200,
json={"result": {"blocked": False, "transformed": False}},
request=httpx.Request(method="POST", url=guardrail_endpoint),
),
) as mock_method:
await crowdstrike_aidr_guardrail.apply_guardrail(
inputs=inputs,
request_data=request_data,
input_type="request",
)
payload = mock_method.call_args.kwargs["json"]
assert payload["user_id"] == "uid-abc"
assert payload["extra_info"] == {"user_name": "alice@example.com"}
@pytest.mark.asyncio
async def test_apply_guardrail_request_skipped_messages_stay_aligned(
crowdstrike_aidr_guardrail: CrowdStrikeAIDRHandler,