Skip to content

Commit 96a71cd

Browse files
add float8 moe prototype
1 parent 3381277 commit 96a71cd

File tree

3 files changed

+24
-0
lines changed

3 files changed

+24
-0
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
4949

5050
self.enabled = True
5151
self.filter_fqns = float8_config.filter_fqns
52+
self.moe_fqns = float8_config.moe_fqns
5253

5354
if float8_config.recipe_name is not None:
5455
assert (
@@ -109,6 +110,22 @@ def convert(self, model: nn.Module):
109110
f"{self.config.enable_fsdp_float8_all_gather}"
110111
)
111112

113+
# Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
114+
# to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
115+
if self.moe_fqns:
116+
from torchao.prototype.scaled_grouped_mm.conversion_utils import (
117+
convert_moe_to_float8_training,
118+
)
119+
120+
def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
121+
for target_fqn in self.moe_fqns:
122+
if target_fqn in cur_fqn:
123+
return True
124+
return False
125+
126+
convert_moe_to_float8_training(model, module_filter_fn=moe_module_filter_fn)
127+
logger.info("Converted MoE to float8")
128+
112129
def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
113130
if not self.enabled:
114131
return

torchtitan/config_manager.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,12 @@ class Float8:
450450
Example: --float8.filter_fqns "attention.wq,attention.wk,attention.wv,output"
451451
"""
452452

453+
moe_fqns: list[str] | str = field(default_factory=list)
454+
"""
455+
Comma-separated list of fully qualified names of MoE modules to apply float8 rowwise training to.
456+
Example: --float8.moe_fqns="experts"
457+
"""
458+
453459

454460
@dataclass
455461
class MX:

torchtitan/experiments/llama4/train_configs/debug_model.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,4 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas
6969
enable_fsdp_float8_all_gather = false
7070
precompute_float8_dynamic_scale_for_fsdp = false
7171
filter_fqns = ["output", "router.gate"]
72+
moe_fqns = []

0 commit comments

Comments
 (0)