12
12
#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
13
13
#include " mlir/IR/Builders.h"
14
14
#include " mlir/IR/TypeUtilities.h"
15
+ #include " mlir/Interfaces/ViewLikeInterface.h"
15
16
16
17
#include " llvm/Support/Debug.h"
17
18
@@ -112,6 +113,68 @@ isValidGatherScatterParams(Type maskTy, VectorType valueTy,
112
113
// ===----------------------------------------------------------------------===//
113
114
// XeGPU_CreateNdDescOp
114
115
// ===----------------------------------------------------------------------===//
116
+
117
+ void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
118
+ Type tdesc, TypedValue<MemRefType> source) {
119
+ [[maybe_unused]] auto ty = source.getType ();
120
+ assert (ty.hasStaticShape () && " expecting a memref with static shape" );
121
+
122
+ build (builder, state, tdesc, source, ValueRange ({}) /* dynamic offsets */ ,
123
+ ValueRange ({}) /* empty dynamic shape */ ,
124
+ ValueRange ({}) /* empty dynamic strides */ ,
125
+ DenseI64ArrayAttr ({}) /* const offsets */ ,
126
+ DenseI64ArrayAttr ({}) /* empty const shape*/ ,
127
+ DenseI64ArrayAttr ({}) /* empty const strides*/ );
128
+ }
129
+
130
+ void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
131
+ Type tdesc, TypedValue<MemRefType> source,
132
+ llvm::ArrayRef<OpFoldResult> shape,
133
+ llvm::ArrayRef<OpFoldResult> strides) {
134
+ assert (shape.size () && strides.size () && shape.size () == strides.size () &&
135
+ " Shape and strides must be present and of equal size for ui64 "
136
+ " initialization." );
137
+
138
+ llvm::SmallVector<int64_t > staticShape;
139
+ llvm::SmallVector<int64_t > staticStrides;
140
+ llvm::SmallVector<Value> dynamicShape;
141
+ llvm::SmallVector<Value> dynamicStrides;
142
+
143
+ dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
144
+ dispatchIndexOpFoldResults (strides, dynamicStrides, staticStrides);
145
+
146
+ auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
147
+ auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
148
+
149
+ build (builder, state, tdesc, source, ValueRange ({}), dynamicShape,
150
+ dynamicStrides, builder.getDenseI64ArrayAttr ({}), staticShapeAttr,
151
+ staticStridesAttr);
152
+ }
153
+
154
+ void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
155
+ Type tdesc, TypedValue<IntegerType> source,
156
+ llvm::ArrayRef<OpFoldResult> shape,
157
+ llvm::ArrayRef<OpFoldResult> strides) {
158
+ assert (shape.size () && strides.size () && shape.size () == strides.size () &&
159
+ " Shape and strides must be present and of equal size for ui64 "
160
+ " initialization." );
161
+
162
+ llvm::SmallVector<int64_t > staticShape;
163
+ llvm::SmallVector<int64_t > staticStrides;
164
+ llvm::SmallVector<Value> dynamicShape;
165
+ llvm::SmallVector<Value> dynamicStrides;
166
+
167
+ dispatchIndexOpFoldResults (shape, dynamicShape, staticShape);
168
+ dispatchIndexOpFoldResults (strides, dynamicStrides, staticStrides);
169
+
170
+ auto staticShapeAttr = builder.getDenseI64ArrayAttr (staticShape);
171
+ auto staticStridesAttr = builder.getDenseI64ArrayAttr (staticStrides);
172
+
173
+ build (builder, state, tdesc, source, ValueRange ({}), dynamicShape,
174
+ dynamicStrides, builder.getDenseI64ArrayAttr ({}), staticShapeAttr,
175
+ staticStridesAttr);
176
+ }
177
+
115
178
void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
116
179
Type tdesc, TypedValue<MemRefType> source,
117
180
llvm::ArrayRef<OpFoldResult> offsets) {
@@ -125,8 +188,8 @@ void CreateNdDescOp::build(OpBuilder &builder, OperationState &state,
125
188
build (builder, state, tdesc, source, dynamicOffsets /* dynamic offsets */ ,
126
189
ValueRange ({}) /* empty dynamic shape */ ,
127
190
ValueRange ({}) /* empty dynamic strides */ ,
128
- staticOffsets /* const offsets */ , {} /* empty const shape */ ,
129
- {} /* empty const strides*/ );
191
+ builder. getDenseI64ArrayAttr ( staticOffsets) /* const offsets */ ,
192
+ {} /* empty const shape */ , {} /* empty const strides*/ );
130
193
}
131
194
132
195
void CreateNdDescOp::build (OpBuilder &builder, OperationState &state,
@@ -197,6 +260,13 @@ LogicalResult CreateNdDescOp::verify() {
197
260
invalidElemTy |= memrefTy.getElementType () != getElementType ();
198
261
}
199
262
263
+ if (llvm::isa<IntegerType>(getSourceType ())) {
264
+ // strides and shape must present for integer source.
265
+ if (getMixedStrides ().empty () || getMixedSizes ().empty ())
266
+ return emitOpError (" Expecting strides and shape to be present for "
267
+ " integer source." );
268
+ }
269
+
200
270
// mismatches among shape, strides, and offsets are
201
271
// already handeled by OffsetSizeAndStrideOpInterface.
202
272
// So they are not check here.
@@ -221,6 +291,53 @@ LogicalResult CreateNdDescOp::verify() {
221
291
return success ();
222
292
}
223
293
294
+ ParseResult parseOptionalDynamicIndexList (
295
+ OpAsmParser &parser,
296
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
297
+ DenseI64ArrayAttr &integers, SmallVectorImpl<Type> *valueTypes = nullptr ,
298
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
299
+
300
+ SmallVector<int64_t , 4 > integerVals;
301
+ auto parseIntegerOrValue = [&]() {
302
+ OpAsmParser::UnresolvedOperand operand;
303
+ auto res = parser.parseOptionalOperand (operand);
304
+
305
+ if (res.has_value () && succeeded (res.value ())) {
306
+ values.push_back (operand);
307
+ integerVals.push_back (ShapedType::kDynamic );
308
+ if (valueTypes && parser.parseColonType (valueTypes->emplace_back ()))
309
+ return failure ();
310
+ } else {
311
+ int64_t integer;
312
+ if (failed (parser.parseInteger (integer)))
313
+ return failure ();
314
+ integerVals.push_back (integer);
315
+ }
316
+ return success ();
317
+ };
318
+
319
+ // If the optional values are given there must be left bracket
320
+ if (parser.parseOptionalLSquare ().succeeded ()) {
321
+ if (parser.parseCommaSeparatedList (parseIntegerOrValue) ||
322
+ parser.parseRSquare ())
323
+ return parser.emitError (parser.getNameLoc ())
324
+ << " expected a list of SSA values or integers" ;
325
+ integers = parser.getBuilder ().getDenseI64ArrayAttr (integerVals);
326
+ return success ();
327
+ }
328
+
329
+ return success ();
330
+ }
331
+
332
+ void printOptionalDynamicIndexList (
333
+ OpAsmPrinter &printer, Operation *op, OperandRange values,
334
+ ArrayRef<int64_t > integers, TypeRange valueTypes = TypeRange(),
335
+ AsmParser::Delimiter delimiter = AsmParser::Delimiter::Square) {
336
+
337
+ return printDynamicIndexList (printer, op, values, integers,
338
+ /* scalableFlags=*/ {}, valueTypes, delimiter);
339
+ }
340
+
224
341
// ===----------------------------------------------------------------------===//
225
342
// XeGPU_PrefetchNdOp
226
343
// ===----------------------------------------------------------------------===//
0 commit comments