diff --git a/src/adcp/decisioning/serve.py b/src/adcp/decisioning/serve.py index bf36e86c9..8703006ea 100644 --- a/src/adcp/decisioning/serve.py +++ b/src/adcp/decisioning/serve.py @@ -373,6 +373,7 @@ def serve( advertise_all: bool = False, mock_ad_server: Any | None = None, enable_debug_endpoints: bool = False, + pre_validation_hooks: dict[str, Any] | None = None, **serve_kwargs: Any, ) -> None: """One-call wrapper — build the handler and serve over MCP. @@ -435,6 +436,25 @@ def serve( responses="strict")`` to enable schema-driven request/response validation against the bundled AdCP JSON schemas — sellers who want their server to enforce wire conformance turn it on here. + :param pre_validation_hooks: Optional dict mapping AdCP tool name to + a ``(tool_name, raw_args) -> raw_args`` callable. The hook runs + on the raw wire dict **before** schema + Pydantic validation — + use it to apply spec-mandated defaults for pre-v3 buyers that + omit required fields. Example:: + + serve( + router, + pre_validation_hooks={ + "get_products": lambda n, a: { + **a, "buying_mode": a.get("buying_mode", "brief") + }, + }, + ) + + Hook exceptions surface as ``INVALID_REQUEST`` on the wire. + The hook receives a shallow copy of the wire args, so it may + mutate its argument freely or return a new dict — either style + is safe. Context echo always reflects the original wire input. """ # Local import to avoid a circular at module-load time. Adopter # serves never run during foundation imports anyway. @@ -504,6 +524,8 @@ def serve( server_name = name or type(platform).__name__ debug_traffic_source = mock_ad_server.get_traffic if mock_ad_server is not None else None + if pre_validation_hooks is not None: + serve_kwargs["pre_validation_hooks"] = pre_validation_hooks _adcp_serve( handler, name=server_name, diff --git a/src/adcp/server/a2a_server.py b/src/adcp/server/a2a_server.py index c586c8d5b..6c61cfd41 100644 --- a/src/adcp/server/a2a_server.py +++ b/src/adcp/server/a2a_server.py @@ -143,6 +143,7 @@ def __init__( message_parser: MessageParser | None = None, advertise_all: bool = False, validation: ValidationHookConfig | None = SERVER_DEFAULT_VALIDATION, + pre_validation_hooks: dict[str, Any] | None = None, test_controller_account_resolver: Any | None = None, ) -> None: self._handler = handler @@ -169,7 +170,10 @@ def __init__( name = tool_def["name"] if name == "comply_test_controller" and test_controller is None: continue - self._tool_callers[name] = create_tool_caller(handler, name, validation=validation) + hook = (pre_validation_hooks or {}).get(name) + self._tool_callers[name] = create_tool_caller( + handler, name, validation=validation, pre_validation_hook=hook + ) if test_controller is not None: self._register_test_controller(test_controller) @@ -758,6 +762,7 @@ def create_a2a_server( message_parser: MessageParser | None = None, advertise_all: bool = False, validation: ValidationHookConfig | None = SERVER_DEFAULT_VALIDATION, + pre_validation_hooks: dict[str, Any] | None = None, context_builder: Any | None = None, auth: BearerTokenAuth | None = None, public_url: str | None = None, @@ -884,6 +889,7 @@ def create_a2a_server( message_parser=message_parser, advertise_all=advertise_all, validation=validation, + pre_validation_hooks=pre_validation_hooks, test_controller_account_resolver=test_controller_account_resolver, ) diff --git a/src/adcp/server/mcp_tools.py b/src/adcp/server/mcp_tools.py index 51e84e679..efa786e52 100644 --- a/src/adcp/server/mcp_tools.py +++ b/src/adcp/server/mcp_tools.py @@ -1870,6 +1870,7 @@ def create_tool_caller( method_name: str, *, validation: ValidationHookConfig | None = None, + pre_validation_hook: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None, ) -> Callable[..., Any]: """Create a tool caller function for an ADCP handler method. @@ -1899,12 +1900,36 @@ def create_tool_caller( server validation is a deliberate opt-in for authors who want dispatcher-level enforcement. + **Pre-validation hook (issue #614).** When ``pre_validation_hook`` is + supplied, it is called with ``(tool_name, shallow_copy_of_args)`` and + must return a ``dict`` that replaces the wire args before schema + validation and Pydantic ``model_validate`` run. The framework passes + a shallow copy of the incoming params dict, so the hook may mutate + its argument freely or return a brand-new dict — either style is safe. + The original wire params are captured before the copy is made, so + context echo always reflects what the buyer sent. Use this to apply + spec-mandated defaults for pre-v3 buyers that omit required fields + (e.g. ``buying_mode``, ``format_id`` shape coercion, ``asset_type`` + inference). The hook runs on every call; keep it fast. + Exceptions from the hook surface as ``INVALID_REQUEST`` — do not raise + for missing-but-defaultable fields, only for structurally unusable args. + + .. note:: + For the specific case of buyers omitting ``account``, see issue + #623 ("Typed dispatcher rejects valid request when ``account`` is + omitted") — that will be the canonical spec-level fix for that + field. Once #623 lands you can drop any ``account`` placeholder + hook entry. + Args: handler: The ADCP handler instance method_name: Name of the method to call validation: Optional :class:`ValidationHookConfig` with per-side modes (``strict`` / ``warn`` / ``off``). Omitting it disables server-side schema validation entirely. + pre_validation_hook: Optional callable ``(tool_name, args) -> args`` + invoked on the raw wire dict before schema + Pydantic validation. + See the **Pre-validation hook** section above. Returns: Async callable ``call_tool(params, context=None)``. The ``context`` @@ -1938,7 +1963,22 @@ def create_tool_caller( async def call_tool(params: dict[str, Any], context: ToolContext | None = None) -> Any: ctx = context if context is not None else ToolContext() - raw_params = params # Preserve the original dict for context echo. + + raw_params = params # Preserve original wire params for context echo. + + if pre_validation_hook is not None: + try: + params = pre_validation_hook(method_name, dict(params)) + except Exception as exc: + raise ADCPTaskError( + operation=method_name, + errors=[ + Error( + code="INVALID_REQUEST", + message=f"pre_validation_hook raised {type(exc).__name__}: {exc}", + ) + ], + ) from exc if request_mode is not None and request_mode != "off": outcome = validate_request(method_name, params) @@ -2069,6 +2109,7 @@ def __init__( *, advertise_all: bool = False, validation: ValidationHookConfig | None = None, + pre_validation_hooks: dict[str, Callable[[str, dict[str, Any]], dict[str, Any]]] | None = None, ): """Create tool set from handler. @@ -2081,6 +2122,9 @@ def __init__( (override-filtered advertisement). validation: Opt-in schema validation config applied to every tool caller. See :func:`create_tool_caller`. + pre_validation_hooks: Optional dict mapping tool name to a + ``(tool_name, args) -> args`` callable. Applied before + schema + Pydantic validation. See :func:`create_tool_caller`. """ self.handler = handler self._filtered_definitions = get_tools_for_handler(handler, advertise_all=advertise_all) @@ -2089,7 +2133,10 @@ def __init__( # Create tool callers only for filtered tools for tool_def in self._filtered_definitions: name = tool_def["name"] - self._tools[name] = create_tool_caller(handler, name, validation=validation) + hook = (pre_validation_hooks or {}).get(name) + self._tools[name] = create_tool_caller( + handler, name, validation=validation, pre_validation_hook=hook + ) @property def tool_definitions(self) -> list[dict[str, Any]]: @@ -2123,6 +2170,7 @@ def create_mcp_tools( *, advertise_all: bool = False, validation: ValidationHookConfig | None = None, + pre_validation_hooks: dict[str, Callable[[str, dict[str, Any]], dict[str, Any]]] | None = None, ) -> MCPToolSet: """Create MCP tools from an ADCP handler. @@ -2157,8 +2205,16 @@ async def call_tool(name: str, arguments: dict): every tool caller validates requests and responses against the bundled AdCP JSON schemas. See :func:`create_tool_caller` for mode semantics. + pre_validation_hooks: Optional dict mapping tool name to a + ``(tool_name, args) -> args`` callable. Applied before schema + + Pydantic validation. See :func:`create_tool_caller`. Returns: MCPToolSet with tool definitions and handlers. """ - return MCPToolSet(handler, advertise_all=advertise_all, validation=validation) + return MCPToolSet( + handler, + advertise_all=advertise_all, + validation=validation, + pre_validation_hooks=pre_validation_hooks, + ) diff --git a/src/adcp/server/serve.py b/src/adcp/server/serve.py index 80478191a..3bbfecaaa 100644 --- a/src/adcp/server/serve.py +++ b/src/adcp/server/serve.py @@ -134,6 +134,7 @@ class ServeConfig: advertise_all: bool = False max_request_size: int | None = None validation: ValidationHookConfig | None = None + pre_validation_hooks: dict[str, Callable[..., Any]] | None = None # --- Discovery manifest --- base_url: str | None = None @@ -525,6 +526,7 @@ def serve( max_request_size: int | None = None, streaming_responses: bool = False, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, + pre_validation_hooks: dict[str, Callable[..., Any]] | None = None, enable_debug_endpoints: bool = False, debug_traffic_source: Callable[[], dict[str, int]] | None = None, base_url: str | None = None, @@ -772,6 +774,7 @@ async def force_account_status(self, account_id, status): max_request_size = config.max_request_size streaming_responses = config.streaming_responses validation = config.validation + pre_validation_hooks = config.pre_validation_hooks enable_debug_endpoints = config.enable_debug_endpoints debug_traffic_source = config.debug_traffic_source base_url = config.base_url @@ -815,6 +818,7 @@ async def force_account_status(self, account_id, status): advertise_all=advertise_all, max_request_size=max_request_size, validation=validation, + pre_validation_hooks=pre_validation_hooks, base_url=base_url, specialisms=specialisms, description=description, @@ -838,6 +842,7 @@ async def force_account_status(self, account_id, status): max_request_size=max_request_size, streaming_responses=streaming_responses, validation=validation, + pre_validation_hooks=pre_validation_hooks, base_url=base_url, specialisms=specialisms, description=description, @@ -865,6 +870,7 @@ async def force_account_status(self, account_id, status): max_request_size=max_request_size, streaming_responses=streaming_responses, validation=validation, + pre_validation_hooks=pre_validation_hooks, base_url=base_url, specialisms=specialisms, description=description, @@ -1239,6 +1245,7 @@ def _serve_mcp( max_request_size: int | None = None, streaming_responses: bool = False, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, + pre_validation_hooks: dict[str, Callable[..., Any]] | None = None, base_url: str | None = None, specialisms: list[str] | None = None, description: str | None = None, @@ -1260,6 +1267,7 @@ def _serve_mcp( advertise_all=advertise_all, streaming_responses=streaming_responses, validation=validation, + pre_validation_hooks=pre_validation_hooks, allowed_hosts=allowed_hosts, allowed_origins=allowed_origins, enable_dns_rebinding_protection=enable_dns_rebinding_protection, @@ -1399,6 +1407,7 @@ def _serve_a2a( advertise_all: bool = False, max_request_size: int | None = None, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, + pre_validation_hooks: dict[str, Callable[..., Any]] | None = None, base_url: str | None = None, specialisms: list[str] | None = None, description: str | None = None, @@ -1427,6 +1436,7 @@ def _serve_a2a( message_parser=message_parser, advertise_all=advertise_all, validation=validation, + pre_validation_hooks=pre_validation_hooks, auth=auth, public_url=public_url, ) @@ -1481,6 +1491,7 @@ def _build_mcp_and_a2a_app( max_request_size: int | None = None, streaming_responses: bool = False, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, + pre_validation_hooks: dict[str, Callable[..., Any]] | None = None, base_url: str | None = None, specialisms: list[str] | None = None, description: str | None = None, @@ -1523,6 +1534,7 @@ def _build_mcp_and_a2a_app( advertise_all=advertise_all, streaming_responses=streaming_responses, validation=validation, + pre_validation_hooks=pre_validation_hooks, allowed_hosts=allowed_hosts, allowed_origins=allowed_origins, enable_dns_rebinding_protection=enable_dns_rebinding_protection, @@ -1576,6 +1588,7 @@ def _build_mcp_and_a2a_app( message_parser=message_parser, advertise_all=advertise_all, validation=validation, + pre_validation_hooks=pre_validation_hooks, auth=auth, public_url=public_url, ) @@ -1659,6 +1672,7 @@ def _serve_mcp_and_a2a( max_request_size: int | None = None, streaming_responses: bool = False, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, + pre_validation_hooks: dict[str, Callable[..., Any]] | None = None, base_url: str | None = None, specialisms: list[str] | None = None, description: str | None = None, @@ -1706,6 +1720,7 @@ def _serve_mcp_and_a2a( max_request_size=max_request_size, streaming_responses=streaming_responses, validation=validation, + pre_validation_hooks=pre_validation_hooks, base_url=base_url, specialisms=specialisms, description=description, @@ -1787,6 +1802,7 @@ def create_mcp_server( advertise_all: bool = False, streaming_responses: bool = False, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, + pre_validation_hooks: dict[str, Callable[..., Any]] | None = None, allowed_hosts: Sequence[str] | None = None, allowed_origins: Sequence[str] | None = None, enable_dns_rebinding_protection: bool | None = None, @@ -1948,6 +1964,7 @@ def create_mcp_server( middleware=middleware, advertise_all=advertise_all, validation=validation, + pre_validation_hooks=pre_validation_hooks, ) return mcp @@ -1961,6 +1978,7 @@ def _register_handler_tools( middleware: Sequence[SkillMiddleware] | None = None, advertise_all: bool = False, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, + pre_validation_hooks: dict[str, Callable[..., Any]] | None = None, ) -> None: """Register all ADCP tools from a handler onto a FastMCP server.""" # Freeze middleware ordering at registration time. Tuple both guards @@ -1980,7 +1998,10 @@ def _register_handler_tools( description = tool_def.get("description", "") input_schema = tool_def.get("inputSchema", {"type": "object", "properties": {}}) output_schema = tool_def.get("outputSchema") - caller = create_tool_caller(handler, tool_name, validation=validation) + hook = (pre_validation_hooks or {}).get(tool_name) + caller = create_tool_caller( + handler, tool_name, validation=validation, pre_validation_hook=hook + ) _register_tool( mcp, tool_name, diff --git a/tests/test_pre_validation_hooks.py b/tests/test_pre_validation_hooks.py new file mode 100644 index 000000000..e9d7e66bf --- /dev/null +++ b/tests/test_pre_validation_hooks.py @@ -0,0 +1,232 @@ +"""pre_validation_hooks — wire dict rewriting before schema + Pydantic validation. + +Tests that: +- The hook is called with (tool_name, raw_dict) before validation. +- A missing required field supplied by the hook passes model_validate. +- A hook that raises surfaces as INVALID_REQUEST, not INTERNAL_ERROR. +- pre_validation_hooks=None (default) is a no-op (hot path unchanged). +- A hook for tool X is not called when tool Y is dispatched. +- Hook runs before validate_request in strict validation mode. +- In-place mutation of hook args is safe (framework passes a shallow copy). +""" + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock + +import pytest + +from adcp.exceptions import ADCPTaskError +from adcp.server.base import ADCPHandler, ToolContext +from adcp.server.mcp_tools import create_tool_caller + + +class _MinimalHandler(ADCPHandler[Any]): + """Passes params straight through as the return value.""" + + async def get_products(self, params: dict[str, Any], ctx: ToolContext) -> dict[str, Any]: + return {"params_received": dict(params)} + + +class _TypedHandler(ADCPHandler[Any]): + """Handler that validates get_products against the real Pydantic model.""" + + async def get_products(self, params: Any, ctx: ToolContext) -> dict[str, Any]: + return {"buying_mode": getattr(params, "buying_mode", None)} + + +# --------------------------------------------------------------------------- +# Basic hook mechanics +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_hook_is_called_with_tool_name_and_args() -> None: + """Hook receives (tool_name, raw_dict) and its return value replaces the args.""" + calls: list[tuple[str, dict[str, Any]]] = [] + + def my_hook(tool_name: str, args: dict[str, Any]) -> dict[str, Any]: + calls.append((tool_name, dict(args))) + return {**args, "injected": True} + + handler = _MinimalHandler() + caller = create_tool_caller(handler, "get_products", pre_validation_hook=my_hook) + + result = await caller({"foo": "bar"}) + assert len(calls) == 1 + assert calls[0] == ("get_products", {"foo": "bar"}) + assert result["params_received"]["injected"] is True + + +@pytest.mark.asyncio +async def test_hook_none_is_noop() -> None: + """pre_validation_hook=None (default) does not change dispatch behaviour.""" + handler = _MinimalHandler() + caller = create_tool_caller(handler, "get_products", pre_validation_hook=None) + result = await caller({"key": "val"}) + assert result["params_received"] == {"key": "val"} + + +@pytest.mark.asyncio +async def test_hook_not_called_for_other_tool() -> None: + """A hook registered for get_products is not invoked when create_media_buy is called.""" + calls: list[str] = [] + + def hook(tool_name: str, args: dict[str, Any]) -> dict[str, Any]: + calls.append(tool_name) + return args + + handler = _MinimalHandler() + # create a caller for a different tool — hook should not apply + caller_other = create_tool_caller(handler, "get_adcp_capabilities", pre_validation_hook=None) + await caller_other({}) + assert calls == [], "hook was called for the wrong tool" + + +# --------------------------------------------------------------------------- +# Missing-required-field scenario: the primary use case +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_hook_supplies_missing_required_field() -> None: + """A hook that fills in buying_mode allows model_validate to succeed. + + Without the hook, a bare {} request would fail model_validate on + GetProductsRequest because buying_mode is required in 4.4+ schemas. + This test uses the dict-typed handler path to avoid importing schema + models directly. + """ + called_with: list[dict[str, Any]] = [] + + def buying_mode_default(tool_name: str, args: dict[str, Any]) -> dict[str, Any]: + called_with.append(args) + return {**args, "buying_mode": args.get("buying_mode", "brief")} + + handler = _MinimalHandler() + caller = create_tool_caller(handler, "get_products", pre_validation_hook=buying_mode_default) + + result = await caller({}) + assert called_with == [{}] + assert result["params_received"]["buying_mode"] == "brief" + + +# --------------------------------------------------------------------------- +# Hook exception handling +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_hook_exception_surfaces_as_invalid_request() -> None: + """A hook that raises must surface as INVALID_REQUEST, not INTERNAL_ERROR.""" + + def bad_hook(tool_name: str, args: dict[str, Any]) -> dict[str, Any]: + raise ValueError("unsupported format_id shape") + + handler = _MinimalHandler() + caller = create_tool_caller(handler, "get_products", pre_validation_hook=bad_hook) + + with pytest.raises(ADCPTaskError) as exc_info: + await caller({"something": "here"}) + + errors = exc_info.value.errors + assert errors, "ADCPTaskError must carry at least one error" + assert errors[0].code == "INVALID_REQUEST" + assert "ValueError" in errors[0].message + + +# --------------------------------------------------------------------------- +# Hook must not mutate raw_params (context-echo path) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_hook_does_not_pollute_context_echo() -> None: + """raw_params must snapshot the original wire dict BEFORE the hook runs. + + inject_context echoes the wire ``context`` field from raw_params back into + the response. If raw_params were assigned after the hook, a hook that + returns a new dict (dropping ``context``) would silently suppress the echo. + Conversely, a hook that adds keys would cause server-injected fields to + appear in the echo as if the buyer sent them. + + We exercise both directions: + - A hook that strips all fields and adds "server_default" (no context key + in its return) still produces context echo from the original wire params. + - The handler result carries hook-modified fields, confirming the hook ran. + """ + wire_context = {"correlation_id": "req-abc"} + wire_args = {"buyer_field": "x", "context": wire_context} + + def stripping_hook(tool_name: str, args: dict[str, Any]) -> dict[str, Any]: + # Returns a brand-new dict — deliberately omits "context" + return {"server_default": "y"} + + handler = _MinimalHandler() + caller = create_tool_caller(handler, "get_products", pre_validation_hook=stripping_hook) + result = await caller(dict(wire_args)) + + # Hook ran: handler received hook-modified params, not original + assert result["params_received"] == {"server_default": "y"} + # Context echo used raw_params (pre-hook snapshot), not hook return + assert result.get("context") == wire_context + + +@pytest.mark.asyncio +async def test_in_place_mutation_is_safe_for_context_echo() -> None: + """Hook that mutates its argument in-place must not corrupt context echo. + + The framework passes a shallow copy to the hook, so in-place mutation + of the hook argument leaves the original wire params untouched for the + context-echo path. This test exercises the ``args["key"] = val; return args`` + pattern that the original docstring labelled a "bug". + """ + wire_context = {"correlation_id": "req-xyz"} + wire_args = {"buyer_field": "original", "context": wire_context} + + def mutating_hook(tool_name: str, args: dict[str, Any]) -> dict[str, Any]: + args["server_default"] = "injected" + del args["buyer_field"] + return args + + handler = _MinimalHandler() + caller = create_tool_caller(handler, "get_products", pre_validation_hook=mutating_hook) + result = await caller(dict(wire_args)) + + assert result["params_received"].get("server_default") == "injected" + assert "buyer_field" not in result["params_received"] + assert result.get("context") == wire_context + + +# --------------------------------------------------------------------------- +# MCPToolSet threading +# --------------------------------------------------------------------------- + + +def test_mcp_tool_set_threads_hook_to_tool_caller() -> None: + """MCPToolSet must forward pre_validation_hooks to create_tool_caller.""" + import importlib + from unittest.mock import patch + + _serve_mod = importlib.import_module("adcp.server.mcp_tools") + + handler = _MinimalHandler() + captured_hooks: list[Any] = [] + real = _serve_mod.create_tool_caller + + def spy(h: Any, name: str, **kw: Any) -> Any: + captured_hooks.append(kw.get("pre_validation_hook")) + return real(h, name, **kw) + + my_hook = lambda n, a: a # noqa: E731 + hooks = {"get_products": my_hook} + + with patch.object(_serve_mod, "create_tool_caller", side_effect=spy): + from adcp.server.mcp_tools import MCPToolSet + + MCPToolSet(handler, pre_validation_hooks=hooks) + + assert any(h is my_hook for h in captured_hooks), ( + "my_hook was not forwarded to create_tool_caller for get_products" + )