Skip to content

Commit dabe80a

Browse files
committed
Auto merge of rust-lang#136457 - calder:master, r=<try>
Expose algebraic floating point intrinsics # Problem A stable Rust implementation of a simple dot product is 8x slower than C++ on modern x86-64 CPUs. The root cause is an inability to let the compiler reorder floating point operations for better vectorization. See https://github.com/calder/dot-bench for benchmarks. Measurements below were performed on a i7-10875H. ### C++: 10us ✅ With Clang 18.1.3 and `-O2 -march=haswell`: <table> <tr> <th>C++</th> <th>Assembly</th> </tr> <tr> <td> <pre lang="cc"> float dot(float *a, float *b, size_t len) { #pragma clang fp reassociate(on) float sum = 0.0; for (size_t i = 0; i < len; ++i) { sum += a[i] * b[i]; } return sum; } </pre> </td> <td> <img src="https://github.com/user-attachments/assets/739573c0-380a-4d84-9fd9-141343ce7e68" /> </td> </tr> </table> ### Nightly Rust: 10us ✅ With rustc 1.86.0-nightly (8239a37) and `-C opt-level=3 -C target-feature=+avx2,+fma`: <table> <tr> <th>Rust</th> <th>Assembly</th> </tr> <tr> <td> <pre lang="rust"> fn dot(a: &[f32], b: &[f32]) -> f32 { let mut sum = 0.0; for i in 0..a.len() { sum = fadd_algebraic(sum, fmul_algebraic(a[i], b[i])); } sum } </pre> </td> <td> <img src="https://github.com/user-attachments/assets/9dcf953a-2cd7-42f3-bc34-7117de4c5fb9" /> </td> </tr> </table> ### Stable Rust: 84us ❌ With rustc 1.84.1 (e71f9a9) and `-C opt-level=3 -C target-feature=+avx2,+fma`: <table> <tr> <th>Rust</th> <th>Assembly</th> </tr> <tr> <td> <pre lang="rust"> fn dot(a: &[f32], b: &[f32]) -> f32 { let mut sum = 0.0; for i in 0..a.len() { sum += a[i] * b[i]; } sum } </pre> </td> <td> <img src="https://github.com/user-attachments/assets/936a1f7e-33e4-4ff8-a732-c3cdfe068dca" /> </td> </tr> </table> # Proposed Change Add `core::intrinsics::f*_algebraic` wrappers to `f16`, `f32`, `f64`, and `f128` gated on a new `float_algebraic` feature. # Alternatives Considered rust-lang#21690 has a lot of good discussion of various options for supporting fast math in Rust, but is still open a decade later because any choice that opts in more than individual operations is ultimately contrary to Rust's design principles. In the mean time, processors have evolved and we're leaving major performance on the table by not supporting vectorization. We shouldn't make users choose between an unstable compiler and an 8x performance hit. # References * rust-lang#21690 * rust-lang/libs-team#532 * rust-lang#136469 * https://github.com/calder/dot-bench * https://www.felixcloutier.com/x86/vfmadd132ps:vfmadd213ps:vfmadd231ps try-job: x86_64-gnu-nopt try-job: x86_64-gnu-aux
2 parents f2d69d5 + 0151a01 commit dabe80a

File tree

14 files changed

+507
-19
lines changed

14 files changed

+507
-19
lines changed

library/core/src/intrinsics/mod.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -2324,35 +2324,35 @@ pub unsafe fn float_to_int_unchecked<Float: Copy, Int: Copy>(_value: Float) -> I
23242324

23252325
/// Float addition that allows optimizations based on algebraic rules.
23262326
///
2327-
/// This intrinsic does not have a stable counterpart.
2327+
/// Stabilized as [`f16::algebraic_add`], [`f32::algebraic_add`], [`f64::algebraic_add`] and [`f128::algebraic_add`].
23282328
#[rustc_nounwind]
23292329
#[rustc_intrinsic]
23302330
pub fn fadd_algebraic<T: Copy>(_a: T, _b: T) -> T;
23312331

23322332
/// Float subtraction that allows optimizations based on algebraic rules.
23332333
///
2334-
/// This intrinsic does not have a stable counterpart.
2334+
/// Stabilized as [`f16::algebraic_sub`], [`f32::algebraic_sub`], [`f64::algebraic_sub`] and [`f128::algebraic_sub`].
23352335
#[rustc_nounwind]
23362336
#[rustc_intrinsic]
23372337
pub fn fsub_algebraic<T: Copy>(_a: T, _b: T) -> T;
23382338

