@@ -218,13 +218,7 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
218
218
sym:: cttz => self . count_leading_trailing_zeros ( args[ 0 ] . immediate ( ) , true , false ) ,
219
219
sym:: cttz_nonzero => self . count_leading_trailing_zeros ( args[ 0 ] . immediate ( ) , true , true ) ,
220
220
221
- sym:: ctpop => {
222
- let u32 = SpirvType :: Integer ( 32 , false ) . def ( self . span ( ) , self ) ;
223
- self . emit ( )
224
- . bit_count ( u32, None , args[ 0 ] . immediate ( ) . def ( self ) )
225
- . unwrap ( )
226
- . with_type ( u32)
227
- }
221
+ sym:: ctpop => self . count_ones ( args[ 0 ] . immediate ( ) ) ,
228
222
sym:: bitreverse => self
229
223
. emit ( )
230
224
. bit_reverse ( args[ 0 ] . immediate ( ) . ty , None , args[ 0 ] . immediate ( ) . def ( self ) )
@@ -377,6 +371,54 @@ impl<'a, 'tcx> IntrinsicCallBuilderMethods<'tcx> for Builder<'a, 'tcx> {
377
371
}
378
372
379
373
impl Builder < ' _ , ' _ > {
374
+ pub fn count_ones ( & self , arg : SpirvValue ) -> SpirvValue {
375
+ let ty = arg. ty ;
376
+ match self . cx . lookup_type ( ty) {
377
+ SpirvType :: Integer ( bits, signed) => {
378
+ let u32 = SpirvType :: Integer ( 32 , false ) . def ( self . span ( ) , self ) ;
379
+
380
+ match bits {
381
+ 8 | 16 => {
382
+ let arg = arg. def ( self ) ;
383
+ let arg = if signed {
384
+ let unsigned =
385
+ SpirvType :: Integer ( bits, false ) . def ( self . span ( ) , self ) ;
386
+ self . emit ( ) . bitcast ( unsigned, None , arg) . unwrap ( )
387
+ } else {
388
+ arg
389
+ } ;
390
+ let arg = self . emit ( ) . u_convert ( u32, None , arg) . unwrap ( ) ;
391
+ self . emit ( ) . bit_count ( u32, None , arg) . unwrap ( )
392
+ }
393
+ 32 => self . emit ( ) . bit_count ( u32, None , arg. def ( self ) ) . unwrap ( ) ,
394
+ 64 => {
395
+ let u32_32 = self . constant_u32 ( self . span ( ) , 32 ) . def ( self ) ;
396
+ let arg = arg. def ( self ) ;
397
+ let lower = self . emit ( ) . s_convert ( u32, None , arg) . unwrap ( ) ;
398
+ let higher = self
399
+ . emit ( )
400
+ . shift_left_logical ( ty, None , arg, u32_32)
401
+ . unwrap ( ) ;
402
+ let higher = self . emit ( ) . s_convert ( u32, None , higher) . unwrap ( ) ;
403
+
404
+ let lower_bits = self . emit ( ) . bit_count ( u32, None , lower) . unwrap ( ) ;
405
+ let higher_bits = self . emit ( ) . bit_count ( u32, None , higher) . unwrap ( ) ;
406
+ self . emit ( ) . i_add ( u32, None , lower_bits, higher_bits) . unwrap ( )
407
+ }
408
+ _ => {
409
+ let undef = self . undef ( ty) . def ( self ) ;
410
+ self . zombie ( undef, & format ! (
411
+ "counting leading / trailing zeros on unsupported {ty:?} bit integer type"
412
+ ) ) ;
413
+ undef
414
+ }
415
+ }
416
+ . with_type ( u32)
417
+ }
418
+ _ => self . fatal ( "count_ones on a non-integer type" ) ,
419
+ }
420
+ }
421
+
380
422
pub fn count_leading_trailing_zeros (
381
423
& self ,
382
424
arg : SpirvValue ,
0 commit comments