Skip to content

When replacing a register with its reset value, attempt width coercion #8379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 138 additions & 3 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,140 @@ static APInt getMaxSignedValue(unsigned bitWidth) {
return bitWidth > 0 ? APInt::getSignedMaxValue(bitWidth) : APInt();
}

// NOLINTNEXTLINE(misc-no-recursion)
static Value coerceSource(PatternRewriter &rewriter, Location &loc,
FIRRTLBaseType targetType, FIRRTLBaseType sourceType,
Value source) {
if (sourceType == targetType)
return source;

auto srcType = sourceType.getAnonymousType();
auto tgtType = targetType.getAnonymousType();
if (srcType == tgtType)
return source;

auto srcBundleType = dyn_cast<BundleType>(srcType);
auto tgtBundleType = dyn_cast<BundleType>(tgtType);
if (srcBundleType && tgtBundleType) {
auto n = tgtBundleType.getNumElements();
SmallVector<Value> elems;
elems.reserve(n);
for (unsigned i = 0; i < n; ++i) {
auto srcElemType = srcBundleType.getElementType(i);
auto tgtElemType = tgtBundleType.getElementType(i);
auto srcElem = rewriter.create<SubfieldOp>(loc, source, i);
auto elem =
coerceSource(rewriter, loc, tgtElemType, srcElemType, srcElem);
elems.push_back(elem);
}
return rewriter.create<BundleCreateOp>(loc, tgtBundleType, elems);
}

auto srcVectorType = dyn_cast<FVectorType>(srcType);
auto tgtVectorType = dyn_cast<FVectorType>(tgtType);
if (srcVectorType && tgtVectorType) {
auto srcElemType = srcVectorType.getElementType();
auto tgtElemType = tgtVectorType.getElementType();
auto n = tgtVectorType.getNumElements();
SmallVector<Value> elems;
elems.reserve(n);
for (unsigned i = 0; i < n; ++i) {
auto srcElem = rewriter.create<SubindexOp>(loc, source, i);
auto elem =
coerceSource(rewriter, loc, tgtElemType, srcElemType, srcElem);
elems.push_back(elem);
}
return rewriter.create<VectorCreateOp>(loc, tgtVectorType, elems);
}

auto srcIntType = dyn_cast<IntType>(srcType);
auto tgtIntType = dyn_cast<IntType>(tgtType);
if (srcIntType && tgtIntType) {
auto srcWidth = srcIntType.getBitWidthOrSentinel();
auto tgtWidth = tgtIntType.getBitWidthOrSentinel();
if (tgtWidth < srcWidth) {
auto delta = srcWidth - tgtWidth;
Value value = rewriter.create<TailPrimOp>(loc, source, delta);
if (tgtIntType.isSigned())
value = rewriter.create<AsSIntPrimOp>(loc, value);
return value;
}

if (tgtWidth > srcWidth)
source = rewriter.create<PadPrimOp>(loc, source, tgtWidth);
if (tgtIntType.isSigned() && !srcIntType.isSigned())
return rewriter.create<AsSIntPrimOp>(loc, source);
if (!tgtIntType.isSigned() && srcIntType.isSigned())
return rewriter.create<AsUIntPrimOp>(loc, source);
return source;
}

return nullptr;
}

