Skip to content

Sync 4 layer norms - bf16, fp32, optimizer states on restart #274

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
8d7a603
WIP
thomasw21 Mar 23, 2022
240f673
Wip
thomasw21 Mar 24, 2022
1cdcd7d
Woops
thomasw21 Mar 24, 2022
2937280
WIP
thomasw21 Mar 24, 2022
7fcff06
Woops
thomasw21 Mar 24, 2022
1f2f800
Woops
thomasw21 Mar 24, 2022
f152e48
Woops
thomasw21 Mar 24, 2022
ce02dd1
Test with alibi
thomasw21 Mar 24, 2022
02365d1
Still trying to reproduce
thomasw21 Mar 24, 2022
42d6b4e
Huh
thomasw21 Mar 24, 2022
c20c8ba
Have high LR to see weights actually change
thomasw21 Mar 24, 2022
7f2441e
Launch bf16
thomasw21 Mar 24, 2022
a4172bf
Woops
thomasw21 Mar 24, 2022
5fbe107
Make test to work with both bf16 and fp16 to see who fails
thomasw21 Mar 24, 2022
a0c0913
Woops
thomasw21 Mar 24, 2022
6b19339
Remove assert
thomasw21 Mar 24, 2022
a5e3295
Try to figure out how the divergence happens
thomasw21 Mar 24, 2022
7145f6d
I think bias starts to diverge first
thomasw21 Mar 24, 2022
311e531
Woops
thomasw21 Mar 24, 2022
39d4b8f
Woops
thomasw21 Mar 24, 2022
8ffb278
Woops
thomasw21 Mar 24, 2022
2389bfd
Add embed layer norm
thomasw21 Mar 24, 2022
0cf35ee
Woops
thomasw21 Mar 24, 2022
f0d6d17
Backward compatibility on torch
thomasw21 Mar 24, 2022
07ccb3d
Better
thomasw21 Mar 24, 2022
3c5e491
Merge remote-tracking branch 'origin/main' into thomas/test_different…
stas00 Mar 26, 2022
a5b5edc
fix
stas00 Mar 26, 2022
c7f2006
Sync lp/hp/optim for layer norms
tjruwase Mar 28, 2022
8f2ea60
fix requirements
stas00 Mar 28, 2022
fc8f813
dynamically discovered layer norm weights / refactor
stas00 Mar 29, 2022
4443e6d
fix regex
stas00 Mar 29, 2022
d2aa4f1
add the test script
stas00 Mar 29, 2022
d64a947
compare on cpu
stas00 Mar 29, 2022
bf7eeb3
add 2 more weights to sync
stas00 Mar 29, 2022
8482595
fp32 accessors
tjruwase Mar 30, 2022
86b726c
improve the doc, and comment out the demo
stas00 Mar 30, 2022
2ac141b
typo
stas00 Mar 30, 2022
d576775
Sync torch_rng_state (#277)
thomasw21 Apr 6, 2022
475f373
Fix device issue when using torch.broadcast
thomasw21 Apr 6, 2022
5b36884
Merge remote-tracking branch 'origin/main' into olruwase/sync_layer_n…
stas00 Jun 29, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions compare_tp_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@

# usage:
# python compare_tp_weights.py input_layernorm.weight 40 2 .

# input_layernorm.weight
# input_layernorm.bias
# post_attention_layernorm.weight
# post_attention_layernorm.bias

# one liner for just 2 weights comparison
# python -c 'import torch, sys; k=sys.argv[1]; a,b = map(torch.load, sys.argv[2:4]); print("Exact match" if torch.testing.assert_close(a[k], b[k], rtol=0.0, atol=0.0, check_device=False) is None else "Mismatch")' input_layernorm.weight layer_03-model_00-model_states.pt layer_03-model_01-model_states.pt

# 13B
# cd /gpfsdsstore/projects/rech/six/commun/checkpoints/tr1-13B/tr1-13B-with-optim/global_step168000
# python ~/compare_tp_weights.py input_layernorm.weight 40 2 .

# 104B
# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr8b-104B/checkpoints/emb-norm/global_step16800
#
# python ~/compare_tp_weights.py input_layernorm.weight 64 4 . > ~/104B.input_layernorm.weight.txt
# python ~/compare_tp_weights.py post_attention_layernorm.weight 64 4 . > ~/104B.post_attention_layernorm.weight.txt
# python ~/compare_tp_weights.py input_layernorm.bias 64 4 . > ~/104B.input_layernorm.bias.txt
# python ~/compare_tp_weights.py post_attention_layernorm.bias 64 4 . > ~/104B.post_attention_layernorm.bias.txt

# other 104B checkpoints:

# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr8b-104B/to-back-up/tr8b-104B/checkpoints/cl-exp-02/global_step10500
# mismatched 68
#
# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr8-104B-wide/experiment11/global_step15660
# mismatched
#
# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr8-104B-wide/experiment06/global_step5100
# python ~/compare_tp_weights.py input_layernorm.weight 32 4
# **all matched**
#
# python ~/compare_tp_weights.py post_attention_layernorm.weight 32 4
# not matched



# # 104B/176B embed-norm check
# python -c 'import torch, sys; k=sys.argv[1]; a,b = map(torch.load, sys.argv[2:4]); print("Exact match" if torch.testing.assert_close(a[k], b[k], rtol=0.0, atol=0.0, check_device=False) is None else "Mismatch")' word_embeddings.norm.weight layer_01-model_00-model_states.pt layer_01-model_01-model_states.pt
# python -c 'import torch, sys; k=sys.argv[1]; a,b = map(torch.load, sys.argv[2:4]); print("Exact match" if torch.testing.assert_close(a[k], b[k], rtol=0.0, atol=0.0, check_device=False) is None else "Mismatch")' word_embeddings.norm.weight layer_01-model_01-model_states.pt layer_01-model_02-model_states.pt
# python -c 'import torch, sys; k=sys.argv[1]; a,b = map(torch.load, sys.argv[2:4]); print("Exact match" if torch.testing.assert_close(a[k], b[k], rtol=0.0, atol=0.0, check_device=False) is None else "Mismatch")' word_embeddings.norm.weight layer_01-model_02-model_states.pt layer_01-model_03-model_states.pt

# same on cpu
python -c 'import torch, sys; k=sys.argv[1]; a=torch.load(sys.argv[2], map_location=torch.device("cpu"));b=torch.load(sys.argv[3], map_location=torch.device("cpu")); print("Exact match" if torch.testing.assert_close(a[k], b[k], rtol=0.0, atol=0.0, check_device=False) is None else "Mismatch")' word_embeddings.norm.weight layer_01-model_00-model_states.pt layer_01-model_01-model_states.pt

# # 176B
# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr11-176B-ml/checkpoints/main/global_step16400
# python ~/compare_tp_weights.py input_layernorm.weight 70 4 . > ~/176B.input_layernorm.weight.txt
# python ~/compare_tp_weights.py post_attention_layernorm.weight 70 4 . > ~/176B.post_attention_layernorm.weight.txt
# python ~/compare_tp_weights.py input_layernorm.bias 70 4 . > ~/176B.input_layernorm.bias.txt
# python ~/compare_tp_weights.py post_attention_layernorm.bias 70 4 . > ~/176B.post_attention_layernorm.bias.txt


import torch, sys



key, nlayers, tp_size, checkpoint_dir = sys.argv[1:5]

print(f"checking key={key}")
matched, mismatched = 0, 0
for layer_id in range(int(nlayers)):
for tp in range(int(tp_size)-1):
f1 = f"{checkpoint_dir}/layer_{3+layer_id:02d}-model_{tp:02d}-model_states.pt"
f2 = f"{checkpoint_dir}/layer_{3+layer_id:02d}-model_{tp+1:02d}-model_states.pt"
c1 = torch.load(f1)
c2 = torch.load(f2)
# print(f1)
# print(f2)
header = f"layer_id={layer_id}: {tp}-{tp+1}"
try:
torch.testing.assert_close(c1[key], c2[key], rtol=0.0, atol=0.0, check_device=False)
print(f"✓ {header}")
matched += 1
except:
print(f"✗ {header}")
mismatched += 1
#raise

print(f"Matched : {matched}")
print(f"Mismatched: {mismatched}")
42 changes: 42 additions & 0 deletions compare_tp_weights_cpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@

# usage:
# python compare_tp_weights.py input_layernorm.weight 40 2 .


# 13B
# cd /gpfsdsstore/projects/rech/six/commun/checkpoints/tr1-13B/tr1-13B-with-optim/global_step168000
# python ~/compare_tp_weights.py input_layernorm.weight 40 2 .

# 104B
# cd /gpfsssd/scratch/rech/six/commun/checkpoints/tr8b-104B/checkpoints/emb-norm/global_step16800
# python ~/compare_tp_weights.py input_layernorm.weight 64 4 .


import torch, sys



key, nlayers, tp_size, checkpoint_dir = sys.argv[1:5]

print(f"checking key={key}")
matched, mismatched = 0, 0
for layer_id in range(int(nlayers)):
for tp in range(int(tp_size)-1):
f1 = f"{checkpoint_dir}/layer_{3+layer_id:02d}-model_{tp:02d}-model_states.pt"
f2 = f"{checkpoint_dir}/layer_{3+layer_id:02d}-model_{tp+1:02d}-model_states.pt"
c1 = torch.load(f1, map_location=torch.device('cpu'))
c2 = torch.load(f2, map_location=torch.device('cpu'))
# print(f1)
# print(f2)
header = f"layer_id={layer_id}: {tp}-{tp+1}"
try:
torch.testing.assert_close(c1[key], c2[key], rtol=0.0, atol=0.0, check_device=False)
print(f"✓ {header}")
matched += 1
except:
print(f"✗ {header}")
mismatched += 1
#raise

print(f"Matched : {matched}")
print(f"Mismatched: {mismatched}")
3 changes: 2 additions & 1 deletion megatron/data/data_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def build_pretraining_data_loader(dataset, consumed_samples, num_workers=None):
# Torch dataloader.
return torch.utils.data.DataLoader(dataset,
batch_sampler=batch_sampler,
num_workers=num_workers,
num_workers=args.num_workers,
generator=torch.Generator().manual_seed(args.seed),
pin_memory=True)

class MegatronPretrainingSampler:
Expand Down
24 changes: 23 additions & 1 deletion megatron/model/fused_layer_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

from packaging import version
import torch
from megatron import mpu
from torch import nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F
Expand All @@ -37,7 +38,6 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):

@staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps):

ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
Expand Down Expand Up @@ -96,7 +96,29 @@ def reset_parameters(self):
init.zeros_(self.bias)


def forward_old(self, input):
# weights = [torch.empty_like(self.weight) for tp in range(mpu.get_tensor_model_parallel_world_size())]
# torch.distributed.all_gather(weights, self.weight, group=mpu.get_tensor_model_parallel_group())
# biases = [torch.empty_like(self.bias) for tp in range(mpu.get_tensor_model_parallel_world_size())]
# torch.distributed.all_gather(biases, self.bias, group=mpu.get_tensor_model_parallel_group())
# if any(torch.any(weight != self.weight) for weight in weights):
# if mpu.get_tensor_model_parallel_rank() == 0:
# print("Weight sync failed")
# print(weights)
# if any(torch.any(bias != self.bias) for bias in biases):
# if mpu.get_tensor_model_parallel_rank() == 0:
# print("Bias sync failed")
# print(biases)

return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape,self.eps)


def forward(self, input):

torch.distributed.all_reduce(self.weight, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
torch.distributed.all_reduce(self.bias, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())

if self.use_meg_ds_fused_layer_norm:
return FusedLayerNormAffineFunction.apply(
input, self.weight, self.bias, self.normalized_shape, self.eps)
Expand Down
2 changes: 2 additions & 0 deletions megatron/mpu/mappings.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def symbolic(graph, input_):

@staticmethod
def forward(ctx, input_):
# TODO: we need to assert that the input_ are all the same within a group
return input_

@staticmethod
Expand All @@ -102,6 +103,7 @@ def forward(ctx, input_):

@staticmethod
def backward(ctx, grad_output):
# TODO: we need to assert that the grad_output are all the same within a group
return grad_output


Expand Down
95 changes: 95 additions & 0 deletions megatron/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,93 @@ def get_learning_rate_scheduler(optimizer):
return lr_scheduler


def sync_layer_norm(n, p):

rank = torch.distributed.get_rank()

print(f'rank {rank} processing {n}')

#return

# # Here is how you can access fp32 version of the bf16 param and fp32 optim states
# #
# # Note that there is an all_reduce called on all dp ranks when `get_full_hp_param` is called -
# # so it's not free
# #
# # a. fp32 param
# fp32_param = p.get_full_hp_param()
# torch.set_printoptions(sci_mode=False, precision=6)
# print(f'rank {rank} bf16 = {p}')
# print(f'rank {rank} fp32 = {fp32_param}')
# torch.testing.assert_close(p, fp32_param, rtol=4e-3, atol=0, check_dtype=False)

# # b. fp32 optim states
# for key in ['exp_avg', 'exp_avg_sq']:
# full_optim_state = p.get_full_hp_param(optim_state_key=key)
# print(f'rank {rank} full optim state {key} = {full_optim_state}')

# 1. bf16
#print(f'rank {rank} before reduce p = {p}')
torch.distributed.all_reduce(p, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
#print(f'rank {rank} after reduce p = {p}')


if p._hp_mapping is not None:
#print(f'rank {rank} fixing hp for input_layernorm')
#p._hp_mapping.update_hp()

# 2. fp32
hp = p._hp_mapping.hp_fragment
torch.distributed.all_reduce(hp, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())

# 3. optim states
for key in ['exp_avg', 'exp_avg_sq']:
optim_state_fragment = p._hp_mapping.get_optim_state_fragment(key)
#print(f'rank {rank} before reduce optim state fragment {key} = {optim_state_fragment}')
torch.distributed.all_reduce(optim_state_fragment, op=torch.distributed.ReduceOp.AVG, group=mpu.get_tensor_model_parallel_group())
#print(f'rank {rank} after reduce optim state fragment {key} = {optim_state_fragment}')


def sync_all_layer_norms(model):
# syncs weight+bias for each of the following layer norms (via averaging across TP ranks)
# 1. word embedding front word_embeddings.norm
# 2. transformer block input_layernorm x 70
# 3. transformer block post_attention_layernorm x 70
# 4. word embedding head - I think it's just weight + bias w/o a proper name in the last layer file layer_0X-model_0X-model_states.pt, see: https://github.com/bigscience-workshop/Megatron-DeepSpeed/blob/affff3d2927864c6948075700c672971782441f4/megatron/model/gpt_model.py#L267

import re
layer_norms_params_end_with = [
"word_embeddings.norm.weight", "word_embeddings.norm.bias",
"input_layernorm.weight", "input_layernorm.bias",
"post_attention_layernorm.weight", "post_attention_layernorm.bias",
"self_attention.dense.bias", "mlp.dense_4h_to_h.bias",
]

for n,p in model.named_parameters():
#print(n)
# XXX: would be much simpler to re-do this logic to traverse children modules and act on isinstance of MixedFusedLayerNorm instead
# 1. first easy to identify layer norm params as they have a unique prefix each
for end in layer_norms_params_end_with:
if n.endswith(end):
sync_layer_norm(n, p)

# 2. now the last layer norm that has no prefix
# hack: (\d\d): MixedFusedLayerNorm() is hanging there w/o any prefix name, so need to match something like:
# /^6.weight$/ or /^6.bias$/
if mpu.is_pipeline_last_stage() and re.match(r'^\d+\.(weight|bias)$', n):
sync_layer_norm(n, p)

def sync_all_torch_random_state():
torch_rng_state = torch.get_rng_state().cuda()
# We use rank 1 as source of truth and sed the new
torch.distributed.broadcast(
torch_rng_state,
src=mpu.get_tensor_model_parallel_src_rank() + 1,
group=mpu.get_tensor_model_parallel_group()
)
torch.set_rng_state(torch_rng_state.cpu())


def setup_model_and_optimizer(model_provider_func):
"""Setup model and optimizer."""
args = get_args()
Expand Down Expand Up @@ -416,9 +503,17 @@ def setup_model_and_optimizer(model_provider_func):
torch.distributed.barrier()
timers('load-checkpoint').stop()
timers.log(['load-checkpoint'])
print_rank_0(f'module = {model[0]}')

# turn on to enable layer norm syncing
if 1:
sync_all_layer_norms(model[0].module)
sync_all_torch_random_state()
else:
args.iteration = 0

torch.distributed.barrier()

# We only support local DDP with multiple micro-batches.
if len(model) > 1 or mpu.get_pipeline_model_parallel_world_size() > 1:
assert args.DDP_impl == 'local'
Expand Down
6 changes: 4 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@ pybind11
regex
six
tensorboard
torch>=1.7
torch>=1.11
transformers
DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git
# for now using this branch for bf16 work
DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git@olruwase/bf16-updates
#DeepSpeed @ git+https://github.com/microsoft/DeepSpeed.git
# versions from HF transformers
black==21.4b0
isort>=5.5.4
Loading