diff --git a/deploy/charts/litellm-helm/templates/deployment.yaml b/deploy/charts/litellm-helm/templates/deployment.yaml index aefe2a564b..b9cd1be06e 100644 --- a/deploy/charts/litellm-helm/templates/deployment.yaml +++ b/deploy/charts/litellm-helm/templates/deployment.yaml @@ -30,7 +30,7 @@ spec: checksum/config: {{ include (print $.Template.BasePath "/configmap-litellm.yaml") . | sha256sum }} {{- end }} {{- with .Values.podAnnotations }} - {{- toYaml . | nindent 8 }} + {{- tpl (toYaml .) $ | nindent 8 }} {{- end }} labels: {{- include "litellm.labels" . | nindent 8 }} diff --git a/deploy/charts/litellm-helm/tests/deployment_tests.yaml b/deploy/charts/litellm-helm/tests/deployment_tests.yaml index df6d134564..f3d62651d8 100644 --- a/deploy/charts/litellm-helm/tests/deployment_tests.yaml +++ b/deploy/charts/litellm-helm/tests/deployment_tests.yaml @@ -377,3 +377,28 @@ tests: content: name: sidecar-tpl image: "ghcr.io/berriai/litellm-database:test" + - it: should support tpl in podAnnotations + template: deployment.yaml + set: + image: + repository: ghcr.io/berriai/litellm-database + tag: test + # Mirrors the real-world scenario this feature unblocks: + # user disables the built-in ConfigMap (and its built-in checksum/config + # annotation) and re-implements checksum/config themselves via tpl. + proxyConfigMap: + create: false + podAnnotations: + checksum/config: "{{ .Values.image.tag }}" + example.com/some-key: "{{ .Values.image.repository }}" + example.com/literal: "plain-string-value" + asserts: + - equal: + path: spec.template.metadata.annotations["checksum/config"] + value: "test" + - equal: + path: spec.template.metadata.annotations["example.com/some-key"] + value: "ghcr.io/berriai/litellm-database" + - equal: + path: spec.template.metadata.annotations["example.com/literal"] + value: "plain-string-value" diff --git a/litellm/integrations/arize/arize_phoenix.py b/litellm/integrations/arize/arize_phoenix.py index b8cd04836c..d48dba8e7b 100644 --- a/litellm/integrations/arize/arize_phoenix.py +++ b/litellm/integrations/arize/arize_phoenix.py @@ -1,5 +1,7 @@ import os -from typing import TYPE_CHECKING, Any, Optional, Union +import threading +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Optional, Tuple, Union from litellm._logging import verbose_logger from litellm.integrations.arize import _utils @@ -8,8 +10,10 @@ from litellm.types.integrations.arize_phoenix import ArizePhoenixConfig if TYPE_CHECKING: from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SpanProcessor from opentelemetry.trace import Span as _Span from opentelemetry.trace import SpanKind + from opentelemetry.trace import Tracer from litellm.integrations.opentelemetry import OpenTelemetry as _OpenTelemetry from litellm.integrations.opentelemetry import ( @@ -21,20 +25,27 @@ if TYPE_CHECKING: OpenTelemetryConfig = _OpenTelemetryConfig Span = Union[_Span, Any] OpenTelemetry = _OpenTelemetry + LITELLM_TRACER_NAME: str else: Protocol = Any OpenTelemetryConfig = Any Span = Any + Tracer = Any TracerProvider = Any SpanKind = Any - # Import OpenTelemetry at runtime + SpanProcessor = Any try: - from litellm.integrations.opentelemetry import OpenTelemetry + from litellm.integrations.opentelemetry import ( + LITELLM_TRACER_NAME, + OpenTelemetry, + ) except ImportError: + LITELLM_TRACER_NAME = "litellm" OpenTelemetry = None # type: ignore ARIZE_HOSTED_PHOENIX_ENDPOINT = "https://otlp.arize.com/v1/traces" +_MAX_PROJECT_PROVIDERS = 64 class ArizePhoenixLogger(OpenTelemetry): # type: ignore @@ -48,37 +59,142 @@ class ArizePhoenixLogger(OpenTelemetry): # type: ignore def _init_tracing(self, tracer_provider): """ - Override to always create a *private* TracerProvider for Arize Phoenix. + Override to create per-project TracerProviders (LRU-cached) for Arize Phoenix. The base ``OpenTelemetry._init_tracing`` falls back to the global TracerProvider when one already exists. That causes whichever integration initialises second to silently reuse the first one's exporter, so spans only reach one destination. - - By creating our own provider we guarantee Arize Phoenix always gets - its own exporter pipeline, regardless of initialisation order. """ - from opentelemetry.sdk.trace import TracerProvider from opentelemetry.trace import SpanKind if tracer_provider is not None: - # Explicitly supplied (e.g. in tests) — honour it. - self.tracer = tracer_provider.get_tracer("litellm") + self._use_injected_tracer_provider = True + self._shared_span_processor = None + self.tracer = tracer_provider.get_tracer(LITELLM_TRACER_NAME) self.span_kind = SpanKind return - # Always create a dedicated provider — never touch the global one. - provider = TracerProvider(resource=self._get_litellm_resource(self.config)) - provider.add_span_processor(self._get_span_processor()) - self.tracer = provider.get_tracer("litellm") + self._use_injected_tracer_provider = False + self._project_providers: OrderedDict[str, TracerProvider] = OrderedDict() + self._project_providers_lock = threading.Lock() + self._shared_span_processor = self._get_span_processor() self.span_kind = SpanKind + + default_project = self._resolve_project_name({}) + self.tracer = self._get_tracer_for(default_project) verbose_logger.debug( - "ArizePhoenixLogger: Created dedicated TracerProvider " - "(endpoint=%s, exporter=%s)", + "ArizePhoenixLogger: Initialized per-project TracerProvider cache " + "(default_project=%s, endpoint=%s, exporter=%s)", + default_project, self.config.endpoint, self.config.exporter, ) + def flush_tracer_providers(self) -> None: + """ + Flush all cached per-project providers and the shared span processor. + + Call on graceful proxy shutdown. Do not call on LRU eviction — in-flight + spans may still reference evicted providers. + """ + if getattr(self, "_use_injected_tracer_provider", False): + return + + shared_processor = getattr(self, "_shared_span_processor", None) + if shared_processor is not None: + try: + shared_processor.force_flush() + except Exception as e: + verbose_logger.debug( + "ArizePhoenixLogger: shared span processor force_flush failed: %s", + e, + ) + + with getattr(self, "_project_providers_lock", threading.Lock()): + providers = list(getattr(self, "_project_providers", {}).values()) + + for provider in providers: + try: + provider.force_flush() + except Exception as e: + verbose_logger.debug( + "ArizePhoenixLogger: TracerProvider force_flush failed: %s", e + ) + + def _get_litellm_resource_for_project(self, project_name: str): + """ + Build an OTEL Resource with project routing attrs that win over env detector. + + Phoenix uses ``openinference.project.name``; Arize AX uses ``model_id`` and + ``service.name``. Project attrs are merged last so OTEL_RESOURCE_ATTRIBUTES + from init does not pin every provider to one project. + """ + from opentelemetry.sdk.resources import OTELResourceDetector, Resource + + project_attributes: dict[str, str] = { + "openinference.project.name": project_name, + "model_id": project_name, + "service.name": project_name, + } + deployment_environment = getattr(self.config, "deployment_environment", None) + if deployment_environment is not None: + project_attributes["deployment.environment"] = deployment_environment + + env_resource = OTELResourceDetector().detect() + project_resource = Resource.create(project_attributes) # type: ignore[arg-type] + return env_resource.merge(project_resource) + + def _build_tracer_provider_for_project(self, project_name: str) -> TracerProvider: + """Create a TracerProvider for *project_name* (caller holds no cache lock).""" + from opentelemetry.sdk.trace import TracerProvider + + provider = TracerProvider( + resource=self._get_litellm_resource_for_project(project_name) + ) + provider.add_span_processor(self._shared_span_processor) + return provider + + def _get_tracer_for(self, project_name: str) -> Tracer: + """Return a tracer for *project_name*, creating/caching a provider on miss.""" + if getattr(self, "_use_injected_tracer_provider", False): + return self.tracer + + with self._project_providers_lock: + if project_name in self._project_providers: + self._project_providers.move_to_end(project_name) + return self._project_providers[project_name].get_tracer( + LITELLM_TRACER_NAME + ) + + # OTELResourceDetector().detect() is synchronous; build outside the lock so + # concurrent requests for other projects are not blocked on cache misses. + new_provider = self._build_tracer_provider_for_project(project_name) + + with self._project_providers_lock: + if project_name in self._project_providers: + self._project_providers.move_to_end(project_name) + return self._project_providers[project_name].get_tracer( + LITELLM_TRACER_NAME + ) + + if len(self._project_providers) >= _MAX_PROJECT_PROVIDERS: + self._project_providers.popitem(last=False) + + self._project_providers[project_name] = new_provider + return new_provider.get_tracer(LITELLM_TRACER_NAME) + + def _resolve_tracer_for_kwargs(self, kwargs: dict) -> Tuple[str, Tracer]: + """Resolve project name once and return the matching tracer.""" + project_name = self._resolve_project_name(kwargs) + return project_name, self._get_tracer_for(project_name) + + def get_tracer_to_use_for_request(self, kwargs: dict) -> Tracer: + """Route guardrail/raw-request spans to the same per-project tracer as the request.""" + if getattr(self, "_use_injected_tracer_provider", False): + return self.tracer + return self._resolve_tracer_for_kwargs(kwargs)[1] + def _init_otel_logger_on_litellm_proxy(self): """ Override: Arize Phoenix should NOT overwrite the proxy's @@ -93,56 +209,109 @@ class ArizePhoenixLogger(OpenTelemetry): # type: ignore @staticmethod def set_arize_phoenix_attributes(span: Span, kwargs, response_obj): - from litellm.integrations.opentelemetry_utils.base_otel_llm_obs_attributes import ( - safe_set_attribute, - ) - _utils.set_attributes(span, kwargs, response_obj, ArizeOTELAttributes) - - # Dynamic project name: check metadata first, then fall back to env var config - dynamic_project_name = ArizePhoenixLogger._get_dynamic_project_name(kwargs) - if dynamic_project_name: - safe_set_attribute(span, "openinference.project.name", dynamic_project_name) - else: - # Fall back to static config from env var - config = ArizePhoenixLogger.get_arize_phoenix_config() - if config.project_name: - safe_set_attribute( - span, "openinference.project.name", config.project_name - ) - return @staticmethod - def _get_dynamic_project_name(kwargs) -> Optional[str]: - """ - Retrieve dynamic Phoenix project name from request metadata. + def _normalize_project_name(name: Optional[str]) -> Optional[str]: + if name is None: + return None + normalized = str(name).strip() + return normalized if normalized else None - Users can set `metadata.phoenix_project_name` in their request to route - traces to different Phoenix projects dynamically. - """ - standard_logging_payload = kwargs.get("standard_logging_object") - if isinstance(standard_logging_payload, dict): - metadata = standard_logging_payload.get("metadata") + @staticmethod + def _iter_metadata_dicts_from_kwargs(kwargs: dict): + """Yield request metadata dicts; standard_logging_object before litellm_params.""" + for key in ("standard_logging_object", "litellm_params"): + found_key = kwargs.get(key) + if not isinstance(found_key, dict): + continue + metadata = found_key.get("metadata") if isinstance(metadata, dict): - project_name = metadata.get("phoenix_project_name") - if project_name: - return str(project_name) + yield metadata - # Also check litellm_params.metadata for SDK usage + @staticmethod + def _is_proxy_request(kwargs: dict) -> bool: + """True when the call is routed through the LiteLLM proxy. + + Proxy mode is determined solely by the server-set ``proxy_server_request`` + field in ``litellm_params``. Checking request metadata for + ``user_api_key_auth_metadata`` is intentionally avoided: that field is + user-supplied and would let an authenticated caller fake proxy-mode + detection to route their telemetry into arbitrary Arize/Phoenix projects. + """ litellm_params = kwargs.get("litellm_params") - if isinstance(litellm_params, dict): - metadata = litellm_params.get("metadata") or {} - else: - metadata = {} - if isinstance(metadata, dict): - project_name = metadata.get("phoenix_project_name") - if project_name: - return str(project_name) + return isinstance(litellm_params, dict) and bool( + litellm_params.get("proxy_server_request") + ) + @staticmethod + def _project_from_metadata_dict( + metadata: dict, metadata_key: str, *, proxy_mode: bool + ) -> Optional[str]: + """ + Read a Phoenix project field from proxy/SDK metadata. + + On the proxy, only ``user_api_key_auth_metadata`` (team/key config) may + select the project. SDK callers may still set project fields directly on + ``metadata``. + """ + auth_metadata = metadata.get("user_api_key_auth_metadata") + if isinstance(auth_metadata, dict): + project = ArizePhoenixLogger._normalize_project_name( + auth_metadata.get(metadata_key) + ) + if project: + return project + + if not proxy_mode: + return ArizePhoenixLogger._normalize_project_name( + metadata.get(metadata_key) + ) return None - def _get_phoenix_context(self, kwargs): + @staticmethod + def _metadata_project_from_kwargs(kwargs: dict, metadata_key: str) -> Optional[str]: + proxy_mode = ArizePhoenixLogger._is_proxy_request(kwargs) + for metadata in ArizePhoenixLogger._iter_metadata_dicts_from_kwargs(kwargs): + project = ArizePhoenixLogger._project_from_metadata_dict( + metadata, metadata_key, proxy_mode=proxy_mode + ) + if project: + return project + return None + + @staticmethod + def _resolve_project_name(kwargs: dict) -> str: + """ + Resolve the target Phoenix/Arize project for this request. + + Proxy priority: ``user_api_key_auth_metadata.phoenix_project_name_override``, + ``user_api_key_auth_metadata.phoenix_project_name``, env, then ``default``. + SDK priority: request metadata fields, then env, then ``default``. + """ + override = ArizePhoenixLogger._metadata_project_from_kwargs( + kwargs, "phoenix_project_name_override" + ) + if override: + return override + + phoenix_name = ArizePhoenixLogger._metadata_project_from_kwargs( + kwargs, "phoenix_project_name" + ) + if phoenix_name: + return phoenix_name + + env_name = ArizePhoenixLogger._normalize_project_name( + os.environ.get("PHOENIX_PROJECT_NAME") + or os.environ.get("ARIZE_PROJECT_NAME") + ) + if env_name: + return env_name + + return "default" + + def _get_phoenix_context(self, kwargs, tracer: Optional[Tracer] = None): """ Build a trace context for Phoenix's dedicated TracerProvider. @@ -159,11 +328,13 @@ class ArizePhoenixLogger(OpenTelemetry): # type: ignore """ from opentelemetry import trace + if tracer is None: + tracer = self._resolve_tracer_for_kwargs(kwargs)[1] + litellm_params = kwargs.get("litellm_params", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request", {}) or {} headers = proxy_server_request.get("headers", {}) or {} - # Propagate distributed trace context if the caller sent a traceparent traceparent_ctx = ( self.get_traceparent_from_header(headers=headers) if headers.get("traceparent") @@ -173,10 +344,8 @@ class ArizePhoenixLogger(OpenTelemetry): # type: ignore is_proxy_mode = bool(proxy_server_request) if is_proxy_mode: - # Create a parent span on Phoenix's own tracer so both parent - # and child are exported to Phoenix. start_time_val = kwargs.get("start_time", kwargs.get("api_call_start_time")) - parent_span = self.tracer.start_span( + parent_span = tracer.start_span( name="litellm_proxy_request", start_time=( self._to_ns(start_time_val) if start_time_val is not None else None @@ -187,100 +356,77 @@ class ArizePhoenixLogger(OpenTelemetry): # type: ignore ctx = trace.set_span_in_context(parent_span) return ctx, parent_span - # SDK mode — no parent span needed return traceparent_ctx, None def _handle_success(self, kwargs, response_obj, start_time, end_time): - """ - Override to always create spans on ArizePhoenixLogger's dedicated TracerProvider. - - The base class's ``_get_span_context`` would find the parent span created by - the ``otel`` callback on the *global* TracerProvider. That span is invisible - in Phoenix (different exporter pipeline), so we ignore it and build our own - hierarchy via ``_get_phoenix_context``. - """ - from opentelemetry.trace import Status, StatusCode - - verbose_logger.debug( - "ArizePhoenixLogger: Logging kwargs: %s, OTEL config settings=%s", - kwargs, - self.config, + self._handle_phoenix_trace( + kwargs, response_obj, start_time, end_time, success=True ) - ctx, parent_span = self._get_phoenix_context(kwargs) - - # Create litellm_request span (child of our parent when in proxy mode) - span = self.tracer.start_span( - name=self._get_span_name(kwargs), - start_time=self._to_ns(start_time), - context=ctx, - ) - span.set_status(Status(StatusCode.OK)) - self.set_attributes(span, kwargs, response_obj) - - # Raw-request sub-span (if enabled) — must be created before - # ending the parent span so the hierarchy is valid. - self._maybe_log_raw_request(kwargs, response_obj, start_time, end_time, span) - span.end(end_time=self._to_ns(end_time)) - - # Guardrail span - self._create_guardrail_span(kwargs=kwargs, context=ctx) - - # Annotate and close our proxy parent span - if parent_span is not None: - parent_span.set_status(Status(StatusCode.OK)) - self.set_attributes(parent_span, kwargs, response_obj) - parent_span.end(end_time=self._to_ns(end_time)) - - # Metrics & cost recording - self._record_metrics(kwargs, response_obj, start_time, end_time) - - # Semantic logs - if self.config.enable_events: - self._emit_semantic_logs(kwargs, response_obj, span) - def _handle_failure(self, kwargs, response_obj, start_time, end_time): - """ - Override to always create failure spans on ArizePhoenixLogger's dedicated - TracerProvider. Mirrors ``_handle_success`` but sets ERROR status. - """ + self._handle_phoenix_trace( + kwargs, response_obj, start_time, end_time, success=False + ) + + def _handle_phoenix_trace( + self, + kwargs, + response_obj, + start_time, + end_time, + *, + success: bool, + ): from opentelemetry.trace import Status, StatusCode verbose_logger.debug( - "ArizePhoenixLogger: Failure - Logging kwargs: %s, OTEL config settings=%s", + "ArizePhoenixLogger: %s - kwargs: %s, OTEL config settings=%s", + "success" if success else "failure", kwargs, self.config, ) - ctx, parent_span = self._get_phoenix_context(kwargs) + _project_name, tracer = self._resolve_tracer_for_kwargs(kwargs) + ctx, parent_span = self._get_phoenix_context(kwargs, tracer=tracer) - # Create litellm_request span (child of our parent when in proxy mode) - span = self.tracer.start_span( + status = Status(StatusCode.OK if success else StatusCode.ERROR) + + span = tracer.start_span( name=self._get_span_name(kwargs), start_time=self._to_ns(start_time), context=ctx, ) - span.set_status(Status(StatusCode.ERROR)) + span.set_status(status) self.set_attributes(span, kwargs, response_obj) - self._record_exception_on_span(span=span, kwargs=kwargs) + if not success: + self._record_exception_on_span(span=span, kwargs=kwargs) + + if success: + self._maybe_log_raw_request( + kwargs, response_obj, start_time, end_time, span + ) span.end(end_time=self._to_ns(end_time)) - # Guardrail span self._create_guardrail_span(kwargs=kwargs, context=ctx) - # Annotate and close our proxy parent span if parent_span is not None: - parent_span.set_status(Status(StatusCode.ERROR)) + parent_span.set_status(status) self.set_attributes(parent_span, kwargs, response_obj) - self._record_exception_on_span(span=parent_span, kwargs=kwargs) + if not success: + self._record_exception_on_span(span=parent_span, kwargs=kwargs) parent_span.end(end_time=self._to_ns(end_time)) + if success: + self._record_metrics(kwargs, response_obj, start_time, end_time) + + if self.config.enable_events: + self._emit_semantic_logs(kwargs, response_obj, span) + @staticmethod def get_arize_phoenix_config() -> ArizePhoenixConfig: """ Retrieves the Arize Phoenix configuration based on environment variables. Returns: - ArizePhoenixConfig: A Pydantic model containing Arize Phoenix configuration. """ api_key = os.environ.get("PHOENIX_API_KEY", None) @@ -295,18 +441,15 @@ class ArizePhoenixLogger(OpenTelemetry): # type: ignore protocol: Protocol = "otlp_http" if collector_endpoint: - # Parse the endpoint to determine protocol if collector_endpoint.startswith("grpc://") or ( ":4317" in collector_endpoint and "/v1/traces" not in collector_endpoint ): endpoint = collector_endpoint protocol = "otlp_grpc" else: - # Phoenix Cloud endpoints (app.phoenix.arize.com) include the space in the URL if "app.phoenix.arize.com" in collector_endpoint: endpoint = collector_endpoint protocol = "otlp_http" - # For other HTTP endpoints, ensure they have the correct path elif "/v1/traces" not in collector_endpoint: if collector_endpoint.endswith("/v1"): endpoint = collector_endpoint + "/traces" @@ -318,7 +461,6 @@ class ArizePhoenixLogger(OpenTelemetry): # type: ignore endpoint = collector_endpoint protocol = "otlp_http" else: - # If no endpoint specified, self hosted phoenix endpoint = "http://localhost:6006/v1/traces" protocol = "otlp_http" verbose_logger.debug( @@ -329,12 +471,11 @@ class ArizePhoenixLogger(OpenTelemetry): # type: ignore if api_key is not None: otlp_auth_headers = f"Authorization=Bearer {api_key}" elif "app.phoenix.arize.com" in endpoint: - # Phoenix Cloud requires an API key raise ValueError( "PHOENIX_API_KEY must be set when using Phoenix Cloud (app.phoenix.arize.com)." ) - project_name = os.environ.get("PHOENIX_PROJECT_NAME", "default") + project_name = os.environ.get("PHOENIX_PROJECT_NAME") or "default" return ArizePhoenixConfig( otlp_auth_headers=otlp_auth_headers, @@ -343,8 +484,6 @@ class ArizePhoenixLogger(OpenTelemetry): # type: ignore project_name=project_name, ) - ## cannot suppress additional proxy server spans, removed previous methods. - async def async_health_check(self): config = self.get_arize_phoenix_config() diff --git a/litellm/integrations/datadog/datadog_metrics.py b/litellm/integrations/datadog/datadog_metrics.py index fcf40701e2..d7847027d7 100644 --- a/litellm/integrations/datadog/datadog_metrics.py +++ b/litellm/integrations/datadog/datadog_metrics.py @@ -144,7 +144,26 @@ class DatadogMetricsLogger(CustomBatchLogger): } self.log_queue.append(series_llm_latency) - # 3. Request Count / Status Code + # 3. LiteLLM Overhead Latency Metric (total - llm_api time) + hidden_params = log.get("hidden_params", {}) or {} + litellm_overhead_time_ms = hidden_params.get("litellm_overhead_time_ms") + if litellm_overhead_time_ms is not None: + overhead_tags = self._extract_tags(log) # no status_code on latency metric + series_overhead: DatadogMetricSeries = { + "metric": "litellm.overhead.latency", + "type": 3, # gauge + "points": [ + { + "timestamp": timestamp, + "value": litellm_overhead_time_ms + / 1000, # convert ms → seconds + } + ], + "tags": overhead_tags, + } + self.log_queue.append(series_overhead) + + # 4. Request Count / Status Code series_count: DatadogMetricSeries = { "metric": "litellm.llm_api.request_count", "type": 1, # count diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index ef0e674715..97266096ef 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -3910,31 +3910,6 @@ def _init_custom_logger_compatible_class( # noqa: PLR0915 endpoint=arize_phoenix_config.endpoint, headers=arize_phoenix_config.otlp_auth_headers, ) - if arize_phoenix_config.project_name: - existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") - # Add openinference.project.name attribute - if existing_attrs: - os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( - f"{existing_attrs},openinference.project.name={arize_phoenix_config.project_name}" - ) - else: - os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( - f"openinference.project.name={arize_phoenix_config.project_name}" - ) - - # Set Phoenix project name from environment variable - phoenix_project_name = os.environ.get("PHOENIX_PROJECT_NAME", None) - if phoenix_project_name: - existing_attrs = os.environ.get("OTEL_RESOURCE_ATTRIBUTES", "") - # Add openinference.project.name attribute - if existing_attrs: - os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( - f"{existing_attrs},openinference.project.name={phoenix_project_name}" - ) - else: - os.environ["OTEL_RESOURCE_ATTRIBUTES"] = ( - f"openinference.project.name={phoenix_project_name}" - ) # auth can be disabled on local deployments of arize phoenix if arize_phoenix_config.otlp_auth_headers is not None: diff --git a/litellm/litellm_core_utils/llm_cost_calc/utils.py b/litellm/litellm_core_utils/llm_cost_calc/utils.py index 6c999590dd..882561ed2e 100644 --- a/litellm/litellm_core_utils/llm_cost_calc/utils.py +++ b/litellm/litellm_core_utils/llm_cost_calc/utils.py @@ -30,6 +30,9 @@ _IMAGE_RESPONSE_CALL_TYPES = frozenset( } ) +# Pre-resolved DataResidency enum values for fast membership checks +_VALID_DATA_RESIDENCIES = frozenset(r.value for r in DataResidency) + def _is_above_128k(tokens: float) -> bool: if tokens > 128000: @@ -636,7 +639,7 @@ def _get_regional_uplift_multiplier( if data_residency is None: return 1.0 residency = data_residency.lower() - if residency not in {r.value for r in DataResidency}: + if residency not in _VALID_DATA_RESIDENCIES: return 1.0 multiplier = model_info.get(f"regional_processing_uplift_multiplier_{residency}") if multiplier is None: diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index bec25916c4..5f35a58ce1 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -108,10 +108,9 @@ class BaseConfig(ABC): return type_to_response_format_param(response_format=response_format) def is_thinking_enabled(self, non_default_params: dict) -> bool: - return ( - non_default_params.get("thinking", {}).get("type") == "enabled" - or non_default_params.get("reasoning_effort") is not None - ) + return (non_default_params.get("thinking") or {}).get( + "type" + ) == "enabled" or non_default_params.get("reasoning_effort") is not None def is_max_tokens_in_request(self, non_default_params: dict) -> bool: """ diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index bed030d284..ec16f19799 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -17919,22 +17919,9 @@ }, "github_copilot/claude-haiku-4.5": { "litellm_provider": "github_copilot", - "max_input_tokens": 128000, - "max_output_tokens": 16000, - "max_tokens": 16000, - "mode": "chat", - "supported_endpoints": [ - "/v1/chat/completions" - ], - "supports_function_calling": true, - "supports_parallel_function_calling": true, - "supports_vision": true - }, - "github_copilot/claude-opus-4.5": { - "litellm_provider": "github_copilot", - "max_input_tokens": 128000, - "max_output_tokens": 16000, - "max_tokens": 16000, + "max_input_tokens": 200000, + "max_output_tokens": 32000, + "max_tokens": 32000, "mode": "chat", "supported_endpoints": [ "/v1/chat/completions" @@ -17942,7 +17929,22 @@ "supports_function_calling": true, "supports_parallel_function_calling": true, "supports_vision": true, - "supports_minimal_reasoning_effort": true + "supports_reasoning": true + }, + "github_copilot/claude-opus-4.5": { + "litellm_provider": "github_copilot", + "max_input_tokens": 200000, + "max_output_tokens": 32000, + "max_tokens": 32000, + "mode": "chat", + "supported_endpoints": [ + "/v1/chat/completions" + ], + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_vision": true, + "supports_minimal_reasoning_effort": true, + "supports_reasoning": true }, "github_copilot/claude-opus-4.6-fast": { "litellm_provider": "github_copilot", @@ -17957,6 +17959,22 @@ "supports_parallel_function_calling": true, "supports_vision": true }, + "github_copilot/claude-opus-4.7": { + "litellm_provider": "github_copilot", + "max_input_tokens": 200000, + "max_output_tokens": 64000, + "max_tokens": 64000, + "mode": "chat", + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/messages" + ], + "supports_vision": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_response_schema": true, + "supports_reasoning": true + }, "github_copilot/claude-opus-41": { "litellm_provider": "github_copilot", "max_input_tokens": 80000, @@ -17983,16 +18001,33 @@ }, "github_copilot/claude-sonnet-4.5": { "litellm_provider": "github_copilot", - "max_input_tokens": 128000, - "max_output_tokens": 16000, - "max_tokens": 16000, + "max_input_tokens": 200000, + "max_output_tokens": 32000, + "max_tokens": 32000, "mode": "chat", "supported_endpoints": [ "/v1/chat/completions" ], "supports_function_calling": true, "supports_parallel_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supports_reasoning": true + }, + "github_copilot/claude-sonnet-4.6": { + "litellm_provider": "github_copilot", + "max_input_tokens": 200000, + "max_output_tokens": 32000, + "max_tokens": 32000, + "mode": "chat", + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/messages" + ], + "supports_vision": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_response_schema": true, + "supports_reasoning": true }, "github_copilot/gemini-2.5-pro": { "litellm_provider": "github_copilot", @@ -18002,7 +18037,25 @@ "mode": "chat", "supports_function_calling": true, "supports_parallel_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supported_endpoints": [ + "/v1/chat/completions" + ], + "supports_reasoning": true + }, + "github_copilot/gemini-3-flash-preview": { + "litellm_provider": "github_copilot", + "max_input_tokens": 128000, + "max_output_tokens": 64000, + "max_tokens": 64000, + "mode": "chat", + "supported_endpoints": [ + "/v1/chat/completions" + ], + "supports_vision": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_reasoning": true }, "github_copilot/gemini-3-pro-preview": { "litellm_provider": "github_copilot", @@ -18014,13 +18067,30 @@ "supports_parallel_function_calling": true, "supports_vision": true }, + "github_copilot/gemini-3.1-pro-preview": { + "litellm_provider": "github_copilot", + "max_input_tokens": 128000, + "max_output_tokens": 64000, + "max_tokens": 64000, + "mode": "chat", + "supported_endpoints": [ + "/v1/chat/completions" + ], + "supports_vision": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_reasoning": true + }, "github_copilot/gpt-3.5-turbo": { "litellm_provider": "github_copilot", "max_input_tokens": 16384, "max_output_tokens": 4096, "max_tokens": 4096, "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] }, "github_copilot/gpt-3.5-turbo-0613": { "litellm_provider": "github_copilot", @@ -18028,7 +18098,10 @@ "max_output_tokens": 4096, "max_tokens": 4096, "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] }, "github_copilot/gpt-4": { "litellm_provider": "github_copilot", @@ -18036,7 +18109,22 @@ "max_output_tokens": 4096, "max_tokens": 4096, "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] + }, + "github_copilot/gpt-4-0125-preview": { + "litellm_provider": "github_copilot", + "max_input_tokens": 128000, + "max_output_tokens": 4096, + "max_tokens": 4096, + "mode": "chat", + "supported_endpoints": [ + "/v1/chat/completions" + ], + "supports_function_calling": true, + "supports_parallel_function_calling": true }, "github_copilot/gpt-4-0613": { "litellm_provider": "github_copilot", @@ -18044,16 +18132,22 @@ "max_output_tokens": 4096, "max_tokens": 4096, "mode": "chat", - "supports_function_calling": true + "supports_function_calling": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] }, "github_copilot/gpt-4-o-preview": { "litellm_provider": "github_copilot", - "max_input_tokens": 64000, + "max_input_tokens": 128000, "max_output_tokens": 4096, "max_tokens": 4096, "mode": "chat", "supports_function_calling": true, - "supports_parallel_function_calling": true + "supports_parallel_function_calling": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] }, "github_copilot/gpt-4.1": { "litellm_provider": "github_copilot", @@ -18064,7 +18158,10 @@ "supports_function_calling": true, "supports_parallel_function_calling": true, "supports_response_schema": true, - "supports_vision": true + "supports_vision": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] }, "github_copilot/gpt-4.1-2025-04-14": { "litellm_provider": "github_copilot", @@ -18075,68 +18172,89 @@ "supports_function_calling": true, "supports_parallel_function_calling": true, "supports_response_schema": true, - "supports_vision": true + "supports_vision": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] }, "github_copilot/gpt-41-copilot": { "litellm_provider": "github_copilot", - "mode": "completion" + "mode": "chat" }, "github_copilot/gpt-4o": { "litellm_provider": "github_copilot", - "max_input_tokens": 64000, + "max_input_tokens": 128000, "max_output_tokens": 4096, "max_tokens": 4096, "mode": "chat", "supports_function_calling": true, "supports_parallel_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] }, "github_copilot/gpt-4o-2024-05-13": { "litellm_provider": "github_copilot", - "max_input_tokens": 64000, + "max_input_tokens": 128000, "max_output_tokens": 4096, "max_tokens": 4096, "mode": "chat", "supports_function_calling": true, "supports_parallel_function_calling": true, - "supports_vision": true + "supports_vision": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] }, "github_copilot/gpt-4o-2024-08-06": { "litellm_provider": "github_copilot", - "max_input_tokens": 64000, - "max_output_tokens": 16384, - "max_tokens": 16384, - "mode": "chat", - "supports_function_calling": true, - "supports_parallel_function_calling": true - }, - "github_copilot/gpt-4o-2024-11-20": { - "litellm_provider": "github_copilot", - "max_input_tokens": 64000, + "max_input_tokens": 128000, "max_output_tokens": 16384, "max_tokens": 16384, "mode": "chat", "supports_function_calling": true, "supports_parallel_function_calling": true, - "supports_vision": true + "supported_endpoints": [ + "/v1/chat/completions" + ] + }, + "github_copilot/gpt-4o-2024-11-20": { + "litellm_provider": "github_copilot", + "max_input_tokens": 128000, + "max_output_tokens": 16384, + "max_tokens": 16384, + "mode": "chat", + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_vision": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] }, "github_copilot/gpt-4o-mini": { "litellm_provider": "github_copilot", - "max_input_tokens": 64000, + "max_input_tokens": 128000, "max_output_tokens": 4096, "max_tokens": 4096, "mode": "chat", "supports_function_calling": true, - "supports_parallel_function_calling": true + "supports_parallel_function_calling": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] }, "github_copilot/gpt-4o-mini-2024-07-18": { "litellm_provider": "github_copilot", - "max_input_tokens": 64000, + "max_input_tokens": 128000, "max_output_tokens": 4096, "max_tokens": 4096, "mode": "chat", "supports_function_calling": true, - "supports_parallel_function_calling": true + "supports_parallel_function_calling": true, + "supported_endpoints": [ + "/v1/chat/completions" + ] }, "github_copilot/gpt-5": { "litellm_provider": "github_copilot", @@ -18155,14 +18273,19 @@ }, "github_copilot/gpt-5-mini": { "litellm_provider": "github_copilot", - "max_input_tokens": 128000, + "max_input_tokens": 264000, "max_output_tokens": 64000, "max_tokens": 64000, "mode": "chat", "supports_function_calling": true, "supports_parallel_function_calling": true, "supports_response_schema": true, - "supports_vision": true + "supports_vision": true, + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/responses" + ], + "supports_reasoning": true }, "github_copilot/gpt-5.1": { "litellm_provider": "github_copilot", @@ -18195,7 +18318,7 @@ }, "github_copilot/gpt-5.2": { "litellm_provider": "github_copilot", - "max_input_tokens": 128000, + "max_input_tokens": 264000, "max_output_tokens": 64000, "max_tokens": 64000, "mode": "chat", @@ -18206,11 +18329,27 @@ "supports_function_calling": true, "supports_parallel_function_calling": true, "supports_response_schema": true, - "supports_vision": true + "supports_vision": true, + "supports_reasoning": true + }, + "github_copilot/gpt-5.2-codex": { + "litellm_provider": "github_copilot", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "responses", + "supported_endpoints": [ + "/v1/responses" + ], + "supports_vision": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_response_schema": true, + "supports_reasoning": true }, "github_copilot/gpt-5.3-codex": { "litellm_provider": "github_copilot", - "max_input_tokens": 128000, + "max_input_tokens": 400000, "max_output_tokens": 128000, "max_tokens": 128000, "mode": "responses", @@ -18220,25 +18359,96 @@ "supports_function_calling": true, "supports_parallel_function_calling": true, "supports_response_schema": true, - "supports_vision": true + "supports_vision": true, + "supports_reasoning": true + }, + "github_copilot/gpt-5.4": { + "litellm_provider": "github_copilot", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "chat", + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/responses" + ], + "supports_vision": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_response_schema": true, + "supports_reasoning": true + }, + "github_copilot/gpt-5.4-mini": { + "litellm_provider": "github_copilot", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "responses", + "supported_endpoints": [ + "/v1/responses" + ], + "supports_vision": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_response_schema": true, + "supports_reasoning": true + }, + "github_copilot/gpt-5.5": { + "litellm_provider": "github_copilot", + "max_input_tokens": 400000, + "max_output_tokens": 128000, + "max_tokens": 128000, + "mode": "responses", + "supported_endpoints": [ + "/v1/responses" + ], + "supports_vision": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_response_schema": true, + "supports_reasoning": true + }, + "github_copilot/oswe-vscode-prime": { + "litellm_provider": "github_copilot", + "max_input_tokens": 264000, + "max_output_tokens": 64000, + "max_tokens": 64000, + "mode": "chat", + "supported_endpoints": [ + "/v1/chat/completions", + "/v1/responses" + ], + "supports_vision": true, + "supports_function_calling": true, + "supports_parallel_function_calling": true, + "supports_response_schema": true }, "github_copilot/text-embedding-3-small": { "litellm_provider": "github_copilot", "max_input_tokens": 8191, "max_tokens": 8191, - "mode": "embedding" + "mode": "embedding", + "supported_endpoints": [ + "/v1/embeddings" + ] }, "github_copilot/text-embedding-3-small-inference": { "litellm_provider": "github_copilot", "max_input_tokens": 8191, "max_tokens": 8191, - "mode": "embedding" + "mode": "embedding", + "supported_endpoints": [ + "/v1/embeddings" + ] }, "github_copilot/text-embedding-ada-002": { "litellm_provider": "github_copilot", "max_input_tokens": 8191, "max_tokens": 8191, - "mode": "embedding" + "mode": "embedding", + "supported_endpoints": [ + "/v1/embeddings" + ] }, "chatgpt/gpt-5.4": { "litellm_provider": "chatgpt", diff --git a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py index 652e284ed4..8324ba641a 100644 --- a/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py +++ b/litellm/proxy/_experimental/mcp_server/discoverable_endpoints.py @@ -1,3 +1,4 @@ +import html as _html import json from typing import Any, Dict, Optional from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse @@ -618,8 +619,105 @@ async def token_endpoint( ) +# Per RFC 6749 §4.1.2.1, an IdP that rejects an OAuth authorization request +# redirects back to the configured redirect URI with ``error`` / +# ``error_description`` / ``error_uri`` query params and no ``code``. The MCP +# loopback flow funnels that response through this /callback endpoint, so +# the endpoint must accept either a successful (``code``+``state``) or an +# error response. Declaring ``code``/``state`` as required would cause +# FastAPI to reject the error response with a 422 before the handler runs, +# which strands the MCP client waiting on the loopback (see LIT-2750). + + +def _render_oauth_error_html(error: str, description: Optional[str]) -> HTMLResponse: + """Render an actionable HTML page for an IdP-reported OAuth error. + + Used when we cannot propagate the error back to the registered + ``redirect_uri`` (state missing or undecryptable). Returned with a 400 + status so the failure is observable to operators while still being a + human-readable page for the end user. + """ + safe_error = _html.escape(error or "unknown_error") + safe_description = _html.escape(description) if description else "" + description_html = f"
{safe_description}
" if safe_description else "" + body = ( + "" + "Error: {safe_error}
" + f"{description_html}" + "You can close this window and try again.
" + "" + ) + return HTMLResponse(body, status_code=400) + + @router.get("/callback") -async def callback(request: Request, code: str, state: str): +async def callback( + request: Request, + code: Optional[str] = None, + state: Optional[str] = None, + error: Optional[str] = None, + error_description: Optional[str] = None, + error_uri: Optional[str] = None, +): + """OAuth 2.0 authorization response handler for MCP loopback clients. + + Accepts either: + + - A successful authorization response (``code`` + ``state``), which is + forwarded back to the validated client ``redirect_uri`` with the + original (un-wrapped) ``state``. + - An error response (``error``[+``error_description``/``error_uri``]), per + RFC 6749 §4.1.2.1. When ``state`` is present and decodes to a trusted + ``redirect_uri``, the error params are propagated back to the client so + its OAuth library can surface them. Otherwise we render an HTML error + page so the user is not left on an opaque 422 / blank screen. + """ + # 1. IdP-reported error path (e.g. ``?error=access_denied``). + if error: + verbose_logger.info( + "MCP /callback received IdP error: error=%s, error_description=%s", + error, + error_description, + ) + if state: + try: + state_data = decode_state_hash(state) + original_state = state_data.get("original_state") + redirect_uri = _get_validated_client_redirect_uri(request, state_data) + except HTTPException: + # Untrusted/invalid client redirect_uri — surface inline rather + # than blindly forwarding the error to an attacker-controlled URL. + return _render_oauth_error_html(error, error_description) + except Exception: + # State could not be decrypted (expired key, tampered, etc.). + return _render_oauth_error_html(error, error_description) + + params: Dict[str, str] = {"error": error} + if error_description: + params["error_description"] = error_description + if error_uri: + params["error_uri"] = error_uri + if original_state is not None: + params["state"] = original_state + complete_returned_url = _append_query_params(redirect_uri, params) + return RedirectResponse(url=complete_returned_url, status_code=302) + + # No state — nothing to round-trip to. Show the user the error. + return _render_oauth_error_html(error, error_description) + + # 2. Neither success nor error parameters present — most likely a stray + # GET / dropped SSO redirect chain. Surface a 400 instead of 422. + if not code or not state: + missing = [ + name for name, value in (("code", code), ("state", state)) if not value + ] + return _render_oauth_error_html( + "invalid_request", + f"Missing authorization {' and '.join(repr(m) for m in missing)} parameter(s).", + ) + + # 3. Successful authorization response. try: state_data = decode_state_hash(state) original_state = state_data["original_state"] diff --git a/litellm/proxy/auth/auth_utils.py b/litellm/proxy/auth/auth_utils.py index 1e87dcaef1..8626527035 100644 --- a/litellm/proxy/auth/auth_utils.py +++ b/litellm/proxy/auth/auth_utils.py @@ -213,6 +213,12 @@ _EXTRA_BANNED_OBSERVABILITY_PARAMS: FrozenSet[str] = frozenset( { "posthog_api_url", "phoenix_project_name", + "phoenix_project_name_override", + # Server-reserved: written exclusively by add_user_api_key_auth_to_request_metadata + # from the authenticated key's database record. A caller-supplied value + # would survive the server merge and let an authenticated user redirect + # their Arize/Phoenix telemetry into arbitrary projects. + "user_api_key_auth_metadata", "wandb_api_key", "weave_project_id", } diff --git a/litellm/proxy/example_config_yaml/oai_misc_config.yaml b/litellm/proxy/example_config_yaml/oai_misc_config.yaml index 0b647de8a0..551043ec76 100644 --- a/litellm/proxy/example_config_yaml/oai_misc_config.yaml +++ b/litellm/proxy/example_config_yaml/oai_misc_config.yaml @@ -23,11 +23,11 @@ model_list: model: bedrock/us.anthropic.claude-haiku-4-5-20251001-v1:0 ######################################################### ########## batch specific params ######################## - s3_bucket_name: litellm-proxy-941277531214 + s3_bucket_name: litellm-proxy-123456789012 s3_region_name: us-west-2 s3_access_key_id: os.environ/AWS_ACCESS_KEY_ID s3_secret_access_key: os.environ/AWS_SECRET_ACCESS_KEY - aws_batch_role_arn: arn:aws:iam::941277531214:role/service-role/AmazonBedrockExecutionRoleForAgents_BB9HNW6V4CV + aws_batch_role_arn: arn:aws:iam::123456789012:role/service-role/AmazonBedrockExecutionRoleForAgents_EXAMPLE model_info: mode: batch diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index 495bce2f00..549d9e469f 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -471,10 +471,32 @@ async def _emit_management_endpoint_otel_span( route = func.__name__ request_body = {} + _CREDENTIAL_FIELDS = frozenset( + { + "key", + "token", + "api_key", + "secret", + "password", + "access_token", + "refresh_token", + "private_key", + "service_account_key", + } + ) + + _response: Optional[dict] = None + if exception is None and result is not None: + try: + raw = dict(result) + _response = {k: v for k, v in raw.items() if k not in _CREDENTIAL_FIELDS} + except Exception: + _response = None + logging_payload = ManagementEndpointLoggingPayload( route=route, request_data=request_body, - response=None, + response=_response, start_time=start_time, end_time=end_time, exception=exception, diff --git a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py index 00eaba09ac..36a389d233 100644 --- a/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py +++ b/litellm/proxy/pass_through_endpoints/pass_through_endpoints.py @@ -908,7 +908,7 @@ async def pass_through_request( # noqa: PLR0915 else {"json": _parsed_body} ) req = async_client.build_request( - "POST", + request.method, url, params=requested_query_params, headers=headers, diff --git a/tests/test_litellm/integrations/arize/test_arize_phoenix.py b/tests/test_litellm/integrations/arize/test_arize_phoenix.py index 4a2eab29e8..afd83f81ce 100644 --- a/tests/test_litellm/integrations/arize/test_arize_phoenix.py +++ b/tests/test_litellm/integrations/arize/test_arize_phoenix.py @@ -7,7 +7,6 @@ from litellm.integrations.arize.arize_phoenix import ( ArizePhoenixConfig, ArizePhoenixLogger, ) -from litellm.integrations.arize._utils import ArizeOTELAttributes class TestArizePhoenixConfig(unittest.TestCase): @@ -217,44 +216,147 @@ def test_get_arize_phoenix_config_expection_on_missing_api_key(monkeypatch, env_ # --------------------------------------------------------------------------- -# Dynamic project naming from metadata +# Per-project routing via Resource (not span attributes) # --------------------------------------------------------------------------- -class TestGetDynamicProjectName: - """Tests for _get_dynamic_project_name extraction logic.""" +class TestResolveProjectName: + """Tests for _resolve_project_name priority chain.""" - def test_extracts_from_standard_logging_object_metadata(self): + def test_extracts_phoenix_name_from_standard_logging_object_metadata(self): kwargs = { "standard_logging_object": { "metadata": {"phoenix_project_name": "my-project"}, } } - assert ArizePhoenixLogger._get_dynamic_project_name(kwargs) == "my-project" + assert ArizePhoenixLogger._resolve_project_name(kwargs) == "my-project" - def test_extracts_from_litellm_params_metadata(self): + def test_extracts_phoenix_name_from_litellm_params_metadata(self): kwargs = { "litellm_params": { "metadata": {"phoenix_project_name": "sdk-project"}, } } - assert ArizePhoenixLogger._get_dynamic_project_name(kwargs) == "sdk-project" + assert ArizePhoenixLogger._resolve_project_name(kwargs) == "sdk-project" - def test_returns_none_when_no_metadata(self): - assert ArizePhoenixLogger._get_dynamic_project_name({}) is None + @patch.dict("os.environ", {"PHOENIX_PROJECT_NAME": "env-project"}, clear=False) + def test_falls_back_to_phoenix_env_when_no_metadata(self): + assert ArizePhoenixLogger._resolve_project_name({}) == "env-project" + + @patch.dict( + "os.environ", + {"ARIZE_PROJECT_NAME": "arize-env", "PHOENIX_PROJECT_NAME": ""}, + clear=False, + ) + def test_falls_back_to_arize_env_when_phoenix_unset(self): + assert ArizePhoenixLogger._resolve_project_name({}) == "arize-env" + + @patch.dict("os.environ", {}, clear=True) + def test_falls_back_to_default_when_no_metadata_or_env(self): + assert ArizePhoenixLogger._resolve_project_name({}) == "default" + + def test_phoenix_override_beats_phoenix_metadata(self): + kwargs = { + "standard_logging_object": { + "metadata": { + "phoenix_project_name_override": "override-proj", + "phoenix_project_name": "phoenix-proj", + }, + } + } + assert ArizePhoenixLogger._resolve_project_name(kwargs) == "override-proj" + + def test_whitespace_only_metadata_falls_through_to_default(self): + kwargs = { + "standard_logging_object": { + "metadata": {"phoenix_project_name_override": " "}, + } + } + with patch.dict("os.environ", {}, clear=True): + assert ArizePhoenixLogger._resolve_project_name(kwargs) == "default" + + def test_strips_whitespace_from_project_name(self): + kwargs = { + "standard_logging_object": { + "metadata": {"phoenix_project_name": " trimmed "}, + } + } + assert ArizePhoenixLogger._resolve_project_name(kwargs) == "trimmed" def test_non_dict_standard_logging_object_does_not_raise(self): - """isinstance(dict) guard prevents AttributeError on non-dict payloads.""" kwargs = {"standard_logging_object": "not-a-dict"} - assert ArizePhoenixLogger._get_dynamic_project_name(kwargs) is None + with patch.dict("os.environ", {}, clear=True): + assert ArizePhoenixLogger._resolve_project_name(kwargs) == "default" + + def test_resolves_override_from_user_api_key_auth_metadata(self): + kwargs = { + "litellm_params": { + "metadata": { + "user_api_key_auth_metadata": { + "phoenix_project_name_override": "claude-code", + }, + }, + }, + } + with patch.dict("os.environ", {}, clear=True): + assert ArizePhoenixLogger._resolve_project_name(kwargs) == "claude-code" + + def test_resolves_phoenix_name_from_user_api_key_auth_metadata(self): + kwargs = { + "standard_logging_object": { + "metadata": { + "user_api_key_auth_metadata": { + "phoenix_project_name": "team-project", + }, + }, + }, + } + with patch.dict("os.environ", {}, clear=True): + assert ArizePhoenixLogger._resolve_project_name(kwargs) == "team-project" + + def test_proxy_ignores_client_metadata_when_auth_metadata_set(self): + kwargs = { + "litellm_params": { + "proxy_server_request": { + "url": "/v1/chat/completions", + "method": "POST", + "headers": {}, + }, + "metadata": { + "phoenix_project_name_override": "attacker-project", + "user_api_key_auth_metadata": { + "phoenix_project_name_override": "team-project", + }, + }, + }, + } + with patch.dict("os.environ", {}, clear=True): + assert ArizePhoenixLogger._resolve_project_name(kwargs) == "team-project" + + def test_proxy_without_auth_metadata_falls_back_to_env(self): + kwargs = { + "litellm_params": { + "proxy_server_request": { + "url": "/v1/chat/completions", + "method": "POST", + "headers": {}, + }, + "metadata": {"phoenix_project_name": "attacker-project"}, + }, + } + with patch.dict( + "os.environ", {"PHOENIX_PROJECT_NAME": "env-project"}, clear=True + ): + assert ArizePhoenixLogger._resolve_project_name(kwargs) == "env-project" -class TestDynamicProjectNameOnSpan: - """set_arize_phoenix_attributes sets openinference.project.name on the span.""" +class TestProjectNameNotOnSpan: + """Project routing uses Resource on TracerProvider, not span attributes.""" - @patch.dict("os.environ", {"PHOENIX_PROJECT_NAME": "env-fallback"}, clear=False) @patch("litellm.integrations.arize._utils.set_attributes") - def test_dynamic_name_sets_span_attribute(self, _mock_set_attrs): + def test_set_arize_phoenix_attributes_does_not_set_project_on_span( + self, _mock_set_attrs + ): span = MagicMock() kwargs = { "standard_logging_object": { @@ -263,20 +365,468 @@ class TestDynamicProjectNameOnSpan: } ArizePhoenixLogger.set_arize_phoenix_attributes(span, kwargs, response_obj=None) - span.set_attribute.assert_called_once_with( - "openinference.project.name", "dynamic-proj" + for call in span.set_attribute.call_args_list: + assert call[0][0] != "openinference.project.name" + + +class TestPerProjectTracerProviderCache: + """Spans for different projects use different Resources on export.""" + + def test_different_metadata_routes_to_different_resource(self): + from datetime import datetime + + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, ) - @patch.dict("os.environ", {"PHOENIX_PROJECT_NAME": "env-project"}, clear=False) - @patch("litellm.integrations.arize._utils.set_attributes") - def test_falls_back_to_env_var_when_no_dynamic_name(self, _mock_set_attrs): - span = MagicMock() - ArizePhoenixLogger.set_arize_phoenix_attributes(span, {}, response_obj=None) + from litellm.integrations.opentelemetry import OpenTelemetryConfig - span.set_attribute.assert_called_once_with( - "openinference.project.name", "env-project" + exporter = InMemorySpanExporter() + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=exporter), + callback_name="arize_phoenix", ) + start = datetime(2024, 1, 1, 12, 0, 0) + end = datetime(2024, 1, 1, 12, 0, 1) + + logger._handle_success( + { + "standard_logging_object": { + "metadata": {"phoenix_project_name": "project-a"}, + }, + }, + response_obj={}, + start_time=start, + end_time=end, + ) + logger._handle_success( + { + "standard_logging_object": { + "metadata": {"phoenix_project_name": "project-b"}, + }, + }, + response_obj={}, + start_time=start, + end_time=end, + ) + + spans = exporter.get_finished_spans() + project_names = { + s.resource.attributes.get("openinference.project.name") for s in spans + } + assert "project-a" in project_names + assert "project-b" in project_names + + def test_shared_span_processor_created_once_at_init(self): + from litellm.integrations.opentelemetry import ( + OpenTelemetry, + OpenTelemetryConfig, + ) + + mock_processor = MagicMock() + with patch.object( + OpenTelemetry, "_get_span_processor", return_value=mock_processor + ) as mock_get_processor: + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=MagicMock()), + callback_name="arize_phoenix", + ) + assert mock_get_processor.call_count == 1 + assert logger._shared_span_processor is mock_processor + + logger._project_providers.clear() + logger._get_tracer_for("project-a") + logger._get_tracer_for("project-b") + assert mock_get_processor.call_count == 1 + + def test_lru_eviction_does_not_shutdown_provider(self): + from litellm.integrations.opentelemetry import OpenTelemetryConfig + + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=MagicMock()), + callback_name="arize_phoenix", + ) + logger._project_providers.clear() + + logger._get_tracer_for("project-0") + evicted_provider = logger._project_providers["project-0"] + shutdown_mock = MagicMock() + evicted_provider.shutdown = shutdown_mock # type: ignore[method-assign] + + for i in range(1, 65): + logger._get_tracer_for(f"project-{i}") + + assert len(logger._project_providers) == 64 + assert "project-0" not in logger._project_providers + assert "project-64" in logger._project_providers + shutdown_mock.assert_not_called() + + def test_flush_tracer_providers_force_flushes_shared_processor(self): + from litellm.integrations.opentelemetry import OpenTelemetryConfig + + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=MagicMock()), + callback_name="arize_phoenix", + ) + mock_processor = MagicMock() + logger._shared_span_processor = mock_processor + mock_provider = MagicMock() + logger._project_providers["proj"] = mock_provider + + logger.flush_tracer_providers() + + mock_processor.force_flush.assert_called_once() + mock_provider.force_flush.assert_called_once() + + +class TestGetLitellmResourceForProject: + """Resource attrs used by Phoenix OSS and Arize AX for project routing.""" + + def test_project_attrs_win_over_otel_resource_attributes_env(self): + from litellm.integrations.opentelemetry import OpenTelemetryConfig + + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=MagicMock()), + callback_name="arize_phoenix", + ) + + with patch.dict( + "os.environ", + { + "OTEL_RESOURCE_ATTRIBUTES": "openinference.project.name=env-pinned,model_id=env-model" + }, + clear=False, + ): + resource = logger._get_litellm_resource_for_project("dynamic-proj") + + assert resource.attributes["openinference.project.name"] == "dynamic-proj" + assert resource.attributes["model_id"] == "dynamic-proj" + assert resource.attributes["service.name"] == "dynamic-proj" + + @patch.dict("os.environ", {"OTEL_DEPLOYMENT_ENVIRONMENT": "staging"}, clear=False) + def test_preserves_deployment_environment_from_config(self): + from litellm.integrations.opentelemetry import OpenTelemetryConfig + + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig( + exporter=MagicMock(), deployment_environment="staging" + ), + callback_name="arize_phoenix", + ) + resource = logger._get_litellm_resource_for_project("my-proj") + assert resource.attributes.get("deployment.environment") == "staging" + + +class TestTracerResolutionAndCache: + """_resolve_tracer_for_kwargs, get_tracer_to_use_for_request, provider cache.""" + + def test_get_tracer_to_use_for_request_matches_resolve_tracer(self): + from litellm.integrations.opentelemetry import OpenTelemetryConfig + + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=MagicMock()), + callback_name="arize_phoenix", + ) + kwargs = { + "standard_logging_object": { + "metadata": {"phoenix_project_name": "same-proj"}, + } + } + project_name, _ = logger._resolve_tracer_for_kwargs(kwargs) + tracer_from_request = logger.get_tracer_to_use_for_request(kwargs) + assert project_name == "same-proj" + assert "same-proj" in logger._project_providers + assert logger._resolve_project_name(kwargs) == project_name + assert tracer_from_request is not None + + def test_cache_reuses_provider_for_same_project(self): + from litellm.integrations.opentelemetry import OpenTelemetryConfig + + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=MagicMock()), + callback_name="arize_phoenix", + ) + logger._project_providers.clear() + + logger._get_tracer_for("cached-proj") + provider_first = logger._project_providers["cached-proj"] + + logger._get_tracer_for("cached-proj") + provider_second = logger._project_providers["cached-proj"] + + assert provider_first is provider_second + assert len(logger._project_providers) == 1 + + def test_parallel_cache_miss_for_same_project_inserts_once(self): + import threading + + from litellm.integrations.opentelemetry import OpenTelemetryConfig + + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=MagicMock()), + callback_name="arize_phoenix", + ) + logger._project_providers.clear() + + build_calls: list[str] = [] + real_build = logger._build_tracer_provider_for_project + + def tracking_build(project_name: str): + build_calls.append(project_name) + return real_build(project_name) + + barrier = threading.Barrier(10) + errors: list[Exception] = [] + + def worker() -> None: + try: + barrier.wait() + logger._get_tracer_for("race-proj") + except Exception as exc: + errors.append(exc) + + with patch.object( + logger, + "_build_tracer_provider_for_project", + side_effect=tracking_build, + ): + threads = [threading.Thread(target=worker) for _ in range(10)] + for thread in threads: + thread.start() + for thread in threads: + thread.join() + + assert not errors + assert len(logger._project_providers) == 1 + assert "race-proj" in logger._project_providers + assert len(build_calls) >= 1 + + def test_injected_tracer_provider_bypasses_project_cache(self): + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + + from litellm.integrations.opentelemetry import OpenTelemetryConfig + + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=exporter), + callback_name="arize_phoenix", + tracer_provider=provider, + ) + + assert getattr(logger, "_use_injected_tracer_provider", False) is True + assert not hasattr(logger, "_project_providers") or not getattr( + logger, "_project_providers", None + ) + + tracer_a = logger._get_tracer_for("any-project") + tracer_b = logger.get_tracer_to_use_for_request( + {"standard_logging_object": {"metadata": {"phoenix_project_name": "x"}}} + ) + assert tracer_a is logger.tracer + assert tracer_b is logger.tracer + + def test_flush_tracer_providers_noop_for_injected_provider(self): + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + + from litellm.integrations.opentelemetry import OpenTelemetryConfig + + exporter = InMemorySpanExporter() + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(exporter)) + + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=exporter), + callback_name="arize_phoenix", + tracer_provider=provider, + ) + logger.flush_tracer_providers() + exporter.shutdown() + + def test_standard_logging_metadata_wins_over_litellm_params(self): + kwargs = { + "standard_logging_object": { + "metadata": {"phoenix_project_name_override": "from-logging"}, + }, + "litellm_params": { + "metadata": {"phoenix_project_name_override": "from-params"}, + }, + } + assert ArizePhoenixLogger._resolve_project_name(kwargs) == "from-logging" + + +class TestPhoenixTraceHandling: + """_handle_success / _handle_failure span export behavior.""" + + def test_handle_failure_sets_error_status_on_request_span(self): + from datetime import datetime + + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + from opentelemetry.trace import StatusCode + + from litellm.integrations.opentelemetry import ( + LITELLM_REQUEST_SPAN_NAME, + OpenTelemetryConfig, + ) + + exporter = InMemorySpanExporter() + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=exporter), + callback_name="arize_phoenix", + ) + + start = datetime(2024, 1, 1, 12, 0, 0) + end = datetime(2024, 1, 1, 12, 0, 1) + + logger._handle_failure( + { + "standard_logging_object": { + "metadata": {"phoenix_project_name": "fail-proj"}, + }, + "exception": Exception("boom"), + }, + response_obj=None, + start_time=start, + end_time=end, + ) + + spans = exporter.get_finished_spans() + request_spans = [s for s in spans if s.name == LITELLM_REQUEST_SPAN_NAME] + assert len(request_spans) == 1 + assert request_spans[0].status.status_code == StatusCode.ERROR + assert ( + request_spans[0].resource.attributes.get("openinference.project.name") + == "fail-proj" + ) + + def test_proxy_mode_parent_and_child_share_trace_id(self): + from datetime import datetime + + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + + from litellm.integrations.opentelemetry import ( + LITELLM_REQUEST_SPAN_NAME, + OpenTelemetryConfig, + ) + + exporter = InMemorySpanExporter() + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=exporter), + callback_name="arize_phoenix", + ) + + start = datetime(2024, 1, 1, 12, 0, 0) + end = datetime(2024, 1, 1, 12, 0, 1) + + logger._handle_success( + { + "litellm_params": { + "proxy_server_request": { + "url": "/chat/completions", + "method": "POST", + "headers": {}, + }, + "metadata": { + "user_api_key_auth_metadata": { + "phoenix_project_name_override": "proxy-proj", + }, + }, + }, + }, + response_obj={}, + start_time=start, + end_time=end, + ) + + spans = exporter.get_finished_spans() + span_names = {s.name for s in spans} + assert "litellm_proxy_request" in span_names + assert LITELLM_REQUEST_SPAN_NAME in span_names + + trace_ids = {s.context.trace_id for s in spans} + assert len(trace_ids) == 1 + for span in spans: + assert ( + span.resource.attributes.get("openinference.project.name") + == "proxy-proj" + ) + + def test_override_routes_all_spans_to_one_project_in_single_request(self): + from datetime import datetime + + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + + from litellm.integrations.opentelemetry import OpenTelemetryConfig + + exporter = InMemorySpanExporter() + logger = ArizePhoenixLogger( + config=OpenTelemetryConfig(exporter=exporter), + callback_name="arize_phoenix", + ) + + start = datetime(2024, 1, 1, 12, 0, 0) + end = datetime(2024, 1, 1, 12, 0, 1) + + logger._handle_success( + { + "standard_logging_object": { + "metadata": { + "user_api_key_auth_metadata": { + "phoenix_project_name_override": "unified-proj", + }, + }, + }, + "litellm_params": { + "proxy_server_request": { + "url": "/v1/chat/completions", + "method": "POST", + "headers": {}, + }, + }, + }, + response_obj={"id": "resp-1"}, + start_time=start, + end_time=end, + ) + + for span in exporter.get_finished_spans(): + assert ( + span.resource.attributes.get("openinference.project.name") + == "unified-proj" + ) + assert span.resource.attributes.get("model_id") == "unified-proj" + + +class TestGetArizePhoenixConfigProjectName: + @patch.dict( + "os.environ", {"PHOENIX_PROJECT_NAME": "phoenix-config-proj"}, clear=True + ) + def test_project_name_from_phoenix_env(self): + config = ArizePhoenixLogger.get_arize_phoenix_config() + assert config.project_name == "phoenix-config-proj" + + @patch.dict("os.environ", {}, clear=True) + def test_project_name_defaults_when_env_unset(self): + config = ArizePhoenixLogger.get_arize_phoenix_config() + assert config.project_name == "default" + if __name__ == "__main__": unittest.main() diff --git a/tests/test_litellm/integrations/datadog/test_datadog_metrics.py b/tests/test_litellm/integrations/datadog/test_datadog_metrics.py index 757c558c29..2a26b7fade 100644 --- a/tests/test_litellm/integrations/datadog/test_datadog_metrics.py +++ b/tests/test_litellm/integrations/datadog/test_datadog_metrics.py @@ -104,6 +104,7 @@ async def test_add_metrics_from_log(clean_env): logger._add_metrics_from_log(log=payload, kwargs=kwargs, status_code="200") # Should have 3 series: total_latency, llm_api_latency, request_count + # (no overhead metric because payload has no hidden_params litellm_overhead_time_ms) assert len(logger.log_queue) == 3 metrics = {s["metric"]: s for s in logger.log_queue} @@ -125,6 +126,72 @@ async def test_add_metrics_from_log(clean_env): assert "status_code:200" in count["tags"] +@pytest.mark.asyncio +async def test_overhead_latency_metric_emitted(clean_env): + """Test that litellm.overhead.latency is emitted when hidden_params contains litellm_overhead_time_ms.""" + logger = DatadogMetricsLogger(batch_size=100, start_periodic_flush=False) + + now = datetime.now() + start_time = now - timedelta(seconds=2) + api_call_start_time = now - timedelta(seconds=1) + + payload = StandardLoggingPayload( + custom_llm_provider="openai", + model="gpt-4o", + hidden_params={ + "litellm_overhead_time_ms": 250.0, # 250 ms of overhead + }, + ) + + kwargs = { + "start_time": start_time, + "api_call_start_time": api_call_start_time, + "end_time": now, + } + + logger._add_metrics_from_log(log=payload, kwargs=kwargs, status_code="200") + + metrics = {s["metric"]: s for s in logger.log_queue} + + # Overhead metric must be present + assert ( + "litellm.overhead.latency" in metrics + ), f"Expected 'litellm.overhead.latency' in emitted metrics, got: {list(metrics.keys())}" + overhead = metrics["litellm.overhead.latency"] + assert overhead["type"] == 3 # gauge + # 250 ms → 0.25 s + assert abs(overhead["points"][0]["value"] - 0.25) < 1e-6 + # status_code should NOT be in overhead tags (it is a latency metric, not a request count) + assert not any(tag.startswith("status_code:") for tag in overhead["tags"]) + + +@pytest.mark.asyncio +async def test_overhead_latency_metric_absent_when_no_hidden_params(clean_env): + """Test that litellm.overhead.latency is NOT emitted when hidden_params has no overhead value.""" + logger = DatadogMetricsLogger(batch_size=100, start_periodic_flush=False) + + now = datetime.now() + start_time = now - timedelta(seconds=2) + api_call_start_time = now - timedelta(seconds=1) + + payload = StandardLoggingPayload( + custom_llm_provider="openai", + model="gpt-4o", + # No hidden_params / no litellm_overhead_time_ms + ) + + kwargs = { + "start_time": start_time, + "api_call_start_time": api_call_start_time, + "end_time": now, + } + + logger._add_metrics_from_log(log=payload, kwargs=kwargs, status_code="200") + + metrics = {s["metric"]: s for s in logger.log_queue} + assert "litellm.overhead.latency" not in metrics + + @pytest.mark.asyncio async def test_async_log_success_event(clean_env): """Test that success events are added to the queue.""" diff --git a/tests/test_litellm/proxy/_experimental/mcp_server/test_callback_oauth_error_responses.py b/tests/test_litellm/proxy/_experimental/mcp_server/test_callback_oauth_error_responses.py new file mode 100644 index 0000000000..11ef40b996 --- /dev/null +++ b/tests/test_litellm/proxy/_experimental/mcp_server/test_callback_oauth_error_responses.py @@ -0,0 +1,210 @@ +"""Regression tests for LIT-2750. + +The MCP OAuth ``/callback`` endpoint must handle IdP error responses +(e.g. ``?error=access_denied``) gracefully instead of returning a 422 +because ``code`` and ``state`` were declared as required FastAPI query +params. Per RFC 6749 §4.1.2.1 the IdP redirects to the configured +redirect URI with ``error`` / ``error_description`` / ``error_uri`` +query params and no ``code`` when the user denies access. + +These tests cover both the propagate-to-client path (when state decodes +to a trusted ``redirect_uri``) and the in-page fallback (when state is +missing, undecryptable, or carries an untrusted redirect_uri). They also +pin the success path (``code`` + ``state``) against accidental +regressions. +""" + +import pytest + + +@pytest.fixture(autouse=True) +def _mock_mcp_client_ip(): + """Bypass IP-based access control for the in-process TestClient. + + Mirrors the autouse fixture in ``test_discoverable_endpoints.py`` so + these tests don't require a real client IP context. + """ + from unittest.mock import patch + + with patch( + "litellm.proxy._experimental.mcp_server.discoverable_endpoints.IPAddressUtils.get_mcp_client_ip", + return_value=None, + ): + yield + + +@pytest.fixture +def callback_test_client(monkeypatch): + """FastAPI TestClient mounted with the MCP discoverable router. + + Sets a deterministic ``LITELLM_SALT_KEY`` so encoded states minted + in-test can be decrypted by the handler. + """ + from fastapi import FastAPI + from fastapi.testclient import TestClient + + monkeypatch.setenv("LITELLM_SALT_KEY", "sk-test-salt-for-LIT-2750") + + from litellm.proxy._experimental.mcp_server.discoverable_endpoints import ( + router, + ) + + app = FastAPI() + app.include_router(router) + return TestClient(app) + + +class TestCallbackOAuthErrorResponses: + """LIT-2750: IdP error responses to ``/callback`` must not 422.""" + + def test_idp_error_with_no_state_returns_400_html(self, callback_test_client): + """Pre-fix: 422 Pydantic. Post-fix: 400 HTML with the IdP's error.""" + resp = callback_test_client.get( + "/callback", + params={ + "error": "access_denied", + "error_description": "User declined access", + }, + follow_redirects=False, + ) + assert resp.status_code == 400 + assert "text/html" in resp.headers["content-type"] + body = resp.text + assert "access_denied" in body + assert "User declined access" in body + # Sanity: must not leak the Pydantic validation error. + assert "Field required" not in body + + def test_idp_error_html_escapes_user_controlled_fields( + self, callback_test_client + ): + """A malicious IdP must not be able to inject HTML/JS via error params.""" + resp = callback_test_client.get( + "/callback", + params={ + "error": "", + "error_description": "