Skip to content

Commit

Permalink
Better fix to the tensor overwriting issue using less temp memory
Browse files Browse the repository at this point in the history
  • Loading branch information
ducksoup committed Jan 8, 2020
1 parent fd332f2 commit ff5439b
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 100 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ To install PyTorch, please refer to https://github.com/pytorch/pytorch#installat

To install the package containing the iABN layers:
```bash
pip install git+https://github.com/mapillary/[email protected].9
pip install git+https://github.com/mapillary/[email protected].10
```
Note that some parts of InPlace-ABN have native C++/CUDA implementations, meaning that the command above will need to
compile them.
Expand Down
50 changes: 31 additions & 19 deletions include/inplace_abn.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,19 +33,19 @@ void forward_cuda(at::Tensor& x, const at::Tensor& mean, const at::Tensor& var,
const c10::optional<at::Tensor>& weight, const c10::optional<at::Tensor>& bias,
float eps, Activation activation, float activation_param);

std::tuple<at::Tensor, at::Tensor> backward_reduce_cpu(
at::Tensor& y_act, at::Tensor& dy_act, const c10::optional<at::Tensor>& weight,
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> backward_reduce_cpu(
const at::Tensor& y_act, const at::Tensor& dy_act, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias, float eps, Activation activation, float activation_param);
std::tuple<at::Tensor, at::Tensor> backward_reduce_cuda(
at::Tensor& y_act, at::Tensor& dy_act, const c10::optional<at::Tensor>& weight,
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> backward_reduce_cuda(
const at::Tensor& y_act, const at::Tensor& dy_act, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias, float eps, Activation activation, float activation_param);

at::Tensor backward_cpu(const at::Tensor& xhat, const at::Tensor& dy, const at::Tensor& var, const at::Tensor& count,
const at::Tensor& sum_dy, const at::Tensor& sum_xhat_dy,
const c10::optional<at::Tensor>& weight, float eps);
at::Tensor backward_cuda(const at::Tensor& xhat, const at::Tensor& dy, const at::Tensor& var, const at::Tensor& count,
const at::Tensor& sum_dy, const at::Tensor& sum_xhat_dy,
const c10::optional<at::Tensor>& weight, float eps);
void backward_cpu(const at::Tensor& xhat, at::Tensor& dy, const at::Tensor& var, const at::Tensor& count,
const at::Tensor& sum_dy, const at::Tensor& sum_xhat_dy,
const c10::optional<at::Tensor>& weight, float eps);
void backward_cuda(const at::Tensor& xhat, at::Tensor& dy, const at::Tensor& var, const at::Tensor& count,
const at::Tensor& sum_dy, const at::Tensor& sum_xhat_dy,
const c10::optional<at::Tensor>& weight, float eps);

/***********************************************************************************************************************
* Handling of activation functions
Expand All @@ -59,10 +59,14 @@ struct ActivationFn<scalar_t, Activation::LeakyReLU> {
x = (x >= 0) ? x : static_cast<scalar_t>(x * activation_param);
}

static INLINE_HOST_DEVICE void backward(scalar_t& y_act, scalar_t& dy_act, float activation_param) {
if (y_act < 0) {
y_act /= static_cast<scalar_t>(activation_param);
dy_act *= static_cast<scalar_t>(activation_param);
static INLINE_HOST_DEVICE void backward(scalar_t y_act, scalar_t dy_act, float activation_param,
scalar_t& y, scalar_t& dy) {
if (y_act >= 0) {
y = y_act;
dy = dy_act;
} else {
y = static_cast<scalar_t>(y_act / activation_param);
dy = static_cast<scalar_t>(dy_act * activation_param);
}
}
};
Expand All @@ -73,10 +77,14 @@ struct ActivationFn<scalar_t, Activation::ELU> {
x = (x >= 0) ? x : static_cast<scalar_t>(activation_param * (::exp(x) - 1));
}

static INLINE_HOST_DEVICE void backward(scalar_t& y_act, scalar_t& dy_act, float activation_param) {
if (y_act < 0) {
dy_act *= y_act + static_cast<scalar_t>(activation_param);
y_act = ::log1p(y_act / static_cast<scalar_t>(activation_param));
static INLINE_HOST_DEVICE void backward(scalar_t y_act, scalar_t dy_act, float activation_param,
scalar_t& y, scalar_t& dy) {
if (y_act >= 0) {
y = y_act;
dy = dy_act;
} else {
y = ::log1p(static_cast<scalar_t>(y_act / activation_param));
dy = static_cast<scalar_t>(dy_act * (y_act + activation_param));
}
}
};
Expand All @@ -85,5 +93,9 @@ template<typename scalar_t>
struct ActivationFn<scalar_t, Activation::Identity> {
static INLINE_HOST_DEVICE void forward(scalar_t& x, float activation_param) {}

static INLINE_HOST_DEVICE void backward(scalar_t& y_act, scalar_t& dy_act, float activation_param) {}
static INLINE_HOST_DEVICE void backward(scalar_t y_act, scalar_t dy_act, float activation_param,
scalar_t& y, scalar_t& dy) {
y = y_act;
dy = dy_act;
}
};
9 changes: 4 additions & 5 deletions inplace_abn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,12 +106,9 @@ def forward(ctx, x, weight, bias, running_mean, running_var,
def backward(ctx, dy_act):
y_act, var, count, weight, bias = ctx.saved_tensors

# Create clones of y_act and dy_act, as the backend will modify them in-place
y_act, dy_act = y_act.clone(), dy_act.clone()

# Call backward_reduce if we need to compute at least one of the gradients
if any(ctx.needs_input_grad):
sum_dy_local, sum_xhat_dy_local = _backend.backward_reduce(
xhat, dy, sum_dy_local, sum_xhat_dy_local = _backend.backward_reduce(
y_act, dy_act, weight, bias, ctx.eps, ctx.activation, ctx.activation_param)

if ctx.distributed:
Expand All @@ -125,7 +122,9 @@ def backward(ctx, dy_act):
# Gradient w.r.t. x
if ctx.needs_input_grad[0]:
if ctx.training:
dx = _backend.backward(y_act, dy_act, var, count, sum_dy, sum_xhat_dy, weight, ctx.eps)
# This overwrites dy with dx
_backend.backward(xhat, dy, var, count, sum_dy, sum_xhat_dy, weight, ctx.eps)
dx = dy
else:
dx = _backend.backward_test(dy_act, var, weight, ctx.eps)
else:
Expand Down
16 changes: 7 additions & 9 deletions src/inplace_abn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ void forward(at::Tensor& x, const at::Tensor& mean, const at::Tensor& var,
CUDA_DISPATCH(x, forward, x, mean, var, weight, bias, eps, activation, activation_param)
}

std::tuple<at::Tensor, at::Tensor> backward_reduce(
at::Tensor& y_act, at::Tensor& dy_act, const c10::optional<at::Tensor>& weight,
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> backward_reduce(
const at::Tensor& y_act, const at::Tensor& dy_act, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias, float eps, Activation activation, float activation_param) {
// Check dimensions and types
AT_CHECK(y_act.ndimension() >= 2, "y_act should have at least 2 dimensions");
Expand All @@ -86,9 +86,9 @@ std::tuple<at::Tensor, at::Tensor> backward_reduce(
CUDA_DISPATCH(y_act, backward_reduce, y_act, dy_act, weight, bias, eps, activation, activation_param)
}

at::Tensor backward(const at::Tensor& xhat, const at::Tensor& dy, const at::Tensor& var, const at::Tensor& count,
const at::Tensor& sum_dy, const at::Tensor& sum_xhat_dy, const c10::optional<at::Tensor>& weight,
float eps) {
void backward(const at::Tensor& xhat, at::Tensor& dy, const at::Tensor& var, const at::Tensor& count,
const at::Tensor& sum_dy, const at::Tensor& sum_xhat_dy, const c10::optional<at::Tensor>& weight,
float eps) {
// Check dimensions and types
AT_CHECK(xhat.ndimension() >= 2, "xhat should have at least 2 dimensions");
AT_CHECK(have_same_dims(xhat, dy), "xhat and dy should have the same size");
Expand Down Expand Up @@ -141,9 +141,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &forward, "iABN forward pass. This is an in-place operation w.r.t. x");

// Backward methods
m.def("backward_reduce", &backward_reduce,
"First step of the backward pass. This is an in-place operation w.r.t. y_act and dy_act, which are transformed "
"into xhat and dy, respectively.");
m.def("backward", &backward, "Second step of the backward pass");
m.def("backward_reduce", &backward_reduce, "First step of the backward pass");
m.def("backward", &backward, "Second step of the backward pass. This is an in-place operation w.r.t. dy");
m.def("backward_test", &backward_test, "Second step of the backward pass, test mode");
}
59 changes: 34 additions & 25 deletions src/inplace_abn_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,19 +20,29 @@ int32_t count_samples(const at::Tensor& x) {
**********************************************************************************************************************/

template<typename scalar_t, Activation activation>
std::tuple<at::Tensor, at::Tensor> backward_reduce_impl(
at::Tensor& y_act_, at::Tensor& dy_act_, const c10::optional<at::Tensor>& weight_,
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> backward_reduce_impl(
const at::Tensor& y_act_, const at::Tensor& dy_act_, const c10::optional<at::Tensor>& weight_,
const c10::optional<at::Tensor>& bias_, float eps, float activation_param) {
// Get dimensions
int64_t num = y_act_.size(0), chn = y_act_.size(1), sp = y_act_.size(2);

// Initialize output tensors
auto sum_dy_ = at::zeros({chn}, y_act_.options());
auto sum_xhat_dy_ = at::zeros({chn}, y_act_.options());
auto xhat_ = at::empty_like(y_act_);
auto dy_ = at::empty_like(y_act_);
auto sum_dy_ = at::zeros({y_act_.size(1)}, y_act_.options());
auto sum_xhat_dy_ = at::zeros({y_act_.size(1)}, y_act_.options());

// Normalize shapes
auto y_act_norm_ = normalize_shape(y_act_);
auto dy_act_norm_ = normalize_shape(dy_act_);
auto xhat_norm_ = normalize_shape(xhat_);
auto dy_norm_ = normalize_shape(dy_);

// Get dimensions
int64_t num = y_act_norm_.size(0), chn = y_act_norm_.size(1), sp = y_act_norm_.size(2);

// Make accessors
auto y_act = y_act_.accessor<scalar_t, 3>();
auto dy_act = dy_act_.accessor<scalar_t, 3>();
auto y_act = y_act_norm_.accessor<scalar_t, 3>();
auto dy_act = dy_act_norm_.accessor<scalar_t, 3>();
auto xhat = xhat_norm_.accessor<scalar_t, 3>();
auto dy = dy_norm_.accessor<scalar_t, 3>();
auto weight = accessor_or_dummy<scalar_t, 1>(weight_);
auto bias = accessor_or_dummy<scalar_t, 1>(bias_);
auto sum_dy = sum_dy_.accessor<scalar_t, 1>();
Expand All @@ -46,22 +56,24 @@ std::tuple<at::Tensor, at::Tensor> backward_reduce_impl(
for (int64_t n = 0; n < num; ++n) {
auto y_act_nc = y_act[n][c];
auto dy_act_nc = dy_act[n][c];
auto xhat_nc = xhat[n][c];
auto dy_nc = dy[n][c];

for (int64_t s = 0; s < sp; ++s) {
// Invert activation
ActivationFn<scalar_t, activation>::backward(y_act_nc[s], dy_act_nc[s], activation_param);
ActivationFn<scalar_t, activation>::backward(y_act_nc[s], dy_act_nc[s], activation_param, xhat_nc[s], dy_nc[s]);

// Invert affine transformation
y_act_nc[s] = (y_act_nc[s] - beta_c) * inv_gamma_c;
xhat_nc[s] = (xhat_nc[s] - beta_c) * inv_gamma_c;

// Accumulate
sum_dy[c] += dy_act_nc[s];
sum_xhat_dy[c] += y_act_nc[s] * dy_act_nc[s];
sum_dy[c] += dy_nc[s];
sum_xhat_dy[c] += xhat_nc[s] * dy_nc[s];
}
}
}

return std::make_tuple(sum_dy_, sum_xhat_dy_);
return std::make_tuple(xhat_, dy_, sum_dy_, sum_xhat_dy_);
}

/***********************************************************************************************************************
Expand Down Expand Up @@ -108,13 +120,10 @@ void forward_cpu(at::Tensor& x_, const at::Tensor& mean, const at::Tensor& var,
}
}

std::tuple<at::Tensor, at::Tensor> backward_reduce_cpu(
at::Tensor& y_act_, at::Tensor& dy_act_, const c10::optional<at::Tensor>& weight,
std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor> backward_reduce_cpu(
const at::Tensor& y_act, const at::Tensor& dy_act, const c10::optional<at::Tensor>& weight,
const c10::optional<at::Tensor>& bias, float eps, Activation activation, float activation_param) {
CHECK_NOT_HALF(y_act_);

auto y_act = normalize_shape(y_act_);
auto dy_act = normalize_shape(dy_act_);
CHECK_NOT_HALF(y_act);

// Run templated implementation
return AT_DISPATCH_FLOATING_TYPES(y_act.scalar_type(), "backward_reduce_cpu", [&] {
Expand All @@ -130,9 +139,9 @@ std::tuple<at::Tensor, at::Tensor> backward_reduce_cpu(
});
}

at::Tensor backward_cpu(const at::Tensor& xhat_, const at::Tensor& dy_, const at::Tensor& var, const at::Tensor& count,
const at::Tensor& sum_dy, const at::Tensor& sum_xhat_dy,
const c10::optional<at::Tensor>& weight, float eps) {
void backward_cpu(const at::Tensor& xhat_, at::Tensor& dy_, const at::Tensor& var, const at::Tensor& count,
const at::Tensor& sum_dy, const at::Tensor& sum_xhat_dy,
const c10::optional<at::Tensor>& weight, float eps) {
CHECK_NOT_HALF(xhat_);

auto xhat = normalize_shape(xhat_);
Expand All @@ -141,7 +150,7 @@ at::Tensor backward_cpu(const at::Tensor& xhat_, const at::Tensor& dy_, const at
auto mean_xhat_dy = normalize_shape(sum_xhat_dy / count.to(sum_xhat_dy.options()));

auto mult = weight.has_value() ? (weight.value().abs() + eps) / (var + eps).sqrt() : 1 / (var + eps).sqrt();
auto dx = normalize_shape(mult) * (dy - mean_dy - xhat * mean_xhat_dy);

return dx.view(xhat_.sizes());
// dy = (dy - mean_dy - xhat * mean_xhat_dy) * mult
dy.sub_(mean_dy).sub_(xhat * mean_xhat_dy).mul_(normalize_shape(mult));
}
Loading

0 comments on commit ff5439b

Please sign in to comment.