Skip to content

Commit c495117

Browse files
committed
Format sort and refactor
1 parent 4a25fd1 commit c495117

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

Diff for: im2latex/models/cnn_lstm.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import argparse
2-
from typing import Any, Dict, Tuple, Union
32
import math
3+
from typing import Any, Dict, Tuple, Union
4+
45
import torch
56
import torch.nn as nn
67
import torch.nn.functional as F
@@ -122,7 +123,7 @@ def forward(self, imgs, formulas):
122123
if logits:
123124
tgt = torch.argmax(torch.log(logits[-1]), dim=1, keepdim=True)
124125
# ont step decoding
125-
dec_states, O_t, logit = self.step_decoding(dec_states, o_t, encoded_imgs, tgt)
126+
dec_states, _, logit = self.step_decoding(dec_states, o_t, encoded_imgs, tgt)
126127
logits.append(logit)
127128
logits = torch.stack(logits, dim=1) # [B, MAX_LEN, VOCAB_SIZE]
128129
return logits.permute(0, 2, 1) # (B, C, Sy)
@@ -148,7 +149,7 @@ def step_decoding(
148149
c_t = self.dropout(c_t)
149150

150151
# context_t : [B, C]
151-
context_t, attn_scores = self._get_attn(enc_out, h_t)
152+
context_t, _ = self._get_attn(enc_out, h_t)
152153

153154
# [B, dec_rnn_h]
154155
o_t = self.W_3(torch.cat([h_t, context_t], dim=1)).tanh()

0 commit comments

Comments
 (0)