Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding CPU training support to AxoNN #39

Open
wants to merge 28 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e3bfa64
Rebased cpu changes onto latest axonn
Feb 2, 2024
472d355
Args should default to true, removed tensor_list operator for cpus
Feb 5, 2024
9c237b1
Made changes conditional for CPUs/GPUs
Feb 6, 2024
676640d
Gpu runs do not make Slurm calls
Feb 8, 2024
e84c64d
Revamped cpu flag with cpu hard set in comm.py
Feb 23, 2024
ae74e70
Set device before init
Mar 4, 2024
9514d65
Added initial pytest parameters for cpu tests
Mar 13, 2024
0a22f88
removed mpi4py dependency (#63)
S-Mahua Feb 20, 2024
52b0aef
adding parallelize context for opt (#65)
jwendlan Feb 27, 2024
7728f1a
Removing the drop and gathers in depth tensor parallelism for the eas…
siddharth9820 Feb 28, 2024
5087268
change parallelize context to use AutoConfig (#67)
siddharth9820 Feb 28, 2024
5faec5b
Bugfix: Initialize grad_input, grad_weight to None (#68)
adityaranjan Mar 6, 2024
e831180
Removed distributed communication class
Mar 24, 2024
0444e50
Merge branch 'develop' into axonn-cpu
siddharth9820 May 1, 2024
70a5a37
formatting
siddharth9820 May 1, 2024
e185214
fix CI
siddharth9820 May 1, 2024
e9f21ce
format
siddharth9820 May 1, 2024
f3a5dfc
Matched changes for intra_layer_conv tests
May 8, 2024
88176e9
Removed unused env variables and format
May 8, 2024
c63e489
Merge branch 'develop' into axonn-cpu
Avuxon May 8, 2024
f933edc
Fixed CI to return for depth > 1
May 8, 2024
20af63a
Merge branch 'axonn-cpu' of https://github.com/axonn-ai/axonn into ax…
May 8, 2024
8249add
Added depth check to fw pass in fc test
May 8, 2024
1a677ff
Merge develop into this branch
May 24, 2024
f39883f
Added xfail to convolution tests
May 24, 2024
d24fa4c
remove xfail
siddharth9820 Jun 10, 2024
ad779c3
Merge branch 'develop' into axonn-cpu
siddharth9820 Jun 18, 2024
1f9e54b
Merge branch 'develop' into axonn-cpu
siddharth9820 Jun 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .github/workflows/nvidia-rtx-3090-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ jobs:
ginter: [ 1, 2 ]
memopt: [ '0', '1' ]

env:
SLURM_NTASKS: 0
Avuxon marked this conversation as resolved.
Show resolved Hide resolved
SLURM_PROCID: 0

steps:
- uses: actions/checkout@v3
- name: Install AxoNN
Expand All @@ -40,6 +44,10 @@ jobs:
intra-layer:
runs-on: [ nvidia ]

env:
SLURM_NTASKS: 0
SLURM_PROCID: 0

steps:
- uses: actions/checkout@v3
- name: Install AxoNN
Expand Down
17 changes: 14 additions & 3 deletions axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def init(
mixed_precision=False,
fp16_allreduce=True,
cpu_offload=False,
device="cuda",
) -> None:
"""
Initialize AxoNN's 2D parallelism with G_inter-way inter-layer
Expand All @@ -135,8 +136,15 @@ def init(
global comm_handle, is_initialized, computation_dtype, _fp16_all_reduce
global _cpu_offload
comm_handle = communication_handle(
G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node
G_inter,
G_data,
G_intra_r,
G_intra_c,
G_intra_d,
gpus_per_node=gpus_per_node,
device=device,
)
config.device = device
config.G_inter = G_inter
config.G_data = G_data
config.G_intra = G_intra_r * G_intra_c * G_intra_d
Expand All @@ -152,6 +160,9 @@ def init(
comm_handle.intra_layer_column_parallel_rank
)
is_initialized = True
if device == "cuda" and not torch.cuda.is_available():
raise ValueError("CUDA is not available. Please choose a different device.")

if mixed_precision:
computation_dtype = torch.float16
else:
Expand Down Expand Up @@ -537,7 +548,7 @@ def _post_fw_recv_requests():
if (requests["fw"] is None) and config.inter_layer_parallel_rank > 0:
tensor = torch.empty(
size=_fill_shape(model.get_input_shape()),
device="cuda",
device=config.device,
dtype=computation_dtype,
)
tensor.requires_grad = True
Expand All @@ -556,7 +567,7 @@ def _post_bw_recv_requests():
):
tensor = torch.empty(
size=_fill_shape(model.get_output_shape()),
device="cuda",
device=config.device,
dtype=computation_dtype,
)
requests["bw"] = [
Expand Down
26 changes: 17 additions & 9 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
G_intra_c=1,
G_intra_d=1,
gpus_per_node=None,
device="cuda",
):
"""Constructor for the communication handle

Expand All @@ -44,6 +45,11 @@ def __init__(
G_intra_c (int): number of GPUs in the column intra-layer parallel dimension
G_intra_d (int): number of GPUs in the depth intra-layer parallel dimension
"""
if device == "cpu":
self.backend = "gloo"
else:
self.backend = "nccl"

if not torch.distributed.is_initialized():
assert MPI4PY, "either install mpi4py and launch via mpirun/srun"
"or initialize torch.distributed outside axonn"
Expand Down Expand Up @@ -71,8 +77,10 @@ def __init__(
self.gpus_per_node = (
gpus_per_node if gpus_per_node is not None else torch.cuda.device_count()
)
self.local_rank = self.world_rank % self.gpus_per_node
torch.cuda.set_device(self.local_rank)

if device == "cuda":
self.local_rank = self.world_rank % self.gpus_per_node
torch.cuda.set_device(self.local_rank)
self.intra_layer_parallel_rank = self.world_rank % G_intra
self.intra_layer_column_parallel_rank = (
self.intra_layer_parallel_rank % G_intra_c
Expand Down Expand Up @@ -115,7 +123,7 @@ def __init__(
master_port = os.getenv("MASTER_PORT", "6000")
init_method += master_ip + ":" + master_port
torch.distributed.init_process_group(
backend="nccl",
backend=self.backend,
world_size=self.world_size,
rank=self.world_rank,
init_method=init_method,
Expand All @@ -130,7 +138,7 @@ def __init__(
for k in range(self.G_data)
]
ith_jth_data_parallel_group = torch.distributed.new_group(
ranks=ranks_in_ith_jth_data_parallel_group, backend="nccl"
ranks=ranks_in_ith_jth_data_parallel_group, backend=self.backend
)
if self.world_rank in ranks_in_ith_jth_data_parallel_group:
self.coll_nccl_comm = ith_jth_data_parallel_group
Expand All @@ -142,7 +150,7 @@ def __init__(
i_ * G_inter * G_intra + j_ * G_intra + k for k in range(G_intra)
]
ith_jth_intra_layer_group = torch.distributed.new_group(
ranks=ranks_in_ith_jth_intra_layer_group, backend="nccl"
ranks=ranks_in_ith_jth_intra_layer_group, backend=self.backend
)
if self.world_rank in ranks_in_ith_jth_intra_layer_group:
self.intra_layer_group = ith_jth_intra_layer_group
Expand All @@ -165,7 +173,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[i, j, :]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.inner_intra_layer_parallel_group = group
Expand All @@ -177,7 +185,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[i, :, j]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.outer_intra_layer_parallel_group = group
Expand All @@ -189,7 +197,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[:, i, j]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.depth_intra_layer_parallel_group = group
Expand All @@ -200,7 +208,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[i, :, :].flatten()
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.outer_inner_intra_layer_parallel_group = group
Expand Down
1 change: 1 addition & 0 deletions axonn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
G_data = 0
micro_batch_size = 0
batch_size = 0
device = "cuda"
29 changes: 22 additions & 7 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,22 @@ def trigger_async_all_gathers(model):
handle = None
else:
assert weight.ndim == 1
output_shape = weight.shape[0] * world_size
all_gathered_weight = torch.empty(
output_shape, dtype=weight.dtype, device=weight.device
)
handle = dist.all_gather_into_tensor(
all_gathered_weight, weight, group=process_group, async_op=True
)

if torch.distributed.get_backend() == "nccl":
output_shape = weight.shape[0] * world_size
all_gathered_weight = torch.empty(
output_shape, dtype=weight.dtype, device=weight.device
)
handle = dist.all_gather_into_tensor(
all_gathered_weight,
weight,
group=process_group,
async_op=True,
)

elif torch.distributed.get_backend() == "gloo":
raise NotImplementedError

weights_cache[weight] = [all_gathered_weight, handle]
yield

Expand Down Expand Up @@ -173,6 +182,12 @@ def optimize_communication(
"for_overlapping_allgathers=model, ...)"
"if overlap_all_gather is True"
)

if torch.distributed.get_backend() == "gloo":
raise ValueError(
"overlap_all_gather does not work with gloo" "please set it to false"
)

ALL_GATHER_ITERATOR = trigger_async_all_gathers(
model_object_for_overlapping_allgathers
)
Expand Down
4 changes: 3 additions & 1 deletion axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.distributed as dist
import torch
import axonn
from axonn import config


def _all_reduce(input_, process_group=None, overlap_comm=False):
Expand Down Expand Up @@ -48,7 +49,8 @@ def _gather(input_, dim, process_group=None, cache=False):
tensor_list = [
torch.empty_like(input_) for _ in range(dist.get_world_size(process_group))
]
tensor_list[rank] = input_
if config.device == "cuda":
tensor_list[rank] = input_
dist.all_gather(tensor_list, input_, group=process_group)

# Note: torch.cat already creates a contiguous tensor.
Expand Down
5 changes: 3 additions & 2 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from axonn import axonn as ax
import axonn
from axonn import config
from .communication import (
Drop,
Gather,
Expand Down Expand Up @@ -38,9 +39,9 @@ def initialize_params(
in_features_group,
depth_group,
init_method,
init_device="cuda",
init_device=config.device,
):
params = torch.empty((out_features, in_features), device=init_device)
params = torch.empty((out_features, in_features), device=config.device)
init_method(params)
params = extract_local_params_from_full_params(
params, out_features_group, in_features_group, depth_group
Expand Down
56 changes: 32 additions & 24 deletions axonn/tests/test_intra_layer_fc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,28 @@
)
@pytest.mark.parametrize("easy_tp", [False, True])
@pytest.mark.parametrize("bias", [False, True])
def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias):
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias, device):
# These tests are in fp-32
torch.manual_seed(42)

# GPU runs on axonn-cpu currently do not work with mixed_precision or fp16_allreduce
# if set_device == "cpu":
# bool set_mixed_precision = False
# bool set_fp16_allreduce = False

ax.init(
G_data=1,
G_inter=1,
G_intra_r=G_intra_r,
G_intra_c=G_intra_c,
G_intra_d=G_intra_d,
mixed_precision=False,
fp16_allreduce=False,
device=device,
)

X = torch.randn(B, H).cuda() * 0.01
X = torch.randn(B, H).to(device) * 0.01

inner_group = ax.comm_handle.inner_intra_layer_parallel_group
outer_group = ax.comm_handle.outer_intra_layer_parallel_group
Expand All @@ -44,8 +54,10 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias):
) # divide colunns of X along the inner tensor group
# manually divide input

layer = Linear(in_features=H, out_features=H, bias=bias).cuda()
layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda()
layer = Linear(in_features=H, out_features=H, bias=bias).to(device)
layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).to(
device
)

# test if load state dict works with a sequential checkpoint
layer.load_state_dict(layer_sequential.state_dict())
Expand All @@ -72,6 +84,7 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, easy_tp, bias):
@pytest.mark.parametrize("easy_tp", [False, True])
@pytest.mark.parametrize("clip_grad_norm", [-1, 1e-3])
@pytest.mark.parametrize("bias", [False, True])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_bw_pass(
G_intra_r,
G_intra_c,
Expand All @@ -82,18 +95,25 @@ def test_bw_pass(
easy_tp,
clip_grad_norm,
bias,
device,
):
# These tests are in fp-32
if device == "cpu" and G_intra_d > 1:
return # Gloo doesnt support reduce scatter

torch.manual_seed(42)
ax.init(
G_data=1,
G_inter=1,
G_intra_r=G_intra_r,
G_intra_c=G_intra_c,
G_intra_d=G_intra_d,
mixed_precision=False,
fp16_allreduce=False,
device=device,
)
X = torch.randn(B, H).cuda() * 0.01
Y_grad = torch.randn(B, H).cuda() * 0.01
X = torch.randn(B, H).to(device) * 0.01
Y_grad = torch.randn(B, H).to(device) * 0.01

inner_group = ax.comm_handle.inner_intra_layer_parallel_group
outer_group = ax.comm_handle.outer_intra_layer_parallel_group
Expand All @@ -104,8 +124,10 @@ def test_bw_pass(
in_features=H,
out_features=H,
bias=bias,
).cuda()
layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).cuda()
).to(device)
layer_sequential = torch.nn.Linear(in_features=H, out_features=H, bias=bias).to(
device
)

# test if load state dict works with a sequential checkpoint
layer.load_state_dict(layer_sequential.state_dict())
Expand All @@ -128,9 +150,9 @@ def test_bw_pass(

with optimize_communication(
overlap_all_reduce=comm_opt_level >= 1,
overlap_reduce_scatter=comm_opt_level >= 2,
overlap_reduce_scatter=comm_opt_level >= 2 and device != "cpu",
cache_weights=comm_opt_level >= 3,
overlap_all_gather=comm_opt_level == 4,
overlap_all_gather=comm_opt_level == 4 and device != "cpu",
model_object_for_overlapping_allgathers=layer,
):
Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
Expand Down Expand Up @@ -175,17 +197,3 @@ def test_bw_pass(
assert torch.allclose(
bias_grad_parallel, layer_sequential.bias.grad
), "BW Pass - gradients of bias do not match"


if __name__ == "__main__":
test_bw_pass(
G_intra_r=1,
G_intra_c=1,
G_intra_d=2,
B=2,
H=256,
comm_opt_level=0,
easy_tp=False,
clip_grad_norm=-1,
bias=True,
)
Loading