Skip to content

Commit 6d71e03

Browse files
author
Dima
committed
improve performance of regexp_count
1 parent cd69e37 commit 6d71e03

File tree

1 file changed

+26
-23
lines changed

1 file changed

+26
-23
lines changed

datafusion/functions/src/regex/regexpcount.rs

Lines changed: 26 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@ use datafusion_expr::{
3030
};
3131
use itertools::izip;
3232
use regex::Regex;
33-
use std::collections::hash_map::Entry;
3433
use std::collections::HashMap;
3534
use std::sync::{Arc, OnceLock};
3635

@@ -310,12 +309,13 @@ where
310309
Some(regex) => regex,
311310
};
312311

313-
let pattern = compile_regex(regex, flags_scalar)?;
312+
let pattern = get_pattern(regex, flags_scalar)?;
313+
let re = compile_regex(pattern)?;
314314

315315
Ok(Arc::new(Int64Array::from_iter_values(
316316
values
317317
.iter()
318-
.map(|value| count_matches(value, &pattern, start_scalar))
318+
.map(|value| count_matches(value, &re, start_scalar))
319319
.collect::<Result<Vec<i64>, ArrowError>>()?,
320320
)))
321321
}
@@ -356,15 +356,15 @@ where
356356
Some(regex) => regex,
357357
};
358358

359-
let pattern = compile_regex(regex, flags_scalar)?;
360-
359+
let pattern = get_pattern(regex, flags_scalar)?;
360+
let re = compile_regex(pattern)?;
361361
let start_array = start_array.unwrap();
362362

363363
Ok(Arc::new(Int64Array::from_iter_values(
364364
values
365365
.iter()
366366
.zip(start_array.iter())
367-
.map(|(value, start)| count_matches(value, &pattern, start))
367+
.map(|(value, start)| count_matches(value, &re, start))
368368
.collect::<Result<Vec<i64>, ArrowError>>()?,
369369
)))
370370
}
@@ -549,34 +549,37 @@ where
549549
}
550550
}
551551

552-
fn compile_and_cache_regex(
553-
regex: &str,
554-
flags: Option<&str>,
555-
regex_cache: &mut HashMap<String, Regex>,
556-
) -> Result<Regex, ArrowError> {
557-
match regex_cache.entry(regex.to_string()) {
558-
Entry::Vacant(entry) => {
559-
let compiled = compile_regex(regex, flags)?;
560-
entry.insert(compiled.clone());
561-
Ok(compiled)
562-
}
563-
Entry::Occupied(entry) => Ok(entry.get().to_owned()),
552+
fn compile_and_cache_regex<'a>(
553+
regex: &'a str,
554+
flags: Option<&'a str>,
555+
regex_cache: &'a mut HashMap<String, Regex>,
556+
) -> Result<&'a Regex, ArrowError> {
557+
let pattern = get_pattern(regex, flags)?;
558+
559+
if regex_cache.contains_key(&pattern) {
560+
return Ok(regex_cache.get(&pattern).unwrap());
564561
}
562+
563+
let re = compile_regex(pattern.clone())?;
564+
regex_cache.insert(pattern.clone(), re);
565+
Ok(regex_cache.get(&pattern).unwrap())
565566
}
566567

567-
fn compile_regex(regex: &str, flags: Option<&str>) -> Result<Regex, ArrowError> {
568-
let pattern = match flags {
569-
None | Some("") => regex.to_string(),
568+
fn get_pattern(regex: &str, flags: Option<&str>) -> Result<String, ArrowError> {
569+
match flags {
570+
None | Some("") => Ok(regex.to_string()),
570571
Some(flags) => {
571572
if flags.contains("g") {
572573
return Err(ArrowError::ComputeError(
573574
"regexp_count() does not support global flag".to_string(),
574575
));
575576
}
576-
format!("(?{}){}", flags, regex)
577+
Ok(format!("(?{}){}", flags, regex))
577578
}
578-
};
579+
}
580+
}
579581

582+
fn compile_regex(pattern: String) -> Result<Regex, ArrowError> {
580583
Regex::new(&pattern).map_err(|_| {
581584
ArrowError::ComputeError(format!(
582585
"Regular expression did not compile: {}",

0 commit comments

Comments
 (0)