@@ -71,38 +71,38 @@ def flash_context_attention(q, k, v, out, b_start_loc, b_seq_len, max_input_len)
71
71
ext .prompt_flash_attention (single_out , single_q , single_k , single_v , None , mask , [], head , scale , 2147473647 , 0 , "BSH" , numKeyValueHeads )
72
72
return out
73
73
74
- def fused_context_attention (q , k , v , out , b_start_loc , b_seq_len , max_input_len , head , numKeyValueHeads , dim ):
75
- batch = b_start_loc .shape [0 ]
76
- scale = 1 / math .sqrt (dim )
77
- mask_key_str = str (batch ) + ":" + str (max_input_len )
78
- if mask_key_str not in mask_cache :
79
- mask = torch .tril (torch .ones (max_input_len , max_input_len , dtype = torch .bool ), diagonal = 0 ).cuda ()
80
- mask = mask .repeat (batch , 1 , 1 )
81
- mask = torch .logical_not (mask )
82
- mask_cache [mask_key_str ] = mask
83
- print (f"cache mask in context attention, batch:seqLen={ mask_key_str } " )
74
+ # def fused_context_attention(out, q, k, v, mask , b_seq_len, max_input_len, head, numKeyValueHeads, dim):
75
+ # batch = b_start_loc.shape[0]
76
+ # scale = 1 / math.sqrt(dim)
77
+ # mask_key_str = str(batch) + ":" + str(max_input_len)
78
+ # if mask_key_str not in mask_cache:
79
+ # mask = torch.tril(torch.ones(max_input_len, max_input_len, dtype=torch.bool), diagonal=0).cuda()
80
+ # mask = mask.repeat(batch, 1, 1)
81
+ # mask = torch.logical_not(mask)
82
+ # mask_cache[mask_key_str] = mask
83
+ # print(f"cache mask in context attention, batch:seqLen={mask_key_str}")
84
84
85
- mask = mask_cache [mask_key_str ]
86
- ext .prompt_flash_attention (out , q , k , v ,
87
- None , mask , b_seq_len , head , scale , 2147473647 , 0 , "BSH" , numKeyValueHeads )
88
- return out
89
-
90
- context_attention_pack .context_attention_fwd = (
91
- # flash_context_attention
92
- fused_context_attention
93
- )
85
+ # mask = mask_cache[mask_key_str]
86
+ # ext.prompt_flash_attention(out, q, k, v,
87
+ # mask, b_seq_len, max_input_len, head, numKeyValueHeads, dim)
88
+ # return out
94
89
90
+ # context_attention_pack.context_attention_fwd = (
91
+ # # flash_context_attention
92
+ # fused_context_attention
93
+ # )
94
+ context_attention_pack .prompt_flash_attention = ext .prompt_flash_attention
95
95
96
96
def patch_paged_token_attention_inference ():
97
- def paged_token_attention (q , k_cache , v_cache , out , q_head_num , kv_head_num , head_dim , b_seq_len , block_table :torch .Tensor , block_size ):
98
- ext .paged_attention (out , q , k_cache , v_cache , None , None ,
99
- b_seq_len , block_table , q_head_num , kv_head_num ,
100
- 1.0 / math .sqrt (head_dim ), "BSH" , block_size , 0 ,
101
- None , None , None , None , None , None , None , None
102
- )
103
- return out
97
+ # def paged_token_attention(q, k_cache, v_cache, out, q_head_num, kv_head_num, head_dim, b_seq_len, block_table:torch.Tensor, block_size):
98
+ # ext.paged_attention(out, q, k_cache, v_cache, None, None,
99
+ # b_seq_len, block_table, q_head_num, kv_head_num,
100
+ # 1.0 / math.sqrt(head_dim), "BSH", block_size, 0,
101
+ # None, None, None, None, None, None, None, None
102
+ # )
103
+ # return out
104
104
105
- token_attention_pack .paged_token_attention = ( paged_token_attention )
105
+ token_attention_pack .paged_token_attention = ext . paged_attention
106
106
107
107
108
108
def patch_token_attention_inference ():
0 commit comments