Skip to content

Commit ddc622f

Browse files
committed
♻️(backend) refactor resource access viewset
The document viewset was overriding the get_queryset method from its own mixin. This was a sign that the mixin was not optimal anymore. In the next commit I will need to complexify it further so it's time to refactor the mixin.
1 parent f5fe4b6 commit ddc622f

File tree

3 files changed

+59
-92
lines changed

3 files changed

+59
-92
lines changed

src/backend/core/api/viewsets.py

Lines changed: 46 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -219,58 +219,17 @@ def get_me(self, request):
219219
class ResourceAccessViewsetMixin:
220220
"""Mixin with methods common to all access viewsets."""
221221

222-
def get_permissions(self):
223-
"""User only needs to be authenticated to list resource accesses"""
224-
if self.action == "list":
225-
permission_classes = [permissions.IsAuthenticated]
226-
else:
227-
return super().get_permissions()
228-
229-
return [permission() for permission in permission_classes]
222+
def filter_queryset(self, queryset):
223+
"""Override to filter on related resource."""
224+
queryset = super().filter_queryset(queryset)
225+
return queryset.filter(**{self.resource_field_name: self.kwargs["resource_id"]})
230226

231227
def get_serializer_context(self):
232228
"""Extra context provided to the serializer class."""
233229
context = super().get_serializer_context()
234230
context["resource_id"] = self.kwargs["resource_id"]
235231
return context
236232

237-
def get_queryset(self):
238-
"""Return the queryset according to the action."""
239-
queryset = super().get_queryset()
240-
queryset = queryset.filter(
241-
**{self.resource_field_name: self.kwargs["resource_id"]}
242-
)
243-
244-
if self.action == "list":
245-
user = self.request.user
246-
teams = user.teams
247-
user_roles_query = (
248-
queryset.filter(
249-
db.Q(user=user) | db.Q(team__in=teams),
250-
**{self.resource_field_name: self.kwargs["resource_id"]},
251-
)
252-
.values(self.resource_field_name)
253-
.annotate(roles_array=ArrayAgg("role"))
254-
.values("roles_array")
255-
)
256-
257-
# Limit to resource access instances related to a resource THAT also has
258-
# a resource access
259-
# instance for the logged-in user (we don't want to list only the resource
260-
# access instances pointing to the logged-in user)
261-
queryset = (
262-
queryset.filter(
263-
db.Q(**{f"{self.resource_field_name}__accesses__user": user})
264-
| db.Q(
265-
**{f"{self.resource_field_name}__accesses__team__in": teams}
266-
),
267-
**{self.resource_field_name: self.kwargs["resource_id"]},
268-
)
269-
.annotate(user_roles=db.Subquery(user_roles_query))
270-
.distinct()
271-
)
272-
return queryset
273-
274233
def destroy(self, request, *args, **kwargs):
275234
"""Forbid deleting the last owner access"""
276235
instance = self.get_object()
@@ -1373,7 +1332,11 @@ def cors_proxy(self, request, *args, **kwargs):
13731332

13741333
class DocumentAccessViewSet(
13751334
ResourceAccessViewsetMixin,
1376-
viewsets.ModelViewSet,
1335+
drf.mixins.CreateModelMixin,
1336+
drf.mixins.RetrieveModelMixin,
1337+
drf.mixins.UpdateModelMixin,
1338+
drf.mixins.DestroyModelMixin,
1339+
viewsets.GenericViewSet,
13771340
):
13781341
"""
13791342
API ViewSet for all interactions with document accesses.
@@ -1400,31 +1363,35 @@ class DocumentAccessViewSet(
14001363
"""
14011364

14021365
lookup_field = "pk"
1403-
pagination_class = Pagination
14041366
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
14051367
queryset = models.DocumentAccess.objects.select_related("user").all()
14061368
resource_field_name = "document"
14071369
serializer_class = serializers.DocumentAccessSerializer
14081370
is_current_user_owner_or_admin = False
14091371

1410-
def get_queryset(self):
1411-
"""Return the queryset according to the action."""
1412-
queryset = super().get_queryset()
1372+
def list(self, request, *args, **kwargs):
1373+
"""Return accesses for the current document with filters and annotations."""
1374+
user = self.request.user
1375+
queryset = self.filter_queryset(self.get_queryset())
14131376

1414-
if self.action == "list":
1415-
try:
1416-
document = models.Document.objects.get(pk=self.kwargs["resource_id"])
1417-
except models.Document.DoesNotExist:
1418-
return queryset.none()
1377+
try:
1378+
document = models.Document.objects.get(pk=self.kwargs["resource_id"])
1379+
except models.Document.DoesNotExist:
1380+
return drf.response.Response([])
14191381

1420-
roles = set(document.get_roles(self.request.user))
1421-
is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES)))
1422-
self.is_current_user_owner_or_admin = is_owner_or_admin
1423-
if not is_owner_or_admin:
1424-
# Return only the document owner access
1425-
queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES)
1382+
roles = set(document.get_roles(user))
1383+
if not roles:
1384+
return drf.response.Response([])
14261385

1427-
return queryset
1386+
is_owner_or_admin = bool(roles.intersection(set(models.PRIVILEGED_ROLES)))
1387+
self.is_current_user_owner_or_admin = is_owner_or_admin
1388+
if not is_owner_or_admin:
1389+
# Return only the document's privileged accesses
1390+
queryset = queryset.filter(role__in=models.PRIVILEGED_ROLES)
1391+
1392+
queryset = queryset.distinct()
1393+
serializer = self.get_serializer(queryset, many=True)
1394+
return drf.response.Response(serializer.data)
14281395

