@@ -84,7 +84,11 @@ def _attn(self, q, k, v):
84
84
w = torch .matmul (q , k )
85
85
if self .scale :
86
86
w = w / math .sqrt (v .size (- 1 ))
87
- w = w * self .b + - 1e9 * (1 - self .b ) # TF implem method: mask_attn_weights
87
+ # w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
88
+ # XD: self.b may be larger than w, so we need to crop it
89
+ b = self .b [:, :, w .size (- 2 ), w .size (- 1 )]
90
+ w = w * b + - 1e9 * (1 - b )
91
+
88
92
w = nn .Softmax (dim = - 1 )(w )
89
93
w = self .attn_dropout (w )
90
94
return torch .matmul (w , v )
@@ -175,16 +179,18 @@ def forward(self, x):
175
179
class LMHead (nn .Module ):
176
180
""" Language Model Head for the transformer """
177
181
178
- def __init__ (self , model , cfg ):
182
+ def __init__ (self , model , cfg , trunc_and_reshape = True ):
179
183
super (LMHead , self ).__init__ ()
180
184
self .n_embd = cfg .n_embd
181
185
embed_shape = model .embed .weight .shape
182
186
self .decoder = nn .Linear (embed_shape [1 ], embed_shape [0 ], bias = False )
183
187
self .decoder .weight = model .embed .weight # Tied weights
188
+ self .trunc_and_reshape = trunc_and_reshape # XD
184
189
185
190
def forward (self , h ):
186
191
# Truncated Language modeling logits (we remove the last token)
187
- h_trunc = h [:, :- 1 ].contiguous ().view (- 1 , self .n_embd )
192
+ h_trunc = h [:, :- 1 ].contiguous ().view (- 1 , self .n_embd ) \
193
+ if self .trunc_and_reshape else h # XD
188
194
lm_logits = self .decoder (h_trunc )
189
195
return lm_logits
190
196
@@ -266,6 +272,29 @@ def forward(self, h, x):
266
272
267
273
return sim_logits
268
274
275
+
276
+ # XD
277
+ class LMModel (nn .Module ):
278
+ """ Transformer with language model head only """
279
+ def __init__ (self , cfg , vocab = 40990 , n_ctx = 512 , return_probs = False ):
280
+ super (LMModel , self ).__init__ ()
281
+ self .transformer = TransformerModel (cfg , vocab = vocab , n_ctx = n_ctx )
282
+ self .lm_head = LMHead (self .transformer , cfg , trunc_and_reshape = False )
283
+ self .return_probs = return_probs
284
+ if self .return_probs :
285
+ pos_emb_mask = torch .zeros (1 , 1 , vocab )
286
+ pos_emb_mask [:, :, - n_ctx :] = - 1e12
287
+ self .register_buffer ('pos_emb_mask' , pos_emb_mask )
288
+
289
+
290
+ def forward (self , x ):
291
+ h = self .transformer (x )
292
+ lm_logits = self .lm_head (h )
293
+ if self .return_probs :
294
+ lm_logits = F .softmax (lm_logits + self .pos_emb_mask , dim = - 1 )
295
+ return lm_logits
296
+
297
+
269
298
class DoubleHeadModel (nn .Module ):
270
299
""" Transformer with language model and task specific heads """
271
300
def __init__ (self , cfg , clf_token , task_head_type , vocab = 40990 , n_ctx = 512 ):
0 commit comments