You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: python/paddle/nn/functional/flash_attention.py
+31-34Lines changed: 31 additions & 34 deletions
Original file line number
Diff line number
Diff line change
@@ -1575,7 +1575,7 @@ def flashmask_attention(
1575
1575
training: bool=True,
1576
1576
name: str|None=None,
1577
1577
softmax_scale: float|None=None,
1578
-
block_mask_indices: Tensor|None=None,
1578
+
block_mask: Tensor|None=None,
1579
1579
):
1580
1580
r"""
1581
1581
FlashMask: Official Implementation
@@ -1636,25 +1636,24 @@ def flashmask_attention(
1636
1636
training (bool): Whether the module is in training mode. Default is True.
1637
1637
name (str, optional): Name of the operation. Default is None. Normally, users do not need to set this property.
1638
1638
For more information, refer to :ref:`api_guide_Name` .
1639
-
block_mask_indices (tensor, optional):
1640
-
block_mask_indices (Tensor, optional):
1641
-
A 4-D integer mask tensor indicating whether each block in the attention matrix should be kept or masked. Must be used together with flashmask.
1642
-
The shape should be [batch_size, num_heads, blocklen_q, blocklen_k], where:
1639
+
block_mask (tensor, optional):
1640
+
A 4-D integer mask tensor indicating whether each block in the attention matrix should be kept or masked. Must be used together with flashmask.
1641
+
The shape should be [batch_size, num_heads, blocklen_q, blocklen_k], where:
1643
1642
1644
-
blocklen_q = ceil(seqlen_q / 128), i.e., block_mask_indices.shape[2] must be (seqlen_q + 127) // 128
1645
-
blocklen_k = ceil(seqlen_k / 128), i.e., block_mask_indices.shape[3] must be (seqlen_k + 127) // 128
1646
-
block_mask_indices.shape[1] (number of heads) must match the num_heads dimension of the flashmask
1647
-
Both seqlen_q and seqlen_k must be less than or equal to 128 * 1024
1648
-
The dtype should be int32, and each element should be either 0 or 1.
1649
-
A value of 1 indicates that the corresponding block is kept (not masked), while 0 means the block is masked.
1643
+
blocklen_q = ceil(seqlen_q / 128), i.e., block_mask.shape[2] must be (seqlen_q + 127) // 128
1644
+
blocklen_k = ceil(seqlen_k / 128), i.e., block_mask.shape[3] must be (seqlen_k + 127) // 128
1645
+
block_mask.shape[1] (number of heads) must match the num_heads dimension of the flashmask
1646
+
Both seqlen_q and seqlen_k must be less than or equal to 128 * 1024
1647
+
The dtype should be int32, and each element should be either 0 or 1.
1648
+
A value of 1 indicates that the corresponding block is kept (not masked), while 0 means the block is masked.
1650
1649
1651
-
Usage Notes:
1650
+
Usage Notes:
1652
1651
1653
-
Only supported when blockdim_q = blockdim_k = 128 now.
1654
-
Only supported when headdim = 128 now.
1655
-
This argument must be provided together with flashmask.
1656
-
The mask will be applied at the block level: each [i, j] position in block_mask_indices controls whether the corresponding [128 x 128] block in the attention matrix is masked.
1657
-
Any mismatch in expected shape or head dimension will raise an error.
1652
+
Only supported when blockdim_q = blockdim_k = 128 now.
1653
+
Only supported when headdim = 128 now.
1654
+
This argument must be provided together with flashmask.
1655
+
The mask will be applied at the block level: each [i, j] position in block_mask controls whether the corresponding [128 x 128] block in the attention matrix is masked.
1656
+
Any mismatch in expected shape or head dimension will raise an error.
1658
1657
1659
1658
1660
1659
Returns
@@ -2228,7 +2227,7 @@ def flashmask_attention(
2228
2227
startend_row_indices, min=0, max=sq
2229
2228
).repeat_interleave(bsz, 0)
2230
2229
2231
-
ifblock_mask_indicesisnotNone:
2230
+
ifblock_maskisnotNone:
2232
2231
# xhy: can set a full startend_row_indices for block_mask_attn when using block_mask_attn?
2233
2232
assertstartend_row_indicesisnotNone, (
2234
2233
"must provide startend_row_indices when using block_mask_attn"
@@ -2275,26 +2274,24 @@ def flashmask_attention(
2275
2274
"startend_row_indices head_num must be equal to 1(broadcast) or head_num_k."
2276
2275
)
2277
2276
2278
-
ifblock_mask_indicesisnotNone:
2279
-
assertblock_mask_indices.dtype==paddle.int32, (
2280
-
f"block_mask_indices.dtype must be paddle.int32, but got {block_mask_indices.dtype}"
2277
+
ifblock_maskisnotNone:
2278
+
assertblock_mask.dtype==paddle.int32, (
2279
+
f"block_mask.dtype must be paddle.int32, but got {block_mask.dtype}"
0 commit comments