-
Notifications
You must be signed in to change notification settings - Fork 12
/
Copy pathutilities_mlp.py
70 lines (55 loc) · 2.58 KB
/
utilities_mlp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pickle
import gzip
import numpy as np
import torch
import utilities
class MLPDataset(torch.utils.data.Dataset):
def __init__(self, sample_files, weighing_scheme="sigmoidal_decay"):
self.sample_files = sample_files
self.weighing_scheme = weighing_scheme if weighing_scheme != "" else "constant"
def __len__(self):
return len(self.sample_files)
def __getitem__(self, index):
with gzip.open(self.sample_files[index], 'rb') as f:
sample = pickle.load(f)
is_root = "root" in self.sample_files[index]
obss, target, obss_feats, _ = sample['obss']
v, _, _ = obss
sample_cand_scores = obss_feats['scores']
sample_cands = np.where(sample_cand_scores != -1)[0]
v_feats = v[sample_cands]
v_feats = utilities._preprocess(v_feats, mode='min-max-2')
cand_scores = sample_cand_scores[sample_cands]
sample_action = np.where(sample_cands == target)[0][0]
weight = obss_feats['depth']/sample['max_depth'] if sample['max_depth'] else 1.0
if self.weighing_scheme == "linear_decay":
m = np.exp(-0.5) - 1
c = 1
weight = weight * m + c
elif self.weighing_scheme == "sigmoidal_decay":
weight = (1 + np.exp(-0.5))/(1 + np.exp(weight - 0.5))
elif self.weighing_scheme == "exponential_decay":
weight = np.exp(weight * -0.5)
elif self.weighing_scheme == "quadratic_decay":
weight = (np.exp(-0.5) - 1) * weight ** 2 + 1
elif self.weighing_scheme == "constant":
weight = 1.0
else:
raise ValueError(f"Unknown value for node weights: {self.weighing_scheme}")
return v_feats, sample_action, cand_scores, weight
def load_batch(sample_batch):
cand_featuress, sample_actions, cand_scoress, weights = list(zip(*sample_batch))
n_cands = [cds.shape[0] for cds in cand_featuress]
# convert to numpy arrays
cand_featuress = np.concatenate(cand_featuress, axis=0)
cand_scoress = np.concatenate(cand_scoress, axis=0)
n_cands = np.array(n_cands)
best_actions = np.array(sample_actions)
weights = np.array(weights)
# convert to tensors
cand_featuress = torch.as_tensor(cand_featuress, dtype=torch.float32)
cand_scoress = torch.as_tensor(cand_scoress, dtype=torch.float32)
n_cands = torch.as_tensor(n_cands, dtype=torch.int32)
best_actions = torch.as_tensor(sample_actions, dtype=torch.long)
weights = torch.as_tensor(weights, dtype=torch.float32)
return cand_featuress, n_cands, best_actions, cand_scoress, weights