|
17 | 17 |
|
18 | 18 | package org.apache.lucene.util;
|
19 | 19 |
|
| 20 | +import static java.lang.foreign.ValueLayout.JAVA_BYTE; |
| 21 | +import static java.lang.foreign.ValueLayout.JAVA_INT; |
| 22 | + |
| 23 | +import java.lang.foreign.*; |
| 24 | +import java.lang.invoke.MethodHandle; |
20 | 25 | import org.apache.lucene.internal.vectorization.VectorUtilSupport;
|
21 | 26 | import org.apache.lucene.internal.vectorization.VectorizationProvider;
|
22 | 27 |
|
@@ -168,6 +173,69 @@ public static void add(float[] u, float[] v) {
|
168 | 173 | }
|
169 | 174 | }
|
170 | 175 |
|
| 176 | + /** Ankur: Hacky code start */ |
| 177 | + public static final AddressLayout POINTER = |
| 178 | + ValueLayout.ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(JAVA_BYTE)); |
| 179 | + |
| 180 | + private static final Linker LINKER = Linker.nativeLinker(); |
| 181 | + private static final SymbolLookup SYMBOL_LOOKUP; |
| 182 | + |
| 183 | + static { |
| 184 | + System.loadLibrary("dotProduct"); |
| 185 | + SymbolLookup loaderLookup = SymbolLookup.loaderLookup(); |
| 186 | + SYMBOL_LOOKUP = name -> loaderLookup.find(name).or(() -> LINKER.defaultLookup().find(name)); |
| 187 | + } |
| 188 | + |
| 189 | + static final FunctionDescriptor vdot8sDesc = |
| 190 | + FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT); |
| 191 | + |
| 192 | + static final FunctionDescriptor dot8sDesc = |
| 193 | + FunctionDescriptor.of(JAVA_INT, POINTER, POINTER, JAVA_INT); |
| 194 | + |
| 195 | + static final MethodHandle vdot8sMH = |
| 196 | + SYMBOL_LOOKUP |
| 197 | + .find("vdot8s") |
| 198 | + .map(addr -> LINKER.downcallHandle(addr, vdot8sDesc)) |
| 199 | + .orElse(null); |
| 200 | + |
| 201 | + static final MethodHandle dot8sMH = |
| 202 | + SYMBOL_LOOKUP.find("dot8s").map(addr -> LINKER.downcallHandle(addr, dot8sDesc)).orElse(null); |
| 203 | + |
| 204 | + static final MethodHandle vdot8s$MH() { |
| 205 | + return requireNonNull(vdot8sMH, "vdot8s"); |
| 206 | + } |
| 207 | + |
| 208 | + static final MethodHandle dot8s$MH() { |
| 209 | + return requireNonNull(dot8sMH, "dot8s"); |
| 210 | + } |
| 211 | + |
| 212 | + static <T> T requireNonNull(T obj, String symbolName) { |
| 213 | + if (obj == null) { |
| 214 | + throw new UnsatisfiedLinkError("unresolved symbol: " + symbolName); |
| 215 | + } |
| 216 | + return obj; |
| 217 | + } |
| 218 | + |
| 219 | + public static int vdot8s(MemorySegment vec1, MemorySegment vec2, int limit) { |
| 220 | + var mh$ = vdot8s$MH(); |
| 221 | + try { |
| 222 | + return (int) mh$.invokeExact(vec1, vec2, limit); |
| 223 | + } catch (Throwable ex$) { |
| 224 | + throw new AssertionError("should not reach here", ex$); |
| 225 | + } |
| 226 | + } |
| 227 | + |
| 228 | + public static int dot8s(MemorySegment vec1, MemorySegment vec2, int limit) { |
| 229 | + var mh$ = dot8s$MH(); |
| 230 | + try { |
| 231 | + return (int) mh$.invokeExact(vec1, vec2, limit); |
| 232 | + } catch (Throwable ex$) { |
| 233 | + throw new AssertionError("should not reach here", ex$); |
| 234 | + } |
| 235 | + } |
| 236 | + |
| 237 | + /** Ankur: Hacky code end * */ |
| 238 | + |
171 | 239 | /**
|
172 | 240 | * Dot product computed over signed bytes.
|
173 | 241 | *
|
|
0 commit comments