Skip to content

Commit 7d8ab7c

Browse files
committed
leading_zeros: add support for leading_zeros and trailing_zeros, limited to u32/i32
1 parent 6e2c84d commit 7d8ab7c

File tree

4 files changed

+89
-62
lines changed

4 files changed

+89
-62
lines changed

crates/rustc_codegen_spirv/src/builder/ext_inst.rs

+3-30
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
use super::Builder;
22
use crate::builder_spirv::{SpirvValue, SpirvValueExt};
33
use crate::custom_insts;
4+
use rspirv::dr::Operand;
45
use rspirv::spirv::{GLOp, Word};
5-
use rspirv::{dr::Operand, spirv::Capability};
66

77
const GLSL_STD_450: &str = "GLSL.std.450";
88

@@ -13,7 +13,6 @@ pub struct ExtInst {
1313
custom: Option<Word>,
1414

1515
glsl: Option<Word>,
16-
integer_functions_2_intel: bool,
1716
}
1817

1918
impl ExtInst {
@@ -38,32 +37,11 @@ impl ExtInst {
3837
id
3938
}
4039
}
41-
42-
pub fn require_integer_functions_2_intel(&mut self, bx: &Builder<'_, '_>, to_zombie: Word) {
43-
if !self.integer_functions_2_intel {
44-
self.integer_functions_2_intel = true;
45-
if !bx
46-
.builder
47-
.has_capability(Capability::IntegerFunctions2INTEL)
48-
{
49-
bx.zombie(to_zombie, "capability IntegerFunctions2INTEL is required");
50-
}
51-
if !bx
52-
.builder
53-
.has_extension(bx.sym.spv_intel_shader_integer_functions2)
54-
{
55-
bx.zombie(
56-
to_zombie,
57-
"extension SPV_INTEL_shader_integer_functions2 is required",
58-
);
59-
}
60-
}
61-
}
6240
}
6341

