File tree 3 files changed +24
-2
lines changed
src/pytorch_metric_learning 3 files changed +24
-2
lines changed Original file line number Diff line number Diff line change 1
- __version__ = "2.1.0 "
1
+ __version__ = "2.1.1 "
Original file line number Diff line number Diff line change @@ -79,7 +79,9 @@ def set_default_stats(
79
79
):
80
80
if self .collect_stats :
81
81
with torch .no_grad ():
82
- self .initial_avg_query_norm : torch .mean (self .get_norm (query_emb )).item ()
82
+ self .initial_avg_query_norm = torch .mean (
83
+ self .get_norm (query_emb )
84
+ ).item ()
83
85
self .initial_avg_ref_norm = torch .mean (self .get_norm (ref_emb )).item ()
84
86
self .final_avg_query_norm = torch .mean (
85
87
self .get_norm (query_emb_normalized )
Original file line number Diff line number Diff line change
1
+ import unittest
2
+
3
+ import torch
4
+
5
+ from pytorch_metric_learning .distances import LpDistance
6
+
7
+ from .. import WITH_COLLECT_STATS
8
+
9
+
10
+ class TestCollectedStats (unittest .TestCase ):
11
+ @unittest .skipUnless (WITH_COLLECT_STATS , "WITH_COLLECT_STATS is false" )
12
+ def test_collected_stats (self ):
13
+ x = torch .randn (32 , 128 )
14
+ d = LpDistance ()
15
+ d (x )
16
+
17
+ self .assertNotEqual (d .initial_avg_query_norm , 0 )
18
+ self .assertNotEqual (d .initial_avg_ref_norm , 0 )
19
+ self .assertNotEqual (d .final_avg_query_norm , 0 )
20
+ self .assertNotEqual (d .final_avg_ref_norm , 0 )
You can’t perform that action at this time.
0 commit comments