Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/adcp/decisioning/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,7 @@ def serve(
advertise_all: bool = False,
mock_ad_server: Any | None = None,
enable_debug_endpoints: bool = False,
pre_validation_hooks: dict[str, Any] | None = None,
**serve_kwargs: Any,
) -> None:
"""One-call wrapper — build the handler and serve over MCP.
Expand Down Expand Up @@ -435,6 +436,25 @@ def serve(
responses="strict")`` to enable schema-driven request/response
validation against the bundled AdCP JSON schemas — sellers who
want their server to enforce wire conformance turn it on here.
:param pre_validation_hooks: Optional dict mapping AdCP tool name to
a ``(tool_name, raw_args) -> raw_args`` callable. The hook runs
on the raw wire dict **before** schema + Pydantic validation —
use it to apply spec-mandated defaults for pre-v3 buyers that
omit required fields. Example::

serve(
router,
pre_validation_hooks={
"get_products": lambda n, a: {
**a, "buying_mode": a.get("buying_mode", "brief")
},
},
)

Hook exceptions surface as ``INVALID_REQUEST`` on the wire.
The hook receives a shallow copy of the wire args, so it may
mutate its argument freely or return a new dict — either style
is safe. Context echo always reflects the original wire input.
"""
# Local import to avoid a circular at module-load time. Adopter
# serves never run during foundation imports anyway.
Expand Down Expand Up @@ -504,6 +524,8 @@ def serve(

server_name = name or type(platform).__name__
debug_traffic_source = mock_ad_server.get_traffic if mock_ad_server is not None else None
if pre_validation_hooks is not None:
serve_kwargs["pre_validation_hooks"] = pre_validation_hooks
_adcp_serve(
handler,
name=server_name,
Expand Down
8 changes: 7 additions & 1 deletion src/adcp/server/a2a_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ def __init__(
message_parser: MessageParser | None = None,
advertise_all: bool = False,
validation: ValidationHookConfig | None = SERVER_DEFAULT_VALIDATION,
pre_validation_hooks: dict[str, Any] | None = None,
test_controller_account_resolver: Any | None = None,
) -> None:
self._handler = handler
Expand All @@ -169,7 +170,10 @@ def __init__(
name = tool_def["name"]
if name == "comply_test_controller" and test_controller is None:
continue
self._tool_callers[name] = create_tool_caller(handler, name, validation=validation)
hook = (pre_validation_hooks or {}).get(name)
self._tool_callers[name] = create_tool_caller(
handler, name, validation=validation, pre_validation_hook=hook
)

if test_controller is not None:
self._register_test_controller(test_controller)
Expand Down Expand Up @@ -758,6 +762,7 @@ def create_a2a_server(
message_parser: MessageParser | None = None,
advertise_all: bool = False,
validation: ValidationHookConfig | None = SERVER_DEFAULT_VALIDATION,
pre_validation_hooks: dict[str, Any] | None = None,
context_builder: Any | None = None,
auth: BearerTokenAuth | None = None,
public_url: str | None = None,
Expand Down Expand Up @@ -884,6 +889,7 @@ def create_a2a_server(
message_parser=message_parser,
advertise_all=advertise_all,
validation=validation,
pre_validation_hooks=pre_validation_hooks,
test_controller_account_resolver=test_controller_account_resolver,
)

Expand Down
62 changes: 59 additions & 3 deletions src/adcp/server/mcp_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1870,6 +1870,7 @@ def create_tool_caller(
method_name: str,
*,
validation: ValidationHookConfig | None = None,
pre_validation_hook: Callable[[str, dict[str, Any]], dict[str, Any]] | None = None,
) -> Callable[..., Any]:
"""Create a tool caller function for an ADCP handler method.

Expand Down Expand Up @@ -1899,12 +1900,36 @@ def create_tool_caller(
server validation is a deliberate opt-in for authors who want
dispatcher-level enforcement.

**Pre-validation hook (issue #614).** When ``pre_validation_hook`` is
supplied, it is called with ``(tool_name, shallow_copy_of_args)`` and
must return a ``dict`` that replaces the wire args before schema
validation and Pydantic ``model_validate`` run. The framework passes
a shallow copy of the incoming params dict, so the hook may mutate
its argument freely or return a brand-new dict — either style is safe.
The original wire params are captured before the copy is made, so
context echo always reflects what the buyer sent. Use this to apply
spec-mandated defaults for pre-v3 buyers that omit required fields
(e.g. ``buying_mode``, ``format_id`` shape coercion, ``asset_type``
inference). The hook runs on every call; keep it fast.
Exceptions from the hook surface as ``INVALID_REQUEST`` — do not raise
for missing-but-defaultable fields, only for structurally unusable args.

.. note::
For the specific case of buyers omitting ``account``, see issue
#623 ("Typed dispatcher rejects valid request when ``account`` is
omitted") — that will be the canonical spec-level fix for that
field. Once #623 lands you can drop any ``account`` placeholder
hook entry.

Args:
handler: The ADCP handler instance
method_name: Name of the method to call
validation: Optional :class:`ValidationHookConfig` with
per-side modes (``strict`` / ``warn`` / ``off``). Omitting
it disables server-side schema validation entirely.
pre_validation_hook: Optional callable ``(tool_name, args) -> args``
invoked on the raw wire dict before schema + Pydantic validation.
See the **Pre-validation hook** section above.

Returns:
Async callable ``call_tool(params, context=None)``. The ``context``
Expand Down Expand Up @@ -1938,7 +1963,22 @@ def create_tool_caller(

async def call_tool(params: dict[str, Any], context: ToolContext | None = None) -> Any:
ctx = context if context is not None else ToolContext()
raw_params = params # Preserve the original dict for context echo.

raw_params = params # Preserve original wire params for context echo.

if pre_validation_hook is not None:
try:
params = pre_validation_hook(method_name, dict(params))
except Exception as exc:
raise ADCPTaskError(
operation=method_name,
errors=[
Error(
code="INVALID_REQUEST",
message=f"pre_validation_hook raised {type(exc).__name__}: {exc}",
)
],
) from exc

if request_mode is not None and request_mode != "off":
outcome = validate_request(method_name, params)
Expand Down Expand Up @@ -2069,6 +2109,7 @@ def __init__(
*,
advertise_all: bool = False,
validation: ValidationHookConfig | None = None,
pre_validation_hooks: dict[str, Callable[[str, dict[str, Any]], dict[str, Any]]] | None = None,
):
"""Create tool set from handler.

Expand All @@ -2081,6 +2122,9 @@ def __init__(
(override-filtered advertisement).
validation: Opt-in schema validation config applied to every
tool caller. See :func:`create_tool_caller`.
pre_validation_hooks: Optional dict mapping tool name to a
``(tool_name, args) -> args`` callable. Applied before
schema + Pydantic validation. See :func:`create_tool_caller`.
"""
self.handler = handler
self._filtered_definitions = get_tools_for_handler(handler, advertise_all=advertise_all)
Expand All @@ -2089,7 +2133,10 @@ def __init__(
# Create tool callers only for filtered tools
for tool_def in self._filtered_definitions:
name = tool_def["name"]
self._tools[name] = create_tool_caller(handler, name, validation=validation)
hook = (pre_validation_hooks or {}).get(name)
self._tools[name] = create_tool_caller(
handler, name, validation=validation, pre_validation_hook=hook
)

@property
def tool_definitions(self) -> list[dict[str, Any]]:
Expand Down Expand Up @@ -2123,6 +2170,7 @@ def create_mcp_tools(
*,
advertise_all: bool = False,
validation: ValidationHookConfig | None = None,
pre_validation_hooks: dict[str, Callable[[str, dict[str, Any]], dict[str, Any]]] | None = None,
) -> MCPToolSet:
"""Create MCP tools from an ADCP handler.

Expand Down Expand Up @@ -2157,8 +2205,16 @@ async def call_tool(name: str, arguments: dict):
every tool caller validates requests and responses against
the bundled AdCP JSON schemas. See
:func:`create_tool_caller` for mode semantics.
pre_validation_hooks: Optional dict mapping tool name to a
``(tool_name, args) -> args`` callable. Applied before schema
+ Pydantic validation. See :func:`create_tool_caller`.

Returns:
MCPToolSet with tool definitions and handlers.
"""
return MCPToolSet(handler, advertise_all=advertise_all, validation=validation)
return MCPToolSet(
handler,
advertise_all=advertise_all,
validation=validation,
pre_validation_hooks=pre_validation_hooks,
)
23 changes: 22 additions & 1 deletion src/adcp/server/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ class ServeConfig:
advertise_all: bool = False
max_request_size: int | None = None
validation: ValidationHookConfig | None = None
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None

# --- Discovery manifest ---
base_url: str | None = None
Expand Down Expand Up @@ -525,6 +526,7 @@ def serve(
max_request_size: int | None = None,
streaming_responses: bool = False,
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
enable_debug_endpoints: bool = False,
debug_traffic_source: Callable[[], dict[str, int]] | None = None,
base_url: str | None = None,
Expand Down Expand Up @@ -772,6 +774,7 @@ async def force_account_status(self, account_id, status):
max_request_size = config.max_request_size
streaming_responses = config.streaming_responses
validation = config.validation
pre_validation_hooks = config.pre_validation_hooks
enable_debug_endpoints = config.enable_debug_endpoints
debug_traffic_source = config.debug_traffic_source
base_url = config.base_url
Expand Down Expand Up @@ -815,6 +818,7 @@ async def force_account_status(self, account_id, status):
advertise_all=advertise_all,
max_request_size=max_request_size,
validation=validation,
pre_validation_hooks=pre_validation_hooks,
base_url=base_url,
specialisms=specialisms,
description=description,
Expand All @@ -838,6 +842,7 @@ async def force_account_status(self, account_id, status):
max_request_size=max_request_size,
streaming_responses=streaming_responses,
validation=validation,
pre_validation_hooks=pre_validation_hooks,
base_url=base_url,
specialisms=specialisms,
description=description,
Expand Down Expand Up @@ -865,6 +870,7 @@ async def force_account_status(self, account_id, status):
max_request_size=max_request_size,
streaming_responses=streaming_responses,
validation=validation,
pre_validation_hooks=pre_validation_hooks,
base_url=base_url,
specialisms=specialisms,
description=description,
Expand Down Expand Up @@ -1239,6 +1245,7 @@ def _serve_mcp(
max_request_size: int | None = None,
streaming_responses: bool = False,
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
base_url: str | None = None,
specialisms: list[str] | None = None,
description: str | None = None,
Expand All @@ -1260,6 +1267,7 @@ def _serve_mcp(
advertise_all=advertise_all,
streaming_responses=streaming_responses,
validation=validation,
pre_validation_hooks=pre_validation_hooks,
allowed_hosts=allowed_hosts,
allowed_origins=allowed_origins,
enable_dns_rebinding_protection=enable_dns_rebinding_protection,
Expand Down Expand Up @@ -1399,6 +1407,7 @@ def _serve_a2a(
advertise_all: bool = False,
max_request_size: int | None = None,
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
base_url: str | None = None,
specialisms: list[str] | None = None,
description: str | None = None,
Expand Down Expand Up @@ -1427,6 +1436,7 @@ def _serve_a2a(
message_parser=message_parser,
advertise_all=advertise_all,
validation=validation,
pre_validation_hooks=pre_validation_hooks,
auth=auth,
public_url=public_url,
)
Expand Down Expand Up @@ -1481,6 +1491,7 @@ def _build_mcp_and_a2a_app(
max_request_size: int | None = None,
streaming_responses: bool = False,
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
base_url: str | None = None,
specialisms: list[str] | None = None,
description: str | None = None,
Expand Down Expand Up @@ -1523,6 +1534,7 @@ def _build_mcp_and_a2a_app(
advertise_all=advertise_all,
streaming_responses=streaming_responses,
validation=validation,
pre_validation_hooks=pre_validation_hooks,
allowed_hosts=allowed_hosts,
allowed_origins=allowed_origins,
enable_dns_rebinding_protection=enable_dns_rebinding_protection,
Expand Down Expand Up @@ -1576,6 +1588,7 @@ def _build_mcp_and_a2a_app(
message_parser=message_parser,
advertise_all=advertise_all,
validation=validation,
pre_validation_hooks=pre_validation_hooks,
auth=auth,
public_url=public_url,
)
Expand Down Expand Up @@ -1659,6 +1672,7 @@ def _serve_mcp_and_a2a(
max_request_size: int | None = None,
streaming_responses: bool = False,
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
base_url: str | None = None,
specialisms: list[str] | None = None,
description: str | None = None,
Expand Down Expand Up @@ -1706,6 +1720,7 @@ def _serve_mcp_and_a2a(
max_request_size=max_request_size,
streaming_responses=streaming_responses,
validation=validation,
pre_validation_hooks=pre_validation_hooks,
base_url=base_url,
specialisms=specialisms,
description=description,
Expand Down Expand Up @@ -1787,6 +1802,7 @@ def create_mcp_server(
advertise_all: bool = False,
streaming_responses: bool = False,
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
allowed_hosts: Sequence[str] | None = None,
allowed_origins: Sequence[str] | None = None,
enable_dns_rebinding_protection: bool | None = None,
Expand Down Expand Up @@ -1948,6 +1964,7 @@ def create_mcp_server(
middleware=middleware,
advertise_all=advertise_all,
validation=validation,
pre_validation_hooks=pre_validation_hooks,
)
return mcp

Expand All @@ -1961,6 +1978,7 @@ def _register_handler_tools(
middleware: Sequence[SkillMiddleware] | None = None,
advertise_all: bool = False,
validation: ValidationHookConfig | None = DEFAULT_VALIDATION,
pre_validation_hooks: dict[str, Callable[..., Any]] | None = None,
) -> None:
"""Register all ADCP tools from a handler onto a FastMCP server."""
# Freeze middleware ordering at registration time. Tuple both guards
Expand All @@ -1980,7 +1998,10 @@ def _register_handler_tools(
description = tool_def.get("description", "")
input_schema = tool_def.get("inputSchema", {"type": "object", "properties": {}})
output_schema = tool_def.get("outputSchema")
caller = create_tool_caller(handler, tool_name, validation=validation)
hook = (pre_validation_hooks or {}).get(tool_name)
caller = create_tool_caller(
handler, tool_name, validation=validation, pre_validation_hook=hook
)
_register_tool(
mcp,
tool_name,
Expand Down
Loading
Loading