Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,11 @@ use super::{serde, utils::SparkArrowConvert};
use crate::{
errors::{try_unwrap_or_throw, CometError, CometResult},
execution::{
metrics::utils::update_comet_metric, planner::PhysicalPlanner, serde::to_arrow_datatype,
shuffle::spark_unsafe::row::process_sorted_row_partition, sort::RdxSort,
metrics::utils::{build_metric_layout, update_comet_metric, MetricLayout},
planner::PhysicalPlanner,
serde::to_arrow_datatype,
shuffle::spark_unsafe::row::process_sorted_row_partition,
sort::RdxSort,
},
jvm_bridge::{jni_new_global_ref, JVMClasses},
};
Expand Down Expand Up @@ -173,6 +176,8 @@ struct ExecutionContext {
pub memory_pool_config: MemoryPoolConfig,
/// Whether to log memory usage on each call to execute_plan
pub tracing_enabled: bool,
/// Pre-computed metric layout for flat array metric updates
pub metric_layout: Option<MetricLayout>,
}

/// Accept serialized query plan and return the address of the native query plan.
Expand Down Expand Up @@ -320,6 +325,7 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
explain_native,
memory_pool_config,
tracing_enabled,
metric_layout: None,
});

Ok(Box::into_raw(exec_context) as i64)
Expand Down Expand Up @@ -544,6 +550,10 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
exec_context.root_op = Some(Arc::clone(&root_op));
exec_context.scans = scans;

// Build the flat metric layout for efficient metric updates
let metrics = exec_context.metrics.as_obj();
exec_context.metric_layout = Some(build_metric_layout(&mut env, metrics)?);

