11
11
12
12
13
13
def gelu (x ):
14
- return 0.5 * x * (1 + torch .tanh (math .sqrt (2 / math .pi )* (x + 0.044715 * torch .pow (x , 3 ))))
14
+ return 0.5 * x * (1 + torch .tanh (math .sqrt (2 / math .pi ) * (x + 0.044715 * torch .pow (x , 3 ))))
15
+
15
16
16
17
def swish (x ):
17
- return x * torch .sigmoid (x )
18
+ return x * torch .sigmoid (x )
19
+
18
20
19
21
ACT_FNS = {
20
22
'relu' : nn .ReLU ,
@@ -25,6 +27,7 @@ def swish(x):
25
27
26
28
class LayerNorm (nn .Module ):
27
29
"Construct a layernorm module in the OpenAI style (epsilon inside the square root)."
30
+
28
31
def __init__ (self , n_state , e = 1e-5 ):
29
32
super (LayerNorm , self ).__init__ ()
30
33
self .g = nn .Parameter (torch .ones (n_state ))
@@ -43,12 +46,12 @@ def __init__(self, nf, rf, nx):
43
46
super (Conv1D , self ).__init__ ()
44
47
self .rf = rf
45
48
self .nf = nf
46
- if rf == 1 : # faster 1x1 conv
49
+ if rf == 1 : # faster 1x1 conv
47
50
w = torch .empty (nx , nf )
48
51
nn .init .normal_ (w , std = 0.02 )
49
52
self .w = Parameter (w )
50
53
self .b = Parameter (torch .zeros (nf ))
51
- else : # was used to train LM
54
+ else : # was used to train LM
52
55
raise NotImplementedError
53
56
54
57
def forward (self , x ):
@@ -64,9 +67,9 @@ def forward(self, x):
64
67
class Attention (nn .Module ):
65
68
def __init__ (self , nx , n_ctx , cfg , scale = False ):
66
69
super (Attention , self ).__init__ ()
67
- n_state = nx # in Attention: n_state=768 (nx=n_embd)
68
- #[switch nx => n_state from Block to Attention to keep identical to TF implem]
69
- assert n_state % cfg .n_head == 0
70
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
71
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
72
+ assert n_state % cfg .n_head == 0
70
73
self .register_buffer ('b' , torch .tril (torch .ones (n_ctx , n_ctx )).view (1 , 1 , n_ctx , n_ctx ))
71
74
self .n_head = cfg .n_head
72
75
self .split_size = n_state
@@ -80,19 +83,19 @@ def _attn(self, q, k, v):
80
83
w = torch .matmul (q , k )
81
84
if self .scale :
82
85
w = w / math .sqrt (v .size (- 1 ))
83
- w = w * self .b + - 1e9 * ( 1 - self .b ) # TF implem method: mask_attn_weights
86
+ w = w * self .b + - 1e9 * ( 1 - self .b ) # TF implem method: mask_attn_weights
84
87
w = nn .Softmax (dim = - 1 )(w )
85
88
w = self .attn_dropout (w )
86
89
return torch .matmul (w , v )
87
90
88
91
def merge_heads (self , x ):
89
92
x = x .permute (0 , 2 , 1 , 3 ).contiguous ()
90
93
new_x_shape = x .size ()[:- 2 ] + (x .size (- 2 ) * x .size (- 1 ),)
91
- return x .view (* new_x_shape ) # in Tensorflow implem: fct merge_states
94
+ return x .view (* new_x_shape ) # in Tensorflow implem: fct merge_states
92
95
93
96
def split_heads (self , x , k = False ):
94
- new_x_shape = x .size ()[:- 1 ] + (self .n_head , x .size (- 1 )// self .n_head )
95
- x = x .view (* new_x_shape ) # in Tensorflow implem: fct split_states
97
+ new_x_shape = x .size ()[:- 1 ] + (self .n_head , x .size (- 1 ) // self .n_head )
98
+ x = x .view (* new_x_shape ) # in Tensorflow implem: fct split_states
96
99
if k :
97
100
return x .permute (0 , 2 , 3 , 1 )
98
101
else :
@@ -112,7 +115,7 @@ def forward(self, x):
112
115
113
116
114
117
class MLP (nn .Module ):
115
- def __init__ (self , n_state , cfg ): # in MLP: n_state=3072 (4 * n_embd)
118
+ def __init__ (self , n_state , cfg ): # in MLP: n_state=3072 (4 * n_embd)
116
119
super (MLP , self ).__init__ ()
117
120
nx = cfg .n_embd
118
121
self .c_fc = Conv1D (n_state , 1 , nx )
@@ -132,19 +135,20 @@ def __init__(self, n_ctx, cfg, scale=False):
132
135
nx = cfg .n_embd
133
136
self .attn = Attention (nx , n_ctx , cfg , scale )
134
137
self .ln_1 = LayerNorm (nx )
135
- self .mlp = MLP (4 * nx , cfg )
138
+ self .mlp = MLP (4 * nx , cfg )
136
139
self .ln_2 = LayerNorm (nx )
137
140
138
141
def forward (self , x ):
139
142
a = self .attn (x )
140
- n = self .ln_1 (x + a )
143
+ n = self .ln_1 (x + a )
141
144
m = self .mlp (n )
142
- h = self .ln_2 (n + m )
145
+ h = self .ln_2 (n + m )
143
146
return h
144
147
145
148
146
149
class Model (nn .Module ):
147
150
""" Transformer model """
151
+
148
152
def __init__ (self , cfg , vocab = 40990 , n_ctx = 512 ):
149
153
super (Model , self ).__init__ ()
150
154
self .vocab = vocab
@@ -153,8 +157,8 @@ def __init__(self, cfg, vocab=40990, n_ctx=512):
153
157
block = Block (n_ctx , cfg , scale = True )
154
158
self .h = nn .ModuleList ([copy .deepcopy (block ) for _ in range (cfg .n_layer )])
155
159
self .decoder = nn .Linear (cfg .n_embd , vocab , bias = False )
156
- self .decoder .weight = self .embed .weight # Tied weights
157
- self .clf_dropout = nn .Dropout2d (cfg .clf_pdrop ) # To reproduce the noise_shape parameter of TF implementation
160
+ self .decoder .weight = self .embed .weight # Tied weights
161
+ self .clf_dropout = nn .Dropout2d (cfg .clf_pdrop ) # To reproduce the noise_shape parameter of TF implementation
158
162
159
163
nn .init .normal_ (self .embed .weight , std = 0.02 )
160
164
@@ -169,25 +173,27 @@ def forward(self, x):
169
173
170
174
class LMHead (nn .Module ):
171
175
""" Language Model Head for the transformer """
176
+
172
177
def __init__ (self , model , cfg ):
173
178
super (LMHead , self ).__init__ ()
174
179
self .n_embd = cfg .n_embd
175
- self .decoder = lambda x : F .linear (x , model .embed .weight ) # Tied weights
180
+ self .decoder = lambda x : F .linear (x , model .embed .weight ) # Tied weights
176
181
177
182
def forward (self , h ):
178
183
# Truncated Language modeling logits (we remove the last token)
179
- h_trunc = h [:, :- 1 ].contiguous ().view (- 1 , self .n_embd ) # Shape: 252, 768
184
+ h_trunc = h [:, :- 1 ].contiguous ().view (- 1 , self .n_embd ) # Shape: 252, 768
180
185
lm_logits = self .decoder (h_trunc )
181
186
return lm_logits
182
187
183
188
184
189
class ClfHead (nn .Module ):
185
190
""" Classifier Head for the transformer """
191
+
186
192
def __init__ (self , clf_token , cfg ):
187
193
super (ClfHead , self ).__init__ ()
188
194
self .n_embd = cfg .n_embd
189
195
self .clf_token = clf_token
190
- self .dropout = nn .Dropout2d (cfg .clf_pdrop ) # To reproduce the noise_shape parameter of TF implementation
196
+ self .dropout = nn .Dropout2d (cfg .clf_pdrop ) # To reproduce the noise_shape parameter of TF implementation
191
197
self .linear = nn .Linear (cfg .n_embd , 1 )
192
198
nn .init .normal_ (self .linear .weight , std = 0.02 )
193
199
nn .init .normal_ (self .linear .bias , 0 )
@@ -196,17 +202,30 @@ def forward(self, h, x):
196
202
# Classification logits
197
203
clf_h = h .view (- 1 , self .n_embd )
198
204
flat = x [:, :, :, 0 ].contiguous ().view (- 1 )
199
- #pool_idx = torch.eq(x[:, :, 0].contiguous().view(-1), self.clf_token)
200
- clf_h = clf_h [flat == self .clf_token , :] # .index_select(0, pool_idx)
205
+ # pool_idx = torch.eq(x[:, :, 0].contiguous().view(-1), self.clf_token)
206
+ clf_h = clf_h [flat == self .clf_token , :] # .index_select(0, pool_idx)
201
207
clf_h = clf_h .view (- 1 , 2 , self .n_embd , 1 )
202
208
clf_h = self .dropout (clf_h )
203
209
clf_h = clf_h .view (- 1 , self .n_embd )
204
210
clf_logits = self .linear (clf_h )
205
211
return clf_logits .view (- 1 , 2 )
206
212
207
213
208
- def load_openai_pretrained_model (model , n_ctx = - 1 , n_special = - 1 , n_transfer = 12 , n_embd = 768 , path = './model/' , path_names = './' ):
214
+ class DataParallelWithEmbed (torch .nn .DataParallel ):
215
+ """DataParallel that proxies the embed property to the wrapped module"""
216
+
217
+ def __init__ (self , model ):
218
+ super (DataParallelWithEmbed , self ).__init__ (model )
219
+
220
+ @property
221
+ def embed (self ):
222
+ return self .module .embed
223
+
224
+
225
+ def load_openai_pretrained_model (model , n_ctx = - 1 , n_special = - 1 , n_transfer = 12 , n_embd = 768 , path = './model/' ,
226
+ path_names = './' ):
209
227
# Load weights from TF model
228
+ print ("Loading weights..." )
210
229
names = json .load (open (path_names + 'parameters_names.json' ))
211
230
shapes = json .load (open (path + 'params_shapes.json' ))
212
231
offsets = np .cumsum ([np .prod (shape ) for shape in shapes ])
@@ -216,32 +235,40 @@ def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n
216
235
if n_ctx > 0 :
217
236
init_params [0 ] = init_params [0 ][:n_ctx ]
218
237
if n_special > 0 :
219
- init_params [0 ] = np .concatenate ([init_params [1 ],
220
- (np .random .randn (n_special , n_embd )* 0.02 ).astype (np .float32 ),
221
- init_params [0 ]
222
- ], 0 )
238
+ init_params [0 ] = np .concatenate (
239
+ [init_params [1 ],
240
+ (np .random .randn (n_special , n_embd ) * 0.02 ).astype (np .float32 ),
241
+ init_params [0 ]
242
+ ], 0 )
223
243
else :
224
- init_params [0 ] = np .concatenate ([init_params [1 ],
225
- init_params [0 ]
226
- ], 0 )
244
+ init_params [0 ] = np .concatenate (
245
+ [init_params [1 ],
246
+ init_params [0 ]
247
+ ], 0 )
227
248
del init_params [1 ]
228
249
if n_transfer == - 1 :
229
250
n_transfer = 0
230
251
else :
231
- n_transfer = 1 + n_transfer * 12
252
+ n_transfer = 1 + n_transfer * 12
232
253
init_params = [arr .squeeze () for arr in init_params ]
254
+
233
255
try :
234
256
assert model .embed .weight .shape == init_params [0 ].shape
235
257
except AssertionError as e :
236
258
e .args += (model .embed .weight .shape , init_params [0 ].shape )
237
259
raise
260
+
238
261
model .embed .weight .data = torch .from_numpy (init_params [0 ])
262
+
263
+ # Load the weights into our torch module
264
+ module = model .module
265
+
239
266
for name , ip in zip (names [1 :n_transfer ], init_params [1 :n_transfer ]):
240
- name = name [6 :] # skip "model/"
267
+ name = name [6 :] # skip "model/"
241
268
assert name [- 2 :] == ":0"
242
269
name = name [:- 2 ]
243
270
name = name .split ('/' )
244
- pointer = model
271
+ pointer = module
245
272
for m_name in name :
246
273
if re .fullmatch (r'[A-Za-z]+\d+' , m_name ):
247
274
l = re .split (r'(\d+)' , m_name )
@@ -258,12 +285,14 @@ def load_openai_pretrained_model(model, n_ctx=-1, n_special=-1, n_transfer=12, n
258
285
raise
259
286
pointer .data = torch .from_numpy (ip )
260
287
288
+
261
289
class dotdict (dict ):
262
290
"""dot.notation access to dictionary attributes"""
263
291
__getattr__ = dict .get
264
292
__setattr__ = dict .__setitem__
265
293
__delattr__ = dict .__delitem__
266
294
295
+
267
296
DEFAULT_CONFIG = dotdict ({
268
297
'n_embd' : 768 ,
269
298
'n_head' : 12 ,
0 commit comments