6442
impl<'a, 'tcx> Builder<'a, 'tcx> {
6543
pub fn custom_inst(
66-
&mut self,
44+
&self,
6745
result_type: Word,
6846
inst: custom_insts::CustomInst<Operand>,
6947
) -> SpirvValue {
@@ -80,12 +58,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
8058
.with_type(result_type)
8159
}
8260

83-
pub fn gl_op(
84-
&mut self,
85-
op: GLOp,
86-
result_type: Word,
87-
args: impl AsRef<[SpirvValue]>,
88-
) -> SpirvValue {
61+
pub fn gl_op(&self, op: GLOp, result_type: Word, args: impl AsRef<[SpirvValue]>) -> SpirvValue {
8962
let args = args.as_ref();
9063
let glsl = self.ext_inst.borrow_mut().import_glsl(self);
9164
self.emit()

crates/rustc_codegen_spirv/src/builder/intrinsics.rs

+49-28
Original file line numberDiff line numberDiff line change
@@ -211,35 +211,12 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
211211
self.rotate(val, shift, is_left)
212212
}
213213

214-
// TODO: Do we want to manually implement these instead of using intel instructions?
215-
sym::ctlz | sym::ctlz_nonzero => {
216-
let result = self
217-
.emit()
218-
.u_count_leading_zeros_intel(
219-
args[0].immediate().ty,
220-
None,
221-
args[0].immediate().def(self),
222-
)
223-
.unwrap();
224-
self.ext_inst
225-
.borrow_mut()
226-
.require_integer_functions_2_intel(self, result);
227-
result.with_type(args[0].immediate().ty)
228-
}
229-
sym::cttz | sym::cttz_nonzero => {
230-
let result = self
231-
.emit()
232-
.u_count_trailing_zeros_intel(
233-
args[0].immediate().ty,
234-
None,
235-
args[0].immediate().def(self),
236-
)
237-
.unwrap();
238-
self.ext_inst
239-
.borrow_mut()
240-
.require_integer_functions_2_intel(self, result);
241-
result.with_type(args[0].immediate().ty)
214+
sym::ctlz => self.count_leading_trailing_zeros(args[0].immediate(), false, false),
215+
sym::ctlz_nonzero => {
216+
self.count_leading_trailing_zeros(args[0].immediate(), false, true)
242217
}
218+
sym::cttz => self.count_leading_trailing_zeros(args[0].immediate(), true, false),
219+
sym::cttz_nonzero => self.count_leading_trailing_zeros(args[0].immediate(), true, true),
243220

244221
sym::ctpop => self
245222
.emit()
@@ -398,6 +375,50 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
398375
}
399376

400377
impl Builder<'_, '_> {
378+
pub fn count_leading_trailing_zeros(
379+
&self,
380+
arg: SpirvValue,
381+
trailing: bool,
382+
non_zero: bool,
383+
) -> SpirvValue {
384+
let ty = arg.ty;
385+
match self.cx.lookup_type(ty) {
386+
SpirvType::Integer(bits, _) => {
387+
let int_0 = self.constant_int(ty, 0);
388+
let int_bits = self.constant_int(ty, bits as u128).def(self);
389+
let bool = SpirvType::Bool.def(self.span(), self);
390+
391+
let gl_op = if trailing {
392+
// rust is always unsigned
393+
GLOp::FindILsb
394+
} else {
395+
GLOp::FindUMsb
396+
};
397+
398+
let glsl = self.ext_inst.borrow_mut().import_glsl(self);
399+
let find_xsb = self
400+
.emit()
401+
.ext_inst(ty, None, glsl, gl_op as u32, [Operand::IdRef(
402+
arg.def(self),
403+
)])
404+
.unwrap();
405+
if non_zero {
406+
find_xsb
407+
} else {
408+
let is_0 = self
409+
.emit()
410+
.i_equal(bool, None, arg.def(self), int_0.def(self))
411+
.unwrap();
412+
self.emit()
413+
.select(ty, None, is_0, int_bits, find_xsb)
414+
.unwrap()
415+
}
416+
.with_type(ty)
417+
}
418+
_ => self.fatal("counting leading / trailing zeros on a non-integer type"),
419+
}
420+
}
421+
401422
pub fn abort_with_kind_and_message_debug_printf(
402423
&mut self,
403424
kind: &str,

crates/rustc_codegen_spirv/src/symbols.rs

-4
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ pub struct Symbols {
2121
pub spirv: Symbol,
2222
pub libm: Symbol,
2323
pub entry_point_name: Symbol,
24-
pub spv_intel_shader_integer_functions2: Symbol,
2524
pub spv_khr_vulkan_memory_model: Symbol,
2625

2726
descriptor_set: Symbol,
@@ -411,9 +410,6 @@ impl Symbols {
411410
spirv: Symbol::intern("spirv"),
412411
libm: Symbol::intern("libm"),
413412
entry_point_name: Symbol::intern("entry_point_name"),
414-
spv_intel_shader_integer_functions2: Symbol::intern(
415-
"SPV_INTEL_shader_integer_functions2",
416-
),
417413
spv_khr_vulkan_memory_model: Symbol::intern("SPV_KHR_vulkan_memory_model"),
418414

419415
descriptor_set: Symbol::intern("descriptor_set"),
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Test all trailing and leading zeros. No need to test ones, they just call the zero variant with !value
2+
3+
// build-pass
4+
5+
use spirv_std::spirv;
6+
7+
#[spirv(fragment)]
8+
pub fn leading_zeros_u32(
9+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &u32,
10+
out: &mut u32,
11+
) {
12+
*out = u32::leading_zeros(*buffer);
13+
}
14+
15+
#[spirv(fragment)]
16+
pub fn trailing_zeros_u32(
17+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &u32,
18+
out: &mut u32,
19+
) {
20+
*out = u32::trailing_zeros(*buffer);
21+
}
22+
23+
#[spirv(fragment)]
24+
pub fn leading_zeros_i32(
25+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &i32,
26+
out: &mut u32,
27+
) {
28+
*out = i32::leading_zeros(*buffer);
29+
}
30+
31+
#[spirv(fragment)]
32+
pub fn trailing_zeros_i32(
33+
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buffer: &i32,
34+
out: &mut u32,
35+
) {
36+
*out = i32::trailing_zeros(*buffer);
37+
}

0 commit comments

Comments
 (0)