Skip to content

Commit a309bd6

Browse files
authored
Move HitQueue in TopScoreDocCollector to a LongHeap (#14714)
1 parent d242a40 commit a309bd6

File tree

11 files changed

+224
-101
lines changed

11 files changed

+224
-101
lines changed

lucene/CHANGES.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ Optimizations
144144

145145
* GITHUB#14674: Optimize AbstractKnnVectorQuery#createBitSet with intoBitset. (Guo Feng)
146146

147+
* GITHUB#14714: Move HitQueue in TopScoreDocCollector to a LongHeap. (Guo Feng)
148+
147149
* GITHUB#14720: Cache high-order bits of hashcode to speed up BytesRefHash. (Pan Guixin)
148150

149151
* GITHUB#14753: Implement IndexedDISI#docIDRunEnd. (Ge Song)
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.lucene.search;
19+
20+
import org.apache.lucene.util.NumericUtils;
21+
22+
/**
23+
* An encoder do encode (doc, score) pair as a long whose sort order is same as {@code (o1, o2) ->
24+
* Float.compare(o1.score, o2.score)).thenComparing(Comparator.comparingInt((ScoreDoc o) ->
25+
* o.doc).reversed())}
26+
*/
27+
class DocScoreEncoder {
28+
29+
static final long LEAST_COMPETITIVE_CODE = encode(Integer.MAX_VALUE, Float.NEGATIVE_INFINITY);
30+
31+
static long encode(int docId, float score) {
32+
return (((long) NumericUtils.floatToSortableInt(score)) << 32) | (Integer.MAX_VALUE - docId);
33+
}
34+
35+
static float toScore(long value) {
36+
return NumericUtils.sortableIntToFloat((int) (value >>> 32));
37+
}
38+
39+
static int docId(long value) {
40+
return Integer.MAX_VALUE - ((int) value);
41+
}
42+
}

lucene/core/src/java/org/apache/lucene/search/MaxScoreAccumulator.java

Lines changed: 4 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
/** Maintains the maximum score and its corresponding document id concurrently */
2323
final class MaxScoreAccumulator {
2424
// we use 2^10-1 to check the remainder with a bitwise operation
25-
static final int DEFAULT_INTERVAL = 0x3ff;
25+
private static final int DEFAULT_INTERVAL = 0x3ff;
2626

2727
// scores are always positive
28-
final LongAccumulator acc = new LongAccumulator(MaxScoreAccumulator::maxEncode, Long.MIN_VALUE);
28+
final LongAccumulator acc = new LongAccumulator(Math::max, Long.MIN_VALUE);
2929

3030
// non-final and visible for tests
3131
long modInterval;
@@ -34,35 +34,8 @@ final class MaxScoreAccumulator {
3434
this.modInterval = DEFAULT_INTERVAL;
3535
}
3636

37-
/**
38-
* Return the max encoded docId and score found in the two longs, following the encoding in {@link
39-
* #accumulate}.
40-
*/
41-
private static long maxEncode(long v1, long v2) {
42-
float score1 = Float.intBitsToFloat((int) (v1 >> 32));
43-
float score2 = Float.intBitsToFloat((int) (v2 >> 32));
44-
int cmp = Float.compare(score1, score2);
45-
if (cmp == 0) {
46-
// tie-break on the minimum doc base
47-
return (int) v1 < (int) v2 ? v1 : v2;
48-
} else if (cmp > 0) {
49-
return v1;
50-
}
51-
return v2;
52-
}
53-
54-
void accumulate(int docId, float score) {
55-
assert docId >= 0 && score >= 0;
56-
long encode = (((long) Float.floatToIntBits(score)) << 32) | docId;
57-
acc.accumulate(encode);
58-
}
59-
60-
public static float toScore(long value) {
61-
return Float.intBitsToFloat((int) (value >> 32));
62-
}
63-
64-
public static int docId(long value) {
65-
return (int) value;
37+
void accumulate(long code) {
38+
acc.accumulate(code);
6639
}
6740

6841
long getRaw() {

lucene/core/src/java/org/apache/lucene/search/TopDocsCollector.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -153,18 +153,25 @@ public TopDocs topDocs(int start, int howMany) {
153153
howMany = Math.min(size - start, howMany);
154154
ScoreDoc[] results = new ScoreDoc[howMany];
155155

156-
// pq's pop() returns the 'least' element in the queue, therefore need
157-
// to discard the first ones, until we reach the requested range.
156+
// Prune the least competitive hits until we reach the requested range.
158157
// Note that this loop will usually not be executed, since the common usage
159158
// should be that the caller asks for the last howMany results. However it's
160159
// needed here for completeness.
161-
for (int i = pq.size() - start - howMany; i > 0; i--) {
162-
pq.pop();
163-
}
160+
pruneLeastCompetitiveHitsTo(start + howMany);
164161

165162
// Get the requested results from pq.
166163
populateResults(results, howMany);
167164

168165
return newTopDocs(results, start);
169166
}
167+
168+
/**
169+
* Prune the least competitive hits until the number of candidates is less than or equal to {@code
170+
* keep}. This is typically called before {@link #populateResults} to ensure we are at right pos.
171+
*/
172+
protected void pruneLeastCompetitiveHitsTo(int keep) {
173+
for (int i = pq.size() - keep; i > 0; i--) {
174+
pq.pop();
175+
}
176+
}
170177
}

lucene/core/src/java/org/apache/lucene/search/TopFieldCollector.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -367,7 +367,7 @@ protected void updateGlobalMinCompetitiveScore(Scorable scorer) throws IOExcepti
367367
long maxMinScore = minScoreAcc.getRaw();
368368
float score;
369369
if (maxMinScore != Long.MIN_VALUE
370-
&& (score = MaxScoreAccumulator.toScore(maxMinScore)) > minCompetitiveScore) {
370+
&& (score = DocScoreEncoder.toScore(maxMinScore)) > minCompetitiveScore) {
371371
scorer.setMinCompetitiveScore(score);
372372
minCompetitiveScore = score;
373373
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
@@ -384,7 +384,7 @@ protected void updateMinCompetitiveScore(Scorable scorer) throws IOException {
384384
minCompetitiveScore = minScore;
385385
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
386386
if (minScoreAcc != null) {
387-
minScoreAcc.accumulate(docBase, minScore);
387+
minScoreAcc.accumulate(DocScoreEncoder.encode(docBase, minScore));
388388
}
389389
}
390390
}

lucene/core/src/java/org/apache/lucene/search/TopScoreDocCollector.java

Lines changed: 40 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import java.io.IOException;
2020
import org.apache.lucene.index.LeafReaderContext;
21+
import org.apache.lucene.util.LongHeap;
2122

2223
/**
2324
* A {@link Collector} implementation that collects the top-scoring hits, returning them as a {@link
@@ -32,31 +33,20 @@
3233
public class TopScoreDocCollector extends TopDocsCollector<ScoreDoc> {
3334

3435
private final ScoreDoc after;
36+
private final LongHeap heap;
3537
final int totalHitsThreshold;
3638
final MaxScoreAccumulator minScoreAcc;
3739

3840
// prevents instantiation
3941
TopScoreDocCollector(
4042
int numHits, ScoreDoc after, int totalHitsThreshold, MaxScoreAccumulator minScoreAcc) {
41-
super(new HitQueue(numHits, true));
43+
super(null);
44+
this.heap = new LongHeap(numHits, DocScoreEncoder.LEAST_COMPETITIVE_CODE);
4245
this.after = after;
4346
this.totalHitsThreshold = totalHitsThreshold;
4447
this.minScoreAcc = minScoreAcc;
4548
}
4649

47-
@Override
48-
protected int topDocsSize() {
49-
// Note: this relies on sentinel values having Integer.MAX_VALUE as a doc ID.
50-
int[] validTopHitCount = new int[1];
51-
pq.forEach(
52-
scoreDoc -> {
53-
if (scoreDoc.doc != Integer.MAX_VALUE) {
54-
validTopHitCount[0]++;
55-
}
56-
});
57-
return validTopHitCount[0];
58-
}
59-
6050
@Override
6151
protected TopDocs newTopDocs(ScoreDoc[] results, int start) {
6252
return results == null
@@ -86,9 +76,8 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept
8676
return new LeafCollector() {
8777

8878
private Scorable scorer;
89-
// HitQueue implements getSentinelObject to return a ScoreDoc, so we know
90-
// that at this point top() is already initialized.
91-
private ScoreDoc pqTop = pq.top();
79+
private long topCode = heap.top();
80+
private float topScore = DocScoreEncoder.toScore(topCode);
9281
private float minCompetitiveScore;
9382

9483
@Override
@@ -121,7 +110,7 @@ public void collect(int doc) throws IOException {
121110
return;
122111
}
123112

124-
if (score <= pqTop.score) {
113+
if (score <= topScore) {
125114
// Note: for queries that match lots of hits, this is the common case: most hits are not
126115
// competitive.
127116
if (hitCountSoFar == totalHitsThreshold + 1) {
@@ -139,9 +128,9 @@ public void collect(int doc) throws IOException {
139128
}
140129

141130
private void collectCompetitiveHit(int doc, float score) throws IOException {
142-
pqTop.doc = doc + docBase;
143-
pqTop.score = score;
144-
pqTop = pq.updateTop();
131+
final long code = DocScoreEncoder.encode(doc + docBase, score);
132+
topCode = heap.updateTop(code);
133+
topScore = DocScoreEncoder.toScore(topCode);
145134
updateMinCompetitiveScore(scorer);
146135
}
147136

@@ -152,8 +141,8 @@ private void updateGlobalMinCompetitiveScore(Scorable scorer) throws IOException
152141
// since we tie-break on doc id and collect in doc id order we can require
153142
// the next float if the global minimum score is set on a document id that is
154143
// smaller than the ids in the current leaf
155-
float score = MaxScoreAccumulator.toScore(maxMinScore);
156-
score = docBase >= MaxScoreAccumulator.docId(maxMinScore) ? Math.nextUp(score) : score;
144+
float score = DocScoreEncoder.toScore(maxMinScore);
145+
score = docBase >= DocScoreEncoder.docId(maxMinScore) ? Math.nextUp(score) : score;
157146
if (score > minCompetitiveScore) {
158147
scorer.setMinCompetitiveScore(score);
159148
minCompetitiveScore = score;
@@ -168,19 +157,45 @@ private void updateMinCompetitiveScore(Scorable scorer) throws IOException {
168157
// pqTop is never null since TopScoreDocCollector fills the priority queue with sentinel
169158
// values if the top element is a sentinel value, its score will be -Infty and the below
170159
// logic is still valid
171-
float localMinScore = Math.nextUp(pqTop.score);
160+
float localMinScore = Math.nextUp(topScore);
172161
if (localMinScore > minCompetitiveScore) {
173162
scorer.setMinCompetitiveScore(localMinScore);
174163
totalHitsRelation = TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO;
175164
minCompetitiveScore = localMinScore;
176165
if (minScoreAcc != null) {
177166
// we don't use the next float but we register the document id so that other leaves or
178167
// leaf partitions can require it if they are after the current maximum
179-
minScoreAcc.accumulate(pqTop.doc, pqTop.score);
168+
minScoreAcc.accumulate(topCode);
180169
}
181170
}
182171
}
183172
}
184173
};
185174
}
175+
176+
@Override
177+
protected int topDocsSize() {
178+
int cnt = 0;
179+
for (int i = 1; i <= heap.size(); i++) {
180+
if (heap.get(i) != DocScoreEncoder.LEAST_COMPETITIVE_CODE) {
181+
cnt++;
182+
}
183+
}
184+
return cnt;
185+
}
186+
187+
@Override
188+
protected void populateResults(ScoreDoc[] results, int howMany) {
189+
for (int i = howMany - 1; i >= 0; i--) {
190+
long encode = heap.pop();
191+
results[i] = new ScoreDoc(DocScoreEncoder.docId(encode), DocScoreEncoder.toScore(encode));
192+
}
193+
}
194+
195+
@Override
196+
protected void pruneLeastCompetitiveHitsTo(int keep) {
197+
for (int i = heap.size() - keep; i > 0; i--) {
198+
heap.pop();
199+
}
200+
}
186201
}

lucene/core/src/java/org/apache/lucene/util/LongHeap.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
*/
1717
package org.apache.lucene.util;
1818

19+
import java.util.Arrays;
20+
1921
/**
2022
* A min heap that stores longs; a primitive priority queue that like all priority queues maintains
2123
* a partial ordering of its elements such that the least element can always be found in constant
@@ -33,6 +35,18 @@ public final class LongHeap {
3335
private long[] heap;
3436
private int size = 0;
3537

38+
/**
39+
* Constructs a heap with specified size and initializes all elements with the given value.
40+
*
41+
* @param size the number of elements to initialize in the heap.
42+
* @param initialValue the value to fill the heap with.
43+
*/
44+
public LongHeap(int size, long initialValue) {
45+
this(size);
46+
Arrays.fill(heap, 1, size + 1, initialValue);
47+
this.size = size;
48+
}
49+
3650
/**
3751
* Create an empty priority queue of the configured initial size.
3852
*
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.lucene.search;
19+
20+
import org.apache.lucene.tests.util.LuceneTestCase;
21+
22+
public class TestDocScoreEncoder extends LuceneTestCase {
23+
24+
public void testRandom() {
25+
for (int i = 0; i < 1000; i++) {
26+
doAssert(
27+
Float.intBitsToFloat(random().nextInt()),
28+
random().nextInt(Integer.MAX_VALUE),
29+
Float.intBitsToFloat(random().nextInt()),
30+
random().nextInt(Integer.MAX_VALUE));
31+
}
32+
}
33+
34+
public void testSameDoc() {
35+
for (int i = 0; i < 1000; i++) {
36+
doAssert(
37+
Float.intBitsToFloat(random().nextInt()), 1, Float.intBitsToFloat(random().nextInt()), 1);
38+
}
39+
}
40+
41+
public void testSameScore() {
42+
for (int i = 0; i < 1000; i++) {
43+
doAssert(1f, random().nextInt(Integer.MAX_VALUE), 1f, random().nextInt(Integer.MAX_VALUE));
44+
}
45+
}
46+
47+
private void doAssert(float score1, int doc1, float score2, int doc2) {
48+
if (Float.isNaN(score1) || Float.isNaN(score2)) {
49+
return;
50+
}
51+
52+
long code1 = DocScoreEncoder.encode(doc1, score1);
53+
long code2 = DocScoreEncoder.encode(doc2, score2);
54+
55+
assertEquals(doc1, DocScoreEncoder.docId(code1));
56+
assertEquals(doc2, DocScoreEncoder.docId(code2));
57+
assertEquals(score1, DocScoreEncoder.toScore(code1), 0f);
58+
assertEquals(score2, DocScoreEncoder.toScore(code2), 0f);
59+
60+
if (score1 < score2) {
61+
assertTrue(code1 < code2);
62+
} else if (score1 > score2) {
63+
assertTrue(code1 > code2);
64+
} else if (doc1 == doc2) {
65+
assertEquals(code1, code2);
66+
} else {
67+
assertEquals(code1 > code2, doc1 < doc2);
68+
}
69+
}
70+
}

0 commit comments

Comments
 (0)