diff --git a/README.md b/README.md new file mode 100644 index 0000000..f53e989 --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +## Deformable Convolutional Networks V2 with Pytorch +```bash + .\make # build + python test.py # run gradient check +``` +### Known issues: + +-[ ] Gradient check w.r.t offset. +-[ ] Backward is not reentrant. + +This is adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). +I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes. +However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some +non-differential points? + +Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small `(<1e-7)`, +so it may not be a serious problem (?) + +Please post an issue or PR if you have any comments. + \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build.py b/build.py new file mode 100644 index 0000000..d858da5 --- /dev/null +++ b/build.py @@ -0,0 +1,42 @@ +import os +import torch +from torch.utils.ffi import create_extension + + +sources = ['src/dcn_v2.c'] +headers = ['src/dcn_v2.h'] +defines = [] +with_cuda = False + +extra_objects = [] +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/dcn_v2_cuda.c'] + headers += ['src/dcn_v2_cuda.h'] + defines += [('WITH_CUDA', None)] + extra_objects += ['src/cuda/dcn_v2_im2col_cuda.cu.o'] + with_cuda = True +else: + raise ValueError('CUDA is not available') + +extra_compile_args = ['-fopenmp', '-std=c99'] + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +sources = [os.path.join(this_file, fname) for fname in sources] +headers = [os.path.join(this_file, fname) for fname in headers] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] + +ffi = create_extension( + '_ext.dcn_v2', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects, + extra_compile_args=extra_compile_args +) + +if __name__ == '__main__': + ffi.build() diff --git a/build_double.py b/build_double.py new file mode 100644 index 0000000..a707f44 --- /dev/null +++ b/build_double.py @@ -0,0 +1,42 @@ +import os +import torch +from torch.utils.ffi import create_extension + + +sources = ['src/dcn_v2_double.c'] +headers = ['src/dcn_v2_double.h'] +defines = [] +with_cuda = False + +extra_objects = [] +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/dcn_v2_cuda_double.c'] + headers += ['src/dcn_v2_cuda_double.h'] + defines += [('WITH_CUDA', None)] + extra_objects += ['src/cuda/dcn_v2_im2col_cuda_double.cu.o'] + with_cuda = True +else: + raise ValueError('CUDA is not available') + +extra_compile_args = ['-fopenmp', '-std=c99'] + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +sources = [os.path.join(this_file, fname) for fname in sources] +headers = [os.path.join(this_file, fname) for fname in headers] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] + +ffi = create_extension( + '_ext.dcn_v2_double', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects, + extra_compile_args=extra_compile_args +) + +if __name__ == '__main__': + ffi.build() diff --git a/dcn_v2.py b/dcn_v2.py new file mode 100644 index 0000000..11fcf51 --- /dev/null +++ b/dcn_v2.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import math +from torch import nn +from torch.nn.modules.utils import _pair + +from dcn_v2_func import DCNv2Function + +class DCNv2(nn.Module): + + def __init__(self, in_channels, out_channels, + kernel_size, stride, padding, dilation=1, deformable_groups=1): + super(DCNv2, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.deformable_groups = deformable_groups + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) + self.bias = nn.Parameter(torch.Tensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + self.bias.data.zero_() + + def forward(self, input, offset, mask): + func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) + return func(input, offset, mask, self.weight, self.bias) + + +class DCN(DCNv2): + + def __init__(self, in_channels, out_channels, + kernel_size, stride, padding, + dilation=1, deformable_groups=1): + super(DCN, self).__init__(in_channels, out_channels, + kernel_size, stride, padding, dilation, deformable_groups) + + self.conv_offset_mask = nn.Conv2d(self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=(self.stride, self.stride), + padding=(self.padding, self.padding), + bias=True) + self.reset_parameters() + + def reset_parameters(self): + self.conv_offset_mask.weight.data.zero_() + self.conv_offset_mask.bias.data.zero_() + + def forward(self, input): + out = self.conv_offset_mask(input) + dy, dx, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((dy, dx), dim=1) + func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) + return func(input, offset, mask, self.weight, self.bias) \ No newline at end of file diff --git a/dcn_v2_func.py b/dcn_v2_func.py new file mode 100644 index 0000000..f522dd6 --- /dev/null +++ b/dcn_v2_func.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +from torch.autograd import Function + +from _ext import dcn_v2 as _backend +# from _ext import dcn_v2_double as _backend + +class DCNv2Function(Function): + + def __init__(self, stride, padding, dilation=1, deformable_groups=1): + super(DCNv2Function, self).__init__() + self.stride = stride + self.padding = padding + self.dilation = dilation + self.deformable_groups = deformable_groups + + def forward(self, input, offset, mask, weight, bias): + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: + self.save_for_backward(input, offset, mask, weight, bias) + output = input.new(*self._infer_shape(input, weight)) + self._bufs = [input.new(), input.new()] + _backend.dcn_v2_cuda_forward(input, weight, + bias, self._bufs[0], + offset, mask, + output, self._bufs[1], + weight.shape[2], weight.shape[3], + self.stride, self.stride, + self.padding, self.padding, + self.dilation, self.dilation, + self.deformable_groups) + return output + + def backward(self, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = self.saved_tensors + grad_input = input.new(*input.size()).zero_() + grad_offset = offset.new(*offset.size()).zero_() + grad_mask = mask.new(*mask.size()).zero_() + grad_weight = weight.new(*weight.size()).zero_() + grad_bias = bias.new(*bias.size()).zero_() + _backend.dcn_v2_cuda_backward(input, weight, + bias, self._bufs[0], + offset, mask, + self._bufs[1], + grad_input, grad_weight, + grad_bias, grad_offset, + grad_mask, grad_output, + weight.shape[2], weight.shape[3], + self.stride, self.stride, + self.padding, self.padding, + self.dilation, self.dilation, + self.deformable_groups) + + + return grad_input, grad_offset, grad_mask, grad_weight, grad_bias + + def _infer_shape(self, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * self.padding - (self.dilation * (kernel_h - 1) + 1)) // self.stride + 1 + width_out = (width + 2 * self.padding - (self.dilation * (kernel_w - 1) + 1)) // self.stride + 1 + return (n, channels_out, height_out, width_out) + diff --git a/mask.sh b/mask.sh new file mode 100755 index 0000000..1f7ca1f --- /dev/null +++ b/mask.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +cd src/cuda +nvcc -c -o dcn_v2_im2col_cuda.cu.o dcn_v2_im2col_cuda.cu -x cu -Xcompiler -fPIC +nvcc -c -o dcn_v2_im2col_cuda_double.cu.o dcn_v2_im2col_cuda_double.cu -x cu -Xcompiler -fPIC +cd - +python build.py +python build_double.py diff --git a/src/cuda/dcn_v2_im2col_cuda.cu b/src/cuda/dcn_v2_im2col_cuda.cu new file mode 100644 index 0000000..ab22b1b --- /dev/null +++ b/src/cuda/dcn_v2_im2col_cuda.cu @@ -0,0 +1,387 @@ +#include "dcn_v2_im2col_cuda.h" +#include +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + + +__device__ float dmcn_im2col_bilinear(const float *bottom_data, const int data_width, + const int height, const int width, float h, float w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h - h_low; + float lw = w - w_low; + float hh = 1 - lh, hw = 1 - lw; + + float v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + float v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + float v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + float v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +__device__ float dmcn_get_gradient_weight(float argmax_h, float argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + float weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +__device__ float dmcn_get_coordinate_weight(float argmax_h, float argmax_w, + const int height, const int width, const float *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + float weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const float *data_im, const float *data_offset, const float *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + float *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + float val = static_cast(0); + const float h_im = h_in + i * dilation_h + offset_h; + const float w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const float *data_col, const float *data_offset, const float *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + float *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + const float cur_inv_h_data = h_in + i * dilation_h + offset_h; + const float cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const float cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + float weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const float *data_col, const float *data_im, + const float *data_offset, const float *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + float *grad_offset, float *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + float val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + float inv_h = h_in + i * dilation_h + offset_h; + float inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const float weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda(cudaStream_t stream, + const float* data_im, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float* data_col) { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + modulated_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +void modulated_deformable_col2im_cuda(cudaStream_t stream, + const float* data_col, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float* grad_im){ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + modulated_deformable_col2im_gpu_kernel + <<>>( + num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + float* grad_offset, float* grad_mask) { + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + modulated_deformable_col2im_coord_gpu_kernel + <<>>( + num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset, grad_mask); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/src/cuda/dcn_v2_im2col_cuda.cu.o b/src/cuda/dcn_v2_im2col_cuda.cu.o new file mode 100644 index 0000000..b98618e Binary files /dev/null and b/src/cuda/dcn_v2_im2col_cuda.cu.o differ diff --git a/src/cuda/dcn_v2_im2col_cuda.h b/src/cuda/dcn_v2_im2col_cuda.h new file mode 100644 index 0000000..3457e96 --- /dev/null +++ b/src/cuda/dcn_v2_im2col_cuda.h @@ -0,0 +1,100 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.h + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1811.11168 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu + */ + +/***************** Adapted by Charles Shang *********************/ + +#ifndef DCN_V2_IM2COL_CUDA +#define DCN_V2_IM2COL_CUDA + +#ifdef __cplusplus +extern "C" +{ +#endif + + void modulated_deformable_im2col_cuda(cudaStream_t stream, + const float *data_im, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float *data_col); + + void modulated_deformable_col2im_cuda(cudaStream_t stream, + const float *data_col, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float *grad_im); + + void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + float *grad_offset, float *grad_mask); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/src/cuda/dcn_v2_im2col_cuda_double.cu b/src/cuda/dcn_v2_im2col_cuda_double.cu new file mode 100644 index 0000000..29cb048 --- /dev/null +++ b/src/cuda/dcn_v2_im2col_cuda_double.cu @@ -0,0 +1,399 @@ +#include "dcn_v2_im2col_cuda_double.h" +#include +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 512; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 +#else +__device__ double atomicAdd(double* address, double val) +{ + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); +} +#endif + +__device__ double dmcn_im2col_bilinear(const double *bottom_data, const int data_width, + const int height, const int width, double h, double w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + double lh = h - h_low; + double lw = w - w_low; + double hh = 1 - lh, hw = 1 - lw; + + double v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + double v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + double v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + double v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + double w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + double val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +__device__ double dmcn_get_gradient_weight(double argmax_h, double argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + double weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +__device__ double dmcn_get_coordinate_weight(double argmax_h, double argmax_w, + const int height, const int width, const double *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + double weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const double *data_im, const double *data_offset, const double *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + double *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + double *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const double* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const double *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const double *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const double *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const double offset_h = data_offset_ptr[data_offset_h_ptr]; + const double offset_w = data_offset_ptr[data_offset_w_ptr]; + const double mask = data_mask_ptr[data_mask_hw_ptr]; + double val = static_cast(0); + const double h_im = h_in + i * dilation_h + offset_h; + const double w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const double map_h = i * dilation_h + offset_h; + //const double map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const double *data_col, const double *data_offset, const double *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + double *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const double *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const double *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const double offset_h = data_offset_ptr[data_offset_h_ptr]; + const double offset_w = data_offset_ptr[data_offset_w_ptr]; + const double mask = data_mask_ptr[data_mask_hw_ptr]; + const double cur_inv_h_data = h_in + i * dilation_h + offset_h; + const double cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const double cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + double weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const double *data_col, const double *data_im, + const double *data_offset, const double *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + double *grad_offset, double *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + double val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const double *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const double *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const double *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const double *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const double offset_h = data_offset_ptr[data_offset_h_ptr]; + const double offset_w = data_offset_ptr[data_offset_w_ptr]; + const double mask = data_mask_ptr[data_mask_hw_ptr]; + double inv_h = h_in + i * dilation_h + offset_h; + double inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const double weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda(cudaStream_t stream, + const double *data_im, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, double *data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda(cudaStream_t stream, + const double *data_col, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, double *grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const double *data_col, const double *data_im, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + double *grad_offset, double *grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset, grad_mask); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/src/cuda/dcn_v2_im2col_cuda_double.cu.o b/src/cuda/dcn_v2_im2col_cuda_double.cu.o new file mode 100644 index 0000000..e5c3479 Binary files /dev/null and b/src/cuda/dcn_v2_im2col_cuda_double.cu.o differ diff --git a/src/cuda/dcn_v2_im2col_cuda_double.h b/src/cuda/dcn_v2_im2col_cuda_double.h new file mode 100644 index 0000000..a461692 --- /dev/null +++ b/src/cuda/dcn_v2_im2col_cuda_double.h @@ -0,0 +1,100 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.h + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1811.11168 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu + */ + +/***************** Adapted by Charles Shang *********************/ + +#ifndef DCN_V2_IM2COL_CUDA_DOUBLE +#define DCN_V2_IM2COL_CUDA_DOUBLE + +#ifdef __cplusplus +extern "C" +{ +#endif + + void modulated_deformable_im2col_cuda(cudaStream_t stream, + const double *data_im, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, double *data_col); + + void modulated_deformable_col2im_cuda(cudaStream_t stream, + const double *data_col, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, double *grad_im); + + void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const double *data_col, const double *data_im, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + double *grad_offset, double *grad_mask); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/src/dcn_v2.c b/src/dcn_v2.c new file mode 100644 index 0000000..b440d3f --- /dev/null +++ b/src/dcn_v2.c @@ -0,0 +1,30 @@ +#include +#include +#include + +void dcn_v2_forward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + printf("only implemented in GPU"); +} + void dcn_v2_backward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + THFloatTensor *grad_input, THFloatTensor *grad_weight, + THFloatTensor *grad_bias, THFloatTensor *grad_offset, + THFloatTensor *grad_mask, THFloatTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + printf("only implemented in GPU"); +} \ No newline at end of file diff --git a/src/dcn_v2.h b/src/dcn_v2.h new file mode 100644 index 0000000..1a97ff0 --- /dev/null +++ b/src/dcn_v2.h @@ -0,0 +1,20 @@ +void dcn_v2_forward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_backward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + THFloatTensor *grad_input, THFloatTensor *grad_weight, + THFloatTensor *grad_bias, THFloatTensor *grad_offset, + THFloatTensor *grad_mask, THFloatTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); \ No newline at end of file diff --git a/src/dcn_v2_cuda.c b/src/dcn_v2_cuda.c new file mode 100644 index 0000000..3ef9c4a --- /dev/null +++ b/src/dcn_v2_cuda.c @@ -0,0 +1,240 @@ +#include +#include "cuda/dcn_v2_im2col_cuda.h" + +extern THCState *state; + +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +void dcn_v2_cuda_forward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *output, THCudaTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + THCAssertSameGPU(THCudaTensor_checkGPU(state, 8, input, weight, bias, ones, offset, mask, output, columns)); + THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + const int batch = THCudaTensor_size(state, input, 0); + const int channels = THCudaTensor_size(state, input, 1); + const int height = THCudaTensor_size(state, input, 2); + const int width = THCudaTensor_size(state, input, 3); + + const int channels_out = THCudaTensor_size(state, weight, 0); + const int channels_kernel = THCudaTensor_size(state, weight, 1); + const int kernel_h_ = THCudaTensor_size(state, weight, 2); + const int kernel_w_ = THCudaTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaTensor__nDimension(state, ones) != 2 || + THCudaTensor_size(state, ones, 0) * THCudaTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaTensor_resize2d(state, ones, height_out, width_out); + THCudaTensor_fill(state, ones, 1); + } + + // resize output + THCudaTensor_resize4d(state, output, batch, channels_out, height_out, width_out); + // resize temporary columns + THCudaTensor_resize2d(state, columns, channels * kernel_h * kernel_w, 1 * height_out * width_out); + + THCudaTensor *input_n = THCudaTensor_new(state); + THCudaTensor *offset_n = THCudaTensor_new(state); + THCudaTensor *mask_n = THCudaTensor_new(state); + THCudaTensor *output_n = THCudaTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaTensor_select(state, input_n, input, 0, b); + THCudaTensor_select(state, offset_n, offset, 0, b); + THCudaTensor_select(state, mask_n, mask, 0, b); + THCudaTensor_select(state, output_n, output, 0, b); + + // Do Bias first: + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + // (N x 1) (1 x M) + long m_ = channels_out; + long n_ = height_out * width_out; + long k_ = 1; + THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, + THCudaTensor_data(state, ones), k_, + THCudaTensor_data(state, bias), k_, 0.0f, + THCudaTensor_data(state, output_n), n_); + + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, input_n), THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + deformable_group, THCudaTensor_data(state, columns)); + + //(k * m) x (m * n) + // Y = WC + long m = channels_out; + long n = height_out * width_out; + long k = channels * kernel_h * kernel_w; + THCudaBlas_Sgemm(state, 'n', 'n', n, m, k, 1.0f, + THCudaTensor_data(state, columns), n, + THCudaTensor_data(state, weight), k, 1.0f, + THCudaTensor_data(state, output_n), n); + } + THCudaTensor_free(state, input_n); + THCudaTensor_free(state, offset_n); + THCudaTensor_free(state, mask_n); + THCudaTensor_free(state, output_n); +} + +void dcn_v2_cuda_backward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *columns, + THCudaTensor *grad_input, THCudaTensor *grad_weight, + THCudaTensor *grad_bias, THCudaTensor *grad_offset, + THCudaTensor *grad_mask, THCudaTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + THCAssertSameGPU(THCudaTensor_checkGPU(state, 13, input, weight, bias, ones, offset, mask, columns, + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output)); + THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + const int batch = THCudaTensor_size(state, input, 0); + const int channels = THCudaTensor_size(state, input, 1); + const int height = THCudaTensor_size(state, input, 2); + const int width = THCudaTensor_size(state, input, 3); + + const int channels_out = THCudaTensor_size(state, weight, 0); + const int channels_kernel = THCudaTensor_size(state, weight, 1); + const int kernel_h_ = THCudaTensor_size(state, weight, 2); + const int kernel_w_ = THCudaTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaTensor__nDimension(state, ones) != 2 || + THCudaTensor_size(state, ones, 0) * THCudaTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaTensor_resize2d(state, ones, height_out, width_out); + THCudaTensor_fill(state, ones, 1); + } + + THCudaTensor_resize4d(state, grad_input, batch, channels, height, width); + THCudaTensor_resize2d(state, columns, channels * kernel_h * kernel_w, height_out * width_out); + + THCudaTensor *input_n = THCudaTensor_new(state); + THCudaTensor *offset_n = THCudaTensor_new(state); + THCudaTensor *mask_n = THCudaTensor_new(state); + + THCudaTensor *grad_output_n = THCudaTensor_new(state); + THCudaTensor *grad_input_n = THCudaTensor_new(state); + THCudaTensor *grad_offset_n = THCudaTensor_new(state); + THCudaTensor *grad_mask_n = THCudaTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaTensor_select(state, input_n, input, 0, b); + THCudaTensor_select(state, offset_n, offset, 0, b); + THCudaTensor_select(state, mask_n, mask, 0, b); + THCudaTensor_select(state, grad_output_n, grad_output, 0, b); + THCudaTensor_select(state, grad_input_n, grad_input, 0, b); + THCudaTensor_select(state, grad_offset_n, grad_offset, 0, b); + THCudaTensor_select(state, grad_mask_n, grad_mask, 0, b); + + long m = channels * kernel_h * kernel_w; + long n = height_out * width_out; + long k = channels_out; + + THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, + THCudaTensor_data(state, grad_output_n), n, + THCudaTensor_data(state, weight), m, 0.0f, + THCudaTensor_data(state, columns), n); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, columns), + THCudaTensor_data(state, input_n), + THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaTensor_data(state, grad_offset_n), + THCudaTensor_data(state, grad_mask_n)); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, columns), + THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaTensor_data(state, grad_input_n)); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and group + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, input_n), + THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaTensor_data(state, columns)); + long m_ = channels_out; + long n_ = channels * kernel_h * kernel_w; + long k_ = height_out * width_out; + + THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, + THCudaTensor_data(state, columns), k_, + THCudaTensor_data(state, grad_output_n), k_, 1.0f, + THCudaTensor_data(state, grad_weight), n_); + + // gradient w.r.t. bias + // long m_ = channels_out; + // long k__ = height_out * width_out; + THCudaBlas_Sgemv(state, + 't', + k_, m_, 1.0f, + THCudaTensor_data(state, grad_output_n), k_, + THCudaTensor_data(state, ones), 1, 1.0f, + THCudaTensor_data(state, grad_bias), 1); + } + + THCudaTensor_free(state, input_n); + THCudaTensor_free(state, offset_n); + THCudaTensor_free(state, mask_n); + + THCudaTensor_free(state, grad_output_n); + THCudaTensor_free(state, grad_input_n); + THCudaTensor_free(state, grad_offset_n); + THCudaTensor_free(state, grad_mask_n); +} \ No newline at end of file diff --git a/src/dcn_v2_cuda.h b/src/dcn_v2_cuda.h new file mode 100644 index 0000000..2c75dc6 --- /dev/null +++ b/src/dcn_v2_cuda.h @@ -0,0 +1,35 @@ +// #ifndef DCN_V2_CUDA +// #define DCN_V2_CUDA + +// #ifdef __cplusplus +// extern "C" +// { +// #endif + +void dcn_v2_cuda_forward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *output, THCudaTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_cuda_backward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *columns, + THCudaTensor *grad_input, THCudaTensor *grad_weight, + THCudaTensor *grad_bias, THCudaTensor *grad_offset, + THCudaTensor *grad_mask, THCudaTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); + +// #ifdef __cplusplus +// } +// #endif + +// #endif \ No newline at end of file diff --git a/src/dcn_v2_cuda_double.c b/src/dcn_v2_cuda_double.c new file mode 100644 index 0000000..1ea0821 --- /dev/null +++ b/src/dcn_v2_cuda_double.c @@ -0,0 +1,262 @@ +#include +#include "cuda/dcn_v2_im2col_cuda_double.h" + +extern THCState *state; + +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +void dcn_v2_cuda_forward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *output, THCudaDoubleTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 8, input, weight, bias, ones, offset, mask, output, columns)); + THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaDoubleTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + input = THCudaDoubleTensor_newContiguous(state, input); + offset = THCudaDoubleTensor_newContiguous(state, offset); + mask = THCudaDoubleTensor_newContiguous(state, mask); + weight = THCudaDoubleTensor_newContiguous(state, weight); + + const int batch = THCudaDoubleTensor_size(state, input, 0); + const int channels = THCudaDoubleTensor_size(state, input, 1); + const int height = THCudaDoubleTensor_size(state, input, 2); + const int width = THCudaDoubleTensor_size(state, input, 3); + + const int channels_out = THCudaDoubleTensor_size(state, weight, 0); + const int channels_kernel = THCudaDoubleTensor_size(state, weight, 1); + const int kernel_h_ = THCudaDoubleTensor_size(state, weight, 2); + const int kernel_w_ = THCudaDoubleTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaDoubleTensor__nDimension(state, ones) != 2 || + THCudaDoubleTensor_size(state, ones, 0) * THCudaDoubleTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaDoubleTensor_resize2d(state, ones, height_out, width_out); + THCudaDoubleTensor_fill(state, ones, 1); + } + + // resize output + THCudaDoubleTensor_resize4d(state, output, batch, channels_out, height_out, width_out); + // resize temporary columns + THCudaDoubleTensor_resize2d(state, columns, channels * kernel_h * kernel_w, 1 * height_out * width_out); + + THCudaDoubleTensor *input_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *offset_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *mask_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *output_n = THCudaDoubleTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaDoubleTensor_select(state, input_n, input, 0, b); + THCudaDoubleTensor_select(state, offset_n, offset, 0, b); + THCudaDoubleTensor_select(state, mask_n, mask, 0, b); + THCudaDoubleTensor_select(state, output_n, output, 0, b); + + // Do Bias first: + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + // (N x 1) (1 x M) + long m_ = channels_out; + long n_ = height_out * width_out; + long k_ = 1; + THCudaBlas_Dgemm(state, 't', 'n', n_, m_, k_, 1.0, + THCudaDoubleTensor_data(state, ones), k_, + THCudaDoubleTensor_data(state, bias), k_, 0.0, + THCudaDoubleTensor_data(state, output_n), n_); + + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, input_n), THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + deformable_group, THCudaDoubleTensor_data(state, columns)); + + //(k * m) x (m * n) + // Y = WC + long m = channels_out; + long n = height_out * width_out; + long k = channels * kernel_h * kernel_w; + THCudaBlas_Dgemm(state, 'n', 'n', n, m, k, 1.0f, + THCudaDoubleTensor_data(state, columns), n, + THCudaDoubleTensor_data(state, weight), k, 1.0f, + THCudaDoubleTensor_data(state, output_n), n); + } + THCudaDoubleTensor_free(state, input_n); + THCudaDoubleTensor_free(state, offset_n); + THCudaDoubleTensor_free(state, mask_n); + THCudaDoubleTensor_free(state, output_n); + + THCudaDoubleTensor_free(state, input); + THCudaDoubleTensor_free(state, offset); + THCudaDoubleTensor_free(state, mask); + THCudaDoubleTensor_free(state, weight); +} + +void dcn_v2_cuda_backward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *columns, + THCudaDoubleTensor *grad_input, THCudaDoubleTensor *grad_weight, + THCudaDoubleTensor *grad_bias, THCudaDoubleTensor *grad_offset, + THCudaDoubleTensor *grad_mask, THCudaDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 13, input, weight, bias, ones, offset, mask, columns, + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output)); + THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaDoubleTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + input = THCudaDoubleTensor_newContiguous(state, input); + offset = THCudaDoubleTensor_newContiguous(state, offset); + mask = THCudaDoubleTensor_newContiguous(state, mask); + weight = THCudaDoubleTensor_newContiguous(state, weight); + grad_output = THCudaDoubleTensor_newContiguous(state, grad_output); + + const int batch = THCudaDoubleTensor_size(state, input, 0); + const int channels = THCudaDoubleTensor_size(state, input, 1); + const int height = THCudaDoubleTensor_size(state, input, 2); + const int width = THCudaDoubleTensor_size(state, input, 3); + + const int channels_out = THCudaDoubleTensor_size(state, weight, 0); + const int channels_kernel = THCudaDoubleTensor_size(state, weight, 1); + const int kernel_h_ = THCudaDoubleTensor_size(state, weight, 2); + const int kernel_w_ = THCudaDoubleTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaDoubleTensor__nDimension(state, ones) != 2 || + THCudaDoubleTensor_size(state, ones, 0) * THCudaDoubleTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaDoubleTensor_resize2d(state, ones, height_out, width_out); + THCudaDoubleTensor_fill(state, ones, 1); + } + + // THCudaDoubleTensor_resize4d(state, grad_input, batch, channels, height, width); + THCudaDoubleTensor_resize2d(state, columns, channels * kernel_h * kernel_w, height_out * width_out); + + THCudaDoubleTensor *input_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *offset_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *mask_n = THCudaDoubleTensor_new(state); + + THCudaDoubleTensor *grad_output_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *grad_input_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *grad_offset_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *grad_mask_n = THCudaDoubleTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaDoubleTensor_select(state, input_n, input, 0, b); + THCudaDoubleTensor_select(state, offset_n, offset, 0, b); + THCudaDoubleTensor_select(state, mask_n, mask, 0, b); + THCudaDoubleTensor_select(state, grad_output_n, grad_output, 0, b); + THCudaDoubleTensor_select(state, grad_input_n, grad_input, 0, b); + THCudaDoubleTensor_select(state, grad_offset_n, grad_offset, 0, b); + THCudaDoubleTensor_select(state, grad_mask_n, grad_mask, 0, b); + + long m = channels * kernel_h * kernel_w; + long n = height_out * width_out; + long k = channels_out; + + THCudaBlas_Dgemm(state, 'n', 't', n, m, k, 1.0, + THCudaDoubleTensor_data(state, grad_output_n), n, + THCudaDoubleTensor_data(state, weight), m, 0.0, + THCudaDoubleTensor_data(state, columns), n); + + // gradient w.r.t. input offset and mask data + modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, columns), + THCudaDoubleTensor_data(state, input_n), + THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaDoubleTensor_data(state, grad_offset_n), + THCudaDoubleTensor_data(state, grad_mask_n)); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, columns), + THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaDoubleTensor_data(state, grad_input_n)); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and group + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, input_n), + THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaDoubleTensor_data(state, columns)); + long m_ = channels_out; + long n_ = channels * kernel_h * kernel_w; + long k_ = height_out * width_out; + + THCudaBlas_Dgemm(state, 't', 'n', n_, m_, k_, 1.0, + THCudaDoubleTensor_data(state, columns), k_, + THCudaDoubleTensor_data(state, grad_output_n), k_, 1.0, + THCudaDoubleTensor_data(state, grad_weight), n_); + + // gradient w.r.t. bias + // long m_ = channels_out; + // long k__ = height_out * width_out; + THCudaBlas_Dgemv(state, + 't', + k_, m_, 1.0f, + THCudaDoubleTensor_data(state, grad_output_n), k_, + THCudaDoubleTensor_data(state, ones), 1, 1.0f, + THCudaDoubleTensor_data(state, grad_bias), 1); + } + + THCudaDoubleTensor_free(state, input_n); + THCudaDoubleTensor_free(state, offset_n); + THCudaDoubleTensor_free(state, mask_n); + + THCudaDoubleTensor_free(state, grad_output_n); + THCudaDoubleTensor_free(state, grad_input_n); + THCudaDoubleTensor_free(state, grad_offset_n); + THCudaDoubleTensor_free(state, grad_mask_n); + + THCudaDoubleTensor_free(state, input); + THCudaDoubleTensor_free(state, offset); + THCudaDoubleTensor_free(state, mask); + THCudaDoubleTensor_free(state, weight); + THCudaDoubleTensor_free(state, grad_output); +} \ No newline at end of file diff --git a/src/dcn_v2_cuda_double.h b/src/dcn_v2_cuda_double.h new file mode 100644 index 0000000..18eb45b --- /dev/null +++ b/src/dcn_v2_cuda_double.h @@ -0,0 +1,35 @@ +// #ifndef DCN_V2_CUDA +// #define DCN_V2_CUDA + +// #ifdef __cplusplus +// extern "C" +// { +// #endif + +void dcn_v2_cuda_forward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *output, THCudaDoubleTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_cuda_backward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *columns, + THCudaDoubleTensor *grad_input, THCudaDoubleTensor *grad_weight, + THCudaDoubleTensor *grad_bias, THCudaDoubleTensor *grad_offset, + THCudaDoubleTensor *grad_mask, THCudaDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); + +// #ifdef __cplusplus +// } +// #endif + +// #endif \ No newline at end of file diff --git a/src/dcn_v2_double.c b/src/dcn_v2_double.c new file mode 100644 index 0000000..2b86545 --- /dev/null +++ b/src/dcn_v2_double.c @@ -0,0 +1,30 @@ +#include +#include +#include + +void dcn_v2_forward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + printf("only implemented in GPU"); +} +void dcn_v2_backward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + THDoubleTensor *grad_input, THDoubleTensor *grad_weight, + THDoubleTensor *grad_bias, THDoubleTensor *grad_offset, + THDoubleTensor *grad_mask, THDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + printf("only implemented in GPU"); +} \ No newline at end of file diff --git a/src/dcn_v2_double.h b/src/dcn_v2_double.h new file mode 100644 index 0000000..eda1f4c --- /dev/null +++ b/src/dcn_v2_double.h @@ -0,0 +1,20 @@ +void dcn_v2_forward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_backward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + THDoubleTensor *grad_input, THDoubleTensor *grad_weight, + THDoubleTensor *grad_bias, THDoubleTensor *grad_offset, + THDoubleTensor *grad_mask, THDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..bc71846 --- /dev/null +++ b/test.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from dcn_v2 import DCNv2 +from dcn_v2_func import DCNv2Function + +deformable_groups = 1 +N, inC, inH, inW = 2, 2, 4, 4 +outC = 2 +kH, kW = 3, 3 + +def conv_identify(weight, bias): + weight.data.zero_() + bias.data.zero_() + o, i, h, w = weight.shape + y = h//2 + x = w//2 + for p in range(i): + for q in range(o): + if p == q: + weight.data[q, p, y, x] = 1.0 + +def check_zero_offset(): + conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True).cuda() + + conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True).cuda() + + dcn_v2 = DCNv2(inC, outC, (kH, kW), + stride=1, padding=1, dilation=1, + deformable_groups=deformable_groups).cuda() + + conv_offset.weight.data.zero_() + conv_offset.bias.data.zero_() + conv_mask.weight.data.zero_() + conv_mask.bias.data.zero_() + conv_identify(dcn_v2.weight, dcn_v2.bias) + + input = torch.randn(N, inC, inH, inW).cuda() + offset = conv_offset(input) + mask = conv_mask(input) + mask = torch.sigmoid(mask) + output = dcn_v2(input, offset, mask) + output *= 2 + d = (input - output).abs().max() + if d < 1e-10: + print('Zero offset passed') + else: + print('Zero offset failed') + +def check_gradient_double(): + + input = torch.randn(N, inC, inH, inW, dtype=torch.float64).cuda() + input.requires_grad = True + + offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW, dtype=torch.float64).cuda() + offset.data.zero_() + offset.data -= 0.5 + offset.requires_grad = True + + mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW, dtype=torch.float64).cuda() + # mask.data.zero_() + mask.requires_grad = True + mask = torch.sigmoid(mask) + + weight = torch.randn(outC, inC, kH, kW, dtype=torch.float64).cuda() + weight.requires_grad = True + + bias = torch.rand(outC, dtype=torch.float64).cuda() + bias.requires_grad = True + + func = DCNv2Function(stride=1, padding=1, dilation=1, deformable_groups=deformable_groups) + + print(gradcheck(func, (input, offset, mask, weight, bias), eps=1e-6, atol=1e-5, rtol=1e-3)) + +def check_gradient(): + + input = torch.randn(N, inC, inH, inW).cuda() + input.requires_grad = True + + offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() + offset.data.zero_() + offset.data -= 0.5 + offset.requires_grad = True + + mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() + # mask.data.zero_() + mask.requires_grad = True + mask = torch.sigmoid(mask) + + weight = torch.randn(outC, inC, kH, kW).cuda() + weight.requires_grad = True + + bias = torch.rand(outC).cuda() + bias.requires_grad = True + + func = DCNv2Function(stride=1, padding=1, dilation=1, deformable_groups=deformable_groups) + + print(gradcheck(func, (input, offset, mask, weight, bias), eps=1e-3, atol=1e-3, rtol=1e-2)) + + +if __name__ == '__main__': + if inC == outC: + check_zero_offset() + try: + check_gradient_double() + except TypeError: + print('''You can swith to double precision in dcn_v2_func.py by (un)commenting these two lines: + from _ext import dcn_v2 as _backend + from _ext import dcn_v2_double as _backend''') + print('Your tensor may not be **double** type') + print('Switching to **float** type') + + check_gradient() + finally: + print('Note: backward is not reentrant error may not be a serious problem, ' + 'since the max error is less than 1e-7\n' + 'Still looking for what trigger this problem')