Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding CPU training support to AxoNN #39

Open
wants to merge 28 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e3bfa64
Rebased cpu changes onto latest axonn
Feb 2, 2024
472d355
Args should default to true, removed tensor_list operator for cpus
Feb 5, 2024
9c237b1
Made changes conditional for CPUs/GPUs
Feb 6, 2024
676640d
Gpu runs do not make Slurm calls
Feb 8, 2024
e84c64d
Revamped cpu flag with cpu hard set in comm.py
Feb 23, 2024
ae74e70
Set device before init
Mar 4, 2024
9514d65
Added initial pytest parameters for cpu tests
Mar 13, 2024
0a22f88
removed mpi4py dependency (#63)
S-Mahua Feb 20, 2024
52b0aef
adding parallelize context for opt (#65)
jwendlan Feb 27, 2024
7728f1a
Removing the drop and gathers in depth tensor parallelism for the eas…
siddharth9820 Feb 28, 2024
5087268
change parallelize context to use AutoConfig (#67)
siddharth9820 Feb 28, 2024
5faec5b
Bugfix: Initialize grad_input, grad_weight to None (#68)
adityaranjan Mar 6, 2024
e831180
Removed distributed communication class
Mar 24, 2024
0444e50
Merge branch 'develop' into axonn-cpu
siddharth9820 May 1, 2024
70a5a37
formatting
siddharth9820 May 1, 2024
e185214
fix CI
siddharth9820 May 1, 2024
e9f21ce
format
siddharth9820 May 1, 2024
f3a5dfc
Matched changes for intra_layer_conv tests
May 8, 2024
88176e9
Removed unused env variables and format
May 8, 2024
c63e489
Merge branch 'develop' into axonn-cpu
Avuxon May 8, 2024
f933edc
Fixed CI to return for depth > 1
May 8, 2024
20af63a
Merge branch 'axonn-cpu' of https://github.com/axonn-ai/axonn into ax…
May 8, 2024
8249add
Added depth check to fw pass in fc test
May 8, 2024
1a677ff
Merge develop into this branch
May 24, 2024
f39883f
Added xfail to convolution tests
May 24, 2024
d24fa4c
remove xfail
siddharth9820 Jun 10, 2024
ad779c3
Merge branch 'develop' into axonn-cpu
siddharth9820 Jun 18, 2024
1f9e54b
Merge branch 'develop' into axonn-cpu
siddharth9820 Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions axonn/axonn.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def init(
mixed_precision=False,
fp16_allreduce=True,
cpu_offload=False,
device="cuda",
) -> None:
"""
Initialize AxoNN's 2D parallelism with G_inter-way inter-layer
Expand All @@ -135,8 +136,15 @@ def init(
global comm_handle, is_initialized, computation_dtype, _fp16_all_reduce
global _cpu_offload
comm_handle = communication_handle(
G_inter, G_data, G_intra_r, G_intra_c, G_intra_d, gpus_per_node=gpus_per_node
G_inter,
G_data,
G_intra_r,
G_intra_c,
G_intra_d,
gpus_per_node=gpus_per_node,
device=device,
)
config.device = device
config.G_inter = G_inter
config.G_data = G_data
config.G_intra = G_intra_r * G_intra_c * G_intra_d
Expand All @@ -152,6 +160,14 @@ def init(
comm_handle.intra_layer_column_parallel_rank
)
is_initialized = True
if device == "cuda" and not torch.cuda.is_available():
raise ValueError("CUDA is not available. Please choose a different device.")

if device == "cpu":
assert (
G_intra_d == 1
), "G_intra_d > 1: Intra_d uses reduce-scatters which gloo(cpu) doesn't support"

if mixed_precision:
computation_dtype = torch.float16
else:
Expand Down Expand Up @@ -542,7 +558,7 @@ def _post_fw_recv_requests():
if (requests["fw"] is None) and config.inter_layer_parallel_rank > 0:
tensor = torch.empty(
size=_fill_shape(model.get_input_shape()),
device="cuda",
device=config.device,
dtype=computation_dtype,
)
tensor.requires_grad = True
Expand All @@ -561,7 +577,7 @@ def _post_bw_recv_requests():
):
tensor = torch.empty(
size=_fill_shape(model.get_output_shape()),
device="cuda",
device=config.device,
dtype=computation_dtype,
)
requests["bw"] = [
Expand Down
26 changes: 17 additions & 9 deletions axonn/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(
G_intra_c=1,
G_intra_d=1,
gpus_per_node=None,
device="cuda",
):
"""Constructor for the communication handle

Expand All @@ -44,6 +45,11 @@ def __init__(
G_intra_c (int): number of GPUs in the column intra-layer parallel dimension
G_intra_d (int): number of GPUs in the depth intra-layer parallel dimension
"""
if device == "cpu":
self.backend = "gloo"
else:
self.backend = "nccl"

if not torch.distributed.is_initialized():
assert MPI4PY, "either install mpi4py and launch via mpirun/srun"
"or initialize torch.distributed outside axonn"
Expand Down Expand Up @@ -71,8 +77,10 @@ def __init__(
self.gpus_per_node = (
gpus_per_node if gpus_per_node is not None else torch.cuda.device_count()
)
self.local_rank = self.world_rank % self.gpus_per_node
torch.cuda.set_device(self.local_rank)

if device == "cuda":
self.local_rank = self.world_rank % self.gpus_per_node
torch.cuda.set_device(self.local_rank)
self.intra_layer_parallel_rank = self.world_rank % G_intra
self.intra_layer_column_parallel_rank = (
self.intra_layer_parallel_rank % G_intra_c
Expand Down Expand Up @@ -115,7 +123,7 @@ def __init__(
master_port = os.getenv("MASTER_PORT", "6000")
init_method += master_ip + ":" + master_port
torch.distributed.init_process_group(
backend="nccl",
backend=self.backend,
world_size=self.world_size,
rank=self.world_rank,
init_method=init_method,
Expand All @@ -130,7 +138,7 @@ def __init__(
for k in range(self.G_data)
]
ith_jth_data_parallel_group = torch.distributed.new_group(
ranks=ranks_in_ith_jth_data_parallel_group, backend="nccl"
ranks=ranks_in_ith_jth_data_parallel_group, backend=self.backend
)
if self.world_rank in ranks_in_ith_jth_data_parallel_group:
self.coll_nccl_comm = ith_jth_data_parallel_group
Expand All @@ -142,7 +150,7 @@ def __init__(
i_ * G_inter * G_intra + j_ * G_intra + k for k in range(G_intra)
]
ith_jth_intra_layer_group = torch.distributed.new_group(
ranks=ranks_in_ith_jth_intra_layer_group, backend="nccl"
ranks=ranks_in_ith_jth_intra_layer_group, backend=self.backend
)
if self.world_rank in ranks_in_ith_jth_intra_layer_group:
self.intra_layer_group = ith_jth_intra_layer_group
Expand All @@ -165,7 +173,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[i, j, :]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.inner_intra_layer_parallel_group = group
Expand All @@ -177,7 +185,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[i, :, j]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.outer_intra_layer_parallel_group = group
Expand All @@ -189,7 +197,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[:, i, j]
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.depth_intra_layer_parallel_group = group
Expand All @@ -200,7 +208,7 @@ def __init__(
ranks_in_ith_jth_intra_layer_group[i, :, :].flatten()
)
group = torch.distributed.new_group(
ranks=group_members, backend="nccl"
ranks=group_members, backend=self.backend
)
if self.world_rank in group_members:
self.outer_inner_intra_layer_parallel_group = group
Expand Down
1 change: 1 addition & 0 deletions axonn/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,4 @@
G_data = 0
micro_batch_size = 0
batch_size = 0
device = "cuda"
29 changes: 22 additions & 7 deletions axonn/intra_layer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,13 +98,22 @@ def trigger_async_all_gathers(model):
handle = None
else:
assert weight.ndim == 1
output_shape = weight.shape[0] * world_size
all_gathered_weight = torch.empty(
output_shape, dtype=weight.dtype, device=weight.device
)
handle = dist.all_gather_into_tensor(
all_gathered_weight, weight, group=process_group, async_op=True
)

if torch.distributed.get_backend() == "nccl":
output_shape = weight.shape[0] * world_size
all_gathered_weight = torch.empty(
output_shape, dtype=weight.dtype, device=weight.device
)
handle = dist.all_gather_into_tensor(
all_gathered_weight,
weight,
group=process_group,
async_op=True,
)

elif torch.distributed.get_backend() == "gloo":
raise NotImplementedError

weights_cache[weight] = [all_gathered_weight, handle]
yield

Expand Down Expand Up @@ -173,6 +182,12 @@ def optimize_communication(
"for_overlapping_allgathers=model, ...)"
"if overlap_all_gather is True"
)

if torch.distributed.get_backend() == "gloo":
raise ValueError(
"overlap_all_gather does not work with gloo" "please set it to false"
)

ALL_GATHER_ITERATOR = trigger_async_all_gathers(
model_object_for_overlapping_allgathers
)
Expand Down
4 changes: 3 additions & 1 deletion axonn/intra_layer/communication.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch.distributed as dist
import torch
import axonn
from axonn import config


def _all_reduce(input_, process_group=None, overlap_comm=False):
Expand Down Expand Up @@ -48,7 +49,8 @@ def _gather(input_, dim, process_group=None, cache=False):
tensor_list = [
torch.empty_like(input_) for _ in range(dist.get_world_size(process_group))
]
tensor_list[rank] = input_
if config.device == "cuda":
tensor_list[rank] = input_
dist.all_gather(tensor_list, input_, group=process_group)

# Note: torch.cat already creates a contiguous tensor.
Expand Down
5 changes: 3 additions & 2 deletions axonn/intra_layer/fully_connected.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from axonn import axonn as ax
import axonn
from axonn import config
from .communication import (
Drop,
Gather,
Expand Down Expand Up @@ -38,9 +39,9 @@ def initialize_params(
in_features_group,
depth_group,
init_method,
init_device="cuda",
init_device=config.device,
):
params = torch.empty((out_features, in_features), device=init_device)
params = torch.empty((out_features, in_features), device=config.device)
init_method(params)
params = extract_local_params_from_full_params(
params, out_features_group, in_features_group, depth_group
Expand Down
47 changes: 37 additions & 10 deletions axonn/tests/test_intra_layer_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def norm_allclose(X, Y):
)
@pytest.mark.parametrize("easy_tp", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.skip(reason="torch.all_close does not work with conv")
def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
# These tests are in fp-32
Expand All @@ -54,15 +55,21 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
# This is required because TF32 cores only look at the first 10 bits of mantissa
torch.backends.cudnn.allow_tf32 = False

if device == "cpu" and G_intra_d > 1:
return # Gloo doesnt support reduce scatter

ax.init(
G_data=1,
G_inter=1,
G_intra_r=G_intra_r,
G_intra_c=G_intra_c,
G_intra_d=G_intra_d,
mixed_precision=False,
fp16_allreduce=False,
device=device,
)

X = torch.randn(B, C, H, W).cuda() * 0.01
X = torch.randn(B, C, H, W).to(device) * 0.01

inner_group = ax.comm_handle.inner_intra_layer_parallel_group
outer_group = ax.comm_handle.outer_intra_layer_parallel_group
Expand All @@ -78,7 +85,9 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
else:
X_local = X

layer = Conv2d(in_channels=C, out_channels=2 * C, kernel_size=5, bias=bias).cuda()
layer = Conv2d(in_channels=C, out_channels=2 * C, kernel_size=5, bias=bias).to(
device
)

with torch.no_grad():
# parallel FW pass
Expand All @@ -95,7 +104,7 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
out_channels=C * 2,
kernel_size=5,
bias=bias,
).cuda()
).to(device)
weight_sequential = _gather(
_gather(
_gather(layer.weight, 0, depth_group).reshape(
Expand Down Expand Up @@ -126,12 +135,25 @@ def test_fw_pass(G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias):
)
@pytest.mark.parametrize("easy_tp", [True, False])
@pytest.mark.parametrize("bias", [True, False])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
@pytest.mark.parametrize("comm_opt_level", [0, 3])
@pytest.mark.skip(reason="torch.all_close does not work with conv")
def test_bw_pass(
G_intra_r, G_intra_c, G_intra_d, B, H, W, C, easy_tp, bias, comm_opt_level
G_intra_r,
G_intra_c,
G_intra_d,
B,
H,
W,
C,
easy_tp,
bias,
comm_opt_level,
device,
):
# These tests are in fp-32
if device == "cpu" and G_intra_d > 1:
return # Gloo doesn't support reduce scatter
# Need to remove all non-determinism from convolutions
torch.manual_seed(42)
torch.cuda.manual_seed(42)
Expand All @@ -147,16 +169,21 @@ def test_bw_pass(
G_intra_r=G_intra_r,
G_intra_c=G_intra_c,
G_intra_d=G_intra_d,
mixed_precision=False,
fp16_allreduce=False,
device=device,
)
X = torch.randn(B, C, H, W).cuda() * 0.01
Y_grad = torch.randn(B, 2 * C, H - 4, W - 4).cuda() * 0.01
X = torch.randn(B, C, H, W).to(device) * 0.01
Y_grad = torch.randn(B, 2 * C, H - 4, W - 4).to(device) * 0.01

inner_group = ax.comm_handle.inner_intra_layer_parallel_group
outer_group = ax.comm_handle.outer_intra_layer_parallel_group
depth_group = ax.comm_handle.depth_intra_layer_parallel_group

# parallel backward pass
layer = Conv2d(in_channels=C, out_channels=2 * C, kernel_size=5, bias=bias).cuda()
layer = Conv2d(in_channels=C, out_channels=2 * C, kernel_size=5, bias=bias).to(
device
)

if not easy_tp:
X_local = (
Expand All @@ -176,9 +203,9 @@ def test_bw_pass(
Y_local_grad = Y_grad

with optimize_communication(
overlap_reduce_scatter=comm_opt_level >= 1,
overlap_reduce_scatter=comm_opt_level >= 1 and device != "cpu",
cache_weights=comm_opt_level >= 2,
overlap_all_gather=comm_opt_level == 3,
overlap_all_gather=comm_opt_level == 3 and device != "cpu",
model_object_for_overlapping_allgathers=layer,
):
Y_local = layer(X_local, scatter_input=easy_tp, gather_output=easy_tp)
Expand All @@ -195,7 +222,7 @@ def test_bw_pass(
out_channels=C * 2,
kernel_size=5,
bias=bias,
).cuda()
).to(device)
with torch.no_grad():
weight_sequential = _gather(
_gather(
Expand Down
Loading
Loading