Skip to content

Commit 01ed8ab

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 818a548 commit 01ed8ab

File tree

7 files changed

+322
-224
lines changed

7 files changed

+322
-224
lines changed

data/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ The output of this step is multiple JSON files, each file corresponds
8181
to one dataset.
8282

8383
##### 2. Add label_dict.json and label_mapping.json
84-
Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`.
84+
Add new class indexes to `label_dict.json` and the local to global mapping to `label_mapping.json`.
8585

8686
## SupverVoxel Generation
8787
1. Download the segment anything repo and download the ViT-H weights

scripts/export.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313
import os
1414
import sys
15+
import time
1516
from functools import partial
1617

1718
import monai
@@ -32,7 +33,6 @@
3233
from .train import CONFIG
3334
from .utils.trans_utils import VistaPostTransform, get_largest_connected_component_point
3435
from .utils.trt_utils import ExportWrapper, TRTWrapper
35-
import time
3636

3737
rearrange, _ = optional_import("einops", name="rearrange")
3838
sys.path.insert(0, os.path.abspath(os.path.dirname(__file__)))
@@ -131,16 +131,20 @@ def __init__(self, config_file="./configs/infer.yaml", **override):
131131
self.prev_mask = None
132132
self.batch_data = None
133133

134-
en_wrapper = ExportWrapper.wrap(self.model.image_encoder.encoder,
135-
input_names = ['x'], output_names = ['x_out'])
134+
en_wrapper = ExportWrapper.wrap(
135+
self.model.image_encoder.encoder, input_names=["x"], output_names=["x_out"]
136+
)
136137
self.model.image_encoder.encoder = TRTWrapper("Encoder", en_wrapper)
137138
self.model.image_encoder.encoder.load_engine()
138139

139-
cls_wrapper = ExportWrapper.wrap(self.model.class_head,
140-
input_names = ['src', 'class_vector'], output_names = ['masks', 'class_embedding'])
140+
cls_wrapper = ExportWrapper.wrap(
141+
self.model.class_head,
142+
input_names=["src", "class_vector"],
143+
output_names=["masks", "class_embedding"],
144+
)
141145
self.model.class_head = TRTWrapper("ClassHead", cls_wrapper)
142146
self.model.class_head.load_engine()
143-
147+
144148
return
145149

146150
def clear_cache(self):
@@ -174,7 +178,7 @@ def infer(
174178
used together with prev_mask. If prev_mask is generated by N points, point_start should be N+1 to save
175179
time and avoid repeated inference. This is by default disabled.
176180
"""
177-
time00=time.time()
181+
time00 = time.time()
178182
self.model.eval()
179183
if not isinstance(image_file, dict):
180184
image_file = {"image": image_file}
@@ -277,7 +281,7 @@ def infer(
277281

278282
@torch.no_grad()
279283
def infer_everything(self, image_file, label_prompt=EVERYTHING_PROMPT, rank=0):
280-
time00=time.time()
284+
time00 = time.time()
281285
self.model.eval()
282286
device = f"cuda:{rank}"
283287
if not isinstance(image_file, dict):
@@ -344,8 +348,8 @@ def batch_infer_everything(self, datalist=str, basedir=str):
344348

345349
if __name__ == "__main__":
346350
try:
347-
#import torch_onnx
348-
#torch_onnx.patch_torch(error_report=True)
351+
# import torch_onnx
352+
# torch_onnx.patch_torch(error_report=True)
349353
print("patch succeeded")
350354
except Exception:
351355
pass

scripts/utils/cast_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3-
#
3+
#
44
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
55
# property and proprietary rights in and to this material, related
66
# documentation and any modifications thereto. Any use, reproduction,
@@ -26,6 +26,7 @@
2626

2727
import torch
2828

29+
2930
def avoid_bfloat16_autocast_context():
3031
"""
3132
If the current autocast context is bfloat16,
@@ -70,7 +71,9 @@ def cast_all(x, from_dtype=torch.float16, to_dtype=torch.float32):
7071
new_dict[k] = cast_all(x[k], from_dtype=from_dtype, to_dtype=to_dtype)
7172
return new_dict
7273
elif isinstance(x, tuple):
73-
return tuple(cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x)
74+
return tuple(
75+
cast_all(y, from_dtype=from_dtype, to_dtype=to_dtype) for y in x
76+
)
7477

7578

7679
class CastToFloat(torch.nn.Module):
@@ -92,5 +95,7 @@ def __init__(self, mod):
9295
def forward(self, *args):
9396
from_dtype = args[0].dtype
9497
with torch.cuda.amp.autocast(enabled=False):
95-
ret = self.mod.forward(*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32))
98+
ret = self.mod.forward(
99+
*cast_all(args, from_dtype=from_dtype, to_dtype=torch.float32)
100+
)
96101
return cast_all(ret, from_dtype=torch.float32, to_dtype=from_dtype)

scripts/utils/export_utils.py

Lines changed: 75 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3-
#
3+
#
44
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
55
# property and proprietary rights in and to this material, related
66
# documentation and any modifications thereto. Any use, reproduction,
@@ -22,16 +22,14 @@
2222
# See the License for the specific language governing permissions and
2323
# limitations under the License.
2424

25-
import os
26-
from contextlib import nullcontext
27-
from enum import Enum
2825
from typing import Callable, Dict, Optional, Type
29-
import logging
26+
3027
import torch
3128
import torch.nn as nn
3229
import torch.nn.functional as F
3330

34-
from .cast_utils import CastToFloat, CastToFloatAll
31+
from .cast_utils import CastToFloat
32+
3533

3634
class LinearWithBiasSkip(nn.Module):
3735
def __init__(self, weight, bias, skip_bias_add):
@@ -45,7 +43,10 @@ def forward(self, x):
4543
return F.linear(x, self.weight), self.bias
4644
return F.linear(x, self.weight, self.bias), None
4745

48-
def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01):
46+
47+
def run_ts_and_compare(
48+
ts_model, ts_input_list, ts_input_dict, output_example, check_tolerance=0.01
49+
):
4950
# Verify the model can be read, and is valid
5051
ts_out = ts_model(*ts_input_list, **ts_input_dict)
5152

@@ -54,16 +55,20 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c
5455
expected = output_example[i]
5556

5657
if torch.is_tensor(expected):
57-
tout = out.to('cpu')
58+
tout = out.to("cpu")
5859
print(f"Checking output {i}, shape: {expected.shape}:\n")
5960
this_good = True
6061
try:
61-
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance):
62+
if not torch.allclose(
63+
tout, expected.cpu(), rtol=check_tolerance, atol=check_tolerance
64+
):
6265
this_good = False
6366
except Exception: # there may ne size mismatch and it may be OK
6467
this_good = False
6568
if not this_good:
66-
print(f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}")
69+
print(
70+
f"Results mismatch! PyTorch(expected):\n{expected}\nTorchScript:\n{tout}"
71+
)
6772
all_good = False
6873
return all_good
6974

@@ -80,12 +85,19 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
8085
print(f"Checking output {i}, shape: {expected.shape}:\n")
8186
this_good = True
8287
try:
83-
if not torch.allclose(tout, expected.cpu(), rtol=check_tolerance, atol=100 * check_tolerance):
88+
if not torch.allclose(
89+
tout,
90+
expected.cpu(),
91+
rtol=check_tolerance,
92+
atol=100 * check_tolerance,
93+
):
8494
this_good = False
8595
except Exception: # there may ne size mismatch and it may be OK
8696
this_good = False
8797
if not this_good:
88-
print(f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}")
98+
print(
99+
f"onnxruntime results mismatch! PyTorch(expected):\n{expected}\nONNXruntime:\n{tout}"
100+
)
89101
all_good = False
90102
return all_good
91103

@@ -96,7 +108,10 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
96108
from apex.contrib.layer_norm.layer_norm import FastLayerNorm
97109
from apex.normalization.fused_layer_norm import FusedLayerNorm, MixedFusedLayerNorm
98110
from apex.transformer.functional.fused_softmax import FusedScaleMaskSoftmax
99-
from apex.transformer.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
111+
from apex.transformer.tensor_parallel.layers import (
112+
ColumnParallelLinear,
113+
RowParallelLinear,
114+
)
100115

101116
def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
102117
"""
@@ -115,7 +130,9 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
115130
else:
116131
return None
117132

118-
mod = nn.LayerNorm(shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype)
133+
mod = nn.LayerNorm(
134+
shape, eps=eps, elementwise_affine=affine, device=p.device, dtype=p.dtype
135+
)
119136
n_state = n.state_dict()
120137
mod.load_state_dict(n_state)
121138
return mod
@@ -129,7 +146,9 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
129146
Equivalent LayerNorm module
130147
"""
131148
if not isinstance(n, RowParallelLinear):
132-
raise ValueError("This function can only change the RowParallelLinear module.")
149+
raise ValueError(
150+
"This function can only change the RowParallelLinear module."
151+
)
133152

134153
dev = next(n.parameters()).device
135154
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(device=dev)
@@ -146,8 +165,12 @@ def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
146165
Returns:
147166
Equivalent Linear module
148167
"""
149-
if not (isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear)):
150-
raise ValueError("This function can only change the ColumnParallelLinear or RowParallelLinear module.")
168+
if not (
169+
isinstance(n, ColumnParallelLinear) or isinstance(n, RowParallelLinear)
170+
):
171+
raise ValueError(
172+
"This function can only change the ColumnParallelLinear or RowParallelLinear module."
173+
)
151174

152175
dev = next(n.parameters()).device
153176
mod = LinearWithBiasSkip(n.weight, n.bias, n.skip_bias_add).to(dev)
@@ -165,11 +188,19 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
165188
Equivalent LayerNorm module
166189
"""
167190
if not isinstance(n, FusedScaleMaskSoftmax):
168-
raise ValueError("This function can only change the FusedScaleMaskSoftmax module.")
191+
raise ValueError(
192+
"This function can only change the FusedScaleMaskSoftmax module."
193+
)
169194

170195
# disable the fusion only
171196
mod = FusedScaleMaskSoftmax(
172-
n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale
197+
n.input_in_fp16,
198+
n.input_in_bf16,
199+
n.attn_mask_type,
200+
False,
201+
n.mask_func,
202+
n.softmax_in_fp32,
203+
n.scale,
173204
)
174205

175206
return mod
@@ -178,18 +209,20 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
178209
"FusedLayerNorm": replace_FusedLayerNorm,
179210
"MixedFusedLayerNorm": replace_FusedLayerNorm,
180211
"FastLayerNorm": replace_FusedLayerNorm,
181-
"ESM1bLayerNorm" : replace_FusedLayerNorm,
212+
"ESM1bLayerNorm": replace_FusedLayerNorm,
182213
"RowParallelLinear": replace_ParallelLinear,
183214
"ColumnParallelLinear": replace_ParallelLinear,
184215
"FusedScaleMaskSoftmax": replace_FusedScaleMaskSoftmax,
185216
}
186217

187-
except Exception as e:
218+
except Exception:
188219
default_Apex_replacements = {}
189220
apex_available = False
190221

191222

192-
def simple_replace(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]:
223+
def simple_replace(
224+
BaseT: Type[nn.Module], DestT: Type[nn.Module]
225+
) -> Callable[[nn.Module], Optional[nn.Module]]:
193226
"""
194227
Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied.
195228
Args:
@@ -218,18 +251,28 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
218251
exportable module
219252
"""
220253
# including the import here to avoid circular imports
221-
from nemo.collections.nlp.modules.common.megatron.fused_softmax import MatchedScaleMaskSoftmax
254+
from nemo.collections.nlp.modules.common.megatron.fused_softmax import (
255+
MatchedScaleMaskSoftmax,
256+
)
222257

223258
# disabling fusion for the MatchedScaleMaskSoftmax
224259
mod = MatchedScaleMaskSoftmax(
225-
n.input_in_fp16, n.input_in_bf16, n.attn_mask_type, False, n.mask_func, n.softmax_in_fp32, n.scale
260+
n.input_in_fp16,
261+
n.input_in_bf16,
262+
n.attn_mask_type,
263+
False,
264+
n.mask_func,
265+
n.softmax_in_fp32,
266+
n.scale,
226267
)
227268
return mod
228269

229270

230-
def wrap_module(BaseT: Type[nn.Module], DestT: Type[nn.Module]) -> Callable[[nn.Module], Optional[nn.Module]]:
271+
def wrap_module(
272+
BaseT: Type[nn.Module], DestT: Type[nn.Module]
273+
) -> Callable[[nn.Module], Optional[nn.Module]]:
231274
"""
232-
Generic function generator to replace BaseT module with DestT wrapper.
275+
Generic function generator to replace BaseT module with DestT wrapper.
233276
Args:
234277
BaseT : module type to replace
235278
DestT : destination module type
@@ -256,14 +299,15 @@ def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]):
256299
expanded_path = path.split(".")
257300
parent_mod = model
258301
for sub_path in expanded_path[:-1]:
259-
parent_mod = parent_mod._modules[sub_path] # noqa
260-
parent_mod._modules[expanded_path[-1]] = new_mod # noqa
302+
parent_mod = parent_mod._modules[sub_path]
303+
parent_mod._modules[expanded_path[-1]] = new_mod
261304

262305
return model
263306

264307

265308
def replace_modules(
266-
model: nn.Module, expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None
309+
model: nn.Module,
310+
expansions: Dict[str, Callable[[nn.Module], Optional[nn.Module]]] = None,
267311
) -> nn.Module:
268312
"""
269313
Top-level function to replace modules in model, specified by class name with a desired replacement.
@@ -308,7 +352,7 @@ def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module:
308352
if apex_available:
309353
print("Replacing Apex layers ...")
310354
replace_modules(model, default_Apex_replacements)
311-
355+
312356
if do_cast:
313357
print("Adding casts around norms...")
314358
cast_replacements = {
@@ -319,6 +363,6 @@ def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module:
319363
"InstanceNorm3d": wrap_module(nn.InstanceNorm3d, CastToFloat),
320364
}
321365
replace_modules(model, cast_replacements)
322-
366+
323367
# This one has to be the last
324368
replace_modules(model, script_replacements)

0 commit comments

Comments
 (0)