Skip to content

Commit 11a1fba

Browse files
MarcBressonvfdev-5
andauthored
feat: improve how device switch is handled between the metric device and the input tensors device (#3043)
* refactor: remove outdated code and issue a warning if two tensors are on separate devices. * feat: prioritize computation on GPU devices over CPUs If either one of the metric device or the update input device is a GPU, this commit will put the other one on GPU. * fix: use a temp var that will be moved with y_pred The comparison with self._device was not possible because it can be created with `torch.device("cuda")` which is not equal to `torch.device("cuda:0")` which is the device of a tensor created with `torch.device("cuda")`. This change will have a bigger performance hit when self._kernel is not on the same device as y_pred as it will need to be moved onto y_pred's device every time update() is called. * test: add metric and y_pred with different devices test * feat: move self._kernel directly and issue a warning only when not all y_pred tensors are on the same device * feat: adapt test to new behaviour * feat: keep the accumulation on the same device as self._kernel * feat: move accumulation along side self._kernel * feat: allow different channel number * style: format using the run_code_style script * style: add line brak to conform to E501 * fix: use torch.empty to avoid type incompatibility between None and Tensor with mypy * feat: only operate on self._kernel, keep the accumulation on user's selected device * test: add variable channel test and factorize the code * refactor: remove redundant line between init and reset * refactor: elif comparison and replace RuntimeWarning by UserWarning Co-authored-by: vfdev <[email protected]> * refactor: set _kernel in __init__ and manually format to pass E501 * test: adapt test to new UserWarning * test: remove skips * refactor: use None instead of torch.empty * style: reorder imports * refactor: rename channel to nb_channel * Fixed failing test_distrib_accumulator_device --------- Co-authored-by: vfdev <[email protected]>
1 parent 86c2a1d commit 11a1fba

File tree

2 files changed

+119
-18
lines changed

2 files changed

+119
-18
lines changed

ignite/metrics/ssim.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import Callable, Sequence, Union
1+
import warnings
2+
from typing import Callable, Optional, Sequence, Union
23

34
import torch
45
import torch.nn.functional as F
@@ -102,7 +103,8 @@ def __init__(
102103
self.c2 = (k2 * data_range) ** 2
103104
self.pad_h = (self.kernel_size[0] - 1) // 2
104105
self.pad_w = (self.kernel_size[1] - 1) // 2
105-
self._kernel = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)
106+
self._kernel_2d = self._gaussian_or_uniform_kernel(kernel_size=self.kernel_size, sigma=self.sigma)
107+
self._kernel: Optional[torch.Tensor] = None
106108

107109
@reinit__is_reduced
108110
def reset(self) -> None:
@@ -155,9 +157,22 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
155157
f"Expected y_pred and y to have BxCxHxW shape. Got y_pred: {y_pred.shape} and y: {y.shape}."
156158
)
157159

158-
channel = y_pred.size(1)
159-
if len(self._kernel.shape) < 4:
160-
self._kernel = self._kernel.expand(channel, 1, -1, -1).to(device=y_pred.device)
160+
nb_channel = y_pred.size(1)
161+
if self._kernel is None or self._kernel.shape[0] != nb_channel:
162+
self._kernel = self._kernel_2d.expand(nb_channel, 1, -1, -1)
163+
164+
if y_pred.device != self._kernel.device:
165+
if self._kernel.device == torch.device("cpu"):
166+
self._kernel = self._kernel.to(device=y_pred.device)
167+
168+
elif y_pred.device == torch.device("cpu"):
169+
warnings.warn(
170+
"y_pred tensor is on cpu device but previous computation was on another device: "
171+
f"{self._kernel.device}. To avoid having a performance hit, please ensure that all "
172+
"y and y_pred tensors are on the same device.",
173+
)
174+
y_pred = y_pred.to(device=self._kernel.device)
175+
y = y.to(device=self._kernel.device)
161176

