|
| 1 | +#pragma once |
| 2 | + |
| 3 | +// DO NOT DEFINE STATIC DATA IN THIS HEADER! |
| 4 | +// See Note [Do not compile initializers with SVE] |
| 5 | + |
| 6 | +#include <ATen/cpu/vec/intrinsics.h> |
| 7 | + |
| 8 | +#include <ATen/cpu/vec/vec_base.h> |
| 9 | +#include <ATen/cpu/vec/sve/sve_helper.h> |
| 10 | + |
| 11 | +#if defined(CPU_CAPABILITY_SVE) |
| 12 | +#include <ATen/cpu/vec/sve/vec_float.h> |
| 13 | +#include <ATen/cpu/vec/sve/vec_double.h> |
| 14 | +#include <ATen/cpu/vec/sve/vec_int.h> |
| 15 | +#include <ATen/cpu/vec/sve/vec_qint.h> |
| 16 | +#endif |
| 17 | + |
| 18 | +namespace at { |
| 19 | +namespace vec { |
| 20 | +// Note [CPU_CAPABILITY namespace] |
| 21 | +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 22 | +// This header, and all of its subheaders, will be compiled with |
| 23 | +// different architecture flags for each supported set of vector |
| 24 | +// intrinsics. So we need to make sure they aren't inadvertently |
| 25 | +// linked together. We do this by declaring objects in an `inline |
| 26 | +// namespace` which changes the name mangling, but can still be |
| 27 | +// accessed as `at::vec`. |
| 28 | +inline namespace CPU_CAPABILITY { |
| 29 | + |
| 30 | +#if defined(CPU_CAPABILITY_SVE) |
| 31 | + |
| 32 | +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 33 | + |
| 34 | +template<> |
| 35 | +inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) { |
| 36 | + return svreinterpret_f32_f64(src); |
| 37 | +} |
| 38 | + |
| 39 | +template<> |
| 40 | +inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) { |
| 41 | + return svreinterpret_f64_f32(src); |
| 42 | +} |
| 43 | + |
| 44 | +#define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \ |
| 45 | +template<> \ |
| 46 | +inline Vectorized<int_t> cast<int_t, float_t>(const Vectorized<float_t>& src) { \ |
| 47 | + return svreinterpret_s##int_bit##_f##float_bit(src); \ |
| 48 | +} \ |
| 49 | +template<> \ |
| 50 | +inline Vectorized<float_t> cast<float_t, int_t>(const Vectorized<int_t>& src) { \ |
| 51 | + return svreinterpret_f##float_bit##_s##int_bit(src); \ |
| 52 | +} |
| 53 | + |
| 54 | +DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64) |
| 55 | +DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64) |
| 56 | +DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64) |
| 57 | +DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32) |
| 58 | +DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32) |
| 59 | +DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32) |
| 60 | + |
| 61 | +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 62 | + |
| 63 | +template<int64_t scale = 1> |
| 64 | +std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>> |
| 65 | +inline gather(const double* base_addr, const Vectorized<int64_t>& vindex_) { |
| 66 | + svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); |
| 67 | + return svld1_gather_s64index_f64(ptrue, base_addr, vindex); |
| 68 | +} |
| 69 | + |
| 70 | +template<int64_t scale = 1> |
| 71 | +std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>> |
| 72 | +inline gather(const float* base_addr, const Vectorized<int32_t>& vindex_) { |
| 73 | + svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); |
| 74 | + return svld1_gather_s32index_f32(ptrue, base_addr, vindex); |
| 75 | +} |
| 76 | + |
| 77 | +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 78 | + |
| 79 | +template<int64_t scale = 1> |
| 80 | +std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>> |
| 81 | +inline mask_gather(const Vectorized<double>& src, const double* base_addr, |
| 82 | + const Vectorized<int64_t>& vindex_, const Vectorized<double>& mask_) { |
| 83 | + svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_), |
| 84 | + ALL_S64_TRUE_MASK); |
| 85 | + svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3); |
| 86 | + return svsel_f64(mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src); |
| 87 | +} |
| 88 | + |
| 89 | +template<int64_t scale = 1> |
| 90 | +std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>> |
| 91 | +inline mask_gather(const Vectorized<float>& src, const float* base_addr, |
| 92 | + const Vectorized<int32_t>& vindex_, const Vectorized<float>& mask_) { |
| 93 | + svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_), |
| 94 | + ALL_S32_TRUE_MASK); |
| 95 | + svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2); |
| 96 | + return svsel_f32(mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src); |
| 97 | +} |
| 98 | + |
| 99 | +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 100 | + |
| 101 | +// Only works for inputs in the range: [-2^51, 2^51] |
| 102 | +// From: https://stackoverflow.com/a/41148578 |
| 103 | +template<> |
| 104 | +Vectorized<int64_t> |
| 105 | +inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) { |
| 106 | + svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000)); |
| 107 | + return svsub_s64_x(ptrue, |
| 108 | + svreinterpret_s64_f64(x), |
| 109 | + svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000))); |
| 110 | +} |
| 111 | + |
| 112 | +template<> |
| 113 | +Vectorized<int32_t> |
| 114 | +inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) { |
| 115 | + return svcvt_s32_f32_x(ptrue, src); |
| 116 | +} |
| 117 | + |
| 118 | +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 119 | + |
| 120 | +template <> |
| 121 | +std::pair<Vectorized<double>, Vectorized<double>> |
| 122 | +inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) { |
| 123 | + // inputs: |
| 124 | + // a = {a0, a1, a3, a3} |
| 125 | + // b = {b0, b1, b2, b3} |
| 126 | + // group cols crossing lanes: |
| 127 | + // return {a0, b0, a1, b1} |
| 128 | + // {a2, b2, a3, b3} |
| 129 | + return std::make_pair(Vectorized<double>(svzip1_f64(a, b)), |
| 130 | + Vectorized<double>(svzip2_f64(a, b))); |
| 131 | +} |
| 132 | + |
| 133 | +template <> |
| 134 | +std::pair<Vectorized<float>, Vectorized<float>> |
| 135 | +inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) { |
| 136 | + // inputs: |
| 137 | + // a = {a0, a1, a2, a3, a4, a5, a6, a7} |
| 138 | + // b = {b0, b1, b2, b3, b4, b5, b6, b7} |
| 139 | + // group cols crossing lanes: |
| 140 | + // return {a0, b0, a1, b1, a2, b2, a3, b3} |
| 141 | + // {a4, b4, a5, b5, a6, b6, a7, b7} |
| 142 | + return std::make_pair(Vectorized<float>(svzip1_f32(a, b)), |
| 143 | + Vectorized<float>(svzip2_f32(a, b))); |
| 144 | +} |
| 145 | + |
| 146 | +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |
| 147 | + |
| 148 | +template <> |
| 149 | +std::pair<Vectorized<double>, Vectorized<double>> |
| 150 | +inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) { |
| 151 | + // inputs: |
| 152 | + // a = {a0, b0, a1, b1} |
| 153 | + // b = {a2, b2, a3, b3} |
| 154 | + // swap lanes: |
| 155 | + // return {a0, a1, a2, a3} |
| 156 | + // {b0, b1, b2, b3} |
| 157 | + return std::make_pair(Vectorized<double>(svuzp1_f64(a, b)), |
| 158 | + Vectorized<double>(svuzp2_f64(a, b))); |
| 159 | +} |
| 160 | + |
| 161 | +template <> |
| 162 | +std::pair<Vectorized<float>, Vectorized<float>> |
| 163 | +inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) { |
| 164 | + // inputs: |
| 165 | + // a = {a0, b0, a1, b1, a2, b2, a3, b3} |
| 166 | + // b = {a4, b4, a5, b5, a6, b6, a7, b7} |
| 167 | + // swap lanes: |
| 168 | + // return {a0, a1, a2, a3, a4, a5, a6, a7} |
| 169 | + // {b0, b1, b2, b3, b4, b5, b6, b7} |
| 170 | + return std::make_pair(Vectorized<float>(svuzp1_f32(a, b)), |
| 171 | + Vectorized<float>(svuzp2_f32(a, b))); |
| 172 | +} |
| 173 | + |
| 174 | +#endif // defined(CPU_CAPABILITY_SVE) |
| 175 | + |
| 176 | +}}} |
0 commit comments