diff --git a/.gitignore b/.gitignore index b1e9421..f1202fa 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ *.so *.o *pyc -_ext \ No newline at end of file +_ext +build +DCNv2.egg-info \ No newline at end of file diff --git a/README.md b/README.md index 0ddcf18..e7d23d7 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -## Deformable Convolutional Networks V2 with Pytorch +## Deformable Convolutional Networks V2 with Pytorch 1.0 ### Build ```bash diff --git a/build.py b/build.py deleted file mode 100644 index b93f2a9..0000000 --- a/build.py +++ /dev/null @@ -1,43 +0,0 @@ -import os -import torch -from torch.utils.ffi import create_extension - - -sources = ['src/dcn_v2.c'] -headers = ['src/dcn_v2.h'] -defines = [] -with_cuda = False - -extra_objects = [] -if torch.cuda.is_available(): - print('Including CUDA code.') - sources += ['src/dcn_v2_cuda.c'] - headers += ['src/dcn_v2_cuda.h'] - defines += [('WITH_CUDA', None)] - extra_objects += ['src/cuda/dcn_v2_im2col_cuda.cu.o'] - extra_objects += ['src/cuda/dcn_v2_psroi_pooling_cuda.cu.o'] - with_cuda = True -else: - raise ValueError('CUDA is not available') - -extra_compile_args = ['-fopenmp', '-std=c99'] - -this_file = os.path.dirname(os.path.realpath(__file__)) -print(this_file) -sources = [os.path.join(this_file, fname) for fname in sources] -headers = [os.path.join(this_file, fname) for fname in headers] -extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] - -ffi = create_extension( - '_ext.dcn_v2', - headers=headers, - sources=sources, - define_macros=defines, - relative_to=__file__, - with_cuda=with_cuda, - extra_objects=extra_objects, - extra_compile_args=extra_compile_args -) - -if __name__ == '__main__': - ffi.build() diff --git a/build_double.py b/build_double.py deleted file mode 100644 index 02f3912..0000000 --- a/build_double.py +++ /dev/null @@ -1,43 +0,0 @@ -import os -import torch -from torch.utils.ffi import create_extension - - -sources = ['src/dcn_v2_double.c'] -headers = ['src/dcn_v2_double.h'] -defines = [] -with_cuda = False - -extra_objects = [] -if torch.cuda.is_available(): - print('Including CUDA code.') - sources += ['src/dcn_v2_cuda_double.c'] - headers += ['src/dcn_v2_cuda_double.h'] - defines += [('WITH_CUDA', None)] - extra_objects += ['src/cuda/dcn_v2_im2col_cuda_double.cu.o'] - extra_objects += ['src/cuda/dcn_v2_psroi_pooling_cuda_double.cu.o'] - with_cuda = True -else: - raise ValueError('CUDA is not available') - -extra_compile_args = ['-fopenmp', '-std=c99'] - -this_file = os.path.dirname(os.path.realpath(__file__)) -print(this_file) -sources = [os.path.join(this_file, fname) for fname in sources] -headers = [os.path.join(this_file, fname) for fname in headers] -extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] - -ffi = create_extension( - '_ext.dcn_v2_double', - headers=headers, - sources=sources, - define_macros=defines, - relative_to=__file__, - with_cuda=with_cuda, - extra_objects=extra_objects, - extra_compile_args=extra_compile_args -) - -if __name__ == '__main__': - ffi.build() diff --git a/dcn_v2.py b/dcn_v2.py index a409a2b..982bef5 100644 --- a/dcn_v2.py +++ b/dcn_v2.py @@ -3,13 +3,56 @@ from __future__ import print_function from __future__ import division -import torch import math +import torch from torch import nn +from torch.autograd import Function from torch.nn.modules.utils import _pair +from torch.autograd.function import once_differentiable + +import _ext as _backend + + +class _DCNv2(Function): + @staticmethod + def forward(ctx, input, offset, mask, weight, bias, + stride, padding, dilation, deformable_groups): + ctx.stride = _pair(stride) + ctx.padding = _pair(padding) + ctx.dilation = _pair(dilation) + ctx.kernel_size = _pair(weight.shape[2:4]) + ctx.deformable_groups = deformable_groups + output = _backend.dcn_v2_forward(input, weight, bias, + offset, mask, + ctx.kernel_size[0], ctx.kernel_size[1], + ctx.stride[0], ctx.stride[1], + ctx.padding[0], ctx.padding[1], + ctx.dilation[0], ctx.dilation[1], + ctx.deformable_groups) + ctx.save_for_backward(input, offset, mask, weight, bias) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, offset, mask, weight, bias = ctx.saved_tensors + grad_input, grad_offset, grad_mask, grad_weight, grad_bias = \ + _backend.dcn_v2_backward(input, weight, + bias, + offset, mask, + grad_output, + ctx.kernel_size[0], ctx.kernel_size[1], + ctx.stride[0], ctx.stride[1], + ctx.padding[0], ctx.padding[1], + ctx.dilation[0], ctx.dilation[1], + ctx.deformable_groups) + + return grad_input, grad_offset, grad_mask, grad_weight, grad_bias,\ + None, None, None, None, + + +dcn_v2_conv = _DCNv2.apply -from dcn_v2_func import DCNv2Function -from dcn_v2_func import DCNv2PoolingFunction class DCNv2(nn.Module): @@ -19,12 +62,13 @@ def __init__(self, in_channels, out_channels, self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = _pair(kernel_size) - self.stride = stride - self.padding = padding - self.dilation = dilation + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) self.deformable_groups = deformable_groups - self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) + self.weight = nn.Parameter(torch.Tensor( + out_channels, in_channels, *self.kernel_size)) self.bias = nn.Parameter(torch.Tensor(out_channels)) self.reset_parameters() @@ -37,8 +81,17 @@ def reset_parameters(self): self.bias.data.zero_() def forward(self, input, offset, mask): - func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) - return func(input, offset, mask, self.weight, self.bias) + assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ + offset.shape[1] + assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \ + mask.shape[1] + return dcn_v2_conv(input, offset, mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.deformable_groups) class DCN(DCNv2): @@ -49,11 +102,12 @@ def __init__(self, in_channels, out_channels, super(DCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, deformable_groups) + channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1] self.conv_offset_mask = nn.Conv2d(self.in_channels, - self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + channels_, kernel_size=self.kernel_size, - stride=(self.stride, self.stride), - padding=(self.padding, self.padding), + stride=self.stride, + padding=self.padding, bias=True) self.init_offset() @@ -66,8 +120,68 @@ def forward(self, input): o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) - func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) - return func(input, offset, mask, self.weight, self.bias) + return dcn_v2_conv(input, offset, mask, + self.weight, self.bias, + self.stride, + self.padding, + self.dilation, + self.deformable_groups) + + + +class _DCNv2Pooling(Function): + @staticmethod + def forward(ctx, input, rois, offset, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0): + ctx.spatial_scale = spatial_scale + ctx.no_trans = int(no_trans) + ctx.output_dim = output_dim + ctx.group_size = group_size + ctx.pooled_size = pooled_size + ctx.part_size = pooled_size if part_size is None else part_size + ctx.sample_per_part = sample_per_part + ctx.trans_std = trans_std + + output, output_count = \ + _backend.dcn_v2_psroi_pooling_forward(input, rois, offset, + ctx.no_trans, ctx.spatial_scale, + ctx.output_dim, ctx.group_size, + ctx.pooled_size, ctx.part_size, + ctx.sample_per_part, ctx.trans_std) + ctx.save_for_backward(input, rois, offset, output_count) + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + input, rois, offset, output_count = ctx.saved_tensors + grad_input, grad_offset = \ + _backend.dcn_v2_psroi_pooling_backward(grad_output, + input, + rois, + offset, + output_count, + ctx.no_trans, + ctx.spatial_scale, + ctx.output_dim, + ctx.group_size, + ctx.pooled_size, + ctx.part_size, + ctx.sample_per_part, + ctx.trans_std) + + return grad_input, None, grad_offset, \ + None, None, None, None, None, None, None, None + + +dcn_v2_pooling = _DCNv2Pooling.apply class DCNv2Pooling(nn.Module): @@ -90,20 +204,21 @@ def __init__(self, self.part_size = pooled_size if part_size is None else part_size self.sample_per_part = sample_per_part self.trans_std = trans_std - self.func = DCNv2PoolingFunction(self.spatial_scale, - self.pooled_size, - self.output_dim, - self.no_trans, - self.group_size, - self.part_size, - self.sample_per_part, - self.trans_std) - - def forward(self, data, rois, offset): + def forward(self, input, rois, offset): + assert input.shape[1] == self.output_dim if self.no_trans: - offset = data.new() - return self.func(data, rois, offset) + offset = input.new() + return dcn_v2_pooling(input, rois, offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + self.no_trans, + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std) + class DCNPooling(DCNv2Pooling): @@ -129,43 +244,60 @@ def __init__(self, self.deform_fc_dim = deform_fc_dim if not no_trans: - self.func_offset = DCNv2PoolingFunction(self.spatial_scale, - self.pooled_size, - self.output_dim, - True, - self.group_size, - self.part_size, - self.sample_per_part, - self.trans_std) - self.offset_fc = nn.Sequential( - nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim), + self.offset_mask_fc = nn.Sequential( + nn.Linear(self.pooled_size * self.pooled_size * + self.output_dim, self.deform_fc_dim), nn.ReLU(inplace=True), nn.Linear(self.deform_fc_dim, self.deform_fc_dim), nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 2) + nn.Linear(self.deform_fc_dim, self.pooled_size * + self.pooled_size * 3) ) - self.offset_fc[4].weight.data.zero_() - self.offset_fc[4].bias.data.zero_() - self.mask_fc = nn.Sequential( - nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim), - nn.ReLU(inplace=True), - nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 1), - nn.Sigmoid() - ) - self.mask_fc[2].weight.data.zero_() - self.mask_fc[2].bias.data.zero_() + self.offset_mask_fc[4].weight.data.zero_() + self.offset_mask_fc[4].bias.data.zero_() - def forward(self, data, rois): - if self.no_trans: - offset = data.new() - else: + def forward(self, input, rois): + offset = input.new() + + if not self.no_trans: + + # do roi_align first n = rois.shape[0] - offset = data.new() - x = self.func_offset(data, rois, offset) - offset = self.offset_fc(x.view(n, -1)) - offset = offset.view(n, 2, self.pooled_size, self.pooled_size) - mask = self.mask_fc(x.view(n, -1)) - mask = mask.view(n, 1, self.pooled_size, self.pooled_size) - feat = self.func(data, rois, offset) * mask - return feat - return self.func(data, rois, offset) \ No newline at end of file + roi = dcn_v2_pooling(input, rois, offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + True, # no trans + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std) + + # build mask and offset + offset_mask = self.offset_mask_fc(roi.view(n, -1)) + offset_mask = offset_mask.view( + n, 3, self.pooled_size, self.pooled_size) + o1, o2, mask = torch.chunk(offset_mask, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + + # do pooling with offset and mask + return dcn_v2_pooling(input, rois, offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + self.no_trans, + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std) * mask + # only roi_align + return dcn_v2_pooling(input, rois, offset, + self.spatial_scale, + self.pooled_size, + self.output_dim, + self.no_trans, + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std) diff --git a/dcn_v2_func.py b/dcn_v2_func.py deleted file mode 100644 index 84d1313..0000000 --- a/dcn_v2_func.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python -from __future__ import absolute_import -from __future__ import print_function -from __future__ import division - -import torch -from torch.autograd import Function - -from _ext import dcn_v2 as _backend -# from _ext import dcn_v2_double as _backend - - -class DCNv2Function(Function): - - def __init__(self, stride, padding, dilation=1, deformable_groups=1): - super(DCNv2Function, self).__init__() - self.stride = stride - self.padding = padding - self.dilation = dilation - self.deformable_groups = deformable_groups - - def forward(self, input, offset, mask, weight, bias): - if not input.is_cuda: - raise NotImplementedError - if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: - self.save_for_backward(input, offset, mask, weight, bias) - output = input.new(*self._infer_shape(input, weight)) - self._bufs = [input.new(), input.new()] - _backend.dcn_v2_cuda_forward(input, weight, - bias, self._bufs[0], - offset, mask, - output, self._bufs[1], - weight.shape[2], weight.shape[3], - self.stride, self.stride, - self.padding, self.padding, - self.dilation, self.dilation, - self.deformable_groups) - return output - - def backward(self, grad_output): - if not grad_output.is_cuda: - raise NotImplementedError - input, offset, mask, weight, bias = self.saved_tensors - grad_input = input.new(*input.size()).zero_() - grad_offset = offset.new(*offset.size()).zero_() - grad_mask = mask.new(*mask.size()).zero_() - grad_weight = weight.new(*weight.size()).zero_() - grad_bias = bias.new(*bias.size()).zero_() - _backend.dcn_v2_cuda_backward(input, weight, - bias, self._bufs[0], - offset, mask, - self._bufs[1], - grad_input, grad_weight, - grad_bias, grad_offset, - grad_mask, grad_output, - weight.shape[2], weight.shape[3], - self.stride, self.stride, - self.padding, self.padding, - self.dilation, self.dilation, - self.deformable_groups) - - return grad_input, grad_offset, grad_mask, grad_weight, grad_bias - - def _infer_shape(self, input, weight): - n = input.size(0) - channels_out = weight.size(0) - height, width = input.shape[2:4] - kernel_h, kernel_w = weight.shape[2:4] - height_out = (height + 2 * self.padding - - (self.dilation * (kernel_h - 1) + 1)) // self.stride + 1 - width_out = (width + 2 * self.padding - (self.dilation * - (kernel_w - 1) + 1)) // self.stride + 1 - return (n, channels_out, height_out, width_out) - - -class DCNv2PoolingFunction(Function): - - def __init__(self, - spatial_scale, - pooled_size, - output_dim, - no_trans, - group_size=1, - part_size=None, - sample_per_part=4, - trans_std=.0): - super(DCNv2PoolingFunction, self).__init__() - self.spatial_scale = spatial_scale - self.pooled_size = pooled_size - self.output_dim = output_dim - self.no_trans = no_trans - self.group_size = group_size - self.part_size = pooled_size if part_size is None else part_size - self.sample_per_part = sample_per_part - self.trans_std = trans_std - - assert self.trans_std >= 0.0 and self.trans_std <= 1.0 - - def forward(self, data, rois, offset): - if not data.is_cuda: - raise NotImplementedError - - output = data.new(*self._infer_shape(data, rois)) - output_count = data.new(*self._infer_shape(data, rois)) - _backend.dcn_v2_psroi_pooling_cuda_forward(data, rois, offset, - output, output_count, - self.no_trans, self.spatial_scale, - self.output_dim, self.group_size, - self.pooled_size, self.part_size, - self.sample_per_part, self.trans_std) - - if data.requires_grad or rois.requires_grad or offset.requires_grad: - self.save_for_backward(data, rois, offset, output_count) - - return output - - def backward(self, grad_output): - if not grad_output.is_cuda: - raise NotImplementedError - - data, rois, offset, output_count = self.saved_tensors - grad_input = data.new(*data.size()).zero_() - grad_offset = offset.new(*offset.size()).zero_() - - _backend.dcn_v2_psroi_pooling_cuda_backward(grad_output, - data, - rois, - offset, - output_count, - grad_input, - grad_offset, - self.no_trans, - self.spatial_scale, - self.output_dim, - self.group_size, - self.pooled_size, - self.part_size, - self.sample_per_part, - self.trans_std) - return grad_input, None, grad_offset - - def _infer_shape(self, data, rois): - # _, c, h, w = data.shape[:4] - c = data.shape[1] - n = rois.shape[0] - return (n, self.output_dim, self.pooled_size, self.pooled_size) diff --git a/dist/DCNv2-0.1-py2.7-linux-x86_64.egg b/dist/DCNv2-0.1-py2.7-linux-x86_64.egg new file mode 100644 index 0000000..933df10 Binary files /dev/null and b/dist/DCNv2-0.1-py2.7-linux-x86_64.egg differ diff --git a/make.sh b/make.sh index d489f7c..f1f15c0 100755 --- a/make.sh +++ b/make.sh @@ -1,14 +1,2 @@ #!/usr/bin/env bash -cd src/cuda - -# compile dcn -nvcc -c -o dcn_v2_im2col_cuda.cu.o dcn_v2_im2col_cuda.cu -x cu -Xcompiler -fPIC -nvcc -c -o dcn_v2_im2col_cuda_double.cu.o dcn_v2_im2col_cuda_double.cu -x cu -Xcompiler -fPIC - -# compile dcn-roi-pooling -nvcc -c -o dcn_v2_psroi_pooling_cuda.cu.o dcn_v2_psroi_pooling_cuda.cu -x cu -Xcompiler -fPIC -nvcc -c -o dcn_v2_psroi_pooling_cuda_double.cu.o dcn_v2_psroi_pooling_cuda_double.cu -x cu -Xcompiler -fPIC - -cd - -python build.py -python build_double.py +python setup.py build develop diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..1082494 --- /dev/null +++ b/setup.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python + +import os +import glob + +import torch + +from torch.utils.cpp_extension import CUDA_HOME +from torch.utils.cpp_extension import CppExtension +from torch.utils.cpp_extension import CUDAExtension + +from setuptools import find_packages +from setuptools import setup + +requirements = ["torch", "torchvision"] + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "cuda", "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": []} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-DCUDA_HAS_FP16=1", + "-D__CUDA_NO_HALF_OPERATORS__", + "-D__CUDA_NO_HALF_CONVERSIONS__", + "-D__CUDA_NO_HALF2_OPERATORS__", + ] + else: + raise NotImplementedError('Cuda is not availabel') + + sources = [os.path.join(extensions_dir, s) for s in sources] + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "_ext", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + return ext_modules + +setup( + name="DCNv2", + version="0.1", + author="charlesshang", + url="https://github.com/charlesshang/DCNv2", + description="deformable convolutional networks", + packages=find_packages(exclude=("configs", "tests",)), + # install_requires=requirements, + ext_modules=get_extensions(), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) \ No newline at end of file diff --git a/src/cpu/dcn_v2_cpu.cpp b/src/cpu/dcn_v2_cpu.cpp new file mode 100644 index 0000000..a68ccef --- /dev/null +++ b/src/cpu/dcn_v2_cpu.cpp @@ -0,0 +1,74 @@ +#include + +#include +#include + + +at::Tensor +dcn_v2_cpu_forward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int deformable_group) +{ + AT_ERROR("Not implement on cpu"); +} + +std::vector +dcn_v2_cpu_backward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const at::Tensor &grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + AT_ERROR("Not implement on cpu"); +} + +std::tuple +dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + AT_ERROR("Not implement on cpu"); +} + +std::tuple +dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, + const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const at::Tensor &top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + AT_ERROR("Not implement on cpu"); +} \ No newline at end of file diff --git a/src/cpu/vision.h b/src/cpu/vision.h new file mode 100644 index 0000000..d5fbf1f --- /dev/null +++ b/src/cpu/vision.h @@ -0,0 +1,60 @@ +#pragma once +#include + +at::Tensor +dcn_v2_cpu_forward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int deformable_group); + +std::vector +dcn_v2_cpu_backward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const at::Tensor &grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); + + +std::tuple +dcn_v2_psroi_pooling_cpu_forward(const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); + +std::tuple +dcn_v2_psroi_pooling_cpu_backward(const at::Tensor &out_grad, + const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const at::Tensor &top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); \ No newline at end of file diff --git a/src/cuda/dcn_v2_cuda.cu b/src/cuda/dcn_v2_cuda.cu new file mode 100644 index 0000000..d33cc0f --- /dev/null +++ b/src/cuda/dcn_v2_cuda.cu @@ -0,0 +1,238 @@ +#include +#include "cuda/dcn_v2_im2col_cuda.h" + +#include +#include + +#include +#include +#include + +extern THCState *state; + +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +at::Tensor +dcn_v2_cuda_forward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int deformable_group) +{ + // THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, weight, bias, offset, mask)); + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); + AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); + AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + // printf("Kernels: %d %d %d %d\n", kernel_h_, kernel_w_, kernel_w, kernel_h); + // printf("Channels: %d %d\n", channels, channels_kernel); + // printf("Channels: %d %d\n", channels_out, channels_kernel); + + AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, + "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); + + AT_ASSERTM(channels == channels_kernel, + "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + auto ones = at::ones({height_out, width_out}, input.options()); + auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); + auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); + + using scalar_t = float; + for (int b = 0; b < batch; b++) + { + auto input_n = input.select(0, b); + auto offset_n = offset.select(0, b); + auto mask_n = mask.select(0, b); + auto output_n = output.select(0, b); + + // Do Bias first: + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + // (N x 1) (1 x M) + long m_ = channels_out; + long n_ = height_out * width_out; + long k_ = 1; + THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, + ones.contiguous().data(), k_, + bias.contiguous().data(), k_, 0.0f, + output_n.data(), n_); + + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + input_n.data(), + offset_n.data(), + mask_n.data(), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + deformable_group, + columns.data()); + + //(k * m) x (m * n) + // Y = WC + long m = channels_out; + long n = height_out * width_out; + long k = channels * kernel_h * kernel_w; + THCudaBlas_Sgemm(state, 'n', 'n', n, m, k, 1.0f, + columns.data(), n, + weight.data(), k, 1.0f, + output_n.data(), n); + } + return output; +} + +std::vector dcn_v2_cuda_backward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const at::Tensor &grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + + THArgCheck(input.is_contiguous(), 1, "input tensor has to be contiguous"); + THArgCheck(weight.is_contiguous(), 2, "weight tensor has to be contiguous"); + + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(weight.type().is_cuda(), "weight must be a CUDA tensor"); + AT_ASSERTM(bias.type().is_cuda(), "bias must be a CUDA tensor"); + AT_ASSERTM(offset.type().is_cuda(), "offset must be a CUDA tensor"); + AT_ASSERTM(mask.type().is_cuda(), "mask must be a CUDA tensor"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + + const int channels_out = weight.size(0); + const int channels_kernel = weight.size(1); + const int kernel_h_ = weight.size(2); + const int kernel_w_ = weight.size(3); + + AT_ASSERTM(kernel_h_ == kernel_h && kernel_w_ == kernel_w, + "Input shape and kernel shape wont match: (%d x %d vs %d x %d).", kernel_h_, kernel_w, kernel_h_, kernel_w_); + + AT_ASSERTM(channels == channels_kernel, + "Input shape and kernel channels wont match: (%d vs %d).", channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + auto ones = at::ones({height_out, width_out}, input.options()); + auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options()); + auto output = at::empty({batch, channels_out, height_out, width_out}, input.options()); + + auto grad_input = at::zeros_like(input); + auto grad_weight = at::zeros_like(weight); + auto grad_bias = at::zeros_like(bias); + auto grad_offset = at::zeros_like(offset); + auto grad_mask = at::zeros_like(mask); + + using scalar_t = float; + + for (int b = 0; b < batch; b++) + { + auto input_n = input.select(0, b); + auto offset_n = offset.select(0, b); + auto mask_n = mask.select(0, b); + auto grad_output_n = grad_output.select(0, b); + auto grad_input_n = grad_input.select(0, b); + auto grad_offset_n = grad_offset.select(0, b); + auto grad_mask_n = grad_mask.select(0, b); + + long m = channels * kernel_h * kernel_w; + long n = height_out * width_out; + long k = channels_out; + + THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, + grad_output_n.data(), n, + weight.data(), m, 0.0f, + columns.data(), n); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), + columns.data(), + input_n.data(), + offset_n.data(), + mask_n.data(), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + grad_offset_n.data(), + grad_mask_n.data()); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), + columns.data(), + offset_n.data(), + mask_n.data(), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + grad_input_n.data()); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and group + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + input_n.data(), + offset_n.data(), + mask_n.data(), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + columns.data()); + + long m_ = channels_out; + long n_ = channels * kernel_h * kernel_w; + long k_ = height_out * width_out; + + THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, + columns.data(), k_, + grad_output_n.data(), k_, 1.0f, + grad_weight.data(), n_); + + // gradient w.r.t. bias + // long m_ = channels_out; + // long k__ = height_out * width_out; + THCudaBlas_Sgemv(state, + 't', + k_, m_, 1.0f, + grad_output_n.data(), k_, + ones.data(), 1, 1.0f, + grad_bias.data(), 1); + } + + return { + grad_input, grad_offset, grad_mask, grad_weight, grad_bias + }; +} \ No newline at end of file diff --git a/src/cuda/dcn_v2_im2col_cuda.cu b/src/cuda/dcn_v2_im2col_cuda.cu index ab22b1b..06f6028 100644 --- a/src/cuda/dcn_v2_im2col_cuda.cu +++ b/src/cuda/dcn_v2_im2col_cuda.cu @@ -3,6 +3,13 @@ #include #include +#include +#include + +#include +#include +#include + #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ i < (n); \ diff --git a/src/cuda/dcn_v2_im2col_cuda.h b/src/cuda/dcn_v2_im2col_cuda.h index 3457e96..c856831 100644 --- a/src/cuda/dcn_v2_im2col_cuda.h +++ b/src/cuda/dcn_v2_im2col_cuda.h @@ -1,3 +1,4 @@ + /*! ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** * diff --git a/src/cuda/dcn_v2_im2col_cuda_double.cu b/src/cuda/dcn_v2_im2col_cuda_double.cu deleted file mode 100644 index 29cb048..0000000 --- a/src/cuda/dcn_v2_im2col_cuda_double.cu +++ /dev/null @@ -1,399 +0,0 @@ -#include "dcn_v2_im2col_cuda_double.h" -#include -#include -#include - -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ - i < (n); \ - i += blockDim.x * gridDim.x) - -const int CUDA_NUM_THREADS = 512; -inline int GET_BLOCKS(const int N) -{ - return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; -} - -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 -#else -__device__ double atomicAdd(double* address, double val) -{ - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val + __longlong_as_double(assumed))); - } while (assumed != old); - return __longlong_as_double(old); -} -#endif - -__device__ double dmcn_im2col_bilinear(const double *bottom_data, const int data_width, - const int height, const int width, double h, double w) -{ - int h_low = floor(h); - int w_low = floor(w); - int h_high = h_low + 1; - int w_high = w_low + 1; - - double lh = h - h_low; - double lw = w - w_low; - double hh = 1 - lh, hw = 1 - lw; - - double v1 = 0; - if (h_low >= 0 && w_low >= 0) - v1 = bottom_data[h_low * data_width + w_low]; - double v2 = 0; - if (h_low >= 0 && w_high <= width - 1) - v2 = bottom_data[h_low * data_width + w_high]; - double v3 = 0; - if (h_high <= height - 1 && w_low >= 0) - v3 = bottom_data[h_high * data_width + w_low]; - double v4 = 0; - if (h_high <= height - 1 && w_high <= width - 1) - v4 = bottom_data[h_high * data_width + w_high]; - - double w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; - - double val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); - return val; -} - -__device__ double dmcn_get_gradient_weight(double argmax_h, double argmax_w, - const int h, const int w, const int height, const int width) -{ - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) - { - //empty - return 0; - } - - int argmax_h_low = floor(argmax_h); - int argmax_w_low = floor(argmax_w); - int argmax_h_high = argmax_h_low + 1; - int argmax_w_high = argmax_w_low + 1; - - double weight = 0; - if (h == argmax_h_low && w == argmax_w_low) - weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); - if (h == argmax_h_low && w == argmax_w_high) - weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); - if (h == argmax_h_high && w == argmax_w_low) - weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); - if (h == argmax_h_high && w == argmax_w_high) - weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); - return weight; -} - -__device__ double dmcn_get_coordinate_weight(double argmax_h, double argmax_w, - const int height, const int width, const double *im_data, - const int data_width, const int bp_dir) -{ - if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) - { - //empty - return 0; - } - - int argmax_h_low = floor(argmax_h); - int argmax_w_low = floor(argmax_w); - int argmax_h_high = argmax_h_low + 1; - int argmax_w_high = argmax_w_low + 1; - - double weight = 0; - - if (bp_dir == 0) - { - if (argmax_h_low >= 0 && argmax_w_low >= 0) - weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; - if (argmax_h_low >= 0 && argmax_w_high <= width - 1) - weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; - if (argmax_h_high <= height - 1 && argmax_w_low >= 0) - weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; - if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) - weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; - } - else if (bp_dir == 1) - { - if (argmax_h_low >= 0 && argmax_w_low >= 0) - weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; - if (argmax_h_low >= 0 && argmax_w_high <= width - 1) - weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; - if (argmax_h_high <= height - 1 && argmax_w_low >= 0) - weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; - if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) - weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; - } - - return weight; -} - -__global__ void modulated_deformable_im2col_gpu_kernel(const int n, - const double *data_im, const double *data_offset, const double *data_mask, - const int height, const int width, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, const int num_channels, const int deformable_group, - const int height_col, const int width_col, - double *data_col) -{ - CUDA_KERNEL_LOOP(index, n) - { - // index index of output matrix - const int w_col = index % width_col; - const int h_col = (index / width_col) % height_col; - const int b_col = (index / width_col / height_col) % batch_size; - const int c_im = (index / width_col / height_col) / batch_size; - const int c_col = c_im * kernel_h * kernel_w; - - // compute deformable group index - const int deformable_group_index = c_im / channel_per_deformable_group; - - const int h_in = h_col * stride_h - pad_h; - const int w_in = w_col * stride_w - pad_w; - - double *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; - //const double* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; - const double *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; - const double *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - - const double *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; - - for (int i = 0; i < kernel_h; ++i) - { - for (int j = 0; j < kernel_w; ++j) - { - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; - const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; - const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; - const double offset_h = data_offset_ptr[data_offset_h_ptr]; - const double offset_w = data_offset_ptr[data_offset_w_ptr]; - const double mask = data_mask_ptr[data_mask_hw_ptr]; - double val = static_cast(0); - const double h_im = h_in + i * dilation_h + offset_h; - const double w_im = w_in + j * dilation_w + offset_w; - //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { - if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) - { - //const double map_h = i * dilation_h + offset_h; - //const double map_w = j * dilation_w + offset_w; - //const int cur_height = height - h_in; - //const int cur_width = width - w_in; - //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); - val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); - } - *data_col_ptr = val * mask; - data_col_ptr += batch_size * height_col * width_col; - //data_col_ptr += height_col * width_col; - } - } - } -} - -__global__ void modulated_deformable_col2im_gpu_kernel(const int n, - const double *data_col, const double *data_offset, const double *data_mask, - const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, const int deformable_group, - const int height_col, const int width_col, - double *grad_im) -{ - CUDA_KERNEL_LOOP(index, n) - { - const int j = (index / width_col / height_col / batch_size) % kernel_w; - const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; - const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; - // compute the start and end of the output - - const int deformable_group_index = c / channel_per_deformable_group; - - int w_out = index % width_col; - int h_out = (index / width_col) % height_col; - int b = (index / width_col / height_col) % batch_size; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - - const double *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - const double *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; - const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; - const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; - const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; - const double offset_h = data_offset_ptr[data_offset_h_ptr]; - const double offset_w = data_offset_ptr[data_offset_w_ptr]; - const double mask = data_mask_ptr[data_mask_hw_ptr]; - const double cur_inv_h_data = h_in + i * dilation_h + offset_h; - const double cur_inv_w_data = w_in + j * dilation_w + offset_w; - - const double cur_top_grad = data_col[index] * mask; - const int cur_h = (int)cur_inv_h_data; - const int cur_w = (int)cur_inv_w_data; - for (int dy = -2; dy <= 2; dy++) - { - for (int dx = -2; dx <= 2; dx++) - { - if (cur_h + dy >= 0 && cur_h + dy < height && - cur_w + dx >= 0 && cur_w + dx < width && - abs(cur_inv_h_data - (cur_h + dy)) < 1 && - abs(cur_inv_w_data - (cur_w + dx)) < 1) - { - int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; - double weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); - atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); - } - } - } - } -} - -__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, - const double *data_col, const double *data_im, - const double *data_offset, const double *data_mask, - const int channels, const int height, const int width, - const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int channel_per_deformable_group, - const int batch_size, const int offset_channels, const int deformable_group, - const int height_col, const int width_col, - double *grad_offset, double *grad_mask) -{ - CUDA_KERNEL_LOOP(index, n) - { - double val = 0, mval = 0; - int w = index % width_col; - int h = (index / width_col) % height_col; - int c = (index / width_col / height_col) % offset_channels; - int b = (index / width_col / height_col) / offset_channels; - // compute the start and end of the output - - const int deformable_group_index = c / (2 * kernel_h * kernel_w); - const int col_step = kernel_h * kernel_w; - int cnt = 0; - const double *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; - const double *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; - const double *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; - const double *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; - - const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; - - for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) - { - const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; - const int bp_dir = offset_c % 2; - - int j = (col_pos / width_col / height_col / batch_size) % kernel_w; - int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; - int w_out = col_pos % width_col; - int h_out = (col_pos / width_col) % height_col; - int w_in = w_out * stride_w - pad_w; - int h_in = h_out * stride_h - pad_h; - const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); - const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); - const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); - const double offset_h = data_offset_ptr[data_offset_h_ptr]; - const double offset_w = data_offset_ptr[data_offset_w_ptr]; - const double mask = data_mask_ptr[data_mask_hw_ptr]; - double inv_h = h_in + i * dilation_h + offset_h; - double inv_w = w_in + j * dilation_w + offset_w; - if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) - { - inv_h = inv_w = -2; - } - else - { - mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); - } - const double weight = dmcn_get_coordinate_weight( - inv_h, inv_w, - height, width, data_im_ptr + cnt * height * width, width, bp_dir); - val += weight * data_col_ptr[col_pos] * mask; - cnt += 1; - } - // KERNEL_ASSIGN(grad_offset[index], offset_req, val); - grad_offset[index] = val; - if (offset_c % 2 == 0) - // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); - grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; - } -} - -void modulated_deformable_im2col_cuda(cudaStream_t stream, - const double *data_im, const double *data_offset, const double *data_mask, - const int batch_size, const int channels, const int height_im, const int width_im, - const int height_col, const int width_col, const int kernel_h, const int kenerl_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, double *data_col) -{ - // num_axes should be smaller than block size - const int channel_per_deformable_group = channels / deformable_group; - const int num_kernels = channels * batch_size * height_col * width_col; - modulated_deformable_im2col_gpu_kernel<<>>( - num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kenerl_w, - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, - batch_size, channels, deformable_group, height_col, width_col, data_col); - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); - } -} - -void modulated_deformable_col2im_cuda(cudaStream_t stream, - const double *data_col, const double *data_offset, const double *data_mask, - const int batch_size, const int channels, const int height_im, const int width_im, - const int height_col, const int width_col, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, double *grad_im) -{ - - const int channel_per_deformable_group = channels / deformable_group; - const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; - modulated_deformable_col2im_gpu_kernel<<>>( - num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, - kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, - dilation_h, dilation_w, channel_per_deformable_group, - batch_size, deformable_group, height_col, width_col, grad_im); - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); - } -} - -void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, - const double *data_col, const double *data_im, const double *data_offset, const double *data_mask, - const int batch_size, const int channels, const int height_im, const int width_im, - const int height_col, const int width_col, const int kernel_h, const int kernel_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, - double *grad_offset, double *grad_mask) -{ - const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; - const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; - modulated_deformable_col2im_coord_gpu_kernel<<>>( - num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, - kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, channel_per_deformable_group, - batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, - grad_offset, grad_mask); - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); - } -} \ No newline at end of file diff --git a/src/cuda/dcn_v2_im2col_cuda_double.h b/src/cuda/dcn_v2_im2col_cuda_double.h deleted file mode 100644 index a461692..0000000 --- a/src/cuda/dcn_v2_im2col_cuda_double.h +++ /dev/null @@ -1,100 +0,0 @@ -/*! - ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** - * - * COPYRIGHT - * - * All contributions by the University of California: - * Copyright (c) 2014-2017 The Regents of the University of California (Regents) - * All rights reserved. - * - * All other contributions: - * Copyright (c) 2014-2017, the respective contributors - * All rights reserved. - * - * Caffe uses a shared copyright model: each contributor holds copyright over - * their contributions to Caffe. The project versioning records all such - * contribution and copyright details. If a contributor wants to further mark - * their specific copyright on a particular contribution, they should indicate - * their copyright solely in the commit message of the change when it is - * committed. - * - * LICENSE - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * CONTRIBUTION AGREEMENT - * - * By contributing to the BVLC/caffe repository through pull-request, comment, - * or otherwise, the contributor releases their content to the - * license and copyright terms herein. - * - ***************** END Caffe Copyright Notice and Disclaimer ******************** - * - * Copyright (c) 2018 Microsoft - * Licensed under The MIT License [see LICENSE for details] - * \file modulated_deformable_im2col.h - * \brief Function definitions of converting an image to - * column matrix based on kernel, padding, dilation, and offset. - * These functions are mainly used in deformable convolution operators. - * \ref: https://arxiv.org/abs/1811.11168 - * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu - */ - -/***************** Adapted by Charles Shang *********************/ - -#ifndef DCN_V2_IM2COL_CUDA_DOUBLE -#define DCN_V2_IM2COL_CUDA_DOUBLE - -#ifdef __cplusplus -extern "C" -{ -#endif - - void modulated_deformable_im2col_cuda(cudaStream_t stream, - const double *data_im, const double *data_offset, const double *data_mask, - const int batch_size, const int channels, const int height_im, const int width_im, - const int height_col, const int width_col, const int kernel_h, const int kenerl_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, double *data_col); - - void modulated_deformable_col2im_cuda(cudaStream_t stream, - const double *data_col, const double *data_offset, const double *data_mask, - const int batch_size, const int channels, const int height_im, const int width_im, - const int height_col, const int width_col, const int kernel_h, const int kenerl_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, double *grad_im); - - void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, - const double *data_col, const double *data_im, const double *data_offset, const double *data_mask, - const int batch_size, const int channels, const int height_im, const int width_im, - const int height_col, const int width_col, const int kernel_h, const int kenerl_w, - const int pad_h, const int pad_w, const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group, - double *grad_offset, double *grad_mask); - -#ifdef __cplusplus -} -#endif - -#endif \ No newline at end of file diff --git a/src/cuda/dcn_v2_psroi_pooling_cuda.cu b/src/cuda/dcn_v2_psroi_pooling_cuda.cu index 295657c..07b438e 100644 --- a/src/cuda/dcn_v2_psroi_pooling_cuda.cu +++ b/src/cuda/dcn_v2_psroi_pooling_cuda.cu @@ -6,10 +6,18 @@ * \author Yi Li, Guodong Zhang, Jifeng Dai */ /***************** Adapted by Charles Shang *********************/ -#include "dcn_v2_psroi_pooling_cuda.h" + #include #include #include +#include + +#include +#include + +#include +#include +#include #define CUDA_KERNEL_LOOP(i, n) \ for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ @@ -22,10 +30,11 @@ inline int GET_BLOCKS(const int N) return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; } -__device__ float bilinear_interp( - const float *data, - const float x, - const float y, +template +__device__ T bilinear_interp( + const T *data, + const T x, + const T y, const int width, const int height) { @@ -33,34 +42,38 @@ __device__ float bilinear_interp( int x2 = ceil(x); int y1 = floor(y); int y2 = ceil(y); - float dist_x = (float)(x - x1); - float dist_y = (float)(y - y1); - float value11 = data[y1 * width + x1]; - float value12 = data[y2 * width + x1]; - float value21 = data[y1 * width + x2]; - float value22 = data[y2 * width + x2]; - float value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; + T dist_x = static_cast(x - x1); + T dist_y = static_cast(y - y1); + T value11 = data[y1 * width + x1]; + T value12 = data[y2 * width + x1]; + T value21 = data[y1 * width + x2]; + T value22 = data[y2 * width + x2]; + T value = (1 - dist_x) * (1 - dist_y) * value11 + + (1 - dist_x) * dist_y * value12 + + dist_x * (1 - dist_y) * value21 + + dist_x * dist_y * value22; return value; } +template __global__ void DeformablePSROIPoolForwardKernel( const int count, - const float *bottom_data, - const float spatial_scale, + const T *bottom_data, + const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, - const float *bottom_rois, const float *bottom_trans, + const T *bottom_rois, const T *bottom_trans, const int no_trans, - const float trans_std, + const T trans_std, const int sample_per_part, const int output_dim, const int group_size, const int part_size, const int num_classes, const int channels_each_class, - float *top_data, - float *top_count) + T *top_data, + T *top_count) { CUDA_KERNEL_LOOP(index, count) { @@ -71,49 +84,49 @@ __global__ void DeformablePSROIPoolForwardKernel( int n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling - const float *offset_bottom_rois = bottom_rois + n * 5; + const T *offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; - float roi_start_w = (float)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; - float roi_start_h = (float)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; - float roi_end_w = (float)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; - float roi_end_h = (float)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; // Force too small ROIs to be 1x1 - float roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 - float roi_height = max(roi_end_h - roi_start_h, 0.1); + T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + T roi_height = max(roi_end_h - roi_start_h, 0.1); // Compute w and h at bottom - float bin_size_h = roi_height / (float)(pooled_height); - float bin_size_w = roi_width / (float)(pooled_width); + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); - float sub_bin_size_h = bin_size_h / (float)(sample_per_part); - float sub_bin_size_w = bin_size_w / (float)(sample_per_part); + T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); - int part_h = floor((float)(ph) / pooled_height * part_size); - int part_w = floor((float)(pw) / pooled_width * part_size); + int part_h = floor(static_cast(ph) / pooled_height * part_size); + int part_w = floor(static_cast(pw) / pooled_width * part_size); int class_id = ctop / channels_each_class; - float trans_x = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; - float trans_y = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; + T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; + T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; - float wstart = (float)(pw)*bin_size_w + roi_start_w; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; - float hstart = (float)(ph)*bin_size_h + roi_start_h; + T hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; - float sum = 0; + T sum = 0; int count = 0; - int gw = floor((float)(pw)*group_size / pooled_width); - int gh = floor((float)(ph)*group_size / pooled_height); + int gw = floor(static_cast(pw) * group_size / pooled_width); + int gh = floor(static_cast(ph) * group_size / pooled_height); gw = min(max(gw, 0), group_size - 1); gh = min(max(gh, 0), group_size - 1); - const float *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; + const T *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; for (int ih = 0; ih < sample_per_part; ih++) { for (int iw = 0; iw < sample_per_part; iw++) { - float w = wstart + iw * sub_bin_size_w; - float h = hstart + ih * sub_bin_size_h; + T w = wstart + iw * sub_bin_size_w; + T h = hstart + ih * sub_bin_size_h; // bilinear interpolation if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { @@ -122,32 +135,33 @@ __global__ void DeformablePSROIPoolForwardKernel( w = min(max(w, 0.), width - 1.); h = min(max(h, 0.), height - 1.); int c = (ctop * group_size + gh) * group_size + gw; - float val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); + T val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); sum += val; count++; } } - top_data[index] = count == 0 ? (float)(0) : sum / count; + top_data[index] = count == 0 ? static_cast(0) : sum / count; top_count[index] = count; } } +template __global__ void DeformablePSROIPoolBackwardAccKernel( const int count, - const float *top_diff, - const float *top_count, + const T *top_diff, + const T *top_count, const int num_rois, - const float spatial_scale, + const T spatial_scale, const int channels, const int height, const int width, const int pooled_height, const int pooled_width, const int output_dim, - float *bottom_data_diff, float *bottom_trans_diff, - const float *bottom_data, - const float *bottom_rois, - const float *bottom_trans, + T *bottom_data_diff, T *bottom_trans_diff, + const T *bottom_data, + const T *bottom_rois, + const T *bottom_trans, const int no_trans, - const float trans_std, + const T trans_std, const int sample_per_part, const int group_size, const int part_size, @@ -163,44 +177,44 @@ __global__ void DeformablePSROIPoolBackwardAccKernel( int n = index / pooled_width / pooled_height / output_dim; // [start, end) interval for spatial sampling - const float *offset_bottom_rois = bottom_rois + n * 5; + const T *offset_bottom_rois = bottom_rois + n * 5; int roi_batch_ind = offset_bottom_rois[0]; - float roi_start_w = (float)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; - float roi_start_h = (float)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; - float roi_end_w = (float)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; - float roi_end_h = (float)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + T roi_start_w = static_cast(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + T roi_start_h = static_cast(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + T roi_end_w = static_cast(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + T roi_end_h = static_cast(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; // Force too small ROIs to be 1x1 - float roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 - float roi_height = max(roi_end_h - roi_start_h, 0.1); + T roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + T roi_height = max(roi_end_h - roi_start_h, 0.1); // Compute w and h at bottom - float bin_size_h = roi_height / (float)(pooled_height); - float bin_size_w = roi_width / (float)(pooled_width); + T bin_size_h = roi_height / static_cast(pooled_height); + T bin_size_w = roi_width / static_cast(pooled_width); - float sub_bin_size_h = bin_size_h / (float)(sample_per_part); - float sub_bin_size_w = bin_size_w / (float)(sample_per_part); + T sub_bin_size_h = bin_size_h / static_cast(sample_per_part); + T sub_bin_size_w = bin_size_w / static_cast(sample_per_part); - int part_h = floor((float)(ph) / pooled_height * part_size); - int part_w = floor((float)(pw) / pooled_width * part_size); + int part_h = floor(static_cast(ph) / pooled_height * part_size); + int part_w = floor(static_cast(pw) / pooled_width * part_size); int class_id = ctop / channels_each_class; - float trans_x = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; - float trans_y = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; + T trans_x = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; + T trans_y = no_trans ? static_cast(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; - float wstart = (float)(pw)*bin_size_w + roi_start_w; + T wstart = static_cast(pw) * bin_size_w + roi_start_w; wstart += trans_x * roi_width; - float hstart = (float)(ph)*bin_size_h + roi_start_h; + T hstart = static_cast(ph) * bin_size_h + roi_start_h; hstart += trans_y * roi_height; if (top_count[index] <= 0) { continue; } - float diff_val = top_diff[index] / top_count[index]; - const float *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; - float *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; - int gw = floor((float)(pw)*group_size / pooled_width); - int gh = floor((float)(ph)*group_size / pooled_height); + T diff_val = top_diff[index] / top_count[index]; + const T *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; + T *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; + int gw = floor(static_cast(pw) * group_size / pooled_width); + int gh = floor(static_cast(ph) * group_size / pooled_height); gw = min(max(gw, 0), group_size - 1); gh = min(max(gh, 0), group_size - 1); @@ -208,8 +222,8 @@ __global__ void DeformablePSROIPoolBackwardAccKernel( { for (int iw = 0; iw < sample_per_part; iw++) { - float w = wstart + iw * sub_bin_size_w; - float h = hstart + ih * sub_bin_size_h; + T w = wstart + iw * sub_bin_size_w; + T h = hstart + ih * sub_bin_size_h; // bilinear interpolation if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) { @@ -223,11 +237,11 @@ __global__ void DeformablePSROIPoolBackwardAccKernel( int x1 = ceil(w); int y0 = floor(h); int y1 = ceil(h); - float dist_x = w - x0, dist_y = h - y0; - float q00 = (1 - dist_x) * (1 - dist_y); - float q01 = (1 - dist_x) * dist_y; - float q10 = dist_x * (1 - dist_y); - float q11 = dist_x * dist_y; + T dist_x = w - x0, dist_y = h - y0; + T q00 = (1 - dist_x) * (1 - dist_y); + T q01 = (1 - dist_x) * dist_y; + T q10 = dist_x * (1 - dist_y); + T q11 = dist_x * dist_y; int bottom_index_base = c * height * width; atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); @@ -238,13 +252,13 @@ __global__ void DeformablePSROIPoolBackwardAccKernel( { continue; } - float U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; - float U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; - float U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; - float U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; - float diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; + T U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; + T U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; + T U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; + T U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; + T diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; diff_x *= roi_width; - float diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; + T diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; diff_y *= roi_height; atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); @@ -254,100 +268,152 @@ __global__ void DeformablePSROIPoolBackwardAccKernel( } } -void DeformablePSROIPoolForward(cudaStream_t stream, - const float *data, - const float *bbox, - const float *trans, - float *out, - float *top_count, - const int batch, - const int channels, - const int height, - const int width, - const int num_bbox, - const int channels_trans, - const int no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) +std::tuple +dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) { + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(bbox.type().is_cuda(), "rois must be a CUDA tensor"); + AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + const int num_bbox = bbox.size(0); - const float *bottom_data = data; - const float *bottom_rois = bbox; - const float *bottom_trans = no_trans ? NULL : trans; - float *top_data = out; - float *top_count_data = top_count; + AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); + auto pooled_height = pooled_size; + auto pooled_width = pooled_size; + + auto out = at::empty({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); + long out_size = num_bbox * output_dim * pooled_height * pooled_width; + auto top_count = at::zeros({num_bbox, output_dim, pooled_height, pooled_width}, input.options()); - const int pooled_height = pooled_size; - const int pooled_width = pooled_size; - const int count = num_bbox * output_dim * pooled_height * pooled_width; const int num_classes = no_trans ? 1 : channels_trans / 2; const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; - DeformablePSROIPoolForwardKernel<<>>( - count, bottom_data, spatial_scale, channels, height, width, pooled_height, pooled_width, - bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, output_dim, - group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) + if (out.numel() == 0) { - printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); + THCudaCheck(cudaGetLastError()); + return std::make_tuple(out, top_count); } + + dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); + dim3 block(512); + + AT_DISPATCH_FLOATING_TYPES(input.type(), "dcn_v2_psroi_pooling_cuda_forward", [&] { + DeformablePSROIPoolForwardKernel<<>>( + out_size, + input.contiguous().data(), + spatial_scale, + channels, + height, width, + pooled_height, + pooled_width, + bbox.contiguous().data(), + trans.contiguous().data(), + no_trans, + trans_std, + sample_per_part, + output_dim, + group_size, + part_size, + num_classes, + channels_each_class, + out.data(), + top_count.data()); + }); + THCudaCheck(cudaGetLastError()); + return std::make_tuple(out, top_count); } -void DeformablePSROIPoolBackwardAcc(cudaStream_t stream, - const float *out_grad, - const float *data, - const float *bbox, - const float *trans, - const float *top_count, - float *in_grad, - float *trans_grad, - const int batch, - const int channels, - const int height, - const int width, - const int num_bbox, - const int channels_trans, - const int no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) +std::tuple +dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, + const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const at::Tensor &top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) { - // LOG(INFO) << "DeformablePSROIPoolBackward"; - const float *top_diff = out_grad; - const float *bottom_data = data; - const float *bottom_rois = bbox; - const float *bottom_trans = no_trans ? NULL : trans; - float *bottom_data_diff = in_grad; - float *bottom_trans_diff = no_trans ? NULL : trans_grad; - const float *top_count_data = top_count; - - const int num_rois = num_bbox; - const int pooled_height = pooled_size; - const int pooled_width = pooled_size; - const int count = num_bbox * output_dim * pooled_height * pooled_width; + AT_ASSERTM(out_grad.type().is_cuda(), "out_grad must be a CUDA tensor"); + AT_ASSERTM(input.type().is_cuda(), "input must be a CUDA tensor"); + AT_ASSERTM(bbox.type().is_cuda(), "bbox must be a CUDA tensor"); + AT_ASSERTM(trans.type().is_cuda(), "trans must be a CUDA tensor"); + AT_ASSERTM(top_count.type().is_cuda(), "top_count must be a CUDA tensor"); + + const int batch = input.size(0); + const int channels = input.size(1); + const int height = input.size(2); + const int width = input.size(3); + const int channels_trans = no_trans ? 2 : trans.size(1); + const int num_bbox = bbox.size(0); + + AT_ASSERTM(channels == output_dim, "input channels and output channels must equal"); + auto pooled_height = pooled_size; + auto pooled_width = pooled_size; + long out_size = num_bbox * output_dim * pooled_height * pooled_width; const int num_classes = no_trans ? 1 : channels_trans / 2; const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; - DeformablePSROIPoolBackwardAccKernel<<>>( - count, top_diff, top_count_data, num_rois, spatial_scale, channels, height, width, - pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, - bottom_data, bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, - group_size, part_size, num_classes, channels_each_class); + auto input_grad = at::zeros({batch, channels, height, width}, out_grad.options()); + auto trans_grad = at::zeros_like(trans); - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) + if (input_grad.numel() == 0) { - printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); + THCudaCheck(cudaGetLastError()); + return std::make_tuple(input_grad, trans_grad); } + + dim3 grid(std::min(THCCeilDiv(out_size, 512L), 4096L)); + dim3 block(512); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_DISPATCH_FLOATING_TYPES(out_grad.type(), "dcn_v2_psroi_pooling_cuda_backward", [&] { + DeformablePSROIPoolBackwardAccKernel<<>>( + out_size, + out_grad.contiguous().data(), + top_count.contiguous().data(), + num_bbox, + spatial_scale, + channels, + height, + width, + pooled_height, + pooled_width, + output_dim, + input_grad.contiguous().data(), + trans_grad.contiguous().data(), + input.contiguous().data(), + bbox.contiguous().data(), + trans.contiguous().data(), + no_trans, + trans_std, + sample_per_part, + group_size, + part_size, + num_classes, + channels_each_class); + }); + THCudaCheck(cudaGetLastError()); + return std::make_tuple(input_grad, trans_grad); } \ No newline at end of file diff --git a/src/cuda/dcn_v2_psroi_pooling_cuda.h b/src/cuda/dcn_v2_psroi_pooling_cuda.h deleted file mode 100644 index 5fa2c6c..0000000 --- a/src/cuda/dcn_v2_psroi_pooling_cuda.h +++ /dev/null @@ -1,66 +0,0 @@ -/*! - * Copyright (c) 2017 Microsoft - * Licensed under The MIT License [see LICENSE for details] - * \file deformable_psroi_pooling.cu - * \brief - * \author Yi Li, Guodong Zhang, Jifeng Dai -*/ -/***************** Adapted by Charles Shang *********************/ - -#ifndef DCN_V2_PSROI_POOLING_CUDA -#define DCN_V2_PSROI_POOLING_CUDA - -#ifdef __cplusplus -extern "C" -{ -#endif - - void DeformablePSROIPoolForward(cudaStream_t stream, - const float *data, - const float *bbox, - const float *trans, - float *out, - float *top_count, - const int batch, - const int channels, - const int height, - const int width, - const int num_bbox, - const int channels_trans, - const int no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std); - - void DeformablePSROIPoolBackwardAcc(cudaStream_t stream, - const float *out_grad, - const float *data, - const float *bbox, - const float *trans, - const float *top_count, - float *in_grad, - float *trans_grad, - const int batch, - const int channels, - const int height, - const int width, - const int num_bbox, - const int channels_trans, - const int no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std); - -#ifdef __cplusplus -} -#endif - -#endif \ No newline at end of file diff --git a/src/cuda/dcn_v2_psroi_pooling_cuda_double.cu b/src/cuda/dcn_v2_psroi_pooling_cuda_double.cu deleted file mode 100644 index ce05cc9..0000000 --- a/src/cuda/dcn_v2_psroi_pooling_cuda_double.cu +++ /dev/null @@ -1,368 +0,0 @@ -/*! - * Copyright (c) 2017 Microsoft - * Licensed under The MIT License [see LICENSE for details] - * \file deformable_psroi_pooling.cu - * \brief - * \author Yi Li, Guodong Zhang, Jifeng Dai -*/ -/***************** Adapted by Charles Shang *********************/ -#include "dcn_v2_psroi_pooling_cuda_double.h" -#include -#include -#include - -#define CUDA_KERNEL_LOOP(i, n) \ - for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ - i < (n); \ - i += blockDim.x * gridDim.x) - -const int CUDA_NUM_THREADS = 1024; -inline int GET_BLOCKS(const int N) -{ - return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; -} - -#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 -#else -__device__ double atomicAdd(double* address, double val) -{ - unsigned long long int* address_as_ull = (unsigned long long int*)address; - unsigned long long int old = *address_as_ull, assumed; - do { - assumed = old; - old = atomicCAS(address_as_ull, assumed, - __double_as_longlong(val + __longlong_as_double(assumed))); - } while (assumed != old); - return __longlong_as_double(old); -} -#endif - -__device__ double bilinear_interp( - const double *data, - const double x, - const double y, - const int width, - const int height) -{ - int x1 = floor(x); - int x2 = ceil(x); - int y1 = floor(y); - int y2 = ceil(y); - double dist_x = (double)(x - x1); - double dist_y = (double)(y - y1); - double value11 = data[y1 * width + x1]; - double value12 = data[y2 * width + x1]; - double value21 = data[y1 * width + x2]; - double value22 = data[y2 * width + x2]; - double value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; - return value; -} - -__global__ void DeformablePSROIPoolForwardKernel( - const int count, - const double *bottom_data, - const double spatial_scale, - const int channels, - const int height, const int width, - const int pooled_height, const int pooled_width, - const double *bottom_rois, const double *bottom_trans, - const int no_trans, - const double trans_std, - const int sample_per_part, - const int output_dim, - const int group_size, - const int part_size, - const int num_classes, - const int channels_each_class, - double *top_data, - double *top_count) -{ - CUDA_KERNEL_LOOP(index, count) - { - // The output is in order (n, ctop, ph, pw) - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int ctop = (index / pooled_width / pooled_height) % output_dim; - int n = index / pooled_width / pooled_height / output_dim; - - // [start, end) interval for spatial sampling - const double *offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; - double roi_start_w = (double)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; - double roi_start_h = (double)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; - double roi_end_w = (double)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; - double roi_end_h = (double)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; - - // Force too small ROIs to be 1x1 - double roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 - double roi_height = max(roi_end_h - roi_start_h, 0.1); - - // Compute w and h at bottom - double bin_size_h = roi_height / (double)(pooled_height); - double bin_size_w = roi_width / (double)(pooled_width); - - double sub_bin_size_h = bin_size_h / (double)(sample_per_part); - double sub_bin_size_w = bin_size_w / (double)(sample_per_part); - - int part_h = floor((double)(ph) / pooled_height * part_size); - int part_w = floor((double)(pw) / pooled_width * part_size); - int class_id = ctop / channels_each_class; - double trans_x = no_trans ? (double)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; - double trans_y = no_trans ? (double)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; - - double wstart = (double)(pw)*bin_size_w + roi_start_w; - wstart += trans_x * roi_width; - double hstart = (double)(ph)*bin_size_h + roi_start_h; - hstart += trans_y * roi_height; - - double sum = 0; - int count = 0; - int gw = floor((double)(pw)*group_size / pooled_width); - int gh = floor((double)(ph)*group_size / pooled_height); - gw = min(max(gw, 0), group_size - 1); - gh = min(max(gh, 0), group_size - 1); - - const double *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; - for (int ih = 0; ih < sample_per_part; ih++) - { - for (int iw = 0; iw < sample_per_part; iw++) - { - double w = wstart + iw * sub_bin_size_w; - double h = hstart + ih * sub_bin_size_h; - // bilinear interpolation - if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) - { - continue; - } - w = min(max(w, 0.), width - 1.); - h = min(max(h, 0.), height - 1.); - int c = (ctop * group_size + gh) * group_size + gw; - double val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); - sum += val; - count++; - } - } - top_data[index] = count == 0 ? (double)(0) : sum / count; - top_count[index] = count; - } -} - -__global__ void DeformablePSROIPoolBackwardAccKernel( - const int count, - const double *top_diff, - const double *top_count, - const int num_rois, - const double spatial_scale, - const int channels, - const int height, const int width, - const int pooled_height, const int pooled_width, - const int output_dim, - double *bottom_data_diff, double *bottom_trans_diff, - const double *bottom_data, - const double *bottom_rois, - const double *bottom_trans, - const int no_trans, - const double trans_std, - const int sample_per_part, - const int group_size, - const int part_size, - const int num_classes, - const int channels_each_class) -{ - CUDA_KERNEL_LOOP(index, count) - { - // The output is in order (n, ctop, ph, pw) - int pw = index % pooled_width; - int ph = (index / pooled_width) % pooled_height; - int ctop = (index / pooled_width / pooled_height) % output_dim; - int n = index / pooled_width / pooled_height / output_dim; - - // [start, end) interval for spatial sampling - const double *offset_bottom_rois = bottom_rois + n * 5; - int roi_batch_ind = offset_bottom_rois[0]; - double roi_start_w = (double)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; - double roi_start_h = (double)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; - double roi_end_w = (double)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; - double roi_end_h = (double)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; - - // Force too small ROIs to be 1x1 - double roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 - double roi_height = max(roi_end_h - roi_start_h, 0.1); - - // Compute w and h at bottom - double bin_size_h = roi_height / (double)(pooled_height); - double bin_size_w = roi_width / (double)(pooled_width); - - double sub_bin_size_h = bin_size_h / (double)(sample_per_part); - double sub_bin_size_w = bin_size_w / (double)(sample_per_part); - - int part_h = floor((double)(ph) / pooled_height * part_size); - int part_w = floor((double)(pw) / pooled_width * part_size); - int class_id = ctop / channels_each_class; - double trans_x = no_trans ? (double)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; - double trans_y = no_trans ? (double)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; - - double wstart = (double)(pw)*bin_size_w + roi_start_w; - wstart += trans_x * roi_width; - double hstart = (double)(ph)*bin_size_h + roi_start_h; - hstart += trans_y * roi_height; - - if (top_count[index] <= 0) - { - continue; - } - double diff_val = top_diff[index] / top_count[index]; - const double *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; - double *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; - int gw = floor((double)(pw)*group_size / pooled_width); - int gh = floor((double)(ph)*group_size / pooled_height); - gw = min(max(gw, 0), group_size - 1); - gh = min(max(gh, 0), group_size - 1); - - for (int ih = 0; ih < sample_per_part; ih++) - { - for (int iw = 0; iw < sample_per_part; iw++) - { - double w = wstart + iw * sub_bin_size_w; - double h = hstart + ih * sub_bin_size_h; - // bilinear interpolation - if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) - { - continue; - } - w = min(max(w, 0.), width - 1.); - h = min(max(h, 0.), height - 1.); - int c = (ctop * group_size + gh) * group_size + gw; - // backward on feature - int x0 = floor(w); - int x1 = ceil(w); - int y0 = floor(h); - int y1 = ceil(h); - double dist_x = w - x0, dist_y = h - y0; - double q00 = (1 - dist_x) * (1 - dist_y); - double q01 = (1 - dist_x) * dist_y; - double q10 = dist_x * (1 - dist_y); - double q11 = dist_x * dist_y; - int bottom_index_base = c * height * width; - atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); - atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); - atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); - atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); - - if (no_trans) - { - continue; - } - double U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; - double U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; - double U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; - double U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; - double diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; - diff_x *= roi_width; - double diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; - diff_y *= roi_height; - - atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); - atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); - } - } - } -} - -void DeformablePSROIPoolForward(cudaStream_t stream, - const double *data, - const double *bbox, - const double *trans, - double *out, - double *top_count, - const int batch, - const int channels, - const int height, - const int width, - const int num_bbox, - const int channels_trans, - const int no_trans, - const double spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const double trans_std) -{ - - const double *bottom_data = data; - const double *bottom_rois = bbox; - const double *bottom_trans = no_trans ? NULL : trans; - double *top_data = out; - double *top_count_data = top_count; - - const int pooled_height = pooled_size; - const int pooled_width = pooled_size; - const int count = num_bbox * output_dim * pooled_height * pooled_width; - const int num_classes = no_trans ? 1 : channels_trans / 2; - const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; - - DeformablePSROIPoolForwardKernel<<>>( - count, bottom_data, spatial_scale, channels, height, width, pooled_height, pooled_width, - bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, output_dim, - group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); - } -} - -void DeformablePSROIPoolBackwardAcc(cudaStream_t stream, - const double *out_grad, - const double *data, - const double *bbox, - const double *trans, - const double *top_count, - double *in_grad, - double *trans_grad, - const int batch, - const int channels, - const int height, - const int width, - const int num_bbox, - const int channels_trans, - const int no_trans, - const double spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const double trans_std) -{ - // LOG(INFO) << "DeformablePSROIPoolBackward"; - const double *top_diff = out_grad; - const double *bottom_data = data; - const double *bottom_rois = bbox; - const double *bottom_trans = no_trans ? NULL : trans; - double *bottom_data_diff = in_grad; - double *bottom_trans_diff = no_trans ? NULL : trans_grad; - const double *top_count_data = top_count; - - const int num_rois = num_bbox; - const int pooled_height = pooled_size; - const int pooled_width = pooled_size; - const int count = num_bbox * output_dim * pooled_height * pooled_width; - const int num_classes = no_trans ? 1 : channels_trans / 2; - const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; - - DeformablePSROIPoolBackwardAccKernel<<>>( - count, top_diff, top_count_data, num_rois, spatial_scale, channels, height, width, - pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, - bottom_data, bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, - group_size, part_size, num_classes, channels_each_class); - - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) - { - printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); - } -} \ No newline at end of file diff --git a/src/cuda/dcn_v2_psroi_pooling_cuda_double.h b/src/cuda/dcn_v2_psroi_pooling_cuda_double.h deleted file mode 100644 index 8a16f72..0000000 --- a/src/cuda/dcn_v2_psroi_pooling_cuda_double.h +++ /dev/null @@ -1,66 +0,0 @@ -/*! - * Copyright (c) 2017 Microsoft - * Licensed under The MIT License [see LICENSE for details] - * \file deformable_psroi_pooling.cu - * \brief - * \author Yi Li, Guodong Zhang, Jifeng Dai -*/ -/***************** Adapted by Charles Shang *********************/ - -#ifndef DCN_V2_PSROI_POOLING_CUDA_DOUBLE -#define DCN_V2_PSROI_POOLING_CUDA_DOUBLE - -#ifdef __cplusplus -extern "C" -{ -#endif - - void DeformablePSROIPoolForward(cudaStream_t stream, - const double *data, - const double *bbox, - const double *trans, - double *out, - double *top_count, - const int batch, - const int channels, - const int height, - const int width, - const int num_bbox, - const int channels_trans, - const int no_trans, - const double spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const double trans_std); - - void DeformablePSROIPoolBackwardAcc(cudaStream_t stream, - const double *out_grad, - const double *data, - const double *bbox, - const double *trans, - const double *top_count, - double *in_grad, - double *trans_grad, - const int batch, - const int channels, - const int height, - const int width, - const int num_bbox, - const int channels_trans, - const int no_trans, - const double spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const double trans_std); - -#ifdef __cplusplus -} -#endif - -#endif \ No newline at end of file diff --git a/src/cuda/vision.h b/src/cuda/vision.h new file mode 100644 index 0000000..e42a2a7 --- /dev/null +++ b/src/cuda/vision.h @@ -0,0 +1,60 @@ +#pragma once +#include + +at::Tensor +dcn_v2_cuda_forward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int deformable_group); + +std::vector +dcn_v2_cuda_backward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const at::Tensor &grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); + + +std::tuple +dcn_v2_psroi_pooling_cuda_forward(const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); + +std::tuple +dcn_v2_psroi_pooling_cuda_backward(const at::Tensor &out_grad, + const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const at::Tensor &top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); \ No newline at end of file diff --git a/src/dcn_v2.c b/src/dcn_v2.c deleted file mode 100644 index b440d3f..0000000 --- a/src/dcn_v2.c +++ /dev/null @@ -1,30 +0,0 @@ -#include -#include -#include - -void dcn_v2_forward(THFloatTensor *input, THFloatTensor *weight, - THFloatTensor *bias, THFloatTensor *ones, - THFloatTensor *offset, THFloatTensor *mask, - THFloatTensor *output, THFloatTensor *columns, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group) -{ - printf("only implemented in GPU"); -} - void dcn_v2_backward(THFloatTensor *input, THFloatTensor *weight, - THFloatTensor *bias, THFloatTensor *ones, - THFloatTensor *offset, THFloatTensor *mask, - THFloatTensor *output, THFloatTensor *columns, - THFloatTensor *grad_input, THFloatTensor *grad_weight, - THFloatTensor *grad_bias, THFloatTensor *grad_offset, - THFloatTensor *grad_mask, THFloatTensor *grad_output, - int kernel_h, int kernel_w, - int stride_h, int stride_w, - int pad_h, int pad_w, - int dilation_h, int dilation_w, - int deformable_group) -{ - printf("only implemented in GPU"); -} \ No newline at end of file diff --git a/src/dcn_v2.h b/src/dcn_v2.h index 1a97ff0..23f5caf 100644 --- a/src/dcn_v2.h +++ b/src/dcn_v2.h @@ -1,20 +1,145 @@ -void dcn_v2_forward(THFloatTensor *input, THFloatTensor *weight, - THFloatTensor *bias, THFloatTensor *ones, - THFloatTensor *offset, THFloatTensor *mask, - THFloatTensor *output, THFloatTensor *columns, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group); -void dcn_v2_backward(THFloatTensor *input, THFloatTensor *weight, - THFloatTensor *bias, THFloatTensor *ones, - THFloatTensor *offset, THFloatTensor *mask, - THFloatTensor *output, THFloatTensor *columns, - THFloatTensor *grad_input, THFloatTensor *grad_weight, - THFloatTensor *grad_bias, THFloatTensor *grad_offset, - THFloatTensor *grad_mask, THFloatTensor *grad_output, - int kernel_h, int kernel_w, - int stride_h, int stride_w, - int pad_h, int pad_w, - int dilation_h, int dilation_w, - int deformable_group); \ No newline at end of file +#pragma once + +#include "cpu/vision.h" + +#ifdef WITH_CUDA +#include "cuda/vision.h" +#endif + +at::Tensor +dcn_v2_forward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const int kernel_h, + const int kernel_w, + const int stride_h, + const int stride_w, + const int pad_h, + const int pad_w, + const int dilation_h, + const int dilation_w, + const int deformable_group) +{ + if (input.type().is_cuda()) + { +#ifdef WITH_CUDA + return dcn_v2_cuda_forward(input, weight, bias, offset, mask, + kernel_h, kernel_w, + stride_h, stride_w, + pad_h, pad_w, + dilation_h, dilation_w, + deformable_group); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::vector +dcn_v2_backward(const at::Tensor &input, + const at::Tensor &weight, + const at::Tensor &bias, + const at::Tensor &offset, + const at::Tensor &mask, + const at::Tensor &grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + if (input.type().is_cuda()) + { +#ifdef WITH_CUDA + return dcn_v2_cuda_backward(input, + weight, + bias, + offset, + mask, + grad_output, + kernel_h, kernel_w, + stride_h, stride_w, + pad_h, pad_w, + dilation_h, dilation_w, + deformable_group); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::tuple +dcn_v2_psroi_pooling_forward(const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + if (input.type().is_cuda()) + { +#ifdef WITH_CUDA + return dcn_v2_psroi_pooling_cuda_forward(input, + bbox, + trans, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} + +std::tuple +dcn_v2_psroi_pooling_backward(const at::Tensor &out_grad, + const at::Tensor &input, + const at::Tensor &bbox, + const at::Tensor &trans, + const at::Tensor &top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + if (input.type().is_cuda()) + { +#ifdef WITH_CUDA + return dcn_v2_psroi_pooling_cuda_backward(out_grad, + input, + bbox, + trans, + top_count, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); +#else + AT_ERROR("Not compiled with GPU support"); +#endif + } + AT_ERROR("Not implemented on the CPU"); +} \ No newline at end of file diff --git a/src/dcn_v2_cuda.c b/src/dcn_v2_cuda.c deleted file mode 100644 index 1503b5d..0000000 --- a/src/dcn_v2_cuda.c +++ /dev/null @@ -1,335 +0,0 @@ -#include -#include "cuda/dcn_v2_im2col_cuda.h" -#include "cuda/dcn_v2_psroi_pooling_cuda.h" - -extern THCState *state; - -// author: Charles Shang -// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu - -void dcn_v2_cuda_forward(THCudaTensor *input, THCudaTensor *weight, - THCudaTensor *bias, THCudaTensor *ones, - THCudaTensor *offset, THCudaTensor *mask, - THCudaTensor *output, THCudaTensor *columns, - int kernel_h, int kernel_w, - const int stride_h, const int stride_w, - const int pad_h, const int pad_w, - const int dilation_h, const int dilation_w, - const int deformable_group) -{ - THCAssertSameGPU(THCudaTensor_checkGPU(state, 8, input, weight, bias, ones, offset, mask, output, columns)); - THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); - THArgCheck(THCudaTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); - - const int batch = THCudaTensor_size(state, input, 0); - const int channels = THCudaTensor_size(state, input, 1); - const int height = THCudaTensor_size(state, input, 2); - const int width = THCudaTensor_size(state, input, 3); - - const int channels_out = THCudaTensor_size(state, weight, 0); - const int channels_kernel = THCudaTensor_size(state, weight, 1); - const int kernel_h_ = THCudaTensor_size(state, weight, 2); - const int kernel_w_ = THCudaTensor_size(state, weight, 3); - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", - kernel_h_, kernel_w, kernel_h_, kernel_w_); - if (channels != channels_kernel) - THError("Input shape and kernel channels wont match: (%d vs %d).", - channels, channels_kernel); - - const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - if (THCudaTensor_nDimension(state, ones) != 2 || - THCudaTensor_size(state, ones, 0) * THCudaTensor_size(state, ones, 1) < height_out * width_out) - { - // Resize plane and fill with ones... - THCudaTensor_resize2d(state, ones, height_out, width_out); - THCudaTensor_fill(state, ones, 1); - } - - // resize output - THCudaTensor_resize4d(state, output, batch, channels_out, height_out, width_out); - // resize temporary columns - THCudaTensor_resize2d(state, columns, channels * kernel_h * kernel_w, 1 * height_out * width_out); - - THCudaTensor *input_n = THCudaTensor_new(state); - THCudaTensor *offset_n = THCudaTensor_new(state); - THCudaTensor *mask_n = THCudaTensor_new(state); - THCudaTensor *output_n = THCudaTensor_new(state); - - for (int b = 0; b < batch; b++) - { - THCudaTensor_select(state, input_n, input, 0, b); - THCudaTensor_select(state, offset_n, offset, 0, b); - THCudaTensor_select(state, mask_n, mask, 0, b); - THCudaTensor_select(state, output_n, output, 0, b); - - // Do Bias first: - // M,N,K are dims of matrix A and B - // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) - // (N x 1) (1 x M) - long m_ = channels_out; - long n_ = height_out * width_out; - long k_ = 1; - THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, - THCudaTensor_data(state, ones), k_, - THCudaTensor_data(state, bias), k_, 0.0f, - THCudaTensor_data(state, output_n), n_); - - modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), - THCudaTensor_data(state, input_n), THCudaTensor_data(state, offset_n), - THCudaTensor_data(state, mask_n), - 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, - deformable_group, THCudaTensor_data(state, columns)); - - //(k * m) x (m * n) - // Y = WC - long m = channels_out; - long n = height_out * width_out; - long k = channels * kernel_h * kernel_w; - THCudaBlas_Sgemm(state, 'n', 'n', n, m, k, 1.0f, - THCudaTensor_data(state, columns), n, - THCudaTensor_data(state, weight), k, 1.0f, - THCudaTensor_data(state, output_n), n); - } - THCudaTensor_free(state, input_n); - THCudaTensor_free(state, offset_n); - THCudaTensor_free(state, mask_n); - THCudaTensor_free(state, output_n); -} - -void dcn_v2_cuda_backward(THCudaTensor *input, THCudaTensor *weight, - THCudaTensor *bias, THCudaTensor *ones, - THCudaTensor *offset, THCudaTensor *mask, - THCudaTensor *columns, - THCudaTensor *grad_input, THCudaTensor *grad_weight, - THCudaTensor *grad_bias, THCudaTensor *grad_offset, - THCudaTensor *grad_mask, THCudaTensor *grad_output, - int kernel_h, int kernel_w, - int stride_h, int stride_w, - int pad_h, int pad_w, - int dilation_h, int dilation_w, - int deformable_group) -{ - THCAssertSameGPU(THCudaTensor_checkGPU(state, 13, input, weight, bias, ones, offset, mask, columns, - grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output)); - THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); - THArgCheck(THCudaTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); - - const int batch = THCudaTensor_size(state, input, 0); - const int channels = THCudaTensor_size(state, input, 1); - const int height = THCudaTensor_size(state, input, 2); - const int width = THCudaTensor_size(state, input, 3); - - const int channels_out = THCudaTensor_size(state, weight, 0); - const int channels_kernel = THCudaTensor_size(state, weight, 1); - const int kernel_h_ = THCudaTensor_size(state, weight, 2); - const int kernel_w_ = THCudaTensor_size(state, weight, 3); - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", - kernel_h_, kernel_w, kernel_h_, kernel_w_); - if (channels != channels_kernel) - THError("Input shape and kernel channels wont match: (%d vs %d).", - channels, channels_kernel); - - const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - if (THCudaTensor_nDimension(state, ones) != 2 || - THCudaTensor_size(state, ones, 0) * THCudaTensor_size(state, ones, 1) < height_out * width_out) - { - // Resize plane and fill with ones... - THCudaTensor_resize2d(state, ones, height_out, width_out); - THCudaTensor_fill(state, ones, 1.0f); - } - - THCudaTensor_resize4d(state, grad_input, batch, channels, height, width); - THCudaTensor_resize2d(state, columns, channels * kernel_h * kernel_w, height_out * width_out); - - THCudaTensor *input_n = THCudaTensor_new(state); - THCudaTensor *offset_n = THCudaTensor_new(state); - THCudaTensor *mask_n = THCudaTensor_new(state); - - THCudaTensor *grad_output_n = THCudaTensor_new(state); - THCudaTensor *grad_input_n = THCudaTensor_new(state); - THCudaTensor *grad_offset_n = THCudaTensor_new(state); - THCudaTensor *grad_mask_n = THCudaTensor_new(state); - - for (int b = 0; b < batch; b++) - { - THCudaTensor_select(state, input_n, input, 0, b); - THCudaTensor_select(state, offset_n, offset, 0, b); - THCudaTensor_select(state, mask_n, mask, 0, b); - THCudaTensor_select(state, grad_output_n, grad_output, 0, b); - THCudaTensor_select(state, grad_input_n, grad_input, 0, b); - THCudaTensor_select(state, grad_offset_n, grad_offset, 0, b); - THCudaTensor_select(state, grad_mask_n, grad_mask, 0, b); - - long m = channels * kernel_h * kernel_w; - long n = height_out * width_out; - long k = channels_out; - - THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, - THCudaTensor_data(state, grad_output_n), n, - THCudaTensor_data(state, weight), m, 0.0f, - THCudaTensor_data(state, columns), n); - - // gradient w.r.t. input coordinate data - modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), - THCudaTensor_data(state, columns), - THCudaTensor_data(state, input_n), - THCudaTensor_data(state, offset_n), - THCudaTensor_data(state, mask_n), - 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, - pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, - THCudaTensor_data(state, grad_offset_n), - THCudaTensor_data(state, grad_mask_n)); - // gradient w.r.t. input data - modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), - THCudaTensor_data(state, columns), - THCudaTensor_data(state, offset_n), - THCudaTensor_data(state, mask_n), - 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, - pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, - THCudaTensor_data(state, grad_input_n)); - - // gradient w.r.t. weight, dWeight should accumulate across the batch and group - modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), - THCudaTensor_data(state, input_n), - THCudaTensor_data(state, offset_n), - THCudaTensor_data(state, mask_n), - 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, - pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, - THCudaTensor_data(state, columns)); - long m_ = channels_out; - long n_ = channels * kernel_h * kernel_w; - long k_ = height_out * width_out; - - THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, - THCudaTensor_data(state, columns), k_, - THCudaTensor_data(state, grad_output_n), k_, 1.0f, - THCudaTensor_data(state, grad_weight), n_); - - // gradient w.r.t. bias - // long m_ = channels_out; - // long k__ = height_out * width_out; - THCudaBlas_Sgemv(state, - 't', - k_, m_, 1.0f, - THCudaTensor_data(state, grad_output_n), k_, - THCudaTensor_data(state, ones), 1, 1.0f, - THCudaTensor_data(state, grad_bias), 1); - } - - THCudaTensor_free(state, input_n); - THCudaTensor_free(state, offset_n); - THCudaTensor_free(state, mask_n); - - THCudaTensor_free(state, grad_output_n); - THCudaTensor_free(state, grad_input_n); - THCudaTensor_free(state, grad_offset_n); - THCudaTensor_free(state, grad_mask_n); -} - -void dcn_v2_psroi_pooling_cuda_forward(THCudaTensor * input, THCudaTensor * bbox, - THCudaTensor * trans, - THCudaTensor * out, THCudaTensor * top_count, - const int no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) -{ - THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); - THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, bbox, trans, out, top_count)); - - const int batch = THCudaTensor_size(state, input, 0); - const int channels = THCudaTensor_size(state, input, 1); - const int height = THCudaTensor_size(state, input, 2); - const int width = THCudaTensor_size(state, input, 3); - const int channels_trans = no_trans? 2 : THCudaTensor_size(state, trans, 1); - - const int num_bbox = THCudaTensor_size(state, bbox, 0); - if (num_bbox != THCudaTensor_size(state, out, 0)) - THError("Output shape and bbox number wont match: (%d vs %d).", - THCudaTensor_size(state, out, 0), num_bbox); - - DeformablePSROIPoolForward(THCState_getCurrentStream(state), - THCudaTensor_data(state, input), - THCudaTensor_data(state, bbox), - THCudaTensor_data(state, trans), - THCudaTensor_data(state, out), - THCudaTensor_data(state, top_count), - batch, channels, height, width, - num_bbox, - channels_trans, - no_trans, - spatial_scale, - output_dim, - group_size, - pooled_size, - part_size, - sample_per_part, - trans_std); -} - -void dcn_v2_psroi_pooling_cuda_backward(THCudaTensor * out_grad, - THCudaTensor * input, THCudaTensor * bbox, - THCudaTensor * trans, THCudaTensor * top_count, - THCudaTensor * input_grad, THCudaTensor * trans_grad, - const int no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std) -{ - THArgCheck(THCudaTensor_isContiguous(state, out_grad), 0, "out_grad tensor has to be contiguous"); - THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); - THCAssertSameGPU(THCudaTensor_checkGPU(state, 7, input, bbox, trans, out_grad, top_count, - input_grad, trans_grad)); - - const int batch = THCudaTensor_size(state, input, 0); - const int channels = THCudaTensor_size(state, input, 1); - const int height = THCudaTensor_size(state, input, 2); - const int width = THCudaTensor_size(state, input, 3); - const int channels_trans = no_trans? 2 : THCudaTensor_size(state, trans, 1); - - const int num_bbox = THCudaTensor_size(state, bbox, 0); - if (num_bbox != THCudaTensor_size(state, out_grad, 0)) - THError("Output shape and bbox number wont match: (%d vs %d).", - THCudaTensor_size(state, out_grad, 0), num_bbox); - - DeformablePSROIPoolBackwardAcc(THCState_getCurrentStream(state), - THCudaTensor_data(state, out_grad), - THCudaTensor_data(state, input), - THCudaTensor_data(state, bbox), - THCudaTensor_data(state, trans), - THCudaTensor_data(state, top_count), - THCudaTensor_data(state, input_grad), - THCudaTensor_data(state, trans_grad), - batch, channels, height, width, num_bbox, - channels_trans, - no_trans, - spatial_scale, - output_dim, - group_size, - pooled_size, - part_size, - sample_per_part, - trans_std); -} \ No newline at end of file diff --git a/src/dcn_v2_cuda.h b/src/dcn_v2_cuda.h deleted file mode 100644 index 70a27a8..0000000 --- a/src/dcn_v2_cuda.h +++ /dev/null @@ -1,60 +0,0 @@ -// #ifndef DCN_V2_CUDA -// #define DCN_V2_CUDA - -// #ifdef __cplusplus -// extern "C" -// { -// #endif - -void dcn_v2_cuda_forward(THCudaTensor *input, THCudaTensor *weight, - THCudaTensor *bias, THCudaTensor *ones, - THCudaTensor *offset, THCudaTensor *mask, - THCudaTensor *output, THCudaTensor *columns, - int kernel_h, int kernel_w, - const int stride_h, const int stride_w, - const int pad_h, const int pad_w, - const int dilation_h, const int dilation_w, - const int deformable_group); -void dcn_v2_cuda_backward(THCudaTensor *input, THCudaTensor *weight, - THCudaTensor *bias, THCudaTensor *ones, - THCudaTensor *offset, THCudaTensor *mask, - THCudaTensor *columns, - THCudaTensor *grad_input, THCudaTensor *grad_weight, - THCudaTensor *grad_bias, THCudaTensor *grad_offset, - THCudaTensor *grad_mask, THCudaTensor *grad_output, - int kernel_h, int kernel_w, - int stride_h, int stride_w, - int pad_h, int pad_w, - int dilation_h, int dilation_w, - int deformable_group); - -void dcn_v2_psroi_pooling_cuda_forward(THCudaTensor * input, THCudaTensor * bbox, - THCudaTensor * trans, - THCudaTensor * out, THCudaTensor * top_count, - const int no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std); - -void dcn_v2_psroi_pooling_cuda_backward(THCudaTensor * out_grad, - THCudaTensor * input, THCudaTensor * bbox, - THCudaTensor * trans, THCudaTensor * top_count, - THCudaTensor * input_grad, THCudaTensor * trans_grad, - const int no_trans, - const float spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const float trans_std); - -// #ifdef __cplusplus -// } -// #endif - -// #endif \ No newline at end of file diff --git a/src/dcn_v2_cuda_double.c b/src/dcn_v2_cuda_double.c deleted file mode 100644 index 021ef12..0000000 --- a/src/dcn_v2_cuda_double.c +++ /dev/null @@ -1,358 +0,0 @@ -#include -#include "cuda/dcn_v2_im2col_cuda_double.h" -#include "cuda/dcn_v2_psroi_pooling_cuda_double.h" - -extern THCState *state; - -// author: Charles Shang -// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu - -void dcn_v2_cuda_forward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, - THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, - THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, - THCudaDoubleTensor *output, THCudaDoubleTensor *columns, - int kernel_h, int kernel_w, - const int stride_h, const int stride_w, - const int pad_h, const int pad_w, - const int dilation_h, const int dilation_w, - const int deformable_group) -{ - THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 8, input, weight, bias, ones, offset, mask, output, columns)); - THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); - THArgCheck(THCudaDoubleTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); - - input = THCudaDoubleTensor_newContiguous(state, input); - offset = THCudaDoubleTensor_newContiguous(state, offset); - mask = THCudaDoubleTensor_newContiguous(state, mask); - weight = THCudaDoubleTensor_newContiguous(state, weight); - - const int batch = THCudaDoubleTensor_size(state, input, 0); - const int channels = THCudaDoubleTensor_size(state, input, 1); - const int height = THCudaDoubleTensor_size(state, input, 2); - const int width = THCudaDoubleTensor_size(state, input, 3); - - const int channels_out = THCudaDoubleTensor_size(state, weight, 0); - const int channels_kernel = THCudaDoubleTensor_size(state, weight, 1); - const int kernel_h_ = THCudaDoubleTensor_size(state, weight, 2); - const int kernel_w_ = THCudaDoubleTensor_size(state, weight, 3); - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", - kernel_h_, kernel_w, kernel_h_, kernel_w_); - if (channels != channels_kernel) - THError("Input shape and kernel channels wont match: (%d vs %d).", - channels, channels_kernel); - - const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - if (THCudaDoubleTensor_nDimension(state, ones) != 2 || - THCudaDoubleTensor_size(state, ones, 0) * THCudaDoubleTensor_size(state, ones, 1) < height_out * width_out) - { - // Resize plane and fill with ones... - THCudaDoubleTensor_resize2d(state, ones, height_out, width_out); - THCudaDoubleTensor_fill(state, ones, 1); - } - - // resize output - THCudaDoubleTensor_resize4d(state, output, batch, channels_out, height_out, width_out); - // resize temporary columns - THCudaDoubleTensor_resize2d(state, columns, channels * kernel_h * kernel_w, 1 * height_out * width_out); - - THCudaDoubleTensor *input_n = THCudaDoubleTensor_new(state); - THCudaDoubleTensor *offset_n = THCudaDoubleTensor_new(state); - THCudaDoubleTensor *mask_n = THCudaDoubleTensor_new(state); - THCudaDoubleTensor *output_n = THCudaDoubleTensor_new(state); - - for (int b = 0; b < batch; b++) - { - THCudaDoubleTensor_select(state, input_n, input, 0, b); - THCudaDoubleTensor_select(state, offset_n, offset, 0, b); - THCudaDoubleTensor_select(state, mask_n, mask, 0, b); - THCudaDoubleTensor_select(state, output_n, output, 0, b); - - // Do Bias first: - // M,N,K are dims of matrix A and B - // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) - // (N x 1) (1 x M) - long m_ = channels_out; - long n_ = height_out * width_out; - long k_ = 1; - THCudaBlas_Dgemm(state, 't', 'n', n_, m_, k_, 1.0, - THCudaDoubleTensor_data(state, ones), k_, - THCudaDoubleTensor_data(state, bias), k_, 0.0, - THCudaDoubleTensor_data(state, output_n), n_); - - modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), - THCudaDoubleTensor_data(state, input_n), THCudaDoubleTensor_data(state, offset_n), - THCudaDoubleTensor_data(state, mask_n), - 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, - pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, - deformable_group, THCudaDoubleTensor_data(state, columns)); - - //(k * m) x (m * n) - // Y = WC - long m = channels_out; - long n = height_out * width_out; - long k = channels * kernel_h * kernel_w; - THCudaBlas_Dgemm(state, 'n', 'n', n, m, k, 1.0f, - THCudaDoubleTensor_data(state, columns), n, - THCudaDoubleTensor_data(state, weight), k, 1.0f, - THCudaDoubleTensor_data(state, output_n), n); - } - THCudaDoubleTensor_free(state, input_n); - THCudaDoubleTensor_free(state, offset_n); - THCudaDoubleTensor_free(state, mask_n); - THCudaDoubleTensor_free(state, output_n); - - THCudaDoubleTensor_free(state, input); - THCudaDoubleTensor_free(state, offset); - THCudaDoubleTensor_free(state, mask); - THCudaDoubleTensor_free(state, weight); -} - -void dcn_v2_cuda_backward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, - THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, - THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, - THCudaDoubleTensor *columns, - THCudaDoubleTensor *grad_input, THCudaDoubleTensor *grad_weight, - THCudaDoubleTensor *grad_bias, THCudaDoubleTensor *grad_offset, - THCudaDoubleTensor *grad_mask, THCudaDoubleTensor *grad_output, - int kernel_h, int kernel_w, - int stride_h, int stride_w, - int pad_h, int pad_w, - int dilation_h, int dilation_w, - int deformable_group) -{ - THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 13, input, weight, bias, ones, offset, mask, columns, - grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output)); - THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); - THArgCheck(THCudaDoubleTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); - - input = THCudaDoubleTensor_newContiguous(state, input); - offset = THCudaDoubleTensor_newContiguous(state, offset); - mask = THCudaDoubleTensor_newContiguous(state, mask); - weight = THCudaDoubleTensor_newContiguous(state, weight); - grad_output = THCudaDoubleTensor_newContiguous(state, grad_output); - - const int batch = THCudaDoubleTensor_size(state, input, 0); - const int channels = THCudaDoubleTensor_size(state, input, 1); - const int height = THCudaDoubleTensor_size(state, input, 2); - const int width = THCudaDoubleTensor_size(state, input, 3); - - const int channels_out = THCudaDoubleTensor_size(state, weight, 0); - const int channels_kernel = THCudaDoubleTensor_size(state, weight, 1); - const int kernel_h_ = THCudaDoubleTensor_size(state, weight, 2); - const int kernel_w_ = THCudaDoubleTensor_size(state, weight, 3); - if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) - THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", - kernel_h_, kernel_w, kernel_h_, kernel_w_); - if (channels != channels_kernel) - THError("Input shape and kernel channels wont match: (%d vs %d).", - channels, channels_kernel); - - const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; - const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; - - if (THCudaDoubleTensor_nDimension(state, ones) != 2 || - THCudaDoubleTensor_size(state, ones, 0) * THCudaDoubleTensor_size(state, ones, 1) < height_out * width_out) - { - // Resize plane and fill with ones... - THCudaDoubleTensor_resize2d(state, ones, height_out, width_out); - THCudaDoubleTensor_fill(state, ones, 1); - } - - // THCudaDoubleTensor_resize4d(state, grad_input, batch, channels, height, width); - THCudaDoubleTensor_resize2d(state, columns, channels * kernel_h * kernel_w, height_out * width_out); - - THCudaDoubleTensor *input_n = THCudaDoubleTensor_new(state); - THCudaDoubleTensor *offset_n = THCudaDoubleTensor_new(state); - THCudaDoubleTensor *mask_n = THCudaDoubleTensor_new(state); - - THCudaDoubleTensor *grad_output_n = THCudaDoubleTensor_new(state); - THCudaDoubleTensor *grad_input_n = THCudaDoubleTensor_new(state); - THCudaDoubleTensor *grad_offset_n = THCudaDoubleTensor_new(state); - THCudaDoubleTensor *grad_mask_n = THCudaDoubleTensor_new(state); - - for (int b = 0; b < batch; b++) - { - THCudaDoubleTensor_select(state, input_n, input, 0, b); - THCudaDoubleTensor_select(state, offset_n, offset, 0, b); - THCudaDoubleTensor_select(state, mask_n, mask, 0, b); - THCudaDoubleTensor_select(state, grad_output_n, grad_output, 0, b); - THCudaDoubleTensor_select(state, grad_input_n, grad_input, 0, b); - THCudaDoubleTensor_select(state, grad_offset_n, grad_offset, 0, b); - THCudaDoubleTensor_select(state, grad_mask_n, grad_mask, 0, b); - - long m = channels * kernel_h * kernel_w; - long n = height_out * width_out; - long k = channels_out; - - THCudaBlas_Dgemm(state, 'n', 't', n, m, k, 1.0, - THCudaDoubleTensor_data(state, grad_output_n), n, - THCudaDoubleTensor_data(state, weight), m, 0.0, - THCudaDoubleTensor_data(state, columns), n); - - // gradient w.r.t. input offset and mask data - modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), - THCudaDoubleTensor_data(state, columns), - THCudaDoubleTensor_data(state, input_n), - THCudaDoubleTensor_data(state, offset_n), - THCudaDoubleTensor_data(state, mask_n), - 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, - pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, - THCudaDoubleTensor_data(state, grad_offset_n), - THCudaDoubleTensor_data(state, grad_mask_n)); - // gradient w.r.t. input data - modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), - THCudaDoubleTensor_data(state, columns), - THCudaDoubleTensor_data(state, offset_n), - THCudaDoubleTensor_data(state, mask_n), - 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, - pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, - THCudaDoubleTensor_data(state, grad_input_n)); - - // gradient w.r.t. weight, dWeight should accumulate across the batch and group - modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), - THCudaDoubleTensor_data(state, input_n), - THCudaDoubleTensor_data(state, offset_n), - THCudaDoubleTensor_data(state, mask_n), - 1, channels, height, width, - height_out, width_out, kernel_h, kernel_w, - pad_h, pad_w, stride_h, stride_w, - dilation_h, dilation_w, deformable_group, - THCudaDoubleTensor_data(state, columns)); - long m_ = channels_out; - long n_ = channels * kernel_h * kernel_w; - long k_ = height_out * width_out; - - THCudaBlas_Dgemm(state, 't', 'n', n_, m_, k_, 1.0, - THCudaDoubleTensor_data(state, columns), k_, - THCudaDoubleTensor_data(state, grad_output_n), k_, 1.0, - THCudaDoubleTensor_data(state, grad_weight), n_); - - // gradient w.r.t. bias - // long m_ = channels_out; - // long k__ = height_out * width_out; - THCudaBlas_Dgemv(state, - 't', - k_, m_, 1.0, - THCudaDoubleTensor_data(state, grad_output_n), k_, - THCudaDoubleTensor_data(state, ones), 1, 1.0, - THCudaDoubleTensor_data(state, grad_bias), 1); - } - - THCudaDoubleTensor_free(state, input_n); - THCudaDoubleTensor_free(state, offset_n); - THCudaDoubleTensor_free(state, mask_n); - - THCudaDoubleTensor_free(state, grad_output_n); - THCudaDoubleTensor_free(state, grad_input_n); - THCudaDoubleTensor_free(state, grad_offset_n); - THCudaDoubleTensor_free(state, grad_mask_n); - - THCudaDoubleTensor_free(state, input); - THCudaDoubleTensor_free(state, offset); - THCudaDoubleTensor_free(state, mask); - THCudaDoubleTensor_free(state, weight); - THCudaDoubleTensor_free(state, grad_output); -} - - -void dcn_v2_psroi_pooling_cuda_forward(THCudaDoubleTensor * input, THCudaDoubleTensor * bbox, - THCudaDoubleTensor * trans, - THCudaDoubleTensor * out, THCudaDoubleTensor * top_count, - const int no_trans, - const double spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const double trans_std) -{ - THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); - THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 5, input, bbox, trans, out, top_count)); - - const int batch = THCudaDoubleTensor_size(state, input, 0); - const int channels = THCudaDoubleTensor_size(state, input, 1); - const int height = THCudaDoubleTensor_size(state, input, 2); - const int width = THCudaDoubleTensor_size(state, input, 3); - const int channels_trans = no_trans? 2 : THCudaDoubleTensor_size(state, trans, 1); - - const int num_bbox = THCudaDoubleTensor_size(state, bbox, 0); - if (num_bbox != THCudaDoubleTensor_size(state, out, 0)) - THError("Output shape and bbox number wont match: (%d vs %d).", - THCudaDoubleTensor_size(state, out, 0), num_bbox); - - DeformablePSROIPoolForward(THCState_getCurrentStream(state), - THCudaDoubleTensor_data(state, input), - THCudaDoubleTensor_data(state, bbox), - THCudaDoubleTensor_data(state, trans), - THCudaDoubleTensor_data(state, out), - THCudaDoubleTensor_data(state, top_count), - batch, channels, height, width, - num_bbox, - channels_trans, - no_trans, - spatial_scale, - output_dim, - group_size, - pooled_size, - part_size, - sample_per_part, - trans_std); -} - -void dcn_v2_psroi_pooling_cuda_backward(THCudaDoubleTensor * out_grad, - THCudaDoubleTensor * input, THCudaDoubleTensor * bbox, - THCudaDoubleTensor * trans, THCudaDoubleTensor * top_count, - THCudaDoubleTensor * input_grad, THCudaDoubleTensor * trans_grad, - const int no_trans, - const double spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const double trans_std) -{ - THArgCheck(THCudaDoubleTensor_isContiguous(state, out_grad), 0, "out_grad tensor has to be contiguous"); - THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); - THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 7, input, bbox, trans, out_grad, top_count, - input_grad, trans_grad)); - - const int batch = THCudaDoubleTensor_size(state, input, 0); - const int channels = THCudaDoubleTensor_size(state, input, 1); - const int height = THCudaDoubleTensor_size(state, input, 2); - const int width = THCudaDoubleTensor_size(state, input, 3); - const int channels_trans = no_trans? 2 : THCudaDoubleTensor_size(state, trans, 1); - - const int num_bbox = THCudaDoubleTensor_size(state, bbox, 0); - if (num_bbox != THCudaDoubleTensor_size(state, out_grad, 0)) - THError("Output shape and bbox number wont match: (%d vs %d).", - THCudaDoubleTensor_size(state, out_grad, 0), num_bbox); - - DeformablePSROIPoolBackwardAcc(THCState_getCurrentStream(state), - THCudaDoubleTensor_data(state, out_grad), - THCudaDoubleTensor_data(state, input), - THCudaDoubleTensor_data(state, bbox), - THCudaDoubleTensor_data(state, trans), - THCudaDoubleTensor_data(state, top_count), - THCudaDoubleTensor_data(state, input_grad), - THCudaDoubleTensor_data(state, trans_grad), - batch, channels, height, width, num_bbox, - channels_trans, - no_trans, - spatial_scale, - output_dim, - group_size, - pooled_size, - part_size, - sample_per_part, - trans_std); -} \ No newline at end of file diff --git a/src/dcn_v2_cuda_double.h b/src/dcn_v2_cuda_double.h deleted file mode 100644 index 826cb2b..0000000 --- a/src/dcn_v2_cuda_double.h +++ /dev/null @@ -1,61 +0,0 @@ -// #ifndef DCN_V2_CUDA -// #define DCN_V2_CUDA - -// #ifdef __cplusplus -// extern "C" -// { -// #endif - -void dcn_v2_cuda_forward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, - THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, - THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, - THCudaDoubleTensor *output, THCudaDoubleTensor *columns, - int kernel_h, int kernel_w, - const int stride_h, const int stride_w, - const int pad_h, const int pad_w, - const int dilation_h, const int dilation_w, - const int deformable_group); -void dcn_v2_cuda_backward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, - THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, - THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, - THCudaDoubleTensor *columns, - THCudaDoubleTensor *grad_input, THCudaDoubleTensor *grad_weight, - THCudaDoubleTensor *grad_bias, THCudaDoubleTensor *grad_offset, - THCudaDoubleTensor *grad_mask, THCudaDoubleTensor *grad_output, - int kernel_h, int kernel_w, - int stride_h, int stride_w, - int pad_h, int pad_w, - int dilation_h, int dilation_w, - int deformable_group); - -void dcn_v2_psroi_pooling_cuda_forward(THCudaDoubleTensor * input, THCudaDoubleTensor * bbox, - THCudaDoubleTensor * trans, - THCudaDoubleTensor * out, THCudaDoubleTensor * top_count, - const int no_trans, - const double spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const double trans_std); - -void dcn_v2_psroi_pooling_cuda_backward(THCudaDoubleTensor * out_grad, - THCudaDoubleTensor * input, THCudaDoubleTensor * bbox, - THCudaDoubleTensor * trans, THCudaDoubleTensor * top_count, - THCudaDoubleTensor * input_grad, THCudaDoubleTensor * trans_grad, - const int no_trans, - const double spatial_scale, - const int output_dim, - const int group_size, - const int pooled_size, - const int part_size, - const int sample_per_part, - const double trans_std); - - -// #ifdef __cplusplus -// } -// #endif - -// #endif \ No newline at end of file diff --git a/src/dcn_v2_double.c b/src/dcn_v2_double.c deleted file mode 100644 index 2b86545..0000000 --- a/src/dcn_v2_double.c +++ /dev/null @@ -1,30 +0,0 @@ -#include -#include -#include - -void dcn_v2_forward(THDoubleTensor *input, THDoubleTensor *weight, - THDoubleTensor *bias, THDoubleTensor *ones, - THDoubleTensor *offset, THDoubleTensor *mask, - THDoubleTensor *output, THDoubleTensor *columns, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group) -{ - printf("only implemented in GPU"); -} -void dcn_v2_backward(THDoubleTensor *input, THDoubleTensor *weight, - THDoubleTensor *bias, THDoubleTensor *ones, - THDoubleTensor *offset, THDoubleTensor *mask, - THDoubleTensor *output, THDoubleTensor *columns, - THDoubleTensor *grad_input, THDoubleTensor *grad_weight, - THDoubleTensor *grad_bias, THDoubleTensor *grad_offset, - THDoubleTensor *grad_mask, THDoubleTensor *grad_output, - int kernel_h, int kernel_w, - int stride_h, int stride_w, - int pad_h, int pad_w, - int dilation_h, int dilation_w, - int deformable_group) -{ - printf("only implemented in GPU"); -} \ No newline at end of file diff --git a/src/dcn_v2_double.h b/src/dcn_v2_double.h deleted file mode 100644 index eda1f4c..0000000 --- a/src/dcn_v2_double.h +++ /dev/null @@ -1,20 +0,0 @@ -void dcn_v2_forward(THDoubleTensor *input, THDoubleTensor *weight, - THDoubleTensor *bias, THDoubleTensor *ones, - THDoubleTensor *offset, THDoubleTensor *mask, - THDoubleTensor *output, THDoubleTensor *columns, - const int pad_h, const int pad_w, - const int stride_h, const int stride_w, - const int dilation_h, const int dilation_w, - const int deformable_group); -void dcn_v2_backward(THDoubleTensor *input, THDoubleTensor *weight, - THDoubleTensor *bias, THDoubleTensor *ones, - THDoubleTensor *offset, THDoubleTensor *mask, - THDoubleTensor *output, THDoubleTensor *columns, - THDoubleTensor *grad_input, THDoubleTensor *grad_weight, - THDoubleTensor *grad_bias, THDoubleTensor *grad_offset, - THDoubleTensor *grad_mask, THDoubleTensor *grad_output, - int kernel_h, int kernel_w, - int stride_h, int stride_w, - int pad_h, int pad_w, - int dilation_h, int dilation_w, - int deformable_group); \ No newline at end of file diff --git a/src/vision.cpp b/src/vision.cpp new file mode 100644 index 0000000..ff54233 --- /dev/null +++ b/src/vision.cpp @@ -0,0 +1,9 @@ + +#include "dcn_v2.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward"); + m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward"); + m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward"); + m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward"); +} diff --git a/test.py b/test.py index 3a8b2e4..74af217 100644 --- a/test.py +++ b/test.py @@ -8,16 +8,15 @@ import torch.nn as nn from torch.autograd import gradcheck -from dcn_v2 import DCNv2 -from dcn_v2_func import DCNv2Function -from dcn_v2 import DCNv2Pooling -from dcn_v2_func import DCNv2PoolingFunction +from dcn_v2 import dcn_v2_conv, DCNv2, DCN +from dcn_v2 import dcn_v2_pooling, DCNv2Pooling, DCNPooling deformable_groups = 1 N, inC, inH, inW = 2, 2, 4, 4 outC = 2 kH, kW = 3, 3 + def conv_identify(weight, bias): weight.data.zero_() bias.data.zero_() @@ -29,18 +28,19 @@ def conv_identify(weight, bias): if p == q: weight.data[q, p, y, x] = 1.0 + def check_zero_offset(): conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW, - kernel_size=(kH, kW), - stride=(1, 1), - padding=(1, 1), - bias=True).cuda() + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True).cuda() conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW, - kernel_size=(kH, kW), - stride=(1, 1), - padding=(1, 1), - bias=True).cuda() + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True).cuda() dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, @@ -64,37 +64,12 @@ def check_zero_offset(): else: print('Zero offset failed') -def check_gradient_dconv_double(): - - input = torch.randn(N, inC, inH, inW, dtype=torch.float64).cuda() - input.requires_grad = True - - offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW, dtype=torch.float64).cuda() - # offset.data.zero_() - # offset.data -= 0.00001 - offset.requires_grad = True - - mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW, dtype=torch.float64).cuda() - # mask.data.zero_() - mask.requires_grad = True - mask = torch.sigmoid(mask) - - weight = torch.randn(outC, inC, kH, kW, dtype=torch.float64).cuda() - weight.requires_grad = True - - bias = torch.rand(outC, dtype=torch.float64).cuda() - bias.requires_grad = True - - func = DCNv2Function(stride=1, padding=1, dilation=1, deformable_groups=deformable_groups) - - print(gradcheck(func, (input, offset, mask, weight, bias), eps=1e-6, atol=1e-5, rtol=1e-3)) - def check_gradient_dconv(): - input = torch.randn(N, inC, inH, inW).cuda() + input = torch.rand(N, inC, inH, inW).cuda() * 0.01 input.requires_grad = True - offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() + offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() * 2 # offset.data.zero_() # offset.data -= 0.5 offset.requires_grad = True @@ -110,12 +85,18 @@ def check_gradient_dconv(): bias = torch.rand(outC).cuda() bias.requires_grad = True - func = DCNv2Function(stride=1, padding=1, dilation=1, deformable_groups=deformable_groups) + stride = 1 + padding = 1 + dilation = 1 + + print('check_gradient_dconv: ', + gradcheck(dcn_v2_conv, (input, offset, mask, weight, bias, + stride, padding, dilation, deformable_groups), + eps=1e-3, atol=1e-4, rtol=1e-2)) - print(gradcheck(func, (input, offset, mask, weight, bias), eps=1e-3, atol=1e-3, rtol=1e-2)) def check_pooling_zero_offset(): - from dcn_v2 import DCNv2Pooling + input = torch.randn(2, 16, 64, 64).cuda().zero_() input[0, :, 16:26, 16:26] = 1. input[1, :, 10:20, 20:30] = 2. @@ -128,10 +109,11 @@ def check_pooling_zero_offset(): output_dim=16, no_trans=True, group_size=1, - trans_std=0.1).cuda() + trans_std=0.0).cuda() out = pooling(input, rois, input.new()) - s = ', '.join(['%f' % out[i, :, :, :].mean().item() for i in range(rois.shape[0])]) + s = ', '.join(['%f' % out[i, :, :, :].mean().item() + for i in range(rois.shape[0])]) print(s) dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, @@ -139,12 +121,14 @@ def check_pooling_zero_offset(): output_dim=16, no_trans=False, group_size=1, - trans_std=0.1).cuda() + trans_std=0.0).cuda() offset = torch.randn(20, 2, 7, 7).cuda().zero_() dout = dpooling(input, rois, offset) - s = ', '.join(['%f' % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])]) + s = ', '.join(['%f' % dout[i, :, :, :].mean().item() + for i in range(rois.shape[0])]) print(s) + def check_gradient_dpooling(): input = torch.randn(2, 3, 5, 5).cuda() * 0.01 N = 4 @@ -155,22 +139,37 @@ def check_gradient_dpooling(): h = torch.rand((N, 1)).cuda().float() * 10 rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) offset = torch.randn(N, 2, 3, 3).cuda() - dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, - pooled_size=3, - output_dim=3, - no_trans=False, - group_size=1, - trans_std=0.0).cuda() input.requires_grad = True offset.requires_grad = True - print('check_gradient_dpooling', gradcheck(dpooling, (input, rois, offset), eps=1e-4)) + + spatial_scale = 1.0 / 4 + pooled_size = 3 + output_dim = 3 + no_trans = 0 + group_size = 1 + trans_std = 0.0 + sample_per_part = 4 + part_size = pooled_size + + print('check_gradient_dpooling:', + gradcheck(dcn_v2_pooling, (input, rois, offset, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size, + part_size, + sample_per_part, + trans_std), + eps=1e-4)) def example_dconv(): - from dcn_v2 import DCN input = torch.randn(2, 64, 128, 128).cuda() # wrap all things (offset and mask) in DCN - dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda() + dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, + padding=1, deformable_groups=2).cuda() + # print(dcn.weight.shape, input.shape) output = dcn(input) targert = output.new(*output.size()) targert.data.uniform_(-0.01, 0.01) @@ -178,8 +177,8 @@ def example_dconv(): error.backward() print(output.shape) + def example_dpooling(): - from dcn_v2 import DCNv2Pooling input = torch.randn(2, 32, 64, 64).cuda() batch_inds = torch.randint(2, (20, 1)).cuda().float() x = torch.randint(256, (20, 1)).cuda().float() @@ -221,8 +220,8 @@ def example_dpooling(): e = (target_dout - dout).mean() e.backward() + def example_mdpooling(): - from dcn_v2 import DCNPooling input = torch.randn(2, 32, 64, 64).cuda() input.requires_grad = True batch_inds = torch.randint(2, (20, 1)).cuda().float() @@ -234,11 +233,12 @@ def example_mdpooling(): # mdformable pooling (V2) dpooling = DCNPooling(spatial_scale=1.0 / 4, - pooled_size=7, - output_dim=32, - no_trans=False, - group_size=1, - trans_std=0.1).cuda() + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1, + deform_fc_dim=1024).cuda() dout = dpooling(input, rois) target = dout.new(*dout.size()) @@ -247,6 +247,7 @@ def example_mdpooling(): error.backward() print(dout.shape) + if __name__ == '__main__': example_dconv() @@ -258,20 +259,10 @@ def example_mdpooling(): if inC == outC: check_zero_offset() - check_gradient_dpooling() - - # # gradient check - # try: - # check_gradient_double() - # except TypeError: - # print('''****** You can swith to double precision in dcn_v2_func.py by (un)commenting these two lines: - # ****** from _ext import dcn_v2 as _backend - # ****** from _ext import dcn_v2_double as _backend''') - # print('****** Your tensor may not be **double** type') - # print('****** Switching to **float** type') - # - # check_gradient() - # finally: - # print('****** Note: backward is not reentrant error may not be a serious problem, ' - # '****** since the max error is less than 1e-7\n' - # '****** Still looking for what trigger this problem') \ No newline at end of file + # check_gradient_dpooling() + # check_gradient_dconv() + # """ + # ****** Note: backward is not reentrant error may not be a serious problem, + # ****** since the max error is less than 1e-7, + # ****** Still looking for what trigger this problem + # """