-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcommon.lua
398 lines (363 loc) · 14.5 KB
/
common.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
local ffi = require('ffi')
local fio = require('fio')
local fun = require('fun')
local log = require('log')
local json = require('json')
local yaml = require('yaml')
local clock = require('clock')
local errno = require('errno')
local fiber = require('fiber')
local digest = require('digest')
local pmaster = require('pregel.master')
local pworker = require('pregel.worker')
local avro = require('pregel.avro')
local xpcall_tb = require('pregel.utils').xpcall_tb
local deepcopy = require('pregel.utils.copy').deep
local GradientDescent = require('math.gd').GradientDescent
local PercentileCounter = require('math.pc').PercentileCounter
local TaskReport_new = require('report').new
local avro_loaders = require('avro_loaders')
local constants = require('constants')
local utils = require('utils')
local fdup = fun.duplicate
local worker, port_offset = arg[0]:match('(%a+)-(%d+)')
port_offset = port_offset or 0
if worker == 'worker' then
box.cfg{
wal_mode = 'none',
slab_alloc_arena = 3,
listen = '0.0.0.0:' .. tostring(3301 + port_offset),
background = true,
logger_nonblock = true
}
else
box.cfg{
slab_alloc_arena = 0.1,
wal_mode = 'none',
listen = '0.0.0.0:' .. tostring(3301 + port_offset),
logger_nonblock = true
}
end
box.schema.user.grant('guest', 'read,write,execute', 'universe', nil, {
if_not_exists = true
})
--[[------------------------------------------------------------------------]]--
--[[--------------------------------- Utils --------------------------------]]--
--[[------------------------------------------------------------------------]]--
local math_round = utils.math_round
local log_features = utils.log_features
local NULL = json.NULL
--[[------------------------------------------------------------------------]]--
--[[--------------------------- Job configuration --------------------------]]--
--[[------------------------------------------------------------------------]]--
local vertex_type = constants.vertex_type
local message_command = constants.message_command
local node_status = constants.node_status
local task_phase = constants.task_phase
-- keys of dataSet
local dataSetKeys = constants.dataSetKeys
-- config keys
local FEATURES_LIST = constants.FEATURES_LIST
local DATASET_PATH = constants.DATASET_PATH
-- other
local SUFFIX_TRAIN = constants.SUFFIX_TRAIN
local SUFFIX_TEST = constants.SUFFIX_TEST
local DISTRIBUTED_GD_GROUP = constants.DISTRIBUTED_GD_GROUP
local MISSING_USERS_COUNT = constants.MISSING_USERS_COUNT
local MASTER_VERTEX_TYPE = constants.MASTER_VERTEX_TYPE
local TASK_VERTEX_TYPE = constants.TASK_VERTEX_TYPE
-- Parameters of gradient descend / algorithm
local GDParams = constants.GDParams
--[[------------------------------------------------------------------------]]--
--[[----------------------------- Worker Context ---------------------------]]--
--[[------------------------------------------------------------------------]]--
local wc = nil
do
local featureList = {}
local featureMap = {}
local randomVertexIds = {} -- List of vertices
local taskPhases = {} -- taskName -> phase
local taskDeploymentConfigs = {} -- taskName:uid_type -> config
local taskDataSet = {} -- taskName -> data_set_path
local taskReport = {} -- taskName -> report
local predictionReportSamplingProb = GDParams['p.report.prediction']
local calibrationBucketPercents = GDParams['calibration.bucket.percents']
-- Open/Parse file with features
local fd = io.open(fio.pathjoin(DATASET_PATH, 'features.txt'))
assert(fd, "Can't open file for reading")
local line, input = fd:read('*a'), nil
assert(line, "Bad input")
local n = 1
while true do
if line:match('\n') ~= nil then
input, line = line:match('([^\n]+)\n(.*)')
elseif #line > 0 then
input = line
line = ''
else
break
end
table.insert(featureList, input)
featureMap[input] = n
n = n + 1
end
log.info("<worker_context> Found %d features", #featureList)
fd:close()
local fd = io.open(fio.pathjoin(DATASET_PATH, 'prediction_config.json'))
assert(fd, "Can't open file for reading")
local input = json.decode(fd:read('*a'))
assert(input, "Bad input")
fd:close()
local function check_file(fname)
local file = fio.open(fname, {'O_RDONLY'})
if file == nil then
local errstr = "Can't open file '%s' for reading: %s [errno %d]"
errstr = string.format(errstr, fname, errno.strerror(), errno())
error(errstr)
end
file:close()
return fname
end
for _, task_config in ipairs(input) do
local name = task_config['name']
local input = task_config['input']
for _, deployment in ipairs(task_config['deployment']) do
local key = ('%s:%s'):format(name, deployment['user_id_type'])
taskDeploymentConfigs[key] = deployment
end
taskPhases[name] = task_phase.SELECTION
taskReport[name] = TaskReport_new(name)
log.info("<worker_context> Added config for '%s'", name)
end
local dataSetKeysTypes = {'email', 'okid', 'vkid'}
wc = setmetatable({
featureList = featureList,
featureMap = featureMap,
randomVertexIds = randomVertexIds,
taskReport = taskReport,
taskPhases = taskPhases,
taskDeploymentConfigs = taskDeploymentConfigs,
taskDataSet = taskDataSet,
calibrationBucketPercents = calibrationBucketPercents,
predictionReportSamplingProb = predictionReportSamplingProb,
}, {
__index = {
addRandomVertex = function(self, vertexId)
table.insert(self.randomVertexIds, vertexId)
end,
iterateRandomVertexIds = function(self)
return ipairs(self.randomVertexIds)
end,
setTaskPhase = function(self, name, phase)
self.taskPhases[name] = phase
end,
getTaskPhase = function(self, name)
return self.taskPhases[name] or task_phase.SELECTION
end,
iterateDataSet = function(self, name)
local function processDataSet(input)
local category = input.category and input.category.int
local output = fun.iter(dataSetKeysTypes):map(function(fname)
if input[fname] ~= nil then
return {fname, input[fname].string, category}
end
end):totable()
local vid = input['vid']
if vid and #vid > 0 then
table.insert(output, {'vid', vid, category})
end
return output
end
local file = fio.open(self.taskDataSet[name][1], {'O_RDONLY'})
local line = ''
local errstr = "Can't open file '%s' for reading: %s [errno %d]"
errstr = string.format(errstr, self.taskDataSet[name][1],
errno.strerror(), errno())
assert(file ~= nil, errstr)
local function iterator()
while true do
local input
if line:find('\n') == nil then
local rv = file:read(65536)
if #rv == 0 then
file:close()
return nil
else
line = line .. rv
end
else
input, line = line:match('([^\n]*)\n(.*)')
input = json.decode(input)
return processDataSet(input)
end
end
end
return iterator, nil
end,
iterateDataSetWrap = function(self, name)
local iter_func = self:iterateDataSet(name)
local last_item = {}
local category = 1
local iterator = function()
if last_item[1] == nil then
last_item = iter_func()
if last_item == nil then
return
end
end
return table.remove(last_item)
end
return iterator, nil
end,
storeDataSet = function(self, name)
for tuple in self:iterateDataSetWrap(name) do
self.taskDataSet[name][2]:replace(tuple)
end
end,
addAggregators = function(self, instance)
log.info("<worker_context> Adding aggregators")
for taskName, _ in pairs(self.taskPhases) do
instance:add_aggregator(taskName, {
default = {
name = nil,
command = message_command.NONE,
target = 0.0,
features = {}
},
merge = function(old, new)
if new ~= nil and
(old == nil or new.command > old.command) then
return deepcopy(new)
end
return old
end
})
end
return instance
end,
getTaskReport = function(self, name)
return self.taskReport[name]
end
}
})
if worker == 'worker' then
for _, task_config in ipairs(input) do
local name = task_config['name']
local input = task_config['input']
local fname = check_file(fio.pathjoin(DATASET_PATH, input))
local sname = ('wc_%s_ds'):format(name)
local fspace = box.space[sname]
taskDataSet[name] = {fname, fspace}
if box.space[sname] == nil then
local space = box.schema.create_space(sname, {
format = {
[1] = {name = 'id_type', type = 'str'},
[2] = {name = 'id', type = 'str'},
[3] = {name = 'category', type = 'num'}
}
})
space:create_index('primary', {
type = 'TREE',
parts = {1, 'STR', 2, 'STR'}
})
taskDataSet[name][2] = space
log.info("<worker_context> Begin preloading data for '%s'", name)
wc:storeDataSet(name)
log.info("<worker_context> Data stored for '%s'", name)
end
log.info("<worker_context> Done loading dataSet for '%s'", name)
end
end
log.info('<worker_context> Initialized:')
log.info('<worker_context> taskDeploymentConfigs:')
fun.iter(taskDeploymentConfigs):each(function(name, config)
log.info('<worker_context> %s -> %s', name, json.encode(config))
end)
end
--[[------------------------------------------------------------------------]]--
--[[------------------------ Configuration of Runner -----------------------]]--
--[[------------------------------------------------------------------------]]--
local vertex_mt = nil
local node_master_mt = require('node_master').mt
local node_task_mt = require('node_task').mt
local node_data_mt = require('node_data').mt
local function computeGradientDescent(vertex)
if vertex_mt == nil then
vertex_mt = getmetatable(vertex)
end
local vtype = vertex:get_value().vtype
if vtype == vertex_type.MASTER then
setmetatable(vertex, node_master_mt)
elseif vtype == vertex_type.TASK then
if vertex:get_superstep() == 0 then
return
end
setmetatable(vertex, node_task_mt)
else
setmetatable(vertex, node_data_mt)
end
vertex:compute_new()
setmetatable(vertex, vertex_mt)
end
local function generate_worker_uri(hosts, count)
count = count or 8
return fun.chain(
unpack(
fun.iter(hosts):zip(fdup(count)):map(function(host, count)
-- create iterator with workers for next host in line
local new_it = fun.range(count):zip(
fdup(host)
):map(function(port, host)
return ('%s:%d'):format(host, tostring(3301 + port))
end)
-- chain new iterator or set it in the first place
it = it and new_it or it:chain(new_it)
end):totable()
)
):totable()
end
local common_cfg = {
master = constants.MASTER_URI,
workers = generate_worker_uri(
constants.HOSTS_LIST,
constants.INSTANCE_COUNT
),
compute = computeGradientDescent,
combiner = nil,
master_preload = avro_loaders.master,
worker_preload = avro_loaders.worker_additional,
preload_args = {
path = DATASET_PATH,
feature_count = 300,
-- vertex_count = 17600000,
},
squash_only = false,
pool_size = 250,
delayed_push = false,
obtain_name = utils.obtain_name,
worker_context = wc
}
if worker == 'worker' then
worker = pworker.new('test', common_cfg)
wc:addAggregators(worker)
else
xpcall_tb(function()
local master = pmaster.new('test', common_cfg)
wc:addAggregators(master)
master:wait_up()
if arg[1] == 'load' then
-- master:preload()
master:preload_on_workers()
master:save_snapshot()
end
master.mpool:by_id('MASTER:'):put('vertex.store', {
key = {vid = 'MASTER', category = 0},
features = fdup(0.0):take(#wc.featureList):totable(),
vtype = constants.vertex_type.MASTER,
status = constants.node_status.NEW
})
master.mpool:flush()
master:start()
end)
os.exit(0)
end