Skip to content

[MLIR][Target/Cpp] Natural induction variable naming. #136102

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 84 additions & 26 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ struct CppEmitter {
/// Return the existing or a new name for a Value.
StringRef getOrCreateName(Value val);

/// Return the existing or a new name for a loop induction variable of an
/// emitc::ForOp.
StringRef getOrCreateInductionVarName(Value val);

// Returns the textual representation of a subscript operation.
std::string getSubscriptName(emitc::SubscriptOp op);

Expand All @@ -201,23 +205,39 @@ struct CppEmitter {
/// Whether to map an mlir integer to a unsigned integer in C++.
bool shouldMapToUnsigned(IntegerType::SignednessSemantics val);

/// RAII helper function to manage entering/exiting C++ scopes.
/// Abstract RAII helper function to manage entering/exiting C++ scopes.
struct Scope {
~Scope() { emitter.labelInScopeCount.pop(); }

private:
llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;

protected:
Scope(CppEmitter &emitter)
: valueMapperScope(emitter.valueMapper),
blockMapperScope(emitter.blockMapper), emitter(emitter) {
emitter.valueInScopeCount.push(emitter.valueInScopeCount.top());
emitter.labelInScopeCount.push(emitter.labelInScopeCount.top());
}
~Scope() {
emitter.valueInScopeCount.pop();
emitter.labelInScopeCount.pop();
CppEmitter &emitter;
};

/// RAII helper function to manage entering/exiting functions, while re-using
/// value names.
struct FunctionScope : Scope {
FunctionScope(CppEmitter &emitter) : Scope(emitter) {
// Re-use value names
emitter.resetValueCounter();
}
};

private:
llvm::ScopedHashTableScope<Value, std::string> valueMapperScope;
llvm::ScopedHashTableScope<Block *, std::string> blockMapperScope;
CppEmitter &emitter;
/// RAII helper function to manage entering/exiting emitc::forOp loops and
/// handle induction variable naming.
struct LoopScope : Scope {
LoopScope(CppEmitter &emitter) : Scope(emitter) {
emitter.increaseLoopNestingLevel();
}
~LoopScope() { emitter.decreaseLoopNestingLevel(); }
};

/// Returns wether the Value is assigned to a C++ variable in the scope.
Expand Down Expand Up @@ -253,6 +273,15 @@ struct CppEmitter {
return operandExpression == emittedExpression;
};

// Resets the value counter to 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Resets the value counter to 0
// Resets the value counter to 0.

void resetValueCounter();

// Increases the loop nesting level by 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Increases the loop nesting level by 1
// Increases the loop nesting level by 1.

void increaseLoopNestingLevel();

// Decreases the loop nesting level by 1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// Decreases the loop nesting level by 1
// Decreases the loop nesting level by 1.

void decreaseLoopNestingLevel();

private:
using ValueMapper = llvm::ScopedHashTable<Value, std::string>;
using BlockMapper = llvm::ScopedHashTable<Block *, std::string>;
Expand All @@ -274,11 +303,19 @@ struct CppEmitter {
/// Map from block to name of C++ label.
BlockMapper blockMapper;

/// The number of values in the current scope. This is used to declare the
/// names of values in a scope.
std::stack<int64_t> valueInScopeCount;
/// Default values representing outermost scope
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Default values representing outermost scope
/// Default values representing outermost scope.

llvm::ScopedHashTableScope<Value, std::string> defaultValueMapperScope;
llvm::ScopedHashTableScope<Block *, std::string> defaultBlockMapperScope;

std::stack<int64_t> labelInScopeCount;

/// Keeps track of the amount of nested loops the emitter currently operates
/// in.
uint64_t loopNestingLevel{0};

/// Emitter-level count of created values to enable unique identifiers.
unsigned int valueCount{0};

/// State of the current expression being emitted.
ExpressionOp emittedExpression;
SmallVector<int> emittedExpressionPrecedence;
Expand Down Expand Up @@ -860,7 +897,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
}

static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {

raw_indented_ostream &os = emitter.ostream();

// Utility function to determine whether a value is an expression that will be
Expand All @@ -879,12 +915,12 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
emitter.emitType(forOp.getLoc(), forOp.getInductionVar().getType())))
return failure();
os << " ";
os << emitter.getOrCreateName(forOp.getInductionVar());
os << emitter.getOrCreateInductionVarName(forOp.getInductionVar());
os << " = ";
if (failed(emitter.emitOperand(forOp.getLowerBound())))
return failure();
os << "; ";
os << emitter.getOrCreateName(forOp.getInductionVar());
os << emitter.getOrCreateInductionVarName(forOp.getInductionVar());
os << " < ";
Value upperBound = forOp.getUpperBound();
bool upperBoundRequiresParentheses = requiresParentheses(upperBound);
Expand All @@ -895,13 +931,15 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::ForOp forOp) {
if (upperBoundRequiresParentheses)
os << ")";
os << "; ";
os << emitter.getOrCreateName(forOp.getInductionVar());
os << emitter.getOrCreateInductionVarName(forOp.getInductionVar());
os << " += ";
if (failed(emitter.emitOperand(forOp.getStep())))
return failure();
os << ") {\n";
os.indent();

CppEmitter::LoopScope lScope(emitter);

Region &forRegion = forOp.getRegion();
auto regionOps = forRegion.getOps();

Expand Down Expand Up @@ -988,8 +1026,6 @@ static LogicalResult printOperation(CppEmitter &emitter,
}

static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
CppEmitter::Scope scope(emitter);

for (Operation &op : moduleOp) {
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
return failure();
Expand All @@ -1001,8 +1037,6 @@ static LogicalResult printOperation(CppEmitter &emitter, FileOp file) {
if (!emitter.shouldEmitFile(file))
return success();

CppEmitter::Scope scope(emitter);

for (Operation &op : file) {
if (failed(emitter.emitOperation(op, /*trailingSemicolon=*/false)))
return failure();
Expand Down Expand Up @@ -1118,7 +1152,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
return functionOp.emitOpError() << "cannot emit array type as result type";
}

CppEmitter::Scope scope(emitter);
CppEmitter::FunctionScope scope(emitter);
raw_indented_ostream &os = emitter.ostream();
if (failed(emitter.emitTypes(functionOp.getLoc(),
functionOp.getFunctionType().getResults())))
Expand Down Expand Up @@ -1146,7 +1180,7 @@ static LogicalResult printOperation(CppEmitter &emitter,
"with multiple blocks needs variables declared at top");
}

CppEmitter::Scope scope(emitter);
CppEmitter::FunctionScope scope(emitter);
raw_indented_ostream &os = emitter.ostream();
if (functionOp.getSpecifiers()) {
for (Attribute specifier : functionOp.getSpecifiersAttr()) {
Expand Down Expand Up @@ -1180,7 +1214,6 @@ static LogicalResult printOperation(CppEmitter &emitter,

static LogicalResult printOperation(CppEmitter &emitter,
DeclareFuncOp declareFuncOp) {
CppEmitter::Scope scope(emitter);
raw_indented_ostream &os = emitter.ostream();

auto functionOp = SymbolTable::lookupNearestSymbolFrom<emitc::FuncOp>(
Expand Down Expand Up @@ -1212,8 +1245,8 @@ static LogicalResult printOperation(CppEmitter &emitter,
CppEmitter::CppEmitter(raw_ostream &os, bool declareVariablesAtTop,
StringRef fileId)
: os(os), declareVariablesAtTop(declareVariablesAtTop),
fileId(fileId.str()) {
valueInScopeCount.push(0);
fileId(fileId.str()), defaultValueMapperScope(valueMapper),
defaultBlockMapperScope(blockMapper) {
labelInScopeCount.push(0);
}

Expand Down Expand Up @@ -1254,7 +1287,26 @@ StringRef CppEmitter::getOrCreateName(Value val) {
assert(!hasDeferredEmission(val.getDefiningOp()) &&
"cacheDeferredOpResult should have been called on this value, "
"update the emitOperation function.");
valueMapper.insert(val, formatv("v{0}", ++valueInScopeCount.top()));

valueMapper.insert(val, formatv("v{0}", ++valueCount));
}
return *valueMapper.begin(val);
}

/// Return the existing or a new name for a loop induction variable Value.
/// Loop induction variables follow natural naming: i, j, k,...
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe expand this a little on the actual naming that you're now introducing.

Suggested change
/// Loop induction variables follow natural naming: i, j, k,...
/// Loop induction variables follow natural naming: i, j, k, ..., t, u{X}

or

Suggested change
/// Loop induction variables follow natural naming: i, j, k,...
/// Loop induction variables follow natural naming: i, j, k, ..., t, uX

StringRef CppEmitter::getOrCreateInductionVarName(Value val) {
if (!valueMapper.count(val)) {

int64_t identifier = 'i' + loopNestingLevel;

if (identifier >= 'i' && identifier <= 't') {
valueMapper.insert(val,
formatv("{0}{1}", (char)identifier, ++valueCount));
} else {
// If running out of letters, continue with uX
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
// If running out of letters, continue with uX
// If running out of letters, continue with uX.

valueMapper.insert(val, formatv("u{0}", ++valueCount));
}
}
return *valueMapper.begin(val);
}
Expand Down Expand Up @@ -1793,6 +1845,12 @@ LogicalResult CppEmitter::emitTupleType(Location loc, ArrayRef<Type> types) {
return success();
}

void CppEmitter::resetValueCounter() { valueCount = 0; }

void CppEmitter::increaseLoopNestingLevel() { loopNestingLevel++; }

void CppEmitter::decreaseLoopNestingLevel() { loopNestingLevel--; }

LogicalResult emitc::translateToCpp(Operation *op, raw_ostream &os,
bool declareVariablesAtTop,
StringRef fileId) {
Expand Down
73 changes: 73 additions & 0 deletions mlir/test/Target/Cpp/for_loop_induction_vars.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s

// CHECK-LABEL: test_for_siblings
func.func @test_for_siblings() {
%start = emitc.literal "0" : index
%stop = emitc.literal "10" : index
%step = emitc.literal "1" : index

%var1 = "emitc.variable"() <{value = 0 : index}> : () -> !emitc.lvalue<index>
%var2 = "emitc.variable"() <{value = 0 : index}> : () -> !emitc.lvalue<index>

// CHECK: for (size_t [[ITER0:i[0-9]*]] = {{.*}}; [[ITER0]] < {{.*}}; [[ITER0]] += {{.*}}) {
emitc.for %i0 = %start to %stop step %step {
// CHECK: for (size_t [[ITER1:j[0-9]*]] = {{.*}}; [[ITER1]] < {{.*}}; [[ITER1]] += {{.*}}) {
emitc.for %i1 = %start to %stop step %step {
// CHECK: {{.*}} = [[ITER0]];
//"emitc.assign"(%var1,%i0) : (!emitc.lvalue<!emitc.size_t>, !emitc.size_t) -> ()
emitc.assign %i0 : index to %var1 : !emitc.lvalue<index>
// CHECK: {{.*}} = [[ITER1]];
//"emitc.assign"(%var2,%i1) : (!emitc.lvalue<!emitc.size_t>, !emitc.size_t) -> ()
emitc.assign %i1 : index to %var2 : !emitc.lvalue<index>
}
}
// CHECK: for (size_t [[ITER2:i[0-9]*]] = {{.*}}; [[ITER2]] < {{.*}}; [[ITER2]] += {{.*}})
emitc.for %ki2 = %start to %stop step %step {
// CHECK: for (size_t [[ITER3:j[0-9]*]] = {{.*}}; [[ITER3]] < {{.*}}; [[ITER3]] += {{.*}})
emitc.for %i3 = %start to %stop step %step {
%1 = emitc.call_opaque "f"() : () -> i32
}
}
return
}

// CHECK-LABEL: test_for_nesting
func.func @test_for_nesting() {
%start = emitc.literal "0" : index
%stop = emitc.literal "10" : index
%step = emitc.literal "1" : index

// CHECK-COUNT-12: for (size_t [[ITER:[i-t][0-9]*]] = {{.*}}; [[ITER]] < {{.*}}; [[ITER]] += {{.*}}) {
emitc.for %i0 = %start to %stop step %step {
emitc.for %i1 = %start to %stop step %step {
emitc.for %i2 = %start to %stop step %step {
emitc.for %i3 = %start to %stop step %step {
emitc.for %i4 = %start to %stop step %step {
emitc.for %i5 = %start to %stop step %step {
emitc.for %i6 = %start to %stop step %step {
emitc.for %i7 = %start to %stop step %step {
emitc.for %i8 = %start to %stop step %step {
emitc.for %i9 = %start to %stop step %step {
emitc.for %i10 = %start to %stop step %step {
emitc.for %i11 = %start to %stop step %step {
// CHECK: for (size_t [[ITERu0:u13]] = {{.*}}; [[ITERu0]] < {{.*}}; [[ITERu0]] += {{.*}}) {
emitc.for %i14 = %start to %stop step %step {
// CHECK: for (size_t [[ITERu1:u14]] = {{.*}}; [[ITERu1]] < {{.*}}; [[ITERu1]] += {{.*}}) {
emitc.for %i15 = %start to %stop step %step {
%0 = emitc.call_opaque "f"() : () -> i32
}
}
}
}
}
}
}
}
}
}
}
}
}
}
return
}
Loading