Skip to content

Commit d2afadf

Browse files
JuanPedroGHMHoppemrfh92
authored
Large data count support for MPI Communication (#1765)
* trick to send large data * added tests * Fixes for allreduce * fixed large counts for allreduce, now trying to fix non-contiguous data types * Custom operations for allreduce * bench fixes * perun fix * correct inplace contiguous (sorry fabian) * remove print statements * benchmark fixes and debug output * Incorrect move to acc if not CUDA_AWARE_MPI * added tests for the Allreduce case * tests were too large * Update test_communication.py --------- Co-authored-by: Hoppe <[email protected]> Co-authored-by: Fabian Hoppe <[email protected]>
1 parent 1c970d1 commit d2afadf

File tree

4 files changed

+206
-50
lines changed

4 files changed

+206
-50
lines changed

.perun.ini

+21-2
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,35 @@
1+
[post-processing]
2+
power_overhead = 100
3+
pue = 1.05
4+
emissions_factor = 417.8
5+
price_factor = 0.3251
6+
price_unit = €
7+
8+
[monitor]
9+
sampling_period = 0.1
10+
include_backends =
11+
include_sensors =
12+
exclude_backends =
13+
exclude_sensors = CPU_FREQ_\d
14+
115
[output]
16+
app_name
17+
run_id
218
format = bench
319
data_out = ./bench_data
420

521
[benchmarking]
622
rounds = 10
723
warmup_rounds = 1
8-
metrics=runtime
9-
region_metrics=runtime
24+
metrics = runtime,energy
25+
region_metrics = runtime,power
1026

1127
[benchmarking.units]
1228
joule = k
1329
second =
1430
percent =
1531
watt =
1632
byte = G
33+
34+
[debug]
35+
log_lvl = WARNING

benchmarks/cb/main.py

+4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
ht.use_device(os.environ["HEAT_DEVICE"] if os.environ["HEAT_DEVICE"] else "cpu")
77
ht.random.seed(12345)
88

9+
world_size = ht.MPI_WORLD.size
10+
rank = ht.MPI_WORLD.rank
11+
print(f"{rank}/{world_size}: Working on {ht.get_device()}")
12+
913
from linalg import run_linalg_benchmarks
1014
from cluster import run_cluster_benchmarks
1115
from manipulations import run_manipulation_benchmarks

heat/core/communication.py

+144-48
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from __future__ import annotations
66

77
import numpy as np
8+
import math
9+
import ctypes
810
import os
911
import subprocess
1012
import torch
@@ -123,6 +125,8 @@ class MPICommunication(Communication):
123125
Handle for the mpi4py Communicator
124126
"""
125127

128+
COUNT_LIMIT = torch.iinfo(torch.int32).max
129+
126130
__mpi_type_mappings = {
127131
torch.bool: MPI.BOOL,
128132
torch.uint8: MPI.UNSIGNED_CHAR,
@@ -288,7 +292,33 @@ def mpi_type_and_elements_of(
288292

289293
if is_contiguous:
290294
if counts is None:
291-
return mpi_type, elements
295+
if elements > cls.COUNT_LIMIT:
296+
# Uses vector type to get around the MAX_INT limit on certain MPI implementations
297+
# This is at the moment only applied when sending contiguous data, as the construction of data types to get around non-contiguous data naturally aliviates the problem to a certain extent.
298+
# Thanks to: J. R. Hammond, A. Schäfer and R. Latham, "To INT_MAX... and Beyond! Exploring Large-Count Support in MPI," 2014 Workshop on Exascale MPI at Supercomputing Conference, New Orleans, LA, USA, 2014, pp. 1-8, doi: 10.1109/ExaMPI.2014.5. keywords: {Vectors;Standards;Libraries;Optimization;Context;Memory management;Open area test sites},
299+
300+
new_count = elements // cls.COUNT_LIMIT
301+
left_over = elements % cls.COUNT_LIMIT
302+
303+
if new_count > cls.COUNT_LIMIT:
304+
raise ValueError("Tensor is too large")
305+
vector_type = mpi_type.Create_vector(
306+
new_count, cls.COUNT_LIMIT, cls.COUNT_LIMIT
307+
)
308+
if left_over > 0:
309+
left_over_mpi_type = mpi_type.Create_contiguous(left_over).Commit()
310+
_, old_type_extent = mpi_type.Get_extent()
311+
disp = cls.COUNT_LIMIT * new_count * old_type_extent
312+
struct_type = mpi_type.Create_struct(
313+
[1, 1], [0, disp], [vector_type, left_over_mpi_type]
314+
).Commit()
315+
vector_type.Free()
316+
left_over_mpi_type.Free()
317+
return struct_type, 1
318+
else:
319+
return vector_type, 1
320+
else:
321+
return mpi_type, elements
292322
factor = np.prod(obj.shape[1:], dtype=np.int32)
293323
return (
294324
mpi_type,
@@ -317,7 +347,7 @@ def mpi_type_and_elements_of(
317347
return mpi_type, elements
318348

319349
@classmethod
320-
def as_mpi_memory(cls, obj) -> MPI.memory:
350+
def as_mpi_memory(cls, obj: torch.Tensor) -> MPI.memory:
321351
"""
322352
Converts the passed ``torch.Tensor`` into an MPI compatible memory view.
323353
@@ -327,7 +357,8 @@ def as_mpi_memory(cls, obj) -> MPI.memory:
327357
The tensor to be converted into a MPI memory view.
328358
"""
329359
# TODO: MPI.memory might be depraecated in future versions of mpi4py. The following code might need to be adapted and use MPI.buffer instead.
330-
return MPI.memory.fromaddress(obj.data_ptr(), 0)
360+
nbytes = obj.dtype.itemsize * obj.numel()
361+
return MPI.memory.fromaddress(obj.data_ptr(), nbytes)
331362

332363
@classmethod
333364
def as_buffer(
@@ -782,11 +813,71 @@ def Ibcast(self, buf: Union[DNDarray, torch.Tensor, Any], root: int = 0) -> MPIR
782813

783814
Ibcast.__doc__ = MPI.Comm.Ibcast.__doc__
784815

816+
def __derived_op(
817+
self, tensor: torch.Tensor, datatype: MPI.Datatype, operation: MPI.Op
818+
) -> Callable[[MPI.memory, MPI.memory, MPI.Datatype], None]:
819+
820+
# Based from this conversation on the internet: https://groups.google.com/g/mpi4py/c/UkDT_9pp4V4?pli=1
821+
shape = tensor.shape
822+
dtype = tensor.dtype
823+
stride = tensor.stride()
824+
offset = tensor.storage_offset()
825+
count = tensor.numel()
826+
827+
mpiOp2torch = {
828+
MPI.SUM.handle: torch.add,
829+
MPI.PROD.handle: torch.mul,
830+
MPI.MIN.handle: torch.min,
831+
MPI.MAX.handle: torch.max,
832+
MPI.LAND.handle: torch.logical_and,
833+
MPI.LOR.handle: torch.logical_or,
834+
MPI.LXOR.handle: torch.logical_xor,
835+
MPI.BAND.handle: torch.bitwise_and,
836+
MPI.BOR.handle: torch.bitwise_or,
837+
MPI.BXOR.handle: torch.bitwise_xor,
838+
# MPI.MINLOC.handle: torch.argmin, Not supported, seems to be an invalid inplace operation
839+
# MPI.MAXLOC.handle: torch.argmax
840+
}
841+
mpiDtype2Ctype = {
842+
torch.bool: ctypes.c_bool,
843+
torch.uint8: ctypes.c_uint8,
844+
torch.uint16: ctypes.c_uint16,
845+
torch.uint32: ctypes.c_uint32,
846+
torch.uint64: ctypes.c_uint64,
847+
torch.int8: ctypes.c_int8,
848+
torch.int16: ctypes.c_int16,
849+
torch.int32: ctypes.c_int32,
850+
torch.int64: ctypes.c_int64,
851+
torch.float32: ctypes.c_float,
852+
torch.float64: ctypes.c_double,
853+
torch.complex64: ctypes.c_double,
854+
torch.complex128: ctypes.c_longdouble,
855+
}
856+
ctype_size = mpiDtype2Ctype[dtype]
857+
torch_op = mpiOp2torch[operation.handle]
858+
859+
def op(sendbuf: MPI.memory, recvbuf: MPI.memory, datatype):
860+
send_arr = (ctype_size * (count + offset)).from_address(sendbuf.address)
861+
recv_arr = (ctype_size * (count + offset)).from_address(recvbuf.address)
862+
863+
send_tensor = torch.as_strided(
864+
torch.frombuffer(send_arr, dtype=dtype, count=count, offset=offset), shape, stride
865+
)
866+
recv_tensor = torch.as_strided(
867+
torch.frombuffer(recv_arr, dtype=dtype, count=count, offset=offset), shape, stride
868+
)
869+
torch_op(send_tensor, recv_tensor, out=recv_tensor)
870+
871+
op = MPI.Op.Create(op)
872+
873+
return op
874+
785875
def __reduce_like(
786876
self,
787877
func: Callable,
788878
sendbuf: Union[DNDarray, torch.Tensor, Any],
789879
recvbuf: Union[DNDarray, torch.Tensor, Any],
880+
op: MPI.Op,
790881
*args,
791882
**kwargs,
792883
) -> Tuple[Optional[DNDarray, torch.Tensor]]:
@@ -801,6 +892,8 @@ def __reduce_like(
801892
Buffer address of the send message
802893
recvbuf: Union[DNDarray, torch.Tensor, Any]
803894
Buffer address where to store the result of the reduction
895+
op: MPI.Op
896+
Operation to apply during the reduction.
804897
"""
805898
sbuf = None
806899
rbuf = None
@@ -815,56 +908,59 @@ def __reduce_like(
815908
# harmonize the input and output buffers
816909
# MPI requires send and receive buffers to be of same type and length. If the torch tensors are either not both
817910
# contiguous or differently strided, they have to be made matching (if possible) first.
818-
if isinstance(sendbuf, torch.Tensor):
819-
# convert the send buffer to a pointer, number of elements and type are identical to the receive buffer
820-
dummy = (
821-
sendbuf.contiguous()
822-
) # make a contiguous copy and reassign the storage, old will be collected
823-
# In PyTorch Version >= 2.0.0 we can use untyped_storage() instead of storage
824-
# to keep backward compatibility with earlier PyTorch versions (where no untyped_storage() exists) we use a try/except
825-
# (this applies to all places of Heat where untyped_storage() is used without further comment)
826-
try:
827-
sendbuf.set_(
828-
dummy.untyped_storage(),
829-
dummy.storage_offset(),
830-
size=dummy.shape,
831-
stride=dummy.stride(),
832-
)
833-
except AttributeError:
834-
sendbuf.set_(
835-
dummy.storage(),
836-
dummy.storage_offset(),
837-
size=dummy.shape,
838-
stride=dummy.stride(),
839-
)
840-
sbuf = sendbuf if CUDA_AWARE_MPI else sendbuf.cpu()
841-
sendbuf = self.as_buffer(sbuf)
911+
if sendbuf is not MPI.IN_PLACE:
912+
# Send and recv buffer need the same number of elements.
913+
if sendbuf.numel() != recvbuf.numel():
914+
raise ValueError("Send and recv buffers need the same number of elements.")
915+
916+
# Stride and offset should be the same to create the same datatype and operation. If they differ, they should be made contiguous (at the expense of memory)
917+
if (
918+
sendbuf.stride() != recvbuf.stride()
919+
or sendbuf.storage_offset() != recvbuf.storage_offset()
920+
):
921+
if not sendbuf.is_contiguous():
922+
tmp = sendbuf.contiguous()
923+
try:
924+
sendbuf.set_(
925+
tmp.untyped_storage(),
926+
tmp.storage_offset(),
927+
size=tmp.shape,
928+
stride=tmp.stride(),
929+
)
930+
except AttributeError:
931+
sendbuf.set_(
932+
tmp.storage(), tmp.storage_offset(), size=tmp.shape, stride=tmp.stride()
933+
)
934+
if not recvbuf.is_contiguous():
935+
tmp = recvbuf.contiguous()
936+
try:
937+
recvbuf.set_(
938+
tmp.untyped_storage(),
939+
tmp.storage_offset(),
940+
size=tmp.shape,
941+
stride=tmp.stride(),
942+
)
943+
except AttributeError:
944+
recvbuf.set_(
945+
tmp.storage(), tmp.storage_offset(), size=tmp.shape, stride=tmp.stride()
946+
)
947+
842948
if isinstance(recvbuf, torch.Tensor):
949+
# Datatype and count shall be derived from the recv buffer, and applied to both, as they should match after the last code block
843950
buf = recvbuf
844-
# nothing matches, the buffers have to be made contiguous
845-
dummy = recvbuf.contiguous()
846-
try:
847-
recvbuf.set_(
848-
dummy.untyped_storage(),
849-
dummy.storage_offset(),
850-
size=dummy.shape,
851-
stride=dummy.stride(),
852-
)
853-
except AttributeError:
854-
recvbuf.set_(
855-
dummy.storage(),
856-
dummy.storage_offset(),
857-
size=dummy.shape,
858-
stride=dummy.stride(),
859-
)
860951
rbuf = recvbuf if CUDA_AWARE_MPI else recvbuf.cpu()
861-
if sendbuf is MPI.IN_PLACE:
862-
recvbuf = self.as_buffer(rbuf)
863-
else:
864-
recvbuf = (self.as_mpi_memory(rbuf), sendbuf[1], sendbuf[2])
952+
recvbuf: Tuple[MPI.memory, int, MPI.Datatype] = self.as_buffer(rbuf, is_contiguous=True)
953+
if not recvbuf[2].is_predefined:
954+
# If using a derived datatype, we need to define the reduce operation to be able to handle the it.
955+
derived_op = self.__derived_op(rbuf, recvbuf[2], op)
956+
op = derived_op
957+
958+
if isinstance(sendbuf, torch.Tensor):
959+
sbuf = sendbuf if CUDA_AWARE_MPI else sendbuf.cpu()
960+
sendbuf = (self.as_mpi_memory(sbuf), recvbuf[1], recvbuf[2])
865961

866962
# perform the actual reduction operation
867-
return func(sendbuf, recvbuf, *args, **kwargs), sbuf, rbuf, buf
963+
return func(sendbuf, recvbuf, op, **kwargs), sbuf, rbuf, buf
868964

869965
def Allreduce(
870966
self,

heat/core/tests/test_communication.py

+37
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import numpy as np
22
import torch
33
import heat as ht
4+
import unittest
45

56
from .test_suites.basic_test import TestCase
67

@@ -2492,3 +2493,39 @@ def test_alltoallSorting(self):
24922493
test4.comm.Alltoallv(test4.larray, redistributed4, send_axis=2, recv_axis=2)
24932494
with self.assertRaises(NotImplementedError):
24942495
test4.comm.Alltoallv(test4.larray, redistributed4, send_axis=None)
2496+
2497+
# The following test is only for the bool data type to save memory
2498+
# memory requirement: ~16MB * number of processes
2499+
def test_largecount_workaround_IsendRecv(self):
2500+
shape = (2**15, 2**16)
2501+
data = (
2502+
torch.zeros(shape, dtype=torch.bool)
2503+
if ht.MPI_WORLD.rank % 2 == 0
2504+
else torch.ones(shape, dtype=torch.bool)
2505+
)
2506+
buf = torch.empty(shape, dtype=torch.bool)
2507+
req = ht.MPI_WORLD.Isend(
2508+
data, ht.MPI_WORLD.rank - 1 if ht.MPI_WORLD.rank > 0 else ht.MPI_WORLD.size - 1
2509+
)
2510+
ht.MPI_WORLD.Recv(
2511+
buf, ht.MPI_WORLD.rank + 1 if ht.MPI_WORLD.rank < ht.MPI_WORLD.size - 1 else 0
2512+
)
2513+
req.Wait()
2514+
self.assertTrue(
2515+
buf.all()
2516+
if (ht.MPI_WORLD.rank % 2 == 0 and ht.MPI_WORLD.rank != ht.MPI_WORLD.size - 1)
2517+
else not buf.all()
2518+
)
2519+
2520+
# the following test is only for two processes to save memory
2521+
# memory requirement: ~16MB * number of processes
2522+
@unittest.skipIf(ht.MPI_WORLD.size != 2, "Only for two processes")
2523+
def test_largecount_workaround_Allreduce(self):
2524+
shape = (2**10, 2**11, 2**10)
2525+
data = (
2526+
torch.zeros(shape, dtype=torch.bool)
2527+
if ht.MPI_WORLD.rank % 2 == 0
2528+
else torch.ones(shape, dtype=torch.bool)
2529+
)
2530+
ht.MPI_WORLD.Allreduce(ht.MPI.IN_PLACE, data, op=ht.MPI.SUM)
2531+
self.assertTrue(data.all())

0 commit comments

Comments
 (0)