diff --git a/src/adcp/client.py b/src/adcp/client.py index 23a162d74..84546b599 100644 --- a/src/adcp/client.py +++ b/src/adcp/client.py @@ -309,6 +309,7 @@ def __init__( validate_features: bool = False, strict_idempotency: bool = False, signing: SigningConfig | None = None, + context_id: str | None = None, ): """ Initialize ADCP client for a single agent. @@ -343,6 +344,20 @@ def __init__( ``jwks_uri``. Supported on both A2A and MCP (``mcp_transport="streamable_http"``); SSE-transport MCP logs a warning and falls through unsigned. + context_id: A2A-only. Seed the A2A conversation context. Pass a + previously-returned ``context_id`` to resume a session + across process restarts, or a self-assigned UUID to name + the session with your own correlation key (the ADK server + honors buyer-proposed ids). If omitted, the server mints + one on the first message and this client auto-retains it + for subsequent calls. Read the current value via + ``client.context_id``; call ``client.reset_context()`` to + start a fresh conversation. Rule of thumb: one + ``ADCPClient`` per A2A conversation — if a buyer has + multiple concurrent briefs with the same agent, construct + one client per brief rather than sharing. + + Raises ``ValueError`` if passed with a non-A2A protocol. """ self.agent_config = agent_config self.webhook_url_template = webhook_url_template @@ -380,11 +395,81 @@ def __init__( if signing is not None: self.adapter.signing_request_hook = self._sign_outgoing_request + if context_id is not None: + if not isinstance(self.adapter, A2AAdapter): + raise ValueError( + "context_id is only supported for A2A protocol; " f"got {agent_config.protocol}" + ) + self.adapter.set_context_id(context_id) + # Initialize simple API accessor (lazy import to avoid circular dependency) from adcp.simple import SimpleAPI self.simple = SimpleAPI(self) + @property + def context_id(self) -> str | None: + """Current A2A conversation context_id. + + Reads the context_id currently associated with this client: the + value assigned by the A2A server (auto-captured from the most + recent response) or the one seeded via the constructor or + ``reset_context()``. Returns ``None`` before the first A2A call + in a fresh conversation, or for clients on non-A2A protocols — + reads are lenient across protocols so generic code can probe + ``if client.context_id: ...`` safely. Writes (constructor kwarg, + ``reset_context``) raise on non-A2A because the operation has no + meaning there. + + Not safe for concurrent calls on the same client — the adapter + mutates this on every response. Rule of thumb: one ADCPClient + per A2A conversation. Persist this value (e.g., Redis keyed by + your brief id) to resume across process restarts by passing it + to ``ADCPClient(context_id=...)``. + """ + if isinstance(self.adapter, A2AAdapter): + return self.adapter.context_id + return None + + @property + def pending_task_id(self) -> str | None: + """A2A task_id pending resume, or None if no task is in-flight. + + Set when the last A2A response was non-terminal + (``input-required``, ``working``, ``submitted``, + ``auth-required``). The adapter echoes this id on the next + outbound message so the server resumes the same task. Clears + automatically when the task reaches a terminal state. + + Returns ``None`` for non-A2A clients. + """ + if isinstance(self.adapter, A2AAdapter): + return self.adapter.pending_task_id + return None + + def reset_context(self, context_id: str | None = None) -> None: + """Start a new A2A conversation on this client. + + Passing ``None`` (default) clears the current context so the + server mints a fresh one on the next call. Passing a string uses + it as the new conversation id — useful for resuming a specific + prior session or for naming the conversation with your own + correlation key. Note: some servers (notably ADK) rewrite + client-supplied ids into their own session format; the client + auto-adopts the rewritten id on the next response. + + Also clears any pending_task_id — starting a new conversation + discards any in-flight task on the old one. + + Raises ``ValueError`` when called on a non-A2A client. + """ + if not isinstance(self.adapter, A2AAdapter): + raise ValueError( + "reset_context is only supported for A2A protocol; " + f"got {self.agent_config.protocol}" + ) + self.adapter.set_context_id(context_id) + async def _ensure_idempotency_capability(self) -> None: """Verify the seller positively declares idempotency support in capabilities. diff --git a/src/adcp/protocols/a2a.py b/src/adcp/protocols/a2a.py index ed1408a1e..ec312b2c1 100644 --- a/src/adcp/protocols/a2a.py +++ b/src/adcp/protocols/a2a.py @@ -38,11 +38,81 @@ class A2AAdapter(ProtocolAdapter): """Adapter for A2A protocol using official a2a-sdk client.""" + # A2A task states in which the server is still expecting more from + # the buyer on the same task (input-required, auth-required, and + # in-flight states). While the adapter holds a task_id in one of + # these states, the next outbound Message must echo it back so the + # server resumes the same task rather than orphaning it and starting + # a new one. Terminal states (completed/failed/canceled/rejected) + # clear the retained task_id — subsequent calls in the conversation + # are new tasks. + _NONTERMINAL_TASK_STATES = frozenset( + {"submitted", "working", "input-required", "auth-required"} + ) + def __init__(self, agent_config: AgentConfig): """Initialize A2A adapter with official A2A client.""" super().__init__(agent_config) self._httpx_client: httpx.AsyncClient | None = None self._a2a_client: A2AClient | None = None + # A2A contextId for multi-turn conversations. First request sends + # context_id=None → server mints one and returns it on Task.context_id; + # we stash it here and echo it back on every subsequent send so the + # server can scope state to the same session. Callers can seed this + # via ADCPClient(context_id=...) to resume a session across process + # restarts, or clear it via ADCPClient.reset_context() to start a + # new conversation. + self._context_id: str | None = None + # A2A task_id retained across turns only while the prior task is + # non-terminal (input-required, working, etc). On terminal states + # this clears to None so the next call starts a new task under + # the same context_id. Without this, resume of an input-required + # task orphans the server-side pending task. + self._pending_task_id: str | None = None + + @property + def context_id(self) -> str | None: + """Current A2A conversation context_id, or None if not yet established. + + ``None`` means either (a) a fresh conversation where the server + has not yet replied, or (b) the context was cleared via + ``set_context_id(None)``. Callers that need to distinguish these + must track their own state. + + Not thread-safe: the adapter mutates this on every response. For + concurrent use, serialize calls on one adapter or construct one + per conversation. + """ + return self._context_id + + @property + def pending_task_id(self) -> str | None: + """A2A task_id retained for resume, or None if no task is pending. + + Populated when the last response was non-terminal (e.g. + ``input-required``). Echoed on the next outbound message so the + server continues the same task. Clears to None on terminal + states (``completed``/``failed``/``canceled``). + """ + return self._pending_task_id + + def set_context_id(self, context_id: str | None) -> None: + """Set the A2A context_id for subsequent message sends. + + Pass ``None`` to clear — the server mints a fresh id on the next + call — or a string to seed. Seeding is safe for *resume* (pass + back an id the server previously returned). Seeding with a + *self-generated* id is server-dependent: per the A2A spec, + agents MAY accept or reject client-supplied ids, and some + frameworks (notably ADK) rewrite the id into their own session + format and return the rewritten value on the next response — at + which point this adapter auto-adopts it. + + Also clears any retained ``pending_task_id``: switching context + always starts a fresh task under the new context. + """ + self._context_id = context_id + self._pending_task_id = None async def _get_httpx_client(self) -> httpx.AsyncClient: """Get or create the HTTP client with connection pooling.""" @@ -189,6 +259,8 @@ async def _call_a2a_tool( message_id=message_id, role=Role.user, parts=[Part(root=data_part)], + context_id=self._context_id, + task_id=self._pending_task_id, ) else: # Natural language invocation (flexible) @@ -198,6 +270,8 @@ async def _call_a2a_tool( message_id=message_id, role=Role.user, parts=[Part(root=text_part)], + context_id=self._context_id, + task_id=self._pending_task_id, ) # Build request params @@ -265,6 +339,18 @@ async def _call_a2a_tool( # Result can be either Task or Message if isinstance(result, Task): + # Retain the server-assigned context_id so subsequent + # turns continue the same A2A conversation. Task.context_id + # is required by a2a-sdk, so no None-guard needed. + self._context_id = result.context_id + # Retain task_id only while the task is non-terminal. + # On terminal states (completed/failed/canceled/rejected) + # the next send must NOT echo this task_id — it starts a + # fresh task under the same context. + if result.status.state in self._NONTERMINAL_TASK_STATES: + self._pending_task_id = result.id + else: + self._pending_task_id = None task_result = self._process_task_response(result, debug_info) _idempotency.raise_for_idempotency_error( tool_name, task_result.data, self.agent_config.id diff --git a/tests/integration/test_a2a_context_id.py b/tests/integration/test_a2a_context_id.py new file mode 100644 index 000000000..dd1ef46b0 --- /dev/null +++ b/tests/integration/test_a2a_context_id.py @@ -0,0 +1,403 @@ +"""End-to-end HTTP integration tests for A2A contextId / taskId handling. + +Spins up a real A2A Starlette app on localhost via uvicorn and drives +it with the SDK's ``ADCPClient``. These tests prove the wire-level +behavior — they are the counterpart to the mocked adapter tests in +``tests/test_protocols.py`` and the only thing that would catch +regressions in how the client actually serializes ``context_id`` / +``task_id`` onto the JSON-RPC ``message/send`` request. + +Also guards the server-side session-scoping claim: a handler keyed on +``context_id`` would start seeing fresh buckets on every call if the +client stopped echoing the id — these tests would fire first. + +Note on what the observer can see: the a2a-sdk's ``RequestContext`` +auto-populates ``context_id`` / ``task_id`` server-side — it mints +one when the client sent nothing, so a server-side observer cannot +tell "client sent None" from "client sent X" by reading +``RequestContext.context_id`` alone. What it *can* prove is the +higher-value contract — that two turns on one ``ADCPClient`` land on +the *same* server-observed id, and that ``reset_context()`` produces +a *different* one. That's the buyer-visible semantic; the None-on- +first-wire detail is covered by the unit tests in +``tests/test_protocols.py``. +""" + +from __future__ import annotations + +import asyncio +import socket +import sys +from collections.abc import AsyncIterator +from contextlib import asynccontextmanager +from typing import Any +from uuid import uuid4 + +import pytest +import uvicorn +from a2a.server.agent_execution.agent_executor import AgentExecutor +from a2a.server.agent_execution.context import RequestContext +from a2a.server.events.event_queue import EventQueue +from a2a.server.request_handlers.default_request_handler import ( + DefaultRequestHandler, +) +from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore +from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, + Artifact, + DataPart, + Part, + Task, + TaskState, + TaskStatus, + TextPart, +) + +from adcp import ADCPClient +from adcp.server import ADCPHandler +from adcp.server.a2a_server import create_a2a_server +from adcp.types import AgentConfig, Protocol + +# Starlette/uvicorn A2A integration requires Python 3.11+. +pytestmark = pytest.mark.skipif( + sys.version_info < (3, 11), + reason="a2a-sdk starlette integration requires Python 3.11+", +) + + +class _EchoHandler(ADCPHandler): + """Minimal handler for the happy-path tests — the assertions are at + the protocol layer, not the handler. Returns empty payloads.""" + + async def get_adcp_capabilities(self, params: Any, context: Any = None) -> dict[str, Any]: + return {"adcp": {"major_versions": [3]}, "supported_protocols": ["media_buy"]} + + async def get_products(self, params: Any, context: Any = None) -> dict[str, Any]: + return {"products": [{"id": "p1", "name": "Display"}]} + + async def create_media_buy(self, params: Any, context: Any = None) -> dict[str, Any]: + return {"media_buy_id": "mb-1"} + + +class _Observer: + """Captures the (context_id, task_id) the server saw on each + incoming A2A message. + + Installed via the ``message_parser`` hook — the parser is invoked + on every ``message/send`` with the full RequestContext. See the + module docstring for what this observation point can and cannot + prove. + """ + + def __init__(self) -> None: + self.calls: list[dict[str, str | None]] = [] + + def parser(self, context: RequestContext) -> tuple[str | None, dict[str, Any]]: + self.calls.append({"context_id": context.context_id, "task_id": context.task_id}) + # Reimplement the default DataPart(skill=..., parameters=...) + # parse inline so we don't reach into executor internals. + msg = context.message + if msg is None: + return None, {} + for part in msg.parts: + inner = part.root if hasattr(part, "root") else part + if isinstance(inner, DataPart) and isinstance(inner.data, dict): + skill = inner.data.get("skill") + params = inner.data.get("parameters") or {} + if skill and isinstance(params, dict): + return str(skill), params + return None, {} + + +def _pick_free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return int(s.getsockname()[1]) + + +@asynccontextmanager +async def _running_server(handler: ADCPHandler, observer: _Observer) -> AsyncIterator[str]: + """Start an in-process uvicorn serving the A2A app, yield its base URL.""" + port = _pick_free_port() + app = create_a2a_server( + handler, + name="integration-test-agent", + port=port, + message_parser=observer.parser, + ) + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning") + server = uvicorn.Server(config) + task = asyncio.create_task(server.serve()) + try: + for _ in range(200): # ~10s ceiling + if server.started: + break + await asyncio.sleep(0.05) + else: + raise RuntimeError("uvicorn failed to start within timeout") + yield f"http://127.0.0.1:{port}" + finally: + server.should_exit = True + await task + + +@pytest.mark.asyncio +async def test_two_calls_share_server_assigned_context_id(): + """Core contract: two sequential calls on one ADCPClient land on + the server under the same context_id. If the client stopped + echoing it, turn 2 would get a different server-minted id — the + assertion that calls[0] == calls[1] catches that regression.""" + observer = _Observer() + async with _running_server(_EchoHandler(), observer) as base_url: + config = AgentConfig( + id="ctx-test-agent", + agent_uri=base_url, + protocol=Protocol.A2A, + auth_token="test", + ) + async with ADCPClient(config) as client: + assert client.context_id is None + r1 = await client.adapter.get_products({"brief": "x"}) + assert r1.success, r1.error + # After turn 1 the client has captured the server's id. + assert client.context_id is not None + captured_after_turn_1 = client.context_id + + r2 = await client.adapter.create_media_buy({"budget": 1000}) + assert r2.success, r2.error + + assert len(observer.calls) == 2 + # The server saw the same context_id on both turns — this proves + # session continuity end-to-end. The server is authoritative, so + # this value is what any handler keyed on context_id would scope to. + assert observer.calls[0]["context_id"] == observer.calls[1]["context_id"] + # And that server-observed id matches what the client captured. + assert observer.calls[1]["context_id"] == captured_after_turn_1 + + +@pytest.mark.asyncio +async def test_reset_context_produces_new_server_side_session(): + """After ``reset_context()``, the next call must land on a + different server-side context than the one before. If reset were a + no-op the two server-observed ids would match and this would fire.""" + observer = _Observer() + async with _running_server(_EchoHandler(), observer) as base_url: + config = AgentConfig( + id="ctx-test-agent", + agent_uri=base_url, + protocol=Protocol.A2A, + auth_token="test", + ) + async with ADCPClient(config) as client: + await client.adapter.get_products({"brief": "x"}) + client.reset_context() + assert client.context_id is None + await client.adapter.get_products({"brief": "y"}) + + # Two distinct server-side sessions. + assert observer.calls[0]["context_id"] != observer.calls[1]["context_id"] + + +@pytest.mark.asyncio +async def test_seeded_context_id_reaches_server_on_first_call(): + """Constructor seeding (``ADCPClient(context_id=...)``) — the + resume-across-restart use case. The server sees the seeded id on + turn 1 with no round-trip, so a buyer rehydrating from persisted + state lands on the same server-side session.""" + seed = f"buyer-seeded-{uuid4()}" + observer = _Observer() + async with _running_server(_EchoHandler(), observer) as base_url: + config = AgentConfig( + id="ctx-test-agent", + agent_uri=base_url, + protocol=Protocol.A2A, + auth_token="test", + ) + async with ADCPClient(config, context_id=seed) as client: + assert client.context_id == seed + r1 = await client.adapter.get_products({"brief": "x"}) + assert r1.success, r1.error + + assert observer.calls[0]["context_id"] == seed + + +# --------------------------------------------------------------------------- +# HITL / input-required resume — requires a custom AgentExecutor that emits +# a non-terminal task on the first call. ADCPAgentExecutor always emits +# terminal (completed/failed), so we drop down to the raw a2a-sdk here. +# --------------------------------------------------------------------------- + + +class _HitlExecutor(AgentExecutor): + """Emits an ``input-required`` task on the first call, then a + ``completed`` task on the second. Records what came in on the wire. + """ + + def __init__(self) -> None: + self.observations: list[dict[str, str | None]] = [] + self._served = 0 + + async def execute(self, context: RequestContext, event_queue: EventQueue) -> None: + self.observations.append( + { + "context_id": context.context_id, + "task_id": context.task_id, + "message_task_id": (context.message.task_id if context.message else None), + "message_context_id": (context.message.context_id if context.message else None), + } + ) + self._served += 1 + if self._served == 1: + state = TaskState.input_required + text = "manager approval needed" + else: + state = TaskState.completed + text = "approved" + + task = Task( + id=context.task_id or str(uuid4()), + context_id=context.context_id or str(uuid4()), + status=TaskStatus(state=state), + artifacts=[ + Artifact( + artifact_id=str(uuid4()), + parts=[ + Part(root=TextPart(text=text)), + Part(root=DataPart(data={"approved": state == TaskState.completed})), + ], + ) + ], + ) + await event_queue.enqueue_event(task) + + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: + task = Task( + id=context.task_id or str(uuid4()), + context_id=context.context_id or str(uuid4()), + status=TaskStatus(state=TaskState.canceled), + ) + await event_queue.enqueue_event(task) + + +def _make_hitl_app(executor: _HitlExecutor, port: int) -> Any: + """Build a raw A2A Starlette app around the custom executor. + + The agent-card ``url`` must include the serving port — the client + routes JSON-RPC POSTs to ``agent_card.url``, not to the base_url + it passed to the resolver. + """ + from a2a.server.apps.jsonrpc.starlette_app import A2AStarletteApplication + + card = AgentCard( + name="hitl-test-agent", + description="non-terminal-state test", + url=f"http://127.0.0.1:{port}/", + version="1.0.0", + capabilities=AgentCapabilities(streaming=False), + default_input_modes=["application/json"], + default_output_modes=["application/json"], + skills=[ + AgentSkill( + id="create_media_buy", + name="create_media_buy", + description="create_media_buy", + tags=["adcp"], + ) + ], + ) + handler = DefaultRequestHandler( + agent_executor=executor, + task_store=InMemoryTaskStore(), + ) + return A2AStarletteApplication(agent_card=card, http_handler=handler).build() + + +@asynccontextmanager +async def _running_raw_server( + executor: _HitlExecutor, +) -> AsyncIterator[str]: + port = _pick_free_port() + app = _make_hitl_app(executor, port) + config = uvicorn.Config(app, host="127.0.0.1", port=port, log_level="warning") + server = uvicorn.Server(config) + task = asyncio.create_task(server.serve()) + try: + for _ in range(200): + if server.started: + break + await asyncio.sleep(0.05) + else: + raise RuntimeError("uvicorn failed to start within timeout") + yield f"http://127.0.0.1:{port}" + finally: + server.should_exit = True + await task + + +@pytest.mark.asyncio +async def test_task_id_echoed_on_resume_after_input_required(): + """HITL flow: server returns ``input-required`` on turn 1 → client + auto-retains task_id → turn 2 carries both context_id and task_id + so the server resumes the same task. Without task_id echo the + server would orphan the pending HITL task.""" + executor = _HitlExecutor() + async with _running_raw_server(executor) as base_url: + config = AgentConfig( + id="hitl-agent", + agent_uri=base_url, + protocol=Protocol.A2A, + auth_token="test", + ) + async with ADCPClient(config) as client: + r1 = await client.adapter.create_media_buy({"budget": 1000}) + # After an input-required response the adapter stashed both ids. + assert client.context_id is not None + assert client.pending_task_id is not None + retained_task_id = client.pending_task_id + retained_context_id = client.context_id + + r2 = await client.adapter.create_media_buy({"approval": "yes"}) + # Terminal state on turn 2 cleared pending_task_id; context stays. + assert client.pending_task_id is None + assert client.context_id == retained_context_id + + assert len(executor.observations) == 2 + # Turn 1: both ids are server-generated (client sent nothing). + # Turn 2: the client echoed the server's task_id back on the Message — + # this is what resumes the pending HITL task server-side. + assert executor.observations[1]["message_task_id"] == retained_task_id + assert executor.observations[1]["message_context_id"] == retained_context_id + # Sanity: r2 came back as completed. + assert r2.success, r2.error + _ = r1 + + +@pytest.mark.asyncio +async def test_resume_across_simulated_restart_lands_on_same_session(): + """Persistence-across-restart story: client A establishes a + session, persists its context_id, dies. Client B spins up, seeds + with the persisted id, and its first call must carry that id so + the server can reattach it to the original session.""" + observer = _Observer() + async with _running_server(_EchoHandler(), observer) as base_url: + config = AgentConfig( + id="ctx-test-agent", + agent_uri=base_url, + protocol=Protocol.A2A, + auth_token="test", + ) + # Client A — establishes the session and "persists" the id. + async with ADCPClient(config) as client_a: + await client_a.adapter.get_products({"brief": "x"}) + persisted_context_id = client_a.context_id + assert persisted_context_id is not None + + # Client B — different instance, seeds from persisted state. + async with ADCPClient(config, context_id=persisted_context_id) as client_b: + await client_b.adapter.create_media_buy({"budget": 1000}) + + # Both server-observed calls share the same context_id. + assert observer.calls[0]["context_id"] == observer.calls[1]["context_id"] + assert observer.calls[1]["context_id"] == persisted_context_id diff --git a/tests/test_protocols.py b/tests/test_protocols.py index 3254de979..a88c8dd96 100644 --- a/tests/test_protocols.py +++ b/tests/test_protocols.py @@ -442,6 +442,323 @@ async def test_get_agent_info_without_extensions(self, a2a_config): assert "protocols_supported" not in info +class TestA2AContextId: + """Tests for A2A contextId auto-retain, inject, and reset. + + Covers the multi-turn conversation story: first send carries + context_id=None, server assigns one, adapter echoes it on every + subsequent turn. Callers can seed the id at construction (resume / + self-named sessions) or clear it to start a fresh conversation. + """ + + @staticmethod + def _captured_context_id(mock_send_message: AsyncMock) -> str | None: + """Pull the ``Message.context_id`` off the captured send call. + + The adapter wraps the outbound ``Message`` in ``MessageSendParams`` + inside a ``SendMessageRequest`` — drill through to the message. + """ + request = mock_send_message.call_args[0][0] + return request.params.message.context_id + + @pytest.mark.asyncio + async def test_first_call_sends_no_context_id_and_captures_server_assigned(self, a2a_config): + """First turn: no context yet → server assigns → adapter stores it.""" + adapter = A2AAdapter(a2a_config) + assert adapter.context_id is None + + mock_task = create_mock_a2a_task(context_id="server-assigned-abc") + mock_response = SendMessageSuccessResponse(result=mock_task) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock(return_value=mock_response) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("get_products", {}) + + assert self._captured_context_id(mock_a2a_client.send_message) is None + assert adapter.context_id == "server-assigned-abc" + + @pytest.mark.asyncio + async def test_subsequent_call_echoes_retained_context_id(self, a2a_config): + """Second turn: adapter sends the context_id captured on turn one.""" + adapter = A2AAdapter(a2a_config) + + first_task = create_mock_a2a_task(context_id="ctx-session-1") + second_task = create_mock_a2a_task(context_id="ctx-session-1") + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + side_effect=[ + SendMessageSuccessResponse(result=first_task), + SendMessageSuccessResponse(result=second_task), + ] + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("get_products", {}) + await adapter._call_a2a_tool("create_media_buy", {}) + + second_call = mock_a2a_client.send_message.call_args_list[1] + assert second_call[0][0].params.message.context_id == "ctx-session-1" + assert adapter.context_id == "ctx-session-1" + + @pytest.mark.asyncio + async def test_set_context_id_is_used_on_next_send(self, a2a_config): + """Seeded context_id is sent on the very next call (resume use case).""" + adapter = A2AAdapter(a2a_config) + adapter.set_context_id("resumed-from-redis") + + mock_task = create_mock_a2a_task(context_id="resumed-from-redis") + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=mock_task) + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("get_products", {}) + + assert self._captured_context_id(mock_a2a_client.send_message) == "resumed-from-redis" + + @pytest.mark.asyncio + async def test_clearing_context_id_starts_fresh_conversation(self, a2a_config): + """set_context_id(None) clears; next send carries no context_id.""" + adapter = A2AAdapter(a2a_config) + adapter._context_id = "old-ctx" + + mock_task = create_mock_a2a_task(context_id="new-server-ctx") + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=mock_task) + ) + + adapter.set_context_id(None) + assert adapter.context_id is None + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("get_products", {}) + + assert self._captured_context_id(mock_a2a_client.send_message) is None + # After response we retain whatever the server assigned. + assert adapter.context_id == "new-server-ctx" + + @staticmethod + def _captured_task_id(mock_send_message: AsyncMock, call_index: int = 0) -> str | None: + """Pull the ``Message.task_id`` off a specific captured send call.""" + request = mock_send_message.call_args_list[call_index][0][0] + return request.params.message.task_id + + @pytest.mark.asyncio + async def test_task_id_retained_when_state_is_input_required(self, a2a_config): + """Non-terminal state (input-required) → task_id echoed on next send + so the server resumes the same task rather than orphaning it.""" + adapter = A2AAdapter(a2a_config) + + hitl_task = create_mock_a2a_task( + task_id="task-hitl-1", + context_id="ctx-abc", + state="input-required", + parts=[TextPart(text="Need approval")], + ) + resume_task = create_mock_a2a_task( + task_id="task-hitl-1", + context_id="ctx-abc", + state="completed", + ) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + side_effect=[ + SendMessageSuccessResponse(result=hitl_task), + SendMessageSuccessResponse(result=resume_task), + ] + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("create_media_buy", {}) + assert adapter.pending_task_id == "task-hitl-1" + + await adapter._call_a2a_tool("create_media_buy", {"approval": "yes"}) + + assert self._captured_task_id(mock_a2a_client.send_message, 0) is None + assert self._captured_task_id(mock_a2a_client.send_message, 1) == "task-hitl-1" + # Terminal state clears the pending task. + assert adapter.pending_task_id is None + + @pytest.mark.asyncio + async def test_task_id_cleared_on_completed_state(self, a2a_config): + """Terminal state → subsequent call starts a new task under the + same context (task_id=None on send, context_id retained).""" + adapter = A2AAdapter(a2a_config) + + first_task = create_mock_a2a_task( + task_id="task-get-products", + context_id="ctx-session", + state="completed", + ) + second_task = create_mock_a2a_task( + task_id="task-create-media-buy", + context_id="ctx-session", + state="completed", + ) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + side_effect=[ + SendMessageSuccessResponse(result=first_task), + SendMessageSuccessResponse(result=second_task), + ] + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("get_products", {}) + assert adapter.pending_task_id is None + + await adapter._call_a2a_tool("create_media_buy", {}) + + assert self._captured_task_id(mock_a2a_client.send_message, 1) is None + second_call = mock_a2a_client.send_message.call_args_list[1] + assert second_call[0][0].params.message.context_id == "ctx-session" + + @pytest.mark.asyncio + async def test_task_id_cleared_on_failed_state(self, a2a_config): + """Failure is terminal too — pending task_id must clear.""" + adapter = A2AAdapter(a2a_config) + + failed = create_mock_a2a_task( + task_id="task-failed", + context_id="ctx", + state="failed", + parts=[TextPart(text="server error")], + ) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=failed) + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("get_products", {}) + + assert adapter.pending_task_id is None + + @pytest.mark.asyncio + async def test_task_id_retained_on_working_state(self, a2a_config): + """'working' is also non-terminal — adapter must retain task_id + so clients polling / resuming land on the right task.""" + adapter = A2AAdapter(a2a_config) + + working = create_mock_a2a_task( + task_id="task-in-progress", + context_id="ctx", + state="working", + parts=[TextPart(text="processing...")], + ) + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=working) + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("create_media_buy", {}) + + assert adapter.pending_task_id == "task-in-progress" + + @pytest.mark.asyncio + async def test_set_context_id_clears_pending_task(self, a2a_config): + """Switching context discards any in-flight task — a new + conversation shouldn't try to resume a task from the old one.""" + adapter = A2AAdapter(a2a_config) + adapter._context_id = "old-ctx" + adapter._pending_task_id = "old-task" + + adapter.set_context_id("new-ctx") + + assert adapter.context_id == "new-ctx" + assert adapter.pending_task_id is None + + @pytest.mark.asyncio + async def test_server_rebinding_context_id_is_honored(self, a2a_config): + """If the server returns a different context_id than we proposed, + we adopt the server's value — servers are authoritative on context + assignment even when the buyer self-named the session.""" + adapter = A2AAdapter(a2a_config) + adapter.set_context_id("buyer-proposed") + + mock_task = create_mock_a2a_task(context_id="server-overrode") + mock_a2a_client = AsyncMock() + mock_a2a_client.send_message = AsyncMock( + return_value=SendMessageSuccessResponse(result=mock_task) + ) + + with patch.object(adapter, "_get_a2a_client", return_value=mock_a2a_client): + await adapter._call_a2a_tool("get_products", {}) + + assert self._captured_context_id(mock_a2a_client.send_message) == "buyer-proposed" + assert adapter.context_id == "server-overrode" + + +class TestADCPClientContextId: + """Tests for the ADCPClient-level contextId surface.""" + + def test_constructor_seeds_context_id_on_a2a_client(self, a2a_config): + from adcp.client import ADCPClient + + client = ADCPClient(a2a_config, context_id="seeded-ctx") + assert client.context_id == "seeded-ctx" + assert isinstance(client.adapter, A2AAdapter) + assert client.adapter.context_id == "seeded-ctx" + + def test_context_id_property_defaults_to_none(self, a2a_config): + from adcp.client import ADCPClient + + client = ADCPClient(a2a_config) + assert client.context_id is None + + def test_reset_context_clears(self, a2a_config): + from adcp.client import ADCPClient + + client = ADCPClient(a2a_config, context_id="will-clear") + client.reset_context() + assert client.context_id is None + + def test_reset_context_with_new_id(self, a2a_config): + from adcp.client import ADCPClient + + client = ADCPClient(a2a_config) + client.reset_context("fresh-named-session") + assert client.context_id == "fresh-named-session" + + def test_constructor_rejects_context_id_on_non_a2a(self, mcp_config): + from adcp.client import ADCPClient + + with pytest.raises(ValueError, match="only supported for A2A"): + ADCPClient(mcp_config, context_id="nope") + + def test_reset_context_rejects_on_non_a2a(self, mcp_config): + from adcp.client import ADCPClient + + client = ADCPClient(mcp_config) + with pytest.raises(ValueError, match="only supported for A2A"): + client.reset_context("anything") + + def test_context_id_property_returns_none_on_non_a2a(self, mcp_config): + from adcp.client import ADCPClient + + client = ADCPClient(mcp_config) + assert client.context_id is None + + def test_pending_task_id_property_exposes_adapter_state(self, a2a_config): + from adcp.client import ADCPClient + + client = ADCPClient(a2a_config) + assert client.pending_task_id is None + assert isinstance(client.adapter, A2AAdapter) + client.adapter._pending_task_id = "task-mid-flight" + assert client.pending_task_id == "task-mid-flight" + + def test_pending_task_id_returns_none_on_non_a2a(self, mcp_config): + from adcp.client import ADCPClient + + client = ADCPClient(mcp_config) + assert client.pending_task_id is None + + class TestMCPAdapter: """Tests for MCP protocol adapter.""" @@ -532,9 +849,7 @@ async def test_call_tool_text_json_fallback(self, mcp_config): mock_session = AsyncMock() mock_result = MagicMock() # Reference-agent shape: JSON inside TextContent, no structuredContent. - mock_result.content = [ - {"type": "text", "text": '{"status":"completed","products":[]}'} - ] + mock_result.content = [{"type": "text", "text": '{"status":"completed","products":[]}'}] mock_result.structuredContent = None mock_result.isError = False mock_session.call_tool.return_value = mock_result