Skip to content

Commit

Permalink
fix a bug in graph encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
yerfor committed May 25, 2022
1 parent 7bbc7c0 commit 64dc2f2
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions modules/tts/syntaspeech/syntactic_graph_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def word_forward(self, graph_lst, word_encoding, etypes_lst):
gcc2_out = self.ggc_2(batched_graph, gcc2_out, batched_etypes)
if self.skip_connect:
assert self.in_dim == self.hid_dim and self.hid_dim == self.out_dim
gcc2_out = inp + gcc1_out + gcc1_out
gcc2_out = inp + gcc1_out + gcc2_out

word_len = torch.tensor([g.num_nodes() for g in graph_lst]).reshape([-1])
max_len = max(word_len)
Expand Down Expand Up @@ -150,7 +150,7 @@ def forward(self, graph_lst, ph_encoding, ph2word, etypes_lst, return_word_encod
gcc1_out = self.ggc_1(batched_graph, inp, batched_etypes)
gcc2_out = self.ggc_2(batched_graph, gcc1_out, batched_etypes) # [num_nodes_in_batch, hin]
# skip connection
gcc2_out = inp + gcc1_out + gcc1_out # [n_nodes, hid]
gcc2_out = inp + gcc1_out + gcc2_out # [n_nodes, hid]

output = torch.zeros([bs * t_w, hid]).to(gcc2_out.device)
output[has_word_row_idx] = gcc2_out
Expand Down

0 comments on commit 64dc2f2

Please sign in to comment.