Skip to content

Commit

Permalink
Merge branch 'develop' into axonn-cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
Avuxon authored Jan 29, 2024
2 parents 6e38585 + a9d38c2 commit 7f64a37
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 36 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/nvidia-rtx-3090-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
export G_data=$(( 2 / G_inter ))
export memopt=${{ matrix.memopt }}
echo "training with G_inter = ${G_inter}, G_data = $(( 2 / G_inter )) ${{ matrix.memopt }}"
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_vit.py
mpirun -mca orte_allowed_exit_without_sync 1 -n 2 pytest --with-mpi ./axonn/tests/test_vit.py
- name: Uninstall AxoNN
run: |
pip uninstall --yes axonn
Expand All @@ -47,10 +47,10 @@ jobs:
pip install -r requirements.txt
- name: Run intra-layer FC unit tests
run: |
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_fc.py
mpirun -mca orte_allowed_exit_without_sync 1 -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_fc.py
- name: Run intra-layer Conv unit tests
run: |
mpirun -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_conv.py
mpirun -mca orte_allowed_exit_without_sync 1 -n 2 pytest --with-mpi ./axonn/tests/test_intra_layer_conv.py
- name: Uninstall AxoNN
run: |
pip uninstall --yes axonn
14 changes: 13 additions & 1 deletion axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,21 @@
from .communication import communication_handle
from .optim import CPUAdam
import torch
from mpi4py import MPI
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from enum import Enum
import numpy as np
import types

try:
# from mpi4py import MPI
import mpi4py

MPI4PY = True
mpi4py.rc.initialize = False # do not initialize MPI automatically
from mpi4py import MPI
except ImportError:
MPI4PY = False

