25
25
from ..utils .import_utils import is_torch_npu_available , is_xformers_available
26
26
from ..utils .torch_utils import maybe_allow_in_graph
27
27
from .lora import LoRALinearLayer
28
+ from shark_turbine .ops .iree import trace_tensor
28
29
29
30
30
31
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -816,6 +817,8 @@ def __call__(
816
817
value = attn .head_to_batch_dim (value )
817
818
818
819
attention_probs = attn .get_attention_scores (query , key , attention_mask )
820
+
821
+
819
822
hidden_states = torch .bmm (attention_probs , value )
820
823
hidden_states = attn .batch_to_head_dim (hidden_states )
821
824
@@ -922,6 +925,7 @@ def __call__(
922
925
value = attn .head_to_batch_dim (value )
923
926
924
927
attention_probs = attn .get_attention_scores (query , key , attention_mask )
928
+
925
929
hidden_states = torch .bmm (attention_probs , value )
926
930
hidden_states = attn .batch_to_head_dim (hidden_states )
927
931
@@ -1131,10 +1135,14 @@ def __call__(
1131
1135
query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1132
1136
key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1133
1137
value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1134
-
1138
+ # trace_tensor("query", query[0,0,0])
1139
+ # trace_tensor("key", key[0,0,0])
1140
+ # trace_tensor("value", value[0,0,0])
1135
1141
hidden_states = hidden_states = F .scaled_dot_product_attention (
1136
1142
query , key , value , dropout_p = 0.0 , is_causal = False
1137
1143
)
1144
+ trace_tensor ("attn_out" , hidden_states [0 ,0 ,0 ,0 ])
1145
+
1138
1146
hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1139
1147
hidden_states = hidden_states .to (query .dtype )
1140
1148
@@ -1143,9 +1151,10 @@ def __call__(
1143
1151
hidden_states [:, : residual .shape [1 ]],
1144
1152
hidden_states [:, residual .shape [1 ] :],
1145
1153
)
1146
-
1154
+ hidden_states_cl = hidden_states .clone ()
1155
+ trace_tensor ("attn_out" , hidden_states_cl [0 ,0 ,0 ])
1147
1156
# linear proj
1148
- hidden_states = attn .to_out [0 ](hidden_states )
1157
+ hidden_states = attn .to_out [0 ](hidden_states_cl )
1149
1158
# dropout
1150
1159
hidden_states = attn .to_out [1 ](hidden_states )
1151
1160
if not attn .context_pre_only :
@@ -1212,10 +1221,14 @@ def __call__(
1212
1221
query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1213
1222
key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1214
1223
value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1215
-
1224
+ trace_tensor ("query" , query )
1225
+ trace_tensor ("key" , key )
1226
+ trace_tensor ("value" , value )
1216
1227
hidden_states = hidden_states = F .scaled_dot_product_attention (
1217
1228
query , key , value , dropout_p = 0.0 , is_causal = False
1218
1229
)
1230
+ trace_tensor ("attn_out" , hidden_states [:,:,:50 ])
1231
+
1219
1232
hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1220
1233
hidden_states = hidden_states .to (query .dtype )
1221
1234
@@ -1584,7 +1597,10 @@ def __call__(
1584
1597
hidden_states = F .scaled_dot_product_attention (
1585
1598
query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
1586
1599
)
1587
-
1600
+ trace_tensor ("query" , query )
1601
+ trace_tensor ("key" , key )
1602
+ trace_tensor ("value" , value )
1603
+ trace_tensor ("attn_out" , hidden_states [:,:,:50 ])
1588
1604
hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1589
1605
hidden_states = hidden_states .to (query .dtype )
1590
1606
@@ -1778,6 +1794,7 @@ def __call__(
1778
1794
key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1779
1795
value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1780
1796
1797
+
1781
1798
# the output of sdp = (batch, num_heads, seq_len, head_dim)
1782
1799
# TODO: add support for attn.scale when we move to Torch 2.1
1783
1800
hidden_states = F .scaled_dot_product_attention (
0 commit comments