Skip to content

Commit 0d66529

Browse files
committed
drop in an async view
1 parent 2a2e5e7 commit 0d66529

File tree

1 file changed

+354
-0
lines changed

1 file changed

+354
-0
lines changed

graphene_django/views.py

Lines changed: 354 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
import json
33
import re
44

5+
from asyncio import gather, coroutines
6+
57
from django.db import connection, transaction
68
from django.http import HttpResponse, HttpResponseNotAllowed
79
from django.http.response import HttpResponseBadRequest
810
from django.shortcuts import render
911
from django.utils.decorators import method_decorator
1012
from django.views.decorators.csrf import ensure_csrf_cookie
13+
from django.utils.decorators import classonlymethod
1114
from django.views.generic import View
1215
from graphql import OperationType, get_operation_ast, parse, validate
1316
from graphql.error import GraphQLError
@@ -396,3 +399,354 @@ def get_content_type(request):
396399
meta = request.META
397400
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
398401
return content_type.split(";", 1)[0].lower()
402+
403+
404+
class AsyncGraphQLView(GraphQLView):
405+
graphiql_template = "graphene/graphiql.html"
406+
407+
# Polyfill for window.fetch.
408+
whatwg_fetch_version = "3.6.2"
409+
whatwg_fetch_sri = "sha256-+pQdxwAcHJdQ3e/9S4RK6g8ZkwdMgFQuHvLuN5uyk5c="
410+
411+
# React and ReactDOM.
412+
react_version = "17.0.2"
413+
react_sri = "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8="
414+
react_dom_sri = "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0="
415+
416+
# The GraphiQL React app.
417+
graphiql_version = "1.4.7" # "1.0.3"
418+
graphiql_sri = "sha256-cpZ8w9D/i6XdEbY/Eu7yAXeYzReVw0mxYd7OU3gUcsc=" # "sha256-VR4buIDY9ZXSyCNFHFNik6uSe0MhigCzgN4u7moCOTk="
419+
graphiql_css_sri = "sha256-HADQowUuFum02+Ckkv5Yu5ygRoLllHZqg0TFZXY7NHI=" # "sha256-LwqxjyZgqXDYbpxQJ5zLQeNcf7WVNSJ+r8yp2rnWE/E="
420+
421+
# The websocket transport library for subscriptions.
422+
subscriptions_transport_ws_version = "0.9.18"
423+
subscriptions_transport_ws_sri = (
424+
"sha256-i0hAXd4PdJ/cHX3/8tIy/Q/qKiWr5WSTxMFuL9tACkw="
425+
)
426+
427+
schema = None
428+
graphiql = False
429+
middleware = None
430+
root_value = None
431+
pretty = False
432+
batch = False
433+
subscription_path = None
434+
execution_context_class = None
435+
436+
def __init__(
437+
self,
438+
schema=None,
439+
middleware=None,
440+
root_value=None,
441+
graphiql=False,
442+
pretty=False,
443+
batch=False,
444+
subscription_path=None,
445+
execution_context_class=None,
446+
):
447+
if not schema:
448+
schema = graphene_settings.SCHEMA
449+
450+
if middleware is None:
451+
middleware = graphene_settings.MIDDLEWARE
452+
453+
self.schema = self.schema or schema
454+
if middleware is not None:
455+
if isinstance(middleware, MiddlewareManager):
456+
self.middleware = middleware
457+
else:
458+
self.middleware = list(instantiate_middleware(middleware))
459+
self.root_value = root_value
460+
self.pretty = self.pretty or pretty
461+
self.graphiql = self.graphiql or graphiql
462+
self.batch = self.batch or batch
463+
self.execution_context_class = execution_context_class
464+
if subscription_path is None:
465+
self.subscription_path = graphene_settings.SUBSCRIPTION_PATH
466+
467+
assert isinstance(
468+
self.schema, Schema
469+
), "A Schema is required to be provided to GraphQLView."
470+
assert not all((graphiql, batch)), "Use either graphiql or batch processing"
471+
472+
# noinspection PyUnusedLocal
473+
def get_root_value(self, request):
474+
return self.root_value
475+
476+
def get_middleware(self, request):
477+
return self.middleware
478+
479+
def get_context(self, request):
480+
return request
481+
482+
@classonlymethod
483+
def as_view(cls, **initkwargs):
484+
view = super().as_view(**initkwargs)
485+
view._is_coroutine = coroutines._is_coroutine
486+
return view
487+
488+
@method_decorator(ensure_csrf_cookie)
489+
async def dispatch(self, request, *args, **kwargs):
490+
try:
491+
if request.method.lower() not in ("get", "post"):
492+
raise HttpError(
493+
HttpResponseNotAllowed(
494+
["GET", "POST"], "GraphQL only supports GET and POST requests."
495+
)
496+
)
497+
498+
data = self.parse_body(request)
499+
show_graphiql = self.graphiql and self.can_display_graphiql(request, data)
500+
501+
if show_graphiql:
502+
return self.render_graphiql(
503+
request,
504+
# Dependency parameters.
505+
whatwg_fetch_version=self.whatwg_fetch_version,
506+
whatwg_fetch_sri=self.whatwg_fetch_sri,
507+
react_version=self.react_version,
508+
react_sri=self.react_sri,
509+
react_dom_sri=self.react_dom_sri,
510+
graphiql_version=self.graphiql_version,
511+
graphiql_sri=self.graphiql_sri,
512+
graphiql_css_sri=self.graphiql_css_sri,
513+
subscriptions_transport_ws_version=self.subscriptions_transport_ws_version,
514+
subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri,
515+
# The SUBSCRIPTION_PATH setting.
516+
subscription_path=self.subscription_path,
517+
# GraphiQL headers tab,
518+
graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED,
519+
graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS,
520+
)
521+
522+
if self.batch:
523+
responses = await gather(*[self.get_response(request, entry) for entry in data])
524+
result = "[{}]".format(
525+
",".join([response[0] for response in responses])
526+
)
527+
status_code = (
528+
responses
529+
and max(responses, key=lambda response: response[1])[1]
530+
or 200
531+
)
532+
else:
533+
result, status_code = await self.get_response(request, data, show_graphiql)
534+
535+
return HttpResponse(
536+
status=status_code, content=result, content_type="application/json"
537+
)
538+
539+
except HttpError as e:
540+
response = e.response
541+
response["Content-Type"] = "application/json"
542+
response.content = self.json_encode(
543+
request, {"errors": [self.format_error(e)]}
544+
)
545+
return response
546+
547+
async def get_response(self, request, data, show_graphiql=False):
548+
query, variables, operation_name, id = self.get_graphql_params(request, data)
549+
550+
execution_result = await self.execute_graphql_request(
551+
request, data, query, variables, operation_name, show_graphiql
552+
)
553+
554+
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
555+
set_rollback()
556+
557+
status_code = 200
558+
if execution_result:
559+
response = {}
560+
561+
if execution_result.errors:
562+
set_rollback()
563+
response["errors"] = [
564+
self.format_error(e) for e in execution_result.errors
565+
]
566+
567+
if execution_result.errors and any(
568+
not getattr(e, "path", None) for e in execution_result.errors
569+
):
570+
status_code = 400
571+
else:
572+
response["data"] = execution_result.data
573+
574+
if self.batch:
575+
response["id"] = id
576+
response["status"] = status_code
577+
578+
result = self.json_encode(request, response, pretty=show_graphiql)
579+
else:
580+
result = None
581+
582+
return result, status_code
583+
584+
def render_graphiql(self, request, **data):
585+
return render(request, self.graphiql_template, data)
586+
587+
def json_encode(self, request, d, pretty=False):
588+
if not (self.pretty or pretty) and not request.GET.get("pretty"):
589+
return json.dumps(d, separators=(",", ":"))
590+
591+
return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": "))
592+
593+
def parse_body(self, request):
594+
content_type = self.get_content_type(request)
595+
596+
if content_type == "application/graphql":
597+
return {"query": request.body.decode()}
598+
599+
elif content_type == "application/json":
600+
# noinspection PyBroadException
601+
try:
602+
body = request.body.decode("utf-8")
603+
except Exception as e:
604+
raise HttpError(HttpResponseBadRequest(str(e)))
605+
606+
try:
607+
request_json = json.loads(body)
608+
if self.batch:
609+
assert isinstance(request_json, list), (
610+
"Batch requests should receive a list, but received {}."
611+
).format(repr(request_json))
612+
assert (
613+
len(request_json) > 0
614+
), "Received an empty list in the batch request."
615+
else:
616+
assert isinstance(
617+
request_json, dict
618+
), "The received data is not a valid JSON query."
619+
return request_json
620+
except AssertionError as e:
621+
raise HttpError(HttpResponseBadRequest(str(e)))
622+
except (TypeError, ValueError):
623+
raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON."))
624+
625+
elif content_type in [
626+
"application/x-www-form-urlencoded",
627+
"multipart/form-data",
628+
]:
629+
return request.POST
630+
631+
return {}
632+
633+
async def execute_graphql_request(
634+
self, request, data, query, variables, operation_name, show_graphiql=False
635+
):
636+
if not query:
637+
if show_graphiql:
638+
return None
639+
raise HttpError(HttpResponseBadRequest("Must provide query string."))
640+
641+
try:
642+
document = parse(query)
643+
except Exception as e:
644+
return ExecutionResult(errors=[e])
645+
646+
if request.method.lower() == "get":
647+
operation_ast = get_operation_ast(document, operation_name)
648+
if operation_ast and operation_ast.operation != OperationType.QUERY:
649+
if show_graphiql:
650+
return None
651+
652+
raise HttpError(
653+
HttpResponseNotAllowed(
654+
["POST"],
655+
"Can only perform a {} operation from a POST request.".format(
656+
operation_ast.operation.value
657+
),
658+
)
659+
)
660+
661+
validation_errors = validate(self.schema.graphql_schema, document)
662+
if validation_errors:
663+
return ExecutionResult(data=None, errors=validation_errors)
664+
665+
try:
666+
extra_options = {}
667+
if self.execution_context_class:
668+
extra_options["execution_context_class"] = self.execution_context_class
669+
670+
options = {
671+
"source": query,
672+
"root_value": self.get_root_value(request),
673+
"variable_values": variables,
674+
"operation_name": operation_name,
675+
"context_value": self.get_context(request),
676+
"middleware": self.get_middleware(request),
677+
}
678+
options.update(extra_options)
679+
680+
operation_ast = get_operation_ast(document, operation_name)
681+
if (
682+
operation_ast
683+
and operation_ast.operation == OperationType.MUTATION
684+
and (
685+
graphene_settings.ATOMIC_MUTATIONS is True
686+
or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True
687+
)
688+
):
689+
with transaction.atomic():
690+
result = await self.schema.execute_async(**options)
691+
if getattr(request, MUTATION_ERRORS_FLAG, False) is True:
692+
transaction.set_rollback(True)
693+
return result
694+
695+
return await self.schema.execute_async(**options)
696+
except Exception as e:
697+
return ExecutionResult(errors=[e])
698+
699+
@classmethod
700+
def can_display_graphiql(cls, request, data):
701+
raw = "raw" in request.GET or "raw" in data
702+
return not raw and cls.request_wants_html(request)
703+
704+
@classmethod
705+
def request_wants_html(cls, request):
706+
accepted = get_accepted_content_types(request)
707+
accepted_length = len(accepted)
708+
# the list will be ordered in preferred first - so we have to make
709+
# sure the most preferred gets the highest number
710+
html_priority = (
711+
accepted_length - accepted.index("text/html")
712+
if "text/html" in accepted
713+
else 0
714+
)
715+
json_priority = (
716+
accepted_length - accepted.index("application/json")
717+
if "application/json" in accepted
718+
else 0
719+
)
720+
721+
return html_priority > json_priority
722+
723+
@staticmethod
724+
def get_graphql_params(request, data):
725+
query = request.GET.get("query") or data.get("query")
726+
variables = request.GET.get("variables") or data.get("variables")
727+
id = request.GET.get("id") or data.get("id")
728+
729+
if variables and isinstance(variables, str):
730+
try:
731+
variables = json.loads(variables)
732+
except Exception:
733+
raise HttpError(HttpResponseBadRequest("Variables are invalid JSON."))
734+
735+
operation_name = request.GET.get("operationName") or data.get("operationName")
736+
if operation_name == "null":
737+
operation_name = None
738+
739+
return query, variables, operation_name, id
740+
741+
@staticmethod
742+
def format_error(error):
743+
if isinstance(error, GraphQLError):
744+
return error.formatted
745+
746+
return {"message": str(error)}
747+
748+
@staticmethod
749+
def get_content_type(request):
750+
meta = request.META
751+
content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
752+
return content_type.split(";", 1)[0].lower()

0 commit comments

Comments
 (0)