23392339
/// Float multiplication that allows optimizations based on algebraic rules.
23402340
///
2341-
/// This intrinsic does not have a stable counterpart.
2341+
/// Stabilized as [`f16::algebraic_mul`], [`f32::algebraic_mul`], [`f64::algebraic_mul`] and [`f128::algebraic_mul`].
23422342
#[rustc_nounwind]
23432343
#[rustc_intrinsic]
23442344
pub fn fmul_algebraic<T: Copy>(_a: T, _b: T) -> T;
23452345

23462346
/// Float division that allows optimizations based on algebraic rules.
23472347
///
2348-
/// This intrinsic does not have a stable counterpart.
2348+
/// Stabilized as [`f16::algebraic_div`], [`f32::algebraic_div`], [`f64::algebraic_div`] and [`f128::algebraic_div`].
23492349
#[rustc_nounwind]
23502350
#[rustc_intrinsic]
23512351
pub fn fdiv_algebraic<T: Copy>(_a: T, _b: T) -> T;
23522352

23532353
/// Float remainder that allows optimizations based on algebraic rules.
23542354
///
2355-
/// This intrinsic does not have a stable counterpart.
2355+
/// Stabilized as [`f16::algebraic_rem`], [`f32::algebraic_rem`], [`f64::algebraic_rem`] and [`f128::algebraic_rem`].
23562356
#[rustc_nounwind]
23572357
#[rustc_intrinsic]
23582358
pub fn frem_algebraic<T: Copy>(_a: T, _b: T) -> T;

library/core/src/num/f128.rs

+50
Original file line numberDiff line numberDiff line change
@@ -1365,4 +1365,54 @@ impl f128 {
13651365
// SAFETY: this is actually a safe intrinsic
13661366
unsafe { intrinsics::copysignf128(self, sign) }
13671367
}
1368+
1369+
/// Float addition that allows optimizations based on algebraic rules.
1370+
///
1371+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1372+
#[must_use = "method returns a new number and does not mutate the original value"]
1373+
#[unstable(feature = "float_algebraic", issue = "136469")]
1374+
#[inline]
1375+
pub fn algebraic_add(self, rhs: f128) -> f128 {
1376+
intrinsics::fadd_algebraic(self, rhs)
1377+
}
1378+
1379+
/// Float subtraction that allows optimizations based on algebraic rules.
1380+
///
1381+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1382+
#[must_use = "method returns a new number and does not mutate the original value"]
1383+
#[unstable(feature = "float_algebraic", issue = "136469")]
1384+
#[inline]
1385+
pub fn algebraic_sub(self, rhs: f128) -> f128 {
1386+
intrinsics::fsub_algebraic(self, rhs)
1387+
}
1388+
1389+
/// Float multiplication that allows optimizations based on algebraic rules.
1390+
///
1391+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1392+
#[must_use = "method returns a new number and does not mutate the original value"]
1393+
#[unstable(feature = "float_algebraic", issue = "136469")]
1394+
#[inline]
1395+
pub fn algebraic_mul(self, rhs: f128) -> f128 {
1396+
intrinsics::fmul_algebraic(self, rhs)
1397+
}
1398+
1399+
/// Float division that allows optimizations based on algebraic rules.
1400+
///
1401+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1402+
#[must_use = "method returns a new number and does not mutate the original value"]
1403+
#[unstable(feature = "float_algebraic", issue = "136469")]
1404+
#[inline]
1405+
pub fn algebraic_div(self, rhs: f128) -> f128 {
1406+
intrinsics::fdiv_algebraic(self, rhs)
1407+
}
1408+
1409+
/// Float remainder that allows optimizations based on algebraic rules.
1410+
///
1411+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1412+
#[must_use = "method returns a new number and does not mutate the original value"]
1413+
#[unstable(feature = "float_algebraic", issue = "136469")]
1414+
#[inline]
1415+
pub fn algebraic_rem(self, rhs: f128) -> f128 {
1416+
intrinsics::frem_algebraic(self, rhs)
1417+
}
13681418
}

library/core/src/num/f16.rs

