1+ #include < ATen/core/Tensor.h>
2+ #include < ATen/native/transformers/attention.h>
3+ #include < ATen/native/transformers/sdp_utils_cpp.h>
4+
5+ #ifndef AT_PER_OPERATOR_HEADERS
6+ #include < ATen/Functions.h>
7+ #include < ATen/NativeFunctions.h>
8+ #else
9+ #include < ATen/ops/empty_like.h>
10+ #include < ATen/ops/linear.h>
11+ #include < ATen/ops/scaled_dot_product_attention.h>
12+ #endif
13+
14+ #include < ATen/native/cutlass/Attention.h>
15+ #include < ATen/native/cutlass/sycl/AttentionKernels.h>
16+
17+ #include < comm/SYCLContext.h>
18+
19+ namespace at {
20+ namespace native {
21+ namespace cutlass_sycl {
22+
23+ void sdpa_backward (
24+ int batch_size,
25+ int num_head_q,
26+ int num_head_kv,
27+ int seq_len_q,
28+ int seq_len_kv,
29+ int head_dim_qk,
30+ int head_dim_v,
31+ const Tensor& grad_out,
32+ const Tensor& query,
33+ const Tensor& key,
34+ const Tensor& value,
35+ const Tensor& out,
36+ const Tensor& logsumexp,
37+ std::optional<at::Tensor> attn_mask,
38+ bool is_causal,
39+ double scale,
40+ Tensor& grad_query,
41+ Tensor& grad_key,
42+ Tensor& grad_value) {
43+
44+ std::cout << " lfq: entering cutlass sdpa_backward" << std::endl;
45+
46+ auto ps = at::matmul (query, key.transpose (-2 , -1 ));
47+ ps = ps / std::sqrt (scale);
48+ ps = at::softmax (ps, -1 ).to (query.dtype ());
49+ auto dps = at::empty_like (ps);
50+ cutlass_sdpa_backward (batch_size, num_head_q, num_head_kv, seq_len_q, seq_len_kv,
51+ head_dim_qk, head_dim_v,
52+ grad_out.data_ptr (),
53+ query.data_ptr (),
54+ key.data_ptr (),
55+ value.data_ptr (),
56+ ps.data_ptr (),
57+ nullptr ,
58+ grad_query.data_ptr (),
59+ grad_key.data_ptr (),
60+ grad_value.data_ptr (),
61+ dps.data_ptr ());
62+ }
63+ } // cutlass_sycl
64+ } // namespace native
65+ } // namespace at
0 commit comments