Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions src/main/knn/KnnGraphTester.java
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.FilterDocIdSetIterator;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
Expand Down Expand Up @@ -1426,18 +1427,29 @@ long totalVectorCount() {

private static class BitSetQuery extends Query {
private final BitSet[] segmentDocs;
private final int[] cardinalities;
private final int hash;

BitSetQuery(BitSet[] segmentDocs) {
this.segmentDocs = segmentDocs;
this.cardinalities = new int[segmentDocs.length];
for (int i = 0; i < segmentDocs.length; i++) {
cardinalities[i] = segmentDocs[i].cardinality();
}
this.hash = Arrays.hashCode(segmentDocs);
}

@Override
public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
return new ConstantScoreWeight(this, boost) {
public ScorerSupplier scorerSupplier(LeafReaderContext context) throws IOException {
var bitSet = segmentDocs[context.ord];
var cardinality = bitSet.cardinality();
var scorer = new ConstantScoreScorer(score(), scoreMode, new BitSetIterator(bitSet, cardinality));
var cardinality = cardinalities[context.ord];
var scorer = new ConstantScoreScorer(
score(),
scoreMode,
// wrap it to simulate a more realistic query that must iterate its docs
new FilterDocIdSetIterator(new BitSetIterator(bitSet, cardinality)));
return new ScorerSupplier() {
@Override
public Scorer get(long leadCost) throws IOException {
Expand Down Expand Up @@ -1474,7 +1486,7 @@ public boolean equals(Object other) {

@Override
public int hashCode() {
return 31 * classHash() + Arrays.hashCode(segmentDocs);
return 31 * classHash() + hash;
}
}
}