Skip to content

Commit 9703fc5

Browse files
[SYCL-MLIR] Add sycl.accessor.get_pointer (#8703)
This operation returns a pointer to the start of this accessor's memory. --------- Signed-off-by: Tsang, Whitney <[email protected]>
1 parent f236fa8 commit 9703fc5

File tree

7 files changed

+179
-12
lines changed

7 files changed

+179
-12
lines changed

mlir-sycl/include/mlir/Dialect/SYCL/IR/SYCLOps.td

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -668,6 +668,29 @@ def SYCLCallOp : SYCL_Op<"call", [CallOpInterface]> {
668668
}];
669669
}
670670

671+
////////////////////////////////////////////////////////////////////////////////
672+
// accessor.get_pointer OPERATION
673+
////////////////////////////////////////////////////////////////////////////////
674+
675+
def SYCLAccessorGetPointerOp
676+
: SYCLMethodOpInterfaceImpl<"accessor.get_pointer", "AccessorType",
677+
["get_pointer"]> {
678+
let summary = "Represents the accessor get_pointer operation";
679+
let description = [{
680+
Returns a pointer to the start of this accessor's memory.
681+
}];
682+
683+
let arguments = (ins Arg<AccessorMemRef, "The accessor", [MemRead]>:$Acc,
684+
TypeArrayAttr:$ArgumentTypes,
685+
FlatSymbolRefAttr:$FunctionName,
686+
OptionalAttr<FlatSymbolRefAttr>:$MangledFunctionName,
687+
FlatSymbolRefAttr:$TypeName);
688+
689+
let hasVerifier = 1;
690+
691+
let results = (outs AnyMemRef:$result);
692+
}
693+
671694
////////////////////////////////////////////////////////////////////////////////
672695
// accessor.size OPERATION
673696
////////////////////////////////////////////////////////////////////////////////

mlir-sycl/lib/Conversion/SYCLToLLVM/SYCLToLLVM.cpp

Lines changed: 74 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,11 @@ struct AccessorGetPtr : public OffsetTag {
133133
static constexpr std::array<int32_t, 2> indices{1, 0};
134134
};
135135

136+
/// Get the ID field from an accessor.
137+
struct AccessorGetID : public OffsetTag {
138+
static constexpr std::array<int32_t, 2> indices{0, 0};
139+
};
140+
136141
/// Get the MAccessRange field from an accessor.
137142
struct AccessorGetMAccessRange : public OffsetTag {
138143
static constexpr std::array<int32_t, 2> indices{0, 1};
@@ -1061,6 +1066,67 @@ class ConstructorPattern final
10611066
}
10621067
};
10631068

