Skip to content

Commit 8ecac02

Browse files
authored
[8.x] Infer the score mode to use from the Lucene collector (#125930) (#126032)
* [ES|QL] Infer the score mode to use from the Lucene collector (#125930) This change uses the Lucene collector to infer which score mode to use when the topN collector is used. * fix typo * fix tests
1 parent 2f01329 commit 8ecac02

File tree

13 files changed

+122
-83
lines changed

13 files changed

+122
-83
lines changed

docs/changelog/125930.yaml

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
pr: 125930
2+
summary: Infer the score mode to use from the Lucene collector
3+
area: "ES|QL"
4+
type: enhancement
5+
issues: []

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneCountOperator.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ public Factory(
4949
int taskConcurrency,
5050
int limit
5151
) {
52-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
52+
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
5353
}
5454

5555
@Override

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMaxFactory.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.List;
2424
import java.util.function.Function;
2525

26+
import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
27+
2628
/**
2729
* Factory that generates an operator that finds the max value of a field using the {@link LuceneMinMaxOperator}.
2830
*/
@@ -121,7 +123,7 @@ public LuceneMaxFactory(
121123
NumberType numberType,
122124
int limit
123125
) {
124-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
126+
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
125127
this.fieldName = fieldName;
126128
this.numberType = numberType;
127129
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneMinFactory.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import java.util.List;
2424
import java.util.function.Function;
2525

26+
import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
27+
2628
/**
2729
* Factory that generates an operator that finds the min value of a field using the {@link LuceneMinMaxOperator}.
2830
*/
@@ -121,7 +123,7 @@ public LuceneMinFactory(
121123
NumberType numberType,
122124
int limit
123125
) {
124-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
126+
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), dataPartitioning, taskConcurrency, limit, false);
125127
this.fieldName = fieldName;
126128
this.numberType = numberType;
127129
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneOperator.java

+5-7
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.apache.lucene.search.BulkScorer;
1212
import org.apache.lucene.search.ConstantScoreQuery;
1313
import org.apache.lucene.search.DocIdSetIterator;
14-
import org.apache.lucene.search.IndexSearcher;
1514
import org.apache.lucene.search.LeafCollector;
1615
import org.apache.lucene.search.Query;
1716
import org.apache.lucene.search.ScoreMode;
@@ -83,28 +82,27 @@ public abstract static class Factory implements SourceOperator.SourceOperatorFac
8382
protected final DataPartitioning dataPartitioning;
8483
protected final int taskConcurrency;
8584
protected final int limit;
86-
protected final ScoreMode scoreMode;
85+
protected final boolean needsScore;
8786
protected final LuceneSliceQueue sliceQueue;
8887

8988
/**
9089
* Build the factory.
9190
*
92-
* @param scoreMode the {@link ScoreMode} passed to {@link IndexSearcher#createWeight}
91+
* @param needsScore Whether the score is needed.
9392
*/
9493
protected Factory(
9594
List<? extends ShardContext> contexts,
96-
Function<ShardContext, Query> queryFunction,
95+
Function<ShardContext, Weight> weightFunction,
9796
DataPartitioning dataPartitioning,
9897
int taskConcurrency,
9998
int limit,
100-
ScoreMode scoreMode
99+
boolean needsScore
101100
) {
102101
this.limit = limit;
103-
this.scoreMode = scoreMode;
104102
this.dataPartitioning = dataPartitioning;
105-
var weightFunction = weightFunction(queryFunction, scoreMode);
106103
this.sliceQueue = LuceneSliceQueue.create(contexts, weightFunction, dataPartitioning, taskConcurrency);
107104
this.taskConcurrency = Math.min(sliceQueue.totalSlices(), taskConcurrency);
105+
this.needsScore = needsScore;
108106
}
109107

110108
public final int taskConcurrency() {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneSourceOperator.java

+14-8
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
import org.apache.lucene.search.LeafCollector;
1212
import org.apache.lucene.search.Query;
1313
import org.apache.lucene.search.Scorable;
14-
import org.apache.lucene.search.ScoreMode;
1514
import org.elasticsearch.compute.data.BlockFactory;
1615
import org.elasticsearch.compute.data.DocBlock;
1716
import org.elasticsearch.compute.data.DocVector;
@@ -56,17 +55,24 @@ public Factory(
5655
int taskConcurrency,
5756
int maxPageSize,
5857
int limit,
59-
boolean scoring
58+
boolean needsScore
6059
) {
61-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : COMPLETE_NO_SCORES);
60+
super(
61+
contexts,
62+
weightFunction(queryFunction, needsScore ? COMPLETE : COMPLETE_NO_SCORES),
63+
dataPartitioning,
64+
taskConcurrency,
65+
limit,
66+
needsScore
67+
);
6268
this.maxPageSize = maxPageSize;
6369
// TODO: use a single limiter for multiple stage execution
6470
this.limiter = limit == NO_LIMIT ? Limiter.NO_LIMIT : new Limiter(limit);
6571
}
6672

6773
@Override
6874
public SourceOperator get(DriverContext driverContext) {
69-
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, scoreMode);
75+
return new LuceneSourceOperator(driverContext.blockFactory(), maxPageSize, sliceQueue, limit, limiter, needsScore);
7076
}
7177

7278
public int maxPageSize() {
@@ -81,8 +87,8 @@ public String describe() {
8187
+ maxPageSize
8288
+ ", limit = "
8389
+ limit
84-
+ ", scoreMode = "
85-
+ scoreMode
90+
+ ", needsScore = "
91+
+ needsScore
8692
+ "]";
8793
}
8894
}
@@ -94,7 +100,7 @@ public LuceneSourceOperator(
94100
LuceneSliceQueue sliceQueue,
95101
int limit,
96102
Limiter limiter,
97-
ScoreMode scoreMode
103+
boolean needsScore
98104
) {
99105
super(blockFactory, maxPageSize, sliceQueue);
100106
this.minPageSize = Math.max(1, maxPageSize / 2);
@@ -104,7 +110,7 @@ public LuceneSourceOperator(
104110
boolean success = false;
105111
try {
106112
this.docsBuilder = blockFactory.newIntVectorBuilder(estimatedSize);
107-
if (scoreMode.needsScores()) {
113+
if (needsScore) {
108114
scoreBuilder = blockFactory.newDoubleVectorBuilder(estimatedSize);
109115
this.leafCollector = new ScoringCollector();
110116
} else {

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperator.java

+54-47
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,12 @@
1414
import org.apache.lucene.search.LeafCollector;
1515
import org.apache.lucene.search.Query;
1616
import org.apache.lucene.search.ScoreDoc;
17-
import org.apache.lucene.search.ScoreMode;
1817
import org.apache.lucene.search.Sort;
1918
import org.apache.lucene.search.SortField;
2019
import org.apache.lucene.search.TopDocsCollector;
2120
import org.apache.lucene.search.TopFieldCollectorManager;
2221
import org.apache.lucene.search.TopScoreDocCollectorManager;
22+
import org.apache.lucene.search.Weight;
2323
import org.elasticsearch.common.Strings;
2424
import org.elasticsearch.compute.data.BlockFactory;
2525
import org.elasticsearch.compute.data.DocBlock;
@@ -36,16 +36,14 @@
3636
import org.elasticsearch.search.sort.SortBuilder;
3737

3838
import java.io.IOException;
39+
import java.io.UncheckedIOException;
3940
import java.util.ArrayList;
4041
import java.util.Arrays;
4142
import java.util.List;
4243
import java.util.Optional;
4344
import java.util.function.Function;
4445
import java.util.stream.Collectors;
4546

46-
import static org.apache.lucene.search.ScoreMode.COMPLETE;
47-
import static org.apache.lucene.search.ScoreMode.TOP_DOCS;
48-
4947
/**
5048
* Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN)
5149
*/
@@ -62,16 +60,16 @@ public Factory(
6260
int maxPageSize,
6361
int limit,
6462
List<SortBuilder<?>> sorts,
65-
boolean scoring
63+
boolean needsScore
6664
) {
67-
super(contexts, queryFunction, dataPartitioning, taskConcurrency, limit, scoring ? COMPLETE : TOP_DOCS);
65+
super(contexts, weightFunction(queryFunction, sorts, needsScore), dataPartitioning, taskConcurrency, limit, needsScore);
6866
this.maxPageSize = maxPageSize;
6967
this.sorts = sorts;
7068
}
7169

7270
@Override
7371
public SourceOperator get(DriverContext driverContext) {
74-
return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, scoreMode);
72+
return new LuceneTopNSourceOperator(driverContext.blockFactory(), maxPageSize, sorts, limit, sliceQueue, needsScore);
7573
}
7674

7775
public int maxPageSize() {
@@ -87,8 +85,8 @@ public String describe() {
8785
+ maxPageSize
8886
+ ", limit = "
8987
+ limit
90-
+ ", scoreMode = "
91-
+ scoreMode
88+
+ ", needsScore = "
89+
+ needsScore
9290
+ ", sorts = ["
9391
+ notPrettySorts
9492
+ "]]";
@@ -107,20 +105,20 @@ public String describe() {
107105
private PerShardCollector perShardCollector;
108106
private final List<SortBuilder<?>> sorts;
109107
private final int limit;
110-
private final ScoreMode scoreMode;
108+
private final boolean needsScore;
111109

112110
public LuceneTopNSourceOperator(
113111
BlockFactory blockFactory,
114112
int maxPageSize,
115113
List<SortBuilder<?>> sorts,
116114
int limit,
117115
LuceneSliceQueue sliceQueue,
118-
ScoreMode scoreMode
116+
boolean needsScore
119117
) {
120118
super(blockFactory, maxPageSize, sliceQueue);
121119
this.sorts = sorts;
122120
this.limit = limit;
123-
this.scoreMode = scoreMode;
121+
this.needsScore = needsScore;
124122
}
125123

126124
@Override
@@ -162,7 +160,7 @@ private Page collect() throws IOException {
162160
try {
163161
if (perShardCollector == null || perShardCollector.shardContext.index() != scorer.shardContext().index()) {
164162
// TODO: share the bottom between shardCollectors
165-
perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, limit);
163+
perShardCollector = newPerShardCollector(scorer.shardContext(), sorts, needsScore, limit);
166164
}
167165
var leafCollector = perShardCollector.getLeafCollector(scorer.leafReaderContext());
168166
scorer.scoreNextRange(leafCollector, scorer.leafReaderContext().reader().getLiveDocs(), maxPageSize);
@@ -260,7 +258,7 @@ private float getScore(ScoreDoc scoreDoc) {
260258
}
261259

262260
private DoubleVector.Builder scoreVectorOrNull(int size) {
263-
if (scoreMode.needsScores()) {
261+
if (needsScore) {
264262
return blockFactory.newDoubleVectorFixedBuilder(size);
265263
} else {
266264
return null;
@@ -270,43 +268,11 @@ private DoubleVector.Builder scoreVectorOrNull(int size) {
270268
@Override
271269
protected void describe(StringBuilder sb) {
272270
sb.append(", limit = ").append(limit);
273-
sb.append(", scoreMode = ").append(scoreMode);
271+
sb.append(", needsScore = ").append(needsScore);
274272
String notPrettySorts = sorts.stream().map(Strings::toString).collect(Collectors.joining(","));
275273
sb.append(", sorts = [").append(notPrettySorts).append("]");
276274
}
277275

278-
PerShardCollector newPerShardCollector(ShardContext shardContext, List<SortBuilder<?>> sorts, int limit) throws IOException {
279-
Optional<SortAndFormats> sortAndFormats = shardContext.buildSort(sorts);
280-
if (sortAndFormats.isEmpty()) {
281-
throw new IllegalStateException("sorts must not be disabled in TopN");
282-
}
283-
if (scoreMode.needsScores() == false) {
284-
return new NonScoringPerShardCollector(shardContext, sortAndFormats.get().sort, limit);
285-
} else {
286-
SortField[] sortFields = sortAndFormats.get().sort.getSort();
287-
if (sortFields != null && sortFields.length == 1 && sortFields[0].needsScores() && sortFields[0].getReverse() == false) {
288-
// SORT _score DESC
289-
return new ScoringPerShardCollector(
290-
shardContext,
291-
new TopScoreDocCollectorManager(limit, null, limit, false).newCollector()
292-
);
293-
} else {
294-
// SORT ..., _score, ...
295-
var sort = new Sort();
296-
if (sortFields != null) {
297-
var l = new ArrayList<>(Arrays.asList(sortFields));
298-
l.add(SortField.FIELD_DOC);
299-
l.add(SortField.FIELD_SCORE);
300-
sort = new Sort(l.toArray(SortField[]::new));
301-
}
302-
return new ScoringPerShardCollector(
303-
shardContext,
304-
new TopFieldCollectorManager(sort, limit, null, limit, false).newCollector()
305-
);
306-
}
307-
}
308-
}
309-
310276
abstract static class PerShardCollector {
311277
private final ShardContext shardContext;
312278
private final TopDocsCollector<?> collector;
@@ -341,4 +307,45 @@ static final class ScoringPerShardCollector extends PerShardCollector {
341307
super(shardContext, topDocsCollector);
342308
}
343309
}
310+
311+
private static Function<ShardContext, Weight> weightFunction(
312+
Function<ShardContext, Query> queryFunction,
313+
List<SortBuilder<?>> sorts,
314+
boolean needsScore
315+
) {
316+
return ctx -> {
317+
final var query = queryFunction.apply(ctx);
318+
final var searcher = ctx.searcher();
319+
try {
320+
// we create a collector with a limit of 1 to determine the appropriate score mode to use.
321+
var scoreMode = newPerShardCollector(ctx, sorts, needsScore, 1).collector.scoreMode();
322+
return searcher.createWeight(searcher.rewrite(query), scoreMode, 1);
323+
} catch (IOException e) {
324+
throw new UncheckedIOException(e);
325+
}
326+
};
327+
}
328+
329+
private static PerShardCollector newPerShardCollector(ShardContext context, List<SortBuilder<?>> sorts, boolean needsScore, int limit)
330+
throws IOException {
331+
Optional<SortAndFormats> sortAndFormats = context.buildSort(sorts);
332+
if (sortAndFormats.isEmpty()) {
333+
throw new IllegalStateException("sorts must not be disabled in TopN");
334+
}
335+
if (needsScore == false) {
336+
return new NonScoringPerShardCollector(context, sortAndFormats.get().sort, limit);
337+
}
338+
Sort sort = sortAndFormats.get().sort;
339+
if (Sort.RELEVANCE.equals(sort)) {
340+
// SORT _score DESC
341+
return new ScoringPerShardCollector(context, new TopScoreDocCollectorManager(limit, null, 0).newCollector());
342+
}
343+
344+
// SORT ..., _score, ...
345+
var l = new ArrayList<>(Arrays.asList(sort.getSort()));
346+
l.add(SortField.FIELD_DOC);
347+
l.add(SortField.FIELD_SCORE);
348+
sort = new Sort(l.toArray(SortField[]::new));
349+
return new ScoringPerShardCollector(context, new TopFieldCollectorManager(sort, limit, null, 0).newCollector());
350+
}
344351
}

x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/lucene/TimeSeriesSortedSourceOperatorFactory.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@
3333
import java.util.List;
3434
import java.util.function.Function;
3535

36+
import static org.elasticsearch.compute.lucene.LuceneOperator.weightFunction;
37+
3638
/**
3739
* Creates a source operator that takes advantage of the natural sorting of segments in a tsdb index.
3840
* <p>
@@ -56,7 +58,7 @@ private TimeSeriesSortedSourceOperatorFactory(
5658
int maxPageSize,
5759
int limit
5860
) {
59-
super(contexts, queryFunction, DataPartitioning.SHARD, taskConcurrency, limit, ScoreMode.COMPLETE_NO_SCORES);
61+
super(contexts, weightFunction(queryFunction, ScoreMode.COMPLETE_NO_SCORES), DataPartitioning.SHARD, taskConcurrency, limit, false);
6062
this.maxPageSize = maxPageSize;
6163
}
6264

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneSourceOperatorTests.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ protected Matcher<String> expectedToStringOfSimple() {
119119
protected Matcher<String> expectedDescriptionOfSimple() {
120120
return matchesRegex(
121121
"LuceneSourceOperator"
122-
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, scoreMode = (COMPLETE|COMPLETE_NO_SCORES)]"
122+
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, needsScore = (true|false)]"
123123
);
124124
}
125125

x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/LuceneTopNSourceOperatorScoringTests.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -109,14 +109,14 @@ public Optional<SortAndFormats> buildSort(List<SortBuilder<?>> sorts) {
109109

110110
@Override
111111
protected Matcher<String> expectedToStringOfSimple() {
112-
return matchesRegex("LuceneTopNSourceOperator\\[maxPageSize = \\d+, limit = 100, scoreMode = COMPLETE, sorts = \\[\\{.+}]]");
112+
return matchesRegex("LuceneTopNSourceOperator\\[maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]");
113113
}
114114

115115
@Override
116116
protected Matcher<String> expectedDescriptionOfSimple() {
117117
return matchesRegex(
118118
"LuceneTopNSourceOperator"
119-
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, scoreMode = COMPLETE, sorts = \\[\\{.+}]]"
119+
+ "\\[dataPartitioning = (DOC|SHARD|SEGMENT), maxPageSize = \\d+, limit = 100, needsScore = true, sorts = \\[\\{.+}]]"
120120
);
121121
}
122122

0 commit comments

Comments
 (0)