Skip to content

Commit 4d44aa1

Browse files
author
jiangxinglei
committed
update bn
1 parent 77fbd94 commit 4d44aa1

File tree

5 files changed

+91
-43
lines changed

5 files changed

+91
-43
lines changed

core/kernels/bn_table_ops.cc

+19-1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class BnStatisticsPushKernel : public AsyncOpKernel {
8787
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
8888
butil::IOBuf acc_buf;
8989

90+
std::vector<double*> allocated_pointers;
91+
9092
for (int i = 0; i < N_; i++) {
9193
const ResourceHandle& handle = HandleFromInput(c, i);
9294

@@ -97,12 +99,28 @@ class BnStatisticsPushKernel : public AsyncOpKernel {
9799
CHECK(variable);
98100

99101
Tensor *var_tensor = variable->tensor();
100-
acc_buf.append_user_data(var_tensor->flat<float>().data(), var_tensor->NumElements() * sizeof(float), NoOpDeleter);
102+
103+
int num_elements = var_tensor->NumElements();
104+
double* dynamic_double_data = new double[num_elements];
105+
const float* float_data = var_tensor->flat<float>().data();
106+
for (int i = 0; i < num_elements; ++i) {
107+
// std::cout << "float data is: " << float_data[i] << std::endl;
108+
dynamic_double_data[i] = static_cast<double>(float_data[i]);
109+
// std::cout << "double data is: " << dynamic_double_data[i] << std::endl;
110+
}
111+
acc_buf.append_user_data(dynamic_double_data, num_elements * sizeof(double), NoOpDeleter);
112+
// acc_buf.append_user_data(var_tensor->flat<float>().data(), var_tensor->NumElements() * sizeof(float), NoOpDeleter);
113+
allocated_pointers.push_back(dynamic_double_data);
101114
}
102115

103116
BnTable* table = BnTableRegistry::Instance()->Get(table_handle_);
104117
table->Append(acc_buf, true);
105118

119+
for (auto ptr : allocated_pointers) {
120+
delete[] ptr;
121+
}
122+
allocated_pointers.clear();
123+
106124
if(synchronized_){
107125
PsCluster* cluster = PsCluster::Instance();
108126
OP_REQUIRES_ASYNC( c, true == cluster->IsInitialized(),

core/main/py_wrapper.cc

+6-2
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,14 @@ PYBIND11_MODULE(_pywrap_tn, m) {
114114

115115
return py::reinterpret_steal<py::object>(obj);
116116
})
117-
.def("create_sparse_table", [](py::object obj, std::string name, int dimension) {
117+
.def("create_sparse_table", [](py::object obj, std::string name, int dimension, bool use_cvm) {
118118
OptimizerBase* opt =
119119
static_cast<OptimizerBase*>(PyCapsule_GetPointer(obj.ptr(), nullptr));
120120

121+
opt->SetUseCvm(use_cvm);
122+
123+
std::cout << "Cvm plugin is: " << opt->ShouldUseCvm() << std::endl;
124+
121125
PsCluster* cluster = PsCluster::Instance();
122126

123127
SparseTable* table = CreateSparseTable(opt, name, dimension, cluster->RankNum(), cluster->Rank());
@@ -134,7 +138,7 @@ PYBIND11_MODULE(_pywrap_tn, m) {
134138

135139
return table->GetHandle();
136140
})
137-
.def("create_bn_table", [](std::string name, uint32_t bn_size, bool sync, float moment, int max_count) {
141+
.def("create_bn_table", [](std::string name, uint32_t bn_size, bool sync, float moment, uint64_t max_count) {
138142
PsCluster* cluster = PsCluster::Instance();
139143

140144
BnTable* table = CreateBnTable(name, cluster->RankNum(), cluster->Rank(), bn_size, sync, moment, max_count);

core/ps/ps_local_server.cc

+2-1
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ void PsLocalServer::BnStatisticsPullAsync(brpc::Controller *cntl,
104104
Callback done) const {
105105
BnTable *table = BnTableRegistry::Instance()->Get(request->table_handle());
106106
CHECK(nullptr != table);
107+
response->set_table_handle(request->table_handle());
107108
butil::IOBuf& bn_statistics_buf = cntl->response_attachment();
108-
table->GetStatistics(request, bn_statistics_buf, response);
109+
table->GetIncStatistics(bn_statistics_buf);
109110

110111
done();
111112
}

core/ps/table/bn_table.cc

+50-29
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030
namespace tensornet {
3131

32-
BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, int max_count)
32+
BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, uint64_t max_count)
3333
: shard_num_(shard_num)
3434
, self_shard_id_(self_shard_id)
3535
, name_(name)
@@ -38,7 +38,9 @@ BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int
3838
, max_count_(max_count)
3939
, bn_size_(bn_size) {
4040
total_sum_.setZero(bn_size);
41+
total_sum_err_.setZero(bn_size);
4142
total_squared_sum_.setZero(bn_size);
43+
total_squared_sum_err_.setZero(bn_size);
4244
total_count_.setZero(bn_size);
4345
inc_sum_.setZero(bn_size);
4446
inc_squared_sum_.setZero(bn_size);
@@ -54,35 +56,54 @@ void BnTable::SetHandle(uint32_t handle) {
5456

5557
void BnTable::Append(butil::IOBuf& bn_statistics_buf, bool isLocal) {
5658
const std::lock_guard<std::mutex> lock(*mu_);
57-
Eigen::ArrayXf acc_sum = Eigen::ArrayXf::Zero(bn_size_);
58-
Eigen::ArrayXf acc_squared_sum = Eigen::ArrayXf::Zero(bn_size_);
59-
Eigen::ArrayXf acc_count = Eigen::ArrayXf::Zero(bn_size_);
59+
Eigen::ArrayXd acc_sum = Eigen::ArrayXd::Zero(bn_size_);
60+
Eigen::ArrayXd acc_squared_sum = Eigen::ArrayXd::Zero(bn_size_);
61+
Eigen::ArrayXd acc_count = Eigen::ArrayXd::Zero(bn_size_);
6062

61-
bn_statistics_buf.cutn(acc_sum.data(), acc_sum.size() * sizeof(float));
62-
bn_statistics_buf.cutn(acc_squared_sum.data(), acc_squared_sum.size() * sizeof(float));
63-
bn_statistics_buf.cutn(acc_count.data(), acc_count.size() * sizeof(float));
63+
bn_statistics_buf.cutn(acc_sum.data(), acc_sum.size() * sizeof(double));
64+
bn_statistics_buf.cutn(acc_squared_sum.data(), acc_squared_sum.size() * sizeof(double));
65+
bn_statistics_buf.cutn(acc_count.data(), acc_count.size() * sizeof(double));
66+
CHECK_EQ(bn_statistics_buf.size(), 0);
6467

65-
if(synchronized_ && isLocal){
68+
if(isLocal){
6669
inc_sum_ += acc_sum;
6770
inc_squared_sum_ += acc_squared_sum;
6871
inc_count_ += acc_count;
6972
}
7073

71-
int cur_count = static_cast<int>(total_count_.maxCoeff());
72-
if(cur_count > max_count_) {
73-
int acc_count_num = static_cast<int>(acc_count.maxCoeff());
74-
float ratio = (float) acc_count_num / cur_count;
75-
total_sum_ = total_sum_ * (1 - (1 - moment_) * ratio) + (1 - moment_) * ratio * acc_sum;
76-
total_squared_sum_ = total_squared_sum_ * (1 - (1 - moment_) * ratio) + (1 - moment_) * ratio * acc_squared_sum;
77-
74+
uint64_t cur_count = static_cast<uint64_t>(total_count_.maxCoeff());
75+
76+
// std::cout << "cur_count is : " << cur_count << std::endl;
77+
// PrintDetail();
78+
// std::cout << "acc_count is : " << acc_count(0) << std::endl;
79+
if(max_count_ > 0 && cur_count > max_count_) {
80+
uint64_t acc_count_num = static_cast<uint64_t>(acc_count.maxCoeff());
81+
double ratio = (double) acc_count_num / cur_count;
82+
total_sum_ *= (1 - (1 - moment_) * ratio);
83+
TotalSumAcc((1 - moment_) * ratio * acc_sum);
84+
total_squared_sum_ *= (1 - (1 - moment_) * ratio);
85+
TotalSquareSumAcc((1 - moment_) * ratio * acc_squared_sum);
7886
} else {
79-
80-
total_sum_ += acc_sum;
81-
total_squared_sum_ += acc_squared_sum;
82-
total_count_ += acc_count;
87+
TotalSumAcc(acc_sum);
88+
TotalSquareSumAcc(acc_squared_sum);
89+
total_count_ += acc_count;
8390
}
8491
}
8592

93+
void BnTable::TotalSquareSumAcc(Eigen::ArrayXd acc){
94+
Eigen::ArrayXd y = acc - total_squared_sum_err_;
95+
Eigen::ArrayXd t = total_squared_sum_ + y;
96+
total_squared_sum_err_ = (t - total_squared_sum_) - y;
97+
total_squared_sum_ = t;
98+
}
99+
100+
void BnTable::TotalSumAcc(Eigen::ArrayXd acc){
101+
Eigen::ArrayXd y = acc - total_sum_err_;
102+
Eigen::ArrayXd t = total_sum_ + y;
103+
total_sum_err_ = (t - total_sum_) - y;
104+
total_sum_ = t;
105+
}
106+
86107

87108
std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> BnTable::GetMoments() {
88109
Eigen::ArrayXf global_mean = DivideNoNan(total_sum_, total_count_);
@@ -93,15 +114,15 @@ std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> BnTable::GetMoments() {
93114

94115
void BnTable::GetStatistics(const BnStatisticsPullRequest* req, butil::IOBuf& bn_statistics_buf, BnStatisticsPullResponse* resp) {
95116
resp->set_table_handle(req->table_handle());
96-
bn_statistics_buf.append(total_sum_.data(), total_sum_.size() * sizeof(float));
97-
bn_statistics_buf.append(total_squared_sum_.data(), total_squared_sum_.size() * sizeof(float));
98-
bn_statistics_buf.append(total_count_.data(), total_count_.size() * sizeof(float));
117+
bn_statistics_buf.append(total_sum_.data(), total_sum_.size() * sizeof(double));
118+
bn_statistics_buf.append(total_squared_sum_.data(), total_squared_sum_.size() * sizeof(double));
119+
bn_statistics_buf.append(total_count_.data(), total_count_.size() * sizeof(double));
99120
}
100121

101122
void BnTable::GetIncStatistics(butil::IOBuf& bn_statistics_buf) {
102-
bn_statistics_buf.append(inc_sum_.data(), inc_sum_.size() * sizeof(float));
103-
bn_statistics_buf.append(inc_squared_sum_.data(), inc_squared_sum_.size() * sizeof(float));
104-
bn_statistics_buf.append(inc_count_.data(), inc_count_.size() * sizeof(float));
123+
bn_statistics_buf.append(inc_sum_.data(), inc_sum_.size() * sizeof(double));
124+
bn_statistics_buf.append(inc_squared_sum_.data(), inc_squared_sum_.size() * sizeof(double));
125+
bn_statistics_buf.append(inc_count_.data(), inc_count_.size() * sizeof(double));
105126
inc_sum_.setZero();
106127
inc_squared_sum_.setZero();
107128
inc_count_.setZero();
@@ -119,16 +140,16 @@ void BnTable::Refresh() {
119140
}
120141

121142

122-
Eigen::ArrayXf BnTable::DivideNoNan(const Eigen::ArrayXf& numerator, const Eigen::ArrayXf& denominator) {
123-
Eigen::ArrayXf result = numerator;
143+
Eigen::ArrayXf BnTable::DivideNoNan(const Eigen::ArrayXd& numerator, const Eigen::ArrayXd& denominator) {
144+
Eigen::ArrayXd result = numerator;
124145
for (int i = 0; i < numerator.size(); ++i) {
125146
if (!std::isnan(denominator(i)) && denominator(i) != 0.0) {
126147
result(i) = numerator(i) / denominator(i);
127148
} else {
128149
result(i) = 0.0;
129150
}
130151
}
131-
return result;
152+
return result.cast<float>();
132153
}
133154

134155
void BnTable::PrintDetail(){
@@ -199,7 +220,7 @@ uint32_t BnTableRegistry::Register(BnTable* table) {
199220
return table_handle;
200221
}
201222

202-
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, int max_count) {
223+
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count) {
203224
BnTable* table = new BnTable(name, shard_num, self_shard_id, bn_size, sync, moment, max_count);
204225

205226
table->SetHandle(BnTableRegistry::Instance()->Register(table));

core/ps/table/bn_table.h

+14-10
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ namespace tensornet {
3131

3232
class BnTable {
3333
public:
34-
BnTable(const std::string& name,int shard_num, int self_shard_id, int bn_size, bool sync, float moment, int max_count);
34+
BnTable(const std::string& name,int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count);
3535

3636
~BnTable() = default;
3737

@@ -44,8 +44,10 @@ class BnTable {
4444
std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> GetMoments();
4545
std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> GetIncMoments();
4646

47-
Eigen::ArrayXf DivideNoNan(const Eigen::ArrayXf& numerator, const Eigen::ArrayXf& denominator);
47+
Eigen::ArrayXf DivideNoNan(const Eigen::ArrayXd& numerator, const Eigen::ArrayXd& denominator);
4848

49+
void TotalSumAcc(Eigen::ArrayXd acc_sum);
50+
void TotalSquareSumAcc(Eigen::ArrayXd acc_square_sum);
4951
void Save(const std::string& filepath);
5052
void Load(const std::string& filepath);
5153

@@ -67,13 +69,15 @@ class BnTable {
6769
uint32_t bn_size_ = 0;
6870
bool synchronized_ = false;
6971
float moment_ = 0.0;
70-
int max_count_ = 0;
71-
Eigen::ArrayXf total_sum_;
72-
Eigen::ArrayXf total_squared_sum_;
73-
Eigen::ArrayXf total_count_;
74-
Eigen::ArrayXf inc_sum_;
75-
Eigen::ArrayXf inc_squared_sum_;
76-
Eigen::ArrayXf inc_count_;
72+
uint64_t max_count_ = 0;
73+
Eigen::ArrayXd total_sum_;
74+
Eigen::ArrayXd total_sum_err_;
75+
Eigen::ArrayXd total_squared_sum_;
76+
Eigen::ArrayXd total_squared_sum_err_;
77+
Eigen::ArrayXd total_count_;
78+
Eigen::ArrayXd inc_sum_;
79+
Eigen::ArrayXd inc_squared_sum_;
80+
Eigen::ArrayXd inc_count_;
7781
std::unique_ptr<std::mutex> mu_;
7882

7983
};
@@ -100,7 +104,7 @@ class BnTableRegistry {
100104
std::vector<BnTable*> tables_;
101105
};
102106

103-
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, int max_count);
107+
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count);
104108

105109
} // namespace tensornet
106110

0 commit comments

Comments
 (0)