Skip to content

Commit be407cd

Browse files
committed
Solving missing variable issue.
1 parent 2b7e97e commit be407cd

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

model_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ class LMHead(nn.Module):
177177
def __init__(self, model, cfg):
178178
super(LMHead, self).__init__()
179179
self.n_embd = cfg.n_embd
180+
embed_shape = model.embed.weight.shape
180181
self.decoder = nn.Linear(embed_shape[1], embed_shape[0], bias=False)
181182
self.decoder.weight = model.embed.weight # Tied weights
182183

0 commit comments

Comments
 (0)