@@ -133,6 +133,11 @@ struct AccessorGetPtr : public OffsetTag {
133
133
static constexpr std::array<int32_t , 2 > indices{1 , 0 };
134
134
};
135
135
136
+ // / Get the ID field from an accessor.
137
+ struct AccessorGetID : public OffsetTag {
138
+ static constexpr std::array<int32_t , 2 > indices{0 , 0 };
139
+ };
140
+
136
141
// / Get the MAccessRange field from an accessor.
137
142
struct AccessorGetMAccessRange : public OffsetTag {
138
143
static constexpr std::array<int32_t , 2 > indices{0 , 1 };
@@ -1061,6 +1066,67 @@ class ConstructorPattern final
1061
1066
}
1062
1067
};
1063
1068
1069
+ // ===----------------------------------------------------------------------===//
1070
+ // AccessorGetPointerPattern - Convert `sycl.accessor.get_pointer` to LLVM.
1071
+ // ===----------------------------------------------------------------------===//
1072
+
1073
+ class AccessorGetPointerPattern
1074
+ : public ConvertOpToLLVMPattern<SYCLAccessorGetPointerOp>,
1075
+ public GetMemberPattern<AccessorGetPtr>,
1076
+ public GetMemberPattern<AccessorGetID, IDGetDim>,
1077
+ public GetMemberPattern<AccessorGetMemRange, RangeGetDim> {
1078
+ public:
1079
+ using ConvertOpToLLVMPattern<
1080
+ SYCLAccessorGetPointerOp>::ConvertOpToLLVMPattern;
1081
+
1082
+ private:
1083
+ template <typename ... Args> Value getID (Args &&...args) const {
1084
+ return GetMemberPattern<AccessorGetID, IDGetDim>::loadValue (
1085
+ std::forward<Args>(args)...);
1086
+ }
1087
+ template <typename ... Args> Value getMemRange (Args &&...args) const {
1088
+ return GetMemberPattern<AccessorGetMemRange, RangeGetDim>::loadValue (
1089
+ std::forward<Args>(args)...);
1090
+ }
1091
+
1092
+ Value getTotalOffset (OpBuilder &builder, Location loc, AccessorType accTy,
1093
+ OpAdaptor opAdaptor) const {
1094
+ const auto acc = opAdaptor.getAcc ();
1095
+ const auto resTy = builder.getI64Type ();
1096
+ Value res = builder.create <arith::ConstantIntOp>(loc, 0 , resTy);
1097
+ for (unsigned i = 0 ; i < accTy.getDimension (); ++i) {
1098
+ // Res = Res * Mem[I] + Id[I]
1099
+ const auto memI = getMemRange (builder, loc, resTy, acc, i);
1100
+ const auto idI = getID (builder, loc, resTy, acc, i);
1101
+ res = builder.create <arith::AddIOp>(
1102
+ loc, builder.create <arith::MulIOp>(loc, res, memI), idI);
1103
+ }
1104
+ return res;
1105
+ }
1106
+
1107
+ public:
1108
+ LogicalResult
1109
+ matchAndRewrite (SYCLAccessorGetPointerOp op, OpAdaptor opAdaptor,
1110
+ ConversionPatternRewriter &rewriter) const final {
1111
+ const auto loc = op.getLoc ();
1112
+ const Value zero = rewriter.create <arith::ConstantIntOp>(loc, 0 , 64 );
1113
+ Value index = rewriter.create <arith::SubIOp>(
1114
+ loc, zero,
1115
+ getTotalOffset (
1116
+ rewriter, loc,
1117
+ op.getAcc ().getType ().getElementType ().cast <AccessorType>(),
1118
+ opAdaptor));
1119
+ const auto ptrTy = getTypeConverter ()
1120
+ ->convertType (op.getType ())
1121
+ .cast <LLVM::LLVMPointerType>();
1122
+ Value ptr = GetMemberPattern<AccessorGetPtr>::loadValue (
1123
+ rewriter, loc, ptrTy, opAdaptor.getAcc ());
1124
+ rewriter.replaceOpWithNewOp <LLVM::GEPOp>(op, ptrTy, ptr, index,
1125
+ /* inbounds*/ true );
1126
+ return success ();
1127
+ }
1128
+ };
1129
+
1064
1130
// ===----------------------------------------------------------------------===//
1065
1131
// AccessorSizePattern - Convert `sycl.accessor.size` to LLVM.
1066
1132
// ===----------------------------------------------------------------------===//
@@ -2123,13 +2189,14 @@ void mlir::populateSYCLToLLVMConversionPatterns(
2123
2189
patterns.add <CastPattern>(typeConverter);
2124
2190
patterns.add <BarePtrCastPattern>(typeConverter, /* benefit*/ 2 );
2125
2191
patterns
2126
- .add <AccessorSizePattern, AddZeroArgPattern<SYCLIDGetOp>,
2127
- AddZeroArgPattern<SYCLItemGetIDOp>, AtomicSubscriptIDOffset,
2128
- BarePtrAddrSpaceCastPattern, GroupGetGroupIDPattern,
2129
- GroupGetGroupLinearRangePattern, GroupGetGroupRangeDimPattern,
2130
- GroupGetLocalIDPattern, GroupGetLocalLinearRangePattern,
2131
- GroupGetLocalRangeDimPattern, IDGetPattern, IDGetRefPattern,
2132
- ItemGetIDDimPattern, ItemGetRangeDimPattern, ItemGetRangePattern,
2192
+ .add <AccessorGetPointerPattern, AccessorSizePattern,
2193
+ AddZeroArgPattern<SYCLIDGetOp>, AddZeroArgPattern<SYCLItemGetIDOp>,
2194
+ AtomicSubscriptIDOffset, BarePtrAddrSpaceCastPattern,
2195
+ GroupGetGroupIDPattern, GroupGetGroupLinearRangePattern,
2196
+ GroupGetGroupRangeDimPattern, GroupGetLocalIDPattern,
2197
+ GroupGetLocalLinearRangePattern, GroupGetLocalRangeDimPattern,
2198
+ IDGetPattern, IDGetRefPattern, ItemGetIDDimPattern,
2199
+ ItemGetRangeDimPattern, ItemGetRangePattern,
2133
2200
NDItemGetGlobalIDDimPattern, NDItemGetGlobalIDPattern,
2134
2201
NDItemGetGroupPattern, NDItemGetGroupRangeDimPattern,
2135
2202
NDItemGetLocalIDDimPattern, NDItemGetLocalLinearIDPattern,
0 commit comments