Skip to content

Commit

Permalink
refactor NanotronParameter to support fp8
Browse files Browse the repository at this point in the history
  • Loading branch information
xrsrke committed Nov 20, 2024
1 parent 478984a commit c7d9e8a
Show file tree
Hide file tree
Showing 14 changed files with 168 additions and 60 deletions.
2 changes: 1 addition & 1 deletion examples/config_fp8_llama.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ logging:
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: int8
dtype: float8
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
Expand Down
52 changes: 26 additions & 26 deletions examples/config_tiny_fp8_llama.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
checkpoints:
checkpoint_interval: 10
checkpoint_interval: 10000
checkpoints_path: checkpoints
checkpoints_path_is_shared_file_system: false
resume_checkpoint_path: null
Expand All @@ -10,25 +10,25 @@ data_stages:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_or_datasets: roneneldan/TinyStories
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Stable Training Stage
start_training_step: 1
- data:
dataset:
dataset_overwrite_cache: false
dataset_processing_num_proc_per_process: 1
hf_dataset_config_name: null
hf_dataset_or_datasets: stas/openwebtext-10k
hf_dataset_splits: train
text_column_name: text
num_loading_workers: 1
seed: 42
name: Annealing Phase
start_training_step: 10
# - data:
# dataset:
# dataset_overwrite_cache: false
# dataset_processing_num_proc_per_process: 1
# hf_dataset_config_name: null
# hf_dataset_or_datasets: stas/openwebtext-10k
# hf_dataset_splits: train
# text_column_name: text
# num_loading_workers: 1
# seed: 42
# name: Annealing Phase
# start_training_step: 10
general:
benchmark_csv_path: null
consumed_train_samples: null
Expand All @@ -44,29 +44,29 @@ logging:
log_level_replica: info
model:
ddp_bucket_cap_mb: 25
dtype: bfloat16
dtype: float8
init_method:
std: 0.025
make_vocab_size_divisible_by: 1
model_config:
bos_token_id: 1
eos_token_id: 2
hidden_act: silu
hidden_size: 16
hidden_size: 1024
initializer_range: 0.02
intermediate_size: 64
intermediate_size: 4096
is_llama_config: true
max_position_embeddings: 256
max_position_embeddings: 1024
num_attention_heads: 4
num_hidden_layers: 2
num_hidden_layers: 6
num_key_value_heads: 4
pad_token_id: null
pretraining_tp: 1
rms_norm_eps: 1.0e-05
rope_scaling: null
tie_word_embeddings: true
use_cache: true
vocab_size: 256
vocab_size: 1024
optimizer:
accumulate_grad_in_fp32: true
clip_grad: 1.0
Expand All @@ -87,13 +87,13 @@ optimizer:
weight_decay: 0.01
zero_stage: 0
parallelism:
dp: 2
dp: 1
expert_parallel_size: 1
pp: 2
pp: 1
pp_engine: 1f1b
tp: 2
tp_linear_async_communication: true
tp_mode: REDUCE_SCATTER
tp_linear_async_communication: false
tp_mode: ALL_REDUCE
profiler: null
tokenizer:
tokenizer_max_length: null
Expand All @@ -104,6 +104,6 @@ tokens:
limit_test_batches: 0
limit_val_batches: 0
micro_batch_size: 2
sequence_length: 256
train_steps: 15
sequence_length: 1024
train_steps: 1500
val_check_interval: -1
6 changes: 6 additions & 0 deletions src/nanotron/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from datasets.download.streaming_download_manager import xPath
from yaml.loader import SafeLoader

from nanotron.config.fp8_config import FP8Args
from nanotron.config.lighteval_config import LightEvalConfig
from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit
from nanotron.config.parallelism_config import ParallelismArgs
Expand Down Expand Up @@ -351,6 +352,7 @@ class Config:
profiler: Optional[ProfilerArgs] = None
lighteval: Optional[LightEvalConfig] = None
s3_upload: Optional[S3UploadArgs] = None
fp8: Optional[FP8Args] = None

@classmethod
def create_empty(cls):
Expand Down Expand Up @@ -398,6 +400,10 @@ def __post_init__(self):
# if self.checkpoints.lighteval is not None:
# assert self.tokenizer.tokenizer_name_or_path is not None

if self.model.dtype == torch.int8:
if self.fp8 is None:
self.fp8 = FP8Args()

