Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import java.nio.ByteBuffer

import com.google.common.primitives.{Doubles, Ints, Longs}
import org.apache.datasketches.memory.Memory
import org.apache.datasketches.quantiles.{DoublesSketch, DoublesUnion, UpdateDoublesSketch}

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
Expand All @@ -31,10 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
import org.apache.spark.sql.catalyst.trees.TernaryLike
import org.apache.spark.sql.catalyst.types.PhysicalNumericType
import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
import org.apache.spark.sql.types._
import org.apache.spark.util.ArrayImplicits._

/**
* The ApproximatePercentile function returns the approximate percentile(s) of a column at the given
Expand Down Expand Up @@ -267,35 +263,40 @@ object ApproximatePercentile {
// The default relative error can be deduced by defaultError = 1.0 / DEFAULT_PERCENTILE_ACCURACY
val DEFAULT_PERCENTILE_ACCURACY: Int = 10000

def nextPowOf2(relativeError: Double): Int = {
val baseK = DoublesSketch.getKFromEpsilon(relativeError, true)
if (baseK == 1 || (baseK & (baseK - 1)) == 0) {
baseK
} else {
Integer.highestOneBit(baseK) * 2
}
}

/**
* PercentileDigest is a probabilistic data structure used for approximating percentiles
* with limited memory. PercentileDigest is backed by [[QuantileSummaries]].
* with limited memory. PercentileDigest is backed by [[DoublesSketch]].
*
* @param summaries underlying probabilistic data structure [[QuantileSummaries]].
* @param sketch underlying probabilistic data structure [[DoublesSketch]].
*/
class PercentileDigest(private var summaries: QuantileSummaries) {
class PercentileDigest(private var sketch: UpdateDoublesSketch) {

def this(relativeError: Double) = {
this(new QuantileSummaries(defaultCompressThreshold, relativeError, compressed = true))
this(DoublesSketch.builder().setK(ApproximatePercentile.nextPowOf2(relativeError)).build())
}

private[sql] def isCompressed: Boolean = summaries.compressed

/** Returns compressed object of [[QuantileSummaries]] */
def quantileSummaries: QuantileSummaries = {
if (!isCompressed) compress()
summaries
}
def sketchInfo: UpdateDoublesSketch = sketch

/** Insert an observation value into the PercentileDigest data structure. */
def add(value: Double): Unit = {
summaries = summaries.insert(value)
sketch.update(value)
}

/** In-place merges in another PercentileDigest. */
def merge(other: PercentileDigest): Unit = {
if (!isCompressed) compress()
summaries = summaries.merge(other.quantileSummaries)
val doublesUnion = DoublesUnion.builder().setMaxK(sketch.getK).build()
doublesUnion.union(sketch)
doublesUnion.union(other.sketch)
sketch = doublesUnion.getResult
}

/**
Expand All @@ -309,17 +310,12 @@ object ApproximatePercentile {
* }}}
*/
def getPercentiles(percentages: Array[Double]): Seq[Double] = {
if (!isCompressed) compress()
if (summaries.count == 0 || percentages.length == 0) {
Array.emptyDoubleArray.toImmutableArraySeq
if (!sketch.isEmpty) {
sketch.getQuantiles(percentages).toSeq
} else {
summaries.query(percentages.toImmutableArraySeq).get
Seq.empty[Double]
}
}

private final def compress(): Unit = {
summaries = summaries.compress()
}
}

/**
Expand All @@ -329,52 +325,14 @@ object ApproximatePercentile {
*/
class PercentileDigestSerializer {

private final def length(summaries: QuantileSummaries): Int = {
// summaries.compressThreshold, summary.relativeError, summary.count
Ints.BYTES + Doubles.BYTES + Longs.BYTES +
// length of summary.sampled
Ints.BYTES +
// summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)]
summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES)
}

final def serialize(obj: PercentileDigest): Array[Byte] = {
val summary = obj.quantileSummaries
val buffer = ByteBuffer.wrap(new Array(length(summary)))
buffer.putInt(summary.compressThreshold)
buffer.putDouble(summary.relativeError)
buffer.putLong(summary.count)
buffer.putInt(summary.sampled.length)

var i = 0
while (i < summary.sampled.length) {
val stat = summary.sampled(i)
buffer.putDouble(stat.value)
buffer.putLong(stat.g)
buffer.putLong(stat.delta)
i += 1
}
buffer.array()
val sketch = obj.sketchInfo
sketch.toByteArray(false)
}

final def deserialize(bytes: Array[Byte]): PercentileDigest = {
val buffer = ByteBuffer.wrap(bytes)
val compressThreshold = buffer.getInt()
val relativeError = buffer.getDouble()
val count = buffer.getLong()
val sampledLength = buffer.getInt()
val sampled = new Array[Stats](sampledLength)

var i = 0
while (i < sampledLength) {
val value = buffer.getDouble()
val g = buffer.getLong()
val delta = buffer.getLong()
sampled(i) = Stats(value, g, delta)
i += 1
}
val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count, true)
new PercentileDigest(summary)
val sketch = DoublesSketch.heapify(Memory.wrap(bytes))
new PercentileDigest(sketch.asInstanceOf[UpdateDoublesSketch])
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -426,12 +426,18 @@ class ApproximatePercentileSuite extends SparkFunSuite {
}

private def compareEquals(left: PercentileDigest, right: PercentileDigest): Boolean = {
val leftSummary = left.quantileSummaries
val rightSummary = right.quantileSummaries
leftSummary.compressThreshold == rightSummary.compressThreshold &&
leftSummary.relativeError == rightSummary.relativeError &&
leftSummary.count == rightSummary.count &&
leftSummary.sampled.sameElements(rightSummary.sampled)
val leftSketch = left.sketchInfo
val rightSketch = right.sketchInfo
if (leftSketch.isEmpty && rightSketch.isEmpty) {
true
} else if (leftSketch.isEmpty || rightSketch.isEmpty) {
false
} else {
leftSketch.getK == rightSketch.getK &&
leftSketch.getMaxItem == rightSketch.getMaxItem &&
leftSketch.getMinItem == rightSketch.getMinItem &&
leftSketch.getN == rightSketch.getN
}
}

private def assertEqual[T](left: T, right: T): Unit = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
import java.time.{Duration, LocalDateTime, Period}

import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.PercentileDigest
import org.apache.spark.sql.catalyst.util.DateTimeUtils
Expand Down Expand Up @@ -291,18 +290,6 @@ class ApproximatePercentileQuerySuite extends QueryTest with SharedSparkSession
}
}

