Skip to content

Commit

Permalink
Fix meta update
Browse files Browse the repository at this point in the history
  • Loading branch information
Marishka17 committed Feb 12, 2025
1 parent c9bbe48 commit 3b9aefc
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 21 deletions.
27 changes: 15 additions & 12 deletions cvat/apps/engine/rq_job_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@
optional_float_validator = attrs.validators.optional(attrs.validators.instance_of(float))


def _update_value(self: AbstractRQMeta, attribute: attrs.Attribute, value: Any):
self._job.meta[attribute.name] = value


@attrs.frozen(kw_only=True)
class UserInfo:
Expand All @@ -55,12 +52,15 @@ def to_dict(self) -> dict[str, Any]:
class AbstractRQMeta(metaclass=ABCMeta):
_job: RQJob | None = attrs.field(init=False, default=None)

def update_value(self, attribute: attrs.Attribute, value: Any):
self._job.meta[attribute.name] = value

def to_dict(self) -> dict:
return asdict(self, filter=lambda k, _: k.name != "_job")
return asdict(self, filter=lambda k, _: not k.name.startswith("_"))

@classmethod
def from_job(cls, rq_job: RQJob):
keys_to_keep = [k.name for k in attrs.fields(cls)]
keys_to_keep = [k.name for k in attrs.fields(cls) if not k.name.startswith("_")]
meta = cls(**{k: v for k, v in rq_job.meta.items() if k in keys_to_keep})
meta._job = rq_job

Expand All @@ -80,26 +80,29 @@ def reset_meta_on_retry(self) -> dict[RQJobMetaField, Any]:

return {k: v for k, v in self._job.meta.items() if k not in resettable_fields}

on_setattr = attrs.setters.pipe(attrs.setters.validate, AbstractRQMeta.update_value)

@attrs.define(kw_only=True)
class RQMetaWithFailureInfo(AbstractRQMeta):

# mutable && optional fields
formatted_exception: str | None = attrs.field(
validator=[optional_str_validator],
default=None,
on_setattr=_update_value,
on_setattr=on_setattr,
)
exc_type: type[Exception] | None = attrs.field(
default=None,
on_setattr=_update_value,
on_setattr=on_setattr,
)
@exc_type.validator
def _check_exc_type(self, attribute: attrs.Attribute, value: Any):
if value is not None and not issubclass(value, Exception):
raise ValueError("Wrong exception type")

exc_args: Iterable | None = attrs.field(default=None, on_setattr=_update_value)
exc_args: Iterable | None = attrs.field(
default=None,
on_setattr=on_setattr
)

@staticmethod
def _get_resettable_fields() -> list[RQJobMetaField]:
Expand Down Expand Up @@ -145,10 +148,10 @@ class BaseRQMeta(RQMetaWithFailureInfo):
progress: float | None = attrs.field(
validator=[optional_float_validator],
default=None,
on_setattr=_update_value,
on_setattr=on_setattr,
)
status: str = attrs.field(
validator=[str_validator], default="", on_setattr=_update_value
validator=[str_validator], default="", on_setattr=on_setattr
)

@staticmethod
Expand Down Expand Up @@ -227,7 +230,7 @@ class ImportRQMeta(BaseRQMeta):

# mutable fields
task_progress: float | None = attrs.field(
validator=[optional_float_validator], default=None, on_setattr=_update_value
validator=[optional_float_validator], default=None, on_setattr=on_setattr
) # used when importing project dataset

@staticmethod
Expand Down
34 changes: 27 additions & 7 deletions cvat/apps/lambda_manager/rq.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@
from __future__ import annotations

import attrs
from attrs import asdict
from django.db.models import Model
from rq.job import Job as RQJob

from cvat.apps.engine.middleware import PatchedRequest
from cvat.apps.engine.rq_job_handler import BaseRQMeta, RQJobMetaField
from cvat.apps.engine.rq_job_handler import BaseRQMeta, RQJobMetaField, on_setattr


@attrs.define(kw_only=True)
Expand All @@ -20,15 +20,34 @@ class LambdaRQMeta(BaseRQMeta):
)
lambda_: bool = attrs.field(
validator=[attrs.validators.instance_of(bool)],
init=False,
default=True,
default=False,
on_setattr=attrs.setters.frozen,
)

# FUTURE-FIXME: progress should be in [0, 1] range
progress: float | None = attrs.field(
validator=[attrs.validators.optional(attrs.validators.instance_of(int))],
default=None,
on_setattr=on_setattr,
)

@classmethod
def from_job(cls, rq_job: RQJob):
keys_to_keep = [k.name for k in attrs.fields(cls) if not k.name.startswith("_")]
params = {}
for k, v in rq_job.meta.items():
if k in keys_to_keep:
params[k] = v
elif k == RQJobMetaField.LAMBDA:
params[RQJobMetaField.LAMBDA + "_"] = v
meta = cls(**params)
meta._job = rq_job

return meta

def to_dict(self) -> dict:
d = asdict(self)
if v := d.pop(RQJobMetaField.LAMBDA + "_", None) is not None:
d[RQJobMetaField.LAMBDA] = v
d = super().to_dict()
d[RQJobMetaField.LAMBDA] = d.pop(RQJobMetaField.LAMBDA + "_")

return d

Expand All @@ -44,4 +63,5 @@ def build_for(
return cls(
**base_meta,
function_id=function_id,
lambda_=True,
).to_dict()
4 changes: 2 additions & 2 deletions cvat/apps/lambda_manager/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,7 +591,7 @@ def get_jobs(self):
)
jobs = queue.job_class.fetch_many(job_ids, queue.connection)

return [LambdaJob(job) for job in jobs if job and job.meta.get("lambda")]
return [LambdaJob(job) for job in jobs if job and LambdaRQMeta.from_job(job).lambda_]

def enqueue(
self,
Expand Down Expand Up @@ -702,7 +702,7 @@ def to_dict(self):
),
},
"status": self.job.get_status(),
"progress": self.job.meta.get("progress", 0),
"progress": LambdaRQMeta.from_job(self.job).progress,
"enqueued": self.job.enqueued_at,
"started": self.job.started_at,
"ended": self.job.ended_at,
Expand Down

0 comments on commit 3b9aefc

Please sign in to comment.