Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROO-70] [mlir] [arith] emulate wide int #3

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 122 additions & 1 deletion mlir/lib/Dialect/Arith/Transforms/EmulateWideInt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,126 @@ struct ConvertUIToFP final : OpConversionPattern<arith::UIToFPOp> {
}
};

//===----------------------------------------------------------------------===//
// ConvertFPToSI
//===----------------------------------------------------------------------===//

struct ConvertFPToSI final : OpConversionPattern<arith::FPToSIOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::FPToSIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
/* Get the input float type */
auto inFp = adaptor.getIn();
auto fpTy = inFp.getType();
auto fpElemTy = getElementTypeOrSelf(fpTy);

Type intTy = op.getType();
unsigned oldBitWidth = getElementTypeOrSelf(intTy).getIntOrFloatBitWidth();

auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
if (!newTy)
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unsupported type: {0}", intTy));

/*
Work on the absolute value and then convert the result to signed integer.
Defer absolute value to fptoui. If minSInt < fp < maxSInt, i.e.
if the fp is representable in signed i2N, emits the correct result.
Else, the result is UB.
*/
TypedAttr zeroAttr = rewriter.getFloatAttr(fpElemTy, 0.0);

if (auto vecTy = dyn_cast<VectorType>(fpTy))
zeroAttr = SplatElementsAttr::get(vecTy, zeroAttr);

Value zeroCst = rewriter.create<arith::ConstantOp>(loc, zeroAttr);

Value oneCst = createScalarOrSplatConstant(rewriter, loc, intTy, 1);
Value allOnesCst = createScalarOrSplatConstant(
rewriter, loc, intTy, APInt::getAllOnes(oldBitWidth));

/* Get the absolute value */
Value isNeg = rewriter.create<arith::CmpFOp>(loc, arith::CmpFPredicate::OLT,
inFp, zeroCst);
Value negInFp = rewriter.create<arith::NegFOp>(loc, inFp);

Value absVal = rewriter.create<arith::SelectOp>(loc, isNeg, negInFp, inFp);

/* Defer the absolute value to fptoui */
Value res = rewriter.create<arith::FPToUIOp>(loc, intTy, absVal);

/* Negate the value if < 0 */
Value bitwiseNeg = rewriter.create<arith::XOrIOp>(loc, res, allOnesCst);
Value neg = rewriter.create<arith::AddIOp>(loc, bitwiseNeg, oneCst);

rewriter.replaceOpWithNewOp<arith::SelectOp>(op, isNeg, neg, res);
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertFPToUI
//===----------------------------------------------------------------------===//

struct ConvertFPToUI final : OpConversionPattern<arith::FPToUIOp> {
using OpConversionPattern::OpConversionPattern;

LogicalResult
matchAndRewrite(arith::FPToUIOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
/* Get the input float type */
auto inFp = adaptor.getIn();
auto fpTy = inFp.getType();

Type intTy = op.getType();
auto newTy = getTypeConverter()->convertType<VectorType>(intTy);
if (!newTy)
return rewriter.notifyMatchFailure(
loc, llvm::formatv("unsupported type: {0}", intTy));
unsigned newBitWidth = newTy.getElementTypeBitWidth();
Type newHalfType = IntegerType::get(inFp.getContext(), newBitWidth);
if (auto vecType = dyn_cast<VectorType>(fpTy))
newHalfType = VectorType::get(vecType.getShape(), newHalfType);
/*
The resulting integer has the upper part and the lower part.
This would be interpreted as 2^N * high + low, where N is the bitwidth.
Therefore, to calculate the higher part, we emit resHigh = fptoui(fp/2^N).
For the lower part, we emit fptoui(fp - resHigh * 2^N).
The special cases of overflows including +-inf, NaNs and negative numbers
are UB.
*/
double powBitwidth = (uint64_t(1) << newBitWidth);
TypedAttr powBitwidthAttr =
FloatAttr::get(getElementTypeOrSelf(fpTy), powBitwidth);
if (auto vecType = dyn_cast<VectorType>(fpTy))
powBitwidthAttr = SplatElementsAttr::get(vecType, powBitwidthAttr);
Value powBitwidthFloatCst =
rewriter.create<arith::ConstantOp>(loc, powBitwidthAttr);

Value fpDivPowBitwidth =
rewriter.create<arith::DivFOp>(loc, inFp, powBitwidthFloatCst);
Value resHigh =
rewriter.create<arith::FPToUIOp>(loc, newHalfType, fpDivPowBitwidth);
// Calculate fp - resHigh * 2^N by getting the remainder of the division
Value remainder =
rewriter.create<arith::RemFOp>(loc, inFp, powBitwidthFloatCst);
Value resLow =
rewriter.create<arith::FPToUIOp>(loc, newHalfType, remainder);

auto high = appendX1Dim(rewriter, loc, resHigh);
auto low = appendX1Dim(rewriter, loc, resLow);

auto resultVec = constructResultVector(rewriter, loc, newTy, {low, high});

rewriter.replaceOp(op, resultVec);
return success();
}
};

