Skip to content

Commit 5bf3607

Browse files
committed
Generation from pretrained LM model
1 parent a287899 commit 5bf3607

File tree

1 file changed

+117
-0
lines changed

1 file changed

+117
-0
lines changed

generate.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import argparse
2+
import os
3+
import random
4+
5+
import numpy as np
6+
import torch
7+
import torch.nn as nn
8+
9+
from model_pytorch import LMModel, load_openai_pretrained_model
10+
from text_utils import TextEncoder
11+
12+
13+
def make_batch(X):
14+
X = np.array(X)
15+
assert X.ndim in [1, 2]
16+
if X.ndim == 1:
17+
X = np.expand_dims(X, axis=0)
18+
pos_enc = np.arange(n_vocab + n_special, n_vocab + n_special + X.shape[-1])
19+
pos_enc = np.expand_dims(pos_enc, axis=0)
20+
batch = np.stack([X, pos_enc], axis=-1)
21+
batch = torch.tensor(batch, dtype=torch.long).to(device)
22+
return batch
23+
24+
def append_batch(X, next_idx):
25+
next_pos = X[:, -1:, 1] + 1
26+
next_x = torch.cat((next_idx, next_pos), -1).unsqueeze(1)
27+
return torch.cat((X, next_x), 1)
28+
29+
30+
if __name__ == '__main__':
31+
parser = argparse.ArgumentParser()
32+
parser.add_argument('--desc', type=str, help="Description")
33+
parser.add_argument('--dataset', type=str)
34+
parser.add_argument('--log_dir', type=str, default='log/')
35+
parser.add_argument('--save_dir', type=str, default='save/')
36+
parser.add_argument('--data_dir', type=str, default='data/')
37+
parser.add_argument('--submission_dir', type=str, default='submission/')
38+
parser.add_argument('--submit', action='store_true')
39+
parser.add_argument('--analysis', action='store_true')
40+
parser.add_argument('--seed', type=int, default=42)
41+
parser.add_argument('--n_iter', type=int, default=3)
42+
parser.add_argument('--n_batch', type=int, default=8)
43+
parser.add_argument('--max_grad_norm', type=int, default=1)
44+
parser.add_argument('--lr', type=float, default=6.25e-5)
45+
parser.add_argument('--lr_warmup', type=float, default=0.002)
46+
parser.add_argument('--n_ctx', type=int, default=512)
47+
parser.add_argument('--n_embd', type=int, default=768)
48+
parser.add_argument('--n_head', type=int, default=12)
49+
parser.add_argument('--n_layer', type=int, default=12)
50+
parser.add_argument('--embd_pdrop', type=float, default=0.1)
51+
parser.add_argument('--attn_pdrop', type=float, default=0.1)
52+
parser.add_argument('--resid_pdrop', type=float, default=0.1)
53+
parser.add_argument('--clf_pdrop', type=float, default=0.1)
54+
parser.add_argument('--l2', type=float, default=0.01)
55+
parser.add_argument('--vector_l2', action='store_true')
56+
parser.add_argument('--opt', type=str, default='adam')
57+
parser.add_argument('--afn', type=str, default='gelu')
58+
parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
59+
parser.add_argument('--encoder_path', type=str, default='model/encoder_bpe_40000.json')
60+
parser.add_argument('--bpe_path', type=str, default='model/vocab_40000.bpe')
61+
parser.add_argument('--n_transfer', type=int, default=12)
62+
parser.add_argument('--lm_coef', type=float, default=0.5)
63+
parser.add_argument('--b1', type=float, default=0.9)
64+
parser.add_argument('--b2', type=float, default=0.999)
65+
parser.add_argument('--e', type=float, default=1e-8)
66+
parser.add_argument('--n_valid', type=int, default=374)
67+
parser.add_argument('--gen_len', type=int, default=20)
68+
69+
args = parser.parse_args()
70+
print(args)
71+
72+
random.seed(args.seed)
73+
np.random.seed(args.seed)
74+
torch.manual_seed(args.seed)
75+
torch.cuda.manual_seed_all(args.seed)
76+
77+
# Constants
78+
submit = args.submit
79+
dataset = args.dataset
80+
n_ctx = args.n_ctx
81+
save_dir = args.save_dir
82+
desc = args.desc
83+
data_dir = args.data_dir
84+
log_dir = args.log_dir
85+
submission_dir = args.submission_dir
86+
87+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
88+
n_gpu = torch.cuda.device_count()
89+
print("device", device, "n_gpu", n_gpu)
90+
91+
text_encoder = TextEncoder(args.encoder_path, args.bpe_path)
92+
encoder = text_encoder.encoder
93+
n_vocab = len(text_encoder.encoder)
94+
95+
n_special = 0 # XD: useless for language modeling task
96+
vocab = n_vocab + n_special + n_ctx
97+
98+
lm_model = LMModel(args, vocab, n_ctx, return_probs=True)
99+
load_openai_pretrained_model(lm_model.transformer, n_ctx=n_ctx, n_special=n_special)
100+
lm_model.to(device)
101+
102+
lm_model.eval()
103+
104+
text = input('Input some beginning words:')
105+
while text != 'q':
106+
X = text_encoder.encode([text,])
107+
XMB = make_batch(X)
108+
109+
for _ in range(args.gen_len):
110+
lm_probs = lm_model(XMB)
111+
next_idx = torch.multinomial(lm_probs[:, -1, :], 1)
112+
next_token = text_encoder.decoder[next_idx.item()].replace('</w>', '')
113+
print(next_token, end=' ')
114+
XMB = append_batch(XMB, next_idx)
115+
116+
print()
117+
text = input('Input some beginning words:')

0 commit comments

Comments
 (0)