Skip to content

distributed training bug fix, to maintain same order for mem queue fo… #642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CONTENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@
| [**IntraPairVarianceLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#intrapairvarianceloss) | [Deep Metric Learning with Tuplet Margin Loss](http://openaccess.thecvf.com/content_ICCV_2019/papers/Yu_Deep_Metric_Learning_With_Tuplet_Margin_Loss_ICCV_2019_paper.pdf)
| [**LargeMarginSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#largemarginsoftmaxloss) | [Large-Margin Softmax Loss for Convolutional Neural Networks](https://arxiv.org/pdf/1612.02295.pdf)
| [**LiftedStructreLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#liftedstructureloss) | [Deep Metric Learning via Lifted Structured Feature Embedding](https://arxiv.org/pdf/1511.06452.pdf)
| [**ManifoldLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) | [Ensemble Deep Manifold Similarity Learning using Hard Proxies](https://openaccess.thecvf.com/content_CVPR_2019/papers/Aziere_Ensemble_Deep_Manifold_Similarity_Learning_Using_Hard_Proxies_CVPR_2019_paper.pdf)
| [**MarginLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#marginloss) | [Sampling Matters in Deep Embedding Learning](https://arxiv.org/pdf/1706.07567.pdf)
| [**MultiSimilarityLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#multisimilarityloss) | [Multi-Similarity Loss with General Pair Weighting for Deep Metric Learning](http://openaccess.thecvf.com/content_CVPR_2019/papers/Wang_Multi-Similarity_Loss_With_General_Pair_Weighting_for_Deep_Metric_Learning_CVPR_2019_paper.pdf)
| [**NCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ncaloss) | [Neighbourhood Components Analysis](https://www.cs.toronto.edu/~hinton/absps/nca.pdf)
| [**NormalizedSoftmaxLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#normalizedsoftmaxloss) | - [NormFace: L2 Hypersphere Embedding for Face Verification](https://arxiv.org/pdf/1704.06369.pdf) <br/> - [Classification is a Strong Baseline for DeepMetric Learning](https://arxiv.org/pdf/1811.12649.pdf)
| [**NPairsLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#npairsloss) | [Improved Deep Metric Learning with Multi-class N-pair Loss Objective](http://www.nec-labs.com/uploads/images/Department-Images/MediaAnalytics/papers/nips16_npairmetriclearning.pdf)
| [**NTXentLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#ntxentloss) | - [Representation Learning with Contrastive Predictive Coding](https://arxiv.org/pdf/1807.03748.pdf) <br/> - [Momentum Contrast for Unsupervised Visual Representation Learning](https://arxiv.org/pdf/1911.05722.pdf) <br/> - [A Simple Framework for Contrastive Learning of Visual Representations](https://arxiv.org/abs/2002.05709)
| [**P2SGradLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss) | [P2SGrad: Refined Gradients for Optimizing Deep Face Models](https://arxiv.org/abs/1905.02479)
| [**PNPLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss) | [Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf)
| [**ProxyAnchorLoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyanchorloss) | [Proxy Anchor Loss for Deep Metric Learning](https://arxiv.org/pdf/2003.13911.pdf)
| [**ProxyNCALoss**](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#proxyncaloss) | [No Fuss Distance Metric Learning using Proxies](https://arxiv.org/pdf/1703.07464.pdf)
Expand Down
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@

## News

**June 18**: v2.2.0
- Added [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) and [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss).
- Added a `symmetric` flag to [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss).
- See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.2.0).
- Thank you [domenicoMuscill0](https://github.com/domenicoMuscill0).

**April 5**: v2.1.0
- Added [PNPLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#pnploss)
- Thanks to contributor [interestingzhuo](https://github.com/interestingzhuo).

**January 29**: v2.0.0
- Added SelfSupervisedLoss, plus various API improvements. See the [release notes](https://github.com/KevinMusgrave/pytorch-metric-learning/releases/tag/v2.0.0).
- Thanks to contributor [cwkeam](https://github.com/cwkeam).
- Thanks you [interestingzhuo](https://github.com/interestingzhuo).


## Documentation
Expand Down Expand Up @@ -225,6 +227,7 @@ Thanks to the contributors who made pull requests!

| Contributor | Highlights |
| -- | -- |
|[domenicoMuscill0](https://github.com/domenicoMuscill0)| - [ManifoldLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#manifoldloss) <br/> - [P2SGradLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#p2sgradloss)
|[mlopezantequera](https://github.com/mlopezantequera) | - Made the [testers](https://kevinmusgrave.github.io/pytorch-metric-learning/testers) work on any combination of query and reference sets <br/> - Made [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) work with arbitrary label comparisons |
|[cwkeam](https://github.com/cwkeam) | - [SelfSupervisedLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#selfsupervisedloss) <br/> - [VICRegLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#vicregloss) <br/> - Added mean reciprocal rank accuracy to [AccuracyCalculator](https://kevinmusgrave.github.io/pytorch-metric-learning/accuracy_calculation/) <br/> - BaseLossWrapper|
|[marijnl](https://github.com/marijnl)| - [BatchEasyHardMiner](https://kevinmusgrave.github.io/pytorch-metric-learning/miners/#batcheasyhardminer) <br/> - [TwoStreamMetricLoss](https://kevinmusgrave.github.io/pytorch-metric-learning/trainers/#twostreammetricloss) <br/> - [GlobalTwoStreamEmbeddingSpaceTester](https://kevinmusgrave.github.io/pytorch-metric-learning/testers/#globaltwostreamembeddingspacetester) <br/> - [Example using trainers.TwoStreamMetricLoss](https://github.com/KevinMusgrave/pytorch-metric-learning/blob/master/examples/notebooks/TwoStreamMetricLoss.ipynb) |
Expand Down Expand Up @@ -273,6 +276,7 @@ This library contains code that has been adapted and modified from the following
- https://github.com/ronekko/deep_metric_learning
- https://github.com/tjddus9597/Proxy-Anchor-CVPR2020
- http://kaizhao.net/regularface
- https://github.com/nii-yamagishilab/project-NN-Pytorch-scripts

### Logo
Thanks to [Jeff Musgrave](https://www.designgenius.ca/) for designing the logo.
Expand Down
101 changes: 100 additions & 1 deletion docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,57 @@ losses.LiftedStructureLoss(neg_margin=1, pos_margin=0, **kwargs):
* **loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```.


## ManifoldLoss

[Ensemble Deep Manifold Similarity Learning using Hard Proxies](https://openaccess.thecvf.com/content_CVPR_2019/papers/Aziere_Ensemble_Deep_Manifold_Similarity_Learning_Using_Hard_Proxies_CVPR_2019_paper.pdf)

```python
losses.ManifoldLoss(
l: int,
K: int = 50,
lambdaC: float = 1.0,
alpha: float = 0.8,
margin: float = 5e-4,
**kwargs
)
```

**Parameters**

- **l**: embedding size.

- **K**: number of proxies.

- **lambdaC**: regularization weight. Used in the formula `loss = intrinsic_loss + lambdaC*context_loss`.
If `lambdaC=0`, then it uses only the intrinsic loss. If `lambdaC=np.inf`, then it uses only the context loss.

- **alpha**: parameter of the Random Walk. Must be in the range `(0,1)`. It specifies the amount of similarity between neighboring nodes.

- **margin**: margin used in the calculation of the loss.


Example usage:
```python
loss_fn = ManifoldLoss(128)

# use random cluster centers
loss = loss_fn(embeddings)
# or specify indices of embeddings to use as cluster centers
loss = loss_fn(embeddings, indices_tuple=indices)
```

**Important notes**

`labels`, `ref_emb`, and `ref_labels` are not supported for this loss function.

In addition, `indices_tuple` is **not** for the output of miners. Instead, it is for a list of indices of embeddings to be used as cluster centers.


**Default reducer**:

- This loss returns an **already reduced** loss.


## MarginLoss
[Sampling Matters in Deep Embedding Learning](https://arxiv.org/pdf/1706.07567.pdf){target=_blank}
```python
Expand Down Expand Up @@ -761,6 +812,37 @@ losses.NTXentLoss(temperature=0.07, **kwargs)
* **loss**: The loss per positive pair in the batch. Reduction type is ```"pos_pair"```.



## P2SGradLoss
[P2SGrad: Refined Gradients for Optimizing Deep Face Models](https://arxiv.org/abs/1905.02479)
```python
losses.P2SGradLoss(descriptors_dim, num_classes, **kwargs)
```

**Parameters**

- **descriptors_dim**: The embedding size.

- **num_classes**: The number of classes in your training dataset.


Example usage:
```python
loss_fn = P2SGradLoss(128, 10)
loss = loss_fn(embeddings, labels)
```

**Important notes**

`indices_tuple`, `ref_emb`, and `ref_labels` are not supported for this loss function.


**Default reducer**:

- This loss returns an **already reduced** loss.



## PNPLoss
[Rethinking the Optimization of Average Precision: Only Penalizing Negative Instances before Positive Ones is Enough](https://arxiv.org/pdf/2102.04640.pdf){target=_blank}
```python
Expand Down Expand Up @@ -849,14 +931,31 @@ loss_optimizer.step()

## SelfSupervisedLoss

A common use case is to have `embeddings` and `ref_emb` be augmented versions of each other. For most losses, you have to create labels to indicate which `embeddings` correspond with which `ref_emb`. `SelfSupervisedLoss` automates this.
A common use case is to have `embeddings` and `ref_emb` be augmented versions of each other. For most losses, you have to create labels to indicate which `embeddings` correspond with which `ref_emb`.

`SelfSupervisedLoss` is a wrapper that takes care of this by creating labels internally. It assumes that:

- `ref_emb[i]` is an augmented version of `embeddings[i]`.
- `ref_emb[i]` is the only augmented version of `embeddings[i]` in the batch.

```python
losses.SelfSupervisedLoss(loss, symmetric=True, **kwargs)
```

**Parameters**:

* **loss**: The loss function to be wrapped.
* **symmetric**: If `True`, then the embeddings in both `embeddings` and `ref_emb` are used as anchors. If `False`, then only the embeddings in `embeddings` are used as anchors.

Example usage:

```
loss_fn = losses.TripletMarginLoss()
loss_fn = SelfSupervisedLoss(loss_fn)
loss = loss_fn(embeddings, ref_emb)
```


??? "Supported Loss Functions"
- [AngularLoss](losses.md#angularloss)
- [CircleLoss](losses.md#circleloss)
Expand Down
129 changes: 123 additions & 6 deletions src/pytorch_metric_learning/utils/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def all_gather_embeddings_and_labels(emb, labels):
return ref_emb, ref_labels


def gather(emb, labels):
def gather_bak(emb, labels):
device = emb.device
if labels is not None:
labels = c_f.to_device(labels, device=device)
Expand All @@ -45,6 +45,28 @@ def gather(emb, labels):
)
return all_emb, all_labels, labels

def gather(emb, labels):
device = emb.device
if labels is not None:
labels = c_f.to_device(labels, device=device)
# Gather the embeddings from every replica.
emb = c_f.to_device(emb, device=device)
emb_list = [torch.ones_like(emb) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(emb_list, emb)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
emb_list[torch.distributed.get_rank()] = emb
all_emb = torch.cat(emb_list, dim=0)

# Gather the labels from every replica.
if labels is not None:
labels_list = [torch.ones_like(labels) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(labels_list, labels)
# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
labels_list[torch.distributed.get_rank()] = labels
all_labels = torch.cat(labels_list, dim=0)
else:
all_labels = None
return all_emb, all_labels, labels

def gather_emb_and_ref(emb, labels, ref_emb=None, ref_labels=None):
all_emb, all_labels, labels = gather(emb, labels)
Expand All @@ -58,20 +80,44 @@ def gather_emb_and_ref(emb, labels, ref_emb=None, ref_labels=None):

def get_indices_tuple(labels, ref_labels, embeddings=None, ref_emb=None, miner=None):
device = labels.device
curr_batch_idx = torch.arange(len(labels), device=device)


# curr_batch_idx should be the local batch corresponding idx of ref_batch(global batch)
# curr_batch_idx = torch.arange(len(labels), device=device) # this is wrong

local_bs = len(ref_labels) // torch.distributed.get_world_size()
local_b_start_idx = torch.distributed.get_rank() * local_bs
curr_batch_idx = torch.arange(local_b_start_idx, (local_b_start_idx + local_bs), device=device)



if miner:
indices_tuple = miner(embeddings, labels, ref_emb, ref_labels)
else:
indices_tuple = lmu.get_all_pairs_indices(labels, ref_labels)
return lmu.remove_self_comparisons(indices_tuple, curr_batch_idx, len(ref_labels))


def gather_enqueue_mask(enqueue_mask, device):
def gather_enqueue_mask_bak(enqueue_mask, device):
if enqueue_mask is None:
return enqueue_mask
enqueue_mask = c_f.to_device(enqueue_mask, device=device)
return torch.cat([enqueue_mask, all_gather(enqueue_mask)], dim=0)

def gather_enqueue_mask(enqueue_mask, device):
if enqueue_mask is None:
return enqueue_mask
enqueue_mask = c_f.to_device(enqueue_mask, device=device)
# Gather the enqueue_mask from every replica.
enqueue_mask_list = [torch.ones_like(enqueue_mask) for _ in range(torch.distributed.get_world_size())]
torch.distributed.all_gather(enqueue_mask_list, enqueue_mask)

# Gathered tensors have no gradient, so we overwrite the gathered tensor for the current replica.with the embeddings produced on this replica, which do have gradients.
enqueue_mask_list[torch.distributed.get_rank()] = enqueue_mask

return torch.cat(enqueue_mask_list, dim=0)



def select_ref_or_regular(regular, ref):
return regular if ref is None else ref
Expand Down Expand Up @@ -123,12 +169,14 @@ def forward_regular_loss(
if indices_tuple is None:
indices_tuple = get_indices_tuple(labels, all_labels)
loss = self.loss(emb, labels, indices_tuple, all_emb, all_labels)

return loss
else:
loss = self.loss(
all_emb, all_labels, indices_tuple, all_ref_emb, all_ref_labels
)

return loss * world_size
return loss * world_size

def forward_cross_batch(
self,
Expand All @@ -152,9 +200,71 @@ def forward_cross_batch(
emb, labels, ref_emb, ref_labels
)
enqueue_mask = gather_enqueue_mask(enqueue_mask, emb.device)
loss = self.loss(all_emb, all_labels, indices_tuple, enqueue_mask)
return loss * world_size

# print(f'all_gathered emb size: {all_emb.shape}')
# print(f'all_labels emb size: {all_labels.shape}')
# print(f'print enqueue_mask after all gather on {torch.distributed.get_rank()}: {enqueue_mask}')

# loss = self.loss(all_emb, all_labels, indices_tuple, enqueue_mask)
loss = self.forward_cross_batch_dist_helper(self.loss, all_emb, all_labels, indices_tuple, enqueue_mask)
return loss # unit test has confirmed that this is right.

def forward_cross_batch_dist_helper(self, loss_inst, embeddings, labels, indices_tuple=None, enqueue_mask=None):
if indices_tuple is not None and enqueue_mask is not None:
raise ValueError("indices_tuple and enqueue_mask are mutually exclusive")
if enqueue_mask is not None:
assert len(enqueue_mask) == len(embeddings)
else:
assert len(embeddings) <= len(loss_inst.embedding_memory)
loss_inst.reset_stats()
device = embeddings.device
labels = c_f.to_device(labels, device=device)
loss_inst.embedding_memory = c_f.to_device(
loss_inst.embedding_memory, device=device, dtype=embeddings.dtype
)
loss_inst.label_memory = c_f.to_device(
loss_inst.label_memory, device=device, dtype=labels.dtype
)

if enqueue_mask is not None:
emb_for_queue = embeddings[enqueue_mask]
labels_for_queue = labels[enqueue_mask]
embeddings = embeddings[~enqueue_mask]
labels = labels[~enqueue_mask]
do_remove_self_comparisons = False
else:
emb_for_queue = embeddings
labels_for_queue = labels
do_remove_self_comparisons = True

# ==== DDP specific =====#
# get local device emb instead of using all gathered to be efficient
local_bs = len(embeddings)//torch.distributed.get_world_size()
local_b_start_idx = torch.distributed.get_rank() * local_bs
embeddings = embeddings[local_b_start_idx:(local_b_start_idx+local_bs), :]
labels = labels[local_b_start_idx:(local_b_start_idx + local_bs)]
# ==== end DDP specific ======#

queue_batch_size = len(emb_for_queue)
loss_inst.add_to_memory(emb_for_queue, labels_for_queue, queue_batch_size)

if not loss_inst.has_been_filled:
E_mem = loss_inst.embedding_memory[: loss_inst.queue_idx]
L_mem = loss_inst.label_memory[: loss_inst.queue_idx]
else:
E_mem = loss_inst.embedding_memory
L_mem = loss_inst.label_memory

indices_tuple = loss_inst.create_indices_tuple(
embeddings,
labels,
E_mem,
L_mem,
indices_tuple,
do_remove_self_comparisons,
)
loss = loss_inst.loss(embeddings, labels, indices_tuple, E_mem, L_mem)
return loss

class DistributedMinerWrapper(torch.nn.Module):
def __init__(self, miner, efficient=False):
Expand All @@ -172,9 +282,16 @@ def forward(self, emb, labels, ref_emb=None, ref_labels=None):
all_emb, all_labels, all_ref_emb, all_ref_labels, labels = gather_emb_and_ref(
emb, labels, ref_emb, ref_labels
)

if self.efficient:
all_labels = select_ref_or_regular(all_labels, all_ref_labels)
all_emb = select_ref_or_regular(all_emb, all_ref_emb)

# print('in DistributedMinerWrapper: ')
# print(f'all_gathered emb size: {all_emb.shape}')
# print(f'all_labels emb size: {all_labels.shape}, all labels: {all_labels}')
# print(f'labels emb size: {labels.shape}, labels: {labels}')

return get_indices_tuple(labels, all_labels, emb, all_emb, self.miner)
else:
return self.miner(all_emb, all_labels, all_ref_emb, all_ref_labels)
Loading