Skip to content

Commit 11db9de

Browse files
committed
simd_select_bitmask: support passing the mask as an array
1 parent 6318e9d commit 11db9de

File tree

2 files changed

+28
-3
lines changed

2 files changed

+28
-3
lines changed

src/tools/miri/src/shims/intrinsics/simd.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -386,16 +386,25 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
386386
let (dest, dest_len) = this.place_to_simd(dest)?;
387387
let bitmask_len = dest_len.max(8);
388388

389-
assert!(mask.layout.ty.is_integral());
390389
assert!(bitmask_len <= 64);
391390
assert_eq!(bitmask_len, mask.layout.size.bits());
392391
assert_eq!(dest_len, yes_len);
393392
assert_eq!(dest_len, no_len);
394393
let dest_len = u32::try_from(dest_len).unwrap();
395394
let bitmask_len = u32::try_from(bitmask_len).unwrap();
396395

397-
let mask: u64 =
398-
this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap();
396+
// The mask can be a single integer or an array.
397+
let mask: u64 = match mask.layout.ty.kind() {
398+
ty::Int(..) | ty::Uint(..) =>
399+
this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap(),
400+
ty::Array(elem, _) if matches!(elem.kind(), ty::Uint(ty::UintTy::U8)) => {
401+
let mask_ty = this.machine.layouts.uint(mask.layout.size).unwrap();
402+
let mask = mask.transmute(mask_ty, this)?;
403+
this.read_scalar(&mask)?.to_bits(mask_ty.size)?.try_into().unwrap()
404+
}
405+
_ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
406+
};
407+
399408
for i in 0..dest_len {
400409
let mask = mask
401410
& 1u64

src/tools/miri/tests/pass/portable-simd.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,22 @@ fn simd_mask() {
247247
assert_eq!(bitmask2, [0b0001]);
248248
}
249249
}
250+
251+
// This used to cause an ICE.
252+
let bitmask = u8x8::from_array([0b01000101, 0, 0, 0, 0, 0, 0, 0]);
253+
assert_eq!(
254+
mask32x8::from_bitmask_vector(bitmask),
255+
mask32x8::from_array([true, false, true, false, false, false, true, false]),
256+
);
257+
let bitmask =
258+
u8x16::from_array([0b01000101, 0b11110000, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
259+
assert_eq!(
260+
mask32x16::from_bitmask_vector(bitmask),
261+
mask32x16::from_array([
262+
true, false, true, false, false, false, true, false, false, false, false, false, true,
263+
true, true, true,
264+
]),
265+
);
250266
}
251267

252268
fn simd_cast() {

0 commit comments

Comments
 (0)