Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 110 additions & 64 deletions src/main/knn/KnnGraphTester.java
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@
package knn;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.lang.management.ManagementFactory;
import java.lang.management.ThreadMXBean;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.IntBuffer;
import java.io.Serializable;
import java.nio.channels.FileChannel;
import java.nio.file.Files;
import java.nio.file.Path;
Expand All @@ -33,6 +32,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -83,6 +83,8 @@
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.ConstantScoreScorer;
import org.apache.lucene.search.ConstantScoreWeight;
import org.apache.lucene.search.DoubleValuesSourceRescorer;
import org.apache.lucene.search.FullPrecisionFloatVectorSimilarityValuesSource;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.KnnByteVectorQuery;
import org.apache.lucene.search.KnnFloatVectorQuery;
Expand All @@ -93,7 +95,6 @@
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.CheckJoinIndex;
import org.apache.lucene.store.Directory;
Expand Down Expand Up @@ -180,6 +181,8 @@ enum IndexType {
private IndexType indexType;
// oversampling, e.g. the multiple * k to gather before checking recall
private float overSample;
// rerank using full precision vectors
private boolean rerank;

private KnnGraphTester() {
// set defaults
Expand All @@ -203,6 +206,7 @@ private KnnGraphTester() {
queryStartIndex = 0;
indexType = IndexType.HNSW;
overSample = 1f;
rerank = false;
}

private static FileChannel getVectorFileChannel(Path path, int dim, VectorEncoding vectorEncoding, boolean noisy) throws IOException {
Expand Down Expand Up @@ -284,6 +288,9 @@ private void run(String... args) throws Exception {
throw new IllegalArgumentException("-overSample must be >= 1");
}
break;
case "-rerank":
rerank = true;
break;
case "-fanout":
if (iarg == args.length - 1) {
throw new IllegalArgumentException("-fanout requires a following number");
Expand Down Expand Up @@ -839,12 +846,12 @@ private void printHist(int[] hist, int max, int count, int nbuckets) {
}

@SuppressForbidden(reason = "Prints stuff")
private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Path outputPath, int[][] nn)
private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Path outputPath, ResultIds[][] nn)
throws IOException {
Result[] results = new Result[numQueryVectors];
int[][] resultIds = new int[numQueryVectors][];
ResultIds[][] resultIds = new ResultIds[numQueryVectors][];
long elapsedMS, totalCpuTimeMS, totalVisited = 0;
int topK = (overSample > 1) ? (int) (this.topK * overSample) : this.topK;
int annTopK = (overSample > 1) ? (int) (this.topK * overSample) : this.topK;
int fanout = (overSample > 1) ? (int) (this.fanout * overSample) : this.fanout;
ExecutorService executorService;
if (numSearchThread > 0) {
Expand All @@ -860,7 +867,7 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
if (targetReader instanceof VectorReaderByte b) {
targetReaderByte = b;
}
log("searching " + numQueryVectors + " query vectors; topK=" + topK + ", fanout=" + fanout + "\n");
log("searching " + numQueryVectors + " query vectors; ann-topK=" + annTopK + ", fanout=" + fanout + "\n");
long startNS;
try (MMapDirectory dir = new MMapDirectory(indexPath)) {
dir.setPreload((x, ctx) -> x.endsWith(".vec") || x.endsWith(".veq"));
Expand All @@ -874,10 +881,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
for (int i = 0; i < numQueryVectors; i++) {
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
byte[] target = targetReaderByte.nextBytes();
doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery);
doKnnByteVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery);
} else {
float[] target = targetReader.next();
doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery, parentJoin);
doKnnVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery, parentJoin);
}
}
targetReader.reset();
Expand All @@ -886,10 +893,10 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
for (int i = 0; i < numQueryVectors; i++) {
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
byte[] target = targetReaderByte.nextBytes();
results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery);
results[i] = doKnnByteVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery);
} else {
float[] target = targetReader.next();
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, topK, fanout, prefilter, filterQuery, parentJoin);
results[i] = doKnnVectorQuery(searcher, KNN_FIELD, target, annTopK, fanout, prefilter, filterQuery, parentJoin);
}
}
ThreadDetails endThreadDetails = new ThreadDetails();
Expand Down Expand Up @@ -930,18 +937,14 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
executorService.shutdown();
}
}
// Do we need to write nn here again? We already wrote it in getExactNN()
if (outputPath != null) {
ByteBuffer tmp =
ByteBuffer.allocate(resultIds[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN);
try (OutputStream out = Files.newOutputStream(outputPath)) {
for (int i = 0; i < numQueryVectors; i++) {
tmp.asIntBuffer().put(nn[i]);
out.write(tmp.array());
}
}
writeExactNN(nn, outputPath);
} else {
log("checking results\n");
float recall = checkResults(resultIds, nn);
float recall = checkRecall(resultIds, nn);
double ndcg10 = calculateNDCG(nn, resultIds, 10);
double ndcgK = calculateNDCG(nn, resultIds, topK);
totalVisited /= numQueryVectors;
String quantizeDesc;
if (quantize) {
Expand All @@ -952,8 +955,11 @@ private void testSearch(Path indexPath, Path queryPath, int queryStartIndex, Pat
double reindexSec = reindexTimeMsec / 1000.0;
System.out.printf(
Locale.ROOT,
"SUMMARY: %5.3f\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\t%5.3f\t%s\n",
"SUMMARY: %5.3f\t%5.3f\t%5.3f\t%s\t%5.3f\t%5.3f\t%5.3f\t%d\t%d\t%d\t%d\t%d\t%s\t%d\t%.2f\t%.2f\t%.2f\t%d\t%.2f\t%.2f\t%s\t%5.3f\t%5.3f\t%5.3f\t%s\n",
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uh oh -- nightly benchy will be angry -- could you fix it to expect these new inserted columns? Don't worry about testing it ... I can do that when we merge this.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How do I do that Mike? Do I need to declare these columns somewhere for nightly benchmarks?

recall,
ndcg10,
ndcgK,
rerank,
elapsedMS / (float) numQueryVectors,
totalCpuTimeMS / (float) numQueryVectors,
totalCpuTimeMS / (float) elapsedMS,
Expand Down Expand Up @@ -999,7 +1005,7 @@ private static Result doKnnByteVectorQuery(
return new Result(docs, profiledQuery.totalVectorCount(), 0);
}

private static Result doKnnVectorQuery(
private Result doKnnVectorQuery(
IndexSearcher searcher, String field, float[] vector, int k, int fanout, boolean prefilter, Query filter, boolean isParentJoinQuery)
throws IOException {
if (isParentJoinQuery) {
Expand All @@ -1013,35 +1019,78 @@ private static Result doKnnVectorQuery(
.add(filter, BooleanClause.Occur.FILTER)
.build();
TopDocs docs = searcher.search(query, k);
if (rerank) {
FullPrecisionFloatVectorSimilarityValuesSource valuesSource = new FullPrecisionFloatVectorSimilarityValuesSource(vector, field);
DoubleValuesSourceRescorer rescorer = new DoubleValuesSourceRescorer(valuesSource) {
@Override
protected float combine(float firstPassScore, boolean valuePresent, double sourceValue) {
return valuePresent ? (float) sourceValue : firstPassScore;
}
};
TopDocs rerankedDocs = rescorer.rescore(searcher, docs, topK);
return new Result(rerankedDocs, profiledQuery.totalVectorCount(), 0);
}
return new Result(docs, profiledQuery.totalVectorCount(), 0);
}

record Result(TopDocs topDocs, long visitedCount, int reentryCount) {
}

private float checkResults(int[][] results, int[][] nn) {
/** Holds ids and scores for corpus docs in search results */
record ResultIds(int id, float score) implements Serializable {}

private float checkRecall(ResultIds[][] results, ResultIds[][] expected) {
int totalMatches = 0;
int totalResults = results.length * topK;
for (int i = 0; i < results.length; i++) {
int totalResults = expected.length * topK;
for (int i = 0; i < expected.length; i++) {
// System.out.println("compare " + Arrays.toString(nn[i]) + " to ");
// System.out.println(Arrays.toString(results[i]));
totalMatches += compareNN(nn[i], results[i]);
totalMatches += compareNN(expected[i], results[i]);
}
return totalMatches / (float) totalResults;
}

private int compareNN(int[] expected, int[] results) {
/**
* Calculates Normalized Discounted Cumulative Gain (NDCG) at K.
*
* <p>We use full precision vector similarity scores for relevance. Since actual
* knn search result may hold quantized scores, we use scores for the corresponding
* document "id" from {@code ideal} search results. If a document is not present
* in ideal, it is considered irrelevant, and we assign it a score of 0f.
*/
private double calculateNDCG(ResultIds[][] ideal, ResultIds[][] actual, int k) {
double ndcg = 0;
for (int i = 0; i < ideal.length; i++) {
float[] exactResultsRelevance = new float[ideal[i].length];
HashMap<Integer, Float> idToRelevance = new HashMap<Integer, Float>(ideal[i].length);
for (int rank = 0; rank < ideal[i].length; rank++) {
exactResultsRelevance[rank] = ideal[i][rank].score();
idToRelevance.put(ideal[i][rank].id(), ideal[i][rank].score());
}
float[] actualResultsRelevance = new float[actual[i].length];
for (int rank = 0; rank < actual[i].length; rank++) {
actualResultsRelevance[rank] = idToRelevance.getOrDefault(actual[i][rank].id(), 0f);
}
double idealDCG = KnnTesterUtils.dcg(exactResultsRelevance, k);
double actualDCG = KnnTesterUtils.dcg(actualResultsRelevance, k);
ndcg += (actualDCG / idealDCG);
}
ndcg /= ideal.length;
return ndcg;
}

private int compareNN(ResultIds[] expected, ResultIds[] results) {
int matched = 0;
Set<Integer> expectedSet = new HashSet<>();
Set<Integer> alreadySeen = new HashSet<>();
for (int i = 0; i < topK; i++) {
expectedSet.add(expected[i]);
expectedSet.add(expected[i].id);
}
for (int docId : results) {
if (alreadySeen.add(docId) == false) {
throw new IllegalStateException("duplicate docId=" + docId);
for (ResultIds r : results) {
if (alreadySeen.add(r.id) == false) {
throw new IllegalStateException("duplicate docId=" + r.id);
}
if (expectedSet.contains(docId)) {
if (expectedSet.contains(r.id)) {
++matched;
}
}
Expand All @@ -1053,7 +1102,7 @@ private int compareNN(int[] expected, int[] results) {
* The method runs "numQueryVectors" target queries and returns "topK" nearest neighbors
* for each of them. Nearest Neighbors are computed using exact match.
*/
private int[][] getExactNN(Path docPath, Path indexPath, Path queryPath, int queryStartIndex) throws IOException, InterruptedException {
private ResultIds[][] getExactNN(Path docPath, Path indexPath, Path queryPath, int queryStartIndex) throws IOException, InterruptedException {
// look in working directory for cached nn file
String hash = Integer.toString(Objects.hash(docPath, indexPath, queryPath, numDocs, numQueryVectors, topK, similarityFunction.ordinal(), parentJoin, queryStartIndex, prefilter ? selectivity : 1f, prefilter ? randomSeed : 0f), 36);
String nnFileName = "nn-" + hash + ".bin";
Expand All @@ -1066,7 +1115,7 @@ private int[][] getExactNN(Path docPath, Path indexPath, Path queryPath, int que
long startNS = System.nanoTime();
// TODO: enable computing NN from high precision vectors when
// checking low-precision recall
int[][] nn;
ResultIds[][] nn;
if (vectorEncoding.equals(VectorEncoding.BYTE)) {
nn = computeExactNNByte(queryPath, queryStartIndex);
} else {
Expand All @@ -1089,35 +1138,32 @@ private boolean isNewer(Path path, Path... others) throws IOException {
return true;
}

private int[][] readExactNN(Path nnPath) throws IOException {
int[][] result = new int[numQueryVectors][];
try (FileChannel in = FileChannel.open(nnPath)) {
IntBuffer intBuffer =
in.map(FileChannel.MapMode.READ_ONLY, 0, numQueryVectors * topK * Integer.BYTES)
.order(ByteOrder.LITTLE_ENDIAN)
.asIntBuffer();
private ResultIds[][] readExactNN(Path nnPath) throws IOException {
log("reading true nearest neighbors from file \"" + nnPath + "\"\n");
ResultIds[][] nn = new ResultIds[numQueryVectors][];
try (InputStream in = Files.newInputStream(nnPath);
ObjectInputStream ois = new ObjectInputStream(in)) {
for (int i = 0; i < numQueryVectors; i++) {
result[i] = new int[topK];
intBuffer.get(result[i]);
nn[i] = (ResultIds[]) ois.readObject();
}
} catch (ClassNotFoundException e) {
throw new RuntimeException(e);
}
return result;
return nn;
}

private void writeExactNN(int[][] nn, Path nnPath) throws IOException {
log("writing true nearest neighbors to cache file \"" + nnPath + "\"\n");
ByteBuffer tmp =
ByteBuffer.allocate(nn[0].length * Integer.BYTES).order(ByteOrder.LITTLE_ENDIAN);
try (OutputStream out = Files.newOutputStream(nnPath)) {
private void writeExactNN(ResultIds[][] nn, Path nnPath) throws IOException {
log("\nwriting true nearest neighbors to cache file \"" + nnPath + "\"\n");
try (OutputStream fileOutputStream = Files.newOutputStream(nnPath);
ObjectOutputStream objectOutputStream = new ObjectOutputStream(fileOutputStream)) {
for (int i = 0; i < numQueryVectors; i++) {
tmp.asIntBuffer().put(nn[i]);
out.write(tmp.array());
objectOutputStream.writeObject(nn[i]);
}
}
}

private int[][] computeExactNNByte(Path queryPath, int queryStartIndex) throws IOException, InterruptedException {
int[][] result = new int[numQueryVectors][];
private ResultIds[][] computeExactNNByte(Path queryPath, int queryStartIndex) throws IOException, InterruptedException {
ResultIds[][] result = new ResultIds[numQueryVectors][];
log("computing true nearest neighbors of " + numQueryVectors + " target vectors\n");
List<ComputeNNByteTask> tasks = new ArrayList<>();
try (MMapDirectory dir = new MMapDirectory(indexPath)) {
Expand All @@ -1143,10 +1189,10 @@ class ComputeNNByteTask implements Callable<Void> {

private final int queryOrd;
private final byte[] query;
private final int[][] result;
private final ResultIds[][] result;
private final IndexReader reader;

ComputeNNByteTask(int queryOrd, byte[] query, int[][] result, IndexReader reader) {
ComputeNNByteTask(int queryOrd, byte[] query, ResultIds[][] result, IndexReader reader) {
this.queryOrd = queryOrd;
this.query = query;
this.result = result;
Expand Down Expand Up @@ -1176,9 +1222,9 @@ public Void call() {
}

/** Brute force computation of "true" nearest neighhbors. */
private int[][] computeExactNN(Path queryPath, int queryStartIndex)
private ResultIds[][] computeExactNN(Path queryPath, int queryStartIndex)
throws IOException, InterruptedException {
int[][] result = new int[numQueryVectors][];
ResultIds[][] result = new ResultIds[numQueryVectors][];
log("computing true nearest neighbors of " + numQueryVectors + " target vectors\n");
log("parentJoin = %s\n", parentJoin);
try (MMapDirectory dir = new MMapDirectory(indexPath)) {
Expand Down Expand Up @@ -1216,10 +1262,10 @@ class ComputeNNFloatTask implements Callable<Void> {

private final int queryOrd;
private final float[] query;
private final int[][] result;
private final ResultIds[][] result;
private final IndexReader reader;

ComputeNNFloatTask(int queryOrd, float[] query, int[][] result, IndexReader reader) {
ComputeNNFloatTask(int queryOrd, float[] query, ResultIds[][] result, IndexReader reader) {
this.queryOrd = queryOrd;
this.query = query;
this.result = result;
Expand Down Expand Up @@ -1255,10 +1301,10 @@ class ComputeExactSearchNNFloatTask implements Callable<Void> {

private final int queryOrd;
private final float[] query;
private final int[][] result;
private final ResultIds[][] result;
private final IndexReader reader;

ComputeExactSearchNNFloatTask(int queryOrd, float[] query, int[][] result, IndexReader reader) {
ComputeExactSearchNNFloatTask(int queryOrd, float[] query, ResultIds[][] result, IndexReader reader) {
this.queryOrd = queryOrd;
this.query = query;
this.result = result;
Expand Down
Loading