Skip to content

Commit 3818653

Browse files
committed
Allow injecting custom data to custom execution context (#226)
1 parent c685d84 commit 3818653

File tree

2 files changed

+29
-15
lines changed

2 files changed

+29
-15
lines changed

Diff for: src/graphql/execution/execute.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from asyncio import ensure_future, gather, shield, wait_for
66
from contextlib import suppress
7+
from copy import copy
78
from typing import (
89
Any,
910
AsyncGenerator,
@@ -219,6 +220,7 @@ def build(
219220
subscribe_field_resolver: GraphQLFieldResolver | None = None,
220221
middleware: Middleware | None = None,
221222
is_awaitable: Callable[[Any], bool] | None = None,
223+
**custom_args: Any,
222224
) -> list[GraphQLError] | ExecutionContext:
223225
"""Build an execution context
224226
@@ -292,24 +294,14 @@ def build(
292294
IncrementalPublisher(),
293295
middleware_manager,
294296
is_awaitable,
297+
**custom_args,
295298
)
296299

297300
def build_per_event_execution_context(self, payload: Any) -> ExecutionContext:
298301
"""Create a copy of the execution context for usage with subscribe events."""
299-
return self.__class__(
300-
self.schema,
301-
self.fragments,
302-
payload,
303-
self.context_value,
304-
self.operation,
305-
self.variable_values,
306-
self.field_resolver,
307-
self.type_resolver,
308-
self.subscribe_field_resolver,
309-
self.incremental_publisher,
310-
self.middleware_manager,
311-
self.is_awaitable,
312-
)
302+
context = copy(self)
303+
context.root_value = payload
304+
return context
313305

314306
def execute_operation(
315307
self, initial_result_record: InitialResultRecord
@@ -1709,6 +1701,7 @@ def execute(
17091701
middleware: Middleware | None = None,
17101702
execution_context_class: type[ExecutionContext] | None = None,
17111703
is_awaitable: Callable[[Any], bool] | None = None,
1704+
**custom_context_args: Any,
17121705
) -> AwaitableOrValue[ExecutionResult]:
17131706
"""Execute a GraphQL operation.
17141707
@@ -1741,6 +1734,7 @@ def execute(
17411734
middleware,
17421735
execution_context_class,
17431736
is_awaitable,
1737+
**custom_context_args,
17441738
)
17451739
if isinstance(result, ExecutionResult):
17461740
return result
@@ -1769,6 +1763,7 @@ def experimental_execute_incrementally(
17691763
middleware: Middleware | None = None,
17701764
execution_context_class: type[ExecutionContext] | None = None,
17711765
is_awaitable: Callable[[Any], bool] | None = None,
1766+
**custom_context_args: Any,
17721767
) -> AwaitableOrValue[ExecutionResult | ExperimentalIncrementalExecutionResults]:
17731768
"""Execute GraphQL operation incrementally (internal implementation).
17741769
@@ -1797,6 +1792,7 @@ def experimental_execute_incrementally(
17971792
subscribe_field_resolver,
17981793
middleware,
17991794
is_awaitable,
1795+
**custom_context_args,
18001796
)
18011797

18021798
# Return early errors if execution context failed.
@@ -2127,6 +2123,7 @@ def subscribe(
21272123
subscribe_field_resolver: GraphQLFieldResolver | None = None,
21282124
execution_context_class: type[ExecutionContext] | None = None,
21292125
middleware: MiddlewareManager | None = None,
2126+
**custom_context_args: Any,
21302127
) -> AwaitableOrValue[AsyncIterator[ExecutionResult] | ExecutionResult]:
21312128
"""Create a GraphQL subscription.
21322129
@@ -2167,6 +2164,7 @@ def subscribe(
21672164
type_resolver,
21682165
subscribe_field_resolver,
21692166
middleware=middleware,
2167+
**custom_context_args,
21702168
)
21712169

21722170
# Return early errors if execution context failed.
@@ -2202,6 +2200,7 @@ def create_source_event_stream(
22022200
type_resolver: GraphQLTypeResolver | None = None,
22032201
subscribe_field_resolver: GraphQLFieldResolver | None = None,
22042202
execution_context_class: type[ExecutionContext] | None = None,
2203+
**custom_context_args: Any,
22052204
) -> AwaitableOrValue[AsyncIterable[Any] | ExecutionResult]:
22062205
"""Create source event stream
22072206
@@ -2238,6 +2237,7 @@ def create_source_event_stream(
22382237
field_resolver,
22392238
type_resolver,
22402239
subscribe_field_resolver,
2240+
**custom_context_args,
22412241
)
22422242

22432243
# Return early errors if execution context failed.

Diff for: tests/execution/test_customize.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,10 @@ def uses_a_custom_execution_context_class():
4343
)
4444

4545
class TestExecutionContext(ExecutionContext):
46+
def __init__(self, *args, **kwargs):
47+
assert kwargs.pop("custom_arg", None) == "baz"
48+
super().__init__(*args, **kwargs)
49+
4650
def execute_field(
4751
self,
4852
parent_type,
@@ -62,7 +66,12 @@ def execute_field(
6266
)
6367
return result * 2 # type: ignore
6468

65-
assert execute(schema, query, execution_context_class=TestExecutionContext) == (
69+
assert execute(
70+
schema,
71+
query,
72+
execution_context_class=TestExecutionContext,
73+
custom_arg="baz",
74+
) == (
6675
{"foo": "barbar"},
6776
None,
6877
)
@@ -101,6 +110,10 @@ async def custom_foo():
101110
@pytest.mark.asyncio
102111
async def uses_a_custom_execution_context_class():
103112
class TestExecutionContext(ExecutionContext):
113+
def __init__(self, *args, **kwargs):
114+
assert kwargs.pop("custom_arg", None) == "baz"
115+
super().__init__(*args, **kwargs)
116+
104117
def build_resolve_info(self, *args, **kwargs):
105118
resolve_info = super().build_resolve_info(*args, **kwargs)
106119
resolve_info.context["foo"] = "bar"
@@ -132,6 +145,7 @@ def resolve_foo(message, _info):
132145
document,
133146
context_value={},
134147
execution_context_class=TestExecutionContext,
148+
custom_arg="baz",
135149
)
136150
assert isasyncgen(subscription)
137151

0 commit comments

Comments
 (0)