Skip to content

Commit d13c56b

Browse files
committed
feat: Add Spark-compatible monthname function to datafusion-spark
Implements `monthname(date_or_timestamp)` that returns the three-letter abbreviated month name (Jan, Feb, ..., Dec) from a date or timestamp, matching Apache Spark's behavior.
1 parent 244f891 commit d13c56b

File tree

2 files changed

+304
-0
lines changed

2 files changed

+304
-0
lines changed

datafusion/spark/src/function/datetime/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ pub mod from_utc_timestamp;
2626
pub mod last_day;
2727
pub mod make_dt_interval;
2828
pub mod make_interval;
29+
pub mod monthname;
2930
pub mod next_day;
3031
pub mod time_trunc;
3132
pub mod to_utc_timestamp;
@@ -52,6 +53,7 @@ make_udf_function!(extract::SparkSecond, second);
5253
make_udf_function!(last_day::SparkLastDay, last_day);
5354
make_udf_function!(make_dt_interval::SparkMakeDtInterval, make_dt_interval);
5455
make_udf_function!(make_interval::SparkMakeInterval, make_interval);
56+
make_udf_function!(monthname::SparkMonthName, monthname);
5557
make_udf_function!(next_day::SparkNextDay, next_day);
5658
make_udf_function!(time_trunc::SparkTimeTrunc, time_trunc);
5759
make_udf_function!(to_utc_timestamp::SparkToUtcTimestamp, to_utc_timestamp);
@@ -117,6 +119,11 @@ pub mod expr_fn {
117119
"Make interval from years, months, weeks, days, hours, mins and secs.",
118120
years months weeks days hours mins secs
119121
));
122+
export_functions!((
123+
monthname,
124+
"Returns the three-letter abbreviated month name from a date or timestamp.",
125+
arg1
126+
));
120127
// TODO: add once ANSI support is added:
121128
// "When both of the input parameters are not NULL and day_of_week is an invalid input, the function throws SparkIllegalArgumentException if spark.sql.ansi.enabled is set to true, otherwise NULL."
122129
export_functions!((
@@ -195,6 +202,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
195202
make_dt_interval(),
196203
make_interval(),
197204
minute(),
205+
monthname(),
198206
next_day(),
199207
second(),
200208
time_trunc(),
Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,296 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::array::{AsArray, StringArray};
21+
use arrow::compute::{DatePart, date_part};
22+
use arrow::datatypes::{DataType, Field, FieldRef};
23+
use datafusion_common::utils::take_function_args;
24+
use datafusion_common::{Result, ScalarValue, exec_err, internal_err};
25+
use datafusion_expr::{
26+
Coercion, ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl,
27+
Signature, TypeSignatureClass, Volatility,
28+
};
29+
30+
const MONTH_NAMES: [&str; 12] = [
31+
"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec",
32+
];
33+
34+
fn month_number_to_name(month: i32) -> Option<&'static str> {
35+
MONTH_NAMES.get((month - 1) as usize).copied()
36+
}
37+
38+
/// Spark-compatible `monthname` expression.
39+
/// Returns the three-letter abbreviated month name from a date or timestamp.
40+
///
41+
/// <https://spark.apache.org/docs/latest/api/sql/index.html#monthname>
42+
#[derive(Debug, PartialEq, Eq, Hash)]
43+
pub struct SparkMonthName {
44+
signature: Signature,
45+
}
46+
47+
impl Default for SparkMonthName {
48+
fn default() -> Self {
49+
Self::new()
50+
}
51+
}
52+
53+
impl SparkMonthName {
54+
pub fn new() -> Self {
55+
Self {
56+
signature: Signature::coercible(
57+
vec![Coercion::new_exact(TypeSignatureClass::Timestamp)],
58+
Volatility::Immutable,
59+
),
60+
}
61+
}
62+
}
63+
64+
impl ScalarUDFImpl for SparkMonthName {
65+
fn name(&self) -> &str {
66+
"monthname"
67+
}
68+
69+
fn signature(&self) -> &Signature {
70+
&self.signature
71+
}
72+
73+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
74+
internal_err!("return_field_from_args should be used instead")
75+
}
76+
77+
fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<FieldRef> {
78+
let nullable = args.arg_fields.iter().any(|f| f.is_nullable());
79+
Ok(Arc::new(Field::new(self.name(), DataType::Utf8, nullable)))
80+
}
81+
82+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
83+
let [arg] = take_function_args(self.name(), args.args)?;
84+
match arg {
85+
ColumnarValue::Scalar(scalar) => {
86+
if scalar.is_null() {
87+
return Ok(ColumnarValue::Scalar(ScalarValue::Utf8(None)));
88+
}
89+
let arr = scalar.to_array_of_size(1)?;
90+
let month_arr = date_part(&arr, DatePart::Month)?;
91+
let month_val = month_arr
92+
.as_primitive::<arrow::datatypes::Int32Type>()
93+
.value(0);
94+
match month_number_to_name(month_val) {
95+
Some(name) => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(
96+
name.to_string(),
97+
)))),
98+
None => {
99+
exec_err!("Invalid month number: {month_val}")
100+
}
101+
}
102+
}
103+
ColumnarValue::Array(arr) => {
104+
let month_arr = date_part(&arr, DatePart::Month)?;
105+
let int_arr = month_arr.as_primitive::<arrow::datatypes::Int32Type>();
106+
107+
let result: StringArray = int_arr
108+
.iter()
109+
.map(|maybe_month| match maybe_month {
110+
Some(m) => Ok(month_number_to_name(m)),
111+
None => Ok(None),
112+
})
113+
.collect::<Result<StringArray>>()?;
114+
115+
Ok(ColumnarValue::Array(Arc::new(result)))
116+
}
117+
}
118+
}
119+
}
120+
121+
#[cfg(test)]
122+
mod tests {
123+
use super::*;
124+
use arrow::array::{Array, ArrayRef, Date32Array};
125+
use arrow::datatypes::TimeUnit;
126+
use datafusion_common::config::ConfigOptions;
127+
128+
fn make_args(
129+
args: Vec<ColumnarValue>,
130+
arg_fields: Vec<FieldRef>,
131+
number_rows: usize,
132+
) -> ScalarFunctionArgs {
133+
ScalarFunctionArgs {
134+
args,
135+
arg_fields,
136+
number_rows,
137+
return_field: Arc::new(Field::new("monthname", DataType::Utf8, true)),
138+
config_options: Arc::new(ConfigOptions::default()),
139+
}
140+
}
141+
142+
#[test]
143+
fn test_monthname_scalar_date() {
144+
let func = SparkMonthName::new();
145+
// 2024-03-15 = 19797 days since epoch
146+
let result = func
147+
.invoke_with_args(make_args(
148+
vec![ColumnarValue::Scalar(ScalarValue::Date32(Some(19797)))],
149+
vec![Arc::new(Field::new("d", DataType::Date32, true))],
150+
1,
151+
))
152+
.unwrap();
153+
match result {
154+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(name))) => {
155+
assert_eq!(name, "Mar");
156+
}
157+
other => panic!("Expected scalar Utf8, got {other:?}"),
158+
}
159+
}
160+
161+
#[test]
162+
fn test_monthname_array_dates() {
163+
let func = SparkMonthName::new();
164+
let date_array: ArrayRef = Arc::new(Date32Array::from(vec![
165+
Some(19723), // 2024-01-01 => Jan
166+
Some(19797), // 2024-03-15 => Mar
167+
Some(20088), // 2024-12-31 => Dec
168+
None,
169+
]));
170+
171+
let result = func
172+
.invoke_with_args(make_args(
173+
vec![ColumnarValue::Array(date_array)],
174+
vec![Arc::new(Field::new("d", DataType::Date32, true))],
175+
4,
176+
))
177+
.unwrap();
178+
179+
match result {
180+
ColumnarValue::Array(arr) => {
181+
let str_arr = arr.as_any().downcast_ref::<StringArray>().unwrap();
182+
assert_eq!(str_arr.value(0), "Jan");
183+
assert_eq!(str_arr.value(1), "Mar");
184+
assert_eq!(str_arr.value(2), "Dec");
185+
assert!(str_arr.is_null(3));
186+
}
187+
other => panic!("Expected array, got {other:?}"),
188+
}
189+
}
190+
191+
#[test]
192+
fn test_monthname_null_scalar() {
193+
let func = SparkMonthName::new();
194+
let result = func
195+
.invoke_with_args(make_args(
196+
vec![ColumnarValue::Scalar(ScalarValue::Date32(None))],
197+
vec![Arc::new(Field::new("d", DataType::Date32, true))],
198+
1,
199+
))
200+
.unwrap();
201+
match result {
202+
ColumnarValue::Scalar(ScalarValue::Utf8(None)) => {}
203+
other => panic!("Expected Utf8(None), got {other:?}"),
204+
}
205+
}
206+
207+
#[test]
208+
fn test_monthname_timestamp_micros() {
209+
let func = SparkMonthName::new();
210+
// 2024-07-15 10:30:00 UTC in microseconds
211+
let result = func
212+
.invoke_with_args(make_args(
213+
vec![ColumnarValue::Scalar(ScalarValue::TimestampMicrosecond(
214+
Some(1721038200000000),
215+
None,
216+
))],
217+
vec![Arc::new(Field::new(
218+
"ts",
219+
DataType::Timestamp(TimeUnit::Microsecond, None),
220+
true,
221+
))],
222+
1,
223+
))
224+
.unwrap();
225+
match result {
226+
ColumnarValue::Scalar(ScalarValue::Utf8(Some(name))) => {
227+
assert_eq!(name, "Jul");
228+
}
229+
other => panic!("Expected scalar Utf8, got {other:?}"),
230+
}
231+
}
232+
233+
#[test]
234+
fn test_monthname_all_months() {
235+
let func = SparkMonthName::new();
236+
let dates: Vec<Option<i32>> = vec![
237+
Some(19737), // 2024-01-15
238+
Some(19768), // 2024-02-15
239+
Some(19797), // 2024-03-15
240+
Some(19828), // 2024-04-15
241+
Some(19858), // 2024-05-15
242+
Some(19889), // 2024-06-15
243+
Some(19919), // 2024-07-15
244+
Some(19950), // 2024-08-15
245+
Some(19981), // 2024-09-15
246+
Some(20011), // 2024-10-15
247+
Some(20042), // 2024-11-15
248+
Some(20072), // 2024-12-15
249+
];
250+
let date_array: ArrayRef = Arc::new(Date32Array::from(dates));
251+
252+
let result = func
253+
.invoke_with_args(make_args(
254+
vec![ColumnarValue::Array(date_array)],
255+
vec![Arc::new(Field::new("d", DataType::Date32, true))],
256+
12,
257+
))
258+
.unwrap();
259+
260+
let expected = [
261+
"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov",
262+
"Dec",
263+
];
264+
match result {
265+
ColumnarValue::Array(arr) => {
266+
let str_arr = arr.as_any().downcast_ref::<StringArray>().unwrap();
267+
for (i, exp) in expected.iter().enumerate() {
268+
assert_eq!(str_arr.value(i), *exp, "Month {} mismatch", i + 1);
269+
}
270+
}
271+
other => panic!("Expected array, got {other:?}"),
272+
}
273+
}
274+
275+
#[test]
276+
fn test_monthname_return_field_nullable() {
277+
let func = SparkMonthName::new();
278+
279+
let nullable = func
280+
.return_field_from_args(ReturnFieldArgs {
281+
arg_fields: &[Arc::new(Field::new("d", DataType::Date32, true))],
282+
scalar_arguments: &[None],
283+
})
284+
.unwrap();
285+
assert!(nullable.is_nullable());
286+
assert_eq!(nullable.data_type(), &DataType::Utf8);
287+
288+
let non_nullable = func
289+
.return_field_from_args(ReturnFieldArgs {
290+
arg_fields: &[Arc::new(Field::new("d", DataType::Date32, false))],
291+
scalar_arguments: &[None],
292+
})
293+
.unwrap();
294+
assert!(!non_nullable.is_nullable());
295+
}
296+
}

0 commit comments

Comments
 (0)