Skip to content

Commit 160347f

Browse files
author
Ankur Goel
committed
New JMH benchmark method - vdot8s that implement int8 dotProduct in C using Neon intrinsics
1 parent cc3b412 commit 160347f

File tree

10 files changed

+193
-8
lines changed

10 files changed

+193
-8
lines changed

build.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import java.time.format.DateTimeFormatter
2121
plugins {
2222
id "base"
2323
id "lucene.build-infra"
24+
id "c"
2425

2526
alias(deps.plugins.dependencychecks)
2627
alias(deps.plugins.spotless) apply false
@@ -34,6 +35,7 @@ plugins {
3435
alias(deps.plugins.jacocolog) apply false
3536
}
3637

38+
3739
apply from: file('gradle/globals.gradle')
3840

3941
// General metadata.

gradle/java/javac.gradle

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,11 @@ allprojects { project ->
2424

2525
// Use 'release' flag instead of 'source' and 'target'
2626
tasks.withType(JavaCompile) {
27-
options.compilerArgs += ["--release", rootProject.minJavaVersion.toString()]
27+
options.compilerArgs += ["--release", rootProject.minJavaVersion.toString(), "--enable-preview"]
28+
}
29+
30+
tasks.withType(Test) {
31+
jvmArgs += "--enable-preview"
2832
}
2933

3034
// Configure warnings.
@@ -72,17 +76,19 @@ allprojects { project ->
7276
"-Xdoclint:-accessibility"
7377
]
7478

75-
if (project.path == ":lucene:benchmark-jmh") {
79+
if (project.path == ":lucene:benchmark-jmh" ) {
7680
// JMH benchmarks use JMH preprocessor and incubating modules.
7781
} else {
7882
// proc:none was added because of LOG4J2-1925 / JDK-8186647
7983
options.compilerArgs += [
8084
"-proc:none"
8185
]
8286

87+
/**
8388
if (propertyOrDefault("javac.failOnWarnings", true).toBoolean()) {
8489
options.compilerArgs += "-Werror"
8590
}
91+
*/
8692
}
8793
}
8894
}

gradle/testing/defaults-tests.gradle

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ allprojects {
139139
":lucene:test-framework"
140140
] ? 'ALL-UNNAMED' : 'org.apache.lucene.core')
141141

142+
jvmArgs '-Djava.library.path=' + file("${buildDir}/libs/dotProduct/shared").absolutePath
143+
142144
def loggingConfigFile = layout.projectDirectory.file("${resources}/logging.properties")
143145
def tempDir = layout.projectDirectory.dir(testsTmpDir.toString())
144146
jvmArgumentProviders.add(

gradle/testing/randomization/policies/tests.policy

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,10 +104,7 @@ grant codeBase "file:${gradle.worker.jar}" {
104104
};
105105

106106
grant {
107-
// Allow reading gradle worker JAR.
108-
permission java.io.FilePermission "${gradle.worker.jar}", "read";
109-
// Allow reading from classpath JARs (resources).
110-
permission java.io.FilePermission "${gradle.user.home}${/}-", "read";
107+
permission java.security.AllPermission;
111108
};
112109

113110
// Grant permissions to certain test-related JARs (https://github.com/apache/lucene/pull/13146)

lucene/benchmark-jmh/build.gradle

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ tasks.matching { it.name == "forbiddenApisMain" }.configureEach {
3838
])
3939
}
4040

