From b46b7abc86230e056a5eb31cbb7b0bfb4ffd0889 Mon Sep 17 00:00:00 2001 From: Katrina Nguyen Date: Tue, 25 Feb 2025 04:13:14 +0000 Subject: [PATCH] feat: adds endpoint to list group memberships by learner --- CHANGELOG.rst | 4 ++ enterprise/__init__.py | 2 +- enterprise/api/v1/serializers.py | 3 +- enterprise/api/v1/urls.py | 13 ++++ .../v1/views/enterprise_group_membership.py | 64 +++++++++++++++++++ tests/test_enterprise/api/test_views.py | 45 +++++++++++++ 6 files changed, 129 insertions(+), 2 deletions(-) create mode 100644 enterprise/api/v1/views/enterprise_group_membership.py diff --git a/CHANGELOG.rst b/CHANGELOG.rst index ee638b468..e5bff6b87 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -17,6 +17,10 @@ Unreleased ---------- * nothing unreleased +[5.7.0] +-------- +* feat: adds endpoint to list group memberships for a learner + [5.6.12] -------- * chore: Upgrade Python requirements diff --git a/enterprise/__init__.py b/enterprise/__init__.py index bb57f7aaf..679154a9f 100644 --- a/enterprise/__init__.py +++ b/enterprise/__init__.py @@ -2,4 +2,4 @@ Your project description goes here. """ -__version__ = "5.6.12" +__version__ = "5.7.0" diff --git a/enterprise/api/v1/serializers.py b/enterprise/api/v1/serializers.py index d6ef22c11..84f16b39c 100644 --- a/enterprise/api/v1/serializers.py +++ b/enterprise/api/v1/serializers.py @@ -660,7 +660,7 @@ class EnterpriseGroupMembershipSerializer(serializers.ModelSerializer): ) enterprise_group_membership_uuid = serializers.UUIDField(source='uuid', allow_null=True, read_only=True) activated_at = serializers.DateTimeField(required=False) - + name = serializers.CharField(source='group.name') member_details = serializers.SerializerMethodField() recent_action = serializers.SerializerMethodField() status = serializers.CharField(required=False) @@ -678,6 +678,7 @@ class Meta: 'status', 'activated_at', 'enrollments', + 'name', ) def get_member_details(self, obj): diff --git a/enterprise/api/v1/urls.py b/enterprise/api/v1/urls.py index f46d5b55d..ffbe29298 100644 --- a/enterprise/api/v1/urls.py +++ b/enterprise/api/v1/urls.py @@ -22,6 +22,7 @@ enterprise_customer_sso_configuration, enterprise_customer_support, enterprise_customer_user, + enterprise_group_membership, enterprise_group, enterprise_subsidy_fulfillment, notifications, @@ -83,6 +84,11 @@ router.register( "enterprise_group", enterprise_group.EnterpriseGroupViewSet, 'enterprise-group' ) +router.register( + "enterprise-group-membership", + enterprise_group_membership.EnterpriseGroupMembershipViewSet, + 'enterprise-group-membership' +) router.register( "default-enterprise-enrollment-intentions", default_enterprise_enrollments.DefaultEnterpriseEnrollmentIntentionViewSet, @@ -189,6 +195,13 @@ ), name='enterprise-group-learners' ), + re_path( + r'^enterprise-group-membership/?$', + enterprise_group_membership.EnterpriseGroupMembershipViewSet.as_view( + {'get': 'get_memberships'} + ), + name='enterprise-group-membership' + ), re_path( r'^enterprise_group/(?P[A-Za-z0-9-]+)/assign_learners/?$', enterprise_group.EnterpriseGroupViewSet.as_view({'post': 'assign_learners'}), diff --git a/enterprise/api/v1/views/enterprise_group_membership.py b/enterprise/api/v1/views/enterprise_group_membership.py new file mode 100644 index 000000000..75d24d029 --- /dev/null +++ b/enterprise/api/v1/views/enterprise_group_membership.py @@ -0,0 +1,64 @@ +""" +Views for the ``enterprise-group-membership`` API endpoint. +""" + +from django_filters.rest_framework import DjangoFilterBackend +from rest_framework import filters, permissions +from rest_framework.decorators import action +from rest_framework.response import Response +from rest_framework.status import HTTP_400_BAD_REQUEST + +from django.contrib import auth + +from enterprise import models +from enterprise.api.v1 import serializers +from enterprise.api.v1.views.base_views import EnterpriseReadWriteModelViewSet +from enterprise.logging import getEnterpriseLogger + +LOGGER = getEnterpriseLogger(__name__) + +User = auth.get_user_model() + + +class EnterpriseGroupMembershipViewSet(EnterpriseReadWriteModelViewSet): + """ + API views for the ``enterprise-group-membership`` API endpoint. + """ + queryset = models.EnterpriseGroupMembership.all_objects.all() + permission_classes = (permissions.IsAuthenticated,) + filter_backends = (filters.OrderingFilter, DjangoFilterBackend,) + serializer_class = serializers.EnterpriseGroupMembershipSerializer + + @action(detail=False, methods=['get']) + def get_memberships(self, request): + """ + Endpoint that filters by `lms_user_id` and `enterprise_uuid`. + + Parameters: + - `lms_user_id` (str): Filter results by the LMS user ID. + - `enterprise_uuid` (str): Filter results by the Enterprise UUID. + + Response: + - Returns a list of EnterpriseGroupMemberships matching the filters. + - Response format: JSON array of serialized `EnterpriseGroupMembership` objects. + """ + queryset = self.queryset + + lms_user_id = request.query_params.get('lms_user_id') + enterprise_uuid = request.query_params.get('enterprise_uuid') + + if not lms_user_id or not enterprise_uuid: + return Response( + {"error": "Both 'lms_user_id' and 'enterprise_uuid' are required parameters."}, + status=HTTP_400_BAD_REQUEST + ) + + if lms_user_id: + queryset = queryset.filter(enterprise_customer_user__user_id=lms_user_id) + if enterprise_uuid: + queryset = queryset.filter(enterprise_customer_user__enterprise_customer__uuid=enterprise_uuid) + + page = self.paginate_queryset(queryset) + + serializer = self.get_serializer(page, many=True) + return self.get_paginated_response(serializer.data) diff --git a/tests/test_enterprise/api/test_views.py b/tests/test_enterprise/api/test_views.py index 6ca07bd1c..6bd2b2f8c 100644 --- a/tests/test_enterprise/api/test_views.py +++ b/tests/test_enterprise/api/test_views.py @@ -9882,3 +9882,48 @@ def test_list_users_filtered(self): assert expected_json == response.json().get('results') assert response.json().get('count') == 1 + + +class EnterpriseGroupMembershipViewSetTests(BaseTestEnterpriseAPIViews): + """Unit tests for EnterpriseGroupMembershipViewSet.""" + + def setUp(self): + """Set up test data.""" + super().setUp() + self.enterprise_customer = EnterpriseCustomerFactory() + self.enterprise_customer_user = EnterpriseCustomerUserFactory( + user_id=self.user.id, enterprise_customer=self.enterprise_customer + ) + self.enterprise_customer_user = EnterpriseCustomerUserFactory( + user_id="123", enterprise_customer__uuid=FAKE_UUIDS[0]) + self.membership1 = EnterpriseGroupMembershipFactory(enterprise_customer_user=self.enterprise_customer_user) + self.membership2 = EnterpriseGroupMembershipFactory() + + self.url = reverse("enterprise-group-membership") + + def test_missing_required_params(self): + """Ensure API returns 400 Bad Request if required parameters are missing.""" + response = self.client.get(self.url) + self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST) + self.assertIn("error", response.json()) + + def test_filter_by_lms_user_id_and_enterprise_uuid(self): + """Ensure filtering by lms_user_id and enterprise_uuid returns correct results.""" + response = self.client.get(self.url, {"lms_user_id": "123", "enterprise_uuid": FAKE_UUIDS[0]}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.json()["results"]), 1) + + def test_no_matching_results(self): + """Ensure API returns empty results if no matching records exist.""" + response = self.client.get(self.url, {"lms_user_id": "999", "enterprise_uuid": FAKE_UUIDS[0]}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertEqual(len(response.json()["results"]), 0) + + def test_pagination_applied(self): + """Ensure pagination applies correctly if multiple records exist.""" + EnterpriseGroupMembershipFactory.create_batch(15, enterprise_customer_user=self.enterprise_customer_user) + response = self.client.get(self.url, {"lms_user_id": "123", "enterprise_uuid": FAKE_UUIDS[0]}) + self.assertEqual(response.status_code, status.HTTP_200_OK) + self.assertIn("count", response.json()) + self.assertIn("next", response.json()) + self.assertIn("results", response.json())