29
29
30
30
namespace tensornet {
31
31
32
- BnTable::BnTable (const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, int max_count)
32
+ BnTable::BnTable (const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, uint64_t max_count)
33
33
: shard_num_(shard_num)
34
34
, self_shard_id_(self_shard_id)
35
35
, name_(name)
@@ -38,7 +38,9 @@ BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int
38
38
, max_count_(max_count)
39
39
, bn_size_(bn_size) {
40
40
total_sum_.setZero (bn_size);
41
+ total_sum_err_.setZero (bn_size);
41
42
total_squared_sum_.setZero (bn_size);
43
+ total_squared_sum_err_.setZero (bn_size);
42
44
total_count_.setZero (bn_size);
43
45
inc_sum_.setZero (bn_size);
44
46
inc_squared_sum_.setZero (bn_size);
@@ -54,35 +56,54 @@ void BnTable::SetHandle(uint32_t handle) {
54
56
55
57
void BnTable::Append (butil::IOBuf& bn_statistics_buf, bool isLocal) {
56
58
const std::lock_guard<std::mutex> lock (*mu_);
57
- Eigen::ArrayXf acc_sum = Eigen::ArrayXf ::Zero (bn_size_);
58
- Eigen::ArrayXf acc_squared_sum = Eigen::ArrayXf ::Zero (bn_size_);
59
- Eigen::ArrayXf acc_count = Eigen::ArrayXf ::Zero (bn_size_);
59
+ Eigen::ArrayXd acc_sum = Eigen::ArrayXd ::Zero (bn_size_);
60
+ Eigen::ArrayXd acc_squared_sum = Eigen::ArrayXd ::Zero (bn_size_);
61
+ Eigen::ArrayXd acc_count = Eigen::ArrayXd ::Zero (bn_size_);
60
62
61
- bn_statistics_buf.cutn (acc_sum.data (), acc_sum.size () * sizeof (float ));
62
- bn_statistics_buf.cutn (acc_squared_sum.data (), acc_squared_sum.size () * sizeof (float ));
63
- bn_statistics_buf.cutn (acc_count.data (), acc_count.size () * sizeof (float ));
63
+ bn_statistics_buf.cutn (acc_sum.data (), acc_sum.size () * sizeof (double ));
64
+ bn_statistics_buf.cutn (acc_squared_sum.data (), acc_squared_sum.size () * sizeof (double ));
65
+ bn_statistics_buf.cutn (acc_count.data (), acc_count.size () * sizeof (double ));
66
+ CHECK_EQ (bn_statistics_buf.size (), 0 );
64
67
65
- if (synchronized_ && isLocal){
68
+ if (isLocal){
66
69
inc_sum_ += acc_sum;
67
70
inc_squared_sum_ += acc_squared_sum;
68
71
inc_count_ += acc_count;
69
72
}
70
73
71
- int cur_count = static_cast <int >(total_count_.maxCoeff ());
72
- if (cur_count > max_count_) {
73
- int acc_count_num = static_cast <int >(acc_count.maxCoeff ());
74
- float ratio = (float ) acc_count_num / cur_count;
75
- total_sum_ = total_sum_ * (1 - (1 - moment_) * ratio) + (1 - moment_) * ratio * acc_sum;
76
- total_squared_sum_ = total_squared_sum_ * (1 - (1 - moment_) * ratio) + (1 - moment_) * ratio * acc_squared_sum;
77
-
74
+ uint64_t cur_count = static_cast <uint64_t >(total_count_.maxCoeff ());
75
+
76
+ // std::cout << "cur_count is : " << cur_count << std::endl;
77
+ // PrintDetail();
78
+ // std::cout << "acc_count is : " << acc_count(0) << std::endl;
79
+ if (max_count_ > 0 && cur_count > max_count_) {
80
+ uint64_t acc_count_num = static_cast <uint64_t >(acc_count.maxCoeff ());
81
+ double ratio = (double ) acc_count_num / cur_count;
82
+ total_sum_ *= (1 - (1 - moment_) * ratio);
83
+ TotalSumAcc ((1 - moment_) * ratio * acc_sum);
84
+ total_squared_sum_ *= (1 - (1 - moment_) * ratio);
85
+ TotalSquareSumAcc ((1 - moment_) * ratio * acc_squared_sum);
78
86
} else {
79
-
80
- total_sum_ += acc_sum;
81
- total_squared_sum_ += acc_squared_sum;
82
- total_count_ += acc_count;
87
+ TotalSumAcc (acc_sum);
88
+ TotalSquareSumAcc (acc_squared_sum);
89
+ total_count_ += acc_count;
83
90
}
84
91
}
85
92
93
+ void BnTable::TotalSquareSumAcc (Eigen::ArrayXd acc){
94
+ Eigen::ArrayXd y = acc - total_squared_sum_err_;
95
+ Eigen::ArrayXd t = total_squared_sum_ + y;
96
+ total_squared_sum_err_ = (t - total_squared_sum_) - y;
97
+ total_squared_sum_ = t;
98
+ }
99
+
100
+ void BnTable::TotalSumAcc (Eigen::ArrayXd acc){
101
+ Eigen::ArrayXd y = acc - total_sum_err_;
102
+ Eigen::ArrayXd t = total_sum_ + y;
103
+ total_sum_err_ = (t - total_sum_) - y;
104
+ total_sum_ = t;
105
+ }
106
+
86
107
87
108
std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> BnTable::GetMoments () {
88
109
Eigen::ArrayXf global_mean = DivideNoNan (total_sum_, total_count_);
@@ -93,15 +114,15 @@ std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> BnTable::GetMoments() {
93
114
94
115
void BnTable::GetStatistics (const BnStatisticsPullRequest* req, butil::IOBuf& bn_statistics_buf, BnStatisticsPullResponse* resp) {
95
116
resp->set_table_handle (req->table_handle ());
96
- bn_statistics_buf.append (total_sum_.data (), total_sum_.size () * sizeof (float ));
97
- bn_statistics_buf.append (total_squared_sum_.data (), total_squared_sum_.size () * sizeof (float ));
98
- bn_statistics_buf.append (total_count_.data (), total_count_.size () * sizeof (float ));
117
+ bn_statistics_buf.append (total_sum_.data (), total_sum_.size () * sizeof (double ));
118
+ bn_statistics_buf.append (total_squared_sum_.data (), total_squared_sum_.size () * sizeof (double ));
119
+ bn_statistics_buf.append (total_count_.data (), total_count_.size () * sizeof (double ));
99
120
}
100
121
101
122
void BnTable::GetIncStatistics (butil::IOBuf& bn_statistics_buf) {
102
- bn_statistics_buf.append (inc_sum_.data (), inc_sum_.size () * sizeof (float ));
103
- bn_statistics_buf.append (inc_squared_sum_.data (), inc_squared_sum_.size () * sizeof (float ));
104
- bn_statistics_buf.append (inc_count_.data (), inc_count_.size () * sizeof (float ));
123
+ bn_statistics_buf.append (inc_sum_.data (), inc_sum_.size () * sizeof (double ));
124
+ bn_statistics_buf.append (inc_squared_sum_.data (), inc_squared_sum_.size () * sizeof (double ));
125
+ bn_statistics_buf.append (inc_count_.data (), inc_count_.size () * sizeof (double ));
105
126
inc_sum_.setZero ();
106
127
inc_squared_sum_.setZero ();
107
128
inc_count_.setZero ();
@@ -119,16 +140,16 @@ void BnTable::Refresh() {
119
140
}
120
141
121
142
122
- Eigen::ArrayXf BnTable::DivideNoNan (const Eigen::ArrayXf & numerator, const Eigen::ArrayXf & denominator) {
123
- Eigen::ArrayXf result = numerator;
143
+ Eigen::ArrayXf BnTable::DivideNoNan (const Eigen::ArrayXd & numerator, const Eigen::ArrayXd & denominator) {
144
+ Eigen::ArrayXd result = numerator;
124
145
for (int i = 0 ; i < numerator.size (); ++i) {
125
146
if (!std::isnan (denominator (i)) && denominator (i) != 0.0 ) {
126
147
result (i) = numerator (i) / denominator (i);
127
148
} else {
128
149
result (i) = 0.0 ;
129
150
}
130
151
}
131
- return result;
152
+ return result. cast < float >() ;
132
153
}
133
154
134
155
void BnTable::PrintDetail (){
@@ -199,7 +220,7 @@ uint32_t BnTableRegistry::Register(BnTable* table) {
199
220
return table_handle;
200
221
}
201
222
202
- BnTable* CreateBnTable (const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, int max_count) {
223
+ BnTable* CreateBnTable (const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count) {
203
224
BnTable* table = new BnTable (name, shard_num, self_shard_id, bn_size, sync , moment, max_count);
204
225
205
226
table->SetHandle (BnTableRegistry::Instance ()->Register (table));
0 commit comments