@@ -349,26 +349,29 @@ void LaunchBmmCustomKernel(OpKernelContext* ctx, const T* A, const T* B, T* C,
349
349
sycl::range<3 > local{1 , BS_X, BS_Y};
350
350
Tensor A_offset_tensor, B_offset_tensor;
351
351
352
+ if (src_dims > 3 && is_bcast_required) {
353
+ const std::vector<int64_t >& x_batch_indices = bcast.x_batch_indices ();
354
+ const std::vector<int64_t >& y_batch_indices = bcast.y_batch_indices ();
355
+ OP_REQUIRES_OK (ctx,
356
+ ctx->allocate_temp (DataTypeToEnum<int64_t >::value,
357
+ TensorShape ({bs}), &A_offset_tensor));
358
+ OP_REQUIRES_OK (ctx,
359
+ ctx->allocate_temp (DataTypeToEnum<int64_t >::value,
360
+ TensorShape ({bs}), &B_offset_tensor));
361
+ stream
362
+ ->memcpy (GetTensorBuffer<int64_t >(&A_offset_tensor),
363
+ x_batch_indices.data (), bs * sizeof (int64_t ))
364
+ .wait ();
365
+ stream
366
+ ->memcpy (GetTensorBuffer<int64_t >(&B_offset_tensor),
367
+ y_batch_indices.data (), bs * sizeof (int64_t ))
368
+ .wait ();
369
+ }
370
+
352
371
stream->submit ([&](sycl::handler& cgh) {
353
372
LocalAcc<T> Asub (sycl::range<2 >{c_M * BS_X, TILE_K}, cgh);
354
373
LocalAcc<T> Bsub (sycl::range<2 >{TILE_K, c_P * BS_Y}, cgh);
355
374
if (src_dims > 3 && is_bcast_required) {
356
- const std::vector<int64_t >& x_batch_indices = bcast.x_batch_indices ();
357
- const std::vector<int64_t >& y_batch_indices = bcast.y_batch_indices ();
358
- OP_REQUIRES_OK (ctx,
359
- ctx->allocate_temp (DataTypeToEnum<int64_t >::value,
360
- TensorShape ({bs}), &A_offset_tensor));
361
- OP_REQUIRES_OK (ctx,
362
- ctx->allocate_temp (DataTypeToEnum<int64_t >::value,
363
- TensorShape ({bs}), &B_offset_tensor));
364
- stream
365
- ->memcpy (GetTensorBuffer<int64_t >(&A_offset_tensor),
366
- x_batch_indices.data (), bs * sizeof (int64_t ))
367
- .wait ();
368
- stream
369
- ->memcpy (GetTensorBuffer<int64_t >(&B_offset_tensor),
370
- y_batch_indices.data (), bs * sizeof (int64_t ))
371
- .wait ();
372
375
BatchMatMulWithBcastKernel<T, c_M, c_P, BS_X, BS_Y, TILE_K, TILE_AB> task (
373
376
A, B, C, bs, M, N, P, Asub, Bsub, adj_A, adj_B,
374
377
static_cast <int64_t *>(GetTensorBuffer<int64_t >(&A_offset_tensor)),
0 commit comments