diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ccea9484a..e7da672d8 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -39,6 +39,48 @@ jobs: run: | pytest tests/ -v --cov=src/adcp --cov-report=term-missing + pg-replay-store: + name: PgReplayStore tests (Postgres 16) + runs-on: ubuntu-latest + services: + postgres: + # CI-local ephemeral database. POSTGRES_HOST_AUTH_METHOD=trust + # avoids shipping any password literal (real or placeholder) in + # this workflow — GitHub's default CI network is already the + # trust boundary for this throwaway service. + image: postgres:16 + env: + POSTGRES_HOST_AUTH_METHOD: trust + POSTGRES_DB: adcp_test + ports: + - 5432:5432 + options: >- + --health-cmd pg_isready + --health-interval 5s + --health-timeout 5s + --health-retries 10 + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install dependencies (with [pg] extra) + run: | + python -m pip install --upgrade pip + pip install -e ".[dev,pg]" + + - name: Run PgReplayStore tests (unit + full-wire e2e) + env: + ADCP_PG_TEST_URL: postgresql://postgres@localhost:5432/adcp_test + run: | + pytest tests/conformance/signing/test_pg_replay_store.py \ + tests/conformance/signing/test_pg_replay_store_e2e.py \ + -v + conventional-commits: name: Validate conventional commit format runs-on: ubuntu-latest diff --git a/pyproject.toml b/pyproject.toml index 791980721..71d924a67 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,6 +59,13 @@ dev = [ docs = [ "pdoc3>=0.10.0", ] +pg = [ + # PostgreSQL-backed PgReplayStore (and future PgIdempotencyBackend). + # psycopg3 gives both sync + async client interfaces so the same dep + # serves the sync replay store today and an async one later. + "psycopg[binary]>=3.1.0", + "psycopg-pool>=3.2.0", +] [project.urls] Homepage = "https://github.com/adcontextprotocol/adcp-client-python" @@ -70,7 +77,7 @@ Issues = "https://github.com/adcontextprotocol/adcp-client-python/issues" where = ["src"] [tool.setuptools.package-data] -adcp = ["py.typed", "ADCP_VERSION"] +adcp = ["py.typed", "ADCP_VERSION", "signing/pg/*.sql"] [tool.black] line-length = 100 @@ -114,6 +121,12 @@ disable_error_code = ["valid-type"] module = "tests.integration.*" ignore_errors = true +# psycopg is an optional dep behind the [pg] extra; type stubs aren't +# guaranteed to be present when the base SDK is installed. +[[tool.mypy.overrides]] +module = ["psycopg", "psycopg.*", "psycopg_pool", "psycopg_pool.*"] +ignore_missing_imports = true + [tool.pytest.ini_options] testpaths = ["tests"] asyncio_mode = "auto" diff --git a/src/adcp/signing/__init__.py b/src/adcp/signing/__init__.py index 6510defd8..5b002185c 100644 --- a/src/adcp/signing/__init__.py +++ b/src/adcp/signing/__init__.py @@ -1,7 +1,45 @@ """AdCP RFC 9421 request-signing profile. -Implements the transport-layer signed-request profile from the AdCP specification. -See: https://adcontextprotocol.org/docs/building/implementation/security#signed-requests-transport-layer +Implements the transport-layer signed-request profile from the AdCP +specification. See: +https://adcontextprotocol.org/docs/building/implementation/security#signed-requests-transport-layer + +Quickstart +========== + +The core names you'll reach for (everything else is for advanced use): + +**Buyers** (signing outgoing requests): + +* :func:`sign_request` — produce ``Signature`` / ``Signature-Input`` + headers for one request +* :func:`load_private_key_pem` — rehydrate the PEM ``adcp-keygen`` wrote +* :class:`SigningConfig` — bundle key material for auto-signing via + ``ADCPClient(signing=...)`` + +**Sellers** (verifying incoming requests): + +* :func:`verify_starlette_request` / :func:`verify_flask_request` — + framework-shaped wrappers around :func:`verify_request_signature` +* :class:`VerifyOptions` — the knobs (capability, jwks_resolver, + replay_store, revocation_checker) +* :class:`VerifierCapability` — what the seller advertises (e.g. + ``required_for={"create_media_buy"}``) +* :class:`StaticJwksResolver` — for testing; use + :class:`CachingJwksResolver` against a live ``jwks_uri`` +* :class:`SignatureVerificationError` — raised on rejection; ``.code`` + is the spec error string +* :func:`unauthorized_response_headers` — builds the 401 + ``WWW-Authenticate: Signature error="..."`` header +* :class:`InMemoryReplayStore` for single-process deployments; + :class:`PgReplayStore` (behind ``[pg]`` extra) for multi-worker + +**Governance agents**: + +* :class:`CachingRevocationChecker` — fetches + caches a signed + revocation list from ``{issuer}/.well-known/governance-revocations.json`` +* Async variants: :class:`AsyncCachingJwksResolver`, + :class:`AsyncCachingRevocationChecker` """ from __future__ import annotations @@ -35,6 +73,7 @@ b64url_encode, extract_signature_bytes, format_signature_header, + load_private_key_pem, private_key_from_jwk, public_key_from_jwk, sign_signature_base, @@ -117,6 +156,31 @@ verify_request_signature, ) +# Conditional import: PgReplayStore needs the [pg] extra. Always expose +# the name — if psycopg isn't installed we fall through to a stub class +# whose constructor raises ImportError with the install hint. Exposing +# None would give callers a confusing ``TypeError: 'NoneType' object is +# not callable`` on instantiation; the stub turns that into a +# self-explanatory error at the right moment. +try: + from adcp.signing.pg import PgReplayStore # noqa: F401 +except ImportError: # pragma: no cover — exercised by the [pg] extra tests + + class PgReplayStore: # type: ignore[no-redef] + """Stub raised when ``adcp[pg]`` isn't installed. + + Attempting to instantiate raises :class:`ImportError` with the + install-hint text from :mod:`adcp.signing.pg.replay_store`. + """ + + def __init__(self, *args: object, **kwargs: object) -> None: + raise ImportError( + "PgReplayStore requires psycopg3 and psycopg-pool. Install the " + "'pg' extra: `pip install 'adcp[pg]'` (Poetry: " + "`poetry add 'adcp[pg]'`)." + ) + + __all__ = [ "ALG_ED25519", "ALG_ES256", @@ -141,6 +205,7 @@ "JwsUnknownKeyError", "MAX_WINDOW_SECONDS", "NONCE_BYTES", + "PgReplayStore", "REQUEST_SIGNATURE_ALG_NOT_ALLOWED", "REQUEST_SIGNATURE_COMPONENTS_INCOMPLETE", "REQUEST_SIGNATURE_COMPONENTS_UNEXPECTED", @@ -195,6 +260,7 @@ "default_revocation_list_fetcher", "extract_signature_bytes", "format_signature_header", + "load_private_key_pem", "operation_needs_signing", "parse_signature_input_header", "private_key_from_jwk", diff --git a/src/adcp/signing/crypto.py b/src/adcp/signing/crypto.py index 2e9ff64cd..7096b9f1d 100644 --- a/src/adcp/signing/crypto.py +++ b/src/adcp/signing/crypto.py @@ -18,7 +18,7 @@ from typing import Any from cryptography.exceptions import InvalidSignature -from cryptography.hazmat.primitives import hashes +from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec, ed25519 from cryptography.hazmat.primitives.asymmetric.utils import ( decode_dss_signature, @@ -50,6 +50,50 @@ def b64url_encode(b: bytes) -> str: return base64.urlsafe_b64encode(b).rstrip(b"=").decode("ascii") +def load_private_key_pem(pem: bytes, *, password: bytes | None = None) -> PrivateKey: + """Load an Ed25519 or P-256 private key from PKCS8 PEM bytes. + + Closes the loop between ``adcp-keygen`` (which writes a PEM) and + :func:`sign_request` (which takes a ``PrivateKey`` object), so + integrators don't need a direct ``cryptography`` import just to + rehydrate the key. + + Parameters + ---------- + pem: + PEM-encoded PKCS8 private key as bytes. Read via + ``pathlib.Path(...).read_bytes()``. + password: + Passphrase if the PEM is encrypted (``adcp-keygen --encrypt``). + Passed through to the cryptography loader as bytes. + + Returns + ------- + PrivateKey + An :class:`Ed25519PrivateKey` or + :class:`EllipticCurvePrivateKey` ready to pass into + :func:`sign_request`. + + Raises + ------ + ValueError + The PEM is not Ed25519 or ES256 (P-256). These are the only + algorithms the AdCP request-signing profile allows. + """ + key = serialization.load_pem_private_key(pem, password=password) + if not isinstance(key, (ed25519.Ed25519PrivateKey, ec.EllipticCurvePrivateKey)): + raise ValueError( + f"unsupported private key type {type(key).__name__} — " + f"AdCP signing accepts Ed25519 or ECDSA-P-256 only" + ) + if isinstance(key, ec.EllipticCurvePrivateKey) and not isinstance(key.curve, ec.SECP256R1): + raise ValueError( + f"EC key curve {key.curve.name} is not supported — only " + f"P-256 (SECP256R1) is allowed" + ) + return key + + def public_key_from_jwk(jwk: dict[str, Any]) -> PublicKey: """Reconstruct a public key from its JWK.""" kty = jwk.get("kty") diff --git a/src/adcp/signing/middleware.py b/src/adcp/signing/middleware.py index 9c793942c..a0a8efec8 100644 --- a/src/adcp/signing/middleware.py +++ b/src/adcp/signing/middleware.py @@ -34,10 +34,28 @@ def verify_flask_request(request: Any, *, options: VerifyOptions) -> VerifiedSig async def verify_starlette_request(request: Any, *, options: VerifyOptions) -> VerifiedSigner: - """Verify a Starlette / FastAPI `Request` object against the AdCP profile. + """Verify a Starlette / FastAPI ``Request`` object against the AdCP profile. - Consumes `await request.body()` — if downstream code also needs the body, - it must read `request.state` or the returned `VerifiedSigner`-side context. + Consumes ``await request.body()`` once — Starlette caches the result + internally, so downstream handlers calling ``request.body()`` or + ``request.json()`` again will get the same bytes. If your handler + needs the parsed body AFTER this verifier succeeds, call + ``await request.body()`` yourself downstream; there's no hidden + side channel on the returned :class:`VerifiedSigner`. + + Returns + ------- + VerifiedSigner + On success — carries the verified ``key_id`` and metadata. + + Raises + ------ + SignatureVerificationError + On any failure of the AdCP verifier checklist. The ``.code`` + attribute holds the spec's error code string (e.g. + ``request_signature_replayed``) and ``.step`` points at the + failed checklist step. Frameworks typically map this to a 401 + with :func:`unauthorized_response_headers`. """ body = await request.body() return verify_request_signature( diff --git a/src/adcp/signing/pg/__init__.py b/src/adcp/signing/pg/__init__.py new file mode 100644 index 000000000..04f23cc32 --- /dev/null +++ b/src/adcp/signing/pg/__init__.py @@ -0,0 +1,22 @@ +"""PostgreSQL-backed implementations for the signing module. + +This sub-package ships backends that require PostgreSQL via psycopg3. +They live here (and behind the ``[pg]`` optional extra) so the base +``adcp.signing`` import path stays free of SQL dependencies for users +who only need the pure-Python primitives. + +Available when ``adcp[pg]`` is installed: + +* :class:`PgReplayStore` — multi-instance-safe replay store for the + RFC 9421 verifier pipeline. + +The schema DDL ships alongside the Python code at +``adcp/signing/pg/replay_store.sql`` so integrators can run it through +whatever migration tool they use (Alembic, Flyway, psql, ...). +""" + +from __future__ import annotations + +from adcp.signing.pg.replay_store import PgReplayStore + +__all__ = ["PgReplayStore"] diff --git a/src/adcp/signing/pg/replay_store.py b/src/adcp/signing/pg/replay_store.py new file mode 100644 index 000000000..f8b46077d --- /dev/null +++ b/src/adcp/signing/pg/replay_store.py @@ -0,0 +1,302 @@ +"""PostgreSQL-backed :class:`~adcp.signing.ReplayStore` implementation. + +Gives multi-instance AdCP verifiers a shared nonce-seen store so a +replay accepted on worker A can't land again on worker B within the +signature's validity window. + +The caller supplies a :class:`psycopg_pool.ConnectionPool`. We don't +open, own, or close the pool — integrators typically already have one +for their main database and sharing is cleaner than a second pool. + +End-to-end example +------------------ + +:: + + from psycopg_pool import ConnectionPool + from adcp.signing import ( + PgReplayStore, + StaticJwksResolver, + VerifierCapability, + VerifyOptions, + verify_request_signature, + ) + + pool = ConnectionPool("postgresql://...", min_size=4, max_size=20) + replay = PgReplayStore(pool=pool) + replay.create_schema() # bootstrap once per deployment; idempotent + + options = VerifyOptions( + now=..., + capability=VerifierCapability(required_for=frozenset({"create_media_buy"})), + operation="create_media_buy", + jwks_resolver=StaticJwksResolver({"keys": [...]}), + replay_store=replay, # <-- plug in here + ) + signer = verify_request_signature( + method="POST", url=..., headers=..., body=..., options=options, + ) + +REQUIRED: sweep job +------------------- + +:meth:`seen` self-filters via ``expires_at > now()``, so lookups never +return stale entries. Rows accumulate, though — you MUST run a +periodic sweep or the table grows unbounded. Two options: + +1. **pg_cron** (or any out-of-process scheduler):: + + DELETE FROM adcp_replay WHERE expires_at <= now(); + +2. **In-process loop** — call :meth:`sweep_expired` on a timer:: + + async def sweep_forever(store: PgReplayStore, interval: float = 60.0) -> None: + while True: + store.sweep_expired() + await asyncio.sleep(interval) + +Pick one. An instance without a sweep is a memory leak waiting to +page your oncall. + +Failure mode +------------ + +Transport or connection errors propagate from psycopg unchanged +(``OperationalError``, ``PoolTimeout``, etc.). The current verifier +does not catch them — so a pool hiccup raises out of +:func:`~adcp.signing.verify_request_signature`, and the enclosing +framework returns a 5xx. That's fail-closed from the client's +perspective (no 2xx on a broken store), but it's the framework's +default, not a SignatureVerificationError the caller can cleanly +handle. If your handler wraps verifier calls in a +``except Exception: return 503``, you're good; if it only catches +``SignatureVerificationError``, a broken store bubbles up as an +uncaught exception. +""" + +from __future__ import annotations + +import re +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from psycopg_pool import ConnectionPool + +try: + import psycopg # noqa: F401 + import psycopg_pool # noqa: F401 + + PG_AVAILABLE = True +except ImportError: + PG_AVAILABLE = False + + +DEFAULT_TABLE_NAME = "adcp_replay" + +# Byte-level ASCII identifier match. ``str.islower()`` / ``str.isalpha()`` +# return True for non-ASCII Unicode letters (``é``, fullwidth Latin +# ``t``, ``µ``, ``ß`` etc.) — which would then format verbatim into SQL +# as a DIFFERENT table from the one the operator thinks they configured. +# Under multi-tenant config where ``table_name`` can be attacker- +# influenced, that's a real replay-bypass vector. The regex here is +# ASCII-range-by-construction. +_SAFE_IDENTIFIER_RE = re.compile(r"^[a-z_][a-z0-9_]{0,62}$") + +_INSTALL_HINT = ( + "PgReplayStore requires psycopg3 and psycopg-pool. Install the 'pg' " + "extra: `pip install 'adcp[pg]'` (Poetry: `poetry add 'adcp[pg]'`)." +) + + +class PgReplayStore: + """PostgreSQL-backed replay store implementing :class:`ReplayStore`. + + Parameters + ---------- + pool: + A :class:`psycopg_pool.ConnectionPool` owned by the caller. Each + operation acquires a short-lived connection, runs a single + statement, and returns the connection. No long-lived + transactions, no cross-operation state. + per_keyid_cap: + Maximum number of live (non-expired) nonces per ``keyid``. + Mirrors :class:`InMemoryReplayStore`; spec-recommended 1M. + When :meth:`at_capacity` reports True, the verifier rejects + with ``request_signature_rate_abuse`` rather than silently + evicting older entries (which would create a replay window + under attack). + table_name: + Override the default ``adcp_replay`` table if two tenants share + a database and need separate replay stores. Must be a + byte-equal-clean identifier — we don't quote it into the SQL + dynamically for obvious injection reasons; the constructor + validates shape. + + Concurrency + ----------- + + Safe to share across threads and processes. Postgres provides the + cross-instance locking we need via PK conflict resolution on + ``INSERT ... ON CONFLICT``. + """ + + def __init__( + self, + *, + pool: ConnectionPool, + per_keyid_cap: int = 1_000_000, + table_name: str = DEFAULT_TABLE_NAME, + ) -> None: + if not PG_AVAILABLE: + raise ImportError(_INSTALL_HINT) + if not _is_safe_identifier(table_name): + raise ValueError( + f"table_name must match [a-z_][a-z0-9_]* (ASCII only), " f"got {table_name!r}" + ) + self._pool = pool + self._per_keyid_cap = per_keyid_cap + self._table = table_name + + # Pre-format queries with the validated table name so the hot + # path doesn't f-string per call. + self._sql_seen = ( + f"SELECT 1 FROM {self._table} " # noqa: S608 — table name is whitelisted + f"WHERE keyid = %s AND nonce = %s AND expires_at > now()" + ) + # ``WHERE EXCLUDED.expires_at > {table}.expires_at`` avoids write + # amplification on the common case (a row is already present + # with a later-or-equal expiry). Without the predicate, every + # remember() would re-write the MVCC tuple even when the new + # TTL is shorter or equal. + self._sql_remember = ( + f"INSERT INTO {self._table} (keyid, nonce, expires_at) " # noqa: S608 + f"VALUES (%s, %s, now() + make_interval(secs => %s)) " + f"ON CONFLICT (keyid, nonce) DO UPDATE " + f"SET expires_at = EXCLUDED.expires_at " + f"WHERE EXCLUDED.expires_at > {self._table}.expires_at" + ) + self._sql_at_capacity = ( + f"SELECT COUNT(*) >= %s FROM {self._table} " # noqa: S608 + f"WHERE keyid = %s AND expires_at > now()" + ) + self._sql_sweep = f"DELETE FROM {self._table} WHERE expires_at <= now()" # noqa: S608 + + # -- schema bootstrap -------------------------------------------- + + def create_schema(self) -> None: + """Create the replay table + indexes for this store's ``table_name``. + + Honors the ``table_name`` kwarg the store was constructed with — + integrators using per-tenant tables get the right DDL without + extra plumbing. Idempotent via ``CREATE ... IF NOT EXISTS``; + safe to call on every app boot. + + The equivalent raw DDL ships at + :file:`src/adcp/signing/pg/replay_store.sql` for integrators + using a migration tool (Alembic, Flyway, psql) — that file + uses the canonical ``adcp_replay`` name. + """ + table = self._table # already validated at __init__ + ddl = ( + f"CREATE TABLE IF NOT EXISTS {table} (" # noqa: S608 — validated + f' keyid TEXT COLLATE "C" NOT NULL,' + f' nonce TEXT COLLATE "C" NOT NULL,' + f" expires_at TIMESTAMPTZ NOT NULL," + f" PRIMARY KEY (keyid, nonce)" + f");" + f"CREATE INDEX IF NOT EXISTS {table}_expires_idx " # noqa: S608 + f" ON {table} (expires_at);" + ) + with self._pool.connection() as conn, conn.cursor() as cur: + cur.execute(ddl) + + # -- ReplayStore Protocol ----------------------------------------- + + def seen(self, keyid: str, nonce: str) -> bool: + """Return True iff ``(keyid, nonce)`` has a live entry.""" + with self._pool.connection() as conn, conn.cursor() as cur: + cur.execute(self._sql_seen, (keyid, nonce)) + return cur.fetchone() is not None + + def remember(self, keyid: str, nonce: str, ttl_seconds: float) -> None: + """Record ``(keyid, nonce)`` with a TTL. + + ``ON CONFLICT ... DO UPDATE`` refreshes the expiry on a + legitimate retry of the same nonce in-window — matches + :class:`InMemoryReplayStore` behavior. + """ + with self._pool.connection() as conn, conn.cursor() as cur: + cur.execute(self._sql_remember, (keyid, nonce, ttl_seconds)) + + def at_capacity(self, keyid: str) -> bool: + """Return True iff the live row count for ``keyid`` meets the cap. + + Implementation note: ``COUNT(*) >= cap`` uses the PK for the + keyid filter and the expires index for the time predicate. + For the spec-recommended 1M cap, the expensive case is exactly + when a signer is misbehaving, so paying for accuracy is the + right trade. + + For deployments that need faster short-circuiting on a hot + keyid, an alternative shape is:: + + SELECT 1 FROM {table} + WHERE keyid = %s AND expires_at > now() + OFFSET %s LIMIT 1 + + which stops scanning once ``cap+1`` rows are seen. Swap in if + profiling identifies ``at_capacity`` as hot. + """ + with self._pool.connection() as conn, conn.cursor() as cur: + cur.execute(self._sql_at_capacity, (self._per_keyid_cap, keyid)) + row = cur.fetchone() + return bool(row[0]) if row is not None else False + + # -- admin / cron ------------------------------------------------ + + def sweep_expired(self) -> int: + """Delete all rows whose ``expires_at`` is in the past. + + Returns the number of rows removed. Safe to call concurrently + with :meth:`seen` / :meth:`remember`. + + Call from a cron or admin endpoint. :meth:`seen` self-filters + so expired rows never cause false positives, but they do + accumulate and grow the table. + """ + with self._pool.connection() as conn, conn.cursor() as cur: + cur.execute(self._sql_sweep) + return cur.rowcount or 0 + + def live_count(self, keyid: str) -> int: + """Return the number of live (non-expired) rows for ``keyid``. + + Mostly useful for tests, monitoring, and admin tooling. Not on + the :class:`ReplayStore` Protocol — hit-path code should call + :meth:`at_capacity` which short-circuits at the cap without + materializing the count. + """ + sql = ( + f"SELECT COUNT(*) FROM {self._table} " # noqa: S608 + f"WHERE keyid = %s AND expires_at > now()" + ) + with self._pool.connection() as conn, conn.cursor() as cur: + cur.execute(sql, (keyid,)) + row = cur.fetchone() + return int(row[0]) if row is not None else 0 + + +def _is_safe_identifier(name: str) -> bool: + """Allow only byte-ASCII lowercase identifiers for the table-name kwarg. + + The table name is static-formatted into SQL at construction; this + validator is the sole guard against injection OR silent table-name + substitution via Unicode homoglyphs. Must stay ASCII-byte-exact + (see :data:`_SAFE_IDENTIFIER_RE`). + + Postgres's NAMEDATALEN default caps identifiers at 63 bytes. + """ + return _SAFE_IDENTIFIER_RE.fullmatch(name) is not None + + +__all__ = ["PG_AVAILABLE", "DEFAULT_TABLE_NAME", "PgReplayStore"] diff --git a/src/adcp/signing/pg/replay_store.sql b/src/adcp/signing/pg/replay_store.sql new file mode 100644 index 000000000..235161f3f --- /dev/null +++ b/src/adcp/signing/pg/replay_store.sql @@ -0,0 +1,36 @@ +-- AdCP RFC 9421 replay-dedup store. +-- +-- Run this once per deployment. Tracked by the adcp-client-python +-- PgReplayStore; see src/adcp/signing/pg/replay_store.py for the +-- query shapes the Python code executes. +-- +-- COLLATE "C" on the identifier columns avoids locale-dependent case +-- folding — on some locales "Key-A" and "key-a" compare equal, which +-- would let an attacker collapse distinct kids / nonces into the same +-- slot. "C" is the byte-for-byte comparison we actually want. +-- +-- Run a periodic sweep (cron, every minute or so): +-- DELETE FROM adcp_replay WHERE expires_at <= now(); +-- The PgReplayStore.sweep_expired() method does exactly this and can +-- be called from an admin endpoint if you prefer an in-process sweep. + +CREATE TABLE IF NOT EXISTS adcp_replay ( + keyid TEXT COLLATE "C" NOT NULL, + nonce TEXT COLLATE "C" NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + PRIMARY KEY (keyid, nonce) +); + +-- Supports the sweep query and the at_capacity COUNT. Postgres will +-- use this for range predicates like ``expires_at > now()``, so +-- ``at_capacity`` for a busy keyid is an index-assisted scan rather +-- than a full table scan. +CREATE INDEX IF NOT EXISTS adcp_replay_expires_idx + ON adcp_replay (expires_at); + +-- A partial index on (keyid) WHERE expires_at > now() is NOT usable — +-- ``now()`` is STABLE, not IMMUTABLE, which Postgres forbids in index +-- predicates. If ``at_capacity`` for a specific keyid becomes hot in +-- profiling, the workable alternative is a composite +-- ``(keyid, expires_at)`` index; the existing PK + single-column +-- expires index already covers most patterns. diff --git a/tests/conformance/signing/test_pg_replay_store.py b/tests/conformance/signing/test_pg_replay_store.py new file mode 100644 index 000000000..eb0bf5ea6 --- /dev/null +++ b/tests/conformance/signing/test_pg_replay_store.py @@ -0,0 +1,300 @@ +"""Tests for :class:`adcp.signing.pg.PgReplayStore`. + +Requires a real PostgreSQL instance. To run locally:: + + docker run --rm -d -p 5432:5432 -e POSTGRES_PASSWORD=pg postgres:16 + export ADCP_PG_TEST_URL=postgresql://postgres:pg@localhost:5432/postgres + pytest tests/conformance/signing/test_pg_replay_store.py -v + +The entire module skips when ``ADCP_PG_TEST_URL`` is unset, so the +default test matrix stays green without a database dependency. + +Each test runs in an isolated schema (``test_adcp_replay_``) +so parallel test runs and rerun-after-crash scenarios don't collide. +""" + +from __future__ import annotations + +import os +import secrets +import threading +import time +from collections.abc import Iterator + +import pytest + +psycopg = pytest.importorskip("psycopg") +psycopg_pool = pytest.importorskip("psycopg_pool") + +TEST_URL = os.environ.get("ADCP_PG_TEST_URL") +if not TEST_URL: + pytest.skip( + "ADCP_PG_TEST_URL not set — skipping PgReplayStore tests", + allow_module_level=True, + ) + +from adcp.signing.pg import PgReplayStore # noqa: E402 + + +@pytest.fixture() +def isolated_pool() -> Iterator[psycopg_pool.ConnectionPool]: + """Connection pool against a per-test schema + table. + + Creates a unique table name so tests running in parallel (or a + crashed-then-retry run) don't step on each other. Drops the table + on teardown. + """ + table = f"test_adcp_replay_{secrets.token_hex(6)}" + with psycopg_pool.ConnectionPool(TEST_URL, min_size=2, max_size=8) as pool: + with pool.connection() as conn, conn.cursor() as cur: + cur.execute( + f""" + CREATE TABLE {table} ( + keyid TEXT COLLATE "C" NOT NULL, + nonce TEXT COLLATE "C" NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + PRIMARY KEY (keyid, nonce) + ) + """ + ) + cur.execute(f"CREATE INDEX {table}_expires_idx ON {table} (expires_at)") + try: + yield pool, table # type: ignore[misc] + finally: + with pool.connection() as conn, conn.cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {table}") + + +def _store(fixture, **overrides) -> PgReplayStore: + pool, table = fixture + return PgReplayStore(pool=pool, table_name=table, **overrides) + + +# -- Protocol happy path ---------------------------------------------- + + +def test_seen_returns_false_for_unknown_nonce(isolated_pool) -> None: + store = _store(isolated_pool) + assert store.seen("k", "n") is False + + +def test_remember_then_seen_returns_true(isolated_pool) -> None: + store = _store(isolated_pool) + store.remember("k", "n", ttl_seconds=60) + assert store.seen("k", "n") is True + + +def test_remember_different_nonce_isolated(isolated_pool) -> None: + store = _store(isolated_pool) + store.remember("k", "n1", ttl_seconds=60) + assert store.seen("k", "n2") is False + + +def test_remember_different_keyid_isolated(isolated_pool) -> None: + store = _store(isolated_pool) + store.remember("k1", "n", ttl_seconds=60) + assert store.seen("k2", "n") is False + + +# -- TTL semantics ---------------------------------------------------- + + +def test_seen_returns_false_after_ttl_expiry(isolated_pool) -> None: + store = _store(isolated_pool) + store.remember("k", "n", ttl_seconds=1) + time.sleep(1.2) + assert store.seen("k", "n") is False + + +def test_remember_refreshes_ttl_on_repeat(isolated_pool) -> None: + """ON CONFLICT DO UPDATE keeps the most recent TTL — mirrors InMemoryReplayStore.""" + store = _store(isolated_pool) + store.remember("k", "n", ttl_seconds=1) + # Refresh well before expiry with a longer TTL. + store.remember("k", "n", ttl_seconds=60) + time.sleep(1.2) + # The second remember's 60s TTL wins — still seen. + assert store.seen("k", "n") is True + + +# -- at_capacity ------------------------------------------------------ + + +def test_at_capacity_false_when_empty(isolated_pool) -> None: + store = _store(isolated_pool, per_keyid_cap=3) + assert store.at_capacity("k") is False + + +def test_at_capacity_true_at_cap(isolated_pool) -> None: + store = _store(isolated_pool, per_keyid_cap=3) + for i in range(3): + store.remember("k", f"n{i}", ttl_seconds=60) + assert store.at_capacity("k") is True + + +def test_at_capacity_respects_ttl_expiry(isolated_pool) -> None: + store = _store(isolated_pool, per_keyid_cap=3) + for i in range(3): + store.remember("k", f"n{i}", ttl_seconds=1) + assert store.at_capacity("k") is True + time.sleep(1.2) + # All three rows expired → count drops back to zero. + assert store.at_capacity("k") is False + + +def test_at_capacity_is_per_keyid(isolated_pool) -> None: + store = _store(isolated_pool, per_keyid_cap=2) + store.remember("k1", "a", ttl_seconds=60) + store.remember("k1", "b", ttl_seconds=60) + assert store.at_capacity("k1") is True + assert store.at_capacity("k2") is False + + +# -- sweep_expired --------------------------------------------------- + + +def test_sweep_expired_removes_stale_rows(isolated_pool) -> None: + store = _store(isolated_pool) + store.remember("k", "old", ttl_seconds=1) + store.remember("k", "fresh", ttl_seconds=60) + time.sleep(1.2) + + removed = store.sweep_expired() + assert removed == 1 + assert store.live_count("k") == 1 + assert store.seen("k", "fresh") is True + + +def test_sweep_expired_returns_zero_when_clean(isolated_pool) -> None: + store = _store(isolated_pool) + store.remember("k", "n", ttl_seconds=60) + assert store.sweep_expired() == 0 + + +# -- concurrency ----------------------------------------------------- + + +def test_concurrent_remember_same_nonce_is_idempotent(isolated_pool) -> None: + """Two workers racing on the same (keyid, nonce) MUST NOT error. + + ``ON CONFLICT ... DO UPDATE`` makes the second insert a no-op + (with refreshed TTL). Without it, the second worker would hit a + PK violation and blow up. + """ + store = _store(isolated_pool) + errors: list[Exception] = [] + + def worker() -> None: + try: + store.remember("k", "shared", ttl_seconds=60) + except Exception as exc: # noqa: BLE001 + errors.append(exc) + + threads = [threading.Thread(target=worker) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert errors == [] + assert store.seen("k", "shared") is True + assert store.live_count("k") == 1 + + +def test_concurrent_at_capacity_safe(isolated_pool) -> None: + """at_capacity from many threads shouldn't throw; value should stabilize.""" + store = _store(isolated_pool, per_keyid_cap=5) + for i in range(5): + store.remember("k", f"n{i}", ttl_seconds=60) + + results: list[bool] = [] + + def worker() -> None: + results.append(store.at_capacity("k")) + + threads = [threading.Thread(target=worker) for _ in range(20)] + for t in threads: + t.start() + for t in threads: + t.join() + + assert all(results) + + +# -- validation ----------------------------------------------------- + + +def test_bad_table_name_rejected(isolated_pool) -> None: + pool, _ = isolated_pool + with pytest.raises(ValueError, match="table_name"): + PgReplayStore(pool=pool, table_name="has-dash") + with pytest.raises(ValueError, match="table_name"): + PgReplayStore(pool=pool, table_name="1leading_digit") + with pytest.raises(ValueError, match="table_name"): + PgReplayStore(pool=pool, table_name="") + + +def test_non_ascii_table_name_rejected(isolated_pool) -> None: + """Unicode letters that ``str.islower()`` / ``str.isalpha()`` accept + MUST be rejected — otherwise a homoglyph like fullwidth ``table`` + formats into SQL as a DIFFERENT table than operators configured. + """ + pool, _ = isolated_pool + # Fullwidth Latin letters lowercase to themselves, not ASCII. + with pytest.raises(ValueError, match="ASCII"): + PgReplayStore(pool=pool, table_name="table") + # Latin small letter with accents — passes .islower() but not ASCII. + with pytest.raises(ValueError, match="ASCII"): + PgReplayStore(pool=pool, table_name="café_replay") + # Greek small letter mu — looks like "u" in some fonts. + with pytest.raises(ValueError, match="ASCII"): + PgReplayStore(pool=pool, table_name="µreplay") + + +def test_remember_twice_with_shorter_ttl_keeps_longer_expiry(isolated_pool) -> None: + """Conditional ``DO UPDATE WHERE EXCLUDED.expires_at > current`` means + a shorter-TTL repeat must NOT shrink the cached expiry. + """ + store = _store(isolated_pool) + store.remember("k", "n", ttl_seconds=60) + # Shorter TTL should be a no-op on the row. + store.remember("k", "n", ttl_seconds=1) + time.sleep(1.2) + # If the update had fired, the row would have expired by now. + assert store.seen("k", "n") is True + + +def test_create_schema_honors_table_name(isolated_pool) -> None: + """create_schema must create the table the store was built for — + not a hardcoded name. Integrators using per-tenant tables would + break silently if this regresses. + """ + pool, _existing_table = isolated_pool + custom_table = f"custom_{secrets.token_hex(4)}_replay" + store = PgReplayStore(pool=pool, table_name=custom_table) + try: + store.create_schema() + store.create_schema() # idempotent — second call must not error + # Actually use the store — proves the DDL + runtime queries + # target the same table. + store.remember("k", "n", ttl_seconds=60) + assert store.seen("k", "n") is True + finally: + with pool.connection() as conn, conn.cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {custom_table}") + + +def test_collation_prevents_case_collapse(isolated_pool) -> None: + """With COLLATE "C", keyid "Key-A" and "key-a" are distinct slots. + + Would be a real problem on locales where default collation case-folds: + a buyer with kid "Key-A" and an attacker with kid "key-a" would share + a replay cache, opening cross-tenant nonce interference. + """ + store = _store(isolated_pool) + store.remember("Key-A", "n", ttl_seconds=60) + # Same nonce, case-variant kid. With "C" collation these are distinct. + assert store.seen("key-a", "n") is False + # And at_capacity for the other case shouldn't see the first one either. + assert store.live_count("key-a") == 0 + assert store.live_count("Key-A") == 1 diff --git a/tests/conformance/signing/test_pg_replay_store_e2e.py b/tests/conformance/signing/test_pg_replay_store_e2e.py new file mode 100644 index 000000000..2f078b19f --- /dev/null +++ b/tests/conformance/signing/test_pg_replay_store_e2e.py @@ -0,0 +1,300 @@ +"""Full-wire end-to-end: signed HTTP request → Starlette verifier → PgReplayStore. + +This is the shape an actual integrator ships: a FastAPI/Starlette +server running ``verify_starlette_request`` with a ``PgReplayStore``, +receiving signed requests from a buyer-side client. Single-process +here (via ``httpx.ASGITransport``), but the wire-level contract is +identical to what a load-balanced multi-worker deployment sees — +including the cross-instance replay defense the Postgres store +exists to provide. + +Scenarios covered: + +1. **Happy path** — signed request with fresh nonce → 200. +2. **Replay** — same signed headers sent again → 401 with + ``WWW-Authenticate: Signature error="request_signature_replayed"``. +3. **Fresh nonce after replay** — different signed request → 200. +4. **Simulated second worker** — second ``PgReplayStore`` instance + on the same pool sees the first instance's ``remember``, rejects + a replay that landed on the "other" worker. + +Requires ``ADCP_PG_TEST_URL``; skipped otherwise (same gate as the +rest of the pg suite). +""" + +from __future__ import annotations + +import os +import secrets +import time +from typing import Any + +import httpx +import pytest + +psycopg = pytest.importorskip("psycopg") +psycopg_pool = pytest.importorskip("psycopg_pool") + +TEST_URL = os.environ.get("ADCP_PG_TEST_URL") +if not TEST_URL: + pytest.skip( + "ADCP_PG_TEST_URL not set — skipping PgReplayStore e2e tests", + allow_module_level=True, + ) + +from cryptography.hazmat.primitives.asymmetric import ed25519 # noqa: E402 +from starlette.applications import Starlette # noqa: E402 +from starlette.requests import Request # noqa: E402 +from starlette.responses import JSONResponse # noqa: E402 +from starlette.routing import Route # noqa: E402 + +from adcp.signing import ( # noqa: E402 + PgReplayStore, + SignatureVerificationError, + StaticJwksResolver, + VerifierCapability, + VerifyOptions, + b64url_encode, + sign_request, + unauthorized_response_headers, + verify_starlette_request, +) + +# -- fixtures --------------------------------------------------------- + + +@pytest.fixture() +def isolated_table() -> str: + """Unique per-test table so parallel runs and reruns don't collide.""" + table = f"test_e2e_replay_{secrets.token_hex(6)}" + with psycopg_pool.ConnectionPool(TEST_URL, min_size=1, max_size=2) as pool: + with pool.connection() as conn, conn.cursor() as cur: + cur.execute( + f""" + CREATE TABLE {table} ( + keyid TEXT COLLATE "C" NOT NULL, + nonce TEXT COLLATE "C" NOT NULL, + expires_at TIMESTAMPTZ NOT NULL, + PRIMARY KEY (keyid, nonce) + ) + """ + ) + yield table + with pool.connection() as conn, conn.cursor() as cur: + cur.execute(f"DROP TABLE IF EXISTS {table}") + + +@pytest.fixture() +def signing_keypair() -> tuple[ed25519.Ed25519PrivateKey, dict[str, Any]]: + private = ed25519.Ed25519PrivateKey.generate() + jwk = { + "kty": "OKP", + "crv": "Ed25519", + "alg": "EdDSA", + "use": "sig", + "key_ops": ["verify"], + "adcp_use": "request-signing", + "kid": "e2e-buyer", + "x": b64url_encode(private.public_key().public_bytes_raw()), + } + return private, jwk + + +# -- helpers ---------------------------------------------------------- + + +def _build_app(*, pool: psycopg_pool.ConnectionPool, table: str, jwk: dict[str, Any]) -> Starlette: + """Build a Starlette app that runs the verifier with PgReplayStore. + + Identical to the shape an integrator writes — no special-case code + beyond wiring ``replay_store=PgReplayStore(...)`` into + ``VerifyOptions``. + """ + replay_store = PgReplayStore(pool=pool, table_name=table) + jwks_resolver = StaticJwksResolver({"keys": [jwk]}) + + async def create_media_buy(request: Request) -> JSONResponse: + options = VerifyOptions( + now=float(int(time.time())), + capability=VerifierCapability( + covers_content_digest="either", + required_for=frozenset({"create_media_buy"}), + ), + operation="create_media_buy", + jwks_resolver=jwks_resolver, + replay_store=replay_store, + ) + try: + signer = await verify_starlette_request(request, options=options) + except SignatureVerificationError as exc: + return JSONResponse( + {"error": exc.code, "step": exc.step, "message": str(exc)}, + status_code=401, + headers=unauthorized_response_headers(exc), + ) + return JSONResponse({"verified_key_id": signer.key_id, "status": "accepted"}) + + return Starlette(routes=[Route("/adcp/create_media_buy", create_media_buy, methods=["POST"])]) + + +# -- tests ------------------------------------------------------------ + + +async def test_signed_request_verifies_end_to_end( + isolated_table: str, + signing_keypair: tuple[ed25519.Ed25519PrivateKey, dict[str, Any]], +) -> None: + """Happy path: sign → POST → Starlette verifies → 200.""" + private_key, jwk = signing_keypair + + with psycopg_pool.ConnectionPool(TEST_URL, min_size=1, max_size=4) as pool: + app = _build_app(pool=pool, table=isolated_table, jwk=jwk) + body = b'{"plan_id":"p1"}' + signed = sign_request( + method="POST", + url="http://test/adcp/create_media_buy", + headers={"Content-Type": "application/json"}, + body=body, + private_key=private_key, + key_id="e2e-buyer", + alg="ed25519", + ) + headers = {"Content-Type": "application/json", **signed.as_dict()} + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + resp = await client.post("/adcp/create_media_buy", content=body, headers=headers) + + assert resp.status_code == 200, resp.text + assert resp.json()["verified_key_id"] == "e2e-buyer" + + +async def test_replay_rejected_with_request_signature_replayed( + isolated_table: str, + signing_keypair: tuple[ed25519.Ed25519PrivateKey, dict[str, Any]], +) -> None: + """The load-bearing property: the second identical request → 401 replayed. + + Without the replay store this would succeed twice; with it the + second attempt must return the spec's ``request_signature_replayed`` + code and the WWW-Authenticate header. + """ + private_key, jwk = signing_keypair + + with psycopg_pool.ConnectionPool(TEST_URL, min_size=1, max_size=4) as pool: + app = _build_app(pool=pool, table=isolated_table, jwk=jwk) + body = b'{"plan_id":"p1"}' + signed = sign_request( + method="POST", + url="http://test/adcp/create_media_buy", + headers={"Content-Type": "application/json"}, + body=body, + private_key=private_key, + key_id="e2e-buyer", + alg="ed25519", + ) + headers = {"Content-Type": "application/json", **signed.as_dict()} + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + # First pass: accepted. + r1 = await client.post("/adcp/create_media_buy", content=body, headers=headers) + assert r1.status_code == 200, r1.text + + # Replay the same headers → the spec says reject with 401 + + # WWW-Authenticate: Signature error="request_signature_replayed". + r2 = await client.post("/adcp/create_media_buy", content=body, headers=headers) + assert r2.status_code == 401, r2.text + assert r2.json()["error"] == "request_signature_replayed" + www_auth = r2.headers.get("www-authenticate", "") + assert 'Signature error="request_signature_replayed"' in www_auth + + +async def test_fresh_nonce_after_replay_accepted( + isolated_table: str, + signing_keypair: tuple[ed25519.Ed25519PrivateKey, dict[str, Any]], +) -> None: + """After a replay rejection, a newly-signed request MUST be accepted — + the replay store locks one (keyid, nonce), not the whole keyid. + """ + private_key, jwk = signing_keypair + + with psycopg_pool.ConnectionPool(TEST_URL, min_size=1, max_size=4) as pool: + app = _build_app(pool=pool, table=isolated_table, jwk=jwk) + body = b'{"plan_id":"p1"}' + + def _sign() -> dict[str, str]: + signed = sign_request( + method="POST", + url="http://test/adcp/create_media_buy", + headers={"Content-Type": "application/json"}, + body=body, + private_key=private_key, + key_id="e2e-buyer", + alg="ed25519", + ) + return {"Content-Type": "application/json", **signed.as_dict()} + + transport = httpx.ASGITransport(app=app) + async with httpx.AsyncClient(transport=transport, base_url="http://test") as client: + h1 = _sign() + r1 = await client.post("/adcp/create_media_buy", content=body, headers=h1) + assert r1.status_code == 200 + + # Replay of h1 rejected. + r2 = await client.post("/adcp/create_media_buy", content=body, headers=h1) + assert r2.status_code == 401 + + # Fresh signature (new nonce under the hood) accepted. + h2 = _sign() + assert h2["Signature-Input"] != h1["Signature-Input"] # sanity + r3 = await client.post("/adcp/create_media_buy", content=body, headers=h2) + assert r3.status_code == 200, r3.text + + +async def test_cross_instance_replay_rejection( + isolated_table: str, + signing_keypair: tuple[ed25519.Ed25519PrivateKey, dict[str, Any]], +) -> None: + """Sim two workers sharing a pool: worker A accepts, worker B rejects the replay. + + This is the core reason Postgres exists in this module — the + in-memory store can't enforce this. Worker B holds a SEPARATE + ``PgReplayStore`` instance backed by the same pool, and still sees + worker A's ``remember`` via the shared table. + """ + private_key, jwk = signing_keypair + + with psycopg_pool.ConnectionPool(TEST_URL, min_size=2, max_size=6) as pool: + # Two independent Starlette apps, each with its own + # PgReplayStore instance but sharing the DB-side table. + app_a = _build_app(pool=pool, table=isolated_table, jwk=jwk) + app_b = _build_app(pool=pool, table=isolated_table, jwk=jwk) + + body = b'{"plan_id":"cross"}' + signed = sign_request( + method="POST", + url="http://test/adcp/create_media_buy", + headers={"Content-Type": "application/json"}, + body=body, + private_key=private_key, + key_id="e2e-buyer", + alg="ed25519", + ) + headers = {"Content-Type": "application/json", **signed.as_dict()} + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app_a), base_url="http://test" + ) as client_a: + r_a = await client_a.post("/adcp/create_media_buy", content=body, headers=headers) + assert r_a.status_code == 200, r_a.text + + async with httpx.AsyncClient( + transport=httpx.ASGITransport(app=app_b), base_url="http://test" + ) as client_b: + # Worker B receives the replay. Its own PgReplayStore instance + # has never called remember(), but the DB-side row from + # worker A is visible, so seen() returns True → 401. + r_b = await client_b.post("/adcp/create_media_buy", content=body, headers=headers) + assert r_b.status_code == 401, r_b.text + assert r_b.json()["error"] == "request_signature_replayed"