Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update broadcast algo #447

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
76ede02
Lazily create the context stream
chhwang Nov 9, 2024
68db835
Use the default stream for CudaIpcConnection
chhwang Nov 9, 2024
eccb783
Broadcast implementation underprogress for AMD GPUs.
SreevatsaAnantharamu Dec 7, 2024
203f61c
In-place bcast works fine!
SreevatsaAnantharamu Dec 7, 2024
8992d17
Works fine when we run only one single message size
SreevatsaAnantharamu Dec 7, 2024
5b88be0
Removed channelOutOffset from buff.
SreevatsaAnantharamu Dec 7, 2024
ffffc94
Moved smChans to shared memory.
SreevatsaAnantharamu Dec 11, 2024
e446fb1
Slightly improved performance by using more blocks.
SreevatsaAnantharamu Dec 11, 2024
6090022
Expose nccl bcast api, uses broadcast (#409)
pash-msft Dec 12, 2024
b98b977
Scratch buffer copy-based implementation of ncclBcast / ncclBroadcast…
SreevatsaAnantharamu Dec 17, 2024
e93add9
Fused copy of the one-to-all algorithm
SreevatsaAnantharamu Jan 7, 2025
909cab5
Basic version that gives correct results!
SreevatsaAnantharamu Jan 7, 2025
ce7d823
Remove root rank from sync. Need to enhance further.
SreevatsaAnantharamu Jan 7, 2025
3a83816
Heavy bcast sync. Still performs better than RCCL for 114688 bytes.
SreevatsaAnantharamu Jan 7, 2025
d5a17e1
Initial version.
SreevatsaAnantharamu Jan 8, 2025
9919c7f
Initial version.
SreevatsaAnantharamu Jan 8, 2025
98cd5fd
Setting unitBytesPerThread to 32 works but not 64.
SreevatsaAnantharamu Jan 8, 2025
1574891
Fixed the correctness issue
SreevatsaAnantharamu Jan 8, 2025
753d5e3
Fixed a bug for messages > 140M
SreevatsaAnantharamu Jan 8, 2025
23880ce
WIP
Binyang2014 Jan 8, 2025
34b0a18
update
Binyang2014 Jan 9, 2025
230f1af
WIP
Binyang2014 Jan 9, 2025
738bb57
WIP
Binyang2014 Jan 9, 2025
986f4ae
WIP
Binyang2014 Jan 9, 2025
c8e71f6
WIP
Binyang2014 Jan 9, 2025
5cc2b38
update
Binyang2014 Jan 9, 2025
5bb4568
merge main
Binyang2014 Jan 9, 2025
3efcb6f
Merge branch 'main' into binyli/amd-test
Binyang2014 Jan 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 115 additions & 91 deletions apps/nccl/src/broadcast.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ __global__ void __launch_bounds__(1024, 1)
broadcast6(void* sendbuff, void* scratchbuff, void* recvbuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
[[maybe_unused]] size_t channelOutOffset, size_t rank, [[maybe_unused]] size_t worldSize, size_t root,
size_t nRanksPerNode, size_t nelemsPerGPU) {
const size_t bid = blockIdx.x;

const size_t nThread = blockDim.x * gridDim.x;
const size_t nPeer = nRanksPerNode - 1;
const unsigned int nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;

__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> smChans[NRANKS_PER_NODE - 1];
Expand All @@ -29,119 +31,141 @@ __global__ void __launch_bounds__(1024, 1)
}
__syncthreads();

const size_t peerRootIdx = (root == rank) ? nPeer : ((root < rank) ? root : (root - 1));
const unsigned int peerRootIdx = (root == rank) ? nPeer : ((root < rank) ? root : (root - 1));

const size_t bytesPerGPU = nelemsPerGPU * sizeof(int);
const size_t bytes = bytesPerGPU;
size_t unitBytesPerThread;
if (bytes * nPeer >= nThread * 64) {
if (bytes >= nThread * 64 * nPeer) {
unitBytesPerThread = 64;
} else {
unitBytesPerThread = 16;
}
const size_t unitBytesPerBlock = unitBytesPerThread * blockDim.x;
const size_t unitBytes = unitBytesPerBlock * gridDim.x;
const size_t unitBytes = unitBytesPerBlock * gridDim.x * nPeer;
const size_t nLoop = bytes / unitBytes;

const size_t maxScratchSizeToUse = (SCRATCH_SIZE - unitBytes);
const size_t nLoopToSync = (maxScratchSizeToUse / unitBytes) + 1;

size_t scratchSub = 0;

// First loop will always fit the scratch size.
if (nLoop > 0) {
// First loop unrolling
const size_t offset = blockIdx.x * unitBytesPerBlock;
if (rank == root) {
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}

} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
}
}

