Skip to content

Commit 3a45981

Browse files
Yaohua628cloud-fan
authored andcommitted
[SPARK-37896][SQL] Implement a ConstantColumnVector and improve performance of the hidden file metadata
### What changes were proposed in this pull request? Implement a new column vector named `ConstantColumnVector`, which avoids copying the same data for all rows but storing only one copy of the data. Also, improve performance of hidden file metadata FileScanRDD ### Why are the changes needed? Performance improvements. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? A new test suite. Closes apache#35068 from Yaohua628/spark-37770. Authored-by: yaohua <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 817d1d7 commit 3a45981

File tree

4 files changed

+515
-33
lines changed

4 files changed

+515
-33
lines changed
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
package org.apache.spark.sql.execution.vectorized;
18+
19+
import java.math.BigDecimal;
20+
import java.math.BigInteger;
21+
22+
import org.apache.spark.sql.types.*;
23+
import org.apache.spark.sql.vectorized.ColumnVector;
24+
import org.apache.spark.sql.vectorized.ColumnarArray;
25+
import org.apache.spark.sql.vectorized.ColumnarMap;
26+
import org.apache.spark.unsafe.types.UTF8String;
27+
28+
/**
29+
* This class adds the constant support to ColumnVector.
30+
* It supports all the types and contains `set` APIs,
31+
* which will set the exact same value to all rows.
32+
*
33+
* Capacity: The vector stores only one copy of the data.
34+
*/
35+
public class ConstantColumnVector extends ColumnVector {
36+
37+
// The data stored in this ConstantColumnVector, the vector stores only one copy of the data.
38+
private byte nullData;
39+
private byte byteData;
40+
private short shortData;
41+
private int intData;
42+
private long longData;
43+
private float floatData;
44+
private double doubleData;
45+
private UTF8String stringData;
46+
private byte[] byteArrayData;
47+
private ConstantColumnVector[] childData;
48+
private ColumnarArray arrayData;
49+
private ColumnarMap mapData;
50+
51+
private final int numRows;
52+
53+
/**
54+
* @param numRows: The number of rows for this ConstantColumnVector
55+
* @param type: The data type of this ConstantColumnVector
56+
*/
57+
public ConstantColumnVector(int numRows, DataType type) {
58+
super(type);
59+
this.numRows = numRows;
60+
61+
if (type instanceof StructType) {
62+
this.childData = new ConstantColumnVector[((StructType) type).fields().length];
63+
} else if (type instanceof CalendarIntervalType) {
64+
// Three columns. Months as int. Days as Int. Microseconds as Long.
65+
this.childData = new ConstantColumnVector[3];
66+
} else {
67+
this.childData = null;
68+
}
69+
}
70+
71+
@Override
72+
public void close() {
73+
byteArrayData = null;
74+
for (int i = 0; i < childData.length; i++) {
75+
childData[i].close();
76+
childData[i] = null;
77+
}
78+
childData = null;
79+
arrayData = null;
80+
mapData = null;
81+
}
82+
83+
@Override
84+
public boolean hasNull() {
85+
return nullData == 1;
86+
}
87+
88+
@Override
89+
public int numNulls() {
90+
return hasNull() ? numRows : 0;
91+
}
92+
93+
@Override
94+
public boolean isNullAt(int rowId) {
95+
return nullData == 1;
96+
}
97+
98+
/**
99+
* Sets all rows as `null`
100+
*/
101+
public void setNull() {
102+
nullData = (byte) 1;
103+
}
104+
105+
/**
106+
* Sets all rows as not `null`
107+
*/
108+
public void setNotNull() {
109+
nullData = (byte) 0;
110+
}
111+
112+
@Override
113+
public boolean getBoolean(int rowId) {
114+
return byteData == 1;
115+
}
116+
117+
/**
118+
* Sets the boolean `value` for all rows
119+
*/
120+
public void setBoolean(boolean value) {
121+
byteData = (byte) ((value) ? 1 : 0);
122+
}
123+
124+
@Override
125+
public byte getByte(int rowId) {
126+
return byteData;
127+
}
128+
129+
/**
130+
* Sets the byte `value` for all rows
131+
*/
132+
public void setByte(byte value) {
133+
byteData = value;
134+
}
135+
136+
@Override
137+
public short getShort(int rowId) {
138+
return shortData;
139+
}
140+
141+
/**
142+
* Sets the short `value` for all rows
143+
*/
144+
public void setShort(short value) {
145+
shortData = value;
146+
}
147+
148+
@Override
149+
public int getInt(int rowId) {
150+
return intData;
151+
}
152+
153+
/**
154+
* Sets the int `value` for all rows
155+
*/
156+
public void setInt(int value) {
157+
intData = value;
158+
}
159+
160+
@Override
161+
public long getLong(int rowId) {
162+
return longData;
163+
}
164+
165+
/**
166+
* Sets the long `value` for all rows
167+
*/
168+
public void setLong(long value) {
169+
longData = value;
170+
}
171+
172+
@Override
173+
public float getFloat(int rowId) {
174+
return floatData;
175+
}
176+
177+
/**
178+
* Sets the float `value` for all rows
179+
*/
180+
public void setFloat(float value) {
181+
floatData = value;
182+
}
183+
184+
@Override
185+
public double getDouble(int rowId) {
186+
return doubleData;
187+
}
188+
189+
/**
190+
* Sets the double `value` for all rows
191+
*/
192+
public void setDouble(double value) {
193+
doubleData = value;
194+
}
195+
196+
@Override
197+
public ColumnarArray getArray(int rowId) {
198+
return arrayData;
199+
}
200+
201+
/**
202+
* Sets the `ColumnarArray` `value` for all rows
203+
*/
204+
public void setArray(ColumnarArray value) {
205+
arrayData = value;
206+
}
207+
208+
@Override
209+
public ColumnarMap getMap(int ordinal) {
210+
return mapData;
211+
}
212+
213+
/**
214+
* Sets the `ColumnarMap` `value` for all rows
215+
*/
216+
public void setMap(ColumnarMap value) {
217+
mapData = value;
218+
}
219+
220+
@Override
221+
public Decimal getDecimal(int rowId, int precision, int scale) {
222+
// copy and modify from WritableColumnVector
223+
if (precision <= Decimal.MAX_INT_DIGITS()) {
224+
return Decimal.createUnsafe(getInt(rowId), precision, scale);
225+
} else if (precision <= Decimal.MAX_LONG_DIGITS()) {
226+
return Decimal.createUnsafe(getLong(rowId), precision, scale);
227+
} else {
228+
byte[] bytes = getBinary(rowId);
229+
BigInteger bigInteger = new BigInteger(bytes);
230+
BigDecimal javaDecimal = new BigDecimal(bigInteger, scale);
231+
return Decimal.apply(javaDecimal, precision, scale);
232+
}
233+
}
234+
235+
/**
236+
* Sets the `Decimal` `value` with the precision for all rows
237+
*/
238+
public void setDecimal(Decimal value, int precision) {
239+
// copy and modify from WritableColumnVector
240+
if (precision <= Decimal.MAX_INT_DIGITS()) {
241+
setInt((int) value.toUnscaledLong());
242+
} else if (precision <= Decimal.MAX_LONG_DIGITS()) {
243+
setLong(value.toUnscaledLong());
244+
} else {
245+
BigInteger bigInteger = value.toJavaBigDecimal().unscaledValue();
246+
setByteArray(bigInteger.toByteArray());
247+
}
248+
}
249+
250+
@Override
251+
public UTF8String getUTF8String(int rowId) {
252+
return stringData;
253+
}
254+
255+
/**
256+
* Sets the `UTF8String` `value` for all rows
257+
*/
258+
public void setUtf8String(UTF8String value) {
259+
stringData = value;
260+
}
261+
262+
/**
263+
* Sets the byte array `value` for all rows
264+
*/
265+
private void setByteArray(byte[] value) {
266+
byteArrayData = value;
267+
}
268+
269+
@Override
270+
public byte[] getBinary(int rowId) {
271+
return byteArrayData;
272+
}
273+
274+
/**
275+
* Sets the binary `value` for all rows
276+
*/
277+
public void setBinary(byte[] value) {
278+
setByteArray(value);
279+
}
280+
281+
@Override
282+
public ColumnVector getChild(int ordinal) {
283+
return childData[ordinal];
284+
}
285+
286+
/**
287+
* Sets the child `ConstantColumnVector` `value` at the given ordinal for all rows
288+
*/
289+
public void setChild(int ordinal, ConstantColumnVector value) {
290+
childData[ordinal] = value;
291+
}
292+
}

sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources._
3535
import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat => ParquetSource}
3636
import org.apache.spark.sql.execution.datasources.v2.PushedDownOperators
3737
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
38-
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
38+
import org.apache.spark.sql.execution.vectorized.ConstantColumnVector
3939
import org.apache.spark.sql.internal.SQLConf
4040
import org.apache.spark.sql.sources.{BaseRelation, Filter}
4141
import org.apache.spark.sql.types.StructType
@@ -221,8 +221,8 @@ case class FileSourceScanExec(
221221
requiredSchema = requiredSchema,
222222
partitionSchema = relation.partitionSchema,
223223
relation.sparkSession.sessionState.conf).map { vectorTypes =>
224-
// for column-based file format, append metadata struct column's vector type classes if any
225-
vectorTypes ++ Seq.fill(metadataColumns.size)(classOf[OnHeapColumnVector].getName)
224+
// for column-based file format, append metadata column's vector type classes if any
225+
vectorTypes ++ Seq.fill(metadataColumns.size)(classOf[ConstantColumnVector].getName)
226226
}
227227

228228
private lazy val driverMetrics: HashMap[String, Long] = HashMap.empty

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala

Lines changed: 15 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow
3131
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, GenericInternalRow, JoinedRow, UnsafeProjection, UnsafeRow}
3232
import org.apache.spark.sql.errors.QueryExecutionErrors
3333
import org.apache.spark.sql.execution.datasources.FileFormat._
34-
import org.apache.spark.sql.execution.vectorized.{OnHeapColumnVector, WritableColumnVector}
34+
import org.apache.spark.sql.execution.vectorized.ConstantColumnVector
3535
import org.apache.spark.sql.types.{LongType, StringType, StructType}
3636
import org.apache.spark.sql.vectorized.ColumnarBatch
3737
import org.apache.spark.util.NextIterator
@@ -133,49 +133,35 @@ class FileScanRDD(
133133
* For each partitioned file, metadata columns for each record in the file are exactly same.
134134
* Only update metadata row when `currentFile` is changed.
135135
*/
136-
private def updateMetadataRow(): Unit = {
136+
private def updateMetadataRow(): Unit =
137137
if (metadataColumns.nonEmpty && currentFile != null) {
138138
updateMetadataInternalRow(metadataRow, metadataColumns.map(_.name),
139139
new Path(currentFile.filePath), currentFile.fileSize, currentFile.modificationTime)
140140
}
141-
}
142141

143142
/**
144-
* Create a writable column vector containing all required metadata columns
143+
* Create an array of constant column vectors containing all required metadata columns
145144
*/
146-
private def createMetadataColumnVector(c: ColumnarBatch): Array[WritableColumnVector] = {
145+
private def createMetadataColumnVector(c: ColumnarBatch): Array[ConstantColumnVector] = {
147146
val path = new Path(currentFile.filePath)
148-
val filePathBytes = path.toString.getBytes
149-
val fileNameBytes = path.getName.getBytes
150-
var rowId = 0
151147
metadataColumns.map(_.name).map {
152148
case FILE_PATH =>
153-
val columnVector = new OnHeapColumnVector(c.numRows(), StringType)
154-
rowId = 0
155-
// use a tight-loop for better performance
156-
while (rowId < c.numRows()) {
157-
columnVector.putByteArray(rowId, filePathBytes)
158-
rowId += 1
159-
}
149+
val columnVector = new ConstantColumnVector(c.numRows(), StringType)
150+
columnVector.setUtf8String(UTF8String.fromString(path.toString))
160151
columnVector
161152
case FILE_NAME =>
162-
val columnVector = new OnHeapColumnVector(c.numRows(), StringType)
163-
rowId = 0
164-
// use a tight-loop for better performance
165-
while (rowId < c.numRows()) {
166-
columnVector.putByteArray(rowId, fileNameBytes)
167-
rowId += 1
168-
}
153+
val columnVector = new ConstantColumnVector(c.numRows(), StringType)
154+
columnVector.setUtf8String(UTF8String.fromString(path.getName))
169155
columnVector
170156
case FILE_SIZE =>
171-
val columnVector = new OnHeapColumnVector(c.numRows(), LongType)
172-
columnVector.putLongs(0, c.numRows(), currentFile.fileSize)
157+
val columnVector = new ConstantColumnVector(c.numRows(), LongType)
158+
columnVector.setLong(currentFile.fileSize)
173159
columnVector
174160
case FILE_MODIFICATION_TIME =>
175-
val columnVector = new OnHeapColumnVector(c.numRows(), LongType)
161+
val columnVector = new ConstantColumnVector(c.numRows(), LongType)
176162
// the modificationTime from the file is in millisecond,
177163
// while internally, the TimestampType is stored in microsecond
178-
columnVector.putLongs(0, c.numRows(), currentFile.modificationTime * 1000L)
164+
columnVector.setLong(currentFile.modificationTime * 1000L)
179165
columnVector
180166
}.toArray
181167
}
@@ -187,10 +173,9 @@ class FileScanRDD(
187173
private def addMetadataColumnsIfNeeded(nextElement: Object): Object = {
188174
if (metadataColumns.nonEmpty) {
189175
nextElement match {
190-
case c: ColumnarBatch =>
191-
new ColumnarBatch(
192-
Array.tabulate(c.numCols())(c.column) ++ createMetadataColumnVector(c),
193-
c.numRows())
176+
case c: ColumnarBatch => new ColumnarBatch(
177+
Array.tabulate(c.numCols())(c.column) ++ createMetadataColumnVector(c),
178+
c.numRows())
194179
case u: UnsafeRow => projection.apply(new JoinedRow(u, metadataRow))
195180
case i: InternalRow => new JoinedRow(i, metadataRow)
196181
}

0 commit comments

Comments
 (0)