Skip to content

8185 test refactor 2 #8405

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 57 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
57 commits
Select commit Hold shift + click to select a range
57d7b09
Merge branch 'dev' into 8185-test-refactor-2
garciadias Feb 27, 2025
5342cba
Merge remote-tracking branch 'upstream/dev' into 8185-test-refactor-2
garciadias Mar 28, 2025
0139125
Refactor self-attention test cases to use dict_product for parameter …
garciadias Mar 28, 2025
836cf6e
Refactor PatchEmbeddingBlock test cases to use dict_product for param…
garciadias Mar 28, 2025
b2ccb26
Refactor RetinaNet test cases to use dict_product for parameter combi…
garciadias Mar 28, 2025
0735dfd
Refactor test_meta_tensor to use dict_product for parameter combinations
garciadias Mar 28, 2025
45622d4
Refactor test_box_transform to use dict_product for parameter combina…
garciadias Mar 28, 2025
129f778
Autofix
garciadias Mar 28, 2025
ebae4e3
Fix mypy error
garciadias Mar 29, 2025
482e5bf
Fix missing parameter
garciadias Apr 1, 2025
9a0d5b1
DCO Remediation Commit for R. Garcia-Dias <[email protected]>
garciadias Apr 1, 2025
c06cad1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 1, 2025
a5596d7
revert change to tests/apps/detection/test_box_transform.py
garciadias Apr 10, 2025
b46ccc0
redesign dict_product to make more readable
garciadias Apr 11, 2025
6077ccd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
cd5f790
add license to temporary test
garciadias Apr 11, 2025
84d85ea
fix test param name
garciadias Apr 11, 2025
336b287
Simplify with list comprehension
garciadias Apr 11, 2025
0cedbb5
autofix
garciadias Apr 11, 2025
a9b0fc9
N806 not catch locally, but in CI
garciadias Apr 11, 2025
d2259b9
dict_product function to accept Iterable[Any] for improved flexibility
garciadias Apr 11, 2025
5922e2a
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 11, 2025
ef2388e
autofix
garciadias Apr 11, 2025
016d3b0
Fix test mistakes
garciadias Apr 11, 2025
ce798d7
autofix
garciadias Apr 11, 2025
8fad281
fix mypy errora
garciadias Apr 11, 2025
6c3eeac
Refactor TEST_CASES_CAB to use dict_product for cleaner test case gen…
garciadias May 9, 2025
c3b42bb
Refactor TEST_CASE_MaskedAutoEncoderViT to use dict_product for clean…
garciadias May 9, 2025
e5b5895
Refactor test_crossattention to use dict_product for cleaner test cas…
garciadias May 9, 2025
c9b41af
Refactor TEST_CASE_RESBLOCK to use dict_product for cleaner test case…
garciadias May 9, 2025
83fc0a4
autofix
garciadias May 9, 2025
da9fd6c
Refactor TEST_CASE_TRANSFORMERBLOCK to use dict_product for cleaner t…
garciadias May 9, 2025
621f445
Refactor TEST_CASE_UNETR_BASIC_BLOCK and TEST_UP_BLOCK to use dict_pr…
garciadias May 9, 2025
c3ef211
Refactor test cases in test_dynunet_block to use dict_product for cle…
garciadias May 9, 2025
5eaf55d
autofix
garciadias May 9, 2025
89c43b8
Refactor test cases in test_dynunet.py to use dict_product for cleane…
garciadias May 9, 2025
827460e
revert self attention changes
garciadias May 9, 2025
5c7cc18
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] May 9, 2025
b4c81f1
Refactor test cases in test_mednext.py to use dict_product for cleane…
garciadias May 9, 2025
61d20d3
Refactor test cases in test_segresnet_ds.py to use dict_product for c…
garciadias May 9, 2025
1dc304e
Refactor test cases in test_segresnet.py to use dict_product for clea…
garciadias May 9, 2025
79bc4f2
Refactor test cases in test_swin_unetr.py to use dict_product for cle…
garciadias May 9, 2025
3049126
Refactor test cases in test_transchex.py to use dict_product for clea…
garciadias May 9, 2025
91ff49a
Refactor test cases in test_unetr.py to use dict_product for cleaner …
garciadias May 9, 2025
a683162
Merge branch '8185-test-refactor-2' of https://github.com/garciadias/…
garciadias May 9, 2025
fbf47cd
Refactor test cases in test_vit.py to use dict_product for cleaner ge…
garciadias May 9, 2025
e8c0270
Refactor test cases in test_vitautoenc.py to use dict_product for cle…
garciadias May 9, 2025
2a010b1
Refactor test cases in test_spatial_resampled.py to use dict_product …
garciadias May 9, 2025
0240249
Refactor test cases in test_splitdimd.py to use dict_product for clea…
garciadias May 9, 2025
ce33936
Refactor test cases in test_spacing.py to use dict_product for cleane…
garciadias May 9, 2025
51a94c2
Refactor test cases in test_spatial_resample.py to use dict_product f…
garciadias May 9, 2025
74741e7
Refactor test cases in test_pad_mode.py to use dict_product for clean…
garciadias May 9, 2025
60f7201
Refactor test cases in test_spacing.py to use dict_product for cleane…
garciadias May 9, 2025
9214fca
Remove test_utils.py
garciadias May 9, 2025
76033a6
Merge remote-tracking branch 'upstream/dev' into 8185-test-refactor-2
garciadias May 9, 2025
1e3339d
Apply suggestions from code review
garciadias May 9, 2025
e00df4d
autofix
garciadias May 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion monai/utils/jupyter_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def plot_engine_status(


def _get_loss_from_output(
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor,
output: list[torch.Tensor | dict[str, torch.Tensor]] | dict[str, torch.Tensor] | torch.Tensor
) -> torch.Tensor:
"""Returns a single value from the network output, which is a dict or tensor."""

Expand Down
15 changes: 6 additions & 9 deletions tests/apps/detection/networks/test_retinanet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from monai.networks import eval_mode
from monai.networks.nets import resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200
from monai.utils import ensure_tuple, optional_import
from tests.test_utils import SkipIfBeforePyTorchVersion, skip_if_quick, test_onnx_save, test_script_save
from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product, skip_if_quick, test_onnx_save, test_script_save

_, has_torchvision = optional_import("torchvision")

Expand Down Expand Up @@ -86,15 +86,12 @@
(2, 1, 32, 64),
]

