Skip to content

Commit

Permalink
Fix synchronization in allreduce8 kernel (#407)
Browse files Browse the repository at this point in the history
Running kernel allreduce8 across 64 vGPUs (in CPX mode) revealed a
synchronization bug. The PR addresses it by ensuring that signals are
only issued after all threads in the block have issued their writes to
guarantee correct ordering between data writes and signal writes.

---------

Co-authored-by: Changho Hwang <[email protected]>
  • Loading branch information
dsidler and chhwang authored Dec 19, 2024
1 parent 774602d commit d8d0dfb
Showing 1 changed file with 10 additions and 0 deletions.
10 changes: 10 additions & 0 deletions apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ __global__ void __launch_bounds__(512, 1)
}

/// Starts reduce-scatter
// Ensure that all writes of this block have been issued before issuing the signal
__syncthreads();
if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
outChannels[threadIdx.x].signal();
outChannels[threadIdx.x].wait();
Expand All @@ -398,6 +400,8 @@ __global__ void __launch_bounds__(512, 1)
}
}
offsetOfThisBlock += nInt4PerChunk;
// Ensure all threads have consumed data from scratch buffer before signaling re-use in next iteration
__syncthreads();
}
if (restNInt4 > 0) {
if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
Expand All @@ -414,6 +418,8 @@ __global__ void __launch_bounds__(512, 1)
}
}

// Ensure that all writes of this block have been issued before issuing the signal
__syncthreads();
if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
outChannels[threadIdx.x].signal();
outChannels[threadIdx.x].wait();
Expand All @@ -433,7 +439,11 @@ __global__ void __launch_bounds__(512, 1)
data);
}
}
// Ensure all threads have issued writes to outChannel
__syncthreads();
}
// Threads are already synchronized
// So all writes to outChannel have been issued before signal is being issued
if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
outChannels[threadIdx.x].signal();
outChannels[threadIdx.x].wait();
Expand Down

0 comments on commit d8d0dfb

Please sign in to comment.