Skip to content

Commit

Permalink
Lazily create streams for CudaIpcConnection (#449)
Browse files Browse the repository at this point in the history
  • Loading branch information
chhwang authored Jan 15, 2025
1 parent 869cdba commit d12247b
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 31 deletions.
3 changes: 3 additions & 0 deletions include/mscclpp/gpu_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ struct AvoidCudaGraphCaptureGuard {

/// A RAII wrapper around cudaStream_t that will call cudaStreamDestroy on destruction.
struct CudaStreamWithFlags {
CudaStreamWithFlags() : stream_(nullptr) {}
CudaStreamWithFlags(unsigned int flags);
~CudaStreamWithFlags();
void set(unsigned int flags);
bool empty() const;
operator cudaStream_t() const { return stream_; }
cudaStream_t stream_;
};
Expand Down
17 changes: 13 additions & 4 deletions src/connection.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <mscclpp/npkit/npkit.hpp>
#endif

#include <mscclpp/env.hpp>
#include <mscclpp/utils.hpp>
#include <sstream>
#include <thread>
Expand Down Expand Up @@ -40,7 +41,8 @@ int Connection::getMaxWriteQueueSize() { return maxWriteQueueSize; }

// CudaIpcConnection

CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream)
CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint,
std::shared_ptr<CudaStreamWithFlags> stream)
: Connection(localEndpoint.maxWriteQueueSize()), stream_(stream) {
if (localEndpoint.transport() != Transport::CudaIpc) {
throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage);
Expand Down Expand Up @@ -74,7 +76,9 @@ void CudaIpcConnection::write(RegisteredMemory dst, uint64_t dstOffset, Register
char* dstPtr = (char*)dst.data();
char* srcPtr = (char*)src.data();

MSCCLPP_CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, stream_));
if (!env().cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);

MSCCLPP_CUDATHROW(cudaMemcpyAsync(dstPtr + dstOffset, srcPtr + srcOffset, size, cudaMemcpyDeviceToDevice, *stream_));
INFO(MSCCLPP_P2P, "CudaIpcConnection write: from %p to %p, size %lu", srcPtr + srcOffset, dstPtr + dstOffset, size);

#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_CONN_CUDA_IPC_WRITE_EXIT)
Expand All @@ -92,7 +96,9 @@ void CudaIpcConnection::updateAndSync(RegisteredMemory dst, uint64_t dstOffset,
*src = newValue;
uint64_t* dstPtr = reinterpret_cast<uint64_t*>(reinterpret_cast<char*>(dst.data()) + dstOffset);

MSCCLPP_CUDATHROW(cudaMemcpyAsync(dstPtr, src, sizeof(uint64_t), cudaMemcpyHostToDevice, stream_));
if (!env().cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);

MSCCLPP_CUDATHROW(cudaMemcpyAsync(dstPtr, src, sizeof(uint64_t), cudaMemcpyHostToDevice, *stream_));
INFO(MSCCLPP_P2P, "CudaIpcConnection atomic write: from %p to %p, %lu -> %lu", src, dstPtr + dstOffset, oldValue,
newValue);

Expand All @@ -109,8 +115,11 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
if (timeoutUsec >= 0) {
INFO(MSCCLPP_P2P, "CudaIpcConnection flush: timeout is not supported, ignored");
}

if (!env().cudaIpcUseDefaultStream && stream_->empty()) stream_->set(cudaStreamNonBlocking);

AvoidCudaGraphCaptureGuard guard;
MSCCLPP_CUDATHROW(cudaStreamSynchronize(stream_));
MSCCLPP_CUDATHROW(cudaStreamSynchronize(*stream_));
INFO(MSCCLPP_P2P, "CudaIpcConnection flushing connection");

#if defined(ENABLE_NPKIT) && defined(ENABLE_NPKIT_EVENT_CONN_CUDA_IPC_FLUSH_EXIT)
Expand Down
7 changes: 2 additions & 5 deletions src/context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

namespace mscclpp {

Context::Impl::Impl() {}
Context::Impl::Impl() : ipcStream_(std::make_shared<CudaStreamWithFlags>()) {}

IbCtx* Context::Impl::getIbContext(Transport ibTransport) {
// Find IB context or create it
Expand Down Expand Up @@ -43,10 +43,7 @@ MSCCLPP_API_CPP std::shared_ptr<Connection> Context::connect(Endpoint localEndpo
if (remoteEndpoint.transport() != Transport::CudaIpc) {
throw mscclpp::Error("Local transport is CudaIpc but remote is not", ErrorCode::InvalidUsage);
}
if (!(pimpl_->ipcStream_)) {
pimpl_->ipcStream_ = std::make_shared<CudaStreamWithFlags>(cudaStreamNonBlocking);
}
conn = std::make_shared<CudaIpcConnection>(localEndpoint, remoteEndpoint, cudaStream_t(*(pimpl_->ipcStream_)));
conn = std::make_shared<CudaIpcConnection>(localEndpoint, remoteEndpoint, pimpl_->ipcStream_);
} else if (AllIBTransports.has(localEndpoint.transport())) {
if (!AllIBTransports.has(remoteEndpoint.transport())) {
throw mscclpp::Error("Local transport is IB but remote is not", ErrorCode::InvalidUsage);
Expand Down
19 changes: 0 additions & 19 deletions src/cuda_utils.cc

This file was deleted.

21 changes: 21 additions & 0 deletions src/gpu_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,27 @@

namespace mscclpp {

AvoidCudaGraphCaptureGuard::AvoidCudaGraphCaptureGuard() : mode_(cudaStreamCaptureModeRelaxed) {
MSCCLPP_CUDATHROW(cudaThreadExchangeStreamCaptureMode(&mode_));
}

AvoidCudaGraphCaptureGuard::~AvoidCudaGraphCaptureGuard() { (void)cudaThreadExchangeStreamCaptureMode(&mode_); }

CudaStreamWithFlags::CudaStreamWithFlags(unsigned int flags) {
MSCCLPP_CUDATHROW(cudaStreamCreateWithFlags(&stream_, flags));
}

CudaStreamWithFlags::~CudaStreamWithFlags() {
if (!empty()) (void)cudaStreamDestroy(stream_);
}

void CudaStreamWithFlags::set(unsigned int flags) {
if (!empty()) throw Error("CudaStreamWithFlags already set", ErrorCode::InternalError);
MSCCLPP_CUDATHROW(cudaStreamCreateWithFlags(&stream_, flags));
}

bool CudaStreamWithFlags::empty() const { return stream_ == nullptr; }

namespace detail {

/// set memory access permission to read-write
Expand Down
6 changes: 3 additions & 3 deletions src/include/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#define MSCCLPP_CONNECTION_HPP_

#include <mscclpp/core.hpp>
#include <mscclpp/gpu.hpp>
#include <mscclpp/gpu_utils.hpp>

#include "communicator.hpp"
#include "context.hpp"
Expand All @@ -16,10 +16,10 @@
namespace mscclpp {

class CudaIpcConnection : public Connection {
cudaStream_t stream_;
std::shared_ptr<CudaStreamWithFlags> stream_;

public:
CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream);
CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, std::shared_ptr<CudaStreamWithFlags> stream);

Transport transport() override;

Expand Down

0 comments on commit d12247b

Please sign in to comment.