14
14
import org .apache .lucene .search .LeafCollector ;
15
15
import org .apache .lucene .search .Query ;
16
16
import org .apache .lucene .search .ScoreDoc ;
17
- import org .apache .lucene .search .ScoreMode ;
18
17
import org .apache .lucene .search .Sort ;
19
18
import org .apache .lucene .search .SortField ;
20
19
import org .apache .lucene .search .TopDocsCollector ;
21
20
import org .apache .lucene .search .TopFieldCollectorManager ;
22
21
import org .apache .lucene .search .TopScoreDocCollectorManager ;
22
+ import org .apache .lucene .search .Weight ;
23
23
import org .elasticsearch .common .Strings ;
24
24
import org .elasticsearch .compute .data .BlockFactory ;
25
25
import org .elasticsearch .compute .data .DocBlock ;
36
36
import org .elasticsearch .search .sort .SortBuilder ;
37
37
38
38
import java .io .IOException ;
39
+ import java .io .UncheckedIOException ;
39
40
import java .util .ArrayList ;
40
41
import java .util .Arrays ;
41
42
import java .util .List ;
42
43
import java .util .Optional ;
43
44
import java .util .function .Function ;
44
45
import java .util .stream .Collectors ;
45
46
46
- import static org .apache .lucene .search .ScoreMode .COMPLETE ;
47
- import static org .apache .lucene .search .ScoreMode .TOP_DOCS ;
48
-
49
47
/**
50
48
* Source operator that builds Pages out of the output of a TopFieldCollector (aka TopN)
51
49
*/
@@ -62,16 +60,16 @@ public Factory(
62
60
int maxPageSize ,
63
61
int limit ,
64
62
List <SortBuilder <?>> sorts ,
65
- boolean scoring
63
+ boolean needsScore
66
64
) {
67
- super (contexts , queryFunction , dataPartitioning , taskConcurrency , limit , scoring ? COMPLETE : TOP_DOCS );
65
+ super (contexts , weightFunction ( queryFunction , sorts , needsScore ), dataPartitioning , taskConcurrency , limit , needsScore );
68
66
this .maxPageSize = maxPageSize ;
69
67
this .sorts = sorts ;
70
68
}
71
69
72
70
@ Override
73
71
public SourceOperator get (DriverContext driverContext ) {
74
- return new LuceneTopNSourceOperator (driverContext .blockFactory (), maxPageSize , sorts , limit , sliceQueue , scoreMode );
72
+ return new LuceneTopNSourceOperator (driverContext .blockFactory (), maxPageSize , sorts , limit , sliceQueue , needsScore );
75
73
}
76
74
77
75
public int maxPageSize () {
@@ -87,8 +85,8 @@ public String describe() {
87
85
+ maxPageSize
88
86
+ ", limit = "
89
87
+ limit
90
- + ", scoreMode = "
91
- + scoreMode
88
+ + ", needsScore = "
89
+ + needsScore
92
90
+ ", sorts = ["
93
91
+ notPrettySorts
94
92
+ "]]" ;
@@ -107,20 +105,20 @@ public String describe() {
107
105
private PerShardCollector perShardCollector ;
108
106
private final List <SortBuilder <?>> sorts ;
109
107
private final int limit ;
110
- private final ScoreMode scoreMode ;
108
+ private final boolean needsScore ;
111
109
112
110
public LuceneTopNSourceOperator (
113
111
BlockFactory blockFactory ,
114
112
int maxPageSize ,
115
113
List <SortBuilder <?>> sorts ,
116
114
int limit ,
117
115
LuceneSliceQueue sliceQueue ,
118
- ScoreMode scoreMode
116
+ boolean needsScore
119
117
) {
120
118
super (blockFactory , maxPageSize , sliceQueue );
121
119
this .sorts = sorts ;
122
120
this .limit = limit ;
123
- this .scoreMode = scoreMode ;
121
+ this .needsScore = needsScore ;
124
122
}
125
123
126
124
@ Override
@@ -162,7 +160,7 @@ private Page collect() throws IOException {
162
160
try {
163
161
if (perShardCollector == null || perShardCollector .shardContext .index () != scorer .shardContext ().index ()) {
164
162
// TODO: share the bottom between shardCollectors
165
- perShardCollector = newPerShardCollector (scorer .shardContext (), sorts , limit );
163
+ perShardCollector = newPerShardCollector (scorer .shardContext (), sorts , needsScore , limit );
166
164
}
167
165
var leafCollector = perShardCollector .getLeafCollector (scorer .leafReaderContext ());
168
166
scorer .scoreNextRange (leafCollector , scorer .leafReaderContext ().reader ().getLiveDocs (), maxPageSize );
@@ -260,7 +258,7 @@ private float getScore(ScoreDoc scoreDoc) {
260
258
}
261
259
262
260
private DoubleVector .Builder scoreVectorOrNull (int size ) {
263
- if (scoreMode . needsScores () ) {
261
+ if (needsScore ) {
264
262
return blockFactory .newDoubleVectorFixedBuilder (size );
265
263
} else {
266
264
return null ;
@@ -270,43 +268,11 @@ private DoubleVector.Builder scoreVectorOrNull(int size) {
270
268
@ Override
271
269
protected void describe (StringBuilder sb ) {
272
270
sb .append (", limit = " ).append (limit );
273
- sb .append (", scoreMode = " ).append (scoreMode );
271
+ sb .append (", needsScore = " ).append (needsScore );
274
272
String notPrettySorts = sorts .stream ().map (Strings ::toString ).collect (Collectors .joining ("," ));
275
273
sb .append (", sorts = [" ).append (notPrettySorts ).append ("]" );
276
274
}
277
275
278
- PerShardCollector newPerShardCollector (ShardContext shardContext , List <SortBuilder <?>> sorts , int limit ) throws IOException {
279
- Optional <SortAndFormats > sortAndFormats = shardContext .buildSort (sorts );
280
- if (sortAndFormats .isEmpty ()) {
281
- throw new IllegalStateException ("sorts must not be disabled in TopN" );
282
- }
283
- if (scoreMode .needsScores () == false ) {
284
- return new NonScoringPerShardCollector (shardContext , sortAndFormats .get ().sort , limit );
285
- } else {
286
- SortField [] sortFields = sortAndFormats .get ().sort .getSort ();
287
- if (sortFields != null && sortFields .length == 1 && sortFields [0 ].needsScores () && sortFields [0 ].getReverse () == false ) {
288
- // SORT _score DESC
289
- return new ScoringPerShardCollector (
290
- shardContext ,
291
- new TopScoreDocCollectorManager (limit , null , limit , false ).newCollector ()
292
- );
293
- } else {
294
- // SORT ..., _score, ...
295
- var sort = new Sort ();
296
- if (sortFields != null ) {
297
- var l = new ArrayList <>(Arrays .asList (sortFields ));
298
- l .add (SortField .FIELD_DOC );
299
- l .add (SortField .FIELD_SCORE );
300
- sort = new Sort (l .toArray (SortField []::new ));
301
- }
302
- return new ScoringPerShardCollector (
303
- shardContext ,
304
- new TopFieldCollectorManager (sort , limit , null , limit , false ).newCollector ()
305
- );
306
- }
307
- }
308
- }
309
-
310
276
abstract static class PerShardCollector {
311
277
private final ShardContext shardContext ;
312
278
private final TopDocsCollector <?> collector ;
@@ -341,4 +307,45 @@ static final class ScoringPerShardCollector extends PerShardCollector {
341
307
super (shardContext , topDocsCollector );
342
308
}
343
309
}
310
+
311
+ private static Function <ShardContext , Weight > weightFunction (
312
+ Function <ShardContext , Query > queryFunction ,
313
+ List <SortBuilder <?>> sorts ,
314
+ boolean needsScore
315
+ ) {
316
+ return ctx -> {
317
+ final var query = queryFunction .apply (ctx );
318
+ final var searcher = ctx .searcher ();
319
+ try {
320
+ // we create a collector with a limit of 1 to determine the appropriate score mode to use.
321
+ var scoreMode = newPerShardCollector (ctx , sorts , needsScore , 1 ).collector .scoreMode ();
322
+ return searcher .createWeight (searcher .rewrite (query ), scoreMode , 1 );
323
+ } catch (IOException e ) {
324
+ throw new UncheckedIOException (e );
325
+ }
326
+ };
327
+ }
328
+
329
+ private static PerShardCollector newPerShardCollector (ShardContext context , List <SortBuilder <?>> sorts , boolean needsScore , int limit )
330
+ throws IOException {
331
+ Optional <SortAndFormats > sortAndFormats = context .buildSort (sorts );
332
+ if (sortAndFormats .isEmpty ()) {
333
+ throw new IllegalStateException ("sorts must not be disabled in TopN" );
334
+ }
335
+ if (needsScore == false ) {
336
+ return new NonScoringPerShardCollector (context , sortAndFormats .get ().sort , limit );
337
+ }
338
+ Sort sort = sortAndFormats .get ().sort ;
339
+ if (Sort .RELEVANCE .equals (sort )) {
340
+ // SORT _score DESC
341
+ return new ScoringPerShardCollector (context , new TopScoreDocCollectorManager (limit , null , 0 ).newCollector ());
342
+ }
343
+
344
+ // SORT ..., _score, ...
345
+ var l = new ArrayList <>(Arrays .asList (sort .getSort ()));
346
+ l .add (SortField .FIELD_DOC );
347
+ l .add (SortField .FIELD_SCORE );
348
+ sort = new Sort (l .toArray (SortField []::new ));
349
+ return new ScoringPerShardCollector (context , new TopFieldCollectorManager (sort , limit , null , 0 ).newCollector ());
350
+ }
344
351
}
0 commit comments