Skip to content

Commit 059704b

Browse files
authored
Merge pull request #4071 from RalfJung/simd_relaxed_fma
implement simd_relaxed_fma
2 parents efd1352 + 16ee60a commit 059704b

File tree

3 files changed

+97
-33
lines changed

3 files changed

+97
-33
lines changed

src/intrinsics/simd.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use either::Either;
2+
use rand::Rng;
23
use rustc_abi::{Endian, HasDataLayout};
34
use rustc_apfloat::{Float, Round};
45
use rustc_middle::ty::FloatTy;
@@ -286,7 +287,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
286287
this.write_scalar(val, &dest)?;
287288
}
288289
}
289-
"fma" => {
290+
"fma" | "relaxed_fma" => {
290291
let [a, b, c] = check_arg_count(args)?;
291292
let (a, a_len) = this.project_to_simd(a)?;
292293
let (b, b_len) = this.project_to_simd(b)?;
@@ -303,6 +304,8 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
303304
let c = this.read_scalar(&this.project_index(&c, i)?)?;
304305
let dest = this.project_index(&dest, i)?;
305306

307+
let fuse: bool = intrinsic_name == "fma" || this.machine.rng.get_mut().gen();
308+
306309
// Works for f32 and f64.
307310
// FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
308311
let ty::Float(float_ty) = dest.layout.ty.kind() else {
@@ -314,15 +317,23 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
314317
let a = a.to_f32()?;
315318
let b = b.to_f32()?;
316319
let c = c.to_f32()?;
317-
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
320+
let res = if fuse {
321+
a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
322+
} else {
323+
((a * b).value + c).value
324+
};
318325
let res = this.adjust_nan(res, &[a, b, c]);
319326
Scalar::from(res)
320327
}
321328
FloatTy::F64 => {
322329
let a = a.to_f64()?;
323330
let b = b.to_f64()?;
324331
let c = c.to_f64()?;
325-
let res = a.to_host().mul_add(b.to_host(), c.to_host()).to_soft();
332+
let res = if fuse {
333+
a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
334+
} else {
335+
((a * b).value + c).value
336+
};
326337
let res = this.adjust_nan(res, &[a, b, c]);
327338
Scalar::from(res)
328339
}
Lines changed: 61 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,75 @@
1-
#![feature(core_intrinsics)]
1+
#![feature(core_intrinsics, portable_simd)]
2+
use std::intrinsics::simd::simd_relaxed_fma;
23
use std::intrinsics::{fmuladdf32, fmuladdf64};
4+
use std::simd::prelude::*;
35

4-
fn main() {
5-
let mut saw_zero = false;
6-
let mut saw_nonzero = false;
6+
fn ensure_both_happen(f: impl Fn() -> bool) -> bool {
7+
let mut saw_true = false;
8+
let mut saw_false = false;
79
for _ in 0..50 {
8-
let a = std::hint::black_box(0.1_f64);
9-
let b = std::hint::black_box(0.2);
10-
let c = std::hint::black_box(-a * b);
11-
// It is unspecified whether the following operation is fused or not. The
12-
// following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused.
13-
let x = unsafe { fmuladdf64(a, b, c) };
14-
if x == 0.0 {
15-
saw_zero = true;
10+
let b = f();
11+
if b {
12+
saw_true = true;
1613
} else {
17-
saw_nonzero = true;
14+
saw_false = true;
15+
}
16+
if saw_true && saw_false {
17+
return true;
1818
}
1919
}
20+
false
21+
}
22+
23+
fn main() {
2024
assert!(
21-
saw_zero && saw_nonzero,
25+
ensure_both_happen(|| {
26+
let a = std::hint::black_box(0.1_f64);
27+
let b = std::hint::black_box(0.2);
28+
let c = std::hint::black_box(-a * b);
29+
// It is unspecified whether the following operation is fused or not. The
30+
// following evaluates to 0.0 if unfused, and nonzero (-1.66e-18) if fused.
31+
let x = unsafe { fmuladdf64(a, b, c) };
32+
x == 0.0
33+
}),
2234
"`fmuladdf64` failed to be evaluated as both fused and unfused"
2335
);
2436

25-
let mut saw_zero = false;
26-
let mut saw_nonzero = false;
27-
for _ in 0..50 {
28-
let a = std::hint::black_box(0.1_f32);
29-
let b = std::hint::black_box(0.2);
30-
let c = std::hint::black_box(-a * b);
31-
// It is unspecified whether the following operation is fused or not. The
32-
// following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused.
33-
let x = unsafe { fmuladdf32(a, b, c) };
34-
if x == 0.0 {
35-
saw_zero = true;
36-
} else {
37-
saw_nonzero = true;
38-
}
39-
}
4037
assert!(
41-
saw_zero && saw_nonzero,
38+
ensure_both_happen(|| {
39+
let a = std::hint::black_box(0.1_f32);
40+
let b = std::hint::black_box(0.2);
41+
let c = std::hint::black_box(-a * b);
42+
// It is unspecified whether the following operation is fused or not. The
43+
// following evaluates to 0.0 if unfused, and nonzero (-8.1956386e-10) if fused.
44+
let x = unsafe { fmuladdf32(a, b, c) };
45+
x == 0.0
46+
}),
4247
"`fmuladdf32` failed to be evaluated as both fused and unfused"
4348
);
49+
50+
assert!(
51+
ensure_both_happen(|| {
52+
let a = f32x4::splat(std::hint::black_box(0.1));
53+
let b = f32x4::splat(std::hint::black_box(0.2));
54+
let c = std::hint::black_box(-a * b);
55+
let x = unsafe { simd_relaxed_fma(a, b, c) };
56+
// Whether we fuse or not is a per-element decision, so sometimes these should be
57+
// the same and sometimes not.
58+
x[0] == x[1]
59+
}),
60+
"`simd_relaxed_fma` failed to be evaluated as both fused and unfused"
61+
);
62+
63+
assert!(
64+
ensure_both_happen(|| {
65+
let a = f64x4::splat(std::hint::black_box(0.1));
66+
let b = f64x4::splat(std::hint::black_box(0.2));
67+
let c = std::hint::black_box(-a * b);
68+
let x = unsafe { simd_relaxed_fma(a, b, c) };
69+
// Whether we fuse or not is a per-element decision, so sometimes these should be
70+
// the same and sometimes not.
71+
x[0] == x[1]
72+
}),
73+
"`simd_relaxed_fma` failed to be evaluated as both fused and unfused"
74+
);
4475
}

tests/pass/intrinsics/portable-simd.rs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,17 @@ fn simd_ops_f32() {
4040
f32x4::splat(-3.2).mul_add(b, f32x4::splat(f32::NEG_INFINITY)),
4141
f32x4::splat(f32::NEG_INFINITY)
4242
);
43+
44+
unsafe {
45+
assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a);
46+
assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a);
47+
assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b);
48+
assert_eq!(
49+
intrinsics::simd_relaxed_fma(f32x4::splat(-3.2), b, f32x4::splat(f32::NEG_INFINITY)),
50+
f32x4::splat(f32::NEG_INFINITY)
51+
);
52+
}
53+
4354
assert_eq!((a * a).sqrt(), a);
4455
assert_eq!((b * b).sqrt(), b.abs());
4556

@@ -94,6 +105,17 @@ fn simd_ops_f64() {
94105
f64x4::splat(-3.2).mul_add(b, f64x4::splat(f64::NEG_INFINITY)),
95106
f64x4::splat(f64::NEG_INFINITY)
96107
);
108+
109+
unsafe {
110+
assert_eq!(intrinsics::simd_relaxed_fma(a, b, a), (a * b) + a);
111+
assert_eq!(intrinsics::simd_relaxed_fma(b, b, a), (b * b) + a);
112+
assert_eq!(intrinsics::simd_relaxed_fma(a, b, b), (a * b) + b);
113+
assert_eq!(
114+
intrinsics::simd_relaxed_fma(f64x4::splat(-3.2), b, f64x4::splat(f64::NEG_INFINITY)),
115+
f64x4::splat(f64::NEG_INFINITY)
116+
);
117+
}
118+
97119
assert_eq!((a * a).sqrt(), a);
98120
assert_eq!((b * b).sqrt(), b.abs());
99121

0 commit comments

Comments
 (0)