Skip to content

Revert "[ROCm] Improvements to non-vectorized elementwise kernels (#1… #1944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: release/2.5
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 1 addition & 142 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -185,44 +185,6 @@ __global__ void elementwise_kernel(int N, func_t f) {
}
}

#ifdef USE_ROCM
template <int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, 4)
__global__ void elementwise_kernel_manual_unroll(int N, func_t f) {
int tid = threadIdx.x;
int nv = nt * vt;
int idx = nv * blockIdx.x + tid;
if ((idx + nt*(vt-1)) < N) {
f(idx, true);
} else {
#pragma unroll
for (int i = 0; i < vt; i++) {
if (idx < N) {
f(idx, false);
idx += nt;
}
}
}
}

template <int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, 4)
__global__ void elementwise_kernel_strided(int N, func_t f) {
int tid = threadIdx.x;
int idx = nt * vt * blockIdx.x + tid;
int step = nt * vt * gridDim.x;
while (idx < N) {
#pragma unroll
for (int i = 0; i < vt; i++) {
if ((idx + nt * i) < N) {
f(idx + nt * i);
}
}
idx += step;
}
}
#endif

template <int nt, int vt, typename func_t>
static void launch_legacy_kernel(int64_t N, const func_t& f) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
Expand All @@ -236,37 +198,6 @@ static void launch_legacy_kernel(int64_t N, const func_t& f) {
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

#ifdef USE_ROCM
template <int nt, int vt, typename func_t>
static void launch_legacy_kernel_manual_unroll(int64_t N, const func_t& f) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
return;
}
dim3 block(nt);
dim3 grid((N + block.x * vt - 1) / (block.x * vt));
auto stream = at::cuda::getCurrentCUDAStream();
elementwise_kernel_manual_unroll<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <int nt, int vt, typename func_t>
static void launch_legacy_kernel_strided(int64_t N, const func_t& f) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
return;
}
dim3 block(nt);
dim3 grid(8192);
auto stream = at::cuda::getCurrentCUDAStream();
int ub_idx = nt * vt;
ub_idx = ub_idx * (grid.x - 1) +(block.x - 1);
ub_idx = ub_idx + nt*vt;
elementwise_kernel_strided<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
#endif

template <typename traits, typename func_t, typename index_t, size_t... INDEX>
C10_HOST_DEVICE typename traits::result_type invoke_impl(
const func_t& f,
Expand Down Expand Up @@ -345,84 +276,12 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
return launch_vectorized_kernel(numel, f, data);
}
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
#ifndef USE_ROCM
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 2 : 4;
launch_legacy_kernel<128, unroll_factor>(numel, [=] GPU_LAMBDA(int idx) {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
*out = invoke(f, &data.data[1], &offsets.data[1], 1);
});
#else
constexpr int unroll_factor = sizeof(arg0_t) >= 4 ? 4 : 8;
constexpr int grp_sz = 128;
launch_legacy_kernel_manual_unroll<grp_sz, unroll_factor>(numel, [=] GPU_LAMBDA(int idx, bool unrl4x) {
if constexpr (unroll_factor == 4) {
if (unrl4x) {
auto offsets0 = offset_calc.get(idx);
auto offsets1 = offset_calc.get(idx+grp_sz);
auto offsets2 = offset_calc.get(idx+grp_sz*2);
auto offsets3 = offset_calc.get(idx+grp_sz*3);
arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]);
arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]);
arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]);
arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]);
auto tmp0 = invoke(f, &data.data[1], &offsets0.data[1], 1);
auto tmp1 = invoke(f, &data.data[1], &offsets1.data[1], 1);
auto tmp2 = invoke(f, &data.data[1], &offsets2.data[1], 1);
auto tmp3 = invoke(f, &data.data[1], &offsets3.data[1], 1);
*out0 = tmp0;
*out1 = tmp1;
*out2 = tmp2;
*out3 = tmp3;
}
else {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
*out = invoke(f, &data.data[1], &offsets.data[1], 1);
}
} else {
if (unrl4x) {
auto offsets0 = offset_calc.get(idx);
auto offsets1 = offset_calc.get(idx+grp_sz);
auto offsets2 = offset_calc.get(idx+grp_sz*2);
auto offsets3 = offset_calc.get(idx+grp_sz*3);
auto offsets4 = offset_calc.get(idx+grp_sz*4);
auto offsets5 = offset_calc.get(idx+grp_sz*5);
auto offsets6 = offset_calc.get(idx+grp_sz*6);
auto offsets7 = offset_calc.get(idx+grp_sz*7);
arg0_t* out0 = (arg0_t*)(data[0] + offsets0[0]);
arg0_t* out1 = (arg0_t*)(data[0] + offsets1[0]);
arg0_t* out2 = (arg0_t*)(data[0] + offsets2[0]);
arg0_t* out3 = (arg0_t*)(data[0] + offsets3[0]);
arg0_t* out4 = (arg0_t*)(data[0] + offsets4[0]);
arg0_t* out5 = (arg0_t*)(data[0] + offsets5[0]);
arg0_t* out6 = (arg0_t*)(data[0] + offsets6[0]);
arg0_t* out7 = (arg0_t*)(data[0] + offsets7[0]);
auto tmp0 = invoke(f, &data.data[1], &offsets0.data[1], 1);
auto tmp1 = invoke(f, &data.data[1], &offsets1.data[1], 1);
auto tmp2 = invoke(f, &data.data[1], &offsets2.data[1], 1);
auto tmp3 = invoke(f, &data.data[1], &offsets3.data[1], 1);
auto tmp4 = invoke(f, &data.data[1], &offsets4.data[1], 1);
auto tmp5 = invoke(f, &data.data[1], &offsets5.data[1], 1);
auto tmp6 = invoke(f, &data.data[1], &offsets6.data[1], 1);
auto tmp7 = invoke(f, &data.data[1], &offsets7.data[1], 1);
*out0 = tmp0;
*out1 = tmp1;
*out2 = tmp2;
*out3 = tmp3;
*out4 = tmp4;
*out5 = tmp5;
*out6 = tmp6;
*out7 = tmp7;
}
else {
auto offsets = offset_calc.get(idx);
arg0_t* out = (arg0_t*)(data[0] + offsets[0]);
*out = invoke(f, &data.data[1], &offsets.data[1], 1);
}
}
});
#endif
}

template <typename func_t>
Expand Down Expand Up @@ -456,7 +315,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
dtypes[i] = iter.dtype(i);
strides[i] = inner_strides[i];
}
launch_legacy_kernel_strided<512, 4>(numel, [=]GPU_LAMBDA(int idx) {
launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) {
void* out = data[0] + strides[0] * idx;
arg0_t result = invoke(f, &data.data[1], &strides.data[1], &dtypes.data[1], idx);
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
Expand Down