29
29
30
30
namespace tensornet {
31
31
32
- BnTable::BnTable (const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, uint64_t max_count)
32
+ BnTable::BnTable (const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, uint64_t max_count, bool use_pctr_dnn_bn )
33
33
: shard_num_(shard_num)
34
34
, self_shard_id_(self_shard_id)
35
35
, name_(name)
36
36
, synchronized_(synchronized)
37
37
, moment_(moment)
38
38
, max_count_(max_count)
39
- , bn_size_(bn_size) {
39
+ , bn_size_(bn_size)
40
+ , use_pctr_dnn_bn_(use_pctr_dnn_bn){
40
41
total_sum_.setZero (bn_size);
41
42
total_sum_err_.setZero (bn_size);
42
43
total_squared_sum_.setZero (bn_size);
@@ -104,9 +105,13 @@ void BnTable::TotalSumAcc(Eigen::ArrayXd acc){
104
105
105
106
std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> BnTable::GetMoments () {
106
107
Eigen::ArrayXf global_mean = DivideNoNan (total_sum_, total_count_);
107
- Eigen::ArrayXf global_squared_mean = DivideNoNan (total_squared_sum_, total_count_);
108
- Eigen::ArrayXf global_var = (global_squared_mean - global_mean.square ()).max (0.0 );
109
- return std::make_tuple (global_mean, global_var);
108
+ if (use_pctr_dnn_bn_){
109
+ return std::make_tuple (global_mean, total_squared_sum_.cast <float >());
110
+ } else {
111
+ Eigen::ArrayXf global_squared_mean = DivideNoNan (total_squared_sum_, total_count_);
112
+ Eigen::ArrayXf global_var = (global_squared_mean - global_mean.square ()).max (0.0 );
113
+ return std::make_tuple (global_mean, global_var);
114
+ }
110
115
}
111
116
112
117
void BnTable::GetStatistics (const BnStatisticsPullRequest* req, butil::IOBuf& bn_statistics_buf, BnStatisticsPullResponse* resp) {
@@ -120,9 +125,9 @@ void BnTable::GetIncStatistics(butil::IOBuf& bn_statistics_buf) {
120
125
bn_statistics_buf.append (inc_sum_.data (), inc_sum_.size () * sizeof (double ));
121
126
bn_statistics_buf.append (inc_squared_sum_.data (), inc_squared_sum_.size () * sizeof (double ));
122
127
bn_statistics_buf.append (inc_count_.data (), inc_count_.size () * sizeof (double ));
123
- inc_sum_.setZero ();
124
- inc_squared_sum_.setZero ();
125
- inc_count_.setZero ();
128
+ inc_sum_.setZero ();
129
+ inc_squared_sum_.setZero ();
130
+ inc_count_.setZero ();
126
131
}
127
132
128
133
@@ -167,6 +172,11 @@ void BnTable::Load(const std::string& filepath) {
167
172
in_stream.iword (SERIALIZE_FMT_ID) = SF_BIN;
168
173
169
174
int bn_size = 0 ;
175
+ bool use_pctr_dnn_bn = false ;
176
+
177
+ in_stream.read (reinterpret_cast <char *>(&use_pctr_dnn_bn), sizeof (use_pctr_dnn_bn));
178
+ CHECK_EQ (use_pctr_dnn_bn_, use_pctr_dnn_bn) << " bn calculate logic should be same, before use pctrdnn is " << use_pctr_dnn_bn;
179
+
170
180
in_stream.read (reinterpret_cast <char *>(&bn_size), sizeof (bn_size));
171
181
172
182
for ( int i = 0 ; i < bn_size; i++) {
@@ -187,6 +197,7 @@ void BnTable::Save(const std::string& filepath) {
187
197
boost::iostreams::stream<FileWriterSink> out_stream (writer_sink);
188
198
out_stream.iword (SERIALIZE_FMT_ID) = SF_BIN;
189
199
200
+ out_stream.write (reinterpret_cast <const char *>(&use_pctr_dnn_bn_), sizeof (use_pctr_dnn_bn_));
190
201
out_stream.write (reinterpret_cast <const char *>(&bn_size_), sizeof (bn_size_));
191
202
192
203
for ( int i = 0 ; i < bn_size_; i++) {
@@ -217,8 +228,8 @@ uint32_t BnTableRegistry::Register(BnTable* table) {
217
228
return table_handle;
218
229
}
219
230
220
- BnTable* CreateBnTable (const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count) {
221
- BnTable* table = new BnTable (name, shard_num, self_shard_id, bn_size, sync , moment, max_count);
231
+ BnTable* CreateBnTable (const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count, bool use_pctr_dnn_bn ) {
232
+ BnTable* table = new BnTable (name, shard_num, self_shard_id, bn_size, sync , moment, max_count, use_pctr_dnn_bn );
222
233
223
234
table->SetHandle (BnTableRegistry::Instance ()->Register (table));
224
235
0 commit comments