Skip to content

Commit ca270f3

Browse files
committed
AliasMethod weighted index: use u32 internally
Primarily for value stability, also slight performance boost.
1 parent 4a375f6 commit ca270f3

File tree

2 files changed

+58
-48
lines changed

2 files changed

+58
-48
lines changed

src/distributions/weighted/alias_method.rs

Lines changed: 54 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,10 @@ use Rng;
2525
/// Given that `n` is the number of items in the vector used to create an
2626
/// [`WeightedIndex<W>`], [`WeightedIndex<W>`] will require `O(n)` amount of
2727
/// memory. More specifically it takes up some constant amount of memory plus
28-
/// the vector used to create it and a [`Vec<usize>`] with capacity `n`.
28+
/// the vector used to create it and a [`Vec<u32>`] with capacity `n`.
2929
///
3030
/// Time complexity for the creation of a [`WeightedIndex<W>`] is `O(n)`.
31-
/// Sampling is `O(1)`, it makes a call to [`Uniform<usize>::sample`] and a call
31+
/// Sampling is `O(1)`, it makes a call to [`Uniform<u32>::sample`] and a call
3232
/// to [`Uniform<W>::sample`].
3333
///
3434
/// # Example
@@ -56,13 +56,13 @@ use Rng;
5656
///
5757
/// [`WeightedIndex<W>`]: crate::distributions::weighted::alias_method::WeightedIndex
5858
/// [`Weight`]: crate::distributions::weighted::alias_method::Weight
59-
/// [`Vec<usize>`]: Vec
60-
/// [`Uniform<usize>::sample`]: Distribution::sample
59+
/// [`Vec<u32>`]: Vec
60+
/// [`Uniform<u32>::sample`]: Distribution::sample
6161
/// [`Uniform<W>::sample`]: Distribution::sample
6262
pub struct WeightedIndex<W: Weight> {
63-
aliases: Vec<usize>,
63+
aliases: Vec<u32>,
6464
no_alias_odds: Vec<W>,
65-
uniform_index: Uniform<usize>,
65+
uniform_index: Uniform<u32>,
6666
uniform_within_weight_sum: Uniform<W>,
6767
}
6868

