diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e7da672d8..387339acd 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -131,6 +131,100 @@ jobs: echo "" echo "✅ All commits follow Conventional Commits format" + downstream-imports: + name: Downstream import smoke (representative consumer symbols) + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.11 + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Build and install wheel + run: | + python -m pip install --upgrade pip build + python -m build --wheel --outdir dist/ + pip install dist/*.whl + + # Proxy for real downstream import sites (salesagent, creative agents, + # signals agents). Any ImportError here means we broke the public API + # surface without a migration pointer — failing CI is the goal. + - name: Import representative public-API symbols + run: | + python - <<'PY' + from adcp import ( + ADCPClient, + AgentConfig, + BrandReference, + CpmPricingOption, + CreateMediaBuyRequest, + Error, + GetProductsRequest, + ListCreativesRequest, + MediaBuyStatus, + Package, + PackageRequest, + PublisherPropertiesAll, + SyncCatalogsRequest, + ) + from adcp.types import ( + AudioFormatAsset, + BriefFormatAsset, + CatalogFormatAsset, + ContextObject, + CreativeAsset, + CssFormatAsset, + DaastFormatAsset, + HtmlFormatAsset, + ImageFormatAsset, + JavascriptFormatAsset, + MarkdownFormatAsset, + RepeatableAssetGroup, + TargetingOverlay, + TextFormatAsset, + UrlFormatAsset, + VastFormatAsset, + VideoFormatAsset, + WebhookFormatAsset, + ) + + # Removed-type shims: old import paths must raise a guided + # ImportError pointing at the migration guide. + import adcp + for name in ("BrandManifest", "FormatCategory", "DeliverTo"): + try: + getattr(adcp, name) + except ImportError as exc: + assert "MIGRATION_v3_to_v4" in str(exc), ( + f"{name} deprecation shim dropped migration pointer: {exc}" + ) + else: + raise AssertionError( + f"{name} import should raise ImportError with migration pointer" + ) + + # The deep submodule path (some older import sites reach this far) + # must also surface the migration pointer, not a bare ModuleNotFoundError. + try: + from adcp.types.generated_poc.enums.format_category import FormatCategory # noqa: F401 + except ImportError as exc: + assert "MIGRATION_v3_to_v4" in str(exc), exc + else: + raise AssertionError( + "format_category submodule should raise ImportError with migration pointer" + ) + + assert adcp.__version__ and adcp.__version__ != "3.12.0", ( + f"adcp.__version__={adcp.__version__!r} — expected real pkg metadata" + ) + assert adcp.get_adcp_version(), "ADCP_VERSION file is empty" + + print(f"OK — adcp=={adcp.__version__}, spec={adcp.get_adcp_version()}") + PY + schema-check: name: Validate schemas are up-to-date runs-on: ubuntu-latest diff --git a/MIGRATION_v3_to_v4.md b/MIGRATION_v3_to_v4.md index 888c50175..cda3e55ec 100644 --- a/MIGRATION_v3_to_v4.md +++ b/MIGRATION_v3_to_v4.md @@ -177,6 +177,56 @@ Aliases for all discriminated-union success/error variants live in `adcp/types/aliases.py`. If a variant you need isn't aliased, file an issue — aliasing is the supported path; direct `Assets*` imports aren't. +### Creative format asset slots: `FormatAsset` aliases + +Format definitions enumerate asset slots with a discriminated union on +`asset_type`. These are the classes salesagent hit when `Assets5`/`Assets14` +renumbered to `Assets57`/`Assets149`. The stable names are: + +| Generated class | Semantic alias | `asset_type` | +|--------------------|-------------------------------|--------------------| +| `Assets` (base) | `ImageFormatAsset` | `image` | +| `Assets81` | `VideoFormatAsset` | `video` | +| `Assets82` | `AudioFormatAsset` | `audio` | +| `Assets83` | `TextFormatAsset` | `text` | +| `Assets84` | `MarkdownFormatAsset` | `markdown` | +| `Assets85` | `HtmlFormatAsset` | `html` | +| `Assets86` | `CssFormatAsset` | `css` | +| `Assets87` | `JavascriptFormatAsset` | `javascript` | +| `Assets88` | `VastFormatAsset` | `vast` | +| `Assets89` | `DaastFormatAsset` | `daast` | +| `Assets90` | `UrlFormatAsset` | `url` | +| `Assets91` | `WebhookFormatAsset` | `webhook` | +| `Assets92` | `BriefFormatAsset` | `brief` | +| `Assets93` | `CatalogFormatAsset` | `catalog` | +| `Assets94` | `RepeatableAssetGroup` | `repeatable_group` | +| `Assets95…Assets106` | `ImageFormatGroupAsset` etc.| (same type inside a group) | + +The `Format` prefix disambiguates these *format-slot* types from the +separate *asset-content* types (`VideoAsset`, `HtmlAsset`, `ImageAsset`, +etc. in `adcp.types`), which describe the actual asset payload (codec, +duration, file URL) delivered by creative sync — a distinct concept. + +`tests/test_asset_aliases_stable.py` pins each alias to its expected +`asset_type` discriminator default. When upstream renumbers, that test +fails and points at the specific alias that drifted — fix the numbered +import in `src/adcp/types/aliases.py`, not your call sites. + +### Deep-submodule `format_category` shim + +Some older import sites reach into the raw generated path: + +```python +from adcp.types.generated_poc.enums.format_category import FormatCategory +``` + +4.0 registers a ``sys.modules`` shim for this path so the import raises +an ``ImportError`` with the same migration pointer as the top-level +``from adcp import FormatCategory``, instead of a bare +``ModuleNotFoundError``. If you're seeing the deep path in your code, +switch to the migration above — the shim is a safety net, not a +permanent export. + ## Public vs. internal imports `adcp.types.generated_poc.*` is internal. Generated module paths and class diff --git a/docs/handler-authoring.md b/docs/handler-authoring.md new file mode 100644 index 000000000..a338bd20e --- /dev/null +++ b/docs/handler-authoring.md @@ -0,0 +1,245 @@ +# Authoring an ADCP server handler + +This guide is for teams building AdCP-compliant agents — sales agents, +creative agents, governance agents, signals agents — on top of +`adcp.server`. It captures the patterns that keep handlers spec-compliant +and production-grade, plus the hooks the SDK provides so you don't have +to rebuild middleware that already exists. + +## 15-minute decision tree + +- **Just want an agent running?** → Start with "The one-file starting + point" below, then `serve()`. +- **Need auth in front of tools?** → If your proxy already validates + credentials, use "Pattern 1 — reverse-proxy auth". Otherwise copy + `examples/mcp_with_auth_middleware.py` — it covers the ContextVars + pattern, the `DISCOVERY_TOOLS` bypass, and `hmac.compare_digest`. +- **Multi-tenant?** → Subclass `ToolContext`, populate `tenant_id` in + your `context_factory`, and read the + [Multi-tenant typing](#multi-tenant-typing) section. The idempotency + middleware uses `(tenant_id, caller_identity)` for scope isolation — + populating `tenant_id` is required for cross-tenant safety. +- **Full context?** → Keep reading. + +## The one-file starting point + +```python +from adcp.server import ADCPHandler, ToolContext, serve +from adcp.server.responses import capabilities_response, products_response + +class MyAgent(ADCPHandler): + async def get_adcp_capabilities(self, params, context=None): + return capabilities_response(["media_buy"]) + + async def get_products(self, params, context=None): + return products_response(MY_PRODUCTS) + +serve(MyAgent(), name="my-agent") +``` + +That's a complete AdCP agent. All 57+ other tools return `not_supported` +automatically via the `ADCPHandler` default methods; override only what +your agent actually implements. + +## The `_impl` pattern (production-grade) + +Production agents usually don't put business logic directly on handler +methods. Instead: + +- Business logic lives in `src/core/_impl/` or similar — transport-free, + takes typed domain objects, returns typed responses. +- `ADCPHandler` methods are thin delegations that pull identity / + adapter config out of `ToolContext` and call the `_impl` function. + +This keeps the tested surface independent of whether the caller came in +via MCP, A2A, HTTP, a background job, or a test. The SDK's server +framework is designed for this shape: + +```python +from adcp.server import ADCPHandler, ToolContext +from myagent.impl.products import get_products_impl +from myagent.identity import ResolvedIdentity + +class MyAgent(ADCPHandler): + async def get_products(self, params, context: ToolContext | None = None): + identity = _resolve_identity(context) + return await get_products_impl(params, identity=identity) + +def _resolve_identity(ctx: ToolContext | None) -> ResolvedIdentity: + if ctx is None or ctx.caller_identity is None: + raise AuthenticationRequired() + return ResolvedIdentity( + principal_id=ctx.caller_identity, + tenant_id=ctx.tenant_id, + # … adapter config, feature flags, etc. from your DB + ) +``` + +## Authentication + +The SDK does not enforce authentication. There are two supported +integration patterns: + +### Pattern 1 — reverse-proxy auth + +The proxy (nginx, Caddy, Envoy) validates credentials and forwards only +authenticated requests. The SDK trusts the proxy's decision. Simplest, +and the right choice when your identity provider and tool endpoints run +behind the same gateway. + +### Pattern 2 — in-process HTTP middleware + +Call `mcp.streamable_http_app()` to get the Starlette ASGI app, then +`app.add_middleware(YourAuthMiddleware)`. The middleware validates +credentials, stashes the resolved principal + tenant somewhere the +`context_factory` can read (ContextVars are recommended), and calls +`context_factory=` on `create_mcp_server()` to inject a typed +`ToolContext` per call. + +Full worked example: `examples/mcp_with_auth_middleware.py`. Integration +test proving the composition: `tests/test_mcp_middleware_composition.py`. + +### Discovery tools bypass auth + +Per AdCP spec, `get_adcp_capabilities` is the handshake — clients MUST +be able to call it before authenticating. The SDK exports the list as a +frozenset: + +```python +from adcp.server import DISCOVERY_TOOLS + +async def dispatch(self, request, call_next): + tool_name = _peek_tool_name(request) + if tool_name not in DISCOVERY_TOOLS: + self._require_valid_token(request) + return await call_next(request) +``` + +Your agent may have additional public discovery tools outside the AdCP +spec (e.g. a public `list_public_formats`); extend with `DISCOVERY_TOOLS +| {"your_tool"}` rather than redefining the set. + +## Idempotency + +The SDK ships an `IdempotencyStore` middleware that honors the +`Idempotency-Key` header per AdCP §idempotency. Requests with the same +`(caller_identity, idempotency_key)` return the cached response instead +of re-executing the handler. + +The store keys on `ToolContext.caller_identity` — if your transport +doesn't populate it, per-principal scoping falls through and dedup is +skipped (with a UserWarning). A2A populates it automatically from +`ServerCallContext.user`; MCP requires you to wire `context_factory`. + +Don't rebuild idempotency in your handler. Import the middleware. + +## Error handling + +Raise `AdCPError` (or a subclass: `ADCPTaskError`, `IdempotencyConflictError`) +from handler code. The SDK translates to the wire-level error shape the +AdCP spec mandates — MCP gets a `ToolError` with the spec error code in +the message, A2A gets a `JSON-RPC error` with the code populated. + +Use the error classification helpers: + +```python +from adcp.server import adcp_error + +raise adcp_error("BUDGET_TOO_LOW") # auto-classifies as correctable +raise adcp_error("DOWNSTREAM_TIMEOUT") # auto-classifies as transient +``` + +The recovery hint (transient / correctable / terminal) gets populated +from 20+ standard codes — don't reinvent the table. + +## Response builders + +Manual `model_dump()` on response Pydantic objects is error-prone — +you'll drift from the spec's required fields. Use the response builders: + +```python +from adcp.server.responses import media_buy_response, products_response + +return media_buy_response( + media_buy_id="mb_123", + status="active", # auto-populates valid_actions from the state machine +) +``` + +One per AdCP operation. Read the `adcp.server.responses` docstrings. + +## Multi-tenant typing + +Production multi-tenant agents usually carry `tenant + principal + +adapter + testing hooks` in their own identity type. `ToolContext` +exposes the fields those handlers need: + +- `ToolContext.tenant_id: str | None` — first-class field; populate from + your `context_factory`. **Required** for multi-tenant deployments + whose principal IDs are only unique within a tenant (Okta group-scoped, + SCIM per-tenant, seller-internal employee IDs) — the idempotency + store keys its cache on `(tenant_id, caller_identity)`, so leaving + `tenant_id` unset collapses distinct tenants into the same scope and + enables cross-tenant response replay. +- `ToolContext.metadata: dict[str, Any]` — escape hatch for adapter + instance handles, testing hooks, per-tenant config blobs. +- Subclassing `ToolContext` is supported — return the subclass from your + `context_factory` and your handler methods `isinstance(context, + MyContext)` (or `cast(MyContext, context)` if you've established the + invariant via the factory) to reach the extra fields. + +When in doubt, subclass: `metadata: dict[str, Any]` loses type safety. + +## A2A transport + +`serve(MyAgent(), transport="a2a")` wires the same handler through the +A2A protocol with auto-generated agent card (`/.well-known/agent.json`) +derived from the `ADCPHandler` methods your class overrides. + +Caveats: + +- The SDK uses `a2a-sdk`'s `DefaultRequestHandler` + `InMemoryTaskStore`. + Tasks do not persist across restarts. +- Push-notification config is in-memory only. +- Per-skill middleware hooks for audit logging / activity feeds don't + exist yet (tracked in the SDK adoption roadmap). + +If your agent needs DB-backed tasks, persistent push-notif config, or +per-skill audit hooks, keep a custom A2A server for now. The MCP side is +production-ready; the A2A side is reference-quality. + +## Testing + +The integration test pattern in `tests/test_mcp_middleware_composition.py` +is the shape you can copy for your own middleware tests. Key pieces: + +- `create_mcp_server(..., context_factory=build_context)` wires the + context factory. +- `mcp.settings.stateless_http = True` + `mcp.settings.json_response = True` + disables the session manager so tests don't need a TaskGroup. +- `mcp.settings.transport_security.allowed_hosts = ["localhost"]` allows + in-process `httpx.ASGITransport` requests through the DNS-rebinding + guard. +- Run the app's lifespan manually if you're exercising HTTP endpoints. + +## What not to build + +- Don't write per-tool `@mcp.tool()` wrappers. `create_mcp_server()` + registers all ADCP tools from a handler automatically. +- Don't hand-maintain an agent card. A2A auto-derives it from the + handler methods you override. +- Don't reinvent `IdempotencyStore`, response builders, or error + classification. Use the shipped helpers. +- Don't import from `adcp.types.generated_poc.*`. Everything public + lives at `adcp.types` or `adcp` — and the internal paths renumber + between releases (see `MIGRATION_v3_to_v4.md`). + +## Where to look next + +- `examples/minimal_sales_agent.py` — handler-only starting point. +- `examples/mcp_with_auth_middleware.py` — full auth + typed context. +- `src/adcp/server/responses.py` — response builder reference. +- `src/adcp/server/helpers.py` — error codes, state machine, account + resolution. +- `tests/test_mcp_middleware_composition.py` — the integration test + that protects this contract. diff --git a/examples/mcp_with_auth_middleware.py b/examples/mcp_with_auth_middleware.py new file mode 100644 index 000000000..f86eb68f3 --- /dev/null +++ b/examples/mcp_with_auth_middleware.py @@ -0,0 +1,216 @@ +"""Example: custom HTTP auth middleware + typed ToolContext via context_factory. + +This is the recipe for multi-tenant sales agents that need to: + +1. Validate bearer tokens (or any other credential) in front of + :func:`adcp.server.create_mcp_server`-registered tools. +2. Allow the AdCP discovery handshake (``get_adcp_capabilities``) to go + through unauthenticated — per :data:`adcp.server.DISCOVERY_TOOLS`. +3. Pass the authenticated principal + tenant to handlers as a typed + :class:`adcp.server.ToolContext`. + +Run:: + + uv run python examples/mcp_with_auth_middleware.py + # → server on http://localhost:3001/mcp/ + # curl -H 'Authorization: Bearer token-acme' ... + +Production note: ``mcp.run()`` is used here for brevity. Real deployments +should mount the Starlette app behind uvicorn + a reverse proxy that +terminates TLS and handles rate limiting. +""" + +from __future__ import annotations + +import hashlib +import hmac +from contextvars import ContextVar +from typing import Any + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse + +from adcp.server import ( + DISCOVERY_TOOLS, + ADCPHandler, + RequestMetadata, + ToolContext, + create_mcp_server, +) +from adcp.server.responses import capabilities_response, products_response + +# ---------------------------------------------------------------------- +# Per-request auth state — populated by middleware, read by context_factory. +# ContextVars are the recommended carrier: they compose cleanly with +# async tasks and don't leak across requests the way module globals do. +# IMPORTANT: always pair ``.set(x)`` with ``.reset(token)`` in a ``finally:`` +# block so the value doesn't linger in the current context past the +# response — otherwise a subsequent task reusing the same context reads a +# stale principal (cross-request confidentiality leak). +# ---------------------------------------------------------------------- + +_principal: ContextVar[str | None] = ContextVar("adcp_principal", default=None) +_tenant: ContextVar[str | None] = ContextVar("adcp_tenant", default=None) + + +# Real agents look tokens up in Postgres / Vault / an identity provider / +# etc. This dict is a stand-in: it stores a per-token SHA-256 so the +# example's token-compare path uses ``hmac.compare_digest`` (constant-time) +# against a hash rather than comparing raw bearer tokens with ``==`` or +# ``in``. Never ship plain-text token equality against a user-supplied +# bearer token — it leaks information via timing, and dict lookups short- +# circuit on hash mismatch. +_TOKEN_HASHES: dict[str, tuple[str, str]] = { + hashlib.sha256(raw.encode()).hexdigest(): (principal, tenant) + for raw, (principal, tenant) in { + "token-acme": ("principal-acme-ops", "tenant-acme"), + "token-globex": ("principal-globex-ops", "tenant-globex"), + }.items() +} + + +def _lookup_token(token: str) -> tuple[str, str] | None: + """Constant-time bearer-token lookup. + + Iterate all known hashes with ``hmac.compare_digest`` so the wall-clock + runtime doesn't depend on how much of the candidate matches any entry — + the dict-lookup-then-equality pattern leaks that. + """ + if not token: + return None + candidate = hashlib.sha256(token.encode()).hexdigest() + for stored_hash, identity in _TOKEN_HASHES.items(): + if hmac.compare_digest(candidate, stored_hash): + return identity + return None + + +# ---------------------------------------------------------------------- +# HTTP middleware — auth gate, honors DISCOVERY_TOOLS. +# ---------------------------------------------------------------------- + + +class BearerAuthMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next: Any) -> Any: + tool = await _peek_tool_name(request) + + principal_token = None + tenant_token = None + try: + # AdCP spec: ``get_adcp_capabilities`` is the discovery handshake — + # clients MUST be able to call it without authenticating. + if tool in DISCOVERY_TOOLS: + principal_token = _principal.set(None) + tenant_token = _tenant.set(None) + return await call_next(request) + + # Everything else requires a bearer token. + auth_header = request.headers.get("authorization", "") + bearer = auth_header.removeprefix("Bearer ").strip() + identity = _lookup_token(bearer) + if identity is None: + return JSONResponse({"error": "unauthenticated"}, status_code=401) + + principal_id, tenant_id = identity + principal_token = _principal.set(principal_id) + tenant_token = _tenant.set(tenant_id) + return await call_next(request) + finally: + # Reset unconditionally. Without this, a later task running in + # the same context reads the leftover principal — a + # cross-request confidentiality leak. + if principal_token is not None: + _principal.reset(principal_token) + if tenant_token is not None: + _tenant.reset(tenant_token) + + +async def _peek_tool_name(request: Request) -> str | None: + """Inspect the JSON-RPC body without consuming it for downstream handlers.""" + body = await request.body() + if not body: + return None + import json + + try: + payload = json.loads(body) + except ValueError: + return None + if payload.get("method") != "tools/call": + return None + params = payload.get("params") or {} + name = params.get("name") + return name if isinstance(name, str) else None + + +# ---------------------------------------------------------------------- +# context_factory — runs per tool call, reads the ContextVars the +# middleware populated, returns a typed ToolContext. +# ---------------------------------------------------------------------- + + +def build_context(meta: RequestMetadata) -> ToolContext: + return ToolContext( + request_id=meta.request_id, + caller_identity=_principal.get(), + tenant_id=_tenant.get(), + metadata={"tool_name": meta.tool_name, "transport": meta.transport}, + ) + + +# ---------------------------------------------------------------------- +# Handler — reads caller_identity + tenant_id off the ToolContext. +# ---------------------------------------------------------------------- + + +class MultiTenantSalesAgent(ADCPHandler): + _agent_type = "demo multi-tenant sales agent" + + async def get_adcp_capabilities( + self, params: Any, context: ToolContext | None = None + ) -> dict[str, Any]: + return capabilities_response(["media_buy"]) + + async def get_products(self, params: Any, context: ToolContext | None = None) -> dict[str, Any]: + # context.caller_identity is the authenticated principal; + # context.tenant_id is populated for multi-tenant agents. + tenant = context.tenant_id if context is not None else None + catalog = _products_for_tenant(tenant) + return products_response(catalog) + + +def _products_for_tenant(tenant_id: str | None) -> list[dict[str, Any]]: + if tenant_id == "tenant-acme": + return [{"product_id": "acme_display_1", "name": "Acme homepage display"}] + if tenant_id == "tenant-globex": + return [{"product_id": "globex_video_1", "name": "Globex CTV video"}] + return [] + + +# ---------------------------------------------------------------------- +# Wiring — create_mcp_server with context_factory, then add middleware +# to the Starlette app. +# ---------------------------------------------------------------------- + + +def main() -> None: + mcp = create_mcp_server( + MultiTenantSalesAgent(), + name="multi-tenant-sales-agent", + context_factory=build_context, + ) + + # Middleware must be added BEFORE the app runs. create_mcp_server + # returns a FastMCP instance; its ASGI app is accessed via + # streamable_http_app(), which is a standard Starlette app. + app = mcp.streamable_http_app() + app.add_middleware(BearerAuthMiddleware) + + # mcp.run() hands control to FastMCP. In production, mount with + # uvicorn and a reverse proxy for TLS + rate limiting. + mcp.run(transport="streamable-http") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml index 08f4d9e9f..b0d98f47a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,11 @@ dev = [ # CreateMediaBuyResponse) shifts between versions, producing diff churn and breaking # generated-code imports that reference specific suffixes. "datamodel-code-generator[http]==0.56.1", + # Runs Starlette app lifespan under httpx.ASGITransport in tests — + # the canonical library for what httpx doesn't do natively. Used by + # tests/test_mcp_middleware_composition.py and future integration + # tests that exercise the streamable-HTTP ASGI app in-process. + "asgi-lifespan>=2.1.0", ] docs = [ "pdoc3>=0.10.0", diff --git a/scripts/post_generate_fixes.py b/scripts/post_generate_fixes.py index 1559e66b4..2d011e8dc 100644 --- a/scripts/post_generate_fixes.py +++ b/scripts/post_generate_fixes.py @@ -383,7 +383,9 @@ def __getattr__(self, name): ... lines = content.split("\n") # Collect classes to unwrap (process in reverse order to preserve line numbers) - replacements: list[tuple[int, int, str, str]] = [] # (start_line, end_line, name, union_types) + replacements: list[tuple[int, int, str, str]] = ( + [] + ) # (start_line, end_line, name, union_types) for node in ast.walk(tree): if not isinstance(node, ast.ClassDef) or node.name not in _UNWRAP_TO_UNION: @@ -423,9 +425,7 @@ def __getattr__(self, name): ... if "\n" in union_types: # Re-indent continuation lines to 4 spaces union_lines = [ln.strip() for ln in union_types.split("\n")] - indented = union_lines[0] + "\n" + "\n".join( - f" {ln}" for ln in union_lines[1:] - ) + indented = union_lines[0] + "\n" + "\n".join(f" {ln}" for ln in union_lines[1:]) replacement = f"{type_name} = (\n {indented}\n)" else: replacement = f"{type_name} = {union_types}" @@ -480,9 +480,7 @@ def add_rootmodel_getattr_proxy(): # Ensure Any is imported before parsing AST (avoids line number shift) if "from typing import Any" not in source and "Any," not in source: if "from typing import " in source: - source = source.replace( - "from typing import ", "from typing import Any, ", 1 - ) + source = source.replace("from typing import ", "from typing import Any, ", 1) else: source = "from typing import Any\n" + source @@ -554,8 +552,8 @@ def fix_list_field_shadowing(): # Replace list[identifier...] and dict[str, list[identifier...]] patterns content = re.sub( - r'(? object: if name in _REMOVED_IN_V4: + hint, anchor = _REMOVED_IN_V4[name] raise ImportError( - f"`{name}` was removed in adcp 4.0: {_REMOVED_IN_V4[name]}. " - "See MIGRATION_v3_to_v4.md." + f"`{name}` was removed in adcp 4.0: {hint}. " f"See MIGRATION_v3_to_v4.md#{anchor}." ) raise AttributeError(f"module 'adcp' has no attribute {name!r}") diff --git a/src/adcp/server/__init__.py b/src/adcp/server/__init__.py index c5e934e4d..c3563297b 100644 --- a/src/adcp/server/__init__.py +++ b/src/adcp/server/__init__.py @@ -80,7 +80,13 @@ async def get_products(params, context=None): valid_actions_for_status, ) from adcp.server.idempotency import IdempotencyStore, MemoryBackend -from adcp.server.mcp_tools import MCPToolSet, create_mcp_tools, get_tools_for_handler +from adcp.server.mcp_tools import ( + DISCOVERY_TOOLS, + MCPToolSet, + create_mcp_tools, + get_tools_for_handler, + validate_discovery_set, +) from adcp.server.proposal import ProposalBuilder, ProposalNotSupported from adcp.server.responses import ( activate_signal_response, @@ -103,7 +109,12 @@ async def get_products(params, context=None): sync_governance_response, update_media_buy_response, ) -from adcp.server.serve import create_mcp_server, serve +from adcp.server.serve import ( + ContextFactory, + RequestMetadata, + create_mcp_server, + serve, +) from adcp.server.sponsored_intelligence import SponsoredIntelligenceHandler from adcp.server.test_controller import ( TestControllerError, @@ -131,11 +142,15 @@ async def get_products(params, context=None): "ProposalBuilder", "ProposalNotSupported", # MCP integration + "ContextFactory", + "DISCOVERY_TOOLS", "MCPToolSet", + "RequestMetadata", "create_mcp_tools", "create_mcp_server", "get_tools_for_handler", "serve", + "validate_discovery_set", # A2A integration "ADCPAgentExecutor", "create_a2a_server", diff --git a/src/adcp/server/a2a_server.py b/src/adcp/server/a2a_server.py index 3026fd25b..7694440a2 100644 --- a/src/adcp/server/a2a_server.py +++ b/src/adcp/server/a2a_server.py @@ -12,7 +12,7 @@ import json import logging import os -from typing import Any +from typing import TYPE_CHECKING, Any from uuid import uuid4 from a2a.server.agent_execution.agent_executor import AgentExecutor @@ -37,6 +37,9 @@ from adcp.exceptions import ADCPError, ADCPTaskError from adcp.server.base import ADCPHandler, ToolContext + +if TYPE_CHECKING: + from adcp.server.serve import ContextFactory from adcp.server.helpers import STANDARD_ERROR_CODES from adcp.server.mcp_tools import create_tool_caller, get_tools_for_handler from adcp.server.test_controller import TestControllerStore, _handle_test_controller @@ -59,8 +62,11 @@ def __init__( self, handler: ADCPHandler, test_controller: TestControllerStore | None = None, + *, + context_factory: ContextFactory | None = None, ) -> None: self._handler = handler + self._context_factory = context_factory self._tool_callers: dict[str, Any] = {} # Build tool callers for all tools this handler supports. @@ -104,7 +110,7 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non await self._send_error(event_queue, context, f"Unknown skill: {skill_name}") return - tool_context = _tool_context_from_request(context) + tool_context = self._build_tool_context(skill_name, context) try: result = await self._tool_callers[skill_name](params, tool_context) await self._send_result(event_queue, context, skill_name, result) @@ -120,6 +126,49 @@ async def execute(self, context: RequestContext, event_queue: EventQueue) -> Non logger.exception("Error executing skill %s", skill_name) await self._send_error(event_queue, context, f"Skill execution failed: {skill_name}") + def _build_tool_context(self, skill_name: str, request: RequestContext) -> ToolContext: + """Build the :class:`ToolContext` handed to the skill dispatcher. + + When ``context_factory`` is configured, call it with a + :class:`RequestMetadata` describing this A2A invocation; overlay the + transport-derived ``caller_identity`` / ``request_id`` afterwards + **only when the factory left them unset**, so factories that already + know the principal (e.g. from a ContextVar the seller's auth layer + populated) aren't clobbered. + + When no factory is configured, fall back to the A2A-only path that + derives ``caller_identity`` from ``ServerCallContext.user`` — + preserving behavior for sellers who haven't adopted + ``context_factory=`` yet. + """ + if self._context_factory is None: + return _tool_context_from_request(request) + + from adcp.server.serve import RequestMetadata + + meta = RequestMetadata( + tool_name=skill_name, + transport="a2a", + request_id=request.task_id, + ) + ctx = self._context_factory(meta) + if not isinstance(ctx, ToolContext): + raise TypeError( + f"context_factory for skill {skill_name!r} returned " + f"{type(ctx).__name__}, not a ToolContext instance" + ) + # Fill in transport-derived fields the factory didn't set. This + # preserves the pre-factory A2A security invariant: if the seller + # didn't explicitly populate caller_identity in their factory, + # fall through to ServerCallContext.user (verified by the a2a-sdk + # auth middleware) rather than silently sending None. + if ctx.caller_identity is None: + fallback = _tool_context_from_request(request) + ctx.caller_identity = fallback.caller_identity + if ctx.request_id is None: + ctx.request_id = request.task_id + return ctx + async def cancel(self, context: RequestContext, event_queue: EventQueue) -> None: """ADCP operations are synchronous; cancellation sets state to canceled.""" event = _make_task( @@ -388,6 +437,7 @@ def create_a2a_server( description: str | None = None, version: str = "1.0.0", test_controller: TestControllerStore | None = None, + context_factory: ContextFactory | None = None, ) -> Any: """Create an A2A Starlette application from an ADCP handler. @@ -398,6 +448,15 @@ def create_a2a_server( description: Agent description for the agent card. version: Agent version string. test_controller: Optional TestControllerStore for storyboard testing. + context_factory: Optional callable invoked per skill call to build + a :class:`ToolContext` from :class:`RequestMetadata`. Mirrors + the MCP-side ``context_factory=`` on + :func:`~adcp.server.create_mcp_server` so a single factory + populates tenant/adapter fields on both transports. When + unset, the executor falls back to deriving ``caller_identity`` + from ``ServerCallContext.user`` — preserving pre-factory + behavior. See :data:`~adcp.server.ContextFactory` for the + recommended contextvars pattern. Returns: A Starlette app ready to be run with uvicorn. @@ -406,7 +465,9 @@ def create_a2a_server( resolved_port = port or int(os.environ.get("PORT", "3001")) - executor = ADCPAgentExecutor(handler, test_controller=test_controller) + executor = ADCPAgentExecutor( + handler, test_controller=test_controller, context_factory=context_factory + ) agent_card = _build_agent_card( handler, diff --git a/src/adcp/server/base.py b/src/adcp/server/base.py index b1eb4e748..9546f79c2 100644 --- a/src/adcp/server/base.py +++ b/src/adcp/server/base.py @@ -83,6 +83,11 @@ class ToolContext: Contains metadata about the current request that may be useful for logging, authorization, or other cross-cutting concerns. + Subclassing is supported. Multi-tenant agents commonly define a + subclass carrying typed tenant + adapter fields (see + ``docs/handler-authoring.md``) and populate it from a + ``context_factory`` passed to :func:`create_mcp_server`. + :param caller_identity: The authenticated principal making the request. **MUST** be a stable, globally-unique identifier within the seller's tenant — never an email, display name, or any other mutable handle. @@ -92,10 +97,26 @@ class ToolContext: causes cross-principal replay (confidentiality leak). Populated by the transport layer (A2A: ``ServerCallContext.user.user_name``; MCP: seller's FastMCP auth middleware). + :param tenant_id: Multi-tenant agents may populate this with the tenant + the request is scoped to. Typed as a first-class field so + multi-tenant handlers don't have to smuggle it through ``metadata``. + The server-side idempotency middleware composes the cache scope key + from ``(tenant_id, caller_identity)`` when ``tenant_id`` is set — + sellers whose principal IDs are only unique *within* a tenant (Okta + group-scoped, SCIM per-tenant, seller-internal employee IDs) **MUST** + populate this so cross-tenant response replay can't happen. When + unset, the scope collapses to ``caller_identity`` alone (safe for + single-tenant deployments). + :param metadata: Open extension point for transport-specific or + agent-specific fields (e.g. adapter instance handles, request + headers, testing hooks). Downstream agents may subclass + :class:`ToolContext` for typed fields; ``metadata`` is the escape + hatch when subclassing isn't worth it. """ request_id: str | None = None caller_identity: str | None = None + tenant_id: str | None = None metadata: dict[str, Any] = field(default_factory=dict) diff --git a/src/adcp/server/idempotency/backends.py b/src/adcp/server/idempotency/backends.py index d51a00ebf..2f26cc5e9 100644 --- a/src/adcp/server/idempotency/backends.py +++ b/src/adcp/server/idempotency/backends.py @@ -52,19 +52,23 @@ class IdempotencyBackend(ABC): """ @abstractmethod - async def get( - self, principal_id: str, key: str - ) -> CachedResponse | None: - """Return the cached entry, or None if missing or expired.""" + async def get(self, scope_key: str, key: str) -> CachedResponse | None: + """Return the cached entry, or None if missing or expired. + + ``scope_key`` is the caller-composed identity scope — typically + ``tenant_id + caller_identity``. Backends treat it as an opaque + string; the composition is owned by + :class:`~adcp.server.idempotency.IdempotencyStore`. + """ @abstractmethod async def put( self, - principal_id: str, + scope_key: str, key: str, entry: CachedResponse, ) -> None: - """Store ``entry`` under ``(principal_id, key)``. Overwrites any prior + """Store ``entry`` under ``(scope_key, key)``. Overwrites any prior entry — the store only calls ``put`` after verifying the slot is empty or expired, so an overwrite in that window is a legitimate retry of the write itself.""" @@ -101,28 +105,26 @@ def __init__(self, *, clock: Callable[[], float] = time.time) -> None: self._lock = asyncio.Lock() self._clock = clock - async def get( - self, principal_id: str, key: str - ) -> CachedResponse | None: + async def get(self, scope_key: str, key: str) -> CachedResponse | None: async with self._lock: - entry = self._store.get((principal_id, key)) + entry = self._store.get((scope_key, key)) if entry is None: return None if entry.expires_at_epoch <= self._clock(): # Lazy expiry — drop the stale entry so the next request # treats the slot as fresh and races to repopulate. - del self._store[(principal_id, key)] + del self._store[(scope_key, key)] return None return entry async def put( self, - principal_id: str, + scope_key: str, key: str, entry: CachedResponse, ) -> None: async with self._lock: - self._store[(principal_id, key)] = entry + self._store[(scope_key, key)] = entry async def delete_expired(self, now_epoch: float | None = None) -> int: cutoff = now_epoch if now_epoch is not None else self._clock() @@ -167,12 +169,12 @@ class PgBackend(IdempotencyBackend): .. code-block:: sql CREATE TABLE adcp_idempotency ( - principal_id TEXT COLLATE "C" NOT NULL, + scope_key TEXT COLLATE "C" NOT NULL, key TEXT COLLATE "C" NOT NULL, payload_hash TEXT NOT NULL, response JSONB NOT NULL, expires_at TIMESTAMPTZ NOT NULL, - PRIMARY KEY (principal_id, key) + PRIMARY KEY (scope_key, key) ); Notes: @@ -181,13 +183,15 @@ class PgBackend(IdempotencyBackend): the default locale collation on the identifier columns. On some locales ``Principal-A`` and ``principal-a`` compare equal, which would collapse distinct tenants into the same cache slot. - * Queries MUST filter on ``principal_id`` in the ``WHERE`` clause even - with the composite PK — row-level security (RLS) enforced via a - policy like ``USING (principal_id = current_setting('adcp.principal_id')::text)`` - gives belt-and-suspenders protection against accidental cross-tenant - reads in future handlers. + * ``scope_key`` is already composed from ``(tenant_id, caller_identity)`` + by the store — Postgres sees it as an opaque string. Queries MUST + still filter on ``scope_key`` in the ``WHERE`` clause even with the + composite PK — row-level security (RLS) enforced via a policy like + ``USING (scope_key = current_setting('adcp.scope_key')::text)`` gives + belt-and-suspenders protection against accidental cross-tenant reads + in future handlers. * ``get`` uses ``SELECT ... WHERE expires_at > now()``. - * ``put`` uses ``INSERT ... ON CONFLICT (principal_id, key) DO UPDATE``. + * ``put`` uses ``INSERT ... ON CONFLICT (scope_key, key) DO UPDATE``. * Accept a SQLAlchemy/asyncpg session factory so the caller can thread the handler's transaction through for atomic commit — the atomicity guarantee is the whole reason to use a SQL backend. @@ -202,20 +206,16 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: "https://github.com/adcontextprotocol/adcp-client-python/issues/182." ) - async def get( - self, principal_id: str, key: str - ) -> CachedResponse | None: # pragma: no cover + async def get(self, scope_key: str, key: str) -> CachedResponse | None: # pragma: no cover raise NotImplementedError async def put( self, - principal_id: str, + scope_key: str, key: str, entry: CachedResponse, ) -> None: # pragma: no cover raise NotImplementedError - async def delete_expired( - self, now_epoch: float | None = None - ) -> int: # pragma: no cover + async def delete_expired(self, now_epoch: float | None = None) -> int: # pragma: no cover raise NotImplementedError diff --git a/src/adcp/server/idempotency/store.py b/src/adcp/server/idempotency/store.py index 8ba6bb8ef..63927f16c 100644 --- a/src/adcp/server/idempotency/store.py +++ b/src/adcp/server/idempotency/store.py @@ -3,7 +3,8 @@ Responsibilities: 1. Extract ``idempotency_key`` from the incoming request. -2. Scope lookups by ``(principal_id, key)`` via the backend. +2. Scope lookups by ``(scope_key, key)`` via the backend, where ``scope_key`` + composes ``tenant_id`` (when present) with ``caller_identity``. 3. On cache hit with matching canonical payload hash: return the cached response and mark ``replayed=True`` on the envelope. 4. On cache hit with a different hash: raise @@ -11,10 +12,17 @@ 5. On miss: run the wrapped handler, then commit ``(hash, response)`` to the backend. -Per-principal scoping is a hard security requirement (AdCP #2315): a key from -principal A on seller S has no meaning for principal B. The store pulls the -principal id from :class:`adcp.server.base.ToolContext.caller_identity`. If no -context / no caller_identity is supplied, the store refuses to proceed — +Per-scope scoping is a hard security requirement (AdCP #2315): a key from +principal A on tenant T has no meaning for principal B or tenant T'. The store +pulls both ``tenant_id`` and ``caller_identity`` from +:class:`adcp.server.base.ToolContext` and composes them into a single scope +key — sellers whose principal ids are only unique *within* a tenant (Okta +group-scoped IDs, seller-internal employee IDs, SCIM per-tenant IDs) must +populate ``tenant_id`` so the store can keep those tenants isolated. When no +``tenant_id`` is set, the scope collapses to ``caller_identity`` alone +(safe for single-tenant deployments). + +If no context / no caller_identity is supplied, the store refuses to proceed — fail-closed rather than collapse every buyer into a shared namespace. """ @@ -113,8 +121,8 @@ async def _wrapped( *args: Any, **kwargs: Any, ) -> Any: - principal_id, idempotency_key, params_dict = self._prepare(params, context) - if principal_id is None or idempotency_key is None: + scope_key, idempotency_key, params_dict = self._prepare(params, context) + if scope_key is None or idempotency_key is None: # No key → spec says the server MUST reject with INVALID_REQUEST. # We let the handler run so validation layers above us (Pydantic, # FastAPI, etc.) can reject with a typed error; the middleware's @@ -123,12 +131,12 @@ async def _wrapped( payload_hash = self._hash_fn(params_dict) - cached = await self.backend.get(principal_id, idempotency_key) + cached = await self.backend.get(scope_key, idempotency_key) if cached is not None: if cached.payload_hash == payload_hash: logger.debug( - "idempotency replay: principal=%s key_prefix=%s", - principal_id, + "idempotency replay: scope=%s key_prefix=%s", + scope_key, idempotency_key[:8], ) return _clone_response(cached.response) @@ -168,14 +176,14 @@ async def _wrapped( # operators, return the result, and accept that the next retry # with this key will re-execute. try: - await self.backend.put(principal_id, idempotency_key, entry) + await self.backend.put(scope_key, idempotency_key, entry) except Exception: logger.warning( - "Idempotency cache put failed for principal=%s key_prefix=%s — " + "Idempotency cache put failed for scope=%s key_prefix=%s — " "handler completed but a subsequent retry with this key will " "re-execute rather than replay. This indicates an operational " "issue with the idempotency backend.", - principal_id, + scope_key, idempotency_key[:8], exc_info=True, ) @@ -184,7 +192,11 @@ async def _wrapped( return _wrapped def _prepare(self, params: Any, context: Any) -> tuple[str | None, str | None, dict[str, Any]]: - """Normalize inputs and extract the (principal, key, params_dict) tuple. + """Normalize inputs and extract the (scope_key, key, params_dict) tuple. + + ``scope_key`` composes ``tenant_id`` (when present) with + ``caller_identity`` so cache entries are isolated across tenants even + if the seller's principal IDs are only unique within each tenant. Returns ``(None, None, params_dict)`` when idempotency doesn't apply (no caller identity or no key supplied). The caller falls through to @@ -195,15 +207,15 @@ def _prepare(self, params: Any, context: Any) -> tuple[str | None, str | None, d idempotency_key = params_dict.get("idempotency_key") if not isinstance(idempotency_key, str) or not idempotency_key: return None, None, params_dict - principal_id = _extract_principal_id(context) - if principal_id is None: + scope_key = _extract_scope_key(context) + if scope_key is None: # No caller identity: we can't safely scope the key. Spec requires # per-principal scope; anything else is a cross-principal replay # attack surface. Fall through to the handler (which will process # the request normally — no dedup, but no security regression). self._warn_missing_principal_once() return None, None, params_dict - return principal_id, idempotency_key, params_dict + return scope_key, idempotency_key, params_dict _missing_principal_warned: bool = False @@ -244,37 +256,73 @@ def _to_dict(value: Any) -> dict[str, Any]: raise TypeError(f"Cannot coerce {type(value).__name__} to dict for idempotency caching") -def _extract_principal_id(context: Any) -> str | None: - """Pull the principal id from a ToolContext or equivalent shape. +# \x1e (ASCII 0x1e "record separator") is used between tenant and principal — +# distinct from anything a tenant/principal id would contain, and the resulting +# scope key stays opaque to callers. Downstream backends compare it as a plain +# string; the separator is only meaningful internally. +_SCOPE_SEP = "\x1e" + + +def _extract_scope_key(context: Any) -> str | None: + """Pull the idempotency scope key from a ToolContext or equivalent shape. + + The scope key composes ``tenant_id`` (when present) with + ``caller_identity`` so cache entries can't collide across tenants whose + principal IDs are only locally unique. Returns ``None`` when no caller + identity is available — idempotency then falls through to the handler + (no dedup, but no cross-principal leakage either). Accepts: - * :class:`adcp.server.base.ToolContext` with ``caller_identity`` - * Any object exposing ``caller_identity`` / ``principal_id`` / ``principal.id`` + * :class:`adcp.server.base.ToolContext` with ``caller_identity`` and + optional ``tenant_id`` + * Any object exposing ``caller_identity`` / ``principal_id`` / + ``principal.id`` (and optional ``tenant_id``) * A dict with any of the above keys """ if context is None: return None + + principal_id: str | None = None + tenant_id: str | None = None + for attr in ("caller_identity", "principal_id"): val = getattr(context, attr, None) if isinstance(val, str) and val: - return val - principal = getattr(context, "principal", None) - if principal is not None: - val = getattr(principal, "id", None) - if isinstance(val, str) and val: - return val - if isinstance(context, dict): + principal_id = val + break + if principal_id is None: + principal = getattr(context, "principal", None) + if principal is not None: + val = getattr(principal, "id", None) + if isinstance(val, str) and val: + principal_id = val + val = getattr(context, "tenant_id", None) + if isinstance(val, str) and val: + tenant_id = val + + if principal_id is None and isinstance(context, dict): for key in ("caller_identity", "principal_id"): val = context.get(key) if isinstance(val, str) and val: - return val - principal = context.get("principal") - if isinstance(principal, dict): - val = principal.get("id") + principal_id = val + break + if principal_id is None: + principal = context.get("principal") + if isinstance(principal, dict): + val = principal.get("id") + if isinstance(val, str) and val: + principal_id = val + if tenant_id is None: + val = context.get("tenant_id") if isinstance(val, str) and val: - return val - return None + tenant_id = val + + if principal_id is None: + return None + if tenant_id is None: + return principal_id + return f"{tenant_id}{_SCOPE_SEP}{principal_id}" def _clone_response(response: dict[str, Any]) -> dict[str, Any]: diff --git a/src/adcp/server/mcp_tools.py b/src/adcp/server/mcp_tools.py index 7793b5467..d2e1b9c40 100644 --- a/src/adcp/server/mcp_tools.py +++ b/src/adcp/server/mcp_tools.py @@ -8,7 +8,7 @@ import json import logging -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any from adcp.server.base import ADCPHandler, ToolContext @@ -850,6 +850,75 @@ # Protocol discovery tool included for all handler types _PROTOCOL_TOOLS: set[str] = {"get_adcp_capabilities"} + +# Tools the AdCP spec allows callers to invoke without an authenticated +# principal. ``get_adcp_capabilities`` is the handshake tool — any client +# has to call it before auth to discover which ops the agent supports and +# what auth scheme to use. Everything else requires a principal. +# +# Sellers wiring their own auth middleware (the SDK explicitly punts auth +# to the transport layer — see :func:`adcp.server.create_mcp_server`) +# should import this and skip auth enforcement for any tool name in the +# set. Downstream MAY extend it for discovery tools outside the AdCP spec +# (e.g. a public ``list_public_formats`` surface). The base set is the +# spec-mandated floor, not a cap. +# +# Example:: +# +# from adcp.server import DISCOVERY_TOOLS +# +# async def dispatch(self, request, call_next): +# tool = _extract_tool_name(request) +# if tool not in DISCOVERY_TOOLS: +# self._require_valid_token(request) +# return await call_next(request) +DISCOVERY_TOOLS: frozenset[str] = frozenset({"get_adcp_capabilities"}) + + +def validate_discovery_set(tools: Iterable[str]) -> None: + """Fail-closed validation for an auth-optional tool set. + + Downstream that extends :data:`DISCOVERY_TOOLS` (``DISCOVERY_TOOLS | + {"my_public_tool"}``) risks accidentally including a mutation tool, + which would silently unauthenticate writes over HTTP. This helper + asserts every name in the set resolves to a known ADCP tool whose + annotations declare ``readOnlyHint: True`` — it refuses to pass + anything mutating, destructive, or unknown. + + Call this at server startup on the effective set your middleware + uses:: + + from adcp.server import DISCOVERY_TOOLS, validate_discovery_set + + MY_DISCOVERY = DISCOVERY_TOOLS | {"list_public_formats"} + validate_discovery_set(MY_DISCOVERY) # raises early if misconfigured + + :raises ValueError: if any name in ``tools`` is unknown or resolves + to a non-read-only tool. + """ + by_name = {t["name"]: t for t in ADCP_TOOL_DEFINITIONS} + unknown: list[str] = [] + mutating: list[str] = [] + for name in tools: + tool = by_name.get(name) + if tool is None: + unknown.append(name) + continue + annotations = tool.get("annotations") or {} + if not annotations.get("readOnlyHint"): + mutating.append(name) + problems: list[str] = [] + if unknown: + problems.append(f"unknown tool(s): {sorted(unknown)}") + if mutating: + problems.append( + f"non-read-only tool(s) {sorted(mutating)} — adding these to the " + "auth-optional set would silently unauthenticate mutations" + ) + if problems: + raise ValueError("validate_discovery_set rejected the set: " + "; ".join(problems)) + + # Tools specific to each specialized handler type _HANDLER_TOOLS: dict[str, set[str]] = { "GovernanceHandler": { diff --git a/src/adcp/server/serve.py b/src/adcp/server/serve.py index 72846eef7..6fde3a10c 100644 --- a/src/adcp/server/serve.py +++ b/src/adcp/server/serve.py @@ -20,15 +20,82 @@ async def get_adcp_capabilities(self, params, context=None): import os from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Literal -from adcp.server.base import ADCPHandler +from adcp.server.base import ADCPHandler, ToolContext from adcp.server.mcp_tools import create_tool_caller, get_tools_for_handler if TYPE_CHECKING: from adcp.server.test_controller import TestControllerStore +@dataclass(frozen=True) +class RequestMetadata: + """Per-request metadata passed to :class:`ContextFactory`. + + Populated by the SDK before invoking the factory. Stable across the + MCP and A2A transports — factories written against this shape work + on both sides. Additional fields may be added in minor releases; + factories should keep accepting ``RequestMetadata`` and pluck the + fields they need by name, not by positional unpacking. + + :param tool_name: The AdCP operation being invoked (e.g. + ``"get_products"``, ``"create_media_buy"``). Useful for + tool-level audit logging and feature flagging. + :param transport: ``"mcp"`` or ``"a2a"`` — the wire protocol + currently dispatching this call. Agents that expose both can + use this to branch on transport-specific behavior. + :param request_id: The transport-assigned request id when one + exists (A2A populates this from the task id; MCP leaves it + ``None`` at the SDK layer today). + """ + + tool_name: str + transport: Literal["mcp", "a2a"] + request_id: str | None = None + + +ContextFactory = Callable[[RequestMetadata], ToolContext] +"""Factory invoked per tool call to build a :class:`ToolContext`. + +The SDK's server-side idempotency middleware reads +``ToolContext.caller_identity`` (and ``tenant_id`` for multi-tenant +scope) for cache keying, so factories wiring auth MUST populate +``caller_identity``. See :class:`~adcp.server.base.ToolContext` for +the full field contract. + +The SDK deliberately does not know how your auth middleware surfaces +the authenticated principal — different downstreams use Starlette +``request.state``, ``contextvars.ContextVar``, thread-locals, etc. +The factory closes over whatever mechanism your middleware populates +and returns a ``ToolContext`` (or subclass). + +Example using ``contextvars`` (recommended — middleware-agnostic):: + + from contextvars import ContextVar + from adcp.server import RequestMetadata, ToolContext, create_mcp_server + + _principal: ContextVar[str | None] = ContextVar( + "adcp_principal", default=None + ) + _tenant: ContextVar[str | None] = ContextVar( + "adcp_tenant", default=None + ) + + # Your HTTP middleware sets the ContextVars; tool calls read them. + def build_context(meta: RequestMetadata) -> ToolContext: + return ToolContext( + request_id=meta.request_id, + caller_identity=_principal.get(), + tenant_id=_tenant.get(), + metadata={"tool_name": meta.tool_name, "transport": meta.transport}, + ) + + mcp = create_mcp_server(MyAgent(), context_factory=build_context) +""" + + def serve( handler: ADCPHandler | Any, *, @@ -37,6 +104,7 @@ def serve( transport: str = "streamable-http", instructions: str | None = None, test_controller: TestControllerStore | None = None, + context_factory: ContextFactory | None = None, ) -> None: """Start an MCP or A2A server from an ADCP handler or server builder. @@ -92,7 +160,13 @@ async def force_account_status(self, account_id, status): handler = handler.build_handler() if transport == "a2a": - _serve_a2a(handler, name=name, port=port, test_controller=test_controller) + _serve_a2a( + handler, + name=name, + port=port, + test_controller=test_controller, + context_factory=context_factory, + ) elif transport in ("streamable-http", "sse", "stdio"): _serve_mcp( handler, @@ -101,6 +175,7 @@ async def force_account_status(self, account_id, status): transport=transport, instructions=instructions, test_controller=test_controller, + context_factory=context_factory, ) else: valid = ", ".join(sorted(("a2a", "streamable-http", "sse", "stdio"))) @@ -152,6 +227,7 @@ def _serve_mcp( transport: str, instructions: str | None, test_controller: TestControllerStore | None, + context_factory: ContextFactory | None = None, ) -> None: """Start an MCP server.""" mcp = create_mcp_server( @@ -160,6 +236,7 @@ def _serve_mcp( port=port, instructions=instructions, include_test_controller=test_controller is not None, + context_factory=context_factory, ) if test_controller is not None: @@ -214,6 +291,7 @@ def _serve_a2a( name: str, port: int | None, test_controller: TestControllerStore | None, + context_factory: ContextFactory | None = None, ) -> None: """Start an A2A server using uvicorn.""" import uvicorn @@ -222,7 +300,13 @@ def _serve_a2a( resolved_port = port or int(os.environ.get("PORT", "3001")) - app = create_a2a_server(handler, name=name, port=resolved_port, test_controller=test_controller) + app = create_a2a_server( + handler, + name=name, + port=resolved_port, + test_controller=test_controller, + context_factory=context_factory, + ) sock = _bind_reusable_socket("0.0.0.0", resolved_port) try: config = uvicorn.Config(app) @@ -244,6 +328,7 @@ def create_mcp_server( port: int | None = None, instructions: str | None = None, include_test_controller: bool = False, + context_factory: ContextFactory | None = None, ) -> Any: """Create a FastMCP server from an ADCP handler without starting it. @@ -262,24 +347,85 @@ def create_mcp_server( via :func:`register_test_controller` and sets this flag implicitly. Registering the handler stub unconditionally would advertise a tool the seller didn't opt into. + context_factory: Optional callable invoked per tool call to build + a :class:`ToolContext` from the incoming :class:`RequestMetadata`. + **Wiring this is how the server-side idempotency middleware + gets the caller identity and tenant it needs for per-principal + scoping** — a factory that returns ``caller_identity=None`` + effectively disables idempotency dedup. Sellers wiring their + own HTTP auth middleware pass this to inject the authenticated + principal into ``ToolContext.caller_identity``. See + :data:`ContextFactory` for the recommended contextvars + pattern. When ``None``, handlers receive a bare + ``ToolContext()`` (no caller identity, no tenant). Returns: - A configured FastMCP server instance. Call mcp.run() to start. - - Example: - mcp = create_mcp_server(MyAgent(), name="my-agent") - mcp.run(transport="streamable-http") + A configured FastMCP server instance. Call ``mcp.run()`` to start, + or ``mcp.streamable_http_app()`` to get the Starlette ASGI app for + mounting behind a reverse proxy / adding HTTP middleware. + + Authentication: + The SDK does not enforce authentication itself. Two integration + patterns work: + + 1. **Reverse-proxy auth** (simplest): the proxy (nginx, Caddy, + Envoy) validates credentials and forwards only authenticated + requests. The SDK trusts the proxy's decision. + + 2. **In-process HTTP middleware**: call + ``mcp.streamable_http_app()`` to get the Starlette app, then + ``app.add_middleware(YourAuthMiddleware)``. The middleware + extracts auth state per request (token, tenant, principal) + into ContextVars; ``context_factory`` reads those to build a + typed ``ToolContext``. Tools in + :data:`adcp.server.DISCOVERY_TOOLS` (``get_adcp_capabilities``) + should bypass auth per AdCP spec. See + ``examples/mcp_with_auth_middleware.py`` and + ``docs/handler-authoring.md``. + + Example (basic): + >>> mcp = create_mcp_server(MyAgent(), name="my-agent") + >>> mcp.run(transport="streamable-http") + + Example (custom auth + typed context via contextvars): + >>> from contextvars import ContextVar + >>> from adcp.server import RequestMetadata, ToolContext, create_mcp_server + >>> + >>> _principal: ContextVar[str | None] = ContextVar("p", default=None) + >>> _tenant: ContextVar[str | None] = ContextVar("t", default=None) + >>> + >>> def build_context(meta: RequestMetadata) -> ToolContext: + ... return ToolContext( + ... caller_identity=_principal.get(), + ... tenant_id=_tenant.get(), + ... ) + >>> + >>> mcp = create_mcp_server( + ... MyAgent(), name="my-agent", context_factory=build_context + ... ) + >>> app = mcp.streamable_http_app() + >>> app.add_middleware(MyAuthMiddleware) # sets the ContextVars + >>> # run via uvicorn """ from mcp.server.fastmcp import FastMCP resolved_port = port or int(os.environ.get("PORT", "3001")) mcp = FastMCP(name, instructions=instructions, port=resolved_port) - _register_handler_tools(mcp, handler, include_test_controller=include_test_controller) + _register_handler_tools( + mcp, + handler, + include_test_controller=include_test_controller, + context_factory=context_factory, + ) return mcp def _register_handler_tools( - mcp: Any, handler: ADCPHandler, *, include_test_controller: bool = False + mcp: Any, + handler: ADCPHandler, + *, + include_test_controller: bool = False, + context_factory: ContextFactory | None = None, ) -> None: """Register all ADCP tools from a handler onto a FastMCP server.""" tool_defs = get_tools_for_handler(handler) @@ -293,7 +439,14 @@ def _register_handler_tools( description = tool_def.get("description", "") input_schema = tool_def.get("inputSchema", {"type": "object", "properties": {}}) caller = create_tool_caller(handler, tool_name) - _register_tool(mcp, tool_name, description, input_schema, caller) + _register_tool( + mcp, + tool_name, + description, + input_schema, + caller, + context_factory=context_factory, + ) def _register_tool( @@ -302,6 +455,8 @@ def _register_tool( description: str, input_schema: dict[str, Any], caller: Callable[..., Any], + *, + context_factory: ContextFactory | None = None, ) -> None: """Register a single ADCP tool on a FastMCP server. @@ -318,17 +473,28 @@ def _register_tool( from adcp.server.translate import translate_error async def fn(**kwargs: Any) -> dict[str, Any]: - # Note on caller identity: FastMCP does not expose an authenticated - # principal to tool handlers at the SDK level — ``Context.client_id`` - # is a session hint, not an authenticated user identifier. Sellers - # who need per-principal server middleware (e.g. the idempotency - # store's per-principal scoping) should wire their own FastMCP auth - # middleware and either pre-populate ``params`` with a principal - # hint their handler reads, or override ``create_tool_caller`` to - # build a ToolContext from their auth layer. The A2A transport - # derives caller_identity from ServerCallContext.user automatically. + # Caller identity: FastMCP does not expose an authenticated principal + # at the SDK level (``Context.client_id`` is a session hint, not an + # authenticated user). Sellers wire auth via HTTP middleware on + # ``mcp.streamable_http_app()`` and pass ``context_factory`` to + # ``create_mcp_server()`` — the factory reads a ``contextvars.ContextVar`` + # the middleware populates and returns a typed ``ToolContext``. + # The A2A transport derives ``caller_identity`` from + # ``ServerCallContext.user`` automatically. + context: ToolContext | None = None + if context_factory is not None: + meta = RequestMetadata(tool_name=name, transport="mcp") + context = context_factory(meta) + if not isinstance(context, ToolContext): + # Catch downstream factories that return a dict or other + # shape early — otherwise the handler explodes deep inside + # with an AttributeError on caller_identity. + raise TypeError( + f"context_factory for tool {name!r} returned " + f"{type(context).__name__}, not a ToolContext instance" + ) try: - result = await caller(kwargs) + result = await caller(kwargs, context=context) except ADCPError as exc: # Translate AdCP-typed exceptions (IdempotencyConflictError, # ADCPTaskError with a spec code, etc.) into a ToolError so FastMCP diff --git a/src/adcp/types/__init__.py b/src/adcp/types/__init__.py index 3941ab5a0..e2f1f4275 100644 --- a/src/adcp/types/__init__.py +++ b/src/adcp/types/__init__.py @@ -28,6 +28,11 @@ from __future__ import annotations # Apply type coercion to generated types (must be imported before other types) +# Note: the deprecation shim for the removed ``format_category`` submodule +# lives as a real file at ``generated_poc/enums/format_category.py`` — no +# sys.modules dance needed; Python's import system picks it up natively. +# ``scripts/post_generate_fixes.py`` restores that file after codegen wipes +# ``generated_poc/``. from adcp.types import ( _ergonomic, # noqa: F401 aliases, # noqa: F401 @@ -396,6 +401,8 @@ ActivateSignalSuccessResponse, AgentDeployment, AgentDestination, + AudioFormatAsset, + AudioFormatGroupAsset, AuthorizedAgent, AuthorizedAgentsByInlineProperties, AuthorizedAgentsByPropertyId, @@ -404,10 +411,12 @@ AuthorizedAgentsBySignalId, AuthorizedAgentsBySignalTag, BothPreviewRender, + BriefFormatAsset, BuildCreativeErrorResponse, BuildCreativeSuccessResponse, CalibrateContentErrorResponse, CalibrateContentSuccessResponse, + CatalogFormatAsset, CatalogGroupBinding, ComplyErrorResponse, ComplyListScenariosResponse, @@ -418,6 +427,10 @@ CreateContentStandardsSuccessResponse, CreateMediaBuyErrorResponse, CreateMediaBuySuccessResponse, + CssFormatAsset, + CssFormatGroupAsset, + DaastFormatAsset, + DaastFormatGroupAsset, Deployment, Destination, GetAccountFinancialsErrorResponse, @@ -442,14 +455,22 @@ GetRightsSuccessResponse, GetSignalsDiscoveryRequest, GetSignalsLookupRequest, + HtmlFormatAsset, + HtmlFormatGroupAsset, HtmlPreviewRender, + ImageFormatAsset, + ImageFormatGroupAsset, InlineDaastAsset, InlineVastAsset, + JavascriptFormatAsset, + JavascriptFormatGroupAsset, KeyValueActivationKey, ListContentStandardsErrorResponse, ListContentStandardsSuccessResponse, LogEventErrorResponse, LogEventSuccessResponse, + MarkdownFormatAsset, + MarkdownFormatGroupAsset, MediaBuyDeliveryStatus, PlatformDeployment, PlatformDestination, @@ -467,6 +488,7 @@ PublisherPropertiesAll, PublisherPropertiesById, PublisherPropertiesByTag, + RepeatableAssetGroup, SegmentIdActivationKey, SiSendActionResponseRequest, SiSendTextMessageRequest, @@ -483,6 +505,8 @@ SyncCreativesSuccessResponse, SyncEventSourcesErrorResponse, SyncEventSourcesSuccessResponse, + TextFormatAsset, + TextFormatGroupAsset, UpdateContentStandardsErrorResponse, UpdateContentStandardsSuccessResponse, UpdateMediaBuyErrorResponse, @@ -490,10 +514,18 @@ UpdateMediaBuyPropertiesRequest, UpdateMediaBuySuccessResponse, UrlDaastAsset, + UrlFormatAsset, + UrlFormatGroupAsset, UrlPreviewRender, UrlVastAsset, ValidateContentDeliveryErrorResponse, ValidateContentDeliverySuccessResponse, + VastFormatAsset, + VastFormatGroupAsset, + VideoFormatAsset, + VideoFormatGroupAsset, + WebhookFormatAsset, + WebhookFormatGroupAsset, ) # Re-export core types (not in generated, but part of public API) @@ -1108,6 +1140,35 @@ def __init__(self, *args: object, **kwargs: object) -> None: "ComplyStateTransitionResponse", "ComplySimulationResponse", "ComplyErrorResponse", + # Creative format asset slot aliases (item_type='individual') + "ImageFormatAsset", + "VideoFormatAsset", + "AudioFormatAsset", + "TextFormatAsset", + "MarkdownFormatAsset", + "HtmlFormatAsset", + "CssFormatAsset", + "JavascriptFormatAsset", + "VastFormatAsset", + "DaastFormatAsset", + "UrlFormatAsset", + "WebhookFormatAsset", + "BriefFormatAsset", + "CatalogFormatAsset", + # Creative format asset slot aliases (repeatable groups) + "RepeatableAssetGroup", + "ImageFormatGroupAsset", + "VideoFormatGroupAsset", + "AudioFormatGroupAsset", + "TextFormatGroupAsset", + "MarkdownFormatGroupAsset", + "HtmlFormatGroupAsset", + "CssFormatGroupAsset", + "JavascriptFormatGroupAsset", + "VastFormatGroupAsset", + "DaastFormatGroupAsset", + "UrlFormatGroupAsset", + "WebhookFormatGroupAsset", # Audiences responses "MediaBuyDeliveryStatus", "SyncAudiencesErrorResponse", diff --git a/src/adcp/types/aliases.py b/src/adcp/types/aliases.py index e6223ab28..d2ffd660f 100644 --- a/src/adcp/types/aliases.py +++ b/src/adcp/types/aliases.py @@ -1170,6 +1170,262 @@ def get_pricing(options: list[PricingOption]) -> None: ``` """ +# ============================================================================ +# CREATIVE FORMAT ASSET ALIASES - Discriminated Union on asset_type +# ============================================================================ +# AdCP creative format definitions enumerate asset slots as a discriminated +# union on the ``asset_type`` field (image, video, audio, text, markdown, +# html, css, javascript, vast, daast, url, webhook, brief, catalog). The +# code generator emits numbered class names (``Assets``, ``Assets81``, +# ``Assets82``, ...) that renumber between releases whenever the upstream +# ``$defs`` ordering shifts. These aliases pin semantic names so consumers +# never import ``AssetsNN`` directly. +# +# Two asset shapes exist in a format definition: +# +# 1. Individual asset slots (``item_type='individual'``) — top-level slots in +# a creative format. Aliased as ``FormatAsset``. The ``Format`` +# prefix disambiguates from the separate asset-content types (``VideoAsset``, +# ``HtmlAsset``, etc. in ``adcp.types``) which describe the actual asset +# payload (codec, duration, file URL) delivered by creative sync — a +# distinct concept. +# 2. Group asset variants — the same asset types nested inside a +# ``RepeatableAssetGroup`` (``Assets94``). Aliased as ``FormatGroupAsset``. +# +# Stability contract: these aliases are covered by +# ``tests/test_asset_aliases_stable.py`` which asserts each alias resolves +# to a class whose ``asset_type`` literal default matches the expected +# value. Generator renumbering is caught there, not in downstream code. + +from adcp.types.generated_poc.core.format import ( + Assets as _ImageFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets81 as _VideoFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets82 as _AudioFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets83 as _TextFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets84 as _MarkdownFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets85 as _HtmlFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets86 as _CssFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets87 as _JavascriptFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets88 as _VastFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets89 as _DaastFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets90 as _UrlFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets91 as _WebhookFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets92 as _BriefFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets93 as _CatalogFormatAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets94 as _RepeatableAssetGroupInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets95 as _ImageFormatGroupAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets96 as _VideoFormatGroupAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets97 as _AudioFormatGroupAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets98 as _TextFormatGroupAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets99 as _MarkdownFormatGroupAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets100 as _HtmlFormatGroupAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets101 as _CssFormatGroupAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets102 as _JavascriptFormatGroupAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets103 as _VastFormatGroupAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets104 as _DaastFormatGroupAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets105 as _UrlFormatGroupAssetInternal, +) +from adcp.types.generated_poc.core.format import ( + Assets106 as _WebhookFormatGroupAssetInternal, +) + +ImageFormatAsset = _ImageFormatAssetInternal +"""Image asset slot in a creative format (asset_type='image'). + +Distinct from ``ImageAsset`` in ``adcp.types`` (the asset-content type +describing an actual image payload — dimensions, file URL, etc.). This +alias names the slot shape used inside a format definition. +""" + +VideoFormatAsset = _VideoFormatAssetInternal +"""Video asset slot in a creative format (asset_type='video'). + +Distinct from ``VideoAsset`` in ``adcp.types`` (the asset-content type +describing an actual video payload — codec, duration, file URL). This +alias names the slot shape used inside a format definition. +""" + +AudioFormatAsset = _AudioFormatAssetInternal +"""Audio asset slot in a creative format (asset_type='audio'). + +Distinct from ``AudioAsset`` in ``adcp.types`` (the asset-content type +describing an actual audio payload). This alias names the slot shape +used inside a format definition. +""" + +TextFormatAsset = _TextFormatAssetInternal +"""Text asset slot in a creative format (asset_type='text'). + +Distinct from ``TextAsset`` in ``adcp.types``. This alias names the slot +shape used inside a format definition. +""" + +MarkdownFormatAsset = _MarkdownFormatAssetInternal +"""Markdown asset slot in a creative format (asset_type='markdown'). + +Distinct from the asset-content type in ``adcp.types``. This alias names +the slot shape used inside a format definition. +""" + +HtmlFormatAsset = _HtmlFormatAssetInternal +"""HTML asset slot in a creative format (asset_type='html'). + +Distinct from ``HtmlAsset`` in ``adcp.types`` (the asset-content type +describing actual HTML payload). This alias names the slot shape used +inside a format definition. +""" + +CssFormatAsset = _CssFormatAssetInternal +"""CSS asset slot in a creative format (asset_type='css'). + +Distinct from ``CssAsset`` in ``adcp.types``. This alias names the slot +shape used inside a format definition. +""" + +JavascriptFormatAsset = _JavascriptFormatAssetInternal +"""JavaScript asset slot in a creative format (asset_type='javascript'). + +Distinct from ``JavascriptAsset`` in ``adcp.types``. This alias names +the slot shape used inside a format definition. +""" + +VastFormatAsset = _VastFormatAssetInternal +"""VAST asset slot in a creative format (asset_type='vast'). + +Distinct from ``UrlVastAsset`` / ``InlineVastAsset``, which describe how +the VAST content itself is delivered (url vs inline payload). This alias +names the asset slot inside a format definition. +""" + +DaastFormatAsset = _DaastFormatAssetInternal +"""DAAST asset slot in a creative format (asset_type='daast'). + +Distinct from ``UrlDaastAsset`` / ``InlineDaastAsset``, which describe how +the DAAST content itself is delivered (url vs inline payload). +""" + +UrlFormatAsset = _UrlFormatAssetInternal +"""URL asset slot in a creative format (asset_type='url'). + +Distinct from ``UrlAsset`` in ``adcp.types``. This alias names the slot +shape used inside a format definition. +""" + +WebhookFormatAsset = _WebhookFormatAssetInternal +"""Webhook asset slot in a creative format (asset_type='webhook'). + +Distinct from ``WebhookAsset`` in ``adcp.types``. This alias names the +slot shape used inside a format definition. +""" + +BriefFormatAsset = _BriefFormatAssetInternal +"""Brief asset slot in a creative format (asset_type='brief'). + +Distinct from ``BriefAsset`` in ``adcp.types``. This alias names the +slot shape used inside a format definition. +""" + +CatalogFormatAsset = _CatalogFormatAssetInternal +"""Catalog asset slot in a creative format (asset_type='catalog'). + +Distinct from ``CatalogAsset`` in ``adcp.types``. This alias names the +slot shape used inside a format definition. +""" + +RepeatableAssetGroup = _RepeatableAssetGroupInternal +"""Repeatable asset group in a creative format (item_type='repeatable_group'). + +Holds a sequence of group-variant assets (``FormatGroupAsset``) that +repeat either sequentially (carousels, playlists) or via platform +optimization. +""" + +ImageFormatGroupAsset = _ImageFormatGroupAssetInternal +"""Image asset nested in a RepeatableAssetGroup (asset_type='image').""" + +VideoFormatGroupAsset = _VideoFormatGroupAssetInternal +"""Video asset nested in a RepeatableAssetGroup (asset_type='video').""" + +AudioFormatGroupAsset = _AudioFormatGroupAssetInternal +"""Audio asset nested in a RepeatableAssetGroup (asset_type='audio').""" + +TextFormatGroupAsset = _TextFormatGroupAssetInternal +"""Text asset nested in a RepeatableAssetGroup (asset_type='text').""" + +MarkdownFormatGroupAsset = _MarkdownFormatGroupAssetInternal +"""Markdown asset nested in a RepeatableAssetGroup (asset_type='markdown').""" + +HtmlFormatGroupAsset = _HtmlFormatGroupAssetInternal +"""HTML asset nested in a RepeatableAssetGroup (asset_type='html').""" + +CssFormatGroupAsset = _CssFormatGroupAssetInternal +"""CSS asset nested in a RepeatableAssetGroup (asset_type='css').""" + +JavascriptFormatGroupAsset = _JavascriptFormatGroupAssetInternal +"""JavaScript asset nested in a RepeatableAssetGroup (asset_type='javascript').""" + +VastFormatGroupAsset = _VastFormatGroupAssetInternal +"""VAST asset slot nested in a RepeatableAssetGroup (asset_type='vast').""" + +DaastFormatGroupAsset = _DaastFormatGroupAssetInternal +"""DAAST asset slot nested in a RepeatableAssetGroup (asset_type='daast').""" + +UrlFormatGroupAsset = _UrlFormatGroupAssetInternal +"""URL asset slot nested in a RepeatableAssetGroup (asset_type='url').""" + +WebhookFormatGroupAsset = _WebhookFormatGroupAssetInternal +"""Webhook asset slot nested in a RepeatableAssetGroup (asset_type='webhook').""" + # ============================================================================ # EXPORTS # ============================================================================ @@ -1328,4 +1584,33 @@ def get_pricing(options: list[PricingOption]) -> None: "ComplyStateTransitionResponse", "ComplySimulationResponse", "ComplyErrorResponse", + # Creative format asset slot aliases (item_type='individual') + "ImageFormatAsset", + "VideoFormatAsset", + "AudioFormatAsset", + "TextFormatAsset", + "MarkdownFormatAsset", + "HtmlFormatAsset", + "CssFormatAsset", + "JavascriptFormatAsset", + "VastFormatAsset", + "DaastFormatAsset", + "UrlFormatAsset", + "WebhookFormatAsset", + "BriefFormatAsset", + "CatalogFormatAsset", + # Creative format asset slot aliases (repeatable groups) + "RepeatableAssetGroup", + "ImageFormatGroupAsset", + "VideoFormatGroupAsset", + "AudioFormatGroupAsset", + "TextFormatGroupAsset", + "MarkdownFormatGroupAsset", + "HtmlFormatGroupAsset", + "CssFormatGroupAsset", + "JavascriptFormatGroupAsset", + "VastFormatGroupAsset", + "DaastFormatGroupAsset", + "UrlFormatGroupAsset", + "WebhookFormatGroupAsset", ] diff --git a/src/adcp/types/generated_poc/enums/format_category.py b/src/adcp/types/generated_poc/enums/format_category.py new file mode 100644 index 000000000..89c830e29 --- /dev/null +++ b/src/adcp/types/generated_poc/enums/format_category.py @@ -0,0 +1,23 @@ +"""Deprecation shim for the removed ``format_category`` submodule. + +``FormatCategory`` was replaced by free-form ``FormatId`` strings in +AdCP 3.0. See MIGRATION_v3_to_v4.md for the full migration path. + +Importing this module raises :class:`ImportError` with a pointer to the +migration guide — so downstream import sites like:: + + from adcp.types.generated_poc.enums.format_category import FormatCategory + +get the same pointer as the top-level ``from adcp import FormatCategory``, +instead of a bare ``ModuleNotFoundError``. + +This file is restored after every codegen run by +``scripts/post_generate_fixes.py`` (which wipes ``generated_poc/``). +""" + +raise ImportError( + "adcp.types.generated_poc.enums.format_category was removed in AdCP 3.0. " + "Use free-form format-id strings (e.g. 'goog:video_responsive_ad') via " + "adcp.types.FormatId. See MIGRATION_v3_to_v4.md#creative-format-asset-slots-formataasset-aliases " + "for details." +) diff --git a/tests/fixtures/public_api_snapshot.json b/tests/fixtures/public_api_snapshot.json index d35d20065..d0726f00b 100644 --- a/tests/fixtures/public_api_snapshot.json +++ b/tests/fixtures/public_api_snapshot.json @@ -400,6 +400,8 @@ "Assignments", "AudienceSource", "AudioAsset", + "AudioFormatAsset", + "AudioFormatGroupAsset", "Authentication", "AuthenticationScheme", "AuthorizedAgent", @@ -416,6 +418,7 @@ "BothPreviewRender", "BrandReference", "BrandSource", + "BriefFormatAsset", "BuildCreativeErrorResponse", "BuildCreativeRequest", "BuildCreativeResponse", @@ -433,6 +436,7 @@ "CatalogFieldBinding", "CatalogFieldBinding1", "CatalogFieldMapping", + "CatalogFormatAsset", "CatalogGroupBinding", "CatalogItemStatus", "CatalogRequirements", @@ -493,6 +497,10 @@ "CreativeStatus", "CreativeVariant", "CssAsset", + "CssFormatAsset", + "CssFormatGroupAsset", + "DaastFormatAsset", + "DaastFormatGroupAsset", "DaastTrackingEvent", "DaastVersion", "DailyBreakdownItem", @@ -598,16 +606,22 @@ "GetSignalsResponse", "Gtin", "HtmlAsset", + "HtmlFormatAsset", + "HtmlFormatGroupAsset", "HtmlPreviewRender", "HttpMethod", "Identifier", "IdentityMatchRequest", "IdentityMatchResponse", "ImageAsset", + "ImageFormatAsset", + "ImageFormatGroupAsset", "InlineDaastAsset", "InlineVastAsset", "Input", "JavascriptAsset", + "JavascriptFormatAsset", + "JavascriptFormatGroupAsset", "JavascriptModuleType", "KellerType", "KeyValueActivationKey", @@ -634,6 +648,8 @@ "LogEventSuccessResponse", "Logo", "MarkdownFlavor", + "MarkdownFormatAsset", + "MarkdownFormatGroupAsset", "McpWebhookPayload", "MeasurementPeriod", "MediaBuy", @@ -729,6 +745,7 @@ "QuerySummary", "ReachUnit", "Refine", + "RepeatableAssetGroup", "ReportPlanOutcomeRequest", "ReportPlanOutcomeResponse", "ReportUsageRequest", @@ -809,6 +826,8 @@ "TaskResult", "TaskType", "TextAsset", + "TextFormatAsset", + "TextFormatGroupAsset", "TextSubAsset", "TimeBasedPricingOption", "TimeUnit", @@ -838,6 +857,8 @@ "UrlAsset", "UrlAssetType", "UrlDaastAsset", + "UrlFormatAsset", + "UrlFormatGroupAsset", "UrlPreviewRender", "UrlType", "UrlVastAsset", @@ -846,6 +867,8 @@ "ValidateContentDeliveryResponse", "ValidateContentDeliverySuccessResponse", "ValidationMode", + "VastFormatAsset", + "VastFormatGroupAsset", "VastTrackingEvent", "VastVersion", "VcpmAuctionPricingOption", @@ -853,9 +876,13 @@ "VcpmPricingOption", "VenueBreakdownItem", "VideoAsset", + "VideoFormatAsset", + "VideoFormatGroupAsset", "ViewThreshold", "WcagLevel", "WebhookAsset", + "WebhookFormatAsset", + "WebhookFormatGroupAsset", "WebhookMetadata", "WebhookResponseType", "aliases", diff --git a/tests/test_asset_aliases_stable.py b/tests/test_asset_aliases_stable.py new file mode 100644 index 000000000..50300fe6f --- /dev/null +++ b/tests/test_asset_aliases_stable.py @@ -0,0 +1,155 @@ +"""Stability contract for creative-format asset aliases. + +`adcp.types.aliases` pins semantic names onto ``AssetsNN`` generated classes +so downstream consumers never import the numbered form. datamodel-codegen +renumbers these whenever the upstream ``$defs`` ordering shifts, which +silently breaks consumers that pinned to ``Assets5`` or ``Assets14``. + +These tests assert each semantic alias still resolves to a class whose +``asset_type`` (or ``item_type`` for the group container) discriminator +default matches the expected literal. If a test here fails after a schema +regeneration, the generator renumbered something and the corresponding +alias in ``src/adcp/types/aliases.py`` needs its numbered import updated. + +The contract is intentionally loud: add/remove aliases here in lockstep +with ``aliases.py`` so the public API remains stable across generator runs. +""" + +from __future__ import annotations + +from typing import Literal + +import pytest + +from adcp import types as adcp_types + +INDIVIDUAL_ASSET_EXPECTATIONS: dict[str, str] = { + "ImageFormatAsset": "image", + "VideoFormatAsset": "video", + "AudioFormatAsset": "audio", + "TextFormatAsset": "text", + "MarkdownFormatAsset": "markdown", + "HtmlFormatAsset": "html", + "CssFormatAsset": "css", + "JavascriptFormatAsset": "javascript", + "VastFormatAsset": "vast", + "DaastFormatAsset": "daast", + "UrlFormatAsset": "url", + "WebhookFormatAsset": "webhook", + "BriefFormatAsset": "brief", + "CatalogFormatAsset": "catalog", +} + +GROUP_ASSET_EXPECTATIONS: dict[str, str] = { + "ImageFormatGroupAsset": "image", + "VideoFormatGroupAsset": "video", + "AudioFormatGroupAsset": "audio", + "TextFormatGroupAsset": "text", + "MarkdownFormatGroupAsset": "markdown", + "HtmlFormatGroupAsset": "html", + "CssFormatGroupAsset": "css", + "JavascriptFormatGroupAsset": "javascript", + "VastFormatGroupAsset": "vast", + "DaastFormatGroupAsset": "daast", + "UrlFormatGroupAsset": "url", + "WebhookFormatGroupAsset": "webhook", +} + + +def _literal_value(cls, field_name: str) -> str | None: + """Return the single literal value on a Pydantic discriminator field. + + Parses the field's ``annotation`` (the ``Literal[...]`` type) rather + than reading ``FieldInfo.default`` — several generated discriminator + fields declare the literal as the annotation without also setting a + default, which would make ``FieldInfo.default`` return + ``PydanticUndefined`` and any equality comparison vacuously pass. + Reading the annotation catches renumbering even when defaults aren't + populated on the field. + """ + from typing import get_args, get_origin + + field = cls.model_fields[field_name] + annotation = field.annotation + if get_origin(annotation) is Literal: + args = get_args(annotation) + if len(args) == 1 and isinstance(args[0], str): + return args[0] + # Fall back to FieldInfo.default when the annotation isn't a bare + # Literal (Annotated[Literal[...], Field(default=...)] shape). + default = field.default + if isinstance(default, str): + return default + return None + + +@pytest.mark.parametrize( + ("alias_name", "expected_asset_type"), + list(INDIVIDUAL_ASSET_EXPECTATIONS.items()), +) +def test_individual_asset_alias_resolves_to_expected_discriminator( + alias_name: str, expected_asset_type: str +) -> None: + cls = getattr(adcp_types, alias_name) + asset_type = _literal_value(cls, "asset_type") + assert asset_type == expected_asset_type, ( + f"{alias_name} resolved to class with asset_type={asset_type!r}; " + f"expected {expected_asset_type!r}. The generator likely renumbered " + "AssetsNN — update the numbered import in src/adcp/types/aliases.py " + "to point at the class matching this asset_type." + ) + item_type = _literal_value(cls, "item_type") + assert item_type == "individual", ( + f"{alias_name} resolved to class with item_type={item_type!r}; " "expected 'individual'." + ) + + +@pytest.mark.parametrize( + ("alias_name", "expected_asset_type"), + list(GROUP_ASSET_EXPECTATIONS.items()), +) +def test_group_asset_alias_resolves_to_expected_discriminator( + alias_name: str, expected_asset_type: str +) -> None: + cls = getattr(adcp_types, alias_name) + asset_type = _literal_value(cls, "asset_type") + assert asset_type == expected_asset_type, ( + f"{alias_name} resolved to class with asset_type={asset_type!r}; " + f"expected {expected_asset_type!r}. The generator likely renumbered " + "AssetsNN — update the numbered import in src/adcp/types/aliases.py " + "to point at the class matching this asset_type." + ) + + +def test_repeatable_asset_group_discriminator_is_stable() -> None: + cls = adcp_types.RepeatableAssetGroup + item_type = _literal_value(cls, "item_type") + assert item_type == "repeatable_group", ( + f"RepeatableAssetGroup.item_type={item_type!r}; " "expected 'repeatable_group'." + ) + + +def test_format_category_module_raises_migration_pointer() -> None: + # MIGRATION_v3_to_v4: `FormatCategory` was removed from the generated + # schemas in AdCP 3.0. Importing the old module path now raises a + # guided ImportError instead of ModuleNotFoundError. + with pytest.raises(ImportError, match="MIGRATION_v3_to_v4"): + from adcp.types.generated_poc.enums.format_category import ( # noqa: F401 + FormatCategory, + ) + + +def test_all_aliases_exported_from_adcp_types() -> None: + # `adcp.types` is the canonical public surface per CLAUDE.md. + # Top-level ``adcp`` re-exports a curated subset; these format-asset + # aliases are specifically reachable via ``from adcp.types import X``. + missing = [ + name + for name in ( + *INDIVIDUAL_ASSET_EXPECTATIONS, + *GROUP_ASSET_EXPECTATIONS, + "RepeatableAssetGroup", + ) + if not hasattr(adcp_types, name) + ] + assert not missing, f"Asset aliases missing from adcp.types: {missing}" diff --git a/tests/test_mcp_middleware_composition.py b/tests/test_mcp_middleware_composition.py new file mode 100644 index 000000000..39515e306 --- /dev/null +++ b/tests/test_mcp_middleware_composition.py @@ -0,0 +1,310 @@ +"""Integration test: custom HTTP middleware composes with SDK-registered tools. + +Downstream agents (salesagent, creative agents) need to wire their own +auth middleware around tools registered by ``create_mcp_server()``. This +test proves the composition path works end-to-end: + +1. ``mcp.streamable_http_app()`` returns a Starlette app that accepts + ``.add_middleware()``. +2. The middleware fires before tool dispatch and can reject requests + (401 Unauthorized) or let them through. +3. When the middleware lets the request through, a ``context_factory`` + passed to ``create_mcp_server()`` builds a :class:`ToolContext` the + handler receives — populated from the middleware's side-channel + (``contextvars.ContextVar``). +4. Tools in :data:`adcp.server.DISCOVERY_TOOLS` are callable without + auth (the spec-mandated handshake path). + +If any of this regresses, salesagent and every other downstream has to +keep their wrapper layer (``mcp_context_wrapper.py``, custom +``@mcp.tool()`` scaffolding) forever. Failing here is the signal to fix +the integration, not the test. +""" + +from __future__ import annotations + +from contextvars import ContextVar +from typing import Any + +import httpx +import pytest +from asgi_lifespan import LifespanManager +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import JSONResponse + +from adcp.server import ( + DISCOVERY_TOOLS, + ADCPHandler, + RequestMetadata, + ToolContext, + create_mcp_server, +) + +_current_principal: ContextVar[str | None] = ContextVar("test_current_principal", default=None) +_current_tenant: ContextVar[str | None] = ContextVar("test_current_tenant", default=None) + + +class _RecordingHandler(ADCPHandler): + """Handler that records the ToolContext each call received.""" + + def __init__(self) -> None: + self.calls: list[ToolContext | None] = [] + + async def get_adcp_capabilities( + self, params: Any, context: ToolContext | None = None + ) -> dict[str, Any]: + self.calls.append(context) + return {"adcp": {"major_versions": [3]}} + + async def get_products(self, params: Any, context: ToolContext | None = None) -> dict[str, Any]: + self.calls.append(context) + return {"products": []} + + +class _AuthMiddleware(BaseHTTPMiddleware): + """Middleware that validates Authorization headers. + + Rejects any tool call except :data:`DISCOVERY_TOOLS` without a valid + token. On a valid token, stashes principal + tenant in ContextVars + so the handler-side ``context_factory`` can read them. + """ + + VALID_TOKENS: dict[str, tuple[str, str]] = { + "token-acme": ("principal-acme-1", "tenant-acme"), + "token-beta": ("principal-beta-9", "tenant-beta"), + } + + async def dispatch(self, request: Request, call_next: Any) -> Any: + tool_name = await _peek_tool_name(request) + + if tool_name not in DISCOVERY_TOOLS: + auth = request.headers.get("authorization", "") + token = auth.removeprefix("Bearer ").strip() + if token not in self.VALID_TOKENS: + return JSONResponse({"error": "unauthenticated"}, status_code=401) + principal, tenant = self.VALID_TOKENS[token] + _principal_token = _current_principal.set(principal) + _tenant_token = _current_tenant.set(tenant) + else: + _principal_token = _current_principal.set(None) + _tenant_token = _current_tenant.set(None) + + try: + return await call_next(request) + finally: + _current_principal.reset(_principal_token) + _current_tenant.reset(_tenant_token) + + +async def _peek_tool_name(request: Request) -> str | None: + """Extract the MCP tool name from the incoming JSON-RPC body without + consuming the request body for downstream handlers.""" + # Starlette caches ``request._body`` on first read, so subsequent + # reads inside the app still see the bytes. + body = await request.body() + if not body: + return None + try: + import json + + payload = json.loads(body) + except ValueError: + return None + if payload.get("method") != "tools/call": + return None + params = payload.get("params") or {} + name = params.get("name") + return name if isinstance(name, str) else None + + +def _build_context(meta: RequestMetadata) -> ToolContext: + return ToolContext( + request_id=meta.request_id, + caller_identity=_current_principal.get(), + tenant_id=_current_tenant.get(), + metadata={"tool_name": meta.tool_name, "transport": meta.transport}, + ) + + +@pytest.fixture +async def handler_and_client() -> Any: + handler = _RecordingHandler() + mcp = create_mcp_server( + handler, + name="test-agent", + context_factory=_build_context, + ) + # Force stateless JSON responses. Production deployments mount the + # MCP app behind a reverse proxy; this test covers that shape. + mcp.settings.stateless_http = True + mcp.settings.json_response = True + # Allow in-process test host — MCP's DNS-rebinding protection + # rejects unknown Host headers by default when enabled. + mcp.settings.transport_security.allowed_hosts = ["localhost", "127.0.0.1"] + app = mcp.streamable_http_app() + app.add_middleware(_AuthMiddleware) + + # FastMCP's streamable HTTP session manager initializes a TaskGroup + # via the Starlette app lifespan. httpx.ASGITransport does not run + # lifespan by default — asgi-lifespan handles startup/shutdown and + # surfaces exceptions raised during startup so test failures report + # the real error instead of hanging. + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://localhost", + follow_redirects=True, + ) as client: + yield handler, client + + +@pytest.mark.asyncio +async def test_discovery_tool_is_callable_without_auth(handler_and_client: Any) -> None: + handler, client = handler_and_client + + await _initialize_session(client) + response = await _call_tool(client, "get_adcp_capabilities", {}) + + assert response.status_code == 200, response.text + payload = _parse_event_stream(response.text) + assert "result" in payload, payload + assert handler.calls, "handler was not invoked" + call_context = handler.calls[-1] + # Discovery calls have no authenticated principal — that's the whole point. + assert call_context is not None + assert call_context.caller_identity is None + assert call_context.tenant_id is None + + +@pytest.mark.asyncio +async def test_authenticated_tool_call_populates_caller_identity( + handler_and_client: Any, +) -> None: + handler, client = handler_and_client + + await _initialize_session(client, headers={"Authorization": "Bearer token-acme"}) + response = await _call_tool( + client, + "get_products", + {"brief": "coffee"}, + headers={"Authorization": "Bearer token-acme"}, + ) + + assert response.status_code == 200, response.text + call_context = handler.calls[-1] + assert call_context is not None + assert call_context.caller_identity == "principal-acme-1" + assert call_context.tenant_id == "tenant-acme" + + +@pytest.mark.asyncio +async def test_missing_token_blocks_non_discovery_tool(handler_and_client: Any) -> None: + handler, client = handler_and_client + + response = await _call_tool(client, "get_products", {"brief": "coffee"}) + + assert response.status_code == 401 + assert not handler.calls, ( + "handler was invoked despite missing auth — middleware did NOT " + "compose with the tool dispatch" + ) + + +def test_discovery_tools_frozenset_contract() -> None: + # Protects against accidental widening/narrowing of the spec-mandated + # auth-optional set. Callers extend via ``DISCOVERY_TOOLS | {...}``. + assert DISCOVERY_TOOLS == frozenset({"get_adcp_capabilities"}) + + +def test_validate_discovery_set_accepts_base_set() -> None: + from adcp.server import validate_discovery_set + + # The base DISCOVERY_TOOLS set must always validate — any regression + # here means we added a mutation tool to the spec-mandated handshake. + validate_discovery_set(DISCOVERY_TOOLS) + + +def test_validate_discovery_set_accepts_read_only_extension() -> None: + from adcp.server import validate_discovery_set + + # list_creative_formats is annotated read-only — downstream that + # wants to make format listing public should be allowed to. + validate_discovery_set(DISCOVERY_TOOLS | {"list_creative_formats"}) + + +def test_validate_discovery_set_rejects_mutation_tool() -> None: + from adcp.server import validate_discovery_set + + with pytest.raises(ValueError, match="non-read-only"): + validate_discovery_set(DISCOVERY_TOOLS | {"create_media_buy"}) + + +def test_validate_discovery_set_rejects_unknown_tool() -> None: + from adcp.server import validate_discovery_set + + with pytest.raises(ValueError, match="unknown tool"): + validate_discovery_set(DISCOVERY_TOOLS | {"not_a_real_tool"}) + + +# ---------------------------------------------------------------------- +# Helpers +# ---------------------------------------------------------------------- + + +async def _initialize_session( + client: httpx.AsyncClient, *, headers: dict[str, str] | None = None +) -> httpx.Response: + """Send an MCP ``initialize`` JSON-RPC call — FastMCP requires this + before ``tools/call`` even in stateless mode.""" + request_headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + if headers: + request_headers.update(headers) + body = { + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": { + "protocolVersion": "2025-06-18", + "capabilities": {}, + "clientInfo": {"name": "test-client", "version": "1.0"}, + }, + } + return await client.post("/mcp/", json=body, headers=request_headers) + + +async def _call_tool( + client: httpx.AsyncClient, + tool_name: str, + arguments: dict[str, Any], + *, + headers: dict[str, str] | None = None, +) -> httpx.Response: + """POST a JSON-RPC ``tools/call`` to the MCP endpoint.""" + request_headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + if headers: + request_headers.update(headers) + body = { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": {"name": tool_name, "arguments": arguments}, + } + return await client.post("/mcp/", json=body, headers=request_headers) + + +def _parse_event_stream(body: str) -> dict[str, Any]: + """Parse SSE event-stream body from FastMCP into a dict.""" + import json + + for line in body.splitlines(): + line = line.strip() + if line.startswith("data: "): + return json.loads(line.removeprefix("data: ")) + return json.loads(body) if body.strip() else {} diff --git a/tests/test_server_idempotency.py b/tests/test_server_idempotency.py index 6dbd1380a..11ea43545 100644 --- a/tests/test_server_idempotency.py +++ b/tests/test_server_idempotency.py @@ -296,6 +296,46 @@ async def test_per_principal_scope_enforced(self) -> None: assert r_a["media_buy_id"] != r_b["media_buy_id"] assert handler.call_count == 2 + @pytest.mark.asyncio + async def test_per_tenant_scope_enforced_for_shared_principal_id(self) -> None: + # Multi-tenant deployments whose principal IDs are only unique + # *within* a tenant (Okta group-scoped, SCIM per-tenant, seller- + # internal employee IDs) must not leak cached responses across + # tenants on the same (locally-unique) principal id. + store = self._make_store() + handler = _FakeHandler() + wrapped = store.wrap(_FakeHandler.create_media_buy) + key = str(uuid.uuid4()) + r_a = await wrapped( + handler, + {"idempotency_key": key, "b": 1}, + ToolContext(caller_identity="alice-42", tenant_id="tenant-acme"), + ) + r_b = await wrapped( + handler, + {"idempotency_key": key, "b": 1}, + ToolContext(caller_identity="alice-42", tenant_id="tenant-beta"), + ) + assert r_a["media_buy_id"] != r_b["media_buy_id"], ( + "Same principal_id across two tenants shared the cache slot — " + "cross-tenant response replay is possible." + ) + assert handler.call_count == 2 + + @pytest.mark.asyncio + async def test_tenant_scope_matches_on_identical_tenant_and_principal(self) -> None: + # Sanity-check the positive case: same (tenant_id, caller_identity) + # still shares the scope and replays from cache. + store = self._make_store() + handler = _FakeHandler() + wrapped = store.wrap(_FakeHandler.create_media_buy) + key = str(uuid.uuid4()) + ctx = ToolContext(caller_identity="alice-42", tenant_id="tenant-acme") + r1 = await wrapped(handler, {"idempotency_key": key, "b": 1}, ctx) + r2 = await wrapped(handler, {"idempotency_key": key, "b": 1}, ctx) + assert handler.call_count == 1 + assert r1 == r2 + @pytest.mark.asyncio async def test_no_idempotency_key_falls_through(self) -> None: # Middleware doesn't reject; server-side schema validation handles that.