|
30 | 30 | //! ```
|
31 | 31 | //! use rand::{Rng, thread_rng};
|
32 | 32 | //! use rand::distributions::Uniform;
|
33 |
| -//! |
| 33 | +//! |
34 | 34 | //! let mut rng = thread_rng();
|
35 | 35 | //! let side = Uniform::new(-10.0, 10.0);
|
36 |
| -//! |
| 36 | +//! |
37 | 37 | //! // sample between 1 and 10 points
|
38 | 38 | //! for _ in 0..rng.gen_range(1, 11) {
|
39 | 39 | //! // sample a point from the square with sides -10 - 10 in two dimensions
|
@@ -482,6 +482,150 @@ uniform_int_impl! { usize, isize, usize, isize, usize }
|
482 | 482 | #[cfg(rust_1_26)]
|
483 | 483 | uniform_int_impl! { u128, u128, u128, i128, u128 }
|
484 | 484 |
|
| 485 | +#[cfg(feature = "simd_support")] |
| 486 | +macro_rules! uniform_simd_int_impl { |
| 487 | + ($ty:ident, $unsigned:ident, $u_scalar:ident) => { |
| 488 | + // The "pick the largest zone that can fit in an `u32`" optimization |
| 489 | + // is less useful here. Multiple lanes complicate things, we don't |
| 490 | + // know the PRNG's minimal output size, and casting to a larger vector |
| 491 | + // is generally a bad idea for SIMD performance. The user can still |
| 492 | + // implement it manually. |
| 493 | + |
| 494 | + // TODO: look into `Uniform::<u32x4>::new(0u32, 100)` functionality |
| 495 | + // perhaps `impl SampleUniform for $u_scalar`? |
| 496 | + impl SampleUniform for $ty { |
| 497 | + type Sampler = UniformInt<$ty>; |
| 498 | + } |
| 499 | + |
| 500 | + impl UniformSampler for UniformInt<$ty> { |
| 501 | + type X = $ty; |
| 502 | + |
| 503 | + #[inline] // if the range is constant, this helps LLVM to do the |
| 504 | + // calculations at compile-time. |
| 505 | + fn new<B1, B2>(low_b: B1, high_b: B2) -> Self |
| 506 | + where B1: SampleBorrow<Self::X> + Sized, |
| 507 | + B2: SampleBorrow<Self::X> + Sized |
| 508 | + { |
| 509 | + let low = *low_b.borrow(); |
| 510 | + let high = *high_b.borrow(); |
| 511 | + assert!(low.lt(high).all(), "Uniform::new called with `low >= high`"); |
| 512 | + UniformSampler::new_inclusive(low, high - 1) |
| 513 | + } |
| 514 | + |
| 515 | + #[inline] // if the range is constant, this helps LLVM to do the |
| 516 | + // calculations at compile-time. |
| 517 | + fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Self |
| 518 | + where B1: SampleBorrow<Self::X> + Sized, |
| 519 | + B2: SampleBorrow<Self::X> + Sized |
| 520 | + { |
| 521 | + let low = *low_b.borrow(); |
| 522 | + let high = *high_b.borrow(); |
| 523 | + assert!(low.le(high).all(), |
| 524 | + "Uniform::new_inclusive called with `low > high`"); |
| 525 | + let unsigned_max = ::core::$u_scalar::MAX; |
| 526 | + |
| 527 | + // NOTE: these may need to be replaced with explicitly |
| 528 | + // wrapping operations if `packed_simd` changes |
| 529 | + let range: $unsigned = ((high - low) + 1).cast(); |
| 530 | + // `% 0` will panic at runtime. |
| 531 | + let not_full_range = range.gt($unsigned::splat(0)); |
| 532 | + // replacing 0 with `unsigned_max` allows a faster `select` |
| 533 | + // with bitwise OR |
| 534 | + let modulo = not_full_range.select(range, $unsigned::splat(unsigned_max)); |
| 535 | + // wrapping addition |
| 536 | + let ints_to_reject = (unsigned_max - range + 1) % modulo; |
| 537 | + // When `range` is 0, `lo` of `v.wmul(range)` will always be |
| 538 | + // zero which means only one sample is needed. |
| 539 | + let zone = unsigned_max - ints_to_reject; |
| 540 | + |
| 541 | + UniformInt { |
| 542 | + low: low, |
| 543 | + // These are really $unsigned values, but store as $ty: |
| 544 | + range: range.cast(), |
| 545 | + zone: zone.cast(), |
| 546 | + } |
| 547 | + } |
| 548 | + |
| 549 | + fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X { |
| 550 | + let range: $unsigned = self.range.cast(); |
| 551 | + let zone: $unsigned = self.zone.cast(); |
| 552 | + |
| 553 | + // This might seem very slow, generating a whole new |
| 554 | + // SIMD vector for every sample rejection. For most uses |
| 555 | + // though, the chance of rejection is small and provides good |
| 556 | + // general performance. With multiple lanes, that chance is |
| 557 | + // multiplied. To mitigate this, we replace only the lanes of |
| 558 | + // the vector which fail, iteratively reducing the chance of |
| 559 | + // rejection. The replacement method does however add a little |
| 560 | + // overhead. Benchmarking or calculating probabilities might |
| 561 | + // reveal contexts where this replacement method is slower. |
| 562 | + let mut v: $unsigned = rng.gen(); |
| 563 | + loop { |
| 564 | + let (hi, lo) = v.wmul(range); |
| 565 | + let mask = lo.le(zone); |
| 566 | + if mask.all() { |
| 567 | + let hi: $ty = hi.cast(); |
| 568 | + // wrapping addition |
| 569 | + let result = self.low + hi; |
| 570 | + // `select` here compiles to a blend operation |
| 571 | + // When `range.eq(0).none()` the compare and blend |
| 572 | + // operations are avoided. |
| 573 | + let v: $ty = v.cast(); |
| 574 | + return range.gt($unsigned::splat(0)).select(result, v); |
| 575 | + } |
| 576 | + // Replace only the failing lanes |
| 577 | + v = mask.select(v, rng.gen()); |
| 578 | + } |
| 579 | + } |
| 580 | + } |
| 581 | + }; |
| 582 | + |
| 583 | + // bulk implementation |
| 584 | + ($(($unsigned:ident, $signed:ident),)+ $u_scalar:ident) => { |
| 585 | + $( |
| 586 | + uniform_simd_int_impl!($unsigned, $unsigned, $u_scalar); |
| 587 | + uniform_simd_int_impl!($signed, $unsigned, $u_scalar); |
| 588 | + )+ |
| 589 | + }; |
| 590 | +} |
| 591 | + |
| 592 | +#[cfg(feature = "simd_support")] |
| 593 | +uniform_simd_int_impl! { |
| 594 | + (u64x2, i64x2), |
| 595 | + (u64x4, i64x4), |
| 596 | + (u64x8, i64x8), |
| 597 | + u64 |
| 598 | +} |
| 599 | + |
| 600 | +#[cfg(feature = "simd_support")] |
| 601 | +uniform_simd_int_impl! { |
| 602 | + (u32x2, i32x2), |
| 603 | + (u32x4, i32x4), |
| 604 | + (u32x8, i32x8), |
| 605 | + (u32x16, i32x16), |
| 606 | + u32 |
| 607 | +} |
| 608 | + |
| 609 | +#[cfg(feature = "simd_support")] |
| 610 | +uniform_simd_int_impl! { |
| 611 | + (u16x2, i16x2), |
| 612 | + (u16x4, i16x4), |
| 613 | + (u16x8, i16x8), |
| 614 | + (u16x16, i16x16), |
| 615 | + (u16x32, i16x32), |
| 616 | + u16 |
| 617 | +} |
| 618 | + |
| 619 | +#[cfg(feature = "simd_support")] |
| 620 | +uniform_simd_int_impl! { |
| 621 | + (u8x2, i8x2), |
| 622 | + (u8x4, i8x4), |
| 623 | + (u8x8, i8x8), |
| 624 | + (u8x16, i8x16), |
| 625 | + (u8x32, i8x32), |
| 626 | + (u8x64, i8x64), |
| 627 | + u8 |
| 628 | +} |
485 | 629 |
|
486 | 630 |
|
487 | 631 | /// The back-end implementing [`UniformSampler`] for floating-point types.
|
@@ -817,50 +961,86 @@ mod tests {
|
817 | 961 |
|
818 | 962 | #[test]
|
819 | 963 | fn test_integers() {
|
| 964 | + use core::{i8, i16, i32, i64, isize}; |
| 965 | + use core::{u8, u16, u32, u64, usize}; |
| 966 | + #[cfg(rust_1_26)] |
| 967 | + use core::{i128, u128}; |
| 968 | + |
820 | 969 | let mut rng = ::test::rng(251);
|
821 | 970 | macro_rules! t {
|
822 |
| - ($($ty:ident),*) => {{ |
823 |
| - $( |
824 |
| - let v: &[($ty, $ty)] = &[(0, 10), |
825 |
| - (10, 127), |
826 |
| - (::core::$ty::MIN, ::core::$ty::MAX)]; |
827 |
| - for &(low, high) in v.iter() { |
828 |
| - let my_uniform = Uniform::new(low, high); |
829 |
| - for _ in 0..1000 { |
830 |
| - let v: $ty = rng.sample(my_uniform); |
831 |
| - assert!(low <= v && v < high); |
832 |
| - } |
| 971 | + ($ty:ident, $v:expr, $le:expr, $lt:expr) => {{ |
| 972 | + for &(low, high) in $v.iter() { |
| 973 | + let my_uniform = Uniform::new(low, high); |
| 974 | + for _ in 0..1000 { |
| 975 | + let v: $ty = rng.sample(my_uniform); |
| 976 | + assert!($le(low, v) && $lt(v, high)); |
| 977 | + } |
833 | 978 |
|
834 |
| - let my_uniform = Uniform::new_inclusive(low, high); |
835 |
| - for _ in 0..1000 { |
836 |
| - let v: $ty = rng.sample(my_uniform); |
837 |
| - assert!(low <= v && v <= high); |
838 |
| - } |
| 979 | + let my_uniform = Uniform::new_inclusive(low, high); |
| 980 | + for _ in 0..1000 { |
| 981 | + let v: $ty = rng.sample(my_uniform); |
| 982 | + assert!($le(low, v) && $le(v, high)); |
| 983 | + } |
839 | 984 |
|
840 |
| - let my_uniform = Uniform::new(&low, high); |
841 |
| - for _ in 0..1000 { |
842 |
| - let v: $ty = rng.sample(my_uniform); |
843 |
| - assert!(low <= v && v < high); |
844 |
| - } |
| 985 | + let my_uniform = Uniform::new(&low, high); |
| 986 | + for _ in 0..1000 { |
| 987 | + let v: $ty = rng.sample(my_uniform); |
| 988 | + assert!($le(low, v) && $lt(v, high)); |
| 989 | + } |
845 | 990 |
|
846 |
| - let my_uniform = Uniform::new_inclusive(&low, &high); |
847 |
| - for _ in 0..1000 { |
848 |
| - let v: $ty = rng.sample(my_uniform); |
849 |
| - assert!(low <= v && v <= high); |
850 |
| - } |
| 991 | + let my_uniform = Uniform::new_inclusive(&low, &high); |
| 992 | + for _ in 0..1000 { |
| 993 | + let v: $ty = rng.sample(my_uniform); |
| 994 | + assert!($le(low, v) && $le(v, high)); |
| 995 | + } |
851 | 996 |
|
852 |
| - for _ in 0..1000 { |
853 |
| - let v: $ty = rng.gen_range(low, high); |
854 |
| - assert!(low <= v && v < high); |
855 |
| - } |
| 997 | + for _ in 0..1000 { |
| 998 | + let v: $ty = rng.gen_range(low, high); |
| 999 | + assert!($le(low, v) && $lt(v, high)); |
856 | 1000 | }
|
857 |
| - )* |
858 |
| - }} |
| 1001 | + } |
| 1002 | + }}; |
| 1003 | + |
| 1004 | + // scalar bulk |
| 1005 | + ($($ty:ident),*) => {{ |
| 1006 | + $(t!( |
| 1007 | + $ty, |
| 1008 | + [(0, 10), (10, 127), ($ty::MIN, $ty::MAX)], |
| 1009 | + |x, y| x <= y, |
| 1010 | + |x, y| x < y |
| 1011 | + );)* |
| 1012 | + }}; |
| 1013 | + |
| 1014 | + // simd bulk |
| 1015 | + ($($ty:ident),* => $scalar:ident) => {{ |
| 1016 | + $(t!( |
| 1017 | + $ty, |
| 1018 | + [ |
| 1019 | + ($ty::splat(0), $ty::splat(10)), |
| 1020 | + ($ty::splat(10), $ty::splat(127)), |
| 1021 | + ($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)), |
| 1022 | + ], |
| 1023 | + |x: $ty, y| x.le(y).all(), |
| 1024 | + |x: $ty, y| x.lt(y).all() |
| 1025 | + );)* |
| 1026 | + }}; |
859 | 1027 | }
|
860 | 1028 | t!(i8, i16, i32, i64, isize,
|
861 | 1029 | u8, u16, u32, u64, usize);
|
862 | 1030 | #[cfg(rust_1_26)]
|
863 |
| - t!(i128, u128) |
| 1031 | + t!(i128, u128); |
| 1032 | + |
| 1033 | + #[cfg(feature = "simd_support")] |
| 1034 | + { |
| 1035 | + t!(u8x2, u8x4, u8x8, u8x16, u8x32, u8x64 => u8); |
| 1036 | + t!(i8x2, i8x4, i8x8, i8x16, i8x32, i8x64 => i8); |
| 1037 | + t!(u16x2, u16x4, u16x8, u16x16, u16x32 => u16); |
| 1038 | + t!(i16x2, i16x4, i16x8, i16x16, i16x32 => i16); |
| 1039 | + t!(u32x2, u32x4, u32x8, u32x16 => u32); |
| 1040 | + t!(i32x2, i32x4, i32x8, i32x16 => i32); |
| 1041 | + t!(u64x2, u64x4, u64x8 => u64); |
| 1042 | + t!(i64x2, i64x4, i64x8 => i64); |
| 1043 | + } |
864 | 1044 | }
|
865 | 1045 |
|
866 | 1046 | #[test]
|
@@ -932,13 +1112,16 @@ mod tests {
|
932 | 1112 |
|
933 | 1113 | t!(f32, f32, 32 - 23);
|
934 | 1114 | t!(f64, f64, 64 - 52);
|
935 |
| - #[cfg(feature="simd_support")] t!(f32x2, f32, 32 - 23); |
936 |
| - #[cfg(feature="simd_support")] t!(f32x4, f32, 32 - 23); |
937 |
| - #[cfg(feature="simd_support")] t!(f32x8, f32, 32 - 23); |
938 |
| - #[cfg(feature="simd_support")] t!(f32x16, f32, 32 - 23); |
939 |
| - #[cfg(feature="simd_support")] t!(f64x2, f64, 64 - 52); |
940 |
| - #[cfg(feature="simd_support")] t!(f64x4, f64, 64 - 52); |
941 |
| - #[cfg(feature="simd_support")] t!(f64x8, f64, 64 - 52); |
| 1115 | + #[cfg(feature="simd_support")] |
| 1116 | + { |
| 1117 | + t!(f32x2, f32, 32 - 23); |
| 1118 | + t!(f32x4, f32, 32 - 23); |
| 1119 | + t!(f32x8, f32, 32 - 23); |
| 1120 | + t!(f32x16, f32, 32 - 23); |
| 1121 | + t!(f64x2, f64, 64 - 52); |
| 1122 | + t!(f64x4, f64, 64 - 52); |
| 1123 | + t!(f64x8, f64, 64 - 52); |
| 1124 | + } |
942 | 1125 | }
|
943 | 1126 |
|
944 | 1127 | #[test]
|
@@ -985,13 +1168,16 @@ mod tests {
|
985 | 1168 |
|
986 | 1169 | t!(f32, f32);
|
987 | 1170 | t!(f64, f64);
|
988 |
| - #[cfg(feature="simd_support")] t!(f32x2, f32); |
989 |
| - #[cfg(feature="simd_support")] t!(f32x4, f32); |
990 |
| - #[cfg(feature="simd_support")] t!(f32x8, f32); |
991 |
| - #[cfg(feature="simd_support")] t!(f32x16, f32); |
992 |
| - #[cfg(feature="simd_support")] t!(f64x2, f64); |
993 |
| - #[cfg(feature="simd_support")] t!(f64x4, f64); |
994 |
| - #[cfg(feature="simd_support")] t!(f64x8, f64); |
| 1171 | + #[cfg(feature="simd_support")] |
| 1172 | + { |
| 1173 | + t!(f32x2, f32); |
| 1174 | + t!(f32x4, f32); |
| 1175 | + t!(f32x8, f32); |
| 1176 | + t!(f32x16, f32); |
| 1177 | + t!(f64x2, f64); |
| 1178 | + t!(f64x4, f64); |
| 1179 | + t!(f64x8, f64); |
| 1180 | + } |
995 | 1181 | }
|
996 | 1182 |
|
997 | 1183 |
|
|
0 commit comments