Skip to content

Feat: support bit_get function #1713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 281 additions & 0 deletions native/spark-expr/src/bitwise_funcs/bitwise_get.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::{array::*, datatypes::DataType};
use datafusion::common::{Result, ScalarValue};
use datafusion::{error::DataFusionError, logical_expr::ColumnarValue};
use std::sync::Arc;

macro_rules! bit_get_scalar_position {
($args:expr, $array_type:ty, $pos:expr, $bit_size:expr) => {{
if let Some(pos) = $pos {
check_position(*pos, $bit_size as i32)?;
}
let args = $args
.as_any()
.downcast_ref::<$array_type>()
.expect("bit_get_scalar_position failed to downcast array");

let result: Int8Array = args
.iter()
.map(|x| x.and_then(|x| $pos.map(|pos| bit_get(x.into(), pos))))
.collect();

Ok(Arc::new(result))
}};
}

macro_rules! bit_get_array_positions {
($args:expr, $array_type:ty, $positions:expr, $bit_size:expr) => {{
let args = $args
.as_any()
.downcast_ref::<$array_type>()
.expect("bit_get_array_positions failed to downcast args array");

let positions = $positions
.as_any()
.downcast_ref::<Int32Array>()
.expect("bit_get_array_positions failed to downcast positions array");

for pos in positions.iter().flatten() {
check_position(pos, $bit_size as i32)?
}

let result: Int8Array = args
.iter()
.zip(positions.iter())
.map(|(i, p)| i.and_then(|i| p.map(|p| bit_get(i.into(), p))))
.collect();

Ok(Arc::new(result))
}};
}

pub fn spark_bit_get(args: &[ColumnarValue]) -> Result<ColumnarValue> {
if args.len() != 2 {
return Err(DataFusionError::Internal(
"bit_get expects exactly two arguments".to_string(),
));
}
match (&args[0], &args[1]) {
(ColumnarValue::Array(args), ColumnarValue::Array(positions)) => {
if args.len() != positions.len() {
return Err(DataFusionError::Execution(format!(
"Input arrays must have equal length. Positions array has {} elements, but arguments array has {} elements",
positions.len(), args.len()
)));
}
if !matches!(positions.data_type(), DataType::Int32) {
return Err(DataFusionError::Execution(format!(
"Invalid data type for positions array: expected `Int32`, found `{}`",
positions.data_type()
)));
}
let result: Result<ArrayRef> = match args.data_type() {
DataType::Int8 => bit_get_array_positions!(args, Int8Array, positions, i8::BITS),
DataType::Int16 => bit_get_array_positions!(args, Int16Array, positions, i16::BITS),
DataType::Int32 => bit_get_array_positions!(args, Int32Array, positions, i32::BITS),
DataType::Int64 => bit_get_array_positions!(args, Int64Array, positions, i64::BITS),
_ => Err(DataFusionError::Execution(format!(
"Can't be evaluated because the expression's type is {:?}, not signed int",
args.data_type(),
))),
};
result.map(ColumnarValue::Array)
}
(ColumnarValue::Array(args), ColumnarValue::Scalar(ScalarValue::Int32(pos))) => {
let result: Result<ArrayRef> = match args.data_type() {
DataType::Int8 => {
bit_get_scalar_position!(args, Int8Array, pos, i8::BITS)
}
DataType::Int16 => {
bit_get_scalar_position!(args, Int16Array, pos, i16::BITS)
}
DataType::Int32 => {
bit_get_scalar_position!(args, Int32Array, pos, i32::BITS)
}
DataType::Int64 => {
bit_get_scalar_position!(args, Int64Array, pos, i64::BITS)
}
_ => Err(DataFusionError::Execution(format!(
"Can't be evaluated because the expression's type is {:?}, not signed int",
args.data_type(),
))),
};
result.map(ColumnarValue::Array)
}
_ => Err(DataFusionError::Execution(
"Invalid input to function bit_get. Expected (IntegralType array, Int32Scalar) or \
(IntegralType array, Int32Array)"
.to_string(),
)),
}
}

fn bit_get(arg: i64, pos: i32) -> i8 {
((arg >> pos) & 1) as i8
}

fn check_position(pos: i32, bit_size: i32) -> Result<()> {
if pos < 0 {
return Err(DataFusionError::Execution(format!(
"Invalid bit position: {:?} is less than zero",
pos
)));
}
if bit_size <= pos {
return Err(DataFusionError::Execution(format!(
"Invalid bit position: {:?} exceeds the bit upper limit: {:?}",
pos, bit_size
)));
}
Ok(())
}

