From 88ce0a440e865f7367eeaeef12a42014e3ff5432 Mon Sep 17 00:00:00 2001 From: Roman Donchenko Date: Mon, 9 Sep 2024 15:30:03 +0300 Subject: [PATCH] Rename and move `LambdaType` (#8418) This enum is used for a field named `kind` in a resource located at `/api/lambda/function`, so it seems pretty clear that it should be named `FunctionKind`. (Or perhaps `LambdaFunctionKind`, but I omitted the "lambda" for consistency with views and serializers.) In addition to renaming, move it to `models`, so that it can be used by serializers. No such serializers currently exist, but I'd like to add them later. Turn it into a Django choice enum as well, so that `__str__` works out of the box. --- cvat/apps/lambda_manager/models.py | 12 ++++++++++++ cvat/apps/lambda_manager/views.py | 29 ++++++++++------------------- 2 files changed, 22 insertions(+), 19 deletions(-) create mode 100644 cvat/apps/lambda_manager/models.py diff --git a/cvat/apps/lambda_manager/models.py b/cvat/apps/lambda_manager/models.py new file mode 100644 index 000000000000..47d732c41dd1 --- /dev/null +++ b/cvat/apps/lambda_manager/models.py @@ -0,0 +1,12 @@ +# Copyright (C) 2022 Intel Corporation +# Copyright (C) 2022-2024 CVAT.ai Corporation +# +# SPDX-License-Identifier: MIT + +import django.db.models as models + +class FunctionKind(models.TextChoices): + DETECTOR = "detector" + INTERACTOR = "interactor" + REID = "reid" + TRACKER = "tracker" diff --git a/cvat/apps/lambda_manager/views.py b/cvat/apps/lambda_manager/views.py index 9336f33ee5ed..286b8b4cc985 100644 --- a/cvat/apps/lambda_manager/views.py +++ b/cvat/apps/lambda_manager/views.py @@ -11,7 +11,6 @@ import textwrap from copy import deepcopy from datetime import timedelta -from enum import Enum from functools import wraps from typing import Any, Dict, Optional @@ -39,6 +38,7 @@ ) from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField from cvat.apps.engine.serializers import LabeledDataSerializer +from cvat.apps.lambda_manager.models import FunctionKind from cvat.apps.lambda_manager.permissions import LambdaPermission from cvat.apps.lambda_manager.serializers import ( FunctionCallRequestSerializer, FunctionCallSerializer @@ -50,15 +50,6 @@ slogger = ServerLogManager(__name__) -class LambdaType(Enum): - DETECTOR = "detector" - INTERACTOR = "interactor" - REID = "reid" - TRACKER = "tracker" - - def __str__(self): - return self.value - class LambdaGateway: NUCLIO_ROOT_URL = '/api/functions' @@ -152,7 +143,7 @@ def __init__(self, gateway, data): meta_anno = data['metadata']['annotations'] kind = meta_anno.get('type') try: - self.kind = LambdaType(kind) + self.kind = FunctionKind(kind) except ValueError as e: raise InvalidFunctionMetadataError( f"{self.id} lambda function has unknown type: {kind!r}") from e @@ -225,7 +216,7 @@ def to_dict(self): 'version': self.version } - if self.kind is LambdaType.INTERACTOR: + if self.kind is FunctionKind.INTERACTOR: response.update({ 'min_pos_points': self.min_pos_points, 'min_neg_points': self.min_neg_points, @@ -394,18 +385,18 @@ def validate_attributes_mapping(attributes_mapping, model_attributes, db_attribu code=status.HTTP_400_BAD_REQUEST) - if self.kind == LambdaType.DETECTOR: + if self.kind == FunctionKind.DETECTOR: payload.update({ "image": self._get_image(db_task, mandatory_arg("frame"), quality) }) - elif self.kind == LambdaType.INTERACTOR: + elif self.kind == FunctionKind.INTERACTOR: payload.update({ "image": self._get_image(db_task, mandatory_arg("frame"), quality), "pos_points": mandatory_arg("pos_points"), "neg_points": mandatory_arg("neg_points"), "obj_bbox": data.get("obj_bbox", None) }) - elif self.kind == LambdaType.REID: + elif self.kind == FunctionKind.REID: payload.update({ "image0": self._get_image(db_task, mandatory_arg("frame0"), quality), "image1": self._get_image(db_task, mandatory_arg("frame1"), quality), @@ -417,7 +408,7 @@ def validate_attributes_mapping(attributes_mapping, model_attributes, db_attribu payload.update({ "max_distance": max_distance }) - elif self.kind == LambdaType.TRACKER: + elif self.kind == FunctionKind.TRACKER: payload.update({ "image": self._get_image(db_task, mandatory_arg("frame"), quality), "shapes": data.get("shapes", []), @@ -466,7 +457,7 @@ def transform_attributes(input_attributes, attr_mapping, db_attributes): }) return attributes - if self.kind == LambdaType.DETECTOR: + if self.kind == FunctionKind.DETECTOR: for item in response: item_label = item['label'] if item_label not in mapping: @@ -985,11 +976,11 @@ def convert_labels(db_labels): labels = convert_labels(db_task.get_labels(prefetch=True)) - if function.kind == LambdaType.DETECTOR: + if function.kind == FunctionKind.DETECTOR: cls._call_detector(function, db_task, labels, quality, kwargs.get("threshold"), kwargs.get("mapping"), kwargs.get("conv_mask_to_poly"), db_job=db_job) - elif function.kind == LambdaType.REID: + elif function.kind == FunctionKind.REID: cls._call_reid(function, db_task, quality, kwargs.get("threshold"), kwargs.get("max_distance"), db_job=db_job)