Skip to content

Commit

Permalink
Merge pull request #42 from target/add-stats
Browse files Browse the repository at this point in the history
Column Stats
  • Loading branch information
phpisciuneri authored Jun 10, 2020
2 parents 9805204 + b05b915 commit 9a8aa9c
Show file tree
Hide file tree
Showing 18 changed files with 767 additions and 8 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,25 @@ This check sums a column in all rows. If the sum applied to the `column` doesn't

**Note:** If bounds are non-inclusive, and the actual sum is equal to one of the bounds, the relative error percentage will be undefined.

#### `colstats`

This check generates column statistics about the specified column.

| Arg | Type | Description |
|-------------|-------------|--------------------------------------------|
| `column` | String | The column on which to collect statistics. |

These values will appear in the check's JSON summary when using the JSON report output mode:

| Arg | Type | Description |
|-------------|-------------|-------------------------------------------------------------------------------------------------------------------------|
| `count` | Integer | Count of non-null entries in the `column`. |
| `mean` | Double | Mean/Average of the values in the `column`. |
| `min` | Double | Smallest value in the `column`. |
| `max` | Double | Largest value in the `column`. |
| `stdDev` | Double | Standard deviation of the values in the `column`. |
| `histogram` | Complex | Summary of an equi-width histogram, counts of values appearing in 10 equally sized buckets over the range `[min, max]`. |

## Example Config