if exec_context.explain_native {
let formatted_plan_str =
DisplayableExecutionPlan::new(root_op.native_plan.as_ref()).indent(true);
Expand Down Expand Up @@ -675,9 +685,14 @@ pub extern "system" fn Java_org_apache_comet_Native_releasePlan(

/// Updates the metrics of the query plan.
fn update_metrics(env: &mut JNIEnv, exec_context: &mut ExecutionContext) -> CometResult<()> {
if let Some(native_query) = &exec_context.root_op {
if let Some(ref native_query) = exec_context.root_op {
let native_query = Arc::clone(native_query);
let metrics = exec_context.metrics.as_obj();
update_comet_metric(env, metrics, native_query)
if let Some(ref mut layout) = exec_context.metric_layout {
update_comet_metric(env, metrics, &native_query, layout)
} else {
Ok(())
}
} else {
Ok(())
}
Expand Down
150 changes: 114 additions & 36 deletions native/core/src/execution/metrics/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,94 @@
use crate::execution::spark_plan::SparkPlan;
use crate::{errors::CometError, jvm_bridge::jni_call};
use datafusion::physical_plan::metrics::MetricValue;
use datafusion_comet_proto::spark_metric::NativeMetricNode;
use jni::{objects::JObject, JNIEnv};
use prost::Message;
use jni::objects::{GlobalRef, JIntArray, JLongArray, JObject, JObjectArray};
use jni::JNIEnv;
use std::collections::HashMap;
use std::sync::Arc;

/// Updates the metrics of a CometMetricNode. This function is called recursively to
/// update the metrics of all the children nodes. The metrics are pulled from the
/// native execution plan and pushed to the Java side through JNI.
pub(crate) fn update_comet_metric(
/// Pre-computed layout mapping metric names to indices in a flat array.
/// Built once at plan creation, reused on every metric update.
pub(crate) struct MetricLayout {
/// Per SparkPlan node (DFS order), maps metric name to index in the flat values array
node_indices: Vec<HashMap<String, usize>>,
/// Flat array of metric values, written by native and bulk-copied to JVM
values: Vec<i64>,
/// Global reference to the JVM long[] array (kept alive for the lifetime of the plan)
jarray: Arc<GlobalRef>,
}

/// Builds a MetricLayout by calling JNI methods on the CometMetricNode to retrieve
/// the flattened metric names, node offsets, and a reference to the pre-allocated long[].
pub(crate) fn build_metric_layout(
env: &mut JNIEnv,
metric_node: &JObject,
spark_plan: &Arc<SparkPlan>,
) -> Result<(), CometError> {
if metric_node.is_null() {
return Ok(());
) -> Result<MetricLayout, CometError> {
// Get metric names array (String[])
let names_obj: JObject =
unsafe { jni_call!(env, comet_metric_node(metric_node).get_metric_names() -> JObject) }?;
let names_array = JObjectArray::from(names_obj);
let num_metrics = env.get_array_length(&names_array)? as usize;

let mut metric_names = Vec::with_capacity(num_metrics);
for i in 0..num_metrics {
let jstr = env.get_object_array_element(&names_array, i as i32)?;
let name: String = env.get_string((&jstr).into())?.into();
metric_names.push(name);
}

let native_metric = to_native_metric_node(spark_plan);
let jbytes = env.byte_array_from_slice(&native_metric?.encode_to_vec())?;
// Get node offsets array (int[])
let offsets_obj: JObject =
unsafe { jni_call!(env, comet_metric_node(metric_node).get_node_offsets() -> JObject) }?;
let offsets_array = JIntArray::from(offsets_obj);
let num_offsets = env.get_array_length(&offsets_array)? as usize;
let mut offsets = vec![0i32; num_offsets];
env.get_int_array_region(&offsets_array, 0, &mut offsets)?;

// Get values array reference (long[])
let values_obj: JObject =
unsafe { jni_call!(env, comet_metric_node(metric_node).get_values_array() -> JObject) }?;
let jarray = Arc::new(env.new_global_ref(values_obj)?);

unsafe { jni_call!(env, comet_metric_node(metric_node).set_all_from_bytes(&jbytes) -> ()) }
// Build per-node index maps
let num_nodes = num_offsets - 1;
let mut node_indices = Vec::with_capacity(num_nodes);
for node_idx in 0..num_nodes {
let start = offsets[node_idx] as usize;
let end = offsets[node_idx + 1] as usize;
let mut map = HashMap::with_capacity(end - start);
for (i, name) in metric_names.iter().enumerate().take(end).skip(start) {
map.insert(name.clone(), i);
}
node_indices.push(map);
}

Ok(MetricLayout {
node_indices,
values: vec![-1i64; num_metrics],
jarray,
})
}

pub(crate) fn to_native_metric_node(
/// Recursively fills the values array from DataFusion metrics on the SparkPlan tree.
fn fill_metric_values(
spark_plan: &Arc<SparkPlan>,
) -> Result<NativeMetricNode, CometError> {
let mut native_metric_node = NativeMetricNode {
metrics: HashMap::new(),
children: Vec::new(),
};
layout: &mut MetricLayout,
node_idx: &mut usize,
) {
let current_node = *node_idx;
*node_idx += 1;

if current_node >= layout.node_indices.len() {
// Skip if node index exceeds layout (shouldn't happen with correct setup)
for child in spark_plan.children() {
fill_metric_values(child, layout, node_idx);
}
return;
}

let indices = &layout.node_indices[current_node];

// Collect metrics from the native plan (and additional plans)
let node_metrics = if spark_plan.additional_native_plans.is_empty() {
spark_plan.native_plan.metrics()
} else {
Expand All @@ -59,7 +115,7 @@ pub(crate) fn to_native_metric_node(
for c in additional_metrics.iter() {
match c.value() {
MetricValue::OutputRows(_) => {
// we do not want to double count output rows
// do not double count output rows
}
_ => metrics.push(c.to_owned()),
}
Expand All @@ -68,21 +124,43 @@ pub(crate) fn to_native_metric_node(
Some(metrics.aggregate_by_name())
};

// add metrics
node_metrics
.unwrap_or_default()
.iter()
.map(|m| m.value())
.map(|m| (m.name(), m.as_usize() as i64))
.for_each(|(name, value)| {
native_metric_node.metrics.insert(name.to_string(), value);
});

// add children
for child_plan in spark_plan.children() {
let child_node = to_native_metric_node(child_plan)?;
native_metric_node.children.push(child_node);
// Write metric values into their pre-assigned slots
if let Some(metrics) = node_metrics {
for m in metrics.iter() {
let value = m.value();
let name = value.name();
if let Some(&idx) = indices.get(name) {
layout.values[idx] = value.as_usize() as i64;
}
}
}

// Recurse into children
for child in spark_plan.children() {
fill_metric_values(child, layout, node_idx);
}
}

/// Updates metrics by filling the flat values array and bulk-copying to JVM.
pub(crate) fn update_comet_metric(
env: &mut JNIEnv,
metric_node: &JObject,
spark_plan: &Arc<SparkPlan>,
layout: &mut MetricLayout,
) -> Result<(), CometError> {
if metric_node.is_null() {
return Ok(());
}

Ok(native_metric_node)
// Fill values from native metrics
let mut node_idx = 0;
fill_metric_values(spark_plan, layout, &mut node_idx);

// Bulk copy values to JVM long[] via SetLongArrayRegion
let local_ref = env.new_local_ref(layout.jarray.as_obj())?;
let jlong_array = JLongArray::from(local_ref);
env.set_long_array_region(&jlong_array, 0, &layout.values)?;

// Call updateFromValues() on the JVM side
unsafe { jni_call!(env, comet_metric_node(metric_node).update_from_values() -> ()) }
}
44 changes: 28 additions & 16 deletions native/core/src/jvm_bridge/comet_metric_node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ use jni::{
#[allow(dead_code)] // we need to keep references to Java items to prevent GC
pub struct CometMetricNode<'a> {
pub class: JClass<'a>,
pub method_get_child_node: JMethodID,
pub method_get_child_node_ret: ReturnType,
pub method_set: JMethodID,
pub method_set_ret: ReturnType,
pub method_set_all_from_bytes: JMethodID,
pub method_set_all_from_bytes_ret: ReturnType,
pub method_get_metric_names: JMethodID,
pub method_get_metric_names_ret: ReturnType,
pub method_get_node_offsets: JMethodID,
pub method_get_node_offsets_ret: ReturnType,
pub method_get_values_array: JMethodID,
pub method_get_values_array_ret: ReturnType,
pub method_update_from_values: JMethodID,
pub method_update_from_values_ret: ReturnType,
}

impl<'a> CometMetricNode<'a> {
Expand All @@ -41,20 +43,30 @@ impl<'a> CometMetricNode<'a> {
let class = env.find_class(Self::JVM_CLASS)?;

Ok(CometMetricNode {
method_get_child_node: env.get_method_id(
method_get_metric_names: env.get_method_id(
Self::JVM_CLASS,
"getChildNode",
format!("(I)L{:};", Self::JVM_CLASS).as_str(),
"getMetricNames",
"()[Ljava/lang/String;",
)?,
method_get_child_node_ret: ReturnType::Object,
method_set: env.get_method_id(Self::JVM_CLASS, "set", "(Ljava/lang/String;J)V")?,
method_set_ret: ReturnType::Primitive(Primitive::Void),
method_set_all_from_bytes: env.get_method_id(
method_get_metric_names_ret: ReturnType::Object,
method_get_node_offsets: env.get_method_id(
Self::JVM_CLASS,
"set_all_from_bytes",
"([B)V",
"getNodeOffsets",
"()[I",
)?,
method_set_all_from_bytes_ret: ReturnType::Primitive(Primitive::Void),
method_get_node_offsets_ret: ReturnType::Object,
method_get_values_array: env.get_method_id(
Self::JVM_CLASS,
"getValuesArray",
"()[J",
)?,
method_get_values_array_ret: ReturnType::Object,
method_update_from_values: env.get_method_id(
Self::JVM_CLASS,
"updateFromValues",
"()V",
)?,
method_update_from_values_ret: ReturnType::Primitive(Primitive::Void),
class,
})
}
Expand Down
10 changes: 8 additions & 2 deletions native/core/src/parquet/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ use jni::{

use self::util::jni::TypePromotionInfo;
use crate::execution::jni_api::get_runtime;
use crate::execution::metrics::utils::update_comet_metric;
use crate::execution::metrics::utils::{build_metric_layout, update_comet_metric, MetricLayout};
use crate::execution::operators::ExecutionError;
use crate::execution::planner::PhysicalPlanner;
use crate::execution::serde;
Expand Down Expand Up @@ -605,6 +605,7 @@ enum ParquetReaderState {
struct BatchContext {
native_plan: Arc<SparkPlan>,
metrics_node: Arc<GlobalRef>,
metric_layout: MetricLayout,
batch_stream: Option<SendableRecordBatchStream>,
current_batch: Option<RecordBatch>,
reader_state: ParquetReaderState,
Expand Down Expand Up @@ -780,9 +781,13 @@ pub unsafe extern "system" fn Java_org_apache_comet_parquet_Native_initRecordBat
let partition_index: usize = 0;
let batch_stream = scan.execute(partition_index, session_ctx.task_ctx())?;

let metrics_global_ref = Arc::new(jni_new_global_ref!(env, metrics_node)?);
let metric_layout = build_metric_layout(&mut env, metrics_global_ref.as_obj())?;

let ctx = BatchContext {
native_plan: Arc::new(SparkPlan::new(0, scan, vec![])),
metrics_node: Arc::new(jni_new_global_ref!(env, metrics_node)?),
metrics_node: metrics_global_ref,
metric_layout,
batch_stream: Some(batch_stream),
current_batch: None,
reader_state: ParquetReaderState::Init,
Expand Down Expand Up @@ -825,6 +830,7 @@ pub extern "system" fn Java_org_apache_comet_parquet_Native_readNextRecordBatch(
&mut env,
context.metrics_node.as_obj(),
&context.native_plan,
&mut context.metric_layout,
)?;

context.current_batch = None;
Expand Down
1 change: 0 additions & 1 deletion native/proto/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ fn main() -> Result<()> {
prost_build::Config::new().out_dir(out_dir).compile_protos(
&[
"src/proto/expr.proto",
"src/proto/metric.proto",
"src/proto/partitioning.proto",
"src/proto/operator.proto",
"src/proto/config.proto",
Expand Down
6 changes: 0 additions & 6 deletions native/proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ pub mod spark_operator {
include!(concat!("generated", "/spark.spark_operator.rs"));
}

// Include generated modules from .proto files.
#[allow(missing_docs)]
pub mod spark_metric {
include!(concat!("generated", "/spark.spark_metric.rs"));
}

// Include generated modules from .proto files.
#[allow(missing_docs)]
pub mod spark_config {
Expand Down
29 changes: 0 additions & 29 deletions native/proto/src/proto/metric.proto

This file was deleted.

Loading
Loading