Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CPX mode on MI300X #446

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 115 additions & 1 deletion apps/nccl/src/allreduce.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -495,13 +495,122 @@ __global__ void __launch_bounds__(512, 1)
}
}

template <typename T>
__global__ void __launch_bounds__(512, 1)
allreduce8Read(T* buff, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelOutDataOffset,
int rank, int nRanksPerNode, int worldSize, size_t nelems) {
const int nPeer = nRanksPerNode - 1;
const size_t chanOffset = nPeer * blockIdx.x;
// assume (nelems * sizeof(T)) is divisible by (16 * worldSize)
const size_t nInt4 = nelems * sizeof(T) / sizeof(int4);
const size_t nInt4PerRank = nInt4 / worldSize;
auto smChans = smChannels + chanOffset;
auto smOutChans = smOutChannels + chanOffset;

int4* buff4 = reinterpret_cast<int4*>(buff);
int4* resultBuff4 = reinterpret_cast<int4*>(resultBuff);

// Distribute `nInt4PerRank` across all blocks with the unit size `unitNInt4`
constexpr size_t unitNInt4 = 512;
const size_t maxNInt4PerBlock =
(((nInt4PerRank + gridDim.x - 1) / gridDim.x) + unitNInt4 - 1) / unitNInt4 * unitNInt4;
size_t offsetOfThisBlock = maxNInt4PerBlock * blockIdx.x;
size_t nInt4OfThisBlock = maxNInt4PerBlock;
size_t nNeededBlocks = (nInt4PerRank + maxNInt4PerBlock - 1) / maxNInt4PerBlock;
constexpr size_t nInt4PerChunk = 1024 * 256 / sizeof(int4); // 256KB
if (blockIdx.x >= nNeededBlocks) {
nInt4OfThisBlock = 0;
} else if (blockIdx.x == nNeededBlocks - 1) {
nInt4OfThisBlock = nInt4PerRank - maxNInt4PerBlock * (nNeededBlocks - 1);
}

const size_t nItrs = nInt4OfThisBlock / nInt4PerChunk;
const size_t restNInt4 = nInt4OfThisBlock % nInt4PerChunk;

__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> channels[NRANKS_PER_NODE - 1];
__shared__ mscclpp::DeviceHandle<mscclpp::SmChannel> outChannels[NRANKS_PER_NODE - 1];
const int lid = threadIdx.x % WARP_SIZE;

if (lid < nPeer) {
channels[lid] = smChans[lid];
outChannels[lid] = smOutChans[lid];
}
__syncwarp();

for (size_t itr = 0; itr < nItrs; itr++) {
if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
channels[threadIdx.x].signal();
channels[threadIdx.x].wait();
}
__syncthreads();

for (size_t idx = threadIdx.x; idx < nInt4PerChunk; idx += blockDim.x) {
int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
int4 val = channels[peerIdx].read<int4>(nInt4PerRank * rank + offsetOfThisBlock + idx);;
data = add_vectors<T>(val, data);
}
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
data);
}
}
__syncthreads();

offsetOfThisBlock += nInt4PerChunk;
}

if (restNInt4 > 0) {
if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
outChannels[threadIdx.x].signal();
outChannels[threadIdx.x].wait();
}
__syncthreads();

for (size_t idx = threadIdx.x; idx < restNInt4; idx += blockDim.x) {
int4 data = buff4[nInt4PerRank * rank + idx + offsetOfThisBlock];
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
int4 val = channels[peerIdx].read<int4>(nInt4PerRank * rank + offsetOfThisBlock + idx);;
data = add_vectors<T>(val, data);
}
resultBuff4[nInt4PerRank * rank + idx + offsetOfThisBlock] = data;
for (int peerIdx = 0; peerIdx < nPeer; peerIdx++) {
outChannels[peerIdx].write(nInt4PerRank * rank + idx + offsetOfThisBlock + channelOutDataOffset / sizeof(int4),
data);
}
}
__syncthreads();
}

if (threadIdx.x < static_cast<uint32_t>(nPeer)) {
outChannels[threadIdx.x].signal();
outChannels[threadIdx.x].wait();
}
__syncthreads();

}

