Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 55 additions & 20 deletions native/core/src/execution/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ use datafusion_comet_proto::{
use datafusion_comet_spark_expr::monotonically_increasing_id::MonotonicallyIncreasingId;
use datafusion_comet_spark_expr::{
ArrayInsert, Avg, AvgDecimal, Cast, CheckOverflow, Correlation, Covariance, CreateNamedStruct,
GetArrayStructFields, GetStructField, IfExpr, ListExtract, NormalizeNaNAndZero, RandExpr,
RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson, UnboundColumn, Variance,
DecimalRescaleCheckOverflow, GetArrayStructFields, GetStructField, IfExpr, ListExtract,
NormalizeNaNAndZero, RandExpr, RandnExpr, SparkCastOptions, Stddev, SumDecimal, ToJson,
UnboundColumn, Variance, WideDecimalBinaryExpr, WideDecimalOp,
};
use itertools::Itertools;
use jni::objects::GlobalRef;
Expand Down Expand Up @@ -376,10 +377,45 @@ impl PhysicalPlanner {
)))
}
ExprStruct::CheckOverflow(expr) => {
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
let child =
self.create_expr(expr.child.as_ref().unwrap(), Arc::clone(&input_schema))?;
let data_type = to_arrow_datatype(expr.datatype.as_ref().unwrap());
let fail_on_error = expr.fail_on_error;

// WideDecimalBinaryExpr already handles overflow — skip redundant check
// but only if its output type matches CheckOverflow's declared type
if child
.as_any()
.downcast_ref::<WideDecimalBinaryExpr>()
.is_some()
{
let child_type = child.data_type(&input_schema)?;
if child_type == data_type {
return Ok(child);
}
}

// Fuse Cast(Decimal128→Decimal128) + CheckOverflow into single rescale+check
// Only fuse when the Cast target type matches the CheckOverflow output type
if let Some(cast) = child.as_any().downcast_ref::<Cast>() {
if let (
DataType::Decimal128(p_out, s_out),
Ok(DataType::Decimal128(_p_in, s_in)),
) = (&data_type, cast.child.data_type(&input_schema))
{
let cast_target = cast.data_type(&input_schema)?;
if cast_target == data_type {
return Ok(Arc::new(DecimalRescaleCheckOverflow::new(
Arc::clone(&cast.child),
s_in,
*p_out,
*s_out,
fail_on_error,
)));
}
}
}

Ok(Arc::new(CheckOverflow::new(
child,
data_type,
Expand Down Expand Up @@ -682,23 +718,22 @@ impl PhysicalPlanner {
|| (op == DataFusionOperator::Multiply && p1 + p2 >= DECIMAL128_MAX_PRECISION) =>
{
let data_type = return_type.map(to_arrow_datatype).unwrap();
// For some Decimal128 operations, we need wider internal digits.
// Cast left and right to Decimal256 and cast the result back to Decimal128
let left = Arc::new(Cast::new(
left,
DataType::Decimal256(p1, s1),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let right = Arc::new(Cast::new(
right,
DataType::Decimal256(p2, s2),
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
));
let child = Arc::new(BinaryExpr::new(left, op, right));
Ok(Arc::new(Cast::new(
child,
data_type,
SparkCastOptions::new_without_timezone(EvalMode::Legacy, false),
let (p_out, s_out) = match &data_type {
DataType::Decimal128(p, s) => (*p, *s),
dt => {
return Err(ExecutionError::GeneralError(format!(
"Expected Decimal128 return type, got {dt:?}"
)))
}
};
let wide_op = match op {
DataFusionOperator::Plus => WideDecimalOp::Add,
DataFusionOperator::Minus => WideDecimalOp::Subtract,
DataFusionOperator::Multiply => WideDecimalOp::Multiply,
_ => unreachable!(),
};
Ok(Arc::new(WideDecimalBinaryExpr::new(
left, right, wide_op, p_out, s_out, eval_mode,
)))
}
(
Expand Down
4 changes: 4 additions & 0 deletions native/spark-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,7 @@ path = "tests/spark_expr_reg.rs"
[[bench]]
name = "cast_from_boolean"
harness = false

[[bench]]
name = "wide_decimal"
harness = false
162 changes: 162 additions & 0 deletions native/spark-expr/benches/wide_decimal.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Benchmarks comparing the old Cast->BinaryExpr->Cast chain vs the fused WideDecimalBinaryExpr
//! for Decimal128 arithmetic that requires wider intermediate precision.

use arrow::array::builder::Decimal128Builder;
use arrow::array::RecordBatch;
use arrow::datatypes::{DataType, Field, Schema};
use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use datafusion::logical_expr::Operator;
use datafusion::physical_expr::expressions::{BinaryExpr, Column};
use datafusion::physical_expr::PhysicalExpr;
use datafusion_comet_spark_expr::{
Cast, EvalMode, SparkCastOptions, WideDecimalBinaryExpr, WideDecimalOp,
};
use std::sync::Arc;

const BATCH_SIZE: usize = 8192;

/// Build a RecordBatch with two Decimal128 columns.
fn make_decimal_batch(p1: u8, s1: i8, p2: u8, s2: i8) -> RecordBatch {
let mut left = Decimal128Builder::new();
let mut right = Decimal128Builder::new();
for i in 0..BATCH_SIZE as i128 {
left.append_value(123456789012345_i128 + i * 1000);
right.append_value(987654321098765_i128 - i * 1000);
}
let left = left.finish().with_data_type(DataType::Decimal128(p1, s1));
let right = right.finish().with_data_type(DataType::Decimal128(p2, s2));
let schema = Schema::new(vec![
Field::new("left", DataType::Decimal128(p1, s1), false),
Field::new("right", DataType::Decimal128(p2, s2), false),
]);
RecordBatch::try_new(Arc::new(schema), vec![Arc::new(left), Arc::new(right)]).unwrap()
}

/// Old approach: Cast(Decimal128->Decimal256) both sides, BinaryExpr, Cast(Decimal256->Decimal128).
fn build_old_expr(
p1: u8,
s1: i8,
p2: u8,
s2: i8,
op: Operator,
out_type: DataType,
) -> Arc<dyn PhysicalExpr> {
let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
let cast_opts = SparkCastOptions::new_without_timezone(EvalMode::Legacy, false);
let left_cast = Arc::new(Cast::new(
left_col,
DataType::Decimal256(p1, s1),
cast_opts.clone(),
));
let right_cast = Arc::new(Cast::new(
right_col,
DataType::Decimal256(p2, s2),
cast_opts.clone(),
));
let binary = Arc::new(BinaryExpr::new(left_cast, op, right_cast));
Arc::new(Cast::new(binary, out_type, cast_opts))
}

/// New approach: single fused WideDecimalBinaryExpr.
fn build_new_expr(op: WideDecimalOp, p_out: u8, s_out: i8) -> Arc<dyn PhysicalExpr> {
let left_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("left", 0));
let right_col: Arc<dyn PhysicalExpr> = Arc::new(Column::new("right", 1));
Arc::new(WideDecimalBinaryExpr::new(
left_col,
right_col,
op,
p_out,
s_out,
EvalMode::Legacy,
))
}

fn bench_case(
group: &mut criterion::BenchmarkGroup<criterion::measurement::WallTime>,
name: &str,
batch: &RecordBatch,
old_expr: &Arc<dyn PhysicalExpr>,
new_expr: &Arc<dyn PhysicalExpr>,
) {
group.bench_with_input(BenchmarkId::new("old", name), batch, |b, batch| {
b.iter(|| old_expr.evaluate(batch).unwrap());
});
group.bench_with_input(BenchmarkId::new("fused", name), batch, |b, batch| {
b.iter(|| new_expr.evaluate(batch).unwrap());
});
}

fn criterion_benchmark(c: &mut Criterion) {
let mut group = c.benchmark_group("wide_decimal");

// Case 1: Add with same scale - Decimal128(38,10) + Decimal128(38,10) -> Decimal128(38,10)
// Triggers wide path because max(s1,s2) + max(p1-s1, p2-s2) = 10 + 28 = 38 >= 38
{
let batch = make_decimal_batch(38, 10, 38, 10);
let old = build_old_expr(38, 10, 38, 10, Operator::Plus, DataType::Decimal128(38, 10));
let new = build_new_expr(WideDecimalOp::Add, 38, 10);
bench_case(&mut group, "add_same_scale", &batch, &old, &new);
}

// Case 2: Add with different scales - Decimal128(38,6) + Decimal128(38,4) -> Decimal128(38,6)
{
let batch = make_decimal_batch(38, 6, 38, 4);
let old = build_old_expr(38, 6, 38, 4, Operator::Plus, DataType::Decimal128(38, 6));
let new = build_new_expr(WideDecimalOp::Add, 38, 6);
bench_case(&mut group, "add_diff_scale", &batch, &old, &new);
}

// Case 3: Multiply - Decimal128(20,10) * Decimal128(20,10) -> Decimal128(38,6)
// Triggers wide path because p1 + p2 = 40 >= 38
{
let batch = make_decimal_batch(20, 10, 20, 10);
let old = build_old_expr(
20,
10,
20,
10,
Operator::Multiply,
DataType::Decimal128(38, 6),
);
let new = build_new_expr(WideDecimalOp::Multiply, 38, 6);
bench_case(&mut group, "multiply", &batch, &old, &new);
}

// Case 4: Subtract with same scale - Decimal128(38,18) - Decimal128(38,18) -> Decimal128(38,18)
{
let batch = make_decimal_batch(38, 18, 38, 18);
let old = build_old_expr(
38,
18,
38,
18,
Operator::Minus,
DataType::Decimal128(38, 18),
);
let new = build_new_expr(WideDecimalOp::Subtract, 38, 18);
bench_case(&mut group, "subtract", &batch, &old, &new);
}

group.finish();
}

criterion_group!(benches, criterion_benchmark);
criterion_main!(benches);
3 changes: 2 additions & 1 deletion native/spark-expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@ pub use json_funcs::{FromJson, ToJson};
pub use math_funcs::{
create_modulo_expr, create_negate_expr, spark_ceil, spark_decimal_div,
spark_decimal_integral_div, spark_floor, spark_make_decimal, spark_round, spark_unhex,
spark_unscaled_value, CheckOverflow, NegativeExpr, NormalizeNaNAndZero,
spark_unscaled_value, CheckOverflow, DecimalRescaleCheckOverflow, NegativeExpr,
NormalizeNaNAndZero, WideDecimalBinaryExpr, WideDecimalOp,
};
pub use string_funcs::*;

Expand Down
Loading
Loading