@@ -1141,7 +1141,7 @@ def __call__(
1141
1141
hidden_states = hidden_states = F .scaled_dot_product_attention (
1142
1142
query , key , value , dropout_p = 0.0 , is_causal = False
1143
1143
)
1144
- trace_tensor ("attn_out" , hidden_states [0 ,0 ,0 ,0 ])
1144
+ # trace_tensor("attn_out", hidden_states[0,0,0,0])
1145
1145
1146
1146
hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1147
1147
hidden_states = hidden_states .to (query .dtype )
@@ -1152,7 +1152,7 @@ def __call__(
1152
1152
hidden_states [:, residual .shape [1 ] :],
1153
1153
)
1154
1154
hidden_states_cl = hidden_states .clone ()
1155
- trace_tensor ("attn_out" , hidden_states_cl [0 ,0 ,0 ])
1155
+ # trace_tensor("attn_out", hidden_states_cl[0,0,0])
1156
1156
# linear proj
1157
1157
hidden_states = attn .to_out [0 ](hidden_states_cl )
1158
1158
# dropout
@@ -1221,13 +1221,9 @@ def __call__(
1221
1221
query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1222
1222
key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1223
1223
value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1224
- trace_tensor ("query" , query )
1225
- trace_tensor ("key" , key )
1226
- trace_tensor ("value" , value )
1227
1224
hidden_states = hidden_states = F .scaled_dot_product_attention (
1228
1225
query , key , value , dropout_p = 0.0 , is_causal = False
1229
1226
)
1230
- trace_tensor ("attn_out" , hidden_states [:,:,:50 ])
1231
1227
1232
1228
hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1233
1229
hidden_states = hidden_states .to (query .dtype )
@@ -1597,10 +1593,6 @@ def __call__(
1597
1593
hidden_states = F .scaled_dot_product_attention (
1598
1594
query , key , value , attn_mask = attention_mask , dropout_p = 0.0 , is_causal = False
1599
1595
)
1600
- trace_tensor ("query" , query )
1601
- trace_tensor ("key" , key )
1602
- trace_tensor ("value" , value )
1603
- trace_tensor ("attn_out" , hidden_states [:,:,:50 ])
1604
1596
hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1605
1597
hidden_states = hidden_states .to (query .dtype )
1606
1598
0 commit comments