Skip to content

[float8] add float8 rowwise MoE prototype #1245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions torchtitan/components/quantization/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):

self.enabled = True
self.filter_fqns = float8_config.filter_fqns
self.moe_fqns = float8_config.moe_fqns_prototype

if float8_config.recipe_name is not None:
assert (
Expand Down Expand Up @@ -114,6 +115,30 @@ def convert(self, model: nn.Module):
f"{self.config.enable_fsdp_float8_all_gather}"
)

# Mutates the model inplace replacing instances of nn.Parameter with ScaledGroupedMMTensor,
# to perform dynamic float8 rowwise quantization + scaled grouped GEMMs for the target MoE FQNs.
if self.moe_fqns:
from torchao.quantization.quant_api import quantize_

try:
from torchao.prototype.moe_training.conversion_utils import (
MoETrainingConfig,
)
except ImportError as e:
raise ImportError(
"torchao installation does not have MoE training support. Please install torchao nightly build."
) from e

def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
for target_fqn in self.moe_fqns:
if target_fqn in cur_fqn:
return True
return False

config = MoETrainingConfig()
quantize_(model, config=config, filter_fn=moe_module_filter_fn)
logger.info("Converted MoE to float8")

def post_optimizer_hook(self, model: nn.Module | list[nn.Module]):
if not self.enabled:
return
Expand Down
7 changes: 7 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,13 @@ class Float8:
Not compatible with torch.compile.
"""

moe_fqns_prototype: list[str] | str = field(default_factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to add "prototype" to config name?

Suggested change
moe_fqns_prototype: list[str] | str = field(default_factory=list)
moe_fqns: list[str] | str = field(default_factory=list)

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo requested "prototype" be in the field name here. Unless I misunderstood the suggestion?

Alternatively we could omit "prototype" from the field name and just make sure the docstring/help text is very clear it is a prototype feature with limitations.

For context, I don't plan to land this until at least FSDP is supported (ideally TP as well).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK either way then. Also since this is an experiment folder, everything could be experimental.

"""
Comma-separated list of fully qualified names of MoE modules to apply float8 rowwise training to.
This is a prototype feature that requires the torchao nightly build.
Example: --float8.moe_fqns_prototype="experts"
"""


@dataclass
class MX:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output", "router.gate"]
moe_fqns = ["experts"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to capture the shared expert? If so may need to use "expert" instead of "experts"
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/llama4/model/moe.py#L204

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is well-tested, let's put it into the other toml configs as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet, this is intentional - the routed experts work with FSDP and TP, but shared expert only works with FSDP right now. Still debugging an issue related to shared expert + TP.

Loading