You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: python/paddle/nn/functional/flash_attention.py
+61-1Lines changed: 61 additions & 1 deletion
Original file line number
Diff line number
Diff line change
@@ -1575,6 +1575,7 @@ def flashmask_attention(
1575
1575
training: bool=True,
1576
1576
name: str|None=None,
1577
1577
softmax_scale: float|None=None,
1578
+
block_mask_indices: Tensor|None=None,
1578
1579
):
1579
1580
r"""
1580
1581
FlashMask: Official Implementation
@@ -1635,6 +1636,26 @@ def flashmask_attention(
1635
1636
training (bool): Whether the module is in training mode. Default is True.
1636
1637
name (str, optional): Name of the operation. Default is None. Normally, users do not need to set this property.
1637
1638
For more information, refer to :ref:`api_guide_Name` .
1639
+
block_mask_indices (tensor, optional):
1640
+
block_mask_indices (Tensor, optional):
1641
+
A 4-D integer mask tensor indicating whether each block in the attention matrix should be kept or masked. Must be used together with flashmask.
1642
+
The shape should be [batch_size, num_heads, blocklen_q, blocklen_k], where:
1643
+
1644
+
blocklen_q = ceil(seqlen_q / 128), i.e., block_mask_indices.shape[2] must be (seqlen_q + 127) // 128
1645
+
blocklen_k = ceil(seqlen_k / 128), i.e., block_mask_indices.shape[3] must be (seqlen_k + 127) // 128
1646
+
block_mask_indices.shape[1] (number of heads) must match the num_heads dimension of the flashmask
1647
+
Both seqlen_q and seqlen_k must be less than or equal to 128 * 1024
1648
+
The dtype should be int32, and each element should be either 0 or 1.
1649
+
A value of 1 indicates that the corresponding block is kept (not masked), while 0 means the block is masked.
1650
+
1651
+
Usage Notes:
1652
+
1653
+
Only supported when blockdim_q = blockdim_k = 128 now.
1654
+
Only supported when headdim = 128 now.
1655
+
This argument must be provided together with flashmask.
1656
+
The mask will be applied at the block level: each [i, j] position in block_mask_indices controls whether the corresponding [128 x 128] block in the attention matrix is masked.
1657
+
Any mismatch in expected shape or head dimension will raise an error.
1658
+
1638
1659
1639
1660
Returns
1640
1661
Tensor. The computed attention result with the same shape as the input `query`.
@@ -2207,6 +2228,12 @@ def flashmask_attention(
2207
2228
startend_row_indices, min=0, max=sq
2208
2229
).repeat_interleave(bsz, 0)
2209
2230
2231
+
ifblock_mask_indicesisnotNone:
2232
+
# xhy: can set a full startend_row_indices for block_mask_attn when using block_mask_attn?
2233
+
assertstartend_row_indicesisnotNone, (
2234
+
"must provide startend_row_indices when using block_mask_attn"
2235
+
)
2236
+
2210
2237
ifstartend_row_indicesisNone:
2211
2238
(
2212
2239
out,
@@ -2248,6 +2275,33 @@ def flashmask_attention(
2248
2275
"startend_row_indices head_num must be equal to 1(broadcast) or head_num_k."
2249
2276
)
2250
2277
2278
+
ifblock_mask_indicesisnotNone:
2279
+
assertblock_mask_indices.dtype==paddle.int32, (
2280
+
f"block_mask_indices.dtype must be paddle.int32, but got {block_mask_indices.dtype}"
0 commit comments