Skip to content
This repository was archived by the owner on Apr 28, 2025. It is now read-only.

Commit f851cb5

Browse files
committed
Add assembly version of simple operations on aarch64
Replace `core::arch` versions of the following with handwritten assembly, which avoids recursion issues (cg_gcc using `rint` as a fallback) as well as problems with `aarch64be`. * `rint` * `rintf` Additionally, add assembly versions of the following: * `fma` * `fmaf` * `sqrt` * `sqrtf` If the `fp16` target feature is available, which implies `neon`, also include the following: * `rintf16` * `sqrtf16` `sqrt` is added to match the implementation for `x86`. `fma` is included since it is used by many other routines. There are a handful of other operations that have assembly implementations. They are omitted here because we should have basic float math routines available in `core` in the near future, which will allow us to defer to LLVM for assembly lowering rather than implementing these ourselves.
1 parent c9672e5 commit f851cb5

File tree

9 files changed

+135
-28
lines changed

9 files changed

+135
-28
lines changed

etc/function-definitions.json

+6
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,14 @@
342342
},
343343
"fma": {
344344
"sources": [
345+
"src/math/arch/aarch64.rs",
345346
"src/math/fma.rs"
346347
],
347348
"type": "f64"
348349
},
349350
"fmaf": {
350351
"sources": [
352+
"src/math/arch/aarch64.rs",
351353
"src/math/fma_wide.rs"
352354
],
353355
"type": "f32"
@@ -806,6 +808,7 @@
806808
},
807809
"rintf16": {
808810
"sources": [
811+
"src/math/arch/aarch64.rs",
809812
"src/math/rint.rs"
810813
],
811814
"type": "f16"
@@ -928,6 +931,7 @@
928931
},
929932
"sqrt": {
930933
"sources": [
934+
"src/math/arch/aarch64.rs",
931935
"src/math/arch/i686.rs",
932936
"src/math/arch/wasm32.rs",
933937
"src/math/generic/sqrt.rs",
@@ -937,6 +941,7 @@
937941
},
938942
"sqrtf": {
939943
"sources": [
944+
"src/math/arch/aarch64.rs",
940945
"src/math/arch/i686.rs",
941946
"src/math/arch/wasm32.rs",
942947
"src/math/generic/sqrt.rs",
@@ -953,6 +958,7 @@
953958
},
954959
"sqrtf16": {
955960
"sources": [
961+
"src/math/arch/aarch64.rs",
956962
"src/math/generic/sqrt.rs",
957963
"src/math/sqrtf16.rs"
958964
],

src/math/arch/aarch64.rs

+86-22
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,97 @@
1-
use core::arch::aarch64::{
2-
float32x2_t, float64x1_t, vdup_n_f32, vdup_n_f64, vget_lane_f32, vget_lane_f64, vrndn_f32,
3-
vrndn_f64,
4-
};
1+
//! Architecture-specific support for aarch64 with neon.
52
6-
pub fn rint(x: f64) -> f64 {
7-
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
8-
let x_vec: float64x1_t = unsafe { vdup_n_f64(x) };
3+
use core::arch::asm;
94

10-
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
11-
let result_vec: float64x1_t = unsafe { vrndn_f64(x_vec) };
5+
pub fn fma(mut x: f64, y: f64, z: f64) -> f64 {
6+
unsafe {
7+
asm!(
8+
"fmadd {x:d}, {x:d}, {y:d}, {z:d}",
9+
x = inout(vreg) x,
10+
y = in(vreg) y,
11+
z = in(vreg) z,
12+
options(nomem, nostack, pure)
13+
);
14+
}
15+
x
16+
}
1217

13-
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
14-
let result: f64 = unsafe { vget_lane_f64::<0>(result_vec) };
18+
pub fn fmaf(mut x: f32, y: f32, z: f32) -> f32 {
19+
unsafe {
20+
asm!(
21+
"fmadd {x:s}, {x:s}, {y:s}, {z:s}",
22+
x = inout(vreg) x,
23+
y = in(vreg) y,
24+
z = in(vreg) z,
25+
options(nomem, nostack, pure)
26+
);
27+
}
28+
x
29+
}
1530

16-
result
31+
pub fn rint(mut x: f64) -> f64 {
32+
unsafe {
33+
asm!(
34+
"frinti {x:d}, {x:d}",
35+
x = inout(vreg) x,
36+
options(nomem, nostack, pure)
37+
);
38+
}
39+
x
1740
}
1841

19-
pub fn rintf(x: f32) -> f32 {
20-
// There's a scalar form of this instruction (FRINTN) but core::arch doesn't expose it, so we
21-
// have to use the vector form and drop the other lanes afterwards.
42+
pub fn rintf(mut x: f32) -> f32 {
43+
unsafe {
44+
asm!(
45+
"frinti {x:s}, {x:s}",
46+
x = inout(vreg) x,
47+
options(nomem, nostack, pure)
48+
);
49+
}
50+
x
51+
}
2252

23-
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
24-
let x_vec: float32x2_t = unsafe { vdup_n_f32(x) };
53+
#[cfg(all(f16_enabled, target_feature = "fp16"))]
54+
pub fn rintf16(mut x: f16) -> f16 {
55+
unsafe {
56+
asm!(
57+
"frinti {x:h}, {x:h}",
58+
x = inout(vreg) x,
59+
options(nomem, nostack, pure)
60+
);
61+
}
62+
x
63+
}
2564

26-
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
27-
let result_vec: float32x2_t = unsafe { vrndn_f32(x_vec) };
65+
pub fn sqrt(mut x: f64) -> f64 {
66+
unsafe {
67+
asm!(
68+
"fsqrt {x:d}, {x:d}",
69+
x = inout(vreg) x,
70+
options(nomem, nostack, pure)
71+
);
72+
}
73+
x
74+
}
2875

29-
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
30-
let result: f32 = unsafe { vget_lane_f32::<0>(result_vec) };
76+
pub fn sqrtf(mut x: f32) -> f32 {
77+
unsafe {
78+
asm!(
79+
"fsqrt {x:s}, {x:s}",
80+
x = inout(vreg) x,
81+
options(nomem, nostack, pure)
82+
);
83+
}
84+
x
85+
}
3186

32-
result
87+
#[cfg(all(f16_enabled, target_feature = "fp16"))]
88+
pub fn sqrtf16(mut x: f16) -> f16 {
89+
unsafe {
90+
asm!(
91+
"fsqrt {x:h}, {x:h}",
92+
x = inout(vreg) x,
93+
options(nomem, nostack, pure)
94+
);
95+
}
96+
x
3397
}

src/math/arch/mod.rs

+17-4
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,25 @@ cfg_if! {
1818
mod i686;
1919
pub use i686::{sqrt, sqrtf};
2020
} else if #[cfg(all(
21-
target_arch = "aarch64", // TODO: also arm64ec?
22-
target_feature = "neon",
23-
target_endian = "little", // see https://github.com/rust-lang/stdarch/issues/1484
21+
any(target_arch = "aarch64", target_arch = "arm64ec"),
22+
target_feature = "neon"
2423
))] {
2524
mod aarch64;
26-
pub use aarch64::{rint, rintf};
25+
26+
pub use aarch64::{
27+
fma,
28+
fmaf,
29+
rint,
30+
rintf,
31+
sqrt,
32+
sqrtf,
33+
};
34+
35+
#[cfg(all(f16_enabled, target_feature = "fp16"))]
36+
pub use aarch64::{
37+
rintf16,
38+
sqrtf16,
39+
};
2740
}
2841
}
2942

src/math/fma.rs

+5
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ use super::{CastFrom, CastInto, Float, Int, MinInt};
99
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
1010
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
1111
pub fn fma(x: f64, y: f64, z: f64) -> f64 {
12+
select_implementation! {
13+
name: fma,
14+
use_arch: all(target_arch = "aarch64", target_feature = "neon"),
15+
args: x, y, z,
16+
}
1217
fma_round(x, y, z, Round::Nearest).val
1318
}
1419

src/math/fma_wide.rs

+5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ pub(crate) fn fmaf16(_x: f16, _y: f16, _z: f16) -> f16 {
1717
/// Computes `(x*y)+z`, rounded as one ternary operation (i.e. calculated with infinite precision).
1818
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
1919
pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
20+
select_implementation! {
21+
name: fmaf,
22+
use_arch: all(target_arch = "aarch64", target_feature = "neon"),
23+
args: x, y, z,
24+
}
2025
fma_wide_round(x, y, z, Round::Nearest).val
2126
}
2227

src/math/rint.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,12 @@ use super::support::Round;
44
#[cfg(f16_enabled)]
55
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
66
pub fn rintf16(x: f16) -> f16 {
7+
select_implementation! {
8+
name: rintf16,
9+
use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
10+
args: x,
11+
}
12+
713
super::generic::rint_round(x, Round::Nearest).val
814
}
915

@@ -13,8 +19,8 @@ pub fn rintf(x: f32) -> f32 {
1319
select_implementation! {
1420
name: rintf,
1521
use_arch: any(
22+
all(target_arch = "aarch64", target_feature = "neon"),
1623
all(target_arch = "wasm32", intrinsics_enabled),
17-
all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"),
1824
),
1925
args: x,
2026
}
@@ -28,8 +34,8 @@ pub fn rint(x: f64) -> f64 {
2834
select_implementation! {
2935
name: rint,
3036
use_arch: any(
37+
all(target_arch = "aarch64", target_feature = "neon"),
3138
all(target_arch = "wasm32", intrinsics_enabled),
32-
all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"),
3339
),
3440
args: x,
3541
}

src/math/sqrt.rs

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub fn sqrt(x: f64) -> f64 {
44
select_implementation! {
55
name: sqrt,
66
use_arch: any(
7+
all(target_arch = "aarch64", target_feature = "neon"),
78
all(target_arch = "wasm32", intrinsics_enabled),
89
target_feature = "sse2"
910
),

src/math/sqrtf.rs

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ pub fn sqrtf(x: f32) -> f32 {
44
select_implementation! {
55
name: sqrtf,
66
use_arch: any(
7+
all(target_arch = "aarch64", target_feature = "neon"),
78
all(target_arch = "wasm32", intrinsics_enabled),
89
target_feature = "sse2"
910
),

src/math/sqrtf16.rs

+6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
/// The square root of `x` (f16).
22
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
33
pub fn sqrtf16(x: f16) -> f16 {
4+
select_implementation! {
5+
name: sqrtf16,
6+
use_arch: all(target_arch = "aarch64", target_feature = "fp16"),
7+
args: x,
8+
}
9+
410
return super::generic::sqrt(x);
511
}

0 commit comments

Comments
 (0)