/// Emit a coercion from a value to a target type. Returns nullptr if the
/// coercion is not possible. The resulting value is a non-aliasing source
/// value. As such, we can only emit coercions for passive types.
static Value coerceSource(PatternRewriter &rewriter, Location loc,
Type targetType, Value source) {
Type sourceType = source.getType();

// If the types are syntactically equal, no action is needed.
if (sourceType == targetType)
return source;

// If either of the types are not FIRRTL base types, we cannot coerce.
auto sourceFType = type_cast<FIRRTLBaseType>(sourceType);
auto targetFType = type_cast<FIRRTLBaseType>(targetType);
if (!sourceFType || !targetFType)
return nullptr;

// After type_cast resolves type-aliases, the underlying types may be the
// same. If they are, no action is needed.
if (sourceFType == targetFType)
return source;

// One last shot at avoiding coercion: recursively unfold type-aliases and
// check again for syntactic equality. If they are, no action is needed.
if (sourceFType.getAnonymousType() == targetFType.getAnonymousType())
return source;

// OK, some coercion is necessary. Check if it's possible.

// Give up if either side contains const. Eventually, const will be removed
// from the compiler.
if (sourceFType.containsConst() || targetFType.containsConst())
return nullptr;

// We can only coerce when all the involved widths are known. We can usually
// truncate or extend the source value to match the destination, but if either
// src or dst has an uninferred width, we don't know which way to go.
if (sourceFType.hasUninferredWidth() || targetFType.hasUninferredWidth())
return nullptr;

// Similar story for resets...
if (sourceFType.hasUninferredReset() || targetFType.hasUninferredReset())
return nullptr;

// Give up if the target is not passive. If we have to coerce the source
// value, the coercion ops will produce a nonaliasing source value, which
// prevents us from properly coercing to a correct non-passive value.
if (!targetFType.isPassive() || targetFType.containsAnalog())
return nullptr;

// After the earlier recursive checks, we can defer to equivalence checking.
if (!areTypesEquivalent(targetFType, sourceFType))
return nullptr;

auto result = coerceSource(rewriter, loc, targetFType, sourceFType, source);

// Final sanity check: ensure the result will make matchingconnect happy.
if (result)
assert(areAnonymousTypesEquivalent(targetType, result.getType()));

return result;
}

//===----------------------------------------------------------------------===//
// Fold Hooks
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2269,14 +2403,15 @@ canonicalizeRegResetWithOneReset(RegResetOp reg, PatternRewriter &rewriter) {
if (!isDefinedByOneConstantOp(reg.getResetSignal()))
return failure();

auto resetValue = reg.getResetValue();
if (reg.getType(0) != resetValue.getType())
auto value =
coerceSource(rewriter, reg.getLoc(), reg.getType(0), reg.getResetValue());
if (!value)
return failure();

// Ignore 'passthrough'.
(void)dropWrite(rewriter, reg->getResult(0), {});
replaceOpWithNewOpAndCopyName<NodeOp>(
rewriter, reg, resetValue, reg.getNameAttr(), reg.getNameKind(),
rewriter, reg, value, reg.getNameAttr(), reg.getNameKind(),
reg.getAnnotationsAttr(), reg.getInnerSymAttr(), reg.getForceable());
return success();
}
Expand Down
68 changes: 58 additions & 10 deletions test/Dialect/FIRRTL/canonicalization.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2294,17 +2294,65 @@ firrtl.module @ForceableRegResetToNode(in %clock: !firrtl.clock, in %dummy : !fi
}

