Skip to content

Commit

Permalink
Rename and move LambdaType (#8418)
Browse files Browse the repository at this point in the history
This enum is used for a field named `kind` in a resource located at
`/api/lambda/function`, so it seems pretty clear that it should be named
`FunctionKind`. (Or perhaps `LambdaFunctionKind`, but I omitted the
"lambda" for consistency with views and serializers.)

In addition to renaming, move it to `models`, so that it can be used by
serializers. No such serializers currently exist, but I'd like to add
them later. Turn it into a Django choice enum as well, so that `__str__`
works out of the box.
  • Loading branch information
SpecLad authored Sep 9, 2024
1 parent cc5a016 commit 88ce0a4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 19 deletions.
12 changes: 12 additions & 0 deletions cvat/apps/lambda_manager/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Copyright (C) 2022 Intel Corporation
# Copyright (C) 2022-2024 CVAT.ai Corporation
#
# SPDX-License-Identifier: MIT

import django.db.models as models

class FunctionKind(models.TextChoices):
DETECTOR = "detector"
INTERACTOR = "interactor"
REID = "reid"
TRACKER = "tracker"
29 changes: 10 additions & 19 deletions cvat/apps/lambda_manager/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
import textwrap
from copy import deepcopy
from datetime import timedelta
from enum import Enum
from functools import wraps
from typing import Any, Dict, Optional

Expand Down Expand Up @@ -39,6 +38,7 @@
)
from cvat.apps.engine.rq_job_handler import RQId, RQJobMetaField
from cvat.apps.engine.serializers import LabeledDataSerializer
from cvat.apps.lambda_manager.models import FunctionKind
from cvat.apps.lambda_manager.permissions import LambdaPermission
from cvat.apps.lambda_manager.serializers import (
FunctionCallRequestSerializer, FunctionCallSerializer
Expand All @@ -50,15 +50,6 @@

slogger = ServerLogManager(__name__)

class LambdaType(Enum):
DETECTOR = "detector"
INTERACTOR = "interactor"
REID = "reid"
TRACKER = "tracker"

def __str__(self):
return self.value

class LambdaGateway:
NUCLIO_ROOT_URL = '/api/functions'

Expand Down Expand Up @@ -152,7 +143,7 @@ def __init__(self, gateway, data):
meta_anno = data['metadata']['annotations']
kind = meta_anno.get('type')
try:
self.kind = LambdaType(kind)
self.kind = FunctionKind(kind)
except ValueError as e:
raise InvalidFunctionMetadataError(
f"{self.id} lambda function has unknown type: {kind!r}") from e
Expand Down Expand Up @@ -225,7 +216,7 @@ def to_dict(self):
'version': self.version
}

if self.kind is LambdaType.INTERACTOR:
if self.kind is FunctionKind.INTERACTOR:
response.update({
'min_pos_points': self.min_pos_points,
'min_neg_points': self.min_neg_points,
Expand Down Expand Up @@ -394,18 +385,18 @@ def validate_attributes_mapping(attributes_mapping, model_attributes, db_attribu
code=status.HTTP_400_BAD_REQUEST)


if self.kind == LambdaType.DETECTOR:
if self.kind == FunctionKind.DETECTOR:
payload.update({
"image": self._get_image(db_task, mandatory_arg("frame"), quality)
})
elif self.kind == LambdaType.INTERACTOR:
elif self.kind == FunctionKind.INTERACTOR:
payload.update({
"image": self._get_image(db_task, mandatory_arg("frame"), quality),
"pos_points": mandatory_arg("pos_points"),
"neg_points": mandatory_arg("neg_points"),
"obj_bbox": data.get("obj_bbox", None)
})
elif self.kind == LambdaType.REID:
elif self.kind == FunctionKind.REID:
payload.update({
"image0": self._get_image(db_task, mandatory_arg("frame0"), quality),
"image1": self._get_image(db_task, mandatory_arg("frame1"), quality),
Expand All @@ -417,7 +408,7 @@ def validate_attributes_mapping(attributes_mapping, model_attributes, db_attribu
payload.update({
"max_distance": max_distance
})
elif self.kind == LambdaType.TRACKER:
elif self.kind == FunctionKind.TRACKER:
payload.update({
"image": self._get_image(db_task, mandatory_arg("frame"), quality),
"shapes": data.get("shapes", []),
Expand Down Expand Up @@ -466,7 +457,7 @@ def transform_attributes(input_attributes, attr_mapping, db_attributes):
})
return attributes

if self.kind == LambdaType.DETECTOR:
if self.kind == FunctionKind.DETECTOR:
for item in response:
item_label = item['label']
if item_label not in mapping:
Expand Down Expand Up @@ -985,11 +976,11 @@ def convert_labels(db_labels):

labels = convert_labels(db_task.get_labels(prefetch=True))

if function.kind == LambdaType.DETECTOR:
if function.kind == FunctionKind.DETECTOR:
cls._call_detector(function, db_task, labels, quality,
kwargs.get("threshold"), kwargs.get("mapping"), kwargs.get("conv_mask_to_poly"),
db_job=db_job)
elif function.kind == LambdaType.REID:
elif function.kind == FunctionKind.REID:
cls._call_reid(function, db_task, quality,
kwargs.get("threshold"), kwargs.get("max_distance"), db_job=db_job)

Expand Down

0 comments on commit 88ce0a4

Please sign in to comment.