diff --git a/modules/tts/syntaspeech/syntactic_graph_encoder.py b/modules/tts/syntaspeech/syntactic_graph_encoder.py index 0260b31..134d803 100644 --- a/modules/tts/syntaspeech/syntactic_graph_encoder.py +++ b/modules/tts/syntaspeech/syntactic_graph_encoder.py @@ -117,7 +117,7 @@ def word_forward(self, graph_lst, word_encoding, etypes_lst): gcc2_out = self.ggc_2(batched_graph, gcc2_out, batched_etypes) if self.skip_connect: assert self.in_dim == self.hid_dim and self.hid_dim == self.out_dim - gcc2_out = inp + gcc1_out + gcc1_out + gcc2_out = inp + gcc1_out + gcc2_out word_len = torch.tensor([g.num_nodes() for g in graph_lst]).reshape([-1]) max_len = max(word_len) @@ -150,7 +150,7 @@ def forward(self, graph_lst, ph_encoding, ph2word, etypes_lst, return_word_encod gcc1_out = self.ggc_1(batched_graph, inp, batched_etypes) gcc2_out = self.ggc_2(batched_graph, gcc1_out, batched_etypes) # [num_nodes_in_batch, hin] # skip connection - gcc2_out = inp + gcc1_out + gcc1_out # [n_nodes, hid] + gcc2_out = inp + gcc1_out + gcc2_out # [n_nodes, hid] output = torch.zeros([bs * t_w, hid]).to(gcc2_out.device) output[has_word_row_idx] = gcc2_out