Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-50977][CORE] Enhance availability of logic performing aggregation of accumulator results #49618

Closed
106 changes: 106 additions & 0 deletions core/src/main/scala/org/apache/spark/util/MetricUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.util

import java.text.NumberFormat
import java.util.{Arrays, Locale}

import scala.concurrent.duration._

import org.apache.spark.SparkException
import org.apache.spark.util.Utils

object MetricUtils {

val SUM_METRIC: String = "sum"
val SIZE_METRIC: String = "size"
val TIMING_METRIC: String = "timing"
val NS_TIMING_METRIC: String = "nsTiming"
val AVERAGE_METRIC: String = "average"
private val baseForAvgMetric: Int = 10
private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))"

private def toNumberFormat(value: Long): String = {
val numberFormat = NumberFormat.getNumberInstance(Locale.US)
numberFormat.format(value.toDouble / baseForAvgMetric)
}

def metricNeedsMax(metricsType: String): Boolean = {
metricsType != SUM_METRIC
}

/**
* A function that defines how we aggregate the final accumulator results among all tasks,
* and represent it in string for a SQL physical operator.
*/
def stringValue(metricsType: String, values: Array[Long], maxMetrics: Array[Long]): String = {
// taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId)
val taskInfo = if (maxMetrics.isEmpty) {
"(driver)"
} else {
s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})"
}
if (metricsType == SUM_METRIC) {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
numberFormat.format(values.sum)
} else if (metricsType == AVERAGE_METRIC) {
val validValues = values.filter(_ > 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
toNumberFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(min, med, max) = {
Arrays.sort(validValues)
Seq(
toNumberFormat(validValues(0)),
toNumberFormat(validValues(validValues.length / 2)),
toNumberFormat(validValues(validValues.length - 1)))
}
s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)"
}
} else {
val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
Utils.bytesToString
} else if (metricsType == TIMING_METRIC) {
Utils.msDurationToString
} else if (metricsType == NS_TIMING_METRIC) {
duration => Utils.msDurationToString(duration.nanos.toMillis)
} else {
throw SparkException.internalError(s"unexpected metrics type: $metricsType")
}

val validValues = values.filter(_ >= 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
strFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(sum, min, med, max) = {
Arrays.sort(validValues)
Seq(
strFormat(validValues.sum),
strFormat(validValues(0)),
strFormat(validValues(validValues.length / 2)),
strFormat(validValues(validValues.length - 1)))
}
s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)"
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
package org.apache.spark.sql.execution.datasources.v2.python

import org.apache.spark.sql.connector.metric.{CustomMetric, CustomTaskMetric}
import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics}
import org.apache.spark.sql.execution.metric.SQLMetric
import org.apache.spark.sql.execution.python.PythonSQLMetrics
import org.apache.spark.util.MetricUtils