#[cfg(test)]
mod tests {
use super::*;
use datafusion::common::cast::as_int8_array;

#[test]
fn bitwise_get_scalar_position() -> Result<()> {
let args = vec![
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
Some(1),
None,
Some(1234553454),
]))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(1))),
];

let expected = &Int8Array::from(vec![Some(0), None, Some(1)]);

let ColumnarValue::Array(result) = spark_bit_get(&args)? else {
unreachable!()
};

let result = as_int8_array(&result).expect("failed to downcast to Int8Array");

assert_eq!(result, expected);

Ok(())
}

#[test]
fn bitwise_get_scalar_negative_position() -> Result<()> {
let args = vec![
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
Some(1),
None,
Some(1234553454),
]))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(-1))),
];

let expected = String::from("Execution error: Invalid bit position: -1 is less than zero");
let result = spark_bit_get(&args).err().unwrap().to_string();

assert_eq!(result, expected);

Ok(())
}

#[test]
fn bitwise_get_scalar_overflow_position() -> Result<()> {
let args = vec![
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
Some(1),
None,
Some(1234553454),
]))),
ColumnarValue::Scalar(ScalarValue::Int32(Some(33))),
];

let expected = String::from(
"Execution error: Invalid bit position: 33 exceeds the bit upper limit: 32",
);
let result = spark_bit_get(&args).err().unwrap().to_string();

assert_eq!(result, expected);

Ok(())
}

#[test]
fn bitwise_get_array_positions() -> Result<()> {
let args = vec![
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
Some(1),
None,
Some(1234553454),
]))),
ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(1), None, Some(1)]))),
];

let expected = &Int8Array::from(vec![Some(0), None, Some(1)]);

let ColumnarValue::Array(result) = spark_bit_get(&args)? else {
unreachable!()
};

let result = as_int8_array(&result).expect("failed to downcast to Int8Array");

assert_eq!(result, expected);

Ok(())
}

#[test]
fn bitwise_get_array_positions_contains_negative() -> Result<()> {
let args = vec![
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
Some(1),
None,
Some(1234553454),
]))),
ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(-1), None, Some(1)]))),
];

let expected = String::from("Execution error: Invalid bit position: -1 is less than zero");
let result = spark_bit_get(&args).err().unwrap().to_string();

assert_eq!(result, expected);

Ok(())
}

#[test]
fn bitwise_get_array_positions_contains_overflow() -> Result<()> {
let args = vec![
ColumnarValue::Array(Arc::new(Int32Array::from(vec![
Some(1),
None,
Some(1234553454),
]))),
ColumnarValue::Array(Arc::new(Int32Array::from(vec![Some(33), None, Some(1)]))),
];

let expected = String::from(
"Execution error: Invalid bit position: 33 exceeds the bit upper limit: 32",
);
let result = spark_bit_get(&args).err().unwrap().to_string();

assert_eq!(result, expected);

Ok(())
}
}
2 changes: 2 additions & 0 deletions native/spark-expr/src/bitwise_funcs/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
// specific language governing permissions and limitations
// under the License.

mod bitwise_get;
mod bitwise_not;

pub use bitwise_get::spark_bit_get;
pub use bitwise_not::{bitwise_not, BitwiseNotExpr};
12 changes: 8 additions & 4 deletions native/spark-expr/src/comet_scalar_funcs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@

use crate::hash_funcs::*;
use crate::{
spark_array_repeat, spark_ceil, spark_date_add, spark_date_sub, spark_decimal_div,
spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan, spark_make_decimal,
spark_read_side_padding, spark_round, spark_rpad, spark_unhex, spark_unscaled_value,
SparkChrFunc,
spark_array_repeat, spark_bit_get, spark_ceil, spark_date_add, spark_date_sub,
spark_decimal_div, spark_decimal_integral_div, spark_floor, spark_hex, spark_isnan,
spark_make_decimal, spark_read_side_padding, spark_round, spark_rpad, spark_unhex,
spark_unscaled_value, SparkChrFunc,
};
use arrow::datatypes::DataType;
use datafusion::common::{DataFusionError, Result as DataFusionResult};
Expand Down Expand Up @@ -145,6 +145,10 @@ pub fn create_comet_physical_fun(
let func = Arc::new(spark_array_repeat);
make_comet_scalar_udf!("array_repeat", func, without data_type)
}
"bit_get" => {
let func = Arc::new(spark_bit_get);
make_comet_scalar_udf!("bit_get", func, without data_type)
}
_ => registry.udf(fun_name).map_err(|e| {
DataFusionError::Execution(format!(
"Function {fun_name} not found in the registry: {e}",
Expand Down
Loading
Loading