From 5f2b65bd941648938ebb551952c51f8fd3f7695a Mon Sep 17 00:00:00 2001 From: kunlunl Date: Tue, 27 Aug 2024 03:49:35 -0700 Subject: [PATCH] Add hpreserve_high_precision_init_value to fp8_model_init --- transformer_engine/pytorch/fp8.py | 29 ++++++++++++++++++++++- transformer_engine/pytorch/module/base.py | 22 ++++++++++++++++- 2 files changed, 49 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index e15268b998..f54cda6429 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -66,6 +66,7 @@ class FP8GlobalStateManager: FP8_RECIPE = None FP8_DISTRIBUTED_GROUP = None FP8_PARAMETERS = False + HIGH_PRECISION_INIT_VAL = False IS_FIRST_FP8_MODULE = False FP8_GRAPH_CAPTURING = False FP8_AUTOCAST_DEPTH = 0 @@ -89,6 +90,7 @@ def reset(cls) -> None: cls.FP8_RECIPE = None cls.FP8_DISTRIBUTED_GROUP = None cls.FP8_PARAMETERS = False + cls.HIGH_PRECISION_INIT_VAL = False cls.IS_FIRST_FP8_MODULE = False cls.FP8_GRAPH_CAPTURING = False cls.FP8_AUTOCAST_DEPTH = 0 @@ -251,6 +253,11 @@ def with_fp8_parameters(cls) -> bool: """Should the parameters be stored as FP8""" return cls.FP8_PARAMETERS + @classmethod + def with_high_precision_init_val(cls) -> bool: + """Should the high precision initial values be stored with FP8 parameters""" + return cls.HIGH_PRECISION_INIT_VAL + @classmethod def fp8_graph_capturing(cls) -> bool: """Is CUDA graph capture under way?""" @@ -477,7 +484,10 @@ def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None: @contextmanager -def fp8_model_init(enabled: bool = True) -> None: +def fp8_model_init( + enabled: bool = True, + preserve_high_precision_init_val: bool = False, + ) -> None: """ Context manager for FP8 initialization of parameters. @@ -488,6 +498,12 @@ def fp8_model_init(enabled: bool = True) -> None: with fp8_model_init(enabled=True): model = transformer_engine.pytorch.Linear(768, 768) + # Preserving high precision initial value to initialize master weight + with fp8_model_init(enabled=True, preserve_high_precision_init_val=True): + model = transformer_engine.pytorch.Linear(768, 768) + master_weight = model.weight.get_high_precision_init_val() + model.weight.clear_high_precision_init_val() + Parameters ---------- enabled: bool, default = `True` @@ -501,15 +517,26 @@ def fp8_model_init(enabled: bool = True) -> None: precision copies of weights are already present in the optimizer. * inference, where only the FP8 copies of the parameters are used. * LoRA-like fine-tuning, where the main parameters of the model do not change. + preserve_high_precision_init_val: bool, default = `False` + when enabled, store the high precision tensor used to initialize FP8 parameters + in CPU memory, and add two function attributes named `get_high_precision_init_val()` + and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high + precision tensor. The purpose is that users can use this high-precision copy + to initialize master weights, avoiding the loss of precision that can occur when + using FP8 parameters directly. Note that after the master weights are initialized, + users should call `clear_high_precision_init_val()` to release this CPU memory. This functionality is *EXPERIMENTAL*. """ _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS FP8GlobalStateManager.FP8_PARAMETERS = enabled + _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val try: yield finally: FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters + FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val @contextmanager diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3613e1fa5e..27e7469434 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -13,6 +13,7 @@ from abc import ABC, abstractmethod from typing import Dict, Generator, List, Optional, Tuple, Union from contextlib import contextmanager +from types import MethodType import torch import torch.nn.functional as F @@ -389,6 +390,7 @@ def __init__(self) -> None: self.sequence_parallel = False self.param_init_meta = {} self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() + self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val() self.fsdp_wrapped = False self.fsdp_group = None self._fp8_workspaces: Dict[str, Float8Tensor] = {} @@ -865,6 +867,8 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # If primary weights are in fp8, wrap the parameter as Float8Tensor fp8_meta_index = self.param_init_meta[name].fp8_meta_index if self.primary_weights_in_fp8 and fp8_meta_index is not None: + if self.preserve_high_precision_init_val: + high_precision_init_val = param.detach().cpu() param = Float8Tensor.to_float8( param, fp8_meta=self.fp8_meta, @@ -876,7 +880,23 @@ def reset_parameters(self, defer_init: Optional[bool] = False) -> None: # NOTE: Currently this can only be broken when primary weights are in Fp8 but # re-applying the nn.Parameter() wrap is a no-op when the input is already # a parameter so we always re-apply it just for extra safety. - setattr(self, name, torch.nn.Parameter(param)) + param = torch.nn.Parameter(param) + if self.primary_weights_in_fp8 and self.preserve_high_precision_init_val: + def get(self): + if hasattr(self, "_high_precision_init_val"): + return self._high_precision_init_val + else: + return None + + def clear(self): + if hasattr(self, "_high_precision_init_val"): + del self._high_precision_init_val + + param._high_precision_init_val = high_precision_init_val + param.get_high_precision_init_val = MethodType(get, param) + param.clear_high_precision_init_val = MethodType(clear, param) + + setattr(self, name, param) @abstractmethod def forward(self):