Skip to content

Commit 2238680

Browse files
authored
Allow extensions_options to accept Option field (apache#14664)
* implement ConfigField for extension option * fix compile * fix doc test
1 parent 40bb75f commit 2238680

File tree

2 files changed

+95
-23
lines changed

2 files changed

+95
-23
lines changed

datafusion/common/src/config.rs

+60-23
Original file line numberDiff line numberDiff line change
@@ -1241,35 +1241,72 @@ macro_rules! extensions_options {
12411241
Box::new(self.clone())
12421242
}
12431243

1244-
fn set(&mut self, key: &str, value: &str) -> $crate::Result<()> {
1245-
match key {
1246-
$(
1247-
stringify!($field_name) => {
1248-
self.$field_name = value.parse().map_err(|e| {
1249-
$crate::DataFusionError::Context(
1250-
format!(concat!("Error parsing {} as ", stringify!($t),), value),
1251-
Box::new($crate::DataFusionError::External(Box::new(e))),
1252-
)
1253-
})?;
1254-
Ok(())
1255-
}
1256-
)*
1257-
_ => Err($crate::DataFusionError::Configuration(
1258-
format!(concat!("Config value \"{}\" not found on ", stringify!($struct_name)), key)
1259-
))
1260-
}
1244+
fn set(&mut self, key: &str, value: &str) -> $crate::error::Result<()> {
1245+
$crate::config::ConfigField::set(self, key, value)
12611246
}
12621247

12631248
fn entries(&self) -> Vec<$crate::config::ConfigEntry> {
1264-
vec![
1249+
struct Visitor(Vec<$crate::config::ConfigEntry>);
1250+
1251+
impl $crate::config::Visit for Visitor {
1252+
fn some<V: std::fmt::Display>(
1253+
&mut self,
1254+
key: &str,
1255+
value: V,
1256+
description: &'static str,
1257+
) {
1258+
self.0.push($crate::config::ConfigEntry {
1259+
key: key.to_string(),
1260+
value: Some(value.to_string()),
1261+
description,
1262+
})
1263+
}
1264+
1265+
fn none(&mut self, key: &str, description: &'static str) {
1266+
self.0.push($crate::config::ConfigEntry {
1267+
key: key.to_string(),
1268+
value: None,
1269+
description,
1270+
})
1271+
}
1272+
}
1273+
1274+
let mut v = Visitor(vec![]);
1275+
// The prefix is not used for extensions.
1276+
// The description is generated in ConfigField::visit.
1277+
// We can just pass empty strings here.
1278+
$crate::config::ConfigField::visit(self, &mut v, "", "");
1279+
v.0
1280+
}
1281+
}
1282+
1283+
impl $crate::config::ConfigField for $struct_name {
1284+
fn set(&mut self, key: &str, value: &str) -> $crate::error::Result<()> {
1285+
let (key, rem) = key.split_once('.').unwrap_or((key, ""));
1286+
match key {
12651287
$(
1266-
$crate::config::ConfigEntry {
1267-
key: stringify!($field_name).to_owned(),
1268-
value: (self.$field_name != $default).then(|| self.$field_name.to_string()),
1269-
description: concat!($($d),*).trim(),
1288+
stringify!($field_name) => {
1289+
// Safely apply deprecated attribute if present
1290+
// $(#[allow(deprecated)])?
1291+
{
1292+
#[allow(deprecated)]
1293+
self.$field_name.set(rem, value.as_ref())
1294+
}
12701295
},
12711296
)*
1272-
]
1297+
_ => return $crate::error::_config_err!(
1298+
"Config value \"{}\" not found on {}", key, stringify!($struct_name)
1299+
)
1300+
}
1301+
}
1302+
1303+
fn visit<V: $crate::config::Visit>(&self, v: &mut V, _key_prefix: &str, _description: &'static str) {
1304+
$(
1305+
let key = stringify!($field_name).to_string();
1306+
let desc = concat!($($d),*).trim();
1307+
#[allow(deprecated)]
1308+
self.$field_name.visit(v, key.as_str(), desc);
1309+
)*
12731310
}
12741311
}
12751312
}

datafusion/execution/src/task.rs

+35
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ mod tests {
214214
extensions_options! {
215215
struct TestExtension {
216216
value: usize, default = 42
217+
option_value: Option<usize>, default = None
217218
}
218219
}
219220

@@ -229,6 +230,7 @@ mod tests {
229230

230231
let mut config = ConfigOptions::new().with_extensions(extensions);
231232
config.set("test.value", "24")?;
233+
config.set("test.option_value", "42")?;
232234
let session_config = SessionConfig::from(config);
233235

234236
let task_context = TaskContext::new(
@@ -249,6 +251,39 @@ mod tests {
249251
assert!(test.is_some());
250252

251253
assert_eq!(test.unwrap().value, 24);
254+
assert_eq!(test.unwrap().option_value, Some(42));
255+
256+
Ok(())
257+
}
258+
259+
#[test]
260+
fn task_context_extensions_default() -> Result<()> {
261+
let runtime = Arc::new(RuntimeEnv::default());
262+
let mut extensions = Extensions::new();
263+
extensions.insert(TestExtension::default());
264+
265+
let config = ConfigOptions::new().with_extensions(extensions);
266+
let session_config = SessionConfig::from(config);
267+
268+
let task_context = TaskContext::new(
269+
Some("task_id".to_string()),
270+
"session_id".to_string(),
271+
session_config,
272+
HashMap::default(),
273+
HashMap::default(),
274+
HashMap::default(),
275+
runtime,
276+
);
277+
278+
let test = task_context
279+
.session_config()
280+
.options()
281+
.extensions
282+
.get::<TestExtension>();
283+
assert!(test.is_some());
284+
285+
assert_eq!(test.unwrap().value, 42);
286+
assert_eq!(test.unwrap().option_value, None);
252287

253288
Ok(())
254289
}

0 commit comments

Comments
 (0)