Skip to content

Commit 18193e6

Browse files
dharanadjayzhan211
andauthored
chore: Add SessionState to MockContextProvider just like SessionContextProvider (#11940)
* refac: mock context provide to match public api * lower udaf names * cleanup * typos Co-authored-by: Jay Zhan <[email protected]> * more typos Co-authored-by: Jay Zhan <[email protected]> * typos * refactor func name --------- Co-authored-by: Jay Zhan <[email protected]>
1 parent e66636d commit 18193e6

File tree

3 files changed

+83
-61
lines changed

3 files changed

+83
-61
lines changed

datafusion/sql/tests/cases/plan_to_sql.rs

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use datafusion_functions::core::planner::CoreFunctionPlanner;
3333
use sqlparser::dialect::{Dialect, GenericDialect, MySqlDialect};
3434
use sqlparser::parser::Parser;
3535

36-
use crate::common::MockContextProvider;
36+
use crate::common::{MockContextProvider, MockSessionState};
3737

3838
#[test]
3939
fn roundtrip_expr() {
@@ -59,8 +59,8 @@ fn roundtrip_expr() {
5959
let roundtrip = |table, sql: &str| -> Result<String> {
6060
let dialect = GenericDialect {};
6161
let sql_expr = Parser::new(&dialect).try_with_sql(sql)?.parse_expr()?;
62-
63-
let context = MockContextProvider::default().with_udaf(sum_udaf());
62+
let state = MockSessionState::default().with_aggregate_function(sum_udaf());
63+
let context = MockContextProvider { state };
6464
let schema = context.get_table_source(table)?.schema();
6565
let df_schema = DFSchema::try_from(schema.as_ref().clone())?;
6666
let sql_to_rel = SqlToRel::new(&context);
@@ -156,11 +156,11 @@ fn roundtrip_statement() -> Result<()> {
156156
let statement = Parser::new(&dialect)
157157
.try_with_sql(query)?
158158
.parse_statement()?;
159-
160-
let context = MockContextProvider::default()
161-
.with_udaf(sum_udaf())
162-
.with_udaf(count_udaf())
159+
let state = MockSessionState::default()
160+
.with_aggregate_function(sum_udaf())
161+
.with_aggregate_function(count_udaf())
163162
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
163+
let context = MockContextProvider { state };
164164
let sql_to_rel = SqlToRel::new(&context);
165165
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
166166

@@ -189,8 +189,10 @@ fn roundtrip_crossjoin() -> Result<()> {
189189
.try_with_sql(query)?
190190
.parse_statement()?;
191191

192-
let context = MockContextProvider::default()
192+
let state = MockSessionState::default()
193193
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
194+
195+
let context = MockContextProvider { state };
194196
let sql_to_rel = SqlToRel::new(&context);
195197
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
196198

@@ -412,10 +414,12 @@ fn roundtrip_statement_with_dialect() -> Result<()> {
412414
.try_with_sql(query.sql)?
413415
.parse_statement()?;
414416

415-
let context = MockContextProvider::default()
416-
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()))
417-
.with_udaf(max_udaf())
418-
.with_udaf(min_udaf());
417+
let state = MockSessionState::default()
418+
.with_aggregate_function(max_udaf())
419+
.with_aggregate_function(min_udaf())
420+
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
421+
422+
let context = MockContextProvider { state };
419423
let sql_to_rel = SqlToRel::new(&context);
420424
let plan = sql_to_rel
421425
.sql_statement_to_plan(statement)
@@ -443,7 +447,9 @@ fn test_unnest_logical_plan() -> Result<()> {
443447
.try_with_sql(query)?
444448
.parse_statement()?;
445449

446-
let context = MockContextProvider::default();
450+
let context = MockContextProvider {
451+
state: MockSessionState::default(),
452+
};
447453
let sql_to_rel = SqlToRel::new(&context);
448454
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
449455

@@ -516,7 +522,9 @@ fn test_pretty_roundtrip() -> Result<()> {
516522

517523
let df_schema = DFSchema::try_from(schema)?;
518524

519-
let context = MockContextProvider::default();
525+
let context = MockContextProvider {
526+
state: MockSessionState::default(),
527+
};
520528
let sql_to_rel = SqlToRel::new(&context);
521529

522530
let unparser = Unparser::default().with_pretty(true);
@@ -589,7 +597,9 @@ fn sql_round_trip(query: &str, expect: &str) {
589597
.parse_statement()
590598
.unwrap();
591599

592-
let context = MockContextProvider::default();
600+
let context = MockContextProvider {
601+
state: MockSessionState::default(),
602+
};
593603
let sql_to_rel = SqlToRel::new(&context);
594604
let plan = sql_to_rel.sql_statement_to_plan(statement).unwrap();
595605

datafusion/sql/tests/common/mod.rs

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -50,36 +50,40 @@ impl Display for MockCsvType {
5050
}
5151

5252
#[derive(Default)]
53-
pub(crate) struct MockContextProvider {
54-
options: ConfigOptions,
55-
udfs: HashMap<String, Arc<ScalarUDF>>,
56-
udafs: HashMap<String, Arc<AggregateUDF>>,
53+
pub(crate) struct MockSessionState {
54+
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
55+
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
5756
expr_planners: Vec<Arc<dyn ExprPlanner>>,
57+
pub config_options: ConfigOptions,
5858
}
5959

60-
impl MockContextProvider {
61-
// Suppressing dead code warning, as this is used in integration test crates
62-
#[allow(dead_code)]
63-
pub(crate) fn options_mut(&mut self) -> &mut ConfigOptions {
64-
&mut self.options
60+
impl MockSessionState {
61+
pub fn with_expr_planner(mut self, expr_planner: Arc<dyn ExprPlanner>) -> Self {
62+
self.expr_planners.push(expr_planner);
63+
self
6564
}
6665

67-
#[allow(dead_code)]
68-
pub(crate) fn with_udf(mut self, udf: ScalarUDF) -> Self {
69-
self.udfs.insert(udf.name().to_string(), Arc::new(udf));
66+
pub fn with_scalar_function(mut self, scalar_function: Arc<ScalarUDF>) -> Self {
67+
self.scalar_functions
68+
.insert(scalar_function.name().to_string(), scalar_function);
7069
self
7170
}
7271

73-
pub(crate) fn with_udaf(mut self, udaf: Arc<AggregateUDF>) -> Self {
72+
pub fn with_aggregate_function(
73+
mut self,
74+
aggregate_function: Arc<AggregateUDF>,
75+
) -> Self {
7476
// TODO: change to to_string() if all the function name is converted to lowercase
75-
self.udafs.insert(udaf.name().to_lowercase(), udaf);
77+
self.aggregate_functions.insert(
78+
aggregate_function.name().to_string().to_lowercase(),
79+
aggregate_function,
80+
);
7681
self
7782
}
83+
}
7884

79-
pub(crate) fn with_expr_planner(mut self, planner: Arc<dyn ExprPlanner>) -> Self {
80-
self.expr_planners.push(planner);
81-
self
82-
}
85+
pub(crate) struct MockContextProvider {
86+
pub(crate) state: MockSessionState,
8387
}
8488

8589
impl ContextProvider for MockContextProvider {
@@ -202,11 +206,11 @@ impl ContextProvider for MockContextProvider {
202206
}
203207

204208
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
205-
self.udfs.get(name).cloned()
209+
self.state.scalar_functions.get(name).cloned()
206210
}
207211

208212
fn get_aggregate_meta(&self, name: &str) -> Option<Arc<AggregateUDF>> {
209-
self.udafs.get(name).cloned()
213+
self.state.aggregate_functions.get(name).cloned()
210214
}
211215

212216
fn get_variable_type(&self, _: &[String]) -> Option<DataType> {
@@ -218,7 +222,7 @@ impl ContextProvider for MockContextProvider {
218222
}
219223

220224
fn options(&self) -> &ConfigOptions {
221-
&self.options
225+
&self.state.config_options
222226
}
223227

224228
fn get_file_type(
@@ -237,19 +241,19 @@ impl ContextProvider for MockContextProvider {
237241
}
238242

239243
fn udf_names(&self) -> Vec<String> {
240-
self.udfs.keys().cloned().collect()
244+
self.state.scalar_functions.keys().cloned().collect()
241245
}
242246

243247
fn udaf_names(&self) -> Vec<String> {
244-
self.udafs.keys().cloned().collect()
248+
self.state.aggregate_functions.keys().cloned().collect()
245249
}
246250

247251
fn udwf_names(&self) -> Vec<String> {
248252
Vec::new()
249253
}
250254

251255
fn get_expr_planners(&self) -> &[Arc<dyn ExprPlanner>] {
252-
&self.expr_planners
256+
&self.state.expr_planners
253257
}
254258
}
255259

datafusion/sql/tests/sql_integration.rs

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ use datafusion_sql::{
4141
planner::{ParserOptions, SqlToRel},
4242
};
4343

44+
use crate::common::MockSessionState;
4445
use datafusion_functions::core::planner::CoreFunctionPlanner;
4546
use datafusion_functions_aggregate::{
4647
approx_median::approx_median_udaf, count::count_udaf, min_max::max_udaf,
@@ -1495,8 +1496,9 @@ fn recursive_ctes_disabled() {
14951496
select * from numbers;";
14961497

14971498
// manually setting up test here so that we can disable recursive ctes
1498-
let mut context = MockContextProvider::default();
1499-
context.options_mut().execution.enable_recursive_ctes = false;
1499+
let mut state = MockSessionState::default();
1500+
state.config_options.execution.enable_recursive_ctes = false;
1501+
let context = MockContextProvider { state };
15001502

15011503
let planner = SqlToRel::new_with_options(&context, ParserOptions::default());
15021504
let result = DFParser::parse_sql_with_dialect(sql, &GenericDialect {});
@@ -2727,7 +2729,8 @@ fn logical_plan_with_options(sql: &str, options: ParserOptions) -> Result<Logica
27272729
}
27282730

27292731
fn logical_plan_with_dialect(sql: &str, dialect: &dyn Dialect) -> Result<LogicalPlan> {
2730-
let context = MockContextProvider::default().with_udaf(sum_udaf());
2732+
let state = MockSessionState::default().with_aggregate_function(sum_udaf());
2733+
let context = MockContextProvider { state };
27312734
let planner = SqlToRel::new(&context);
27322735
let result = DFParser::parse_sql_with_dialect(sql, dialect);
27332736
let mut ast = result?;
@@ -2739,39 +2742,44 @@ fn logical_plan_with_dialect_and_options(
27392742
dialect: &dyn Dialect,
27402743
options: ParserOptions,
27412744
) -> Result<LogicalPlan> {
2742-
let context = MockContextProvider::default()
2743-
.with_udf(unicode::character_length().as_ref().clone())
2744-
.with_udf(string::concat().as_ref().clone())
2745-
.with_udf(make_udf(
2745+
let state = MockSessionState::default()
2746+
.with_scalar_function(Arc::new(unicode::character_length().as_ref().clone()))
2747+
.with_scalar_function(Arc::new(string::concat().as_ref().clone()))
2748+
.with_scalar_function(Arc::new(make_udf(
27462749
"nullif",
27472750
vec![DataType::Int32, DataType::Int32],
27482751
DataType::Int32,
2749-
))
2750-
.with_udf(make_udf(
2752+
)))
2753+
.with_scalar_function(Arc::new(make_udf(
27512754
"round",
27522755
vec![DataType::Float64, DataType::Int64],
27532756
DataType::Float32,
2754-
))
2755-
.with_udf(make_udf(
2757+
)))
2758+
.with_scalar_function(Arc::new(make_udf(
27562759
"arrow_cast",
27572760
vec![DataType::Int64, DataType::Utf8],
27582761
DataType::Float64,
2759-
))
2760-
.with_udf(make_udf(
2762+
)))
2763+
.with_scalar_function(Arc::new(make_udf(
27612764
"date_trunc",
27622765
vec![DataType::Utf8, DataType::Timestamp(Nanosecond, None)],
27632766
DataType::Int32,
2764-
))
2765-
.with_udf(make_udf("sqrt", vec![DataType::Int64], DataType::Int64))
2766-
.with_udaf(sum_udaf())
2767-
.with_udaf(approx_median_udaf())
2768-
.with_udaf(count_udaf())
2769-
.with_udaf(avg_udaf())
2770-
.with_udaf(min_udaf())
2771-
.with_udaf(max_udaf())
2772-
.with_udaf(grouping_udaf())
2767+
)))
2768+
.with_scalar_function(Arc::new(make_udf(
2769+
"sqrt",
2770+
vec![DataType::Int64],
2771+
DataType::Int64,
2772+
)))
2773+
.with_aggregate_function(sum_udaf())
2774+
.with_aggregate_function(approx_median_udaf())
2775+
.with_aggregate_function(count_udaf())
2776+
.with_aggregate_function(avg_udaf())
2777+
.with_aggregate_function(min_udaf())
2778+
.with_aggregate_function(max_udaf())
2779+
.with_aggregate_function(grouping_udaf())
27732780
.with_expr_planner(Arc::new(CoreFunctionPlanner::default()));
27742781

2782+
let context = MockContextProvider { state };
27752783
let planner = SqlToRel::new_with_options(&context, options);
27762784
let result = DFParser::parse_sql_with_dialect(sql, dialect);
27772785
let mut ast = result?;

0 commit comments

Comments
 (0)