Skip to content

Commit 3f64bd6

Browse files
author
Kevin Musgrave
committed
Revert to_device change (keep it for a future PR).
1 parent ab47660 commit 3f64bd6

15 files changed

+20
-51
lines changed

src/pytorch_metric_learning/losses/manifold_loss.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,7 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
6666
meta_classes = torch.cat((torch.arange(self.K), meta_classes))
6767
meta_classes = meta_classes[torch.randperm(N)]
6868

69-
loss_int = torch.zeros(1)
70-
loss_int = c_f.to_device(loss_int, tensor=embeddings, dtype=embeddings.dtype)
69+
loss_int = torch.zeros(1, device=embeddings.device, dtype=embeddings.dtype)
7170
embs_and_proxies = torch.cat([embeddings, self.proxies], dim=0)
7271

7372
S = self.distance(embs_and_proxies, embs_and_proxies).clamp(0, np.inf)

src/pytorch_metric_learning/losses/p2s_grad_loss.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
4444

4545
self.weight.data = self.weight.data.renorm(2, 1, 1e-5).mul(1e5)
4646
dtype = embeddings.dtype
47-
self.weight.data, labels = c_f.to_device(
48-
(self.weight.data, labels), tensor=embeddings, dtype=(dtype, torch.long)
47+
self.weight.data = c_f.to_device(
48+
self.weight.data, tensor=embeddings, dtype=dtype
4949
)
5050

5151
rtol = 1e-2 if dtype == torch.float16 else 1e-5

src/pytorch_metric_learning/samplers/hierarchical_sampler.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,12 +74,9 @@ def __len__(
7474
def reshuffle(self):
7575
batches = []
7676
for combinations in self.super_pairs:
77-
7877
for b in range(self.batches_per_super_tuple):
79-
8078
batch = []
8179
for slb in combinations:
82-
8380
sub_batch = []
8481
all_classes = list(self.super_image_lists[slb].keys())
8582
c_f.NUMPY_RANDOM.shuffle(all_classes)

src/pytorch_metric_learning/utils/common_functions.py

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,10 @@
33
import logging
44
import os
55
import re
6-
from typing import List, Tuple, Union
76

87
import numpy as np
98
import scipy.stats
109
import torch
11-
from torch import nn
1210

1311
LOGGER_NAME = "PML"
1412
LOGGER = logging.getLogger(LOGGER_NAME)
@@ -461,34 +459,13 @@ def to_dtype(x, tensor=None, dtype=None):
461459
return x
462460

463461

464-
def to_device(
465-
x: Union[torch.Tensor, nn.Parameter, List, Tuple],
466-
tensor=None,
467-
device=None,
468-
dtype: Union[torch.dtype, List, Tuple] = None,
469-
):
462+
def to_device(x, tensor=None, device=None, dtype=None):
470463
dv = device if device is not None else tensor.device
471-
is_iterable = is_list_or_tuple(x)
472-
if not is_iterable:
473-
x = [x]
474-
475-
xd = x
476-
if is_list_or_tuple(dtype):
477-
if len(dtype) == len(x):
478-
xd = [
479-
to_dtype(x[i].to(dv), tensor=tensor, dtype=dtype[i])
480-
for i in range(len(x))
481-
]
482-
else:
483-
raise RuntimeError(
484-
f"The size of dtype was {len(dtype)}. It is only available 1 or the same of x"
485-
)
486-
elif dtype is not None:
487-
xd = [to_dtype(xt.to(dv), tensor=tensor, dtype=dtype) for xt in x]
488-
489-
if len(xd) == 1:
490-
xd = xd[0]
491-
return xd
464+
if x.device != dv:
465+
x = x.to(dv)
466+
if dtype is not None:
467+
x = to_dtype(x, dtype=dtype)
468+
return x
492469

493470

494471
def set_ref_emb(embeddings, labels, ref_emb, ref_labels):

tests/losses/test_manifold_loss.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,7 @@ def forward(self, fvec, fLvec, fvecs_add=None):
119119
)
120120

121121
for j in range(self.nb_proxy):
122-
123122
if fLvec[i] != j:
124-
125123
val1_context = (
126124
self.d(torch.unsqueeze(A[i], 0), torch.unsqueeze(A_p[j], 0))
127125
- dist_pos

tests/losses/test_multi_similarity_loss.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
class TestMultiSimilarityLoss(unittest.TestCase):
1313
def test_multi_similarity_loss(self):
1414
for dtype in TEST_DTYPES:
15-
1615
embedding_angles = [0, 20, 40, 60, 80]
1716
embeddings = torch.tensor(
1817
[angle_to_coord(a) for a in embedding_angles],

tests/losses/test_p2s_grad_loss.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,9 +137,8 @@ def forward(self, input_score, target):
137137
# index (batch_size, class_num)
138138
with torch.no_grad():
139139
index = torch.zeros_like(input_score)
140-
index, target = c_f.to_device(
141-
(index, target), tensor=input_score, dtype=torch.long
142-
)
140+
index = c_f.to_device(index, tensor=input_score, dtype=torch.long)
141+
target = c_f.to_device(target, tensor=input_score, dtype=torch.long)
143142
# index[i][target[i][j]] = 1
144143
index.scatter_(1, target.data.view(-1, 1), 1)
145144

@@ -199,7 +198,6 @@ def test_p2s_grad_loss_with_paper_formula(self):
199198
)
200199

201200
def test_p2s_grad_loss_with_trusted_implementation(self):
202-
203201
for dtype in TEST_DTYPES:
204202
embedding_angles = [0, 20, 40, 60, 80]
205203
embeddings = torch.tensor(

tests/losses/test_pnp_loss.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def __init__(self, b, alpha, anneal, variant, bs, classes):
3131
self.mask = mask.unsqueeze(dim=0).repeat(self.batch_size, 1, 1)
3232

3333
def forward(self, batch):
34-
3534
dtype, device = batch.dtype, batch.device
3635
self.mask = self.mask.type(dtype).to(device)
3736
# compute the relevance scores via cosine similarity of the CNN-produced embedding vectors

tests/losses/test_proxy_anchor_loss.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def test_proxyanchor_loss(self):
8989
margin = 0.5
9090

9191
for use_autocast in [True, False]:
92-
9392
if use_autocast:
9493
cm = torch.cuda.amp.autocast()
9594
else:

tests/losses/test_triplet_margin_loss.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,12 @@
1414
class TestTripletMarginLoss(unittest.TestCase):
1515
def test_triplet_margin_loss(self):
1616
for dtype in TEST_DTYPES:
17-
embeddings = torch.randn(5, 32, requires_grad=True, dtype=dtype,).to(
17+
embeddings = torch.randn(
18+
5,
19+
32,
20+
requires_grad=True,
21+
dtype=dtype,
22+
).to(
1823
TEST_DEVICE
1924
) # 2D embeddings
2025
embeddings = torch.nn.functional.normalize(embeddings)

tests/miners/test_batch_easy_hard_miner_labels.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def test_labels(self):
2424
DotProductSimilarity,
2525
SNRDistance,
2626
]:
27-
for (pos_strategy, neg_strategy) in [
27+
for pos_strategy, neg_strategy in [
2828
("easy", "easy"),
2929
("easy", "semihard"),
3030
("easy", "hard"),

tests/samplers/test_fixed_set_of_triplets.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def test_fixed_set_of_triplets_with_batch_size(self):
3636
dataset, batch_size=batch_size, sampler=sampler, drop_last=True
3737
)
3838
for _ in range(2):
39-
for (embeddings, curr_labels) in dataloader:
39+
for embeddings, curr_labels in dataloader:
4040
a, p, n = miner(batch_of_fake_embeddings, curr_labels)
4141
self.assertTrue(len(a) == batch_size // 3)
4242
self.assertTrue(torch.all(curr_labels[a] == curr_labels[p]))

tests/samplers/test_m_per_class_sampler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_mperclass_sampler_with_batch_size(self):
6868
drop_last=False,
6969
)
7070
for _ in range(2):
71-
for (_, curr_labels) in dataloader:
71+
for _, curr_labels in dataloader:
7272
unique_labels, counts = torch.unique(
7373
curr_labels, return_counts=True
7474
)

tests/trainers/test_key_checking.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def test_metric_loss_only(self):
3131
TrainWithClassifier,
3232
TwoStreamMetricLoss,
3333
]:
34-
3534
model_dict = {"trunk": model}
3635
optimizer_dict = {"trunk_optimizer": None}
3736
loss_fn_dict = {"metric_loss": loss_fn}

tests/trainers/test_metric_loss_only.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424

2525
class TestMetricLossOnly(unittest.TestCase):
2626
def test_metric_loss_only(self):
27-
2827
cifar_resnet_folder = "temp_cifar_resnet_for_pytorch_metric_learning_test"
2928
dataset_folder = "temp_dataset_for_pytorch_metric_learning_test"
3029
model_folder = "temp_saved_models_for_pytorch_metric_learning_test"

0 commit comments

Comments
 (0)