template <typename T>
cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels,
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels, size_t channelInOffset,
size_t channelOutOffset, size_t channelScratchOffset, int rank, int nRanksPerNode, int worldSize,
size_t nelems, cudaStream_t stream) {
static uint32_t flag = 1;

int readAllred = 0;
char* envValue = nullptr;

envValue = std::getenv("MSCCLPP_READ_ALLRED");

if (envValue != nullptr) {
if (atoi(envValue) == 1) {
readAllred = 1;
}
}

if (sizeof(T) * nelems < worldSize * sizeof(int)) {
int nBlocks = 7;
int nThreadsPerBlock = 32;
Expand All @@ -528,9 +637,14 @@ cudaError_t allreduce(T* buff, T* scratch, T* resultBuff, mscclpp::DeviceHandle<
} else {
int nBlocks = 35;
int nThreadsPerBlock = 512;
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smChannels, smOutChannels,
if (!readAllred) {
allreduce8<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, scratch, resultBuff, smScrChannels, smOutChannels,
channelOutOffset, channelScratchOffset, rank, nRanksPerNode,
worldSize, nelems);
} else {
allreduce8Read<<<nBlocks, nThreadsPerBlock, 0, stream>>>(buff, resultBuff, smChannels, smOutChannels,
channelOutOffset, rank, nRanksPerNode, worldSize, nelems);
}
}

return cudaGetLastError();
Expand Down
29 changes: 19 additions & 10 deletions apps/nccl/src/nccl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ struct hash<channelKey> {

struct ChannelInfo {
std::vector<mscclpp::SmChannel> smChannels;
std::vector<mscclpp::SmChannel> smChannels1;
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::shared_ptr<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles1;
};

struct ncclComm {
Expand Down Expand Up @@ -212,26 +214,32 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
channelKey recvKey{(void*)recvBasePtr, recvBytes};
mscclpp::DeviceHandle<mscclpp::SmChannel>* smChannels = nullptr;
mscclpp::DeviceHandle<mscclpp::SmChannel>* smOutChannels = nullptr;
mscclpp::DeviceHandle<mscclpp::SmChannel>* smScrChannels = nullptr;

// Creating the channels
if (count * ncclTypeSize(datatype) <= (1 << 20)) {
auto sendIt = comm->channelScratchInfos.find(sendKey);
if (sendIt == comm->channelScratchInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)};
sendIt = comm->channelScratchInfos.emplace(sendKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
} else {
std::vector<mscclpp::RegisteredMemory> remoteMemories;
std::vector<mscclpp::RegisteredMemory> remoteMemories1;

auto sendIt = comm->channelInInfos.find(sendKey);
if (sendIt == comm->channelInInfos.end()) {
std::vector<mscclpp::SmChannel> channels =
setupSmChannels(comm, comm->remoteScratchRegMemories, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
remoteMemories1 =
setupRemoteMemories(comm->comm, rank, (void*)sendBasePtr, sendBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> channels1 =
setupSmChannels(comm, remoteMemories1, const_cast<void*>((void*)sendBasePtr));
ChannelInfo channelInfo{channels, channels1, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels1)};
sendIt = comm->channelInInfos.emplace(sendKey, channelInfo).first;
}

Expand All @@ -241,33 +249,34 @@ static ncclResult_t ncclAllReduceFallback(const void* sendbuff, void* recvbuff,
setupRemoteMemories(comm->comm, rank, (void*)recvBasePtr, recvBytes, mscclpp::Transport::CudaIpc);
std::vector<mscclpp::SmChannel> outChannels =
setupSmChannels(comm, remoteMemories, const_cast<void*>((void*)recvBasePtr));
ChannelInfo channelInfo{outChannels, setupSmChannelDeviceHandles(outChannels)};
ChannelInfo channelInfo{outChannels, outChannels, setupSmChannelDeviceHandles(outChannels), setupSmChannelDeviceHandles(outChannels)};
recvIt = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

smChannels = sendIt->second.smChannelDeviceHandles.get();
smChannels = sendIt->second.smChannelDeviceHandles1.get();
smOutChannels = recvIt->second.smChannelDeviceHandles.get();
smScrChannels = sendIt->second.smChannelDeviceHandles.get();
}

switch (datatype) {
case ncclFloat16:
CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smOutChannels,
CUDACHECK(allreduce((half*)sendbuff, (half*)comm->scratchBuff.get(), (half*)recvbuff, smChannels, smScrChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclFloat32:
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels,
CUDACHECK(allreduce((float*)sendbuff, (float*)comm->scratchBuff.get(), (float*)recvbuff, smChannels, smScrChannels,
smOutChannels, offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(),
NRANKS_PER_NODE, comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclBfloat16:
CUDACHECK(allreduce((__bfloat16*)sendbuff, (__bfloat16*)comm->scratchBuff.get(), (__bfloat16*)recvbuff,
smChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
smChannels, smScrChannels, smOutChannels, offsetIn, offsetOut, offsetScratch, rank, NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
case ncclInt32:
case ncclUint32:
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smOutChannels,
CUDACHECK(allreduce((int*)sendbuff, (int*)comm->scratchBuff.get(), (int*)recvbuff, smChannels, smScrChannels, smOutChannels,
offsetIn, offsetOut, offsetScratch, comm->comm->bootstrap()->getRank(), NRANKS_PER_NODE,
comm->comm->bootstrap()->getNranks(), count, stream));
break;
Expand Down Expand Up @@ -314,7 +323,7 @@ static ncclResult_t ncclAllGatherFallback(const void* sendbuff, void* recvbuff,
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)};
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

Expand Down Expand Up @@ -597,7 +606,7 @@ NCCL_API ncclResult_t ncclBroadcastFallback(const void* sendbuff, void* recvbuff
std::vector<mscclpp::DeviceHandle<mscclpp::SmChannel>> smChannelDeviceHandles;
std::transform(channels.begin(), channels.end(), std::back_inserter(smChannelDeviceHandles),
[](const mscclpp::SmChannel& smChannel) { return mscclpp::deviceHandle(smChannel); });
ChannelInfo channelInfo{channels, setupSmChannelDeviceHandles(channels)};
ChannelInfo channelInfo{channels, channels, setupSmChannelDeviceHandles(channels), setupSmChannelDeviceHandles(channels)};
it = comm->channelOutInfos.emplace(recvKey, channelInfo).first;
}

Expand Down