1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch .nn as nn
8
+
9
+ from model_pytorch import LMModel , load_openai_pretrained_model
10
+ from text_utils import TextEncoder
11
+
12
+
13
+ def make_batch (X ):
14
+ X = np .array (X )
15
+ assert X .ndim in [1 , 2 ]
16
+ if X .ndim == 1 :
17
+ X = np .expand_dims (X , axis = 0 )
18
+ pos_enc = np .arange (n_vocab + n_special , n_vocab + n_special + X .shape [- 1 ])
19
+ pos_enc = np .expand_dims (pos_enc , axis = 0 )
20
+ batch = np .stack ([X , pos_enc ], axis = - 1 )
21
+ batch = torch .tensor (batch , dtype = torch .long ).to (device )
22
+ return batch
23
+
24
+ def append_batch (X , next_idx ):
25
+ next_pos = X [:, - 1 :, 1 ] + 1
26
+ next_x = torch .cat ((next_idx , next_pos ), - 1 ).unsqueeze (1 )
27
+ return torch .cat ((X , next_x ), 1 )
28
+
29
+
30
+ if __name__ == '__main__' :
31
+ parser = argparse .ArgumentParser ()
32
+ parser .add_argument ('--desc' , type = str , help = "Description" )
33
+ parser .add_argument ('--dataset' , type = str )
34
+ parser .add_argument ('--log_dir' , type = str , default = 'log/' )
35
+ parser .add_argument ('--save_dir' , type = str , default = 'save/' )
36
+ parser .add_argument ('--data_dir' , type = str , default = 'data/' )
37
+ parser .add_argument ('--submission_dir' , type = str , default = 'submission/' )
38
+ parser .add_argument ('--submit' , action = 'store_true' )
39
+ parser .add_argument ('--analysis' , action = 'store_true' )
40
+ parser .add_argument ('--seed' , type = int , default = 42 )
41
+ parser .add_argument ('--n_iter' , type = int , default = 3 )
42
+ parser .add_argument ('--n_batch' , type = int , default = 8 )
43
+ parser .add_argument ('--max_grad_norm' , type = int , default = 1 )
44
+ parser .add_argument ('--lr' , type = float , default = 6.25e-5 )
45
+ parser .add_argument ('--lr_warmup' , type = float , default = 0.002 )
46
+ parser .add_argument ('--n_ctx' , type = int , default = 512 )
47
+ parser .add_argument ('--n_embd' , type = int , default = 768 )
48
+ parser .add_argument ('--n_head' , type = int , default = 12 )
49
+ parser .add_argument ('--n_layer' , type = int , default = 12 )
50
+ parser .add_argument ('--embd_pdrop' , type = float , default = 0.1 )
51
+ parser .add_argument ('--attn_pdrop' , type = float , default = 0.1 )
52
+ parser .add_argument ('--resid_pdrop' , type = float , default = 0.1 )
53
+ parser .add_argument ('--clf_pdrop' , type = float , default = 0.1 )
54
+ parser .add_argument ('--l2' , type = float , default = 0.01 )
55
+ parser .add_argument ('--vector_l2' , action = 'store_true' )
56
+ parser .add_argument ('--opt' , type = str , default = 'adam' )
57
+ parser .add_argument ('--afn' , type = str , default = 'gelu' )
58
+ parser .add_argument ('--lr_schedule' , type = str , default = 'warmup_linear' )
59
+ parser .add_argument ('--encoder_path' , type = str , default = 'model/encoder_bpe_40000.json' )
60
+ parser .add_argument ('--bpe_path' , type = str , default = 'model/vocab_40000.bpe' )
61
+ parser .add_argument ('--n_transfer' , type = int , default = 12 )
62
+ parser .add_argument ('--lm_coef' , type = float , default = 0.5 )
63
+ parser .add_argument ('--b1' , type = float , default = 0.9 )
64
+ parser .add_argument ('--b2' , type = float , default = 0.999 )
65
+ parser .add_argument ('--e' , type = float , default = 1e-8 )
66
+ parser .add_argument ('--n_valid' , type = int , default = 374 )
67
+ parser .add_argument ('--gen_len' , type = int , default = 20 )
68
+ parser .add_argument ('--topk' , type = int , default = 10 )
69
+
70
+ args = parser .parse_args ()
71
+ print (args )
72
+
73
+ random .seed (args .seed )
74
+ np .random .seed (args .seed )
75
+ torch .manual_seed (args .seed )
76
+ torch .cuda .manual_seed_all (args .seed )
77
+
78
+ # Constants
79
+ submit = args .submit
80
+ dataset = args .dataset
81
+ n_ctx = args .n_ctx
82
+ save_dir = args .save_dir
83
+ desc = args .desc
84
+ data_dir = args .data_dir
85
+ log_dir = args .log_dir
86
+ submission_dir = args .submission_dir
87
+
88
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
89
+ n_gpu = torch .cuda .device_count ()
90
+ print ("device" , device , "n_gpu" , n_gpu )
91
+
92
+ text_encoder = TextEncoder (args .encoder_path , args .bpe_path )
93
+ encoder = text_encoder .encoder
94
+ n_vocab = len (text_encoder .encoder )
95
+
96
+ n_special = 0 # XD: useless for language modeling task
97
+ vocab = n_vocab + n_special + n_ctx
98
+
99
+ lm_model = LMModel (args , vocab , n_ctx , return_probs = True )
100
+ load_openai_pretrained_model (lm_model .transformer , n_ctx = n_ctx , n_special = n_special )
101
+ lm_model .to (device )
102
+
103
+ lm_model .eval ()
104
+
105
+ text = input ('Input some beginning words:' )
106
+ while text != 'q' :
107
+ X = text_encoder .encode ([text ,])
108
+ XMB = make_batch (X )
109
+
110
+ for _ in range (args .gen_len ):
111
+ lm_probs = lm_model (XMB )
112
+ if args .topk == 0 :
113
+ next_idx = torch .multinomial (lm_probs [:, - 1 , :], 1 )
114
+ else :
115
+ values , indices = lm_probs [:, - 1 , :].topk (args .topk )
116
+ next_idx = indices .gather (- 1 , torch .multinomial (values , 1 ))
117
+ next_token = text_encoder .decoder [next_idx .item ()].replace ('</w>' , '' )
118
+ print (next_token , end = ' ' )
119
+ XMB = append_batch (XMB , next_idx )
120
+
121
+ print ()
122
+ text = input ('Input some beginning words:' )
0 commit comments