Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 211 additions & 0 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
// clang-format on
#include <fusion.h>

#include <type.h>
#include <iterator>
#include <ranges>

Expand All @@ -19,7 +20,9 @@
#include <host_ir/container.h>
#include <instrumentation.h>
#include <ir/all_nodes.h>
#include <ir/builder.h>
#include <ir/cloner.h>
#include <ir/internal_nodes.h>
#include <ir/printer.h>
#include <ir/utils.h>
#include <iter_visitor.h>
Expand Down Expand Up @@ -137,6 +140,16 @@ void Fusion::swap(Fusion& a, Fusion& b) noexcept {
std::swap(a.outputs_, b.outputs_);

std::swap(a.io_alias_, b.io_alias_);

// Swap per-Fusion special values (Phase 2)
std::swap(a.zero_val_, b.zero_val_);
std::swap(a.one_val_, b.one_val_);
std::swap(a.true_val_, b.true_val_);
std::swap(a.false_val_, b.false_val_);
std::swap(a.magic_zero_val_, b.magic_zero_val_);

std::swap(a.axioms_, b.axioms_);
std::swap(a.metadata_, b.metadata_);
}

std::unique_ptr<SegmentedFusion> Fusion::segment(
Expand All @@ -150,6 +163,24 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {

auto ir_cloner = IrContainer::copy(from->ir_container(), to->ir_container());

// Remap cached special val pointers through the cloner
if (from->zero_val_) {
to->zero_val_ = ir_cloner.clone(from->zero_val_);
}
if (from->one_val_) {
to->one_val_ = ir_cloner.clone(from->one_val_);
}
if (from->true_val_) {
to->true_val_ = ir_cloner.clone(from->true_val_);
}
if (from->false_val_) {
to->false_val_ = ir_cloner.clone(from->false_val_);
}
if (from->magic_zero_val_) {
to->magic_zero_val_ =
ir_cloner.clone(from->magic_zero_val_)->as<NamedScalar>();
}

for (auto val : from->vals()) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
Expand Down Expand Up @@ -198,6 +229,19 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {

to->expected_dynamic_smem_bytes_ = from->expected_dynamic_smem_bytes_;

if (from->axioms_ != nullptr) {
to->axioms_ = std::make_unique<std::vector<Val*>>();
to->axioms_->reserve(from->axioms_->size());
for (auto pred : *from->axioms_) {
to->axioms_->push_back(ir_cloner.clone(pred));
}
}

for (auto& [key, val_expr] : from->metadata_) {
to->metadata_[ir_cloner.clone(key)] = std::make_pair(
ir_cloner.clone(val_expr.first), ir_cloner.clone(val_expr.second));
}

if (from->all_tvs_ptr_ != nullptr) {
to->all_tvs_ptr_ = std::make_unique<std::vector<TensorView*>>();
to->all_tvs_ptr_->reserve(from->all_tvs_ptr_->size());
Expand Down Expand Up @@ -264,6 +308,17 @@ void Fusion::clear() noexcept {
managed_data_.clear();
managed_named_data_.clear();

// Reset per-Fusion special value caches (the vals themselves are owned by
// ir_container and were already destroyed by ir_container()->clear() above).
zero_val_ = nullptr;
one_val_ = nullptr;
true_val_ = nullptr;
false_val_ = nullptr;
magic_zero_val_ = nullptr;

axioms_.reset();
metadata_.clear();

invalidateTvsAndUses();

is_during_update_uses_ = false;
Expand Down Expand Up @@ -297,6 +352,12 @@ void Fusion::removeExpr(Expr* expr) {
void Fusion::removeVal(Val* val) {
assertInContainer(val, "Cannot remove val ");

// Don't remove cached special vals — they are lazily created singletons
if (val == zero_val_ || val == one_val_ || val == true_val_ ||
val == false_val_ || val == magic_zero_val_) {
return;
}

NVF_CHECK(
!val->isFusionInput(),
"Cannot remove val as it is an input of the fusion.");
Expand Down Expand Up @@ -340,6 +401,57 @@ void Fusion::removeVal(Val* val) {
invalidateTvsAndUses();
}

void Fusion::removeStatementsCreatedAfter(
int64_t num_exprs_before,
int64_t num_vals_before) {
auto* c = ir_container();

NVF_ERROR(
c->exprs_up_.size() == c->exprs_.size(),
"exprs_up_ (size ",
c->exprs_up_.size(),
") and exprs_ (size ",
c->exprs_.size(),
") are out of sync.");
NVF_ERROR(
std::ssize(c->exprs_up_) >= num_exprs_before,
"exprs_up_ size (",
std::ssize(c->exprs_up_),
") is less than num_exprs_before (",
num_exprs_before,
").");

// Remove expressions before values because we need to change Val::uses_.
while (std::ssize(c->exprs_up_) > num_exprs_before) {
Expr* e = c->exprs_up_.back().get();
for (Val* in : e->inputs()) {
in->removeUse(e);
}
c->exprs_.erase(e);
c->exprs_up_.pop_back();
}

// Null out any special value caches that point to vals about to be destroyed.
// This prevents dangling pointers when special vals are lazily created inside
// a StatementGuard scope.
while (std::ssize(c->vals_up_) > num_vals_before) {
Val* v = c->vals_up_.back().get();
if (v == zero_val_) {
zero_val_ = nullptr;
} else if (v == one_val_) {
one_val_ = nullptr;
} else if (v == true_val_) {
true_val_ = nullptr;
} else if (v == false_val_) {
false_val_ = nullptr;
} else if (v == magic_zero_val_) {
magic_zero_val_ = nullptr;
}
c->vals_.erase(v);
c->vals_up_.pop_back();
}
}

void Fusion::addInput(Val* input) {
assertInContainer(input, "Cannot register input ");

Expand Down Expand Up @@ -689,6 +801,105 @@ void Fusion::printTransforms() {
t_exprs.handle(this);
}

Val* Fusion::zeroVal() {
if (!zero_val_) {
zero_val_ = IrBuilder::createInContainer<Val>(this, 0L, DataType::Index);
}
return zero_val_;
}

Val* Fusion::oneVal() {
if (!one_val_) {
one_val_ = IrBuilder::createInContainer<Val>(this, 1L, DataType::Index);
}
return one_val_;
}

Val* Fusion::falseVal() {
if (!false_val_) {
false_val_ = IrBuilder::createInContainer<Val>(this, false, DataType::Bool);
}
return false_val_;
}

Val* Fusion::trueVal() {
if (!true_val_) {
true_val_ = IrBuilder::createInContainer<Val>(this, true, DataType::Bool);
}
return true_val_;
}

NamedScalar* Fusion::magicZeroVal() {
if (!magic_zero_val_) {
magic_zero_val_ = IrBuilder::createInContainer<NamedScalar>(
this, kMagicZeroName, DataType::Index);
}
return magic_zero_val_;
}

Val* Fusion::zeroVal(DataType dtype) {
if (dtype == DataType::Index) {
return zeroVal();
} else if (isBooleanType(dtype)) {
return falseVal();
} else {
// NOTE: this does not cache values
return IrBuilder::createInContainer<Val>(this, 0L, dtype);
}
}

Val* Fusion::oneVal(DataType dtype) {
if (dtype == DataType::Index) {
return oneVal();
} else if (isBooleanType(dtype)) {
return trueVal();
} else {
// NOTE: this does not cache values
return IrBuilder::createInContainer<Val>(this, 1L, dtype);
}
}

Val* Fusion::metadataOf(Val* v) {
if (metadata_.count(v) == 0) {
auto metadata_val =
IrBuilder::createInContainer<Val>(this, metaDataTypeOf(v));
auto metadata_expr =
IrBuilder::createInContainer<GetMetaData>(this, metadata_val, v);
metadata_[v] = std::make_pair(metadata_val, metadata_expr);
}
return metadata_.at(v).first;
}

const std::vector<Val*>& Fusion::axioms() {
if (!axioms_) {
axioms_ = std::make_unique<std::vector<Val*>>();
axioms_->reserve(kParallelTypeThreads.size() * 3);
auto zero = zeroVal();
for (auto p : kParallelTypeThreads) {
auto pidx = NamedScalar::getParallelIndex(p);
auto pdim = NamedScalar::getParallelDim(p);
axioms_->push_back(SimplifyingIrBuilder::geExpr(pidx, zero));
axioms_->push_back(SimplifyingIrBuilder::gtExpr(pdim, zero));
axioms_->push_back(SimplifyingIrBuilder::ltExpr(pidx, pdim));
}
}
return *axioms_;
}

void Fusion::assumePositive(Val* val) {
NVF_ERROR(inContainer(val));
// Lazy init axioms, then add the assumption
axioms();
axioms_->emplace_back(IrBuilder::gtExpr(val, zeroVal()));
}

void Fusion::assumeNonNegative(Val* val) {
NVF_ERROR(inContainer(val));
// Lazy init axioms, then add the assumption
axioms();
axioms_->emplace_back(IrBuilder::geExpr(val, zeroVal()));
}

void Fusion::registerVal(Val* val) {
if (inContainer(val)) {
return;
Expand Down
71 changes: 25 additions & 46 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ namespace nvfuser {
//! checks.

class Fusion;
class NamedScalar;
class TensorView;

class SegmentCandidateFinder;
Expand Down Expand Up @@ -549,63 +550,31 @@ class NVF_API Fusion : public PolymorphicBase {
return ir_container()->numExprs();
}

int64_t numVals(bool include_shortcuts) const noexcept {
return ir_container()->numVals(include_shortcuts);
int64_t numVals() const noexcept {
return ir_container()->numVals();
}

// Shortcut values (frequently used constants)
Val* zeroVal() {
return ir_container()->zeroVal();
}

Val* oneVal() {
return ir_container()->oneVal();
}

Val* falseVal() {
return ir_container()->falseVal();
}

Val* trueVal() {
return ir_container()->trueVal();
}

NamedScalar* magicZeroVal() {
return ir_container()->magicZeroVal();
}

Val* zeroVal(DataType dtype) {
return ir_container()->zeroVal(dtype);
}
Val* zeroVal();
Val* oneVal();
Val* falseVal();
Val* trueVal();
NamedScalar* magicZeroVal();
Val* zeroVal(DataType dtype);
Val* oneVal(DataType dtype);

Val* oneVal(DataType dtype) {
return ir_container()->oneVal(dtype);
}

Val* metadataOf(Val* val) {
return ir_container()->metadataOf(val);
}
Val* metadataOf(Val* val);

// Axioms (CUDA programming assumptions)
const std::vector<Val*>& axioms() {
return ir_container()->axioms();
}
const std::vector<Val*>& axioms();

void assumePositive(Val* val) {
ir_container()->assumePositive(val);
}

void assumeNonNegative(Val* val) {
ir_container()->assumeNonNegative(val);
}
void assumePositive(Val* val);
void assumeNonNegative(Val* val);

// Statement removal
void removeStatementsCreatedAfter(
int64_t num_exprs_before,
int64_t num_vals_before) {
ir_container()->removeStatementsCreatedAfter(
num_exprs_before, num_vals_before);
}
int64_t num_vals_before);

protected:
friend SegmentCandidateFinder;
Expand Down Expand Up @@ -667,6 +636,16 @@ class NVF_API Fusion : public PolymorphicBase {

inline static const std::string exact_mappings_key = "exact_mappings";
std::unique_ptr<IrContainer> ir_container_;

Val* zero_val_ = nullptr;
Val* one_val_ = nullptr;
Val* true_val_ = nullptr;
Val* false_val_ = nullptr;
NamedScalar* magic_zero_val_ = nullptr;

std::unique_ptr<std::vector<Val*>> axioms_;

std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;
};

// Template implementations for Fusion::manage<T>() that use IrCloner
Expand Down
Loading