-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #42 from target/add-stats
Column Stats
- Loading branch information
Showing
18 changed files
with
767 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
package com.target.data_validator.stats | ||
|
||
case class Bin(lowerBound: Double, upperBound: Double, count: Long) |
41 changes: 41 additions & 0 deletions
41
src/main/scala/com/target/data_validator/stats/CompleteStats.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
package com.target.data_validator.stats | ||
|
||
import io.circe._ | ||
import io.circe.generic.semiauto._ | ||
|
||
case class CompleteStats( | ||
name: String, | ||
column: String, | ||
count: Long, | ||
mean: Double, | ||
min: Double, | ||
max: Double, | ||
stdDev: Double, | ||
histogram: Histogram | ||
) | ||
|
||
object CompleteStats { | ||
implicit val binEncoder: Encoder[Bin] = deriveEncoder | ||
implicit val histogramEncoder: Encoder[Histogram] = deriveEncoder | ||
implicit val encoder: Encoder[CompleteStats] = deriveEncoder | ||
|
||
implicit val binDecoder: Decoder[Bin] = deriveDecoder | ||
implicit val histogramDecoder: Decoder[Histogram] = deriveDecoder | ||
implicit val decoder: Decoder[CompleteStats] = deriveDecoder | ||
|
||
def apply( | ||
name: String, | ||
column: String, | ||
firstPassStats: FirstPassStats, | ||
secondPassStats: SecondPassStats | ||
): CompleteStats = CompleteStats( | ||
name = name, | ||
column = column, | ||
count = firstPassStats.count, | ||
mean = firstPassStats.mean, | ||
min = firstPassStats.min, | ||
max = firstPassStats.max, | ||
stdDev = secondPassStats.stdDev, | ||
histogram = secondPassStats.histogram | ||
) | ||
} |
29 changes: 29 additions & 0 deletions
29
src/main/scala/com/target/data_validator/stats/FirstPassStats.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
package com.target.data_validator.stats | ||
|
||
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.catalyst.ScalaReflection | ||
import org.apache.spark.sql.types.DataType | ||
|
||
case class FirstPassStats(count: Long, mean: Double, min: Double, max: Double) | ||
|
||
object FirstPassStats { | ||
def dataType: DataType = ScalaReflection | ||
.schemaFor[FirstPassStats] | ||
.dataType | ||
|
||
/** | ||
* Convert from Spark SQL row format to case class [[FirstPassStats]] format. | ||
* | ||
* @param row a complex column of [[org.apache.spark.sql.types.StructType]] output of [[FirstPassStatsAggregator]] | ||
* @return struct format converted to [[FirstPassStats]] | ||
*/ | ||
def fromRowRepr(row: Row): FirstPassStats = { | ||
FirstPassStats( | ||
count = row.getLong(0), | ||
mean = row.getDouble(1), | ||
min = row.getDouble(2), | ||
max = row.getDouble(3) | ||
) | ||
} | ||
|
||
} |
83 changes: 83 additions & 0 deletions
83
src/main/scala/com/target/data_validator/stats/FirstPassStatsAggregator.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
package com.target.data_validator.stats | ||
|
||
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} | ||
import org.apache.spark.sql.types._ | ||
|
||
/** | ||
* Calculate the count, mean, min and maximum values of a numeric column. | ||
*/ | ||
class FirstPassStatsAggregator extends UserDefinedAggregateFunction { | ||
|
||
/** | ||
* input is a single column of `DoubleType` | ||
*/ | ||
override def inputSchema: StructType = new StructType().add("value", DoubleType) | ||
|
||
/** | ||
* buffer keeps state for the count, sum, min and max | ||
*/ | ||
override def bufferSchema: StructType = new StructType() | ||
.add(StructField("count", LongType)) | ||
.add(StructField("sum", DoubleType)) | ||
.add(StructField("min", DoubleType)) | ||
.add(StructField("max", DoubleType)) | ||
|
||
private val count = bufferSchema.fieldIndex("count") | ||
private val sum = bufferSchema.fieldIndex("sum") | ||
private val min = bufferSchema.fieldIndex("min") | ||
private val max = bufferSchema.fieldIndex("max") | ||
|
||
/** | ||
* specifies the return type when using the UDAF | ||
*/ | ||
override def dataType: DataType = FirstPassStats.dataType | ||
|
||
/** | ||
* These calculations are deterministic | ||
*/ | ||
override def deterministic: Boolean = true | ||
|
||
/** | ||
* set the initial values for count, sum, min and max | ||
*/ | ||
override def initialize(buffer: MutableAggregationBuffer): Unit = { | ||
buffer(count) = 0L | ||
buffer(sum) = 0.0 | ||
buffer(min) = Double.MaxValue | ||
buffer(max) = Double.MinValue | ||
} | ||
|
||
/** | ||
* update the count, sum, min and max buffer values | ||
*/ | ||
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { | ||
buffer(count) = buffer.getLong(count) + 1 | ||
buffer(sum) = buffer.getDouble(sum) + input.getDouble(0) | ||
buffer(min) = math.min(input.getDouble(0), buffer.getDouble(min)) | ||
buffer(max) = math.max(input.getDouble(0), buffer.getDouble(max)) | ||
} | ||
|
||
/** | ||
* reduce the count, sum, min and max values of two buffers | ||
*/ | ||
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { | ||
buffer1(count) = buffer1.getLong(count) + buffer2.getLong(count) | ||
buffer1(sum) = buffer1.getDouble(sum) + buffer2.getDouble(sum) | ||
buffer1(min) = math.min(buffer1.getDouble(min), buffer2.getDouble(min)) | ||
buffer1(max) = math.max(buffer1.getDouble(max), buffer2.getDouble(max)) | ||
} | ||
|
||
/** | ||
* evaluate the count, mean, min and max values of a column | ||
*/ | ||
override def evaluate(buffer: Row): Any = { | ||
FirstPassStats( | ||
buffer.getLong(count), | ||
buffer.getDouble(sum) / buffer.getLong(count), | ||
buffer.getDouble(min), | ||
buffer.getDouble(max) | ||
) | ||
} | ||
|
||
} |
3 changes: 3 additions & 0 deletions
3
src/main/scala/com/target/data_validator/stats/Histogram.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
package com.target.data_validator.stats | ||
|
||
case class Histogram(bins: Seq[Bin]) |
35 changes: 35 additions & 0 deletions
35
src/main/scala/com/target/data_validator/stats/SecondPassStats.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
package com.target.data_validator.stats | ||
|
||
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.catalyst.ScalaReflection | ||
import org.apache.spark.sql.types.DataType | ||
|
||
case class SecondPassStats(stdDev: Double, histogram: Histogram) | ||
|
||
object SecondPassStats { | ||
def dataType: DataType = ScalaReflection | ||
.schemaFor[SecondPassStats] | ||
.dataType | ||
|
||
/** | ||
* Convert from Spark SQL row format to case class [[SecondPassStats]] format. | ||
* | ||
* @param row a complex column of [[org.apache.spark.sql.types.StructType]] output of [[SecondPassStatsAggregator]] | ||
* @return struct format converted to [[SecondPassStats]] | ||
*/ | ||
def fromRowRepr(row: Row): SecondPassStats = { | ||
SecondPassStats( | ||
stdDev = row.getDouble(0), | ||
histogram = Histogram( | ||
row.getStruct(1).getSeq[Row](0) map { | ||
bin => Bin( | ||
lowerBound = bin.getDouble(0), | ||
upperBound = bin.getDouble(1), | ||
count = bin.getLong(2) | ||
) | ||
} | ||
) | ||
) | ||
} | ||
|
||
} |
Oops, something went wrong.