Skip to content

Commit d848a49

Browse files
authored
Merge pull request #35 from xiaoda99/master
Text generation with pretrained LM model
2 parents eafc28a + d515b8f commit d848a49

File tree

2 files changed

+154
-3
lines changed

2 files changed

+154
-3
lines changed

generate.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import argparse
2+
import os
3+
import random
4+
5+
import numpy as np
6+
import torch
7+
import torch.nn as nn
8+
9+
from model_pytorch import LMModel, load_openai_pretrained_model
10+
from text_utils import TextEncoder
11+
12+
13+
def make_batch(X):
14+
X = np.array(X)
15+
assert X.ndim in [1, 2]
16+
if X.ndim == 1:
17+
X = np.expand_dims(X, axis=0)
18+
pos_enc = np.arange(n_vocab + n_special, n_vocab + n_special + X.shape[-1])
19+
pos_enc = np.expand_dims(pos_enc, axis=0)
20+
batch = np.stack([X, pos_enc], axis=-1)
21+
batch = torch.tensor(batch, dtype=torch.long).to(device)
22+
return batch
23+
24+
def append_batch(X, next_idx):
25+
next_pos = X[:, -1:, 1] + 1
26+
next_x = torch.cat((next_idx, next_pos), -1).unsqueeze(1)
27+
return torch.cat((X, next_x), 1)
28+
29+
30+
if __name__ == '__main__':
31+
parser = argparse.ArgumentParser()
32+
parser.add_argument('--desc', type=str, help="Description")
33+
parser.add_argument('--dataset', type=str)
34+
parser.add_argument('--log_dir', type=str, default='log/')
35+
parser.add_argument('--save_dir', type=str, default='save/')
36+
parser.add_argument('--data_dir', type=str, default='data/')
37+
parser.add_argument('--submission_dir', type=str, default='submission/')
38+
parser.add_argument('--submit', action='store_true')
39+
parser.add_argument('--analysis', action='store_true')
40+
parser.add_argument('--seed', type=int, default=42)
41+
parser.add_argument('--n_iter', type=int, default=3)
42+
parser.add_argument('--n_batch', type=int, default=8)
43+
parser.add_argument('--max_grad_norm', type=int, default=1)
44+
parser.add_argument('--lr', type=float, default=6.25e-5)
45+
parser.add_argument('--lr_warmup', type=float, default=0.002)
46+
parser.add_argument('--n_ctx', type=int, default=512)
47+
parser.add_argument('--n_embd', type=int, default=768)
48+
parser.add_argument('--n_head', type=int, default=12)
49+
parser.add_argument('--n_layer', type=int, default=12)
50+
parser.add_argument('--embd_pdrop', type=float, default=0.1)
51+
parser.add_argument('--attn_pdrop', type=float, default=0.1)
52+
parser.add_argument('--resid_pdrop', type=float, default=0.1)
53+
parser.add_argument('--clf_pdrop', type=float, default=0.1)
54+
parser.add_argument('--l2', type=float, default=0.01)
55+
parser.add_argument('--vector_l2', action='store_true')
56+
parser.add_argument('--opt', type=str, default='adam')
57+
parser.add_argument('--afn', type=str, default='gelu')
58+
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
59+
parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json')
60+
parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe')
61+
parser.add_argument('--n_transfer', type=int, default=12)
62+
parser.add_argument('--lm_coef', type=float, default=0.5)
63+
parser.add_argument('--b1', type=float, default=0.9)
64+
parser.add_argument('--b2', type=float, default=0.999)
65+
parser.add_argument('--e', type=float, default=1e-8)
66+
parser.add_argument('--n_valid', type=int, default=374)
67+
parser.add_argument('--gen_len', type=int, default=20)
68+
parser.add_argument('--topk', type=int, default=10)
69+
70+
args = parser.parse_args()
71+
print(args)
72+
73+
random.seed(args.seed)
74+
np.random.seed(args.seed)
75+
torch.manual_seed(args.seed)
76+
torch.cuda.manual_seed_all(args.seed)
77+
78+
# Constants
79+
submit = args.submit
80+
dataset = args.dataset
81+
n_ctx = args.n_ctx
82+
save_dir = args.save_dir
83+
desc = args.desc
84+
data_dir = args.data_dir
85+
log_dir = args.log_dir
86+
submission_dir = args.submission_dir
87+
88+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
89+
n_gpu = torch.cuda.device_count()
90+
print("device", device, "n_gpu", n_gpu)
91+
92+
text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
93+
encoder = text_encoder.encoder
94+
n_vocab = len(text_encoder.encoder)
95+
96+
n_special = 0 # XD: useless for language modeling task
97+
vocab = n_vocab + n_special + n_ctx
98+
99+
lm_model = LMModel(args, vocab, n_ctx, return_probs=True)
100+
load_openai_pretrained_model(lm_model.transformer, n_ctx=n_ctx, n_special=n_special)
101+
lm_model.to(device)
102+
103+
lm_model.eval()
104+
105+
text = input('Input some beginning words:')
106+
while text != 'q':
107+
X = text_encoder.encode([text,])
108+
XMB = make_batch(X)
109+
110+
for _ in range(args.gen_len):
111+
lm_probs = lm_model(XMB)
112+
if args.topk == 0:
113+
next_idx = torch.multinomial(lm_probs[:, -1, :], 1)
114+
else:
115+
values, indices = lm_probs[:, -1, :].topk(args.topk)
116+
next_idx = indices.gather(-1, torch.multinomial(values, 1))
117+
next_token = text_encoder.decoder[next_idx.item()].replace('</w>', '')
118+
print(next_token, end=' ')
119+
XMB = append_batch(XMB, next_idx)
120+
121+
print()
122+
text = input('Input some beginning words:')

