Skip to content

Commit 80abce5

Browse files
caiomcbrchhwang
andauthored
Flushing Proxy Channels at CPU side upon reaching the Inflight Request Limit (#415)
Co-authored-by: Changho Hwang <[email protected]>
1 parent 1989d4b commit 80abce5

File tree

6 files changed

+55
-13
lines changed

6 files changed

+55
-13
lines changed

include/mscclpp/core.hpp

Lines changed: 34 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,11 @@ class Endpoint {
388388
/// @return The transport used.
389389
Transport transport();
390390

391+
/// Get the maximum write queue size.
392+
///
393+
/// @return The maximum number of write requests that can be queued.
394+
int maxWriteQueueSize();
395+
391396
/// Serialize the Endpoint object to a vector of characters.
392397
///
393398
/// @return A vector of characters representing the serialized Endpoint object.
@@ -416,6 +421,10 @@ class Endpoint {
416421
/// Represents a connection between two processes.
417422
class Connection {
418423
public:
424+
/// Constructor.
425+
/// @param maxWriteQueueSize The maximum number of write requests that can be queued.
426+
Connection(int maxWriteQueueSize) : maxWriteQueueSize(maxWriteQueueSize){};
427+
419428
virtual ~Connection() = default;
420429

421430
/// Write data from a source @ref RegisteredMemory to a destination @ref RegisteredMemory.
@@ -454,10 +463,16 @@ class Connection {
454463
/// @return name of @ref transport() -> @ref remoteTransport()
455464
std::string getTransportName();
456465

466+
/// Get the maximum write queue size
467+
///
468+
/// @return The maximum number of write requests that can be queued.
469+
int getMaxWriteQueueSize();
470+
457471
protected:
458472
// Internal methods for getting implementation pointers.
459473
static std::shared_ptr<RegisteredMemory::Impl> getImpl(RegisteredMemory& memory);
460474
static std::shared_ptr<Endpoint::Impl> getImpl(Endpoint& memory);
475+
int maxWriteQueueSize;
461476
};
462477

463478
/// Used to configure an endpoint.
@@ -468,18 +483,29 @@ struct EndpointConfig {
468483
static const int DefaultMaxWrPerSend = 64;
469484

470485
Transport transport;
471-
int ibMaxCqSize = DefaultMaxCqSize;
472-
int ibMaxCqPollNum = DefaultMaxCqPollNum;
473-
int ibMaxSendWr = DefaultMaxSendWr;
474-
int ibMaxWrPerSend = DefaultMaxWrPerSend;
475-
476-
/// Default constructor. Sets transport to Transport::Unknown.
477-
EndpointConfig() : transport(Transport::Unknown) {}
486+
int ibMaxCqSize;
487+
int ibMaxCqPollNum;
488+
int ibMaxSendWr;
489+
int ibMaxWrPerSend;
490+
int maxWriteQueueSize;
478491

479492
/// Constructor that takes a transport and sets the other fields to their default values.
480493
///
481494
/// @param transport The transport to use.
482-
EndpointConfig(Transport transport) : transport(transport) {}
495+
/// @param ibMaxCqSize The maximum completion queue size.
496+
/// @param ibMaxCqPollNum The maximum completion queue poll number.
497+
/// @param ibMaxSendWr The maximum send work requests.
498+
/// @param ibMaxWrPerSend The maximum work requests per send.
499+
/// @param maxWriteQueueSize The maximum write queue size.
500+
EndpointConfig(Transport transport = Transport::Unknown, int ibMaxCqSize = DefaultMaxCqSize,
501+
int ibMaxCqPollNum = DefaultMaxCqPollNum, int ibMaxSendWr = DefaultMaxSendWr,
502+
int ibMaxWrPerSend = DefaultMaxWrPerSend, int maxWriteQueueSize = -1)
503+
: transport(transport),
504+
ibMaxCqSize(ibMaxCqSize),
505+
ibMaxCqPollNum(ibMaxCqPollNum),
506+
ibMaxSendWr(ibMaxSendWr),
507+
ibMaxWrPerSend(ibMaxWrPerSend),
508+
maxWriteQueueSize(maxWriteQueueSize) {}
483509
};
484510

485511
/// Represents a context for communication. This provides a low-level interface for forming connections in use-cases

include/mscclpp/proxy_channel.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class ProxyService : public BaseProxyService {
7272
std::vector<RegisteredMemory> memories_;
7373
std::shared_ptr<Proxy> proxy_;
7474
int deviceNumaNode;
75+
std::unordered_map<std::shared_ptr<Connection>, int> inflightRequests;
7576

7677
void bindThread();
7778

src/connection.cc

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,10 +36,12 @@ std::string Connection::getTransportName() {
3636
TransportNames[static_cast<int>(this->remoteTransport())];
3737
}
3838

39+
int Connection::getMaxWriteQueueSize() { return maxWriteQueueSize; }
40+
3941
// CudaIpcConnection
4042

4143
CudaIpcConnection::CudaIpcConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, cudaStream_t stream)
42-
: stream_(stream) {
44+
: Connection(localEndpoint.maxWriteQueueSize()), stream_(stream) {
4345
if (localEndpoint.transport() != Transport::CudaIpc) {
4446
throw mscclpp::Error("Cuda IPC connection can only be made from a Cuda IPC endpoint", ErrorCode::InvalidUsage);
4547
}
@@ -119,7 +121,9 @@ void CudaIpcConnection::flush(int64_t timeoutUsec) {
119121
// IBConnection
120122

121123
IBConnection::IBConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, Context& context)
122-
: transport_(localEndpoint.transport()),
124+
: Connection(localEndpoint.maxWriteQueueSize() != -1 ? localEndpoint.maxWriteQueueSize()
125+
: EndpointConfig::DefaultMaxCqSize),
126+
transport_(localEndpoint.transport()),
123127
remoteTransport_(remoteEndpoint.transport()),
124128
dummyAtomicSource_(std::make_unique<uint64_t>(0)) {
125129
qp = getImpl(localEndpoint)->ibQp_;
@@ -231,7 +235,10 @@ void IBConnection::flush(int64_t timeoutUsec) {
231235

232236
EthernetConnection::EthernetConnection(Endpoint localEndpoint, Endpoint remoteEndpoint, uint64_t sendBufferSize,
233237
uint64_t recvBufferSize)
234-
: abortFlag_(0), sendBufferSize_(sendBufferSize), recvBufferSize_(recvBufferSize) {
238+
: Connection(localEndpoint.maxWriteQueueSize()),
239+
abortFlag_(0),
240+
sendBufferSize_(sendBufferSize),
241+
recvBufferSize_(recvBufferSize) {
235242
// Validating Transport Protocol
236243
if (localEndpoint.transport() != Transport::Ethernet || remoteEndpoint.transport() != Transport::Ethernet) {
237244
throw mscclpp::Error("Ethernet connection can only be made from Ethernet endpoints", ErrorCode::InvalidUsage);

src/endpoint.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
namespace mscclpp {
1414

1515
Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
16-
: transport_(config.transport), hostHash_(getHostHash()) {
16+
: transport_(config.transport), hostHash_(getHostHash()), maxWriteQueueSize_(config.maxWriteQueueSize) {
1717
if (AllIBTransports.has(transport_)) {
1818
ibLocal_ = true;
1919
ibQp_ = contextImpl.getIbContext(transport_)
@@ -34,6 +34,8 @@ Endpoint::Impl::Impl(EndpointConfig config, Context::Impl& contextImpl)
3434

3535
MSCCLPP_API_CPP Transport Endpoint::transport() { return pimpl_->transport_; }
3636

37+
MSCCLPP_API_CPP int Endpoint::maxWriteQueueSize() { return pimpl_->maxWriteQueueSize_; }
38+
3739
MSCCLPP_API_CPP std::vector<char> Endpoint::serialize() {
3840
std::vector<char> data;
3941
std::copy_n(reinterpret_cast<char*>(&pimpl_->transport_), sizeof(pimpl_->transport_), std::back_inserter(data));

src/include/endpoint.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ struct Endpoint::Impl {
2020

2121
Transport transport_;
2222
uint64_t hostHash_;
23+
int maxWriteQueueSize_;
2324

2425
// The following are only used for IB and are undefined for other transports.
2526
bool ibLocal_;

src/proxy_channel.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,21 +70,26 @@ ProxyHandlerResult ProxyService::handleTrigger(ProxyTrigger triggerRaw) {
7070
std::shared_ptr<Host2DeviceSemaphore> semaphore = semaphores_[trigger->fields.chanId];
7171

7272
auto result = ProxyHandlerResult::Continue;
73+
int maxWriteQueueSize = semaphore->connection()->getMaxWriteQueueSize();
7374

7475
if (trigger->fields.type & TriggerData) {
7576
RegisteredMemory& dst = memories_[trigger->fields.dstMemoryId];
7677
RegisteredMemory& src = memories_[trigger->fields.srcMemoryId];
7778
semaphore->connection()->write(dst, trigger->fields.dstOffset, src, trigger->fields.srcOffset,
7879
trigger->fields.size);
80+
inflightRequests[semaphore->connection()]++;
7981
}
8082

8183
if (trigger->fields.type & TriggerFlag) {
8284
semaphore->signal();
85+
inflightRequests[semaphore->connection()]++;
8386
}
8487

85-
if (trigger->fields.type & TriggerSync) {
88+
if (trigger->fields.type & TriggerSync ||
89+
(maxWriteQueueSize != -1 && inflightRequests[semaphore->connection()] > maxWriteQueueSize)) {
8690
semaphore->connection()->flush();
8791
result = ProxyHandlerResult::FlushFifoTailAndContinue;
92+
inflightRequests[semaphore->connection()] = 0;
8893
}
8994

9095
return result;

0 commit comments

Comments
 (0)