Skip to content

Commit

Permalink
cvat-ui: simplify code related to model label types
Browse files Browse the repository at this point in the history
Recent changes to the Enterprise version make the following simplifications
possible:

* A label type may no longer be `unknown`, which means that the type of
  `MLModelLabel.type` can be changed to `LabelType`, and all special
  handling for `unknown` can be removed.

* Roboflow and Hugging Face models now correctly report their label types,
  which makes the dedicated `return_type` field unnecessary. The
  `ModelReturnType` enum is unnecessary as well, as it's just an arbitrary
  subset of `LabelType`.
  • Loading branch information
SpecLad committed Feb 7, 2025
1 parent fec040d commit 38fbd4e
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 26 deletions.
5 changes: 2 additions & 3 deletions cvat-core/src/core-types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
// SPDX-License-Identifier: MIT

import {
ModelKind, ModelReturnType, RQStatus, ShapeType,
LabelType, ModelKind, RQStatus,
} from './enums';

export interface ModelAttribute {
Expand All @@ -28,7 +28,7 @@ export interface MLModelTip {

export interface MLModelLabel {
name: string;
type: ShapeType | 'unknown';
type: LabelType;
attributes: ModelAttribute[];
sublabels?: MLModelLabel[];
svg?: string,
Expand All @@ -42,7 +42,6 @@ export interface SerializedModel {
description?: string;
kind?: ModelKind;
type?: string;
return_type?: ModelReturnType;
owner?: any;
provider?: string;
url?: string;
Expand Down
7 changes: 0 additions & 7 deletions cvat-core/src/enums.ts
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,6 @@ export enum ModelProviders {
CVAT = 'cvat',
}

export enum ModelReturnType {
RECTANGLE = 'rectangle',
TAG = 'tag',
POLYGON = 'polygon',
MASK = 'mask',
}

export const colors = [
'#33ddff',
'#fa3253',
Expand Down
18 changes: 13 additions & 5 deletions cvat-core/src/ml-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import PluginRegistry from './plugins';
import {
ModelProviders, ModelKind, ModelReturnType,
LabelType, ModelProviders, ModelKind,
} from './enums';
import {
SerializedModel, ModelParams, MLModelTip, MLModelLabel,
Expand Down Expand Up @@ -44,8 +44,11 @@ export default class MLModel {

public get displayKind(): string {
if (this.kind === ModelKind.DETECTOR) {
if (this.returnType === ModelReturnType.TAG) return 'classifier';
if (this.returnType === ModelReturnType.MASK) return 'segmenter';
switch (this.returnType) {
case LabelType.TAG: return 'classifier';
case LabelType.MASK: return 'segmenter';
default: // fall back on the original kind
}
}
return this.kind;
}
Expand Down Expand Up @@ -94,8 +97,13 @@ export default class MLModel {
return this.serialized?.url;
}

public get returnType(): ModelReturnType | undefined {
return this.serialized?.return_type;
public get returnType(): LabelType {
const uniqueLabelTypes = new Set(this.labels.map((label) => label.type));

if (uniqueLabelTypes.size !== 1) return LabelType.ANY;

const [labelType] = uniqueLabelTypes;
return labelType;
}

public async preview(): Promise<string> {
Expand Down
10 changes: 5 additions & 5 deletions cvat-ui/src/components/labels-editor/pick-from-model.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import Select from 'antd/lib/select';
import Text from 'antd/lib/typography';
import { PlusCircleOutlined } from '@ant-design/icons';

import { CombinedState, ShapeType } from 'reducers';
import MLModel from 'cvat-core/src/ml-model';
import { CombinedState } from 'reducers';
import { LabelType, MLModel } from 'cvat-core-wrapper';
import { LabelOptColor } from './common';

interface Props {
Expand Down Expand Up @@ -83,18 +83,18 @@ function PickFromModelComponent(props: Props): JSX.Element {
if (!labelNames.includes(label.name)) {
const generatedLabel: LabelOptColor = {
name: label.name,
type: label.type === 'unknown' ? 'any' : label.type as ShapeType,
type: label.type,
attributes: label.attributes.map((attr) => ({
...attr,
mutable: false,
default_value: attr.values[0],
})),
};

if (generatedLabel.type === ShapeType.SKELETON && label.sublabels && label.svg) {
if (generatedLabel.type === LabelType.SKELETON && label.sublabels && label.svg) {
generatedLabel.sublabels = label.sublabels.map((sublabel) => ({
name: sublabel.name,
type: sublabel.type === 'unknown' ? 'any' : sublabel.type as ShapeType,
type: sublabel.type,
attributes: sublabel.attributes.map((attr) => ({
...attr,
mutable: false,
Expand Down
4 changes: 2 additions & 2 deletions cvat-ui/src/components/model-runner-modal/detector-runner.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ import { ArrowRightOutlined, QuestionCircleOutlined } from '@ant-design/icons';
import CVATTooltip from 'components/common/cvat-tooltip';
import { clamp } from 'utils/math';
import {
MLModel, ModelKind, ModelReturnType, DimensionType, Label,
MLModel, ModelKind, DimensionType, Label, LabelType,
} from 'cvat-core-wrapper';

import LabelsMapperComponent, { LabelInterface, FullMapping } from './labels-mapper';
Expand Down Expand Up @@ -80,7 +80,7 @@ function DetectorRunner(props: Props): JSX.Element {
const isDetector = model?.kind === ModelKind.DETECTOR;
const isReId = model?.kind === ModelKind.REID;
const convertMasks2PolygonVisible = isDetector &&
(!model.returnType || model.returnType === ModelReturnType.MASK);
[LabelType.ANY, LabelType.MASK].includes(model.returnType);

const buttonEnabled = model && (isReId || (isDetector && mapping.length));

Expand Down
4 changes: 2 additions & 2 deletions cvat-ui/src/components/model-runner-modal/labels-mapper.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ function labelsCompatible(modelLabel: LabelInterface, jobLabel: LabelInterface):
const { type: jobLabelType } = jobLabel;
const compatibleTypes = [[LabelType.MASK, LabelType.POLYGON]];
return modelLabelType === jobLabelType ||
(jobLabelType === 'any' && modelLabelType !== LabelType.SKELETON) ||
((modelLabelType === 'any' || modelLabelType === 'unknown') && jobLabelType !== LabelType.SKELETON) || // legacy support
(jobLabelType === LabelType.ANY && modelLabelType !== LabelType.SKELETON) ||
(modelLabelType === LabelType.ANY && jobLabelType !== LabelType.SKELETON) ||
compatibleTypes.some((compatible) => compatible.includes(jobLabelType) && compatible.includes(modelLabelType));
}

Expand Down
3 changes: 1 addition & 2 deletions cvat-ui/src/cvat-core-wrapper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import { FramesMetaData, FrameData } from 'cvat-core/src/frames';
import { ServerError, RequestError } from 'cvat-core/src/exceptions';
import {
ShapeType, ObjectType, LabelType, ModelKind, ModelProviders,
ModelReturnType, DimensionType, JobType, Source,
DimensionType, JobType, Source,
JobStage, JobState, RQStatus, StorageLocation,
} from 'cvat-core/src/enums';
import { Storage, StorageData } from 'cvat-core/src/storage';
Expand Down Expand Up @@ -85,7 +85,6 @@ export {
MLModel,
ModelKind,
ModelProviders,
ModelReturnType,
DimensionType,
Dumper,
JobType,
Expand Down

0 comments on commit 38fbd4e

Please sign in to comment.