Skip to content

Commit bf51d58

Browse files
authored
[CIR] Upstream ShuffleOp for VectorType (#142288)
This change adds support for the Shuffle op for VectorType Issue #136487
1 parent 5e9527b commit bf51d58

File tree

9 files changed

+212
-5
lines changed

9 files changed

+212
-5
lines changed

clang/include/clang/CIR/Dialect/IR/CIRAttrConstraints.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
//===----------------------------------------------------------------------===//
32
//
43
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -39,4 +38,11 @@ def CIR_AnyIntOrFloatAttr : AnyAttrOf<[CIR_AnyIntAttr, CIR_AnyFPAttr],
3938
string cppType = "::mlir::TypedAttr";
4039
}
4140

42-
#endif // CLANG_CIR_DIALECT_IR_CIRATTRCONSTRAINTS_TD
41+
//===----------------------------------------------------------------------===//
42+
// ArrayAttr constraints
43+
//===----------------------------------------------------------------------===//
44+
45+
def CIR_IntArrayAttr : TypedArrayAttrBase<CIR_AnyIntAttr,
46+
"integer array attribute">;
47+
48+
#endif // CLANG_CIR_DIALECT_IR_CIRATTRCONSTRAINTS_TD

clang/include/clang/CIR/Dialect/IR/CIROps.td

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
include "clang/CIR/Dialect/IR/CIRDialect.td"
1818
include "clang/CIR/Dialect/IR/CIRTypes.td"
1919
include "clang/CIR/Dialect/IR/CIRAttrs.td"
20+
include "clang/CIR/Dialect/IR/CIRAttrConstraints.td"
2021

2122
include "clang/CIR/Interfaces/CIROpInterfaces.td"
2223
include "clang/CIR/Interfaces/CIRLoopOpInterface.td"
@@ -2155,6 +2156,52 @@ def VecCmpOp : CIR_Op<"vec.cmp", [Pure, SameTypeOperands]> {
21552156
}];
21562157
}
21572158

2159+
//===----------------------------------------------------------------------===//
2160+
// VecShuffleOp
2161+
//===----------------------------------------------------------------------===//
2162+
2163+
// TODO: Create an interface that both VecShuffleOp and VecShuffleDynamicOp
2164+
// implement. This could be useful for passes that don't care how the vector
2165+
// shuffle was specified.
2166+
2167+
def VecShuffleOp : CIR_Op<"vec.shuffle",
2168+
[Pure, AllTypesMatch<["vec1", "vec2"]>]> {
2169+
let summary = "Combine two vectors using indices passed as constant integers";
2170+
let description = [{
2171+
The `cir.vec.shuffle` operation implements the documented form of Clang's
2172+
`__builtin_shufflevector`, where the indices of the shuffled result are
2173+
integer constants.
2174+
2175+
The two input vectors, which must have the same type, are concatenated.
2176+
Each of the integer constant arguments is interpreted as an index into that
2177+
concatenated vector, with a value of -1 meaning that the result value
2178+
doesn't matter. The result vector, which must have the same element type as
2179+
the input vectors and the same number of elements as the list of integer
2180+
constant indices, is constructed by taking the elements at the given
2181+
indices from the concatenated vector. The size of the result vector does
2182+
not have to match the size of the individual input vectors or of the
2183+
concatenated vector.
2184+
2185+
```mlir
2186+
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<2 x !s32i>)
2187+
[#cir.int<3> : !s64i, #cir.int<1> : !s64i] : !cir.vector<2 x !s32i>
2188+
```
2189+
}];
2190+
2191+
let arguments = (ins
2192+
CIR_VectorType:$vec1,
2193+
CIR_VectorType:$vec2,
2194+
CIR_IntArrayAttr:$indices
2195+
);
2196+
2197+
let results = (outs CIR_VectorType:$result);
2198+
let assemblyFormat = [{
2199+
`(` $vec1 `,` $vec2 `:` qualified(type($vec1)) `)` $indices `:`
2200+
qualified(type($result)) attr-dict
2201+
}];
2202+
let hasVerifier = 1;
2203+
}
2204+
21582205
//===----------------------------------------------------------------------===//
21592206
// VecShuffleDynamicOp
21602207
//===----------------------------------------------------------------------===//

