Why non-decreasing dimension order on block ptr is not supported? #207
-
In #140 introduced a check against the pid_n = tl.program_id(axis=0) # x
pid_m = tl.program_id(axis=1) # y
offset_m = pid_m * BLOCK_SIZE_M
offset_n = pid_n * BLOCK_SIZE_N
in_block_ptr = tl.make_block_ptr(
base=in_ptr,
shape=(m, n),
strides=(in_stride_m, in_stride_n),
offsets=(offset_m, offset_n),
block_shape=(BLOCK_SIZE_M, BLOCK_SIZE_N),
order=(0, 1), # increasing.
) the check fails. I wonder what's the purpose of this check, doesn't |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
I am also interested since it does not allow to run flash attention v2 from here: |
Beta Was this translation helpful? Give feedback.
-
Thanks for the question! This check was introduced because our lowering doesn't take into account the order field at the moment; it prevents us from producing incorrect code because the conversion to memref.reinterpret_cast currently assumes row-major layout. When using increasing order on a row-major tensor, the behavior is as if there were an implicit transpose. Currently we're reworking some of the passes to support arbitrary pointer patterns. Once that is done, we would appreciate any help updating the structured-to-memref pass to take into account the order field. |
Beta Was this translation helpful? Give feedback.
Thanks for the question! This check was introduced because our lowering doesn't take into account the order field at the moment; it prevents us from producing incorrect code because the conversion to memref.reinterpret_cast currently assumes row-major layout. When using increasing order on a row-major tensor, the behavior is as if there were an implicit transpose. Currently we're reworking some of the passes to support arbitrary pointer patterns. Once that is done, we would appreciate any help updating the structured-to-memref pass to take into account the order field.