File tree Expand file tree Collapse file tree 1 file changed +13
-7
lines changed
torchvision/csrc/ops/cuda Expand file tree Collapse file tree 1 file changed +13
-7
lines changed Original file line number Diff line number Diff line change @@ -228,7 +228,9 @@ void deformable_im2col(
228
228
int deformable_group,
229
229
bool use_mask,
230
230
at::Tensor data_col) {
231
- int64_t num_kernels = (int64_t )n_in_channels * out_h * out_w * parallel_imgs;
231
+ at::cuda::CUDAGuard device_guard (input.get_device ());
232
+
233
+ const int64_t num_kernels = (int64_t )n_in_channels * out_h * out_w * parallel_imgs;
232
234
233
235
const unsigned int threads = GET_THREADS ();
234
236
const unsigned int blocks = GET_BLOCKS (threads, num_kernels);
@@ -408,12 +410,14 @@ void compute_grad_input(
408
410
int n_offset_grps,
409
411
bool use_mask,
410
412
at::Tensor grad_im) {
411
- int out_h =
413
+ at::cuda::CUDAGuard device_guard (columns.get_device ());
414
+
415
+ const int out_h =
412
416
(height + 2 * pad_h - (dilation_h * (weight_h - 1 ) + 1 )) / stride_h + 1 ;
413
- int out_w =
417
+ const int out_w =
414
418
(width + 2 * pad_w - (dilation_w * (weight_w - 1 ) + 1 )) / stride_w + 1 ;
415
419
416
- int64_t num_kernels =
420
+ const int64_t num_kernels =
417
421
(int64_t )channels * weight_h * weight_w * out_h * out_w * parallel_imgs;
418
422
419
423
const unsigned int threads = GET_THREADS ();
@@ -650,11 +654,13 @@ void compute_grad_offset_and_mask(
650
654
bool use_mask,
651
655
at::Tensor grad_offset,
652
656
at::Tensor grad_mask) {
653
- int out_h =
657
+ at::cuda::CUDAGuard device_guard (columns.get_device ());
658
+
659
+ const int out_h =
654
660
(height + 2 * pad_h - (dilation_h * (weight_h - 1 ) + 1 )) / stride_h + 1 ;
655
- int out_w =
661
+ const int out_w =
656
662
(width + 2 * pad_w - (dilation_w * (weight_w - 1 ) + 1 )) / stride_w + 1 ;
657
- int64_t num_kernels = (int64_t )out_h * out_w * 2 * weight_h * weight_w *
663
+ const int64_t num_kernels = (int64_t )out_h * out_w * 2 * weight_h * weight_w *
658
664
n_offset_grps * parallel_imgs;
659
665
660
666
const unsigned int threads = GET_THREADS ();
You can’t perform that action at this time.
0 commit comments