Skip to content

Commit 92fdc16

Browse files
[feat] Add Ability to Generate Function Visibility to arrow-udf (#52)
This PR will add an additional meta parameter `visibility` to `arrow-udf`. I might want this to be added while working on apache/datafusion#11413. Sometimes it is better to reference the symbol directly instead of using the function registry. --------- Co-authored-by: Runji Wang <[email protected]>
1 parent e379764 commit 92fdc16

File tree

7 files changed

+239
-8
lines changed

7 files changed

+239
-8
lines changed

arrow-udf-macros/src/gen.rs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,17 @@ impl FunctionAttr {
108108
user_fn: &UserFunctionAttr,
109109
eval_fn_name: &Ident,
110110
) -> Result<TokenStream2> {
111+
let fn_with_visibility = if let Some(visiblity) = &self.visibility {
112+
// handle the scope of the visibility by parsing the visibility string
113+
match syn::parse_str::<syn::Visibility>(visiblity)? {
114+
syn::Visibility::Public(token) => quote! { #token fn },
115+
syn::Visibility::Restricted(vis_restricted) => quote! { #vis_restricted fn },
116+
syn::Visibility::Inherited => quote! { fn },
117+
}
118+
} else {
119+
quote! { fn }
120+
};
121+
111122
let variadic = matches!(self.args.last(), Some(t) if t == "...");
112123
let num_args = self.args.len() - if variadic { 1 } else { 0 };
113124
let user_fn_name = format_ident!("{}", user_fn.name);
@@ -420,7 +431,7 @@ impl FunctionAttr {
420431

421432
Ok(if self.is_table_function {
422433
quote! {
423-
fn #eval_fn_name<'a>(input: &'a ::arrow_udf::codegen::arrow_array::RecordBatch)
434+
#fn_with_visibility #eval_fn_name<'a>(input: &'a ::arrow_udf::codegen::arrow_array::RecordBatch)
424435
-> ::arrow_udf::Result<Box<dyn Iterator<Item = ::arrow_udf::codegen::arrow_array::RecordBatch> + 'a>>
425436
{
426437
const BATCH_SIZE: usize = 1024;
@@ -432,7 +443,7 @@ impl FunctionAttr {
432443
}
433444
} else {
434445
quote! {
435-
fn #eval_fn_name(input: &::arrow_udf::codegen::arrow_array::RecordBatch)
446+
#fn_with_visibility #eval_fn_name(input: &::arrow_udf::codegen::arrow_array::RecordBatch)
436447
-> ::arrow_udf::Result<::arrow_udf::codegen::arrow_array::RecordBatch>
437448
{
438449
#downcast_arrays

arrow-udf-macros/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,8 @@ struct FunctionAttr {
323323
/// Generated batch function name.
324324
/// If not specified, the macro will not generate batch function.
325325
output: Option<String>,
326+
/// Customized function visibility.
327+
visibility: Option<String>,
326328
}
327329

328330
/// Attributes from function signature `fn(..)`

arrow-udf-macros/src/parse.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ impl Parse for FunctionAttr {
8484
parsed.volatile = true;
8585
} else if meta.path().is_ident("append_only") {
8686
parsed.append_only = true;
87+
} else if meta.path().is_ident("visibility") {
88+
parsed.visibility = Some(get_value()?);
8789
} else {
8890
return Err(Error::new(
8991
meta.span(),

arrow-udf/tests/cases/mod.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// Copyright 2024 RisingWave Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
pub mod visibility_tests;
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
// Copyright 2024 RisingWave Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::sync::Arc;
16+
17+
use crate::common::check;
18+
use arrow_array::{Float32Array, Float64Array, Int32Array, RecordBatch, StringArray, UInt32Array};
19+
use arrow_schema::{DataType, Field, Schema};
20+
use arrow_udf::function;
21+
use expect_test::expect;
22+
23+
// test visibility
24+
#[function("maybe_visible(int) -> int", output = "maybe_visible_udf")]
25+
#[function(
26+
"maybe_visible(uint32) -> uint32",
27+
output = "maybe_visible_pub_udf",
28+
visibility = "pub"
29+
)]
30+
#[function(
31+
"maybe_visible(float32) -> float32",
32+
output = "maybe_visible_pub_crate_udf",
33+
visibility = "pub(crate)"
34+
)]
35+
#[function(
36+
"maybe_visible(float64) -> float64",
37+
output = "maybe_visible_pub_self_udf",
38+
visibility = "pub(self)"
39+
)]
40+
#[function(
41+
"maybe_visible(string) -> string",
42+
output = "maybe_visible_pub_super_udf",
43+
visibility = "pub(super)"
44+
)]
45+
fn maybe_visible<T>(x: T) -> T {
46+
x
47+
}
48+
49+
#[test]
50+
fn test_default() {
51+
let schema = Schema::new(vec![Field::new("int", DataType::Int32, true)]);
52+
let arg0 = Int32Array::from(vec![Some(1), None]);
53+
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
54+
55+
let output = maybe_visible_udf(&input).unwrap();
56+
check(
57+
&[output],
58+
expect![[r#"
59+
+---------------+
60+
| maybe_visible |
61+
+---------------+
62+
| 1 |
63+
| |
64+
+---------------+"#]],
65+
);
66+
}
67+
68+
#[test]
69+
fn test_pub() {
70+
let schema = Schema::new(vec![Field::new("uint32", DataType::UInt32, true)]);
71+
let arg0 = UInt32Array::from(vec![Some(1), None]);
72+
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
73+
74+
let output = maybe_visible_pub_udf(&input).unwrap();
75+
check(
76+
&[output],
77+
expect![[r#"
78+
+---------------+
79+
| maybe_visible |
80+
+---------------+
81+
| 1 |
82+
| |
83+
+---------------+"#]],
84+
);
85+
}
86+
87+
#[test]
88+
fn test_pub_crate() {
89+
let schema = Schema::new(vec![Field::new("float32", DataType::Float32, true)]);
90+
let arg0 = Float32Array::from(vec![Some(1.0), None]);
91+
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
92+
93+
let output = maybe_visible_pub_crate_udf(&input).unwrap();
94+
check(
95+
&[output],
96+
expect![[r#"
97+
+---------------+
98+
| maybe_visible |
99+
+---------------+
100+
| 1.0 |
101+
| |
102+
+---------------+"#]],
103+
);
104+
}
105+
106+
#[test]
107+
fn test_pub_self() {
108+
let schema = Schema::new(vec![Field::new("float64", DataType::Float64, true)]);
109+
let arg0 = Float64Array::from(vec![Some(1.0), None]);
110+
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
111+
112+
let output = maybe_visible_pub_self_udf(&input).unwrap();
113+
check(
114+
&[output],
115+
expect![[r#"
116+
+---------------+
117+
| maybe_visible |
118+
+---------------+
119+
| 1.0 |
120+
| |
121+
+---------------+"#]],
122+
);
123+
}
124+
125+
#[test]
126+
fn test_pub_super() {
127+
let schema = Schema::new(vec![Field::new("string", DataType::Utf8, true)]);
128+
let arg0 = StringArray::from(vec![Some("1.0"), None]);
129+
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
130+
131+
let output = maybe_visible_pub_super_udf(&input).unwrap();
132+
check(
133+
&[output],
134+
expect![[r#"
135+
+---------------+
136+
| maybe_visible |
137+
+---------------+
138+
| 1.0 |
139+
| |
140+
+---------------+"#]],
141+
);
142+
}

arrow-udf/tests/common/mod.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Copyright 2024 RisingWave Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use arrow_array::RecordBatch;
16+
use arrow_cast::pretty::pretty_format_batches;
17+
use expect_test::Expect;
18+
19+
/// Compare the actual output with the expected output.
20+
#[track_caller]
21+
pub fn check(actual: &[RecordBatch], expect: Expect) {
22+
expect.assert_eq(&pretty_format_batches(actual).unwrap().to_string());
23+
}

arrow-udf/tests/tests.rs

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,15 @@ use arrow_array::cast::AsArray;
2020
use arrow_array::temporal_conversions::time_to_time64us;
2121
use arrow_array::types::{Date32Type, Int32Type};
2222
use arrow_array::*;
23-
use arrow_cast::pretty::pretty_format_batches;
2423
use arrow_schema::{DataType, Field, Schema, TimeUnit};
2524
use arrow_udf::function;
2625
use arrow_udf::types::*;
27-
use expect_test::{expect, Expect};
26+
use cases::visibility_tests::{maybe_visible_pub_crate_udf, maybe_visible_pub_udf};
27+
use common::check;
28+
use expect_test::expect;
29+
30+
mod cases;
31+
mod common;
2832

2933
// test no return value
3034
#[function("null()")]
@@ -670,10 +674,42 @@ fn test_json_array_elements() {
670674
);
671675
}
672676

673-
/// Compare the actual output with the expected output.
674-
#[track_caller]
675-
fn check(actual: &[RecordBatch], expect: Expect) {
676-
expect.assert_eq(&pretty_format_batches(actual).unwrap().to_string());
677+
#[test]
678+
fn test_pub() {
679+
let schema = Schema::new(vec![Field::new("uint32", DataType::UInt32, true)]);
680+
let arg0 = UInt32Array::from(vec![Some(1), None]);
681+
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
682+
683+
let output = maybe_visible_pub_udf(&input).unwrap();
684+
check(
685+
&[output],
686+
expect![[r#"
687+
+---------------+
688+
| maybe_visible |
689+
+---------------+
690+
| 1 |
691+
| |
692+
+---------------+"#]],
693+
);
694+
}
695+
696+
#[test]
697+
fn test_pub_crate() {
698+
let schema = Schema::new(vec![Field::new("float32", DataType::Float32, true)]);
699+
let arg0 = Float32Array::from(vec![Some(1.0), None]);
700+
let input = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(arg0)]).unwrap();
701+
702+
let output = maybe_visible_pub_crate_udf(&input).unwrap();
703+
check(
704+
&[output],
705+
expect![[r#"
706+
+---------------+
707+
| maybe_visible |
708+
+---------------+
709+
| 1.0 |
710+
| |
711+
+---------------+"#]],
712+
);
677713
}
678714

679715
/// Returns a field with JSON type.

0 commit comments

Comments
 (0)