Skip to content

Commit 3079c5d

Browse files
author
Ankur Goel
committed
Allocate offHeap memory in dotProduct(byte[], byte[]) for unit tests if native dot-product is enabled. Simplifyy JMH benchmark code that tests native dot product. Incorporate other review feedback
1 parent 4579dea commit 3079c5d

File tree

13 files changed

+325
-264
lines changed

13 files changed

+325
-264
lines changed

lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java

Lines changed: 93 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
import java.lang.invoke.MethodType;
2222
import java.util.concurrent.ThreadLocalRandom;
2323
import java.util.concurrent.TimeUnit;
24+
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
25+
import org.apache.lucene.internal.vectorization.VectorizationProvider;
2426
import org.apache.lucene.util.VectorUtil;
2527
import org.openjdk.jmh.annotations.*;
2628

@@ -36,6 +38,81 @@
3638
value = 3,
3739
jvmArgsAppend = {"-Xmx2g", "-Xms2g", "-XX:+AlwaysPreTouch"})
3840
public class VectorUtilBenchmark {
41+
42+
/**
43+
* Used to get a MethodHandle of PanamaVectorUtilSupport.dotProduct(MemorySegment a, MemorySegment
44+
* b). The method above will use a native C implementation of dotProduct if it is enabled via
45+
* {@link org.apache.lucene.util.Constants#NATIVE_DOT_PRODUCT_ENABLED} AND both MemorySegment
46+
* arguments are backed by off-heap memory. A reflection based approach is necessary to avoid
47+
* taking a direct dependency on preview APIs in Panama which may be blocked at compile time.
48+
*
49+
* @return MethodHandle PanamaVectorUtilSupport.DotProduct(MemorySegment a, MemorySegment b)
50+
*/
51+
private static MethodHandle nativeDotProductHandle(String methodName) {
52+
if (Runtime.version().feature() < 21) {
53+
return null;
54+
}
55+
try {
56+
final VectorUtilSupport vectorUtilSupport =
57+
VectorizationProvider.getInstance().getVectorUtilSupport();
58+
if (vectorUtilSupport.getClass().getName().endsWith("PanamaVectorUtilSupport")) {
59+
MethodHandles.Lookup lookup = MethodHandles.lookup();
60+
// A method type that computes dot-product between two off-heap vectors
61+
// provided as native MemorySegment and returns an int score.
62+
final var MemorySegment = "java.lang.foreign.MemorySegment";
63+
final var methodType =
64+
MethodType.methodType(
65+
int.class, lookup.findClass(MemorySegment), lookup.findClass(MemorySegment));
66+
var mh = lookup.findStatic(vectorUtilSupport.getClass(), methodName, methodType);
67+
// Erase the type of receiver to Object so that mh.invokeExact(a, b) does not throw
68+
// WrongMethodException.
69+
// Here 'a' and 'b' are off-heap vectors of type MemorySegment constructed via reflection
70+
// API.
71+
// This minimizes the reflection overhead and brings us very close to the performance of
72+
// direct method invocation.
73+
mh = mh.asType(mh.type().changeParameterType(0, Object.class));
74+
mh = mh.asType(mh.type().changeParameterType(1, Object.class));
75+
return mh;
76+
}
77+
} catch (Throwable e) {
78+
throw new RuntimeException(e);
79+
}
80+
return null;
81+
}
82+
83+
/**
84+
* Get randomly initialized byte-vectors of given size in off-heap MemorySegment
85+
*
86+
* @param size dimension of byte-vector
87+
* @return Object MemorySegment
88+
*/
89+
private static Object getOffHeapByteVector(int size) {
90+
try {
91+
VectorizationProvider vectorizationProvider = VectorizationProvider.getInstance();
92+
if (vectorizationProvider.getClass().getName().endsWith("PanamaVectorizationProvider")) {
93+
MethodHandles.Lookup lookup = MethodHandles.lookup();
94+
// A method type that accepts numBytes and returns an off-heap vector of size 'numBytes'
95+
// where each byte is randomly initialized
96+
final var methodType =
97+
MethodType.methodType(lookup.findClass("java.lang.foreign.MemorySegment"), int.class);
98+
// The class is expected to be "PanamaVectorUtilSupport" with a static method
99+
// "MemorySegment offHeapByteVector(int numBytes)" that returns the off-heap vector as a
100+
// MemorySegment
101+
Class<?> vectorUtilSupportClass = vectorizationProvider.getVectorUtilSupport().getClass();
102+
final MethodHandle offHeapByteVector =
103+
lookup.findStatic(vectorUtilSupportClass, "offHeapByteVector", methodType);
104+
return offHeapByteVector.invoke(size);
105+
}
106+
} catch (Throwable e) {
107+
throw new RuntimeException(e);
108+
}
109+
return null;
110+
}
111+
112+
private static final MethodHandle NATIVE_DOT_PRODUCT = nativeDotProductHandle("dotProduct");
113+
private static final MethodHandle SIMPLE_NATIVE_DOT_PRODUCT =
114+
nativeDotProductHandle("simpleNativeDotProduct");
115+
39116
static void compressBytes(byte[] raw, byte[] compressed) {
40117
for (int i = 0; i < compressed.length; ++i) {
41118
int v = (raw[i] << 4) | raw[compressed.length + i];
@@ -52,8 +129,8 @@ static void compressBytes(byte[] raw, byte[] compressed) {
52129
private float[] floatsB;
53130
private int expectedhalfByteDotProduct;
54131

55-
private Object nativeBytesA;
56-
private Object nativeBytesB;
132+
private Object offHeapBytesA;
133+
private Object offHeapBytesB;
57134

58135
/** private Object nativeBytesA; private Object nativeBytesB; */
59136
@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
@@ -94,70 +171,26 @@ public void init() {
94171
// Java 21+ specific initialization
95172
final int runtimeVersion = Runtime.version().feature();
96173
if (runtimeVersion >= 21) {
97-
// Reflection based code to eliminate the use of Preview classes in JMH benchmarks
98-
try {
99-
final Class<?> vectorUtilSupportClass = VectorUtil.getVectorUtilSupportClass();
100-
final var className = "org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport";
101-
if (vectorUtilSupportClass.getName().equals(className) == false) {
102-
nativeBytesA = null;
103-
nativeBytesB = null;
104-
} else {
105-
MethodHandles.Lookup lookup = MethodHandles.lookup();
106-
final var MemorySegment = "java.lang.foreign.MemorySegment";
107-
final var methodType =
108-
MethodType.methodType(lookup.findClass(MemorySegment), byte[].class);
109-
MethodHandle nativeMemorySegment =
110-
lookup.findStatic(vectorUtilSupportClass, "nativeMemorySegment", methodType);
111-
byte[] a = new byte[size];
112-
byte[] b = new byte[size];
113-
for (int i = 0; i < size; ++i) {
114-
a[i] = (byte) random.nextInt(128);
115-
b[i] = (byte) random.nextInt(128);
116-
}
117-
nativeBytesA = nativeMemorySegment.invoke(a);
118-
nativeBytesB = nativeMemorySegment.invoke(b);
119-
}
120-
} catch (Throwable e) {
121-
throw new RuntimeException(e);
122-
}
123-
/*
124-
Arena offHeap = Arena.ofAuto();
125-
nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
126-
nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
127-
for (int i = 0; i < size; ++i) {
128-
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
129-
nativeBytesB.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
130-
}*/
174+
offHeapBytesA = getOffHeapByteVector(size);
175+
offHeapBytesB = getOffHeapByteVector(size);
176+
}
177+
}
178+
179+
@Benchmark
180+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
181+
public int dot8s() {
182+
try {
183+
return (int) NATIVE_DOT_PRODUCT.invokeExact(offHeapBytesA, offHeapBytesB);
184+
} catch (Throwable e) {
185+
throw new RuntimeException(e);
131186
}
132187
}
133188

134-
/**
135-
* High overhead (lower score) from using NATIVE_DOT_PRODUCT.invoke(nativeBytesA, nativeBytesB).
136-
* Both nativeBytesA and nativeBytesB are offHeap MemorySegments created by invoking the method
137-
* PanamaVectorUtilSupport.nativeMemorySegment(byte[]) which allocated these segments and copies
138-
* bytes from the supplied byte[] to offHeap memory. The benchmark output below shows
139-
* significantly more overhead. <b>NOTE:</b> Return type of dots8s() was set to void for the
140-
* benchmark run to avoid boxing/unboxing overhead.
141-
*
142-
* <pre>
143-
* Benchmark (size) Mode Cnt Score Error Units
144-
* VectorUtilBenchmark.dot8s 768 thrpt 15 36.406 ± 0.496 ops/us
145-
* </pre>
146-
*
147-
* Much lower overhead was observed when preview APIs were used directly in JMH benchmarking code
148-
* and exact method invocation was made as shown below <b>return (int)
149-
* VectorUtil.NATIVE_DOT_PRODUCT.invokeExact(nativeBytesA, nativeBytesB);</b>
150-
*
151-
* <pre>
152-
* Benchmark (size) Mode Cnt Score Error Units
153-
* VectorUtilBenchmark.dot8s 768 thrpt 15 43.662 ± 0.818 ops/us
154-
* </pre>
155-
*/
156189
@Benchmark
157190
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
158-
public void dot8s() {
191+
public int simpleDot8s() {
159192
try {
160-
VectorUtil.NATIVE_DOT_PRODUCT.invoke(nativeBytesA, nativeBytesB);
193+
return (int) SIMPLE_NATIVE_DOT_PRODUCT.invokeExact(offHeapBytesA, offHeapBytesB);
161194
} catch (Throwable e) {
162195
throw new RuntimeException(e);
163196
}

lucene/core/src/java/module-info.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@
6969

7070
exports org.apache.lucene.util.quantization;
7171
exports org.apache.lucene.codecs.hnsw;
72+
exports org.apache.lucene.internal.vectorization to
73+
org.apache.lucene.benchmark.jmh;
7274

7375
provides org.apache.lucene.analysis.TokenizerFactory with
7476
org.apache.lucene.analysis.standard.StandardTokenizerFactory;

lucene/core/src/java/org/apache/lucene/codecs/lucene99/OffHeapQuantizedByteVectorValues.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,6 @@ public float getScoreCorrectionConstant(int targetOrd) throws IOException {
146146
}
147147
slice.seek(((long) targetOrd * byteSize) + numBytes);
148148
slice.readFloats(scoreCorrectionConstant, 0, 1);
149-
lastOrd = targetOrd;
150149
return scoreCorrectionConstant[0];
151150
}
152151

lucene/core/src/java/org/apache/lucene/internal/vectorization/VectorizationProvider.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ private static Optional<Module> lookupVectorModule() {
188188
// add all possible callers here as FQCN:
189189
private static final Set<String> VALID_CALLERS =
190190
Set.of(
191+
"org.apache.lucene.benchmark.jmh.VectorUtilBenchmark",
191192
"org.apache.lucene.codecs.hnsw.FlatVectorScorerUtil",
192193
"org.apache.lucene.util.VectorUtil",
193194
"org.apache.lucene.codecs.lucene912.Lucene912PostingsReader",

lucene/core/src/java/org/apache/lucene/util/Constants.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,19 @@ private static boolean is64Bit() {
104104
// "False")
105105
public static final boolean NATIVE_DOT_PRODUCT_ENABLED = OS_ARCH.equalsIgnoreCase("aarch64");
106106

107+
public static final boolean TEST_NATIVE_DOT_PRODUCT;
108+
109+
static {
110+
String v = System.getProperty("test.native.dotProduct", "false");
111+
v = v.trim();
112+
if (v.isEmpty() == false) {
113+
TEST_NATIVE_DOT_PRODUCT = Boolean.parseBoolean(v);
114+
} else {
115+
throw new IllegalArgumentException(
116+
"Boolean value expected for property - test.native.dotProduct");
117+
}
118+
}
119+
107120
/** true iff we know FMA has faster throughput than separate mul/add. */
108121
public static final boolean HAS_FAST_SCALAR_FMA = hasFastScalarFMA();
109122

lucene/core/src/java/org/apache/lucene/util/VectorUtil.java

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@
1717

1818
package org.apache.lucene.util;
1919

20-
import java.lang.invoke.MethodHandle;
21-
import java.lang.invoke.MethodHandles;
22-
import java.lang.invoke.MethodType;
2320
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
2421
import org.apache.lucene.internal.vectorization.VectorizationProvider;
2522

@@ -55,39 +52,8 @@ public final class VectorUtil {
5552
public static final VectorUtilSupport IMPL =
5653
VectorizationProvider.getInstance().getVectorUtilSupport();
5754

58-
// TODO: Harden this implementation and may be find a new home for this
59-
private static MethodHandle nativeDotProduct() {
60-
try {
61-
final var PanamaVectorUtilSupport =
62-
"org.apache.lucene.internal.vectorization.PanamaVectorUtilSupport";
63-
if (!IMPL.getClass().getName().equals(PanamaVectorUtilSupport)) {
64-
return null;
65-
}
66-
MethodHandles.Lookup lookup = MethodHandles.lookup();
67-
final var MemorySegment = "java.lang.foreign.MemorySegment";
68-
final var methodType =
69-
MethodType.methodType(
70-
int.class, lookup.findClass(MemorySegment), lookup.findClass(MemorySegment));
71-
return lookup.findStatic(IMPL.getClass(), "nativeDotProduct", methodType);
72-
} catch (Exception e) {
73-
throw new RuntimeException(e);
74-
}
75-
}
76-
77-
// For use in JMH benchmark
78-
public static final MethodHandle NATIVE_DOT_PRODUCT = nativeDotProduct();
79-
8055
private VectorUtil() {}
8156

82-
/*
83-
Used in o.a.l.benchmark.jmh.VectorUtilBenchmark to create test vectors
84-
in off-heap MemorySegments IF VectorUtilSupport instance supports
85-
Panama APIs.
86-
*/
87-
public static Class<?> getVectorUtilSupportClass() {
88-
return IMPL.getClass();
89-
}
90-
9157
/**
9258
* Returns the vector dot product of the two vectors.
9359
*

0 commit comments

Comments
 (0)