diff --git a/dcn_v2.py b/dcn_v2.py index a2014e3..4e1ed63 100644 --- a/dcn_v2.py +++ b/dcn_v2.py @@ -112,18 +112,31 @@ def reset_parameters(self): def forward(self, input, offset, mask): assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == offset.shape[1] assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == mask.shape[1] - return dcn_v2_conv( - input, - offset, - mask, - self.weight, - self.bias, - self.stride, - self.padding, - self.dilation, - self.deformable_groups, - ) + if not torch.jit.is_scripting(): + if not torch.jit.is_tracing(): + return dcn_v2_conv( + input, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.deformable_groups, + ) + + kernel_size = self.weight.shape[2:4] + output = torch.ops.custom.dcn_v2(input, self.weight, self.bias, + offset, mask, + kernel_size[0], kernel_size[1], + self.stride[0], self.stride[1], + self.padding[0], self.padding[1], + self.dilation[0], self.dilation[1], + self.deformable_groups) + return output + class DCN(DCNv2): def __init__( @@ -158,19 +171,30 @@ def forward(self, input): o1, o2, mask = torch.chunk(out, 3, dim=1) offset = torch.cat((o1, o2), dim=1) mask = torch.sigmoid(mask) - return dcn_v2_conv( - input, - offset, - mask, - self.weight, - self.bias, - self.stride, - self.padding, - self.dilation, - self.deformable_groups, - ) - + if not torch.jit.is_scripting(): + if not torch.jit.is_tracing(): + return dcn_v2_conv( + input, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.deformable_groups, + ) + kernel_size = self.weight.shape[2:4] + output = torch.ops.custom.dcn_v2(input, self.weight, self.bias, + offset, mask, + kernel_size[0], kernel_size[1], + self.stride[0], self.stride[1], + self.padding[0], self.padding[1], + self.dilation[0], self.dilation[1], + self.deformable_groups) + return output + class _DCNv2Pooling(Function): @staticmethod def forward( diff --git a/src/vision.cpp b/src/vision.cpp index ff54233..4559947 100644 --- a/src/vision.cpp +++ b/src/vision.cpp @@ -1,9 +1,26 @@ #include "dcn_v2.h" +#include PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("dcn_v2_forward", &dcn_v2_forward, "dcn_v2_forward"); m.def("dcn_v2_backward", &dcn_v2_backward, "dcn_v2_backward"); - m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, "dcn_v2_psroi_pooling_forward"); - m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, "dcn_v2_psroi_pooling_backward"); + m.def("dcn_v2_psroi_pooling_forward", &dcn_v2_psroi_pooling_forward, + "dcn_v2_psroi_pooling_forward"); + m.def("dcn_v2_psroi_pooling_backward", &dcn_v2_psroi_pooling_backward, + "dcn_v2_psroi_pooling_backward"); } + +inline at::Tensor dcn_v2(const at::Tensor &input, const at::Tensor &weight, + const at::Tensor &bias, const at::Tensor &offset, + const at::Tensor &mask, const int64_t kernel_h, + const int64_t kernel_w, const int64_t stride_h, + const int64_t stride_w, const int64_t pad_h, + const int64_t pad_w, const int64_t dilation_h, + const int64_t dilation_w, + const int64_t deformable_group) { + return dcn_v2_forward(input, weight, bias, offset, mask, kernel_h, kernel_w, + stride_h, stride_w, pad_h, pad_w, dilation_h, + dilation_w, deformable_group); +} +static auto registry = torch::RegisterOperators().op("custom::dcn_v2", &dcn_v2); diff --git a/test/testcuda.py b/test/testcuda.py index b83a4aa..2776349 100644 --- a/test/testcuda.py +++ b/test/testcuda.py @@ -45,6 +45,7 @@ def check_zero_offset(): ).cuda() dcn_v2 = DCNv2(inC, outC, (kH, kW), stride=1, padding=1, dilation=1, deformable_groups=deformable_groups).cuda() + scripted_dcnv2 = torch.jit.script(dcn_v2) conv_offset.weight.data.zero_() conv_offset.bias.data.zero_() @@ -57,6 +58,10 @@ def check_zero_offset(): mask = conv_mask(input) mask = torch.sigmoid(mask) output = dcn_v2(input, offset, mask) + result = scripted_dcnv2(input, offset, mask) + print(scripted_dcnv2.code) + torch.testing.assert_allclose(output, result) + output *= 2 d = (input - output).abs().max() if d < 1e-10: @@ -194,13 +199,18 @@ def example_dconv(): input = torch.randn(2, 64, 128, 128).cuda() # wrap all things (offset and mask) in DCN dcn = DCN(64, 64, kernel_size=(3, 3), stride=1, padding=1, deformable_groups=2).cuda() + traced_dcn = torch.jit.script(dcn, input) # print(dcn.weight.shape, input.shape) output = dcn(input) + result = traced_dcn(input) + targert = output.new(*output.size()) targert.data.uniform_(-0.01, 0.01) error = (targert - output).mean() error.backward() print(output.shape) + print(traced_dcn.code) + torch.testing.assert_allclose(output, result) def example_dpooling():