Skip to content
This repository was archived by the owner on Feb 8, 2025. It is now read-only.

Make Make ZipfDistribution generic over T and impl Distribution<T> #6

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ maintenance = { status = "passively-maintained" }

[dependencies]
rand = { version = "0.8.0", default-features = false }
num = "0.4"

[dev-dependencies]
rand = "0.8.0"
55 changes: 26 additions & 29 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@
//! use rand::distributions::Distribution;
//!
//! let mut rng = rand::thread_rng();
//! let mut zipf = zipf::ZipfDistribution::new(1000, 1.03).unwrap();
//! let sample = zipf.sample(&mut rng);
//! let mut zipf = zipf::ZipfDistribution::<usize>::new(1000, 1.03).unwrap();
//! let sample: usize = zipf.sample(&mut rng);
//! ```
//!
//! This implementation is effectively a direct port of Apache Common's
Expand All @@ -31,12 +31,13 @@

#![warn(rust_2018_idioms)]

use num::FromPrimitive;
use rand::Rng;

/// Random number generator that generates Zipf-distributed random numbers using rejection
/// inversion.
#[derive(Clone, Copy)]
pub struct ZipfDistribution {
pub struct ZipfDistribution<T> {
/// Number of elements
num_elements: f64,
/// Exponent parameter of the distribution
Expand All @@ -47,9 +48,10 @@ pub struct ZipfDistribution {
h_integral_num_elements: f64,
/// `2 - hIntegralInverse(hIntegral(2.5) - h(2)}`
s: f64,
marker: PhantomData<T>,
}

impl ZipfDistribution {
impl<T> ZipfDistribution<T> {
/// Creates a new [Zipf-distributed](https://en.wikipedia.org/wiki/Zipf's_law)
/// random number generator.
///
Expand All @@ -65,17 +67,14 @@ impl ZipfDistribution {
let z = ZipfDistribution {
num_elements: num_elements as f64,
exponent,
h_integral_x1: ZipfDistribution::h_integral(1.5, exponent) - 1f64,
h_integral_num_elements: ZipfDistribution::h_integral(
num_elements as f64 + 0.5,
exponent,
),
h_integral_x1: Self::h_integral(1.5, exponent) - 1f64,
h_integral_num_elements: Self::h_integral(num_elements as f64 + 0.5, exponent),
s: 2f64
- ZipfDistribution::h_integral_inv(
ZipfDistribution::h_integral(2.5, exponent)
- ZipfDistribution::h(2f64, exponent),
- Self::h_integral_inv(
Self::h_integral(2.5, exponent) - Self::h(2f64, exponent),
exponent,
),
marker: PhantomData,
};

// populate cache
Expand All @@ -84,8 +83,8 @@ impl ZipfDistribution {
}
}

impl ZipfDistribution {
fn next<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
impl<T: FromPrimitive> ZipfDistribution<T> {
fn next<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
// The paper describes an algorithm for exponents larger than 1 (Algorithm ZRI).
//
// The original method uses
Expand All @@ -108,17 +107,16 @@ impl ZipfDistribution {
let hnum = self.h_integral_num_elements;

loop {
use std::cmp;
let u: f64 = hnum + rng.gen::<f64>() * (self.h_integral_x1 - hnum);
// u is uniformly distributed in (h_integral_x1, h_integral_num_elements]

let x: f64 = ZipfDistribution::h_integral_inv(u, self.exponent);
let x: f64 = Self::h_integral_inv(u, self.exponent);

// Limit k to the range [1, num_elements] if it would be outside
// due to numerical inaccuracies.
let k64 = x.max(1.0).min(self.num_elements);
// float -> integer rounds towards zero
let k = cmp::max(1, k64 as usize);
let k = T::from_f64(k64.max(1.0)).unwrap();

// Here, the distribution of k is given by:
//
Expand All @@ -127,8 +125,7 @@ impl ZipfDistribution {
//
// where C = 1 / (h_integral_num_elements - h_integral_x1)
if k64 - x <= self.s
|| u >= ZipfDistribution::h_integral(k64 + 0.5, self.exponent)
- ZipfDistribution::h(k64, self.exponent)
|| u >= Self::h_integral(k64 + 0.5, self.exponent) - Self::h(k64, self.exponent)
{
// Case k = 1:
//
Expand Down Expand Up @@ -173,14 +170,14 @@ impl ZipfDistribution {
}
}

impl rand::distributions::Distribution<usize> for ZipfDistribution {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
impl<T: FromPrimitive> rand::distributions::Distribution<T> for ZipfDistribution<T> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
self.next(rng)
}
}

use std::fmt;
impl fmt::Debug for ZipfDistribution {
use std::{fmt, marker::PhantomData};
impl<T> fmt::Debug for ZipfDistribution<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> {
f.debug_struct("ZipfDistribution")
.field("e", &self.exponent)
Expand All @@ -189,7 +186,7 @@ impl fmt::Debug for ZipfDistribution {
}
}

impl ZipfDistribution {
impl<T> ZipfDistribution<T> {
/// Computes `H(x)`, defined as
///
/// - `(x^(1 - exponent) - 1) / (1 - exponent)`, if `exponent != 1`
Expand Down Expand Up @@ -262,7 +259,7 @@ mod test {
// sample a bunch
let mut buckets = vec![0; N];
for _ in 0..samples {
let sample = zipf.sample(&mut rng);
let sample: usize = zipf.sample(&mut rng);
buckets[sample - 1] += 1;
}

Expand Down Expand Up @@ -323,13 +320,13 @@ mod test {

#[test]
fn debug() {
eprintln!("{:?}", ZipfDistribution::new(100, 1.0).unwrap());
eprintln!("{:?}", ZipfDistribution::<usize>::new(100, 1.0).unwrap());
}

#[test]
fn errs() {
ZipfDistribution::new(0, 1.0).unwrap_err();
ZipfDistribution::new(100, 0.0).unwrap_err();
ZipfDistribution::new(100, -1.0).unwrap_err();
ZipfDistribution::<usize>::new(0, 1.0).unwrap_err();
ZipfDistribution::<usize>::new(100, 0.0).unwrap_err();
ZipfDistribution::<usize>::new(100, -1.0).unwrap_err();
}
}