for (size_t i = 1; i < nLoop; ++i) {
const size_t offset = blockIdx.x * unitBytesPerBlock + i * unitBytes;
if (i % nLoopToSync == 0) { // Sync to reuse scratch buff
scratchSub = -i * unitBytes;
size_t scratchOffset = 0;
for (size_t i = 0; i < nLoop; ++i) {
if (i % nLoopToSync == 0 && i > 0) {
scratchOffset -= nLoopToSync * unitBytes;
deviceSyncer.sync(gridDim.x);
if (threadIdx.x < nPeer) {
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].signal();
smChans[threadIdx.x].wait();
}
}
if (rank == root) {
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset + scratchSub, send_ + offset, unitBytesPerBlock, threadIdx.x,
unsigned int peerIdx = bid % nPeer;
const size_t offset = blockIdx.x * unitBytesPerBlock * nPeer + i * unitBytes;
char* send = reinterpret_cast<char*>(sendbuff);
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_);

smChans[peerIdx].copy<16, false>(dst + offset + scratchOffset, send + offset, nPeer * unitBytesPerBlock,
threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[threadIdx.x].signal();
if (IsOutOfPlace) {
char* recv = reinterpret_cast<char*>(recvbuff);
smChans[peerIdx].copy<16, false>(recv + offset, send + offset, nPeer * unitBytesPerBlock, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset + scratchSub, unitBytesPerBlock,
threadIdx.x, blockDim.x);
}
}
int rankIndexInRoot = (rank < root) ? rank : (rank - 1);
if (bid % nPeer == rankIndexInRoot && threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
deviceSyncer.sync(gridDim.x); // All blocks in the GPU wait.

// Remainder loop will also fit the scratch buff since we subtract unitBytes from SCRATCH_SIZE.
if (bytes % unitBytes > 0) { // remainder.
const size_t offset = blockIdx.x * unitBytesPerBlock + nLoop * unitBytes;
const size_t remainBytes = (offset < bytes) ? (bytes - offset) : 0;
if (remainBytes > 0) {
if (rank == root) {
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
// Step 2.
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff);

const int chunkId = bid % nPeer;
const int chunkGroundId = bid / nPeer;
const size_t offset = chunkId * unitBytesPerBlock +
unitBytesPerBlock * nPeer * (chunkGroundId * nPeer + rankIndexInRoot) + i * unitBytes;
for (unsigned int j = 0; j < nPeer; ++j) {
unsigned int peerIdx = (bid + j) % nPeer;
if (peerIdx != peerRootIdx) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, true>(dst + offset + scratchSub, send_ + offset, remainBytes, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
smChans[peerIdx].copy<16, false>(dst + offset + scratchOffset, scratch_ + offset + scratchOffset,
unitBytesPerBlock, threadIdx.x, blockDim.x);
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, true>(recv_ + offset, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, true>(recv_ + offset, scratch_ + offset + scratchSub, remainBytes, threadIdx.x,
blockDim.x);
}
} // remainBytes > 0.
__syncthreads();
if (threadIdx.x != peerRootIdx && threadIdx.x < nPeer) {
smChans[threadIdx.x].signal();
smChans[threadIdx.x].wait();
}
__syncthreads();
for (unsigned int peerId = 0; peerId < nPeer; ++peerId) {
const size_t offset =
chunkId * unitBytesPerBlock + (peerId + chunkGroundId * nPeer) * unitBytesPerBlock * nPeer + i * unitBytes;
smChans[0].copy<16, false>(recv_ + offset, scratch_ + offset + scratchOffset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
}
}
}

deviceSyncer.sync(gridDim.x);
if (bytes % unitBytes > 0) {
if (rank == root) {
unsigned int peerIdx = bid % nPeer;
const size_t offset = blockIdx.x * unitBytesPerBlock * nPeer + nLoop * unitBytes;
const size_t remainBytes =
offset < bytes ? ((bytes - offset) > unitBytesPerBlock * nPeer ? unitBytesPerBlock * nPeer : (bytes - offset))
: 0;
char* send = reinterpret_cast<char*>(sendbuff);
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_);

smChans[peerIdx].copy<16, true>(dst + offset + scratchOffset, send + offset, remainBytes, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[threadIdx.x].signal();
if constexpr (IsOutOfPlace) {
char* recv = reinterpret_cast<char*>(recvbuff);
smChans[peerIdx].copy<16, true>(recv + offset, send + offset, remainBytes, threadIdx.x, blockDim.x);
}

if (threadIdx.x < nPeer) {
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].wait();
} else {
int rankIndexInRoot = (rank < root) ? rank : (rank - 1);
if (bid % nPeer == rankIndexInRoot && threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
deviceSyncer.sync(gridDim.x);

// Step 2.
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff);

const int chunkId = bid % nPeer;
const int chunkGroundId = bid / nPeer;
const size_t offset = chunkId * unitBytesPerBlock +
unitBytesPerBlock * nPeer * (chunkGroundId * nPeer + rankIndexInRoot) + nLoop * unitBytes;
const size_t remainBytes =
(offset < bytes) ? ((bytes - offset) > unitBytesPerBlock ? unitBytesPerBlock : (bytes - offset)) : 0;

for (size_t j = 0; j < nPeer; ++j) {
unsigned peerIdx = (bid + j) % nPeer;
if (peerIdx != peerRootIdx) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, true>(dst + offset + scratchOffset, scratch_ + offset + scratchOffset, remainBytes,
threadIdx.x, blockDim.x);
}
}
__syncthreads();
if (threadIdx.x != peerRootIdx && threadIdx.x < nPeer) {
smChans[threadIdx.x].signal();
smChans[threadIdx.x].wait();
}
__syncthreads();
for (unsigned int peerId = 0; peerId < nPeer; ++peerId) {
const size_t offset = chunkId * unitBytesPerBlock +
(peerId + chunkGroundId * nPeer) * unitBytesPerBlock * nPeer + nLoop * unitBytes;
const size_t remainBytes =
(offset < bytes) ? ((bytes - offset) > unitBytesPerBlock ? unitBytesPerBlock : (bytes - offset)) : 0;
smChans[0].copy<16, true>(recv_ + offset, scratch_ + offset + scratchOffset, remainBytes, threadIdx.x,
blockDim.x);
}
}
}
}

Expand All @@ -150,13 +174,13 @@ cudaError_t broadcast(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
size_t channelOutOffset, int rank, int nRanksPerNode, int root, int worldSize, size_t nelems,
cudaStream_t stream) {
int nBlocks = 7;
// if (nelems <= 4096) {
// nBlocks = 7;
// } else if (nelems <= 32768) {
// nBlocks = 14;
// } else if (nelems >= 2097152) {
// nBlocks = 35;
// }
if (nelems <= 4096) {
nBlocks = 7;
} else if (nelems >= 32768) {
nBlocks = 14;
} else if (nelems >= 5242880) {
nBlocks = 28;
}
broadcast6<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, (void*)scratch, (void*)resultBuff, smChannels,
channelOutOffset, rank, worldSize, root, nRanksPerNode,
nelems * sizeof(T) / sizeof(int));
Expand Down
2 changes: 1 addition & 1 deletion apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -861,5 +861,5 @@ ncclResult_t ncclMemFree(void* ptr) {

// Pointer not found
WARN("Pointer not found");
return ncclInvalidUsage;
return ncclSuccess;
}
Loading