Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[LoopScheduleToCalyx] deduplicate groups within a ParOp. #8055

Merged
merged 13 commits into from
Jan 9, 2025
17 changes: 17 additions & 0 deletions include/circt/Dialect/Calyx/CalyxLoweringUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,23 @@ struct EliminateUnusedCombGroups : mlir::OpRewritePattern<calyx::CombGroupOp> {
PatternRewriter &rewriter) const override;
};

/// Removes duplicate EnableOps in parallel operations.
struct DeduplicateParallelOp : mlir::OpRewritePattern<calyx::ParOp> {
using mlir::OpRewritePattern<calyx::ParOp>::OpRewritePattern;

LogicalResult matchAndRewrite(calyx::ParOp parOp,
PatternRewriter &rewriter) const override;
};

/// Removes duplicate EnableOps in static parallel operations.
struct DeduplicateStaticParallelOp
: mlir::OpRewritePattern<calyx::StaticParOp> {
using mlir::OpRewritePattern<calyx::StaticParOp>::OpRewritePattern;

LogicalResult matchAndRewrite(calyx::StaticParOp parOp,
PatternRewriter &rewriter) const override;
};

/// This pass recursively inlines use-def chains of combinational logic (from
/// non-stateful groups) into groups referenced in the control schedule.
class InlineCombGroups
Expand Down
57 changes: 38 additions & 19 deletions lib/Conversion/LoopScheduleToCalyx/LoopScheduleToCalyx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/TypeSwitch.h"

#include <type_traits>
#include <variant>

