Skip to content

Enable load-compute-store interleaving for unrolled elementwise kernel. #1886

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: release/2.5
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 124 additions & 0 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,29 @@ __global__ void unrolled_elementwise_kernel(
elementwise_kernel_helper(f, policy);
}

template <
typename func_t,
typename array_t,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
typename storer_t>
C10_LAUNCH_BOUNDS_1(num_threads())
__global__ void unrolled_templated_elementwise_kernel(
int N,
func_t f,
array_t data,
inp_calc_t ic,
out_calc_t oc,
loader_t l,
storer_t s) {
int remaining = N - block_work_size() * blockIdx.x;
auto policy = memory::policies::
unroll<array_t, inp_calc_t, out_calc_t, loader_t, storer_t>(
data, remaining, ic, oc, l, s);
unrolled_templated_elementwise_kernel_helper(f, policy);
}

// this function assume trivial 1d and no dynamic casting
template <typename func_t, typename array_t>
static inline void launch_vectorized_kernel(
Expand Down Expand Up @@ -170,6 +193,30 @@ static inline void launch_unrolled_kernel(
C10_CUDA_KERNEL_LAUNCH_CHECK();
}

template <
typename func_t,
typename array_t,
typename inp_calc_t,
typename out_calc_t,
typename loader_t,
typename storer_t>
static inline void launch_unrolled_templated_kernel(
int64_t N,
const func_t& f,
array_t data,
inp_calc_t ic,
out_calc_t oc,
loader_t l,
storer_t s) {
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
int64_t grid = (N + block_work_size() - 1) / block_work_size();
auto stream = at::cuda::getCurrentCUDAStream();
unrolled_templated_elementwise_kernel<func_t, array_t>
<<<grid, num_threads(), 0, stream>>>(N, f, data, ic, oc, l, s);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}


template <int nt, int vt, typename func_t>
C10_LAUNCH_BOUNDS_2(nt, 4)
__global__ void elementwise_kernel(int N, func_t f) {
Expand Down Expand Up @@ -425,6 +472,44 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
#endif
}

namespace {
template<typename TupleLike, size_t arity, size_t arg_num=0>
struct check_types {
constexpr static inline bool check() {
bool current = false;
if constexpr (arity != 2) return false;
if constexpr (arg_num == 0) {
using SelectedType = std::tuple_element_t<arg_num, TupleLike>;
if constexpr (std::is_same_v<float, SelectedType>)
return check_types<TupleLike, arity, arg_num+1>::check();
} else if constexpr (arg_num == 1) {
using SelectedType2 = std::tuple_element_t<arg_num, TupleLike>;
if constexpr (std::is_same_v<float, SelectedType2>)
return check_types<TupleLike, arity, arg_num+1>::check();
}
return false;
}
};

// Bottom case: if we got this far, assume correct type matching except
// when there are no arguments (arity == 0).
template<typename TupleLike, size_t arity>
struct check_types<TupleLike, arity, arity> {
constexpr static inline bool check() {
if constexpr (arity != 0)
return true;
return false;
}
};

template<typename TupleLike>
struct check_types<TupleLike, 0, 0> {
constexpr static inline bool check() {
return false;
}
};
} // namespace anonymous

template <typename func_t>
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
if (!needs_dynamic_casting<func_t>::check(iter)) {
Expand All @@ -449,6 +534,45 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {

if (contiguous) {
#ifdef USE_ROCM
// Attempt to call specialized unrolled elementwise kernel
// that enables interleaving.
using float_map = c10::CppTypeToScalarType<float>;
using bfloat16_map = c10::CppTypeToScalarType<BFloat16>;
int64_t grid = (numel + block_work_size() - 1) / block_work_size();
// Number of iterations is a perfect multiple of the grid size
// to avoid bound checking and enabling loop unrolling without
// intervening basic blocks, which prevents interleaving.
if (iter.ninputs() == 2 &&
iter.input_dtype(0) == float_map::value &&
iter.input_dtype(1) == bfloat16_map::value &&
!(numel%(block_work_size()*grid))) {
// constexpr to reduce the amount of kernels (empty) generated for
// unrolled templated elementwise and limit which functors are actually
// applied to the load and store at compile time.
using func_tuple = typename traits::ArgsTuple;
if constexpr (std::is_same_v<float,arg0_t> &&
traits::arity == 2 &&
check_types<func_tuple, traits::arity, 0>::check()) {
// templated load/store for specific data type remove the need for a runtime
// switch statement over the input tensor type. This, together with
// no bound checks, enables memory instruction interleaving with
// compute.
auto loader = memory::TemplatedLoad<float, float, BFloat16>();
auto storer = memory::TemplatedStore<float, float>();
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
auto output_offset_calculator = TrivialOffsetCalculator<1>();
launch_unrolled_templated_kernel(
numel,
f,
data,
input_offset_calculator,
output_offset_calculator,
loader,
storer);
return;
}
}

at::detail::Array<ScalarType, ntensors> dtypes;
auto inner_strides = iter.get_inner_strides();
at::detail::Array<int, ntensors> strides;
Expand Down
23 changes: 23 additions & 0 deletions aten/src/ATen/native/cuda/Loops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,29 @@ __device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
policy.store(results, idx);
}

template<typename func_t, typename policy_t>
__device__ inline void unrolled_templated_elementwise_kernel_helper(func_t f, policy_t policy) {
using traits = function_traits<func_t>;
using return_t = typename traits::result_type;
using args_t = typename traits::ArgsTuple;

int idx = blockIdx.x;

return_t results[thread_work_size()];
args_t args[thread_work_size()];

// load
policy.templatedLoad(args,idx);

// compute (no bound checks here, they are done in the callers).
#pragma unroll
for (int i = 0; i < thread_work_size(); i++) {
results[i] = c10::guts::apply(f, args[i]);
}

// store
policy.templatedStore(results, idx);
}
}} // namespace at::native

#include <ATen/native/cuda/CUDALoops.cuh>
Expand Down
62 changes: 62 additions & 0 deletions aten/src/ATen/native/cuda/MemoryAccess.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,19 @@ struct unroll_load_helper {
}
};