162177
y_pred = F.pad(y_pred, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
163178
y = F.pad(y, [self.pad_w, self.pad_w, self.pad_h, self.pad_h], mode="reflect")
@@ -166,7 +181,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
166181
self._kernel = self._kernel.to(dtype=y_pred.dtype)
167182

168183
input_list = [y_pred, y, y_pred * y_pred, y * y, y_pred * y]
169-
outputs = F.conv2d(torch.cat(input_list), self._kernel, groups=channel)
184+
outputs = F.conv2d(torch.cat(input_list), self._kernel, groups=nb_channel)
170185
batch_size = y_pred.size(0)
171186
output_list = [outputs[x * batch_size : (x + 1) * batch_size] for x in range(len(input_list))]
172187

@@ -184,7 +199,7 @@ def update(self, output: Sequence[torch.Tensor]) -> None:
184199
b2 = sigma_pred_sq + sigma_target_sq + self.c2
185200

186201
ssim_idx = (a1 * a2) / (b1 * b2)
187-
self._sum_of_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64).sum().to(self._device)
202+
self._sum_of_ssim += torch.mean(ssim_idx, (1, 2, 3), dtype=torch.float64).sum().to(device=self._device)
188203

189204
self._num_examples += y.shape[0]
190205

tests/ignite/metrics/test_ssim.py

Lines changed: 97 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from typing import Sequence, Union
2+
13
import numpy as np
24
import pytest
35
import torch
@@ -70,25 +72,49 @@ def test_invalid_ssim():
7072
"shape, kernel_size, gaussian, use_sample_covariance",
7173
[[(8, 3, 224, 224), 7, False, True], [(12, 3, 28, 28), 11, True, False]],
7274
)
73-
def test_ssim(
74-
available_device, shape, kernel_size, gaussian, use_sample_covariance, dtype=torch.float32, precision=7e-5
75-
):
76-
y_pred = torch.rand(shape, device=available_device, dtype=dtype)
75+
def test_ssim(available_device, shape, kernel_size, gaussian, use_sample_covariance):
76+
y_pred = torch.rand(shape, device=available_device)
7777
y = y_pred * 0.8
7878

79+
compare_ssim_ignite_skiimg(
80+
y_pred,
81+
y,
82+
available_device,
83+
kernel_size=kernel_size,
84+
gaussian=gaussian,
85+
use_sample_covariance=use_sample_covariance,
86+
)
87+
88+
89+
def compare_ssim_ignite_skiimg(
90+
y_pred: torch.Tensor,
91+
y: torch.Tensor,
92+
device: torch.device,
93+
precision: float = 2e-5, # default to float32 expected precision
94+
*,
95+
skimg_y_pred: Union[np.ndarray, None] = None,
96+
skimg_y: Union[np.ndarray, None] = None,
97+
data_range: float = 1.0,
98+
kernel_size: Union[int, Sequence[int]] = 11,
99+
gaussian: bool = True,
100+
use_sample_covariance: bool = False,
101+
):
79102
sigma = 1.5
80-
data_range = 1.0
81-
ssim = SSIM(data_range=data_range, sigma=sigma, device=available_device)
103+
104+
ssim = SSIM(data_range=data_range, sigma=sigma, device=device)
82105
ssim.update((y_pred, y))
83106
ignite_ssim = ssim.compute()
84107

85108
if y_pred.dtype == torch.bfloat16:
86109
y_pred = y_pred.to(dtype=torch.float16)
87110