TEST_CASES = []
for case in [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES.append([model, *case])
# Create all test case combinations using dict_product
CASE_LIST = [TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_2_A, TEST_CASE_3_A]
MODEL_LIST = [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]

TEST_CASES_TS = []
for case in [TEST_CASE_1]:
for model in [resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200]:
TEST_CASES_TS.append([model, *case])
TEST_CASES = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=CASE_LIST)]
TEST_CASES_TS = [[params["model"], *params["case"]] for params in dict_product(model=MODEL_LIST, case=[TEST_CASE_1])]


@SkipIfBeforePyTorchVersion((1, 12))
Expand Down
10 changes: 5 additions & 5 deletions tests/data/meta_tensor/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
from monai.data.utils import decollate_batch, list_data_collate
from monai.transforms import BorderPadd, Compose, DivisiblePadd, FromMetaTensord, ToMetaTensord
from monai.utils.enums import PostFix
from tests.test_utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, skip_if_no_cuda
from tests.test_utils import TEST_DEVICES, SkipIfBeforePyTorchVersion, assert_allclose, dict_product, skip_if_no_cuda

DTYPES = [[torch.float32], [torch.float64], [torch.float16], [torch.int64], [torch.int32], [None]]
TESTS = []
for _device in TEST_DEVICES:
for _dtype in DTYPES:
TESTS.append((*_device, *_dtype)) # type: ignore

