Skip to content

Commit 0a6ac86

Browse files
committed
Add BlockRng wrapper
1 parent 0eefffb commit 0a6ac86

File tree

1 file changed

+180
-3
lines changed

1 file changed

+180
-3
lines changed

src/impls.rs

Lines changed: 180 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,13 @@
2020
//! non-reproducible sources (e.g. `OsRng`) need not bother with it.
2121
2222
// TODO: eventually these should be exported somehow
23-
#![allow(unused)]
2423

2524
use core::intrinsics::transmute;
2625
use core::ptr::copy_nonoverlapping;
27-
use core::slice;
26+
use core::{fmt, slice};
2827
use core::cmp::min;
2928
use core::mem::size_of;
30-
use RngCore;
29+
use {RngCore, BlockRngCore, CryptoRng, SeedableRng, Error};
3130

3231
/// Implement `next_u64` via `next_u32`, little-endian order.
3332
pub fn next_u64_via_u32<R: RngCore + ?Sized>(rng: &mut R) -> u64 {
@@ -167,4 +166,182 @@ pub fn next_u64_via_fill<R: RngCore + ?Sized>(rng: &mut R) -> u64 {
167166
impl_uint_from_fill!(rng, u64, 8)
168167
}
169168

169+
/// Wrapper around PRNGs that implement [`BlockRngCore`] to keep a results
170+
/// buffer and offer the methods from [`RngCore`].
171+
///
172+
/// `BlockRng` has optimized methods to read from the output array that the
173+
/// algorithm of many cryptograpic RNGs generates natively. Also they handle the
174+
/// bookkeeping when to generate a new batch of values.
175+
///
176+
/// `next_u32` simply indexes the array. `next_u64` tries to read two `u32`
177+
/// values at a time if possible, and handles edge cases like when only one
178+
/// value is left. `try_fill_bytes` is optimized to even attempt to use the
179+
/// [`BlockRngCore`] implementation to write the results directly to the
180+
/// destination slice. No generated values are ever thown away.
181+
///
182+
/// Although `BlockCoreRng::generate` can return a `Result`, we assume all PRNGs
183+
/// to be infallible, and for the `Result` to only have a signaling function.
184+
/// Therefore, the error is only reported by `try_fill_bytes`, all other
185+
/// functions squelch the error.
186+
///
187+
/// For easy initialization `BlockRng` also implements [`SeedableRng`].
188+
///
189+
/// [`BlockRngCore`]: ../BlockRngCore.t.html
190+
/// [`RngCore`]: ../RngCore.t.html
191+
/// [`SeedableRng`]: ../SeedableRng.t.html
192+
#[derive(Clone)]
193+
pub struct BlockRng<R: BlockRngCore<u32>> {
194+
pub core: R,
195+
pub results: R::Results,
196+
pub index: usize,
197+
}
198+
199+
// Custom Debug implementation that does not expose the contents of `results`.
200+
impl<R: BlockRngCore<u32>+fmt::Debug> fmt::Debug for BlockRng<R> {
201+
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
202+
fmt.debug_struct("BlockRng")
203+
.field("core", &self.core)
204+
.field("result_len", &self.results.as_ref().len())
205+
.field("index", &self.index)
206+
.finish()
207+
}
208+
}
209+
210+
impl<R: BlockRngCore<u32>> RngCore for BlockRng<R> {
211+
#[inline(always)]
212+
fn next_u32(&mut self) -> u32 {
213+
if self.index >= self.results.as_ref().len() {
214+
let _ = self.core.generate(&mut self.results).unwrap();
215+
self.index = 0;
216+
}
217+
218+
let value = self.results.as_ref()[self.index];
219+
self.index += 1;
220+
value
221+
}
222+
223+
#[inline(always)]
224+
fn next_u64(&mut self) -> u64 {
225+
let len = self.results.as_ref().len();
226+
227+
let index = self.index;
228+
if index < len-1 {
229+
self.index += 2;
230+
// Read an u64 from the current index
231+
if cfg!(any(target_arch = "x86", target_arch = "x86_64")) {
232+
unsafe { *(&self.results.as_ref()[index] as *const u32 as *const u64) }
233+
} else {
234+
let x = self.results.as_ref()[index] as u64;
235+
let y = self.results.as_ref()[index + 1] as u64;
236+
(y << 32) | x
237+
}
238+
} else if index >= len {
239+
let _ = self.core.generate(&mut self.results);
240+
self.index = 2;
241+
let x = self.results.as_ref()[0] as u64;
242+
let y = self.results.as_ref()[1] as u64;
243+
(y << 32) | x
244+
} else {
245+
let x = self.results.as_ref()[len-1] as u64;
246+
let _ = self.core.generate(&mut self.results);
247+
self.index = 1;
248+
let y = self.results.as_ref()[0] as u64;
249+
(y << 32) | x
250+
}
251+
}
252+
253+
fn fill_bytes(&mut self, dest: &mut [u8]) {
254+
let _ = self.try_fill_bytes(dest);
255+
}
256+
257+
// As an optimization we try to write directly into the output buffer.
258+
// This is only enabled for platforms where unaligned writes are known to
259+
// be safe and fast.
260+
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
261+
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
262+
let mut filled = 0;
263+
let mut res = Ok(());
264+
265+
// Continue filling from the current set of results
266+
if self.index < self.results.as_ref().len() {
267+
let (consumed_u32, filled_u8) =
268+
fill_via_u32_chunks(&self.results.as_ref()[self.index..],
269+
dest);
270+
271+
self.index += consumed_u32;
272+
filled += filled_u8;
273+
}
274+
275+
let len_remainder =
276+
(dest.len() - filled) % (self.results.as_ref().len() * 4);
277+
let len_direct = dest.len() - len_remainder;
278+
279+
while filled < len_direct {
280+
let dest_u32: &mut R::Results = unsafe {
281+
::core::mem::transmute(dest[filled..].as_mut_ptr())
282+
};
283+
let res2 = self.core.generate(dest_u32);
284+
if res2.is_err() && res.is_ok() { res = res2 };
285+
filled += self.results.as_ref().len() * 4;
286+
}
287+
self.index = self.results.as_ref().len();
288+
289+
if len_remainder > 0 {
290+
let res2 = self.core.generate(&mut self.results);
291+
if res2.is_err() && res.is_ok() { res = res2 };
292+
293+
let (consumed_u32, _) =
294+
fill_via_u32_chunks(&mut self.results.as_ref(),
295+
&mut dest[filled..]);
296+
297+
self.index = consumed_u32;
298+
}
299+
res
300+
}
301+
302+
#[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
303+
fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), Error> {
304+
let mut res = Ok(());
305+
let mut read_len = 0;
306+
while read_len < dest.len() {
307+
if self.index >= self.results.as_ref().len() {
308+
let res2 = self.core.generate(&mut self.results);
309+
if res2.is_err() && res.is_ok() { res = res2 };
310+
self.index = 0;
311+
}
312+
let (consumed_u32, filled_u8) =
313+
fill_via_u32_chunks(&self.results.as_ref()[self.index..],
314+
&mut dest[read_len..]);
315+
316+
self.index += consumed_u32;
317+
read_len += filled_u8;
318+
}
319+
res
320+
}
321+
}
322+
323+
impl<R: BlockRngCore<u32> + SeedableRng> SeedableRng for BlockRng<R> {
324+
type Seed = R::Seed;
325+
326+
fn from_seed(seed: Self::Seed) -> Self {
327+
let results_empty = R::Results::default();
328+
Self {
329+
core: R::from_seed(seed),
330+
index: results_empty.as_ref().len(), // generate on first use
331+
results: results_empty,
332+
}
333+
}
334+
335+
fn from_rng<RNG: RngCore>(rng: &mut RNG) -> Result<Self, Error> {
336+
let results_empty = R::Results::default();
337+
Ok(Self {
338+
core: R::from_rng(rng)?,
339+
index: results_empty.as_ref().len(), // generate on first use
340+
results: results_empty,
341+
})
342+
}
343+
}
344+
345+
impl<R: BlockRngCore<u32>+CryptoRng> CryptoRng for BlockRng<R> {}
346+
170347
// TODO: implement tests for the above

0 commit comments

Comments
 (0)