diff --git a/docs/handler-authoring.md b/docs/handler-authoring.md index 9d8be887c..f22434f1b 100644 --- a/docs/handler-authoring.md +++ b/docs/handler-authoring.md @@ -489,6 +489,102 @@ is the shape you can copy for your own middleware tests. Key pieces: guard. - Run the app's lifespan manually if you're exercising HTTP endpoints. +## Testing hooks — storyboard + header-driven composition + +Two orthogonal test-runtime shapes exist in the wild. Compose them via +the same `context_factory` you already wire for auth: + +**Storyboard-driven** (SDK-native). Sellers register a +`TestControllerStore` and clients invoke the `comply_test_controller` +skill with a scenario name (`force_media_buy_status`, `simulate_delivery`, +etc.). This is the AdCP spec's compliance-test shape and what the +conformance suite exercises. + +**Header-driven** (downstream pattern, e.g. salesagent's +`AdCPTestContext.from_headers(request.headers)`). Clients pass HTTP +headers like `X-AdCP-Test-Mode: slow` and the server adjusts mock +behavior. Useful for scenario-wide state that doesn't fit the +storyboard frame — "every update in this request returns pending", +"this request simulates a delayed ad server". + +Before SDK 3.x you had to pick one. As of #227 both compose through +the existing `context_factory`: + +```python +from contextvars import ContextVar +from starlette.middleware.base import BaseHTTPMiddleware + +from adcp.server import RequestMetadata, ToolContext, create_mcp_server +from adcp.server.test_controller import ( + TestControllerStore, + register_test_controller, +) + +# 1. ContextVar the HTTP middleware populates from request headers. +_test_context: ContextVar[AdCPTestContext | None] = ContextVar( + "test_context", default=None +) + + +# 2. Starlette middleware reads headers into the ContextVar per request. +# Always reset the token in a finally block — otherwise the set +# value leaks into the next request that reuses this asyncio task +# (cross-request state bleed; see PR #232's cross-tenant idempotency +# scoping for the analogous failure mode). +class TestHeaderMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request, call_next): + token = _test_context.set(AdCPTestContext.from_headers(request.headers)) + try: + return await call_next(request) + finally: + _test_context.reset(token) + + +# 3. context_factory snapshots the ContextVar onto ToolContext. +def build_context(meta: RequestMetadata) -> ToolContext: + return ToolContext( + metadata={"test_context": _test_context.get()}, + ) + + +# 4. Store methods that want header-driven state accept `context`. +class MyStore(TestControllerStore): + async def force_media_buy_status( + self, + media_buy_id: str, + status: str, + rejection_reason: str | None = None, + *, + context: ToolContext | None = None, + ) -> dict[str, Any]: + test_ctx = (context.metadata.get("test_context") if context else None) + if test_ctx and test_ctx.slow_ad_server: + status = "pending" # header-driven behavior override + self.media_buys[media_buy_id] = status + return {"previous_state": "active", "current_state": status} + + +# 5. Wire the same factory into both create_mcp_server AND +# register_test_controller. Regular handler methods and +# comply_test_controller both see the same context. +mcp = create_mcp_server(MySeller(), name="my-agent", context_factory=build_context) +register_test_controller(mcp, MyStore(), context_factory=build_context) + +app = mcp.streamable_http_app() +app.add_middleware(TestHeaderMiddleware) +``` + +**Backward compatibility**: stores whose methods don't declare +`context` keep working. The dispatcher inspects the signature and +only passes `context` to methods that opt in. `serve(..., test_controller=...)` +automatically threads `context_factory` through, so no extra wiring is +needed if you use the `serve()` helper. + +**When to pick which**: the storyboard skill is for spec-level +compliance tests (scenarios named by the AdCP test suite). Headers are +for your own mock-ad-server behaviors that sit outside the spec. +Sellers typically need both. + ## What not to build - Don't write per-tool `@mcp.tool()` wrappers. `create_mcp_server()` diff --git a/src/adcp/server/a2a_server.py b/src/adcp/server/a2a_server.py index 9f38611ae..e8a1834d5 100644 --- a/src/adcp/server/a2a_server.py +++ b/src/adcp/server/a2a_server.py @@ -113,12 +113,19 @@ def supported_skills(self) -> list[str]: return list(self._tool_callers.keys()) def _register_test_controller(self, store: TestControllerStore) -> None: - """Register comply_test_controller as a callable skill.""" + """Register comply_test_controller as a callable skill. + + Threads the ToolContext that the A2A executor built for this + dispatch into the store so header-driven test state (populated + by ``context_factory`` from ``ServerCallContext.user`` / + message-metadata headers) composes with the storyboard-driven + ``comply_test_controller`` skill. See #227. + """ async def _call_test_controller( params: dict[str, Any], context: ToolContext | None = None ) -> Any: - return await _handle_test_controller(store, params) + return await _handle_test_controller(store, params, context=context) self._tool_callers["comply_test_controller"] = _call_test_controller diff --git a/src/adcp/server/serve.py b/src/adcp/server/serve.py index 2c2b04006..18b181eed 100644 --- a/src/adcp/server/serve.py +++ b/src/adcp/server/serve.py @@ -397,7 +397,7 @@ def _serve_mcp( if test_controller is not None: from adcp.server.test_controller import register_test_controller - register_test_controller(mcp, test_controller) + register_test_controller(mcp, test_controller, context_factory=context_factory) if transport in ("streamable-http", "sse"): _run_mcp_http(mcp, transport=transport) diff --git a/src/adcp/server/test_controller.py b/src/adcp/server/test_controller.py index 9929a76ae..b52fb64b5 100644 --- a/src/adcp/server/test_controller.py +++ b/src/adcp/server/test_controller.py @@ -17,12 +17,32 @@ async def force_account_status(self, account_id, status): store = MyStore() serve(MySeller(), name="my-agent", test_controller=store) + +Header-driven compatibility: + Store methods MAY accept a keyword-only ``context: ToolContext | None`` + parameter. When the server was configured with a ``context_factory``, + the dispatcher calls the factory per request and threads the + resulting ``ToolContext`` into the store method. This lets sellers + whose test runtime reads request headers (e.g. + ``AdCPTestContext.from_headers(request.headers)``) compose the + storyboard-driven ``comply_test_controller`` skill with their + existing header-driven mock state — populate the test context in + the ``context_factory`` (from a ContextVar set by your HTTP + middleware) and read it off ``context.metadata`` inside the store. + Stores that don't declare ``context`` on a method keep working + unchanged — the dispatcher only passes ``context`` to methods whose + signature accepts it. """ from __future__ import annotations +import inspect import json -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from adcp.server.base import ToolContext + from adcp.server.serve import ContextFactory # Scenario names — must match the AdCP comply_test_controller schema SCENARIOS = [ @@ -71,10 +91,22 @@ class TestControllerStore: and excluded from list_scenarios. Raise TestControllerError for structured error responses. + + Methods MAY declare an optional keyword-only ``context: ToolContext | + None = None`` parameter. When present, the dispatcher threads the + ``ToolContext`` built by the server's ``context_factory`` into the + call — header-driven mock state (e.g. ``AdCPTestContext.from_headers``) + populated in the factory is readable off ``context.metadata``. + Stores that don't declare ``context`` keep working unchanged. """ async def force_creative_status( - self, creative_id: str, status: str, rejection_reason: str | None = None + self, + creative_id: str, + status: str, + rejection_reason: str | None = None, + *, + context: ToolContext | None = None, ) -> dict[str, Any]: """Force a creative to a given status. @@ -83,7 +115,13 @@ async def force_creative_status( """ raise NotImplementedError - async def force_account_status(self, account_id: str, status: str) -> dict[str, Any]: + async def force_account_status( + self, + account_id: str, + status: str, + *, + context: ToolContext | None = None, + ) -> dict[str, Any]: """Force an account to a given status. Returns: @@ -92,7 +130,12 @@ async def force_account_status(self, account_id: str, status: str) -> dict[str, raise NotImplementedError async def force_media_buy_status( - self, media_buy_id: str, status: str, rejection_reason: str | None = None + self, + media_buy_id: str, + status: str, + rejection_reason: str | None = None, + *, + context: ToolContext | None = None, ) -> dict[str, Any]: """Force a media buy to a given status. @@ -102,7 +145,12 @@ async def force_media_buy_status( raise NotImplementedError async def force_session_status( - self, session_id: str, status: str, termination_reason: str | None = None + self, + session_id: str, + status: str, + termination_reason: str | None = None, + *, + context: ToolContext | None = None, ) -> dict[str, Any]: """Force a session to a given status. @@ -118,6 +166,8 @@ async def simulate_delivery( clicks: int | None = None, conversions: int | None = None, reported_spend: dict[str, Any] | None = None, + *, + context: ToolContext | None = None, ) -> dict[str, Any]: """Simulate delivery metrics for a media buy. @@ -131,6 +181,8 @@ async def simulate_budget_spend( spend_percentage: float, account_id: str | None = None, media_buy_id: str | None = None, + *, + context: ToolContext | None = None, ) -> dict[str, Any]: """Simulate budget spend to a percentage. @@ -159,9 +211,7 @@ def _list_scenarios(store: TestControllerStore) -> list[str]: return implemented -def _controller_error( - error: str, detail: str, current_state: str | None = None -) -> dict[str, Any]: +def _controller_error(error: str, detail: str, current_state: str | None = None) -> dict[str, Any]: """Format a test controller error response.""" resp: dict[str, Any] = { "success": False, @@ -173,10 +223,64 @@ def _controller_error( return resp +def _accepts_context_kwarg(method: Any) -> bool: + """True when ``method``'s signature accepts ``context=`` by keyword. + + TestControllerStore subclasses written against the original API + (pre-#227) don't declare ``context``; passing it would raise + ``TypeError`` at the call site. Signature inspection keeps the + dispatcher backward-compatible while letting stores opt in to + header-driven context by simply adding ``context=None`` to their + override. + + Counts as an opt-in: + + - ``*, context: ...`` — keyword-only (the documented recipe). + - ``context: ...`` as a regular positional-or-keyword parameter. + - ``**kwargs`` — accepts any keyword, including ``context``. + + Does **not** count: + + - ``context`` as positional-only (before ``/``) — passing by + keyword raises ``TypeError``. + - ``context`` as ``*args`` (it's never a variadic positional). + + Caveat: ``inspect.signature`` follows ``__wrapped__`` set by + ``@functools.wraps``. A decorator that wraps a legacy store method + and exposes the legacy signature will look "not opted in" even if + the wrapper itself would accept ``context``. This matches the + behavior callers expect — the wrapped callable signature is the + authoritative contract. + """ + try: + sig = inspect.signature(method) + except (TypeError, ValueError): + return False + allowed = { + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + } + for param in sig.parameters.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + return True + if param.name == "context" and param.kind in allowed: + return True + return False + + async def _handle_test_controller( - store: TestControllerStore, params: dict[str, Any] + store: TestControllerStore, + params: dict[str, Any], + context: ToolContext | None = None, ) -> dict[str, Any]: - """Dispatch a comply_test_controller request to the store.""" + """Dispatch a comply_test_controller request to the store. + + When ``context`` is supplied and the store's scenario method accepts + a ``context`` keyword, it's passed through — enabling header-driven + mock behavior composed with storyboard-driven compliance testing. + Methods without ``context`` in their signature keep working + unchanged. + """ scenario = params.get("scenario") implemented = _list_scenarios(store) @@ -201,29 +305,37 @@ async def _handle_test_controller( method = getattr(store, scenario) scenario_params = params.get("params", {}) + extra: dict[str, Any] = {} + if context is not None and _accepts_context_kwarg(method): + extra["context"] = context + try: if scenario == "force_creative_status": result = await method( creative_id=scenario_params["creative_id"], status=scenario_params["status"], rejection_reason=scenario_params.get("rejection_reason"), + **extra, ) elif scenario == "force_account_status": result = await method( account_id=scenario_params["account_id"], status=scenario_params["status"], + **extra, ) elif scenario == "force_media_buy_status": result = await method( media_buy_id=scenario_params["media_buy_id"], status=scenario_params["status"], rejection_reason=scenario_params.get("rejection_reason"), + **extra, ) elif scenario == "force_session_status": result = await method( session_id=scenario_params["session_id"], status=scenario_params["status"], termination_reason=scenario_params.get("termination_reason"), + **extra, ) elif scenario == "simulate_delivery": result = await method( @@ -232,12 +344,14 @@ async def _handle_test_controller( clicks=scenario_params.get("clicks"), conversions=scenario_params.get("conversions"), reported_spend=scenario_params.get("reported_spend"), + **extra, ) elif scenario == "simulate_budget_spend": result = await method( spend_percentage=scenario_params["spend_percentage"], account_id=scenario_params.get("account_id"), media_buy_id=scenario_params.get("media_buy_id"), + **extra, ) else: return _controller_error("UNKNOWN_SCENARIO", f"Unknown scenario: {scenario}") @@ -260,7 +374,12 @@ async def _handle_test_controller( return dict(result) -def register_test_controller(mcp: Any, store: TestControllerStore) -> None: +def register_test_controller( + mcp: Any, + store: TestControllerStore, + *, + context_factory: ContextFactory | None = None, +) -> None: """Register the comply_test_controller tool on an MCP server. This is the Python equivalent of the JS SDK's registerTestController(). @@ -269,6 +388,15 @@ def register_test_controller(mcp: Any, store: TestControllerStore) -> None: Args: mcp: A FastMCP server instance. store: Your TestControllerStore implementation. + context_factory: Optional ``ContextFactory`` invoked per call to + build a :class:`ToolContext`. When set, the context is + threaded into store methods that declare a ``context`` + keyword — which is how sellers whose test runtime reads + request headers (``AdCPTestContext.from_headers``) combine + header-driven mock state with the storyboard-driven + ``comply_test_controller`` skill. Wire the same factory you + pass to :func:`create_mcp_server` so both paths see the + same per-request context. Example: from adcp.server.test_controller import TestControllerStore, register_test_controller @@ -288,8 +416,20 @@ async def force_account_status(self, account_id, status): from mcp.server.fastmcp.utilities.func_metadata import ArgModelBase, FuncMetadata from pydantic import ConfigDict + from adcp.server.base import ToolContext as _ToolContext + from adcp.server.serve import RequestMetadata as _RequestMetadata + async def comply_test_controller(**kwargs: Any) -> str: - result = await _handle_test_controller(store, kwargs) + context: _ToolContext | None = None + if context_factory is not None: + meta = _RequestMetadata(tool_name="comply_test_controller", transport="mcp") + context = context_factory(meta) + if not isinstance(context, _ToolContext): + raise TypeError( + "context_factory for comply_test_controller returned " + f"{type(context).__name__}, not a ToolContext instance" + ) + result = await _handle_test_controller(store, kwargs, context=context) return json.dumps(result) tool = Tool.from_function( diff --git a/tests/test_test_controller_context.py b/tests/test_test_controller_context.py new file mode 100644 index 000000000..f45ba11f1 --- /dev/null +++ b/tests/test_test_controller_context.py @@ -0,0 +1,364 @@ +"""Header-driven test context + TestControllerStore composition — closes #227. + +Downstream (salesagent) drives mock-ad-server behavior from request +headers via ``AdCPTestContext.from_headers(request.headers)``. The SDK's +:class:`~adcp.server.test_controller.TestControllerStore` is +storyboard-shaped — scenarios dispatch via the ``comply_test_controller`` +skill. Before #227, there was no way to read HTTP headers inside a +``TestControllerStore`` method, so sellers who adopted the SDK's +storyboard testing lost their header-driven test scaffolding. + +Fix: ``register_test_controller`` accepts the same ``context_factory`` +as ``create_mcp_server``. The dispatcher builds a ``ToolContext`` per +call and threads it into store methods that declare a ``context`` +keyword. Sellers populate test state in the factory (typically by +reading a ContextVar set by their HTTP middleware from request headers) +and read it off ``context.metadata`` inside their store. + +Backward-compatibility contract: stores whose methods do NOT declare +``context`` MUST keep working — the dispatcher inspects the signature +and only passes ``context`` when the store opts in. +""" + +from __future__ import annotations + +from contextvars import ContextVar +from typing import Any + +import pytest + +from adcp.server import ( + ADCPHandler, + RequestMetadata, + ToolContext, + create_mcp_server, +) +from adcp.server.test_controller import ( + TestControllerStore, + _accepts_context_kwarg, + _handle_test_controller, + register_test_controller, +) + +# --------------------------------------------------------------------------- +# Minimal handler — TestControllerStore tests need an MCP server but not +# an interesting ADCPHandler. +# --------------------------------------------------------------------------- + + +class _MinimalHandler(ADCPHandler): + _agent_type = "test" + + async def get_adcp_capabilities(self, params, context=None): + return {"adcp": {"major_versions": [3]}} + + +# --------------------------------------------------------------------------- +# _accepts_context_kwarg — the signature-inspection helper +# --------------------------------------------------------------------------- + + +def test_accepts_context_kwarg_detects_keyword_only_param(): + """The documented opt-in pattern — ``*, context=None`` — must be + detected.""" + + async def fn(self, account_id: str, *, context: Any = None) -> dict[str, Any]: + return {} + + assert _accepts_context_kwarg(fn) is True + + +def test_accepts_context_kwarg_detects_var_keyword(): + """Stores that catch arbitrary kwargs (``**kwargs``) count as + accepting ``context`` — passing it won't raise TypeError.""" + + async def fn(self, account_id: str, **kwargs: Any) -> dict[str, Any]: + return {} + + assert _accepts_context_kwarg(fn) is True + + +def test_accepts_context_kwarg_rejects_methods_without_context(): + """Pre-#227 TestControllerStore overrides don't have ``context`` in + their signature — the dispatcher must NOT pass it, or the call + raises TypeError.""" + + async def fn(self, account_id: str, status: str) -> dict[str, Any]: + return {} + + assert _accepts_context_kwarg(fn) is False + + +def test_accepts_context_kwarg_rejects_positional_only_context(): + """``context`` before a ``/`` is positional-only — the dispatcher + passes ``context=ctx`` by keyword, so a positional-only declaration + would raise TypeError. Must be treated as not-opted-in.""" + import textwrap + + # positional-only syntax (PEP 570) works in Py 3.8+; use exec so + # the file still parses cleanly without introducing a sigil. + ns: dict[str, Any] = {} + exec( + textwrap.dedent( + """ + async def fn(self, context, /, account_id, status): + return {} + """ + ), + ns, + ) + assert _accepts_context_kwarg(ns["fn"]) is False + + +def test_accepts_context_kwarg_follows_functools_wraps(): + """``inspect.signature`` follows ``__wrapped__``. A decorator using + ``@functools.wraps`` exposes the wrapped signature — which is the + authoritative contract. Verify the detection pipeline respects it + so operators can reason about which methods opt in.""" + import functools + + async def legacy(self, account_id: str, status: str) -> dict[str, Any]: + return {} + + @functools.wraps(legacy) + async def wrapper(self, *args, **kwargs): + return await legacy(self, *args, **kwargs) + + # Wrapper preserves the legacy signature — no context visible. + assert _accepts_context_kwarg(wrapper) is False + + async def modern(self, account_id: str, status: str, *, context: Any = None) -> dict[str, Any]: + return {} + + @functools.wraps(modern) + async def modern_wrapper(self, *args, **kwargs): + return await modern(self, *args, **kwargs) + + # Wrapped signature preserves the kwarg — opt-in survives. + assert _accepts_context_kwarg(modern_wrapper) is True + + +def test_dispatcher_finds_override_on_intermediate_base(): + """A store may compose behavior across an inheritance chain + (``MyStore(Mixin, TestControllerStore)``). The dispatcher must find + the override wherever it lives in the MRO — and the context-kwarg + detection must work on the bound method even when the implementing + class is an intermediate base, not the leaf.""" + + class _Mixin: + async def force_media_buy_status( + self, + media_buy_id: str, + status: str, + rejection_reason: str | None = None, + *, + context: ToolContext | None = None, + ) -> dict[str, Any]: + return { + "previous_state": "active", + "current_state": status, + "from_mixin": True, + "saw_context": context is not None, + } + + class _Store(_Mixin, TestControllerStore): + pass + + ctx = ToolContext(caller_identity="p-x", metadata={"test_context": {"x": 1}}) + import asyncio + + result = asyncio.run( + _handle_test_controller( + _Store(), + { + "scenario": "force_media_buy_status", + "params": {"media_buy_id": "mb-1", "status": "paused"}, + }, + context=ctx, + ) + ) + assert result["success"] is True + assert result["from_mixin"] is True + assert result["saw_context"] is True + + +# --------------------------------------------------------------------------- +# Dispatcher threads context into store methods that opt in +# --------------------------------------------------------------------------- + + +async def test_store_with_context_kwarg_receives_the_context(): + """The primary #227 scenario: a store method that accepts ``context`` + receives the ToolContext the caller passed into the dispatcher.""" + received: list[ToolContext | None] = [] + + class _Store(TestControllerStore): + async def force_account_status( + self, + account_id: str, + status: str, + *, + context: ToolContext | None = None, + ) -> dict[str, Any]: + received.append(context) + return {"previous_state": "active", "current_state": status} + + ctx = ToolContext( + caller_identity="p-1", + tenant_id="t-1", + metadata={"test_context": {"env": "ci", "slow_ad_server": True}}, + ) + result = await _handle_test_controller( + _Store(), + { + "scenario": "force_account_status", + "params": {"account_id": "acc-1", "status": "suspended"}, + }, + context=ctx, + ) + + assert result["success"] is True + assert result["current_state"] == "suspended" + assert len(received) == 1 + assert received[0] is ctx + # The seller's header-driven state threads through verbatim. + assert received[0].metadata["test_context"]["env"] == "ci" + + +async def test_legacy_store_without_context_kwarg_still_works(): + """Backward-compat contract. Stores written before #227 don't + declare ``context``; the dispatcher must NOT pass it or the call + raises TypeError. This test fails fast if signature detection + regresses.""" + + class _LegacyStore(TestControllerStore): + # Original API shape — no context kwarg. + async def force_account_status(self, account_id: str, status: str) -> dict[str, Any]: + return {"previous_state": "active", "current_state": status} + + ctx = ToolContext(caller_identity="p-legacy") + result = await _handle_test_controller( + _LegacyStore(), + { + "scenario": "force_account_status", + "params": {"account_id": "acc-1", "status": "suspended"}, + }, + context=ctx, + ) + + assert result["success"] is True + assert result["current_state"] == "suspended" + + +async def test_context_not_passed_when_none(): + """If the caller didn't supply a context (``context=None``), don't + shove None into a store method that might not have the kwarg. The + call should succeed without any context-related machinery firing.""" + + class _Store(TestControllerStore): + # Legacy-shape method — no context in signature. + async def force_creative_status( + self, creative_id: str, status: str, rejection_reason: str | None = None + ) -> dict[str, Any]: + return {"previous_state": "pending", "current_state": status} + + result = await _handle_test_controller( + _Store(), + { + "scenario": "force_creative_status", + "params": {"creative_id": "cr-1", "status": "approved"}, + }, + context=None, + ) + + assert result["success"] is True + + +# --------------------------------------------------------------------------- +# End-to-end: context_factory + TestControllerStore via FastMCP registration +# --------------------------------------------------------------------------- + + +async def test_register_test_controller_threads_context_factory(): + """Integration: ``register_test_controller`` with a + ``context_factory`` matches the pattern sellers use to wire HTTP + middleware → ContextVars → ToolContext. The factory is called per + request; the store reads header-derived state off the context.""" + # A ContextVar the downstream HTTP middleware would populate from + # request headers. This test simulates the middleware having already + # run by setting the ContextVar directly. + test_state_var: ContextVar[dict[str, Any] | None] = ContextVar("_test_state", default=None) + received: list[ToolContext | None] = [] + + class _Store(TestControllerStore): + async def force_account_status( + self, + account_id: str, + status: str, + *, + context: ToolContext | None = None, + ) -> dict[str, Any]: + received.append(context) + return {"previous_state": "active", "current_state": status} + + def build_context(meta: RequestMetadata) -> ToolContext: + return ToolContext( + metadata={ + "tool_name": meta.tool_name, + "test_context": test_state_var.get(), + }, + ) + + mcp = create_mcp_server(_MinimalHandler(), name="test-agent") + register_test_controller(mcp, _Store(), context_factory=build_context) + + # Simulate the HTTP middleware populating the ContextVar from headers. + test_state_var.set({"env": "ci", "slow_ad_server": True}) + + tool = mcp._tool_manager._tools["comply_test_controller"] + # FastMCP's tool wrapper takes the function args as kwargs. + fn = tool.fn # type: ignore[attr-defined] + result_json = await fn( + scenario="force_account_status", + params={"account_id": "acc-1", "status": "suspended"}, + ) + + import json + + result = json.loads(result_json) + assert result["success"] is True + assert result["current_state"] == "suspended" + # The factory ran, built a ToolContext, and the store saw the header- + # derived test state verbatim. + assert len(received) == 1 + assert received[0] is not None + assert received[0].metadata["test_context"] == { + "env": "ci", + "slow_ad_server": True, + } + # And the tool name was populated by RequestMetadata. + assert received[0].metadata["tool_name"] == "comply_test_controller" + + +async def test_register_test_controller_rejects_non_toolcontext_from_factory(): + """Guard rail — a factory that returns a dict instead of a + ToolContext fails loudly at call time, not deep inside the store.""" + + class _Store(TestControllerStore): + async def force_account_status(self, account_id: str, status: str) -> dict[str, Any]: + return {"previous_state": "active", "current_state": status} + + def bad_factory(meta: RequestMetadata) -> Any: + return {"not": "a ToolContext"} + + mcp = create_mcp_server(_MinimalHandler(), name="test-agent") + register_test_controller(mcp, _Store(), context_factory=bad_factory) + + tool = mcp._tool_manager._tools["comply_test_controller"] + fn = tool.fn # type: ignore[attr-defined] + + with pytest.raises(TypeError, match="not a ToolContext"): + await fn( + scenario="force_account_status", + params={"account_id": "acc-1", "status": "suspended"}, + )