-
Notifications
You must be signed in to change notification settings - Fork 6k
Add Finegrained FP8 #11647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Finegrained FP8 #11647
Changes from all commits
d561f65
87c4dc4
83163bc
e97e2ef
33b2be7
0eb3989
e6e0c4a
243173f
c3cf41f
c7072fe
f54dcc0
94b2db5
2b71830
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
<!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
--> | ||
|
||
# FinegrainedFP8 | ||
|
||
## Overview | ||
|
||
## Usage | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -96,6 +96,8 @@ | |
else: | ||
_import_structure["quantizers.quantization_config"].append("TorchAoConfig") | ||
|
||
_import_structure["quantizers.quantization_config"].append("FinegrainedFP8Config") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We haven't moved to setting torch and accelerate as required deps yet (will happen soon). So this would have to go under a |
||
|
||
try: | ||
if not is_torch_available() and not is_accelerate_available() and not is_optimum_quanto_available(): | ||
raise OptionalDependencyNotAvailable() | ||
|
@@ -725,6 +727,8 @@ | |
else: | ||
from .quantizers.quantization_config import QuantoConfig | ||
|
||
from .quantizers.quantization_config import FinegrainedFP8Config | ||
|
||
try: | ||
if not is_onnx_available(): | ||
raise OptionalDependencyNotAvailable() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .finegrained_fp8_quantizer import FinegrainedFP8Quantizer |
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -0,0 +1,205 @@ | ||||||||||||||||||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional | ||||||||||||||||||||||
|
||||||||||||||||||||||
from ...utils import get_module_from_name, is_accelerate_available, is_torch_available, logging | ||||||||||||||||||||||
from ..base import DiffusersQuantizer | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
if is_torch_available(): | ||||||||||||||||||||||
import torch | ||||||||||||||||||||||
|
||||||||||||||||||||||
logger = logging.get_logger(__name__) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if TYPE_CHECKING: | ||||||||||||||||||||||
from ...models.modeling_utils import ModelMixin | ||||||||||||||||||||||
|
||||||||||||||||||||||
|
||||||||||||||||||||||
class FinegrainedFP8Quantizer(DiffusersQuantizer): | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
FP8 quantization implementation supporting both standard and MoE models. | ||||||||||||||||||||||
Supports both e4m3fn formats based on platform. | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we expand on this a bit? What are both |
||||||||||||||||||||||
""" | ||||||||||||||||||||||
|
||||||||||||||||||||||
requires_parameters_quantization = True | ||||||||||||||||||||||
requires_calibration = False | ||||||||||||||||||||||
required_packages = ["accelerate"] | ||||||||||||||||||||||
|
||||||||||||||||||||||
def __init__(self, quantization_config, **kwargs): | ||||||||||||||||||||||
super().__init__(quantization_config, **kwargs) | ||||||||||||||||||||||
self.quantization_config = quantization_config | ||||||||||||||||||||||
|
||||||||||||||||||||||
def validate_environment(self, *args, **kwargs): | ||||||||||||||||||||||
if not is_torch_available(): | ||||||||||||||||||||||
raise ImportError( | ||||||||||||||||||||||
"Using fp8 quantization requires torch >= 2.1.0" | ||||||||||||||||||||||
"Please install the latest version of torch ( pip install --upgrade torch )" | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if not is_accelerate_available(): | ||||||||||||||||||||||
raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)") | ||||||||||||||||||||||
|
||||||||||||||||||||||
if kwargs.get("from_tf", False) or kwargs.get("from_flax", False): | ||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||
"Converting into FP8 weights from tf/flax weights is currently not supported, " | ||||||||||||||||||||||
"please make sure the weights are in PyTorch format." | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if torch.cuda.is_available(): | ||||||||||||||||||||||
compute_capability = torch.cuda.get_device_capability() | ||||||||||||||||||||||
major, minor = compute_capability | ||||||||||||||||||||||
if (major < 8) or (major == 8 and minor < 9): | ||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||
"FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)" | ||||||||||||||||||||||
f", actual = `{major}.{minor}`" | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
device_map = kwargs.get("device_map", None) | ||||||||||||||||||||||
if device_map is None: | ||||||||||||||||||||||
logger.warning_once( | ||||||||||||||||||||||
"You have loaded an FP8 model on CPU and have a CUDA device available, make sure to set " | ||||||||||||||||||||||
"your model on a GPU device in order to run your model. To remove this warning, pass device_map = 'cuda'. " | ||||||||||||||||||||||
) | ||||||||||||||||||||||
elif device_map is not None: | ||||||||||||||||||||||
if ( | ||||||||||||||||||||||
not self.pre_quantized | ||||||||||||||||||||||
and isinstance(device_map, dict) | ||||||||||||||||||||||
and ("cpu" in device_map.values() or "disk" in device_map.values()) | ||||||||||||||||||||||
): | ||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||
"You are attempting to load an FP8 model with a device_map that contains a cpu/disk device." | ||||||||||||||||||||||
"This is not supported when the model is quantized on the fly. " | ||||||||||||||||||||||
"Please use a quantized checkpoint or remove the cpu/disk device from the device_map." | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype": | ||||||||||||||||||||||
if torch_dtype is None: | ||||||||||||||||||||||
logger.info("Setting torch_dtype to torch.float32 as no torch_dtype was specified in from_pretrained") | ||||||||||||||||||||||
torch_dtype = torch.float32 | ||||||||||||||||||||||
return torch_dtype | ||||||||||||||||||||||
|
||||||||||||||||||||||
def create_quantized_param( | ||||||||||||||||||||||
self, | ||||||||||||||||||||||
model: "ModelMixin", | ||||||||||||||||||||||
param_value: "torch.Tensor", | ||||||||||||||||||||||
param_name: str, | ||||||||||||||||||||||
target_device: "torch.device", | ||||||||||||||||||||||
state_dict: Dict[str, Any], | ||||||||||||||||||||||
unexpected_keys: Optional[List[str]] = None, | ||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||
): | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
Quantizes weights to FP8 format using Block-wise quantization | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
# print("############ create quantized param ########") | ||||||||||||||||||||||
MekkCyber marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
from accelerate.utils import set_module_tensor_to_device | ||||||||||||||||||||||
|
||||||||||||||||||||||
set_module_tensor_to_device(model, param_name, target_device, param_value) | ||||||||||||||||||||||
|
||||||||||||||||||||||
module, tensor_name = get_module_from_name(model, param_name) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Get FP8 min/max values | ||||||||||||||||||||||
fp8_min = torch.finfo(torch.float8_e4m3fn).min | ||||||||||||||||||||||
fp8_max = torch.finfo(torch.float8_e4m3fn).max | ||||||||||||||||||||||
|
||||||||||||||||||||||
block_size_m, block_size_n = self.quantization_config.weight_block_size | ||||||||||||||||||||||
|
||||||||||||||||||||||
rows, cols = param_value.shape[-2:] | ||||||||||||||||||||||
|
||||||||||||||||||||||
if rows % block_size_m != 0 or cols % block_size_n != 0: | ||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n}) for {param_name}" | ||||||||||||||||||||||
) | ||||||||||||||||||||||
param_value_orig_shape = param_value.shape | ||||||||||||||||||||||
|
||||||||||||||||||||||
param_value = param_value.reshape( | ||||||||||||||||||||||
rows // block_size_m, block_size_m, cols // block_size_n, block_size_n | ||||||||||||||||||||||
).permute(0, 2, 1, 3) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Calculate scaling factor for each block | ||||||||||||||||||||||
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2)) | ||||||||||||||||||||||
scale = fp8_max / max_abs | ||||||||||||||||||||||
scale_orig_shape = scale.shape | ||||||||||||||||||||||
scale = scale.unsqueeze(-1).unsqueeze(-1) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Quantize the weights | ||||||||||||||||||||||
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn) | ||||||||||||||||||||||
|
||||||||||||||||||||||
quantized_param = quantized_param.permute(0, 2, 1, 3) | ||||||||||||||||||||||
# Reshape back to matrix shape | ||||||||||||||||||||||
quantized_param = quantized_param.reshape(param_value_orig_shape) | ||||||||||||||||||||||
MekkCyber marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
|
||||||||||||||||||||||
# Reshape scale to match the number of blocks | ||||||||||||||||||||||
scale = scale.reshape(scale_orig_shape).reciprocal() | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Load into the model | ||||||||||||||||||||||
module._parameters[tensor_name] = quantized_param.to(target_device) | ||||||||||||||||||||||
module._parameters["weight_scale_inv"] = scale.to(target_device) | ||||||||||||||||||||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
|
||||||||||||||||||||||
def check_if_quantized_param( | ||||||||||||||||||||||
self, | ||||||||||||||||||||||
model: "ModelMixin", | ||||||||||||||||||||||
param_value: "torch.Tensor", | ||||||||||||||||||||||
param_name: str, | ||||||||||||||||||||||
state_dict: Dict[str, Any], | ||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||
): | ||||||||||||||||||||||
from .utils import FP8Linear | ||||||||||||||||||||||
|
||||||||||||||||||||||
module, tensor_name = get_module_from_name(model, param_name) | ||||||||||||||||||||||
if isinstance(module, FP8Linear): | ||||||||||||||||||||||
if self.pre_quantized or tensor_name == "bias": | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's pre-quantized, shouldn't we return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch ! it's the opposite of what we have in transformers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wait i'm a bit confused here! does There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This. Example: diffusers/src/diffusers/models/model_loading_utils.py Lines 295 to 302 in 1bc6f3d
I think this is because during this step, we call diffusers/src/diffusers/models/modeling_utils.py Line 1228 in 1bc6f3d
which does:
(the job replacing the regular linears with bnb linears) LMK if you need more clarifications. |
||||||||||||||||||||||
if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn: | ||||||||||||||||||||||
raise ValueError("Expected quantized weights but got an unquantized weight") | ||||||||||||||||||||||
return False | ||||||||||||||||||||||
else: | ||||||||||||||||||||||
if tensor_name == "weight_scale_inv": | ||||||||||||||||||||||
raise ValueError("Expected unquantized weights but got a quantized weight_scale") | ||||||||||||||||||||||
return True | ||||||||||||||||||||||
return False | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _process_model_before_weight_loading( | ||||||||||||||||||||||
self, | ||||||||||||||||||||||
model: "ModelMixin", | ||||||||||||||||||||||
keep_in_fp32_modules: Optional[List[str]] = None, | ||||||||||||||||||||||
**kwargs, | ||||||||||||||||||||||
): | ||||||||||||||||||||||
from .utils import replace_with_fp8_linear | ||||||||||||||||||||||
|
||||||||||||||||||||||
if self.quantization_config.modules_to_not_convert is not None: | ||||||||||||||||||||||
self.modules_to_not_convert.extend(self.quantization_config.modules_to_not_convert) | ||||||||||||||||||||||
|
||||||||||||||||||||||
model = replace_with_fp8_linear( | ||||||||||||||||||||||
model, | ||||||||||||||||||||||
modules_to_not_convert=self.modules_to_not_convert, | ||||||||||||||||||||||
quantization_config=self.quantization_config, | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
model.config.quantization_config = self.quantization_config | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs): | ||||||||||||||||||||||
return model | ||||||||||||||||||||||
|
||||||||||||||||||||||
def update_missing_keys(self, model, missing_keys: List[str], prefix: str) -> List[str]: | ||||||||||||||||||||||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
from .utils import FP8Linear | ||||||||||||||||||||||
|
||||||||||||||||||||||
not_missing_keys = [] | ||||||||||||||||||||||
for name, module in model.named_modules(): | ||||||||||||||||||||||
if isinstance(module, FP8Linear): | ||||||||||||||||||||||
for missing in missing_keys: | ||||||||||||||||||||||
if ( | ||||||||||||||||||||||
(name in missing or name in f"{prefix}.{missing}") | ||||||||||||||||||||||
and not missing.endswith(".weight") | ||||||||||||||||||||||
and not missing.endswith(".bias") | ||||||||||||||||||||||
): | ||||||||||||||||||||||
not_missing_keys.append(missing) | ||||||||||||||||||||||
return [k for k in missing_keys if k not in not_missing_keys] | ||||||||||||||||||||||
|
||||||||||||||||||||||
def is_serializable(self, safe_serialization=None): | ||||||||||||||||||||||
return True | ||||||||||||||||||||||
|
||||||||||||||||||||||
@property | ||||||||||||||||||||||
def is_trainable(self) -> bool: | ||||||||||||||||||||||
return False | ||||||||||||||||||||||
|
||||||||||||||||||||||
def get_cuda_warm_up_factor(self): | ||||||||||||||||||||||
MekkCyber marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||||||||||||
# Pre-processing is done cleanly, so we can allocate everything here | ||||||||||||||||||||||
return 2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO.