1
+ import argparse
2
+ import os
3
+ import random
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch .nn as nn
8
+
9
+ from model_pytorch import LMModel , load_openai_pretrained_model
10
+ from text_utils import TextEncoder
11
+
12
+
13
+ def make_batch (X ):
14
+ X = np .array (X )
15
+ assert X .ndim in [1 , 2 ]
16
+ if X .ndim == 1 :
17
+ X = np .expand_dims (X , axis = 0 )
18
+ pos_enc = np .arange (n_vocab + n_special , n_vocab + n_special + X .shape [- 1 ])
19
+ pos_enc = np .expand_dims (pos_enc , axis = 0 )
20
+ batch = np .stack ([X , pos_enc ], axis = - 1 )
21
+ batch = torch .tensor (batch , dtype = torch .long ).to (device )
22
+ return batch
23
+
24
+ def append_batch (X , next_idx ):
25
+ next_pos = X [:, - 1 :, 1 ] + 1
26
+ next_x = torch .cat ((next_idx , next_pos ), - 1 ).unsqueeze (1 )
27
+ return torch .cat ((X , next_x ), 1 )
28
+
29
+
30
+ if __name__ == '__main__' :
31
+ parser = argparse .ArgumentParser ()
32
+ parser .add_argument ('--desc' , type = str , help = "Description" )
33
+ parser .add_argument ('--dataset' , type = str )
34
+ parser .add_argument ('--log_dir' , type = str , default = 'log/' )
35
+ parser .add_argument ('--save_dir' , type = str , default = 'save/' )
36
+ parser .add_argument ('--data_dir' , type = str , default = 'data/' )
37
+ parser .add_argument ('--submission_dir' , type = str , default = 'submission/' )
38
+ parser .add_argument ('--submit' , action = 'store_true' )
39
+ parser .add_argument ('--analysis' , action = 'store_true' )
40
+ parser .add_argument ('--seed' , type = int , default = 42 )
41
+ parser .add_argument ('--n_iter' , type = int , default = 3 )
42
+ parser .add_argument ('--n_batch' , type = int , default = 8 )
43
+ parser .add_argument ('--max_grad_norm' , type = int , default = 1 )
44
+ parser .add_argument ('--lr' , type = float , default = 6.25e-5 )
45
+ parser .add_argument ('--lr_warmup' , type = float , default = 0.002 )
46
+ parser .add_argument ('--n_ctx' , type = int , default = 512 )
47
+ parser .add_argument ('--n_embd' , type = int , default = 768 )
48
+ parser .add_argument ('--n_head' , type = int , default = 12 )
49
+ parser .add_argument ('--n_layer' , type = int , default = 12 )
50
+ parser .add_argument ('--embd_pdrop' , type = float , default = 0.1 )
51
+ parser .add_argument ('--attn_pdrop' , type = float , default = 0.1 )
52
+ parser .add_argument ('--resid_pdrop' , type = float , default = 0.1 )
53
+ parser .add_argument ('--clf_pdrop' , type = float , default = 0.1 )
54
+ parser .add_argument ('--l2' , type = float , default = 0.01 )
55
+ parser .add_argument ('--vector_l2' , action = 'store_true' )
56
+ parser .add_argument ('--opt' , type = str , default = 'adam' )
57
+ parser .add_argument ('--afn' , type = str , default = 'gelu' )
58
+ parser .add_argument ('--lr_schedule' , type = str , default = 'warmup_linear' )
59
+ parser .add_argument ('--encoder_path' , type = str , default = 'model/encoder_bpe_40000.json' )
60
+ parser .add_argument ('--bpe_path' , type = str , default = 'model/vocab_40000.bpe' )
61
+ parser .add_argument ('--n_transfer' , type = int , default = 12 )
62
+ parser .add_argument ('--lm_coef' , type = float , default = 0.5 )
63
+ parser .add_argument ('--b1' , type = float , default = 0.9 )
64
+ parser .add_argument ('--b2' , type = float , default = 0.999 )
65
+ parser .add_argument ('--e' , type = float , default = 1e-8 )
66
+ parser .add_argument ('--n_valid' , type = int , default = 374 )
67
+ parser .add_argument ('--gen_len' , type = int , default = 20 )
68
+
69
+ args = parser .parse_args ()
70
+ print (args )
71
+
72
+ random .seed (args .seed )
73
+ np .random .seed (args .seed )
74
+ torch .manual_seed (args .seed )
75
+ torch .cuda .manual_seed_all (args .seed )
76
+
77
+ # Constants
78
+ submit = args .submit
79
+ dataset = args .dataset
80
+ n_ctx = args .n_ctx
81
+ save_dir = args .save_dir
82
+ desc = args .desc
83
+ data_dir = args .data_dir
84
+ log_dir = args .log_dir
85
+ submission_dir = args .submission_dir
86
+
87
+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
88
+ n_gpu = torch .cuda .device_count ()
89
+ print ("device" , device , "n_gpu" , n_gpu )
90
+
91
+ text_encoder = TextEncoder (args .encoder_path , args .bpe_path )
92
+ encoder = text_encoder .encoder
93
+ n_vocab = len (text_encoder .encoder )
94
+
95
+ n_special = 0 # XD: useless for language modeling task
96
+ vocab = n_vocab + n_special + n_ctx
97
+
98
+ lm_model = LMModel (args , vocab , n_ctx , return_probs = True )
99
+ load_openai_pretrained_model (lm_model .transformer , n_ctx = n_ctx , n_special = n_special )
100
+ lm_model .to (device )
101
+
102
+ lm_model .eval ()
103
+
104
+ text = input ('Input some beginning words:' )
105
+ while text != 'q' :
106
+ X = text_encoder .encode ([text ,])
107
+ XMB = make_batch (X )
108
+
109
+ for _ in range (args .gen_len ):
110
+ lm_probs = lm_model (XMB )
111
+ next_idx = torch .multinomial (lm_probs [:, - 1 , :], 1 )
112
+ next_token = text_encoder .decoder [next_idx .item ()].replace ('</w>' , '' )
113
+ print (next_token , end = ' ' )
114
+ XMB = append_batch (XMB , next_idx )
115
+
116
+ print ()
117
+ text = input ('Input some beginning words:' )
0 commit comments