test("SPARK-24013: unneeded compress can cause performance issues with sorted input") {
val buffer = new PercentileDigest(1.0D / ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)
var compressCounts = 0
(1 to 10000000).foreach { i =>
buffer.add(i)
if (buffer.isCompressed) compressCounts += 1
}
assert(compressCounts > 0)
buffer.quantileSummaries
assert(buffer.isCompressed)
}

test("SPARK-32908: maximum target error in percentile_approx") {
withTempView(table) {
spark.read
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1182,8 +1182,8 @@ class DataFrameAggregateSuite extends QueryTest
approx_percentile(col("earnings"), lit(0.3), lit(1)),
approx_percentile(col("earnings"), array(lit(0.3), lit(0.6)), lit(1))
),
Row("Java", 20000.0, Seq(20000.0, 30000.0), 20000.0, Seq(20000.0, 20000.0)) ::
Row("dotNET", 5000.0, Seq(5000.0, 10000.0), 5000.0, Seq(5000.0, 5000.0)) :: Nil
Row("Java", 20000.0, Seq(20000.0, 30000.0), 20000.0, Seq(20000.0, 30000.0)) ::
Row("dotNET", 5000.0, Seq(5000.0, 10000.0), 5000.0, Seq(5000.0, 10000.0)) :: Nil
)
}

Expand Down