Skip to content

Commit

Permalink
feat: adds endpoint to list group memberships by learner
Browse files Browse the repository at this point in the history
  • Loading branch information
katrinan029 committed Feb 25, 2025
1 parent ff9cd5a commit b46b7ab
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 2 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ Unreleased
----------
* nothing unreleased

[5.7.0]
--------
* feat: adds endpoint to list group memberships for a learner

[5.6.12]
--------
* chore: Upgrade Python requirements
Expand Down
2 changes: 1 addition & 1 deletion enterprise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Your project description goes here.
"""

__version__ = "5.6.12"
__version__ = "5.7.0"
3 changes: 2 additions & 1 deletion enterprise/api/v1/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ class EnterpriseGroupMembershipSerializer(serializers.ModelSerializer):
)
enterprise_group_membership_uuid = serializers.UUIDField(source='uuid', allow_null=True, read_only=True)
activated_at = serializers.DateTimeField(required=False)

name = serializers.CharField(source='group.name')
member_details = serializers.SerializerMethodField()
recent_action = serializers.SerializerMethodField()
status = serializers.CharField(required=False)
Expand All @@ -678,6 +678,7 @@ class Meta:
'status',
'activated_at',
'enrollments',
'name',
)

def get_member_details(self, obj):
Expand Down
13 changes: 13 additions & 0 deletions enterprise/api/v1/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
enterprise_customer_sso_configuration,
enterprise_customer_support,
enterprise_customer_user,
enterprise_group_membership,
enterprise_group,
enterprise_subsidy_fulfillment,
notifications,
Expand Down Expand Up @@ -83,6 +84,11 @@
router.register(
"enterprise_group", enterprise_group.EnterpriseGroupViewSet, 'enterprise-group'
)
router.register(
"enterprise-group-membership",
enterprise_group_membership.EnterpriseGroupMembershipViewSet,
'enterprise-group-membership'
)
router.register(
"default-enterprise-enrollment-intentions",
default_enterprise_enrollments.DefaultEnterpriseEnrollmentIntentionViewSet,
Expand Down Expand Up @@ -189,6 +195,13 @@
),
name='enterprise-group-learners'
),
re_path(
r'^enterprise-group-membership/?$',
enterprise_group_membership.EnterpriseGroupMembershipViewSet.as_view(
{'get': 'get_memberships'}
),
name='enterprise-group-membership'
),
re_path(
r'^enterprise_group/(?P<group_uuid>[A-Za-z0-9-]+)/assign_learners/?$',
enterprise_group.EnterpriseGroupViewSet.as_view({'post': 'assign_learners'}),
Expand Down
64 changes: 64 additions & 0 deletions enterprise/api/v1/views/enterprise_group_membership.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
Views for the ``enterprise-group-membership`` API endpoint.
"""

from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import filters, permissions
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.status import HTTP_400_BAD_REQUEST

from django.contrib import auth

from enterprise import models
from enterprise.api.v1 import serializers
from enterprise.api.v1.views.base_views import EnterpriseReadWriteModelViewSet
from enterprise.logging import getEnterpriseLogger

LOGGER = getEnterpriseLogger(__name__)

User = auth.get_user_model()


class EnterpriseGroupMembershipViewSet(EnterpriseReadWriteModelViewSet):
"""
API views for the ``enterprise-group-membership`` API endpoint.
"""
queryset = models.EnterpriseGroupMembership.all_objects.all()
permission_classes = (permissions.IsAuthenticated,)
filter_backends = (filters.OrderingFilter, DjangoFilterBackend,)
serializer_class = serializers.EnterpriseGroupMembershipSerializer

@action(detail=False, methods=['get'])
def get_memberships(self, request):
"""
Endpoint that filters by `lms_user_id` and `enterprise_uuid`.
Parameters:
- `lms_user_id` (str): Filter results by the LMS user ID.
- `enterprise_uuid` (str): Filter results by the Enterprise UUID.
Response:
- Returns a list of EnterpriseGroupMemberships matching the filters.
- Response format: JSON array of serialized `EnterpriseGroupMembership` objects.
"""
queryset = self.queryset

lms_user_id = request.query_params.get('lms_user_id')
enterprise_uuid = request.query_params.get('enterprise_uuid')

if not lms_user_id or not enterprise_uuid:
return Response(
{"error": "Both 'lms_user_id' and 'enterprise_uuid' are required parameters."},
status=HTTP_400_BAD_REQUEST
)

if lms_user_id:
queryset = queryset.filter(enterprise_customer_user__user_id=lms_user_id)
if enterprise_uuid:
queryset = queryset.filter(enterprise_customer_user__enterprise_customer__uuid=enterprise_uuid)

page = self.paginate_queryset(queryset)

serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
45 changes: 45 additions & 0 deletions tests/test_enterprise/api/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9882,3 +9882,48 @@ def test_list_users_filtered(self):

assert expected_json == response.json().get('results')
assert response.json().get('count') == 1


class EnterpriseGroupMembershipViewSetTests(BaseTestEnterpriseAPIViews):
"""Unit tests for EnterpriseGroupMembershipViewSet."""

def setUp(self):
"""Set up test data."""
super().setUp()
self.enterprise_customer = EnterpriseCustomerFactory()
self.enterprise_customer_user = EnterpriseCustomerUserFactory(
user_id=self.user.id, enterprise_customer=self.enterprise_customer
)
self.enterprise_customer_user = EnterpriseCustomerUserFactory(
user_id="123", enterprise_customer__uuid=FAKE_UUIDS[0])
self.membership1 = EnterpriseGroupMembershipFactory(enterprise_customer_user=self.enterprise_customer_user)
self.membership2 = EnterpriseGroupMembershipFactory()

self.url = reverse("enterprise-group-membership")

def test_missing_required_params(self):
"""Ensure API returns 400 Bad Request if required parameters are missing."""
response = self.client.get(self.url)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("error", response.json())

def test_filter_by_lms_user_id_and_enterprise_uuid(self):
"""Ensure filtering by lms_user_id and enterprise_uuid returns correct results."""
response = self.client.get(self.url, {"lms_user_id": "123", "enterprise_uuid": FAKE_UUIDS[0]})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.json()["results"]), 1)

def test_no_matching_results(self):
"""Ensure API returns empty results if no matching records exist."""
response = self.client.get(self.url, {"lms_user_id": "999", "enterprise_uuid": FAKE_UUIDS[0]})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.json()["results"]), 0)

def test_pagination_applied(self):
"""Ensure pagination applies correctly if multiple records exist."""
EnterpriseGroupMembershipFactory.create_batch(15, enterprise_customer_user=self.enterprise_customer_user)
response = self.client.get(self.url, {"lms_user_id": "123", "enterprise_uuid": FAKE_UUIDS[0]})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("count", response.json())
self.assertIn("next", response.json())
self.assertIn("results", response.json())

0 comments on commit b46b7ab

Please sign in to comment.