//===----------------------------------------------------------------------===//
// ConvertTruncI
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1150,5 +1270,6 @@ void arith::populateArithWideIntEmulationPatterns(
ConvertIndexCastIntToIndex<arith::IndexCastUIOp>,
ConvertIndexCastIndexToInt<arith::IndexCastOp, arith::ExtSIOp>,
ConvertIndexCastIndexToInt<arith::IndexCastUIOp, arith::ExtUIOp>,
ConvertSIToFP, ConvertUIToFP>(typeConverter, patterns.getContext());
ConvertSIToFP, ConvertUIToFP, ConvertFPToUI, ConvertFPToSI>(
typeConverter, patterns.getContext());
}
124 changes: 124 additions & 0 deletions mlir/test/Dialect/Arith/emulate-wide-int.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1007,3 +1007,127 @@ func.func @sitofp_i64_f64_vector(%a : vector<3xi64>) -> vector<3xf64> {
%r = arith.sitofp %a : vector<3xi64> to vector<3xf64>
return %r : vector<3xf64>
}

// CHECK-LABEL: func @fptoui_i64_f64
// CHECK-SAME: ([[ARG:%.+]]: f64) -> vector<2xi32>
// CHECK-NEXT: [[POW:%.+]] = arith.constant 0x41F0000000000000 : f64
// CHECK-NEXT: [[DIV:%.+]] = arith.divf [[ARG]], [[POW]] : f64
// CHECK-NEXT: [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : f64 to i32
// CHECK-NEXT: [[REM:%.+]] = arith.remf [[ARG]], [[POW]] : f64
// CHECK-NEXT: [[LOWHALF:%.+]] = arith.fptoui [[REM]] : f64 to i32
// CHECK: %{{.+}} = vector.insert [[LOWHALF]], %{{.+}} [0]
// CHECK-NEXT: [[RESVEC:%.+]] = vector.insert [[HIGHHALF]], %{{.+}} [1]
// CHECK: return [[RESVEC]] : vector<2xi32>
func.func @fptoui_i64_f64(%a : f64) -> i64 {
%r = arith.fptoui %a : f64 to i64
return %r : i64
}

// CHECK-LABEL: func @fptoui_i64_f64_vector
// CHECK-SAME: ([[ARG:%.+]]: vector<3xf64>) -> vector<3x2xi32>
// CHECK-NEXT: [[POW:%.+]] = arith.constant dense<0x41F0000000000000> : vector<3xf64>
// CHECK-NEXT: [[DIV:%.+]] = arith.divf [[ARG]], [[POW]] : vector<3xf64>
// CHECK-NEXT: [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : vector<3xf64> to vector<3xi32>
// CHECK-NEXT: [[REM:%.+]] = arith.remf [[ARG]], [[POW]] : vector<3xf64>
// CHECK-NEXT: [[LOWHALF:%.+]] = arith.fptoui [[REM]] : vector<3xf64> to vector<3xi32>
// CHECK-DAG: [[HIGHHALFX1:%.+]] = vector.shape_cast [[HIGHHALF]] : vector<3xi32> to vector<3x1xi32>
// CHECK-DAG: [[LOWHALFX1:%.+]] = vector.shape_cast [[LOWHALF]] : vector<3xi32> to vector<3x1xi32>
// CHECK: %{{.+}} = vector.insert_strided_slice [[LOWHALFX1]], %{{.+}} {offsets = [0, 0], strides = [1, 1]}
// CHECK-NEXT: [[RESVEC:%.+]] = vector.insert_strided_slice [[HIGHHALFX1]], %{{.+}} {offsets = [0, 1], strides = [1, 1]}
// CHECK: return [[RESVEC]] : vector<3x2xi32>
func.func @fptoui_i64_f64_vector(%a : vector<3xf64>) -> vector<3xi64> {
%r = arith.fptoui %a : vector<3xf64> to vector<3xi64>
return %r : vector<3xi64>
}

