Skip to content

Commit 70eb6d7

Browse files
authored
Fixing the bug in allreduce1 (#220)
1 parent 1d11997 commit 70eb6d7

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

Diff for: python/benchmark/allreduce.cu

+7-4
Original file line numberDiff line numberDiff line change
@@ -176,10 +176,12 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
176176
TYPE val = smChans[peerIdx].read<TYPE>(idx);
177177
tmp += val;
178178
}
179-
for (int index = 0; index < nPeer; ++index) {
180-
int peerIdx = (index + rank);
181-
if (peerIdx >= nPeer) peerIdx -= nPeer;
182-
smChans[peerIdx].write<TYPE>(idx, tmp);
179+
if (READ_ONLY == 0) {
180+
for (int index = 0; index < nPeer; ++index) {
181+
int peerIdx = (index + rank);
182+
if (peerIdx >= nPeer) peerIdx -= nPeer;
183+
smChans[peerIdx].write<TYPE>(idx, tmp);
184+
}
183185
}
184186
buff[idx] = tmp;
185187
}
@@ -198,6 +200,7 @@ __device__ void allreduce1_helper(mscclpp::SmChannelDeviceHandle* smChans, TYPE*
198200
}
199201

200202
if (READ_ONLY) {
203+
deviceSyncer.sync(gridDim.x);
201204
for (int i = 0; i < nPeer; ++i) {
202205
int peerIdx = (i + rank);
203206
if (peerIdx >= nPeer) peerIdx -= nPeer;

0 commit comments

Comments
 (0)