Skip to content

Commit 0e30b89

Browse files
authored
Merge pull request #12 from VisualComputingInstitute/fix-sqeuclid
Fix important bug
2 parents 23d314a + 31a3b08 commit 0e30b89

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

loss.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ def cdist(a, b, metric='euclidean'):
4444
"""
4545
with tf.name_scope("cdist"):
4646
diffs = all_diffs(a, b)
47-
if metric == 'euclidean':
47+
if metric == 'sqeuclidean':
4848
return tf.reduce_sum(tf.square(diffs), axis=-1)
49-
elif metric == 'sqeuclidean':
49+
elif metric == 'euclidean':
5050
return tf.sqrt(tf.reduce_sum(tf.square(diffs), axis=-1) + 1e-12)
5151
elif metric == 'cityblock':
5252
return tf.reduce_sum(tf.abs(diffs), axis=-1)
@@ -82,10 +82,10 @@ def batch_hard(dists, pids, margin, batch_precision_at_k=None):
8282
"""
8383
with tf.name_scope("batch_hard"):
8484
same_identity_mask = tf.equal(tf.expand_dims(pids, axis=1),
85-
tf.expand_dims(pids, axis=0))
85+
tf.expand_dims(pids, axis=0))
8686
negative_mask = tf.logical_not(same_identity_mask)
8787
positive_mask = tf.logical_xor(same_identity_mask,
88-
tf.eye(tf.shape(pids)[0], dtype=tf.bool))
88+
tf.eye(tf.shape(pids)[0], dtype=tf.bool))
8989

9090
furthest_positive = tf.reduce_max(dists*tf.cast(positive_mask, tf.float32), axis=1)
9191
closest_negative = tf.map_fn(lambda x: tf.reduce_min(tf.boolean_mask(x[0], x[1])),

0 commit comments

Comments
 (0)