11# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
3- #
3+ #
44# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
55# property and proprietary rights in and to this material, related
66# documentation and any modifications thereto. Any use, reproduction,
2222# See the License for the specific language governing permissions and
2323# limitations under the License.
2424
25- import os
26- from contextlib import nullcontext
27- from enum import Enum
2825from typing import Callable , Dict , Optional , Type
29- import logging
26+
3027import torch
3128import torch .nn as nn
3229import torch .nn .functional as F
3330
34- from .cast_utils import CastToFloat , CastToFloatAll
31+ from .cast_utils import CastToFloat
32+
3533
3634class LinearWithBiasSkip (nn .Module ):
3735 def __init__ (self , weight , bias , skip_bias_add ):
@@ -45,7 +43,10 @@ def forward(self, x):
4543 return F .linear (x , self .weight ), self .bias
4644 return F .linear (x , self .weight , self .bias ), None
4745
48- def run_ts_and_compare (ts_model , ts_input_list , ts_input_dict , output_example , check_tolerance = 0.01 ):
46+
47+ def run_ts_and_compare (
48+ ts_model , ts_input_list , ts_input_dict , output_example , check_tolerance = 0.01
49+ ):
4950 # Verify the model can be read, and is valid
5051 ts_out = ts_model (* ts_input_list , ** ts_input_dict )
5152
@@ -54,16 +55,20 @@ def run_ts_and_compare(ts_model, ts_input_list, ts_input_dict, output_example, c
5455 expected = output_example [i ]
5556
5657 if torch .is_tensor (expected ):
57- tout = out .to (' cpu' )
58+ tout = out .to (" cpu" )
5859 print (f"Checking output { i } , shape: { expected .shape } :\n " )
5960 this_good = True
6061 try :
61- if not torch .allclose (tout , expected .cpu (), rtol = check_tolerance , atol = check_tolerance ):
62+ if not torch .allclose (
63+ tout , expected .cpu (), rtol = check_tolerance , atol = check_tolerance
64+ ):
6265 this_good = False
6366 except Exception : # there may ne size mismatch and it may be OK
6467 this_good = False
6568 if not this_good :
66- print (f"Results mismatch! PyTorch(expected):\n { expected } \n TorchScript:\n { tout } " )
69+ print (
70+ f"Results mismatch! PyTorch(expected):\n { expected } \n TorchScript:\n { tout } "
71+ )
6772 all_good = False
6873 return all_good
6974
@@ -80,12 +85,19 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
8085 print (f"Checking output { i } , shape: { expected .shape } :\n " )
8186 this_good = True
8287 try :
83- if not torch .allclose (tout , expected .cpu (), rtol = check_tolerance , atol = 100 * check_tolerance ):
88+ if not torch .allclose (
89+ tout ,
90+ expected .cpu (),
91+ rtol = check_tolerance ,
92+ atol = 100 * check_tolerance ,
93+ ):
8494 this_good = False
8595 except Exception : # there may ne size mismatch and it may be OK
8696 this_good = False
8797 if not this_good :
88- print (f"onnxruntime results mismatch! PyTorch(expected):\n { expected } \n ONNXruntime:\n { tout } " )
98+ print (
99+ f"onnxruntime results mismatch! PyTorch(expected):\n { expected } \n ONNXruntime:\n { tout } "
100+ )
89101 all_good = False
90102 return all_good
91103
@@ -96,7 +108,10 @@ def run_ort_and_compare(sess, ort_input, output_example, check_tolerance=0.01):
96108 from apex .contrib .layer_norm .layer_norm import FastLayerNorm
97109 from apex .normalization .fused_layer_norm import FusedLayerNorm , MixedFusedLayerNorm
98110 from apex .transformer .functional .fused_softmax import FusedScaleMaskSoftmax
99- from apex .transformer .tensor_parallel .layers import ColumnParallelLinear , RowParallelLinear
111+ from apex .transformer .tensor_parallel .layers import (
112+ ColumnParallelLinear ,
113+ RowParallelLinear ,
114+ )
100115
101116 def replace_FusedLayerNorm (n : nn .Module ) -> Optional [nn .LayerNorm ]:
102117 """
@@ -115,7 +130,9 @@ def replace_FusedLayerNorm(n: nn.Module) -> Optional[nn.LayerNorm]:
115130 else :
116131 return None
117132
118- mod = nn .LayerNorm (shape , eps = eps , elementwise_affine = affine , device = p .device , dtype = p .dtype )
133+ mod = nn .LayerNorm (
134+ shape , eps = eps , elementwise_affine = affine , device = p .device , dtype = p .dtype
135+ )
119136 n_state = n .state_dict ()
120137 mod .load_state_dict (n_state )
121138 return mod
@@ -129,7 +146,9 @@ def replace_RowParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
129146 Equivalent LayerNorm module
130147 """
131148 if not isinstance (n , RowParallelLinear ):
132- raise ValueError ("This function can only change the RowParallelLinear module." )
149+ raise ValueError (
150+ "This function can only change the RowParallelLinear module."
151+ )
133152
134153 dev = next (n .parameters ()).device
135154 mod = LinearWithBiasSkip (n .weight , n .bias , n .skip_bias_add ).to (device = dev )
@@ -146,8 +165,12 @@ def replace_ParallelLinear(n: nn.Module) -> Optional[nn.Linear]:
146165 Returns:
147166 Equivalent Linear module
148167 """
149- if not (isinstance (n , ColumnParallelLinear ) or isinstance (n , RowParallelLinear )):
150- raise ValueError ("This function can only change the ColumnParallelLinear or RowParallelLinear module." )
168+ if not (
169+ isinstance (n , ColumnParallelLinear ) or isinstance (n , RowParallelLinear )
170+ ):
171+ raise ValueError (
172+ "This function can only change the ColumnParallelLinear or RowParallelLinear module."
173+ )
151174
152175 dev = next (n .parameters ()).device
153176 mod = LinearWithBiasSkip (n .weight , n .bias , n .skip_bias_add ).to (dev )
@@ -165,11 +188,19 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
165188 Equivalent LayerNorm module
166189 """
167190 if not isinstance (n , FusedScaleMaskSoftmax ):
168- raise ValueError ("This function can only change the FusedScaleMaskSoftmax module." )
191+ raise ValueError (
192+ "This function can only change the FusedScaleMaskSoftmax module."
193+ )
169194
170195 # disable the fusion only
171196 mod = FusedScaleMaskSoftmax (
172- n .input_in_fp16 , n .input_in_bf16 , n .attn_mask_type , False , n .mask_func , n .softmax_in_fp32 , n .scale
197+ n .input_in_fp16 ,
198+ n .input_in_bf16 ,
199+ n .attn_mask_type ,
200+ False ,
201+ n .mask_func ,
202+ n .softmax_in_fp32 ,
203+ n .scale ,
173204 )
174205
175206 return mod
@@ -178,18 +209,20 @@ def replace_FusedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
178209 "FusedLayerNorm" : replace_FusedLayerNorm ,
179210 "MixedFusedLayerNorm" : replace_FusedLayerNorm ,
180211 "FastLayerNorm" : replace_FusedLayerNorm ,
181- "ESM1bLayerNorm" : replace_FusedLayerNorm ,
212+ "ESM1bLayerNorm" : replace_FusedLayerNorm ,
182213 "RowParallelLinear" : replace_ParallelLinear ,
183214 "ColumnParallelLinear" : replace_ParallelLinear ,
184215 "FusedScaleMaskSoftmax" : replace_FusedScaleMaskSoftmax ,
185216 }
186217
187- except Exception as e :
218+ except Exception :
188219 default_Apex_replacements = {}
189220 apex_available = False
190221
191222
192- def simple_replace (BaseT : Type [nn .Module ], DestT : Type [nn .Module ]) -> Callable [[nn .Module ], Optional [nn .Module ]]:
223+ def simple_replace (
224+ BaseT : Type [nn .Module ], DestT : Type [nn .Module ]
225+ ) -> Callable [[nn .Module ], Optional [nn .Module ]]:
193226 """
194227 Generic function generator to replace BaseT module with DestT. BaseT and DestT should have same atrributes. No weights are copied.
195228 Args:
@@ -218,18 +251,28 @@ def replace_MatchedScaleMaskSoftmax(n: nn.Module) -> Optional[nn.Linear]:
218251 exportable module
219252 """
220253 # including the import here to avoid circular imports
221- from nemo .collections .nlp .modules .common .megatron .fused_softmax import MatchedScaleMaskSoftmax
254+ from nemo .collections .nlp .modules .common .megatron .fused_softmax import (
255+ MatchedScaleMaskSoftmax ,
256+ )
222257
223258 # disabling fusion for the MatchedScaleMaskSoftmax
224259 mod = MatchedScaleMaskSoftmax (
225- n .input_in_fp16 , n .input_in_bf16 , n .attn_mask_type , False , n .mask_func , n .softmax_in_fp32 , n .scale
260+ n .input_in_fp16 ,
261+ n .input_in_bf16 ,
262+ n .attn_mask_type ,
263+ False ,
264+ n .mask_func ,
265+ n .softmax_in_fp32 ,
266+ n .scale ,
226267 )
227268 return mod
228269
229270
230- def wrap_module (BaseT : Type [nn .Module ], DestT : Type [nn .Module ]) -> Callable [[nn .Module ], Optional [nn .Module ]]:
271+ def wrap_module (
272+ BaseT : Type [nn .Module ], DestT : Type [nn .Module ]
273+ ) -> Callable [[nn .Module ], Optional [nn .Module ]]:
231274 """
232- Generic function generator to replace BaseT module with DestT wrapper.
275+ Generic function generator to replace BaseT module with DestT wrapper.
233276 Args:
234277 BaseT : module type to replace
235278 DestT : destination module type
@@ -256,14 +299,15 @@ def swap_modules(model: nn.Module, mapping: Dict[str, nn.Module]):
256299 expanded_path = path .split ("." )
257300 parent_mod = model
258301 for sub_path in expanded_path [:- 1 ]:
259- parent_mod = parent_mod ._modules [sub_path ] # noqa
260- parent_mod ._modules [expanded_path [- 1 ]] = new_mod # noqa
302+ parent_mod = parent_mod ._modules [sub_path ]
303+ parent_mod ._modules [expanded_path [- 1 ]] = new_mod
261304
262305 return model
263306
264307
265308def replace_modules (
266- model : nn .Module , expansions : Dict [str , Callable [[nn .Module ], Optional [nn .Module ]]] = None
309+ model : nn .Module ,
310+ expansions : Dict [str , Callable [[nn .Module ], Optional [nn .Module ]]] = None ,
267311) -> nn .Module :
268312 """
269313 Top-level function to replace modules in model, specified by class name with a desired replacement.
@@ -308,7 +352,7 @@ def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module:
308352 if apex_available :
309353 print ("Replacing Apex layers ..." )
310354 replace_modules (model , default_Apex_replacements )
311-
355+
312356 if do_cast :
313357 print ("Adding casts around norms..." )
314358 cast_replacements = {
@@ -319,6 +363,6 @@ def replace_for_export(model: nn.Module, do_cast: bool = False) -> nn.Module:
319363 "InstanceNorm3d" : wrap_module (nn .InstanceNorm3d , CastToFloat ),
320364 }
321365 replace_modules (model , cast_replacements )
322-
366+
323367 # This one has to be the last
324368 replace_modules (model , script_replacements )
0 commit comments