Skip to content

Commit 19eac0d

Browse files
committed
Merge branch 'use_pctr_bn_logic' into 'master'
Use pctr bn logic See merge request deep-learning/tensornet!19
2 parents a1191ba + 441b14e commit 19eac0d

File tree

6 files changed

+76
-21
lines changed

6 files changed

+76
-21
lines changed

core/main/py_wrapper.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,10 @@ PYBIND11_MODULE(_pywrap_tn, m) {
136136

137137
return table->GetHandle();
138138
})
139-
.def("create_bn_table", [](std::string name, uint32_t bn_size, bool sync, float moment, uint64_t max_count) {
139+
.def("create_bn_table", [](std::string name, uint32_t bn_size, bool sync, float moment, uint64_t max_count, bool use_pctr_dnn_bn) {
140140
PsCluster* cluster = PsCluster::Instance();
141141

142-
BnTable* table = CreateBnTable(name, cluster->RankNum(), cluster->Rank(), bn_size, sync, moment, max_count);
142+
BnTable* table = CreateBnTable(name, cluster->RankNum(), cluster->Rank(), bn_size, sync, moment, max_count, use_pctr_dnn_bn);
143143

144144
return table->GetHandle();
145145
})

core/ps/table/bn_table.cc

+21-10
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,15 @@
2929

3030
namespace tensornet {
3131

32-
BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, uint64_t max_count)
32+
BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, uint64_t max_count, bool use_pctr_dnn_bn)
3333
: shard_num_(shard_num)
3434
, self_shard_id_(self_shard_id)
3535
, name_(name)
3636
, synchronized_(synchronized)
3737
, moment_(moment)
3838
, max_count_(max_count)
39-
, bn_size_(bn_size) {
39+
, bn_size_(bn_size)
40+
, use_pctr_dnn_bn_(use_pctr_dnn_bn){
4041
total_sum_.setZero(bn_size);
4142
total_sum_err_.setZero(bn_size);
4243
total_squared_sum_.setZero(bn_size);
@@ -104,9 +105,13 @@ void BnTable::TotalSumAcc(Eigen::ArrayXd acc){
104105

105106
std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> BnTable::GetMoments() {
106107
Eigen::ArrayXf global_mean = DivideNoNan(total_sum_, total_count_);
107-
Eigen::ArrayXf global_squared_mean = DivideNoNan(total_squared_sum_, total_count_);
108-
Eigen::ArrayXf global_var = (global_squared_mean - global_mean.square()).max(0.0);
109-
return std::make_tuple(global_mean, global_var);
108+
if(use_pctr_dnn_bn_){
109+
return std::make_tuple(global_mean, total_squared_sum_.cast<float>());
110+
} else {
111+
Eigen::ArrayXf global_squared_mean = DivideNoNan(total_squared_sum_, total_count_);
112+
Eigen::ArrayXf global_var = (global_squared_mean - global_mean.square()).max(0.0);
113+
return std::make_tuple(global_mean, global_var);
114+
}
110115
}
111116

112117
void BnTable::GetStatistics(const BnStatisticsPullRequest* req, butil::IOBuf& bn_statistics_buf, BnStatisticsPullResponse* resp) {
@@ -120,9 +125,9 @@ void BnTable::GetIncStatistics(butil::IOBuf& bn_statistics_buf) {
120125
bn_statistics_buf.append(inc_sum_.data(), inc_sum_.size() * sizeof(double));
121126
bn_statistics_buf.append(inc_squared_sum_.data(), inc_squared_sum_.size() * sizeof(double));
122127
bn_statistics_buf.append(inc_count_.data(), inc_count_.size() * sizeof(double));
123-
inc_sum_.setZero();
124-
inc_squared_sum_.setZero();
125-
inc_count_.setZero();
128+
inc_sum_.setZero();
129+
inc_squared_sum_.setZero();
130+
inc_count_.setZero();
126131
}
127132

128133

@@ -167,6 +172,11 @@ void BnTable::Load(const std::string& filepath) {
167172
in_stream.iword(SERIALIZE_FMT_ID) = SF_BIN;
168173

169174
int bn_size = 0;
175+
bool use_pctr_dnn_bn = false;
176+
177+
in_stream.read(reinterpret_cast<char*>(&use_pctr_dnn_bn), sizeof(use_pctr_dnn_bn));
178+
CHECK_EQ(use_pctr_dnn_bn_, use_pctr_dnn_bn) << "bn calculate logic should be same, before use pctrdnn is " << use_pctr_dnn_bn;
179+
170180
in_stream.read(reinterpret_cast<char*>(&bn_size), sizeof(bn_size));
171181

172182
for( int i = 0; i < bn_size; i++) {
@@ -187,6 +197,7 @@ void BnTable::Save(const std::string& filepath) {
187197
boost::iostreams::stream<FileWriterSink> out_stream(writer_sink);
188198
out_stream.iword(SERIALIZE_FMT_ID) = SF_BIN;
189199

200+
out_stream.write(reinterpret_cast<const char*>(&use_pctr_dnn_bn_), sizeof(use_pctr_dnn_bn_));
190201
out_stream.write(reinterpret_cast<const char*>(&bn_size_), sizeof(bn_size_));
191202

192203
for( int i = 0; i < bn_size_; i++) {
@@ -217,8 +228,8 @@ uint32_t BnTableRegistry::Register(BnTable* table) {
217228
return table_handle;
218229
}
219230

220-
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count) {
221-
BnTable* table = new BnTable(name, shard_num, self_shard_id, bn_size, sync, moment, max_count);
231+
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count, bool use_pctr_dnn_bn) {
232+
BnTable* table = new BnTable(name, shard_num, self_shard_id, bn_size, sync, moment, max_count, use_pctr_dnn_bn);
222233

223234
table->SetHandle(BnTableRegistry::Instance()->Register(table));
224235

core/ps/table/bn_table.h

+4-3
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace tensornet {
3131

3232
class BnTable {
3333
public:
34-
BnTable(const std::string& name,int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count);
34+
BnTable(const std::string& name,int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count, bool use_pctr_dnn_bn);
3535

3636
~BnTable() = default;
3737

@@ -67,7 +67,8 @@ class BnTable {
6767
uint32_t handle_ = 0;
6868
std::string name_;
6969
uint32_t bn_size_ = 0;
70-
bool synchronized_ = false;
70+
bool synchronized_ = false;
71+
bool use_pctr_dnn_bn_ = false;
7172
float moment_ = 0.0;
7273
uint64_t max_count_ = 0;
7374
Eigen::ArrayXd total_sum_;
@@ -103,7 +104,7 @@ class BnTableRegistry {
103104
std::vector<BnTable*> tables_;
104105
};
105106

106-
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count);
107+
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count, bool use_pctr_dnn_bn);
107108

108109
} // namespace tensornet
109110

tensornet/layers/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,6 @@
1414

1515
from .embedding_features import EmbeddingFeatures
1616
from .sequence_embedding_features import SequenceEmbeddingFeatures
17+
from .normalization_layer import TNBatchNormalizationBase
1718
from .normalization_layer import TNBatchNormalization
19+
from .normalization_layer import PCTRDNNBatchNormalization

tensornet/layers/normalization_layer.py

+45-4
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from tensorflow.python.ops import variable_scope, array_ops
1313

1414

15-
class TNBatchNormalization(Layer):
15+
class TNBatchNormalizationBase(Layer):
1616
"""
1717
Reference: https://github.com/keras-team/keras/blob/v3.5.0/keras/src/layers/normalization/batch_normalization.py
1818
@@ -25,8 +25,10 @@ class TNBatchNormalization(Layer):
2525
sync_freq: frequency that bn statistics will be sent to other ranks(based on batches). Only should be used when 'synchronized' is True
2626
max_count: Threshold that to avoid bn statistics overflow. Note that: it's record number, not batch number. This is an empirical parameter that needs to be adjusted based on the size of the training data.
2727
"""
28+
_USE_PCTR_DNN_BN = False
29+
2830
def __init__(self, center=True, scale=True, epsilon=1e-5, momentum=0.99, name=None, synchronized=False, sync_freq=1,max_count=100000,**kwargs):
29-
super(TNBatchNormalization, self).__init__(**kwargs)
31+
super(TNBatchNormalizationBase, self).__init__(**kwargs)
3032
self.center = center
3133
self.scale = scale
3234
self.epsilon = epsilon
@@ -97,8 +99,7 @@ def build(self, input_shape):
9799
initializer=self.local_squared_num_initializer,
98100
trainable=False)
99101

100-
self.bn_table_handle = tn.core.create_bn_table(self.name, self.apply_axis[0], self.synchronized, self.momentum, self.max_count)
101-
102+
self.bn_table_handle = tn.core.create_bn_table(self.name, self.apply_axis[0], self.synchronized, self.momentum, self.max_count, self._USE_PCTR_DNN_BN)
102103

103104
def call(self, inputs, training=None):
104105

@@ -147,3 +148,43 @@ def save_bn_table(self, filepath):
147148
def load_bn_table(self, filepath):
148149
return tn.core.load_bn_table(self.bn_table_handle, filepath)
149150

151+
152+
class TNBatchNormalization(TNBatchNormalizationBase):
153+
"""
154+
Calculate incremental count, sum, squared sum. use (squared_sum / count - (sum / count).square) as var
155+
"""
156+
157+
class PCTRDNNBatchNormalization(TNBatchNormalizationBase):
158+
"""
159+
Calculate incremental count, sum. Calculate incremental (data - mean).sqrt() as var
160+
"""
161+
_USE_PCTR_DNN_BN = True
162+
163+
def call(self, inputs, training=None):
164+
165+
@tf.function
166+
def _increment_and_check_count():
167+
self.batch_counter.assign_add(1)
168+
if tf.equal(self.batch_counter, self.sync_freq):
169+
self.bn_statistics_push(True)
170+
self.batch_counter.assign(0)
171+
else:
172+
self.bn_statistics_push(False)
173+
174+
self.update_moments()
175+
mean = self.moving_mean
176+
var = self.moving_variance
177+
178+
if training:
179+
local_count_sample = tf.ones_like(inputs, name="count")
180+
self.local_sum.assign(tf.reduce_sum(inputs, axis=self.moments_axes))
181+
self.local_squared_sum.assign(tf.reduce_sum(tf.square(inputs - self.moving_mean), axis=self.moments_axes))
182+
self.local_count.assign(tf.reduce_sum(local_count_sample, axis=self.moments_axes))
183+
if self.synchronized:
184+
_increment_and_check_count()
185+
else:
186+
self.bn_statistics_push(False)
187+
188+
outputs = tf.nn.batch_normalization(x=inputs, mean=mean, variance=var, offset=self.beta, scale=self.gamma, variance_epsilon=self.epsilon)
189+
190+
return outputs

tensornet/model/Model.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def save_weights(self, filepath, overwrite=True, save_format=None, dt="", root=T
177177
layer.save_sparse_table(cp_dir, mode)
178178
elif isinstance(layer, tn.layers.SequenceEmbeddingFeatures):
179179
layer.save_sparse_table(cp_dir, mode)
180-
elif isinstance(layer, tn.layers.TNBatchNormalization):
180+
elif isinstance(layer, tn.layers.TNBatchNormalizationBase):
181181
if tn.core.self_shard_id() == 0:
182182
layer.bn_statistics_pull()
183183
layer.save_bn_table(cp_dir)
@@ -223,7 +223,7 @@ def load_weights(self, filepath, by_name=False, skip_mismatch=False, include_dt=
223223
layer.load_sparse_table(cp_dir, mode)
224224
elif isinstance(layer, tn.layers.SequenceEmbeddingFeatures):
225225
layer.load_sparse_table(cp_dir, mode)
226-
elif isinstance(layer, tn.layers.TNBatchNormalization):
226+
elif isinstance(layer, tn.layers.TNBatchNormalizationBase):
227227
layer.load_bn_table(cp_dir)
228228

229229
# dense weight

0 commit comments

Comments
 (0)