Skip to content

Commit aea2d53

Browse files
authored
[MLIR][XeGPU] make offsets optional for create_nd_tdesc (#148335)
1 parent 867ff30 commit aea2d53

File tree

9 files changed

+233
-27
lines changed

9 files changed

+233
-27
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUOps.td

Lines changed: 41 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -110,23 +110,34 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
110110
Variadic<Index>: $offsets,
111111
Variadic<Index>: $shape,
112112
Variadic<Index>: $strides,
113-
DenseI64ArrayAttr: $const_offsets,
113+
OptionalAttr<DenseI64ArrayAttr>: $const_offsets,
114114
OptionalAttr<DenseI64ArrayAttr>: $const_shape,
115115
OptionalAttr<DenseI64ArrayAttr>: $const_strides
116116
);
117-
let results = (outs XeGPU_TensorDesc: $TensorDesc);
118117

119118
let assemblyFormat = [{
120119
$source ``
121-
custom<DynamicIndexList>($offsets, $const_offsets)
122-
(`,` custom<DynamicIndexList>($shape, $const_shape)^
123-
`,` custom<DynamicIndexList>($strides, $const_strides))?
120+
custom<OptionalDynamicIndexList>($offsets, $const_offsets)
121+
(`,` `shape` `:` custom<DynamicIndexList>($shape, $const_shape)^
122+
`,` `strides``:` custom<DynamicIndexList>($strides, $const_strides))?
124123
attr-dict `:` type($source) `->` qualified(type($TensorDesc))
125124
}];
126125

126+
let results = (outs XeGPU_TensorDesc: $TensorDesc);
127+
127128
let hasVerifier = 1;
128129

129130
let builders = [
131+
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source)>,
132+
133+
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType> ": $source,
134+
"llvm::ArrayRef<OpFoldResult>": $shape,
135+
"llvm::ArrayRef<OpFoldResult>": $strides)>,
136+
137+
OpBuilder<(ins "Type": $tdesc, "TypedValue<IntegerType> ": $source,
138+
"llvm::ArrayRef<OpFoldResult>": $shape,
139+
"llvm::ArrayRef<OpFoldResult>": $strides)>,
140+
130141
OpBuilder<(ins "Type": $tdesc, "TypedValue<MemRefType>": $source,
131142
"llvm::ArrayRef<OpFoldResult>": $offsets)>,
132143

@@ -163,7 +174,17 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
163174
}
164175

165176
ArrayRef<int64_t> getStaticOffsets(){
166-
return getConstOffsets();
177+
auto attr = getConstOffsetsAttr();
178+
179+
if (attr)
180+
return attr;
181+
182+
int64_t rank = getMixedSizes().size();
183+
184+
setConstOffsets(llvm::SmallVector<int64_t, 4>(rank, 0));
185+
186+
attr = getConstOffsetsAttr();
187+
return attr;
167188
}
168189

