Skip to content

Commit 8c35270

Browse files
authored
Fix regex cache on pattern, less alloc, hash less often (#13414)
* cache on pattern, less alloc, hash less often * inline get_pattern * reduce to one hash * remove unnecessary lifetimes
1 parent 75a27a8 commit 8c35270

File tree

1 file changed

+40
-10
lines changed

1 file changed

+40
-10
lines changed

datafusion/functions/src/regex/regexpcount.rs

+40-10
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use datafusion_expr::{
3030
};
3131
use itertools::izip;
3232
use regex::Regex;
33+
use std::collections::hash_map::Entry;
3334
use std::collections::HashMap;
3435
use std::sync::{Arc, OnceLock};
3536

@@ -548,16 +549,22 @@ where
548549
}
549550
}
550551

551-
fn compile_and_cache_regex<'a>(
552-
regex: &'a str,
553-
flags: Option<&'a str>,
554-
regex_cache: &'a mut HashMap<String, Regex>,
555-
) -> Result<&'a Regex, ArrowError> {
556-
if !regex_cache.contains_key(regex) {
557-
let compiled = compile_regex(regex, flags)?;
558-
regex_cache.insert(regex.to_string(), compiled);
559-
}
560-
Ok(regex_cache.get(regex).unwrap())
552+
fn compile_and_cache_regex<'strings, 'cache>(
553+
regex: &'strings str,
554+
flags: Option<&'strings str>,
555+
regex_cache: &'cache mut HashMap<(&'strings str, Option<&'strings str>), Regex>,
556+
) -> Result<&'cache Regex, ArrowError>
557+
where
558+
'strings: 'cache,
559+
{
560+
let result = match regex_cache.entry((regex, flags)) {
561+
Entry::Occupied(occupied_entry) => occupied_entry.into_mut(),
562+
Entry::Vacant(vacant_entry) => {
563+
let compiled = compile_regex(regex, flags)?;
564+
vacant_entry.insert(compiled)
565+
}
566+
};
567+
Ok(result)
561568
}
562569

563570
fn compile_regex(regex: &str, flags: Option<&str>) -> Result<Regex, ArrowError> {
@@ -634,6 +641,8 @@ mod tests {
634641
test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i32>>();
635642
test_case_sensitive_regexp_count_array_complex::<GenericStringArray<i64>>();
636643
test_case_sensitive_regexp_count_array_complex::<StringViewArray>();
644+
645+
test_case_regexp_count_cache_check::<GenericStringArray<i32>>();
637646
}
638647

639648
fn test_case_sensitive_regexp_count_scalar() {
@@ -977,4 +986,25 @@ mod tests {
977986
.unwrap();
978987
assert_eq!(re.as_ref(), &expected);
979988
}
989+
990+
fn test_case_regexp_count_cache_check<A>()
991+
where
992+
A: From<Vec<&'static str>> + Array + 'static,
993+
{
994+
let values = A::from(vec!["aaa", "Aaa", "aaa"]);
995+
let regex = A::from(vec!["aaa", "aaa", "aaa"]);
996+
let start = Int64Array::from(vec![1, 1, 1]);
997+
let flags = A::from(vec!["", "i", ""]);
998+
999+
let expected = Int64Array::from(vec![1, 1, 1]);
1000+
1001+
let re = regexp_count_func(&[
1002+
Arc::new(values),
1003+
Arc::new(regex),
1004+
Arc::new(start),
1005+
Arc::new(flags),
1006+
])
1007+
.unwrap();
1008+
assert_eq!(re.as_ref(), &expected);
1009+
}
9801010
}

0 commit comments

Comments
 (0)