forked from CharlesShang/DCNv2
-
Notifications
You must be signed in to change notification settings - Fork 83
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 490610f
Showing
22 changed files
with
2,041 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
## Deformable Convolutional Networks V2 with Pytorch | ||
```bash | ||
.\make # build | ||
python test.py # run gradient check | ||
``` | ||
### Known issues: | ||
|
||
-[ ] Gradient check w.r.t offset. | ||
-[ ] Backward is not reentrant. | ||
|
||
This is adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). | ||
I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes. | ||
However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some | ||
non-differential points? | ||
|
||
Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small `(<1e-7)`, | ||
so it may not be a serious problem (?) | ||
|
||
Please post an issue or PR if you have any comments. | ||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
import torch | ||
from torch.utils.ffi import create_extension | ||
|
||
|
||
sources = ['src/dcn_v2.c'] | ||
headers = ['src/dcn_v2.h'] | ||
defines = [] | ||
with_cuda = False | ||
|
||
extra_objects = [] | ||
if torch.cuda.is_available(): | ||
print('Including CUDA code.') | ||
sources += ['src/dcn_v2_cuda.c'] | ||
headers += ['src/dcn_v2_cuda.h'] | ||
defines += [('WITH_CUDA', None)] | ||
extra_objects += ['src/cuda/dcn_v2_im2col_cuda.cu.o'] | ||
with_cuda = True | ||
else: | ||
raise ValueError('CUDA is not available') | ||
|
||
extra_compile_args = ['-fopenmp', '-std=c99'] | ||
|
||
this_file = os.path.dirname(os.path.realpath(__file__)) | ||
print(this_file) | ||
sources = [os.path.join(this_file, fname) for fname in sources] | ||
headers = [os.path.join(this_file, fname) for fname in headers] | ||
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] | ||
|
||
ffi = create_extension( | ||
'_ext.dcn_v2', | ||
headers=headers, | ||
sources=sources, | ||
define_macros=defines, | ||
relative_to=__file__, | ||
with_cuda=with_cuda, | ||
extra_objects=extra_objects, | ||
extra_compile_args=extra_compile_args | ||
) | ||
|
||
if __name__ == '__main__': | ||
ffi.build() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
import os | ||
import torch | ||
from torch.utils.ffi import create_extension | ||
|
||
|
||
sources = ['src/dcn_v2_double.c'] | ||
headers = ['src/dcn_v2_double.h'] | ||
defines = [] | ||
with_cuda = False | ||
|
||
extra_objects = [] | ||
if torch.cuda.is_available(): | ||
print('Including CUDA code.') | ||
sources += ['src/dcn_v2_cuda_double.c'] | ||
headers += ['src/dcn_v2_cuda_double.h'] | ||
defines += [('WITH_CUDA', None)] | ||
extra_objects += ['src/cuda/dcn_v2_im2col_cuda_double.cu.o'] | ||
with_cuda = True | ||
else: | ||
raise ValueError('CUDA is not available') | ||
|
||
extra_compile_args = ['-fopenmp', '-std=c99'] | ||
|
||
this_file = os.path.dirname(os.path.realpath(__file__)) | ||
print(this_file) | ||
sources = [os.path.join(this_file, fname) for fname in sources] | ||
headers = [os.path.join(this_file, fname) for fname in headers] | ||
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] | ||
|
||
ffi = create_extension( | ||
'_ext.dcn_v2_double', | ||
headers=headers, | ||
sources=sources, | ||
define_macros=defines, | ||
relative_to=__file__, | ||
with_cuda=with_cuda, | ||
extra_objects=extra_objects, | ||
extra_compile_args=extra_compile_args | ||
) | ||
|
||
if __name__ == '__main__': | ||
ffi.build() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
#!/usr/bin/env python | ||
from __future__ import absolute_import | ||
from __future__ import print_function | ||
from __future__ import division | ||
|
||
import torch | ||
import math | ||
from torch import nn | ||
from torch.nn.modules.utils import _pair | ||
|
||
from dcn_v2_func import DCNv2Function | ||
|
||
class DCNv2(nn.Module): | ||
|
||
def __init__(self, in_channels, out_channels, | ||
kernel_size, stride, padding, dilation=1, deformable_groups=1): | ||
super(DCNv2, self).__init__() | ||
self.in_channels = in_channels | ||
self.out_channels = out_channels | ||
self.kernel_size = _pair(kernel_size) | ||
self.stride = stride | ||
self.padding = padding | ||
self.dilation = dilation | ||
self.deformable_groups = deformable_groups | ||
|
||
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) | ||
self.bias = nn.Parameter(torch.Tensor(out_channels)) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
n = self.in_channels | ||
for k in self.kernel_size: | ||
n *= k | ||
stdv = 1. / math.sqrt(n) | ||
self.weight.data.uniform_(-stdv, stdv) | ||
self.bias.data.zero_() | ||
|
||
def forward(self, input, offset, mask): | ||
func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) | ||
return func(input, offset, mask, self.weight, self.bias) | ||
|
||
|
||
class DCN(DCNv2): | ||
|
||
def __init__(self, in_channels, out_channels, | ||
kernel_size, stride, padding, | ||
dilation=1, deformable_groups=1): | ||
super(DCN, self).__init__(in_channels, out_channels, | ||
kernel_size, stride, padding, dilation, deformable_groups) | ||
|
||
self.conv_offset_mask = nn.Conv2d(self.in_channels, | ||
self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], | ||
kernel_size=self.kernel_size, | ||
stride=(self.stride, self.stride), | ||
padding=(self.padding, self.padding), | ||
bias=True) | ||
self.reset_parameters() | ||
|
||
def reset_parameters(self): | ||
self.conv_offset_mask.weight.data.zero_() | ||
self.conv_offset_mask.bias.data.zero_() | ||
|
||
def forward(self, input): | ||
out = self.conv_offset_mask(input) | ||
dy, dx, mask = torch.chunk(out, 3, dim=1) | ||
offset = torch.cat((dy, dx), dim=1) | ||
func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) | ||
return func(input, offset, mask, self.weight, self.bias) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
#!/usr/bin/env python | ||
from __future__ import absolute_import | ||
from __future__ import print_function | ||
from __future__ import division | ||
|
||
import torch | ||
from torch.autograd import Function | ||
|
||
from _ext import dcn_v2 as _backend | ||
# from _ext import dcn_v2_double as _backend | ||
|
||
class DCNv2Function(Function): | ||
|
||
def __init__(self, stride, padding, dilation=1, deformable_groups=1): | ||
super(DCNv2Function, self).__init__() | ||
self.stride = stride | ||
self.padding = padding | ||
self.dilation = dilation | ||
self.deformable_groups = deformable_groups | ||
|
||
def forward(self, input, offset, mask, weight, bias): | ||
if not input.is_cuda: | ||
raise NotImplementedError | ||
if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: | ||
self.save_for_backward(input, offset, mask, weight, bias) | ||
output = input.new(*self._infer_shape(input, weight)) | ||
self._bufs = [input.new(), input.new()] | ||
_backend.dcn_v2_cuda_forward(input, weight, | ||
bias, self._bufs[0], | ||
offset, mask, | ||
output, self._bufs[1], | ||
weight.shape[2], weight.shape[3], | ||
self.stride, self.stride, | ||
self.padding, self.padding, | ||
self.dilation, self.dilation, | ||
self.deformable_groups) | ||
return output | ||
|
||
def backward(self, grad_output): | ||
if not grad_output.is_cuda: | ||
raise NotImplementedError | ||
input, offset, mask, weight, bias = self.saved_tensors | ||
grad_input = input.new(*input.size()).zero_() | ||
grad_offset = offset.new(*offset.size()).zero_() | ||
grad_mask = mask.new(*mask.size()).zero_() | ||
grad_weight = weight.new(*weight.size()).zero_() | ||
grad_bias = bias.new(*bias.size()).zero_() | ||
_backend.dcn_v2_cuda_backward(input, weight, | ||
bias, self._bufs[0], | ||
offset, mask, | ||
self._bufs[1], | ||
grad_input, grad_weight, | ||
grad_bias, grad_offset, | ||
grad_mask, grad_output, | ||
weight.shape[2], weight.shape[3], | ||
self.stride, self.stride, | ||
self.padding, self.padding, | ||
self.dilation, self.dilation, | ||
self.deformable_groups) | ||
|
||
|
||
return grad_input, grad_offset, grad_mask, grad_weight, grad_bias | ||
|
||
def _infer_shape(self, input, weight): | ||
n = input.size(0) | ||
channels_out = weight.size(0) | ||
height, width = input.shape[2:4] | ||
kernel_h, kernel_w = weight.shape[2:4] | ||
height_out = (height + 2 * self.padding - (self.dilation * (kernel_h - 1) + 1)) // self.stride + 1 | ||
width_out = (width + 2 * self.padding - (self.dilation * (kernel_w - 1) + 1)) // self.stride + 1 | ||
return (n, channels_out, height_out, width_out) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
#!/usr/bin/env bash | ||
cd src/cuda | ||
nvcc -c -o dcn_v2_im2col_cuda.cu.o dcn_v2_im2col_cuda.cu -x cu -Xcompiler -fPIC | ||
nvcc -c -o dcn_v2_im2col_cuda_double.cu.o dcn_v2_im2col_cuda_double.cu -x cu -Xcompiler -fPIC | ||
cd - | ||
python build.py | ||
python build_double.py |
Oops, something went wrong.