Skip to content

[float8] add _auto_filter_for_recipe for float8 training #1319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/float8.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ./run_trai
* `--float8.enable_fsdp_float8_all_gather`: cast `Float8Linear.weight` from high precision to float8 before FSDP all-gather so we can communicate in float8 to save bandwidth.
* `--float8.precompute_float8_dynamic_scale_for_fsdp` (optional): communicate AMAX/scales efficiently in a single all-reduce for all parameters instead of doing many small all-reduce for each parameter.
* `--float8.force_recompute_fp8_weight_in_bwd` (optional): force recomputation of fp8 weights during backward pass, preventing unsharded fp8 weights from being saved for backward.
* `--float8.filter_fqns="..."` (optional): a comma separated list of fully qualified names of modules not to convert to float8 training. Example: `--float8.filter_fqns="attention.wk,attention.wv"`. You can determine which layers to convert by looking at the microbenchmarks in the [performance section](https://github.com/pytorch/ao/tree/main/torchao/float8#performance) of the torchao documentation for the float8 recipe you're using.
* **Auto-filter**: add `"auto_filter_low_kn"` as one of the `--float8.filter_fqns=...` to to enable automatic module filtering, which will automatically not convert linear layers whose K,N dimensions are not large enough to benefit from float8 training. The thresholds for conversion are based on microbenchmarks measured on NVIDIA H100 GPUs. For best performance, you should still manually filter out layers that are too small to benefit from float8 training.
* `--training.compile` (required for competitive performance): use `torch.compile` to fuse the float8 scaling/casting kernels

For float8 with rowwise scaling, launch training job with the following command (or alternatively set configs in toml files)
Expand Down
68 changes: 57 additions & 11 deletions torchtitan/components/quantization/float8.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from functools import partial

import torch
Expand All @@ -20,6 +19,8 @@

from .utils import module_filter_fn

AUTO_FILTER_SMALL_KN_FLAG = "auto_filter_small_kn"


class Float8Converter(ModelConverter):
def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
Expand Down Expand Up @@ -52,15 +53,18 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
return

self.enabled = True
self.filter_fqns = float8_config.filter_fqns

if float8_config.recipe_name is not None:
assert (
not float8_config.enable_fsdp_float8_all_gather
), "using `float8_config.enable_fsdp_float8_all_gather` together with `float8_config.recipe_name` is not supported"
assert (
not float8_config.force_recompute_fp8_weight_in_bwd
), "using `float8_config.force_recompute_fp8_weight_in_bwd` together with `float8_config.recipe_name` is not supported"
assert not float8_config.enable_fsdp_float8_all_gather, (
"using `float8_config.enable_fsdp_float8_all_gather` together "
"with `float8_config.recipe_name` is not supported"
)

assert not float8_config.force_recompute_fp8_weight_in_bwd, (
"using `float8_config.force_recompute_fp8_weight_in_bwd` together "
"with `float8_config.recipe_name` is not supported"
)

self.config = Float8LinearConfig.from_recipe_name(float8_config.recipe_name)
self.precompute_scale = False
logger.info(
Expand All @@ -73,7 +77,6 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
logger.debug(
"Set torch._inductor.config.emulate_precision_casts to True"
)

else:
# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
enable_fsdp_float8_all_gather = (
Expand All @@ -92,6 +95,50 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
)
logger.info("Float8 tensorwise scaled training active")

# configure the module filter function
self.filter_fn = self._init_filter_fn(float8_config)

def _init_filter_fn(self, float8_config: Float8):
# use auto_filter if filter_fqns "auto_filter_small_kn" is one of the given fqns.
use_auto_filter = AUTO_FILTER_SMALL_KN_FLAG in float8_config.filter_fqns
if use_auto_filter:
try:
from torchao.float8 import _auto_filter_for_recipe

logger.info(
"Using automatic module filter for float8 model conversion."
)

recipe_name = (
float8_config.recipe_name
if float8_config.recipe_name
else "tensorwise"
)

# remove auto filter flag from filter_fqns before passing to _auto_filter_for_recipe
fqns = [
fqn
for fqn in float8_config.filter_fqns
if fqn != AUTO_FILTER_SMALL_KN_FLAG
]

filter_fn = _auto_filter_for_recipe(
recipe_name,
filter_fqns=fqns,
)
return filter_fn
except ImportError:
logger.warning(
(
"Using default module_filter_fn for float8 model conversion. "
"To use _auto_filter_for_recipe, please install torchao nightly build."
)
)

# use default filter func
filter_fn = partial(module_filter_fn, filter_fqns=float8_config.filter_fqns)
return filter_fn

def convert(self, model: nn.Module):
"""
This function converts the linear layers of `model` to `Float8Linear`.
Expand All @@ -103,11 +150,10 @@ def convert(self, model: nn.Module):

from torchao.float8 import convert_to_float8_training

# Mutates the model inplace replacing instances of nn.Linear with Float8Linear
convert_to_float8_training(
model,
config=self.config,
module_filter_fn=partial(module_filter_fn, filter_fqns=self.filter_fqns),
module_filter_fn=self.filter_fn,
)
logger.info(
"Swapped to Float8Linear layers with enable_fsdp_float8_all_gather="
Expand Down
Loading