Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions docs/usage/extensions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,38 @@
Extensions
----------

Request extensions
^^^^^^^^^^^^^^^^^^

The `GraphQL over HTTP spec <https://github.com/graphql/graphql-over-http>`_
defines an optional :code:`extensions` field on requests. This is sent as a
top-level key in the request payload alongside :code:`query`, :code:`variables`,
and :code:`operationName`.

You can use this to pass protocol extensions such as
`trusted documents <https://graphql.org/learn/security/#trusted-documents>`_:

.. code-block:: python

from gql import Client, GraphQLRequest
from gql.transport.aiohttp import AIOHTTPTransport

transport = AIOHTTPTransport(url="https://example.com/graphql")

async with Client(transport=transport) as session:

request = GraphQLRequest(
"query { viewer { name } }",
extensions={
"document-id": "155d6e8f5545...",
},
)

result = await session.execute(request)

Response extensions
^^^^^^^^^^^^^^^^^^^

When you execute (or subscribe) GraphQL requests, the server will send
responses which may have 3 fields:

Expand Down
11 changes: 11 additions & 0 deletions gql/graphql_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def __init__(
*,
variable_values: Optional[Dict[str, Any]] = None,
operation_name: Optional[str] = None,
extensions: Optional[Dict[str, Any]] = None,
):
"""Initialize a GraphQL request.

Expand All @@ -21,6 +22,9 @@ def __init__(
:param variable_values: Dictionary of input parameters (Default: None).
:param operation_name: Name of the operation that shall be executed.
Only required in multi-operation documents (Default: None).
:param extensions: Dictionary of protocol extensions (Default: None).
This is passed as the top-level "extensions" key in the request
payload, as defined in the GraphQL over HTTP spec.
:return: a :class:`GraphQLRequest <gql.GraphQLRequest>`
which can be later executed or subscribed by a
:class:`Client <gql.client.Client>`, by an
Expand All @@ -42,9 +46,12 @@ def __init__(
variable_values = request.variable_values
if operation_name is None:
operation_name = request.operation_name
if extensions is None:
extensions = request.extensions

self.variable_values: Optional[Dict[str, Any]] = variable_values
self.operation_name: Optional[str] = operation_name
self.extensions: Optional[Dict[str, Any]] = extensions

def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":

Expand All @@ -61,6 +68,7 @@ def serialize_variable_values(self, schema: GraphQLSchema) -> "GraphQLRequest":
operation_name=self.operation_name,
),
operation_name=self.operation_name,
extensions=self.extensions,
)

@property
Expand All @@ -74,6 +82,9 @@ def payload(self) -> Dict[str, Any]:
if self.variable_values:
payload["variables"] = self.variable_values

if self.extensions:
payload["extensions"] = self.extensions

return payload

def __str__(self):
Expand Down
39 changes: 38 additions & 1 deletion tests/test_aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import pytest

from gql import Client, FileVar, gql
from gql import Client, FileVar, GraphQLRequest, gql
from gql.cli import get_parser, main
from gql.transport.exceptions import (
TransportAlreadyConnected,
Expand Down Expand Up @@ -87,6 +87,43 @@ async def handler(request):
assert transport.response_headers["dummy"] == "test1234"


@pytest.mark.asyncio
async def test_aiohttp_request_extensions(aiohttp_server):
from aiohttp import web

from gql.transport.aiohttp import AIOHTTPTransport

extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}

async def handler(request):
body = await request.json()
assert body["extensions"] == extensions
return web.Response(
text=query1_server_answer,
content_type="application/json",
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

transport = AIOHTTPTransport(url=url, timeout=10)

request = GraphQLRequest(query1_str, extensions=extensions)

async with Client(transport=transport) as session:

# execute
result = await session.execute(request)
assert result["continents"][0]["code"] == "AF"

# subscribe
async for result in session.subscribe(request):
assert result["continents"][0]["code"] == "AF"


@pytest.mark.asyncio
async def test_aiohttp_ignore_backend_content_type(aiohttp_server):
from aiohttp import web
Expand Down
32 changes: 32 additions & 0 deletions tests/test_aiohttp_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,38 @@ async def handler(request):
assert transport.response_headers["dummy"] == "test1234"


@pytest.mark.asyncio
async def test_aiohttp_batch_request_extensions(aiohttp_server):
from aiohttp import web

from gql.transport.aiohttp import AIOHTTPTransport

extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}

async def handler(request):
body = await request.json()
assert isinstance(body, list)
assert body[0]["extensions"] == extensions
return web.Response(
text=query1_server_answer_list,
content_type="application/json",
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

transport = AIOHTTPTransport(url=url, timeout=10)

async with Client(transport=transport) as session:

query = [GraphQLRequest(query1_str, extensions=extensions)]
results = await session.execute_batch(query)
assert results[0]["continents"][0]["code"] == "AF"


@pytest.mark.asyncio
async def test_aiohttp_batch_query_auto_batch_enabled(aiohttp_server, run_sync_test):
from aiohttp import web
Expand Down
33 changes: 32 additions & 1 deletion tests/test_aiohttp_websocket_subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from graphql import ExecutionResult
from parse import search

from gql import Client, gql
from gql import Client, GraphQLRequest, gql
from gql.client import AsyncClientSession
from gql.transport.exceptions import TransportConnectionFailed, TransportServerError

Expand Down Expand Up @@ -460,6 +460,37 @@ async def test_aiohttp_websocket_subscription_with_operation_name(
assert '"operationName": "CountdownSubscription"' in logged_messages[0]


@pytest.mark.asyncio
@pytest.mark.parametrize("server", [server_countdown], indirect=True)
@pytest.mark.parametrize("subscription_str", [countdown_subscription_str])
async def test_aiohttp_websocket_subscription_with_extensions(
aiohttp_client_and_server, subscription_str
):

session, server = aiohttp_client_and_server

count = 10
request = GraphQLRequest(
subscription_str.format(count=count),
extensions={"persistedQuery": {"version": 1, "sha256Hash": "abc123"}},
)

async for result in session.subscribe(request):

number = result["number"]
print(f"Number received: {number}")

assert number == count
count -= 1

assert count == -1

message = json.loads(logged_messages[0])
assert message["payload"]["extensions"] == {
"persistedQuery": {"version": 1, "sha256Hash": "abc123"}
}


WITH_KEEPALIVE = True


Expand Down
28 changes: 28 additions & 0 deletions tests/test_graphql_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,3 +236,31 @@ def test_graphql_request_init_with_graphql_request():
assert request_1.variable_values["money"] == money_value_1
assert request_2.variable_values["money"] == money_value_1
assert request_3.variable_values["money"] == money_value_2


def test_graphql_request_extensions():
extensions_1 = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}
extensions_2 = {"custom": "value"}
money_value = Money(10, "DM")

assert "extensions" not in GraphQLRequest("{balance}").payload

request_1 = GraphQLRequest("{balance}", extensions=extensions_1)
assert request_1.payload["extensions"] == extensions_1

# Copied from another GraphQLRequest
request_2 = GraphQLRequest(request_1)
assert request_2.extensions == extensions_1

# Explicit extensions override the copied value
request_3 = GraphQLRequest(request_1, extensions=extensions_2)
assert request_3.extensions == extensions_2

# Preserved through serialize_variable_values
request_4 = GraphQLRequest(
"query myquery($money: Money) {toEuros(money: $money)}",
variable_values={"money": money_value},
extensions=extensions_1,
)
serialized = request_4.serialize_variable_values(schema)
assert serialized.extensions == extensions_1
36 changes: 35 additions & 1 deletion tests/test_httpx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import pytest

from gql import Client, FileVar, gql
from gql import Client, FileVar, GraphQLRequest, gql
from gql.transport.exceptions import (
TransportAlreadyConnected,
TransportClosed,
Expand Down Expand Up @@ -84,6 +84,40 @@ def test_code():
await run_sync_test(server, test_code)


@pytest.mark.aiohttp
@pytest.mark.asyncio
async def test_httpx_request_extensions(aiohttp_server, run_sync_test):
from aiohttp import web

from gql.transport.httpx import HTTPXTransport

extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}

async def handler(request):
body = await request.json()
assert body["extensions"] == extensions
return web.Response(
text=query1_server_answer,
content_type="application/json",
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = str(server.make_url("/"))

def test_code():
transport = HTTPXTransport(url=url)

with Client(transport=transport) as session:
request = GraphQLRequest(query1_str, extensions=extensions)
result = session.execute(request)
assert result["continents"][0]["code"] == "AF"

await run_sync_test(server, test_code)


@pytest.mark.aiohttp
@pytest.mark.asyncio
@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"])
Expand Down
33 changes: 33 additions & 0 deletions tests/test_httpx_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,39 @@ def test_code():
await run_sync_test(server, test_code)


@pytest.mark.aiohttp
@pytest.mark.asyncio
async def test_httpx_batch_request_extensions(aiohttp_server):
from aiohttp import web

from gql.transport.httpx import HTTPXAsyncTransport

extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}

async def handler(request):
body = await request.json()
assert isinstance(body, list)
assert body[0]["extensions"] == extensions
return web.Response(
text=query1_server_answer_list,
content_type="application/json",
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = str(server.make_url("/"))

transport = HTTPXAsyncTransport(url=url, timeout=10)

async with Client(transport=transport) as session:

query = [GraphQLRequest(query1_str, extensions=extensions)]
results = await session.execute_batch(query)
assert results[0]["continents"][0]["code"] == "AF"


@pytest.mark.aiohttp
@pytest.mark.asyncio
async def test_httpx_async_batch_query_without_session(aiohttp_server, run_sync_test):
Expand Down
36 changes: 35 additions & 1 deletion tests/test_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import pytest

from gql import Client, FileVar, gql
from gql import Client, FileVar, GraphQLRequest, gql
from gql.transport.exceptions import (
TransportAlreadyConnected,
TransportClosed,
Expand Down Expand Up @@ -85,6 +85,40 @@ def test_code():
await run_sync_test(server, test_code)


@pytest.mark.aiohttp
@pytest.mark.asyncio
async def test_requests_request_extensions(aiohttp_server, run_sync_test):
from aiohttp import web

from gql.transport.requests import RequestsHTTPTransport

extensions = {"persistedQuery": {"version": 1, "sha256Hash": "abc123"}}

async def handler(request):
body = await request.json()
assert body["extensions"] == extensions
return web.Response(
text=query1_server_answer,
content_type="application/json",
)

app = web.Application()
app.router.add_route("POST", "/", handler)
server = await aiohttp_server(app)

url = server.make_url("/")

def test_code():
transport = RequestsHTTPTransport(url=url)

with Client(transport=transport) as session:
request = GraphQLRequest(query1_str, extensions=extensions)
result = session.execute(request)
assert result["continents"][0]["code"] == "AF"

await run_sync_test(server, test_code)


@pytest.mark.aiohttp
@pytest.mark.asyncio
@pytest.mark.parametrize("verify_https", ["disabled", "cert_provided"])
Expand Down
Loading
Loading