Skip to content

Commit b4de692

Browse files
Improve performance of set_bits by avoiding to set individual bits (#6288)
* bench * fix: Optimize set_bits * clippy * clippyj * miri * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * fix: Optimize set_bits * miri * miri * miri * miri * miri * miri * miri * miri * miri * miri * miri * address review comments * address review comments * address review comments * Revert "address review comments" This reverts commit ef2864f. * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * Revert "address review comments" This reverts commit a15db14. * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments * address review comments
1 parent ba85fa3 commit b4de692

File tree

1 file changed

+158
-26
lines changed

1 file changed

+158
-26
lines changed

arrow-buffer/src/util/bit_mask.rs

Lines changed: 158 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,48 +17,144 @@
1717

1818
//! Utils for working with packed bit masks
1919
20-
use crate::bit_chunk_iterator::BitChunks;
21-
use crate::bit_util::{ceil, get_bit, set_bit};
20+
use crate::bit_util::ceil;
2221

2322
/// Sets all bits on `write_data` in the range `[offset_write..offset_write+len]` to be equal to the
2423
/// bits in `data` in the range `[offset_read..offset_read+len]`
2524
/// returns the number of `0` bits `data[offset_read..offset_read+len]`
25+
/// `offset_write`, `offset_read`, and `len` are in terms of bits
2626
pub fn set_bits(
2727
write_data: &mut [u8],
2828
data: &[u8],
2929
offset_write: usize,
3030
offset_read: usize,
3131
len: usize,
3232
) -> usize {
33+
assert!(offset_write + len <= write_data.len() * 8);
34+
assert!(offset_read + len <= data.len() * 8);
3335
let mut null_count = 0;
34-
35-
let mut bits_to_align = offset_write % 8;
36-
if bits_to_align > 0 {
37-
bits_to_align = std::cmp::min(len, 8 - bits_to_align);
36+
let mut acc = 0;
37+
while len > acc {
38+
// SAFETY: the arguments to `set_upto_64bits` are within the valid range because
39+
// (offset_write + acc) + (len - acc) == offset_write + len <= write_data.len() * 8
40+
// (offset_read + acc) + (len - acc) == offset_read + len <= data.len() * 8
41+
let (n, len_set) = unsafe {
42+
set_upto_64bits(
43+
write_data,
44+
data,
45+
offset_write + acc,
46+
offset_read + acc,
47+
len - acc,
48+
)
49+
};
50+
null_count += n;
51+
acc += len_set;
3852
}
39-
let mut write_byte_index = ceil(offset_write + bits_to_align, 8);
40-
41-
// Set full bytes provided by bit chunk iterator (which iterates in 64 bits at a time)
42-
let chunks = BitChunks::new(data, offset_read + bits_to_align, len - bits_to_align);
43-
chunks.iter().for_each(|chunk| {
44-
null_count += chunk.count_zeros();
45-
write_data[write_byte_index..write_byte_index + 8].copy_from_slice(&chunk.to_le_bytes());
46-
write_byte_index += 8;
47-
});
48-
49-
// Set individual bits both to align write_data to a byte offset and the remainder bits not covered by the bit chunk iterator
50-
let remainder_offset = len - chunks.remainder_len();
51-
(0..bits_to_align)
52-
.chain(remainder_offset..len)
53-
.for_each(|i| {
54-
if get_bit(data, offset_read + i) {
55-
set_bit(write_data, offset_write + i);
53+
54+
null_count
55+
}
56+
57+
/// Similar to `set_bits` but sets only upto 64 bits, actual number of bits set may vary.
58+
/// Returns a pair of the number of `0` bits and the number of bits set
59+
///
60+
/// # Safety
61+
/// The caller must ensure all arguments are within the valid range.
62+
#[inline]
63+
unsafe fn set_upto_64bits(
64+
write_data: &mut [u8],
65+
data: &[u8],
66+
offset_write: usize,
67+
offset_read: usize,
68+
len: usize,
69+
) -> (usize, usize) {
70+
let read_byte = offset_read / 8;
71+
let read_shift = offset_read % 8;
72+
let write_byte = offset_write / 8;
73+
let write_shift = offset_write % 8;
74+
75+
if len >= 64 {
76+
let chunk = unsafe { (data.as_ptr().add(read_byte) as *const u64).read_unaligned() };
77+
if read_shift == 0 {
78+
if write_shift == 0 {
79+
// no shifting necessary
80+
let len = 64;
81+
let null_count = chunk.count_zeros() as usize;
82+
unsafe { write_u64_bytes(write_data, write_byte, chunk) };
83+
(null_count, len)
5684
} else {
57-
null_count += 1;
85+
// only write shifting necessary
86+
let len = 64 - write_shift;
87+
let chunk = chunk << write_shift;
88+
let null_count = len - chunk.count_ones() as usize;
89+
unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
90+
(null_count, len)
5891
}
59-
});
92+
} else if write_shift == 0 {
93+
// only read shifting necessary
94+
let len = 64 - 8; // 56 bits so the next set_upto_64bits call will see write_shift == 0
95+
let chunk = (chunk >> read_shift) & 0x00FFFFFFFFFFFFFF; // 56 bits mask
96+
let null_count = len - chunk.count_ones() as usize;
97+
unsafe { write_u64_bytes(write_data, write_byte, chunk) };
98+
(null_count, len)
99+
} else {
100+
let len = 64 - std::cmp::max(read_shift, write_shift);
101+
let chunk = (chunk >> read_shift) << write_shift;
102+
let null_count = len - chunk.count_ones() as usize;
103+
unsafe { or_write_u64_bytes(write_data, write_byte, chunk) };
104+
(null_count, len)
105+
}
106+
} else if len == 1 {
107+
let byte_chunk = (unsafe { data.get_unchecked(read_byte) } >> read_shift) & 1;
108+
unsafe { *write_data.get_unchecked_mut(write_byte) |= byte_chunk << write_shift };
109+
((byte_chunk ^ 1) as usize, 1)
110+
} else {
111+
let len = std::cmp::min(len, 64 - std::cmp::max(read_shift, write_shift));
112+
let bytes = ceil(len + read_shift, 8);
113+
// SAFETY: the args of `read_bytes_to_u64` are valid as read_byte + bytes <= data.len()
114+
let chunk = unsafe { read_bytes_to_u64(data, read_byte, bytes) };
115+
let mask = u64::MAX >> (64 - len);
116+
let chunk = (chunk >> read_shift) & mask; // masking to read `len` bits only
117+
let chunk = chunk << write_shift; // shifting back to align with `write_data`
118+
let null_count = len - chunk.count_ones() as usize;
119+
let bytes = ceil(len + write_shift, 8);
120+
for (i, c) in chunk.to_le_bytes().iter().enumerate().take(bytes) {
121+
unsafe { *write_data.get_unchecked_mut(write_byte + i) |= c };
122+
}
123+
(null_count, len)
124+
}
125+
}
60126

61-
null_count as usize
127+
/// # Safety
128+
/// The caller must ensure all arguments are within the valid range.
129+
#[inline]
130+
unsafe fn read_bytes_to_u64(data: &[u8], offset: usize, count: usize) -> u64 {
131+
debug_assert!(count <= 8);
132+
let mut tmp = std::mem::MaybeUninit::<u64>::new(0);
133+
let src = data.as_ptr().add(offset);
134+
unsafe {
135+
std::ptr::copy_nonoverlapping(src, tmp.as_mut_ptr() as *mut u8, count);
136+
tmp.assume_init()
137+
}
138+
}
139+
140+
/// # Safety
141+
/// The caller must ensure `data` has `offset..(offset + 8)` range
142+
#[inline]
143+
unsafe fn write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
144+
let ptr = data.as_mut_ptr().add(offset) as *mut u64;
145+
ptr.write_unaligned(chunk);
146+
}
147+
148+
/// Similar to `write_u64_bytes`, but this method ORs the offset addressed `data` and `chunk`
149+
/// instead of overwriting
150+
///
151+
/// # Safety
152+
/// The caller must ensure `data` has `offset..(offset + 8)` range
153+
#[inline]
154+
unsafe fn or_write_u64_bytes(data: &mut [u8], offset: usize, chunk: u64) {
155+
let ptr = data.as_mut_ptr().add(offset);
156+
let chunk = chunk | (*ptr) as u64;
157+
(ptr as *mut u64).write_unaligned(chunk);
62158
}
63159

64160
#[cfg(test)]
@@ -185,4 +281,40 @@ mod tests {
185281
assert_eq!(destination, expected_data);
186282
assert_eq!(result, expected_null_count);
187283
}
284+
285+
#[test]
286+
fn test_set_upto_64bits() {
287+
// len >= 64
288+
let write_data: &mut [u8] = &mut [0; 9];
289+
let data: &[u8] = &[
290+
0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001, 0b00000001,
291+
0b00000001, 0b00000001,
292+
];
293+
let offset_write = 1;
294+
let offset_read = 0;
295+
let len = 65;
296+
let (n, len_set) =
297+
unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
298+
assert_eq!(n, 55);
299+
assert_eq!(len_set, 63);
300+
assert_eq!(
301+
write_data,
302+
&[
303+
0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010, 0b00000010,
304+
0b00000010, 0b00000000
305+
]
306+
);
307+
308+
// len = 1
309+
let write_data: &mut [u8] = &mut [0b00000000];
310+
let data: &[u8] = &[0b00000001];
311+
let offset_write = 1;
312+
let offset_read = 0;
313+
let len = 1;
314+
let (n, len_set) =
315+
unsafe { set_upto_64bits(write_data, data, offset_write, offset_read, len) };
316+
assert_eq!(n, 0);
317+
assert_eq!(len_set, 1);
318+
assert_eq!(write_data, &[0b00000010]);
319+
}
188320
}

0 commit comments

Comments
 (0)