@property
def global_batch_size(self):
return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp
Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/config/fp8_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def __post_init__(self):

@dataclass
class FP8Args:
# NOTE: this is the datatype of model initialization, before casting to fp8
init_dtype: torch.dtype = torch.float32
# NOTE: this is the datatype for residual stream (aka: non-fp8 operation)
resid_dtype: torch.dtype = torch.float32
# NOTE: the datatype for fp8 operation's accumulation
Expand Down
1 change: 1 addition & 0 deletions src/nanotron/config/utils_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def serialize(data) -> dict:
"bfloat16": torch.bfloat16,
"uint8": torch.uint8,
"int8": torch.int8,
"float8": torch.int8,
"int16": torch.int16,
"int32": torch.int32,
"int64": torch.int64,
Expand Down
2 changes: 2 additions & 0 deletions src/nanotron/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@

# TODO(xrsrke): remove this shit
ITERATION_STEP = 1
# TODO(xrsrke): refactor to training stage,
# keep it in the same class as iteration_step
CONFIG = None
4 changes: 3 additions & 1 deletion src/nanotron/fp8/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from nanotron.fp8.linear import FP8LinearMeta
from nanotron.fp8.recipe import FP8LinearRecipe
from nanotron.fp8.tensor import FP8Tensor
from nanotron.parallel.parameters import NanotronParameter


