@@ -44,9 +44,9 @@ def cdist(a, b, metric='euclidean'):
44
44
"""
45
45
with tf .name_scope ("cdist" ):
46
46
diffs = all_diffs (a , b )
47
- if metric == 'euclidean ' :
47
+ if metric == 'sqeuclidean ' :
48
48
return tf .reduce_sum (tf .square (diffs ), axis = - 1 )
49
- elif metric == 'sqeuclidean ' :
49
+ elif metric == 'euclidean ' :
50
50
return tf .sqrt (tf .reduce_sum (tf .square (diffs ), axis = - 1 ) + 1e-12 )
51
51
elif metric == 'cityblock' :
52
52
return tf .reduce_sum (tf .abs (diffs ), axis = - 1 )
@@ -82,10 +82,10 @@ def batch_hard(dists, pids, margin, batch_precision_at_k=None):
82
82
"""
83
83
with tf .name_scope ("batch_hard" ):
84
84
same_identity_mask = tf .equal (tf .expand_dims (pids , axis = 1 ),
85
- tf .expand_dims (pids , axis = 0 ))
85
+ tf .expand_dims (pids , axis = 0 ))
86
86
negative_mask = tf .logical_not (same_identity_mask )
87
87
positive_mask = tf .logical_xor (same_identity_mask ,
88
- tf .eye (tf .shape (pids )[0 ], dtype = tf .bool ))
88
+ tf .eye (tf .shape (pids )[0 ], dtype = tf .bool ))
89
89
90
90
furthest_positive = tf .reduce_max (dists * tf .cast (positive_mask , tf .float32 ), axis = 1 )
91
91
closest_negative = tf .map_fn (lambda x : tf .reduce_min (tf .boolean_mask (x [0 ], x [1 ])),
0 commit comments