Skip to content

Commit

Permalink
Basic MegaBlock Support
Browse files Browse the repository at this point in the history
Adding basic support for MegaBlocks MoE and dMoE layers.

---------

Co-authored-by: Deepak Narayanan <[email protected]>
  • Loading branch information
tgale96 and deepakn94 authored Feb 22, 2023
1 parent 285068c commit fde1cbb
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 20 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ __pycache__
build
.coverage_*
*.egg-info
*~
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ The following table shows both model (MFU) and hardware (HFU) FLOPs utilization
* [Distributed Pretraining](#distributed-pretraining)
* [Activation Checkpointing and Recomputation](#activation-checkpointing-and-recomputation)
* [Distributed Optimizer](#distributed-optimizer)
* [Mixture-of-Experts](#mixture-of-experts)
* [GPT-3 Example](#gpt-3-example)
* [Evaluation and Tasks](#evaluation-and-tasks)
* [GPT Text Generation](#gpt-text-generation)
Expand Down Expand Up @@ -346,6 +347,17 @@ To install FlashAttention:
pip install flash-attn
```

## Mixture-of-Experts

Usage: `--moe-num-experts <number_of_experts>`. See command line arguments prefixed with `moe-` for additional mixture-of-experts (MoE) arguments. Compatible with GPT models.

MoEs are supported through [MegaBlocks](https://github.com/stanford-futuredata/megablocks), a light-weight library for MoE training. The core of the system is efficient "dropless-MoE" ([paper](https://arxiv.org/abs/2211.15841)) and standard MoE layers.

To install MegaBlocks:
```sh
pip install megablocks
```

## GPT-3 Example

In `examples/pretrain_gpt3_175B.sh` we have provided an example of how to configure Megatron to run [GPT-3](https://arxiv.org/abs/2005.14165) with 175 billion parameters on 1024 GPUs. The script is designed for [slurm](https://slurm.schedmd.com/documentation.html) with [pyxis](https://github.com/NVIDIA/pyxis) plugin but can be easily adopted to any other scheduler. It uses 8-way and 16-way tensor and pipeline parallelism, respectively. With options `global-batch-size 1536` and `rampup-batch-size 16 16 5859375`, the training will start with global batch size 16 and linearly increase the global batch size to 1536 over 5,859,375 samples with incrmeental steps 16. The training dataset can be either a single set or a multiple datasets combined with a set of weights.
Expand Down
44 changes: 35 additions & 9 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
parser = _add_logging_args(parser)
parser = _add_inference_args(parser)
parser = _add_transformer_engine_args(parser)
parser = _add_moe_args(parser)

# Custom arguments.
if extra_args_provider is not None:
Expand All @@ -43,7 +44,7 @@ def parse_args(extra_args_provider=None, ignore_unknown_args=False):
# Args from environment
args.rank = int(os.getenv('RANK', '0'))
args.world_size = int(os.getenv("WORLD_SIZE", '1'))

return args

def validate_args(args, defaults={}):
Expand Down Expand Up @@ -333,7 +334,6 @@ def validate_args(args, defaults={}):
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False


if os.environ.get('CUDA_DEVICE_MAX_CONNECTIONS') != "1":
if args.sequence_parallel:
raise RuntimeError(
Expand All @@ -344,6 +344,10 @@ def validate_args(args, defaults={}):
"Using async gradient all reduce requires setting the environment "
"variable CUDA_DEVICE_MAX_CONNECTIONS to 1")

# Pipeline parallelism not supported with MoE.
if args.moe_num_experts is not None:
assert args.pipeline_model_parallel_size == 1, (
"Pipeline parallelism not yet support for MoEs.")

_print_args(args)
return args
Expand Down Expand Up @@ -403,15 +407,15 @@ def _add_inference_args(parser):
help='During inference, if batch-size times '
'sequence-length is smaller than this threshold '
'then we will not use pipelining, otherwise we will.')

group.add_argument('--max-tokens-to-oom',
type=int, default=12000,
help='Maximum number of tokens during inference'
'tokens here is # in prompt + # to generate'
'Allows us to throw an error before OOM crashes server')
return parser


def _add_network_size_args(parser):
group = parser.add_argument_group(title='network size')

Expand Down Expand Up @@ -455,8 +459,31 @@ def _add_network_size_args(parser):
group.add_argument('--bert-no-binary-head', action='store_false',
help='Disable BERT binary head.',
dest='bert_binary_head')
group.add_argument('--num-experts', type=int, default=None,
help='Number of Experts in Switch Transformer (None means no Switch)')
return parser


def _add_moe_args(parser):
group = parser.add_argument_group(title='moe')
group.add_argument('--moe-num-experts', type=int, default=None,
help='The number of experts in MoE layers. MoE '
'layers not used if set to None')
group.add_argument('--moe-capacity-factor', type=int, default=0,
help='Capacity factor for MoE layers. If zero, use '
'dropless MoE implementation.')
group.add_argument('--moe-top-k', type=int, default=1,
help='The number of experts each token is routed to '
'in MoE layers.')
group.add_argument('--moe-loss-weight', type=float, default=0.1,
help='The weight for the MoE auxiliary load balancing '
'loss.')
group.add_argument('--moe-lbl-in-fp32', type=bool, default=False,
help='Whether to compute the load balancing loss in '
'fp32.')
group.add_argument('--moe-jitter-eps', type=float, default=None,
help='Coefficient for MoE routing jitter. Jitter is '
'not used if set to None.')
group.add_argument('--moe-use-megatron-switch', type=bool, default=False,
help='Whether to use Megatron SwitchMLP for MoE layers.')
return parser


Expand Down Expand Up @@ -873,7 +900,6 @@ def _add_distributed_args(parser):
'affects the encoder embedding.)')
group.add_argument('--use-distributed-optimizer', action='store_true',
help='Use distributed optimizer.')

return parser


Expand Down Expand Up @@ -1078,14 +1104,14 @@ def _add_vision_args(parser):
group.add_argument('--swin-backbone-type', type=str, default='tiny',
choices=['tiny', 'base', 'h3'],
help='pretraining objectives')

# inpainting arguments
group.add_argument('--mask-type', type=str, default='random',
choices=['random', 'row'],
help='mask types')
group.add_argument('--mask-factor', type=float, default=1.0,
help='mask size scaling parameter')

# dino arguments
group.add_argument('--iter-per-epoch', type=int, default=1250,
help='iterations per epoch')
Expand Down
16 changes: 16 additions & 0 deletions megatron/model/megablocks_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Adapter to expose MegaBlocks package, if available."""
try:
import megablocks
except ImportError:
megablocks = None

def megablocks_is_available():
return megablocks is not None

def assert_megablocks_is_available():
assert megablocks_is_available(), (
'MegaBlocks not available. Please run `pip install megablocks`.')

moe = megablocks.layers.moe if megablocks_is_available() else None
dmoe = megablocks.layers.dmoe if megablocks_is_available() else None
arguments = megablocks.layers.arguments if megablocks_is_available() else None
46 changes: 39 additions & 7 deletions megatron/model/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .module import MegatronModule
from megatron.core import mpu, tensor_parallel
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType
from megatron.model import LayerNorm
from megatron.model import LayerNorm, megablocks_utils
from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from megatron.model.fused_bias_gelu import bias_gelu_impl
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
Expand Down Expand Up @@ -136,9 +136,9 @@ class SwitchMLP(MegatronModule):
def __init__(self, init_method, output_layer_init_method):
super(SwitchMLP, self).__init__()
args = get_args()
self.router = torch.nn.Linear(args.hidden_size, args.num_experts)
self.router = torch.nn.Linear(args.hidden_size, args.moe_num_experts)
self.experts = torch.nn.ModuleList()
for i in range(args.num_experts):
for i in range(args.moe_num_experts):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))

def forward(self, hidden_states):
Expand Down Expand Up @@ -177,6 +177,33 @@ def forward(self, hidden_states):

return output_total, output_bias_total

class _MegablocksAdapter(MegatronModule):

def __init__(self, layer_cls, init_method, output_layer_init_method):
super().__init__()
megablocks_utils.assert_megablocks_is_available()
args = megablocks_utils.arguments.from_megatron(get_args())
args.device = torch.cuda.current_device()
args.init_method = init_method
args.output_layer_init_method = output_layer_init_method
self.moe = layer_cls(args)

def forward(self, x):
return self.moe.forward(x)

class MoE(_MegablocksAdapter):

def __init__(self, init_method, output_layer_init_method):
megablocks_utils.assert_megablocks_is_available()
super().__init__(
megablocks_utils.moe.MoE, init_method, output_layer_init_method)

class dMoE(_MegablocksAdapter):

def __init__(self, init_method, output_layer_init_method):
megablocks_utils.assert_megablocks_is_available()
super().__init__(
megablocks_utils.dmoe.dMoE, init_method, output_layer_init_method)

class CoreAttention(MegatronModule):

Expand Down Expand Up @@ -673,10 +700,15 @@ def __init__(self, init_method, output_layer_init_method,
sequence_parallel=args.sequence_parallel)

# MLP
if args.num_experts is not None:
self.mlp = SwitchMLP(init_method, output_layer_init_method)
else:
self.mlp = ParallelMLP(init_method, output_layer_init_method)
mlp_cls = ParallelMLP
if args.moe_num_experts is not None:
if args.moe_use_megatron_switch:
mlp_cls = SwitchMLP
elif args.moe_capacity_factor > 0:
mlp_cls = MoE
else:
mlp_cls = dMoE
self.mlp = mlp_cls(init_method, output_layer_init_method)

# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR = int(torch.__version__.split('.')[0])
Expand Down
41 changes: 37 additions & 4 deletions pretrain_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@

import torch
from functools import partial

from megatron import get_args
from megatron import print_rank_0
from megatron import get_timers
from megatron import get_tokenizer
from megatron.core import tensor_parallel
from megatron.data.gpt_dataset import build_train_valid_test_datasets
from megatron.model import GPTModel, ModelType
from megatron.model import GPTModel, ModelType, megablocks_utils
from megatron.model.megablocks_utils import moe
from megatron.training import pretrain
from megatron.utils import get_ltor_masks_and_position_ids
from megatron.utils import average_losses_across_data_parallel_group
Expand Down Expand Up @@ -69,6 +71,37 @@ def loss_func(loss_mask, output_tensor):

return loss, {'lm loss': averaged_loss[0]}

def moe_loss_func(loss_mask, output_tensor):
loss, loss_dict = loss_func(loss_mask, output_tensor)
assert loss.numel() == 1

# NOTE: If recompute is enabled we will collect duplicate load
# balancing loss contributions. Prune these before calculating
# the load balancing loss.
args = get_args()
if args.recompute_granularity is not None:
# Ignore load balancing loss contributions compute during
# the forward pass if recompute is turned on.
load_balancing_loss_data = moe.get_load_balancing_loss()
if args.num_layers * 2 == len(load_balancing_loss_data):
load_balancing_loss_data = (
load_balancing_loss_data[args.num_layers:])
moe.clear_load_balancing_loss()
moe.save_load_balancing_loss(load_balancing_loss_data)

# Compute the load balancing loss for all MoE layers.
megablocks_args = megablocks_utils.arguments.from_megatron(args)
lbl = moe.batched_load_balancing_loss(megablocks_args)
moe.clear_load_balancing_loss()

# Average the load balancing loss across data parallel
# replicas and save for logging.
averaged_lbl = average_losses_across_data_parallel_group([lbl])
loss_dict['load balancing loss'] = averaged_lbl[0]

# Compute the total loss, if necessary.
total_loss = loss + lbl if loss is not None else lbl
return total_loss, loss_dict

def forward_step(data_iterator, model):
"""Forward step."""
Expand All @@ -84,8 +117,9 @@ def forward_step(data_iterator, model):
output_tensor = model(tokens, position_ids, attention_mask,
labels=labels)

return output_tensor, partial(loss_func, loss_mask)

loss_fn = (
moe_loss_func if args.moe_num_experts is not None else loss_func)
return output_tensor, partial(loss_fn, loss_mask)

def train_valid_test_datasets_provider(train_val_test_num_samples):
"""Build train, valid, and test datasets."""
Expand All @@ -110,7 +144,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):


if __name__ == "__main__":

pretrain(train_valid_test_datasets_provider, model_provider,
ModelType.encoder_or_decoder,
forward_step,
Expand Down

0 comments on commit fde1cbb

Please sign in to comment.