Skip to content

Commit

Permalink
Merge pull request #89 from saritasa-nest/feature/add-ordering
Browse files Browse the repository at this point in the history
Add support for ordering in export
  • Loading branch information
TheSuperiorStanislav authored Jan 9, 2025
2 parents e7e8d26 + c637bc5 commit 64c5a0d
Show file tree
Hide file tree
Showing 9 changed files with 125 additions and 21 deletions.
4 changes: 4 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ UNRELEASED

* Add base import/export views that only allow users to work with their own jobs (`ImportJobForUserViewSet` and `ExportJobForUserViewSet`).
* Small actions definition refactor in `ExportJobViewSet/ExportJobViewSet` to allow easier overriding.
* Add support for ordering in `export`
* Add settings for DjangoFilterBackend and OrderingFilter in export api.
`DRF_EXPORT_DJANGO_FILTERS_BACKEND` with default `django_filters.rest_framework.DjangoFilterBackend` and
`DRF_EXPORT_ORDERING_BACKEND` with default `rest_framework.filters.OrderingFilter`.

1.2.0 (2024-12-26)
------------------
Expand Down
13 changes: 11 additions & 2 deletions docs/api_drf.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,19 @@
API (Rest Framework)
====================

.. autoclass:: import_export_extensions.api.views.ImportJobViewSet
.. autoclass:: import_export_extensions.api.ImportJobViewSet
:members:

.. autoclass:: import_export_extensions.api.views.ExportJobViewSet
.. autoclass:: import_export_extensions.api.ExportJobViewSet
:members:

.. autoclass:: import_export_extensions.api.ImportJobForUserViewSet
:members:

.. autoclass:: import_export_extensions.api.ExportJobForUserViewSet
:members:

.. autoclass:: import_export_extensions.api.LimitQuerySetToCurrentUserMixin
:members:

.. autoclass:: import_export_extensions.api.CreateExportJob
Expand Down
22 changes: 11 additions & 11 deletions import_export_extensions/api/serializers/export_job.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import collections.abc
import typing

from rest_framework import request, serializers

from celery import states
from django_filters.utils import translate_validation

from ... import models, resources
from .progress import ProgressSerializer
Expand Down Expand Up @@ -53,27 +53,26 @@ class CreateExportJob(serializers.Serializer):
def __init__(
self,
*args,
ordering: collections.abc.Sequence[str] | None = None,
filter_kwargs: dict[str, typing.Any] | None = None,
resource_kwargs: dict[str, typing.Any] | None = None,
**kwargs,
):
"""Set filter kwargs and current user."""
"""Set ordering, filter kwargs and current user."""
super().__init__(*args, **kwargs)
self._ordering = ordering
self._filter_kwargs = filter_kwargs
self._resource_kwargs = resource_kwargs or {}
self._request: request.Request = self.context.get("request")
self._user = getattr(self._request, "user", None)

def validate(self, attrs: dict[str, typing.Any]) -> dict[str, typing.Any]:
"""Check that filter kwargs are valid."""
if not self._filter_kwargs:
return attrs

filter_instance = self.resource_class.filterset_class(
data=self._filter_kwargs,
)
if not filter_instance.is_valid():
raise translate_validation(error_dict=filter_instance.errors)
"""Check that ordering and filter kwargs are valid."""
self.resource_class(
ordering=self._ordering,
filter_kwargs=self._filter_kwargs,
**self._resource_kwargs,
).get_queryset()
return attrs

