Skip to content

Feat: support bit_count function #1602

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
105 changes: 105 additions & 0 deletions native/spark-expr/src/bitwise_funcs/bitwise_count.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::{array::*, datatypes::DataType};
use datafusion::common::Result;
use datafusion::{error::DataFusionError, logical_expr::ColumnarValue};
use std::sync::Arc;

macro_rules! compute_op {
($OPERAND:expr, $DT:ident) => {{
let operand = $OPERAND.as_any().downcast_ref::<$DT>().ok_or_else(|| {
DataFusionError::Execution(format!(
"compute_op failed to downcast array to: {:?}",
stringify!($DT)
))
})?;

let result: Int32Array = operand
.iter()
.map(|x| x.map(|y| bit_count(y.into())))
.collect();

Ok(Arc::new(result))
}};
}

pub fn spark_bit_count(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 1 {
return Err(DataFusionError::Internal(
"bit_count expects exactly one argument".to_string(),
));
}
match &args[0] {
ColumnarValue::Array(array) => {
let result: Result<ArrayRef> = match array.data_type() {
DataType::Int8 | DataType::Boolean => compute_op!(array, Int8Array),
DataType::Int16 => compute_op!(array, Int16Array),
DataType::Int32 => compute_op!(array, Int32Array),
DataType::Int64 => compute_op!(array, Int64Array),
_ => Err(DataFusionError::Execution(format!(
"Can't be evaluated because the expression's type is {:?}, not signed int",
array.data_type(),
))),
};
result.map(ColumnarValue::Array)
}
ColumnarValue::Scalar(_) => Err(DataFusionError::Internal(
"shouldn't go to bit_count scalar path".to_string(),
)),
}
}

// Here’s the equivalent Rust implementation of the bitCount function (similar to Apache Spark's bitCount for LongType)
fn bit_count(i: i64) -> i32 {
let mut u = i as u64;
u = u - ((u >> 1) & 0x5555555555555555);
u = (u & 0x3333333333333333) + ((u >> 2) & 0x3333333333333333);
u = (u + (u >> 4)) & 0x0f0f0f0f0f0f0f0f;
u = u + (u >> 8);
u = u + (u >> 16);
u = u + (u >> 32);
(u as i32) & 0x7f
}

#[cfg(test)]
mod tests {
use datafusion::common::{cast::as_int32_array, Result};

use super::*;

#[test]
fn bitwise_count_op() -> Result<()> {
let args = vec![ColumnarValue::Array(Arc::new(Int32Array::from(vec![
Some(1),
None,
Some(12345),
Some(89),
Some(-3456),
])))];
let expected = &Int32Array::from(vec![Some(1), None, Some(6), Some(4), Some(54)]);

let ColumnarValue::Array(result) = spark_bit_count(&args)? else {
unreachable!()
};

let result = as_int32_array(&result).expect("failed to downcast to In32Array");
assert_eq!(result, expected);

Ok(())
}
}
2 changes: 2 additions & 0 deletions native/spark-expr/src/bitwise_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

mod bitwise_count;
mod bitwise_not;

pub use bitwise_count::spark_bit_count;
pub use bitwise_not::{bitwise_not, BitwiseNotExpr};
12 changes: 8 additions & 4 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

use crate::hash_funcs::*;
use crate::{
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
SparkChrFunc,
spark_array_repeat, spark_bit_count, spark_ceil, spark_date_add, spark_date_sub,
spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan,
spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
spark_unscaled_value, SparkChrFunc,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
Expand Down Expand Up @@ -145,6 +145,10 @@ pub fn create_comet_physical_fun(
let func = Arc::new(spark_array_repeat);
make_comet_scalar_udf!("array_repeat", func, without data_type)
}
"bit_count" => {
let func = Arc::new(spark_bit_count);
make_comet_scalar_udf!("bit_count", func, without data_type)
}
_ => registry.udf(fun_name).map_err(|e| {
DataFusionError::Execution(format!(
"Function {fun_name} not found in the registry: {e}",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1655,6 +1655,12 @@ object QueryPlanSerde extends Logging with CometExprShim {
binding,
(builder, binaryExpr) => builder.setBitwiseXor(binaryExpr))

case BitwiseCount(child) =>
val childProto = exprToProto(child, inputs, binding)
val bitCountScalarExpr =
scalarFunctionExprToProtoWithReturnType("bit_count", IntegerType, childProto)
optExprWithInfo(bitCountScalarExpr, expr, expr.children: _*)

case ShiftRight(left, right) =>
// DataFusion bitwise shift right expression requires
// same data type between left and right side
Expand Down
69 changes: 69 additions & 0 deletions spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
import org.apache.spark.sql.types.{Decimal, DecimalType}

import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
import org.apache.comet.testing.{DataGenOptions, ParquetGenerator}

class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
import testImplicits._
Expand Down Expand Up @@ -90,6 +91,74 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
}
}

test("bitwise_count - min/max values") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
val table = "bitwise_count_test"
withTable(table) {
sql(s"create table $table(col1 long, col2 int, col3 short, col4 byte) using parquet")
sql(s"insert into $table values(1111, 2222, 17, 7)")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind adding random number cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep. Added tests with random data.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also add a test with a Parquet file not created by Spark (see makeParquetFileAllTypes) and see if this reports correct results with unsigned int columns?

sql(
s"insert into $table values(${Long.MaxValue}, ${Int.MaxValue}, ${Short.MaxValue}, ${Byte.MaxValue})")
sql(
s"insert into $table values(${Long.MinValue}, ${Int.MinValue}, ${Short.MinValue}, ${Byte.MinValue})")

checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col1) FROM $table"))
checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col2) FROM $table"))
checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col3) FROM $table"))
checkSparkAnswerAndOperator(sql(s"SELECT bit_count(col4) FROM $table"))
checkSparkAnswerAndOperator(sql(s"SELECT bit_count(true) FROM $table"))
checkSparkAnswerAndOperator(sql(s"SELECT bit_count(false) FROM $table"))
}
}
}
}

test("bitwise_count - random values (spark gen)") {
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
val filename = path.toString
val random = new Random(42)
withSQLConf(CometConf.COMET_ENABLED.key -> "false") {
ParquetGenerator.makeParquetFile(
random,
spark,
filename,
10,
DataGenOptions(
allowNull = true,
generateNegativeZero = true,
generateArray = false,
generateStruct = false,
generateMap = false))
}
val table = spark.read.parquet(filename)
val df =
table.selectExpr("bit_count(c1)", "bit_count(c2)", "bit_count(c3)", "bit_count(c4)")

checkSparkAnswerAndOperator(df)
}
}

test("bitwise_count - random values (native parquet gen)") {
Seq(true, false).foreach { dictionaryEnabled =>
withTempDir { dir =>
val path = new Path(dir.toURI.toString, "test.parquet")
makeParquetFileAllTypes(path, dictionaryEnabled, 0, 10000, nullEnabled = false)
val table = spark.read.parquet(path.toString)
checkSparkAnswerAndOperator(
table
.selectExpr(
"bit_count(_2)",
"bit_count(_3)",
"bit_count(_4)",
"bit_count(_5)",
"bit_count(_10)",
"bit_count(_11)"))
}
}
}

test("bitwise shift with different left/right types") {
Seq(false, true).foreach { dictionary =>
withSQLConf("parquet.enable.dictionary" -> dictionary.toString) {
Expand Down
Loading