Skip to content

Commit 5123d36

Browse files
authored
[mlir][amdgpu] Lower make_gather_dma_descriptor. (#172083)
* Makes `MakeDescriptorOp` a template for `make_dma_descriptor` and `make_gather_dma_descriptor`. * Makes verification and folder for `make_dma_descriptor` a template. * Adds custom verification and folder for `make_dma_gather_descriptor` based on tempalte. * Adds `make_gather_dma_descriptor` op. * Lowers `make_gather_dma_descriptor` to ROCDL.
1 parent 3e32735 commit 5123d36

File tree

7 files changed

+489
-121
lines changed

7 files changed

+489
-121
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 71 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,7 +1303,7 @@ def AMDGPU_MakeGatherDmaBaseOp : AMDGPU_DmaBaseOp<"make_gather_dma_base", AMDGPU
13031303
let hasVerifier = 1;
13041304

13051305
let extraClassDeclaration = [{
1306-
constexpr bool isGather() {
1306+
static constexpr bool isGather() {
13071307
return true;
13081308
}
13091309
}];
@@ -1354,16 +1354,17 @@ def AMDGPU_MakeDmaBaseOp : AMDGPU_DmaBaseOp<"make_dma_base", AMDGPU_TDMBaseType>
13541354
let hasVerifier = 1;
13551355

13561356
let extraClassDeclaration = [{
1357-
constexpr bool isGather() {
1357+
static constexpr bool isGather() {
13581358
return false;
13591359
}
13601360
}];
13611361
}
13621362

1363-
def AMDGPU_MakeDmaDescriptorOp :
1364-
AMDGPU_Op<"make_dma_descriptor", [Pure, AttrSizedOperandSegments]>,
1365-
Arguments<(ins
1366-
AMDGPU_TDMBaseType: $base,
1363+
class AMDGPU_MakeDescriptorOp<string mnemonic> :
1364+
AMDGPU_Op<mnemonic, [Pure, AttrSizedOperandSegments]>,
1365+
Results<(outs AMDGPU_TDMDescriptorType: $desc)> {
1366+
1367+
dag baseArgs = (ins
13671368
Variadic<Index>: $global_dynamic_sizes,
13681369
DenseI64ArrayAttr: $global_static_sizes,
13691370
Variadic<Index>: $global_dynamic_strides,
@@ -1378,9 +1379,66 @@ def AMDGPU_MakeDmaDescriptorOp :
13781379
Variadic<Index>: $atomic_barrier_indices,
13791380
Optional<Index>: $global_increment,
13801381
Optional<I32>: $lds_increment,
1381-
Optional<Index>: $iteration_count)>,
1382-
Results<(outs AMDGPU_TDMDescriptorType: $desc)> {
1382+
Optional<Index>: $iteration_count);
1383+
1384+
code extraClassDeclarationBase = [{
1385+
int64_t getRank() {
1386+
return getGlobalStaticSizes().size();
1387+
}
1388+
1389+
unsigned getElementTypeWidth() {
1390+
return getBase().getType().getElementType().getIntOrFloatBitWidth();
1391+
}
1392+
1393+
SmallVector<OpFoldResult> getMixedGlobalSizes() {
1394+
return getMixedValues(getGlobalStaticSizes(), getGlobalDynamicSizes(), getContext());
1395+
}
1396+
1397+
SmallVector<OpFoldResult> getMixedGlobalStrides() {
1398+
return getMixedValues(getGlobalStaticStrides(), getGlobalDynamicStrides(), getContext());
1399+
}
1400+
1401+
SmallVector<OpFoldResult> getMixedSharedSizes() {
1402+
return getMixedValues(getSharedStaticSizes(), getSharedDynamicSizes(), getContext());
1403+
}
1404+
1405+
}];
1406+
1407+
}
1408+
1409+
def AMDGPU_MakeGatherDmaDescriptorOp : AMDGPU_MakeDescriptorOp<"make_gather_dma_descriptor"> {
1410+
dag args = (ins AMDGPU_TDMGatherBaseType: $base,
1411+
AnyTypeOf<[VectorOfMinMaxLengthAndType<1, 8, [I32]>,
1412+
VectorOfMinMaxLengthAndType<1, 16, [I16]>]>: $indices);
1413+
let arguments = !con(args, baseArgs);
1414+
let summary = "Make all descriptor groups needed by TensorLoadToLDS/TensorStoreFromLDS.";
1415+
1416+
let assemblyFormat = [{
1417+
$base `[` $indices `]`
1418+
`globalSize` custom<DynamicIndexList>($global_dynamic_sizes, $global_static_sizes)
1419+
`globalStride` custom<DynamicIndexList>($global_dynamic_strides, $global_static_strides)
1420+
`sharedSize` custom<DynamicIndexList>($shared_dynamic_sizes, $shared_static_sizes)
1421+
( `padShared` `(` $pad_amount^ `every` $pad_interval `)` )?
1422+
( `workgroupMask` $workgroup_mask^ ( `earlyTimeout` $early_timeout^)?)?
1423+
( `atomicBarrier` `(` $atomic_barrier_address^ `[` $atomic_barrier_indices `]`
1424+
`:` type($atomic_barrier_address) `)`)?
1425+
( `iterate` $global_increment^ `,` $lds_increment `,` $iteration_count )?
1426+
attr-dict `:` qualified(type($base)) `,` type($indices) `->` type(results)
1427+
}];
1428+
1429+
let hasVerifier = 1;
1430+
let hasFolder = 1;
1431+
1432+
let extraClassDeclaration = extraClassDeclarationBase # [{
1433+
static constexpr bool isGather() {
1434+
return true;
1435+
}
1436+
}];
1437+
}
13831438

1439+
def AMDGPU_MakeDmaDescriptorOp : AMDGPU_MakeDescriptorOp<"make_dma_descriptor"> {
1440+
dag args = (ins AMDGPU_TDMBaseType: $base);
1441+
let arguments = !con(args, baseArgs);
13841442
let summary = "Make all descriptor groups needed by TensorLoadToLDS/TensorStoreFromLDS.";
13851443
let description = [{
13861444
Make all descriptor groups needed by tensor memory operations.
@@ -1437,30 +1495,15 @@ def AMDGPU_MakeDmaDescriptorOp :
14371495
attr-dict `:` qualified(type($base)) `->` type(results)
14381496
}];
14391497

1440-
let extraClassDeclaration = [{
1441-
int64_t getRank() {
1442-
return getGlobalStaticSizes().size();
1443-
}
1444-
1445-
unsigned getElementTypeWidth() {
1446-
return getBase().getType().getElementType().getIntOrFloatBitWidth();
1447-
}
1448-
1449-
SmallVector<OpFoldResult> getMixedGlobalSizes() {
1450-
return getMixedValues(getGlobalStaticSizes(), getGlobalDynamicSizes(), getContext());
1451-
}
1452-
1453-
SmallVector<OpFoldResult> getMixedGlobalStrides() {
1454-
return getMixedValues(getGlobalStaticStrides(), getGlobalDynamicStrides(), getContext());
1455-
}
1498+
let hasVerifier = 1;
1499+
let hasFolder = 1;
14561500

1457-
SmallVector<OpFoldResult> getMixedSharedSizes() {
1458-
return getMixedValues(getSharedStaticSizes(), getSharedDynamicSizes(), getContext());
1501+
let extraClassDeclaration = extraClassDeclarationBase # [{
1502+
static constexpr bool isGather() {
1503+
return false;
14591504
}
14601505
}];
14611506

1462-
let hasVerifier = 1;
1463-
let hasFolder = 1;
14641507
}
14651508

14661509
#endif // AMDGPU

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,18 @@ class IsVectorOfLengthPred<list<int> allowedLengths> :
537537
== }]
538538
# allowedlength>)>]>;
539539

