Skip to content

Commit 7fd2871

Browse files
committed
Merge branch 'add_global_normlization' into 'master'
Add global normlization See merge request deep-learning/tensornet!15
2 parents d16ea64 + 8ac64c2 commit 7fd2871

21 files changed

+1324
-3
lines changed

.bumpversion.cfg

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[bumpversion]
2-
current_version = 0.1.3.post2
2+
current_version = 0.2.0.rc
33
commit = False
44
tag = False
55
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)(\.(?P<release>[a-z]+)(?P<build>\d*))?

core/BUILD

+25
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,30 @@ tf_gen_op_wrapper_py(
5151
cc_linkopts = ['-lrt'],
5252
)
5353

54+
cc_library(
55+
name = "bn_table_ops_kernels",
56+
srcs = [
57+
"kernels/bn_table_ops_dummy.cc",
58+
"ops/bn_table_ops.cc",
59+
],
60+
hdrs = [
61+
"//core/utility:semaphore",
62+
],
63+
linkstatic = 1,
64+
deps = [
65+
"@org_tensorflow//tensorflow/core:framework",
66+
"@org_tensorflow//tensorflow/core:lib",
67+
"@org_tensorflow//tensorflow/core:protos_all_cc",
68+
],
69+
alwayslink = 0,
70+
)
71+
72+
tf_gen_op_wrapper_py(
73+
name = "bn_table_ops",
74+
deps = [":bn_table_ops_kernels"],
75+
cc_linkopts = ['-lrt', '-lssl']
76+
)
77+
5478
cc_library(
5579
name = "balance_dataset_ops_kernels",
5680
srcs = [
@@ -127,6 +151,7 @@ cc_binary(
127151
"kernels/dense_table_ops.cc",
128152
"kernels/data/balance_dataset_ops.cc",
129153
"kernels/data/balance_dataset_ops.h",
154+
"kernels/bn_table_ops.cc",
130155
"public/version.h",
131156
"kernels/resource_var_wrapper.h",
132157
"//core/utility:semaphore",

core/kernels/bn_table_ops.cc

+309
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
// Copyright (c) 2020, Qihoo, Inc. All rights reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "core/utility/semaphore.h"
16+
#include "core/ps/table/bn_table.h"
17+
18+
#include "tensorflow/core/framework/attr_value.pb.h"
19+
#include "tensorflow/core/framework/op_kernel.h"
20+
#include "tensorflow/core/lib/core/errors.h"
21+
#include "tensorflow/core/lib/core/refcount.h"
22+
23+
#include "core/kernels/resource_var_wrapper.h"
24+
#include "core/ps_interface/ps_raw_interface.h"
25+
26+
27+
#include <brpc/controller.h>
28+
#include <sstream>
29+
#include <Eigen/Dense>
30+
#include <iostream>
31+
#include <mutex>
32+
33+
#include "core/ps/ps_server_interface.h"
34+
#include "core/ps/ps_cluster.h"
35+
36+
using namespace tensornet;
37+
38+
namespace tensorflow {
39+
40+
static void NoOpDeleter(void *) {}
41+
42+
template <typename T, bool use_dynamic_cast>
43+
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
44+
45+
const ResourceHandle& HandleFromInput(OpKernelContext* ctx, int input);
46+
47+
class BnStatisticsPushCall {
48+
public:
49+
BnStatisticsPushCall(int table_handle, int shard_id)
50+
: shard_id_(shard_id) {
51+
req.set_req_shard_id(shard_id);
52+
req.set_table_handle(table_handle);
53+
}
54+
55+
~BnStatisticsPushCall() {}
56+
57+
void AddRequestData(butil::IOBuf& k_buf) {
58+
butil::IOBuf &buf = cntl.request_attachment();
59+
buf.append(k_buf);
60+
}
61+
62+
void Start(const tensornet::Callback& done) {
63+
const PsServerInterface* si =
64+
PsCluster::Instance()->GetServer(shard_id_);
65+
si->BnStatisticsPushAsync(&cntl, &req, &resp, done);
66+
}
67+
68+
public:
69+
brpc::Controller cntl;
70+
BnStatisticsPushRequest req;
71+
BnStatisticsPushResponse resp;
72+
73+
private:
74+
int shard_id_ = -1;
75+
};
76+
77+
78+
class BnStatisticsPushKernel : public AsyncOpKernel {
79+
public:
80+
explicit BnStatisticsPushKernel(OpKernelConstruction* c)
81+
: AsyncOpKernel(c) {
82+
OP_REQUIRES_OK(c, c->GetAttr("table_handle", &table_handle_));
83+
OP_REQUIRES_OK(c, c->GetAttr("N", &N_));
84+
OP_REQUIRES_OK(c, c->GetAttr("synchronized", &synchronized_));
85+
}
86+
87+
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
88+
butil::IOBuf acc_buf;
89+
90+
std::vector<double*> allocated_pointers;
91+
92+
for (int i = 0; i < N_; i++) {
93+
const ResourceHandle& handle = HandleFromInput(c, i);
94+
95+
Var* variable = nullptr;
96+
const auto status = LookupResource<Var, false>(c, handle, &variable);
97+
98+
OP_REQUIRES_OK_ASYNC(c, status, done);
99+
CHECK(variable);
100+
101+
Tensor *var_tensor = variable->tensor();
102+
103+
int num_elements = var_tensor->NumElements();
104+
double* dynamic_double_data = new double[num_elements];
105+
const float* float_data = var_tensor->flat<float>().data();
106+
for (int i = 0; i < num_elements; ++i) {
107+
dynamic_double_data[i] = static_cast<double>(float_data[i]);
108+
}
109+
acc_buf.append_user_data(dynamic_double_data, num_elements * sizeof(double), NoOpDeleter);
110+
allocated_pointers.push_back(dynamic_double_data);
111+
}
112+
113+
BnTable* table = BnTableRegistry::Instance()->Get(table_handle_);
114+
table->Append(acc_buf, true);
115+
116+
for (auto ptr : allocated_pointers) {
117+
delete[] ptr;
118+
}
119+
allocated_pointers.clear();
120+
121+
if(synchronized_){
122+
PsCluster* cluster = PsCluster::Instance();
123+
OP_REQUIRES_ASYNC( c, true == cluster->IsInitialized(),
124+
errors::InvalidArgument("cluster instance not initialized:"), done);
125+
126+
butil::IOBuf inc_buf;
127+
table->GetIncStatistics(inc_buf);
128+
129+
std::vector<BnStatisticsPushCall*> calls;
130+
131+
for (size_t shard_id = 0; shard_id < cluster->RankNum(); shard_id++) {
132+
if(shard_id != cluster->Rank()){
133+
auto* call = new BnStatisticsPushCall(table_handle_, shard_id);
134+
call->AddRequestData(inc_buf);
135+
calls.emplace_back(call);
136+
}
137+
}
138+
139+
Semaphore semaphore(calls.size());
140+
141+
for (auto& call : calls) {
142+
call->Start([this, call, &semaphore]() {
143+
semaphore.Notify();
144+
delete call;
145+
});
146+
}
147+
148+
semaphore.WaitForSemaphore();
149+
}
150+
151+
done();
152+
153+
return;
154+
}
155+
156+
private:
157+
int table_handle_;
158+
int N_;
159+
bool synchronized_;
160+
};
161+
162+
REGISTER_KERNEL_BUILDER(Name("BnStatisticsPush").Device(DEVICE_CPU),
163+
BnStatisticsPushKernel);
164+
165+
class UpdateMomentsKernel : public OpKernel {
166+
public:
167+
explicit UpdateMomentsKernel(OpKernelConstruction* c)
168+
: OpKernel(c) {
169+
OP_REQUIRES_OK(c, c->GetAttr("table_handle", &table_handle_));
170+
OP_REQUIRES_OK(c, c->GetAttr("N", &N_));
171+
}
172+
173+
void Compute(OpKernelContext* c) override {
174+
std::vector<Var*> bn_vars;
175+
176+
for (int i = 0; i < N_; i++) {
177+
const ResourceHandle &handle = HandleFromInput(c, i);
178+
179+
Var *variable = nullptr;
180+
const auto status = LookupResource<Var, false>(c, handle, &variable);
181+
182+
OP_REQUIRES_OK(c, status);
183+
CHECK(variable);
184+
bn_vars.emplace_back(variable);
185+
}
186+
187+
BnTable* table = BnTableRegistry::Instance()->Get(table_handle_);
188+
189+
std::tuple<Eigen::ArrayXf, Eigen::ArrayXf> moments_tuple = table->GetMoments();
190+
191+
auto& global_mean_var = bn_vars[0];
192+
float* global_mean_flat = global_mean_var->tensor()->flat<float>().data();
193+
std::copy(std::get<0>(moments_tuple).data(), std::get<0>(moments_tuple).data() + std::get<0>(moments_tuple).size(), global_mean_flat);
194+
195+
auto& global_var_var = bn_vars[1];
196+
float* global_var_flat = global_var_var->tensor()->flat<float>().data();
197+
std::copy(std::get<1>(moments_tuple).data(), std::get<1>(moments_tuple).data() + std::get<1>(moments_tuple).size(), global_var_flat);
198+
199+
return;
200+
}
201+
202+
private:
203+
int table_handle_;
204+
int N_;
205+
};
206+
207+
208+
REGISTER_KERNEL_BUILDER(Name("UpdateMoments").Device(DEVICE_CPU),
209+
UpdateMomentsKernel);
210+
211+
class BnStatisticsPullCall {
212+
public:
213+
BnStatisticsPullCall(int table_handle, int shard_id)
214+
: shard_id_(shard_id) {
215+
req.set_req_shard_id(shard_id);
216+
req.set_table_handle(table_handle);
217+
}
218+
219+
~BnStatisticsPullCall() {}
220+
221+
void Start(const tensornet::Callback& done) {
222+
const PsServerInterface* si =
223+
PsCluster::Instance()->GetServer(shard_id_);
224+
si->BnStatisticsPullAsync(&cntl, &req, &resp, done);
225+
}
226+
227+
public:
228+
brpc::Controller cntl;
229+
BnStatisticsPullRequest req;
230+
BnStatisticsPullResponse resp;
231+
232+
private:
233+
int shard_id_ = -1;
234+
};
235+
236+
237+
class BnStatisticsPullKernel : public AsyncOpKernel {
238+
public:
239+
explicit BnStatisticsPullKernel(OpKernelConstruction* c)
240+
: AsyncOpKernel(c) {
241+
OP_REQUIRES_OK(c, c->GetAttr("table_handle", &table_handle_));
242+
OP_REQUIRES_OK(c, c->GetAttr("N", &N_));
243+
}
244+
245+
void ComputeAsync(OpKernelContext* c, DoneCallback done) override {
246+
247+
std::vector<Var*> bn_vars;
248+
249+
for (int i = 0; i < N_; i++) {
250+
const ResourceHandle &handle = HandleFromInput(c, i);
251+
252+
Var *variable = nullptr;
253+
const auto status = LookupResource<Var, false>(c, handle, &variable);
254+
255+
OP_REQUIRES_OK(c, status);
256+
CHECK(variable);
257+
bn_vars.emplace_back(variable);
258+
}
259+
260+
PsCluster* cluster = PsCluster::Instance();
261+
OP_REQUIRES_ASYNC(
262+
c, true == cluster->IsInitialized(),
263+
errors::InvalidArgument("cluster instance not initialized:"), done);
264+
265+
BnTable *table = BnTableRegistry::Instance()->Get(table_handle_);
266+
std::vector<BnStatisticsPullCall*> calls;
267+
268+
for (size_t shard_id = 0; shard_id < cluster->RankNum(); shard_id++) {
269+
if(shard_id != cluster->Rank()){
270+
calls.emplace_back(
271+
new BnStatisticsPullCall(table_handle_, shard_id));
272+
}
273+
}
274+
275+
Semaphore semaphore(calls.size());
276+
277+
for (auto& call : calls) {
278+
call->Start([this, call, &table, &semaphore]() {
279+
table->Append(call->cntl.response_attachment(), false);
280+
semaphore.Notify();
281+
delete call;
282+
});
283+
}
284+
285+
semaphore.WaitForSemaphore();
286+
std::tuple<Eigen::ArrayXf, Eigen::ArrayXf> moments_tuple = table->GetMoments();
287+
288+
auto& global_mean_var = bn_vars[0];
289+
float* global_mean_flat = global_mean_var->tensor()->flat<float>().data();
290+
std::copy(std::get<0>(moments_tuple).data(), std::get<0>(moments_tuple).data() + std::get<0>(moments_tuple).size(), global_mean_flat);
291+
292+
auto& global_var_var = bn_vars[1];
293+
float* global_var_flat = global_var_var->tensor()->flat<float>().data();
294+
std::copy(std::get<1>(moments_tuple).data(), std::get<1>(moments_tuple).data() + std::get<1>(moments_tuple).size(), global_var_flat);
295+
296+
done();
297+
298+
return;
299+
}
300+
301+
private:
302+
int table_handle_;
303+
int N_;
304+
};
305+
306+
REGISTER_KERNEL_BUILDER(Name("BnStatisticsPull").Device(DEVICE_CPU),
307+
BnStatisticsPullKernel);
308+
309+
};

0 commit comments

Comments
 (0)