169190
/// wrapper for matching with OffsetSizeAndStrideOpInterface
@@ -172,10 +193,16 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
172193
/// and `const_shape` will be used to represent the shape of
173194
/// source operand. They overide static shape from source memref type.
174195
ArrayRef<int64_t> getStaticSizes() {
196+
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
197+
static llvm::SmallVector<int64_t, 4> emptyShape;
198+
175199
auto attr = getConstShapeAttr();
176-
if (llvm::isa<IntegerType>(getSourceType()) || attr)
200+
if (attr)
177201
return attr;
178202

203+
if (llvm::isa<IntegerType>(getSourceType()))
204+
return emptyShape;
205+
179206
auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
180207
assert(memrefType && "Incorrect use of getStaticSizes");
181208
return memrefType.getShape();
@@ -187,9 +214,15 @@ def XeGPU_CreateNdDescOp: XeGPU_Op<"create_nd_tdesc", [Pure, ViewLikeOpInterface
187214
/// and `const_strides` will be used to represent the strides of
188215
/// source operand. They overide static strides from source memref type.
189216
ArrayRef<int64_t> getStaticStrides() {
217+
/// To be compatible with OffsetSizeAndStrideOpInterface, which expects valid return value and perform checks
218+
static llvm::SmallVector<int64_t, 4> emptyStrides;
219+
190220
auto attr = getConstStridesAttr();
191-
if (llvm::isa<IntegerType>(getSourceType()) || attr)
221+
if (attr)
192222
return attr;
223+
224+
if (llvm::isa<IntegerType>(getSourceType()))
225+
return emptyStrides;
193226

194227
auto memrefType = llvm::dyn_cast<MemRefType>(getSourceType());
195228
assert(memrefType && "Incorrect use of getStaticStrides");

mlir/lib/Dialect/XeGPU/IR/XeGPUOps.cpp

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1313
#include "mlir/IR/Builders.h"
1414
#include "mlir/IR/TypeUtilities.h"
15+
#include "mlir/Interfaces/ViewLikeInterface.h"
1516

1617
#include "llvm/Support/Debug.h"
1718

@@ -112,6 +113,68 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
112113
//===----------------------------------------------------------------------===//
113114
// XeGPU_CreateNdDescOp
114115
//===----------------------------------------------------------------------===//
116+
117+
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
118+
Type tdesc, TypedValue<MemRefType> source) {
119+
[[maybe_unused]] auto ty = source.getType();
120+
assert(ty.hasStaticShape() && "expecting a memref with static shape");
121+
122+
build(builder, state, tdesc, source, ValueRange({}) /* dynamic offsets */,
123+
ValueRange({}) /* empty dynamic shape */,
124+
ValueRange({}) /* empty dynamic strides */,
125+
DenseI64ArrayAttr({}) /* const offsets */,
126+
DenseI64ArrayAttr({}) /* empty const shape*/,
127+
DenseI64ArrayAttr({}) /* empty const strides*/);
128+
}
129+
130+
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
131+
Type tdesc, TypedValue<MemRefType> source,
132+
llvm::ArrayRef<OpFoldResult> shape,
133+
llvm::ArrayRef<OpFoldResult> strides) {
134+
assert(shape.size() && strides.size() && shape.size() == strides.size() &&
135+
"Shape and strides must be present and of equal size for ui64 "
136+
"initialization.");
137+
138+
llvm::SmallVector<int64_t> staticShape;
139+
llvm::SmallVector<int64_t> staticStrides;
140+
llvm::SmallVector<Value> dynamicShape;
141+
llvm::SmallVector<Value> dynamicStrides;
142+
143+
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
144+
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
145+
146+
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
147+
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
148+
149+
build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
150+
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
151+
staticStridesAttr);
152+
}
153+
154+
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
155+
Type tdesc, TypedValue<IntegerType> source,
156+
llvm::ArrayRef<OpFoldResult> shape,
157+
llvm::ArrayRef<OpFoldResult> strides) {
158+
assert(shape.size() && strides.size() && shape.size() == strides.size() &&
159+
"Shape and strides must be present and of equal size for ui64 "
160+
"initialization.");
161+
162+
llvm::SmallVector<int64_t> staticShape;
163+
llvm::SmallVector<int64_t> staticStrides;
164+
llvm::SmallVector<Value> dynamicShape;
165+
llvm::SmallVector<Value> dynamicStrides;
166+
167+
dispatchIndexOpFoldResults(shape, dynamicShape, staticShape);
168+
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
169+
170+
auto staticShapeAttr = builder.getDenseI64ArrayAttr(staticShape);
171+
auto staticStridesAttr = builder.getDenseI64ArrayAttr(staticStrides);
172+
173+
build(builder, state, tdesc, source, ValueRange({}), dynamicShape,
174+
dynamicStrides, builder.getDenseI64ArrayAttr({}), staticShapeAttr,
175+
staticStridesAttr);
176+
}
177+
115178
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
116179
Type tdesc, TypedValue<MemRefType> source,
117180
llvm::ArrayRef<OpFoldResult> offsets) {
@@ -125,8 +188,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
125188
build(builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */,
126189
ValueRange({}) /* empty dynamic shape */,
127190
ValueRange({}) /* empty dynamic strides */,
128-
staticOffsets /* const offsets */, {} /* empty const shape*/,
129-
{} /* empty const strides*/);
191+
builder.getDenseI64ArrayAttr(staticOffsets) /* const offsets */,
192+
{} /* empty const shape*/, {} /* empty const strides*/);
130193
}
131194

