From 9583ff1853301e7625dfa06facc6f315cbe96ff5 Mon Sep 17 00:00:00 2001 From: Brian O'Kelley Date: Sun, 10 May 2026 09:50:29 -0400 Subject: [PATCH] feat(server): default MCP streamable-http to stateful with idle eviction MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Stateless was the production-broken default — upstream MCP holds GET-SSE streams open with no idle eviction, causing connection accumulation under load. Stateful mode + session_idle_timeout=1800s (the knob added in mcp 1.27.0) is the production-safe shape. To make stateful safe by default, plumb the originating Starlette Request into RequestMetadata.request_context (sourced from upstream's mcp.server.lowlevel.server.request_ctx, set reliably in both dispatch paths). BearerTokenAuthMiddleware mirrors principal/tenant onto request.state in addition to ContextVars; auth_context_factory reads request.state first and falls back to ContextVars. Adopters using the bundled middleware + factory get the fix for free. Adopters with custom factories using ContextVars need to migrate to meta.request_context.state for stateful — documented in the docstrings. Adopters who genuinely need stateless (multi-replica without sticky LB on Mcp-Session-Id) opt back in via stateless_http=True; the SDK suppresses session_idle_timeout in that combination to honor the upstream constructor's stateless+timeout RuntimeError contract. Bumps mcp pin to >=1.27.0,<2.0 — the upper bound is required because create_mcp_server pre-creates StreamableHTTPSessionManager (FastMCP doesn't expose session_idle_timeout in its settings) which reads four FastMCP private attrs whose contract is not preserved across majors. Co-Authored-By: Claude Opus 4.7 (1M context) --- pyproject.toml | 10 +- src/adcp/server/auth.py | 144 +++++++++-- src/adcp/server/serve.py | 175 ++++++++++++- tests/test_mcp_middleware_composition.py | 61 +++-- tests/test_mcp_stateful_session.py | 305 +++++++++++++++++++++++ tests/test_tools_list_output_schema.py | 7 + 6 files changed, 659 insertions(+), 43 deletions(-) create mode 100644 tests/test_mcp_stateful_session.py diff --git a/pyproject.toml b/pyproject.toml index b1bdcafe4..58da8b3c0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -66,7 +66,15 @@ dependencies = [ # rather than relying on a2a-sdk's transitive pin (which is wider). "protobuf>=6,<8", "sse-starlette>=2.0", # required by a2a-sdk v0.3 compat adapter - "mcp>=1.23.2", + # 1.27.0 added ``session_idle_timeout`` to ``StreamableHTTPSessionManager`` + # which we pass when adopters opt into stateful streamable-http. Upper + # bound at <2.0 because ``adcp.server.serve.create_mcp_server`` reads + # FastMCP private attrs (``_mcp_server``, ``_event_store``, + # ``_retry_interval``) when pre-creating the session manager — that + # contract is not preserved across majors, and upstream signaled v2 + # development on ``main`` (v1.x maintenance branch only ports critical + # fixes). Bump deliberately when v2 lands. + "mcp>=1.27.0,<2.0", "email-validator>=2.0.0", "cryptography>=41.0.0", # RFC 8785 JSON Canonicalization Scheme — used by the server-side diff --git a/src/adcp/server/auth.py b/src/adcp/server/auth.py index 4b27c66c4..038f0d356 100644 --- a/src/adcp/server/auth.py +++ b/src/adcp/server/auth.py @@ -1,13 +1,15 @@ """Bearer-token HTTP authentication middleware for ADCP MCP servers. `examples/mcp_with_auth_middleware.py` is the full, load-bearing -recipe for multi-tenant sellers. Four things have to be right at the -same time — a ContextVar carrier for the authenticated principal, -constant-time token compare, the AdCP/MCP discovery-method bypass, and -reset-in-finally to prevent cross-request leak. Getting any of them -wrong is a security incident. This module factors that recipe into a -middleware class + matching ``context_factory`` so sellers write four -lines of wiring instead of four pages of auth code. +recipe for multi-tenant sellers. Five things have to be right at the +same time — a ``request.state`` carrier that survives the stateful +streamable-http session boundary, a ContextVar fallback for stateless +mode and A2A, constant-time token compare, the AdCP/MCP +discovery-method bypass, and reset-in-finally to prevent cross-request +leak. Getting any of them wrong is a security incident. This module +factors that recipe into a middleware class + matching +``context_factory`` so sellers write four lines of wiring instead of +four pages of auth code. Typical usage:: @@ -194,6 +196,63 @@ def __call__(self, token: str) -> Awaitable[Principal | None]: ... ) +# Well-known ``request.state`` attribute names. The middleware writes +# these alongside the ContextVars; ``auth_context_factory`` reads them +# off the request reachable via :class:`RequestMetadata.request_context`. +# The state path is the only auth-propagation channel that survives the +# stateful streamable-http session boundary — the session task is a +# separate async task from the middleware's, so the ContextVars set +# above are invisible during dispatch in stateful mode. Stateless mode +# happens to share the dispatch context, so the ContextVar path also +# works there. The factory reads request.state first and falls back to +# the ContextVar so adopters running on stateless mode don't have to +# change anything. +REQUEST_STATE_PRINCIPAL = "adcp_auth_principal" +REQUEST_STATE_TENANT = "adcp_auth_tenant" +REQUEST_STATE_PRINCIPAL_METADATA = "adcp_auth_principal_metadata" + + +def _set_request_state( + request: Any, + principal_identity: str | None, + tenant_id: str | None, + principal_metadata: dict[str, Any] | None, +) -> None: + """Write the auth triple onto ``request.state``. + + Defensive: silently no-ops if ``request`` lacks a ``state`` + attribute (e.g., a test double). Real Starlette ``Request`` objects + always have it. + """ + state = getattr(request, "state", None) + if state is None: + return + setattr(state, REQUEST_STATE_PRINCIPAL, principal_identity) + setattr(state, REQUEST_STATE_TENANT, tenant_id) + setattr(state, REQUEST_STATE_PRINCIPAL_METADATA, principal_metadata) + + +def _read_request_state_auth( + request: Any, +) -> tuple[str | None, str | None, dict[str, Any] | None] | None: + """Read the auth triple off ``request.state``, or ``None`` if not set. + + ``None`` means the middleware never ran for this request (e.g., the + factory was invoked outside an HTTP path) — the caller should fall + back to the ContextVars. + """ + state = getattr(request, "state", None) + if state is None: + return None + if not hasattr(state, REQUEST_STATE_PRINCIPAL): + return None + return ( + getattr(state, REQUEST_STATE_PRINCIPAL, None), + getattr(state, REQUEST_STATE_TENANT, None), + getattr(state, REQUEST_STATE_PRINCIPAL_METADATA, None), + ) + + class BearerTokenAuthMiddleware(BaseHTTPMiddleware): """Starlette HTTP middleware that gates every non-discovery JSON-RPC request on a valid bearer token. @@ -274,6 +333,7 @@ async def dispatch(self, request: Request, call_next: Any) -> Any: principal_token = current_principal.set(None) tenant_token = current_tenant.set(None) metadata_token = current_principal_metadata.set(None) + _set_request_state(request, None, None, None) return await call_next(request) raw_header = request.headers.get(self._header_name, "") @@ -305,10 +365,22 @@ async def dispatch(self, request: Request, call_next: Any) -> Any: if principal is None: return self._unauthenticated() + principal_metadata = dict(principal.metadata) if principal.metadata else None principal_token = current_principal.set(principal.caller_identity) tenant_token = current_tenant.set(principal.tenant_id) - metadata_token = current_principal_metadata.set( - dict(principal.metadata) if principal.metadata else None + metadata_token = current_principal_metadata.set(principal_metadata) + # Mirror onto ``request.state`` so the dispatch-side + # ``context_factory`` can read the principal even when the + # MCP server is in stateful mode (where the session task is a + # separate async task than this middleware's task and does + # not see the ContextVar set above). ``request.state`` is the + # standard Starlette per-request scratchpad and travels with + # the request through any nested ASGI app. + _set_request_state( + request, + principal.caller_identity, + principal.tenant_id, + principal_metadata, ) return await call_next(request) finally: @@ -384,13 +456,27 @@ async def _peek_jsonrpc(request: Request) -> tuple[str | None, str | None]: def auth_context_factory(meta: RequestMetadata) -> ToolContext: - """Build a :class:`~adcp.server.ToolContext` from the ContextVars - :class:`BearerTokenAuthMiddleware` populates. + """Build a :class:`~adcp.server.ToolContext` from auth state the + :class:`BearerTokenAuthMiddleware` populated for the in-flight + request. Pass this to :func:`~adcp.server.create_mcp_server` (or :func:`~adcp.server.serve`) alongside the middleware so handlers receive a typed context carrying the authenticated principal. + Resolution order: + + 1. ``meta.request_context.state`` — the standard Starlette + per-request scratchpad. Survives the stateful streamable-http + session-task boundary (the dispatch sub-task gets the originating + Starlette ``Request`` via the upstream MCP ``request_ctx`` + contextvar). Works on both stateless and stateful streamable-http. + 2. Module-level :data:`current_principal` etc. ContextVars — the + legacy carrier. Works only when the dispatch runs in the same + async task as the middleware (i.e., stateless streamable-http + and A2A). In stateful streamable-http, these read ``None`` + because the session task is a separate task. + Populates ``caller_identity``, ``tenant_id``, and a ``metadata`` dict containing the transport + tool name plus anything the :class:`Principal` provided. SDK-owned keys (``tool_name``, @@ -415,8 +501,22 @@ def auth_context_factory(meta: RequestMetadata) -> ToolContext: framework. Do not pass ``ctx.metadata`` wholesale to a JSON serializer — the ``AuthInfo`` object is not JSON-serializable. """ - principal_identity = current_principal.get() - principal_metadata = current_principal_metadata.get() or {} + principal_identity: str | None = None + tenant_id: str | None = None + principal_metadata: dict[str, Any] | None = None + if meta.request_context is not None: + triple = _read_request_state_auth(meta.request_context) + if triple is not None: + principal_identity, tenant_id, principal_metadata = triple + if principal_identity is None and tenant_id is None and principal_metadata is None: + # Either no Request was threaded (stdio MCP, A2A pre-builder + # path) or the middleware didn't write to state — fall back to + # the ContextVars. Works on stateless streamable-http and A2A + # where dispatch shares the middleware's task context. + principal_identity = current_principal.get() + tenant_id = current_tenant.get() + principal_metadata = current_principal_metadata.get() + principal_metadata = principal_metadata or {} combined_metadata: dict[str, Any] = { **principal_metadata, "tool_name": meta.tool_name, @@ -438,7 +538,7 @@ def auth_context_factory(meta: RequestMetadata) -> ToolContext: return ToolContext( request_id=meta.request_id, caller_identity=principal_identity, - tenant_id=current_tenant.get(), + tenant_id=tenant_id, metadata=combined_metadata, ) @@ -824,10 +924,11 @@ async def __call__(self, scope: Any, receive: Any, send: Any) -> None: # and the raw Principal (for downstream code reading scope['auth']). # Mutating the scope dict before delegating propagates state to # nested apps without copying. + principal_metadata = dict(principal.metadata) if principal.metadata else None scope["user"] = _A2AAuthenticatedUser( display_name=principal.caller_identity, tenant_id=principal.tenant_id, - principal_metadata=dict(principal.metadata) if principal.metadata else None, + principal_metadata=principal_metadata, ) scope["auth"] = principal @@ -838,11 +939,18 @@ async def __call__(self, scope: Any, receive: Any, send: Any) -> None: # ``None`` default while MCP handlers see the principal — a silent # transport-coupled divergence that breaks tenant policies that # require principal-bound calls. See issue #590. + # + # ContextVars carry on the A2A leg because the dispatch runs in + # the same async task as this middleware (no session-task seam + # like MCP stateful streamable-http). The MCP leg's mirror onto + # ``request.state`` is what survives the stateful session-task + # boundary; A2A's dispatcher reads ContextVars directly. If A2A + # ever grows a long-lived dispatch task that decouples from the + # request task, we'll need to thread the request through + # ``RequestMetadata`` on the A2A side too. principal_token = current_principal.set(principal.caller_identity) tenant_token = current_tenant.set(principal.tenant_id) - metadata_token = current_principal_metadata.set( - dict(principal.metadata) if principal.metadata else None - ) + metadata_token = current_principal_metadata.set(principal_metadata) try: await self._app(scope, receive, send) finally: diff --git a/src/adcp/server/serve.py b/src/adcp/server/serve.py index 80478191a..fd1b5e999 100644 --- a/src/adcp/server/serve.py +++ b/src/adcp/server/serve.py @@ -27,6 +27,8 @@ async def get_adcp_capabilities(self, params, context=None): logger = logging.getLogger("adcp.server") +from mcp.server.streamable_http_manager import StreamableHTTPSessionManager + from adcp.server.base import ADCPHandler, ToolContext from adcp.server.mcp_tools import ( _HANDLER_TOOLS, @@ -77,11 +79,22 @@ class RequestMetadata: :param request_id: The transport-assigned request id when one exists (A2A populates this from the task id; MCP leaves it ``None`` at the SDK layer today). + :param request_context: The originating Starlette ``Request`` for + HTTP-borne calls (MCP streamable-http, A2A). ``None`` for + stdio MCP and any path that doesn't have a Request to thread. + Use ``request_context.state`` to read per-request state set + by ASGI middleware — this works in both stateless and stateful + MCP modes, where the older :mod:`contextvars` pattern only + works in stateless (the stateful session task is a separate + async task and does not see middleware-set ContextVars). + Typed as ``Any`` to keep this dataclass dependency-light; + adopters can ``cast(Request, meta.request_context)``. """ tool_name: str transport: Literal["mcp", "a2a"] request_id: str | None = None + request_context: Any = None @dataclass(frozen=True) @@ -119,6 +132,8 @@ class ServeConfig: # --- MCP only --- instructions: str | None = None streaming_responses: bool = False + stateless_http: bool = False + session_idle_timeout: float | None = 1800.0 # --- A2A only --- task_store: TaskStore | None = None @@ -146,7 +161,13 @@ class ServeConfig: def __post_init__(self) -> None: _a2a_only = ("task_store", "push_config_store", "message_parser", "public_url") - _mcp_only = ("instructions", "streaming_responses") + # ``session_idle_timeout`` (default 1800.0) is excluded from + # the warning list: the ``not in (None, False)`` heuristic + # treats any non-falsy default as "set" and would fire + # spuriously under transport='a2a'. ``stateless_http`` (default + # False) and ``streaming_responses`` (default False) work + # cleanly with the heuristic. + _mcp_only = ("instructions", "streaming_responses", "stateless_http") if self.transport == "a2a": mcp_set = sorted(f for f in _mcp_only if getattr(self, f) not in (None, False)) if mcp_set: @@ -288,6 +309,34 @@ async def audit_middleware( """ +def _get_starlette_request_for_dispatch() -> Any: + """Return the Starlette ``Request`` for the in-flight MCP tool call, if + any — else ``None``. + + The MCP lowlevel server stashes the originating ``Request`` in a + contextvar (``mcp.server.lowlevel.server.request_ctx``) for the + duration of each dispatched request, in both stateless and stateful + modes. The contextvar lives in the dispatch sub-task that the + session task spawned (``tg.start_soon(_handle_message, ...)``), so + the value reachable here is the originating request — not the + session-creation request — even when the streamable-http transport + holds a long-lived session task. + + Returns ``None`` when called outside an MCP dispatch (e.g. from the + server-builder smoke tests, or from A2A's executor which has its + own context channel via ``ServerCallContext``). + """ + try: + from mcp.server.lowlevel.server import request_ctx + except ImportError: # pragma: no cover — mcp pin guarantees this + return None + try: + ctx = request_ctx.get() + except LookupError: + return None + return getattr(ctx, "request", None) + + def _log_advertised_tools( *, transport: Literal["mcp", "a2a"], @@ -524,6 +573,8 @@ def serve( advertise_all: bool = False, max_request_size: int | None = None, streaming_responses: bool = False, + stateless_http: bool = False, + session_idle_timeout: float | None = 1800.0, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, enable_debug_endpoints: bool = False, debug_traffic_source: Callable[[], dict[str, int]] | None = None, @@ -653,6 +704,26 @@ def serve( (MCP transports only). Note: the legacy ``transport="sse"`` is a separate (deprecated) MCP transport, unrelated to this flag. + stateless_http: When ``False`` (default), MCP keeps a per-client + session alive across requests so subsequent ``tools/call`` + posts skip the transport-construction tax — meaningfully + faster for chatty clients, and the only mode where + ``StreamableHTTPSessionManager``'s idle-reap path actually + runs. (Stateless mode in upstream MCP holds GET-SSE streams + with no idle eviction, which is why production adopters + saw connections accumulate.) The SDK threads the + originating Starlette ``Request`` into + ``RequestMetadata.request_context`` in both modes so + ``context_factory`` can read auth off ``request.state``; + the bundled :func:`auth_context_factory` already does + this. Set ``True`` for stateless deployments — multi-replica + without sticky LB on ``Mcp-Session-Id``, or where you + cannot configure session affinity. + session_idle_timeout: Idle reap deadline (seconds) for stateful + sessions. Each request pushes the deadline forward; idle + sessions are terminated and their per-session state freed. + Defaults to 1800 (30 min); ``None`` disables reaping. + Ignored when ``stateless_http=True``. enable_debug_endpoints: When ``True``, mount ``GET /_debug/traffic`` on the outer HTTP app. Returns the JSON dict from ``debug_traffic_source()`` — typically wired to the @@ -771,6 +842,8 @@ async def force_account_status(self, account_id, status): advertise_all = config.advertise_all max_request_size = config.max_request_size streaming_responses = config.streaming_responses + stateless_http = config.stateless_http + session_idle_timeout = config.session_idle_timeout validation = config.validation enable_debug_endpoints = config.enable_debug_endpoints debug_traffic_source = config.debug_traffic_source @@ -837,6 +910,8 @@ async def force_account_status(self, account_id, status): advertise_all=advertise_all, max_request_size=max_request_size, streaming_responses=streaming_responses, + stateless_http=stateless_http, + session_idle_timeout=session_idle_timeout, validation=validation, base_url=base_url, specialisms=specialisms, @@ -864,6 +939,8 @@ async def force_account_status(self, account_id, status): advertise_all=advertise_all, max_request_size=max_request_size, streaming_responses=streaming_responses, + stateless_http=stateless_http, + session_idle_timeout=session_idle_timeout, validation=validation, base_url=base_url, specialisms=specialisms, @@ -1238,6 +1315,8 @@ def _serve_mcp( advertise_all: bool = False, max_request_size: int | None = None, streaming_responses: bool = False, + stateless_http: bool = False, + session_idle_timeout: float | None = 1800.0, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, base_url: str | None = None, specialisms: list[str] | None = None, @@ -1259,6 +1338,8 @@ def _serve_mcp( middleware=middleware, advertise_all=advertise_all, streaming_responses=streaming_responses, + stateless_http=stateless_http, + session_idle_timeout=session_idle_timeout, validation=validation, allowed_hosts=allowed_hosts, allowed_origins=allowed_origins, @@ -1480,6 +1561,8 @@ def _build_mcp_and_a2a_app( advertise_all: bool = False, max_request_size: int | None = None, streaming_responses: bool = False, + stateless_http: bool = False, + session_idle_timeout: float | None = 1800.0, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, base_url: str | None = None, specialisms: list[str] | None = None, @@ -1522,6 +1605,8 @@ def _build_mcp_and_a2a_app( middleware=middleware, advertise_all=advertise_all, streaming_responses=streaming_responses, + stateless_http=stateless_http, + session_idle_timeout=session_idle_timeout, validation=validation, allowed_hosts=allowed_hosts, allowed_origins=allowed_origins, @@ -1658,6 +1743,8 @@ def _serve_mcp_and_a2a( advertise_all: bool = False, max_request_size: int | None = None, streaming_responses: bool = False, + stateless_http: bool = False, + session_idle_timeout: float | None = 1800.0, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, base_url: str | None = None, specialisms: list[str] | None = None, @@ -1705,6 +1792,8 @@ def _serve_mcp_and_a2a( advertise_all=advertise_all, max_request_size=max_request_size, streaming_responses=streaming_responses, + stateless_http=stateless_http, + session_idle_timeout=session_idle_timeout, validation=validation, base_url=base_url, specialisms=specialisms, @@ -1786,6 +1875,8 @@ def create_mcp_server( middleware: Sequence[SkillMiddleware] | None = None, advertise_all: bool = False, streaming_responses: bool = False, + stateless_http: bool = False, + session_idle_timeout: float | None = 1800.0, validation: ValidationHookConfig | None = DEFAULT_VALIDATION, allowed_hosts: Sequence[str] | None = None, allowed_origins: Sequence[str] | None = None, @@ -1844,6 +1935,35 @@ def create_mcp_server( without completing, blocking the storyboard runner. Set to ``True`` only if your tools genuinely emit progress notifications and your clients consume the SSE stream. + stateless_http: When ``False`` (default), MCP keeps a + per-client session task alive across requests so subsequent + ``tools/call`` posts skip the per-request transport- + construction tax — meaningfully faster for chatty clients, + and the only mode where ``StreamableHTTPSessionManager``'s + idle-reap path actually runs. (Stateless mode in upstream + MCP holds GET-SSE streams open with no idle eviction — + connections accumulate.) The SDK threads the originating + Starlette ``Request`` into + ``RequestMetadata.request_context``; the bundled + :func:`~adcp.server.auth_context_factory` reads auth off + ``request.state`` and works in both stateless and stateful. + Custom factories using :mod:`contextvars` set in ASGI + middleware should migrate — those vars do NOT propagate + from the HTTP request task to the stateful session's + dispatch task. Multi-replica stateful deployments need + sticky load balancing on ``Mcp-Session-Id``; set + ``stateless_http=True`` only when affinity isn't possible. + Do not memoize per-call state on ``mcp.Context`` or + session-manager-scoped objects in stateful mode — that + smears identity across calls. + session_idle_timeout: Idle reap deadline (seconds) for stateful + sessions. Each request pushes the deadline forward; idle + sessions are terminated and their per-session state freed. + Defaults to 1800 (30 minutes); set to ``None`` to disable + reaping. Ignored when ``stateless_http=True``. Required + because without it + ``StreamableHTTPSessionManager._server_instances`` grows + without bound for clients that disconnect without DELETE. Returns: A configured FastMCP server instance. Call ``mcp.run()`` to start, @@ -1900,11 +2020,12 @@ def create_mcp_server( mcp = FastMCP(name, instructions=instructions, port=resolved_port) mcp.settings.host = resolved_host if not streaming_responses: - # FastMCP's SSE-internal default has an upstream bug; switching to - # stateless JSON-response mode is also semantically correct for - # AdCP tools, which return one complete envelope per request. - mcp.settings.stateless_http = True + # FastMCP's SSE-internal streaming default has an upstream bug + # that drops the ASGI response without completing; AdCP tools + # return one complete envelope per request anyway, so JSON + # response mode is both safer and semantically correct. mcp.settings.json_response = True + mcp.settings.stateless_http = stateless_http # FastMCP's TransportSecurityMiddleware enforces DNS-rebinding # protection: the default ``allowed_hosts`` accepts only loopback # patterns (``127.0.0.1:*``, ``localhost:*``, ``[::1]:*``). Adopters @@ -1949,6 +2070,34 @@ def create_mcp_server( advertise_all=advertise_all, validation=validation, ) + # Pre-create the StreamableHTTPSessionManager so we can pass + # ``session_idle_timeout`` — FastMCP's settings don't expose it as of + # mcp 1.27.x. ``streamable_http_app()`` lazy-creates the manager only + # if ``_session_manager`` is ``None``, so populating it here is the + # extension point. Reaches into FastMCP private attrs ``_mcp_server``, + # ``_event_store``, ``_retry_interval`` to mirror upstream's own + # constructor call — guarded by the ``mcp<2.0`` pin since v2 may + # rename these. + if session_idle_timeout is not None and session_idle_timeout <= 0: + raise ValueError( + f"session_idle_timeout must be positive (got {session_idle_timeout!r}); " + "set None to disable reaping." + ) + # Suppress the timeout in stateless mode — upstream raises + # ``RuntimeError`` if both are set. Silent because ``stateless_http=True, + # session_idle_timeout=1800.0`` is the default combination and would + # warn on every server boot otherwise. Adopters who explicitly want a + # timeout should set ``stateless_http=False``. + idle_timeout = None if mcp.settings.stateless_http else session_idle_timeout + mcp._session_manager = StreamableHTTPSessionManager( + app=mcp._mcp_server, + event_store=mcp._event_store, + retry_interval=mcp._retry_interval, + json_response=mcp.settings.json_response, + stateless=mcp.settings.stateless_http, + security_settings=mcp.settings.transport_security, + session_idle_timeout=idle_timeout, + ) return mcp @@ -2040,13 +2189,23 @@ async def fn(**kwargs: Any) -> dict[str, Any]: # at the SDK level (``Context.client_id`` is a session hint, not an # authenticated user). Sellers wire auth via HTTP middleware on # ``mcp.streamable_http_app()`` and pass ``context_factory`` to - # ``create_mcp_server()`` — the factory reads a ``contextvars.ContextVar`` - # the middleware populates and returns a typed ``ToolContext``. + # ``create_mcp_server()``. ``RequestMetadata.request_context`` carries + # the originating Starlette ``Request`` so the factory can read + # ``request.state.*`` set by middleware — this works in both + # stateless and stateful streamable-http modes, where the older + # ``contextvars.ContextVar`` pattern only works in stateless (the + # stateful session task is a separate async task than the HTTP + # request task and does not see middleware-set ContextVars). # The A2A transport derives ``caller_identity`` from # ``ServerCallContext.user`` automatically. context: ToolContext | None = None if context_factory is not None: - meta = RequestMetadata(tool_name=name, transport="mcp") + request_context = _get_starlette_request_for_dispatch() + meta = RequestMetadata( + tool_name=name, + transport="mcp", + request_context=request_context, + ) context = context_factory(meta) if not isinstance(context, ToolContext): # Catch downstream factories that return a dict or other diff --git a/tests/test_mcp_middleware_composition.py b/tests/test_mcp_middleware_composition.py index 31fa3315c..09fd645bf 100644 --- a/tests/test_mcp_middleware_composition.py +++ b/tests/test_mcp_middleware_composition.py @@ -48,6 +48,16 @@ _current_principal: ContextVar[str | None] = ContextVar("test_current_principal", default=None) _current_tenant: ContextVar[str | None] = ContextVar("test_current_tenant", default=None) +# Per-request state attribute names. With the default stateful +# streamable-http transport, the dispatch task is a different async +# task than the middleware's, so ``ContextVar`` propagation breaks. +# ``request.state`` survives the boundary because the dispatch reads +# the originating Starlette ``Request`` from +# ``mcp.server.lowlevel.server.request_ctx`` and we plumb it through +# ``RequestMetadata.request_context``. +_REQUEST_STATE_PRINCIPAL = "test_principal" +_REQUEST_STATE_TENANT = "test_tenant" + class _RecordingHandler(ADCPHandler): """Handler that records the ToolContext each call received.""" @@ -93,11 +103,18 @@ async def dispatch(self, request: Request, call_next: Any) -> Any: if token not in self.VALID_TOKENS: return JSONResponse({"error": "unauthenticated"}, status_code=401) principal, tenant = self.VALID_TOKENS[token] - _principal_token = _current_principal.set(principal) - _tenant_token = _current_tenant.set(tenant) else: - _principal_token = _current_principal.set(None) - _tenant_token = _current_tenant.set(None) + principal = None + tenant = None + + # Write to BOTH ``request.state`` (survives the stateful + # streamable-http session-task boundary) and the legacy + # ContextVars (read by adopters who haven't migrated). The + # dispatch-side ``_build_context`` prefers ``request.state``. + setattr(request.state, _REQUEST_STATE_PRINCIPAL, principal) + setattr(request.state, _REQUEST_STATE_TENANT, tenant) + _principal_token = _current_principal.set(principal) + _tenant_token = _current_tenant.set(tenant) try: return await call_next(request) @@ -134,10 +151,20 @@ async def _peek_jsonrpc(request: Request) -> tuple[str | None, str | None]: def _build_context(meta: RequestMetadata) -> ToolContext: + """Read auth state off the request the SDK threaded into + ``meta.request_context``. This is the pattern adopters should use + in stateful streamable-http mode (the default). The + :mod:`contextvars`-based pattern only works when stateless mode is + explicitly opted in.""" + principal = None + tenant = None + if meta.request_context is not None: + principal = getattr(meta.request_context.state, _REQUEST_STATE_PRINCIPAL, None) + tenant = getattr(meta.request_context.state, _REQUEST_STATE_TENANT, None) return ToolContext( request_id=meta.request_id, - caller_identity=_current_principal.get(), - tenant_id=_current_tenant.get(), + caller_identity=principal, + tenant_id=tenant, metadata={"tool_name": meta.tool_name, "transport": meta.transport}, ) @@ -155,14 +182,10 @@ async def handler_and_client() -> Any: # so a non-spec-conformant stub response doesn't get rewritten # into a VALIDATION_ERROR before the assertion runs. validation=None, + # Allow in-process test host — MCP's DNS-rebinding protection + # rejects unknown Host headers by default when enabled. + allowed_hosts=["localhost", "127.0.0.1"], ) - # Force stateless JSON responses. Production deployments mount the - # MCP app behind a reverse proxy; this test covers that shape. - mcp.settings.stateless_http = True - mcp.settings.json_response = True - # Allow in-process test host — MCP's DNS-rebinding protection - # rejects unknown Host headers by default when enabled. - mcp.settings.transport_security.allowed_hosts = ["localhost", "127.0.0.1"] app = mcp.streamable_http_app() app.add_middleware(_AuthMiddleware) @@ -336,8 +359,10 @@ def test_validate_discovery_set_rejects_unknown_tool() -> None: async def _initialize_session( client: httpx.AsyncClient, *, headers: dict[str, str] | None = None ) -> httpx.Response: - """Send an MCP ``initialize`` JSON-RPC call — FastMCP requires this - before ``tools/call`` even in stateless mode.""" + """Send an MCP ``initialize`` JSON-RPC call. Required before any + ``tools/call`` — and in stateful streamable-http, the response's + ``Mcp-Session-Id`` header must be echoed on every subsequent request + targeting the same session.""" request_headers = { "content-type": "application/json", "accept": "application/json, text/event-stream", @@ -354,7 +379,11 @@ async def _initialize_session( "clientInfo": {"name": "test-client", "version": "1.0"}, }, } - return await client.post("/mcp/", json=body, headers=request_headers) + response = await client.post("/mcp/", json=body, headers=request_headers) + session_id = response.headers.get("mcp-session-id") + if session_id is not None: + client.headers["mcp-session-id"] = session_id + return response async def _call_tool( diff --git a/tests/test_mcp_stateful_session.py b/tests/test_mcp_stateful_session.py new file mode 100644 index 000000000..dd22e86fa --- /dev/null +++ b/tests/test_mcp_stateful_session.py @@ -0,0 +1,305 @@ +"""Coverage for stateful streamable-http (the default). + +The default is ``stateless_http=False`` — the only mode where +``StreamableHTTPSessionManager``'s idle-reap path runs. Stateless +mode in upstream MCP holds GET-SSE streams open without idle +eviction; production adopters saw connections accumulate. Adopters +who can't run stateful (multi-replica without sticky LB on +``Mcp-Session-Id``) opt back into stateless via +``stateless_http=True``; ``session_idle_timeout`` (default 1800s) +caps idle stateful sessions. + +These tests exercise: +1. Default builds a stateful session manager with + ``session_idle_timeout=1800``. +2. End-to-end auth propagation: the SDK's built-in + ``BearerTokenAuthMiddleware`` + ``auth_context_factory`` + surface the principal/tenant on a real ``tools/call`` POST + through the real session manager. +3. The session-id contract: a stateful session reuses across calls + and rejects requests without ``Mcp-Session-Id``. +4. The upstream constraint that ``session_idle_timeout`` cannot + combine with ``stateless=True`` — we suppress the timeout + automatically when adopters opt back into stateless rather than + letting the upstream constructor raise. +""" + +from __future__ import annotations + +from typing import Any + +import httpx +import pytest +from asgi_lifespan import LifespanManager + +from adcp.server import ADCPHandler, create_mcp_server + + +class _BareHandler(ADCPHandler[Any]): + """Minimal handler — exercises only server construction.""" + + +def test_default_is_stateful() -> None: + """Default flipped from stateless to stateful in v4.7. Stateless + mode in upstream MCP holds GET-SSE streams open with no idle + eviction, leading to the connection-leak adopters reported (~100 + stuck connections across hours). Stateful with + ``session_idle_timeout=1800`` reaps abandoned sessions properly.""" + mcp = create_mcp_server(_BareHandler(), name="t", advertise_all=True) + assert mcp.settings.stateless_http is False + assert mcp.settings.json_response is True + assert mcp._session_manager.session_idle_timeout == 1800.0 + assert mcp._session_manager.stateless is False + + +def test_stateless_opt_in_drops_idle_timeout() -> None: + """Adopters who explicitly opt into stateless (multi-replica + without affinity, no shared session store) need the upstream + constructor to accept ``stateless=True`` — which forbids + ``session_idle_timeout``. Verify we suppress before construction.""" + mcp = create_mcp_server(_BareHandler(), name="t", advertise_all=True, stateless_http=True) + assert mcp.settings.stateless_http is True + assert mcp._session_manager.session_idle_timeout is None + assert mcp._session_manager.stateless is True + + +def test_stateful_opt_in_explicit_timeout() -> None: + mcp = create_mcp_server( + _BareHandler(), + name="t", + advertise_all=True, + stateless_http=False, + session_idle_timeout=600.0, + ) + assert mcp._session_manager.session_idle_timeout == 600.0 + + +def test_stateful_with_disabled_timeout() -> None: + mcp = create_mcp_server( + _BareHandler(), + name="t", + advertise_all=True, + stateless_http=False, + session_idle_timeout=None, + ) + # Adopter explicitly opted out of reaping — pass through. + assert mcp._session_manager.session_idle_timeout is None + assert mcp._session_manager.stateless is False + + +def test_stateless_suppresses_caller_supplied_timeout() -> None: + """Upstream raises if ``stateless=True`` AND + ``session_idle_timeout`` is set. We suppress before construction so + the default-arg combo doesn't blow up adopters who never touched + these knobs.""" + mcp = create_mcp_server( + _BareHandler(), + name="t", + advertise_all=True, + stateless_http=True, + session_idle_timeout=600.0, + ) + assert mcp._session_manager.session_idle_timeout is None + + +def test_negative_idle_timeout_rejected_at_boundary() -> None: + """Upstream raises ``ValueError`` for ``session_idle_timeout <= 0``. + We catch it at the SDK boundary so adopters see a message that + mentions the framework parameter name, not an upstream stack trace.""" + with pytest.raises(ValueError, match="session_idle_timeout must be positive"): + create_mcp_server( + _BareHandler(), + name="t", + advertise_all=True, + stateless_http=False, + session_idle_timeout=0, + ) + + +def test_streaming_responses_keeps_json_response_off() -> None: + """Adopters who genuinely emit progress events flip + ``streaming_responses=True``; that path must NOT also force + ``json_response=True`` (which would defeat the point).""" + mcp = create_mcp_server( + _BareHandler(), + name="t", + advertise_all=True, + streaming_responses=True, + ) + # FastMCP's default is False; staying False means the SSE-stream + # code path remains active for tools that need progress. + assert mcp.settings.json_response is False + + +@pytest.mark.asyncio +async def test_stateful_session_reuses_across_calls() -> None: + """End-to-end sanity: the same ``Mcp-Session-Id`` from + ``initialize`` is accepted on a follow-up ``tools/list``. In + stateless mode the second request would fail with "Missing session + ID" (as today's middleware tests show); proving stateful mode + accepts the bound session id is the regression guard.""" + mcp = create_mcp_server( + _BareHandler(), + name="t", + advertise_all=True, + stateless_http=False, + allowed_hosts=["localhost", "127.0.0.1"], + ) + app = mcp.streamable_http_app() + headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + } + + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://localhost", + follow_redirects=True, + ) as client: + init_resp = await client.post( + "/mcp/", + json={ + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": { + "protocolVersion": "2025-06-18", + "capabilities": {}, + "clientInfo": {"name": "t", "version": "1"}, + }, + }, + headers=headers, + ) + assert init_resp.status_code == 200, init_resp.text + session_id = init_resp.headers.get("mcp-session-id") + assert session_id, "stateful mode must return Mcp-Session-Id" + + list_resp = await client.post( + "/mcp/", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}}, + headers={**headers, "mcp-session-id": session_id}, + ) + assert list_resp.status_code == 200, list_resp.text + + +@pytest.mark.asyncio +async def test_stateful_auth_propagates_via_request_state() -> None: + """The headline guarantee for the default flip: in stateful mode + (the default), middleware-set state on the Starlette ``Request`` + reaches ``context_factory`` via ``meta.request_context.state``. + Without this, the contextvar-only auth pattern would silently fail + in production because the session task is a different async task + than the middleware's. This test wires the SDK's built-in + ``BearerTokenAuthMiddleware`` + ``auth_context_factory`` end-to-end + and asserts the principal/tenant arrive at the handler.""" + from adcp.server import ( + BearerTokenAuthMiddleware, + Principal, + ToolContext, + auth_context_factory, + create_mcp_server, + validator_from_token_map, + ) + + received: dict[str, Any] = {} + + class _Recording(_BareHandler): + async def get_products( + self, params: Any, context: ToolContext | None = None + ) -> dict[str, Any]: + received["caller_identity"] = context.caller_identity if context is not None else None + received["tenant_id"] = context.tenant_id if context is not None else None + return {"products": []} + + mcp = create_mcp_server( + _Recording(), + name="t", + advertise_all=True, + context_factory=auth_context_factory, + allowed_hosts=["localhost", "127.0.0.1"], + validation=None, + ) + app = mcp.streamable_http_app() + app.add_middleware( + BearerTokenAuthMiddleware, + validate_token=validator_from_token_map( + {"tk-acme": Principal(caller_identity="p-acme", tenant_id="t-acme")} + ), + ) + + headers = { + "content-type": "application/json", + "accept": "application/json, text/event-stream", + "authorization": "Bearer tk-acme", + } + + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://localhost", + follow_redirects=True, + ) as client: + init = await client.post( + "/mcp/", + json={ + "jsonrpc": "2.0", + "id": 0, + "method": "initialize", + "params": { + "protocolVersion": "2025-06-18", + "capabilities": {}, + "clientInfo": {"name": "t", "version": "1"}, + }, + }, + headers=headers, + ) + assert init.status_code == 200, init.text + session_id = init.headers["mcp-session-id"] + + call = await client.post( + "/mcp/", + json={ + "jsonrpc": "2.0", + "id": 1, + "method": "tools/call", + "params": {"name": "get_products", "arguments": {}}, + }, + headers={**headers, "mcp-session-id": session_id}, + ) + assert call.status_code == 200, call.text + + assert received["caller_identity"] == "p-acme" + assert received["tenant_id"] == "t-acme" + + +@pytest.mark.asyncio +async def test_stateful_rejects_request_without_session_id() -> None: + """Inverse of the above — without ``Mcp-Session-Id`` the upstream + SDK returns 400 ``Missing session ID``. Locks the contract that + adopters who flip stateful mode know to thread the session id.""" + mcp = create_mcp_server( + _BareHandler(), + name="t", + advertise_all=True, + stateless_http=False, + allowed_hosts=["localhost", "127.0.0.1"], + ) + app = mcp.streamable_http_app() + + async with LifespanManager(app): + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://localhost", + follow_redirects=True, + ) as client: + resp = await client.post( + "/mcp/", + json={"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}}, + headers={ + "content-type": "application/json", + "accept": "application/json, text/event-stream", + }, + ) + assert resp.status_code == 400 + assert "session" in resp.text.lower() diff --git a/tests/test_tools_list_output_schema.py b/tests/test_tools_list_output_schema.py index 676f0c822..dae3ad86f 100644 --- a/tests/test_tools_list_output_schema.py +++ b/tests/test_tools_list_output_schema.py @@ -125,6 +125,13 @@ async def _initialize_session(client: httpx.AsyncClient) -> None: } resp = await client.post("/mcp/", json=body, headers=headers) assert resp.status_code == 200, resp.text + # Stateful streamable-http binds subsequent requests to the + # ``Mcp-Session-Id`` returned by ``initialize``. Persist it on the + # client's default headers so ``tools/list`` and ``tools/call`` from + # tests target the same session. + session_id = resp.headers.get("mcp-session-id") + if session_id is not None: + client.headers["mcp-session-id"] = session_id async def _list_tools(client: httpx.AsyncClient) -> dict[str, Any]: