Skip to content

Commit ae42a91

Browse files
authored
#[spirv(block)] for Block decorations on structs. (#295)
1 parent 340dfc4 commit ae42a91

File tree

9 files changed

+130
-5
lines changed

9 files changed

+130
-5
lines changed

crates/rustc_codegen_spirv/src/abi.rs

Lines changed: 41 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ impl<'tcx> ConvSpirvType<'tcx> for CastTarget {
229229
field_types: args,
230230
field_offsets,
231231
field_names: None,
232+
is_block: false,
232233
}
233234
.def(span, cx)
234235
}
@@ -324,6 +325,35 @@ fn trans_type_impl<'tcx>(
324325
) -> Word {
325326
if let TyKind::Adt(adt, _) = *ty.ty.kind() {
326327
for attr in parse_attrs(cx, cx.tcx.get_attrs(adt.did)) {
328+
if matches!(attr, SpirvAttribute::Block) {
329+
if !adt.is_struct() {
330+
cx.tcx.sess.span_err(
331+
span,
332+
&format!(
333+
"`#[spirv(block)]` can only be used on a `struct`, \
334+
but `{}` is a `{}`",
335+
ty.ty,
336+
adt.descr(),
337+
),
338+
);
339+
}
340+
341+
if !matches!(ty.abi, Abi::Aggregate { sized: true }) {
342+
cx.tcx.sess.span_err(
343+
span,
344+
&format!(
345+
"`#[spirv(block)]` can only be used for `Sized` aggregates, \
346+
but `{}` has `Abi::{:?}`",
347+
ty.ty, ty.abi,
348+
),
349+
);
350+
}
351+
352+
assert!(matches!(ty.fields, FieldsShape::Arbitrary { .. }));
353+
354+
return trans_struct(cx, span, ty, true);
355+
}
356+
327357
if let Some(image) = trans_image(cx, span, ty, attr) {
328358
return image;
329359
}
@@ -340,6 +370,7 @@ fn trans_type_impl<'tcx>(
340370
field_types: Vec::new(),
341371
field_offsets: Vec::new(),
342372
field_names: None,
373+
is_block: false,
343374
}
344375
.def(span, cx),
345376
Abi::Scalar(ref scalar) => trans_scalar(cx, span, ty, scalar, None, is_immediate),
@@ -359,6 +390,7 @@ fn trans_type_impl<'tcx>(
359390
field_types: vec![one_spirv, two_spirv],
360391
field_offsets: vec![one_offset, two_offset],
361392
field_names: None,
393+
is_block: false,
362394
}
363395
.def(span, cx)
364396
}
@@ -618,6 +650,7 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
618650
field_types: Vec::new(),
619651
field_offsets: Vec::new(),
620652
field_names: None,
653+
is_block: false,
621654
}
622655
.def(span, cx)
623656
} else {
@@ -638,7 +671,7 @@ fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>
638671
FieldsShape::Arbitrary {
639672
offsets: _,
640673
memory_index: _,
641-
} => trans_struct(cx, span, ty),
674+
} => trans_struct(cx, span, ty, false),
642675
}
643676
}
644677

@@ -668,7 +701,12 @@ pub fn auto_struct_layout<'tcx>(
668701
}
669702

670703
// see struct_llfields in librustc_codegen_llvm for implementation hints
671-
fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -> Word {
704+
fn trans_struct<'tcx>(
705+
cx: &CodegenCx<'tcx>,
706+
span: Span,
707+
ty: TyAndLayout<'tcx>,
708+
is_block: bool,
709+
) -> Word {
672710
let name = name_of_struct(ty);
673711
if let TyKind::Foreign(_) = ty.ty.kind() {
674712
// "An unsized FFI type that is opaque to Rust", `extern type A;` (currently unstable)
@@ -718,6 +756,7 @@ fn trans_struct<'tcx>(cx: &CodegenCx<'tcx>, span: Span, ty: TyAndLayout<'tcx>) -
718756
field_types,
719757
field_offsets,
720758
field_names: Some(field_names),
759+
is_block,
721760
}
722761
.def(span, cx)
723762
}

