Skip to content

Commit 23a3cef

Browse files
nariaki3551pytorchmergebot
authored andcommitted
[c10d] Add _allgather_base , reduce_scatter , and _reduce_scatter_base into ProcessGroupMPI to enable FSDP with MPI backend (pytorch#150162)
This PR implements _allgather_base, reduce_scatter, and _reduce_scatter_base in the MPI backend (ProcessGroupMPI), enabling support for Fully Sharded Data Parallel (FSDP) in environments that use MPI for distributed communication. ### Context As noted in pytorch#85628, FSDP currently supports only the NCCL backend. Due to this limitation, FSDP cannot run on legacy HPC environments or clusters that rely on MPI. By implementing just these three collective operations, we can enable FSDP to work with the MPI backend. These collectives are implemented in a similar manner to existing operations such as allgather. ### Testing We validated this PR using pytorch/build/bin/ProcessGroupMPITest with OpenMPI, and all tests passed successfully. Pull Request resolved: pytorch#150162 Approved by: https://github.com/H-Huang
1 parent 7deed19 commit 23a3cef

File tree

3 files changed

+220
-5
lines changed

3 files changed

+220
-5
lines changed

test/cpp/c10d/ProcessGroupMPITest.cpp

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,113 @@ void testAllgather(int iter = 10000) {
185185
}
186186
}
187187