clang/lib/CIR/CodeGen/CIRGenExprScalar.cpp

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,9 +189,24 @@ class ScalarExprEmitter : public StmtVisitor<ScalarExprEmitter, mlir::Value> {
189189
cgf.getLoc(e->getSourceRange()), inputVec, indexVec);
190190
}
191191

192-
cgf.getCIRGenModule().errorNYI(e->getSourceRange(),
193-
"ShuffleVectorExpr with indices");
194-
return {};
192+
mlir::Value vec1 = Visit(e->getExpr(0));
193+
mlir::Value vec2 = Visit(e->getExpr(1));
194+
195+
// The documented form of __builtin_shufflevector, where the indices are
196+
// a variable number of integer constants. The constants will be stored
197+
// in an ArrayAttr.
198+
SmallVector<mlir::Attribute, 8> indices;
199+
for (unsigned i = 2; i < e->getNumSubExprs(); ++i) {
200+
indices.push_back(
201+
cir::IntAttr::get(cgf.builder.getSInt64Ty(),
202+
e->getExpr(i)
203+
->EvaluateKnownConstInt(cgf.getContext())
204+
.getSExtValue()));
205+
}
206+
207+
return cgf.builder.create<cir::VecShuffleOp>(
208+
cgf.getLoc(e->getSourceRange()), cgf.convertType(e->getType()), vec1,
209+
vec2, cgf.builder.getArrayAttr(indices));
195210
}
196211

