Skip to content

Commit 55aa92e

Browse files
committed
Refactor
1 parent cce1f05 commit 55aa92e

File tree

1 file changed

+43
-16
lines changed

1 file changed

+43
-16
lines changed

parquet/benches/arrow_reader_clickbench.rs

Lines changed: 43 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
//! [ClickBench]: https://benchmark.clickhouse.com/
2828
2929
use arrow::compute::kernels::cmp::{eq, neq};
30-
use arrow::compute::{like, nlike};
30+
use arrow::compute::{like, nlike, or};
3131
use arrow_array::types::{Int16Type, Int32Type, Int64Type};
3232
use arrow_array::{ArrayRef, ArrowPrimitiveType, BooleanArray, PrimitiveArray, StringViewArray};
3333
use arrow_schema::{ArrowError, DataType, Schema};
@@ -96,9 +96,7 @@ struct Query {
9696
filter_columns: Vec<&'static str>,
9797
/// Which columns will by projected (decoded after filter)
9898
projection_columns: Vec<&'static str>,
99-
/// Returns a Vec of `RunPredicateFn` that filter the data. The
100-
/// `RecordBatch` passed to the fn has the columns specified in
101-
/// `filter_columns`
99+
/// Predicates to apply
102100
predicates: Vec<ClickBenchPredicate>,
103101
/// How many rows are expected to pass the predicate. This serves
104102
/// as a sanity check that the benchmark is working correctly.
@@ -420,8 +418,7 @@ fn all_queries() -> Vec<Query> {
420418
// ClickBenchPredicate::gt_eq_literal::<Int16Type>(1, str_to_i16_date("2013-07-01")),
421419
// ClickBenchPredicate::lt_eq_literal::<Int16Type>(1, str_to_i16_date("2013-07-31")),
422420
ClickBenchPredicate::eq_literal::<Int16Type>(2, 0),
423-
// TODO implement IN predicate
424-
ClickBenchPredicate::eq_literal::<Int16Type>(3, -1), // IN -1, 6
421+
ClickBenchPredicate::in_list::<Int16Type>(3, (-1, 6)), // IN -1, 6
425422
ClickBenchPredicate::eq_literal::<Int64Type>(4, 3594120000172545465),
426423
],
427424
expected_row_count: 24793,
@@ -512,6 +509,23 @@ impl ClickBenchPredicate {
512509
})
513510
}
514511

512+
/// Create Predicate: col IN (lit1, lit2)
513+
fn in_list<T: ArrowPrimitiveType>(
514+
column_index: usize,
515+
literal_values: (T::Native, T::Native),
516+
) -> Self {
517+
Self::new(column_index, move || {
518+
let literal_1 = PrimitiveArray::<T>::new_scalar(literal_values.0);
519+
let literal_2 = PrimitiveArray::<T>::new_scalar(literal_values.1);
520+
Box::new(move |col| {
521+
// use OR
522+
let match1 = eq(&col, &literal_1)?;
523+
let match2 = eq(&col, &literal_2)?;
524+
or(&match1, &match2)
525+
})
526+
})
527+
}
528+
515529
/// Create predicate: col != ''
516530
fn neq_literal<T: ArrowPrimitiveType>(column_index: usize, literal_value: T::Native) -> Self {
517531
Self::new(column_index, move || {
@@ -660,19 +674,33 @@ impl FilterIndices {
660674

661675
/// Encapsulates the test parameters for a single benchmark
662676
struct ReadTest {
677+
/// Human identifiable name
678+
name: &'static str,
679+
/// Metadata from Parquet file
663680
arrow_reader_metadata: ArrowReaderMetadata,
664-
// TODO keep only fields needed (inline Query field)
665-
query: Query,
666681
/// Which columns in the file should be projected (decoded after filter)
667682
projection_mask: ProjectionMask,
668683
/// Which columns in the file should be passed to the filter.
669684
filter_mask: ProjectionMask,
670685
/// Mapping from column selected in filter mask to Query::filter_columns
671686
filter_indices: FilterIndices,
687+
/// Predicates to apply
688+
predicates: Vec<ClickBenchPredicate>,
689+
/// How many rows are expected to pass the predicate. This serves
690+
/// as a sanity check that the benchmark is working correctly.
691+
expected_row_count: usize,
672692
}
673693

674694
impl ReadTest {
675695
fn new(query: Query) -> Self {
696+
let Query {
697+
name,
698+
filter_columns,
699+
projection_columns,
700+
predicates,
701+
expected_row_count,
702+
} = query;
703+
676704
let arrow_reader_metadata = load_metadata(hits_1());
677705
let schema_descr = arrow_reader_metadata
678706
.metadata()
@@ -685,27 +713,27 @@ impl ReadTest {
685713
// Determine the correct selection ("ProjectionMask")
686714
//ProjectionMask::columns(schema, projection_columns)
687715

688-
let projection_columns = &query.projection_columns;
689716
let projection_mask = if projection_columns.iter().any(|&name| name == "*") {
690717
// * means all columns
691718
ProjectionMask::all()
692719
} else {
693-
let projection_schema_indices = column_indices(schema_descr, &query.projection_columns);
720+
let projection_schema_indices = column_indices(schema_descr, &projection_columns);
694721
ProjectionMask::leaves(schema_descr, projection_schema_indices)
695722
};
696723

697-
let filter_columns = &query.filter_columns;
698-
let filter_schema_indices = column_indices(schema_descr, filter_columns);
724+
let filter_schema_indices = column_indices(schema_descr, &filter_columns);
699725
let filter_mask =
700726
ProjectionMask::leaves(schema_descr, filter_schema_indices.iter().cloned());
701727
let filter_indices = FilterIndices::new(schema_descr, filter_schema_indices);
702728

703729
Self {
730+
name,
704731
arrow_reader_metadata,
705-
query,
706732
projection_mask,
707733
filter_mask,
708734
filter_indices,
735+
predicates,
736+
expected_row_count,
709737
}
710738
}
711739

@@ -774,7 +802,6 @@ impl ReadTest {
774802
//let run_predicate_fns = (self.query.predicate)();
775803
// Convert the predicates to ArrowPredicateFn to conform to the RowFilter API
776804
let arrow_predicates: Vec<_> = self
777-
.query
778805
.predicates
779806
.iter()
780807
.map(|pred| {
@@ -795,11 +822,11 @@ impl ReadTest {
795822
}
796823

797824
fn check_row_count(&self, row_count: usize) {
798-
let expected_row_count = self.query.expected_row_count;
825+
let expected_row_count = self.expected_row_count;
799826
assert_eq!(
800827
row_count, expected_row_count,
801828
"Expected {} rows, but got {} in {}",
802-
expected_row_count, row_count, self.query,
829+
expected_row_count, row_count, self.name,
803830
);
804831
}
805832
}

0 commit comments

Comments
 (0)