Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions csrc/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ class Val;
f(ShareMemHandles); \
f(HirAliasSelect); \
f(ShardByStream); \
f(Allocate); \
f(Deallocate); \
f(ForLoop); \
f(SymmetricContiguousView);
Expand Down
3 changes: 2 additions & 1 deletion csrc/host_ir/allocate_and_deallocate.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <unordered_set>
#include <vector>

#include "host_ir/ir.h"
#include "ir/builder.h"
#include "ir/utils.h"

Expand Down Expand Up @@ -170,7 +171,7 @@ void insertAllocations(hir::HostIrContainer& hic) {

if (needsOutputPreallocation(e)) {
auto* allocate =
IrBuilder::create<kir::Allocate>(out, out->getMemoryType());
IrBuilder::create<hir::Allocate>(out, out->getMemoryType());
node->scope()->insert(node->iterator(), allocate);
}

Expand Down
22 changes: 22 additions & 0 deletions csrc/host_ir/evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,28 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) {
expr_evaluator_.bind(tv, tensor);
}

void HostIrEvaluator::handle(hir::Allocate* allocate) {
FUSER_PERF_SCOPE("HostIrEvaluator::handle(Allocate)");
auto* tv = allocate->in();

GlobalBufferInfo info =
getBufferInfos(expr_evaluator_, PrimDataType::Int, {tv}).at(0);
c10::Device device =
communicator_ ? communicator_->device() : at::Device("cuda:0");
at::Tensor tensor = at::native::empty_strided_cuda(
info.shape_info.logical_sizes,
info.shape_info.logical_strides,
info.type,
c10::nullopt,
device,
c10::nullopt);

if (allocate->zeroInit()) {
tensor.zero_();
}
expr_evaluator_.bind(tv, tensor);
}

void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) {
auto indexed_id =
hir_alias_select->in()->getLogicalDomain().at(hir_alias_select->axis());
Expand Down
1 change: 1 addition & 0 deletions csrc/host_ir/evaluator.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ class NVF_API HostIrEvaluator final : public OptOutDispatch {
void handle(MatmulOp*) override;
void handle(LinearOp*) override;
void handle(kir::Allocate*) override;
void handle(Allocate*) override;
void handle(LoadStoreOp*) override;
void handle(BinaryOp*) override;
void handle(ReductionOp*) override;
Expand Down
35 changes: 35 additions & 0 deletions csrc/host_ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -504,4 +504,39 @@ std::string ForLoop::toInlineString(int indent_size) const {
index, iter_domain->start(), iter_domain->stop());
}

Allocate::Allocate(
IrBuilderPasskey passkey,
Val* in,
MemoryType memory_type,
Comment on lines +509 to +510
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Val* in,
MemoryType memory_type,
TensorView* in,

bool zero_init)
: Expr(passkey) {
NVF_ERROR(passkey.ir_container_ != nullptr);
NVF_ERROR(passkey.ir_container_->isA<HostIrContainer>());
NVF_ERROR(in->isA<TensorView>(), "hir::Allocate input must be a TensorView.");

addInput(in);
addDataAttribute(memory_type);
addDataAttribute(zero_init);
}

NVFUSER_DEFINE_CLONE_AND_CREATE(Allocate)

std::string Allocate::toString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << in()->toString() << " = ALLOCATE("
<< "mem_type=" << memoryType() << ", "
<< "zero_init=" << std::boolalpha << zeroInit() << ")"
<< std::endl;
return ss.str();
}

std::string Allocate::toInlineString(int indent_size) const {
std::stringstream ss;
indent(ss, indent_size) << in()->toInlineString() << " = ALLOCATE("
<< "mem_type=" << memoryType() << ", "
<< "zero_init=" << std::boolalpha << zeroInit()
<< ")";
return ss.str();
}

} // namespace nvfuser::hir
36 changes: 36 additions & 0 deletions csrc/host_ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,42 @@ class LaunchKernel : public Expr {
CompiledKernel* compiled_kernel_ = nullptr;
};

class Allocate : public Expr {
public:
using Expr::Expr;

explicit Allocate(
IrBuilderPasskey passkey,
Val* in,
MemoryType memory_type,
bool zero_init = false);

Allocate(const Allocate& other) = delete;
Allocate& operator=(const Allocate& other) = delete;
Allocate(Allocate&& other) = delete;
Allocate& operator=(Allocate&& other) = delete;

NVFUSER_DECLARE_CLONE_AND_CREATE

std::string toString(int indent_size = 0) const override;
std::string toInlineString(int indent_size = 0) const override;
const char* getOpString() const override {
return "hir::Allocate";
}

TensorView* in() const {
return inputs().at(0)->as<TensorView>();
}

MemoryType memoryType() const {
return attribute<MemoryType>(0);
}

bool zeroInit() const {
return attribute<bool>(1);
}
};

class Deallocate : public Expr {
public:
using Expr::Expr;
Expand Down
19 changes: 7 additions & 12 deletions csrc/host_ir/jit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ class HostIrCompileDispatcher : public OptInDispatch {
smem});
}

