diff --git a/label_studio/data_import/api.py b/label_studio/data_import/api.py index f5331c25dd5a..388d663f2522 100644 --- a/label_studio/data_import/api.py +++ b/label_studio/data_import/api.py @@ -28,6 +28,7 @@ from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.views import APIView +from tasks.functions import update_tasks_counters from tasks.models import Prediction, Task from users.models import User from webhooks.models import WebhookAction @@ -358,7 +359,7 @@ def create(self, request, *args, **kwargs): ) ) predictions_obj = Prediction.objects.bulk_create(predictions, batch_size=settings.BATCH_SIZE) - project.update_tasks_counters(Task.objects.filter(id__in=tasks_ids)) + start_job_async_or_sync(update_tasks_counters, Task.objects.filter(id__in=tasks_ids)) return Response({'created': len(predictions_obj)}, status=status.HTTP_201_CREATED) diff --git a/label_studio/data_manager/actions/basic.py b/label_studio/data_manager/actions/basic.py index bdc588841906..42676390c73a 100644 --- a/label_studio/data_manager/actions/basic.py +++ b/label_studio/data_manager/actions/basic.py @@ -9,6 +9,7 @@ from data_manager.functions import evaluate_predictions from django.conf import settings from projects.models import Project +from tasks.functions import update_tasks_counters from tasks.models import Annotation, AnnotationDraft, Prediction, Task from webhooks.models import WebhookAction from webhooks.utils import emit_webhooks_for_instance @@ -115,7 +116,7 @@ def delete_tasks_predictions(project, queryset, **kwargs): real_task_ids = set(list(predictions.values_list('task__id', flat=True))) count = predictions.count() predictions.delete() - project.update_tasks_counters(Task.objects.filter(id__in=real_task_ids)) + start_job_async_or_sync(update_tasks_counters, Task.objects.filter(id__in=real_task_ids)) return {'processed_items': count, 'detail': 'Deleted ' + str(count) + ' predictions'} diff --git a/label_studio/projects/mixins.py b/label_studio/projects/mixins.py index 9a9904f5b73f..0d6b29aaa1f7 100644 --- a/label_studio/projects/mixins.py +++ b/label_studio/projects/mixins.py @@ -11,14 +11,6 @@ def rearrange_overlap_cohort(self): """ start_job_async_or_sync(self._rearrange_overlap_cohort) - def update_tasks_counters(self, tasks_queryset, from_scratch=True): - """ - Async start updating tasks counters - :param tasks_queryset: Tasks to update queryset - :param from_scratch: Skip calculated tasks - """ - start_job_async_or_sync(self._update_tasks_counters, tasks_queryset, from_scratch=from_scratch) - def update_tasks_counters_and_is_labeled(self, tasks_queryset, from_scratch=True): """ Async start updating tasks counters and than is_labeled diff --git a/label_studio/projects/models.py b/label_studio/projects/models.py index 3a95d7b7a7ee..39f8666b5a66 100644 --- a/label_studio/projects/models.py +++ b/label_studio/projects/models.py @@ -5,7 +5,6 @@ from typing import Any, Mapping, Optional from annoying.fields import AutoOneToOneField -from core.bulk_update_utils import bulk_update from core.feature_flags import flag_set from core.label_config import ( check_control_in_config_by_regex, @@ -28,7 +27,6 @@ merge_labels_counters, ) from core.utils.exceptions import LabelStudioValidationErrorSentryIgnored -from data_manager.managers import TaskQuerySet from django.conf import settings from django.core.validators import MaxLengthValidator, MinLengthValidator from django.db import models, transaction @@ -919,56 +917,6 @@ def resolve_storage_uri(self, url: str) -> Optional[Mapping[str, Any]]: 'presign_ttl': storage.presign_ttl, } - def _update_tasks_counters(self, queryset, from_scratch=True): - """ - Update tasks counters - :param queryset: Tasks to update queryset - :param from_scratch: Skip calculated tasks - :return: Count of updated tasks - """ - objs = [] - - total_annotations = Count('annotations', distinct=True, filter=Q(annotations__was_cancelled=False)) - cancelled_annotations = Count('annotations', distinct=True, filter=Q(annotations__was_cancelled=True)) - total_predictions = Count('predictions', distinct=True) - # construct QuerySet in case of list of Tasks - if isinstance(queryset, list) and len(queryset) > 0 and isinstance(queryset[0], Task): - queryset = Task.objects.filter(id__in=[task.id for task in queryset]) - # construct QuerySet in case annotated queryset - if isinstance(queryset, TaskQuerySet) and queryset.exists() and isinstance(queryset[0], int): - queryset = Task.objects.filter(id__in=queryset) - - if not from_scratch: - queryset = queryset.exclude( - Q(total_annotations__gt=0) | Q(cancelled_annotations__gt=0) | Q(total_predictions__gt=0) - ) - - # filter our tasks with 0 annotations and 0 predictions and update them with 0 - queryset.filter(annotations__isnull=True, predictions__isnull=True).update( - total_annotations=0, cancelled_annotations=0, total_predictions=0 - ) - - # filter our tasks with 0 annotations and 0 predictions - queryset = queryset.filter(Q(annotations__isnull=False) | Q(predictions__isnull=False)) - queryset = queryset.annotate( - new_total_annotations=total_annotations, - new_cancelled_annotations=cancelled_annotations, - new_total_predictions=total_predictions, - ) - - for task in queryset.only('id', 'total_annotations', 'cancelled_annotations', 'total_predictions'): - task.total_annotations = task.new_total_annotations - task.cancelled_annotations = task.new_cancelled_annotations - task.total_predictions = task.new_total_predictions - objs.append(task) - with transaction.atomic(): - bulk_update( - objs, - update_fields=['total_annotations', 'cancelled_annotations', 'total_predictions'], - batch_size=settings.BATCH_SIZE, - ) - return len(objs) - def _update_tasks_counters_and_is_labeled(self, task_ids, from_scratch=True): """ Update tasks counters and is_labeled in batches of size settings.BATCH_SIZE. @@ -976,6 +924,8 @@ def _update_tasks_counters_and_is_labeled(self, task_ids, from_scratch=True): :param from_scratch: Skip calculated tasks :return: Count of updated tasks """ + from tasks.functions import update_tasks_counters + num_tasks_updated = 0 page_idx = 0 @@ -984,7 +934,7 @@ def _update_tasks_counters_and_is_labeled(self, task_ids, from_scratch=True): # If counters are updated, is_labeled must be updated as well. Hence, if either fails, we # will roll back. queryset = make_queryset_from_iterable(task_ids_slice) - num_tasks_updated += self._update_tasks_counters(queryset, from_scratch) + num_tasks_updated += update_tasks_counters(queryset, from_scratch) bulk_update_stats_project_tasks(queryset, self) page_idx += 1 return num_tasks_updated @@ -1004,8 +954,10 @@ def _update_tasks_counters_and_task_states( :param from_scratch: Skip calculated tasks :return: Count of updated tasks """ + from tasks.functions import update_tasks_counters + queryset = make_queryset_from_iterable(queryset) - objs = self._update_tasks_counters(queryset, from_scratch) + objs = update_tasks_counters(queryset, from_scratch) self._update_tasks_states(maximum_annotations_changed, overlap_cohort_percentage_changed, tasks_number_changed) if recalculate_all_stats and recalculate_stats_counts: diff --git a/label_studio/tasks/functions.py b/label_studio/tasks/functions.py index 2503002c699f..5e3ea9532299 100644 --- a/label_studio/tasks/functions.py +++ b/label_studio/tasks/functions.py @@ -3,13 +3,17 @@ import os import sys +from core.bulk_update_utils import bulk_update from core.models import AsyncMigrationStatus from core.redis import start_job_async_or_sync from core.utils.common import batch from data_export.mixins import ExportMixin from data_export.models import DataExport from data_export.serializers import ExportDataSerializer +from data_manager.managers import TaskQuerySet from django.conf import settings +from django.db import transaction +from django.db.models import Count, Q from organizations.models import Organization from projects.models import Project from tasks.models import Annotation, Prediction, Task @@ -57,25 +61,34 @@ def redis_job_for_calculation(org_id, from_scratch, migration_name='0018_manual_ handler.setFormatter(formatter) logger.addHandler(handler) - projects = Project.objects.filter(organization_id=org_id).order_by('-updated_at') - for project in projects: + project_dicts = ( + Project.objects.filter(organization_id=org_id) + .order_by('-updated_at') + .values( + 'id', + 'updated_at', + 'title', + ) + ) + for project_dict in project_dicts: migration = AsyncMigrationStatus.objects.create( - project=project, + project_id=project_dict['id'], name=migration_name, status=AsyncMigrationStatus.STATUS_STARTED, ) + project_tasks = Task.objects.filter(project_id=project_dict['id']) logger.debug( - f'Start processing stats project <{project.title}> ({project.id}) ' - f'with task count {project.tasks.count()} and updated_at {project.updated_at}' + f'Start processing stats project <{project_dict["title"]}> ({project_dict["id"]}) ' + f'with task count {project_tasks.count()} and updated_at {project_dict["updated_at"]}' ) - task_count = project.update_tasks_counters(project.tasks.all(), from_scratch=from_scratch) + task_count = update_tasks_counters(project_tasks, from_scratch=from_scratch) migration.status = AsyncMigrationStatus.STATUS_FINISHED - migration.meta = {'tasks_processed': task_count, 'total_project_tasks': project.tasks.count()} + migration.meta = {'tasks_processed': task_count, 'total_project_tasks': project_tasks.count()} migration.save() logger.debug( - f'End processing counters for project <{project.title}> ({project.id}), ' + f'End processing counters for project <{project_dict["title"]}> ({project_dict["id"]}), ' f'processed {str(task_count)} tasks' ) @@ -157,3 +170,54 @@ def fill_predictions_project(migration_name): logger.info('Start filling project field for Prediction model') start_job_async_or_sync(_fill_predictions_project, migration_name=migration_name) logger.info('Finished filling project field for Prediction model') + + +def update_tasks_counters(queryset, from_scratch=True): + """ + Update tasks counters for the passed queryset of Tasks + :param queryset: Tasks to update queryset + :param from_scratch: Skip calculated tasks + :return: Count of updated tasks + """ + objs = [] + + total_annotations = Count('annotations', distinct=True, filter=Q(annotations__was_cancelled=False)) + cancelled_annotations = Count('annotations', distinct=True, filter=Q(annotations__was_cancelled=True)) + total_predictions = Count('predictions', distinct=True) + # construct QuerySet in case of list of Tasks + if isinstance(queryset, list) and len(queryset) > 0 and isinstance(queryset[0], Task): + queryset = Task.objects.filter(id__in=[task.id for task in queryset]) + # construct QuerySet in case annotated queryset + if isinstance(queryset, TaskQuerySet) and queryset.exists() and isinstance(queryset[0], int): + queryset = Task.objects.filter(id__in=queryset) + + if not from_scratch: + queryset = queryset.exclude( + Q(total_annotations__gt=0) | Q(cancelled_annotations__gt=0) | Q(total_predictions__gt=0) + ) + + # filter our tasks with 0 annotations and 0 predictions and update them with 0 + queryset.filter(annotations__isnull=True, predictions__isnull=True).update( + total_annotations=0, cancelled_annotations=0, total_predictions=0 + ) + + # filter our tasks with 0 annotations and 0 predictions + queryset = queryset.filter(Q(annotations__isnull=False) | Q(predictions__isnull=False)) + queryset = queryset.annotate( + new_total_annotations=total_annotations, + new_cancelled_annotations=cancelled_annotations, + new_total_predictions=total_predictions, + ) + + for task in queryset.only('id', 'total_annotations', 'cancelled_annotations', 'total_predictions'): + task.total_annotations = task.new_total_annotations + task.cancelled_annotations = task.new_cancelled_annotations + task.total_predictions = task.new_total_predictions + objs.append(task) + with transaction.atomic(): + bulk_update( + objs, + update_fields=['total_annotations', 'cancelled_annotations', 'total_predictions'], + batch_size=settings.BATCH_SIZE, + ) + return len(objs) diff --git a/label_studio/tasks/management/commands/calculate_stats.py b/label_studio/tasks/management/commands/calculate_stats.py index 65220bc3df5e..43381a9c9c59 100644 --- a/label_studio/tasks/management/commands/calculate_stats.py +++ b/label_studio/tasks/management/commands/calculate_stats.py @@ -1,7 +1,9 @@ import logging +from core.redis import start_job_async_or_sync from django.core.management.base import BaseCommand from projects.models import Project +from tasks.functions import update_tasks_counters logger = logging.getLogger(__name__) @@ -18,7 +20,7 @@ def handle(self, *args, **options): for project in projects: logger.debug(f'Start processing project {project.id}.') - project.update_tasks_counters(project.tasks.all()) + start_job_async_or_sync(update_tasks_counters, project.tasks.all()) logger.debug(f'End processing project {project.id}.') logger.debug(f"Organization {options['organization']} stats were recalculated.")