1
+ import re
1
2
import math
2
3
import json
3
4
import copy
@@ -36,22 +37,34 @@ def load_openai_pretrained_model(model, n_ctx, n_special, cfg, path='model'):
36
37
n_transfer = 0
37
38
else :
38
39
n_transfer = 1 + n_transfer * 12
39
- assert model .embed .weight .shape == init_params [0 ].shape
40
- model .embed .weight = init_params [0 ]
40
+ init_params = [arr .squeeze () for arr in init_params ]
41
+ try :
42
+ assert model .embed .weight .shape == init_params [0 ].shape
43
+ except AssertionError as e :
44
+ e .args += (model .embed .weight .shape , init_params [0 ].shape )
45
+ raise
46
+ model .embed .weight .data = torch .from_numpy (init_params [0 ])
41
47
for name , ip in zip (names [1 :n_transfer ], init_params [1 :n_transfer ]):
42
48
name = name [6 :] # skip "model/"
43
49
assert name [- 2 :] == ":0"
44
50
name = name [:- 2 ]
45
51
name = name .split ('/' )
46
52
pointer = model
47
53
for m_name in name :
48
- l = re .split ('(\d+)' , m_name )
54
+ if re .fullmatch (r'[A-Za-z]+\d+' , m_name ):
55
+ l = re .split (r'(\d+)' , m_name )
56
+ else :
57
+ l = [m_name ]
49
58
pointer = getattr (pointer , l [0 ])
50
- if len (l ) == 1 :
59
+ if len (l ) >= 2 :
51
60
num = int (l [1 ])
52
61
pointer = pointer [num ]
53
- assert pointer .shape == ip .shape
54
- pointer = ip
62
+ try :
63
+ assert pointer .shape == ip .shape
64
+ except AssertionError as e :
65
+ e .args += (pointer .shape , ip .shape )
66
+ raise
67
+ pointer .data = torch .from_numpy (ip )
55
68
56
69
57
70
class LayerNorm (nn .Module ):
@@ -82,7 +95,7 @@ def __init__(self, nf, rf, nx):
82
95
83
96
def forward (self , x ):
84
97
if self .rf == 1 :
85
- size_out = x .size ()[:- 1 ] + [ self .nf ]
98
+ size_out = x .size ()[:- 1 ] + ( self .nf ,)
86
99
x = torch .addmm (self .b , x .view (- 1 , x .size (- 1 )), self .w )
87
100
x = x .view (* size_out )
88
101
else :
@@ -93,38 +106,35 @@ def forward(self, x):
93
106
class Attention (nn .Module ):
94
107
def __init__ (self , nx , cfg , scale = False ):
95
108
super (Attention , self ).__init__ ()
96
- n_state = nx # in Attention: n_state=768 (nx=n_embed )
109
+ n_state = nx # in Attention: n_state=768 (nx=n_embd )
97
110
#[switch nx => n_state from Block to Attention to keep identical to TF implem]
98
111
assert n_state % cfg .n_head == 0
112
+ mask_size = n_state // cfg .n_head
113
+ self .register_buffer ('b' , torch .tril (torch .ones (mask_size , mask_size )).view (1 , 1 , mask_size , mask_size ))
99
114
self .n_head = cfg .n_head
115
+ self .split_size = n_state
100
116
self .scale = scale
101
117
self .c_attn = Conv1D (n_state * 3 , 1 , nx )
102
118
self .c_proj = Conv1D (n_state , 1 , nx )
103
119
self .attn_dropout = nn .Dropout (cfg .attn_pdrop )
104
120
self .resid_dropout = nn .Dropout (cfg .resid_pdrop )
105
121
106
- @staticmethod
107
- def mask_attn_weights (w ):
108
- n = w .size (- 1 )
109
- b = torch .tril (np .ones (n , n )).view (1 , 1 , n , n )
110
- return w * b + - 1e9 * (1 - b )
111
-
112
122
def _attn (self , q , k , v ):
113
123
w = torch .matmul (q , k )
114
124
if self .scale :
115
125
w = w / math .sqrt (v .size (- 1 ))
116
- w = self .mask_attn_weights ( w )
126
+ w = w * self .b + - 1e9 * ( 1 - self . b ) # TF implem method: mask_attn_weights
117
127
w = nn .Softmax ()(w )
118
128
w = self .attn_dropout (w )
119
129
return torch .matmul (w , v )
120
130
121
131
def merge_heads (self , x ):
122
- new_x_shape = x .size ()[: - 2 ] + [ np . prod ( x . size ()[ - 2 :])]
123
- x = x .view ( * new_x_shape ) # in Tensorflow implem: fct merge_states
124
- return x .permute ( 0 , 2 , 1 , 3 )
132
+ x = x .permute ( 0 , 2 , 1 , 3 ). contiguous ()
133
+ new_x_shape = x .size ()[: - 2 ] + ( x . size ( - 2 ) * x . size ( - 1 ),)
134
+ return x .view ( * new_x_shape ) # in Tensorflow implem: fct merge_states
125
135
126
136
def split_heads (self , x , k = False ):
127
- new_x_shape = x .size ()[:- 1 ] + [ self .n_head , x .size (- 1 )// self .n_head ]
137
+ new_x_shape = x .size ()[:- 1 ] + ( self .n_head , x .size (- 1 )// self .n_head )
128
138
x = x .view (* new_x_shape ) # in Tensorflow implem: fct split_states
129
139
if k :
130
140
return x .permute (0 , 2 , 3 , 1 )
@@ -133,7 +143,7 @@ def split_heads(self, x, k=False):
133
143
134
144
def forward (self , x ):
135
145
x = self .c_attn (x )
136
- query , key , value = x .split (3 , dim = 2 )
146
+ query , key , value = x .split (self . split_size , dim = 2 )
137
147
query = self .split_heads (query )
138
148
key = self .split_heads (key , k = True )
139
149
value = self .split_heads (value )
@@ -145,11 +155,11 @@ def forward(self, x):
145
155
146
156
147
157
class MLP (nn .Module ):
148
- def __init__ (self , n_state , cfg ): # in MLP: n_state=3072 (4 * n_embed )
158
+ def __init__ (self , n_state , cfg ): # in MLP: n_state=3072 (4 * n_embd )
149
159
super (MLP , self ).__init__ ()
150
- nx = cfg .n_embed
160
+ nx = cfg .n_embd
151
161
self .c_fc = Conv1D (n_state , 1 , nx )
152
- self .c_proj = Conv1D (nx , 1 , nx )
162
+ self .c_proj = Conv1D (nx , 1 , n_state )
153
163
self .act = ACT_FNS [cfg .afn ]
154
164
self .dropout = nn .Dropout (cfg .resid_pdrop )
155
165
@@ -162,7 +172,7 @@ def forward(self, x):
162
172
class Block (nn .Module ):
163
173
def __init__ (self , cfg , scale = False ):
164
174
super (Block , self ).__init__ ()
165
- nx = cfg .n_embed
175
+ nx = cfg .n_embd
166
176
self .attn = Attention (nx , cfg , scale )
167
177
self .ln_1 = LayerNorm (nx )
168
178
self .mlp = MLP (4 * nx , cfg )
@@ -185,13 +195,12 @@ def __init__(self, vocab, cfg):
185
195
self .drop = nn .Dropout (cfg .embd_pdrop )
186
196
block = Block (cfg , scale = True )
187
197
self .h = nn .ModuleList ([copy .deepcopy (block ) for _ in range (cfg .n_layer )])
188
- self .decoder = nn .Linear (cfg .n_embed , vocab , bias = False )
198
+ self .decoder = nn .Linear (cfg .n_embd , vocab , bias = False )
189
199
self .decoder .weight = self .embed .weight # Tied weights
190
200
self .clf_dropout = nn .Dropout2d (cfg .clf_pdrop ) # To reproduce the noise_shape parameter of TF implementation
191
201
192
- def forward (self , x , m ):
202
+ def forward (self , x ):
193
203
x = x .view (- 1 , x .size (2 ), x .size (3 ))
194
- m = m .view (- 1 , m .size (2 ))
195
204
e = self .embed (x )
196
205
h = e .sum (dim = 2 )
197
206
for block in self .h :
@@ -200,36 +209,37 @@ def forward(self, x, m):
200
209
201
210
202
211
class LMHead (nn .Module ):
203
- """ Language Model Head """
212
+ """ Language Model Head for the transformer """
204
213
def __init__ (self , model , cfg ):
205
214
super (LMHead , self ).__init__ ()
206
- self .n_embed = cfg .n_embed
207
- self .decoder = nn .Linear (cfg .n_embed , model .vocab , bias = False )
215
+ self .n_embd = cfg .n_embd
216
+ self .decoder = nn .Linear (cfg .n_embd , model .vocab , bias = False )
208
217
self .decoder .weight = model .embed .weight # Tied weights
209
218
210
219
def forward (self , h ):
211
220
# Truncated Language modeling logits
212
- h_trunc = h [:, :- 1 ].contiguous ().view (- 1 , self .n_embed ) # Shape: 252, 768
221
+ h_trunc = h [:, :- 1 ].contiguous ().view (- 1 , self .n_embd ) # Shape: 252, 768
213
222
lm_logits = self .decoder (h_trunc )
214
223
return lm_logits
215
224
216
225
217
226
class ClfHead (nn .Module ):
218
- """ Classifier Head for the model """
219
- def __init__ (self , model , clf_token , cfg ):
227
+ """ Classifier Head for the transformer """
228
+ def __init__ (self , clf_token , cfg ):
220
229
super (ClfHead , self ).__init__ ()
221
- self .n_embed = cfg .n_embed
230
+ self .n_embd = cfg .n_embd
222
231
self .clf_token = clf_token
223
232
self .dropout = nn .Dropout2d (cfg .clf_pdrop ) # To reproduce the noise_shape parameter of TF implementation
224
- self .linear = nn .Linear (cfg .n_embed , 1 )
233
+ self .linear = nn .Linear (cfg .n_embd , 1 )
225
234
226
235
def forward (self , h , x ):
227
236
# Classification logits
228
- clf_h = h .view (- 1 , self .n_embed )
229
- pool_idx = torch .eq (x [:, :, 0 ].contiguous ().view (- 1 ), self .clf_token )
230
- clf_h = clf_h [pool_idx , :]
231
- clf_h = clf_h .view (- 1 , 2 , self .n_embed , 1 )
237
+ clf_h = h .view (- 1 , self .n_embd )
238
+ flat = x [:, :, :, 0 ].contiguous ().view (- 1 )
239
+ #pool_idx = torch.eq(x[:, :, 0].contiguous().view(-1), self.clf_token)
240
+ clf_h = clf_h [flat == self .clf_token , :] #.index_select(0, pool_idx)
241
+ clf_h = clf_h .view (- 1 , 2 , self .n_embd , 1 )
232
242
clf_h = self .dropout (clf_h )
233
- clf_h = clf_h .view (- 1 , self .n_embed )
243
+ clf_h = clf_h .view (- 1 , self .n_embd )
234
244
clf_logits = self .linear (clf_h )
235
245
return clf_logits .view (- 1 , 2 )
0 commit comments