model_pytorch.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ def _attn(self, q, k, v):
8484
w = torch.matmul(q, k)
8585
if self.scale:
8686
w = w / math.sqrt(v.size(-1))
87-
w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
87+
# w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
88+
# XD: self.b may be larger than w, so we need to crop it
89+
b = self.b[:, :, :w.size(-2), :w.size(-1)]
90+
w = w * b + -1e9 * (1 - b)
91+
8892
w = nn.Softmax(dim=-1)(w)
8993
w = self.attn_dropout(w)
9094
return torch.matmul(w, v)
@@ -173,16 +177,18 @@ def forward(self, x):
173177
class LMHead(nn.Module):
174178
""" Language Model Head for the transformer """
175179

176-
def __init__(self, model, cfg):
180+
def __init__(self, model, cfg, trunc_and_reshape=True):
177181
super(LMHead, self).__init__()
178182
self.n_embd = cfg.n_embd
179183
embed_shape = model.embed.weight.shape
180184
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
181185
self.decoder.weight = model.embed.weight # Tied weights
186+
self.trunc_and_reshape = trunc_and_reshape # XD
182187

183188
def forward(self, h):
184189
# Truncated Language modeling logits (we remove the last token)
185-
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
190+
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) \
191+
if self.trunc_and_reshape else h # XD
186192
lm_logits = self.decoder(h_trunc)
187193
return lm_logits
188194

@@ -264,6 +270,29 @@ def forward(self, h, x):
264270

265271
return sim_logits
266272

273+
274+
# XD
275+
class LMModel(nn.Module):
276+
""" Transformer with language model head only """
277+
def __init__(self, cfg, vocab=40990, n_ctx=512, return_probs=False):
278+
super(LMModel, self).__init__()
279+
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
280+
self.lm_head = LMHead(self.transformer, cfg, trunc_and_reshape=False)
281+
self.return_probs = return_probs
282+
if self.return_probs:
283+
pos_emb_mask = torch.zeros(1, 1, vocab)
284+
pos_emb_mask[:, :, -n_ctx:] = -1e12
285+
self.register_buffer('pos_emb_mask', pos_emb_mask)
286+
287+
288+
def forward(self, x):
289+
h = self.transformer(x)
290+
lm_logits = self.lm_head(h)
291+
if self.return_probs:
292+
lm_logits = F.softmax(lm_logits + self.pos_emb_mask, dim=-1)
293+
return lm_logits
294+
295+
267296
class DoubleHeadModel(nn.Module):
268297
""" Transformer with language model and task specific heads """
269298
def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):

0 commit comments

Comments
 (0)