Skip to content

Commit 66dd673

Browse files
cboss6mini-goel
authored andcommitted
[Refactor] Refactor code to avoid potential nested queue submit issue. (#2695)
1 parent 8decc58 commit 66dd673

File tree

1 file changed

+19
-16
lines changed

1 file changed

+19
-16
lines changed

itex/core/kernels/gpu/matmul_op.cc

+19-16
Original file line numberDiff line numberDiff line change
@@ -349,26 +349,29 @@ void LaunchBmmCustomKernel(OpKernelContext* ctx, const T* A, const T* B, T* C,
349349
sycl::range<3> local{1, BS_X, BS_Y};
350350
Tensor A_offset_tensor, B_offset_tensor;
351351

352+
if (src_dims > 3 && is_bcast_required) {
353+
const std::vector<int64_t>& x_batch_indices = bcast.x_batch_indices();
354+
const std::vector<int64_t>& y_batch_indices = bcast.y_batch_indices();
355+
OP_REQUIRES_OK(ctx,
356+
ctx->allocate_temp(DataTypeToEnum<int64_t>::value,
357+
TensorShape({bs}), &A_offset_tensor));
358+
OP_REQUIRES_OK(ctx,
359+
ctx->allocate_temp(DataTypeToEnum<int64_t>::value,
360+
TensorShape({bs}), &B_offset_tensor));
361+
stream
362+
->memcpy(GetTensorBuffer<int64_t>(&A_offset_tensor),
363+
x_batch_indices.data(), bs * sizeof(int64_t))
364+
.wait();
365+
stream
366+
->memcpy(GetTensorBuffer<int64_t>(&B_offset_tensor),
367+
y_batch_indices.data(), bs * sizeof(int64_t))
368+
.wait();
369+
}
370+
352371
stream->submit([&](sycl::handler& cgh) {
353372
LocalAcc<T> Asub(sycl::range<2>{c_M * BS_X, TILE_K}, cgh);
354373
LocalAcc<T> Bsub(sycl::range<2>{TILE_K, c_P * BS_Y}, cgh);
355374
if (src_dims > 3 && is_bcast_required) {
356-
const std::vector<int64_t>& x_batch_indices = bcast.x_batch_indices();
357-
const std::vector<int64_t>& y_batch_indices = bcast.y_batch_indices();
358-
OP_REQUIRES_OK(ctx,
359-
ctx->allocate_temp(DataTypeToEnum<int64_t>::value,
360-
TensorShape({bs}), &A_offset_tensor));
361-
OP_REQUIRES_OK(ctx,
362-
ctx->allocate_temp(DataTypeToEnum<int64_t>::value,
363-
TensorShape({bs}), &B_offset_tensor));
364-
stream
365-
->memcpy(GetTensorBuffer<int64_t>(&A_offset_tensor),
366-
x_batch_indices.data(), bs * sizeof(int64_t))
367-
.wait();
368-
stream
369-
->memcpy(GetTensorBuffer<int64_t>(&B_offset_tensor),
370-
y_batch_indices.data(), bs * sizeof(int64_t))
371-
.wait();
372375
BatchMatMulWithBcastKernel<T, c_M, c_P, BS_X, BS_Y, TILE_K, TILE_AB> task(
373376
A, B, C, bs, M, N, P, Asub, Bsub, adj_A, adj_B,
374377
static_cast<int64_t*>(GetTensorBuffer<int64_t>(&A_offset_tensor)),

0 commit comments

Comments
 (0)