From 1fa082003acd317b9743099fb83dc898614db573 Mon Sep 17 00:00:00 2001 From: Brian O'Kelley Date: Sun, 19 Apr 2026 19:06:00 -0400 Subject: [PATCH 1/2] feat(signing): IP-pinned httpx transport closes DNS-rebinding TOCTOU Closes #190. See PR description for details. --- pyproject.toml | 8 + src/adcp/signing/__init__.py | 12 + src/adcp/signing/ip_pinned_transport.py | 267 +++++++++++++++++ src/adcp/signing/jwks.py | 124 ++++++-- src/adcp/signing/revocation_fetcher.py | 56 ++-- .../signing/test_ip_pinned_transport.py | 280 ++++++++++++++++++ 6 files changed, 691 insertions(+), 56 deletions(-) create mode 100644 src/adcp/signing/ip_pinned_transport.py create mode 100644 tests/conformance/signing/test_ip_pinned_transport.py diff --git a/pyproject.toml b/pyproject.toml index 71d924a67..08f4d9e9f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,14 @@ classifiers = [ dependencies = [ "httpx>=0.24.0", + # httpcore is a transitive dep of httpx, but we import from it + # directly (``adcp.signing.ip_pinned_transport`` subclasses + # ``SyncBackend`` / ``AnyIOBackend`` from ``httpcore._backends.*``). + # Pin the major so a breaking upgrade fails at install rather than + # at runtime. The contract test in + # ``tests/conformance/signing/test_ip_pinned_transport_contract.py`` + # guards the specific API shapes we rely on. + "httpcore>=1.0,<2.0", "pydantic>=2.0.0", "typing-extensions>=4.5.0", "a2a-sdk>=0.3.0", diff --git a/src/adcp/signing/__init__.py b/src/adcp/signing/__init__.py index 5b002185c..33b31aef6 100644 --- a/src/adcp/signing/__init__.py +++ b/src/adcp/signing/__init__.py @@ -101,6 +101,12 @@ REQUEST_SIGNATURE_WINDOW_INVALID, SignatureVerificationError, ) +from adcp.signing.ip_pinned_transport import ( + AsyncIpPinnedTransport, + IpPinnedTransport, + abuild_ip_pinned_transport, + build_ip_pinned_transport, +) from adcp.signing.jwks import ( AsyncCachingJwksResolver, AsyncJwksFetcher, @@ -112,6 +118,7 @@ as_async_resolver, async_default_jwks_fetcher, default_jwks_fetcher, + resolve_and_validate_host, validate_jwks_uri, ) from adcp.signing.jws import ( @@ -187,6 +194,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: "ALLOWED_ALGS", "AsyncCachingJwksResolver", "AsyncCachingRevocationChecker", + "AsyncIpPinnedTransport", "AsyncJwksFetcher", "AsyncJwksResolver", "AsyncRevocationListFetcher", @@ -198,6 +206,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: "DEFAULT_TAG", "FetchResult", "InMemoryReplayStore", + "IpPinnedTransport", "JwksResolver", "JwsError", "JwsMalformedError", @@ -244,6 +253,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: "VerifierCapability", "VerifyOptions", "alg_for_jwk", + "abuild_ip_pinned_transport", "as_async_resolver", "async_default_jwks_fetcher", "async_default_revocation_list_fetcher", @@ -251,6 +261,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: "averify_jws_document", "b64url_decode", "b64url_encode", + "build_ip_pinned_transport", "build_signature_base", "canonicalize_authority", "canonicalize_target_uri", @@ -265,6 +276,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: "parse_signature_input_header", "private_key_from_jwk", "public_key_from_jwk", + "resolve_and_validate_host", "sign_request", "sign_signature_base", "unauthorized_response_headers", diff --git a/src/adcp/signing/ip_pinned_transport.py b/src/adcp/signing/ip_pinned_transport.py new file mode 100644 index 000000000..4214f7e32 --- /dev/null +++ b/src/adcp/signing/ip_pinned_transport.py @@ -0,0 +1,267 @@ +"""IP-pinned httpx transports that close the DNS-rebinding TOCTOU. + +The default signing fetchers (JWKS + revocation-list, sync + async) +resolve the target hostname via :func:`resolve_and_validate_host`, +then hand the URL back to httpx — which resolves the hostname a +second time at connect. A malicious origin with ``TTL=0`` can return +a safe IP on the first lookup (passing SSRF validation) and a +private IP or cloud-metadata address on the second. + +This module closes that gap. :func:`build_ip_pinned_transport` +resolves once, picks the first IP that passes the SSRF validator, +and returns an :class:`httpx.HTTPTransport` wired to a custom +:mod:`httpcore` network backend that translates the pinned +hostname → IP at connect time. TLS certificate validation still +runs against the original hostname (httpcore passes it separately +as ``server_hostname`` during the TLS handshake), so cert CN/SAN +matching is unaffected. + +The transport is single-host-scoped. Reusing it for a DIFFERENT +hostname would bypass the pin and either connect to the wrong IP +or fail SSRF re-resolution. Build one transport per hostname you +need to reach; the existing fetchers do this per-call. + +Naming conventions +------------------ + +* Classes use the ``Async`` CapWords prefix + (:class:`AsyncIpPinnedTransport`). +* Free functions use the ``async_``/``a`` prefix + (:func:`abuild_ip_pinned_transport`) — matches the rest of this + sub-package. + +Dependency on httpcore internals +-------------------------------- + +We reach into httpcore at two points: + +1. ``httpcore.ConnectionPool(network_backend=...)`` — public API. +2. ``httpcore._backends.sync.SyncBackend`` / + ``httpcore._backends.anyio.AnyIOBackend`` — underscore-prefixed + path, nominally private. The backend classes are the documented + default-backend implementations, and the ``network_backend`` kwarg + is the sanctioned extension point, but the stability of the + backend class names themselves isn't guaranteed. + +Mitigations: + +* ``pyproject.toml`` pins ``httpcore>=1.0,<2.0``. +* :class:`adcp.signing.ip_pinned_transport` exports the backend + signatures from a contract test that fails on import if upstream + changes them — see + ``tests/conformance/signing/test_ip_pinned_transport.py``. +""" + +from __future__ import annotations + +import ssl +from collections.abc import Iterable +from typing import TYPE_CHECKING, Any + +import httpcore +import httpx + +# Private but documented-as-the-default-backend implementations. The +# underscore prefix is a stability hazard; the contract test in +# tests/conformance/signing/test_ip_pinned_transport.py fails if the +# signatures we rely on change, so a silent upstream break becomes a +# CI failure instead of a latent security regression. +from httpcore._backends.anyio import AnyIOBackend as _AnyIOBackend +from httpcore._backends.sync import SyncBackend as _SyncBackend + +from adcp.signing.jwks import ( + DEFAULT_JWKS_TIMEOUT_SECONDS, + resolve_and_validate_host, +) + +if TYPE_CHECKING: + from httpcore._backends.base import SOCKET_OPTION + + +__all__ = [ + "AsyncIpPinnedTransport", + "IpPinnedTransport", + "abuild_ip_pinned_transport", + "build_ip_pinned_transport", +] + + +def _build_ssl_context() -> ssl.SSLContext: + """Standard cert-validating TLS context. ``check_hostname`` stays True + so the hostname-in-cert-SAN match runs against the URL's original + host (the hostname httpcore passes as ``server_hostname`` during + the handshake), not the pinned IP. + """ + return ssl.create_default_context() + + +class _IpPinnedSyncBackend(_SyncBackend): + """httpcore sync backend that connects by IP for one pinned hostname. + + Delegates to the parent's ``connect_tcp`` after swapping the + host argument from the hostname to the pre-resolved IP. All + other methods (``connect_unix_socket``) pass through unchanged. + """ + + def __init__(self, *, hostname: str, resolved_ip: str) -> None: + super().__init__() + self._hostname = hostname.lower() + self._resolved_ip = resolved_ip + + def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: Iterable[SOCKET_OPTION] | None = None, + ) -> Any: + if host.lower() == self._hostname: + host = self._resolved_ip + return super().connect_tcp( + host=host, + port=port, + timeout=timeout, + local_address=local_address, + socket_options=socket_options, + ) + + +class _IpPinnedAsyncBackend(_AnyIOBackend): + """Async counterpart to :class:`_IpPinnedSyncBackend`.""" + + def __init__(self, *, hostname: str, resolved_ip: str) -> None: + super().__init__() + self._hostname = hostname.lower() + self._resolved_ip = resolved_ip + + async def connect_tcp( + self, + host: str, + port: int, + timeout: float | None = None, + local_address: str | None = None, + socket_options: Iterable[SOCKET_OPTION] | None = None, + ) -> Any: + if host.lower() == self._hostname: + host = self._resolved_ip + return await super().connect_tcp( + host=host, + port=port, + timeout=timeout, + local_address=local_address, + socket_options=socket_options, + ) + + +class IpPinnedTransport(httpx.HTTPTransport): + """``httpx.HTTPTransport`` that connects by pre-resolved IP. + + Preserves normal httpx ergonomics — pass this to + ``httpx.Client(transport=...)`` and everything else works + unchanged. The TLS handshake uses the original hostname for + SNI + cert validation; only the TCP destination is rewritten. + + Construct via :func:`build_ip_pinned_transport` unless you've + already resolved the hostname yourself. + """ + + def __init__( + self, + *, + hostname: str, + resolved_ip: str, + verify: bool = True, + retries: int = 0, + ) -> None: + if verify: + ssl_context = _build_ssl_context() + else: + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + backend = _IpPinnedSyncBackend(hostname=hostname, resolved_ip=resolved_ip) + # Build the ConnectionPool ourselves (rather than super().__init__ + # and then reassign ._pool) so the TLS + backend config is set + # up atomically and we don't briefly own a vanilla pool. + self._pool = httpcore.ConnectionPool( + ssl_context=ssl_context, + network_backend=backend, + http1=True, + http2=False, + retries=retries, + ) + + +class AsyncIpPinnedTransport(httpx.AsyncHTTPTransport): + """Async counterpart to :class:`IpPinnedTransport`.""" + + def __init__( + self, + *, + hostname: str, + resolved_ip: str, + verify: bool = True, + retries: int = 0, + ) -> None: + if verify: + ssl_context = _build_ssl_context() + else: + ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + backend = _IpPinnedAsyncBackend(hostname=hostname, resolved_ip=resolved_ip) + self._pool = httpcore.AsyncConnectionPool( + ssl_context=ssl_context, + network_backend=backend, + http1=True, + http2=False, + retries=retries, + ) + + +def build_ip_pinned_transport( + uri: str, + *, + allow_private: bool = False, + verify: bool = True, +) -> IpPinnedTransport: + """Resolve ``uri`` once and return a transport pinned to the validated IP. + + Raises :class:`SSRFValidationError` if the URI's scheme isn't + ``http``/``https``, the host doesn't resolve, or every resolved + IP is in a blocked range. + + Typical use inside a fetcher:: + + transport = build_ip_pinned_transport(uri) + with httpx.Client(transport=transport, timeout=10.0) as client: + response = client.get(uri) + """ + hostname, resolved_ip, _port = resolve_and_validate_host(uri, allow_private=allow_private) + return IpPinnedTransport(hostname=hostname, resolved_ip=resolved_ip, verify=verify) + + +def abuild_ip_pinned_transport( + uri: str, + *, + allow_private: bool = False, + verify: bool = True, +) -> AsyncIpPinnedTransport: + """Async counterpart to :func:`build_ip_pinned_transport`. + + The resolution itself is synchronous (``socket.getaddrinfo``); + this function is not actually awaitable, but the prefix matches + the rest of the sub-package's naming and the returned transport + is async. + """ + hostname, resolved_ip, _port = resolve_and_validate_host(uri, allow_private=allow_private) + return AsyncIpPinnedTransport(hostname=hostname, resolved_ip=resolved_ip, verify=verify) + + +# Quiet the unused-import warning — DEFAULT_JWKS_TIMEOUT_SECONDS is +# imported for re-export convenience and callers who want to pair the +# transport with a matching timeout. +_ = DEFAULT_JWKS_TIMEOUT_SECONDS diff --git a/src/adcp/signing/jwks.py b/src/adcp/signing/jwks.py index 168473f4c..eb34d2b47 100644 --- a/src/adcp/signing/jwks.py +++ b/src/adcp/signing/jwks.py @@ -72,9 +72,7 @@ def __call__(self, uri: str, *, allow_private: bool = False) -> dict[str, Any]: class AsyncJwksFetcher(Protocol): """Async variant of :class:`JwksFetcher`.""" - async def __call__( - self, uri: str, *, allow_private: bool = False - ) -> dict[str, Any]: ... + async def __call__(self, uri: str, *, allow_private: bool = False) -> dict[str, Any]: ... class JwksResolver(Protocol): @@ -107,19 +105,67 @@ async def __call__(self, keyid: str) -> dict[str, Any] | None: ... def validate_jwks_uri(uri: str, *, allow_private: bool = False) -> None: - """Raise SSRFValidationError if `uri` resolves to a blocked IP or has a bad scheme.""" + """Raise SSRFValidationError if `uri` resolves to a blocked IP or has a bad scheme. + + This is kept as a standalone no-return helper for callers that only + want validation — :func:`resolve_and_validate_host` returns the + accepted IP when the caller needs it for IP-pinned connects. + """ + resolve_and_validate_host(uri, allow_private=allow_private) + + +def resolve_and_validate_host( + uri: str, + *, + allow_private: bool = False, +) -> tuple[str, str, int]: + """Resolve the URI's hostname once and return ``(hostname, ip, port)``. + + Runs the full SSRF validation — reserved-range rejection + cloud- + metadata blocklist — and returns the first IP that passes. Callers + that connect by IP (see :class:`adcp.signing.IpPinnedTransport`) + use the returned IP to close the DNS-rebinding TOCTOU: they resolve + ONCE through this helper, then pin subsequent connects to that IP. + + The returned IP is always ASCII (no IPv6 scope id, no IPv4-mapped + IPv6 wrapping) so it can be handed verbatim to + :func:`socket.create_connection`. + + Parameters + ---------- + uri: + A full URL. Only ``http`` and ``https`` schemes are accepted. + allow_private: + Skip the reserved-range check. For tests only; cloud-metadata + IPs remain blocked unconditionally. + + Returns + ------- + tuple[str, str, int] + ``(hostname, ip, port)`` — hostname and port are parsed from + the URI; IP is the validated resolution. + + Raises + ------ + SSRFValidationError + Scheme is not ``http``/``https``, the hostname doesn't resolve, + or every resolved IP is in a blocked range. + """ parts = urlsplit(uri) if parts.scheme not in ("http", "https"): raise SSRFValidationError(f"unsupported scheme for JWKS URI: {parts.scheme!r}") host = parts.hostname if host is None or host == "": raise SSRFValidationError("JWKS URI has no host") + port = parts.port if parts.port is not None else (443 if parts.scheme == "https" else 80) try: infos = socket.getaddrinfo(host, None) except OSError as exc: raise SSRFValidationError(f"cannot resolve host {host!r}: {exc}") from exc + accepted_ip: str | None = None + last_rejection: str | None = None for _family, _, _, _, sockaddr in infos[:_MAX_RESOLVED_ADDRESSES]: ip_raw = sockaddr[0] ip_str = str(ip_raw) @@ -135,9 +181,7 @@ def validate_jwks_uri(uri: str, *, allow_private: bool = False) -> None: ip = ip.ipv4_mapped if str(ip) in BLOCKED_METADATA_IPS: raise SSRFValidationError(f"cloud metadata IP {ip} blocked") - if allow_private: - continue - if ( + if not allow_private and ( ip.is_private or ip.is_loopback or ip.is_link_local @@ -145,16 +189,47 @@ def validate_jwks_uri(uri: str, *, allow_private: bool = False) -> None: or ip.is_reserved or ip.is_unspecified ): - raise SSRFValidationError(f"resolved IP {ip} is in a reserved range") + last_rejection = f"resolved IP {ip} is in a reserved range" + # Historical behavior of validate_jwks_uri was to raise on + # ANY reserved IP in the result list, not to skip-and-try- + # the-next-one. Preserve that: reject immediately so a host + # with mixed public + private results doesn't silently pin + # the public one. + raise SSRFValidationError(last_rejection) + if accepted_ip is None: + accepted_ip = str(ip) + + if accepted_ip is None: + # Shouldn't happen — getaddrinfo with results + no raise means + # at least one entry passed. Belt-and-braces. + raise SSRFValidationError( + f"host {host!r} resolved but no usable IP ({last_rejection or 'unknown'})" + ) + return host, accepted_ip, port def default_jwks_fetcher(uri: str, *, allow_private: bool = False) -> dict[str, Any]: - """Validate the URI against SSRF rules, then GET the JWKS document.""" - validate_jwks_uri(uri, allow_private=allow_private) + """Validate + resolve the URI once, then GET the JWKS over an IP-pinned + transport. + + Pinning closes the DNS-rebinding TOCTOU that would otherwise let a + ``TTL=0`` attacker pass SSRF validation with one IP and connect to + a different one. See :mod:`adcp.signing.ip_pinned_transport`. + """ + # Import lazily to avoid a module-load cycle with the transport module + # (which imports from this file). + from adcp.signing.ip_pinned_transport import build_ip_pinned_transport + + transport = build_ip_pinned_transport(uri, allow_private=allow_private) # follow_redirects=False is explicit: httpx already defaults to no-follow, - # but an attacker controlling the JWKS origin could 302 us to an IP that - # `validate_jwks_uri` already cleared. - with httpx.Client(timeout=DEFAULT_JWKS_TIMEOUT_SECONDS, follow_redirects=False) as client: + # but an attacker controlling the JWKS origin could 302 us to a + # hostname our pinned transport doesn't cover, re-introducing the + # TOCTOU. Keep redirects off. + with httpx.Client( + transport=transport, + timeout=DEFAULT_JWKS_TIMEOUT_SECONDS, + follow_redirects=False, + ) as client: response = client.get(uri, headers={"Accept": "application/json"}) response.raise_for_status() body = response.json() @@ -239,18 +314,21 @@ def __call__(self, keyid: str) -> dict[str, Any] | None: # --------------------------------------------------------------------------- -async def async_default_jwks_fetcher( - uri: str, *, allow_private: bool = False -) -> dict[str, Any]: +async def async_default_jwks_fetcher(uri: str, *, allow_private: bool = False) -> dict[str, Any]: """Async counterpart to :func:`default_jwks_fetcher`. - Uses :class:`httpx.AsyncClient` so callers on an asyncio event loop - don't block the loop on JWKS fetches. Same SSRF + follow-redirects - rules as the sync version. + Uses :class:`httpx.AsyncClient` with an IP-pinned transport so + callers on an asyncio event loop don't block the loop on JWKS + fetches AND the DNS-rebinding TOCTOU stays closed. Same SSRF + + follow-redirects rules as the sync version. """ - validate_jwks_uri(uri, allow_private=allow_private) + from adcp.signing.ip_pinned_transport import abuild_ip_pinned_transport + + transport = abuild_ip_pinned_transport(uri, allow_private=allow_private) async with httpx.AsyncClient( - timeout=DEFAULT_JWKS_TIMEOUT_SECONDS, follow_redirects=False + transport=transport, + timeout=DEFAULT_JWKS_TIMEOUT_SECONDS, + follow_redirects=False, ) as client: response = await client.get(uri, headers={"Accept": "application/json"}) response.raise_for_status() @@ -313,8 +391,7 @@ async def __call__(self, keyid: str) -> dict[str, Any] | None: return self._cache[keyid] now = self._clock() if not self._primed or ( - self._last_attempt is not None - and now - self._last_attempt >= self._cooldown + self._last_attempt is not None and now - self._last_attempt >= self._cooldown ): await self._refresh(now) return self._cache.get(keyid) @@ -369,5 +446,6 @@ async def resolve(keyid: str) -> dict[str, Any] | None: "as_async_resolver", "async_default_jwks_fetcher", "default_jwks_fetcher", + "resolve_and_validate_host", "validate_jwks_uri", ] diff --git a/src/adcp/signing/revocation_fetcher.py b/src/adcp/signing/revocation_fetcher.py index 8593c7c9c..ede0d1e59 100644 --- a/src/adcp/signing/revocation_fetcher.py +++ b/src/adcp/signing/revocation_fetcher.py @@ -47,7 +47,6 @@ DEFAULT_JWKS_TIMEOUT_SECONDS, AsyncJwksResolver, JwksResolver, - validate_jwks_uri, ) from adcp.signing.jws import ( JwsError, @@ -210,9 +209,7 @@ def _fetch_result_from_response( not_modified=True, ) if status_code != 200: - raise RevocationListFetchError( - f"revocation list {uri!r} returned HTTP {status_code}" - ) + raise RevocationListFetchError(f"revocation list {uri!r} returned HTTP {status_code}") etag = response_headers.get("ETag") last_modified = _sanitize_last_modified(response_headers.get("Last-Modified")) @@ -250,20 +247,20 @@ def default_revocation_list_fetcher( ) -> FetchResult: """HTTPS GET the revocation list, honoring SSRF rules and conditional requests. - Reuses ``validate_jwks_uri`` — the SSRF controls are identical (same - reserved-range rejection, same cloud-metadata block). ``httpx`` - re-resolves the hostname on connect, which is the TOCTOU window - tracked separately in #190. Sends ``If-None-Match`` when an ETag is + Reuses the JWKS SSRF controls (same reserved-range rejection, same + cloud-metadata block) via an IP-pinned transport so the hostname + is resolved once and connections target that validated IP — no + DNS-rebinding TOCTOU. Sends ``If-None-Match`` when an ETag is supplied and ``If-Modified-Since`` when a ``Last-Modified`` is supplied; the spec accepts either (sellers SHOULD use both when available). """ - validate_jwks_uri(uri, allow_private=allow_private) - headers = _build_fetch_headers( - if_none_match=if_none_match, if_modified_since=if_modified_since - ) + from adcp.signing.ip_pinned_transport import build_ip_pinned_transport + + transport = build_ip_pinned_transport(uri, allow_private=allow_private) + headers = _build_fetch_headers(if_none_match=if_none_match, if_modified_since=if_modified_since) try: - with httpx.Client(timeout=timeout, follow_redirects=False) as client: + with httpx.Client(transport=transport, timeout=timeout, follow_redirects=False) as client: response = client.get(uri, headers=headers) except httpx.HTTPError as exc: raise RevocationListFetchError(f"revocation list GET {uri!r} failed: {exc}") from exc @@ -288,16 +285,18 @@ async def async_default_revocation_list_fetcher( ) -> FetchResult: """Async counterpart to :func:`default_revocation_list_fetcher`. - Same SSRF + conditional-request behavior, but uses + Same SSRF + IP-pinned + conditional-request behavior, but uses :class:`httpx.AsyncClient` so the event loop isn't blocked during the round-trip. """ - validate_jwks_uri(uri, allow_private=allow_private) - headers = _build_fetch_headers( - if_none_match=if_none_match, if_modified_since=if_modified_since - ) + from adcp.signing.ip_pinned_transport import abuild_ip_pinned_transport + + transport = abuild_ip_pinned_transport(uri, allow_private=allow_private) + headers = _build_fetch_headers(if_none_match=if_none_match, if_modified_since=if_modified_since) try: - async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client: + async with httpx.AsyncClient( + transport=transport, timeout=timeout, follow_redirects=False + ) as client: response = await client.get(uri, headers=headers) except httpx.HTTPError as exc: raise RevocationListFetchError(f"revocation list GET {uri!r} failed: {exc}") from exc @@ -379,9 +378,7 @@ def _normalize_issuer(issuer: str) -> str: return urlunsplit((scheme, netloc, "", "", "")) -def _slide_next_update( - current: RevocationList, polling_interval_seconds: float -) -> RevocationList: +def _slide_next_update(current: RevocationList, polling_interval_seconds: float) -> RevocationList: """Return ``current`` with ``next_update`` advanced by one polling interval. Used on a 304 response so the cached list's freshness window slides @@ -881,9 +878,7 @@ def __init__( self._revocation_uri = revocation_uri self._issuer = _normalize_issuer(issuer) self._jwks_resolver = jwks_resolver - self._fetcher: AsyncRevocationListFetcher = ( - fetcher or async_default_revocation_list_fetcher - ) + self._fetcher: AsyncRevocationListFetcher = fetcher or async_default_revocation_list_fetcher self._grace_multiplier = grace_multiplier self._clock = clock self._wall_clock = wall_clock @@ -953,9 +948,7 @@ async def _ensure_fresh(self) -> None: # pre-lock clocks could be stale. now_wall = self._wall_clock() now_mono = self._clock() - await self._refresh( - conditional=False, now_wall=now_wall, now_mono=now_mono - ) + await self._refresh(conditional=False, now_wall=now_wall, now_mono=now_mono) return next_update = _parse_iso8601(self._current_list.next_update) @@ -975,8 +968,7 @@ async def _ensure_fresh(self) -> None: # Re-check under the lock with fresh clock reads. now_mono_inside = self._clock() if self._last_refresh_attempt is None or ( - now_mono_inside - self._last_refresh_attempt - >= MIN_POLLING_INTERVAL_SECONDS + now_mono_inside - self._last_refresh_attempt >= MIN_POLLING_INTERVAL_SECONDS ): now_wall_inside = self._wall_clock() await self._refresh( @@ -1001,9 +993,7 @@ async def _ensure_fresh(self) -> None: f"last refresh error: {last_exc}" ) from last_exc - async def _refresh( - self, *, conditional: bool, now_wall: datetime, now_mono: float - ) -> None: + async def _refresh(self, *, conditional: bool, now_wall: datetime, now_mono: float) -> None: # Stamp the attempt BEFORE the awaitable. On CancelledError the # finally block rolls it back so a cancelled task doesn't burn # the 60s cooldown for the next caller — non-cancellation diff --git a/tests/conformance/signing/test_ip_pinned_transport.py b/tests/conformance/signing/test_ip_pinned_transport.py new file mode 100644 index 000000000..46706956b --- /dev/null +++ b/tests/conformance/signing/test_ip_pinned_transport.py @@ -0,0 +1,280 @@ +"""Tests for :mod:`adcp.signing.ip_pinned_transport`. + +Three kinds of coverage: + +1. **Contract tests** — fail fast if httpcore's private-backend API + shifts shape between versions. These protect against silent upstream + breakage of the ``_backends`` subpackage we reach into. +2. **Rebinding simulation** — monkey-patch :func:`socket.getaddrinfo` + so the first resolution returns a safe IP and a hypothetical + second resolution would return a dangerous one. The pinned + transport MUST connect to the first IP, and the second resolution + MUST never happen. +3. **SSRF integration** — the fetchers defer resolution to + :func:`resolve_and_validate_host`; reserved-range and + cloud-metadata IPs still reject at construction. +""" + +from __future__ import annotations + +import inspect +import socket +from unittest.mock import patch + +import httpcore +import pytest +from httpcore._backends.anyio import AnyIOBackend # type: ignore[attr-defined] +from httpcore._backends.sync import SyncBackend # type: ignore[attr-defined] + +from adcp.signing import ( + AsyncIpPinnedTransport, + IpPinnedTransport, + SSRFValidationError, + abuild_ip_pinned_transport, + build_ip_pinned_transport, + resolve_and_validate_host, +) + +# -- contract tests --------------------------------------------------- + + +def test_httpcore_sync_backend_connect_tcp_signature_unchanged() -> None: + """If httpcore changes ``SyncBackend.connect_tcp``, the pinned + transport silently breaks. This test fails fast on upgrade so we + notice during CI, not during a real rebinding attempt. + """ + sig = inspect.signature(SyncBackend.connect_tcp) + # Required positional: host, port. Then the rest are kwargs with + # defaults. If any of these vanish, override becomes wrong. + params = list(sig.parameters) + assert params[0] == "self" + assert params[1] == "host" + assert params[2] == "port" + assert "timeout" in params + assert "local_address" in params + assert "socket_options" in params + + +def test_httpcore_async_backend_connect_tcp_signature_unchanged() -> None: + sig = inspect.signature(AnyIOBackend.connect_tcp) + params = list(sig.parameters) + assert params[0] == "self" + assert params[1] == "host" + assert params[2] == "port" + assert "timeout" in params + assert "local_address" in params + assert "socket_options" in params + + +def test_httpcore_connection_pool_accepts_network_backend() -> None: + """``ConnectionPool(network_backend=...)`` is the public extension + point we rely on.""" + sig = inspect.signature(httpcore.ConnectionPool.__init__) + assert "network_backend" in sig.parameters + sig_async = inspect.signature(httpcore.AsyncConnectionPool.__init__) + assert "network_backend" in sig_async.parameters + + +# -- resolve_and_validate_host --------------------------------------- + + +def test_resolve_returns_tuple_of_host_ip_port() -> None: + host, ip, port = resolve_and_validate_host("https://example.com/jwks") + assert host == "example.com" + assert port == 443 + # Accepted IP is a string form, not a wrapped ipaddress object. + assert isinstance(ip, str) + # example.com resolves publicly; we just check the ip isn't private. + import ipaddress + + parsed = ipaddress.ip_address(ip) + assert not parsed.is_private + assert not parsed.is_loopback + + +def test_resolve_defaults_http_port_80() -> None: + # Even though we normally refuse non-https elsewhere, the helper + # itself is scheme-agnostic for the port default. + host, _ip, port = resolve_and_validate_host("http://example.com/jwks") + assert host == "example.com" + assert port == 80 + + +def test_resolve_rejects_non_http_scheme() -> None: + with pytest.raises(SSRFValidationError, match="scheme"): + resolve_and_validate_host("ftp://example.com/jwks") + + +def test_resolve_rejects_private_result_without_allow_private() -> None: + # Simulate getaddrinfo returning a private IP — a rebinding + # attacker's payload. + def fake_getaddrinfo(_host, _port, *_args, **_kwargs): + return [(socket.AF_INET, 0, 0, "", ("10.0.0.1", 0))] + + with patch("adcp.signing.jwks.socket.getaddrinfo", side_effect=fake_getaddrinfo): + with pytest.raises(SSRFValidationError, match="reserved range"): + resolve_and_validate_host("https://example.com/") + + +def test_resolve_rejects_cloud_metadata_ip_even_with_allow_private() -> None: + """Cloud metadata IPs are blocked unconditionally — not even + ``allow_private=True`` unlocks them.""" + + def fake_getaddrinfo(_host, _port, *_args, **_kwargs): + return [(socket.AF_INET, 0, 0, "", ("169.254.169.254", 0))] + + with patch("adcp.signing.jwks.socket.getaddrinfo", side_effect=fake_getaddrinfo): + with pytest.raises(SSRFValidationError, match="metadata"): + resolve_and_validate_host("https://example.com/", allow_private=True) + + +# -- rebinding simulation -------------------------------------------- + + +def test_transport_pins_first_resolution_against_rebinding() -> None: + """Attacker scenario: TTL=0 DNS returns a safe IP first (passes + validation), then returns a metadata IP on the second resolution. + The transport MUST ignore the second resolution and connect to + the first IP. + """ + call_count = {"n": 0} + + def fake_getaddrinfo(_host, _port, *_args, **_kwargs): + call_count["n"] += 1 + if call_count["n"] == 1: + # Safe public IP — passes validation. + return [(socket.AF_INET, 0, 0, "", ("8.8.8.8", 0))] + # Any subsequent lookup would return cloud metadata. + return [(socket.AF_INET, 0, 0, "", ("169.254.169.254", 0))] + + with patch("adcp.signing.jwks.socket.getaddrinfo", side_effect=fake_getaddrinfo): + transport = build_ip_pinned_transport("https://attacker.example/") + + # One resolution happened during build; nothing else is allowed. + assert call_count["n"] == 1 + + # Inspect the backend the transport installed: the pinned IP must + # be the first resolution, and the hostname match must be + # case-insensitive for safety. + pool = transport._pool + backend = pool._network_backend # type: ignore[attr-defined] + assert backend._resolved_ip == "8.8.8.8" + assert backend._hostname == "attacker.example" + + +def test_async_transport_pins_first_resolution_against_rebinding() -> None: + call_count = {"n": 0} + + def fake_getaddrinfo(_host, _port, *_args, **_kwargs): + call_count["n"] += 1 + return [(socket.AF_INET, 0, 0, "", ("1.1.1.1", 0))] + + with patch("adcp.signing.jwks.socket.getaddrinfo", side_effect=fake_getaddrinfo): + transport = abuild_ip_pinned_transport("https://attacker.example/") + + assert call_count["n"] == 1 + pool = transport._pool + backend = pool._network_backend # type: ignore[attr-defined] + assert backend._resolved_ip == "1.1.1.1" + assert backend._hostname == "attacker.example" + + +def test_backend_connect_tcp_swaps_hostname_for_pinned_ip() -> None: + """Directly test the backend's override — ``connect_tcp`` with the + pinned hostname calls the parent with the resolved IP instead. + """ + from adcp.signing.ip_pinned_transport import _IpPinnedSyncBackend + + backend = _IpPinnedSyncBackend(hostname="attacker.example", resolved_ip="198.51.100.30") + + captured = {} + + def _fake_parent_connect(self, *, host, port, timeout, local_address, socket_options): + captured["host"] = host + captured["port"] = port + return object() # stand-in for a NetworkStream + + with patch.object(SyncBackend, "connect_tcp", _fake_parent_connect): + backend.connect_tcp(host="attacker.example", port=443) + + assert captured["host"] == "198.51.100.30" + assert captured["port"] == 443 + + +def test_backend_connect_tcp_leaves_other_hosts_unchanged() -> None: + """If some code path reuses the transport for a DIFFERENT host + (misuse), the backend MUST NOT silently route it to the pinned IP. + """ + from adcp.signing.ip_pinned_transport import _IpPinnedSyncBackend + + backend = _IpPinnedSyncBackend(hostname="attacker.example", resolved_ip="198.51.100.40") + captured = {} + + def _fake_parent_connect(self, *, host, port, timeout, local_address, socket_options): + captured["host"] = host + return object() + + with patch.object(SyncBackend, "connect_tcp", _fake_parent_connect): + backend.connect_tcp(host="other.example", port=443) + + assert captured["host"] == "other.example" + + +def test_backend_hostname_match_is_case_insensitive() -> None: + from adcp.signing.ip_pinned_transport import _IpPinnedSyncBackend + + backend = _IpPinnedSyncBackend(hostname="Attacker.Example", resolved_ip="198.51.100.50") + captured = {} + + def _fake_parent_connect(self, *, host, port, timeout, local_address, socket_options): + captured["host"] = host + return object() + + with patch.object(SyncBackend, "connect_tcp", _fake_parent_connect): + backend.connect_tcp(host="ATTACKER.example", port=443) + + assert captured["host"] == "198.51.100.50" + + +# -- real-network smoke (optional, skipped if no internet) ------------ + + +def _internet_ok() -> bool: + try: + socket.create_connection(("1.1.1.1", 443), timeout=2).close() + return True + except OSError: + return False + + +@pytest.mark.skipif(not _internet_ok(), reason="no outbound internet") +def test_real_tls_handshake_still_validates_hostname() -> None: + """End-to-end sanity: with the pinned transport, TLS cert + validation still runs against the hostname (not the IP). A + successful GET against a public HTTPS endpoint proves the TLS + SNI + cert validation paths are intact. + """ + import httpx + + transport = build_ip_pinned_transport("https://example.com/") + # Generous timeout — this test is inherently network-dependent and + # real TLS handshakes occasionally slow-run on constrained CI + # machines. The intent is "handshake didn't reject", not speed. + with httpx.Client(transport=transport, timeout=60.0) as client: + response = client.get("https://example.com/") + assert response.status_code == 200 + + +@pytest.mark.skipif(not _internet_ok(), reason="no outbound internet") +def test_transport_type_is_httpx_httptransport() -> None: + """The returned transport IS an httpx.HTTPTransport instance so + callers can use it with httpx.Client without type-gymnastics.""" + import httpx + + transport = build_ip_pinned_transport("https://example.com/") + assert isinstance(transport, httpx.HTTPTransport) + assert isinstance(transport, IpPinnedTransport) + + atransport = abuild_ip_pinned_transport("https://example.com/") + assert isinstance(atransport, httpx.AsyncHTTPTransport) + assert isinstance(atransport, AsyncIpPinnedTransport) From b432b03003c5cf2d1afa393d7dd5d714e0785405 Mon Sep 17 00:00:00 2001 From: Brian O'Kelley Date: Sun, 19 Apr 2026 20:07:06 -0400 Subject: [PATCH 2/2] =?UTF-8?q?fix(signing):=20PR=20#206=20reviewer=20fixe?= =?UTF-8?q?s=20=E2=80=94=20IDN=20normalization,=20HTTPS=5FPROXY,=20fail-cl?= =?UTF-8?q?osed=20reuse?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Applies 8 SHOULD-FIX items + nits from code-reviewer, security-reviewer, and dx-expert walkthroughs of the IP-pinned httpx transport. Security: - IDN/punycode normalization in resolve_and_validate_host and _normalize_pin_host closes a bypass where a Unicode hostname would resolve successfully but the backend pin match would fail, silently falling through to an unpinned connect_tcp. - trust_env=False on all four default fetchers (sync+async JWKS, sync+async revocation) blocks HTTPS_PROXY from reopening the TOCTOU by routing through a CONNECT proxy that re-resolves the host. DX: - Backend fails closed on wrong-host reuse. Reusing a transport for a different hostname now raises RuntimeError instead of silently bypassing the pin — agents were observed caching one transport per base URL and reusing it across hosts. - Generic error wording ("SSRF-validated fetch" instead of "JWKS URI") since the helper is shared with revocation fetchers and custom transports. - Rename abuild_ip_pinned_transport → build_async_ip_pinned_transport; the factory is synchronous, only the returned transport is async. Legacy alias emits DeprecationWarning for one release. - verify=False now warns at construction. - Quickstart gains a "Custom fetchers" section pointing at the two public helpers. Implementation: - Explicit max_connections=100, max_keepalive_connections=20 on ConnectionPool to match httpx defaults (httpcore's 10/_ would be a surprise downgrade). - New tests cover IDN punycode, trailing-dot FQDN, wrong-host refusal, and the deprecation alias. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/adcp/signing/__init__.py | 15 ++ src/adcp/signing/ip_pinned_transport.py | 135 ++++++++++++++---- src/adcp/signing/jwks.py | 40 ++++-- src/adcp/signing/revocation_fetcher.py | 19 ++- .../signing/test_ip_pinned_transport.py | 103 +++++++++++-- 5 files changed, 264 insertions(+), 48 deletions(-) diff --git a/src/adcp/signing/__init__.py b/src/adcp/signing/__init__.py index 33b31aef6..70ef2bf5b 100644 --- a/src/adcp/signing/__init__.py +++ b/src/adcp/signing/__init__.py @@ -40,6 +40,19 @@ revocation list from ``{issuer}/.well-known/governance-revocations.json`` * Async variants: :class:`AsyncCachingJwksResolver`, :class:`AsyncCachingRevocationChecker` + +**Custom fetchers** (rolling your own JWKS / revocation transport): + +* :func:`build_ip_pinned_transport` / + :func:`build_async_ip_pinned_transport` — returns an + :class:`httpx.HTTPTransport` wired to resolve the URI's host once + (with SSRF validation) and pin subsequent connects to that IP. + Closes the DNS-rebinding TOCTOU for anything built on + :class:`httpx.Client`. +* :func:`resolve_and_validate_host` — returns ``(host, ip, port)``; + same SSRF rules as :func:`validate_jwks_uri`. Use this if you're + wiring your own transport and only need the resolved + validated + IP. """ from __future__ import annotations @@ -105,6 +118,7 @@ AsyncIpPinnedTransport, IpPinnedTransport, abuild_ip_pinned_transport, + build_async_ip_pinned_transport, build_ip_pinned_transport, ) from adcp.signing.jwks import ( @@ -261,6 +275,7 @@ def __init__(self, *args: object, **kwargs: object) -> None: "averify_jws_document", "b64url_decode", "b64url_encode", + "build_async_ip_pinned_transport", "build_ip_pinned_transport", "build_signature_base", "canonicalize_authority", diff --git a/src/adcp/signing/ip_pinned_transport.py b/src/adcp/signing/ip_pinned_transport.py index 4214f7e32..0a8df62d7 100644 --- a/src/adcp/signing/ip_pinned_transport.py +++ b/src/adcp/signing/ip_pinned_transport.py @@ -26,9 +26,11 @@ * Classes use the ``Async`` CapWords prefix (:class:`AsyncIpPinnedTransport`). -* Free functions use the ``async_``/``a`` prefix - (:func:`abuild_ip_pinned_transport`) — matches the rest of this - sub-package. +* Factory functions that BUILD an async transport use + ``build_async_*`` (:func:`build_async_ip_pinned_transport`). The + factory itself is synchronous — it returns an async transport. +* The legacy ``abuild_*`` alias remains for backward-compatibility + but is deprecated. Dependency on httpcore internals -------------------------------- @@ -55,6 +57,7 @@ from __future__ import annotations import ssl +import warnings from collections.abc import Iterable from typing import TYPE_CHECKING, Any @@ -69,10 +72,7 @@ from httpcore._backends.anyio import AnyIOBackend as _AnyIOBackend from httpcore._backends.sync import SyncBackend as _SyncBackend -from adcp.signing.jwks import ( - DEFAULT_JWKS_TIMEOUT_SECONDS, - resolve_and_validate_host, -) +from adcp.signing.jwks import resolve_and_validate_host if TYPE_CHECKING: from httpcore._backends.base import SOCKET_OPTION @@ -81,7 +81,8 @@ __all__ = [ "AsyncIpPinnedTransport", "IpPinnedTransport", - "abuild_ip_pinned_transport", + "abuild_ip_pinned_transport", # deprecated alias; remove next release + "build_async_ip_pinned_transport", "build_ip_pinned_transport", ] @@ -95,17 +96,42 @@ def _build_ssl_context() -> ssl.SSLContext: return ssl.create_default_context() +def _normalize_pin_host(host: str) -> str: + """Normalize a hostname for byte-equal comparison. + + Lowercases, strips a single trailing dot, and IDNA-encodes so + Unicode hostnames compare equal to the punycode form httpx + passes to httpcore. + """ + host = host.lower() + if host.endswith("."): + host = host[:-1] + try: + return host.encode("idna").decode("ascii") + except (UnicodeError, UnicodeEncodeError): + # Caller already stored the normalized form; fall through + # with the lowercased input so the comparison just fails + # cleanly instead of raising inside connect_tcp. + return host + + class _IpPinnedSyncBackend(_SyncBackend): """httpcore sync backend that connects by IP for one pinned hostname. Delegates to the parent's ``connect_tcp`` after swapping the host argument from the hostname to the pre-resolved IP. All other methods (``connect_unix_socket``) pass through unchanged. + + **Fails closed on wrong-host reuse.** If the caller reuses this + transport for a DIFFERENT hostname (stored in a dict keyed by + origin, for example), we raise instead of falling through to an + unpinned ``connect_tcp`` — that fall-through is exactly the + TOCTOU the pin exists to close. Build a new transport per host. """ def __init__(self, *, hostname: str, resolved_ip: str) -> None: super().__init__() - self._hostname = hostname.lower() + self._hostname = _normalize_pin_host(hostname) self._resolved_ip = resolved_ip def connect_tcp( @@ -116,10 +142,15 @@ def connect_tcp( local_address: str | None = None, socket_options: Iterable[SOCKET_OPTION] | None = None, ) -> Any: - if host.lower() == self._hostname: - host = self._resolved_ip + normalized = _normalize_pin_host(host) + if normalized != self._hostname: + raise RuntimeError( + f"IpPinnedTransport is pinned to {self._hostname!r}; " + f"refusing connect to {host!r} — build a new transport per host " + f"(see build_ip_pinned_transport)" + ) return super().connect_tcp( - host=host, + host=self._resolved_ip, port=port, timeout=timeout, local_address=local_address, @@ -128,11 +159,15 @@ def connect_tcp( class _IpPinnedAsyncBackend(_AnyIOBackend): - """Async counterpart to :class:`_IpPinnedSyncBackend`.""" + """Async counterpart to :class:`_IpPinnedSyncBackend`. + + See :class:`_IpPinnedSyncBackend` for the fail-closed contract + on wrong-host reuse. + """ def __init__(self, *, hostname: str, resolved_ip: str) -> None: super().__init__() - self._hostname = hostname.lower() + self._hostname = _normalize_pin_host(hostname) self._resolved_ip = resolved_ip async def connect_tcp( @@ -143,10 +178,15 @@ async def connect_tcp( local_address: str | None = None, socket_options: Iterable[SOCKET_OPTION] | None = None, ) -> Any: - if host.lower() == self._hostname: - host = self._resolved_ip + normalized = _normalize_pin_host(host) + if normalized != self._hostname: + raise RuntimeError( + f"AsyncIpPinnedTransport is pinned to {self._hostname!r}; " + f"refusing connect to {host!r} — build a new transport per host " + f"(see abuild_ip_pinned_transport)" + ) return await super().connect_tcp( - host=host, + host=self._resolved_ip, port=port, timeout=timeout, local_address=local_address, @@ -173,10 +213,18 @@ def __init__( resolved_ip: str, verify: bool = True, retries: int = 0, + max_connections: int | None = 100, + max_keepalive_connections: int | None = 20, ) -> None: if verify: ssl_context = _build_ssl_context() else: + warnings.warn( + "IpPinnedTransport constructed with verify=False — TLS cert " + "validation is disabled. Use only for tests against local " + "origins; NEVER in production.", + stacklevel=2, + ) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE @@ -184,13 +232,18 @@ def __init__( backend = _IpPinnedSyncBackend(hostname=hostname, resolved_ip=resolved_ip) # Build the ConnectionPool ourselves (rather than super().__init__ # and then reassign ._pool) so the TLS + backend config is set - # up atomically and we don't briefly own a vanilla pool. + # up atomically and we don't briefly own a vanilla pool. Match + # httpx's default connection limits explicitly — httpcore's + # ConnectionPool default is 10/_ which would be a surprise + # downgrade for callers who expect httpx-shaped pool sizing. self._pool = httpcore.ConnectionPool( ssl_context=ssl_context, network_backend=backend, http1=True, http2=False, retries=retries, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, ) @@ -204,10 +257,18 @@ def __init__( resolved_ip: str, verify: bool = True, retries: int = 0, + max_connections: int | None = 100, + max_keepalive_connections: int | None = 20, ) -> None: if verify: ssl_context = _build_ssl_context() else: + warnings.warn( + "AsyncIpPinnedTransport constructed with verify=False — TLS " + "cert validation is disabled. Use only for tests against " + "local origins; NEVER in production.", + stacklevel=2, + ) ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) ssl_context.check_hostname = False ssl_context.verify_mode = ssl.CERT_NONE @@ -219,6 +280,8 @@ def __init__( http1=True, http2=False, retries=retries, + max_connections=max_connections, + max_keepalive_connections=max_keepalive_connections, ) @@ -244,24 +307,40 @@ def build_ip_pinned_transport( return IpPinnedTransport(hostname=hostname, resolved_ip=resolved_ip, verify=verify) -def abuild_ip_pinned_transport( +def build_async_ip_pinned_transport( uri: str, *, allow_private: bool = False, verify: bool = True, ) -> AsyncIpPinnedTransport: - """Async counterpart to :func:`build_ip_pinned_transport`. + """Build an :class:`AsyncIpPinnedTransport` for ``uri``. - The resolution itself is synchronous (``socket.getaddrinfo``); - this function is not actually awaitable, but the prefix matches - the rest of the sub-package's naming and the returned transport - is async. + Resolve + validate run synchronously (``socket.getaddrinfo``); this + function itself is not awaitable. The returned transport plugs + into :class:`httpx.AsyncClient`. """ hostname, resolved_ip, _port = resolve_and_validate_host(uri, allow_private=allow_private) return AsyncIpPinnedTransport(hostname=hostname, resolved_ip=resolved_ip, verify=verify) -# Quiet the unused-import warning — DEFAULT_JWKS_TIMEOUT_SECONDS is -# imported for re-export convenience and callers who want to pair the -# transport with a matching timeout. -_ = DEFAULT_JWKS_TIMEOUT_SECONDS +def abuild_ip_pinned_transport( + uri: str, + *, + allow_private: bool = False, + verify: bool = True, +) -> AsyncIpPinnedTransport: + """Deprecated alias for :func:`build_async_ip_pinned_transport`. + + The ``a``-prefix convention in this package means "awaitable + coroutine" (``averify_detached_jws`` etc.) — but this factory is + synchronous. Renamed during PR #206 review; kept for one release + so downstream callers have time to migrate. + """ + warnings.warn( + "abuild_ip_pinned_transport is deprecated; use " + "build_async_ip_pinned_transport (factory is sync, returns " + "an AsyncIpPinnedTransport).", + DeprecationWarning, + stacklevel=2, + ) + return build_async_ip_pinned_transport(uri, allow_private=allow_private, verify=verify) diff --git a/src/adcp/signing/jwks.py b/src/adcp/signing/jwks.py index eb34d2b47..79fcbc9c3 100644 --- a/src/adcp/signing/jwks.py +++ b/src/adcp/signing/jwks.py @@ -153,10 +153,29 @@ def resolve_and_validate_host( """ parts = urlsplit(uri) if parts.scheme not in ("http", "https"): - raise SSRFValidationError(f"unsupported scheme for JWKS URI: {parts.scheme!r}") + raise SSRFValidationError( + f"unsupported URI scheme for SSRF-validated fetch: " + f"{parts.scheme!r} (only http/https allowed)" + ) host = parts.hostname if host is None or host == "": - raise SSRFValidationError("JWKS URI has no host") + raise SSRFValidationError(f"URI has no host: {uri!r}") + # Strip a single trailing dot (FQDN form) so the pin matches what + # httpx / httpcore pass on subsequent requests. Without this, a + # caller who constructs with ``https://host./`` and then requests + # ``https://host/`` (or vice versa) sees the backend's + # hostname-match fail and falls through to unpinned resolution. + if host.endswith("."): + host = host[:-1] + # IDNA-encode so Unicode hostnames match the ASCII form httpx + # produces before calling into httpcore. urlsplit preserves the + # raw Unicode; httpx encodes it. A mismatch here breaks the + # hostname-match in the backend override and silently reopens + # the TOCTOU for IDN hosts. + try: + host = host.encode("idna").decode("ascii").lower() + except (UnicodeError, UnicodeEncodeError) as exc: + raise SSRFValidationError(f"URI host {host!r} is not IDNA-valid: {exc}") from exc port = parts.port if parts.port is not None else (443 if parts.scheme == "https" else 80) try: @@ -221,14 +240,17 @@ def default_jwks_fetcher(uri: str, *, allow_private: bool = False) -> dict[str, from adcp.signing.ip_pinned_transport import build_ip_pinned_transport transport = build_ip_pinned_transport(uri, allow_private=allow_private) - # follow_redirects=False is explicit: httpx already defaults to no-follow, - # but an attacker controlling the JWKS origin could 302 us to a - # hostname our pinned transport doesn't cover, re-introducing the - # TOCTOU. Keep redirects off. + # follow_redirects=False: a 302 to a different hostname would + # bypass the pin. trust_env=False: httpx's default True picks up + # HTTPS_PROXY / HTTP_PROXY from the environment and routes the + # request through an HTTPProxy pool that ignores our pinned + # backend entirely — a process with HTTPS_PROXY set to an + # attacker-controlled endpoint would bypass the TOCTOU fix. with httpx.Client( transport=transport, timeout=DEFAULT_JWKS_TIMEOUT_SECONDS, follow_redirects=False, + trust_env=False, ) as client: response = client.get(uri, headers={"Accept": "application/json"}) response.raise_for_status() @@ -322,13 +344,15 @@ async def async_default_jwks_fetcher(uri: str, *, allow_private: bool = False) - fetches AND the DNS-rebinding TOCTOU stays closed. Same SSRF + follow-redirects rules as the sync version. """ - from adcp.signing.ip_pinned_transport import abuild_ip_pinned_transport + from adcp.signing.ip_pinned_transport import build_async_ip_pinned_transport - transport = abuild_ip_pinned_transport(uri, allow_private=allow_private) + transport = build_async_ip_pinned_transport(uri, allow_private=allow_private) + # See default_jwks_fetcher for why trust_env=False matters. async with httpx.AsyncClient( transport=transport, timeout=DEFAULT_JWKS_TIMEOUT_SECONDS, follow_redirects=False, + trust_env=False, ) as client: response = await client.get(uri, headers={"Accept": "application/json"}) response.raise_for_status() diff --git a/src/adcp/signing/revocation_fetcher.py b/src/adcp/signing/revocation_fetcher.py index ede0d1e59..c902a04d6 100644 --- a/src/adcp/signing/revocation_fetcher.py +++ b/src/adcp/signing/revocation_fetcher.py @@ -260,7 +260,15 @@ def default_revocation_list_fetcher( transport = build_ip_pinned_transport(uri, allow_private=allow_private) headers = _build_fetch_headers(if_none_match=if_none_match, if_modified_since=if_modified_since) try: - with httpx.Client(transport=transport, timeout=timeout, follow_redirects=False) as client: + # trust_env=False keeps HTTPS_PROXY env vars from routing through + # an attacker-controlled proxy that would bypass the IP-pinned + # transport. See default_jwks_fetcher docstring. + with httpx.Client( + transport=transport, + timeout=timeout, + follow_redirects=False, + trust_env=False, + ) as client: response = client.get(uri, headers=headers) except httpx.HTTPError as exc: raise RevocationListFetchError(f"revocation list GET {uri!r} failed: {exc}") from exc @@ -289,13 +297,16 @@ async def async_default_revocation_list_fetcher( :class:`httpx.AsyncClient` so the event loop isn't blocked during the round-trip. """ - from adcp.signing.ip_pinned_transport import abuild_ip_pinned_transport + from adcp.signing.ip_pinned_transport import build_async_ip_pinned_transport - transport = abuild_ip_pinned_transport(uri, allow_private=allow_private) + transport = build_async_ip_pinned_transport(uri, allow_private=allow_private) headers = _build_fetch_headers(if_none_match=if_none_match, if_modified_since=if_modified_since) try: async with httpx.AsyncClient( - transport=transport, timeout=timeout, follow_redirects=False + transport=transport, + timeout=timeout, + follow_redirects=False, + trust_env=False, ) as client: response = await client.get(uri, headers=headers) except httpx.HTTPError as exc: diff --git a/tests/conformance/signing/test_ip_pinned_transport.py b/tests/conformance/signing/test_ip_pinned_transport.py index 46706956b..b9292b8e5 100644 --- a/tests/conformance/signing/test_ip_pinned_transport.py +++ b/tests/conformance/signing/test_ip_pinned_transport.py @@ -31,6 +31,7 @@ IpPinnedTransport, SSRFValidationError, abuild_ip_pinned_transport, + build_async_ip_pinned_transport, build_ip_pinned_transport, resolve_and_validate_host, ) @@ -101,10 +102,46 @@ def test_resolve_defaults_http_port_80() -> None: def test_resolve_rejects_non_http_scheme() -> None: - with pytest.raises(SSRFValidationError, match="scheme"): + # Error wording is generic (not "JWKS URI") since the helper is + # used by revocation fetchers and custom-transport callers too. + with pytest.raises(SSRFValidationError, match="SSRF-validated"): resolve_and_validate_host("ftp://example.com/jwks") +def test_resolve_normalizes_idn_hostname_to_punycode() -> None: + """Unicode hostnames get IDNA-encoded to the ASCII form httpx + passes to httpcore — otherwise the backend's hostname-match fails + and the pin silently falls through to the parent's unpinned + connect_tcp, reopening the TOCTOU. + """ + + # Patch getaddrinfo to short-circuit DNS for the IDN test host. + def fake_getaddrinfo(host, _port, *_args, **_kwargs): + # Must be called with the ASCII-encoded form. + assert ( + host == "xn--mnchen-3ya.example" + ), f"resolve_and_validate_host should IDNA-encode; got {host!r}" + return [(socket.AF_INET, 0, 0, "", ("8.8.8.8", 0))] + + with patch("adcp.signing.jwks.socket.getaddrinfo", side_effect=fake_getaddrinfo): + host, _ip, _port = resolve_and_validate_host("https://münchen.example/") + assert host == "xn--mnchen-3ya.example" + + +def test_resolve_strips_trailing_dot_fqdn() -> None: + """An FQDN URL form (trailing dot) must compare equal to the + non-FQDN form so the backend pin fires either way. + """ + + def fake_getaddrinfo(host, _port, *_args, **_kwargs): + assert host == "example.com", f"trailing dot not stripped; got {host!r}" + return [(socket.AF_INET, 0, 0, "", ("8.8.8.8", 0))] + + with patch("adcp.signing.jwks.socket.getaddrinfo", side_effect=fake_getaddrinfo): + host, _ip, _port = resolve_and_validate_host("https://example.com./jwks") + assert host == "example.com" + + def test_resolve_rejects_private_result_without_allow_private() -> None: # Simulate getaddrinfo returning a private IP — a rebinding # attacker's payload. @@ -170,7 +207,7 @@ def fake_getaddrinfo(_host, _port, *_args, **_kwargs): return [(socket.AF_INET, 0, 0, "", ("1.1.1.1", 0))] with patch("adcp.signing.jwks.socket.getaddrinfo", side_effect=fake_getaddrinfo): - transport = abuild_ip_pinned_transport("https://attacker.example/") + transport = build_async_ip_pinned_transport("https://attacker.example/") assert call_count["n"] == 1 pool = transport._pool @@ -179,6 +216,19 @@ def fake_getaddrinfo(_host, _port, *_args, **_kwargs): assert backend._hostname == "attacker.example" +def test_abuild_alias_emits_deprecation_warning() -> None: + """Legacy alias still works but warns. Remove after downstream + migration lands.""" + + def fake_getaddrinfo(_host, _port, *_args, **_kwargs): + return [(socket.AF_INET, 0, 0, "", ("8.8.8.8", 0))] + + with patch("adcp.signing.jwks.socket.getaddrinfo", side_effect=fake_getaddrinfo): + with pytest.warns(DeprecationWarning, match="build_async_ip_pinned_transport"): + transport = abuild_ip_pinned_transport("https://example.com/") + assert isinstance(transport, AsyncIpPinnedTransport) + + def test_backend_connect_tcp_swaps_hostname_for_pinned_ip() -> None: """Directly test the backend's override — ``connect_tcp`` with the pinned hostname calls the parent with the resolved IP instead. @@ -201,13 +251,30 @@ def _fake_parent_connect(self, *, host, port, timeout, local_address, socket_opt assert captured["port"] == 443 -def test_backend_connect_tcp_leaves_other_hosts_unchanged() -> None: - """If some code path reuses the transport for a DIFFERENT host - (misuse), the backend MUST NOT silently route it to the pinned IP. +def test_backend_connect_tcp_refuses_wrong_host() -> None: + """Reuse of a transport for a DIFFERENT host MUST raise. + + Falling through to the parent's unpinned ``connect_tcp`` would + silently re-open the DNS-rebinding TOCTOU — exactly what the + pin exists to close. Agents observed caching one transport in + a dict keyed by base URL and reusing it, which would hit this + path in production. """ from adcp.signing.ip_pinned_transport import _IpPinnedSyncBackend backend = _IpPinnedSyncBackend(hostname="attacker.example", resolved_ip="198.51.100.40") + with pytest.raises(RuntimeError, match="pinned to 'attacker.example'"): + backend.connect_tcp(host="other.example", port=443) + + +def test_backend_accepts_trailing_dot_fqdn_form() -> None: + """A caller pinning ``host.`` and connecting to ``host`` (or vice + versa) must still fire the pin — trailing dots are stripped on + both sides. + """ + from adcp.signing.ip_pinned_transport import _IpPinnedSyncBackend + + backend = _IpPinnedSyncBackend(hostname="attacker.example.", resolved_ip="198.51.100.45") captured = {} def _fake_parent_connect(self, *, host, port, timeout, local_address, socket_options): @@ -215,9 +282,29 @@ def _fake_parent_connect(self, *, host, port, timeout, local_address, socket_opt return object() with patch.object(SyncBackend, "connect_tcp", _fake_parent_connect): - backend.connect_tcp(host="other.example", port=443) + backend.connect_tcp(host="attacker.example", port=443) + assert captured["host"] == "198.51.100.45" - assert captured["host"] == "other.example" + +def test_backend_accepts_idn_punycode_form() -> None: + """httpx IDNA-encodes Unicode hostnames before calling httpcore. + The backend stored the punycode form, so the comparison against + a punycode input must succeed and the pin must fire. + """ + from adcp.signing.ip_pinned_transport import _IpPinnedSyncBackend + + # Pin the Unicode form; _normalize_pin_host encodes it to punycode. + backend = _IpPinnedSyncBackend(hostname="münchen.example", resolved_ip="198.51.100.55") + captured = {} + + def _fake_parent_connect(self, *, host, port, timeout, local_address, socket_options): + captured["host"] = host + return object() + + with patch.object(SyncBackend, "connect_tcp", _fake_parent_connect): + # httpx passes the punycode form. + backend.connect_tcp(host="xn--mnchen-3ya.example", port=443) + assert captured["host"] == "198.51.100.55" def test_backend_hostname_match_is_case_insensitive() -> None: @@ -275,6 +362,6 @@ def test_transport_type_is_httpx_httptransport() -> None: assert isinstance(transport, httpx.HTTPTransport) assert isinstance(transport, IpPinnedTransport) - atransport = abuild_ip_pinned_transport("https://example.com/") + atransport = build_async_ip_pinned_transport("https://example.com/") assert isinstance(atransport, httpx.AsyncHTTPTransport) assert isinstance(atransport, AsyncIpPinnedTransport)