Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pytorch 1.6-1.8 compatability - CUDA11/3090 ready #92

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DCN/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .dcn_v2 import *
File renamed without changes.
118 changes: 57 additions & 61 deletions src/cpu/dcn_v2_cpu.cpp → DCN/src/cpu/dcn_v2_cpu.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include <vector>

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello MatthewHowe !pytorch1.8.0+cuda11.1 had launched at 2021.3.6 ,and i find that they moved THCBlas away,the THCudaBlas_Sgemm and THCudaBlas_SgemmBatched can't be use anymore.I found THCudaBlas_Sgemm can be replaced by at::cuda::blas::gemm<scalar_t>,but I don't find out how to replace THCudaBlas_SgemmBatched ....(pytorch/pytorch#49725)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to replace all at::cuda::blas with torch.matmul. Mostly same as in the CPU version.
You can check my fork https://github.com/tteepe/DCNv2, it works with PyTorch 1.8 GPU & CPU.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tteepe I'm testing your fork and the tests are failing. Zero offset failed and also there is runtime error for check_gradient_dconv(). Do you know any fixes?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you running on CPU or GPU? Could you share the errors?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tteepe I'm running tests for GPU:

❯ python ../tests/test_cuda.py >> errors.txt
/home/ubuntu/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/autograd/gradcheck.py:379: UserWarning: Input #0 requires gradient and is not a double precision floating point or complex. This check will likely fail if all the inputs are not of double precision floating point or complex.
  warnings.warn(
/home/ubuntu/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/autograd/gradcheck.py:379: UserWarning: Input #2 requires gradient and is not a double precision floating point or complex. This check will likely fail if all the inputs are not of double precision floating point or complex.
  warnings.warn(
/home/ubuntu/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/autograd/gradcheck.py:379: UserWarning: Input #1 requires gradient and is not a double precision floating point or complex. This check will likely fail if all the inputs are not of double precision floating point or complex.
  warnings.warn(
/home/ubuntu/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/autograd/gradcheck.py:379: UserWarning: Input #3 requires gradient and is not a double precision floating point or complex. This check will likely fail if all the inputs are not of double precision floating point or complex.
  warnings.warn(
/home/ubuntu/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/autograd/gradcheck.py:379: UserWarning: Input #4 requires gradient and is not a double precision floating point or complex. This check will likely fail if all the inputs are not of double precision floating point or complex.
  warnings.warn(
Traceback (most recent call last):
  File "../tests/test_cuda.py", line 310, in <module>
    check_gradient_dconv()
  File "../tests/test_cuda.py", line 102, in check_gradient_dconv
    gradcheck(
  File "/home/ubuntu/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 468, in gradcheck
    checkIfNumericalAnalyticAreClose(a, n, j)
  File "/home/ubuntu/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 449, in checkIfNumericalAnalyticAreClose
    return fail_test(error_str + 'Jacobian mismatch for output %d with respect to input %d,\n'
  File "/home/ubuntu/.pyenv/versions/3.8.6/lib/python3.8/site-packages/torch/autograd/gradcheck.py", line 367, in fail_test
    raise RuntimeError(msg)
RuntimeError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[ 0.0482,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1630,  0.4700,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1120,  0.1633,  0.2644,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0073,  0.0899, -0.0317],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.3312,  0.0000, -0.0292],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')
analytical:tensor([[ 0.0484,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1630,  0.4696,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [ 0.1107,  0.1641,  0.2640,  ...,  0.0000,  0.0000,  0.0000],
        ...,
        [ 0.0000,  0.0000,  0.0000,  ..., -0.0073,  0.0896, -0.0317],
        [ 0.0000,  0.0000,  0.0000,  ..., -0.3318,  0.0000, -0.0293],
        [ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000]],
       device='cuda:0')

The print outs for the tests are:

❯ cat errors.txt
torch.Size([2, 64, 128, 128])
torch.Size([20, 32, 7, 7])
torch.Size([20, 32, 7, 7])
torch.Size([20, 32, 7, 7])
0.971507, 1.943014
0.971507, 1.943014
Zero offset failed
tensor([[[[ 8.3014e-01, -5.9995e-01,  3.3635e-01,  9.7973e-01],
          [-3.9212e-01, -1.4908e+00, -7.4743e-01, -1.2212e+00],
          [-2.6852e-01,  6.1768e-01, -4.6663e-02, -8.0118e-01],
          [ 5.4885e-01, -6.4019e-01, -3.2594e-01,  5.0614e-01]],

         [[ 4.3672e-01, -1.2259e+00, -7.5590e-01,  1.1371e+00],
          [ 1.5166e-03, -4.0051e-01, -6.4815e-01,  7.1022e-01],
          [-1.1470e+00, -2.6621e+00, -9.3397e-02, -3.2880e-01],
          [-7.3536e-01, -5.0394e-01, -2.7695e-01, -7.5032e-01]]],


        [[[ 1.1741e+00, -2.4331e-01,  2.1694e-02,  1.1321e+00],
          [ 4.2356e-01,  9.0400e-01,  4.8025e-01,  8.6040e-01],
          [-9.0097e-02, -1.1146e+00,  2.7222e-01, -8.4943e-01],
          [ 1.8496e+00, -7.6862e-01,  7.1716e-01, -8.0398e-01]],

         [[-1.1821e+00, -1.2773e+00, -1.1843e-01, -1.3768e-01],
          [ 1.6133e+00, -1.2187e+00,  1.1415e-01, -1.1589e+00],
          [-8.9244e-02, -3.6727e-01, -1.8207e+00, -3.6395e-01],
          [ 1.7848e+00,  7.1422e-01,  1.4804e+00,  1.5632e+00]]]],
       device='cuda:0')
tensor([[[[ 8.3008e-01, -6.0010e-01,  3.3643e-01,  9.7949e-01],
          [-3.9209e-01, -1.4912e+00, -7.4756e-01, -1.2217e+00],
          [-2.6855e-01,  6.1768e-01, -4.6661e-02, -8.0127e-01],
          [ 5.4883e-01, -6.4014e-01, -3.2593e-01,  5.0635e-01]],

         [[ 4.3677e-01, -1.2256e+00, -7.5586e-01,  1.1367e+00],
          [ 1.5163e-03, -4.0039e-01, -6.4795e-01,  7.1045e-01],
          [-1.1475e+00, -2.6621e+00, -9.3384e-02, -3.2886e-01],
          [-7.3535e-01, -5.0391e-01, -2.7686e-01, -7.5049e-01]]],


        [[[ 1.1738e+00, -2.4329e-01,  2.1698e-02,  1.1318e+00],
          [ 4.2358e-01,  9.0381e-01,  4.8022e-01,  8.6035e-01],
          [-9.0088e-02, -1.1143e+00,  2.7222e-01, -8.4961e-01],
          [ 1.8496e+00, -7.6855e-01,  7.1729e-01, -8.0420e-01]],

         [[-1.1816e+00, -1.2773e+00, -1.1841e-01, -1.3770e-01],
          [ 1.6133e+00, -1.2188e+00,  1.1414e-01, -1.1592e+00],
          [-8.9233e-02, -3.6719e-01, -1.8203e+00, -3.6401e-01],
          [ 1.7852e+00,  7.1436e-01,  1.4805e+00,  1.5635e+00]]]],
       device='cuda:0', grad_fn=<MulBackward0>)
check_gradient_dpooling: True

Copy link

@haruishi43 haruishi43 Mar 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rathaROG

tteepe's fork is actually a JIT-compiled version so you don't have to do python setup.py.
What I did to test that his fork is working is:

git clone  https://github.com/tteepe/DCNv2.git
cd DCN
python ../tests/test_cuda.py

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @haruishi43, I still had one more problem:

[4/4] "C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\amd64/link.exe" dcn_v2_cuda.o dcn_v2_cpu.o dcn_v2_im2col_cpu.o dcn_v2_psroi_pooling_cpu.o dcn_v2_cuda.cuda.o dcn_v2_im2col_cuda.cuda.o dcn_v2_psroi_pooling_cuda.cuda.o /nologo /DLL c10.lib c10_cuda.lib torch_cpu.lib torch_cuda_cu.lib -INCLUDE:?searchsorted_cuda@native@at@@YA?AVTensor@2@AEBV32@0_N1@Z torch_cuda_cpp.lib -INCLUDE:?warp_size@cuda@at@@YAHXZ torch.lib /LIBPATH:C:\dev\exc\Anaconda3\envs\DEFT\lib\site-packages\torch\lib torch_python.lib /LIBPATH:C:\dev\exc\Anaconda3\envs\DEFT\libs "/LIBPATH:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1\lib/x64" cudart.lib /out:DCNv2_gpu.pyd
FAILED: DCNv2_gpu.pyd
"C:\Program Files (x86)\Microsoft Visual Studio 14.0\VC\bin\amd64/link.exe" dcn_v2_cuda.o dcn_v2_cpu.o dcn_v2_im2col_cpu.o dcn_v2_psroi_pooling_cpu.o dcn_v2_cuda.cuda.o dcn_v2_im2col_cuda.cuda.o dcn_v2_psroi_pooling_cuda.cuda.o /nologo /DLL c10.lib c10_cuda.lib torch_cpu.lib torch_cuda_cu.lib -INCLUDE:?searchsorted_cuda@native@at@@YA?AVTensor@2@AEBV32@0_N1@Z torch_cuda_cpp.lib -INCLUDE:?warp_size@cuda@at@@YAHXZ torch.lib /LIBPATH:C:\dev\exc\Anaconda3\envs\DEFT\lib\site-packages\torch\lib torch_python.lib /LIBPATH:C:\dev\exc\Anaconda3\envs\DEFT\libs "/LIBPATH:C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1\lib/x64" cudart.lib /out:DCNv2_gpu.pyd
   Creating library DCNv2_gpu.lib and object DCNv2_gpu.exp
MSVCRT.lib(loadcfg.obj) : error LNK2001: unresolved external symbol __enclave_config
DCNv2_gpu.pyd : fatal error LNK1120: 1 unresolved externals
ninja: build stopped: subcommand failed.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rathaROG

tteepe's fork is actually a JIT-compiled version so you don't have to do python setup.py.
What I did to test that his fork is working is:

git clone  https://github.com/tteepe/DCNv2.git
cd DCN
python ../tests/test_cuda.py

Thanks for your help so far. In case you're interested, windows-ready is here
https://github.com/rathaROG/DCNv2_Windows

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tteepe Hello, I successfully compiled your version, but when I run testcuda.py, It raise an errer
No module named 'dcnv2_gpu'
Have you ever experienced this?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rathaROG Have you run it successfully under Linux? My environment is :
Python 3.6.13
torch 1.8.0+cu111
No module named 'dCNV2_GPU' in my test
I would appreciate it if you have a solution
image

#include "cpu/dcn_v2_im2col_cpu.h"
#include <iostream>

#include <ATen/ATen.h>
//#include <ATen/cuda/CUDAContext.h>
Expand All @@ -12,8 +13,12 @@

// author: Charles Shang
// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu

// modified from the CUDA version for CPU use by Daniel K. Suhendro

// edit by: James Bockman and Matthew Howe
// modified for torch implementation to remove use of deprecated torch access to Blas

at::Tensor
dcn_v2_cpu_forward(const at::Tensor &input,
const at::Tensor &weight,
Expand Down Expand Up @@ -60,9 +65,10 @@ dcn_v2_cpu_forward(const at::Tensor &input,
const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

auto ones = at::ones({height_out, width_out}, input.options());
// auto ones = at::ones({height_out, width_out}, input.options());
auto ones = at::ones({bias.sizes()[0], height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());
auto output = at::zeros({batch, channels_out, height_out, width_out}, input.options());

using scalar_t = float;
for (int b = 0; b < batch; b++)
Expand All @@ -71,37 +77,35 @@ dcn_v2_cpu_forward(const at::Tensor &input,
auto offset_n = offset.select(0, b);
auto mask_n = mask.select(0, b);
auto output_n = output.select(0, b);
// std::cout << "output_n: " << output_n << "output.select(0,b): " << output.select(0,b) << "\n";

// Do Bias first:
// M,N,K are dims of matrix A and B
// (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm)
// (N x 1) (1 x M)
long m_ = channels_out;
long n_ = height_out * width_out;
long k_ = 1;
THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,
ones.contiguous().data<scalar_t>(), k_,
bias.contiguous().data<scalar_t>(), k_, 0.0f,
output_n.data<scalar_t>(), n_);

modulated_deformable_im2col_cpu(input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),

// torch implementation
auto ones_T = at::transpose(ones.contiguous(), 2, 0);
ones_T = at::mul(ones_T, bias.contiguous());
ones_T = at::transpose(ones_T, 2, 0);
output_n = at::add(output_n, ones_T);

modulated_deformable_im2col_cpu(input_n.data_ptr<scalar_t>(),
offset_n.data_ptr<scalar_t>(),
mask_n.data_ptr<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w,
deformable_group,
columns.data<scalar_t>());
columns.data_ptr<scalar_t>());

//(k * m) x (m * n)
// Y = WC
long m = channels_out;
long n = height_out * width_out;
long k = channels * kernel_h * kernel_w;
THFloatBlas_gemm('n', 'n', n, m, k, 1.0f,
columns.data<scalar_t>(), n,
weight.data<scalar_t>(), k, 1.0f,
output_n.data<scalar_t>(), n);

// torch implementation
auto weight_flat = weight.view({channels_out, channels * kernel_h * kernel_w});
auto product = at::matmul(weight_flat, columns);
output.select(0, b) = at::add(output_n, product.view({channels_out, height_out, width_out}));
}
return output;
}
Expand Down Expand Up @@ -148,7 +152,7 @@ std::vector<at::Tensor> dcn_v2_cpu_backward(const at::Tensor &input,
const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;

auto ones = at::ones({height_out, width_out}, input.options());
auto columns = at::empty({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto columns = at::zeros({channels * kernel_h * kernel_w, 1 * height_out * width_out}, input.options());
auto output = at::empty({batch, channels_out, height_out, width_out}, input.options());

auto grad_input = at::zeros_like(input);
Expand All @@ -169,65 +173,57 @@ std::vector<at::Tensor> dcn_v2_cpu_backward(const at::Tensor &input,
auto grad_offset_n = grad_offset.select(0, b);
auto grad_mask_n = grad_mask.select(0, b);

long m = channels * kernel_h * kernel_w;
long n = height_out * width_out;
long k = channels_out;

THFloatBlas_gemm('n', 't', n, m, k, 1.0f,
grad_output_n.data<scalar_t>(), n,
weight.data<scalar_t>(), m, 0.0f,
columns.data<scalar_t>(), n);

// Torch implementation
auto weight_flat = weight.view({channels_out, channels*kernel_h*kernel_w});
weight_flat = at::transpose(weight_flat, 1, 0);
auto grad_output_n_flat = grad_output_n.view({channels_out, height_out*width_out});
columns = at::matmul(weight_flat, grad_output_n_flat);

// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cpu(columns.data<scalar_t>(),
input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
modulated_deformable_col2im_coord_cpu(columns.data_ptr<scalar_t>(),
input_n.data_ptr<scalar_t>(),
offset_n.data_ptr<scalar_t>(),
mask_n.data_ptr<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_offset_n.data<scalar_t>(),
grad_mask_n.data<scalar_t>());
grad_offset_n.data_ptr<scalar_t>(),
grad_mask_n.data_ptr<scalar_t>());
// gradient w.r.t. input data
modulated_deformable_col2im_cpu(columns.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
modulated_deformable_col2im_cpu(columns.data_ptr<scalar_t>(),
offset_n.data_ptr<scalar_t>(),
mask_n.data_ptr<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
grad_input_n.data<scalar_t>());
grad_input_n.data_ptr<scalar_t>());

// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cpu(input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
modulated_deformable_im2col_cpu(input_n.data_ptr<scalar_t>(),
offset_n.data_ptr<scalar_t>(),
mask_n.data_ptr<scalar_t>(),
1, channels, height, width,
height_out, width_out, kernel_h, kernel_w,
pad_h, pad_w, stride_h, stride_w,
dilation_h, dilation_w, deformable_group,
columns.data<scalar_t>());

long m_ = channels_out;
long n_ = channels * kernel_h * kernel_w;
long k_ = height_out * width_out;

THFloatBlas_gemm('t', 'n', n_, m_, k_, 1.0f,
columns.data<scalar_t>(), k_,
grad_output_n.data<scalar_t>(), k_, 1.0f,
grad_weight.data<scalar_t>(), n_);

// gradient w.r.t. bias
// long m_ = channels_out;
// long k__ = height_out * width_out;
THFloatBlas_gemv('t', k_, m_, 1.0f,
grad_output_n.data<scalar_t>(), k_,
ones.data<scalar_t>(), 1, 1.0f,
grad_bias.data<scalar_t>(), 1);
columns.data_ptr<scalar_t>());

// Torch implementation
auto product = at::matmul(grad_output_n_flat, at::transpose(columns, 1, 0));
grad_weight = at::add(grad_weight, product.view({channels_out, channels, kernel_h, kernel_w}));


// Torch implementation
auto ones_flat = ones.view({height_out*width_out});
product = at::matmul(grad_output_n_flat, ones_flat);
grad_bias = at::add(grad_bias, product);
}

return {
grad_input, grad_offset, grad_mask, grad_weight, grad_bias
};
}
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
30 changes: 18 additions & 12 deletions src/cuda/dcn_v2_cuda.cu → DCN/src/cuda/dcn_v2_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ dcn_v2_cuda_forward(const at::Tensor &input,
const int block = 128;
const int grid = (batch + block - 1) / block;

createBatchGemmBuffer<<<grid, block, 0, THCState_getCurrentStream(state)>>>(
createBatchGemmBuffer<<<grid, block, 0, c10::cuda::getCurrentCUDAStream()>>>(
input_b, output_b,
columns_b, ones_b,
weight_b, bias_b,
Expand Down Expand Up @@ -136,7 +136,7 @@ dcn_v2_cuda_forward(const at::Tensor &input,
output_b, n_,
batch);

modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),
input.data<scalar_t>(),
offset.data<scalar_t>(),
mask.data<scalar_t>(),
Expand Down Expand Up @@ -276,7 +276,7 @@ std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,
columns.data<scalar_t>(), n);

// gradient w.r.t. input coordinate data
modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state),
modulated_deformable_col2im_coord_cuda(c10::cuda::getCurrentCUDAStream(),
columns.data<scalar_t>(),
input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
Expand All @@ -288,7 +288,7 @@ std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,
grad_offset_n.data<scalar_t>(),
grad_mask_n.data<scalar_t>());
// gradient w.r.t. input data
modulated_deformable_col2im_cuda(THCState_getCurrentStream(state),
modulated_deformable_col2im_cuda(c10::cuda::getCurrentCUDAStream(),
columns.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
Expand All @@ -299,7 +299,7 @@ std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,
grad_input_n.data<scalar_t>());

// gradient w.r.t. weight, dWeight should accumulate across the batch and group
modulated_deformable_im2col_cuda(THCState_getCurrentStream(state),
modulated_deformable_im2col_cuda(c10::cuda::getCurrentCUDAStream(),
input_n.data<scalar_t>(),
offset_n.data<scalar_t>(),
mask_n.data<scalar_t>(),
Expand All @@ -321,15 +321,21 @@ std::vector<at::Tensor> dcn_v2_cuda_backward(const at::Tensor &input,
// gradient w.r.t. bias
// long m_ = channels_out;
// long k__ = height_out * width_out;
THCudaBlas_Sgemv(state,
't',
k_, m_, 1.0f,
grad_output_n.data<scalar_t>(), k_,
ones.data<scalar_t>(), 1, 1.0f,
grad_bias.data<scalar_t>(), 1);
// THCudaBlas_Sgemm(state,
// 't', 'n',
// k_, m_, 1, 1.0f,
// grad_output_n.data<scalar_t>(), k_,
// ones.data<scalar_t>(), 1, 1.0f,
// grad_bias.data<scalar_t>(), 1);
THCudaBlas_Sgemm(state,
'N', 'N', 1, m_, k_, 1.0f,
ones.data<scalar_t>(), 1,
grad_output_n.data<scalar_t>(), k_,
1.0f,
grad_bias.data<scalar_t>(), 1);
}

return {
grad_input, grad_offset, grad_mask, grad_weight, grad_bias
};
}
}
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion testcuda.py → DCN/testcuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def check_pooling_zero_offset():


def check_gradient_dpooling():
input = torch.randn(2, 3, 5, 5).cuda() * 0.01
input = torch.randn(2, 3, 5, 5).cuda().float() * 0.01
N = 4
batch_inds = torch.randint(2, (N, 1)).cuda().float()
x = torch.rand((N, 1)).cuda().float() * 15
Expand Down
Empty file removed __init__.py
Empty file.
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def get_extensions():
this_dir = os.path.dirname(os.path.abspath(__file__))
extensions_dir = os.path.join(this_dir, "src")
extensions_dir = os.path.join(this_dir, "DCN", "src")

main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
source_cpu = glob.glob(os.path.join(extensions_dir, "cpu", "*.cpp"))
Expand Down Expand Up @@ -68,4 +68,4 @@ def get_extensions():
# install_requires=requirements,
ext_modules=get_extensions(),
cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
)
)