# Replace nested loops with dict_product

TESTS = [(*params["device"], *params["dtype"]) for params in dict_product(device=TEST_DEVICES, dtype=DTYPES)]


def rand_string(min_len=5, max_len=10):
Expand Down
36 changes: 15 additions & 21 deletions tests/networks/blocks/test_CABlock.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,28 +20,24 @@
from monai.networks import eval_mode
from monai.networks.blocks.cablock import CABlock, FeedForward
from monai.utils import optional_import
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, dict_product

einops, has_einops = optional_import("einops")


TEST_CASES_CAB = []
for spatial_dims in [2, 3]:
for dim in [32, 64, 128]:
for num_heads in [2, 4, 8]:
for bias in [True, False]:
test_case = [
{
"spatial_dims": spatial_dims,
"dim": dim,
"num_heads": num_heads,
"bias": bias,
"flash_attention": False,
},
(2, dim, *([16] * spatial_dims)),
(2, dim, *([16] * spatial_dims)),
]
TEST_CASES_CAB.append(test_case)
TEST_CASES_CAB = [
[
{
"spatial_dims": params["spatial_dims"],
"dim": params["dim"],
"num_heads": params["num_heads"],
"bias": params["bias"],
"flash_attention": False,
},
(2, params["dim"], *([16] * params["spatial_dims"])),
(2, params["dim"], *([16] * params["spatial_dims"])),
]
for params in dict_product(spatial_dims=[2, 3], dim=[32, 64, 128], num_heads=[2, 4, 8], bias=[True, False])
]