+50
Original file line numberDiff line numberDiff line change
@@ -1341,4 +1341,54 @@ impl f16 {
13411341
// SAFETY: this is actually a safe intrinsic
13421342
unsafe { intrinsics::copysignf16(self, sign) }
13431343
}
1344+
1345+
/// Float addition that allows optimizations based on algebraic rules.
1346+
///
1347+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1348+
#[must_use = "method returns a new number and does not mutate the original value"]
1349+
#[unstable(feature = "float_algebraic", issue = "136469")]
1350+
#[inline]
1351+
pub fn algebraic_add(self, rhs: f16) -> f16 {
1352+
intrinsics::fadd_algebraic(self, rhs)
1353+
}
1354+
1355+
/// Float subtraction that allows optimizations based on algebraic rules.
1356+
///
1357+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1358+
#[must_use = "method returns a new number and does not mutate the original value"]
1359+
#[unstable(feature = "float_algebraic", issue = "136469")]
1360+
#[inline]
1361+
pub fn algebraic_sub(self, rhs: f16) -> f16 {
1362+
intrinsics::fsub_algebraic(self, rhs)
1363+
}
1364+
1365+
/// Float multiplication that allows optimizations based on algebraic rules.
1366+
///
1367+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1368+
#[must_use = "method returns a new number and does not mutate the original value"]
1369+
#[unstable(feature = "float_algebraic", issue = "136469")]
1370+
#[inline]
1371+
pub fn algebraic_mul(self, rhs: f16) -> f16 {
1372+
intrinsics::fmul_algebraic(self, rhs)
1373+
}
1374+
1375+
/// Float division that allows optimizations based on algebraic rules.
1376+
///
1377+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1378+
#[must_use = "method returns a new number and does not mutate the original value"]
1379+
#[unstable(feature = "float_algebraic", issue = "136469")]
1380+
#[inline]
1381+
pub fn algebraic_div(self, rhs: f16) -> f16 {
1382+
intrinsics::fdiv_algebraic(self, rhs)
1383+
}
1384+
1385+
/// Float remainder that allows optimizations based on algebraic rules.
1386+
///
1387+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1388+
#[must_use = "method returns a new number and does not mutate the original value"]
1389+
#[unstable(feature = "float_algebraic", issue = "136469")]
1390+
#[inline]
1391+
pub fn algebraic_rem(self, rhs: f16) -> f16 {
1392+
intrinsics::frem_algebraic(self, rhs)
1393+
}
13441394
}

library/core/src/num/f32.rs

+50
Original file line numberDiff line numberDiff line change
@@ -1506,4 +1506,54 @@ impl f32 {
15061506
// SAFETY: this is actually a safe intrinsic
15071507
unsafe { intrinsics::copysignf32(self, sign) }
15081508
}
1509+
1510+
/// Float addition that allows optimizations based on algebraic rules.
1511+
///
1512+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1513+
#[must_use = "method returns a new number and does not mutate the original value"]
1514+
#[unstable(feature = "float_algebraic", issue = "136469")]
1515+
#[inline]
1516+
pub fn algebraic_add(self, rhs: f32) -> f32 {
1517+
intrinsics::fadd_algebraic(self, rhs)
1518+
}
1519+
1520+
/// Float subtraction that allows optimizations based on algebraic rules.
1521+
///
1522+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1523+
#[must_use = "method returns a new number and does not mutate the original value"]
1524+
#[unstable(feature = "float_algebraic", issue = "136469")]
1525+
#[inline]
1526+
pub fn algebraic_sub(self, rhs: f32) -> f32 {
1527+
intrinsics::fsub_algebraic(self, rhs)
1528+
}
1529+
1530+
/// Float multiplication that allows optimizations based on algebraic rules.
1531+
///
1532+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1533+
#[must_use = "method returns a new number and does not mutate the original value"]
1534+
#[unstable(feature = "float_algebraic", issue = "136469")]
1535+
#[inline]
1536+
pub fn algebraic_mul(self, rhs: f32) -> f32 {
1537+
intrinsics::fmul_algebraic(self, rhs)
1538+
}
1539+
1540+
/// Float division that allows optimizations based on algebraic rules.
1541+
///
1542+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1543+
#[must_use = "method returns a new number and does not mutate the original value"]
1544+
#[unstable(feature = "float_algebraic", issue = "136469")]
1545+
#[inline]
1546+
pub fn algebraic_div(self, rhs: f32) -> f32 {
1547+
intrinsics::fdiv_algebraic(self, rhs)
1548+
}
1549+
1550+
/// Float remainder that allows optimizations based on algebraic rules.
1551+
///
1552+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1553+
#[must_use = "method returns a new number and does not mutate the original value"]
1554+
#[unstable(feature = "float_algebraic", issue = "136469")]
1555+
#[inline]
1556+
pub fn algebraic_rem(self, rhs: f32) -> f32 {
1557+
intrinsics::frem_algebraic(self, rhs)
1558+
}
15091559
}

library/core/src/num/f64.rs

