Skip to content

Commit ecb0b79

Browse files
r-barnesfacebook-github-bot
authored andcommitted
Fix missing kernel guards (#455)
Summary: Pull Request resolved: pytorch/nestedtensor#455 Fixes missing kernel guards as identified by D30072495 Differential Revision: D31553158 fbshipit-source-id: 73e4a82d6c9164f8e792211a93c45bee849cc897
1 parent 50ecbfd commit ecb0b79

File tree

1 file changed

+13
-7
lines changed

1 file changed

+13
-7
lines changed

torchvision/csrc/ops/cuda/deform_conv2d_kernel.cu

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,9 @@ void deformable_im2col(
228228
int deformable_group,
229229
bool use_mask,
230230
at::Tensor data_col) {
231-
int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs;
231+
at::cuda::CUDAGuard device_guard(input.get_device());
232+
233+
const int64_t num_kernels = (int64_t)n_in_channels * out_h * out_w * parallel_imgs;
232234

233235
const unsigned int threads = GET_THREADS();
234236
const unsigned int blocks = GET_BLOCKS(threads, num_kernels);
@@ -408,12 +410,14 @@ void compute_grad_input(
408410
int n_offset_grps,
409411
bool use_mask,
410412
at::Tensor grad_im) {
411-
int out_h =
413+
at::cuda::CUDAGuard device_guard(columns.get_device());
414+
415+
const int out_h =
412416
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
413-
int out_w =
417+
const int out_w =
414418
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
415419

416-
int64_t num_kernels =
420+
const int64_t num_kernels =
417421
(int64_t)channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
418422

419423
const unsigned int threads = GET_THREADS();
@@ -650,11 +654,13 @@ void compute_grad_offset_and_mask(
650654
bool use_mask,
651655
at::Tensor grad_offset,
652656
at::Tensor grad_mask) {
653-
int out_h =
657+
at::cuda::CUDAGuard device_guard(columns.get_device());
658+
659+
const int out_h =
654660
(height + 2 * pad_h - (dilation_h * (weight_h - 1) + 1)) / stride_h + 1;
655-
int out_w =
661+
const int out_w =
656662
(width + 2 * pad_w - (dilation_w * (weight_w - 1) + 1)) / stride_w + 1;
657-
int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w *
663+
const int64_t num_kernels = (int64_t)out_h * out_w * 2 * weight_h * weight_w *
658664
n_offset_grps * parallel_imgs;
659665

660666
const unsigned int threads = GET_THREADS();

0 commit comments

Comments
 (0)