7
7
import h5py
8
8
import json
9
9
import numpy as np
10
+ from sklearn .metrics import average_precision_score
10
11
import tensorflow as tf
11
12
12
13
import common
50
51
'--batch_size' , default = 256 , type = common .positive_int ,
51
52
help = 'Batch size used during evaluation, adapt based on your memory usage.' )
52
53
54
+ parser .add_argument (
55
+ '--use_market_ap' , action = 'store_true' , default = False ,
56
+ help = 'When this flag is provided, the average precision is computed exactly'
57
+ ' as done by the Market-1501 evaluation script, rather than the '
58
+ 'default scikit-learn implementation that gives slightly different'
59
+ 'scores.' )
60
+
53
61
54
- def average_precision_score (y_true , y_score ):
62
+ def average_precision_score_market (y_true , y_score ):
55
63
""" Compute average precision (AP) from prediction scores.
56
64
57
65
This is a replacement for the scikit-learn version which, while likely more
@@ -75,6 +83,8 @@ def average_precision_score(y_true, y_score):
75
83
'got lengths y_true:{} and y_score:{}' .format (
76
84
len (y_true ), len (y_score )))
77
85
86
+ # Mergesort is used since it is a stable sorting algorithm. This is
87
+ # important to compute consistent and correct scores.
78
88
y_true_sorted = y_true [np .argsort (- y_score , kind = 'mergesort' )]
79
89
80
90
tp = np .cumsum (y_true_sorted )
@@ -119,6 +129,12 @@ def main():
119
129
120
130
batch_distances = loss .cdist (batch_embs , gallery_embs , metric = args .metric )
121
131
132
+ # Check if we should use Market-1501 specific average precision computation.
133
+ if args .use_market_ap :
134
+ average_precision = average_precision_score_market
135
+ else :
136
+ average_precision = average_precision_score
137
+
122
138
# Loop over the query embeddings and compute their APs and the CMC curve.
123
139
aps = []
124
140
cmc = np .zeros (len (gallery_pids ), dtype = np .int32 )
@@ -153,7 +169,7 @@ def main():
153
169
# it won't change anything.
154
170
scores = 1 / (1 + distances )
155
171
for i in range (len (distances )):
156
- ap = average_precision_score (pid_matches [i ], scores [i ])
172
+ ap = average_precision (pid_matches [i ], scores [i ])
157
173
158
174
if np .isnan (ap ):
159
175
print ()
0 commit comments