Skip to content

Commit 261c66d

Browse files
use quantize_ for moe
1 parent a992026 commit 261c66d

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,17 +118,19 @@ def convert(self, model: nn.Module):
118118
# Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
119119
# to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
120120
if self.moe_fqns:
121+
from torchao.quantization.quant_api import quantize_
121122
from torchao.prototype.scaled_grouped_mm.conversion_utils import (
122-
convert_moe_to_float8_training,
123+
MoETrainingConfig,
123124
)
124125

125126
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
126127
for target_fqn in self.moe_fqns:
127128
if target_fqn in cur_fqn:
128129
return True
129130
return False
130-
131-
convert_moe_to_float8_training(model, module_filter_fn=moe_module_filter_fn)
131+
132+
config = MoETrainingConfig(module_filter_fn=moe_module_filter_fn)
133+
quantize_(model, config=config)
132134
logger.info("Converted MoE to float8")
133135

134136
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):

torchtitan/train.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ def __init__(self, job_config: JobConfig):
228228
model.to_empty(device=init_device)
229229
with torch.no_grad():
230230
model.init_weights(buffer_device=buffer_device)
231+
model = model.to(torch.bfloat16)
231232
model.train()
232233

233234
self.model_parts = [model]

0 commit comments

Comments
 (0)