|
17 | 17 |
|
18 | 18 | //! Utils for working with packed bit masks
|
19 | 19 |
|
20 |
| -use crate::bit_chunk_iterator::BitChunks; |
21 |
| -use crate::bit_util::{ceil, get_bit, set_bit}; |
| 20 | +use crate::bit_util::ceil; |
22 | 21 |
|
23 | 22 | /// Sets all bits on `write_data` in the range `[offset_write..offset_write+len]` to be equal to the
|
24 | 23 | /// bits in `data` in the range `[offset_read..offset_read+len]`
|
25 | 24 | /// returns the number of `0` bits `data[offset_read..offset_read+len]`
|
| 25 | +/// `offset_write`, `offset_read`, and `len` are in terms of bits |
26 | 26 | pub fn set_bits(
|
27 | 27 | write_data: &mut [u8],
|
28 | 28 | data: &[u8],
|
29 | 29 | offset_write: usize,
|
30 | 30 | offset_read: usize,
|
31 | 31 | len: usize,
|
32 | 32 | ) -> usize {
|
| 33 | + assert!(offset_write + len <= write_data.len() * 8); |
| 34 | + assert!(offset_read + len <= data.len() * 8); |
33 | 35 | let mut null_count = 0;
|
34 |
| - |
35 |
| - let mut bits_to_align = offset_write % 8; |
36 |
| - if bits_to_align > 0 { |
37 |
| - bits_to_align = std::cmp::min(len, 8 - bits_to_align); |
| 36 | + let mut acc = 0; |
| 37 | + while len > acc { |
| 38 | + // SAFETY: the arguments to `set_upto_64bits` are within the valid range because |
| 39 | + // (offset_write + acc) + (len - acc) == offset_write + len <= write_data.len() * 8 |
| 40 | + // (offset_read + acc) + (len - acc) == offset_read + len <= data.len() * 8 |
| 41 | + let (n, len_set) = unsafe { |
| 42 | + set_upto_64bits( |
| 43 | + write_data, |
| 44 | + data, |
| 45 | + offset_write + acc, |
| 46 | + offset_read + acc, |
| 47 | + len - acc, |
| 48 | + ) |
| 49 | + }; |
| 50 | + null_count += n; |
| 51 | + acc += len_set; |
38 | 52 | }
|
39 |
| - let mut write_byte_index = ceil(offset_write + bits_to_align, 8); |
40 |
| - |
41 |
| - // Set full bytes provided by bit chunk iterator (which iterates in 64 bits at a time) |
42 |
| - let chunks = BitChunks::new(data, offset_read + bits_to_align, len - bits_to_align); |
43 |
| - chunks.iter().for_each(|chunk| { |
44 |
| - null_count += chunk.count_zeros(); |
45 |
| - write_data[write_byte_index..write_byte_index + 8].copy_from_slice(&chunk.to_le_bytes()); |
46 |
| - write_byte_index += 8; |
47 |
| - }); |
48 |
| - |
49 |
| - // Set individual bits both to align write_data to a byte offset and the remainder bits not covered by the bit chunk iterator |
50 |
| - let remainder_offset = len - chunks.remainder_len(); |
51 |
| - (0..bits_to_align) |
52 |
| - .chain(remainder_offset..len) |
53 |
| - .for_each(|i| { |
54 |
| - if get_bit(data, offset_read + i) { |
55 |
| - set_bit(write_data, offset_write + i); |
| 53 | + |
| 54 | + null_count |
| 55 | +} |
| 56 | + |
| 57 | +/// Similar to `set_bits` but sets only upto 64 bits, actual number of bits set may vary. |
| 58 | +/// Returns a pair of the number of `0` bits and the number of bits set |
| 59 | +/// |
| 60 | +/// # Safety |
| 61 | +/// The caller must ensure all arguments are within the valid range. |
| 62 | +#[inline] |
| 63 | +unsafe fn set_upto_64bits( |
| 64 | + write_data: &mut [u8], |
| 65 | + data: &[u8], |
| 66 | + offset_write: usize, |
| 67 | + offset_read: usize, |
| 68 | + len: usize, |
| 69 | +) -> (usize, usize) { |
| 70 | + let read_byte = offset_read / 8; |
| 71 | + let read_shift = offset_read % 8; |
| 72 | + let write_byte = offset_write / 8; |
| 73 | + let write_shift = offset_write % 8; |
| 74 | + |
| 75 | + if len >= 64 { |
| 76 | + let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() }; |
| 77 | + if read_shift == 0 { |
| 78 | + if write_shift == 0 { |
| 79 | + // no shifting necessary |
| 80 | + let len = 64; |
| 81 | + let null_count = chunk.count_zeros() as usize; |
| 82 | + unsafe { write_u64_bytes(write_data, write_byte, chunk) }; |
| 83 | + (null_count, len) |
56 | 84 | } else {
|
57 |
| - null_count += 1; |
| 85 | + // only write shifting necessary |
| 86 | + let len = 64 - write_shift; |
| 87 | + let chunk = chunk << write_shift; |
| 88 | + let null_count = len - chunk.count_ones() as usize; |
| 89 | + unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; |
| 90 | + (null_count, len) |
58 | 91 | }
|
59 |
| - }); |
| 92 | + } else if write_shift == 0 { |
| 93 | + // only read shifting necessary |
| 94 | + let len = 64 - 8; // 56 bits so the next set_upto_64bits call will see write_shift == 0 |
| 95 | + let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits mask |
| 96 | + let null_count = len - chunk.count_ones() as usize; |
| 97 | + unsafe { write_u64_bytes(write_data, write_byte, chunk) }; |
| 98 | + (null_count, len) |
| 99 | + } else { |
| 100 | + let len = 64 - std::cmp::max(read_shift, write_shift); |
| 101 | + let chunk = (chunk >> read_shift) << write_shift; |
| 102 | + let null_count = len - chunk.count_ones() as usize; |
| 103 | + unsafe { or_write_u64_bytes(write_data, write_byte, chunk) }; |
| 104 | + (null_count, len) |
| 105 | + } |
| 106 | + } else if len == 1 { |
| 107 | + let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1; |
| 108 | + unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift }; |
| 109 | + ((byte_chunk ^ 1) as usize, 1) |
| 110 | + } else { |
| 111 | + let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift)); |
| 112 | + let bytes = ceil(len + read_shift, 8); |
| 113 | + // SAFETY: the args of `read_bytes_to_u64` are valid as read_byte + bytes <= data.len() |
| 114 | + let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) }; |
| 115 | + let mask = u64::MAX >> (64 - len); |
| 116 | + let chunk = (chunk >> read_shift) & mask; // masking to read `len` bits only |
| 117 | + let chunk = chunk << write_shift; // shifting back to align with `write_data` |
| 118 | + let null_count = len - chunk.count_ones() as usize; |
| 119 | + let bytes = ceil(len + write_shift, 8); |
| 120 | + for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) { |
| 121 | + unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c }; |
| 122 | + } |
| 123 | + (null_count, len) |
| 124 | + } |
| 125 | +} |
60 | 126 |
|
61 |
| - null_count as usize |
| 127 | +/// # Safety |
| 128 | +/// The caller must ensure all arguments are within the valid range. |
| 129 | +#[inline] |
| 130 | +unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 { |
| 131 | + debug_assert!(count <= 8); |
| 132 | + let mut tmp = std::mem::MaybeUninit::<u64>::new(0); |
| 133 | + let src = data.as_ptr().add(offset); |
| 134 | + unsafe { |
| 135 | + std::ptr::copy_nonoverlapping(src, tmp.as_mut_ptr() as *mut u8, count); |
| 136 | + tmp.assume_init() |
| 137 | + } |
| 138 | +} |
| 139 | + |
| 140 | +/// # Safety |
| 141 | +/// The caller must ensure `data` has `offset..(offset + 8)` range |
| 142 | +#[inline] |
| 143 | +unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { |
| 144 | + let ptr = data.as_mut_ptr().add(offset) as *mut u64; |
| 145 | + ptr.write_unaligned(chunk); |
| 146 | +} |
| 147 | + |
| 148 | +/// Similar to `write_u64_bytes`, but this method ORs the offset addressed `data` and `chunk` |
| 149 | +/// instead of overwriting |
| 150 | +/// |
| 151 | +/// # Safety |
| 152 | +/// The caller must ensure `data` has `offset..(offset + 8)` range |
| 153 | +#[inline] |
| 154 | +unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) { |
| 155 | + let ptr = data.as_mut_ptr().add(offset); |
| 156 | + let chunk = chunk | (*ptr) as u64; |
| 157 | + (ptr as *mut u64).write_unaligned(chunk); |
62 | 158 | }
|
63 | 159 |
|
64 | 160 | #[cfg(test)]
|
@@ -185,4 +281,40 @@ mod tests {
|
185 | 281 | assert_eq!(destination, expected_data);
|
186 | 282 | assert_eq!(result, expected_null_count);
|
187 | 283 | }
|
| 284 | + |
| 285 | + #[test] |
| 286 | + fn test_set_upto_64bits() { |
| 287 | + // len >= 64 |
| 288 | + let write_data: &mut [u8] = &mut [0; 9]; |
| 289 | + let data: &[u8] = &[ |
| 290 | + 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, |
| 291 | + 0b00000001, 0b00000001, |
| 292 | + ]; |
| 293 | + let offset_write = 1; |
| 294 | + let offset_read = 0; |
| 295 | + let len = 65; |
| 296 | + let (n, len_set) = |
| 297 | + unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) }; |
| 298 | + assert_eq!(n, 55); |
| 299 | + assert_eq!(len_set, 63); |
| 300 | + assert_eq!( |
| 301 | + write_data, |
| 302 | + &[ |
| 303 | + 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, |
| 304 | + 0b00000010, 0b00000000 |
| 305 | + ] |
| 306 | + ); |
| 307 | + |
| 308 | + // len = 1 |
| 309 | + let write_data: &mut [u8] = &mut [0b00000000]; |
| 310 | + let data: &[u8] = &[0b00000001]; |
| 311 | + let offset_write = 1; |
| 312 | + let offset_read = 0; |
| 313 | + let len = 1; |
| 314 | + let (n, len_set) = |
| 315 | + unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) }; |
| 316 | + assert_eq!(n, 0); |
| 317 | + assert_eq!(len_set, 1); |
| 318 | + assert_eq!(write_data, &[0b00000010]); |
| 319 | + } |
188 | 320 | }
|
0 commit comments