132195
void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
@@ -197,6 +260,13 @@ LogicalResult CreateNdDescOp::verify() {
197260
invalidElemTy |= memrefTy.getElementType() != getElementType();
198261
}
199262

263+
if (llvm::isa<IntegerType>(getSourceType())) {
264+
// strides and shape must present for integer source.
265+
if (getMixedStrides().empty() || getMixedSizes().empty())
266+
return emitOpError("Expecting strides and shape to be present for "
267+
"integer source.");
268+
}
269+
200270
// mismatches among shape, strides, and offsets are
201271
// already handeled by OffsetSizeAndStrideOpInterface.
202272
// So they are not check here.
@@ -221,6 +291,53 @@ LogicalResult CreateNdDescOp::verify() {
221291
return success();
222292
}
223293

294+
ParseResult parseOptionalDynamicIndexList(
295+
OpAsmParser &parser,
296+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
297+
DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr,
298+
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
299+
300+
SmallVector<int64_t, 4> integerVals;
301+
auto parseIntegerOrValue = [&]() {
302+
OpAsmParser::UnresolvedOperand operand;
303+
auto res = parser.parseOptionalOperand(operand);
304+
305+
if (res.has_value() && succeeded(res.value())) {
306+
values.push_back(operand);
307+
integerVals.push_back(ShapedType::kDynamic);
308+
if (valueTypes && parser.parseColonType(valueTypes->emplace_back()))
309+
return failure();
310+
} else {
311+
int64_t integer;
312+
if (failed(parser.parseInteger(integer)))
313+
return failure();
314+
integerVals.push_back(integer);
315+
}
316+
return success();
317+
};
318+
319+
// If the optional values are given there must be left bracket
320+
if (parser.parseOptionalLSquare().succeeded()) {
321+
if (parser.parseCommaSeparatedList(parseIntegerOrValue) ||
322+
parser.parseRSquare())
323+
return parser.emitError(parser.getNameLoc())
324+
<< "expected a list of SSA values or integers";
325+
integers = parser.getBuilder().getDenseI64ArrayAttr(integerVals);
326+
return success();
327+
}
328+
329+
return success();
330+
}
331+
332+
void printOptionalDynamicIndexList(
333+
OpAsmPrinter &printer, Operation *op, OperandRange values,
334+
ArrayRef<int64_t> integers, TypeRange valueTypes = TypeRange(),
335+
AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
336+
337+
return printDynamicIndexList(printer, op, values, integers,
338+
/*scalableFlags=*/{}, valueTypes, delimiter);
339+
}
340+
224341
//===----------------------------------------------------------------------===//
225342
// XeGPU_PrefetchNdOp
226343
//===----------------------------------------------------------------------===//

mlir/test/Conversion/VectorToXeGPU/load-to-xegpu.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
5454
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
5555
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
5656
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
57-
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
57+
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
5858
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
5959
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
6060
// CHECK: return %[[VEC]]

mlir/test/Conversion/VectorToXeGPU/store-to-xegpu.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
5656
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
5757
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
5858
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
59-
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
59+
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
6060
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32,
6161
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
6262

