Skip to content

Commit e4683de

Browse files
Merge branch 'master' into cocomap
2 parents e2ac8ee + 9b4bef8 commit e4683de

File tree

6 files changed

+185
-52
lines changed

6 files changed

+185
-52
lines changed

ignite/handlers/param_scheduler.py

+59-5
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from typing import Any, cast, Dict, List, Mapping, Optional, Sequence, Tuple, Type, Union
1111

1212
import torch
13-
from torch.optim.lr_scheduler import ReduceLROnPlateau
13+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ReduceLROnPlateau
1414
from torch.optim.optimizer import Optimizer
1515

1616
# https://github.com/pytorch/ignite/issues/2773
@@ -792,6 +792,57 @@ def simulate_values( # type: ignore[override]
792792
return output
793793

794794

795+
class _CosineAnnealingWarmRestarts:
796+
def __init__(self, lr_scheduler: CosineAnnealingWarmRestarts):
797+
self._lr_scheduler = lr_scheduler
798+
799+
@property
800+
def last_epoch(self) -> int:
801+
return self._lr_scheduler.last_epoch
802+
803+
@last_epoch.setter
804+
def last_epoch(self, value: int) -> None:
805+
self._lr_scheduler.last_epoch = value
806+
807+
@property
808+
def optimizer(self) -> torch.optim.Optimizer:
809+
return self._lr_scheduler.optimizer
810+
811+
def get_lr(self, epoch: Optional[int] = None) -> List[float]:
812+
T_mult = self._lr_scheduler.T_mult
813+
eta_min = self._lr_scheduler.eta_min
814+
815+
if epoch is None and self.last_epoch < 0:
816+
epoch = 0
817+
if epoch is None:
818+
epoch = self.last_epoch + 1
819+
self._lr_scheduler.T_cur = self._lr_scheduler.T_cur + 1
820+
if self._lr_scheduler.T_cur >= self._lr_scheduler.T_i:
821+
self._lr_scheduler.T_cur = self._lr_scheduler.T_cur - self._lr_scheduler.T_i
822+
self._lr_scheduler.T_i = self._lr_scheduler.T_i * T_mult
823+
else:
824+
if epoch < 0:
825+
raise ValueError("Expected non-negative epoch, but got {}".format(epoch))
826+
if epoch >= self._lr_scheduler.T_0:
827+
if T_mult == 1:
828+
self._lr_scheduler.T_cur = epoch % self._lr_scheduler.T_0
829+
else:
830+
n = int(math.log((epoch / self._lr_scheduler.T_0 * (T_mult - 1) + 1), T_mult))
831+
self._lr_scheduler.T_cur = epoch - self._lr_scheduler.T_0 * (T_mult**n - 1) / (T_mult - 1)
832+
self._lr_scheduler.T_i = self._lr_scheduler.T_0 * T_mult**n
833+
else:
834+
self._lr_scheduler.T_i = self._lr_scheduler.T_0
835+
self._lr_scheduler.T_cur = epoch
836+
837+
self.last_epoch = math.floor(epoch)
838+
839+
return [
840+
eta_min
841+
+ (base_lr - eta_min) * (1 + math.cos(math.pi * self._lr_scheduler.T_cur / self._lr_scheduler.T_i)) / 2
842+
for base_lr in self._lr_scheduler.base_lrs
843+
]
844+
845+
795846
class LRScheduler(ParamScheduler):
796847
"""A wrapper class to call `torch.optim.lr_scheduler` objects as `ignite` handlers.
797848
@@ -853,7 +904,10 @@ def __init__(
853904
f"but given {type(lr_scheduler)}"
854905
)
855906

856-
self.lr_scheduler = lr_scheduler
907+
self.lr_scheduler: Union[PyTorchLRScheduler, _CosineAnnealingWarmRestarts] = lr_scheduler
908+
if isinstance(lr_scheduler, CosineAnnealingWarmRestarts):
909+
self.lr_scheduler = _CosineAnnealingWarmRestarts(lr_scheduler)
910+
857911
super(LRScheduler, self).__init__(
858912
optimizer=self.lr_scheduler.optimizer,
859913
param_name="lr",
@@ -863,7 +917,7 @@ def __init__(
863917
warnings.warn(
864918
"Please make sure to attach scheduler to Events.ITERATION_COMPLETED "
865919
"instead of Events.ITERATION_STARTED to make sure to use "
866-
"the first lr value from the optimizer, otherwise it is will be skipped"
920+
"the first lr value from the optimizer, otherwise it will be skipped"
867921
)
868922
self.lr_scheduler.last_epoch += 1
869923

@@ -876,9 +930,9 @@ def __call__(self, engine: Optional[Engine], name: Optional[str] = None) -> None
876930
def get_param(self) -> Union[float, List[float]]:
877931
"""Method to get current optimizer's parameter value"""
878932
# Emulate context manager for pytorch>=1.4
879-
self.lr_scheduler._get_lr_called_within_step = True # type: ignore[attr-defined]
933+
self.lr_scheduler._get_lr_called_within_step = True # type: ignore[union-attr]
880934
lr_list = cast(List[float], self.lr_scheduler.get_lr())
881-
self.lr_scheduler._get_lr_called_within_step = False # type: ignore[attr-defined]
935+
self.lr_scheduler._get_lr_called_within_step = False # type: ignore[union-attr]
882936
if len(lr_list) == 1:
883937
return lr_list[0]
884938
else:

ignite/metrics/metric.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def __init__(
219219
@abstractmethod
220220
def reset(self) -> None:
221221
"""
222-
Resets the metric to it's initial state.
222+
Resets the metric to its initial state.
223223
224224
By default, this is called at the start of each epoch.
225225
"""
@@ -240,7 +240,7 @@ def update(self, output: Any) -> None:
240240
@abstractmethod
241241
def compute(self) -> Any:
242242
"""
243-
Computes the metric based on it's accumulated state.
243+
Computes the metric based on its accumulated state.
244244
245245
By default, this is called at the end of each epoch.
246246
@@ -273,7 +273,7 @@ def iteration_completed(self, engine: Engine) -> None:
273273
274274
Note:
275275
``engine.state.output`` is used to compute metric values.
276-
The majority of implemented metrics accepts the following formats for ``engine.state.output``:
276+
The majority of implemented metrics accept the following formats for ``engine.state.output``:
277277
``(y_pred, y)`` or ``{'y_pred': y_pred, 'y': y}``. ``y_pred`` and ``y`` can be torch tensors or
278278
list of tensors/numbers if applicable.
279279

ignite/metrics/precision.py

+69-36
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
6161
num_classes = 2 if self._type == "binary" else y_pred.size(1)
6262
if self._type == "multiclass" and y.max() + 1 > num_classes:
6363
raise ValueError(
64-
f"y_pred contains less classes than y. Number of predicted classes is {num_classes}"
65-
f" and element in y has invalid class = {y.max().item() + 1}."
64+
f"y_pred contains fewer classes than y. Number of classes in the prediction is {num_classes}"
65+
f" and an element in y has invalid class = {y.max().item() + 1}."
6666
)
6767
y = y.view(-1)
6868
if self._type == "binary" and self._average is False:
@@ -86,30 +86,32 @@ def _prepare_output(self, output: Sequence[torch.Tensor]) -> Sequence[torch.Tens
8686

8787
@reinit__is_reduced
8888
def reset(self) -> None:
89-
# `numerator`, `denominator` and `weight` are three variables chosen to be abstract
90-
# representatives of the ones that are measured for cases with different `average` parameters.
91-
# `weight` is only used when `average='weighted'`. Actual value of these three variables is
92-
# as follows.
93-
#
94-
# average='samples':
95-
# numerator (torch.Tensor): sum of metric value for samples
96-
# denominator (int): number of samples
97-
#
98-
# average='weighted':
99-
# numerator (torch.Tensor): number of true positives per class/label
100-
# denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
101-
# positives per class/label
102-
# weight (torch.Tensor): number of actual positives per class
103-
#
104-
# average='micro':
105-
# numerator (torch.Tensor): sum of number of true positives for classes/labels
106-
# denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives
107-
# for classes/labels
108-
#
109-
# average='macro' or boolean or None:
110-
# numerator (torch.Tensor): number of true positives per class/label
111-
# denominator (torch.Tensor): number of predicted(for precision) or actual(for recall)
112-
# positives per class/label
89+
"""
90+
`numerator`, `denominator` and `weight` are three variables chosen to be abstract
91+
representatives of the ones that are measured for cases with different `average` parameters.
92+
`weight` is only used when `average='weighted'`. Actual value of these three variables is
93+
as follows.
94+
95+
average='samples':
96+
numerator (torch.Tensor): sum of metric value for samples
97+
denominator (int): number of samples
98+
99+
average='weighted':
100+
numerator (torch.Tensor): number of true positives per class/label
101+
denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
102+
class/label.
103+
weight (torch.Tensor): number of actual positives per class
104+
105+
average='micro':
106+
numerator (torch.Tensor): sum of number of true positives for classes/labels
107+
denominator (torch.Tensor): sum of number of predicted(for precision) or actual(for recall) positives for
108+
classes/labels.
109+
110+
average='macro' or boolean or None:
111+
numerator (torch.Tensor): number of true positives per class/label
112+
denominator (torch.Tensor): number of predicted(for precision) or actual(for recall) positives per
113+
class/label.
114+
"""
113115

114116
self._numerator: Union[int, torch.Tensor] = 0
115117
self._denominator: Union[int, torch.Tensor] = 0
@@ -120,16 +122,20 @@ def reset(self) -> None:
120122

121123
@sync_all_reduce("_numerator", "_denominator")
122124
def compute(self) -> Union[torch.Tensor, float]:
123-
# Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
124-
#
125-
# .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight
126-
#
127-
# wherein `weight` is the internal variable `weight` for `'weighted'` option and :math:`1/C`
128-
# for the `macro` one. :math:`C` is the number of classes/labels.
129-
#
130-
# Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.
131-
#
132-
# .. math:: \text{Precision/Recall} = \frac{ numerator }{ denominator }
125+
r"""
126+
Return value of the metric for `average` options `'weighted'` and `'macro'` is computed as follows.
127+
128+
.. math::
129+
\text{Precision/Recall} = \frac{ numerator }{ denominator } \cdot weight
130+
131+
wherein `weight` is the internal variable `_weight` for `'weighted'` option and :math:`1/C`
132+
for the `macro` one. :math:`C` is the number of classes/labels.
133+
134+
Return value of the metric for `average` options `'micro'`, `'samples'`, `False` and None is as follows.
135+
136+
.. math::
137+
\text{Precision/Recall} = \frac{ numerator }{ denominator }
138+
"""
133139

134140
if not self._updated:
135141
raise NotComputableError(
@@ -367,6 +373,33 @@ def thresholded_output_transform(output):
367373

368374
@reinit__is_reduced
369375
def update(self, output: Sequence[torch.Tensor]) -> None:
376+
r"""
377+
Update the metric state using prediction and target.
378+
379+
Args:
380+
output: a binary tuple of tensors (y_pred, y) whose shapes follow the table below. N stands for the batch
381+
dimension, `...` for possible additional dimensions and C for class dimension.
382+
383+
.. list-table::
384+
:widths: 20 10 10 10
385+
:header-rows: 1
386+
387+
* - Output member\\Data type
388+
- Binary
389+
- Multiclass
390+
- Multilabel
391+
* - y_pred
392+
- (N, ...)
393+
- (N, C, ...)
394+
- (N, C, ...)
395+
* - y
396+
- (N, ...)
397+
- (N, ...)
398+
- (N, C, ...)
399+
400+
For binary and multilabel data, both y and y_pred should consist of 0's and 1's, but for multiclass
401+
data, y_pred and y should consist of probabilities and integers respectively.
402+
"""
370403
self._check_shape(output)
371404
self._check_type(output)
372405
y_pred, y, correct = self._prepare_output(output)

ignite/metrics/running_average.py

+2
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ def attach(self, engine: Engine, name: str, _usage: Union[str, MetricUsage] = Ep
143143
if self.epoch_bound:
144144
# restart average every epoch
145145
engine.add_event_handler(Events.EPOCH_STARTED, self.started)
146+
else:
147+
engine.add_event_handler(Events.STARTED, self.started)
146148
# compute metric
147149
engine.add_event_handler(Events.ITERATION_COMPLETED, self.iteration_completed)
148150
# apply running average

tests/ignite/handlers/test_param_scheduler.py

+44-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import numpy as np
44
import pytest
55
import torch
6-
from torch.optim.lr_scheduler import ExponentialLR, StepLR
6+
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts, ExponentialLR, StepLR
77

88
from ignite.engine import Engine, Events
99
from ignite.handlers.param_scheduler import (
@@ -650,7 +650,7 @@ def test_lr_scheduler(torch_lr_scheduler_cls, kwargs):
650650
state_dict1 = scheduler1.state_dict()
651651

652652
torch_lr_scheduler2 = torch_lr_scheduler_cls(optimizer=optimizer2, **kwargs)
653-
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it is will be skipped"):
653+
with pytest.warns(UserWarning, match=r"the first lr value from the optimizer, otherwise it will be skipped"):
654654
scheduler2 = LRScheduler(torch_lr_scheduler2, use_legacy=True)
655655
state_dict2 = scheduler2.state_dict()
656656

@@ -1362,3 +1362,45 @@ def test_reduce_lr_on_plateau_scheduler_asserts():
13621362
with pytest.raises(ValueError, match=r"Length of argument metric_values should be equal to num_events."):
13631363
metric_values = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6]
13641364
ReduceLROnPlateauScheduler.simulate_values(5, metric_values, 0.01)
1365+
1366+
1367+
@pytest.mark.parametrize("warmup_end_value", [0.23, None])
1368+
@pytest.mark.parametrize("T_0", [1, 12])
1369+
@pytest.mark.parametrize("T_mult", [1, 3])
1370+
def test_create_lr_scheduler_with_warmup_cosine(warmup_end_value, T_0, T_mult):
1371+
lr = 0.2
1372+
steps = 200
1373+
warm_steps = 50
1374+
warm_start = 0.023
1375+
1376+
def get_optim():
1377+
t1 = torch.zeros([1], requires_grad=True)
1378+
return torch.optim.SGD([t1], lr=lr)
1379+
1380+
def get_cos_shed():
1381+
return CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult)
1382+
1383+
optimizer = get_optim()
1384+
scheduler = get_cos_shed()
1385+
cosine_lrs = []
1386+
for i in range(steps):
1387+
cosine_lrs.append(optimizer.param_groups[0]["lr"])
1388+
scheduler.step()
1389+
1390+
optimizer = get_optim()
1391+
scheduler = create_lr_scheduler_with_warmup(
1392+
get_cos_shed(), warmup_start_value=warm_start, warmup_end_value=warmup_end_value, warmup_duration=warm_steps
1393+
)
1394+
1395+
warm_lrs = []
1396+
real_warm_steps = warm_steps if warmup_end_value is not None else (warm_steps - 1)
1397+
for epoch in range(real_warm_steps + steps):
1398+
scheduler(None)
1399+
warm_lrs.append(optimizer.param_groups[0]["lr"])
1400+
1401+
if warmup_end_value is not None:
1402+
np.testing.assert_allclose(np.linspace(warm_start, warmup_end_value, warm_steps), warm_lrs[:warm_steps])
1403+
assert warm_lrs[real_warm_steps:] == cosine_lrs
1404+
else:
1405+
np.testing.assert_allclose(np.linspace(warm_start, lr, warm_steps), warm_lrs[:warm_steps])
1406+
assert warm_lrs[real_warm_steps:] == cosine_lrs

tests/ignite/metrics/test_running_average.py

+8-6
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def test_epoch_unbound():
125125
batch_size = 10
126126
n_classes = 10
127127
data = list(range(n_iters))
128-
loss_values = iter(range(n_epochs * n_iters))
129-
y_true_batch_values = iter(np.random.randint(0, n_classes, size=(n_epochs * n_iters, batch_size)))
130-
y_pred_batch_values = iter(np.random.rand(n_epochs * n_iters, batch_size, n_classes))
128+
loss_values = iter(range(2 * n_epochs * n_iters))
129+
y_true_batch_values = iter(np.random.randint(0, n_classes, size=(2 * n_epochs * n_iters, batch_size)))
130+
y_pred_batch_values = iter(np.random.rand(2 * n_epochs * n_iters, batch_size, n_classes))
131131

132132
def update_fn(engine, batch):
133133
loss_value = next(loss_values)
@@ -146,9 +146,7 @@ def update_fn(engine, batch):
146146

147147
running_avg_acc = [None]
148148

149-
@trainer.on(Events.STARTED)
150-
def running_avg_output_init(engine):
151-
engine.state.running_avg_output = None
149+
trainer.state.running_avg_output = None
152150

153151
@trainer.on(Events.ITERATION_COMPLETED, running_avg_acc)
154152
def manual_running_avg_acc(engine, running_avg_acc):
@@ -187,6 +185,10 @@ def assert_equal_running_avg_output_values(engine):
187185

188186
trainer.run(data, max_epochs=3)
189187

188+
running_avg_acc[0] = None
189+
trainer.state.running_avg_output = None
190+
trainer.run(data, max_epochs=3)
191+
190192

191193
def test_multiple_attach():
192194
n_iters = 100

0 commit comments

Comments
 (0)