Skip to content

Commit 0c7ed2c

Browse files
Add ncclBcast / ncclBroadcast support (#419)
A simple broadcast using scratch buffer and option to use an executor.
1 parent d8d0dfb commit 0c7ed2c

File tree

2 files changed

+264
-6
lines changed

2 files changed

+264
-6
lines changed

apps/nccl/src/broadcast.hpp

+171
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
// Copyright (c) Microsoft Corporation.
2+
// Licensed under the MIT license.
3+
4+
#ifndef BROADCAST_HPP_
5+
#define BROADCAST_HPP_
6+
7+
#include <mscclpp/concurrency_device.hpp>
8+
#include <mscclpp/core.hpp>
9+
#include <mscclpp/gpu.hpp>
10+
#include <mscclpp/sm_channel.hpp>
11+
#include <mscclpp/sm_channel_device.hpp>
12+
13+
#include "common.hpp"
14+
15+
template <bool IsOutOfPlace>
16+
__global__ void __launch_bounds__(1024, 1)
17+
broadcast6(void* sendbuff, void* scratchbuff, void* recvbuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
18+
size_t channelOutOffset, size_t rank, [[maybe_unused]] size_t worldSize, size_t root,
19+
size_t nRanksPerNode, size_t nelemsPerGPU) {
20+
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
21+
const size_t lid = tid % WARP_SIZE;
22+
const size_t wid = tid / WARP_SIZE;
23+
24+
const size_t nThread = blockDim.x * gridDim.x;
25+
const size_t nWarp = nThread / WARP_SIZE;
26+
const size_t nPeer = nRanksPerNode - 1;
27+
const size_t chanOffset = nPeer * blockIdx.x;
28+
29+
__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> smChans[NRANKS_PER_NODE - 1];
30+
if (threadIdx.x < nPeer) {
31+
smChans[threadIdx.x] = smChannels[chanOffset + threadIdx.x];
32+
smChans[threadIdx.x].relaxedSignal();
33+
smChans[threadIdx.x].wait();
34+
}
35+
__syncthreads();
36+
37+
const size_t peerRootIdx = (root == rank) ? nPeer : ((root < rank) ? root : (root - 1));
38+
39+
const size_t bytesPerGPU = nelemsPerGPU * sizeof(int);
40+
const size_t bytes = bytesPerGPU;
41+
size_t unitBytesPerThread;
42+
if (bytes * nPeer >= nThread * 64) {
43+
unitBytesPerThread = 64;
44+
} else {
45+
unitBytesPerThread = 16;
46+
}
47+
const size_t unitBytesPerBlock = unitBytesPerThread * blockDim.x;
48+
const size_t unitBytes = unitBytesPerBlock * gridDim.x;
49+
const size_t nLoop = bytes / unitBytes;
50+
51+
const size_t maxScratchSizeToUse = (SCRATCH_SIZE - unitBytes);
52+
const size_t nLoopToSync = (maxScratchSizeToUse / unitBytes) + 1;
53+
54+
size_t scratchSub = 0;
55+
56+
// First loop will always fit the scratch size.
57+
if (nLoop > 0) {
58+
// First loop unrolling
59+
const size_t offset = blockIdx.x * unitBytesPerBlock;
60+
if (rank == root) {
61+
char* send_ = reinterpret_cast<char*>(sendbuff);
62+
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
63+
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
64+
smChans[peerIdx].copy<16, false>(dst + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
65+
__syncthreads();
66+
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
67+
}
68+
if constexpr (IsOutOfPlace) {
69+
char* recv_ = reinterpret_cast<char*>(recvbuff);
70+
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
71+
}
72+
73+
} else { // rank != root.
74+
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
75+
__syncthreads();
76+
char* recv_ = reinterpret_cast<char*>(recvbuff);
77+
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
78+
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x,
79+
blockDim.x);
80+
}
81+
}
82+
83+
for (size_t i = 1; i < nLoop; ++i) {
84+
const size_t offset = blockIdx.x * unitBytesPerBlock + i * unitBytes;
85+
if (i % nLoopToSync == 0) { // Sync to reuse scratch buff
86+
scratchSub = -i * unitBytes;
87+
deviceSyncer.sync(gridDim.x);
88+
if (threadIdx.x < nPeer) {
89+
smChans[threadIdx.x].relaxedSignal();
90+
smChans[threadIdx.x].wait();
91+
}
92+
}
93+
if (rank == root) {
94+
char* send_ = reinterpret_cast<char*>(sendbuff);
95+
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
96+
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
97+
smChans[peerIdx].copy<16, false>(dst + offset + scratchSub, send_ + offset, unitBytesPerBlock, threadIdx.x,
98+
blockDim.x);
99+
__syncthreads();
100+
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
101+
}
102+
if constexpr (IsOutOfPlace) {
103+
char* recv_ = reinterpret_cast<char*>(recvbuff);
104+
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
105+
}
106+
} else { // rank != root.
107+
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
108+
__syncthreads();
109+
char* recv_ = reinterpret_cast<char*>(recvbuff);
110+
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
111+
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset + scratchSub, unitBytesPerBlock,
112+
threadIdx.x, blockDim.x);
113+
}
114+
}
115+
116+
// Remainder loop will also fit the scratch buff since we subtract unitBytes from SCRATCH_SIZE.
117+
if (bytes % unitBytes > 0) { // remainder.
118+
const size_t offset = blockIdx.x * unitBytesPerBlock + nLoop * unitBytes;
119+
const size_t remainBytes = (offset < bytes) ? (bytes - offset) : 0;
120+
if (remainBytes > 0) {
121+
if (rank == root) {
122+
char* send_ = reinterpret_cast<char*>(sendbuff);
123+
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
124+
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
125+
smChans[peerIdx].copy<16, true>(dst + offset + scratchSub, send_ + offset, remainBytes, threadIdx.x,
126+
blockDim.x);
127+
__syncthreads();
128+
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
129+
}
130+
if constexpr (IsOutOfPlace) {
131+
char* recv_ = reinterpret_cast<char*>(recvbuff);
132+
smChans[0].copy<16, true>(recv_ + offset, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
133+
}
134+
} else { // rank != root.
135+
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
136+
__syncthreads();
137+
char* recv_ = reinterpret_cast<char*>(recvbuff);
138+
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
139+
smChans[peerRootIdx].copy<16, true>(recv_ + offset, scratch_ + offset + scratchSub, remainBytes, threadIdx.x,
140+
blockDim.x);
141+
}
142+
} // remainBytes > 0.
143+
}
144+
145+
deviceSyncer.sync(gridDim.x);
146+
147+
if (threadIdx.x < nPeer) {
148+
smChans[threadIdx.x].relaxedSignal();
149+
smChans[threadIdx.x].wait();
150+
}
151+
}
152+
153+
template <bool IsOutOfPlace, typename T>
154+
cudaError_t broadcast(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
155+
size_t channelOutOffset, int rank, int nRanksPerNode, int root, int worldSize, size_t nelems,
156+
cudaStream_t stream) {
157+
int nBlocks = 7;
158+
// if (nelems <= 4096) {
159+
// nBlocks = 7;
160+
// } else if (nelems <= 32768) {
161+
// nBlocks = 14;
162+
// } else if (nelems >= 2097152) {
163+
// nBlocks = 35;
164+
// }
165+
broadcast6<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, (void*)scratch, (void*)resultBuff, smChannels,
166+
channelOutOffset, rank, worldSize, root, nRanksPerNode,
167+
nelems * sizeof(T) / sizeof(int));
168+
return cudaGetLastError();
169+
}
170+
171+
#endif // BROADCAST_HPP_

