diff --git a/docs/usage/extensions.rst b/docs/usage/extensions.rst index ec413656..72924711 100644 --- a/docs/usage/extensions.rst +++ b/docs/usage/extensions.rst @@ -3,6 +3,38 @@ Extensions ---------- +Request extensions +^^^^^^^^^^^^^^^^^^ + +The `GraphQL over HTTP spec `_ +defines an optional :code:`extensions` field on requests. This is sent as a +top-level key in the request payload alongside :code:`query`, :code:`variables`, +and :code:`operationName`. + +You can use this to pass protocol extensions such as +`trusted documents `_: + +.. code-block:: python + + from gql import Client, GraphQLRequest + from gql.transport.aiohttp import AIOHTTPTransport + + transport = AIOHTTPTransport(url="https://example.com/graphql") + + async with Client(transport=transport) as session: + + request = GraphQLRequest( + "query { viewer { name } }", + extensions={ + "document-id": "155d6e8f5545...", + }, + ) + + result = await session.execute(request) + +Response extensions +^^^^^^^^^^^^^^^^^^^ + When you execute (or subscribe) GraphQL requests, the server will send responses which may have 3 fields: diff --git a/gql/graphql_request.py b/gql/graphql_request.py index 5e6f3ee4..e8e4fb7d 100644 --- a/gql/graphql_request.py +++ b/gql/graphql_request.py @@ -13,6 +13,7 @@ def __init__( *, variable_values: Optional[Dict[str, Any]] = None, operation_name: Optional[str] = None, + extensions: Optional[Dict[str, Any]] = None, ): """Initialize a GraphQL request. @@ -21,6 +22,9 @@ def __init__( :param variable_values: Dictionary of input parameters (Default: None). :param operation_name: Name of the operation that shall be executed. Only required in multi-operation documents (Default: None). + :param extensions: Dictionary of protocol extensions (Default: None). + This is passed as the top-level "extensions" key in the request + payload, as defined in the GraphQL over HTTP spec. :return: a :class:`GraphQLRequest ` which can be later executed or subscribed by a :class:`Client `, by an @@ -42,9 +46,12 @@ def __init__( variable_values = request.variable_values if operation_name is None: operation_name = request.operation_name + if extensions is None: + extensions = request.extensions self.variable_values: Optional[Dict[str, Any]] = variable_values self.operation_name: Optional[str] = operation_name + self.extensions: Optional[Dict[str, Any]] = extensions def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest": @@ -61,6 +68,7 @@ def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest": operation_name=self.operation_name, ), operation_name=self.operation_name, + extensions=self.extensions, ) @property @@ -74,6 +82,9 @@ def payload(self) -> Dict[str, Any]: if self.variable_values: payload["variables"] = self.variable_values + if self.extensions: + payload["extensions"] = self.extensions + return payload def __str__(self): diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index 102fe3f2..00bd8a0f 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -6,7 +6,7 @@ import pytest -from gql import Client, FileVar, gql +from gql import Client, FileVar, GraphQLRequest, gql from gql.cli import get_parser, main from gql.transport.exceptions import ( TransportAlreadyConnected, @@ -87,6 +87,43 @@ async def handler(request): assert transport.response_headers["dummy"] == "test1234" +@pytest.mark.asyncio +async def test_aiohttp_request_extensions(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}} + + async def handler(request): + body = await request.json() + assert body["extensions"] == extensions + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + request = GraphQLRequest(query1_str, extensions=extensions) + + async with Client(transport=transport) as session: + + # execute + result = await session.execute(request) + assert result["continents"][0]["code"] == "AF" + + # subscribe + async for result in session.subscribe(request): + assert result["continents"][0]["code"] == "AF" + + @pytest.mark.asyncio async def test_aiohttp_ignore_backend_content_type(aiohttp_server): from aiohttp import web diff --git a/tests/test_aiohttp_batch.py b/tests/test_aiohttp_batch.py index ad9924a0..37d8f64e 100644 --- a/tests/test_aiohttp_batch.py +++ b/tests/test_aiohttp_batch.py @@ -89,6 +89,38 @@ async def handler(request): assert transport.response_headers["dummy"] == "test1234" +@pytest.mark.asyncio +async def test_aiohttp_batch_request_extensions(aiohttp_server): + from aiohttp import web + + from gql.transport.aiohttp import AIOHTTPTransport + + extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}} + + async def handler(request): + body = await request.json() + assert isinstance(body, list) + assert body[0]["extensions"] == extensions + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + transport = AIOHTTPTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(query1_str, extensions=extensions)] + results = await session.execute_batch(query) + assert results[0]["continents"][0]["code"] == "AF" + + @pytest.mark.asyncio async def test_aiohttp_batch_query_auto_batch_enabled(aiohttp_server, run_sync_test): from aiohttp import web diff --git a/tests/test_aiohttp_websocket_subscription.py b/tests/test_aiohttp_websocket_subscription.py index f06046df..b099a4a9 100644 --- a/tests/test_aiohttp_websocket_subscription.py +++ b/tests/test_aiohttp_websocket_subscription.py @@ -8,7 +8,7 @@ from graphql import ExecutionResult from parse import search -from gql import Client, gql +from gql import Client, GraphQLRequest, gql from gql.client import AsyncClientSession from gql.transport.exceptions import TransportConnectionFailed, TransportServerError @@ -460,6 +460,37 @@ async def test_aiohttp_websocket_subscription_with_operation_name( assert '"operationName": "CountdownSubscription"' in logged_messages[0] +@pytest.mark.asyncio +@pytest.mark.parametrize("server", [server_countdown], indirect=True) +@pytest.mark.parametrize("subscription_str", [countdown_subscription_str]) +async def test_aiohttp_websocket_subscription_with_extensions( + aiohttp_client_and_server, subscription_str +): + + session, server = aiohttp_client_and_server + + count = 10 + request = GraphQLRequest( + subscription_str.format(count=count), + extensions={"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}, + ) + + async for result in session.subscribe(request): + + number = result["number"] + print(f"Number received: {number}") + + assert number == count + count -= 1 + + assert count == -1 + + message = json.loads(logged_messages[0]) + assert message["payload"]["extensions"] == { + "persistedQuery": {"version": 1, "sha256Hash": "abc123"} + } + + WITH_KEEPALIVE = True diff --git a/tests/test_graphql_request.py b/tests/test_graphql_request.py index ea255c7d..d6ba30d2 100644 --- a/tests/test_graphql_request.py +++ b/tests/test_graphql_request.py @@ -236,3 +236,31 @@ def test_graphql_request_init_with_graphql_request(): assert request_1.variable_values["money"] == money_value_1 assert request_2.variable_values["money"] == money_value_1 assert request_3.variable_values["money"] == money_value_2 + + +def test_graphql_request_extensions(): + extensions_1 = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}} + extensions_2 = {"custom": "value"} + money_value = Money(10, "DM") + + assert "extensions" not in GraphQLRequest("{balance}").payload + + request_1 = GraphQLRequest("{balance}", extensions=extensions_1) + assert request_1.payload["extensions"] == extensions_1 + + # Copied from another GraphQLRequest + request_2 = GraphQLRequest(request_1) + assert request_2.extensions == extensions_1 + + # Explicit extensions override the copied value + request_3 = GraphQLRequest(request_1, extensions=extensions_2) + assert request_3.extensions == extensions_2 + + # Preserved through serialize_variable_values + request_4 = GraphQLRequest( + "query myquery($money: Money) {toEuros(money: $money)}", + variable_values={"money": money_value}, + extensions=extensions_1, + ) + serialized = request_4.serialize_variable_values(schema) + assert serialized.extensions == extensions_1 diff --git a/tests/test_httpx.py b/tests/test_httpx.py index 0411294b..aa25edc2 100644 --- a/tests/test_httpx.py +++ b/tests/test_httpx.py @@ -3,7 +3,7 @@ import pytest -from gql import Client, FileVar, gql +from gql import Client, FileVar, GraphQLRequest, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -84,6 +84,40 @@ def test_code(): await run_sync_test(server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_request_extensions(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.httpx import HTTPXTransport + + extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}} + + async def handler(request): + body = await request.json() + assert body["extensions"] == extensions + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + def test_code(): + transport = HTTPXTransport(url=url) + + with Client(transport=transport) as session: + request = GraphQLRequest(query1_str, extensions=extensions) + result = session.execute(request) + assert result["continents"][0]["code"] == "AF" + + await run_sync_test(server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) diff --git a/tests/test_httpx_batch.py b/tests/test_httpx_batch.py index 63472dab..2f05df99 100644 --- a/tests/test_httpx_batch.py +++ b/tests/test_httpx_batch.py @@ -118,6 +118,39 @@ def test_code(): await run_sync_test(server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_httpx_batch_request_extensions(aiohttp_server): + from aiohttp import web + + from gql.transport.httpx import HTTPXAsyncTransport + + extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}} + + async def handler(request): + body = await request.json() + assert isinstance(body, list) + assert body[0]["extensions"] == extensions + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = str(server.make_url("/")) + + transport = HTTPXAsyncTransport(url=url, timeout=10) + + async with Client(transport=transport) as session: + + query = [GraphQLRequest(query1_str, extensions=extensions)] + results = await session.execute_batch(query) + assert results[0]["continents"][0]["code"] == "AF" + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_httpx_async_batch_query_without_session(aiohttp_server, run_sync_test): diff --git a/tests/test_requests.py b/tests/test_requests.py index fe57f5e3..31399f70 100644 --- a/tests/test_requests.py +++ b/tests/test_requests.py @@ -4,7 +4,7 @@ import pytest -from gql import Client, FileVar, gql +from gql import Client, FileVar, GraphQLRequest, gql from gql.transport.exceptions import ( TransportAlreadyConnected, TransportClosed, @@ -85,6 +85,40 @@ def test_code(): await run_sync_test(server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_request_extensions(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.requests import RequestsHTTPTransport + + extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}} + + async def handler(request): + body = await request.json() + assert body["extensions"] == extensions + return web.Response( + text=query1_server_answer, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client(transport=transport) as session: + request = GraphQLRequest(query1_str, extensions=extensions) + result = session.execute(request) + assert result["continents"][0]["code"] == "AF" + + await run_sync_test(server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio @pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"]) diff --git a/tests/test_requests_batch.py b/tests/test_requests_batch.py index 7131c2da..b9732639 100644 --- a/tests/test_requests_batch.py +++ b/tests/test_requests_batch.py @@ -90,6 +90,41 @@ def test_code(): await run_sync_test(server, test_code) +@pytest.mark.aiohttp +@pytest.mark.asyncio +async def test_requests_batch_request_extensions(aiohttp_server, run_sync_test): + from aiohttp import web + + from gql.transport.requests import RequestsHTTPTransport + + extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}} + + async def handler(request): + body = await request.json() + assert isinstance(body, list) + assert body[0]["extensions"] == extensions + return web.Response( + text=query1_server_answer_list, + content_type="application/json", + ) + + app = web.Application() + app.router.add_route("POST", "/", handler) + server = await aiohttp_server(app) + + url = server.make_url("/") + + def test_code(): + transport = RequestsHTTPTransport(url=url) + + with Client(transport=transport) as session: + query = [GraphQLRequest(query1_str, extensions=extensions)] + results = session.execute_batch(query) + assert results[0]["continents"][0]["code"] == "AF" + + await run_sync_test(server, test_code) + + @pytest.mark.aiohttp @pytest.mark.asyncio async def test_requests_query_auto_batch_enabled(aiohttp_server, run_sync_test):