1
1
import argparse
2
- from typing import Any , Dict , Tuple , Union
3
2
import math
3
+ from typing import Any , Dict , Tuple , Union
4
+
4
5
import torch
5
6
import torch .nn as nn
6
7
import torch .nn .functional as F
@@ -122,7 +123,7 @@ def forward(self, imgs, formulas):
122
123
if logits :
123
124
tgt = torch .argmax (torch .log (logits [- 1 ]), dim = 1 , keepdim = True )
124
125
# ont step decoding
125
- dec_states , O_t , logit = self .step_decoding (dec_states , o_t , encoded_imgs , tgt )
126
+ dec_states , _ , logit = self .step_decoding (dec_states , o_t , encoded_imgs , tgt )
126
127
logits .append (logit )
127
128
logits = torch .stack (logits , dim = 1 ) # [B, MAX_LEN, VOCAB_SIZE]
128
129
return logits .permute (0 , 2 , 1 ) # (B, C, Sy)
@@ -148,7 +149,7 @@ def step_decoding(
148
149
c_t = self .dropout (c_t )
149
150
150
151
# context_t : [B, C]
151
- context_t , attn_scores = self ._get_attn (enc_out , h_t )
152
+ context_t , _ = self ._get_attn (enc_out , h_t )
152
153
153
154
# [B, dec_rnn_h]
154
155
o_t = self .W_3 (torch .cat ([h_t , context_t ], dim = 1 )).tanh ()
0 commit comments