Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion csrc/engine/rdma/memory_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ int RDMAMemoryPool::register_memory_region(const std::string& mr_key, uintptr_t

SLIME_ASSERT(mr, " Failed to register memory " << data_ptr);

SLIME_LOG_DEBUG("Memory region: " << (void*)data_ptr << " -- " << (void*)(data_ptr + length)
SLIME_LOG_DEBUG("Memory region: " << mr_key << ", " << (void*)data_ptr << " -- " << (void*)(data_ptr + length)
<< ", Device name: " << pd_->context->device->dev_name << ", Length: " << length
<< " (" << length / 1024 / 1024 << " MB)"
<< ", Permission: " << access_rights << ", LKey: " << mr->lkey
Expand All @@ -47,6 +47,7 @@ int RDMAMemoryPool::register_remote_memory_region(const std::string& mr_key,
{
std::unique_lock<std::mutex> lock(remote_mrs_mutex_);
remote_mrs_[mr_key] = remote_mr_t(addr, length, rkey);
SLIME_LOG_DEBUG("Remote memory region registered: " << mr_key << ", " << addr << ", " << length << ", " << rkey << ".");
return 0;
}

Expand All @@ -55,6 +56,7 @@ int RDMAMemoryPool::register_remote_memory_region(const std::string& mr_key, con
std::unique_lock<std::mutex> lock(remote_mrs_mutex_);
remote_mrs_[mr_key] =
remote_mr_t(mr_info["addr"].get<uintptr_t>(), mr_info["length"].get<size_t>(), mr_info["rkey"].get<uint32_t>());
SLIME_LOG_DEBUG("Remote memory region registered: " << mr_key << ", " << mr_info << ".");
return 0;
}

Expand Down
38 changes: 18 additions & 20 deletions csrc/engine/rdma/rdma_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,27 +5,26 @@ namespace slime {

void RDMABuffer::send()
{
send_pending_ = true;
send_completed_ = false;

end_point_->addSendTask(data_info, [this]() {
std::unique_lock<std::mutex> lock(send_mutex_);
send_completed_ = true;
send_pending_ = false;
send_cv_.notify_all();
});
endpoint_->addSendTask(shared_from_this());
}

void RDMABuffer::recv()
{
recv_pending_ = true;
recv_completed_ = false;
end_point_->addRecvTask(data_info, [this]() {
std::unique_lock<std::mutex> lock(recv_mutex_);
recv_completed_ = true;
recv_pending_ = false;
recv_cv_.notify_all();
});
endpoint_->addRecvTask(shared_from_this());
}

void RDMABuffer::send_done_callback()
{
std::unique_lock<std::mutex> lock(send_mutex_);
++send_completed_;
send_cv_.notify_all();
}

void RDMABuffer::recv_done_callback()
{
std::unique_lock<std::mutex> lock(recv_mutex_);
++recv_completed_;
recv_cv_.notify_all();
}

bool RDMABuffer::waitSend()
Expand All @@ -35,10 +34,9 @@ bool RDMABuffer::waitSend()
if (send_completed_)
return send_completed_;

send_cv_.wait(lock, [this]() { return send_completed_; });
send_cv_.wait(lock, [this]() { return send_completed_ > 0; });
send_pending_ = false;
SLIME_LOG_INFO("complete to send the data.");

return send_completed_;
}

Expand All @@ -50,7 +48,7 @@ bool RDMABuffer::waitRecv()
return recv_completed_;

// waiting for the recv complete...
recv_cv_.wait(lock, [this]() { return recv_completed_; });
recv_cv_.wait(lock, [this]() { return recv_completed_ > 0; });
recv_pending_ = false;
SLIME_LOG_INFO("complete to recv the data.");

Expand Down
46 changes: 34 additions & 12 deletions csrc/engine/rdma/rdma_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,51 +13,73 @@
#include <unordered_map>
#include <vector>

#include "rdma_common.h"

namespace slime {

class RDMAEndpoint;

class RDMABuffer : public std::enable_shared_from_this<RDMABuffer> {
class RDMABuffer: public std::enable_shared_from_this<RDMABuffer> {
friend class RDMAEndpoint;

public:
RDMABuffer(std::shared_ptr<RDMAEndpoint> end_point,
RDMABuffer(std::shared_ptr<RDMAEndpoint> endpoint, storage_view_batch_t& batch):
endpoint_(endpoint), storage_view_batch_(std::move(batch))
{
}

RDMABuffer(std::shared_ptr<RDMAEndpoint> endpoint,
std::vector<uintptr_t> ptrs,
std::vector<size_t> offset,
std::vector<size_t> data_size)
{
batch_size_ = ptrs.size();
for (uint32_t i = 0; i < batch_size_; ++i) {
data_info.push_back(std::make_tuple(ptrs[i], data_size[i], offset[i]));
storage_view_t view{.data_ptr = ptrs[i], .storage_offset = offset[i], .length = data_size[i]};
storage_view_batch_.push_back(view);
}
end_point_ = end_point;
endpoint_ = endpoint;
}

~RDMABuffer() = default;

void send();
const size_t batchSize()
{
return storage_view_batch_.size();
}

const storage_view_batch_t& storageViewBatch()
{
return storage_view_batch_;
}

void send();
void recv();

bool waitSend();

bool waitRecv();

void send_done_callback();
void recv_done_callback();

private:
std::shared_ptr<RDMAEndpoint> end_point_;
std::shared_ptr<RDMAEndpoint> endpoint_;

// <tensor_ptrs_, tensor_size_, offset>
// tensor_ptrs: the pointer of the tensor
// tensor_size: the length of the tensor
// offset: the offset of the transmitted tensor
std::vector<std::tuple<uintptr_t, size_t, size_t>> data_info;
// std::vector<std::tuple<uintptr_t, size_t, size_t>> data_info;

storage_view_batch_t storage_view_batch_;

size_t batch_size_;

bool send_pending_{false};
bool recv_pending_{false};
std::atomic<int> send_pending_{0};
std::atomic<int> recv_pending_{0};

bool send_completed_{false};
bool recv_completed_{false};
std::atomic<int> send_completed_{0};
std::atomic<int> recv_completed_{0};

std::condition_variable send_cv_;
std::condition_variable recv_cv_;
Expand Down
17 changes: 17 additions & 0 deletions csrc/engine/rdma/rdma_common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#pragma once

#include <cstdint>
#include <cstdlib>
#include <vector>

namespace slime {

typedef struct StorageView {
uintptr_t data_ptr;
size_t storage_offset;
size_t length;
} storage_view_t;

using storage_view_batch_t = std::vector<storage_view_t>;

}
4 changes: 3 additions & 1 deletion csrc/engine/rdma/rdma_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,10 @@ int64_t RDMAContext::post_send_batch(int qpi, RDMAAssignmentSharedPtr assign)
memset(&wr[i], 0, sizeof(ibv_send_wr));
wr[i].wr_id =
(i == batch_size - 1) ? (uintptr_t)(new callback_info_with_qpi_t{assign->callback_info_, qpi}) : 0;
wr[i].opcode = IBV_WR_SEND;
wr[i].opcode = ASSIGN_OP_2_IBV_WR_OP.at(assign->opcode_);
wr[i].sg_list = &sge[i];
wr[i].num_sge = 1;
wr[i].imm_data = (i == batch_size - 1) ? assign->imm_data_ : UNDEFINED_IMM_DATA;
wr[i].send_flags = (i == batch_size - 1) ? IBV_SEND_SIGNALED : 0;
wr[i].next = (i == batch_size - 1) ? nullptr : &wr[i + 1];
}
Expand Down Expand Up @@ -584,6 +585,7 @@ int64_t RDMAContext::cq_poll_handle()
callback_with_qpi->callback_info_->callback_(status_code, wc[i].imm_data);
break;
case OpCode::SEND:
case OpCode::SEND_WITH_IMM:
callback_with_qpi->callback_info_->callback_(status_code, wc[i].imm_data);
break;
case OpCode::RECV:
Expand Down
7 changes: 6 additions & 1 deletion csrc/engine/rdma/rdma_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,11 @@ class RDMAContext {
SLIME_LOG_DEBUG("RDMAContext deconstructed")
}

struct ibv_mr* get_mr(const std::string& mr_key)
{
return memory_pool_->get_mr(mr_key);
}

/* Initialize */
int64_t init(const std::string& dev_name, uint8_t ib_port, const std::string& link_type);

Expand Down Expand Up @@ -204,7 +209,7 @@ class RDMAContext {
size_t qp_list_len_{1};
qp_management_t** qp_management_;

int last_qp_selection_ = -1;
int last_qp_selection_{-1};
int select_qpi()
{
// Simplest round robin, we could enrich it in the future
Expand Down
Loading