Skip to content

Commit 2c7e0fd

Browse files
committed
Auto merge of rust-lang#3237 - RalfJung:simd-loadstore, r=RalfJung
implement and test simd_masked_load and simd_masked_store also extend the scatter/gather tests Fixes rust-lang/miri#3235
2 parents 23efae0 + e8a4bd1 commit 2c7e0fd

File tree

2 files changed

+111
-6
lines changed

2 files changed

+111
-6
lines changed

src/tools/miri/src/shims/intrinsics/simd.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,54 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriInterpCxExt<'mir, 'tcx> {
656656
}
657657
}
658658
}
659+
"masked_load" => {
660+
let [mask, ptr, default] = check_arg_count(args)?;
661+
let (mask, mask_len) = this.operand_to_simd(mask)?;
662+
let ptr = this.read_pointer(ptr)?;
663+
let (default, default_len) = this.operand_to_simd(default)?;
664+
let (dest, dest_len) = this.place_to_simd(dest)?;
665+
666+
assert_eq!(dest_len, mask_len);
667+
assert_eq!(dest_len, default_len);
668+
669+
for i in 0..dest_len {
670+
let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
671+
let default = this.read_immediate(&this.project_index(&default, i)?)?;
672+
let dest = this.project_index(&dest, i)?;
673+
674+
let val = if simd_element_to_bool(mask)? {
675+
// Size * u64 is implemented as always checked
676+
#[allow(clippy::arithmetic_side_effects)]
677+
let ptr = ptr.wrapping_offset(dest.layout.size * i, this);
678+
let place = this.ptr_to_mplace(ptr, dest.layout);
679+
this.read_immediate(&place)?
680+
} else {
681+
default
682+
};
683+
this.write_immediate(*val, &dest)?;
684+
}
685+
}
686+
"masked_store" => {
687+
let [mask, ptr, vals] = check_arg_count(args)?;
688+
let (mask, mask_len) = this.operand_to_simd(mask)?;
689+
let ptr = this.read_pointer(ptr)?;
690+
let (vals, vals_len) = this.operand_to_simd(vals)?;
691+
692+
assert_eq!(mask_len, vals_len);
693+
694+
for i in 0..vals_len {
695+
let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
696+
let val = this.read_immediate(&this.project_index(&vals, i)?)?;
697+
698+
if simd_element_to_bool(mask)? {
699+
// Size * u64 is implemented as always checked
700+
#[allow(clippy::arithmetic_side_effects)]
701+
let ptr = ptr.wrapping_offset(val.layout.size * i, this);
702+
let place = this.ptr_to_mplace(ptr, val.layout);
703+
this.write_immediate(*val, &place)?
704+
};
705+
}
706+
}
659707

660708
name => throw_unsup_format!("unimplemented intrinsic: `simd_{name}`"),
661709
}

src/tools/miri/tests/pass/portable-simd.rs

Lines changed: 63 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
//@compile-flags: -Zmiri-strict-provenance
2-
#![feature(portable_simd, platform_intrinsics, adt_const_params, inline_const)]
2+
#![feature(portable_simd, platform_intrinsics, adt_const_params, inline_const, core_intrinsics)]
33
#![allow(incomplete_features, internal_features)]
4+
use std::intrinsics::simd as intrinsics;
5+
use std::ptr;
46
use std::simd::{prelude::*, StdFloat};
57

