diff --git a/setup.py b/setup.py index 887cce2..c20977a 100644 --- a/setup.py +++ b/setup.py @@ -2,6 +2,7 @@ import glob import os +import sys import torch from setuptools import find_packages, setup @@ -37,6 +38,8 @@ def get_extensions(): else: # raise NotImplementedError('Cuda is not available') pass + + extra_compile_args['cxx'].append('-fopenmp') sources = [os.path.join(extensions_dir, s) for s in sources] include_dirs = [extensions_dir] diff --git a/src/cpu/dcn_v2_cpu.cpp b/src/cpu/dcn_v2_cpu.cpp index 8d76c28..25b6e9c 100644 --- a/src/cpu/dcn_v2_cpu.cpp +++ b/src/cpu/dcn_v2_cpu.cpp @@ -36,11 +36,11 @@ dcn_v2_cpu_forward(const at::Tensor &input, const int deformable_group) { // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); - /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); - AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); - AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); - AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/ + /*AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor"); + AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor"); + AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor");*/ const int batch = input.size(0); const int channels = input.size(1); @@ -126,11 +126,11 @@ std::vector dcn_v2_cpu_backward(const at::Tensor &input, THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); - /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); - AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); - AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); - AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor");*/ + /*AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor"); + AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor"); + AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor");*/ const int batch = input.size(0); const int channels = input.size(1); diff --git a/src/cpu/dcn_v2_psroi_pooling_cpu.cpp b/src/cpu/dcn_v2_psroi_pooling_cpu.cpp index 6e41aae..553cb35 100644 --- a/src/cpu/dcn_v2_psroi_pooling_cpu.cpp +++ b/src/cpu/dcn_v2_psroi_pooling_cpu.cpp @@ -288,11 +288,11 @@ dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, const int sample_per_part, const float trans_std) { - /*AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor"); - AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor");*/ + /*AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(bbox.is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor");*/ - const int batch = input.size(0); + // const int batch = input.size(0); const int channels = input.size(1); const int height = input.size(2); const int width = input.size(3); @@ -321,17 +321,17 @@ dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, /*dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); dim3 block(512);*/ - AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cpu_forward", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "dcn_v2_psroi_pooling_cpu_forward", [&] { DeformablePSROIPoolForwardKernelCpu( out_size, - input.contiguous().data(), + input.contiguous().data_ptr(), spatial_scale, channels, height, width, pooled_height, pooled_width, - bbox.contiguous().data(), - trans.contiguous().data(), + bbox.contiguous().data_ptr(), + trans.contiguous().data_ptr(), no_trans, trans_std, sample_per_part, @@ -340,8 +340,8 @@ dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, part_size, num_classes, channels_each_class, - out.data(), - top_count.data()); + out.data_ptr(), + top_count.data_ptr()); }); //THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); @@ -362,11 +362,11 @@ dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, const int sample_per_part, const float trans_std) { - /*AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor"); - AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor"); - AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); - AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor");*/ + /*AT_ASSERTM(out_grad.is_cuda(), "out_grad must be a CUDA tensor"); + AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(bbox.is_cuda(), "bbox must be a CUDA tensor"); + AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor"); + AT_ASSERTM(top_count.is_cuda(), "top_count must be a CUDA tensor");*/ const int batch = input.size(0); const int channels = input.size(1); @@ -395,11 +395,11 @@ dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, dim3 block(512); cudaStream_t stream = at::cuda::getCurrentCUDAStream();*/ - AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cpu_backward", [&] { + AT_DISPATCH_FLOATING_TYPES(out_grad.scalar_type(), "dcn_v2_psroi_pooling_cpu_backward", [&] { DeformablePSROIPoolBackwardAccKernelCpu( out_size, - out_grad.contiguous().data(), - top_count.contiguous().data(), + out_grad.contiguous().data_ptr(), + top_count.contiguous().data_ptr(), num_bbox, spatial_scale, channels, @@ -408,11 +408,11 @@ dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, pooled_height, pooled_width, output_dim, - input_grad.contiguous().data(), - trans_grad.contiguous().data(), - input.contiguous().data(), - bbox.contiguous().data(), - trans.contiguous().data(), + input_grad.contiguous().data_ptr(), + trans_grad.contiguous().data_ptr(), + input.contiguous().data_ptr(), + bbox.contiguous().data_ptr(), + trans.contiguous().data_ptr(), no_trans, trans_std, sample_per_part, diff --git a/src/cuda/dcn_v2_cuda.cu b/src/cuda/dcn_v2_cuda.cu index a23180e..220f398 100644 --- a/src/cuda/dcn_v2_cuda.cu +++ b/src/cuda/dcn_v2_cuda.cu @@ -57,11 +57,11 @@ dcn_v2_cuda_forward(const at::Tensor &input, { using scalar_t = float; // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); - AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); - AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); - AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); - AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); + AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor"); + AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor"); + AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); @@ -108,12 +108,12 @@ dcn_v2_cuda_forward(const at::Tensor &input, input_b, output_b, columns_b, ones_b, weight_b, bias_b, - input.data(), - output.data(), - columns.data(), - ones.data(), - weight.data(), - bias.data(), + input.data_ptr(), + output.data_ptr(), + columns.data_ptr(), + ones.data_ptr(), + weight.data_ptr(), + bias.data_ptr(), channels * width * height, channels_out * width_out * height_out, channels * kernel_h * kernel_w * height_out * width_out, @@ -137,14 +137,14 @@ dcn_v2_cuda_forward(const at::Tensor &input, batch); modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), - input.data(), - offset.data(), - mask.data(), + input.data_ptr(), + offset.data_ptr(), + mask.data_ptr(), batch, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns.data()); + columns.data_ptr()); long m = channels_out; long n = height_out * width_out; @@ -219,11 +219,11 @@ std::vector dcn_v2_cuda_backward(const at::Tensor &input, THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); - AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); - AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); - AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); - AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); + AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(weight.is_cuda(), "weight must be a CUDA tensor"); + AT_ASSERTM(bias.is_cuda(), "bias must be a CUDA tensor"); + AT_ASSERTM(offset.is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.is_cuda(), "mask must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); @@ -271,52 +271,52 @@ std::vector dcn_v2_cuda_backward(const at::Tensor &input, long k = channels_out; THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, - grad_output_n.data(), n, - weight.data(), m, 0.0f, - columns.data(), n); + grad_output_n.data_ptr(), n, + weight.data_ptr(), m, 0.0f, + columns.data_ptr(), n); // gradient w.r.t. input coordinate data modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(), - columns.data(), - input_n.data(), - offset_n.data(), - mask_n.data(), + columns.data_ptr(), + input_n.data_ptr(), + offset_n.data_ptr(), + mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - grad_offset_n.data(), - grad_mask_n.data()); + grad_offset_n.data_ptr(), + grad_mask_n.data_ptr()); // gradient w.r.t. input data modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(), - columns.data(), - offset_n.data(), - mask_n.data(), + columns.data_ptr(), + offset_n.data_ptr(), + mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - grad_input_n.data()); + grad_input_n.data_ptr()); // gradient w.r.t. weight, dWeight should accumulate across the batch and group modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(), - input_n.data(), - offset_n.data(), - mask_n.data(), + input_n.data_ptr(), + offset_n.data_ptr(), + mask_n.data_ptr(), 1, channels, height, width, height_out, width_out, kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, deformable_group, - columns.data()); + columns.data_ptr()); long m_ = channels_out; long n_ = channels * kernel_h * kernel_w; long k_ = height_out * width_out; THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, - columns.data(), k_, - grad_output_n.data(), k_, 1.0f, - grad_weight.data(), n_); + columns.data_ptr(), k_, + grad_output_n.data_ptr(), k_, 1.0f, + grad_weight.data_ptr(), n_); // gradient w.r.t. bias // long m_ = channels_out; @@ -324,15 +324,15 @@ std::vector dcn_v2_cuda_backward(const at::Tensor &input, // THCudaBlas_Sgemm(state, // 't', 'n', // k_, m_, 1, 1.0f, - // grad_output_n.data(), k_, - // ones.data(), 1, 1.0f, - // grad_bias.data(), 1); + // grad_output_n.data_ptr(), k_, + // ones.data_ptr(), 1, 1.0f, + // grad_bias.data_ptr(), 1); THCudaBlas_Sgemm(state, 'N', 'N', 1, m_, k_, 1.0f, - ones.data(), 1, - grad_output_n.data(), k_, + ones.data_ptr(), 1, + grad_output_n.data_ptr(), k_, 1.0f, - grad_bias.data(), 1); + grad_bias.data_ptr(), 1); } return { diff --git a/src/cuda/dcn_v2_psroi_pooling_cuda.cu b/src/cuda/dcn_v2_psroi_pooling_cuda.cu index 8f08f6a..7039868 100644 --- a/src/cuda/dcn_v2_psroi_pooling_cuda.cu +++ b/src/cuda/dcn_v2_psroi_pooling_cuda.cu @@ -281,9 +281,9 @@ dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, const int sample_per_part, const float trans_std) { - AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor"); - AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); + AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(bbox.is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); @@ -314,17 +314,17 @@ dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); dim3 block(512); - AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] { + AT_DISPATCH_FLOATING_TYPES(input.scalar_type(), "dcn_v2_psroi_pooling_cuda_forward", [&] { DeformablePSROIPoolForwardKernelCuda<<>>( out_size, - input.contiguous().data(), + input.contiguous().data_ptr(), spatial_scale, channels, height, width, pooled_height, pooled_width, - bbox.contiguous().data(), - trans.contiguous().data(), + bbox.contiguous().data_ptr(), + trans.contiguous().data_ptr(), no_trans, trans_std, sample_per_part, @@ -333,8 +333,8 @@ dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, part_size, num_classes, channels_each_class, - out.data(), - top_count.data()); + out.data_ptr(), + top_count.data_ptr()); }); THCudaCheck(cudaGetLastError()); return std::make_tuple(out, top_count); @@ -355,11 +355,11 @@ dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, const int sample_per_part, const float trans_std) { - AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor"); - AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); - AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor"); - AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); - AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor"); + AT_ASSERTM(out_grad.is_cuda(), "out_grad must be a CUDA tensor"); + AT_ASSERTM(input.is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(bbox.is_cuda(), "bbox must be a CUDA tensor"); + AT_ASSERTM(trans.is_cuda(), "trans must be a CUDA tensor"); + AT_ASSERTM(top_count.is_cuda(), "top_count must be a CUDA tensor"); const int batch = input.size(0); const int channels = input.size(1); @@ -388,11 +388,11 @@ dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, dim3 block(512); cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] { + AT_DISPATCH_FLOATING_TYPES(out_grad.scalar_type(), "dcn_v2_psroi_pooling_cuda_backward", [&] { DeformablePSROIPoolBackwardAccKernelCuda<<>>( out_size, - out_grad.contiguous().data(), - top_count.contiguous().data(), + out_grad.contiguous().data_ptr(), + top_count.contiguous().data_ptr(), num_bbox, spatial_scale, channels, @@ -401,11 +401,11 @@ dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, pooled_height, pooled_width, output_dim, - input_grad.contiguous().data(), - trans_grad.contiguous().data(), - input.contiguous().data(), - bbox.contiguous().data(), - trans.contiguous().data(), + input_grad.contiguous().data_ptr(), + trans_grad.contiguous().data_ptr(), + input.contiguous().data_ptr(), + bbox.contiguous().data_ptr(), + trans.contiguous().data_ptr(), no_trans, trans_std, sample_per_part, diff --git a/src/dcn_v2.h b/src/dcn_v2.h index de670bf..cffa4e9 100644 --- a/src/dcn_v2.h +++ b/src/dcn_v2.h @@ -22,7 +22,7 @@ dcn_v2_forward(const at::Tensor &input, const int dilation_w, const int deformable_group) { - if (input.type().is_cuda()) + if (input.is_cuda()) { #ifdef WITH_CUDA return dcn_v2_cuda_forward(input, weight, bias, offset, mask, @@ -58,7 +58,7 @@ dcn_v2_backward(const at::Tensor &input, int dilation_h, int dilation_w, int deformable_group) { - if (input.type().is_cuda()) + if (input.is_cuda()) { #ifdef WITH_CUDA return dcn_v2_cuda_backward(input, @@ -104,7 +104,7 @@ dcn_v2_psroi_pooling_forward(const at::Tensor &input, const int sample_per_part, const float trans_std) { - if (input.type().is_cuda()) + if (input.is_cuda()) { #ifdef WITH_CUDA return dcn_v2_psroi_pooling_cuda_forward(input, @@ -152,7 +152,7 @@ dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad, const int sample_per_part, const float trans_std) { - if (input.type().is_cuda()) + if (input.is_cuda()) { #ifdef WITH_CUDA return dcn_v2_psroi_pooling_cuda_backward(out_grad,