1069+
//===----------------------------------------------------------------------===//
1070+
// AccessorGetPointerPattern - Convert `sycl.accessor.get_pointer` to LLVM.
1071+
//===----------------------------------------------------------------------===//
1072+
1073+
class AccessorGetPointerPattern
1074+
: public ConvertOpToLLVMPattern<SYCLAccessorGetPointerOp>,
1075+
public GetMemberPattern<AccessorGetPtr>,
1076+
public GetMemberPattern<AccessorGetID, IDGetDim>,
1077+
public GetMemberPattern<AccessorGetMemRange, RangeGetDim> {
1078+
public:
1079+
using ConvertOpToLLVMPattern<
1080+
SYCLAccessorGetPointerOp>::ConvertOpToLLVMPattern;
1081+
1082+
private:
1083+
template <typename... Args> Value getID(Args &&...args) const {
1084+
return GetMemberPattern<AccessorGetID, IDGetDim>::loadValue(
1085+
std::forward<Args>(args)...);
1086+
}
1087+
template <typename... Args> Value getMemRange(Args &&...args) const {
1088+
return GetMemberPattern<AccessorGetMemRange, RangeGetDim>::loadValue(
1089+
std::forward<Args>(args)...);
1090+
}
1091+
1092+
Value getTotalOffset(OpBuilder &builder, Location loc, AccessorType accTy,
1093+
OpAdaptor opAdaptor) const {
1094+
const auto acc = opAdaptor.getAcc();
1095+
const auto resTy = builder.getI64Type();
1096+
Value res = builder.create<arith::ConstantIntOp>(loc, 0, resTy);
1097+
for (unsigned i = 0; i < accTy.getDimension(); ++i) {
1098+
// Res = Res * Mem[I] + Id[I]
1099+
const auto memI = getMemRange(builder, loc, resTy, acc, i);
1100+
const auto idI = getID(builder, loc, resTy, acc, i);
1101+
res = builder.create<arith::AddIOp>(
1102+
loc, builder.create<arith::MulIOp>(loc, res, memI), idI);
1103+
}
1104+
return res;
1105+
}
1106+
1107+
public:
1108+
LogicalResult
1109+
matchAndRewrite(SYCLAccessorGetPointerOp op, OpAdaptor opAdaptor,
1110+
ConversionPatternRewriter &rewriter) const final {
1111+
const auto loc = op.getLoc();
1112+
const Value zero = rewriter.create<arith::ConstantIntOp>(loc, 0, 64);
1113+
Value index = rewriter.create<arith::SubIOp>(
1114+
loc, zero,
1115+
getTotalOffset(
1116+
rewriter, loc,
1117+
op.getAcc().getType().getElementType().cast<AccessorType>(),
1118+
opAdaptor));
1119+
const auto ptrTy = getTypeConverter()
1120+
->convertType(op.getType())
1121+
.cast<LLVM::LLVMPointerType>();
1122+
Value ptr = GetMemberPattern<AccessorGetPtr>::loadValue(
1123+
rewriter, loc, ptrTy, opAdaptor.getAcc());
1124+
rewriter.replaceOpWithNewOp<LLVM::GEPOp>(op, ptrTy, ptr, index,
1125+
/*inbounds*/ true);
1126+
return success();
1127+
}
1128+
};
1129+
10641130
//===----------------------------------------------------------------------===//
10651131
// AccessorSizePattern - Convert `sycl.accessor.size` to LLVM.
10661132
//===----------------------------------------------------------------------===//
@@ -2123,13 +2189,14 @@ void mlir::populateSYCLToLLVMConversionPatterns(
21232189
patterns.add<CastPattern>(typeConverter);
21242190
patterns.add<BarePtrCastPattern>(typeConverter, /*benefit*/ 2);
21252191
patterns
2126-
.add<AccessorSizePattern, AddZeroArgPattern<SYCLIDGetOp>,
2127-
AddZeroArgPattern<SYCLItemGetIDOp>, AtomicSubscriptIDOffset,
2128-
BarePtrAddrSpaceCastPattern, GroupGetGroupIDPattern,
2129-
GroupGetGroupLinearRangePattern, GroupGetGroupRangeDimPattern,
2130-
GroupGetLocalIDPattern, GroupGetLocalLinearRangePattern,
2131-
GroupGetLocalRangeDimPattern, IDGetPattern, IDGetRefPattern,
2132-
ItemGetIDDimPattern, ItemGetRangeDimPattern, ItemGetRangePattern,
2192+
.add<AccessorGetPointerPattern, AccessorSizePattern,
2193+
AddZeroArgPattern<SYCLIDGetOp>, AddZeroArgPattern<SYCLItemGetIDOp>,
2194+
AtomicSubscriptIDOffset, BarePtrAddrSpaceCastPattern,
2195+
GroupGetGroupIDPattern, GroupGetGroupLinearRangePattern,
2196+
GroupGetGroupRangeDimPattern, GroupGetLocalIDPattern,
2197+
GroupGetLocalLinearRangePattern, GroupGetLocalRangeDimPattern,
2198+
IDGetPattern, IDGetRefPattern, ItemGetIDDimPattern,
2199+
ItemGetRangeDimPattern, ItemGetRangePattern,
21332200
NDItemGetGlobalIDDimPattern, NDItemGetGlobalIDPattern,
21342201
NDItemGetGroupPattern, NDItemGetGroupRangeDimPattern,
21352202
NDItemGetLocalIDDimPattern, NDItemGetLocalLinearIDPattern,

mlir-sycl/lib/Dialect/IR/SYCLOps.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,18 @@ bool SYCLAddrSpaceCastOp::areCastCompatible(TypeRange inputs,
8989
(outputMS == genericAddressSpace));
9090
}
9191

