Skip to content

Commit e694688

Browse files
authored
Merge pull request rust-lang#4026 from eduardosm/soft-sqrt
miri: implement square root without relying on host floats
2 parents 8d3d694 + 8a5c187 commit e694688

File tree

6 files changed

+219
-68
lines changed

6 files changed

+219
-68
lines changed

src/tools/miri/src/intrinsics/mod.rs

+20-22
Original file line numberDiff line numberDiff line change
@@ -218,20 +218,19 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
218218
=> {
219219
let [f] = check_arg_count(args)?;
220220
let f = this.read_scalar(f)?.to_f32()?;
221-
// Using host floats (but it's fine, these operations do not have guaranteed precision).
222-
let f_host = f.to_host();
221+
// Using host floats except for sqrt (but it's fine, these operations do not have
222+
// guaranteed precision).
223223
let res = match intrinsic_name {
224-
"sinf32" => f_host.sin(),
225-
"cosf32" => f_host.cos(),
226-
"sqrtf32" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats
227-
"expf32" => f_host.exp(),
228-
"exp2f32" => f_host.exp2(),
229-
"logf32" => f_host.ln(),
230-
"log10f32" => f_host.log10(),
231-
"log2f32" => f_host.log2(),
224+
"sinf32" => f.to_host().sin().to_soft(),
225+
"cosf32" => f.to_host().cos().to_soft(),
226+
"sqrtf32" => math::sqrt(f),
227+
"expf32" => f.to_host().exp().to_soft(),
228+
"exp2f32" => f.to_host().exp2().to_soft(),
229+
"logf32" => f.to_host().ln().to_soft(),
230+
"log10f32" => f.to_host().log10().to_soft(),
231+
"log2f32" => f.to_host().log2().to_soft(),
232232
_ => bug!(),
233233
};
234-
let res = res.to_soft();
235234
let res = this.adjust_nan(res, &[f]);
236235
this.write_scalar(res, dest)?;
237236
}
@@ -247,20 +246,19 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
247246
=> {
248247
let [f] = check_arg_count(args)?;
249248
let f = this.read_scalar(f)?.to_f64()?;
250-
// Using host floats (but it's fine, these operations do not have guaranteed precision).
251-
let f_host = f.to_host();
249+
// Using host floats except for sqrt (but it's fine, these operations do not have
250+
// guaranteed precision).
252251
let res = match intrinsic_name {
253-
"sinf64" => f_host.sin(),
254-
"cosf64" => f_host.cos(),
255-
"sqrtf64" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats
256-
"expf64" => f_host.exp(),
257-
"exp2f64" => f_host.exp2(),
258-
"logf64" => f_host.ln(),
259-
"log10f64" => f_host.log10(),
260-
"log2f64" => f_host.log2(),
252+
"sinf64" => f.to_host().sin().to_soft(),
253+
"cosf64" => f.to_host().cos().to_soft(),
254+
"sqrtf64" => math::sqrt(f),
255+
"expf64" => f.to_host().exp().to_soft(),
256+
"exp2f64" => f.to_host().exp2().to_soft(),
257+
"logf64" => f.to_host().ln().to_soft(),
258+
"log10f64" => f.to_host().log10().to_soft(),
259+
"log2f64" => f.to_host().log2().to_soft(),
261260
_ => bug!(),
262261
};
263-
let res = res.to_soft();
264262
let res = this.adjust_nan(res, &[f]);
265263
this.write_scalar(res, dest)?;
266264
}

src/tools/miri/src/intrinsics/simd.rs

