Skip to content

#[spirv(block)] for Block decorations on structs. #295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 41 additions & 2 deletions crates/rustc_codegen_spirv/src/abi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ impl<'tcx> ConvSpirvType<'tcx> for CastTarget {
field_types: args,
field_offsets,
field_names: None,
is_block: false,
}
.def(span, cx)
}
Expand Down Expand Up @@ -324,6 +325,35 @@ fn trans_type_impl<'tcx>(
) -> Word {
if let TyKind::Adt(adt, _) = *ty.ty.kind() {
for attr in parse_attrs(cx, cx.tcx.get_attrs(adt.did)) {
if matches!(attr, SpirvAttribute::Block) {
if !adt.is_struct() {
cx.tcx.sess.span_err(
span,
&format!(
"`#[spirv(block)]` can only be used on a `struct`, \
but `{}` is a `{}`",
ty.ty,
adt.descr(),
),
);
}

if !matches!(ty.abi, Abi::Aggregate { sized: true }) {
cx.tcx.sess.span_err(
span,
&format!(
"`#[spirv(block)]` can only be used for `Sized` aggregates, \
but `{}` has `Abi::{:?}`",
ty.ty, ty.abi,
),
);
}

assert!(matches!(ty.fields, FieldsShape::Arbitrary { .. }));

return trans_struct(cx, span, ty, true);
}

if let Some(image) = trans_image(cx, span, ty, attr) {
return image;
}
Expand All @@ -340,6 +370,7 @@ fn trans_type_impl<'tcx>(
field_types: Vec::new(),
field_offsets: Vec::new(),
field_names: None,
is_block: false,
}
.def(span, cx),
Abi::Scalar(ref scalar) => trans_scalar(cx, span, ty, scalar, None, is_immediate),
Expand All @@ -359,6 +390,7 @@ fn trans_type_impl<'tcx>(
field_types: vec![one_spirv, two_spirv],
field_offsets: vec![one_offset, two_offset],
field_names: None,
is_block: false,
}
.def(span, cx)
}
Expand Down Expand Up @@ -618,6 +650,7 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
field_types: Vec::new(),
field_offsets: Vec::new(),
field_names: None,
is_block: false,
}
.def(span, cx)
} else {
Expand All @@ -638,7 +671,7 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
FieldsShape::Arbitrary {
offsets: _,
memory_index: _,
} => trans_struct(cx, span, ty),
} => trans_struct(cx, span, ty, false),
}
}

Expand Down Expand Up @@ -668,7 +701,12 @@ pub fn auto_struct_layout<'tcx>(
}

// see struct_llfields in librustc_codegen_llvm for implementation hints
fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
fn trans_struct<'tcx>(
cx: &CodegenCx<'tcx>,
span: Span,
ty: TyAndLayout<'tcx>,
is_block: bool,
) -> Word {
let name = name_of_struct(ty);
if let TyKind::Foreign(_) = ty.ty.kind() {
// "An unsized FFI type that is opaque to Rust", `extern type A;` (currently unstable)
Expand Down Expand Up @@ -718,6 +756,7 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -
field_types,
field_offsets,
field_names: Some(field_names),
is_block,
}
.def(span, cx)
}
Expand Down
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/codegen_cx/constant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
field_types,
field_offsets,
field_names: None,
is_block: false,
}
.def(DUMMY_SP, self);
self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)).collect())
Expand Down
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/codegen_cx/type_.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> {
field_types: els.to_vec(),
field_offsets,
field_names: None,
is_block: false,
}
.def(DUMMY_SP, self)
}
Expand Down
1 change: 1 addition & 0 deletions crates/rustc_codegen_spirv/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ use rustc_session::Session;
use rustc_span::symbol::{sym, Symbol};
use rustc_target::spec::abi::Abi;
use rustc_target::spec::{LinkerFlavor, PanicStrategy, Target, TargetOptions, TargetTriple};
pub use spirv_tools;
use std::any::Any;
use std::env;
use std::fs::{create_dir_all, File};
Expand Down
15 changes: 13 additions & 2 deletions crates/rustc_codegen_spirv/src/spirv_type.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use rustc_span::Span;
use rustc_target::abi::{Align, Size};
use std::cell::RefCell;
use std::fmt;
use std::iter::once;
use std::iter;
use std::lazy::SyncLazy;
use std::sync::Mutex;