apps/nccl/src/nccl.cu

+93-6
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
#include "allgather.hpp"
1717
#include "allreduce.hpp"
18+
#include "broadcast.hpp"
1819
#include "nccl.h"
1920

2021
#define NCCL_API extern "C" __attribute__((visibility("default")))
@@ -530,14 +531,100 @@ NCCL_API ncclResult_t ncclReduce(const void*, void*, size_t, ncclDataType_t, ncc
530531
return ncclInternalError;
531532
}
532533

533-
NCCL_API ncclResult_t ncclBcast(void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) {
534-
// TODO: implement this function
535-
return ncclInternalError;
534+
NCCL_API ncclResult_t ncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root, ncclComm_t comm,
535+
cudaStream_t stream) {
536+
return ncclBroadcast(buff, buff, count, datatype, root, comm, stream);
536537
}
537538

538-
NCCL_API ncclResult_t ncclBroadcast(const void*, void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) {
539-
// TODO: implement this function
540-
return ncclInternalError;
539+
NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff, size_t sendcount,
540+
ncclDataType_t datatype, int root, ncclComm_t comm, cudaStream_t stream) {
541+
size_t bytes = sendcount * ncclTypeSize(datatype);
542+
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;
543+
544+
// Declarating variables
545+
size_t recvBytes;
546+
CUdeviceptr recvBasePtr;
547+
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
548+
// size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
549+
size_t offsetOut = 0;
550+
// channelKey recvKey{(void*)recvBasePtr, recvBytes};
551+
channelKey recvKey{(void*)0x0, 0}; // Just create the channel once.
552+
int rank = comm->comm->bootstrap()->getRank();
553+
int nRank = comm->comm->bootstrap()->getNranks();
554+
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
555+
556+
auto it = comm->channelOutInfos.find(recvKey);
557+
if (it == comm->channelOutInfos.end()) {
558+
// std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
559+
// comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
560+
// std::vector<mscclpp::SmChannel> channels =
561+
// setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
562+
std::vector<mscclpp::SmChannel> channels =
563+
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)recvBasePtr));
564+
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
565+
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
566+
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
567+
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
568+
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
569+
}
570+
571+
smChannels = it->second.smChannelDeviceHandles.get();
572+
if ((char*)sendbuff == (char*)recvbuff) {
573+
CUDACHECK(broadcast<false>((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, offsetOut,
574+
rank, NRANKS_PER_NODE, root, nRank, bytes / sizeof(int), stream));
575+
} else {
576+
CUDACHECK(broadcast<true>((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, offsetOut,
577+
rank, NRANKS_PER_NODE, root, nRank, bytes / sizeof(int), stream));
578+
}
579+
580+
return ncclSuccess;
581+
}
582+
583+
NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
584+
int root, ncclComm_t comm, cudaStream_t stream) {
585+
size_t bytes = count * ncclTypeSize(datatype);
586+
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;
587+
588+
int rank = comm->comm->bootstrap()->getRank();
589+
int nRank = comm->comm->bootstrap()->getNranks();
590+
591+
std::vector<executionPlanInstance>& plans = comm->executionPlans["broadcast"];
592+
std::shared_ptr<mscclpp::ExecutionPlan> plan;
593+
void* basePtr = (char*)sendbuff;
594+
bool inPlace = basePtr == recvbuff;
595+
const size_t totalBytes = bytes;
596+
for (const auto& p : plans) {
597+
if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
598+
plan = p.plan;
599+
break;
600+
}
601+
}
602+
603+
if (plan == nullptr) return ncclBroadcastFallback(sendbuff, recvbuff, count, datatype, root, comm, stream);
604+
605+
switch (datatype) {
606+
case ncclFloat16:
607+
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan,
608+
stream);
609+
break;
610+
case ncclFloat32:
611+
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32, *plan,
612+
stream);
613+
break;
614+
case ncclBfloat16:
615+
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes,
616+
mscclpp::DataType::BFLOAT16, *plan, stream);
617+
break;
618+
case ncclInt32:
619+
case ncclUint32:
620+
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan,
621+
stream);
622+
break;
623+
default:
624+
return ncclInvalidArgument;
625+
}
626+
627+
return ncclSuccess;
541628
}
542629

543630
NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,

0 commit comments

Comments
 (0)