@@ -23,7 +23,7 @@ def forward(self, inputs):
23
23
24
24
class TimeAwareMultiHeadAttention (torch .nn .Module ):
25
25
# required homebrewed mha layer for Ti/SASRec experiments
26
- def __init__ (self , hidden_size , head_num , dropout_rate ):
26
+ def __init__ (self , hidden_size , head_num , dropout_rate , dev ):
27
27
super (TimeAwareMultiHeadAttention , self ).__init__ ()
28
28
self .Q_w = torch .nn .Linear (hidden_size , hidden_size )
29
29
self .K_w = torch .nn .Linear (hidden_size , hidden_size )
@@ -36,6 +36,7 @@ def __init__(self, hidden_size, head_num, dropout_rate):
36
36
self .head_num = head_num
37
37
self .head_size = hidden_size // head_num
38
38
self .dropout_rate = dropout_rate
39
+ self .dev = dev
39
40
40
41
def forward (self , queries , keys , time_mask , attn_mask , time_matrix_K , time_matrix_V , abs_pos_K , abs_pos_V ):
41
42
Q , K , V = self .Q_w (queries ), self .K_w (keys ), self .V_w (keys )
@@ -63,7 +64,8 @@ def forward(self, queries, keys, time_mask, attn_mask, time_matrix_K, time_matri
63
64
64
65
time_mask = time_mask .unsqueeze (- 1 ).expand (attn_weights .shape [0 ], - 1 , attn_weights .shape [- 1 ])
65
66
attn_mask = attn_mask .unsqueeze (0 ).expand (attn_weights .shape [0 ], - 1 , - 1 )
66
- paddings = torch .ones (attn_weights .shape ) * FLOAT_MIN # float('-inf')
67
+ paddings = torch .ones (attn_weights .shape ) * - 1e23 # float('-inf')
68
+ paddings = paddings .to (self .dev )
67
69
attn_weights = torch .where (time_mask , paddings , attn_weights ) # True:pick padding
68
70
attn_weights = torch .where (attn_mask , paddings , attn_weights ) # enforcing causality
69
71
@@ -119,7 +121,8 @@ def __init__(self, user_num, item_num, time_num, args):
119
121
120
122
new_attn_layer = TimeAwareMultiHeadAttention (args .hidden_units ,
121
123
args .num_heads ,
122
- args .dropout_rate )
124
+ args .dropout_rate ,
125
+ args .device )
123
126
self .attention_layers .append (new_attn_layer )
124
127
125
128
new_fwd_layernorm = torch .nn .LayerNorm (args .hidden_units , eps = 1e-8 )
0 commit comments