Skip to content

Commit fea585d

Browse files
committed
Greatly sped up checked_isqrt and isqrt methods
* Uses a lookup table for 8-bit integers and then the Karatsuba square root algorithm for larger integers. * Includes optimization hints that give the compiler the exact numeric range of results.
1 parent d180572 commit fea585d

File tree

5 files changed

+237
-35
lines changed

5 files changed

+237
-35
lines changed

library/core/src/num/int_macros.rs

+27-9
Original file line numberDiff line numberDiff line change
@@ -1580,7 +1580,18 @@ macro_rules! int_impl {
15801580
if self < 0 {
15811581
None
15821582
} else {
1583-
Some((self as $UnsignedT).isqrt() as Self)
1583+
let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT;
1584+
1585+
// SAFETY: Inform the optimizer that square roots of
1586+
// nonnegative integers are nonnegative and what the maximum
1587+
// result is.
1588+
unsafe {
1589+
crate::hint::assert_unchecked(result >= 0);
1590+
const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT;
1591+
crate::hint::assert_unchecked(result <= MAX_RESULT);
1592+
}
1593+
1594+
Some(result)
15841595
}
15851596
}
15861597

@@ -2766,14 +2777,21 @@ macro_rules! int_impl {
27662777
without modifying the original"]
27672778
#[inline]
27682779
pub const fn isqrt(self) -> Self {
2769-
// I would like to implement it as
2770-
// ```
2771-
// self.checked_isqrt().expect("argument of integer square root must be non-negative")
2772-
// ```
2773-
// but `expect` is not yet stable as a `const fn`.
2774-
match self.checked_isqrt() {
2775-
Some(sqrt) => sqrt,
2776-
None => panic!("argument of integer square root must be non-negative"),
2780+
if self < 0 {
2781+
crate::num::int_sqrt::panic_for_negative_argument();
2782+
} else {
2783+
let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT;
2784+
2785+
// SAFETY: Inform the optimizer that square roots of
2786+
// nonnegative integers are nonnegative and what the maximum
2787+
// result is.
2788+
unsafe {
2789+
crate::hint::assert_unchecked(result >= 0);
2790+
const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT;
2791+
crate::hint::assert_unchecked(result <= MAX_RESULT);
2792+
}
2793+
2794+
result
27772795
}
27782796
}
27792797

library/core/src/num/int_sqrt.rs

+190
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/// These functions compute the integer square root of their type, assuming
2+
/// that someone has already checked that the value is nonnegative.
3+
4+
const ISQRT_AND_REMAINDER_8_BIT: [(u8, u8); 256] = {
5+
let mut result = [(0, 0); 256];
6+
7+
let mut sqrt = 0;
8+
let mut i = 0;
9+
'outer: loop {
10+
let mut remaining = 2 * sqrt + 1;
11+
while remaining > 0 {
12+
result[i as usize] = (sqrt, 2 * sqrt + 1 - remaining);
13+
i += 1;
14+
if i >= result.len() {
15+
break 'outer;
16+
}
17+
remaining -= 1;
18+
}
19+
sqrt += 1;
20+
}
21+
22+
result
23+
};
24+
25+
// `#[inline(always)]` because the programmer-accessible functions will use
26+
// this internally and the contents of this should be inlined there.
27+
#[inline(always)]
28+
pub const fn u8(n: u8) -> u8 {
29+
ISQRT_AND_REMAINDER_8_BIT[n as usize].0
30+
}
31+
32+
#[inline(always)]
33+
const fn intermediate_u8(n: u8) -> (u8, u8) {
34+
ISQRT_AND_REMAINDER_8_BIT[n as usize]
35+
}
36+
37+
macro_rules! karatsuba_isqrt {
38+
($FullBitsT:ty, $fn:ident, $intermediate_fn:ident, $HalfBitsT:ty, $half_fn:ident, $intermediate_half_fn:ident) => {
39+
// `#[inline(always)]` because the programmer-accessible functions will
40+
// use this internally and the contents of this should be inlined
41+
// there.
42+
#[inline(always)]
43+
pub const fn $fn(mut n: $FullBitsT) -> $FullBitsT {
44+
// Performs a Karatsuba square root.
45+
// https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf
46+
47+
const HALF_BITS: u32 = <$FullBitsT>::BITS >> 1;
48+
const QUARTER_BITS: u32 = <$FullBitsT>::BITS >> 2;
49+
50+
let leading_zeros = n.leading_zeros();
51+
let result = if leading_zeros >= HALF_BITS {
52+
$half_fn(n as $HalfBitsT) as $FullBitsT
53+
} else {
54+
// Either the most-significant bit or its neighbor must be a one, so we shift left to make that happen.
55+
let precondition_shift = leading_zeros & (HALF_BITS - 2);
56+
n <<= precondition_shift;
57+
58+
let hi = (n >> HALF_BITS) as $HalfBitsT;
59+
let lo = n & (<$HalfBitsT>::MAX as $FullBitsT);
60+
61+
let (s_prime, r_prime) = $intermediate_half_fn(hi);
62+
63+
let numerator = ((r_prime as $FullBitsT) << QUARTER_BITS) | (lo >> QUARTER_BITS);
64+
let denominator = (s_prime as $FullBitsT) << 1;
65+
66+
let q = numerator / denominator;
67+
let u = numerator % denominator;
68+
69+
let mut s = (s_prime << QUARTER_BITS) as $FullBitsT + q;
70+
if ((u << QUARTER_BITS) | (lo & ((1 << QUARTER_BITS) - 1))) < q * q {
71+
s -= 1;
72+
}
73+
s >> (precondition_shift >> 1)
74+
};
75+
76+
result
77+
}
78+
79+
const fn $intermediate_fn(mut n: $FullBitsT) -> ($FullBitsT, $FullBitsT) {
80+
// Performs a Karatsuba square root.
81+
// https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf
82+
83+
const HALF_BITS: u32 = <$FullBitsT>::BITS >> 1;
84+
const QUARTER_BITS: u32 = <$FullBitsT>::BITS >> 2;
85+
86+
let leading_zeros = n.leading_zeros();
87+
let result = if leading_zeros >= HALF_BITS {
88+
let (s, r) = $intermediate_half_fn(n as $HalfBitsT);
89+
(s as $FullBitsT, r as $FullBitsT)
90+
} else {
91+
// Either the most-significant bit or its neighbor must be a one, so we shift left to make that happen.
92+
let precondition_shift = leading_zeros & (HALF_BITS - 2);
93+
n <<= precondition_shift;
94+
95+
let hi = (n >> HALF_BITS) as $HalfBitsT;
96+
let lo = n & (<$HalfBitsT>::MAX as $FullBitsT);
97+
98+
let (s_prime, r_prime) = $intermediate_half_fn(hi);
99+
100+
let numerator = ((r_prime as $FullBitsT) << QUARTER_BITS) | (lo >> QUARTER_BITS);
101+
let denominator = (s_prime as $FullBitsT) << 1;
102+
103+
let q = numerator / denominator;
104+
let u = numerator % denominator;
105+
106+
let mut s = (s_prime << QUARTER_BITS) as $FullBitsT + q;
107+
let (mut r, overflow) =
108+
((u << QUARTER_BITS) | (lo & ((1 << QUARTER_BITS) - 1))).overflowing_sub(q * q);
109+
if overflow {
110+
r = r.wrapping_add((s << 1) - 1);
111+
s -= 1;
112+
}
113+
(s >> (precondition_shift >> 1), r >> (precondition_shift >> 1))
114+
};
115+
116+
result
117+
}
118+
};
119+
}
120+
121+
karatsuba_isqrt!(u16, u16, intermediate_u16, u8, u8, intermediate_u8);
122+
karatsuba_isqrt!(u32, u32, intermediate_u32, u16, u16, intermediate_u16);
123+
karatsuba_isqrt!(u64, u64, intermediate_u64, u32, u32, intermediate_u32);
124+
karatsuba_isqrt!(u128, u128, _intermediate_u128, u64, u64, intermediate_u64);
125+
126+
#[cfg(target_pointer_width = "16")]
127+
#[inline(always)]
128+
pub const fn usize(n: usize) -> usize {
129+
u16(n as u16) as usize
130+
}
131+
132+
#[cfg(target_pointer_width = "32")]
133+
#[inline(always)]
134+
pub const fn usize(n: usize) -> usize {
135+
u32(n as u32) as usize
136+
}
137+
138+
#[cfg(target_pointer_width = "64")]
139+
#[inline(always)]
140+
pub const fn usize(n: usize) -> usize {
141+
u64(n as u64) as usize
142+
}
143+
144+
// 0 <= val <= i8::MAX
145+
#[inline(always)]
146+
pub const fn i8(n: i8) -> i8 {
147+
u8(n as u8) as i8
148+
}
149+
150+
// 0 <= val <= i16::MAX
151+
#[inline(always)]
152+
pub const fn i16(n: i16) -> i16 {
153+
u16(n as u16) as i16
154+
}
155+
156+
// 0 <= val <= i32::MAX
157+
#[inline(always)]
158+
pub const fn i32(n: i32) -> i32 {
159+
u32(n as u32) as i32
160+
}
161+
162+
// 0 <= val <= i64::MAX
163+
#[inline(always)]
164+
pub const fn i64(n: i64) -> i64 {
165+
u64(n as u64) as i64
166+
}
167+
168+
// 0 <= val <= i128::MAX
169+
#[inline(always)]
170+
pub const fn i128(n: i128) -> i128 {
171+
u128(n as u128) as i128
172+
}
173+
174+
/*
175+
This function is not used.
176+
177+
// 0 <= val <= isize::MAX
178+
#[inline(always)]
179+
pub const fn isize(n: isize) -> isize {
180+
usize(n as usize) as isize
181+
}
182+
*/
183+
184+
/// Instantiate this panic logic once, rather than for all the ilog methods
185+
/// on every single primitive type.
186+
#[cold]
187+
#[track_caller]
188+
pub const fn panic_for_negative_argument() -> ! {
189+
panic!("argument of integer square root cannot be negative")
190+
}

library/core/src/num/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ mod uint_macros; // import uint_impl!
4343

4444
mod error;
4545
mod int_log10;
46+
mod int_sqrt;
4647
mod nonzero;
4748
mod overflow_panic;
4849
mod saturating;

library/core/src/num/nonzero.rs

+11-23
Original file line numberDiff line numberDiff line change
@@ -1247,31 +1247,19 @@ macro_rules! nonzero_integer_signedness_dependent_methods {
12471247
without modifying the original"]
12481248
#[inline]
12491249
pub const fn isqrt(self) -> Self {
1250-
// The algorithm is based on the one presented in
1251-
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
1252-
// which cites as source the following C code:
1253-
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.
1254-
1255-
let mut op = self.get();
1256-
let mut res = 0;
1257-
let mut one = 1 << (self.ilog2() & !1);
1258-
1259-
while one != 0 {
1260-
if op >= res + one {
1261-
op -= res + one;
1262-
res = (res >> 1) + one;
1263-
} else {
1264-
res >>= 1;
1265-
}
1266-
one >>= 2;
1250+
let result = super::int_sqrt::$Int(self.get());
1251+
1252+
// SAFETY: Inform the optimizer that square roots of positive
1253+
// integers are positive and what the maximum result is.
1254+
unsafe {
1255+
hint::assert_unchecked(result > 0);
1256+
const MAX_RESULT: $Int = super::int_sqrt::$Int($Int::MAX);
1257+
hint::assert_unchecked(result <= MAX_RESULT);
12671258
}
12681259

1269-
// SAFETY: The result fits in an integer with half as many bits.
1270-
// Inform the optimizer about it.
1271-
unsafe { hint::assert_unchecked(res < 1 << (Self::BITS / 2)) };
1272-
1273-
// SAFETY: The square root of an integer >= 1 is always >= 1.
1274-
unsafe { Self::new_unchecked(res) }
1260+
// SAFETY: The square root of a positive integer is always
1261+
// positive.
1262+
unsafe { Self::new_unchecked(result) }
12751263
}
12761264
};
12771265

library/core/src/num/uint_macros.rs

+8-3
Original file line numberDiff line numberDiff line change
@@ -2588,10 +2588,15 @@ macro_rules! uint_impl {
25882588
without modifying the original"]
25892589
#[inline]
25902590
pub const fn isqrt(self) -> Self {
2591-
match NonZero::new(self) {
2592-
Some(x) => x.isqrt().get(),
2593-
None => 0,
2591+
let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT;
2592+
2593+
// SAFETY: Inform the optimizer of what the maximum result is.
2594+
unsafe {
2595+
const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT;
2596+
crate::hint::assert_unchecked(result <= MAX_RESULT);
25942597
}
2598+
2599+
result
25952600
}
25962601

25972602
/// Performs Euclidean division.

0 commit comments

Comments
 (0)