Skip to content

Commit c5c80e2

Browse files
committed
feat: rand expression support
1 parent 58dee73 commit c5c80e2

File tree

7 files changed

+296
-2
lines changed

7 files changed

+296
-2
lines changed

native/core/src/execution/jni_api.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
317317
// query plan, we need to defer stream initialization to first time execution.
318318
if exec_context.root_op.is_none() {
319319
let start = Instant::now();
320-
let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx))
320+
let planner = PhysicalPlanner::new(Arc::clone(&exec_context.session_ctx), partition)
321321
.with_exec_id(exec_context_id);
322322
let (scans, root_op) = planner.create_plan(
323323
&exec_context.spark_plan,

native/core/src/execution/planner.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ use datafusion_comet_proto::{
8282
},
8383
spark_partitioning::{partitioning::PartitioningStruct, Partitioning as SparkPartitioning},
8484
};
85+
use datafusion_comet_spark_expr::rand::RandExpr;
8586
use datafusion_comet_spark_expr::{
8687
ArrayInsert, Avg, AvgDecimal, BitwiseNotExpr, Cast, CheckOverflow, Contains, Correlation,
8788
Covariance, CreateNamedStruct, DateTruncExpr, EndsWith, GetArrayStructFields, GetStructField,
@@ -127,6 +128,7 @@ pub const TEST_EXEC_CONTEXT_ID: i64 = -1;
127128
pub struct PhysicalPlanner {
128129
// The execution context id of this planner.
129130
exec_context_id: i64,
131+
partition: i32,
130132
execution_props: ExecutionProps,
131133
session_ctx: Arc<SessionContext>,
132134
}
@@ -137,17 +139,19 @@ impl Default for PhysicalPlanner {
137139
let execution_props = ExecutionProps::new();
138140
Self {
139141
exec_context_id: TEST_EXEC_CONTEXT_ID,
142+
partition: 0,
140143
execution_props,
141144
session_ctx,
142145
}
143146
}
144147
}
145148