TEST_CASES_FEEDFORWARD = [
Expand All @@ -53,7 +49,6 @@


class TestFeedForward(unittest.TestCase):

@parameterized.expand(TEST_CASES_FEEDFORWARD)
def test_shape(self, input_param, input_shape):
net = FeedForward(**input_param)
Expand All @@ -69,7 +64,6 @@ def test_gating_mechanism(self):


class TestCABlock(unittest.TestCase):

@parameterized.expand(TEST_CASES_CAB)
@skipUnless(has_einops, "Requires einops")
def test_shape(self, input_param, input_shape, expected_shape):
Expand Down
44 changes: 23 additions & 21 deletions tests/networks/blocks/test_crossattention.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,30 +22,32 @@
from monai.networks.blocks.crossattention import CrossAttentionBlock
from monai.networks.layers.factories import RelPosEmbedding
from monai.utils import optional_import
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose
from tests.test_utils import SkipIfBeforePyTorchVersion, assert_allclose, dict_product

einops, has_einops = optional_import("einops")

TEST_CASE_CABLOCK = []
for dropout_rate in np.linspace(0, 1, 4):
for hidden_size in [360, 480, 600, 768]:
for num_heads in [4, 6, 8, 12]:
for rel_pos_embedding in [None, RelPosEmbedding.DECOMPOSED]:
for input_size in [(16, 32), (8, 8, 8)]:
for flash_attn in [True, False]:
test_case = [
{
"hidden_size": hidden_size,
"num_heads": num_heads,
"dropout_rate": dropout_rate,
"rel_pos_embedding": rel_pos_embedding if not flash_attn else None,
"input_size": input_size,
"use_flash_attention": flash_attn,
},
(2, 512, hidden_size),
(2, 512, hidden_size),
]
TEST_CASE_CABLOCK.append(test_case)
TEST_CASE_CABLOCK = [
[
{
"hidden_size": params["hidden_size"],
"num_heads": params["num_heads"],
"dropout_rate": params["dropout_rate"],
"rel_pos_embedding": params["rel_pos_embedding_val"] if not params["flash_attn"] else None,
"input_size": params["input_size"],
"use_flash_attention": params["flash_attn"],
},
(2, 512, params["hidden_size"]),
(2, 512, params["hidden_size"]),
]
for params in dict_product(
dropout_rate=np.linspace(0, 1, 4),
hidden_size=[360, 480, 600, 768],
num_heads=[4, 6, 8, 12],
rel_pos_embedding_val=[None, RelPosEmbedding.DECOMPOSED],
input_size=[(16, 32), (8, 8, 8)],
flash_attn=[True, False],
)
]


class TestResBlock(unittest.TestCase):
Expand Down
111 changes: 64 additions & 47 deletions tests/networks/blocks/test_dynunet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,58 +18,75 @@

from monai.networks import eval_mode
from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, UnetUpBlock, get_padding
from tests.test_utils import test_script_save
from tests.test_utils import dict_product, test_script_save

TEST_CASE_RES_BASIC_BLOCK = []
for spatial_dims in range(2, 4):
for kernel_size in [1, 3]:
for stride in [1, 2]:
for norm_name in [("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"]:
for in_size in [15, 16]:
padding = get_padding(kernel_size, stride)
if not isinstance(padding, int):
padding = padding[0]
out_size = int((in_size + 2 * padding - kernel_size) / stride) + 1
test_case = [
{
"spatial_dims": spatial_dims,
"in_channels": 16,
"out_channels": 16,
"kernel_size": kernel_size,
"norm_name": norm_name,
"act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.1}),
"stride": stride,
},
(1, 16, *([in_size] * spatial_dims)),
(1, 16, *([out_size] * spatial_dims)),
]
TEST_CASE_RES_BASIC_BLOCK.append(test_case)
for params in dict_product(
spatial_dims=range(2, 4),
kernel_size=[1, 3],
stride=[1, 2],
norm_name=[("GROUP", {"num_groups": 16}), ("batch", {"track_running_stats": False}), "instance"],
in_size=[15, 16],
):
spatial_dims = params["spatial_dims"]
kernel_size = params["kernel_size"]
stride = params["stride"]
norm_name = params["norm_name"]
in_size = params["in_size"]

padding = get_padding(kernel_size, stride)
if not isinstance(padding, int):
padding = padding[0]
out_size = int((in_size + 2 * padding - kernel_size) / stride) + 1
test_case = [
{
"spatial_dims": spatial_dims,
"in_channels": 16,
"out_channels": 16,
"kernel_size": kernel_size,
"norm_name": norm_name,
"act_name": ("leakyrelu", {"inplace": True, "negative_slope": 0.1}),
"stride": stride,
},
(1, 16, *([in_size] * spatial_dims)),
(1, 16, *([out_size] * spatial_dims)),
]
TEST_CASE_RES_BASIC_BLOCK.append(test_case)

TEST_UP_BLOCK = []
in_channels, out_channels = 4, 2
for spatial_dims in range(2, 4):
for kernel_size in [1, 3]:
for stride in [1, 2]:
for norm_name in ["batch", "instance"]:
for in_size in [15, 16]:
for trans_bias in [True, False]:
out_size = in_size * stride
test_case = [
{
"spatial_dims": spatial_dims,
"in_channels": in_channels,
"out_channels": out_channels,
"kernel_size": kernel_size,
"norm_name": norm_name,
"stride": stride,
"upsample_kernel_size": stride,
"trans_bias": trans_bias,
},
(1, in_channels, *([in_size] * spatial_dims)),
(1, out_channels, *([out_size] * spatial_dims)),
(1, out_channels, *([in_size * stride] * spatial_dims)),
]
TEST_UP_BLOCK.append(test_case)
for params in dict_product(
spatial_dims=range(2, 4),
kernel_size=[1, 3],
stride=[1, 2],
norm_name=["batch", "instance"],
in_size=[15, 16],
trans_bias=[True, False],
):
spatial_dims = params["spatial_dims"]
kernel_size = params["kernel_size"]
stride = params["stride"]
norm_name = params["norm_name"]
in_size = params["in_size"]
trans_bias = params["trans_bias"]

out_size = in_size * stride
test_case = [
{
"spatial_dims": spatial_dims,
"in_channels": in_channels,
"out_channels": out_channels,
"kernel_size": kernel_size,
"norm_name": norm_name,
"stride": stride,
"upsample_kernel_size": stride,
"trans_bias": trans_bias,
},
(1, in_channels, *([in_size] * spatial_dims)),
(1, out_channels, *([out_size] * spatial_dims)),
(1, out_channels, *([in_size * stride] * spatial_dims)),
]
TEST_UP_BLOCK.append(test_case)


class TestResBasicBlock(unittest.TestCase):
Expand Down
81 changes: 32 additions & 49 deletions tests/networks/blocks/test_patchembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,58 +21,41 @@
from monai.networks import eval_mode
from monai.networks.blocks.patchembedding import PatchEmbed, PatchEmbeddingBlock
from monai.utils import optional_import
from tests.test_utils import SkipIfBeforePyTorchVersion
from tests.test_utils import SkipIfBeforePyTorchVersion, dict_product

einops, has_einops = optional_import("einops")

TEST_CASE_PATCHEMBEDDINGBLOCK = []
for dropout_rate in (0.5,):
for in_channels in [1, 4]:
for hidden_size in [96, 288]:
for img_size in [32, 64]:
for patch_size in [8, 16]:
for num_heads in [8, 12]:
for proj_type in ["conv", "perceptron"]:
for pos_embed_type in ["none", "learnable", "sincos"]:
# for classification in (False, True): # TODO: add classification tests
for nd in (2, 3):
test_case = [
{
"in_channels": in_channels,
"img_size": (img_size,) * nd,
"patch_size": (patch_size,) * nd,
"hidden_size": hidden_size,
"num_heads": num_heads,
"proj_type": proj_type,
"pos_embed_type": pos_embed_type,
"dropout_rate": dropout_rate,
},
(2, in_channels, *([img_size] * nd)),
(2, (img_size // patch_size) ** nd, hidden_size),
]
if nd == 2:
test_case[0]["spatial_dims"] = 2 # type: ignore
TEST_CASE_PATCHEMBEDDINGBLOCK.append(test_case)

TEST_CASE_PATCHEMBED = []
for patch_size in [2]:
for in_chans in [1, 4]:
for img_size in [96]:
for embed_dim in [6, 12]:
for norm_layer in [nn.LayerNorm]:
for nd in [2, 3]:
test_case = [
{
"patch_size": (patch_size,) * nd,
"in_chans": in_chans,
"embed_dim": embed_dim,
"norm_layer": norm_layer,
"spatial_dims": nd,
},
(2, in_chans, *([img_size] * nd)),
(2, embed_dim, *([img_size // patch_size] * nd)),
]
TEST_CASE_PATCHEMBED.append(test_case)

TEST_CASE_PATCHEMBEDDINGBLOCK = [
[
params,
(2, params["in_channels"], *([params["img_size"]] * params["spatial_dims"])),
(2, (params["img_size"] // params["patch_size"]) ** params["spatial_dims"], params["hidden_size"]),
]
for params in dict_product(
dropout_rate=[0.5],
in_channels=[1, 4],
hidden_size=[96, 288],
img_size=[32, 64],
patch_size=[8, 16],
num_heads=[8, 12],
proj_type=["conv", "perceptron"],
pos_embed_type=["none", "learnable", "sincos"],
spatial_dims=[2, 3],
)
]

img_size = 96
TEST_CASE_PATCHEMBED = [
[
params,
(2, params["in_chans"], *([img_size] * params["spatial_dims"])),
(2, params["embed_dim"], *([img_size // params["patch_size"]]) * params["spatial_dims"]),
]
for params in dict_product(
patch_size=[2], in_chans=[1, 4], embed_dim=[6, 12], norm_layer=[nn.LayerNorm], spatial_dims=[2, 3]
)
]


@SkipIfBeforePyTorchVersion((1, 11, 1))
Expand Down
Loading