92+
LogicalResult SYCLAccessorGetPointerOp::verify() {
93+
const auto accTy = cast<AccessorType>(
94+
cast<MemRefType>(getOperand().getType()).getElementType());
95+
const Type resTy = getResult().getType();
96+
const Type resElemTy = cast<MemRefType>(resTy).getElementType();
97+
return (resElemTy != accTy.getType())
98+
? emitOpError(
99+
"Expecting a reference to this accessor's value type (")
100+
<< accTy.getType() << "). Got " << resTy
101+
: success();
102+
}
103+
92104
LogicalResult SYCLAccessorSubscriptOp::verify() {
93105
// Available only when: (Dimensions > 0)
94106
// reference operator[](id<Dimensions> index) const;
@@ -98,11 +110,8 @@ LogicalResult SYCLAccessorSubscriptOp::verify() {
98110

99111
// Available only when: (AccessMode != access_mode::atomic && Dimensions == 1)
100112
// reference operator[](size_t index) const;
101-
const auto AccessorTy = getOperand(0)
102-
.getType()
103-
.cast<MemRefType>()
104-
.getElementType()
105-
.cast<AccessorType>();
113+
const auto AccessorTy = cast<AccessorType>(
114+
cast<MemRefType>(getOperand(0).getType()).getElementType());
106115

107116
const unsigned Dimensions = AccessorTy.getDimension();
108117
if (Dimensions == 0)

mlir-sycl/test/Conversion/SYCLToLLVM/sycl-methods-to-llvm.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,38 @@ func.func @test(%id: memref<?x!sycl_id_3_>, %idx: i32) -> memref<?xi64> {
155155

156156
// -----
157157

158+
//===-------------------------------------------------------------------------------------------------===//
159+
// sycl.accessor.get_pointer
160+
//===-------------------------------------------------------------------------------------------------===//
161+
162+
!sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64>)>)>
163+
!sycl_range_1_ = !sycl.range<[1], (!sycl.array<[1], (memref<1xi64>)>)>
164+
!sycl_accessor_impl_device_1_ = !sycl.accessor_impl_device<[1], (!sycl_id_1_, !sycl_range_1_, !sycl_range_1_)>
165+
!sycl_accessor_1_i32_rw_gb = !sycl.accessor<[1, i32, read_write, global_buffer], (!sycl_accessor_impl_device_1_, !llvm.struct<(ptr<i32, 1>)>)>
166+
167+
// CHECK-LABEL: llvm.func @test(
168+
// CHECK-SAME: %[[VAL_0:.*]]: !llvm.ptr<[[ACCESSOR1:.*]]>) -> !llvm.ptr<i32, 1> {
169+
// CHECK-NEXT: %0 = llvm.mlir.constant(0 : i64) : i64
170+
// CHECK-NEXT: %1 = llvm.mlir.constant(0 : i64) : i64
171+
// CHECK-NEXT: %2 = llvm.getelementptr inbounds %arg0[0, 0, 2, 0, 0, 0] : (!llvm.ptr<[[ACCESSOR1]]>) -> !llvm.ptr<i64>
172+
// CHECK-NEXT: %3 = llvm.load %2 : !llvm.ptr<i64>
173+
// CHECK-NEXT: %4 = llvm.getelementptr inbounds %arg0[0, 0, 0, 0, 0, 0] : (!llvm.ptr<[[ACCESSOR1]]>) -> !llvm.ptr<i64>
174+
// CHECK-NEXT: %5 = llvm.load %4 : !llvm.ptr<i64>
175+
// CHECK-NEXT: %6 = llvm.mul %1, %3 : i64
176+
// CHECK-NEXT: %7 = llvm.add %6, %5 : i64
177+
// CHECK-NEXT: %8 = llvm.sub %0, %7 : i64
178+
// CHECK-NEXT: %9 = llvm.getelementptr inbounds %arg0[0, 1, 0] : (!llvm.ptr<[[ACCESSOR1]]>) -> !llvm.ptr<ptr<i32, 1>>
179+
// CHECK-NEXT: %10 = llvm.load %9 : !llvm.ptr<ptr<i32, 1>>
180+
// CHECK-NEXT: %11 = llvm.getelementptr inbounds %10[%8] : (!llvm.ptr<i32, 1>, i64) -> !llvm.ptr<i32, 1>
181+
// CHECK-NEXT: llvm.return %11 : !llvm.ptr<i32, 1>
182+
// CHECK-NEXT: }
183+
func.func @test(%acc: memref<?x!sycl_accessor_1_i32_rw_gb>) -> memref<?xi32, 1> {
184+
%0 = sycl.accessor.get_pointer(%acc) { ArgumentTypes = [memref<?x!sycl_accessor_1_i32_rw_gb>], FunctionName = @"get_pointer", MangledFunctionName = @"get_pointer", TypeName = @"accessor" } : (memref<?x!sycl_accessor_1_i32_rw_gb>) -> memref<?xi32, 1>
185+
return %0 : memref<?xi32, 1>
186+
}
187+
188+
// -----
189+
158190
//===-------------------------------------------------------------------------------------------------===//
159191
// sycl.accessor.size
160192
//===-------------------------------------------------------------------------------------------------===//

mlir-sycl/test/Dialect/IR/SYCL/invalid.mlir

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,19 @@ func.func @test_work_group_id_dim() -> index {
260260

261261
// -----
262262

263+
!sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64>)>)>
264+
!sycl_range_1_ = !sycl.range<[1], (!sycl.array<[1], (memref<1xi64>)>)>
265+
!sycl_accessor_impl_device_1_ = !sycl.accessor_impl_device<[1], (!sycl_id_1_, !sycl_range_1_, !sycl_range_1_)>
266+
!sycl_accessor_1_i32_rw_gb = !sycl.accessor<[1, i32, read_write, global_buffer], (!sycl_accessor_impl_device_1_, !llvm.struct<(ptr<i32, 1>)>)>
267+
268+
func.func @test_accessor_get_pointer(%acc: memref<?x!sycl_accessor_1_i32_rw_gb>) -> memref<?xi64, 1> {
269+
// expected-error @+1 {{'sycl.accessor.get_pointer' op Expecting a reference to this accessor's value type}}
270+
%0 = sycl.accessor.get_pointer(%acc) { ArgumentTypes = [memref<?x!sycl_accessor_1_i32_rw_gb>], FunctionName = @"get_pointer", MangledFunctionName = @"get_pointer", TypeName = @"accessor" } : (memref<?x!sycl_accessor_1_i32_rw_gb>) -> memref<?xi64, 1>
271+
return %0 : memref<?xi64, 1>
272+
}
273+
274+
// -----
275+
263276
!sycl_id_1_ = !sycl.id<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)>
264277
!sycl_range_1_ = !sycl.range<[1], (!sycl.array<[1], (memref<1xi64, 4>)>)>
265278
!sycl_accessor_impl_device_1_ = !sycl.accessor_impl_device<[1], (!sycl_id_1_, !sycl_range_1_, !sycl_range_1_)>

