Skip to content

Blocks (StorageBuffer and PushConstant) #289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
20 changes: 20 additions & 0 deletions crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ impl<'tcx> ConvSpirvType<'tcx> for CastTarget {
field_types: args,
field_offsets,
field_names: None,
is_block: false,
}
.def(span, cx)
}
Expand Down Expand Up @@ -340,6 +341,7 @@ fn trans_type_impl<'tcx>(
field_types: Vec::new(),
field_offsets: Vec::new(),
field_names: None,
is_block: false,
}
.def(span, cx),
Abi::Scalar(ref scalar) => trans_scalar(cx, span, ty, scalar, None, is_immediate),
Expand All @@ -359,6 +361,7 @@ fn trans_type_impl<'tcx>(
field_types: vec![one_spirv, two_spirv],
field_offsets: vec![one_offset, two_offset],
field_names: None,
is_block: false,
}
.def(span, cx)
}
Expand Down Expand Up @@ -582,6 +585,20 @@ fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Optio
None
}

/// Handles `#[spirv(block)]`. Note this is only called in the scalar translation code, because this is only
/// used for spooky builtin stuff, and we pinky promise to never have more than one pointer field in one of these.
// TODO: Enforce this is only used in spirv-std.
Comment on lines +588 to +590
Copy link
Contributor

@khyperia khyperia Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment isn't correct, right? It's the opposite of what's actually happening, it's only called in the ADT translation code. I believe this should be called at the entrypoint of trans_type_impl, in the same for loop that trans_image is called, and error when applied to a struct that results in a scalar or otherwise non-struct representation.

fn get_is_block_decorated<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> bool {
if let TyKind::Adt(adt, _substs) = ty.ty.kind() {
for attr in parse_attrs(cx, cx.tcx.get_attrs(adt.did)) {
if let SpirvAttribute::Block = attr {
return true;
}
}
}
false
}

fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
match ty.fields {
FieldsShape::Primitive => cx.tcx.sess.fatal(&format!(
Expand Down Expand Up @@ -618,6 +635,7 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
field_types: Vec::new(),
field_offsets: Vec::new(),
field_names: None,
is_block: false,
}
.def(span, cx)
} else {
Expand Down Expand Up @@ -711,13 +729,15 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -
}
};
}
let is_block = get_is_block_decorated(cx, ty);
SpirvType::Adt {
name,
size,
align,
field_types,
field_offsets,
field_names: Some(field_names),
is_block,
}
.def(span, cx)
}
Expand Down
21 changes: 19 additions & 2 deletions crates/rustc_codegen_spirv/src/builder/builder_methods.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1225,8 +1225,11 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}

