Skip to content

Commit 54aa68f

Browse files
committed
validated, runnable on gpu, loss, NDCG and hit rate looks okay
1 parent 442bdf4 commit 54aa68f

File tree

3 files changed

+9
-6
lines changed

3 files changed

+9
-6
lines changed

main.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def str2bool(s):
6767
print('failed loading state_dicts, pls check file path: ', end="")
6868
print(args.state_dict_path)
6969

70-
if args.inference_only or True:
70+
if args.inference_only:
7171
model.eval()
7272
t_test = evaluate(model, dataset, args)
7373
print('test (NDCG@10: %.4f, HR@10: %.4f)' % (t_test[0], t_test[1]))

model.py

+6-3
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def forward(self, inputs):
2323

2424
class TimeAwareMultiHeadAttention(torch.nn.Module):
2525
# required homebrewed mha layer for Ti/SASRec experiments
26-
def __init__(self, hidden_size, head_num, dropout_rate):
26+
def __init__(self, hidden_size, head_num, dropout_rate, dev):
2727
super(TimeAwareMultiHeadAttention, self).__init__()
2828
self.Q_w = torch.nn.Linear(hidden_size, hidden_size)
2929
self.K_w = torch.nn.Linear(hidden_size, hidden_size)
@@ -36,6 +36,7 @@ def __init__(self, hidden_size, head_num, dropout_rate):
3636
self.head_num = head_num
3737
self.head_size = hidden_size // head_num
3838
self.dropout_rate = dropout_rate
39+
self.dev = dev
3940

4041
def forward(self, queries, keys, time_mask, attn_mask, time_matrix_K, time_matrix_V, abs_pos_K, abs_pos_V):
4142
Q, K, V = self.Q_w(queries), self.K_w(keys), self.V_w(keys)
@@ -63,7 +64,8 @@ def forward(self, queries, keys, time_mask, attn_mask, time_matrix_K, time_matri
6364

6465
time_mask = time_mask.unsqueeze(-1).expand(attn_weights.shape[0], -1, attn_weights.shape[-1])
6566
attn_mask = attn_mask.unsqueeze(0).expand(attn_weights.shape[0], -1, -1)
66-
paddings = torch.ones(attn_weights.shape) * FLOAT_MIN # float('-inf')
67+
paddings = torch.ones(attn_weights.shape) * -1e23 # float('-inf')
68+
paddings = paddings.to(self.dev)
6769
attn_weights = torch.where(time_mask, paddings, attn_weights) # True:pick padding
6870
attn_weights = torch.where(attn_mask, paddings, attn_weights) # enforcing causality
6971

@@ -119,7 +121,8 @@ def __init__(self, user_num, item_num, time_num, args):
119121

120122
new_attn_layer = TimeAwareMultiHeadAttention(args.hidden_units,
121123
args.num_heads,
122-
args.dropout_rate)
124+
args.dropout_rate,
125+
args.device)
123126
self.attention_layers.append(new_attn_layer)
124127

125128
new_fwd_layernorm = torch.nn.LayerNorm(args.hidden_units, eps=1e-8)

utils.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def evaluate(model, dataset, args):
241241
predictions = -model.predict(*[np.array(l) for l in [[u], [seq], [time_matrix],item_idx]])
242242
predictions = predictions[0]
243243

244-
rank = predictions.argsort().argsort()[0]
244+
rank = predictions.argsort().argsort()[0].item()
245245

246246
valid_user += 1
247247

@@ -290,7 +290,7 @@ def evaluate_valid(model, dataset, args):
290290
predictions = -model.predict(*[np.array(l) for l in [[u], [seq], [time_matrix],item_idx]])
291291
predictions = predictions[0]
292292

293-
rank = predictions.argsort().argsort()[0]
293+
rank = predictions.argsort().argsort()[0].item()
294294

295295
valid_user += 1
296296

0 commit comments

Comments
 (0)