Skip to content

Commit 71bcad8

Browse files
committed
Using negative index to reshape the input tensor.
1 parent e339361 commit 71bcad8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

model_py.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def __init__(self, cfg, vocab=40990, n_ctx=512):
157157
nn.init.normal_(self.embed.weight, std=0.02)
158158

159159
def forward(self, x):
160-
x = x.view(-1, x.size(2), x.size(3))
160+
x = x.view(-1, x.size(-2), x.size(-1))
161161
e = self.embed(x)
162162
h = e.sum(dim=2)
163163
for block in self.h:

0 commit comments

Comments
 (0)