Skip to content

[float8] add float8 rowwise MoE prototype #1245

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented May 30, 2025

Summary

  • Adds --float8.moe_fqns_prototype="..." option to float8 training API
  • API accepts a comma-separated list of FQNs to apply MoE float8 training conversion to.
  • quanttize_ with the MoETrainingConfig will recursively swap nn.Parameter data tensors to a tensor subclass, which has an override for grouped_mm => dynamic quant + scaled grouped mm prototype. Context: see implementation of GroupedExperts here.

Testing

  • Tested via manual testing with torchao convert_moe_to_float8_training prototype (PR) and confirmed single GPU training works as expected.

Limitations

  • Only supports single GPU training so far.
  • Only performs grouped_mm override for routed experts (see condition here). For shared experts, I'll need to update the torchao prototype to support 3d A tensor (see torchtitan here).

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label May 30, 2025
@danielvegamyhre danielvegamyhre marked this pull request as draft May 30, 2025 03:46
@danielvegamyhre
Copy link
Contributor Author

danielvegamyhre commented May 30, 2025

cc @tianyu-l @vkuzo this is not ready to land yet but I wanted to discuss the API proposed here and make sure we are aligned. Happy to rework it, this is my initial idea on how it should look.

@tianyu-l
Copy link
Contributor

Thanks! The UI makes sense to me.

@@ -465,6 +465,12 @@ class Float8:
Not compatible with torch.compile.
"""

moe_fqns: list[str] | str = field(default_factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add "prototype" to the field name and add a link to the README in the docstring

@danielvegamyhre danielvegamyhre marked this pull request as ready for review June 10, 2025 14:52
@@ -465,6 +465,13 @@ class Float8:
Not compatible with torch.compile.
"""

moe_fqns_prototype: list[str] | str = field(default_factory=list)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to add "prototype" to config name?

Suggested change
moe_fqns_prototype: list[str] | str = field(default_factory=list)
moe_fqns: list[str] | str = field(default_factory=list)

Copy link
Contributor Author

@danielvegamyhre danielvegamyhre Jun 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vkuzo requested "prototype" be in the field name here. Unless I misunderstood the suggestion?

Alternatively we could omit "prototype" from the field name and just make sure the docstring/help text is very clear it is a prototype feature with limitations.

For context, I don't plan to land this until at least FSDP is supported (ideally TP as well).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK either way then. Also since this is an experiment folder, everything could be experimental.

@@ -69,3 +69,4 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output", "router.gate"]
moe_fqns = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's put something in the list

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added "experts" as the default value (this is what I've been testing with).

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, had two more comments

@@ -69,3 +69,4 @@ selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac bas
enable_fsdp_float8_all_gather = false
precompute_float8_dynamic_scale_for_fsdp = false
filter_fqns = ["output", "router.gate"]
moe_fqns = ["experts"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you want to capture the shared expert? If so may need to use "expert" instead of "experts"
https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/llama4/model/moe.py#L204

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is well-tested, let's put it into the other toml configs as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not yet, this is intentional - the routed experts work with FSDP and TP, but shared expert only works with FSDP right now. Still debugging an issue related to shared expert + TP.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Meta Open Source bot.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants