Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ class Val;
f(SdpaFwdOp); \
f(SdpaBwdOp); \
f(EmbeddingFwdOp); \
f(CollectivePermute); \
f(Communication); \
f(P2PCommunication);
#define DISPATCH_FOR_ALL_KIR_EXPRS(f) \
Expand Down
32 changes: 32 additions & 0 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,38 @@ void HostIrEvaluator::handle(ShareMemHandles* share_mem_handles) {
ipc_handle_cache_.exchangeHandles(share_mem_handles->communications());
}

void HostIrEvaluator::handle(CollectivePermute* communication) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
"A valid communicator must be provided");

at::Tensor input_tensor = getKnownTensorOrUndefined(communication->input(0));
at::Tensor output_tensor =
getKnownTensorOrUndefined(communication->output(0));

#ifndef NDEBUG
validateSizesAndStrides(
{input_tensor, output_tensor},
{communication->in(), communication->out()},
expr_evaluator_);
#endif

CommunicatorBackend backend_type = communication->backend();
// CollectivePermute is only supported with NCCL backend because
// UCC does not support coalescing.
NVF_CHECK_EQ(backend_type, CommunicatorBackend::kNccl);
c10d::Backend* backend =
communicator_->getBackendForTeam(communication->team(), backend_type);
works_[communication] = postSingleCommunication(
communication,
communicator_->deviceId(),
backend,
input_tensor,
output_tensor,
expr_evaluator_.evaluate(communication->sendPeer()).as<int64_t>(),
expr_evaluator_.evaluate(communication->recvPeer()).as<int64_t>());
}

void HostIrEvaluator::handle(Communication* communication) {
NVF_ERROR(
communicator_ != nullptr && communicator_->is_available(),
Expand Down
1 change: 1 addition & 0 deletions csrc/host_ir/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch {
void handle(Synchronize*) override;
void handle(PostOnStream*) override;
void handle(LaunchKernel*) override;
void handle(CollectivePermute*) override;
void handle(Communication*) override;
void handle(P2PCommunication*) override;
void handle(MoeDispatch*) override;
Expand Down
10 changes: 7 additions & 3 deletions csrc/host_ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,9 +258,13 @@ Wait::Wait(IrBuilderPasskey passkey, Expr* expr)
this,
"must be registered in a HostIrContainer");
NVF_ERROR(
(expr->isOneOf<Communication, P2PCommunication, EndCoalescing>()),
expr,
" must be a Communication, a P2PCommunication, or a EndCoalescing");
(expr->isOneOf<
Communication,
CollectivePermute,
P2PCommunication,
EndCoalescing>()),
"Got: ",
expr);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Wait)
Expand Down
40 changes: 37 additions & 3 deletions csrc/host_ir/lower_to_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,33 @@ void lowerToAllToAll(
backend));
}

void lowerToCollectivePermute(
TensorView* input_tv,
TensorView* output_tv,
const CommunicatorBackend backend,
std::vector<Expr*>& comms,
Val* root,
DeviceIdxType my_device_idx) {
NVF_ERROR_EQ(
input_tv->getDeviceMesh(),
output_tv->getDeviceMesh(),
"CollectivePermute sender and receiver meshes must be the same. Given ",
input_tv->getDeviceMesh(),
" and ",
output_tv->getDeviceMesh());

IterDomain* stream_id =
getShardedIterDomain(output_tv, ParallelType::Stream, DomainType::kLoop);
Swizzle1D* swizzle = stream_id->definition()->as<Swizzle1D>();
ParallelType pt = swizzle->parallelType();

const auto& [recv_peer, send_peer] =
dispatchSwizzle1D(root, my_device_idx, pt, input_tv->getDeviceMesh());
Team team = input_tv->getDeviceMesh().vector();
comms.push_back(IrBuilder::create<CollectivePermute>(
output_tv, input_tv, team, send_peer, recv_peer, backend));
}

IterDomain* getLogicalFromLoopId(TensorView* tv, IterDomain* loop_id) {
std::vector<IterDomain*> logical_ids =
ir_utils::getReachableIds(tv->getLogicalDomain(), {loop_id});
Expand Down Expand Up @@ -399,15 +426,18 @@ std::optional<CommunicationInfo> getCommunicationInfoForParallelType(

if (p_loop_id && !c_loop_id) {
// Check if we are going from DID -> Stream, which is a ring allgather.
// This can be executed as a broadcast or send recvs, which is decided
// This can be executed as a broadcast or collective permute, which is decided
// by the presence of a swizzle in the stream id definition.
if (c_logical_stream_id == p2c.at(p_logical_id)) {
NVF_CHECK(
same_mesh,
"Broadcast based allgather in stream parallel requires same "
"mesh.")
CommunicationType type = c_stream_id->definition()->isA<Swizzle1D>()
? CommunicationType::CollectivePermute
: CommunicationType::Broadcast;
return CommunicationInfo{
.type = CommunicationType::Broadcast,
.type = type,
.p_sharded_id = p_logical_id,
.c_sharded_id = c_logical_stream_id};
}
Expand Down Expand Up @@ -525,7 +555,8 @@ Layout getCommunicationLayout(
type == CommunicationType::Allreduce ||
type == CommunicationType::Broadcast ||
type == CommunicationType::SendRecv ||
type == CommunicationType::AllToAll) {
type == CommunicationType::AllToAll ||
type == CommunicationType::CollectivePermute) {
return layout;
}

Expand Down Expand Up @@ -667,6 +698,9 @@ std::vector<Expr*> convertSingleOpToCommunication(
case CommunicationType::AllToAll:
lowerToAllToAll(input_tv, output_tv, backend, comms);
break;
case CommunicationType::CollectivePermute:
lowerToCollectivePermute(input_tv, output_tv, backend, comms, root, my_device_idx);
break;
}

return comms;
Expand Down
17 changes: 15 additions & 2 deletions csrc/host_ir/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,11 @@
#include "host_ir/ir.h"
#include "host_ir/lower_to_communication.h"
#include "host_ir/ops.h"
#include "ir/builder.h"
#include "ir/iostream.h"
#include "ir/utils.h"
#include "iter_visitor.h"
#include "kernel_ir.h"
#include "multidevice/propagation.h"
#include "multidevice/resharding.h"
#include "multidevice/utils.h"
Expand Down Expand Up @@ -231,10 +234,20 @@ void lowerSegment(
Val* root = loop_nest.empty() ? nullptr : innermost.loop->index();
for (Expr* c : convertSingleOpToCommunication(e, device_id, root)) {
NVF_ERROR(
c->isA<Communication>(),
"Exprs in a Communication group should be Communication: ",
c->isA<Communication>() || c->isA<CollectivePermute>(),
"Exprs in a Communication group should be Communication or CollectivePermute: ",
c);

if (auto* cp = dynamic_cast<CollectivePermute*>(c)) {
auto add_definition_chain = [&innermost_scope](Val* val) -> void {
for (Expr* expr : StmtSort::getExprsTo({val})) {
innermost_scope.pushBack(expr);
}
};
add_definition_chain(cp->sendPeer());
add_definition_chain(cp->recvPeer());
}

Expr* new_c = cloneWithNewOperands(c, replacement_map);
innermost_scope.pushBack(new_c);

Expand Down
45 changes: 45 additions & 0 deletions csrc/multidevice/communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@ std::ostream& operator<<(std::ostream& os, const CommunicationType& type) {
case CommunicationType::AllToAll:
os << "AllToAll";
break;
case CommunicationType::CollectivePermute:
os << "CollectivePermute";
break;
}
return os;
}
Expand All @@ -64,6 +67,7 @@ bool hasRoot(CommunicationType type) {
case CommunicationType::Allreduce:
case CommunicationType::ReduceScatter:
case CommunicationType::AllToAll:
case CommunicationType::CollectivePermute:
return false;
}
std::unreachable();
Expand Down Expand Up @@ -216,6 +220,47 @@ std::string P2PCommunication::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

CollectivePermute::CollectivePermute(
IrBuilderPasskey passkey,
TensorView* out,
TensorView* in,
Team team,
Val* send_peer,
Val* recv_peer,
CommunicatorBackend backend)
: Expr(passkey) {
NVF_ERROR(
in->getDeviceMesh().size() > 0,
"The input mesh size must be greater than 0.");
NVF_ERROR(
out->getDeviceMesh().size() > 0,
"The output mesh size must be greater than 0.");
addInput(in);
addInput(send_peer);
addInput(recv_peer);
addOutput(out);
addDataAttribute(CommunicationType::CollectivePermute);
addDataAttribute(team);
addDataAttribute(backend);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(CollectivePermute)

std::string CollectivePermute::toInlineString(const int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << "CollectivePermute " << name() << " ("
<< "team=(" << team() << ")"
<< ", send_peer=" << sendPeer()->toInlineString()
<< ", recv_peer=" << recvPeer()->toInlineString()
<< ", input=" << in() << ", output=" << out()
<< ", backend=" << backend() << ")";
return ss.str();
}

std::string CollectivePermute::toString(int indent_size) const {
return toInlineString(indent_size) + "\n";
}

MoeDispatch::MoeDispatch(
IrBuilderPasskey passkey,
TensorView* out_x,
Expand Down
58 changes: 57 additions & 1 deletion csrc/multidevice/communication.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ enum class CommunicationType {
ReduceScatter,
Broadcast,
SendRecv,
AllToAll
AllToAll,
CollectivePermute
};

std::ostream& operator<<(std::ostream& os, const CommunicationType& type);
Expand Down Expand Up @@ -122,6 +123,61 @@ class Communication : public Expr {
void validate();
};

// CollectivePermute: send to send_peer, recv from recv_peer. Separate from
// Communication (no root, no reduce op).
class CollectivePermute : public Expr {
public:
using Expr::Expr;

CollectivePermute(
IrBuilderPasskey passkey,
TensorView* out,
TensorView* in,
Team team,
Val* send_peer,
Val* recv_peer,
CommunicatorBackend backend = CommunicatorBackend::kNccl);

CollectivePermute(const CollectivePermute& other) = delete;
CollectivePermute& operator=(const CollectivePermute& other) = delete;
CollectivePermute(CollectivePermute&& other) = delete;
CollectivePermute& operator=(CollectivePermute&& other) = delete;

NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "CollectivePermute";
}

CommunicationType type() const {
return attribute<CommunicationType>(0);
}

TensorView* in() const {
return input(0)->as<TensorView>();
}
TensorView* out() const {
return output(0)->as<TensorView>();
}
Val* sendPeer() const {
return input(1);
}
Val* recvPeer() const {
return input(2);
}
const Team& team() const {
return attribute<Team>(1);
}
int64_t team_size() const {
return static_cast<int64_t>(team().size());
}
CommunicatorBackend backend() const {
return attribute<CommunicatorBackend>(2);
}
};

enum class P2PCommunicationType { SEND, RECV };

std::ostream& operator<<(std::ostream& os, const P2PCommunicationType& type);
Expand Down
60 changes: 60 additions & 0 deletions csrc/multidevice/post_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,33 @@ c10::intrusive_ptr<c10d::Work> postRecv(
return backend->recv(packed_buffer, static_cast<int>(peer), /*tag=*/0);
}

c10::intrusive_ptr<c10d::Work> postCollectivePermute(
CollectivePermute* communication,
DeviceIdxType my_device_index,
DeviceIdxType send_peer_index,
DeviceIdxType recv_peer_index,
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor) {
if (my_device_index == send_peer_index &&
my_device_index == recv_peer_index) {
doLocalCopy(output_tensor, input_tensor);
return nullptr;
}
backend->startCoalescing();
std::vector<at::Tensor> send_tensors = {input_tensor};
backend->send(
send_tensors,
send_peer_index,
/*tag=*/0);
std::vector<at::Tensor> recv_tensors = {output_tensor};
backend->recv(
recv_tensors,
recv_peer_index,
/*tag=*/0);
return backend->endCoalescing();
}

} // namespace

c10::intrusive_ptr<c10d::Work> postSingleCommunication(
Expand Down Expand Up @@ -561,4 +588,37 @@ c10::intrusive_ptr<c10d::Work> postSingleCommunication(
}
}

c10::intrusive_ptr<c10d::Work> postSingleCommunication(
CollectivePermute* communication,
DeviceIdxType my_device_index,
c10d::Backend* backend,
at::Tensor input_tensor,
at::Tensor output_tensor,
DeviceIdxType send_peer_index,
DeviceIdxType recv_peer_index) {
const Team& team = communication->team();
if (std::find(team.begin(), team.end(), my_device_index) == team.end()) {
return nullptr;
}
NVF_CHECK(backend != nullptr);

if (isDebugDumpEnabled(DebugDumpOption::Communication) &&
my_device_index == 0) {
debug() << "Posting " << communication->toInlineString()
<< " with input_tensor " << input_tensor.sizes()
<< " and output_tensor " << output_tensor.sizes()
<< " send_peer=" << send_peer_index
<< " recv_peer=" << recv_peer_index << std::endl;
}

return postCollectivePermute(
communication,
my_device_index,
send_peer_index,
recv_peer_index,
backend,
input_tensor,
output_tensor);
}

} // namespace nvfuser
Loading
Loading