Skip to content

Commit 46a8a59

Browse files
authored
feat: add remaining Float16 specializations for simd calculation (#78)
This PR adds the remaining `Float16` specializations for `CosineSimilarityImpl`, resulting in SIMD ops being used instead of the generic reference implementation in some of the test cases.
1 parent a88c8e6 commit 46a8a59

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

include/svs/core/distance/cosine.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -306,9 +306,26 @@ template <size_t N> struct CosineSimilarityImpl<N, float, int8_t> {
306306
template <size_t N> struct CosineSimilarityImpl<N, float, Float16> {
307307
SVS_NOINLINE static float
308308
compute(const float* a, const Float16* b, float a_norm, lib::MaybeStatic<N> length) {
309-
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>(), a, b, length);
309+
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length);
310+
return sum / (std::sqrt(norm) * a_norm);
311+
}
312+
};
313+
314+
template <size_t N> struct CosineSimilarityImpl<N, Float16, float> {
315+
SVS_NOINLINE static float
316+
compute(const Float16* a, const float* b, float a_norm, lib::MaybeStatic<N> length) {
317+
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length);
310318
return sum / (std::sqrt(norm) * a_norm);
311319
}
312320
};
321+
322+
template <size_t N> struct CosineSimilarityImpl<N, Float16, Float16> {
323+
SVS_NOINLINE static float
324+
compute(const Float16* a, const Float16* b, float a_norm, lib::MaybeStatic<N> length) {
325+
auto [sum, norm] = simd::generic_simd_op(CosineFloatOp<16>{}, a, b, length);
326+
return sum / (std::sqrt(norm) * a_norm);
327+
}
328+
};
329+
313330
#endif
314331
} // namespace svs::distance

include/svs/index/inverted/common.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,12 @@ template <typename T> inline T bound_with(T nearest, T epsilon, svs::DistanceL2)
4040
}
4141

4242
template <typename T> inline T bound_with(T nearest, T epsilon, svs::DistanceIP) {
43-
// TODO: What do we do if the best match is simply bad?
43+
assert(nearest > 0.0f);
44+
return nearest / (1 + epsilon);
45+
}
46+
47+
template <typename T>
48+
inline T bound_with(T nearest, T epsilon, svs::DistanceCosineSimilarity) {
4449
assert(nearest > 0.0f);
4550
return nearest / (1 + epsilon);
4651
}

0 commit comments

Comments
 (0)