Skip to content

Commit

Permalink
Refactor segment matching
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiltsov-max committed Jan 31, 2025
1 parent 08d8421 commit b94d430
Showing 1 changed file with 40 additions and 22 deletions.
62 changes: 40 additions & 22 deletions cvat/apps/quality_control/quality_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from collections.abc import Hashable, Sequence
from copy import deepcopy
from functools import cached_property, partial
from typing import Any, Callable, Optional, Union, cast
from typing import Any, Callable, Optional, TypeVar, Union, cast

import datumaro
import datumaro as dm
Expand Down Expand Up @@ -661,13 +661,31 @@ def _convert_shape(self, shape, *, index):
return converted


_ShapeT1 = TypeVar("_ShapeT1")
_ShapeT2 = TypeVar("_ShapeT2")
ShapeSimilarityFunction = Callable[
[_ShapeT1, _ShapeT2], float
] # (shape1, shape2) -> [0; 1], returns 0 for mismatches, 1 for matches
LabelEqualityFunction = Callable[[_ShapeT1, _ShapeT2], bool]
SegmentMatchingResult = tuple[
list[tuple[_ShapeT1, _ShapeT2]], # matches
list[tuple[_ShapeT1, _ShapeT2]], # mismatches
list[_ShapeT1], # a unmatched
list[_ShapeT2], # b unmatched
]


def match_segments(
a_segms,
b_segms,
distance=datumaro.util.annotation_util.segment_iou,
dist_thresh=1.0,
label_matcher=lambda a, b: a.label == b.label,
):
a_segms: Sequence[_ShapeT1],
b_segms: Sequence[_ShapeT2],
*,
distance: ShapeSimilarityFunction[_ShapeT1, _ShapeT2],
dist_thresh: float = 1.0,
label_matcher: LabelEqualityFunction[_ShapeT1, _ShapeT2] = lambda a, b: a.label == b.label,
) -> SegmentMatchingResult[_ShapeT1, _ShapeT2]:
# Comparing to the dm version, this one changes the algorithm to match shapes first
# label comparison is only used to distinguish between matches and mismatches

assert callable(distance), distance
assert callable(label_matcher), label_matcher

Expand All @@ -690,11 +708,11 @@ def match_segments(
a_matches = []
b_matches = []

# matches: boxes we succeeded to match completely
# mispred: boxes we succeeded to match, having label mismatch
# matches: segments we succeeded to match completely
# mispred: segments we succeeded to match, having label mismatch
matches = []
mispred = []
# *_umatched: boxes of (*) we failed to match
# *_umatched: segments of (*) we failed to match
a_unmatched = []
b_unmatched = []

Expand Down Expand Up @@ -1021,12 +1039,12 @@ def to_polygon(bbox_ann: dm.Bbox):
return dm.Polygon(points)

@staticmethod
def _get_ann_type(t, item: dm.DatasetItem) -> Sequence[dm.Annotation]:
def _get_ann_type(t: dm.AnnotationType, item: dm.DatasetItem) -> Sequence[dm.Annotation]:
return [
a for a in item.annotations if a.type == t and not a.attributes.get("outside", False)
]

def _match_ann_type(self, t, *args):
def _match_ann_type(self, t: dm.AnnotationType, *args):
if t not in self.included_ann_types:
return None

Expand All @@ -1047,8 +1065,8 @@ def _match_ann_type(self, t, *args):
else:
return None

def match_labels(self, item_a, item_b):
def label_distance(a, b):
def match_labels(self, item_a: dm.DatasetItem, item_b: dm.DatasetItem):
def label_distance(a: dm.Label, b: dm.Label) -> float:
if a is None or b is None:
return 0
return 0.5 + (a.label == b.label) / 2
Expand All @@ -1064,15 +1082,15 @@ def label_distance(a, b):

def match_segments(
self,
t,
item_a,
item_b,
t: dm.AnnotationType,
item_a: dm.DatasetItem,
item_b: dm.DatasetItem,
*,
distance: Callable = datumaro.util.annotation_util.segment_iou,
label_matcher: Callable = None,
a_objs: Optional[Sequence[dm.Annotation]] = None,
b_objs: Optional[Sequence[dm.Annotation]] = None,
dist_thresh: Optional[float] = None,
distance: ShapeSimilarityFunction[_ShapeT1, _ShapeT2],
label_matcher: LabelEqualityFunction[_ShapeT1, _ShapeT2] | None = None,
a_objs: Sequence[_ShapeT1] | None = None,
b_objs: Sequence[_ShapeT2] | None = None,
dist_thresh: float | None = None,
):
if a_objs is None:
a_objs = self._get_ann_type(t, item_a)
Expand Down

0 comments on commit b94d430

Please sign in to comment.