Skip to content
Draft
4 changes: 3 additions & 1 deletion crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

use crate::attr::{AggregatedSpirvAttributes, IntrinsicType};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::spirv_type::{SpirvType, StorageClassKind};
use itertools::Itertools;
use rspirv::spirv::{Dim, ImageFormat, StorageClass, Word};
use rustc_data_structures::fx::FxHashMap;
Expand Down Expand Up @@ -339,6 +339,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
PointeeDefState::Defining => {
let id = SpirvType::Pointer {
pointee: pointee_spv,
storage_class: StorageClassKind::Inferred, // TODO(jwollen): Do we need to cache by storage class?
}
.def(span, cx);
entry.insert(PointeeDefState::Defined(id));
Expand All @@ -350,6 +351,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
entry.insert(PointeeDefState::Defined(id));
SpirvType::Pointer {
pointee: pointee_spv,
storage_class: StorageClassKind::Inferred, // TODO(jwollen): Do we need to cache by storage class?
}
.def_with_id(cx, span, id)
}
Expand Down
5 changes: 3 additions & 2 deletions crates/rustc_codegen_spirv/src/attr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ impl CheckSpirvAttrVisitor<'_> {
"attribute is only valid on a parameter of an entry-point function",
);
} else {
// FIXME(eddyb) should we just remove all 5 of these storage class
// FIXME(eddyb) should we just remove all 6 of these storage class
// attributes, instead of disallowing them here?
if let SpirvAttribute::StorageClass(storage_class) = parsed_attr {
let valid = match storage_class {
Expand All @@ -347,7 +347,8 @@ impl CheckSpirvAttrVisitor<'_> {

StorageClass::Private
| StorageClass::Function
| StorageClass::Generic => {
| StorageClass::Generic
| StorageClass::PhysicalStorageBuffer => {
Err("can not be used as part of an entry's interface")
}

Expand Down
119 changes: 90 additions & 29 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -360,11 +360,21 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
}

fn zombie_convert_ptr_to_u(&self, def: Word) {
self.zombie(def, "cannot convert pointers to integers");
if !self
.builder
.has_capability(Capability::PhysicalStorageBufferAddresses)
{
self.zombie(def, "cannot convert pointers to integers without OpCapability PhysicalStorageBufferAddresses");
}
}

fn zombie_convert_u_to_ptr(&self, def: Word) {
self.zombie(def, "cannot convert integers to pointers");
if !self
.builder
.has_capability(Capability::PhysicalStorageBufferAddresses)
{
self.zombie(def, "cannot convert integers to pointers without OpCapability PhysicalStorageBufferAddresses");
}
}

fn zombie_ptr_equal(&self, def: Word, inst: &str) {
Expand Down Expand Up @@ -407,14 +417,15 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
size: Size,
) -> Option<(SpirvValue, <Self as BackendTypes>::Type)> {
let ptr = ptr.strip_ptrcasts();
let mut leaf_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
let pointee_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!("non-pointer type: {other:?}")),
};

// FIXME(eddyb) this isn't efficient, `recover_access_chain_from_offset`
// could instead be doing all the extra digging itself.
let mut indices = SmallVec::<[_; 8]>::new();
let mut leaf_ty = pointee_ty;
while let Some((inner_indices, inner_ty)) = self.recover_access_chain_from_offset(
leaf_ty,
Size::ZERO,
Expand All @@ -429,7 +440,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.then(|| self.type_ptr_to(leaf_ty))?;

let leaf_ptr = if indices.is_empty() {
assert_ty_eq!(self, ptr.ty, leaf_ptr_ty);
// Compare pointee types instead of pointer types as storage class might be different.
assert_ty_eq!(self, pointee_ty, leaf_ty);
ptr
} else {
let indices = indices
Expand Down Expand Up @@ -586,7 +598,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let ptr = ptr.strip_ptrcasts();
let ptr_id = ptr.def(self);
let original_pointee_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!("gep called on non-pointer type: {other:?}")),
};

Expand Down Expand Up @@ -1461,11 +1473,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
self.fatal("dynamic alloca not supported yet")
}

fn load(&mut self, ty: Self::Type, ptr: Self::Value, _align: Align) -> Self::Value {
fn load(&mut self, ty: Self::Type, ptr: Self::Value, align: Align) -> Self::Value {
let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, ty);
let loaded_val = ptr.const_fold_load(self).unwrap_or_else(|| {
self.emit()
.load(access_ty, None, ptr.def(self), None, empty())
.load(
access_ty,
None,
ptr.def(self),
Some(MemoryAccess::ALIGNED),
std::iter::once(Operand::LiteralBit32(align.bytes() as _)),
)
.unwrap()
.with_type(access_ty)
});
Expand Down Expand Up @@ -1587,12 +1605,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// ignore
}

fn store(&mut self, val: Self::Value, ptr: Self::Value, _align: Align) -> Self::Value {
fn store(&mut self, val: Self::Value, ptr: Self::Value, align: Align) -> Self::Value {
let (ptr, access_ty) = self.adjust_pointer_for_typed_access(ptr, val.ty);
let val = self.bitcast(val, access_ty);

self.emit()
.store(ptr.def(self), val.def(self), None, empty())
.store(
ptr.def(self),
val.def(self),
Some(MemoryAccess::ALIGNED),
std::iter::once(Operand::LiteralBit32(align.bytes() as _)),
)
.unwrap();
// FIXME(eddyb) this is meant to be a handle the store instruction itself.
val
Expand Down Expand Up @@ -1750,20 +1773,23 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}

fn inttoptr(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
match self.lookup_type(dest_ty) {
SpirvType::Pointer { .. } => (),
let result_ty = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee, .. } => self.type_ptr_to_with_storage_class(
pointee,
StorageClassKind::Explicit(StorageClass::PhysicalStorageBuffer),
),
other => self.fatal(format!(
"inttoptr called on non-pointer dest type: {other:?}"
)),
}
if val.ty == dest_ty {
};
if val.ty == result_ty {
val
} else {
let result = self
.emit()
.convert_u_to_ptr(dest_ty, None, val.def(self))
.convert_u_to_ptr(result_ty, None, val.def(self))
.unwrap()
.with_type(dest_ty);
.with_type(result_ty);
self.zombie_convert_u_to_ptr(result.def(self));
result
}
Expand Down Expand Up @@ -1926,6 +1952,25 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
return ptr;
}

// No cast is needed if only the storage class mismatches.
let ptr_pointee = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer source type: {other:?}"
)),
};
let dest_pointee = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer dest type: {other:?}"
)),
};