void handle(kir::Allocate* allocate) final {
void handle(hir::Allocate* allocate) final {
llvm::LLVMContext& context = builder_.getContext();
llvm::Module* module = builder_.GetInsertBlock()->getParent()->getParent();

Expand All @@ -765,14 +765,10 @@ class HostIrCompileDispatcher : public OptInDispatch {
llvm::SmallVector<llvm::Value*, kMaxTensorDim> tensor_sizes;
llvm::SmallVector<llvm::Value*, kMaxTensorDim> tensor_strides;
inferTensorShapesAndStrides(
allocate->buffer()->as<TensorView>(),
val_to_value_,
builder_,
tensor_sizes,
tensor_strides);
allocate->in(), val_to_value_, builder_, tensor_sizes, tensor_strides);

const std::vector<IterDomain*>& logical_domain = TensorDomain::noReductions(
allocate->buffer()->as<TensorView>()->getLogicalDomain());
const std::vector<IterDomain*>& logical_domain =
TensorDomain::noReductions(allocate->in()->getLogicalDomain());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider | TensorDomain::kNoReductions


NVF_ERROR_EQ(tensor_sizes.size(), logical_domain.size());

Expand Down Expand Up @@ -811,9 +807,8 @@ class HostIrCompileDispatcher : public OptInDispatch {

// Create constants for type and device from params
at::ScalarType data_type = data_type_to_aten(
allocate->buffer()->dtype() == DataType::Index
? PrimDataType::Int
: allocate->buffer()->dtype());
allocate->in()->dtype() == DataType::Index ? PrimDataType::Int
: allocate->in()->dtype());
llvm::Value* dtype_constant =
builder_.getInt32(static_cast<int32_t>(data_type));
llvm::Value* device_index_constant =
Expand All @@ -833,7 +828,7 @@ class HostIrCompileDispatcher : public OptInDispatch {
dtype_constant,
device_index_constant,
out_tensor});
val_to_value_[allocate->buffer()] = out_tensor;
val_to_value_[allocate->in()] = out_tensor;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

zeroInit() flag not handled in JIT path - tensor may contain uninitialized data when allocate->zeroInit() is true

}

void handle(hir::Deallocate* deallocate) final {
Expand Down
6 changes: 3 additions & 3 deletions csrc/host_ir/lowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ void lowerSegment(

// Allocate the recv buffers of communications
auto* allocate =
IrBuilder::create<kir::Allocate>(out, out->getMemoryType());
IrBuilder::create<hir::Allocate>(out, out->getMemoryType());
if (getShardedIterDomain(
out, ParallelType::Stream, DomainType::kLoop) != nullptr &&
getShardedIterDomain(
Expand Down Expand Up @@ -311,7 +311,7 @@ void lowerSegment(
out, ParallelType::Stream, DomainType::kAllocation) ==
nullptr) {
auto* allocate =
IrBuilder::create<kir::Allocate>(out, out->getMemoryType());
IrBuilder::create<hir::Allocate>(out, out->getMemoryType());
innermost.parent_scope->insert(
innermost.parent_insertion_point, allocate);
// Loop is stream parallelized but allocation is not. Therefore,
Expand Down Expand Up @@ -348,7 +348,7 @@ void lowerSegment(
alias);

auto* allocate =
IrBuilder::create<kir::Allocate>(out_tv, out_tv->getMemoryType());
IrBuilder::create<hir::Allocate>(out_tv, out_tv->getMemoryType());
innermost_scope.pushBack(allocate);
}

Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/test_host_ir_evaluator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ TEST_F(HostIrEvaluatorTest, LaunchKernel) {
Val* in = ir_cloner.clone(fusion.inputs().at(0));
Val* out = ir_cloner.clone(fusion.outputs().at(0));

auto allocate = IrBuilder::create<kir::Allocate>(out, MemoryType::Global);
auto allocate = IrBuilder::create<hir::Allocate>(out, MemoryType::Global);
auto* cache_id =
IrBuilder::create<NamedScalar>("cacheId", DataType::UInt64);
auto launch_kernel = IrBuilder::create<LaunchKernel>(
Expand Down Expand Up @@ -182,8 +182,8 @@ TEST_F(HostIrEvaluatorTest, AddInLoop) {
hic->addInput(in);
hic->addOutput(out);

auto* allocate_out = IrBuilder::create<kir::Allocate>(
out, MemoryType::Global, std::vector<Val*>({}), /*zero_init=*/true);
auto* allocate_out = IrBuilder::create<hir::Allocate>(
out, MemoryType::Global, /*zero_init=*/true);

auto* stream_index = IrBuilder::create<Val>(DataType::Index);
auto* for_loop = IrBuilder::create<ForLoop>(
Expand Down
4 changes: 2 additions & 2 deletions tests/cpp/test_multidevice_host_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -532,8 +532,8 @@ TEST_F(MultiDeviceTest, SwizzleWithParallelType) {
tv->axis(0)->parallelize(ParallelType::Stream);
}

auto* allocate_out = IrBuilder::create<kir::Allocate>(
out_tv, MemoryType::Global, std::vector<Val*>({}), /*zero_init=*/true);
auto* allocate_out = IrBuilder::create<hir::Allocate>(
out_tv, MemoryType::Global, /*zero_init=*/true);
auto* stream_index = IrBuilder::create<Val>(DataType::Index);
auto* for_loop = IrBuilder::create<ForLoop>(
stream_index,
Expand Down
Loading