Skip to content

[ROCm][WIP] Improve performance of casted elementwise add operations #1805

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
@@ -231,6 +231,36 @@ static void launch_legacy_kernel(int64_t N, const func_t& f) {
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, 4)
__global__ void elementwise_kernel_strided(int N, func_t f) {
int tid = threadIdx.x;
int idx = nt * vt * blockIdx.x + tid;
int step = nt * vt * gridDim.x;
while (idx < N) {
#pragma unroll
for (int i = 0; i < vt; i++) {
if ((idx + nt * i) < N) {
f(idx + nt * i);
}
}
idx += step;
}
}

template <int nt, int vt, typename func_t>
static void launch_legacy_kernel_strided(int64_t N, const func_t& f) {
TORCH_INTERNAL_ASSERT(N >= 0 && N <= std::numeric_limits<int32_t>::max());
if (N == 0) {
return;
}
dim3 block(nt);
dim3 grid(8192);
auto stream = at::cuda::getCurrentCUDAStream();
elementwise_kernel_strided<nt, vt, func_t><<<grid, block, 0, stream>>>(N, f);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <typename traits, typename func_t, typename index_t, size_t... INDEX>
C10_HOST_DEVICE typename traits::result_type invoke_impl(
const func_t& f,
@@ -348,7 +378,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
dtypes[i] = iter.dtype(i);
strides[i] = inner_strides[i];
}
launch_legacy_kernel<512, 1>(numel, [=]GPU_LAMBDA(int idx) {
launch_legacy_kernel_strided<512, 4>(numel, [=]GPU_LAMBDA(int idx) {
void* out = data[0] + strides[0] * idx;
arg0_t result = invoke(f, &data[1], &strides[1], &dtypes[1], idx);
c10::cast_and_store<arg0_t>(dtypes[0], out, result);