Skip to content

Commit f3db233

Browse files
authored
fix: prob attention shape error while bs>1 (#50)
1 parent 679e349 commit f3db233

File tree

2 files changed

+10
-10
lines changed

2 files changed

+10
-10
lines changed

Diff for: tests/test_models/test_informer.py

+9-8
Original file line numberDiff line numberDiff line change
@@ -129,20 +129,21 @@ def test_train(self):
129129
predict_length = 10
130130
n_encoder_feature = 2
131131
n_decoder_feature = 3
132+
batch_size = 1
132133

133134
x_train = (
134-
np.random.rand(1, train_length, 1),
135-
np.random.rand(1, train_length, n_encoder_feature),
136-
np.random.rand(1, predict_length, n_decoder_feature),
135+
np.random.rand(batch_size, train_length, 1),
136+
np.random.rand(batch_size, train_length, n_encoder_feature),
137+
np.random.rand(batch_size, predict_length, n_decoder_feature),
137138
)
138-
y_train = np.random.rand(1, predict_length, 1) # target: (batch, predict_length, 1)
139+
y_train = np.random.rand(batch_size, predict_length, 1) # target: (batch, predict_length, 1)
139140

140141
x_valid = (
141-
np.random.rand(1, train_length, 1),
142-
np.random.rand(1, train_length, n_encoder_feature),
143-
np.random.rand(1, predict_length, n_decoder_feature),
142+
np.random.rand(batch_size, train_length, 1),
143+
np.random.rand(batch_size, train_length, n_encoder_feature),
144+
np.random.rand(batch_size, predict_length, n_decoder_feature),
144145
)
145-
y_valid = np.random.rand(1, predict_length, 1)
146+
y_valid = np.random.rand(batch_size, predict_length, 1)
146147

147148
model = AutoModel("Informer", predict_length=predict_length, custom_model_params=custom_params)
148149
trainer = KerasTrainer(model)

Diff for: tfts/layers/attention_layer.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,9 @@ def _prob_qk(self, q, k, sample_k, top_n):
155155
K_sample = tf.gather(K_sample, indx_q_seq, axis=2)
156156
K_sample = tf.gather(K_sample, indx_k_seq, axis=3)
157157

158-
Q_K_sample = tf.squeeze(tf.matmul(tf.expand_dims(q, -2), tf.einsum("...ij->...ji", K_sample)))
158+
Q_K_sample = tf.squeeze(tf.matmul(tf.expand_dims(q, -2), tf.einsum("...ij->...ji", K_sample)), axis=3)
159159
M = tf.math.reduce_max(Q_K_sample, axis=-1) - tf.raw_ops.Div(x=tf.reduce_sum(Q_K_sample, axis=-1), y=L)
160160
m_top = tf.math.top_k(M, top_n, sorted=False)[1]
161-
m_top = m_top[tf.newaxis, tf.newaxis] if B == 1 else m_top
162161

163162
batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis], (1, H, top_n))
164163
head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis], (B, 1, top_n))

0 commit comments

Comments
 (0)