Skip to content

[float8] add auto_filter_for_recipe to float8 #2410

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 78 additions & 2 deletions torchao/float8/float8_linear_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
# This source code is licensed under the BSD 3-Clause license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import Callable, Optional
from functools import partial
from typing import Callable, List, Optional

import torch.nn as nn

from torchao.float8.config import Float8LinearConfig
from torchao.float8.config import Float8LinearConfig, Float8LinearRecipeName
from torchao.float8.float8_linear import Float8Linear

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -113,3 +114,78 @@ def convert_to_float8_training(
from_float,
module_filter_fn=module_filter_fn,
)


def auto_filter_for_recipe(
recipe: Float8LinearRecipeName, filter_fqns: List[str]
) -> Callable[[nn.Module, str], bool]:
"""Automatically filters nn.Linear modules that meet at least one of the following criteria:

1. Dims not divisible by 16 (hardware requirement for float8).
2. Dim sizes below certain thresholds, which will result in worse performance.

NOTE: the thresholds are simple heuristics based on performance testing, and may not be optimal
for your model. For the best performance, we recommend defining your own module_filter_fn customized for
your module, using the performance tables for the given float8 recipe here:
https://github.com/pytorch/ao/tree/main/torchao/float8#performance).

The design of this function may change in the future.
"""
if recipe == Float8LinearRecipeName.TENSORWISE.value:
return partial(_auto_filter_for_tensorwise, filter_fqns=filter_fqns)
elif recipe == Float8LinearRecipeName.ROWWISE.value:
return partial(_auto_filter_for_rowwise, filter_fqns=filter_fqns)
elif recipe == Float8LinearRecipeName.ROWWISE_WITH_GW_HP.value:
raise NotImplementedError(f"Unsupported recipe: {recipe}")
else:
raise ValueError(f"Invalid recipe: {recipe}")


def _auto_filter_for_rowwise(mod: nn.Module, fqn: str, filter_fqns: List[str]) -> bool:
if not isinstance(mod, nn.Linear):
return False

# If the fqn matches any filtered fqn, then we should not convert this module.
is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns)
if is_filtered_fqn:
return False

# All dims must be divisible by 16 due to float8 hardware requirements.
K, N = mod.weight.shape
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
if not dims_multiples_of_16:
return False

# Dims below these thresholds will result in worse performance
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#rowwise-scaling)
if N <= 2048:
return False
elif K <= 1024:
return False
elif N <= 4096 and K <= 2048:
return False
return True


def _auto_filter_for_tensorwise(
mod: nn.Module, fqn: str, filter_fqns: List[str]
) -> bool:
if not isinstance(mod, nn.Linear):
return False

# If the fqn matches any filtered fqn, then we should not convert this module.
is_filtered_fqn = any(filter_fqn in fqn for filter_fqn in filter_fqns)
if is_filtered_fqn:
return False

# All dims must be divisible by 16 due to float8 hardware requirements.
K, N = mod.weight.shape
dims_multiples_of_16 = K % 16 == 0 and N % 16 == 0
if not dims_multiples_of_16:
return False

# Dims below these thresholds will result in worse performance
# (see https://github.com/pytorch/ao/tree/main/torchao/float8#tensorwise-scaling)
if K <= 4096 and N <= 1024:
return False
return True
Loading