```yaml
Expand Down
10 changes: 10 additions & 0 deletions src/main/scala/com/target/data_validator/JsonEncoders.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ object JsonEncoders extends LazyLogging {
Json.fromString(a.toString)
}

// scalastyle:off cyclomatic.complexity
implicit val eventEncoder: Encoder[ValidatorEvent] = new Encoder[ValidatorEvent] {
override def apply(a: ValidatorEvent): Json = a match {
case vc: ValidatorCounter => Json.obj(
Expand Down Expand Up @@ -69,8 +70,10 @@ object JsonEncoders extends LazyLogging {
("src", Json.fromString(vs.src)),
("dest", vs.dest)
)
case vj: JsonEvent => vj.json
}
}
// scalastyle:on cyclomatic.complexity

implicit val baseEncoder: Encoder[ValidatorBase] = new Encoder[ValidatorBase] {
final def apply(a: ValidatorBase): Json = a.toJson
Expand Down Expand Up @@ -98,6 +101,13 @@ object JsonEncoders extends LazyLogging {
("keyColumns", vp.keyColumns.asJson),
("checks", vp.checks.asJson),
("events", vp.getEvents.asJson))
case vdf: ValidatorDataFrame => Json.obj(
("dfLabel", vdf.label.asJson),
("failed", vdf.failed.asJson),
("keyColumns", vdf.keyColumns.asJson),
("checks", vdf.checks.asJson),
("events", vdf.getEvents.asJson)
)
}
}

Expand Down
6 changes: 6 additions & 0 deletions src/main/scala/com/target/data_validator/ValidatorEvent.scala
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,9 @@ case class VarSubJsonEvent(src: String, dest: Json) extends ValidatorEvent {
override def toString: String = s"VarSub src: $src dest: ${dest.noSpaces}"
override def toHTML: Text.all.Tag = div(cls:="subEvent")(toString)
}

case class JsonEvent(json: Json) extends ValidatorEvent {
override def failed: Boolean = false
override def toString: String = s"JsonEvent: json:${json.noSpaces}"
override def toHTML: Text.all.Tag = div(cls := "jsonEvent")(toString)
}
38 changes: 31 additions & 7 deletions src/main/scala/com/target/data_validator/ValidatorTable.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
package com.target.data_validator

import com.target.data_validator.validator.{CheapCheck, ColumnBased, CostlyCheck, RowBased, ValidatorBase}
import org.apache.spark.sql.{Column, DataFrame, Row, SparkSession}
import com.target.data_validator.validator._
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression}
import org.apache.spark.sql.catalyst.expressions.aggregate.{Count, Sum}

import scala.collection.mutable.ListBuffer
import scala.util.{Failure, Success, Try}
import scala.util._
import scalatags.Text.all._

abstract class ValidatorTable(
Expand Down Expand Up @@ -73,16 +73,37 @@ abstract class ValidatorTable(
ret
}

private def performFirstPass(df: DataFrame, checks: List[TwoPassCheapCheck]): Unit = {
if (checks.nonEmpty) {
val firstPassTimer = new ValidatorTimer(s"$label: pre-processing stage")

addEvent(firstPassTimer)

firstPassTimer.time {
val cols = checks.map { _.firstPassSelect() }
val row = df.select(cols: _*).head

checks foreach { _ sinkFirstPassRow row }
}
}
}

private def cheapExpression(dataFrame: DataFrame, dict: VarSubstitution): PartialFunction[CheapCheck, Expression] = {
case tp: TwoPassCheapCheck => tp.select(dataFrame.schema, dict)
case colChk: ColumnBased => colChk.select(dataFrame.schema, dict)
case chk: RowBased => Sum(chk.select(dataFrame.schema, dict)).toAggregateExpression()
}

def quickChecks(session: SparkSession, dict: VarSubstitution)(implicit vc: ValidatorConfig): Boolean = {
val dataFrame = open(session).get

performFirstPass(dataFrame, checks.collect { case tp: TwoPassCheapCheck => tp })

val qc: List[CheapCheck] = checks.flatMap {
case cc: CheapCheck => Some(cc)
case _ => None
}
val checkSelects: Seq[Expression] = qc.map {
case colChk: ColumnBased => colChk.select(dataFrame.schema, dict)
case chk: RowBased => Sum(chk.select(dataFrame.schema, dict)).toAggregateExpression()
}
val checkSelects = qc.map(cheapExpression(dataFrame, dict))

val cols: Seq[Column] = createCountSelect() ++ checkSelects.zipWithIndex.map {
case (chkSelect: Expression, idx: Int) => new Column(Alias(chkSelect, s"qc$idx")())
Expand Down Expand Up @@ -278,6 +299,9 @@ case class ValidatorDataFrame(
checks,
"DataFrame" + condition.map(x => s" with condition($x)").getOrElse("")
) {

final def label: String = "DataFrame" + condition.map(x => s" with condition($x)").getOrElse("")

override def getDF(session: SparkSession): Try[DataFrame] = Success(df)

override def substituteVariables(dict: VarSubstitution): ValidatorTable = {
Expand Down
3 changes: 3 additions & 0 deletions src/main/scala/com/target/data_validator/stats/Bin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.target.data_validator.stats

case class Bin(lowerBound: Double, upperBound: Double, count: Long)
41 changes: 41 additions & 0 deletions src/main/scala/com/target/data_validator/stats/CompleteStats.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package com.target.data_validator.stats

import io.circe._
import io.circe.generic.semiauto._

case class CompleteStats(
name: String,
column: String,
count: Long,
mean: Double,
min: Double,
max: Double,
stdDev: Double,
histogram: Histogram
)

object CompleteStats {
implicit val binEncoder: Encoder[Bin] = deriveEncoder
implicit val histogramEncoder: Encoder[Histogram] = deriveEncoder
implicit val encoder: Encoder[CompleteStats] = deriveEncoder

implicit val binDecoder: Decoder[Bin] = deriveDecoder
implicit val histogramDecoder: Decoder[Histogram] = deriveDecoder
implicit val decoder: Decoder[CompleteStats] = deriveDecoder

def apply(
name: String,
column: String,
firstPassStats: FirstPassStats,
secondPassStats: SecondPassStats
): CompleteStats = CompleteStats(
name = name,
column = column,
count = firstPassStats.count,
mean = firstPassStats.mean,
min = firstPassStats.min,
max = firstPassStats.max,
stdDev = secondPassStats.stdDev,
histogram = secondPassStats.histogram
)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package com.target.data_validator.stats

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.DataType

case class FirstPassStats(count: Long, mean: Double, min: Double, max: Double)

object FirstPassStats {
def dataType: DataType = ScalaReflection
.schemaFor[FirstPassStats]
.dataType

/**
* Convert from Spark SQL row format to case class [[FirstPassStats]] format.
*
* @param row a complex column of [[org.apache.spark.sql.types.StructType]] output of [[FirstPassStatsAggregator]]
* @return struct format converted to [[FirstPassStats]]
*/
def fromRowRepr(row: Row): FirstPassStats = {
FirstPassStats(
count = row.getLong(0),
mean = row.getDouble(1),
min = row.getDouble(2),
max = row.getDouble(3)
)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package com.target.data_validator.stats

import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._

/**
* Calculate the count, mean, min and maximum values of a numeric column.
*/
class FirstPassStatsAggregator extends UserDefinedAggregateFunction {

/**
* input is a single column of `DoubleType`
*/
override def inputSchema: StructType = new StructType().add("value", DoubleType)

/**
* buffer keeps state for the count, sum, min and max
*/
override def bufferSchema: StructType = new StructType()
.add(StructField("count", LongType))
.add(StructField("sum", DoubleType))
.add(StructField("min", DoubleType))
.add(StructField("max", DoubleType))

private val count = bufferSchema.fieldIndex("count")
private val sum = bufferSchema.fieldIndex("sum")
private val min = bufferSchema.fieldIndex("min")
private val max = bufferSchema.fieldIndex("max")

/**
* specifies the return type when using the UDAF
*/
override def dataType: DataType = FirstPassStats.dataType

/**
* These calculations are deterministic
*/
override def deterministic: Boolean = true

/**
* set the initial values for count, sum, min and max
*/
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(count) = 0L
buffer(sum) = 0.0
buffer(min) = Double.MaxValue
buffer(max) = Double.MinValue
}

/**
* update the count, sum, min and max buffer values
*/
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(count) = buffer.getLong(count) + 1
buffer(sum) = buffer.getDouble(sum) + input.getDouble(0)
buffer(min) = math.min(input.getDouble(0), buffer.getDouble(min))
buffer(max) = math.max(input.getDouble(0), buffer.getDouble(max))
}

/**
* reduce the count, sum, min and max values of two buffers
*/
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(count) = buffer1.getLong(count) + buffer2.getLong(count)
buffer1(sum) = buffer1.getDouble(sum) + buffer2.getDouble(sum)
buffer1(min) = math.min(buffer1.getDouble(min), buffer2.getDouble(min))
buffer1(max) = math.max(buffer1.getDouble(max), buffer2.getDouble(max))
}

/**
* evaluate the count, mean, min and max values of a column
*/
override def evaluate(buffer: Row): Any = {
FirstPassStats(
buffer.getLong(count),
buffer.getDouble(sum) / buffer.getLong(count),
buffer.getDouble(min),
buffer.getDouble(max)
)
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package com.target.data_validator.stats

case class Histogram(bins: Seq[Bin])
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
package com.target.data_validator.stats

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.types.DataType

case class SecondPassStats(stdDev: Double, histogram: Histogram)

object SecondPassStats {
def dataType: DataType = ScalaReflection
.schemaFor[SecondPassStats]
.dataType

/**
* Convert from Spark SQL row format to case class [[SecondPassStats]] format.
*
* @param row a complex column of [[org.apache.spark.sql.types.StructType]] output of [[SecondPassStatsAggregator]]
* @return struct format converted to [[SecondPassStats]]
*/
def fromRowRepr(row: Row): SecondPassStats = {
SecondPassStats(
stdDev = row.getDouble(0),
histogram = Histogram(
row.getStruct(1).getSeq[Row](0) map {
bin => Bin(
lowerBound = bin.getDouble(0),
upperBound = bin.getDouble(1),
count = bin.getLong(2)
)
}
)
)
}

}
Loading

0 comments on commit 9a8aa9c

Please sign in to comment.