Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.sql.SqlConstantValueAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.tools.RelBuilder;

import org.checkerframework.checker.nullness.qual.Nullable;
Expand All @@ -52,12 +54,19 @@
* arguments exist in the aggregate's group set or are deterministic
* expressions involving only group set columns and constants:
* <ul>
* <li>{@code MAX}</li>
* <li>{@code MIN}</li>
* <li>{@code AVG}</li>
* <li>{@code ANY_VALUE}</li>
* <li>{@code MAX} - the GROUP BY key value itself</li>
* <li>{@code MIN} - the GROUP BY key value itself</li>
* <li>{@code AVG} - the GROUP BY key value itself</li>
* <li>{@code ANY_VALUE} - the GROUP BY key value itself</li>
* <li>Functions implementing {@link SqlConstantValueAggFunction} such as
* {@code STDDEV_POP}, {@code STDDEV_SAMP}, {@code VAR_POP}, {@code VAR_SAMP}
* - return their constant result</li>
* </ul>
*
* <p>Aggregate functions that implement {@link SqlConstantValueAggFunction}
* declare what value to return when applied to constant (GROUP BY key) arguments.
* This enables the rule to optimize them without hard-coded type checks.
*
* <p>Note: This optimization preserves NULL semantics correctly. For aggregate
* functions like MAX, MIN, and ANY_VALUE, NULL values in the source columns or
* expressions are handled the same way before and after the transformation:
Expand Down Expand Up @@ -131,7 +140,7 @@
*
* @return the reduced expression, or null if cannot reduce
*/
private static @Nullable RexNode reduce(

Check failure on line 143 in core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsOnGroupKeysRule.java

View check run for this annotation

SonarQubeCloud / SonarCloud Code Analysis

Refactor this method to reduce its Cognitive Complexity from 32 to the 15 allowed.

See more on https://sonarcloud.io/project/issues?id=apache_calcite&issues=AZ6qMnurEfbHpVh1K-38&open=AZ6qMnurEfbHpVh1K-38&pullRequest=5003
Aggregate aggregate,
AggregateCall call,
RexBuilder rexBuilder) {
Expand All @@ -144,14 +153,18 @@
return null;
}
final SqlKind kind = call.getAggregation().getKind();
final boolean isConstantValueAgg =
call.getAggregation() instanceof SqlConstantValueAggFunction;
switch (kind) {
case AVG:
case MAX:
case MIN:
case ANY_VALUE:
break;
default:
return null;
if (!isConstantValueAgg) {
return null;
}
}
final List<Integer> argList = call.getArgList();
if (argList.size() != 1) {
Expand All @@ -163,6 +176,29 @@
if (aggregate.getGroupSet().get(arg)) {
final int groupIndex = aggregate.getGroupSet().asList().indexOf(arg);
RexNode ref = RexInputRef.of(groupIndex, aggregate.getRowType().getFieldList());

// For functions that return a constant value when applied to constant (GROUP BY key)
// arguments, delegate to the function's own implementation
if (isConstantValueAgg) {
final @Nullable RexNode constantResult =
((SqlConstantValueAggFunction) call.getAggregation())
.getConstantResult(rexBuilder, call.getType());
if (constantResult != null) {
// Handle NULL semantics: if the GROUP BY key is nullable and the constant value
// is non-null (e.g., 0 for STDDEV functions), wrap in CASE to return NULL when
// the key is NULL, since aggregate functions skip NULL inputs
if (ref.getType().isNullable()) {
return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, ref),
rexBuilder.makeNullLiteral(call.getType()),
constantResult);
}
return constantResult;
}
}

// For other aggregate functions (MAX, MIN, AVG, ANY_VALUE),
// the value of a constant is the constant itself
if (!ref.getType().equals(call.getType())) {
ref = rexBuilder.makeCast(call.getParserPosition(), call.getType(), ref);
}
Expand Down Expand Up @@ -192,6 +228,29 @@
if (translated == null) {
return null;
}

