Skip to content
Open
30 changes: 20 additions & 10 deletions src/ATen/native/xpu/sycl/GroupNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,22 +63,27 @@ template <typename T, int SIMD>
struct GNRowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
using T_ACC = acc_type_device<T, kXPU>;
using WelfordType = WelfordData<T_ACC, int64_t>;
using WelfordOp =
WelfordOpsXPU<T_ACC, T_ACC, int64_t, std::pair<T_ACC, T_ACC>>;
using WelfordOp = WelfordOps<T_ACC, T_ACC, int64_t, std::pair<T_ACC, T_ACC>>;

[[intel::reqd_sub_group_size(SIMD)]] void operator()(
sycl::nd_item<1> item) const {
const int64_t i = item.get_group(0);
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item};
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the nd_item object?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WelfordOpXPU is replaced by WelfordOp, which does not need item as a pamameter.

WelfordType val(0, 0, 0, 0);
WelfordType identity_element(0, 0, 0, 0);
for (int64_t j = item.get_local_id(0); j < N_;
j += item.get_local_range(0)) {
const int64_t index = i * N_ + j;
val = welford_op.reduce(val, static_cast<T_ACC>(X_[index]), index);
}

val = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
item, val, welford_op, shared_);
if (item.get_local_range(0) <= SIMD) {
val = SubgroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
item, val, welford_op);
} else {
val = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
item, val, welford_op, identity_element, shared_);
}

if (item.get_local_id(0) == 0) {
T_ACC m1;
Expand Down Expand Up @@ -110,14 +115,14 @@ struct GNRowwiseMomentsVectorizedFunctor
: public __SYCL_KER_CONFIG_CONVENTION__ {
using T_ACC = acc_type_device<T, kXPU>;
using WelfordType = WelfordData<T_ACC, int64_t>;
using WelfordOp =
WelfordOpsXPU<T_ACC, T_ACC, int64_t, std::pair<T_ACC, T_ACC>>;
using WelfordOp = WelfordOps<T_ACC, T_ACC, int64_t, std::pair<T_ACC, T_ACC>>;
using vec_t = memory::aligned_vector<T, VEC_SIZE>;

[[intel::reqd_sub_group_size(SIMD)]] void operator()(
sycl::nd_item<1> item) const {
WelfordType val[VEC_SIZE];
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false, item};
WelfordType identity_element(0, 0, 0, 0);
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
auto group_start = item.get_group(0) * VEC_SIZE;

#pragma unroll
Expand All @@ -138,8 +143,13 @@ struct GNRowwiseMomentsVectorizedFunctor

#pragma unroll
for (int v = 0; v < VEC_SIZE; ++v) {
val[v] = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
item, val[v], welford_op, shared_);
if (item.get_local_range(0) <= SIMD) {
val[v] = SubgroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
item, val[v], welford_op);
} else {
val[v] = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
item, val[v], welford_op, identity_element, shared_);
}
}

if (item.get_local_id(0) == 0) {
Expand Down
7 changes: 6 additions & 1 deletion src/ATen/native/xpu/sycl/GroupReduceUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,10 @@ inline T& GroupReduceWithoutBroadcast(
sycl::nd_item<DIM>& item,
T& val,
const ReduceOp& op,
const T& identity_element,
shared_t shared) {
auto sg = item.get_sub_group();
int g_tid = item.get_local_linear_id();
int sg_tid = sg.get_local_linear_id();
int sg_id = sg.get_group_linear_id();
int n_sg = get_local_linear_range<DIM>(item) / SIMD;
Expand All @@ -151,10 +153,13 @@ inline T& GroupReduceWithoutBroadcast(
shared[sg_id] = val;
}
item.barrier(sycl_local_fence);
val = identity_element;

if (sg_id == 0) {
for (int i = 1; i < n_sg; i++) {
for (int i = sg_tid; i < n_sg; i += SIMD) {
val = op.combine(val, shared[i]);
}
val = SubgroupReduceWithoutBroadcast<T, ReduceOp, SIMD, DIM>(item, val, op);
}
return val;
}
Expand Down
3 changes: 2 additions & 1 deletion src/ATen/native/xpu/sycl/LayerNormKernels.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,14 +191,15 @@ struct RowwiseMomentsFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
const int64_t i = item_id.get_group(0);
WelfordOp welford_op = {/*correction=*/0, /*take_sqrt=*/false};
WelfordType val(0, 0, 0, 0);
WelfordType identity_element(0, 0, 0, 0);
for (int64_t j = item_id.get_local_id(0); j < N_;
j += item_id.get_local_range(0)) {
const int64_t index = i * N_ + j;
val = welford_op.reduce(val, static_cast<T_ACC>(X_[index]), index);
}

val = GroupReduceWithoutBroadcast<WelfordType, WelfordOp, SIMD>(
item_id, val, welford_op, shared_);
item_id, val, welford_op, identity_element, shared_);

if (item_id.get_local_id(0) == 0) {
T_ACC m1;
Expand Down
3 changes: 2 additions & 1 deletion src/ATen/native/xpu/sycl/TensorModeKernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ inline T reduceGroupWithNThreadLocalReductions(
T init) {
int offset = item.get_local_id(2) * N;
T local = offset < numVals ? threadVals[0] : init;
T identity_element = init;

#pragma unroll
for (int i = 1; i < N; ++i) {
Expand All @@ -226,7 +227,7 @@ inline T reduceGroupWithNThreadLocalReductions(
}

return GroupReduceWithoutBroadcast<T, ReduceOp, 32>(
item, local, reduceOp, smem);
item, local, reduceOp, identity_element, smem);
}

template <typename T, unsigned int Power2Size>
Expand Down