Skip to content

Commit 21cc903

Browse files
committed
Add predict-lstm-attend
1 parent 5e47832 commit 21cc903

File tree

5 files changed

+220
-24
lines changed

5 files changed

+220
-24
lines changed

Makefile

+6-2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ bin = \
1515
learn-residual \
1616
predict-residual \
1717
learn-lstm-attend \
18+
predict-lstm-attend \
1819
libnn.a
1920

2021
all: $(bin)
@@ -56,10 +57,13 @@ predict-lstm2d: lstm.o predict-lstm2d.o pred.o
5657
learn-residual: learn-residual.o residual.o pred.o nn.o
5758
$(CXX) $(CXXFLAGS) -o $@ $^ -lautodiff -lspeech -lopt -lla -lebt -lblas
5859

59-
predict-residual: residual.o predict-residual.o pred.o
60+
predict-residual: predict-residual.o residual.o pred.o
6061
$(CXX) $(CXXFLAGS) -o $@ $^ -lautodiff -lspeech -lopt -lla -lebt -lblas
6162

62-
learn-lstm-attend: lstm.o learn-lstm-attend.o pred.o nn.o
63+
learn-lstm-attend: learn-lstm-attend.o lstm.o attention.o pred.o nn.o
64+
$(CXX) $(CXXFLAGS) -o $@ $^ -lautodiff -lspeech -lopt -lla -lebt -lblas
65+
66+
predict-lstm-attend: predict-lstm-attend.o lstm.o attention.o pred.o nn.o
6367
$(CXX) $(CXXFLAGS) -o $@ $^ -lautodiff -lspeech -lopt -lla -lebt -lblas
6468

6569
nn.o: nn.h

attention.cc

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#include "nn/attention.h"
2+
3+
namespace attention {
4+
5+
attention_nn_t attend(
6+
std::shared_ptr<autodiff::op_t> const& hs,
7+
std::shared_ptr<autodiff::op_t> const& target)
8+
{
9+
attention_nn_t att;
10+
11+
att.attention = autodiff::softmax(autodiff::mul(hs, target));
12+
att.context = autodiff::lmul(att.attention, hs);
13+
14+
return att;
15+
}
16+
17+
}

attention.h

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef ATTENTION_H
2+
#define ATTENTION_H
3+
4+
#include "autodiff/autodiff.h"
5+
6+
namespace attention {
7+
8+
struct attention_nn_t {
9+
std::shared_ptr<autodiff::op_t> attention;
10+
std::shared_ptr<autodiff::op_t> context;
11+
};
12+
13+
attention_nn_t attend(
14+
std::shared_ptr<autodiff::op_t> const& inputs,
15+
std::shared_ptr<autodiff::op_t> const& target);
16+
17+
}
18+
19+
#endif

learn-lstm-attend.cc

+3-22
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "nn/lstm.h"
99
#include "nn/pred.h"
1010
#include "nn/nn.h"
11+
#include "nn/attention.h"
1112
#include <random>
1213

