You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
feat: improve how device switch is handled between the metric device and the input tensors device (#3043)
* refactor: remove outdated code and issue a warning if two tensors are on separate devices.
* feat: prioritize computation on GPU devices over CPUs
If either one of the metric device or the update input device
is a GPU, this commit will put the other one on GPU.
* fix: use a temp var that will be moved with y_pred
The comparison with self._device was not possible because it
can be created with `torch.device("cuda")` which is not equal
to `torch.device("cuda:0")` which is the device of a tensor
created with `torch.device("cuda")`. This change will have
a bigger performance hit when self._kernel is not on the same
device as y_pred as it will need to be moved onto y_pred's
device every time update() is called.
* test: add metric and y_pred with different devices test
* feat: move self._kernel directly and issue a warning only when not all y_pred tensors are on the same device
* feat: adapt test to new behaviour
* feat: keep the accumulation on the same device as self._kernel
* feat: move accumulation along side self._kernel
* feat: allow different channel number
* style: format using the run_code_style script
* style: add line brak to conform to E501
* fix: use torch.empty to avoid type incompatibility between None and Tensor with mypy
* feat: only operate on self._kernel, keep the accumulation on user's selected device
* test: add variable channel test and factorize the code
* refactor: remove redundant line between init and reset
* refactor: elif comparison and replace RuntimeWarning by UserWarning
Co-authored-by: vfdev <[email protected]>
* refactor: set _kernel in __init__ and manually format to pass E501
* test: adapt test to new UserWarning
* test: remove skips
* refactor: use None instead of torch.empty
* style: reorder imports
* refactor: rename channel to nb_channel
* Fixed failing test_distrib_accumulator_device
---------
Co-authored-by: vfdev <[email protected]>
0 commit comments