Skip to content
163 changes: 112 additions & 51 deletions csrc/fusion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,51 +105,88 @@ bool Fusion::sameDefinition(const Fusion& other) const {
return true;
}

void Fusion::swap(Fusion& a, Fusion& b) noexcept {
void Fusion::swap(Fusion& a, Fusion& b) {
FUSER_PERF_SCOPE("Fusion swap");

// We need to be careful to call IrContainer swap not unique_ptr swap, which
// will only swap the ptrs NOT the contents.
IrContainer::swap(*(a.ir_container()), *(b.ir_container()));
if (&a == &b) {
return;
}

// After swapping container contents, update Statement::ir_container_
// pointers so each Statement points to the Fusion whose container now
// holds it. Also fix per-Fusion tracking keys since a's container had
// b's entries and vice versa.
a.ir_container()->transferStatementOwnership(&b, &a);
b.ir_container()->transferStatementOwnership(&a, &b);
NVF_ERROR(a.ir_container_ != nullptr, "Fusion::swap: a has null ir_container_");
NVF_ERROR(b.ir_container_ != nullptr, "Fusion::swap: b has null ir_container_");

if (a.ir_container_) {
for (auto val : a.vals()) {
val->ir_container_ = &a;
}
for (auto expr : a.deterministic_exprs()) {
expr->ir_container_ = &a;
}
}
if (b.ir_container_) {
for (auto val : b.vals()) {
val->ir_container_ = &b;
}
for (auto expr : b.deterministic_exprs()) {
expr->ir_container_ = &b;
}
// Collect statements owned by each Fusion BEFORE swap so we can update
// Statement::ir_container_ pointers afterward.
std::vector<Val*> a_owned_vals, b_owned_vals;
std::vector<Expr*> a_owned_exprs, b_owned_exprs;

const auto& av = a.ir_container_->valsOwnedBy(&a);
const auto& ae = a.ir_container_->exprsOwnedBy(&a);
a_owned_vals.assign(av.begin(), av.end());
a_owned_exprs.assign(ae.begin(), ae.end());

const auto& bv = b.ir_container_->valsOwnedBy(&b);
const auto& be = b.ir_container_->exprsOwnedBy(&b);
b_owned_vals.assign(bv.begin(), bv.end());
b_owned_exprs.assign(be.begin(), be.end());

// Transfer Fusion registrations between containers before pointer swap.
// After swap, a will own b's container and b will own a's container.
if (a.ir_container_.get() != b.ir_container_.get()) {
a.ir_container_->transferFusion(&a, &b);
b.ir_container_->transferFusion(&b, &a);
}

// Swap container pointers
std::swap(a.ir_container_, b.ir_container_);

// Swap all Fusion-level members
std::swap(a.inputs_, b.inputs_);
std::swap(a.outputs_, b.outputs_);

std::swap(a.io_alias_, b.io_alias_);

// Swap per-Fusion special values (Phase 2)
std::swap(a.all_tv_uses_valid_, b.all_tv_uses_valid_);
std::swap(a.is_during_update_uses_, b.is_during_update_uses_);
std::swap(a.managed_data_, b.managed_data_);
std::swap(a.managed_named_data_, b.managed_named_data_);
std::swap(a.expected_dynamic_smem_bytes_, b.expected_dynamic_smem_bytes_);
std::swap(a.all_tvs_ptr_, b.all_tvs_ptr_);
std::swap(a.zero_val_, b.zero_val_);
std::swap(a.one_val_, b.one_val_);
std::swap(a.true_val_, b.true_val_);
std::swap(a.false_val_, b.false_val_);
std::swap(a.magic_zero_val_, b.magic_zero_val_);

std::swap(a.axioms_, b.axioms_);
std::swap(a.metadata_, b.metadata_);
std::swap(a.val_type_name_map_, b.val_type_name_map_);
std::swap(a.expr_name_counter_, b.expr_name_counter_);

// Update Statement::ir_container_ pointers: a's old statements now belong
// to b, and b's old statements now belong to a
for (auto* val : a_owned_vals) {
val->ir_container_ = &b;
}
for (auto* expr : a_owned_exprs) {
expr->ir_container_ = &b;
}
for (auto* val : b_owned_vals) {
val->ir_container_ = &a;
}
for (auto* expr : b_owned_exprs) {
expr->ir_container_ = &a;
}

// Update per-Fusion tracking keys in containers. At this point, both
// a and b are guaranteed to have non-null ir_container_ (verified above).
if (a.ir_container_.get() == b.ir_container_.get()) {
// Same container: directly swap per-Fusion tracking entries
auto* c = a.ir_container_.get();
std::swap(c->per_fusion_vals_[&a], c->per_fusion_vals_[&b]);
std::swap(c->per_fusion_exprs_[&a], c->per_fusion_exprs_[&b]);
} else {
// Different containers: rename tracking keys to match new owners
a.ir_container_->transferStatementOwnership(&b, &a);
b.ir_container_->transferStatementOwnership(&a, &b);
}
}

std::unique_ptr<SegmentedFusion> Fusion::segment(
Expand All @@ -161,10 +198,33 @@ std::unique_ptr<SegmentedFusion> Fusion::segment(
IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
to->clear();

auto ir_cloner =
IrContainer::copy(from->ir_container(), to->ir_container(), to);
IrCloner ir_cloner(to);

// Clone from's vals in insertion order
for (auto val : from->deterministic_vals()) {
ir_cloner.clone(val);
}

// Wire up definitions and uses on cloned vals in deterministic order
// to ensure exprs are inserted into exprs_up_ deterministically
for (auto val : from->deterministic_vals()) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
}

// Remap cached special val pointers through the cloner
// Sync per-Fusion name counters from source to dest.
// Must be AFTER all cloning (vals and exprs) so that registerVal/registerExpr
// increments during cloning do not inflate the final counter values.
// During cloning, registerVal increments the dest Fusion's counter for each
// val, then IrBuilder::clone overrides the name with setName(src->name()).
// If source names are non-sequential (e.g., {0..10, 22..27} from segmenter
// creating intermediate TVs), the dest counter ends up at N (number of vals)
// instead of max(name)+1. Copying the source's counter state ensures new
// vals created post-copy won't collide with existing names.
to->val_type_name_map_ = from->val_type_name_map_;
to->expr_name_counter_ = from->expr_name_counter_;

// Remap cached special val pointers
if (from->zero_val_) {
to->zero_val_ = ir_cloner.clone(from->zero_val_);
}
Expand All @@ -182,11 +242,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
ir_cloner.clone(from->magic_zero_val_)->as<NamedScalar>();
}

for (auto val : from->vals()) {
ir_cloner.clone(val)->setDefinition(ir_cloner.clone(val->definition_));
ir_cloner.clone(val)->setUses(ir_cloner.clone(val->uses_));
}

to->inputs_ = ir_cloner.clone(from->inputs_);
to->outputs_ = ir_cloner.clone(from->outputs_);
for (auto inp : to->inputs_) {
Expand All @@ -196,7 +251,6 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
out->setIsFusionOutput(true);
}

// TODO: put this into ir_cloner instead
for (Val* out : from->outputs_) {
const AliasInfo& alias = from->io_alias_.get(out);
if (alias.type == AllocationType::New) {
Expand All @@ -209,14 +263,12 @@ IrCloner Fusion::copy(const Fusion* from, Fusion* to) {
}

to->all_tv_uses_valid_ = from->all_tv_uses_valid_;
// This should never be true on copy, but copying for completeness.
to->is_during_update_uses_ = from->is_during_update_uses_;

for (const auto& i : from->managed_data_) {
if (i.first.has_value()) {
to->managed_data_.emplace_back(i.second(ir_cloner, i.first), i.second);
} else {
// Don't clone managed data if it has been reset
to->managed_data_.emplace_back(i.first, i.second);
}
}
Expand Down Expand Up @@ -259,30 +311,36 @@ Fusion::Fusion() : ir_container_(std::make_shared<IrContainer>()) {
ir_container_->addFusion(this);
}

// Copy constructor
Fusion::Fusion(const Fusion& other) : Fusion() {
// Copy constructor -- shares the source's container
Fusion::Fusion(const Fusion& other) : ir_container_(other.ir_container_) {
FUSER_PERF_SCOPE("Fusion copy");
ir_container_->addFusion(this);
Fusion::copy(&other, this);
}

// Move constructor
Fusion::Fusion(Fusion&& other) noexcept : Fusion() {
Fusion::Fusion(Fusion&& other) : Fusion() {
FUSER_PERF_SCOPE("Fusion move");
swap(*this, other);
}

// Copy Assignment -- shares the source's container
Fusion& Fusion::operator=(const Fusion& other) {
FUSER_PERF_SCOPE("Fusion copy assign");
Fusion copy(other);
clear();
swap(*this, copy);
if (this != &other) {
Fusion copy(other);
clear();
swap(*this, copy);
}
return *this;
}

Fusion& Fusion::operator=(Fusion&& other) noexcept {
Fusion& Fusion::operator=(Fusion&& other) {
FUSER_PERF_SCOPE("Fusion move assign");
clear();
swap(*this, other);
if (this != &other) {
clear();
swap(*this, other);
}
return *this;
}

Expand Down Expand Up @@ -320,6 +378,9 @@ void Fusion::clear() noexcept {
axioms_.reset();
metadata_.clear();

val_type_name_map_.clear();
expr_name_counter_ = 0;

invalidateTvsAndUses();

is_during_update_uses_ = false;
Expand Down Expand Up @@ -925,7 +986,7 @@ void Fusion::registerVal(Val* val) {
c->vals_up_.emplace_back(val);
c->vals_.insert(val);
c->per_fusion_vals_[this].insert(val);
val->setName(IrContainerPasskey(), c->getValName(val->vtype()));
val->setName(IrContainerPasskey(), getValName(val->vtype()));
}

void Fusion::registerExpr(Expr* expr) {
Expand All @@ -942,7 +1003,7 @@ void Fusion::registerExpr(Expr* expr) {
c->exprs_up_.emplace_back(expr);
c->exprs_.insert(expr);
c->per_fusion_exprs_[this].insert(expr);
expr->setName(IrContainerPasskey(), c->getExprName());
expr->setName(IrContainerPasskey(), getExprName());

for (Val* input : expr->inputs()) {
assertInContainer(input, "Input to expr is invalid, ");
Expand Down
21 changes: 18 additions & 3 deletions csrc/fusion.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,14 +180,14 @@ class NVF_API Fusion : public PolymorphicBase {
Fusion();

Fusion(const Fusion& other);
Fusion(Fusion&& other) noexcept;
Fusion(Fusion&& other);

Fusion& operator=(const Fusion& other);
Fusion& operator=(Fusion&& other) noexcept;
Fusion& operator=(Fusion&& other);

~Fusion() override;

static void swap(Fusion& a, Fusion& b) noexcept;
static void swap(Fusion& a, Fusion& b);

void clear() noexcept;

Expand Down Expand Up @@ -661,6 +661,21 @@ class NVF_API Fusion : public PolymorphicBase {
std::unique_ptr<std::vector<Val*>> axioms_;

std::unordered_map<Val*, std::pair<Val*, Expr*>> metadata_;

// Per-Fusion name counters. Each Fusion independently tracks name assignment
// so that cloned Fusions get matching names (T0→T0) regardless of whether
// they share an IrContainer. This is required by downstream consumers that
// use tv->name() as a map key (alias_memory, GreedyParams, etc.).
std::unordered_map<ValType, StmtNameType> val_type_name_map_;
StmtNameType expr_name_counter_ = 0;

StmtNameType getValName(ValType vtype) {
return val_type_name_map_[vtype]++;
}

StmtNameType getExprName() {
return expr_name_counter_++;
}
};

// Template implementations for Fusion::manage<T>() that use IrCloner
Expand Down
Loading