540+
// Whether the number of elements of a vector is greater than
541+
// or equal to `minLength`.
542+
class IsVectorOfMinLengthPred<int minLength> :
543+
And<[IsVectorOfNonZeroRankTypePred,
544+
CPred<"::llvm::cast<::mlir::VectorType>($_self).getNumElements() >= " # minLength>]>;
545+
546+
// Whether the number of elements of a vector is less than
547+
// or equal to `maxLength`.
548+
class IsVectorOfMaxLengthPred<int maxLength> :
549+
And<[IsVectorOfNonZeroRankTypePred,
550+
CPred<"::llvm::cast<::mlir::VectorType>($_self).getNumElements() <= " # maxLength>]>;
551+
540552
// Whether the number of elements of a fixed-length vector is from the given
541553
// `allowedLengths` list
542554
class IsFixedVectorOfLengthPred<list<int> allowedLengths> :
@@ -600,6 +612,20 @@ class VectorOfLength<list<int> allowedLengths> : Type<
600612
" of length " # !interleave(allowedLengths, "/"),
601613
"::mlir::VectorType">;
602614

615+
// Any vector where the number of elements is more than
616+
// or equal to minLength.
617+
class VectorOfMinLength<int minLength> : Type<
618+
IsVectorOfMinLengthPred<minLength>,
619+
" of at least length " # minLength,
620+
"::mlir::VectorType">;
621+
622+
// Any vector where the number of elements is less than
623+
// or equal to maxLength.
624+
class VectorOfMaxLength<int maxLength> : Type<
625+
IsVectorOfMaxLengthPred<maxLength>,
626+
" of at most length " # maxLength,
627+
"::mlir::VectorType">;
628+
603629
// Any fixed-length vector where the number of elements is from the given
604630
// `allowedLengths` list
605631
class FixedVectorOfLength<list<int> allowedLengths> : Type<
@@ -623,6 +649,14 @@ class VectorOfLengthAndType<list<int> allowedLengths,
623649
VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfLength<allowedLengths>.summary,
624650
"::mlir::VectorType">;
625651

652+
// Any vector where the number of elements is between
653+
// `minLength` and `maxLength` (inclusive).
654+
class VectorOfMinMaxLengthAndType<int minLength, int maxLength,
655+
list<Type> allowedTypes> : AllOfType<
656+
[VectorOfNonZeroRankOf<allowedTypes>, VectorOfMinLength<minLength>, VectorOfMaxLength<maxLength>],
657+
VectorOfNonZeroRankOf<allowedTypes>.summary # VectorOfMinLength<minLength>.summary # VectorOfMaxLength<maxLength>.summary,
658+
"::mlir::VectorType">;
659+
626660
class FixedVectorOfShapeAndType<list<int> shape, Type elType>: ShapedContainerType<
627661
[elType],
628662
And<[IsVectorOfShape<shape>, IsFixedVectorOfAnyRankTypePred]>,

0 commit comments

Comments
 (0)