Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/run-checks-all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
matrix:
os: [ ubuntu-latest ]
java: [ '21' ]
compiler: [ gcc ]

runs-on: ${{ matrix.os }}

Expand All @@ -38,6 +39,8 @@ jobs:
- uses: ./.github/actions/prepare-for-build

- name: Run gradle check (without tests)
env:
CC: ${{ matrix.compiler }}
run: ./gradlew check -x test -Ptask.times=true --max-workers 2


Expand All @@ -53,6 +56,7 @@ jobs:
# macos-latest: a tad slower than ubuntu and pretty much the same (?) so leaving out.
os: [ ubuntu-latest ]
java: [ '21' ]
compiler: [ gcc ]

runs-on: ${{ matrix.os }}

Expand All @@ -61,6 +65,8 @@ jobs:
- uses: ./.github/actions/prepare-for-build

- name: Run gradle tests
env:
CC: ${{ matrix.compiler }}
run: ./gradlew test "-Ptask.times=true" --max-workers 2

- name: List automatically-initialized gradle.properties
Expand Down
2 changes: 2 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.time.format.DateTimeFormatter
plugins {
id "base"
id "lucene.build-infra"
id "c"

alias(deps.plugins.dependencychecks)
alias(deps.plugins.spotless) apply false
Expand All @@ -34,6 +35,7 @@ plugins {
alias(deps.plugins.jacocolog) apply false
}


apply from: file('gradle/globals.gradle')

// General metadata.
Expand Down
3 changes: 3 additions & 0 deletions gradle/testing/randomization.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ allprojects {
[propName: 'tests.forceintegervectors',
value: { -> testsDefaultVectorizationRequested() ? false : (randomVectorSize != 'default') },
description: "Forces use of integer vectors even when slow."],
// test native dot-product when running with Java 21 or greater and 'default' vector size (chosen by randomized testing)
[propName: 'test.native.dotProduct',
value: { -> testsDefaultVectorizationRequested() ? false : (randomVectorSize == 'default' && rootProject.vectorIncubatorJavaVersions.contains(rootProject.runtimeJavaVersion))}],
[propName: 'tests.defaultvectorization', value: false,
description: "Uses defaults for running tests with correct JVM settings to test Panama vectorization (tests.jvmargs, tests.vectorsize, tests.forceintegervectors)."],
]
Expand Down
3 changes: 3 additions & 0 deletions gradle/testing/randomization/policies/tests.policy
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ grant {
// Needed for DirectIODirectory to retrieve block size
permission java.lang.RuntimePermission "getFileStoreAttributes";

// Needed to load native library containing optimized dot product implementation
permission java.lang.RuntimePermission "loadLibrary.dotProduct";

// TestLockFactoriesMultiJVM opens a random port on 127.0.0.1 (port 0 = ephemeral port range):
permission java.net.SocketPermission "127.0.0.1:0", "accept,listen,resolve";
// Replicator tests connect to ephemeral ports
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@
*/
package org.apache.lucene.benchmark.jmh;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.TimeUnit;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;
import org.apache.lucene.util.VectorUtil;
import org.openjdk.jmh.annotations.*;

Expand All @@ -33,6 +38,82 @@
value = 3,
jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
public class VectorUtilBenchmark {

/**
* Used to get a MethodHandle of PanamaVectorUtilSupport.dotProduct(MemorySegment a, MemorySegment
* b). The method above will use a native C implementation of dotProduct if it is enabled via
* {@link org.apache.lucene.util.Constants#NATIVE_DOT_PRODUCT_ENABLED} AND both MemorySegment
* arguments are backed by off-heap memory. A reflection based approach is necessary to avoid
* taking a direct dependency on preview APIs in Panama which may be blocked at compile time.
*
* @return MethodHandle PanamaVectorUtilSupport.DotProduct(MemorySegment a, MemorySegment b)
*/
private static MethodHandle nativeDotProductHandle(String methodName) {
if (Runtime.version().feature() < 21) {
return null;
}
try {
final VectorUtilSupport vectorUtilSupport =
VectorizationProvider.getInstance().getVectorUtilSupport();
if (vectorUtilSupport.getClass().getName().endsWith("PanamaVectorUtilSupport")) {
MethodHandles.Lookup lookup = MethodHandles.lookup();
// A method type that computes dot-product between two off-heap vectors
// provided as native MemorySegment and returns an int score.
final var MemorySegment = "java.lang.foreign.MemorySegment";
final var methodType =
MethodType.methodType(
int.class, lookup.findClass(MemorySegment), lookup.findClass(MemorySegment));
var mh = lookup.findStatic(vectorUtilSupport.getClass(), methodName, methodType);
// Erase the type of receiver to Object so that mh.invokeExact(a, b) does not throw
// WrongMethodException.
// Here 'a' and 'b' are off-heap vectors of type MemorySegment constructed via reflection
// API.
// This minimizes the reflection overhead and brings us very close to the performance of
// direct method invocation.
mh = mh.asType(mh.type().changeParameterType(0, Object.class));
mh = mh.asType(mh.type().changeParameterType(1, Object.class));
return mh;
}
} catch (ClassNotFoundException | IllegalAccessException | NoSuchMethodException e) {
throw new RuntimeException(e);
}
return null;
}

/**
* Copy input byte[] to off-heap MemorySegment
*
* @param byteVector to be copied off-heap
* @return Object MemorySegment
*/
private static Object getOffHeapByteVector(byte[] byteVector) {
try {
VectorizationProvider vectorizationProvider = VectorizationProvider.getInstance();
if (vectorizationProvider.getClass().getName().endsWith("PanamaVectorizationProvider")) {
MethodHandles.Lookup lookup = MethodHandles.lookup();
// A method type that copies input byte[] to an off-heap MemorySegment
final var methodType =
MethodType.methodType(
lookup.findClass("java.lang.foreign.MemorySegment"), byte[].class);
// The class is expected to be "PanamaVectorUtilSupport" with a static method
// "MemorySegment offHeapByteVector(byte[] byteVector)" that returns the off-heap vector as
// a
// MemorySegment
Class<?> vectorUtilSupportClass = vectorizationProvider.getVectorUtilSupport().getClass();
final MethodHandle offHeapByteVector =
lookup.findStatic(vectorUtilSupportClass, "offHeapByteVector", methodType);
return offHeapByteVector.invoke(byteVector);
}
} catch (Throwable e) {
throw new RuntimeException(e);
}
return null;
}

private static final MethodHandle NATIVE_DOT_PRODUCT = nativeDotProductHandle("dotProduct");
private static final MethodHandle SIMPLE_NATIVE_DOT_PRODUCT =
nativeDotProductHandle("simpleNativeDotProduct");

static void compressBytes(byte[] raw, byte[] compressed) {
for (int i = 0; i < compressed.length; ++i) {
int v = (raw[i] << 4) | raw[compressed.length + i];
Expand All @@ -49,6 +130,10 @@ static void compressBytes(byte[] raw, byte[] compressed) {
private float[] floatsB;
private int expectedhalfByteDotProduct;

private Object offHeapBytesA;
private Object offHeapBytesB;

/** private Object nativeBytesA; private Object nativeBytesB; */
@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
int size;

Expand Down Expand Up @@ -84,6 +169,32 @@ public void init() {
floatsA[i] = random.nextFloat();
floatsB[i] = random.nextFloat();
}
// Java 21+ specific initialization
final int runtimeVersion = Runtime.version().feature();
if (runtimeVersion >= 21) {
offHeapBytesA = getOffHeapByteVector(bytesA);
offHeapBytesB = getOffHeapByteVector(bytesB);
}
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int dot8s() {
try {
return (int) NATIVE_DOT_PRODUCT.invokeExact(offHeapBytesA, offHeapBytesB);
} catch (Throwable e) {
throw new RuntimeException(e);
}
}

@Benchmark
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
public int simpleDot8s() {
try {
return (int) SIMPLE_NATIVE_DOT_PRODUCT.invokeExact(offHeapBytesA, offHeapBytesB);
} catch (Throwable e) {
throw new RuntimeException(e);
}
}

@Benchmark
Expand Down
8 changes: 8 additions & 0 deletions lucene/core/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,11 @@ dependencies {
moduleTestImplementation project(':lucene:codecs')
moduleTestImplementation project(':lucene:test-framework')
}

test {
dependsOn ':lucene:misc:dotProductSharedLibrary'
systemProperty(
"java.library.path",
project(":lucene:misc").layout.buildDirectory.get().asFile.absolutePath + "/libs/dotProduct/shared"
)
}
2 changes: 2 additions & 0 deletions lucene/core/src/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@

exports org.apache.lucene.util.quantization;
exports org.apache.lucene.codecs.hnsw;
exports org.apache.lucene.internal.vectorization to
org.apache.lucene.benchmark.jmh;

provides org.apache.lucene.analysis.TokenizerFactory with
org.apache.lucene.analysis.standard.StandardTokenizerFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil;
import org.apache.lucene.codecs.hnsw.FlatVectorsFormat;
import org.apache.lucene.codecs.hnsw.FlatVectorsReader;
import org.apache.lucene.codecs.hnsw.FlatVectorsScorer;
import org.apache.lucene.codecs.hnsw.FlatVectorsWriter;
import org.apache.lucene.index.SegmentReadState;
import org.apache.lucene.index.SegmentWriteState;
import org.apache.lucene.util.Constants;

/**
* Format supporting vector quantization, storage, and retrieval
Expand Down Expand Up @@ -70,7 +72,8 @@ public class Lucene99ScalarQuantizedVectorsFormat extends FlatVectorsFormat {

final byte bits;
final boolean compress;
final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
// final Lucene99ScalarQuantizedVectorScorer flatVectorScorer;
final FlatVectorsScorer flatVectorScorer;

/** Constructs a format using default graph construction parameters */
public Lucene99ScalarQuantizedVectorsFormat() {
Expand Down Expand Up @@ -117,8 +120,16 @@ public Lucene99ScalarQuantizedVectorsFormat(
this.bits = (byte) bits;
this.confidenceInterval = confidenceInterval;
this.compress = compress;
this.flatVectorScorer =
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
if (Constants.NATIVE_DOT_PRODUCT_ENABLED == false) {
this.flatVectorScorer =
new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
} else {
FlatVectorsScorer scorer = FlatVectorScorerUtil.getLucene99FlatVectorsScorer();
if (scorer == DefaultFlatVectorScorer.INSTANCE) {
scorer = new Lucene99ScalarQuantizedVectorScorer(DefaultFlatVectorScorer.INSTANCE);
}
this.flatVectorScorer = scorer;
}
}

public static float calculateDefaultConfidenceInterval(int vectorDimension) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ private static Optional<Module> lookupVectorModule() {
// add all possible callers here as FQCN:
private static final Set<String> VALID_CALLERS =
Set.of(
"org.apache.lucene.benchmark.jmh.VectorUtilBenchmark",
"org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil",
"org.apache.lucene.util.VectorUtil",
"org.apache.lucene.codecs.lucene101.Lucene101PostingsReader",
Expand Down
9 changes: 9 additions & 0 deletions lucene/core/src/java/org/apache/lucene/util/Constants.java
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,15 @@ private static boolean is64Bit() {
/** true iff we know VFMA has faster throughput than separate vmul/vadd. */
public static final boolean HAS_FAST_VECTOR_FMA = hasFastVectorFMA();

public static final boolean NATIVE_DOT_PRODUCT_ENABLED = enableNativeDotProduct();

private static boolean enableNativeDotProduct() {
var armArchitecture = OS_ARCH.equalsIgnoreCase("aarch64");
var enabledExplicitly = Boolean.parseBoolean(getSysProp("lucene.useNativeDotProduct", "false"));
var enabledForTests = Boolean.parseBoolean(getSysProp("test.native.dotProduct", "false"));
return (armArchitecture && enabledExplicitly) || enabledForTests;
}

/** true iff we know FMA has faster throughput than separate mul/add. */
public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();

Expand Down
16 changes: 7 additions & 9 deletions lucene/core/src/java/org/apache/lucene/util/VectorUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.lucene.util;

import java.util.stream.IntStream;
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
import org.apache.lucene.internal.vectorization.VectorizationProvider;

Expand Down Expand Up @@ -50,7 +49,7 @@ public final class VectorUtil {

private static final float EPSILON = 1e-4f;

private static final VectorUtilSupport IMPL =
public static final VectorUtilSupport IMPL =
VectorizationProvider.getInstance().getVectorUtilSupport();

private VectorUtil() {}
Expand Down Expand Up @@ -310,13 +309,12 @@ public static float[] checkFinite(float[] v) {
}

/**
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code to}
* exclusive, find the first array index whose value is greater than or equal to {@code target}.
* This index is guaranteed to be at least {@code from}. If there is no such array index, {@code
* to} is returned.
* Given an array {@code buffer} that is sorted between indexes {@code 0} inclusive and {@code
* length} exclusive, find the first array index whose value is greater than or equal to {@code
* target}. This index is guaranteed to be at least {@code from}. If there is no such array index,
* {@code length} is returned.
*/
public static int findNextGEQ(int[] buffer, int target, int from, int to) {
assert IntStream.range(0, to - 1).noneMatch(i -> buffer[i] > buffer[i + 1]);
return IMPL.findNextGEQ(buffer, target, from, to);
public static int findNextGEQ(int[] buffer, int length, int target, int from) {
return IMPL.findNextGEQ(buffer, length, target, from);
}
}
Loading