Skip to content

Commit

Permalink
Add ncclBcast / ncclBroadcast support (#419)
Browse files Browse the repository at this point in the history
A simple broadcast using scratch buffer and option to use an executor.
  • Loading branch information
SreevatsaAnantharamu authored Dec 19, 2024
1 parent d8d0dfb commit 0c7ed2c
Show file tree
Hide file tree
Showing 2 changed files with 264 additions and 6 deletions.
171 changes: 171 additions & 0 deletions apps/nccl/src/broadcast.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

#ifndef BROADCAST_HPP_
#define BROADCAST_HPP_

#include <mscclpp/concurrency_device.hpp>
#include <mscclpp/core.hpp>
#include <mscclpp/gpu.hpp>
#include <mscclpp/sm_channel.hpp>
#include <mscclpp/sm_channel_device.hpp>

#include "common.hpp"

template <bool IsOutOfPlace>
__global__ void __launch_bounds__(1024, 1)
broadcast6(void* sendbuff, void* scratchbuff, void* recvbuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
size_t channelOutOffset, size_t rank, [[maybe_unused]] size_t worldSize, size_t root,
size_t nRanksPerNode, size_t nelemsPerGPU) {
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
const size_t lid = tid % WARP_SIZE;
const size_t wid = tid / WARP_SIZE;

const size_t nThread = blockDim.x * gridDim.x;
const size_t nWarp = nThread / WARP_SIZE;
const size_t nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;

__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> smChans[NRANKS_PER_NODE - 1];
if (threadIdx.x < nPeer) {
smChans[threadIdx.x] = smChannels[chanOffset + threadIdx.x];
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].wait();
}
__syncthreads();

const size_t peerRootIdx = (root == rank) ? nPeer : ((root < rank) ? root : (root - 1));

const size_t bytesPerGPU = nelemsPerGPU * sizeof(int);
const size_t bytes = bytesPerGPU;
size_t unitBytesPerThread;
if (bytes * nPeer >= nThread * 64) {
unitBytesPerThread = 64;
} else {
unitBytesPerThread = 16;
}
const size_t unitBytesPerBlock = unitBytesPerThread * blockDim.x;
const size_t unitBytes = unitBytesPerBlock * gridDim.x;
const size_t nLoop = bytes / unitBytes;

const size_t maxScratchSizeToUse = (SCRATCH_SIZE - unitBytes);
const size_t nLoopToSync = (maxScratchSizeToUse / unitBytes) + 1;

size_t scratchSub = 0;

// First loop will always fit the scratch size.
if (nLoop > 0) {
// First loop unrolling
const size_t offset = blockIdx.x * unitBytesPerBlock;
if (rank == root) {
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}

} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
}
}

for (size_t i = 1; i < nLoop; ++i) {
const size_t offset = blockIdx.x * unitBytesPerBlock + i * unitBytes;
if (i % nLoopToSync == 0) { // Sync to reuse scratch buff
scratchSub = -i * unitBytes;
deviceSyncer.sync(gridDim.x);
if (threadIdx.x < nPeer) {
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].wait();
}
}
if (rank == root) {
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, false>(dst + offset + scratchSub, send_ + offset, unitBytesPerBlock, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, false>(recv_ + offset, send_ + offset, unitBytesPerBlock, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, false>(recv_ + offset, scratch_ + offset + scratchSub, unitBytesPerBlock,
threadIdx.x, blockDim.x);
}
}

// Remainder loop will also fit the scratch buff since we subtract unitBytes from SCRATCH_SIZE.
if (bytes % unitBytes > 0) { // remainder.
const size_t offset = blockIdx.x * unitBytesPerBlock + nLoop * unitBytes;
const size_t remainBytes = (offset < bytes) ? (bytes - offset) : 0;
if (remainBytes > 0) {
if (rank == root) {
char* send_ = reinterpret_cast<char*>(sendbuff);
for (size_t peerIdx = 0; peerIdx < nPeer; peerIdx++) {
char* dst = reinterpret_cast<char*>(smChans[peerIdx].dst_); // Peer's scratchbuff.
smChans[peerIdx].copy<16, true>(dst + offset + scratchSub, send_ + offset, remainBytes, threadIdx.x,
blockDim.x);
__syncthreads();
if (threadIdx.x == peerIdx) smChans[peerIdx].signal();
}
if constexpr (IsOutOfPlace) {
char* recv_ = reinterpret_cast<char*>(recvbuff);
smChans[0].copy<16, true>(recv_ + offset, send_ + offset, remainBytes, threadIdx.x, blockDim.x);
}
} else { // rank != root.
if (threadIdx.x == peerRootIdx) smChans[peerRootIdx].wait();
__syncthreads();
char* recv_ = reinterpret_cast<char*>(recvbuff);
char* scratch_ = reinterpret_cast<char*>(scratchbuff); // My scratchbuff.
smChans[peerRootIdx].copy<16, true>(recv_ + offset, scratch_ + offset + scratchSub, remainBytes, threadIdx.x,
blockDim.x);
}
} // remainBytes > 0.
}

deviceSyncer.sync(gridDim.x);

if (threadIdx.x < nPeer) {
smChans[threadIdx.x].relaxedSignal();
smChans[threadIdx.x].wait();
}
}