// https://github.com/llvm/circt/issues/8348
// CHECK-LABEL: firrtl.module @RegResetInvalidResetValueType
// We cannot replace a regreset with its reset value, when the reset value's type does not match.
firrtl.module @RegResetInvalidResetValueType(in %c : !firrtl.clock, out %out : !firrtl.uint<2>) {
%c0_ui1 = firrtl.constant 0 : !firrtl.uint<1>
%c0_ui2 = firrtl.constant 0 : !firrtl.uint<2>
// CHECK-LABEL: firrtl.module @RegResetCoerceIntResetValue
// When we replace a regreset with its reset value, we must ensure the reset-value is the correct type.
firrtl.module @RegResetCoerceIntResetValue(in %c : !firrtl.clock,
out %out_si1 : !firrtl.sint<1>, out %out_si2 : !firrtl.sint<2>,
out %out_ui1 : !firrtl.uint<1>, out %out_ui2 : !firrtl.uint<2>
) {
%c1_asyncreset = firrtl.specialconstant 1 : !firrtl.asyncreset

%c1_si1 = firrtl.constant 1 : !firrtl.sint<1>
%c1_si2 = firrtl.constant 1 : !firrtl.sint<2>

%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
%c1_ui2 = firrtl.constant 1 : !firrtl.uint<2>

// SInt Extension.
// CHECK: firrtl.matchingconnect %out_si2, %c-1_si2 : !firrtl.sint<2>
%reg_si2 = firrtl.regreset %c, %c1_asyncreset, %c1_si1 : !firrtl.clock, !firrtl.asyncreset, !firrtl.sint<1>, !firrtl.sint<2>
firrtl.matchingconnect %out_si2, %reg_si2 : !firrtl.sint<2>

// SInt Truncation.
// CHECK: firrtl.matchingconnect %out_si1, %c-1_si1 : !firrtl.sint<1>
%reg_si1 = firrtl.regreset %c, %c1_asyncreset, %c1_si2 : !firrtl.clock, !firrtl.asyncreset, !firrtl.sint<2>, !firrtl.sint<1>
firrtl.matchingconnect %out_si1, %reg_si1 : !firrtl.sint<1>

// UInt Extension.
// CHECK: firrtl.matchingconnect %out_ui2, %c1_ui2 : !firrtl.uint<2>
%reg_ui2 = firrtl.regreset %c, %c1_asyncreset, %c1_ui1 : !firrtl.clock, !firrtl.asyncreset, !firrtl.uint<1>, !firrtl.uint<2>
firrtl.matchingconnect %out_ui2, %reg_ui2 : !firrtl.uint<2>

// UInt Truncation.
// CHECK: firrtl.matchingconnect %out_ui1, %c1_ui1 : !firrtl.uint<1>
%reg_ui1 = firrtl.regreset %c, %c1_asyncreset, %c1_ui2 : !firrtl.clock, !firrtl.asyncreset, !firrtl.uint<2>, !firrtl.uint<1>
firrtl.matchingconnect %out_ui1, %reg_ui1 : !firrtl.uint<1>
}

// CHECK-LABEL: firrtl.module @RegResetCoerceBundleResetValue
firrtl.module @RegResetCoerceBundleResetValue(in %c : !firrtl.clock, out %out : !firrtl.bundle<a: sint<2>, b: sint<1>>) {
// CHECK: %0 = firrtl.aggregateconstant [-1 : si2, -1 : si1] : !firrtl.bundle<a: sint<2>, b: sint<1>>
// CHECK: firrtl.matchingconnect %out, %0 : !firrtl.bundle<a: sint<2>, b: sint<1>>
%c1_asyncreset = firrtl.specialconstant 1 : !firrtl.asyncreset
%v = firrtl.aggregateconstant [-1 : si1, 1 : si2] : !firrtl.bundle<a: sint<1>, b: sint<2>>
%r = firrtl.regreset %c, %c1_asyncreset, %v :
!firrtl.clock, !firrtl.asyncreset,
!firrtl.bundle<a: sint<1>, b: sint<2>>,
!firrtl.bundle<a: sint<2>, b: sint<1>>
firrtl.matchingconnect %out, %r : !firrtl.bundle<a: sint<2>, b: sint<1>>
}

// CHECK-LABEL: firrtl.module @RegResetCoerceVectorResetValue
firrtl.module @RegResetCoerceVectorResetValue(in %c : !firrtl.clock, out %out : !firrtl.vector<sint<2>, 1>) {
// CHECK: %0 = firrtl.aggregateconstant [-1 : si2] : !firrtl.vector<sint<2>, 1>
// CHECK: firrtl.matchingconnect %out, %0 : !firrtl.vector<sint<2>, 1>
%c1_asyncreset = firrtl.specialconstant 1 : !firrtl.asyncreset
// CHECK: %reg = firrtl.regreset
%reg = firrtl.regreset %c, %c1_asyncreset, %c0_ui1 : !firrtl.clock, !firrtl.asyncreset, !firrtl.uint<1>, !firrtl.uint<2>
// CHECK: firrtl.matchingconnect %out, %reg : !firrtl.uint<2>
firrtl.matchingconnect %out, %reg : !firrtl.uint<2>
firrtl.matchingconnect %reg, %c0_ui2 : !firrtl.uint<2>
%v = firrtl.aggregateconstant [-1 : si1] : !firrtl.vector<sint<1>, 1>
%r = firrtl.regreset %c, %c1_asyncreset, %v :
!firrtl.clock, !firrtl.asyncreset,
!firrtl.vector<sint<1>, 1>,
!firrtl.vector<sint<2>, 1>
firrtl.matchingconnect %out, %r : !firrtl.vector<sint<2>, 1>
}

// https://github.com/llvm/circt/issues/929
Expand Down