Skip to content

Updates for Pytorch 2.7 #8429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: dev
Choose a base branch
from
2 changes: 1 addition & 1 deletion .github/workflows/pythonapp-min.yml
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ jobs:
strategy:
fail-fast: false
matrix:
pytorch-version: ['2.4.1', '2.5.1', '2.6.0'] # FIXME: add 'latest' back once PyTorch 2.7 issues are resolved
pytorch-version: ['2.4.1', '2.5.1', '2.6.0', 'latest']
timeout-minutes: 40
steps:
- uses: actions/checkout@v4
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-f https://download.pytorch.org/whl/cpu/torch-2.3.0%2Bcpu-cp39-cp39-linux_x86_64.whl
torch>=2.4.1, <2.7.0
torch>=2.4.1
pytorch-ignite==0.4.11
numpy>=1.20
itk>=5.2
Expand Down
22 changes: 17 additions & 5 deletions monai/networks/layers/spatial_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import sys
from collections.abc import Sequence

import torch
Expand Down Expand Up @@ -526,6 +527,11 @@ def forward(
ValueError: When affine and image batch dimension differ.

"""

# In some cases it's necessary to convert inputs to grid_sample from float64 to float32 to work around known
# issues with PyTorch, see https://github.com/Project-MONAI/MONAI/pull/8429
convert_f32 = sys.platform == "win32" and src.dtype == torch.float64 and src.device == torch.device("cpu")

# validate `theta`
if not isinstance(theta, torch.Tensor):
raise TypeError(f"theta must be torch.Tensor but is {type(theta).__name__}.")
Expand Down Expand Up @@ -582,11 +588,17 @@ def forward(
)

grid = nn.functional.affine_grid(theta=theta[:, :sr], size=list(dst_size), align_corners=self.align_corners)

_input = src.contiguous()
if convert_f32:
_input = _input.to(torch.float32)
grid = grid.to(torch.float32)

dst = nn.functional.grid_sample(
input=src.contiguous(),
grid=grid,
mode=self.mode,
padding_mode=self.padding_mode,
align_corners=self.align_corners,
input=_input, grid=grid, mode=self.mode, padding_mode=self.padding_mode, align_corners=self.align_corners
)

if convert_f32:
dst = dst.to(torch.float64)

return dst
20 changes: 19 additions & 1 deletion monai/transforms/spatial/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from __future__ import annotations

import sys
import warnings
from collections.abc import Callable, Sequence
from copy import deepcopy
Expand Down Expand Up @@ -2106,13 +2107,30 @@ def __call__(
if self.norm_coords:
for i, dim in enumerate(img_t.shape[sr + 1 : 0 : -1]):
grid_t[0, ..., i] *= 2.0 / max(2, dim)

# In some cases it's necessary to convert inputs to grid_sample from float64 to float32 to work around known
# issues with PyTorch, see https://github.com/Project-MONAI/MONAI/pull/8429
convert_f32 = (
sys.platform == "win32" and img_t.dtype == torch.float64 and img_t.device == torch.device("cpu")
)

_img_t = img_t.unsqueeze(0)

if convert_f32:
_img_t = _img_t.to(torch.float32)
grid_t = grid_t.to(torch.float32)

out = torch.nn.functional.grid_sample(
img_t.unsqueeze(0),
_img_t,
grid_t,
mode=_interp_mode,
padding_mode=_padding_mode,
align_corners=None if _align_corners == TraceKeys.NONE else _align_corners, # type: ignore
)[0]

if convert_f32:
out = out.to(torch.float64)

out_val, *_ = convert_to_dst_type(out, dst=img, dtype=np.float32)
return out_val

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
requires = [
"wheel",
"setuptools",
"torch>=2.4.1, <2.7.0",
"torch>=2.4.1",
"ninja",
"packaging"
]
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pyflakes
black>=25.1.0
isort>=5.1, <6.0
ruff
pytype>=2020.6.1; platform_system != "Windows"
pytype>=2020.6.1, <=2024.4.11; platform_system != "Windows"
types-setuptools
mypy>=1.5.0, <1.12.0
ninja
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
torch>=2.4.1, <2.7.0
torch>=2.4.1
numpy>=1.24,<3.0
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ setup_requires =
ninja
packaging
install_requires =
torch>=2.4.1, <2.7.0
torch>=2.4.1
numpy>=1.24,<3.0

[options.extras_require]
Expand Down
9 changes: 6 additions & 3 deletions tests/integration/test_pad_collation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,10 @@

from __future__ import annotations

import os
import random
import unittest
from contextlib import redirect_stderr
from functools import wraps

import numpy as np
Expand All @@ -35,7 +37,7 @@
RandZoomd,
ToTensor,
)
from monai.utils import set_determinism
from monai.utils import first, set_determinism


@wraps(pad_list_data_collate)
Expand Down Expand Up @@ -97,8 +99,9 @@ def test_pad_collation(self, t_type, collate_method, transform):
# Default collation should raise an error
loader_fail = DataLoader(dataset, batch_size=10)
with self.assertRaises(RuntimeError):
for _ in loader_fail:
pass
# stifle PyTorch error reporting, we expect failure so don't need to look at it
with open(os.devnull) as f, redirect_stderr(f):
_ = first(loader_fail)

# Padded collation shouldn't
loader = DataLoader(dataset, batch_size=10, collate_fn=collate_method)
Expand Down
8 changes: 8 additions & 0 deletions tests/lazy_transforms_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from __future__ import annotations

import sys
from copy import deepcopy

from monai.data import MetaTensor, set_track_meta
Expand Down Expand Up @@ -62,6 +63,13 @@ def test_resampler_lazy(
resampler.set_random_state(seed=seed)
set_track_meta(True)
resampler.lazy = True

# FIXME: this is a fix for https://github.com/Project-MONAI/MONAI/pull/8429, remove when PyTorch has
# fixed the underlying issue
if sys.platform == "win32":
atol = 1e-4
rtol = 1e-4

pending_output = resampler(**deepcopy(call_param))
if output_idx is not None:
expected_output, pending_output = (expected_output[output_idx], pending_output[output_idx])
Expand Down
Loading