Fully Sharded Data Parallel (FSDP) in PyTorch XLA is a utility for sharding Module parameters across data-parallel workers.
Example usage:
import torch
import torch_xla.core.xla_model as xm
from torch_xla.distributed.fsdp import XlaFullyShardedDataParallel as FSDP
model = FSDP(my_module)
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
output = model(x, y)
loss = output.sum()
loss.backward()
optim.step()
It is also possible to shard individual layers separately and have an outer wrapper handle any leftover parameters.
Notes:
- The
XlaFullyShardedDataParallel
class supports both the ZeRO-2 optimizer (sharding gradients and optimizer states) and the ZeRO-3 optimizer (sharding parameters, gradients, and optimizer states) in https://arxiv.org/abs/1910.02054.- The ZeRO-3 optimizer should be implemented via nested FSDP with
reshard_after_forward=True
. Seetest/test_train_mp_mnist_fsdp_with_ckpt.py
andtest/test_train_mp_imagenet_fsdp.py
for an example. - For large models that cannot fit into a single TPU memory or the host CPU memory, one should interleave submodule construction with inner FSDP wrapping. See
FSDPViTModel
for an example.
- The ZeRO-3 optimizer should be implemented via nested FSDP with
- a simple wrapper
checkpoint_module
is provided (based ontorch_xla.utils.checkpoint.checkpoint
from pytorch#3524) to perform gradient checkpointing over a givennn.Module
instance. Seetest/test_train_mp_mnist_fsdp_with_ckpt.py
andtest/test_train_mp_imagenet_fsdp.py
for an example. - Auto-wrapping submodules: instead of manually nested FSDP wrapping, one can also specify an
auto_wrap_policy
argument to automatically wrap the submodules with inner FSDP.size_based_auto_wrap_policy
intorch_xla.distributed.fsdp.wrap
is an example ofauto_wrap_policy
callable, this policy wraps layers with the number of parameters larger than 100M.transformer_auto_wrap_policy
intorch_xla.distributed.fsdp.wrap
is an example ofauto_wrap_policy
callable for transformer-like model architectures.
For example, to automatically wrap all torch.nn.Conv2d
submodules with inner FSDP, one can use:
from torch_xla.distributed.fsdp.wrap import transformer_auto_wrap_policy
auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={torch.nn.Conv2d})
Additionally, one can also specify an auto_wrapper_callable
argument to use a custom callable wrapper for the submodules (the default wrapper is just the XlaFullyShardedDataParallel
class itself). For example, one can use the following to apply gradient checkpointing (i.e. activation checkpointing/rematerialization) to each auto-wrapped submodule.
from torch_xla.distributed.fsdp import checkpoint_module
auto_wrapper_callable = lambda m, *args, **kwargs: XlaFullyShardedDataParallel(
checkpoint_module(m), *args, **kwargs)
- When stepping the optimizer, directly call
optimizer.step
and do not callxm.optimizer_step
. The latter reduces the gradient across ranks, which is not needed for FSDP (where the parameters are already sharded). - When saving model and optimizer checkpoints during training, each training process needs to save its own checkpoint of the (sharded) model and optimizer state dicts (use
master_only=False
and set different paths for each rank inxm.save
). When resuming, it needs to load the checkpoint for the corresponding rank. - Please also save
model.get_shard_metadata()
along withmodel.state_dict()
as follows and useconsolidate_sharded_model_checkpoints
to stitch the sharded model checkpoints together into a full model state dict. Seetest/test_train_mp_mnist_fsdp_with_ckpt.py
for an example.
ckpt = {
'model': model.state_dict(),
'shard_metadata': model.get_shard_metadata(),
'optimizer': optimizer.state_dict(),
}
ckpt_path = f'/tmp/rank-{xm.get_ordinal()}-of-{xm.xrt_world_size()}.pth'
xm.save(ckpt, ckpt_path, master_only=False)
- The checkpoint consolidation script can also be launched from the command line as follows.
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /path/to/your_sharded_checkpoint_files \
--ckpt_suffix "_rank-*-of-*.pth"
The implementation of this class is largely inspired by and mostly follows the structure of fairscale.nn.FullyShardedDataParallel
in https://fairscale.readthedocs.io/en/stable/api/nn/fsdp.html. One of the biggest differences from fairscale.nn.FullyShardedDataParallel
is that in XLA we don't have explicit parameter storage, so here we resort to a different approach to free full parameters for ZeRO-3.
- MNIST:
test/test_train_mp_mnist_fsdp_with_ckpt.py
(it also tests checkpoint consolidation) - ImageNet:
test/test_train_mp_imagenet_fsdp.py
FSDP is available on PyTorch/XLA 1.12 release and newer nightly. Please refer to https://github.com/pytorch/xla#-available-images-and-wheels for installation guide.
git clone --recursive https://github.com/pytorch/pytorch
cd pytorch/
git clone --recursive https://github.com/pytorch/xla.git
cd ~/
It gets around 98.9 accuracy for 2 epochs:
python3 ~/pytorch/xla/test/test_train_mp_mnist_fsdp_with_ckpt.py \
--batch_size 16 --drop_last --num_epochs 2 \
--use_nested_fsdp --use_gradient_checkpointing
This script automatically tests checkpoint consolidation at the end. You can also manually consolidate the sharded checkpoints via
# consolidate the saved checkpoints via command line tool
python3 -m torch_xla.distributed.fsdp.consolidate_sharded_ckpts \
--ckpt_prefix /tmp/mnist-fsdp/final_ckpt \
--ckpt_suffix "_rank-*-of-*.pth"
It gets around 75.9 accuracy for 100 epochs; download ImageNet-1k to /datasets/imagenet-1k
:
python3 ~/pytorch/xla/test/test_train_mp_imagenet_fsdp.py \
--datadir /datasets/imagenet-1k --drop_last \
--model resnet50 --test_set_batch_size 64 --eval_interval 10 \
--lr 0.4 --batch_size 128 --num_warmup_epochs 5 --lr_scheduler_divide_every_n_epochs 30 --lr_scheduler_divisor 10 --num_epochs 100 \
--use_nested_fsdp
You can also add --use_gradient_checkpointing
(which needs to be used along with --use_nested_fsdp
or --auto_wrap_policy
) to apply gradient checkpointing on the residual blocks.
To train large models that cannot fit into a single TPU, one should apply auto-wrap or manually wrap the submodules with inner FSDP when building the entire model to implement the ZeRO-3 algorithm.
Please see https://github.com/ronghanghu/vit_10b_fsdp_example for an example of sharded training of a Vision Transformer (ViT) model using this XLA FSDP PR.