Skip to content

Commit 5a6ddbc

Browse files
maajidkhanndivya2108
authored andcommitted
Extending the Pytorch vec backend for SVE (ARM) (pytorch#119571)
**Motivation:** In Pytorch, Aten vectorization supports multiple platforms, including x86 and Arm, as well as multiple data types. It provides a generic implementation of Vector (Vec) type that allows the programmer to write code packing various primitives (such as floats) within 256bit & 512bits registers. It can be extended to support other ISAs easily by adding more VecISA sub-classes. **Reference Link:** https://github.com/pytorch/pytorch/tree/main/aten/src/ATen/cpu/vec **This PR:** * Our goal with this contribution is to add support for SVE backend for Vec in the Aten vectorization for CPU backend which can be benefitted by any ARM architecture supported CPU's that supports SVE. * More about SVE ISA for ARM: [https://developer.arm.com/Architectures/Scalable Vector Extensions](https://developer.arm.com/Architectures/Scalable%20Vector%20Extensions) * We are using the ARM C Language Extensions for SVE (https://developer.arm.com/documentation/102699/0100/Optimizing-with-intrinsics ) to accelerate performance for various operators in the SVE backend for Vec. * Currently we are adding support only for SVE ISA with the vector length of 256 bits (SVE 256). In future, we plan to extend this SVE support for other vector lengths as well. Pull Request resolved: pytorch#119571 Approved by: https://github.com/malfet, https://github.com/snadampal Co-authored-by: Divya Kotadiya <[email protected]>
1 parent bad6904 commit 5a6ddbc

29 files changed

+2554
-9
lines changed

aten/src/ATen/CMakeLists.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ if(NOT BUILD_LITE_INTERPRETER)
5454
endif()
5555
EXCLUDE(ATen_CORE_SRCS "${ATen_CORE_SRCS}" ${ATen_CORE_TEST_SRCS})
5656

57-
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
57+
file(GLOB base_h "*.h" "detail/*.h" "cpu/*.h" "cpu/vec/vec512/*.h" "cpu/vec/vec256/*.h" "cpu/vec/vec256/vsx/*.h" "cpu/vec/vec256/zarch/*.h" "cpu/vec/sve/*.h" "cpu/vec/*.h" "quantized/*.h" "functorch/*.h")
5858
file(GLOB base_cpp "*.cpp" "detail/*.cpp" "cpu/*.cpp" "functorch/*.cpp")
5959
file(GLOB cuda_h "cuda/*.h" "cuda/detail/*.h" "cuda/*.cuh" "cuda/detail/*.cuh" "cuda/tunable/*.cuh" "cuda/tunable/*.h")
6060
file(GLOB cuda_cpp "cuda/*.cpp" "cuda/detail/*.cpp" "cuda/tunable/*.cpp")

aten/src/ATen/Version.cpp

+5
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,11 @@ std::string get_cpu_capability() {
105105
return "DEFAULT";
106106
case native::CPUCapability::ZVECTOR:
107107
return "Z VECTOR";
108+
#elif defined(HAVE_SVE_CPU_DEFINITION)
109+
case native::CPUCapability::DEFAULT:
110+
return "DEFAULT";
111+
case native::CPUCapability::SVE256:
112+
return "SVE256";
108113
#else
109114
case native::CPUCapability::DEFAULT:
110115
return "NO AVX";

aten/src/ATen/cpu/vec/functional_base.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ struct VecReduceAllSIMD<float, Op> {
7878
#endif // defined(CPU_CAPABILITY_AVX512)
7979
#endif // defined(__GNUC__) && (__GNUC__ > 5) && !defined(_MSC_VER) && !defined(C10_MOBILE)
8080

81-
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__)
81+
#if defined(__aarch64__) && !defined(C10_MOBILE) && !defined(__CUDACC__) && !defined(CPU_CAPABILITY_SVE)
8282
template <typename Op>
8383
struct VecReduceAllSIMD<float, Op> {
8484
static inline float apply(const Op& vec_fun, const Vectorized<float>& acc_vec) {

aten/src/ATen/cpu/vec/intrinsics.h

+8
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,10 @@
55
#elif defined(__clang__) && (defined(__ARM_NEON__) || defined(__aarch64__))
66
/* Clang-compatible compiler, targeting arm neon */
77
#include <arm_neon.h>
8+
#if defined(__ARM_FEATURE_SVE)
9+
/* CLANG-compatible compiler, targeting ARM with SVE */
10+
#include <arm_sve.h>
11+
#endif
812
#elif defined(_MSC_VER)
913
/* Microsoft C/C++-compatible compiler */
1014
#include <intrin.h>
@@ -17,6 +21,10 @@
1721
#elif defined(__GNUC__) && (defined(__ARM_NEON__) || defined(__aarch64__))
1822
/* GCC-compatible compiler, targeting ARM with NEON */
1923
#include <arm_neon.h>
24+
#if defined(__ARM_FEATURE_SVE)
25+
/* GCC-compatible compiler, targeting ARM with SVE */
26+
#include <arm_sve.h>
27+
#endif
2028
#if defined (MISSING_ARM_VLD1)
2129
#include <ATen/cpu/vec/vec256/missing_vld1_neon.h>
2230
#elif defined (MISSING_ARM_VST1)
+63
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
#pragma once
2+
3+
#include <ATen/cpu/vec/intrinsics.h>
4+
5+
#include <ATen/cpu/vec/vec_base.h>
6+
7+
#if defined(CPU_CAPABILITY_SVE)
8+
9+
// Define the data type of VLS(vector-length specific).
10+
typedef svbool_t vls_pred_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
11+
typedef svint8_t vls_int8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
12+
typedef svint16_t vls_int16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
13+
typedef svint32_t vls_int32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
14+
typedef svint64_t vls_int64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
15+
typedef svuint8_t vls_uint8_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
16+
typedef svuint16_t vls_uint16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
17+
typedef svuint32_t vls_uint32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
18+
typedef svuint64_t vls_uint64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
19+
typedef svfloat16_t vls_float16_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
20+
typedef svfloat32_t vls_float32_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
21+
typedef svfloat64_t vls_float64_t __attribute__((arm_sve_vector_bits(VECTOR_WIDTH * 8)));
22+
23+
#define ptrue svptrue_b8()
24+
#define ZERO_S8 svdup_n_s8(0)
25+
#define ZERO_S16 svdup_n_s16(0)
26+
#define ZERO_S32 svdup_n_s32(0)
27+
#define ZERO_S64 svdup_n_s64(0)
28+
#define ZERO_U8 svdup_n_u8(0)
29+
#define ZERO_U16 svdup_n_u16(0)
30+
#define ZERO_U32 svdup_n_u32(0)
31+
#define ZERO_U64 svdup_n_u64(0)
32+
#define ZERO_F16 svdup_n_f16(0.f)
33+
#define ZERO_F32 svdup_n_f32(0.f)
34+
#define ZERO_F64 svdup_n_f64(0.0)
35+
#define ONE_S8 svdup_n_s8(1)
36+
#define ONE_S16 svdup_n_s16(1)
37+
#define ONE_S32 svdup_n_s32(1)
38+
#define ONE_S64 svdup_n_s64(1)
39+
#define ONE_U8 svdup_n_u8(1)
40+
#define ONE_U16 svdup_n_u16(1)
41+
#define ONE_U32 svdup_n_u32(1)
42+
#define ONE_U64 svdup_n_u64(1)
43+
#define ONE_F16 svdup_n_f16(1.f)
44+
#define ONE_F32 svdup_n_f32(1.f)
45+
#define ONE_F64 svdup_n_f64(1.0)
46+
#define ALL_S8_TRUE_MASK svdup_n_s8(0xff)
47+
#define ALL_S8_FALSE_MASK svdup_n_s8(0x0)
48+
#define ALL_S16_TRUE_MASK svdup_n_s16(0xffff)
49+
#define ALL_S16_FALSE_MASK svdup_n_s16(0x0)
50+
#define ALL_S32_TRUE_MASK svdup_n_s32(0xffffffff)
51+
#define ALL_S32_FALSE_MASK svdup_n_s32(0x0)
52+
#define ALL_S64_TRUE_MASK svdup_n_s64(0xffffffffffffffff)
53+
#define ALL_S64_FALSE_MASK svdup_n_s64(0x0)
54+
#define ALL_U8_TRUE_MASK svdup_n_u8(0x01)
55+
#define ALL_U8_FALSE_MASK svdup_n_u8(0x00)
56+
#define ALL_F16_TRUE_MASK svreinterpret_f16_s16(ALL_S16_TRUE_MASK)
57+
#define ALL_F16_FALSE_MASK svreinterpret_f16_s16(ALL_S16_FALSE_MASK)
58+
#define ALL_F32_TRUE_MASK svreinterpret_f32_s32(ALL_S32_TRUE_MASK)
59+
#define ALL_F32_FALSE_MASK svreinterpret_f32_s32(ALL_S32_FALSE_MASK)
60+
#define ALL_F64_TRUE_MASK svreinterpret_f64_s64(ALL_S64_TRUE_MASK)
61+
#define ALL_F64_FALSE_MASK svreinterpret_f64_s64(ALL_S64_FALSE_MASK)
62+
63+
#endif // defined(CPU_CAPABILITY_SVE)
+176
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
#pragma once
2+
3+
// DO NOT DEFINE STATIC DATA IN THIS HEADER!
4+
// See Note [Do not compile initializers with SVE]
5+
6+
#include <ATen/cpu/vec/intrinsics.h>
7+
8+
#include <ATen/cpu/vec/vec_base.h>
9+
#include <ATen/cpu/vec/sve/sve_helper.h>
10+
11+
#if defined(CPU_CAPABILITY_SVE)
12+
#include <ATen/cpu/vec/sve/vec_float.h>
13+
#include <ATen/cpu/vec/sve/vec_double.h>
14+
#include <ATen/cpu/vec/sve/vec_int.h>
15+
#include <ATen/cpu/vec/sve/vec_qint.h>
16+
#endif
17+
18+
namespace at {
19+
namespace vec {
20+
// Note [CPU_CAPABILITY namespace]
21+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
22+
// This header, and all of its subheaders, will be compiled with
23+
// different architecture flags for each supported set of vector
24+
// intrinsics. So we need to make sure they aren't inadvertently
25+
// linked together. We do this by declaring objects in an `inline
26+
// namespace` which changes the name mangling, but can still be
27+
// accessed as `at::vec`.
28+
inline namespace CPU_CAPABILITY {
29+
30+
#if defined(CPU_CAPABILITY_SVE)
31+
32+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CAST ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
33+
34+
template<>
35+
inline Vectorized<float> cast<float, double>(const Vectorized<double>& src) {
36+
return svreinterpret_f32_f64(src);
37+
}
38+
39+
template<>
40+
inline Vectorized<double> cast<double, float>(const Vectorized<float>& src) {
41+
return svreinterpret_f64_f32(src);
42+
}
43+
44+
#define DEFINE_FLOAT_INT_CAST(int_t, int_bit, float_t, float_bit) \
45+
template<> \
46+
inline Vectorized<int_t> cast<int_t, float_t>(const Vectorized<float_t>& src) { \
47+
return svreinterpret_s##int_bit##_f##float_bit(src); \
48+
} \
49+
template<> \
50+
inline Vectorized<float_t> cast<float_t, int_t>(const Vectorized<int_t>& src) { \
51+
return svreinterpret_f##float_bit##_s##int_bit(src); \
52+
}
53+
54+
DEFINE_FLOAT_INT_CAST(int64_t, 64, double, 64)
55+
DEFINE_FLOAT_INT_CAST(int32_t, 32, double, 64)
56+
DEFINE_FLOAT_INT_CAST(int16_t, 16, double, 64)
57+
DEFINE_FLOAT_INT_CAST(int64_t, 64, float, 32)
58+
DEFINE_FLOAT_INT_CAST(int32_t, 32, float, 32)
59+
DEFINE_FLOAT_INT_CAST(int16_t, 16, float, 32)
60+
61+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
62+
63+
template<int64_t scale = 1>
64+
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
65+
inline gather(const double* base_addr, const Vectorized<int64_t>& vindex_) {
66+
svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3);
67+
return svld1_gather_s64index_f64(ptrue, base_addr, vindex);
68+
}
69+
70+
template<int64_t scale = 1>
71+
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
72+
inline gather(const float* base_addr, const Vectorized<int32_t>& vindex_) {
73+
svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2);
74+
return svld1_gather_s32index_f32(ptrue, base_addr, vindex);
75+
}
76+
77+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ MASK GATHER ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
78+
79+
template<int64_t scale = 1>
80+
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<double>>
81+
inline mask_gather(const Vectorized<double>& src, const double* base_addr,
82+
const Vectorized<int64_t>& vindex_, const Vectorized<double>& mask_) {
83+
svbool_t mask = svcmpeq_s64(ptrue, svreinterpret_s64_f64(mask_),
84+
ALL_S64_TRUE_MASK);
85+
svint64_t vindex = svasrd_n_s64_x(ptrue, svmul_s64_x(ptrue, vindex_, svdup_n_s64(scale)), 3);
86+
return svsel_f64(mask, svld1_gather_s64index_f64(mask, base_addr, vindex), src);
87+
}
88+
89+
template<int64_t scale = 1>
90+
std::enable_if_t<scale == 1 || scale == 2 || scale == 4 || scale == 8, Vectorized<float>>
91+
inline mask_gather(const Vectorized<float>& src, const float* base_addr,
92+
const Vectorized<int32_t>& vindex_, const Vectorized<float>& mask_) {
93+
svbool_t mask = svcmpeq_s32(ptrue, svreinterpret_s32_f32(mask_),
94+
ALL_S32_TRUE_MASK);
95+
svint32_t vindex = svasrd_n_s32_x(ptrue, svmul_s32_x(ptrue, vindex_, svdup_n_s32(scale)), 2);
96+
return svsel_f32(mask, svld1_gather_s32index_f32(mask, base_addr, vindex), src);
97+
}
98+
99+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ CONVERT ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
100+
101+
// Only works for inputs in the range: [-2^51, 2^51]
102+
// From: https://stackoverflow.com/a/41148578
103+
template<>
104+
Vectorized<int64_t>
105+
inline convert_to_int_of_same_size<double>(const Vectorized<double> &src) {
106+
svfloat64_t x = svadd_f64_x(ptrue, src, svdup_n_f64(0x0018000000000000));
107+
return svsub_s64_x(ptrue,
108+
svreinterpret_s64_f64(x),
109+
svreinterpret_s64_f64(svdup_n_f64(0x0018000000000000)));
110+
}
111+
112+
template<>
113+
Vectorized<int32_t>
114+
inline convert_to_int_of_same_size<float>(const Vectorized<float> &src) {
115+
return svcvt_s32_f32_x(ptrue, src);
116+
}
117+
118+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ INTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
119+
120+
template <>
121+
std::pair<Vectorized<double>, Vectorized<double>>
122+
inline interleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
123+
// inputs:
124+
// a = {a0, a1, a3, a3}
125+
// b = {b0, b1, b2, b3}
126+
// group cols crossing lanes:
127+
// return {a0, b0, a1, b1}
128+
// {a2, b2, a3, b3}
129+
return std::make_pair(Vectorized<double>(svzip1_f64(a, b)),
130+
Vectorized<double>(svzip2_f64(a, b)));
131+
}
132+
133+
template <>
134+
std::pair<Vectorized<float>, Vectorized<float>>
135+
inline interleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
136+
// inputs:
137+
// a = {a0, a1, a2, a3, a4, a5, a6, a7}
138+
// b = {b0, b1, b2, b3, b4, b5, b6, b7}
139+
// group cols crossing lanes:
140+
// return {a0, b0, a1, b1, a2, b2, a3, b3}
141+
// {a4, b4, a5, b5, a6, b6, a7, b7}
142+
return std::make_pair(Vectorized<float>(svzip1_f32(a, b)),
143+
Vectorized<float>(svzip2_f32(a, b)));
144+
}
145+
146+
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ DEINTERLEAVE ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
147+
148+
template <>
149+
std::pair<Vectorized<double>, Vectorized<double>>
150+
inline deinterleave2<double>(const Vectorized<double>& a, const Vectorized<double>& b) {
151+
// inputs:
152+
// a = {a0, b0, a1, b1}
153+
// b = {a2, b2, a3, b3}
154+
// swap lanes:
155+
// return {a0, a1, a2, a3}
156+
// {b0, b1, b2, b3}
157+
return std::make_pair(Vectorized<double>(svuzp1_f64(a, b)),
158+
Vectorized<double>(svuzp2_f64(a, b)));
159+
}
160+
161+
template <>
162+
std::pair<Vectorized<float>, Vectorized<float>>
163+
inline deinterleave2<float>(const Vectorized<float>& a, const Vectorized<float>& b) {
164+
// inputs:
165+
// a = {a0, b0, a1, b1, a2, b2, a3, b3}
166+
// b = {a4, b4, a5, b5, a6, b6, a7, b7}
167+
// swap lanes:
168+
// return {a0, a1, a2, a3, a4, a5, a6, a7}
169+
// {b0, b1, b2, b3, b4, b5, b6, b7}
170+
return std::make_pair(Vectorized<float>(svuzp1_f32(a, b)),
171+
Vectorized<float>(svuzp2_f32(a, b)));
172+
}
173+
174+
#endif // defined(CPU_CAPABILITY_SVE)
175+
176+
}}}

0 commit comments

Comments
 (0)