Skip to content

Commit 5ab2397

Browse files
committed
fixed the bug pointed out in #2
1 parent 059c6a6 commit 5ab2397

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

model.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ def forward(self, queries, keys, time_mask, attn_mask, time_matrix_K, time_matri
6363
# key masking, -2^32 lead to leaking, inf lead to nan
6464
# 0 * inf = nan, then reduce_sum([nan,...]) = nan
6565

66-
time_mask = time_mask.unsqueeze(-1).expand(attn_weights.shape[0], -1, attn_weights.shape[-1])
66+
# fixed a bug pointed out in https://github.com/pmixer/TiSASRec.pytorch/issues/2
67+
# time_mask = time_mask.unsqueeze(-1).expand(attn_weights.shape[0], -1, attn_weights.shape[-1])
68+
time_mask = time_mask.unsqueeze(-1).repeat(self.head_num, 1, 1)
69+
time_mask = time_mask.expand(-1, -1, attn_weights.shape[-1])
6770
attn_mask = attn_mask.unsqueeze(0).expand(attn_weights.shape[0], -1, -1)
6871
paddings = torch.ones(attn_weights.shape) * (-2**32+1) # -1e23 # float('-inf')
6972
paddings = paddings.to(self.dev)

0 commit comments

Comments
 (0)