Skip to content

Commit

Permalink
Add hpreserve_high_precision_init_value to fp8_model_init
Browse files Browse the repository at this point in the history
  • Loading branch information
kunlunl committed Aug 27, 2024
1 parent 3040785 commit 5f2b65b
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
29 changes: 28 additions & 1 deletion transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ class FP8GlobalStateManager:
FP8_RECIPE = None
FP8_DISTRIBUTED_GROUP = None
FP8_PARAMETERS = False
HIGH_PRECISION_INIT_VAL = False
IS_FIRST_FP8_MODULE = False
FP8_GRAPH_CAPTURING = False
FP8_AUTOCAST_DEPTH = 0
Expand All @@ -89,6 +90,7 @@ def reset(cls) -> None:
cls.FP8_RECIPE = None
cls.FP8_DISTRIBUTED_GROUP = None
cls.FP8_PARAMETERS = False
cls.HIGH_PRECISION_INIT_VAL = False
cls.IS_FIRST_FP8_MODULE = False
cls.FP8_GRAPH_CAPTURING = False
cls.FP8_AUTOCAST_DEPTH = 0
Expand Down Expand Up @@ -251,6 +253,11 @@ def with_fp8_parameters(cls) -> bool:
"""Should the parameters be stored as FP8"""
return cls.FP8_PARAMETERS

@classmethod
def with_high_precision_init_val(cls) -> bool:
"""Should the high precision initial values be stored with FP8 parameters"""
return cls.HIGH_PRECISION_INIT_VAL

@classmethod
def fp8_graph_capturing(cls) -> bool:
"""Is CUDA graph capture under way?"""
Expand Down Expand Up @@ -477,7 +484,10 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:


@contextmanager
def fp8_model_init(enabled: bool = True) -> None:
def fp8_model_init(
enabled: bool = True,
preserve_high_precision_init_val: bool = False,
) -> None:
"""
Context manager for FP8 initialization of parameters.
Expand All @@ -488,6 +498,12 @@ def fp8_model_init(enabled: bool = True) -> None:
with fp8_model_init(enabled=True):
model = transformer_engine.pytorch.Linear(768, 768)
# Preserving high precision initial value to initialize master weight
with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
model = transformer_engine.pytorch.Linear(768, 768)
master_weight = model.weight.get_high_precision_init_val()
model.weight.clear_high_precision_init_val()
Parameters
----------
enabled: bool, default = `True`
Expand All @@ -501,15 +517,26 @@ def fp8_model_init(enabled: bool = True) -> None:
precision copies of weights are already present in the optimizer.
* inference, where only the FP8 copies of the parameters are used.
* LoRA-like fine-tuning, where the main parameters of the model do not change.
preserve_high_precision_init_val: bool, default = `False`
when enabled, store the high precision tensor used to initialize FP8 parameters
in CPU memory, and add two function attributes named `get_high_precision_init_val()`
and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high
precision tensor. The purpose is that users can use this high-precision copy
to initialize master weights, avoiding the loss of precision that can occur when
using FP8 parameters directly. Note that after the master weights are initialized,
users should call `clear_high_precision_init_val()` to release this CPU memory.
This functionality is *EXPERIMENTAL*.
"""
_fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
FP8GlobalStateManager.FP8_PARAMETERS = enabled
_high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val
try:
yield
finally:
FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val


@contextmanager
Expand Down
22 changes: 21 additions & 1 deletion transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Optional, Tuple, Union
from contextlib import contextmanager
from types import MethodType

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -389,6 +390,7 @@ def __init__(self) -> None:
self.sequence_parallel = False
self.param_init_meta = {}
self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
self.fsdp_wrapped = False
self.fsdp_group = None
self._fp8_workspaces: Dict[str, Float8Tensor] = {}
Expand Down Expand Up @@ -865,6 +867,8 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
# If primary weights are in fp8, wrap the parameter as Float8Tensor
fp8_meta_index = self.param_init_meta[name].fp8_meta_index
if self.primary_weights_in_fp8 and fp8_meta_index is not None:
if self.preserve_high_precision_init_val:
high_precision_init_val = param.detach().cpu()
param = Float8Tensor.to_float8(
param,
fp8_meta=self.fp8_meta,
Expand All @@ -876,7 +880,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None:
# NOTE: Currently this can only be broken when primary weights are in Fp8 but
# re-applying the nn.Parameter() wrap is a no-op when the input is already
# a parameter so we always re-apply it just for extra safety.
setattr(self, name, torch.nn.Parameter(param))
param = torch.nn.Parameter(param)
if self.primary_weights_in_fp8 and self.preserve_high_precision_init_val:
def get(self):
if hasattr(self, "_high_precision_init_val"):
return self._high_precision_init_val
else:
return None

def clear(self):
if hasattr(self, "_high_precision_init_val"):
del self._high_precision_init_val

param._high_precision_init_val = high_precision_init_val
param.get_high_precision_init_val = MethodType(get, param)
param.clear_high_precision_init_val = MethodType(clear, param)

setattr(self, name, param)

@abstractmethod
def forward(self):
Expand Down

0 comments on commit 5f2b65b

Please sign in to comment.