Skip to content

Commit 2629741

Browse files
Update deep sort to have class name.
1 parent a7ec8b3 commit 2629741

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

Diff for: deep_sort/detection.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -21,17 +21,22 @@ class Detection(object):
2121
Bounding box in format `(top left x, top left y, width, height)`.
2222
confidence : ndarray
2323
Detector confidence score.
24+
class_name : ndarray
25+
Detector class.
2426
feature : ndarray | NoneType
2527
A feature vector that describes the object contained in this image.
2628
2729
"""
2830

29-
def __init__(self, tlwh, confidence, cls, feature):
31+
def __init__(self, tlwh, confidence, class_name, feature):
3032
self.tlwh = np.asarray(tlwh, dtype=np.float)
3133
self.confidence = float(confidence)
32-
self.cls = cls
34+
self.class_name = class_name
3335
self.feature = np.asarray(feature, dtype=np.float32)
3436

37+
def get_class(self):
38+
return self.class_name
39+
3540
def to_tlbr(self):
3641
"""Convert bounding box to format `(min x, min y, max x, max y)`, i.e.,
3742
`(top left, bottom right)`.

Diff for: deep_sort/track.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class Track:
6464
"""
6565

6666
def __init__(self, mean, covariance, track_id, n_init, max_age,
67-
feature=None):
67+
feature=None, class_name=None):
6868
self.mean = mean
6969
self.covariance = covariance
7070
self.track_id = track_id
@@ -79,6 +79,7 @@ def __init__(self, mean, covariance, track_id, n_init, max_age,
7979

8080
self._n_init = n_init
8181
self._max_age = max_age
82+
self.class_name = class_name
8283

8384
def to_tlwh(self):
8485
"""Get current position in bounding box format `(top left x, top left y,
@@ -108,6 +109,9 @@ def to_tlbr(self):
108109
ret = self.to_tlwh()
109110
ret[2:] = ret[:2] + ret[2:]
110111
return ret
112+
113+
def get_class(self):
114+
return self.class_name
111115

112116
def predict(self, kf):
113117
"""Propagate the state distribution to the current time step using a

Diff for: deep_sort/tracker.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,8 @@ def gated_metric(tracks, dets, track_indices, detection_indices):
132132

133133
def _initiate_track(self, detection):
134134
mean, covariance = self.kf.initiate(detection.to_xyah())
135+
class_name = detection.get_class()
135136
self.tracks.append(Track(
136137
mean, covariance, self._next_id, self.n_init, self.max_age,
137-
detection.feature))
138+
detection.feature, class_name))
138139
self._next_id += 1

0 commit comments

Comments
 (0)