Skip to content

MaisiVAE: Auto-cast GroupNorm, deprecate norm_float16 #8326

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions monai/apps/generation/maisi/networks/autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
@@ -43,7 +43,8 @@ class MaisiGroupNorm3D(nn.GroupNorm):
num_channels: Number of channels for the group norm.
eps: Epsilon value for numerical stability.
affine: Whether to use learnable affine parameters, default to `True`.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16, if False convert to float32.
If None, convert to the datatype of the input. Defaults to `False`.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
@@ -54,7 +55,7 @@ def __init__(
num_channels: int,
eps: float = 1e-5,
affine: bool = True,
norm_float16: bool = False,
norm_float16: bool | None = False,
print_info: bool = False,
save_mem: bool = True,
):
@@ -67,6 +68,8 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
if self.print_info:
logger.info(f"MaisiGroupNorm3D with input size: {input.size()}")

target_dtype = input.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the input is float32 but users want convert the output to the float16?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change only affects the group norm and makes the behavior in line with the rest of the model.
If I understand you correctly, to achieve what you want, the common pattern would be:

model = model.to(dtype=torch.bfloat16)
prediction = model(x.to(dtype=torch.bfloat16)

Or are you referring to something else?

Copy link
Contributor Author

@johnzielke johnzielke Apr 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this change, the parameter needs to be manually adjusted to produce a tensor that's compatible with the rest of the model. In addition, it could only be float32 or float16. Is there a reason one would want to have the GroupNorm in a different datatype than the rest of the model?


if len(input.shape) != 5:
raise ValueError("Expected a 5D tensor")

@@ -75,13 +78,17 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:

inputs = []
for i in range(input.size(1)):
array = input[:, i : i + 1, ...].to(dtype=torch.float32)
array = input[:, i : i + 1, ...]
if self.norm_float16 is not None:
array = array.to(dtype=torch.float32)
mean = array.mean([2, 3, 4, 5], keepdim=True)
std = array.var([2, 3, 4, 5], unbiased=False, keepdim=True).add_(self.eps).sqrt_()
if self.norm_float16:
if self.norm_float16 is None:
inputs.append(((array - mean) / std).to(dtype=target_dtype))
elif self.norm_float16:
inputs.append(((array - mean) / std).to(dtype=torch.float16))
else:
inputs.append((array - mean) / std)
inputs.append(((array - mean) / std).to(dtype=torch.float32))

del input
_empty_cuda_cache(self.save_mem)
@@ -393,7 +400,8 @@ class MaisiResBlock(nn.Module):
out_channels: Number of output channels.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16, if False convert to float32.
If None, convert to the datatype of the input. Defaults to `False`.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
@@ -407,7 +415,7 @@ def __init__(
out_channels: int,
num_splits: int,
dim_split: int,
norm_float16: bool = False,
norm_float16: bool | None = False,
print_info: bool = False,
save_mem: bool = True,
) -> None:
@@ -524,7 +532,8 @@ class MaisiEncoder(nn.Module):
use_flash_attention: If True, use flash attention for a memory efficient attention mechanism.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16, if False convert to float32.
If None, convert to the datatype of the input. Defaults to `False`.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
@@ -541,7 +550,7 @@ def __init__(
attention_levels: Sequence[bool],
num_splits: int,
dim_split: int,
norm_float16: bool = False,
norm_float16: bool | None = False,
print_info: bool = False,
save_mem: bool = True,
with_nonlocal_attn: bool = True,
@@ -714,7 +723,8 @@ class MaisiDecoder(nn.Module):
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16, if False convert to float32.
If None, convert to the datatype of the input. Defaults to `False`.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
@@ -731,7 +741,7 @@ def __init__(
attention_levels: Sequence[bool],
num_splits: int,
dim_split: int,
norm_float16: bool = False,
norm_float16: bool | None = False,
print_info: bool = False,
save_mem: bool = True,
with_nonlocal_attn: bool = True,
@@ -905,7 +915,8 @@ class AutoencoderKlMaisi(AutoencoderKL):
use_convtranspose: If True, use ConvTranspose to upsample feature maps in decoder.
num_splits: Number of splits for the input tensor.
dim_split: Dimension of splitting for the input tensor.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16 format, default to `False`.
norm_float16: If True, convert output of MaisiGroupNorm3D to float16, if False convert to float32.
If None, convert to the datatype of the input. Defaults to `False`.
print_info: Whether to print information, default to `False`.
save_mem: Whether to clean CUDA cache in order to save GPU memory, default to `True`.
"""
@@ -930,7 +941,7 @@ def __init__(
use_convtranspose: bool = False,
num_splits: int = 16,
dim_split: int = 0,
norm_float16: bool = False,
norm_float16: bool | None = False,
print_info: bool = False,
save_mem: bool = True,
) -> None:
32 changes: 24 additions & 8 deletions tests/apps/maisi/networks/test_autoencoderkl_maisi.py
Original file line number Diff line number Diff line change
@@ -75,27 +75,43 @@
else:
CASES = CASES_NO_ATTENTION

test_dtypes = [torch.float32]
if device.type == "cuda":
test_dtypes.append(torch.bfloat16)
test_dtypes.append(torch.float16)

DTYPE_CASES = []
for dtype in test_dtypes:
for case in CASES:
for norm_float in [False, None]:
if dtype != torch.float32 and norm_float is not None:
continue
new_case = [{**case[0], "norm_float16": norm_float}, case[1], case[2], case[3]] # type: ignore[dict-item]
DTYPE_CASES.append(new_case + [dtype])


class TestAutoencoderKlMaisi(unittest.TestCase):
@parameterized.expand(CASES)
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape):
net = AutoencoderKlMaisi(**input_param).to(device)

@parameterized.expand(DTYPE_CASES)
def test_shape(self, input_param, input_shape, expected_shape, expected_latent_shape, dtype):
net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype)
print(input_param)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype))
self.assertEqual(result[0].shape, expected_shape)
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)

@parameterized.expand(CASES)
@parameterized.expand(DTYPE_CASES)
@SkipIfBeforePyTorchVersion((1, 11))
def test_shape_with_convtranspose_and_checkpointing(
self, input_param, input_shape, expected_shape, expected_latent_shape
self, input_param, input_shape, expected_shape, expected_latent_shape, dtype
):
input_param = input_param.copy()
input_param.update({"use_checkpointing": True, "use_convtranspose": True})
net = AutoencoderKlMaisi(**input_param).to(device)
net = AutoencoderKlMaisi(**input_param).to(device=device, dtype=dtype)
with eval_mode(net):
result = net.forward(torch.randn(input_shape).to(device))
result = net.forward(torch.randn(input_shape).to(device=device, dtype=dtype))
self.assertEqual(result[0].shape, expected_shape)
self.assertEqual(result[1].shape, expected_latent_shape)
self.assertEqual(result[2].shape, expected_latent_shape)