Skip to content

Commit a538696

Browse files
authored
Merge pull request #45 from Pandoro/fork/ap_computation
Fixes the "wrong" computation of the AP score for newer sklearn versions.
2 parents 8d4dca0 + 047b989 commit a538696

File tree

2 files changed

+56
-1
lines changed

2 files changed

+56
-1
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ The evaluation code in this repository simply uses the scikit-learn code, and th
273273
Unfortunately, almost no paper mentions which code-base they used and how they computed `mAP` scores, so comparison is difficult.
274274
Other frameworks have [the same problem](https://github.com/Cysu/open-reid/issues/50), but we expect many not to be aware of this.
275275

276+
We provide evaluation code that computes the mAP as done by the Market-1501 MATLAB evaluation script, independent of the scikit-learn version.
277+
This can be used by providing the `--use_market_ap` flag when running `evaluate.py`.
278+
276279
# Independent re-implementations
277280

278281
These are the independent re-implementations of our paper that we are aware of,

evaluate.py

+53-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,52 @@
5151
'--batch_size', default=256, type=common.positive_int,
5252
help='Batch size used during evaluation, adapt based on your memory usage.')
5353

54+
parser.add_argument(
55+
'--use_market_ap', action='store_true', default=False,
56+
help='When this flag is provided, the average precision is computed exactly'
57+
' as done by the Market-1501 evaluation script, rather than the '
58+
'default scikit-learn implementation that gives slightly different'
59+
'scores.')
60+
61+
62+
def average_precision_score_market(y_true, y_score):
63+
""" Compute average precision (AP) from prediction scores.
64+
65+
This is a replacement for the scikit-learn version which, while likely more
66+
correct does not follow the same protocol as used in the default Market-1501
67+
evaluation that first introduced this score to the person ReID field.
68+
69+
Args:
70+
y_true (array): The binary labels for all data points.
71+
y_score (array): The predicted scores for each samples for all data
72+
points.
73+
74+
Raises:
75+
ValueError if the length of the labels and scores do not match.
76+
77+
Returns:
78+
A float representing the average precision given the predictions.
79+
"""
80+
81+
if len(y_true) != len(y_score):
82+
raise ValueError('The length of the labels and predictions must match '
83+
'got lengths y_true:{} and y_score:{}'.format(
84+
len(y_true), len(y_score)))
85+
86+
# Mergesort is used since it is a stable sorting algorithm. This is
87+
# important to compute consistent and correct scores.
88+
y_true_sorted = y_true[np.argsort(-y_score, kind='mergesort')]
89+
90+
tp = np.cumsum(y_true_sorted)
91+
total_true = np.sum(y_true_sorted)
92+
recall = tp / total_true
93+
recall = np.insert(recall, 0, 0.)
94+
precision = tp / np.arange(1, len(tp) + 1)
95+
precision = np.insert(precision, 0, 1.)
96+
ap = np.sum(np.diff(recall) * ((precision[1:] + precision[:-1]) / 2))
97+
98+
return ap
99+
54100

55101
def main():
56102
# Verify that parameters are set correctly.
@@ -83,6 +129,12 @@ def main():
83129

84130
batch_distances = loss.cdist(batch_embs, gallery_embs, metric=args.metric)
85131

132+
# Check if we should use Market-1501 specific average precision computation.
133+
if args.use_market_ap:
134+
average_precision = average_precision_score_market
135+
else:
136+
average_precision = average_precision_score
137+
86138
# Loop over the query embeddings and compute their APs and the CMC curve.
87139
aps = []
88140
cmc = np.zeros(len(gallery_pids), dtype=np.int32)
@@ -117,7 +169,7 @@ def main():
117169
# it won't change anything.
118170
scores = 1 / (1 + distances)
119171
for i in range(len(distances)):
120-
ap = average_precision_score(pid_matches[i], scores[i])
172+
ap = average_precision(pid_matches[i], scores[i])
121173

122174
if np.isnan(ap):
123175
print()

0 commit comments

Comments
 (0)