Skip to content
Open
88 changes: 87 additions & 1 deletion native/spark-expr/src/conversion_funcs/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ use crate::utils::array_with_timezone;
use crate::{timezone, BinaryOutputStyle};
use crate::{EvalMode, SparkError, SparkResult};
use arrow::array::builder::StringBuilder;
use arrow::array::{DictionaryArray, GenericByteArray, StringArray, StructArray};
use arrow::array::{DictionaryArray, GenericByteArray, ListArray, StringArray, StructArray};
use arrow::compute::can_cast_types;
use arrow::datatypes::DataType::Utf8;
use arrow::datatypes::{
ArrowDictionaryKeyType, ArrowNativeType, DataType, GenericBinaryType, Schema,
};
Expand Down Expand Up @@ -1023,6 +1024,7 @@ fn cast_array(
to_type,
cast_options,
)?),
(List(_), Utf8) => Ok(cast_array_to_string(array.as_list(), cast_options)?),
(List(_), List(_)) if can_cast_types(from_type, to_type) => {
Ok(cast_with_options(&array, to_type, &CAST_OPTIONS)?)
}
Expand Down Expand Up @@ -1239,6 +1241,39 @@ fn cast_struct_to_struct(
}
}

fn cast_array_to_string(
array: &ListArray,
spark_cast_options: &SparkCastOptions,
) -> DataFusionResult<ArrayRef> {
let mut builder = StringBuilder::with_capacity(array.len(), array.len() * 16);
let mut str = String::with_capacity(array.len() * 16);
for row_index in 0..array.len() {
if array.is_null(row_index) {
builder.append_null();
} else {
str.clear();
let value_ref = array.value(row_index);
let native_cast_result = cast_array(value_ref, &Utf8, spark_cast_options).unwrap();
let string_array = native_cast_result
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
let mut any_fields_written = false;
str.push('[');
for s in string_array.iter() {
if any_fields_written {
str.push_str(", ");
}
str.push_str(s.unwrap_or(&spark_cast_options.null_string));
any_fields_written = true;
}
str.push(']');
builder.append_value(&str);
}
}
Ok(Arc::new(builder.finish()))
}

