Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 0 additions & 61 deletions csrc/ir/composite_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -488,67 +488,6 @@ std::vector<PolymorphicValue> SdpaFwdOp::evaluate(
return {output, log_sumexp, philox_seed, philox_offset};
}

std::string Scope::toString(int indent_size) const {
std::stringstream ss;
for (auto expr : exprs()) {
ss << expr->toString(indent_size);
}
return ss.str();
}

Scope::Iterator Scope::insert(Iterator pos, Expr* expr) {
return exprs_.insert(pos, expr);
}

Scope::Iterator Scope::insert_before(Expr* ref, Expr* expr) {
const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
NVF_ERROR(
it != exprs_.end(),
"Tried to insert ",
expr,
" before the reference: ",
ref,
" @ ",
(size_t)ref,
" however the reference was not found in this scope.");
return insert(it, expr);
}

Scope::Iterator Scope::insert_after(Expr* ref, Expr* expr) {
const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
NVF_ERROR(
it != exprs_.end(),
"Tried to insert ",
expr,
" after the reference: ",
ref,
" however the reference was not found in this scope.");
auto insert_pos = std::next(it);
return insert(insert_pos, expr);
}

void Scope::erase(Iterator pos) {
// Remove the scope of the expr if this is the scope
[[maybe_unused]] auto expr = *pos;
exprs_.erase(pos);
}

void Scope::erase(Expr* ref) {
const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
if (it != exprs_.end()) {
erase(it);
}
}

bool Scope::contains(Expr* expr) const {
const auto it = std::find(exprs_.begin(), exprs_.end(), expr);
return it != exprs_.end();
}

void Scope::clear() {
exprs_.clear();
}

SdpaBwdOp::SdpaBwdOp(
IrBuilderPasskey passkey,
TensorView* grad_query,
Expand Down
69 changes: 0 additions & 69 deletions csrc/ir/composite_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
// clang-format on
#pragma once

#include <list>

#include <exceptions.h>
#include <fusion.h>
#include <ir/base_nodes.h>
Expand Down Expand Up @@ -223,73 +221,6 @@ class SdpaFwdOp : public Expr {
const std::vector<PolymorphicValue>& inputs) const override;
};

class Scope {
public:
using ExprList = std::list<Expr*>;
using Iterator = ExprList::const_iterator;

explicit Scope(Expr* owner) : owner_(owner) {}

std::string toString(int indent_size = 0) const;

const ExprList& exprs() const {
return exprs_;
}

// Used only by MultiDeviceExecutor. Should generally be avoided in favor of
// other modifying methods.
ExprList& mutableExprs() {
return exprs_;
}

Expr* front() const {
NVF_ERROR(
!exprs_.empty(), "Attempting to access the front of an empty Scope");
return exprs_.front();
}

Expr* back() const {
NVF_ERROR(
!exprs_.empty(), "Attempting to access the back of an empty Scope");
return exprs_.back();
}

bool empty() const {
return exprs_.empty();
}

int64_t size() const {
return std::ssize(exprs_);
}

Iterator insert(Iterator pos, Expr* expr);

Iterator pushBack(Expr* e) {
return insert(exprs_.end(), e);
}

void clear();

Expr* owner() const {
return owner_;
}

// The following methods perform linear searches over exprs_. Use them only
// when necessary, as they do not scale well with large scopes.
Iterator insert_before(Expr* ref, Expr* expr);
Iterator insert_after(Expr* ref, Expr* expr);
void erase(Expr* ref);
bool contains(Expr* expr) const;

private:
void erase(Iterator pos);

ExprList exprs_;

//! Owner exprssion of this scope, e.g., IfThenElse
Expr* owner_ = nullptr;
};

// SDPA bwd node with same functionality
// at::_scaled_dot_product_flash_attention_backward
// grad_query = [N, H, L, E]
Expand Down
63 changes: 63 additions & 0 deletions csrc/ir/internal_nodes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,69 @@

namespace nvfuser {

std::string Scope::toString(int indent_size) const {
std::stringstream ss;
for (auto expr : exprs()) {
ss << expr->toString(indent_size);
}
return ss.str();
}

Scope::Iterator Scope::insert(Iterator pos, Expr* expr) {
return exprs_.insert(pos, expr);
}

Scope::Iterator Scope::insert_before(Expr* ref, Expr* expr) {
const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
NVF_ERROR(
it != exprs_.end(),
"Tried to insert ",
expr,
" before the reference: ",
ref,
" @ ",
(size_t)ref,
" however the reference was not found in this scope.");
return insert(it, expr);
}

Scope::Iterator Scope::insert_after(Expr* ref, Expr* expr) {
const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
NVF_ERROR(
it != exprs_.end(),
"Tried to insert ",
expr,
" after the reference: ",
ref,
" @ ",
(size_t)ref,
" however the reference was not found in this scope.");
auto insert_pos = std::next(it);
return insert(insert_pos, expr);
}

void Scope::erase(Iterator pos) {
// Remove the scope of the expr if this is the scope
[[maybe_unused]] auto expr = *pos;
exprs_.erase(pos);
}

void Scope::erase(Expr* ref) {
const auto it = std::find(exprs_.begin(), exprs_.end(), ref);
if (it != exprs_.end()) {
erase(it);
}
}

bool Scope::contains(Expr* expr) const {
const auto it = std::find(exprs_.begin(), exprs_.end(), expr);
return it != exprs_.end();
}

void Scope::clear() {
exprs_.clear();
}

FullOp::FullOp(IrBuilderPasskey passkey, Val* out, Val* fill_value)
: Expr(passkey) {
if (out->isA<TensorView>()) {
Expand Down
68 changes: 67 additions & 1 deletion csrc/ir/internal_nodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,76 @@
namespace nvfuser {

class ViewTransform;
class Scope;
class IrCloner;
struct AnalyzeViewResult;

class Scope {
public:
using ExprList = std::list<Expr*>;
using Iterator = ExprList::const_iterator;

explicit Scope(Expr* owner) : owner_(owner) {}

std::string toString(int indent_size = 0) const;

const ExprList& exprs() const {
return exprs_;
}

// Used only by MultiDeviceExecutor. Should generally be avoided in favor of
// other modifying methods.
ExprList& mutableExprs() {
return exprs_;
}

Expr* front() const {
NVF_ERROR(
!exprs_.empty(), "Attempting to access the front of an empty Scope");
return exprs_.front();
}

Expr* back() const {
NVF_ERROR(
!exprs_.empty(), "Attempting to access the back of an empty Scope");
return exprs_.back();
}

bool empty() const {
return exprs_.empty();
}

int64_t size() const {
return std::ssize(exprs_);
}

Iterator insert(Iterator pos, Expr* expr);

Iterator pushBack(Expr* e) {
return insert(exprs_.end(), e);
}

void clear();

Expr* owner() const {
return owner_;
}

// The following methods perform linear searches over exprs_. Use them only
// when necessary, as they do not scale well with large scopes.
Iterator insert_before(Expr* ref, Expr* expr);
Iterator insert_after(Expr* ref, Expr* expr);
void erase(Expr* ref);
bool contains(Expr* expr) const;

private:
void erase(Iterator pos);

ExprList exprs_;

//! Owner exprssion of this scope, e.g., IfThenElse
Expr* owner_ = nullptr;
};

class NVF_API FullOp : public Expr {
public:
using Expr::Expr;
Expand Down