Skip to content

Fix AutoencoderKL docstrings. #8445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions monai/networks/nets/autoencoderkl.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,9 @@ class Encoder(nn.Module):
channels: sequence of block output channels.
out_channels: number of channels in the bottom layer (latent space) of the autoencoder.
num_res_blocks: number of residual blocks (see _ResBlock) per level.
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
norm_eps: epsilon for the normalization.
attention_levels: indicate which level from num_channels contain an attention block.
attention_levels: indicate which level from channels contain an attention block.
with_nonlocal_attn: if True use non-local attention block.
include_fc: whether to include the final linear layer. Default to True.
use_combined_linear: whether to use a single linear layer for qkv projection, default to False.
Expand Down Expand Up @@ -299,9 +299,9 @@ class Decoder(nn.Module):
in_channels: number of channels in the bottom layer (latent space) of the autoencoder.
out_channels: number of output channels.
num_res_blocks: number of residual blocks (see _ResBlock) per level.
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
norm_eps: epsilon for the normalization.
attention_levels: indicate which level from num_channels contain an attention block.
attention_levels: indicate which level from channels contain an attention block.
with_nonlocal_attn: if True use non-local attention block.
use_convtranspose: if True, use ConvTranspose to upsample feature maps in decoder.
include_fc: whether to include the final linear layer. Default to True.
Expand Down Expand Up @@ -483,7 +483,7 @@ class AutoencoderKL(nn.Module):
channels: number of output channels for each block.
attention_levels: sequence of levels to add attention.
latent_channels: latent embedding dimension.
norm_num_groups: number of groups for the GroupNorm layers, num_channels must be divisible by this number.
norm_num_groups: number of groups for the GroupNorm layers, channels must be divisible by this number.
norm_eps: epsilon for the normalization.
with_encoder_nonlocal_attn: if True use non-local attention block in the encoder.
with_decoder_nonlocal_attn: if True use non-local attention block in the decoder.
Expand Down Expand Up @@ -518,18 +518,18 @@ def __init__(

# All number of channels should be multiple of num_groups
if any((out_channel % norm_num_groups) != 0 for out_channel in channels):
raise ValueError("AutoencoderKL expects all num_channels being multiple of norm_num_groups")
raise ValueError("AutoencoderKL expects all channels being multiple of norm_num_groups")

if len(channels) != len(attention_levels):
raise ValueError("AutoencoderKL expects num_channels being same size of attention_levels")
raise ValueError("AutoencoderKL expects channels being same size of attention_levels")

if isinstance(num_res_blocks, int):
num_res_blocks = ensure_tuple_rep(num_res_blocks, len(channels))

if len(num_res_blocks) != len(channels):
raise ValueError(
"`num_res_blocks` should be a single integer or a tuple of integers with the same length as "
"`num_channels`."
"`channels`."
)

self.encoder: nn.Module = Encoder(
Expand Down
Loading