-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_embeddings.py
117 lines (106 loc) · 3.95 KB
/
train_embeddings.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import pickle
from os import makedirs
from torch.nn import Embedding
from pykeen.models import CompGCN
from pykeen.triples import TriplesFactory
from pykeen.pipeline import pipeline
# PARAMETERS
dataset = "movielens"
emb_dim = 100
n_layers = 2
epochs = 15
# None if not using wiki2vec embeddings
wiki2vec_embeddings_file = "wiki2vec_embeddings.pkl"
output_path = f"results/{dataset}/{emb_dim}"
makedirs(output_path, exist_ok=True)
# Class to translate the dataset ids to pykeen ids
class TranslateId:
def __init__(self, id2lab, lab2id):
# original id -> label
self.id2lab_dict = {}
with open(id2lab) as f:
for line in f:
id, label = line.strip().split('\t')
self.id2lab_dict[int(id)] = label
# label -> pykeen id
# with open(lab2id, "rb") as f:
# self.lab2id_dict = pickle.load(f)
self.lab2id_dict = lab2id
def __call__(self, ids):
# translation: original id -> label -> pykeen id
return [self.lab2id_dict[self.id2lab_dict[id]] for id in ids]
class Wiki2VecCompGCN(CompGCN):
def __init__(self, wiki2vec_emb, **kwargs):
super().__init__(**kwargs)
num_embeddings = self.entity_representations[0].combined.entity_representations._embeddings.num_embeddings
embedding_dim = self.entity_representations[0].combined.entity_representations._embeddings.embedding_dim
emb = torch.randn((num_embeddings, embedding_dim))
print(f"Total number of entities: {num_embeddings}")
coverage = 0
for id in wiki2vec_emb:
try:
emb[translate_id([id])[0],:] = wiki2vec_emb[id]
coverage += 1
except KeyError:
# Item's id not present in the training data (hence in the pykeen id list)
...
print(f"Number of covered entities: {coverage}")
self.entity_representations[0].combined.entity_representations._embeddings = Embedding.from_pretrained(emb, freeze=False)
emb_training = TriplesFactory.from_path(
f"results/{dataset}/pykeen_train.tsv",
create_inverse_triples=True)
emb_testing = TriplesFactory.from_path(
f"results/{dataset}/pykeen_test.tsv",
entity_to_id=emb_training.entity_to_id,
relation_to_id=emb_training.relation_to_id,
create_inverse_triples=True)
translate_id = TranslateId(f"results/{dataset}/mapping_entities.tsv",
emb_training.entity_to_id)
if wiki2vec_embeddings_file:
# Prepare wiki2vec pre-trained embeddings
wiki2vec_emb = torch.load(wiki2vec_embeddings_file)
result = pipeline(
training=emb_training,
testing=emb_testing,
model=Wiki2VecCompGCN,
model_kwargs=dict(wiki2vec_emb=wiki2vec_emb,
embedding_dim=emb_dim,
encoder_kwargs={"num_layers": n_layers}),
evaluation_fallback = True,
random_seed=4316,
training_kwargs=dict(
num_epochs=epochs,
)
)
else:
result = pipeline(
training=emb_training,
testing=emb_testing,
model="CompGCN",
model_kwargs=dict(embedding_dim=emb_dim, encoder_kwargs={"num_layers": n_layers}),
evaluation_fallback = True,
random_seed=4316,
training_kwargs=dict(
num_epochs=epochs,
),
)
model = result.model
def get_all_embeddings(model):
combined = model.entity_representations[0].combined
emb_e = combined.entity_representations().detach()
emb_r = combined.relation_representations().detach()
all_emb_e = []
all_emb_r = []
for layer in combined.layers:
emb_e, emb_r = layer(emb_e, emb_r, combined.edge_index, combined.edge_type)
all_emb_e.append(emb_e)
all_emb_r.append(emb_r)
return all_emb_e, all_emb_r
all_emb_e, all_emb_r = get_all_embeddings(model)
exp_name = f"{n_layers}_wiki2vec" if wiki2vec_embeddings_file else f"{n_layers}"
torch.save(model, f"{output_path}/{dataset}_model_{exp_name}.pt")
torch.save(all_emb_e, f"{output_path}/{dataset}_embeddings_{exp_name}.pkl")
# save entity2id
with open(f"results/{dataset}/{dataset}_ent2id.pkl", "wb") as f:
pickle.dump(emb_training.entity_to_id, f)