Skip to content

Commit 2144ae5

Browse files
committed
UniformUsize: inline sub-impls
1 parent d9debb7 commit 2144ae5

File tree

1 file changed

+70
-29
lines changed

1 file changed

+70
-29
lines changed

src/distr/uniform_int.rs

+70-29
Original file line numberDiff line numberDiff line change
@@ -395,21 +395,19 @@ uniform_simd_int_impl! { (u8, i8), (u16, i16), (u32, i32), (u64, i64) }
395395
/// this implementation will use 32-bit sampling when possible.
396396
#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
397397
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
398-
pub struct UniformUsize(UniformUsizeImpl);
398+
pub struct UniformUsize {
399+
low: usize,
400+
range: usize,
401+
thresh: usize,
402+
#[cfg(target_pointer_width = "64")]
403+
mode64: bool,
404+
}
399405

400406
#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
401407
impl SampleUniform for usize {
402408
type Sampler = UniformUsize;
403409
}
404410

405-
#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
406-
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
407-
pub enum UniformUsizeImpl {
408-
U32(UniformInt<u32>),
409-
#[cfg(target_pointer_width = "64")]
410-
U64(UniformInt<u64>),
411-
}
412-
413411
#[cfg(any(target_pointer_width = "32", target_pointer_width = "64"))]
414412
impl UniformSampler for UniformUsize {
415413
type X = usize;
@@ -427,13 +425,7 @@ impl UniformSampler for UniformUsize {
427425
return Err(Error::EmptyRange);
428426
}
429427

430-
#[cfg(target_pointer_width = "64")]
431-
if high > (u32::MAX as usize) {
432-
return UniformInt::new(low as u64, high as u64)
433-
.map(|ui| UniformUsize(UniformUsizeImpl::U64(ui)));
434-
}
435-
436-
UniformInt::new(low as u32, high as u32).map(|ui| UniformUsize(UniformUsizeImpl::U32(ui)))
428+
UniformSampler::new_inclusive(low, high - 1)
437429
}
438430

439431
#[inline] // if the range is constant, this helps LLVM to do the
@@ -450,21 +442,72 @@ impl UniformSampler for UniformUsize {
450442
}
451443

452444
#[cfg(target_pointer_width = "64")]
453-
if high > (u32::MAX as usize) {
454-
return UniformInt::new_inclusive(low as u64, high as u64)
455-
.map(|ui| UniformUsize(UniformUsizeImpl::U64(ui)));
445+
let mode64 = high > (u32::MAX as usize);
446+
#[cfg(target_pointer_width = "32")]
447+
let mode64 = false;
448+
449+
let (range, thresh);
450+
if cfg!(target_pointer_width = "64") && !mode64 {
451+
let range32 = (high as u32).wrapping_sub(low as u32).wrapping_add(1);
452+
range = range32 as usize;
453+
thresh = if range32 > 0 {
454+
(range32.wrapping_neg() % range32) as usize
455+
} else {
456+
0
457+
};
458+
} else {
459+
range = high.wrapping_sub(low).wrapping_add(1);
460+
thresh = if range > 0 {
461+
range.wrapping_neg() % range
462+
} else {
463+
0
464+
};
456465
}
457466

458-
UniformInt::new_inclusive(low as u32, high as u32)
459-
.map(|ui| UniformUsize(UniformUsizeImpl::U32(ui)))
467+
Ok(UniformUsize {
468+
low,
469+
range,
470+
thresh,
471+
#[cfg(target_pointer_width = "64")]
472+
mode64,
473+
})
460474
}
461475

462476
#[inline]
463477
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
464-
match self.0 {
465-
UniformUsizeImpl::U32(uu) => uu.sample(rng) as usize,
466-
#[cfg(target_pointer_width = "64")]
467-
UniformUsizeImpl::U64(uu) => uu.sample(rng) as usize,
478+
#[cfg(target_pointer_width = "32")]
479+
let mode32 = true;
480+
#[cfg(target_pointer_width = "64")]
481+
let mode32 = !self.mode64;
482+
483+
if mode32 {
484+
let range = self.range as u32;
485+
if range == 0 {
486+
return rng.random::<u32>() as usize;
487+
}
488+
489+
let thresh = self.thresh as u32;
490+
let hi = loop {
491+
let (hi, lo) = rng.random::<u32>().wmul(range);
492+
if lo >= thresh {
493+
break hi;
494+
}
495+
};
496+
self.low.wrapping_add(hi as usize)
497+
} else {
498+
let range = self.range as u64;
499+
if range == 0 {
500+
return rng.random::<u64>() as usize;
501+
}
502+
503+
let thresh = self.thresh as u64;
504+
let hi = loop {
505+
let (hi, lo) = rng.random::<u64>().wmul(range);
506+
if lo >= thresh {
507+
break hi;
508+
}
509+
};
510+
self.low.wrapping_add(hi as usize)
468511
}
469512
}
470513

@@ -484,8 +527,7 @@ impl UniformSampler for UniformUsize {
484527
return Err(Error::EmptyRange);
485528
}
486529

487-
#[cfg(target_pointer_width = "64")]
488-
if high > (u32::MAX as usize) {
530+
if cfg!(target_pointer_width = "64") && high > (u32::MAX as usize) {
489531
return UniformInt::<u64>::sample_single(low as u64, high as u64, rng)
490532
.map(|x| x as usize);
491533
}
@@ -509,8 +551,7 @@ impl UniformSampler for UniformUsize {
509551
return Err(Error::EmptyRange);
510552
}
511553

512-
#[cfg(target_pointer_width = "64")]
513-
if high > (u32::MAX as usize) {
554+
if cfg!(target_pointer_width = "64") && high > (u32::MAX as usize) {
514555
return UniformInt::<u64>::sample_single_inclusive(low as u64, high as u64, rng)
515556
.map(|x| x as usize);
516557
}

0 commit comments

Comments
 (0)