Skip to content

Commit 1e35a86

Browse files
committed
Implement UniformSampler for Wrapping<T>
1 parent 5953334 commit 1e35a86

File tree

1 file changed

+203
-8
lines changed

1 file changed

+203
-8
lines changed

src/distributions/uniform.rs

Lines changed: 203 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@
2020
//! and supports extension to user-defined types via a type-specific *back-end*
2121
//! implementation.
2222
//!
23-
//! The types [`UniformInt`], [`UniformFloat`] and [`UniformDuration`] are the
24-
//! back-ends supporting sampling from primitive integer and floating-point
25-
//! ranges as well as from `std::time::Duration`; these types do not normally
26-
//! need to be used directly (unless implementing a derived back-end).
23+
//! The types [`UniformInt`], [`UniformFloat`], [`UniformWrapping`] and
24+
//! [`UniformDuration`] are the back-ends supporting sampling from primitive
25+
//! integer and floating-point ranges as well as from `std::time::Duration`;
26+
//! these types do not normally need to be used directly (unless implementing a
27+
//! derived back-end).
2728
//!
2829
//! # Example usage
2930
//!
@@ -97,6 +98,7 @@
9798
//! [`UniformFloat`]: struct.UniformFloat.html
9899
//! [`UniformDuration`]: struct.UniformDuration.html
99100
101+
use core::num::Wrapping;
100102
#[cfg(feature = "std")]
101103
use std::time::Duration;
102104

@@ -492,6 +494,139 @@ wmul_impl_usize! { u64 }
492494

493495

494496

497+
/// The back-end implementing [`UniformSampler`] for `Wrapping<T>`.
498+
///
499+
/// Unless you are implementing [`UniformSampler`] for your own types, this type
500+
/// should not be used directly, use [`Uniform`] instead.
501+
///
502+
/// The method used is the same as for [`UniformInt`], with one exception: it is
503+
/// not required that `low <= high`. If `low` is greater than `high`, the range
504+
/// to sample from contains both `low..MAX` and `MIN..high`, i.e. the range
505+
/// wraps around.
506+
///
507+
/// [`UniformSampler`]: trait.UniformSampler.html
508+
/// [`Uniform`]: struct.Uniform.html
509+
/// [`UniformInt`]: struct.UniformInt.html
510+
#[derive(Clone, Copy, Debug)]
511+
pub struct UniformWrapping<X> {
512+
low: X,
513+
range: X,
514+
zone: X,
515+
}
516+
517+
macro_rules! uniform_wrapping_int_impl {
518+
($ty:ty, $signed:ty, $unsigned:ident,
519+
$i_large:ident, $u_large:ident) => {
520+
impl SampleUniform for Wrapping<$ty> {
521+
type Sampler = UniformWrapping<$ty>;
522+
}
523+
524+
impl UniformSampler for UniformWrapping<$ty> {
525+
// We play free and fast with unsigned vs signed here
526+
// (when $ty is signed), but that's fine, since the
527+
// contract of this macro is for $ty and $unsigned to be
528+
// "bit-equal", so casting between them is a no-op.
529+
530+
type X = Wrapping<$ty>;
531+
532+
#[inline] // if the range is constant, this helps LLVM to do the
533+
// calculations at compile-time.
534+
fn new(low: Self::X, high: Self::X) -> Self {
535+
UniformSampler::new_inclusive(low, high - Wrapping(1))
536+
}
537+
538+
#[inline] // if the range is constant, this helps LLVM to do the
539+
// calculations at compile-time.
540+
fn new_inclusive(low: Self::X, high: Self::X) -> Self {
541+
let unsigned_max = ::core::$unsigned::MAX;
542+
543+
let range =
544+
high.0.wrapping_sub(low.0).wrapping_add(1) as $unsigned;
545+
let ints_to_reject =
546+
if range > 0 {
547+
(unsigned_max - range + 1) % range
548+
} else {
549+
0
550+
};
551+
let zone = unsigned_max - ints_to_reject;
552+
553+
UniformWrapping {
554+
low: low.0,
555+
// These are really $unsigned values, but store as $ty:
556+
range: range as $ty,
557+
zone: zone as $ty
558+
}
559+
}
560+
561+
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
562+
let range = self.range as $unsigned as $u_large;
563+
if range > 0 {
564+
// Grow `zone` to fit a type of at least 32 bits, by
565+
// sign-extending it (the first bit is always 1, so are all
566+
// the preceding bits of the larger type).
567+
// For types that already have the right size, all the
568+
// casting is a no-op.
569+
let zone = self.zone as $signed as $i_large as $u_large;
570+
loop {
571+
let v: $u_large = rng.gen();
572+
let (hi, lo) = v.wmul(range);
573+
if lo <= zone {
574+
return Wrapping(self.low.wrapping_add(hi as $ty));
575+
}
576+
}
577+
} else {
578+
// Sample from the entire integer range.
579+
rng.gen()
580+
}
581+
}
582+
583+
fn sample_single<R: Rng + ?Sized>(low: Self::X,
584+
high: Self::X,
585+
rng: &mut R) -> Self::X
586+
{
587+
let range = high.0.wrapping_sub(low.0) as $unsigned as $u_large;
588+
let zone =
589+
if ::core::$unsigned::MAX <= ::core::u16::MAX as $unsigned {
590+
// Using a modulus is faster than the approximation for
591+
// i8 and i16. I suppose we trade the cost of one
592+
// modulus for near-perfect branch prediction.
593+
let unsigned_max: $u_large = ::core::$u_large::MAX;
594+
let ints_to_reject = (unsigned_max - range + 1) % range;
595+
unsigned_max - ints_to_reject
596+
} else {
597+
// conservative but fast approximation
598+
range << range.leading_zeros()
599+
};
600+
601+
loop {
602+
let v: $u_large = rng.gen();
603+
let (hi, lo) = v.wmul(range);
604+
if lo <= zone {
605+
return Wrapping(low.0.wrapping_add(hi as $ty));
606+
}
607+
}
608+
}
609+
}
610+
}
611+
}
612+
613+
uniform_wrapping_int_impl! { i8, i8, u8, i32, u32 }
614+
uniform_wrapping_int_impl! { i16, i16, u16, i32, u32 }
615+
uniform_wrapping_int_impl! { i32, i32, u32, i32, u32 }
616+
uniform_wrapping_int_impl! { i64, i64, u64, i64, u64 }
617+
#[cfg(feature = "i128_support")]
618+
uniform_wrapping_int_impl! { i128, i128, u128, u128, u128 }
619+
uniform_wrapping_int_impl! { isize, isize, usize, isize, usize }
620+
uniform_wrapping_int_impl! { u8, i8, u8, i32, u32 }
621+
uniform_wrapping_int_impl! { u16, i16, u16, i32, u32 }
622+
uniform_wrapping_int_impl! { u32, i32, u32, i32, u32 }
623+
uniform_wrapping_int_impl! { u64, i64, u64, i64, u64 }
624+
uniform_wrapping_int_impl! { usize, isize, usize, isize, usize }
625+
#[cfg(feature = "i128_support")]
626+
uniform_wrapping_int_impl! { u128, u128, u128, i128, u128 }
627+
628+
629+
495630
/// The back-end implementing [`UniformSampler`] for floating-point types.
496631
///
497632
/// Unless you are implementing [`UniformSampler`] for your own type, this type
@@ -683,6 +818,7 @@ impl UniformSampler for UniformDuration {
683818
mod tests {
684819
use Rng;
685820
use distributions::uniform::{Uniform, UniformSampler, UniformFloat, SampleUniform};
821+
use core::num::Wrapping;
686822

687823
#[should_panic]
688824
#[test]
@@ -732,10 +868,10 @@ mod tests {
732868
macro_rules! t {
733869
($($ty:ident),*) => {{
734870
$(
735-
let v: &[($ty, $ty)] = &[(0, 10),
736-
(10, 127),
737-
(::core::$ty::MIN, ::core::$ty::MAX)];
738-
for &(low, high) in v.iter() {
871+
let v: &[($ty, $ty)] = &[(0, 10),
872+
(10, 127),
873+
(::core::$ty::MIN, ::core::$ty::MAX)];
874+
for &(low, high) in v.iter() {
739875
let my_uniform = Uniform::new(low, high);
740876
for _ in 0..1000 {
741877
let v: $ty = rng.sample(my_uniform);
@@ -762,6 +898,65 @@ mod tests {
762898
t!(i128, u128)
763899
}
764900

901+
#[test]
902+
fn test_wrapping() {
903+
let mut rng = ::test::rng(251);
904+
macro_rules! t {
905+
($($ty:ident),*) => {{
906+
$(
907+
let v: &[(Wrapping<$ty>, Wrapping<$ty>)] =
908+
&[(Wrapping(0), Wrapping(10)),
909+
(Wrapping(10), Wrapping(127)),
910+
(Wrapping(::core::$ty::MIN), Wrapping(::core::$ty::MAX))];
911+
for &(low, high) in v.iter() {
912+
let my_uniform = Uniform::new(low, high);
913+
for _ in 0..1000 {
914+
let v: Wrapping<$ty> = rng.sample(my_uniform);
915+
assert!(low <= v && v < high);
916+
}
917+
918+
let my_uniform = Uniform::new_inclusive(low, high);
919+
for _ in 0..1000 {
920+
let v: Wrapping<$ty> = rng.sample(my_uniform);
921+
assert!(low <= v && v <= high);
922+
}
923+
924+
for _ in 0..1000 {
925+
let v: Wrapping<$ty> =
926+
Uniform::sample_single(low, high, &mut rng);
927+
assert!(low <= v && v < high);
928+
}
929+
}
930+
931+
// Switch the bounds to test wrapping around
932+
for &(low, high) in v.iter() {
933+
let my_uniform = Uniform::new(high, low);
934+
for _ in 0..1000 {
935+
let v: Wrapping<$ty> = rng.sample(my_uniform);
936+
assert!(v >= high || v < low);
937+
}
938+
939+
let my_uniform = Uniform::new_inclusive(high, low);
940+
for _ in 0..1000 {
941+
let v: Wrapping<$ty> = rng.sample(my_uniform);
942+
assert!(v >= high || v <= low);
943+
}
944+
945+
for _ in 0..1000 {
946+
let v: Wrapping<$ty> =
947+
Uniform::sample_single(high, low, &mut rng);
948+
assert!(v >= high || v < low);
949+
}
950+
}
951+
)*
952+
}}
953+
}
954+
t!(i8, i16, i32, i64, isize,
955+
u8, u16, u32, u64, usize);
956+
#[cfg(feature = "i128_support")]
957+
t!(i128, u128)
958+
}
959+
765960
#[test]
766961
fn test_floats() {
767962
let mut rng = ::test::rng(252);

0 commit comments

Comments
 (0)