template<int arg_index>
struct unroll_load_helper_templated {
template <typename args_t, typename policy_t, typename offset_t, typename loader_t>
static __device__ void apply(policy_t &self, args_t *args, offset_t offset, loader_t loader, int j, int num_outputs) {
// type instantiation has already been done on the host based on argument runtime types.
// Here, the argument index is enough to retrieve the type from the load variadic template argument.
// using arg_t = std::tuple_element_t<arg_index, args_t>;
// `data` hold the data_ptr for tensors [output, input0, input1, ...], so we
// need a +1 offset to get the input
std::get<arg_index>(args[j]) = loader.template load<arg_index>(self.data[arg_index + num_outputs], offset[arg_index]);
}
};

template <int current>
struct multi_outputs_store_helper {
template<int ntensors, int num_outputs, typename ...Args>
Expand Down Expand Up @@ -155,6 +168,28 @@ struct StoreWithCast {
}
};

template<typename CastToT, typename... CastFromTs>
struct TemplatedLoad {
template<int arg_index>
__device__ CastToT load(char *base_ptr, uint32_t offset) {
// extract the arg_index-th input tensor element type from the
// variadic template argument.
using CastFromT = std::tuple_element_t<arg_index,
std::tuple<CastFromTs...>>;
void *ptr = base_ptr + sizeof(CastFromT) * offset;
return c10::convert<CastToT>(c10::load<CastFromT>(ptr));
}
};

// This only supports a single output tensors.
template<typename CastTo, typename CastFrom>
struct TemplatedStore {
__device__ void store(CastFrom value, char *base_ptr, uint32_t offset, int arg=0) {
void *ptr = base_ptr + sizeof(CastTo) * offset;
*(CastTo*)ptr = c10::convert<CastTo>(value);
}
};

// aligned vector generates vectorized load/store on CUDA
template<typename scalar_t, int vec_size>
struct alignas(sizeof(scalar_t) * vec_size) aligned_vector {
Expand Down Expand Up @@ -230,6 +265,33 @@ struct unroll {
thread_idx += num_threads();
}
}

// Load and store used for interleaving: no bound checks (moved to callers) to prevent
// extra basic blocks; use templated version of load/store.
template<typename args_t>
__device__ inline void templatedLoad(args_t *args, int idx) {
constexpr int arity = std::tuple_size<args_t>::value;
int thread_idx = threadIdx.x;
#pragma unroll
for (int i = 0; i < thread_work_size(); i++) {
int linear_idx = thread_idx + block_work_size() * idx;
auto offset = input_offset_calculator.get(linear_idx);
detail::static_unroll<detail::unroll_load_helper_templated, arity>::with_args(*this, args, offset, loader, i, num_outputs);
thread_idx += num_threads();
}
}

template<typename scalar_t>
__device__ inline void templatedStore(scalar_t *from, int idx) {
int thread_idx = threadIdx.x;
#pragma unroll
for (int i = 0; i < thread_work_size(); i++) {
int linear_idx = thread_idx + block_work_size() * idx;
int offset = output_offset_calculator.get(linear_idx)[0];
storer.store(from[i], data[0], offset);
thread_idx += num_threads();
}
}
};

// Assumption:
Expand Down