@@ -125,42 +125,31 @@ Status PrepareQkv_Attention(contrib::AttentionParameters& parameters,
125
125
bool use_fused_kernel = (nullptr != fused_runner && !parameters.is_unidirectional );
126
126
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional );
127
127
128
- if (data.bias == nullptr ) {
129
- assert (nullptr == fused_runner);
130
- // For quantized attention, bias has been added so only need transpose here.
131
- // gemm_buffer should be BxSx3xNxH => qkv: 3xBxNxSxH
132
- assert (qk_head_size == v_head_size);
133
- int matrix_to_trans = (past_present_share_buffer ? 1 : 3 );
134
- ORT_RETURN_IF_ERROR (LaunchTransQkv (stream, matrix_to_trans, sequence_length, batch_size, qk_head_size, num_heads,
135
- max_threads_per_block, false , data.gemm_buffer , qkv, 3 ));
136
- data.qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
137
- } else {
138
- // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
139
- // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
140
- // For unfused kernel, transpose to 3xBxNxSxH (format 1)
141
- // For fused causal kernel, use format 1 since we need have K and V to update present state,
142
- // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
143
- const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1 ));
144
- data.qkv_format = use_fused_kernel
145
- ? AttentionQkvFormat::QKV_BSN3H
146
- : (use_flash_or_efficient_attention
147
- ? AttentionQkvFormat::Q_K_V_BSNH
148
- : (use_fused_causal
149
- ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
150
- : AttentionQkvFormat::Q_K_V_BNSH));
151
-
152
- // For fused causal, we will update gemm_buffer with bias directly.
153
- T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr ;
154
-
155
- int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3 );
156
- // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
157
- // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
158
- LaunchAddBiasTranspose (stream, matrix_to_transpose, format, max_threads_per_block,
159
- batch_size, sequence_length, num_heads, qk_head_size,
160
- data.gemm_buffer , data.bias , qkv, true , v_head_size, qkv_add_bias,
161
- 3 , parameters.do_rotary , parameters.rotary_embedding ,
162
- parameters.past_sequence_length );
163
- }
128
+ // For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
129
+ // For flash or memory efficient attention, transpose to 3xBxSxNxH (format 3)
130
+ // For unfused kernel, transpose to 3xBxNxSxH (format 1)
131
+ // For fused causal kernel, use format 1 since we need have K and V to update present state,
132
+ // at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
133
+ const int format = (use_fused_kernel ? 2 : (use_flash_or_efficient_attention ? 3 : 1 ));
134
+ data.qkv_format = use_fused_kernel
135
+ ? AttentionQkvFormat::QKV_BSN3H
136
+ : (use_flash_or_efficient_attention
137
+ ? AttentionQkvFormat::Q_K_V_BSNH
138
+ : (use_fused_causal
139
+ ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH
140
+ : AttentionQkvFormat::Q_K_V_BNSH));
141
+
142
+ // For fused causal, we will update gemm_buffer with bias directly.
143
+ T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr ;
144
+
145
+ int matrix_to_transpose = ((format == AttentionQkvFormat::Q_K_V_BNSH && past_present_share_buffer) ? 1 : 3 );
146
+ // format 1: BxSx(NH + NH + NH_v) => BxNxSxH + BxNxSxH + BxNxSxH_v
147
+ // format 2: BxSx(NH + NH + NH) => BxSxNx(H + H + H)
148
+ LaunchAddBiasTranspose (stream, matrix_to_transpose, format, max_threads_per_block,
149
+ batch_size, sequence_length, num_heads, qk_head_size,
150
+ data.gemm_buffer , data.bias , qkv, true , v_head_size, qkv_add_bias,
151
+ 3 , parameters.do_rotary , parameters.rotary_embedding ,
152
+ parameters.past_sequence_length );
164
153
return Status::OK ();
165
154
}
166
155
0 commit comments