diff --git a/skills/build-creative-agent/SKILL.md b/skills/build-creative-agent/SKILL.md index 937f53193..acb3563d7 100644 --- a/skills/build-creative-agent/SKILL.md +++ b/skills/build-creative-agent/SKILL.md @@ -228,7 +228,7 @@ async def build_creative(self, params, context=None): | Function | Usage | |----------|-------| -| `serve(handler)` | Start server on `:3001/mcp` | +| `serve(handler, transport="a2a"\|"streamable-http", port=3001)` | Start MCP or A2A server. Context passthrough is automatic. | | `capabilities_response(protocols)` | `get_adcp_capabilities` response | | `creative_formats_response(formats)` | `list_creative_formats` response | | `sync_creatives_response(creatives)` | `sync_creatives` response | diff --git a/skills/build-seller-agent/SKILL.md b/skills/build-seller-agent/SKILL.md index 964106819..ab0c63a0c 100644 --- a/skills/build-seller-agent/SKILL.md +++ b/skills/build-seller-agent/SKILL.md @@ -58,7 +58,7 @@ Pricing models: One file. Subclass `ADCPHandler`, override the tools you support, call `serve()`. ```python -from adcp.server import ADCPHandler, serve, adcp_error, resolve_account, inject_context +from adcp.server import ADCPHandler, serve, adcp_error, resolve_account from adcp.server.responses import capabilities_response, products_response, media_buy_response from adcp.server.test_controller import TestControllerStore @@ -76,6 +76,8 @@ serve(MySeller(), name="my-seller", test_controller=MyStore()) ## Product Construction Example +Every product needs `description`, `reporting_capabilities`, and `delivery_measurement` — these are required by the schema and the storyboard validator. + ```python PRODUCTS = [ { @@ -87,7 +89,7 @@ PRODUCTS = [ {"publisher_domain": "example.com", "selection_type": "all"} ], "format_ids": [ - {"agent_url": "http://localhost:3001/mcp", "id": "display_970x250"} + {"agent_url": AGENT_URL, "id": "display_970x250"} ], "pricing_options": [ { @@ -97,6 +99,19 @@ PRODUCTS = [ "currency": "USD", } ], + "reporting_capabilities": { + "available_metrics": ["impressions", "spend", "clicks", "ctr"], + "available_reporting_frequencies": ["hourly", "daily"], # hourly|daily|monthly only + "date_range_support": "date_range", # or "lifetime_only" + "supports_webhooks": False, + "expected_delay_minutes": 60, + "timezone": "UTC", + }, + "delivery_measurement": { + "measurement_type": "server_side", + "verification": "internal", + "provider": "internal", + }, }, ] ``` @@ -105,13 +120,15 @@ PRODUCTS = [ The SDK provides helpers that eliminate common boilerplate. Import from `adcp.server`: +**Automatic behaviors** (no handler code needed): +- **Context passthrough** — if the request has a `context` field, it's echoed back in the response automatically. + | Helper | What it does | |--------|-------------| | `adcp_error(code, message, field=, suggestion=)` | Structured error with auto-recovery classification (20+ standard codes) | | `media_buy_response(..., status="active")` | Auto-populates `valid_actions` from status, auto-sets `revision` and `confirmed_at` | | `cancel_media_buy_response(id, "buyer")` | Auto-sets `canceled_at`, `status`, `valid_actions=[]` | | `resolve_account(params, resolver)` | Auto-resolves AccountReference, returns ACCOUNT_NOT_FOUND if missing | -| `inject_context(params, response)` | Echoes `context` field from request to response (ADCP requirement) | | `valid_actions_for_status(status)` | Maps status to valid buyer actions | | `is_terminal_status(status)` | True for completed/rejected/canceled | | `AccountError(code, message, suggestion=)` | Raise from resolver for suspended/payment/ambiguous accounts | @@ -257,7 +274,7 @@ async def get_products(self, params, context=None): **`create_media_buy`** ```python -from adcp.server import adcp_error, inject_context +from adcp.server import adcp_error from adcp.server.responses import media_buy_response async def create_media_buy(self, params, context=None): @@ -283,8 +300,7 @@ async def create_media_buy(self, params, context=None): # Store so get_media_buys and test controller can find it media_buys[mb_id] = {"status": "active", "currency": "USD", "packages": packages} # status="active" auto-populates valid_actions, revision, confirmed_at - resp = media_buy_response(mb_id, packages, status="active") - return inject_context(params, resp) + return media_buy_response(mb_id, packages, status="active") ``` **`get_media_buys`** @@ -308,7 +324,7 @@ async def get_media_buys(self, params, context=None): **`update_media_buy`** — handles pause, resume, cancel, budget changes, package updates. ```python -from adcp.server import adcp_error, inject_context, cancel_media_buy_response +from adcp.server import adcp_error, cancel_media_buy_response from adcp.server.responses import update_media_buy_response async def update_media_buy(self, params, context=None): @@ -342,10 +358,9 @@ async def update_media_buy(self, params, context=None): # Apply budget, dates, etc. mb["revision"] = mb.get("revision", 1) + 1 - resp = update_media_buy_response( + return update_media_buy_response( mb_id, status=mb["status"], revision=mb["revision"] ) - return inject_context(params, resp) ``` **`list_creative_formats`** @@ -490,9 +505,8 @@ return capabilities_response(["media_buy", "compliance_testing"]) | `adcp_error(code, message, field=, suggestion=)` | Structured error with auto-recovery | | `cancel_media_buy_response(id, "buyer"/"seller")` | Cancellation with auto-defaults | | `resolve_account(params, resolver)` | Account resolution with auto-error | -| `inject_context(params, response)` | Context passthrough (ADCP requirement) | | `valid_actions_for_status(status)` | Status-to-actions mapping | -| `serve(handler, test_controller=store)` | Start server on `:3001/mcp` | +| `serve(handler, transport="a2a"\|"streamable-http", port=3001, test_controller=store)` | Start MCP or A2A server. Context passthrough is automatic. | Import helpers from `adcp.server`. Import response builders from `adcp.server.responses`. Import types from `adcp.types`. @@ -525,6 +539,9 @@ npx @adcp/client storyboard run http://localhost:3001/mcp media_buy_seller --jso | Missing `brand`/`operator` in sync_accounts | Echo them back from the request | | Not storing entities in memory | Test controller needs to find accounts, media buys, creatives | | Wrong `delivery_response` signature | Takes `delivery_response(deliveries_list, reporting_period=...)`, not individual metrics | +| Missing `reporting_capabilities` on products | Required. Sub-fields: `available_metrics`, `available_reporting_frequencies`, `date_range_support`, `supports_webhooks` | +| `weekly` in `available_reporting_frequencies` | Only `hourly`, `daily`, `monthly` are valid | +| Missing `delivery_measurement.provider` | Required field — use `"internal"` or third-party provider name | ## Reference diff --git a/skills/build-signals-agent/SKILL.md b/skills/build-signals-agent/SKILL.md index 5b9f4d5aa..10d1a65d0 100644 --- a/skills/build-signals-agent/SKILL.md +++ b/skills/build-signals-agent/SKILL.md @@ -53,7 +53,7 @@ At least one pricing option per signal: One file. Subclass `ADCPHandler`, override the tools you support, call `serve()`. ```python -from adcp.server import ADCPHandler, serve, adcp_error, inject_context +from adcp.server import ADCPHandler, serve, adcp_error from adcp.server.responses import capabilities_response, signals_response, activate_signal_response class MySignalsAgent(ADCPHandler): @@ -205,8 +205,7 @@ async def activate_signal(self, params, context=None): | Function | Usage | |----------|-------| | `adcp_error(code, message, field=, suggestion=)` | Structured error with auto-recovery | -| `inject_context(params, response)` | Context passthrough (ADCP requirement) | -| `serve(handler)` | Start server on `:3001/mcp` | +| `serve(handler, transport="a2a"\|"streamable-http", port=3001)` | Start MCP or A2A server. Context passthrough is automatic — no need to call `inject_context` in handlers. | Import helpers from `adcp.server`. Import response builders from `adcp.server.responses`. diff --git a/src/adcp/server/__init__.py b/src/adcp/server/__init__.py index a570d0a6f..6bb09391b 100644 --- a/src/adcp/server/__init__.py +++ b/src/adcp/server/__init__.py @@ -53,6 +53,7 @@ async def get_products(params, context=None): from __future__ import annotations from adcp.capabilities import validate_capabilities +from adcp.server.a2a_server import ADCPAgentExecutor, create_a2a_server from adcp.server.base import ( ADCPHandler, NotImplementedResponse, @@ -134,6 +135,9 @@ async def get_products(params, context=None): "create_mcp_server", "get_tools_for_handler", "serve", + # A2A integration + "ADCPAgentExecutor", + "create_a2a_server", # Test controller "TestControllerStore", "TestControllerError", diff --git a/src/adcp/server/a2a_server.py b/src/adcp/server/a2a_server.py new file mode 100644 index 000000000..0d490cc3b --- /dev/null +++ b/src/adcp/server/a2a_server.py @@ -0,0 +1,355 @@ +"""A2A server support for ADCP handlers. + +Bridges ADCPHandler to the a2a-sdk server framework so the same handler +can be served over both MCP and A2A transports. + + from adcp.server import ADCPHandler, serve + serve(MyHandler(), name="my-agent", transport="a2a") +""" + +from __future__ import annotations + +import json +import logging +import os +from typing import Any +from uuid import uuid4 + +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.server.request_handlers.default_request_handler import ( + DefaultRequestHandler, +) +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, + Artifact, + DataPart, + Part, + Task, + TaskState, + TaskStatus, + TextPart, +) + +from adcp.server.base import ADCPHandler +from adcp.server.mcp_tools import create_tool_caller, get_tools_for_handler +from adcp.server.test_controller import TestControllerStore, _handle_test_controller + +logger = logging.getLogger(__name__) + + +class ADCPAgentExecutor(AgentExecutor): + """Bridges ADCPHandler methods to the a2a-sdk AgentExecutor interface. + + Incoming A2A messages are parsed to extract the ADCP skill name and + parameters, dispatched to the matching handler method, and the result + is published back as A2A Task events. + + Expects the explicit skill invocation format used by A2AAdapter: + DataPart(data={"skill": "get_products", "parameters": {...}}) + """ + + def __init__( + self, + handler: ADCPHandler, + test_controller: TestControllerStore | None = None, + ) -> None: + self._handler = handler + self._tool_callers: dict[str, Any] = {} + + # Build tool callers for all tools this handler supports + tool_defs = get_tools_for_handler(handler) + for tool_def in tool_defs: + name = tool_def["name"] + self._tool_callers[name] = create_tool_caller(handler, name) + + if test_controller is not None: + self._register_test_controller(test_controller) + + @property + def supported_skills(self) -> list[str]: + """List of skill names this executor can handle.""" + return list(self._tool_callers.keys()) + + def _register_test_controller(self, store: TestControllerStore) -> None: + """Register comply_test_controller as a callable skill.""" + + async def _call_test_controller(params: dict[str, Any]) -> Any: + return await _handle_test_controller(store, params) + + self._tool_callers["comply_test_controller"] = _call_test_controller + + async def execute( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + """Execute an ADCP skill from an incoming A2A message.""" + skill_name, params = self._parse_request(context) + + if skill_name is None: + await self._send_error( + event_queue, context, "No skill specified in message" + ) + return + + if skill_name not in self._tool_callers: + await self._send_error( + event_queue, context, f"Unknown skill: {skill_name}" + ) + return + + try: + result = await self._tool_callers[skill_name](params) + await self._send_result(event_queue, context, skill_name, result) + except Exception: + logger.exception("Error executing skill %s", skill_name) + await self._send_error( + event_queue, context, f"Skill execution failed: {skill_name}" + ) + + async def cancel( + self, context: RequestContext, event_queue: EventQueue + ) -> None: + """ADCP operations are synchronous; cancellation sets state to canceled.""" + event = _make_task( + context, + state=TaskState.canceled, + message="Task canceled", + ) + await event_queue.enqueue_event(event) + + # ------------------------------------------------------------------ + # Message parsing + # ------------------------------------------------------------------ + + def _parse_request( + self, context: RequestContext + ) -> tuple[str | None, dict[str, Any]]: + """Extract skill name and parameters from the A2A message. + + Supports two formats: + 1. Explicit skill invocation via DataPart: + DataPart(data={"skill": "get_products", "parameters": {...}}) + 2. Natural language fallback via TextPart (best-effort parse) + """ + msg = context.message + if msg is None or not msg.parts: + return None, {} + + # Try DataPart first (explicit skill invocation) + for part in msg.parts: + inner = part.root if hasattr(part, "root") else part + if isinstance(inner, DataPart) and isinstance(inner.data, dict): + skill = inner.data.get("skill") + params = inner.data.get("parameters", {}) + if skill: + return str(skill), params if isinstance(params, dict) else {} + + # Fallback: try to parse TextPart as JSON + for part in msg.parts: + inner = part.root if hasattr(part, "root") else part + if isinstance(inner, TextPart): + parsed = self._parse_text_request(inner.text) + if parsed[0] is not None: + return parsed + + return None, {} + + def _parse_text_request( + self, text: str + ) -> tuple[str | None, dict[str, Any]]: + """Best-effort parse of a text request for skill + params.""" + try: + data = json.loads(text) + if isinstance(data, dict) and "skill" in data: + return str(data["skill"]), data.get("parameters", {}) + except (json.JSONDecodeError, TypeError): + pass + return None, {} + + # ------------------------------------------------------------------ + # Response helpers + # ------------------------------------------------------------------ + + async def _send_result( + self, + event_queue: EventQueue, + context: RequestContext, + skill_name: str, + result: Any, + ) -> None: + """Publish a completed task with the skill result.""" + # Normalize result to a JSON-safe dict + if hasattr(result, "model_dump"): + data = result.model_dump(mode="json", exclude_none=True) + elif not isinstance(result, dict): + data = {"result": result} + else: + data = result + + task = _make_task( + context, + state=TaskState.completed, + data=data, + message=f"Completed {skill_name}", + ) + await event_queue.enqueue_event(task) + + async def _send_error( + self, + event_queue: EventQueue, + context: RequestContext, + error_msg: str, + ) -> None: + """Publish a failed task.""" + task = _make_task( + context, + state=TaskState.failed, + message=error_msg, + ) + await event_queue.enqueue_event(task) + + +# ------------------------------------------------------------------ +# Task factory +# ------------------------------------------------------------------ + + +def _make_task( + context: RequestContext, + *, + state: TaskState, + data: dict[str, Any] | None = None, + message: str | None = None, +) -> Task: + """Build an a2a Task event from context and result data.""" + parts: list[Part] = [] + if data is not None: + parts.append(Part(root=DataPart(data=data))) + if message: + parts.append(Part(root=TextPart(text=message))) + + artifacts = [] + if parts: + artifacts.append( + Artifact( + artifact_id=str(uuid4()), + parts=parts, + ) + ) + + return Task( + id=context.task_id or str(uuid4()), + context_id=context.context_id or str(uuid4()), + status=TaskStatus(state=state), + artifacts=artifacts if artifacts else None, + ) + + +# ------------------------------------------------------------------ +# Public API +# ------------------------------------------------------------------ + + +def _build_agent_card( + handler: ADCPHandler, + *, + name: str, + port: int, + description: str | None = None, + version: str = "1.0.0", + extra_skills: list[AgentSkill] | None = None, +) -> AgentCard: + """Build an A2A AgentCard from an ADCPHandler's tool definitions.""" + tool_defs = get_tools_for_handler(handler) + + skills = [ + AgentSkill( + id=td["name"], + name=td["name"], + description=td.get("description", td["name"]), + tags=["adcp"], + ) + for td in tool_defs + ] + + if extra_skills: + skills.extend(extra_skills) + + return AgentCard( + name=name, + description=description or f"ADCP agent: {name}", + url=f"http://localhost:{port}/", + version=version, + skills=skills, + capabilities=AgentCapabilities(streaming=False), + default_input_modes=["application/json"], + default_output_modes=["application/json"], + ) + + +def create_a2a_server( + handler: ADCPHandler, + *, + name: str = "adcp-agent", + port: int | None = None, + description: str | None = None, + version: str = "1.0.0", + test_controller: TestControllerStore | None = None, +) -> Any: + """Create an A2A Starlette application from an ADCP handler. + + Args: + handler: An ADCPHandler subclass instance. + name: Agent name shown in the A2A agent card. + port: Port number (used in the agent card URL). + description: Agent description for the agent card. + version: Agent version string. + test_controller: Optional TestControllerStore for storyboard testing. + + Returns: + A Starlette app ready to be run with uvicorn. + """ + from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication + + resolved_port = port or int(os.environ.get("PORT", "3001")) + + executor = ADCPAgentExecutor(handler, test_controller=test_controller) + + agent_card = _build_agent_card( + handler, + name=name, + port=resolved_port, + description=description, + version=version, + extra_skills=_test_controller_skills() if test_controller else None, + ) + + task_store = InMemoryTaskStore() + + request_handler = DefaultRequestHandler( + agent_executor=executor, + task_store=task_store, + ) + + a2a_app = A2AStarletteApplication( + agent_card=agent_card, + http_handler=request_handler, + ) + + return a2a_app.build() + + +def _test_controller_skills() -> list[AgentSkill]: + """Build A2A skill definition for comply_test_controller.""" + return [ + AgentSkill( + id="comply_test_controller", + name="comply_test_controller", + description="Compliance test controller. Sandbox only, not for production use.", + tags=["adcp", "testing"], + ) + ] diff --git a/src/adcp/server/mcp_tools.py b/src/adcp/server/mcp_tools.py index d8cb6718f..b9dab7a6c 100644 --- a/src/adcp/server/mcp_tools.py +++ b/src/adcp/server/mcp_tools.py @@ -1117,6 +1117,10 @@ def create_tool_caller( ) -> Callable[[dict[str, Any]], Any]: """Create a tool caller function for an ADCP handler method. + Automatically injects context passthrough: if the request contains a + ``context`` field, it is echoed back in the response (ADCP requirement). + Handlers no longer need to call ``inject_context()`` manually. + Args: handler: The ADCP handler instance method_name: Name of the method to call @@ -1124,6 +1128,8 @@ def create_tool_caller( Returns: Async callable that invokes the handler method """ + from adcp.server.helpers import inject_context + method = getattr(handler, method_name) async def call_tool(params: dict[str, Any]) -> Any: @@ -1131,7 +1137,10 @@ async def call_tool(params: dict[str, Any]) -> Any: result = await method(params, context) # Convert Pydantic models to JSON-safe dicts for MCP serialization if hasattr(result, "model_dump"): - return result.model_dump(mode="json", exclude_none=True) + result = result.model_dump(mode="json", exclude_none=True) + # ADCP requires echoing context from request to response + if isinstance(result, dict): + inject_context(params, result) return result return call_tool diff --git a/src/adcp/server/serve.py b/src/adcp/server/serve.py index 9d1cc9bb0..72fe26006 100644 --- a/src/adcp/server/serve.py +++ b/src/adcp/server/serve.py @@ -1,6 +1,6 @@ -"""One-liner MCP server for ADCP handlers. +"""One-liner server for ADCP handlers (MCP or A2A). -Stand up an ADCP-compliant MCP server with a single function call: +Stand up an ADCP-compliant server with a single function call: from adcp.server import ADCPHandler, serve from adcp.server.responses import capabilities_response @@ -9,7 +9,11 @@ class MyAgent(ADCPHandler): async def get_adcp_capabilities(self, params, context=None): return capabilities_response(["media_buy"]) + # MCP (default) serve(MyAgent()) + + # A2A + serve(MyAgent(), transport="a2a") """ from __future__ import annotations @@ -35,31 +39,31 @@ def serve( instructions: str | None = None, test_controller: TestControllerStore | None = None, ) -> None: - """Start an MCP server from an ADCP handler or server builder. + """Start an MCP or A2A server from an ADCP handler or server builder. Accepts either an ``ADCPHandler`` instance or an ``ADCPServerBuilder`` (from ``adcp_server()``). Builders are auto-converted via ``build_handler()``. - This is the simplest way to run an ADCP agent. It creates a FastMCP server, - registers all tools from the handler, optionally registers a test controller, - and starts serving. + This is the simplest way to run an ADCP agent. Set ``transport="a2a"`` + to serve over the A2A protocol instead of MCP. Args: handler: An ADCPHandler subclass instance with your tool implementations. - name: Server name shown to MCP clients. + name: Server name shown to clients / in the A2A agent card. port: Port to listen on. Defaults to PORT env var, then 3001. - mount: URL path to mount MCP endpoint on. - transport: MCP transport type. Default "streamable-http". - instructions: Optional system instructions for the agent. + mount: URL path to mount MCP endpoint on (MCP only). + transport: ``"streamable-http"`` (default, MCP) or ``"a2a"``. + instructions: Optional system instructions for the agent (MCP only). test_controller: Optional TestControllerStore instance for storyboard testing. Security: This function does NOT configure authentication. In production, use a reverse proxy or middleware that validates credentials - before forwarding to the MCP endpoint. Without authentication, - the tools/list endpoint exposes the agent's capability surface. + before forwarding to the endpoint. Without authentication, + MCP exposes tools/list and A2A exposes /.well-known/agent.json, + both of which reveal the agent's full capability surface. - Example: + Example (MCP): from adcp.server import ADCPHandler, serve from adcp.server.responses import capabilities_response @@ -69,6 +73,9 @@ async def get_adcp_capabilities(self, params, context=None): serve(MyAgent(), name="my-agent") + Example (A2A): + serve(MyAgent(), name="my-agent", transport="a2a") + With test controller: from adcp.server.test_controller import TestControllerStore @@ -86,6 +93,32 @@ async def force_account_status(self, account_id, status): name = handler.name handler = handler.build_handler() + if transport == "a2a": + _serve_a2a(handler, name=name, port=port, test_controller=test_controller) + elif transport in ("streamable-http", "sse", "stdio"): + _serve_mcp( + handler, + name=name, + port=port, + transport=transport, + instructions=instructions, + test_controller=test_controller, + ) + else: + valid = ", ".join(sorted(("a2a", "streamable-http", "sse", "stdio"))) + raise ValueError(f"Unknown transport {transport!r}. Valid: {valid}") + + +def _serve_mcp( + handler: ADCPHandler, + *, + name: str, + port: int | None, + transport: str, + instructions: str | None, + test_controller: TestControllerStore | None, +) -> None: + """Start an MCP server.""" mcp = create_mcp_server(handler, name=name, port=port, instructions=instructions) if test_controller is not None: @@ -96,6 +129,26 @@ async def force_account_status(self, account_id, status): mcp.run(transport=transport) +def _serve_a2a( + handler: ADCPHandler, + *, + name: str, + port: int | None, + test_controller: TestControllerStore | None, +) -> None: + """Start an A2A server using uvicorn.""" + import uvicorn + + from adcp.server.a2a_server import create_a2a_server + + resolved_port = port or int(os.environ.get("PORT", "3001")) + + app = create_a2a_server( + handler, name=name, port=resolved_port, test_controller=test_controller + ) + uvicorn.run(app, host="0.0.0.0", port=resolved_port) + + def create_mcp_server( handler: ADCPHandler, *, diff --git a/tests/test_a2a_server.py b/tests/test_a2a_server.py new file mode 100644 index 000000000..84599e478 --- /dev/null +++ b/tests/test_a2a_server.py @@ -0,0 +1,390 @@ +"""Tests for A2A server support: ADCPAgentExecutor, create_a2a_server.""" + +from __future__ import annotations + +import json +import sys +from typing import Any + +import pytest +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.types import ( + DataPart, + Message, + MessageSendParams, + Part, + Role, + Task, + TextPart, +) + +from adcp.server import ADCPHandler +from adcp.server.a2a_server import ( + ADCPAgentExecutor, + _build_agent_card, + create_a2a_server, +) +from adcp.server.test_controller import TestControllerError, TestControllerStore + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +class _TestHandler(ADCPHandler): + """Minimal handler that supports get_adcp_capabilities and get_products.""" + + async def get_adcp_capabilities( + self, params: Any, context: Any = None + ) -> dict[str, Any]: + return {"adcp": {"major_versions": [3]}, "supported_protocols": ["media_buy"]} + + async def get_products( + self, params: Any, context: Any = None + ) -> dict[str, Any]: + return { + "products": [{"id": "p1", "name": "Display"}], + "sandbox": True, + } + + +def _make_datapart_msg(skill: str, parameters: dict[str, Any] | None = None) -> Message: + return Message( + message_id="msg-1", + role=Role.user, + parts=[ + Part( + root=DataPart(data={"skill": skill, "parameters": parameters or {}}) + ) + ], + ) + + +def _make_text_msg(text: str) -> Message: + return Message( + message_id="msg-1", + role=Role.user, + parts=[Part(root=TextPart(text=text))], + ) + + +# --------------------------------------------------------------------------- +# ADCPAgentExecutor — sync tests +# --------------------------------------------------------------------------- + + +def test_executor_supported_skills(): + executor = ADCPAgentExecutor(_TestHandler()) + skills = executor.supported_skills + assert "get_adcp_capabilities" in skills + assert "get_products" in skills + + +# --------------------------------------------------------------------------- +# ADCPAgentExecutor — async tests +# --------------------------------------------------------------------------- + + +async def test_execute_with_datapart(): + """Executor dispatches DataPart skill invocation to handler.""" + executor = ADCPAgentExecutor(_TestHandler()) + ctx = RequestContext(request=MessageSendParams(message=_make_datapart_msg("get_products"))) + queue = EventQueue() + + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "completed" + + # Verify the result data is in the artifact + assert event.artifacts + data_parts = [ + p.root + for p in event.artifacts[0].parts + if hasattr(p.root, "data") and isinstance(p.root.data, dict) + ] + assert len(data_parts) >= 1 + result = data_parts[0].data + assert "products" in result + assert result["products"][0]["id"] == "p1" + + +async def test_context_auto_injected(): + """Context from request is automatically echoed in response.""" + executor = ADCPAgentExecutor(_TestHandler()) + ctx = RequestContext( + request=MessageSendParams( + message=_make_datapart_msg( + "get_products", + {"context": {"correlation_id": "test-ctx-123"}}, + ) + ) + ) + queue = EventQueue() + + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + data_parts = [ + p.root + for p in event.artifacts[0].parts + if hasattr(p.root, "data") and isinstance(p.root.data, dict) + ] + result = data_parts[0].data + assert result["context"]["correlation_id"] == "test-ctx-123" + + +async def test_execute_unknown_skill(): + """Executor returns failed task for unknown skills.""" + executor = ADCPAgentExecutor(_TestHandler()) + ctx = RequestContext( + request=MessageSendParams(message=_make_datapart_msg("nonexistent_skill")) + ) + queue = EventQueue() + + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "failed" + + +async def test_execute_no_skill_in_message(): + """Executor returns failed task when message has no parseable skill.""" + executor = ADCPAgentExecutor(_TestHandler()) + ctx = RequestContext(request=MessageSendParams(message=_make_text_msg("hello"))) + queue = EventQueue() + + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "failed" + + +async def test_execute_json_text_fallback(): + """Executor parses JSON text as skill invocation.""" + executor = ADCPAgentExecutor(_TestHandler()) + payload = json.dumps({"skill": "get_products", "parameters": {}}) + ctx = RequestContext(request=MessageSendParams(message=_make_text_msg(payload))) + queue = EventQueue() + + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "completed" + + +async def test_execute_handler_exception(): + """Handler exception returns failed task without leaking details.""" + + class _BrokenHandler(ADCPHandler): + async def get_adcp_capabilities(self, params: Any, context: Any = None) -> Any: + return {"adcp": {"major_versions": [3]}} + + async def get_products(self, params: Any, context: Any = None) -> Any: + raise RuntimeError("secret database connection string leaked") + + executor = ADCPAgentExecutor(_BrokenHandler()) + ctx = RequestContext( + request=MessageSendParams(message=_make_datapart_msg("get_products")) + ) + queue = EventQueue() + + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "failed" + + # Verify exception details are NOT in the error message + text_parts = [ + p.root for p in event.artifacts[0].parts if hasattr(p.root, "text") + ] + error_text = text_parts[0].text + assert "secret database" not in error_text + assert "get_products" in error_text + + +async def test_cancel(): + """Cancel returns a canceled task.""" + executor = ADCPAgentExecutor(_TestHandler()) + ctx = RequestContext(task_id="t1", context_id="c1") + queue = EventQueue() + + await executor.cancel(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "canceled" + + +# --------------------------------------------------------------------------- +# Agent card builder +# --------------------------------------------------------------------------- + + +def test_build_agent_card_with_skills(): + card = _build_agent_card(_TestHandler(), name="test-agent", port=3001) + assert card.name == "test-agent" + assert card.url == "http://localhost:3001/" + skill_ids = [s.id for s in card.skills] + assert "get_adcp_capabilities" in skill_ids + assert "get_products" in skill_ids + + +def test_build_agent_card_skills_tagged_adcp(): + card = _build_agent_card(_TestHandler(), name="test", port=8080) + for skill in card.skills: + assert "adcp" in skill.tags + + +# --------------------------------------------------------------------------- +# create_a2a_server +# --------------------------------------------------------------------------- + + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="a2a-sdk starlette integration requires Python 3.11+", +) +def test_create_a2a_server_creates_starlette_app(): + app = create_a2a_server(_TestHandler(), name="test-agent") + # Starlette app has .routes + assert hasattr(app, "routes") + route_paths = [r.path for r in app.routes] + # A2A well-known agent card endpoint + assert "/.well-known/agent.json" in route_paths + + +# --------------------------------------------------------------------------- +# TestControllerStore integration +# --------------------------------------------------------------------------- + + +class _TestStore(TestControllerStore): + def __init__(self) -> None: + self.accounts: dict[str, str] = {"acct-1": "active"} + + async def force_account_status(self, account_id: str, status: str) -> dict[str, Any]: + if account_id not in self.accounts: + raise TestControllerError("NOT_FOUND", f"Account {account_id} not found") + prev = self.accounts[account_id] + self.accounts[account_id] = status + return {"previous_state": prev, "current_state": status} + + +def test_executor_with_test_controller_has_skill(): + """Test controller registers comply_test_controller as a skill.""" + executor = ADCPAgentExecutor(_TestHandler(), test_controller=_TestStore()) + assert "comply_test_controller" in executor.supported_skills + + +async def test_execute_test_controller_list_scenarios(): + """comply_test_controller list_scenarios works via A2A.""" + executor = ADCPAgentExecutor(_TestHandler(), test_controller=_TestStore()) + ctx = RequestContext( + request=MessageSendParams( + message=_make_datapart_msg( + "comply_test_controller", + {"scenario": "list_scenarios"}, + ) + ) + ) + queue = EventQueue() + + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "completed" + + data_parts = [ + p.root + for p in event.artifacts[0].parts + if hasattr(p.root, "data") and isinstance(p.root.data, dict) + ] + result = data_parts[0].data + assert result["success"] is True + assert "force_account_status" in result["scenarios"] + + +async def test_execute_test_controller_force_account_status(): + """comply_test_controller dispatches force_account_status correctly.""" + executor = ADCPAgentExecutor(_TestHandler(), test_controller=_TestStore()) + ctx = RequestContext( + request=MessageSendParams( + message=_make_datapart_msg( + "comply_test_controller", + { + "scenario": "force_account_status", + "params": {"account_id": "acct-1", "status": "suspended"}, + }, + ) + ) + ) + queue = EventQueue() + + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "completed" + + data_parts = [ + p.root + for p in event.artifacts[0].parts + if hasattr(p.root, "data") and isinstance(p.root.data, dict) + ] + result = data_parts[0].data + assert result["success"] is True + assert result["previous_state"] == "active" + assert result["current_state"] == "suspended" + + +async def test_execute_test_controller_error(): + """comply_test_controller handles TestControllerError.""" + executor = ADCPAgentExecutor(_TestHandler(), test_controller=_TestStore()) + ctx = RequestContext( + request=MessageSendParams( + message=_make_datapart_msg( + "comply_test_controller", + { + "scenario": "force_account_status", + "params": {"account_id": "nonexistent", "status": "active"}, + }, + ) + ) + ) + queue = EventQueue() + + await executor.execute(ctx, queue) + + event = await queue.dequeue_event(no_wait=True) + assert isinstance(event, Task) + assert event.status.state == "completed" # A2A task succeeds; error is in data + + data_parts = [ + p.root + for p in event.artifacts[0].parts + if hasattr(p.root, "data") and isinstance(p.root.data, dict) + ] + result = data_parts[0].data + assert result["success"] is False + assert result["error"] == "NOT_FOUND" + + +@pytest.mark.skipif( + sys.version_info < (3, 11), + reason="a2a-sdk starlette integration requires Python 3.11+", +) +def test_create_a2a_server_with_test_controller(): + """create_a2a_server includes comply_test_controller in agent card.""" + app = create_a2a_server( + _TestHandler(), name="test-agent", test_controller=_TestStore() + ) + assert hasattr(app, "routes")