|
2 | 2 | import json
|
3 | 3 | import re
|
4 | 4 |
|
| 5 | +from asyncio import gather, coroutines |
| 6 | + |
5 | 7 | from django.db import connection, transaction
|
6 | 8 | from django.http import HttpResponse, HttpResponseNotAllowed
|
7 | 9 | from django.http.response import HttpResponseBadRequest
|
8 | 10 | from django.shortcuts import render
|
9 | 11 | from django.utils.decorators import method_decorator
|
10 | 12 | from django.views.decorators.csrf import ensure_csrf_cookie
|
| 13 | +from django.utils.decorators import classonlymethod |
11 | 14 | from django.views.generic import View
|
12 | 15 | from graphql import OperationType, get_operation_ast, parse, validate
|
13 | 16 | from graphql.error import GraphQLError
|
@@ -396,3 +399,354 @@ def get_content_type(request):
|
396 | 399 | meta = request.META
|
397 | 400 | content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
|
398 | 401 | return content_type.split(";", 1)[0].lower()
|
| 402 | + |
| 403 | + |
| 404 | +class AsyncGraphQLView(GraphQLView): |
| 405 | + graphiql_template = "graphene/graphiql.html" |
| 406 | + |
| 407 | + # Polyfill for window.fetch. |
| 408 | + whatwg_fetch_version = "3.6.2" |
| 409 | + whatwg_fetch_sri = "sha256-+pQdxwAcHJdQ3e/9S4RK6g8ZkwdMgFQuHvLuN5uyk5c=" |
| 410 | + |
| 411 | + # React and ReactDOM. |
| 412 | + react_version = "17.0.2" |
| 413 | + react_sri = "sha256-Ipu/TQ50iCCVZBUsZyNJfxrDk0E2yhaEIz0vqI+kFG8=" |
| 414 | + react_dom_sri = "sha256-nbMykgB6tsOFJ7OdVmPpdqMFVk4ZsqWocT6issAPUF0=" |
| 415 | + |
| 416 | + # The GraphiQL React app. |
| 417 | + graphiql_version = "1.4.7" # "1.0.3" |
| 418 | + graphiql_sri = "sha256-cpZ8w9D/i6XdEbY/Eu7yAXeYzReVw0mxYd7OU3gUcsc=" # "sha256-VR4buIDY9ZXSyCNFHFNik6uSe0MhigCzgN4u7moCOTk=" |
| 419 | + graphiql_css_sri = "sha256-HADQowUuFum02+Ckkv5Yu5ygRoLllHZqg0TFZXY7NHI=" # "sha256-LwqxjyZgqXDYbpxQJ5zLQeNcf7WVNSJ+r8yp2rnWE/E=" |
| 420 | + |
| 421 | + # The websocket transport library for subscriptions. |
| 422 | + subscriptions_transport_ws_version = "0.9.18" |
| 423 | + subscriptions_transport_ws_sri = ( |
| 424 | + "sha256-i0hAXd4PdJ/cHX3/8tIy/Q/qKiWr5WSTxMFuL9tACkw=" |
| 425 | + ) |
| 426 | + |
| 427 | + schema = None |
| 428 | + graphiql = False |
| 429 | + middleware = None |
| 430 | + root_value = None |
| 431 | + pretty = False |
| 432 | + batch = False |
| 433 | + subscription_path = None |
| 434 | + execution_context_class = None |
| 435 | + |
| 436 | + def __init__( |
| 437 | + self, |
| 438 | + schema=None, |
| 439 | + middleware=None, |
| 440 | + root_value=None, |
| 441 | + graphiql=False, |
| 442 | + pretty=False, |
| 443 | + batch=False, |
| 444 | + subscription_path=None, |
| 445 | + execution_context_class=None, |
| 446 | + ): |
| 447 | + if not schema: |
| 448 | + schema = graphene_settings.SCHEMA |
| 449 | + |
| 450 | + if middleware is None: |
| 451 | + middleware = graphene_settings.MIDDLEWARE |
| 452 | + |
| 453 | + self.schema = self.schema or schema |
| 454 | + if middleware is not None: |
| 455 | + if isinstance(middleware, MiddlewareManager): |
| 456 | + self.middleware = middleware |
| 457 | + else: |
| 458 | + self.middleware = list(instantiate_middleware(middleware)) |
| 459 | + self.root_value = root_value |
| 460 | + self.pretty = self.pretty or pretty |
| 461 | + self.graphiql = self.graphiql or graphiql |
| 462 | + self.batch = self.batch or batch |
| 463 | + self.execution_context_class = execution_context_class |
| 464 | + if subscription_path is None: |
| 465 | + self.subscription_path = graphene_settings.SUBSCRIPTION_PATH |
| 466 | + |
| 467 | + assert isinstance( |
| 468 | + self.schema, Schema |
| 469 | + ), "A Schema is required to be provided to GraphQLView." |
| 470 | + assert not all((graphiql, batch)), "Use either graphiql or batch processing" |
| 471 | + |
| 472 | + # noinspection PyUnusedLocal |
| 473 | + def get_root_value(self, request): |
| 474 | + return self.root_value |
| 475 | + |
| 476 | + def get_middleware(self, request): |
| 477 | + return self.middleware |
| 478 | + |
| 479 | + def get_context(self, request): |
| 480 | + return request |
| 481 | + |
| 482 | + @classonlymethod |
| 483 | + def as_view(cls, **initkwargs): |
| 484 | + view = super().as_view(**initkwargs) |
| 485 | + view._is_coroutine = coroutines._is_coroutine |
| 486 | + return view |
| 487 | + |
| 488 | + @method_decorator(ensure_csrf_cookie) |
| 489 | + async def dispatch(self, request, *args, **kwargs): |
| 490 | + try: |
| 491 | + if request.method.lower() not in ("get", "post"): |
| 492 | + raise HttpError( |
| 493 | + HttpResponseNotAllowed( |
| 494 | + ["GET", "POST"], "GraphQL only supports GET and POST requests." |
| 495 | + ) |
| 496 | + ) |
| 497 | + |
| 498 | + data = self.parse_body(request) |
| 499 | + show_graphiql = self.graphiql and self.can_display_graphiql(request, data) |
| 500 | + |
| 501 | + if show_graphiql: |
| 502 | + return self.render_graphiql( |
| 503 | + request, |
| 504 | + # Dependency parameters. |
| 505 | + whatwg_fetch_version=self.whatwg_fetch_version, |
| 506 | + whatwg_fetch_sri=self.whatwg_fetch_sri, |
| 507 | + react_version=self.react_version, |
| 508 | + react_sri=self.react_sri, |
| 509 | + react_dom_sri=self.react_dom_sri, |
| 510 | + graphiql_version=self.graphiql_version, |
| 511 | + graphiql_sri=self.graphiql_sri, |
| 512 | + graphiql_css_sri=self.graphiql_css_sri, |
| 513 | + subscriptions_transport_ws_version=self.subscriptions_transport_ws_version, |
| 514 | + subscriptions_transport_ws_sri=self.subscriptions_transport_ws_sri, |
| 515 | + # The SUBSCRIPTION_PATH setting. |
| 516 | + subscription_path=self.subscription_path, |
| 517 | + # GraphiQL headers tab, |
| 518 | + graphiql_header_editor_enabled=graphene_settings.GRAPHIQL_HEADER_EDITOR_ENABLED, |
| 519 | + graphiql_should_persist_headers=graphene_settings.GRAPHIQL_SHOULD_PERSIST_HEADERS, |
| 520 | + ) |
| 521 | + |
| 522 | + if self.batch: |
| 523 | + responses = await gather(*[self.get_response(request, entry) for entry in data]) |
| 524 | + result = "[{}]".format( |
| 525 | + ",".join([response[0] for response in responses]) |
| 526 | + ) |
| 527 | + status_code = ( |
| 528 | + responses |
| 529 | + and max(responses, key=lambda response: response[1])[1] |
| 530 | + or 200 |
| 531 | + ) |
| 532 | + else: |
| 533 | + result, status_code = await self.get_response(request, data, show_graphiql) |
| 534 | + |
| 535 | + return HttpResponse( |
| 536 | + status=status_code, content=result, content_type="application/json" |
| 537 | + ) |
| 538 | + |
| 539 | + except HttpError as e: |
| 540 | + response = e.response |
| 541 | + response["Content-Type"] = "application/json" |
| 542 | + response.content = self.json_encode( |
| 543 | + request, {"errors": [self.format_error(e)]} |
| 544 | + ) |
| 545 | + return response |
| 546 | + |
| 547 | + async def get_response(self, request, data, show_graphiql=False): |
| 548 | + query, variables, operation_name, id = self.get_graphql_params(request, data) |
| 549 | + |
| 550 | + execution_result = await self.execute_graphql_request( |
| 551 | + request, data, query, variables, operation_name, show_graphiql |
| 552 | + ) |
| 553 | + |
| 554 | + if getattr(request, MUTATION_ERRORS_FLAG, False) is True: |
| 555 | + set_rollback() |
| 556 | + |
| 557 | + status_code = 200 |
| 558 | + if execution_result: |
| 559 | + response = {} |
| 560 | + |
| 561 | + if execution_result.errors: |
| 562 | + set_rollback() |
| 563 | + response["errors"] = [ |
| 564 | + self.format_error(e) for e in execution_result.errors |
| 565 | + ] |
| 566 | + |
| 567 | + if execution_result.errors and any( |
| 568 | + not getattr(e, "path", None) for e in execution_result.errors |
| 569 | + ): |
| 570 | + status_code = 400 |
| 571 | + else: |
| 572 | + response["data"] = execution_result.data |
| 573 | + |
| 574 | + if self.batch: |
| 575 | + response["id"] = id |
| 576 | + response["status"] = status_code |
| 577 | + |
| 578 | + result = self.json_encode(request, response, pretty=show_graphiql) |
| 579 | + else: |
| 580 | + result = None |
| 581 | + |
| 582 | + return result, status_code |
| 583 | + |
| 584 | + def render_graphiql(self, request, **data): |
| 585 | + return render(request, self.graphiql_template, data) |
| 586 | + |
| 587 | + def json_encode(self, request, d, pretty=False): |
| 588 | + if not (self.pretty or pretty) and not request.GET.get("pretty"): |
| 589 | + return json.dumps(d, separators=(",", ":")) |
| 590 | + |
| 591 | + return json.dumps(d, sort_keys=True, indent=2, separators=(",", ": ")) |
| 592 | + |
| 593 | + def parse_body(self, request): |
| 594 | + content_type = self.get_content_type(request) |
| 595 | + |
| 596 | + if content_type == "application/graphql": |
| 597 | + return {"query": request.body.decode()} |
| 598 | + |
| 599 | + elif content_type == "application/json": |
| 600 | + # noinspection PyBroadException |
| 601 | + try: |
| 602 | + body = request.body.decode("utf-8") |
| 603 | + except Exception as e: |
| 604 | + raise HttpError(HttpResponseBadRequest(str(e))) |
| 605 | + |
| 606 | + try: |
| 607 | + request_json = json.loads(body) |
| 608 | + if self.batch: |
| 609 | + assert isinstance(request_json, list), ( |
| 610 | + "Batch requests should receive a list, but received {}." |
| 611 | + ).format(repr(request_json)) |
| 612 | + assert ( |
| 613 | + len(request_json) > 0 |
| 614 | + ), "Received an empty list in the batch request." |
| 615 | + else: |
| 616 | + assert isinstance( |
| 617 | + request_json, dict |
| 618 | + ), "The received data is not a valid JSON query." |
| 619 | + return request_json |
| 620 | + except AssertionError as e: |
| 621 | + raise HttpError(HttpResponseBadRequest(str(e))) |
| 622 | + except (TypeError, ValueError): |
| 623 | + raise HttpError(HttpResponseBadRequest("POST body sent invalid JSON.")) |
| 624 | + |
| 625 | + elif content_type in [ |
| 626 | + "application/x-www-form-urlencoded", |
| 627 | + "multipart/form-data", |
| 628 | + ]: |
| 629 | + return request.POST |
| 630 | + |
| 631 | + return {} |
| 632 | + |
| 633 | + async def execute_graphql_request( |
| 634 | + self, request, data, query, variables, operation_name, show_graphiql=False |
| 635 | + ): |
| 636 | + if not query: |
| 637 | + if show_graphiql: |
| 638 | + return None |
| 639 | + raise HttpError(HttpResponseBadRequest("Must provide query string.")) |
| 640 | + |
| 641 | + try: |
| 642 | + document = parse(query) |
| 643 | + except Exception as e: |
| 644 | + return ExecutionResult(errors=[e]) |
| 645 | + |
| 646 | + if request.method.lower() == "get": |
| 647 | + operation_ast = get_operation_ast(document, operation_name) |
| 648 | + if operation_ast and operation_ast.operation != OperationType.QUERY: |
| 649 | + if show_graphiql: |
| 650 | + return None |
| 651 | + |
| 652 | + raise HttpError( |
| 653 | + HttpResponseNotAllowed( |
| 654 | + ["POST"], |
| 655 | + "Can only perform a {} operation from a POST request.".format( |
| 656 | + operation_ast.operation.value |
| 657 | + ), |
| 658 | + ) |
| 659 | + ) |
| 660 | + |
| 661 | + validation_errors = validate(self.schema.graphql_schema, document) |
| 662 | + if validation_errors: |
| 663 | + return ExecutionResult(data=None, errors=validation_errors) |
| 664 | + |
| 665 | + try: |
| 666 | + extra_options = {} |
| 667 | + if self.execution_context_class: |
| 668 | + extra_options["execution_context_class"] = self.execution_context_class |
| 669 | + |
| 670 | + options = { |
| 671 | + "source": query, |
| 672 | + "root_value": self.get_root_value(request), |
| 673 | + "variable_values": variables, |
| 674 | + "operation_name": operation_name, |
| 675 | + "context_value": self.get_context(request), |
| 676 | + "middleware": self.get_middleware(request), |
| 677 | + } |
| 678 | + options.update(extra_options) |
| 679 | + |
| 680 | + operation_ast = get_operation_ast(document, operation_name) |
| 681 | + if ( |
| 682 | + operation_ast |
| 683 | + and operation_ast.operation == OperationType.MUTATION |
| 684 | + and ( |
| 685 | + graphene_settings.ATOMIC_MUTATIONS is True |
| 686 | + or connection.settings_dict.get("ATOMIC_MUTATIONS", False) is True |
| 687 | + ) |
| 688 | + ): |
| 689 | + with transaction.atomic(): |
| 690 | + result = await self.schema.execute_async(**options) |
| 691 | + if getattr(request, MUTATION_ERRORS_FLAG, False) is True: |
| 692 | + transaction.set_rollback(True) |
| 693 | + return result |
| 694 | + |
| 695 | + return await self.schema.execute_async(**options) |
| 696 | + except Exception as e: |
| 697 | + return ExecutionResult(errors=[e]) |
| 698 | + |
| 699 | + @classmethod |
| 700 | + def can_display_graphiql(cls, request, data): |
| 701 | + raw = "raw" in request.GET or "raw" in data |
| 702 | + return not raw and cls.request_wants_html(request) |
| 703 | + |
| 704 | + @classmethod |
| 705 | + def request_wants_html(cls, request): |
| 706 | + accepted = get_accepted_content_types(request) |
| 707 | + accepted_length = len(accepted) |
| 708 | + # the list will be ordered in preferred first - so we have to make |
| 709 | + # sure the most preferred gets the highest number |
| 710 | + html_priority = ( |
| 711 | + accepted_length - accepted.index("text/html") |
| 712 | + if "text/html" in accepted |
| 713 | + else 0 |
| 714 | + ) |
| 715 | + json_priority = ( |
| 716 | + accepted_length - accepted.index("application/json") |
| 717 | + if "application/json" in accepted |
| 718 | + else 0 |
| 719 | + ) |
| 720 | + |
| 721 | + return html_priority > json_priority |
| 722 | + |
| 723 | + @staticmethod |
| 724 | + def get_graphql_params(request, data): |
| 725 | + query = request.GET.get("query") or data.get("query") |
| 726 | + variables = request.GET.get("variables") or data.get("variables") |
| 727 | + id = request.GET.get("id") or data.get("id") |
| 728 | + |
| 729 | + if variables and isinstance(variables, str): |
| 730 | + try: |
| 731 | + variables = json.loads(variables) |
| 732 | + except Exception: |
| 733 | + raise HttpError(HttpResponseBadRequest("Variables are invalid JSON.")) |
| 734 | + |
| 735 | + operation_name = request.GET.get("operationName") or data.get("operationName") |
| 736 | + if operation_name == "null": |
| 737 | + operation_name = None |
| 738 | + |
| 739 | + return query, variables, operation_name, id |
| 740 | + |
| 741 | + @staticmethod |
| 742 | + def format_error(error): |
| 743 | + if isinstance(error, GraphQLError): |
| 744 | + return error.formatted |
| 745 | + |
| 746 | + return {"message": str(error)} |
| 747 | + |
| 748 | + @staticmethod |
| 749 | + def get_content_type(request): |
| 750 | + meta = request.META |
| 751 | + content_type = meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", "")) |
| 752 | + return content_type.split(";", 1)[0].lower() |
0 commit comments