Skip to content

Commit 392c5c1

Browse files
authored
Support regular expression in the mapping arg of copy_model_state (#6917)
Part of #6552. ### Description After PR #6835, we have added `copy_model_args` in the `load` API which can help us update the state_dict flexibly. https://github.com/KumoLiu/MONAI/blob/93a149a611b66153cf804b31a7b36a939e2e593a/monai/bundle/scripts.py#L397 Given this [issue](#6552), we need to be able to filter the model's weights flexibly. In `copy_model_state`, we already have a "mapping" arg, the filter will be more flexible if we can support regular expression in the mapping. This PR mainly added the support for regular expression for "mapping" arg. In the [example](#6552 (comment)) in this [issue](#6552), after this PR, we can do something like: ``` exclude_vars = "encoder.mask_token|encoder.norm.weight|encoder.norm.bias|out.conv.conv.weight|out.conv.conv.bias" mapping={"encoder.layers(.*).0.0.": "swinViT.layers(.*).0."} dst_dict, updated_keys, unchanged_keys = copy_model_state( model, ssl_weights, exclude_vars=exclude_vars, mapping=mapping ) ``` Additionally, based on the comments of Eric [here](#6552 (comment)), I totally agree, we could add a handler to make the pipeline easier to implement, but perhaps this task is no need to set as a "BundleTodo" for MONAIv1.3 but as an enhancement for MONAI near future. What do you think? @ericspod @wyli @Nic-Ma ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <[email protected]>
1 parent 66f42c1 commit 392c5c1

File tree

4 files changed

+96
-2
lines changed

4 files changed

+96
-2
lines changed

monai/networks/nets/swin_unetr.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,3 +1055,52 @@ def forward(self, x, normalize=True):
10551055
x4 = self.layers4[0](x3.contiguous())
10561056
x4_out = self.proj_out(x4, normalize)
10571057
return [x0_out, x1_out, x2_out, x3_out, x4_out]
1058+
1059+
1060+
def filter_swinunetr(key, value):
1061+
"""
1062+
A filter function used to filter the pretrained weights from [1], then the weights can be loaded into MONAI SwinUNETR Model.
1063+
This function is typically used with `monai.networks.copy_model_state`
1064+
[1] "Valanarasu JM et al., Disruptive Autoencoders: Leveraging Low-level features for 3D Medical Image Pre-training
1065+
<https://arxiv.org/abs/2307.16896>"
1066+
1067+
Args:
1068+
key: the key in the source state dict used for the update.
1069+
value: the value in the source state dict used for the update.
1070+
1071+
Examples::
1072+
1073+
import torch
1074+
from monai.apps import download_url
1075+
from monai.networks.utils import copy_model_state
1076+
from monai.networks.nets.swin_unetr import SwinUNETR, filter_swinunetr
1077+
1078+
model = SwinUNETR(img_size=(96, 96, 96), in_channels=1, out_channels=3, feature_size=48)
1079+
resource = (
1080+
"https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth"
1081+
)
1082+
ssl_weights_path = "./ssl_pretrained_weights.pth"
1083+
download_url(resource, ssl_weights_path)
1084+
ssl_weights = torch.load(ssl_weights_path)["model"]
1085+
1086+
dst_dict, loaded, not_loaded = copy_model_state(model, ssl_weights, filter_func=filter_swinunetr)
1087+
1088+
"""
1089+
if key in [
1090+
"encoder.mask_token",
1091+
"encoder.norm.weight",
1092+
"encoder.norm.bias",
1093+
"out.conv.conv.weight",
1094+
"out.conv.conv.bias",
1095+
]:
1096+
return None
1097+
1098+
if key[:8] == "encoder.":
1099+
if key[8:19] == "patch_embed":
1100+
new_key = "swinViT." + key[8:]
1101+
else:
1102+
new_key = "swinViT." + key[8:18] + key[20:]
1103+
1104+
return new_key, value
1105+
else:
1106+
return None

monai/networks/utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,7 @@ def copy_model_state(
478478
mapping=None,
479479
exclude_vars=None,
480480
inplace=True,
481+
filter_func=None,
481482
):
482483
"""
483484
Compute a module state_dict, of which the keys are the same as `dst`. The values of `dst` are overwritten
@@ -490,7 +491,7 @@ def copy_model_state(
490491
491492
Args:
492493
dst: a pytorch module or state dict to be updated.
493-
src: a pytorch module or state dist used to get the values used for the update.
494+
src: a pytorch module or state dict used to get the values used for the update.
494495
dst_prefix: `dst` key prefix, so that `dst[dst_prefix + src_key]`
495496
will be assigned to the value of `src[src_key]`.
496497
mapping: a `{"src_key": "dst_key"}` dict, indicating that `dst[dst_prefix + dst_key]`
@@ -499,6 +500,8 @@ def copy_model_state(
499500
so that their values are not overwritten by `src`.
500501
inplace: whether to set the `dst` module with the updated `state_dict` via `load_state_dict`.
501502
This option is only available when `dst` is a `torch.nn.Module`.
503+
filter_func: a filter function used to filter the weights to be loaded.
504+
See 'filter_swinunetr' in "monai.networks.nets.swin_unetr.py".
502505
503506
Examples:
504507
.. code-block:: python
@@ -536,6 +539,12 @@ def copy_model_state(
536539
warnings.warn(f"Param. shape changed from {dst_dict[dst_key].shape} to {src_dict[s].shape}.")
537540
dst_dict[dst_key] = src_dict[s]
538541
updated_keys.append(dst_key)
542+
if filter_func is not None:
543+
for key, value in src_dict.items():
544+
new_pair = filter_func(key, value)
545+
if new_pair is not None and new_pair[0] not in to_skip:
546+
dst_dict[new_pair[0]] = new_pair[1]
547+
updated_keys.append(new_pair[0])
539548

540549
updated_keys = sorted(set(updated_keys))
541550
unchanged_keys = sorted(set(all_keys).difference(updated_keys))

tests/test_swin_unetr.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,20 @@
1111

1212
from __future__ import annotations
1313

14+
import os
15+
import tempfile
1416
import unittest
1517
from unittest import skipUnless
1618

1719
import torch
1820
from parameterized import parameterized
1921

22+
from monai.apps import download_url
2023
from monai.networks import eval_mode
21-
from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR
24+
from monai.networks.nets.swin_unetr import PatchMerging, PatchMergingV2, SwinUNETR, filter_swinunetr
25+
from monai.networks.utils import copy_model_state
2226
from monai.utils import optional_import
27+
from tests.utils import assert_allclose, skip_if_downloading_fails, skip_if_quick, testing_data_config
2328

2429
einops, has_einops = optional_import("einops")
2530

@@ -51,6 +56,14 @@
5156
case_idx += 1
5257
TEST_CASE_SWIN_UNETR.append(test_case)
5358

59+
TEST_CASE_FILTER = [
60+
[
61+
{"img_size": (96, 96, 96), "in_channels": 1, "out_channels": 14, "feature_size": 48, "use_checkpoint": True},
62+
"swinViT.layers1.0.blocks.0.norm1.weight",
63+
torch.tensor([0.9473, 0.9343, 0.8566, 0.8487, 0.8065, 0.7779, 0.6333, 0.5555]),
64+
]
65+
]
66+
5467

5568
class TestSWINUNETR(unittest.TestCase):
5669
@parameterized.expand(TEST_CASE_SWIN_UNETR)
@@ -93,6 +106,24 @@ def test_patch_merging(self):
93106
t = PatchMerging(dim)(torch.zeros((1, 21, 20, 20, dim)))
94107
self.assertEqual(t.shape, torch.Size([1, 11, 10, 10, 20]))
95108

109+
@parameterized.expand(TEST_CASE_FILTER)
110+
@skip_if_quick
111+
def test_filter_swinunetr(self, input_param, key, value):
112+
with skip_if_downloading_fails():
113+
with tempfile.TemporaryDirectory() as tempdir:
114+
file_name = "ssl_pretrained_weights.pth"
115+
data_spec = testing_data_config("models", f"{file_name.split('.', 1)[0]}")
116+
weight_path = os.path.join(tempdir, file_name)
117+
download_url(
118+
data_spec["url"], weight_path, hash_val=data_spec["hash_val"], hash_type=data_spec["hash_type"]
119+
)
120+
121+
ssl_weight = torch.load(weight_path)["model"]
122+
net = SwinUNETR(**input_param)
123+
dst_dict, loaded, not_loaded = copy_model_state(net, ssl_weight, filter_func=filter_swinunetr)
124+
assert_allclose(dst_dict[key][:8], value, atol=1e-4, rtol=1e-4, type_test=False)
125+
self.assertTrue(len(loaded) == 157 and len(not_loaded) == 2)
126+
96127

97128
if __name__ == "__main__":
98129
unittest.main()

tests/testing_data/data_config.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,11 @@
128128
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/se_resnext50_32x4d-a260b3a4.pth",
129129
"hash_type": "sha256",
130130
"hash_val": "a260b3a40f82dfe37c58d26a612bcf7bef0d27c6fed096226b0e4e9fb364168e"
131+
},
132+
"ssl_pretrained_weights": {
133+
"url": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/ssl_pretrained_weights.pth",
134+
"hash_type": "sha256",
135+
"hash_val": "c3564f40a6a051d3753a6d8fae5cc8eaf21ce8d82a9a3baf80748d15664055e8"
131136
}
132137
},
133138
"configs": {

0 commit comments

Comments
 (0)