@@ -30,6 +30,7 @@ use datafusion_expr::{
30
30
} ;
31
31
use itertools:: izip;
32
32
use regex:: Regex ;
33
+ use std:: collections:: hash_map:: Entry ;
33
34
use std:: collections:: HashMap ;
34
35
use std:: sync:: { Arc , OnceLock } ;
35
36
@@ -548,16 +549,22 @@ where
548
549
}
549
550
}
550
551
551
- fn compile_and_cache_regex < ' a > (
552
- regex : & ' a str ,
553
- flags : Option < & ' a str > ,
554
- regex_cache : & ' a mut HashMap < String , Regex > ,
555
- ) -> Result < & ' a Regex , ArrowError > {
556
- if !regex_cache. contains_key ( regex) {
557
- let compiled = compile_regex ( regex, flags) ?;
558
- regex_cache. insert ( regex. to_string ( ) , compiled) ;
559
- }
560
- Ok ( regex_cache. get ( regex) . unwrap ( ) )
552
+ fn compile_and_cache_regex < ' strings , ' cache > (
553
+ regex : & ' strings str ,
554
+ flags : Option < & ' strings str > ,
555
+ regex_cache : & ' cache mut HashMap < ( & ' strings str , Option < & ' strings str > ) , Regex > ,
556
+ ) -> Result < & ' cache Regex , ArrowError >
557
+ where
558
+ ' strings : ' cache ,
559
+ {
560
+ let result = match regex_cache. entry ( ( regex, flags) ) {
561
+ Entry :: Occupied ( occupied_entry) => occupied_entry. into_mut ( ) ,
562
+ Entry :: Vacant ( vacant_entry) => {
563
+ let compiled = compile_regex ( regex, flags) ?;
564
+ vacant_entry. insert ( compiled)
565
+ }
566
+ } ;
567
+ Ok ( result)
561
568
}
562
569
563
570
fn compile_regex ( regex : & str , flags : Option < & str > ) -> Result < Regex , ArrowError > {
@@ -634,6 +641,8 @@ mod tests {
634
641
test_case_sensitive_regexp_count_array_complex :: < GenericStringArray < i32 > > ( ) ;
635
642
test_case_sensitive_regexp_count_array_complex :: < GenericStringArray < i64 > > ( ) ;
636
643
test_case_sensitive_regexp_count_array_complex :: < StringViewArray > ( ) ;
644
+
645
+ test_case_regexp_count_cache_check :: < GenericStringArray < i32 > > ( ) ;
637
646
}
638
647
639
648
fn test_case_sensitive_regexp_count_scalar ( ) {
@@ -977,4 +986,25 @@ mod tests {
977
986
. unwrap ( ) ;
978
987
assert_eq ! ( re. as_ref( ) , & expected) ;
979
988
}
989
+
990
+ fn test_case_regexp_count_cache_check < A > ( )
991
+ where
992
+ A : From < Vec < & ' static str > > + Array + ' static ,
993
+ {
994
+ let values = A :: from ( vec ! [ "aaa" , "Aaa" , "aaa" ] ) ;
995
+ let regex = A :: from ( vec ! [ "aaa" , "aaa" , "aaa" ] ) ;
996
+ let start = Int64Array :: from ( vec ! [ 1 , 1 , 1 ] ) ;
997
+ let flags = A :: from ( vec ! [ "" , "i" , "" ] ) ;
998
+
999
+ let expected = Int64Array :: from ( vec ! [ 1 , 1 , 1 ] ) ;
1000
+
1001
+ let re = regexp_count_func ( & [
1002
+ Arc :: new ( values) ,
1003
+ Arc :: new ( regex) ,
1004
+ Arc :: new ( start) ,
1005
+ Arc :: new ( flags) ,
1006
+ ] )
1007
+ . unwrap ( ) ;
1008
+ assert_eq ! ( re. as_ref( ) , & expected) ;
1009
+ }
980
1010
}
0 commit comments