We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2b7e97e commit be407cdCopy full SHA for be407cd
model_pytorch.py
@@ -177,6 +177,7 @@ class LMHead(nn.Module):
177
def __init__(self, model, cfg):
178
super(LMHead, self).__init__()
179
self.n_embd = cfg.n_embd
180
+ embed_shape = model.embed.weight.shape
181
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
182
self.decoder.weight = model.embed.weight # Tied weights
183
0 commit comments