Skip to content

Commit a1082af

Browse files
Blizzarawiedld
authored andcommitted
feat: consume and produce Substrait type extensions (apache#11510)
* support reading type extensions in consumer * read extension for UDTs * support also type extensions in producer * produce extensions for MonthDayNano UDT * unify extensions between consumer and producer * fixes * add doc comments * add extension tests * fix * fix docs * fix test * fix clipppy
1 parent fa577d1 commit a1082af

File tree

6 files changed

+644
-430
lines changed

6 files changed

+644
-430
lines changed
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion::common::{plan_err, DataFusionError};
19+
use std::collections::HashMap;
20+
use substrait::proto::extensions::simple_extension_declaration::{
21+
ExtensionFunction, ExtensionType, ExtensionTypeVariation, MappingType,
22+
};
23+
use substrait::proto::extensions::SimpleExtensionDeclaration;
24+
25+
/// Substrait uses [SimpleExtensions](https://substrait.io/extensions/#simple-extensions) to define
26+
/// behavior of plans in addition to what's supported directly by the protobuf definitions.
27+
/// That includes functions, but also provides support for custom types and variations for existing
28+
/// types. This structs facilitates the use of these extensions in DataFusion.
29+
/// TODO: DF doesn't yet use extensions for type variations <https://github.com/apache/datafusion/issues/11544>
30+
/// TODO: DF doesn't yet provide valid extensionUris <https://github.com/apache/datafusion/issues/11545>
31+
#[derive(Default, Debug, PartialEq)]
32+
pub struct Extensions {
33+
pub functions: HashMap<u32, String>, // anchor -> function name
34+
pub types: HashMap<u32, String>, // anchor -> type name
35+
pub type_variations: HashMap<u32, String>, // anchor -> type variation name
36+
}
37+
38+
impl Extensions {
39+
/// Registers a function and returns the anchor (reference) to it. If the function has already
40+
/// been registered, it returns the existing anchor.
41+
/// Function names are case-insensitive (converted to lowercase).
42+
pub fn register_function(&mut self, function_name: String) -> u32 {
43+
let function_name = function_name.to_lowercase();
44+
45+
// Some functions are named differently in Substrait default extensions than in DF
46+
// Rename those to match the Substrait extensions for interoperability
47+
let function_name = match function_name.as_str() {
48+
"substr" => "substring".to_string(),
49+
_ => function_name,
50+
};
51+
52+
match self.functions.iter().find(|(_, f)| *f == &function_name) {
53+
Some((function_anchor, _)) => *function_anchor, // Function has been registered
54+
None => {
55+
// Function has NOT been registered
56+
let function_anchor = self.functions.len() as u32;
57+
self.functions
58+
.insert(function_anchor, function_name.clone());
59+
function_anchor
60+
}
61+
}
62+
}
63+
64+
/// Registers a type and returns the anchor (reference) to it. If the type has already
65+
/// been registered, it returns the existing anchor.
66+
pub fn register_type(&mut self, type_name: String) -> u32 {
67+
let type_name = type_name.to_lowercase();
68+
match self.types.iter().find(|(_, t)| *t == &type_name) {
69+
Some((type_anchor, _)) => *type_anchor, // Type has been registered
70+
None => {
71+
// Type has NOT been registered
72+
let type_anchor = self.types.len() as u32;
73+
self.types.insert(type_anchor, type_name.clone());
74+
type_anchor
75+
}
76+
}
77+
}
78+
}
79+
80+
impl TryFrom<&Vec<SimpleExtensionDeclaration>> for Extensions {
81+
type Error = DataFusionError;
82+
83+
fn try_from(
84+
value: &Vec<SimpleExtensionDeclaration>,
85+
) -> datafusion::common::Result<Self> {
86+
let mut functions = HashMap::new();
87+
let mut types = HashMap::new();
88+
let mut type_variations = HashMap::new();
89+
90+
for ext in value {
91+
match &ext.mapping_type {
92+
Some(MappingType::ExtensionFunction(ext_f)) => {
93+
functions.insert(ext_f.function_anchor, ext_f.name.to_owned());
94+
}
95+
Some(MappingType::ExtensionType(ext_t)) => {
96+
types.insert(ext_t.type_anchor, ext_t.name.to_owned());
97+
}
98+
Some(MappingType::ExtensionTypeVariation(ext_v)) => {
99+
type_variations
100+
.insert(ext_v.type_variation_anchor, ext_v.name.to_owned());
101+
}
102+
None => return plan_err!("Cannot parse empty extension"),
103+
}
104+
}
105+
106+
Ok(Extensions {
107+
functions,
108+
types,
109+
type_variations,
110+
})
111+
}
112+
}
113+
114+
impl From<Extensions> for Vec<SimpleExtensionDeclaration> {
115+
fn from(val: Extensions) -> Vec<SimpleExtensionDeclaration> {
116+
let mut extensions = vec![];
117+
for (f_anchor, f_name) in val.functions {
118+
let function_extension = ExtensionFunction {
119+
extension_uri_reference: u32::MAX,
120+
function_anchor: f_anchor,
121+
name: f_name,
122+
};
123+
let simple_extension = SimpleExtensionDeclaration {
124+
mapping_type: Some(MappingType::ExtensionFunction(function_extension)),
125+
};
126+
extensions.push(simple_extension);
127+
}
128+
129+
for (t_anchor, t_name) in val.types {
130+
let type_extension = ExtensionType {
131+
extension_uri_reference: u32::MAX, // https://github.com/apache/datafusion/issues/11545
132+
type_anchor: t_anchor,
133+
name: t_name,
134+
};
135+
let simple_extension = SimpleExtensionDeclaration {
136+
mapping_type: Some(MappingType::ExtensionType(type_extension)),
137+
};
138+
extensions.push(simple_extension);
139+
}
140+
141+
for (tv_anchor, tv_name) in val.type_variations {
142+
let type_variation_extension = ExtensionTypeVariation {
143+
extension_uri_reference: u32::MAX, // We don't register proper extension URIs yet
144+
type_variation_anchor: tv_anchor,
145+
name: tv_name,
146+
};
147+
let simple_extension = SimpleExtensionDeclaration {
148+
mapping_type: Some(MappingType::ExtensionTypeVariation(
149+
type_variation_extension,
150+
)),
151+
};
152+
extensions.push(simple_extension);
153+
}
154+
155+
extensions
156+
}
157+
}

datafusion/substrait/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@
7272
//! # Ok(())
7373
//! # }
7474
//! ```
75+
pub mod extensions;
7576
pub mod logical_plan;
7677
pub mod physical_plan;
7778
pub mod serializer;

0 commit comments

Comments
 (0)