8
8
import torch .optim as optim
9
9
from torch .nn .parallel import DistributedDataParallel as DDP
10
10
11
- from pytorch_metric_learning import losses , miners
11
+ from pytorch_metric_learning .losses import ContrastiveLoss , CrossBatchMemory
12
+ from pytorch_metric_learning .miners import MultiSimilarityMiner
12
13
from pytorch_metric_learning .utils import distributed
13
14
14
15
from .. import TEST_DEVICE , TEST_DTYPES
@@ -52,6 +53,7 @@ def single_process_function(
52
53
rank ,
53
54
world_size ,
54
55
lr ,
56
+ iterations ,
55
57
model ,
56
58
inputs ,
57
59
labels ,
@@ -83,20 +85,23 @@ def single_process_function(
83
85
)
84
86
85
87
optimizer = optim .SGD (ddp_mp_model .parameters (), lr = lr )
86
- optimizer .zero_grad ()
87
- outputs = ddp_mp_model (inputs [rank ].to (device ))
88
- indices_tuple = None
89
- if miner_fn :
90
- indices_tuple = miner_fn (outputs , labels [rank ])
91
- loss = loss_fn (outputs , labels [rank ], indices_tuple )
92
-
93
- dist .barrier ()
94
- loss .backward ()
95
88
96
89
original_model = original_model .to (device )
97
90
assert not parameters_are_equal (original_model , ddp_mp_model .module )
98
- dist .barrier ()
99
- optimizer .step ()
91
+
92
+ for i in range (iterations ):
93
+ optimizer .zero_grad ()
94
+ outputs = ddp_mp_model (inputs [i ][rank ].to (device ))
95
+ indices_tuple = None
96
+ if miner_fn :
97
+ indices_tuple = miner_fn (outputs , labels [i ][rank ])
98
+ loss = loss_fn (outputs , labels [i ][rank ], indices_tuple )
99
+
100
+ dist .barrier ()
101
+ loss .backward ()
102
+ dist .barrier ()
103
+ optimizer .step ()
104
+
100
105
dist .barrier ()
101
106
assert parameters_are_equal (original_model , ddp_mp_model .module )
102
107
dist .barrier ()
@@ -113,7 +118,7 @@ def create_efficient_batch(x, i, batch_size):
113
118
114
119
115
120
class TestDistributedLossWrapper (unittest .TestCase ):
116
- def loss_and_miner_tester (self , loss_class , miner_class , efficient ):
121
+ def loss_and_miner_tester (self , loss_class , miner_class , efficient , xbm ):
117
122
torch .manual_seed (75210 )
118
123
if TEST_DEVICE == torch .device ("cpu" ):
119
124
return
@@ -129,13 +134,7 @@ def loss_and_miner_tester(self, loss_class, miner_class, efficient):
129
134
for world_size in range (2 , max_world_size + 1 ):
130
135
batch_size = 20
131
136
lr = 1
132
- inputs = [
133
- torch .randn (batch_size , 10 ).type (dtype ) for _ in range (world_size )
134
- ]
135
- labels = [
136
- torch .randint (low = 0 , high = 2 , size = (batch_size ,))
137
- for _ in range (world_size )
138
- ]
137
+ iterations = 10
139
138
original_model = ToyMpModel ().type (dtype )
140
139
model = ToyMpModel ().type (dtype )
141
140
model .load_state_dict (original_model .state_dict ())
@@ -144,6 +143,11 @@ def loss_and_miner_tester(self, loss_class, miner_class, efficient):
144
143
original_model = original_model .to (TEST_DEVICE )
145
144
original_loss_fn = loss_class ()
146
145
loss_fn = loss_class ()
146
+ if xbm :
147
+ original_loss_fn = CrossBatchMemory (
148
+ original_loss_fn , embedding_size = 5
149
+ )
150
+ loss_fn = CrossBatchMemory (loss_fn , embedding_size = 5 )
147
151
148
152
if miner_class :
149
153
original_miner_fn = miner_class ()
@@ -153,54 +157,68 @@ def loss_and_miner_tester(self, loss_class, miner_class, efficient):
153
157
miner_fn = None
154
158
155
159
optimizer = optim .SGD (original_model .parameters (), lr = lr )
156
- optimizer .zero_grad ()
157
- all_inputs = torch .cat (inputs , dim = 0 ).to (TEST_DEVICE )
158
- all_labels = torch .cat (labels , dim = 0 ).to (TEST_DEVICE )
159
- all_outputs = original_model (all_inputs )
160
- indices_tuple = None
161
- if efficient :
162
- losses = []
163
- for i in range (len (inputs )):
164
- curr_emb , other_emb = create_efficient_batch (
165
- all_outputs , i , batch_size
166
- )
167
- curr_labels , other_labels = create_efficient_batch (
168
- all_labels , i , batch_size
169
- )
170
- if original_miner_fn :
171
- indices_tuple = distributed .get_indices_tuple (
160
+ inputs = [
161
+ [torch .randn (batch_size , 10 ).type (dtype ) for _ in range (world_size )]
162
+ for _ in range (iterations )
163
+ ]
164
+ labels = [
165
+ [
166
+ torch .randint (low = 0 , high = 2 , size = (batch_size ,))
167
+ for _ in range (world_size )
168
+ ]
169
+ for _ in range (iterations )
170
+ ]
171
+
172
+ for aaa in range (iterations ):
173
+ optimizer .zero_grad ()
174
+ all_inputs = torch .cat (inputs [aaa ], dim = 0 ).to (TEST_DEVICE )
175
+ all_labels = torch .cat (labels [aaa ], dim = 0 ).to (TEST_DEVICE )
176
+ all_outputs = original_model (all_inputs )
177
+ indices_tuple = None
178
+ if efficient :
179
+ losses = []
180
+ for i in range (len (inputs [aaa ])):
181
+ curr_emb , other_emb = create_efficient_batch (
182
+ all_outputs , i , batch_size
183
+ )
184
+ curr_labels , other_labels = create_efficient_batch (
185
+ all_labels , i , batch_size
186
+ )
187
+ if original_miner_fn :
188
+ indices_tuple = distributed .get_indices_tuple (
189
+ curr_labels ,
190
+ other_labels ,
191
+ TEST_DEVICE ,
192
+ embeddings = curr_emb ,
193
+ ref_emb = other_emb ,
194
+ miner = original_miner_fn ,
195
+ )
196
+ else :
197
+ indices_tuple = distributed .get_indices_tuple (
198
+ curr_labels , other_labels , TEST_DEVICE
199
+ )
200
+ loss = original_loss_fn (
201
+ curr_emb ,
172
202
curr_labels ,
203
+ indices_tuple ,
204
+ other_emb ,
173
205
other_labels ,
174
- TEST_DEVICE ,
175
- embeddings = curr_emb ,
176
- ref_emb = other_emb ,
177
- miner = original_miner_fn ,
178
206
)
179
- else :
180
- indices_tuple = distributed .get_indices_tuple (
181
- curr_labels , other_labels , TEST_DEVICE
182
- )
183
- loss = original_loss_fn (
184
- curr_emb ,
185
- curr_labels ,
186
- indices_tuple ,
187
- other_emb ,
188
- other_labels ,
189
- )
190
- losses .append (loss )
191
- loss = sum (losses )
192
- else :
193
- if original_miner_fn :
194
- indices_tuple = original_miner_fn (all_outputs , all_labels )
195
- loss = original_loss_fn (all_outputs , all_labels , indices_tuple )
196
- loss .backward ()
197
- optimizer .step ()
207
+ losses .append (loss )
208
+ loss = sum (losses )
209
+ else :
210
+ if original_miner_fn :
211
+ indices_tuple = original_miner_fn (all_outputs , all_labels )
212
+ loss = original_loss_fn (all_outputs , all_labels , indices_tuple )
213
+ loss .backward ()
214
+ optimizer .step ()
198
215
199
216
mp .spawn (
200
217
single_process_function ,
201
218
args = (
202
219
world_size ,
203
220
lr ,
221
+ iterations ,
204
222
model ,
205
223
inputs ,
206
224
labels ,
@@ -215,19 +233,21 @@ def loss_and_miner_tester(self, loss_class, miner_class, efficient):
215
233
)
216
234
217
235
def test_distributed_tuple_loss (self ):
218
- self .loss_and_miner_tester (losses .ContrastiveLoss , None , False )
236
+ for xbm in [False , True ]:
237
+ self .loss_and_miner_tester (ContrastiveLoss , None , False , xbm )
219
238
220
239
def test_distributed_tuple_loss_and_miner (self ):
221
- self .loss_and_miner_tester (
222
- losses .ContrastiveLoss , miners .MultiSimilarityMiner , False
223
- )
240
+ for xbm in [False , True ]:
241
+ self .loss_and_miner_tester (
242
+ ContrastiveLoss , MultiSimilarityMiner , False , xbm
243
+ )
224
244
225
245
def test_distributed_tuple_loss_efficient (self ):
226
- self .loss_and_miner_tester (losses . ContrastiveLoss , None , True )
246
+ self .loss_and_miner_tester (ContrastiveLoss , None , True , xbm = False )
227
247
228
248
def test_distributed_tuple_loss_and_miner_efficient (self ):
229
249
self .loss_and_miner_tester (
230
- losses . ContrastiveLoss , miners . MultiSimilarityMiner , True
250
+ ContrastiveLoss , MultiSimilarityMiner , True , xbm = False
231
251
)
232
252
233
253
0 commit comments