diff --git a/src/adcp/signing/__init__.py b/src/adcp/signing/__init__.py index da1effaa9..6510defd8 100644 --- a/src/adcp/signing/__init__.py +++ b/src/adcp/signing/__init__.py @@ -63,10 +63,15 @@ SignatureVerificationError, ) from adcp.signing.jwks import ( + AsyncCachingJwksResolver, + AsyncJwksFetcher, + AsyncJwksResolver, CachingJwksResolver, JwksResolver, SSRFValidationError, StaticJwksResolver, + as_async_resolver, + async_default_jwks_fetcher, default_jwks_fetcher, validate_jwks_uri, ) @@ -75,6 +80,9 @@ JwsMalformedError, JwsSignatureInvalidError, JwsUnknownKeyError, + averify_detached_jws, + averify_jws_document, + verify_detached_jws, verify_jws_document, ) from adcp.signing.middleware import ( @@ -87,12 +95,15 @@ from adcp.signing.revocation_fetcher import ( DEFAULT_GRACE_MULTIPLIER, REVOCATION_LIST_TYP, + AsyncCachingRevocationChecker, + AsyncRevocationListFetcher, CachingRevocationChecker, FetchResult, RevocationListFetcher, RevocationListFetchError, RevocationListFreshnessError, RevocationListParseError, + async_default_revocation_list_fetcher, default_revocation_list_fetcher, ) from adcp.signing.signer import ( @@ -110,6 +121,11 @@ "ALG_ED25519", "ALG_ES256", "ALLOWED_ALGS", + "AsyncCachingJwksResolver", + "AsyncCachingRevocationChecker", + "AsyncJwksFetcher", + "AsyncJwksResolver", + "AsyncRevocationListFetcher", "CachingJwksResolver", "CachingRevocationChecker", "DEFAULT_EXPIRES_IN_SECONDS", @@ -163,6 +179,11 @@ "VerifierCapability", "VerifyOptions", "alg_for_jwk", + "as_async_resolver", + "async_default_jwks_fetcher", + "async_default_revocation_list_fetcher", + "averify_detached_jws", + "averify_jws_document", "b64url_decode", "b64url_encode", "build_signature_base", @@ -182,6 +203,7 @@ "sign_signature_base", "unauthorized_response_headers", "validate_jwks_uri", + "verify_detached_jws", "verify_flask_request", "verify_jws_document", "verify_request_signature", diff --git a/src/adcp/signing/jwks.py b/src/adcp/signing/jwks.py index 29113e09b..168473f4c 100644 --- a/src/adcp/signing/jwks.py +++ b/src/adcp/signing/jwks.py @@ -12,10 +12,19 @@ tracked in #190 and not implemented here — the current design is vulnerable to a TOCTOU where DNS resolves to an allowed IP during validation and a blocked IP at connect time. + +Naming conventions +------------------ +* Classes use the ``Async`` CapWords prefix (``AsyncCachingJwksResolver``). +* Free functions use the ``async_`` snake_case prefix + (``async_default_jwks_fetcher``). +* Methods use the ``a`` prefix (``aclose``, ``aprime``) — matches the + ``httpx`` / ``anyio`` ecosystem. """ from __future__ import annotations +import asyncio import ipaddress import socket import time @@ -60,18 +69,43 @@ class JwksFetcher(Protocol): def __call__(self, uri: str, *, allow_private: bool = False) -> dict[str, Any]: ... +class AsyncJwksFetcher(Protocol): + """Async variant of :class:`JwksFetcher`.""" + + async def __call__( + self, uri: str, *, allow_private: bool = False + ) -> dict[str, Any]: ... + + class JwksResolver(Protocol): """Resolves a keyid to a JWK, or returns None if unknown. - The canonical Protocol used by both the RFC 9421 verifier and the JWS - document verifier. Implementations include + The canonical Protocol used by the sync RFC 9421 verifier and the + sync JWS document verifier. Implementations include :class:`StaticJwksResolver` (in-memory, for tests) and :class:`CachingJwksResolver` (fetches + caches from a URI). + + Async callers use :class:`AsyncJwksResolver` instead. """ def __call__(self, keyid: str) -> dict[str, Any] | None: ... +class AsyncJwksResolver(Protocol): + """Async variant of :class:`JwksResolver`. + + Used by the async JWS document verifier and the async revocation + checker so JWKS cache-misses don't block the event loop. + Implementations: :class:`AsyncCachingJwksResolver`. For tests, + :class:`StaticJwksResolver` doubles as an async resolver if you wrap + it in a thin async callable — there's no async work, just a dict + lookup — but typically you'll just use the static one directly where + an :class:`AsyncJwksResolver` is expected via :func:`as_async`. + """ + + async def __call__(self, keyid: str) -> dict[str, Any] | None: ... + + def validate_jwks_uri(uri: str, *, allow_private: bool = False) -> None: """Raise SSRFValidationError if `uri` resolves to a blocked IP or has a bad scheme.""" parts = urlsplit(uri) @@ -200,8 +234,131 @@ def __call__(self, keyid: str) -> dict[str, Any] | None: return self._keys.get(keyid) +# --------------------------------------------------------------------------- +# Async variants +# --------------------------------------------------------------------------- + + +async def async_default_jwks_fetcher( + uri: str, *, allow_private: bool = False +) -> dict[str, Any]: + """Async counterpart to :func:`default_jwks_fetcher`. + + Uses :class:`httpx.AsyncClient` so callers on an asyncio event loop + don't block the loop on JWKS fetches. Same SSRF + follow-redirects + rules as the sync version. + """ + validate_jwks_uri(uri, allow_private=allow_private) + async with httpx.AsyncClient( + timeout=DEFAULT_JWKS_TIMEOUT_SECONDS, follow_redirects=False + ) as client: + response = await client.get(uri, headers={"Accept": "application/json"}) + response.raise_for_status() + body = response.json() + if not isinstance(body, dict) or "keys" not in body: + raise ValueError(f"JWKS document at {uri!r} has no 'keys' array") + return body + + +class AsyncCachingJwksResolver: + """Async JWKS resolver with per-URI cache and refetch cooldown. + + Identical semantics to :class:`CachingJwksResolver` — per-``kid`` + cache, cooldown-gated refresh on miss, SSRF errors surface as + ``request_signature_jwks_untrusted``, network errors as + ``request_signature_jwks_unavailable`` — but awaitable and backed by + :class:`httpx.AsyncClient`. + + Concurrency: a single :class:`asyncio.Lock` serializes refreshes so + N parallel verifying tasks all driving the first miss don't fire N + fetches. + """ + + def __init__( + self, + jwks_uri: str, + *, + fetcher: AsyncJwksFetcher | None = None, + cooldown_seconds: float = DEFAULT_JWKS_COOLDOWN_SECONDS, + allow_private: bool = False, + clock: Callable[[], float] = time.monotonic, + ) -> None: + self._jwks_uri = jwks_uri + self._fetcher = fetcher or async_default_jwks_fetcher + self._cooldown = cooldown_seconds + self._allow_private = allow_private + self._clock = clock + self._cache: dict[str, dict[str, Any]] = {} + self._last_attempt: float | None = None + self._primed = False + # Construct the lock eagerly. Lazy init was racy: two tasks both + # seeing ``self._lock is None`` would each construct a separate + # Lock and proceed in parallel without serialization. In Python + # 3.10+ ``asyncio.Lock()`` no longer requires a running loop at + # construction time — it binds to whatever loop is running the + # first ``async with`` call. Instances are per-loop: don't share + # across ``asyncio.run`` boundaries. + self._lock: asyncio.Lock = asyncio.Lock() + + async def __call__(self, keyid: str) -> dict[str, Any] | None: + if keyid in self._cache: + return self._cache[keyid] + now = self._clock() + if not self._primed or ( + self._last_attempt is not None and now - self._last_attempt >= self._cooldown + ): + async with self._lock: + # Re-check after acquiring: another task may have refreshed. + if keyid in self._cache: + return self._cache[keyid] + now = self._clock() + if not self._primed or ( + self._last_attempt is not None + and now - self._last_attempt >= self._cooldown + ): + await self._refresh(now) + return self._cache.get(keyid) + + async def _refresh(self, now: float) -> None: + self._last_attempt = now + try: + jwks = await self._fetcher(self._jwks_uri, allow_private=self._allow_private) + except SSRFValidationError as exc: + raise SignatureVerificationError( + REQUEST_SIGNATURE_JWKS_UNTRUSTED, + step=7, + message=f"JWKS URI failed SSRF check: {exc}", + ) from exc + except (httpx.HTTPError, ValueError, OSError) as exc: + raise SignatureVerificationError( + REQUEST_SIGNATURE_JWKS_UNAVAILABLE, + step=7, + message=f"JWKS fetch failed: {exc}", + ) from exc + self._primed = True + self._cache = {jwk["kid"]: jwk for jwk in jwks.get("keys", []) if "kid" in jwk} + + +def as_async_resolver(resolver: JwksResolver) -> AsyncJwksResolver: + """Wrap a sync :class:`JwksResolver` so it satisfies :class:`AsyncJwksResolver`. + + Useful for tests: pass a :class:`StaticJwksResolver` through + :func:`as_async_resolver` to plug it into an async-verifier + pipeline that types ``AsyncJwksResolver``. There's no real async + work (just a dict lookup); the wrapper is a shape adapter. + """ + + async def resolve(keyid: str) -> dict[str, Any] | None: + return resolver(keyid) + + return resolve + + __all__ = [ "BLOCKED_METADATA_IPS", + "AsyncCachingJwksResolver", + "AsyncJwksFetcher", + "AsyncJwksResolver", "CachingJwksResolver", "DEFAULT_JWKS_COOLDOWN_SECONDS", "DEFAULT_JWKS_TIMEOUT_SECONDS", @@ -209,6 +366,8 @@ def __call__(self, keyid: str) -> dict[str, Any] | None: "JwksResolver", "SSRFValidationError", "StaticJwksResolver", + "as_async_resolver", + "async_default_jwks_fetcher", "default_jwks_fetcher", "validate_jwks_uri", ] diff --git a/src/adcp/signing/jws.py b/src/adcp/signing/jws.py index d50ea5de6..8c4501111 100644 --- a/src/adcp/signing/jws.py +++ b/src/adcp/signing/jws.py @@ -33,7 +33,7 @@ public_key_from_jwk, verify_signature, ) -from adcp.signing.jwks import JwksResolver +from adcp.signing.jwks import AsyncJwksResolver, JwksResolver # JWS uses RFC 7518 algorithm names; RFC 9421 uses the IANA HTTP Signature # names. We convert at the JWS boundary so the rest of the module speaks @@ -138,32 +138,15 @@ def parse_general_json_jws(doc: dict[str, Any]) -> tuple[str, str, bytes]: return b64_header, b64_payload, signature -def verify_detached_jws( - *, +def _validate_header( b64_protected: str, - b64_payload: str, - signature: bytes, - jwks_resolver: JwksResolver, + *, expected_typ: str, - allowed_algs: frozenset[str] = ALLOWED_JWS_ALGS, -) -> dict[str, Any]: - """Verify a parsed JWS and return the decoded payload JSON. - - Performs, in order: - 1. Decode + parse the protected header. - 2. Reject if ``alg`` is absent, is ``none``, or not in ``allowed_algs``. - 3. Reject if ``typ`` is not exactly ``expected_typ`` (byte equality, no - normalization — matches the AdCP spec). - 4. Resolve ``kid`` via ``jwks_resolver``; reject unknown. - 5. Reconstruct the signature base - (``b64_protected + "." + b64_payload``) using the ORIGINAL - base64url strings and verify the signature bytes with the existing - HTTP-signature crypto. Using the original strings (not decode + - re-encode) defends against lenient-decode mismatches. - 6. Decode the payload as JSON and return it. + allowed_algs: frozenset[str], +) -> tuple[str, str]: + """Parse + validate the protected header; return ``(internal_alg, kid)``. - Any failure raises a :class:`JwsError` subclass. The caller maps these - to transport-error codes (e.g. ``request_signature_revocation_stale``). + Shared by sync and async verify paths — everything here is pure CPU. """ header = _decode_protected_header(b64_protected) @@ -180,30 +163,32 @@ def verify_detached_jws( f"JWS typ {typ!r} does not match expected {expected_typ!r}" ) - # crit handling: the AdCP profile defines no extensions, so any - # unrecognized crit entry means the caller can't safely process it. - # Only allow crit if it's empty / absent. crit = header.get("crit") if crit is not None and (not isinstance(crit, list) or len(crit) > 0): - raise JwsMalformedError( - "JWS 'crit' header is not supported for this profile" - ) + raise JwsMalformedError("JWS 'crit' header is not supported for this profile") kid = header.get("kid") if not isinstance(kid, str) or not kid: raise JwsMalformedError("JWS protected header must include a non-empty 'kid'") - jwk = jwks_resolver(kid) - if jwk is None: - raise JwsUnknownKeyError(f"no JWK for kid {kid!r}") + return internal_alg, kid + + +def _verify_signature_and_decode_payload( + *, + b64_protected: str, + b64_payload: str, + signature: bytes, + jwk: dict[str, Any], + internal_alg: str, + kid: str, +) -> dict[str, Any]: + """Run the cryptographic verify + decode the payload as JSON. - # Reconstruct the detached signature base per RFC 7515 §5.1 step 5: - # ASCII(BASE64URL(protected header)) || "." || ASCII(BASE64URL(payload)). - # Use the ORIGINAL base64url strings — don't decode-then-re-encode, since - # ``urlsafe_b64decode`` is lenient (accepts padding, tolerates some - # alphabet variants) and a round-trip can produce different bytes than - # the wire form. Verifying against the re-encoded bytes would let a - # crafted token verify against bytes the signer never signed. + Signing input uses the ORIGINAL base64url strings — don't decode- + then-re-encode, since ``urlsafe_b64decode`` is lenient and a + round-trip can produce different bytes than the wire form. + """ signing_input = (b64_protected + "." + b64_payload).encode("ascii") public_key = public_key_from_jwk(jwk) @@ -215,8 +200,6 @@ def verify_detached_jws( ): raise JwsSignatureInvalidError(f"signature did not verify for kid {kid!r}") - # Now that the signature is verified, decode the payload bytes from the - # trusted b64 string and parse as JSON. try: payload_bytes = b64url_decode(b64_payload) except (ValueError, binascii.Error) as exc: @@ -230,6 +213,75 @@ def verify_detached_jws( return decoded_payload +def verify_detached_jws( + *, + b64_protected: str, + b64_payload: str, + signature: bytes, + jwks_resolver: JwksResolver, + expected_typ: str, + allowed_algs: frozenset[str] = ALLOWED_JWS_ALGS, +) -> dict[str, Any]: + """Verify a parsed JWS and return the decoded payload JSON. + + Performs, in order: + 1. Decode + parse the protected header. + 2. Reject if ``alg`` is absent, is ``none``, or not in ``allowed_algs``. + 3. Reject if ``typ`` is not exactly ``expected_typ`` (byte equality, no + normalization — matches the AdCP spec). + 4. Resolve ``kid`` via ``jwks_resolver``; reject unknown. + 5. Verify the detached signature against the original b64 strings. + 6. Decode the payload as JSON and return it. + + Any failure raises a :class:`JwsError` subclass. The caller maps these + to transport-error codes (e.g. ``request_signature_revocation_stale``). + """ + internal_alg, kid = _validate_header( + b64_protected, expected_typ=expected_typ, allowed_algs=allowed_algs + ) + jwk = jwks_resolver(kid) + if jwk is None: + raise JwsUnknownKeyError(f"no JWK for kid {kid!r}") + return _verify_signature_and_decode_payload( + b64_protected=b64_protected, + b64_payload=b64_payload, + signature=signature, + jwk=jwk, + internal_alg=internal_alg, + kid=kid, + ) + + +async def averify_detached_jws( + *, + b64_protected: str, + b64_payload: str, + signature: bytes, + jwks_resolver: AsyncJwksResolver, + expected_typ: str, + allowed_algs: frozenset[str] = ALLOWED_JWS_ALGS, +) -> dict[str, Any]: + """Async variant of :func:`verify_detached_jws`. + + Awaits an :class:`AsyncJwksResolver` on the ``kid`` lookup; all + other work is pure CPU and runs inline. + """ + internal_alg, kid = _validate_header( + b64_protected, expected_typ=expected_typ, allowed_algs=allowed_algs + ) + jwk = await jwks_resolver(kid) + if jwk is None: + raise JwsUnknownKeyError(f"no JWK for kid {kid!r}") + return _verify_signature_and_decode_payload( + b64_protected=b64_protected, + b64_payload=b64_payload, + signature=signature, + jwk=jwk, + internal_alg=internal_alg, + kid=kid, + ) + + def verify_jws_document( doc: str | dict[str, Any], *, @@ -260,14 +312,47 @@ def verify_jws_document( ) +async def averify_jws_document( + doc: str | dict[str, Any], + *, + jwks_resolver: AsyncJwksResolver, + expected_typ: str, + allowed_algs: frozenset[str] = ALLOWED_JWS_ALGS, +) -> dict[str, Any]: + """Async variant of :func:`verify_jws_document`. + + Accepts an :class:`AsyncJwksResolver`; parsing and signature verify + are pure CPU and run inline. + """ + if isinstance(doc, str): + b64_header, b64_payload, signature = parse_compact_jws(doc) + elif isinstance(doc, dict): + b64_header, b64_payload, signature = parse_general_json_jws(doc) + else: + raise JwsMalformedError( + "JWS document must be a compact string or JSON general-serialization object" + ) + return await averify_detached_jws( + b64_protected=b64_header, + b64_payload=b64_payload, + signature=signature, + jwks_resolver=jwks_resolver, + expected_typ=expected_typ, + allowed_algs=allowed_algs, + ) + + __all__ = [ "ALLOWED_JWS_ALGS", + "AsyncJwksResolver", "JWS_ALG_TO_INTERNAL", "JwksResolver", "JwsError", "JwsMalformedError", "JwsSignatureInvalidError", "JwsUnknownKeyError", + "averify_detached_jws", + "averify_jws_document", "parse_compact_jws", "parse_general_json_jws", "verify_detached_jws", diff --git a/src/adcp/signing/revocation_fetcher.py b/src/adcp/signing/revocation_fetcher.py index a5e915699..8593c7c9c 100644 --- a/src/adcp/signing/revocation_fetcher.py +++ b/src/adcp/signing/revocation_fetcher.py @@ -13,12 +13,25 @@ :class:`adcp.signing.revocation.RevocationChecker` Protocol. Handles first fetch, refetch near ``next_update``, 304s, and the spec's fail-closed rule past ``next_update + grace``. +* :class:`AsyncCachingRevocationChecker` and + :func:`async_default_revocation_list_fetcher` — async counterparts + for verifiers running on an asyncio event loop. The verifier plugs it into :class:`VerifyOptions.revocation_checker`. + +Naming conventions +------------------ +* Classes use the ``Async`` CapWords prefix + (``AsyncCachingRevocationChecker``). +* Free functions use the ``async_`` snake_case prefix + (``async_default_revocation_list_fetcher``). +* Methods use the ``a`` prefix (``aprime``) — matches ``httpx`` / + ``anyio`` ecosystem idioms. """ from __future__ import annotations +import asyncio import json import logging import re @@ -32,11 +45,13 @@ from adcp.signing.jwks import ( DEFAULT_JWKS_TIMEOUT_SECONDS, + AsyncJwksResolver, JwksResolver, validate_jwks_uri, ) from adcp.signing.jws import ( JwsError, + averify_jws_document, verify_jws_document, ) from adcp.signing.revocation import RevocationList @@ -150,56 +165,59 @@ def __call__( ) -> FetchResult: ... -def default_revocation_list_fetcher( - uri: str, - *, - if_none_match: str | None = None, - if_modified_since: str | None = None, - allow_private: bool = False, - timeout: float = DEFAULT_JWKS_TIMEOUT_SECONDS, -) -> FetchResult: - """HTTPS GET the revocation list, honoring SSRF rules and conditional requests. +class AsyncRevocationListFetcher(Protocol): + """Async variant of :class:`RevocationListFetcher`. - Reuses ``validate_jwks_uri`` — the SSRF controls are identical (same - reserved-range rejection, same cloud-metadata block). ``httpx`` - re-resolves the hostname on connect, which is the TOCTOU window - tracked separately in #190. Sends ``If-None-Match`` when an ETag is - supplied and ``If-Modified-Since`` when a ``Last-Modified`` is - supplied; the spec accepts either (sellers SHOULD use both when - available). + Used by :class:`AsyncCachingRevocationChecker` so async verifier + pipelines don't block the event loop on revocation-list fetches. """ - validate_jwks_uri(uri, allow_private=allow_private) + + async def __call__( + self, + uri: str, + *, + if_none_match: str | None = None, + if_modified_since: str | None = None, + ) -> FetchResult: ... + + +def _build_fetch_headers( + *, if_none_match: str | None, if_modified_since: str | None +) -> dict[str, str]: headers = {"Accept": "application/jose+json, application/json, application/jose"} if if_none_match is not None: headers["If-None-Match"] = if_none_match if if_modified_since is not None: headers["If-Modified-Since"] = if_modified_since + return headers - try: - with httpx.Client(timeout=timeout, follow_redirects=False) as client: - response = client.get(uri, headers=headers) - except httpx.HTTPError as exc: - raise RevocationListFetchError(f"revocation list GET {uri!r} failed: {exc}") from exc - if response.status_code == 304: +def _fetch_result_from_response( + uri: str, + status_code: int, + response_text: str, + response_headers: Any, + *, + if_none_match: str | None, + if_modified_since: str | None, +) -> FetchResult: + """Shared response → FetchResult conversion for sync + async fetchers.""" + if status_code == 304: return FetchResult( body="", etag=if_none_match, last_modified=if_modified_since, not_modified=True, ) - if response.status_code != 200: + if status_code != 200: raise RevocationListFetchError( - f"revocation list {uri!r} returned HTTP {response.status_code}" + f"revocation list {uri!r} returned HTTP {status_code}" ) - etag = response.headers.get("ETag") - last_modified = _sanitize_last_modified(response.headers.get("Last-Modified")) - raw_body = response.text.strip() + etag = response_headers.get("ETag") + last_modified = _sanitize_last_modified(response_headers.get("Last-Modified")) + raw_body = response_text.strip() - # General JSON serialization starts with `{`; compact form is three - # base64url segments separated by dots. Dispatch by first-byte shape - # rather than trusting Content-Type, which is unreliable in practice. if not raw_body: raise RevocationListFetchError(f"revocation list {uri!r} returned empty body") @@ -222,6 +240,78 @@ def default_revocation_list_fetcher( ) +def default_revocation_list_fetcher( + uri: str, + *, + if_none_match: str | None = None, + if_modified_since: str | None = None, + allow_private: bool = False, + timeout: float = DEFAULT_JWKS_TIMEOUT_SECONDS, +) -> FetchResult: + """HTTPS GET the revocation list, honoring SSRF rules and conditional requests. + + Reuses ``validate_jwks_uri`` — the SSRF controls are identical (same + reserved-range rejection, same cloud-metadata block). ``httpx`` + re-resolves the hostname on connect, which is the TOCTOU window + tracked separately in #190. Sends ``If-None-Match`` when an ETag is + supplied and ``If-Modified-Since`` when a ``Last-Modified`` is + supplied; the spec accepts either (sellers SHOULD use both when + available). + """ + validate_jwks_uri(uri, allow_private=allow_private) + headers = _build_fetch_headers( + if_none_match=if_none_match, if_modified_since=if_modified_since + ) + try: + with httpx.Client(timeout=timeout, follow_redirects=False) as client: + response = client.get(uri, headers=headers) + except httpx.HTTPError as exc: + raise RevocationListFetchError(f"revocation list GET {uri!r} failed: {exc}") from exc + + return _fetch_result_from_response( + uri, + response.status_code, + response.text, + response.headers, + if_none_match=if_none_match, + if_modified_since=if_modified_since, + ) + + +async def async_default_revocation_list_fetcher( + uri: str, + *, + if_none_match: str | None = None, + if_modified_since: str | None = None, + allow_private: bool = False, + timeout: float = DEFAULT_JWKS_TIMEOUT_SECONDS, +) -> FetchResult: + """Async counterpart to :func:`default_revocation_list_fetcher`. + + Same SSRF + conditional-request behavior, but uses + :class:`httpx.AsyncClient` so the event loop isn't blocked during + the round-trip. + """ + validate_jwks_uri(uri, allow_private=allow_private) + headers = _build_fetch_headers( + if_none_match=if_none_match, if_modified_since=if_modified_since + ) + try: + async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client: + response = await client.get(uri, headers=headers) + except httpx.HTTPError as exc: + raise RevocationListFetchError(f"revocation list GET {uri!r} failed: {exc}") from exc + + return _fetch_result_from_response( + uri, + response.status_code, + response.text, + response.headers, + if_none_match=if_none_match, + if_modified_since=if_modified_since, + ) + + def _sanitize_last_modified(raw: str | None) -> str | None: """Validate a ``Last-Modified`` header value before persisting it. @@ -289,6 +379,149 @@ def _normalize_issuer(issuer: str) -> str: return urlunsplit((scheme, netloc, "", "", "")) +def _slide_next_update( + current: RevocationList, polling_interval_seconds: float +) -> RevocationList: + """Return ``current`` with ``next_update`` advanced by one polling interval. + + Used on a 304 response so the cached list's freshness window slides + forward without needing a fresh JWS. Preserves every other field. + """ + prior = _parse_iso8601(current.next_update) + new_next_update = prior + timedelta(seconds=polling_interval_seconds) + return RevocationList( + issuer=current.issuer, + updated=current.updated, + next_update=new_next_update.isoformat().replace("+00:00", "Z"), + revoked_kids=current.revoked_kids, + revoked_jtis=current.revoked_jtis, + ) + + +def _post_jws_validation( + payload: dict[str, Any], + *, + expected_issuer: str, + now_wall: datetime, + current_list: RevocationList | None, +) -> RevocationList: + """Schema + clock-skew + replay check on an already-verified JWS payload. + + Everything in this helper is pure CPU — runs identically from sync + and async refresh paths. Raises :class:`RevocationListParseError` on + any violation. + """ + revocation_list = _build_list_from_payload(payload, expected_issuer=expected_issuer) + + updated = _parse_iso8601(revocation_list.updated) + next_update = _parse_iso8601(revocation_list.next_update) + + # Verify `updated` is not in the future beyond clock skew. An issuer + # whose clock is far ahead would otherwise force an immediate stale + # rejection. + if updated > now_wall.replace(microsecond=0): + delta = (updated - now_wall).total_seconds() + if delta > 60: # 60s clock skew tolerance, mirrors JWS exp/iat rules + raise RevocationListParseError( + f"revocation list updated={revocation_list.updated!r} is " + f"{delta:.0f}s in the future" + ) + if next_update <= updated: + raise RevocationListParseError( + f"revocation list next_update {revocation_list.next_update!r} is not " + f"after updated {revocation_list.updated!r}" + ) + + # Reject a freshly-fetched list whose `updated` is older than the + # one we already have cached. Defense against CDN replay or + # compromised operator serving an older list with revocations + # removed. The spec doesn't permit un-revocation. + if current_list is not None: + current_updated = _parse_iso8601(current_list.updated) + if updated < current_updated: + raise RevocationListParseError( + f"revocation list updated={revocation_list.updated!r} is older " + f"than cached list updated={current_list.updated!r} — " + f"refusing to roll back" + ) + + return revocation_list + + +def _polling_interval_seconds(revocation_list: RevocationList) -> float: + """Compute the clamped polling interval for a list. + + Parse-time already enforces declared cadence >= 60s and we clamp + the ceiling at :data:`MAX_POLLING_INTERVAL_SECONDS` as defense + against an issuer declaring an out-of-bounds ceiling value. + """ + updated = _parse_iso8601(revocation_list.updated) + next_update = _parse_iso8601(revocation_list.next_update) + return min(MAX_POLLING_INTERVAL_SECONDS, (next_update - updated).total_seconds()) + + +class _CheckerState: + """Mixin for the mutable cache state shared by sync + async checkers. + + Both checkers carry identical cache fields and run identical + post-fetch transitions (304-slide, commit, grace calculation). + Keeping this as a mixin lets either class inherit without + duplicating the methods — the alternative (module-level helpers + that take a mutable state dict) reads worse and loses type info. + """ + + _current_list: RevocationList | None + _current_etag: str | None + _current_last_modified: str | None + _last_successful_refresh: float | None + _last_polling_interval_seconds: float | None + _last_refresh_attempt: float | None + _grace_multiplier: float + + def _init_state(self) -> None: + self._current_list = None + self._current_etag = None + self._current_last_modified = None + self._last_successful_refresh = None + self._last_polling_interval_seconds = None + self._last_refresh_attempt = None + + def _handle_not_modified(self, *, now_mono: float) -> None: + """Slide ``next_update`` forward on a 304 response. + + Without this, subsequent calls past the original ``next_update`` + would re-enter the refresh branch on every verification (gated + only by the 60s cooldown). Advancing the cached + ``next_update`` by one polling interval lets the hot path + short-circuit cleanly. + """ + self._last_successful_refresh = now_mono + if self._current_list is not None and self._last_polling_interval_seconds: + self._current_list = _slide_next_update( + self._current_list, self._last_polling_interval_seconds + ) + + def _commit( + self, + *, + result: FetchResult, + revocation_list: RevocationList, + now_mono: float, + ) -> None: + """Commit a validated list to the cache.""" + self._current_list = revocation_list + self._current_etag = result.etag + # Sanitize again at write-side: custom fetcher impls may not have + # validated the value before constructing FetchResult. + self._current_last_modified = _sanitize_last_modified(result.last_modified) + self._last_successful_refresh = now_mono + self._last_polling_interval_seconds = _polling_interval_seconds(revocation_list) + + def _grace_seconds(self) -> float: + interval = self._last_polling_interval_seconds or MAX_POLLING_INTERVAL_SECONDS + return interval * self._grace_multiplier + + def _build_list_from_payload(payload: dict[str, Any], expected_issuer: str) -> RevocationList: """Validate the JWS payload schema and assemble a ``RevocationList``. @@ -343,7 +576,7 @@ def _build_list_from_payload(payload: dict[str, Any], expected_issuer: str) -> R return RevocationList.from_dict(payload) -class CachingRevocationChecker: +class CachingRevocationChecker(_CheckerState): """Live revocation checker with caching, refetch, grace, and fail-closed. Implements the ``RevocationChecker`` Protocol — callable as @@ -440,16 +673,7 @@ def __init__( self._clock = clock self._wall_clock = wall_clock - self._current_list: RevocationList | None = None - self._current_etag: str | None = None - self._current_last_modified: str | None = None - self._last_successful_refresh: float | None = None - self._last_polling_interval_seconds: float | None = None - # Cooldown state: when a refresh attempt fails, we don't retry - # until at least MIN_POLLING_INTERVAL_SECONDS of monotonic time - # have elapsed. Stops a high-traffic verifier from hammering a - # dead revocation endpoint. - self._last_refresh_attempt: float | None = None + self._init_state() @classmethod def from_issuer_origin( @@ -586,25 +810,7 @@ def _refresh(self, *, conditional: bool, now_wall: datetime, now_mono: float) -> if_modified_since=if_modified_since, ) if result.not_modified: - # 304: server confirms the cached list is still current. Advance - # the cached list's `next_update` by the declared polling - # interval so subsequent calls don't re-enter the past- - # next_update branch on every request — without this, the 60s - # cooldown gate would fire once per verification request once - # we cross the original next_update. - self._last_successful_refresh = now_mono - if self._current_list is not None and self._last_polling_interval_seconds: - prior = _parse_iso8601(self._current_list.next_update) - new_next_update = prior + timedelta( - seconds=self._last_polling_interval_seconds - ) - self._current_list = RevocationList( - issuer=self._current_list.issuer, - updated=self._current_list.updated, - next_update=new_next_update.isoformat().replace("+00:00", "Z"), - revoked_kids=self._current_list.revoked_kids, - revoked_jtis=self._current_list.revoked_jtis, - ) + self._handle_not_modified(now_mono=now_mono) return try: @@ -618,63 +824,235 @@ def _refresh(self, *, conditional: bool, now_wall: datetime, now_mono: float) -> f"revocation list JWS verification failed: {exc}" ) from exc - revocation_list = _build_list_from_payload(payload, expected_issuer=self._issuer) - - updated = _parse_iso8601(revocation_list.updated) - next_update = _parse_iso8601(revocation_list.next_update) - - # Verify `updated` is not in the future beyond clock skew. An issuer - # whose clock is far ahead would otherwise force an immediate stale - # rejection. - if updated > now_wall.replace(microsecond=0): - delta = (updated - now_wall).total_seconds() - if delta > 60: # 60s clock skew tolerance, mirrors JWS exp/iat rules - raise RevocationListParseError( - f"revocation list updated={revocation_list.updated!r} is " - f"{delta:.0f}s in the future" - ) - if next_update <= updated: - raise RevocationListParseError( - f"revocation list next_update {revocation_list.next_update!r} is not " - f"after updated {revocation_list.updated!r}" + revocation_list = _post_jws_validation( + payload, + expected_issuer=self._issuer, + now_wall=now_wall, + current_list=self._current_list, + ) + self._commit(result=result, revocation_list=revocation_list, now_mono=now_mono) + + +class AsyncCachingRevocationChecker(_CheckerState): + """Async counterpart to :class:`CachingRevocationChecker`. + + Same state machine, identical semantics — see the sync class for + full behavior documentation. Differences: + + * ``__call__(keyid)`` is awaitable. + * :meth:`aprime` replaces :meth:`prime`. + * :meth:`is_jti_revoked` is awaitable. + * The ``jwks_resolver`` is an :class:`AsyncJwksResolver`. + * The ``fetcher`` is an :class:`AsyncRevocationListFetcher`. + * A single ``asyncio.Lock`` serializes refreshes so N concurrent + verifying tasks that all miss on the first verification fire + exactly one fetch. (The sync class is documented single-threaded; + async concurrency is the typical case.) + + Wire it into :class:`adcp.signing.VerifyOptions.revocation_checker` + only if you have an async verifier pipeline — the sync + ``verify_request_signature`` expects a sync ``RevocationChecker`` + and will not await this class. + """ + + def __init__( + self, + *, + revocation_uri: str, + issuer: str, + jwks_resolver: AsyncJwksResolver, + fetcher: AsyncRevocationListFetcher | None = None, + grace_multiplier: float = DEFAULT_GRACE_MULTIPLIER, + clock: Callable[[], float] = time.monotonic, + wall_clock: Callable[[], datetime] = lambda: datetime.now(timezone.utc), + ) -> None: + if clock is time.time: + raise ValueError( + "clock must be a monotonic time source (use time.monotonic, " + "not time.time); wall-clock jumps would break cooldown math" + ) + if clock is wall_clock: # type: ignore[comparison-overlap] + raise ValueError( + "clock and wall_clock must be different sources — clock is " + "monotonic seconds (cooldown timing), wall_clock is a UTC " + "datetime source (freshness evaluation)" ) - # Reject a freshly-fetched list whose `updated` is older than the - # one we already have cached. Defense-in-depth against: - # - CDN replaying a stale list after a kid has been revoked, - # - operator-key compromise where an attacker serves an older list - # with revocations removed. - # The spec doesn't permit un-revocation, so `updated` MUST be - # monotonically non-decreasing across refreshes. - if self._current_list is not None: - current_updated = _parse_iso8601(self._current_list.updated) - if updated < current_updated: - raise RevocationListParseError( - f"revocation list updated={revocation_list.updated!r} is older " - f"than cached list updated={self._current_list.updated!r} — " - f"refusing to roll back" - ) + self._revocation_uri = revocation_uri + self._issuer = _normalize_issuer(issuer) + self._jwks_resolver = jwks_resolver + self._fetcher: AsyncRevocationListFetcher = ( + fetcher or async_default_revocation_list_fetcher + ) + self._grace_multiplier = grace_multiplier + self._clock = clock + self._wall_clock = wall_clock + self._init_state() + # Eager lock construction: lazy-init was racy (two tasks both + # seeing self._lock is None construct separate Locks and skip + # serialization). In Python 3.10+ asyncio.Lock() does not + # require a running loop at construction time. Instances are + # per-loop — don't reuse a checker across asyncio.run boundaries. + self._lock: asyncio.Lock = asyncio.Lock() - self._current_list = revocation_list - self._current_etag = result.etag - # Sanitize again at write-side: custom fetcher impls may not have - # validated the value before constructing FetchResult. - self._current_last_modified = _sanitize_last_modified(result.last_modified) - self._last_successful_refresh = now_mono - # Declared cadence is already validated >= 60s and bounded above - # at parse time; clamp to the spec ceiling as defense against an - # issuer declaring an out-of-bounds ceiling value. - self._last_polling_interval_seconds = min( - MAX_POLLING_INTERVAL_SECONDS, - (next_update - updated).total_seconds(), + @classmethod + def from_issuer_origin( + cls, + origin: str, + *, + jwks_resolver: AsyncJwksResolver, + fetcher: AsyncRevocationListFetcher | None = None, + grace_multiplier: float = DEFAULT_GRACE_MULTIPLIER, + clock: Callable[[], float] = time.monotonic, + wall_clock: Callable[[], datetime] = lambda: datetime.now(timezone.utc), + ) -> AsyncCachingRevocationChecker: + """Async variant of :meth:`CachingRevocationChecker.from_issuer_origin`.""" + normalized = _normalize_issuer(origin) + revocation_uri = f"{normalized}/.well-known/governance-revocations.json" + return cls( + revocation_uri=revocation_uri, + issuer=normalized, + jwks_resolver=jwks_resolver, + fetcher=fetcher, + grace_multiplier=grace_multiplier, + clock=clock, + wall_clock=wall_clock, ) - def _grace_seconds(self) -> float: - interval = self._last_polling_interval_seconds or MAX_POLLING_INTERVAL_SECONDS - return interval * self._grace_multiplier + async def aprime(self) -> None: + """Fetch and verify the revocation list. Call at startup for fail-fast. + + Async counterpart to :meth:`CachingRevocationChecker.prime`. + """ + await self._ensure_fresh() + + async def __call__(self, keyid: str) -> bool: + """Return True iff ``keyid`` is in the cached list's ``revoked_kids``. + + See :meth:`CachingRevocationChecker.__call__`. For + governance-token ``jti`` checks use :meth:`is_jti_revoked`. + """ + await self._ensure_fresh() + if self._current_list is None: + raise RevocationListFreshnessError("revocation list not available") + return self._current_list.is_revoked(keyid) + + async def is_jti_revoked(self, jti: str) -> bool: + """Async variant of :meth:`CachingRevocationChecker.is_jti_revoked`.""" + await self._ensure_fresh() + if self._current_list is None: + raise RevocationListFreshnessError("revocation list not available") + return jti in self._current_list.revoked_jtis + + async def _ensure_fresh(self) -> None: + if self._current_list is None: + async with self._lock: + if self._current_list is None: + # Recompute clocks inside the lock: the awaiting task + # may have been queued behind a slow refresh, so the + # pre-lock clocks could be stale. + now_wall = self._wall_clock() + now_mono = self._clock() + await self._refresh( + conditional=False, now_wall=now_wall, now_mono=now_mono + ) + return + + next_update = _parse_iso8601(self._current_list.next_update) + now_wall = self._wall_clock() + if now_wall < next_update: + return + + now_mono = self._clock() + since_last_attempt = ( + now_mono - self._last_refresh_attempt + if self._last_refresh_attempt is not None + else float("inf") + ) + if since_last_attempt >= MIN_POLLING_INTERVAL_SECONDS: + try: + async with self._lock: + # Re-check under the lock with fresh clock reads. + now_mono_inside = self._clock() + if self._last_refresh_attempt is None or ( + now_mono_inside - self._last_refresh_attempt + >= MIN_POLLING_INTERVAL_SECONDS + ): + now_wall_inside = self._wall_clock() + await self._refresh( + conditional=True, + now_wall=now_wall_inside, + now_mono=now_mono_inside, + ) + return + except (RevocationListFetchError, RevocationListParseError) as exc: + last_exc: Exception = exc + else: + last_exc = RevocationListFetchError( + f"refresh cooldown not elapsed ({since_last_attempt:.0f}s < " + f"{MIN_POLLING_INTERVAL_SECONDS}s)" + ) + + grace_seconds = self._grace_seconds() + if now_wall.timestamp() >= next_update.timestamp() + grace_seconds: + raise RevocationListFreshnessError( + f"revocation list {self._revocation_uri!r} past next_update " + f"({self._current_list.next_update}) + grace ({grace_seconds:.0f}s); " + f"last refresh error: {last_exc}" + ) from last_exc + + async def _refresh( + self, *, conditional: bool, now_wall: datetime, now_mono: float + ) -> None: + # Stamp the attempt BEFORE the awaitable. On CancelledError the + # finally block rolls it back so a cancelled task doesn't burn + # the 60s cooldown for the next caller — non-cancellation + # failures keep the cooldown (correct: a server rejection still + # counts as "recently tried"). Requires try/finally rather than + # setting post-fetch so the cooldown fires if the server replies + # before the caller cancels. + prior_attempt = self._last_refresh_attempt + self._last_refresh_attempt = now_mono + try: + if_none_match = self._current_etag if conditional else None + if_modified_since = self._current_last_modified if conditional else None + result = await self._fetcher( + self._revocation_uri, + if_none_match=if_none_match, + if_modified_since=if_modified_since, + ) + except asyncio.CancelledError: + # Cancellation doesn't count as "tried" — don't burn the + # cooldown for the next (non-cancelled) caller. + self._last_refresh_attempt = prior_attempt + raise + if result.not_modified: + self._handle_not_modified(now_mono=now_mono) + return + + try: + payload = await averify_jws_document( + result.body, + jwks_resolver=self._jwks_resolver, + expected_typ=REVOCATION_LIST_TYP, + ) + except JwsError as exc: + raise RevocationListSignatureError( + f"revocation list JWS verification failed: {exc}" + ) from exc + + revocation_list = _post_jws_validation( + payload, + expected_issuer=self._issuer, + now_wall=now_wall, + current_list=self._current_list, + ) + self._commit(result=result, revocation_list=revocation_list, now_mono=now_mono) __all__ = [ + "AsyncCachingRevocationChecker", + "AsyncRevocationListFetcher", "CachingRevocationChecker", "DEFAULT_GRACE_MULTIPLIER", "FetchResult", @@ -683,6 +1061,7 @@ def _grace_seconds(self) -> float: "RevocationListFetcher", "RevocationListFreshnessError", "RevocationListParseError", + "async_default_revocation_list_fetcher", "default_revocation_list_fetcher", ] diff --git a/tests/conformance/signing/test_async_revocation.py b/tests/conformance/signing/test_async_revocation.py new file mode 100644 index 000000000..250a981b2 --- /dev/null +++ b/tests/conformance/signing/test_async_revocation.py @@ -0,0 +1,649 @@ +"""Unit + e2e tests for the async revocation-list fetcher and checker. + +Mirrors the sync coverage in ``test_revocation_fetcher.py`` and +``test_revocation_e2e.py`` against the async counterparts. Uses a +scripted async fetcher for unit tests and a real Starlette app via +``httpx.ASGITransport`` + ``httpx.AsyncClient`` for e2e — no +``asyncio.run`` bridge, which is the whole point of the async path. +""" + +from __future__ import annotations + +import asyncio +import json +from collections.abc import Callable +from datetime import datetime, timedelta, timezone +from typing import Any + +import httpx +import pytest +from cryptography.hazmat.primitives.asymmetric import ed25519 +from starlette.applications import Starlette +from starlette.responses import PlainTextResponse, Response +from starlette.routing import Route + +from adcp.signing import ( + REVOCATION_LIST_TYP, + AsyncCachingJwksResolver, + AsyncCachingRevocationChecker, + AsyncJwksResolver, + FetchResult, + RevocationListFetchError, + RevocationListFreshnessError, + as_async_resolver, + async_default_revocation_list_fetcher, + averify_jws_document, +) +from adcp.signing.crypto import ALG_ED25519, b64url_encode, sign_signature_base +from adcp.signing.jws import JwsMalformedError + +ISSUER = "https://gov.example.com" +REVOCATION_URI = f"{ISSUER}/.well-known/governance-revocations.json" + + +# -- shared fixtures ---------------------------------------------------- + + +def _operator_key_and_resolver() -> tuple[ed25519.Ed25519PrivateKey, AsyncJwksResolver]: + private = ed25519.Ed25519PrivateKey.generate() + jwk = { + "kty": "OKP", + "crv": "Ed25519", + "alg": "EdDSA", + "use": "sig", + "key_ops": ["verify"], + "kid": "operator-2026", + "x": b64url_encode(private.public_key().public_bytes_raw()), + } + + async def resolver(keyid: str) -> dict[str, Any] | None: + return jwk if keyid == "operator-2026" else None + + return private, resolver + + +def _make_payload( + *, + issuer: str = ISSUER, + updated: str = "2026-04-18T14:00:00Z", + next_update: str = "2026-04-18T14:15:00Z", + revoked_kids: list[str] | None = None, + revoked_jtis: list[str] | None = None, +) -> dict[str, Any]: + return { + "version": 1, + "issuer": issuer, + "updated": updated, + "next_update": next_update, + "revoked_kids": revoked_kids or [], + "revoked_jtis": revoked_jtis or [], + } + + +def _sign_compact(payload: dict[str, Any], *, private: ed25519.Ed25519PrivateKey) -> str: + header = {"alg": "EdDSA", "kid": "operator-2026", "typ": REVOCATION_LIST_TYP} + b64_header = b64url_encode(json.dumps(header, separators=(",", ":")).encode()) + b64_payload = b64url_encode(json.dumps(payload, separators=(",", ":")).encode()) + signing_input = (b64_header + "." + b64_payload).encode("ascii") + signature = sign_signature_base( + alg=ALG_ED25519, private_key=private, signature_base=signing_input + ) + return b64_header + "." + b64_payload + "." + b64url_encode(signature) + + +class _ScriptedAsyncFetcher: + def __init__(self) -> None: + self.calls: list[tuple[str, str | None, str | None]] = [] + self._queue: list[FetchResult | Exception] = [] + + def enqueue(self, result: FetchResult | Exception) -> None: + self._queue.append(result) + + async def __call__( + self, + uri: str, + *, + if_none_match: str | None = None, + if_modified_since: str | None = None, + ) -> FetchResult: + self.calls.append((uri, if_none_match, if_modified_since)) + if not self._queue: + raise AssertionError("ScriptedAsyncFetcher had no response queued") + next_up = self._queue.pop(0) + if isinstance(next_up, Exception): + raise next_up + return next_up + + +def _controllable_clock( + start: datetime, +) -> tuple[Callable[[], datetime], Callable[[], float], Callable[[float], None]]: + now = [start] + mono = [0.0] + + def wall_clock() -> datetime: + return now[0] + + def monotonic_clock() -> float: + return mono[0] + + def advance_seconds(seconds: float) -> None: + now[0] = now[0] + timedelta(seconds=seconds) + mono[0] = mono[0] + seconds + + return wall_clock, monotonic_clock, advance_seconds + + +# -- averify_jws_document ---------------------------------------------- + + +async def test_averify_jws_round_trip() -> None: + private, resolver = _operator_key_and_resolver() + token = _sign_compact(_make_payload(), private=private) + + verified = await averify_jws_document( + token, jwks_resolver=resolver, expected_typ=REVOCATION_LIST_TYP + ) + assert verified["issuer"] == ISSUER + + +async def test_averify_jws_rejects_unknown_kid() -> None: + private, _ = _operator_key_and_resolver() + token = _sign_compact(_make_payload(), private=private) + + async def empty_resolver(_keyid: str) -> dict[str, Any] | None: + return None + + with pytest.raises(Exception) as exc_info: + await averify_jws_document( + token, jwks_resolver=empty_resolver, expected_typ=REVOCATION_LIST_TYP + ) + assert "no JWK" in str(exc_info.value) + + +async def test_averify_jws_rejects_wrong_typ() -> None: + private, resolver = _operator_key_and_resolver() + # Build a token with the wrong typ header. + header = {"alg": "EdDSA", "kid": "operator-2026", "typ": "different+jws"} + payload = _make_payload() + b64_header = b64url_encode(json.dumps(header, separators=(",", ":")).encode()) + b64_payload = b64url_encode(json.dumps(payload, separators=(",", ":")).encode()) + signature = sign_signature_base( + alg=ALG_ED25519, + private_key=private, + signature_base=(b64_header + "." + b64_payload).encode("ascii"), + ) + token = b64_header + "." + b64_payload + "." + b64url_encode(signature) + + with pytest.raises(JwsMalformedError, match="typ"): + await averify_jws_document( + token, jwks_resolver=resolver, expected_typ=REVOCATION_LIST_TYP + ) + + +# -- as_async_resolver -------------------------------------------------- + + +async def test_as_async_resolver_wraps_sync_resolver() -> None: + def sync_resolver(keyid: str) -> dict[str, Any] | None: + return {"kid": keyid} if keyid == "x" else None + + async_resolver: AsyncJwksResolver = as_async_resolver(sync_resolver) + assert await async_resolver("x") == {"kid": "x"} + assert await async_resolver("y") is None + + +# -- AsyncCachingRevocationChecker: happy path + cache ------------------ + + +async def test_async_first_call_fetches_and_decides() -> None: + private, resolver = _operator_key_and_resolver() + token = _sign_compact(_make_payload(revoked_kids=["rev"]), private=private) + + fetcher = _ScriptedAsyncFetcher() + fetcher.enqueue(FetchResult(body=token, etag='"v1"', not_modified=False)) + + wall_clock, mono_clock, _ = _controllable_clock( + datetime(2026, 4, 18, 14, 5, tzinfo=timezone.utc) + ) + checker = AsyncCachingRevocationChecker( + revocation_uri=REVOCATION_URI, + issuer=ISSUER, + jwks_resolver=resolver, + fetcher=fetcher, + wall_clock=wall_clock, + clock=mono_clock, + ) + + assert await checker("rev") is True + assert await checker("clean") is False + assert len(fetcher.calls) == 1 + + +async def test_async_cache_hit_within_next_update_skips_refetch() -> None: + private, resolver = _operator_key_and_resolver() + token = _sign_compact(_make_payload(), private=private) + fetcher = _ScriptedAsyncFetcher() + fetcher.enqueue(FetchResult(body=token, etag=None, not_modified=False)) + + wall_clock, mono_clock, advance = _controllable_clock( + datetime(2026, 4, 18, 14, 1, tzinfo=timezone.utc) + ) + checker = AsyncCachingRevocationChecker( + revocation_uri=REVOCATION_URI, + issuer=ISSUER, + jwks_resolver=resolver, + fetcher=fetcher, + wall_clock=wall_clock, + clock=mono_clock, + ) + await checker("k") + advance(60) + await checker("k") + assert len(fetcher.calls) == 1 + + +async def test_async_past_next_update_triggers_conditional_refresh() -> None: + private, resolver = _operator_key_and_resolver() + token = _sign_compact(_make_payload(), private=private) + fetcher = _ScriptedAsyncFetcher() + fetcher.enqueue(FetchResult(body=token, etag='"v1"', not_modified=False)) + fetcher.enqueue(FetchResult(body="", etag='"v1"', not_modified=True)) + + wall_clock, mono_clock, advance = _controllable_clock( + datetime(2026, 4, 18, 14, 1, tzinfo=timezone.utc) + ) + checker = AsyncCachingRevocationChecker( + revocation_uri=REVOCATION_URI, + issuer=ISSUER, + jwks_resolver=resolver, + fetcher=fetcher, + wall_clock=wall_clock, + clock=mono_clock, + ) + await checker("k") + advance(20 * 60) + await checker("k") + assert len(fetcher.calls) == 2 + _, if_none_match, _ = fetcher.calls[1] + assert if_none_match == '"v1"' + + +async def test_async_replay_older_list_rejected() -> None: + private, resolver = _operator_key_and_resolver() + newer = _make_payload( + updated="2026-04-18T14:10:00Z", + next_update="2026-04-18T14:25:00Z", + revoked_kids=["compromised"], + ) + older = _make_payload( + updated="2026-04-18T14:00:00Z", + next_update="2026-04-18T14:15:00Z", + revoked_kids=[], + ) + fetcher = _ScriptedAsyncFetcher() + fetcher.enqueue(FetchResult( + body=_sign_compact(newer, private=private), etag='"v2"', not_modified=False + )) + fetcher.enqueue(FetchResult( + body=_sign_compact(older, private=private), etag='"v1"', not_modified=False + )) + + wall_clock, mono_clock, advance = _controllable_clock( + datetime(2026, 4, 18, 14, 15, tzinfo=timezone.utc) + ) + checker = AsyncCachingRevocationChecker( + revocation_uri=REVOCATION_URI, + issuer=ISSUER, + jwks_resolver=resolver, + fetcher=fetcher, + wall_clock=wall_clock, + clock=mono_clock, + ) + assert await checker("compromised") is True + advance(15 * 60) + assert await checker("compromised") is True + assert checker._current_list is not None + assert checker._current_list.updated == "2026-04-18T14:10:00Z" + + +async def test_async_refresh_failure_past_grace_raises_freshness_error() -> None: + private, resolver = _operator_key_and_resolver() + token = _sign_compact(_make_payload(revoked_kids=["rev"]), private=private) + fetcher = _ScriptedAsyncFetcher() + fetcher.enqueue(FetchResult(body=token, etag='"v1"', not_modified=False)) + for _ in range(5): + fetcher.enqueue(RevocationListFetchError("server unavailable")) + + wall_clock, mono_clock, advance = _controllable_clock( + datetime(2026, 4, 18, 14, 1, tzinfo=timezone.utc) + ) + checker = AsyncCachingRevocationChecker( + revocation_uri=REVOCATION_URI, + issuer=ISSUER, + jwks_resolver=resolver, + fetcher=fetcher, + wall_clock=wall_clock, + clock=mono_clock, + ) + await checker("rev") + advance(45 * 60 + 1) + with pytest.raises(RevocationListFreshnessError): + await checker("rev") + + +async def test_async_aprime_fails_fast() -> None: + _, resolver = _operator_key_and_resolver() + fetcher = _ScriptedAsyncFetcher() + fetcher.enqueue(RevocationListFetchError("operator unreachable")) + + checker = AsyncCachingRevocationChecker( + revocation_uri=REVOCATION_URI, + issuer=ISSUER, + jwks_resolver=resolver, + fetcher=fetcher, + ) + with pytest.raises(RevocationListFetchError, match="operator unreachable"): + await checker.aprime() + + +async def test_async_is_jti_revoked() -> None: + private, resolver = _operator_key_and_resolver() + token = _sign_compact( + _make_payload(revoked_jtis=["jti-abc"]), private=private + ) + fetcher = _ScriptedAsyncFetcher() + fetcher.enqueue(FetchResult(body=token, etag=None, not_modified=False)) + + wall_clock, mono_clock, _ = _controllable_clock( + datetime(2026, 4, 18, 14, 5, tzinfo=timezone.utc) + ) + checker = AsyncCachingRevocationChecker( + revocation_uri=REVOCATION_URI, + issuer=ISSUER, + jwks_resolver=resolver, + fetcher=fetcher, + wall_clock=wall_clock, + clock=mono_clock, + ) + assert await checker.is_jti_revoked("jti-abc") is True + assert await checker.is_jti_revoked("jti-other") is False + + +async def test_async_from_issuer_origin_builds_spec_path() -> None: + _, resolver = _operator_key_and_resolver() + checker = AsyncCachingRevocationChecker.from_issuer_origin( + "https://Gov.Example.COM/", + jwks_resolver=resolver, + ) + assert checker._revocation_uri == ( + "https://gov.example.com/.well-known/governance-revocations.json" + ) + assert checker._issuer == "https://gov.example.com" + + +# -- concurrency: lock serializes refreshes ---------------------------- + + +async def test_async_lock_serializes_first_fetch_under_concurrency() -> None: + """N concurrent tasks hitting the first miss should fire exactly one fetch.""" + private, resolver = _operator_key_and_resolver() + token = _sign_compact(_make_payload(revoked_kids=["rev"]), private=private) + + # Fetcher with a small await inside to make the race observable. + fetch_count = [0] + + async def slow_fetcher( + _uri: str, + *, + if_none_match: str | None = None, + if_modified_since: str | None = None, + ) -> FetchResult: + fetch_count[0] += 1 + await asyncio.sleep(0) # let other tasks interleave + return FetchResult(body=token, etag='"v1"', not_modified=False) + + wall_clock, mono_clock, _ = _controllable_clock( + datetime(2026, 4, 18, 14, 5, tzinfo=timezone.utc) + ) + checker = AsyncCachingRevocationChecker( + revocation_uri=REVOCATION_URI, + issuer=ISSUER, + jwks_resolver=resolver, + fetcher=slow_fetcher, + wall_clock=wall_clock, + clock=mono_clock, + ) + results = await asyncio.gather( + checker("rev"), checker("rev"), checker("rev"), checker("rev") + ) + assert all(r is True for r in results) + assert fetch_count[0] == 1 + + +# -- async JWKS resolver ----------------------------------------------- + + +async def test_async_caching_jwks_resolver_caches_first_fetch() -> None: + jwk = { + "kty": "OKP", + "crv": "Ed25519", + "alg": "EdDSA", + "use": "sig", + "key_ops": ["verify"], + "kid": "my-kid", + "x": "a" * 43, + } + fetch_count = [0] + + async def fake_fetcher(_uri: str, *, allow_private: bool = False) -> dict[str, Any]: + fetch_count[0] += 1 + return {"keys": [jwk]} + + resolver = AsyncCachingJwksResolver( + "https://gov.example.com/.well-known/jwks.json", + fetcher=fake_fetcher, + ) + assert await resolver("my-kid") == jwk + assert await resolver("my-kid") == jwk # cached + assert fetch_count[0] == 1 + + +async def test_async_caching_jwks_resolver_handles_concurrent_misses() -> None: + jwk = {"kid": "concurrent-kid", "kty": "OKP", "crv": "Ed25519", "x": "a" * 43} + fetch_count = [0] + + async def slow_fetcher(_uri: str, *, allow_private: bool = False) -> dict[str, Any]: + fetch_count[0] += 1 + await asyncio.sleep(0) + return {"keys": [jwk]} + + resolver = AsyncCachingJwksResolver( + "https://gov.example.com/.well-known/jwks.json", + fetcher=slow_fetcher, + ) + results = await asyncio.gather( + resolver("concurrent-kid"), + resolver("concurrent-kid"), + resolver("concurrent-kid"), + ) + assert all(r == jwk for r in results) + assert fetch_count[0] == 1 + + +# -- e2e via Starlette + ASGITransport (native async, no asyncio.run) --- + + +def _build_revocation_app(*, body: str, etag: str) -> Starlette: + async def handler(request: Any) -> Response: + if_none_match = request.headers.get("if-none-match") + if if_none_match == etag: + return Response(status_code=304, headers={"ETag": etag}) + return PlainTextResponse( + content=body, + media_type="application/jose", + headers={"ETag": etag}, + ) + + return Starlette( + routes=[Route("/.well-known/governance-revocations.json", handler, methods=["GET"])] + ) + + +def _asgi_async_fetcher(app: Starlette) -> Any: + transport = httpx.ASGITransport(app=app) + + async def fetch( + uri: str, + *, + if_none_match: str | None = None, + if_modified_since: str | None = None, + ) -> FetchResult: + headers: dict[str, str] = {"Accept": "application/jose"} + if if_none_match is not None: + headers["If-None-Match"] = if_none_match + if if_modified_since is not None: + headers["If-Modified-Since"] = if_modified_since + + async with httpx.AsyncClient(transport=transport, base_url=ISSUER) as client: + response = await client.get( + "/.well-known/governance-revocations.json", headers=headers + ) + if response.status_code == 304: + return FetchResult( + body="", + etag=if_none_match, + last_modified=if_modified_since, + not_modified=True, + ) + response.raise_for_status() + return FetchResult( + body=response.text, + etag=response.headers.get("ETag"), + last_modified=response.headers.get("Last-Modified"), + not_modified=False, + ) + + return fetch + + +async def test_async_e2e_asgi_round_trip() -> None: + """Native async path — no asyncio.run bridge in the test.""" + private, resolver = _operator_key_and_resolver() + payload = _make_payload(revoked_kids=["compromised"]) + token = _sign_compact(payload, private=private) + app = _build_revocation_app(body=token, etag='"rev-1"') + + checker = AsyncCachingRevocationChecker( + revocation_uri=REVOCATION_URI, + issuer=ISSUER, + jwks_resolver=resolver, + fetcher=_asgi_async_fetcher(app), + wall_clock=lambda: datetime(2026, 4, 18, 14, 5, tzinfo=timezone.utc), + ) + assert await checker("compromised") is True + assert await checker("clean") is False + + +# -- default fetcher (SSRF smoke only, like sync version) -------------- + + +async def test_async_default_fetcher_rejects_non_https() -> None: + from adcp.signing.jwks import SSRFValidationError + + with pytest.raises(SSRFValidationError): + await async_default_revocation_list_fetcher("ftp://example.com/list.json") + + +async def test_async_default_fetcher_rejects_metadata_ip() -> None: + from adcp.signing.jwks import SSRFValidationError + + with pytest.raises(SSRFValidationError): + await async_default_revocation_list_fetcher("https://169.254.169.254/list.json") + + +async def test_concurrent_first_calls_share_one_refresh() -> None: + """Concurrent first-miss calls on a brand-new checker use the eager + lock to serialize refreshes. + + Regression test for the lazy-lock-init race — previously two tasks + both seeing ``self._lock is None`` would construct separate Locks + and skip serialization. + """ + private, resolver = _operator_key_and_resolver() + token = _sign_compact(_make_payload(revoked_kids=["k"]), private=private) + + refresh_count = [0] + + async def counting_fetcher( + _uri: str, + *, + if_none_match: str | None = None, + if_modified_since: str | None = None, + ) -> FetchResult: + refresh_count[0] += 1 + # Yield to force interleaving between tasks. + await asyncio.sleep(0) + return FetchResult(body=token, etag='"v1"', not_modified=False) + + wall_clock, mono_clock, _ = _controllable_clock( + datetime(2026, 4, 18, 14, 5, tzinfo=timezone.utc) + ) + checker = AsyncCachingRevocationChecker( + revocation_uri=REVOCATION_URI, + issuer=ISSUER, + jwks_resolver=resolver, + fetcher=counting_fetcher, + wall_clock=wall_clock, + clock=mono_clock, + ) + # 20 concurrent tasks all hitting the first miss. + results = await asyncio.gather(*(checker("k") for _ in range(20))) + assert all(r is True for r in results) + assert refresh_count[0] == 1 # one shared refresh for all 20 tasks + + +# -- cancellation safety ------------------------------------------------ + + +async def test_cancellation_rolls_back_cooldown_attempt() -> None: + """A cancelled refresh doesn't burn the 60s cooldown for the next caller. + + Covers the security-reviewer's cancellation-safety concern: setting + ``_last_refresh_attempt`` before the await meant a cancelled task + could block legitimate retries for up to ``MIN_POLLING_INTERVAL_SECONDS``. + The fix rolls the timestamp back on CancelledError. + """ + _, resolver = _operator_key_and_resolver() + + cancel_event = asyncio.Event() + + async def cancellable_fetcher( + _uri: str, + *, + if_none_match: str | None = None, + if_modified_since: str | None = None, + ) -> FetchResult: + # Block forever until the task is cancelled. + await cancel_event.wait() + raise AssertionError("unreachable — task must be cancelled before this") + + checker = AsyncCachingRevocationChecker( + revocation_uri=REVOCATION_URI, + issuer=ISSUER, + jwks_resolver=resolver, + fetcher=cancellable_fetcher, + ) + prior_attempt = checker._last_refresh_attempt + assert prior_attempt is None + + task = asyncio.create_task(checker("rev")) + await asyncio.sleep(0) # let the task block on the fetcher + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + # The cancelled refresh must NOT have left `_last_refresh_attempt` + # stamped — otherwise the next legitimate caller would be blocked + # by the cooldown. + assert checker._last_refresh_attempt == prior_attempt