# True when init has been called
is_initialized = False
# Communication handle for point-to-point (MPI) and collective (NCCL) communication
Expand Down Expand Up @@ -577,6 +586,7 @@ def _recv(post_fw_recv=True, post_bw_recv=True, eval_mode=False) -> int:
Returns:
tag(int): the tag of the received message which is the microbatch number
"""
assert MPI4PY, "attempting to use inter-layer parallelism without mpi4py installed"
status = MPI.Status()
if (requests["bw"] is None) and (requests["fw"] is not None):
requests["fw"][1].Wait(status)
Expand Down Expand Up @@ -655,6 +665,8 @@ def _backward_pass(output_gradients, microbatch_no):


def _sync_scale(local_overflow):
assert MPI4PY, "attempting to use inter-layer parallelism without mpi4py installed"

global loss_scale, no_overflow_iters, max_scale
assert computation_dtype == torch.float16
overflow_np = np.array(int(local_overflow), "i")
Expand Down
21 changes: 15 additions & 6 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

import os
import sys
try:
from mpi4py import MPI
# from mpi4py import MPI
import mpi4py

MPI4PY = True
mpi4py.rc.initialize = False # do not initialize MPI automatically
from mpi4py import MPI
except ImportError:
MPI4PY = False
import torch
Expand Down Expand Up @@ -43,6 +46,8 @@ def __init__(
if not torch.distributed.is_initialized():
assert MPI4PY, "either install mpi4py and launch via mpirun/srun"
"or initialize torch.distributed outside axonn"
if not MPI.Is_initialized():
MPI.Init()
self.world_rank = MPI.COMM_WORLD.Get_rank()
self.world_size = MPI.COMM_WORLD.Get_size()
else:
Expand Down Expand Up @@ -86,6 +91,8 @@ def __init__(
if G_inter > 1:
# this needs to be checked
if MPI4PY:
if not MPI.Is_initialized():
MPI.Init()
self.p2p_mpi_comm = MPI.COMM_WORLD.Split(colour)
assert self.p2p_mpi_comm.Get_size() == G_inter
else:
Expand Down Expand Up @@ -155,7 +162,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[i, j, :]
)
group = torch.distributed.new_group(
ranks=group_members, backend="gloo"
ranks=group_members, backend="self.backend"
)
if self.world_rank in group_members:
self.inner_intra_layer_parallel_group = group
Expand All @@ -167,7 +174,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[i, :, j]
)
group = torch.distributed.new_group(
ranks=group_members, backend="gloo"
ranks=group_members, backend="self.backend"
)
if self.world_rank in group_members:
self.outer_intra_layer_parallel_group = group
Expand All @@ -179,7 +186,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[:, i, j]
)
group = torch.distributed.new_group(
ranks=group_members, backend="gloo"
ranks=group_members, backend="self.backend"
)
if self.world_rank in group_members:
self.depth_intra_layer_parallel_group = group
Expand Down Expand Up @@ -241,7 +248,7 @@ def recv(
self,
tensor: torch.Tensor,
send_rank: int,
tag: int = MPI.ANY_TAG,
tag: int = None,
async_op: bool = True,
):
"""Receive a PyTorch tensor from a particular rank using MPI
Expand All @@ -257,6 +264,8 @@ def recv(
mpi4py future object if async is true, else None - this object
can be queried to check for completion of communication
"""
if tag is None:
tag = MPI.ANY_TAG
mpi4py_compatible_array = self._torch_to_mpi(tensor)
if async_op:
mpi_future_object = self.p2p_mpi_comm.Irecv(
Expand Down
132 changes: 106 additions & 26 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,12 @@

from axonn import axonn as ax
import axonn
from .communication import Drop, Gather, ForwardGather_BackwardReduceScatter

from .communication import (
Drop,
Gather,
_gather,
_reduce_scatter,
)

def divide(a, b):
assert a % b == 0
Expand Down Expand Up @@ -57,11 +61,20 @@ def forward(
weight,
forward_all_reduce_group,
backward_all_reduce_group,
depth_parallel_group,
local_weight_shape,
cache_weights,
backward_comm_async,
forward_comm_async,
):
ctx.save_for_backward(input_, weight)
original_weight = weight
weight = _gather(
weight, dim=0, process_group=depth_parallel_group, cache=cache_weights
)
weight = weight.reshape(local_weight_shape)
ctx.save_for_backward(input_, weight, original_weight)
ctx.backward_all_reduce_group = backward_all_reduce_group
ctx.depth_parallel_group = depth_parallel_group
ctx.backward_comm_async = backward_comm_async
if not forward_comm_async:
output = input_.matmul(weight.t())
Expand All @@ -86,25 +99,59 @@ def forward(
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
input_, weight = ctx.saved_tensors
input_, weight, original_weight = ctx.saved_tensors
handle = None
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
handle = dist.all_reduce(
grad_input,
group=ctx.backward_all_reduce_group,
async_op=ctx.backward_comm_async,
)
if ctx.needs_input_grad[1]:
grad_weight = (
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
overlap_reduce_scatter = axonn.intra_layer.OVERLAP_REDUCE_SCATTER
if dist.get_world_size(ctx.backward_all_reduce_group) > 1 or (
not overlap_reduce_scatter
):
if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
handle = dist.all_reduce(
grad_input,
group=ctx.backward_all_reduce_group,
async_op=ctx.backward_comm_async,
)
if ctx.needs_input_grad[1]:
grad_weight = (
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
)

grad_weight = grad_weight.reshape(-1)
grad_weight = _reduce_scatter(
grad_weight,
dim=0,
process_group=ctx.depth_parallel_group,
overlap_comm=overlap_reduce_scatter,
)
if handle and ctx.backward_comm_async:
handle.wait()
return grad_input, grad_weight, None, None, None, None

if handle and ctx.backward_comm_async:
handle.wait()
if overlap_reduce_scatter:
axonn.intra_layer.accumulate_later(original_weight, grad_weight)
grad_weight = None # weight gradients are not ready yet
return grad_input, grad_weight, None, None, None, None, None, None, None
else:
if ctx.needs_input_grad[1]:
grad_weight = (
grad_output.reshape(-1, grad_output.shape[-1])
.t()
.mm(input_.view(-1, input_.shape[-1]))
).reshape(-1)
grad_weight = _reduce_scatter(
grad_weight,
dim=0,
process_group=ctx.depth_parallel_group,
overlap_comm=True,
)
axonn.intra_layer.accumulate_later(original_weight, grad_weight)
grad_weight = None # weight gradients are not ready yet

if ctx.needs_input_grad[0]:
grad_input = grad_output.matmul(weight)
return grad_input, grad_weight, None, None, None, None, None, None, None

class Linear(torch.nn.Module):
def __init__(
Expand Down Expand Up @@ -170,6 +217,39 @@ def __init__(
ax.comm_handle.intra_layer_group,
)

if bias:
self.bias = torch.nn.Parameter(
torch.zeros(
self.local_out_features,
)
)
setattr(self.bias, "is_tensor_parallel", True)
setattr(self.bias, "needs_gradient_sync", True)
if not transpose:
setattr(
self.bias,
"process_group_for_norm_reduction",
ax.comm_handle.outer_intra_layer_parallel_group,
)
else:
setattr(
self.bias,
"process_group_for_norm_reduction",
ax.comm_handle.inner_intra_layer_parallel_group,
)
else:
self.bias = None

self.weight = torch.nn.Parameter(initial_params, requires_grad=True)

setattr(self.weight, "is_tensor_parallel", True)
setattr(self.weight, "needs_gradient_sync", False)
setattr(
self.weight,
"process_group_for_norm_reduction",
ax.comm_handle.intra_layer_group,
)

if bias:
self.bias = torch.nn.Parameter(
torch.zeros(
Expand Down Expand Up @@ -210,14 +290,8 @@ def forward(
):
# gather weights from depth parallel group
# reduce scatter in the backward pass
weight = ForwardGather_BackwardReduceScatter.apply(
self.weight,
self.depth_group,
0,
axonn.intra_layer.OVERLAP_REDUCE_SCATTER,
cache_weights_in_all_gather,
).reshape(self.local_out_features, self.local_in_features)

weight = self.weight
if not self.transpose:
if scatter_input:
x = Drop.apply(x, self.inner_group)
Expand All @@ -227,6 +301,9 @@ def forward(
weight,
self.inner_group,
self.outer_group,
self.depth_group,
(self.local_out_features, self.local_in_features),
cache_weights_in_all_gather,
axonn.intra_layer.OVERLAP_ALL_REDUCE,
False,
)
Expand All @@ -243,6 +320,9 @@ def forward(
weight,
self.outer_group,
self.inner_group,
self.depth_group,
(self.local_out_features, self.local_in_features),
cache_weights_in_all_gather,
axonn.intra_layer.OVERLAP_ALL_REDUCE,
False,
)
Expand Down

0 comments on commit 7f64a37

Please sign in to comment.