fn pointercast(&mut self, val: Self::Value, dest_ty: Self::Type) -> Self::Value {
let val_pointee = match self.lookup_type(val.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
let (storage_class, val_pointee) = match self.lookup_type(val.ty) {
SpirvType::Pointer {
storage_class,
pointee,
} => (storage_class, pointee),
other => self.fatal(&format!(
"pointercast called on non-pointer source type: {:?}",
other
Expand All @@ -1242,6 +1245,20 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
if val.ty == dest_ty {
val
} else if let Some(indices) = self.try_pointercast_via_gep(val_pointee, dest_pointee) {
let dest_ty = if self
.really_unsafe_ignore_bitcasts
.borrow()
.contains(&self.current_fn)
{
SpirvType::Pointer {
storage_class,
pointee: dest_pointee,
}
// TODO: Get actual span here
.def(Span::default(), self)
} else {
dest_ty
};
let indices = indices
.into_iter()
.map(|idx| self.constant_u32(self.span(), idx).def(self))
Expand Down
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/codegen_cx/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
field_types,
field_offsets,
field_names: None,
is_block: false,
}
.def(DUMMY_SP, self);
self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)).collect())
Expand Down
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/codegen_cx/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> {
field_types: els.to_vec(),
field_offsets,
field_names: None,
is_block: false,
}
.def(DUMMY_SP, self)
}
Expand Down
8 changes: 8 additions & 0 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ pub enum SpirvType {
field_types: Vec<Word>,
field_offsets: Vec<Size>,
field_names: Option<Vec<String>>,
is_block: bool,
},
Opaque {
name: String,
Expand Down Expand Up @@ -126,6 +127,7 @@ impl SpirvType {
ref field_types,
ref field_offsets,
ref field_names,
is_block,
} => {
let mut emit = cx.emit_global();
// Ensure a unique struct is emitted each time, due to possibly having different OpMemberDecorates
Expand All @@ -146,6 +148,9 @@ impl SpirvType {
);
}
}
if is_block {
emit.decorate(id, Decoration::Block, None);
}
if let Some(field_names) = field_names {
for (index, field_name) in field_names.iter().enumerate() {
emit.member_name(result, index as u32, field_name);
Expand Down Expand Up @@ -344,6 +349,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
ref field_types,
ref field_offsets,
ref field_names,
is_block,
} => {
let fields = field_types
.iter()
Expand All @@ -357,6 +363,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
.field("field_types", &fields)
.field("field_offsets", field_offsets)
.field("field_names", field_names)
.field("is_block", &is_block)
.finish()
}
SpirvType::Opaque { ref name } => f
Expand Down Expand Up @@ -485,6 +492,7 @@ impl SpirvTypePrinter<'_, '_> {
ref field_types,
field_offsets: _,
ref field_names,
is_block: _,
} => {
write!(f, "struct {} {{ ", name)?;
for (index, &field) in field_types.iter().enumerate() {
Expand Down
2 changes: 2 additions & 0 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ impl Symbols {
SpirvAttribute::ReallyUnsafeIgnoreBitcasts,
),
("sampler", SpirvAttribute::Sampler),
("block", SpirvAttribute::Block),
]
.iter()
.cloned();
Expand Down Expand Up @@ -437,6 +438,7 @@ impl From<ExecutionModel> for Entry {
pub enum SpirvAttribute {
Builtin(BuiltIn),
StorageClass(StorageClass),
Block,
Entry(Entry),
DescriptorSet(u32),
Binding(u32),
Expand Down
52 changes: 50 additions & 2 deletions crates/spirv-std/src/storage_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,54 @@ macro_rules! storage_class {
}
};

// Interior Block
Copy link
Contributor

@khyperia khyperia Dec 1, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While the #[spirv(block)] attribute matches our plans in this area for a hack quick fix, we were planning on having the user apply the attribute themselves to types used in block contexts. As you pointed out, there are many problems with this type-wrapping approach, surfacing most obviously with the storage class mismatch issues you're hitting.

(Do keep in mind that this is intended to be a quick hack fix before we're able to fully implement the bindings RFC, as that will replace this whole system with something much more elegant)

($(#[$($meta:meta)+])* block $block:ident storage_class $name:ident ; $($tt:tt)*) => {

#[spirv(block)]
#[allow(unused_attributes)]
pub struct $block <T> {
value: T
}

$(#[$($meta)+])*
#[allow(unused_attributes)]
pub struct $name<'block, T> {
block: &'block mut $block <T>,
}

impl<T: Copy> $name<'_, T> {
/// Load the value into memory.
#[inline]
#[allow(unused_attributes)]
#[spirv(really_unsafe_ignore_bitcasts)]
pub fn load(&self) -> T {
self.block.value
}
}

storage_class!($($tt)*);
};

// Methods available on writeable storage classes.
($(#[$($meta:meta)+])* writeable block $block:ident storage_class $name:ident $($tt:tt)+) => {
storage_class!($(#[$($meta)+])* block $block storage_class $name $($tt)+);

impl <T: Copy> $name<'_, T> {
/// Store the value in storage.
#[inline]
#[allow(unused_attributes)]
#[spirv(really_unsafe_ignore_bitcasts)]
pub fn store(&mut self, v: T) {
self.block.value = v;
}

/// A convenience function to load a value into memory and store it.
pub fn then(&mut self, then: impl FnOnce(T) -> T) {
self.store((then)(self.load()));
}
}
};

(;) => {};
() => {};
}
Expand Down Expand Up @@ -112,7 +160,7 @@ storage_class! {
/// Intended to contain a small bank of values pushed from the client API.
/// Variables declared with this storage class are read-only, and must not
/// have initializers.
#[spirv(push_constant)] storage_class PushConstant;
#[spirv(push_constant)] block PushConstantBlock storage_class PushConstant;

/// Atomic counter-specific memory.
///
Expand All @@ -131,7 +179,7 @@ storage_class! {
///
/// Shared externally, readable and writable, visible across all functions
/// in all invocations in all work groups.
#[spirv(storage_buffer)] writeable storage_class StorageBuffer;
#[spirv(storage_buffer)] writeable block StorageBufferBlock storage_class StorageBuffer;

/// Used for storing arbitrary data associated with a ray to pass
/// to callables. (Requires `SPV_KHR_ray_tracing` extension)
Expand Down