fn casts_struct_to_string(
array: &StructArray,
spark_cast_options: &SparkCastOptions,
Expand Down Expand Up @@ -2825,4 +2860,55 @@ mod tests {
assert!(casted.is_null(8));
assert!(casted.is_null(9));
}

#[test]
fn test_cast_string_array_to_string() {
use arrow::array::ListArray;
use arrow::buffer::OffsetBuffer;
let values_array =
StringArray::from(vec![Some("a"), Some("b"), Some("c"), Some("a"), None, None]);
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
let item_field = Arc::new(Field::new("item", DataType::Utf8, true));
let list_array = Arc::new(ListArray::new(
item_field,
offsets_buffer,
Arc::new(values_array),
None,
));
let string_array = cast_array_to_string(
&list_array,
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
)
.unwrap();
let string_array = string_array.as_string::<i32>();
assert_eq!(r#"[a, b, c]"#, string_array.value(0));
assert_eq!(r#"[a, null]"#, string_array.value(1));
assert_eq!(r#"[null]"#, string_array.value(2));
assert_eq!(r#"[]"#, string_array.value(3));
}

#[test]
fn test_cast_i32_array_to_string() {
use arrow::array::ListArray;
use arrow::buffer::OffsetBuffer;
let values_array = Int32Array::from(vec![Some(1), Some(2), Some(3), Some(1), None, None]);
let offsets_buffer = OffsetBuffer::<i32>::new(vec![0, 3, 5, 6, 6].into());
let item_field = Arc::new(Field::new("item", DataType::Int32, true));
let list_array = Arc::new(ListArray::new(
item_field,
offsets_buffer,
Arc::new(values_array),
None,
));
let string_array = cast_array_to_string(
&list_array,
&SparkCastOptions::new(EvalMode::Legacy, "UTC", false),
)
.unwrap();
let string_array = string_array.as_string::<i32>();
assert_eq!(r#"[1, 2, 3]"#, string_array.value(0));
assert_eq!(r#"[1, null]"#, string_array.value(1));
assert_eq!(r#"[null]"#, string_array.value(2));
assert_eq!(r#"[]"#, string_array.value(3));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ object CometCast extends CometExpressionSerde[Cast] with CometExprShim {

(fromType, toType) match {
case (dt: ArrayType, _: ArrayType) if dt.elementType == NullType => Compatible()
case (dt: ArrayType, DataTypes.StringType) =>
isSupported(dt.elementType, DataTypes.StringType, timeZoneId, evalMode)
case (dt: ArrayType, dt1: ArrayType) =>
isSupported(dt.elementType, dt1.elementType, timeZoneId, evalMode)
case (dt: DataType, _) if dt.typeName == "timestamp_ntz" =>
Expand Down
36 changes: 35 additions & 1 deletion spark/src/test/scala/org/apache/comet/CometCastSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.comet

import java.io.File

import scala.collection.mutable.ListBuffer
import scala.util.Random
import scala.util.matching.Regex

Expand All @@ -30,10 +31,11 @@ import org.apache.spark.sql.catalyst.expressions.Cast
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{DataType, DataTypes, DecimalType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, BinaryType, BooleanType, ByteType, DataType, DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, StructField, StructType}

import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.rules.CometScanTypeChecker
import org.apache.comet.serde.Compatible

class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
Expand Down Expand Up @@ -1034,6 +1036,32 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
castTest(generateDecimalsPrecision10Scale2(), DataTypes.createDecimalType(10, 0))
}

test("cast ArrayType to StringType") {

val scanImpl = sys.env
.getOrElse(
"COMET_PARQUET_SCAN_IMPL",
conf.getConfString(CometConf.COMET_NATIVE_SCAN_IMPL.key, "native_comet"))
val cometScanTypeChecker = CometScanTypeChecker(scanImpl)
val hasIncompatibleType =
(dt: DataType) => !cometScanTypeChecker.isTypeSupported(dt, scanImpl, ListBuffer.empty)
Seq(
BooleanType,
StringType,
ByteType,
IntegerType,
LongType,
ShortType,
// FloatType,
// DoubleType,
DecimalType(10, 2),
DecimalType(38, 18),
BinaryType).foreach { dt =>
val input = generateArrays(100, dt)
castTest(input, StringType, hasIncompatibleType(input.schema))
}
}

private def generateFloats(): DataFrame = {
withNulls(gen.generateFloats(dataSize)).toDF("a")
}
Expand Down Expand Up @@ -1062,6 +1090,12 @@ class CometCastSuite extends CometTestBase with AdaptiveSparkPlanHelper {
withNulls(gen.generateLongs(dataSize)).toDF("a")
}

private def generateArrays(rowSize: Int, elementType: DataType): DataFrame = {
import scala.collection.JavaConverters._
val schema = StructType(Seq(StructField("a", ArrayType(elementType), true)))
spark.createDataFrame(gen.generateRows(rowSize, schema).asJava, schema)
}

// https://github.com/apache/datafusion-comet/issues/2038
test("test implicit cast to dictionary with case when and dictionary type") {
withSQLConf("parquet.enable.dictionary" -> "true") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

package org.apache.spark.sql

import org.apache.comet.CometFuzzTestBase
import scala.collection.mutable.ListBuffer

import org.apache.comet.{CometConf, CometFuzzTestBase}
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.rules.CometScanTypeChecker
import org.apache.comet.serde.Compatible

import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Alias, ToPrettyString}
Expand All @@ -32,6 +34,8 @@ import org.apache.spark.sql.types.DataTypes
class CometToPrettyStringSuite extends CometFuzzTestBase {

test("ToPrettyString") {
val cometScanTypeChecker = CometScanTypeChecker(conf.getConfString(CometConf.COMET_NATIVE_SCAN_IMPL.key))
val scanImpl = conf.getConfString(CometConf.COMET_NATIVE_SCAN_IMPL.key)
val df = spark.read.parquet(filename)
df.createOrReplaceTempView("t1")
val table = spark.sessionState.catalog.lookupRelation(TableIdentifier("t1"))
Expand All @@ -43,7 +47,8 @@ class CometToPrettyStringSuite extends CometFuzzTestBase {
val analyzed = spark.sessionState.analyzer.execute(plan)
val result: DataFrame = Dataset.ofRows(spark, analyzed)
CometCast.isSupported(field.dataType, DataTypes.StringType, Some(spark.sessionState.conf.sessionLocalTimeZone), CometEvalMode.TRY) match {
case _: Compatible => checkSparkAnswerAndOperator(result)
case _: Compatible if cometScanTypeChecker.isTypeSupported(field.dataType, scanImpl, ListBuffer.empty) =>
checkSparkAnswerAndOperator(result)
case _ => checkSparkAnswer(result)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@

package org.apache.spark.sql

import org.apache.comet.CometFuzzTestBase
import scala.collection.mutable.ListBuffer

import org.apache.comet.{CometConf, CometFuzzTestBase}
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.rules.CometScanTypeChecker
import org.apache.comet.serde.Compatible
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
Expand All @@ -34,6 +37,8 @@ import org.apache.spark.sql.types.DataTypes
class CometToPrettyStringSuite extends CometFuzzTestBase {

test("ToPrettyString") {
val cometScanTypeChecker = CometScanTypeChecker(conf.getConfString(CometConf.COMET_NATIVE_SCAN_IMPL.key))
val scanImpl = conf.getConfString(CometConf.COMET_NATIVE_SCAN_IMPL.key)
val style = List(
BinaryOutputStyle.UTF8,
BinaryOutputStyle.BASIC,
Expand All @@ -54,7 +59,8 @@ class CometToPrettyStringSuite extends CometFuzzTestBase {
val analyzed = spark.sessionState.analyzer.execute(plan)
val result: DataFrame = Dataset.ofRows(spark, analyzed)
CometCast.isSupported(field.dataType, DataTypes.StringType, Some(spark.sessionState.conf.sessionLocalTimeZone), CometEvalMode.TRY) match {
case _: Compatible => checkSparkAnswerAndOperator(result)
case _: Compatible if cometScanTypeChecker.isTypeSupported(field.dataType, scanImpl, ListBuffer.empty) =>
checkSparkAnswerAndOperator(result)
case _ => checkSparkAnswer(result)
}
}
Expand Down
Loading