88-
skimg_pred = y_pred.cpu().numpy()
89-
skimg_y = skimg_pred * 0.8
111+
if skimg_y_pred is None:
112+
skimg_y_pred = y_pred.cpu().numpy()
113+
if skimg_y is None:
114+
skimg_y = skimg_y_pred * 0.8
115+
90116
skimg_ssim = ski_ssim(
91-
skimg_pred,
117+
skimg_y_pred,
92118
skimg_y,
93119
win_size=kernel_size,
94120
sigma=sigma,
@@ -102,6 +128,43 @@ def test_ssim(
102128
assert np.allclose(ignite_ssim, skimg_ssim, atol=precision)
103129

104130

131+
@pytest.mark.parametrize(
132+
"metric_device, y_pred_device",
133+
[
134+
[torch.device("cpu"), torch.device("cpu")],
135+
[torch.device("cpu"), torch.device("cuda")],
136+
[torch.device("cuda"), torch.device("cpu")],
137+
[torch.device("cuda"), torch.device("cuda")],
138+
],
139+
)
140+
def test_ssim_device(available_device, metric_device, y_pred_device):
141+
if available_device == "cpu":
142+
pytest.skip("This test requires a cuda device.")
143+
144+
data_range = 1.0
145+
sigma = 1.5
146+
shape = (12, 5, 256, 256)
147+
148+
ssim = SSIM(data_range=data_range, sigma=sigma, device=metric_device)
149+
150+
y_pred = torch.rand(shape, device=y_pred_device)
151+
y = y_pred * 0.8
152+
153+
if metric_device == torch.device("cuda") and y_pred_device == torch.device("cpu"):
154+
with pytest.warns(UserWarning):
155+
ssim.update((y_pred, y))
156+
else:
157+
ssim.update((y_pred, y))
158+
159+
if metric_device == torch.device("cuda") or y_pred_device == torch.device("cuda"):
160+
# A tensor will always have the device index set
161+
excepted_device = torch.device("cuda:0")
162+
else:
163+
excepted_device = torch.device("cpu")
164+
165+
assert ssim._kernel.device == excepted_device
166+
167+
105168
def test_ssim_variable_batchsize(available_device):
106169
# Checks https://github.com/pytorch/ignite/issues/2532
107170
sigma = 1.5
@@ -128,6 +191,21 @@ def test_ssim_variable_batchsize(available_device):
128191
assert np.allclose(out, expected)
129192

130193

194+
def test_ssim_variable_channel(available_device):
195+
y_preds = [
196+
torch.rand(12, 5, 28, 28, device=available_device),
197+
torch.rand(12, 4, 28, 28, device=available_device),
198+
torch.rand(12, 7, 28, 28, device=available_device),
199+
torch.rand(12, 3, 28, 28, device=available_device),
200+
torch.rand(12, 11, 28, 28, device=available_device),
201+
torch.rand(12, 6, 28, 28, device=available_device),
202+
]
203+
y_true = [v * 0.8 for v in y_preds]
204+
205+
for y_pred, y in zip(y_preds, y_true):
206+
compare_ssim_ignite_skiimg(y_pred, y, available_device)
207+
208+
131209
@pytest.mark.parametrize(
132210
"dtype, precision", [(torch.bfloat16, 2e-3), (torch.float16, 4e-4), (torch.float32, 2e-5), (torch.float64, 2e-5)]
133211
)
@@ -136,7 +214,12 @@ def test_cuda_ssim_dtypes(available_device, dtype, precision):
136214
if available_device == "cpu" and dtype in [torch.float16, torch.bfloat16]:
137215
pytest.skip(reason=f"Unsupported dtype {dtype} on CPU device")
138216

139-
test_ssim(available_device, (12, 3, 28, 28), 11, True, False, dtype=dtype, precision=precision)
217+
shape = (12, 3, 28, 28)
218+
219+
y_pred = torch.rand(shape, device=available_device, dtype=dtype)
220+
y = y_pred * 0.8
221+
222+
compare_ssim_ignite_skiimg(y_pred, y, available_device, precision)
140223

141224

142225
@pytest.mark.parametrize("metric_device", ["cpu", "process_device"])
@@ -213,7 +296,10 @@ def test_distrib_accumulator_device(distributed, metric_device):
213296

214297
ssim = SSIM(data_range=1.0, device=metric_device)
215298

216-
for dev in [ssim._device, ssim._kernel.device]:
299+
assert ssim._kernel is None
300+
assert isinstance(ssim._kernel_2d, torch.Tensor)
301+
302+
for dev in [ssim._device, ssim._kernel_2d.device]:
217303
assert dev == metric_device, f"{type(dev)}:{dev} vs {type(metric_device)}:{metric_device}"
218304

219305
y_pred = torch.rand(2, 3, 28, 28, dtype=torch.float, device=device)

0 commit comments

Comments
 (0)