template <bool IsOutOfPlace, typename T>
cudaError_t broadcast(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
size_t channelOutOffset, int rank, int nRanksPerNode, int root, int worldSize, size_t nelems,
cudaStream_t stream) {
int nBlocks = 7;
// if (nelems <= 4096) {
// nBlocks = 7;
// } else if (nelems <= 32768) {
// nBlocks = 14;
// } else if (nelems >= 2097152) {
// nBlocks = 35;
// }
broadcast6<IsOutOfPlace><<<nBlocks, 1024, 0, stream>>>((void*)buff, (void*)scratch, (void*)resultBuff, smChannels,
channelOutOffset, rank, worldSize, root, nRanksPerNode,
nelems * sizeof(T) / sizeof(int));
return cudaGetLastError();
}

#endif // BROADCAST_HPP_
99 changes: 93 additions & 6 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

#include "allgather.hpp"
#include "allreduce.hpp"
#include "broadcast.hpp"
#include "nccl.h"

#define NCCL_API extern "C" __attribute__((visibility("default")))
Expand Down Expand Up @@ -530,14 +531,100 @@ NCCL_API ncclResult_t ncclReduce(const void*, void*, size_t, ncclDataType_t, ncc
return ncclInternalError;
}

NCCL_API ncclResult_t ncclBcast(void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) {
// TODO: implement this function
return ncclInternalError;
NCCL_API ncclResult_t ncclBcast(void* buff, size_t count, ncclDataType_t datatype, int root, ncclComm_t comm,
cudaStream_t stream) {
return ncclBroadcast(buff, buff, count, datatype, root, comm, stream);
}

NCCL_API ncclResult_t ncclBroadcast(const void*, void*, size_t, ncclDataType_t, int, ncclComm_t, cudaStream_t) {
// TODO: implement this function
return ncclInternalError;
NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff, size_t sendcount,
ncclDataType_t datatype, int root, ncclComm_t comm, cudaStream_t stream) {
size_t bytes = sendcount * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;

// Declarating variables
size_t recvBytes;
CUdeviceptr recvBasePtr;
MSCCLPP_CUTHROW(cuMemGetAddressRange(&recvBasePtr, &recvBytes, (CUdeviceptr)recvbuff));
// size_t offsetOut = (char*)recvbuff - (char*)recvBasePtr;
size_t offsetOut = 0;
// channelKey recvKey{(void*)recvBasePtr, recvBytes};
channelKey recvKey{(void*)0x0, 0}; // Just create the channel once.
int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;

auto it = comm->channelOutInfos.find(recvKey);
if (it == comm->channelOutInfos.end()) {
// std::vector<mscclpp::RegisteredMemory> remoteMemories = setupRemoteMemories(
// comm->comm, rank, const_cast<void*>((void*)recvBasePtr), recvBytes, mscclpp::Transport::CudaIpc);
// std::vector<mscclpp::SmChannel> channels =
// setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)recvBasePtr));
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

smChannels = it->second.smChannelDeviceHandles.get();
if ((char*)sendbuff == (char*)recvbuff) {
CUDACHECK(broadcast<false>((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, offsetOut,
rank, NRANKS_PER_NODE, root, nRank, bytes / sizeof(int), stream));
} else {
CUDACHECK(broadcast<true>((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, offsetOut,
rank, NRANKS_PER_NODE, root, nRank, bytes / sizeof(int), stream));
}

return ncclSuccess;
}

NCCL_API ncclResult_t ncclBroadcast(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
int root, ncclComm_t comm, cudaStream_t stream) {
size_t bytes = count * ncclTypeSize(datatype);
if (sendbuff == nullptr || recvbuff == nullptr || bytes == 0 || comm == nullptr) return ncclInvalidArgument;

int rank = comm->comm->bootstrap()->getRank();
int nRank = comm->comm->bootstrap()->getNranks();

std::vector<executionPlanInstance>& plans = comm->executionPlans["broadcast"];
std::shared_ptr<mscclpp::ExecutionPlan> plan;
void* basePtr = (char*)sendbuff;
bool inPlace = basePtr == recvbuff;
const size_t totalBytes = bytes;
for (const auto& p : plans) {
if (totalBytes >= p.key.minMessageSize && totalBytes < p.key.maxMessageSize && inPlace == p.key.isInPlace) {
plan = p.plan;
break;
}
}

if (plan == nullptr) return ncclBroadcastFallback(sendbuff, recvbuff, count, datatype, root, comm, stream);

switch (datatype) {
case ncclFloat16:
comm->executor->execute(rank, (half*)sendbuff, (half*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT16, *plan,
stream);
break;
case ncclFloat32:
comm->executor->execute(rank, (float*)sendbuff, (float*)recvbuff, bytes, bytes, mscclpp::DataType::FLOAT32, *plan,
stream);
break;
case ncclBfloat16:
comm->executor->execute(rank, (__bfloat16*)sendbuff, (__bfloat16*)recvbuff, bytes, bytes,
mscclpp::DataType::BFLOAT16, *plan, stream);
break;
case ncclInt32:
case ncclUint32:
comm->executor->execute(rank, (int*)sendbuff, (int*)recvbuff, bytes, bytes, mscclpp::DataType::UINT32, *plan,
stream);
break;
default:
return ncclInvalidArgument;
}

return ncclSuccess;
}

NCCL_API ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count, ncclDataType_t datatype,
Expand Down

0 comments on commit 0c7ed2c

Please sign in to comment.