Skip to content

Commit 0df0cb5

Browse files
Add more diagnostic messages
1 parent a2a1160 commit 0df0cb5

File tree

1 file changed

+12
-1
lines changed
  • aten/src/ATen/native/transformers/hip/flash_attn

1 file changed

+12
-1
lines changed

aten/src/ATen/native/transformers/hip/flash_attn/flash_api.h

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,7 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
248248
gen_,
249249
dummy_attn_bias); // Not used in flash attention
250250
} else {
251+
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention forward...");
251252
return mha_fwd_aot(q,
252253
k,
253254
v,
@@ -263,7 +264,8 @@ mha_fwd(const at::Tensor &q, // batch_size x seqlen_q x num_heads x head
263264

264265
}
265266
#else
266-
return mha_fwd_aot(q,
267+
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention forward...");
268+
return mha_fwd_aot(q,
267269
k,
268270
v,
269271
out_,
@@ -301,6 +303,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
301303
#if defined(USE_CK_FLASH_ATTENTION)
302304
if (at::globalContext().getROCmFAPreferredBackend() ==
303305
at::ROCmFABackend::Ck) {
306+
TORCH_WARN_ONCE("Using CK backend for Flash Attention varlen forward...");
304307
std::optional<at::Tensor> dummy_attn_bias = std::nullopt;
305308
return mha_varlen_fwd_ck(
306309
q,
@@ -322,6 +325,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
322325
gen_,
323326
dummy_attn_bias); // Not used in flash attention
324327
} else {
328+
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention varlen forward...");
325329
return mha_varlen_fwd_aot(q,
326330
k,
327331
v,
@@ -343,6 +347,7 @@ mha_varlen_fwd(const at::Tensor &q, // total_q x num_heads x head_size, total_q
343347
gen_);
344348
}
345349
#else
350+
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention varlen forward...");
346351
return mha_varlen_fwd_aot(q,
347352
k,
348353
v,
@@ -389,6 +394,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
389394
#if defined(USE_CK_FLASH_ATTENTION)
390395
if (at::globalContext().getROCmFAPreferredBackend() ==
391396
at::ROCmFABackend::Ck) {
397+
TORCH_WARN_ONCE("Using CK backend for Flash Attention backward...");
392398
std::optional<at::Tensor> non_null_dbias = std::nullopt;
393399
auto[dQuery,
394400
dKey,
@@ -418,6 +424,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
418424
// for FA return [dQ, dV, dK, dSoftmax]
419425
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
420426
} else {
427+
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention backward...");
421428
return mha_bwd_aot(dout,
422429
q,
423430
k,
@@ -442,6 +449,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
442449
at::ROCmFABackend::Ck) {
443450
TORCH_WARN_ONCE("Warning! You have opted to use CK flash attention backend in a build that was not compiled using USE_CK_FLASH_ATTENTION=1. Please set this variable and try again. Defaulting to use aotriton backend...");
444451
}
452+
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention backward...");
445453
return mha_bwd_aot(
446454
dout,
447455
q,
@@ -492,6 +500,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
492500
#if defined(USE_CK_FLASH_ATTENTION)
493501
if (at::globalContext().getROCmFAPreferredBackend() ==
494502
at::ROCmFABackend::Ck) {
503+
TORCH_WARN_ONCE("Using CK backend for Flash Attention varlen backward...");
495504
std::optional<at::Tensor> non_null_dbias = std::nullopt;
496505
auto[dQuery,
497506
dKey,
@@ -526,6 +535,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
526535
// for FA return [dQ, dV, dK, dSoftmax]
527536
return std::make_tuple(std::move(dQuery), std::move(dKey), std::move(dValue), std::move(dSoftmax));
528537
} else {
538+
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention varlen backward...");
529539
return mha_varlen_bwd_aot(dout,
530540
q,
531541
k,
@@ -551,6 +561,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
551561
philox_offset);
552562
}
553563
#else
564+
TORCH_WARN_ONCE("Using AOTriton backend for Flash Attention varlen backward...");
554565
return mha_varlen_bwd_aot(dout,
555566
q,
556567
k,

0 commit comments

Comments
 (0)