Skip to content

Commit 63b180b

Browse files
amathews-amdfacebook-github-bot
authored andcommitted
ROCm MIOpen NHWC Convolution support (pytorch#63617)
Summary: - Added 2D-Convolution NHWC support - on ROCm 4.3, with `PYTORCH_MIOPEN_SUGGEST_NHWC=1` flag - May need to force MIOpen to search for solutions ( see examples below for flags ) **PYTORCH_MIOPEN_SUGGEST_NHWC Environment Flag** MIOpen does not officially support NHWC yet, although convolution support has been added to tip-of-tree of MIOpen. This flag is intended to be a short-lived flag to explicitly turn on NHWC support until ROCm officially supports NHWC and performance is verified. **Examples** 1. Example usage 1 : Run test on ROCm4.3 `PYTORCH_TEST_WITH_ROCM=1 PYTORCH_MIOPEN_SUGGEST_NHWC=1 MIOPEN_FIND_ENFORCE=4 MIOPEN_DEBUG_CONV_GEMM=0 MIOPEN_FIND_MODE=1 pytest test_nn.py -v -k "test_conv_cudnn_nhwc" ` 2. Example usage 2: Run the following with `PYTORCH_MIOPEN_SUGGEST_NHWC=1` on ROCm4.3. ``` #!/usr/bin/env python3 import torch model = torch.nn.Conv2d(8, 4, 3).cuda().half() model = model.to(memory_format=torch.channels_last) input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True) input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16) # should print True for is_contiguous(channels_last), and strides must match NHWC format print(input.is_contiguous(memory_format=torch.channels_last), input.shape, input.stride() ) out = model(input) # should print True for is_contiguous(channels_last), and strides must match NHWC format print("Contiguous channel last :", out.is_contiguous(memory_format=torch.channels_last), " out shape :", out.shape, "out stride :", out.stride() ) ``` See https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html for more examples. cc jeffdaily sunway513 jithunnair-amd ROCmSupport Pull Request resolved: pytorch#63617 Reviewed By: saketh-are Differential Revision: D30730800 Pulled By: ezyang fbshipit-source-id: 61906a0f30be8299e6547d312ae6ac91cc7c3238
1 parent 2a81e8b commit 63b180b

File tree

9 files changed

+287
-71
lines changed

9 files changed

+287
-71
lines changed

aten/src/ATen/miopen/Descriptors.cpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -90,17 +90,17 @@ std::ostream& operator<<(std::ostream & out, const TensorDescriptor& d) {
9090

9191
void TensorDescriptor::print() { std::cout << *this; }
9292

93-
void FilterDescriptor::set(const at::Tensor &t, int64_t pad) {
93+
void FilterDescriptor::set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad) {
9494
auto dim = t.ndimension();
9595
if (dim > MIOPEN_DIM_MAX || pad > MIOPEN_DIM_MAX)
9696
#define _STR(X) #X
9797
#define STR(X) _STR(X)
9898
throw std::runtime_error("MIOpen supports only up to " STR(MIOPEN_DIM_MAX) " dimensions");
9999
#undef _STR
100100
#undef STR
101-
if (!t.is_contiguous()) {
102-
throw std::runtime_error("MIOpen filters (a.k.a. weights) must be contiguous");
103-
}
101+
TORCH_CHECK(t.is_contiguous(memory_format),
102+
"MIOpen filters (a.k.a. weights) must be contiguous");
103+
104104
int size[MIOPEN_DIM_MAX];
105105
int stride[MIOPEN_DIM_MAX];
106106
for (int i = 0; i < dim; ++i) {
@@ -109,9 +109,15 @@ void FilterDescriptor::set(const at::Tensor &t, int64_t pad) {
109109
for (int i = dim; i < pad; ++i) {
110110
size[i] = (int) 1;
111111
}
112-
for (int i = dim - 1; i >=0; --i) {
113-
stride[i] = (i == dim - 1) ? 1 : stride[i+1] * size[i+1];
112+
113+
for (int i = pad; i >= dim; --i ) {
114+
stride[i] = 1;
114115
}
116+
for (int i = dim-1 ; i >=0; --i ) {
117+
// Pass-through
118+
stride[i] = t.stride(i);
119+
}
120+
115121
dim = std::max(dim, pad);
116122
set(getDataType(t), (int) dim, size, stride);
117123
}

aten/src/ATen/miopen/Descriptors.h

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,6 @@ inline int dataSize(miopenDataType_t dataType)
1818
}
1919
}
2020

21-
// This function modifies 'stride' in place so that the stride for
22-
// dim i is the product of the sizes of dims i+1 to the end.
23-
static inline void fixSizeOneDimStride(int dim, const int *size, int *stride) {
24-
int64_t z = 1;
25-
for(int d = dim-1; d >= 0; d--)
26-
{
27-
if (size[d] == 1) {
28-
stride[d] = z;
29-
} else {
30-
z *= size[d];
31-
}
32-
}
33-
}
34-
3521
template <typename T, miopenStatus_t (*dtor)(T*)>
3622
struct DescriptorDeleter {
3723
void operator()(T* x) {
@@ -96,7 +82,6 @@ class TensorDescriptor
9682

9783
private:
9884
void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
99-
fixSizeOneDimStride(dim, size, stride);
10085
MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
10186
}
10287
};
@@ -108,12 +93,15 @@ class FilterDescriptor
10893
&miopenCreateTensorDescriptor,
10994
&miopenDestroyTensorDescriptor>
11095
{
111-
public:
112-
void set(const at::Tensor &t, int64_t pad = 0);
96+
public:
97+
void set(const at::Tensor &t, int64_t pad = 0) {
98+
set(t, at::MemoryFormat::Contiguous, pad);
99+
}
100+
101+
void set(const at::Tensor &t, const at::MemoryFormat memory_format, int64_t pad = 0);
113102

114103
private:
115104
void set(miopenDataType_t dataType, int dim, int* size, int* stride) {
116-
fixSizeOneDimStride(dim, size, stride);
117105
MIOPEN_CHECK(miopenSetTensorDescriptor(mut_desc(), dataType, dim, size, stride));
118106
}
119107
};

aten/src/ATen/native/ConvUtils.h

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <ATen/detail/CUDAHooksInterface.h>
3+
#include <c10/util/env.h>
34

45
namespace at { namespace native {
56

@@ -106,4 +107,33 @@ static inline bool cudnn_conv_use_channels_last(const at::Tensor& input, const a
106107
return can_use_cudnn_channels_last_2d || can_use_cudnn_channels_last_3d;
107108
}
108109

110+
static inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Tensor& weight) {
111+
112+
// disable NHWC for float64 input.
113+
if (!at::detail::getCUDAHooks().compiledWithMIOpen() ||
114+
input.scalar_type() == at::kDouble ||
115+
weight.scalar_type() == at::kDouble) {
116+
return false;
117+
}
118+
119+
bool can_use_miopen_channels_last_2d = false;
120+
#if defined(USE_ROCM) && (ROCM_VERSION >= 40300)
121+
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
122+
// See #64427
123+
static c10::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
124+
125+
auto input_memory_format = input.suggest_memory_format();
126+
auto weight_memory_format = weight.suggest_memory_format();
127+
128+
can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && (
129+
( (input_memory_format == at::MemoryFormat::ChannelsLast) ||
130+
(weight_memory_format == at::MemoryFormat::ChannelsLast) )
131+
);
132+
#endif
133+
134+
bool can_use_miopen_channels_last_3d = false;
135+
136+
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
137+
}
138+
109139
}} // namespace at::native

aten/src/ATen/native/Convolution.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -838,9 +838,14 @@ at::Tensor _convolution(
838838
weight = view4d(weight);
839839
}
840840

841-
at::MemoryFormat cudnn_memory_format = at::MemoryFormat::Contiguous;
842-
if (cudnn_conv_use_channels_last(input, weight)) {
843-
cudnn_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
841+
at::MemoryFormat backend_memory_format = at::MemoryFormat::Contiguous;
842+
843+
if (detail::getCUDAHooks().compiledWithCuDNN() && cudnn_conv_use_channels_last(input, weight)) {
844+
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
845+
}
846+
847+
if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
848+
backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
844849
}
845850

846851
Tensor output;
@@ -853,15 +858,15 @@ at::Tensor _convolution(
853858
auto dilation = params.dilation;
854859
if (params.use_cudnn_depthwise(input, weight)) {
855860
output = at::cudnn_convolution(
856-
input.contiguous(cudnn_memory_format), weight,
861+
input.contiguous(backend_memory_format), weight,
857862
padding, stride, dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
858863
if (bias.defined()) {
859864
output.add_(reshape_bias(input.dim(), bias));
860865
}
861866

862867
} else if (params.use_miopen(input, weight, bias.defined())){
863868
output = at::miopen_depthwise_convolution(
864-
input.contiguous(), weight, bias,
869+
input.contiguous(backend_memory_format), weight, bias,
865870
padding, stride, dilation, params.groups, params.benchmark, params.deterministic);
866871
} else {
867872
if (input.ndimension() == 4) {
@@ -882,14 +887,14 @@ at::Tensor _convolution(
882887

883888
if (params.transposed) {
884889
output = at::cudnn_convolution_transpose(
885-
input.contiguous(cudnn_memory_format), weight,
890+
input.contiguous(backend_memory_format), weight,
886891
params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
887892
if (bias.defined()) {
888893
output.add_(reshape_bias(input.dim(), bias));
889894
}
890895
} else {
891896
output = at::cudnn_convolution(
892-
input.contiguous(cudnn_memory_format), weight,
897+
input.contiguous(backend_memory_format), weight,
893898
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic, params.allow_tf32);
894899
if (bias.defined()) {
895900
output.add_(reshape_bias(input.dim(), bias));
@@ -905,11 +910,11 @@ at::Tensor _convolution(
905910

906911
if (params.transposed) {
907912
output = at::miopen_convolution_transpose(
908-
input.contiguous(), weight, bias,
913+
input.contiguous(backend_memory_format), weight, bias,
909914
params.padding, params.output_padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
910915
} else {
911916
output = at::miopen_convolution(
912-
input.contiguous(), weight, bias,
917+
input.contiguous(backend_memory_format), weight, bias,
913918
params.padding, params.stride, params.dilation, params.groups, params.benchmark, params.deterministic);
914919
}
915920
} else if (params.use_mkldnn(input, weight)) {

0 commit comments

Comments
 (0)