197212
mlir::Value VisitConvertVectorExpr(ConvertVectorExpr *e) {

clang/lib/CIR/Dialect/IR/CIRDialect.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1579,6 +1579,29 @@ OpFoldResult cir::VecExtractOp::fold(FoldAdaptor adaptor) {
15791579
return elements[index];
15801580
}
15811581

1582+
//===----------------------------------------------------------------------===//
1583+
// VecShuffle
1584+
//===----------------------------------------------------------------------===//
1585+
1586+
LogicalResult cir::VecShuffleOp::verify() {
1587+
// The number of elements in the indices array must match the number of
1588+
// elements in the result type.
1589+
if (getIndices().size() != getResult().getType().getSize()) {
1590+
return emitOpError() << ": the number of elements in " << getIndices()
1591+
<< " and " << getResult().getType() << " don't match";
1592+
}
1593+
1594+
// The element types of the two input vectors and of the result type must
1595+
// match.
1596+
if (getVec1().getType().getElementType() !=
1597+
getResult().getType().getElementType()) {
1598+
return emitOpError() << ": element types of " << getVec1().getType()
1599+
<< " and " << getResult().getType() << " don't match";
1600+
}
1601+
1602+
return success();
1603+
}
1604+
15821605
//===----------------------------------------------------------------------===//
15831606
// VecShuffleDynamicOp
15841607
//===----------------------------------------------------------------------===//

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1770,6 +1770,7 @@ void ConvertCIRToLLVMPass::runOnOperation() {
17701770
CIRToLLVMVecExtractOpLowering,
17711771
CIRToLLVMVecInsertOpLowering,
17721772
CIRToLLVMVecCmpOpLowering,
1773+
CIRToLLVMVecShuffleOpLowering,
17731774
CIRToLLVMVecShuffleDynamicOpLowering,
17741775
CIRToLLVMVecTernaryOpLowering
17751776
// clang-format on
@@ -1922,6 +1923,23 @@ mlir::LogicalResult CIRToLLVMVecCmpOpLowering::matchAndRewrite(
19221923
return mlir::success();
19231924
}
19241925

1926+
mlir::LogicalResult CIRToLLVMVecShuffleOpLowering::matchAndRewrite(
1927+
cir::VecShuffleOp op, OpAdaptor adaptor,
1928+
mlir::ConversionPatternRewriter &rewriter) const {
1929+
// LLVM::ShuffleVectorOp takes an ArrayRef of int for the list of indices.
1930+
// Convert the ClangIR ArrayAttr of IntAttr constants into a
1931+
// SmallVector<int>.
1932+
SmallVector<int, 8> indices;
1933+
std::transform(
1934+
op.getIndices().begin(), op.getIndices().end(),
1935+
std::back_inserter(indices), [](mlir::Attribute intAttr) {
1936+
return mlir::cast<cir::IntAttr>(intAttr).getValue().getSExtValue();
1937+
});
1938+
rewriter.replaceOpWithNewOp<mlir::LLVM::ShuffleVectorOp>(
1939+
op, adaptor.getVec1(), adaptor.getVec2(), indices);
1940+
return mlir::success();
1941+
}
1942+
19251943
mlir::LogicalResult CIRToLLVMVecShuffleDynamicOpLowering::matchAndRewrite(
19261944
cir::VecShuffleDynamicOp op, OpAdaptor adaptor,
19271945
mlir::ConversionPatternRewriter &rewriter) const {

clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,16 @@ class CIRToLLVMVecCmpOpLowering
357357
mlir::ConversionPatternRewriter &) const override;
358358
};
359359

360+
class CIRToLLVMVecShuffleOpLowering
361+
: public mlir::OpConversionPattern<cir::VecShuffleOp> {
362+
public:
363+
using mlir::OpConversionPattern<cir::VecShuffleOp>::OpConversionPattern;
364+
365+
mlir::LogicalResult
366+
matchAndRewrite(cir::VecShuffleOp op, OpAdaptor,
367+
mlir::ConversionPatternRewriter &) const override;
368+
};
369+
360370
class CIRToLLVMVecShuffleDynamicOpLowering
361371
: public mlir::OpConversionPattern<cir::VecShuffleDynamicOp> {
362372
public:

clang/test/CIR/CodeGen/vector-ext.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1091,3 +1091,28 @@ void foo17() {
10911091
// OGCG: %[[VEC_A:.*]] = alloca <2 x double>, align 16
10921092
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
10931093
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
1094+
1095+
void foo19() {
1096+
vi4 a;
1097+
vi4 b;
1098+
vi4 u = __builtin_shufflevector(a, b, 7, 5, 3, 1);
1099+
}
1100+
1101+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
1102+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b"]
1103+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
1104+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
1105+
// CIR: %[[SHUF:.*]] = cir.vec.shuffle(%[[TMP_A]], %[[TMP_B]] : !cir.vector<4 x !s32i>) [#cir.int<7> :
1106+
// CIR-SAME: !s64i, #cir.int<5> : !s64i, #cir.int<3> : !s64i, #cir.int<1> : !s64i] : !cir.vector<4 x !s32i>
1107+
1108+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
1109+
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
1110+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
1111+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
1112+
// LLVM: %[[SHUF:.*]] = shufflevector <4 x i32> %[[TMP_A]], <4 x i32> %[[TMP_B]], <4 x i32> <i32 7, i32 5, i32 3, i32 1>
1113+
1114+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
1115+
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
1116+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
1117+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
1118+
// OGCG: %[[SHUF:.*]] = shufflevector <4 x i32> %[[TMP_A]], <4 x i32> %[[TMP_B]], <4 x i32> <i32 7, i32 5, i32 3, i32 1>

clang/test/CIR/CodeGen/vector.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1071,6 +1071,31 @@ void foo17() {
10711071
// OGCG: %[[TMP:.*]] = load <2 x double>, ptr %[[VEC_A]], align 16
10721072
// OGCG: %[[RES:.*]]= fptoui <2 x double> %[[TMP]] to <2 x i16>
10731073

1074+
void foo19() {
1075+
vi4 a;
1076+
vi4 b;
1077+
vi4 u = __builtin_shufflevector(a, b, 7, 5, 3, 1);
1078+
}
1079+
1080+
// CIR: %[[VEC_A:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["a"]
1081+
// CIR: %[[VEC_B:.*]] = cir.alloca !cir.vector<4 x !s32i>, !cir.ptr<!cir.vector<4 x !s32i>>, ["b"]
1082+
// CIR: %[[TMP_A:.*]] = cir.load{{.*}} %[[VEC_A]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
1083+
// CIR: %[[TMP_B:.*]] = cir.load{{.*}} %[[VEC_B]] : !cir.ptr<!cir.vector<4 x !s32i>>, !cir.vector<4 x !s32i>
1084+
// CIR: %[[SHUF:.*]] = cir.vec.shuffle(%[[TMP_A]], %[[TMP_B]] : !cir.vector<4 x !s32i>) [#cir.int<7> :
1085+
// CIR-SAME: !s64i, #cir.int<5> : !s64i, #cir.int<3> : !s64i, #cir.int<1> : !s64i] : !cir.vector<4 x !s32i>
1086+
1087+
// LLVM: %[[VEC_A:.*]] = alloca <4 x i32>, i64 1, align 16
1088+
// LLVM: %[[VEC_B:.*]] = alloca <4 x i32>, i64 1, align 16
1089+
// LLVM: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
1090+
// LLVM: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
1091+
// LLVM: %[[SHUF:.*]] = shufflevector <4 x i32> %[[TMP_A]], <4 x i32> %[[TMP_B]], <4 x i32> <i32 7, i32 5, i32 3, i32 1>
1092+
1093+
// OGCG: %[[VEC_A:.*]] = alloca <4 x i32>, align 16
1094+
// OGCG: %[[VEC_B:.*]] = alloca <4 x i32>, align 16
1095+
// OGCG: %[[TMP_A:.*]] = load <4 x i32>, ptr %[[VEC_A]], align 16
1096+
// OGCG: %[[TMP_B:.*]] = load <4 x i32>, ptr %[[VEC_B]], align 16
1097+
// OGCG: %[[SHUF:.*]] = shufflevector <4 x i32> %[[TMP_A]], <4 x i32> %[[TMP_B]], <4 x i32> <i32 7, i32 5, i32 3, i32 1>
1098+
10741099
void foo20() {
10751100
vi4 a;
10761101
vi4 b;

clang/test/CIR/IR/invalid-vector.cir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,41 @@ module {
88
cir.global external @vec_b = #cir.zero : !cir.vector<4 x !cir.array<!s32i x 10>>
99

1010
}
11+
12+
// -----
13+
14+
!s32i = !cir.int<s, 32>
15+
!s64i = !cir.int<s, 64>
16+
17+
module {
18+
cir.func @invalid_vector_shuffle() {
19+
%1 = cir.const #cir.int<1> : !s32i
20+
%2 = cir.const #cir.int<2> : !s32i
21+
%3 = cir.const #cir.int<3> : !s32i
22+
%4 = cir.const #cir.int<4> : !s32i
23+
%vec_1 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
24+
%vec_2 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
25+
// expected-error @below {{element types of '!cir.vector<4 x !cir.int<s, 32>>' and '!cir.vector<4 x !cir.int<s, 64>>' don't match}}
26+
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<7> : !s64i, #cir.int<5> : !s64i, #cir.int<3> : !s64i, #cir.int<1> : !s64i] : !cir.vector<4 x !s64i>
27+
cir.return
28+
}
29+
}
30+
31+
// -----
32+
33+
!s32i = !cir.int<s, 32>
34+
!s64i = !cir.int<s, 64>
35+
36+
module {
37+
cir.func @invalid_vector_shuffle() {
38+
%1 = cir.const #cir.int<1> : !s32i
39+
%2 = cir.const #cir.int<2> : !s32i
40+
%3 = cir.const #cir.int<3> : !s32i
41+
%4 = cir.const #cir.int<4> : !s32i
42+
%vec_1 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
43+
%vec_2 = cir.vec.create(%1, %2, %3, %4 : !s32i, !s32i, !s32i, !s32i) : !cir.vector<4 x !s32i>
44+
// expected-error @below {{the number of elements in [#cir.int<7> : !cir.int<s, 64>, #cir.int<5> : !cir.int<s, 64>, #cir.int<3> : !cir.int<s, 64>] and '!cir.vector<4 x !cir.int<s, 64>>' don't match}}
45+
%new_vec = cir.vec.shuffle(%vec_1, %vec_2 : !cir.vector<4 x !s32i>) [#cir.int<7> : !s64i, #cir.int<5> : !s64i, #cir.int<3> : !s64i] : !cir.vector<4 x !s64i>
46+
cir.return
47+
}
48+
}

0 commit comments

Comments
 (0)