// This generates lines that are already verified by other patterns
// We do not re-verify these and just check for the wrapper around fptoui by following its low part
// CHECK-LABEL: func @fptosi_i64_f64
// CHECK-SAME: ([[ARG:%.+]]: f64) -> vector<2xi32>
// CHECK: [[ZEROCST:%.+]] = arith.constant 0.000000e+00 : f64
// CHECK: [[ONECST:%.+]] = arith.constant dense<[1, 0]> : vector<2xi32>
// CHECK: [[ALLONECST:%.+]] = arith.constant dense<-1> : vector<2xi32>
// CHECK-NEXT: [[ISNEGATIVE:%.+]] = arith.cmpf olt, [[ARG]], [[ZEROCST]] : f64
// CHECK-NEXT: [[NEGATED:%.+]] = arith.negf [[ARG]] : f64
// CHECK-NEXT: [[ABSVALUE:%.+]] = arith.select [[ISNEGATIVE]], [[NEGATED]], [[ARG]] : f64
// CHECK-NEXT: [[POW:%.+]] = arith.constant 0x41F0000000000000 : f64
// CHECK-NEXT: [[DIV:%.+]] = arith.divf [[ABSVALUE]], [[POW]] : f64
// CHECK-NEXT: [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : f64 to i32
// CHECK-NEXT: [[REM:%.+]] = arith.remf [[ABSVALUE]], [[POW]] : f64
// CHECK-NEXT: [[LOWHALF:%.+]] = arith.fptoui [[REM]] : f64 to i32
// CHECK: vector.insert [[LOWHALF]], %{{.+}} [0] : i32 into vector<2xi32>
// CHECK-NEXT: [[FPTOUIRESVEC:%.+]] = vector.insert [[HIGHHALF]]
// CHECK: [[ALLONECSTHALF:%.+]] = vector.extract [[ALLONECST]][0] : i32 from vector<2xi32>
// CHECK: [[XOR:%.+]] = arith.xori %{{.+}}, [[ALLONECSTHALF]] : i32
// CHECK-NEXT: arith.xori
// CHECK: vector.insert [[XOR]]
// CHECK-NEXT: [[XORVEC:%.+]] = vector.insert
// CHECK: [[XOR:%.+]] = vector.extract [[XORVEC]][0] : i32 from vector<2xi32>
// CHECK: [[ONECSTHALF:%.+]] = vector.extract [[ONECST]][0] : i32 from vector<2xi32>
// CHECK: [[SUM:%.+]], %{{.+}} = arith.addui_extended [[XOR]], [[ONECSTHALF]] : i32, i1
// CHECK-NEXT: arith.extui
// CHECK-NEXT: arith.addi
// CHECK-NEXT: arith.addi
// CHECK: vector.insert [[SUM]]
// CHECK-NEXT: [[SUMVEC:%.+]] = vector.insert
// CHECK: [[NEGATEDRES:%.+]] = vector.extract [[SUMVEC]][0] : i32 from vector<2xi32>
// CHECK: [[LOWRES:%.+]] = vector.extract [[FPTOUIRESVEC]][0] : i32 from vector<2xi32>
// CHECK: [[ABSRES:%.+]] = arith.select [[ISNEGATIVE]], [[NEGATEDRES]], [[LOWRES]] : i32
// CHECK-NEXT: arith.select [[ISNEGATIVE]]
// CHECK: vector.insert [[ABSRES]]
// CHECK-NEXT: [[ABSRESVEC:%.+]] = vector.insert
// CHECK-NEXT: return [[ABSRESVEC]] : vector<2xi32>
func.func @fptosi_i64_f64(%a : f64) -> i64 {
%r = arith.fptosi %a : f64 to i64
return %r : i64
}

