Skip to content

Commit

Permalink
DCN
Browse files Browse the repository at this point in the history
  • Loading branch information
CharlesShang committed Dec 5, 2018
0 parents commit 490610f
Show file tree
Hide file tree
Showing 22 changed files with 2,041 additions and 0 deletions.
20 changes: 20 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
## Deformable Convolutional Networks V2 with Pytorch
```bash
.\make # build
python test.py # run gradient check
```
### Known issues:

-[ ] Gradient check w.r.t offset.
-[ ] Backward is not reentrant.

This is adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op).
I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes.
However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some
non-differential points?

Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small `(<1e-7)`,
so it may not be a serious problem (?)

Please post an issue or PR if you have any comments.

Empty file added __init__.py
Empty file.
42 changes: 42 additions & 0 deletions build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import torch
from torch.utils.ffi import create_extension


sources = ['src/dcn_v2.c']
headers = ['src/dcn_v2.h']
defines = []
with_cuda = False

extra_objects = []
if torch.cuda.is_available():
print('Including CUDA code.')
sources += ['src/dcn_v2_cuda.c']
headers += ['src/dcn_v2_cuda.h']
defines += [('WITH_CUDA', None)]
extra_objects += ['src/cuda/dcn_v2_im2col_cuda.cu.o']
with_cuda = True
else:
raise ValueError('CUDA is not available')

extra_compile_args = ['-fopenmp', '-std=c99']

this_file = os.path.dirname(os.path.realpath(__file__))
print(this_file)
sources = [os.path.join(this_file, fname) for fname in sources]
headers = [os.path.join(this_file, fname) for fname in headers]
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]

ffi = create_extension(
'_ext.dcn_v2',
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
extra_objects=extra_objects,
extra_compile_args=extra_compile_args
)

if __name__ == '__main__':
ffi.build()
42 changes: 42 additions & 0 deletions build_double.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import os
import torch
from torch.utils.ffi import create_extension


sources = ['src/dcn_v2_double.c']
headers = ['src/dcn_v2_double.h']
defines = []
with_cuda = False

extra_objects = []
if torch.cuda.is_available():
print('Including CUDA code.')
sources += ['src/dcn_v2_cuda_double.c']
headers += ['src/dcn_v2_cuda_double.h']
defines += [('WITH_CUDA', None)]
extra_objects += ['src/cuda/dcn_v2_im2col_cuda_double.cu.o']
with_cuda = True
else:
raise ValueError('CUDA is not available')

extra_compile_args = ['-fopenmp', '-std=c99']

this_file = os.path.dirname(os.path.realpath(__file__))
print(this_file)
sources = [os.path.join(this_file, fname) for fname in sources]
headers = [os.path.join(this_file, fname) for fname in headers]
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]

ffi = create_extension(
'_ext.dcn_v2_double',
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda,
extra_objects=extra_objects,
extra_compile_args=extra_compile_args
)

if __name__ == '__main__':
ffi.build()
68 changes: 68 additions & 0 deletions dcn_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import torch
import math
from torch import nn
from torch.nn.modules.utils import _pair

from dcn_v2_func import DCNv2Function

class DCNv2(nn.Module):

def __init__(self, in_channels, out_channels,
kernel_size, stride, padding, dilation=1, deformable_groups=1):
super(DCNv2, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = _pair(kernel_size)
self.stride = stride
self.padding = padding
self.dilation = dilation
self.deformable_groups = deformable_groups

self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
self.bias = nn.Parameter(torch.Tensor(out_channels))
self.reset_parameters()

def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
stdv = 1. / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
self.bias.data.zero_()

def forward(self, input, offset, mask):
func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups)
return func(input, offset, mask, self.weight, self.bias)


class DCN(DCNv2):

def __init__(self, in_channels, out_channels,
kernel_size, stride, padding,
dilation=1, deformable_groups=1):
super(DCN, self).__init__(in_channels, out_channels,
kernel_size, stride, padding, dilation, deformable_groups)

self.conv_offset_mask = nn.Conv2d(self.in_channels,
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1],
kernel_size=self.kernel_size,
stride=(self.stride, self.stride),
padding=(self.padding, self.padding),
bias=True)
self.reset_parameters()

def reset_parameters(self):
self.conv_offset_mask.weight.data.zero_()
self.conv_offset_mask.bias.data.zero_()

def forward(self, input):
out = self.conv_offset_mask(input)
dy, dx, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((dy, dx), dim=1)
func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups)
return func(input, offset, mask, self.weight, self.bias)
72 changes: 72 additions & 0 deletions dcn_v2_func.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
#!/usr/bin/env python
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import torch
from torch.autograd import Function

from _ext import dcn_v2 as _backend
# from _ext import dcn_v2_double as _backend

class DCNv2Function(Function):

def __init__(self, stride, padding, dilation=1, deformable_groups=1):
super(DCNv2Function, self).__init__()
self.stride = stride
self.padding = padding
self.dilation = dilation
self.deformable_groups = deformable_groups

def forward(self, input, offset, mask, weight, bias):
if not input.is_cuda:
raise NotImplementedError
if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad:
self.save_for_backward(input, offset, mask, weight, bias)
output = input.new(*self._infer_shape(input, weight))
self._bufs = [input.new(), input.new()]
_backend.dcn_v2_cuda_forward(input, weight,
bias, self._bufs[0],
offset, mask,
output, self._bufs[1],
weight.shape[2], weight.shape[3],
self.stride, self.stride,
self.padding, self.padding,
self.dilation, self.dilation,
self.deformable_groups)
return output

def backward(self, grad_output):
if not grad_output.is_cuda:
raise NotImplementedError
input, offset, mask, weight, bias = self.saved_tensors
grad_input = input.new(*input.size()).zero_()
grad_offset = offset.new(*offset.size()).zero_()
grad_mask = mask.new(*mask.size()).zero_()
grad_weight = weight.new(*weight.size()).zero_()
grad_bias = bias.new(*bias.size()).zero_()
_backend.dcn_v2_cuda_backward(input, weight,
bias, self._bufs[0],
offset, mask,
self._bufs[1],
grad_input, grad_weight,
grad_bias, grad_offset,
grad_mask, grad_output,
weight.shape[2], weight.shape[3],
self.stride, self.stride,
self.padding, self.padding,
self.dilation, self.dilation,
self.deformable_groups)


return grad_input, grad_offset, grad_mask, grad_weight, grad_bias

def _infer_shape(self, input, weight):
n = input.size(0)
channels_out = weight.size(0)
height, width = input.shape[2:4]
kernel_h, kernel_w = weight.shape[2:4]
height_out = (height + 2 * self.padding - (self.dilation * (kernel_h - 1) + 1)) // self.stride + 1
width_out = (width + 2 * self.padding - (self.dilation * (kernel_w - 1) + 1)) // self.stride + 1
return (n, channels_out, height_out, width_out)

7 changes: 7 additions & 0 deletions mask.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#!/usr/bin/env bash
cd src/cuda
nvcc -c -o dcn_v2_im2col_cuda.cu.o dcn_v2_im2col_cuda.cu -x cu -Xcompiler -fPIC
nvcc -c -o dcn_v2_im2col_cuda_double.cu.o dcn_v2_im2col_cuda_double.cu -x cu -Xcompiler -fPIC
cd -
python build.py
python build_double.py
Loading

0 comments on commit 490610f

Please sign in to comment.