Skip to content

Commit 929b227

Browse files
committed
builder_spirv: add SpirvConst::Scalar to replace SpirvConst float/integer variants.
1 parent 1ee244b commit 929b227

File tree

6 files changed

+210
-185
lines changed

6 files changed

+210
-185
lines changed

crates/rustc_codegen_spirv/src/builder/builder_methods.rs

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,7 @@ macro_rules! simple_op {
4949
let size = Size::from_bits(bits);
5050
let as_u128 = |const_val| {
5151
let x = match const_val {
52-
SpirvConst::U32(x) => x as u128,
53-
SpirvConst::U64(x) => x as u128,
52+
SpirvConst::Scalar(x) => x,
5453
_ => return None,
5554
};
5655
Some(if signed {
@@ -225,7 +224,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
225224
}
226225
SpirvType::Array { element, count } => {
227226
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
228-
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
227+
let count = self.builder.lookup_const_scalar(count).unwrap() as usize;
229228
self.constant_composite(
230229
ty.def(self.span(), self),
231230
iter::repeat(elem_pat).take(count),
@@ -269,7 +268,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
269268
SpirvType::Adt { .. } => self.fatal("memset on structs not implemented yet"),
270269
SpirvType::Array { element, count } => {
271270
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
272-
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
271+
let count = self.builder.lookup_const_scalar(count).unwrap() as usize;
273272
self.emit()
274273
.composite_construct(
275274
ty.def(self.span(), self),
@@ -327,7 +326,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
327326
.lookup_type(pat.ty)
328327
.sizeof(self)
329328
.expect("Unable to memset a dynamic sized object");
330-
let size_elem_const = self.constant_int(size_bytes.ty, size_elem.bytes());
329+
let size_elem_const = self.constant_int(size_bytes.ty, size_elem.bytes().into());
331330
let zero = self.constant_int(size_bytes.ty, 0);
332331
let one = self.constant_int(size_bytes.ty, 1);
333332
let zero_align = Align::from_bytes(0).unwrap();
@@ -595,8 +594,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
595594
// - dynamic indexing of a single array
596595
let const_ptr_offset = self
597596
.builder
598-
.lookup_const_u64(ptr_base_index)
599-
.and_then(|idx| Some(idx * self.lookup_type(ty).sizeof(self)?));
597+
.lookup_const_scalar(ptr_base_index)
598+
.and_then(|idx| Some(u64::try_from(idx).ok()? * self.lookup_type(ty).sizeof(self)?));
600599
if let Some(const_ptr_offset) = const_ptr_offset {
601600
if let Some((base_indices, base_pointee_ty)) = self.recover_access_chain_from_offset(
602601
original_pointee_ty,
@@ -707,7 +706,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
707706
let mut emit = self.emit();
708707

709708
let non_zero_ptr_base_index =
710-
ptr_base_index.filter(|&idx| self.builder.lookup_const_u64(idx) != Some(0));
709+
ptr_base_index.filter(|&idx| self.builder.lookup_const_scalar(idx) != Some(0));
711710
if let Some(ptr_base_index) = non_zero_ptr_base_index {
712711
let result = if is_inbounds {
713712
emit.in_bounds_ptr_access_chain(
@@ -1083,8 +1082,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
10831082
// HACK(eddyb) constant-fold branches early on, as the `core` library is
10841083
// starting to get a lot of `if cfg!(debug_assertions)` added to it.
10851084
match self.builder.lookup_const_by_id(cond) {
1086-
Some(SpirvConst::Bool(true)) => self.br(then_llbb),
1087-
Some(SpirvConst::Bool(false)) => self.br(else_llbb),
1085+
Some(SpirvConst::Scalar(1)) => self.br(then_llbb),
1086+
Some(SpirvConst::Scalar(0)) => self.br(else_llbb),
10881087
_ => {
10891088
self.emit()
10901089
.branch_conditional(cond, then_llbb, else_llbb, empty())
@@ -2232,7 +2231,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
22322231
"memcpy with mem flags is not supported yet: {flags:?}"
22332232
));
22342233
}
2235-
let const_size = self.builder.lookup_const_u64(size).map(Size::from_bytes);
2234+
let const_size = self
2235+
.builder
2236+
.lookup_const_scalar(size)
2237+
.and_then(|size| Some(Size::from_bytes(u64::try_from(size).ok()?)));
22362238
if const_size == Some(Size::ZERO) {
22372239
// Nothing to do!
22382240
return;
@@ -2306,6 +2308,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
23062308
"memset with mem flags is not supported yet: {flags:?}"
23072309
));
23082310
}
2311+
2312+
let const_size = self
2313+
.builder
2314+
.lookup_const_scalar(size)
2315+
.and_then(|size| Some(Size::from_bytes(u64::try_from(size).ok()?)));
2316+
23092317
let elem_ty = match self.lookup_type(ptr.ty) {
23102318
SpirvType::Pointer { pointee } => pointee,
23112319
_ => self.fatal(format!(
@@ -2314,13 +2322,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
23142322
)),
23152323
};
23162324
let elem_ty_spv = self.lookup_type(elem_ty);
2317-
let pat = match self.builder.lookup_const_u64(fill_byte) {
2325+
let pat = match self.builder.lookup_const_scalar(fill_byte) {
23182326
Some(fill_byte) => self.memset_const_pattern(&elem_ty_spv, fill_byte as u8),
23192327
None => self.memset_dynamic_pattern(&elem_ty_spv, fill_byte.def(self)),
23202328
}
23212329
.with_type(elem_ty);
2322-
match self.builder.lookup_const_u64(size) {
2323-
Some(size) => self.memset_constant_size(ptr, pat, size),
2330+
match const_size {
2331+
Some(size) => self.memset_constant_size(ptr, pat, size.bytes()),
23242332
None => self.memset_dynamic_size(ptr, pat, size),
23252333
}
23262334
}
@@ -2354,7 +2362,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
23542362
SpirvType::Vector { element, .. } => element,
23552363
other => self.fatal(format!("extract_element not implemented on type {other:?}")),
23562364
};
2357-
match self.builder.lookup_const_u64(idx) {
2365+
match self.builder.lookup_const_scalar(idx) {
23582366
Some(const_index) => self.emit().composite_extract(
23592367
result_type,
23602368
None,
@@ -2781,7 +2789,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
27812789
let mut decoded_format_args = DecodedFormatArgs::default();
27822790

27832791
let const_u32_as_usize = |ct_id| match self.builder.lookup_const_by_id(ct_id)? {
2784-
SpirvConst::U32(x) => Some(x as usize),
2792+
SpirvConst::Scalar(x) => Some(u32::try_from(x).ok()? as usize),
27852793
_ => None,
27862794
};
27872795
let const_slice_as_elem_ids = |slice_ptr_and_len_ids: &[Word]| {
@@ -2948,10 +2956,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
29482956
match (inst.class.opcode, inst.result_id, &id_operands[..]) {
29492957
(Op::Bitcast, Some(r), &[x]) => Inst::Bitcast(r, x),
29502958
(Op::InBoundsAccessChain, Some(r), &[p, i]) => {
2951-
if let Some(SpirvConst::U32(i)) =
2959+
if let Some(SpirvConst::Scalar(i)) =
29522960
self.builder.lookup_const_by_id(i)
29532961
{
2954-
Inst::InBoundsAccessChain(r, p, i)
2962+
Inst::InBoundsAccessChain(r, p, i as u32)
29552963
} else {
29562964
Inst::Unsupported(inst.class.opcode)
29572965
}

crates/rustc_codegen_spirv/src/builder/byte_addressable_buffer.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
111111
count,
112112
),
113113
SpirvType::Array { element, count } => {
114-
let count = match self.builder.lookup_const_u64(count) {
114+
let count = match self.builder.lookup_const_scalar(count) {
115115
Some(count) => count as u32,
116116
None => return self.load_err(original_type, result_type),
117117
};
@@ -301,7 +301,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
301301
count,
302302
),
303303
SpirvType::Array { element, count } => {
304-
let count = match self.builder.lookup_const_u64(count) {
304+
let count = match self.builder.lookup_const_scalar(count) {
305305
Some(count) => count as u32,
306306
None => return self.store_err(original_type, value),
307307
};

crates/rustc_codegen_spirv/src/builder/mod.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
113113
other.debug(shift.ty, self)
114114
)),
115115
};
116-
let int_size = self.constant_int(shift.ty, width as u64);
117-
let mask = self.constant_int(shift.ty, (width - 1) as u64);
116+
let int_size = self.constant_int(shift.ty, width.into());
117+
let mask = self.constant_int(shift.ty, (width - 1).into());
118118
let zero = self.constant_int(shift.ty, 0);
119119
let bool = SpirvType::Bool.def(self.span(), self);
120120
// https://stackoverflow.com/a/10134877

crates/rustc_codegen_spirv/src/builder_spirv.rs

Lines changed: 119 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ use crate::spirv_type::SpirvType;
44
use crate::symbols::Symbols;
55
use crate::target::SpirvTarget;
66
use crate::target_feature::TargetFeature;
7-
use rspirv::dr::{Block, Builder, Module, Operand};
7+
use rspirv::dr::{Block, Builder, Instruction, Module, Operand};
88
use rspirv::spirv::{
99
AddressingModel, Capability, MemoryModel, Op, SourceLanguage, StorageClass, Word,
1010
};
1111
use rspirv::{binary::Assemble, binary::Disassemble};
1212
use rustc_arena::DroplessArena;
13+
use rustc_codegen_ssa::traits::ConstMethods as _;
1314
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
1415
use rustc_data_structures::sync::Lrc;
1516
use rustc_middle::bug;
@@ -18,6 +19,7 @@ use rustc_middle::ty::TyCtxt;
1819
use rustc_span::source_map::SourceMap;
1920
use rustc_span::symbol::Symbol;
2021
use rustc_span::{FileName, FileNameDisplayPreference, SourceFile, Span, DUMMY_SP};
22+
use rustc_target::abi::Size;
2123
use std::assert_matches::assert_matches;
2224
use std::cell::{RefCell, RefMut};
2325
use std::hash::{Hash, Hasher};
@@ -221,13 +223,8 @@ impl SpirvValueExt for Word {
221223

222224
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)]
223225
pub enum SpirvConst<'a, 'tcx> {
224-
U32(u32),
225-
U64(u64),
226-
/// f32 isn't hash, so store bits
227-
F32(u32),
228-
/// f64 isn't hash, so store bits
229-
F64(u64),
230-
Bool(bool),
226+
/// Constants of boolean, integer or floating-point type (up to 128-bit).
227+
Scalar(u128),
231228

232229
Null,
233230
Undef,
@@ -273,11 +270,7 @@ impl<'tcx> SpirvConst<'_, 'tcx> {
273270

274271
match self {
275272
// FIXME(eddyb) these are all noop cases, could they be automated?
276-
SpirvConst::U32(v) => SpirvConst::U32(v),
277-
SpirvConst::U64(v) => SpirvConst::U64(v),
278-
SpirvConst::F32(v) => SpirvConst::F32(v),
279-
SpirvConst::F64(v) => SpirvConst::F64(v),
280-
SpirvConst::Bool(v) => SpirvConst::Bool(v),
273+
SpirvConst::Scalar(v) => SpirvConst::Scalar(v),
281274
SpirvConst::Null => SpirvConst::Null,
282275
SpirvConst::Undef => SpirvConst::Undef,
283276
SpirvConst::ZombieUndefForFnAddr => SpirvConst::ZombieUndefForFnAddr,
@@ -570,8 +563,26 @@ impl<'tcx> BuilderSpirv<'tcx> {
570563
val: SpirvConst<'_, 'tcx>,
571564
cx: &CodegenCx<'tcx>,
572565
) -> SpirvValue {
566+
let scalar_ty = match val {
567+
SpirvConst::Scalar(_) => Some(cx.lookup_type(ty)),
568+
_ => None,
569+
};
570+
571+
// HACK(eddyb) this is done so late (just before interning `val`) to
572+
// minimize any potential misuse from direct `def_constant` calls.
573+
let val = match (val, scalar_ty) {
574+
(SpirvConst::Scalar(val), Some(SpirvType::Integer(bits, signed))) => {
575+
let size = Size::from_bits(bits);
576+
SpirvConst::Scalar(if signed {
577+
size.sign_extend(val)
578+
} else {
579+
size.truncate(val)
580+
})
581+
}
582+
_ => val,
583+
};
584+
573585
let val_with_type = WithType { ty, val };
574-
let mut builder = self.builder(BuilderCursor::default());
575586
if let Some(entry) = self.const_to_id.borrow().get(&val_with_type) {
576587
// FIXME(eddyb) deduplicate this `if`-`else` and its other copies.
577588
let kind = if entry.legal.is_ok() {
@@ -582,16 +593,99 @@ impl<'tcx> BuilderSpirv<'tcx> {
582593
return SpirvValue { kind, ty };
583594
}
584595
let val = val_with_type.val;
596+
597+
// FIXME(eddyb) make this an extension method on `rspirv::dr::Builder`?
598+
let const_op = |builder: &mut Builder, op, lhs, maybe_rhs: Option<_>| {
599+
// HACK(eddyb) remove after `OpSpecConstantOp` support gets added to SPIR-T.
600+
let spirt_has_const_op = false;
601+
602+
if !spirt_has_const_op {
603+
let zombie = builder.undef(ty, None);
604+
cx.zombie_with_span(
605+
zombie,
606+
DUMMY_SP,
607+
&format!("unsupported constant of type `{}`", cx.debug_type(ty)),
608+
);
609+
return zombie;
610+
}
611+
612+
let id = builder.id();
613+
builder
614+
.module_mut()
615+
.types_global_values
616+
.push(Instruction::new(
617+
Op::SpecConstantOp,
618+
Some(ty),
619+
Some(id),
620+
[
621+
Operand::LiteralSpecConstantOpInteger(op),
622+
Operand::IdRef(lhs),
623+
]
624+
.into_iter()
625+
.chain(maybe_rhs.map(Operand::IdRef))
626+
.collect(),
627+
));
628+
id
629+
};
630+
631+
let mut builder = self.builder(BuilderCursor::default());
585632
let id = match val {
586-
SpirvConst::U32(v) | SpirvConst::F32(v) => builder.constant_bit32(ty, v),
587-
SpirvConst::U64(v) | SpirvConst::F64(v) => builder.constant_bit64(ty, v),
588-
SpirvConst::Bool(v) => {
589-
if v {
590-
builder.constant_true(ty)
591-
} else {
592-
builder.constant_false(ty)
633+
SpirvConst::Scalar(v) => match scalar_ty.unwrap() {
634+
SpirvType::Integer(..=32, _) | SpirvType::Float(..=32) => {
635+
builder.constant_bit32(ty, v as u32)
593636
}
594-
}
637+
SpirvType::Integer(64, _) | SpirvType::Float(64) => {
638+
builder.constant_bit64(ty, v as u64)
639+
}
640+
SpirvType::Integer(128, false) => {
641+
// HACK(eddyb) avoid borrow conflicts.
642+
drop(builder);
643+
644+
let const_64_u32_id = cx.const_u32(64).def_cx(cx);
645+
let [lo_id, hi_id] =
646+
[v as u64, (v >> 64) as u64].map(|half| cx.const_u64(half).def_cx(cx));
647+
648+
builder = self.builder(BuilderCursor::default());
649+
let mut const_op =
650+
|op, lhs, maybe_rhs| const_op(&mut builder, op, lhs, maybe_rhs);
651+
let [lo_u128_id, hi_shifted_u128_id] =
652+
[(lo_id, None), (hi_id, Some(const_64_u32_id))].map(
653+
|(half_u64_id, shift)| {
654+
let mut half_u128_id = const_op(Op::UConvert, half_u64_id, None);
655+
if let Some(shift_amount_id) = shift {
656+
half_u128_id = const_op(
657+
Op::ShiftLeftLogical,
658+
half_u128_id,
659+
Some(shift_amount_id),
660+
);
661+
}
662+
half_u128_id
663+
},
664+
);
665+
const_op(Op::BitwiseOr, lo_u128_id, Some(hi_shifted_u128_id))
666+
}
667+
SpirvType::Integer(128, true) | SpirvType::Float(128) => {
668+
// HACK(eddyb) avoid borrow conflicts.
669+
drop(builder);
670+
671+
let v_u128_id = cx.const_u128(v).def_cx(cx);
672+
673+
builder = self.builder(BuilderCursor::default());
674+
const_op(&mut builder, Op::Bitcast, v_u128_id, None)
675+
}
676+
SpirvType::Bool => match v {
677+
0 => builder.constant_false(ty),
678+
1 => builder.constant_true(ty),
679+
_ => cx
680+
.tcx
681+
.dcx()
682+
.fatal(format!("invalid constant value for bool: {v}")),
683+
},
684+
other => cx.tcx.dcx().fatal(format!(
685+
"SpirvConst::Scalar does not support type {}",
686+
other.debug(ty, cx)
687+
)),
688+
},
595689

596690
SpirvConst::Null => builder.constant_null(ty),
597691
SpirvConst::Undef
@@ -606,11 +700,7 @@ impl<'tcx> BuilderSpirv<'tcx> {
606700
};
607701
#[allow(clippy::match_same_arms)]
608702
let legal = match val {
609-
SpirvConst::U32(_)
610-
| SpirvConst::U64(_)
611-
| SpirvConst::F32(_)
612-
| SpirvConst::F64(_)
613-
| SpirvConst::Bool(_) => Ok(()),
703+
SpirvConst::Scalar(_) => Ok(()),
614704

615705
SpirvConst::Null => {
616706
// FIXME(eddyb) check that the type supports `OpConstantNull`.
@@ -712,10 +802,9 @@ impl<'tcx> BuilderSpirv<'tcx> {
712802
}
713803
}
714804

715-
pub fn lookup_const_u64(&self, def: SpirvValue) -> Option<u64> {
805+
pub fn lookup_const_scalar(&self, def: SpirvValue) -> Option<u128> {
716806
match self.lookup_const(def)? {
717-
SpirvConst::U32(v) => Some(v as u64),
718-
SpirvConst::U64(v) => Some(v),
807+
SpirvConst::Scalar(v) => Some(v),
719808
_ => None,
720809
}
721810
}

0 commit comments

Comments
 (0)