File tree Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Expand file tree Collapse file tree 2 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -118,17 +118,19 @@ def convert(self, model: nn.Module):
118
118
# Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
119
119
# to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
120
120
if self .moe_fqns :
121
+ from torchao .quantization .quant_api import quantize_
121
122
from torchao .prototype .scaled_grouped_mm .conversion_utils import (
122
- convert_moe_to_float8_training ,
123
+ MoETrainingConfig ,
123
124
)
124
125
125
126
def moe_module_filter_fn (mod : nn .Module , cur_fqn : str ) -> bool :
126
127
for target_fqn in self .moe_fqns :
127
128
if target_fqn in cur_fqn :
128
129
return True
129
130
return False
130
-
131
- convert_moe_to_float8_training (model , module_filter_fn = moe_module_filter_fn )
131
+
132
+ config = MoETrainingConfig (module_filter_fn = moe_module_filter_fn )
133
+ quantize_ (model , config = config )
132
134
logger .info ("Converted MoE to float8" )
133
135
134
136
def post_optimizer_hook (self , model : nn .Module | list [nn .Module ]):
Original file line number Diff line number Diff line change @@ -228,6 +228,7 @@ def __init__(self, job_config: JobConfig):
228
228
model .to_empty (device = init_device )
229
229
with torch .no_grad ():
230
230
model .init_weights (buffer_device = buffer_device )
231
+ model = model .to (torch .bfloat16 )
231
232
model .train ()
232
233
233
234
self .model_parts = [model ]
You can’t perform that action at this time.
0 commit comments