Skip to content

Commit 8833a83

Browse files
committed
Generation from pretrained LM model
1 parent 5bf3607 commit 8833a83

File tree

1 file changed

+32
-3
lines changed

1 file changed

+32
-3
lines changed

model_pytorch.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ def _attn(self, q, k, v):
8484
w = torch.matmul(q, k)
8585
if self.scale:
8686
w = w / math.sqrt(v.size(-1))
87-
w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
87+
# w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
88+
# XD: self.b may be larger than w, so we need to crop it
89+
b = self.b[:, :, w.size(-2), w.size(-1)]
90+
w = w * b + -1e9 * (1 - b)
91+
8892
w = nn.Softmax(dim=-1)(w)
8993
w = self.attn_dropout(w)
9094
return torch.matmul(w, v)
@@ -175,16 +179,18 @@ def forward(self, x):
175179
class LMHead(nn.Module):
176180
""" Language Model Head for the transformer """
177181

178-
def __init__(self, model, cfg):
182+
def __init__(self, model, cfg, trunc_and_reshape=True):
179183
super(LMHead, self).__init__()
180184
self.n_embd = cfg.n_embd
181185
embed_shape = model.embed.weight.shape
182186
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
183187
self.decoder.weight = model.embed.weight # Tied weights
188+
self.trunc_and_reshape = trunc_and_reshape # XD
184189

185190
def forward(self, h):
186191
# Truncated Language modeling logits (we remove the last token)
187-
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd)
192+
h_trunc = h[:, :-1].contiguous().view(-1, self.n_embd) \
193+
if self.trunc_and_reshape else h # XD
188194
lm_logits = self.decoder(h_trunc)
189195
return lm_logits
190196

@@ -266,6 +272,29 @@ def forward(self, h, x):
266272

267273
return sim_logits
268274

275+
276+
# XD
277+
class LMModel(nn.Module):
278+
""" Transformer with language model head only """
279+
def __init__(self, cfg, vocab=40990, n_ctx=512, return_probs=False):
280+
super(LMModel, self).__init__()
281+
self.transformer = TransformerModel(cfg, vocab=vocab, n_ctx=n_ctx)
282+
self.lm_head = LMHead(self.transformer, cfg, trunc_and_reshape=False)
283+
self.return_probs = return_probs
284+
if self.return_probs:
285+
pos_emb_mask = torch.zeros(1, 1, vocab)
286+
pos_emb_mask[:, :, -n_ctx:] = -1e12
287+
self.register_buffer('pos_emb_mask', pos_emb_mask)
288+
289+
290+
def forward(self, x):
291+
h = self.transformer(x)
292+
lm_logits = self.lm_head(h)
293+
if self.return_probs:
294+
lm_logits = F.softmax(lm_logits + self.pos_emb_mask, dim=-1)
295+
return lm_logits
296+
297+
269298
class DoubleHeadModel(nn.Module):
270299
""" Transformer with language model and task specific heads """
271300
def __init__(self, cfg, clf_token, task_head_type, vocab=40990, n_ctx=512):

0 commit comments

Comments
 (0)