Skip to content
This repository was archived by the owner on Apr 28, 2025. It is now read-only.

Commit a7a1be1

Browse files
authored
Merge pull request #430 from hanna-kruppe/rint-arch
wasm32 and aarch64 intrinsics for rint and rintf
2 parents 1666f41 + 5562dd3 commit a7a1be1

File tree

6 files changed

+73
-1
lines changed

6 files changed

+73
-1
lines changed

libm/etc/function-definitions.json

+4
Original file line numberDiff line numberDiff line change
@@ -604,12 +604,16 @@
604604
"rint": {
605605
"sources": [
606606
"src/libm_helper.rs",
607+
"src/math/arch/aarch64.rs",
608+
"src/math/arch/wasm32.rs",
607609
"src/math/rint.rs"
608610
],
609611
"type": "f64"
610612
},
611613
"rintf": {
612614
"sources": [
615+
"src/math/arch/aarch64.rs",
616+
"src/math/arch/wasm32.rs",
613617
"src/math/rintf.rs"
614618
],
615619
"type": "f32"

libm/src/math/arch/aarch64.rs

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
use core::arch::aarch64::{
2+
float32x2_t, float64x1_t, vdup_n_f32, vdup_n_f64, vget_lane_f32, vget_lane_f64, vrndn_f32,
3+
vrndn_f64,
4+
};
5+
6+
pub fn rint(x: f64) -> f64 {
7+
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
8+
let x_vec: float64x1_t = unsafe { vdup_n_f64(x) };
9+
10+
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
11+
let result_vec: float64x1_t = unsafe { vrndn_f64(x_vec) };
12+
13+
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
14+
let result: f64 = unsafe { vget_lane_f64::<0>(result_vec) };
15+
16+
result
17+
}
18+
19+
pub fn rintf(x: f32) -> f32 {
20+
// There's a scalar form of this instruction (FRINTN) but core::arch doesn't expose it, so we
21+
// have to use the vector form and drop the other lanes afterwards.
22+
23+
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
24+
let x_vec: float32x2_t = unsafe { vdup_n_f32(x) };
25+
26+
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
27+
let result_vec: float32x2_t = unsafe { vrndn_f32(x_vec) };
28+
29+
// SAFETY: only requires target_feature=neon, ensured by `cfg_if` in parent module.
30+
let result: f32 = unsafe { vget_lane_f32::<0>(result_vec) };
31+
32+
result
33+
}

libm/src/math/arch/mod.rs

+10-1
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,19 @@
1111
cfg_if! {
1212
if #[cfg(all(target_arch = "wasm32", intrinsics_enabled))] {
1313
mod wasm32;
14-
pub use wasm32::{ceil, ceilf, fabs, fabsf, floor, floorf, sqrt, sqrtf, trunc, truncf};
14+
pub use wasm32::{
15+
ceil, ceilf, fabs, fabsf, floor, floorf, rint, rintf, sqrt, sqrtf, trunc, truncf,
16+
};
1517
} else if #[cfg(target_feature = "sse2")] {
1618
mod i686;
1719
pub use i686::{sqrt, sqrtf};
20+
} else if #[cfg(all(
21+
target_arch = "aarch64", // TODO: also arm64ec?
22+
target_feature = "neon",
23+
target_endian = "little", // see https://github.com/rust-lang/stdarch/issues/1484
24+
))] {
25+
mod aarch64;
26+
pub use aarch64::{rint, rintf};
1827
}
1928
}
2029

libm/src/math/arch/wasm32.rs

+8
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ pub fn floorf(x: f32) -> f32 {
2525
core::arch::wasm32::f32_floor(x)
2626
}
2727

28+
pub fn rint(x: f64) -> f64 {
29+
core::arch::wasm32::f64_nearest(x)
30+
}
31+
32+
pub fn rintf(x: f32) -> f32 {
33+
core::arch::wasm32::f32_nearest(x)
34+
}
35+
2836
pub fn sqrt(x: f64) -> f64 {
2937
core::arch::wasm32::f64_sqrt(x)
3038
}

libm/src/math/rint.rs

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
22
pub fn rint(x: f64) -> f64 {
3+
select_implementation! {
4+
name: rint,
5+
use_arch: any(
6+
all(target_arch = "wasm32", intrinsics_enabled),
7+
all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"),
8+
),
9+
args: x,
10+
}
11+
312
let one_over_e = 1.0 / f64::EPSILON;
413
let as_u64: u64 = x.to_bits();
514
let exponent: u64 = (as_u64 >> 52) & 0x7ff;

libm/src/math/rintf.rs

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
22
pub fn rintf(x: f32) -> f32 {
3+
select_implementation! {
4+
name: rintf,
5+
use_arch: any(
6+
all(target_arch = "wasm32", intrinsics_enabled),
7+
all(target_arch = "aarch64", target_feature = "neon", target_endian = "little"),
8+
),
9+
args: x,
10+
}
11+
312
let one_over_e = 1.0 / f32::EPSILON;
413
let as_u32: u32 = x.to_bits();
514
let exponent: u32 = (as_u32 >> 23) & 0xff;

0 commit comments

Comments
 (0)