Skip to content

Improve benchmark for ltrim #12513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
223 changes: 206 additions & 17 deletions datafusion/functions/benches/ltrim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,32 +17,221 @@

extern crate criterion;

use arrow::array::{ArrayRef, StringArray};
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use arrow::array::{ArrayRef, LargeStringArray, StringArray, StringViewArray};
use criterion::{
black_box, criterion_group, criterion_main, measurement::Measurement, BenchmarkGroup,
Criterion, SamplingMode,
};
use datafusion_common::ScalarValue;
use datafusion_expr::ColumnarValue;
use datafusion_expr::{ColumnarValue, ScalarUDF};
use datafusion_functions::string;
use std::sync::Arc;
use rand::{distributions::Alphanumeric, rngs::StdRng, Rng, SeedableRng};
use std::{fmt, sync::Arc};

fn create_args(size: usize, characters: &str) -> Vec<ColumnarValue> {
let iter =
std::iter::repeat(format!("{}datafusion{}", characters, characters)).take(size);
let array = Arc::new(StringArray::from_iter_values(iter)) as ArrayRef;
pub fn seedable_rng() -> StdRng {
StdRng::seed_from_u64(42)
}

#[derive(Clone, Copy)]
pub enum StringArrayType {
Utf8View,
Utf8,
LargeUtf8,
}

impl fmt::Display for StringArrayType {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
StringArrayType::Utf8View => f.write_str("string_view"),
StringArrayType::Utf8 => f.write_str("string"),
StringArrayType::LargeUtf8 => f.write_str("large_string"),
}
}
}

/// returns an array of strings, and `characters` as a ScalarValue
pub fn create_string_array_and_characters(
size: usize,
characters: &str,
trimmed: &str,
remaining_len: usize,
string_array_type: StringArrayType,
) -> (ArrayRef, ScalarValue) {
let rng = &mut seedable_rng();

// Create `size` rows:
// - 10% rows will be `None`
// - Other 90% will be strings with same `remaining_len` lengths
// We will build the string array on it later.
let string_iter = (0..size).map(|_| {
if rng.gen::<f32>() < 0.1 {
None
} else {
let mut value = trimmed.as_bytes().to_vec();
let generated = rng.sample_iter(&Alphanumeric).take(remaining_len);
value.extend(generated);
Some(String::from_utf8(value).unwrap())
}
});

// Build the target `string array` and `characters` according to `string_array_type`
match string_array_type {
StringArrayType::Utf8View => (
Arc::new(string_iter.collect::<StringViewArray>()),
ScalarValue::Utf8View(Some(characters.to_string())),
),
StringArrayType::Utf8 => (
Arc::new(string_iter.collect::<StringArray>()),
ScalarValue::Utf8(Some(characters.to_string())),
),
StringArrayType::LargeUtf8 => (
Arc::new(string_iter.collect::<LargeStringArray>()),
ScalarValue::LargeUtf8(Some(characters.to_string())),
),
}
}

/// Create args for the ltrim benchmark
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

❤️

/// Inputs:
/// - size: rows num of the test array
/// - characters: the characters we need to trim
/// - trimmed: the part in the testing string that will be trimmed
/// - remaining_len: the len of the remaining part of testing string after trimming
/// - string_array_type: the method used to store the testing strings
///
/// Outputs:
/// - testing string array
/// - trimmed characters
///
fn create_args(
size: usize,
characters: &str,
trimmed: &str,
remaining_len: usize,
string_array_type: StringArrayType,
) -> Vec<ColumnarValue> {
let (string_array, pattern) = create_string_array_and_characters(
size,
characters,
trimmed,
remaining_len,
string_array_type,
);
vec![
ColumnarValue::Array(array),
ColumnarValue::Scalar(ScalarValue::Utf8(Some(characters.to_string()))),
ColumnarValue::Array(string_array),
ColumnarValue::Scalar(pattern),
]
}

#[allow(clippy::too_many_arguments)]
fn run_with_string_type<M: Measurement>(
group: &mut BenchmarkGroup<'_, M>,
ltrim: &ScalarUDF,
size: usize,
len: usize,
characters: &str,
trimmed: &str,
remaining_len: usize,
string_type: StringArrayType,
) {
let args = create_args(size, characters, trimmed, remaining_len, string_type);
group.bench_function(
format!(
"{string_type} [size={size}, len_before={len}, len_after={remaining_len}]",
),
|b| b.iter(|| black_box(ltrim.invoke(&args))),
);
}

#[allow(clippy::too_many_arguments)]
fn run_one_group(
c: &mut Criterion,
group_name: &str,
ltrim: &ScalarUDF,
string_types: &[StringArrayType],
size: usize,
len: usize,
characters: &str,
trimmed: &str,
remaining_len: usize,
) {
let mut group = c.benchmark_group(group_name);
group.sampling_mode(SamplingMode::Flat);
group.sample_size(10);

for string_type in string_types {
run_with_string_type(
&mut group,
ltrim,
size,
len,
characters,
trimmed,
remaining_len,
*string_type,
);
}

group.finish();
}

fn criterion_benchmark(c: &mut Criterion) {
let ltrim = string::ltrim();
for char in ["\"", "Header:"] {
for size in [1024, 4096, 8192] {
let args = create_args(size, char);
c.bench_function(&format!("ltrim {}: {}", char, size), |b| {
b.iter(|| black_box(ltrim.invoke(&args)))
});
}
let characters = ",!()";

let string_types = [
StringArrayType::Utf8View,
StringArrayType::Utf8,
StringArrayType::LargeUtf8,
];
for size in [1024, 4096, 8192] {
// len=12, trimmed_len=4, len_after_ltrim=8
let len = 12;
let trimmed = characters;
let remaining_len = len - trimmed.len();
run_one_group(
c,
"INPUT LEN <= 12",
&ltrim,
&string_types,
size,
len,
characters,
trimmed,
remaining_len,
);

// len=64, trimmed_len=4, len_after_ltrim=60
let len = 64;
let trimmed = characters;
let remaining_len = len - trimmed.len();
run_one_group(
c,
"INPUT LEN > 12, OUTPUT LEN > 12",
&ltrim,
&string_types,
size,
len,
characters,
trimmed,
remaining_len,
);

// len=64, trimmed_len=56, len_after_ltrim=8
let len = 64;
let trimmed = characters.repeat(15);
let remaining_len = len - trimmed.len();
run_one_group(
c,
"INPUT LEN > 12, OUTPUT LEN <= 12",
&ltrim,
&string_types,
size,
len,
characters,
&trimmed,
remaining_len,
);
}
}

Expand Down