+50
Original file line numberDiff line numberDiff line change
@@ -1506,4 +1506,54 @@ impl f64 {
15061506
// SAFETY: this is actually a safe intrinsic
15071507
unsafe { intrinsics::copysignf64(self, sign) }
15081508
}
1509+
1510+
/// Float addition that allows optimizations based on algebraic rules.
1511+
///
1512+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1513+
#[must_use = "method returns a new number and does not mutate the original value"]
1514+
#[unstable(feature = "float_algebraic", issue = "136469")]
1515+
#[inline]
1516+
pub fn algebraic_add(self, rhs: f64) -> f64 {
1517+
intrinsics::fadd_algebraic(self, rhs)
1518+
}
1519+
1520+
/// Float subtraction that allows optimizations based on algebraic rules.
1521+
///
1522+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1523+
#[must_use = "method returns a new number and does not mutate the original value"]
1524+
#[unstable(feature = "float_algebraic", issue = "136469")]
1525+
#[inline]
1526+
pub fn algebraic_sub(self, rhs: f64) -> f64 {
1527+
intrinsics::fsub_algebraic(self, rhs)
1528+
}
1529+
1530+
/// Float multiplication that allows optimizations based on algebraic rules.
1531+
///
1532+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1533+
#[must_use = "method returns a new number and does not mutate the original value"]
1534+
#[unstable(feature = "float_algebraic", issue = "136469")]
1535+
#[inline]
1536+
pub fn algebraic_mul(self, rhs: f64) -> f64 {
1537+
intrinsics::fmul_algebraic(self, rhs)
1538+
}
1539+
1540+
/// Float division that allows optimizations based on algebraic rules.
1541+
///
1542+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1543+
#[must_use = "method returns a new number and does not mutate the original value"]
1544+
#[unstable(feature = "float_algebraic", issue = "136469")]
1545+
#[inline]
1546+
pub fn algebraic_div(self, rhs: f64) -> f64 {
1547+
intrinsics::fdiv_algebraic(self, rhs)
1548+
}
1549+
1550+
/// Float remainder that allows optimizations based on algebraic rules.
1551+
///
1552+
/// See [algebraic operators](primitive@f32#algebraic-operators) for more info.
1553+
#[must_use = "method returns a new number and does not mutate the original value"]
1554+
#[unstable(feature = "float_algebraic", issue = "136469")]
1555+
#[inline]
1556+
pub fn algebraic_rem(self, rhs: f64) -> f64 {
1557+
intrinsics::frem_algebraic(self, rhs)
1558+
}
15091559
}

library/core/src/primitive_docs.rs

+45
Original file line numberDiff line numberDiff line change
@@ -1315,6 +1315,51 @@ mod prim_f16 {}
13151315
/// | `wasm32`, `wasm64` | If all input NaNs are quiet with all-zero payload: None.<br> Otherwise: all possible payloads. |
13161316
///
13171317
/// For targets not in this table, all payloads are possible.
1318+
///
1319+
/// # Algebraic operators
1320+
///
1321+
/// Algebraic operators of the form `a.algebraic_*(b)` allow the compiler to optimize
1322+
/// floating point operations using all the usual algebraic properties of real numbers --
1323+
/// despite the fact that those properties do *not* hold on floating point numbers.
1324+
/// This can give a great performance boost since it may unlock vectorization.
1325+
///
1326+
/// The exact set of optimizations is unspecified but typically allows combining operations,
1327+
/// rearranging series of operations based on mathematical properties, converting between division
1328+
/// and reciprocal multiplication, and disregarding the sign of zero. This means that the results of
1329+
/// elementary operations may have undefined precision, and "non-mathematical" values
1330+
/// such as NaN, +/-Inf, or -0.0 may behave in unexpected ways, but these operations
1331+
/// will never cause undefined behavior.
1332+
///
1333+
/// Because of the unpredictable nature of compiler optimizations, the same inputs may produce
1334+
/// different results even within a single program run. **Unsafe code must not rely on any property
1335+
/// of the return value for soundness.** However, implementations will generally do their best to
1336+
/// pick a reasonable tradeoff between performance and accuracy of the result.
1337+
///
1338+
/// For example:
1339+
///
1340+
/// ```
1341+
/// # #![feature(float_algebraic)]
1342+
/// # #![allow(unused_assignments)]
1343+
/// # let mut x: f32 = 0.0;
1344+
/// # let a: f32 = 1.0;
1345+
/// # let b: f32 = 2.0;
1346+
/// # let c: f32 = 3.0;
1347+
/// # let d: f32 = 4.0;
1348+
/// x = a.algebraic_add(b).algebraic_add(c).algebraic_add(d);
1349+
/// ```
1350+
///
1351+
/// May be rewritten as:
1352+
///
1353+
/// ```
1354+
/// # #![allow(unused_assignments)]
1355+
/// # let mut x: f32 = 0.0;
1356+
/// # let a: f32 = 1.0;
1357+
/// # let b: f32 = 2.0;
1358+
/// # let c: f32 = 3.0;
1359+
/// # let d: f32 = 4.0;
1360+
/// x = a + b + c + d; // As written
1361+
/// x = (a + c) + (b + d); // Reordered to shorten critical path and enable vectorization
1362+
/// ```
13181363
13191364
#[stable(feature = "rust1", since = "1.0.0")]
13201365
mod prim_f32 {}