188+
void testAllgatherBase(int iter = 10000) {
189+
auto pg = c10d::ProcessGroupMPI::createProcessGroupMPI();
190+
std::vector<c10::intrusive_ptr<::c10d::Work>> works;
191+
192+
// Get the world size
193+
auto worldSize = pg->getSize();
194+
auto rank = pg->getRank();
195+
196+
// Generate inputs
197+
for (const auto i : c10::irange(iter)) {
198+
auto tensor = at::ones({16, 16}) * i * rank;
199+
auto output = at::zeros({worldSize, 16, 16});
200+
201+
// Queue the work.
202+
c10::intrusive_ptr<::c10d::Work> work = pg->_allgather_base(output, tensor);
203+
works.push_back(std::move(work));
204+
}
205+
206+
auto outputTensors = waitFuture(pg, works);
207+
208+
// Verify outputs
209+
for (const auto i : c10::irange(iter)) {
210+
for (const auto j : c10::irange(worldSize)) {
211+
const auto expected = i * j;
212+
auto data = outputTensors[i][0][j].data_ptr<float>();
213+
for (auto k = 0; k < outputTensors[i][0][j].numel(); ++k) {
214+
if (data[k] != static_cast<float>(expected)) {
215+
TORCH_CHECK(false, "BOOM!");
216+
}
217+
}
218+
}
219+
}
220+
}
221+
222+
void testReduceScatter(int iter = 10000) {
223+
auto pg = c10d::ProcessGroupMPI::createProcessGroupMPI();
224+
std::vector<c10::intrusive_ptr<::c10d::Work>> works;
225+
226+
// Get the world size
227+
auto worldSize = pg->getSize();
228+
auto rank = pg->getRank();
229+
230+
// Generate inputs
231+
int count = 2;
232+
for (const auto i : c10::irange(iter)) {
233+
auto tensors = std::vector<std::vector<at::Tensor>>(1);
234+
tensors[0].resize(worldSize);
235+
for (const auto j : c10::irange(worldSize)) {
236+
tensors[0][j] = at::ones({count, count}) * i * rank;
237+
}
238+
auto output = at::zeros({count, count});
239+
auto outputs = std::vector<at::Tensor>({output});
240+
241+
// Queue the work.
242+
c10::intrusive_ptr<::c10d::Work> work =
243+
pg->reduce_scatter(outputs, tensors);
244+
works.push_back(std::move(work));
245+
}
246+
247+
auto outputTensors = waitFuture(pg, works);
248+
249+
// Verify outputs
250+
for (const auto i : c10::irange(iter)) {
251+
const auto expected = i * (worldSize * (worldSize - 1)) / 2.0;
252+
auto data = outputTensors[i][0].data_ptr<float>();
253+
for (auto j = 0; j < outputTensors[i][0].numel(); ++j) {
254+
if (data[j] != static_cast<float>(expected)) {
255+
TORCH_CHECK(false, "BOOM!");
256+
}
257+
}
258+
}
259+
}
260+
261+
void testReduceScatterBase(int iter = 10000) {
262+
auto pg = c10d::ProcessGroupMPI::createProcessGroupMPI();
263+
std::vector<c10::intrusive_ptr<::c10d::Work>> works;
264+
265+
// Get the world size
266+
auto worldSize = pg->getSize();
267+
auto rank = pg->getRank();
268+
269+
// Generate inputs
270+
for (const auto i : c10::irange(iter)) {
271+
auto tensor = at::ones({worldSize, 16, 16}) * i * rank;
272+
auto output = at::zeros({16, 16});
273+
auto outputs = std::vector<at::Tensor>({output});
274+
275+
// Queue the work.
276+
c10::intrusive_ptr<::c10d::Work> work =
277+
pg->_reduce_scatter_base(output, tensor);
278+
works.push_back(std::move(work));
279+
}
280+
281+
auto outputTensors = waitFuture(pg, works);
282+
283+
// Verify outputs
284+
for (const auto i : c10::irange(iter)) {
285+
const auto expected = i * (worldSize * (worldSize - 1)) / 2.0;
286+
auto data = outputTensors[i][0].data_ptr<float>();
287+
for (auto j = 0; j < outputTensors[i][0].numel(); ++j) {
288+
if (data[j] != static_cast<float>(expected)) {
289+
TORCH_CHECK(false, "BOOM!");
290+
}
291+
}
292+
}
293+
}
294+
188295
void testGather(int iter = 10000) {
189296
auto pg = c10d::ProcessGroupMPI::createProcessGroupMPI();
190297
std::vector<c10::intrusive_ptr<::c10d::Work>> works;
@@ -355,6 +462,9 @@ int main(int argc, char** argv) {
355462
testBroadcast();
356463
testReduce();
357464
testAllgather();
465+
testAllgatherBase();
466+
testReduceScatter();
467+
testReduceScatterBase();
358468
testGather();
359469
testScatter();
360470
testSendRecv(false);

torch/csrc/distributed/c10d/ProcessGroupMPI.cpp

Lines changed: 105 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,47 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::reduce_scatter(
695695
std::vector<at::Tensor>& outputTensors,
696696
std::vector<std::vector<at::Tensor>>& inputTensors,
697697
const ReduceScatterOptions& opts) {
698-
TORCH_CHECK(false, "ProcessGroupMPI does not support reduce_scatter");
698+
checkSingleTensor(outputTensors);
699+
if (inputTensors.size() != 1) {
700+
TORCH_CHECK(
701+
false,
702+
"MPI process group only supports a single "
703+
"tensor op");
704+
}
705+
if (static_cast<size_t>(size_) != inputTensors[0].size()) {
706+
TORCH_CHECK(
707+
false,
708+
"Reduce scatter: number of input tensors should equal "
709+
"to the world size");
710+
}
711+
checkSameSizeAndType(outputTensors[0], inputTensors[0]);
712+
713+
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
714+
[opts, this](std::unique_ptr<WorkEntry>& entry) {
715+
auto data = (entry->dst)[0];
716+
auto flatInputTensor = newLikeFlat(entry->src);
717+
for (const auto i : c10::irange(entry->src.size())) {
718+
flatInputTensor[static_cast<int64_t>(i)].copy_(entry->src[i]);
719+
}
720+
int recvcount = flatInputTensor.numel() / size_;
721+
722+
c10::DeviceGuard guard(data.device());
723+
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
724+
MPI_CHECK(MPI_Reduce_scatter_block(
725+
flatInputTensor.data_ptr(),
726+
data.data_ptr(),
727+
recvcount,
728+
mpiDatatype.at(data.scalar_type()),
729+
mpiOp.at(opts.reduceOp),
730+
pgComm_));
731+
};
732+
733+
auto entry = std::make_unique<WorkEntry>(
734+
&inputTensors[0], &outputTensors, std::move(runFunc));
735+
return enqueue(
736+
std::move(entry),
737+
"mpi:reduce_scatter",
738+
std::optional<std::vector<at::Tensor>>(inputTensors[0]));
699739
}
700740

701741
c10::intrusive_ptr<Work> ProcessGroupMPI::alltoall_base(
@@ -941,10 +981,70 @@ c10::intrusive_ptr<Work> ProcessGroupMPI::barrier(const BarrierOptions& opts) {
941981
}
942982

943983
c10::intrusive_ptr<Work> ProcessGroupMPI::_allgather_base(
944-
at::Tensor& /*unused */,
945-
at::Tensor& /*unused */,
946-
const AllgatherOptions& /*unused */) {
947-
TORCH_CHECK(false, "no support for _allgather_base in MPI process group");
984+
at::Tensor& outputTensor,
985+
at::Tensor& inputTensor,
986+
const AllgatherOptions& opts) {
987+
TORCH_CHECK(
988+
outputTensor.numel() == inputTensor.numel() * size_,
989+
"All gather: output tensor size must be equal to input tensor size times the world size");
990+
991+
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
992+
[this](std::unique_ptr<WorkEntry>& entry) {
993+
auto dstdata = (entry->dst)[0];
994+
auto srcdata = (entry->src)[0];
995+
c10::DeviceGuard guard(srcdata.device());
996+
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
997+
MPI_CHECK(MPI_Allgather(
998+
srcdata.data_ptr(),
999+
srcdata.numel(),
1000+
mpiDatatype.at(srcdata.scalar_type()),
1001+
dstdata.data_ptr(),
1002+
srcdata.numel(),
1003+
mpiDatatype.at(dstdata.scalar_type()),
1004+
pgComm_));
1005+
};
1006+
1007+
auto inputTensors = std::vector<at::Tensor>({inputTensor});
1008+
auto outputTensors = std::vector<at::Tensor>({outputTensor});
1009+
auto entry = std::make_unique<WorkEntry>(
1010+
&inputTensors, &outputTensors, std::move(runFunc));
1011+
return enqueue(
1012+
std::move(entry),
1013+
"mpi:_allgather_base",
1014+
std::optional<std::vector<at::Tensor>>(inputTensors));
1015+
}
1016+
1017+
c10::intrusive_ptr<Work> ProcessGroupMPI::_reduce_scatter_base(
1018+
at::Tensor& outputTensor,
1019+
at::Tensor& inputTensor,
1020+
const ReduceScatterOptions& opts) {
1021+
TORCH_CHECK(
1022+
outputTensor.numel() * size_ == inputTensor.numel(),
1023+
"Reduce scatter: input tensor size must be equal to output tensor size times the world size");
1024+
1025+
std::function<void(std::unique_ptr<WorkEntry>&)> runFunc =
1026+
[opts, this](std::unique_ptr<WorkEntry>& entry) {
1027+
auto dstdata = (entry->dst)[0];
1028+
auto srcdata = (entry->src)[0];
1029+
c10::DeviceGuard guard(srcdata.device());
1030+
std::unique_lock<std::mutex> globalLock(pgGlobalMutex_);
1031+
MPI_CHECK(MPI_Reduce_scatter_block(
1032+
srcdata.data_ptr(),
1033+
dstdata.data_ptr(),
1034+
dstdata.numel(),
1035+
mpiDatatype.at(srcdata.scalar_type()),
1036+
mpiOp.at(opts.reduceOp),
1037+
pgComm_));
1038+
};
1039+
1040+
auto inputTensors = std::vector<at::Tensor>({inputTensor});
1041+
auto outputTensors = std::vector<at::Tensor>({outputTensor});
1042+
auto entry = std::make_unique<WorkEntry>(
1043+
&inputTensors, &outputTensors, std::move(runFunc));
1044+
return enqueue(
1045+
std::move(entry),
1046+
"mpi:_reduce_scatter_base",
1047+
std::optional<std::vector<at::Tensor>>(inputTensors));
9481048
}
9491049

9501050
} // namespace c10d

torch/csrc/distributed/c10d/ProcessGroupMPI.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -197,6 +197,11 @@ class TORCH_API ProcessGroupMPI : public Backend {
197197
std::vector<std::vector<at::Tensor>>& inputTensors,
198198
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
199199

200+
c10::intrusive_ptr<Work> _reduce_scatter_base(
201+
at::Tensor& outputTensor,
202+
at::Tensor& inputTensor,
203+
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
204+
200205
c10::intrusive_ptr<Work> alltoall_base(
201206
at::Tensor& outputTensor,
202207
at::Tensor& inputTensor,

0 commit comments

Comments
 (0)