mlir/test/Conversion/VectorToXeGPU/transfer-read-to-xegpu.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ func.func @load_dynamic_source(%source: memref<?x?x?xf32>,
9696
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
9797
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
9898
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
99-
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
99+
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
100100
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
101101
// CHECK: %[[VEC:.+]] = xegpu.load_nd %[[DESC]]{{.*}}-> vector<8x16xf32>
102102
// CHECK: return %[[VEC]]

mlir/test/Conversion/VectorToXeGPU/transfer-write-to-xegpu.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func.func @store_dynamic_source(%vec: vector<8x16xf32>,
6060
// CHECK-DAG: %[[DIM_2:.+]] = memref.dim %[[SRC]], %[[C2]]
6161
// CHECK: %[[DIM_0_STRIDE:.+]] = arith.muli %[[DIM_2]], %[[DIM_1]]
6262
// CHECK: %[[DESC:.+]] = xegpu.create_nd_tdesc %[[SRC]][%[[OFFSET]], %[[OFFSET]], %[[OFFSET]]]
63-
// CHECK-SAME: [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
63+
// CHECK-SAME: , shape : [%[[DIM_0]], %[[DIM_1]], %[[DIM_2]]], strides : [%[[DIM_0_STRIDE]], %[[DIM_2]], 1]
6464
// CHECK-SAME: memref<?x?x?xf32> -> !xegpu.tensor_desc<8x16xf32
6565
// CHECK: xegpu.store_nd %[[VEC]], %[[DESC]] : vector<8x16xf32>
6666

mlir/test/Dialect/XeGPU/invalid.mlir

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,55 +1,70 @@
11
// RUN: mlir-opt %s -split-input-file -verify-diagnostics
22

33
// -----
4-
func.func @create_nd_tdesc_vc_1(%src: memref<24xf32>) {
4+
func.func @create_nd_tdesc_1(%src: memref<24xf32>) {
55
// expected-error@+1 {{Expecting the TensorDesc rank is not greater than the ranks of shape, strides, offsets or the memref source}}
66
%1 = xegpu.create_nd_tdesc %src[0] : memref<24xf32> -> !xegpu.tensor_desc<8x16xf32>
77
return
88
}
99

1010
// -----
1111

12-
func.func @create_nd_tdesc_vc_2(%src: memref<24x32xf32>) {
12+
func.func @create_nd_tdesc_2(%src: memref<24x32xf32>) {
1313
// expected-error@+1 {{TensorDesc should have the same element type with the source if it is a memref}}
1414
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf16>
1515
return
1616
}
1717

1818
// -----
19-
func.func @create_nd_tdesc_vc_3(%src: memref<2x24x32xf32, 3>) {
19+
func.func @create_nd_tdesc_3(%src: memref<2x24x32xf32, 3>) {
2020
// expected-error@+1 {{SLM is only supported for 1D block tensor}}
2121
%1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<8x16xf32, #xegpu.block_tdesc_attr<memory_space = slm>>
2222
return
2323
}
2424

2525
// -----
26-
func.func @create_nd_tdesc_vc_4(%src: memref<2x24x32xf32, 3>) {
26+
func.func @create_nd_tdesc_4(%src: memref<2x24x32xf32, 3>) {
2727
// expected-error@+1 {{Memory space mismatch}}
2828
%1 = xegpu.create_nd_tdesc %src[0, 0, 0] : memref<2x24x32xf32, 3> -> !xegpu.tensor_desc<16xf32>
2929
return
3030
}
3131

3232
// -----
33-
func.func @create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
33+
func.func @create_nd_tdesc_5(%src: memref<128x128xf32>) {
3434
// expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [24, 48]>}}
3535
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [24, 48]>>
3636
return
3737
}
3838

3939
// -----
40-
func.func @create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
40+
func.func @create_nd_tdesc_6(%src: memref<128x128xf32>) {
4141
// expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [24, 48]>}}
4242
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [24, 48]>>
4343
return
4444
}
4545

4646
// -----
47-
func.func @create_nd_tdesc_subgroup_1(%src: memref<128x128xf32>) {
47+
func.func @create_nd_tdesc_7(%src: memref<128x128xf32>) {
4848
// expected-error@+1 {{cannot distribute [128, 128] using #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [64, 32]>}}
4949
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<128x128xf32> -> !xegpu.tensor_desc<128x128xf32, #xegpu.layout<sg_layout = [4, 2], sg_data = [32, 64], inst_data = [64, 32]>>
5050
return
5151
}
5252

53+
// -----
54+
func.func @create_nd_tdesc_8(%src: ui64) {
55+
// expected-error@+1 {{'xegpu.create_nd_tdesc' op Expecting strides and shape to be present for integer source}}
56+
%1 = xegpu.create_nd_tdesc %src : ui64-> !xegpu.tensor_desc<128x128xf32>
57+
return
58+
}
59+
60+
// -----
61+
func.func @create_nd_tdesc_9(%src: ui64) {
62+
// expected-error@+1 {{expected mixed offsets rank to match mixed sizes rank}}
63+
%1 = xegpu.create_nd_tdesc %src[0, 0] : ui64-> !xegpu.tensor_desc<128x128xf32>
64+
return
65+
}
66+
67+
5368
// -----
5469
func.func @prefetch_nd_vc_1(%src: memref<24x32xf16>) {
5570
%1 = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf16> -> !xegpu.tensor_desc<8x16xf16>

0 commit comments

Comments
 (0)