Skip to content

Commit

Permalink
Merge pull request #146 from openedx/alisalman/fix-thread-api-cohort-…
Browse files Browse the repository at this point in the history
…users

fix: add group_id and group_ids params in the get_user_threads
  • Loading branch information
Ali-Salman29 authored Jan 16, 2025
2 parents 35ebf78 + a6ec30b commit 342ad3b
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 3 deletions.
4 changes: 4 additions & 0 deletions forum/api/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,8 @@ def get_user_threads(
request_id: Optional[str] = None,
commentable_ids: Optional[str] = None,
user_id: Optional[str] = None,
group_id: Optional[int] = None,
group_ids: Optional[int] = None,
) -> dict[str, Any]:
"""
Get the threads for the given thread_ids.
Expand All @@ -377,6 +379,8 @@ def get_user_threads(
"request_id": request_id,
"commentable_ids": commentable_ids,
"user_id": user_id,
"group_id": group_id,
"group_ids": group_ids,
}
params = {k: v for k, v in params.items() if v is not None}
backend.validate_params(params)
Expand Down
2 changes: 2 additions & 0 deletions forum/backends/mongodb/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,8 @@ def validate_params(params: dict[str, Any], user_id: Optional[str] = None) -> No
"per_page",
"request_id",
"commentable_ids",
"group_id",
"group_ids",
]
if not user_id:
valid_params.append("user_id")
Expand Down
2 changes: 2 additions & 0 deletions forum/backends/mysql/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,8 @@ def validate_params(
"per_page",
"request_id",
"commentable_ids",
"group_id",
"group_ids",
]
if not user_id:
valid_params.append("user_id")
Expand Down
1 change: 1 addition & 0 deletions forum/backends/mysql/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def to_dict(self) -> dict[str, Any]:
"created_at": self.created_at,
"last_activity_at": self.last_activity_at,
"edit_history": edit_history,
"group_id": self.group_id,
}

def doc_to_hash(self) -> dict[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions forum/serializers/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class ThreadSerializer(ContentSerializer):
last_activity_at (datetime): The timestamp of the last activity in the thread.
closed_by (str or None): The user who closed the thread, if any.
tags (list): A list of tags associated with the thread.
group_id (str or None): The ID of the group associated with the thread, if any.
group_id (int or None): The ID of the group associated with the thread, if any.
pinned (bool): Whether the thread is pinned at the top of the list.
comment_count (int): The number of comments on the thread.
Expand All @@ -53,7 +53,7 @@ class ThreadSerializer(ContentSerializer):
closed_by = serializers.SerializerMethodField()
close_reason_code = serializers.CharField(allow_null=True, default=None)
tags = serializers.ListField(default=[])
group_id = serializers.CharField(allow_null=True, default=None)
group_id = serializers.IntegerField(allow_null=True, default=None)
pinned = serializers.BooleanField(default=False)
comments_count = serializers.IntegerField(required=False, source="comment_count")
read = serializers.SerializerMethodField()
Expand Down
77 changes: 76 additions & 1 deletion tests/test_views/test_threads.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Test threads api endpoints."""

import time
from typing import Any, Optional

import pytest

from forum.backends.mongodb.api import MongoBackend
Expand Down Expand Up @@ -901,3 +901,78 @@ def is_thread_id_exists_in_user_read_state(user_id: str, thread_id: str) -> bool
if thread_id in read_state.get("last_read_times", {}):
return True
return False


def test_filter_by_group_id(api_client: APIClient, patched_get_backend: Any) -> None:
"""
Filter threads by their group_id. This should return:
- Threads with the specified group ID.
- Threads that do not belong to any group (i.e., group_id=None).
"""
backend = patched_get_backend
setup_models(backend=backend)

for i in range(2, 5):
time.sleep(0.001)
backend.create_thread(
{
"title": f"Thread {i}",
"body": f"Thread {i}",
"course_id": "course1",
"commentable_id": "CommentThread",
"author_id": "1",
"author_username": "user1",
"abuse_flaggers": [],
"historical_abuse_flaggers": [],
"context": "course",
"group_id": i,
}
)
params = {"course_id": "course1", "group_id": "2"}
response = api_client.get_json("/api/v2/threads", params)
assert response.status_code == 200
results = response.json().get("collection", [])

# The result includes one thread from setup_models with group_id=None
# and another thread: Thread 2 with group_id=2.
assert len(results) == 2
assert results[0]["group_id"] == 2
assert results[1]["group_id"] is None


def test_filter_by_group_ids(api_client: APIClient, patched_get_backend: Any) -> None:
"""
Filter threads by their group IDs. This should return:
- Threads with the specified group IDs.
- Threads that do not belong to any group (i.e., group_id=None).
"""
backend = patched_get_backend
setup_models(backend=backend)
for i in range(2, 5):
time.sleep(0.001)
backend.create_thread(
{
"title": f"Thread {i}",
"body": f"Thread {i}",
"course_id": "course1",
"commentable_id": "CommentThread",
"author_id": "1",
"author_username": "user1",
"abuse_flaggers": [],
"historical_abuse_flaggers": [],
"context": "course",
"group_id": i,
}
)

params = {"course_id": "course1", "group_ids": "2,3"}
response = api_client.get_json("/api/v2/threads", params)
assert response.status_code == 200
results = response.json().get("collection", [])

# The result includes one thread from setup_models with group_id=None
# and two threads: Thread 2 and Thread 3 with group_id=2 and group_id=3, respectively.
assert len(results) == 3
assert results[0]["group_id"] == 3
assert results[1]["group_id"] == 2
assert results[2]["group_id"] is None

0 comments on commit 342ad3b

Please sign in to comment.