+18-21
Original file line numberDiff line numberDiff line change
@@ -104,42 +104,39 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
104104
let ty::Float(float_ty) = op.layout.ty.kind() else {
105105
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
106106
};
107-
// Using host floats (but it's fine, these operations do not have guaranteed precision).
107+
// Using host floats except for sqrt (but it's fine, these operations do not
108+
// have guaranteed precision).
108109
match float_ty {
109110
FloatTy::F16 => unimplemented!("f16_f128"),
110111
FloatTy::F32 => {
111112
let f = op.to_scalar().to_f32()?;
112-
let f_host = f.to_host();
113113
let res = match host_op {
114-
"fsqrt" => f_host.sqrt(), // FIXME Using host floats, this should use full-precision soft-floats
115-
"fsin" => f_host.sin(),
116-
"fcos" => f_host.cos(),
117-
"fexp" => f_host.exp(),
118-
"fexp2" => f_host.exp2(),
119-
"flog" => f_host.ln(),
120-
"flog2" => f_host.log2(),
121-
"flog10" => f_host.log10(),
114+
"fsqrt" => math::sqrt(f),
115+
"fsin" => f.to_host().sin().to_soft(),
116+
"fcos" => f.to_host().cos().to_soft(),
117+
"fexp" => f.to_host().exp().to_soft(),
118+
"fexp2" => f.to_host().exp2().to_soft(),
119+
"flog" => f.to_host().ln().to_soft(),
120+
"flog2" => f.to_host().log2().to_soft(),
121+
"flog10" => f.to_host().log10().to_soft(),
122122
_ => bug!(),
123123
};
124-
let res = res.to_soft();
125124
let res = this.adjust_nan(res, &[f]);
126125
Scalar::from(res)
127126
}
128127
FloatTy::F64 => {
129128
let f = op.to_scalar().to_f64()?;
130-
let f_host = f.to_host();
131129
let res = match host_op {
132-
"fsqrt" => f_host.sqrt(),
133-
"fsin" => f_host.sin(),
134-
"fcos" => f_host.cos(),
135-
"fexp" => f_host.exp(),
136-
"fexp2" => f_host.exp2(),
137-
"flog" => f_host.ln(),
138-
"flog2" => f_host.log2(),
139-
"flog10" => f_host.log10(),
130+
"fsqrt" => math::sqrt(f),
131+
"fsin" => f.to_host().sin().to_soft(),
132+
"fcos" => f.to_host().cos().to_soft(),
133+
"fexp" => f.to_host().exp().to_soft(),
134+
"fexp2" => f.to_host().exp2().to_soft(),
135+
"flog" => f.to_host().ln().to_soft(),
136+
"flog2" => f.to_host().log2().to_soft(),
137+
"flog10" => f.to_host().log10().to_soft(),
140138
_ => bug!(),
141139
};
142-
let res = res.to_soft();
143140
let res = this.adjust_nan(res, &[f]);
144141
Scalar::from(res)
145142
}

src/tools/miri/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ mod eval;
8383
mod helpers;
8484
mod intrinsics;
8585
mod machine;
86+
mod math;
8687
mod mono_hash_map;
8788
mod operator;
8889
mod provenance_gc;

src/tools/miri/src/math.rs

+164
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,164 @@
1+
use rand::Rng as _;
2+
use rand::distributions::Distribution as _;
3+
use rustc_apfloat::Float as _;
4+
use rustc_apfloat::ieee::IeeeFloat;
5+
6+
/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale).
7+
pub(crate) fn apply_random_float_error<F: rustc_apfloat::Float>(
8+
ecx: &mut crate::MiriInterpCx<'_>,
9+
val: F,
10+
err_scale: i32,
11+
) -> F {
12+
let rng = ecx.machine.rng.get_mut();
13+
// Generate a random integer in the range [0, 2^PREC).
14+
let dist = rand::distributions::Uniform::new(0, 1 << F::PRECISION);
15+
let err = F::from_u128(dist.sample(rng))
16+
.value
17+
.scalbn(err_scale.strict_sub(F::PRECISION.try_into().unwrap()));
18+
// give it a random sign
19+
let err = if rng.gen::<bool>() { -err } else { err };
20+
// multiple the value with (1+err)
21+
(val * (F::from_u128(1).value + err).value).value
22+
}
23+
24+
pub(crate) fn sqrt<S: rustc_apfloat::ieee::Semantics>(x: IeeeFloat<S>) -> IeeeFloat<S> {
25+
match x.category() {
26+
// preserve zero sign
27+
rustc_apfloat::Category::Zero => x,
28+
// propagate NaN
29+
rustc_apfloat::Category::NaN => x,
30+
// sqrt of negative number is NaN
31+
_ if x.is_negative() => IeeeFloat::NAN,
32+
// sqrt(∞) = ∞
33+
rustc_apfloat::Category::Infinity => IeeeFloat::INFINITY,
34+
rustc_apfloat::Category::Normal => {
35+
// Floating point precision, excluding the integer bit
36+
let prec = i32::try_from(S::PRECISION).unwrap() - 1;
37+
38+
// x = 2^(exp - prec) * mant
39+
// where mant is an integer with prec+1 bits
40+
// mant is a u128, which should be large enough for the largest prec (112 for f128)
41+
let mut exp = x.ilogb();
42+
let mut mant = x.scalbn(prec - exp).to_u128(128).value;
43+
44+
if exp % 2 != 0 {
45+
// Make exponent even, so it can be divided by 2
46+
exp -= 1;
47+
mant <<= 1;
48+
}
49+
50+
// Bit-by-bit (base-2 digit-by-digit) sqrt of mant.
51+
// mant is treated here as a fixed point number with prec fractional bits.
52+
// mant will be shifted left by one bit to have an extra fractional bit, which
53+
// will be used to determine the rounding direction.
54+
55+
// res is the truncated sqrt of mant, where one bit is added at each iteration.
56+
let mut res = 0u128;
57+
// rem is the remainder with the current res
58+
// rem_i = 2^i * ((mant<<1) - res_i^2)
59+
// starting with res = 0, rem = mant<<1
60+
let mut rem = mant << 1;
61+
// s_i = 2*res_i
62+
let mut s = 0u128;
63+
// d is used to iterate over bits, from high to low (d_i = 2^(-i))
64+
let mut d = 1u128 << (prec + 1);
65+
66+
// For iteration j=i+1, we need to find largest b_j = 0 or 1 such that
67+
// (res_i + b_j * 2^(-j))^2 <= mant<<1
68+
// Expanding (a + b)^2 = a^2 + b^2 + 2*a*b:
69+
// res_i^2 + (b_j * 2^(-j))^2 + 2 * res_i * b_j * 2^(-j) <= mant<<1
70+
// And rearranging the terms:
71+
// b_j^2 * 2^(-j) + 2 * res_i * b_j <= 2^j * (mant<<1 - res_i^2)
72+
// b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i
73+
74+
while d != 0 {
75+
// Probe b_j^2 * 2^(-j) + 2 * res_i * b_j <= rem_i with b_j = 1:
76+
// t = 2*res_i + 2^(-j)
77+
let t = s + d;
78+
if rem >= t {
79+
// b_j should be 1, so make res_j = res_i + 2^(-j) and adjust rem
80+
res += d;
81+
s += d + d;
82+
rem -= t;
83+
}
84+
// Adjust rem for next iteration
85+
rem <<= 1;
86+
// Shift iterator
87+
d >>= 1;
88+
}
89+
90+
// Remove extra fractional bit from result, rounding to nearest.
91+
// If the last bit is 0, then the nearest neighbor is definitely the lower one.
92+
// If the last bit is 1, it sounds like this may either be a tie (if there's
93+
// infinitely many 0s after this 1), or the nearest neighbor is the upper one.
94+
// However, since square roots are either exact or irrational, and an exact root
95+
// would lead to the last "extra" bit being 0, we can exclude a tie in this case.
96+
// We therefore always round up if the last bit is 1. When the last bit is 0,
97+
// adding 1 will not do anything since the shift will discard it.
98+
res = (res + 1) >> 1;
99+
100+
// Build resulting value with res as mantissa and exp/2 as exponent
101+
IeeeFloat::from_u128(res).value.scalbn(exp / 2 - prec)
102+
}
103+
}
104+
}
105+
106+
#[cfg(test)]
107+
mod tests {
108+
use rustc_apfloat::ieee::{DoubleS, HalfS, IeeeFloat, QuadS, SingleS};
109+
110+
use super::sqrt;
111+
112+
#[test]
113+
fn test_sqrt() {
114+
#[track_caller]
115+
fn test<S: rustc_apfloat::ieee::Semantics>(x: &str, expected: &str) {
116+
let x: IeeeFloat<S> = x.parse().unwrap();
117+
let expected: IeeeFloat<S> = expected.parse().unwrap();
118+
let result = sqrt(x);
119+
assert_eq!(result, expected);
120+
}
121+
122+
fn exact_tests<S: rustc_apfloat::ieee::Semantics>() {
123+
test::<S>("0", "0");
124+
test::<S>("1", "1");
125+
test::<S>("1.5625", "1.25");
126+
test::<S>("2.25", "1.5");
127+
test::<S>("4", "2");
128+
test::<S>("5.0625", "2.25");
129+
test::<S>("9", "3");
130+
test::<S>("16", "4");
131+
test::<S>("25", "5");
132+
test::<S>("36", "6");
133+
test::<S>("49", "7");
134+
test::<S>("64", "8");
135+
test::<S>("81", "9");
136+
test::<S>("100", "10");
137+
138+
test::<S>("0.5625", "0.75");
139+
test::<S>("0.25", "0.5");
140+
test::<S>("0.0625", "0.25");
141+
test::<S>("0.00390625", "0.0625");
142+
}
143+
144+
exact_tests::<HalfS>();
145+
exact_tests::<SingleS>();
146+
exact_tests::<DoubleS>();
147+
exact_tests::<QuadS>();
148+
149+
test::<SingleS>("2", "1.4142135");
150+
test::<DoubleS>("2", "1.4142135623730951");
151+
152+
test::<SingleS>("1.1", "1.0488088");
153+
test::<DoubleS>("1.1", "1.0488088481701516");
154+
155+
test::<SingleS>("2.2", "1.4832398");
156+
test::<DoubleS>("2.2", "1.4832396974191326");
157+
158+
test::<SingleS>("1.22101e-40", "1.10499205e-20");
159+
test::<DoubleS>("1.22101e-310", "1.1049932126488395e-155");
160+
161+
test::<SingleS>("3.4028235e38", "1.8446743e19");
162+
test::<DoubleS>("1.7976931348623157e308", "1.3407807929942596e154");
163+
}
164+
}