namespace circt {
Expand Down Expand Up @@ -126,6 +127,19 @@ class PipelineScheduler : public calyx::SchedulerInterface<Scheduleable> {
return pipelineRegs[stage];
}

/// Returns the pipeline register for this value if its defining operation is
/// a stage, and std::nullopt otherwise.
std::optional<calyx::RegisterOp> getPipelineRegister(Value value) {
auto opStage = dyn_cast<LoopSchedulePipelineStageOp>(value.getDefiningOp());
if (opStage == nullptr)
return std::nullopt;
// The pipeline register for this input value needs to be discovered.
auto opResult = cast<OpResult>(value);
unsigned int opNumber = opResult.getResultNumber();
auto &stageRegisters = getPipelineRegs(opStage);
return stageRegisters.find(opNumber)->second;
}

/// Add a stage's groups to the pipeline prologue.
void addPipelinePrologue(Operation *op, SmallVector<StringAttr> groupNames) {
pipelinePrologue[op].push_back(groupNames);
Expand Down Expand Up @@ -306,9 +320,14 @@ class BuildOpGroups : public calyx::FuncOpPartialLoweringPattern {
/// Create assignments to the inputs of the library op.
auto group = createGroupForOp<TGroupOp>(rewriter, op);
rewriter.setInsertionPointToEnd(group.getBodyBlock());
for (auto dstOp : enumerate(opInputPorts))
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(),
op->getOperand(dstOp.index()));
for (auto dstOp : enumerate(opInputPorts)) {
Value srcOp = op->getOperand(dstOp.index());
std::optional<calyx::RegisterOp> pipelineRegister =
getState<ComponentLoweringState>().getPipelineRegister(srcOp);
if (pipelineRegister.has_value())
srcOp = pipelineRegister->getOut();
rewriter.create<calyx::AssignOp>(op.getLoc(), dstOp.value(), srcOp);
}

/// Replace the result values of the source operator with the new operator.
for (auto res : enumerate(opOutputPorts)) {
Expand Down Expand Up @@ -1055,22 +1074,17 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
Value value = operand.get();

// Get the pipeline register for that result.
auto pipelineRegister = pipelineRegisters[i];
calyx::RegisterOp pipelineRegister = pipelineRegisters[i];
if (std::optional<calyx::RegisterOp> pr =
state.getPipelineRegister(value)) {
value = pr->getOut();
}

calyx::GroupOp group;
// Get the evaluating group for that value.
std::optional<calyx::GroupInterface> evaluatingGroup =
state.findEvaluatingGroup(value);
if (!evaluatingGroup.has_value()) {
if (auto opStage =
dyn_cast<LoopSchedulePipelineStageOp>(value.getDefiningOp())) {
// The pipeline register for this input value needs to be discovered.
auto opResult = cast<OpResult>(value);
unsigned int opNumber = opResult.getResultNumber();
auto &stageRegisters = state.getPipelineRegs(opStage);
calyx::RegisterOp opRegister = stageRegisters.find(opNumber)->second;
value = opRegister.getOut(); // Pass the `out` wire of this register.
}
if (value.getDefiningOp<calyx::RegisterOp>() == nullptr) {
// We add this for any unhandled cases.
llvm::errs() << "unexpected: input value: " << value << ", in stage "
Expand Down Expand Up @@ -1166,8 +1180,9 @@ class BuildPipelineGroups : public calyx::FuncOpPartialLoweringPattern {
}
doneOp.getSrcMutable().assign(pipelineRegister.getDone());

// Remove the old register completely.
rewriter.eraseOp(tempReg);
// Remove the old register if it has no more uses.
if (tempReg->use_empty())
rewriter.eraseOp(tempReg);

return group;
}
Expand Down Expand Up @@ -1534,10 +1549,11 @@ class LoopScheduleToCalyxPass
if (runOnce)
config.maxIterations = 1;

/// Can't return applyPatternsGreedily. Root isn't
/// Can't return applyPatternsAndFoldGreedily. Root isn't
/// necessarily erased so it will always return failed(). Instead,
/// forward the 'succeeded' value from PartialLoweringPatternBase.
(void)applyPatternsGreedily(getOperation(), std::move(pattern), config);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(pattern),
config);
return partialPatternRes;
}

Expand Down Expand Up @@ -1628,6 +1644,9 @@ void LoopScheduleToCalyxPass::runOnOperation() {
addOncePattern<calyx::InlineCombGroups>(loweringPatterns, patternState,
*loweringState);

addGreedyPattern<calyx::DeduplicateParallelOp>(loweringPatterns);
addGreedyPattern<calyx::DeduplicateStaticParallelOp>(loweringPatterns);

/// This pattern performs various SSA replacements that must be done
/// after control generation.
addOncePattern<LateSSAReplacement>(loweringPatterns, patternState, funcMap,
Expand Down Expand Up @@ -1665,8 +1684,8 @@ void LoopScheduleToCalyxPass::runOnOperation() {
RewritePatternSet cleanupPatterns(&getContext());
cleanupPatterns.add<calyx::MultipleGroupDonePattern,
calyx::NonTerminatingGroupDonePattern>(&getContext());
if (failed(
applyPatternsGreedily(getOperation(), std::move(cleanupPatterns)))) {
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(cleanupPatterns)))) {
signalPassFailure();
return;
}
Expand Down
38 changes: 38 additions & 0 deletions lib/Dialect/Calyx/Transforms/CalyxLoweringUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,28 @@ using namespace mlir::arith;
namespace circt {
namespace calyx {

template <typename OpTy>
static LogicalResult deduplicateParallelOperation(OpTy parOp,
PatternRewriter &rewriter) {
auto *body = parOp.getBodyBlock();
if (body->getOperations().size() < 2)
return failure();

LogicalResult result = LogicalResult::failure();
SetVector<StringRef> members;
for (auto &op : make_early_inc_range(*body)) {
auto enableOp = dyn_cast<EnableOp>(&op);
if (enableOp == nullptr)
continue;
bool inserted = members.insert(enableOp.getGroupName());
if (!inserted) {
rewriter.eraseOp(enableOp);
result = LogicalResult::success();
}
}
return result;
}

void appendPortsForExternalMemref(PatternRewriter &rewriter, StringRef memName,
Value memref, unsigned memoryID,
SmallVectorImpl<calyx::PortInfo> &inPorts,
Expand Down Expand Up @@ -609,6 +631,22 @@ EliminateUnusedCombGroups::matchAndRewrite(calyx::CombGroupOp combGroupOp,
return success();
}

//===----------------------------------------------------------------------===//
// DeduplicateParallelOperations
//===----------------------------------------------------------------------===//

LogicalResult
DeduplicateParallelOp::matchAndRewrite(calyx::ParOp parOp,
PatternRewriter &rewriter) const {
return deduplicateParallelOperation<calyx::ParOp>(parOp, rewriter);
}

LogicalResult
DeduplicateStaticParallelOp::matchAndRewrite(calyx::StaticParOp parOp,
PatternRewriter &rewriter) const {
return deduplicateParallelOperation<calyx::StaticParOp>(parOp, rewriter);
}

//===----------------------------------------------------------------------===//
// InlineCombGroups
//===----------------------------------------------------------------------===//
Expand Down
62 changes: 62 additions & 0 deletions test/Conversion/LoopScheduleToCalyx/pipeline_register_pass.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// RUN: circt-opt %s -lower-loopschedule-to-calyx -canonicalize -split-input-file | FileCheck %s

// This will introduce duplicate groups; these should be subsequently removed during canonicalization.

// CHECK: calyx.while %std_lt_0.out with @bb0_0 {
// CHECK-NEXT: calyx.par {
// CHECK-NEXT: calyx.enable @bb0_1
// CHECK-NEXT: }
// CHECK-NEXT: }
module {
func.func @foo() attributes {} {
%const = arith.constant 1 : index
loopschedule.pipeline II = 1 trip_count = 20 iter_args(%counter = %const) : (index) -> () {
%latch = arith.cmpi ult, %counter, %const : index
loopschedule.register %latch : i1
} do {
%S0 = loopschedule.pipeline.stage start = 0 {
%op = arith.addi %counter, %const : index
loopschedule.register %op : index
} : index
%S1 = loopschedule.pipeline.stage start = 1 {
loopschedule.register %S0: index
} : index
loopschedule.terminator iter_args(%S0), results() : (index) -> ()
}
return
}
}

// -----

// Stage pipeline registers passed directly to the next stage
// should also be updated when used in computations.

// CHECK: calyx.group @bb0_2 {
// CHECK-NEXT: calyx.assign %std_add_1.left = %while_0_arg0_reg.out : i32
// CHECK-NEXT: calyx.assign %std_add_1.right = %c1_i32 : i32
// CHECK-NEXT: calyx.assign %stage_1_register_0_reg.in = %std_add_1.out : i32
// CHECK-NEXT: calyx.assign %stage_1_register_0_reg.write_en = %true : i1
// CHECK-NEXT: calyx.group_done %stage_1_register_0_reg.done : i1
// CHECK-NEXT: }
module {
func.func @foo() attributes {} {
%const = arith.constant 1 : index
loopschedule.pipeline II = 1 trip_count = 20 iter_args(%counter = %const) : (index) -> () {
%latch = arith.cmpi ult, %counter, %const : index
loopschedule.register %latch : i1
} do {
%S0 = loopschedule.pipeline.stage start = 0 {
%op = arith.addi %counter, %const : index
loopschedule.register %op : index
} : index
%S1 = loopschedule.pipeline.stage start = 1 {
%math = arith.addi %S0, %const : index
loopschedule.register %math : index
} : index
loopschedule.terminator iter_args(%S0), results() : (index) -> ()
}
return
}
}

Loading