From 490610f32f0534b5f1075d8836a05b1e825f3a82 Mon Sep 17 00:00:00 2001 From: CharlesShang Date: Wed, 5 Dec 2018 13:36:21 +0800 Subject: [PATCH] DCN --- README.md | 20 ++ __init__.py | 0 build.py | 42 +++ build_double.py | 42 +++ dcn_v2.py | 68 ++++ dcn_v2_func.py | 72 +++++ mask.sh | 7 + src/cuda/dcn_v2_im2col_cuda.cu | 387 +++++++++++++++++++++++ src/cuda/dcn_v2_im2col_cuda.cu.o | Bin 0 -> 43528 bytes src/cuda/dcn_v2_im2col_cuda.h | 100 ++++++ src/cuda/dcn_v2_im2col_cuda_double.cu | 399 ++++++++++++++++++++++++ src/cuda/dcn_v2_im2col_cuda_double.cu.o | Bin 0 -> 47264 bytes src/cuda/dcn_v2_im2col_cuda_double.h | 100 ++++++ src/dcn_v2.c | 30 ++ src/dcn_v2.h | 20 ++ src/dcn_v2_cuda.c | 240 ++++++++++++++ src/dcn_v2_cuda.h | 35 +++ src/dcn_v2_cuda_double.c | 262 ++++++++++++++++ src/dcn_v2_cuda_double.h | 35 +++ src/dcn_v2_double.c | 30 ++ src/dcn_v2_double.h | 20 ++ test.py | 132 ++++++++ 22 files changed, 2041 insertions(+) create mode 100644 README.md create mode 100644 __init__.py create mode 100644 build.py create mode 100644 build_double.py create mode 100644 dcn_v2.py create mode 100644 dcn_v2_func.py create mode 100755 mask.sh create mode 100644 src/cuda/dcn_v2_im2col_cuda.cu create mode 100644 src/cuda/dcn_v2_im2col_cuda.cu.o create mode 100644 src/cuda/dcn_v2_im2col_cuda.h create mode 100644 src/cuda/dcn_v2_im2col_cuda_double.cu create mode 100644 src/cuda/dcn_v2_im2col_cuda_double.cu.o create mode 100644 src/cuda/dcn_v2_im2col_cuda_double.h create mode 100644 src/dcn_v2.c create mode 100644 src/dcn_v2.h create mode 100644 src/dcn_v2_cuda.c create mode 100644 src/dcn_v2_cuda.h create mode 100644 src/dcn_v2_cuda_double.c create mode 100644 src/dcn_v2_cuda_double.h create mode 100644 src/dcn_v2_double.c create mode 100644 src/dcn_v2_double.h create mode 100644 test.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..f53e989 --- /dev/null +++ b/README.md @@ -0,0 +1,20 @@ +## Deformable Convolutional Networks V2 with Pytorch +```bash + .\make # build + python test.py # run gradient check +``` +### Known issues: + +-[ ] Gradient check w.r.t offset. +-[ ] Backward is not reentrant. + +This is adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). +I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes. +However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some +non-differential points? + +Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small `(<1e-7)`, +so it may not be a serious problem (?) + +Please post an issue or PR if you have any comments. + \ No newline at end of file diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/build.py b/build.py new file mode 100644 index 0000000..d858da5 --- /dev/null +++ b/build.py @@ -0,0 +1,42 @@ +import os +import torch +from torch.utils.ffi import create_extension + + +sources = ['src/dcn_v2.c'] +headers = ['src/dcn_v2.h'] +defines = [] +with_cuda = False + +extra_objects = [] +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/dcn_v2_cuda.c'] + headers += ['src/dcn_v2_cuda.h'] + defines += [('WITH_CUDA', None)] + extra_objects += ['src/cuda/dcn_v2_im2col_cuda.cu.o'] + with_cuda = True +else: + raise ValueError('CUDA is not available') + +extra_compile_args = ['-fopenmp', '-std=c99'] + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +sources = [os.path.join(this_file, fname) for fname in sources] +headers = [os.path.join(this_file, fname) for fname in headers] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] + +ffi = create_extension( + '_ext.dcn_v2', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects, + extra_compile_args=extra_compile_args +) + +if __name__ == '__main__': + ffi.build() diff --git a/build_double.py b/build_double.py new file mode 100644 index 0000000..a707f44 --- /dev/null +++ b/build_double.py @@ -0,0 +1,42 @@ +import os +import torch +from torch.utils.ffi import create_extension + + +sources = ['src/dcn_v2_double.c'] +headers = ['src/dcn_v2_double.h'] +defines = [] +with_cuda = False + +extra_objects = [] +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/dcn_v2_cuda_double.c'] + headers += ['src/dcn_v2_cuda_double.h'] + defines += [('WITH_CUDA', None)] + extra_objects += ['src/cuda/dcn_v2_im2col_cuda_double.cu.o'] + with_cuda = True +else: + raise ValueError('CUDA is not available') + +extra_compile_args = ['-fopenmp', '-std=c99'] + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +sources = [os.path.join(this_file, fname) for fname in sources] +headers = [os.path.join(this_file, fname) for fname in headers] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] + +ffi = create_extension( + '_ext.dcn_v2_double', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects, + extra_compile_args=extra_compile_args +) + +if __name__ == '__main__': + ffi.build() diff --git a/dcn_v2.py b/dcn_v2.py new file mode 100644 index 0000000..11fcf51 --- /dev/null +++ b/dcn_v2.py @@ -0,0 +1,68 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import math +from torch import nn +from torch.nn.modules.utils import _pair + +from dcn_v2_func import DCNv2Function + +class DCNv2(nn.Module): + + def __init__(self, in_channels, out_channels, + kernel_size, stride, padding, dilation=1, deformable_groups=1): + super(DCNv2, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.deformable_groups = deformable_groups + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) + self.bias = nn.Parameter(torch.Tensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + self.bias.data.zero_() + + def forward(self, input, offset, mask): + func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) + return func(input, offset, mask, self.weight, self.bias) + + +class DCN(DCNv2): + + def __init__(self, in_channels, out_channels, + kernel_size, stride, padding, + dilation=1, deformable_groups=1): + super(DCN, self).__init__(in_channels, out_channels, + kernel_size, stride, padding, dilation, deformable_groups) + + self.conv_offset_mask = nn.Conv2d(self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=(self.stride, self.stride), + padding=(self.padding, self.padding), + bias=True) + self.reset_parameters() + + def reset_parameters(self): + self.conv_offset_mask.weight.data.zero_() + self.conv_offset_mask.bias.data.zero_() + + def forward(self, input): + out = self.conv_offset_mask(input) + dy, dx, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((dy, dx), dim=1) + func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) + return func(input, offset, mask, self.weight, self.bias) \ No newline at end of file diff --git a/dcn_v2_func.py b/dcn_v2_func.py new file mode 100644 index 0000000..f522dd6 --- /dev/null +++ b/dcn_v2_func.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +from torch.autograd import Function + +from _ext import dcn_v2 as _backend +# from _ext import dcn_v2_double as _backend + +class DCNv2Function(Function): + + def __init__(self, stride, padding, dilation=1, deformable_groups=1): + super(DCNv2Function, self).__init__() + self.stride = stride + self.padding = padding + self.dilation = dilation + self.deformable_groups = deformable_groups + + def forward(self, input, offset, mask, weight, bias): + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: + self.save_for_backward(input, offset, mask, weight, bias) + output = input.new(*self._infer_shape(input, weight)) + self._bufs = [input.new(), input.new()] + _backend.dcn_v2_cuda_forward(input, weight, + bias, self._bufs[0], + offset, mask, + output, self._bufs[1], + weight.shape[2], weight.shape[3], + self.stride, self.stride, + self.padding, self.padding, + self.dilation, self.dilation, + self.deformable_groups) + return output + + def backward(self, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = self.saved_tensors + grad_input = input.new(*input.size()).zero_() + grad_offset = offset.new(*offset.size()).zero_() + grad_mask = mask.new(*mask.size()).zero_() + grad_weight = weight.new(*weight.size()).zero_() + grad_bias = bias.new(*bias.size()).zero_() + _backend.dcn_v2_cuda_backward(input, weight, + bias, self._bufs[0], + offset, mask, + self._bufs[1], + grad_input, grad_weight, + grad_bias, grad_offset, + grad_mask, grad_output, + weight.shape[2], weight.shape[3], + self.stride, self.stride, + self.padding, self.padding, + self.dilation, self.dilation, + self.deformable_groups) + + + return grad_input, grad_offset, grad_mask, grad_weight, grad_bias + + def _infer_shape(self, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * self.padding - (self.dilation * (kernel_h - 1) + 1)) // self.stride + 1 + width_out = (width + 2 * self.padding - (self.dilation * (kernel_w - 1) + 1)) // self.stride + 1 + return (n, channels_out, height_out, width_out) + diff --git a/mask.sh b/mask.sh new file mode 100755 index 0000000..1f7ca1f --- /dev/null +++ b/mask.sh @@ -0,0 +1,7 @@ +#!/usr/bin/env bash +cd src/cuda +nvcc -c -o dcn_v2_im2col_cuda.cu.o dcn_v2_im2col_cuda.cu -x cu -Xcompiler -fPIC +nvcc -c -o dcn_v2_im2col_cuda_double.cu.o dcn_v2_im2col_cuda_double.cu -x cu -Xcompiler -fPIC +cd - +python build.py +python build_double.py diff --git a/src/cuda/dcn_v2_im2col_cuda.cu b/src/cuda/dcn_v2_im2col_cuda.cu new file mode 100644 index 0000000..ab22b1b --- /dev/null +++ b/src/cuda/dcn_v2_im2col_cuda.cu @@ -0,0 +1,387 @@ +#include "dcn_v2_im2col_cuda.h" +#include +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + + +__device__ float dmcn_im2col_bilinear(const float *bottom_data, const int data_width, + const int height, const int width, float h, float w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h - h_low; + float lw = w - w_low; + float hh = 1 - lh, hw = 1 - lw; + + float v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + float v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + float v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + float v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +__device__ float dmcn_get_gradient_weight(float argmax_h, float argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + float weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +__device__ float dmcn_get_coordinate_weight(float argmax_h, float argmax_w, + const int height, const int width, const float *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + float weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const float *data_im, const float *data_offset, const float *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + float *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + float val = static_cast(0); + const float h_im = h_in + i * dilation_h + offset_h; + const float w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const float *data_col, const float *data_offset, const float *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + float *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + const float cur_inv_h_data = h_in + i * dilation_h + offset_h; + const float cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const float cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + float weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const float *data_col, const float *data_im, + const float *data_offset, const float *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + float *grad_offset, float *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + float val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + float inv_h = h_in + i * dilation_h + offset_h; + float inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const float weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda(cudaStream_t stream, + const float* data_im, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float* data_col) { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + modulated_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +void modulated_deformable_col2im_cuda(cudaStream_t stream, + const float* data_col, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float* grad_im){ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + modulated_deformable_col2im_gpu_kernel + <<>>( + num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + float* grad_offset, float* grad_mask) { + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + modulated_deformable_col2im_coord_gpu_kernel + <<>>( + num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset, grad_mask); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/src/cuda/dcn_v2_im2col_cuda.cu.o b/src/cuda/dcn_v2_im2col_cuda.cu.o new file mode 100644 index 0000000000000000000000000000000000000000..b98618e2bc378e48fce85929d398603bc838a670 GIT binary patch literal 43528 zcmeFa3w%`7wLiYkV;-42CWMgZBr}=G6NY5YnRmit@*+_{!y7dr2@oI(7(yawi^f2p zxe<#Nt*wEAQEOYRh}CLaYy?E@y=}4Ds_noq5;t{oQB#7oF;#8~%fNLhT#4 z&3)~~+7AiWK<$V9FTO)vqV1{4yVT^J{w-(wyHE8m3ir=_;~#Eb-@oXM{<*K!^w%CA zuTKj|da1s@dl+q>sqdeA0F585>7V-uCa}4_zjois`T@D17&M{qfigT$`^c9rZ0YE) zeGNJDZVO;p!SVOuKyUCox0R4+M<2{ zlng9-cVO=O^#k3f25R5AqW{m8&j1mSTIny`Xy|QMYHlu!E?g2D@1}#C;8}(b? zHD2uAH?XC6$-ovHJ{|Zxg3qJ)yoS#jd|L2XicbSR_4tJF3E<;eQrrJLkS0e61hO7J zSbuGJ>y|LJxFL~_(uLJ*Y zc-{*3KqUS?&&T68@%j?K2exD^sTr92&t`bXXok<}uYCl<{b2n-Uk?|n+WV0GoR;17 z5mA2-%!Xvwx`g2f>q7&JHgi#j_$NnU*btllp76i+=?K4>(>+id;1q>G!RG!K=L{?w zwKlI^?A~>;`#vn1tFi(2Vt(ITNs|8{K5yVdJLKbbg$?!ED$VUb9+m#vkM07qKK;!x zw#L*p+B-kLc=2NW0N14>8*N0VCS#2w+E!CEk?nT-yUEZ&{f8{nx&NtTp4?;P$xeo( zroS(PCp$mu)-2k~W5d9UBuSMzbT zJ6$9dJMSsx$y3x8T?k;dcQ-|``BnRBpZ-hnN13bxD#Q1@XRHTb|F8<=|vM#DQW z8qUGk;zq+$J+!U2Z2ku{q6=+`@zFHnz?S4CQMCwzVb{O(&M5i5qz(V2cP_@94U7Jz zcm5fA=c4}8(>qT_asB^E?}WI$qZt5CY3I#7@9<-1=P4v}`upBMQq#Zl0|E~8ktJ8( z-*;Rq-KUjaLg^#cj|{zyYFIy%Z zO=0|6KC|)qYFcPaw3_uZc9m`1;T9E-U&XDTu@Q3XXY8ub_z3IAuj1Cv*a*4xGjHu=6Qu3%iOCj$>ER7GT`2;ynkR;->!2KH|&3J$!>)W>*nL*i}x#IB*Xk z9_X{-{!zQiD#qxy804_42s!Mke;n7VP!-%p`=9AIxY2LGuKECW)dzDZ`f|Blb@3@u zCKsO~1v7AnR(Hw3p&p+6FYPSN{fwGgTi;E-Eio(y<-GGfUcwf*1k7(1be&#Ms?kzidS(G6p^ejq|L1s7kxcpRaq zSI=F%$o=t4NT7S*68FDU)BpUO{<+6t5x>*F=p6bZ7v-fA{;q(?9oJ>M^kB{eCzGlsVVmAZxjI`x}f1 zcM|c@3n9R^&dwVuy}xy|%ThZDsSywpBNFu4`Sswyn8+ox0-2waqKKSGJbh z3N|93`tj}2SZ!bT$=ly}W9Q1)UW_$g>+y_9QgjI3Wpzh)^9^mC>)Y10H_TqOz}2k% zZy$X%tXiNn^I11v>kfPz-mx%#6dTbu4xR|h{y{vnd~l55@aQf|jP2Y-cesP~hoAqz z)O6(HMIrv6(EdY_eMR(P+oJd#VbxDW_F$x#@DpLXuv23H87=S9KPS9qV6Ph>o@kyh zy8Hp^zh`9MH*)dCJesds+&uT{nniPJo3E&0tl_^FAwHbUHm}u$L}Y>xLiUj>v~u)r zK!iOvT0qKWKc*bb-loO#-G&H2MC5s8w35;Y-kw?wn5)ZWAr@c~?nxBIFWtCiV^?Qa z>vH6x>CNl3GEbW~u35iwGb>$xQ)&D9RX4KIuC~oxrT+~1i4O5kL%wlMYiHZae;WD~ zH?H5<)w;gR_0Jh0!uN6ZMP%RpgZm{IeHT3cO@e?A8 zoKVodpN_HszQ4C0KRJc5%k!y_Iqqyf*|v@z>*LV$36cHXrTFCd{@zl2a*Y4LQhYL` zv3A#}{C1QjhLHZKYY--w#Ehe$jLm_^E3t%+z$;+CN-XXp@YgwBr)zYkv(1d1%Nr@Dv5w2ilh_)RCydmml7$>o zp2%7*E2pr=%gPhj(#y)pY~f{PGi$i4oW$l_R!(H~mz5J39D$he<6(=(mQ8HhW#u^7 z>#_Al=DDnFV6Mx`dgi>WtYgKOm1Sl_`4Yb*mT_5GgntrSFVHH*w8s*lJb_8sGnt$t zQ)?z$j6?xWGIL>}WULp&59_rC>$N6^|1Fo{n>dQ!DWmv3VHCfUNAcS{ir-12 z_?XQV|70u%@!uBV zKZ5!H;{X3e{y&4QLi|q#|NnLUiv$1Sz`r=~FAj{)0fW1@_wZ@eezu7v3n5l44~o0d z+9`+GHcE#EAK1r2dYdD>pXowL*>A&M#?>P)vp=B=u?#~st726)yL{6AJe6y5ZTs~h zIpfHlKMd(ZY|YW!$tW|+*%Q(}q{T;?-v0u@4Hnk2n6Z{Rbgp8Z_OpBBT~vQ`nWF^w zEj77$8#ta5?`@>^XZPqksGeE4x%1M|+>*Ocw$!oSsvfn43mUz-p%%t&wN+ax(amTB9mlcW#Y^n-)v6HqcUq%+ z$f7$1c!Ohb66nuYq;QRn4YL+Uc=g{1FX+35=sUWsa4x0(s@^KFz+N2=rJ~$`bq8F| z0KF|lFB{a>Q404V1fP^84z}`j1pTOI^;mzqSQP^QZs+w{#2ioy|uBGxe}`p5GGiX_vjB3!EY+6f3z0U6 zg9_)<(6T<@ooq8K7e(>z6uTU5fisk9+d5bY=Yy@0!e3m(=^f$YJkH0OWqGgYAXn15 zMQvzTU)PdSwXcfVMQ1O^8|9-EruwApvrW2T;%0sJ0isBsbPy9UvywXf?T@m1}P za2v6j zMRKx6YtJButu{>#`_Mlr3-q&tm_ODNaO{&uS~Wg!`XDd4_P2;G@C~V)!5X#si9K37 z+CO3^(YicZSB5mXn;YdPn$n@Hr{lVPW8t5$mp7 z2zbN$L?v7bIXuAv{0VjJJoGPv{?zC^+oY%d#nE+$>ak1uSks$PeOzaQ{DA%#y}WM< z^zw*4KAFe$@!36w%k}SGP2WcQ*X%9k>vMLGiR* zj|%x63fp!8(6F`fmAg>CM#?35VD(tnZlu`1`N_~Lwwk@R4W!qhXZBNSfcowfvHvZM z_(a7;%T(=*}f_s?j`w^!m~L(QP%o7!V5q@&Vz`acMQ@i<@w5xrDY${ ziT#D~m}QeHsw&1GIlnAVK+ds0Ez6N4*tS}1v~FvLmQ}7Hd9JqXfZTw8;W&=Jh@W4b zb}2`rQwl$X{z+Ll-}WQ5$%6vqJj*5)S8DTVp9egwUJCEv=iG}BpAhZe@B+X=@4Co& zCu{8;;aC4g@FPtZ_&JV!!~01&$M+GQ)#&Z20{v=4?}opo&HG_(-#Ws_!KZ18Y@r}P zs?9Mt0eT(himJ)SCXw_vt^e(i8`vq4b8u*I1LP#ZHWXgH5A(MjJ(;%!#f`2u+cJ#-&liLcql5EV<%jE#{LN0OcSb8;mQc}t3j_Rq^DGxvb4%%gl zJ&)F13RhGA?4gsN+eYhMZRx#(&K+@EtBmn+{!4s*a=5sk+V24Mq(7zTd!ZDY0{5J_m=EIe@k_v@ykdXX#QfA%RaO>6aD63eG5-&~NqmzB8SJ|h z+fj}WcGg}R7y9(;Hq00HZR30BzYP5C=JeIM=A%FK|A{~Ib5m}d@)XhO7(5RBGEKp` z_YC?qND=#CXzv@nnjX{kMO_1JKiJ>PoLnDm9khYZ&^LRpdz7(8TlL@@81Guh>yvo- zKPjuaG4B}1_t0-aFRS+-#^mvTXxVi=j4emI@A7t1SbCbT-&0yUNBDkD@7WV3+7HFx z1J+@dvV*7>tDwI+x%^9QzW{yU-_yV4d=yt|=PmSWF5$zv@O78=g!WN8@;_2&{P1@k zp!Jr*Z=yfW-N9Qx$3o@op7;=rhjS>Iw5IIA*`<+d>R`3d5S)?&?YaNOmdtd@16WLO4;BpysgC= zUhF-5_;A%ZI}1D5soy6RGRuD2FW9%$C?Cv2n$MWsfqEyULygyx+%>!nuf)JuNb^78 z|L#Gs9QzS^4(m+z6C0-eeELK@@jK&eQ#`FlG0v5Tk;2ZU^ON?ky~>qhw{Npd0bOD$ zN#pBg}Gf)KNQ{ox|U-9eiv__rewkY5mU&28gzw-zfxnR1^wX1K2P&p z;|SjjdZ_<huSaQfy3`Kt-s~7 zhc)vR*TDVBYH520b%b?hJn=se`-3pnqR3Jn$( z!a#^O2z(t`c!8+N(Gb!HAWu&|!q@|OkRQ4pf&PXCOY)=13*>;WFYufaLCCgn9+wL- z`~$w8LzX-)2i1)V_Y3@@o$vpl!C|aB{I15?Tn;q<ctApb1d?*-u9oxy&l$qT!C zKhA|)UGnzHT(0Ucdth*|KCK_y{X{P8UKiWG|2azEE+Mq+k-kPAN7xpXN4BSd%Uj(k zlK+G7Kk}eI(H>fr+MnJda6VLrZ;s6Wdh9#$%RckpG*4++JFO#~9|uJ70nEP}dK+<* z-5aPM_U(~-soeM4WWtAi^zvo%$Noe8?hM*Lbb$9IJGK3?xP2KmDYHD9Y%o zpTp}#`7<0pTl_WjhyBx1)gzwY>WK1P{y`pZ&tCZ&_y#`ewf)Do@6-0z%lp9}?5}(0 zVPD+pg8n{7`=8yt0s9Ah+m;9ULi)K4M`QUR4i1yV!>)RDAI&Qo4y?}9$%`c?Xs7AJCqPhp#p{|4(-`eC3b=kO`sPtQfmvE0=cGT| zpq@%yZ=c8U)1lh2PeR|nD23i9>4AQzw(K9)L%)j_7p-Ri`+?*y1A46m1hz<{RV z9V9+g9riPg`;Y5;?y&WRYpMV4_wy)WA?aZo=~0|V9 z55)_YItEWcUUDOTUy4n#)XpV5@aN-@TBOO0M4=`Z-JU+f;m$h}Yq=IqeZ`a(V57}Tp$@}#o3+!3;7V+gwn7v2y zy@1}Q&x^Ku1`B|m?S`BmXRsTF2GxBm1b;OeH_u}|L*Q>jzw{tIbrEGP{w_L3;_nu= z8+b0o-z^MyDE=<6%i`~dUqJt4ko_Xs;lD=W8e79c2koDud$l<5k(>)pqTQOBllDsT zZxLTdig?=I0>lyJ46>hlaehbR9!K})KachnF|5Tsj&i(+dWR3v`oC0jvQQyBV)!8U zhYjc~E~ogX4gBAp$_CpeP&>rqfggAV+y2bsHp>*O=dCuT?X#EU#=>i;eO)pcfW7u3 zd;WNY=oG_G^Y)0FK~7-@?!C8_`r*99d4sq{x}>lJ z`ywfe4YqQ6rC+*8@6<>AfRT7;nCQXz6%wlu0<_DGcIe&hu-`cUiZ8|K$PPyw@+9HM zbBr93_79$8XwN0;5Pj?=`*-i%7NS$KYjIYGWxfsc!GF5DSo6mmv?DClgLk`-A`bch z+Gp4{9ftm{I$+;w+1f(uCWSZVb9rqee4zWm2V|tsyLX{XcHC#Q`V&;f`7c(r(E5oz zT02Mh`46d`Y;jDb`5ZZMp8Mfv_eh`vcI_UK#wnH_ID>ZZ-?9CwLcR8(lSe#6zhm&p z0Q?;I-AASoeKjW?kO#<3ZhjE$q0i2S(0>gJ|AF(j4$fZ)>xcFs)HjIXKWOo?+@ru} zL)^V!I_K*NiTV}e{6IU3FV$W_2tGon5A7xY_rYa6GhjpU!`8Pez(yovS| z=+Wq1B36Z#WBsK)HJon93EGpK?9lq>MEVcrV}2=EkM4XP&x`QWP?Zb=qyhSGKc`O~ zEL=zHC4D^~WmbRogr4l2V%f4C_|mW*uZK8YO%m3H>>t>_VhZ~Aak-1me`xPlF;DCV zbvwu7up9yXKBao_;hXUs19(dKKCKH?HGq83ukBN5axHH!TVBIBX||(#i_i~p7k=Ia zzN~?rfqog#hdMqU>A(3vfVjV#>K%h^;KSs|IPlY&Bq(WFKn! zVkrCsW&AnD!&uJ@7D}|4{fK_Gv2ofxCI? z2)ou`-IUW!679Q;p}{Ah=S?=UpL!90kqWhQ>+A{9rpRKis%9OkLT3qbRnNj-fJ%bf(-ej z?Jme|#*pO?J-Sea4bP22fH!O#q<;ET*0_L?pGp1;9)Q_J%W~+gi}7N3Io27rYIuJ> zwHF(g!(X`=FYnFIL^(qoT+Y`|KAE3KWvn0Ke=7V(#e(%K6fG+upXk455%dh~?52ft zD3x;y3aGs!yp^vr=;L%ke(^5_{Wl<9aUf4#RtUY#GJ3+Rg;PIHDq<&I!I_vu1vHVe-KzI}P};0F%aQ!bXXcqP^MgqLtU za`?hTf*UZrtt#2x6YkdFdkY~qh_6_>IDgo(h1xuFwefnw{Tf}l8a|>w1r)F_Jzw3I z$Chbyu)#CL$E9-Pg?JigTiB7q`M&2h4d02Ma(cmk7?7a9>1UiixiLcT-oaU=AEoF0pJdN7-h&Sw_{ZiLK=I?p?=^P&S!t>Q%@22)cgEq)f0^(3) z2VRWFzU|TMUe_*t&r7yR%Se2g6Y&$U?=Z%1gMLfO#{63}{-=ZPJF@@>9#7UKJ{Hsv z-#&(aA!#iB4MzAEPS*HmA-_1vKZ-vT^p;lfOj#E8_Kz@La&bKg$(* z&~F6&Fu;g^)IKd!{&ESeB=%`7@6WDz^FEXt@u8fu2 z=zdSf^7r8UqkbQ|?+1G>;~VySG4bm~+WmNbGa|2lhhI^-jPhwmRQ{IU9N|}ggkLw0 z$*-~HZ#Y)|FnwkmpKnj`^9%FsAwGfr!{_<_d5ZH(4m&uXdcr!$T`~0LM6NgAC+!=n zA9IXcKVE?HJPi3A(UXlOnw}ioH;P{iN;Q2#a*~oMTZ++Nw&D2;BB*zRe~S=kBYmUE zU37mt!qvQely5BIK8;WEBb+~VSUy-w{NQYk_CALnPrs`X|C3J^ z+)nnAytg1?H@(R1-=XkP@<;GqmiFGn*(Mow9e=-P5yr6~{>R_|)BAfZ1dqNS2Y)dS zW%&Qvdml%6e`Z0P0D2McW$ksL`>i(q{0Gmk;O~Q#ciQtT7I`0I=wt}( zP4LI&k^Q{}=O6bo>G??yw;_c{{g6DR(^YLqz*}+eWQ{ag{=q3Kyt9IF0JU>NSdQ2+z-BVs;?YwUTC#|wt`9<*`)ZTa26L-D(V&{t`g|8mR=@zZ6F0>CXh zX!e*wwdF|<%9!tA^uzligNGyYb>wOD-Rh+Ivc^{<^CbKjzeBU%8kzqN-e1bm`pJV1>@38OV84sdKT$t#bR6VJ&L6m6ICSz%h~|Um@4$n3 zgqLvs?11&ls+Oa_aDe>W>cKP9FfE((zyo~Vj__T)o%B2M>*zV&Jm@)!Gad#%vJ}|= z-~i6=#!8~QVI=-DxY9y=+S>SljplRor1oAbp7W7>z|Pe08tnA=NQM5OaRIlyfO%&j z?n&~TO3zQJe=*(J&0cp4OUf>zbJ*{rW2pxFJYW~oAllv zFmK5Jfip5v$Ul|A*KJxE@)xn67yOOj(YTPqLSMqsD;OVo5YIvPXnv$NJe7_>VorpP|7J=?VCo4M^?q;{@Vc@sQ;~%!}Z)WFIx)`Tv3Qq(8$_ z59v?rLvBZ&8I)jW4syK+gDQmerE%%GEMF(Q@5kF+;QotfhyL7`h;~DkSE#-LaXPJi zlW>a0J)0w)BaX#$c5a-vBZ$D z2N`vt!@YO_K=On76dTUfeNcogI8Q=6j?bUh9-!y&yxbSIp&Ne=4?p~Xt?m@W@7z{< z^m#m<%eH9!Ho%`gKz!$X=o@V1{p7ucqyX{!{9&*w^86h36Z~BP&(EQU9wxf*{QNGI z$*;SsiLoA>d;71UG#b~(bImqxe?$F}{d>B}K>WgUa&5eA;lg&RmvUdB^90YWJ0M3s z?7Jr?kepRpp2$J_HGKc#{bMQ2h@N_B@9T69A&ycDx@rG!;_Ekcpsg=wg|smG;cg-+)U4{`$Bm%pKU{puY(SstkFr^K6|LnYk_xH|(kK=jE z0r2TzqH8J3eVF%$U*E#Z;l*nyjq-)@Wi zhOLmp!&3oA{7&txI@OE$-;RBQ=ar3|j@?6+VJ@#jjfc5ho*ATm+J1Z=b_(<|ALqr> zuzpUH{rt*4ZeQfFmiNfs!TaH`I}o?g?PH@xKDs_)2XW~C;wGRk z{LPPCp$J5!S;kz{KK6=7VXXd$_VFuQvwKInjP5@XHQWLmX`y9yk#@v&PD^+j?yyG6 z8IiO^OFJUPk)(&e-oO|M9bHixM@@&_p;dpR|GSa?2OMzs6j_N-a{VLi}e_GqvYnXO3uAa#E!ih}=#?%l2W3I2G^H}{uPOVWW zSxc!u&XuB&RvfwiE;9eo{h-T^+ zT_+8C>3$s69bM5hw%=%4wU%SwFZC0IdNirGYwfC#hLDxO`CJXe?~`mX1!O&6xV${k zB(U?A(wo{kH@4rn-sUTHn@YP{J6E@L**30gcDqcatt(e{wr$+lys`b3Hk;2=#{TMQ z-`I}7HfAerTi?}rv+aNM2O|V_ku|I;HkI1SXF~IQ=4*C28(Epp)5t2U zN+YYcs<@T8ogG4*6~KC{mu^0~`RNv*TM)N0#nr^Fuqt#@aWfA}DD9!!Ub=mSZu{uA zU%H}93HIu*_(Pn_RATy!skF0gwXL)Rf6d8O&@oN*%&3^i?pftlZ3U~QdHg9Xd3g$3 z+&N8gDHGXw!RN6RbgrD{cFm|TnM`X}KF)-A;kJTK)$g=j_j^M{QBy@56SpvHu#HW( zDs7B`v{lHetC(O_kelvqs<>CM2B_YHdh@-)j9hOUv%2pQW>|d$*^P`J8LLW-&=QrD z15Rf3d|sMi1xjZ1u3|F?CA0b{hY^^$PzTVTBy{)_nkg43w^KLM`~fna4+!*U%Ke#g zf1%u8DEC*&{Z*O)&J61X^AFjKWc6uGaRIrJFO@IxkA=yV6;W z51BN!p><{H+8awZW-?FE!Z#%Gm2n?`D&|t}IHuUd3r+ zCGMaHs{sg_*w_?zV5Z>qdYp$0tk?*u3ca8zJ7^JoULSH;e~aQ^1%6CB#nZ!#KV(^+ zOujeJ0TY}Zk1#yezshsI#28J zRON>T>A!X2U(M|0c%g*lHdz>iD2E9@F|x%Vim`o6_>nksmFm(P1n~yZ8f?JN*QgCZ zjm&ZZf9cP&N>%q6l&;cM&q;QVvlH)guBw_9a0x8I&V5G-%!=g8!It#HNHyZWg z8+!A98JX}F6Ka|1m{{Zmxh1NnMac5D2&RypT|Ys5m1VlTEKhY-F!A?s`g3vY_i>3f zN&1%P@HX8qWCW+n`2lylkTdm=zEDxG6bb|Gr$u}ZDAhe%m^@XvQWPE)1uF}7DC)P_ zB=@)Ry1ba?nI&c^v&2*-!7xK{mA?WZc!hbUc$$j9xYwAN7;gdCI#lPgmrpv0IH0c*NP&9ho*%;m?O?)iO)c4=7^cT79kVUN%OY| z+3pr0RSgPOwMAgkeBC79GNH^}UZ7*TR*5~T$GR0L?p<=38mu%5(*07LhfVU_C1<%m zfK0i%gcS8Pmag6+XL~`8=W!@q&@&MQg9n@|uMk+X(dSy_en6+&&)EAZqK&b)SX>rM zcV{u_j6up`g&t=)%S2K9rp}6lrTIK8{wflDG~R7;uUgl-RA5|FZh(-=PPZSbDCoR} ziP$6VKnqZ~cC#spXF3xGlOzXJRM3zbtP-$+W+Cq3ywi|ZuwuTD-X zPc96qq2#GSPZ!H^S0tyYO$j=O{FPui6n0mIY0J%dnKx>GIP2!Till|id~;ro`BSkT z$`lXT5wIT06d&baJ(MW{GBC(`z(9gC=)XB{y5bJpl2_{X-jZikZpjmp`ZJTx8O*oj zmC@-x9T~y=12%)&^F!R+n*);o2mf?4zV=Y2;N+if^_2;XTE)Me4cU1+TPkzA-p&@w z+{)Y8@&x6|q@qsuA2Xb8<(=$8w{tQpQB+fgO=*55(fs3V*{ZyhqL=PXF>lJ&i(Qkf z?q{+Sh2n&nO0!y!%uIsMWV-lZ6qvm_)Ox8F6@vLyd4~B}yHKX6?fMz88bq@s&#(s29GRI` zAGJZo{39}bkSPu}ebbidUZXEjf)2kM3o5FzI?wQ57GzW(%l2n8gOAR8 zUgh?|q{?P`@oJmM0iNI4vXw*3kd1Ts5Yy*)%tG9E4C$VEh!U&Ea}ti_rF#C8=?*ZH zl@B%w#xaH-D8$X+b%^ z)a(9AruaFQcpw)$ap+l65pn&Gs4f&2~it4`?$K+=nnBVKL>Jp&>z9+(1{)rW*!DE@G zr>rQ_ayVbuI#XnRa&J=aGWMs6Z=2_jTd8~LXa^l?r`w}a=u`?;pn9-J$ zm*rVZrBo`Rc7fHXvJ539FaCF^DzoTwnfPdSQZ9>|T5yHaHdVdgnCPC!l5Q4LJe38; z%@)YUaxv4h!jTH&->O^*I-bdo3tI$>EM~EiRh~>XNxe6x)ayBuLnb-wWkLLr9%ra{ zo0tnrJIyzTnQqIUsJzU~FX7cM{~rtDfaS0-)$3ou3RTxjEYp|C^xw2(dN*05Hw(&~ zwmSt;{C<9>dL@IUlpFXgOY?q~W%@Ij;R6((Eu0B+*`*W2yY!j<-dw?%3{p&=g#!P8 zO$_WeXS%(Gsm@JeSwZJ|Q1duiEo74ve?g}2r=nO~Wc5ETNa92KQom<`HPv@QN|xFL>U%d)Bm=1E14ATL&nPA>k%-B4< zvbeA~-GB3x{I(6x3kH9Q#h-)lsVU#@zfNcDwhMpM2{TxrLv_^}ihTZkc|r%vU)%8` zCfF!%trg{r@jFt+2&`Ehr*v_VRsO~P(BGrOvBNP5Z zELA=K#e`m+^m!>BnLKKj;(wcEc^9*E-(nW`GAr`C7cG6~u2ABpr+^3OIw%zy5RJ_r4O= zJ+-E_>&A5(m_ap4gZX(%z~*rHJr~%NVENv{RQD-^xLb}dC^dKt`rWdySP+A3atF?B z?nua2r|BEgN|eA#xy+(M4jHE^vb2~z@x?1V402;x`U4v$KS<@Jg&Q-S9TP^T4)pEb`+&5 zJBkoDF;5_kDP?+pY?|bL)R-A4lck?A@jMJw_fnSWHOu0&%={|k;irkEf#3yMaGGJp zB&-n9J^N)O1*G=HFSGIvk7p@M_bz4fWFgc0c`3LaPXUG9utSXX&*wy-?fIq0SdeG7l z2>i8d0vhHA{2wq&pj;>BPq6wYIE4~VaGJ^ZlHTF8m9uFbo?xD-%&Xj?Gp2fL65?)8 z5^gtRy@DNHxPIdCq;&Tkx-@S{2m3apD+>hg(8=P9iOkC4QS@SUbSc5Vl;lFIIDC;M zL*96Y&guouG8Ns$vY>j0PA}Yt?GsdK`;dQb5cEP>5MC!NLS?dK^-h=JQ~I+@5)bLp zy=7jLTNM5+5oT*ihj-_=0 z1xr`X$yV<IIxbv|Ttzv9BlsLP${9g$Fw?)4iu=>b zT*^0Ouu8Qs^P})+7o>G0nrrU!g-mS?TxaRp({nx2rS`aQQo zgl}c0gL0AIw}qL0Xe?4(TUh+|IPqXAY-@1s8NL#i@DnD!p%9KW%+uiyzA42$=EYRJhrsGxqU>Pi zV{uH}gkyu5UyRekGJ6A)Y-TzBH$e^DgvqM!&miwju;NWw?@(}0L0!J${Z^{Ql{j5u z$yW8LiL5Z7e!0T*)I?zg2}Maq;6G*Ic)Vm|(hKn5KNC0E?c0Qn>3>0%g>S~AN&KOt zvVa$?7vB#g-5Mvp6EA3-_kr^-$X2v8mbtwz$Woc=LSB4F(m5*P%W+D-ARDB;QkHtP zTYPvDVq-l?LP5X}uG35GzmAtV*PkgD%<1a68e z4!XehYF%Maxs_!FA2(*GvM>H|qnMpREV~124`L@swboqam@EkK;CkYpa1N}M6aGZ2 z5#f2zeJ$kuS_W<#-rzic&==)-&~r?Rd(tfaRyH4#Wocg!hWjyDCk+Mhc8K}dM6f(? zOxEWIo#gwvYvgU7%H)J=4+pJ8{qj>-9IkO^@IW=c0@3R^N0jlP1Q*TN>a zk4ovmkR0c(5_if;7rn;$lf`@{T{9Veg?l~A^(~s59z2~W55roNe5NvwXJ>xgvK91` z3SY=K-8Y@grk(cOrlpAW+7&g4IKrN)<+Z zxj5V02aoCvX} zbF41>`846@Qif;Ebu*Qq`Ug=zQE4ok6x3a3w6G$I{gwKB-<5Daepi}3#oshdELQ@7 z`BLKD@Ntc^%`0Yz)__+x$Nb<-guyAZWd<_eikoA8z7jSJjJnz8A5l(4uGrIL(x)na zT`Wqn{`9n|N}xwF?2*K8xa54;Im((WafwSlDoLw@CH`P5o8tFp${-}$mnoZ0mxG?? zWJe&lmYIzeVutUHX->Z{pTTyX%p~iK62E_Og=ef{i`mQWl4HG?knek0# z+*F>0J(Pum$lxJQ6uMFMWkK8QVfsozYHhN*X9{AWU+D|xtNs(^=1`MtDrdUtCRv__ zvuREfD&6yi5?^qS+v?rkq)+o@HJO{3F5Um9MqLv#IS@=ic)u3THut+8dTh{mlU@-@z^w7vOE4BRuptC znpP4}e;NnRA`dbUaKBY5icU#r(Tm2#NxNr=zm2o1yJs-zf&3E1{myhLqfjbXV(}h8 zsJOc|Rr&X_xP6PHn(Gq2DHjF3w*_EleScBX0VW-Gnh!9T7)3Mm2bh>#k?--(nrSjl zm#&*)zHfS*{^3INs_D`s?^mYl-*QV8W38b ztU|%zb)|~txQYTi!W>^kiL}%VxtOZ>^T`KW zB}ffZOb<;#^kFV5b9twgCRlybN_DVY1byPPQo}USs#La^Cb>*gO#hMSa0PxaSzMPg z$@K*GM@HyiHK|JKGWhGJ*O^!zXZ0XKxN*AT3cSvwIx*>W z9EzbC@jE2xyXGV{PTEo|*-G?7qA{bWSj7_z^D#+~t}9ByS-4Eh^5WyVK}hr6AY`jw zmQ$5l#CF!R=^;uY(dAnmSASAQ;)dCDhN z>oS$hstoT}^)Q@ouR=t~?_-9q>5IMoDxu69yrxA5i}p}ui8p}gr;^3u=1ooal~mBkp4f+h}$t!zlBZmu9}sl`kB!+({!seRaIV_nx^(jlii+-RP%7E>{Gn%jMRzV z?TVCPkRF+hSC2BMHpKfCui|AnYGPcLYn?H^GXyEGX6EM%aO0E$D`#^> z&-%922>-O=$;axo%n@{-n3~hHUq-aI9F7D8ZIWP`8iz3E)Hp+i=g{?tU@kV|Jnm|j zgkAsDDKPHfFV5LFNS63GXy>IDDf*!XZh04ut0Z%jBLfI6L zS3ul7u3RirQ%v8)YeR^}Z;vw+dfX=@VMTSkOUU&cZ%_Qj%qbqvH)d9-&F-*qlKXxZ z*M7acs>}>eCM~e3On-%75G%||q5!#_9B1|j`t;y$%MD(^)DTzX3B1MR31ViTLLY}* z160Yo?Sj-4h<~=d(4|fkg@N*HeqmS;*iinw=&5MxXBWGPyT-()!1Uzcd(7cKyvN zNY?~ScMIafwFQ1}lR^KGBzD)9`h5=xW&Xe$a=Zjv8mVA<1SYoU4cU<4enZYtH?OlI z)k(v%Qj}w}#23Xh6tmSwBwZ5PDKCn6$QJYpa4diqnZ(1%R+z_$Vl|r_Q0|f@t7* z-Y+>|F`l$Gi3W>sGc)I>A^v@(MJi-T4ngdnhfLxZKQmEDUxb__WthJaFz9woN_c#N zBxq1x6<_b%wn(evxXec8>0UC4LT=_yRufYB*N}|J)22i z%_&m-Rif3^Q721ICOjqMZyf60l;rme%w3+CBIp+u#HR>F9_N<|%iQi;I|QHG(}edh z(~YGb_r%#L%YRjy@4~4o{@IY`&7Z_{<#ESyt=_-d4P}1Mf9Dzz?sp(G9oU&aQ}L^3 zjZ(V*OXd>4x+`CLD=TpVreY+AEGKYKlzx$x;is(?StIKIQi8a}a3K2V1+PCaemzmWd=|X}^LMYh>gsY^@uKD3>$|#b{!(|TYpUw+ z=0)XJB~a?}lxScLQ)&dEaLarqY)%%k6AZ>(7zn{<{-NqLx>Y76Zx^c*_4CjGMXH}h za9j%AqV;Pp)i)9x7asRts=rpNr}xo5OK@B+x<&gVOLg}V3|AT5M%8_bV7NSVi>V`; z=n`&en8RmdxSxvI(K+R#Kp(o+VvhXXuxNiNMJVkN#6=*F+RWjt7~AlXcIlV3>x#5{ z^ds%A7D{&s;!@xvy4IkcuJL&Jxkx=;GrBaVrc1nxzE_54k(cU;SQLrg4+)Ozyfk`V zH%q0-Vi&a}I?hGL*-nk4a~E&thzL(jr2g;W35bLMZ`9Lm3_N#4@LUTfA3`t()BEd8<;K9AxftQ@<@ExlB!j3xkR=LaN=9>sP?~)z`2e@ zT@=r)*yq$fFsl7G0Vn*cBkz;)eS9_#?(8O2@h+9OGkM(zxxxj4>_=D>M3%A;9>O-UyhVQ7BCu2X5!^$YG0i{LDg1 z?{lYspHX>@uEPSr$Ku*r3^^9OJ%VRegd%>fA%8~%PG2IR%cCXiTM_uzB4s+S==v$( zdE%vYiO%pnz{eUl5pqxT-xS&MTwepvLcqsr-!=|>^EmL^#)02C4m>;#onJwFqH|JY ziRs*>Yd8Y`U7}Ww(r?gf?YBpkp3YyoN+a+eMOeho6~KD|A8WpY$_;Ms=p>7w(H zE<05DSa_xbPV|2m;d8Y8!3ex7q9W)#q^n{aIhj8W{Q7a=8^(dNu5}%oS9LYJaQRlY zxp;;OQ@zbASFCToNsaw6<5K!X<5g`+>57gH{8r!`Rc&rwe^Yb&`gZyOhjNf%$*^D34RP!D zsv+_lip$&Awy$q%#iFmmR98Vf>gHTsRXL~m>RGcEpx?sED*Q$^b@T9U_)0_b>dw}c z?fBM0^QN};)oZ#|QBRK9H`0VZRl~;VYKv*a$6#ar4tazcnk%bWg6|QmDs66Fy?JwU z2fk8qE9vLKhVB{Zt0Fe5crKu^xSab@>fu;>4Q%8LrjjU8jj z%Rj7wKhbhr(jTLhM+0f!|8Hop2xFp(Wfsl?Z4+IlZ!WDScJk)-F1{CH4-U=DVNt)0 z%sMiCBIXaPiquoN6d(E{Hgs(m2ma7F@JSK7kMK;6!DEeqPl z;TSmC=21Fdj)Bux&!TXA=Y(HGk1q!Ps|cLSV&JiS_(lvow*B!Kcq|{3i2Y0BPKz10JO&=ihq*EESUxO?z-bub zL)72FtNr{M&4(Le;IVw@iGj!RA%zg&q9k@+mKb;}A3QPe*!I`Nz+?IF^BDLHj1gUz zU&X*<`S9l$cq|{Jh+h%q!%89=7x`;+QMXumNRNTX^1&GckL5#U41U5Djig*X3Lf>q zuNw!xA_o6V;EB%n));sw23`t(gf8M`?0)e^;51*NhhlqlEr`Gg7R!I~Yv>vczb6J> zh59Hxe~G|Hk2@>kkBx>mM&N`$mY*Fl@K}Dn7Xz=18Ta7`V+nukzIZGGkIpwA^>qDv z1U`CRza9sUXFmL*_ObKrjKD{a``kG2*JAL*&i9=dcK0ydqeNg%SmoD~xDy%PCfp;y_HUF$!X3Fn76}cg39-_L7iTi6}`W zX|xg)Q9-n`P|(6cu(7cath7+-kaImnVWrkKO}Ly zR~6^DxL!RnI4+I5k|56t@aKyA{=ZTjeosrH-_Op@5XMEm72rRNeR~}a3+#{S#K*YM z{|SA+Iaq*?D~|r3ki>qL75Du=W$cU@{hHCAH27Jg4_jJI8U^^Wu`^}tyfb*s;JQ47 z$l@NsyC_MP{rdU?M`9dq+~CmXevZ$HsLwdu8K=+h z$}@gReLe@^b0Xr5-!q;^obmhj#y;8@pZjQV^b>nZi2hF8=)5rwx*KE>+GpJEfWcu0 zYYk+=;3zU4Ipc7~`S5G`5M;13jf%((gQLhekPVjVx=dDgd7#$d}!ljc)q4+x@iJ??qGc zDfd=7HGQERZn4!;t228!Ysj5<+Rpk*3*prN*I}1jU;5Ux#ZFagob!97en=A5FbMM^ zKGrCR3(DuZ zP6NLGHw*auy_xkcnKAeL+sgN^AO2+zg6qbL=Wj?OuB)n5l5*`{1jt8@ttMg3#VvBK+z`7dc?+3H z)iBOCrWE-+|9QS)qP~V>l6=yIai5frO7eW{kzE%Bk^SPWkR{hih}A#FkN1u|e&s +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 512; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 +#else +__device__ double atomicAdd(double* address, double val) +{ + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); +} +#endif + +__device__ double dmcn_im2col_bilinear(const double *bottom_data, const int data_width, + const int height, const int width, double h, double w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + double lh = h - h_low; + double lw = w - w_low; + double hh = 1 - lh, hw = 1 - lw; + + double v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + double v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + double v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + double v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + double w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + double val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +__device__ double dmcn_get_gradient_weight(double argmax_h, double argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + double weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +__device__ double dmcn_get_coordinate_weight(double argmax_h, double argmax_w, + const int height, const int width, const double *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + double weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const double *data_im, const double *data_offset, const double *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + double *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + double *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const double* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const double *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const double *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const double *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const double offset_h = data_offset_ptr[data_offset_h_ptr]; + const double offset_w = data_offset_ptr[data_offset_w_ptr]; + const double mask = data_mask_ptr[data_mask_hw_ptr]; + double val = static_cast(0); + const double h_im = h_in + i * dilation_h + offset_h; + const double w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const double map_h = i * dilation_h + offset_h; + //const double map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const double *data_col, const double *data_offset, const double *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + double *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const double *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const double *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const double offset_h = data_offset_ptr[data_offset_h_ptr]; + const double offset_w = data_offset_ptr[data_offset_w_ptr]; + const double mask = data_mask_ptr[data_mask_hw_ptr]; + const double cur_inv_h_data = h_in + i * dilation_h + offset_h; + const double cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const double cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + double weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const double *data_col, const double *data_im, + const double *data_offset, const double *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + double *grad_offset, double *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + double val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const double *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const double *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const double *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const double *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const double offset_h = data_offset_ptr[data_offset_h_ptr]; + const double offset_w = data_offset_ptr[data_offset_w_ptr]; + const double mask = data_mask_ptr[data_mask_hw_ptr]; + double inv_h = h_in + i * dilation_h + offset_h; + double inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const double weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda(cudaStream_t stream, + const double *data_im, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, double *data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda(cudaStream_t stream, + const double *data_col, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, double *grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const double *data_col, const double *data_im, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + double *grad_offset, double *grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset, grad_mask); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/src/cuda/dcn_v2_im2col_cuda_double.cu.o b/src/cuda/dcn_v2_im2col_cuda_double.cu.o new file mode 100644 index 0000000000000000000000000000000000000000..e5c3479fff88f0d272f1709dad3ed31ef45867ac GIT binary patch literal 47264 zcmeFa34B!5**|{nGD{}Qgd}90%uHsoOu~@Ny|V+tkcA8dG%6_C1PDnKlqCcZTE)fy zYD{Sn(P}%S%81%lTa3l6)fa&lx@y5z`}RVSM6s>yOHr$Bwfw%%xfuk6w*KDs_y7Dq zpN?kEJ==4h^X%t2H{o+lGpC7yz)~N9O=cg}Qmu>~ef*>67=Mhh5?0K|no*}YmeDc? zWiEc^#QR^5_pf_DKG5`jyx)Kpwz+LDZ(aXZ3_Wl5uO8{2bF_b!cG5bg=_P#4I*VESm(KNV4V{`-S~L~Kab+)W&A9{&jS3+#ZL=<;`oW-CxRc}l}-IG0BH6of&kXb z7wd1**00mpD$hWZmg*zyFCv^cVtm-|%&9GGCboCR{W`@~8F=>(QtV&wzofMfw2u6L zL-Tf!2O{y0X?`Jk6Rkg@_rSWsD;o!9{ksX?eJ;Ug_BTBO;eIeau(6kmRnxtAyj_3X z{4r5~5X6RLH~EC!qoj`w%vsAt9payzYJ)*U<(|l7~@a^H>Yp0UrKZu`K@k1K&_q4)^ew|CR`d>;(|1JAAgINE1 z=B<$Z241%cF|Z&A%H0e8+Tjp-o0VBb>p)DyZ$*Y(5u_x1JCi91H1H@8%df6 z?xx;&$HDju$E2B?GP2_Xn+Rg#z{Z)>F>p5l9@rSe_mx+~`+w2D=k#9ozquwIXs_Rw z{kz#`1E0BR69NB3u*Rieo%XPQ*TMSw*L{Fk<^*DyHxbJmMJ)3^%)JIzgC7KB1GA1I zXm}Gr!wCdiJZRX}OKQDt?LT1=eHi1xk8T+U)@5IrGK&ZpHvh-&oFm^KS;PO>ol|G{ zhB^PSJO2*5bISOyvpY|uVEzBe?!1^aTNi)KwL#Cz<6B6XAbP*q5h_gM3;fPxq_S(R}n{QYF4=EH4Ivyf_`n?Q{OWpLphmM$f!DQlk=V$RxZG%*F>s zk*4J0E2myUD(cXQQ>XZO{39T+Jvb8gA871Ky*jWC zq44TA2hagRomyRZYIXLk{uie9W2SzPCIR`uULcanY4)#v2XOjly+tzy=DgF7g8@C9 z=x>qrr?>iB%t&_<@~ID`wyURSd5?4PGUw9eovW9$uj=ZY-`RD|@}8yb3zu}wU%XW9 zSiWR_$Lh}Z3C@a@NT~k(@ffTwUi!($FJIo%dFCv}=3gBMoEN3k5`4a~hdp?LCher5a0M01M{z*8*J{*j2;4z$n z-f!AW-#8A&_q_0aTHD@Jr-bw&VewBy_Ej+iwMF$=;hKXYdvLUxaZtEb*d(!Er|O%H z&kHY`*efQ8Cx&O7TYn$@-!-%UGIR08KIUJ1#r#>9G|rjXG=F*{V=cd*gZywdGvBBS ziO2*YhQ~){>-E&P2^n^Ks)CZs{u%YujoopB&;pLBmf5^pjhRe~Yd7j}guq5U6s+&R##jK}Khnofj^UZ?`In%1!QFl`wNAP93*ht#f&D{P zd~$gIh!vk4;y=NPPX_eNxa*wfTTz!7Qu^n75(vR0WVLCS9PX>8=2 zGoHts@r*I49nzT7&lq$18Dmm0ER8w+j4`L5EoOSFg2jYhCKdHExj@IWfURZhMA_+j zKI=ZKK9(&)eaz|pJc^LdsOPc;XVr69>sj?NZ0=e0YznH+m~UkWuFt4vv6*MpGgS>7A&+Ip|z*%(@^PN>UGVfV+1M{3!mzfjwkLV>aF#n9Y zi1W#rb%D5aMn719`WPmaOkr}VOrx1>GZPlGV_5^GWZ4Y1HTr5UyIJ5|AMpN z<(@!`TwE+|3muU3%f}4&j$Vf^Z5@2{zHNPP~bll zxG)8tQv3S$98+B*Z7f@eF%RO1tr+c^V1%lP!lcJ*;X zj1@jr&+wnqB@ZeuP`%Oae83o!3-`9Ymo&!MqJ8sfQD?T1Qq%n?J$u{Uzl+*UHbzm# zf@VyvXFaZwZRX9?zi+VW7QnYP&MR9%a7MPt-PAtPmbM%4JM%^Sx=u@ z-{+EB%aG!|HBUO?`~dCo^$E5e@D?@L_#wdApqbXvcW~1_b4N6c580H&nl6$*U+9-@>G2<(8v0kj$bn2J4)+_ zyY=xS+r*nt_el0Kj@L*j1vjWyGF~^)?)&z*+=-X4zlhRc^I_jc=RSKG9a3!T6HcP* zq6UrAttDw!jsgv*q;142<5|Vcy?aEl&*hS;r=lHrzUh-unk9|uSta6bz3$d%{j8GY zU`vd#o1G1z&;WDV=obZq+>pgms`utVysa~+NzHd{oUGoeEz;8PI`mc zCTVda+sziN5nX-ty{ZfF=r95LmTiIuFWfh;93`_QkI;HXC)4UVKPAn8LlEez4RJm; z*Zb;YSYP^$593%|wzANL+UxsvuRX#yWFl`6bHId5+x;V;{!V!do8ajaji?6X&n!8*2v*1Zc*wyY7ce^GRaw#zx)l7oi; z-{+KDU&v!D&$-^#@Og||B*re~{gPG?CH&o*!s+oaw5J>M8;>}DW`JBVjrhk7h)WLX z;~yIaJQkN0Udr(hH6P!%q`i)Ik{?#TyH9nwwIQrmu0W2lU&wh}pYPVb$Jg0g`t}zI z59zu?`hGOp?=+=ZRqK7K3jJ|oebC$X?OtQd?o1xJ%tv%e9*ImQc^N4ch>xD6#x6C2 zpWqh`ITojhi=XI9+NRe7e#^ciF02o^(|jZki}Z1E4 zr2pcg6n!2?dsYeXd5!RCJmNx0bQ#j=GD^o3{hDKHy;?82_8vZQh}KKVMT%as=Hl|; zPWV6{SxL@&g70hJJY7Gk`XGLLL>Kgg_6I!~0$uWhDQ0*ewi_rRpL%KK`*lc`t$XZt<_rzFX$Vu6aSs{J6-SSb~z?%V{BMIW2Dr= z`&08$_^0Nj@PYlse~`aXyLQ@ckJ{Js*lj{+LCRk1_Pfv3IM2=P{>ZjWvNxW6N2*t# zwCz3o7p{jRrRk(ko@1q%PT*gdvX}b$3tcHcAtppF->`ExVI2HH1@{N6^-wnI2SyI3 zb2)4rEa!4T^Pz#zp9-!&l1sOX;Cr=h?+VLAgHcn zV(2LPTimvnxgBNNYCc}F-%oUmLmsE1gguq+#Qa5#gD&h(F03N+pOXXs;STck5;t(T z`#3(Z6Ske0mhWVV?Hqqut6Z#)+pgO$=%3mL`=g1F8_1jeyKP+W>`xpedn70HeZXG( zN$!&KzMi0dsd_K^S=_Wx91{2RotD3b)~`QEc1O~vpZv@KpC|51$=xvagZ4LcIWv*< zt&hQf}1Izk%?=$^CZ>{?B#@`!9Bik8GDQ z&J6nicr^Y?+@DI?6Q~34ZBOd$kLiB3xfcj*$No2r&TAXy{#MfN*5?gv2Occ$)}#+< zzFSNC85^Eu&=@Orx?6t* zIfy99#CFJeq4Tyr%!d6)wtfTg&DGeydl%;08WJZtKi6xAwu2svlJ?|??qBZa`~d$@ zhyBtRYN3D15JK~;2>#ny2 zxO}GMt)W$(gm%cgE^o;rxvw#1cc$bjOzs=_L%yb99mv;?7A{}EdHX2gCoan4aZ5vD z`)pmVs2%t<>h1fHCS z=Sj{R3%JmiQTjEs=FMTu<#gMcK*w?g@`L~O!2VRwx}e`QS_gQJ#W>j6HkrpY$$4FX zQ{xm93pqUyhoC>J*d*335Mh6%#dFY)_#?3c@}B2RCUOW4;-1%a{%M!f{-BR(e4bP~ zn~#V5iMfzCkzUkv=z#(%0KX^mQkmypHsvnIV1<$$kw*2tTPt*PrB|Z7-+0rptRW zQGO}-(UP3^WGlcq0T*_V*$QBM$^NZRJVfgvUawk1rtajrFC&`$jnY){#!g zmm+P`^(SR#<-{=bu-KWjjfGq|ai01Ww?j`+^d-7r|3V7auOYI}ST8inM_Yq|lkZQ+*;lQ(kn5sSH?KANu^J{%ia2O@CkNFJBax08mzMy_NfwbGpv`|5S<%Z&<;CQ zL+!i__wDIJ015l`gqiJzoM}|roE?di3k)$KHmsfOHuS!U`XOH5k=PC}Y{Rg2h_}Oj zQUK-Jkr?LfNm~=_57xDPpX^Xevb6zm0oH46SV;RHDK*5LvKMfqOCFKiqm~ zGp)0O#0Kz_rLth`_0V`9jZY@V{}S!Me|DJ>^h=hG{i6Z>;gkN$hPE7Z$@T!`gcT;W zz!+o9;}jAi^c(tU50G6f6tpD!;D560LPCmCyPyrDK2-F|L4=<;zoPC0-%2p9kjmo* z5SOvoNU6LVdhFTI`WWDedB@reMEAmxQX}!n(;VX)!1-CPZ?NqEd@(QCD%1J6W#n)c zAD6T>zF>%52m8?k$jqG7N*`eCQ%;IsA%ZSx-fZ-Tm4ZOP}i_#yp-q%d&Splu+6?4;}-3ZA1Xt+;k138H;3?w zV_!P{%vG;Ke|N|a3N%h043HfxgucV>Y%NSCmT|m=Z6k5CLl4&dj>?g3`o2uy-NKW8 z=uZRLwG!DL{5j1R1}lUQUId*>$)C_ZF2{R>ImJR^+Eu8x2#H%s|69IvP^@>>Z*~dx zY5MrY#?gKQ%7C-OwyJ^lJ+e)w@5o`j6l|64G@sxibqgl|6b=9+Y6rto1DkC(J)qc z>@Y7&j%^cp{p4X@mcah+Jb3UR+-}x;`gFshmBl++}hVb zpR5uSgrEEW!6D!cypBAn%jYQl+*%j*omH}9^LAnYTc~Dy8Q2i0I-v8x7LAr zPO+R5{0YVt?#O${P2+`}*SY-d*kIuD;MR6v{?@`2U&3nr9(9Z861FsA9pnY__4)Fx zD&bowZ=t&Q!77o39Ob+V&z;nqgTXSoyko0HNp(RCcN@p)v2P3H0@4~F#l+lCkM z@rjWbjXz#W!`LH`D^7>ouzT`94T(+~zeqYVO!fx)GLHKf>Bz`Vz20>Z?OWX1$@#ux zYm(FB;T&-bTnPEESvp<5#_5ux2lk8ht%Wh_XMV6ox&Zmk<+8O_cF2mgyh;` zw(59q36cI43U@xF9CPs$IwwS(iyPa$M^HNs9ZVNodEpfA}cH$&z*(_xsZCN8{l~hly{%>s|C`l{D|B<+1+ldb?ZO#n+d< zuKRJEtGK^L{7Upd-ssEqC=myJd5|%ub3YCbdoUm8qK5#lSXxZ&WM|K=pN=E1`!eDN z$Vcn57?1c1!S^1V9}rLNAo%NBzf8YmNpTd9lafbr_Udtw3GoBq|9)sU%@>#0OeoF! zq5S}X^>?8i@)g(H-P+weK9R1ooucv3Lp>kS>fKLqNV2sAe(GxAfATKz*pREk_OgrY z?r`hNDz~o*iOZLeaMQkF?{|a$vJ-mZeEyUadq>xU7|!oIDgW1DllB-fU({6MTLk*y zL|F*?7X!f;NZR9wk718mrGv)agmXt?H@ewY#7~GH`dsgj^uiuV59#NmR36r?8TdXL zw8&^a<{&gRf;cKd`Nz4lbY{-$Rpw|-%+B=u;x2K``%^a8{LZh%}We%$!&X%9fK!AycM~h#bO`j z>C4#e2SDdhd((^Z=qc1kM(Lg^Kc4?*Td|Kb^8XBQdH$cBo&OKs3A@De|4#TY#0?Kn zy>ZaFg3hJu6GAueNB+NzB6p?g4Jq&pDbi zlg=wlds6SW*8ncTd6M^sDf&bVgS@zz+y7RIE08}f{S*8p^oZwU!H>av)=__hP07PL z$ba4gd#yTetKUrq4Ec{3=3AV&x4=0U=P=EkLGqEhhmc; zVbDKXA9k+QHI|=W+c^C^BGbpawRSEyd$;{@G@m9CT|AIqY(EF`OY1^@U42{*lY?8( ze>3Ev-3h#uhSn_}lH2uK&jO5Fw9oD$Nw8g;voL4j9_0T9T~i5fjKg@m;4zx7O4Vt^Kox8 zB%?GpB)ZxtF2sET+OfaUJnkB-Pxjxf{gLl;c&TDR6IYPuiA`E?$r^Xse8j&qFm_FT@t#8}dUMUu7x?TDXn z{O`s2S={|L#}oYiklIJK8L%GVAyGU5x^|(R@VFiR*{z(bYxxm?-}!4!PF1=F>gfqEsG6(jLG#%6GiW z=ZU?kyvCgzKe@GXDe<8h%F~1RF*)x`4{*AV;t4wQ{kygA1K&o-?>(zY-=u9iJ#bHm z`E>us)eroClJt5J@M3#qL*hdOv;~_^~+d1^l!xxAq&}pR^qWKfO*_ z>&ANd&gK^E-v|6_rlOzDk8hJ6wcvgc^FjYN^!vEec};Br=2;Xmrmutbd;6e&&{NJA z=r7MhxovxacfJBS=I3u|;VbBe{3M*KsNJnmK08MV^ZbK~`>z#T{*sB^EwCf-Un_tI z;s!_2gZyQ!IDc$Nd<}NVm?tE* z!`{K}Xp8YC6Z}NtTEqp+END{^pJdr_&%Or*)Cyl|5~c!{&ofV5z)5%FYx2(@Y}=(+um(|9raro z;2ZPMIT;6%1<+rF4|IR`0rv}ygC6dO_6>Sb(t87$kQd}T4}Zw*^+;(}jOxrbmE`3> zv$H-%{M$D;p2r7!xBZR#ZRA~cligzpPY#XSH?KyI3tFe@@dL+`;(v1gkl%y{_)amW z-UCPu()$BA54Mi)Cp(3EOza!}eNgxBsP_VHOS203Mf}YY?Hr%v5t-~ry$pOKCnkg>vNXrTVpW@Kp-3s#JzS{FR>bQ?Ru#oU+Xf4Ei=$XB+nDE?R z-%AD@_N%Lx@RO>|Xvg}C3b{PN51~$Wd@1i28})X#=I8rB{7VSHf6YbSmhfx)B4gJ% zrE5m53wx#do}m(l2(Cvh@WiJIsc@^ zh3E%fZz667T}0b-&aeICK5&06X}Rd9_^I|zqN})1ABVgX$1`ca59by5^*17%J`Jsd zSii?<>Vf+eaeryA=JBXoTMPWVmDIZX>@}M>-dbc2t3^0^4_ zG2C|#`A_5(`MP3Urx)HwI%$Z_gI*rEmGGqZZ-4;kSHZ^}JDlzWT;yldx1#h&*XeR7 z+oQKoJKl%Dc{Z!W-MSX*bt@)ZzLEV9Y`@uz{Q?SKF-QGra>FXW>-0g))+rW0p zUx>E<=KBw;BDjbAVX65Jf-7u`^HSEv#(`&sY@hG|%@^#`iLdbv+YkHbJa~&%0r=&P z=AEQ}BK+46w-|emINwRbC@8kHOn+a*-TFWI+z!kudl&7{zh84Z3p-MPddr$V2=E?? zUw_5vlAJdNcE$oZz83Zba=eu69^`m5Um$67Q2!`TK=1$X{J{;WJV82-Q{1hS_H(@- zEc>3mu8rGo*ayHT`MebNA_Mn-I6hpqkm4uc!`C~Yuh(Wm3H#T~?WR+{^&SiT?WPb+{A?4kCWo#-H4*M@eAgV()6{O9NIJ#_vo z`#IW?pBdzINzNODy&yVt9p`v0raI}(+4cQ>^r8IQ9DTj(21nt46Z39RlBIVJ8+y$+ zpRYhXV0OxOWengH3JHPA_-$Jb4vCOowEvN}4^HLuzD4Weemlj_R2~Fz2ki%Oy0LvK@lH|%od(5$Cdwc` zvNKOp98|)8?it#0%(cUI%O33eHZk!d9(M?dO|*}e4%=H?Uxr)X;`$*Z+5w-%D;$Ev zHfPR$Pa3Mr{Kyv8m(B}M@=!ExYP%_W)G=@JCb4LDZ_ntjN zkUzK9g?tjuW60Noz__pLf}RJ8Wv!wS@{^I`qnNm!^9THR2Ji~MbWrENP^#1Y7Ojc% zJ*D5Nd4Rhc2;#k)bNG*!pXmL=6#qB0KEU}c*p`sJYY|)9cpTx@u6&yIA=ujQq;|;b zcc^}w3)TbjlaiMdpU>szY5t7zQ*I^ulHoMm5f zgpw2=W$n|P|L_mMKfgG|S6Q1AC%HiUj|%{w>kK}}z_5?;ImKV-H`lKeznY=^s>|g{ z@%^-(+gek6zgOow&I5EGcUr#R6 zm2DZYdpPH`k{|d;-&1n&k^GQ7fqu7Mx{pf4Uqkgfx47g)JKDLP>-v6oL~jdW`ZTJ$|a5E3hjiZ_r!t ziRCgY+2wrlRv~J=XaE`wa!g`(LB>^`Fq=-ZS;Md8gV(@l(p4r}&)W=Z3@s z+>S$kA0hlKCJ`Clh2tu&kGOE{0zfE=;a)>goNJzXe+OF&88GPXS+uTj& z1N>`k)$O$H7H&6W+Y-=?_;eHKWr6*EVYHq22R?EB8R`CpEbu7*$S$t{|74t(Z*c+7 z4Y0?$Ke&KB{t~Hg{chK}cKILa?P-0*`a^s6xL}uepJ|toXG!U6jQDtqR)u)U<3Rf; zKi!<4`$_$WRM?-~nVi3rePJv{dVVIKKE@xLZfBdJehaz%O!289L43mfIp>#PdyUgy zuzhWm57&Y(Ij8w^8}Sbc480!ZYsYAR7v~@19Poww-dtV+9_NE#dxV$JA83#W{HG$m zA%0zT2QL#8ABpgzmu}@H_t(UqF*r|g{?rqHZrh{6eV%lSwjWtgH+Lis=>AeWFzPP{ zM)mDAI)}k;mO;Nr&zxM(z<;>UJ{R;2{oHT%VjTSD_66r z-#<78|M5f?`A@o6ru(6h(ff_0&38YGF`NSg#_u1>*n&ayA1?(|#x8iTk-FdbN2PxM z0sTw;{zKxN-+xGWw(IvFW#XClANYNQZetAjAGHq>)J^f*EruB7ReX0+iQhX|h<@a= z=BiZU{aT(E**DL{&%t`cm(|mX}dG)AKe55GV2cf@tgG^I2k;@(Rdd-Q=oosn5py zuSeR^eY2C{{TUjE_vp*$9KE+xe~(_~_jov0YXN?5(vWC}d}TO!{s{I1=dwQFkGOth zl74<|F9)0!Su3AKa2gWBjST0( zrP3*mtdKFmEbSotDE_9B@)2o#+{1Kxg7{e~UBLUr-THX9b|>@&?+wl)IwFsBSOmU2 zZsvSTy=RT{GwXx>#P0=^)B5g2pBMA+dldTb1Gu$^bvX8N*kL*!?2VyB{^StGJDs@y zK>h$6)DG%+YIpevf2o4tQQn^3C#U@;7Z|Z$NqYh9g#TCgIN0y{1@!)3+a5kXb+6#o ze#X}wX_IK4N461MdVfj=ev}W>?@=g!6+`?f)$YZ3(8o3U{;KWUF^tE#mF`W7@&35v zBsloJim4cf`v>9&zkh%L(*B8WE~l6nkq*)ejHms*tIt22%jYLs#Ro_}#NJ^JANLvP$NRUj`fjif`d`N3 z?cMec&v(nMkY(JrN^gwbpJ;Aw&rjbE**25@w3y&Np+cDN#yTqQrMD2hTHG98RY|>% zntGpked0Bu1Mt5E>t`w9dK`eEfx5hrzHTA;0DnmyJvjee3ONKF%IJM4oP&=Bcz)hJ zdLP&9bLDKO`-{YD=(Z!@a>7mUlZm%u#u&YiNb`^vNVEel_1waRS-$J;b?19@xzg%6YbR^dF(zze&B0o=ghCjDp9p@Sq@89(8aUCmV zPRv7oVR|LOAJ*1vHum;9kC!oLKEr!#Z6=Zr59}f56YMF6L-|`L=ttwmVjS|DE69%Q zEq#aQYd^AAqrZn$YC=An;OOsRg+YECfa(66>scxfDr$P(6M0R*V>VeUDd!%bq(?^G+{pEp?o-=6#0S<&_mS*?^uTtEqxYA@ zzIrG5zguklp0ahzXug%+lLi3JzbVe-H8x^gb5&yzK!F z7x$jf8{`X*>G=c1r|74Alz@G>U>~+PseT*wzX$XOozM=wvi0fvu=inpj&pCDuBXs@ zj3+swcCur7-6V?J5ka|9?QV_MLmsVE*ZZXQGs;`HNE-0njhJNgei!<}^#S^W7ie+* z;`kIQqNdXg^N$jrr0wm*cjT-1{c1~+-^*m<{mI_lhTcMF)#&|=J0ZefxJZ9i5gn7- z5abml-e)jE&XTP+Kt6GAR{AjRHzMf2hu6b6ywu0`K>nZPb;Q8}*28<4^?2VR;5;pt ztyEW?h^y{TV{Ac3;@dI0XW{oRy*R&V-)=ELUOQ}e@%9}FH^IUE%U!&EhixCKY=hKV zzYHbvU|SZ@eDIH*ybJF?LT_n4_D%cs>=?bj!F>xkoI==9-Y)R_WV#=Djot?>#`zoZ zJI+70&KSWJG(q^_e^@cfF^u)%KIAnsm8thjHfWDbrF_bmrgxq%`uS@o)qlWn0jMF!_pW8%^pQMSkghxkW zmy`B~-*duw1n*6jN(!|PYrA}WpNIAF`29DrelyO`_P~9#?vCDtWhgfwpL`i(TMC73 z!!eqddfzkkekkC6$k#s#>4X05V4a-a!r(rBPDDTELr>^E8R(&4dp(X4@)P0v7&&a< z_%@&LZ#l^CFDdVoT`VOGw7)p+v07-nV1HddcjEnTUbo59$?kU~7M5}TZH58B{pGpu zmknzR6~51H`Z_y$AH;N=|7_pmaCdCj;iUR|cToUO`K!GV)LR4_t>eV|n;oE+6M0y! zZ}@#FaV`4kUj8O3@qP%m#^HGGauz$lE^*>Mcmc4`cRBYAnSK-Iu_Yl{j{kIg?ZWCTYJd5|4 z7u^qQ>)g*@-=X++G`=iju~$Apd`rmvZ$89yN_?IBiH$a>B;{v(X#AOR{9 z9rk&>|6}vt8lAs$6k*yU?e*OG@9F&?o1e(g0r#FUpPJ5{Pk&DcpO4MAW$|`)=KM@F zojX4c__V~iG)&;_H2UlwU7w!MhyQ(;c)t0qV`v6mR6S$<`S7RLqR%&9CdQsK|9tqX z7(?@kw`YEK=Z)4z8BJs|F*Z4N<{&(O9h1*H-#a=GUX4C9ABOy6^Zziq|BtQT#HX>; z-&vyh5Xq{T-hBF-=5Z09eY)wahet+hsS-Hw58+JnQhd@|Kl*+tra!03Gv}R)>qoOb z{`Ap2L5QRK)e3#gL6k8(N}xOv2d6({ou?o9L;b^N*K;ibJ87%Ep{r-*;^oVnp<2JC zc2#@NwOy;6E0@mq`z*EXot-^hD_726x%jhP&QPF^{WY+7OD+HVch1_?eswkrjG2EK3(fa=TUlKw(8?w{lvWmZsQ9Y$d%J~Z2O8s! zAbo}CD@DMXJ!;tN{PYRa#HzN5T})iZ9MLXzkwfWX45(d$NA(&eI21fx z9@v?-nLILe#Pq4`Dnw{4@xVm?KytO{VG0`Yt_uPekXR zq$$>Rk<|qNd$J=!ogWI5vAZ@|u>OQi&Q`x@z+Wc5;Q)Ik;B{7XtZHA}S-ZB7Ee)c! zY*hhk4B;W4t!VG8U9!A(Wf5L#@&&wMCeC5Idjg?}mX%$rx@)iPs$f^7Rdg%f8g?R6 z4NtVhP;oBoX?NDw`{pZAP@?08Rkc0Ko-~SnR(4HifK5>&6LGmSMlE62pnvJ=C3i4? zAnd(gU>RaYDXWNh-7MEXfi96Z91`qn`zX3Yw~;fdF%;XyGHqMT4)3 z{$RlSo{=py5#B-IJuVt}MhXQ(cm$_yiknr0z13_?zzw@1lmv0m za|gWR*!svs3-qaWWsT&CdRI!B-(eL|?@YlJ4ZXn9|H}L!pZAw$u}QWR3S+%4HZBOc zt_nd5@OO=E{#KT+UMXZQmV+Tb*0lvDu*xuob$}Er7cKctx+~zz;v13z zp@=tAlpYh9%B+nhLEI>26iZd1=vz`o8w-Ygm?(9y@zIFYoF6$Ui3%I9_)Jp1vP+Ht z(=^E+Rsk$m$&p2sx&7W!mU%bJ4-B!ouzxF)O}EOG;lNgw7AK~BjN zNR^@^JewKZ5pN|+pUsM*_=(JBl~M00!Hyoks!-4G3{|SvmR6=J4@puaP1={uhD|~Z zbF|U1QuW!H@EbFm%P0%>AQRsZSvj*#vxrhJg!T$1P&Vmz>=f;54|Uw+e)y~@ch zhpq*H-6CIS{j_O5Xv{b!-VbEQF)7f#tn>I7*29Fktg<_3Gl~tgmEW);CCJJn-U-YW z4zjAK_W(-^GBzmUS0{6#PlSc?wS#J16(G{Ms zPAFTt_+~-4jrp(XRK0@qWx@H`#a|M{Ql|DOk*|Wx6DKf06|5-4U$YK1IMK4a=W;nG za;q%fFJ`m|$*o1CHeHQZlEPk~?%nw{G7AseadCG6hQxxBX zmtn{KFSEe7NSg zLY;rYLW5vfC9$`S_~n!;#s8FCr$!TIQ5upg5mpd*Rxb8eK(KtD6LQqUtU&#;ToOFY z@&ZpWN8m7o1{G6G^aLj~3+!3*WM8NMn8El>7VxVu&%!&*Qp$?_rT8O(nKJ&qOu##p z6;t&O1_uhUFf@T>T*<1!-g4~C;=g8T`&vP8vZ{#pW+qLO$NR%oBtEw@sYR^zN3I75 z-wvpBppgkPa|~51H}G{+UUZ329hxwUxfQQ$7ACMb^ncY(K`1i2qHT7T=C??#X_AT6 z_`^0<5c#Ir6A1TN%A-n6PWmORARJ69iJ~L4Bdso~{5VUH4`roAn6T0$npoCfGAk8W zhg|=uOyi4&yilD`9iDIp8yg%<&$z~wyDZt!Aj@A?ooD?Z z7h0-{c^AQttzmZm-*WTS^;w3+@>io1VE9)}w5-!yMa%U+HLJNMX3b_+%~fjsm&6-& zN`N8{#2a-=h@KE{)F}}PNhsdHLIUto54*;PBZplVDWTvIS6x^=f?t?ao+`~cX|f)1 z)z+zS<`-Fi?SkC=814)Yb!vc`ik%205GPz@Eutp`3KtcHW;p8nz8Maw&acdHh;@E- zhQr_qhkjBb7?}-~N| z%9waO-J!f|XH%505|SocXBP`~Vc)d~Fnmj)S5bNjT#HyHLQe`kg|DS31@Vn_J0{## zE&%#6d9rnW89hxfVAwBY(P}bIMpPnN523uHRB(hOd5R+h09HI$e+ocBJQPLe+CBb& z(O4aXh*SpQhVEt7wo2o8C3tTo zOno^^`@P-hDbM_@Ta?WO(Ysl-KRl61LvHItoL}n7g+GeY0;jjVGgL0tAl(+cpA`na zoM$X3cPQa<**c$DOp$Wad}ei)Bl3$gi}5QDW@Am?Y?e85J;9_$Sx(Tzti@%Bor2fN=?@51 zA@3h74!=hbbHl3q;ckA~(x&*Ns>6 zL_s_(6$h@eNiEJirJtF8?h2jcndHgd$4{L<@Mm*< z=Cv5PKds)l-5gc?>a|u!U;{#=AxndK1FQ0R2hhE#nvp$EpK4mqM32dE2XiRv5x$2v zW=u7&M?B}fg9cDj@S?1GQ~o&d`)&hfh5^dqud+@}Q%t$af0fFcSW$4gN2;~Uh#JO2 z^L$Eny{Ss@D$&P9%dd09E8KDw6HivvDe9*UVy+U+s8%9^F{1dxbIcPIf8-KJZumWU z+;}z25T{LH)=5UeiYtY&iXYj)v5w$zd2)^F`;!Qh&?9C#lnrchN#LglHb0Qke<~oj z0sJ4<;SdyjQWPJllxu;~L&iMsV$qFKSsWhQraX~FK1$9z*{as)#osT3SX&8D`u4;?`mdc)kg9A>5k66OhG&#Wqrz9fI!0K zk38xb8#<6LyjB%#U%2wmY0`7$WAO%4?&?O1@FEk-nDv*EP!M~)Qo535J~qbT`?mSA zxAa&%*Ux4#f%~EYPaZO8Ob^ge`MfE0+;Pr=}5G|cXX?31d#iTrQojCv9)~n#XBqI$lu(7==`hQCJrw~pJv(i0O-&Sxs8%L^OGZRm#jZH*x3TJNY6jV%+M%KV~`NfhdWswiYg*BI0NY4VP0>u+nagoBbvWi?9R z`s(ta(#x#BO_QVwXl4oT`N|YV1JkRZDB+gf3sAEqc;hrpLzv0Mwr5)-KsAurzWfp zj}tyqSFxn~7YxRyl4Vti!X+&8Avp8be3|`-)caYnil0C~D+%@!yDORSC3aD_>c5FW z46YU|pJl=pGhCsBVCoI8!!HRkqL@mS6OOauU>h{LjitZBis5fjh#T_}s|ZV&zgrEB z8B-UT@P+YADwT?Y*8~u&=Z13f6lL7_=_X+stBp?R&TvG!GwMw{BzM^Va9ZtnbzhqJ zVgU{jR{IpA`1^u!ox$U(Cx}BxPWYkR+=WX7F=u>rDEt}O6Y)J#TNRG_Szb?OZFb~u zb%9mOmu_HlI@(v(F1*Swe4qq^ZQg}~fPX3GMse!0uG*Ei8$4>*$#R2J@@hf>+|_;E zYn)=qjcym%4XUM374I#`3tXL^zHXwaDbILyy3iqtQBc#5gJ}g!8@!ErtZZyxBP&r$ z((58hNxDILunrH`8;#@q!7G{RNyK|k8ZxA3Ax3mR21M{8$$5FpHB}N%w6^am1=-{74oim57zh>NO%Zi5|eA+V>kH6B{gb5jaYr zGN9ZK+Z%`!1nw5n0+BJw*ZpEAmbo+_NV&nH;A*id5*f^|j)Y}Fegje{lt&`-Sz6~r zcO<$_$odujRFu&k-EI{AWEPz)^D{zGw8@B~iq2-@cEtaI;5?Qen#T-bp(ylYF)b|Q zgI(~qbCs|#jvfn@y$sPQT=uZh6YvikL62dhFg7x5ERVEJgR4X`A*2o) zFU|`L8|~;53yVoQhK)vXsa?3W$c%|mh({6i75i&CL*LJXPzu$QbqAF# zIdwsQ$YeHr)sS|vO}Ny|MI#u1XwcC_3YiLmeq(VsWJ3BQWHJc9H9E z$`2Rh;Rs?$SnGnpd8KK2%2!<_xL=oFGCKT4S?0QE_&JkVoQm8_7{zVGj49J3;5_7E;ujzspSw@%969Rv06#q3wV?G!>jv2dULypiDo?uzFeAP{Y@NQY=JGGUeXrd{r##j~h z{UF`go+%s~BTj_Ehmo7ARKZp8roxJ_nm{ZT7-tkKCPA-6Y55rBYJQq)8W;9o#Z0GU zN9dHCF1=Toeq&yBFg(pz6PfrZljc}zgW<=R{2{Uqk1#W0 zMw@lB!61D>bSSrHB5JZ(x1f66*lI<+$&eR)en#d^hRSg8msA#n*Gw-Cr{{{l8zWu< z+aAi7J&ONb*`|(_rDtOBUe=opl0AF}%a7c_EH@iq8)4NW{c!R9EaNj7@*z7spm-mI zIUHTb#5;2x(RYy+SegxHWf_ISu+4x1Au&d6V&l3aPfsgQ7qE=OEbWTS%)?BYQbCbV zH!JUstWj(H>Wzk?=#8*Kh8XELl$X9&_}k0bbW zMw+c z2i=UhMo+|#JfRYvYc!2=v{yDHDl?@^$|FJ7%*LSsr@S=`YD#f4d5Z<*)(8Q~lTod(O#nMg~ z9Dx&tF^=d71Cj;769$vvb7^?^qd}@uLsN6pq>BygB(uI>59Ll5#Dfiz){^MmTB~i3JUI{i>|74KHNBslLeA(scw-~FU!2wnh7+`6q z(#QFI15Em_YNGX5;EJQc&p~6Hjg~(f1jPiCl9?t{`<43)8EJx?kwHBD0t z{Rje^U!-Rq!HKCxXIG#*8nMXmwtsG#67f3UEJV&n%=)IQF!C)|agYhsL3M`F5tv~V z@{}1yyO=@jj zF#L^)qS%t2^|xu}Z%vdJ#pTKpd+2{)(cf@Ne<&~TeQu20C?N1?w$udzr&!v&4*G}v z41(1xNcczn=-=(jZnn1L&|casS{LJHLRrk_IPk}3<~aQLstX_^I@=MXI{u=}97mYG zBJ>r-7j*eLc7;QsFO|Oh__8|WD-fR3QyD#l>B%i)SUJ8_{uMJ{X`O1mqE7YCpUz5y z^QKQ$eg1jVYt_i4C64H%CHSj!H`>EKWnD8Jr*C>#kB2ZG9R41}u@P#z zaf#pv(1W27iSP;H3~OW4=tEzsj7bg^`}js$7AA*I@R|BG_7PatZQ)fu#(Mm zsIe8IbrqXY?7w4)!`HI}$I1tm$kr*d3?bF$n>$VFoaRvPSR&$&ZB(oN7xOPs=ldgf zEHOI*cPueki^4{6f4Hs_w$ms+@34NS$cS_A!xdI1lZ`8vWLsXqDQd!}XE5u16400~ zTDO^IS~90s1(3t8_D4&QDE><`@>l-2Y3DLQZ8CLVk6&`Ldg4sHcVdycrP=6?7b+J+ zpzX^uZf+8*!Ue&5akDfjj=((7xx`Sc?q;UWC0;fB*XyJqQ9dMQef2t1#*%TXmrR-{ z{*9FeTCUIcv*~H;rWHj0FpaIqZ(h_>yZV;+yr|ddSJuSyqAxZxACsRhkZ*`1dVjbW z$~wNo?76mU+V#?>+Dh<8&ho=8#isV_o$CDA1*Rb;yn2cC3Ebm_n)9OvrWtmzpz2fR zHd~)!dB)>%wz{p{;d_*sgcFh{FySwex~Q5a8FLl6q&%voOJ*A#Lfx!36#AJ2gYq+J zMrrsvk(x+kk(9ZGS&x}#Sf{jzj>rtdOl$9@5PW*Lss#@}PMc}ncDYa&@q=U5uhWx? zCr_ZwY0OiuUnFM9Cga4?NUvntCQ0syRL;gnl#Rv68t#+iY*$gVosA2Ji)7$X5-O4{ zOJ@L;AK<(aUBaw^6-YatyreE1Vzcn{>c!RJaJ9)4U151nHt(Hj`Hn0FO}T+VX8t7; z7tDlZDTdAzhl`mhKo>=@wrZ#tLDDv6oGM6TmplAZ1X23xOdN-{X0k6a>mQaGEfbLI zecK}f_pt9T%V2E(I;%DmO>eLUmW%nJv<7P%gYUX;nW2qYaBnjK$F^xK%_L1Up_pri zmpx(d`jNg6?~N3yPpzq5vi!i}+(^z5t+Rhe{E7-xj0>L zA`A0kdaa^9ytpomn*o@wM;7DA8l25CKA%<)2@08WMX_HJvEt7gt%uek1pRKn`ctP^ zp@f<$vNm082sqR4LDJIqGcl?})VI=j_R6b7K5T~l9@8T}gPhy@HKJH+0l&@Ng7oAi z!Zwq+qO3^ygV!4f?m{fo(jps!Y0uS3-&~UMZ5c-Vk3Hp)aPu1LF1();-nJ%d7cNAC zHjI@P`=swOTf}}zzW;#7`obmF{6Jf^^^2FJj}86|Z58QaBP$L)P@SfZ&+|@Xwor({ z6aSQ@{a`|Q6b%`#NLA5Li_z;3*df2aT$7n2OOe%%Fs^B;lwf4__;B=tB~yHz4L?Ip zE&6VHTJ@50b^fAs%c<)P-D3KT4EZ2ZN zeDzMFL$NOtq#IXfermQ@JrjR66BgzP5w1iOzK~Zus1l{0`>boGN;uKh ziPlNDnK&||KzV+JeAimTQ*mqGbwZ^Qnl6k-MvhLE%LH)(EZ*+xMX9&S`kNId>7}r> zTVP}T;u7m`<2Y%oD)FnLL;c-~oUy*$9_g0qvA)NlGrw3N;Ks!iRwt_SBQJ>7pVY~j z|FZ;6xMGD^pxh;l_k}B0NMBR3DpwTvFaA`nvY~sd|HlaYd{HSk`W&L+3B$M=^nZ2y z81>#MR!^=xHuAJ!z8SU_Ng44wX;g})8%1P@BSrzC7($?e(6Skh;1kOPd+3QYdEBxZ zU-Y-4`EM5^)~#mNv9d5W`1x7sW90&GC2kt*J^7(83nl6yIZs)MgkwSw<(#<&YU&lU zhF4|0e?yfLtrTRhX#ASNrgRzd{MRGrv;{;vXf9A|#k#+C4cpxYK8uu@V{CQf9eR@r2Y@u1Fsj_JnZBP!o*c!ohTv z19uH#Vep$*j8~NDJyPZ+(&w1`yq|q|ak>Y8tzg9!a;3i{^gUJ>8DQgs5s1W-R}@A5 zTP+Pe)g#ZYmBz(NqVZg|EM~bh4HqK0R{_p<=hnf$r;Qc)jNv#ZC`tZ`wh6mePUu!G_AgH;ccH1 zT&!q(Btt6ul+iLh{a{+LuRJZiK3*R7B6s=bbW7G_X7l3D)`gWxILxSznJrR!yj1M@i?%bD~v&bD`%jq><8(i1qcZNE0d5&CeR0ocA+>+pX#&m?b~WBt}zCO zMv4BC86RJ3*m#8m7G5-8koI4hIeC5&E;5TFH#Uz^l?iG8W5^4iEUpSo@LI;H(cj&e z^;^+;<&7EAS8~{NCas-jTxBd#>!-;Nug%&MHQmrKMe+OAGzhREt>Tr-x~{FgMyVH9 zG-R4Smm584uJyMak&(HW%&bS!WKj`{B9GMJy5M-4sW1>-=k+VEnd<_|$c?hFy;S;3 zli|5W(^W<3$J5Hy=vQ!+{#>Kz2xNX1H{-#%rkZGAnjpcFZMF(|(Oc)&jf=zulN5s` zcqEYXf10}**tV)G{C?~>KTRD%)AT=SnkEfx9skBombUp%OVcE6n$d=~Z;tIGZfz%Z zp3}5+&}x*m8f+a5fz%0g;%8!oi4C@~!M0ACU=?hG38V=z#s;fEY@?mn#Gn#8=e~38 z>leGgHcdP7^Sk#u=bm%!dH3G;-hKDk#XpjYTj9CI@?w}}_gF0(V3ri}-YrKh>J2^R zo{zfKR(YYXRW3PIRO@k_k&16|!G-!6rKBQo|L($JS?Ou3YYv=I=3ifQt+TSZ%DYjv zRW`r6sc4hj;BI*bE-c@XEcMFbfL>4psXxN}ckHNQ8FQ8Rjy5>m4T|hOVSgFUpf{>j z-U_9#;OREoPn(Lq0guo63mg3{M=aOcw%e)#+ijLgKivI7qN3@Ld){`qW%j|fz*`>0 zV^eo+tnlqTS***21J)|{fOY+9ck?}a)Zs!KpQ%2YUQHTS~Q*1iFlLhF_1e{@ZitHx?qH%n^+t~b={-0pg~W~)JcHP-d+rh3_4 zW3|?Kd{qOhJl=W;YrS0Jx74~lXC%AdqQDva{$o-@z;mgzA>fWH>f(Mly|>EnkYxq) za+$A1UhG|Nt@i{jN|sW^zT9dn=#}h&Z7R&wO7ar_MQEVlS0J#M`LQk#c-LBTU-i=F z6U*i;8EI?|e5-b#@H9StzDFt@UkK|?-dwQ%8i zU{bb^t=zk&$rX$(_g2A;>53}XDp{@fHUId)a?hLd)UWPV>m>EWV*AFs3e@W3_V1U& zIsKWDl8OV1TqDQfLDwrbMg3CwCXZ|HbG=HzYqk~M*Oj8+DXU8=a+Ru2m%>?M9EKi$ zRlVDPy|Ti2U2$b|sH(=}y>nYZoy{^_xY6TlSt1vFP*<`(%)Xz;CK7$)Pu`P2fu8Y8Nmsux9_O;`p)mQ{HD6}g|H>GZ;|h%YWvR=SzRN| z_r+l9@|lK4Pt!MLWvzejqfjS*%d{@pkGuM93pX@2c^n(mMz7zo!KS_|D=^4(O43t` zy>C#hgM|f2DXJ7auy%pFKt3Pf4>n;rlCi$)JDsCFxF+pmskgJUrFzZqXeyzns{Ia+ z!wJs{QY`Db*%fd&y$uFzXj(uD)JIYegIg7V>8Ju?&9?;y~~sc!ob>@NwDTQN)3q_{H@@5_$I`2GodOCf*Z>0eSID z;Gs`Q@{z$lRvX(H%r511s)Y=0NHM6R1rxlyS}r?|$;vRehcbQ013EZIt};tbH^ihz z=Ex1mj?=QT3-qC16TnBx7t52xuR$U1)80%h`8v-fA|IWFJov z#|qMXk7Bw0AJxEd#d0>i0a>h98NuBDA~1yhyR8Nz+kafSoK5v04~wMIAT5`cS{ZI* zV+1m}{?C&B&7{AEN**G7KjEzZ5Kom#DV5os1$p)h(HSV6QYBjSokk?-9P=lGf!j}`%GygVC7_K<;MZXHQvTW+tYA>a^X{? zr-y{uc@5Hc5uPI4OZ%q?f0*)kzCwBx@RdsDxbXJyKH$0R;=h2P|1;F(*?0|lZUj7+ zd@K)sOCB8G(#fRs<+vw2a`4`DY z_^&8L?7RZ_DZq32bAKNE3DR>b*hj+iAyNYj<+=290gnD$p*VB-9}=$95P|ImsUuH2 z*_{VJmj7hP7!u$%;5N4TI4W)H_p9!Og6K zjMHRrh7<6LL3CP=Bs-#sVDj*=9*toocwjJSLNK&LQGJ5!U~P?HbOrSw_W_Wu$V@a8 z*%_I|Dvt!bniggVc$FH}{;l^iqS-VUjt*`O4>>jXWOJhMu%p|fX&5nA6V|;im`a2u zwnmbPNNh`}yBV?WWO6b&xOJxn>+j>9A+H8?i4VG5ox>2gNH7kbek#|EMVs3AR7yL* z3|-B=tMRz&BhAp_vZLmO*CpWa2C)r?BhW_UL1^pHykhQX9Cq-$C#M5BLS#G;%J`92 zB<)V0$Hkf+3tmEx>W6nwv*`@RVxz&(0f?Z`RzHb2wiq`wj%hur0ikNHv^~ZeVLv-O zoxbk<`}c!1W!>3X7iC?}-tL`RN8do_)}bg!yPe@Uj1V+Dj7DS8L?j3YJq!&d9G0}+ zzJZSRK5d|ZV!rQyV(Xd(twDsPxzCcn#{s^eVU4(k9|DX{im?<;gbGNUDosK$UvoGma;ZeznQ znE&4odAz|)PUUJZpV*&IbvQEP$J`+pg7E(_f>B`jz|sTf0;3aUjW?O22|al>s-yaTo?@MCgF2lL<$=D{oI`VaM7C+MjX@U;RCE5#U@ z<@X4v7yp z*U2c~Ea0aET#UnA0xrhkoPdk+F9^682Nzujquo})Zi|46ao8r{VjOND9L-=H_f$ph116l`iQ|1C68eMk91jWlHv!Ju;mrcx zCg2XZFG0dsiv7h$IM&uGian@KU<}MR~ElNy5!`f074(NzfzK z_bmYz>l>x}R*lHy6mGU9tV&aSZ@KyuRy}{CN?_ z`8+epCW#ianCOz=jvb69W{ew#p~!d~<_8F6Q)R(qGI*E(w21}a2P{ItnB!{q`;O3L z+)xE8?<0=h{^9lP9-A~pCbY3+5N^F4A!i%v~hdfF(*`}`uv*TRNl z7!u#-E01Fol2{(EPkA}M&sTnsmClSUXdlM`UXJgp0mCi=)7x3JhjK5{xXGE+op@GS z3MtgpmPU-kt_O|%NvZ(aM>(-RC{YI+>ZHqdgcz;QH5dcOY#*;DP^Y=PlghUNPb6-i zV=hRr`cuE*I4_TUkxXK29LDlye1hzM(d4k%KKB1y_VHgJkoa>aBR9AHVPNF4KOax7 zK>Lhc`rG7=Gy|Mm<)^4TAJ@48v;QaZls`u07v0258hdm3yQ#eS_}zP_A&>1wBy;)G zz|7VDUVqdm#mnO!FI{ +#include +#include + +void dcn_v2_forward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + printf("only implemented in GPU"); +} + void dcn_v2_backward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + THFloatTensor *grad_input, THFloatTensor *grad_weight, + THFloatTensor *grad_bias, THFloatTensor *grad_offset, + THFloatTensor *grad_mask, THFloatTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + printf("only implemented in GPU"); +} \ No newline at end of file diff --git a/src/dcn_v2.h b/src/dcn_v2.h new file mode 100644 index 0000000..1a97ff0 --- /dev/null +++ b/src/dcn_v2.h @@ -0,0 +1,20 @@ +void dcn_v2_forward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_backward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + THFloatTensor *grad_input, THFloatTensor *grad_weight, + THFloatTensor *grad_bias, THFloatTensor *grad_offset, + THFloatTensor *grad_mask, THFloatTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); \ No newline at end of file diff --git a/src/dcn_v2_cuda.c b/src/dcn_v2_cuda.c new file mode 100644 index 0000000..3ef9c4a --- /dev/null +++ b/src/dcn_v2_cuda.c @@ -0,0 +1,240 @@ +#include +#include "cuda/dcn_v2_im2col_cuda.h" + +extern THCState *state; + +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +void dcn_v2_cuda_forward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *output, THCudaTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + THCAssertSameGPU(THCudaTensor_checkGPU(state, 8, input, weight, bias, ones, offset, mask, output, columns)); + THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + const int batch = THCudaTensor_size(state, input, 0); + const int channels = THCudaTensor_size(state, input, 1); + const int height = THCudaTensor_size(state, input, 2); + const int width = THCudaTensor_size(state, input, 3); + + const int channels_out = THCudaTensor_size(state, weight, 0); + const int channels_kernel = THCudaTensor_size(state, weight, 1); + const int kernel_h_ = THCudaTensor_size(state, weight, 2); + const int kernel_w_ = THCudaTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaTensor__nDimension(state, ones) != 2 || + THCudaTensor_size(state, ones, 0) * THCudaTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaTensor_resize2d(state, ones, height_out, width_out); + THCudaTensor_fill(state, ones, 1); + } + + // resize output + THCudaTensor_resize4d(state, output, batch, channels_out, height_out, width_out); + // resize temporary columns + THCudaTensor_resize2d(state, columns, channels * kernel_h * kernel_w, 1 * height_out * width_out); + + THCudaTensor *input_n = THCudaTensor_new(state); + THCudaTensor *offset_n = THCudaTensor_new(state); + THCudaTensor *mask_n = THCudaTensor_new(state); + THCudaTensor *output_n = THCudaTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaTensor_select(state, input_n, input, 0, b); + THCudaTensor_select(state, offset_n, offset, 0, b); + THCudaTensor_select(state, mask_n, mask, 0, b); + THCudaTensor_select(state, output_n, output, 0, b); + + // Do Bias first: + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + // (N x 1) (1 x M) + long m_ = channels_out; + long n_ = height_out * width_out; + long k_ = 1; + THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, + THCudaTensor_data(state, ones), k_, + THCudaTensor_data(state, bias), k_, 0.0f, + THCudaTensor_data(state, output_n), n_); + + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, input_n), THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + deformable_group, THCudaTensor_data(state, columns)); + + //(k * m) x (m * n) + // Y = WC + long m = channels_out; + long n = height_out * width_out; + long k = channels * kernel_h * kernel_w; + THCudaBlas_Sgemm(state, 'n', 'n', n, m, k, 1.0f, + THCudaTensor_data(state, columns), n, + THCudaTensor_data(state, weight), k, 1.0f, + THCudaTensor_data(state, output_n), n); + } + THCudaTensor_free(state, input_n); + THCudaTensor_free(state, offset_n); + THCudaTensor_free(state, mask_n); + THCudaTensor_free(state, output_n); +} + +void dcn_v2_cuda_backward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *columns, + THCudaTensor *grad_input, THCudaTensor *grad_weight, + THCudaTensor *grad_bias, THCudaTensor *grad_offset, + THCudaTensor *grad_mask, THCudaTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + THCAssertSameGPU(THCudaTensor_checkGPU(state, 13, input, weight, bias, ones, offset, mask, columns, + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output)); + THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + const int batch = THCudaTensor_size(state, input, 0); + const int channels = THCudaTensor_size(state, input, 1); + const int height = THCudaTensor_size(state, input, 2); + const int width = THCudaTensor_size(state, input, 3); + + const int channels_out = THCudaTensor_size(state, weight, 0); + const int channels_kernel = THCudaTensor_size(state, weight, 1); + const int kernel_h_ = THCudaTensor_size(state, weight, 2); + const int kernel_w_ = THCudaTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaTensor__nDimension(state, ones) != 2 || + THCudaTensor_size(state, ones, 0) * THCudaTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaTensor_resize2d(state, ones, height_out, width_out); + THCudaTensor_fill(state, ones, 1); + } + + THCudaTensor_resize4d(state, grad_input, batch, channels, height, width); + THCudaTensor_resize2d(state, columns, channels * kernel_h * kernel_w, height_out * width_out); + + THCudaTensor *input_n = THCudaTensor_new(state); + THCudaTensor *offset_n = THCudaTensor_new(state); + THCudaTensor *mask_n = THCudaTensor_new(state); + + THCudaTensor *grad_output_n = THCudaTensor_new(state); + THCudaTensor *grad_input_n = THCudaTensor_new(state); + THCudaTensor *grad_offset_n = THCudaTensor_new(state); + THCudaTensor *grad_mask_n = THCudaTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaTensor_select(state, input_n, input, 0, b); + THCudaTensor_select(state, offset_n, offset, 0, b); + THCudaTensor_select(state, mask_n, mask, 0, b); + THCudaTensor_select(state, grad_output_n, grad_output, 0, b); + THCudaTensor_select(state, grad_input_n, grad_input, 0, b); + THCudaTensor_select(state, grad_offset_n, grad_offset, 0, b); + THCudaTensor_select(state, grad_mask_n, grad_mask, 0, b); + + long m = channels * kernel_h * kernel_w; + long n = height_out * width_out; + long k = channels_out; + + THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, + THCudaTensor_data(state, grad_output_n), n, + THCudaTensor_data(state, weight), m, 0.0f, + THCudaTensor_data(state, columns), n); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, columns), + THCudaTensor_data(state, input_n), + THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaTensor_data(state, grad_offset_n), + THCudaTensor_data(state, grad_mask_n)); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, columns), + THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaTensor_data(state, grad_input_n)); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and group + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, input_n), + THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaTensor_data(state, columns)); + long m_ = channels_out; + long n_ = channels * kernel_h * kernel_w; + long k_ = height_out * width_out; + + THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, + THCudaTensor_data(state, columns), k_, + THCudaTensor_data(state, grad_output_n), k_, 1.0f, + THCudaTensor_data(state, grad_weight), n_); + + // gradient w.r.t. bias + // long m_ = channels_out; + // long k__ = height_out * width_out; + THCudaBlas_Sgemv(state, + 't', + k_, m_, 1.0f, + THCudaTensor_data(state, grad_output_n), k_, + THCudaTensor_data(state, ones), 1, 1.0f, + THCudaTensor_data(state, grad_bias), 1); + } + + THCudaTensor_free(state, input_n); + THCudaTensor_free(state, offset_n); + THCudaTensor_free(state, mask_n); + + THCudaTensor_free(state, grad_output_n); + THCudaTensor_free(state, grad_input_n); + THCudaTensor_free(state, grad_offset_n); + THCudaTensor_free(state, grad_mask_n); +} \ No newline at end of file diff --git a/src/dcn_v2_cuda.h b/src/dcn_v2_cuda.h new file mode 100644 index 0000000..2c75dc6 --- /dev/null +++ b/src/dcn_v2_cuda.h @@ -0,0 +1,35 @@ +// #ifndef DCN_V2_CUDA +// #define DCN_V2_CUDA + +// #ifdef __cplusplus +// extern "C" +// { +// #endif + +void dcn_v2_cuda_forward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *output, THCudaTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_cuda_backward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *columns, + THCudaTensor *grad_input, THCudaTensor *grad_weight, + THCudaTensor *grad_bias, THCudaTensor *grad_offset, + THCudaTensor *grad_mask, THCudaTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); + +// #ifdef __cplusplus +// } +// #endif + +// #endif \ No newline at end of file diff --git a/src/dcn_v2_cuda_double.c b/src/dcn_v2_cuda_double.c new file mode 100644 index 0000000..1ea0821 --- /dev/null +++ b/src/dcn_v2_cuda_double.c @@ -0,0 +1,262 @@ +#include +#include "cuda/dcn_v2_im2col_cuda_double.h" + +extern THCState *state; + +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +void dcn_v2_cuda_forward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *output, THCudaDoubleTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 8, input, weight, bias, ones, offset, mask, output, columns)); + THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaDoubleTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + input = THCudaDoubleTensor_newContiguous(state, input); + offset = THCudaDoubleTensor_newContiguous(state, offset); + mask = THCudaDoubleTensor_newContiguous(state, mask); + weight = THCudaDoubleTensor_newContiguous(state, weight); + + const int batch = THCudaDoubleTensor_size(state, input, 0); + const int channels = THCudaDoubleTensor_size(state, input, 1); + const int height = THCudaDoubleTensor_size(state, input, 2); + const int width = THCudaDoubleTensor_size(state, input, 3); + + const int channels_out = THCudaDoubleTensor_size(state, weight, 0); + const int channels_kernel = THCudaDoubleTensor_size(state, weight, 1); + const int kernel_h_ = THCudaDoubleTensor_size(state, weight, 2); + const int kernel_w_ = THCudaDoubleTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaDoubleTensor__nDimension(state, ones) != 2 || + THCudaDoubleTensor_size(state, ones, 0) * THCudaDoubleTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaDoubleTensor_resize2d(state, ones, height_out, width_out); + THCudaDoubleTensor_fill(state, ones, 1); + } + + // resize output + THCudaDoubleTensor_resize4d(state, output, batch, channels_out, height_out, width_out); + // resize temporary columns + THCudaDoubleTensor_resize2d(state, columns, channels * kernel_h * kernel_w, 1 * height_out * width_out); + + THCudaDoubleTensor *input_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *offset_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *mask_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *output_n = THCudaDoubleTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaDoubleTensor_select(state, input_n, input, 0, b); + THCudaDoubleTensor_select(state, offset_n, offset, 0, b); + THCudaDoubleTensor_select(state, mask_n, mask, 0, b); + THCudaDoubleTensor_select(state, output_n, output, 0, b); + + // Do Bias first: + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + // (N x 1) (1 x M) + long m_ = channels_out; + long n_ = height_out * width_out; + long k_ = 1; + THCudaBlas_Dgemm(state, 't', 'n', n_, m_, k_, 1.0, + THCudaDoubleTensor_data(state, ones), k_, + THCudaDoubleTensor_data(state, bias), k_, 0.0, + THCudaDoubleTensor_data(state, output_n), n_); + + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, input_n), THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + deformable_group, THCudaDoubleTensor_data(state, columns)); + + //(k * m) x (m * n) + // Y = WC + long m = channels_out; + long n = height_out * width_out; + long k = channels * kernel_h * kernel_w; + THCudaBlas_Dgemm(state, 'n', 'n', n, m, k, 1.0f, + THCudaDoubleTensor_data(state, columns), n, + THCudaDoubleTensor_data(state, weight), k, 1.0f, + THCudaDoubleTensor_data(state, output_n), n); + } + THCudaDoubleTensor_free(state, input_n); + THCudaDoubleTensor_free(state, offset_n); + THCudaDoubleTensor_free(state, mask_n); + THCudaDoubleTensor_free(state, output_n); + + THCudaDoubleTensor_free(state, input); + THCudaDoubleTensor_free(state, offset); + THCudaDoubleTensor_free(state, mask); + THCudaDoubleTensor_free(state, weight); +} + +void dcn_v2_cuda_backward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *columns, + THCudaDoubleTensor *grad_input, THCudaDoubleTensor *grad_weight, + THCudaDoubleTensor *grad_bias, THCudaDoubleTensor *grad_offset, + THCudaDoubleTensor *grad_mask, THCudaDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 13, input, weight, bias, ones, offset, mask, columns, + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output)); + THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaDoubleTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + input = THCudaDoubleTensor_newContiguous(state, input); + offset = THCudaDoubleTensor_newContiguous(state, offset); + mask = THCudaDoubleTensor_newContiguous(state, mask); + weight = THCudaDoubleTensor_newContiguous(state, weight); + grad_output = THCudaDoubleTensor_newContiguous(state, grad_output); + + const int batch = THCudaDoubleTensor_size(state, input, 0); + const int channels = THCudaDoubleTensor_size(state, input, 1); + const int height = THCudaDoubleTensor_size(state, input, 2); + const int width = THCudaDoubleTensor_size(state, input, 3); + + const int channels_out = THCudaDoubleTensor_size(state, weight, 0); + const int channels_kernel = THCudaDoubleTensor_size(state, weight, 1); + const int kernel_h_ = THCudaDoubleTensor_size(state, weight, 2); + const int kernel_w_ = THCudaDoubleTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaDoubleTensor__nDimension(state, ones) != 2 || + THCudaDoubleTensor_size(state, ones, 0) * THCudaDoubleTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaDoubleTensor_resize2d(state, ones, height_out, width_out); + THCudaDoubleTensor_fill(state, ones, 1); + } + + // THCudaDoubleTensor_resize4d(state, grad_input, batch, channels, height, width); + THCudaDoubleTensor_resize2d(state, columns, channels * kernel_h * kernel_w, height_out * width_out); + + THCudaDoubleTensor *input_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *offset_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *mask_n = THCudaDoubleTensor_new(state); + + THCudaDoubleTensor *grad_output_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *grad_input_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *grad_offset_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *grad_mask_n = THCudaDoubleTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaDoubleTensor_select(state, input_n, input, 0, b); + THCudaDoubleTensor_select(state, offset_n, offset, 0, b); + THCudaDoubleTensor_select(state, mask_n, mask, 0, b); + THCudaDoubleTensor_select(state, grad_output_n, grad_output, 0, b); + THCudaDoubleTensor_select(state, grad_input_n, grad_input, 0, b); + THCudaDoubleTensor_select(state, grad_offset_n, grad_offset, 0, b); + THCudaDoubleTensor_select(state, grad_mask_n, grad_mask, 0, b); + + long m = channels * kernel_h * kernel_w; + long n = height_out * width_out; + long k = channels_out; + + THCudaBlas_Dgemm(state, 'n', 't', n, m, k, 1.0, + THCudaDoubleTensor_data(state, grad_output_n), n, + THCudaDoubleTensor_data(state, weight), m, 0.0, + THCudaDoubleTensor_data(state, columns), n); + + // gradient w.r.t. input offset and mask data + modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, columns), + THCudaDoubleTensor_data(state, input_n), + THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaDoubleTensor_data(state, grad_offset_n), + THCudaDoubleTensor_data(state, grad_mask_n)); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, columns), + THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaDoubleTensor_data(state, grad_input_n)); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and group + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, input_n), + THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaDoubleTensor_data(state, columns)); + long m_ = channels_out; + long n_ = channels * kernel_h * kernel_w; + long k_ = height_out * width_out; + + THCudaBlas_Dgemm(state, 't', 'n', n_, m_, k_, 1.0, + THCudaDoubleTensor_data(state, columns), k_, + THCudaDoubleTensor_data(state, grad_output_n), k_, 1.0, + THCudaDoubleTensor_data(state, grad_weight), n_); + + // gradient w.r.t. bias + // long m_ = channels_out; + // long k__ = height_out * width_out; + THCudaBlas_Dgemv(state, + 't', + k_, m_, 1.0f, + THCudaDoubleTensor_data(state, grad_output_n), k_, + THCudaDoubleTensor_data(state, ones), 1, 1.0f, + THCudaDoubleTensor_data(state, grad_bias), 1); + } + + THCudaDoubleTensor_free(state, input_n); + THCudaDoubleTensor_free(state, offset_n); + THCudaDoubleTensor_free(state, mask_n); + + THCudaDoubleTensor_free(state, grad_output_n); + THCudaDoubleTensor_free(state, grad_input_n); + THCudaDoubleTensor_free(state, grad_offset_n); + THCudaDoubleTensor_free(state, grad_mask_n); + + THCudaDoubleTensor_free(state, input); + THCudaDoubleTensor_free(state, offset); + THCudaDoubleTensor_free(state, mask); + THCudaDoubleTensor_free(state, weight); + THCudaDoubleTensor_free(state, grad_output); +} \ No newline at end of file diff --git a/src/dcn_v2_cuda_double.h b/src/dcn_v2_cuda_double.h new file mode 100644 index 0000000..18eb45b --- /dev/null +++ b/src/dcn_v2_cuda_double.h @@ -0,0 +1,35 @@ +// #ifndef DCN_V2_CUDA +// #define DCN_V2_CUDA + +// #ifdef __cplusplus +// extern "C" +// { +// #endif + +void dcn_v2_cuda_forward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *output, THCudaDoubleTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_cuda_backward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *columns, + THCudaDoubleTensor *grad_input, THCudaDoubleTensor *grad_weight, + THCudaDoubleTensor *grad_bias, THCudaDoubleTensor *grad_offset, + THCudaDoubleTensor *grad_mask, THCudaDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); + +// #ifdef __cplusplus +// } +// #endif + +// #endif \ No newline at end of file diff --git a/src/dcn_v2_double.c b/src/dcn_v2_double.c new file mode 100644 index 0000000..2b86545 --- /dev/null +++ b/src/dcn_v2_double.c @@ -0,0 +1,30 @@ +#include +#include +#include + +void dcn_v2_forward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + printf("only implemented in GPU"); +} +void dcn_v2_backward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + THDoubleTensor *grad_input, THDoubleTensor *grad_weight, + THDoubleTensor *grad_bias, THDoubleTensor *grad_offset, + THDoubleTensor *grad_mask, THDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + printf("only implemented in GPU"); +} \ No newline at end of file diff --git a/src/dcn_v2_double.h b/src/dcn_v2_double.h new file mode 100644 index 0000000..eda1f4c --- /dev/null +++ b/src/dcn_v2_double.h @@ -0,0 +1,20 @@ +void dcn_v2_forward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_backward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + THDoubleTensor *grad_input, THDoubleTensor *grad_weight, + THDoubleTensor *grad_bias, THDoubleTensor *grad_offset, + THDoubleTensor *grad_mask, THDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); \ No newline at end of file diff --git a/test.py b/test.py new file mode 100644 index 0000000..bc71846 --- /dev/null +++ b/test.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from dcn_v2 import DCNv2 +from dcn_v2_func import DCNv2Function + +deformable_groups = 1 +N, inC, inH, inW = 2, 2, 4, 4 +outC = 2 +kH, kW = 3, 3 + +def conv_identify(weight, bias): + weight.data.zero_() + bias.data.zero_() + o, i, h, w = weight.shape + y = h//2 + x = w//2 + for p in range(i): + for q in range(o): + if p == q: + weight.data[q, p, y, x] = 1.0 + +def check_zero_offset(): + conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True).cuda() + + conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True).cuda() + + dcn_v2 = DCNv2(inC, outC, (kH, kW), + stride=1, padding=1, dilation=1, + deformable_groups=deformable_groups).cuda() + + conv_offset.weight.data.zero_() + conv_offset.bias.data.zero_() + conv_mask.weight.data.zero_() + conv_mask.bias.data.zero_() + conv_identify(dcn_v2.weight, dcn_v2.bias) + + input = torch.randn(N, inC, inH, inW).cuda() + offset = conv_offset(input) + mask = conv_mask(input) + mask = torch.sigmoid(mask) + output = dcn_v2(input, offset, mask) + output *= 2 + d = (input - output).abs().max() + if d < 1e-10: + print('Zero offset passed') + else: + print('Zero offset failed') + +def check_gradient_double(): + + input = torch.randn(N, inC, inH, inW, dtype=torch.float64).cuda() + input.requires_grad = True + + offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW, dtype=torch.float64).cuda() + offset.data.zero_() + offset.data -= 0.5 + offset.requires_grad = True + + mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW, dtype=torch.float64).cuda() + # mask.data.zero_() + mask.requires_grad = True + mask = torch.sigmoid(mask) + + weight = torch.randn(outC, inC, kH, kW, dtype=torch.float64).cuda() + weight.requires_grad = True + + bias = torch.rand(outC, dtype=torch.float64).cuda() + bias.requires_grad = True + + func = DCNv2Function(stride=1, padding=1, dilation=1, deformable_groups=deformable_groups) + + print(gradcheck(func, (input, offset, mask, weight, bias), eps=1e-6, atol=1e-5, rtol=1e-3)) + +def check_gradient(): + + input = torch.randn(N, inC, inH, inW).cuda() + input.requires_grad = True + + offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() + offset.data.zero_() + offset.data -= 0.5 + offset.requires_grad = True + + mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() + # mask.data.zero_() + mask.requires_grad = True + mask = torch.sigmoid(mask) + + weight = torch.randn(outC, inC, kH, kW).cuda() + weight.requires_grad = True + + bias = torch.rand(outC).cuda() + bias.requires_grad = True + + func = DCNv2Function(stride=1, padding=1, dilation=1, deformable_groups=deformable_groups) + + print(gradcheck(func, (input, offset, mask, weight, bias), eps=1e-3, atol=1e-3, rtol=1e-2)) + + +if __name__ == '__main__': + if inC == outC: + check_zero_offset() + try: + check_gradient_double() + except TypeError: + print('''You can swith to double precision in dcn_v2_func.py by (un)commenting these two lines: + from _ext import dcn_v2 as _backend + from _ext import dcn_v2_double as _backend''') + print('Your tensor may not be **double** type') + print('Switching to **float** type') + + check_gradient() + finally: + print('Note: backward is not reentrant error may not be a serious problem, ' + 'since the max error is less than 1e-7\n' + 'Still looking for what trigger this problem')