diff --git a/.github/workflows/test-unit-proxy-endpoints.yml b/.github/workflows/test-unit-proxy-endpoints.yml index 118408f746..6b34a08a8e 100644 --- a/.github/workflows/test-unit-proxy-endpoints.yml +++ b/.github/workflows/test-unit-proxy-endpoints.yml @@ -33,6 +33,7 @@ jobs: tests/test_litellm/proxy/image_endpoints tests/test_litellm/proxy/vector_store_endpoints tests/test_litellm/proxy/agent_endpoints + tests/test_litellm/proxy/a2a tests/test_litellm/proxy/discovery_endpoints tests/test_litellm/proxy/health_endpoints tests/test_litellm/proxy/public_endpoints diff --git a/litellm/a2a_protocol/litellm_completion_bridge/handler.py b/litellm/a2a_protocol/litellm_completion_bridge/handler.py index 4e66fe4ba6..67ffcf4f8f 100644 --- a/litellm/a2a_protocol/litellm_completion_bridge/handler.py +++ b/litellm/a2a_protocol/litellm_completion_bridge/handler.py @@ -107,6 +107,14 @@ class A2ACompletionBridgeHandler: if k not in ("model", "custom_llm_provider") and k not in _AGENT_ONLY_PARAMS } completion_params.update(litellm_params_to_add) + # Apply forward metadata AFTER the litellm_params merge so the helper + # sees any agent-owner-configured ``extra_body.metadata`` and can keep + # those keys authoritative over the client-supplied A2A metadata. + A2ACompletionBridgeTransformation.apply_forward_metadata_to_completion_params( + completion_params=completion_params, + a2a_message=message, + params=params, + ) # Call litellm.acompletion response = await litellm.acompletion(**completion_params) @@ -214,6 +222,14 @@ class A2ACompletionBridgeHandler: if k not in ("model", "custom_llm_provider") and k not in _AGENT_ONLY_PARAMS } completion_params.update(litellm_params_to_add) + # Apply forward metadata AFTER the litellm_params merge so the helper + # sees any agent-owner-configured ``extra_body.metadata`` and can keep + # those keys authoritative over the client-supplied A2A metadata. + A2ACompletionBridgeTransformation.apply_forward_metadata_to_completion_params( + completion_params=completion_params, + a2a_message=message, + params=params, + ) # 1. Emit initial task event (kind: "task", status: "submitted") task_event = A2ACompletionBridgeTransformation.create_task_event(ctx) diff --git a/litellm/a2a_protocol/litellm_completion_bridge/transformation.py b/litellm/a2a_protocol/litellm_completion_bridge/transformation.py index 8a03569f68..06c0a8fc82 100644 --- a/litellm/a2a_protocol/litellm_completion_bridge/transformation.py +++ b/litellm/a2a_protocol/litellm_completion_bridge/transformation.py @@ -45,10 +45,80 @@ class A2ACompletionBridgeTransformation: Static methods for transforming between A2A and OpenAI message formats. """ + @staticmethod + def _extract_text_from_a2a_parts(parts: List[Dict[str, Any]]) -> str: + """Extract text from A2A parts (with or without explicit ``kind``).""" + content_parts: List[str] = [] + for part in parts: + if not isinstance(part, dict): + continue + kind = part.get("kind") + text = part.get("text") + if text is None: + continue + if kind in (None, "", "text"): + content_parts.append(str(text)) + return "\n".join(content_parts) + + @staticmethod + def get_forward_metadata( + a2a_message: Dict[str, Any], + params: Optional[Dict[str, Any]] = None, + ) -> Optional[Dict[str, Any]]: + """ + Merge A2A metadata from MessageSendParams and the message for downstream providers. + + Forwarded once on the LangGraph run payload (``metadata``), not duplicated on + each input message — see ``apply_forward_metadata_to_completion_params``. + """ + merged: Dict[str, Any] = {} + if params and isinstance(params.get("metadata"), dict): + merged.update(params["metadata"]) + message_metadata = a2a_message.get("metadata") + if isinstance(message_metadata, dict): + merged.update(message_metadata) + return merged or None + + @staticmethod + def apply_forward_metadata_to_completion_params( + completion_params: Dict[str, Any], + a2a_message: Dict[str, Any], + params: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Attach A2A metadata to completion kwargs for provider bridges (e.g. LangGraph). + + Uses ``extra_body`` so we do not collide with LiteLLM's spend-log ``metadata`` kwarg. + """ + forward_metadata = A2ACompletionBridgeTransformation.get_forward_metadata( + a2a_message=a2a_message, + params=params, + ) + if not forward_metadata: + return + + extra_body = completion_params.get("extra_body") + if not isinstance(extra_body, dict): + extra_body = {} + # Layer client-supplied A2A metadata under any agent-owner-configured + # ``extra_body.metadata`` so the configured keys remain authoritative + # and an A2A caller cannot overwrite server-set run metadata. + existing_metadata = extra_body.get("metadata") + existing_dict: Dict[str, Any] = ( + existing_metadata if isinstance(existing_metadata, dict) else {} + ) + merged_metadata: Dict[str, Any] = {**forward_metadata, **existing_dict} + extra_body = {**extra_body, "metadata": merged_metadata} + completion_params["extra_body"] = extra_body + + verbose_logger.debug( + f"A2A -> completion forward metadata keys={list(forward_metadata.keys())}" + ) + @staticmethod def a2a_message_to_openai_messages( a2a_message: Dict[str, Any], - ) -> List[Dict[str, str]]: + ) -> List[Dict[str, Any]]: """ Transform an A2A message to OpenAI message format. @@ -70,21 +140,20 @@ class A2ACompletionBridgeTransformation: elif role == "system": openai_role = "system" - # Extract text content from parts - content_parts = [] - for part in parts: - kind = part.get("kind", "") - if kind == "text": - text = part.get("text", "") - content_parts.append(text) + if not isinstance(parts, list): + parts = [] - content = "\n".join(content_parts) if content_parts else "" + content = A2ACompletionBridgeTransformation._extract_text_from_a2a_parts(parts) + + # Do not attach A2A message.metadata here — the completion bridge forwards it + # once at run level via extra_body.metadata (LangGraph POST /runs/wait shape). + openai_message: Dict[str, Any] = {"role": openai_role, "content": content} verbose_logger.debug( f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}" ) - return [{"role": openai_role, "content": content}] + return [openai_message] @staticmethod def openai_response_to_a2a_response( @@ -110,6 +179,7 @@ class A2ACompletionBridgeTransformation: # Build A2A message a2a_message = { + "kind": "message", "role": "agent", "parts": [{"kind": "text", "text": content}], "messageId": uuid4().hex, @@ -119,9 +189,7 @@ class A2ACompletionBridgeTransformation: a2a_response = { "jsonrpc": "2.0", "id": request_id, - "result": { - "message": a2a_message, - }, + "result": a2a_message, } verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}") @@ -235,50 +303,3 @@ class A2ACompletionBridgeTransformation: "taskId": ctx.task_id, }, } - - @staticmethod - def openai_chunk_to_a2a_chunk( - chunk: Any, - request_id: Optional[str] = None, - is_final: bool = False, - ) -> Optional[Dict[str, Any]]: - """ - Transform a LiteLLM streaming chunk to A2A streaming format. - - NOTE: This method is deprecated for streaming. Use the event-based - methods (create_task_event, create_status_update_event, - create_artifact_update_event) instead for proper A2A streaming. - - Args: - chunk: LiteLLM ModelResponse chunk - request_id: Original A2A request ID - is_final: Whether this is the final chunk - - Returns: - A2A streaming chunk dict or None if no content - """ - # Extract delta content - content = "" - if chunk is not None and hasattr(chunk, "choices") and chunk.choices: - choice = chunk.choices[0] - if hasattr(choice, "delta") and choice.delta: - content = choice.delta.content or "" - - if not content and not is_final: - return None - - # Build A2A streaming chunk (legacy format) - a2a_chunk = { - "jsonrpc": "2.0", - "id": request_id, - "result": { - "message": { - "role": "agent", - "parts": [{"kind": "text", "text": content}], - "messageId": uuid4().hex, - }, - "final": is_final, - }, - } - - return a2a_chunk diff --git a/litellm/a2a_protocol/providers/litellm_completion/README.md b/litellm/a2a_protocol/providers/litellm_completion/README.md deleted file mode 100644 index a809e9bf55..0000000000 --- a/litellm/a2a_protocol/providers/litellm_completion/README.md +++ /dev/null @@ -1,74 +0,0 @@ -# A2A to LiteLLM Completion Bridge - -Routes A2A protocol requests through `litellm.acompletion`, enabling any LiteLLM-supported provider to be invoked via A2A. - -## Flow - -``` -A2A Request → Transform → litellm.acompletion → Transform → A2A Response -``` - -## SDK Usage - -Use the existing `asend_message` and `asend_message_streaming` functions with `litellm_params`: - -```python -from litellm.a2a_protocol import asend_message, asend_message_streaming -from a2a.types import SendMessageRequest, SendStreamingMessageRequest, MessageSendParams -from uuid import uuid4 - -# Non-streaming -request = SendMessageRequest( - id=str(uuid4()), - params=MessageSendParams( - message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex} - ) -) -response = await asend_message( - request=request, - api_base="http://localhost:2024", - litellm_params={"custom_llm_provider": "langgraph", "model": "agent"}, -) - -# Streaming -stream_request = SendStreamingMessageRequest( - id=str(uuid4()), - params=MessageSendParams( - message={"role": "user", "parts": [{"kind": "text", "text": "Hello!"}], "messageId": uuid4().hex} - ) -) -async for chunk in asend_message_streaming( - request=stream_request, - api_base="http://localhost:2024", - litellm_params={"custom_llm_provider": "langgraph", "model": "agent"}, -): - print(chunk) -``` - -## Proxy Usage - -Configure an agent with `custom_llm_provider` in `litellm_params`: - -```yaml -agents: - - agent_name: my-langgraph-agent - agent_card_params: - name: "LangGraph Agent" - url: "http://localhost:2024" # Used as api_base - litellm_params: - custom_llm_provider: langgraph - model: agent -``` - -When an A2A request hits `/a2a/{agent_id}/message/send`, the bridge: - -1. Detects `custom_llm_provider` in agent's `litellm_params` -2. Transforms A2A message → OpenAI messages -3. Calls `litellm.acompletion(model="langgraph/agent", api_base="http://localhost:2024")` -4. Transforms response → A2A format - -## Classes - -- `A2ACompletionBridgeTransformation` - Static methods for message format conversion -- `A2ACompletionBridgeHandler` - Static methods for handling requests (streaming/non-streaming) - diff --git a/litellm/a2a_protocol/providers/litellm_completion/__init__.py b/litellm/a2a_protocol/providers/litellm_completion/__init__.py deleted file mode 100644 index fc2fc17f54..0000000000 --- a/litellm/a2a_protocol/providers/litellm_completion/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -""" -LiteLLM Completion bridge provider for A2A protocol. - -Routes A2A requests through litellm.acompletion based on custom_llm_provider. -""" diff --git a/litellm/a2a_protocol/providers/litellm_completion/handler.py b/litellm/a2a_protocol/providers/litellm_completion/handler.py deleted file mode 100644 index 730f8f6b36..0000000000 --- a/litellm/a2a_protocol/providers/litellm_completion/handler.py +++ /dev/null @@ -1,301 +0,0 @@ -""" -Handler for A2A to LiteLLM completion bridge. - -Routes A2A requests through litellm.acompletion based on custom_llm_provider. - -A2A Streaming Events (in order): -1. Task event (kind: "task") - Initial task creation with status "submitted" -2. Status update (kind: "status-update") - Status change to "working" -3. Artifact update (kind: "artifact-update") - Content/artifact delivery -4. Status update (kind: "status-update") - Final status "completed" with final=true -""" - -from typing import Any, AsyncIterator, Dict, Optional - -import litellm -from litellm._logging import verbose_logger -from litellm.a2a_protocol.litellm_completion_bridge.pydantic_ai_transformation import ( - PydanticAITransformation, -) -from litellm.a2a_protocol.litellm_completion_bridge.transformation import ( - A2ACompletionBridgeTransformation, - A2AStreamingContext, -) - - -class A2ACompletionBridgeHandler: - """ - Static methods for handling A2A requests via LiteLLM completion. - """ - - @staticmethod - async def handle_non_streaming( - request_id: str, - params: Dict[str, Any], - litellm_params: Dict[str, Any], - api_base: Optional[str] = None, - ) -> Dict[str, Any]: - """ - Handle non-streaming A2A request via litellm.acompletion. - - Args: - request_id: A2A JSON-RPC request ID - params: A2A MessageSendParams containing the message - litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.) - api_base: API base URL from agent_card_params - - Returns: - A2A SendMessageResponse dict - """ - # Check if this is a Pydantic AI agent request - custom_llm_provider = litellm_params.get("custom_llm_provider") - if custom_llm_provider == "pydantic_ai_agents": - if api_base is None: - raise ValueError("api_base is required for Pydantic AI agents") - - verbose_logger.info( - f"Pydantic AI: Routing to Pydantic AI agent at {api_base}" - ) - - # Send request directly to Pydantic AI agent - response_data = await PydanticAITransformation.send_non_streaming_request( - api_base=api_base, - request_id=request_id, - params=params, - ) - - return response_data - - # Extract message from params - message = params.get("message", {}) - - # Transform A2A message to OpenAI format - openai_messages = ( - A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message) - ) - - # Get completion params - custom_llm_provider = litellm_params.get("custom_llm_provider") - model = litellm_params.get("model", "agent") - - # Build full model string if provider specified - # Skip prepending if model already starts with the provider prefix - if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"): - full_model = f"{custom_llm_provider}/{model}" - else: - full_model = model - - verbose_logger.info( - f"A2A completion bridge: model={full_model}, api_base={api_base}" - ) - - # Build completion params dict - completion_params = { - "model": full_model, - "messages": openai_messages, - "api_base": api_base, - "stream": False, - } - # Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.) - litellm_params_to_add = { - k: v - for k, v in litellm_params.items() - if k not in ("model", "custom_llm_provider") - } - completion_params.update(litellm_params_to_add) - - # Call litellm.acompletion - response = await litellm.acompletion(**completion_params) - - # Transform response to A2A format - a2a_response = ( - A2ACompletionBridgeTransformation.openai_response_to_a2a_response( - response=response, - request_id=request_id, - ) - ) - - verbose_logger.info(f"A2A completion bridge completed: request_id={request_id}") - - return a2a_response - - @staticmethod - async def handle_streaming( - request_id: str, - params: Dict[str, Any], - litellm_params: Dict[str, Any], - api_base: Optional[str] = None, - ) -> AsyncIterator[Dict[str, Any]]: - """ - Handle streaming A2A request via litellm.acompletion with stream=True. - - Emits proper A2A streaming events: - 1. Task event (kind: "task") - Initial task with status "submitted" - 2. Status update (kind: "status-update") - Status "working" - 3. Artifact update (kind: "artifact-update") - Content delivery - 4. Status update (kind: "status-update") - Final "completed" status - - Args: - request_id: A2A JSON-RPC request ID - params: A2A MessageSendParams containing the message - litellm_params: Agent's litellm_params (custom_llm_provider, model, etc.) - api_base: API base URL from agent_card_params - - Yields: - A2A streaming response events - """ - # Check if this is a Pydantic AI agent request - custom_llm_provider = litellm_params.get("custom_llm_provider") - if custom_llm_provider == "pydantic_ai_agents": - if api_base is None: - raise ValueError("api_base is required for Pydantic AI agents") - - verbose_logger.info( - f"Pydantic AI: Faking streaming for Pydantic AI agent at {api_base}" - ) - - # Get non-streaming response first - response_data = await PydanticAITransformation.send_non_streaming_request( - api_base=api_base, - request_id=request_id, - params=params, - ) - - # Convert to fake streaming - async for chunk in PydanticAITransformation.fake_streaming_from_response( - response_data=response_data, - request_id=request_id, - ): - yield chunk - - return - - # Extract message from params - message = params.get("message", {}) - - # Create streaming context - ctx = A2AStreamingContext( - request_id=request_id, - input_message=message, - ) - - # Transform A2A message to OpenAI format - openai_messages = ( - A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message) - ) - - # Get completion params - custom_llm_provider = litellm_params.get("custom_llm_provider") - model = litellm_params.get("model", "agent") - - # Build full model string if provider specified - # Skip prepending if model already starts with the provider prefix - if custom_llm_provider and not model.startswith(f"{custom_llm_provider}/"): - full_model = f"{custom_llm_provider}/{model}" - else: - full_model = model - - verbose_logger.info( - f"A2A completion bridge streaming: model={full_model}, api_base={api_base}" - ) - - # Build completion params dict - completion_params = { - "model": full_model, - "messages": openai_messages, - "api_base": api_base, - "stream": True, - } - # Add litellm_params (contains api_key, client_id, client_secret, tenant_id, etc.) - litellm_params_to_add = { - k: v - for k, v in litellm_params.items() - if k not in ("model", "custom_llm_provider") - } - completion_params.update(litellm_params_to_add) - - # 1. Emit initial task event (kind: "task", status: "submitted") - task_event = A2ACompletionBridgeTransformation.create_task_event(ctx) - yield task_event - - # 2. Emit status update (kind: "status-update", status: "working") - working_event = A2ACompletionBridgeTransformation.create_status_update_event( - ctx=ctx, - state="working", - final=False, - message_text="Processing request...", - ) - yield working_event - - # Call litellm.acompletion with streaming - response = await litellm.acompletion(**completion_params) - - # 3. Accumulate content and emit artifact update - accumulated_text = "" - chunk_count = 0 - async for chunk in response: # type: ignore[union-attr] - chunk_count += 1 - - # Extract delta content - content = "" - if chunk is not None and hasattr(chunk, "choices") and chunk.choices: - choice = chunk.choices[0] - if hasattr(choice, "delta") and choice.delta: - content = choice.delta.content or "" - - if content: - accumulated_text += content - - # Emit artifact update with accumulated content - if accumulated_text: - artifact_event = ( - A2ACompletionBridgeTransformation.create_artifact_update_event( - ctx=ctx, - text=accumulated_text, - ) - ) - yield artifact_event - - # 4. Emit final status update (kind: "status-update", status: "completed", final: true) - completed_event = A2ACompletionBridgeTransformation.create_status_update_event( - ctx=ctx, - state="completed", - final=True, - ) - yield completed_event - - verbose_logger.info( - f"A2A completion bridge streaming completed: request_id={request_id}, chunks={chunk_count}" - ) - - -# Convenience functions that delegate to the class methods -async def handle_a2a_completion( - request_id: str, - params: Dict[str, Any], - litellm_params: Dict[str, Any], - api_base: Optional[str] = None, -) -> Dict[str, Any]: - """Convenience function for non-streaming A2A completion.""" - return await A2ACompletionBridgeHandler.handle_non_streaming( - request_id=request_id, - params=params, - litellm_params=litellm_params, - api_base=api_base, - ) - - -async def handle_a2a_completion_streaming( - request_id: str, - params: Dict[str, Any], - litellm_params: Dict[str, Any], - api_base: Optional[str] = None, -) -> AsyncIterator[Dict[str, Any]]: - """Convenience function for streaming A2A completion.""" - async for chunk in A2ACompletionBridgeHandler.handle_streaming( - request_id=request_id, - params=params, - litellm_params=litellm_params, - api_base=api_base, - ): - yield chunk diff --git a/litellm/a2a_protocol/providers/litellm_completion/transformation.py b/litellm/a2a_protocol/providers/litellm_completion/transformation.py deleted file mode 100644 index 8a03569f68..0000000000 --- a/litellm/a2a_protocol/providers/litellm_completion/transformation.py +++ /dev/null @@ -1,284 +0,0 @@ -""" -Transformation utilities for A2A <-> OpenAI message format conversion. - -A2A Message Format: -{ - "role": "user", - "parts": [{"kind": "text", "text": "Hello!"}], - "messageId": "abc123" -} - -OpenAI Message Format: -{"role": "user", "content": "Hello!"} - -A2A Streaming Events: -- Task event (kind: "task") - Initial task creation with status "submitted" -- Status update (kind: "status-update") - Status changes (working, completed) -- Artifact update (kind: "artifact-update") - Content/artifact delivery -""" - -from datetime import datetime, timezone -from typing import Any, Dict, List, Optional -from uuid import uuid4 - -from litellm._logging import verbose_logger - - -class A2AStreamingContext: - """ - Context holder for A2A streaming state. - Tracks task_id, context_id, and message accumulation. - """ - - def __init__(self, request_id: str, input_message: Dict[str, Any]): - self.request_id = request_id - self.task_id = str(uuid4()) - self.context_id = str(uuid4()) - self.input_message = input_message - self.accumulated_text = "" - self.has_emitted_task = False - self.has_emitted_working = False - - -class A2ACompletionBridgeTransformation: - """ - Static methods for transforming between A2A and OpenAI message formats. - """ - - @staticmethod - def a2a_message_to_openai_messages( - a2a_message: Dict[str, Any], - ) -> List[Dict[str, str]]: - """ - Transform an A2A message to OpenAI message format. - - Args: - a2a_message: A2A message with role, parts, and messageId - - Returns: - List of OpenAI-format messages - """ - role = a2a_message.get("role", "user") - parts = a2a_message.get("parts", []) - - # Map A2A roles to OpenAI roles - openai_role = role - if role == "user": - openai_role = "user" - elif role == "assistant": - openai_role = "assistant" - elif role == "system": - openai_role = "system" - - # Extract text content from parts - content_parts = [] - for part in parts: - kind = part.get("kind", "") - if kind == "text": - text = part.get("text", "") - content_parts.append(text) - - content = "\n".join(content_parts) if content_parts else "" - - verbose_logger.debug( - f"A2A -> OpenAI transform: role={role} -> {openai_role}, content_length={len(content)}" - ) - - return [{"role": openai_role, "content": content}] - - @staticmethod - def openai_response_to_a2a_response( - response: Any, - request_id: Optional[str] = None, - ) -> Dict[str, Any]: - """ - Transform a LiteLLM ModelResponse to A2A SendMessageResponse format. - - Args: - response: LiteLLM ModelResponse object - request_id: Original A2A request ID - - Returns: - A2A SendMessageResponse dict - """ - # Extract content from response - content = "" - if hasattr(response, "choices") and response.choices: - choice = response.choices[0] - if hasattr(choice, "message") and choice.message: - content = choice.message.content or "" - - # Build A2A message - a2a_message = { - "role": "agent", - "parts": [{"kind": "text", "text": content}], - "messageId": uuid4().hex, - } - - # Build A2A response - a2a_response = { - "jsonrpc": "2.0", - "id": request_id, - "result": { - "message": a2a_message, - }, - } - - verbose_logger.debug(f"OpenAI -> A2A transform: content_length={len(content)}") - - return a2a_response - - @staticmethod - def _get_timestamp() -> str: - """Get current timestamp in ISO format with timezone.""" - return datetime.now(timezone.utc).isoformat() - - @staticmethod - def create_task_event( - ctx: A2AStreamingContext, - ) -> Dict[str, Any]: - """ - Create the initial task event with status 'submitted'. - - This is the first event emitted in an A2A streaming response. - """ - return { - "id": ctx.request_id, - "jsonrpc": "2.0", - "result": { - "contextId": ctx.context_id, - "history": [ - { - "contextId": ctx.context_id, - "kind": "message", - "messageId": ctx.input_message.get("messageId", uuid4().hex), - "parts": ctx.input_message.get("parts", []), - "role": ctx.input_message.get("role", "user"), - "taskId": ctx.task_id, - } - ], - "id": ctx.task_id, - "kind": "task", - "status": { - "state": "submitted", - }, - }, - } - - @staticmethod - def create_status_update_event( - ctx: A2AStreamingContext, - state: str, - final: bool = False, - message_text: Optional[str] = None, - ) -> Dict[str, Any]: - """ - Create a status update event. - - Args: - ctx: Streaming context - state: Status state ('working', 'completed') - final: Whether this is the final event - message_text: Optional message text for 'working' status - """ - status: Dict[str, Any] = { - "state": state, - "timestamp": A2ACompletionBridgeTransformation._get_timestamp(), - } - - # Add message for 'working' status - if state == "working" and message_text: - status["message"] = { - "contextId": ctx.context_id, - "kind": "message", - "messageId": str(uuid4()), - "parts": [{"kind": "text", "text": message_text}], - "role": "agent", - "taskId": ctx.task_id, - } - - return { - "id": ctx.request_id, - "jsonrpc": "2.0", - "result": { - "contextId": ctx.context_id, - "final": final, - "kind": "status-update", - "status": status, - "taskId": ctx.task_id, - }, - } - - @staticmethod - def create_artifact_update_event( - ctx: A2AStreamingContext, - text: str, - ) -> Dict[str, Any]: - """ - Create an artifact update event with content. - - Args: - ctx: Streaming context - text: The text content for the artifact - """ - return { - "id": ctx.request_id, - "jsonrpc": "2.0", - "result": { - "artifact": { - "artifactId": str(uuid4()), - "name": "response", - "parts": [{"kind": "text", "text": text}], - }, - "contextId": ctx.context_id, - "kind": "artifact-update", - "taskId": ctx.task_id, - }, - } - - @staticmethod - def openai_chunk_to_a2a_chunk( - chunk: Any, - request_id: Optional[str] = None, - is_final: bool = False, - ) -> Optional[Dict[str, Any]]: - """ - Transform a LiteLLM streaming chunk to A2A streaming format. - - NOTE: This method is deprecated for streaming. Use the event-based - methods (create_task_event, create_status_update_event, - create_artifact_update_event) instead for proper A2A streaming. - - Args: - chunk: LiteLLM ModelResponse chunk - request_id: Original A2A request ID - is_final: Whether this is the final chunk - - Returns: - A2A streaming chunk dict or None if no content - """ - # Extract delta content - content = "" - if chunk is not None and hasattr(chunk, "choices") and chunk.choices: - choice = chunk.choices[0] - if hasattr(choice, "delta") and choice.delta: - content = choice.delta.content or "" - - if not content and not is_final: - return None - - # Build A2A streaming chunk (legacy format) - a2a_chunk = { - "jsonrpc": "2.0", - "id": request_id, - "result": { - "message": { - "role": "agent", - "parts": [{"kind": "text", "text": content}], - "messageId": uuid4().hex, - }, - "final": is_final, - }, - } - - return a2a_chunk diff --git a/litellm/a2a_protocol/providers/pydantic_ai_agents/transformation.py b/litellm/a2a_protocol/providers/pydantic_ai_agents/transformation.py index e73b17ac3c..bf68a01d98 100644 --- a/litellm/a2a_protocol/providers/pydantic_ai_agents/transformation.py +++ b/litellm/a2a_protocol/providers/pydantic_ai_agents/transformation.py @@ -289,16 +289,16 @@ class PydanticAITransformation: Transform Pydantic AI task response to standard A2A non-streaming format. Pydantic AI returns a task with history/artifacts, but the standard A2A - non-streaming format expects: + non-streaming format expects ``result`` to be the Message directly + (``kind="message"``), per the A2A spec / ``SendMessageResponse``: { "jsonrpc": "2.0", "id": "...", "result": { - "message": { - "role": "agent", - "parts": [{"kind": "text", "text": "..."}], - "messageId": "..." - } + "kind": "message", + "role": "agent", + "parts": [{"kind": "text", "text": "..."}], + "messageId": "..." } } @@ -316,6 +316,7 @@ class PydanticAITransformation: # Build standard A2A message a2a_message = { + "kind": "message", "role": "agent", "parts": parts if parts else [{"kind": "text", "text": full_text}], "messageId": message_id, @@ -325,9 +326,7 @@ class PydanticAITransformation: return { "jsonrpc": "2.0", "id": request_id, - "result": { - "message": a2a_message, - }, + "result": a2a_message, } @staticmethod diff --git a/litellm/a2a_protocol/utils.py b/litellm/a2a_protocol/utils.py index 1cdbde9775..0dbd1eefc6 100644 --- a/litellm/a2a_protocol/utils.py +++ b/litellm/a2a_protocol/utils.py @@ -60,6 +60,12 @@ class A2ARequestUtils: if not isinstance(result, dict): return "" + # Direct message format (A2A spec): detect by explicit kind tag only. + # The "parts" heuristic is too broad and would match any future result + # type that happens to include a "parts" field. + if result.get("kind") == "message": + return A2ARequestUtils.extract_text_from_message(result) + message = result.get("message", {}) return A2ARequestUtils.extract_text_from_message(message) diff --git a/litellm/llms/langgraph/chat/transformation.py b/litellm/llms/langgraph/chat/transformation.py index 00cc3a8f51..9808b665b5 100644 --- a/litellm/llms/langgraph/chat/transformation.py +++ b/litellm/llms/langgraph/chat/transformation.py @@ -139,14 +139,16 @@ class LangGraphConfig(BaseConfig): def _convert_messages_to_langgraph_format( self, messages: List[AllMessageValues] - ) -> List[Dict[str, str]]: + ) -> List[Dict[str, Any]]: """ Convert OpenAI-format messages to LangGraph format. OpenAI format: {"role": "user", "content": "..."} LangGraph format: {"role": "human", "content": "..."} + + Preserves per-message ``metadata`` when present (e.g. A2A ``skillId``). """ - langgraph_messages: List[Dict[str, str]] = [] + langgraph_messages: List[Dict[str, Any]] = [] for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") @@ -169,7 +171,15 @@ class LangGraphConfig(BaseConfig): if not isinstance(content, str): content = str(content) - langgraph_messages.append({"role": langgraph_role, "content": content}) + langgraph_message: Dict[str, Any] = { + "role": langgraph_role, + "content": content, + } + message_metadata = msg.get("metadata") + if isinstance(message_metadata, dict) and message_metadata: + langgraph_message["metadata"] = message_metadata + + langgraph_messages.append(langgraph_message) return langgraph_messages diff --git a/litellm/proxy/_lazy_features.py b/litellm/proxy/_lazy_features.py index a70c5b3a92..73b8a5538d 100644 --- a/litellm/proxy/_lazy_features.py +++ b/litellm/proxy/_lazy_features.py @@ -51,6 +51,11 @@ class LazyFeature: # whose routes don't appear in the parent app's openapi spec. persistent_swagger_stub: bool = False + def matches(self, path: str) -> bool: + return any(path.startswith(p) for p in self.path_prefixes) or any( + path.endswith(s) for s in self.path_suffixes + ) + LAZY_FEATURES: Tuple[LazyFeature, ...] = ( LazyFeature( @@ -92,7 +97,15 @@ LAZY_FEATURES: Tuple[LazyFeature, ...] = ( LazyFeature( name="a2a", module_path="litellm.proxy.agent_endpoints.a2a_endpoints", - path_prefixes=("/a2a", "/v1/a2a"), + # ``/v1/a2a/{agent_id}/message/send`` is caught via the suffix so the + # ``/v1/a2a`` prefix doesn't subsume the discover prefix below. + path_prefixes=("/a2a",), + path_suffixes=("/message/send",), + ), + LazyFeature( + name="a2a_registration", + module_path="litellm.proxy.a2a.endpoints", + path_prefixes=("/v1/a2a/discover",), ), LazyFeature( name="vector_stores", @@ -298,9 +311,7 @@ class LazyFeatureMiddleware: for feat in self._features: if feat.module_path in self._loaded: continue - if any(path.startswith(p) for p in feat.path_prefixes) or any( - path.endswith(s) for s in feat.path_suffixes - ): + if feat.matches(path): await _force_load(self._fastapi_app, feat) await self.app(scope, receive, send) @@ -376,11 +387,7 @@ def _make_warmup_router(app: "FastAPI") -> "APIRouter": await _force_load(app, feat) - feat_routes = [ - r - for r in app.routes - if any(getattr(r, "path", "").startswith(p) for p in feat.path_prefixes) - ] + feat_routes = [r for r in app.routes if feat.matches(getattr(r, "path", ""))] full = get_openapi(title=app.title, version=app.version, routes=feat_routes) # Force all operations under one tag so they group under a single Swagger # section — many lazy modules tag routes inconsistently. diff --git a/litellm/proxy/a2a/__init__.py b/litellm/proxy/a2a/__init__.py new file mode 100644 index 0000000000..10fb308f9f --- /dev/null +++ b/litellm/proxy/a2a/__init__.py @@ -0,0 +1,29 @@ +""" +A2A registration helpers for the LiteLLM proxy. + +- ``discovery``: fetches the upstream agent's well-known card so the UI can + display its skills/capabilities for the user to pick from. +- ``agent_card``: pure merge logic that builds the LiteLLM-fronted agent card + from the upstream card + the values the user set in the UI. +- ``endpoints``: FastAPI routes that wire the above into the proxy. +""" + +from litellm.proxy.a2a.agent_card import ( + LITELLM_A2A_PROTOCOL_VERSION, + LITELLM_SECURITY_SCHEMES, + LITELLM_SECURITY_REQUIREMENTS, + merge_agent_card, +) +from litellm.proxy.a2a.discovery import ( + AGENT_CARD_WELL_KNOWN_PATHS, + fetch_well_known_card, +) + +__all__ = [ + "AGENT_CARD_WELL_KNOWN_PATHS", + "LITELLM_A2A_PROTOCOL_VERSION", + "LITELLM_SECURITY_REQUIREMENTS", + "LITELLM_SECURITY_SCHEMES", + "fetch_well_known_card", + "merge_agent_card", +] diff --git a/litellm/proxy/a2a/agent_card.py b/litellm/proxy/a2a/agent_card.py new file mode 100644 index 0000000000..57d360ab5a --- /dev/null +++ b/litellm/proxy/a2a/agent_card.py @@ -0,0 +1,181 @@ +""" +Pure logic for merging an upstream A2A agent card with LiteLLM-specific overrides. + +The merge produces the card that LiteLLM exposes to A2A clients at +``/a2a/{agent_id}/.well-known/agent-card.json``. The upstream card is taken as +the base; specific fields are replaced so all traffic flows through the proxy +and uses LiteLLM auth. +""" + +from copy import deepcopy +from typing import Any, Dict, List, Mapping, Optional + +# Protocol version LiteLLM speaks. Bump when the proxy's A2A surface changes. +LITELLM_A2A_PROTOCOL_VERSION = "1.0" + +# Security scheme exposed by the LiteLLM-fronted agent card. Always replaces +# whatever upstream advertised — the client must authenticate to the proxy, +# not the upstream agent. +LITELLM_SECURITY_SCHEMES: Dict[str, Dict[str, Any]] = { + "LiteLLMKey": { + "type": "http", + "scheme": "bearer", + "description": "LiteLLM virtual key", + }, +} + +LITELLM_SECURITY_REQUIREMENTS: List[Dict[str, List[str]]] = [{"LiteLLMKey": []}] + +# Capabilities LiteLLM can faithfully proxy today. Anything not in this set is +# dropped during merge so we don't advertise behavior the proxy can't deliver. +# +# TODO: re-enable ``streaming`` once the A2A streaming endpoint at +# ``POST /a2a/{agent_id}/message/stream`` is exercised end-to-end with +# cost tracking + guardrails. It's wired in ``a2a_endpoints.py`` but not +# yet covered by tests, so we keep it gated on the upstream advertising it. +# TODO: ``pushNotifications`` — proxy has no webhook plumbing yet. +# TODO: ``extendedAgentCard`` — no separate authenticated-extended-card +# endpoint exposed by the proxy. +# TODO: ``extensions`` — protocol extensions aren't validated/forwarded yet. +_ALLOWED_CAPABILITY_KEYS = {"streaming"} + +# v1.0 AgentCard top-level fields. Anything else is stripped from the merged +# card as a defense against upstream drift. ``supportedInterfaces`` is kept +# verbatim per product spec even though it is not in the v1.0 schema — clients +# that expect it will find it; clients that don't will ignore it. +# +# ``additionalInterfaces`` is deliberately excluded: it advertises alternate +# upstream URLs (HTTP/JSONRPC/gRPC backends) that, if persisted and served, +# would let authenticated agent callers reach the backend directly and bypass +# the proxy's auth/budget/logging. The proxy publishes its own entrypoint via +# ``supportedInterfaces`` instead. +_ALLOWED_TOP_LEVEL_KEYS = { + "protocolVersion", + "name", + "description", + "version", + "capabilities", + "defaultInputModes", + "defaultOutputModes", + "skills", + "preferredTransport", + "supportedInterfaces", + "iconUrl", + "provider", + "documentationUrl", + "securitySchemes", + "security", + "supportsAuthenticatedExtendedCard", + "signatures", + # ``url`` is retained on the stored card because the runtime A2A invocation + # path (``a2a_endpoints.py``) reads ``agent.agent_card_params['url']`` to + # locate the upstream backend. The public ``/.well-known/agent-card.json`` + # endpoint rewrites this field to the proxy URL before serving it to + # clients, so retaining it here does not leak the upstream to A2A callers. + "url", +} + +_DEFAULT_SKILLS: List[Dict[str, Any]] = [ + { + "id": "chat", + "name": "Chat", + "description": "Conversational interaction with the agent.", + "tags": ["chat"], + } +] + +_DEFAULT_MODES: List[str] = ["text"] + +# Fallback ``version`` when the upstream card omits the field. The A2A v1.0 +# schema requires ``version`` on every card, so without this default the +# merged card would fail validation on clients that ``model_validate`` it. +_DEFAULT_AGENT_VERSION = "1.0.0" + + +def _filter_capabilities(upstream_capabilities: Any) -> Dict[str, Any]: + """Return a capabilities dict containing only allowlisted, truthy keys.""" + if not isinstance(upstream_capabilities, dict): + return {} + return { + key: value + for key, value in upstream_capabilities.items() + if key in _ALLOWED_CAPABILITY_KEYS and bool(value) + } + + +def _default_litellm_provider(proxy_base_url: str) -> Dict[str, str]: + return {"organization": "LiteLLM Proxy", "url": proxy_base_url} + + +def merge_agent_card( + upstream_card: Optional[Mapping[str, Any]], + *, + proxy_url: str, + proxy_base_url: str, + name: Optional[str] = None, + description: Optional[str] = None, +) -> Dict[str, Any]: + """ + Build the LiteLLM-fronted agent card. + + Args: + upstream_card: Card returned by the upstream agent's well-known endpoint. + May be ``None``/empty when the upstream did not expose one. + proxy_url: Full URL clients should hit to invoke this agent through + the proxy, e.g. ``https://proxy.example.com/a2a/``. + proxy_base_url: Root URL of the LiteLLM proxy, used as a fallback when + we synthesize a provider record. + name: User-supplied agent name from the LiteLLM UI. Takes precedence + over the upstream card's ``name``. + description: User-supplied description from the LiteLLM UI. Takes + precedence over the upstream card's ``description``. + + Returns: + A dict suitable for serving as the proxy's agent card. Only keys in + the v1.0 AgentCard schema (plus ``supportedInterfaces``) are emitted. + """ + base: Dict[str, Any] = deepcopy(dict(upstream_card)) if upstream_card else {} + + # Keep the upstream ``url`` on the stored card: the runtime A2A + # invocation path reads it from ``agent_card_params`` to know where to + # proxy requests. The public well-known endpoint rewrites this field + # to the proxy URL before exposing the card to clients. + + base["protocolVersion"] = LITELLM_A2A_PROTOCOL_VERSION + + if name: + base["name"] = name + if description: + base["description"] = description + + if not base.get("version"): + base["version"] = _DEFAULT_AGENT_VERSION + + base["capabilities"] = _filter_capabilities(base.get("capabilities")) + + if not base.get("skills"): + base["skills"] = deepcopy(_DEFAULT_SKILLS) + if not base.get("defaultInputModes"): + base["defaultInputModes"] = list(_DEFAULT_MODES) + if not base.get("defaultOutputModes"): + base["defaultOutputModes"] = list(_DEFAULT_MODES) + + if not base.get("provider"): + base["provider"] = _default_litellm_provider(proxy_base_url) + + base["supportedInterfaces"] = [ + { + "url": proxy_url, + "protocolBinding": "JSONRPC", + "protocolVersion": LITELLM_A2A_PROTOCOL_VERSION, + } + ] + + base["securitySchemes"] = deepcopy(LITELLM_SECURITY_SCHEMES) + # Use the standard A2A/OpenAPI ``security`` field for requirements, not + # the non-standard ``securityRequirements`` alias. The upstream's own + # ``security`` selector is overwritten here because the proxy enforces its + # own scheme regardless of what upstream required. + base["security"] = deepcopy(LITELLM_SECURITY_REQUIREMENTS) + + return {key: value for key, value in base.items() if key in _ALLOWED_TOP_LEVEL_KEYS} diff --git a/litellm/proxy/a2a/discovery.py b/litellm/proxy/a2a/discovery.py new file mode 100644 index 0000000000..c7ac8dc0be --- /dev/null +++ b/litellm/proxy/a2a/discovery.py @@ -0,0 +1,162 @@ +""" +Fetch an A2A agent's well-known card from the upstream agent. + +Different agent runtimes publish the card at different URL shapes, so the +fetcher dispatches by ``discovery_mode``: + +- ``well_known_fallback`` (pure A2A): the card lives at one of the standard + well-known paths on the agent's own base URL. We try the canonical path, + then the previous-spec path, then a non-standard root fallback. + +- ``langgraph_platform``: LangGraph Platform mounts a single card endpoint at + ``{base}/.well-known/agent-card.json`` and disambiguates assistants via the + ``assistant_id`` query parameter. There is no per-assistant subpath, so the + pure-A2A fallback strategy returns 404 for these deployments. +""" + +from enum import Enum +from typing import Any, Dict, Optional, Tuple +from urllib.parse import urlencode + +from litellm._logging import verbose_proxy_logger +from litellm.litellm_core_utils.url_utils import SSRFError, async_safe_get +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client +from litellm.types.llms.custom_http import httpxSpecialProvider + + +class DiscoveryMode(str, Enum): + """How to locate the upstream agent card. + + String-valued so it serializes cleanly over JSON / Pydantic. + """ + + WELL_KNOWN_FALLBACK = "well_known_fallback" + LANGGRAPH_PLATFORM = "langgraph_platform" + + +# Paths the pure-A2A fetcher tries in order. The first two are the current and +# previous A2A spec locations; ``/agent.json`` is a non-standard root fallback +# some agents still serve. +AGENT_CARD_WELL_KNOWN_PATHS: Tuple[str, ...] = ( + "/.well-known/agent-card.json", + "/.well-known/agent.json", + "/agent.json", +) + +DEFAULT_DISCOVERY_TIMEOUT_SECONDS = 10.0 + + +class AgentCardDiscoveryError(Exception): + """Raised when none of the well-known paths returned a usable agent card.""" + + +def _normalize_base_url(base_url: str) -> str: + return base_url.rstrip("/") + + +def _build_langgraph_platform_paths( + params: Optional[Dict[str, Any]], +) -> Tuple[str, ...]: + """Build the paths to try for LangGraph Platform discovery. + + LangGraph serves the card at ``/.well-known/agent-card.json`` with the + ``assistant_id`` carried as a query parameter. We still try the other + A2A path variants (with the same query string appended) so we degrade + gracefully if a deployment uses an older spec name. + """ + assistant_id = (params or {}).get("assistant_id") + if not assistant_id: + raise AgentCardDiscoveryError( + "langgraph_platform discovery requires params.assistant_id" + ) + query = urlencode({"assistant_id": str(assistant_id)}) + return tuple(f"{path}?{query}" for path in AGENT_CARD_WELL_KNOWN_PATHS) + + +def _paths_for_mode( + mode: DiscoveryMode, params: Optional[Dict[str, Any]] +) -> Tuple[str, ...]: + if mode == DiscoveryMode.WELL_KNOWN_FALLBACK: + return AGENT_CARD_WELL_KNOWN_PATHS + if mode == DiscoveryMode.LANGGRAPH_PLATFORM: + return _build_langgraph_platform_paths(params) + raise AgentCardDiscoveryError(f"unsupported discovery_mode: {mode}") + + +async def fetch_well_known_card( + base_url: str, + *, + discovery_mode: DiscoveryMode = DiscoveryMode.WELL_KNOWN_FALLBACK, + params: Optional[Dict[str, Any]] = None, + timeout: float = DEFAULT_DISCOVERY_TIMEOUT_SECONDS, + headers: Optional[Dict[str, str]] = None, +) -> Dict[str, Any]: + """ + Fetch an agent card from ``base_url`` using the strategy chosen by + ``discovery_mode``. Returns the parsed JSON from the first path that + responds with a JSON body. + + Raises: + AgentCardDiscoveryError: if every path fails (network error, non-2xx, + or non-JSON body), or if the chosen mode is missing required params. + """ + if not base_url: + raise AgentCardDiscoveryError("base_url is required") + + normalized = _normalize_base_url(base_url) + paths = _paths_for_mode(discovery_mode, params) + client = get_async_httpx_client( + llm_provider=httpxSpecialProvider.A2A, + params={"timeout": timeout}, + ) + + last_error: Optional[str] = None + for path in paths: + url = f"{normalized}{path}" + try: + # ``async_safe_get`` validates the URL against the SSRF blocklist + # (private/loopback IPs, cloud metadata endpoints, etc.) on every + # redirect hop. Even though the discovery endpoint is admin-only, + # we don't want a compromised admin key to be able to probe + # internal infrastructure through this fetcher. + # Pass ``headers or {}`` because ``async_safe_get`` (in the + # URL-validation path) uses ``kwargs.pop("headers", {})`` which + # returns ``None`` when the key is present-but-None, then crashes + # on ``{**None, "Host": ...}``. Default the kwarg to an empty + # dict so production (``user_url_validation=True``) doesn't 500. + response = await async_safe_get(client, url, headers=headers or {}) + except SSRFError as exc: + last_error = f"{url}: {exc!s}" + verbose_proxy_logger.debug( + "A2A discovery blocked by SSRF guard for %s: %s", url, exc + ) + continue + except Exception as exc: + last_error = f"{url}: {exc!s}" + verbose_proxy_logger.debug("A2A discovery failed for %s: %s", url, exc) + continue + + if response.status_code >= 400: + last_error = f"{url}: HTTP {response.status_code}" + verbose_proxy_logger.debug( + "A2A discovery HTTP %s for %s", response.status_code, url + ) + continue + + try: + card = response.json() + except Exception as exc: + last_error = f"{url}: invalid JSON ({exc!s})" + continue + + if not isinstance(card, dict): + last_error = f"{url}: expected JSON object, got {type(card).__name__}" + continue + + verbose_proxy_logger.debug("A2A discovery succeeded at %s", url) + return card + + raise AgentCardDiscoveryError( + f"Could not fetch agent card from {base_url} (mode={discovery_mode.value}). " + f"Last error: {last_error}" + ) diff --git a/litellm/proxy/a2a/endpoints.py b/litellm/proxy/a2a/endpoints.py new file mode 100644 index 0000000000..520fdc9d8c --- /dev/null +++ b/litellm/proxy/a2a/endpoints.py @@ -0,0 +1,115 @@ +""" +FastAPI routes for the A2A registration flow. + +Today this exposes a single endpoint, ``POST /v1/a2a/discover``, used by the +LiteLLM UI when an admin registers a new A2A agent: the UI hands us the +upstream agent's base URL, we fetch its well-known card, and we return the +raw card so the UI can render the agent's skills/capabilities and let the +admin pick which ones to expose through the proxy. The actual merge into a +LiteLLM-fronted card happens when the agent is saved via ``POST /v1/agents``. +""" + +from typing import Any, Dict, Optional + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel, Field + +from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth +from litellm.proxy.a2a.discovery import ( + AGENT_CARD_WELL_KNOWN_PATHS, + AgentCardDiscoveryError, + DiscoveryMode, + fetch_well_known_card, +) +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + +router = APIRouter() + + +class DiscoverAgentRequest(BaseModel): + url: str = Field( + ..., + description=( + "Base URL of the upstream agent. Behavior depends on " + "``discovery_mode``: ``well_known_fallback`` (default) tries " + f"{', '.join(AGENT_CARD_WELL_KNOWN_PATHS)} under this URL in " + "order; ``langgraph_platform`` hits " + "``/.well-known/agent-card.json?assistant_id=`` instead." + ), + ) + discovery_mode: DiscoveryMode = Field( + default=DiscoveryMode.WELL_KNOWN_FALLBACK, + description=( + "How to locate the upstream card. " + "``well_known_fallback`` for pure A2A agents (try standard paths); " + "``langgraph_platform`` for LangGraph Platform deployments where " + "the card is shared across assistants and disambiguated by a " + "query parameter." + ), + ) + params: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "Mode-specific parameters. ``langgraph_platform`` requires " + "``{'assistant_id': }``. ``well_known_fallback`` ignores this." + ), + ) + + +class DiscoverAgentResponse(BaseModel): + url: str + agent_card: Dict[str, Any] + + +@router.post( + "/v1/a2a/discover", + tags=["[beta] A2A Agents"], + dependencies=[Depends(user_api_key_auth)], + response_model=DiscoverAgentResponse, +) +async def discover_agent_card( + request: DiscoverAgentRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +) -> JSONResponse: + """ + Fetch the upstream agent's well-known card so the UI can show the admin + which skills/capabilities the agent exposes. + + Only proxy admins can call this — the UI uses it during agent registration, + and we don't want arbitrary keys probing internal URLs. + + Example: + ```bash + curl -X POST "http://localhost:4000/v1/a2a/discover" \\ + -H "Authorization: Bearer " \\ + -H "Content-Type: application/json" \\ + -d '{"url": "https://upstream-agent.example.com"}' + ``` + """ + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: + raise HTTPException( + status_code=403, + detail=( + "Only proxy admins can discover agent cards. " + f"Your role={user_api_key_dict.user_role}" + ), + ) + + try: + card = await fetch_well_known_card( + request.url, + discovery_mode=request.discovery_mode, + params=request.params, + ) + except AgentCardDiscoveryError as exc: + raise HTTPException(status_code=400, detail=str(exc)) + except Exception as exc: + verbose_proxy_logger.exception("Unexpected error during A2A discovery: %s", exc) + raise HTTPException(status_code=500, detail=f"Discovery failed: {exc!s}") + + return JSONResponse( + content={"url": request.url, "agent_card": card}, + media_type="application/json", + ) diff --git a/litellm/proxy/agent_endpoints/a2a_endpoints.py b/litellm/proxy/agent_endpoints/a2a_endpoints.py index 72a34f6157..993d30e381 100644 --- a/litellm/proxy/agent_endpoints/a2a_endpoints.py +++ b/litellm/proxy/agent_endpoints/a2a_endpoints.py @@ -258,9 +258,17 @@ async def get_agent_card( detail=f"Agent '{agent_id}' is not allowed for your key/team. Contact proxy admin for access.", ) + if not agent.agent_card_params: + raise HTTPException( + status_code=404, + detail=f"Agent '{agent_id}' has no agent card configured", + ) + # Copy and rewrite URL to point to LiteLLM proxy - agent_card = dict(agent.agent_card_params) - agent_card["url"] = f"{str(request.base_url).rstrip('/')}/a2a/{agent_id}" + agent_card = { + **agent.agent_card_params, + "url": f"{str(request.base_url).rstrip('/')}/a2a/{agent_id}", + } verbose_proxy_logger.debug( f"Returning agent card for '{agent_id}' with proxy URL: {agent_card['url']}" @@ -332,9 +340,14 @@ async def invoke_agent_a2a( # noqa: PLR0915 if params: # extract any litellm params from the params - eg. 'guardrails' + # ``metadata`` is intentionally excluded: it's a first-class A2A + # ``MessageSendParams`` field that the completion bridge forwards + # downstream via ``get_forward_metadata``. Stripping it here would + # collide with litellm's spend-tracking ``metadata`` kwarg and + # silently drop the caller's A2A request-level metadata. params_to_remove = [] for key, value in params.items(): - if key in all_litellm_params: + if key in all_litellm_params and key != "metadata": params_to_remove.append(key) body[key] = value for key in params_to_remove: @@ -368,8 +381,9 @@ async def invoke_agent_a2a( # noqa: PLR0915 _enforce_inbound_trace_id(agent, request) # Get backend URL and agent name - agent_url = agent.agent_card_params.get("url") - agent_name = agent.agent_card_params.get("name", agent_id) + agent_card_params = agent.agent_card_params or {} + agent_url = agent_card_params.get("url") + agent_name = agent_card_params.get("name", agent_id) # Get litellm_params (may include custom_llm_provider for completion bridge) litellm_params = agent.litellm_params or {} diff --git a/litellm/proxy/agent_endpoints/agent_registry.py b/litellm/proxy/agent_endpoints/agent_registry.py index e69d16e0ea..13a2dd9f04 100644 --- a/litellm/proxy/agent_endpoints/agent_registry.py +++ b/litellm/proxy/agent_endpoints/agent_registry.py @@ -92,10 +92,19 @@ class AgentRegistry: ########### DB management helpers for agents ########### ############################################################ async def add_agent_to_db( - self, agent: AgentConfig, prisma_client: PrismaClient, created_by: str + self, + agent: AgentConfig, + prisma_client: PrismaClient, + created_by: str, + agent_id: Optional[str] = None, ) -> AgentResponse: """ - Add an agent to the database + Add an agent to the database. + + If ``agent_id`` is provided, it is used as the primary key for the new + row (otherwise the DB generates a UUID). Callers pass an explicit ID + when the agent_card_params must reference the agent's own URL before + the row exists, e.g. the A2A merge in ``create_agent``. """ try: agent_name = agent.get("agent_name") @@ -145,6 +154,8 @@ class AgentRegistry: "created_at": datetime.now(timezone.utc), "updated_at": datetime.now(timezone.utc), } + if agent_id is not None: + create_data["agent_id"] = agent_id if static_headers_val is not None: create_data["static_headers"] = static_headers_val if extra_headers_val is not None: diff --git a/litellm/proxy/agent_endpoints/endpoints.py b/litellm/proxy/agent_endpoints/endpoints.py index 374c123c33..082e314b08 100644 --- a/litellm/proxy/agent_endpoints/endpoints.py +++ b/litellm/proxy/agent_endpoints/endpoints.py @@ -10,7 +10,8 @@ Follows the A2A Spec. import asyncio import os -from typing import Any, Dict, List, Optional +import uuid +from typing import Any, Dict, List, Mapping, Optional from fastapi import APIRouter, Depends, HTTPException, Query, Request @@ -18,6 +19,7 @@ import litellm from litellm._logging import verbose_proxy_logger from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.proxy._types import CommonProxyErrors, LitellmUserRoles, UserAPIKeyAuth +from litellm.proxy.a2a.agent_card import merge_agent_card from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.rbac_utils import check_feature_access_for_user from litellm.proxy.management_endpoints.common_daily_activity import get_daily_activity @@ -35,6 +37,34 @@ from litellm.types.proxy.management_endpoints.common_daily_activity import ( SpendAnalyticsPaginatedResponse, ) + +def _proxy_base_url(http_request: Request) -> str: + """Return the proxy's base URL as seen by the caller, without trailing slash.""" + return str(http_request.base_url).rstrip("/") + + +def _build_merged_agent_card( + upstream_card: Optional[Mapping[str, Any]], + *, + agent_id: str, + http_request: Request, + agent_name: Optional[str] = None, +) -> Dict[str, Any]: + """Apply the LiteLLM-fronting merge to ``upstream_card`` for ``agent_id``.""" + proxy_base = _proxy_base_url(http_request) + # Prefer a card-supplied ``name`` (the discovery UI exposes an editable + # "Name (shown to API clients)" field that flows into + # ``agent_card_params.name``) over the internal ``agent_name`` identifier. + # Fall back to ``agent_name`` only when the card itself has no name. + card_name = upstream_card.get("name") if upstream_card else None + return merge_agent_card( + upstream_card, + proxy_url=f"{proxy_base}/a2a/{agent_id}", + proxy_base_url=proxy_base, + name=card_name or agent_name, + ) + + router = APIRouter() @@ -281,6 +311,7 @@ from litellm.proxy.agent_endpoints.agent_registry import ( ) async def create_agent( request: AgentConfig, + http_request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ @@ -345,8 +376,31 @@ async def create_agent( detail=f"Agent with name {request.get('agent_name')} already exists", ) + # Apply the LiteLLM-fronting merge only when the admin actually + # provided an agent card. Plain chat/LLM agents register without + # ``agent_card_params``, and synthesising a default A2A card for them + # would advertise capabilities (``supportedInterfaces``, security + # schemes, default skills) the agent doesn't actually expose. + upstream_card = request.get("agent_card_params") + agent_to_create: AgentConfig = request + new_agent_id: Optional[str] = None + if upstream_card is not None: + # Pre-generate the agent_id so the merged card can reference it + # in ``supportedInterfaces`` before the DB row exists. + new_agent_id = str(uuid.uuid4()) + merged_card = _build_merged_agent_card( + upstream_card, + agent_id=new_agent_id, + http_request=http_request, + agent_name=request.get("agent_name"), + ) + agent_to_create = {**request, "agent_card_params": merged_card} # type: ignore[typeddict-item] + result = await AGENT_REGISTRY.add_agent_to_db( - agent=request, prisma_client=prisma_client, created_by=created_by + agent=agent_to_create, + prisma_client=prisma_client, + created_by=created_by, + agent_id=new_agent_id, ) agent_name = result.agent_name @@ -473,6 +527,7 @@ async def get_agent_by_id( async def update_agent( agent_id: str, request: AgentConfig, + http_request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ @@ -533,9 +588,25 @@ async def update_agent( # Get the user ID from the API key auth updated_by = user_api_key_dict.user_id or "unknown" + # Re-apply the LiteLLM-fronting merge — an update is a re-registration, + # so any new upstream card the admin pasted must go through the same + # transformation as initial create. Plain agents without an + # ``agent_card_params`` skip the merge so we don't synthesise an A2A + # card for them. + upstream_card = request.get("agent_card_params") + agent_to_update: AgentConfig = request + if upstream_card is not None: + merged_card = _build_merged_agent_card( + upstream_card, + agent_id=agent_id, + http_request=http_request, + agent_name=request.get("agent_name"), + ) + agent_to_update = {**request, "agent_card_params": merged_card} # type: ignore[typeddict-item] + result = await AGENT_REGISTRY.update_agent_in_db( agent_id=agent_id, - agent=request, + agent=agent_to_update, prisma_client=prisma_client, updated_by=updated_by, ) @@ -566,6 +637,7 @@ async def update_agent( async def patch_agent( agent_id: str, request: PatchAgentRequest, + http_request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), ): """ @@ -626,9 +698,26 @@ async def patch_agent( # Get the user ID from the API key auth updated_by = user_api_key_dict.user_id or "unknown" + # Re-merge only when the patch actually touches agent_card_params; a + # patch updating just litellm_params/rate limits (``agent_card_params`` + # omitted) shouldn't rewrite the stored card. An explicitly provided + # ``agent_card_params`` — even an empty dict — still goes through the + # merge so LiteLLM applies its security schemes and supported + # interfaces instead of storing a bare card. + patch_payload: PatchAgentRequest = request + upstream_card = request.get("agent_card_params") + if upstream_card is not None: + merged_card = _build_merged_agent_card( + upstream_card, + agent_id=agent_id, + http_request=http_request, + agent_name=request.get("agent_name"), + ) + patch_payload = {**request, "agent_card_params": merged_card} # type: ignore[typeddict-item] + result = await AGENT_REGISTRY.patch_agent_in_db( agent_id=agent_id, - agent=request, + agent=patch_payload, prisma_client=prisma_client, updated_by=updated_by, ) diff --git a/litellm/proxy/public_endpoints/public_endpoints.py b/litellm/proxy/public_endpoints/public_endpoints.py index 7d9da543c7..d12e7a35fb 100644 --- a/litellm/proxy/public_endpoints/public_endpoints.py +++ b/litellm/proxy/public_endpoints/public_endpoints.py @@ -5,7 +5,7 @@ from importlib.resources import files from typing import Any, Dict, List, Optional import litellm -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter, HTTPException, Request from litellm._logging import verbose_logger from litellm.litellm_core_utils.get_blog_posts import ( @@ -211,7 +211,7 @@ async def public_model_hub(): tags=["[beta] Agents", "public"], response_model=List[AgentCard], ) -async def get_agents(): +async def get_agents(request: Request): import litellm from litellm.proxy.agent_endpoints.agent_registry import global_agent_registry @@ -219,12 +219,16 @@ async def get_agents(): if litellm.public_agent_groups is None: return [] - agent_card_list = [ - agent.agent_card_params + + proxy_base = str(request.base_url).rstrip("/") + return [ + { + **(agent.agent_card_params or {}), + "url": f"{proxy_base}/a2a/{agent.agent_id}", + } for agent in agents if agent.agent_id in litellm.public_agent_groups ] - return agent_card_list @router.get( diff --git a/tests/agent_tests/local_only_agent_tests/test_a2a_completion_bridge.py b/tests/agent_tests/local_only_agent_tests/test_a2a_completion_bridge.py index 95d76ba580..d3d582f29e 100644 --- a/tests/agent_tests/local_only_agent_tests/test_a2a_completion_bridge.py +++ b/tests/agent_tests/local_only_agent_tests/test_a2a_completion_bridge.py @@ -54,9 +54,9 @@ async def test_a2a_completion_bridge_non_streaming(): assert response.jsonrpc == "2.0" assert response.id is not None assert response.result is not None - assert "message" in response.result + assert response.result.get("kind") == "message" - message = response.result["message"] + message = response.result assert "role" in message assert message["role"] == "agent" assert "parts" in message diff --git a/tests/litellm/a2a_protocol/providers/pydantic_ai_agents/test_pydantic_ai_agent_transformation.py b/tests/litellm/a2a_protocol/providers/pydantic_ai_agents/test_pydantic_ai_agent_transformation.py index efbb628ee6..717a7c902b 100644 --- a/tests/litellm/a2a_protocol/providers/pydantic_ai_agents/test_pydantic_ai_agent_transformation.py +++ b/tests/litellm/a2a_protocol/providers/pydantic_ai_agents/test_pydantic_ai_agent_transformation.py @@ -90,9 +90,10 @@ class TestPydanticAITransformation: request_id="req-123", ) - # Should return standard A2A format with message + # Should return standard A2A non-streaming format where `result` is the + # Message itself (kind="message"), per A2A spec / SendMessageResponse. assert result["jsonrpc"] == "2.0" assert result["id"] == "req-123" - assert "message" in result["result"] - assert result["result"]["message"]["role"] == "agent" - assert result["result"]["message"]["parts"][0]["text"] == "The answer is 4." + assert result["result"]["kind"] == "message" + assert result["result"]["role"] == "agent" + assert result["result"]["parts"][0]["text"] == "The answer is 4." diff --git a/tests/test_litellm/a2a_protocol/test_completion_bridge_streaming.py b/tests/test_litellm/a2a_protocol/test_completion_bridge_streaming.py index 39c303f275..1b3e5f8602 100644 --- a/tests/test_litellm/a2a_protocol/test_completion_bridge_streaming.py +++ b/tests/test_litellm/a2a_protocol/test_completion_bridge_streaming.py @@ -16,6 +16,89 @@ import pytest class TestA2AStreamingTransformation: """Test the A2A streaming transformation creates proper events.""" + def test_a2a_metadata_forwarded_to_completion_params(self): + from litellm.a2a_protocol.litellm_completion_bridge.transformation import ( + A2ACompletionBridgeTransformation, + ) + + message = { + "role": "user", + "parts": [{"text": "Reply to ticket #4823"}], + "metadata": {"skillId": "draft_reply"}, + } + openai_messages = ( + A2ACompletionBridgeTransformation.a2a_message_to_openai_messages(message) + ) + # Metadata is forwarded on the run payload only, not duplicated on messages. + assert "metadata" not in openai_messages[0] + + completion_params: dict = { + "model": "langgraph/agent", + "messages": openai_messages, + } + A2ACompletionBridgeTransformation.apply_forward_metadata_to_completion_params( + completion_params=completion_params, + a2a_message=message, + params={"metadata": {"trace": "abc"}}, + ) + assert completion_params["extra_body"]["metadata"] == { + "trace": "abc", + "skillId": "draft_reply", + } + + def test_configured_metadata_wins_over_forwarded_a2a_metadata(self): + from litellm.a2a_protocol.litellm_completion_bridge.transformation import ( + A2ACompletionBridgeTransformation, + ) + + # Agent-owner-configured run metadata in ``extra_body``. + completion_params: dict = { + "model": "langgraph/agent", + "messages": [], + "extra_body": { + "metadata": {"owner_tag": "prod", "trace": "server-set"}, + "other": "keep", + }, + } + # Client tries to overwrite ``trace`` and inject a new key. + message = { + "role": "user", + "parts": [{"text": "hi"}], + "metadata": {"trace": "client-spoof", "skillId": "draft_reply"}, + } + A2ACompletionBridgeTransformation.apply_forward_metadata_to_completion_params( + completion_params=completion_params, + a2a_message=message, + params={"metadata": {"trace": "client-spoof-2"}}, + ) + assert completion_params["extra_body"]["other"] == "keep" + assert completion_params["extra_body"]["metadata"] == { + "owner_tag": "prod", + "trace": "server-set", + "skillId": "draft_reply", + } + + def test_langgraph_transform_preserves_message_metadata(self): + from litellm.llms.langgraph.chat.transformation import LangGraphConfig + + config = LangGraphConfig() + request = config.transform_request( + model="langgraph/agent", + messages=[ + { + "role": "user", + "content": "Reply to ticket #4823", + "metadata": {"skillId": "draft_reply"}, + } + ], + optional_params={}, + litellm_params={"stream": False}, + headers={}, + ) + assert request["input"]["messages"][-1]["metadata"] == { + "skillId": "draft_reply", + } + def test_create_task_event(self): """Test that create_task_event produces proper A2A task event structure.""" from litellm.a2a_protocol.litellm_completion_bridge.transformation import ( diff --git a/tests/test_litellm/proxy/a2a/__init__.py b/tests/test_litellm/proxy/a2a/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_litellm/proxy/a2a/test_agent_card.py b/tests/test_litellm/proxy/a2a/test_agent_card.py new file mode 100644 index 0000000000..0022053d8d --- /dev/null +++ b/tests/test_litellm/proxy/a2a/test_agent_card.py @@ -0,0 +1,189 @@ +"""Unit tests for the pure merge logic in litellm/proxy/a2a/agent_card.py.""" + +from litellm.proxy.a2a.agent_card import ( + LITELLM_A2A_PROTOCOL_VERSION, + LITELLM_SECURITY_REQUIREMENTS, + LITELLM_SECURITY_SCHEMES, + merge_agent_card, +) + +PROXY_URL = "https://proxy.example/a2a/agent-xyz" +PROXY_BASE = "https://proxy.example" + + +def _full_upstream_card() -> dict: + return { + "protocolVersion": "0.9", + "name": "Upstream Name", + "description": "Upstream description", + "url": "http://internal:9999/", + "version": "1.2.3", + "capabilities": { + "streaming": True, + "pushNotifications": True, + "stateTransitionHistory": True, + "extensions": [{"uri": "x"}], + }, + "skills": [ + {"id": "s1", "name": "skill one", "description": "d", "tags": ["t"]} + ], + "defaultInputModes": ["text", "audio"], + "defaultOutputModes": ["text"], + "securitySchemes": {"upstreamKey": {"type": "apiKey"}}, + "security": [{"upstreamKey": []}], + "provider": {"organization": "UpstreamCo", "url": "https://upstream.example"}, + "iconUrl": "https://upstream.example/icon.png", + "documentationUrl": "https://upstream.example/docs", + "somethingNotInSchema": "should be stripped", + } + + +def test_preserves_top_level_url_for_runtime_invocation(): + # The runtime A2A invocation path reads ``agent_card_params['url']`` to + # know where to proxy requests, so the merge must keep the upstream URL + # on the stored card. The public well-known endpoint rewrites this field + # to the proxy URL before exposing it to clients. + merged = merge_agent_card( + _full_upstream_card(), proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE + ) + assert merged["url"] == "http://internal:9999/" + + +def test_overrides_protocol_version(): + merged = merge_agent_card( + _full_upstream_card(), proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE + ) + assert merged["protocolVersion"] == LITELLM_A2A_PROTOCOL_VERSION + + +def test_overrides_name_and_description_when_provided(): + merged = merge_agent_card( + _full_upstream_card(), + proxy_url=PROXY_URL, + proxy_base_url=PROXY_BASE, + name="UI Name", + description="UI Description", + ) + assert merged["name"] == "UI Name" + assert merged["description"] == "UI Description" + + +def test_keeps_upstream_name_and_description_when_not_overridden(): + merged = merge_agent_card( + _full_upstream_card(), proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE + ) + assert merged["name"] == "Upstream Name" + assert merged["description"] == "Upstream description" + + +def test_filters_capabilities_to_allowlist(): + merged = merge_agent_card( + _full_upstream_card(), proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE + ) + # Only ``streaming`` is allowlisted today. + assert merged["capabilities"] == {"streaming": True} + + +def test_drops_streaming_when_upstream_disables_it(): + upstream = _full_upstream_card() + upstream["capabilities"]["streaming"] = False + merged = merge_agent_card(upstream, proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE) + assert merged["capabilities"] == {} + + +def test_replaces_security_schemes_and_requirements(): + merged = merge_agent_card( + _full_upstream_card(), proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE + ) + assert merged["securitySchemes"] == LITELLM_SECURITY_SCHEMES + assert merged["security"] == LITELLM_SECURITY_REQUIREMENTS + assert "securityRequirements" not in merged + + +def test_emits_supported_interfaces_pointing_at_proxy(): + merged = merge_agent_card( + _full_upstream_card(), proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE + ) + assert merged["supportedInterfaces"] == [ + { + "url": PROXY_URL, + "protocolBinding": "JSONRPC", + "protocolVersion": LITELLM_A2A_PROTOCOL_VERSION, + } + ] + + +def test_passes_through_skills_modes_provider_icon_docs(): + merged = merge_agent_card( + _full_upstream_card(), proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE + ) + assert merged["skills"] == _full_upstream_card()["skills"] + assert merged["defaultInputModes"] == ["text", "audio"] + assert merged["defaultOutputModes"] == ["text"] + assert merged["provider"] == { + "organization": "UpstreamCo", + "url": "https://upstream.example", + } + assert merged["iconUrl"] == "https://upstream.example/icon.png" + assert merged["documentationUrl"] == "https://upstream.example/docs" + + +def test_strips_fields_not_in_v1_schema(): + merged = merge_agent_card( + _full_upstream_card(), proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE + ) + assert "somethingNotInSchema" not in merged + + +def test_defaults_for_missing_skills_and_modes(): + sparse = {"name": "x", "description": "y", "version": "1"} + merged = merge_agent_card(sparse, proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE) + assert merged["skills"] and merged["skills"][0]["id"] == "chat" + assert merged["defaultInputModes"] == ["text"] + assert merged["defaultOutputModes"] == ["text"] + + +def test_defaults_version_when_upstream_omits_it(): + sparse = {"name": "x", "description": "y"} + merged = merge_agent_card(sparse, proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE) + assert merged["version"] == "1.0.0" + + +def test_preserves_upstream_version_when_present(): + merged = merge_agent_card( + _full_upstream_card(), proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE + ) + assert merged["version"] == "1.2.3" + + +def test_falls_back_to_litellm_provider_when_upstream_lacks_one(): + sparse = {"name": "x", "description": "y", "version": "1"} + merged = merge_agent_card(sparse, proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE) + assert merged["provider"] == { + "organization": "LiteLLM Proxy", + "url": PROXY_BASE, + } + + +def test_handles_none_upstream_card(): + merged = merge_agent_card(None, proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE) + assert merged["protocolVersion"] == LITELLM_A2A_PROTOCOL_VERSION + assert merged["supportedInterfaces"][0]["url"] == PROXY_URL + assert merged["securitySchemes"] == LITELLM_SECURITY_SCHEMES + + +def test_does_not_mutate_input(): + upstream = _full_upstream_card() + snapshot = dict(upstream) + merge_agent_card(upstream, proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE) + assert upstream == snapshot + + +def test_strips_additional_interfaces_to_prevent_backend_url_leak(): + upstream = _full_upstream_card() + upstream["additionalInterfaces"] = [ + {"url": "http://internal-backend:8080/", "transport": "JSONRPC"}, + {"url": "grpc://internal-backend:50051", "transport": "GRPC"}, + ] + merged = merge_agent_card(upstream, proxy_url=PROXY_URL, proxy_base_url=PROXY_BASE) + assert "additionalInterfaces" not in merged diff --git a/tests/test_litellm/proxy/a2a/test_discovery.py b/tests/test_litellm/proxy/a2a/test_discovery.py new file mode 100644 index 0000000000..ac1e7dfbb5 --- /dev/null +++ b/tests/test_litellm/proxy/a2a/test_discovery.py @@ -0,0 +1,283 @@ +"""Tests for the well-known card fetcher and the discovery endpoint.""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient + +import litellm +from litellm.proxy._types import LitellmUserRoles, UserAPIKeyAuth +from litellm.proxy.a2a.discovery import ( + AGENT_CARD_WELL_KNOWN_PATHS, + AgentCardDiscoveryError, + DiscoveryMode, + fetch_well_known_card, +) +from litellm.proxy.a2a.endpoints import router as a2a_router +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + + +@pytest.fixture(autouse=True) +def _disable_url_validation_for_mocks(monkeypatch): + """The fetch tests use placeholder hostnames (``upstream.example``, + ``localhost:2024``) with mocked HTTP clients. ``async_safe_get`` would + otherwise resolve those hostnames and either fail DNS or block on the + SSRF guard. Disabling validation here lets the unit tests focus on + fallback / parsing logic; SSRF behavior is covered in its own test.""" + monkeypatch.setattr(litellm, "user_url_validation", False) + + +# --------------------------------------------------------------------------- +# fetch_well_known_card +# --------------------------------------------------------------------------- + + +def _mock_response(status_code: int = 200, body=None, raise_json=False): + response = MagicMock() + response.status_code = status_code + if raise_json: + response.json = MagicMock(side_effect=ValueError("bad json")) + else: + response.json = MagicMock(return_value=body) + return response + + +@pytest.mark.asyncio +async def test_fetch_uses_first_path_that_returns_200(): + body = {"name": "agent"} + fake_client = MagicMock() + fake_client.get = AsyncMock(return_value=_mock_response(200, body=body)) + + with patch( + "litellm.proxy.a2a.discovery.get_async_httpx_client", return_value=fake_client + ): + card = await fetch_well_known_card("https://upstream.example") + + assert card == body + # First call should be to the canonical path. + called_url = fake_client.get.call_args.args[0] + assert called_url == f"https://upstream.example{AGENT_CARD_WELL_KNOWN_PATHS[0]}" + + +@pytest.mark.asyncio +async def test_fetch_falls_back_to_later_paths_on_404(): + body = {"name": "agent"} + fake_client = MagicMock() + fake_client.get = AsyncMock( + side_effect=[ + _mock_response(404), + _mock_response(404), + _mock_response(200, body=body), + ] + ) + + with patch( + "litellm.proxy.a2a.discovery.get_async_httpx_client", return_value=fake_client + ): + card = await fetch_well_known_card("https://upstream.example") + + assert card == body + assert fake_client.get.await_count == len(AGENT_CARD_WELL_KNOWN_PATHS) + + +@pytest.mark.asyncio +async def test_fetch_raises_when_all_paths_fail(): + fake_client = MagicMock() + fake_client.get = AsyncMock( + side_effect=[_mock_response(404) for _ in AGENT_CARD_WELL_KNOWN_PATHS] + ) + + with patch( + "litellm.proxy.a2a.discovery.get_async_httpx_client", return_value=fake_client + ): + with pytest.raises(AgentCardDiscoveryError): + await fetch_well_known_card("https://upstream.example") + + +@pytest.mark.asyncio +async def test_fetch_skips_path_that_returns_non_json_body(): + body = {"name": "agent"} + fake_client = MagicMock() + fake_client.get = AsyncMock( + side_effect=[ + _mock_response(200, raise_json=True), + _mock_response(200, body=body), + ] + ) + + with patch( + "litellm.proxy.a2a.discovery.get_async_httpx_client", return_value=fake_client + ): + card = await fetch_well_known_card("https://upstream.example") + + assert card == body + + +@pytest.mark.asyncio +async def test_fetch_skips_path_that_returns_non_object_json(): + fake_client = MagicMock() + fake_client.get = AsyncMock( + side_effect=[ + _mock_response(200, body=["not", "an", "object"]), + _mock_response(200, body={"name": "agent"}), + _mock_response(404), + ] + ) + + with patch( + "litellm.proxy.a2a.discovery.get_async_httpx_client", return_value=fake_client + ): + card = await fetch_well_known_card("https://upstream.example") + + assert card == {"name": "agent"} + + +@pytest.mark.asyncio +async def test_fetch_requires_base_url(): + with pytest.raises(AgentCardDiscoveryError): + await fetch_well_known_card("") + + +# --------------------------------------------------------------------------- +# LangGraph Platform discovery mode +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_langgraph_mode_appends_assistant_id_query_param(): + """LangGraph serves one card endpoint; the assistant is selected via query string.""" + body = {"name": "support-agent"} + fake_client = MagicMock() + fake_client.get = AsyncMock(return_value=_mock_response(200, body=body)) + + with patch( + "litellm.proxy.a2a.discovery.get_async_httpx_client", return_value=fake_client + ): + card = await fetch_well_known_card( + "http://localhost:2024", + discovery_mode=DiscoveryMode.LANGGRAPH_PLATFORM, + params={"assistant_id": "agent"}, + ) + + assert card == body + called_url = fake_client.get.call_args.args[0] + # The canonical A2A path with the LangGraph query parameter — NOT a + # per-assistant subpath like /agent/.well-known/agent-card.json. + assert called_url == ( + "http://localhost:2024/.well-known/agent-card.json?assistant_id=agent" + ) + + +@pytest.mark.asyncio +async def test_langgraph_mode_requires_assistant_id(): + with pytest.raises(AgentCardDiscoveryError, match="assistant_id"): + await fetch_well_known_card( + "http://localhost:2024", + discovery_mode=DiscoveryMode.LANGGRAPH_PLATFORM, + params={}, + ) + + +@pytest.mark.asyncio +async def test_langgraph_mode_falls_back_to_older_well_known_paths(): + """If an older LangGraph deployment serves /.well-known/agent.json, accept that too.""" + fake_client = MagicMock() + fake_client.get = AsyncMock( + side_effect=[ + _mock_response(404), + _mock_response(200, body={"name": "support-agent"}), + ] + ) + + with patch( + "litellm.proxy.a2a.discovery.get_async_httpx_client", return_value=fake_client + ): + card = await fetch_well_known_card( + "http://localhost:2024", + discovery_mode=DiscoveryMode.LANGGRAPH_PLATFORM, + params={"assistant_id": "agent"}, + ) + + assert card == {"name": "support-agent"} + # Both calls carry the assistant_id query param. + for call in fake_client.get.await_args_list: + assert "assistant_id=agent" in call.args[0] + + +# --------------------------------------------------------------------------- +# POST /v1/a2a/discover +# --------------------------------------------------------------------------- + + +def _client_for_role(role: LitellmUserRoles) -> TestClient: + app = FastAPI() + app.include_router(a2a_router) + app.dependency_overrides[user_api_key_auth] = lambda: UserAPIKeyAuth( + user_id="u", user_role=role + ) + return TestClient(app) + + +def test_discover_admin_returns_raw_card(): + client = _client_for_role(LitellmUserRoles.PROXY_ADMIN) + with patch( + "litellm.proxy.a2a.endpoints.fetch_well_known_card", + new=AsyncMock(return_value={"name": "Upstream"}), + ): + resp = client.post("/v1/a2a/discover", json={"url": "https://upstream.example"}) + + assert resp.status_code == 200 + body = resp.json() + assert body["url"] == "https://upstream.example" + assert body["agent_card"] == {"name": "Upstream"} + + +def test_discover_non_admin_forbidden(): + client = _client_for_role(LitellmUserRoles.INTERNAL_USER) + resp = client.post("/v1/a2a/discover", json={"url": "https://upstream.example"}) + assert resp.status_code == 403 + + +def test_discover_returns_400_when_upstream_unreachable(): + client = _client_for_role(LitellmUserRoles.PROXY_ADMIN) + with patch( + "litellm.proxy.a2a.endpoints.fetch_well_known_card", + new=AsyncMock(side_effect=AgentCardDiscoveryError("no luck")), + ): + resp = client.post("/v1/a2a/discover", json={"url": "https://upstream.example"}) + + assert resp.status_code == 400 + assert "no luck" in resp.json()["detail"] + + +def test_discover_forwards_mode_and_params_to_fetcher(): + """The endpoint must hand discovery_mode + params to fetch_well_known_card.""" + client = _client_for_role(LitellmUserRoles.PROXY_ADMIN) + fetch_stub = AsyncMock(return_value={"name": "support-agent"}) + with patch("litellm.proxy.a2a.endpoints.fetch_well_known_card", new=fetch_stub): + resp = client.post( + "/v1/a2a/discover", + json={ + "url": "http://localhost:2024", + "discovery_mode": "langgraph_platform", + "params": {"assistant_id": "agent"}, + }, + ) + + assert resp.status_code == 200 + # Pydantic deserializes the JSON string back into the DiscoveryMode enum. + assert fetch_stub.await_args is not None + kwargs = fetch_stub.await_args.kwargs + assert kwargs["discovery_mode"] == DiscoveryMode.LANGGRAPH_PLATFORM + assert kwargs["params"] == {"assistant_id": "agent"} + + +def test_discover_rejects_unknown_mode(): + """Pydantic should 422 on an enum value we don't recognize.""" + client = _client_for_role(LitellmUserRoles.PROXY_ADMIN) + resp = client.post( + "/v1/a2a/discover", + json={"url": "http://localhost:2024", "discovery_mode": "bogus"}, + ) + assert resp.status_code == 422 diff --git a/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py b/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py index dec2e66710..268e6d2dc1 100644 --- a/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py +++ b/tests/test_litellm/proxy/agent_endpoints/test_a2a_endpoints.py @@ -4,6 +4,7 @@ Mock tests for A2A endpoints. Tests that invoke_agent_a2a properly integrates with add_litellm_data_to_request. """ +import json import sys from unittest.mock import AsyncMock, MagicMock, patch @@ -181,3 +182,67 @@ async def test_invoke_agent_a2a_adds_litellm_data(): # Verify proxy_server_request was added assert "proxy_server_request" in captured_data assert captured_data["proxy_server_request"]["method"] == "POST" + + +@pytest.mark.asyncio +async def test_invoke_agent_a2a_handles_none_agent_card_params(): + """Agents without ``agent_card_params`` (e.g. plain chat agents routed + through the A2A endpoint by mistake) must not raise ``AttributeError`` on + ``agent_card_params.get(...)`` — they should return a JSON-RPC error. + """ + from litellm.proxy._types import UserAPIKeyAuth + + mock_agent = MagicMock() + mock_agent.agent_card_params = None + mock_agent.litellm_params = None + + mock_request = MagicMock() + mock_request.json = AsyncMock( + return_value={ + "jsonrpc": "2.0", + "id": "test-id", + "method": "message/send", + "params": { + "message": { + "role": "user", + "parts": [{"kind": "text", "text": "Hello"}], + "messageId": "msg-123", + } + }, + } + ) + + mock_user_api_key_dict = UserAPIKeyAuth( + api_key="sk-test-key", + user_id="test-user", + team_id="test-team", + ) + + with ( + patch( + "litellm.proxy.agent_endpoints.a2a_endpoints._get_agent", + return_value=mock_agent, + ), + patch( + "litellm.a2a_protocol.main.A2A_SDK_AVAILABLE", + True, + ), + patch.dict(sys.modules, {"a2a": MagicMock(), "a2a.types": MagicMock()}), + ): + from litellm.proxy.agent_endpoints.a2a_endpoints import invoke_agent_a2a + + mock_fastapi_response = MagicMock() + + response = await invoke_agent_a2a( + agent_id="test-agent", + request=mock_request, + fastapi_response=mock_fastapi_response, + user_api_key_dict=mock_user_api_key_dict, + ) + + # JSONResponse exposes the body bytes; decode and verify it's a + # JSON-RPC error, not an "internal error" from a Python exception. + body = json.loads(response.body.decode()) + assert body["jsonrpc"] == "2.0" + assert body["error"]["code"] == -32000 + assert "no URL configured" in body["error"]["message"] diff --git a/tests/test_litellm/proxy/agent_endpoints/test_endpoints.py b/tests/test_litellm/proxy/agent_endpoints/test_endpoints.py index 75928b55a9..68df365ff3 100644 --- a/tests/test_litellm/proxy/agent_endpoints/test_endpoints.py +++ b/tests/test_litellm/proxy/agent_endpoints/test_endpoints.py @@ -395,6 +395,36 @@ class TestAgentRBACProxyAdmin: ) assert resp.status_code == 200 + def test_create_agent_applies_litellm_merge_to_stored_card(self): + """The card stored in the DB must reflect the LiteLLM-fronting merge.""" + with patch("litellm.proxy.proxy_server.prisma_client"): + self.mock_registry.get_agent_by_name = MagicMock(return_value=None) + self.mock_registry.add_agent_to_db = AsyncMock( + return_value=_sample_agent_response() + ) + self.mock_registry.register_agent = MagicMock() + + self.admin_client.post( + "/v1/agents", + json=_sample_agent_config(), + headers={"Authorization": "Bearer k"}, + ) + + call_kwargs = self.mock_registry.add_agent_to_db.await_args.kwargs + stored_card = call_kwargs["agent"]["agent_card_params"] + new_agent_id = call_kwargs["agent_id"] + + # Top-level url is retained for runtime A2A invocation (the public + # well-known endpoint rewrites it before exposing to clients); + # supportedInterfaces points at the proxy. + assert stored_card["url"] == "http://localhost" + assert stored_card["supportedInterfaces"][0]["protocolBinding"] == "JSONRPC" + assert stored_card["supportedInterfaces"][0]["url"].endswith( + f"/a2a/{new_agent_id}" + ) + # Security scheme is the LiteLLM scheme. + assert "LiteLLMKey" in stored_card["securitySchemes"] + def test_should_allow_admin_to_delete_agent(self): existing = { "agent_id": "agent-123", diff --git a/tests/test_litellm/proxy/public_endpoints/test_public_endpoints.py b/tests/test_litellm/proxy/public_endpoints/test_public_endpoints.py index f82da59899..6cff91d2c7 100644 --- a/tests/test_litellm/proxy/public_endpoints/test_public_endpoints.py +++ b/tests/test_litellm/proxy/public_endpoints/test_public_endpoints.py @@ -463,6 +463,70 @@ def test_public_model_hub_mixed_health_statuses(): app.dependency_overrides.clear() +# --------------------------------------------------------------------------- +# /public/agent_hub +# --------------------------------------------------------------------------- + + +def test_public_agent_hub_rewrites_upstream_url_to_proxy(): + """Public agent hub must not leak the upstream backend URL retained on the + stored card. The ``url`` field has to be overwritten with the proxy + ``/a2a/{agent_id}`` entrypoint, matching the well-known card endpoint, so + an unauthenticated client cannot call the backend directly.""" + from litellm.types.agents import AgentResponse + + upstream_url = "https://upstream.internal.example.com/a2a" + agent = AgentResponse( + agent_id="agent-123", + agent_name="public-agent", + agent_card_params={"name": "public-agent", "url": upstream_url}, + ) + + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + mock_registry = MagicMock() + mock_registry.get_public_agent_list.return_value = [agent] + + with ( + patch("litellm.public_agent_groups", ["agent-123"]), + patch( + "litellm.proxy.agent_endpoints.agent_registry.global_agent_registry", + mock_registry, + ), + ): + response = client.get("/public/agent_hub") + + assert response.status_code == 200, response.text + payload = response.json() + assert len(payload) == 1 + card = payload[0] + assert upstream_url not in card.get("url", "") + assert card["url"].endswith("/a2a/agent-123") + + +def test_public_agent_hub_returns_empty_when_no_public_groups(): + app = FastAPI() + app.include_router(router) + client = TestClient(app) + + mock_registry = MagicMock() + mock_registry.get_public_agent_list.return_value = [] + + with ( + patch("litellm.public_agent_groups", None), + patch( + "litellm.proxy.agent_endpoints.agent_registry.global_agent_registry", + mock_registry, + ), + ): + response = client.get("/public/agent_hub") + + assert response.status_code == 200 + assert response.json() == [] + + # --------------------------------------------------------------------------- # /public/endpoints # --------------------------------------------------------------------------- diff --git a/tests/test_litellm/proxy/test_proxy_server.py b/tests/test_litellm/proxy/test_proxy_server.py index ba9c3b75ba..b912fec247 100644 --- a/tests/test_litellm/proxy/test_proxy_server.py +++ b/tests/test_litellm/proxy/test_proxy_server.py @@ -605,9 +605,7 @@ def test_ui_extensionless_route_requires_restructure(tmp_path): def test_admin_ui_export_serves_nested_extensionless_routes(): - out_dir = ( - Path(litellm.__file__).parent / "proxy" / "_experimental" / "out" - ) + out_dir = Path(litellm.__file__).parent / "proxy" / "_experimental" / "out" assert out_dir.is_dir(), f"missing UI export at {out_dir}" nested_html_offenders = [ @@ -619,8 +617,7 @@ def test_admin_ui_export_serves_nested_extensionless_routes(): and "litellm-asset-prefix" not in path.parts ] assert not nested_html_offenders, ( - "Nested routes must be named index.html. Offenders: " - f"{nested_html_offenders}" + "Nested routes must be named index.html. Offenders: " f"{nested_html_offenders}" ) callback_index = out_dir / "mcp" / "oauth" / "callback" / "index.html" @@ -630,9 +627,7 @@ def test_admin_ui_export_serves_nested_extensionless_routes(): ) fastapi_app = FastAPI() - fastapi_app.mount( - "/ui", StaticFiles(directory=str(out_dir), html=True), name="ui" - ) + fastapi_app.mount("/ui", StaticFiles(directory=str(out_dir), html=True), name="ui") client = TestClient(fastapi_app) redirect = client.get( @@ -640,7 +635,9 @@ def test_admin_ui_export_serves_nested_extensionless_routes(): follow_redirects=False, ) assert redirect.status_code == 307 - assert redirect.headers["location"].endswith("/ui/mcp/oauth/callback/?code=abc&state=xyz") + assert redirect.headers["location"].endswith( + "/ui/mcp/oauth/callback/?code=abc&state=xyz" + ) landed = client.get("/ui/mcp/oauth/callback?code=abc&state=xyz") assert landed.status_code == 200 @@ -5902,15 +5899,16 @@ async def test_primary_spend_counter_redis_concurrent_seed_does_not_double_seed( if call.kwargs.get("nx") is True ] assert len(nx_writes) == 2 - assert sorted(set_results) == [False, True], ( - f"expected exactly one SET NX winner and one loser, got {set_results}" - ) + assert sorted(set_results) == [ + False, + True, + ], f"expected exactly one SET NX winner and one loser, got {set_results}" # Loser path executed: after the winner's SET NX returned True, the # losing coalesced() call falls back to async_get_cache to read the # winner's value rather than re-seeding. - assert get_after_set_count >= 1, ( - "loser branch (else: read back winner's value) was never exercised" - ) + assert ( + get_after_set_count >= 1 + ), "loser branch (else: read back winner's value) was never exercised" @pytest.mark.asyncio @@ -7137,6 +7135,25 @@ class TestLazyFeatureRegistry: names = [f.name for f in LAZY_FEATURES] assert len(names) == len(set(names)), "duplicate feature names" + def test_matches_covers_prefix_and_suffix(self): + """``matches`` is the single matcher shared by the middleware (request + paths) and the warm endpoint (registered route paths), so a route that + only matches via suffix — e.g. ``/v1/a2a/{id}/message/send`` against the + ``/a2a`` prefix — must still be claimed by the feature.""" + from litellm.proxy._lazy_features import LazyFeature + + feat = LazyFeature( + name="a2a", + module_path="json", + path_prefixes=("/a2a",), + path_suffixes=("/message/send",), + ) + assert feat.matches("/a2a/abc/message/send") + assert feat.matches("/v1/a2a/abc/message/send") + assert feat.matches("/a2a/abc/.well-known/agent-card.json") + assert not feat.matches("/v1/a2a/discover") + assert not feat.matches("/unrelated") + class TestLazyFeaturesNotImportedAtStartup: """ diff --git a/ui/litellm-dashboard/src/components/agents/add_agent_form.tsx b/ui/litellm-dashboard/src/components/agents/add_agent_form.tsx index 046b28640c..3929a9e183 100644 --- a/ui/litellm-dashboard/src/components/agents/add_agent_form.tsx +++ b/ui/litellm-dashboard/src/components/agents/add_agent_form.tsx @@ -19,6 +19,13 @@ import { getModelDisplayName } from "../key_team_helpers/fetch_available_models_ import { Team } from "../key_team_helpers/key_list"; import TeamDropdown from "../common_components/team_dropdown"; import AgentFormFields from "./agent_form_fields"; +import AgentCardDiscovery, { + DiscoveredAgentCardSelection, +} from "./agent_card_discovery"; +import { + buildDiscoveryRequest, + overlayDiscoveredCardParams, +} from "./agent_discovery_utils"; import DynamicAgentFormFields, { buildDynamicAgentData } from "./dynamic_agent_form_fields"; import { getDefaultFormValues, buildAgentDataFromForm } from "./agent_config"; import MCPServerSelector from "../mcp_server_management/MCPServerSelector"; @@ -75,6 +82,12 @@ const AddAgentForm: React.FC = ({ const [maxIterations, setMaxIterations] = useState(null); const [maxBudgetPerSession, setMaxBudgetPerSession] = useState(null); + // Latest upstream card selection from auto-discovery (skills, capabilities, + // name, description). Dynamic agent forms don't render Form.Items for those + // fields, so we overlay this onto agent_card_params at submit. + const [appliedDiscoveredSelection, setAppliedDiscoveredSelection] = + useState(null); + // Fetch agent type metadata on mount useEffect(() => { const fetchMetadata = async () => { @@ -157,6 +170,30 @@ const AddAgentForm: React.FC = ({ (info) => info.agent_type === agentType ); + // Watch every form field so we can recompute the discovery plan whenever + // the user types into a relevant credential field below. + const watchedFormValues = Form.useWatch([], form); + + // Build the discovery plan for the proxy. Different agent runtimes publish + // their cards at different URL shapes: + // + // - LangGraph Platform: one well-known endpoint on the base URL, + // ``?assistant_id=`` selects the assistant. + // - Pure A2A (the default): card lives at one of the well-known paths + // on the agent's own base URL. + // + // Returns undefined when nothing usable is filled in yet, which causes the + // component to fall back to a manual URL input. + const discoveryRequest = React.useMemo( + () => + buildDiscoveryRequest( + agentType, + watchedFormValues || {}, + selectedAgentTypeInfo, + ), + [watchedFormValues, selectedAgentTypeInfo, agentType], + ); + const handleNext = async () => { try { if (currentStep === 0) { @@ -192,10 +229,13 @@ const AddAgentForm: React.FC = ({ skills: [], }, }; - } else if (agentType === "a2a") { - return buildAgentDataFromForm(values); + } + + let agentData: Record; + if (agentType === "a2a") { + agentData = buildAgentDataFromForm(values); } else if (selectedAgentTypeInfo?.use_a2a_form_fields) { - const agentData = buildAgentDataFromForm(values); + agentData = buildAgentDataFromForm(values); if (selectedAgentTypeInfo.litellm_params_template) { agentData.litellm_params = { ...agentData.litellm_params, @@ -208,11 +248,16 @@ const AddAgentForm: React.FC = ({ agentData.litellm_params[field.key] = value; } } - return agentData; } else if (selectedAgentTypeInfo) { - return buildDynamicAgentData(values, selectedAgentTypeInfo); + agentData = buildDynamicAgentData(values, selectedAgentTypeInfo); + } else { + return null; } - return null; + + return overlayDiscoveredCardParams( + agentData, + appliedDiscoveredSelection?.selected_card, + ); }; const handleCreateAgent = async () => { @@ -340,6 +385,7 @@ const AddAgentForm: React.FC = ({ setRequireTraceIdOutbound(false); setMaxIterations(null); setMaxBudgetPerSession(null); + setAppliedDiscoveredSelection(null); onClose(); }; @@ -568,6 +614,66 @@ const AddAgentForm: React.FC = ({ const handleAgentTypeChange = (value: string) => { setAgentType(value); form.resetFields(); + // Discovery selections are tied to a specific agent type's URL shape; + // switching types invalidates them. + setAppliedDiscoveredSelection(null); + }; + + // Apply a discovered agent card to the form so the rest of Step 1 (skills, + // capabilities, name, description, URL) reflects what the user picked. The + // proxy re-applies its own merge at registration; we only seed defaults here. + // + // AntD's `setFieldsValue` silently ignores keys whose Form.Item isn't + // registered, so this is safe across all agent types — A2A forms pick up + // every field below; LangGraph and other dynamic forms only pick up the + // shared ones (`agent_name`, `description`, plus any credential field whose + // key looks URL-ish). + const handleApplyDiscoveredCard = ( + selection: DiscoveredAgentCardSelection | null, + ) => { + setAppliedDiscoveredSelection(selection); + if (!selection) return; + const { selected_card, upstream_url } = selection; + const skills = (selected_card.skills ?? []).map((s) => ({ + id: s.id ?? "", + name: s.name ?? "", + description: s.description ?? "", + tags: s.tags ?? [], + examples: s.examples ?? [], + })); + + const currentAgentName = form.getFieldValue("agent_name"); + const seededAgentName = + currentAgentName || selected_card.name || selected_card.provider?.organization || ""; + + const fieldsToSet: Record = { + agent_name: seededAgentName, + name: selected_card.name, + description: selected_card.description, + url: upstream_url, + version: selected_card.version, + protocolVersion: selected_card.protocolVersion ?? "1.0", + streaming: Boolean(selected_card.capabilities?.streaming), + skills, + iconUrl: selected_card.iconUrl, + documentationUrl: selected_card.documentationUrl, + }; + + // For dynamic agent types (e.g. LangGraph), the URL lives in a + // type-specific credential field. Match on common naming variants so the + // user doesn't have to re-paste the URL they already typed above. + const urlCredentialKeys = (selectedAgentTypeInfo?.credential_fields ?? []) + .map((f) => f.key) + .filter((key) => /(^|_)(url|api_base|endpoint)$/i.test(key)); + for (const key of urlCredentialKeys) { + fieldsToSet[key] = upstream_url; + } + + form.setFieldsValue(fieldsToSet); + + if (!newKeyName && seededAgentName) { + setNewKeyName(`${seededAgentName}-key`); + } }; const isCustomAgent = agentType === CUSTOM_AGENT_TYPE; @@ -702,6 +808,21 @@ const AddAgentForm: React.FC = ({ ) : selectedAgentTypeInfo ? ( ) : null} + + {/* Discovery sits at the bottom so its URL can be derived from the + credential fields the user typed above. The plan (URL + mode + + params) is computed from the agent type — LangGraph hits a + different shape than pure A2A. Custom agents have no upstream to + discover, so we skip them. */} + {agentType !== CUSTOM_AGENT_TYPE && ( +
+ +
+ )} diff --git a/ui/litellm-dashboard/src/components/agents/agent_card_discovery.test.tsx b/ui/litellm-dashboard/src/components/agents/agent_card_discovery.test.tsx new file mode 100644 index 0000000000..18b874c95e --- /dev/null +++ b/ui/litellm-dashboard/src/components/agents/agent_card_discovery.test.tsx @@ -0,0 +1,299 @@ +import React from "react"; +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; +import { screen, waitFor } from "@testing-library/react"; +import userEvent from "@testing-library/user-event"; +import { renderWithProviders } from "../../../tests/test-utils"; +import AgentCardDiscovery from "./agent_card_discovery"; + +vi.mock("../networking", async () => { + const actual = await vi.importActual("../networking"); + return { + ...actual, + discoverAgentCardCall: vi.fn(), + }; +}); + +import { discoverAgentCardCall } from "../networking"; + +const mockDiscover = discoverAgentCardCall as unknown as ReturnType; + +const sampleCard = { + protocolVersion: "1.0", + name: "Upstream Agent", + description: "An upstream agent", + version: "1.2.3", + url: "http://internal:9000", + capabilities: { streaming: true, pushNotifications: true }, + skills: [ + { + id: "search", + name: "Search", + description: "Search the web", + tags: ["search"], + }, + { + id: "summarize", + name: "Summarize", + description: "Summarize a document", + tags: ["llm"], + }, + ], + provider: { organization: "UpstreamCo", url: "https://upstream.example" }, +}; + +describe("AgentCardDiscovery", () => { + beforeEach(() => { + vi.useFakeTimers({ shouldAdvanceTime: true }); + mockDiscover.mockReset(); + }); + + afterEach(() => { + vi.useRealTimers(); + }); + + it("renders the URL input and a Re-discover button after manual entry", async () => { + mockDiscover.mockResolvedValue({ + url: "https://upstream.example.com", + agent_card: sampleCard, + }); + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }); + renderWithProviders( + , + ); + + expect( + screen.getByPlaceholderText("https://upstream-agent.example.com"), + ).toBeInTheDocument(); + + await user.type( + screen.getByPlaceholderText("https://upstream-agent.example.com"), + "https://upstream.example.com", + ); + await vi.advanceTimersByTimeAsync(500); + + await waitFor(() => expect(mockDiscover).toHaveBeenCalled()); + expect( + await screen.findByRole("button", { name: /re-discover/i }), + ).toBeInTheDocument(); + }); + + it("shows an error when re-discover is clicked without a URL", async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }); + renderWithProviders( + , + ); + + await user.click(screen.getByRole("button", { name: /discover/i })); + expect( + await screen.findByText(/Enter the agent's base URL first/i), + ).toBeInTheDocument(); + expect(mockDiscover).not.toHaveBeenCalled(); + }); + + it("auto-discovers and renders upstream skills on success", async () => { + mockDiscover.mockResolvedValueOnce({ + url: "https://upstream.example.com", + agent_card: sampleCard, + }); + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }); + renderWithProviders( + , + ); + + await user.type( + screen.getByPlaceholderText("https://upstream-agent.example.com"), + "https://upstream.example.com", + ); + await vi.advanceTimersByTimeAsync(500); + + expect(await screen.findByText("Upstream card loaded")).toBeInTheDocument(); + expect(screen.getByText("Search")).toBeInTheDocument(); + expect(screen.getByText("Summarize")).toBeInTheDocument(); + expect(screen.getByText(/^streaming$/i)).toBeInTheDocument(); + expect(screen.queryByText(/pushNotifications/i)).not.toBeInTheDocument(); + expect( + screen.queryByRole("button", { name: /use these selections/i }), + ).not.toBeInTheDocument(); + }); + + it("shows an inline error when discovery fails", async () => { + mockDiscover.mockRejectedValueOnce(new Error("upstream unreachable")); + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }); + renderWithProviders( + , + ); + + await user.type( + screen.getByPlaceholderText("https://upstream-agent.example.com"), + "https://nope.example", + ); + await vi.advanceTimersByTimeAsync(500); + + expect(await screen.findByText("Discovery failed")).toBeInTheDocument(); + expect(screen.getByText(/upstream unreachable/)).toBeInTheDocument(); + }); + + it("syncs the selected subset to the parent as the user edits", async () => { + mockDiscover.mockResolvedValueOnce({ + url: "https://upstream.example.com", + agent_card: sampleCard, + }); + const onApply = vi.fn(); + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }); + renderWithProviders( + , + ); + + await user.type( + screen.getByPlaceholderText("https://upstream-agent.example.com"), + "https://upstream.example.com", + ); + await vi.advanceTimersByTimeAsync(500); + await screen.findByText("Upstream card loaded"); + + await waitFor(() => expect(onApply).toHaveBeenCalled()); + const initialSelection = onApply.mock.calls.at(-1)?.[0]; + expect(initialSelection.upstream_url).toBe("https://upstream.example.com"); + expect(initialSelection.selected_card.skills).toHaveLength(2); + + const summarizeLabel = screen.getByText("Summarize").closest("label"); + expect(summarizeLabel).toBeTruthy(); + const summarizeCheckbox = summarizeLabel!.querySelector( + "input[type='checkbox']", + ) as HTMLInputElement; + await user.click(summarizeCheckbox); + + await waitFor(() => { + const latest = onApply.mock.calls.at(-1)?.[0]; + expect(latest.selected_card.skills).toHaveLength(1); + expect(latest.selected_card.skills[0].id).toBe("search"); + }); + }); + + it("hides the URL input and shows the display URL when parent-driven", () => { + renderWithProviders( + , + ); + + expect( + screen.queryByPlaceholderText("https://upstream-agent.example.com"), + ).not.toBeInTheDocument(); + expect( + screen.getByText( + "http://localhost:2024/.well-known/agent-card.json?assistant_id=agent", + ), + ).toBeInTheDocument(); + }); + + it("auto-discovers with discovery_mode and params from the parent plan", async () => { + mockDiscover.mockResolvedValueOnce({ + url: "http://localhost:2024", + agent_card: sampleCard, + }); + renderWithProviders( + , + ); + + await vi.advanceTimersByTimeAsync(0); + await waitFor(() => expect(mockDiscover).toHaveBeenCalledTimes(1)); + expect(mockDiscover).toHaveBeenCalledWith("tok", "http://localhost:2024", { + discovery_mode: "langgraph_platform", + params: { assistant_id: "agent" }, + }); + }); + + it("disables Re-discover until the parent provides a usable URL", async () => { + renderWithProviders( + , + ); + + expect( + (screen.getByRole("button", { + name: /discover/i, + }) as HTMLButtonElement).disabled, + ).toBe(true); + expect(mockDiscover).not.toHaveBeenCalled(); + }); + + it("pre-selects only skills present in savedAgentCard when editing", async () => { + mockDiscover.mockResolvedValueOnce({ + url: "http://localhost:2024", + agent_card: sampleCard, + }); + const onApply = vi.fn(); + renderWithProviders( + , + ); + + await vi.advanceTimersByTimeAsync(0); + await screen.findByText("Upstream card loaded"); + + await waitFor(() => expect(onApply).toHaveBeenCalled()); + const selection = onApply.mock.calls.at(-1)?.[0]; + expect(selection.selected_card.skills).toHaveLength(1); + expect(selection.selected_card.skills[0].id).toBe("search"); + expect(selection.selected_card.name).toBe("DB Agent"); + expect(selection.selected_card.capabilities.streaming).toBe(false); + }); + + it("blocks discover when no access token is provided", async () => { + const user = userEvent.setup({ advanceTimers: vi.advanceTimersByTime }); + renderWithProviders( + , + ); + + await user.type( + screen.getByPlaceholderText("https://upstream-agent.example.com"), + "https://upstream.example.com", + ); + await user.click(screen.getByRole("button", { name: /discover/i })); + + expect( + await screen.findByText(/No access token available/i), + ).toBeInTheDocument(); + expect(mockDiscover).not.toHaveBeenCalled(); + }); +}); diff --git a/ui/litellm-dashboard/src/components/agents/agent_card_discovery.tsx b/ui/litellm-dashboard/src/components/agents/agent_card_discovery.tsx new file mode 100644 index 0000000000..ee34450d09 --- /dev/null +++ b/ui/litellm-dashboard/src/components/agents/agent_card_discovery.tsx @@ -0,0 +1,511 @@ +"use client"; + +import React, { useCallback, useEffect, useMemo, useRef, useState } from "react"; +import { + Alert, + Button, + Checkbox, + Collapse, + Empty, + Input, + Space, + Spin, + Switch, + Tag, + Tooltip, + Typography, +} from "antd"; +// Empty is used in the skills panel below. +import { + CheckCircleTwoTone, + InfoCircleOutlined, + LinkOutlined, + ReloadOutlined, + SearchOutlined, +} from "@ant-design/icons"; + +import { + DiscoveredAgentCard, + discoverAgentCardCall, +} from "../networking"; +import { + ALLOWED_CAPABILITY_KEYS, + selectionsFromSavedAgentCard, + selectionsFromUpstreamCard, + skillId, +} from "./agent_discovery_utils"; + +const { Text, Paragraph } = Typography; +const { Panel } = Collapse; + +export interface DiscoveredAgentCardSelection { + /** Full upstream card the proxy fetched, unmodified. */ + raw_card: DiscoveredAgentCard; + /** Subset of the upstream card with only the user-selected skills and + * capabilities, plus the user-edited name/description. Suitable to send as + * ``agent_card_params`` on ``POST /v1/agents``. */ + selected_card: DiscoveredAgentCard; + /** The base URL the user pasted in. */ + upstream_url: string; +} + +export type { DiscoveryRequestPlan } from "./agent_discovery_utils"; +import type { DiscoveryRequestPlan } from "./agent_discovery_utils"; + +interface AgentCardDiscoveryProps { + accessToken: string | null; + /** Called whenever the upstream card or the user's selections change. Pass + * ``null`` when discovery is cleared or fails so the parent can reset. */ + onApply: (selection: DiscoveredAgentCardSelection | null) => void; + /** + * Parent-supplied discovery plan. When provided the component uses these + * values verbatim and hides its free-form URL input — the parent is the + * source of truth (e.g. for LangGraph it's derived from api_base + + * assistant_id). When omitted the component falls back to a manual URL + * input that defaults to ``well_known_fallback`` mode. + */ + discoveryRequest?: DiscoveryRequestPlan; + /** When editing an existing agent, the card stored in the DB. Upstream + * discovery lists everything available; only skills/capabilities present + * here are pre-selected. */ + savedAgentCard?: DiscoveredAgentCard | null; +} + +const AgentCardDiscovery: React.FC = ({ + accessToken, + onApply, + discoveryRequest, + savedAgentCard, +}) => { + // When the parent drives discovery, ``manualUrl`` is unused — the URL + // comes from ``discoveryRequest.url`` directly. When the parent hasn't + // supplied a plan, the admin types into this field manually. + const [manualUrl, setManualUrl] = useState(""); + const [loading, setLoading] = useState(false); + const [error, setError] = useState(null); + const [card, setCard] = useState(null); + + const isParentDriven = discoveryRequest !== undefined; + const effectiveUrl = isParentDriven ? discoveryRequest!.url : manualUrl; + + const [editedName, setEditedName] = useState(""); + const [editedDescription, setEditedDescription] = useState(""); + const [selectedSkillIds, setSelectedSkillIds] = useState>(new Set()); + const [selectedCapabilities, setSelectedCapabilities] = useState< + Record + >({}); + + const onApplyRef = useRef(onApply); + onApplyRef.current = onApply; + const discoverRequestIdRef = useRef(0); + const lastSyncedSelectionRef = useRef(null); + // Hold the latest ``discoveryRequest`` in a ref so ``handleDiscover`` can + // read its ``discovery_mode``/``params`` without depending on the object + // identity itself — the parent recreates the object on every form keystroke + // even when the underlying values are unchanged. We use stable primitive + // keys (``discoveryMode`` + ``discoveryParamsKey``) as the actual deps so + // the callback / effect only re-run when content actually changes. + const discoveryRequestRef = useRef(discoveryRequest); + discoveryRequestRef.current = discoveryRequest; + // Hold ``savedAgentCard`` in a ref so ``resetSelections`` always sees the + // latest value without making it a dependency of ``handleDiscover``. + // Putting ``savedAgentCard`` directly in the callback deps means any parent + // re-render that hands us a new object reference (e.g. a background + // agent-data refresh during editing) recreates ``handleDiscover``, which + // re-fires the auto-discover effect and overwrites in-progress user edits. + const savedAgentCardRef = useRef(savedAgentCard); + savedAgentCardRef.current = savedAgentCard; + + const resetSelections = (fresh: DiscoveredAgentCard) => { + const saved = savedAgentCardRef.current; + const initial = saved + ? selectionsFromSavedAgentCard(fresh, saved) + : selectionsFromUpstreamCard(fresh); + setEditedName(initial.editedName); + setEditedDescription(initial.editedDescription); + setSelectedSkillIds(initial.selectedSkillIds); + setSelectedCapabilities(initial.selectedCapabilities); + }; + + const discoveryMode = discoveryRequest?.discovery_mode; + const discoveryParamsKey = useMemo( + () => JSON.stringify(discoveryRequest?.params ?? null), + [discoveryRequest?.params], + ); + + const handleDiscover = useCallback(async () => { + if (!accessToken) { + setError("No access token available"); + onApplyRef.current(null); + return; + } + const trimmed = effectiveUrl.trim(); + if (!trimmed) { + setError( + isParentDriven + ? "Fill in the agent's connection details above first" + : "Enter the agent's base URL first", + ); + setCard(null); + onApplyRef.current(null); + return; + } + + const currentDiscoveryRequest = discoveryRequestRef.current; + const requestId = ++discoverRequestIdRef.current; + setLoading(true); + setError(null); + try { + const response = await discoverAgentCardCall( + accessToken, + trimmed, + isParentDriven && currentDiscoveryRequest + ? { + discovery_mode: currentDiscoveryRequest.discovery_mode, + params: currentDiscoveryRequest.params, + } + : undefined, + ); + if (requestId !== discoverRequestIdRef.current) return; + lastSyncedSelectionRef.current = null; + setCard(response.agent_card); + resetSelections(response.agent_card); + } catch (e: any) { + if (requestId !== discoverRequestIdRef.current) return; + setError(e?.message ? String(e.message) : "Failed to discover agent card"); + setCard(null); + lastSyncedSelectionRef.current = null; + onApplyRef.current(null); + } finally { + if (requestId === discoverRequestIdRef.current) { + setLoading(false); + } + } + // ``discoveryMode`` / ``discoveryParamsKey`` are primitive proxies for + // ``discoveryRequest`` content; the actual object is read via the ref + // above so identity churn from the parent doesn't recreate this callback. + // ``savedAgentCard`` is intentionally NOT a dep — it's read via + // ``savedAgentCardRef`` inside ``resetSelections``. Including it here + // would recreate this callback whenever the parent hands us a new + // ``savedAgentCard`` object (e.g. a background refresh of agent data + // during editing), which would re-fire the auto-discover effect and + // wipe in-progress user selections. + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [ + accessToken, + effectiveUrl, + isParentDriven, + discoveryMode, + discoveryParamsKey, + ]); + + // Auto-discover when the URL (or parent plan) becomes available. Debounce + // is applied uniformly so rapid changes from a watched parent form (e.g. + // typing into a LangGraph api_base / assistant_id field) don't fire one + // HTTP request per keystroke. + useEffect(() => { + if (!accessToken) return; + const trimmed = effectiveUrl.trim(); + if (!trimmed) { + setCard(null); + setError(null); + lastSyncedSelectionRef.current = null; + onApplyRef.current(null); + return; + } + + const timer = window.setTimeout(() => { + void handleDiscover(); + }, 400); + return () => window.clearTimeout(timer); + }, [accessToken, effectiveUrl, handleDiscover]); + + const toggleSkill = (id: string, checked: boolean) => { + setSelectedSkillIds((prev) => { + const next = new Set(prev); + if (checked) next.add(id); + else next.delete(id); + return next; + }); + }; + + const buildSelection = useCallback((): DiscoveredAgentCardSelection | null => { + if (!card) return null; + const skills = card.skills ?? []; + const filteredSkills = skills.filter((s, i) => + selectedSkillIds.has(skillId(s, i)), + ); + + const selected_card: DiscoveredAgentCard = { + ...card, + name: editedName, + description: editedDescription, + skills: filteredSkills, + capabilities: { ...selectedCapabilities }, + }; + + return { + raw_card: card, + selected_card, + upstream_url: effectiveUrl.trim(), + }; + }, [ + card, + editedDescription, + editedName, + effectiveUrl, + selectedCapabilities, + selectedSkillIds, + ]); + + // Keep the parent form in sync as the user edits selections — no extra + // "apply" click needed before hitting Next. + useEffect(() => { + if (!card) return; + const selection = buildSelection(); + const serialized = JSON.stringify(selection); + if (lastSyncedSelectionRef.current === serialized) return; + lastSyncedSelectionRef.current = serialized; + onApplyRef.current(selection); + }, [buildSelection, card]); + + const skillCount = card?.skills?.length ?? 0; + const selectedSkillCount = selectedSkillIds.size; + + return ( +
+
+ + Discover from agent URL + + + +
+ {isParentDriven ? ( + <> + + Using the connection details you entered above. We'll fetch: + +
+ {discoveryRequest!.display_url || effectiveUrl || ( + + Fill in the fields above first + + )} +
+
+ +
+ + ) : ( + <> + + Paste the upstream agent's base URL. We'll try{" "} + /.well-known/agent-card.json,{" "} + /.well-known/agent.json, and /agent.json{" "} + in order. + + + + setManualUrl(e.target.value)} + onPressEnter={handleDiscover} + allowClear + disabled={loading} + /> + + + + )} + + {error && ( + setError(null)} + /> + )} + + {loading && !card && ( +
+ +
+ )} + + {card && ( +
+
+ + + Upstream card loaded + {card.version && v{card.version}} + {card.provider?.organization && ( + {card.provider.organization} + )} + +
+ +
+
+ + setEditedName(e.target.value)} + placeholder="Agent name" + /> +
+
+ + setEditedDescription(e.target.value)} + rows={2} + placeholder="What this agent does" + /> +
+
+ + + + Skills + + {selectedSkillCount} / {skillCount} selected + + + } + > + {skillCount === 0 ? ( + + ) : ( +
+ {(card.skills ?? []).map((skill, idx) => { + const id = skillId(skill, idx); + const checked = selectedSkillIds.has(id); + return ( + + ); + })} +
+ )} +
+ + + Capabilities + + + + + } + > +
+ {ALLOWED_CAPABILITY_KEYS.map((key) => { + const upstreamHas = Boolean(card.capabilities?.[key]); + return ( +
+
+ + {key} + + {!upstreamHas && ( + + not advertised upstream + + )} +
+ + setSelectedCapabilities((prev) => ({ + ...prev, + [key]: checked, + })) + } + /> +
+ ); + })} +
+
+
+ +
+ )} +
+ ); +}; + +export default AgentCardDiscovery; diff --git a/ui/litellm-dashboard/src/components/agents/agent_discovery_utils.test.ts b/ui/litellm-dashboard/src/components/agents/agent_discovery_utils.test.ts new file mode 100644 index 0000000000..632fba1532 --- /dev/null +++ b/ui/litellm-dashboard/src/components/agents/agent_discovery_utils.test.ts @@ -0,0 +1,64 @@ +import { describe, it, expect } from "vitest"; +import { + selectionsFromSavedAgentCard, + selectionsFromUpstreamCard, + skillId, +} from "./agent_discovery_utils"; + +const upstreamCard = { + name: "Upstream Agent", + description: "Upstream description", + capabilities: { streaming: true }, + skills: [ + { id: "search", name: "Search", description: "Search the web" }, + { id: "summarize", name: "Summarize", description: "Summarize docs" }, + { id: "chat", name: "Chat", description: "General chat" }, + ], +}; + +describe("selectionsFromSavedAgentCard", () => { + it("pre-selects only skills that exist in the saved DB card", () => { + const savedCard = { + name: "My Agent", + description: "Saved description", + capabilities: { streaming: false }, + skills: [{ id: "search", name: "Search" }], + }; + + const result = selectionsFromSavedAgentCard(upstreamCard, savedCard); + + expect(result.editedName).toBe("My Agent"); + expect(result.editedDescription).toBe("Saved description"); + expect(result.selectedCapabilities.streaming).toBe(false); + expect(result.selectedSkillIds.has(skillId(upstreamCard.skills![0], 0))).toBe( + true, + ); + expect( + result.selectedSkillIds.has(skillId(upstreamCard.skills![1], 1)), + ).toBe(false); + expect( + result.selectedSkillIds.has(skillId(upstreamCard.skills![2], 2)), + ).toBe(false); + }); + + it("matches saved skills by name when id is missing", () => { + const savedCard = { + skills: [{ name: "Summarize" }], + }; + + const result = selectionsFromSavedAgentCard(upstreamCard, savedCard); + + expect( + result.selectedSkillIds.has(skillId(upstreamCard.skills![1], 1)), + ).toBe(true); + expect(result.selectedSkillIds.size).toBe(1); + }); +}); + +describe("selectionsFromUpstreamCard", () => { + it("selects all upstream skills for create flow", () => { + const result = selectionsFromUpstreamCard(upstreamCard); + expect(result.selectedSkillIds.size).toBe(3); + expect(result.editedName).toBe("Upstream Agent"); + }); +}); diff --git a/ui/litellm-dashboard/src/components/agents/agent_discovery_utils.ts b/ui/litellm-dashboard/src/components/agents/agent_discovery_utils.ts new file mode 100644 index 0000000000..bb3d79e259 --- /dev/null +++ b/ui/litellm-dashboard/src/components/agents/agent_discovery_utils.ts @@ -0,0 +1,176 @@ +import { + AgentCreateInfo, + DiscoveredAgentCard, + DiscoveryMode, +} from "../networking"; + +export interface DiscoveryRequestPlan { + url: string; + discovery_mode: DiscoveryMode; + params?: Record; + display_url?: string; +} + +export const skillId = (skill: any, idx: number): string => + skill?.id ?? skill?.name ?? `skill-${idx}`; + +export const ALLOWED_CAPABILITY_KEYS = ["streaming"] as const; + +export const filterCapabilitiesForUI = ( + capabilities: Record | undefined, +): Record => { + if (!capabilities) return {}; + return ALLOWED_CAPABILITY_KEYS.reduce>((acc, key) => { + if (key in capabilities) acc[key] = Boolean(capabilities[key]); + return acc; + }, {}); +}; + +/** + * After fetching the full upstream card, pre-select only skills and + * capabilities that already exist on the agent record in the DB. + */ +export const selectionsFromSavedAgentCard = ( + upstreamCard: DiscoveredAgentCard, + savedCard: DiscoveredAgentCard | undefined | null, +): { + editedName: string; + editedDescription: string; + selectedSkillIds: Set; + selectedCapabilities: Record; +} => { + const upstreamSkills = upstreamCard.skills ?? []; + const savedSkills = savedCard?.skills ?? []; + + const savedSkillIds = new Set( + savedSkills.map((s) => s?.id).filter(Boolean) as string[], + ); + const savedSkillNames = new Set( + savedSkills.map((s) => s?.name).filter(Boolean) as string[], + ); + + const selectedSkillIds = new Set(); + upstreamSkills.forEach((skill, idx) => { + const id = skillId(skill, idx); + const matchesById = skill.id && savedSkillIds.has(skill.id); + const matchesByName = skill.name && savedSkillNames.has(skill.name); + if (matchesById || matchesByName) { + selectedSkillIds.add(id); + } + }); + + const selectedCapabilities = filterCapabilitiesForUI(upstreamCard.capabilities); + if (savedCard?.capabilities) { + for (const key of ALLOWED_CAPABILITY_KEYS) { + if (key in savedCard.capabilities) { + selectedCapabilities[key] = Boolean(savedCard.capabilities[key]); + } + } + } + + return { + editedName: savedCard?.name ?? upstreamCard.name ?? "", + editedDescription: savedCard?.description ?? upstreamCard.description ?? "", + selectedSkillIds, + selectedCapabilities, + }; +}; + +/** Default for create flow: select everything the upstream advertises. */ +export const selectionsFromUpstreamCard = ( + upstreamCard: DiscoveredAgentCard, +): { + editedName: string; + editedDescription: string; + selectedSkillIds: Set; + selectedCapabilities: Record; +} => { + const upstreamSkills = upstreamCard.skills ?? []; + return { + editedName: upstreamCard.name ?? "", + editedDescription: upstreamCard.description ?? "", + selectedSkillIds: new Set(upstreamSkills.map((s, i) => skillId(s, i))), + selectedCapabilities: filterCapabilitiesForUI(upstreamCard.capabilities), + }; +}; + +/** + * Overlay the admin's discovery selections onto the ``agent_card_params`` + * built from the form. Dynamic agent forms (e.g. LangGraph) don't register + * Form.Items for name / description / skills / capabilities, so AntD's + * setFieldsValue silently drops those keys and the values never make it back + * through buildAgentData — we re-apply them here from the selection. + */ +export const overlayDiscoveredCardParams = ( + agentData: Record, + discovered: DiscoveredAgentCard | null | undefined, +): Record => { + if (!discovered) return agentData; + return { + ...agentData, + agent_card_params: { + ...agentData.agent_card_params, + name: discovered.name ?? agentData.agent_card_params?.name, + description: + discovered.description ?? agentData.agent_card_params?.description, + ...(Array.isArray(discovered.skills) && { + skills: discovered.skills, + }), + ...(discovered.capabilities && { + capabilities: discovered.capabilities, + }), + ...(Array.isArray(discovered.defaultInputModes) && + discovered.defaultInputModes.length > 0 && { + defaultInputModes: discovered.defaultInputModes, + }), + ...(Array.isArray(discovered.defaultOutputModes) && + discovered.defaultOutputModes.length > 0 && { + defaultOutputModes: discovered.defaultOutputModes, + }), + ...(discovered.provider && { provider: discovered.provider }), + ...(discovered.iconUrl && { iconUrl: discovered.iconUrl }), + ...(discovered.documentationUrl && { + documentationUrl: discovered.documentationUrl, + }), + }, + }; +}; + +export const buildDiscoveryRequest = ( + agentType: string, + values: Record, + selectedAgentTypeInfo?: AgentCreateInfo, +): DiscoveryRequestPlan | undefined => { + const trim = (v: unknown) => (v ?? "").toString().trim(); + const stripTrailingSlash = (s: string) => s.replace(/\/+$/, ""); + + if (agentType === "langgraph") { + const base = stripTrailingSlash(trim(values.api_base)); + const assistantId = trim(values.assistant_id); + if (!base || !assistantId) return undefined; + const query = `?assistant_id=${encodeURIComponent(assistantId)}`; + return { + url: base, + discovery_mode: "langgraph_platform", + params: { assistant_id: assistantId }, + display_url: `${base}/.well-known/agent-card.json${query}`, + }; + } + + if (agentType === "a2a" || selectedAgentTypeInfo?.use_a2a_form_fields) { + const base = stripTrailingSlash(trim(values.url)); + if (!base) return undefined; + return { + url: base, + discovery_mode: "well_known_fallback", + display_url: `${base}/.well-known/agent-card.json`, + }; + } + + // Non-A2A agent runtimes (Azure AI Foundry, Bedrock AgentCore, Vertex, + // etc.) don't expose well-known agent cards on their credential URLs, so + // we deliberately don't auto-fire discovery for them. The + // ``AgentCardDiscovery`` widget falls back to a manual URL input the admin + // can use as an escape hatch. + return undefined; +}; diff --git a/ui/litellm-dashboard/src/components/agents/agent_info.tsx b/ui/litellm-dashboard/src/components/agents/agent_info.tsx index d543be8356..1e4280d361 100644 --- a/ui/litellm-dashboard/src/components/agents/agent_info.tsx +++ b/ui/litellm-dashboard/src/components/agents/agent_info.tsx @@ -1,4 +1,4 @@ -import React, { useState, useEffect } from "react"; +import React, { useState, useEffect, useMemo } from "react"; import { Card, Title, Text, Button as TremorButton, Tab, TabGroup, TabList, TabPanel, TabPanels} from "@tremor/react"; import { Form, Input, InputNumber, Button as AntButton, Spin, Descriptions, Divider } from "antd"; import MessageManager from "@/components/molecules/message_manager"; @@ -10,6 +10,13 @@ import DynamicAgentFormFields, { buildDynamicAgentData } from "./dynamic_agent_f import { buildAgentDataFromForm, parseAgentForForm } from "./agent_config"; import AgentCostView from "./agent_cost_view"; import { detectAgentType, parseDynamicAgentForForm } from "./agent_type_utils"; +import AgentCardDiscovery, { + DiscoveredAgentCardSelection, +} from "./agent_card_discovery"; +import { + buildDiscoveryRequest, + overlayDiscoveredCardParams, +} from "./agent_discovery_utils"; interface AgentInfoViewProps { agentId: string; @@ -31,6 +38,8 @@ const AgentInfoView: React.FC = ({ const [form] = Form.useForm(); const [agentTypeMetadata, setAgentTypeMetadata] = useState([]); const [detectedAgentType, setDetectedAgentType] = useState("a2a"); + const [appliedDiscoveredSelection, setAppliedDiscoveredSelection] = + useState(null); useEffect(() => { const fetchMetadata = async () => { @@ -93,6 +102,51 @@ const AgentInfoView: React.FC = ({ }, [agentTypeMetadata, agent]); const selectedAgentTypeInfo = agentTypeMetadata.find(t => t.agent_type === detectedAgentType); + const watchedFormValues = Form.useWatch([], form); + + const discoveryRequest = useMemo( + () => + buildDiscoveryRequest( + detectedAgentType, + watchedFormValues || {}, + selectedAgentTypeInfo, + ), + [watchedFormValues, selectedAgentTypeInfo, detectedAgentType], + ); + + const handleApplyDiscoveredCard = ( + selection: DiscoveredAgentCardSelection | null, + ) => { + setAppliedDiscoveredSelection(selection); + if (!selection) return; + const { selected_card } = selection; + const skills = (selected_card.skills ?? []).map((s) => ({ + id: s.id ?? "", + name: s.name ?? "", + description: s.description ?? "", + tags: s.tags ?? [], + examples: s.examples ?? [], + })); + + const fieldsToSet: Record = { + name: selected_card.name, + description: selected_card.description, + url: selection.upstream_url, + streaming: Boolean(selected_card.capabilities?.streaming), + skills, + iconUrl: selected_card.iconUrl, + documentationUrl: selected_card.documentationUrl, + }; + + const urlCredentialKeys = (selectedAgentTypeInfo?.credential_fields ?? []) + .map((f) => f.key) + .filter((key) => /(^|_)(url|api_base|endpoint)$/i.test(key)); + for (const key of urlCredentialKeys) { + fieldsToSet[key] = selection.upstream_url; + } + + form.setFieldsValue(fieldsToSet); + }; const handleUpdate = async (values: any) => { if (!accessToken || !agent) return; @@ -105,12 +159,18 @@ const AgentInfoView: React.FC = ({ updateData = buildAgentDataFromForm(values, agent); } else if (selectedAgentTypeInfo) { updateData = buildDynamicAgentData(values, selectedAgentTypeInfo); - // Preserve the agent_name from form updateData.agent_name = values.agent_name; } else { updateData = buildAgentDataFromForm(values, agent); } - + + if (appliedDiscoveredSelection) { + updateData = overlayDiscoveredCardParams( + updateData, + appliedDiscoveredSelection.selected_card, + ); + } + await patchAgentCall(accessToken, agentId, updateData); MessageManager.success("Agent updated successfully"); setIsEditing(false); @@ -278,7 +338,14 @@ const AgentInfoView: React.FC = ({
Agent Settings {!isEditing && ( - setIsEditing(true)}>Edit Settings + { + setAppliedDiscoveredSelection(null); + setIsEditing(true); + }} + > + Edit Settings + )}
@@ -300,6 +367,17 @@ const AgentInfoView: React.FC = ({ )} + {discoveryRequest && ( +
+ +
+ )} + Rate Limits
@@ -321,6 +399,7 @@ const AgentInfoView: React.FC = ({
{ + setAppliedDiscoveredSelection(null); setIsEditing(false); fetchAgentInfo(); }}> diff --git a/ui/litellm-dashboard/src/components/networking.tsx b/ui/litellm-dashboard/src/components/networking.tsx index 57e7d51123..aa5e08d195 100644 --- a/ui/litellm-dashboard/src/components/networking.tsx +++ b/ui/litellm-dashboard/src/components/networking.tsx @@ -6299,6 +6299,79 @@ export const createAgentCall = async (accessToken: string, agentData: any) => { } }; +export interface DiscoveredAgentCard { + protocolVersion?: string; + name?: string; + description?: string; + version?: string; + url?: string; + iconUrl?: string; + documentationUrl?: string; + defaultInputModes?: string[]; + defaultOutputModes?: string[]; + capabilities?: Record; + skills?: Array<{ + id?: string; + name?: string; + description?: string; + tags?: string[]; + examples?: string[]; + [key: string]: any; + }>; + provider?: { organization?: string; url?: string }; + [key: string]: any; +} + +export interface DiscoverAgentCardResponse { + url: string; + agent_card: DiscoveredAgentCard; +} + +/** + * How the backend should locate the upstream agent card. + * + * - ``well_known_fallback`` (default): pure A2A — try the three standard + * well-known paths under the base URL. + * - ``langgraph_platform``: LangGraph Platform — hits the canonical + * well-known path with an ``assistant_id`` query parameter, because + * LangGraph mounts one shared card endpoint per deployment. + */ +export type DiscoveryMode = "well_known_fallback" | "langgraph_platform"; + +export interface DiscoverAgentCardOptions { + discovery_mode?: DiscoveryMode; + /** Mode-specific params. ``langgraph_platform`` requires ``assistant_id``. */ + params?: Record; +} + +export const discoverAgentCardCall = async ( + accessToken: string, + url: string, + options?: DiscoverAgentCardOptions, +): Promise => { + const endpoint = proxyBaseUrl ? `${proxyBaseUrl}/v1/a2a/discover` : `/v1/a2a/discover`; + const body: Record = { url }; + if (options?.discovery_mode) body.discovery_mode = options.discovery_mode; + if (options?.params) body.params = options.params; + + const response = await fetch(endpoint, { + method: "POST", + headers: { + [globalLitellmHeaderName]: `Bearer ${accessToken}`, + "Content-Type": "application/json", + }, + body: JSON.stringify(body), + }); + + if (!response.ok) { + const errorData = await response.text(); + handleError(errorData); + throw new Error(errorData); + } + + return (await response.json()) as DiscoverAgentCardResponse; +}; + export const createGuardrailCall = async (accessToken: string, guardrailData: any) => { try { const url = proxyBaseUrl ? `${proxyBaseUrl}/guardrails` : `/guardrails`;