// For functions that return a constant value when applied to constant expressions,
// delegate to the function's own implementation
if (isConstantValueAgg) {
final @Nullable RexNode constantResult =
((SqlConstantValueAggFunction) call.getAggregation())
.getConstantResult(rexBuilder, call.getType());
if (constantResult != null) {
// Handle NULL semantics: if the expression is nullable and the constant value
// is non-null (e.g., 0 for STDDEV functions), wrap in CASE to return NULL when
// the expression evaluates to NULL, since aggregate functions skip NULL inputs
if (translated.getType().isNullable()) {
return rexBuilder.makeCall(SqlStdOperatorTable.CASE,
rexBuilder.makeCall(SqlStdOperatorTable.IS_NULL, translated),
rexBuilder.makeNullLiteral(call.getType()),
constantResult);
}
return constantResult;
}
}

// For other aggregate functions (MAX, MIN, AVG, ANY_VALUE),
// return the translated expression
if (!translated.getType().equals(call.getType())) {
return rexBuilder.makeCast(call.getParserPosition(), call.getType(), translated);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.sql;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;

import org.checkerframework.checker.nullness.qual.Nullable;

/**
* Aggregate function that returns a constant value when applied to constant
* (GROUP BY key) arguments.
*
* <p>For example, statistical functions like STDDEV_POP, STDDEV_SAMP, VAR_POP,
* VAR_SAMP always return 0 when applied to a constant value, since there is
* no variation in a set of identical values.
*
* <p>This interface allows optimization rules to identify and reduce such
* aggregate functions without hard-coded checks for specific function types.
*/
public interface SqlConstantValueAggFunction {
/**
* Generates the constant result expression when this aggregate function is
* applied to arguments that are constant within each group (i.e., GROUP BY keys
* or expressions derived only from GROUP BY keys).
*
* <p>For example:
* <ul>
* <li>{@code STDDEV_POP(constant)} returns {@code 0}
* <li>{@code VAR_SAMP(constant)} returns {@code 0}
* <li>{@code STDDEV_SAMP(constant)} returns {@code 0}
* </ul>
*
* @param rexBuilder Rex builder for creating the result expression
* @param returnType The return type of the aggregate function
* @return An expression representing the constant result, or null if this function
* does not return a constant value for constant arguments
*/
@Nullable RexNode getConstantResult(RexBuilder rexBuilder, RelDataType returnType);
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,32 @@
package org.apache.calcite.sql.fun;

import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlConstantValueAggFunction;
import org.apache.calcite.sql.SqlFunctionCategory;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
import org.apache.calcite.util.Optionality;

import org.checkerframework.checker.nullness.qual.Nullable;

import static com.google.common.base.Preconditions.checkArgument;

/**
* <code>Avg</code> is an aggregator which returns the average of the values
* which go into it. It has precisely one argument of numeric type
* (<code>int</code>, <code>long</code>, <code>float</code>, <code>
* double</code>), and the result is the same type.
*
* <p>For statistical functions (STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP),
* this function implements {@link SqlConstantValueAggFunction} to support
* optimization when applied to constant GROUP BY keys.
*/
public class SqlAvgAggFunction extends SqlAggFunction {
public class SqlAvgAggFunction extends SqlAggFunction
implements SqlConstantValueAggFunction {

//~ Constructors -----------------------------------------------------------

Expand Down Expand Up @@ -86,4 +96,19 @@ public enum Subtype {
VAR_POP,
VAR_SAMP
}

@Override public @Nullable RexNode getConstantResult(RexBuilder rexBuilder,
RelDataType returnType) {
// Only statistical functions (variance and standard deviation) return 0 for constant values.
// AVG and other functions should not be optimized through this interface.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this right when the aggregate argument is NULL?

@xuzifu666 xuzifu666 Jun 9, 2026

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for point out this issue, this requires proper NULL semantics handling. When GROUP BY key/expression is NULL, we generate: CASE WHEN ... IS NULL THEN NULL ELSE 0 END to preserve correct NULL semantics. I addressed the issue like this way, PTAL the last commit.

switch (kind) {
case STDDEV_POP:
case STDDEV_SAMP:
case VAR_POP:
case VAR_SAMP:
return rexBuilder.makeLiteral(0, returnType, true);
default:
return null;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,120 @@ private static RelOptFixture sql(String sql) {
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged();
}

/** Test case for
* <a href="https://issues.apache.org/jira/browse/CALCITE-7493">[CALCITE-7493]
* Support constant-result aggregates (e.g., STDDEV_POP, STDDEV) over GROUP BY keys</a>. */
@Test void testStatisticalFunctionStddevSampOfGroupByKey() {
// STDDEV_SAMP of a constant (GROUP BY key) is 0.
String sql = "select sal, stddev_samp(sal) as sd\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testStatisticalFunctionStddevPopOfGroupByKey() {
// STDDEV_POP of a constant (GROUP BY key) is 0.
String sql = "select sal, stddev_pop(sal) as sdp\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testStatisticalFunctionVarPopOfGroupByKey() {
// VAR_POP of a constant (GROUP BY key) is 0.
String sql = "select sal, var_pop(sal) as vp\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testStatisticalFunctionVarSampOfGroupByKey() {
// VAR_SAMP of a constant (GROUP BY key) is 0
// (variance of all identical values is 0).
String sql = "select sal, var_samp(sal) as vs\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testMultipleStatisticalFunctions() {
// Test multiple statistical functions together.
String sql = "select sal, stddev_samp(sal) as sd, stddev_pop(sal) as sdp,\n"
+ "var_pop(sal) as vp, var_samp(sal) as vs\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testStatisticalFunctionWithBinaryExpression() {
// Variance of binary expression (sal + deptno) where all operands are GROUP BY keys.
// Since all values are constant within each group, variance is 0.
String sql = "select sal, var_pop(sal + deptno) as vp\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testStatisticalFunctionWithConstantExpression() {
// Variance of expression with constant (2*sal + 100) where sal is a GROUP BY key.
// Since all values are constant within each group, variance is 0.
String sql = "select sal, stddev_pop(2 * sal + 100) as sdp\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testStatisticalFunctionWithMultipleGroupByKeys() {
// Variance of expression combining multiple GROUP BY keys (sal * deptno).
// Since all values are constant within each group, variance is 0.
String sql = "select sal, var_samp(sal * deptno) as vs\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testStatisticalFunctionWithNonGroupByColumnNoOptimization() {
// Negative test: expression contains only non-GROUP BY column (comm).
// The rule should NOT optimize because comm is not a constant within the group.
String sql = "select sal, var_pop(comm) as vp\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged();
}

@Test void testStatisticalFunctionWithMixedColumnsNoOptimization() {
// Negative test: expression mixes GROUP BY column (sal) and non-GROUP BY column (comm).
// The rule should NOT optimize because comm is not a constant within the group.
String sql = "select sal, stddev_pop(sal + comm) as sdp\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged();
}

@Test void testStatisticalFunctionWithPartialExpressionNoOptimization() {
// Negative test: expression combines a GROUP BY key (sal) with non-GROUP BY column (empno).
// The rule should NOT optimize because empno is not constant within the group.
String sql = "select sal, var_samp(sal * empno) as vs\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged();
}

@Test void testStatisticalFunctionWithComplexMixNoOptimization() {
// Negative test: complex expression with only GROUP BY column (sal) but
// also referencing non-GROUP BY column (comm) in multiplication.
String sql = "select sal, stddev_samp(sal * 2 + comm) as sd\n"
+ "from emp group by sal, deptno";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).checkUnchanged();
}

@Test void testStatisticalFunctionNullableGroupKey() {
// Test NULL semantics: when GROUP BY key is nullable, STDDEV(key) should
// handle NULL correctly. The optimization wraps result in
// CASE WHEN key IS NULL THEN NULL ELSE 0 END
String sql = "select comm, stddev_pop(comm) as sdp\n"
+ "from empnullables group by comm";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@Test void testStatisticalFunctionNullableExpression() {
// Test NULL semantics for expressions: when expression is nullable,
// STDDEV(expr) should return NULL when expr is NULL.
// Optimization wraps result in CASE to preserve NULL semantics
String sql = "select comm, stddev_samp(comm + 1) as sd\n"
+ "from empnullables group by comm";
sql(sql).withRule(AGGREGATE_REDUCE_FUNCTIONS_ON_GROUP_KEYS).check();
}

@AfterAll static void checkActualAndReferenceFiles() {
fixture().diffRepos.checkActualAndReferenceFiles();
}
Expand Down
Loading
Loading