|
51 | 51 | '--batch_size', default=256, type=common.positive_int,
|
52 | 52 | help='Batch size used during evaluation, adapt based on your memory usage.')
|
53 | 53 |
|
| 54 | +parser.add_argument( |
| 55 | + '--use_market_ap', action='store_true', default=False, |
| 56 | + help='When this flag is provided, the average precision is computed exactly' |
| 57 | + ' as done by the Market-1501 evaluation script, rather than the ' |
| 58 | + 'default scikit-learn implementation that gives slightly different' |
| 59 | + 'scores.') |
| 60 | + |
| 61 | + |
| 62 | +def average_precision_score_market(y_true, y_score): |
| 63 | + """ Compute average precision (AP) from prediction scores. |
| 64 | +
|
| 65 | + This is a replacement for the scikit-learn version which, while likely more |
| 66 | + correct does not follow the same protocol as used in the default Market-1501 |
| 67 | + evaluation that first introduced this score to the person ReID field. |
| 68 | +
|
| 69 | + Args: |
| 70 | + y_true (array): The binary labels for all data points. |
| 71 | + y_score (array): The predicted scores for each samples for all data |
| 72 | + points. |
| 73 | +
|
| 74 | + Raises: |
| 75 | + ValueError if the length of the labels and scores do not match. |
| 76 | +
|
| 77 | + Returns: |
| 78 | + A float representing the average precision given the predictions. |
| 79 | + """ |
| 80 | + |
| 81 | + if len(y_true) != len(y_score): |
| 82 | + raise ValueError('The length of the labels and predictions must match ' |
| 83 | + 'got lengths y_true:{} and y_score:{}'.format( |
| 84 | + len(y_true), len(y_score))) |
| 85 | + |
| 86 | + # Mergesort is used since it is a stable sorting algorithm. This is |
| 87 | + # important to compute consistent and correct scores. |
| 88 | + y_true_sorted = y_true[np.argsort(-y_score, kind='mergesort')] |
| 89 | + |
| 90 | + tp = np.cumsum(y_true_sorted) |
| 91 | + total_true = np.sum(y_true_sorted) |
| 92 | + recall = tp / total_true |
| 93 | + recall = np.insert(recall, 0, 0.) |
| 94 | + precision = tp / np.arange(1, len(tp) + 1) |
| 95 | + precision = np.insert(precision, 0, 1.) |
| 96 | + ap = np.sum(np.diff(recall) * ((precision[1:] + precision[:-1]) / 2)) |
| 97 | + |
| 98 | + return ap |
| 99 | + |
54 | 100 |
|
55 | 101 | def main():
|
56 | 102 | # Verify that parameters are set correctly.
|
@@ -83,6 +129,12 @@ def main():
|
83 | 129 |
|
84 | 130 | batch_distances = loss.cdist(batch_embs, gallery_embs, metric=args.metric)
|
85 | 131 |
|
| 132 | + # Check if we should use Market-1501 specific average precision computation. |
| 133 | + if args.use_market_ap: |
| 134 | + average_precision = average_precision_score_market |
| 135 | + else: |
| 136 | + average_precision = average_precision_score |
| 137 | + |
86 | 138 | # Loop over the query embeddings and compute their APs and the CMC curve.
|
87 | 139 | aps = []
|
88 | 140 | cmc = np.zeros(len(gallery_pids), dtype=np.int32)
|
@@ -117,7 +169,7 @@ def main():
|
117 | 169 | # it won't change anything.
|
118 | 170 | scores = 1 / (1 + distances)
|
119 | 171 | for i in range(len(distances)):
|
120 |
| - ap = average_precision_score(pid_matches[i], scores[i]) |
| 172 | + ap = average_precision(pid_matches[i], scores[i]) |
121 | 173 |
|
122 | 174 | if np.isnan(ap):
|
123 | 175 | print()
|
|
0 commit comments