Skip to content

Commit eec4280

Browse files
committed
Update bitmask API
1 parent da42aa5 commit eec4280

File tree

6 files changed

+196
-160
lines changed

6 files changed

+196
-160
lines changed

crates/core_simd/src/intrinsics.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ extern "platform-intrinsic" {
7676
pub(crate) fn simd_reduce_and<T, U>(x: T) -> U;
7777
pub(crate) fn simd_reduce_or<T, U>(x: T) -> U;
7878
pub(crate) fn simd_reduce_xor<T, U>(x: T) -> U;
79+
80+
// truncate integer vector to bitmask
81+
pub(crate) fn simd_bitmask<T, U>(x: T) -> U;
7982
}
8083

8184
#[cfg(feature = "std")]

crates/core_simd/src/lanes_at_most_32.rs

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,38 @@
11
/// Implemented for vectors that are supported by the implementation.
2-
pub trait LanesAtMost32 {}
2+
pub trait LanesAtMost32: sealed::Sealed {
3+
#[doc(hidden)]
4+
type BitMask: Into<u64>;
5+
}
6+
7+
mod sealed {
8+
pub trait Sealed {}
9+
}
310

411
macro_rules! impl_for {
512
{ $name:ident } => {
6-
impl LanesAtMost32 for $name<1> {}
7-
impl LanesAtMost32 for $name<2> {}
8-
impl LanesAtMost32 for $name<4> {}
9-
impl LanesAtMost32 for $name<8> {}
10-
impl LanesAtMost32 for $name<16> {}
11-
impl LanesAtMost32 for $name<32> {}
13+
impl<const LANES: usize> sealed::Sealed for $name<LANES>
14+
where
15+
$name<LANES>: LanesAtMost32,
16+
{}
17+
18+
impl LanesAtMost32 for $name<1> {
19+
type BitMask = u8;
20+
}
21+
impl LanesAtMost32 for $name<2> {
22+
type BitMask = u8;
23+
}
24+
impl LanesAtMost32 for $name<4> {
25+
type BitMask = u8;
26+
}
27+
impl LanesAtMost32 for $name<8> {
28+
type BitMask = u8;
29+
}
30+
impl LanesAtMost32 for $name<16> {
31+
type BitMask = u16;
32+
}
33+
impl LanesAtMost32 for $name<32> {
34+
type BitMask = u32;
35+
}
1236
}
1337
}
1438

crates/core_simd/src/masks/bitmask.rs

Lines changed: 48 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
1-
use crate::LanesAtMost32;
2-
31
/// A mask where each lane is represented by a single bit.
42
#[derive(Copy, Clone, Debug, PartialOrd, PartialEq, Ord, Eq, Hash)]
53
#[repr(transparent)]
6-
pub struct BitMask<const LANES: usize>(u64)
4+
pub struct BitMask<const LANES: usize>(u64);
75

86
impl<const LANES: usize> BitMask<LANES>
9-
where
10-
Self: LanesAtMost32,
117
{
128
#[inline]
139
pub fn splat(value: bool) -> Self {
@@ -25,13 +21,50 @@ where
2521

2622
#[inline]
2723
pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
28-
self.0 ^= ((value ^ self.test(lane)) as u64) << lane
24+
self.0 ^= ((value ^ self.test_unchecked(lane)) as u64) << lane
25+
}
26+
27+
#[inline]
28+
pub fn to_int<V, T>(self) -> V
29+
where
30+
V: Default + AsMut<[T; LANES]>,
31+
T: From<i8>,
32+
{
33+
// TODO this should be an intrinsic sign-extension
34+
let mut v = V::default();
35+
for i in 0..LANES {
36+
let lane = unsafe { self.test_unchecked(i) };
37+
v.as_mut()[i] = (-(lane as i8)).into();
38+
}
39+
v
40+
}
41+
42+
#[inline]
43+
pub unsafe fn from_int_unchecked<V>(value: V) -> Self
44+
where
45+
V: crate::LanesAtMost32,
46+
{
47+
let mask: V::BitMask = crate::intrinsics::simd_bitmask(value);
48+
Self(mask.into())
49+
}
50+
51+
#[inline]
52+
pub fn to_bitmask(self) -> u64 {
53+
self.0
54+
}
55+
56+
#[inline]
57+
pub fn any(self) -> bool {
58+
self != Self::splat(false)
59+
}
60+
61+
#[inline]
62+
pub fn all(self) -> bool {
63+
self == Self::splat(true)
2964
}
3065
}
3166