src/tools/miri/src/shims/x86/mod.rs

+4-23
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
use rand::Rng as _;
21
use rustc_abi::{ExternAbi, Size};
32
use rustc_apfloat::Float;
43
use rustc_apfloat::ieee::Single;
@@ -408,38 +407,20 @@ fn unary_op_f32<'tcx>(
408407
let div = (Single::from_u128(1).value / op).value;
409408
// Apply a relative error with a magnitude on the order of 2^-12 to simulate the
410409
// inaccuracy of RCP.
411-
let res = apply_random_float_error(ecx, div, -12);
410+
let res = math::apply_random_float_error(ecx, div, -12);
412411
interp_ok(Scalar::from_f32(res))
413412
}
414413
FloatUnaryOp::Rsqrt => {
415-
let op = op.to_scalar().to_u32()?;
416-
// FIXME using host floats
417-
let sqrt = Single::from_bits(f32::from_bits(op).sqrt().to_bits().into());
418-
let rsqrt = (Single::from_u128(1).value / sqrt).value;
414+
let op = op.to_scalar().to_f32()?;
415+
let rsqrt = (Single::from_u128(1).value / math::sqrt(op)).value;
419416
// Apply a relative error with a magnitude on the order of 2^-12 to simulate the
420417
// inaccuracy of RSQRT.
421-
let res = apply_random_float_error(ecx, rsqrt, -12);
418+
let res = math::apply_random_float_error(ecx, rsqrt, -12);
422419
interp_ok(Scalar::from_f32(res))
423420
}
424421
}
425422
}
426423

