Skip to content

Commit aded2b0

Browse files
committed
Clean up multi-processing logic -- Switch to PyTorch 0.4 style
1 parent 49ff9b5 commit aded2b0

File tree

5 files changed

+44
-53
lines changed

5 files changed

+44
-53
lines changed

README.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,6 @@ Finetuning the PyTorch model for 3 Epochs on ROCStories takes 10 minutes to run
5858

5959
The single run test accuracy of this PyTorch version is 85.84%, while the authors reports a median accuracy with the TensorFlow code of 85.8% and the paper reports a best single run accuracy of 86.5%.
6060

61-
The authors implementations uses 8 GPU and can thus accomodate a batch of 64 samples while the present implementation is single GPU and is in consequence limited to 20 instances on a K80 for memory reasons. In our test, increasing the batch size from 8 to 20 samples increased the test accuracy by 2.5 points. A better accuracy may be obtained by using a multi-GPU setting (on the TO-DO list).
61+
The authors implementations uses 8 GPU and can thus accomodate a batch of 64 samples while the present implementation is single GPU and is in consequence limited to 20 instances on a K80 for memory reasons. In our test, increasing the batch size from 8 to 20 samples increased the test accuracy by 2.5 points. A better accuracy may be obtained by using a multi-GPU setting (not tried yet).
6262

6363
The previous SOTA on the ROCStories dataset is 77.6% ("Hidden Coherence Model" of Chaturvedi et al. published in "Story Comprehension for Predicting What Happens Next" EMNLP 2017, which is a very nice paper too!)

datasets.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
seed = 3535999445
1111

1212
def _rocstories(path):
13-
with open(path) as f:
13+
with open(path, encoding='utf_8') as f:
1414
f = csv.reader(f)
1515
st = []
1616
ct1 = []

model_pytorch.py

+19-19
Original file line numberDiff line numberDiff line change
@@ -146,11 +146,11 @@ def forward(self, x):
146146
return h
147147

148148

149-
class Model(nn.Module):
149+
class TransformerModel(nn.Module):
150150
""" Transformer model """
151151

152152
def __init__(self, cfg, vocab=40990, n_ctx=512):
153-
super(Model, self).__init__()
153+
super(TransformerModel, self).__init__()
154154
self.vocab = vocab
155155
self.embed = nn.Embedding(vocab, cfg.n_embd)
156156
self.drop = nn.Dropout(cfg.embd_pdrop)
@@ -181,7 +181,7 @@ def __init__(self, model, cfg):
181181

182182
def forward(self, h):
183183
# Truncated Language modeling logits (we remove the last token)
184-
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) # Shape: 252, 768
184+
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
185185
lm_logits = self.decoder(h_trunc)
186186
return lm_logits
187187

@@ -202,24 +202,27 @@ def forward(self, h, x):
202202
# Classification logits
203203
clf_h = h.view(-1, self.n_embd)
204204
flat = x[:, :, :, 0].contiguous().view(-1)
205-
# pool_idx = torch.eq(x[:, :, 0].contiguous().view(-1), self.clf_token)
206-
clf_h = clf_h[flat == self.clf_token, :] # .index_select(0, pool_idx)
207-
clf_h = clf_h.view(-1, 2, self.n_embd, 1)
205+
clf_h = clf_h[flat == self.clf_token, :]
206+
clf_h = clf_h.view(-1, x.size(1), self.n_embd, 1)
208207
clf_h = self.dropout(clf_h)
209208
clf_h = clf_h.view(-1, self.n_embd)
210209
clf_logits = self.linear(clf_h)
211-
return clf_logits.view(-1, 2)
210+
return clf_logits.view(-1, x.size(1))
212211

213212

214-
class DataParallelWithEmbed(torch.nn.DataParallel):
215-
"""DataParallel that proxies the embed property to the wrapped module"""
213+
class DoubleHeadModel(nn.Module):
214+
""" Transformer with language model and classification heads """
215+
def __init__(self, cfg, clf_token, vocab=40990, n_ctx=512):
216+
super(DoubleHeadModel, self).__init__()
217+
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
218+
self.lm_head = LMHead(self.transformer, cfg)
219+
self.clf_head = ClfHead(clf_token, cfg)
216220

217-
def __init__(self, model):
218-
super(DataParallelWithEmbed, self).__init__(model)
219-
220-
@property
221-
def embed(self):
222-
return self.module.embed
221+
def forward(self, x):
222+
h = self.transformer(x)
223+
lm_logits = self.lm_head(h)
224+
clf_logits = self.clf_head(h, x)
225+
return lm_logits, clf_logits
223226

224227

225228
def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n_embd=768, path='./model/',
@@ -260,15 +263,12 @@ def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n
260263

261264
model.embed.weight.data = torch.from_numpy(init_params[0])
262265

263-
# Load the weights into our torch module
264-
module = model.module
265-
266266
for name, ip in zip(names[1:n_transfer], init_params[1:n_transfer]):
267267
name = name[6:] # skip "model/"
268268
assert name[-2:] == ":0"
269269
name = name[:-2]
270270
name = name.split('/')
271-
pointer = module
271+
pointer = model
272272
for m_name in name:
273273
if re.fullmatch(r'[A-Za-z]+\d+', m_name):
274274
l = re.split(r'(\d+)', m_name)

text_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def __init__(self, encoder_path, bpe_path):
4141
self.nlp = spacy.load('en', disable=['parser', 'tagger', 'ner', 'textcat'])
4242
self.encoder = json.load(open(encoder_path))
4343
self.decoder = {v:k for k,v in self.encoder.items()}
44-
merges = open(bpe_path).read().split('\n')[1:-1]
44+
merges = open(bpe_path, encoding='utf-8').read().split('\n')[1:-1]
4545
merges = [tuple(merge.split()) for merge in merges]
4646
self.bpe_ranks = dict(zip(merges, range(len(merges))))
4747
self.cache = {}

train.py

+22-31
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from analysis import rocstories as rocstories_analysis
1212
from datasets import rocstories
13-
from model_pytorch import Model, LMHead, ClfHead, load_openai_pretrained_model, DataParallelWithEmbed
13+
from model_pytorch import DoubleHeadModel, load_openai_pretrained_model
1414
from opt import OpenAIAdam
1515
from text_utils import TextEncoder
1616
from utils import (encode_dataset, iter_data,
@@ -75,14 +75,13 @@ def iter_apply(Xs, Ms, Ys):
7575
logits = []
7676
cost = 0
7777
with torch.no_grad():
78-
model.eval()
78+
dh_model.eval()
7979
for xmb, mmb, ymb in iter_data(Xs, Ms, Ys, n_batch=n_batch_train, truncate=False, verbose=True):
8080
n = len(xmb)
8181
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
8282
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
8383
MMB = torch.tensor(mmb).to(device)
84-
h = model(XMB)
85-
clf_logits = clf_head(h, XMB)
84+
_, clf_logits = dh_model(XMB)
8685
clf_logits *= n
8786
clf_losses = compute_loss_fct(XMB, YMB, MMB, clf_logits, only_return_losses=True)
8887
clf_losses *= n
@@ -95,13 +94,12 @@ def iter_apply(Xs, Ms, Ys):
9594
def iter_predict(Xs, Ms):
9695
logits = []
9796
with torch.no_grad():
98-
model.eval()
97+
dh_model.eval()
9998
for xmb, mmb in iter_data(Xs, Ms, n_batch=n_batch_train, truncate=False, verbose=True):
10099
n = len(xmb)
101100
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
102101
MMB = torch.tensor(mmb).to(device)
103-
h = model(XMB)
104-
clf_logits = clf_head(h, XMB)
102+
_, clf_logits = dh_model(XMB)
105103
logits.append(clf_logits.to("cpu").numpy())
106104
logits = np.concatenate(logits, 0)
107105
return logits
@@ -123,7 +121,7 @@ def log(save_dir, desc):
123121
if score > best_score:
124122
best_score = score
125123
path = os.path.join(save_dir, desc, 'best_params')
126-
torch.save(model.state_dict(), make_path(path))
124+
torch.save(dh_model.state_dict(), make_path(path))
127125

128126

129127
def predict(dataset, submission_dir):
@@ -145,13 +143,11 @@ def run_epoch():
145143
for xmb, mmb, ymb in iter_data(*shuffle(trX, trM, trYt, random_state=np.random),
146144
n_batch=n_batch_train, truncate=True, verbose=True):
147145
global n_updates
148-
model.train()
146+
dh_model.train()
149147
XMB = torch.tensor(xmb, dtype=torch.long).to(device)
150148
YMB = torch.tensor(ymb, dtype=torch.long).to(device)
151149
MMB = torch.tensor(mmb).to(device)
152-
h = model(XMB)
153-
lm_logits = lm_head(h)
154-
clf_logits = clf_head(h, XMB)
150+
lm_logits, clf_logits = dh_model(XMB)
155151
compute_loss_fct(XMB, YMB, MMB, clf_logits, lm_logits)
156152
n_updates += 1
157153
if n_updates in [1000, 2000, 4000, 8000, 16000, 32000] and n_epochs == 0:
@@ -198,7 +194,7 @@ def run_epoch():
198194
parser.add_argument('--clf_pdrop', type=float, default=0.1)
199195
parser.add_argument('--l2', type=float, default=0.01)
200196
parser.add_argument('--vector_l2', action='store_true')
201-
parser.add_argument('--n_gpu', type=int, default=1) # 4) # TODO add mutli-gpu training logic
197+
parser.add_argument('--n_gpu', type=int, default=1)
202198
parser.add_argument('--opt', type=str, default='adam')
203199
parser.add_argument('--afn', type=str, default='gelu')
204200
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
@@ -213,7 +209,6 @@ def run_epoch():
213209

214210
args = parser.parse_args()
215211
print(args)
216-
# globals().update(args.__dict__) # TODO maybe we want to remove these gobal variables to make it cleaner
217212

218213
random.seed(args.seed)
219214
np.random.seed(args.seed)
@@ -238,9 +233,10 @@ def run_epoch():
238233
n_vocab = len(text_encoder.encoder)
239234

240235
print("Encoding dataset...")
241-
(trX1, trX2, trX3, trY), (vaX1, vaX2, vaX3, vaY), (teX1, teX2, teX3) = encode_dataset(
242-
rocstories(data_dir, n_valid=args.n_valid), encoder=text_encoder)
243-
n_y = 2
236+
((trX1, trX2, trX3, trY),
237+
(vaX1, vaX2, vaX3, vaY),
238+
(teX1, teX2, teX3)) = encode_dataset(rocstories(data_dir, n_valid=args.n_valid),
239+
encoder=text_encoder)
244240
encoder['_start_'] = len(encoder)
245241
encoder['_delimiter_'] = len(encoder)
246242
encoder['_classify_'] = len(encoder)
@@ -254,7 +250,7 @@ def run_epoch():
254250
len(x3[:max_len])) for x1, x2, x3 in zip(vaX1, vaX2, vaX3)]
255251
+ [len(x1[:max_len]) + max(len(x2[:max_len]),
256252
len(x3[:max_len])) for x1, x2, x3 in zip(teX1, teX2, teX3)]
257-
) + 3, n_ctx)
253+
) + 3, n_ctx)
258254
vocab = n_vocab + n_special + n_ctx
259255
trX, trM = transform_roc(trX1, trX2, trX3)
260256
vaX, vaM = transform_roc(vaX1, vaX2, vaX3)
@@ -266,14 +262,10 @@ def run_epoch():
266262
n_batch_train = args.n_batch * args.n_gpu
267263
n_updates_total = (n_train // n_batch_train) * args.n_iter
268264

269-
model = Model(args, vocab, n_ctx)
270-
model = DataParallelWithEmbed(model).cuda()
265+
dh_model = DoubleHeadModel(args, clf_token, vocab, n_ctx)
271266

272-
lm_head = LMHead(model, args)
273-
clf_head = ClfHead(clf_token, args)
274-
275-
criterion = nn.CrossEntropyLoss(reduce=False) # TODO check loss functions
276-
model_opt = OpenAIAdam(list(model.parameters()) + list(clf_head.parameters()) + list(lm_head.parameters()),
267+
criterion = nn.CrossEntropyLoss(reduce=False)
268+
model_opt = OpenAIAdam(dh_model.parameters(),
277269
lr=args.lr,
278270
schedule=args.lr_schedule,
279271
warmup=args.lr_warmup,
@@ -288,19 +280,18 @@ def run_epoch():
288280
criterion,
289281
args.lm_coef,
290282
model_opt)
291-
load_openai_pretrained_model(model, n_ctx=n_ctx, n_special=n_special)
283+
load_openai_pretrained_model(dh_model.transformer, n_ctx=n_ctx, n_special=n_special)
292284

293-
model.to(device)
294-
lm_head.to(device)
295-
clf_head.to(device)
285+
dh_model.to(device)
286+
dh_model = nn.DataParallel(dh_model)
296287

297288
n_updates = 0
298289
n_epochs = 0
299290
if dataset != 'stsb':
300291
trYt = trY
301292
if submit:
302293
path = os.path.join(save_dir, desc, 'best_params')
303-
torch.save(model.state_dict(), make_path(path))
294+
torch.save(dh_model.state_dict(), make_path(path))
304295
best_score = 0
305296
for i in range(args.n_iter):
306297
print("running epoch", i)
@@ -309,7 +300,7 @@ def run_epoch():
309300
log(save_dir, desc)
310301
if submit:
311302
path = os.path.join(save_dir, desc, 'best_params')
312-
model.load_state_dict(torch.load(path))
303+
dh_model.load_state_dict(torch.load(path))
313304
predict(dataset, args.submission_dir)
314305
if args.analysis:
315306
rocstories_analysis(data_dir, os.path.join(args.submission_dir, 'ROCStories.tsv'),

0 commit comments

Comments
 (0)