From b94d430c3fa239bdf7621bbef8b090d35582b04e Mon Sep 17 00:00:00 2001 From: Maxim Zhiltsov Date: Fri, 31 Jan 2025 14:45:34 +0200 Subject: [PATCH] Refactor segment matching --- cvat/apps/quality_control/quality_reports.py | 62 +++++++++++++------- 1 file changed, 40 insertions(+), 22 deletions(-) diff --git a/cvat/apps/quality_control/quality_reports.py b/cvat/apps/quality_control/quality_reports.py index a34eb10c14f6..c4e6d7ea2649 100644 --- a/cvat/apps/quality_control/quality_reports.py +++ b/cvat/apps/quality_control/quality_reports.py @@ -10,7 +10,7 @@ from collections.abc import Hashable, Sequence from copy import deepcopy from functools import cached_property, partial -from typing import Any, Callable, Optional, Union, cast +from typing import Any, Callable, Optional, TypeVar, Union, cast import datumaro import datumaro as dm @@ -661,13 +661,31 @@ def _convert_shape(self, shape, *, index): return converted +_ShapeT1 = TypeVar("_ShapeT1") +_ShapeT2 = TypeVar("_ShapeT2") +ShapeSimilarityFunction = Callable[ + [_ShapeT1, _ShapeT2], float +] # (shape1, shape2) -> [0; 1], returns 0 for mismatches, 1 for matches +LabelEqualityFunction = Callable[[_ShapeT1, _ShapeT2], bool] +SegmentMatchingResult = tuple[ + list[tuple[_ShapeT1, _ShapeT2]], # matches + list[tuple[_ShapeT1, _ShapeT2]], # mismatches + list[_ShapeT1], # a unmatched + list[_ShapeT2], # b unmatched +] + + def match_segments( - a_segms, - b_segms, - distance=datumaro.util.annotation_util.segment_iou, - dist_thresh=1.0, - label_matcher=lambda a, b: a.label == b.label, -): + a_segms: Sequence[_ShapeT1], + b_segms: Sequence[_ShapeT2], + *, + distance: ShapeSimilarityFunction[_ShapeT1, _ShapeT2], + dist_thresh: float = 1.0, + label_matcher: LabelEqualityFunction[_ShapeT1, _ShapeT2] = lambda a, b: a.label == b.label, +) -> SegmentMatchingResult[_ShapeT1, _ShapeT2]: + # Comparing to the dm version, this one changes the algorithm to match shapes first + # label comparison is only used to distinguish between matches and mismatches + assert callable(distance), distance assert callable(label_matcher), label_matcher @@ -690,11 +708,11 @@ def match_segments( a_matches = [] b_matches = [] - # matches: boxes we succeeded to match completely - # mispred: boxes we succeeded to match, having label mismatch + # matches: segments we succeeded to match completely + # mispred: segments we succeeded to match, having label mismatch matches = [] mispred = [] - # *_umatched: boxes of (*) we failed to match + # *_umatched: segments of (*) we failed to match a_unmatched = [] b_unmatched = [] @@ -1021,12 +1039,12 @@ def to_polygon(bbox_ann: dm.Bbox): return dm.Polygon(points) @staticmethod - def _get_ann_type(t, item: dm.DatasetItem) -> Sequence[dm.Annotation]: + def _get_ann_type(t: dm.AnnotationType, item: dm.DatasetItem) -> Sequence[dm.Annotation]: return [ a for a in item.annotations if a.type == t and not a.attributes.get("outside", False) ] - def _match_ann_type(self, t, *args): + def _match_ann_type(self, t: dm.AnnotationType, *args): if t not in self.included_ann_types: return None @@ -1047,8 +1065,8 @@ def _match_ann_type(self, t, *args): else: return None - def match_labels(self, item_a, item_b): - def label_distance(a, b): + def match_labels(self, item_a: dm.DatasetItem, item_b: dm.DatasetItem): + def label_distance(a: dm.Label, b: dm.Label) -> float: if a is None or b is None: return 0 return 0.5 + (a.label == b.label) / 2 @@ -1064,15 +1082,15 @@ def label_distance(a, b): def match_segments( self, - t, - item_a, - item_b, + t: dm.AnnotationType, + item_a: dm.DatasetItem, + item_b: dm.DatasetItem, *, - distance: Callable = datumaro.util.annotation_util.segment_iou, - label_matcher: Callable = None, - a_objs: Optional[Sequence[dm.Annotation]] = None, - b_objs: Optional[Sequence[dm.Annotation]] = None, - dist_thresh: Optional[float] = None, + distance: ShapeSimilarityFunction[_ShapeT1, _ShapeT2], + label_matcher: LabelEqualityFunction[_ShapeT1, _ShapeT2] | None = None, + a_objs: Sequence[_ShapeT1] | None = None, + b_objs: Sequence[_ShapeT2] | None = None, + dist_thresh: float | None = None, ): if a_objs is None: a_objs = self._get_ann_type(t, item_a)