diff --git a/cvat/apps/engine/background.py b/cvat/apps/engine/background.py index 02628354b2d..e3683ec51d8 100644 --- a/cvat/apps/engine/background.py +++ b/cvat/apps/engine/background.py @@ -477,7 +477,7 @@ def setup_background_job( result_url = self.make_result_url() with get_rq_lock_by_user(queue, user_id): - meta = ExportRQMeta.build( + meta = ExportRQMeta.build_for( request=self.request, db_obj=self.db_instance, result_url=result_url, @@ -758,7 +758,7 @@ def setup_background_job( user_id = self.request.user.id with get_rq_lock_by_user(queue, user_id): - meta = ExportRQMeta.build( + meta = ExportRQMeta.build_for( request=self.request, db_obj=self.db_instance, result_url=result_url, diff --git a/cvat/apps/engine/backup.py b/cvat/apps/engine/backup.py index 8304c6d287d..dff061f0c15 100644 --- a/cvat/apps/engine/backup.py +++ b/cvat/apps/engine/backup.py @@ -1197,7 +1197,7 @@ def _import(importer, request: PatchedRequest, queue, rq_id, Serializer, file_fi user_id = request.user.id with get_rq_lock_by_user(queue, user_id): - meta = ImportRQMeta.build( + meta = ImportRQMeta.build_for( request=request, db_obj=None, tmp_file=filename, diff --git a/cvat/apps/engine/rq_job_handler.py b/cvat/apps/engine/rq_job_handler.py index 0d6b9c3ac46..1046385f7ab 100644 --- a/cvat/apps/engine/rq_job_handler.py +++ b/cvat/apps/engine/rq_job_handler.py @@ -83,7 +83,7 @@ def reset_meta_on_retry(self) -> dict[RQJobMetaField, Any]: @attrs.define(kw_only=True) class RQMetaWithFailureInfo(AbstractRQMeta): - # immutable and optional fields + # mutable && optional fields formatted_exception: str | None = attrs.field( validator=[optional_str_validator], default=None, @@ -98,7 +98,6 @@ class RQMetaWithFailureInfo(AbstractRQMeta): @staticmethod def _get_resettable_fields() -> list[RQJobMetaField]: - """Return a list of fields that must be reset on retry""" return [ RQJobMetaField.FORMATTED_EXCEPTION, RQJobMetaField.EXCEPTION_TYPE, @@ -111,12 +110,12 @@ class BaseRQMeta(RQMetaWithFailureInfo): # immutable and required fields user: UserInfo = attrs.field( validator=[attrs.validators.instance_of(UserInfo)], - converter=lambda d: UserInfo(**d), + converter=lambda x: x if isinstance(x, UserInfo) else UserInfo(**x), on_setattr=attrs.setters.frozen, ) request: RequestInfo = attrs.field( validator=[attrs.validators.instance_of(RequestInfo)], - converter=lambda d: RequestInfo(**d), + converter=lambda x: x if isinstance(x, RequestInfo) else RequestInfo(**x), on_setattr=attrs.setters.frozen, ) @@ -137,17 +136,19 @@ class BaseRQMeta(RQMetaWithFailureInfo): validator=[optional_int_validator], default=None, on_setattr=attrs.setters.frozen ) - # import && lambda + # mutable fields progress: float | None = attrs.field( validator=[optional_float_validator], default=None, on_setattr=_update_value, ) + status: str = attrs.field( + validator=[str_validator], default="", on_setattr=_update_value + ) @staticmethod def _get_resettable_fields() -> list[RQJobMetaField]: - """Return a list of fields that must be reset on retry""" - return RQMetaWithFailureInfo._get_resettable_fields() + [RQJobMetaField.PROGRESS] + return RQMetaWithFailureInfo._get_resettable_fields() + [RQJobMetaField.PROGRESS, RQJobMetaField.STATUS] @classmethod def build( @@ -189,16 +190,15 @@ def build( @attrs.define(kw_only=True) class ExportRQMeta(BaseRQMeta): # will be changed to ExportResultInfo in the next PR - result_url: str | None = attrs.field(validator=[optional_str_validator]) + result_url: str | None = attrs.field(validator=[optional_str_validator], default=None) @staticmethod def _get_resettable_fields() -> list[RQJobMetaField]: - """Return a list of fields that must be reset on retry""" base_fields = BaseRQMeta._get_resettable_fields() return base_fields + [RQJobMetaField.RESULT] @classmethod - def build( + def build_for( cls, *, request: PatchedRequest, @@ -221,27 +221,18 @@ class ImportRQMeta(BaseRQMeta): ) # mutable fields - # TODO: move into base? - status: str = attrs.field( - validator=[optional_str_validator], default="", on_setattr=_update_value - ) task_progress: float | None = attrs.field( validator=[optional_float_validator], default=None, on_setattr=_update_value - ) + ) # used when importing project dataset @staticmethod def _get_resettable_fields() -> list[RQJobMetaField]: - """Return a list of fields that must be reset on retry""" base_fields = BaseRQMeta._get_resettable_fields() - return base_fields + [ - RQJobMetaField.PROGRESS, - RQJobMetaField.TASK_PROGRESS, - RQJobMetaField.STATUS, - ] + return base_fields + [RQJobMetaField.TASK_PROGRESS] @classmethod - def build( + def build_for( cls, *, request: PatchedRequest, @@ -255,42 +246,6 @@ def build( tmp_file=tmp_file, ).to_dict() - -@attrs.define(kw_only=True) -class LambdaRQMeta(BaseRQMeta): - # immutable fields - function_id: int | None = attrs.field( - validator=[optional_int_validator], default=None, on_setattr=attrs.setters.frozen - ) - lambda_: bool | None = attrs.field( - validator=[optional_bool_validator], - init=False, - default=True, - on_setattr=attrs.setters.frozen, - ) - - def to_dict(self) -> dict: - d = asdict(self) - if v := d.pop(RQJobMetaField.LAMBDA + "_", None) is not None: - d[RQJobMetaField.LAMBDA] = v - - return d - - @classmethod - def build( - cls, - *, - request: PatchedRequest, - db_obj: Model, - function_id: int, - ): - base_meta = BaseRQMeta.build(request=request, db_obj=db_obj) - return cls( - **base_meta, - function_id=function_id, - ).to_dict() - - class RQJobMetaField: # common fields FORMATTED_EXCEPTION = "formatted_exception" diff --git a/cvat/apps/engine/serializers.py b/cvat/apps/engine/serializers.py index 056c794088c..fc12af4a9d2 100644 --- a/cvat/apps/engine/serializers.py +++ b/cvat/apps/engine/serializers.py @@ -44,7 +44,6 @@ BaseRQMeta, ExportRQMeta, ImportRQMeta, - LambdaRQMeta, RequestAction, RQId, ) @@ -60,6 +59,7 @@ reverse, take_by, ) +from cvat.apps.lambda_manager.rq import LambdaRQMeta from utils.dataset_manifest import ImageManifestManager slogger = ServerLogManager(__name__) @@ -3553,17 +3553,19 @@ class RequestSerializer(serializers.Serializer): result_url = serializers.URLField(required=False, allow_null=True) result_id = serializers.IntegerField(required=False, allow_null=True) + def __init__(self, *args, **kwargs): + self._base_rq_job_meta: BaseRQMeta | None = None + super().__init__(*args, **kwargs) + @extend_schema_field(UserIdentifiersSerializer()) def get_owner(self, rq_job: RQJob) -> dict[str, Any]: - # TODO: define parsed meta once - rq_job_meta = BaseRQMeta.from_job(rq_job) - return UserIdentifiersSerializer(rq_job_meta.user.to_dict()).data + assert self._base_rq_job_meta + return UserIdentifiersSerializer(self._base_rq_job_meta.user.to_dict()).data @extend_schema_field( serializers.FloatField(min_value=0, max_value=1, required=False, allow_null=True) ) def get_progress(self, rq_job: RQJob) -> Decimal: - # TODO: define parsed meta once rq_job_meta = ImportRQMeta.from_job(rq_job) # progress of task creation is stored in "task_progress" field # progress of project import is stored in "progress" field @@ -3585,19 +3587,19 @@ def get_expiry_date(self, rq_job: RQJob) -> Optional[str]: @extend_schema_field(serializers.CharField(allow_blank=True)) def get_message(self, rq_job: RQJob) -> str: - # TODO: define parsed meta once - rq_job_meta = ImportRQMeta.from_job(rq_job) + assert self._base_rq_job_meta rq_job_status = rq_job.get_status() message = '' if RQJobStatus.STARTED == rq_job_status: - message = rq_job_meta.status + message = self._base_rq_job_meta.status elif RQJobStatus.FAILED == rq_job_status: - message = rq_job_meta.formatted_exception or parse_exception_message(str(rq_job.exc_info or "Unknown error")) + message = self._base_rq_job_meta.formatted_exception or parse_exception_message(str(rq_job.exc_info or "Unknown error")) return message def to_representation(self, rq_job: RQJob) -> dict[str, Any]: + self._base_rq_job_meta = BaseRQMeta.from_job(rq_job) representation = super().to_representation(rq_job) # FUTURE-TODO: support such statuses on UI diff --git a/cvat/apps/engine/task.py b/cvat/apps/engine/task.py index 9533807ac1a..d5835639a0b 100644 --- a/cvat/apps/engine/task.py +++ b/cvat/apps/engine/task.py @@ -83,7 +83,7 @@ def create( func=_create_thread, args=(db_task.pk, data), job_id=rq_id, - meta=ImportRQMeta.build(request=request, db_obj=db_task), + meta=ImportRQMeta.build_for(request=request, db_obj=db_task), depends_on=define_dependent_job(q, user_id), failure_ttl=settings.IMPORT_CACHE_FAILED_TTL.total_seconds(), ) diff --git a/cvat/apps/engine/views.py b/cvat/apps/engine/views.py index 8bcc2e8d522..086749e9db0 100644 --- a/cvat/apps/engine/views.py +++ b/cvat/apps/engine/views.py @@ -3446,7 +3446,7 @@ def _import_annotations(request, rq_id_factory, rq_func, db_obj, format_name, user_id = request.user.id with get_rq_lock_by_user(queue, user_id): - meta = ImportRQMeta.build(request=request, db_obj=db_obj, tmp_file=filename) + meta = ImportRQMeta.build_for(request=request, db_obj=db_obj, tmp_file=filename) rq_job = queue.enqueue_call( func=func, args=func_args, @@ -3548,7 +3548,7 @@ def _import_project_dataset( user_id = request.user.id with get_rq_lock_by_user(queue, user_id): - meta = ImportRQMeta.build(request=request, db_obj=db_obj, tmp_file=filename) + meta = ImportRQMeta.build_for(request=request, db_obj=db_obj, tmp_file=filename) rq_job = queue.enqueue_call( func=func, args=func_args, diff --git a/cvat/apps/lambda_manager/rq.py b/cvat/apps/lambda_manager/rq.py new file mode 100644 index 00000000000..5d6cd6f3f52 --- /dev/null +++ b/cvat/apps/lambda_manager/rq.py @@ -0,0 +1,47 @@ +# Copyright (C) CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import attrs +from attrs import asdict +from django.db.models import Model + +from cvat.apps.engine.middleware import PatchedRequest +from cvat.apps.engine.rq_job_handler import BaseRQMeta, RQJobMetaField + + +@attrs.define(kw_only=True) +class LambdaRQMeta(BaseRQMeta): + # immutable fields + function_id: int = attrs.field( + validator=[attrs.validators.instance_of(int)], default=None, on_setattr=attrs.setters.frozen + ) + lambda_: bool = attrs.field( + validator=[attrs.validators.instance_of(bool)], + init=False, + default=True, + on_setattr=attrs.setters.frozen, + ) + + def to_dict(self) -> dict: + d = asdict(self) + if v := d.pop(RQJobMetaField.LAMBDA + "_", None) is not None: + d[RQJobMetaField.LAMBDA] = v + + return d + + @classmethod + def build_for( + cls, + *, + request: PatchedRequest, + db_obj: Model, + function_id: int, + ): + base_meta = BaseRQMeta.build(request=request, db_obj=db_obj) + return cls( + **base_meta, + function_id=function_id, + ).to_dict() diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index bf2308173db..54d4574e0bd 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -46,13 +46,14 @@ SourceType, Task, ) -from cvat.apps.engine.rq_job_handler import LambdaRQMeta, RQId +from cvat.apps.engine.rq_job_handler import RQId from cvat.apps.engine.serializers import LabeledDataSerializer from cvat.apps.engine.utils import define_dependent_job, get_rq_lock_by_user from cvat.apps.events.handlers import handle_function_call from cvat.apps.iam.filters import ORGANIZATION_OPEN_API_PARAMETERS from cvat.apps.lambda_manager.models import FunctionKind from cvat.apps.lambda_manager.permissions import LambdaPermission +from cvat.apps.lambda_manager.rq import LambdaRQMeta from cvat.apps.lambda_manager.serializers import ( FunctionCallRequestSerializer, FunctionCallSerializer, @@ -640,7 +641,7 @@ def enqueue( user_id = request.user.id with get_rq_lock_by_user(queue, user_id): - meta = LambdaRQMeta.build( + meta = LambdaRQMeta.build_for( request=request, db_obj=Job.objects.get(pk=job) if job else Task.objects.get(pk=task), function_id=lambda_func.id,