Skip to content

Commit

Permalink
Add torchscript support
Browse files Browse the repository at this point in the history
  • Loading branch information
tianyou.gty committed May 17, 2022
1 parent 711fe28 commit ff5efb8
Show file tree
Hide file tree
Showing 3 changed files with 76 additions and 25 deletions.
70 changes: 47 additions & 23 deletions dcn_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,31 @@ def reset_parameters(self):
def forward(self, input, offset, mask):
assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == offset.shape[1]
assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1]
return dcn_v2_conv(
input,
offset,
mask,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.deformable_groups,
)

if not torch.jit.is_scripting():
if not torch.jit.is_tracing():
return dcn_v2_conv(
input,
offset,
mask,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.deformable_groups,
)

kernel_size = self.weight.shape[2:4]
output = torch.ops.custom.dcn_v2(input, self.weight, self.bias,
offset, mask,
kernel_size[0], kernel_size[1],
self.stride[0], self.stride[1],
self.padding[0], self.padding[1],
self.dilation[0], self.dilation[1],
self.deformable_groups)
return output


class DCN(DCNv2):
def __init__(
Expand Down Expand Up @@ -158,19 +171,30 @@ def forward(self, input):
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
return dcn_v2_conv(
input,
offset,
mask,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.deformable_groups,
)

if not torch.jit.is_scripting():
if not torch.jit.is_tracing():
return dcn_v2_conv(
input,
offset,
mask,
self.weight,
self.bias,
self.stride,
self.padding,
self.dilation,
self.deformable_groups,
)

kernel_size = self.weight.shape[2:4]
output = torch.ops.custom.dcn_v2(input, self.weight, self.bias,
offset, mask,
kernel_size[0], kernel_size[1],
self.stride[0], self.stride[1],
self.padding[0], self.padding[1],
self.dilation[0], self.dilation[1],
self.deformable_groups)
return output

class _DCNv2Pooling(Function):
@staticmethod
def forward(
Expand Down
21 changes: 19 additions & 2 deletions src/vision.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,26 @@

#include "dcn_v2.h"
#include <torch/script.h>

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward");
m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward");
m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward");
m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward");
m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward,
"dcn_v2_psroi_pooling_forward");
m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward,
"dcn_v2_psroi_pooling_backward");
}

inline at::Tensor dcn_v2(const at::Tensor &input, const at::Tensor &weight,
const at::Tensor &bias, const at::Tensor &offset,
const at::Tensor &mask, const int64_t kernel_h,
const int64_t kernel_w, const int64_t stride_h,
const int64_t stride_w, const int64_t pad_h,
const int64_t pad_w, const int64_t dilation_h,
const int64_t dilation_w,
const int64_t deformable_group) {
return dcn_v2_forward(input, weight, bias, offset, mask, kernel_h, kernel_w,
stride_h, stride_w, pad_h, pad_w, dilation_h,
dilation_w, deformable_group);
}
static auto registry = torch::RegisterOperators().op("custom::dcn_v2", &dcn_v2);
10 changes: 10 additions & 0 deletions test/testcuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def check_zero_offset():
).cuda()

dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups).cuda()
scripted_dcnv2 = torch.jit.script(dcn_v2)

conv_offset.weight.data.zero_()
conv_offset.bias.data.zero_()
Expand All @@ -57,6 +58,10 @@ def check_zero_offset():
mask = conv_mask(input)
mask = torch.sigmoid(mask)
output = dcn_v2(input, offset, mask)
result = scripted_dcnv2(input, offset, mask)
print(scripted_dcnv2.code)
torch.testing.assert_allclose(output, result)

output *= 2
d = (input - output).abs().max()
if d < 1e-10:
Expand Down Expand Up @@ -194,13 +199,18 @@ def example_dconv():
input = torch.randn(2, 64, 128, 128).cuda()
# wrap all things (offset and mask) in DCN
dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2).cuda()
traced_dcn = torch.jit.script(dcn, input)
# print(dcn.weight.shape, input.shape)
output = dcn(input)
result = traced_dcn(input)

targert = output.new(*output.size())
targert.data.uniform_(-0.01, 0.01)
error = (targert - output).mean()
error.backward()
print(output.shape)
print(traced_dcn.code)
torch.testing.assert_allclose(output, result)


def example_dpooling():
Expand Down

0 comments on commit ff5efb8

Please sign in to comment.