def smooth_quant(input: torch.Tensor, weight: FP8Tensor, alpha: float) -> Tuple[torch.Tensor, FP8Tensor]:
Expand Down Expand Up @@ -32,12 +33,13 @@ def smooth_quant(input: torch.Tensor, weight: FP8Tensor, alpha: float) -> Tuple[

def linear(
input: torch.Tensor,
weight: FP8Tensor,
weight: NanotronParameter,
bias: Optional[torch.Tensor] = None,
metadatas: FP8LinearMeta = None,
recipe: FP8LinearRecipe = None,
name: Optional[str] = None,
):
assert isinstance(weight, NanotronParameter)
from typing import cast

from nanotron import constants
Expand Down
50 changes: 32 additions & 18 deletions src/nanotron/fp8/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from nanotron.fp8.parameter import FP8Parameter
from nanotron.fp8.recipe import FP8LinearRecipe
from nanotron.fp8.tensor import FP8Tensor
from nanotron.parallel.parameters import NanotronParameter


@dataclass
Expand All @@ -36,7 +37,7 @@ def __init__(
bias: bool = True,
device: Optional[torch.device] = None,
# accum_qtype: DTypes = FP8LM_RECIPE.linear.accum_dtype,
recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE,
# recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE,
# NOTE: placeholder for dtype in torch's nn.Linear
# TODO(xrsrke): remove this shit
**kwargs,
Expand All @@ -50,19 +51,26 @@ def __init__(
# TODO(xrsrke): take initialization dtype from recipe
# NOTE: initialize in float32
super().__init__(in_features, out_features, bias, device, dtype=torch.float32)
self._quantize_weights()

assert self.bias is None
# if self.bias is not None:
# self.bias = nn.Parameter(self.bias.to(recipe.accum_dtype))
# assert self.bias.dtype == recipe.accum_dtype

# self.metadatas = FP8LinearMeta()
# self.recipe = recipe

def _quantize_weights(self, recipe: FP8LinearRecipe = FP8LM_LINEAR_RECIPE):
quant_w = FP8Parameter(self.weight.data, dtype=recipe.weight.dtype, interval=recipe.weight.interval)
assert quant_w.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}"
self.weight = quant_w
# assert quant_w.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}"
# self.weight = quant_w
setattr(self.weight, "data", quant_w)

if self.name == "model.decoder.0.attention.qkv_proj":
assert 1 == 1

assert self.weight.data.dtype in [torch.uint8, torch.int8], f"got {self.weight.data.dtype}, name: {self.name}"

if self.bias is not None:
self.bias = nn.Parameter(self.bias.to(recipe.accum_dtype))
assert self.bias.dtype == recipe.accum_dtype

# NOTE: assume each time we requantize the weights, we reset the metadata
self.metadatas = FP8LinearMeta()
self.recipe = recipe

Expand All @@ -72,14 +80,16 @@ def forward(self, input: Union[FP8Tensor, torch.Tensor]) -> torch.Tensor:

return F.linear(
input=input,
weight=get_data_from_param(self.weight),
bias=None if self.bias is None else get_data_from_param(self.bias),
# weight=get_data_from_param(self.weight),
# bias=None if self.bias is None else get_data_from_param(self.bias),
weight=self.weight,
bias=None,
metadatas=self.metadatas,
recipe=self.recipe,
)

def __repr__(self) -> str:
return f"FP8{super().__repr__()}"
# def __repr__(self) -> str:
# return f"FP8{super().__repr__()}"


class _FP8Matmul(torch.autograd.Function):
Expand All @@ -88,14 +98,15 @@ class _FP8Matmul(torch.autograd.Function):
def forward(
ctx,
input: Union[FP8Tensor, torch.Tensor],
weight: FP8Tensor,
weight: NanotronParameter,
output: torch.Tensor,
phony: torch.Tensor,
metadatas: FP8LinearMeta,
recipe: FP8LinearRecipe,
name,
) -> torch.Tensor:
assert not isinstance(input, FP8Tensor)
assert isinstance(weight, NanotronParameter)

from nanotron import constants
from nanotron.config.fp8_config import FP8Args
Expand Down Expand Up @@ -125,7 +136,8 @@ def forward(

output = fp8_matmul_kernel(
# NOTE: that works
mat_a=weight,
# mat_a=weight, # i used weight before removing get_data_from_param
mat_a=weight.data,
mat_b=fp8_input,
output=accum_output,
use_split_accumulator=recipe.split_accumulator.output,
Expand All @@ -147,6 +159,7 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
from nanotron import constants
from nanotron.config.fp8_config import FP8Args

# pydevd.settrace(suspend=False, trace_only_current_thread=True)
if (
constants.CONFIG is not None
and constants.CONFIG.fp8 is not None
Expand All @@ -164,7 +177,8 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
sync_amax_in_igrad = fp8_config.sync_amax_in_igrad
sync_amax_in_wgrad = fp8_config.sync_amax_in_wgrad

fp8_input, fp8_weight = ctx.saved_tensors
fp8_input, fp8_weight_param = ctx.saved_tensors
fp8_weight = fp8_weight_param.data
recipe = ctx.recipe
recipe = cast(FP8LinearRecipe, recipe)

Expand Down Expand Up @@ -261,9 +275,9 @@ def backward(ctx, grad_output: torch.Tensor, grad_phony: torch.Tensor) -> Tuple[
grad_weight, ctx.metadatas.weight_grad, sync=sync_amax_in_wgrad
)

fp8_weight.grad = fp8_weight_grad
fp8_weight_param.grad = fp8_weight_grad

# NOTE: sanity check
assert isinstance(fp8_weight.grad, FP8Tensor)
assert isinstance(fp8_weight_param.grad, FP8Tensor)

return grad_input, None, None, None, None, None, None
19 changes: 18 additions & 1 deletion src/nanotron/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,22 @@ def init_on_device_and_dtype(
NOTE: in order to initialize an hybrid fp8 properly, you should use this context manager
```
"""
from typing import cast
from nanotron import constants
from nanotron.config.fp8_config import FP8Args

# NOTE: the reason we do float32 init here because my educated guess is that
# if we initially initialize models weight in float32, we have a more accurate representation
# of the target normal distribution rather than int8 (didn't do ablation study on this)

# NOTE: if the model's training dtype is float8, then we retrieve
# the initialization dtype from the fp8 config
if constants.CONFIG is not None and dtype is torch.int8:
training_config = cast(FP8Args, constants.CONFIG.fp8)
init_dtype = training_config.init_dtype
else:
init_dtype = dtype

from functools import wraps

def method_partial(func, *args, **kwargs):
Expand Down Expand Up @@ -403,7 +419,8 @@ def wrapper(*args, **kwargs):
# NOTE: nanotron automatically sets the device and dtype of the tensor
# but for FP8 training, we initializes with float16 first
kwargs["device"] = device
kwargs["dtype"] = torch.float32 if dtype == torch.int8 else dtype
# kwargs["dtype"] = torch.float32 if dtype == torch.int8 else dtype
kwargs["dtype"] = init_dtype
return fn(*args, **kwargs)

return wrapper
Expand Down
Loading

0 comments on commit c7d9e8a

Please sign in to comment.