Skip to content

Commit 0759f28

Browse files
committed
count_ones: fix count_ones, must be u32-only in vulkan
1 parent 04e131a commit 0759f28

File tree

1 file changed

+49
-7
lines changed

1 file changed

+49
-7
lines changed

crates/rustc_codegen_spirv/src/builder/intrinsics.rs

+49-7
Original file line numberDiff line numberDiff line change
@@ -218,13 +218,7 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
218218
sym::cttz => self.count_leading_trailing_zeros(args[0].immediate(), true, false),
219219
sym::cttz_nonzero => self.count_leading_trailing_zeros(args[0].immediate(), true, true),
220220

221-
sym::ctpop => {
222-
let u32 = SpirvType::Integer(32, false).def(self.span(), self);
223-
self.emit()
224-
.bit_count(u32, None, args[0].immediate().def(self))
225-
.unwrap()
226-
.with_type(u32)
227-
}
221+
sym::ctpop => self.count_ones(args[0].immediate()),
228222
sym::bitreverse => self
229223
.emit()
230224
.bit_reverse(args[0].immediate().ty, None, args[0].immediate().def(self))
@@ -377,6 +371,54 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
377371
}
378372

379373
impl Builder<'_, '_> {
374+
pub fn count_ones(&self, arg: SpirvValue) -> SpirvValue {
375+
let ty = arg.ty;
376+
match self.cx.lookup_type(ty) {
377+
SpirvType::Integer(bits, signed) => {
378+
let u32 = SpirvType::Integer(32, false).def(self.span(), self);
379+
380+
match bits {
381+
8 | 16 => {
382+
let arg = arg.def(self);
383+
let arg = if signed {
384+
let unsigned =
385+
SpirvType::Integer(bits, false).def(self.span(), self);
386+
self.emit().bitcast(unsigned, None, arg).unwrap()
387+
} else {
388+
arg
389+
};
390+
let arg = self.emit().u_convert(u32, None, arg).unwrap();
391+
self.emit().bit_count(u32, None, arg).unwrap()
392+
}
393+
32 => self.emit().bit_count(u32, None, arg.def(self)).unwrap(),
394+
64 => {
395+
let u32_32 = self.constant_u32(self.span(), 32).def(self);
396+
let arg = arg.def(self);
397+
let lower = self.emit().s_convert(u32, None, arg).unwrap();
398+
let higher = self
399+
.emit()
400+
.shift_left_logical(ty, None, arg, u32_32)
401+
.unwrap();
402+
let higher = self.emit().s_convert(u32, None, higher).unwrap();
403+
404+
let lower_bits = self.emit().bit_count(u32, None, lower).unwrap();
405+
let higher_bits = self.emit().bit_count(u32, None, higher).unwrap();
406+
self.emit().i_add(u32, None, lower_bits, higher_bits).unwrap()
407+
}
408+
_ => {
409+
let undef = self.undef(ty).def(self);
410+
self.zombie(undef, &format!(
411+
"counting leading / trailing zeros on unsupported {ty:?} bit integer type"
412+
));
413+
undef
414+
}
415+
}
416+
.with_type(u32)
417+
}
418+
_ => self.fatal("count_ones on a non-integer type"),
419+
}
420+
}
421+
380422
pub fn count_leading_trailing_zeros(
381423
&self,
382424
arg: SpirvValue,

0 commit comments

Comments
 (0)