@@ -25,77 +25,45 @@ namespace lstm_frame {
25
25
return std::make_shared<tensor_tree::vertex>(result);
26
26
}
27
27
28
- std::shared_ptr<tensor_tree::vertex> make_hypercolumn_tensor_tree (int layer)
29
- {
30
- tensor_tree::vertex result { " nil" };
31
-
32
- lstm::multilayer_lstm_tensor_tree_factory factory {
33
- std::make_shared<lstm::bi_lstm_tensor_tree_factory>(
34
- lstm::bi_lstm_tensor_tree_factory {
35
- std::make_shared<lstm::dyer_lstm_tensor_tree_factory>(
36
- lstm::dyer_lstm_tensor_tree_factory{})
37
- }),
38
- layer
39
- };
40
-
41
- tensor_tree::vertex hypercolumn { " nil" };
42
-
43
- hypercolumn.children .push_back (factory ());
44
-
45
- for (int i = 0 ; i < layer + 1 ; ++i) {
46
- hypercolumn.children .push_back (tensor_tree::make_tensor (" hypercolumn weight" ));
47
- }
48
- hypercolumn.children .push_back (tensor_tree::make_tensor (" hypercolumn bias" ));
49
-
50
- result.children .push_back (std::make_shared<tensor_tree::vertex>(hypercolumn));
51
-
52
- result.children .push_back (tensor_tree::make_tensor (" softmax weight" ));
53
- result.children .push_back (tensor_tree::make_tensor (" softmax bias" ));
54
-
55
- return std::make_shared<tensor_tree::vertex>(result);
56
- }
57
-
58
28
std::shared_ptr<lstm::transcriber>
59
29
make_transcriber (
60
30
int layer,
61
31
double dropout,
62
32
std::default_random_engine *gen)
63
33
{
64
- std::shared_ptr<lstm::step_transcriber> step;
65
-
66
- if (dropout != 0.0 ) {
67
- assert (gen != nullptr );
68
-
69
- step = std::make_shared<lstm::input_dropout_transcriber>(
70
- lstm::input_dropout_transcriber {
71
- *gen, dropout,
72
- std::make_shared<lstm::dyer_lstm_step_transcriber>(
73
- lstm::dyer_lstm_step_transcriber{})
74
- });
75
- } else {
76
- step = std::make_shared<lstm::dyer_lstm_step_transcriber>(
34
+ std::shared_ptr<lstm::step_transcriber> step
35
+ = std::make_shared<lstm::dyer_lstm_step_transcriber>(
77
36
lstm::dyer_lstm_step_transcriber{});
78
- }
79
37
80
38
lstm::layered_transcriber result;
81
39
82
40
for (int i = 0 ; i < layer; ++i) {
83
- std::shared_ptr<lstm::transcriber> trans;
41
+ std::shared_ptr<lstm::transcriber> f_trans;
42
+ std::shared_ptr<lstm::transcriber> b_trans;
84
43
85
44
if (dropout != 0.0 ) {
86
- trans = std::make_shared<lstm::lstm_transcriber>(
87
- lstm::lstm_transcriber {
88
- std::make_shared<lstm::output_dropout_transcriber>(
89
- lstm::output_dropout_transcriber {
90
- *gen, dropout, step })
91
- });
45
+ f_trans = std::make_shared<lstm::lstm_transcriber>(
46
+ lstm::lstm_transcriber { step });
47
+ f_trans = std::make_shared<lstm::input_dropout_transcriber>(
48
+ lstm::input_dropout_transcriber { f_trans, dropout, *gen });
49
+ f_trans = std::make_shared<lstm::output_dropout_transcriber>(
50
+ lstm::output_dropout_transcriber { f_trans, dropout, *gen });
51
+
52
+ b_trans = std::make_shared<lstm::lstm_transcriber>(
53
+ lstm::lstm_transcriber { step, true });
54
+ b_trans = std::make_shared<lstm::input_dropout_transcriber>(
55
+ lstm::input_dropout_transcriber { b_trans, dropout, *gen });
56
+ b_trans = std::make_shared<lstm::output_dropout_transcriber>(
57
+ lstm::output_dropout_transcriber { b_trans, dropout, *gen });
92
58
} else {
93
- trans = std::make_shared<lstm::lstm_transcriber>(
59
+ f_trans = std::make_shared<lstm::lstm_transcriber>(
94
60
lstm::lstm_transcriber { step });
61
+ b_trans = std::make_shared<lstm::lstm_transcriber>(
62
+ lstm::lstm_transcriber { step, true });
95
63
}
96
64
97
- trans = std::make_shared<lstm::bi_transcriber>(
98
- lstm::bi_transcriber { trans });
65
+ std::shared_ptr<lstm::transcriber> trans = std::make_shared<lstm::bi_transcriber>(
66
+ lstm::bi_transcriber { f_trans, b_trans });
99
67
100
68
result.layer .push_back (trans);
101
69
}
@@ -109,41 +77,39 @@ namespace lstm_frame {
109
77
double dropout,
110
78
std::default_random_engine *gen)
111
79
{
112
- std::shared_ptr<lstm::step_transcriber> step;
113
-
114
- if (dropout != 0.0 ) {
115
- assert (gen != nullptr );
116
-
117
- step = std::make_shared<lstm::input_dropout_transcriber>(
118
- lstm::input_dropout_transcriber {
119
- *gen, dropout,
120
- std::make_shared<lstm::dyer_lstm_step_transcriber>(
121
- lstm::dyer_lstm_step_transcriber{})
122
- });
123
- } else {
124
- step = std::make_shared<lstm::dyer_lstm_step_transcriber>(
80
+ std::shared_ptr<lstm::step_transcriber> step
81
+ = std::make_shared<lstm::dyer_lstm_step_transcriber>(
125
82
lstm::dyer_lstm_step_transcriber{});
126
- }
127
83
128
84
lstm::layered_transcriber result;
129
85
130
86
for (int i = 0 ; i < layer; ++i) {
131
- std::shared_ptr<lstm::transcriber> trans;
87
+ std::shared_ptr<lstm::transcriber> f_trans;
88
+ std::shared_ptr<lstm::transcriber> b_trans;
132
89
133
90
if (dropout != 0.0 ) {
134
- trans = std::make_shared<lstm::lstm_transcriber>(
135
- lstm::lstm_transcriber {
136
- std::make_shared<lstm::output_dropout_transcriber>(
137
- lstm::output_dropout_transcriber {
138
- *gen, dropout, step })
139
- });
91
+ f_trans = std::make_shared<lstm::lstm_transcriber>(
92
+ lstm::lstm_transcriber { step });
93
+ f_trans = std::make_shared<lstm::input_dropout_transcriber>(
94
+ lstm::input_dropout_transcriber { f_trans, dropout, *gen });
95
+ f_trans = std::make_shared<lstm::output_dropout_transcriber>(
96
+ lstm::output_dropout_transcriber { f_trans, dropout, *gen });
97
+
98
+ b_trans = std::make_shared<lstm::lstm_transcriber>(
99
+ lstm::lstm_transcriber { step, true });
100
+ b_trans = std::make_shared<lstm::input_dropout_transcriber>(
101
+ lstm::input_dropout_transcriber { b_trans, dropout, *gen });
102
+ b_trans = std::make_shared<lstm::output_dropout_transcriber>(
103
+ lstm::output_dropout_transcriber { b_trans, dropout, *gen });
140
104
} else {
141
- trans = std::make_shared<lstm::lstm_transcriber>(
105
+ f_trans = std::make_shared<lstm::lstm_transcriber>(
142
106
lstm::lstm_transcriber { step });
107
+ b_trans = std::make_shared<lstm::lstm_transcriber>(
108
+ lstm::lstm_transcriber { step, true });
143
109
}
144
110
145
- trans = std::make_shared<lstm::bi_transcriber>(
146
- lstm::bi_transcriber { trans });
111
+ std::shared_ptr<lstm::transcriber> trans = std::make_shared<lstm::bi_transcriber>(
112
+ lstm::bi_transcriber { f_trans, b_trans });
147
113
148
114
if (i != layer - 1 ) {
149
115
trans = std::make_shared<lstm::subsampled_transcriber>(
0 commit comments