class PythonCustomMetric(
Expand All @@ -28,7 +29,7 @@ class PythonCustomMetric(
def this() = this(null, null)

override def aggregateTaskMetrics(taskMetrics: Array[Long]): String = {
SQLMetrics.stringValue("size", taskMetrics, Array.empty[Long])
MetricUtils.stringValue("size", taskMetrics, Array.empty[Long])
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,14 @@

package org.apache.spark.sql.execution.metric

import java.text.NumberFormat
import java.util.{Arrays, Locale}

import scala.concurrent.duration._

import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache}

import org.apache.spark.{SparkContext, SparkException}
import org.apache.spark.SparkContext
import org.apache.spark.scheduler.AccumulableInfo
import org.apache.spark.sql.connector.metric.CustomMetric
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2, Utils}
import org.apache.spark.util.{AccumulatorContext, AccumulatorV2}
import org.apache.spark.util.AccumulatorContext.internOption

/**
Expand Down Expand Up @@ -72,7 +67,7 @@ class SQLMetric(

// This is used to filter out metrics. Metrics with value equal to initValue should
// be filtered out, since they are either invalid or safe to filter without changing
// the aggregation defined in [[SQLMetrics.stringValue]].
// the aggregation defined in [[MetricUtils.stringValue]].
// Note that we don't use 0 here since we may want to collect 0 metrics for
// calculating min, max, etc. See SPARK-11013.
override def isZero: Boolean = _value == initValue
Expand Down Expand Up @@ -106,8 +101,8 @@ class SQLMetric(
SQLMetrics.cachedSQLAccumIdentifier)
}

// We should provide the raw value which can be -1, so that `SQLMetrics.stringValue` can correctly
// filter out the invalid -1 values.
// We should provide the raw value which can be -1, so that `MetricUtils.stringValue` can
// correctly filter out the invalid -1 values.
override def toInfoUpdate: AccumulableInfo = {
AccumulableInfo(id, name, internOption(Some(_value)), None, true, true,
SQLMetrics.cachedSQLAccumIdentifier)
Expand Down Expand Up @@ -203,77 +198,6 @@ object SQLMetrics {
acc
}

private def toNumberFormat(value: Long): String = {
val numberFormat = NumberFormat.getNumberInstance(Locale.US)
numberFormat.format(value.toDouble / baseForAvgMetric)
}

def metricNeedsMax(metricsType: String): Boolean = {
metricsType != SUM_METRIC
}

private val METRICS_NAME_SUFFIX = "(min, med, max (stageId: taskId))"

/**
* A function that defines how we aggregate the final accumulator results among all tasks,
* and represent it in string for a SQL physical operator.
*/
def stringValue(metricsType: String, values: Array[Long], maxMetrics: Array[Long]): String = {
// taskInfo = "(driver)" OR (stage ${stageId}.${attemptId}: task $taskId)
val taskInfo = if (maxMetrics.isEmpty) {
"(driver)"
} else {
s"(stage ${maxMetrics(1)}.${maxMetrics(2)}: task ${maxMetrics(3)})"
}
if (metricsType == SUM_METRIC) {
val numberFormat = NumberFormat.getIntegerInstance(Locale.US)
numberFormat.format(values.sum)
} else if (metricsType == AVERAGE_METRIC) {
val validValues = values.filter(_ > 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
toNumberFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(min, med, max) = {
Arrays.sort(validValues)
Seq(
toNumberFormat(validValues(0)),
toNumberFormat(validValues(validValues.length / 2)),
toNumberFormat(validValues(validValues.length - 1)))
}
s"$METRICS_NAME_SUFFIX:\n($min, $med, $max $taskInfo)"
}
} else {
val strFormat: Long => String = if (metricsType == SIZE_METRIC) {
Utils.bytesToString
} else if (metricsType == TIMING_METRIC) {
Utils.msDurationToString
} else if (metricsType == NS_TIMING_METRIC) {
duration => Utils.msDurationToString(duration.nanos.toMillis)
} else {
throw SparkException.internalError(s"unexpected metrics type: $metricsType")
}

val validValues = values.filter(_ >= 0)
// When there are only 1 metrics value (or None), no need to display max/min/median. This is
// common for driver-side SQL metrics.
if (validValues.length <= 1) {
strFormat(validValues.headOption.getOrElse(0))
} else {
val Seq(sum, min, med, max) = {
Arrays.sort(validValues)
Seq(
strFormat(validValues.sum),
strFormat(validValues(0)),
strFormat(validValues(validValues.length / 2)),
strFormat(validValues(validValues.length - 1)))
}
s"total $METRICS_NAME_SUFFIX\n$sum ($min, $med, $max $taskInfo)"
}
}
}

def postDriverMetricsUpdatedByValue(
sc: SparkContext,
executionId: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.SQLExecution
import org.apache.spark.sql.execution.metric._
import org.apache.spark.sql.internal.StaticSQLConf._
import org.apache.spark.status.{ElementTrackingStore, KVUtils, LiveEntity}
import org.apache.spark.util.Utils
import org.apache.spark.util.{MetricUtils, Utils}
import org.apache.spark.util.collection.OpenHashMap

class SQLAppStatusListener(
Expand Down Expand Up @@ -235,7 +235,7 @@ class SQLAppStatusListener(
}
}.getOrElse(
// Built-in SQLMetric
SQLMetrics.stringValue(m.metricType, _, _)
MetricUtils.stringValue(m.metricType, _, _)
)
(m.accumulatorId, metricAggMethod)
}.toMap
Expand Down Expand Up @@ -554,7 +554,7 @@ private class LiveStageMetrics(
/**
* Task metrics values for the stage. Maps the metric ID to the metric values for each
* index. For each metric ID, there will be the same number of values as the number
* of indices. This relies on `SQLMetrics.stringValue` treating 0 as a neutral value,
* of indices. This relies on `MetricUtils.stringValue` treating 0 as a neutral value,
* independent of the actual metric type.
*/
private val taskMetrics = new ConcurrentHashMap[Long, Array[Long]]()
Expand Down Expand Up @@ -601,7 +601,7 @@ private class LiveStageMetrics(
val metricValues = taskMetrics.computeIfAbsent(acc.id, _ => new Array(numTasks))
metricValues(taskIdx) = value

if (SQLMetrics.metricNeedsMax(accumIdsToMetricType(acc.id))) {
if (MetricUtils.metricNeedsMax(accumIdsToMetricType(acc.id))) {
val maxMetricsTaskId = metricsIdToMaxTaskValue.computeIfAbsent(acc.id, _ => Array(value,
taskId))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ import org.apache.spark.sql.test.SharedSparkSession
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
import org.apache.spark.status.{AppStatusStore, ElementTrackingStore}
import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator, SerializableConfiguration, Utils}
import org.apache.spark.util.{AccumulatorMetadata, JsonProtocol, LongAccumulator, MetricUtils, SerializableConfiguration, Utils}
import org.apache.spark.util.kvstore.InMemoryStore


Expand Down Expand Up @@ -597,9 +597,9 @@ abstract class SQLAppStatusListenerSuite extends SharedSparkSession with JsonTes
val metrics = statusStore.executionMetrics(execId)
val driverMetric = physicalPlan.metrics("dummy")
val driverMetric2 = physicalPlan.metrics("dummy2")
val expectedValue = SQLMetrics.stringValue(driverMetric.metricType,
val expectedValue = MetricUtils.stringValue(driverMetric.metricType,
Array(expectedAccumValue), Array.empty[Long])
val expectedValue2 = SQLMetrics.stringValue(driverMetric2.metricType,
val expectedValue2 = MetricUtils.stringValue(driverMetric2.metricType,
Array(expectedAccumValue2), Array.empty[Long])

assert(metrics.contains(driverMetric.id))
Expand Down
Loading