Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adds endpoint to list group memberships by learner #2344

Merged
merged 3 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ Unreleased
----------
* nothing unreleased

[5.7.0]
--------
* feat: adds endpoint to list group memberships for a learner

[5.6.12]
--------
* chore: Upgrade Python requirements
Expand Down
2 changes: 1 addition & 1 deletion enterprise/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
Your project description goes here.
"""

__version__ = "5.6.12"
__version__ = "5.7.0"
3 changes: 2 additions & 1 deletion enterprise/api/v1/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ class EnterpriseGroupMembershipSerializer(serializers.ModelSerializer):
)
enterprise_group_membership_uuid = serializers.UUIDField(source='uuid', allow_null=True, read_only=True)
activated_at = serializers.DateTimeField(required=False)

group_name = serializers.CharField(source='group.name')
member_details = serializers.SerializerMethodField()
recent_action = serializers.SerializerMethodField()
status = serializers.CharField(required=False)
Expand All @@ -678,6 +678,7 @@ class Meta:
'status',
'activated_at',
'enrollments',
'group_name',
)

def get_member_details(self, obj):
Expand Down
13 changes: 13 additions & 0 deletions enterprise/api/v1/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
enterprise_customer_support,
enterprise_customer_user,
enterprise_group,
enterprise_group_membership,
enterprise_subsidy_fulfillment,
notifications,
pending_enterprise_customer_admin_user,
Expand Down Expand Up @@ -83,6 +84,11 @@
router.register(
"enterprise_group", enterprise_group.EnterpriseGroupViewSet, 'enterprise-group'
)
router.register(
"enterprise-group-membership",
enterprise_group_membership.EnterpriseGroupMembershipViewSet,
'enterprise-group-membership'
)
router.register(
"default-enterprise-enrollment-intentions",
default_enterprise_enrollments.DefaultEnterpriseEnrollmentIntentionViewSet,
Expand Down Expand Up @@ -189,6 +195,13 @@
),
name='enterprise-group-learners'
),
re_path(
r'^enterprise-group-membership/?$',
enterprise_group_membership.EnterpriseGroupMembershipViewSet.as_view(
{'get': 'get_flex_group_memberships'}
),
name='enterprise-group-membership'
),
re_path(
r'^enterprise_group/(?P<group_uuid>[A-Za-z0-9-]+)/assign_learners/?$',
enterprise_group.EnterpriseGroupViewSet.as_view({'post': 'assign_learners'}),
Expand Down
65 changes: 65 additions & 0 deletions enterprise/api/v1/views/enterprise_group_membership.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
Views for the ``enterprise-group-membership`` API endpoint.
"""

from django_filters.rest_framework import DjangoFilterBackend
from rest_framework import filters, permissions
from rest_framework.decorators import action
from rest_framework.response import Response
from rest_framework.status import HTTP_400_BAD_REQUEST

from django.contrib import auth

from enterprise import models
from enterprise.api.v1 import serializers
from enterprise.api.v1.views.base_views import EnterpriseReadWriteModelViewSet
from enterprise.constants import GROUP_TYPE_FLEX
from enterprise.logging import getEnterpriseLogger

LOGGER = getEnterpriseLogger(__name__)

User = auth.get_user_model()


class EnterpriseGroupMembershipViewSet(EnterpriseReadWriteModelViewSet):
"""
API views for the ``enterprise-group-membership`` API endpoint.
"""
queryset = models.EnterpriseGroupMembership.all_objects.all()
permission_classes = (permissions.IsAuthenticated,)
filter_backends = (filters.OrderingFilter, DjangoFilterBackend,)
serializer_class = serializers.EnterpriseGroupMembershipSerializer

@action(detail=False, methods=['get'])
def get_flex_group_memberships(self, request):
"""
Endpoint that filters flex group memberships by `lms_user_id` and `enterprise_uuid`.

Parameters:
- `lms_user_id` (str): Filter results by the LMS user ID.
- `enterprise_uuid` (str): Filter results by the Enterprise UUID.

Response:
- Returns a list of EnterpriseGroupMemberships matching the filters.
- Response format: JSON array of serialized `EnterpriseGroupMembership` objects.
"""
queryset = self.queryset

lms_user_id = request.query_params.get('lms_user_id')
enterprise_uuid = request.query_params.get('enterprise_uuid')

if not lms_user_id or not enterprise_uuid:
return Response(
{"error": "Both 'lms_user_id' and 'enterprise_uuid' are required parameters."},
status=HTTP_400_BAD_REQUEST
)

queryset = self.queryset.filter(
enterprise_customer_user__user_id=lms_user_id,
enterprise_customer_user__enterprise_customer__uuid=enterprise_uuid,
group__group_type=GROUP_TYPE_FLEX
)
page = self.paginate_queryset(queryset)

serializer = self.get_serializer(page, many=True)
return self.get_paginated_response(serializer.data)
53 changes: 53 additions & 0 deletions tests/test_enterprise/api/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -9882,3 +9882,56 @@ def test_list_users_filtered(self):

assert expected_json == response.json().get('results')
assert response.json().get('count') == 1


class EnterpriseGroupMembershipViewSetTests(BaseTestEnterpriseAPIViews):
"""Unit tests for EnterpriseGroupMembershipViewSet."""

def setUp(self):
"""Set up test data."""
super().setUp()
self.enterprise_customer = EnterpriseCustomerFactory()
self.enterprise_customer_user = EnterpriseCustomerUserFactory(
user_id=self.user.id, enterprise_customer=self.enterprise_customer
)
self.group_1 = EnterpriseGroupFactory(enterprise_customer=self.enterprise_customer, group_type='flex')
self.group_2 = EnterpriseGroupFactory(enterprise_customer=self.enterprise_customer, group_type='budget')
self.enterprise_customer_user = EnterpriseCustomerUserFactory(
user_id="123", enterprise_customer__uuid=FAKE_UUIDS[0])
self.membership1 = EnterpriseGroupMembershipFactory(
enterprise_customer_user=self.enterprise_customer_user,
group=self.group_1,
)
self.membership2 = EnterpriseGroupMembershipFactory(
enterprise_customer_user=self.enterprise_customer_user,
group=self.group_2,
)

self.url = reverse("enterprise-group-membership")

def test_missing_required_params(self):
"""Ensure API returns 400 Bad Request if required parameters are missing."""
response = self.client.get(self.url)
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertIn("error", response.json())

def test_filter_by_lms_user_id_and_enterprise_uuid(self):
"""Ensure filtering by lms_user_id and enterprise_uuid returns correct results."""
response = self.client.get(self.url, {"lms_user_id": "123", "enterprise_uuid": FAKE_UUIDS[0]})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.json()["results"]), 1)

def test_no_matching_results(self):
"""Ensure API returns empty results if no matching records exist."""
response = self.client.get(self.url, {"lms_user_id": "999", "enterprise_uuid": FAKE_UUIDS[0]})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(len(response.json()["results"]), 0)

def test_pagination_applied(self):
"""Ensure pagination applies correctly if multiple records exist."""
EnterpriseGroupMembershipFactory.create_batch(15, enterprise_customer_user=self.enterprise_customer_user)
response = self.client.get(self.url, {"lms_user_id": "123", "enterprise_uuid": FAKE_UUIDS[0]})
self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertIn("count", response.json())
self.assertIn("next", response.json())
self.assertIn("results", response.json())