Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify reducers & mixin #646

Open
wants to merge 2 commits into
base: v3.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ dist/
site/
venv/
.ipynb_checkpoints
**/.vscode
**/temp*_for_pytorch_metric_learning_test
examples/notebooks/dataset
examples/notebooks/CIFAR10_Dataset
examples/notebooks/CIFAR100_Dataset
Expand Down
30 changes: 15 additions & 15 deletions docs/losses.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ losses.ArcFaceLoss(num_classes, embedding_size, margin=28.6, scale=64, **kwargs)

**Other info**:

* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments.
* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments.
* This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example:
```python
loss_func = losses.ArcFaceLoss(...).to(torch.device('cuda'))
Expand Down Expand Up @@ -141,17 +141,17 @@ All loss functions extend this class and therefore inherit its ```__init__``` pa
losses.BaseMetricLossFunction(collect_stats = False,
reducer = None,
distance = None,
embedding_regularizer = None,
embedding_reg_weight = 1)
regularizer = None,
reg_weight = 1)
```

**Parameters**:

* **collect_stats**: If True, will collect various statistics that may be useful to analyze during experiments. If False, these computations will be skipped. Want to make ```True``` the default? Set the global [COLLECT_STATS](common_functions.md#collect_stats) flag.
* **reducer**: A [reducer](reducers.md) object. If None, then the default reducer will be used.
* **distance**: A [distance](distances.md) object. If None, then the default distance will be used.
* **embedding_regularizer**: A [regularizer](regularizers.md) object that will be applied to embeddings. If None, then no embedding regularization will be used.
* **embedding_reg_weight**: If an embedding regularizer is used, then its loss will be multiplied by this amount before being added to the total loss.
* **regularizer**: A [regularizer](regularizers.md) object that will be applied to embeddings. If None, then no embedding regularization will be used.
* **reg_weight**: If an embedding regularizer is used, then its loss will be multiplied by this amount before being added to the total loss.

**Default distance**:

Expand Down Expand Up @@ -273,7 +273,7 @@ losses.CosFaceLoss(num_classes, embedding_size, margin=0.35, scale=64, **kwargs)

**Other info**:

* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments.
* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments.
* This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example:
```python
loss_func = losses.CosFaceLoss(...).to(torch.device('cuda'))
Expand Down Expand Up @@ -491,7 +491,7 @@ where

**Other info**:

* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments.
* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments.
* This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example:
```python
loss_func = losses.LargeMarginSoftmaxLoss(...).to(torch.device('cuda'))
Expand Down Expand Up @@ -737,7 +737,7 @@ losses.NormalizedSoftmaxLoss(num_classes, embedding_size, temperature=0.05, **kw

**Other info**

* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments.
* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments.
* This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example:
```python
loss_func = losses.NormalizedSoftmaxLoss(...).to(torch.device('cuda'))
Expand Down Expand Up @@ -870,7 +870,7 @@ losses.ProxyAnchorLoss(num_classes, embedding_size, margin = 0.1, alpha = 32, **

**Other info**

* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments.
* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments.
* This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example:
```python
loss_func = losses.ProxyAnchorLoss(...).to(torch.device('cuda'))
Expand Down Expand Up @@ -907,7 +907,7 @@ losses.ProxyNCALoss(num_classes, embedding_size, softmax_scale=1, **kwargs)

**Other info**

* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments.
* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments.
* This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example:
```python
loss_func = losses.ProxyNCALoss(...).to(torch.device('cuda'))
Expand Down Expand Up @@ -1027,7 +1027,7 @@ where

**Other info**

* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments.
* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments.
* This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example:
```python
loss_func = losses.SoftTripleLoss(...).to(torch.device('cuda'))
Expand Down Expand Up @@ -1068,7 +1068,7 @@ See [LargeMarginSoftmaxLoss](losses.md#largemarginsoftmaxloss)

**Other info**

* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```weight_regularizer```, ```weight_reg_weight```, and ```weight_init_func``` as optional arguments.
* This also extends [WeightRegularizerMixin](losses.md#weightregularizermixin), so it accepts ```regularizer```, ```reg_weight```, and ```weight_init_func``` as optional arguments.
* This loss **requires an optimizer**. You need to create an optimizer and pass this loss's parameters to that optimizer. For example:
```python
loss_func = losses.SphereFaceLoss(...).to(torch.device('cuda'))
Expand Down Expand Up @@ -1230,14 +1230,14 @@ complete_loss = losses.MultipleLosses([main_loss, var_loss], weights=[1, 0.5])
## WeightRegularizerMixin
Losses can extend this class in addition to BaseMetricLossFunction. You should extend this class if your loss function contains a learnable weight matrix.
```python
losses.WeightRegularizerMixin(weight_init_func=None, weight_regularizer=None, weight_reg_weight=1, **kwargs)
losses.WeightRegularizerMixin(weight_init_func=None, regularizer=None, reg_weight=1, **kwargs)
```

**Parameters**:

* **weight_init_func**: An [TorchInitWrapper](common_functions.md#torchinitwrapper) object, which will be used to initialize the weights of the loss function.
* **weight_regularizer**: The [regularizer](regularizers.md) to apply to the loss's learned weights.
* **weight_reg_weight**: The amount the regularization loss will be multiplied by.
* **regularizer**: The [regularizer](regularizers.md) to apply to the loss's learned weights.
* **reg_weight**: The amount the regularization loss will be multiplied by.

Extended by:

Expand Down
30 changes: 18 additions & 12 deletions src/pytorch_metric_learning/losses/base_metric_loss_function.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
import inspect
import re

from ..utils import common_functions as c_f
from ..utils.module_with_records_and_reducer import ModuleWithRecordsReducerAndDistance
from . import mixins
from .mixins import EmbeddingRegularizerMixin


class BaseMetricLossFunction(
EmbeddingRegularizerMixin, ModuleWithRecordsReducerAndDistance
):
class BaseMetricLossFunction(ModuleWithRecordsReducerAndDistance):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.emb_loss_regularizer = EmbeddingRegularizerMixin(
**kwargs
) # Avoid multiple inheritance errors. In this way if a loss function inherits from a RegularizerMixin subclass it does not affect the mro

def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
"""
This has to be implemented and is what actually computes the loss.
Expand All @@ -34,7 +40,9 @@ def forward(
loss_dict = self.compute_loss(
embeddings, labels, indices_tuple, ref_emb, ref_labels
)
self.add_embedding_regularization_to_loss_dict(loss_dict, embeddings)
self.emb_loss_regularizer.add_embedding_regularization_to_loss_dict(
loss_dict, embeddings
)
return self.reducer(loss_dict, embeddings, labels)

def zero_loss(self):
Expand All @@ -50,12 +58,10 @@ def sub_loss_names(self):
return self._sub_loss_names() + self.all_regularization_loss_names()

def all_regularization_loss_names(self):
reg_names = []
reg_loss_names = []
for base_class in inspect.getmro(self.__class__):
base_class_name = base_class.__name__
mixin_keyword = "RegularizerMixin"
if base_class_name.endswith(mixin_keyword):
descriptor = base_class_name.replace(mixin_keyword, "").lower()
if getattr(self, "{}_regularizer".format(descriptor)):
reg_names.extend(base_class.regularization_loss_names(self))
return reg_names
if base_class.__module__ == mixins.__name__:
m = re.search(r"(\w+)RegularizerMixin", base_class.__name__)
if m is not None:
reg_loss_names.append(m.group(1).lower())
return reg_loss_names
6 changes: 5 additions & 1 deletion src/pytorch_metric_learning/losses/margin_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def compute_loss(self, embeddings, labels, indices_tuple, ref_emb, ref_labels):
if len(anchor_idx) == 0:
return self.zero_losses()

beta = self.beta if len(self.beta) == 1 else self.beta[labels[anchor_idx]]
beta = (
self.beta
if len(self.beta) == 1
else self.beta[labels[anchor_idx].to("cpu")]
) # When labels are on gpu gives error
beta = c_f.to_device(beta, device=embeddings.device, dtype=embeddings.dtype)

mat = self.distance(embeddings, ref_emb)
Expand Down
103 changes: 49 additions & 54 deletions src/pytorch_metric_learning/losses/mixins.py
Original file line number Diff line number Diff line change
@@ -1,72 +1,67 @@
from typing import Dict

import torch

from ..utils import common_functions as c_f

SUPPORTED_REGULARIZATION_TYPES = ["custom", "weight", "embedding"]

class WeightMixin:
def __init__(self, weight_init_func=None, **kwargs):
super().__init__(**kwargs)
self.weight_init_func = weight_init_func
if self.weight_init_func is None:
self.weight_init_func = self.get_default_weight_init_func()

def get_default_weight_init_func(self):
return c_f.TorchInitWrapper(torch.nn.init.normal_)
class RegularizerMixin:
"""Base class for regularization losses.
regularizer: function-like object or `nn.Module` that transforms input data into single number or single-element `torch.Tensor`
"""


class WeightRegularizerMixin(WeightMixin):
def __init__(self, weight_regularizer=None, weight_reg_weight=1, **kwargs):
self.weight_regularizer = (
weight_regularizer is not None
) # hack needed to know whether reg will be in sub-loss names
super().__init__(**kwargs)
self.weight_regularizer = weight_regularizer
self.weight_reg_weight = weight_reg_weight
if self.weight_regularizer is not None:
def __init__(self, regularizer=None, reg_weight=1, type="custom", **kwargs):
self.check_type(type)
self.regularizer = regularizer if regularizer is not None else (lambda data: 0)
self.reg_weight = reg_weight
if regularizer is not None:
self.add_to_recordable_attributes(
list_of_names=["weight_reg_weight"], is_stat=False
list_of_names=[f"{type}_reg_weight"], is_stat=False
)

def weight_regularization_loss(self, weights):
if self.weight_regularizer is None:
loss = 0
else:
loss = self.weight_regularizer(weights) * self.weight_reg_weight
return {"losses": loss, "indices": None, "reduction_type": "already_reduced"}
def regularization_loss(self, data):
loss = self.regularizer(data) * self.reg_weight
return loss

def add_regularization_to_loss_dict(self, loss_dict: Dict[str, Dict], data):
loss_dict[self.reg_loss_type] = {
"losses": self.regularization_loss(data),
"indices": None,
"reduction_type": "already_reduced",
}

def check_type(self, type: str):
if type not in SUPPORTED_REGULARIZATION_TYPES:
raise ValueError(
f"Type provided not supported. Supported types are {', '.join(SUPPORTED_REGULARIZATION_TYPES)}, given type is {type}."
)
self.reg_loss_type = f"{type}_reg_loss"

def add_weight_regularization_to_loss_dict(self, loss_dict, weights):
if self.weight_regularizer is not None:
loss_dict["weight_reg_loss"] = self.weight_regularization_loss(weights)

def regularization_loss_names(self):
return ["weight_reg_loss"]
def get_default_weight_init_func():
return c_f.TorchInitWrapper(torch.nn.init.normal_)


class EmbeddingRegularizerMixin:
def __init__(self, embedding_regularizer=None, embedding_reg_weight=1, **kwargs):
self.embedding_regularizer = (
embedding_regularizer is not None
) # hack needed to know whether reg will be in sub-loss names
class WeightRegularizerMixin(RegularizerMixin):
def __init__(self, weight_init_func=None, **kwargs):
kwargs["type"] = "weight"
super().__init__(**kwargs)
self.embedding_regularizer = embedding_regularizer
self.embedding_reg_weight = embedding_reg_weight
if self.embedding_regularizer is not None:
self.add_to_recordable_attributes(
list_of_names=["embedding_reg_weight"], is_stat=False
)
self.weight_init_func = (
weight_init_func
if weight_init_func is not None
else get_default_weight_init_func()
)

def embedding_regularization_loss(self, embeddings):
if self.embedding_regularizer is None:
loss = 0
else:
loss = self.embedding_regularizer(embeddings) * self.embedding_reg_weight
return {"losses": loss, "indices": None, "reduction_type": "already_reduced"}
def add_weight_regularization_to_loss_dict(self, loss_dict, weights):
self.add_regularization_to_loss_dict(loss_dict, weights)

def add_embedding_regularization_to_loss_dict(self, loss_dict, embeddings):
if self.embedding_regularizer is not None:
loss_dict["embedding_reg_loss"] = self.embedding_regularization_loss(
embeddings
)

def regularization_loss_names(self):
return ["embedding_reg_loss"]
class EmbeddingRegularizerMixin(RegularizerMixin):
def __init__(self, **kwargs):
kwargs["type"] = "embedding"
super().__init__(**kwargs)

def add_embedding_regularization_to_loss_dict(self, loss_dict, embeddings):
self.add_regularization_to_loss_dict(loss_dict, embeddings)
2 changes: 1 addition & 1 deletion src/pytorch_metric_learning/losses/vicreg_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def __init__(
):
if "distance" in kwargs:
raise ValueError("VICRegLoss cannot use a distance function")
if "embedding_regularizer" in kwargs:
if "regularizer" in kwargs:
raise ValueError("VICRegLoss cannot use a regularizer")
super().__init__(**kwargs)
"""
Expand Down
2 changes: 2 additions & 0 deletions src/pytorch_metric_learning/reducers/avg_non_zero_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@


class AvgNonZeroReducer(ThresholdReducer):
"""Equivalent to ThresholdReducer with `low=0`"""

def __init__(self, **kwargs):
super().__init__(low=0, **kwargs)
4 changes: 2 additions & 2 deletions src/pytorch_metric_learning/reducers/base_reducer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def forward(self, loss_dict, embeddings, labels):
loss_name = list(loss_dict.keys())[0]
loss_info = loss_dict[loss_name]
losses, loss_indices, reduction_type, kwargs = self.unpack_loss_info(loss_info)
loss_val = self.reduce_the_loss(
loss_val = self.reduce_loss( # Similar to compute_loss
losses, loss_indices, reduction_type, kwargs, embeddings, labels
)
return loss_val
Expand All @@ -28,7 +28,7 @@ def unpack_loss_info(self, loss_info):
{},
)

def reduce_the_loss(
def reduce_loss( # Similar to compute_loss
self, losses, loss_indices, reduction_type, kwargs, embeddings, labels
):
self.set_losses_size_stat(losses)
Expand Down
28 changes: 9 additions & 19 deletions src/pytorch_metric_learning/reducers/class_weighted_reducer.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,18 @@
import torch

from ..utils import common_functions as c_f
from .base_reducer import BaseReducer
from .threshold_reducer import ThresholdReducer


class ClassWeightedReducer(ThresholdReducer):
"""It weights the losses with user-specified weights and then takes the average.

Subclass of ThresholdReducer, therefore it is possible to specify `low` and `high` hyperparameters.
"""

class ClassWeightedReducer(BaseReducer):
def __init__(self, weights, **kwargs):
super().__init__(**kwargs)
self.weights = weights

def element_reduction(self, losses, loss_indices, embeddings, labels):
return self.element_reduction_helper(losses, loss_indices, labels)

def pos_pair_reduction(self, losses, loss_indices, embeddings, labels):
return self.element_reduction_helper(losses, loss_indices[0], labels)

# based on anchor label
def neg_pair_reduction(self, losses, loss_indices, embeddings, labels):
return self.element_reduction_helper(losses, loss_indices[0], labels)

# based on anchor label
def triplet_reduction(self, losses, loss_indices, embeddings, labels):
return self.element_reduction_helper(losses, loss_indices[0], labels)

def element_reduction_helper(self, losses, indices, labels):
self.weights = c_f.to_device(self.weights, losses, dtype=losses.dtype)
return torch.mean(losses * self.weights[labels[indices]])
losses = losses * self.weights[labels[loss_indices]]
return super().element_reduction(losses, loss_indices, embeddings, labels)
Loading