crates/rustc_codegen_spirv/src/codegen_cx/constant.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
187187
field_types,
188188
field_offsets,
189189
field_names: None,
190+
is_block: false,
190191
}
191192
.def(DUMMY_SP, self);
192193
self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)).collect())

crates/rustc_codegen_spirv/src/codegen_cx/type_.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,7 @@ impl<'tcx> BaseTypeMethods<'tcx> for CodegenCx<'tcx> {
150150
field_types: els.to_vec(),
151151
field_offsets,
152152
field_names: None,
153+
is_block: false,
153154
}
154155
.def(DUMMY_SP, self)
155156
}

crates/rustc_codegen_spirv/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,7 @@ use rustc_session::Session;
118118
use rustc_span::symbol::{sym, Symbol};
119119
use rustc_target::spec::abi::Abi;
120120
use rustc_target::spec::{LinkerFlavor, PanicStrategy, Target, TargetOptions, TargetTriple};
121+
pub use spirv_tools;
121122
use std::any::Any;
122123
use std::env;
123124
use std::fs::{create_dir_all, File};

crates/rustc_codegen_spirv/src/spirv_type.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use rustc_span::Span;
1010
use rustc_target::abi::{Align, Size};
1111
use std::cell::RefCell;
1212
use std::fmt;
13-
use std::iter::once;
13+
use std::iter;
1414
use std::lazy::SyncLazy;
1515
use std::sync::Mutex;
1616