3267
impl<const LANES: usize> core::ops::BitAnd for BitMask<LANES>
33-
where
34-
Self: LanesAtMost32,
3568
{
3669
type Output = Self;
3770
#[inline]
@@ -41,8 +74,6 @@ where
4174
}
4275

4376
impl<const LANES: usize> core::ops::BitAnd<bool> for BitMask<LANES>
44-
where
45-
Self: LanesAtMost32,
4677
{
4778
type Output = Self;
4879
#[inline]
@@ -52,8 +83,6 @@ where
5283
}
5384

5485
impl<const LANES: usize> core::ops::BitAnd<BitMask<LANES>> for bool
55-
where
56-
BitMask<LANES>: LanesAtMost32,
5786
{
5887
type Output = BitMask<LANES>;
5988
#[inline]
@@ -63,8 +92,6 @@ where
6392
}
6493

6594
impl<const LANES: usize> core::ops::BitOr for BitMask<LANES>
66-
where
67-
Self: LanesAtMost32,
6895
{
6996
type Output = Self;
7097
#[inline]
@@ -73,31 +100,7 @@ where
73100
}
74101
}
75102

76-
impl<const LANES: usize> core::ops::BitOr<bool> for BitMask<LANES>
77-
where
78-
Self: LanesAtMost32,
79-
{
80-
type Output = Self;
81-
#[inline]
82-
fn bitor(self, rhs: bool) -> Self {
83-
self | Self::splat(rhs)
84-
}
85-
}
86-
87-
impl<const LANES: usize> core::ops::BitOr<BitMask<LANES>> for bool
88-
where
89-
BitMask<LANES>: LanesAtMost32,
90-
{
91-
type Output = BitMask<LANES>;
92-
#[inline]
93-
fn bitor(self, rhs: BitMask<LANES>) -> BitMask<LANES> {
94-
BitMask::<LANES>::splat(self) | rhs
95-
}
96-
}
97-
98103
impl<const LANES: usize> core::ops::BitXor for BitMask<LANES>
99-
where
100-
Self: LanesAtMost32,
101104
{
102105
type Output = Self;
103106
#[inline]
@@ -106,95 +109,42 @@ where
106109
}
107110
}
108111

109-
impl<const LANES: usize> core::ops::BitXor<bool> for BitMask<LANES>
110-
where
111-
Self: LanesAtMost32,
112-
{
113-
type Output = Self;
114-
#[inline]
115-
fn bitxor(self, rhs: bool) -> Self::Output {
116-
self ^ Self::splat(rhs)
117-
}
118-
}
119-
120-
impl<const LANES: usize> core::ops::BitXor<BitMask<LANES>> for bool
121-
where
122-
BitMask<LANES>: LanesAtMost32,
123-
{
124-
type Output = BitMask<LANES>;
125-
#[inline]
126-
fn bitxor(self, rhs: BitMask<LANES>) -> Self::Output {
127-
BitMask::<LANES>::splat(self) ^ rhs
128-
}
129-
}
130-
131112
impl<const LANES: usize> core::ops::Not for BitMask<LANES>
132-
where
133-
Self: LanesAtMost32,
134113
{
135114
type Output = BitMask<LANES>;
136115
#[inline]
137116
fn not(self) -> Self::Output {
138-
Self(!self.0)
117+
Self(!self.0) & Self::splat(true)
139118
}
140119
}
141120

142121
impl<const LANES: usize> core::ops::BitAndAssign for BitMask<LANES>
143-
where
144-
Self: LanesAtMost32,
145122
{
146123
#[inline]
147124
fn bitand_assign(&mut self, rhs: Self) {
148125
self.0 &= rhs.0;
149126
}
150127
}
151128

152-
impl<const LANES: usize> core::ops::BitAndAssign<bool> for BitMask<LANES>
153-
where
154-
Self: LanesAtMost32,
155-
{
156-
#[inline]
157-
fn bitand_assign(&mut self, rhs: bool) {
158-
*self &= Self::splat(rhs);
159-
}
160-
}
161-
162129
impl<const LANES: usize> core::ops::BitOrAssign for BitMask<LANES>
163-
where
164-
Self: LanesAtMost32,
165130
{
166131
#[inline]
167132
fn bitor_assign(&mut self, rhs: Self) {
168133
self.0 |= rhs.0;
169134
}
170135
}
171136

