Skip to content

Commit 15bbadb

Browse files
authored
fix: correct schema type checking in native_iceberg_compat (#1755)
1 parent 7cfeb8b commit 15bbadb

File tree

4 files changed

+102
-78
lines changed

4 files changed

+102
-78
lines changed

common/src/main/java/org/apache/comet/parquet/NativeBatchReader.java

Lines changed: 98 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@
2525
import java.lang.reflect.InvocationTargetException;
2626
import java.lang.reflect.Method;
2727
import java.net.URI;
28-
import java.net.URISyntaxException;
2928
import java.nio.channels.Channels;
3029
import java.util.*;
3130

3231
import scala.Option;
32+
import scala.collection.JavaConverters;
3333
import scala.collection.Seq;
3434
import scala.collection.mutable.Buffer;
3535

@@ -52,6 +52,7 @@
5252
import org.apache.parquet.column.ColumnDescriptor;
5353
import org.apache.parquet.hadoop.metadata.BlockMetaData;
5454
import org.apache.parquet.hadoop.metadata.ParquetMetadata;
55+
import org.apache.parquet.schema.GroupType;
5556
import org.apache.parquet.schema.MessageType;
5657
import org.apache.parquet.schema.Type;
5758
import org.apache.spark.TaskContext;
@@ -61,6 +62,7 @@
6162
import org.apache.spark.sql.comet.parquet.CometParquetReadSupport;
6263
import org.apache.spark.sql.comet.util.Utils$;
6364
import org.apache.spark.sql.execution.datasources.PartitionedFile;
65+
import org.apache.spark.sql.execution.datasources.parquet.ParquetColumn;
6466
import org.apache.spark.sql.execution.datasources.parquet.ParquetToSparkSchemaConverter;
6567
import org.apache.spark.sql.execution.metric.SQLMetric;
6668
import org.apache.spark.sql.types.DataType;
@@ -76,8 +78,6 @@
7678
import org.apache.comet.vector.CometVector;
7779
import org.apache.comet.vector.NativeUtil;
7880

79-
import static org.apache.comet.parquet.TypeUtil.isEqual;
80-
8181
/**
8282
* A vectorized Parquet reader that reads a Parquet file in a batched fashion.
8383
*
@@ -113,6 +113,7 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme
113113

114114
private StructType sparkSchema;
115115
private StructType dataSchema;
116+
MessageType fileSchema;
116117
private MessageType requestedSchema;
117118
private CometVector[] vectors;
118119
private AbstractColumnReader[] columnReaders;
@@ -124,6 +125,8 @@ public class NativeBatchReader extends RecordReader<Void, ColumnarBatch> impleme
124125
private ParquetMetadata footer;
125126
private byte[] nativeFilter;
126127

128+
private ParquetColumn parquetColumn;
129+
127130
/**
128131
* Whether the native scan should always return decimal represented by 128 bits, regardless of its
129132
* precision. Normally, this should be true if native execution is enabled, since Arrow compute
@@ -229,7 +232,13 @@ public NativeBatchReader(AbstractColumnReader[] columnReaders) {
229232
* Initialize this reader. The reason we don't do it in the constructor is that we want to close
230233
* any resource hold by this reader when error happens during the initialization.
231234
*/
232-
public void init() throws URISyntaxException, IOException {
235+
public void init() throws Throwable {
236+
237+
conf.set("spark.sql.parquet.binaryAsString", "false");
238+
conf.set("spark.sql.parquet.int96AsTimestamp", "false");
239+
conf.set("spark.sql.caseSensitive", "false");
240+
conf.set("spark.sql.parquet.inferTimestampNTZ.enabled", "true");
241+
conf.set("spark.sql.legacy.parquet.nanosAsLong", "false");
233242

234243
useDecimal128 =
235244
conf.getBoolean(
@@ -257,10 +266,11 @@ public void init() throws URISyntaxException, IOException {
257266
CometInputFile.fromPath(path, conf), footer, readOptions, cometReadOptions, metrics)) {
258267

259268
requestedSchema = footer.getFileMetaData().getSchema();
260-
MessageType fileSchema = requestedSchema;
269+
fileSchema = requestedSchema;
270+
ParquetToSparkSchemaConverter converter = new ParquetToSparkSchemaConverter(conf);
261271

262272
if (sparkSchema == null) {
263-
sparkSchema = new ParquetToSparkSchemaConverter(conf).convert(requestedSchema);
273+
sparkSchema = converter.convert(requestedSchema);
264274
} else {
265275
requestedSchema =
266276
CometParquetReadSupport.clipParquetSchema(
@@ -269,9 +279,11 @@ public void init() throws URISyntaxException, IOException {
269279
throw new IllegalArgumentException(
270280
String.format(
271281
"Spark schema has %d columns while " + "Parquet schema has %d columns",
272-
sparkSchema.size(), requestedSchema.getColumns().size()));
282+
sparkSchema.size(), requestedSchema.getFieldCount()));
273283
}
274284
}
285+
this.parquetColumn =
286+
converter.convertParquetColumn(requestedSchema, Option.apply(this.sparkSchema));
275287

276288
String timeZoneId = conf.get("spark.sql.session.timeZone");
277289
// Native code uses "UTC" always as the timeZoneId when converting from spark to arrow schema.
@@ -283,6 +295,8 @@ public void init() throws URISyntaxException, IOException {
283295
// Create Column readers
284296
List<Type> fields = requestedSchema.getFields();
285297
List<Type> fileFields = fileSchema.getFields();
298+
ParquetColumn[] parquetFields =
299+
JavaConverters.seqAsJavaList(parquetColumn.children()).toArray(new ParquetColumn[0]);
286300
int numColumns = fields.size();
287301
if (partitionSchema != null) numColumns += partitionSchema.size();
288302
columnReaders = new AbstractColumnReader[numColumns];
@@ -332,9 +346,8 @@ public void init() throws URISyntaxException, IOException {
332346
} else if (optFileField.isPresent()) {
333347
// The column we are reading may be a complex type in which case we check if each field in
334348
// the requested type is in the file type (and the same data type)
335-
if (!isEqual(field, optFileField.get())) {
336-
throw new UnsupportedOperationException("Schema evolution is not supported");
337-
}
349+
// This makes the same check as Spark's VectorizedParquetReader
350+
checkColumn(parquetFields[i]);
338351
missingColumns[i] = false;
339352
} else {
340353
if (field.getRepetition() == Type.Repetition.REQUIRED) {
@@ -407,6 +420,77 @@ public void init() throws URISyntaxException, IOException {
407420
isInitialized = true;
408421
}
409422

423+
private void checkParquetType(ParquetColumn column) throws IOException {
424+
String[] path = JavaConverters.seqAsJavaList(column.path()).toArray(new String[0]);
425+
if (containsPath(fileSchema, path)) {
426+
if (column.isPrimitive()) {
427+
ColumnDescriptor desc = column.descriptor().get();
428+
ColumnDescriptor fd = fileSchema.getColumnDescription(desc.getPath());
429+
TypeUtil.checkParquetType(fd, column.sparkType());
430+
} else {
431+
for (ParquetColumn childColumn : JavaConverters.seqAsJavaList(column.children())) {
432+
checkColumn(childColumn);
433+
}
434+
}
435+
} else { // A missing column which is either primitive or complex
436+
if (column.required()) {
437+
// Column is missing in data but the required data is non-nullable. This file is invalid.
438+
throw new IOException(
439+
"Required column is missing in data file. Col: " + Arrays.toString(path));
440+
}
441+
}
442+
}
443+
444+
/**
445+
* Checks whether the given 'path' exists in 'parquetType'. The difference between this and {@link
446+
* MessageType#containsPath(String[])} is that the latter only support paths to leaf From Spark:
447+
* VectorizedParquetRecordReader Check whether a column from requested schema is missing from the
448+
* file schema, or whether it conforms to the type of the file schema.
449+
*/
450+
private void checkColumn(ParquetColumn column) throws IOException {
451+
String[] path = JavaConverters.seqAsJavaList(column.path()).toArray(new String[0]);
452+
if (containsPath(fileSchema, path)) {
453+
if (column.isPrimitive()) {
454+
ColumnDescriptor desc = column.descriptor().get();
455+
ColumnDescriptor fd = fileSchema.getColumnDescription(desc.getPath());
456+
if (!fd.equals(desc)) {
457+
throw new UnsupportedOperationException("Schema evolution not supported.");
458+
}
459+
} else {
460+
for (ParquetColumn childColumn : JavaConverters.seqAsJavaList(column.children())) {
461+
checkColumn(childColumn);
462+
}
463+
}
464+
} else { // A missing column which is either primitive or complex
465+
if (column.required()) {
466+
// Column is missing in data but the required data is non-nullable. This file is invalid.
467+
throw new IOException(
468+
"Required column is missing in data file. Col: " + Arrays.toString(path));
469+
}
470+
}
471+
}
472+
473+
/**
474+
* Checks whether the given 'path' exists in 'parquetType'. The difference between this and {@link
475+
* MessageType#containsPath(String[])} is that the latter only support paths to leaf nodes, while
476+
* this support paths both to leaf and non-leaf nodes.
477+
*/
478+
private boolean containsPath(Type parquetType, String[] path) {
479+
return containsPath(parquetType, path, 0);
480+
}
481+
482+
private boolean containsPath(Type parquetType, String[] path, int depth) {
483+
if (path.length == depth) return true;
484+
if (parquetType instanceof GroupType) {
485+
String fieldName = path[depth];
486+
GroupType parquetGroupType = (GroupType) parquetType;
487+
if (parquetGroupType.containsField(fieldName)) {
488+
return containsPath(parquetGroupType.getType(fieldName), path, depth + 1);
489+
}
490+
}
491+
return false;
492+
}
493+
410494
public void setSparkSchema(StructType schema) {
411495
this.sparkSchema = schema;
412496
}
@@ -532,7 +616,10 @@ private int loadNextBatch() throws Throwable {
532616
if (importer != null) importer.close();
533617
importer = new CometSchemaImporter(ALLOCATOR);
534618

535-
List<ColumnDescriptor> columns = requestedSchema.getColumns();
619+
for (ParquetColumn childColumn : JavaConverters.seqAsJavaList(parquetColumn.children())) {
620+
checkParquetType(childColumn);
621+
}
622+
536623
List<Type> fields = requestedSchema.getFields();
537624
for (int i = 0; i < fields.size(); i++) {
538625
if (!missingColumns[i]) {

common/src/main/java/org/apache/comet/parquet/TypeUtil.java

Lines changed: 0 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@
2020
package org.apache.comet.parquet;
2121

2222
import java.util.Arrays;
23-
import java.util.List;
24-
import java.util.Optional;
2523

2624
import org.apache.parquet.column.ColumnDescriptor;
2725
import org.apache.parquet.schema.*;
@@ -319,60 +317,4 @@ private static boolean isUnsignedIntTypeMatched(
319317
private static boolean isSpark40Plus() {
320318
return package$.MODULE$.SPARK_VERSION().compareTo("4.0") >= 0;
321319
}
322-
323-
public static boolean isComplexType(Type t) {
324-
return !t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED);
325-
}
326-
327-
// From Parquet Type.java
328-
public static boolean eqOrBothNull(Object o1, Object o2) {
329-
return (o1 == null && o2 == null) || (o1 != null && o1.equals(o2));
330-
}
331-
332-
// From Parquet Type.java
333-
public static boolean equals(Type one, Type other) {
334-
return one.getName().equals(other.getName())
335-
&& one.getRepetition() == other.getRepetition()
336-
&& eqOrBothNull(one.getRepetition(), other.getRepetition())
337-
&& eqOrBothNull(one.getId(), other.getId())
338-
&& eqOrBothNull(one.getLogicalTypeAnnotation(), other.getLogicalTypeAnnotation());
339-
}
340-
341-
//
342-
// Compare a field with another field and return true if they are the same. Unlike
343-
// the equals method for Type (and derived classes), allows requested to have fields
344-
// that are not in actual.
345-
//
346-
public static boolean isEqual(Type requested, Type actual) {
347-
if (requested == null && actual == null) {
348-
return true;
349-
}
350-
if (requested == null || actual == null) {
351-
return false;
352-
}
353-
if (requested.isPrimitive() && actual.isPrimitive()) {
354-
return requested.asPrimitiveType().equals(actual.asPrimitiveType());
355-
} else if (!requested.isPrimitive() && !actual.isPrimitive()) {
356-
if (equals(requested, actual)) {
357-
// GroupType.equals also checks if LogicalTypeAnnotation is the same.
358-
// But it really is not necessary here.
359-
List<Type> requestedFields = requested.asGroupType().getFields();
360-
List<Type> actualFields = requested.asGroupType().getFields();
361-
for (Type field : requestedFields) {
362-
Optional<Type> optActualField =
363-
actualFields.stream().filter(f -> f.getName().equals(field.getName())).findFirst();
364-
if (optActualField.isPresent()) {
365-
if (!isEqual(field, optActualField.get())) {
366-
return false;
367-
}
368-
}
369-
}
370-
} else {
371-
return false;
372-
}
373-
} else {
374-
return false; // one is a primitive type and the other is not.
375-
}
376-
return true;
377-
}
378320
}

common/src/main/scala/org/apache/spark/sql/comet/parquet/CometParquetReadSupport.scala

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -194,8 +194,6 @@ object CometParquetReadSupport {
194194
.addField(clipParquetType(repeatedGroup, elementType, caseSensitive, useFieldId))
195195
.named(parquetList.getName)
196196
} else {
197-
// Otherwise, the repeated field's type is the element type with the repeated field's
198-
// repetition.
199197
val newRepeatedGroup = Types
200198
.repeatedGroup()
201199
.addField(
@@ -208,15 +206,11 @@ object CometParquetReadSupport {
208206
newRepeatedGroup
209207
}
210208

209+
// Otherwise, the repeated field's type is the element type with the repeated field's
210+
// repetition.
211211
Types
212212
.buildGroup(parquetList.getRepetition)
213213
.as(LogicalTypeAnnotation.listType())
214-
.addField(
215-
Types
216-
.repeatedGroup()
217-
.addField(
218-
clipParquetType(repeatedGroup.getType(0), elementType, caseSensitive, useFieldId))
219-
.named(repeatedGroup.getName))
220214
.addField(newElementType)
221215
.named(parquetList.getName)
222216
}

spark/src/test/scala/org/apache/comet/parquet/ParquetReadSuite.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1233,7 +1233,8 @@ abstract class ParquetReadSuite extends CometTestBase {
12331233

12341234
withParquetDataFrame(data, schema = Some(readSchema)) { df =>
12351235
// TODO: validate with Spark 3.x and 'usingDataFusionParquetExec=true'
1236-
if (enableSchemaEvolution || usingDataSourceExec(conf)) {
1236+
if (enableSchemaEvolution || CometConf.COMET_NATIVE_SCAN_IMPL
1237+
.get(conf) == CometConf.SCAN_NATIVE_DATAFUSION) {
12371238
checkAnswer(df, data.map(Row.fromTuple))
12381239
} else {
12391240
assertThrows[SparkException](df.collect())

0 commit comments

Comments
 (0)