Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions src/adcp/signing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,15 @@
SignatureVerificationError,
)
from adcp.signing.jwks import (
AsyncCachingJwksResolver,
AsyncJwksFetcher,
AsyncJwksResolver,
CachingJwksResolver,
JwksResolver,
SSRFValidationError,
StaticJwksResolver,
as_async_resolver,
async_default_jwks_fetcher,
default_jwks_fetcher,
validate_jwks_uri,
)
Expand All @@ -75,6 +80,9 @@
JwsMalformedError,
JwsSignatureInvalidError,
JwsUnknownKeyError,
averify_detached_jws,
averify_jws_document,
verify_detached_jws,
verify_jws_document,
)
from adcp.signing.middleware import (
Expand All @@ -87,12 +95,15 @@
from adcp.signing.revocation_fetcher import (
DEFAULT_GRACE_MULTIPLIER,
REVOCATION_LIST_TYP,
AsyncCachingRevocationChecker,
AsyncRevocationListFetcher,
CachingRevocationChecker,
FetchResult,
RevocationListFetcher,
RevocationListFetchError,
RevocationListFreshnessError,
RevocationListParseError,
async_default_revocation_list_fetcher,
default_revocation_list_fetcher,
)
from adcp.signing.signer import (
Expand All @@ -110,6 +121,11 @@
"ALG_ED25519",
"ALG_ES256",
"ALLOWED_ALGS",
"AsyncCachingJwksResolver",
"AsyncCachingRevocationChecker",
"AsyncJwksFetcher",
"AsyncJwksResolver",
"AsyncRevocationListFetcher",
"CachingJwksResolver",
"CachingRevocationChecker",
"DEFAULT_EXPIRES_IN_SECONDS",
Expand Down Expand Up @@ -163,6 +179,11 @@
"VerifierCapability",
"VerifyOptions",
"alg_for_jwk",
"as_async_resolver",
"async_default_jwks_fetcher",
"async_default_revocation_list_fetcher",
"averify_detached_jws",
"averify_jws_document",
"b64url_decode",
"b64url_encode",
"build_signature_base",
Expand All @@ -182,6 +203,7 @@
"sign_signature_base",
"unauthorized_response_headers",
"validate_jwks_uri",
"verify_detached_jws",
"verify_flask_request",
"verify_jws_document",
"verify_request_signature",
Expand Down
163 changes: 161 additions & 2 deletions src/adcp/signing/jwks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,19 @@
tracked in #190 and not implemented here — the current design is vulnerable to
a TOCTOU where DNS resolves to an allowed IP during validation and a blocked
IP at connect time.

Naming conventions
------------------
* Classes use the ``Async`` CapWords prefix (``AsyncCachingJwksResolver``).
* Free functions use the ``async_`` snake_case prefix
(``async_default_jwks_fetcher``).
* Methods use the ``a`` prefix (``aclose``, ``aprime``) — matches the
``httpx`` / ``anyio`` ecosystem.
"""

from __future__ import annotations

import asyncio
import ipaddress
import socket
import time
Expand Down Expand Up @@ -60,18 +69,43 @@ class JwksFetcher(Protocol):
def __call__(self, uri: str, *, allow_private: bool = False) -> dict[str, Any]: ...


class AsyncJwksFetcher(Protocol):
"""Async variant of :class:`JwksFetcher`."""

async def __call__(
self, uri: str, *, allow_private: bool = False
) -> dict[str, Any]: ...


class JwksResolver(Protocol):
"""Resolves a keyid to a JWK, or returns None if unknown.

The canonical Protocol used by both the RFC 9421 verifier and the JWS
document verifier. Implementations include
The canonical Protocol used by the sync RFC 9421 verifier and the
sync JWS document verifier. Implementations include
:class:`StaticJwksResolver` (in-memory, for tests) and
:class:`CachingJwksResolver` (fetches + caches from a URI).

Async callers use :class:`AsyncJwksResolver` instead.
"""

def __call__(self, keyid: str) -> dict[str, Any] | None: ...


class AsyncJwksResolver(Protocol):
"""Async variant of :class:`JwksResolver`.

Used by the async JWS document verifier and the async revocation
checker so JWKS cache-misses don't block the event loop.
Implementations: :class:`AsyncCachingJwksResolver`. For tests,
:class:`StaticJwksResolver` doubles as an async resolver if you wrap
it in a thin async callable — there's no async work, just a dict
lookup — but typically you'll just use the static one directly where
an :class:`AsyncJwksResolver` is expected via :func:`as_async`.
"""

async def __call__(self, keyid: str) -> dict[str, Any] | None: ...


def validate_jwks_uri(uri: str, *, allow_private: bool = False) -> None:
"""Raise SSRFValidationError if `uri` resolves to a blocked IP or has a bad scheme."""
parts = urlsplit(uri)
Expand Down Expand Up @@ -200,15 +234,140 @@ def __call__(self, keyid: str) -> dict[str, Any] | None:
return self._keys.get(keyid)


# ---------------------------------------------------------------------------
# Async variants
# ---------------------------------------------------------------------------


async def async_default_jwks_fetcher(
uri: str, *, allow_private: bool = False
) -> dict[str, Any]:
"""Async counterpart to :func:`default_jwks_fetcher`.

Uses :class:`httpx.AsyncClient` so callers on an asyncio event loop
don't block the loop on JWKS fetches. Same SSRF + follow-redirects
rules as the sync version.
"""
validate_jwks_uri(uri, allow_private=allow_private)
async with httpx.AsyncClient(
timeout=DEFAULT_JWKS_TIMEOUT_SECONDS, follow_redirects=False
) as client:
response = await client.get(uri, headers={"Accept": "application/json"})
response.raise_for_status()
body = response.json()
if not isinstance(body, dict) or "keys" not in body:
raise ValueError(f"JWKS document at {uri!r} has no 'keys' array")
return body


class AsyncCachingJwksResolver:
"""Async JWKS resolver with per-URI cache and refetch cooldown.

Identical semantics to :class:`CachingJwksResolver` — per-``kid``
cache, cooldown-gated refresh on miss, SSRF errors surface as
``request_signature_jwks_untrusted``, network errors as
``request_signature_jwks_unavailable`` — but awaitable and backed by
:class:`httpx.AsyncClient`.

Concurrency: a single :class:`asyncio.Lock` serializes refreshes so
N parallel verifying tasks all driving the first miss don't fire N
fetches.
"""

def __init__(
self,
jwks_uri: str,
*,
fetcher: AsyncJwksFetcher | None = None,
cooldown_seconds: float = DEFAULT_JWKS_COOLDOWN_SECONDS,
allow_private: bool = False,
clock: Callable[[], float] = time.monotonic,
) -> None:
self._jwks_uri = jwks_uri
self._fetcher = fetcher or async_default_jwks_fetcher
self._cooldown = cooldown_seconds
self._allow_private = allow_private
self._clock = clock
self._cache: dict[str, dict[str, Any]] = {}
self._last_attempt: float | None = None
self._primed = False
# Construct the lock eagerly. Lazy init was racy: two tasks both
# seeing ``self._lock is None`` would each construct a separate
# Lock and proceed in parallel without serialization. In Python
# 3.10+ ``asyncio.Lock()`` no longer requires a running loop at
# construction time — it binds to whatever loop is running the
# first ``async with`` call. Instances are per-loop: don't share
# across ``asyncio.run`` boundaries.
self._lock: asyncio.Lock = asyncio.Lock()

async def __call__(self, keyid: str) -> dict[str, Any] | None:
if keyid in self._cache:
return self._cache[keyid]
now = self._clock()
if not self._primed or (
self._last_attempt is not None and now - self._last_attempt >= self._cooldown
):
async with self._lock:
# Re-check after acquiring: another task may have refreshed.
if keyid in self._cache:
return self._cache[keyid]
now = self._clock()
if not self._primed or (
self._last_attempt is not None
and now - self._last_attempt >= self._cooldown
):
await self._refresh(now)
return self._cache.get(keyid)

async def _refresh(self, now: float) -> None:
self._last_attempt = now
try:
jwks = await self._fetcher(self._jwks_uri, allow_private=self._allow_private)
except SSRFValidationError as exc:
raise SignatureVerificationError(
REQUEST_SIGNATURE_JWKS_UNTRUSTED,
step=7,
message=f"JWKS URI failed SSRF check: {exc}",
) from exc
except (httpx.HTTPError, ValueError, OSError) as exc:
raise SignatureVerificationError(
REQUEST_SIGNATURE_JWKS_UNAVAILABLE,
step=7,
message=f"JWKS fetch failed: {exc}",
) from exc
self._primed = True
self._cache = {jwk["kid"]: jwk for jwk in jwks.get("keys", []) if "kid" in jwk}


def as_async_resolver(resolver: JwksResolver) -> AsyncJwksResolver:
"""Wrap a sync :class:`JwksResolver` so it satisfies :class:`AsyncJwksResolver`.

Useful for tests: pass a :class:`StaticJwksResolver` through
:func:`as_async_resolver` to plug it into an async-verifier
pipeline that types ``AsyncJwksResolver``. There's no real async
work (just a dict lookup); the wrapper is a shape adapter.
"""

async def resolve(keyid: str) -> dict[str, Any] | None:
return resolver(keyid)

return resolve


__all__ = [
"BLOCKED_METADATA_IPS",
"AsyncCachingJwksResolver",
"AsyncJwksFetcher",
"AsyncJwksResolver",
"CachingJwksResolver",
"DEFAULT_JWKS_COOLDOWN_SECONDS",
"DEFAULT_JWKS_TIMEOUT_SECONDS",
"JwksFetcher",
"JwksResolver",
"SSRFValidationError",
"StaticJwksResolver",
"as_async_resolver",
"async_default_jwks_fetcher",
"default_jwks_fetcher",
"validate_jwks_uri",
]
Loading
Loading