427-
/// Disturbes a floating-point result by a relative error on the order of (-2^scale, 2^scale).
428-
#[expect(clippy::arithmetic_side_effects)] // floating point arithmetic cannot panic
429-
fn apply_random_float_error<F: rustc_apfloat::Float>(
430-
ecx: &mut crate::MiriInterpCx<'_>,
431-
val: F,
432-
err_scale: i32,
433-
) -> F {
434-
let rng = ecx.machine.rng.get_mut();
435-
// generates rand(0, 2^64) * 2^(scale - 64) = rand(0, 1) * 2^scale
436-
let err = F::from_u128(rng.gen::<u64>().into()).value.scalbn(err_scale.strict_sub(64));
437-
// give it a random sign
438-
let err = if rng.gen::<bool>() { -err } else { err };
439-
// multiple the value with (1+err)
440-
(val * (F::from_u128(1).value + err).value).value
441-
}
442-
443424
/// Performs `which` operation on the first component of `op` and copies
444425
/// the other components. The result is stored in `dest`.
445426
fn unary_op_ss<'tcx>(

src/tools/miri/tests/pass/float.rs

+12-2
Original file line numberDiff line numberDiff line change
@@ -959,10 +959,20 @@ pub fn libm() {
959959
unsafe { ldexp(a, b) }
960960
}
961961

962-
assert_approx_eq!(64f32.sqrt(), 8f32);
963-
assert_approx_eq!(64f64.sqrt(), 8f64);
962+
assert_eq!(64_f32.sqrt(), 8_f32);
963+
assert_eq!(64_f64.sqrt(), 8_f64);
964+
assert_eq!(f32::INFINITY.sqrt(), f32::INFINITY);
965+
assert_eq!(f64::INFINITY.sqrt(), f64::INFINITY);
966+
assert_eq!(0.0_f32.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
967+
assert_eq!(0.0_f64.sqrt().total_cmp(&0.0), std::cmp::Ordering::Equal);
968+
assert_eq!((-0.0_f32).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
969+
assert_eq!((-0.0_f64).sqrt().total_cmp(&-0.0), std::cmp::Ordering::Equal);
964970
assert!((-5.0_f32).sqrt().is_nan());
965971
assert!((-5.0_f64).sqrt().is_nan());
972+
assert!(f32::NEG_INFINITY.sqrt().is_nan());
973+
assert!(f64::NEG_INFINITY.sqrt().is_nan());
974+
assert!(f32::NAN.sqrt().is_nan());
975+
assert!(f64::NAN.sqrt().is_nan());
966976

967977
assert_approx_eq!(25f32.powi(-2), 0.0016f32);
968978
assert_approx_eq!(23.2f64.powi(2), 538.24f64);

0 commit comments

Comments
 (0)