Expand All @@ -34,6 +34,7 @@ pub enum SpirvType {
field_types: Vec<Word>,
field_offsets: Vec<Size>,
field_names: Option<Vec<String>>,
is_block: bool,
},
Opaque {
name: String,
Expand Down Expand Up @@ -126,6 +127,7 @@ impl SpirvType {
ref field_types,
ref field_offsets,
ref field_names,
is_block,
} => {
let mut emit = cx.emit_global();
// Ensure a unique struct is emitted each time, due to possibly having different OpMemberDecorates
Expand All @@ -151,6 +153,9 @@ impl SpirvType {
emit.member_name(result, index as u32, field_name);
}
}
if is_block {
emit.decorate(result, Decoration::Block, iter::empty());
}
result
}
Self::Opaque { ref name } => cx.emit_global().type_opaque(name),
Expand All @@ -168,7 +173,7 @@ impl SpirvType {
cx.emit_global().decorate(
result,
Decoration::ArrayStride,
once(Operand::LiteralInt32(element_size as u32)),
iter::once(Operand::LiteralInt32(element_size as u32)),
);
}
result
Expand Down Expand Up @@ -344,6 +349,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
ref field_types,
ref field_offsets,
ref field_names,
is_block,
} => {
let fields = field_types
.iter()
Expand All @@ -357,6 +363,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
.field("field_types", &fields)
.field("field_offsets", field_offsets)
.field("field_names", field_names)
.field("is_block", &is_block)
.finish()
}
SpirvType::Opaque { ref name } => f
Expand Down Expand Up @@ -485,7 +492,11 @@ impl SpirvTypePrinter<'_, '_> {
ref field_types,
field_offsets: _,
ref field_names,
is_block,
} => {
if is_block {
write!(f, "#[spirv(block)] ")?;
}
write!(f, "struct {} {{ ", name)?;
for (index, &field) in field_types.iter().enumerate() {
let suffix = if index + 1 == field_types.len() {
Expand Down
2 changes: 2 additions & 0 deletions crates/rustc_codegen_spirv/src/symbols.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ impl Symbols {
SpirvAttribute::ReallyUnsafeIgnoreBitcasts,
),
("sampler", SpirvAttribute::Sampler),
("block", SpirvAttribute::Block),
]
.iter()
.cloned();
Expand Down Expand Up @@ -451,6 +452,7 @@ pub enum SpirvAttribute {
access_qualifier: Option<AccessQualifier>,
},
Sampler,
Block,
}

// Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused
Expand Down
51 changes: 50 additions & 1 deletion crates/spirv-builder/src/test/basic.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::{dis_fn, val};
use super::{dis_fn, val, val_vulkan};

#[test]
fn hello_world() {
Expand Down Expand Up @@ -126,3 +126,52 @@ pub fn main() {
}
"#);
}

// NOTE(eddyb) this won't pass Vulkan validation (see `push_constant_vulkan`),
// but should still pass the basline SPIR-V validation.
#[test]
fn push_constant() {
val(r#"
#[derive(Copy, Clone)]
pub struct ShaderConstants {
pub width: u32,
pub height: u32,
pub time: f32,
}

#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(
#[spirv(push_constant)] constants: PushConstant<ShaderConstants>,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe #[spirv(push_constant)] does nothing and was erroneously included in the example shader, which I'm assuming you copied this from. Doesn't really matter though.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh haha whoops.

) {
let _constants = constants.load();
}
"#);
}

// NOTE(eddyb) we specifically run Vulkan validation here, as the default
// validation rules are more lax and don't require a `Block` decoration
// (`#[spirv(block)]` here) on `struct ShaderConstants`.
#[test]
fn push_constant_vulkan() {
val_vulkan(
r#"
#[derive(Copy, Clone)]
#[allow(unused_attributes)]
#[spirv(block)]
pub struct ShaderConstants {
pub width: u32,
pub height: u32,
pub time: f32,
}

#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(
#[spirv(push_constant)] constants: PushConstant<ShaderConstants>,
) {
let _constants = constants.load();
}
"#,
);
}
19 changes: 19 additions & 0 deletions crates/spirv-builder/src/test/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,25 @@ fn val(src: &str) {
build(src);
}

/// While `val` runs baseline SPIR-V validation, for some tests we want the
/// stricter Vulkan validation (`vulkan1.2` specifically), which may produce
/// additional errors (such as missing Vulkan-specific decorations).
fn val_vulkan(src: &str) {
use rustc_codegen_spirv::spirv_tools::{
util::to_binary,
val::{self, Validator},
TargetEnv,
};

let validator = val::create(Some(TargetEnv::Vulkan_1_2));

let _lock = global_lock();
let bytes = std::fs::read(build(src)).unwrap();
if let Err(e) = validator.validate(to_binary(&bytes).unwrap(), None) {
panic!("Vulkan validation failed:\n{}", e.to_string());
}
}

fn assert_str_eq(expected: &str, result: &str) {
let expected = expected
.split('\n')
Expand Down
2 changes: 2 additions & 0 deletions examples/shaders/shared/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ use spirv_std::glam::Vec3;
use spirv_std::num_traits::Float;

#[derive(Copy, Clone)]
#[allow(unused_attributes)]
#[spirv(block)]
pub struct ShaderConstants {
pub width: u32,
pub height: u32,
Expand Down