From c7d9e8af15e5019f17fefdccb4393e1206866be5 Mon Sep 17 00:00:00 2001 From: Phuc Nguyen Date: Wed, 20 Nov 2024 16:30:46 +0000 Subject: [PATCH] refactor NanotronParameter to support fp8 --- examples/config_fp8_llama.yaml | 2 +- examples/config_tiny_fp8_llama.yaml | 52 +++++++++---------- src/nanotron/config/config.py | 6 +++ src/nanotron/config/fp8_config.py | 2 + src/nanotron/config/utils_config.py | 1 + src/nanotron/constants.py | 2 + src/nanotron/fp8/functional.py | 4 +- src/nanotron/fp8/linear.py | 50 +++++++++++------- src/nanotron/models/base.py | 19 ++++++- src/nanotron/models/llama.py | 15 ++++-- src/nanotron/optim/gradient_accumulator.py | 16 ++++-- src/nanotron/parallel/parameters.py | 14 +++-- .../parallel/tensor_parallel/functional.py | 6 ++- src/nanotron/trainer.py | 39 +++++++++++++- 14 files changed, 168 insertions(+), 60 deletions(-) diff --git a/examples/config_fp8_llama.yaml b/examples/config_fp8_llama.yaml index fbbad98f..bee3bd5e 100644 --- a/examples/config_fp8_llama.yaml +++ b/examples/config_fp8_llama.yaml @@ -44,7 +44,7 @@ logging: log_level_replica: info model: ddp_bucket_cap_mb: 25 - dtype: int8 + dtype: float8 init_method: std: 0.025 make_vocab_size_divisible_by: 1 diff --git a/examples/config_tiny_fp8_llama.yaml b/examples/config_tiny_fp8_llama.yaml index 58645e2d..dd0f2182 100644 --- a/examples/config_tiny_fp8_llama.yaml +++ b/examples/config_tiny_fp8_llama.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 10 + checkpoint_interval: 10000 checkpoints_path: checkpoints checkpoints_path_is_shared_file_system: false resume_checkpoint_path: null @@ -10,25 +10,25 @@ data_stages: dataset_overwrite_cache: false dataset_processing_num_proc_per_process: 1 hf_dataset_config_name: null - hf_dataset_or_datasets: stas/openwebtext-10k + hf_dataset_or_datasets: roneneldan/TinyStories hf_dataset_splits: train text_column_name: text num_loading_workers: 1 seed: 42 name: Stable Training Stage start_training_step: 1 -- data: - dataset: - dataset_overwrite_cache: false - dataset_processing_num_proc_per_process: 1 - hf_dataset_config_name: null - hf_dataset_or_datasets: stas/openwebtext-10k - hf_dataset_splits: train - text_column_name: text - num_loading_workers: 1 - seed: 42 - name: Annealing Phase - start_training_step: 10 +# - data: +# dataset: +# dataset_overwrite_cache: false +# dataset_processing_num_proc_per_process: 1 +# hf_dataset_config_name: null +# hf_dataset_or_datasets: stas/openwebtext-10k +# hf_dataset_splits: train +# text_column_name: text +# num_loading_workers: 1 +# seed: 42 +# name: Annealing Phase +# start_training_step: 10 general: benchmark_csv_path: null consumed_train_samples: null @@ -44,7 +44,7 @@ logging: log_level_replica: info model: ddp_bucket_cap_mb: 25 - dtype: bfloat16 + dtype: float8 init_method: std: 0.025 make_vocab_size_divisible_by: 1 @@ -52,13 +52,13 @@ model: bos_token_id: 1 eos_token_id: 2 hidden_act: silu - hidden_size: 16 + hidden_size: 1024 initializer_range: 0.02 - intermediate_size: 64 + intermediate_size: 4096 is_llama_config: true - max_position_embeddings: 256 + max_position_embeddings: 1024 num_attention_heads: 4 - num_hidden_layers: 2 + num_hidden_layers: 6 num_key_value_heads: 4 pad_token_id: null pretraining_tp: 1 @@ -66,7 +66,7 @@ model: rope_scaling: null tie_word_embeddings: true use_cache: true - vocab_size: 256 + vocab_size: 1024 optimizer: accumulate_grad_in_fp32: true clip_grad: 1.0 @@ -87,13 +87,13 @@ optimizer: weight_decay: 0.01 zero_stage: 0 parallelism: - dp: 2 + dp: 1 expert_parallel_size: 1 - pp: 2 + pp: 1 pp_engine: 1f1b tp: 2 - tp_linear_async_communication: true - tp_mode: REDUCE_SCATTER + tp_linear_async_communication: false + tp_mode: ALL_REDUCE profiler: null tokenizer: tokenizer_max_length: null @@ -104,6 +104,6 @@ tokens: limit_test_batches: 0 limit_val_batches: 0 micro_batch_size: 2 - sequence_length: 256 - train_steps: 15 + sequence_length: 1024 + train_steps: 1500 val_check_interval: -1 diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index c50334f6..784082fd 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -11,6 +11,7 @@ from datasets.download.streaming_download_manager import xPath from yaml.loader import SafeLoader +from nanotron.config.fp8_config import FP8Args from nanotron.config.lighteval_config import LightEvalConfig from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit from nanotron.config.parallelism_config import ParallelismArgs @@ -351,6 +352,7 @@ class Config: profiler: Optional[ProfilerArgs] = None lighteval: Optional[LightEvalConfig] = None s3_upload: Optional[S3UploadArgs] = None + fp8: Optional[FP8Args] = None @classmethod def create_empty(cls): @@ -398,6 +400,10 @@ def __post_init__(self): # if self.checkpoints.lighteval is not None: # assert self.tokenizer.tokenizer_name_or_path is not None + if self.model.dtype == torch.int8: + if self.fp8 is None: + self.fp8 = FP8Args() + @property def global_batch_size(self): return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp diff --git a/src/nanotron/config/fp8_config.py b/src/nanotron/config/fp8_config.py index 2aa41b01..a0509491 100644 --- a/src/nanotron/config/fp8_config.py +++ b/src/nanotron/config/fp8_config.py @@ -16,6 +16,8 @@ def __post_init__(self): @dataclass class FP8Args: + # NOTE: this is the datatype of model initialization, before casting to fp8 + init_dtype: torch.dtype = torch.float32 # NOTE: this is the datatype for residual stream (aka: non-fp8 operation) resid_dtype: torch.dtype = torch.float32 # NOTE: the datatype for fp8 operation's accumulation diff --git a/src/nanotron/config/utils_config.py b/src/nanotron/config/utils_config.py index bf8407fc..6cc092d4 100644 --- a/src/nanotron/config/utils_config.py +++ b/src/nanotron/config/utils_config.py @@ -62,6 +62,7 @@ def serialize(data) -> dict: "bfloat16": torch.bfloat16, "uint8": torch.uint8, "int8": torch.int8, + "float8": torch.int8, "int16": torch.int16, "int32": torch.int32, "int64": torch.int64, diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 8cce110f..0965e320 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -14,4 +14,6 @@ # TODO(xrsrke): remove this shit ITERATION_STEP = 1 +# TODO(xrsrke): refactor to training stage, +# keep it in the same class as iteration_step CONFIG = None diff --git a/src/nanotron/fp8/functional.py b/src/nanotron/fp8/functional.py index 118c9d98..22f61654 100644 --- a/src/nanotron/fp8/functional.py +++ b/src/nanotron/fp8/functional.py @@ -5,6 +5,7 @@ from nanotron.fp8.linear import FP8LinearMeta from nanotron.fp8.recipe import FP8LinearRecipe from nanotron.fp8.tensor import FP8Tensor +from nanotron.parallel.parameters import NanotronParameter def smooth_quant(input: torch.Tensor, weight: FP8Tensor, alpha: float) -> Tuple[torch.Tensor, FP8Tensor]: @@ -32,12 +33,13 @@ def smooth_quant(input: torch.Tensor, weight: FP8Tensor, alpha: float) -> Tuple[ def linear( input: torch.Tensor, - weight: FP8Tensor, + weight: NanotronParameter, bias: Optional[torch.Tensor] = None, metadatas: FP8LinearMeta = None, recipe: FP8LinearRecipe = None, name: Optional[str] = None, ): + assert isinstance(weight, NanotronParameter) from typing import cast from nanotron import constants diff --git a/src/nanotron/fp8/linear.py b/src/nanotron/fp8/linear.py index 4936e42c..2ff0d85c 100644 --- a/src/nanotron/fp8/linear.py +++ b/src/nanotron/fp8/linear.py @@ -12,6 +12,7 @@ from nanotron.fp8.parameter import FP8Parameter from nanotron.fp8.recipe import FP8LinearRecipe from nanotron.fp8.tensor import FP8Tensor +from nanotron.parallel.parameters import NanotronParameter @dataclass @@ -36,7 +37,7 @@ def __init__( bias: bool = True, device: Optional[torch.device] = None, # accum_qtype: DTypes = FP8LM_RECIPE.linear.accum_dtype, - recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE, + # recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE, # NOTE: placeholder for dtype in torch's nn.Linear # TODO(xrsrke): remove this shit **kwargs, @@ -50,19 +51,26 @@ def __init__( # TODO(xrsrke): take initialization dtype from recipe # NOTE: initialize in float32 super().__init__(in_features, out_features, bias, device, dtype=torch.float32) + self._quantize_weights() + + assert self.bias is None + # if self.bias is not None: + # self.bias = nn.Parameter(self.bias.to(recipe.accum_dtype)) + # assert self.bias.dtype == recipe.accum_dtype + + # self.metadatas = FP8LinearMeta() + # self.recipe = recipe + + def _quantize_weights(self, recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE): quant_w = FP8Parameter(self.weight.data, dtype=recipe.weight.dtype, interval=recipe.weight.interval) - assert quant_w.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}" - self.weight = quant_w + # assert quant_w.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}" + # self.weight = quant_w + setattr(self.weight, "data", quant_w) if self.name == "model.decoder.0.attention.qkv_proj": assert 1 == 1 - assert self.weight.data.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}, name: {self.name}" - - if self.bias is not None: - self.bias = nn.Parameter(self.bias.to(recipe.accum_dtype)) - assert self.bias.dtype == recipe.accum_dtype - + # NOTE: assume each time we requantize the weights, we reset the metadata self.metadatas = FP8LinearMeta() self.recipe = recipe @@ -72,14 +80,16 @@ def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor: return F.linear( input=input, - weight=get_data_from_param(self.weight), - bias=None if self.bias is None else get_data_from_param(self.bias), + # weight=get_data_from_param(self.weight), + # bias=None if self.bias is None else get_data_from_param(self.bias), + weight=self.weight, + bias=None, metadatas=self.metadatas, recipe=self.recipe, ) - def __repr__(self) -> str: - return f"FP8{super().__repr__()}" + # def __repr__(self) -> str: + # return f"FP8{super().__repr__()}" class _FP8Matmul(torch.autograd.Function): @@ -88,7 +98,7 @@ class _FP8Matmul(torch.autograd.Function): def forward( ctx, input: Union[FP8Tensor, torch.Tensor], - weight: FP8Tensor, + weight: NanotronParameter, output: torch.Tensor, phony: torch.Tensor, metadatas: FP8LinearMeta, @@ -96,6 +106,7 @@ def forward( name, ) -> torch.Tensor: assert not isinstance(input, FP8Tensor) + assert isinstance(weight, NanotronParameter) from nanotron import constants from nanotron.config.fp8_config import FP8Args @@ -125,7 +136,8 @@ def forward( output = fp8_matmul_kernel( # NOTE: that works - mat_a=weight, + # mat_a=weight, # i used weight before removing get_data_from_param + mat_a=weight.data, mat_b=fp8_input, output=accum_output, use_split_accumulator=recipe.split_accumulator.output, @@ -147,6 +159,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ from nanotron import constants from nanotron.config.fp8_config import FP8Args + # pydevd.settrace(suspend=False, trace_only_current_thread=True) if ( constants.CONFIG is not None and constants.CONFIG.fp8 is not None @@ -164,7 +177,8 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ sync_amax_in_igrad = fp8_config.sync_amax_in_igrad sync_amax_in_wgrad = fp8_config.sync_amax_in_wgrad - fp8_input, fp8_weight = ctx.saved_tensors + fp8_input, fp8_weight_param = ctx.saved_tensors + fp8_weight = fp8_weight_param.data recipe = ctx.recipe recipe = cast(FP8LinearRecipe, recipe) @@ -261,9 +275,9 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[ grad_weight, ctx.metadatas.weight_grad, sync=sync_amax_in_wgrad ) - fp8_weight.grad = fp8_weight_grad + fp8_weight_param.grad = fp8_weight_grad # NOTE: sanity check - assert isinstance(fp8_weight.grad, FP8Tensor) + assert isinstance(fp8_weight_param.grad, FP8Tensor) return grad_input, None, None, None, None, None, None diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index 78fd3965..8dbb2e3e 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -369,6 +369,22 @@ def init_on_device_and_dtype( NOTE: in order to initialize an hybrid fp8 properly, you should use this context manager ``` """ + from typing import cast + from nanotron import constants + from nanotron.config.fp8_config import FP8Args + + # NOTE: the reason we do float32 init here because my educated guess is that + # if we initially initialize models weight in float32, we have a more accurate representation + # of the target normal distribution rather than int8 (didn't do ablation study on this) + + # NOTE: if the model's training dtype is float8, then we retrieve + # the initialization dtype from the fp8 config + if constants.CONFIG is not None and dtype is torch.int8: + training_config = cast(FP8Args, constants.CONFIG.fp8) + init_dtype = training_config.init_dtype + else: + init_dtype = dtype + from functools import wraps def method_partial(func, *args, **kwargs): @@ -403,7 +419,8 @@ def wrapper(*args, **kwargs): # NOTE: nanotron automatically sets the device and dtype of the tensor # but for FP8 training, we initializes with float16 first kwargs["device"] = device - kwargs["dtype"] = torch.float32 if dtype == torch.int8 else dtype + # kwargs["dtype"] = torch.float32 if dtype == torch.int8 else dtype + kwargs["dtype"] = init_dtype return fn(*args, **kwargs) return wrapper diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index 6ebbeb31..19b60993 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -229,7 +229,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=gate_up_contiguous_chunks, - # name=f"model.decoder.{layer_idx}.mlp.gate_up_proj", + name=f"model.decoder.{layer_idx}.mlp.gate_up_proj", # tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) self.down_proj = TensorParallelRowLinear( @@ -239,6 +239,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication and tp_mode is TensorParallelLinearMode.REDUCE_SCATTER, + name=f"model.decoder.{layer_idx}.mlp.down_proj", ) self.split_silu_mul = GLUActivation(config.hidden_act) @@ -392,7 +393,7 @@ def __init__( bias=False, async_communication=tp_linear_async_communication, contiguous_chunks=qkv_contiguous_chunks, - # name=f"model.decoder.{layer_idx}.attention.qkv_proj", + name=f"model.decoder.{layer_idx}.attention.qkv_proj", # tp_recompute_allgather=parallel_config.tp_recompute_allgather, ) # TODO(kunhao): We want to have only one version per device and not one version per layer. @@ -422,7 +423,7 @@ def __init__( mode=tp_mode, bias=False, async_communication=tp_linear_async_communication, - # name=f"model.decoder.{layer_idx}.attention.o_proj", + name=f"model.decoder.{layer_idx}.attention.o_proj", ) self.attention = CoreAttention( @@ -681,6 +682,13 @@ def forward( batch_size * kv_length, self.n_local_kv_heads, self.d_v ) # [batch_size * kv_length, self.n_heads, d_v] + # NOTE: even though in some cases, we accumulate fp8 gemm in bfloat16, + # but since the layer norm are in float32, the resulting output will be in float32 + # and flash attention don't support float32 qkv, so we have to cast it back to bfloat16 + query_states = query_states.to(torch.bfloat16) + key_states = key_states.to(torch.bfloat16) + value_states = value_states.to(torch.bfloat16) + attention_output = self.attention( query_states=query_states, key_states=key_states, @@ -877,6 +885,7 @@ def __init__( # TODO @thomasw21: refactor so that we store that default in a single place. "mode": self.tp_mode, "async_communication": tp_linear_async_communication, + "name": "model.lm_head", # "tp_recompute_allgather": parallel_config.tp_recompute_allgather, }, module_input_keys={"x"}, diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 74165ca2..fbad0250 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -154,6 +154,11 @@ def sync_gradients_across_dp(self, dp_pg: dist.ProcessGroup, reduce_op: dist.Red else: dist.all_reduce(self._contiguous_fp32_grad_buffer, op=reduce_op, group=dp_pg) + @classmethod + def _is_accumulate_param(cls, param: NanotronParameter) -> bool: + from nanotron.fp8.tensor import FP8Tensor + return param.requires_grad or param.data.__class__ == FP8Tensor + @staticmethod def build_grad_buffers( named_parameters: Iterator[Tuple[str, NanotronParameter]], @@ -166,7 +171,7 @@ def build_grad_buffers( Note: In ZeRO-1, we need to accumulate grads for all parameters, because we need to allreduce all parameters' grads across DP at each sync step. """ - named_parameters = [(name, param) for name, param in named_parameters if param.requires_grad] + named_parameters = [(name, param) for name, param in named_parameters if FP32GradientAccumulator._is_accumulate_param(param)] needed_buffer_size = sum(param.numel() for _, param in named_parameters) # important to have grads zeroed initially (see `self._accumulate_grad`) @@ -178,16 +183,19 @@ def build_grad_buffers( fp32_grad_buffers = OrderedDict() # keeps order of insertion offset = 0 for name, param in named_parameters: - if not param.requires_grad: + # NOTE: because fp8 parameter by default has `requires_grad=False`, + # but we still need to accumulate grads for it + # if not param.requires_grad: + if FP32GradientAccumulator._is_accumulate_param(param) is False: continue - assert param.dtype != torch.float, f"Expected {name} not to be float" + # assert param.dtype != torch.float, f"Expected {name} not to be float" assert param.is_contiguous(), f"Expected {name} to be contiguous" next_offset = offset + param.numel() * element_size fp32_grad_buffer = tensor_from_untyped_storage( - untyped_storage=untyped_storage[offset:next_offset], dtype=torch.float + untyped_storage=untyped_storage[offset:next_offset], dtype=torch.float32 ) fp32_grad_buffers[name] = { diff --git a/src/nanotron/parallel/parameters.py b/src/nanotron/parallel/parameters.py index cb925131..da179468 100644 --- a/src/nanotron/parallel/parameters.py +++ b/src/nanotron/parallel/parameters.py @@ -227,6 +227,14 @@ def data(self): def data(self, data): self._data = data + @property + def grad(self): + return self._grad + + @grad.setter + def grad(self, grad): + self._grad = grad + # @property # def grad(self): # return self.data.grad if self.grad is None else self.grad @@ -269,9 +277,9 @@ def wrap(e): # data = args[0].data # data.requires_grad = args[0].requires_grad data = unwrapped_args[0] - if data.__class__ == FP8Parameter or data.__class__ == nn.Parameter: - data = data.data - # data.requires_grad = unwrapped_args[0].requires_grad + # if data.__class__ == FP8Parameter or data.__class__ == nn.Parameter: + # data = data.data + # # data.requires_grad = unwrapped_args[0].requires_grad return data else: diff --git a/src/nanotron/parallel/tensor_parallel/functional.py b/src/nanotron/parallel/tensor_parallel/functional.py index 6d56ed7f..61dbdd1d 100644 --- a/src/nanotron/parallel/tensor_parallel/functional.py +++ b/src/nanotron/parallel/tensor_parallel/functional.py @@ -456,7 +456,8 @@ def column_linear( input = differentiable_identity(input, group=group) - if isinstance(weight, FP8Tensor): + # if isinstance(weight, FP8Tensor): # i used weight before removing get_data_from_param + if isinstance(weight.data, FP8Tensor): assert recipe is not None, "recipe must be provided for column_linear" from nanotron import constants @@ -642,7 +643,8 @@ def row_linear( import nanotron.fp8.functional as fp8_functional from nanotron.fp8.tensor import FP8Tensor - if isinstance(weight, FP8Tensor): + # if isinstance(weight, FP8Tensor): # i used weight before removing get_data_from_param + if isinstance(weight.data, FP8Tensor): assert recipe is not None, "recipe must be provided for row_linear" out = fp8_functional.linear(input, weight, bias, metadatas=metadatas, recipe=recipe, name=name) else: diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 45d704ee..8a7a5304 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -68,7 +68,7 @@ ) from nanotron.parallel.pipeline_parallel.utils import get_pp_rank_of from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode -from nanotron.parallel.tensor_parallel.nn import TensorParallelRowLinear +from nanotron.parallel.tensor_parallel.nn import TensorParallelRowLinear, TensorParallelColumnLinear from nanotron.parallel.tied_parameters import ( create_pg_for_tied_weights, get_tied_id_to_param, @@ -134,6 +134,8 @@ def __init__( self.config = get_config_from_file( config_or_config_file, config_class=config_class, model_config_class=model_config_class ) + from nanotron import constants + constants.CONFIG = self.config self.model_config = self.config.model.model_config if model_class is not None: CONFIG_TO_MODEL_CLASS[self.model_config.__class__.__name__] = model_class @@ -189,6 +191,41 @@ def __init__( optimizer_args=self.config.optimizer, parallel_context=self.parallel_context, ) + + from nanotron.fp8.utils import get_leaf_modules + def print_sanity_params(model): + for n, p in model.named_parameters(): + print(n, p.__class__.__name__, p.requires_grad, p.data.dtype) + + print("before quantize") + print_sanity_params(self.model) + + assert 1 == 1 + # NOTE: convert to FP8 + from nanotron.parallel.tensor_parallel.nn import FP8TensorParallelColumnLinear, FP8TensorParallelRowLinear + from nanotron.parallel.parameters import NanotronParameter + from nanotron.fp8.tensor import FP8Tensor + + TP_LINEAR_CLS_TO_FP8_LINEAR_CLS = { + TensorParallelColumnLinear: FP8TensorParallelColumnLinear, + TensorParallelRowLinear: FP8TensorParallelRowLinear, + } + if self.config.model.dtype is torch.int8: + for name, module in get_leaf_modules(self.model): + if isinstance(module, (TensorParallelColumnLinear, TensorParallelRowLinear)): + print(f"Converting {name} to FP8") + module.__class__ = TP_LINEAR_CLS_TO_FP8_LINEAR_CLS[module.__class__] + # TODO(xrsrke): retrieve custom recipe + module._quantize_weights() + + assert isinstance(module.weight, NanotronParameter) + assert isinstance(module.weight.data, FP8Tensor) + assert module.weight.data.dtype in [torch.uint8, torch.int8], f"got {module.weight.data.dtype}, name: {name}" + + print("after quantize") + print_sanity_params(self.model) + assert 1 == 1 + if self.init_checkpoint_path is not None: load_optimizer( optimizer=self.optimizer,