14291396
def get_serializer_class(self):
14301397
if self.action == "list" and not self.is_current_user_owner_or_admin:
@@ -1542,7 +1509,6 @@ class TemplateAccessViewSet(
15421509
ResourceAccessViewsetMixin,
15431510
drf.mixins.CreateModelMixin,
15441511
drf.mixins.DestroyModelMixin,
1545-
drf.mixins.ListModelMixin,
15461512
drf.mixins.RetrieveModelMixin,
15471513
drf.mixins.UpdateModelMixin,
15481514
viewsets.GenericViewSet,
@@ -1572,12 +1538,28 @@ class TemplateAccessViewSet(
15721538
"""
15731539

15741540
lookup_field = "pk"
1575-
pagination_class = Pagination
15761541
permission_classes = [permissions.IsAuthenticated, permissions.AccessPermission]
15771542
queryset = models.TemplateAccess.objects.select_related("user").all()
15781543
resource_field_name = "template"
15791544
serializer_class = serializers.TemplateAccessSerializer
15801545

1546+
def list(self, request, *args, **kwargs):
1547+
"""Restrict templates returned by the list endpoint"""
1548+
user = self.request.user
1549+
teams = user.teams
1550+
queryset = self.filter_queryset(self.get_queryset())
1551+
1552+
# Limit to resource access instances related to a resource THAT also has
1553+
# a resource access instance for the logged-in user (we don't want to list
1554+
# only the resource access instances pointing to the logged-in user)
1555+
queryset = queryset.filter(
1556+
db.Q(template__accesses__user=user)
1557+
| db.Q(template__accesses__team__in=teams),
1558+
).distinct()
1559+
1560+
serializer = self.get_serializer(queryset, many=True)
1561+
return drf.response.Response(serializer.data)
1562+
15811563

15821564
class InvitationViewset(
15831565
drf.mixins.CreateModelMixin,

src/backend/core/tests/documents/test_api_document_accesses.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,12 +51,7 @@ def test_api_document_accesses_list_authenticated_unrelated():
5151
f"/api/v1.0/documents/{document.id!s}/accesses/",
5252
)
5353
assert response.status_code == 200
54-
assert response.json() == {
55-
"count": 0,
56-
"next": None,
57-
"previous": None,
58-
"results": [],
59-
}
54+
assert response.json() == []
6055

6156

6257
def test_api_document_accesses_list_unexisting_document():
@@ -70,12 +65,7 @@ def test_api_document_accesses_list_unexisting_document():
7065

7166
response = client.get(f"/api/v1.0/documents/{uuid4()!s}/accesses/")
7267
assert response.status_code == 200
73-
assert response.json() == {
74-
"count": 0,
75-
"next": None,
76-
"previous": None,
77-
"results": [],
78-
}
68+
assert response.json() == []
7969

8070

8171
@pytest.mark.parametrize("via", VIA)
@@ -129,14 +119,14 @@ def test_api_document_accesses_list_authenticated_related_non_privileged(
129119
f"/api/v1.0/documents/{document.id!s}/accesses/",
130120
)
131121

132-
# Return only owners
133-
owners_accesses = [
122+
# Return only privileged roles
123+
privileged_accesses = [
134124
access for access in accesses if access.role in models.PRIVILEGED_ROLES
135125
]
136126
assert response.status_code == 200
137127
content = response.json()
138-
assert content["count"] == len(owners_accesses)
139-
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
128+
assert len(content) == len(privileged_accesses)
129+
assert sorted(content, key=lambda x: x["id"]) == sorted(
140130
[
141131
{
142132
"id": str(access.id),
@@ -152,12 +142,12 @@ def test_api_document_accesses_list_authenticated_related_non_privileged(
152142
"role": access.role,
153143
"abilities": access.get_abilities(user),
154144
}
155-
for access in owners_accesses
145+
for access in privileged_accesses
156146
],
157147
key=lambda x: x["id"],
158148
)
159149

160-
for access in content["results"]:
150+
for access in content:
161151
assert access["role"] in models.PRIVILEGED_ROLES
162152

163153

@@ -216,8 +206,8 @@ def test_api_document_accesses_list_authenticated_related_privileged_roles(
216206

217207
assert response.status_code == 200
218208
content = response.json()
219-
assert len(content["results"]) == 4
220-
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
209+
assert len(content) == 4
210+
assert sorted(content, key=lambda x: x["id"]) == sorted(
221211
[
222212
{
223213
"id": str(user_access.id),

src/backend/core/tests/templates/test_api_template_accesses.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,7 @@ def test_api_template_accesses_list_authenticated_unrelated():
4848
f"/api/v1.0/templates/{template.id!s}/accesses/",
4949
)
5050
assert response.status_code == 200
51-
assert response.json() == {
52-
"count": 0,
53-
"next": None,
54-
"previous": None,
55-
"results": [],
56-
}
51+
assert response.json() == []
5752

5853

5954
@pytest.mark.parametrize("via", VIA)
@@ -96,8 +91,8 @@ def test_api_template_accesses_list_authenticated_related(via, mock_user_teams):
9691

9792
assert response.status_code == 200
9893
content = response.json()
99-
assert len(content["results"]) == 3
100-
assert sorted(content["results"], key=lambda x: x["id"]) == sorted(
94+
assert len(content) == 3
95+
assert sorted(content, key=lambda x: x["id"]) == sorted(
10196
[
10297
{
10398
"id": str(user_access.id),

0 commit comments

Comments
 (0)