172-
impl<const LANES: usize> core::ops::BitOrAssign<bool> for BitMask<LANES>
173-
where
174-
Self: LanesAtMost32,
175-
{
176-
#[inline]
177-
fn bitor_assign(&mut self, rhs: bool) {
178-
*self |= Self::splat(rhs);
179-
}
180-
}
181-
182137
impl<const LANES: usize> core::ops::BitXorAssign for BitMask<LANES>
183-
where
184-
Self: LanesAtMost32,
185138
{
186139
#[inline]
187140
fn bitxor_assign(&mut self, rhs: Self) {
188141
self.0 ^= rhs.0;
189142
}
190143
}
191144

192-
impl<const LANES: usize> core::ops::BitXorAssign<bool> for BitMask<LANES>
193-
where
194-
Self: LanesAtMost32,
195-
{
196-
#[inline]
197-
fn bitxor_assign(&mut self, rhs: bool) {
198-
*self ^= Self::splat(rhs);
199-
}
200-
}
145+
pub type Mask8<const LANES: usize> = BitMask<LANES>;
146+
pub type Mask16<const LANES: usize> = BitMask<LANES>;
147+
pub type Mask32<const LANES: usize> = BitMask<LANES>;
148+
pub type Mask64<const LANES: usize> = BitMask<LANES>;
149+
pub type Mask128<const LANES: usize> = BitMask<LANES>;
150+
pub type MaskSize<const LANES: usize> = BitMask<LANES>;

crates/core_simd/src/masks/full_masks.rs

Lines changed: 8 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,12 @@ macro_rules! define_mask {
4646
}
4747

4848
#[inline]
49-
pub fn test(&self, lane: usize) -> bool {
50-
assert!(lane < LANES, "lane index out of range");
49+
pub unsafe fn test_unchecked(&self, lane: usize) -> bool {
5150
self.0[lane] == -1
5251
}
5352

5453
#[inline]
55-
pub fn set(&mut self, lane: usize, value: bool) {
56-
assert!(lane < LANES, "lane index out of range");
54+
pub unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
5755
self.0[lane] = if value {
5856
-1
5957
} else {
@@ -70,6 +68,12 @@ macro_rules! define_mask {
7068
pub unsafe fn from_int_unchecked(value: crate::$type<LANES>) -> Self {
7169
Self(value)
7270
}
71+
72+
#[inline]
73+
pub fn to_bitmask(self) -> u64 {
74+
let mask: <crate::$type<LANES> as crate::LanesAtMost32>::BitMask = unsafe { crate::intrinsics::simd_bitmask(self.0) };
75+
mask.into()
76+
}
7377
}
7478

7579
impl<const LANES: usize> core::convert::From<$name<LANES>> for crate::$type<LANES>
@@ -81,53 +85,6 @@ macro_rules! define_mask {
8185
}
8286
}
8387

84-
impl<const LANES: usize> core::fmt::Debug for $name<LANES>
85-
where
86-
crate::$type<LANES>: crate::LanesAtMost32,
87-
{
88-
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
89-
f.debug_list()
90-
.entries((0..LANES).map(|lane| self.test(lane)))
91-
.finish()
92-
}
93-
}
94-
95-
impl<const LANES: usize> core::fmt::Binary for $name<LANES>
96-
where
97-
crate::$type<LANES>: crate::LanesAtMost32,
98-
{
99-
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
100-
core::fmt::Binary::fmt(&self.0, f)
101-
}
102-
}
103-
104-
impl<const LANES: usize> core::fmt::Octal for $name<LANES>
105-
where
106-
crate::$type<LANES>: crate::LanesAtMost32,
107-
{
108-
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
109-
core::fmt::Octal::fmt(&self.0, f)
110-
}
111-
}
112-
113-
impl<const LANES: usize> core::fmt::LowerHex for $name<LANES>
114-
where
115-
crate::$type<LANES>: crate::LanesAtMost32,
116-
{
117-
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
118-
core::fmt::LowerHex::fmt(&self.0, f)
119-
}
120-
}
121-
122-
impl<const LANES: usize> core::fmt::UpperHex for $name<LANES>
123-
where
124-
crate::$type<LANES>: crate::LanesAtMost32,
125-
{
126-
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
127-
core::fmt::UpperHex::fmt(&self.0, f)
128-
}
129-
}
130-
13188
impl<const LANES: usize> core::ops::BitAnd for $name<LANES>
13289
where
13390
crate::$type<LANES>: crate::LanesAtMost32,

0 commit comments

Comments
 (0)