// FIXME(jwollen) Do we need to choose `dest_ty` if it has a fixed storage class and `ptr` has none?
if ptr_pointee == dest_pointee {
return ptr;
}

// Strip a previous `pointercast`, to reveal the original pointer type.
let ptr = ptr.strip_ptrcasts();

Expand All @@ -1934,17 +1979,16 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}

let ptr_pointee = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer source type: {other:?}"
)),
};
let dest_pointee = match self.lookup_type(dest_ty) {
SpirvType::Pointer { pointee } => pointee,
other => self.fatal(format!(
"pointercast called on non-pointer dest type: {other:?}"
)),
};

if ptr_pointee == dest_pointee {
return ptr;
}

let dest_pointee_size = self.lookup_type(dest_pointee).sizeof(self);

if let Some((indices, _)) = self.recover_access_chain_from_offset(
Expand Down Expand Up @@ -2229,9 +2273,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
fn memcpy(
&mut self,
dst: Self::Value,
_dst_align: Align,
dst_align: Align,
src: Self::Value,
_src_align: Align,
src_align: Align,
size: Self::Value,
flags: MemFlags,
) {
Expand Down Expand Up @@ -2269,12 +2313,29 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}
});

// Pass all operands as `additional_params` since rspirv doesn't allow specifying
// extra operands ofter the first `MemoryAccess`
let mut ops: SmallVec<[_; 4]> = Default::default();
ops.push(Operand::MemoryAccess(MemoryAccess::ALIGNED));
if src_align != dst_align {
if self.emit().version().unwrap() > (1, 3) {
ops.push(Operand::LiteralBit32(dst_align.bytes() as _));
ops.push(Operand::MemoryAccess(MemoryAccess::ALIGNED));
ops.push(Operand::LiteralBit32(src_align.bytes() as _));
} else {
let align = dst_align.min(src_align);
ops.push(Operand::LiteralBit32(align.bytes() as _));
}
} else {
ops.push(Operand::LiteralBit32(dst_align.bytes() as _));
}

if let Some((dst, src)) = typed_copy_dst_src {
if let Some(const_value) = src.const_fold_load(self) {
self.store(const_value, dst, Align::from_bytes(0).unwrap());
} else {
self.emit()
.copy_memory(dst.def(self), src.def(self), None, None, empty())
.copy_memory(dst.def(self), src.def(self), None, None, ops)
.unwrap();
}
} else {
Expand All @@ -2285,7 +2346,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
size.def(self),
None,
None,
empty(),
ops,
)
.unwrap();
self.zombie(dst.def(self), "cannot memcpy dynamically sized data");
Expand Down Expand Up @@ -2324,7 +2385,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
.and_then(|size| Some(Size::from_bytes(u64::try_from(size).ok()?)));

let elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
_ => self.fatal(format!(
"memset called on non-pointer type: {}",
self.debug_type(ptr.ty)
Expand Down Expand Up @@ -2696,7 +2757,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
(callee.def(self), return_type, arguments)
}

SpirvType::Pointer { pointee } => match self.lookup_type(pointee) {
SpirvType::Pointer { pointee, .. } => match self.lookup_type(pointee) {
SpirvType::Function {
return_type,
arguments,
Expand Down
22 changes: 19 additions & 3 deletions crates/rustc_codegen_spirv/src/builder/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use crate::abi::ConvSpirvType;
use crate::builder_spirv::{BuilderCursor, SpirvValue, SpirvValueExt};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use rspirv::spirv::Word;
use crate::spirv_type::{SpirvType, StorageClassKind};
use rspirv::spirv::{StorageClass, Word};
use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
use rustc_codegen_ssa::mir::place::PlaceRef;
use rustc_codegen_ssa::traits::{
Expand Down Expand Up @@ -104,7 +104,23 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {

// HACK(eddyb) like the `CodegenCx` method but with `self.span()` awareness.
pub fn type_ptr_to(&self, ty: Word) -> Word {
SpirvType::Pointer { pointee: ty }.def(self.span(), self)
SpirvType::Pointer {
pointee: ty,
storage_class: StorageClassKind::Inferred,
}
.def(self.span(), self)
}

pub fn type_ptr_to_with_storage_class(
&self,
ty: Word,
storage_class: StorageClassKind,
) -> Word {
SpirvType::Pointer {
pointee: ty,
storage_class,
}
.def(self.span(), self)
}

// TODO: Definitely add tests to make sure this impl is right.
Expand Down
24 changes: 10 additions & 14 deletions crates/rustc_codegen_spirv/src/builder/spirv_asm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::maybe_pqp_cg_ssa as rustc_codegen_ssa;
use super::Builder;
use crate::builder_spirv::{BuilderCursor, SpirvValue};
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use crate::spirv_type::{SpirvType, StorageClassKind};
use rspirv::dr;
use rspirv::grammar::{LogicalOperand, OperandKind, OperandQuantifier, reflect};
use rspirv::spirv::{
Expand Down Expand Up @@ -307,19 +307,14 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
}
.def(self.span(), self),
Op::TypePointer => {
let storage_class = inst.operands[0].unwrap_storage_class();
if storage_class != StorageClass::Generic {
self.struct_err("TypePointer in asm! requires `Generic` storage class")
.with_note(format!(
"`{storage_class:?}` storage class was specified"
))
.with_help(format!(
"the storage class will be inferred automatically (e.g. to `{storage_class:?}`)"
))
.emit();
}
// The storage class can be specified explicitly or inferred later by using StorageClass::Generic.
let storage_class = match inst.operands[0].unwrap_storage_class() {
StorageClass::Generic => StorageClassKind::Inferred,
storage_class => StorageClassKind::Explicit(storage_class),
};
SpirvType::Pointer {
pointee: inst.operands[1].unwrap_id_ref(),
storage_class,
}
.def(self.span(), self)
}
Expand Down Expand Up @@ -678,6 +673,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {

TyPat::Pointer(_, pat) => SpirvType::Pointer {
pointee: subst_ty_pat(cx, pat, ty_vars, leftover_operands)?,
storage_class: StorageClassKind::Inferred,
}
.def(DUMMY_SP, cx),

Expand Down Expand Up @@ -931,7 +927,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
Some(match kind {
TypeofKind::Plain => ty,
TypeofKind::Dereference => match self.lookup_type(ty) {
SpirvType::Pointer { pointee } => pointee,
SpirvType::Pointer { pointee, .. } => pointee,
other => {
self.tcx.dcx().span_err(
span,
Expand All @@ -953,7 +949,7 @@ impl<'cx, 'tcx> Builder<'cx, 'tcx> {
self.check_reg(span, reg);
if let Some(place) = place {
match self.lookup_type(place.val.llval.ty) {
SpirvType::Pointer { pointee } => Some(pointee),
SpirvType::Pointer { pointee, .. } => Some(pointee),
other => {
self.tcx.dcx().span_err(
span,
Expand Down
Loading
Loading