Skip to content

Commit 3c8f92b

Browse files
authored
Merge pull request #1180 from rust-random/work2
fill_via_chunks: use safe code via chunks_exact_mut on BE
2 parents 6e6b4ce + 34a8f13 commit 3c8f92b

File tree

1 file changed

+53
-30
lines changed

1 file changed

+53
-30
lines changed

rand_core/src/impls.rs

Lines changed: 53 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -52,36 +52,59 @@ pub fn fill_bytes_via_next<R: RngCore + ?Sized>(rng: &mut R, dest: &mut [u8]) {
5252
}
5353
}
5454

55-
macro_rules! fill_via_chunks {
56-
($src:expr, $dst:expr, $ty:ty) => {{
57-
const SIZE: usize = core::mem::size_of::<$ty>();
58-
let chunk_size_u8 = min($src.len() * SIZE, $dst.len());
59-
let chunk_size = (chunk_size_u8 + SIZE - 1) / SIZE;
60-
61-
// The following can be replaced with safe code, but unfortunately it's
62-
// ca. 8% slower.
63-
if cfg!(target_endian = "little") {
64-
unsafe {
65-
core::ptr::copy_nonoverlapping(
66-
$src.as_ptr() as *const u8,
67-
$dst.as_mut_ptr(),
68-
chunk_size_u8);
69-
}
70-
} else {
71-
for (&n, chunk) in $src.iter().zip($dst.chunks_mut(SIZE)) {
72-
let tmp = n.to_le();
73-
let src_ptr = &tmp as *const $ty as *const u8;
74-
unsafe {
75-
core::ptr::copy_nonoverlapping(
76-
src_ptr,
77-
chunk.as_mut_ptr(),
78-
chunk.len());
79-
}
80-
}
55+
trait Observable: Copy {
56+
type Bytes: AsRef<[u8]>;
57+
fn to_le_bytes(self) -> Self::Bytes;
58+
59+
// Contract: observing self is memory-safe (implies no uninitialised padding)
60+
fn as_byte_slice(x: &[Self]) -> &[u8];
61+
}
62+
impl Observable for u32 {
63+
type Bytes = [u8; 4];
64+
fn to_le_bytes(self) -> Self::Bytes {
65+
self.to_le_bytes()
66+
}
67+
fn as_byte_slice(x: &[Self]) -> &[u8] {
68+
let ptr = x.as_ptr() as *const u8;
69+
let len = x.len() * core::mem::size_of::<Self>();
70+
unsafe { core::slice::from_raw_parts(ptr, len) }
71+
}
72+
}
73+
impl Observable for u64 {
74+
type Bytes = [u8; 8];
75+
fn to_le_bytes(self) -> Self::Bytes {
76+
self.to_le_bytes()
77+
}
78+
fn as_byte_slice(x: &[Self]) -> &[u8] {
79+
let ptr = x.as_ptr() as *const u8;
80+
let len = x.len() * core::mem::size_of::<Self>();
81+
unsafe { core::slice::from_raw_parts(ptr, len) }
82+
}
83+
}
84+
85+
fn fill_via_chunks<T: Observable>(src: &[T], dest: &mut [u8]) -> (usize, usize) {
86+
let size = core::mem::size_of::<T>();
87+
let byte_len = min(src.len() * size, dest.len());
88+
let num_chunks = (byte_len + size - 1) / size;
89+
90+
if cfg!(target_endian = "little") {
91+
// On LE we can do a simple copy, which is 25-50% faster:
92+
dest[..byte_len].copy_from_slice(&T::as_byte_slice(&src[..num_chunks])[..byte_len]);
93+
} else {
94+
// This code is valid on all arches, but slower than the above:
95+
let mut i = 0;
96+
let mut iter = dest[..byte_len].chunks_exact_mut(size);
97+
while let Some(chunk) = iter.next() {
98+
chunk.copy_from_slice(src[i].to_le_bytes().as_ref());
99+
i += 1;
81100
}
101+
let chunk = iter.into_remainder();
102+
if !chunk.is_empty() {
103+
chunk.copy_from_slice(&src[i].to_le_bytes().as_ref()[..chunk.len()]);
104+
}
105+
}
82106

83-
(chunk_size, chunk_size_u8)
84-
}};
107+
(num_chunks, byte_len)
85108
}
86109

87110
/// Implement `fill_bytes` by reading chunks from the output buffer of a block
@@ -115,7 +138,7 @@ macro_rules! fill_via_chunks {
115138
/// }
116139
/// ```
117140
pub fn fill_via_u32_chunks(src: &[u32], dest: &mut [u8]) -> (usize, usize) {
118-
fill_via_chunks!(src, dest, u32)
141+
fill_via_chunks(src, dest)
119142
}
120143

121144
/// Implement `fill_bytes` by reading chunks from the output buffer of a block
@@ -129,7 +152,7 @@ pub fn fill_via_u32_chunks(src: &[u32], dest: &mut [u8]) -> (usize, usize) {
129152
///
130153
/// See `fill_via_u32_chunks` for an example.
131154
pub fn fill_via_u64_chunks(src: &[u64], dest: &mut [u8]) -> (usize, usize) {
132-
fill_via_chunks!(src, dest, u64)
155+
fill_via_chunks(src, dest)
133156
}
134157

135158
/// Implement `next_u32` via `fill_bytes`, little-endian order.

0 commit comments

Comments
 (0)