Skip to content

Commit 8f4a8a5

Browse files
authored
fix: stddev_pop should not directly return 0.0 when count is 1.0 (#1184)
* add test * fix * fix * fix
1 parent e297d23 commit 8f4a8a5

File tree

2 files changed

+19
-4
lines changed

2 files changed

+19
-4
lines changed

native/spark-expr/src/variance.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -245,10 +245,8 @@ impl Accumulator for VarianceAccumulator {
245245

246246
Ok(ScalarValue::Float64(match self.count {
247247
count if count == 0.0 => None,
248-
count if count == 1.0 => {
249-
if let StatsType::Population = self.stats_type {
250-
Some(0.0)
251-
} else if self.null_on_divide_by_zero {
248+
count if count == 1.0 && StatsType::Sample == self.stats_type => {
249+
if self.null_on_divide_by_zero {
252250
None
253251
} else {
254252
Some(f64::NAN)

spark/src/test/scala/org/apache/comet/exec/CometAggregateSuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,23 @@ import org.apache.comet.CometSparkSessionExtensions.isSpark34Plus
3838
class CometAggregateSuite extends CometTestBase with AdaptiveSparkPlanHelper {
3939
import testImplicits._
4040

41+
test("stddev_pop should return NaN for some cases") {
42+
withSQLConf(
43+
CometConf.COMET_EXEC_SHUFFLE_ENABLED.key -> "true",
44+
CometConf.COMET_EXPR_STDDEV_ENABLED.key -> "true") {
45+
Seq(true, false).foreach { nullOnDivideByZero =>
46+
withSQLConf("spark.sql.legacy.statisticalAggregate" -> nullOnDivideByZero.toString) {
47+
48+
val data: Seq[(Float, Int)] = Seq((Float.PositiveInfinity, 1))
49+
withParquetTable(data, "tbl", false) {
50+
val df = sql("SELECT stddev_pop(_1), stddev_pop(_2) FROM tbl")
51+
checkSparkAnswer(df)
52+
}
53+
}
54+
}
55+
}
56+
}
57+
4158
test("count with aggregation filter") {
4259
withSQLConf(
4360
CometConf.COMET_ENABLED.key -> "true",

0 commit comments

Comments
 (0)