Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions forum/api/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _get_thread_data_from_request_data(data: dict[str, Any]) -> dict[str, Any]:
"pinned",
"group_id",
"context",
"user_group_ids",
]
result = {field: data.get(field) for field in fields if data.get(field) is not None}

Expand Down Expand Up @@ -286,6 +287,7 @@ def create_thread(
thread_type: str = "discussion",
group_id: Optional[int] = None,
context: str = "course",
user_group_ids: Optional[list[int]] = None,
) -> dict[str, Any]:
"""
Create a new thread.
Expand Down Expand Up @@ -315,6 +317,7 @@ def create_thread(
"thread_type": thread_type,
"group_id": group_id,
"context": context,
"user_group_ids": user_group_ids,
}
thread_data: dict[str, Any] = _get_thread_data_from_request_data(data)

Expand Down Expand Up @@ -380,6 +383,7 @@ def get_user_threads(
"user_id": user_id,
"group_id": group_id,
"group_ids": group_ids,
"user_group_ids": kwargs.get("user_group_ids"),
}
params = {k: v for k, v in params.items() if v is not None}
backend.validate_params(params)
Expand Down
2 changes: 2 additions & 0 deletions forum/api/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ def get_user_active_threads(
page: Optional[int] = FORUM_DEFAULT_PAGE,
per_page: Optional[int] = FORUM_DEFAULT_PER_PAGE,
group_id: Optional[str] = None,
**kwargs,
) -> dict[str, Any]:
"""Get user active threads."""
backend = get_backend(course_id)()
Expand Down Expand Up @@ -237,6 +238,7 @@ def get_user_active_threads(
"user_id": user_id,
"course_id": course_id,
"group_ids": [int(group_id)] if group_id else [],
"user_group_ids": kwargs.get("user_group_ids"),
"author_id": author_id,
"thread_type": thread_type,
"filter_flagged": flagged,
Expand Down
13 changes: 12 additions & 1 deletion forum/backends/mysql/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@
from forum.constants import RETIRED_BODY, RETIRED_TITLE
from forum.utils import get_group_ids_from_params


class MySQLBackend(AbstractBackend):
"""MySQL backend api."""

Expand Down Expand Up @@ -606,6 +605,7 @@ def handle_threads_query(
per_page: int,
context: str = "course",
raw_query: bool = False,
**kwargs: Any, # We use kwargs for not modifying the function signature
) -> dict[str, Any]:
"""
Handles complex thread queries based on various filters and returns paginated results.
Expand Down Expand Up @@ -658,6 +658,13 @@ def handle_threads_query(
Q(group_id__in=group_ids) | Q(group_id__isnull=True)
)

# User group filtering
if kwargs.get("user_group_ids"):
user_groups_filter = Q(user_group_ids__isnull=True)
for group_id in kwargs.get("user_group_ids"):
user_groups_filter |= Q(user_group_ids__contains=group_id)
base_query = base_query.filter(user_groups_filter)

# Author filtering
if author_id:
base_query = base_query.filter(author__pk=author_id)
Expand Down Expand Up @@ -1018,6 +1025,7 @@ def validate_params(
"commentable_ids",
"group_id",
"group_ids",
"user_group_ids",
]
if not user_id:
valid_params.append("user_id")
Expand Down Expand Up @@ -1071,6 +1079,7 @@ def get_threads(
params.get("sort_key", ""),
int(params.get("page", 1)),
int(params.get("per_page", 100)),
user_group_ids=params.get("user_group_ids"),
)
context: dict[str, Any] = {
"count_flagged": count_flagged,
Expand Down Expand Up @@ -1753,6 +1762,8 @@ def create_thread(data: dict[str, Any]) -> str:
optional_args = {}
if group_id := data.get("group_id"):
optional_args["group_id"] = group_id
if user_group_ids := data.get("user_group_ids"):
optional_args["user_group_ids"] = user_group_ids
new_thread = CommentThread.objects.create(
title=data["title"],
body=data["body"],
Expand Down
4 changes: 4 additions & 0 deletions forum/backends/mysql/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ class Content(models.Model):
group_id: models.PositiveIntegerField[int, int] = models.PositiveIntegerField(
null=True
)
user_group_ids: models.JSONField[list[int], list[int]] = models.JSONField(
null=True,
)
created_at: models.DateTimeField[datetime, datetime] = models.DateTimeField(
auto_now_add=True
)
Expand Down Expand Up @@ -294,6 +297,7 @@ def to_dict(self) -> dict[str, Any]:
"last_activity_at": self.last_activity_at,
"edit_history": edit_history,
"group_id": self.group_id,
"user_group_ids": self.user_group_ids,
}

def doc_to_hash(self) -> dict[str, Any]:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Generated by Django 4.2.16 on 2025-06-18 23:58

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("forum", "0003_alter_commentthread_title"),
]

operations = [
migrations.AddField(
model_name="comment",
name="user_group_ids",
field=models.JSONField(null=True),
),
migrations.AddField(
model_name="commentthread",
name="user_group_ids",
field=models.JSONField(null=True),
),
]
1 change: 1 addition & 0 deletions forum/serializers/thread.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ class ThreadSerializer(ContentSerializer):
resp_total = serializers.SerializerMethodField(required=False)
resp_skip = serializers.IntegerField(required=False, default=0)
resp_limit = serializers.IntegerField(required=False, default=10)
user_group_ids = serializers.ListField(allow_null=True, default=None)

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""
Expand Down
Loading