Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions include/mscclpp/core.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,12 @@ class Context : public std::enable_shared_from_this<Context> {
/// @return The newly created endpoint.
Endpoint createEndpoint(EndpointConfig config);

std::shared_ptr<void> get(std::string name) {return nullptr;}
void set(std::string name, std::shared_ptr<void> value) {}

private:
Context();

/// Establish a connection between two endpoints. While this method immediately returns a connection object, the
/// connection is only safe to use after the corresponding connection on the remote endpoint has been established.
/// This method must be called on both endpoints to establish a connection.
Expand All @@ -609,14 +615,12 @@ class Context : public std::enable_shared_from_this<Context> {
/// @return A shared pointer to the connection.
std::shared_ptr<Connection> connect(Endpoint localEndpoint, Endpoint remoteEndpoint);

private:
Context();

struct Impl;
std::unique_ptr<Impl> pimpl_;

friend class RegisteredMemory;
friend class Endpoint;
friend class Communicator;
};

/// SemaphoreStub object only used for constructing Semaphore, not for direct use by the user.
Expand Down Expand Up @@ -848,7 +852,8 @@ class Communicator {
/// @param tag The tag to use for identifying the send and receive.
/// @return A future of shared pointer to the connection.
///
std::shared_future<std::shared_ptr<Connection>> connect(EndpointConfig localConfig, int remoteRank, int tag = 0);
std::shared_future<std::shared_ptr<Connection>> connect(EndpointConfig localConfig, int remoteRank, int tag = 0,
std::string connName = "core");

[[deprecated("Use connect(localConfig, remoteRank, tag) instead. This will be removed in a future release.")]] std::
shared_future<std::shared_ptr<Connection>>
Expand Down
9 changes: 4 additions & 5 deletions python/mscclpp/core_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,7 @@ void register_core(nb::module_& m) {
return self->registerMemory((void*)ptr, size, transports);
},
nb::arg("ptr"), nb::arg("size"), nb::arg("transports"))
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"))
.def("connect", &Context::connect, nb::arg("local_endpoint"), nb::arg("remote_endpoint"));
.def("create_endpoint", &Context::createEndpoint, nb::arg("config"));

nb::class_<SemaphoreStub>(m, "SemaphoreStub")
.def(nb::init<std::shared_ptr<Connection>>(), nb::arg("connection"))
Expand Down Expand Up @@ -213,9 +212,9 @@ void register_core(nb::module_& m) {
.def("send_memory", &Communicator::sendMemory, nb::arg("memory"), nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("recv_memory", &Communicator::recvMemory, nb::arg("remoteRank"), nb::arg("tag") = 0)
.def("connect",
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(EndpointConfig, int, int)>(
&Communicator::connect),
nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0)
static_cast<std::shared_future<std::shared_ptr<Connection>> (Communicator::*)(
EndpointConfig, int, int, std::string)>(&Communicator::connect),
nb::arg("localConfig"), nb::arg("remoteRank"), nb::arg("tag") = 0, nb::arg("connName") = "core")
.def(
"connect",
[](Communicator* self, int remoteRank, int tag, EndpointConfig localConfig) {
Expand Down
16 changes: 11 additions & 5 deletions src/communicator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include "communicator.hpp"

#include "api.h"
#include "connection.hpp"
#include "debug.h"

namespace mscclpp {
Expand All @@ -15,6 +16,10 @@ Communicator::Impl::Impl(std::shared_ptr<Bootstrap> bootstrap, std::shared_ptr<C
} else {
context_ = context;
}
ConnectionFactory::registerConnection(
"core", [](std::shared_ptr<Context> context, Endpoint localEndpoint, Endpoint remoteEndpoint) {
return context->connect(localEndpoint, remoteEndpoint);
});
}

void Communicator::Impl::setLastRecvItem(int remoteRank, int tag, std::shared_ptr<BaseRecvItem> item) {
Expand Down Expand Up @@ -100,7 +105,8 @@ MSCCLPP_API_CPP std::shared_future<RegisteredMemory> Communicator::recvMemory(in
}

MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::connect(EndpointConfig localConfig,
int remoteRank, int tag) {
int remoteRank, int tag,
std::string connName) {
auto localEndpoint = context()->createEndpoint(localConfig);

if (remoteRank == bootstrap()->getRank()) {
Expand All @@ -115,17 +121,17 @@ MSCCLPP_API_CPP std::shared_future<std::shared_ptr<Connection>> Communicator::co

bootstrap()->send(localEndpoint.serialize(), remoteRank, tag);

auto future =
std::async(std::launch::deferred, [this, remoteRank, tag, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag),
localEndpoint = std::move(localEndpoint)]() mutable {
auto future = std::async(
std::launch::deferred, [this, remoteRank, tag, connName, lastRecvItem = pimpl_->getLastRecvItem(remoteRank, tag),
localEndpoint = std::move(localEndpoint)]() mutable {
if (lastRecvItem) {
// Recursive call to the previous receive items
lastRecvItem->wait();
}
std::vector<char> data;
bootstrap()->recv(data, remoteRank, tag);
auto remoteEndpoint = Endpoint::deserialize(data);
auto connection = context()->connect(localEndpoint, remoteEndpoint);
auto connection = ConnectionFactory::createConnection(connName, context(), localEndpoint, remoteEndpoint);
pimpl_->connectionInfos_[connection.get()] = {remoteRank, tag};
return connection;
});
Expand Down
25 changes: 25 additions & 0 deletions src/ext/connection/connection.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#include "ext/connection/connection.hpp"

#include "connection.hpp"

namespace mscclpp {


void IndirectConnection::write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) {
if (dstOffset + size > dst.size() || srcOffset + size > src.size()) {
throw Error("IndirectionConnection::write out of bounds", ErrorCode::InvalidUsage);
}
scheduler_ptr_->sched(dst, dstOffset, src, srcOffset, size);
}

void IndirectConnection::flush(int64_t timeoutUsec) {
if (timeoutUsec != -1) {
throw std::runtime_error("IndirectConnection does not support timeout in flush");
}
scheduler_ptr_->sync();
}
Transport IndirectConnection::transport() const { return Transport::CudaIpc; }
Transport IndirectConnection::remoteTransport() const { return Transport::CudaIpc; }

} // namespace mscclpp
17 changes: 17 additions & 0 deletions src/ext/connection/example.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include "connection.hpp"
#include "ext/connection/connection.hpp"

void test() {
auto context = mscclpp::Context::create();
auto localEndpoint = context->createEndpoint({mscclpp::Transport::CudaIpc});
auto remoteEndpoint = context->createEndpoint({mscclpp::Transport::CudaIpc});
mscclpp::Device fwd(mscclpp::DeviceType::GPU, 2);
std::shared_ptr<mscclpp::ConnectionScheduler> scheduler = std::make_shared<mscclpp::DefaultConnectionScheduler>(context, fwd);
context->set("scheduler", scheduler);
mscclpp::ConnectionFactory::registerConnection(
"indirect", [context](std::shared_ptr<mscclpp::Context> ctx, mscclpp::Endpoint local, mscclpp::Endpoint remote) {
std::shared_ptr<mscclpp::ConnectionScheduler> scheduler = std::static_pointer_cast<mscclpp::ConnectionScheduler>(context->get("scheduler"));
return std::make_shared<mscclpp::IndirectConnection>(ctx, local, scheduler);
});
auto connection = mscclpp::ConnectionFactory::createConnection("indirect", context, localEndpoint, remoteEndpoint);
}
24 changes: 24 additions & 0 deletions src/include/connection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,30 @@

namespace mscclpp {

class ConnectionFactory {
private:
using ConnectionCreator = std::function<std::shared_ptr<Connection>(std::shared_ptr<Context>, Endpoint, Endpoint)>;
static std::unordered_map<std::string, ConnectionCreator>& getRegistry() {
static std::unordered_map<std::string, ConnectionCreator> registry;
return registry;
}

public:
static void registerConnection(const std::string& connName, ConnectionCreator creator) {
getRegistry()[connName] = creator;
}

static std::shared_ptr<Connection> createConnection(const std::string& connName, std::shared_ptr<Context> context,
Endpoint localEndpoint, Endpoint remoteEndpoint) {
auto& registry = getRegistry();
auto it = registry.find(connName);
if (it != registry.end()) {
return it->second(context, localEndpoint, remoteEndpoint);
}
throw std::runtime_error("Unknown connection type: " + connName);
}
};

class CudaIpcConnection : public Connection {
private:
std::shared_ptr<CudaIpcStream> stream_;
Expand Down
52 changes: 52 additions & 0 deletions src/include/ext/connection/connection.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include "mscclpp/core.hpp"
#include "mscclpp/gpu_utils.hpp"

namespace mscclpp {

class ConnectionScheduler {
public:
virtual void sched(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) = 0;
virtual void sync() = 0;
};

class DefaultConnectionScheduler : public ConnectionScheduler {
public:
DefaultConnectionScheduler(std::shared_ptr<Context> context, Device device) : context_(context), device_(device) {}

void sched(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) override {
// Implementation for scheduling tasks
}

void sync() override {
// Implementation for synchronizing tasks
}

private:
std::shared_ptr<Context> context_;
Device device_;
};

class IndirectConnection : public Connection {
std::shared_ptr<ConnectionScheduler> scheduler_ptr_;

public:
IndirectConnection(std::shared_ptr<Context> context, Endpoint localEndpoint,
std::shared_ptr<ConnectionScheduler> scheduler)
: Connection(context, localEndpoint), scheduler_ptr_(scheduler) {
if (scheduler_ptr_ == nullptr) {
throw std::runtime_error("Scheduler not set in context");
}
}
void write(RegisteredMemory dst, uint64_t dstOffset, RegisteredMemory src, uint64_t srcOffset,
uint64_t size) override;
void flush(int64_t timeoutUsec = -1) override;
Transport transport() const override;
Transport remoteTransport() const override;

virtual void updateAndSync(RegisteredMemory dst, uint64_t dstOffset, uint64_t *src, uint64_t newValue) override {
throw std::runtime_error("IndirectConnection does not support updateAndSync");
}
};
} // namespace mscclpp
Loading