@@ -71,16 +71,20 @@ impl<W: Weight> WeightedIndex<W> {
7171
///
7272
/// Returns an error if:
7373
/// - The vector is empty.
74+
/// - The vector is longer than `u32::MAX`.
7475
/// - For any weight `w`: `w < 0` or `w > max` where `max = W::MAX /
7576
/// weights.len()`.
7677
/// - The sum of weights is zero.
7778
pub fn new(weights: Vec<W>) -> Result<Self, WeightedError> {
7879
let n = weights.len();
7980
if n == 0 {
8081
return Err(WeightedError::NoItem);
82+
} else if n > ::core::u32::MAX as usize {
83+
return Err(WeightedError::TooMany);
8184
}
85+
let n = n as u32;
8286

83-
let max_weight_size = W::try_from_usize_lossy(n)
87+
let max_weight_size = W::try_from_u32_lossy(n)
8488
.map(|n| W::MAX / n)
8589
.unwrap_or(W::ZERO);
8690
if !weights
@@ -103,7 +107,7 @@ impl<W: Weight> WeightedIndex<W> {
103107
}
104108

105109
// `weight_sum` would have been zero if `try_from_lossy` causes an error here.
106-
let n_converted = W::try_from_usize_lossy(n).unwrap();
110+
let n_converted = W::try_from_u32_lossy(n).unwrap();
107111

108112
let mut no_alias_odds = weights;
109113
for odds in no_alias_odds.iter_mut() {
@@ -119,52 +123,52 @@ impl<W: Weight> WeightedIndex<W> {
119123
/// be ensured that a single index is only ever in one of them at the
120124
/// same time.
121125
struct Aliases {
122-
aliases: Vec<usize>,
123-
smalls_head: usize,
124-
bigs_head: usize,
126+
aliases: Vec<u32>,
127+
smalls_head: u32,
128+
bigs_head: u32,
125129
}
126130

127131
impl Aliases {
128-
fn new(size: usize) -> Self {
132+
fn new(size: u32) -> Self {
129133
Aliases {
130-
aliases: vec![0; size],
131-
smalls_head: ::core::usize::MAX,
132-
bigs_head: ::core::usize::MAX,
134+
aliases: vec![0; size as usize],
135+
smalls_head: ::core::u32::MAX,
136+
bigs_head: ::core::u32::MAX,
133137
}
134138
}
135139

136-
fn push_small(&mut self, idx: usize) {
137-
self.aliases[idx] = self.smalls_head;
140+
fn push_small(&mut self, idx: u32) {
141+
self.aliases[idx as usize] = self.smalls_head;
138142
self.smalls_head = idx;
139143
}
140144

141-
fn push_big(&mut self, idx: usize) {
142-
self.aliases[idx] = self.bigs_head;
145+
fn push_big(&mut self, idx: u32) {
146+
self.aliases[idx as usize] = self.bigs_head;
143147
self.bigs_head = idx;
144148
}
145149

146-
fn pop_small(&mut self) -> usize {
150+
fn pop_small(&mut self) -> u32 {
147151
let popped = self.smalls_head;
148-
self.smalls_head = self.aliases[popped];
152+
self.smalls_head = self.aliases[popped as usize];
149153
popped
150154
}
151155

152-
fn pop_big(&mut self) -> usize {
156+
fn pop_big(&mut self) -> u32 {
153157
let popped = self.bigs_head;
154-
self.bigs_head = self.aliases[popped];
158+
self.bigs_head = self.aliases[popped as usize];
155159
popped
156160
}
157161

158162
fn smalls_is_empty(&self) -> bool {
159-
self.smalls_head == ::core::usize::MAX
163+
self.smalls_head == ::core::u32::MAX
160164
}
161165

162166
fn bigs_is_empty(&self) -> bool {
163-
self.bigs_head == ::core::usize::MAX
167+
self.bigs_head == ::core::u32::MAX
164168
}
165169

166-
fn set_alias(&mut self, idx: usize, alias: usize) {
167-
self.aliases[idx] = alias;
170+
fn set_alias(&mut self, idx: u32, alias: u32) {
171+
self.aliases[idx as usize] = alias;
168172
}
169173
}
170174

@@ -173,9 +177,9 @@ impl<W: Weight> WeightedIndex<W> {
173177
// Split indices into those with small weights and those with big weights.
174178
for (index, &odds) in no_alias_odds.iter().enumerate() {
175179
if odds < weight_sum {
176-
aliases.push_small(index);
180+
aliases.push_small(index as u32);
177181
} else {
178-
aliases.push_big(index);
182+
aliases.push_big(index as u32);
179183
}
180184
}
181185

@@ -186,9 +190,11 @@ impl<W: Weight> WeightedIndex<W> {
186190
let b = aliases.pop_big();
187191

188192
aliases.set_alias(s, b);
189-
no_alias_odds[b] = no_alias_odds[b] - weight_sum + no_alias_odds[s];
193+
no_alias_odds[b as usize] = no_alias_odds[b as usize]
194+
- weight_sum
195+
+ no_alias_odds[s as usize];
190196

191-
if no_alias_odds[b] < weight_sum {
197+
if no_alias_odds[b as usize] < weight_sum {
192198
aliases.push_small(b);
193199
} else {
194200
aliases.push_big(b);
@@ -198,10 +204,10 @@ impl<W: Weight> WeightedIndex<W> {
198204
// The remaining indices should have no alias odds of about 100%. This is due to
199205
// numeric accuracy. Otherwise they would be exactly 100%.
200206
while !aliases.smalls_is_empty() {
201-
no_alias_odds[aliases.pop_small()] = weight_sum;
207+
no_alias_odds[aliases.pop_small() as usize] = weight_sum;
202208
}
203209
while !aliases.bigs_is_empty() {
204-
no_alias_odds[aliases.pop_big()] = weight_sum;
210+
no_alias_odds[aliases.pop_big() as usize] = weight_sum;
205211
}
206212

207213
// Prepare distributions for sampling. Creating them beforehand improves
@@ -221,10 +227,10 @@ impl<W: Weight> WeightedIndex<W> {
221227
impl<W: Weight> Distribution<usize> for WeightedIndex<W> {
222228
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
223229
let candidate = rng.sample(self.uniform_index);
224-
if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate] {
225-
candidate
230+
if rng.sample(&self.uniform_within_weight_sum) < self.no_alias_odds[candidate as usize] {
231+
candidate as usize
226232
} else {
227-
self.aliases[candidate]
233+
self.aliases[candidate as usize] as usize
228234
}
229235
}
230236
}
@@ -282,10 +288,10 @@ pub trait Weight:
282288
/// Element of `Self` equivalent to 0.
283289
const ZERO: Self;
284290

285-
/// Produce an instance of `Self` from a `usize` value, or return `None` if
291+
/// Produce an instance of `Self` from a `u32` value, or return `None` if
286292
/// out of range. Loss of precision (where `Self` is a floating point type)
287293
/// is acceptable.
288-
fn try_from_usize_lossy(n: usize) -> Option<Self>;
294+
fn try_from_u32_lossy(n: u32) -> Option<Self>;
289295

290296
/// Sums all values in slice `values`.
291297
fn sum(values: &[Self]) -> Self {
@@ -299,7 +305,7 @@ macro_rules! impl_weight_for_float {
299305
const MAX: Self = ::core::$T::MAX;
300306
const ZERO: Self = 0.0;
301307

302-
fn try_from_usize_lossy(n: usize) -> Option<Self> {
308+
fn try_from_u32_lossy(n: u32) -> Option<Self> {
303309
Some(n as $T)
304310
}
305311

@@ -328,9 +334,9 @@ macro_rules! impl_weight_for_int {
328334
const MAX: Self = ::core::$T::MAX;
329335
const ZERO: Self = 0;
330336

331-
fn try_from_usize_lossy(n: usize) -> Option<Self> {
337+
fn try_from_u32_lossy(n: u32) -> Option<Self> {
332338
let n_converted = n as Self;
333-
if n_converted >= Self::ZERO && n_converted as usize == n {
339+
if n_converted >= Self::ZERO && n_converted as u32 == n {
334340
Some(n_converted)
335341
} else {
336342
None
@@ -439,21 +445,21 @@ mod test {
439445
where
440446
WeightedIndex<W>: fmt::Debug,
441447
{
442-
const NUM_WEIGHTS: usize = 10;
443-
const ZERO_WEIGHT_INDEX: usize = 3;
448+
const NUM_WEIGHTS: u32 = 10;
449+
const ZERO_WEIGHT_INDEX: u32 = 3;
444450
const NUM_SAMPLES: u32 = 15000;
445451
let mut rng = ::test::rng(0x9c9fa0b0580a7031);
446452

447453
let weights = {
448-
let mut weights = Vec::with_capacity(NUM_WEIGHTS);
454+
let mut weights = Vec::with_capacity(NUM_WEIGHTS as usize);
449455
let random_weight_distribution = ::distributions::Uniform::new_inclusive(
450456
W::ZERO,
451-
W::MAX / W::try_from_usize_lossy(NUM_WEIGHTS).unwrap(),
457+
W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
452458
);
453459
for _ in 0..NUM_WEIGHTS {
454460
weights.push(rng.sample(&random_weight_distribution));
455461
}
456-
weights[ZERO_WEIGHT_INDEX] = W::ZERO;
462+
weights[ZERO_WEIGHT_INDEX as usize] = W::ZERO;
457463
weights
458464
};
459465
let weight_sum = weights.iter().map(|w| *w).sum::<W>();
@@ -463,12 +469,12 @@ mod test {
463469
.collect::<Vec<f64>>();
464470
let weight_distribution = WeightedIndex::new(weights).unwrap();
465471

466-
let mut counts = vec![0_usize; NUM_WEIGHTS];
472+
let mut counts = vec![0; NUM_WEIGHTS as usize];
467473
for _ in 0..NUM_SAMPLES {
468474
counts[rng.sample(&weight_distribution)] += 1;
469475
}
470476

471-
assert_eq!(counts[ZERO_WEIGHT_INDEX], 0);
477+
assert_eq!(counts[ZERO_WEIGHT_INDEX as usize], 0);
472478
for (count, expected_count) in counts.into_iter().zip(expected_counts) {
473479
let difference = (count as f64 - expected_count).abs();
474480
let max_allowed_difference = NUM_SAMPLES as f64 / NUM_WEIGHTS as f64 * 0.1;

src/distributions/weighted/mod.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,9 @@ pub enum WeightedError {
208208

209209
/// All items in the provided weight collection are zero.
210210
AllWeightsZero,
211+
212+
/// Too many weights are provided (length greater than `u32::MAX`)
213+
TooMany,
211214
}
212215

213216
impl WeightedError {
@@ -216,6 +219,7 @@ impl WeightedError {
216219
WeightedError::NoItem => "No weights provided.",
217220
WeightedError::InvalidWeight => "A weight is invalid.",
218221
WeightedError::AllWeightsZero => "All weights are zero.",
222+
WeightedError::TooMany => "Too many weights (hit u32::MAX)",
219223
}
220224
}
221225
}

0 commit comments

Comments
 (0)