@@ -3,6 +3,7 @@ use core::ops::{Bound, RangeBounds};
3
3
4
4
pub trait Uniform : Sized + PartialEq + Eq + PartialOrd + Ord {
5
5
fn sample < F : FillBytes > ( fill : & mut F , min : Bound < & Self > , max : Bound < & Self > ) -> Option < Self > ;
6
+ fn sample_unbound < F : FillBytes > ( fill : & mut F ) -> Option < Self > ;
6
7
}
7
8
8
9
pub trait FillBytes {
@@ -19,8 +20,15 @@ pub trait FillBytes {
19
20
}
20
21
21
22
macro_rules! uniform_int {
22
- ( $ty: ident, $unsigned: ident $( , $smaller: ident) ? ) => {
23
+ ( $ty: ident, $unsigned: ident $( , $smaller: ident) * ) => {
23
24
impl Uniform for $ty {
25
+ #[ inline( always) ]
26
+ fn sample_unbound<F : FillBytes >( fill: & mut F ) -> Option <$ty> {
27
+ let mut bytes = [ 0u8 ; core:: mem:: size_of:: <$ty>( ) ] ;
28
+ fill. fill_bytes( & mut bytes) ?;
29
+ return Some ( <$ty>:: from_le_bytes( bytes) ) ;
30
+ }
31
+
24
32
#[ inline]
25
33
fn sample<F : FillBytes >( fill: & mut F , min: Bound <& $ty>, max: Bound <& $ty>) -> Option <$ty> {
26
34
match ( min, max) {
@@ -45,19 +53,11 @@ macro_rules! uniform_int {
45
53
| ( Bound :: Unbounded , Bound :: Included ( & $ty:: MAX ) )
46
54
| ( Bound :: Included ( & $ty:: MIN ) , Bound :: Unbounded )
47
55
| ( Bound :: Included ( & $ty:: MIN ) , Bound :: Included ( & $ty:: MAX ) ) => {
48
- let mut bytes = [ 0u8 ; core:: mem:: size_of:: <$ty>( ) ] ;
49
- fill. fill_bytes( & mut bytes) ?;
50
- return Some ( <$ty>:: from_le_bytes( bytes) ) ;
56
+ return Self :: sample_unbound( fill) ;
51
57
}
52
58
_ => { }
53
59
}
54
60
55
- // if we're in direct mode, just sample a value and check if it's within the provided range
56
- if fill. mode( ) == DriverMode :: Direct {
57
- return Self :: sample( fill, Bound :: Unbounded , Bound :: Unbounded )
58
- . filter( |value| ( min, max) . contains( value) ) ;
59
- }
60
-
61
61
let lower = match min {
62
62
Bound :: Included ( & v) => v,
63
63
Bound :: Excluded ( v) => v. saturating_add( 1 ) ,
@@ -90,16 +90,15 @@ macro_rules! uniform_int {
90
90
91
91
return Some ( value) ;
92
92
}
93
- } ) ?
93
+ } ) *
94
94
95
- let value: $unsigned = Uniform :: sample ( fill, Bound :: Unbounded , Bound :: Unbounded ) ?;
95
+ let value: $unsigned = Uniform :: sample_unbound ( fill) ?;
96
96
97
97
if cfg!( test) {
98
98
assert!( range_inclusive < $unsigned:: MAX , "range inclusive should always be less than the max value" ) ;
99
99
}
100
100
let range_exclusive = range_inclusive. wrapping_add( 1 ) ;
101
- // TODO make this less biased
102
- let value = value % range_exclusive;
101
+ let value = value. scale( range_exclusive) ;
103
102
let value = value as $ty;
104
103
let value = lower. wrapping_add( value) ;
105
104
@@ -118,23 +117,77 @@ uniform_int!(u8, u8);
118
117
uniform_int ! ( i8 , u8 ) ;
119
118
uniform_int ! ( u16 , u16 , u8 ) ;
120
119
uniform_int ! ( i16 , u16 , u8 ) ;
121
- uniform_int ! ( u32 , u32 , u16 ) ;
122
- uniform_int ! ( i32 , u32 , u16 ) ;
123
- uniform_int ! ( u64 , u64 , u32 ) ;
124
- uniform_int ! ( i64 , u64 , u32 ) ;
125
- uniform_int ! ( u128 , u128 , u64 ) ;
126
- uniform_int ! ( i128 , u128 , u64 ) ;
127
- uniform_int ! ( usize , usize , u64 ) ;
128
- uniform_int ! ( isize , usize , u64 ) ;
120
+ uniform_int ! ( u32 , u32 , u8 , u16 ) ;
121
+ uniform_int ! ( i32 , u32 , u8 , u16 ) ;
122
+ uniform_int ! ( u64 , u64 , u8 , u16 , u32 ) ;
123
+ uniform_int ! ( i64 , u64 , u8 , u16 , u32 ) ;
124
+ uniform_int ! ( usize , usize , u8 , u16 , u32 ) ;
125
+ uniform_int ! ( isize , usize , u8 , u16 , u32 ) ;
126
+ uniform_int ! ( u128 , u128 , u8 , u16 , u32 , u64 ) ;
127
+ uniform_int ! ( i128 , u128 , u8 , u16 , u32 , u64 ) ;
128
+
129
+ trait Scaled : Sized {
130
+ fn scale ( self , range : Self ) -> Self ;
131
+ }
132
+
133
+ macro_rules! scaled {
134
+ ( $s: ty, $upper: ty) => {
135
+ impl Scaled for $s {
136
+ #[ inline( always) ]
137
+ fn scale( self , range: Self ) -> Self {
138
+ // similar approach to Lemire random sampling
139
+ // see https://lemire.me/blog/2019/06/06/nearly-divisionless-random-integer-generation-on-various-systems/
140
+ let m = self as $upper * range as $upper;
141
+ ( m >> Self :: BITS ) as Self
142
+ }
143
+ }
144
+ } ;
145
+ }
146
+
147
+ scaled ! ( u8 , u16 ) ;
148
+ scaled ! ( u16 , u32 ) ;
149
+ scaled ! ( u32 , u64 ) ;
150
+ scaled ! ( u64 , u128 ) ;
151
+ scaled ! ( usize , u128 ) ;
152
+
153
+ impl Scaled for u128 {
154
+ #[ inline( always) ]
155
+ fn scale ( self , range : Self ) -> Self {
156
+ // adapted from mulddi3 https://github.com/llvm/llvm-project/blob/6a3982f8b7e37987659706cb3e6427c54c9bc7ce/compiler-rt/lib/builtins/multi3.c#L19
157
+ const BITS_IN_DWORD_2 : u32 = 64 ;
158
+ const LOWER_MASK : u128 = u128:: MAX >> BITS_IN_DWORD_2 ;
159
+
160
+ let a = self ;
161
+ let b = range;
162
+
163
+ let mut low = ( a & LOWER_MASK ) * ( b & LOWER_MASK ) ;
164
+ let mut t = low >> BITS_IN_DWORD_2 ;
165
+ low &= LOWER_MASK ;
166
+ t += ( a >> BITS_IN_DWORD_2 ) * ( b & LOWER_MASK ) ;
167
+ low += ( t & LOWER_MASK ) << BITS_IN_DWORD_2 ;
168
+ let mut high = t >> BITS_IN_DWORD_2 ;
169
+ t = low >> BITS_IN_DWORD_2 ;
170
+ low &= LOWER_MASK ;
171
+ t += ( b >> BITS_IN_DWORD_2 ) * ( a & LOWER_MASK ) ;
172
+ low += ( t & LOWER_MASK ) << BITS_IN_DWORD_2 ;
173
+ high += t >> BITS_IN_DWORD_2 ;
174
+ high += ( a >> BITS_IN_DWORD_2 ) * ( b >> BITS_IN_DWORD_2 ) ;
175
+
176
+ // discard the low bits
177
+ let _ = low;
178
+
179
+ high
180
+ }
181
+ }
129
182
130
183
impl Uniform for char {
184
+ #[ inline( always) ]
185
+ fn sample_unbound < F : FillBytes > ( fill : & mut F ) -> Option < Self > {
186
+ Self :: sample ( fill, Bound :: Unbounded , Bound :: Unbounded )
187
+ }
188
+
131
189
#[ inline]
132
190
fn sample < F : FillBytes > ( fill : & mut F , min : Bound < & Self > , max : Bound < & Self > ) -> Option < Self > {
133
- if fill. mode ( ) == DriverMode :: Direct {
134
- let value = u32:: sample ( fill, Bound :: Unbounded , Bound :: Unbounded ) ?;
135
- return char:: from_u32 ( value) ;
136
- }
137
-
138
191
const START : u32 = 0xD800 ;
139
192
const LEN : u32 = 0xE000 - START ;
140
193
@@ -174,6 +227,15 @@ impl Uniform for char {
174
227
#[ cfg( test) ]
175
228
mod tests {
176
229
use super :: * ;
230
+ use core:: fmt;
231
+
232
+ #[ test]
233
+ fn scaled_u128_test ( ) {
234
+ assert_eq ! ( 0u128 . scale( 3 ) , 0 ) ;
235
+ assert_eq ! ( u128 :: MAX . scale( 3 ) , 2 ) ;
236
+ assert_eq ! ( ( u128 :: MAX - 1 ) . scale( 3 ) , 2 ) ;
237
+ assert_eq ! ( ( u128 :: MAX / 2 ) . scale( 3 ) , 1 ) ;
238
+ }
177
239
178
240
#[ derive( Clone , Copy , Debug ) ]
179
241
struct Byte {
@@ -210,7 +272,7 @@ mod tests {
210
272
}
211
273
}
212
274
213
- #[ derive( Clone , Copy , Debug , PartialEq ) ]
275
+ #[ derive( Clone , Copy , PartialEq ) ]
214
276
struct Seen < T : SeenValue > ( [ bool ; 256 ] , core:: marker:: PhantomData < T > ) ;
215
277
216
278
impl < T : SeenValue > Default for Seen < T > {
@@ -219,18 +281,35 @@ mod tests {
219
281
}
220
282
}
221
283
284
+ impl < T : SeenValue > fmt:: Debug for Seen < T > {
285
+ fn fmt ( & self , f : & mut fmt:: Formatter ) -> fmt:: Result {
286
+ f. debug_list ( )
287
+ . entries (
288
+ self . 0
289
+ . iter ( )
290
+ . enumerate ( )
291
+ . filter_map ( |( idx, seen) | if * seen { Some ( idx) } else { None } ) ,
292
+ )
293
+ . finish ( )
294
+ }
295
+ }
296
+
222
297
impl < T : SeenValue > Seen < T > {
223
298
fn insert ( & mut self , v : T ) {
224
299
self . 0 [ v. index ( ) ] = true ;
225
300
}
226
301
}
227
302
228
303
trait SeenValue : Copy + Uniform + core:: fmt:: Debug {
304
+ const ENTRIES : usize ;
305
+
229
306
fn index ( self ) -> usize ;
230
307
fn fill_expected ( min : Bound < Self > , max : Bound < Self > , seen : & mut Seen < Self > ) ;
231
308
}
232
309
233
310
impl SeenValue for u8 {
311
+ const ENTRIES : usize = 256 ;
312
+
234
313
fn index ( self ) -> usize {
235
314
self as _
236
315
}
@@ -245,6 +324,8 @@ mod tests {
245
324
}
246
325
247
326
impl SeenValue for i8 {
327
+ const ENTRIES : usize = 256 ;
328
+
248
329
fn index ( self ) -> usize {
249
330
( self as isize + -( i8:: MIN as isize ) ) . try_into ( ) . unwrap ( )
250
331
}
0 commit comments