library/std/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,7 @@
338338
#![feature(exact_size_is_empty)]
339339
#![feature(exclusive_wrapper)]
340340
#![feature(extend_one)]
341+
#![feature(float_algebraic)]
341342
#![feature(float_gamma)]
342343
#![feature(float_minimum_maximum)]
343344
#![feature(fmt_internals)]

library/std/tests/floats/f128.rs

+12
Original file line numberDiff line numberDiff line change
@@ -983,3 +983,15 @@ fn test_total_cmp() {
983983
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&f128::INFINITY));
984984
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&s_nan()));
985985
}
986+
987+
#[test]
988+
fn test_algebraic() {
989+
let a: f128 = 123.0;
990+
let b: f128 = 456.0;
991+
992+
assert_approx_eq!(a.algebraic_add(b), a + b);
993+
assert_approx_eq!(a.algebraic_sub(b), a - b);
994+
assert_approx_eq!(a.algebraic_mul(b), a * b);
995+
assert_approx_eq!(a.algebraic_div(b), a / b);
996+
assert_approx_eq!(a.algebraic_rem(b), a % b);
997+
}

library/std/tests/floats/f16.rs

+12
Original file line numberDiff line numberDiff line change
@@ -955,3 +955,15 @@ fn test_total_cmp() {
955955
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&f16::INFINITY));
956956
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&s_nan()));
957957
}
958+
959+
#[test]
960+
fn test_algebraic() {
961+
let a: f16 = 123.0;
962+
let b: f16 = 456.0;
963+
964+
assert_approx_eq!(a.algebraic_add(b), a + b, 1e1);
965+
assert_approx_eq!(a.algebraic_sub(b), a - b, 1e1);
966+
assert_approx_eq!(a.algebraic_mul(b), a * b, 1e3);
967+
assert_approx_eq!(a.algebraic_div(b), a / b, 1e-2);
968+
assert_approx_eq!(a.algebraic_rem(b), a % b, 1e1);
969+
}

library/std/tests/floats/f32.rs

+12
Original file line numberDiff line numberDiff line change
@@ -915,3 +915,15 @@ fn test_total_cmp() {
915915
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&f32::INFINITY));
916916
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&s_nan()));
917917
}
918+
919+
#[test]
920+
fn test_algebraic() {
921+
let a: f32 = 123.0;
922+
let b: f32 = 456.0;
923+
924+
assert_approx_eq!(a.algebraic_add(b), a + b, 1e-3);
925+
assert_approx_eq!(a.algebraic_sub(b), a - b, 1e-3);
926+
assert_approx_eq!(a.algebraic_mul(b), a * b, 1e-1);
927+
assert_approx_eq!(a.algebraic_div(b), a / b, 1e-6);
928+
assert_approx_eq!(a.algebraic_rem(b), a % b, 1e-3);
929+
}

library/std/tests/floats/f64.rs

+12
Original file line numberDiff line numberDiff line change
@@ -894,3 +894,15 @@ fn test_total_cmp() {
894894
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&f64::INFINITY));
895895
assert_eq!(Ordering::Less, (-s_nan()).total_cmp(&s_nan()));
896896
}
897+
898+
#[test]
899+
fn test_algebraic() {
900+
let a: f64 = 123.0;
901+
let b: f64 = 456.0;
902+
903+
assert_approx_eq!(a.algebraic_add(b), a + b);
904+
assert_approx_eq!(a.algebraic_sub(b), a - b);
905+
assert_approx_eq!(a.algebraic_mul(b), a * b);
906+
assert_approx_eq!(a.algebraic_div(b), a / b);
907+
assert_approx_eq!(a.algebraic_rem(b), a % b);
908+
}

library/std/tests/floats/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#![feature(f16, f128, float_gamma, float_minimum_maximum)]
1+
#![feature(f16, f128, float_algebraic, float_gamma, float_minimum_maximum)]
22

33
use std::fmt;
44
use std::ops::{Add, Div, Mul, Rem, Sub};

0 commit comments

Comments
 (0)