19
19
import torch
20
20
import torch .nn .functional as F
21
21
from torch import nn
22
+ import numpy as np
22
23
23
24
from ..image_processor import IPAdapterMaskProcessor
24
25
from ..utils import deprecate , logging
25
26
from ..utils .import_utils import is_torch_npu_available , is_xformers_available
26
27
from ..utils .torch_utils import maybe_allow_in_graph
27
28
from .lora import LoRALinearLayer
28
- from shark_turbine .ops .iree import trace_tensor
29
+ # from shark_turbine.ops.iree import trace_tensor
29
30
30
31
31
32
logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
@@ -1116,10 +1117,13 @@ def __call__(
1116
1117
batch_size = encoder_hidden_states .shape [0 ]
1117
1118
1118
1119
# `sample` projections.
1120
+ #trace_tensor("hidden_states", hidden_states[0,0,0])
1119
1121
query = attn .to_q (hidden_states )
1120
1122
key = attn .to_k (hidden_states )
1121
1123
value = attn .to_v (hidden_states )
1122
-
1124
+ #trace_tensor("query_pre_proj", query[0,0,0])
1125
+ #trace_tensor("key_pre_proj", key[0,0,0])
1126
+ #trace_tensor("value_pre_proj", value[0,0,0])
1123
1127
# `context` projections.
1124
1128
encoder_hidden_states_query_proj = attn .add_q_proj (encoder_hidden_states )
1125
1129
encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
@@ -1129,20 +1133,18 @@ def __call__(
1129
1133
query = torch .cat ([query , encoder_hidden_states_query_proj ], dim = 1 )
1130
1134
key = torch .cat ([key , encoder_hidden_states_key_proj ], dim = 1 )
1131
1135
value = torch .cat ([value , encoder_hidden_states_value_proj ], dim = 1 )
1132
-
1133
1136
inner_dim = key .shape [- 1 ]
1134
1137
head_dim = inner_dim // attn .heads
1135
1138
query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1136
1139
key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1137
1140
value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1138
- # trace_tensor("query ", query[0,0,0] )
1139
- # trace_tensor("key ", key[0,0,0] )
1140
- # trace_tensor("value ", value[0,0,0] )
1141
+ # np.save("q.npy ", query.detach().cpu().numpy() )
1142
+ # np.save("k.npy ", key.detach().cpu().numpy() )
1143
+ # np.save("v.npy ", value.detach().cpu().numpy() )
1141
1144
hidden_states = hidden_states = F .scaled_dot_product_attention (
1142
1145
query , key , value , dropout_p = 0.0 , is_causal = False
1143
1146
)
1144
1147
#trace_tensor("attn_out", hidden_states[0,0,0,0])
1145
-
1146
1148
hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1147
1149
hidden_states = hidden_states .to (query .dtype )
1148
1150
@@ -1152,7 +1154,7 @@ def __call__(
1152
1154
hidden_states [:, residual .shape [1 ] :],
1153
1155
)
1154
1156
hidden_states_cl = hidden_states .clone ()
1155
- trace_tensor ("attn_out" , hidden_states_cl [0 ,0 ,0 ])
1157
+ # trace_tensor("attn_out", hidden_states_cl[0,0,0])
1156
1158
# linear proj
1157
1159
hidden_states = attn .to_out [0 ](hidden_states_cl )
1158
1160
# dropout
0 commit comments