Skip to content

Commit a142af3

Browse files
author
Kevin Musgrave
authored
Merge pull request #449 from KevinMusgrave/dev
v1.2.1
2 parents 97076e3 + 6106678 commit a142af3

File tree

5 files changed

+108
-76
lines changed

5 files changed

+108
-76
lines changed
+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.2.0"
1+
__version__ = "1.2.1"

src/pytorch_metric_learning/utils/distributed.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22

3-
from ..losses import BaseMetricLossFunction
3+
from ..losses import BaseMetricLossFunction, CrossBatchMemory
44
from ..miners import BaseMiner
55
from ..utils import common_functions as c_f
66
from ..utils import loss_and_miner_utils as lmu
@@ -58,8 +58,14 @@ def get_indices_tuple(
5858
class DistributedLossWrapper(torch.nn.Module):
5959
def __init__(self, loss, efficient=False):
6060
super().__init__()
61-
if not isinstance(loss, BaseMetricLossFunction):
62-
raise TypeError("The input loss must extend BaseMetricLossFunction")
61+
if not isinstance(loss, (BaseMetricLossFunction, CrossBatchMemory)):
62+
raise TypeError(
63+
"The input loss must extend BaseMetricLossFunction or CrossBatchMemory"
64+
)
65+
if isinstance(loss, CrossBatchMemory) and efficient:
66+
raise ValueError(
67+
"CrossBatchMemory with efficient=True is not currently supported"
68+
)
6369
self.loss = loss
6470
self.efficient = efficient
6571

src/pytorch_metric_learning/utils/inference.py

+4
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ def __call__(
182182
c_f.LOGGER.info("embedding dimensionality is %d" % d)
183183
if self.reset_before:
184184
self.index = self.index_init_fn(d)
185+
if self.index is None:
186+
raise ValueError(
187+
"self.index is None. It needs to be initialized before being used."
188+
)
185189
distances, indices = try_gpu(
186190
self.index,
187191
query,

tests/utils/test_distributed.py

+85-65
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import torch.optim as optim
99
from torch.nn.parallel import DistributedDataParallel as DDP
1010

11-
from pytorch_metric_learning import losses, miners
11+
from pytorch_metric_learning.losses import ContrastiveLoss, CrossBatchMemory
12+
from pytorch_metric_learning.miners import MultiSimilarityMiner
1213
from pytorch_metric_learning.utils import distributed
1314

1415
from .. import TEST_DEVICE, TEST_DTYPES
@@ -52,6 +53,7 @@ def single_process_function(
5253
rank,
5354
world_size,
5455
lr,
56+
iterations,
5557
model,
5658
inputs,
5759
labels,
@@ -83,20 +85,23 @@ def single_process_function(
8385
)
8486

8587
optimizer = optim.SGD(ddp_mp_model.parameters(), lr=lr)
86-
optimizer.zero_grad()
87-
outputs = ddp_mp_model(inputs[rank].to(device))
88-
indices_tuple = None
89-
if miner_fn:
90-
indices_tuple = miner_fn(outputs, labels[rank])
91-
loss = loss_fn(outputs, labels[rank], indices_tuple)
92-
93-
dist.barrier()
94-
loss.backward()
9588

9689
original_model = original_model.to(device)
9790
assert not parameters_are_equal(original_model, ddp_mp_model.module)
98-
dist.barrier()
99-
optimizer.step()
91+
92+
for i in range(iterations):
93+
optimizer.zero_grad()
94+
outputs = ddp_mp_model(inputs[i][rank].to(device))
95+
indices_tuple = None
96+
if miner_fn:
97+
indices_tuple = miner_fn(outputs, labels[i][rank])
98+
loss = loss_fn(outputs, labels[i][rank], indices_tuple)
99+
100+
dist.barrier()
101+
loss.backward()
102+
dist.barrier()
103+
optimizer.step()
104+
100105
dist.barrier()
101106
assert parameters_are_equal(original_model, ddp_mp_model.module)
102107
dist.barrier()
@@ -113,7 +118,7 @@ def create_efficient_batch(x, i, batch_size):
113118

114119

115120
class TestDistributedLossWrapper(unittest.TestCase):
116-
def loss_and_miner_tester(self, loss_class, miner_class, efficient):
121+
def loss_and_miner_tester(self, loss_class, miner_class, efficient, xbm):
117122
torch.manual_seed(75210)
118123
if TEST_DEVICE == torch.device("cpu"):
119124
return
@@ -129,13 +134,7 @@ def loss_and_miner_tester(self, loss_class, miner_class, efficient):
129134
for world_size in range(2, max_world_size + 1):
130135
batch_size = 20
131136
lr = 1
132-
inputs = [
133-
torch.randn(batch_size, 10).type(dtype) for _ in range(world_size)
134-
]
135-
labels = [
136-
torch.randint(low=0, high=2, size=(batch_size,))
137-
for _ in range(world_size)
138-
]
137+
iterations = 10
139138
original_model = ToyMpModel().type(dtype)
140139
model = ToyMpModel().type(dtype)
141140
model.load_state_dict(original_model.state_dict())
@@ -144,6 +143,11 @@ def loss_and_miner_tester(self, loss_class, miner_class, efficient):
144143
original_model = original_model.to(TEST_DEVICE)
145144
original_loss_fn = loss_class()
146145
loss_fn = loss_class()
146+
if xbm:
147+
original_loss_fn = CrossBatchMemory(
148+
original_loss_fn, embedding_size=5
149+
)
150+
loss_fn = CrossBatchMemory(loss_fn, embedding_size=5)
147151

148152
if miner_class:
149153
original_miner_fn = miner_class()
@@ -153,54 +157,68 @@ def loss_and_miner_tester(self, loss_class, miner_class, efficient):
153157
miner_fn = None
154158

155159
optimizer = optim.SGD(original_model.parameters(), lr=lr)
156-
optimizer.zero_grad()
157-
all_inputs = torch.cat(inputs, dim=0).to(TEST_DEVICE)
158-
all_labels = torch.cat(labels, dim=0).to(TEST_DEVICE)
159-
all_outputs = original_model(all_inputs)
160-
indices_tuple = None
161-
if efficient:
162-
losses = []
163-
for i in range(len(inputs)):
164-
curr_emb, other_emb = create_efficient_batch(
165-
all_outputs, i, batch_size
166-
)
167-
curr_labels, other_labels = create_efficient_batch(
168-
all_labels, i, batch_size
169-
)
170-
if original_miner_fn:
171-
indices_tuple = distributed.get_indices_tuple(
160+
inputs = [
161+
[torch.randn(batch_size, 10).type(dtype) for _ in range(world_size)]
162+
for _ in range(iterations)
163+
]
164+
labels = [
165+
[
166+
torch.randint(low=0, high=2, size=(batch_size,))
167+
for _ in range(world_size)
168+
]
169+
for _ in range(iterations)
170+
]
171+
172+
for aaa in range(iterations):
173+
optimizer.zero_grad()
174+
all_inputs = torch.cat(inputs[aaa], dim=0).to(TEST_DEVICE)
175+
all_labels = torch.cat(labels[aaa], dim=0).to(TEST_DEVICE)
176+
all_outputs = original_model(all_inputs)
177+
indices_tuple = None
178+
if efficient:
179+
losses = []
180+
for i in range(len(inputs[aaa])):
181+
curr_emb, other_emb = create_efficient_batch(
182+
all_outputs, i, batch_size
183+
)
184+
curr_labels, other_labels = create_efficient_batch(
185+
all_labels, i, batch_size
186+
)
187+
if original_miner_fn:
188+
indices_tuple = distributed.get_indices_tuple(
189+
curr_labels,
190+
other_labels,
191+
TEST_DEVICE,
192+
embeddings=curr_emb,
193+
ref_emb=other_emb,
194+
miner=original_miner_fn,
195+
)
196+
else:
197+
indices_tuple = distributed.get_indices_tuple(
198+
curr_labels, other_labels, TEST_DEVICE
199+
)
200+
loss = original_loss_fn(
201+
curr_emb,
172202
curr_labels,
203+
indices_tuple,
204+
other_emb,
173205
other_labels,
174-
TEST_DEVICE,
175-
embeddings=curr_emb,
176-
ref_emb=other_emb,
177-
miner=original_miner_fn,
178206
)
179-
else:
180-
indices_tuple = distributed.get_indices_tuple(
181-
curr_labels, other_labels, TEST_DEVICE
182-
)
183-
loss = original_loss_fn(
184-
curr_emb,
185-
curr_labels,
186-
indices_tuple,
187-
other_emb,
188-
other_labels,
189-
)
190-
losses.append(loss)
191-
loss = sum(losses)
192-
else:
193-
if original_miner_fn:
194-
indices_tuple = original_miner_fn(all_outputs, all_labels)
195-
loss = original_loss_fn(all_outputs, all_labels, indices_tuple)
196-
loss.backward()
197-
optimizer.step()
207+
losses.append(loss)
208+
loss = sum(losses)
209+
else:
210+
if original_miner_fn:
211+
indices_tuple = original_miner_fn(all_outputs, all_labels)
212+
loss = original_loss_fn(all_outputs, all_labels, indices_tuple)
213+
loss.backward()
214+
optimizer.step()
198215

199216
mp.spawn(
200217
single_process_function,
201218
args=(
202219
world_size,
203220
lr,
221+
iterations,
204222
model,
205223
inputs,
206224
labels,
@@ -215,19 +233,21 @@ def loss_and_miner_tester(self, loss_class, miner_class, efficient):
215233
)
216234

217235
def test_distributed_tuple_loss(self):
218-
self.loss_and_miner_tester(losses.ContrastiveLoss, None, False)
236+
for xbm in [False, True]:
237+
self.loss_and_miner_tester(ContrastiveLoss, None, False, xbm)
219238

220239
def test_distributed_tuple_loss_and_miner(self):
221-
self.loss_and_miner_tester(
222-
losses.ContrastiveLoss, miners.MultiSimilarityMiner, False
223-
)
240+
for xbm in [False, True]:
241+
self.loss_and_miner_tester(
242+
ContrastiveLoss, MultiSimilarityMiner, False, xbm
243+
)
224244

225245
def test_distributed_tuple_loss_efficient(self):
226-
self.loss_and_miner_tester(losses.ContrastiveLoss, None, True)
246+
self.loss_and_miner_tester(ContrastiveLoss, None, True, xbm=False)
227247

228248
def test_distributed_tuple_loss_and_miner_efficient(self):
229249
self.loss_and_miner_tester(
230-
losses.ContrastiveLoss, miners.MultiSimilarityMiner, True
250+
ContrastiveLoss, MultiSimilarityMiner, True, xbm=False
231251
)
232252

233253

tests/utils/test_inference.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class TestInference(unittest.TestCase):
3434
def setUpClass(cls):
3535
trunk = torchvision.models.resnet18(pretrained=True)
3636
trunk.fc = common_functions.Identity()
37-
trunk = torch.nn.DataParallel(trunk.to(TEST_DEVICE))
37+
trunk = trunk.to(TEST_DEVICE)
3838

3939
cls.model = trunk
4040

@@ -59,15 +59,17 @@ def tearDown(self):
5959
torch.cuda.empty_cache()
6060

6161
def test_untrained_indexer(self):
62-
inference_model = InferenceModel(trunk=self.model)
63-
with self.assertRaises(RuntimeError):
64-
inference_model.get_nearest_neighbors(self.dataset[0][0], k=10)
62+
inference_model = InferenceModel(trunk=self.model, data_device=TEST_DEVICE)
63+
with self.assertRaises(ValueError):
64+
inference_model.get_nearest_neighbors(self.dataset[0][0].unsqueeze(0), k=10)
6565

6666
def test_get_nearest_neighbors(self):
6767
test_filename = "test_inference.index"
6868
for indexer_input in [self.train_vectors, self.dataset]:
6969
for load_from_file in [False, True]:
70-
inference_model = InferenceModel(trunk=self.model)
70+
inference_model = InferenceModel(
71+
trunk=self.model, data_device=TEST_DEVICE
72+
)
7173
if load_from_file:
7274
inference_model.load_knn_func(test_filename)
7375
else:
@@ -79,15 +81,15 @@ def test_get_nearest_neighbors(self):
7981
os.remove(test_filename)
8082

8183
def test_add_to_indexer(self):
82-
inference_model = InferenceModel(trunk=self.model)
84+
inference_model = InferenceModel(trunk=self.model, data_device=TEST_DEVICE)
8385
inference_model.knn_func.index = faiss.IndexFlatL2(512)
8486
inference_model.add_to_knn(self.dataset)
8587
self.helper_assertions(inference_model)
8688

8789
def test_list_of_text(self):
8890
model = TextModel()
8991
dataset = TextDataset()
90-
inference_model = InferenceModel(trunk=model)
92+
inference_model = InferenceModel(trunk=model, data_device=TEST_DEVICE)
9193
inference_model.train_knn(dataset)
9294
inference_model.add_to_knn([["test1", "test2"], ["test3", "test4"]])
9395
result = inference_model.get_nearest_neighbors(["dog", "cat"], k=10)

0 commit comments

Comments
 (0)