fix(passthrough): resolve costing model when body model is unknown (#30160)

This commit is contained in:
Yassin Kortam 2026-06-11 14:26:55 -07:00 committed by GitHub
parent 8e12d42ea7
commit 1828a7c6f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 312 additions and 0 deletions

View File

@ -100,6 +100,42 @@ class AnthropicPassthroughLoggingHandler:
return get_end_user_id_from_request_body(request_body)
return None
@staticmethod
def _resolve_costing_model(model: str, logging_obj: LiteLLMLoggingObj) -> str:
if model and model != "unknown":
return model
litellm_params = (getattr(logging_obj, "model_call_details", {}) or {}).get(
"litellm_params", {}
) or {}
deployment_model = litellm_params.get("model")
if deployment_model and deployment_model != "unknown":
return deployment_model
model_group = (litellm_params.get("metadata", {}) or {}).get("model_group")
if model_group:
return model_group.removeprefix("passthrough/")
return model
@staticmethod
def _extract_model_from_anthropic_chunks(
all_chunks: Sequence[Union[str, bytes]],
) -> Optional[str]:
for raw in all_chunks:
text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
for line in text.splitlines():
if not line.startswith("data:"):
continue
try:
data = json.loads(line[len("data:") :].strip())
except (json.JSONDecodeError, ValueError):
continue
if not isinstance(data, dict):
continue
if data.get("type") == "message_start":
model = (data.get("message") or {}).get("model")
if model:
return model
return None
@staticmethod
def _create_anthropic_response_logging_payload(
litellm_model_response: Union[ModelResponse, TextCompletionResponse],
@ -127,6 +163,10 @@ class AnthropicPassthroughLoggingHandler:
"custom_llm_provider"
)
model = AnthropicPassthroughLoggingHandler._resolve_costing_model(
model, logging_obj
)
# Prepend custom_llm_provider to model if not already present
model_for_cost = model
if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"):
@ -213,6 +253,15 @@ class AnthropicPassthroughLoggingHandler:
):
model = cast(str, litellm_logging_obj.model_call_details.get("model"))
if not model or model == "unknown":
chunk_model = (
AnthropicPassthroughLoggingHandler._extract_model_from_anthropic_chunks(
all_chunks
)
)
if chunk_model:
model = chunk_model
complete_streaming_response = (
AnthropicPassthroughLoggingHandler._build_complete_streaming_response(
all_chunks=all_chunks,

View File

@ -321,6 +321,269 @@ class TestAzureAnthropicCostCalculation:
assert call_kwargs["model"] == "azure_ai/claude-sonnet-4-5_gb_20250929"
assert call_kwargs["custom_llm_provider"] == "azure_ai"
@patch("litellm.completion_cost")
def test_cost_calculation_resolves_unknown_model_from_litellm_params(
self, mock_completion_cost
):
"""When the body model is the "unknown" sentinel, the deployment model
from litellm_params must be used for costing, not "unknown" (which makes
completion_cost raise and the cost silently fall back to $0)."""
from datetime import datetime
from litellm.types.utils import ModelResponse
mock_completion_cost.return_value = 0.001
logging_obj = self._create_mock_logging_obj(model="unknown")
logging_obj.model_call_details["litellm_params"] = {
"model": "anthropic/claude-3-5-haiku-20241022",
"metadata": {
"model_group": "passthrough/anthropic/claude-3-5-haiku-20241022"
},
}
logging_obj.litellm_params = logging_obj.model_call_details["litellm_params"]
mock_response = MagicMock(spec=ModelResponse)
mock_response.id = "test-id"
mock_response.model = "unknown"
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=mock_response,
model="unknown",
kwargs={},
start_time=datetime.now(),
end_time=datetime.now(),
logging_obj=logging_obj,
)
mock_completion_cost.assert_called_once()
assert (
mock_completion_cost.call_args[1]["model"]
== "anthropic/claude-3-5-haiku-20241022"
)
assert kwargs["response_cost"] == 0.001
assert kwargs["model"] == "anthropic/claude-3-5-haiku-20241022"
@patch("litellm.completion_cost")
def test_cost_calculation_resolves_unknown_model_from_model_group(
self, mock_completion_cost
):
"""With only model_group available (no deployment litellm_params.model),
the leading passthrough/ prefix must be stripped so the cost map can
resolve the model."""
from datetime import datetime
from litellm.types.utils import ModelResponse
mock_completion_cost.return_value = 0.002
logging_obj = self._create_mock_logging_obj(model="unknown")
logging_obj.model_call_details["litellm_params"] = {
"metadata": {
"model_group": "passthrough/anthropic/claude-3-5-haiku-20241022"
}
}
logging_obj.litellm_params = logging_obj.model_call_details["litellm_params"]
mock_response = MagicMock(spec=ModelResponse)
mock_response.id = "test-id"
mock_response.model = "unknown"
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=mock_response,
model="unknown",
kwargs={},
start_time=datetime.now(),
end_time=datetime.now(),
logging_obj=logging_obj,
)
mock_completion_cost.assert_called_once()
assert (
mock_completion_cost.call_args[1]["model"]
== "anthropic/claude-3-5-haiku-20241022"
)
assert kwargs["response_cost"] == 0.002
@patch("litellm.completion_cost")
def test_cost_calculation_skips_unknown_litellm_params_model_for_model_group(
self, mock_completion_cost
):
"""When litellm_params.model is itself the "unknown" sentinel, the
deployment-model branch must not short-circuit; resolution falls through
to model_group so costing still prices the real model instead of "unknown"."""
from datetime import datetime
from litellm.types.utils import ModelResponse
mock_completion_cost.return_value = 0.003
logging_obj = self._create_mock_logging_obj(model="unknown")
logging_obj.model_call_details["litellm_params"] = {
"model": "unknown",
"metadata": {
"model_group": "passthrough/anthropic/claude-3-5-haiku-20241022"
},
}
logging_obj.litellm_params = logging_obj.model_call_details["litellm_params"]
mock_response = MagicMock(spec=ModelResponse)
mock_response.id = "test-id"
mock_response.model = "unknown"
kwargs = AnthropicPassthroughLoggingHandler._create_anthropic_response_logging_payload(
litellm_model_response=mock_response,
model="unknown",
kwargs={},
start_time=datetime.now(),
end_time=datetime.now(),
logging_obj=logging_obj,
)
mock_completion_cost.assert_called_once()
assert (
mock_completion_cost.call_args[1]["model"]
== "anthropic/claude-3-5-haiku-20241022"
)
assert kwargs["response_cost"] == 0.003
assert kwargs["model"] == "anthropic/claude-3-5-haiku-20241022"
@patch("litellm.completion_cost")
def test_streaming_cost_calculation_resolves_model_from_message_start_chunk(
self, mock_completion_cost
):
"""On the bare /anthropic passthrough path litellm_params carries no model
or model_group and the body model is the "unknown" sentinel; the model
must be recovered from the message_start SSE event so completion_cost
prices the real model instead of failing on "unknown" and logging $0."""
from datetime import datetime
from litellm.litellm_core_utils.litellm_logging import (
Logging as RealLoggingObj,
)
from litellm.proxy.pass_through_endpoints.streaming_handler import (
PassThroughStreamingHandler,
)
mock_completion_cost.return_value = 0.001
def _sse(event, data):
return f"event: {event}\ndata: {json.dumps(data)}\n\n".encode()
frames = [
_sse(
"message_start",
{
"type": "message_start",
"message": {
"id": "msg_1",
"type": "message",
"role": "assistant",
"model": "claude-3-5-haiku-20241022",
"content": [],
"stop_reason": None,
"stop_sequence": None,
"usage": {"input_tokens": 10, "output_tokens": 0},
},
},
),
_sse(
"content_block_start",
{
"type": "content_block_start",
"index": 0,
"content_block": {"type": "text", "text": ""},
},
),
_sse(
"content_block_delta",
{
"type": "content_block_delta",
"index": 0,
"delta": {"type": "text_delta", "text": "hi"},
},
),
_sse("content_block_stop", {"type": "content_block_stop", "index": 0}),
_sse(
"message_delta",
{
"type": "message_delta",
"delta": {"stop_reason": "end_turn", "stop_sequence": None},
"usage": {"output_tokens": 1},
},
),
_sse("message_stop", {"type": "message_stop"}),
]
all_chunks = list(
PassThroughStreamingHandler._convert_raw_bytes_to_str_lines(frames)
)
logging_obj = RealLoggingObj(
model="unknown",
messages=[{"role": "user", "content": "hi"}],
stream=True,
call_type="pass_through_endpoint",
start_time=datetime.now(),
litellm_call_id="test-call-id",
function_id="1",
)
logging_obj.model_call_details["model"] = "unknown"
logging_obj.model_call_details["stream"] = True
logging_obj.model_call_details["litellm_params"] = {}
logging_obj.litellm_params = {}
result = AnthropicPassthroughLoggingHandler._handle_logging_anthropic_collected_chunks(
litellm_logging_obj=logging_obj,
passthrough_success_handler_obj=MagicMock(),
url_route="/anthropic/v1/messages",
request_body={"stream": True},
endpoint_type="messages",
start_time=datetime.now(),
all_chunks=all_chunks,
end_time=datetime.now(),
)
assert result["result"] is not None
mock_completion_cost.assert_called_once()
assert mock_completion_cost.call_args[1]["model"] == "claude-3-5-haiku-20241022"
assert result["kwargs"]["response_cost"] == 0.001
assert result["kwargs"]["model"] == "claude-3-5-haiku-20241022"
def test_extract_model_skips_non_dict_data_payload(self):
"""A scalar data: payload (e.g. `data: null`) must be skipped, not crash
the streaming log handler with AttributeError, which would propagate out
and break spend logging for the whole request."""
chunks = [
"event: ping\ndata: null\n\n",
'event: message_start\ndata: {"type": "message_start", "message": '
'{"model": "claude-3-5-haiku-20241022"}}\n\n',
]
assert (
AnthropicPassthroughLoggingHandler._extract_model_from_anthropic_chunks(
chunks
)
== "claude-3-5-haiku-20241022"
)
def test_extract_model_parses_per_line_not_first_data_substring(self):
"""A raw multi-line SSE event whose non-data line contains the substring
"data:" must not derail parsing: matching only lines that start with
"data:" recovers the message_start model, whereas a first-substring slice
would consume the wrong offset, fail to parse JSON, and return None."""
raw_event = (
"event: ping data: not-json\n"
'data: {"type": "message_start", "message": '
'{"model": "claude-3-5-haiku-20241022"}}\n\n'
)
assert (
AnthropicPassthroughLoggingHandler._extract_model_from_anthropic_chunks(
[raw_event]
)
== "claude-3-5-haiku-20241022"
)
def test_passthrough_logging_sets_response_cost_with_server_tool_use_dict(self):
from litellm.types.utils import Choices, Message, ModelResponse