@@ -34,6 +34,7 @@ pub enum SpirvType {
3434
field_types: Vec<Word>,
3535
field_offsets: Vec<Size>,
3636
field_names: Option<Vec<String>>,
37+
is_block: bool,
3738
},
3839
Opaque {
3940
name: String,
@@ -126,6 +127,7 @@ impl SpirvType {
126127
ref field_types,
127128
ref field_offsets,
128129
ref field_names,
130+
is_block,
129131
} => {
130132
let mut emit = cx.emit_global();
131133
// Ensure a unique struct is emitted each time, due to possibly having different OpMemberDecorates
@@ -151,6 +153,9 @@ impl SpirvType {
151153
emit.member_name(result, index as u32, field_name);
152154
}
153155
}
156+
if is_block {
157+
emit.decorate(result, Decoration::Block, iter::empty());
158+
}
154159
result
155160
}
156161
Self::Opaque { ref name } => cx.emit_global().type_opaque(name),
@@ -168,7 +173,7 @@ impl SpirvType {
168173
cx.emit_global().decorate(
169174
result,
170175
Decoration::ArrayStride,
171-
once(Operand::LiteralInt32(element_size as u32)),
176+
iter::once(Operand::LiteralInt32(element_size as u32)),
172177
);
173178
}
174179
result
@@ -344,6 +349,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
344349
ref field_types,
345350
ref field_offsets,
346351
ref field_names,
352+
is_block,
347353
} => {
348354
let fields = field_types
349355
.iter()
@@ -357,6 +363,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
357363
.field("field_types", &fields)
358364
.field("field_offsets", field_offsets)
359365
.field("field_names", field_names)
366+
.field("is_block", &is_block)
360367
.finish()
361368
}
362369
SpirvType::Opaque { ref name } => f
@@ -485,7 +492,11 @@ impl SpirvTypePrinter<'_, '_> {
485492
ref field_types,
486493
field_offsets: _,
487494
ref field_names,
495+
is_block,
488496
} => {
497+
if is_block {
498+
write!(f, "#[spirv(block)] ")?;
499+
}
489500
write!(f, "struct {} {{ ", name)?;
490501
for (index, &field) in field_types.iter().enumerate() {
491502
let suffix = if index + 1 == field_types.len() {

crates/rustc_codegen_spirv/src/symbols.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -336,6 +336,7 @@ impl Symbols {
336336
SpirvAttribute::ReallyUnsafeIgnoreBitcasts,
337337
),
338338
("sampler", SpirvAttribute::Sampler),
339+
("block", SpirvAttribute::Block),
339340
]
340341
.iter()
341342
.cloned();
@@ -451,6 +452,7 @@ pub enum SpirvAttribute {
451452
access_qualifier: Option<AccessQualifier>,
452453
},
453454
Sampler,
455+
Block,
454456
}
455457

456458
// Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused

crates/spirv-builder/src/test/basic.rs

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::{dis_fn, val};
1+
use super::{dis_fn, val, val_vulkan};
22

33
#[test]
44
fn hello_world() {
@@ -126,3 +126,52 @@ pub fn main() {
126126
}
127127
"#);
128128
}
129+
130+
// NOTE(eddyb) this won't pass Vulkan validation (see `push_constant_vulkan`),
131+
// but should still pass the basline SPIR-V validation.
132+
#[test]
133+
fn push_constant() {
134+
val(r#"
135+
#[derive(Copy, Clone)]
136+
pub struct ShaderConstants {
137+
pub width: u32,
138+
pub height: u32,
139+
pub time: f32,
140+
}
141+
142+
#[allow(unused_attributes)]
143+
#[spirv(fragment)]
144+
pub fn main(
145+
#[spirv(push_constant)] constants: PushConstant<ShaderConstants>,
146+
) {
147+
let _constants = constants.load();
148+
}
149+
"#);
150+
}
151+
152+
// NOTE(eddyb) we specifically run Vulkan validation here, as the default
153+
// validation rules are more lax and don't require a `Block` decoration
154+
// (`#[spirv(block)]` here) on `struct ShaderConstants`.
155+
#[test]
156+
fn push_constant_vulkan() {
157+
val_vulkan(
158+
r#"
159+
#[derive(Copy, Clone)]
160+
#[allow(unused_attributes)]
161+
#[spirv(block)]
162+
pub struct ShaderConstants {
163+
pub width: u32,
164+
pub height: u32,
165+
pub time: f32,
166+
}
167+
168+
#[allow(unused_attributes)]
169+
#[spirv(fragment)]
170+
pub fn main(
171+
#[spirv(push_constant)] constants: PushConstant<ShaderConstants>,
172+
) {
173+
let _constants = constants.load();
174+
}
175+
"#,
176+
);
177+
}

crates/spirv-builder/src/test/mod.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,25 @@ fn val(src: &str) {
9090
build(src);
9191
}
9292

93+
/// While `val` runs baseline SPIR-V validation, for some tests we want the
94+
/// stricter Vulkan validation (`vulkan1.2` specifically), which may produce
95+
/// additional errors (such as missing Vulkan-specific decorations).
96+
fn val_vulkan(src: &str) {
97+
use rustc_codegen_spirv::spirv_tools::{
98+
util::to_binary,
99+
val::{self, Validator},
100+
TargetEnv,
101+
};
102+
103+
let validator = val::create(Some(TargetEnv::Vulkan_1_2));
104+
105+
let _lock = global_lock();
106+
let bytes = std::fs::read(build(src)).unwrap();
107+
if let Err(e) = validator.validate(to_binary(&bytes).unwrap(), None) {
108+
panic!("Vulkan validation failed:\n{}", e.to_string());
109+
}
110+
}
111+
93112
fn assert_str_eq(expected: &str, result: &str) {
94113
let expected = expected
95114
.split('\n')

examples/shaders/shared/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ use spirv_std::glam::Vec3;
1414
use spirv_std::num_traits::Float;
1515

1616
#[derive(Copy, Clone)]
17+
#[allow(unused_attributes)]
18+
#[spirv(block)]
1719
pub struct ShaderConstants {
1820
pub width: u32,
1921
pub height: u32,

0 commit comments

Comments
 (0)