Skip to content

Commit

Permalink
Fixing the bug in allreduce1 (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
saeedmaleki authored Nov 18, 2023
1 parent 1d11997 commit 70eb6d7
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions python/benchmark/allreduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,12 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
TYPE val = smChans[peerIdx].read<TYPE>(idx);
tmp += val;
}
for (int index = 0; index < nPeer; ++index) {
int peerIdx = (index + rank);
if (peerIdx >= nPeer) peerIdx -= nPeer;
smChans[peerIdx].write<TYPE>(idx, tmp);
if (READ_ONLY == 0) {
for (int index = 0; index < nPeer; ++index) {
int peerIdx = (index + rank);
if (peerIdx >= nPeer) peerIdx -= nPeer;
smChans[peerIdx].write<TYPE>(idx, tmp);
}
}
buff[idx] = tmp;
}
Expand All @@ -198,6 +200,7 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
}

if (READ_ONLY) {
deviceSyncer.sync(gridDim.x);
for (int i = 0; i < nPeer; ++i) {
int peerIdx = (i + rank);
if (peerIdx >= nPeer) peerIdx -= nPeer;
Expand Down

0 comments on commit 70eb6d7

Please sign in to comment.