// Same as the non-vector one, we don't re-verify
// CHECK-LABEL: func @fptosi_i64_f64_vector
// CHECK-SAME: ([[ARG:%.+]]: vector<3xf64>) -> vector<3x2xi32>
// CHECK-NEXT: [[ZEROCST:%.+]] = arith.constant dense<0.000000e+00> : vector<3xf64>
// CHECK-NEXT: [[ONECST:%.+]] = arith.constant
// CHECK-SAME{LITERAL} dense<[[1, 0], [1, 0], [1, 0]]> : vector<3x2xi32>
// CHECK-NEXT: [[ALLONECST:%.+]] = arith.constant dense<-1> : vector<3x2xi32>
// CHECK-NEXT: [[ISNEGATIVE:%.+]] = arith.cmpf olt, [[ARG]], [[ZEROCST]] : vector<3xf64>
// CHECK-NEXT: [[NEGATED:%.+]] = arith.negf [[ARG]] : vector<3xf64>
// CHECK-NEXT: [[ABSVALUE:%.+]] = arith.select [[ISNEGATIVE]], [[NEGATED]], [[ARG]] : vector<3xi1>, vector<3xf64>
// CHECK-NEXT: [[POW:%.+]] = arith.constant dense<0x41F0000000000000> : vector<3xf64>
// CHECK-NEXT: [[DIV:%.+]] = arith.divf [[ABSVALUE]], [[POW]] : vector<3xf64>
// CHECK-NEXT: [[HIGHHALF:%.+]] = arith.fptoui [[DIV]] : vector<3xf64> to vector<3xi32>
// CHECK-NEXT: [[REM:%.+]] = arith.remf [[ABSVALUE]], [[POW]] : vector<3xf64>
// CHECK-NEXT: [[LOWHALF:%.+]] = arith.fptoui [[REM]] : vector<3xf64> to vector<3xi32>
// CHECK-NEXT: [[HIGHHALFX1:%.+]] = vector.shape_cast [[HIGHHALF]] : vector<3xi32> to vector<3x1xi32>
// CHECK-NEXT: [[LOWHALFX1:%.+]] = vector.shape_cast [[LOWHALF]] : vector<3xi32> to vector<3x1xi32>
// CHECK: vector.insert_strided_slice [[LOWHALFX1]], %{{.+}} {offsets = [0, 0], strides = [1, 1]} : vector<3x1xi32> into vector<3x2xi32>
// CHECK-NEXT: [[FPTOUIRESVEC:%.+]] = vector.insert_strided_slice [[HIGHHALFX1]]
// CHECK: [[ALLONECSTHALF:%.+]] = vector.extract_strided_slice [[ALLONECST]]
// CHECK-SAME: {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
// CHECK: [[XOR:%.+]] = arith.xori %{{.+}}, [[ALLONECSTHALF]] : vector<3x1xi32>
// CHECK-NEXT: arith.xori
// CHECK: vector.insert_strided_slice [[XOR]]
// CHECK-NEXT: [[XORVEC:%.+]] = vector.insert_strided_slice
// CHECK: [[XOR:%.+]] = vector.extract_strided_slice [[XORVEC]]
// CHECK-SAME: {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
// CHECK: [[ONECSTHALF:%.+]] = vector.extract_strided_slice [[ONECST]]
// CHECK-SAME: {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
// CHECK: [[SUM:%.+]], %{{.+}} = arith.addui_extended [[XOR]], [[ONECSTHALF]] : vector<3x1xi32>, vector<3x1xi1>
// CHECK-NEXT: arith.extui
// CHECK-NEXT: arith.addi
// CHECK-NEXT: arith.addi
// CHECK: vector.insert_strided_slice [[SUM]]
// CHECK-NEXT: [[SUMVEC:%.+]] = vector.insert_strided_slice
// CHECK: [[NEGATEDRES:%.+]] = vector.extract_strided_slice [[SUMVEC]]
// CHECK-SAME: {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
// CHECK: [[LOWRES:%.+]] = vector.extract_strided_slice [[FPTOUIRESVEC]]
// CHECK-SAME: {offsets = [0, 0], sizes = [3, 1], strides = [1, 1]} : vector<3x2xi32> to vector<3x1xi32>
// CHECK: [[ISNEGATIVEX1:%.+]] = vector.shape_cast [[ISNEGATIVE]] : vector<3xi1> to vector<3x1xi1>
// CHECK: [[ABSRES:%.+]] = arith.select [[ISNEGATIVEX1]], [[NEGATEDRES]], [[LOWRES]] : vector<3x1xi1>, vector<3x1xi32>
// CHECK-NEXT: arith.select [[ISNEGATIVEX1]]
// CHECK: vector.insert_strided_slice [[ABSRES]]
// CHECK-NEXT: [[ABSRESVEC:%.+]] = vector.insert_strided_slice
// CHECK-NEXT: return [[ABSRESVEC]] : vector<3x2xi32>
func.func @fptosi_i64_f64_vector(%a : vector<3xf64>) -> vector<3xi64> {
%r = arith.fptosi %a : vector<3xf64> to vector<3xi64>
return %r : vector<3xi64>
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
// Check that the wide integer `arith.fptosi` emulation produces the same result as wide
// `arith.fptosi`. Emulate i64 ops with i32 ops.

// RUN: mlir-opt %s --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: --shared-libs=%mlir_c_runner_utils | \
// RUN: FileCheck %s --match-full-lines

// RUN: mlir-opt %s --test-arith-emulate-wide-int="widest-int-supported=32" \
// RUN: --convert-scf-to-cf --convert-cf-to-llvm --convert-vector-to-llvm \
// RUN: --convert-func-to-llvm --convert-arith-to-llvm | \
// RUN: mlir-cpu-runner -e entry -entry-point-result=void \
// RUN: --shared-libs=%mlir_c_runner_utils | \
// RUN: FileCheck %s --match-full-lines

// Ops in this function *only* will be emulated using i32 types.
func.func @emulate_fptosi(%arg: f64) -> i64 {
%res = arith.fptosi %arg : f64 to i64
return %res : i64
}

func.func @check_fptosi(%arg : f64) -> () {
%res = func.call @emulate_fptosi(%arg) : (f64) -> (i64)
vector.print %res : i64
return
}

func.func @entry() {
%cst0 = arith.constant 0.0 : f64
%cst_nzero = arith.constant 0x8000000000000000 : f64
%cst1 = arith.constant 1.0 : f64
%cst_n1 = arith.constant -1.0 : f64
%cst_n1_5 = arith.constant -1.5 : f64

%cstpow20 = arith.constant 1048576.0 : f64
%cstnpow20 = arith.constant -1048576.0 : f64

%cst_i32_max = arith.constant 4294967295.0 : f64
%cst_i32_min = arith.constant -4294967296.0 : f64
%cst_i32_overflow = arith.constant 4294967296.0 : f64
%cst_i32_noverflow = arith.constant -4294967297.0 : f64


%cstpow40 = arith.constant 1099511627776.0 : f64
%cstnpow40 = arith.constant -1099511627776.0 : f64
%cst_pow40ppow20 = arith.constant 1099512676352.0 : f64
%cst_npow40ppow20 = arith.constant -1099512676352.0 : f64

// CHECK: 0
func.call @check_fptosi(%cst0) : (f64) -> ()
// CHECK-NEXT: 0
func.call @check_fptosi(%cst_nzero) : (f64) -> ()
// CHECK-NEXT: 1
func.call @check_fptosi(%cst1) : (f64) -> ()
// CHECK-NEXT: -1
func.call @check_fptosi(%cst_n1) : (f64) -> ()
// CHECK-NEXT: -1
func.call @check_fptosi(%cst_n1_5) : (f64) -> ()
// CHECK-NEXT: 1048576
func.call @check_fptosi(%cstpow20) : (f64) -> ()
// CHECK-NEXT: -1048576
func.call @check_fptosi(%cstnpow20) : (f64) -> ()
// CHECK-NEXT: 4294967295
func.call @check_fptosi(%cst_i32_max) : (f64) -> ()
// CHECK-NEXT: -4294967296
func.call @check_fptosi(%cst_i32_min) : (f64) -> ()
// CHECK-NEXT: 4294967296
func.call @check_fptosi(%cst_i32_overflow) : (f64) -> ()
// CHECK-NEXT: -4294967297
func.call @check_fptosi(%cst_i32_noverflow) : (f64) -> ()
// CHECK-NEXT: 1099511627776
func.call @check_fptosi(%cstpow40) : (f64) -> ()
// CHECK-NEXT: -1099511627776
func.call @check_fptosi(%cstnpow40) : (f64) -> ()
// CHECK-NEXT: 1099512676352
func.call @check_fptosi(%cst_pow40ppow20) : (f64) -> ()
// CHECK-NEXT: -1099512676352
func.call @check_fptosi(%cst_npow40ppow20) : (f64) -> ()

return
}
Loading