Skip to content

Commit ad2e8b5

Browse files
author
KevinMusgrave
committed
Fixes bug where BaseDistance.initial_avg_query_norm was not actually being set
1 parent d94576c commit ad2e8b5

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "2.1.0"
1+
__version__ = "2.1.1"

src/pytorch_metric_learning/distances/base_distance.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,9 @@ def set_default_stats(
7979
):
8080
if self.collect_stats:
8181
with torch.no_grad():
82-
self.initial_avg_query_norm: torch.mean(self.get_norm(query_emb)).item()
82+
self.initial_avg_query_norm = torch.mean(
83+
self.get_norm(query_emb)
84+
).item()
8385
self.initial_avg_ref_norm = torch.mean(self.get_norm(ref_emb)).item()
8486
self.final_avg_query_norm = torch.mean(
8587
self.get_norm(query_emb_normalized)
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import unittest
2+
3+
import torch
4+
5+
from pytorch_metric_learning.distances import LpDistance
6+
7+
from .. import WITH_COLLECT_STATS
8+
9+
10+
class TestCollectedStats(unittest.TestCase):
11+
@unittest.skipUnless(WITH_COLLECT_STATS, "WITH_COLLECT_STATS is false")
12+
def test_collected_stats(self):
13+
x = torch.randn(32, 128)
14+
d = LpDistance()
15+
d(x)
16+
17+
self.assertNotEqual(d.initial_avg_query_norm, 0)
18+
self.assertNotEqual(d.initial_avg_ref_norm, 0)
19+
self.assertNotEqual(d.final_avg_query_norm, 0)
20+
self.assertNotEqual(d.final_avg_ref_norm, 0)

0 commit comments

Comments
 (0)