-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsenteval.py
111 lines (88 loc) · 4.1 KB
/
senteval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import absolute_import, division, unicode_literals
import sys
import logging
import argparse
import torch
import torchtext
from SNLIClassifier import SNLIClassifier
DEFAULT_SENTEVAL_PATH = '../SentEval'
DEFAULT_DATA_PATH = '../SentEval/data'
# import SentEval
sys.path.insert(0, DEFAULT_SENTEVAL_PATH)
import senteval
# SentEval prepare function
def prepare(params, samples):
# Set device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load the vocabulary
checkpoint = torch.load(args.checkpoint_path)
vocab = checkpoint['text_vocab']
vocab_size = len(vocab)
# Load the vectors and set the unknown vector
glove_vectors = torchtext.vocab.Vectors(name=args.vector_file)
unk_vector = torch.mean(glove_vectors.vectors, dim=0)
# Define the model and load it
model = SNLIClassifier(encoder=args.model_type,
vocab_size=vocab_size,
embedding_dim=300,
hidden_dim=2048,
fc_dim=512,
num_classes=3,
pretrained_vectors=None).to(device)
model.load_state_dict(checkpoint['model_state_dict'])
# Set params
params['device'] = device
params['encoder'] = model.encoder
params['vectors'] = glove_vectors
params['unk_vector'] = unk_vector
# SentEval batcher function
def batcher(params, batch):
batch = [sent if sent != [] else ['.'] for sent in batch]
batch_size = len(batch)
sent_lengths = [len(sentence) for sentence in batch]
sent_lengths = torch.LongTensor(sent_lengths).to(params['device'])
longest_length = torch.max(sent_lengths)
word_embeddings = torch.ones((longest_length, batch_size, params['vectors'].dim)).to(params['device'])
for sent_id, sent in enumerate(batch):
for word_id, word in enumerate(sent):
if isinstance(word, str):
if word in params['vectors'].stoi:
word_embeddings[word_id, sent_id, :] = params['vectors'].vectors[params['vectors'].stoi[word]]
elif isinstance(word, bytes):
if word.decode('UTF-8') in params['vectors'].stoi:
word_embeddings[word_id, sent_id, :] = params['vectors'].vectors[params['vectors'].stoi[word.decode('UTF-8')]]
else:
word_embeddings[word_id, sent_id, :] = params['unk_vector']
with torch.no_grad():
sent_embeddings = params['encoder'](word_embeddings, sent_lengths)
return sent_embeddings.cpu().numpy()
# Set params for SentEval
params_senteval = {'task_path': DEFAULT_DATA_PATH, 'usepytorch': True, 'kfold': 10}
params_senteval['classifier'] = {'nhid': 0, 'optim': 'adam', 'batch_size': 64,
'tenacity': 5, 'epoch_size': 4}
# Set up logger
logging.basicConfig(format='%(asctime)s : %(message)s', level=logging.DEBUG)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('model_type', choices={'average', 'uniLSTM', 'biLSTM', 'biLSTMmaxpool'},
help='Type of encoder for the sentences')
parser.add_argument('checkpoint_path', type=str,
help='Path to load the model checkpoint')
parser.add_argument('vector_file', type=str,
help='File in which vectors are saved')
parser.add_argument('--senteval_path', type=str, default=DEFAULT_SENTEVAL_PATH,
help='Path to SentEval repository')
parser.add_argument('--data_path', type=str, default=DEFAULT_DATA_PATH,
help='Path to SentEval data')
args = parser.parse_args()
se = senteval.engine.SE(params_senteval, batcher, prepare)
transfer_tasks = ['MR', 'CR', 'SUBJ', 'MPQA', 'SST2', 'TREC', 'MRPC', 'SICKRelatedness',
'SICKEntailment', 'STS14', 'ImageCaptionRetrieval']
results = se.eval(transfer_tasks)
print(results)