68
fn simd_ops_f32() {
@@ -421,6 +423,40 @@ fn simd_gather_scatter() {
421423
let idxs = Simd::from_array([9, 3, 0, 0]);
422424
Simd::from_array([-27, 82, -41, 124]).scatter(&mut vec, idxs);
423425
assert_eq!(vec, vec![124, 11, 12, 82, 14, 15, 16, 17, 18]);
426+
427+
// We call the intrinsics directly to experiment with dangling pointers and masks.
428+
let val = 42u8;
429+
let ptrs: Simd<*const u8, 4> =
430+
Simd::from_array([ptr::null(), ptr::addr_of!(val), ptr::addr_of!(val), ptr::addr_of!(val)]);
431+
let default = u8x4::splat(0);
432+
let mask = i8x4::from_array([0, !0, 0, !0]);
433+
let vals = unsafe { intrinsics::simd_gather(default, ptrs, mask) };
434+
assert_eq!(vals, u8x4::from_array([0, 42, 0, 42]),);
435+
436+
let mut val1 = 0u8;
437+
let mut val2 = 0u8;
438+
let ptrs: Simd<*mut u8, 4> = Simd::from_array([
439+
ptr::null_mut(),
440+
ptr::addr_of_mut!(val1),
441+
ptr::addr_of_mut!(val1),
442+
ptr::addr_of_mut!(val2),
443+
]);
444+
let vals = u8x4::from_array([1, 2, 3, 4]);
445+
unsafe { intrinsics::simd_scatter(vals, ptrs, mask) };
446+
assert_eq!(val1, 2);
447+
assert_eq!(val2, 4);
448+
449+
// Also check what happens when `scatter` has multiple overlapping pointers.
450+
let mut val = 0u8;
451+
let ptrs: Simd<*mut u8, 4> = Simd::from_array([
452+
ptr::addr_of_mut!(val),
453+
ptr::addr_of_mut!(val),
454+
ptr::addr_of_mut!(val),
455+
ptr::addr_of_mut!(val),
456+
]);
457+
let vals = u8x4::from_array([1, 2, 3, 4]);
458+
unsafe { intrinsics::simd_scatter(vals, ptrs, mask) };
459+
assert_eq!(val, 4);
424460
}
425461

426462
fn simd_round() {
@@ -460,14 +496,11 @@ fn simd_round() {
460496
}
461497

462498
fn simd_intrinsics() {
499+
use intrinsics::*;
463500
extern "platform-intrinsic" {
464-
fn simd_eq<T, U>(x: T, y: T) -> U;
465-
fn simd_reduce_any<T>(x: T) -> bool;
466-
fn simd_reduce_all<T>(x: T) -> bool;
467-
fn simd_select<M, T>(m: M, yes: T, no: T) -> T;
468501
fn simd_shuffle_generic<T, U, const IDX: &'static [u32]>(x: T, y: T) -> U;
469-
fn simd_shuffle<T, IDX, U>(x: T, y: T, idx: IDX) -> U;
470502
}
503+
471504
unsafe {
472505
// Make sure simd_eq returns all-1 for `true`
473506
let a = i32x4::splat(10);
@@ -503,6 +536,29 @@ fn simd_intrinsics() {
503536
}
504537
}
505538

539+
fn simd_masked_loadstore() {
540+
// The buffer is deliberarely too short, so reading the last element would be UB.
541+
let buf = [3i32; 3];
542+
let default = i32x4::splat(0);
543+
let mask = i32x4::from_array([!0, !0, !0, 0]);
544+
let vals = unsafe { intrinsics::simd_masked_load(mask, buf.as_ptr(), default) };
545+
assert_eq!(vals, i32x4::from_array([3, 3, 3, 0]));
546+
// Also read in a way that the *first* element is OOB.
547+
let mask2 = i32x4::from_array([0, !0, !0, !0]);
548+
let vals =
549+
unsafe { intrinsics::simd_masked_load(mask2, buf.as_ptr().wrapping_sub(1), default) };
550+
assert_eq!(vals, i32x4::from_array([0, 3, 3, 3]));
551+
552+
// The buffer is deliberarely too short, so writing the last element would be UB.
553+
let mut buf = [42i32; 3];
554+
let vals = i32x4::from_array([1, 2, 3, 4]);
555+
unsafe { intrinsics::simd_masked_store(mask, buf.as_mut_ptr(), vals) };
556+
assert_eq!(buf, [1, 2, 3]);
557+
// Also write in a way that the *first* element is OOB.
558+
unsafe { intrinsics::simd_masked_store(mask2, buf.as_mut_ptr().wrapping_sub(1), vals) };
559+
assert_eq!(buf, [2, 3, 4]);
560+
}
561+
506562
fn main() {
507563
simd_mask();
508564
simd_ops_f32();
@@ -513,4 +569,5 @@ fn main() {
513569
simd_gather_scatter();
514570
simd_round();
515571
simd_intrinsics();
572+
simd_masked_loadstore();
516573
}

0 commit comments

Comments
 (0)