1314
struct learning_env {
@@ -52,15 +53,6 @@ struct learning_env {
5253

5354
};
5455

55-
struct attention_nn_t {
56-
std::shared_ptr<autodiff::op_t> attention;
57-
std::shared_ptr<autodiff::op_t> context;
58-
};
59-
60-
attention_nn_t attend(
61-
std::shared_ptr<autodiff::op_t> const& inputs,
62-
std::shared_ptr<autodiff::op_t> const& target);
63-
6456
int main(int argc, char *argv[])
6557
{
6658
ebt::ArgumentSpec spec {
@@ -210,11 +202,11 @@ void learning_env::run()
210202
}
211203

212204
std::shared_ptr<autodiff::op_t> hs = autodiff::row_cat(nn.layer.back().output);
213-
std::vector<attention_nn_t> atts;
205+
std::vector<attention::attention_nn_t> atts;
214206
std::vector<std::shared_ptr<autodiff::op_t>> context;
215207

216208
for (int i = 0; i < nn.layer.back().output.size(); ++i) {
217-
atts.push_back(attend(hs, nn.layer.back().output[i]));
209+
atts.push_back(attention::attend(hs, nn.layer.back().output[i]));
218210
context.push_back(atts.back().context);
219211
}
220212

@@ -317,14 +309,3 @@ void learning_env::run()
317309
opt_data_ofs.close();
318310
}
319311

320-
attention_nn_t attend(
321-
std::shared_ptr<autodiff::op_t> const& hs,
322-
std::shared_ptr<autodiff::op_t> const& target)
323-
{
324-
attention_nn_t att;
325-
326-
att.attention = autodiff::softmax(autodiff::mul(hs, target));
327-
att.context = autodiff::lmul(att.attention, hs);
328-
329-
return att;
330-
}

predict-lstm-attend.cc

+175
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
#include "ebt/ebt.h"
2+
#include "speech/speech.h"
3+
#include "nn/lstm.h"
4+
#include "nn/pred.h"
5+
#include <fstream>
6+
#include "nn/attention.h"
7+
8+
struct prediction_env {
9+
10+
std::ifstream frame_batch;
11+
12+
lstm::dblstm_feat_param_t param;
13+
lstm::dblstm_feat_nn_t nn;
14+
nn::pred_param_t pred_param;
15+
rnn::pred_nn_t pred_nn;
16+
17+
std::vector<std::string> label;
18+
19+
double rnndrop_prob;
20+
int subsample_freq;
21+
int subsample_shift;
22+
23+
std::unordered_map<std::string, std::string> args;
24+
25+
prediction_env(std::unordered_map<std::string, std::string> args);
26+
27+
void run();
28+
29+
};
30+
31+
int main(int argc, char *argv[])
32+
{
33+
ebt::ArgumentSpec spec {
34+
"predict-lstm",
35+
"Predict frames with LSTM",
36+
{
37+
{"frame-batch", "", true},
38+
{"param", "", true},
39+
{"label", "", true},
40+
{"rnndrop-prob", "", false},
41+
{"logprob", "", false},
42+
{"subsample-freq", "", false},
43+
{"subsample-shift", "", false}
44+
}
45+
};
46+
47+
if (argc == 1) {
48+
ebt::usage(spec);
49+
exit(1);
50+
}
51+
52+
auto args = ebt::parse_args(argc, argv, spec);
53+
54+
std::cout << args << std::endl;
55+
56+
prediction_env env { args };
57+
58+
env.run();
59+
60+
return 0;
61+
}
62+
63+
prediction_env::prediction_env(std::unordered_map<std::string, std::string> args)
64+
: args(args)
65+
{
66+
frame_batch.open(args.at("frame-batch"));
67+
68+
std::ifstream param_ifs { args.at("param") };
69+
param = lstm::load_dblstm_feat_param(param_ifs);
70+
pred_param = nn::load_pred_param(param_ifs);
71+
param_ifs.close();
72+
73+
label = speech::load_label_set(args.at("label"));
74+
75+
if (ebt::in(std::string("rnndrop-prob"), args)) {
76+
rnndrop_prob = std::stod(args.at("rnndrop-prob"));
77+
}
78+
79+
subsample_freq = 1;
80+
if (ebt::in(std::string("subsample-freq"), args)) {
81+
subsample_freq = std::stoi(args.at("subsample-freq"));
82+
}
83+
84+
subsample_shift = 0;
85+
if (ebt::in(std::string("subsample-shift"), args)) {
86+
subsample_shift = std::stoi(args.at("subsample-shift"));
87+
}
88+
}
89+
90+
void prediction_env::run()
91+
{
92+
int i = 1;
93+
94+
while (1) {
95+
std::vector<std::vector<double>> frames;
96+
97+
frames = speech::load_frame_batch(frame_batch);
98+
99+
if (!frame_batch) {
100+
break;
101+
}
102+
103+
autodiff::computation_graph graph;
104+
std::vector<std::shared_ptr<autodiff::op_t>> inputs;
105+
106+
for (int i = 0; i < frames.size(); ++i) {
107+
inputs.push_back(graph.var(la::vector<double>(frames[i])));
108+
}
109+
110+
std::vector<std::shared_ptr<autodiff::op_t>> subsampled_inputs
111+
= rnn::subsample_input(inputs, subsample_freq, subsample_shift);
112+
113+
nn = lstm::make_dblstm_feat_nn(graph, param, subsampled_inputs);
114+
115+
if (ebt::in(std::string("rnndrop-prob"), args)) {
116+
lstm::apply_mask(nn, param, rnndrop_prob);
117+
}
118+
119+
std::shared_ptr<autodiff::op_t> hs = autodiff::row_cat(nn.layer.back().output);
120+
std::vector<attention::attention_nn_t> atts;
121+
std::vector<std::shared_ptr<autodiff::op_t>> context;
122+
123+
for (int i = 0; i < nn.layer.back().output.size(); ++i) {
124+
atts.push_back(attention::attend(hs, nn.layer.back().output[i]));
125+
context.push_back(atts.back().context);
126+
}
127+
128+
pred_nn = rnn::make_pred_nn(graph, pred_param, context);
129+
130+
std::vector<std::shared_ptr<autodiff::op_t>> upsampled_output
131+
= rnn::upsample_output(pred_nn.logprob, subsample_freq, subsample_shift, frames.size());
132+
133+
assert(upsampled_output.size() == frames.size());
134+
135+
auto topo_order = autodiff::topo_order(upsampled_output);
136+
autodiff::eval(topo_order, autodiff::eval_funcs);
137+
138+
std::cout << i << ".phn" << std::endl;
139+
140+
if (ebt::in(std::string("logprob"), args)) {
141+
for (int t = 0; t < upsampled_output.size(); ++t) {
142+
auto& pred = autodiff::get_output<la::vector<double>>(upsampled_output[t]);
143+
144+
std::cout << pred(0);
145+
146+
for (int j = 1; j < pred.size(); ++j) {
147+
std::cout << " " << pred(j);
148+
}
149+
150+
std::cout << std::endl;
151+
}
152+
} else {
153+
for (int t = 0; t < upsampled_output.size(); ++t) {
154+
auto& pred = autodiff::get_output<la::vector<double>>(upsampled_output[t]);
155+
156+
int argmax = -1;
157+
double max = -std::numeric_limits<double>::infinity();
158+
159+
for (int j = 0; j < pred.size(); ++j) {
160+
if (pred(j) > max) {
161+
max = pred(j);
162+
argmax = j;
163+
}
164+
}
165+
166+
std::cout << label[argmax] << std::endl;
167+
}
168+
}
169+
170+
std::cout << "." << std::endl;
171+
172+
++i;
173+
}
174+
}
175+

0 commit comments

Comments
 (0)