@@ -2013,9 +2013,36 @@ class DefaultAttrsIntrinsicFlags<list<LLVMType> ret_types,
2013
2013
!foreach(i, !range(flags),
2014
2014
ImmArg<ArgIndex<!add(i, !size(param_types))>>))>;
2015
2015
2016
- // Intrinsics for Tensor Copy using TMA
2017
- // G2S -> From Global to Shared memory variants
2018
- // S2G -> From Shared to Global memory variants
2016
+ // TMA Tensor Copy Intrinsics: S2G -> From Shared to Global memory variants
2017
+ foreach dim = 1...5 in {
2018
+ defvar tensor_dim_args = !listsplat(llvm_i32_ty, dim);
2019
+ foreach mode = !if(!ge(dim, 3), ["tile", "im2col"], ["tile"]) in {
2020
+ def int_nvvm_cp_async_bulk_tensor_s2g_ # mode # _ # dim # d :
2021
+ DefaultAttrsIntrinsicFlags<[],
2022
+ !listconcat([llvm_shared_ptr_ty, // src_smem_ptr
2023
+ llvm_ptr_ty], // tensormap_ptr
2024
+ tensor_dim_args, // actual tensor dims
2025
+ [llvm_i64_ty]), // cache_hint
2026
+ [llvm_i1_ty], // Flag for cache_hint
2027
+ [IntrConvergent,
2028
+ ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
2029
+ NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>;
2030
+
2031
+ // Intrinsics for TMA Copy with reduction
2032
+ foreach red_op = ["add", "min", "max", "inc", "dec", "and", "or", "xor"] in
2033
+ def int_nvvm_cp_async_bulk_tensor_reduce_ # red_op # _ # mode # _ # dim # d :
2034
+ DefaultAttrsIntrinsicFlags<[],
2035
+ !listconcat([llvm_shared_ptr_ty, // src_smem_ptr
2036
+ llvm_ptr_ty], // tensormap_ptr
2037
+ tensor_dim_args, // actual tensor dims
2038
+ [llvm_i64_ty]), // cache_hint
2039
+ [llvm_i1_ty], // Flag for cache_hint
2040
+ [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
2041
+ NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>;
2042
+ }
2043
+ }
2044
+
2045
+ // TMA Tensor Copy Intrinsics: G2S -> From Global to Shared memory variants
2019
2046
foreach dim = 1...5 in {
2020
2047
defvar tensor_dim_args = !listsplat(llvm_i32_ty, dim);
2021
2048
@@ -2045,17 +2072,6 @@ foreach dim = 1...5 in {
2045
2072
def int_nvvm_cp_async_bulk_tensor_g2s_ # mode # _ # dim # d :
2046
2073
DefaultAttrsIntrinsicFlags<[], g2s_params, g2s_flags, g2s_props>;
2047
2074
2048
- def int_nvvm_cp_async_bulk_tensor_s2g_ # mode # _ # dim # d :
2049
- DefaultAttrsIntrinsicFlags<[],
2050
- !listconcat([llvm_shared_ptr_ty, // src_smem_ptr
2051
- llvm_ptr_ty], // tensormap_ptr
2052
- tensor_dim_args, // actual tensor dims
2053
- [llvm_i64_ty]), // cache_hint
2054
- [llvm_i1_ty], // Flag for cache_hint
2055
- [IntrConvergent,
2056
- ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
2057
- NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>;
2058
-
2059
2075
def int_nvvm_cp_async_bulk_tensor_prefetch_ # mode # _ # dim # d :
2060
2076
DefaultAttrsIntrinsicFlags<[],
2061
2077
!listconcat([llvm_ptr_ty], // tensormap_ptr
@@ -2065,18 +2081,6 @@ foreach dim = 1...5 in {
2065
2081
[llvm_i1_ty], // Flag for cache_hint
2066
2082
[IntrConvergent,
2067
2083
ReadOnly<ArgIndex<0>>, NoCapture<ArgIndex<0>>]>;
2068
-
2069
- // Intrinsics for TMA Copy with reduction
2070
- foreach red_op = ["add", "min", "max", "inc", "dec", "and", "or", "xor"] in
2071
- def int_nvvm_cp_async_bulk_tensor_reduce_ # red_op # _ # mode # _ # dim # d :
2072
- DefaultAttrsIntrinsicFlags<[],
2073
- !listconcat([llvm_shared_ptr_ty, // src_smem_ptr
2074
- llvm_ptr_ty], // tensormap_ptr
2075
- tensor_dim_args, // actual tensor dims
2076
- [llvm_i64_ty]), // cache_hint
2077
- [llvm_i1_ty], // Flag for cache_hint
2078
- [IntrConvergent, ReadOnly<ArgIndex<0>>, ReadOnly<ArgIndex<1>>,
2079
- NoCapture<ArgIndex<0>>, NoCapture<ArgIndex<1>>]>;
2080
2084
}
2081
2085
}
2082
2086
0 commit comments