Skip to content

Commit 995d74f

Browse files
authored
[CIR] Implement folder for VecTernaryOp (#142946)
This change adds a folder for the VecTernaryOp Issue #136487
1 parent bf51d58 commit 995d74f

File tree

4 files changed

+58
-3
lines changed

4 files changed

+58
-3
lines changed

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2270,7 +2270,9 @@ def VecTernaryOp : CIR_Op<"vec.ternary",
22702270
`(` $cond `,` $lhs`,` $rhs `)` `:` qualified(type($cond)) `,`
22712271
qualified(type($lhs)) attr-dict
22722272
}];
2273+
22732274
let hasVerifier = 1;
2275+
let hasFolder = 1;
22742276
}
22752277

22762278
#endif // CLANG_CIR_DIALECT_IR_CIROPS_TD

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1664,6 +1664,38 @@ LogicalResult cir::VecTernaryOp::verify() {
16641664
return success();
16651665
}
16661666

1667+
OpFoldResult cir::VecTernaryOp::fold(FoldAdaptor adaptor) {
1668+
mlir::Attribute cond = adaptor.getCond();
1669+
mlir::Attribute lhs = adaptor.getLhs();
1670+
mlir::Attribute rhs = adaptor.getRhs();
1671+
1672+
if (!mlir::isa_and_nonnull<cir::ConstVectorAttr>(cond) ||
1673+
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(lhs) ||
1674+
!mlir::isa_and_nonnull<cir::ConstVectorAttr>(rhs))
1675+
return {};
1676+
auto condVec = mlir::cast<cir::ConstVectorAttr>(cond);
1677+
auto lhsVec = mlir::cast<cir::ConstVectorAttr>(lhs);
1678+
auto rhsVec = mlir::cast<cir::ConstVectorAttr>(rhs);
1679+
1680+
mlir::ArrayAttr condElts = condVec.getElts();
1681+
1682+
SmallVector<mlir::Attribute, 16> elements;
1683+
elements.reserve(condElts.size());
1684+
1685+
for (const auto &[idx, condAttr] :
1686+
llvm::enumerate(condElts.getAsRange<cir::IntAttr>())) {
1687+
if (condAttr.getSInt()) {
1688+
elements.push_back(lhsVec.getElts()[idx]);
1689+
} else {
1690+
elements.push_back(rhsVec.getElts()[idx]);
1691+
}
1692+
}
1693+
1694+
cir::VectorType vecTy = getLhs().getType();
1695+
return cir::ConstVectorAttr::get(
1696+
vecTy, mlir::ArrayAttr::get(getContext(), elements));
1697+
}
1698+
16671699
//===----------------------------------------------------------------------===//
16681700
// TableGen'd op method definitions
16691701
//===----------------------------------------------------------------------===//

clang/lib/CIR/Dialect/Transforms/CIRCanonicalize.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,10 +138,11 @@ void CIRCanonicalizePass::runOnOperation() {
138138
assert(!cir::MissingFeatures::complexRealOp());
139139
assert(!cir::MissingFeatures::complexImagOp());
140140
assert(!cir::MissingFeatures::callOp());
141-
// CastOp, UnaryOp, VecExtractOp and VecShuffleDynamicOp are here to perform
142-
// a manual `fold` in applyOpPatternsGreedily.
141+
142+
// Many operations are here to perform a manual `fold` in
143+
// applyOpPatternsGreedily.
143144
if (isa<BrOp, BrCondOp, CastOp, ScopeOp, SwitchOp, SelectOp, UnaryOp,
144-
VecExtractOp, VecShuffleDynamicOp>(op))
145+
VecExtractOp, VecShuffleDynamicOp, VecTernaryOp>(op))
145146
ops.push_back(op);
146147
});
147148

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
// RUN: cir-opt %s -cir-canonicalize -o - | FileCheck %s
2+
3+
!s32i = !cir.int<s, 32>
4+
5+
module {
6+
cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> {
7+
%cond = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<0> : !s32i, #cir.int<1> : !s32i, #cir.int<0> : !s32i]> : !cir.vector<4 x !s32i>
8+
%lhs = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<2> : !s32i, #cir.int<3> : !s32i, #cir.int<4> : !s32i]> : !cir.vector<4 x !s32i>
9+
%rhs = cir.const #cir.const_vector<[#cir.int<5> : !s32i, #cir.int<6> : !s32i, #cir.int<7> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
10+
%res = cir.vec.ternary(%cond, %lhs, %rhs) : !cir.vector<4 x !s32i>, !cir.vector<4 x !s32i>
11+
cir.return %res : !cir.vector<4 x !s32i>
12+
}
13+
14+
// [1, 0, 1, 0] ? [1, 2, 3, 4] : [5, 6, 7, 8] Will be fold to [1, 6, 3, 8]
15+
// CHECK: cir.func @vector_ternary_fold_test() -> !cir.vector<4 x !s32i> {
16+
// CHECK-NEXT: %[[RES:.*]] = cir.const #cir.const_vector<[#cir.int<1> : !s32i, #cir.int<6> : !s32i, #cir.int<3> : !s32i, #cir.int<8> : !s32i]> : !cir.vector<4 x !s32i>
17+
// CHECK-NEXT: cir.return %[[RES]] : !cir.vector<4 x !s32i>
18+
}
19+
20+

0 commit comments

Comments
 (0)