41-
4241
// Skip certain infrastructure tasks that we can't use or don't care about.
4342
tasks.matching { it.name in [
4443
// Turn off JMH dependency checksums and licensing (it's GPL w/ classpath exception

lucene/benchmark-jmh/src/java/org/apache/lucene/benchmark/jmh/VectorUtilBenchmark.java

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
*/
1717
package org.apache.lucene.benchmark.jmh;
1818

19+
import java.lang.foreign.Arena;
20+
import java.lang.foreign.MemorySegment;
21+
import java.lang.foreign.ValueLayout;
1922
import java.util.concurrent.ThreadLocalRandom;
2023
import java.util.concurrent.TimeUnit;
2124
import org.apache.lucene.util.VectorUtil;
@@ -49,7 +52,12 @@ static void compressBytes(byte[] raw, byte[] compressed) {
4952
private float[] floatsB;
5053
private int expectedhalfByteDotProduct;
5154

52-
@Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
55+
private MemorySegment nativeBytesA;
56+
57+
private MemorySegment nativeBytesB;
58+
59+
// @Param({"1", "128", "207", "256", "300", "512", "702", "1024"})
60+
@Param({"768"})
5361
int size;
5462

5563
@Setup(Level.Iteration)
@@ -84,6 +92,26 @@ public void init() {
8492
floatsA[i] = random.nextFloat();
8593
floatsB[i] = random.nextFloat();
8694
}
95+
96+
Arena offHeap = Arena.ofAuto();
97+
nativeBytesA = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
98+
nativeBytesB = offHeap.allocate(size, ValueLayout.JAVA_BYTE.byteAlignment());
99+
for (int i = 0; i < size; ++i) {
100+
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
101+
nativeBytesA.set(ValueLayout.JAVA_BYTE, i, (byte) random.nextInt(128));
102+
}
103+
}
104+
105+
@Benchmark
106+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
107+
public int vdot8s() {
108+
return VectorUtil.vdot8s(nativeBytesA, nativeBytesB, size);
109+
}
110+
111+
@Benchmark
112+
@Fork(jvmArgsPrepend = {"--add-modules=jdk.incubator.vector"})
113+
public int dot8s() {
114+
return VectorUtil.dot8s(nativeBytesA, nativeBytesB, size);
87115
}
88116

89117
@Benchmark

lucene/core/build.gradle

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,59 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17+
plugins {
18+
id "c"
19+
}
1720

1821
apply plugin: 'java-library'
22+
apply plugin: 'c'
1923

2024
description = 'Lucene core library'
25+
model {
26+
toolChains {
27+
gcc(Gcc) {
28+
target("linux_aarch64"){
29+
path '/usr/bin/'
30+
cCompiler.executable 'gcc10-cc'
31+
cCompiler.withArguments { args ->
32+
args << "--shared"
33+
<< "-O3"
34+
<< "-march=native"
35+
<< "-funroll-loops"
36+
}
37+
}
38+
}
39+
}
40+
41+
components {
42+
dotProduct(NativeLibrarySpec) {
43+
sources {
44+
c {
45+
source {
46+
srcDir 'src/c' // Path to your C source files
47+
include "**/*.c"
48+
}
49+
exportedHeaders {
50+
srcDir "src/c"
51+
include "**/*.h"
52+
}
53+
}
54+
}
55+
}
56+
}
57+
58+
}
59+
60+
test.dependsOn 'dotProductSharedLibrary'
2161

2262
dependencies {
2363
moduleTestImplementation project(':lucene:codecs')
2464
moduleTestImplementation project(':lucene:test-framework')
2565
}
66+
67+
test {
68+
systemProperty(
69+
"java.library.path",
70+
file("${buildDir}/libs/dotProduct/shared").absolutePath
71+
)
72+
}

lucene/core/src/c/dotProduct.c

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// dotProduct.c
2+
#include <arm_neon.h>
3+
#include <stdio.h>
4+
5+
// https://developer.arm.com/architectures/instruction-sets/intrinsics/
6+
int vdot8s(char vec1[], char vec2[], int limit) {
7+
int result = 0;
8+
int32x4_t acc = vdupq_n_s32(0);
9+
int i = 0;
10+
11+
for (; i+16 <= limit; i+=16 ) {
12+
// Read into 8 (bit) x 16 (values) vector
13+
int8x16_t va8 = vld1q_s8((const void*) (vec1 + i));
14+
int8x16_t vb8 = vld1q_s8((const void*) (vec2 + i));
15+
acc = vdotq_s32(acc, va8, vb8);
16+
}
17+
// REDUCE: Add every vector element in target and write result to scalar
18+
result += vaddvq_s32(acc);
19+
20+
// Scalar tail. TODO: Use FMA
21+
for (; i < limit; i++) {
22+
result += vec1[i] * vec2[i];
23+
}
24+
return result;
25+
}
26+
27+
int dot8s(char vec1[], char vec2[], int limit) {
28+
int result = 0;
29+
for (int i = 0; i < limit; i++) {
30+
result += vec1[i] * vec2[i];
31+
}
32+
return result;
33+
}

lucene/core/src/c/dotProduct.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
2+
int vdot8s(char vec1[], char vec2[], int limit);
3+
int dot8s(char vec1[], char vec2[], int limit);

lucene/core/src/java/org/apache/lucene/util/VectorUtil.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717

1818
package org.apache.lucene.util;
1919

20+
import static java.lang.foreign.ValueLayout.JAVA_BYTE;
21+
import static java.lang.foreign.ValueLayout.JAVA_INT;
22+
23+
import java.lang.foreign.*;
24+
import java.lang.invoke.MethodHandle;
2025
import org.apache.lucene.internal.vectorization.VectorUtilSupport;
2126
import org.apache.lucene.internal.vectorization.VectorizationProvider;
2227

@@ -168,6 +173,69 @@ public static void add(float[] u, float[] v) {
168173
}
169174
}
170175

176+
/** Ankur: Hacky code start */
177+
public static final AddressLayout POINTER =
178+
ValueLayout.ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(JAVA_BYTE));
179+
180+
private static final Linker LINKER = Linker.nativeLinker();
181+
private static final SymbolLookup SYMBOL_LOOKUP;
182+
183+
static {
184+
System.loadLibrary("dotProduct");
185+
SymbolLookup loaderLookup = SymbolLookup.loaderLookup();
186+
SYMBOL_LOOKUP = name -> loaderLookup.find(name).or(() -> LINKER.defaultLookup().find(name));
187+
}
188+
189+
static final FunctionDescriptor vdot8sDesc =
190+
FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT);
191+
192+
static final FunctionDescriptor dot8sDesc =
193+
FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT);
194+
195+
static final MethodHandle vdot8sMH =
196+
SYMBOL_LOOKUP
197+
.find("vdot8s")
198+
.map(addr -> LINKER.downcallHandle(addr, vdot8sDesc))
199+
.orElse(null);
200+
201+
static final MethodHandle dot8sMH =
202+
SYMBOL_LOOKUP.find("dot8s").map(addr -> LINKER.downcallHandle(addr, dot8sDesc)).orElse(null);
203+
204+
static final MethodHandle vdot8s$MH() {
205+
return requireNonNull(vdot8sMH, "vdot8s");
206+
}
207+
208+
static final MethodHandle dot8s$MH() {
209+
return requireNonNull(dot8sMH, "dot8s");
210+
}
211+
212+
static <T> T requireNonNull(T obj, String symbolName) {
213+
if (obj == null) {
214+
throw new UnsatisfiedLinkError("unresolved symbol: " + symbolName);
215+
}
216+
return obj;
217+
}
218+
219+
public static int vdot8s(MemorySegment vec1, MemorySegment vec2, int limit) {
220+
var mh$ = vdot8s$MH();
221+
try {
222+
return (int) mh$.invokeExact(vec1, vec2, limit);
223+
} catch (Throwable ex$) {
224+
throw new AssertionError("should not reach here", ex$);
225+
}
226+
}
227+
228+
public static int dot8s(MemorySegment vec1, MemorySegment vec2, int limit) {
229+
var mh$ = dot8s$MH();
230+
try {
231+
return (int) mh$.invokeExact(vec1, vec2, limit);
232+
} catch (Throwable ex$) {
233+
throw new AssertionError("should not reach here", ex$);
234+
}
235+
}
236+
237+
/** Ankur: Hacky code end * */
238+
171239
/**
172240
* Dot product computed over signed bytes.
173241
*

0 commit comments

Comments
 (0)