Skip to content

Support FP8 accelerate config #39370

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open

Conversation

djsaunde
Copy link

What does this PR do?

Adds config options for mixed precision and fp8 config, which are supported in accelerate's Accelerator object.

Also parses the config field for the torchao ("AO") fp8 backend from dictionary values into the required Float8LinearConfig object.

This requires also a simple gating change to accelerate, which is actually covered by an existing PR: https://github.com/huggingface/accelerate/pull/3677/files#diff-2d7515874eaecac2687c7fc1a9c720be53f802bf14b4c3dcebe14ad443d075dcR501-R505.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

TODO

  • Add docs (?)
  • Add tests (?)
  • Add quick benchmarks

Note

I've tested this downstream in axolotl and find that it works. Performance benefits can be had for certain models when setting torch_compile: true; it's not yet clear to me which models + hyperparameter settings. I'll add some quick numbers here to demonstrate this.

@@ -1256,6 +1256,14 @@ class AcceleratorConfig:
Whether or not to use a pre-configured `AcceleratorState` or `PartialState` defined
before calling `TrainingArguments`. If `True`, an `Accelerator` or `PartialState`
must be initialized. May lead to issues using sweeps or hyperparameter tuning.
mixed_precision (`str`, *optional*, defaults to `"no"`):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can clarify how is it different from the args bf16 and fp16?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@qgallouedec is this because in training args, that bf16 and fp16 are implicitly mixed precision, vs float16 and bfloat16 in the args?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like bf16 and fp16 are ultimately used to set ACCELERATE_MIXED_PRECISION, which could more cleanly be passed as mixed_precision, since the AcceleratorConfig exposes this. Digging into the training arguments code, we have:

os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype
.

        # if training args is specified, it will override the one specified in the accelerate config
        if self.half_precision_backend != "apex":
            mixed_precision_dtype = os.environ.get("ACCELERATE_MIXED_PRECISION", "no")
            if self.fp16:
                mixed_precision_dtype = "fp16"
            elif self.bf16:
                mixed_precision_dtype = "bf16"
            os.environ["ACCELERATE_MIXED_PRECISION"] = mixed_precision_dtype

We set the fp16, bf16 from the training args, or from the ACCELERATE_MIXED_PRECISION if set, in that priority order. But, we could just use / pass the mixed_precision config as exposed by accelerate instead. fp8 is not set by the existing logic.

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this old PR that cleans a bunch of things, let me know if this will clean a bunch of things for you. If so, I can try to take over this PR myself and first merge this PR

Comment on lines +5207 to +5230
if "fp8_config" in accelerator_config and accelerator_config["fp8_config"] is not None:
if "backend" in accelerator_config["fp8_config"]:
recipe_kwargs = MP_BACKEND_TO_KWARGS[accelerator_config["fp8_config"]["backend"]]
fp8_config = accelerator_config["fp8_config"].copy()

if fp8_config["backend"] == "AO":
from torchao.float8 import Float8LinearConfig

if "recipe_name" in fp8_config:
recipe_name = fp8_config["recipe_name"]
fp8_config["config"] = (
Float8LinearConfig.from_recipe_name(recipe_name=recipe_name)
)
fp8_config.pop("recipe_name")
elif "config" in accelerator_config["fp8_config"]:
config = fp8_config["config"]
kwargs = {k: v for k, v in config.items() if v is not None}
fp8_config["config"] = Float8LinearConfig(
**kwargs
)

fp8_config.pop("backend")
kwargs_handlers = [recipe_kwargs(**fp8_config)]
args["kwargs_handlers"] = kwargs_handlers
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it will be easier to allow users to pass directly kwargs_handler in accelerate_config ?

@djsaunde
Copy link
Author

We have this old PR that cleans a bunch of things, let me know if this will clean a bunch of things for you. If so, I can try to take over this PR myself and first merge this PR

That would be great! For the moment we are patching things downstream in axolotl just to get things working. Having this eventually in an upstream release would be nice.

@SunMarc
Copy link
Member

SunMarc commented Jul 16, 2025

That would be great! For the moment we are patching things downstream in axolotl just to get things working. Having this eventually in an upstream release would be nice.

I'll spend some time very soon on that PR to clean a bit trainer then ! I'll ping you when this is done to have it reviewed !

@SunMarc
Copy link
Member

SunMarc commented Jul 17, 2025

Sorry, I forgot to add the link of the PR I was talking about. Here you go #37259

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants