Skip to content

Commit 605476c

Browse files
authored
Portability fixes (#1469)
- Fix portability of `choose_multiple_array` - Fix portability of `rand::distributions::Slice`
1 parent f3aab23 commit 605476c

File tree

8 files changed

+67
-18
lines changed

8 files changed

+67
-18
lines changed

CHANGELOG.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
1010

1111
## [Unreleased]
1212
- Add `rand::distributions::WeightedIndex::{weight, weights, total_weight}` (#1420)
13-
- Add `IndexedRandom::choose_multiple_array`, `index::sample_array` (#1453)
13+
- Add `IndexedRandom::choose_multiple_array`, `index::sample_array` (#1453, #1469)
1414
- Bump the MSRV to 1.61.0
1515
- Rename `Rng::gen` to `Rng::random` to avoid conflict with the new `gen` keyword in Rust 2024 (#1435)
1616
- Move all benchmarks to new `benches` crate (#1439)
1717
- Annotate panicking methods with `#[track_caller]` (#1442, #1447)
1818
- Enable feature `small_rng` by default (#1455)
1919
- Allow `UniformFloat::new` samples and `UniformFloat::sample_single` to yield `high` (#1462)
20+
- Fix portability of `rand::distributions::Slice` (#1469)
2021

2122
## [0.9.0-alpha.1] - 2024-03-18
2223
- Add the `Slice::num_choices` method to the Slice distribution (#1402)

rand_distr/src/dirichlet.rs

+1-8
Original file line numberDiff line numberDiff line change
@@ -333,20 +333,13 @@ where
333333
#[cfg(test)]
334334
mod test {
335335
use super::*;
336-
use alloc::vec::Vec;
337336

338337
#[test]
339338
fn test_dirichlet() {
340339
let d = Dirichlet::new([1.0, 2.0, 3.0]).unwrap();
341340
let mut rng = crate::test::rng(221);
342341
let samples = d.sample(&mut rng);
343-
let _: Vec<f64> = samples
344-
.into_iter()
345-
.map(|x| {
346-
assert!(x > 0.0);
347-
x
348-
})
349-
.collect();
342+
assert!(samples.into_iter().all(|x: f64| x > 0.0));
350343
}
351344

352345
#[test]

src/distributions/slice.rs

+46-2
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,39 @@
99
use core::num::NonZeroUsize;
1010

1111
use crate::distributions::{Distribution, Uniform};
12+
use crate::Rng;
1213
#[cfg(feature = "alloc")]
1314
use alloc::string::String;
1415

16+
#[cfg(not(any(target_pointer_width = "32", target_pointer_width = "64")))]
17+
compile_error!("unsupported pointer width");
18+
19+
#[derive(Debug, Clone, Copy)]
20+
enum UniformUsize {
21+
U32(Uniform<u32>),
22+
#[cfg(target_pointer_width = "64")]
23+
U64(Uniform<u64>),
24+
}
25+
26+
impl UniformUsize {
27+
pub fn new(ubound: usize) -> Result<Self, super::uniform::Error> {
28+
#[cfg(target_pointer_width = "64")]
29+
if ubound > (u32::MAX as usize) {
30+
return Uniform::new(0, ubound as u64).map(UniformUsize::U64);
31+
}
32+
33+
Uniform::new(0, ubound as u32).map(UniformUsize::U32)
34+
}
35+
36+
pub fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
37+
match self {
38+
UniformUsize::U32(uu) => uu.sample(rng) as usize,
39+
#[cfg(target_pointer_width = "64")]
40+
UniformUsize::U64(uu) => uu.sample(rng) as usize,
41+
}
42+
}
43+
}
44+
1545
/// A distribution to sample items uniformly from a slice.
1646
///
1747
/// [`Slice::new`] constructs a distribution referencing a slice and uniformly
@@ -68,7 +98,7 @@ use alloc::string::String;
6898
#[derive(Debug, Clone, Copy)]
6999
pub struct Slice<'a, T> {
70100
slice: &'a [T],
71-
range: Uniform<usize>,
101+
range: UniformUsize,
72102
num_choices: NonZeroUsize,
73103
}
74104

@@ -80,7 +110,7 @@ impl<'a, T> Slice<'a, T> {
80110

81111
Ok(Self {
82112
slice,
83-
range: Uniform::new(0, num_choices.get()).unwrap(),
113+
range: UniformUsize::new(num_choices.get()).unwrap(),
84114
num_choices,
85115
})
86116
}
@@ -161,3 +191,17 @@ impl<'a> super::DistString for Slice<'a, char> {
161191
}
162192
}
163193
}
194+
195+
#[cfg(test)]
196+
mod test {
197+
use super::*;
198+
use core::iter;
199+
200+
#[test]
201+
fn value_stability() {
202+
let rng = crate::test::rng(651);
203+
let slice = Slice::new(b"escaped emus explore extensively").unwrap();
204+
let expected = b"eaxee";
205+
assert!(iter::zip(slice.sample_iter(rng), expected).all(|(a, b)| a == b));
206+
}
207+
}

src/distributions/utils.rs

+2
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,9 @@ pub(crate) trait FloatSIMDScalarUtils: FloatSIMDUtils {
241241

242242
/// Implement functions on f32/f64 to give them APIs similar to SIMD types
243243
pub(crate) trait FloatAsSIMD: Sized {
244+
#[cfg(test)]
244245
const LEN: usize = 1;
246+
245247
#[inline(always)]
246248
fn splat(scalar: Self) -> Self {
247249
scalar

src/rngs/mod.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,10 @@ pub mod mock; // Public so we don't export `StepRng` directly, making it a bit
8686

8787
#[cfg(feature = "small_rng")]
8888
mod small;
89-
#[cfg(all(feature = "small_rng", not(target_pointer_width = "64")))]
89+
#[cfg(all(
90+
feature = "small_rng",
91+
any(target_pointer_width = "32", target_pointer_width = "16")
92+
))]
9093
mod xoshiro128plusplus;
9194
#[cfg(all(feature = "small_rng", target_pointer_width = "64"))]
9295
mod xoshiro256plusplus;

src/rngs/small.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@
1010
1111
use rand_core::{RngCore, SeedableRng};
1212

13+
#[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))]
14+
type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus;
1315
#[cfg(target_pointer_width = "64")]
1416
type Rng = super::xoshiro256plusplus::Xoshiro256PlusPlus;
15-
#[cfg(not(target_pointer_width = "64"))]
16-
type Rng = super::xoshiro128plusplus::Xoshiro128PlusPlus;
1717

1818
/// A small-state, fast, non-crypto, non-portable PRNG
1919
///

src/seq/index.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
// except according to those terms.
88

99
//! Low-level API for sampling indices
10+
use super::gen_index;
1011
#[cfg(feature = "alloc")]
1112
use alloc::vec::{self, Vec};
1213
use core::slice;
@@ -288,7 +289,7 @@ where
288289
// Floyd's algorithm
289290
let mut indices = [0; N];
290291
for (i, j) in (len - N..len).enumerate() {
291-
let t = rng.gen_range(0..=j);
292+
let t = gen_index(rng, j + 1);
292293
if let Some(pos) = indices[0..i].iter().position(|&x| x == t) {
293294
indices[pos] = j;
294295
}

src/seq/slice.rs

+8-3
Original file line numberDiff line numberDiff line change
@@ -495,19 +495,24 @@ mod test {
495495
assert_eq!(chars.choose(&mut r), Some(&'l'));
496496
assert_eq!(nums.choose_mut(&mut r), Some(&mut 3));
497497

498+
assert_eq!(
499+
&chars.choose_multiple_array(&mut r),
500+
&Some(['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k'])
501+
);
502+
498503
#[cfg(feature = "alloc")]
499504
assert_eq!(
500505
&chars
501506
.choose_multiple(&mut r, 8)
502507
.cloned()
503508
.collect::<Vec<char>>(),
504-
&['f', 'i', 'd', 'b', 'c', 'm', 'j', 'k']
509+
&['h', 'm', 'd', 'b', 'c', 'e', 'n', 'f']
505510
);
506511

507512
#[cfg(feature = "alloc")]
508-
assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'l'));
513+
assert_eq!(chars.choose_weighted(&mut r, |_| 1), Ok(&'i'));
509514
#[cfg(feature = "alloc")]
510-
assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 8));
515+
assert_eq!(nums.choose_weighted_mut(&mut r, |_| 1), Ok(&mut 2));
511516

512517
let mut r = crate::test::rng(414);
513518
nums.shuffle(&mut r);

0 commit comments

Comments
 (0)