diff --git a/cpp/src/Spaces/Euclidean.h b/cpp/src/Spaces/Euclidean.h index 252e413d..cf5bbc5c 100644 --- a/cpp/src/Spaces/Euclidean.h +++ b/cpp/src/Spaces/Euclidean.h @@ -174,9 +174,51 @@ static float L2SqrSIMD16Ext(const float *pVect1, const float *pVect2, _mm_store_ps(TmpRes, sum); return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; } + +#elif defined(USE_NEON) + +static float L2SqrSIMD16Ext(const float *pVect1, const float *pVect2, + const size_t qty) { + size_t qty16 = qty >> 4; + const float *pEnd1 = pVect1 + (qty16 << 4); + + float32x4_t sum0 = vdupq_n_f32(0); + float32x4_t sum1 = vdupq_n_f32(0); + float32x4_t sum2 = vdupq_n_f32(0); + float32x4_t sum3 = vdupq_n_f32(0); + + while (pVect1 < pEnd1) { + float32x4_t v1_0 = vld1q_f32(pVect1); + float32x4_t v2_0 = vld1q_f32(pVect2); + float32x4_t diff0 = vsubq_f32(v1_0, v2_0); + sum0 = vfmaq_f32(sum0, diff0, diff0); + + float32x4_t v1_1 = vld1q_f32(pVect1 + 4); + float32x4_t v2_1 = vld1q_f32(pVect2 + 4); + float32x4_t diff1 = vsubq_f32(v1_1, v2_1); + sum1 = vfmaq_f32(sum1, diff1, diff1); + + float32x4_t v1_2 = vld1q_f32(pVect1 + 8); + float32x4_t v2_2 = vld1q_f32(pVect2 + 8); + float32x4_t diff2 = vsubq_f32(v1_2, v2_2); + sum2 = vfmaq_f32(sum2, diff2, diff2); + + float32x4_t v1_3 = vld1q_f32(pVect1 + 12); + float32x4_t v2_3 = vld1q_f32(pVect2 + 12); + float32x4_t diff3 = vsubq_f32(v1_3, v2_3); + sum3 = vfmaq_f32(sum3, diff3, diff3); + + pVect1 += 16; + pVect2 += 16; + } + + return vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3))); +} + #endif -#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) || \ + defined(USE_NEON) static float L2SqrSIMD16ExtResiduals(const float *pVect1, const float *pVect2, const size_t qty) { size_t qty16 = qty >> 4 << 4; @@ -189,7 +231,7 @@ static float L2SqrSIMD16ExtResiduals(const float *pVect1, const float *pVect2, } #endif -#ifdef USE_SSE +#if defined(USE_SSE) static float L2SqrSIMD4Ext(const float *pVect1, const float *pVect2, const size_t qty) { float PORTABLE_ALIGN32 TmpRes[8]; @@ -212,6 +254,30 @@ static float L2SqrSIMD4Ext(const float *pVect1, const float *pVect2, return TmpRes[0] + TmpRes[1] + TmpRes[2] + TmpRes[3]; } +#elif defined(USE_NEON) + +static float L2SqrSIMD4Ext(const float *pVect1, const float *pVect2, + const size_t qty) { + size_t qty4 = qty >> 2; + const float *pEnd1 = pVect1 + (qty4 << 2); + + float32x4_t sum = vdupq_n_f32(0); + + while (pVect1 < pEnd1) { + float32x4_t v1 = vld1q_f32(pVect1); + pVect1 += 4; + float32x4_t v2 = vld1q_f32(pVect2); + pVect2 += 4; + float32x4_t diff = vsubq_f32(v1, v2); + sum = vfmaq_f32(sum, diff, diff); + } + + return vaddvq_f32(sum); +} + +#endif + +#if defined(USE_SSE) || defined(USE_NEON) static float L2SqrSIMD4ExtResiduals(const float *pVect1, const float *pVect2, const size_t qty) { size_t qty4 = qty >> 2 << 2; @@ -276,7 +342,8 @@ template <> EuclideanSpace::EuclideanSpace(size_t dim) : data_size_(dim * sizeof(float)), dim_(dim) { fstdistfunc_ = L2Sqr; -#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) || \ + defined(USE_NEON) if (dim % 16 == 0) fstdistfunc_ = L2SqrSIMD16Ext; else if (dim % 4 == 0) diff --git a/cpp/src/Spaces/InnerProduct.h b/cpp/src/Spaces/InnerProduct.h index 5e671f6e..d557e953 100644 --- a/cpp/src/Spaces/InnerProduct.h +++ b/cpp/src/Spaces/InnerProduct.h @@ -183,6 +183,56 @@ static float InnerProductSIMD4Ext(const float *pVect1, const float *pVect2, return 1.0f - sum; } +#elif defined(USE_NEON) + +static float InnerProductSIMD4Ext(const float *pVect1, const float *pVect2, + const size_t qty) { + size_t qty16 = qty / 16; + size_t qty4 = qty / 4; + + const float *pEnd1 = pVect1 + 16 * qty16; + const float *pEnd2 = pVect1 + 4 * qty4; + + float32x4_t sum0 = vdupq_n_f32(0); + float32x4_t sum1 = vdupq_n_f32(0); + float32x4_t sum2 = vdupq_n_f32(0); + float32x4_t sum3 = vdupq_n_f32(0); + + while (pVect1 < pEnd1) { + float32x4_t v1_0 = vld1q_f32(pVect1); + float32x4_t v2_0 = vld1q_f32(pVect2); + sum0 = vfmaq_f32(sum0, v1_0, v2_0); + + float32x4_t v1_1 = vld1q_f32(pVect1 + 4); + float32x4_t v2_1 = vld1q_f32(pVect2 + 4); + sum1 = vfmaq_f32(sum1, v1_1, v2_1); + + float32x4_t v1_2 = vld1q_f32(pVect1 + 8); + float32x4_t v2_2 = vld1q_f32(pVect2 + 8); + sum2 = vfmaq_f32(sum2, v1_2, v2_2); + + float32x4_t v1_3 = vld1q_f32(pVect1 + 12); + float32x4_t v2_3 = vld1q_f32(pVect2 + 12); + sum3 = vfmaq_f32(sum3, v1_3, v2_3); + + pVect1 += 16; + pVect2 += 16; + } + + float32x4_t sum_prod = + vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3)); + + while (pVect1 < pEnd2) { + float32x4_t v1 = vld1q_f32(pVect1); + pVect1 += 4; + float32x4_t v2 = vld1q_f32(pVect2); + pVect2 += 4; + sum_prod = vfmaq_f32(sum_prod, v1, v2); + } + + return 1.0f - vaddvq_f32(sum_prod); +} + #endif #if defined(USE_AVX512) @@ -294,9 +344,48 @@ static float InnerProductSIMD16Ext(const float *pVect1, const float *pVect2, return 1.0f - sum; } +#elif defined(USE_NEON) + +static float InnerProductSIMD16Ext(const float *pVect1, const float *pVect2, + const size_t qty) { + size_t qty16 = qty / 16; + const float *pEnd1 = pVect1 + 16 * qty16; + + float32x4_t sum0 = vdupq_n_f32(0); + float32x4_t sum1 = vdupq_n_f32(0); + float32x4_t sum2 = vdupq_n_f32(0); + float32x4_t sum3 = vdupq_n_f32(0); + + while (pVect1 < pEnd1) { + float32x4_t v1_0 = vld1q_f32(pVect1); + float32x4_t v2_0 = vld1q_f32(pVect2); + sum0 = vfmaq_f32(sum0, v1_0, v2_0); + + float32x4_t v1_1 = vld1q_f32(pVect1 + 4); + float32x4_t v2_1 = vld1q_f32(pVect2 + 4); + sum1 = vfmaq_f32(sum1, v1_1, v2_1); + + float32x4_t v1_2 = vld1q_f32(pVect1 + 8); + float32x4_t v2_2 = vld1q_f32(pVect2 + 8); + sum2 = vfmaq_f32(sum2, v1_2, v2_2); + + float32x4_t v1_3 = vld1q_f32(pVect1 + 12); + float32x4_t v2_3 = vld1q_f32(pVect2 + 12); + sum3 = vfmaq_f32(sum3, v1_3, v2_3); + + pVect1 += 16; + pVect2 += 16; + } + + float sum = + vaddvq_f32(vaddq_f32(vaddq_f32(sum0, sum1), vaddq_f32(sum2, sum3))); + return 1.0f - sum; +} + #endif -#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) || \ + defined(USE_NEON) static float InnerProductSIMD16ExtResiduals(const float *pVect1, const float *pVect2, const size_t qty) { @@ -374,7 +463,8 @@ template <> InnerProductSpace::InnerProductSpace(size_t dim) : data_size_(dim * sizeof(float)), dim_(dim) { fstdistfunc_ = InnerProduct; -#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) +#if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) || \ + defined(USE_NEON) if (dim % 16 == 0) fstdistfunc_ = InnerProductSIMD16Ext; else if (dim % 4 == 0) diff --git a/cpp/src/hnswlib.h b/cpp/src/hnswlib.h index 13e21552..5ac4bd33 100644 --- a/cpp/src/hnswlib.h +++ b/cpp/src/hnswlib.h @@ -32,6 +32,9 @@ #endif #endif #endif +#ifdef __ARM_NEON +#define USE_NEON +#endif #endif #if defined(USE_AVX) || defined(USE_SSE) @@ -55,6 +58,10 @@ #endif #endif +#if defined(USE_NEON) +#include +#endif + #include "StreamUtils.h" #include "visited_list_pool.h" #include diff --git a/python/pyproject.toml b/python/pyproject.toml index ed15f41c..5fdea543 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["scikit-build-core >=0.4.3", "nanobind >=1.3.2"] +requires = ["scikit-build-core >=0.4.3", "nanobind >=1.3.2,<=2.9.2"] build-backend = "scikit_build_core.build" [project]