Skip to content

Commit d9a3ad3

Browse files
committed
Fix bug occurred when compiling torchscript
1 parent 885da87 commit d9a3ad3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

im2latex/models/resnet_transformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ def predict(self, x: torch.Tensor) -> torch.Tensor:
176176
y = output_tokens[:, :Sy] # (B, Sy)
177177
output = self.decode(x, y) # (Sy, B, C)
178178
output = torch.argmax(output, dim=-1) # (Sy, B)
179-
output_tokens[:, Sy] = output[-1:] # Set the last output token
179+
output_tokens[:, Sy] = output[-1] # Set the last output token
180180

181181
# Early stopping of prediction loop to speed up prediction
182182
if ((output_tokens[:, Sy] == self.end_token) | (output_tokens[:, Sy] == self.padding_token)).all():

0 commit comments

Comments
 (0)