Skip to content

Commit 2131455

Browse files
jiayisusefacebook-github-bot
authored andcommitted
Add alltoall alltoallv collectives (#258)
Summary: Pull Request resolved: #258 Add alltoall and alltoallv to Gloo Reviewed By: osalpekar Differential Revision: D21873282 fbshipit-source-id: 75d82b2c3699279c777e62d9d79ac7a9202bcdb4
1 parent 2b91c6b commit 2131455

File tree

9 files changed

+560
-0
lines changed

9 files changed

+560
-0
lines changed

gloo/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ list(APPEND GLOO_SRCS
1111
"${CMAKE_CURRENT_SOURCE_DIR}/allgatherv.cc"
1212
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce.cc"
1313
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.cc"
14+
"${CMAKE_CURRENT_SOURCE_DIR}/alltoall.cc"
15+
"${CMAKE_CURRENT_SOURCE_DIR}/alltoallv.cc"
1416
"${CMAKE_CURRENT_SOURCE_DIR}/barrier.cc"
1517
"${CMAKE_CURRENT_SOURCE_DIR}/broadcast.cc"
1618
"${CMAKE_CURRENT_SOURCE_DIR}/context.cc"
@@ -32,6 +34,8 @@ list(APPEND GLOO_HDRS
3234
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_local.h"
3335
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_ring.h"
3436
"${CMAKE_CURRENT_SOURCE_DIR}/allreduce_ring_chunked.h"
37+
"${CMAKE_CURRENT_SOURCE_DIR}/alltoall.h"
38+
"${CMAKE_CURRENT_SOURCE_DIR}/alltoallv.h"
3539
"${CMAKE_CURRENT_SOURCE_DIR}/barrier.h"
3640
"${CMAKE_CURRENT_SOURCE_DIR}/barrier_all_to_all.h"
3741
"${CMAKE_CURRENT_SOURCE_DIR}/barrier_all_to_one.h"

gloo/alltoall.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
/**
2+
* Copyright (c) 2018-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "gloo/alltoall.h"
10+
11+
#include <cstring>
12+
13+
#include "gloo/common/logging.h"
14+
#include "gloo/types.h"
15+
16+
namespace gloo {
17+
18+
void alltoall(AlltoallOptions& opts) {
19+
const auto& context = opts.context;
20+
transport::UnboundBuffer* in = opts.in.get();
21+
transport::UnboundBuffer* out = opts.out.get();
22+
const auto slot = Slot::build(kAlltoallSlotPrefix, opts.tag);
23+
24+
// Sanity checks.
25+
// Number of elements should be evenly split in input and output buffers.
26+
GLOO_ENFORCE(opts.elementSize > 0);
27+
GLOO_ENFORCE(in != nullptr);
28+
GLOO_ENFORCE(out != nullptr);
29+
GLOO_ENFORCE(in->size % context->size == 0);
30+
GLOO_ENFORCE(in->size == out->size);
31+
32+
size_t chunkSize = in->size / context->size;
33+
int myRank = context->rank;
34+
int worldSize = context->size;
35+
36+
// Local copy.
37+
memcpy(
38+
static_cast<char*>(out->ptr) + myRank * chunkSize,
39+
static_cast<char*>(in->ptr) + myRank * chunkSize,
40+
chunkSize);
41+
42+
// Remote copy.
43+
for (int i = 1; i < worldSize; i++) {
44+
int sendRank = (myRank + i) % worldSize;
45+
int recvRank = (myRank + worldSize - i) % worldSize;
46+
in->send(sendRank, slot, sendRank * chunkSize, chunkSize);
47+
out->recv(recvRank, slot, recvRank * chunkSize, chunkSize);
48+
}
49+
50+
for (int i = 1; i < worldSize; i++) {
51+
in->waitSend(opts.timeout);
52+
out->waitRecv(opts.timeout);
53+
}
54+
}
55+
56+
} // namespace gloo

gloo/alltoall.h

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
/**
2+
* Copyright (c) 2018-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include "gloo/common/logging.h"
12+
#include "gloo/context.h"
13+
#include "gloo/transport/unbound_buffer.h"
14+
15+
namespace gloo {
16+
17+
class AlltoallOptions {
18+
public:
19+
explicit AlltoallOptions(const std::shared_ptr<Context>& context)
20+
: context(context), timeout(context->getTimeout()) {}
21+
22+
template <typename T>
23+
void setInput(std::unique_ptr<transport::UnboundBuffer> buf) {
24+
elementSize = sizeof(T);
25+
in = std::move(buf);
26+
}
27+
28+
template <typename T>
29+
void setInput(T* ptr, size_t elements) {
30+
elementSize = sizeof(T);
31+
in = context->createUnboundBuffer(ptr, elements * sizeof(T));
32+
}
33+
34+
template <typename T>
35+
void setOutput(std::unique_ptr<transport::UnboundBuffer> buf) {
36+
elementSize = sizeof(T);
37+
out = std::move(buf);
38+
}
39+
40+
template <typename T>
41+
void setOutput(T* ptr, size_t elements) {
42+
elementSize = sizeof(T);
43+
out = context->createUnboundBuffer(ptr, elements * sizeof(T));
44+
}
45+
46+
void setTag(uint32_t tag) {
47+
this->tag = tag;
48+
}
49+
50+
void setTimeout(std::chrono::milliseconds timeout) {
51+
GLOO_ENFORCE(timeout.count() > 0);
52+
this->timeout = timeout;
53+
}
54+
55+
protected:
56+
std::shared_ptr<Context> context;
57+
std::unique_ptr<transport::UnboundBuffer> in;
58+
std::unique_ptr<transport::UnboundBuffer> out;
59+
60+
// Number of bytes per element.
61+
size_t elementSize = 0;
62+
63+
// Tag for this operation.
64+
// Must be unique across operations executing in parallel.
65+
uint32_t tag = 0;
66+
67+
// End-to-end timeout for this operation.
68+
std::chrono::milliseconds timeout;
69+
70+
friend void alltoall(AlltoallOptions&);
71+
};
72+
73+
void alltoall(AlltoallOptions& opts);
74+
75+
} // namespace gloo

gloo/alltoallv.cc

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
/**
2+
* Copyright (c) 2018-present, Facebook, Inc.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include "gloo/alltoallv.h"
10+
11+
#include <cstring>
12+
#include <numeric>
13+
14+
#include "gloo/common/logging.h"
15+
#include "gloo/types.h"
16+
17+
namespace gloo {
18+
19+
static void splitOffsetsAndLengths(
20+
const std::vector<int64_t>& elementsPerRank,
21+
size_t elementSize,
22+
std::vector<size_t>& offsets,
23+
std::vector<size_t>& lengths) {
24+
size_t offset = 0;
25+
for (size_t elements : elementsPerRank) {
26+
size_t length = elements * elementSize;
27+
offsets.push_back(offset);
28+
lengths.push_back(length);
29+
offset += length;
30+
}
31+
}
32+
33+
void AlltoallvOptions::setElementSize(size_t elementSize) {
34+
if (this->elementSize == 0) {
35+
this->elementSize = elementSize;
36+
} else {
37+
GLOO_ENFORCE_EQ(
38+
elementSize,
39+
this->elementSize,
40+
"Element size does not match existing value. ",
41+
"Please double check that the input and output types match.");
42+
}
43+
}
44+
45+
void AlltoallvOptions::setInput(
46+
std::unique_ptr<transport::UnboundBuffer> buf,
47+
std::vector<int64_t> elementsPerRank,
48+
size_t elementSize) {
49+
const auto totalElements = std::accumulate(
50+
elementsPerRank.begin(), elementsPerRank.end(), size_t(0));
51+
this->setElementSize(elementSize);
52+
GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size);
53+
this->inOffsetPerRank.reserve(elementsPerRank.size());
54+
this->inLengthPerRank.reserve(elementsPerRank.size());
55+
splitOffsetsAndLengths(
56+
elementsPerRank,
57+
elementSize,
58+
this->inOffsetPerRank,
59+
this->inLengthPerRank);
60+
GLOO_ENFORCE_EQ(totalElements * elementSize, buf->size);
61+
this->in = std::move(buf);
62+
}
63+
64+
void AlltoallvOptions::setInput(
65+
void* ptr,
66+
std::vector<int64_t> elementsPerRank,
67+
size_t elementSize) {
68+
const auto totalElements = std::accumulate(
69+
elementsPerRank.begin(), elementsPerRank.end(), size_t(0));
70+
this->setElementSize(elementSize);
71+
GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size);
72+
this->inOffsetPerRank.reserve(elementsPerRank.size());
73+
this->inLengthPerRank.reserve(elementsPerRank.size());
74+
splitOffsetsAndLengths(
75+
elementsPerRank,
76+
elementSize,
77+
this->inOffsetPerRank,
78+
this->inLengthPerRank);
79+
this->in = context->createUnboundBuffer(ptr, totalElements * elementSize);
80+
}
81+
82+
void AlltoallvOptions::setOutput(
83+
std::unique_ptr<transport::UnboundBuffer> buf,
84+
std::vector<int64_t> elementsPerRank,
85+
size_t elementSize) {
86+
const auto totalElements = std::accumulate(
87+
elementsPerRank.begin(), elementsPerRank.end(), size_t(0));
88+
this->setElementSize(elementSize);
89+
GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size);
90+
this->outOffsetPerRank.reserve(elementsPerRank.size());
91+
this->outLengthPerRank.reserve(elementsPerRank.size());
92+
splitOffsetsAndLengths(
93+
elementsPerRank,
94+
elementSize,
95+
this->outOffsetPerRank,
96+
this->outLengthPerRank);
97+
GLOO_ENFORCE_EQ(totalElements * elementSize, buf->size);
98+
this->out = std::move(buf);
99+
}
100+
101+
void AlltoallvOptions::setOutput(
102+
void* ptr,
103+
std::vector<int64_t> elementsPerRank,
104+
size_t elementSize) {
105+
const auto totalElements = std::accumulate(
106+
elementsPerRank.begin(), elementsPerRank.end(), size_t(0));
107+
this->setElementSize(elementSize);
108+
GLOO_ENFORCE_EQ(elementsPerRank.size(), context->size);
109+
this->outOffsetPerRank.reserve(elementsPerRank.size());
110+
this->outLengthPerRank.reserve(elementsPerRank.size());
111+
splitOffsetsAndLengths(
112+
elementsPerRank,
113+
elementSize,
114+
this->outOffsetPerRank,
115+
this->outLengthPerRank);
116+
this->out = context->createUnboundBuffer(ptr, totalElements * elementSize);
117+
}
118+
119+
void alltoallv(AlltoallvOptions& opts) {
120+
const auto& context = opts.context;
121+
transport::UnboundBuffer* in = opts.in.get();
122+
transport::UnboundBuffer* out = opts.out.get();
123+
std::vector<size_t>& inOffsetPerRank = opts.inOffsetPerRank;
124+
std::vector<size_t>& inLengthPerRank = opts.inLengthPerRank;
125+
std::vector<size_t>& outOffsetPerRank = opts.outOffsetPerRank;
126+
std::vector<size_t>& outLengthPerRank = opts.outLengthPerRank;
127+
const auto slot = Slot::build(kAlltoallSlotPrefix, opts.tag);
128+
129+
// Sanity checks.
130+
GLOO_ENFORCE(opts.elementSize > 0);
131+
GLOO_ENFORCE(in != nullptr);
132+
GLOO_ENFORCE(out != nullptr);
133+
134+
int myRank = context->rank;
135+
int worldSize = context->size;
136+
137+
// Local copy.
138+
GLOO_ENFORCE(inLengthPerRank[myRank] == outLengthPerRank[myRank]);
139+
size_t myInOffset = inOffsetPerRank[myRank];
140+
size_t myOutOffset = outOffsetPerRank[myRank];
141+
size_t myChunkSize = inLengthPerRank[myRank];
142+
memcpy(
143+
static_cast<char*>(out->ptr) + myOutOffset,
144+
static_cast<char*>(in->ptr) + myInOffset,
145+
myChunkSize);
146+
147+
// Remote copy.
148+
for (int i = 1; i < worldSize; i++) {
149+
int sendRank = (myRank + i) % worldSize;
150+
int recvRank = (myRank + worldSize - i) % worldSize;
151+
in->send(
152+
sendRank, slot, inOffsetPerRank[sendRank], inLengthPerRank[sendRank]);
153+
out->recv(
154+
recvRank, slot, outOffsetPerRank[recvRank], outLengthPerRank[recvRank]);
155+
}
156+
157+
for (int i = 1; i < worldSize; i++) {
158+
in->waitSend(opts.timeout);
159+
out->waitRecv(opts.timeout);
160+
}
161+
}
162+
163+
} // namespace gloo

0 commit comments

Comments
 (0)