Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,15 @@ dependencies = [
# rather than relying on a2a-sdk's transitive pin (which is wider).
"protobuf>=6,<8",
"sse-starlette>=2.0", # required by a2a-sdk v0.3 compat adapter
"mcp>=1.23.2",
# 1.27.0 added ``session_idle_timeout`` to ``StreamableHTTPSessionManager``
# which we pass when adopters opt into stateful streamable-http. Upper
# bound at <2.0 because ``adcp.server.serve.create_mcp_server`` reads
# FastMCP private attrs (``_mcp_server``, ``_event_store``,
# ``_retry_interval``) when pre-creating the session manager — that
# contract is not preserved across majors, and upstream signaled v2
# development on ``main`` (v1.x maintenance branch only ports critical
# fixes). Bump deliberately when v2 lands.
"mcp>=1.27.0,<2.0",
"email-validator>=2.0.0",
"cryptography>=41.0.0",
# RFC 8785 JSON Canonicalization Scheme — used by the server-side
Expand Down
144 changes: 126 additions & 18 deletions src/adcp/server/auth.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
"""Bearer-token HTTP authentication middleware for ADCP MCP servers.

`examples/mcp_with_auth_middleware.py` is the full, load-bearing
recipe for multi-tenant sellers. Four things have to be right at the
same time — a ContextVar carrier for the authenticated principal,
constant-time token compare, the AdCP/MCP discovery-method bypass, and
reset-in-finally to prevent cross-request leak. Getting any of them
wrong is a security incident. This module factors that recipe into a
middleware class + matching ``context_factory`` so sellers write four
lines of wiring instead of four pages of auth code.
recipe for multi-tenant sellers. Five things have to be right at the
same time — a ``request.state`` carrier that survives the stateful
streamable-http session boundary, a ContextVar fallback for stateless
mode and A2A, constant-time token compare, the AdCP/MCP
discovery-method bypass, and reset-in-finally to prevent cross-request
leak. Getting any of them wrong is a security incident. This module
factors that recipe into a middleware class + matching
``context_factory`` so sellers write four lines of wiring instead of
four pages of auth code.

Typical usage::

Expand Down Expand Up @@ -194,6 +196,63 @@ def __call__(self, token: str) -> Awaitable[Principal | None]: ...
)


# Well-known ``request.state`` attribute names. The middleware writes
# these alongside the ContextVars; ``auth_context_factory`` reads them
# off the request reachable via :class:`RequestMetadata.request_context`.
# The state path is the only auth-propagation channel that survives the
# stateful streamable-http session boundary — the session task is a
# separate async task from the middleware's, so the ContextVars set
# above are invisible during dispatch in stateful mode. Stateless mode
# happens to share the dispatch context, so the ContextVar path also
# works there. The factory reads request.state first and falls back to
# the ContextVar so adopters running on stateless mode don't have to
# change anything.
REQUEST_STATE_PRINCIPAL = "adcp_auth_principal"
REQUEST_STATE_TENANT = "adcp_auth_tenant"
REQUEST_STATE_PRINCIPAL_METADATA = "adcp_auth_principal_metadata"


def _set_request_state(
request: Any,
principal_identity: str | None,
tenant_id: str | None,
principal_metadata: dict[str, Any] | None,
) -> None:
"""Write the auth triple onto ``request.state``.

Defensive: silently no-ops if ``request`` lacks a ``state``
attribute (e.g., a test double). Real Starlette ``Request`` objects
always have it.
"""
state = getattr(request, "state", None)
if state is None:
return
setattr(state, REQUEST_STATE_PRINCIPAL, principal_identity)
setattr(state, REQUEST_STATE_TENANT, tenant_id)
setattr(state, REQUEST_STATE_PRINCIPAL_METADATA, principal_metadata)


def _read_request_state_auth(
request: Any,
) -> tuple[str | None, str | None, dict[str, Any] | None] | None:
"""Read the auth triple off ``request.state``, or ``None`` if not set.

``None`` means the middleware never ran for this request (e.g., the
factory was invoked outside an HTTP path) — the caller should fall
back to the ContextVars.
"""
state = getattr(request, "state", None)
if state is None:
return None
if not hasattr(state, REQUEST_STATE_PRINCIPAL):
return None
return (
getattr(state, REQUEST_STATE_PRINCIPAL, None),
getattr(state, REQUEST_STATE_TENANT, None),
getattr(state, REQUEST_STATE_PRINCIPAL_METADATA, None),
)


class BearerTokenAuthMiddleware(BaseHTTPMiddleware):
"""Starlette HTTP middleware that gates every non-discovery JSON-RPC
request on a valid bearer token.
Expand Down Expand Up @@ -274,6 +333,7 @@ async def dispatch(self, request: Request, call_next: Any) -> Any:
principal_token = current_principal.set(None)
tenant_token = current_tenant.set(None)
metadata_token = current_principal_metadata.set(None)
_set_request_state(request, None, None, None)
return await call_next(request)

raw_header = request.headers.get(self._header_name, "")
Expand Down Expand Up @@ -305,10 +365,22 @@ async def dispatch(self, request: Request, call_next: Any) -> Any:
if principal is None:
return self._unauthenticated()

principal_metadata = dict(principal.metadata) if principal.metadata else None
principal_token = current_principal.set(principal.caller_identity)
tenant_token = current_tenant.set(principal.tenant_id)
metadata_token = current_principal_metadata.set(
dict(principal.metadata) if principal.metadata else None
metadata_token = current_principal_metadata.set(principal_metadata)
# Mirror onto ``request.state`` so the dispatch-side
# ``context_factory`` can read the principal even when the
# MCP server is in stateful mode (where the session task is a
# separate async task than this middleware's task and does
# not see the ContextVar set above). ``request.state`` is the
# standard Starlette per-request scratchpad and travels with
# the request through any nested ASGI app.
_set_request_state(
request,
principal.caller_identity,
principal.tenant_id,
principal_metadata,
)
return await call_next(request)
finally:
Expand Down Expand Up @@ -384,13 +456,27 @@ async def _peek_jsonrpc(request: Request) -> tuple[str | None, str | None]:


def auth_context_factory(meta: RequestMetadata) -> ToolContext:
"""Build a :class:`~adcp.server.ToolContext` from the ContextVars
:class:`BearerTokenAuthMiddleware` populates.
"""Build a :class:`~adcp.server.ToolContext` from auth state the
:class:`BearerTokenAuthMiddleware` populated for the in-flight
request.

Pass this to :func:`~adcp.server.create_mcp_server` (or
:func:`~adcp.server.serve`) alongside the middleware so handlers
receive a typed context carrying the authenticated principal.

Resolution order:

1. ``meta.request_context.state`` — the standard Starlette
per-request scratchpad. Survives the stateful streamable-http
session-task boundary (the dispatch sub-task gets the originating
Starlette ``Request`` via the upstream MCP ``request_ctx``
contextvar). Works on both stateless and stateful streamable-http.
2. Module-level :data:`current_principal` etc. ContextVars — the
legacy carrier. Works only when the dispatch runs in the same
async task as the middleware (i.e., stateless streamable-http
and A2A). In stateful streamable-http, these read ``None``
because the session task is a separate task.

Populates ``caller_identity``, ``tenant_id``, and a ``metadata``
dict containing the transport + tool name plus anything the
:class:`Principal` provided. SDK-owned keys (``tool_name``,
Expand All @@ -415,8 +501,22 @@ def auth_context_factory(meta: RequestMetadata) -> ToolContext:
framework. Do not pass ``ctx.metadata`` wholesale to a JSON serializer
— the ``AuthInfo`` object is not JSON-serializable.
"""
principal_identity = current_principal.get()
principal_metadata = current_principal_metadata.get() or {}
principal_identity: str | None = None
tenant_id: str | None = None
principal_metadata: dict[str, Any] | None = None
if meta.request_context is not None:
triple = _read_request_state_auth(meta.request_context)
if triple is not None:
principal_identity, tenant_id, principal_metadata = triple
if principal_identity is None and tenant_id is None and principal_metadata is None:
# Either no Request was threaded (stdio MCP, A2A pre-builder
# path) or the middleware didn't write to state — fall back to
# the ContextVars. Works on stateless streamable-http and A2A
# where dispatch shares the middleware's task context.
principal_identity = current_principal.get()
tenant_id = current_tenant.get()
principal_metadata = current_principal_metadata.get()
principal_metadata = principal_metadata or {}
combined_metadata: dict[str, Any] = {
**principal_metadata,
"tool_name": meta.tool_name,
Expand All @@ -438,7 +538,7 @@ def auth_context_factory(meta: RequestMetadata) -> ToolContext:
return ToolContext(
request_id=meta.request_id,
caller_identity=principal_identity,
tenant_id=current_tenant.get(),
tenant_id=tenant_id,
metadata=combined_metadata,
)

Expand Down Expand Up @@ -824,10 +924,11 @@ async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
# and the raw Principal (for downstream code reading scope['auth']).
# Mutating the scope dict before delegating propagates state to
# nested apps without copying.
principal_metadata = dict(principal.metadata) if principal.metadata else None
scope["user"] = _A2AAuthenticatedUser(
display_name=principal.caller_identity,
tenant_id=principal.tenant_id,
principal_metadata=dict(principal.metadata) if principal.metadata else None,
principal_metadata=principal_metadata,
)
scope["auth"] = principal

Expand All @@ -838,11 +939,18 @@ async def __call__(self, scope: Any, receive: Any, send: Any) -> None:
# ``None`` default while MCP handlers see the principal — a silent
# transport-coupled divergence that breaks tenant policies that
# require principal-bound calls. See issue #590.
#
# ContextVars carry on the A2A leg because the dispatch runs in
# the same async task as this middleware (no session-task seam
# like MCP stateful streamable-http). The MCP leg's mirror onto
# ``request.state`` is what survives the stateful session-task
# boundary; A2A's dispatcher reads ContextVars directly. If A2A
# ever grows a long-lived dispatch task that decouples from the
# request task, we'll need to thread the request through
# ``RequestMetadata`` on the A2A side too.
principal_token = current_principal.set(principal.caller_identity)
tenant_token = current_tenant.set(principal.tenant_id)
metadata_token = current_principal_metadata.set(
dict(principal.metadata) if principal.metadata else None
)
metadata_token = current_principal_metadata.set(principal_metadata)
try:
await self._app(scope, receive, send)
finally:
Expand Down
Loading
Loading