@@ -267,7 +267,7 @@ class SparseTablePullKernel : public AsyncOpKernel {
267
267
float * w_matrix = var_tensor->matrix <float >().data ();
268
268
269
269
size_t emb_size = sizeof (float ) * dim;
270
- CHECK_EQ (emb_size, emb_buf.cutn (w_matrix + sign_index * dim, emb_size));
270
+ CHECK_EQ (emb_size, emb_buf.cutn (w_matrix + sign_index * dim , emb_size));
271
271
}
272
272
}
273
273
@@ -281,22 +281,26 @@ REGISTER_KERNEL_BUILDER(Name("SparseTablePull").Device(DEVICE_CPU),
281
281
282
282
struct SparsePushVarInfo {
283
283
public:
284
- SparsePushVarInfo (const Tensor* t_value, const Tensor* t_grad)
284
+ SparsePushVarInfo (const Tensor* t_value, const Tensor* t_grad, const Tensor* t_labels )
285
285
: value(t_value)
286
- , grad(t_grad) {
286
+ , grad(t_grad)
287
+ , labels(t_labels) {
287
288
288
289
const int64* feasign_vec = value->flat <int64>().data ();
290
+ const int64* fea_label_vec = t_labels->flat <int64>().data ();
289
291
290
292
std::map<uint64, int > sign_id_mapping;
291
293
for (int i = 0 ; i < value->NumElements (); ++i) {
292
294
uint64 sign = (uint64)feasign_vec[i];
295
+ int label = static_cast <int >(fea_label_vec[i]);
293
296
auto ret = sign_id_mapping.insert ({sign, sign_id_mapping.size ()});
294
297
295
298
if (ret.second ) {
296
- virtual_sign_infos.emplace_back (sign, 1 );
299
+ virtual_sign_infos.emplace_back (sign, 1 , label );
297
300
} else {
298
301
auto iter = ret.first ;
299
302
virtual_sign_infos[iter->second ].batch_show += 1 ;
303
+ virtual_sign_infos[iter->second ].batch_click += label;
300
304
}
301
305
}
302
306
}
@@ -308,6 +312,7 @@ struct SparsePushVarInfo {
308
312
public:
309
313
const Tensor* value;
310
314
const Tensor* grad;
315
+ const Tensor* labels;
311
316
312
317
std::vector<SparsePushSignInfo> virtual_sign_infos;
313
318
};
@@ -321,16 +326,17 @@ class SparseTablePushKernel : public AsyncOpKernel {
321
326
}
322
327
323
328
void ComputeAsync (OpKernelContext* c, DoneCallback done) override {
324
- OP_REQUIRES_ASYNC (c, c->num_inputs () == N_ * 2 ,
329
+ OP_REQUIRES_ASYNC (c, c->num_inputs () == N_ * 3 ,
325
330
errors::InvalidArgument (" SparseTable push num_inputs:" ,
326
331
c->num_inputs (),
327
- " not equal:" , N_ * 2 ),
332
+ " not equal:" , N_ * 3 ),
328
333
done);
329
334
std::vector<SparsePushVarInfo> var_infos;
330
335
331
336
for (int i = 0 ; i < N_; i++) {
332
337
const Tensor* value = &c->input (i);
333
338
const Tensor* grad = &c->input (N_ + i);
339
+ const Tensor* labels = &c->input (2 * N_ + i);
334
340
335
341
OP_REQUIRES_ASYNC (
336
342
c, TensorShapeUtils::IsMatrix (grad->shape ()),
@@ -339,7 +345,7 @@ class SparseTablePushKernel : public AsyncOpKernel {
339
345
grad->shape ().DebugString ()),
340
346
done);
341
347
342
- var_infos.emplace_back (value, grad);
348
+ var_infos.emplace_back (value, grad, labels );
343
349
}
344
350
345
351
CHECK_GT (var_infos.size (), 0 );
0 commit comments