Skip to content

Commit 62873d1

Browse files
ivanceaalbertzaharovits
authored andcommitted
ESQL: AVG aggregation tests and ignore complex surrogates (#110579)
Some work around aggregation tests, with AVG as an example: - Added tests and autogenerated docs for AVG - As AVG uses "complex" surrogates (A combination of functions), we can't trivially execute them without a complete plan. As I'm not sure it's worth it for most aggregations, I'm skipping those cases for now, as to avoid blocking other aggs tests. The bad side effect of skipping those tests is that most tests in AvgTests are actually ignored (74 of 100)
1 parent 0876454 commit 62873d1

File tree

14 files changed

+246
-54
lines changed

14 files changed

+246
-54
lines changed

docs/reference/esql/functions/aggregation-functions.asciidoc

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
The <<esql-stats-by>> command supports these aggregate functions:
99

1010
// tag::agg_list[]
11-
* <<esql-agg-avg>>
11+
* <<esql-avg>>
1212
* <<esql-agg-count>>
1313
* <<esql-agg-count-distinct>>
1414
* <<esql-agg-max>>
@@ -23,7 +23,6 @@ The <<esql-stats-by>> command supports these aggregate functions:
2323
* experimental:[] <<esql-agg-weighted-avg>>
2424
// end::agg_list[]
2525

26-
include::avg.asciidoc[]
2726
include::count.asciidoc[]
2827
include::count-distinct.asciidoc[]
2928
include::max.asciidoc[]
@@ -33,6 +32,7 @@ include::min.asciidoc[]
3332
include::percentile.asciidoc[]
3433
include::st_centroid_agg.asciidoc[]
3534
include::sum.asciidoc[]
35+
include::layout/avg.asciidoc[]
3636
include::layout/top.asciidoc[]
3737
include::values.asciidoc[]
3838
include::weighted-avg.asciidoc[]

docs/reference/esql/functions/avg.asciidoc

-47
This file was deleted.

docs/reference/esql/functions/description/avg.asciidoc

+5
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/esql/functions/examples/avg.asciidoc

+22
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/esql/functions/kibana/definition/avg.json

+48
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/esql/functions/kibana/docs/avg.md

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/esql/functions/layout/avg.asciidoc

+15
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/esql/functions/parameters/avg.asciidoc

+6
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

docs/reference/esql/functions/signature/avg.svg

+1
Loading

docs/reference/esql/functions/types/avg.asciidoc

+11
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java

+15-1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import org.elasticsearch.xpack.esql.core.tree.Source;
1515
import org.elasticsearch.xpack.esql.core.type.DataType;
1616
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
17+
import org.elasticsearch.xpack.esql.expression.function.Example;
1718
import org.elasticsearch.xpack.esql.expression.function.FunctionInfo;
1819
import org.elasticsearch.xpack.esql.expression.function.Param;
1920
import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg;
@@ -28,7 +29,20 @@
2829
public class Avg extends AggregateFunction implements SurrogateExpression {
2930
public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Avg", Avg::new);
3031

31-
@FunctionInfo(returnType = "double", description = "The average of a numeric field.", isAggregation = true)
32+
@FunctionInfo(
33+
returnType = "double",
34+
description = "The average of a numeric field.",
35+
isAggregation = true,
36+
examples = {
37+
@Example(file = "stats", tag = "avg"),
38+
@Example(
39+
description = "The expression can use inline functions. For example, to calculate the average "
40+
+ "over a multivalued column, first use `MV_AVG` to average the multiple values per row, "
41+
+ "and use the result with the `AVG` function",
42+
file = "stats",
43+
tag = "docsStatsAvgNestedExpression"
44+
) }
45+
)
3246
public Avg(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) {
3347
super(source, field);
3448
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java

+9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import org.elasticsearch.compute.data.Page;
1616
import org.elasticsearch.core.Releasables;
1717
import org.elasticsearch.xpack.esql.core.expression.Expression;
18+
import org.elasticsearch.xpack.esql.core.expression.FieldAttribute;
19+
import org.elasticsearch.xpack.esql.core.expression.Literal;
1820
import org.elasticsearch.xpack.esql.core.type.DataType;
1921
import org.elasticsearch.xpack.esql.core.util.NumericUtils;
2022
import org.elasticsearch.xpack.esql.expression.SurrogateExpression;
@@ -251,6 +253,13 @@ private void resolveExpression(Expression expression, Consumer<Expression> onAgg
251253
expression = new FoldNull().rule(expression);
252254
assertThat(expression.dataType(), equalTo(testCase.expectedType()));
253255

256+
assumeTrue(
257+
"Surrogate expression with non-trivial children cannot be evaluated",
258+
expression.children()
259+
.stream()
260+
.allMatch(child -> child instanceof FieldAttribute || child instanceof DeepCopy || child instanceof Literal)
261+
);
262+
254263
if (expression instanceof AggregateFunction == false) {
255264
onEvaluableExpression.accept(expression);
256265
return;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
package org.elasticsearch.xpack.esql.expression.function.aggregate;
9+
10+
import com.carrotsearch.randomizedtesting.annotations.Name;
11+
import com.carrotsearch.randomizedtesting.annotations.ParametersFactory;
12+
13+
import org.elasticsearch.xpack.esql.core.expression.Expression;
14+
import org.elasticsearch.xpack.esql.core.tree.Source;
15+
import org.elasticsearch.xpack.esql.core.type.DataType;
16+
import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase;
17+
import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier;
18+
import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier;
19+
20+
import java.util.ArrayList;
21+
import java.util.List;
22+
import java.util.function.Supplier;
23+
import java.util.stream.Collectors;
24+
import java.util.stream.Stream;
25+
26+
import static org.hamcrest.Matchers.equalTo;
27+
28+
public class AvgTests extends AbstractAggregationTestCase {
29+
public AvgTests(@Name("TestCase") Supplier<TestCaseSupplier.TestCase> testCaseSupplier) {
30+
this.testCase = testCaseSupplier.get();
31+
}
32+
33+
@ParametersFactory
34+
public static Iterable<Object[]> parameters() {
35+
var suppliers = new ArrayList<TestCaseSupplier>();
36+
37+
Stream.of(
38+
MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true),
39+
MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true),
40+
MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true)
41+
).flatMap(List::stream).map(AvgTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers));
42+
43+
suppliers.add(
44+
// Folding
45+
new TestCaseSupplier(
46+
List.of(DataType.INTEGER),
47+
() -> new TestCaseSupplier.TestCase(
48+
List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")),
49+
"Avg[field=Attribute[channel=0]]",
50+
DataType.DOUBLE,
51+
equalTo(200.)
52+
)
53+
)
54+
);
55+
56+
return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers);
57+
}
58+
59+
@Override
60+
protected Expression build(Source source, List<Expression> args) {
61+
return new Avg(source, args.get(0));
62+
}
63+
64+
private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) {
65+
return new TestCaseSupplier(List.of(fieldSupplier.type()), () -> {
66+
var fieldTypedData = fieldSupplier.get();
67+
68+
Object expected = switch (fieldTypedData.type().widenSmallNumeric()) {
69+
case INTEGER -> fieldTypedData.multiRowData()
70+
.stream()
71+
.map(v -> (Integer) v)
72+
.collect(Collectors.summarizingInt(Integer::intValue))
73+
.getAverage();
74+
case LONG -> fieldTypedData.multiRowData()
75+
.stream()
76+
.map(v -> (Long) v)
77+
.collect(Collectors.summarizingLong(Long::longValue))
78+
.getAverage();
79+
case DOUBLE -> fieldTypedData.multiRowData()
80+
.stream()
81+
.map(v -> (Double) v)
82+
.collect(Collectors.summarizingDouble(Double::doubleValue))
83+
.getAverage();
84+
default -> throw new IllegalStateException("Unexpected value: " + fieldTypedData.type());
85+
};
86+
87+
return new TestCaseSupplier.TestCase(
88+
List.of(fieldTypedData),
89+
"Avg[field=Attribute[channel=0]]",
90+
DataType.DOUBLE,
91+
equalTo(expected)
92+
);
93+
});
94+
}
95+
}

x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.util.Comparator;
2323
import java.util.List;
2424
import java.util.function.Supplier;
25+
import java.util.stream.Collectors;
2526
import java.util.stream.Stream;
2627

2728
import static org.hamcrest.Matchers.equalTo;
@@ -37,14 +38,15 @@ public static Iterable<Object[]> parameters() {
3738

3839
for (var limitCaseSupplier : TestCaseSupplier.intCases(1, 1000, false)) {
3940
for (String order : List.of("asc", "desc")) {
40-
for (var fieldCaseSupplier : Stream.of(
41+
Stream.of(
4142
MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true),
4243
MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true),
4344
MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true),
4445
MultiRowTestCaseSupplier.dateCases(1, 1000)
45-
).flatMap(List::stream).toList()) {
46-
suppliers.add(TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order));
47-
}
46+
)
47+
.flatMap(List::stream)
48+
.map(fieldCaseSupplier -> TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order))
49+
.collect(Collectors.toCollection(() -> suppliers));
4850
}
4951
}
5052

0 commit comments

Comments
 (0)