146149
impl PhysicalPlanner {
147-
pub fn new(session_ctx: Arc<SessionContext>) -> Self {
150+
pub fn new(session_ctx: Arc<SessionContext>, partition: i32) -> Self {
148151
let execution_props = ExecutionProps::new();
149152
Self {
150153
exec_context_id: TEST_EXEC_CONTEXT_ID,
154+
partition,
151155
execution_props,
152156
session_ctx,
153157
}
@@ -156,6 +160,7 @@ impl PhysicalPlanner {
156160
pub fn with_exec_id(self, exec_context_id: i64) -> Self {
157161
Self {
158162
exec_context_id,
163+
partition: self.partition,
159164
execution_props: self.execution_props,
160165
session_ctx: Arc::clone(&self.session_ctx),
161166
}
@@ -720,6 +725,10 @@ impl PhysicalPlanner {
720725
expr.legacy_negative_index,
721726
)))
722727
}
728+
ExprStruct::Rand(expr) => {
729+
let child = self.create_expr(expr.child.as_ref().unwrap(), input_schema)?;
730+
Ok(Arc::new(RandExpr::new(child, self.partition)))
731+
}
723732
expr => Err(ExecutionError::GeneralError(format!(
724733
"Not implemented: {:?}",
725734
expr

native/proto/src/proto/expr.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ message Expr {
8484
GetArrayStructFields get_array_struct_fields = 57;
8585
BinaryExpr array_append = 58;
8686
ArrayInsert array_insert = 59;
87+
UnaryExpr rand = 60;
8788
}
8889
}
8990

native/spark-expr/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ pub use normalize_nan::NormalizeNaNAndZero;
6666
mod variance;
6767
pub use variance::Variance;
6868
mod comet_scalar_funcs;
69+
pub mod rand;
70+
6971
pub use cast::{spark_cast, Cast, SparkCastOptions};
7072
pub use comet_scalar_funcs::create_comet_physical_fun;
7173
pub use error::{SparkError, SparkResult};

native/spark-expr/src/rand.rs

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use crate::spark_hash::spark_compatible_murmur3_hash;
19+
use arrow_array::builder::Float64Builder;
20+
use arrow_array::{Float64Array, RecordBatch};
21+
use arrow_schema::{DataType, Schema};
22+
use datafusion::logical_expr::ColumnarValue;
23+
use datafusion::physical_expr::PhysicalExpr;
24+
use datafusion::physical_expr_common::physical_expr::down_cast_any_ref;
25+
use datafusion_common::ScalarValue;
26+
use datafusion_common::{DataFusionError, Result};
27+
use std::any::Any;
28+
use std::fmt::Display;
29+
use std::hash::{Hash, Hasher};
30+
use std::sync::{Arc, Mutex};
31+
32+
const DOUBLE_UNIT: f64 = 1.1102230246251565e-16;
33+
const SPARK_MURMUR_ARRAY_SEED: u32 = 0x3c074a61;
34+
35+
#[derive(Debug, Clone)]
36+
struct XorShiftRandom {
37+
seed: i64,
38+
}
39+
40+
impl XorShiftRandom {
41+
fn from_init_seed(init_seed: i64) -> Self {
42+
XorShiftRandom {
43+
seed: Self::init_seed(init_seed),
44+
}
45+
}
46+
47+
fn from_stored_seed(stored_seed: i64) -> Self {
48+
XorShiftRandom { seed: stored_seed }
49+
}
50+
51+
fn next(&mut self, bits: u8) -> i32 {
52+
let mut next_seed = self.seed ^ (self.seed << 21);
53+
next_seed ^= ((next_seed as u64) >> 35) as i64;
54+
next_seed ^= next_seed << 4;
55+
self.seed = next_seed;
56+
(next_seed & ((1i64 << bits) - 1)) as i32
57+
}
58+
59+
pub fn next_f64(&mut self) -> f64 {
60+
let a = self.next(26) as i64;
61+
let b = self.next(27) as i64;
62+
((a << 27) + b) as f64 * DOUBLE_UNIT
63+
}
64+
65+
fn init_seed(init: i64) -> i64 {
66+
let bytes_repr = init.to_be_bytes();
67+
let low_bits = spark_compatible_murmur3_hash(&bytes_repr, SPARK_MURMUR_ARRAY_SEED);
68+
let high_bits = spark_compatible_murmur3_hash(&bytes_repr, low_bits);
69+
((high_bits as i64) << 32) | (low_bits as i64 & 0xFFFFFFFFi64)
70+
}
71+
}
72+
73+
#[derive(Debug)]
74+
pub struct RandExpr {
75+
seed: Arc<dyn PhysicalExpr>,
76+
init_seed_shift: i32,
77+
state_holder: Arc<Mutex<Option<i64>>>,
78+
}
79+
80+
impl RandExpr {
81+
pub fn new(seed: Arc<dyn PhysicalExpr>, init_seed_shift: i32) -> Self {
82+
Self {
83+
seed,
84+
init_seed_shift,
85+
state_holder: Arc::new(Mutex::new(None::<i64>)),
86+
}
87+
}
88+
89+
fn extract_init_state(seed: ScalarValue) -> Result<i64> {
90+
if let ScalarValue::Int64(Some(init_seed)) = seed.cast_to(&DataType::Int64)? {
91+
Ok(init_seed)
92+
} else {
93+
Err(DataFusionError::Internal(
94+
"unexpected execution branch".to_string(),
95+
))
96+
}
97+
}
98+
fn evaluate_batch(&self, seed: ScalarValue, num_rows: usize) -> Result<ColumnarValue> {
99+
let mut seed_state = self.state_holder.lock().unwrap();
100+
let mut rnd = if seed_state.is_none() {
101+
let init_seed = RandExpr::extract_init_state(seed)?;
102+
let init_seed = init_seed.wrapping_add(self.init_seed_shift as i64);
103+
*seed_state = Some(init_seed);
104+
XorShiftRandom::from_init_seed(init_seed)
105+
} else {
106+
let stored_seed = seed_state.unwrap();
107+
XorShiftRandom::from_stored_seed(stored_seed)
108+
};
109+
110+
let mut arr_builder = Float64Builder::with_capacity(num_rows);
111+
std::iter::repeat_with(|| rnd.next_f64())
112+
.take(num_rows)
113+
.for_each(|v| arr_builder.append_value(v));
114+
let array_ref = Arc::new(Float64Array::from(arr_builder.finish()));
115+
*seed_state = Some(rnd.seed);
116+
Ok(ColumnarValue::Array(array_ref))
117+
}
118+
}
119+
120+
impl Display for RandExpr {
121+
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
122+
write!(f, "RAND({})", self.seed)
123+
}
124+
}
125+
126+
impl PartialEq<dyn Any> for RandExpr {
127+
fn eq(&self, other: &dyn Any) -> bool {
128+
down_cast_any_ref(other)
129+
.downcast_ref::<Self>()
130+
.map(|x| self.seed.eq(&x.seed))
131+
.unwrap_or(false)
132+
}
133+
}
134+
135+
impl PhysicalExpr for RandExpr {
136+
fn as_any(&self) -> &dyn Any {
137+
self
138+
}
139+
140+
fn data_type(&self, _input_schema: &Schema) -> Result<DataType> {
141+
Ok(DataType::Float64)
142+
}
143+
144+
fn nullable(&self, _input_schema: &Schema) -> Result<bool> {
145+
Ok(false)
146+
}
147+
148+
fn evaluate(&self, batch: &RecordBatch) -> Result<ColumnarValue> {
149+
match self.seed.evaluate(batch)? {
150+
ColumnarValue::Scalar(seed) => self.evaluate_batch(seed, batch.num_rows()),
151+
ColumnarValue::Array(_arr) => Err(DataFusionError::NotImplemented(format!(
152+
"Only literal seeds are not supported for {}",
153+
self
154+
))),
155+
}
156+
}
157+
158+
fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
159+
vec![&self.seed]
160+
}
161+
162+
fn with_new_children(
163+
self: Arc<Self>,
164+
children: Vec<Arc<dyn PhysicalExpr>>,
165+
) -> Result<Arc<dyn PhysicalExpr>> {
166+
Ok(Arc::new(RandExpr::new(
167+
Arc::clone(&children[0]),
168+
self.init_seed_shift,
169+
)))
170+
}
171+
172+
fn dyn_hash(&self, state: &mut dyn Hasher) {
173+
let mut s = state;
174+
self.children().hash(&mut s);
175+
}
176+
}
177+
178+
pub fn rand(seed: Arc<dyn PhysicalExpr>, init_seed_shift: i32) -> Result<Arc<dyn PhysicalExpr>> {
179+
Ok(Arc::new(RandExpr::new(seed, init_seed_shift)))
180+
}
181+
182+
#[cfg(test)]
183+
mod tests {
184+
use super::*;
185+
use arrow::{array::StringArray, compute::concat, datatypes::*};
186+
use arrow_array::{Array, BooleanArray, Float64Array, Int64Array};
187+
use datafusion_common::cast::as_float64_array;
188+
use datafusion_physical_expr::expressions::lit;
189+
190+
const SPARK_SEED_42_FIRST_5: [f64; 5] = [
191+
0.619189370225301,
192+
0.5096018842446481,
193+
0.8325259388871524,
194+
0.26322809041172357,
195+
0.6702867696264135,
196+
];
197+
198+
#[test]
199+
fn test_rand_single_batch() -> Result<()> {
200+
let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]);
201+
let data = StringArray::from(vec![Some("foo"), None, None, Some("bar"), None]);
202+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
203+
let rand_expr = rand(lit(42), 0)?;
204+
let result = rand_expr.evaluate(&batch)?.into_array(batch.num_rows())?;
205+
let result = as_float64_array(&result)?;
206+
let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
207+
assert_eq!(result, expected);
208+
Ok(())
209+
}
210+
211+
#[test]
212+
fn test_rand_multi_batch() -> Result<()> {
213+
let first_batch_schema = Schema::new(vec![Field::new("a", DataType::Int64, true)]);
214+
let first_batch_data = Int64Array::from(vec![Some(42), None]);
215+
let second_batch_schema = first_batch_schema.clone();
216+
let second_batch_data = Int64Array::from(vec![None, Some(-42), None]);
217+
let rand_expr = rand(lit(42), 0)?;
218+
let first_batch = RecordBatch::try_new(
219+
Arc::new(first_batch_schema),
220+
vec![Arc::new(first_batch_data)],
221+
)?;
222+
let first_batch_result = rand_expr
223+
.evaluate(&first_batch)?
224+
.into_array(first_batch.num_rows())?;
225+
let second_batch = RecordBatch::try_new(
226+
Arc::new(second_batch_schema),
227+
vec![Arc::new(second_batch_data)],
228+
)?;
229+
let second_batch_result = rand_expr
230+
.evaluate(&second_batch)?
231+
.into_array(second_batch.num_rows())?;
232+
let result_arrays: Vec<&dyn Array> = vec![
233+
as_float64_array(&first_batch_result)?,
234+
as_float64_array(&second_batch_result)?,
235+
];
236+
let result_arrays = &concat(&result_arrays)?;
237+
let final_result = as_float64_array(result_arrays)?;
238+
let expected = &Float64Array::from(Vec::from(SPARK_SEED_42_FIRST_5));
239+
assert_eq!(final_result, expected);
240+
Ok(())
241+
}
242+
243+
#[test]
244+
fn test_overflow_shift_seed() -> Result<()> {
245+
let schema = Schema::new(vec![Field::new("a", DataType::Boolean, false)]);
246+
let data = BooleanArray::from(vec![Some(true), Some(false)]);
247+
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(data)])?;
248+
let max_seed_and_shift_expr = rand(lit(i64::MAX), 1)?;
249+
let min_seed_no_shift_expr = rand(lit(i64::MIN), 0)?;
250+
let first_expr_result = max_seed_and_shift_expr
251+
.evaluate(&batch)?
252+
.into_array(batch.num_rows())?;
253+
let first_expr_result = as_float64_array(&first_expr_result)?;
254+
let second_expr_result = min_seed_no_shift_expr
255+
.evaluate(&batch)?
256+
.into_array(batch.num_rows())?;
257+
let second_expr_result = as_float64_array(&second_expr_result)?;
258+
assert_eq!(first_expr_result, second_expr_result);
259+
Ok(())
260+
}
261+
}

spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2272,6 +2272,10 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim
22722272
expr.children(1),
22732273
inputs,
22742274
(builder, binaryExpr) => builder.setArrayAppend(binaryExpr))
2275+
2276+
case Rand(child, _) =>
2277+
createUnaryExpr(child, inputs, (builder, unaryExpr) => builder.setRand(unaryExpr))
2278+
22752279
case _ =>
22762280
withInfo(expr, s"${expr.prettyName} is not supported", expr.children: _*)
22772281
None

spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2517,4 +2517,21 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper {
25172517
checkSparkAnswer(df.select("arrUnsupportedArgs"))
25182518
}
25192519
}
2520+
2521+
test("rand expression with random parameters") {
2522+
val partitionsNumber = Random.nextInt(10) + 1
2523+
val rowsNumber = Random.nextInt(500)
2524+
val seed = Random.nextLong()
2525+
// use this value to have both single-batch and multi-batch partitions
2526+
val cometBatchSize = math.max(1, math.ceil(rowsNumber.toDouble / partitionsNumber).toInt)
2527+
withSQLConf("spark.comet.batchSize" -> cometBatchSize.toString) {
2528+
withParquetDataFrame((0 until rowsNumber).map(Tuple1.apply)) { df =>
2529+
val dfWithRandParameters = df.repartition(partitionsNumber).withColumn("rnd", rand(seed))
2530+
checkSparkAnswer(dfWithRandParameters)
2531+
val dfWithOverflowSeed =
2532+
df.repartition(partitionsNumber).withColumn("rnd", rand(Long.MaxValue))
2533+
checkSparkAnswer(dfWithOverflowSeed)
2534+
}
2535+
}
2536+
}
25202537
}

0 commit comments

Comments
 (0)