def create(
Expand All @@ -88,6 +87,7 @@ def create(
resource_path=self.resource_class.class_path,
file_format_path=f"{file_format_class.__module__}.{file_format_class.__name__}",
resource_kwargs=dict(
ordering=self._ordering,
filter_kwargs=self._filter_kwargs,
**self._resource_kwargs,
),
Expand Down
31 changes: 23 additions & 8 deletions import_export_extensions/api/views/export_job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import collections.abc
import contextlib
import typing

from django.conf import settings
from django.utils import module_loading

from rest_framework import (
decorators,
exceptions,
Expand Down Expand Up @@ -41,17 +45,23 @@ def __new__(cls, name, bases, attrs, **kwargs):
# Skip if it is has no resource_class specified
if not hasattr(viewset, "resource_class"):
return viewset

filter_backends = [
module_loading.import_string(settings.DRF_EXPORT_DJANGO_FILTERS_BACKEND),
]
if viewset.export_ordering_fields:
filter_backends.append(
module_loading.import_string(settings.DRF_EXPORT_ORDERING_BACKEND),
)
decorators.action(
methods=["POST"],
detail=False,
queryset=viewset.resource_class.get_model_queryset(),
filterset_class=getattr(
viewset.resource_class, "filterset_class", None,
),
filter_backends=[
django_filters.rest_framework.DjangoFilterBackend,
],
filter_backends=filter_backends,
ordering=viewset.export_ordering,
ordering_fields=viewset.export_ordering_fields,
)(viewset.start)
decorators.action(
methods=["POST"],
Expand Down Expand Up @@ -101,15 +111,17 @@ class ExportJobViewSet(
serializer_class = serializers.ExportJobSerializer
resource_class: type[resources.CeleryModelResource]
filterset_class: django_filters.rest_framework.FilterSet | None = None
search_fields = ("id",)
ordering = (
search_fields: collections.abc.Sequence[str] = ("id",)
ordering: collections.abc.Sequence[str] = (
"id",
)
ordering_fields = (
ordering_fields: collections.abc.Sequence[str] = (
"id",
"created",
"modified",
)
export_ordering: collections.abc.Sequence[str] = ()
export_ordering_fields: collections.abc.Sequence[str] = ()

def get_queryset(self):
"""Filter export jobs by resource used in viewset."""
Expand Down Expand Up @@ -145,9 +157,12 @@ def get_export_create_serializer_class(self):

def start(self, request: Request):
"""Validate request data and start ExportJob."""
query_params = dict(request.query_params)
ordering = query_params.pop("ordering", self.ordering)
serializer = self.get_serializer(
data=request.data,
filter_kwargs=request.query_params,
ordering=ordering,
filter_kwargs=query_params,
)
serializer.is_valid(raise_exception=True)
export_job = serializer.save()
Expand Down
15 changes: 15 additions & 0 deletions import_export_extensions/apps.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
DEFAULT_MAX_DATASET_ROWS = 100000
# After how many imported/exported rows celery task status will be updated
DEFAULT_STATUS_UPDATE_ROW_COUNT = 100
# Default filter class backends for export api
DEFAULT_DRF_EXPORT_DJANGO_FILTERS_BACKEND = (
"django_filters.rest_framework.DjangoFilterBackend"
)
DEFAULT_DRF_EXPORT_ORDERING_BACKEND = "rest_framework.filters.OrderingFilter"


class CeleryImportExport(AppConfig):
Expand All @@ -31,3 +36,13 @@ def ready(self):
"STATUS_UPDATE_ROW_COUNT",
DEFAULT_STATUS_UPDATE_ROW_COUNT,
)
settings.DRF_EXPORT_DJANGO_FILTERS_BACKEND = getattr(
settings,
"DRF_EXPORT_DJANGO_FILTERS_BACKEND",
DEFAULT_DRF_EXPORT_DJANGO_FILTERS_BACKEND,
)
settings.DRF_EXPORT_ORDERING_BACKEND = getattr(
settings,
"DRF_EXPORT_ORDERING_BACKEND",
DEFAULT_DRF_EXPORT_ORDERING_BACKEND,
)
12 changes: 12 additions & 0 deletions import_export_extensions/resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import typing

from django.conf import settings
from django.core.exceptions import FieldError, ValidationError
from django.db.models import QuerySet
from django.utils import timezone
from django.utils.functional import classproperty
Expand Down Expand Up @@ -38,10 +39,12 @@ class CeleryResourceMixin:
def __init__(
self,
filter_kwargs: dict[str, typing.Any] | None = None,
ordering: collections.abc.Sequence[str] | None = None,
**kwargs,
):
"""Remember init kwargs."""
self._filter_kwargs = filter_kwargs
self._ordering = ordering
self.resource_init_kwargs: dict[str, typing.Any] = kwargs
self.total_objects_count = 0
self.current_object_number = 0
Expand All @@ -59,6 +62,15 @@ def status_update_row_count(self):
def get_queryset(self):
"""Filter export queryset via filterset class."""
queryset = super().get_queryset()
try:
queryset = queryset.order_by(*(self._ordering or ()))
except FieldError as error:
raise ValidationError(
{
# Split error text not to expose all fields to api clients.
"ordering": str(error).split("Choices are:")[0].strip(),
},
) from error
if not self._filter_kwargs:
return queryset
filter_instance = self.filterset_class(
Expand Down
4 changes: 4 additions & 0 deletions test_project/fake_app/api/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ class ArtistExportViewSet(views.ExportJobForUserViewSet):
"""Simple ViewSet for exporting Artist model."""

resource_class = SimpleArtistResource
export_ordering_fields = (
"id",
"name",
)

class ArtistImportViewSet(views.ImportJobForUserViewSet):
"""Simple ViewSet for importing Artist model."""
Expand Down
18 changes: 18 additions & 0 deletions test_project/tests/integration_tests/test_api/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,24 @@ def test_export_api_create_export_job_with_invalid_filter_kwargs(
assert str(response.data["id"][0]) == "Enter a number."


@pytest.mark.django_db(transaction=True)
def test_export_api_create_export_job_with_invalid_ordering(
admin_api_client: test.APIClient,
):
"""Ensure export start API with invalid ordering return an error."""
response = admin_api_client.post(
path=f"{reverse('export-artist-start')}?ordering=invalid_id",
data={
"file_format": "csv",
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST, response.data
assert (
str(response.data["ordering"][0])
== "Cannot resolve keyword 'invalid_id' into field."
), response.data


@pytest.mark.django_db(transaction=True)
def test_export_api_detail(
admin_api_client: test.APIClient,
Expand Down
27 changes: 27 additions & 0 deletions test_project/tests/test_resources.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
import re

from django.core.exceptions import ValidationError as DjangoValidationError

from rest_framework.exceptions import ValidationError

import pytest
Expand Down Expand Up @@ -39,6 +43,29 @@ def test_resource_with_invalid_filter_kwargs():
).get_queryset()


def test_resource_with_ordering():
"""Check that `get_queryset` with ordering will apply correct ordering."""
artists = [ArtistFactory(name=str(num)) for num in range(5)]
resource_queryset = SimpleArtistResource(
ordering=("-name",),
).get_queryset()
assert resource_queryset.last() == artists[0]
assert resource_queryset.first() == artists[-1]


def test_resource_with_invalid_ordering():
"""Check that `get_queryset` raise error if ordering is invalid."""
with pytest.raises(
DjangoValidationError,
match=(
re.escape(
"{'ordering': [\"Cannot resolve keyword 'invalid_id' into field.\"]}", # noqa: E501
)
),
):
SimpleArtistResource(ordering=("invalid_id",)).get_queryset()


def test_resource_get_error_class():
"""Ensure that CeleryResource overrides error class."""
error_class = SimpleArtistResource().get_error_result_class()
Expand Down

0 comments on commit 64c5a0d

Please sign in to comment.