mlir-sycl/test/Dialect/IR/SYCL/ops.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,12 @@ func.func @test_work_group_id_const() -> index {
260260
return %0 : index
261261
}
262262

263+
// CHECL-LABEL: test_accessor_get_pointer
264+
func.func @test_accessor_get_pointer(%acc: memref<?x!sycl_accessor_1_i32_w_gb>) -> memref<?xi32, 1> {
265+
%0 = sycl.accessor.get_pointer(%acc) { ArgumentTypes = [memref<?x!sycl_accessor_1_i32_w_gb>], FunctionName = @"get_pointer", MangledFunctionName = @"get_pointer", TypeName = @"accessor" } : (memref<?x!sycl_accessor_1_i32_w_gb>) -> memref<?xi32, 1>
266+
return %0 : memref<?xi32, 1>
267+
}
268+
263269
// CHECK-LABEL: test_accessor_subscript_atomic
264270
func.func @test_accessor_subscript_atomic(
265271
%acc: memref<?x!sycl_accessor_1_i32_ato_gb>,

polygeist/tools/cgeist/Test/Verification/sycl/functions.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,23 @@
4040

4141
template <typename T> SYCL_EXTERNAL void keep(T);
4242

43+
// COM: Commenting out the checks below, this is the code that should be
44+
// generated. Currently the DPC++ SYCL RT implementation of
45+
// accessor::get_pointer is non-conforming. Once that problem is fixed we
46+
// enable the checks below (with the MangledFunctionName fixed).
47+
48+
// COM-MLIR-LABEL: func.func @_Z20accessor_get_pointerN4sycl3_V18accessorIiLi2ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEEE(
49+
// COM-MLIR: %{{.*}}: memref<?x!sycl_accessor_2_i32_rw_gb> {llvm.align = 8 : i64, llvm.byval = !sycl_accessor_2_i32_rw_gb, llvm.noundef})
50+
// COM-MLIR: %{{.*}} = sycl.accessor.get_pointer(%{{.*}}) {ArgumentTypes = [memref<?x!sycl_accessor_2_i32_rw_gb, 4>], FunctionName = @get_pointer, MangledFunctionName = @_ZNK4sycl3_V18accessorIiLi2ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEE11get_pointerILS4_2014EvEENS0_9IiLNS2_13address_spaceE1ELNS2_9decoratedE2EEEv, TypeName = @accessor} : (memref<?x!sycl_accessor_2_i32_rw_gb>) -> memref<?xi32, 1>
51+
52+
// COM-LLVM-LABEL: define spir_func void @_Z20accessor_get_pointerN4sycl3_V18accessorIiLi2ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEEE(
53+
// COM-LLVM: %"class.sycl::_V1::accessor.2"* noundef byval(%"class.sycl::_V1::accessor.2") align 8 %0) #[[FUNCATTRS:[0-9]+]] {
54+
// COM-LLVM: %{{.*}} = call spir_func i32 addrspace(1)* @_ZNK4sycl3_V18accessorIiLi2ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEE11get_pointerILS4_2014EvEENS0_9IiLNS2_13address_spaceE1ELNS2_9decoratedE2EEEv(%"class.sycl::_V1::accessor.2" addrspace(4)* %{{.*}})
55+
56+
SYCL_EXTERNAL void accessor_get_pointer(sycl::accessor<sycl::cl_int, 2> acc) {
57+
keep(acc.get_pointer());
58+
}
59+
4360
// CHECK-MLIR-LABEL: func.func @_Z13accessor_sizeN4sycl3_V18accessorIiLi2ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEEE(
4461
// CHECK-MLIR: %{{.*}}: memref<?x!sycl_accessor_2_i32_rw_gb> {llvm.align = 8 : i64, llvm.byval = !sycl_accessor_2_i32_rw_gb, llvm.noundef})
4562
// CHECK-MLIR: %{{.*}} = sycl.accessor.size(%arg0) {ArgumentTypes = [memref<?x!sycl_accessor_2_i32_rw_gb, 4>], FunctionName = @size, MangledFunctionName = @_ZNK4sycl3_V18accessorIiLi2ELNS0_6access4modeE1026ELNS2_6targetE2014ELNS2_11placeholderE0ENS0_3ext6oneapi22accessor_property_listIJEEEE4sizeEv, TypeName = @accessor} : (memref<?x!sycl_accessor_2_i32_rw_gb>) -> i64

0 commit comments

Comments
 (0)